fix spectrum take out working

This commit is contained in:
zpeng11 2025-08-20 18:32:38 -04:00
parent 911c53b1ee
commit fd0fb35a49
2 changed files with 38 additions and 31 deletions

View File

@ -13,6 +13,7 @@ import json
import os
import soundfile
from tqdm import tqdm
from text import cleaned_text_to_sequence
@ -133,7 +134,7 @@ class T2SModel(nn.Module):
y, k, v, y_emb, x_example = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
stop = False
for idx in range(1, 1500):
for idx in tqdm(range(1, 1500)):
# [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
enco = self.stage_decoder(y, k, v, y_emb, x_example)
y, k, v, y_emb, logits, samples = enco
@ -230,19 +231,11 @@ class VitsModel(nn.Module):
self.vq_model.eval()
self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
#filter_length: 2048 sampling_rate: 32000 hop_length: 640 win_length: 2048
def forward(self, text_seq, pred_semantic, ref_audio):
refer = spectrogram_torch(
ref_audio,
self.hps.data.filter_length,
self.hps.data.sampling_rate,
self.hps.data.hop_length,
self.hps.data.win_length,
center=False,
)
def forward(self, text_seq, pred_semantic, ref_audio, spectrum):
if self.sv_model is not None:
sv_emb=self.sv_model.compute_embedding3_onnx(resample_audio(ref_audio, 32000, 16000))
return self.vq_model(pred_semantic, text_seq, refer, sv_emb=sv_emb)[0, 0]
return self.vq_model(pred_semantic, text_seq, refer)[0, 0]
return self.vq_model(pred_semantic, text_seq, spectrum, sv_emb=sv_emb)[0, 0]
return self.vq_model(pred_semantic, text_seq, spectrum)[0, 0]
class GptSoVits(nn.Module):
@ -251,24 +244,25 @@ class GptSoVits(nn.Module):
self.vits = vits
self.t2s = t2s
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content):
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, spectrum, ssl_content):
pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
audio = self.vits(text_seq, pred_semantic, ref_audio)
audio = self.vits(text_seq, pred_semantic, ref_audio, spectrum)
return audio
def export(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, project_name):
def export(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, spectrum, ssl_content, project_name):
self.t2s.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name)
pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
torch.onnx.export(
self.vits,
(text_seq, pred_semantic, ref_audio),
(text_seq, pred_semantic, ref_audio, spectrum),
f"onnx/{project_name}/{project_name}_vits.onnx",
input_names=["text_seq", "pred_semantic", "ref_audio"],
input_names=["text_seq", "pred_semantic", "ref_audio", "spectrum"],
output_names=["audio"],
dynamic_axes={
"text_seq": {1: "text_length"},
"pred_semantic": {2: "pred_length"},
"ref_audio": {1: "audio_length"},
"text_seq": {0:"batch_size",1: "text_length"},
"pred_semantic": {0: "batch_size", 2: "pred_length"},
"ref_audio": {0: "batch_size", 1: "audio_length"},
"spectrum": {0: "batch_size", 2: "spectrum_length"},
},
opset_version=17,
verbose=False,
@ -292,14 +286,23 @@ class HuBertSSLModel(nn.Module):
self.model.eval()
def forward(self, ref_audio_32k):
spectrum = spectrogram_torch(
ref_audio_32k,
2048,
32000,
640,
2048,
center=False,
)
ref_audio_16k = resample_audio(ref_audio_32k, 32000, 16000).unsqueeze(0)
zero_tensor = torch.zeros((1, 4800), dtype=torch.float32)
# concate zero_tensor with waveform
ref_audio_16k = torch.cat([ref_audio_16k, zero_tensor], dim=1)
ssl_content = self.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2)
return ssl_content
return ssl_content, spectrum
def export(vits_path, gpt_path, project_name, voice_model_version="v2"):
@ -375,14 +378,17 @@ def export(vits_path, gpt_path, project_name, voice_model_version="v2"):
torch.onnx.export(ssl, (ref_audio32k,), f"onnx/{project_name}/{project_name}_hubertssl.onnx",
input_names=["audio32k"],
output_names=["hubert_ssl_output"],
output_names=["hubert_ssl_output", "spectrum"],
dynamic_axes={
"audio32k": {0: "batch_size", 1: "sequence_length"},
"hubert_ssl_output": {0: "batch_size", 2: "hubert_length"}
"hubert_ssl_output": {0: "batch_size", 2: "hubert_length"},
"spectrum": {0: "batch_size", 2: "spectrum_length"}
})
ssl_content = ssl(ref_audio32k).float()
gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio32k, ssl_content, project_name)
[ssl_content, spectrum] = ssl(ref_audio32k)
gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio32k, spectrum.float(), ssl_content.float())
# exit()
gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio32k, spectrum.float(), ssl_content.float(), project_name)
if voice_model_version == "v1":
symbols = symbols_v1

View File

@ -86,10 +86,10 @@ def get_audio_hubert(audio_path):
waveform = load_and_preprocess_audio(audio_path)
ort_session = ort.InferenceSession(MODEL_PATH + "_export_hubertssl.onnx")
ort_inputs = {ort_session.get_inputs()[0].name: waveform.numpy()}
hubert_feature = ort_session.run(None, ort_inputs)[0].astype(np.float32)
[hubert_feature, spectrum] = ort_session.run(None, ort_inputs)
# transpose axis 1 and 2 with numpy
# hubert_feature = hubert_feature.transpose(0, 2, 1)
return hubert_feature
return hubert_feature, spectrum
def preprocess_text(text:str):
preprocessor = TextPreprocessorOnnx("playground/bert")
@ -100,7 +100,7 @@ def preprocess_text(text:str):
# input_phones_saved = np.load("playground/ref/input_phones.npy")
# input_bert_saved = np.load("playground/ref/input_bert.npy").T.astype(np.float32)
[input_phones, input_bert] = preprocess_text("震撼视角,感受开幕式,闭幕式烟花")
[input_phones, input_bert] = preprocess_text("像大雨匆匆打击过的屋檐")
# ref_phones = np.load("playground/ref/ref_phones.npy")
@ -108,7 +108,7 @@ def preprocess_text(text:str):
[ref_phones, ref_bert] = preprocess_text("今日江苏苏州荷花市集开张热闹与浪漫交织")
audio_prompt_hubert = get_audio_hubert("playground/ref/audio.wav")
[audio_prompt_hubert, spectrum] = get_audio_hubert("playground/ref/audio.wav")
encoder = ort.InferenceSession(MODEL_PATH+"_export_t2s_encoder.onnx")
@ -160,7 +160,8 @@ vtis = ort.InferenceSession(MODEL_PATH+"_export_vits.onnx")
[audio] = vtis.run(None, {
"text_seq": input_phones,
"pred_semantic": pred_semantic,
"ref_audio": ref_audio
"ref_audio": ref_audio,
"spectrum": spectrum.astype(np.float32)
})
audio_postprocess([audio])