mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-29 00:30:15 +08:00
feat:text_bert and audio_hubert exports are ready and fully tested, todo:solve dependancy in playground runs
This commit is contained in:
parent
4e42a28f9c
commit
aef9d26580
3
.gitignore
vendored
3
.gitignore
vendored
@ -194,4 +194,5 @@ cython_debug/
|
||||
# PyPI configuration file
|
||||
.pypirc
|
||||
onnx/
|
||||
*.onnx
|
||||
*.onnx
|
||||
tokenizer.json
|
156
playground/export_bert.py
Normal file
156
playground/export_bert.py
Normal file
@ -0,0 +1,156 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
||||
import onnx
|
||||
import onnxruntime as ort
|
||||
from typing import Dict, Any
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
import numpy as np
|
||||
|
||||
class CombinedBERTModel(nn.Module):
|
||||
"""Wrapper class that combines BERT tokenizer preprocessing and model inference."""
|
||||
|
||||
def __init__(self, model_name: str):
|
||||
super().__init__()
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
self.model = AutoModelForMaskedLM.from_pretrained(model_name)
|
||||
|
||||
def forward(self, text_input: torch.Tensor):
|
||||
"""Forward pass that includes tokenization and model inference."""
|
||||
# Note: For ONNX export, we'll work with pre-tokenized input_ids
|
||||
# In practice, text tokenization needs to happen outside ONNX
|
||||
input_ids = text_input.long()
|
||||
|
||||
outputs = self.model(input_ids=input_ids, output_hidden_states=True)
|
||||
return torch.cat(outputs["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
|
||||
|
||||
def export_bert_to_onnx(
|
||||
model_name: str = "bert-base-uncased",
|
||||
output_dir: str = "bert_exported",
|
||||
max_seq_length: int = 512
|
||||
):
|
||||
"""Export BERT model to ONNX format and copy tokenizer files."""
|
||||
|
||||
# Create output directory
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
print(f"Loading model: {model_name}")
|
||||
combined_model = CombinedBERTModel(model_name)
|
||||
combined_model.eval()
|
||||
|
||||
# Create dummy inputs for ONNX export (pre-tokenized input_ids)
|
||||
batch_size = 1
|
||||
dummy_input_ids = torch.randint(0, combined_model.tokenizer.vocab_size, (batch_size, max_seq_length))
|
||||
|
||||
# Export to ONNX
|
||||
onnx_path = os.path.join(output_dir, "chinese-roberta-wwm-ext-large.onnx")
|
||||
print(f"Exporting to ONNX: {onnx_path}")
|
||||
torch.onnx.export(
|
||||
combined_model,
|
||||
dummy_input_ids,
|
||||
onnx_path,
|
||||
export_params=True,
|
||||
opset_version=14,
|
||||
do_constant_folding=True,
|
||||
input_names=['input_ids'],
|
||||
output_names=['logits'],
|
||||
dynamic_axes={
|
||||
'input_ids': {0: 'batch_size', 1: 'sequence_length'},
|
||||
'logits': {0: 'batch_size', 1: 'sequence_length'}
|
||||
}
|
||||
)
|
||||
|
||||
# Copy tokenizer.json if it exists
|
||||
tokenizer_cache_dir = combined_model.tokenizer.name_or_path
|
||||
if os.path.isdir(tokenizer_cache_dir):
|
||||
tokenizer_json_path = os.path.join(tokenizer_cache_dir, "tokenizer.json")
|
||||
else:
|
||||
# For models from HuggingFace cache
|
||||
from transformers import cached_path
|
||||
try:
|
||||
tokenizer_json_path = combined_model.tokenizer._tokenizer.model_path
|
||||
except:
|
||||
# Alternative approach to find tokenizer.json in cache
|
||||
cache_dir = os.path.expanduser("~/.cache/huggingface/transformers")
|
||||
tokenizer_json_path = None
|
||||
for root, dirs, files in os.walk(cache_dir):
|
||||
if "tokenizer.json" in files and model_name.replace("/", "--") in root:
|
||||
tokenizer_json_path = os.path.join(root, "tokenizer.json")
|
||||
break
|
||||
|
||||
if tokenizer_json_path and os.path.exists(tokenizer_json_path):
|
||||
dest_tokenizer_path = os.path.join(output_dir, "tokenizer.json")
|
||||
shutil.copy2(tokenizer_json_path, dest_tokenizer_path)
|
||||
print(f"Copied tokenizer.json to: {dest_tokenizer_path}")
|
||||
else:
|
||||
print("Warning: tokenizer.json not found")
|
||||
|
||||
print(f"Model exported successfully to: {output_dir}")
|
||||
return combined_model, onnx_path
|
||||
|
||||
def test_model_equivalence(original_model, onnx_path: str, max_seq_length: int = 512, tolerance: float = 1e-5):
|
||||
"""Test if the original PyTorch model and ONNX model produce the same outputs."""
|
||||
|
||||
print("Testing model equivalence...")
|
||||
|
||||
# Create test input
|
||||
batch_size = 1
|
||||
test_input_ids = torch.randint(0, original_model.tokenizer.vocab_size, (batch_size, max_seq_length))
|
||||
|
||||
# Get PyTorch output
|
||||
original_model.eval()
|
||||
with torch.no_grad():
|
||||
pytorch_output = original_model(test_input_ids).numpy()
|
||||
|
||||
# Get ONNX output
|
||||
ort_session = ort.InferenceSession(onnx_path)
|
||||
onnx_output = ort_session.run(None, {"input_ids": test_input_ids.numpy()})[0]
|
||||
|
||||
# Compare outputs
|
||||
max_diff = np.max(np.abs(pytorch_output - onnx_output))
|
||||
mean_diff = np.mean(np.abs(pytorch_output - onnx_output))
|
||||
|
||||
print(f"Maximum absolute difference: {max_diff}")
|
||||
print(f"Mean absolute difference: {mean_diff}")
|
||||
|
||||
if max_diff < tolerance:
|
||||
print("✅ Models are numerically equivalent!")
|
||||
return True
|
||||
else:
|
||||
print("❌ Models have significant differences!")
|
||||
return False
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Export BERT model to ONNX")
|
||||
parser.add_argument("--model_name", type=str, default="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
|
||||
help="Pretrained BERT model name")
|
||||
parser.add_argument("--output_dir", type=str, default="playground/bert",
|
||||
help="Output directory path")
|
||||
parser.add_argument("--max_seq_length", type=int, default=512,
|
||||
help="Maximum sequence length")
|
||||
parser.add_argument("--tolerance", type=float, default=1e-3,
|
||||
help="Tolerance for numerical comparison")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# Export model
|
||||
original_model, onnx_path = export_bert_to_onnx(
|
||||
model_name=args.model_name,
|
||||
output_dir=args.output_dir,
|
||||
max_seq_length=args.max_seq_length
|
||||
)
|
||||
|
||||
# Test equivalence
|
||||
test_model_equivalence(
|
||||
original_model=original_model,
|
||||
onnx_path=onnx_path,
|
||||
max_seq_length=args.max_seq_length,
|
||||
tolerance=args.tolerance
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
170
playground/export_hubert.py
Normal file
170
playground/export_hubert.py
Normal file
@ -0,0 +1,170 @@
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
import torchaudio
|
||||
import onnxruntime as ort
|
||||
import numpy as np
|
||||
import argparse
|
||||
from transformers import HubertModel, HubertConfig
|
||||
|
||||
|
||||
class HubertONNXExporter:
|
||||
"""Export and test HuBERT model to ONNX format"""
|
||||
|
||||
def __init__(self, model_path="GPT_SoVITS/pretrained_models/chinese-hubert-base", output_path="playground/hubert/chinese-hubert-base.onnx"):
|
||||
self.model_path = model_path
|
||||
self.onnx_path = output_path
|
||||
self.model = None
|
||||
self.config = None
|
||||
|
||||
def setup_model(self):
|
||||
"""Configure and load the HuBERT model for ONNX export"""
|
||||
# Configure for better ONNX compatibility
|
||||
self.config = HubertConfig.from_pretrained(self.model_path)
|
||||
self.config._attn_implementation = "eager" # Use standard attention
|
||||
self.config.apply_spec_augment = False # Disable masking for inference
|
||||
self.config.layerdrop = 0.0 # Disable layer dropout
|
||||
|
||||
# Load the model
|
||||
self.model = HubertModel.from_pretrained(
|
||||
self.model_path,
|
||||
config=self.config,
|
||||
local_files_only=True
|
||||
)
|
||||
self.model.eval()
|
||||
|
||||
def export_to_onnx(self, dummy_length=16000):
|
||||
"""Export the model to ONNX format"""
|
||||
if self.model is None:
|
||||
raise ValueError("Model not initialized. Call setup_model() first.")
|
||||
|
||||
# Create dummy input (1 second at 16kHz)
|
||||
dummy_input = torch.rand(1, dummy_length, dtype=torch.float32) - 0.5
|
||||
|
||||
# Export to ONNX
|
||||
torch.onnx.export(
|
||||
self.model,
|
||||
dummy_input,
|
||||
self.onnx_path,
|
||||
export_params=True,
|
||||
opset_version=11,
|
||||
do_constant_folding=True,
|
||||
input_names=['audio16k'],
|
||||
output_names=['last_hidden_state'],
|
||||
dynamic_axes={
|
||||
'audio16k': {0: 'batch_size', 1: 'sequence_length'},
|
||||
'last_hidden_state': {0: 'batch_size', 1: 'sequence_length'}
|
||||
}
|
||||
)
|
||||
print(f"[Success] Model exported to {self.onnx_path}")
|
||||
|
||||
def test_onnx_export_exists(self):
|
||||
"""Test that the ONNX model file was created"""
|
||||
if os.path.exists(self.onnx_path):
|
||||
print(f"[Success] ONNX model file exists at {self.onnx_path}")
|
||||
return True
|
||||
else:
|
||||
print(f"[Error] ONNX model not found at {self.onnx_path}")
|
||||
return False
|
||||
|
||||
def _load_and_preprocess_audio(self, audio_path, max_length=32000):
|
||||
"""Load and preprocess audio file"""
|
||||
waveform, sample_rate = torchaudio.load(audio_path)
|
||||
|
||||
# Resample to 16kHz if needed
|
||||
if sample_rate != 16000:
|
||||
resampler = torchaudio.transforms.Resample(sample_rate, 16000)
|
||||
waveform = resampler(waveform)
|
||||
|
||||
# Take first channel
|
||||
if waveform.shape[0] > 1:
|
||||
waveform = waveform[0:1]
|
||||
|
||||
# Limit length for testing (2 seconds at 16kHz)
|
||||
if waveform.shape[1] > max_length:
|
||||
waveform = waveform[:, :max_length]
|
||||
|
||||
return waveform
|
||||
|
||||
def test_torch_vs_onnx(self, audio_path="playground/ref/audio.wav"):
|
||||
"""Test that ONNX model outputs match PyTorch model outputs"""
|
||||
if not os.path.exists(audio_path):
|
||||
print(f"[Skip] Test audio file not found at {audio_path}")
|
||||
return False
|
||||
|
||||
if self.model is None:
|
||||
raise ValueError("Model not initialized. Call setup_model() first.")
|
||||
|
||||
# Load and preprocess audio
|
||||
waveform = self._load_and_preprocess_audio(audio_path)
|
||||
|
||||
# PyTorch inference
|
||||
with torch.no_grad():
|
||||
torch_output = self.model(waveform)
|
||||
torch_hidden_states = torch_output.last_hidden_state
|
||||
|
||||
# ONNX inference
|
||||
ort_session = ort.InferenceSession(self.onnx_path)
|
||||
input_values = waveform.numpy().astype(np.float32)
|
||||
ort_inputs = {ort_session.get_inputs()[0].name: input_values}
|
||||
ort_outputs = ort_session.run(None, ort_inputs)
|
||||
onnx_hidden_states = ort_outputs[0]
|
||||
|
||||
# Compare outputs
|
||||
torch_numpy = torch_hidden_states.numpy()
|
||||
diff = np.abs(torch_numpy - onnx_hidden_states).mean()
|
||||
|
||||
success = diff <= 1e-5
|
||||
status = "[Success]" if success else "[Fail]"
|
||||
|
||||
print(f"{status} ONNX vs PyTorch comparison")
|
||||
print(f" > mean_difference={diff}")
|
||||
print(f" > torch_shape={torch_numpy.shape}")
|
||||
print(f" > onnx_shape={onnx_hidden_states.shape}")
|
||||
|
||||
return success
|
||||
|
||||
def run_full_export_and_test(self):
|
||||
"""Run the complete export and testing pipeline"""
|
||||
print("Starting HuBERT ONNX export and testing...")
|
||||
|
||||
# Create output directory if it doesn't exist
|
||||
os.makedirs(os.path.dirname(self.onnx_path), exist_ok=True)
|
||||
|
||||
# Setup model
|
||||
self.setup_model()
|
||||
|
||||
# Export to ONNX
|
||||
self.export_to_onnx()
|
||||
|
||||
# Test export
|
||||
self.test_onnx_export_exists()
|
||||
self.test_torch_vs_onnx()
|
||||
|
||||
print("Export and testing complete!")
|
||||
|
||||
|
||||
def main():
|
||||
"""Main execution function"""
|
||||
parser = argparse.ArgumentParser(description="Export HuBERT model to ONNX format")
|
||||
parser.add_argument(
|
||||
"--model_path",
|
||||
type=str,
|
||||
default="GPT_SoVITS/pretrained_models/chinese-hubert-base",
|
||||
help="Path to the HuBERT model directory (default: GPT_SoVITS/pretrained_models/chinese-hubert-base)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_path",
|
||||
type=str,
|
||||
default="playground/hubert/chinese-hubert-base.onnx",
|
||||
help="Output path for the ONNX model (default: playground/hubert/chinese-hubert-base.onnx)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
exporter = HubertONNXExporter(model_path=args.model_path, output_path=args.output_path)
|
||||
exporter.run_full_export_and_test()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user