Merge branch 'fast_inference_' of https://github.com/RVC-Boss/GPT-SoVITS into sync_main

This commit is contained in:
ChasonJiang 2024-04-16 20:36:51 +08:00
commit 0a0e3634c7
2 changed files with 5 additions and 5 deletions

View File

@ -462,7 +462,7 @@ class Text2SemanticDecoder(nn.Module):
value=True, value=True,
) )
y_attn_mask = F.pad( y_attn_mask = F.pad(
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1), torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=0),# diagonal必须为0否则会导致batch_size>1时的复读情况
(x_len, 0), (x_len, 0),
value=False, value=False,
) )
@ -508,10 +508,10 @@ class Text2SemanticDecoder(nn.Module):
def infer_panel_batch_infer_with_flash_attn( def infer_panel_batch_infer_with_flash_attn(
self, self,
x:List[torch.LongTensor], #####全部文本token x:torch.LongTensor, #####全部文本token
x_lens:torch.LongTensor, x_lens:torch.LongTensor,
prompts:torch.LongTensor, ####参考音频token prompts:torch.LongTensor, ####参考音频token
bert_feature:List[torch.LongTensor], bert_feature:torch.LongTensor,
top_k: int = -100, top_k: int = -100,
top_p: int = 100, top_p: int = 100,
early_stop_num: int = -1, early_stop_num: int = -1,
@ -688,7 +688,7 @@ class Text2SemanticDecoder(nn.Module):
x:List[torch.LongTensor], #####全部文本token x:List[torch.LongTensor], #####全部文本token
x_lens:torch.LongTensor, x_lens:torch.LongTensor,
prompts:torch.LongTensor, ####参考音频token prompts:torch.LongTensor, ####参考音频token
bert_feature:List[torch.LongTensor], bert_feature:torch.LongTensor,
top_k: int = -100, top_k: int = -100,
top_p: int = 100, top_p: int = 100,
early_stop_num: int = -1, early_stop_num: int = -1,

View File

@ -749,7 +749,7 @@ class TTS:
if no_prompt_text : if no_prompt_text :
prompt = None prompt = None
else: else:
prompt = self.prompt_cache["prompt_semantic"].expand(len(all_phoneme_ids), -1).to(self.configs.device) prompt = self.prompt_cache["prompt_semantic"].expand(all_phoneme_ids.shape[0], -1).to(self.configs.device)
pred_semantic_list, idx_list = self.t2s_model.model.infer_panel( pred_semantic_list, idx_list = self.t2s_model.model.infer_panel(