mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-12-17 01:59:08 +08:00
修复batch下的报错
This commit is contained in:
parent
94f157c8cf
commit
7338302ee4
@ -700,8 +700,8 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
idx_list = [None] * y.shape[0]
|
idx_list = [None] * y.shape[0]
|
||||||
bad_tokens_list = [809, 207,411,679,676,25,23,7]
|
bad_tokens_list = [809, 207,411,679,676,25,23,7]
|
||||||
bad_tokens = torch.tensor(bad_tokens_list, device=x.device, dtype=torch.long)
|
bad_tokens = torch.tensor(bad_tokens_list, device=x.device, dtype=torch.long)
|
||||||
last_2_token = 0
|
last_2_token = torch.full((bsz,), -1, dtype=torch.long, device=x.device)
|
||||||
repeat_count = 0
|
repeat_count = torch.zeros(bsz, dtype=torch.float, device=x.device)
|
||||||
for idx in tqdm(range(1500)):
|
for idx in tqdm(range(1500)):
|
||||||
if idx == 0:
|
if idx == 0:
|
||||||
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, attn_mask, None)
|
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, attn_mask, None)
|
||||||
@ -718,26 +718,62 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
last_token = y[:, -1]
|
last_token = y[:, -1]
|
||||||
|
|
||||||
# 2. 检查对于批次中的每个序列,其上一步生成的token是否已经重复,如果重复,降低选中概率
|
# 2. 检查对于批次中的每个序列,其上一步生成的token是否已经重复,如果重复,降低选中概率
|
||||||
# is_last_token_bad 的形状是 [batch_size], dtype=torch.bool
|
is_repeated = (last_token == last_2_token)
|
||||||
|
repeat_count[is_repeated] += 1
|
||||||
|
# 3. 对没有重复的序列,计数器清零
|
||||||
|
repeat_count[~is_repeated] = 0 # ~是布尔 'not' 操作
|
||||||
|
if is_repeated.any():
|
||||||
|
# 获取重复序列的行索引
|
||||||
|
repeated_rows = torch.where(is_repeated)[0]
|
||||||
|
|
||||||
|
# 获取这些重复序列对应的 token ID
|
||||||
|
repeated_tokens = last_2_token[repeated_rows]
|
||||||
|
|
||||||
|
# 获取这些重复 token 当前的 logits 值
|
||||||
|
# 这是一种 "gather" 操作,精确选取 logits[row, token_id]
|
||||||
|
current_logits_values = logits[repeated_rows, repeated_tokens]
|
||||||
|
|
||||||
|
# 获取这些重复序列的重复次数
|
||||||
|
current_repeat_counts = repeat_count[repeated_rows]
|
||||||
|
|
||||||
|
# --- 创建掩码来模拟 if/elif/else ---
|
||||||
|
# 条件1: 0 < logit < 35
|
||||||
|
mask1 = (current_logits_values > 0) & (current_logits_values < 35)
|
||||||
|
# 条件2: logit < 0
|
||||||
|
mask2 = current_logits_values < 0
|
||||||
|
# 条件3: logit >= 35
|
||||||
|
mask3 = current_logits_values >= 35
|
||||||
|
|
||||||
|
# 4. --- 应用惩罚 ---
|
||||||
|
# 对满足条件1的 logits 应用惩罚
|
||||||
|
if mask1.any():
|
||||||
|
rows_to_update = repeated_rows[mask1]
|
||||||
|
tokens_to_update = repeated_tokens[mask1]
|
||||||
|
counts_to_use = current_repeat_counts[mask1]
|
||||||
|
logits[rows_to_update, tokens_to_update] *= (0.618 ** counts_to_use.unsqueeze(1)).squeeze(1)
|
||||||
|
|
||||||
|
# 对满足条件2的 logits 应用惩罚
|
||||||
|
if mask2.any():
|
||||||
|
rows_to_update = repeated_rows[mask2]
|
||||||
|
tokens_to_update = repeated_tokens[mask2]
|
||||||
|
counts_to_use = current_repeat_counts[mask2]
|
||||||
|
logits[rows_to_update, tokens_to_update] *= (1.414 ** counts_to_use.unsqueeze(1)).squeeze(1)
|
||||||
|
|
||||||
|
# 对满足条件3的 logits 应用惩罚
|
||||||
|
if mask3.any():
|
||||||
|
rows_to_update = repeated_rows[mask3]
|
||||||
|
tokens_to_update = repeated_tokens[mask3]
|
||||||
|
logits[rows_to_update, tokens_to_update] = -float('inf') # 你的逻辑是 *= 0,设为 -inf 更能禁止采样
|
||||||
|
|
||||||
|
# 5. 更新 last_2_token 以备下一次迭代
|
||||||
|
# 使用 .clone() 以免在下一次循环中意外修改 last_token
|
||||||
|
last_2_token = last_token.clone()
|
||||||
|
|
||||||
|
# 6. 处理 bad_tokens 列表
|
||||||
is_last_token_bad = torch.isin(last_token, bad_tokens)
|
is_last_token_bad = torch.isin(last_token, bad_tokens)
|
||||||
if last_token == last_2_token:
|
|
||||||
repeat_count += 1
|
|
||||||
if 0 < logits[:, last_2_token] < 35:
|
|
||||||
logits[:, last_2_token] *= 0.618**repeat_count
|
|
||||||
elif logits[:, last_2_token] < 0:
|
|
||||||
logits[:, last_2_token] *= 1.414**repeat_count
|
|
||||||
elif logits[:, last_2_token] >= 35:
|
|
||||||
logits[:, last_2_token] *= 0
|
|
||||||
else:
|
|
||||||
logits[:, last_2_token] = -float('inf')
|
|
||||||
else:
|
|
||||||
repeat_count = 0
|
|
||||||
last_2_token = last_token
|
|
||||||
# 3. 如果某个序列的上一个token是恶性token,则在这一步的logits中,
|
|
||||||
# 将所有恶性token的概率都设为-inf,以禁止选择它们。
|
|
||||||
# 我们使用高级索引(advanced indexing)来只修改需要修改的行。
|
|
||||||
if is_last_token_bad.any():
|
if is_last_token_bad.any():
|
||||||
logits[is_last_token_bad, bad_tokens] = -float('inf')
|
bad_rows = is_last_token_bad.nonzero(as_tuple=True)[0]
|
||||||
|
logits[bad_rows.unsqueeze(1), bad_tokens] = -float('inf')
|
||||||
samples = sample(
|
samples = sample(
|
||||||
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
|
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
|
||||||
)[0]
|
)[0]
|
||||||
@ -903,8 +939,8 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
)
|
)
|
||||||
bad_tokens_list = [809, 207,411,679,676,25,23,7]
|
bad_tokens_list = [809, 207,411,679,676,25,23,7]
|
||||||
bad_tokens = torch.tensor(bad_tokens_list, device=x.device, dtype=torch.long)
|
bad_tokens = torch.tensor(bad_tokens_list, device=x.device, dtype=torch.long)
|
||||||
last_2_token = 0
|
last_2_token = torch.full((bsz,), -1, dtype=torch.long, device=x.device)
|
||||||
repeat_count = 0
|
repeat_count = torch.zeros(bsz, dtype=torch.float, device=x.device)
|
||||||
for idx in tqdm(range(1500)):
|
for idx in tqdm(range(1500)):
|
||||||
if xy_attn_mask is not None:
|
if xy_attn_mask is not None:
|
||||||
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, None)
|
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, None)
|
||||||
@ -921,26 +957,62 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
last_token = y[:, -1]
|
last_token = y[:, -1]
|
||||||
|
|
||||||
# 2. 检查对于批次中的每个序列,其上一步生成的token是否已经重复,如果重复,降低选中概率
|
# 2. 检查对于批次中的每个序列,其上一步生成的token是否已经重复,如果重复,降低选中概率
|
||||||
# is_last_token_bad 的形状是 [batch_size], dtype=torch.bool
|
is_repeated = (last_token == last_2_token)
|
||||||
|
repeat_count[is_repeated] += 1
|
||||||
|
# 3. 对没有重复的序列,计数器清零
|
||||||
|
repeat_count[~is_repeated] = 0 # ~是布尔 'not' 操作
|
||||||
|
if is_repeated.any():
|
||||||
|
# 获取重复序列的行索引
|
||||||
|
repeated_rows = torch.where(is_repeated)[0]
|
||||||
|
|
||||||
|
# 获取这些重复序列对应的 token ID
|
||||||
|
repeated_tokens = last_2_token[repeated_rows]
|
||||||
|
|
||||||
|
# 获取这些重复 token 当前的 logits 值
|
||||||
|
# 这是一种 "gather" 操作,精确选取 logits[row, token_id]
|
||||||
|
current_logits_values = logits[repeated_rows, repeated_tokens]
|
||||||
|
|
||||||
|
# 获取这些重复序列的重复次数
|
||||||
|
current_repeat_counts = repeat_count[repeated_rows]
|
||||||
|
|
||||||
|
# --- 创建掩码来模拟 if/elif/else ---
|
||||||
|
# 条件1: 0 < logit < 35
|
||||||
|
mask1 = (current_logits_values > 0) & (current_logits_values < 35)
|
||||||
|
# 条件2: logit < 0
|
||||||
|
mask2 = current_logits_values < 0
|
||||||
|
# 条件3: logit >= 35
|
||||||
|
mask3 = current_logits_values >= 35
|
||||||
|
|
||||||
|
# 4. --- 应用惩罚 ---
|
||||||
|
# 对满足条件1的 logits 应用惩罚
|
||||||
|
if mask1.any():
|
||||||
|
rows_to_update = repeated_rows[mask1]
|
||||||
|
tokens_to_update = repeated_tokens[mask1]
|
||||||
|
counts_to_use = current_repeat_counts[mask1]
|
||||||
|
logits[rows_to_update, tokens_to_update] *= (0.618 ** counts_to_use.unsqueeze(1)).squeeze(1)
|
||||||
|
|
||||||
|
# 对满足条件2的 logits 应用惩罚
|
||||||
|
if mask2.any():
|
||||||
|
rows_to_update = repeated_rows[mask2]
|
||||||
|
tokens_to_update = repeated_tokens[mask2]
|
||||||
|
counts_to_use = current_repeat_counts[mask2]
|
||||||
|
logits[rows_to_update, tokens_to_update] *= (1.414 ** counts_to_use.unsqueeze(1)).squeeze(1)
|
||||||
|
|
||||||
|
# 对满足条件3的 logits 应用惩罚
|
||||||
|
if mask3.any():
|
||||||
|
rows_to_update = repeated_rows[mask3]
|
||||||
|
tokens_to_update = repeated_tokens[mask3]
|
||||||
|
logits[rows_to_update, tokens_to_update] = -float('inf') # 你的逻辑是 *= 0,设为 -inf 更能禁止采样
|
||||||
|
|
||||||
|
# 5. 更新 last_2_token 以备下一次迭代
|
||||||
|
# 使用 .clone() 以免在下一次循环中意外修改 last_token
|
||||||
|
last_2_token = last_token.clone()
|
||||||
|
|
||||||
|
# 6. 处理 bad_tokens 列表
|
||||||
is_last_token_bad = torch.isin(last_token, bad_tokens)
|
is_last_token_bad = torch.isin(last_token, bad_tokens)
|
||||||
if last_token == last_2_token:
|
|
||||||
repeat_count += 1
|
|
||||||
if 0 < logits[:, last_2_token] < 35:
|
|
||||||
logits[:, last_2_token] *= 0.618**repeat_count
|
|
||||||
elif logits[:, last_2_token] < 0:
|
|
||||||
logits[:, last_2_token] *= 1.414**repeat_count
|
|
||||||
elif logits[:, last_2_token] >= 35:
|
|
||||||
logits[:, last_2_token] *= 0
|
|
||||||
else:
|
|
||||||
logits[:, last_2_token] = -float('inf')
|
|
||||||
else:
|
|
||||||
repeat_count = 0
|
|
||||||
last_2_token = last_token
|
|
||||||
# 3. 如果某个序列的上一个token是恶性token,则在这一步的logits中,
|
|
||||||
# 将所有恶性token的概率都设为-inf,以禁止选择它们。
|
|
||||||
# 我们使用高级索引(advanced indexing)来只修改需要修改的行。
|
|
||||||
if is_last_token_bad.any():
|
if is_last_token_bad.any():
|
||||||
logits[is_last_token_bad, bad_tokens] = -float('inf')
|
bad_rows = is_last_token_bad.nonzero(as_tuple=True)[0]
|
||||||
|
logits[bad_rows.unsqueeze(1), bad_tokens] = -float('inf')
|
||||||
samples = sample(
|
samples = sample(
|
||||||
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
|
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
|
||||||
)[0]
|
)[0]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user