From 254b9b0b55475a7d281f1ec535af5d62629d679e Mon Sep 17 00:00:00 2001 From: csh <458761603@qq.com> Date: Fri, 13 Jun 2025 17:31:09 +0800 Subject: [PATCH] =?UTF-8?q?fix(export=5Ftorch=5Fscript):=20=E5=AF=B9?= =?UTF-8?q?=E9=BD=90=20T2S=20=E5=AE=9E=E7=8E=B0=E5=92=8C=E9=BB=98=E8=AE=A4?= =?UTF-8?q?=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- GPT_SoVITS/export_torch_script.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/GPT_SoVITS/export_torch_script.py b/GPT_SoVITS/export_torch_script.py index 6a13c2d4..66999c0a 100644 --- a/GPT_SoVITS/export_torch_script.py +++ b/GPT_SoVITS/export_torch_script.py @@ -103,7 +103,7 @@ def logits_to_probs( @torch.jit.script def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization - q = torch.randn_like(probs_sort) + q = torch.empty_like(probs_sort).exponential_(1.0) return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) @@ -114,7 +114,7 @@ def sample( temperature: float = 1.0, top_k: Optional[int] = None, top_p: Optional[int] = None, - repetition_penalty: float = 1.0, + repetition_penalty: float = 1.35, ): probs = logits_to_probs( logits=logits, @@ -309,8 +309,9 @@ class T2SBlock: attn = F.scaled_dot_product_attention(q, k, v) - attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim) - attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0) + # attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim) + # attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0) + attn = attn.transpose(1, 2).reshape(batch_size, q_len, -1) attn = F.linear(attn, self.out_w, self.out_b) x = x + attn