diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index b49bcfb..dad2440 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -297,7 +297,8 @@ class Text2SemanticDecoder(nn.Module): (0, y_len), value=True, ) - + # 取消对y[0]的mask,以防止复读,详见https://github.com/RVC-Boss/GPT-SoVITS/issues/965 + x_attn_mask[:, x_len]=False y_attn_mask = F.pad( torch.triu( torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), @@ -393,6 +394,8 @@ class Text2SemanticDecoder(nn.Module): (0, y_len), value=True, ) + # 取消对y[0]的mask,以防止复读,详见https://github.com/RVC-Boss/GPT-SoVITS/issues/965 + x_attn_mask[:, x_len]=False y_attn_mask = F.pad( torch.triu( torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), @@ -458,7 +461,7 @@ class Text2SemanticDecoder(nn.Module): value=True, ) y_attn_mask = F.pad( - torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1), + torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=0),# diagonal必须为0,否则会导致batch_size>1时的复读情况 (x_len, 0), value=False, ) @@ -504,29 +507,29 @@ class Text2SemanticDecoder(nn.Module): def infer_panel_batch_infer_with_flash_attn( self, - x:List[torch.LongTensor], #####全部文本token + x:torch.LongTensor, #####全部文本token x_lens:torch.LongTensor, prompts:torch.LongTensor, ####参考音频token - bert_feature:List[torch.LongTensor], + bert_feature:torch.LongTensor, top_k: int = -100, top_p: int = 100, early_stop_num: int = -1, temperature: float = 1.0, ): - # 先对phones进行embedding、对bert_features进行project,再pad到相同长度,以缓解复读问题。(可能还有其他因素导致复读) - max_len = 0 - for x_item, bert_item in zip(x, bert_feature): - max_len = max(max_len, x_item.shape[0], bert_item.shape[1]) - x_list = [self.ar_text_embedding(item) for item in x] - x_list = [F.pad(item,(0,0,0,max_len-item.shape[0]),value=0) if item.shape[0]1时的复读情况 (x_len, 0), value=False, ) @@ -669,29 +672,29 @@ class Text2SemanticDecoder(nn.Module): def infer_panel_batch_only( self, - x:List[torch.LongTensor], #####全部文本token + x:torch.LongTensor, #####全部文本token x_lens:torch.LongTensor, prompts:torch.LongTensor, ####参考音频token - bert_feature:List[torch.LongTensor], + bert_feature:torch.LongTensor, top_k: int = -100, top_p: int = 100, early_stop_num: int = -1, temperature: float = 1.0, ): - # 先对phones进行embedding、对bert_features进行project,再pad到相同长度,以缓解复读问题。(可能还有其他因素导致复读) - max_len = 0 - for x_item, bert_item in zip(x, bert_feature): - max_len = max(max_len, x_item.shape[0], bert_item.shape[1]) - x_list = [self.ar_text_embedding(item) for item in x] - x_list = [F.pad(item,(0,0,0,max_len-item.shape[0]),value=0) if item.shape[0]1时的复读情况 (x_len, 0), value=False, ) diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index b875151..a6f2541 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -515,16 +515,16 @@ class TTS: all_bert_features_batch = all_bert_features_list - # max_len = max(bert_max_len, phones_max_len) + max_len = max(bert_max_len, phones_max_len) # phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len) - #### 直接对phones和bert_features进行pad,会增大复读概率。 - # all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len) - # all_bert_features_batch = all_bert_features_list - # all_bert_features_batch = torch.zeros(len(item_list), 1024, max_len, dtype=precision, device=device) - # for idx, item in enumerate(all_bert_features_list): - # all_bert_features_batch[idx, :, : item.shape[-1]] = item + #### 直接对phones和bert_features进行pad。(padding策略会影响T2S模型生成的结果,但不直接影响复读概率。影响复读概率的主要因素是mask的策略) + all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len) + all_bert_features_batch = all_bert_features_list + all_bert_features_batch = torch.zeros(len(item_list), 1024, max_len, dtype=precision, device=device) + for idx, item in enumerate(all_bert_features_list): + all_bert_features_batch[idx, :, : item.shape[-1]] = item - # #### 先对phones进行embedding、对bert_features进行project,再pad到相同长度,以缓解复读问题。(可能还有其他因素导致复读) + # #### 先对phones进行embedding、对bert_features进行project,再pad到相同长度,(padding策略会影响T2S模型生成的结果,但不直接影响复读概率。影响复读概率的主要因素是mask的策略) # all_phones_list = [self.t2s_model.model.ar_text_embedding(item.to(self.t2s_model.device)) for item in all_phones_list] # all_phones_list = [F.pad(item,(0,0,0,max_len-item.shape[0]),value=0) for item in all_phones_list] # all_phones_batch = torch.stack(all_phones_list, dim=0) @@ -734,17 +734,18 @@ class TTS: continue batch_phones:List[torch.LongTensor] = item["phones"] + # batch_phones:torch.LongTensor = item["phones"] batch_phones_len:torch.LongTensor = item["phones_len"] - all_phoneme_ids:List[torch.LongTensor] = item["all_phones"] + all_phoneme_ids:torch.LongTensor = item["all_phones"] all_phoneme_lens:torch.LongTensor = item["all_phones_len"] - all_bert_features:List[torch.LongTensor] = item["all_bert_features"] + all_bert_features:torch.LongTensor = item["all_bert_features"] norm_text:str = item["norm_text"] print(i18n("前端处理后的文本(每句):"), norm_text) if no_prompt_text : prompt = None else: - prompt = self.prompt_cache["prompt_semantic"].expand(len(all_phoneme_ids), -1).to(self.configs.device) + prompt = self.prompt_cache["prompt_semantic"].expand(all_phoneme_ids.shape[0], -1).to(self.configs.device) pred_semantic_list, idx_list = self.t2s_model.model.infer_panel(