mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-29 00:30:15 +08:00
feat:rename and features to onnx export
This commit is contained in:
parent
633e478b24
commit
0c5f61f98c
@ -8,6 +8,8 @@ import argparse
|
||||
import os
|
||||
import shutil
|
||||
import numpy as np
|
||||
import onnxsim
|
||||
import onnx
|
||||
|
||||
class CombinedBERTModel(nn.Module):
|
||||
"""Wrapper class that combines BERT tokenizer preprocessing and model inference."""
|
||||
@ -58,9 +60,15 @@ def export_bert_to_onnx(
|
||||
output_names=['logits'],
|
||||
dynamic_axes={
|
||||
'input_ids': {0: 'batch_size', 1: 'sequence_length'},
|
||||
'logits': {0: 'batch_size', 1: 'sequence_length'}
|
||||
'logits': {0: 'logits_length'}
|
||||
}
|
||||
)
|
||||
# Load the ONNX model
|
||||
model = onnx.load(onnx_path)
|
||||
# Simplify the model
|
||||
model_simplified, _ = onnxsim.simplify(model)
|
||||
# Save the simplified model
|
||||
onnx.save(model_simplified, onnx_path)
|
||||
|
||||
# Copy tokenizer.json if it exists
|
||||
tokenizer_cache_dir = combined_model.tokenizer.name_or_path
|
||||
@ -117,16 +125,20 @@ def test_model_equivalence(original_model, onnx_path: str, max_seq_length: int =
|
||||
# Create test input
|
||||
batch_size = 1
|
||||
test_input_ids = torch.randint(0, original_model.tokenizer.vocab_size, (batch_size, max_seq_length))
|
||||
input_ids = original_model.tokenizer.encode("原神,启动!", return_tensors="pt")
|
||||
|
||||
|
||||
# Get PyTorch output
|
||||
original_model.eval()
|
||||
with torch.no_grad():
|
||||
pytorch_output = original_model(test_input_ids).numpy()
|
||||
pytorch_output = original_model(input_ids).numpy()
|
||||
|
||||
# Get ONNX output
|
||||
ort_session = ort.InferenceSession(onnx_path)
|
||||
onnx_output = ort_session.run(None, {"input_ids": test_input_ids.numpy()})[0]
|
||||
onnx_output = ort_session.run(None, {"input_ids": input_ids.numpy()})[0]
|
||||
|
||||
print(f"PyTorch output shape: {pytorch_output.shape}")
|
||||
print(f"ONNX output shape: {onnx_output.shape}")
|
||||
# Compare outputs
|
||||
max_diff = np.max(np.abs(pytorch_output - onnx_output))
|
||||
mean_diff = np.mean(np.abs(pytorch_output - onnx_output))
|
@ -462,6 +462,19 @@ def export(vits_path, gpt_path, project_name, voice_model_version, t2s_model_com
|
||||
convert_onnx_to_half(f"onnx/{project_name}/{project_name}_t2s_init_step.onnx")
|
||||
convert_onnx_to_half(f"onnx/{project_name}/{project_name}_t2s_stage_step.onnx")
|
||||
|
||||
configJson = {
|
||||
"project_name": project_name,
|
||||
"type": "GPTSoVits",
|
||||
"version" : voice_model_version,
|
||||
"bert_base_path": 'GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large',
|
||||
"cnhuhbert_base_path": 'GPT_SoVITS/pretrained_models/chinese-hubert-base',
|
||||
"t2s_weights_path": gpt_path,
|
||||
"vits_weights_path": vits_path,
|
||||
"half_precision": half_precision
|
||||
}
|
||||
with open(f"onnx/{project_name}/config.json", "w", encoding="utf-8") as f:
|
||||
json.dump(configJson, f, ensure_ascii=False, indent=4)
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
os.mkdir("onnx")
|
||||
|
Loading…
x
Reference in New Issue
Block a user