From 8ca6bf3f112b0b4714f5208adaa14b671e163489 Mon Sep 17 00:00:00 2001 From: MiaoMiao Li <94794519+jsntcheng@users.noreply.github.com> Date: Tue, 28 Oct 2025 16:08:06 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E6=9C=89=E6=97=B6=E5=80=99?= =?UTF-8?q?=E4=BC=9A=E5=87=BA=E7=8E=B0=E9=95=BF=E6=AE=B5=E6=97=A0=E6=84=8F?= =?UTF-8?q?=E4=B9=89=E9=9F=B3=E9=A2=91=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Added logic to handle bad tokens and adjust logits based on repeated tokens during decoding. --- GPT_SoVITS/AR/models/t2s_model.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index 7196d6ab..c8964823 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -698,6 +698,10 @@ class Text2SemanticDecoder(nn.Module): y_list = [None] * y.shape[0] batch_idx_map = list(range(y.shape[0])) idx_list = [None] * y.shape[0] + bad_tokens_list = [809, 207,411,679,676,25,23,7] + bad_tokens = torch.tensor(bad_tokens_list, device=x.device, dtype=torch.long) + last_2_token = 0 + repeat_count = 0 for idx in tqdm(range(1500)): if idx == 0: xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, attn_mask, None) @@ -710,7 +714,30 @@ class Text2SemanticDecoder(nn.Module): logits = logits[:, :-1] else: attn_mask = F.pad(attn_mask, (0, 1), value=False) + # 1. 获取上一步生成的 token + last_token = y[:, -1] + # 2. 检查对于批次中的每个序列,其上一步生成的token是否已经重复,如果重复,降低选中概率 + # is_last_token_bad 的形状是 [batch_size], dtype=torch.bool + is_last_token_bad = torch.isin(last_token, bad_tokens) + if last_token == last_2_token: + repeat_count += 1 + if 0 < logits[:, last_2_token] < 35: + logits[:, last_2_token] *= 0.618**repeat_count + elif logits[:, last_2_token] < 0: + logits[:, last_2_token] *= 1.414**repeat_count + elif logits[:, last_2_token] >= 35: + logits[:, last_2_token] *= 0 + else: + logits[:, last_2_token] = -float('inf') + else: + repeat_count = 0 + last_2_token = last_token + # 3. 如果某个序列的上一个token是恶性token,则在这一步的logits中, + # 将所有恶性token的概率都设为-inf,以禁止选择它们。 + # 我们使用高级索引(advanced indexing)来只修改需要修改的行。 + if is_last_token_bad.any(): + logits[is_last_token_bad, bad_tokens] = -float('inf') samples = sample( logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature )[0]