mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-29 08:49:59 +08:00
correctly setup onnx export, solved problem
This commit is contained in:
parent
94b31a250f
commit
aafa0561d8
@ -133,14 +133,11 @@ class T2SModel(nn.Module):
|
|||||||
# [1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N]
|
# [1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N]
|
||||||
y, k, v, y_emb, x_example = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
y, k, v, y_emb, x_example = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
||||||
|
|
||||||
stop = False
|
for idx in tqdm(range(1, 20)): # This is a fake one! do take this as reference
|
||||||
for idx in tqdm(range(1, 1500)):
|
|
||||||
# [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
|
# [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
|
||||||
enco = self.stage_decoder(y, k, v, y_emb, x_example)
|
enco = self.stage_decoder(y, k, v, y_emb, x_example)
|
||||||
y, k, v, y_emb, logits, samples = enco
|
y, k, v, y_emb, logits, samples = enco
|
||||||
if torch.argmax(logits, dim=-1)[0] == self.t2s_model.EOS or samples[0, 0] == self.t2s_model.EOS:
|
if torch.argmax(logits, dim=-1)[0] == self.t2s_model.EOS or samples[0, 0] == self.t2s_model.EOS:
|
||||||
stop = True
|
|
||||||
if stop:
|
|
||||||
break
|
break
|
||||||
y[0, -1] = 0
|
y[0, -1] = 0
|
||||||
|
|
||||||
@ -216,9 +213,7 @@ class VitsModel(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.hps["model"]["version"] = version
|
self.hps["model"]["version"] = version
|
||||||
|
|
||||||
self.sv_model = None
|
self.is_v2p = version.lower() in ['v2pro', 'v2proplus']
|
||||||
if version == "v2ProPlus" or version == "v2Pro":
|
|
||||||
self.sv_model = SV("cpu", False)
|
|
||||||
|
|
||||||
self.hps = DictToAttrRecursive(self.hps)
|
self.hps = DictToAttrRecursive(self.hps)
|
||||||
self.hps.model.semantic_frame_rate = "25hz"
|
self.hps.model.semantic_frame_rate = "25hz"
|
||||||
@ -231,10 +226,10 @@ class VitsModel(nn.Module):
|
|||||||
self.vq_model.eval()
|
self.vq_model.eval()
|
||||||
self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
||||||
#filter_length: 2048 sampling_rate: 32000 hop_length: 640 win_length: 2048
|
#filter_length: 2048 sampling_rate: 32000 hop_length: 640 win_length: 2048
|
||||||
def forward(self, text_seq, pred_semantic, ref_audio, spectrum):
|
def forward(self, text_seq, pred_semantic, spectrum, sv_emb):
|
||||||
if self.sv_model is not None:
|
if self.is_v2p:
|
||||||
sv_emb=self.sv_model.compute_embedding3_onnx(resample_audio(ref_audio, 32000, 16000))
|
|
||||||
return self.vq_model(pred_semantic, text_seq, spectrum, sv_emb=sv_emb)[0, 0]
|
return self.vq_model(pred_semantic, text_seq, spectrum, sv_emb=sv_emb)[0, 0]
|
||||||
|
else:
|
||||||
return self.vq_model(pred_semantic, text_seq, spectrum)[0, 0]
|
return self.vq_model(pred_semantic, text_seq, spectrum)[0, 0]
|
||||||
|
|
||||||
|
|
||||||
@ -244,19 +239,19 @@ class GptSoVits(nn.Module):
|
|||||||
self.vits = vits
|
self.vits = vits
|
||||||
self.t2s = t2s
|
self.t2s = t2s
|
||||||
|
|
||||||
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, spectrum, ssl_content):
|
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, spectrum, sv_emb):
|
||||||
pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
||||||
audio = self.vits(text_seq, pred_semantic, ref_audio, spectrum)
|
audio = self.vits(text_seq, pred_semantic, spectrum, sv_emb)
|
||||||
return audio
|
return audio
|
||||||
|
|
||||||
def export(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, spectrum, ssl_content, project_name):
|
def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, spectrum, sv_emb, project_name):
|
||||||
self.t2s.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name)
|
self.t2s.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name)
|
||||||
pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
||||||
torch.onnx.export(
|
torch.onnx.export(
|
||||||
self.vits,
|
self.vits,
|
||||||
(text_seq, pred_semantic, ref_audio, spectrum),
|
(text_seq, pred_semantic, spectrum, sv_emb),
|
||||||
f"onnx/{project_name}/{project_name}_vits.onnx",
|
f"onnx/{project_name}/{project_name}_vits.onnx",
|
||||||
input_names=["text_seq", "pred_semantic", "ref_audio", "spectrum"],
|
input_names=["text_seq", "pred_semantic", "spectrum", "sv_emb"],
|
||||||
output_names=["audio"],
|
output_names=["audio"],
|
||||||
dynamic_axes={
|
dynamic_axes={
|
||||||
"text_seq": {1: "text_length"},
|
"text_seq": {1: "text_length"},
|
||||||
@ -269,7 +264,7 @@ class GptSoVits(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class HuBertSSLModel(nn.Module):
|
class AudioPreprocess(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = HubertConfig.from_pretrained(cnhubert_base_path)
|
self.config = HubertConfig.from_pretrained(cnhubert_base_path)
|
||||||
@ -285,6 +280,8 @@ class HuBertSSLModel(nn.Module):
|
|||||||
)
|
)
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
|
self.sv_model = SV("cpu", False)
|
||||||
|
|
||||||
def forward(self, ref_audio_32k):
|
def forward(self, ref_audio_32k):
|
||||||
spectrum = spectrogram_torch(
|
spectrum = spectrogram_torch(
|
||||||
ref_audio_32k,
|
ref_audio_32k,
|
||||||
@ -294,22 +291,24 @@ class HuBertSSLModel(nn.Module):
|
|||||||
2048,
|
2048,
|
||||||
center=False,
|
center=False,
|
||||||
)
|
)
|
||||||
|
ref_audio_16k = resample_audio(ref_audio_32k, 32000, 16000)
|
||||||
|
|
||||||
|
sv_emb = self.sv_model.compute_embedding3_onnx(ref_audio_16k)
|
||||||
|
|
||||||
ref_audio_16k = resample_audio(ref_audio_32k, 32000, 16000).unsqueeze(0)
|
|
||||||
zero_tensor = torch.zeros((1, 4800), dtype=torch.float32)
|
zero_tensor = torch.zeros((1, 4800), dtype=torch.float32)
|
||||||
|
ref_audio_16k = ref_audio_16k.unsqueeze(0)
|
||||||
# concate zero_tensor with waveform
|
# concate zero_tensor with waveform
|
||||||
ref_audio_16k = torch.cat([ref_audio_16k, zero_tensor], dim=1)
|
ref_audio_16k = torch.cat([ref_audio_16k, zero_tensor], dim=1)
|
||||||
ssl_content = self.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2)
|
ssl_content = self.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2)
|
||||||
|
|
||||||
return ssl_content, spectrum
|
return ssl_content, spectrum, sv_emb
|
||||||
|
|
||||||
|
|
||||||
def export(vits_path, gpt_path, project_name, voice_model_version="v2"):
|
def export(vits_path, gpt_path, project_name, voice_model_version="v2"):
|
||||||
vits = VitsModel(vits_path, version=voice_model_version)
|
vits = VitsModel(vits_path, version=voice_model_version)
|
||||||
gpt = T2SModel(gpt_path, vits)
|
gpt = T2SModel(gpt_path, vits)
|
||||||
gpt_sovits = GptSoVits(vits, gpt)
|
gpt_sovits = GptSoVits(vits, gpt)
|
||||||
ssl = HuBertSSLModel()
|
preprocessor = AudioPreprocess()
|
||||||
ref_seq = torch.LongTensor(
|
ref_seq = torch.LongTensor(
|
||||||
[
|
[
|
||||||
cleaned_text_to_sequence(
|
cleaned_text_to_sequence(
|
||||||
@ -376,19 +375,19 @@ def export(vits_path, gpt_path, project_name, voice_model_version="v2"):
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
torch.onnx.export(ssl, (ref_audio32k,), f"onnx/{project_name}/{project_name}_hubertssl.onnx",
|
torch.onnx.export(preprocessor, (ref_audio32k,), f"onnx/{project_name}/{project_name}_audio_preprocess.onnx",
|
||||||
input_names=["audio32k"],
|
input_names=["audio32k"],
|
||||||
output_names=["hubert_ssl_output", "spectrum"],
|
output_names=["hubert_ssl_output", "spectrum", "sv_emb"],
|
||||||
dynamic_axes={
|
dynamic_axes={
|
||||||
"audio32k": {0: "batch_size", 1: "sequence_length"},
|
"audio32k": {1: "sequence_length"},
|
||||||
"hubert_ssl_output": {0: "batch_size", 2: "hubert_length"},
|
"hubert_ssl_output": {2: "hubert_length"},
|
||||||
"spectrum": {0: "batch_size", 2: "spectrum_length"}
|
"spectrum": {2: "spectrum_length"}
|
||||||
})
|
})
|
||||||
|
|
||||||
[ssl_content, spectrum] = ssl(ref_audio32k)
|
[ssl_content, spectrum, sv_emb] = preprocessor(ref_audio32k)
|
||||||
gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio32k, spectrum.float(), ssl_content.float())
|
gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ssl_content.float(), spectrum.float(), sv_emb.float())
|
||||||
# exit()
|
# exit()
|
||||||
gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio32k, spectrum.float(), ssl_content.float(), project_name)
|
gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content.float(), spectrum.float(), sv_emb.float(), project_name)
|
||||||
|
|
||||||
if voice_model_version == "v1":
|
if voice_model_version == "v1":
|
||||||
symbols = symbols_v1
|
symbols = symbols_v1
|
||||||
@ -401,11 +400,11 @@ if __name__ == "__main__":
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"
|
gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"
|
||||||
# vits_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"
|
vits_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"
|
||||||
# exp_path = "v2_export"
|
exp_path = "v2_export"
|
||||||
# version = "v2"
|
version = "v2"
|
||||||
# export(vits_path, gpt_path, exp_path, version)
|
export(vits_path, gpt_path, exp_path, version)
|
||||||
|
|
||||||
gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"
|
gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"
|
||||||
vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth"
|
vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth"
|
||||||
@ -413,10 +412,10 @@ if __name__ == "__main__":
|
|||||||
version = "v2Pro"
|
version = "v2Pro"
|
||||||
export(vits_path, gpt_path, exp_path, version)
|
export(vits_path, gpt_path, exp_path, version)
|
||||||
|
|
||||||
# gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"
|
gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"
|
||||||
# vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth"
|
vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth"
|
||||||
# exp_path = "v2proplus_export"
|
exp_path = "v2proplus_export"
|
||||||
# version = "v2ProPlus"
|
version = "v2ProPlus"
|
||||||
# export(vits_path, gpt_path, exp_path, version)
|
export(vits_path, gpt_path, exp_path, version)
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ import torch
|
|||||||
from TTS_infer_pack.TextPreprocessor_onnx import TextPreprocessorOnnx
|
from TTS_infer_pack.TextPreprocessor_onnx import TextPreprocessorOnnx
|
||||||
|
|
||||||
|
|
||||||
MODEL_PATH = "onnx/v2pro_export/v2pro"
|
MODEL_PATH = "onnx/v2proplus_export/v2proplus"
|
||||||
|
|
||||||
def audio_postprocess(
|
def audio_postprocess(
|
||||||
audios,
|
audios,
|
||||||
@ -31,42 +31,7 @@ def audio_postprocess(
|
|||||||
|
|
||||||
return audio
|
return audio
|
||||||
|
|
||||||
# def load_and_preprocess_audio(audio_path, max_length=160000):
|
def load_audio(audio_path):
|
||||||
# """Load and preprocess audio file to 16k"""
|
|
||||||
# waveform, sample_rate = torchaudio.load(audio_path)
|
|
||||||
|
|
||||||
# # Resample to 16kHz if needed
|
|
||||||
# if sample_rate != 16000:
|
|
||||||
# resampler = torchaudio.transforms.Resample(sample_rate, 16000)
|
|
||||||
# waveform = resampler(waveform)
|
|
||||||
|
|
||||||
# # Take first channel
|
|
||||||
# if waveform.shape[0] > 1:
|
|
||||||
# waveform = waveform[0:1]
|
|
||||||
|
|
||||||
# # Limit length for testing (10 seconds at 16kHz)
|
|
||||||
# if waveform.shape[1] > max_length:
|
|
||||||
# waveform = waveform[:, :max_length]
|
|
||||||
|
|
||||||
# # make a zero tensor that has length 3200*0.3
|
|
||||||
# zero_tensor = torch.zeros((1, 4800), dtype=torch.float32)
|
|
||||||
|
|
||||||
# # concate zero_tensor with waveform
|
|
||||||
# waveform = torch.cat([waveform, zero_tensor], dim=1)
|
|
||||||
|
|
||||||
# return waveform
|
|
||||||
|
|
||||||
# def get_audio_hubert(audio_path):
|
|
||||||
# """Get HuBERT features for the audio file"""
|
|
||||||
# waveform = load_and_preprocess_audio(audio_path)
|
|
||||||
# ort_session = ort.InferenceSession("playground/hubert/chinese-hubert-base.onnx")
|
|
||||||
# ort_inputs = {ort_session.get_inputs()[0].name: waveform.numpy()}
|
|
||||||
# hubert_feature = ort_session.run(None, ort_inputs)[0].astype(np.float32)
|
|
||||||
# # transpose axis 1 and 2 with numpy
|
|
||||||
# hubert_feature = hubert_feature.transpose(0, 2, 1)
|
|
||||||
# return hubert_feature
|
|
||||||
|
|
||||||
def load_and_preprocess_audio(audio_path):
|
|
||||||
"""Load and preprocess audio file to 32k"""
|
"""Load and preprocess audio file to 32k"""
|
||||||
waveform, sample_rate = torchaudio.load(audio_path)
|
waveform, sample_rate = torchaudio.load(audio_path)
|
||||||
|
|
||||||
@ -81,15 +46,13 @@ def load_and_preprocess_audio(audio_path):
|
|||||||
|
|
||||||
return waveform
|
return waveform
|
||||||
|
|
||||||
def get_audio_hubert(audio_path):
|
def audio_preprocess(audio_path):
|
||||||
"""Get HuBERT features for the audio file"""
|
"""Get HuBERT features for the audio file"""
|
||||||
waveform = load_and_preprocess_audio(audio_path)
|
waveform = load_audio(audio_path)
|
||||||
ort_session = ort.InferenceSession(MODEL_PATH + "_export_hubertssl.onnx")
|
ort_session = ort.InferenceSession(MODEL_PATH + "_export_audio_preprocess.onnx")
|
||||||
ort_inputs = {ort_session.get_inputs()[0].name: waveform.numpy()}
|
ort_inputs = {ort_session.get_inputs()[0].name: waveform.numpy()}
|
||||||
[hubert_feature, spectrum] = ort_session.run(None, ort_inputs)
|
[hubert_feature, spectrum, sv_emb] = ort_session.run(None, ort_inputs)
|
||||||
# transpose axis 1 and 2 with numpy
|
return hubert_feature, spectrum, sv_emb
|
||||||
# hubert_feature = hubert_feature.transpose(0, 2, 1)
|
|
||||||
return hubert_feature, spectrum
|
|
||||||
|
|
||||||
def preprocess_text(text:str):
|
def preprocess_text(text:str):
|
||||||
preprocessor = TextPreprocessorOnnx("playground/bert")
|
preprocessor = TextPreprocessorOnnx("playground/bert")
|
||||||
@ -108,7 +71,7 @@ def preprocess_text(text:str):
|
|||||||
[ref_phones, ref_bert] = preprocess_text("今日江苏苏州荷花市集开张热闹与浪漫交织")
|
[ref_phones, ref_bert] = preprocess_text("今日江苏苏州荷花市集开张热闹与浪漫交织")
|
||||||
|
|
||||||
|
|
||||||
[audio_prompt_hubert, spectrum] = get_audio_hubert("playground/ref/audio.wav")
|
[audio_prompt_hubert, spectrum, sv_emb] = audio_preprocess("playground/ref/audio.wav")
|
||||||
|
|
||||||
|
|
||||||
encoder = ort.InferenceSession(MODEL_PATH+"_export_t2s_encoder.onnx")
|
encoder = ort.InferenceSession(MODEL_PATH+"_export_t2s_encoder.onnx")
|
||||||
@ -131,7 +94,6 @@ sdec = ort.InferenceSession(MODEL_PATH+"_export_t2s_sdec.onnx")
|
|||||||
# })
|
# })
|
||||||
|
|
||||||
|
|
||||||
stop = False
|
|
||||||
for idx in tqdm(range(1, 1500)):
|
for idx in tqdm(range(1, 1500)):
|
||||||
# [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
|
# [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
|
||||||
[y, k, v, y_emb, logits, samples] = sdec.run(None, {
|
[y, k, v, y_emb, logits, samples] = sdec.run(None, {
|
||||||
@ -160,8 +122,8 @@ vtis = ort.InferenceSession(MODEL_PATH+"_export_vits.onnx")
|
|||||||
[audio] = vtis.run(None, {
|
[audio] = vtis.run(None, {
|
||||||
"text_seq": input_phones,
|
"text_seq": input_phones,
|
||||||
"pred_semantic": pred_semantic,
|
"pred_semantic": pred_semantic,
|
||||||
"ref_audio": ref_audio,
|
"spectrum": spectrum.astype(np.float32),
|
||||||
"spectrum": spectrum.astype(np.float32)
|
"sv_emb": sv_emb.astype(np.float32)
|
||||||
})
|
})
|
||||||
|
|
||||||
audio_postprocess([audio])
|
audio_postprocess([audio])
|
||||||
|
Loading…
x
Reference in New Issue
Block a user