From 1ceab938bb76a267074713a0fff860725c9fd1b1 Mon Sep 17 00:00:00 2001 From: csh <458761603@qq.com> Date: Sun, 23 Mar 2025 14:20:50 +0800 Subject: [PATCH] =?UTF-8?q?Fix:=20=E7=94=B1=E4=BA=8E=20export=5Ftorch=5Fsc?= =?UTF-8?q?ript=5Fv3=20=E7=9A=84=E6=94=B9=E5=8A=A8=EF=BC=8Cv2=20=E7=8E=B0?= =?UTF-8?q?=E5=9C=A8=E9=9C=80=E8=A6=81=E4=BC=A0=E5=85=A5=20top=5Fk?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- GPT_SoVITS/export_torch_script.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/GPT_SoVITS/export_torch_script.py b/GPT_SoVITS/export_torch_script.py index 63e74d0..3f2c296 100644 --- a/GPT_SoVITS/export_torch_script.py +++ b/GPT_SoVITS/export_torch_script.py @@ -654,6 +654,8 @@ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_be torch._dynamo.mark_dynamic(ref_bert, 0) torch._dynamo.mark_dynamic(text_bert, 0) + top_k = torch.LongTensor([5]).to(device) + with torch.no_grad(): gpt_sovits_export = torch.jit.trace( gpt_sovits, @@ -663,7 +665,8 @@ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_be ref_seq, text_seq, ref_bert, - text_bert)) + text_bert, + top_k)) gpt_sovits_path = os.path.join(output_path, "gpt_sovits_model.pt") gpt_sovits_export.save(gpt_sovits_path) @@ -685,15 +688,26 @@ class GPT_SoVITS(nn.Module): self.t2s = t2s self.vits = vits - def forward(self, ssl_content:torch.Tensor, ref_audio_sr:torch.Tensor, ref_seq:Tensor, text_seq:Tensor, ref_bert:Tensor, text_bert:Tensor, speed=1.0): + def forward( + self, + ssl_content: torch.Tensor, + ref_audio_sr: torch.Tensor, + ref_seq: Tensor, + text_seq: Tensor, + ref_bert: Tensor, + text_bert: Tensor, + top_k: LongTensor, + speed=1.0, + ): codes = self.vits.vq_model.extract_latent(ssl_content) prompt_semantic = codes[0, 0] prompts = prompt_semantic.unsqueeze(0) - pred_semantic = self.t2s(prompts, ref_seq, text_seq, ref_bert, text_bert) + pred_semantic = self.t2s(prompts, ref_seq, text_seq, ref_bert, text_bert, top_k) audio = self.vits(text_seq, pred_semantic, ref_audio_sr, speed) return audio + def test(): parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool") parser.add_argument('--gpt_model', required=True, help="Path to the GPT model file") @@ -785,8 +799,10 @@ def test(): print('text_bert:',text_bert.shape) text_bert=text_bert.to('cuda') + top_k = torch.LongTensor([5]).to('cuda') + with torch.no_grad(): - audio = gpt_sovits(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, test_bert) + audio = gpt_sovits(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, test_bert, top_k) print('start write wav') soundfile.write("out.wav", audio.detach().cpu().numpy(), 32000)