mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-29 17:10:02 +08:00
fixed using hubert for full run, 80 works
This commit is contained in:
parent
1cdd41d877
commit
911c53b1ee
@ -8,8 +8,7 @@ from torch import nn
|
||||
from sv import SV
|
||||
|
||||
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
|
||||
cnhubert.cnhubert_base_path = cnhubert_base_path
|
||||
ssl_model = cnhubert.get_model()
|
||||
from transformers import HubertModel, HubertConfig
|
||||
import json
|
||||
import os
|
||||
|
||||
@ -230,7 +229,7 @@ class VitsModel(nn.Module):
|
||||
)
|
||||
self.vq_model.eval()
|
||||
self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
||||
|
||||
#filter_length: 2048 sampling_rate: 32000 hop_length: 640 win_length: 2048
|
||||
def forward(self, text_seq, pred_semantic, ref_audio):
|
||||
refer = spectrogram_torch(
|
||||
ref_audio,
|
||||
@ -276,20 +275,38 @@ class GptSoVits(nn.Module):
|
||||
)
|
||||
|
||||
|
||||
class SSLModel(nn.Module):
|
||||
class HuBertSSLModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.ssl = ssl_model
|
||||
self.config = HubertConfig.from_pretrained(cnhubert_base_path)
|
||||
self.config._attn_implementation = "eager" # Use standard attention
|
||||
self.config.apply_spec_augment = False # Disable masking for inference
|
||||
self.config.layerdrop = 0.0 # Disable layer dropout
|
||||
|
||||
def forward(self, ref_audio_16k):
|
||||
return self.ssl.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2)
|
||||
# Load the model
|
||||
self.model = HubertModel.from_pretrained(
|
||||
cnhubert_base_path,
|
||||
config=self.config,
|
||||
local_files_only=True
|
||||
)
|
||||
self.model.eval()
|
||||
|
||||
def forward(self, ref_audio_32k):
|
||||
ref_audio_16k = resample_audio(ref_audio_32k, 32000, 16000).unsqueeze(0)
|
||||
zero_tensor = torch.zeros((1, 4800), dtype=torch.float32)
|
||||
|
||||
# concate zero_tensor with waveform
|
||||
ref_audio_16k = torch.cat([ref_audio_16k, zero_tensor], dim=1)
|
||||
ssl_content = self.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2)
|
||||
|
||||
return ssl_content
|
||||
|
||||
|
||||
def export(vits_path, gpt_path, project_name, voice_model_version="v2"):
|
||||
vits = VitsModel(vits_path, version=voice_model_version)
|
||||
gpt = T2SModel(gpt_path, vits)
|
||||
gpt_sovits = GptSoVits(vits, gpt)
|
||||
ssl = SSLModel()
|
||||
ssl = HuBertSSLModel()
|
||||
ref_seq = torch.LongTensor(
|
||||
[
|
||||
cleaned_text_to_sequence(
|
||||
@ -349,41 +366,29 @@ def export(vits_path, gpt_path, project_name, voice_model_version="v2"):
|
||||
text_bert = torch.randn((text_seq.shape[1], 1024)).float()
|
||||
ref_audio = torch.randn((1, 48000 * 5)).float()
|
||||
# ref_audio = torch.tensor([load_audio("rec.wav", 48000)]).float()
|
||||
ref_audio_16k = torchaudio.functional.resample(ref_audio, 48000, 16000).float()
|
||||
ref_audio_sr = torchaudio.functional.resample(ref_audio, 48000, vits.hps.data.sampling_rate).float()
|
||||
ref_audio32k = torchaudio.functional.resample(ref_audio, 48000, 32000).float()
|
||||
|
||||
try:
|
||||
os.mkdir(f"onnx/{project_name}")
|
||||
except:
|
||||
pass
|
||||
|
||||
ssl_content = ssl(ref_audio_16k).float()
|
||||
torch.onnx.export(ssl, (ref_audio32k,), f"onnx/{project_name}/{project_name}_hubertssl.onnx",
|
||||
input_names=["audio32k"],
|
||||
output_names=["hubert_ssl_output"],
|
||||
dynamic_axes={
|
||||
"audio32k": {0: "batch_size", 1: "sequence_length"},
|
||||
"hubert_ssl_output": {0: "batch_size", 2: "hubert_length"}
|
||||
})
|
||||
|
||||
gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, project_name)
|
||||
ssl_content = ssl(ref_audio32k).float()
|
||||
gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio32k, ssl_content, project_name)
|
||||
|
||||
if voice_model_version == "v1":
|
||||
symbols = symbols_v1
|
||||
else:
|
||||
symbols = symbols_v2
|
||||
|
||||
MoeVSConf = {
|
||||
"Folder": f"{project_name}",
|
||||
"Name": f"{project_name}",
|
||||
"Type": "GPT-SoVits",
|
||||
"Rate": vits.hps.data.sampling_rate,
|
||||
"NumLayers": gpt.t2s_model.num_layers,
|
||||
"EmbeddingDim": gpt.t2s_model.embedding_dim,
|
||||
"Dict": "BasicDict",
|
||||
"BertPath": "chinese-roberta-wwm-ext-large",
|
||||
# "Symbol": symbols,
|
||||
"AddBlank": False,
|
||||
}
|
||||
|
||||
MoeVSConfJson = json.dumps(MoeVSConf)
|
||||
with open(f"onnx/{project_name}.json", "w") as MoeVsConfFile:
|
||||
json.dump(MoeVSConf, MoeVsConfFile, indent=4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
os.mkdir("onnx")
|
||||
|
@ -31,39 +31,64 @@ def audio_postprocess(
|
||||
|
||||
return audio
|
||||
|
||||
def load_and_preprocess_audio(audio_path, max_length=160000):
|
||||
"""Load and preprocess audio file to 16k"""
|
||||
# def load_and_preprocess_audio(audio_path, max_length=160000):
|
||||
# """Load and preprocess audio file to 16k"""
|
||||
# waveform, sample_rate = torchaudio.load(audio_path)
|
||||
|
||||
# # Resample to 16kHz if needed
|
||||
# if sample_rate != 16000:
|
||||
# resampler = torchaudio.transforms.Resample(sample_rate, 16000)
|
||||
# waveform = resampler(waveform)
|
||||
|
||||
# # Take first channel
|
||||
# if waveform.shape[0] > 1:
|
||||
# waveform = waveform[0:1]
|
||||
|
||||
# # Limit length for testing (10 seconds at 16kHz)
|
||||
# if waveform.shape[1] > max_length:
|
||||
# waveform = waveform[:, :max_length]
|
||||
|
||||
# # make a zero tensor that has length 3200*0.3
|
||||
# zero_tensor = torch.zeros((1, 4800), dtype=torch.float32)
|
||||
|
||||
# # concate zero_tensor with waveform
|
||||
# waveform = torch.cat([waveform, zero_tensor], dim=1)
|
||||
|
||||
# return waveform
|
||||
|
||||
# def get_audio_hubert(audio_path):
|
||||
# """Get HuBERT features for the audio file"""
|
||||
# waveform = load_and_preprocess_audio(audio_path)
|
||||
# ort_session = ort.InferenceSession("playground/hubert/chinese-hubert-base.onnx")
|
||||
# ort_inputs = {ort_session.get_inputs()[0].name: waveform.numpy()}
|
||||
# hubert_feature = ort_session.run(None, ort_inputs)[0].astype(np.float32)
|
||||
# # transpose axis 1 and 2 with numpy
|
||||
# hubert_feature = hubert_feature.transpose(0, 2, 1)
|
||||
# return hubert_feature
|
||||
|
||||
def load_and_preprocess_audio(audio_path):
|
||||
"""Load and preprocess audio file to 32k"""
|
||||
waveform, sample_rate = torchaudio.load(audio_path)
|
||||
|
||||
# Resample to 16kHz if needed
|
||||
if sample_rate != 16000:
|
||||
resampler = torchaudio.transforms.Resample(sample_rate, 16000)
|
||||
# Resample to 32kHz if needed
|
||||
if sample_rate != 32000:
|
||||
resampler = torchaudio.transforms.Resample(sample_rate, 32000)
|
||||
waveform = resampler(waveform)
|
||||
|
||||
# Take first channel
|
||||
if waveform.shape[0] > 1:
|
||||
waveform = waveform[0:1]
|
||||
|
||||
# Limit length for testing (10 seconds at 16kHz)
|
||||
if waveform.shape[1] > max_length:
|
||||
waveform = waveform[:, :max_length]
|
||||
|
||||
# make a zero tensor that has length 3200*0.3
|
||||
zero_tensor = torch.zeros((1, 4800), dtype=torch.float32)
|
||||
|
||||
# concate zero_tensor with waveform
|
||||
waveform = torch.cat([waveform, zero_tensor], dim=1)
|
||||
|
||||
return waveform
|
||||
|
||||
def get_audio_hubert(audio_path):
|
||||
"""Get HuBERT features for the audio file"""
|
||||
waveform = load_and_preprocess_audio(audio_path)
|
||||
ort_session = ort.InferenceSession("playground/hubert/chinese-hubert-base.onnx")
|
||||
ort_session = ort.InferenceSession(MODEL_PATH + "_export_hubertssl.onnx")
|
||||
ort_inputs = {ort_session.get_inputs()[0].name: waveform.numpy()}
|
||||
hubert_feature = ort_session.run(None, ort_inputs)[0].astype(np.float32)
|
||||
# transpose axis 1 and 2 with numpy
|
||||
hubert_feature = hubert_feature.transpose(0, 2, 1)
|
||||
# hubert_feature = hubert_feature.transpose(0, 2, 1)
|
||||
return hubert_feature
|
||||
|
||||
def preprocess_text(text:str):
|
||||
|
Loading…
x
Reference in New Issue
Block a user