feat:v2pp onnx export ready testing...

This commit is contained in:
zpeng11 2025-08-17 17:54:57 -04:00
parent fdf794e31d
commit 8c0f32da3e
6 changed files with 413 additions and 15 deletions

1
.gitignore vendored
View File

@ -193,3 +193,4 @@ cython_debug/
# PyPI configuration file # PyPI configuration file
.pypirc .pypirc
onnx/

View File

@ -2,6 +2,7 @@ from torch.nn.functional import *
from torch.nn.functional import ( from torch.nn.functional import (
_canonical_mask, _canonical_mask,
) )
from typing import Tuple
def multi_head_attention_forward_patched( def multi_head_attention_forward_patched(

View File

@ -13,6 +13,7 @@ __all__ = [
"mel_scale_scalar", "mel_scale_scalar",
"spectrogram", "spectrogram",
"fbank", "fbank",
"fbank_onnx"
"mfcc", "mfcc",
"vtln_warp_freq", "vtln_warp_freq",
"vtln_warp_mel_freq", "vtln_warp_mel_freq",
@ -842,3 +843,371 @@ def mfcc(
feature = _subtract_column_mean(feature, subtract_mean) feature = _subtract_column_mean(feature, subtract_mean)
return feature return feature
def _get_log_energy_onnx(strided_input: Tensor, epsilon: Tensor, energy_floor: float = 1.0) -> Tensor:
r"""Returns the log energy of size (m) for a strided_input (m,*)"""
device, dtype = strided_input.device, strided_input.dtype
log_energy = torch.max(strided_input.pow(2).sum(1), epsilon).log() # size (m)
return torch.max(log_energy, torch.tensor(0.0, device=device, dtype=dtype))
def _get_waveform_and_window_properties_onnx(
waveform: Tensor,
) -> Tuple[Tensor, int, int, int]:
r"""ONNX-compatible version with hardcoded parameters from traced fbank call:
channel=-1, sample_frequency=16000, frame_shift=10.0, frame_length=25.0,
round_to_power_of_two=True, preemphasis_coefficient=0.97"""
# Hardcoded values from traced parameters
# channel=-1 -> 0 after max(channel, 0)
channel = 0
# Extract channel 0 from waveform
if waveform.dim() == 1:
# Mono waveform, use as-is
waveform_selected = waveform
else:
# Multi-channel, select first channel
waveform_selected = waveform[channel, :]
# Hardcoded calculations:
# window_shift = int(16000 * 10.0 * 0.001) = 160
# window_size = int(16000 * 25.0 * 0.001) = 400
# padded_window_size = _next_power_of_2(400) = 512
window_shift = 160
window_size = 400
padded_window_size = 512
return waveform_selected, window_shift, window_size, padded_window_size
def _get_window_onnx(
waveform: Tensor,
) -> Tuple[Tensor, Tensor]:
r"""ONNX-compatible version with hardcoded parameters from traced fbank call:
padded_window_size=512, window_size=400, window_shift=160, window_type='povey',
blackman_coeff=0.42, snip_edges=True, raw_energy=True, energy_floor=1.0,
dither=0, remove_dc_offset=True, preemphasis_coefficient=0.97
Returns:
(Tensor, Tensor): strided_input of size (m, 512) and signal_log_energy of size (m)
"""
device, dtype = waveform.device, waveform.dtype
epsilon = _get_epsilon(device, dtype)
# Hardcoded values from traced parameters
window_size = 400
window_shift = 160
padded_window_size = 512
snip_edges = True
# size (m, window_size)
strided_input = _get_strided_onnx(waveform, window_size, window_shift, snip_edges)
# dither=0, so skip dithering (lines 209-211 from original)
# remove_dc_offset=True, so execute this branch
row_means = torch.mean(strided_input, dim=1).unsqueeze(1) # size (m, 1)
strided_input = strided_input - row_means
# raw_energy=True, so execute this branch
signal_log_energy = _get_log_energy_onnx(strided_input, epsilon) # energy_floor=1.0
# preemphasis_coefficient=0.97 != 0.0, so execute this branch
offset_strided_input = torch.nn.functional.pad(strided_input.unsqueeze(0), (1, 0), mode="replicate").squeeze(0)
strided_input = strided_input - 0.97 * offset_strided_input[:, :-1]
# Apply povey window function to each row/frame
# povey window: torch.hann_window(window_size, periodic=False).pow(0.85)
window_function = torch.hann_window(window_size, periodic=False, device=device, dtype=dtype).pow(0.85).unsqueeze(0)
strided_input = strided_input * window_function
# Pad columns from window_size=400 to padded_window_size=512
padding_right = padded_window_size - window_size # 112
strided_input = torch.nn.functional.pad(strided_input.unsqueeze(0), (0, padding_right), mode="constant", value=0).squeeze(0)
# raw_energy=True, so skip the "not raw_energy" branch (lines 244-245)
return strided_input, signal_log_energy
def _get_strided_onnx(waveform: Tensor, window_size = 400, window_shift = 160, snip_edges = 512) -> Tensor:
seq_len = waveform.size(0)
# Calculate number of windows
num_windows = 1 + (seq_len - window_size) // window_shift
# Create indices for all windows at once
window_starts = torch.arange(0, num_windows * window_shift, window_shift, device=waveform.device)
window_indices = window_starts.unsqueeze(1) + torch.arange(window_size, device=waveform.device).unsqueeze(0)
# Extract windows using advanced indexing
windows = waveform[window_indices] # [num_windows, window_size]
return windows
def _subtract_column_mean_onnx(tensor: Tensor) -> Tensor:
"""ONNX-compatible version with hardcoded parameters from traced fbank call:
subtract_mean=False, so this function returns the input tensor unchanged.
Args:
tensor: Input tensor of size (m, n)
Returns:
Tensor: Same as input tensor (m, n) since subtract_mean=False
"""
# subtract_mean=False from traced parameters, so return tensor as-is
return tensor
def get_mel_banks_onnx(
device=None,
dtype=None,
) -> Tensor:
"""ONNX-compatible version with hardcoded parameters from traced fbank call:
num_bins=80, window_length_padded=512, sample_freq=16000, low_freq=20.0,
high_freq=0.0, vtln_low=100.0, vtln_high=-500.0, vtln_warp_factor=1.0
Returns:
Tensor: melbank of size (80, 256) (num_bins, num_fft_bins)
"""
# Hardcoded values from traced parameters
num_bins = 80
window_length_padded = 512
sample_freq = 16000.0
low_freq = 20.0
high_freq = 0.0 # Will be adjusted to nyquist
vtln_warp_factor = 1.0
# Calculate dynamic values to ensure accuracy
num_fft_bins = window_length_padded // 2 # 256 (integer division)
nyquist = 0.5 * sample_freq # 8000.0
# high_freq <= 0.0, so high_freq += nyquist
if high_freq <= 0.0:
high_freq += nyquist # 8000.0
# fft-bin width = sample_freq / window_length_padded = 16000 / 512 = 31.25
fft_bin_width = sample_freq / window_length_padded
# Calculate mel scale values dynamically
mel_low_freq = mel_scale_scalar(low_freq)
mel_high_freq = mel_scale_scalar(high_freq)
# mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1)
mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1)
# vtln_warp_factor == 1.0, so no VTLN warping needed
# Create mel bin centers
bin_indices = torch.arange(num_bins, device=device, dtype=dtype).unsqueeze(1)
left_mel = mel_low_freq + bin_indices * mel_freq_delta
center_mel = mel_low_freq + (bin_indices + 1.0) * mel_freq_delta
right_mel = mel_low_freq + (bin_indices + 2.0) * mel_freq_delta
# No VTLN warping since vtln_warp_factor == 1.0
# Create frequency bins for FFT
fft_freqs = fft_bin_width * torch.arange(num_fft_bins, device=device, dtype=dtype)
mel = mel_scale(fft_freqs).unsqueeze(0) # size(1, num_fft_bins)
# Calculate triangular filter banks
up_slope = (mel - left_mel) / (center_mel - left_mel)
down_slope = (right_mel - mel) / (right_mel - center_mel)
# Since vtln_warp_factor == 1.0, use the simpler branch
bins = torch.max(torch.zeros(1, device=device, dtype=dtype), torch.min(up_slope, down_slope))
return bins
def fbank_onnx(
waveform: Tensor, num_mel_bins=80, sample_frequency=16000, dither=0
) -> Tensor:
r"""ONNX-compatible fbank function with hardcoded parameters from traced call:
num_mel_bins=80, sample_frequency=16000, dither=0
All other parameters use their traced default values.
Args:
waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
Returns:
Tensor: A fbank identical to what Kaldi would output. The shape is (m, 80)
where m is calculated in _get_strided
"""
device, dtype = waveform.device, waveform.dtype
# Use ONNX-compatible version of _get_waveform_and_window_properties
waveform_selected, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties_onnx(waveform)
# min_duration=0.0, so skip the duration check (signal will never be too short)
# Use ONNX-compatible version of _get_window
strided_input, signal_log_energy = _get_window_onnx(waveform_selected)
# spectrum = torch.fft.rfft(strided_input).abs()
m, frame_size = strided_input.shape
# Process all frames at once using batch processing
# Reshape to (m, 1, frame_size) to treat each frame as a separate batch item
batched_frames = strided_input.unsqueeze(1) # Shape: (m, 1, 512)
# Create rectangular window for all frames at once
rectangular_window = torch.ones(512, device=strided_input.device, dtype=strided_input.dtype)
# Apply STFT to all frames simultaneously
# The batch dimension allows us to process all m frames in parallel
stft_result = torch.stft(
batched_frames.flatten(0, 1), # Shape: (m, 512) - flatten batch and channel dims
n_fft=512,
hop_length=512, # Process entire frame at once
window=rectangular_window,
center=False, # Don't add padding
return_complex=False
)
# stft_result shape: (m, 257, 1, 2) where last dim is [real, imag]
# Calculate magnitude: sqrt(real^2 + imag^2)
real_part = stft_result[..., 0] # Shape: (m, 257, 1)
imag_part = stft_result[..., 1] # Shape: (m, 257, 1)
spectrum = torch.sqrt(real_part.pow(2) + imag_part.pow(2)).squeeze(-1) # Shape: (m, 257)
# use_power=True, so execute this branch
spectrum = spectrum.pow(2.0)
# Get mel filterbanks using ONNX-compatible version
mel_energies = get_mel_banks_onnx(device, dtype)
# pad right column with zeros to match FFT output size (80, 256) -> (80, 257)
mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode="constant", value=0)
# sum with mel filterbanks over the power spectrum, size (m, 80)
mel_energies = torch.mm(spectrum, mel_energies.T)
# use_log_fbank=True, so execute this branch
mel_energies = torch.max(mel_energies, _get_epsilon(device, dtype)).log()
# use_energy=False, so skip the energy addition (lines 828-834)
# Use ONNX-compatible version of _subtract_column_mean
mel_energies = _subtract_column_mean_onnx(mel_energies)
return mel_energies
# Test to compare original fbank vs fbank_onnx
if __name__ == "__main__":
import torch
print("Testing fbank vs fbank_onnx with traced parameters...")
# Create test waveform
torch.manual_seed(42)
sample_rate = 16000
duration = 1.0 # 1 second
num_samples = int(sample_rate * duration)
# Create a test waveform (sine wave + noise)
t = torch.linspace(0, duration, num_samples)
frequency = 440.0 # A4 note
waveform = torch.sin(2 * torch.pi * frequency * t) + 0.1 * torch.randn(num_samples)
# Test with both mono and stereo inputs
mono_waveform = waveform.unsqueeze(0) # Shape: (1, num_samples)
print(f"Test waveform shape: {mono_waveform.shape}")
# Test parameters from trace: num_mel_bins=80, sample_frequency=16000, dither=0
try:
print("\n=== DEBUGGING: Step-by-step comparison ===")
# Step 1: Check waveform processing
orig_waveform, orig_window_shift, orig_window_size, orig_padded_window_size = _get_waveform_and_window_properties(
mono_waveform, -1, 16000.0, 10.0, 25.0, True, 0.97
)
onnx_waveform, onnx_window_shift, onnx_window_size, onnx_padded_window_size = _get_waveform_and_window_properties_onnx(mono_waveform)
print(f"Original waveform shape: {orig_waveform.shape}")
print(f"ONNX waveform shape: {onnx_waveform.shape}")
print(f"Waveform difference: {torch.max(torch.abs(orig_waveform - onnx_waveform)).item():.2e}")
print(f"Window params - orig: shift={orig_window_shift}, size={orig_window_size}, padded={orig_padded_window_size}")
print(f"Window params - onnx: shift={onnx_window_shift}, size={onnx_window_size}, padded={onnx_padded_window_size}")
# Step 2: Check windowing
orig_strided, orig_energy = _get_window(
orig_waveform, orig_padded_window_size, orig_window_size, orig_window_shift,
'povey', 0.42, True, True, 1.0, 0, True, 0.97
)
onnx_strided, onnx_energy = _get_window_onnx(onnx_waveform)
print(f"\nOriginal strided shape: {orig_strided.shape}")
print(f"ONNX strided shape: {onnx_strided.shape}")
print(f"Strided difference: {torch.max(torch.abs(orig_strided - onnx_strided)).item():.2e}")
print(f"Energy difference: {torch.max(torch.abs(orig_energy - onnx_energy)).item():.2e}")
# Step 3: Check mel banks
orig_mel_banks = get_mel_banks(80, 512, 16000.0, 20.0, 0.0, 100.0, -500.0, 1.0, mono_waveform.device, mono_waveform.dtype)
onnx_mel_banks = get_mel_banks_onnx(mono_waveform.device, mono_waveform.dtype)
print(f"\nOriginal mel banks shape: {orig_mel_banks.shape}")
print(f"ONNX mel banks shape: {onnx_mel_banks.shape}")
print(f"Mel banks difference: {torch.max(torch.abs(orig_mel_banks - onnx_mel_banks)).item():.2e}")
# Step 4: Full comparison
print("\n=== FULL COMPARISON ===")
# Original fbank
original_result = fbank(
mono_waveform,
num_mel_bins=80,
sample_frequency=16000,
dither=0
)
# ONNX-compatible fbank
onnx_result = fbank_onnx(mono_waveform)
print(f"Original fbank output shape: {original_result.shape}")
print(f"ONNX fbank output shape: {onnx_result.shape}")
# Check if shapes match
if original_result.shape == onnx_result.shape:
print("✅ Output shapes match")
else:
print("❌ Output shapes don't match")
print(f" Original: {original_result.shape}")
print(f" ONNX: {onnx_result.shape}")
# Check numerical differences
if original_result.shape == onnx_result.shape:
diff = torch.abs(original_result - onnx_result)
max_diff = torch.max(diff).item()
mean_diff = torch.mean(diff).item()
relative_diff = torch.mean(diff / (torch.abs(original_result) + 1e-8)).item()
print(f"Max absolute difference: {max_diff:.2e}")
print(f"Mean absolute difference: {mean_diff:.2e}")
print(f"Mean relative difference: {relative_diff:.2e}")
# Find where the max difference occurs
max_idx = torch.argmax(diff)
max_coords = torch.unravel_index(max_idx, diff.shape)
print(f"Max difference at coordinates: {max_coords}")
print(f" Original value: {original_result[max_coords].item():.6f}")
print(f" ONNX value: {onnx_result[max_coords].item():.6f}")
# Check if results are numerically close
tolerance = 1e-5
if max_diff < tolerance:
print(f"✅ Results are numerically identical (within {tolerance})")
else:
print(f"❌ Results {max_diff} differ by more than {tolerance}")
# Additional statistics
print(f"Original result range: [{torch.min(original_result).item():.3f}, {torch.max(original_result).item():.3f}]")
print(f"ONNX result range: [{torch.min(onnx_result).item():.3f}, {torch.max(onnx_result).item():.3f}]")
except Exception as e:
print(f"❌ Error during testing: {e}")
import traceback
traceback.print_exc()

View File

@ -4,6 +4,7 @@ from AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule
from feature_extractor import cnhubert from feature_extractor import cnhubert
from module.models_onnx import SynthesizerTrn, symbols_v1, symbols_v2 from module.models_onnx import SynthesizerTrn, symbols_v1, symbols_v2
from torch import nn from torch import nn
from sv import SV
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base" cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
cnhubert.cnhubert_base_path = cnhubert_base_path cnhubert.cnhubert_base_path = cnhubert_base_path
@ -190,14 +191,18 @@ class T2SModel(nn.Module):
class VitsModel(nn.Module): class VitsModel(nn.Module):
def __init__(self, vits_path): def __init__(self, vits_path, version:str = 'v2'):
super().__init__() super().__init__()
dict_s2 = torch.load(vits_path, map_location="cpu") dict_s2 = torch.load(vits_path, map_location="cpu")
self.hps = dict_s2["config"] self.hps = dict_s2["config"]
if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322: if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
self.hps["model"]["version"] = "v1" self.hps["model"]["version"] = "v1"
else: else:
self.hps["model"]["version"] = "v2" self.hps["model"]["version"] = version
self.sv_model = None
if version == "v2ProPlus" or version == "v2Pro":
self.sv_model = SV("cpu", False)
self.hps = DictToAttrRecursive(self.hps) self.hps = DictToAttrRecursive(self.hps)
self.hps.model.semantic_frame_rate = "25hz" self.hps.model.semantic_frame_rate = "25hz"
@ -219,6 +224,9 @@ class VitsModel(nn.Module):
self.hps.data.win_length, self.hps.data.win_length,
center=False, center=False,
) )
if self.sv_model is not None:
sv_emb=self.sv_model.compute_embedding3_onnx(ref_audio)
return self.vq_model(pred_semantic, text_seq, refer, sv_emb=sv_emb)[0, 0]
return self.vq_model(pred_semantic, text_seq, refer)[0, 0] return self.vq_model(pred_semantic, text_seq, refer)[0, 0]
@ -274,8 +282,8 @@ class SSLModel(nn.Module):
return self.ssl.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2) return self.ssl.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2)
def export(vits_path, gpt_path, project_name, vits_model="v2"): def export(vits_path, gpt_path, project_name, voice_model_version="v2"):
vits = VitsModel(vits_path) vits = VitsModel(vits_path, version=voice_model_version)
gpt = T2SModel(gpt_path, vits) gpt = T2SModel(gpt_path, vits)
gpt_sovits = GptSoVits(vits, gpt) gpt_sovits = GptSoVits(vits, gpt)
ssl = SSLModel() ssl = SSLModel()
@ -297,7 +305,7 @@ def export(vits_path, gpt_path, project_name, vits_model="v2"):
"y", "y",
"e4", "e4",
], ],
version=vits_model, version=voice_model_version,
) )
] ]
) )
@ -330,7 +338,7 @@ def export(vits_path, gpt_path, project_name, vits_model="v2"):
"y", "y",
"e4", "e4",
], ],
version=vits_model, version=voice_model_version,
) )
] ]
) )
@ -349,9 +357,8 @@ def export(vits_path, gpt_path, project_name, vits_model="v2"):
ssl_content = ssl(ref_audio_16k).float() ssl_content = ssl(ref_audio_16k).float()
# debug = False # debug = False
debug = True debug = False
gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, project_name)
# gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, project_name)
if debug: if debug:
a, b = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, debug=debug) a, b = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, debug=debug)
@ -361,7 +368,7 @@ def export(vits_path, gpt_path, project_name, vits_model="v2"):
a = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content).detach().cpu().numpy() a = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content).detach().cpu().numpy()
soundfile.write("out.wav", a, vits.hps.data.sampling_rate) soundfile.write("out.wav", a, vits.hps.data.sampling_rate)
if vits_model == "v1": if voice_model_version == "v1":
symbols = symbols_v1 symbols = symbols_v1
else: else:
symbols = symbols_v2 symbols = symbols_v2
@ -390,9 +397,16 @@ if __name__ == "__main__":
except: except:
pass pass
gpt_path = "GPT_weights/nahida-e25.ckpt" # gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"
vits_path = "SoVITS_weights/nahida_e30_s3930.pth" # vits_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"
exp_path = "nahida" # exp_path = "v2_export"
export(vits_path, gpt_path, exp_path) # version = "v2"
# export(vits_path, gpt_path, exp_path, version)
gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"
vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth"
exp_path = "v2proplus_export"
version = "v2ProPlus"
export(vits_path, gpt_path, exp_path, version)
# soundfile.write("out.wav", a, vits.hps.data.sampling_rate)

View File

@ -30,3 +30,15 @@ class SV:
) )
sv_emb = self.embedding_model.forward3(feat) sv_emb = self.embedding_model.forward3(feat)
return sv_emb return sv_emb
def compute_embedding3_onnx(self, wav):
# Disable gradients for all parameters
for param in self.embedding_model.parameters():
param.requires_grad = False
with torch.no_grad():
if self.is_half == True:
wav = wav.half()
feat = Kaldi.fbank_onnx(wav.detach()).unsqueeze(0)
sv_emb = self.embedding_model.forward3(feat)
return sv_emb

View File

@ -7,6 +7,7 @@ numba
pytorch-lightning>=2.4 pytorch-lightning>=2.4
gradio<5 gradio<5
ffmpeg-python ffmpeg-python
onnx
onnxruntime; platform_machine == "aarch64" or platform_machine == "arm64" onnxruntime; platform_machine == "aarch64" or platform_machine == "arm64"
onnxruntime-gpu; platform_machine == "x86_64" or platform_machine == "AMD64" onnxruntime-gpu; platform_machine == "x86_64" or platform_machine == "AMD64"
tqdm tqdm