diff --git a/GPT_SoVITS/onnx_export.py b/GPT_SoVITS/onnx_export.py index 8718a815..cb167752 100644 --- a/GPT_SoVITS/onnx_export.py +++ b/GPT_SoVITS/onnx_export.py @@ -6,7 +6,7 @@ from feature_extractor import cnhubert from module.models_onnx import SynthesizerTrn, symbols_v1, symbols_v2 from torch import nn from sv import SV - +import onnx cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base" from transformers import HubertModel, HubertConfig import json @@ -103,9 +103,20 @@ class T2SInitStep(nn.Module): all_phoneme_ids = torch.cat([ref_seq, text_seq], 1) bert = bert.unsqueeze(0) prompt = prompt_semantic.unsqueeze(0) - return self.fsdc(self.encoder(all_phoneme_ids, bert), prompt) - + [y, k, v, y_emb, x_example] = self.fsdc(self.encoder(all_phoneme_ids, bert), prompt) + fake_logits = torch.randn((1, 1025)) # Dummy logits for ONNX export + fake_samples = torch.randn((1, 1)) # Dummy samples for ONNX export + return y, k, v, y_emb, x_example, fake_logits, fake_samples +class T2SStageStep(nn.Module): + def __init__(self, stage_decoder): + super().__init__() + self.stage_decoder = stage_decoder + + def forward(self, iy, ik, iv, iy_emb, ix_example): + [y, k, v, y_emb, logits, samples] = self.stage_decoder(iy, ik, iv, iy_emb, ix_example) + fake_x_example = torch.randn((1, 512)) # Dummy x_example for ONNX export + return y, k, v, y_emb, fake_x_example, logits, samples class T2SModel(nn.Module): def __init__(self, t2s_path, vits_model): @@ -131,12 +142,13 @@ class T2SModel(nn.Module): early_stop_num = self.t2s_model.early_stop_num # [1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N] - y, k, v, y_emb, x_example = self.init_step(ref_seq, text_seq, ref_bert, text_bert, ssl_content) + y, k, v, y_emb, x_example, fake_logits, fake_samples = self.init_step(ref_seq, text_seq, ref_bert, text_bert, ssl_content) for idx in tqdm(range(1, 20)): # This is a fake one! do take this as reference # [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N] enco = self.stage_decoder(y, k, v, y_emb, x_example) y, k, v, y_emb, logits, samples = enco + print(logits.shape, samples.shape) if torch.argmax(logits, dim=-1)[0] == self.t2s_model.EOS or samples[0, 0] == self.t2s_model.EOS: break y[0, -1] = 0 @@ -158,7 +170,7 @@ class T2SModel(nn.Module): (ref_seq, text_seq, ref_bert, text_bert, ssl_content), f"onnx/{project_name}/{project_name}_t2s_init_step.onnx", input_names=["ref_text_phones", "input_text_phones", "ref_text_bert", "input_text_bert", "hubert_ssl_content"], - output_names=["y", "k", "v", "y_emb", "x_example"], + output_names=["y", "k", "v", "y_emb", "x_example", 'logits', 'samples'], dynamic_axes={ "ref_text_phones": {1: "ref_length"}, "input_text_phones": {1: "text_length"}, @@ -168,35 +180,22 @@ class T2SModel(nn.Module): }, opset_version=16, ) - y, k, v, y_emb, x_example = self.init_step(ref_seq, text_seq, ref_bert, text_bert, ssl_content) - - # torch.onnx.export( - # self.first_stage_decoder, - # (x, prompts), - # f"onnx/{project_name}/{project_name}_t2s_fsdec.onnx", - # input_names=["x", "prompts"], - # output_names=["y", "k", "v", "y_emb", "x_example"], - # dynamic_axes={ - # "x": {1: "x_length"}, - # "prompts": {1: "prompts_length"}, - # }, - # verbose=False, - # opset_version=16, - # ) - # y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts) + y, k, v, y_emb, x_example, fake_logits, fake_samples = self.init_step(ref_seq, text_seq, ref_bert, text_bert, ssl_content) + stage_step = T2SStageStep(self.stage_decoder) torch.onnx.export( - self.stage_decoder, + stage_step, (y, k, v, y_emb, x_example), f"onnx/{project_name}/{project_name}_t2s_sdec.onnx", input_names=["iy", "ik", "iv", "iy_emb", "ix_example"], - output_names=["y", "k", "v", "y_emb", "logits", "samples"], + output_names=["y", "k", "v", "y_emb","x_example", "logits", "samples"], dynamic_axes={ "iy": {1: "iy_length"}, "ik": {1: "ik_length"}, "iv": {1: "iv_length"}, "iy_emb": {1: "iy_emb_length"}, "ix_example": {1: "ix_example_length"}, + "x_example": {1: "x_example_length"} }, verbose=False, opset_version=16, @@ -303,6 +302,20 @@ class AudioPreprocess(nn.Module): return ssl_content, spectrum, sv_emb +# def combineInitStepAndStageStep(init_step_onnx_path, stage_step_onnx_path): +# init_step_model = onnx.load(init_step_onnx_path) +# stage_step_model = onnx.load(stage_step_onnx_path) + +# # Combine the models (this is a simplified example; actual combination logic may vary) +# combined_graph = helper.make_graph( +# nodes=init_step_model.graph.node + stage_step_model.graph.node, +# name="combined_graph", +# inputs=init_step_model.graph.input, +# outputs=stage_step_model.graph.output +# ) + +# combined_model = helper.make_model(combined_graph, producer_name='onnx-combiner') +# onnx.save(combined_model, f"onnx/{project_name}/{project_name}_combined.onnx") def export(vits_path, gpt_path, project_name, voice_model_version="v2"): vits = VitsModel(vits_path, version=voice_model_version) @@ -366,33 +379,23 @@ def export(vits_path, gpt_path, project_name, voice_model_version="v2"): ) ref_bert = torch.randn((ref_seq.shape[1], 1024)).float() text_bert = torch.randn((text_seq.shape[1], 1024)).float() - ref_audio = torch.randn((1, 48000 * 5)).float() - # ref_audio = torch.tensor([load_audio("rec.wav", 48000)]).float() - ref_audio32k = torchaudio.functional.resample(ref_audio, 48000, 32000).float() + ref_audio32k = torch.randn((1, 32000 * 5)).float() - 0.5 - try: - os.mkdir(f"onnx/{project_name}") - except: - pass - - torch.onnx.export(preprocessor, (ref_audio32k,), f"onnx/{project_name}/{project_name}_audio_preprocess.onnx", - input_names=["audio32k"], - output_names=["hubert_ssl_output", "spectrum", "sv_emb"], - dynamic_axes={ - "audio32k": {1: "sequence_length"}, - "hubert_ssl_output": {2: "hubert_length"}, - "spectrum": {2: "spectrum_length"} - }) + os.makedirs(f"onnx/{project_name}", exist_ok=True) [ssl_content, spectrum, sv_emb] = preprocessor(ref_audio32k) gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ssl_content.float(), spectrum.float(), sv_emb.float()) # exit() gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content.float(), spectrum.float(), sv_emb.float(), project_name) - if voice_model_version == "v1": - symbols = symbols_v1 - else: - symbols = symbols_v2 + torch.onnx.export(preprocessor, (ref_audio32k,), f"onnx/{project_name}/{project_name}_audio_preprocess.onnx", + input_names=["audio32k"], + output_names=["hubert_ssl_output", "spectrum", "sv_emb"], + dynamic_axes={ + "audio32k": {1: "sequence_length"}, + "hubert_ssl_output": {2: "hubert_length"}, + "spectrum": {2: "spectrum_length"} + }) if __name__ == "__main__": try: @@ -400,23 +403,25 @@ if __name__ == "__main__": except: pass - gpt_path = "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt" - vits_path = "GPT_SoVITS/pretrained_models/s2G488k.pth" - exp_path = "v1_export" - version = "v1" - export(vits_path, gpt_path, exp_path, version) - gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt" - vits_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth" - exp_path = "v2_export" - version = "v2" - export(vits_path, gpt_path, exp_path, version) - gpt_path = "GPT_SoVITS/pretrained_models/s1v3.ckpt" - vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth" - exp_path = "v2pro_export" - version = "v2Pro" - export(vits_path, gpt_path, exp_path, version) + # gpt_path = "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt" + # vits_path = "GPT_SoVITS/pretrained_models/s2G488k.pth" + # exp_path = "v1_export" + # version = "v1" + # export(vits_path, gpt_path, exp_path, version) + + # gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt" + # vits_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth" + # exp_path = "v2_export" + # version = "v2" + # export(vits_path, gpt_path, exp_path, version) + + # gpt_path = "GPT_SoVITS/pretrained_models/s1v3.ckpt" + # vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth" + # exp_path = "v2pro_export" + # version = "v2Pro" + # export(vits_path, gpt_path, exp_path, version) gpt_path = "GPT_SoVITS/pretrained_models/s1v3.ckpt" vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth" diff --git a/playground/freerun.py b/playground/freerun.py index 5733f253..de82d52b 100644 --- a/playground/freerun.py +++ b/playground/freerun.py @@ -7,7 +7,7 @@ import torch from TTS_infer_pack.TextPreprocessor_onnx import TextPreprocessorOnnx -MODEL_PATH = "onnx/v1_export/v1" +MODEL_PATH = "onnx/v2proplus_export/v2proplus" def audio_postprocess( audios, @@ -56,7 +56,7 @@ def audio_preprocess(audio_path): def preprocess_text(text:str): preprocessor = TextPreprocessorOnnx("playground/bert") - [phones, bert_features, norm_text] = preprocessor.segment_and_extract_feature_for_text(text, 'all_zh', 'v1') + [phones, bert_features, norm_text] = preprocessor.segment_and_extract_feature_for_text(text, 'all_zh', 'v2') phones = np.expand_dims(np.array(phones, dtype=np.int64), axis=0) return phones, bert_features.T.astype(np.float32) @@ -76,7 +76,7 @@ def preprocess_text(text:str): init_step = ort.InferenceSession(MODEL_PATH+"_export_t2s_init_step.onnx") -[y, k, v, y_emb, x_example] = init_step.run(None, { +[y, k, v, y_emb, x_example, fake_logits, fake_samples] = init_step.run(None, { "input_text_phones": input_phones, "input_text_bert": input_bert, "ref_text_phones": ref_phones, @@ -96,7 +96,7 @@ sdec = ort.InferenceSession(MODEL_PATH+"_export_t2s_sdec.onnx") for idx in tqdm(range(1, 1500)): # [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N] - [y, k, v, y_emb, logits, samples] = sdec.run(None, { + [y, k, v, y_emb, fake_x_example, logits, samples] = sdec.run(None, { "iy": y, "ik": k, "iv": v, @@ -123,7 +123,7 @@ vtis = ort.InferenceSession(MODEL_PATH+"_export_vits.onnx") "input_text_phones": input_phones, "pred_semantic": pred_semantic, "spectrum": spectrum.astype(np.float32), - # "sv_emb": sv_emb.astype(np.float32) + "sv_emb": sv_emb.astype(np.float32) }) audio_postprocess([audio]) diff --git a/playground/onnx_if_node_test.py b/playground/onnx_if_node_test.py new file mode 100644 index 00000000..0497f1a5 --- /dev/null +++ b/playground/onnx_if_node_test.py @@ -0,0 +1,214 @@ +import torch +import torch.nn as nn +import onnx +import onnxruntime as ort +from onnx import helper, TensorProto +import numpy as np +import os + +# Define file paths +PLAYGROUND_DIR = "playground" +MODEL_A_PATH = os.path.join(PLAYGROUND_DIR, "a.onnx") +MODEL_B_PATH = os.path.join(PLAYGROUND_DIR, "b.onnx") +MODEL_C_PATH = os.path.join(PLAYGROUND_DIR, "c.onnx") + +# --- 1. Create two simple PyTorch modules --- + +class ModelA(nn.Module): + """This model adds 1 to the input.""" + def forward(self, x): + return x + 1.0 + +class ModelB(nn.Module): + """This model multiplies the input by 2.""" + def forward(self, x): + return x * 2.0 + +def create_and_export_models(): + """Creates two nn.Modules and exports them to ONNX.""" + print("Step 1: Creating and exporting PyTorch models A and B...") + os.makedirs(PLAYGROUND_DIR, exist_ok=True) + + # Define a dummy input with a dynamic axis + batch_size = 1 + sequence_length = 10 # This dimension will be dynamic + features = 4 + dummy_input = torch.randn(batch_size, sequence_length, features) + + # Export Model A + print(f"Exporting Model A to {MODEL_A_PATH}") + torch.onnx.export( + ModelA(), + dummy_input, + MODEL_A_PATH, + input_names=['inputA'], + output_names=['output'], + dynamic_axes={'inputA': {1: 'sequenceA'}, 'output': {1: 'sequence'}}, + opset_version=11 # If node requires opset >= 11 + ) + + # Export Model B + print(f"Exporting Model B to {MODEL_B_PATH}") + torch.onnx.export( + ModelB(), + dummy_input, + MODEL_B_PATH, + input_names=['inputB'], + output_names=['output'], + dynamic_axes={'inputB': {1: 'sequenceB'}, 'output': {1: 'sequence'}}, + opset_version=11 + ) + print("Models A and B exported successfully.") + +def combine_models_with_if(): + """ + Reads two ONNX models and combines them into a third model + using an 'If' operator. + """ + print("\nStep 2: Combining models A and B into C with an 'If' node...") + + # Load the two exported ONNX models + model_a = onnx.load(MODEL_A_PATH) + model_b = onnx.load(MODEL_B_PATH) + + # The graphs for the 'then' and 'else' branches of the 'If' operator + then_graph = model_a.graph + then_graph.name = "then_branch_graph" + else_graph = model_b.graph + else_graph.name = "else_branch_graph" + + # The data input for the main graph is defined here. + # We take it from one of the original models. + data_inputA = model_a.graph.input[0] + data_inputB = model_b.graph.input[0] + + # For some onnxruntime versions, subgraphs should not have their own + # explicit 'input' list if the inputs are captured from the parent graph. + # We clear the input lists of the subgraphs to force implicit capture. + del then_graph.input[:] + del else_graph.input[:] + + # The output names of the subgraphs must be the same. + # The 'If' node will have an output with this same name. + subgraph_output_name = model_a.graph.output[0].name + assert subgraph_output_name == model_b.graph.output[0].name, "Subgraph output names must match" + + + # Define the inputs for the main graph + # 1. The boolean condition to select the branch + cond_input = helper.make_tensor_value_info('if_use_a', TensorProto.BOOL, []) + + # The main graph's output is the output from the 'If' node. + # We can use the ValueInfoProto from one of the subgraphs directly. + main_output = model_a.graph.output[0] + + # Create the 'If' node + if_node = helper.make_node( + 'If', + inputs=['if_use_a'], + outputs=[subgraph_output_name], # This name MUST match the subgraph's output name + then_branch=then_graph, + else_branch=else_graph + ) + + # Create the main graph containing the 'If' node. Its inputs are the condition + # AND the data that the subgraphs will capture. + main_graph = helper.make_graph( + nodes=[if_node], + name='if_main_graph', + inputs=[cond_input, data_inputA, data_inputB], + outputs=[main_output] + ) + + # Create the final combined model, specifying the opset and IR version + opset_version = 16 + final_model = helper.make_model(main_graph, + producer_name='onnx-if-combiner', + ir_version=9, # For compatibility with older onnxruntime + opset_imports=[helper.make_opsetid("", opset_version)]) + + # Check the model for correctness + onnx.checker.check_model(final_model) + + # Save the combined model + onnx.save(final_model, MODEL_C_PATH) + print(f"Combined model C saved to {MODEL_C_PATH}") + +def verify_combined_model(): + """ + Loads the combined ONNX model and runs inference to verify + that the 'If' branching and dynamic shapes work correctly. + """ + print("\nStep 3: Verifying the combined model C...") + sess = ort.InferenceSession(MODEL_C_PATH) + + # --- Test Case 1: Select Model A (if_use_a = True) --- + print("\n--- Verifying 'then' branch (Model A) ---") + use_a = np.array(True) + # Use a different sequence length to test dynamic axis + test_seq_len_a = 15 + test_seq_len_b = 10 + input_data_a = np.random.randn(1, test_seq_len_a, 4).astype(np.float32) + input_data_b = np.random.randn(1, test_seq_len_a, 4).astype(np.float32) + + # Run inference + outputs = sess.run( + None, + {'if_use_a': use_a, 'inputA': input_data_a, 'inputB': input_data_b} + ) + result_a = outputs[0] + + # Calculate expected output from Model A + expected_a = input_data_a + 1.0 + + # Verify the output and shape + np.testing.assert_allclose(result_a, expected_a, rtol=1e-5, atol=1e-5) + assert result_a.shape[1] == test_seq_len_a, "Dynamic shape failed for branch A" + print("✅ Branch A (if_use_a=True) works correctly.") + print(f"✅ Dynamic shape test passed (input seq_len={test_seq_len_a}, output seq_len={result_a.shape[1]})") + + # --- Test Case 2: Select Model B (if_use_a = False) --- + print("\n--- Verifying 'else' branch (Model B) ---") + use_b = np.array(False) + # Use another sequence length + test_seq_len_a = 8 + test_seq_len_b = 5 + input_data_a = np.random.randn(1, test_seq_len_a, 4).astype(np.float32) + input_data_b = np.random.randn(1, test_seq_len_b, 4).astype(np.float32) + + # Run inference + outputs = sess.run( + None, + {'if_use_a': use_b, 'inputA': input_data_a, 'inputB': input_data_b} + ) + result_b = outputs[0] + + # Calculate expected output from Model B + expected_b = input_data_b * 2.0 + + # Verify the output and shape + np.testing.assert_allclose(result_b, expected_b, rtol=1e-5, atol=1e-5) + assert result_b.shape[1] == test_seq_len_b, "Dynamic shape failed for branch B" + print("✅ Branch B (if_use_a=False) works correctly.") + print(f"✅ Dynamic shape test passed (input seq_len={test_seq_len_b}, output seq_len={result_b.shape[1]})") + +def cleanup(): + """Removes the intermediate ONNX files.""" + print("\nCleaning up intermediate files...") + for path in [MODEL_A_PATH, MODEL_B_PATH]: + if os.path.exists(path): + os.remove(path) + print(f"Removed {path}") + +def main(): + """Main function to run the entire process.""" + try: + create_and_export_models() + combine_models_with_if() + verify_combined_model() + finally: + cleanup() + print("\nAll steps completed successfully!") + +if __name__ == "__main__": + main()