恢复 t2s_model.py 把改动移到 export_torch_script.py

This commit is contained in:
csh 2024-09-24 15:48:15 +08:00
parent 41dbc179c3
commit 4ed0b8bdcc
2 changed files with 267 additions and 71 deletions

View File

@ -83,7 +83,7 @@ class T2SMLP:
class T2SBlock: class T2SBlock:
def __init__( def __init__(
self, self,
num_heads: int, num_heads,
hidden_dim: int, hidden_dim: int,
mlp: T2SMLP, mlp: T2SMLP,
qkv_w, qkv_w,
@ -92,12 +92,12 @@ class T2SBlock:
out_b, out_b,
norm_w1, norm_w1,
norm_b1, norm_b1,
norm_eps1: float, norm_eps1,
norm_w2, norm_w2,
norm_b2, norm_b2,
norm_eps2: float, norm_eps2,
): ):
self.num_heads:int = num_heads self.num_heads = num_heads
self.mlp = mlp self.mlp = mlp
self.hidden_dim: int = hidden_dim self.hidden_dim: int = hidden_dim
self.qkv_w = qkv_w self.qkv_w = qkv_w
@ -266,7 +266,7 @@ class Text2SemanticDecoder(nn.Module):
self.norm_first = norm_first self.norm_first = norm_first
self.vocab_size = config["model"]["vocab_size"] self.vocab_size = config["model"]["vocab_size"]
self.phoneme_vocab_size = config["model"]["phoneme_vocab_size"] self.phoneme_vocab_size = config["model"]["phoneme_vocab_size"]
self.p_dropout = float(config["model"]["dropout"]) self.p_dropout = config["model"]["dropout"]
self.EOS = config["model"]["EOS"] self.EOS = config["model"]["EOS"]
self.norm_first = norm_first self.norm_first = norm_first
assert self.EOS == self.vocab_size - 1 assert self.EOS == self.vocab_size - 1

View File

@ -34,6 +34,7 @@ default_config = {
def get_raw_t2s_model(dict_s1) -> Text2SemanticLightningModule: def get_raw_t2s_model(dict_s1) -> Text2SemanticLightningModule:
config = dict_s1["config"] config = dict_s1["config"]
config["model"]["dropout"] = float(config["model"]["dropout"])
t2s_model = Text2SemanticLightningModule(config, "****", is_train=False) t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
t2s_model.load_state_dict(dict_s1["weight"]) t2s_model.load_state_dict(dict_s1["weight"])
t2s_model = t2s_model.eval() t2s_model = t2s_model.eval()
@ -105,7 +106,7 @@ def sample(
@torch.jit.script @torch.jit.script
def spectrogram_torch(y, n_fft:int, sampling_rate:int, hop_size:int, win_size:int, center:bool=False): def spectrogram_torch(y:Tensor, n_fft:int, sampling_rate:int, hop_size:int, win_size:int, center:bool=False):
hann_window = torch.hann_window(win_size,device=y.device,dtype=y.dtype) hann_window = torch.hann_window(win_size,device=y.device,dtype=y.dtype)
y = torch.nn.functional.pad( y = torch.nn.functional.pad(
y.unsqueeze(1), y.unsqueeze(1),
@ -156,6 +157,179 @@ class DictToAttrRecursive(dict):
except KeyError: except KeyError:
raise AttributeError(f"Attribute {item} not found") raise AttributeError(f"Attribute {item} not found")
@torch.jit.script
class T2SMLP:
def __init__(self, w1, b1, w2, b2):
self.w1 = w1
self.b1 = b1
self.w2 = w2
self.b2 = b2
def forward(self, x):
x = F.relu(F.linear(x, self.w1, self.b1))
x = F.linear(x, self.w2, self.b2)
return x
@torch.jit.script
class T2SBlock:
def __init__(
self,
num_heads: int,
hidden_dim: int,
mlp: T2SMLP,
qkv_w,
qkv_b,
out_w,
out_b,
norm_w1,
norm_b1,
norm_eps1: float,
norm_w2,
norm_b2,
norm_eps2: float,
):
self.num_heads = num_heads
self.mlp = mlp
self.hidden_dim: int = hidden_dim
self.qkv_w = qkv_w
self.qkv_b = qkv_b
self.out_w = out_w
self.out_b = out_b
self.norm_w1 = norm_w1
self.norm_b1 = norm_b1
self.norm_eps1 = norm_eps1
self.norm_w2 = norm_w2
self.norm_b2 = norm_b2
self.norm_eps2 = norm_eps2
self.false = torch.tensor(False, dtype=torch.bool)
@torch.jit.ignore
def to_mask(self, x:torch.Tensor, padding_mask:Optional[torch.Tensor]):
if padding_mask is None:
return x
if padding_mask.dtype == torch.bool:
return x.masked_fill(padding_mask, 0)
else:
return x * padding_mask
def process_prompt(self, x:torch.Tensor, attn_mask : torch.Tensor, padding_mask:Optional[torch.Tensor]=None):
q, k, v = F.linear(self.to_mask(x, padding_mask), self.qkv_w, self.qkv_b).chunk(3, dim=-1)
batch_size = q.shape[0]
q_len = q.shape[1]
kv_len = k.shape[1]
q = self.to_mask(q, padding_mask)
k_cache = self.to_mask(k, padding_mask)
v_cache = self.to_mask(v, padding_mask)
q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2)
k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
attn = F.scaled_dot_product_attention(q, k, v, ~attn_mask)
attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim)
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b)
if padding_mask is not None:
for i in range(batch_size):
# mask = padding_mask[i,:,0]
if self.false.device!= padding_mask.device:
self.false = self.false.to(padding_mask.device)
idx = torch.where(padding_mask[i,:,0]==self.false)[0]
x_item = x[i,idx,:].unsqueeze(0)
attn_item = attn[i,idx,:].unsqueeze(0)
x_item = x_item + attn_item
x_item = F.layer_norm(
x_item, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
)
x_item = x_item + self.mlp.forward(x_item)
x_item = F.layer_norm(
x_item,
[self.hidden_dim],
self.norm_w2,
self.norm_b2,
self.norm_eps2,
)
x[i,idx,:] = x_item.squeeze(0)
x = self.to_mask(x, padding_mask)
else:
x = x + attn
x = F.layer_norm(
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
)
x = x + self.mlp.forward(x)
x = F.layer_norm(
x,
[self.hidden_dim],
self.norm_w2,
self.norm_b2,
self.norm_eps2,
)
return x, k_cache, v_cache
def decode_next_token(self, x:torch.Tensor, k_cache:torch.Tensor, v_cache:torch.Tensor):
q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
k_cache = torch.cat([k_cache, k], dim=1)
v_cache = torch.cat([v_cache, v], dim=1)
batch_size = q.shape[0]
q_len = q.shape[1]
kv_len = k_cache.shape[1]
q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2)
k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
attn = F.scaled_dot_product_attention(q, k, v)
attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim)
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
attn = F.linear(attn, self.out_w, self.out_b)
x = x + attn
x = F.layer_norm(
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
)
x = x + self.mlp.forward(x)
x = F.layer_norm(
x,
[self.hidden_dim],
self.norm_w2,
self.norm_b2,
self.norm_eps2,
)
return x, k_cache, v_cache
@torch.jit.script
class T2STransformer:
def __init__(self, num_blocks : int, blocks: list[T2SBlock]):
self.num_blocks : int = num_blocks
self.blocks = blocks
def process_prompt(
self, x:torch.Tensor, attn_mask : torch.Tensor,padding_mask : Optional[torch.Tensor]=None):
k_cache : list[torch.Tensor] = []
v_cache : list[torch.Tensor] = []
for i in range(self.num_blocks):
x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask, padding_mask)
k_cache.append(k_cache_)
v_cache.append(v_cache_)
return x, k_cache, v_cache
def decode_next_token(
self, x:torch.Tensor,
k_cache: list[torch.Tensor],
v_cache: list[torch.Tensor]):
for i in range(self.num_blocks):
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i])
return x, k_cache, v_cache
class VitsModel(nn.Module): class VitsModel(nn.Module):
def __init__(self, vits_path): def __init__(self, vits_path):
super().__init__() super().__init__()
@ -189,34 +363,19 @@ class VitsModel(nn.Module):
return self.vq_model(pred_semantic, text_seq, refer)[0, 0] return self.vq_model(pred_semantic, text_seq, refer)[0, 0]
class T2SModel(nn.Module): class T2SModel(nn.Module):
def __init__(self, config,raw_t2s:Text2SemanticLightningModule, norm_first=False, top_k=3): def __init__(self,raw_t2s:Text2SemanticLightningModule):
super(T2SModel, self).__init__() super(T2SModel, self).__init__()
self.model_dim = config["model"]["hidden_dim"] self.model_dim = raw_t2s.model.model_dim
self.embedding_dim = config["model"]["embedding_dim"] self.embedding_dim = raw_t2s.model.embedding_dim
self.num_head = config["model"]["head"] self.num_head = raw_t2s.model.num_head
self.num_layers = config["model"]["n_layer"] self.num_layers = raw_t2s.model.num_layers
self.vocab_size = config["model"]["vocab_size"] self.vocab_size = raw_t2s.model.vocab_size
self.phoneme_vocab_size = config["model"]["phoneme_vocab_size"] self.phoneme_vocab_size = raw_t2s.model.phoneme_vocab_size
self.p_dropout = float(config["model"]["dropout"]) # self.p_dropout = float(raw_t2s.model.p_dropout)
self.EOS:int = config["model"]["EOS"] self.EOS:int = int(raw_t2s.model.EOS)
self.norm_first = norm_first self.norm_first = raw_t2s.model.norm_first
assert self.EOS == self.vocab_size - 1 assert self.EOS == self.vocab_size - 1
self.hz = 50 self.hz = 50
self.config = config
# self.bert_proj = nn.Linear(1024, self.embedding_dim)
# self.ar_text_embedding = TokenEmbedding(
# self.embedding_dim, self.phoneme_vocab_size, self.p_dropout
# )
# self.ar_text_position = SinePositionalEmbedding(
# self.embedding_dim, dropout=0.1, scale=False, alpha=True
# )
# self.ar_audio_embedding = TokenEmbedding(
# self.embedding_dim, self.vocab_size, self.p_dropout
# )
# self.ar_audio_position = SinePositionalEmbedding(
# self.embedding_dim, dropout=0.1, scale=False, alpha=True
# )
self.bert_proj = raw_t2s.model.bert_proj self.bert_proj = raw_t2s.model.bert_proj
self.ar_text_embedding = raw_t2s.model.ar_text_embedding self.ar_text_embedding = raw_t2s.model.ar_text_embedding
@ -225,13 +384,45 @@ class T2SModel(nn.Module):
self.ar_audio_position = raw_t2s.model.ar_audio_position self.ar_audio_position = raw_t2s.model.ar_audio_position
# self.t2s_transformer = T2STransformer(self.num_layers, blocks) # self.t2s_transformer = T2STransformer(self.num_layers, blocks)
self.t2s_transformer = raw_t2s.model.t2s_transformer # self.t2s_transformer = raw_t2s.model.t2s_transformer
blocks = []
h = raw_t2s.model.h
for i in range(self.num_layers):
layer = h.layers[i]
t2smlp = T2SMLP(
layer.linear1.weight,
layer.linear1.bias,
layer.linear2.weight,
layer.linear2.bias
)
block = T2SBlock(
self.num_head,
self.model_dim,
t2smlp,
layer.self_attn.in_proj_weight,
layer.self_attn.in_proj_bias,
layer.self_attn.out_proj.weight,
layer.self_attn.out_proj.bias,
layer.norm1.weight,
layer.norm1.bias,
layer.norm1.eps,
layer.norm2.weight,
layer.norm2.bias,
layer.norm2.eps
)
blocks.append(block)
self.t2s_transformer = T2STransformer(self.num_layers, blocks)
# self.ar_predict_layer = nn.Linear(self.model_dim, self.vocab_size, bias=False) # self.ar_predict_layer = nn.Linear(self.model_dim, self.vocab_size, bias=False)
self.ar_predict_layer = raw_t2s.model.ar_predict_layer self.ar_predict_layer = raw_t2s.model.ar_predict_layer
# self.loss_fct = nn.CrossEntropyLoss(reduction="sum") # self.loss_fct = nn.CrossEntropyLoss(reduction="sum")
self.max_sec = self.config["data"]["max_sec"] self.max_sec = raw_t2s.config["data"]["max_sec"]
self.top_k = int(self.config["inference"]["top_k"]) self.top_k = int(raw_t2s.config["inference"]["top_k"])
self.early_stop_num = torch.LongTensor([self.hz * self.max_sec]) self.early_stop_num = torch.LongTensor([self.hz * self.max_sec])
def forward(self,prompts:LongTensor, ref_seq:LongTensor, text_seq:LongTensor, ref_bert:torch.Tensor, text_bert:torch.Tensor): def forward(self,prompts:LongTensor, ref_seq:LongTensor, text_seq:LongTensor, ref_bert:torch.Tensor, text_bert:torch.Tensor):
@ -296,6 +487,10 @@ class T2SModel(nn.Module):
# y, k, v, y_emb, logits, samples = self.stage_decoder(y, k, v, y_emb, x_example) # y, k, v, y_emb, logits, samples = self.stage_decoder(y, k, v, y_emb, x_example)
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache) xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
logits = self.ar_predict_layer(xy_dec[:, -1]) logits = self.ar_predict_layer(xy_dec[:, -1])
if(idx<11):###至少预测出10个token不然不给停止0.4s
logits = logits[:, :-1]
samples = sample(logits, y, top_k=self.top_k, top_p=1, repetition_penalty=1.35, temperature=1.0)[0] samples = sample(logits, y, top_k=self.top_k, top_p=1, repetition_penalty=1.35, temperature=1.0)[0]
y = torch.concat([y, samples], dim=1) y = torch.concat([y, samples], dim=1)
@ -305,13 +500,13 @@ class T2SModel(nn.Module):
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS: if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
stop = True stop = True
if stop: if stop:
if y.shape[1] == 0:
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
break break
y_emb = self.ar_audio_embedding(y[:, -1:]) y_emb = self.ar_audio_embedding(y[:, -1:])
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to(dtype=y_emb.dtype,device=y_emb.device) xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to(dtype=y_emb.dtype,device=y_emb.device)
y[0, -1] = 0
return y[:, -idx:].unsqueeze(0) return y[:, -idx:].unsqueeze(0)
bert_path = os.environ.get( bert_path = os.environ.get(
@ -362,20 +557,19 @@ class ExportSSLModel(torch.nn.Module):
audio = resamplex(ref_audio,src_sr,dst_sr).float() audio = resamplex(ref_audio,src_sr,dst_sr).float()
return audio return audio
def export_bert(tokenizer,ref_text,word2ph): def export_bert(ref_bert_inputs):
ref_bert_inputs = tokenizer(ref_text, return_tensors="pt")
ref_bert_inputs = { ref_bert_inputs = {
'input_ids': torch.jit.annotate(torch.Tensor,ref_bert_inputs['input_ids']), 'input_ids': ref_bert_inputs['input_ids'],
'attention_mask': torch.jit.annotate(torch.Tensor,ref_bert_inputs['attention_mask']), 'attention_mask': ref_bert_inputs['attention_mask'],
'token_type_ids': torch.jit.annotate(torch.Tensor,ref_bert_inputs['token_type_ids']), 'token_type_ids': ref_bert_inputs['token_type_ids'],
'word2ph': torch.jit.annotate(torch.Tensor,word2ph) 'word2ph': ref_bert_inputs['word2ph']
} }
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True) bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True)
my_bert_model = MyBertModel(bert_model) my_bert_model = MyBertModel(bert_model)
my_bert_model = torch.jit.trace(my_bert_model,example_kwarg_inputs=ref_bert_inputs) my_bert_model = torch.jit.trace(my_bert_model,example_kwarg_inputs=ref_bert_inputs)
my_bert_model.save("onnx/bert_model.pt") my_bert_model.save("onnx/bert_model.pt")
print('exported bert') print('#### exported bert ####')
def export(gpt_path, vits_path): def export(gpt_path, vits_path):
tokenizer = AutoTokenizer.from_pretrained(bert_path) tokenizer = AutoTokenizer.from_pretrained(bert_path)
@ -392,18 +586,14 @@ def export(gpt_path, vits_path):
bert = MyBertModel(bert_model) bert = MyBertModel(bert_model)
# export_bert(tokenizer,"声音,是有温度的.夜晚的声音,会发光",ref_bert_inputs['word2ph']) # export_bert(ref_bert_inputs)
ref_audio = torch.tensor([load_audio("output/denoise_opt/chen1.mp4_0000033600_0000192000.wav", 16000)]).float() ref_audio = torch.tensor([load_audio("output/denoise_opt/chen1.mp4_0000033600_0000192000.wav", 16000)]).float()
ssl = SSLModel() ssl = SSLModel()
s = ExportSSLModel(torch.jit.trace(ssl,example_inputs=(torch.jit.annotate(torch.Tensor,ref_audio)))) s = ExportSSLModel(torch.jit.trace(ssl,example_inputs=(ref_audio)))
torch.jit.script(s).save("onnx/xw/ssl_model.pt") torch.jit.script(s).save("onnx/xw/ssl_model.pt")
print('#### exported ssl ####')
print('exported ssl')
# ref_seq = torch.LongTensor([cleaned_text_to_sequence(["zh", "ai4", "ch", "an1","j" ,"ia1","r","ua4","s","i3","t","e3","ch","un1","w","an3","d","e1", "sh", "i2", "h", "ou4", "y", "ou3", "r", "en2","w","en4","l","e1","zh","e4","y","ang4","y","i2","g","e4","w","en4","t","i2"],version='v2')])
# ref_seq = torch.LongTensor([cleaned_text_to_sequence(['sh','eng1','y','in1',',','sh','i4','y','ou3','w','en1','d','u4','d','e','.','y','e4','w','an3','d','e','sh','eng1','y','in1',',','h','ui4','f','a1','g','uang1'],version='v2')])
# text_seq = torch.LongTensor([cleaned_text_to_sequence(["d", "a4", "j", "ia1", "h", "ao3",",","w","o3","y", "ou3","y","i2","g","e4","q","i2","g","uai4","w","en4","t","i2","."],version='v2')])
ref_bert = bert(**ref_bert_inputs) ref_bert = bert(**ref_bert_inputs)
text_bert = bert(**text_berf_inputs) text_bert = bert(**text_berf_inputs)
ssl_content = ssl(ref_audio) ssl_content = ssl(ref_audio)
@ -415,24 +605,30 @@ def export(gpt_path, vits_path):
# gpt_path = "GPT_weights_v2/xw-e15.ckpt" # gpt_path = "GPT_weights_v2/xw-e15.ckpt"
dict_s1 = torch.load(gpt_path, map_location="cpu") dict_s1 = torch.load(gpt_path, map_location="cpu")
raw_t2s = get_raw_t2s_model(dict_s1) raw_t2s = get_raw_t2s_model(dict_s1)
t2s_m = T2SModel(dict_s1['config'],raw_t2s,top_k=3) t2s_m = T2SModel(raw_t2s)
t2s_m.eval() t2s_m.eval()
t2s = torch.jit.script(t2s_m) t2s = torch.jit.script(t2s_m)
print('exported t2s_m') print('#### script t2s_m ####')
print("vits.hps.data.sampling_rate:",vits.hps.data.sampling_rate)
gpt_sovits = GPT_SoVITS(t2s,vits) gpt_sovits = GPT_SoVITS(t2s,vits)
ref_audio_sr = ssl.resample(ref_audio,16000,32000) gpt_sovits.eval()
ref_audio_sr = s.resample(ref_audio,16000,32000)
print('ref_audio_sr:',ref_audio_sr.shape)
# audio = gpt_sovits(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, text_bert) gpt_sovits_export = torch.jit.trace(
gpt_sovits,
example_inputs=(
ssl_content,
ref_audio_sr,
ref_seq,
text_seq,
ref_bert,
text_bert),
check_trace=False) # 默认是True 但是 check 的时候可能是随机生成的一个奇怪维度的值,导致报错
torch.jit.trace(gpt_sovits,example_inputs=( gpt_sovits_export.save("onnx/xw/gpt_sovits_model.pt")
torch.jit.annotate(torch.Tensor,ssl_content), print('#### exported gpt_sovits ####')
torch.jit.annotate(torch.Tensor,ref_audio_sr),
torch.jit.annotate(torch.Tensor,ref_seq),
torch.jit.annotate(torch.Tensor,text_seq),
torch.jit.annotate(torch.Tensor,ref_bert),
torch.jit.annotate(torch.Tensor,text_bert))).save("onnx/xw/gpt_sovits_model.pt")
print('exported vits')
@torch.jit.script @torch.jit.script
def parse_audio(ref_audio): def parse_audio(ref_audio):
@ -459,20 +655,20 @@ class GPT_SoVITS(nn.Module):
audio = self.vits(text_seq, pred_semantic, ref_audio_sr) audio = self.vits(text_seq, pred_semantic, ref_audio_sr)
return audio return audio
def test(): def test(gpt_path, vits_path):
tokenizer = AutoTokenizer.from_pretrained(bert_path) tokenizer = AutoTokenizer.from_pretrained(bert_path)
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True) bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True)
bert = MyBertModel(bert_model) bert = MyBertModel(bert_model)
# bert = torch.jit.load("onnx/bert_model.pt",map_location='cuda') # bert = torch.jit.load("onnx/bert_model.pt",map_location='cuda')
gpt_path = "GPT_weights_v2/xw-e15.ckpt" # gpt_path = "GPT_weights_v2/xw-e15.ckpt"
dict_s1 = torch.load(gpt_path, map_location="cpu") dict_s1 = torch.load(gpt_path, map_location="cpu")
raw_t2s = get_raw_t2s_model(dict_s1) raw_t2s = get_raw_t2s_model(dict_s1)
t2s = T2SModel(dict_s1['config'],raw_t2s,top_k=3) t2s = T2SModel(raw_t2s)
t2s.eval() t2s.eval()
# t2s = torch.jit.load("onnx/xw/t2s_model.pt",map_location='cuda') # t2s = torch.jit.load("onnx/xw/t2s_model.pt",map_location='cuda')
vits_path = "SoVITS_weights_v2/xw_e8_s216.pth" # vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
vits = VitsModel(vits_path) vits = VitsModel(vits_path)
vits.eval() vits.eval()
@ -506,7 +702,7 @@ def test():
text_berf_inputs['word2ph']) text_berf_inputs['word2ph'])
#[1,N] #[1,N]
ref_audio = torch.tensor(load_audio("output/denoise_opt/xw.mp3_0000000000_0000156480.wav", 16000)).float().unsqueeze(0) ref_audio = torch.tensor([load_audio("output/denoise_opt/chen1.mp4_0000033600_0000192000.wav", 16000)]).float()
print('ref_audio:',ref_audio.shape) print('ref_audio:',ref_audio.shape)
ref_audio_sr = ssl.resample(ref_audio,16000,32000) ref_audio_sr = ssl.resample(ref_audio,16000,32000)
@ -537,5 +733,5 @@ def export_symbel(version='v2'):
if __name__ == "__main__": if __name__ == "__main__":
export(gpt_path="GPT_weights_v2/chen1-e15.ckpt", vits_path="SoVITS_weights_v2/chen1_e8_s208.pth") export(gpt_path="GPT_weights_v2/chen1-e15.ckpt", vits_path="SoVITS_weights_v2/chen1_e8_s208.pth")
# test() # test(gpt_path="GPT_weights_v2/chen1-e15.ckpt", vits_path="SoVITS_weights_v2/chen1_e8_s208.pth")
# export_symbel() # export_symbel()