mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-29 00:30:15 +08:00
feat:clean up export logics and add notes
This commit is contained in:
parent
e4d1894a8f
commit
48d52778ce
@ -1025,8 +1025,28 @@ def fbank_onnx(
|
|||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
r"""ONNX-compatible fbank function with hardcoded parameters from traced call:
|
r"""ONNX-compatible fbank function with hardcoded parameters from traced call:
|
||||||
num_mel_bins=80, sample_frequency=16000, dither=0
|
num_mel_bins=80, sample_frequency=16000, dither=0
|
||||||
|
blackman_coeff: float = 0.42,
|
||||||
All other parameters use their traced default values.
|
channel: int = -1,
|
||||||
|
energy_floor: float = 1.0,
|
||||||
|
frame_length: float = 25.0,
|
||||||
|
frame_shift: float = 10.0,
|
||||||
|
high_freq: float = 0.0,
|
||||||
|
htk_compat: bool = False,
|
||||||
|
low_freq: float = 20.0,
|
||||||
|
min_duration: float = 0.0,
|
||||||
|
preemphasis_coefficient: float = 0.97,
|
||||||
|
raw_energy: bool = True,
|
||||||
|
remove_dc_offset: bool = True,
|
||||||
|
round_to_power_of_two: bool = True,
|
||||||
|
snip_edges: bool = True,
|
||||||
|
subtract_mean: bool = False,
|
||||||
|
use_energy: bool = False,
|
||||||
|
use_log_fbank: bool = True,
|
||||||
|
use_power: bool = True,
|
||||||
|
vtln_high: float = -500.0,
|
||||||
|
vtln_low: float = 100.0,
|
||||||
|
vtln_warp: float = 1.0,
|
||||||
|
window_type: str = POVEY
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
|
waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
|
||||||
|
@ -11,8 +11,17 @@ from onnx import helper, TensorProto
|
|||||||
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
|
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
|
||||||
from transformers import HubertModel, HubertConfig
|
from transformers import HubertModel, HubertConfig
|
||||||
import os
|
import os
|
||||||
from tqdm import tqdm
|
import json
|
||||||
from text import cleaned_text_to_sequence
|
from text import cleaned_text_to_sequence
|
||||||
|
import onnxsim
|
||||||
|
|
||||||
|
def simplify_onnx_model(onnx_model_path: str):
|
||||||
|
# Load the ONNX model
|
||||||
|
model = onnx.load(onnx_model_path)
|
||||||
|
# Simplify the model
|
||||||
|
model_simplified, _ = onnxsim.simplify(model)
|
||||||
|
# Save the simplified model
|
||||||
|
onnx.save(model_simplified, onnx_model_path)
|
||||||
|
|
||||||
|
|
||||||
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
|
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
|
||||||
@ -102,7 +111,7 @@ class T2SInitStep(nn.Module):
|
|||||||
bert = bert.unsqueeze(0)
|
bert = bert.unsqueeze(0)
|
||||||
prompt = prompt_semantic.unsqueeze(0)
|
prompt = prompt_semantic.unsqueeze(0)
|
||||||
[y, k, v, y_emb, x_example] = self.fsdc(self.encoder(all_phoneme_ids, bert), prompt, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)
|
[y, k, v, y_emb, x_example] = self.fsdc(self.encoder(all_phoneme_ids, bert), prompt, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)
|
||||||
fake_logits = torch.randn((1, 1025), dtype=torch.float32) # Dummy logits for ONNX export
|
fake_logits = torch.zeros((1, 1025), dtype=torch.float32) # Dummy logits for ONNX export
|
||||||
fake_samples = torch.zeros((1, 1), dtype=torch.int32) # Dummy samples for ONNX export
|
fake_samples = torch.zeros((1, 1), dtype=torch.int32) # Dummy samples for ONNX export
|
||||||
return y, k, v, y_emb, x_example, fake_logits, fake_samples
|
return y, k, v, y_emb, x_example, fake_logits, fake_samples
|
||||||
|
|
||||||
@ -113,7 +122,7 @@ class T2SStageStep(nn.Module):
|
|||||||
|
|
||||||
def forward(self, iy, ik, iv, iy_emb, ix_example, top_k=None, top_p=None, repetition_penalty=None, temperature=None):
|
def forward(self, iy, ik, iv, iy_emb, ix_example, top_k=None, top_p=None, repetition_penalty=None, temperature=None):
|
||||||
[y, k, v, y_emb, logits, samples] = self.stage_decoder(iy, ik, iv, iy_emb, ix_example, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)
|
[y, k, v, y_emb, logits, samples] = self.stage_decoder(iy, ik, iv, iy_emb, ix_example, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)
|
||||||
fake_x_example = torch.randn((1, 512), dtype=torch.float32) # Dummy x_example for ONNX export
|
fake_x_example = torch.zeros((1, 512), dtype=torch.float32) # Dummy x_example for ONNX export
|
||||||
return y, k, v, y_emb, fake_x_example, logits, samples
|
return y, k, v, y_emb, fake_x_example, logits, samples
|
||||||
|
|
||||||
class T2SModel(nn.Module):
|
class T2SModel(nn.Module):
|
||||||
@ -134,22 +143,18 @@ class T2SModel(nn.Module):
|
|||||||
self.init_step = T2SInitStep(self.t2s_model, self.vits_model)
|
self.init_step = T2SInitStep(self.t2s_model, self.vits_model)
|
||||||
self.first_stage_decoder = self.t2s_model.first_stage_decoder
|
self.first_stage_decoder = self.t2s_model.first_stage_decoder
|
||||||
self.stage_decoder = self.t2s_model.stage_decoder
|
self.stage_decoder = self.t2s_model.stage_decoder
|
||||||
# self.t2s_model = torch.jit.script(self.t2s_model)
|
|
||||||
|
|
||||||
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=None, top_p=None, repetition_penalty=None, temperature=None):
|
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=None, top_p=None, repetition_penalty=None, temperature=None):
|
||||||
early_stop_num = self.t2s_model.early_stop_num
|
|
||||||
|
|
||||||
# [1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N]
|
# [1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N]
|
||||||
y, k, v, y_emb, x_example, fake_logits, fake_samples = self.init_step(ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)
|
y, k, v, y_emb, x_example, fake_logits, fake_samples = self.init_step(ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)
|
||||||
|
|
||||||
for idx in tqdm(range(1, 20)): # This is a fake one! do take this as reference
|
for idx in range(5): # This is a fake one! DO NOT take this as reference
|
||||||
# [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
|
|
||||||
enco = self.stage_decoder(y, k, v, y_emb, x_example, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)
|
enco = self.stage_decoder(y, k, v, y_emb, x_example, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)
|
||||||
y, k, v, y_emb, logits, samples = enco
|
y, k, v, y_emb, logits, samples = enco
|
||||||
if torch.argmax(logits, dim=-1)[0] == self.t2s_model.EOS or samples[0, 0] == self.t2s_model.EOS:
|
# if torch.argmax(logits, dim=-1)[0] == self.t2s_model.EOS or samples[0, 0] == self.t2s_model.EOS:
|
||||||
break
|
# break
|
||||||
|
|
||||||
return y[:, -idx:].unsqueeze(0)
|
return y[:, -5:].unsqueeze(0)
|
||||||
|
|
||||||
def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name, top_k=None, top_p=None, repetition_penalty=None, temperature=None):
|
def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name, top_k=None, top_p=None, repetition_penalty=None, temperature=None):
|
||||||
torch.onnx.export(
|
torch.onnx.export(
|
||||||
@ -167,13 +172,14 @@ class T2SModel(nn.Module):
|
|||||||
},
|
},
|
||||||
opset_version=16,
|
opset_version=16,
|
||||||
)
|
)
|
||||||
|
simplify_onnx_model(f"onnx/{project_name}/{project_name}_t2s_init_step.onnx")
|
||||||
y, k, v, y_emb, x_example, fake_logits, fake_samples = self.init_step(ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)
|
y, k, v, y_emb, x_example, fake_logits, fake_samples = self.init_step(ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)
|
||||||
|
|
||||||
stage_step = T2SStageStep(self.stage_decoder)
|
stage_step = T2SStageStep(self.stage_decoder)
|
||||||
torch.onnx.export(
|
torch.onnx.export(
|
||||||
stage_step,
|
stage_step,
|
||||||
(y, k, v, y_emb, x_example, top_k, top_p, repetition_penalty, temperature),
|
(y, k, v, y_emb, x_example, top_k, top_p, repetition_penalty, temperature),
|
||||||
f"onnx/{project_name}/{project_name}_t2s_sdec.onnx",
|
f"onnx/{project_name}/{project_name}_t2s_stage_step.onnx",
|
||||||
input_names=["iy", "ik", "iv", "iy_emb", "ix_example", "top_k", "top_p", "repetition_penalty", "temperature"],
|
input_names=["iy", "ik", "iv", "iy_emb", "ix_example", "top_k", "top_p", "repetition_penalty", "temperature"],
|
||||||
output_names=["y", "k", "v", "y_emb","x_example", "logits", "samples"],
|
output_names=["y", "k", "v", "y_emb","x_example", "logits", "samples"],
|
||||||
dynamic_axes={
|
dynamic_axes={
|
||||||
@ -187,6 +193,7 @@ class T2SModel(nn.Module):
|
|||||||
verbose=False,
|
verbose=False,
|
||||||
opset_version=16,
|
opset_version=16,
|
||||||
)
|
)
|
||||||
|
simplify_onnx_model(f"onnx/{project_name}/{project_name}_t2s_stage_step.onnx")
|
||||||
|
|
||||||
|
|
||||||
class VitsModel(nn.Module):
|
class VitsModel(nn.Module):
|
||||||
@ -248,6 +255,7 @@ class GptSoVits(nn.Module):
|
|||||||
opset_version=17,
|
opset_version=17,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
)
|
)
|
||||||
|
simplify_onnx_model(f"onnx/{project_name}/{project_name}_vits.onnx")
|
||||||
|
|
||||||
|
|
||||||
class AudioPreprocess(nn.Module):
|
class AudioPreprocess(nn.Module):
|
||||||
@ -347,7 +355,7 @@ def combineInitStepAndStageStep(init_step_onnx_path, stage_step_onnx_path, combi
|
|||||||
print(f"Combined model saved to {combined_onnx_path}")
|
print(f"Combined model saved to {combined_onnx_path}")
|
||||||
|
|
||||||
|
|
||||||
def export(vits_path, gpt_path, project_name, voice_model_version="v2"):
|
def export(vits_path, gpt_path, project_name, voice_model_version, t2s_model_combine=False, export_audio_preprocessor=True):
|
||||||
vits = VitsModel(vits_path, version=voice_model_version)
|
vits = VitsModel(vits_path, version=voice_model_version)
|
||||||
gpt = T2SModel(gpt_path, vits)
|
gpt = T2SModel(gpt_path, vits)
|
||||||
gpt_sovits = GptSoVits(vits, gpt)
|
gpt_sovits = GptSoVits(vits, gpt)
|
||||||
@ -409,7 +417,7 @@ def export(vits_path, gpt_path, project_name, voice_model_version="v2"):
|
|||||||
)
|
)
|
||||||
ref_bert = torch.randn((ref_seq.shape[1], 1024)).float()
|
ref_bert = torch.randn((ref_seq.shape[1], 1024)).float()
|
||||||
text_bert = torch.randn((text_seq.shape[1], 1024)).float()
|
text_bert = torch.randn((text_seq.shape[1], 1024)).float()
|
||||||
ref_audio32k = torch.randn((1, 32000 * 5)).float() - 0.5
|
ref_audio32k = torch.randn((1, 32000 * 5)).float() - 0.5 # 5 seconds of dummy audio
|
||||||
top_k = torch.LongTensor([15])
|
top_k = torch.LongTensor([15])
|
||||||
top_p = torch.FloatTensor([1.0])
|
top_p = torch.FloatTensor([1.0])
|
||||||
repetition_penalty = torch.FloatTensor([1.0])
|
repetition_penalty = torch.FloatTensor([1.0])
|
||||||
@ -422,7 +430,8 @@ def export(vits_path, gpt_path, project_name, voice_model_version="v2"):
|
|||||||
# exit()
|
# exit()
|
||||||
gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content.float(), spectrum.float(), sv_emb.float(), project_name, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)
|
gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content.float(), spectrum.float(), sv_emb.float(), project_name, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)
|
||||||
|
|
||||||
torch.onnx.export(preprocessor, (ref_audio32k,), f"onnx/{project_name}/{project_name}_audio_preprocess.onnx",
|
if export_audio_preprocessor:
|
||||||
|
torch.onnx.export(preprocessor, (ref_audio32k,), f"onnx/{project_name}/{project_name}_audio_preprocess.onnx",
|
||||||
input_names=["audio32k"],
|
input_names=["audio32k"],
|
||||||
output_names=["hubert_ssl_output", "spectrum", "sv_emb"],
|
output_names=["hubert_ssl_output", "spectrum", "sv_emb"],
|
||||||
dynamic_axes={
|
dynamic_axes={
|
||||||
@ -430,6 +439,11 @@ def export(vits_path, gpt_path, project_name, voice_model_version="v2"):
|
|||||||
"hubert_ssl_output": {2: "hubert_length"},
|
"hubert_ssl_output": {2: "hubert_length"},
|
||||||
"spectrum": {2: "spectrum_length"}
|
"spectrum": {2: "spectrum_length"}
|
||||||
})
|
})
|
||||||
|
simplify_onnx_model(f"onnx/{project_name}/{project_name}_audio_preprocess.onnx")
|
||||||
|
|
||||||
|
if t2s_model_combine:
|
||||||
|
combineInitStepAndStageStep(f'onnx/{project_name}/{project_name}_t2s_init_step.onnx', f'onnx/{project_name}/{project_name}_t2s_stage_step.onnx', f'onnx/{project_name}/{project_name}_t2s_combined.onnx')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
try:
|
try:
|
||||||
@ -437,32 +451,31 @@ if __name__ == "__main__":
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# 因为io太频繁,可能导致模型导出出错(wsl非常明显),请自行重试
|
||||||
|
|
||||||
|
gpt_path = "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
|
||||||
|
vits_path = "GPT_SoVITS/pretrained_models/s2G488k.pth"
|
||||||
|
exp_path = "v1_export"
|
||||||
|
version = "v1"
|
||||||
|
export(vits_path, gpt_path, exp_path, version, t2s_model_combine = True)
|
||||||
|
|
||||||
# gpt_path = "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
|
gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"
|
||||||
# vits_path = "GPT_SoVITS/pretrained_models/s2G488k.pth"
|
vits_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"
|
||||||
# exp_path = "v1_export"
|
exp_path = "v2_export"
|
||||||
# version = "v1"
|
version = "v2"
|
||||||
# export(vits_path, gpt_path, exp_path, version)
|
export(vits_path, gpt_path, exp_path, version, t2s_model_combine = True)
|
||||||
|
|
||||||
|
|
||||||
# gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"
|
gpt_path = "GPT_SoVITS/pretrained_models/s1v3.ckpt"
|
||||||
# vits_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"
|
vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth"
|
||||||
# exp_path = "v2_export"
|
exp_path = "v2pro_export"
|
||||||
# version = "v2"
|
version = "v2Pro"
|
||||||
# export(vits_path, gpt_path, exp_path, version)
|
export(vits_path, gpt_path, exp_path, version, t2s_model_combine = True)
|
||||||
# combineInitStepAndStageStep('onnx/v2_export/v2_export_t2s_init_step.onnx', 'onnx/v2_export/v2_export_t2s_sdec.onnx', 'onnx/v2_export/v2_export_t2s_combined.onnx')
|
|
||||||
|
|
||||||
# gpt_path = "GPT_SoVITS/pretrained_models/s1v3.ckpt"
|
|
||||||
# vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth"
|
|
||||||
# exp_path = "v2pro_export"
|
|
||||||
# version = "v2Pro"
|
|
||||||
# export(vits_path, gpt_path, exp_path, version)
|
|
||||||
|
|
||||||
gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"
|
gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"
|
||||||
vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth"
|
vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth"
|
||||||
exp_path = "v2proplus_export"
|
exp_path = "v2proplus_export"
|
||||||
version = "v2ProPlus"
|
version = "v2ProPlus"
|
||||||
export(vits_path, gpt_path, exp_path, version)
|
export(vits_path, gpt_path, exp_path, version, t2s_model_combine = True)
|
||||||
combineInitStepAndStageStep('onnx/v2proplus_export/v2proplus_export_t2s_init_step.onnx', 'onnx/v2proplus_export/v2proplus_export_t2s_sdec.onnx', 'onnx/v2proplus_export/v2proplus_export_t2s_combined.onnx')
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -10,6 +10,7 @@ ffmpeg-python
|
|||||||
onnx
|
onnx
|
||||||
onnxruntime; platform_machine == "aarch64" or platform_machine == "arm64"
|
onnxruntime; platform_machine == "aarch64" or platform_machine == "arm64"
|
||||||
onnxruntime-gpu; platform_machine == "x86_64" or platform_machine == "AMD64"
|
onnxruntime-gpu; platform_machine == "x86_64" or platform_machine == "AMD64"
|
||||||
|
onnxsim
|
||||||
tqdm
|
tqdm
|
||||||
funasr==1.0.27
|
funasr==1.0.27
|
||||||
cn2an
|
cn2an
|
||||||
|
Loading…
x
Reference in New Issue
Block a user