mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-29 17:10:02 +08:00
feat:get ready for if node merge
This commit is contained in:
parent
403c5bf320
commit
16d30ce1e4
@ -6,7 +6,7 @@ from feature_extractor import cnhubert
|
|||||||
from module.models_onnx import SynthesizerTrn, symbols_v1, symbols_v2
|
from module.models_onnx import SynthesizerTrn, symbols_v1, symbols_v2
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from sv import SV
|
from sv import SV
|
||||||
|
import onnx
|
||||||
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
|
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
|
||||||
from transformers import HubertModel, HubertConfig
|
from transformers import HubertModel, HubertConfig
|
||||||
import json
|
import json
|
||||||
@ -103,9 +103,20 @@ class T2SInitStep(nn.Module):
|
|||||||
all_phoneme_ids = torch.cat([ref_seq, text_seq], 1)
|
all_phoneme_ids = torch.cat([ref_seq, text_seq], 1)
|
||||||
bert = bert.unsqueeze(0)
|
bert = bert.unsqueeze(0)
|
||||||
prompt = prompt_semantic.unsqueeze(0)
|
prompt = prompt_semantic.unsqueeze(0)
|
||||||
return self.fsdc(self.encoder(all_phoneme_ids, bert), prompt)
|
[y, k, v, y_emb, x_example] = self.fsdc(self.encoder(all_phoneme_ids, bert), prompt)
|
||||||
|
fake_logits = torch.randn((1, 1025)) # Dummy logits for ONNX export
|
||||||
|
fake_samples = torch.randn((1, 1)) # Dummy samples for ONNX export
|
||||||
|
return y, k, v, y_emb, x_example, fake_logits, fake_samples
|
||||||
|
|
||||||
|
class T2SStageStep(nn.Module):
|
||||||
|
def __init__(self, stage_decoder):
|
||||||
|
super().__init__()
|
||||||
|
self.stage_decoder = stage_decoder
|
||||||
|
|
||||||
|
def forward(self, iy, ik, iv, iy_emb, ix_example):
|
||||||
|
[y, k, v, y_emb, logits, samples] = self.stage_decoder(iy, ik, iv, iy_emb, ix_example)
|
||||||
|
fake_x_example = torch.randn((1, 512)) # Dummy x_example for ONNX export
|
||||||
|
return y, k, v, y_emb, fake_x_example, logits, samples
|
||||||
|
|
||||||
class T2SModel(nn.Module):
|
class T2SModel(nn.Module):
|
||||||
def __init__(self, t2s_path, vits_model):
|
def __init__(self, t2s_path, vits_model):
|
||||||
@ -131,12 +142,13 @@ class T2SModel(nn.Module):
|
|||||||
early_stop_num = self.t2s_model.early_stop_num
|
early_stop_num = self.t2s_model.early_stop_num
|
||||||
|
|
||||||
# [1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N]
|
# [1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N]
|
||||||
y, k, v, y_emb, x_example = self.init_step(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
y, k, v, y_emb, x_example, fake_logits, fake_samples = self.init_step(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
||||||
|
|
||||||
for idx in tqdm(range(1, 20)): # This is a fake one! do take this as reference
|
for idx in tqdm(range(1, 20)): # This is a fake one! do take this as reference
|
||||||
# [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
|
# [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
|
||||||
enco = self.stage_decoder(y, k, v, y_emb, x_example)
|
enco = self.stage_decoder(y, k, v, y_emb, x_example)
|
||||||
y, k, v, y_emb, logits, samples = enco
|
y, k, v, y_emb, logits, samples = enco
|
||||||
|
print(logits.shape, samples.shape)
|
||||||
if torch.argmax(logits, dim=-1)[0] == self.t2s_model.EOS or samples[0, 0] == self.t2s_model.EOS:
|
if torch.argmax(logits, dim=-1)[0] == self.t2s_model.EOS or samples[0, 0] == self.t2s_model.EOS:
|
||||||
break
|
break
|
||||||
y[0, -1] = 0
|
y[0, -1] = 0
|
||||||
@ -158,7 +170,7 @@ class T2SModel(nn.Module):
|
|||||||
(ref_seq, text_seq, ref_bert, text_bert, ssl_content),
|
(ref_seq, text_seq, ref_bert, text_bert, ssl_content),
|
||||||
f"onnx/{project_name}/{project_name}_t2s_init_step.onnx",
|
f"onnx/{project_name}/{project_name}_t2s_init_step.onnx",
|
||||||
input_names=["ref_text_phones", "input_text_phones", "ref_text_bert", "input_text_bert", "hubert_ssl_content"],
|
input_names=["ref_text_phones", "input_text_phones", "ref_text_bert", "input_text_bert", "hubert_ssl_content"],
|
||||||
output_names=["y", "k", "v", "y_emb", "x_example"],
|
output_names=["y", "k", "v", "y_emb", "x_example", 'logits', 'samples'],
|
||||||
dynamic_axes={
|
dynamic_axes={
|
||||||
"ref_text_phones": {1: "ref_length"},
|
"ref_text_phones": {1: "ref_length"},
|
||||||
"input_text_phones": {1: "text_length"},
|
"input_text_phones": {1: "text_length"},
|
||||||
@ -168,35 +180,22 @@ class T2SModel(nn.Module):
|
|||||||
},
|
},
|
||||||
opset_version=16,
|
opset_version=16,
|
||||||
)
|
)
|
||||||
y, k, v, y_emb, x_example = self.init_step(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
y, k, v, y_emb, x_example, fake_logits, fake_samples = self.init_step(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
||||||
|
|
||||||
# torch.onnx.export(
|
|
||||||
# self.first_stage_decoder,
|
|
||||||
# (x, prompts),
|
|
||||||
# f"onnx/{project_name}/{project_name}_t2s_fsdec.onnx",
|
|
||||||
# input_names=["x", "prompts"],
|
|
||||||
# output_names=["y", "k", "v", "y_emb", "x_example"],
|
|
||||||
# dynamic_axes={
|
|
||||||
# "x": {1: "x_length"},
|
|
||||||
# "prompts": {1: "prompts_length"},
|
|
||||||
# },
|
|
||||||
# verbose=False,
|
|
||||||
# opset_version=16,
|
|
||||||
# )
|
|
||||||
# y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
|
|
||||||
|
|
||||||
|
stage_step = T2SStageStep(self.stage_decoder)
|
||||||
torch.onnx.export(
|
torch.onnx.export(
|
||||||
self.stage_decoder,
|
stage_step,
|
||||||
(y, k, v, y_emb, x_example),
|
(y, k, v, y_emb, x_example),
|
||||||
f"onnx/{project_name}/{project_name}_t2s_sdec.onnx",
|
f"onnx/{project_name}/{project_name}_t2s_sdec.onnx",
|
||||||
input_names=["iy", "ik", "iv", "iy_emb", "ix_example"],
|
input_names=["iy", "ik", "iv", "iy_emb", "ix_example"],
|
||||||
output_names=["y", "k", "v", "y_emb", "logits", "samples"],
|
output_names=["y", "k", "v", "y_emb","x_example", "logits", "samples"],
|
||||||
dynamic_axes={
|
dynamic_axes={
|
||||||
"iy": {1: "iy_length"},
|
"iy": {1: "iy_length"},
|
||||||
"ik": {1: "ik_length"},
|
"ik": {1: "ik_length"},
|
||||||
"iv": {1: "iv_length"},
|
"iv": {1: "iv_length"},
|
||||||
"iy_emb": {1: "iy_emb_length"},
|
"iy_emb": {1: "iy_emb_length"},
|
||||||
"ix_example": {1: "ix_example_length"},
|
"ix_example": {1: "ix_example_length"},
|
||||||
|
"x_example": {1: "x_example_length"}
|
||||||
},
|
},
|
||||||
verbose=False,
|
verbose=False,
|
||||||
opset_version=16,
|
opset_version=16,
|
||||||
@ -303,6 +302,20 @@ class AudioPreprocess(nn.Module):
|
|||||||
|
|
||||||
return ssl_content, spectrum, sv_emb
|
return ssl_content, spectrum, sv_emb
|
||||||
|
|
||||||
|
# def combineInitStepAndStageStep(init_step_onnx_path, stage_step_onnx_path):
|
||||||
|
# init_step_model = onnx.load(init_step_onnx_path)
|
||||||
|
# stage_step_model = onnx.load(stage_step_onnx_path)
|
||||||
|
|
||||||
|
# # Combine the models (this is a simplified example; actual combination logic may vary)
|
||||||
|
# combined_graph = helper.make_graph(
|
||||||
|
# nodes=init_step_model.graph.node + stage_step_model.graph.node,
|
||||||
|
# name="combined_graph",
|
||||||
|
# inputs=init_step_model.graph.input,
|
||||||
|
# outputs=stage_step_model.graph.output
|
||||||
|
# )
|
||||||
|
|
||||||
|
# combined_model = helper.make_model(combined_graph, producer_name='onnx-combiner')
|
||||||
|
# onnx.save(combined_model, f"onnx/{project_name}/{project_name}_combined.onnx")
|
||||||
|
|
||||||
def export(vits_path, gpt_path, project_name, voice_model_version="v2"):
|
def export(vits_path, gpt_path, project_name, voice_model_version="v2"):
|
||||||
vits = VitsModel(vits_path, version=voice_model_version)
|
vits = VitsModel(vits_path, version=voice_model_version)
|
||||||
@ -366,14 +379,14 @@ def export(vits_path, gpt_path, project_name, voice_model_version="v2"):
|
|||||||
)
|
)
|
||||||
ref_bert = torch.randn((ref_seq.shape[1], 1024)).float()
|
ref_bert = torch.randn((ref_seq.shape[1], 1024)).float()
|
||||||
text_bert = torch.randn((text_seq.shape[1], 1024)).float()
|
text_bert = torch.randn((text_seq.shape[1], 1024)).float()
|
||||||
ref_audio = torch.randn((1, 48000 * 5)).float()
|
ref_audio32k = torch.randn((1, 32000 * 5)).float() - 0.5
|
||||||
# ref_audio = torch.tensor([load_audio("rec.wav", 48000)]).float()
|
|
||||||
ref_audio32k = torchaudio.functional.resample(ref_audio, 48000, 32000).float()
|
|
||||||
|
|
||||||
try:
|
os.makedirs(f"onnx/{project_name}", exist_ok=True)
|
||||||
os.mkdir(f"onnx/{project_name}")
|
|
||||||
except:
|
[ssl_content, spectrum, sv_emb] = preprocessor(ref_audio32k)
|
||||||
pass
|
gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ssl_content.float(), spectrum.float(), sv_emb.float())
|
||||||
|
# exit()
|
||||||
|
gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content.float(), spectrum.float(), sv_emb.float(), project_name)
|
||||||
|
|
||||||
torch.onnx.export(preprocessor, (ref_audio32k,), f"onnx/{project_name}/{project_name}_audio_preprocess.onnx",
|
torch.onnx.export(preprocessor, (ref_audio32k,), f"onnx/{project_name}/{project_name}_audio_preprocess.onnx",
|
||||||
input_names=["audio32k"],
|
input_names=["audio32k"],
|
||||||
@ -384,39 +397,31 @@ def export(vits_path, gpt_path, project_name, voice_model_version="v2"):
|
|||||||
"spectrum": {2: "spectrum_length"}
|
"spectrum": {2: "spectrum_length"}
|
||||||
})
|
})
|
||||||
|
|
||||||
[ssl_content, spectrum, sv_emb] = preprocessor(ref_audio32k)
|
|
||||||
gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ssl_content.float(), spectrum.float(), sv_emb.float())
|
|
||||||
# exit()
|
|
||||||
gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content.float(), spectrum.float(), sv_emb.float(), project_name)
|
|
||||||
|
|
||||||
if voice_model_version == "v1":
|
|
||||||
symbols = symbols_v1
|
|
||||||
else:
|
|
||||||
symbols = symbols_v2
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
try:
|
try:
|
||||||
os.mkdir("onnx")
|
os.mkdir("onnx")
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
gpt_path = "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
|
|
||||||
vits_path = "GPT_SoVITS/pretrained_models/s2G488k.pth"
|
|
||||||
exp_path = "v1_export"
|
|
||||||
version = "v1"
|
|
||||||
export(vits_path, gpt_path, exp_path, version)
|
|
||||||
|
|
||||||
gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"
|
|
||||||
vits_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"
|
|
||||||
exp_path = "v2_export"
|
|
||||||
version = "v2"
|
|
||||||
export(vits_path, gpt_path, exp_path, version)
|
|
||||||
|
|
||||||
gpt_path = "GPT_SoVITS/pretrained_models/s1v3.ckpt"
|
# gpt_path = "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
|
||||||
vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth"
|
# vits_path = "GPT_SoVITS/pretrained_models/s2G488k.pth"
|
||||||
exp_path = "v2pro_export"
|
# exp_path = "v1_export"
|
||||||
version = "v2Pro"
|
# version = "v1"
|
||||||
export(vits_path, gpt_path, exp_path, version)
|
# export(vits_path, gpt_path, exp_path, version)
|
||||||
|
|
||||||
|
# gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"
|
||||||
|
# vits_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"
|
||||||
|
# exp_path = "v2_export"
|
||||||
|
# version = "v2"
|
||||||
|
# export(vits_path, gpt_path, exp_path, version)
|
||||||
|
|
||||||
|
# gpt_path = "GPT_SoVITS/pretrained_models/s1v3.ckpt"
|
||||||
|
# vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth"
|
||||||
|
# exp_path = "v2pro_export"
|
||||||
|
# version = "v2Pro"
|
||||||
|
# export(vits_path, gpt_path, exp_path, version)
|
||||||
|
|
||||||
gpt_path = "GPT_SoVITS/pretrained_models/s1v3.ckpt"
|
gpt_path = "GPT_SoVITS/pretrained_models/s1v3.ckpt"
|
||||||
vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth"
|
vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth"
|
||||||
|
@ -7,7 +7,7 @@ import torch
|
|||||||
from TTS_infer_pack.TextPreprocessor_onnx import TextPreprocessorOnnx
|
from TTS_infer_pack.TextPreprocessor_onnx import TextPreprocessorOnnx
|
||||||
|
|
||||||
|
|
||||||
MODEL_PATH = "onnx/v1_export/v1"
|
MODEL_PATH = "onnx/v2proplus_export/v2proplus"
|
||||||
|
|
||||||
def audio_postprocess(
|
def audio_postprocess(
|
||||||
audios,
|
audios,
|
||||||
@ -56,7 +56,7 @@ def audio_preprocess(audio_path):
|
|||||||
|
|
||||||
def preprocess_text(text:str):
|
def preprocess_text(text:str):
|
||||||
preprocessor = TextPreprocessorOnnx("playground/bert")
|
preprocessor = TextPreprocessorOnnx("playground/bert")
|
||||||
[phones, bert_features, norm_text] = preprocessor.segment_and_extract_feature_for_text(text, 'all_zh', 'v1')
|
[phones, bert_features, norm_text] = preprocessor.segment_and_extract_feature_for_text(text, 'all_zh', 'v2')
|
||||||
phones = np.expand_dims(np.array(phones, dtype=np.int64), axis=0)
|
phones = np.expand_dims(np.array(phones, dtype=np.int64), axis=0)
|
||||||
return phones, bert_features.T.astype(np.float32)
|
return phones, bert_features.T.astype(np.float32)
|
||||||
|
|
||||||
@ -76,7 +76,7 @@ def preprocess_text(text:str):
|
|||||||
|
|
||||||
init_step = ort.InferenceSession(MODEL_PATH+"_export_t2s_init_step.onnx")
|
init_step = ort.InferenceSession(MODEL_PATH+"_export_t2s_init_step.onnx")
|
||||||
|
|
||||||
[y, k, v, y_emb, x_example] = init_step.run(None, {
|
[y, k, v, y_emb, x_example, fake_logits, fake_samples] = init_step.run(None, {
|
||||||
"input_text_phones": input_phones,
|
"input_text_phones": input_phones,
|
||||||
"input_text_bert": input_bert,
|
"input_text_bert": input_bert,
|
||||||
"ref_text_phones": ref_phones,
|
"ref_text_phones": ref_phones,
|
||||||
@ -96,7 +96,7 @@ sdec = ort.InferenceSession(MODEL_PATH+"_export_t2s_sdec.onnx")
|
|||||||
|
|
||||||
for idx in tqdm(range(1, 1500)):
|
for idx in tqdm(range(1, 1500)):
|
||||||
# [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
|
# [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
|
||||||
[y, k, v, y_emb, logits, samples] = sdec.run(None, {
|
[y, k, v, y_emb, fake_x_example, logits, samples] = sdec.run(None, {
|
||||||
"iy": y,
|
"iy": y,
|
||||||
"ik": k,
|
"ik": k,
|
||||||
"iv": v,
|
"iv": v,
|
||||||
@ -123,7 +123,7 @@ vtis = ort.InferenceSession(MODEL_PATH+"_export_vits.onnx")
|
|||||||
"input_text_phones": input_phones,
|
"input_text_phones": input_phones,
|
||||||
"pred_semantic": pred_semantic,
|
"pred_semantic": pred_semantic,
|
||||||
"spectrum": spectrum.astype(np.float32),
|
"spectrum": spectrum.astype(np.float32),
|
||||||
# "sv_emb": sv_emb.astype(np.float32)
|
"sv_emb": sv_emb.astype(np.float32)
|
||||||
})
|
})
|
||||||
|
|
||||||
audio_postprocess([audio])
|
audio_postprocess([audio])
|
||||||
|
214
playground/onnx_if_node_test.py
Normal file
214
playground/onnx_if_node_test.py
Normal file
@ -0,0 +1,214 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import onnx
|
||||||
|
import onnxruntime as ort
|
||||||
|
from onnx import helper, TensorProto
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Define file paths
|
||||||
|
PLAYGROUND_DIR = "playground"
|
||||||
|
MODEL_A_PATH = os.path.join(PLAYGROUND_DIR, "a.onnx")
|
||||||
|
MODEL_B_PATH = os.path.join(PLAYGROUND_DIR, "b.onnx")
|
||||||
|
MODEL_C_PATH = os.path.join(PLAYGROUND_DIR, "c.onnx")
|
||||||
|
|
||||||
|
# --- 1. Create two simple PyTorch modules ---
|
||||||
|
|
||||||
|
class ModelA(nn.Module):
|
||||||
|
"""This model adds 1 to the input."""
|
||||||
|
def forward(self, x):
|
||||||
|
return x + 1.0
|
||||||
|
|
||||||
|
class ModelB(nn.Module):
|
||||||
|
"""This model multiplies the input by 2."""
|
||||||
|
def forward(self, x):
|
||||||
|
return x * 2.0
|
||||||
|
|
||||||
|
def create_and_export_models():
|
||||||
|
"""Creates two nn.Modules and exports them to ONNX."""
|
||||||
|
print("Step 1: Creating and exporting PyTorch models A and B...")
|
||||||
|
os.makedirs(PLAYGROUND_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
# Define a dummy input with a dynamic axis
|
||||||
|
batch_size = 1
|
||||||
|
sequence_length = 10 # This dimension will be dynamic
|
||||||
|
features = 4
|
||||||
|
dummy_input = torch.randn(batch_size, sequence_length, features)
|
||||||
|
|
||||||
|
# Export Model A
|
||||||
|
print(f"Exporting Model A to {MODEL_A_PATH}")
|
||||||
|
torch.onnx.export(
|
||||||
|
ModelA(),
|
||||||
|
dummy_input,
|
||||||
|
MODEL_A_PATH,
|
||||||
|
input_names=['inputA'],
|
||||||
|
output_names=['output'],
|
||||||
|
dynamic_axes={'inputA': {1: 'sequenceA'}, 'output': {1: 'sequence'}},
|
||||||
|
opset_version=11 # If node requires opset >= 11
|
||||||
|
)
|
||||||
|
|
||||||
|
# Export Model B
|
||||||
|
print(f"Exporting Model B to {MODEL_B_PATH}")
|
||||||
|
torch.onnx.export(
|
||||||
|
ModelB(),
|
||||||
|
dummy_input,
|
||||||
|
MODEL_B_PATH,
|
||||||
|
input_names=['inputB'],
|
||||||
|
output_names=['output'],
|
||||||
|
dynamic_axes={'inputB': {1: 'sequenceB'}, 'output': {1: 'sequence'}},
|
||||||
|
opset_version=11
|
||||||
|
)
|
||||||
|
print("Models A and B exported successfully.")
|
||||||
|
|
||||||
|
def combine_models_with_if():
|
||||||
|
"""
|
||||||
|
Reads two ONNX models and combines them into a third model
|
||||||
|
using an 'If' operator.
|
||||||
|
"""
|
||||||
|
print("\nStep 2: Combining models A and B into C with an 'If' node...")
|
||||||
|
|
||||||
|
# Load the two exported ONNX models
|
||||||
|
model_a = onnx.load(MODEL_A_PATH)
|
||||||
|
model_b = onnx.load(MODEL_B_PATH)
|
||||||
|
|
||||||
|
# The graphs for the 'then' and 'else' branches of the 'If' operator
|
||||||
|
then_graph = model_a.graph
|
||||||
|
then_graph.name = "then_branch_graph"
|
||||||
|
else_graph = model_b.graph
|
||||||
|
else_graph.name = "else_branch_graph"
|
||||||
|
|
||||||
|
# The data input for the main graph is defined here.
|
||||||
|
# We take it from one of the original models.
|
||||||
|
data_inputA = model_a.graph.input[0]
|
||||||
|
data_inputB = model_b.graph.input[0]
|
||||||
|
|
||||||
|
# For some onnxruntime versions, subgraphs should not have their own
|
||||||
|
# explicit 'input' list if the inputs are captured from the parent graph.
|
||||||
|
# We clear the input lists of the subgraphs to force implicit capture.
|
||||||
|
del then_graph.input[:]
|
||||||
|
del else_graph.input[:]
|
||||||
|
|
||||||
|
# The output names of the subgraphs must be the same.
|
||||||
|
# The 'If' node will have an output with this same name.
|
||||||
|
subgraph_output_name = model_a.graph.output[0].name
|
||||||
|
assert subgraph_output_name == model_b.graph.output[0].name, "Subgraph output names must match"
|
||||||
|
|
||||||
|
|
||||||
|
# Define the inputs for the main graph
|
||||||
|
# 1. The boolean condition to select the branch
|
||||||
|
cond_input = helper.make_tensor_value_info('if_use_a', TensorProto.BOOL, [])
|
||||||
|
|
||||||
|
# The main graph's output is the output from the 'If' node.
|
||||||
|
# We can use the ValueInfoProto from one of the subgraphs directly.
|
||||||
|
main_output = model_a.graph.output[0]
|
||||||
|
|
||||||
|
# Create the 'If' node
|
||||||
|
if_node = helper.make_node(
|
||||||
|
'If',
|
||||||
|
inputs=['if_use_a'],
|
||||||
|
outputs=[subgraph_output_name], # This name MUST match the subgraph's output name
|
||||||
|
then_branch=then_graph,
|
||||||
|
else_branch=else_graph
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the main graph containing the 'If' node. Its inputs are the condition
|
||||||
|
# AND the data that the subgraphs will capture.
|
||||||
|
main_graph = helper.make_graph(
|
||||||
|
nodes=[if_node],
|
||||||
|
name='if_main_graph',
|
||||||
|
inputs=[cond_input, data_inputA, data_inputB],
|
||||||
|
outputs=[main_output]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the final combined model, specifying the opset and IR version
|
||||||
|
opset_version = 16
|
||||||
|
final_model = helper.make_model(main_graph,
|
||||||
|
producer_name='onnx-if-combiner',
|
||||||
|
ir_version=9, # For compatibility with older onnxruntime
|
||||||
|
opset_imports=[helper.make_opsetid("", opset_version)])
|
||||||
|
|
||||||
|
# Check the model for correctness
|
||||||
|
onnx.checker.check_model(final_model)
|
||||||
|
|
||||||
|
# Save the combined model
|
||||||
|
onnx.save(final_model, MODEL_C_PATH)
|
||||||
|
print(f"Combined model C saved to {MODEL_C_PATH}")
|
||||||
|
|
||||||
|
def verify_combined_model():
|
||||||
|
"""
|
||||||
|
Loads the combined ONNX model and runs inference to verify
|
||||||
|
that the 'If' branching and dynamic shapes work correctly.
|
||||||
|
"""
|
||||||
|
print("\nStep 3: Verifying the combined model C...")
|
||||||
|
sess = ort.InferenceSession(MODEL_C_PATH)
|
||||||
|
|
||||||
|
# --- Test Case 1: Select Model A (if_use_a = True) ---
|
||||||
|
print("\n--- Verifying 'then' branch (Model A) ---")
|
||||||
|
use_a = np.array(True)
|
||||||
|
# Use a different sequence length to test dynamic axis
|
||||||
|
test_seq_len_a = 15
|
||||||
|
test_seq_len_b = 10
|
||||||
|
input_data_a = np.random.randn(1, test_seq_len_a, 4).astype(np.float32)
|
||||||
|
input_data_b = np.random.randn(1, test_seq_len_a, 4).astype(np.float32)
|
||||||
|
|
||||||
|
# Run inference
|
||||||
|
outputs = sess.run(
|
||||||
|
None,
|
||||||
|
{'if_use_a': use_a, 'inputA': input_data_a, 'inputB': input_data_b}
|
||||||
|
)
|
||||||
|
result_a = outputs[0]
|
||||||
|
|
||||||
|
# Calculate expected output from Model A
|
||||||
|
expected_a = input_data_a + 1.0
|
||||||
|
|
||||||
|
# Verify the output and shape
|
||||||
|
np.testing.assert_allclose(result_a, expected_a, rtol=1e-5, atol=1e-5)
|
||||||
|
assert result_a.shape[1] == test_seq_len_a, "Dynamic shape failed for branch A"
|
||||||
|
print("✅ Branch A (if_use_a=True) works correctly.")
|
||||||
|
print(f"✅ Dynamic shape test passed (input seq_len={test_seq_len_a}, output seq_len={result_a.shape[1]})")
|
||||||
|
|
||||||
|
# --- Test Case 2: Select Model B (if_use_a = False) ---
|
||||||
|
print("\n--- Verifying 'else' branch (Model B) ---")
|
||||||
|
use_b = np.array(False)
|
||||||
|
# Use another sequence length
|
||||||
|
test_seq_len_a = 8
|
||||||
|
test_seq_len_b = 5
|
||||||
|
input_data_a = np.random.randn(1, test_seq_len_a, 4).astype(np.float32)
|
||||||
|
input_data_b = np.random.randn(1, test_seq_len_b, 4).astype(np.float32)
|
||||||
|
|
||||||
|
# Run inference
|
||||||
|
outputs = sess.run(
|
||||||
|
None,
|
||||||
|
{'if_use_a': use_b, 'inputA': input_data_a, 'inputB': input_data_b}
|
||||||
|
)
|
||||||
|
result_b = outputs[0]
|
||||||
|
|
||||||
|
# Calculate expected output from Model B
|
||||||
|
expected_b = input_data_b * 2.0
|
||||||
|
|
||||||
|
# Verify the output and shape
|
||||||
|
np.testing.assert_allclose(result_b, expected_b, rtol=1e-5, atol=1e-5)
|
||||||
|
assert result_b.shape[1] == test_seq_len_b, "Dynamic shape failed for branch B"
|
||||||
|
print("✅ Branch B (if_use_a=False) works correctly.")
|
||||||
|
print(f"✅ Dynamic shape test passed (input seq_len={test_seq_len_b}, output seq_len={result_b.shape[1]})")
|
||||||
|
|
||||||
|
def cleanup():
|
||||||
|
"""Removes the intermediate ONNX files."""
|
||||||
|
print("\nCleaning up intermediate files...")
|
||||||
|
for path in [MODEL_A_PATH, MODEL_B_PATH]:
|
||||||
|
if os.path.exists(path):
|
||||||
|
os.remove(path)
|
||||||
|
print(f"Removed {path}")
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main function to run the entire process."""
|
||||||
|
try:
|
||||||
|
create_and_export_models()
|
||||||
|
combine_models_with_if()
|
||||||
|
verify_combined_model()
|
||||||
|
finally:
|
||||||
|
cleanup()
|
||||||
|
print("\nAll steps completed successfully!")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Loading…
x
Reference in New Issue
Block a user