diff --git a/GPT_SoVITS/onnx_export.py b/GPT_SoVITS/onnx_export.py index cb167752..c0b535e2 100644 --- a/GPT_SoVITS/onnx_export.py +++ b/GPT_SoVITS/onnx_export.py @@ -7,12 +7,10 @@ from module.models_onnx import SynthesizerTrn, symbols_v1, symbols_v2 from torch import nn from sv import SV import onnx +from onnx import helper, TensorProto cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base" from transformers import HubertModel, HubertConfig -import json import os - -import soundfile from tqdm import tqdm from text import cleaned_text_to_sequence @@ -104,8 +102,8 @@ class T2SInitStep(nn.Module): bert = bert.unsqueeze(0) prompt = prompt_semantic.unsqueeze(0) [y, k, v, y_emb, x_example] = self.fsdc(self.encoder(all_phoneme_ids, bert), prompt) - fake_logits = torch.randn((1, 1025)) # Dummy logits for ONNX export - fake_samples = torch.randn((1, 1)) # Dummy samples for ONNX export + fake_logits = torch.randn((1, 1025), dtype=torch.float32) # Dummy logits for ONNX export + fake_samples = torch.zeros((1, 1), dtype=torch.int32) # Dummy samples for ONNX export return y, k, v, y_emb, x_example, fake_logits, fake_samples class T2SStageStep(nn.Module): @@ -115,7 +113,7 @@ class T2SStageStep(nn.Module): def forward(self, iy, ik, iv, iy_emb, ix_example): [y, k, v, y_emb, logits, samples] = self.stage_decoder(iy, ik, iv, iy_emb, ix_example) - fake_x_example = torch.randn((1, 512)) # Dummy x_example for ONNX export + fake_x_example = torch.randn((1, 512), dtype=torch.float32) # Dummy x_example for ONNX export return y, k, v, y_emb, fake_x_example, logits, samples class T2SModel(nn.Module): @@ -302,20 +300,63 @@ class AudioPreprocess(nn.Module): return ssl_content, spectrum, sv_emb -# def combineInitStepAndStageStep(init_step_onnx_path, stage_step_onnx_path): -# init_step_model = onnx.load(init_step_onnx_path) -# stage_step_model = onnx.load(stage_step_onnx_path) +def combineInitStepAndStageStep(init_step_onnx_path, stage_step_onnx_path, combined_onnx_path): + init_step_model = onnx.load(init_step_onnx_path) + stage_step_model = onnx.load(stage_step_onnx_path) -# # Combine the models (this is a simplified example; actual combination logic may vary) -# combined_graph = helper.make_graph( -# nodes=init_step_model.graph.node + stage_step_model.graph.node, -# name="combined_graph", -# inputs=init_step_model.graph.input, -# outputs=stage_step_model.graph.output -# ) + then_graph = init_step_model.graph + then_graph.name = "init_step_graph" + else_graph = stage_step_model.graph + else_graph.name = "stage_step_graph" -# combined_model = helper.make_model(combined_graph, producer_name='onnx-combiner') -# onnx.save(combined_model, f"onnx/{project_name}/{project_name}_combined.onnx") + data_inputs_init = [input for input in init_step_model.graph.input] + data_inputs_stage = [input for input in stage_step_model.graph.input] + + del then_graph.input[:] + del else_graph.input[:] + + # The output names of the subgraphs must be the same. + # The 'If' node will have an output with this same name. + subgraph_output_names = [output.name for output in then_graph.output] + for i, output in enumerate(else_graph.output): + assert subgraph_output_names[i] == output.name, "Subgraph output names must match" + + # Define the inputs for the main graph + # 1. The boolean condition to select the branch + cond_input = helper.make_tensor_value_info('if_init_step', TensorProto.BOOL, []) + + main_outputs = [output for output in init_step_model.graph.output] + + # Create the 'If' node + if_node = helper.make_node( + 'If', + inputs=['if_init_step'], + outputs=subgraph_output_names, # This name MUST match the subgraph's output name + then_branch=then_graph, + else_branch=else_graph + ) + + # Combine the models (this is a simplified example; actual combination logic may vary) + main_graph = helper.make_graph( + nodes=[if_node], + name="t2s_combined_graph", + inputs=[cond_input] + data_inputs_init + data_inputs_stage, + outputs=main_outputs + ) + + # Create the final combined model, specifying the opset and IR version + opset_version = 16 + final_model = helper.make_model(main_graph, + producer_name='GSV-ONNX-Exporter', + ir_version=9, # For compatibility with older onnxruntime + opset_imports=[helper.make_opsetid("", opset_version)]) + # Check the model for correctness + onnx.checker.check_model(final_model) + + # Save the combined model + onnx.save(final_model, combined_onnx_path) + print(f"Combined model saved to {combined_onnx_path}") + def export(vits_path, gpt_path, project_name, voice_model_version="v2"): vits = VitsModel(vits_path, version=voice_model_version) @@ -428,5 +469,6 @@ if __name__ == "__main__": exp_path = "v2proplus_export" version = "v2ProPlus" export(vits_path, gpt_path, exp_path, version) + combineInitStepAndStageStep('onnx/v2proplus_export/v2proplus_export_t2s_init_step.onnx', 'onnx/v2proplus_export/v2proplus_export_t2s_sdec.onnx', 'onnx/v2proplus_export/v2proplus_export_t2s_combined.onnx') diff --git a/playground/freerun.py b/playground/freerun.py index de82d52b..ffd80028 100644 --- a/playground/freerun.py +++ b/playground/freerun.py @@ -63,7 +63,7 @@ def preprocess_text(text:str): # input_phones_saved = np.load("playground/ref/input_phones.npy") # input_bert_saved = np.load("playground/ref/input_bert.npy").T.astype(np.float32) -[input_phones, input_bert] = preprocess_text("天上的风筝在天上飞,地上的人儿在地上追") +[input_phones, input_bert] = preprocess_text("地上的人儿吵吵闹闹在地上追") # ref_phones = np.load("playground/ref/ref_phones.npy") @@ -74,29 +74,33 @@ def preprocess_text(text:str): [audio_prompt_hubert, spectrum, sv_emb] = audio_preprocess("playground/ref/audio.wav") -init_step = ort.InferenceSession(MODEL_PATH+"_export_t2s_init_step.onnx") +t2s_combined = ort.InferenceSession(MODEL_PATH+"_export_t2s_combined.onnx") -[y, k, v, y_emb, x_example, fake_logits, fake_samples] = init_step.run(None, { +[y, k, v, y_emb, x_example, fake_logits, fake_samples] = t2s_combined.run(None, { + "if_init_step": np.array(True, dtype=bool), "input_text_phones": input_phones, "input_text_bert": input_bert, "ref_text_phones": ref_phones, "ref_text_bert": ref_bert, - "hubert_ssl_content": audio_prompt_hubert + "hubert_ssl_content": audio_prompt_hubert, + "iy":np.empty((1, 0), dtype=np.int64), + "ik":np.empty((24, 0, 1, 512), dtype=np.float32), + "iv":np.empty((24, 0, 1, 512), dtype=np.float32), + "iy_emb":np.empty((1, 0, 512), dtype=np.float32), + "ix_example":np.empty((1, 0), dtype=np.float32) }) -# fsdec = ort.InferenceSession(MODEL_PATH+"_export_t2s_fsdec.onnx") -sdec = ort.InferenceSession(MODEL_PATH+"_export_t2s_sdec.onnx") - -# for i in tqdm(range(10000)): -# [y, k, v, y_emb, x_example] = fsdec.run(None, { -# "x": x, -# "prompts": prompts -# }) for idx in tqdm(range(1, 1500)): # [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N] - [y, k, v, y_emb, fake_x_example, logits, samples] = sdec.run(None, { + [y, k, v, y_emb, fake_x_example, logits, samples] = t2s_combined.run(None, { + "if_init_step": np.array(False, dtype=bool), + "input_text_phones": np.empty((1, 0), dtype=np.int64), + "input_text_bert": np.empty((0, 1024), dtype=np.float32), + "ref_text_phones": np.empty((1, 0), dtype=np.int64), + "ref_text_bert": np.empty((0, 1024), dtype=np.float32), + "hubert_ssl_content": np.empty((1, 768, 0), dtype=np.float32), "iy": y, "ik": k, "iv": v,