diff --git a/GPT_SoVITS/export_torch_script.py b/GPT_SoVITS/export_torch_script.py index 63e74d0..3f2c296 100644 --- a/GPT_SoVITS/export_torch_script.py +++ b/GPT_SoVITS/export_torch_script.py @@ -654,6 +654,8 @@ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_be torch._dynamo.mark_dynamic(ref_bert, 0) torch._dynamo.mark_dynamic(text_bert, 0) + top_k = torch.LongTensor([5]).to(device) + with torch.no_grad(): gpt_sovits_export = torch.jit.trace( gpt_sovits, @@ -663,7 +665,8 @@ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_be ref_seq, text_seq, ref_bert, - text_bert)) + text_bert, + top_k)) gpt_sovits_path = os.path.join(output_path, "gpt_sovits_model.pt") gpt_sovits_export.save(gpt_sovits_path) @@ -685,15 +688,26 @@ class GPT_SoVITS(nn.Module): self.t2s = t2s self.vits = vits - def forward(self, ssl_content:torch.Tensor, ref_audio_sr:torch.Tensor, ref_seq:Tensor, text_seq:Tensor, ref_bert:Tensor, text_bert:Tensor, speed=1.0): + def forward( + self, + ssl_content: torch.Tensor, + ref_audio_sr: torch.Tensor, + ref_seq: Tensor, + text_seq: Tensor, + ref_bert: Tensor, + text_bert: Tensor, + top_k: LongTensor, + speed=1.0, + ): codes = self.vits.vq_model.extract_latent(ssl_content) prompt_semantic = codes[0, 0] prompts = prompt_semantic.unsqueeze(0) - pred_semantic = self.t2s(prompts, ref_seq, text_seq, ref_bert, text_bert) + pred_semantic = self.t2s(prompts, ref_seq, text_seq, ref_bert, text_bert, top_k) audio = self.vits(text_seq, pred_semantic, ref_audio_sr, speed) return audio + def test(): parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool") parser.add_argument('--gpt_model', required=True, help="Path to the GPT model file") @@ -785,8 +799,10 @@ def test(): print('text_bert:',text_bert.shape) text_bert=text_bert.to('cuda') + top_k = torch.LongTensor([5]).to('cuda') + with torch.no_grad(): - audio = gpt_sovits(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, test_bert) + audio = gpt_sovits(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, test_bert, top_k) print('start write wav') soundfile.write("out.wav", audio.detach().cpu().numpy(), 32000)