mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-12-16 17:40:09 +08:00
修复有时候会出现长段无意义音频的bug
Added logic to handle bad tokens and adjust logits based on repeated tokens during decoding.
This commit is contained in:
parent
11aa78bd9b
commit
8ca6bf3f11
@ -698,6 +698,10 @@ class Text2SemanticDecoder(nn.Module):
|
||||
y_list = [None] * y.shape[0]
|
||||
batch_idx_map = list(range(y.shape[0]))
|
||||
idx_list = [None] * y.shape[0]
|
||||
bad_tokens_list = [809, 207,411,679,676,25,23,7]
|
||||
bad_tokens = torch.tensor(bad_tokens_list, device=x.device, dtype=torch.long)
|
||||
last_2_token = 0
|
||||
repeat_count = 0
|
||||
for idx in tqdm(range(1500)):
|
||||
if idx == 0:
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, attn_mask, None)
|
||||
@ -710,7 +714,30 @@ class Text2SemanticDecoder(nn.Module):
|
||||
logits = logits[:, :-1]
|
||||
else:
|
||||
attn_mask = F.pad(attn_mask, (0, 1), value=False)
|
||||
# 1. 获取上一步生成的 token
|
||||
last_token = y[:, -1]
|
||||
|
||||
# 2. 检查对于批次中的每个序列,其上一步生成的token是否已经重复,如果重复,降低选中概率
|
||||
# is_last_token_bad 的形状是 [batch_size], dtype=torch.bool
|
||||
is_last_token_bad = torch.isin(last_token, bad_tokens)
|
||||
if last_token == last_2_token:
|
||||
repeat_count += 1
|
||||
if 0 < logits[:, last_2_token] < 35:
|
||||
logits[:, last_2_token] *= 0.618**repeat_count
|
||||
elif logits[:, last_2_token] < 0:
|
||||
logits[:, last_2_token] *= 1.414**repeat_count
|
||||
elif logits[:, last_2_token] >= 35:
|
||||
logits[:, last_2_token] *= 0
|
||||
else:
|
||||
logits[:, last_2_token] = -float('inf')
|
||||
else:
|
||||
repeat_count = 0
|
||||
last_2_token = last_token
|
||||
# 3. 如果某个序列的上一个token是恶性token,则在这一步的logits中,
|
||||
# 将所有恶性token的概率都设为-inf,以禁止选择它们。
|
||||
# 我们使用高级索引(advanced indexing)来只修改需要修改的行。
|
||||
if is_last_token_bad.any():
|
||||
logits[is_last_token_bad, bad_tokens] = -float('inf')
|
||||
samples = sample(
|
||||
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
|
||||
)[0]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user