mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-29 00:30:15 +08:00
feat:rename and features to onnx export
This commit is contained in:
parent
633e478b24
commit
0c5f61f98c
@ -8,6 +8,8 @@ import argparse
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import onnxsim
|
||||||
|
import onnx
|
||||||
|
|
||||||
class CombinedBERTModel(nn.Module):
|
class CombinedBERTModel(nn.Module):
|
||||||
"""Wrapper class that combines BERT tokenizer preprocessing and model inference."""
|
"""Wrapper class that combines BERT tokenizer preprocessing and model inference."""
|
||||||
@ -58,10 +60,16 @@ def export_bert_to_onnx(
|
|||||||
output_names=['logits'],
|
output_names=['logits'],
|
||||||
dynamic_axes={
|
dynamic_axes={
|
||||||
'input_ids': {0: 'batch_size', 1: 'sequence_length'},
|
'input_ids': {0: 'batch_size', 1: 'sequence_length'},
|
||||||
'logits': {0: 'batch_size', 1: 'sequence_length'}
|
'logits': {0: 'logits_length'}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
# Load the ONNX model
|
||||||
|
model = onnx.load(onnx_path)
|
||||||
|
# Simplify the model
|
||||||
|
model_simplified, _ = onnxsim.simplify(model)
|
||||||
|
# Save the simplified model
|
||||||
|
onnx.save(model_simplified, onnx_path)
|
||||||
|
|
||||||
# Copy tokenizer.json if it exists
|
# Copy tokenizer.json if it exists
|
||||||
tokenizer_cache_dir = combined_model.tokenizer.name_or_path
|
tokenizer_cache_dir = combined_model.tokenizer.name_or_path
|
||||||
if os.path.isdir(tokenizer_cache_dir):
|
if os.path.isdir(tokenizer_cache_dir):
|
||||||
@ -117,16 +125,20 @@ def test_model_equivalence(original_model, onnx_path: str, max_seq_length: int =
|
|||||||
# Create test input
|
# Create test input
|
||||||
batch_size = 1
|
batch_size = 1
|
||||||
test_input_ids = torch.randint(0, original_model.tokenizer.vocab_size, (batch_size, max_seq_length))
|
test_input_ids = torch.randint(0, original_model.tokenizer.vocab_size, (batch_size, max_seq_length))
|
||||||
|
input_ids = original_model.tokenizer.encode("原神,启动!", return_tensors="pt")
|
||||||
|
|
||||||
|
|
||||||
# Get PyTorch output
|
# Get PyTorch output
|
||||||
original_model.eval()
|
original_model.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
pytorch_output = original_model(test_input_ids).numpy()
|
pytorch_output = original_model(input_ids).numpy()
|
||||||
|
|
||||||
# Get ONNX output
|
# Get ONNX output
|
||||||
ort_session = ort.InferenceSession(onnx_path)
|
ort_session = ort.InferenceSession(onnx_path)
|
||||||
onnx_output = ort_session.run(None, {"input_ids": test_input_ids.numpy()})[0]
|
onnx_output = ort_session.run(None, {"input_ids": input_ids.numpy()})[0]
|
||||||
|
|
||||||
|
print(f"PyTorch output shape: {pytorch_output.shape}")
|
||||||
|
print(f"ONNX output shape: {onnx_output.shape}")
|
||||||
# Compare outputs
|
# Compare outputs
|
||||||
max_diff = np.max(np.abs(pytorch_output - onnx_output))
|
max_diff = np.max(np.abs(pytorch_output - onnx_output))
|
||||||
mean_diff = np.mean(np.abs(pytorch_output - onnx_output))
|
mean_diff = np.mean(np.abs(pytorch_output - onnx_output))
|
@ -462,6 +462,19 @@ def export(vits_path, gpt_path, project_name, voice_model_version, t2s_model_com
|
|||||||
convert_onnx_to_half(f"onnx/{project_name}/{project_name}_t2s_init_step.onnx")
|
convert_onnx_to_half(f"onnx/{project_name}/{project_name}_t2s_init_step.onnx")
|
||||||
convert_onnx_to_half(f"onnx/{project_name}/{project_name}_t2s_stage_step.onnx")
|
convert_onnx_to_half(f"onnx/{project_name}/{project_name}_t2s_stage_step.onnx")
|
||||||
|
|
||||||
|
configJson = {
|
||||||
|
"project_name": project_name,
|
||||||
|
"type": "GPTSoVits",
|
||||||
|
"version" : voice_model_version,
|
||||||
|
"bert_base_path": 'GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large',
|
||||||
|
"cnhuhbert_base_path": 'GPT_SoVITS/pretrained_models/chinese-hubert-base',
|
||||||
|
"t2s_weights_path": gpt_path,
|
||||||
|
"vits_weights_path": vits_path,
|
||||||
|
"half_precision": half_precision
|
||||||
|
}
|
||||||
|
with open(f"onnx/{project_name}/config.json", "w", encoding="utf-8") as f:
|
||||||
|
json.dump(configJson, f, ensure_ascii=False, indent=4)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
try:
|
try:
|
||||||
os.mkdir("onnx")
|
os.mkdir("onnx")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user