feat:text_bert and audio_hubert exports are ready and fully tested, todo:solve dependancy in playground runs

This commit is contained in:
zpeng11 2025-08-19 00:05:45 -04:00
parent 4e42a28f9c
commit aef9d26580
3 changed files with 328 additions and 1 deletions

3
.gitignore vendored
View File

@ -194,4 +194,5 @@ cython_debug/
# PyPI configuration file
.pypirc
onnx/
*.onnx
*.onnx
tokenizer.json

156
playground/export_bert.py Normal file
View File

@ -0,0 +1,156 @@
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForMaskedLM
import onnx
import onnxruntime as ort
from typing import Dict, Any
import argparse
import os
import shutil
import numpy as np
class CombinedBERTModel(nn.Module):
"""Wrapper class that combines BERT tokenizer preprocessing and model inference."""
def __init__(self, model_name: str):
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForMaskedLM.from_pretrained(model_name)
def forward(self, text_input: torch.Tensor):
"""Forward pass that includes tokenization and model inference."""
# Note: For ONNX export, we'll work with pre-tokenized input_ids
# In practice, text tokenization needs to happen outside ONNX
input_ids = text_input.long()
outputs = self.model(input_ids=input_ids, output_hidden_states=True)
return torch.cat(outputs["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
def export_bert_to_onnx(
model_name: str = "bert-base-uncased",
output_dir: str = "bert_exported",
max_seq_length: int = 512
):
"""Export BERT model to ONNX format and copy tokenizer files."""
# Create output directory
os.makedirs(output_dir, exist_ok=True)
print(f"Loading model: {model_name}")
combined_model = CombinedBERTModel(model_name)
combined_model.eval()
# Create dummy inputs for ONNX export (pre-tokenized input_ids)
batch_size = 1
dummy_input_ids = torch.randint(0, combined_model.tokenizer.vocab_size, (batch_size, max_seq_length))
# Export to ONNX
onnx_path = os.path.join(output_dir, "chinese-roberta-wwm-ext-large.onnx")
print(f"Exporting to ONNX: {onnx_path}")
torch.onnx.export(
combined_model,
dummy_input_ids,
onnx_path,
export_params=True,
opset_version=14,
do_constant_folding=True,
input_names=['input_ids'],
output_names=['logits'],
dynamic_axes={
'input_ids': {0: 'batch_size', 1: 'sequence_length'},
'logits': {0: 'batch_size', 1: 'sequence_length'}
}
)
# Copy tokenizer.json if it exists
tokenizer_cache_dir = combined_model.tokenizer.name_or_path
if os.path.isdir(tokenizer_cache_dir):
tokenizer_json_path = os.path.join(tokenizer_cache_dir, "tokenizer.json")
else:
# For models from HuggingFace cache
from transformers import cached_path
try:
tokenizer_json_path = combined_model.tokenizer._tokenizer.model_path
except:
# Alternative approach to find tokenizer.json in cache
cache_dir = os.path.expanduser("~/.cache/huggingface/transformers")
tokenizer_json_path = None
for root, dirs, files in os.walk(cache_dir):
if "tokenizer.json" in files and model_name.replace("/", "--") in root:
tokenizer_json_path = os.path.join(root, "tokenizer.json")
break
if tokenizer_json_path and os.path.exists(tokenizer_json_path):
dest_tokenizer_path = os.path.join(output_dir, "tokenizer.json")
shutil.copy2(tokenizer_json_path, dest_tokenizer_path)
print(f"Copied tokenizer.json to: {dest_tokenizer_path}")
else:
print("Warning: tokenizer.json not found")
print(f"Model exported successfully to: {output_dir}")
return combined_model, onnx_path
def test_model_equivalence(original_model, onnx_path: str, max_seq_length: int = 512, tolerance: float = 1e-5):
"""Test if the original PyTorch model and ONNX model produce the same outputs."""
print("Testing model equivalence...")
# Create test input
batch_size = 1
test_input_ids = torch.randint(0, original_model.tokenizer.vocab_size, (batch_size, max_seq_length))
# Get PyTorch output
original_model.eval()
with torch.no_grad():
pytorch_output = original_model(test_input_ids).numpy()
# Get ONNX output
ort_session = ort.InferenceSession(onnx_path)
onnx_output = ort_session.run(None, {"input_ids": test_input_ids.numpy()})[0]
# Compare outputs
max_diff = np.max(np.abs(pytorch_output - onnx_output))
mean_diff = np.mean(np.abs(pytorch_output - onnx_output))
print(f"Maximum absolute difference: {max_diff}")
print(f"Mean absolute difference: {mean_diff}")
if max_diff < tolerance:
print("✅ Models are numerically equivalent!")
return True
else:
print("❌ Models have significant differences!")
return False
def main():
parser = argparse.ArgumentParser(description="Export BERT model to ONNX")
parser.add_argument("--model_name", type=str, default="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
help="Pretrained BERT model name")
parser.add_argument("--output_dir", type=str, default="playground/bert",
help="Output directory path")
parser.add_argument("--max_seq_length", type=int, default=512,
help="Maximum sequence length")
parser.add_argument("--tolerance", type=float, default=1e-3,
help="Tolerance for numerical comparison")
args = parser.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
# Export model
original_model, onnx_path = export_bert_to_onnx(
model_name=args.model_name,
output_dir=args.output_dir,
max_seq_length=args.max_seq_length
)
# Test equivalence
test_model_equivalence(
original_model=original_model,
onnx_path=onnx_path,
max_seq_length=args.max_seq_length,
tolerance=args.tolerance
)
if __name__ == "__main__":
main()

170
playground/export_hubert.py Normal file
View File

@ -0,0 +1,170 @@
import os
import sys
import torch
import torchaudio
import onnxruntime as ort
import numpy as np
import argparse
from transformers import HubertModel, HubertConfig
class HubertONNXExporter:
"""Export and test HuBERT model to ONNX format"""
def __init__(self, model_path="GPT_SoVITS/pretrained_models/chinese-hubert-base", output_path="playground/hubert/chinese-hubert-base.onnx"):
self.model_path = model_path
self.onnx_path = output_path
self.model = None
self.config = None
def setup_model(self):
"""Configure and load the HuBERT model for ONNX export"""
# Configure for better ONNX compatibility
self.config = HubertConfig.from_pretrained(self.model_path)
self.config._attn_implementation = "eager" # Use standard attention
self.config.apply_spec_augment = False # Disable masking for inference
self.config.layerdrop = 0.0 # Disable layer dropout
# Load the model
self.model = HubertModel.from_pretrained(
self.model_path,
config=self.config,
local_files_only=True
)
self.model.eval()
def export_to_onnx(self, dummy_length=16000):
"""Export the model to ONNX format"""
if self.model is None:
raise ValueError("Model not initialized. Call setup_model() first.")
# Create dummy input (1 second at 16kHz)
dummy_input = torch.rand(1, dummy_length, dtype=torch.float32) - 0.5
# Export to ONNX
torch.onnx.export(
self.model,
dummy_input,
self.onnx_path,
export_params=True,
opset_version=11,
do_constant_folding=True,
input_names=['audio16k'],
output_names=['last_hidden_state'],
dynamic_axes={
'audio16k': {0: 'batch_size', 1: 'sequence_length'},
'last_hidden_state': {0: 'batch_size', 1: 'sequence_length'}
}
)
print(f"[Success] Model exported to {self.onnx_path}")
def test_onnx_export_exists(self):
"""Test that the ONNX model file was created"""
if os.path.exists(self.onnx_path):
print(f"[Success] ONNX model file exists at {self.onnx_path}")
return True
else:
print(f"[Error] ONNX model not found at {self.onnx_path}")
return False
def _load_and_preprocess_audio(self, audio_path, max_length=32000):
"""Load and preprocess audio file"""
waveform, sample_rate = torchaudio.load(audio_path)
# Resample to 16kHz if needed
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(sample_rate, 16000)
waveform = resampler(waveform)
# Take first channel
if waveform.shape[0] > 1:
waveform = waveform[0:1]
# Limit length for testing (2 seconds at 16kHz)
if waveform.shape[1] > max_length:
waveform = waveform[:, :max_length]
return waveform
def test_torch_vs_onnx(self, audio_path="playground/ref/audio.wav"):
"""Test that ONNX model outputs match PyTorch model outputs"""
if not os.path.exists(audio_path):
print(f"[Skip] Test audio file not found at {audio_path}")
return False
if self.model is None:
raise ValueError("Model not initialized. Call setup_model() first.")
# Load and preprocess audio
waveform = self._load_and_preprocess_audio(audio_path)
# PyTorch inference
with torch.no_grad():
torch_output = self.model(waveform)
torch_hidden_states = torch_output.last_hidden_state
# ONNX inference
ort_session = ort.InferenceSession(self.onnx_path)
input_values = waveform.numpy().astype(np.float32)
ort_inputs = {ort_session.get_inputs()[0].name: input_values}
ort_outputs = ort_session.run(None, ort_inputs)
onnx_hidden_states = ort_outputs[0]
# Compare outputs
torch_numpy = torch_hidden_states.numpy()
diff = np.abs(torch_numpy - onnx_hidden_states).mean()
success = diff <= 1e-5
status = "[Success]" if success else "[Fail]"
print(f"{status} ONNX vs PyTorch comparison")
print(f" > mean_difference={diff}")
print(f" > torch_shape={torch_numpy.shape}")
print(f" > onnx_shape={onnx_hidden_states.shape}")
return success
def run_full_export_and_test(self):
"""Run the complete export and testing pipeline"""
print("Starting HuBERT ONNX export and testing...")
# Create output directory if it doesn't exist
os.makedirs(os.path.dirname(self.onnx_path), exist_ok=True)
# Setup model
self.setup_model()
# Export to ONNX
self.export_to_onnx()
# Test export
self.test_onnx_export_exists()
self.test_torch_vs_onnx()
print("Export and testing complete!")
def main():
"""Main execution function"""
parser = argparse.ArgumentParser(description="Export HuBERT model to ONNX format")
parser.add_argument(
"--model_path",
type=str,
default="GPT_SoVITS/pretrained_models/chinese-hubert-base",
help="Path to the HuBERT model directory (default: GPT_SoVITS/pretrained_models/chinese-hubert-base)"
)
parser.add_argument(
"--output_path",
type=str,
default="playground/hubert/chinese-hubert-base.onnx",
help="Output path for the ONNX model (default: playground/hubert/chinese-hubert-base.onnx)"
)
args = parser.parse_args()
exporter = HubertONNXExporter(model_path=args.model_path, output_path=args.output_path)
exporter.run_full_export_and_test()
if __name__ == "__main__":
main()