From 968ac4c26445704201456731f1395a3d020576b2 Mon Sep 17 00:00:00 2001 From: zpeng11 Date: Mon, 25 Aug 2025 22:37:52 -0400 Subject: [PATCH] feat: solved problem, export works --- GPT_SoVITS/onnx_export_v1v2.py | 197 ++++++++++++--------------------- playground/freerun.py | 45 ++++---- 2 files changed, 93 insertions(+), 149 deletions(-) diff --git a/GPT_SoVITS/onnx_export_v1v2.py b/GPT_SoVITS/onnx_export_v1v2.py index f8ffbc28..172eba1d 100644 --- a/GPT_SoVITS/onnx_export_v1v2.py +++ b/GPT_SoVITS/onnx_export_v1v2.py @@ -105,36 +105,29 @@ class DictToAttrRecursive(dict): raise AttributeError(f"Attribute {item} not found") -class T2SInitStep(nn.Module): +class T2SInitStage(nn.Module): def __init__(self, t2s, vits): super().__init__() self.encoder = t2s.onnx_encoder - self.fsdc = t2s.first_stage_decoder self.vits = vits + self.num_layers = t2s.num_layers - def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=None, top_p=None, repetition_penalty=None, temperature=None, first_infer=None): - first_infer = first_infer.to(torch.int64) + def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content): codes = self.vits.extract_latent(ssl_content) prompt_semantic = codes[0, 0] bert = torch.cat([ref_bert.transpose(0, 1), text_bert.transpose(0, 1)], 1) all_phoneme_ids = torch.cat([ref_seq, text_seq], 1) bert = bert.unsqueeze(0) prompt = prompt_semantic.unsqueeze(0) - [y, k, v, y_emb, x_example] = self.fsdc(self.encoder(all_phoneme_ids, bert), prompt, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature, first_infer=first_infer) - fake_logits = torch.zeros((1, 1025), dtype=torch.float32) # Dummy logits for ONNX export - fake_samples = torch.zeros((1, 1), dtype=torch.int32) # Dummy samples for ONNX export - return y, k, v, y_emb, x_example, fake_logits, fake_samples + x = self.encoder(all_phoneme_ids, bert) -class T2SStageStep(nn.Module): - def __init__(self, stage_decoder): - super().__init__() - self.stage_decoder = stage_decoder + x_seq_len = torch.onnx.operators.shape_as_tensor(x)[1] + y_seq_len = torch.onnx.operators.shape_as_tensor(prompt)[1] - def forward(self, iy, ik, iv, iy_emb, ix_example, top_k=None, top_p=None, repetition_penalty=None, temperature=None, first_infer=None): - first_infer = first_infer.to(torch.int64) - [y, k, v, y_emb, logits, samples] = self.stage_decoder(iy, ik, iv, iy_emb, ix_example, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature, first_infer=first_infer) - fake_x_example = torch.zeros((1, 512), dtype=torch.float32) # Dummy x_example for ONNX export - return y, k, v, y_emb, fake_x_example, logits, samples + init_k = torch.zeros((self.num_layers, (x_seq_len + y_seq_len), 1, 512), dtype=torch.float) + init_v = torch.zeros((self.num_layers, (x_seq_len + y_seq_len), 1, 512), dtype=torch.float) + + return x, prompt, init_k, init_v, x_seq_len, y_seq_len class T2SModel(nn.Module): def __init__(self, t2s_path, vits_model): @@ -151,17 +144,26 @@ class T2SModel(nn.Module): self.t2s_model.model.early_stop_num = torch.LongTensor([self.hz * self.max_sec]) self.t2s_model = self.t2s_model.model self.t2s_model.init_onnx() - self.init_step = T2SInitStep(self.t2s_model, self.vits_model) - self.first_stage_decoder = self.t2s_model.first_stage_decoder + self.init_stage = T2SInitStage(self.t2s_model, self.vits_model) self.stage_decoder = self.t2s_model.stage_decoder def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=None, top_p=None, repetition_penalty=None, temperature=None): - # [1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N] - y, k, v, y_emb, x_example, fake_logits, fake_samples = self.init_step(ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature, first_infer=torch.LongTensor([1])) + x, prompt, init_k, init_v, x_seq_len, y_seq_len = self.init_stage(ref_seq, text_seq, ref_bert, text_bert, ssl_content) + empty_tensor = torch.empty((1,0,512)).to(torch.float) + # first step + y, k, v, y_emb, logits, samples = self.stage_decoder(x, prompt, init_k, init_v, + empty_tensor, + top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature, + first_infer=torch.LongTensor([1]), x_seq_len=x_seq_len, y_seq_len=y_seq_len) for idx in range(5): # This is a fake one! DO NOT take this as reference - enco = self.stage_decoder(y, k, v, y_emb, x_example, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature, first_infer=torch.LongTensor([0])) - y, k, v, y_emb, logits, samples = enco + k = torch.nn.functional.pad(k, (0, 0, 0, 0, 0, 1)) + v = torch.nn.functional.pad(v, (0, 0, 0, 0, 0, 1)) + y_seq_len = y.shape[1] + y, k, v, y_emb, logits, samples = self.stage_decoder(empty_tensor, y, k, v, + y_emb, + top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature, + first_infer=torch.LongTensor([0]), x_seq_len=x_seq_len, y_seq_len=y_seq_len) # if torch.argmax(logits, dim=-1)[0] == self.t2s_model.EOS or samples[0, 0] == self.t2s_model.EOS: # break @@ -169,11 +171,11 @@ class T2SModel(nn.Module): def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name, top_k=None, top_p=None, repetition_penalty=None, temperature=None): torch.onnx.export( - self.init_step, - (ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k, top_p, repetition_penalty, temperature, torch.Tensor([True]).to(torch.bool)), - f"onnx/{project_name}/{project_name}_t2s_init_step.onnx", - input_names=["ref_text_phones", "input_text_phones", "ref_text_bert", "input_text_bert", "hubert_ssl_content", "top_k", "top_p", "repetition_penalty", "temperature", "if_init_step"], - output_names=["y", "k", "v", "y_emb", "x_example", 'logits', 'samples'], + self.init_stage, + (ref_seq, text_seq, ref_bert, text_bert, ssl_content), + f"onnx/{project_name}/{project_name}_t2s_init_stage.onnx", + input_names=["ref_text_phones", "input_text_phones", "ref_text_bert", "input_text_bert", "hubert_ssl_content"], + output_names=["x", "prompt", "init_k", "init_v", 'x_seq_len', 'y_seq_len'], dynamic_axes={ "ref_text_phones": {1: "ref_length"}, "input_text_phones": {1: "text_length"}, @@ -184,28 +186,38 @@ class T2SModel(nn.Module): opset_version=16, do_constant_folding=False ) - # simplify_onnx_model(f"onnx/{project_name}/{project_name}_t2s_init_step.onnx") - y, k, v, y_emb, x_example, fake_logits, fake_samples = self.init_step(ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature, first_infer=torch.Tensor([True]).to(torch.bool)) + simplify_onnx_model(f"onnx/{project_name}/{project_name}_t2s_init_stage.onnx") + x, prompt, init_k, init_v, x_seq_len, y_seq_len = self.init_stage(ref_seq, text_seq, ref_bert, text_bert, ssl_content) + empty_tensor = torch.empty((1,0,512)).to(torch.float) + x_seq_len = torch.Tensor([x_seq_len]).to(torch.int64) + y_seq_len = torch.Tensor([y_seq_len]).to(torch.int64) + + y, k, v, y_emb, logits, samples = self.stage_decoder(x, prompt, init_k, init_v, + empty_tensor, + top_k, top_p, repetition_penalty, temperature, + torch.LongTensor([1]), x_seq_len, y_seq_len) + print(y.shape, k.shape, v.shape, y_emb.shape, logits.shape, samples.shape) + k = torch.nn.functional.pad(k, (0, 0, 0, 0, 0, 1)) + v = torch.nn.functional.pad(v, (0, 0, 0, 0, 0, 1)) + y_seq_len = torch.Tensor([y.shape[1]]).to(torch.int64) - stage_step = T2SStageStep(self.stage_decoder) torch.onnx.export( - stage_step, - (y, k, v, y_emb, x_example, top_k, top_p, repetition_penalty, temperature, torch.Tensor([False]).to(torch.bool)), - f"onnx/{project_name}/{project_name}_t2s_stage_step.onnx", - input_names=["iy", "ik", "iv", "iy_emb", "ix_example", "top_k", "top_p", "repetition_penalty", "temperature", "if_init_step"], - output_names=["y", "k", "v", "y_emb","x_example", "logits", "samples"], + self.stage_decoder, + (x, y, k, v, y_emb, top_k, top_p, repetition_penalty, temperature, torch.LongTensor([0]), x_seq_len, y_seq_len), + f"onnx/{project_name}/{project_name}_t2s_stage_decoder.onnx", + input_names=["ix", "iy", "ik", "iv", "iy_emb", "top_k", "top_p", "repetition_penalty", "temperature", "if_init_step", "x_seq_len", "y_seq_len"], + output_names=["y", "k", "v", "y_emb", "logits", "samples"], dynamic_axes={ + "ix": {1: "ix_length"}, "iy": {1: "iy_length"}, "ik": {1: "ik_length"}, "iv": {1: "iv_length"}, "iy_emb": {1: "iy_emb_length"}, - "ix_example": {1: "ix_example_length"}, - "x_example": {1: "x_example_length"} }, verbose=False, opset_version=16, ) - simplify_onnx_model(f"onnx/{project_name}/{project_name}_t2s_stage_step.onnx") + simplify_onnx_model(f"onnx/{project_name}/{project_name}_t2s_stage_decoder.onnx") class VitsModel(nn.Module): @@ -301,73 +313,7 @@ class AudioPreprocess(nn.Module): return ssl_content, spectrum, sv_emb -def combineInitStepAndStageStep(init_step_onnx_path, stage_step_onnx_path, combined_onnx_path): - init_step_model = onnx.load(init_step_onnx_path) - stage_step_model = onnx.load(stage_step_onnx_path) - - then_graph = init_step_model.graph - then_graph.name = "init_step_graph" - else_graph = stage_step_model.graph - else_graph.name = "stage_step_graph" - - data_inputs_init = [input for input in init_step_model.graph.input] - data_inputs_stage = [input for input in stage_step_model.graph.input] - - # Get all names from both lists - names_list_init = {obj.name for obj in data_inputs_init} - names_list_stage = {obj.name for obj in data_inputs_stage} - # Find names that appear in both lists - repeated_input_names = names_list_init.intersection(names_list_stage) - # Filter out objects with repeated names - data_inputs_stage = [obj for obj in data_inputs_stage if obj.name not in repeated_input_names] - - del then_graph.input[:] - del else_graph.input[:] - - # The output names of the subgraphs must be the same. - # The 'If' node will have an output with this same name. - subgraph_output_names = [output.name for output in then_graph.output] - for i, output in enumerate(else_graph.output): - assert subgraph_output_names[i] == output.name, "Subgraph output names must match" - - # Define the inputs for the main graph - # 1. The boolean condition to select the branch - # cond_input = helper.make_tensor_value_info('if_init_step', TensorProto.BOOL, []) - - main_outputs = [output for output in init_step_model.graph.output] - - # Create the 'If' node - if_node = helper.make_node( - 'If', - inputs=['if_init_step'], - outputs=subgraph_output_names, # This name MUST match the subgraph's output name - then_branch=then_graph, - else_branch=else_graph - ) - - # Combine the models (this is a simplified example; actual combination logic may vary) - main_graph = helper.make_graph( - nodes=[if_node], - name="t2s_combined_graph", - inputs= data_inputs_init + data_inputs_stage, - outputs=main_outputs - ) - - # Create the final combined model, specifying the opset and IR version - opset_version = 16 - final_model = helper.make_model(main_graph, - producer_name='GSV-ONNX-Exporter', - ir_version=9, # For compatibility with older onnxruntime - opset_imports=[helper.make_opsetid("", opset_version)]) - # Check the model for correctness - onnx.checker.check_model(final_model) - - # Save the combined model - onnx.save(final_model, combined_onnx_path) - print(f"Combined model saved to {combined_onnx_path}") - - -def export(vits_path, gpt_path, project_name, voice_model_version, t2s_model_combine=False, export_audio_preprocessor=True, half_precision=False): +def export(vits_path, gpt_path, project_name, voice_model_version, export_audio_preprocessor=True, half_precision=False): vits = VitsModel(vits_path, version=voice_model_version) gpt = T2SModel(gpt_path, vits) gpt_sovits = GptSoVits(vits, gpt) @@ -453,12 +399,7 @@ def export(vits_path, gpt_path, project_name, voice_model_version, t2s_model_com }) simplify_onnx_model(f"onnx/{project_name}/{project_name}_audio_preprocess.onnx") - if t2s_model_combine: - combineInitStepAndStageStep(f'onnx/{project_name}/{project_name}_t2s_init_step.onnx', f'onnx/{project_name}/{project_name}_t2s_stage_step.onnx', f'onnx/{project_name}/{project_name}_t2s_combined.onnx') - if half_precision: - if t2s_model_combine: - convert_onnx_to_half(f"onnx/{project_name}/{project_name}_t2s_combined.onnx") if export_audio_preprocessor: convert_onnx_to_half(f"onnx/{project_name}/{project_name}_audio_preprocess.onnx") convert_onnx_to_half(f"onnx/{project_name}/{project_name}_vits.onnx") @@ -467,7 +408,7 @@ def export(vits_path, gpt_path, project_name, voice_model_version, t2s_model_com configJson = { "project_name": project_name, - "type": "GPTSoVits", + "type": "GPTSoVITS", "version" : voice_model_version, "bert_base_path": 'GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large', "cnhuhbert_base_path": 'GPT_SoVITS/pretrained_models/chinese-hubert-base', @@ -486,29 +427,29 @@ if __name__ == "__main__": # 因为io太频繁,可能导致模型导出出错(wsl非常明显),请自行重试 - # gpt_path = "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt" - # vits_path = "GPT_SoVITS/pretrained_models/s2G488k.pth" - # exp_path = "v1_export" - # version = "v1" - # export(vits_path, gpt_path, exp_path, version) + gpt_path = "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt" + vits_path = "GPT_SoVITS/pretrained_models/s2G488k.pth" + exp_path = "v1_export" + version = "v1" + export(vits_path, gpt_path, exp_path, version) gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt" vits_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth" exp_path = "v2_export" version = "v2" - export(vits_path, gpt_path, exp_path, version, t2s_model_combine = True) + export(vits_path, gpt_path, exp_path, version) - # gpt_path = "GPT_SoVITS/pretrained_models/s1v3.ckpt" - # vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth" - # exp_path = "v2pro_export" - # version = "v2Pro" - # export(vits_path, gpt_path, exp_path, version) + gpt_path = "GPT_SoVITS/pretrained_models/s1v3.ckpt" + vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth" + exp_path = "v2pro_export" + version = "v2Pro" + export(vits_path, gpt_path, exp_path, version) - # gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt" - # vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth" - # exp_path = "v2proplus_export" - # version = "v2ProPlus" - # export(vits_path, gpt_path, exp_path, version, t2s_model_combine = True, half_precision=True) + gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt" + vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth" + exp_path = "v2proplus_export" + version = "v2ProPlus" + export(vits_path, gpt_path, exp_path, version) diff --git a/playground/freerun.py b/playground/freerun.py index 93045381..9cfefbeb 100644 --- a/playground/freerun.py +++ b/playground/freerun.py @@ -7,7 +7,7 @@ import torch from TTS_infer_pack.TextPreprocessor_onnx import TextPreprocessorOnnx -MODEL_PATH = "onnx/v2_export/v2" +MODEL_PATH = "onnx/v1_export/v1" def audio_postprocess( audios, @@ -80,49 +80,52 @@ top_p = np.array([1.0], dtype=np.float32) repetition_penalty = np.array([1.0], dtype=np.float32) temperature = np.array([1.0], dtype=np.float32) -t2s_combined = ort.InferenceSession(MODEL_PATH+"_export_t2s_combined.onnx") +t2s_init_stage = ort.InferenceSession(MODEL_PATH+"_export_t2s_init_stage.onnx") # t2s_init_step = ort.InferenceSession(MODEL_PATH+"_export_t2s_init_step.onnx") -[y, k, v, y_emb, x_example, fake_logits, fake_samples] = t2s_combined.run(None, { - "if_init_step": np.array(True, dtype=bool), +[x, prompts, init_k, init_v, x_seq_len, y_seq_len] = t2s_init_stage.run(None, { "input_text_phones": input_phones, "input_text_bert": input_bert, "ref_text_phones": ref_phones, "ref_text_bert": ref_bert, "hubert_ssl_content": audio_prompt_hubert, - "iy":np.empty((1, 0), dtype=np.int64), - "ik":np.empty((24, 0, 1, 512), dtype=np.float32), - "iv":np.empty((24, 0, 1, 512), dtype=np.float32), - "iy_emb":np.empty((1, 0, 512), dtype=np.float32), - "ix_example":np.empty((1, 0), dtype=np.float32), +}) +empty_tensor = np.empty((1,0,512)).astype(np.float32) + +t2s_stage_decoder = ort.InferenceSession(MODEL_PATH+"_export_t2s_stage_decoder.onnx") +y, k, v, y_emb, logits, samples = t2s_stage_decoder.run(None, { + "ix": x, + "iy": prompts, + "ik": init_k, + "iv": init_v, + "iy_emb": empty_tensor, "top_k": top_k, "top_p": top_p, "repetition_penalty": repetition_penalty, "temperature": temperature, - "if_init_step": np.array([True], dtype=bool) + "if_init_step": np.array([1]).astype(np.int64), + "x_seq_len": np.array([x_seq_len]).astype(np.int64), + "y_seq_len": np.array([y_seq_len]).astype(np.int64) }) -# t2s_stage_step = ort.InferenceSession(MODEL_PATH+"_export_t2s_sdec.onnx") - for idx in tqdm(range(1, 1500)): + k = np.pad(k, ((0,0), (0,1), (0,0), (0,0))) + v = np.pad(v, ((0,0), (0,1), (0,0), (0,0))) + y_seq_len = np.array([y.shape[1]]).astype(np.int64) # [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N] - [y, k, v, y_emb, fake_x_example, logits, samples] = t2s_combined.run(None, { - "if_init_step": np.array(False, dtype=bool), - "input_text_phones": np.empty((1, 0), dtype=np.int64), - "input_text_bert": np.empty((0, 1024), dtype=np.float32), - "ref_text_phones": np.empty((1, 0), dtype=np.int64), - "ref_text_bert": np.empty((0, 1024), dtype=np.float32), - "hubert_ssl_content": np.empty((1, 768, 0), dtype=np.float32), + [y, k, v, y_emb, logits, samples] = t2s_stage_decoder.run(None, { + "ix": empty_tensor, "iy": y, "ik": k, "iv": v, "iy_emb": y_emb, - "ix_example": x_example, "top_k": top_k, "top_p": top_p, "repetition_penalty": repetition_penalty, "temperature": temperature, - "if_init_step": np.array([False], dtype=bool) + "if_init_step": np.array([0]).astype(np.int64), + "x_seq_len": np.array([x_seq_len]).astype(np.int64), + "y_seq_len": y_seq_len }) if np.argmax(logits, axis=-1)[0] == 1024 or samples[0, 0] == 1024: # 1024 is the EOS token break