diff --git a/.gitignore b/.gitignore index cefe803f..e9721828 100644 --- a/.gitignore +++ b/.gitignore @@ -194,4 +194,5 @@ cython_debug/ # PyPI configuration file .pypirc onnx/ -*.onnx \ No newline at end of file +*.onnx +tokenizer.json \ No newline at end of file diff --git a/playground/export_bert.py b/playground/export_bert.py new file mode 100644 index 00000000..c6341a08 --- /dev/null +++ b/playground/export_bert.py @@ -0,0 +1,156 @@ +import torch +import torch.nn as nn +from transformers import AutoTokenizer, AutoModelForMaskedLM +import onnx +import onnxruntime as ort +from typing import Dict, Any +import argparse +import os +import shutil +import numpy as np + +class CombinedBERTModel(nn.Module): + """Wrapper class that combines BERT tokenizer preprocessing and model inference.""" + + def __init__(self, model_name: str): + super().__init__() + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.model = AutoModelForMaskedLM.from_pretrained(model_name) + + def forward(self, text_input: torch.Tensor): + """Forward pass that includes tokenization and model inference.""" + # Note: For ONNX export, we'll work with pre-tokenized input_ids + # In practice, text tokenization needs to happen outside ONNX + input_ids = text_input.long() + + outputs = self.model(input_ids=input_ids, output_hidden_states=True) + return torch.cat(outputs["hidden_states"][-3:-2], -1)[0].cpu()[1:-1] + +def export_bert_to_onnx( + model_name: str = "bert-base-uncased", + output_dir: str = "bert_exported", + max_seq_length: int = 512 +): + """Export BERT model to ONNX format and copy tokenizer files.""" + + # Create output directory + os.makedirs(output_dir, exist_ok=True) + + print(f"Loading model: {model_name}") + combined_model = CombinedBERTModel(model_name) + combined_model.eval() + + # Create dummy inputs for ONNX export (pre-tokenized input_ids) + batch_size = 1 + dummy_input_ids = torch.randint(0, combined_model.tokenizer.vocab_size, (batch_size, max_seq_length)) + + # Export to ONNX + onnx_path = os.path.join(output_dir, "chinese-roberta-wwm-ext-large.onnx") + print(f"Exporting to ONNX: {onnx_path}") + torch.onnx.export( + combined_model, + dummy_input_ids, + onnx_path, + export_params=True, + opset_version=14, + do_constant_folding=True, + input_names=['input_ids'], + output_names=['logits'], + dynamic_axes={ + 'input_ids': {0: 'batch_size', 1: 'sequence_length'}, + 'logits': {0: 'batch_size', 1: 'sequence_length'} + } + ) + + # Copy tokenizer.json if it exists + tokenizer_cache_dir = combined_model.tokenizer.name_or_path + if os.path.isdir(tokenizer_cache_dir): + tokenizer_json_path = os.path.join(tokenizer_cache_dir, "tokenizer.json") + else: + # For models from HuggingFace cache + from transformers import cached_path + try: + tokenizer_json_path = combined_model.tokenizer._tokenizer.model_path + except: + # Alternative approach to find tokenizer.json in cache + cache_dir = os.path.expanduser("~/.cache/huggingface/transformers") + tokenizer_json_path = None + for root, dirs, files in os.walk(cache_dir): + if "tokenizer.json" in files and model_name.replace("/", "--") in root: + tokenizer_json_path = os.path.join(root, "tokenizer.json") + break + + if tokenizer_json_path and os.path.exists(tokenizer_json_path): + dest_tokenizer_path = os.path.join(output_dir, "tokenizer.json") + shutil.copy2(tokenizer_json_path, dest_tokenizer_path) + print(f"Copied tokenizer.json to: {dest_tokenizer_path}") + else: + print("Warning: tokenizer.json not found") + + print(f"Model exported successfully to: {output_dir}") + return combined_model, onnx_path + +def test_model_equivalence(original_model, onnx_path: str, max_seq_length: int = 512, tolerance: float = 1e-5): + """Test if the original PyTorch model and ONNX model produce the same outputs.""" + + print("Testing model equivalence...") + + # Create test input + batch_size = 1 + test_input_ids = torch.randint(0, original_model.tokenizer.vocab_size, (batch_size, max_seq_length)) + + # Get PyTorch output + original_model.eval() + with torch.no_grad(): + pytorch_output = original_model(test_input_ids).numpy() + + # Get ONNX output + ort_session = ort.InferenceSession(onnx_path) + onnx_output = ort_session.run(None, {"input_ids": test_input_ids.numpy()})[0] + + # Compare outputs + max_diff = np.max(np.abs(pytorch_output - onnx_output)) + mean_diff = np.mean(np.abs(pytorch_output - onnx_output)) + + print(f"Maximum absolute difference: {max_diff}") + print(f"Mean absolute difference: {mean_diff}") + + if max_diff < tolerance: + print("✅ Models are numerically equivalent!") + return True + else: + print("❌ Models have significant differences!") + return False + +def main(): + parser = argparse.ArgumentParser(description="Export BERT model to ONNX") + parser.add_argument("--model_name", type=str, default="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large", + help="Pretrained BERT model name") + parser.add_argument("--output_dir", type=str, default="playground/bert", + help="Output directory path") + parser.add_argument("--max_seq_length", type=int, default=512, + help="Maximum sequence length") + parser.add_argument("--tolerance", type=float, default=1e-3, + help="Tolerance for numerical comparison") + + args = parser.parse_args() + + os.makedirs(args.output_dir, exist_ok=True) + + # Export model + original_model, onnx_path = export_bert_to_onnx( + model_name=args.model_name, + output_dir=args.output_dir, + max_seq_length=args.max_seq_length + ) + + # Test equivalence + test_model_equivalence( + original_model=original_model, + onnx_path=onnx_path, + max_seq_length=args.max_seq_length, + tolerance=args.tolerance + ) + +if __name__ == "__main__": + main() diff --git a/playground/export_hubert.py b/playground/export_hubert.py new file mode 100644 index 00000000..75c9be38 --- /dev/null +++ b/playground/export_hubert.py @@ -0,0 +1,170 @@ +import os +import sys +import torch +import torchaudio +import onnxruntime as ort +import numpy as np +import argparse +from transformers import HubertModel, HubertConfig + + +class HubertONNXExporter: + """Export and test HuBERT model to ONNX format""" + + def __init__(self, model_path="GPT_SoVITS/pretrained_models/chinese-hubert-base", output_path="playground/hubert/chinese-hubert-base.onnx"): + self.model_path = model_path + self.onnx_path = output_path + self.model = None + self.config = None + + def setup_model(self): + """Configure and load the HuBERT model for ONNX export""" + # Configure for better ONNX compatibility + self.config = HubertConfig.from_pretrained(self.model_path) + self.config._attn_implementation = "eager" # Use standard attention + self.config.apply_spec_augment = False # Disable masking for inference + self.config.layerdrop = 0.0 # Disable layer dropout + + # Load the model + self.model = HubertModel.from_pretrained( + self.model_path, + config=self.config, + local_files_only=True + ) + self.model.eval() + + def export_to_onnx(self, dummy_length=16000): + """Export the model to ONNX format""" + if self.model is None: + raise ValueError("Model not initialized. Call setup_model() first.") + + # Create dummy input (1 second at 16kHz) + dummy_input = torch.rand(1, dummy_length, dtype=torch.float32) - 0.5 + + # Export to ONNX + torch.onnx.export( + self.model, + dummy_input, + self.onnx_path, + export_params=True, + opset_version=11, + do_constant_folding=True, + input_names=['audio16k'], + output_names=['last_hidden_state'], + dynamic_axes={ + 'audio16k': {0: 'batch_size', 1: 'sequence_length'}, + 'last_hidden_state': {0: 'batch_size', 1: 'sequence_length'} + } + ) + print(f"[Success] Model exported to {self.onnx_path}") + + def test_onnx_export_exists(self): + """Test that the ONNX model file was created""" + if os.path.exists(self.onnx_path): + print(f"[Success] ONNX model file exists at {self.onnx_path}") + return True + else: + print(f"[Error] ONNX model not found at {self.onnx_path}") + return False + + def _load_and_preprocess_audio(self, audio_path, max_length=32000): + """Load and preprocess audio file""" + waveform, sample_rate = torchaudio.load(audio_path) + + # Resample to 16kHz if needed + if sample_rate != 16000: + resampler = torchaudio.transforms.Resample(sample_rate, 16000) + waveform = resampler(waveform) + + # Take first channel + if waveform.shape[0] > 1: + waveform = waveform[0:1] + + # Limit length for testing (2 seconds at 16kHz) + if waveform.shape[1] > max_length: + waveform = waveform[:, :max_length] + + return waveform + + def test_torch_vs_onnx(self, audio_path="playground/ref/audio.wav"): + """Test that ONNX model outputs match PyTorch model outputs""" + if not os.path.exists(audio_path): + print(f"[Skip] Test audio file not found at {audio_path}") + return False + + if self.model is None: + raise ValueError("Model not initialized. Call setup_model() first.") + + # Load and preprocess audio + waveform = self._load_and_preprocess_audio(audio_path) + + # PyTorch inference + with torch.no_grad(): + torch_output = self.model(waveform) + torch_hidden_states = torch_output.last_hidden_state + + # ONNX inference + ort_session = ort.InferenceSession(self.onnx_path) + input_values = waveform.numpy().astype(np.float32) + ort_inputs = {ort_session.get_inputs()[0].name: input_values} + ort_outputs = ort_session.run(None, ort_inputs) + onnx_hidden_states = ort_outputs[0] + + # Compare outputs + torch_numpy = torch_hidden_states.numpy() + diff = np.abs(torch_numpy - onnx_hidden_states).mean() + + success = diff <= 1e-5 + status = "[Success]" if success else "[Fail]" + + print(f"{status} ONNX vs PyTorch comparison") + print(f" > mean_difference={diff}") + print(f" > torch_shape={torch_numpy.shape}") + print(f" > onnx_shape={onnx_hidden_states.shape}") + + return success + + def run_full_export_and_test(self): + """Run the complete export and testing pipeline""" + print("Starting HuBERT ONNX export and testing...") + + # Create output directory if it doesn't exist + os.makedirs(os.path.dirname(self.onnx_path), exist_ok=True) + + # Setup model + self.setup_model() + + # Export to ONNX + self.export_to_onnx() + + # Test export + self.test_onnx_export_exists() + self.test_torch_vs_onnx() + + print("Export and testing complete!") + + +def main(): + """Main execution function""" + parser = argparse.ArgumentParser(description="Export HuBERT model to ONNX format") + parser.add_argument( + "--model_path", + type=str, + default="GPT_SoVITS/pretrained_models/chinese-hubert-base", + help="Path to the HuBERT model directory (default: GPT_SoVITS/pretrained_models/chinese-hubert-base)" + ) + parser.add_argument( + "--output_path", + type=str, + default="playground/hubert/chinese-hubert-base.onnx", + help="Output path for the ONNX model (default: playground/hubert/chinese-hubert-base.onnx)" + ) + + args = parser.parse_args() + + exporter = HubertONNXExporter(model_path=args.model_path, output_path=args.output_path) + exporter.run_full_export_and_test() + + +if __name__ == "__main__": + main() \ No newline at end of file