mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-30 01:25:58 +08:00
feat:v2pp onnx export ready testing...
This commit is contained in:
parent
fdf794e31d
commit
8c0f32da3e
1
.gitignore
vendored
1
.gitignore
vendored
@ -193,3 +193,4 @@ cython_debug/
|
|||||||
|
|
||||||
# PyPI configuration file
|
# PyPI configuration file
|
||||||
.pypirc
|
.pypirc
|
||||||
|
onnx/
|
@ -2,6 +2,7 @@ from torch.nn.functional import *
|
|||||||
from torch.nn.functional import (
|
from torch.nn.functional import (
|
||||||
_canonical_mask,
|
_canonical_mask,
|
||||||
)
|
)
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
|
||||||
def multi_head_attention_forward_patched(
|
def multi_head_attention_forward_patched(
|
||||||
|
@ -13,6 +13,7 @@ __all__ = [
|
|||||||
"mel_scale_scalar",
|
"mel_scale_scalar",
|
||||||
"spectrogram",
|
"spectrogram",
|
||||||
"fbank",
|
"fbank",
|
||||||
|
"fbank_onnx"
|
||||||
"mfcc",
|
"mfcc",
|
||||||
"vtln_warp_freq",
|
"vtln_warp_freq",
|
||||||
"vtln_warp_mel_freq",
|
"vtln_warp_mel_freq",
|
||||||
@ -842,3 +843,371 @@ def mfcc(
|
|||||||
|
|
||||||
feature = _subtract_column_mean(feature, subtract_mean)
|
feature = _subtract_column_mean(feature, subtract_mean)
|
||||||
return feature
|
return feature
|
||||||
|
|
||||||
|
def _get_log_energy_onnx(strided_input: Tensor, epsilon: Tensor, energy_floor: float = 1.0) -> Tensor:
|
||||||
|
r"""Returns the log energy of size (m) for a strided_input (m,*)"""
|
||||||
|
device, dtype = strided_input.device, strided_input.dtype
|
||||||
|
log_energy = torch.max(strided_input.pow(2).sum(1), epsilon).log() # size (m)
|
||||||
|
return torch.max(log_energy, torch.tensor(0.0, device=device, dtype=dtype))
|
||||||
|
|
||||||
|
|
||||||
|
def _get_waveform_and_window_properties_onnx(
|
||||||
|
waveform: Tensor,
|
||||||
|
) -> Tuple[Tensor, int, int, int]:
|
||||||
|
r"""ONNX-compatible version with hardcoded parameters from traced fbank call:
|
||||||
|
channel=-1, sample_frequency=16000, frame_shift=10.0, frame_length=25.0,
|
||||||
|
round_to_power_of_two=True, preemphasis_coefficient=0.97"""
|
||||||
|
|
||||||
|
# Hardcoded values from traced parameters
|
||||||
|
# channel=-1 -> 0 after max(channel, 0)
|
||||||
|
channel = 0
|
||||||
|
|
||||||
|
# Extract channel 0 from waveform
|
||||||
|
if waveform.dim() == 1:
|
||||||
|
# Mono waveform, use as-is
|
||||||
|
waveform_selected = waveform
|
||||||
|
else:
|
||||||
|
# Multi-channel, select first channel
|
||||||
|
waveform_selected = waveform[channel, :]
|
||||||
|
|
||||||
|
# Hardcoded calculations:
|
||||||
|
# window_shift = int(16000 * 10.0 * 0.001) = 160
|
||||||
|
# window_size = int(16000 * 25.0 * 0.001) = 400
|
||||||
|
# padded_window_size = _next_power_of_2(400) = 512
|
||||||
|
window_shift = 160
|
||||||
|
window_size = 400
|
||||||
|
padded_window_size = 512
|
||||||
|
|
||||||
|
return waveform_selected, window_shift, window_size, padded_window_size
|
||||||
|
|
||||||
|
def _get_window_onnx(
|
||||||
|
waveform: Tensor,
|
||||||
|
) -> Tuple[Tensor, Tensor]:
|
||||||
|
r"""ONNX-compatible version with hardcoded parameters from traced fbank call:
|
||||||
|
padded_window_size=512, window_size=400, window_shift=160, window_type='povey',
|
||||||
|
blackman_coeff=0.42, snip_edges=True, raw_energy=True, energy_floor=1.0,
|
||||||
|
dither=0, remove_dc_offset=True, preemphasis_coefficient=0.97
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(Tensor, Tensor): strided_input of size (m, 512) and signal_log_energy of size (m)
|
||||||
|
"""
|
||||||
|
device, dtype = waveform.device, waveform.dtype
|
||||||
|
epsilon = _get_epsilon(device, dtype)
|
||||||
|
|
||||||
|
# Hardcoded values from traced parameters
|
||||||
|
window_size = 400
|
||||||
|
window_shift = 160
|
||||||
|
padded_window_size = 512
|
||||||
|
snip_edges = True
|
||||||
|
|
||||||
|
# size (m, window_size)
|
||||||
|
strided_input = _get_strided_onnx(waveform, window_size, window_shift, snip_edges)
|
||||||
|
|
||||||
|
# dither=0, so skip dithering (lines 209-211 from original)
|
||||||
|
|
||||||
|
# remove_dc_offset=True, so execute this branch
|
||||||
|
row_means = torch.mean(strided_input, dim=1).unsqueeze(1) # size (m, 1)
|
||||||
|
strided_input = strided_input - row_means
|
||||||
|
|
||||||
|
# raw_energy=True, so execute this branch
|
||||||
|
signal_log_energy = _get_log_energy_onnx(strided_input, epsilon) # energy_floor=1.0
|
||||||
|
|
||||||
|
# preemphasis_coefficient=0.97 != 0.0, so execute this branch
|
||||||
|
offset_strided_input = torch.nn.functional.pad(strided_input.unsqueeze(0), (1, 0), mode="replicate").squeeze(0)
|
||||||
|
strided_input = strided_input - 0.97 * offset_strided_input[:, :-1]
|
||||||
|
|
||||||
|
# Apply povey window function to each row/frame
|
||||||
|
# povey window: torch.hann_window(window_size, periodic=False).pow(0.85)
|
||||||
|
window_function = torch.hann_window(window_size, periodic=False, device=device, dtype=dtype).pow(0.85).unsqueeze(0)
|
||||||
|
strided_input = strided_input * window_function
|
||||||
|
|
||||||
|
# Pad columns from window_size=400 to padded_window_size=512
|
||||||
|
padding_right = padded_window_size - window_size # 112
|
||||||
|
strided_input = torch.nn.functional.pad(strided_input.unsqueeze(0), (0, padding_right), mode="constant", value=0).squeeze(0)
|
||||||
|
|
||||||
|
# raw_energy=True, so skip the "not raw_energy" branch (lines 244-245)
|
||||||
|
return strided_input, signal_log_energy
|
||||||
|
|
||||||
|
|
||||||
|
def _get_strided_onnx(waveform: Tensor, window_size = 400, window_shift = 160, snip_edges = 512) -> Tensor:
|
||||||
|
seq_len = waveform.size(0)
|
||||||
|
|
||||||
|
# Calculate number of windows
|
||||||
|
num_windows = 1 + (seq_len - window_size) // window_shift
|
||||||
|
|
||||||
|
# Create indices for all windows at once
|
||||||
|
window_starts = torch.arange(0, num_windows * window_shift, window_shift, device=waveform.device)
|
||||||
|
window_indices = window_starts.unsqueeze(1) + torch.arange(window_size, device=waveform.device).unsqueeze(0)
|
||||||
|
|
||||||
|
# Extract windows using advanced indexing
|
||||||
|
windows = waveform[window_indices] # [num_windows, window_size]
|
||||||
|
|
||||||
|
return windows
|
||||||
|
|
||||||
|
|
||||||
|
def _subtract_column_mean_onnx(tensor: Tensor) -> Tensor:
|
||||||
|
"""ONNX-compatible version with hardcoded parameters from traced fbank call:
|
||||||
|
subtract_mean=False, so this function returns the input tensor unchanged.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor: Input tensor of size (m, n)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Same as input tensor (m, n) since subtract_mean=False
|
||||||
|
"""
|
||||||
|
# subtract_mean=False from traced parameters, so return tensor as-is
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
def get_mel_banks_onnx(
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
) -> Tensor:
|
||||||
|
"""ONNX-compatible version with hardcoded parameters from traced fbank call:
|
||||||
|
num_bins=80, window_length_padded=512, sample_freq=16000, low_freq=20.0,
|
||||||
|
high_freq=0.0, vtln_low=100.0, vtln_high=-500.0, vtln_warp_factor=1.0
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: melbank of size (80, 256) (num_bins, num_fft_bins)
|
||||||
|
"""
|
||||||
|
# Hardcoded values from traced parameters
|
||||||
|
num_bins = 80
|
||||||
|
window_length_padded = 512
|
||||||
|
sample_freq = 16000.0
|
||||||
|
low_freq = 20.0
|
||||||
|
high_freq = 0.0 # Will be adjusted to nyquist
|
||||||
|
vtln_warp_factor = 1.0
|
||||||
|
|
||||||
|
# Calculate dynamic values to ensure accuracy
|
||||||
|
num_fft_bins = window_length_padded // 2 # 256 (integer division)
|
||||||
|
nyquist = 0.5 * sample_freq # 8000.0
|
||||||
|
|
||||||
|
# high_freq <= 0.0, so high_freq += nyquist
|
||||||
|
if high_freq <= 0.0:
|
||||||
|
high_freq += nyquist # 8000.0
|
||||||
|
|
||||||
|
# fft-bin width = sample_freq / window_length_padded = 16000 / 512 = 31.25
|
||||||
|
fft_bin_width = sample_freq / window_length_padded
|
||||||
|
|
||||||
|
# Calculate mel scale values dynamically
|
||||||
|
mel_low_freq = mel_scale_scalar(low_freq)
|
||||||
|
mel_high_freq = mel_scale_scalar(high_freq)
|
||||||
|
|
||||||
|
# mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1)
|
||||||
|
mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1)
|
||||||
|
|
||||||
|
# vtln_warp_factor == 1.0, so no VTLN warping needed
|
||||||
|
|
||||||
|
# Create mel bin centers
|
||||||
|
bin_indices = torch.arange(num_bins, device=device, dtype=dtype).unsqueeze(1)
|
||||||
|
left_mel = mel_low_freq + bin_indices * mel_freq_delta
|
||||||
|
center_mel = mel_low_freq + (bin_indices + 1.0) * mel_freq_delta
|
||||||
|
right_mel = mel_low_freq + (bin_indices + 2.0) * mel_freq_delta
|
||||||
|
|
||||||
|
# No VTLN warping since vtln_warp_factor == 1.0
|
||||||
|
|
||||||
|
# Create frequency bins for FFT
|
||||||
|
fft_freqs = fft_bin_width * torch.arange(num_fft_bins, device=device, dtype=dtype)
|
||||||
|
mel = mel_scale(fft_freqs).unsqueeze(0) # size(1, num_fft_bins)
|
||||||
|
|
||||||
|
# Calculate triangular filter banks
|
||||||
|
up_slope = (mel - left_mel) / (center_mel - left_mel)
|
||||||
|
down_slope = (right_mel - mel) / (right_mel - center_mel)
|
||||||
|
|
||||||
|
# Since vtln_warp_factor == 1.0, use the simpler branch
|
||||||
|
bins = torch.max(torch.zeros(1, device=device, dtype=dtype), torch.min(up_slope, down_slope))
|
||||||
|
|
||||||
|
return bins
|
||||||
|
|
||||||
|
|
||||||
|
def fbank_onnx(
|
||||||
|
waveform: Tensor, num_mel_bins=80, sample_frequency=16000, dither=0
|
||||||
|
) -> Tensor:
|
||||||
|
r"""ONNX-compatible fbank function with hardcoded parameters from traced call:
|
||||||
|
num_mel_bins=80, sample_frequency=16000, dither=0
|
||||||
|
|
||||||
|
All other parameters use their traced default values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: A fbank identical to what Kaldi would output. The shape is (m, 80)
|
||||||
|
where m is calculated in _get_strided
|
||||||
|
"""
|
||||||
|
device, dtype = waveform.device, waveform.dtype
|
||||||
|
|
||||||
|
# Use ONNX-compatible version of _get_waveform_and_window_properties
|
||||||
|
waveform_selected, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties_onnx(waveform)
|
||||||
|
|
||||||
|
# min_duration=0.0, so skip the duration check (signal will never be too short)
|
||||||
|
|
||||||
|
# Use ONNX-compatible version of _get_window
|
||||||
|
strided_input, signal_log_energy = _get_window_onnx(waveform_selected)
|
||||||
|
|
||||||
|
# spectrum = torch.fft.rfft(strided_input).abs()
|
||||||
|
|
||||||
|
m, frame_size = strided_input.shape
|
||||||
|
|
||||||
|
# Process all frames at once using batch processing
|
||||||
|
# Reshape to (m, 1, frame_size) to treat each frame as a separate batch item
|
||||||
|
batched_frames = strided_input.unsqueeze(1) # Shape: (m, 1, 512)
|
||||||
|
|
||||||
|
# Create rectangular window for all frames at once
|
||||||
|
rectangular_window = torch.ones(512, device=strided_input.device, dtype=strided_input.dtype)
|
||||||
|
|
||||||
|
# Apply STFT to all frames simultaneously
|
||||||
|
# The batch dimension allows us to process all m frames in parallel
|
||||||
|
stft_result = torch.stft(
|
||||||
|
batched_frames.flatten(0, 1), # Shape: (m, 512) - flatten batch and channel dims
|
||||||
|
n_fft=512,
|
||||||
|
hop_length=512, # Process entire frame at once
|
||||||
|
window=rectangular_window,
|
||||||
|
center=False, # Don't add padding
|
||||||
|
return_complex=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# stft_result shape: (m, 257, 1, 2) where last dim is [real, imag]
|
||||||
|
# Calculate magnitude: sqrt(real^2 + imag^2)
|
||||||
|
real_part = stft_result[..., 0] # Shape: (m, 257, 1)
|
||||||
|
imag_part = stft_result[..., 1] # Shape: (m, 257, 1)
|
||||||
|
spectrum = torch.sqrt(real_part.pow(2) + imag_part.pow(2)).squeeze(-1) # Shape: (m, 257)
|
||||||
|
|
||||||
|
# use_power=True, so execute this branch
|
||||||
|
spectrum = spectrum.pow(2.0)
|
||||||
|
|
||||||
|
# Get mel filterbanks using ONNX-compatible version
|
||||||
|
mel_energies = get_mel_banks_onnx(device, dtype)
|
||||||
|
|
||||||
|
# pad right column with zeros to match FFT output size (80, 256) -> (80, 257)
|
||||||
|
mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode="constant", value=0)
|
||||||
|
|
||||||
|
# sum with mel filterbanks over the power spectrum, size (m, 80)
|
||||||
|
mel_energies = torch.mm(spectrum, mel_energies.T)
|
||||||
|
|
||||||
|
# use_log_fbank=True, so execute this branch
|
||||||
|
mel_energies = torch.max(mel_energies, _get_epsilon(device, dtype)).log()
|
||||||
|
|
||||||
|
# use_energy=False, so skip the energy addition (lines 828-834)
|
||||||
|
|
||||||
|
# Use ONNX-compatible version of _subtract_column_mean
|
||||||
|
mel_energies = _subtract_column_mean_onnx(mel_energies)
|
||||||
|
|
||||||
|
return mel_energies
|
||||||
|
|
||||||
|
# Test to compare original fbank vs fbank_onnx
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import torch
|
||||||
|
|
||||||
|
print("Testing fbank vs fbank_onnx with traced parameters...")
|
||||||
|
|
||||||
|
# Create test waveform
|
||||||
|
torch.manual_seed(42)
|
||||||
|
sample_rate = 16000
|
||||||
|
duration = 1.0 # 1 second
|
||||||
|
num_samples = int(sample_rate * duration)
|
||||||
|
|
||||||
|
# Create a test waveform (sine wave + noise)
|
||||||
|
t = torch.linspace(0, duration, num_samples)
|
||||||
|
frequency = 440.0 # A4 note
|
||||||
|
waveform = torch.sin(2 * torch.pi * frequency * t) + 0.1 * torch.randn(num_samples)
|
||||||
|
|
||||||
|
# Test with both mono and stereo inputs
|
||||||
|
mono_waveform = waveform.unsqueeze(0) # Shape: (1, num_samples)
|
||||||
|
|
||||||
|
print(f"Test waveform shape: {mono_waveform.shape}")
|
||||||
|
|
||||||
|
# Test parameters from trace: num_mel_bins=80, sample_frequency=16000, dither=0
|
||||||
|
try:
|
||||||
|
print("\n=== DEBUGGING: Step-by-step comparison ===")
|
||||||
|
|
||||||
|
# Step 1: Check waveform processing
|
||||||
|
orig_waveform, orig_window_shift, orig_window_size, orig_padded_window_size = _get_waveform_and_window_properties(
|
||||||
|
mono_waveform, -1, 16000.0, 10.0, 25.0, True, 0.97
|
||||||
|
)
|
||||||
|
onnx_waveform, onnx_window_shift, onnx_window_size, onnx_padded_window_size = _get_waveform_and_window_properties_onnx(mono_waveform)
|
||||||
|
|
||||||
|
print(f"Original waveform shape: {orig_waveform.shape}")
|
||||||
|
print(f"ONNX waveform shape: {onnx_waveform.shape}")
|
||||||
|
print(f"Waveform difference: {torch.max(torch.abs(orig_waveform - onnx_waveform)).item():.2e}")
|
||||||
|
print(f"Window params - orig: shift={orig_window_shift}, size={orig_window_size}, padded={orig_padded_window_size}")
|
||||||
|
print(f"Window params - onnx: shift={onnx_window_shift}, size={onnx_window_size}, padded={onnx_padded_window_size}")
|
||||||
|
|
||||||
|
# Step 2: Check windowing
|
||||||
|
orig_strided, orig_energy = _get_window(
|
||||||
|
orig_waveform, orig_padded_window_size, orig_window_size, orig_window_shift,
|
||||||
|
'povey', 0.42, True, True, 1.0, 0, True, 0.97
|
||||||
|
)
|
||||||
|
onnx_strided, onnx_energy = _get_window_onnx(onnx_waveform)
|
||||||
|
|
||||||
|
print(f"\nOriginal strided shape: {orig_strided.shape}")
|
||||||
|
print(f"ONNX strided shape: {onnx_strided.shape}")
|
||||||
|
print(f"Strided difference: {torch.max(torch.abs(orig_strided - onnx_strided)).item():.2e}")
|
||||||
|
print(f"Energy difference: {torch.max(torch.abs(orig_energy - onnx_energy)).item():.2e}")
|
||||||
|
|
||||||
|
# Step 3: Check mel banks
|
||||||
|
orig_mel_banks = get_mel_banks(80, 512, 16000.0, 20.0, 0.0, 100.0, -500.0, 1.0, mono_waveform.device, mono_waveform.dtype)
|
||||||
|
onnx_mel_banks = get_mel_banks_onnx(mono_waveform.device, mono_waveform.dtype)
|
||||||
|
|
||||||
|
print(f"\nOriginal mel banks shape: {orig_mel_banks.shape}")
|
||||||
|
print(f"ONNX mel banks shape: {onnx_mel_banks.shape}")
|
||||||
|
print(f"Mel banks difference: {torch.max(torch.abs(orig_mel_banks - onnx_mel_banks)).item():.2e}")
|
||||||
|
|
||||||
|
# Step 4: Full comparison
|
||||||
|
print("\n=== FULL COMPARISON ===")
|
||||||
|
|
||||||
|
# Original fbank
|
||||||
|
original_result = fbank(
|
||||||
|
mono_waveform,
|
||||||
|
num_mel_bins=80,
|
||||||
|
sample_frequency=16000,
|
||||||
|
dither=0
|
||||||
|
)
|
||||||
|
|
||||||
|
# ONNX-compatible fbank
|
||||||
|
onnx_result = fbank_onnx(mono_waveform)
|
||||||
|
|
||||||
|
print(f"Original fbank output shape: {original_result.shape}")
|
||||||
|
print(f"ONNX fbank output shape: {onnx_result.shape}")
|
||||||
|
|
||||||
|
# Check if shapes match
|
||||||
|
if original_result.shape == onnx_result.shape:
|
||||||
|
print("✅ Output shapes match")
|
||||||
|
else:
|
||||||
|
print("❌ Output shapes don't match")
|
||||||
|
print(f" Original: {original_result.shape}")
|
||||||
|
print(f" ONNX: {onnx_result.shape}")
|
||||||
|
|
||||||
|
# Check numerical differences
|
||||||
|
if original_result.shape == onnx_result.shape:
|
||||||
|
diff = torch.abs(original_result - onnx_result)
|
||||||
|
max_diff = torch.max(diff).item()
|
||||||
|
mean_diff = torch.mean(diff).item()
|
||||||
|
relative_diff = torch.mean(diff / (torch.abs(original_result) + 1e-8)).item()
|
||||||
|
|
||||||
|
print(f"Max absolute difference: {max_diff:.2e}")
|
||||||
|
print(f"Mean absolute difference: {mean_diff:.2e}")
|
||||||
|
print(f"Mean relative difference: {relative_diff:.2e}")
|
||||||
|
|
||||||
|
# Find where the max difference occurs
|
||||||
|
max_idx = torch.argmax(diff)
|
||||||
|
max_coords = torch.unravel_index(max_idx, diff.shape)
|
||||||
|
print(f"Max difference at coordinates: {max_coords}")
|
||||||
|
print(f" Original value: {original_result[max_coords].item():.6f}")
|
||||||
|
print(f" ONNX value: {onnx_result[max_coords].item():.6f}")
|
||||||
|
|
||||||
|
# Check if results are numerically close
|
||||||
|
tolerance = 1e-5
|
||||||
|
if max_diff < tolerance:
|
||||||
|
print(f"✅ Results are numerically identical (within {tolerance})")
|
||||||
|
else:
|
||||||
|
print(f"❌ Results {max_diff} differ by more than {tolerance}")
|
||||||
|
|
||||||
|
# Additional statistics
|
||||||
|
print(f"Original result range: [{torch.min(original_result).item():.3f}, {torch.max(original_result).item():.3f}]")
|
||||||
|
print(f"ONNX result range: [{torch.min(onnx_result).item():.3f}, {torch.max(onnx_result).item():.3f}]")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error during testing: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
@ -4,6 +4,7 @@ from AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule
|
|||||||
from feature_extractor import cnhubert
|
from feature_extractor import cnhubert
|
||||||
from module.models_onnx import SynthesizerTrn, symbols_v1, symbols_v2
|
from module.models_onnx import SynthesizerTrn, symbols_v1, symbols_v2
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from sv import SV
|
||||||
|
|
||||||
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
|
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
|
||||||
cnhubert.cnhubert_base_path = cnhubert_base_path
|
cnhubert.cnhubert_base_path = cnhubert_base_path
|
||||||
@ -190,14 +191,18 @@ class T2SModel(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class VitsModel(nn.Module):
|
class VitsModel(nn.Module):
|
||||||
def __init__(self, vits_path):
|
def __init__(self, vits_path, version:str = 'v2'):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
dict_s2 = torch.load(vits_path, map_location="cpu")
|
dict_s2 = torch.load(vits_path, map_location="cpu")
|
||||||
self.hps = dict_s2["config"]
|
self.hps = dict_s2["config"]
|
||||||
if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
|
if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
|
||||||
self.hps["model"]["version"] = "v1"
|
self.hps["model"]["version"] = "v1"
|
||||||
else:
|
else:
|
||||||
self.hps["model"]["version"] = "v2"
|
self.hps["model"]["version"] = version
|
||||||
|
|
||||||
|
self.sv_model = None
|
||||||
|
if version == "v2ProPlus" or version == "v2Pro":
|
||||||
|
self.sv_model = SV("cpu", False)
|
||||||
|
|
||||||
self.hps = DictToAttrRecursive(self.hps)
|
self.hps = DictToAttrRecursive(self.hps)
|
||||||
self.hps.model.semantic_frame_rate = "25hz"
|
self.hps.model.semantic_frame_rate = "25hz"
|
||||||
@ -219,6 +224,9 @@ class VitsModel(nn.Module):
|
|||||||
self.hps.data.win_length,
|
self.hps.data.win_length,
|
||||||
center=False,
|
center=False,
|
||||||
)
|
)
|
||||||
|
if self.sv_model is not None:
|
||||||
|
sv_emb=self.sv_model.compute_embedding3_onnx(ref_audio)
|
||||||
|
return self.vq_model(pred_semantic, text_seq, refer, sv_emb=sv_emb)[0, 0]
|
||||||
return self.vq_model(pred_semantic, text_seq, refer)[0, 0]
|
return self.vq_model(pred_semantic, text_seq, refer)[0, 0]
|
||||||
|
|
||||||
|
|
||||||
@ -274,8 +282,8 @@ class SSLModel(nn.Module):
|
|||||||
return self.ssl.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2)
|
return self.ssl.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2)
|
||||||
|
|
||||||
|
|
||||||
def export(vits_path, gpt_path, project_name, vits_model="v2"):
|
def export(vits_path, gpt_path, project_name, voice_model_version="v2"):
|
||||||
vits = VitsModel(vits_path)
|
vits = VitsModel(vits_path, version=voice_model_version)
|
||||||
gpt = T2SModel(gpt_path, vits)
|
gpt = T2SModel(gpt_path, vits)
|
||||||
gpt_sovits = GptSoVits(vits, gpt)
|
gpt_sovits = GptSoVits(vits, gpt)
|
||||||
ssl = SSLModel()
|
ssl = SSLModel()
|
||||||
@ -297,7 +305,7 @@ def export(vits_path, gpt_path, project_name, vits_model="v2"):
|
|||||||
"y",
|
"y",
|
||||||
"e4",
|
"e4",
|
||||||
],
|
],
|
||||||
version=vits_model,
|
version=voice_model_version,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@ -330,7 +338,7 @@ def export(vits_path, gpt_path, project_name, vits_model="v2"):
|
|||||||
"y",
|
"y",
|
||||||
"e4",
|
"e4",
|
||||||
],
|
],
|
||||||
version=vits_model,
|
version=voice_model_version,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@ -349,9 +357,8 @@ def export(vits_path, gpt_path, project_name, vits_model="v2"):
|
|||||||
ssl_content = ssl(ref_audio_16k).float()
|
ssl_content = ssl(ref_audio_16k).float()
|
||||||
|
|
||||||
# debug = False
|
# debug = False
|
||||||
debug = True
|
debug = False
|
||||||
|
gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, project_name)
|
||||||
# gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, project_name)
|
|
||||||
|
|
||||||
if debug:
|
if debug:
|
||||||
a, b = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, debug=debug)
|
a, b = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, debug=debug)
|
||||||
@ -361,7 +368,7 @@ def export(vits_path, gpt_path, project_name, vits_model="v2"):
|
|||||||
a = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content).detach().cpu().numpy()
|
a = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content).detach().cpu().numpy()
|
||||||
soundfile.write("out.wav", a, vits.hps.data.sampling_rate)
|
soundfile.write("out.wav", a, vits.hps.data.sampling_rate)
|
||||||
|
|
||||||
if vits_model == "v1":
|
if voice_model_version == "v1":
|
||||||
symbols = symbols_v1
|
symbols = symbols_v1
|
||||||
else:
|
else:
|
||||||
symbols = symbols_v2
|
symbols = symbols_v2
|
||||||
@ -390,9 +397,16 @@ if __name__ == "__main__":
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
gpt_path = "GPT_weights/nahida-e25.ckpt"
|
# gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"
|
||||||
vits_path = "SoVITS_weights/nahida_e30_s3930.pth"
|
# vits_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"
|
||||||
exp_path = "nahida"
|
# exp_path = "v2_export"
|
||||||
export(vits_path, gpt_path, exp_path)
|
# version = "v2"
|
||||||
|
# export(vits_path, gpt_path, exp_path, version)
|
||||||
|
|
||||||
|
gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"
|
||||||
|
vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth"
|
||||||
|
exp_path = "v2proplus_export"
|
||||||
|
version = "v2ProPlus"
|
||||||
|
export(vits_path, gpt_path, exp_path, version)
|
||||||
|
|
||||||
|
|
||||||
# soundfile.write("out.wav", a, vits.hps.data.sampling_rate)
|
|
||||||
|
@ -30,3 +30,15 @@ class SV:
|
|||||||
)
|
)
|
||||||
sv_emb = self.embedding_model.forward3(feat)
|
sv_emb = self.embedding_model.forward3(feat)
|
||||||
return sv_emb
|
return sv_emb
|
||||||
|
|
||||||
|
def compute_embedding3_onnx(self, wav):
|
||||||
|
# Disable gradients for all parameters
|
||||||
|
for param in self.embedding_model.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
if self.is_half == True:
|
||||||
|
wav = wav.half()
|
||||||
|
feat = Kaldi.fbank_onnx(wav.detach()).unsqueeze(0)
|
||||||
|
sv_emb = self.embedding_model.forward3(feat)
|
||||||
|
return sv_emb
|
@ -7,6 +7,7 @@ numba
|
|||||||
pytorch-lightning>=2.4
|
pytorch-lightning>=2.4
|
||||||
gradio<5
|
gradio<5
|
||||||
ffmpeg-python
|
ffmpeg-python
|
||||||
|
onnx
|
||||||
onnxruntime; platform_machine == "aarch64" or platform_machine == "arm64"
|
onnxruntime; platform_machine == "aarch64" or platform_machine == "arm64"
|
||||||
onnxruntime-gpu; platform_machine == "x86_64" or platform_machine == "AMD64"
|
onnxruntime-gpu; platform_machine == "x86_64" or platform_machine == "AMD64"
|
||||||
tqdm
|
tqdm
|
||||||
|
Loading…
x
Reference in New Issue
Block a user