feat:export onnx with combined graph ready, todo:link weights in onnx graph

This commit is contained in:
zpeng11 2025-08-21 01:52:34 -04:00
parent 16d30ce1e4
commit 77794a5923
2 changed files with 77 additions and 31 deletions

View File

@ -7,12 +7,10 @@ from module.models_onnx import SynthesizerTrn, symbols_v1, symbols_v2
from torch import nn
from sv import SV
import onnx
from onnx import helper, TensorProto
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
from transformers import HubertModel, HubertConfig
import json
import os
import soundfile
from tqdm import tqdm
from text import cleaned_text_to_sequence
@ -104,8 +102,8 @@ class T2SInitStep(nn.Module):
bert = bert.unsqueeze(0)
prompt = prompt_semantic.unsqueeze(0)
[y, k, v, y_emb, x_example] = self.fsdc(self.encoder(all_phoneme_ids, bert), prompt)
fake_logits = torch.randn((1, 1025)) # Dummy logits for ONNX export
fake_samples = torch.randn((1, 1)) # Dummy samples for ONNX export
fake_logits = torch.randn((1, 1025), dtype=torch.float32) # Dummy logits for ONNX export
fake_samples = torch.zeros((1, 1), dtype=torch.int32) # Dummy samples for ONNX export
return y, k, v, y_emb, x_example, fake_logits, fake_samples
class T2SStageStep(nn.Module):
@ -115,7 +113,7 @@ class T2SStageStep(nn.Module):
def forward(self, iy, ik, iv, iy_emb, ix_example):
[y, k, v, y_emb, logits, samples] = self.stage_decoder(iy, ik, iv, iy_emb, ix_example)
fake_x_example = torch.randn((1, 512)) # Dummy x_example for ONNX export
fake_x_example = torch.randn((1, 512), dtype=torch.float32) # Dummy x_example for ONNX export
return y, k, v, y_emb, fake_x_example, logits, samples
class T2SModel(nn.Module):
@ -302,20 +300,63 @@ class AudioPreprocess(nn.Module):
return ssl_content, spectrum, sv_emb
# def combineInitStepAndStageStep(init_step_onnx_path, stage_step_onnx_path):
# init_step_model = onnx.load(init_step_onnx_path)
# stage_step_model = onnx.load(stage_step_onnx_path)
def combineInitStepAndStageStep(init_step_onnx_path, stage_step_onnx_path, combined_onnx_path):
init_step_model = onnx.load(init_step_onnx_path)
stage_step_model = onnx.load(stage_step_onnx_path)
# # Combine the models (this is a simplified example; actual combination logic may vary)
# combined_graph = helper.make_graph(
# nodes=init_step_model.graph.node + stage_step_model.graph.node,
# name="combined_graph",
# inputs=init_step_model.graph.input,
# outputs=stage_step_model.graph.output
# )
then_graph = init_step_model.graph
then_graph.name = "init_step_graph"
else_graph = stage_step_model.graph
else_graph.name = "stage_step_graph"
# combined_model = helper.make_model(combined_graph, producer_name='onnx-combiner')
# onnx.save(combined_model, f"onnx/{project_name}/{project_name}_combined.onnx")
data_inputs_init = [input for input in init_step_model.graph.input]
data_inputs_stage = [input for input in stage_step_model.graph.input]
del then_graph.input[:]
del else_graph.input[:]
# The output names of the subgraphs must be the same.
# The 'If' node will have an output with this same name.
subgraph_output_names = [output.name for output in then_graph.output]
for i, output in enumerate(else_graph.output):
assert subgraph_output_names[i] == output.name, "Subgraph output names must match"
# Define the inputs for the main graph
# 1. The boolean condition to select the branch
cond_input = helper.make_tensor_value_info('if_init_step', TensorProto.BOOL, [])
main_outputs = [output for output in init_step_model.graph.output]
# Create the 'If' node
if_node = helper.make_node(
'If',
inputs=['if_init_step'],
outputs=subgraph_output_names, # This name MUST match the subgraph's output name
then_branch=then_graph,
else_branch=else_graph
)
# Combine the models (this is a simplified example; actual combination logic may vary)
main_graph = helper.make_graph(
nodes=[if_node],
name="t2s_combined_graph",
inputs=[cond_input] + data_inputs_init + data_inputs_stage,
outputs=main_outputs
)
# Create the final combined model, specifying the opset and IR version
opset_version = 16
final_model = helper.make_model(main_graph,
producer_name='GSV-ONNX-Exporter',
ir_version=9, # For compatibility with older onnxruntime
opset_imports=[helper.make_opsetid("", opset_version)])
# Check the model for correctness
onnx.checker.check_model(final_model)
# Save the combined model
onnx.save(final_model, combined_onnx_path)
print(f"Combined model saved to {combined_onnx_path}")
def export(vits_path, gpt_path, project_name, voice_model_version="v2"):
vits = VitsModel(vits_path, version=voice_model_version)
@ -428,5 +469,6 @@ if __name__ == "__main__":
exp_path = "v2proplus_export"
version = "v2ProPlus"
export(vits_path, gpt_path, exp_path, version)
combineInitStepAndStageStep('onnx/v2proplus_export/v2proplus_export_t2s_init_step.onnx', 'onnx/v2proplus_export/v2proplus_export_t2s_sdec.onnx', 'onnx/v2proplus_export/v2proplus_export_t2s_combined.onnx')

View File

@ -63,7 +63,7 @@ def preprocess_text(text:str):
# input_phones_saved = np.load("playground/ref/input_phones.npy")
# input_bert_saved = np.load("playground/ref/input_bert.npy").T.astype(np.float32)
[input_phones, input_bert] = preprocess_text("天上的风筝在天上飞,地上的人儿在地上追")
[input_phones, input_bert] = preprocess_text("地上的人儿吵吵闹闹在地上追")
# ref_phones = np.load("playground/ref/ref_phones.npy")
@ -74,29 +74,33 @@ def preprocess_text(text:str):
[audio_prompt_hubert, spectrum, sv_emb] = audio_preprocess("playground/ref/audio.wav")
init_step = ort.InferenceSession(MODEL_PATH+"_export_t2s_init_step.onnx")
t2s_combined = ort.InferenceSession(MODEL_PATH+"_export_t2s_combined.onnx")
[y, k, v, y_emb, x_example, fake_logits, fake_samples] = init_step.run(None, {
[y, k, v, y_emb, x_example, fake_logits, fake_samples] = t2s_combined.run(None, {
"if_init_step": np.array(True, dtype=bool),
"input_text_phones": input_phones,
"input_text_bert": input_bert,
"ref_text_phones": ref_phones,
"ref_text_bert": ref_bert,
"hubert_ssl_content": audio_prompt_hubert
"hubert_ssl_content": audio_prompt_hubert,
"iy":np.empty((1, 0), dtype=np.int64),
"ik":np.empty((24, 0, 1, 512), dtype=np.float32),
"iv":np.empty((24, 0, 1, 512), dtype=np.float32),
"iy_emb":np.empty((1, 0, 512), dtype=np.float32),
"ix_example":np.empty((1, 0), dtype=np.float32)
})
# fsdec = ort.InferenceSession(MODEL_PATH+"_export_t2s_fsdec.onnx")
sdec = ort.InferenceSession(MODEL_PATH+"_export_t2s_sdec.onnx")
# for i in tqdm(range(10000)):
# [y, k, v, y_emb, x_example] = fsdec.run(None, {
# "x": x,
# "prompts": prompts
# })
for idx in tqdm(range(1, 1500)):
# [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
[y, k, v, y_emb, fake_x_example, logits, samples] = sdec.run(None, {
[y, k, v, y_emb, fake_x_example, logits, samples] = t2s_combined.run(None, {
"if_init_step": np.array(False, dtype=bool),
"input_text_phones": np.empty((1, 0), dtype=np.int64),
"input_text_bert": np.empty((0, 1024), dtype=np.float32),
"ref_text_phones": np.empty((1, 0), dtype=np.int64),
"ref_text_bert": np.empty((0, 1024), dtype=np.float32),
"hubert_ssl_content": np.empty((1, 768, 0), dtype=np.float32),
"iy": y,
"ik": k,
"iv": v,