fix(export_torch_script): 对齐 T2S 实现和默认参数

This commit is contained in:
csh 2025-06-13 17:31:09 +08:00
parent 5c91e66d2e
commit 254b9b0b55

View File

@ -103,7 +103,7 @@ def logits_to_probs(
@torch.jit.script
def multinomial_sample_one_no_sync(probs_sort):
# Does multinomial sampling without a cuda synchronization
q = torch.randn_like(probs_sort)
q = torch.empty_like(probs_sort).exponential_(1.0)
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
@ -114,7 +114,7 @@ def sample(
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: Optional[int] = None,
repetition_penalty: float = 1.0,
repetition_penalty: float = 1.35,
):
probs = logits_to_probs(
logits=logits,
@ -309,8 +309,9 @@ class T2SBlock:
attn = F.scaled_dot_product_attention(q, k, v)
attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim)
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
# attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim)
# attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
attn = attn.transpose(1, 2).reshape(batch_size, q_len, -1)
attn = F.linear(attn, self.out_w, self.out_b)
x = x + attn