feat:supporting half export

This commit is contained in:
zpeng11 2025-08-24 02:29:33 -04:00
parent 72c5d3224e
commit 942caa888e
2 changed files with 23 additions and 6 deletions

View File

@ -14,6 +14,7 @@ import os
import json import json
from text import cleaned_text_to_sequence from text import cleaned_text_to_sequence
import onnxsim import onnxsim
from onnxconverter_common import float16
def simplify_onnx_model(onnx_model_path: str): def simplify_onnx_model(onnx_model_path: str):
# Load the ONNX model # Load the ONNX model
@ -23,6 +24,14 @@ def simplify_onnx_model(onnx_model_path: str):
# Save the simplified model # Save the simplified model
onnx.save(model_simplified, onnx_model_path) onnx.save(model_simplified, onnx_model_path)
def convert_onnx_to_half(onnx_model_path:str):
try:
model = onnx.load(onnx_model_path)
model_fp16 = float16.convert_float_to_float16(model)
onnx.save(model_fp16, onnx_model_path)
except Exception as e:
print(f"Error converting {onnx_model_path} to half precision: {e}")
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
hann_window = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) hann_window = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
@ -355,7 +364,7 @@ def combineInitStepAndStageStep(init_step_onnx_path, stage_step_onnx_path, combi
print(f"Combined model saved to {combined_onnx_path}") print(f"Combined model saved to {combined_onnx_path}")
def export(vits_path, gpt_path, project_name, voice_model_version, t2s_model_combine=False, export_audio_preprocessor=True): def export(vits_path, gpt_path, project_name, voice_model_version, t2s_model_combine=False, export_audio_preprocessor=True, half_precision=False):
vits = VitsModel(vits_path, version=voice_model_version) vits = VitsModel(vits_path, version=voice_model_version)
gpt = T2SModel(gpt_path, vits) gpt = T2SModel(gpt_path, vits)
gpt_sovits = GptSoVits(vits, gpt) gpt_sovits = GptSoVits(vits, gpt)
@ -444,6 +453,14 @@ def export(vits_path, gpt_path, project_name, voice_model_version, t2s_model_com
if t2s_model_combine: if t2s_model_combine:
combineInitStepAndStageStep(f'onnx/{project_name}/{project_name}_t2s_init_step.onnx', f'onnx/{project_name}/{project_name}_t2s_stage_step.onnx', f'onnx/{project_name}/{project_name}_t2s_combined.onnx') combineInitStepAndStageStep(f'onnx/{project_name}/{project_name}_t2s_init_step.onnx', f'onnx/{project_name}/{project_name}_t2s_stage_step.onnx', f'onnx/{project_name}/{project_name}_t2s_combined.onnx')
if half_precision:
if t2s_model_combine:
convert_onnx_to_half(f"onnx/{project_name}/{project_name}_t2s_combined.onnx")
if export_audio_preprocessor:
convert_onnx_to_half(f"onnx/{project_name}/{project_name}_audio_preprocess.onnx")
convert_onnx_to_half(f"onnx/{project_name}/{project_name}_vits.onnx")
convert_onnx_to_half(f"onnx/{project_name}/{project_name}_t2s_init_step.onnx")
convert_onnx_to_half(f"onnx/{project_name}/{project_name}_t2s_stage_step.onnx")
if __name__ == "__main__": if __name__ == "__main__":
try: try:
@ -457,25 +474,25 @@ if __name__ == "__main__":
vits_path = "GPT_SoVITS/pretrained_models/s2G488k.pth" vits_path = "GPT_SoVITS/pretrained_models/s2G488k.pth"
exp_path = "v1_export" exp_path = "v1_export"
version = "v1" version = "v1"
export(vits_path, gpt_path, exp_path, version, t2s_model_combine = True) export(vits_path, gpt_path, exp_path, version)
gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt" gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"
vits_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth" vits_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"
exp_path = "v2_export" exp_path = "v2_export"
version = "v2" version = "v2"
export(vits_path, gpt_path, exp_path, version, t2s_model_combine = True) export(vits_path, gpt_path, exp_path, version)
gpt_path = "GPT_SoVITS/pretrained_models/s1v3.ckpt" gpt_path = "GPT_SoVITS/pretrained_models/s1v3.ckpt"
vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth" vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth"
exp_path = "v2pro_export" exp_path = "v2pro_export"
version = "v2Pro" version = "v2Pro"
export(vits_path, gpt_path, exp_path, version, t2s_model_combine = True) export(vits_path, gpt_path, exp_path, version)
gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt" gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"
vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth" vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth"
exp_path = "v2proplus_export" exp_path = "v2proplus_export"
version = "v2ProPlus" version = "v2ProPlus"
export(vits_path, gpt_path, exp_path, version, t2s_model_combine = True) export(vits_path, gpt_path, exp_path, version, t2s_model_combine = True, half_precision=True)

View File

@ -11,7 +11,7 @@ onnx
onnxruntime; platform_machine == "aarch64" or platform_machine == "arm64" onnxruntime; platform_machine == "aarch64" or platform_machine == "arm64"
onnxruntime-gpu; platform_machine == "x86_64" or platform_machine == "AMD64" onnxruntime-gpu; platform_machine == "x86_64" or platform_machine == "AMD64"
onnxsim onnxsim
onnxruntime-tools onnxconverter-common
tqdm tqdm
funasr==1.0.27 funasr==1.0.27
cn2an cn2an