mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-29 08:49:59 +08:00
feat:clean up playground
This commit is contained in:
parent
942caa888e
commit
633e478b24
@ -7,7 +7,7 @@ import torch
|
|||||||
from TTS_infer_pack.TextPreprocessor_onnx import TextPreprocessorOnnx
|
from TTS_infer_pack.TextPreprocessor_onnx import TextPreprocessorOnnx
|
||||||
|
|
||||||
|
|
||||||
MODEL_PATH = "onnx/v2proplus_export/v2proplus"
|
MODEL_PATH = "onnx/v2_export/v2"
|
||||||
|
|
||||||
def audio_postprocess(
|
def audio_postprocess(
|
||||||
audios,
|
audios,
|
||||||
@ -135,7 +135,7 @@ vtis = ort.InferenceSession(MODEL_PATH+"_export_vits.onnx")
|
|||||||
"input_text_phones": input_phones,
|
"input_text_phones": input_phones,
|
||||||
"pred_semantic": pred_semantic,
|
"pred_semantic": pred_semantic,
|
||||||
"spectrum": spectrum.astype(np.float32),
|
"spectrum": spectrum.astype(np.float32),
|
||||||
"sv_emb": sv_emb.astype(np.float32)
|
# "sv_emb": sv_emb.astype(np.float32)
|
||||||
})
|
})
|
||||||
|
|
||||||
audio_postprocess([audio])
|
audio_postprocess([audio])
|
||||||
|
@ -1,214 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import onnx
|
|
||||||
import onnxruntime as ort
|
|
||||||
from onnx import helper, TensorProto
|
|
||||||
import numpy as np
|
|
||||||
import os
|
|
||||||
|
|
||||||
# Define file paths
|
|
||||||
PLAYGROUND_DIR = "playground"
|
|
||||||
MODEL_A_PATH = os.path.join(PLAYGROUND_DIR, "a.onnx")
|
|
||||||
MODEL_B_PATH = os.path.join(PLAYGROUND_DIR, "b.onnx")
|
|
||||||
MODEL_C_PATH = os.path.join(PLAYGROUND_DIR, "c.onnx")
|
|
||||||
|
|
||||||
# --- 1. Create two simple PyTorch modules ---
|
|
||||||
|
|
||||||
class ModelA(nn.Module):
|
|
||||||
"""This model adds 1 to the input."""
|
|
||||||
def forward(self, x):
|
|
||||||
return x + 1.0
|
|
||||||
|
|
||||||
class ModelB(nn.Module):
|
|
||||||
"""This model multiplies the input by 2."""
|
|
||||||
def forward(self, x):
|
|
||||||
return x * 2.0
|
|
||||||
|
|
||||||
def create_and_export_models():
|
|
||||||
"""Creates two nn.Modules and exports them to ONNX."""
|
|
||||||
print("Step 1: Creating and exporting PyTorch models A and B...")
|
|
||||||
os.makedirs(PLAYGROUND_DIR, exist_ok=True)
|
|
||||||
|
|
||||||
# Define a dummy input with a dynamic axis
|
|
||||||
batch_size = 1
|
|
||||||
sequence_length = 10 # This dimension will be dynamic
|
|
||||||
features = 4
|
|
||||||
dummy_input = torch.randn(batch_size, sequence_length, features)
|
|
||||||
|
|
||||||
# Export Model A
|
|
||||||
print(f"Exporting Model A to {MODEL_A_PATH}")
|
|
||||||
torch.onnx.export(
|
|
||||||
ModelA(),
|
|
||||||
dummy_input,
|
|
||||||
MODEL_A_PATH,
|
|
||||||
input_names=['inputA'],
|
|
||||||
output_names=['output'],
|
|
||||||
dynamic_axes={'inputA': {1: 'sequenceA'}, 'output': {1: 'sequence'}},
|
|
||||||
opset_version=11 # If node requires opset >= 11
|
|
||||||
)
|
|
||||||
|
|
||||||
# Export Model B
|
|
||||||
print(f"Exporting Model B to {MODEL_B_PATH}")
|
|
||||||
torch.onnx.export(
|
|
||||||
ModelB(),
|
|
||||||
dummy_input,
|
|
||||||
MODEL_B_PATH,
|
|
||||||
input_names=['inputB'],
|
|
||||||
output_names=['output'],
|
|
||||||
dynamic_axes={'inputB': {1: 'sequenceB'}, 'output': {1: 'sequence'}},
|
|
||||||
opset_version=11
|
|
||||||
)
|
|
||||||
print("Models A and B exported successfully.")
|
|
||||||
|
|
||||||
def combine_models_with_if():
|
|
||||||
"""
|
|
||||||
Reads two ONNX models and combines them into a third model
|
|
||||||
using an 'If' operator.
|
|
||||||
"""
|
|
||||||
print("\nStep 2: Combining models A and B into C with an 'If' node...")
|
|
||||||
|
|
||||||
# Load the two exported ONNX models
|
|
||||||
model_a = onnx.load(MODEL_A_PATH)
|
|
||||||
model_b = onnx.load(MODEL_B_PATH)
|
|
||||||
|
|
||||||
# The graphs for the 'then' and 'else' branches of the 'If' operator
|
|
||||||
then_graph = model_a.graph
|
|
||||||
then_graph.name = "then_branch_graph"
|
|
||||||
else_graph = model_b.graph
|
|
||||||
else_graph.name = "else_branch_graph"
|
|
||||||
|
|
||||||
# The data input for the main graph is defined here.
|
|
||||||
# We take it from one of the original models.
|
|
||||||
data_inputA = model_a.graph.input[0]
|
|
||||||
data_inputB = model_b.graph.input[0]
|
|
||||||
|
|
||||||
# For some onnxruntime versions, subgraphs should not have their own
|
|
||||||
# explicit 'input' list if the inputs are captured from the parent graph.
|
|
||||||
# We clear the input lists of the subgraphs to force implicit capture.
|
|
||||||
del then_graph.input[:]
|
|
||||||
del else_graph.input[:]
|
|
||||||
|
|
||||||
# The output names of the subgraphs must be the same.
|
|
||||||
# The 'If' node will have an output with this same name.
|
|
||||||
subgraph_output_name = model_a.graph.output[0].name
|
|
||||||
assert subgraph_output_name == model_b.graph.output[0].name, "Subgraph output names must match"
|
|
||||||
|
|
||||||
|
|
||||||
# Define the inputs for the main graph
|
|
||||||
# 1. The boolean condition to select the branch
|
|
||||||
cond_input = helper.make_tensor_value_info('if_use_a', TensorProto.BOOL, [])
|
|
||||||
|
|
||||||
# The main graph's output is the output from the 'If' node.
|
|
||||||
# We can use the ValueInfoProto from one of the subgraphs directly.
|
|
||||||
main_output = model_a.graph.output[0]
|
|
||||||
|
|
||||||
# Create the 'If' node
|
|
||||||
if_node = helper.make_node(
|
|
||||||
'If',
|
|
||||||
inputs=['if_use_a'],
|
|
||||||
outputs=[subgraph_output_name], # This name MUST match the subgraph's output name
|
|
||||||
then_branch=then_graph,
|
|
||||||
else_branch=else_graph
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create the main graph containing the 'If' node. Its inputs are the condition
|
|
||||||
# AND the data that the subgraphs will capture.
|
|
||||||
main_graph = helper.make_graph(
|
|
||||||
nodes=[if_node],
|
|
||||||
name='if_main_graph',
|
|
||||||
inputs=[cond_input, data_inputA, data_inputB],
|
|
||||||
outputs=[main_output]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create the final combined model, specifying the opset and IR version
|
|
||||||
opset_version = 16
|
|
||||||
final_model = helper.make_model(main_graph,
|
|
||||||
producer_name='onnx-if-combiner',
|
|
||||||
ir_version=9, # For compatibility with older onnxruntime
|
|
||||||
opset_imports=[helper.make_opsetid("", opset_version)])
|
|
||||||
|
|
||||||
# Check the model for correctness
|
|
||||||
onnx.checker.check_model(final_model)
|
|
||||||
|
|
||||||
# Save the combined model
|
|
||||||
onnx.save(final_model, MODEL_C_PATH)
|
|
||||||
print(f"Combined model C saved to {MODEL_C_PATH}")
|
|
||||||
|
|
||||||
def verify_combined_model():
|
|
||||||
"""
|
|
||||||
Loads the combined ONNX model and runs inference to verify
|
|
||||||
that the 'If' branching and dynamic shapes work correctly.
|
|
||||||
"""
|
|
||||||
print("\nStep 3: Verifying the combined model C...")
|
|
||||||
sess = ort.InferenceSession(MODEL_C_PATH)
|
|
||||||
|
|
||||||
# --- Test Case 1: Select Model A (if_use_a = True) ---
|
|
||||||
print("\n--- Verifying 'then' branch (Model A) ---")
|
|
||||||
use_a = np.array(True)
|
|
||||||
# Use a different sequence length to test dynamic axis
|
|
||||||
test_seq_len_a = 15
|
|
||||||
test_seq_len_b = 10
|
|
||||||
input_data_a = np.random.randn(1, test_seq_len_a, 4).astype(np.float32)
|
|
||||||
input_data_b = np.random.randn(1, test_seq_len_a, 4).astype(np.float32)
|
|
||||||
|
|
||||||
# Run inference
|
|
||||||
outputs = sess.run(
|
|
||||||
None,
|
|
||||||
{'if_use_a': use_a, 'inputA': input_data_a, 'inputB': input_data_b}
|
|
||||||
)
|
|
||||||
result_a = outputs[0]
|
|
||||||
|
|
||||||
# Calculate expected output from Model A
|
|
||||||
expected_a = input_data_a + 1.0
|
|
||||||
|
|
||||||
# Verify the output and shape
|
|
||||||
np.testing.assert_allclose(result_a, expected_a, rtol=1e-5, atol=1e-5)
|
|
||||||
assert result_a.shape[1] == test_seq_len_a, "Dynamic shape failed for branch A"
|
|
||||||
print("✅ Branch A (if_use_a=True) works correctly.")
|
|
||||||
print(f"✅ Dynamic shape test passed (input seq_len={test_seq_len_a}, output seq_len={result_a.shape[1]})")
|
|
||||||
|
|
||||||
# --- Test Case 2: Select Model B (if_use_a = False) ---
|
|
||||||
print("\n--- Verifying 'else' branch (Model B) ---")
|
|
||||||
use_b = np.array(False)
|
|
||||||
# Use another sequence length
|
|
||||||
test_seq_len_a = 8
|
|
||||||
test_seq_len_b = 5
|
|
||||||
input_data_a = np.random.randn(1, test_seq_len_a, 4).astype(np.float32)
|
|
||||||
input_data_b = np.random.randn(1, test_seq_len_b, 4).astype(np.float32)
|
|
||||||
|
|
||||||
# Run inference
|
|
||||||
outputs = sess.run(
|
|
||||||
None,
|
|
||||||
{'if_use_a': use_b, 'inputA': input_data_a, 'inputB': input_data_b}
|
|
||||||
)
|
|
||||||
result_b = outputs[0]
|
|
||||||
|
|
||||||
# Calculate expected output from Model B
|
|
||||||
expected_b = input_data_b * 2.0
|
|
||||||
|
|
||||||
# Verify the output and shape
|
|
||||||
np.testing.assert_allclose(result_b, expected_b, rtol=1e-5, atol=1e-5)
|
|
||||||
assert result_b.shape[1] == test_seq_len_b, "Dynamic shape failed for branch B"
|
|
||||||
print("✅ Branch B (if_use_a=False) works correctly.")
|
|
||||||
print(f"✅ Dynamic shape test passed (input seq_len={test_seq_len_b}, output seq_len={result_b.shape[1]})")
|
|
||||||
|
|
||||||
def cleanup():
|
|
||||||
"""Removes the intermediate ONNX files."""
|
|
||||||
print("\nCleaning up intermediate files...")
|
|
||||||
for path in [MODEL_A_PATH, MODEL_B_PATH]:
|
|
||||||
if os.path.exists(path):
|
|
||||||
os.remove(path)
|
|
||||||
print(f"Removed {path}")
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""Main function to run the entire process."""
|
|
||||||
try:
|
|
||||||
create_and_export_models()
|
|
||||||
combine_models_with_if()
|
|
||||||
verify_combined_model()
|
|
||||||
finally:
|
|
||||||
cleanup()
|
|
||||||
print("\nAll steps completed successfully!")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading…
x
Reference in New Issue
Block a user