diff --git a/GPT_SoVITS/AR/models/t2s_model_onnx.py b/GPT_SoVITS/AR/models/t2s_model_onnx.py index 4f7b50a3..b381e58d 100644 --- a/GPT_SoVITS/AR/models/t2s_model_onnx.py +++ b/GPT_SoVITS/AR/models/t2s_model_onnx.py @@ -7,6 +7,7 @@ from torchmetrics.classification import MulticlassAccuracy from AR.modules.embedding_onnx import SinePositionalEmbedding, TokenEmbedding from AR.modules.transformer_onnx import LayerNorm, TransformerEncoder, TransformerEncoderLayer +from tqdm import tqdm default_config = { "embedding_dim": 512, @@ -179,7 +180,7 @@ class T2SFirstStageDecoder(nn.Module): xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache) logits = self.ar_predict_layer(xy_dec[:, -1]) - samples = sample(logits[0], y, top_k=self.top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0) + samples = sample(logits[0], y, top_k=15, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0) y = torch.concat([y, samples], dim=1) @@ -239,7 +240,7 @@ class T2SStageDecoder(nn.Module): xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache) logits = self.ar_predict_layer(xy_dec[:, -1]) - samples = sample(logits[0], y, top_k=self.top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0) + samples = sample(logits[0], y, top_k=15, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0) y = torch.concat([y, samples], dim=1) @@ -315,26 +316,31 @@ class Text2SemanticDecoder(nn.Module): ) def forward(self, x, prompts, bert_feature): + # torch.manual_seed(42) + # torch.use_deterministic_algorithms(True) early_stop_num = self.early_stop_num prefix_len = prompts.shape[1] x = self.onnx_encoder(x, bert_feature) - y, k, v, y_emb, stage, x_example = self.first_stage_decoder(x, prompts) + y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts) stop = False - for idx in range(1, 1500): - enco = self.stage_decoder(y, k, v, y_emb, stage, x_example) - y, k, v, y_emb, stage, logits, samples = enco + for idx in tqdm(range(1, 1500)): + enco = self.stage_decoder(y, k, v, y_emb, x_example) + y, k, v, y_emb, logits, samples = enco if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: stop = True if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS: stop = True if stop: + y = y[:,:-1] break - y[0, -1] = 0 + # torch.use_deterministic_algorithms(False) return y, idx def infer(self, x, prompts, bert_feature): + # torch.manual_seed(42) + # torch.use_deterministic_algorithms(True) top_k = self.top_k early_stop_num = self.early_stop_num @@ -356,11 +362,14 @@ class Text2SemanticDecoder(nn.Module): "first_infer": 1, "stage": 0, } - for idx in range(1500): + for idx in tqdm(range(1500)): if cache["first_infer"] == 1: y_emb = self.ar_audio_embedding(y) else: y_emb = torch.cat([cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1) + for i in range(len(cache["k"])): + cache["k"][i] = torch.nn.functional.pad(cache["k"][i], (0, 0, 0, 0, 0, 1)) + cache["v"][i] = torch.nn.functional.pad(cache["v"][i], (0, 0, 0, 0, 0, 1)) cache["y_emb"] = y_emb y_pos = self.ar_audio_position(y_emb) if cache["first_infer"] == 1: @@ -380,15 +389,14 @@ class Text2SemanticDecoder(nn.Module): xy_attn_mask = torch.zeros((1, x_len + y_len), dtype=torch.bool) xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache) logits = self.ar_predict_layer(xy_dec[:, -1]) - samples = sample(logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0) + samples = sample(logits[0], y, top_k=15, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0) if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: stop = True if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS: stop = True if stop: - if prompts.shape[1] == y.shape[1]: - y = torch.concat([y, torch.zeros_like(samples)], dim=1) break y = torch.concat([y, samples], dim=1) cache["first_infer"] = 0 + # torch.use_deterministic_algorithms(False) return y, idx diff --git a/playground/freerun.py b/playground/freerun.py index f914e6dc..18736c93 100644 --- a/playground/freerun.py +++ b/playground/freerun.py @@ -68,7 +68,7 @@ def preprocess_text(text:str): # ref_phones = np.load("playground/ref/ref_phones.npy") # ref_bert = np.load("playground/ref/ref_bert.npy").T.astype(np.float32) -[ref_phones, ref_bert] = preprocess_text("今日江苏苏州荷花市集开张热闹与浪漫交织") +[ref_phones, ref_bert] = preprocess_text("近日江苏苏州荷花市集开张热闹与浪漫交织") [audio_prompt_hubert, spectrum, sv_emb] = audio_preprocess("playground/ref/audio.wav") @@ -113,7 +113,7 @@ for idx in tqdm(range(1, 1500)): }) if np.argmax(logits, axis=-1)[0] == 1024 or samples[0, 0] == 1024: # 1024 is the EOS token break -y[0, -1] = 0 +y = y[:,:-1] pred_semantic = np.expand_dims(y[:, -idx:], axis=0)