修复有时候会出现长段无意义音频的bug

Added logic to handle bad tokens and prevent repetition in token generation.
This commit is contained in:
MiaoMiao Li 2025-10-28 16:12:36 +08:00 committed by GitHub
parent 8ca6bf3f11
commit 94f157c8cf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -901,7 +901,10 @@ class Text2SemanticDecoder(nn.Module):
.view(bsz, self.num_head, src_len, src_len)
.to(device=x.device, dtype=torch.bool)
)
bad_tokens_list = [809, 207,411,679,676,25,23,7]
bad_tokens = torch.tensor(bad_tokens_list, device=x.device, dtype=torch.long)
last_2_token = 0
repeat_count = 0
for idx in tqdm(range(1500)):
if xy_attn_mask is not None:
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, None)
@ -914,7 +917,30 @@ class Text2SemanticDecoder(nn.Module):
xy_attn_mask = None
if idx < 11: ###至少预测出10个token不然不给停止0.4s
logits = logits[:, :-1]
# 1. 获取上一步生成的 token
last_token = y[:, -1]
# 2. 检查对于批次中的每个序列其上一步生成的token是否已经重复如果重复降低选中概率
# is_last_token_bad 的形状是 [batch_size], dtype=torch.bool
is_last_token_bad = torch.isin(last_token, bad_tokens)
if last_token == last_2_token:
repeat_count += 1
if 0 < logits[:, last_2_token] < 35:
logits[:, last_2_token] *= 0.618**repeat_count
elif logits[:, last_2_token] < 0:
logits[:, last_2_token] *= 1.414**repeat_count
elif logits[:, last_2_token] >= 35:
logits[:, last_2_token] *= 0
else:
logits[:, last_2_token] = -float('inf')
else:
repeat_count = 0
last_2_token = last_token
# 3. 如果某个序列的上一个token是恶性token则在这一步的logits中
# 将所有恶性token的概率都设为-inf以禁止选择它们。
# 我们使用高级索引(advanced indexing)来只修改需要修改的行。
if is_last_token_bad.any():
logits[is_last_token_bad, bad_tokens] = -float('inf')
samples = sample(
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
)[0]