From 864a148d75c235d3d7c75316ba02254bcbdf9fb3 Mon Sep 17 00:00:00 2001 From: chasonjiang <1440499136@qq.com> Date: Sat, 16 Mar 2024 21:04:49 +0800 Subject: [PATCH] =?UTF-8?q?=E7=BC=93=E8=A7=A3=E4=BA=86batch=5Fsize>1?= =?UTF-8?q?=E6=97=B6=E7=9A=84=E5=A4=8D=E8=AF=BB=E9=97=AE=E9=A2=98=EF=BC=8C?= =?UTF-8?q?=E7=BC=93=E8=A7=A3=E6=96=B9=E6=B3=95=E6=98=AF=EF=BC=9A=E5=9C=A8?= =?UTF-8?q?T2S=E6=A8=A1=E5=9E=8B=E4=B8=AD=EF=BC=8C=E5=85=88=E5=AF=B9phones?= =?UTF-8?q?=E8=BF=9B=E8=A1=8Cembedding=E3=80=81=E5=AF=B9bert=5Ffeatures?= =?UTF-8?q?=E8=BF=9B=E8=A1=8Cproject=EF=BC=8C=E5=86=8Dpad=E5=88=B0?= =?UTF-8?q?=E7=9B=B8=E5=90=8C=E9=95=BF=E5=BA=A6=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- GPT_SoVITS/AR/models/t2s_model.py | 50 ++++++++++++++------ GPT_SoVITS/TTS_infer_pack/TTS.py | 78 +++++++++++++++++++------------ 2 files changed, 85 insertions(+), 43 deletions(-) diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index da95111..b49bcfb 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -504,18 +504,29 @@ class Text2SemanticDecoder(nn.Module): def infer_panel_batch_infer_with_flash_attn( self, - x, #####全部文本token - x_lens, - prompts, ####参考音频token - bert_feature, + x:List[torch.LongTensor], #####全部文本token + x_lens:torch.LongTensor, + prompts:torch.LongTensor, ####参考音频token + bert_feature:List[torch.LongTensor], top_k: int = -100, top_p: int = 100, early_stop_num: int = -1, temperature: float = 1.0, ): - - bert_feature = self.bert_proj(bert_feature.transpose(1, 2)) - x = self.ar_text_embedding(x) + # 先对phones进行embedding、对bert_features进行project,再pad到相同长度,以缓解复读问题。(可能还有其他因素导致复读) + max_len = 0 + for x_item, bert_item in zip(x, bert_feature): + max_len = max(max_len, x_item.shape[0], bert_item.shape[1]) + x_list = [self.ar_text_embedding(item) for item in x] + x_list = [F.pad(item,(0,0,0,max_len-item.shape[0]),value=0) if item.shape[0]