From fd0fb35a491b39eade9244176efd2ecde8fdf5a3 Mon Sep 17 00:00:00 2001 From: zpeng11 Date: Wed, 20 Aug 2025 18:32:38 -0400 Subject: [PATCH] fix spectrum take out working --- GPT_SoVITS/onnx_export.py | 58 +++++++++++++++++++++------------------ playground/freerun.py | 11 ++++---- 2 files changed, 38 insertions(+), 31 deletions(-) diff --git a/GPT_SoVITS/onnx_export.py b/GPT_SoVITS/onnx_export.py index c94b1542..e2aa4ea3 100644 --- a/GPT_SoVITS/onnx_export.py +++ b/GPT_SoVITS/onnx_export.py @@ -13,6 +13,7 @@ import json import os import soundfile +from tqdm import tqdm from text import cleaned_text_to_sequence @@ -133,7 +134,7 @@ class T2SModel(nn.Module): y, k, v, y_emb, x_example = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content) stop = False - for idx in range(1, 1500): + for idx in tqdm(range(1, 1500)): # [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N] enco = self.stage_decoder(y, k, v, y_emb, x_example) y, k, v, y_emb, logits, samples = enco @@ -230,19 +231,11 @@ class VitsModel(nn.Module): self.vq_model.eval() self.vq_model.load_state_dict(dict_s2["weight"], strict=False) #filter_length: 2048 sampling_rate: 32000 hop_length: 640 win_length: 2048 - def forward(self, text_seq, pred_semantic, ref_audio): - refer = spectrogram_torch( - ref_audio, - self.hps.data.filter_length, - self.hps.data.sampling_rate, - self.hps.data.hop_length, - self.hps.data.win_length, - center=False, - ) + def forward(self, text_seq, pred_semantic, ref_audio, spectrum): if self.sv_model is not None: sv_emb=self.sv_model.compute_embedding3_onnx(resample_audio(ref_audio, 32000, 16000)) - return self.vq_model(pred_semantic, text_seq, refer, sv_emb=sv_emb)[0, 0] - return self.vq_model(pred_semantic, text_seq, refer)[0, 0] + return self.vq_model(pred_semantic, text_seq, spectrum, sv_emb=sv_emb)[0, 0] + return self.vq_model(pred_semantic, text_seq, spectrum)[0, 0] class GptSoVits(nn.Module): @@ -251,24 +244,25 @@ class GptSoVits(nn.Module): self.vits = vits self.t2s = t2s - def forward(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content): + def forward(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, spectrum, ssl_content): pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content) - audio = self.vits(text_seq, pred_semantic, ref_audio) + audio = self.vits(text_seq, pred_semantic, ref_audio, spectrum) return audio - def export(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, project_name): + def export(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, spectrum, ssl_content, project_name): self.t2s.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name) pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content) torch.onnx.export( self.vits, - (text_seq, pred_semantic, ref_audio), + (text_seq, pred_semantic, ref_audio, spectrum), f"onnx/{project_name}/{project_name}_vits.onnx", - input_names=["text_seq", "pred_semantic", "ref_audio"], + input_names=["text_seq", "pred_semantic", "ref_audio", "spectrum"], output_names=["audio"], dynamic_axes={ - "text_seq": {1: "text_length"}, - "pred_semantic": {2: "pred_length"}, - "ref_audio": {1: "audio_length"}, + "text_seq": {0:"batch_size",1: "text_length"}, + "pred_semantic": {0: "batch_size", 2: "pred_length"}, + "ref_audio": {0: "batch_size", 1: "audio_length"}, + "spectrum": {0: "batch_size", 2: "spectrum_length"}, }, opset_version=17, verbose=False, @@ -292,14 +286,23 @@ class HuBertSSLModel(nn.Module): self.model.eval() def forward(self, ref_audio_32k): + spectrum = spectrogram_torch( + ref_audio_32k, + 2048, + 32000, + 640, + 2048, + center=False, + ) + + ref_audio_16k = resample_audio(ref_audio_32k, 32000, 16000).unsqueeze(0) zero_tensor = torch.zeros((1, 4800), dtype=torch.float32) - # concate zero_tensor with waveform ref_audio_16k = torch.cat([ref_audio_16k, zero_tensor], dim=1) ssl_content = self.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2) - return ssl_content + return ssl_content, spectrum def export(vits_path, gpt_path, project_name, voice_model_version="v2"): @@ -375,14 +378,17 @@ def export(vits_path, gpt_path, project_name, voice_model_version="v2"): torch.onnx.export(ssl, (ref_audio32k,), f"onnx/{project_name}/{project_name}_hubertssl.onnx", input_names=["audio32k"], - output_names=["hubert_ssl_output"], + output_names=["hubert_ssl_output", "spectrum"], dynamic_axes={ "audio32k": {0: "batch_size", 1: "sequence_length"}, - "hubert_ssl_output": {0: "batch_size", 2: "hubert_length"} + "hubert_ssl_output": {0: "batch_size", 2: "hubert_length"}, + "spectrum": {0: "batch_size", 2: "spectrum_length"} }) - ssl_content = ssl(ref_audio32k).float() - gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio32k, ssl_content, project_name) + [ssl_content, spectrum] = ssl(ref_audio32k) + gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio32k, spectrum.float(), ssl_content.float()) + # exit() + gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio32k, spectrum.float(), ssl_content.float(), project_name) if voice_model_version == "v1": symbols = symbols_v1 diff --git a/playground/freerun.py b/playground/freerun.py index 3be6b3f0..a9b8ed14 100644 --- a/playground/freerun.py +++ b/playground/freerun.py @@ -86,10 +86,10 @@ def get_audio_hubert(audio_path): waveform = load_and_preprocess_audio(audio_path) ort_session = ort.InferenceSession(MODEL_PATH + "_export_hubertssl.onnx") ort_inputs = {ort_session.get_inputs()[0].name: waveform.numpy()} - hubert_feature = ort_session.run(None, ort_inputs)[0].astype(np.float32) + [hubert_feature, spectrum] = ort_session.run(None, ort_inputs) # transpose axis 1 and 2 with numpy # hubert_feature = hubert_feature.transpose(0, 2, 1) - return hubert_feature + return hubert_feature, spectrum def preprocess_text(text:str): preprocessor = TextPreprocessorOnnx("playground/bert") @@ -100,7 +100,7 @@ def preprocess_text(text:str): # input_phones_saved = np.load("playground/ref/input_phones.npy") # input_bert_saved = np.load("playground/ref/input_bert.npy").T.astype(np.float32) -[input_phones, input_bert] = preprocess_text("震撼视角,感受开幕式,闭幕式烟花") +[input_phones, input_bert] = preprocess_text("像大雨匆匆打击过的屋檐") # ref_phones = np.load("playground/ref/ref_phones.npy") @@ -108,7 +108,7 @@ def preprocess_text(text:str): [ref_phones, ref_bert] = preprocess_text("今日江苏苏州荷花市集开张热闹与浪漫交织") -audio_prompt_hubert = get_audio_hubert("playground/ref/audio.wav") +[audio_prompt_hubert, spectrum] = get_audio_hubert("playground/ref/audio.wav") encoder = ort.InferenceSession(MODEL_PATH+"_export_t2s_encoder.onnx") @@ -160,7 +160,8 @@ vtis = ort.InferenceSession(MODEL_PATH+"_export_vits.onnx") [audio] = vtis.run(None, { "text_seq": input_phones, "pred_semantic": pred_semantic, - "ref_audio": ref_audio + "ref_audio": ref_audio, + "spectrum": spectrum.astype(np.float32) }) audio_postprocess([audio])