mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-29 00:30:15 +08:00
feat:remove debug, todo:rewrite the onnx export interface
This commit is contained in:
parent
8c0f32da3e
commit
610b36561a
@ -2,7 +2,7 @@ from torch.nn.functional import *
|
||||
from torch.nn.functional import (
|
||||
_canonical_mask,
|
||||
)
|
||||
from typing import Tuple
|
||||
from typing import Tuple, Optional
|
||||
|
||||
|
||||
def multi_head_attention_forward_patched(
|
||||
|
@ -236,22 +236,9 @@ class GptSoVits(nn.Module):
|
||||
self.vits = vits
|
||||
self.t2s = t2s
|
||||
|
||||
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, debug=False):
|
||||
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content):
|
||||
pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
||||
audio = self.vits(text_seq, pred_semantic, ref_audio)
|
||||
if debug:
|
||||
import onnxruntime
|
||||
|
||||
sess = onnxruntime.InferenceSession("onnx/koharu/koharu_vits.onnx", providers=["CPU"])
|
||||
audio1 = sess.run(
|
||||
None,
|
||||
{
|
||||
"text_seq": text_seq.detach().cpu().numpy(),
|
||||
"pred_semantic": pred_semantic.detach().cpu().numpy(),
|
||||
"ref_audio": ref_audio.detach().cpu().numpy(),
|
||||
},
|
||||
)
|
||||
return audio, audio1
|
||||
return audio
|
||||
|
||||
def export(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, project_name):
|
||||
@ -356,18 +343,8 @@ def export(vits_path, gpt_path, project_name, voice_model_version="v2"):
|
||||
|
||||
ssl_content = ssl(ref_audio_16k).float()
|
||||
|
||||
# debug = False
|
||||
debug = False
|
||||
gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, project_name)
|
||||
|
||||
if debug:
|
||||
a, b = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, debug=debug)
|
||||
soundfile.write("out1.wav", a.cpu().detach().numpy(), vits.hps.data.sampling_rate)
|
||||
soundfile.write("out2.wav", b[0], vits.hps.data.sampling_rate)
|
||||
else:
|
||||
a = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content).detach().cpu().numpy()
|
||||
soundfile.write("out.wav", a, vits.hps.data.sampling_rate)
|
||||
|
||||
if voice_model_version == "v1":
|
||||
symbols = symbols_v1
|
||||
else:
|
||||
|
Loading…
x
Reference in New Issue
Block a user