mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-03 14:00:10 +08:00
fix(export_torch_script): 对齐 T2S 实现和默认参数
This commit is contained in:
parent
5c91e66d2e
commit
254b9b0b55
@ -103,7 +103,7 @@ def logits_to_probs(
|
|||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def multinomial_sample_one_no_sync(probs_sort):
|
def multinomial_sample_one_no_sync(probs_sort):
|
||||||
# Does multinomial sampling without a cuda synchronization
|
# Does multinomial sampling without a cuda synchronization
|
||||||
q = torch.randn_like(probs_sort)
|
q = torch.empty_like(probs_sort).exponential_(1.0)
|
||||||
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
||||||
|
|
||||||
|
|
||||||
@ -114,7 +114,7 @@ def sample(
|
|||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
top_k: Optional[int] = None,
|
top_k: Optional[int] = None,
|
||||||
top_p: Optional[int] = None,
|
top_p: Optional[int] = None,
|
||||||
repetition_penalty: float = 1.0,
|
repetition_penalty: float = 1.35,
|
||||||
):
|
):
|
||||||
probs = logits_to_probs(
|
probs = logits_to_probs(
|
||||||
logits=logits,
|
logits=logits,
|
||||||
@ -309,8 +309,9 @@ class T2SBlock:
|
|||||||
|
|
||||||
attn = F.scaled_dot_product_attention(q, k, v)
|
attn = F.scaled_dot_product_attention(q, k, v)
|
||||||
|
|
||||||
attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim)
|
# attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim)
|
||||||
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
|
# attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
|
||||||
|
attn = attn.transpose(1, 2).reshape(batch_size, q_len, -1)
|
||||||
attn = F.linear(attn, self.out_w, self.out_b)
|
attn = F.linear(attn, self.out_w, self.out_b)
|
||||||
|
|
||||||
x = x + attn
|
x = x + attn
|
||||||
|
Loading…
x
Reference in New Issue
Block a user