From 610b36561a7e47d0e9bd27597c5c24702e90d3f7 Mon Sep 17 00:00:00 2001 From: zpeng11 Date: Sun, 17 Aug 2025 19:22:11 -0400 Subject: [PATCH] feat:remove debug, todo:rewrite the onnx export interface --- .../AR/modules/patched_mha_with_cache_onnx.py | 2 +- GPT_SoVITS/onnx_export.py | 25 +------------------ 2 files changed, 2 insertions(+), 25 deletions(-) diff --git a/GPT_SoVITS/AR/modules/patched_mha_with_cache_onnx.py b/GPT_SoVITS/AR/modules/patched_mha_with_cache_onnx.py index b38a3907..bd39ff6e 100644 --- a/GPT_SoVITS/AR/modules/patched_mha_with_cache_onnx.py +++ b/GPT_SoVITS/AR/modules/patched_mha_with_cache_onnx.py @@ -2,7 +2,7 @@ from torch.nn.functional import * from torch.nn.functional import ( _canonical_mask, ) -from typing import Tuple +from typing import Tuple, Optional def multi_head_attention_forward_patched( diff --git a/GPT_SoVITS/onnx_export.py b/GPT_SoVITS/onnx_export.py index 9f100a5c..661c2edd 100644 --- a/GPT_SoVITS/onnx_export.py +++ b/GPT_SoVITS/onnx_export.py @@ -236,22 +236,9 @@ class GptSoVits(nn.Module): self.vits = vits self.t2s = t2s - def forward(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, debug=False): + def forward(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content): pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content) audio = self.vits(text_seq, pred_semantic, ref_audio) - if debug: - import onnxruntime - - sess = onnxruntime.InferenceSession("onnx/koharu/koharu_vits.onnx", providers=["CPU"]) - audio1 = sess.run( - None, - { - "text_seq": text_seq.detach().cpu().numpy(), - "pred_semantic": pred_semantic.detach().cpu().numpy(), - "ref_audio": ref_audio.detach().cpu().numpy(), - }, - ) - return audio, audio1 return audio def export(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, project_name): @@ -356,18 +343,8 @@ def export(vits_path, gpt_path, project_name, voice_model_version="v2"): ssl_content = ssl(ref_audio_16k).float() - # debug = False - debug = False gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, project_name) - if debug: - a, b = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, debug=debug) - soundfile.write("out1.wav", a.cpu().detach().numpy(), vits.hps.data.sampling_rate) - soundfile.write("out2.wav", b[0], vits.hps.data.sampling_rate) - else: - a = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content).detach().cpu().numpy() - soundfile.write("out.wav", a, vits.hps.data.sampling_rate) - if voice_model_version == "v1": symbols = symbols_v1 else: