改变训练和推理时的mask策略,以修复当batch_size>1时,产生的复读现象 (#966)

This commit is contained in:
ChasonJiang 2024-04-12 18:00:50 +08:00 committed by GitHub
parent 3706ad1b8b
commit 959269b5ae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 48 additions and 44 deletions

View File

@ -297,7 +297,8 @@ class Text2SemanticDecoder(nn.Module):
(0, y_len), (0, y_len),
value=True, value=True,
) )
# 取消对y[0]的mask,以防止复读详见https://github.com/RVC-Boss/GPT-SoVITS/issues/965
x_attn_mask[:, x_len]=False
y_attn_mask = F.pad( y_attn_mask = F.pad(
torch.triu( torch.triu(
torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
@ -393,6 +394,8 @@ class Text2SemanticDecoder(nn.Module):
(0, y_len), (0, y_len),
value=True, value=True,
) )
# 取消对y[0]的mask,以防止复读详见https://github.com/RVC-Boss/GPT-SoVITS/issues/965
x_attn_mask[:, x_len]=False
y_attn_mask = F.pad( y_attn_mask = F.pad(
torch.triu( torch.triu(
torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
@ -458,7 +461,7 @@ class Text2SemanticDecoder(nn.Module):
value=True, value=True,
) )
y_attn_mask = F.pad( y_attn_mask = F.pad(
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1), torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=0),# diagonal必须为0否则会导致batch_size>1时的复读情况
(x_len, 0), (x_len, 0),
value=False, value=False,
) )
@ -504,29 +507,29 @@ class Text2SemanticDecoder(nn.Module):
def infer_panel_batch_infer_with_flash_attn( def infer_panel_batch_infer_with_flash_attn(
self, self,
x:List[torch.LongTensor], #####全部文本token x:torch.LongTensor, #####全部文本token
x_lens:torch.LongTensor, x_lens:torch.LongTensor,
prompts:torch.LongTensor, ####参考音频token prompts:torch.LongTensor, ####参考音频token
bert_feature:List[torch.LongTensor], bert_feature:torch.LongTensor,
top_k: int = -100, top_k: int = -100,
top_p: int = 100, top_p: int = 100,
early_stop_num: int = -1, early_stop_num: int = -1,
temperature: float = 1.0, temperature: float = 1.0,
): ):
# 先对phones进行embedding、对bert_features进行project再pad到相同长度,以缓解复读问题。(可能还有其他因素导致复读 ## 先对phones进行embedding、对bert_features进行project再pad到相同长度padding策略会影响T2S模型生成的结果但不直接影响复读概率。影响复读概率的主要因素是mask的策略
max_len = 0 # max_len = 0
for x_item, bert_item in zip(x, bert_feature): # for x_item, bert_item in zip(x, bert_feature):
max_len = max(max_len, x_item.shape[0], bert_item.shape[1]) # max_len = max(max_len, x_item.shape[0], bert_item.shape[1])
x_list = [self.ar_text_embedding(item) for item in x] # x_list = [self.ar_text_embedding(item) for item in x]
x_list = [F.pad(item,(0,0,0,max_len-item.shape[0]),value=0) if item.shape[0]<max_len else item for item in x_list] # x_list = [F.pad(item,(0,0,0,max_len-item.shape[0]),value=0) if item.shape[0]<max_len else item for item in x_list]
x = torch.stack(x_list, dim=0) # x = torch.stack(x_list, dim=0)
bert_features_list = [self.bert_proj(item.transpose(0, 1)) for item in bert_feature] # bert_features_list = [self.bert_proj(item.transpose(0, 1)) for item in bert_feature]
bert_features_list = [F.pad(item,(0,0,0,max_len-item.shape[0]), value=0) if item.shape[0]<max_len else item for item in bert_features_list] # bert_features_list = [F.pad(item,(0,0,0,max_len-item.shape[0]), value=0) if item.shape[0]<max_len else item for item in bert_features_list]
bert_feature = torch.stack(bert_features_list, dim=0) # bert_feature = torch.stack(bert_features_list, dim=0)
# bert_feature = self.bert_proj(bert_feature.transpose(1, 2)) bert_feature = self.bert_proj(bert_feature.transpose(1, 2))
# x = self.ar_text_embedding(x) x = self.ar_text_embedding(x)
x = x + bert_feature x = x + bert_feature
x = self.ar_text_position(x) x = self.ar_text_position(x)
@ -573,8 +576,8 @@ class Text2SemanticDecoder(nn.Module):
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1(x,x+y) (0, y_len), ###xx的纯0扩展到xx纯0+xy纯1(x,x+y)
value=True, value=True,
) )
y_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y) y_mask = F.pad( ###yy的右上0扩展到左边xy的0,(y,x+y)
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1), torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=0), # diagonal必须为0否则会导致batch_size>1时的复读情况
(x_len, 0), (x_len, 0),
value=False, value=False,
) )
@ -669,29 +672,29 @@ class Text2SemanticDecoder(nn.Module):
def infer_panel_batch_only( def infer_panel_batch_only(
self, self,
x:List[torch.LongTensor], #####全部文本token x:torch.LongTensor, #####全部文本token
x_lens:torch.LongTensor, x_lens:torch.LongTensor,
prompts:torch.LongTensor, ####参考音频token prompts:torch.LongTensor, ####参考音频token
bert_feature:List[torch.LongTensor], bert_feature:torch.LongTensor,
top_k: int = -100, top_k: int = -100,
top_p: int = 100, top_p: int = 100,
early_stop_num: int = -1, early_stop_num: int = -1,
temperature: float = 1.0, temperature: float = 1.0,
): ):
# 先对phones进行embedding、对bert_features进行project再pad到相同长度,以缓解复读问题。(可能还有其他因素导致复读 ## 先对phones进行embedding、对bert_features进行project再pad到相同长度padding策略会影响T2S模型生成的结果但不直接影响复读概率。影响复读概率的主要因素是mask的策略
max_len = 0 # max_len = 0
for x_item, bert_item in zip(x, bert_feature): # for x_item, bert_item in zip(x, bert_feature):
max_len = max(max_len, x_item.shape[0], bert_item.shape[1]) # max_len = max(max_len, x_item.shape[0], bert_item.shape[1])
x_list = [self.ar_text_embedding(item) for item in x] # x_list = [self.ar_text_embedding(item) for item in x]
x_list = [F.pad(item,(0,0,0,max_len-item.shape[0]),value=0) if item.shape[0]<max_len else item for item in x_list] # x_list = [F.pad(item,(0,0,0,max_len-item.shape[0]),value=0) if item.shape[0]<max_len else item for item in x_list]
x = torch.stack(x_list, dim=0) # x = torch.stack(x_list, dim=0)
bert_features_list = [self.bert_proj(item.transpose(0, 1)) for item in bert_feature] # bert_features_list = [self.bert_proj(item.transpose(0, 1)) for item in bert_feature]
bert_features_list = [F.pad(item,(0,0,0,max_len-item.shape[0]), value=0) if item.shape[0]<max_len else item for item in bert_features_list] # bert_features_list = [F.pad(item,(0,0,0,max_len-item.shape[0]), value=0) if item.shape[0]<max_len else item for item in bert_features_list]
bert_feature = torch.stack(bert_features_list, dim=0) # bert_feature = torch.stack(bert_features_list, dim=0)
# bert_feature = self.bert_proj(bert_feature.transpose(1, 2)) bert_feature = self.bert_proj(bert_feature.transpose(1, 2))
# x = self.ar_text_embedding(x) x = self.ar_text_embedding(x)
x = x + bert_feature x = x + bert_feature
x = self.ar_text_position(x) x = self.ar_text_position(x)
@ -747,7 +750,7 @@ class Text2SemanticDecoder(nn.Module):
value=True, value=True,
) )
y_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y) y_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1), torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=0), # diagonal必须为0否则会导致batch_size>1时的复读情况
(x_len, 0), (x_len, 0),
value=False, value=False,
) )

View File

@ -515,16 +515,16 @@ class TTS:
all_bert_features_batch = all_bert_features_list all_bert_features_batch = all_bert_features_list
# max_len = max(bert_max_len, phones_max_len) max_len = max(bert_max_len, phones_max_len)
# phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len) # phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len)
#### 直接对phones和bert_features进行pad,会增大复读概率。 #### 直接对phones和bert_features进行padpadding策略会影响T2S模型生成的结果但不直接影响复读概率。影响复读概率的主要因素是mask的策略
# all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len) all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len)
# all_bert_features_batch = all_bert_features_list all_bert_features_batch = all_bert_features_list
# all_bert_features_batch = torch.zeros(len(item_list), 1024, max_len, dtype=precision, device=device) all_bert_features_batch = torch.zeros(len(item_list), 1024, max_len, dtype=precision, device=device)
# for idx, item in enumerate(all_bert_features_list): for idx, item in enumerate(all_bert_features_list):
# all_bert_features_batch[idx, :, : item.shape[-1]] = item all_bert_features_batch[idx, :, : item.shape[-1]] = item
# #### 先对phones进行embedding、对bert_features进行project再pad到相同长度以缓解复读问题。(可能还有其他因素导致复读 # #### 先对phones进行embedding、对bert_features进行project再pad到相同长度padding策略会影响T2S模型生成的结果但不直接影响复读概率。影响复读概率的主要因素是mask的策略
# all_phones_list = [self.t2s_model.model.ar_text_embedding(item.to(self.t2s_model.device)) for item in all_phones_list] # all_phones_list = [self.t2s_model.model.ar_text_embedding(item.to(self.t2s_model.device)) for item in all_phones_list]
# all_phones_list = [F.pad(item,(0,0,0,max_len-item.shape[0]),value=0) for item in all_phones_list] # all_phones_list = [F.pad(item,(0,0,0,max_len-item.shape[0]),value=0) for item in all_phones_list]
# all_phones_batch = torch.stack(all_phones_list, dim=0) # all_phones_batch = torch.stack(all_phones_list, dim=0)
@ -734,17 +734,18 @@ class TTS:
continue continue
batch_phones:List[torch.LongTensor] = item["phones"] batch_phones:List[torch.LongTensor] = item["phones"]
# batch_phones:torch.LongTensor = item["phones"]
batch_phones_len:torch.LongTensor = item["phones_len"] batch_phones_len:torch.LongTensor = item["phones_len"]
all_phoneme_ids:List[torch.LongTensor] = item["all_phones"] all_phoneme_ids:torch.LongTensor = item["all_phones"]
all_phoneme_lens:torch.LongTensor = item["all_phones_len"] all_phoneme_lens:torch.LongTensor = item["all_phones_len"]
all_bert_features:List[torch.LongTensor] = item["all_bert_features"] all_bert_features:torch.LongTensor = item["all_bert_features"]
norm_text:str = item["norm_text"] norm_text:str = item["norm_text"]
print(i18n("前端处理后的文本(每句):"), norm_text) print(i18n("前端处理后的文本(每句):"), norm_text)
if no_prompt_text : if no_prompt_text :
prompt = None prompt = None
else: else:
prompt = self.prompt_cache["prompt_semantic"].expand(len(all_phoneme_ids), -1).to(self.configs.device) prompt = self.prompt_cache["prompt_semantic"].expand(all_phoneme_ids.shape[0], -1).to(self.configs.device)
pred_semantic_list, idx_list = self.t2s_model.model.infer_panel( pred_semantic_list, idx_list = self.t2s_model.model.infer_panel(