mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-29 17:10:02 +08:00
fixed using hubert for full run, 80 works
This commit is contained in:
parent
1cdd41d877
commit
911c53b1ee
@ -8,8 +8,7 @@ from torch import nn
|
|||||||
from sv import SV
|
from sv import SV
|
||||||
|
|
||||||
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
|
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
|
||||||
cnhubert.cnhubert_base_path = cnhubert_base_path
|
from transformers import HubertModel, HubertConfig
|
||||||
ssl_model = cnhubert.get_model()
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
@ -230,7 +229,7 @@ class VitsModel(nn.Module):
|
|||||||
)
|
)
|
||||||
self.vq_model.eval()
|
self.vq_model.eval()
|
||||||
self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
||||||
|
#filter_length: 2048 sampling_rate: 32000 hop_length: 640 win_length: 2048
|
||||||
def forward(self, text_seq, pred_semantic, ref_audio):
|
def forward(self, text_seq, pred_semantic, ref_audio):
|
||||||
refer = spectrogram_torch(
|
refer = spectrogram_torch(
|
||||||
ref_audio,
|
ref_audio,
|
||||||
@ -276,20 +275,38 @@ class GptSoVits(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class SSLModel(nn.Module):
|
class HuBertSSLModel(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.ssl = ssl_model
|
self.config = HubertConfig.from_pretrained(cnhubert_base_path)
|
||||||
|
self.config._attn_implementation = "eager" # Use standard attention
|
||||||
|
self.config.apply_spec_augment = False # Disable masking for inference
|
||||||
|
self.config.layerdrop = 0.0 # Disable layer dropout
|
||||||
|
|
||||||
def forward(self, ref_audio_16k):
|
# Load the model
|
||||||
return self.ssl.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2)
|
self.model = HubertModel.from_pretrained(
|
||||||
|
cnhubert_base_path,
|
||||||
|
config=self.config,
|
||||||
|
local_files_only=True
|
||||||
|
)
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
def forward(self, ref_audio_32k):
|
||||||
|
ref_audio_16k = resample_audio(ref_audio_32k, 32000, 16000).unsqueeze(0)
|
||||||
|
zero_tensor = torch.zeros((1, 4800), dtype=torch.float32)
|
||||||
|
|
||||||
|
# concate zero_tensor with waveform
|
||||||
|
ref_audio_16k = torch.cat([ref_audio_16k, zero_tensor], dim=1)
|
||||||
|
ssl_content = self.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2)
|
||||||
|
|
||||||
|
return ssl_content
|
||||||
|
|
||||||
|
|
||||||
def export(vits_path, gpt_path, project_name, voice_model_version="v2"):
|
def export(vits_path, gpt_path, project_name, voice_model_version="v2"):
|
||||||
vits = VitsModel(vits_path, version=voice_model_version)
|
vits = VitsModel(vits_path, version=voice_model_version)
|
||||||
gpt = T2SModel(gpt_path, vits)
|
gpt = T2SModel(gpt_path, vits)
|
||||||
gpt_sovits = GptSoVits(vits, gpt)
|
gpt_sovits = GptSoVits(vits, gpt)
|
||||||
ssl = SSLModel()
|
ssl = HuBertSSLModel()
|
||||||
ref_seq = torch.LongTensor(
|
ref_seq = torch.LongTensor(
|
||||||
[
|
[
|
||||||
cleaned_text_to_sequence(
|
cleaned_text_to_sequence(
|
||||||
@ -349,41 +366,29 @@ def export(vits_path, gpt_path, project_name, voice_model_version="v2"):
|
|||||||
text_bert = torch.randn((text_seq.shape[1], 1024)).float()
|
text_bert = torch.randn((text_seq.shape[1], 1024)).float()
|
||||||
ref_audio = torch.randn((1, 48000 * 5)).float()
|
ref_audio = torch.randn((1, 48000 * 5)).float()
|
||||||
# ref_audio = torch.tensor([load_audio("rec.wav", 48000)]).float()
|
# ref_audio = torch.tensor([load_audio("rec.wav", 48000)]).float()
|
||||||
ref_audio_16k = torchaudio.functional.resample(ref_audio, 48000, 16000).float()
|
ref_audio32k = torchaudio.functional.resample(ref_audio, 48000, 32000).float()
|
||||||
ref_audio_sr = torchaudio.functional.resample(ref_audio, 48000, vits.hps.data.sampling_rate).float()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
os.mkdir(f"onnx/{project_name}")
|
os.mkdir(f"onnx/{project_name}")
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
ssl_content = ssl(ref_audio_16k).float()
|
torch.onnx.export(ssl, (ref_audio32k,), f"onnx/{project_name}/{project_name}_hubertssl.onnx",
|
||||||
|
input_names=["audio32k"],
|
||||||
|
output_names=["hubert_ssl_output"],
|
||||||
|
dynamic_axes={
|
||||||
|
"audio32k": {0: "batch_size", 1: "sequence_length"},
|
||||||
|
"hubert_ssl_output": {0: "batch_size", 2: "hubert_length"}
|
||||||
|
})
|
||||||
|
|
||||||
gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, project_name)
|
ssl_content = ssl(ref_audio32k).float()
|
||||||
|
gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio32k, ssl_content, project_name)
|
||||||
|
|
||||||
if voice_model_version == "v1":
|
if voice_model_version == "v1":
|
||||||
symbols = symbols_v1
|
symbols = symbols_v1
|
||||||
else:
|
else:
|
||||||
symbols = symbols_v2
|
symbols = symbols_v2
|
||||||
|
|
||||||
MoeVSConf = {
|
|
||||||
"Folder": f"{project_name}",
|
|
||||||
"Name": f"{project_name}",
|
|
||||||
"Type": "GPT-SoVits",
|
|
||||||
"Rate": vits.hps.data.sampling_rate,
|
|
||||||
"NumLayers": gpt.t2s_model.num_layers,
|
|
||||||
"EmbeddingDim": gpt.t2s_model.embedding_dim,
|
|
||||||
"Dict": "BasicDict",
|
|
||||||
"BertPath": "chinese-roberta-wwm-ext-large",
|
|
||||||
# "Symbol": symbols,
|
|
||||||
"AddBlank": False,
|
|
||||||
}
|
|
||||||
|
|
||||||
MoeVSConfJson = json.dumps(MoeVSConf)
|
|
||||||
with open(f"onnx/{project_name}.json", "w") as MoeVsConfFile:
|
|
||||||
json.dump(MoeVSConf, MoeVsConfFile, indent=4)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
try:
|
try:
|
||||||
os.mkdir("onnx")
|
os.mkdir("onnx")
|
||||||
|
@ -31,39 +31,64 @@ def audio_postprocess(
|
|||||||
|
|
||||||
return audio
|
return audio
|
||||||
|
|
||||||
def load_and_preprocess_audio(audio_path, max_length=160000):
|
# def load_and_preprocess_audio(audio_path, max_length=160000):
|
||||||
"""Load and preprocess audio file to 16k"""
|
# """Load and preprocess audio file to 16k"""
|
||||||
|
# waveform, sample_rate = torchaudio.load(audio_path)
|
||||||
|
|
||||||
|
# # Resample to 16kHz if needed
|
||||||
|
# if sample_rate != 16000:
|
||||||
|
# resampler = torchaudio.transforms.Resample(sample_rate, 16000)
|
||||||
|
# waveform = resampler(waveform)
|
||||||
|
|
||||||
|
# # Take first channel
|
||||||
|
# if waveform.shape[0] > 1:
|
||||||
|
# waveform = waveform[0:1]
|
||||||
|
|
||||||
|
# # Limit length for testing (10 seconds at 16kHz)
|
||||||
|
# if waveform.shape[1] > max_length:
|
||||||
|
# waveform = waveform[:, :max_length]
|
||||||
|
|
||||||
|
# # make a zero tensor that has length 3200*0.3
|
||||||
|
# zero_tensor = torch.zeros((1, 4800), dtype=torch.float32)
|
||||||
|
|
||||||
|
# # concate zero_tensor with waveform
|
||||||
|
# waveform = torch.cat([waveform, zero_tensor], dim=1)
|
||||||
|
|
||||||
|
# return waveform
|
||||||
|
|
||||||
|
# def get_audio_hubert(audio_path):
|
||||||
|
# """Get HuBERT features for the audio file"""
|
||||||
|
# waveform = load_and_preprocess_audio(audio_path)
|
||||||
|
# ort_session = ort.InferenceSession("playground/hubert/chinese-hubert-base.onnx")
|
||||||
|
# ort_inputs = {ort_session.get_inputs()[0].name: waveform.numpy()}
|
||||||
|
# hubert_feature = ort_session.run(None, ort_inputs)[0].astype(np.float32)
|
||||||
|
# # transpose axis 1 and 2 with numpy
|
||||||
|
# hubert_feature = hubert_feature.transpose(0, 2, 1)
|
||||||
|
# return hubert_feature
|
||||||
|
|
||||||
|
def load_and_preprocess_audio(audio_path):
|
||||||
|
"""Load and preprocess audio file to 32k"""
|
||||||
waveform, sample_rate = torchaudio.load(audio_path)
|
waveform, sample_rate = torchaudio.load(audio_path)
|
||||||
|
|
||||||
# Resample to 16kHz if needed
|
# Resample to 32kHz if needed
|
||||||
if sample_rate != 16000:
|
if sample_rate != 32000:
|
||||||
resampler = torchaudio.transforms.Resample(sample_rate, 16000)
|
resampler = torchaudio.transforms.Resample(sample_rate, 32000)
|
||||||
waveform = resampler(waveform)
|
waveform = resampler(waveform)
|
||||||
|
|
||||||
# Take first channel
|
# Take first channel
|
||||||
if waveform.shape[0] > 1:
|
if waveform.shape[0] > 1:
|
||||||
waveform = waveform[0:1]
|
waveform = waveform[0:1]
|
||||||
|
|
||||||
# Limit length for testing (10 seconds at 16kHz)
|
|
||||||
if waveform.shape[1] > max_length:
|
|
||||||
waveform = waveform[:, :max_length]
|
|
||||||
|
|
||||||
# make a zero tensor that has length 3200*0.3
|
|
||||||
zero_tensor = torch.zeros((1, 4800), dtype=torch.float32)
|
|
||||||
|
|
||||||
# concate zero_tensor with waveform
|
|
||||||
waveform = torch.cat([waveform, zero_tensor], dim=1)
|
|
||||||
|
|
||||||
return waveform
|
return waveform
|
||||||
|
|
||||||
def get_audio_hubert(audio_path):
|
def get_audio_hubert(audio_path):
|
||||||
"""Get HuBERT features for the audio file"""
|
"""Get HuBERT features for the audio file"""
|
||||||
waveform = load_and_preprocess_audio(audio_path)
|
waveform = load_and_preprocess_audio(audio_path)
|
||||||
ort_session = ort.InferenceSession("playground/hubert/chinese-hubert-base.onnx")
|
ort_session = ort.InferenceSession(MODEL_PATH + "_export_hubertssl.onnx")
|
||||||
ort_inputs = {ort_session.get_inputs()[0].name: waveform.numpy()}
|
ort_inputs = {ort_session.get_inputs()[0].name: waveform.numpy()}
|
||||||
hubert_feature = ort_session.run(None, ort_inputs)[0].astype(np.float32)
|
hubert_feature = ort_session.run(None, ort_inputs)[0].astype(np.float32)
|
||||||
# transpose axis 1 and 2 with numpy
|
# transpose axis 1 and 2 with numpy
|
||||||
hubert_feature = hubert_feature.transpose(0, 2, 1)
|
# hubert_feature = hubert_feature.transpose(0, 2, 1)
|
||||||
return hubert_feature
|
return hubert_feature
|
||||||
|
|
||||||
def preprocess_text(text:str):
|
def preprocess_text(text:str):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user