diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index da95111..b49bcfb 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -504,18 +504,29 @@ class Text2SemanticDecoder(nn.Module): def infer_panel_batch_infer_with_flash_attn( self, - x, #####全部文本token - x_lens, - prompts, ####参考音频token - bert_feature, + x:List[torch.LongTensor], #####全部文本token + x_lens:torch.LongTensor, + prompts:torch.LongTensor, ####参考音频token + bert_feature:List[torch.LongTensor], top_k: int = -100, top_p: int = 100, early_stop_num: int = -1, temperature: float = 1.0, ): - - bert_feature = self.bert_proj(bert_feature.transpose(1, 2)) - x = self.ar_text_embedding(x) + # 先对phones进行embedding、对bert_features进行project,再pad到相同长度,以缓解复读问题。(可能还有其他因素导致复读) + max_len = 0 + for x_item, bert_item in zip(x, bert_feature): + max_len = max(max_len, x_item.shape[0], bert_item.shape[1]) + x_list = [self.ar_text_embedding(item) for item in x] + x_list = [F.pad(item,(0,0,0,max_len-item.shape[0]),value=0) if item.shape[0]