diff --git a/GPT_SoVITS/export_torch_script.py b/GPT_SoVITS/export_torch_script.py index 6a13c2d4..66999c0a 100644 --- a/GPT_SoVITS/export_torch_script.py +++ b/GPT_SoVITS/export_torch_script.py @@ -103,7 +103,7 @@ def logits_to_probs( @torch.jit.script def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization - q = torch.randn_like(probs_sort) + q = torch.empty_like(probs_sort).exponential_(1.0) return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) @@ -114,7 +114,7 @@ def sample( temperature: float = 1.0, top_k: Optional[int] = None, top_p: Optional[int] = None, - repetition_penalty: float = 1.0, + repetition_penalty: float = 1.35, ): probs = logits_to_probs( logits=logits, @@ -309,8 +309,9 @@ class T2SBlock: attn = F.scaled_dot_product_attention(q, k, v) - attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim) - attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0) + # attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim) + # attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0) + attn = attn.transpose(1, 2).reshape(batch_size, q_len, -1) attn = F.linear(attn, self.out_w, self.out_b) x = x + attn