GPT-SoVITS/playground/export_hubert.py

170 lines
6.0 KiB
Python

import os
import sys
import torch
import torchaudio
import onnxruntime as ort
import numpy as np
import argparse
from transformers import HubertModel, HubertConfig
class HubertONNXExporter:
"""Export and test HuBERT model to ONNX format"""
def __init__(self, model_path="GPT_SoVITS/pretrained_models/chinese-hubert-base", output_path="playground/hubert/chinese-hubert-base.onnx"):
self.model_path = model_path
self.onnx_path = output_path
self.model = None
self.config = None
def setup_model(self):
"""Configure and load the HuBERT model for ONNX export"""
# Configure for better ONNX compatibility
self.config = HubertConfig.from_pretrained(self.model_path)
self.config._attn_implementation = "eager" # Use standard attention
self.config.apply_spec_augment = False # Disable masking for inference
self.config.layerdrop = 0.0 # Disable layer dropout
# Load the model
self.model = HubertModel.from_pretrained(
self.model_path,
config=self.config,
local_files_only=True
)
self.model.eval()
def export_to_onnx(self, dummy_length=16000):
"""Export the model to ONNX format"""
if self.model is None:
raise ValueError("Model not initialized. Call setup_model() first.")
# Create dummy input (1 second at 16kHz)
dummy_input = torch.rand(1, dummy_length, dtype=torch.float32) - 0.5
# Export to ONNX
torch.onnx.export(
self.model,
dummy_input,
self.onnx_path,
export_params=True,
opset_version=11,
do_constant_folding=True,
input_names=['audio16k'],
output_names=['last_hidden_state'],
dynamic_axes={
'audio16k': {0: 'batch_size', 1: 'sequence_length'},
'last_hidden_state': {0: 'batch_size', 1: 'sequence_length'}
}
)
print(f"[Success] Model exported to {self.onnx_path}")
def test_onnx_export_exists(self):
"""Test that the ONNX model file was created"""
if os.path.exists(self.onnx_path):
print(f"[Success] ONNX model file exists at {self.onnx_path}")
return True
else:
print(f"[Error] ONNX model not found at {self.onnx_path}")
return False
def _load_and_preprocess_audio(self, audio_path, max_length=32000):
"""Load and preprocess audio file"""
waveform, sample_rate = torchaudio.load(audio_path)
# Resample to 16kHz if needed
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(sample_rate, 16000)
waveform = resampler(waveform)
# Take first channel
if waveform.shape[0] > 1:
waveform = waveform[0:1]
# Limit length for testing (2 seconds at 16kHz)
if waveform.shape[1] > max_length:
waveform = waveform[:, :max_length]
return waveform
def test_torch_vs_onnx(self, audio_path="playground/ref/audio.wav"):
"""Test that ONNX model outputs match PyTorch model outputs"""
if not os.path.exists(audio_path):
print(f"[Skip] Test audio file not found at {audio_path}")
return False
if self.model is None:
raise ValueError("Model not initialized. Call setup_model() first.")
# Load and preprocess audio
waveform = self._load_and_preprocess_audio(audio_path)
# PyTorch inference
with torch.no_grad():
torch_output = self.model(waveform)
torch_hidden_states = torch_output.last_hidden_state
# ONNX inference
ort_session = ort.InferenceSession(self.onnx_path)
input_values = waveform.numpy().astype(np.float32)
ort_inputs = {ort_session.get_inputs()[0].name: input_values}
ort_outputs = ort_session.run(None, ort_inputs)
onnx_hidden_states = ort_outputs[0]
# Compare outputs
torch_numpy = torch_hidden_states.numpy()
diff = np.abs(torch_numpy - onnx_hidden_states).mean()
success = diff <= 1e-5
status = "[Success]" if success else "[Fail]"
print(f"{status} ONNX vs PyTorch comparison")
print(f" > mean_difference={diff}")
print(f" > torch_shape={torch_numpy.shape}")
print(f" > onnx_shape={onnx_hidden_states.shape}")
return success
def run_full_export_and_test(self):
"""Run the complete export and testing pipeline"""
print("Starting HuBERT ONNX export and testing...")
# Create output directory if it doesn't exist
os.makedirs(os.path.dirname(self.onnx_path), exist_ok=True)
# Setup model
self.setup_model()
# Export to ONNX
self.export_to_onnx()
# Test export
self.test_onnx_export_exists()
self.test_torch_vs_onnx()
print("Export and testing complete!")
def main():
"""Main execution function"""
parser = argparse.ArgumentParser(description="Export HuBERT model to ONNX format")
parser.add_argument(
"--model_path",
type=str,
default="GPT_SoVITS/pretrained_models/chinese-hubert-base",
help="Path to the HuBERT model directory (default: GPT_SoVITS/pretrained_models/chinese-hubert-base)"
)
parser.add_argument(
"--output_path",
type=str,
default="playground/hubert/chinese-hubert-base.onnx",
help="Output path for the ONNX model (default: playground/hubert/chinese-hubert-base.onnx)"
)
args = parser.parse_args()
exporter = HubertONNXExporter(model_path=args.model_path, output_path=args.output_path)
exporter.run_full_export_and_test()
if __name__ == "__main__":
main()