From 7338302ee4381be81acf6cf5ce7e3449a5974472 Mon Sep 17 00:00:00 2001 From: lcc <694625452@qq.com> Date: Mon, 3 Nov 2025 15:23:23 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dbatch=E4=B8=8B=E7=9A=84?= =?UTF-8?q?=E6=8A=A5=E9=94=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- GPT_SoVITS/AR/models/t2s_model.py | 152 ++++++++++++++++++++++-------- 1 file changed, 112 insertions(+), 40 deletions(-) diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index 89eb3d0e..5e278978 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -700,8 +700,8 @@ class Text2SemanticDecoder(nn.Module): idx_list = [None] * y.shape[0] bad_tokens_list = [809, 207,411,679,676,25,23,7] bad_tokens = torch.tensor(bad_tokens_list, device=x.device, dtype=torch.long) - last_2_token = 0 - repeat_count = 0 + last_2_token = torch.full((bsz,), -1, dtype=torch.long, device=x.device) + repeat_count = torch.zeros(bsz, dtype=torch.float, device=x.device) for idx in tqdm(range(1500)): if idx == 0: xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, attn_mask, None) @@ -718,26 +718,62 @@ class Text2SemanticDecoder(nn.Module): last_token = y[:, -1] # 2. 检查对于批次中的每个序列,其上一步生成的token是否已经重复,如果重复,降低选中概率 - # is_last_token_bad 的形状是 [batch_size], dtype=torch.bool + is_repeated = (last_token == last_2_token) + repeat_count[is_repeated] += 1 + # 3. 对没有重复的序列,计数器清零 + repeat_count[~is_repeated] = 0 # ~是布尔 'not' 操作 + if is_repeated.any(): + # 获取重复序列的行索引 + repeated_rows = torch.where(is_repeated)[0] + + # 获取这些重复序列对应的 token ID + repeated_tokens = last_2_token[repeated_rows] + + # 获取这些重复 token 当前的 logits 值 + # 这是一种 "gather" 操作,精确选取 logits[row, token_id] + current_logits_values = logits[repeated_rows, repeated_tokens] + + # 获取这些重复序列的重复次数 + current_repeat_counts = repeat_count[repeated_rows] + + # --- 创建掩码来模拟 if/elif/else --- + # 条件1: 0 < logit < 35 + mask1 = (current_logits_values > 0) & (current_logits_values < 35) + # 条件2: logit < 0 + mask2 = current_logits_values < 0 + # 条件3: logit >= 35 + mask3 = current_logits_values >= 35 + + # 4. --- 应用惩罚 --- + # 对满足条件1的 logits 应用惩罚 + if mask1.any(): + rows_to_update = repeated_rows[mask1] + tokens_to_update = repeated_tokens[mask1] + counts_to_use = current_repeat_counts[mask1] + logits[rows_to_update, tokens_to_update] *= (0.618 ** counts_to_use.unsqueeze(1)).squeeze(1) + + # 对满足条件2的 logits 应用惩罚 + if mask2.any(): + rows_to_update = repeated_rows[mask2] + tokens_to_update = repeated_tokens[mask2] + counts_to_use = current_repeat_counts[mask2] + logits[rows_to_update, tokens_to_update] *= (1.414 ** counts_to_use.unsqueeze(1)).squeeze(1) + + # 对满足条件3的 logits 应用惩罚 + if mask3.any(): + rows_to_update = repeated_rows[mask3] + tokens_to_update = repeated_tokens[mask3] + logits[rows_to_update, tokens_to_update] = -float('inf') # 你的逻辑是 *= 0,设为 -inf 更能禁止采样 + + # 5. 更新 last_2_token 以备下一次迭代 + # 使用 .clone() 以免在下一次循环中意外修改 last_token + last_2_token = last_token.clone() + + # 6. 处理 bad_tokens 列表 is_last_token_bad = torch.isin(last_token, bad_tokens) - if last_token == last_2_token: - repeat_count += 1 - if 0 < logits[:, last_2_token] < 35: - logits[:, last_2_token] *= 0.618**repeat_count - elif logits[:, last_2_token] < 0: - logits[:, last_2_token] *= 1.414**repeat_count - elif logits[:, last_2_token] >= 35: - logits[:, last_2_token] *= 0 - else: - logits[:, last_2_token] = -float('inf') - else: - repeat_count = 0 - last_2_token = last_token - # 3. 如果某个序列的上一个token是恶性token,则在这一步的logits中, - # 将所有恶性token的概率都设为-inf,以禁止选择它们。 - # 我们使用高级索引(advanced indexing)来只修改需要修改的行。 if is_last_token_bad.any(): - logits[is_last_token_bad, bad_tokens] = -float('inf') + bad_rows = is_last_token_bad.nonzero(as_tuple=True)[0] + logits[bad_rows.unsqueeze(1), bad_tokens] = -float('inf') samples = sample( logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature )[0] @@ -903,8 +939,8 @@ class Text2SemanticDecoder(nn.Module): ) bad_tokens_list = [809, 207,411,679,676,25,23,7] bad_tokens = torch.tensor(bad_tokens_list, device=x.device, dtype=torch.long) - last_2_token = 0 - repeat_count = 0 + last_2_token = torch.full((bsz,), -1, dtype=torch.long, device=x.device) + repeat_count = torch.zeros(bsz, dtype=torch.float, device=x.device) for idx in tqdm(range(1500)): if xy_attn_mask is not None: xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, None) @@ -921,26 +957,62 @@ class Text2SemanticDecoder(nn.Module): last_token = y[:, -1] # 2. 检查对于批次中的每个序列,其上一步生成的token是否已经重复,如果重复,降低选中概率 - # is_last_token_bad 的形状是 [batch_size], dtype=torch.bool + is_repeated = (last_token == last_2_token) + repeat_count[is_repeated] += 1 + # 3. 对没有重复的序列,计数器清零 + repeat_count[~is_repeated] = 0 # ~是布尔 'not' 操作 + if is_repeated.any(): + # 获取重复序列的行索引 + repeated_rows = torch.where(is_repeated)[0] + + # 获取这些重复序列对应的 token ID + repeated_tokens = last_2_token[repeated_rows] + + # 获取这些重复 token 当前的 logits 值 + # 这是一种 "gather" 操作,精确选取 logits[row, token_id] + current_logits_values = logits[repeated_rows, repeated_tokens] + + # 获取这些重复序列的重复次数 + current_repeat_counts = repeat_count[repeated_rows] + + # --- 创建掩码来模拟 if/elif/else --- + # 条件1: 0 < logit < 35 + mask1 = (current_logits_values > 0) & (current_logits_values < 35) + # 条件2: logit < 0 + mask2 = current_logits_values < 0 + # 条件3: logit >= 35 + mask3 = current_logits_values >= 35 + + # 4. --- 应用惩罚 --- + # 对满足条件1的 logits 应用惩罚 + if mask1.any(): + rows_to_update = repeated_rows[mask1] + tokens_to_update = repeated_tokens[mask1] + counts_to_use = current_repeat_counts[mask1] + logits[rows_to_update, tokens_to_update] *= (0.618 ** counts_to_use.unsqueeze(1)).squeeze(1) + + # 对满足条件2的 logits 应用惩罚 + if mask2.any(): + rows_to_update = repeated_rows[mask2] + tokens_to_update = repeated_tokens[mask2] + counts_to_use = current_repeat_counts[mask2] + logits[rows_to_update, tokens_to_update] *= (1.414 ** counts_to_use.unsqueeze(1)).squeeze(1) + + # 对满足条件3的 logits 应用惩罚 + if mask3.any(): + rows_to_update = repeated_rows[mask3] + tokens_to_update = repeated_tokens[mask3] + logits[rows_to_update, tokens_to_update] = -float('inf') # 你的逻辑是 *= 0,设为 -inf 更能禁止采样 + + # 5. 更新 last_2_token 以备下一次迭代 + # 使用 .clone() 以免在下一次循环中意外修改 last_token + last_2_token = last_token.clone() + + # 6. 处理 bad_tokens 列表 is_last_token_bad = torch.isin(last_token, bad_tokens) - if last_token == last_2_token: - repeat_count += 1 - if 0 < logits[:, last_2_token] < 35: - logits[:, last_2_token] *= 0.618**repeat_count - elif logits[:, last_2_token] < 0: - logits[:, last_2_token] *= 1.414**repeat_count - elif logits[:, last_2_token] >= 35: - logits[:, last_2_token] *= 0 - else: - logits[:, last_2_token] = -float('inf') - else: - repeat_count = 0 - last_2_token = last_token - # 3. 如果某个序列的上一个token是恶性token,则在这一步的logits中, - # 将所有恶性token的概率都设为-inf,以禁止选择它们。 - # 我们使用高级索引(advanced indexing)来只修改需要修改的行。 if is_last_token_bad.any(): - logits[is_last_token_bad, bad_tokens] = -float('inf') + bad_rows = is_last_token_bad.nonzero(as_tuple=True)[0] + logits[bad_rows.unsqueeze(1), bad_tokens] = -float('inf') samples = sample( logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature )[0]