feat:remove debug, todo:rewrite the onnx export interface

This commit is contained in:
zpeng11 2025-08-17 19:22:11 -04:00
parent 8c0f32da3e
commit 610b36561a
2 changed files with 2 additions and 25 deletions

View File

@ -2,7 +2,7 @@ from torch.nn.functional import *
from torch.nn.functional import (
_canonical_mask,
)
from typing import Tuple
from typing import Tuple, Optional
def multi_head_attention_forward_patched(

View File

@ -236,22 +236,9 @@ class GptSoVits(nn.Module):
self.vits = vits
self.t2s = t2s
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, debug=False):
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content):
pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
audio = self.vits(text_seq, pred_semantic, ref_audio)
if debug:
import onnxruntime
sess = onnxruntime.InferenceSession("onnx/koharu/koharu_vits.onnx", providers=["CPU"])
audio1 = sess.run(
None,
{
"text_seq": text_seq.detach().cpu().numpy(),
"pred_semantic": pred_semantic.detach().cpu().numpy(),
"ref_audio": ref_audio.detach().cpu().numpy(),
},
)
return audio, audio1
return audio
def export(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, project_name):
@ -356,18 +343,8 @@ def export(vits_path, gpt_path, project_name, voice_model_version="v2"):
ssl_content = ssl(ref_audio_16k).float()
# debug = False
debug = False
gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, project_name)
if debug:
a, b = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, debug=debug)
soundfile.write("out1.wav", a.cpu().detach().numpy(), vits.hps.data.sampling_rate)
soundfile.write("out2.wav", b[0], vits.hps.data.sampling_rate)
else:
a = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content).detach().cpu().numpy()
soundfile.write("out.wav", a, vits.hps.data.sampling_rate)
if voice_model_version == "v1":
symbols = symbols_v1
else: