feat:rename and features to onnx export

This commit is contained in:
zpeng11 2025-08-25 01:46:53 -04:00
parent 633e478b24
commit 0c5f61f98c
2 changed files with 31 additions and 6 deletions

View File

@ -8,6 +8,8 @@ import argparse
import os
import shutil
import numpy as np
import onnxsim
import onnx
class CombinedBERTModel(nn.Module):
"""Wrapper class that combines BERT tokenizer preprocessing and model inference."""
@ -58,9 +60,15 @@ def export_bert_to_onnx(
output_names=['logits'],
dynamic_axes={
'input_ids': {0: 'batch_size', 1: 'sequence_length'},
'logits': {0: 'batch_size', 1: 'sequence_length'}
'logits': {0: 'logits_length'}
}
)
# Load the ONNX model
model = onnx.load(onnx_path)
# Simplify the model
model_simplified, _ = onnxsim.simplify(model)
# Save the simplified model
onnx.save(model_simplified, onnx_path)
# Copy tokenizer.json if it exists
tokenizer_cache_dir = combined_model.tokenizer.name_or_path
@ -117,16 +125,20 @@ def test_model_equivalence(original_model, onnx_path: str, max_seq_length: int =
# Create test input
batch_size = 1
test_input_ids = torch.randint(0, original_model.tokenizer.vocab_size, (batch_size, max_seq_length))
input_ids = original_model.tokenizer.encode("原神,启动!", return_tensors="pt")
# Get PyTorch output
original_model.eval()
with torch.no_grad():
pytorch_output = original_model(test_input_ids).numpy()
pytorch_output = original_model(input_ids).numpy()
# Get ONNX output
ort_session = ort.InferenceSession(onnx_path)
onnx_output = ort_session.run(None, {"input_ids": test_input_ids.numpy()})[0]
onnx_output = ort_session.run(None, {"input_ids": input_ids.numpy()})[0]
print(f"PyTorch output shape: {pytorch_output.shape}")
print(f"ONNX output shape: {onnx_output.shape}")
# Compare outputs
max_diff = np.max(np.abs(pytorch_output - onnx_output))
mean_diff = np.mean(np.abs(pytorch_output - onnx_output))

View File

@ -462,6 +462,19 @@ def export(vits_path, gpt_path, project_name, voice_model_version, t2s_model_com
convert_onnx_to_half(f"onnx/{project_name}/{project_name}_t2s_init_step.onnx")
convert_onnx_to_half(f"onnx/{project_name}/{project_name}_t2s_stage_step.onnx")
configJson = {
"project_name": project_name,
"type": "GPTSoVits",
"version" : voice_model_version,
"bert_base_path": 'GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large',
"cnhuhbert_base_path": 'GPT_SoVITS/pretrained_models/chinese-hubert-base',
"t2s_weights_path": gpt_path,
"vits_weights_path": vits_path,
"half_precision": half_precision
}
with open(f"onnx/{project_name}/config.json", "w", encoding="utf-8") as f:
json.dump(configJson, f, ensure_ascii=False, indent=4)
if __name__ == "__main__":
try:
os.mkdir("onnx")