From 911c53b1eecf856217fcb07d876fef4ddb5f0702 Mon Sep 17 00:00:00 2001 From: zpeng11 Date: Wed, 20 Aug 2025 17:37:41 -0400 Subject: [PATCH] fixed using hubert for full run, 80 works --- GPT_SoVITS/onnx_export.py | 65 +++++++++++++++++++++------------------ playground/freerun.py | 61 +++++++++++++++++++++++++----------- 2 files changed, 78 insertions(+), 48 deletions(-) diff --git a/GPT_SoVITS/onnx_export.py b/GPT_SoVITS/onnx_export.py index be1f33ce..c94b1542 100644 --- a/GPT_SoVITS/onnx_export.py +++ b/GPT_SoVITS/onnx_export.py @@ -8,8 +8,7 @@ from torch import nn from sv import SV cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base" -cnhubert.cnhubert_base_path = cnhubert_base_path -ssl_model = cnhubert.get_model() +from transformers import HubertModel, HubertConfig import json import os @@ -230,7 +229,7 @@ class VitsModel(nn.Module): ) self.vq_model.eval() self.vq_model.load_state_dict(dict_s2["weight"], strict=False) - + #filter_length: 2048 sampling_rate: 32000 hop_length: 640 win_length: 2048 def forward(self, text_seq, pred_semantic, ref_audio): refer = spectrogram_torch( ref_audio, @@ -276,20 +275,38 @@ class GptSoVits(nn.Module): ) -class SSLModel(nn.Module): +class HuBertSSLModel(nn.Module): def __init__(self): super().__init__() - self.ssl = ssl_model + self.config = HubertConfig.from_pretrained(cnhubert_base_path) + self.config._attn_implementation = "eager" # Use standard attention + self.config.apply_spec_augment = False # Disable masking for inference + self.config.layerdrop = 0.0 # Disable layer dropout + + # Load the model + self.model = HubertModel.from_pretrained( + cnhubert_base_path, + config=self.config, + local_files_only=True + ) + self.model.eval() - def forward(self, ref_audio_16k): - return self.ssl.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2) + def forward(self, ref_audio_32k): + ref_audio_16k = resample_audio(ref_audio_32k, 32000, 16000).unsqueeze(0) + zero_tensor = torch.zeros((1, 4800), dtype=torch.float32) + + # concate zero_tensor with waveform + ref_audio_16k = torch.cat([ref_audio_16k, zero_tensor], dim=1) + ssl_content = self.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2) + + return ssl_content def export(vits_path, gpt_path, project_name, voice_model_version="v2"): vits = VitsModel(vits_path, version=voice_model_version) gpt = T2SModel(gpt_path, vits) gpt_sovits = GptSoVits(vits, gpt) - ssl = SSLModel() + ssl = HuBertSSLModel() ref_seq = torch.LongTensor( [ cleaned_text_to_sequence( @@ -349,41 +366,29 @@ def export(vits_path, gpt_path, project_name, voice_model_version="v2"): text_bert = torch.randn((text_seq.shape[1], 1024)).float() ref_audio = torch.randn((1, 48000 * 5)).float() # ref_audio = torch.tensor([load_audio("rec.wav", 48000)]).float() - ref_audio_16k = torchaudio.functional.resample(ref_audio, 48000, 16000).float() - ref_audio_sr = torchaudio.functional.resample(ref_audio, 48000, vits.hps.data.sampling_rate).float() + ref_audio32k = torchaudio.functional.resample(ref_audio, 48000, 32000).float() try: os.mkdir(f"onnx/{project_name}") except: pass - ssl_content = ssl(ref_audio_16k).float() + torch.onnx.export(ssl, (ref_audio32k,), f"onnx/{project_name}/{project_name}_hubertssl.onnx", + input_names=["audio32k"], + output_names=["hubert_ssl_output"], + dynamic_axes={ + "audio32k": {0: "batch_size", 1: "sequence_length"}, + "hubert_ssl_output": {0: "batch_size", 2: "hubert_length"} + }) - gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, project_name) + ssl_content = ssl(ref_audio32k).float() + gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio32k, ssl_content, project_name) if voice_model_version == "v1": symbols = symbols_v1 else: symbols = symbols_v2 - MoeVSConf = { - "Folder": f"{project_name}", - "Name": f"{project_name}", - "Type": "GPT-SoVits", - "Rate": vits.hps.data.sampling_rate, - "NumLayers": gpt.t2s_model.num_layers, - "EmbeddingDim": gpt.t2s_model.embedding_dim, - "Dict": "BasicDict", - "BertPath": "chinese-roberta-wwm-ext-large", - # "Symbol": symbols, - "AddBlank": False, - } - - MoeVSConfJson = json.dumps(MoeVSConf) - with open(f"onnx/{project_name}.json", "w") as MoeVsConfFile: - json.dump(MoeVSConf, MoeVsConfFile, indent=4) - - if __name__ == "__main__": try: os.mkdir("onnx") diff --git a/playground/freerun.py b/playground/freerun.py index 66877ff3..3be6b3f0 100644 --- a/playground/freerun.py +++ b/playground/freerun.py @@ -31,39 +31,64 @@ def audio_postprocess( return audio -def load_and_preprocess_audio(audio_path, max_length=160000): - """Load and preprocess audio file to 16k""" - waveform, sample_rate = torchaudio.load(audio_path) +# def load_and_preprocess_audio(audio_path, max_length=160000): +# """Load and preprocess audio file to 16k""" +# waveform, sample_rate = torchaudio.load(audio_path) - # Resample to 16kHz if needed - if sample_rate != 16000: - resampler = torchaudio.transforms.Resample(sample_rate, 16000) +# # Resample to 16kHz if needed +# if sample_rate != 16000: +# resampler = torchaudio.transforms.Resample(sample_rate, 16000) +# waveform = resampler(waveform) + +# # Take first channel +# if waveform.shape[0] > 1: +# waveform = waveform[0:1] + +# # Limit length for testing (10 seconds at 16kHz) +# if waveform.shape[1] > max_length: +# waveform = waveform[:, :max_length] + +# # make a zero tensor that has length 3200*0.3 +# zero_tensor = torch.zeros((1, 4800), dtype=torch.float32) + +# # concate zero_tensor with waveform +# waveform = torch.cat([waveform, zero_tensor], dim=1) + +# return waveform + +# def get_audio_hubert(audio_path): +# """Get HuBERT features for the audio file""" +# waveform = load_and_preprocess_audio(audio_path) +# ort_session = ort.InferenceSession("playground/hubert/chinese-hubert-base.onnx") +# ort_inputs = {ort_session.get_inputs()[0].name: waveform.numpy()} +# hubert_feature = ort_session.run(None, ort_inputs)[0].astype(np.float32) +# # transpose axis 1 and 2 with numpy +# hubert_feature = hubert_feature.transpose(0, 2, 1) +# return hubert_feature + +def load_and_preprocess_audio(audio_path): + """Load and preprocess audio file to 32k""" + waveform, sample_rate = torchaudio.load(audio_path) + + # Resample to 32kHz if needed + if sample_rate != 32000: + resampler = torchaudio.transforms.Resample(sample_rate, 32000) waveform = resampler(waveform) # Take first channel if waveform.shape[0] > 1: waveform = waveform[0:1] - - # Limit length for testing (10 seconds at 16kHz) - if waveform.shape[1] > max_length: - waveform = waveform[:, :max_length] - - # make a zero tensor that has length 3200*0.3 - zero_tensor = torch.zeros((1, 4800), dtype=torch.float32) - - # concate zero_tensor with waveform - waveform = torch.cat([waveform, zero_tensor], dim=1) return waveform def get_audio_hubert(audio_path): """Get HuBERT features for the audio file""" waveform = load_and_preprocess_audio(audio_path) - ort_session = ort.InferenceSession("playground/hubert/chinese-hubert-base.onnx") + ort_session = ort.InferenceSession(MODEL_PATH + "_export_hubertssl.onnx") ort_inputs = {ort_session.get_inputs()[0].name: waveform.numpy()} hubert_feature = ort_session.run(None, ort_inputs)[0].astype(np.float32) # transpose axis 1 and 2 with numpy - hubert_feature = hubert_feature.transpose(0, 2, 1) + # hubert_feature = hubert_feature.transpose(0, 2, 1) return hubert_feature def preprocess_text(text:str):