From 9ed42daa88880511453ad3911aa62d7bd75d1254 Mon Sep 17 00:00:00 2001 From: zpeng11 Date: Sat, 23 Aug 2025 12:17:04 -0400 Subject: [PATCH] feat: allow fsdec and sdec to have sampling parames --- GPT_SoVITS/AR/models/t2s_model_onnx.py | 90 +++++++++++++++----------- 1 file changed, 54 insertions(+), 36 deletions(-) diff --git a/GPT_SoVITS/AR/models/t2s_model_onnx.py b/GPT_SoVITS/AR/models/t2s_model_onnx.py index b381e58d..65c95d48 100644 --- a/GPT_SoVITS/AR/models/t2s_model_onnx.py +++ b/GPT_SoVITS/AR/models/t2s_model_onnx.py @@ -28,45 +28,45 @@ def logits_to_probs( logits, previous_tokens=None, temperature: float = 1.0, - top_k=None, - top_p=None, + top_k=15, + top_p=1.0, repetition_penalty: float = 1.0, ): previous_tokens = previous_tokens.squeeze() - if previous_tokens is not None and repetition_penalty != 1.0: - previous_tokens = previous_tokens.long() - score = torch.gather(logits, dim=0, index=previous_tokens) - score = torch.where( - score < 0, - score * repetition_penalty, - score / repetition_penalty, - ) - logits.scatter_(dim=0, index=previous_tokens, src=score) + # if previous_tokens is not None and repetition_penalty != 1.0: # Always captured by onnx + previous_tokens = previous_tokens.long() + score = torch.gather(logits, dim=0, index=previous_tokens) + score = torch.where( + score < 0, + score * repetition_penalty, + score / repetition_penalty, + ) + logits.scatter_(dim=0, index=previous_tokens, src=score) - if top_p is not None and top_p < 1.0: - sorted_logits, sorted_indices = torch.sort(logits, descending=True) - cum_probs = torch.cumsum( - torch.nn.functional.softmax( - sorted_logits, - dim=-1, - ), + # if top_p is not None and top_p < 1.0: #To be captured by onnx + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cum_probs = torch.cumsum( + torch.nn.functional.softmax( + sorted_logits, dim=-1, - ) - sorted_indices_to_remove = cum_probs > top_p - sorted_indices_to_remove[0] = False # keep at least one option - indices_to_remove = sorted_indices_to_remove.scatter( - dim=0, - index=sorted_indices, - src=sorted_indices_to_remove, - ) - logits = logits.masked_fill(indices_to_remove, -float("Inf")) + ), + dim=-1, + ) + sorted_indices_to_remove = cum_probs > top_p + sorted_indices_to_remove[0] = False # keep at least one option + indices_to_remove = sorted_indices_to_remove.scatter( + dim=0, + index=sorted_indices, + src=sorted_indices_to_remove, + ) + logits = logits.masked_fill(indices_to_remove, -float("Inf")) logits = logits / max(temperature, 1e-5) - if top_k is not None: - v, _ = torch.topk(logits, top_k) - pivot = v.select(-1, -1).unsqueeze(-1) - logits = torch.where(logits < pivot, inf_tensor_value, logits) + # if top_k is not None: # To be captured by onnx + v, _ = torch.topk(logits, top_k) + pivot = v.select(-1, -1).unsqueeze(-1) + logits = torch.where(logits < pivot, inf_tensor_value, logits) probs = torch.nn.functional.softmax(logits, dim=-1) return probs @@ -130,7 +130,16 @@ class T2SFirstStageDecoder(nn.Module): self.early_stop_num = early_stop_num self.num_layers = num_layers - def forward(self, x, prompt): + def forward(self, x, prompt, top_k = None, top_p = None, repetition_penalty = None, temperature = None): + if top_k is None: + top_k = torch.LongTensor([15]).to(device=x.device) + if top_p is None: + top_p = torch.FloatTensor([1.0]).to(device=x.device) + if repetition_penalty is None: + repetition_penalty = torch.FloatTensor([1.0]).to(device=x.device) + if temperature is None: + temperature = torch.FloatTensor([1.0]).to(device=x.device) + y = prompt x_example = x[:, :, 0] * 0.0 # N, 1, 512 @@ -180,7 +189,7 @@ class T2SFirstStageDecoder(nn.Module): xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache) logits = self.ar_predict_layer(xy_dec[:, -1]) - samples = sample(logits[0], y, top_k=15, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0) + samples = sample(logits[0], y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)[0].unsqueeze(0) y = torch.concat([y, samples], dim=1) @@ -211,7 +220,16 @@ class T2SStageDecoder(nn.Module): self.early_stop_num = early_stop_num self.num_layers = num_layers - def forward(self, y, k, v, y_emb, x_example): + def forward(self, y, k, v, y_emb, x_example, top_k = None, top_p = None, repetition_penalty = None, temperature = None): + if top_k is None: + top_k = torch.LongTensor([15]).to(device=y.device) + if top_p is None: + top_p = torch.FloatTensor([1.0]).to(device=y.device) + if repetition_penalty is None: + repetition_penalty = torch.FloatTensor([1.0]).to(device=y.device) + if temperature is None: + temperature = torch.FloatTensor([1.0]).to(device=y.device) + cache = { "all_stage": self.num_layers, "k": torch.nn.functional.pad(k, (0, 0, 0, 0, 0, 1)), @@ -240,7 +258,7 @@ class T2SStageDecoder(nn.Module): xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache) logits = self.ar_predict_layer(xy_dec[:, -1]) - samples = sample(logits[0], y, top_k=15, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0) + samples = sample(logits[0], y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)[0].unsqueeze(0) y = torch.concat([y, samples], dim=1) @@ -389,7 +407,7 @@ class Text2SemanticDecoder(nn.Module): xy_attn_mask = torch.zeros((1, x_len + y_len), dtype=torch.bool) xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache) logits = self.ar_predict_layer(xy_dec[:, -1]) - samples = sample(logits[0], y, top_k=15, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0) + samples = sample(logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35, temperature=1.0)[0].unsqueeze(0) if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: stop = True if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS: