From 5982080939d9d4ee1bd1fcde968e180da260da69 Mon Sep 17 00:00:00 2001 From: zpeng11 Date: Sat, 23 Aug 2025 17:35:21 -0400 Subject: [PATCH] feat:updated fsdecode and decoder interface --- GPT_SoVITS/AR/models/t2s_model_onnx.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/GPT_SoVITS/AR/models/t2s_model_onnx.py b/GPT_SoVITS/AR/models/t2s_model_onnx.py index d0b50449..438b843e 100644 --- a/GPT_SoVITS/AR/models/t2s_model_onnx.py +++ b/GPT_SoVITS/AR/models/t2s_model_onnx.py @@ -61,7 +61,7 @@ def logits_to_probs( ) logits = logits.masked_fill(indices_to_remove, -float("Inf")) - logits = logits / torch.max(temperature, torch.tensor(1e-5, device=temperature.device, dtype=temperature.dtype)) + logits = logits / torch.max(temperature, torch.tensor(1e-5, device=logits.device, dtype=torch.float)) # if top_k is not None: # To be captured by onnx v, _ = torch.topk(logits, top_k) @@ -115,7 +115,6 @@ class T2SFirstStageDecoder(nn.Module): ar_predict_layer, loss_fct, ar_accuracy_metric, - top_k, early_stop_num, num_layers, ): @@ -126,7 +125,6 @@ class T2SFirstStageDecoder(nn.Module): self.ar_predict_layer = ar_predict_layer self.loss_fct = loss_fct self.ar_accuracy_metric = ar_accuracy_metric - self.top_k = top_k self.early_stop_num = early_stop_num self.num_layers = num_layers @@ -205,7 +203,6 @@ class T2SStageDecoder(nn.Module): ar_predict_layer, loss_fct, ar_accuracy_metric, - top_k, early_stop_num, num_layers, ): @@ -216,7 +213,6 @@ class T2SStageDecoder(nn.Module): self.ar_predict_layer = ar_predict_layer self.loss_fct = loss_fct self.ar_accuracy_metric = ar_accuracy_metric - self.top_k = top_k self.early_stop_num = early_stop_num self.num_layers = num_layers @@ -317,7 +313,6 @@ class Text2SemanticDecoder(nn.Module): self.ar_predict_layer, self.loss_fct, self.ar_accuracy_metric, - self.top_k, self.early_stop_num, self.num_layers, ) @@ -328,23 +323,24 @@ class Text2SemanticDecoder(nn.Module): self.ar_predict_layer, self.loss_fct, self.ar_accuracy_metric, - self.top_k, self.early_stop_num, self.num_layers, ) - def forward(self, x, prompts, bert_feature): + def forward(self, x, prompts, bert_feature, top_k = None): # torch.manual_seed(42) # torch.use_deterministic_algorithms(True) + if top_k is None: + top_k = self.top_k early_stop_num = self.early_stop_num prefix_len = prompts.shape[1] x = self.onnx_encoder(x, bert_feature) - y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts) + y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts, top_k=top_k) stop = False for idx in tqdm(range(1, 1500)): - enco = self.stage_decoder(y, k, v, y_emb, x_example) + enco = self.stage_decoder(y, k, v, y_emb, x_example, top_k=top_k) y, k, v, y_emb, logits, samples = enco if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: stop = True @@ -356,10 +352,11 @@ class Text2SemanticDecoder(nn.Module): # torch.use_deterministic_algorithms(False) return y, idx - def infer(self, x, prompts, bert_feature): + def infer(self, x, prompts, bert_feature, top_k=None): # torch.manual_seed(42) # torch.use_deterministic_algorithms(True) - top_k = self.top_k + if top_k is None: + top_k = self.top_k early_stop_num = self.early_stop_num x = self.onnx_encoder(x, bert_feature) @@ -407,7 +404,7 @@ class Text2SemanticDecoder(nn.Module): xy_attn_mask = torch.zeros((1, x_len + y_len), dtype=torch.bool) xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache) logits = self.ar_predict_layer(xy_dec[:, -1]) - samples = sample(logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35, temperature=1.0)[0].unsqueeze(0) + samples = sample(logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35, temperature=torch.Tensor([1.0]))[0].unsqueeze(0) if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: stop = True if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS: