mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-09 00:10:00 +08:00
Merge branch 'fast_inference_' of https://github.com/RVC-Boss/GPT-SoVITS into sync_main
This commit is contained in:
commit
0a0e3634c7
@ -462,7 +462,7 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
value=True,
|
value=True,
|
||||||
)
|
)
|
||||||
y_attn_mask = F.pad(
|
y_attn_mask = F.pad(
|
||||||
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
|
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=0),# diagonal必须为0,否则会导致batch_size>1时的复读情况
|
||||||
(x_len, 0),
|
(x_len, 0),
|
||||||
value=False,
|
value=False,
|
||||||
)
|
)
|
||||||
@ -508,10 +508,10 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
|
|
||||||
def infer_panel_batch_infer_with_flash_attn(
|
def infer_panel_batch_infer_with_flash_attn(
|
||||||
self,
|
self,
|
||||||
x:List[torch.LongTensor], #####全部文本token
|
x:torch.LongTensor, #####全部文本token
|
||||||
x_lens:torch.LongTensor,
|
x_lens:torch.LongTensor,
|
||||||
prompts:torch.LongTensor, ####参考音频token
|
prompts:torch.LongTensor, ####参考音频token
|
||||||
bert_feature:List[torch.LongTensor],
|
bert_feature:torch.LongTensor,
|
||||||
top_k: int = -100,
|
top_k: int = -100,
|
||||||
top_p: int = 100,
|
top_p: int = 100,
|
||||||
early_stop_num: int = -1,
|
early_stop_num: int = -1,
|
||||||
@ -688,7 +688,7 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
x:List[torch.LongTensor], #####全部文本token
|
x:List[torch.LongTensor], #####全部文本token
|
||||||
x_lens:torch.LongTensor,
|
x_lens:torch.LongTensor,
|
||||||
prompts:torch.LongTensor, ####参考音频token
|
prompts:torch.LongTensor, ####参考音频token
|
||||||
bert_feature:List[torch.LongTensor],
|
bert_feature:torch.LongTensor,
|
||||||
top_k: int = -100,
|
top_k: int = -100,
|
||||||
top_p: int = 100,
|
top_p: int = 100,
|
||||||
early_stop_num: int = -1,
|
early_stop_num: int = -1,
|
||||||
|
@ -749,7 +749,7 @@ class TTS:
|
|||||||
if no_prompt_text :
|
if no_prompt_text :
|
||||||
prompt = None
|
prompt = None
|
||||||
else:
|
else:
|
||||||
prompt = self.prompt_cache["prompt_semantic"].expand(len(all_phoneme_ids), -1).to(self.configs.device)
|
prompt = self.prompt_cache["prompt_semantic"].expand(all_phoneme_ids.shape[0], -1).to(self.configs.device)
|
||||||
|
|
||||||
|
|
||||||
pred_semantic_list, idx_list = self.t2s_model.model.infer_panel(
|
pred_semantic_list, idx_list = self.t2s_model.model.infer_panel(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user