fixed using hubert for full run, 80 works

This commit is contained in:
zpeng11 2025-08-20 17:37:41 -04:00
parent 1cdd41d877
commit 911c53b1ee
2 changed files with 78 additions and 48 deletions

View File

@ -8,8 +8,7 @@ from torch import nn
from sv import SV
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
cnhubert.cnhubert_base_path = cnhubert_base_path
ssl_model = cnhubert.get_model()
from transformers import HubertModel, HubertConfig
import json
import os
@ -230,7 +229,7 @@ class VitsModel(nn.Module):
)
self.vq_model.eval()
self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
#filter_length: 2048 sampling_rate: 32000 hop_length: 640 win_length: 2048
def forward(self, text_seq, pred_semantic, ref_audio):
refer = spectrogram_torch(
ref_audio,
@ -276,20 +275,38 @@ class GptSoVits(nn.Module):
)
class SSLModel(nn.Module):
class HuBertSSLModel(nn.Module):
def __init__(self):
super().__init__()
self.ssl = ssl_model
self.config = HubertConfig.from_pretrained(cnhubert_base_path)
self.config._attn_implementation = "eager" # Use standard attention
self.config.apply_spec_augment = False # Disable masking for inference
self.config.layerdrop = 0.0 # Disable layer dropout
# Load the model
self.model = HubertModel.from_pretrained(
cnhubert_base_path,
config=self.config,
local_files_only=True
)
self.model.eval()
def forward(self, ref_audio_16k):
return self.ssl.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2)
def forward(self, ref_audio_32k):
ref_audio_16k = resample_audio(ref_audio_32k, 32000, 16000).unsqueeze(0)
zero_tensor = torch.zeros((1, 4800), dtype=torch.float32)
# concate zero_tensor with waveform
ref_audio_16k = torch.cat([ref_audio_16k, zero_tensor], dim=1)
ssl_content = self.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2)
return ssl_content
def export(vits_path, gpt_path, project_name, voice_model_version="v2"):
vits = VitsModel(vits_path, version=voice_model_version)
gpt = T2SModel(gpt_path, vits)
gpt_sovits = GptSoVits(vits, gpt)
ssl = SSLModel()
ssl = HuBertSSLModel()
ref_seq = torch.LongTensor(
[
cleaned_text_to_sequence(
@ -349,41 +366,29 @@ def export(vits_path, gpt_path, project_name, voice_model_version="v2"):
text_bert = torch.randn((text_seq.shape[1], 1024)).float()
ref_audio = torch.randn((1, 48000 * 5)).float()
# ref_audio = torch.tensor([load_audio("rec.wav", 48000)]).float()
ref_audio_16k = torchaudio.functional.resample(ref_audio, 48000, 16000).float()
ref_audio_sr = torchaudio.functional.resample(ref_audio, 48000, vits.hps.data.sampling_rate).float()
ref_audio32k = torchaudio.functional.resample(ref_audio, 48000, 32000).float()
try:
os.mkdir(f"onnx/{project_name}")
except:
pass
ssl_content = ssl(ref_audio_16k).float()
torch.onnx.export(ssl, (ref_audio32k,), f"onnx/{project_name}/{project_name}_hubertssl.onnx",
input_names=["audio32k"],
output_names=["hubert_ssl_output"],
dynamic_axes={
"audio32k": {0: "batch_size", 1: "sequence_length"},
"hubert_ssl_output": {0: "batch_size", 2: "hubert_length"}
})
gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, project_name)
ssl_content = ssl(ref_audio32k).float()
gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio32k, ssl_content, project_name)
if voice_model_version == "v1":
symbols = symbols_v1
else:
symbols = symbols_v2
MoeVSConf = {
"Folder": f"{project_name}",
"Name": f"{project_name}",
"Type": "GPT-SoVits",
"Rate": vits.hps.data.sampling_rate,
"NumLayers": gpt.t2s_model.num_layers,
"EmbeddingDim": gpt.t2s_model.embedding_dim,
"Dict": "BasicDict",
"BertPath": "chinese-roberta-wwm-ext-large",
# "Symbol": symbols,
"AddBlank": False,
}
MoeVSConfJson = json.dumps(MoeVSConf)
with open(f"onnx/{project_name}.json", "w") as MoeVsConfFile:
json.dump(MoeVSConf, MoeVsConfFile, indent=4)
if __name__ == "__main__":
try:
os.mkdir("onnx")

View File

@ -31,39 +31,64 @@ def audio_postprocess(
return audio
def load_and_preprocess_audio(audio_path, max_length=160000):
"""Load and preprocess audio file to 16k"""
waveform, sample_rate = torchaudio.load(audio_path)
# def load_and_preprocess_audio(audio_path, max_length=160000):
# """Load and preprocess audio file to 16k"""
# waveform, sample_rate = torchaudio.load(audio_path)
# Resample to 16kHz if needed
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(sample_rate, 16000)
# # Resample to 16kHz if needed
# if sample_rate != 16000:
# resampler = torchaudio.transforms.Resample(sample_rate, 16000)
# waveform = resampler(waveform)
# # Take first channel
# if waveform.shape[0] > 1:
# waveform = waveform[0:1]
# # Limit length for testing (10 seconds at 16kHz)
# if waveform.shape[1] > max_length:
# waveform = waveform[:, :max_length]
# # make a zero tensor that has length 3200*0.3
# zero_tensor = torch.zeros((1, 4800), dtype=torch.float32)
# # concate zero_tensor with waveform
# waveform = torch.cat([waveform, zero_tensor], dim=1)
# return waveform
# def get_audio_hubert(audio_path):
# """Get HuBERT features for the audio file"""
# waveform = load_and_preprocess_audio(audio_path)
# ort_session = ort.InferenceSession("playground/hubert/chinese-hubert-base.onnx")
# ort_inputs = {ort_session.get_inputs()[0].name: waveform.numpy()}
# hubert_feature = ort_session.run(None, ort_inputs)[0].astype(np.float32)
# # transpose axis 1 and 2 with numpy
# hubert_feature = hubert_feature.transpose(0, 2, 1)
# return hubert_feature
def load_and_preprocess_audio(audio_path):
"""Load and preprocess audio file to 32k"""
waveform, sample_rate = torchaudio.load(audio_path)
# Resample to 32kHz if needed
if sample_rate != 32000:
resampler = torchaudio.transforms.Resample(sample_rate, 32000)
waveform = resampler(waveform)
# Take first channel
if waveform.shape[0] > 1:
waveform = waveform[0:1]
# Limit length for testing (10 seconds at 16kHz)
if waveform.shape[1] > max_length:
waveform = waveform[:, :max_length]
# make a zero tensor that has length 3200*0.3
zero_tensor = torch.zeros((1, 4800), dtype=torch.float32)
# concate zero_tensor with waveform
waveform = torch.cat([waveform, zero_tensor], dim=1)
return waveform
def get_audio_hubert(audio_path):
"""Get HuBERT features for the audio file"""
waveform = load_and_preprocess_audio(audio_path)
ort_session = ort.InferenceSession("playground/hubert/chinese-hubert-base.onnx")
ort_session = ort.InferenceSession(MODEL_PATH + "_export_hubertssl.onnx")
ort_inputs = {ort_session.get_inputs()[0].name: waveform.numpy()}
hubert_feature = ort_session.run(None, ort_inputs)[0].astype(np.float32)
# transpose axis 1 and 2 with numpy
hubert_feature = hubert_feature.transpose(0, 2, 1)
# hubert_feature = hubert_feature.transpose(0, 2, 1)
return hubert_feature
def preprocess_text(text:str):