diff --git a/GPT_SoVITS/onnx_export.py b/GPT_SoVITS/onnx_export_v1v2.py similarity index 92% rename from GPT_SoVITS/onnx_export.py rename to GPT_SoVITS/onnx_export_v1v2.py index 59a1a33e..700787c3 100644 --- a/GPT_SoVITS/onnx_export.py +++ b/GPT_SoVITS/onnx_export_v1v2.py @@ -14,6 +14,7 @@ import os import json from text import cleaned_text_to_sequence import onnxsim +from onnxconverter_common import float16 def simplify_onnx_model(onnx_model_path: str): # Load the ONNX model @@ -23,6 +24,14 @@ def simplify_onnx_model(onnx_model_path: str): # Save the simplified model onnx.save(model_simplified, onnx_model_path) +def convert_onnx_to_half(onnx_model_path:str): + try: + model = onnx.load(onnx_model_path) + model_fp16 = float16.convert_float_to_float16(model) + onnx.save(model_fp16, onnx_model_path) + except Exception as e: + print(f"Error converting {onnx_model_path} to half precision: {e}") + def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): hann_window = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) @@ -355,7 +364,7 @@ def combineInitStepAndStageStep(init_step_onnx_path, stage_step_onnx_path, combi print(f"Combined model saved to {combined_onnx_path}") -def export(vits_path, gpt_path, project_name, voice_model_version, t2s_model_combine=False, export_audio_preprocessor=True): +def export(vits_path, gpt_path, project_name, voice_model_version, t2s_model_combine=False, export_audio_preprocessor=True, half_precision=False): vits = VitsModel(vits_path, version=voice_model_version) gpt = T2SModel(gpt_path, vits) gpt_sovits = GptSoVits(vits, gpt) @@ -444,6 +453,14 @@ def export(vits_path, gpt_path, project_name, voice_model_version, t2s_model_com if t2s_model_combine: combineInitStepAndStageStep(f'onnx/{project_name}/{project_name}_t2s_init_step.onnx', f'onnx/{project_name}/{project_name}_t2s_stage_step.onnx', f'onnx/{project_name}/{project_name}_t2s_combined.onnx') + if half_precision: + if t2s_model_combine: + convert_onnx_to_half(f"onnx/{project_name}/{project_name}_t2s_combined.onnx") + if export_audio_preprocessor: + convert_onnx_to_half(f"onnx/{project_name}/{project_name}_audio_preprocess.onnx") + convert_onnx_to_half(f"onnx/{project_name}/{project_name}_vits.onnx") + convert_onnx_to_half(f"onnx/{project_name}/{project_name}_t2s_init_step.onnx") + convert_onnx_to_half(f"onnx/{project_name}/{project_name}_t2s_stage_step.onnx") if __name__ == "__main__": try: @@ -457,25 +474,25 @@ if __name__ == "__main__": vits_path = "GPT_SoVITS/pretrained_models/s2G488k.pth" exp_path = "v1_export" version = "v1" - export(vits_path, gpt_path, exp_path, version, t2s_model_combine = True) + export(vits_path, gpt_path, exp_path, version) gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt" vits_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth" exp_path = "v2_export" version = "v2" - export(vits_path, gpt_path, exp_path, version, t2s_model_combine = True) + export(vits_path, gpt_path, exp_path, version) gpt_path = "GPT_SoVITS/pretrained_models/s1v3.ckpt" vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth" exp_path = "v2pro_export" version = "v2Pro" - export(vits_path, gpt_path, exp_path, version, t2s_model_combine = True) + export(vits_path, gpt_path, exp_path, version) gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt" vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth" exp_path = "v2proplus_export" version = "v2ProPlus" - export(vits_path, gpt_path, exp_path, version, t2s_model_combine = True) + export(vits_path, gpt_path, exp_path, version, t2s_model_combine = True, half_precision=True) diff --git a/requirements.txt b/requirements.txt index d6f9a9ee..40f0976e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,7 +11,7 @@ onnx onnxruntime; platform_machine == "aarch64" or platform_machine == "arm64" onnxruntime-gpu; platform_machine == "x86_64" or platform_machine == "AMD64" onnxsim -onnxruntime-tools +onnxconverter-common tqdm funasr==1.0.27 cn2an