diff --git a/.gitignore b/.gitignore index d280e459..473c09c3 100644 --- a/.gitignore +++ b/.gitignore @@ -193,3 +193,4 @@ cython_debug/ # PyPI configuration file .pypirc +onnx/ \ No newline at end of file diff --git a/GPT_SoVITS/AR/modules/patched_mha_with_cache_onnx.py b/GPT_SoVITS/AR/modules/patched_mha_with_cache_onnx.py index 8144c9c6..b38a3907 100644 --- a/GPT_SoVITS/AR/modules/patched_mha_with_cache_onnx.py +++ b/GPT_SoVITS/AR/modules/patched_mha_with_cache_onnx.py @@ -2,6 +2,7 @@ from torch.nn.functional import * from torch.nn.functional import ( _canonical_mask, ) +from typing import Tuple def multi_head_attention_forward_patched( diff --git a/GPT_SoVITS/eres2net/kaldi.py b/GPT_SoVITS/eres2net/kaldi.py index a80e5e6b..42565e00 100644 --- a/GPT_SoVITS/eres2net/kaldi.py +++ b/GPT_SoVITS/eres2net/kaldi.py @@ -13,6 +13,7 @@ __all__ = [ "mel_scale_scalar", "spectrogram", "fbank", + "fbank_onnx" "mfcc", "vtln_warp_freq", "vtln_warp_mel_freq", @@ -842,3 +843,371 @@ def mfcc( feature = _subtract_column_mean(feature, subtract_mean) return feature + +def _get_log_energy_onnx(strided_input: Tensor, epsilon: Tensor, energy_floor: float = 1.0) -> Tensor: + r"""Returns the log energy of size (m) for a strided_input (m,*)""" + device, dtype = strided_input.device, strided_input.dtype + log_energy = torch.max(strided_input.pow(2).sum(1), epsilon).log() # size (m) + return torch.max(log_energy, torch.tensor(0.0, device=device, dtype=dtype)) + + +def _get_waveform_and_window_properties_onnx( + waveform: Tensor, +) -> Tuple[Tensor, int, int, int]: + r"""ONNX-compatible version with hardcoded parameters from traced fbank call: + channel=-1, sample_frequency=16000, frame_shift=10.0, frame_length=25.0, + round_to_power_of_two=True, preemphasis_coefficient=0.97""" + + # Hardcoded values from traced parameters + # channel=-1 -> 0 after max(channel, 0) + channel = 0 + + # Extract channel 0 from waveform + if waveform.dim() == 1: + # Mono waveform, use as-is + waveform_selected = waveform + else: + # Multi-channel, select first channel + waveform_selected = waveform[channel, :] + + # Hardcoded calculations: + # window_shift = int(16000 * 10.0 * 0.001) = 160 + # window_size = int(16000 * 25.0 * 0.001) = 400 + # padded_window_size = _next_power_of_2(400) = 512 + window_shift = 160 + window_size = 400 + padded_window_size = 512 + + return waveform_selected, window_shift, window_size, padded_window_size + +def _get_window_onnx( + waveform: Tensor, +) -> Tuple[Tensor, Tensor]: + r"""ONNX-compatible version with hardcoded parameters from traced fbank call: + padded_window_size=512, window_size=400, window_shift=160, window_type='povey', + blackman_coeff=0.42, snip_edges=True, raw_energy=True, energy_floor=1.0, + dither=0, remove_dc_offset=True, preemphasis_coefficient=0.97 + + Returns: + (Tensor, Tensor): strided_input of size (m, 512) and signal_log_energy of size (m) + """ + device, dtype = waveform.device, waveform.dtype + epsilon = _get_epsilon(device, dtype) + + # Hardcoded values from traced parameters + window_size = 400 + window_shift = 160 + padded_window_size = 512 + snip_edges = True + + # size (m, window_size) + strided_input = _get_strided_onnx(waveform, window_size, window_shift, snip_edges) + + # dither=0, so skip dithering (lines 209-211 from original) + + # remove_dc_offset=True, so execute this branch + row_means = torch.mean(strided_input, dim=1).unsqueeze(1) # size (m, 1) + strided_input = strided_input - row_means + + # raw_energy=True, so execute this branch + signal_log_energy = _get_log_energy_onnx(strided_input, epsilon) # energy_floor=1.0 + + # preemphasis_coefficient=0.97 != 0.0, so execute this branch + offset_strided_input = torch.nn.functional.pad(strided_input.unsqueeze(0), (1, 0), mode="replicate").squeeze(0) + strided_input = strided_input - 0.97 * offset_strided_input[:, :-1] + + # Apply povey window function to each row/frame + # povey window: torch.hann_window(window_size, periodic=False).pow(0.85) + window_function = torch.hann_window(window_size, periodic=False, device=device, dtype=dtype).pow(0.85).unsqueeze(0) + strided_input = strided_input * window_function + + # Pad columns from window_size=400 to padded_window_size=512 + padding_right = padded_window_size - window_size # 112 + strided_input = torch.nn.functional.pad(strided_input.unsqueeze(0), (0, padding_right), mode="constant", value=0).squeeze(0) + + # raw_energy=True, so skip the "not raw_energy" branch (lines 244-245) + return strided_input, signal_log_energy + + +def _get_strided_onnx(waveform: Tensor, window_size = 400, window_shift = 160, snip_edges = 512) -> Tensor: + seq_len = waveform.size(0) + + # Calculate number of windows + num_windows = 1 + (seq_len - window_size) // window_shift + + # Create indices for all windows at once + window_starts = torch.arange(0, num_windows * window_shift, window_shift, device=waveform.device) + window_indices = window_starts.unsqueeze(1) + torch.arange(window_size, device=waveform.device).unsqueeze(0) + + # Extract windows using advanced indexing + windows = waveform[window_indices] # [num_windows, window_size] + + return windows + + +def _subtract_column_mean_onnx(tensor: Tensor) -> Tensor: + """ONNX-compatible version with hardcoded parameters from traced fbank call: + subtract_mean=False, so this function returns the input tensor unchanged. + + Args: + tensor: Input tensor of size (m, n) + + Returns: + Tensor: Same as input tensor (m, n) since subtract_mean=False + """ + # subtract_mean=False from traced parameters, so return tensor as-is + return tensor + + +def get_mel_banks_onnx( + device=None, + dtype=None, +) -> Tensor: + """ONNX-compatible version with hardcoded parameters from traced fbank call: + num_bins=80, window_length_padded=512, sample_freq=16000, low_freq=20.0, + high_freq=0.0, vtln_low=100.0, vtln_high=-500.0, vtln_warp_factor=1.0 + + Returns: + Tensor: melbank of size (80, 256) (num_bins, num_fft_bins) + """ + # Hardcoded values from traced parameters + num_bins = 80 + window_length_padded = 512 + sample_freq = 16000.0 + low_freq = 20.0 + high_freq = 0.0 # Will be adjusted to nyquist + vtln_warp_factor = 1.0 + + # Calculate dynamic values to ensure accuracy + num_fft_bins = window_length_padded // 2 # 256 (integer division) + nyquist = 0.5 * sample_freq # 8000.0 + + # high_freq <= 0.0, so high_freq += nyquist + if high_freq <= 0.0: + high_freq += nyquist # 8000.0 + + # fft-bin width = sample_freq / window_length_padded = 16000 / 512 = 31.25 + fft_bin_width = sample_freq / window_length_padded + + # Calculate mel scale values dynamically + mel_low_freq = mel_scale_scalar(low_freq) + mel_high_freq = mel_scale_scalar(high_freq) + + # mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1) + mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1) + + # vtln_warp_factor == 1.0, so no VTLN warping needed + + # Create mel bin centers + bin_indices = torch.arange(num_bins, device=device, dtype=dtype).unsqueeze(1) + left_mel = mel_low_freq + bin_indices * mel_freq_delta + center_mel = mel_low_freq + (bin_indices + 1.0) * mel_freq_delta + right_mel = mel_low_freq + (bin_indices + 2.0) * mel_freq_delta + + # No VTLN warping since vtln_warp_factor == 1.0 + + # Create frequency bins for FFT + fft_freqs = fft_bin_width * torch.arange(num_fft_bins, device=device, dtype=dtype) + mel = mel_scale(fft_freqs).unsqueeze(0) # size(1, num_fft_bins) + + # Calculate triangular filter banks + up_slope = (mel - left_mel) / (center_mel - left_mel) + down_slope = (right_mel - mel) / (right_mel - center_mel) + + # Since vtln_warp_factor == 1.0, use the simpler branch + bins = torch.max(torch.zeros(1, device=device, dtype=dtype), torch.min(up_slope, down_slope)) + + return bins + + +def fbank_onnx( + waveform: Tensor, num_mel_bins=80, sample_frequency=16000, dither=0 +) -> Tensor: + r"""ONNX-compatible fbank function with hardcoded parameters from traced call: + num_mel_bins=80, sample_frequency=16000, dither=0 + + All other parameters use their traced default values. + + Args: + waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2) + + Returns: + Tensor: A fbank identical to what Kaldi would output. The shape is (m, 80) + where m is calculated in _get_strided + """ + device, dtype = waveform.device, waveform.dtype + + # Use ONNX-compatible version of _get_waveform_and_window_properties + waveform_selected, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties_onnx(waveform) + + # min_duration=0.0, so skip the duration check (signal will never be too short) + + # Use ONNX-compatible version of _get_window + strided_input, signal_log_energy = _get_window_onnx(waveform_selected) + + # spectrum = torch.fft.rfft(strided_input).abs() + + m, frame_size = strided_input.shape + + # Process all frames at once using batch processing + # Reshape to (m, 1, frame_size) to treat each frame as a separate batch item + batched_frames = strided_input.unsqueeze(1) # Shape: (m, 1, 512) + + # Create rectangular window for all frames at once + rectangular_window = torch.ones(512, device=strided_input.device, dtype=strided_input.dtype) + + # Apply STFT to all frames simultaneously + # The batch dimension allows us to process all m frames in parallel + stft_result = torch.stft( + batched_frames.flatten(0, 1), # Shape: (m, 512) - flatten batch and channel dims + n_fft=512, + hop_length=512, # Process entire frame at once + window=rectangular_window, + center=False, # Don't add padding + return_complex=False + ) + + # stft_result shape: (m, 257, 1, 2) where last dim is [real, imag] + # Calculate magnitude: sqrt(real^2 + imag^2) + real_part = stft_result[..., 0] # Shape: (m, 257, 1) + imag_part = stft_result[..., 1] # Shape: (m, 257, 1) + spectrum = torch.sqrt(real_part.pow(2) + imag_part.pow(2)).squeeze(-1) # Shape: (m, 257) + + # use_power=True, so execute this branch + spectrum = spectrum.pow(2.0) + + # Get mel filterbanks using ONNX-compatible version + mel_energies = get_mel_banks_onnx(device, dtype) + + # pad right column with zeros to match FFT output size (80, 256) -> (80, 257) + mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode="constant", value=0) + + # sum with mel filterbanks over the power spectrum, size (m, 80) + mel_energies = torch.mm(spectrum, mel_energies.T) + + # use_log_fbank=True, so execute this branch + mel_energies = torch.max(mel_energies, _get_epsilon(device, dtype)).log() + + # use_energy=False, so skip the energy addition (lines 828-834) + + # Use ONNX-compatible version of _subtract_column_mean + mel_energies = _subtract_column_mean_onnx(mel_energies) + + return mel_energies + +# Test to compare original fbank vs fbank_onnx +if __name__ == "__main__": + import torch + + print("Testing fbank vs fbank_onnx with traced parameters...") + + # Create test waveform + torch.manual_seed(42) + sample_rate = 16000 + duration = 1.0 # 1 second + num_samples = int(sample_rate * duration) + + # Create a test waveform (sine wave + noise) + t = torch.linspace(0, duration, num_samples) + frequency = 440.0 # A4 note + waveform = torch.sin(2 * torch.pi * frequency * t) + 0.1 * torch.randn(num_samples) + + # Test with both mono and stereo inputs + mono_waveform = waveform.unsqueeze(0) # Shape: (1, num_samples) + + print(f"Test waveform shape: {mono_waveform.shape}") + + # Test parameters from trace: num_mel_bins=80, sample_frequency=16000, dither=0 + try: + print("\n=== DEBUGGING: Step-by-step comparison ===") + + # Step 1: Check waveform processing + orig_waveform, orig_window_shift, orig_window_size, orig_padded_window_size = _get_waveform_and_window_properties( + mono_waveform, -1, 16000.0, 10.0, 25.0, True, 0.97 + ) + onnx_waveform, onnx_window_shift, onnx_window_size, onnx_padded_window_size = _get_waveform_and_window_properties_onnx(mono_waveform) + + print(f"Original waveform shape: {orig_waveform.shape}") + print(f"ONNX waveform shape: {onnx_waveform.shape}") + print(f"Waveform difference: {torch.max(torch.abs(orig_waveform - onnx_waveform)).item():.2e}") + print(f"Window params - orig: shift={orig_window_shift}, size={orig_window_size}, padded={orig_padded_window_size}") + print(f"Window params - onnx: shift={onnx_window_shift}, size={onnx_window_size}, padded={onnx_padded_window_size}") + + # Step 2: Check windowing + orig_strided, orig_energy = _get_window( + orig_waveform, orig_padded_window_size, orig_window_size, orig_window_shift, + 'povey', 0.42, True, True, 1.0, 0, True, 0.97 + ) + onnx_strided, onnx_energy = _get_window_onnx(onnx_waveform) + + print(f"\nOriginal strided shape: {orig_strided.shape}") + print(f"ONNX strided shape: {onnx_strided.shape}") + print(f"Strided difference: {torch.max(torch.abs(orig_strided - onnx_strided)).item():.2e}") + print(f"Energy difference: {torch.max(torch.abs(orig_energy - onnx_energy)).item():.2e}") + + # Step 3: Check mel banks + orig_mel_banks = get_mel_banks(80, 512, 16000.0, 20.0, 0.0, 100.0, -500.0, 1.0, mono_waveform.device, mono_waveform.dtype) + onnx_mel_banks = get_mel_banks_onnx(mono_waveform.device, mono_waveform.dtype) + + print(f"\nOriginal mel banks shape: {orig_mel_banks.shape}") + print(f"ONNX mel banks shape: {onnx_mel_banks.shape}") + print(f"Mel banks difference: {torch.max(torch.abs(orig_mel_banks - onnx_mel_banks)).item():.2e}") + + # Step 4: Full comparison + print("\n=== FULL COMPARISON ===") + + # Original fbank + original_result = fbank( + mono_waveform, + num_mel_bins=80, + sample_frequency=16000, + dither=0 + ) + + # ONNX-compatible fbank + onnx_result = fbank_onnx(mono_waveform) + + print(f"Original fbank output shape: {original_result.shape}") + print(f"ONNX fbank output shape: {onnx_result.shape}") + + # Check if shapes match + if original_result.shape == onnx_result.shape: + print("✅ Output shapes match") + else: + print("❌ Output shapes don't match") + print(f" Original: {original_result.shape}") + print(f" ONNX: {onnx_result.shape}") + + # Check numerical differences + if original_result.shape == onnx_result.shape: + diff = torch.abs(original_result - onnx_result) + max_diff = torch.max(diff).item() + mean_diff = torch.mean(diff).item() + relative_diff = torch.mean(diff / (torch.abs(original_result) + 1e-8)).item() + + print(f"Max absolute difference: {max_diff:.2e}") + print(f"Mean absolute difference: {mean_diff:.2e}") + print(f"Mean relative difference: {relative_diff:.2e}") + + # Find where the max difference occurs + max_idx = torch.argmax(diff) + max_coords = torch.unravel_index(max_idx, diff.shape) + print(f"Max difference at coordinates: {max_coords}") + print(f" Original value: {original_result[max_coords].item():.6f}") + print(f" ONNX value: {onnx_result[max_coords].item():.6f}") + + # Check if results are numerically close + tolerance = 1e-5 + if max_diff < tolerance: + print(f"✅ Results are numerically identical (within {tolerance})") + else: + print(f"❌ Results {max_diff} differ by more than {tolerance}") + + # Additional statistics + print(f"Original result range: [{torch.min(original_result).item():.3f}, {torch.max(original_result).item():.3f}]") + print(f"ONNX result range: [{torch.min(onnx_result).item():.3f}, {torch.max(onnx_result).item():.3f}]") + + except Exception as e: + print(f"❌ Error during testing: {e}") + import traceback + traceback.print_exc() diff --git a/GPT_SoVITS/onnx_export.py b/GPT_SoVITS/onnx_export.py index fd680135..9f100a5c 100644 --- a/GPT_SoVITS/onnx_export.py +++ b/GPT_SoVITS/onnx_export.py @@ -4,6 +4,7 @@ from AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule from feature_extractor import cnhubert from module.models_onnx import SynthesizerTrn, symbols_v1, symbols_v2 from torch import nn +from sv import SV cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base" cnhubert.cnhubert_base_path = cnhubert_base_path @@ -190,14 +191,18 @@ class T2SModel(nn.Module): class VitsModel(nn.Module): - def __init__(self, vits_path): + def __init__(self, vits_path, version:str = 'v2'): super().__init__() dict_s2 = torch.load(vits_path, map_location="cpu") self.hps = dict_s2["config"] if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322: self.hps["model"]["version"] = "v1" else: - self.hps["model"]["version"] = "v2" + self.hps["model"]["version"] = version + + self.sv_model = None + if version == "v2ProPlus" or version == "v2Pro": + self.sv_model = SV("cpu", False) self.hps = DictToAttrRecursive(self.hps) self.hps.model.semantic_frame_rate = "25hz" @@ -219,6 +224,9 @@ class VitsModel(nn.Module): self.hps.data.win_length, center=False, ) + if self.sv_model is not None: + sv_emb=self.sv_model.compute_embedding3_onnx(ref_audio) + return self.vq_model(pred_semantic, text_seq, refer, sv_emb=sv_emb)[0, 0] return self.vq_model(pred_semantic, text_seq, refer)[0, 0] @@ -274,8 +282,8 @@ class SSLModel(nn.Module): return self.ssl.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2) -def export(vits_path, gpt_path, project_name, vits_model="v2"): - vits = VitsModel(vits_path) +def export(vits_path, gpt_path, project_name, voice_model_version="v2"): + vits = VitsModel(vits_path, version=voice_model_version) gpt = T2SModel(gpt_path, vits) gpt_sovits = GptSoVits(vits, gpt) ssl = SSLModel() @@ -297,7 +305,7 @@ def export(vits_path, gpt_path, project_name, vits_model="v2"): "y", "e4", ], - version=vits_model, + version=voice_model_version, ) ] ) @@ -330,7 +338,7 @@ def export(vits_path, gpt_path, project_name, vits_model="v2"): "y", "e4", ], - version=vits_model, + version=voice_model_version, ) ] ) @@ -349,9 +357,8 @@ def export(vits_path, gpt_path, project_name, vits_model="v2"): ssl_content = ssl(ref_audio_16k).float() # debug = False - debug = True - - # gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, project_name) + debug = False + gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, project_name) if debug: a, b = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, debug=debug) @@ -361,7 +368,7 @@ def export(vits_path, gpt_path, project_name, vits_model="v2"): a = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content).detach().cpu().numpy() soundfile.write("out.wav", a, vits.hps.data.sampling_rate) - if vits_model == "v1": + if voice_model_version == "v1": symbols = symbols_v1 else: symbols = symbols_v2 @@ -390,9 +397,16 @@ if __name__ == "__main__": except: pass - gpt_path = "GPT_weights/nahida-e25.ckpt" - vits_path = "SoVITS_weights/nahida_e30_s3930.pth" - exp_path = "nahida" - export(vits_path, gpt_path, exp_path) + # gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt" + # vits_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth" + # exp_path = "v2_export" + # version = "v2" + # export(vits_path, gpt_path, exp_path, version) + + gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt" + vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth" + exp_path = "v2proplus_export" + version = "v2ProPlus" + export(vits_path, gpt_path, exp_path, version) + - # soundfile.write("out.wav", a, vits.hps.data.sampling_rate) diff --git a/GPT_SoVITS/sv.py b/GPT_SoVITS/sv.py index 22e70369..50069c7e 100644 --- a/GPT_SoVITS/sv.py +++ b/GPT_SoVITS/sv.py @@ -30,3 +30,15 @@ class SV: ) sv_emb = self.embedding_model.forward3(feat) return sv_emb + + def compute_embedding3_onnx(self, wav): + # Disable gradients for all parameters + for param in self.embedding_model.parameters(): + param.requires_grad = False + + with torch.no_grad(): + if self.is_half == True: + wav = wav.half() + feat = Kaldi.fbank_onnx(wav.detach()).unsqueeze(0) + sv_emb = self.embedding_model.forward3(feat) + return sv_emb \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 90e4957d..b09d2b79 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,6 +7,7 @@ numba pytorch-lightning>=2.4 gradio<5 ffmpeg-python +onnx onnxruntime; platform_machine == "aarch64" or platform_machine == "arm64" onnxruntime-gpu; platform_machine == "x86_64" or platform_machine == "AMD64" tqdm