diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index e12bb11b..95f14859 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -462,7 +462,7 @@ class Text2SemanticDecoder(nn.Module): value=True, ) y_attn_mask = F.pad( - torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1), + torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=0),# diagonal必须为0,否则会导致batch_size>1时的复读情况 (x_len, 0), value=False, ) @@ -508,10 +508,10 @@ class Text2SemanticDecoder(nn.Module): def infer_panel_batch_infer_with_flash_attn( self, - x:List[torch.LongTensor], #####全部文本token + x:torch.LongTensor, #####全部文本token x_lens:torch.LongTensor, prompts:torch.LongTensor, ####参考音频token - bert_feature:List[torch.LongTensor], + bert_feature:torch.LongTensor, top_k: int = -100, top_p: int = 100, early_stop_num: int = -1, @@ -688,7 +688,7 @@ class Text2SemanticDecoder(nn.Module): x:List[torch.LongTensor], #####全部文本token x_lens:torch.LongTensor, prompts:torch.LongTensor, ####参考音频token - bert_feature:List[torch.LongTensor], + bert_feature:torch.LongTensor, top_k: int = -100, top_p: int = 100, early_stop_num: int = -1, diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 4befc0c4..b38e9dc6 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -749,7 +749,7 @@ class TTS: if no_prompt_text : prompt = None else: - prompt = self.prompt_cache["prompt_semantic"].expand(len(all_phoneme_ids), -1).to(self.configs.device) + prompt = self.prompt_cache["prompt_semantic"].expand(all_phoneme_ids.shape[0], -1).to(self.configs.device) pred_semantic_list, idx_list = self.t2s_model.model.infer_panel(