mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-29 17:10:02 +08:00
feat:v2pp onnx export ready testing...
This commit is contained in:
parent
fdf794e31d
commit
8c0f32da3e
1
.gitignore
vendored
1
.gitignore
vendored
@ -193,3 +193,4 @@ cython_debug/
|
||||
|
||||
# PyPI configuration file
|
||||
.pypirc
|
||||
onnx/
|
@ -2,6 +2,7 @@ from torch.nn.functional import *
|
||||
from torch.nn.functional import (
|
||||
_canonical_mask,
|
||||
)
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
def multi_head_attention_forward_patched(
|
||||
|
@ -13,6 +13,7 @@ __all__ = [
|
||||
"mel_scale_scalar",
|
||||
"spectrogram",
|
||||
"fbank",
|
||||
"fbank_onnx"
|
||||
"mfcc",
|
||||
"vtln_warp_freq",
|
||||
"vtln_warp_mel_freq",
|
||||
@ -842,3 +843,371 @@ def mfcc(
|
||||
|
||||
feature = _subtract_column_mean(feature, subtract_mean)
|
||||
return feature
|
||||
|
||||
def _get_log_energy_onnx(strided_input: Tensor, epsilon: Tensor, energy_floor: float = 1.0) -> Tensor:
|
||||
r"""Returns the log energy of size (m) for a strided_input (m,*)"""
|
||||
device, dtype = strided_input.device, strided_input.dtype
|
||||
log_energy = torch.max(strided_input.pow(2).sum(1), epsilon).log() # size (m)
|
||||
return torch.max(log_energy, torch.tensor(0.0, device=device, dtype=dtype))
|
||||
|
||||
|
||||
def _get_waveform_and_window_properties_onnx(
|
||||
waveform: Tensor,
|
||||
) -> Tuple[Tensor, int, int, int]:
|
||||
r"""ONNX-compatible version with hardcoded parameters from traced fbank call:
|
||||
channel=-1, sample_frequency=16000, frame_shift=10.0, frame_length=25.0,
|
||||
round_to_power_of_two=True, preemphasis_coefficient=0.97"""
|
||||
|
||||
# Hardcoded values from traced parameters
|
||||
# channel=-1 -> 0 after max(channel, 0)
|
||||
channel = 0
|
||||
|
||||
# Extract channel 0 from waveform
|
||||
if waveform.dim() == 1:
|
||||
# Mono waveform, use as-is
|
||||
waveform_selected = waveform
|
||||
else:
|
||||
# Multi-channel, select first channel
|
||||
waveform_selected = waveform[channel, :]
|
||||
|
||||
# Hardcoded calculations:
|
||||
# window_shift = int(16000 * 10.0 * 0.001) = 160
|
||||
# window_size = int(16000 * 25.0 * 0.001) = 400
|
||||
# padded_window_size = _next_power_of_2(400) = 512
|
||||
window_shift = 160
|
||||
window_size = 400
|
||||
padded_window_size = 512
|
||||
|
||||
return waveform_selected, window_shift, window_size, padded_window_size
|
||||
|
||||
def _get_window_onnx(
|
||||
waveform: Tensor,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
r"""ONNX-compatible version with hardcoded parameters from traced fbank call:
|
||||
padded_window_size=512, window_size=400, window_shift=160, window_type='povey',
|
||||
blackman_coeff=0.42, snip_edges=True, raw_energy=True, energy_floor=1.0,
|
||||
dither=0, remove_dc_offset=True, preemphasis_coefficient=0.97
|
||||
|
||||
Returns:
|
||||
(Tensor, Tensor): strided_input of size (m, 512) and signal_log_energy of size (m)
|
||||
"""
|
||||
device, dtype = waveform.device, waveform.dtype
|
||||
epsilon = _get_epsilon(device, dtype)
|
||||
|
||||
# Hardcoded values from traced parameters
|
||||
window_size = 400
|
||||
window_shift = 160
|
||||
padded_window_size = 512
|
||||
snip_edges = True
|
||||
|
||||
# size (m, window_size)
|
||||
strided_input = _get_strided_onnx(waveform, window_size, window_shift, snip_edges)
|
||||
|
||||
# dither=0, so skip dithering (lines 209-211 from original)
|
||||
|
||||
# remove_dc_offset=True, so execute this branch
|
||||
row_means = torch.mean(strided_input, dim=1).unsqueeze(1) # size (m, 1)
|
||||
strided_input = strided_input - row_means
|
||||
|
||||
# raw_energy=True, so execute this branch
|
||||
signal_log_energy = _get_log_energy_onnx(strided_input, epsilon) # energy_floor=1.0
|
||||
|
||||
# preemphasis_coefficient=0.97 != 0.0, so execute this branch
|
||||
offset_strided_input = torch.nn.functional.pad(strided_input.unsqueeze(0), (1, 0), mode="replicate").squeeze(0)
|
||||
strided_input = strided_input - 0.97 * offset_strided_input[:, :-1]
|
||||
|
||||
# Apply povey window function to each row/frame
|
||||
# povey window: torch.hann_window(window_size, periodic=False).pow(0.85)
|
||||
window_function = torch.hann_window(window_size, periodic=False, device=device, dtype=dtype).pow(0.85).unsqueeze(0)
|
||||
strided_input = strided_input * window_function
|
||||
|
||||
# Pad columns from window_size=400 to padded_window_size=512
|
||||
padding_right = padded_window_size - window_size # 112
|
||||
strided_input = torch.nn.functional.pad(strided_input.unsqueeze(0), (0, padding_right), mode="constant", value=0).squeeze(0)
|
||||
|
||||
# raw_energy=True, so skip the "not raw_energy" branch (lines 244-245)
|
||||
return strided_input, signal_log_energy
|
||||
|
||||
|
||||
def _get_strided_onnx(waveform: Tensor, window_size = 400, window_shift = 160, snip_edges = 512) -> Tensor:
|
||||
seq_len = waveform.size(0)
|
||||
|
||||
# Calculate number of windows
|
||||
num_windows = 1 + (seq_len - window_size) // window_shift
|
||||
|
||||
# Create indices for all windows at once
|
||||
window_starts = torch.arange(0, num_windows * window_shift, window_shift, device=waveform.device)
|
||||
window_indices = window_starts.unsqueeze(1) + torch.arange(window_size, device=waveform.device).unsqueeze(0)
|
||||
|
||||
# Extract windows using advanced indexing
|
||||
windows = waveform[window_indices] # [num_windows, window_size]
|
||||
|
||||
return windows
|
||||
|
||||
|
||||
def _subtract_column_mean_onnx(tensor: Tensor) -> Tensor:
|
||||
"""ONNX-compatible version with hardcoded parameters from traced fbank call:
|
||||
subtract_mean=False, so this function returns the input tensor unchanged.
|
||||
|
||||
Args:
|
||||
tensor: Input tensor of size (m, n)
|
||||
|
||||
Returns:
|
||||
Tensor: Same as input tensor (m, n) since subtract_mean=False
|
||||
"""
|
||||
# subtract_mean=False from traced parameters, so return tensor as-is
|
||||
return tensor
|
||||
|
||||
|
||||
def get_mel_banks_onnx(
|
||||
device=None,
|
||||
dtype=None,
|
||||
) -> Tensor:
|
||||
"""ONNX-compatible version with hardcoded parameters from traced fbank call:
|
||||
num_bins=80, window_length_padded=512, sample_freq=16000, low_freq=20.0,
|
||||
high_freq=0.0, vtln_low=100.0, vtln_high=-500.0, vtln_warp_factor=1.0
|
||||
|
||||
Returns:
|
||||
Tensor: melbank of size (80, 256) (num_bins, num_fft_bins)
|
||||
"""
|
||||
# Hardcoded values from traced parameters
|
||||
num_bins = 80
|
||||
window_length_padded = 512
|
||||
sample_freq = 16000.0
|
||||
low_freq = 20.0
|
||||
high_freq = 0.0 # Will be adjusted to nyquist
|
||||
vtln_warp_factor = 1.0
|
||||
|
||||
# Calculate dynamic values to ensure accuracy
|
||||
num_fft_bins = window_length_padded // 2 # 256 (integer division)
|
||||
nyquist = 0.5 * sample_freq # 8000.0
|
||||
|
||||
# high_freq <= 0.0, so high_freq += nyquist
|
||||
if high_freq <= 0.0:
|
||||
high_freq += nyquist # 8000.0
|
||||
|
||||
# fft-bin width = sample_freq / window_length_padded = 16000 / 512 = 31.25
|
||||
fft_bin_width = sample_freq / window_length_padded
|
||||
|
||||
# Calculate mel scale values dynamically
|
||||
mel_low_freq = mel_scale_scalar(low_freq)
|
||||
mel_high_freq = mel_scale_scalar(high_freq)
|
||||
|
||||
# mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1)
|
||||
mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1)
|
||||
|
||||
# vtln_warp_factor == 1.0, so no VTLN warping needed
|
||||
|
||||
# Create mel bin centers
|
||||
bin_indices = torch.arange(num_bins, device=device, dtype=dtype).unsqueeze(1)
|
||||
left_mel = mel_low_freq + bin_indices * mel_freq_delta
|
||||
center_mel = mel_low_freq + (bin_indices + 1.0) * mel_freq_delta
|
||||
right_mel = mel_low_freq + (bin_indices + 2.0) * mel_freq_delta
|
||||
|
||||
# No VTLN warping since vtln_warp_factor == 1.0
|
||||
|
||||
# Create frequency bins for FFT
|
||||
fft_freqs = fft_bin_width * torch.arange(num_fft_bins, device=device, dtype=dtype)
|
||||
mel = mel_scale(fft_freqs).unsqueeze(0) # size(1, num_fft_bins)
|
||||
|
||||
# Calculate triangular filter banks
|
||||
up_slope = (mel - left_mel) / (center_mel - left_mel)
|
||||
down_slope = (right_mel - mel) / (right_mel - center_mel)
|
||||
|
||||
# Since vtln_warp_factor == 1.0, use the simpler branch
|
||||
bins = torch.max(torch.zeros(1, device=device, dtype=dtype), torch.min(up_slope, down_slope))
|
||||
|
||||
return bins
|
||||
|
||||
|
||||
def fbank_onnx(
|
||||
waveform: Tensor, num_mel_bins=80, sample_frequency=16000, dither=0
|
||||
) -> Tensor:
|
||||
r"""ONNX-compatible fbank function with hardcoded parameters from traced call:
|
||||
num_mel_bins=80, sample_frequency=16000, dither=0
|
||||
|
||||
All other parameters use their traced default values.
|
||||
|
||||
Args:
|
||||
waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
|
||||
|
||||
Returns:
|
||||
Tensor: A fbank identical to what Kaldi would output. The shape is (m, 80)
|
||||
where m is calculated in _get_strided
|
||||
"""
|
||||
device, dtype = waveform.device, waveform.dtype
|
||||
|
||||
# Use ONNX-compatible version of _get_waveform_and_window_properties
|
||||
waveform_selected, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties_onnx(waveform)
|
||||
|
||||
# min_duration=0.0, so skip the duration check (signal will never be too short)
|
||||
|
||||
# Use ONNX-compatible version of _get_window
|
||||
strided_input, signal_log_energy = _get_window_onnx(waveform_selected)
|
||||
|
||||
# spectrum = torch.fft.rfft(strided_input).abs()
|
||||
|
||||
m, frame_size = strided_input.shape
|
||||
|
||||
# Process all frames at once using batch processing
|
||||
# Reshape to (m, 1, frame_size) to treat each frame as a separate batch item
|
||||
batched_frames = strided_input.unsqueeze(1) # Shape: (m, 1, 512)
|
||||
|
||||
# Create rectangular window for all frames at once
|
||||
rectangular_window = torch.ones(512, device=strided_input.device, dtype=strided_input.dtype)
|
||||
|
||||
# Apply STFT to all frames simultaneously
|
||||
# The batch dimension allows us to process all m frames in parallel
|
||||
stft_result = torch.stft(
|
||||
batched_frames.flatten(0, 1), # Shape: (m, 512) - flatten batch and channel dims
|
||||
n_fft=512,
|
||||
hop_length=512, # Process entire frame at once
|
||||
window=rectangular_window,
|
||||
center=False, # Don't add padding
|
||||
return_complex=False
|
||||
)
|
||||
|
||||
# stft_result shape: (m, 257, 1, 2) where last dim is [real, imag]
|
||||
# Calculate magnitude: sqrt(real^2 + imag^2)
|
||||
real_part = stft_result[..., 0] # Shape: (m, 257, 1)
|
||||
imag_part = stft_result[..., 1] # Shape: (m, 257, 1)
|
||||
spectrum = torch.sqrt(real_part.pow(2) + imag_part.pow(2)).squeeze(-1) # Shape: (m, 257)
|
||||
|
||||
# use_power=True, so execute this branch
|
||||
spectrum = spectrum.pow(2.0)
|
||||
|
||||
# Get mel filterbanks using ONNX-compatible version
|
||||
mel_energies = get_mel_banks_onnx(device, dtype)
|
||||
|
||||
# pad right column with zeros to match FFT output size (80, 256) -> (80, 257)
|
||||
mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode="constant", value=0)
|
||||
|
||||
# sum with mel filterbanks over the power spectrum, size (m, 80)
|
||||
mel_energies = torch.mm(spectrum, mel_energies.T)
|
||||
|
||||
# use_log_fbank=True, so execute this branch
|
||||
mel_energies = torch.max(mel_energies, _get_epsilon(device, dtype)).log()
|
||||
|
||||
# use_energy=False, so skip the energy addition (lines 828-834)
|
||||
|
||||
# Use ONNX-compatible version of _subtract_column_mean
|
||||
mel_energies = _subtract_column_mean_onnx(mel_energies)
|
||||
|
||||
return mel_energies
|
||||
|
||||
# Test to compare original fbank vs fbank_onnx
|
||||
if __name__ == "__main__":
|
||||
import torch
|
||||
|
||||
print("Testing fbank vs fbank_onnx with traced parameters...")
|
||||
|
||||
# Create test waveform
|
||||
torch.manual_seed(42)
|
||||
sample_rate = 16000
|
||||
duration = 1.0 # 1 second
|
||||
num_samples = int(sample_rate * duration)
|
||||
|
||||
# Create a test waveform (sine wave + noise)
|
||||
t = torch.linspace(0, duration, num_samples)
|
||||
frequency = 440.0 # A4 note
|
||||
waveform = torch.sin(2 * torch.pi * frequency * t) + 0.1 * torch.randn(num_samples)
|
||||
|
||||
# Test with both mono and stereo inputs
|
||||
mono_waveform = waveform.unsqueeze(0) # Shape: (1, num_samples)
|
||||
|
||||
print(f"Test waveform shape: {mono_waveform.shape}")
|
||||
|
||||
# Test parameters from trace: num_mel_bins=80, sample_frequency=16000, dither=0
|
||||
try:
|
||||
print("\n=== DEBUGGING: Step-by-step comparison ===")
|
||||
|
||||
# Step 1: Check waveform processing
|
||||
orig_waveform, orig_window_shift, orig_window_size, orig_padded_window_size = _get_waveform_and_window_properties(
|
||||
mono_waveform, -1, 16000.0, 10.0, 25.0, True, 0.97
|
||||
)
|
||||
onnx_waveform, onnx_window_shift, onnx_window_size, onnx_padded_window_size = _get_waveform_and_window_properties_onnx(mono_waveform)
|
||||
|
||||
print(f"Original waveform shape: {orig_waveform.shape}")
|
||||
print(f"ONNX waveform shape: {onnx_waveform.shape}")
|
||||
print(f"Waveform difference: {torch.max(torch.abs(orig_waveform - onnx_waveform)).item():.2e}")
|
||||
print(f"Window params - orig: shift={orig_window_shift}, size={orig_window_size}, padded={orig_padded_window_size}")
|
||||
print(f"Window params - onnx: shift={onnx_window_shift}, size={onnx_window_size}, padded={onnx_padded_window_size}")
|
||||
|
||||
# Step 2: Check windowing
|
||||
orig_strided, orig_energy = _get_window(
|
||||
orig_waveform, orig_padded_window_size, orig_window_size, orig_window_shift,
|
||||
'povey', 0.42, True, True, 1.0, 0, True, 0.97
|
||||
)
|
||||
onnx_strided, onnx_energy = _get_window_onnx(onnx_waveform)
|
||||
|
||||
print(f"\nOriginal strided shape: {orig_strided.shape}")
|
||||
print(f"ONNX strided shape: {onnx_strided.shape}")
|
||||
print(f"Strided difference: {torch.max(torch.abs(orig_strided - onnx_strided)).item():.2e}")
|
||||
print(f"Energy difference: {torch.max(torch.abs(orig_energy - onnx_energy)).item():.2e}")
|
||||
|
||||
# Step 3: Check mel banks
|
||||
orig_mel_banks = get_mel_banks(80, 512, 16000.0, 20.0, 0.0, 100.0, -500.0, 1.0, mono_waveform.device, mono_waveform.dtype)
|
||||
onnx_mel_banks = get_mel_banks_onnx(mono_waveform.device, mono_waveform.dtype)
|
||||
|
||||
print(f"\nOriginal mel banks shape: {orig_mel_banks.shape}")
|
||||
print(f"ONNX mel banks shape: {onnx_mel_banks.shape}")
|
||||
print(f"Mel banks difference: {torch.max(torch.abs(orig_mel_banks - onnx_mel_banks)).item():.2e}")
|
||||
|
||||
# Step 4: Full comparison
|
||||
print("\n=== FULL COMPARISON ===")
|
||||
|
||||
# Original fbank
|
||||
original_result = fbank(
|
||||
mono_waveform,
|
||||
num_mel_bins=80,
|
||||
sample_frequency=16000,
|
||||
dither=0
|
||||
)
|
||||
|
||||
# ONNX-compatible fbank
|
||||
onnx_result = fbank_onnx(mono_waveform)
|
||||
|
||||
print(f"Original fbank output shape: {original_result.shape}")
|
||||
print(f"ONNX fbank output shape: {onnx_result.shape}")
|
||||
|
||||
# Check if shapes match
|
||||
if original_result.shape == onnx_result.shape:
|
||||
print("✅ Output shapes match")
|
||||
else:
|
||||
print("❌ Output shapes don't match")
|
||||
print(f" Original: {original_result.shape}")
|
||||
print(f" ONNX: {onnx_result.shape}")
|
||||
|
||||
# Check numerical differences
|
||||
if original_result.shape == onnx_result.shape:
|
||||
diff = torch.abs(original_result - onnx_result)
|
||||
max_diff = torch.max(diff).item()
|
||||
mean_diff = torch.mean(diff).item()
|
||||
relative_diff = torch.mean(diff / (torch.abs(original_result) + 1e-8)).item()
|
||||
|
||||
print(f"Max absolute difference: {max_diff:.2e}")
|
||||
print(f"Mean absolute difference: {mean_diff:.2e}")
|
||||
print(f"Mean relative difference: {relative_diff:.2e}")
|
||||
|
||||
# Find where the max difference occurs
|
||||
max_idx = torch.argmax(diff)
|
||||
max_coords = torch.unravel_index(max_idx, diff.shape)
|
||||
print(f"Max difference at coordinates: {max_coords}")
|
||||
print(f" Original value: {original_result[max_coords].item():.6f}")
|
||||
print(f" ONNX value: {onnx_result[max_coords].item():.6f}")
|
||||
|
||||
# Check if results are numerically close
|
||||
tolerance = 1e-5
|
||||
if max_diff < tolerance:
|
||||
print(f"✅ Results are numerically identical (within {tolerance})")
|
||||
else:
|
||||
print(f"❌ Results {max_diff} differ by more than {tolerance}")
|
||||
|
||||
# Additional statistics
|
||||
print(f"Original result range: [{torch.min(original_result).item():.3f}, {torch.max(original_result).item():.3f}]")
|
||||
print(f"ONNX result range: [{torch.min(onnx_result).item():.3f}, {torch.max(onnx_result).item():.3f}]")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error during testing: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
@ -4,6 +4,7 @@ from AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule
|
||||
from feature_extractor import cnhubert
|
||||
from module.models_onnx import SynthesizerTrn, symbols_v1, symbols_v2
|
||||
from torch import nn
|
||||
from sv import SV
|
||||
|
||||
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
|
||||
cnhubert.cnhubert_base_path = cnhubert_base_path
|
||||
@ -190,14 +191,18 @@ class T2SModel(nn.Module):
|
||||
|
||||
|
||||
class VitsModel(nn.Module):
|
||||
def __init__(self, vits_path):
|
||||
def __init__(self, vits_path, version:str = 'v2'):
|
||||
super().__init__()
|
||||
dict_s2 = torch.load(vits_path, map_location="cpu")
|
||||
self.hps = dict_s2["config"]
|
||||
if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
|
||||
self.hps["model"]["version"] = "v1"
|
||||
else:
|
||||
self.hps["model"]["version"] = "v2"
|
||||
self.hps["model"]["version"] = version
|
||||
|
||||
self.sv_model = None
|
||||
if version == "v2ProPlus" or version == "v2Pro":
|
||||
self.sv_model = SV("cpu", False)
|
||||
|
||||
self.hps = DictToAttrRecursive(self.hps)
|
||||
self.hps.model.semantic_frame_rate = "25hz"
|
||||
@ -219,6 +224,9 @@ class VitsModel(nn.Module):
|
||||
self.hps.data.win_length,
|
||||
center=False,
|
||||
)
|
||||
if self.sv_model is not None:
|
||||
sv_emb=self.sv_model.compute_embedding3_onnx(ref_audio)
|
||||
return self.vq_model(pred_semantic, text_seq, refer, sv_emb=sv_emb)[0, 0]
|
||||
return self.vq_model(pred_semantic, text_seq, refer)[0, 0]
|
||||
|
||||
|
||||
@ -274,8 +282,8 @@ class SSLModel(nn.Module):
|
||||
return self.ssl.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2)
|
||||
|
||||
|
||||
def export(vits_path, gpt_path, project_name, vits_model="v2"):
|
||||
vits = VitsModel(vits_path)
|
||||
def export(vits_path, gpt_path, project_name, voice_model_version="v2"):
|
||||
vits = VitsModel(vits_path, version=voice_model_version)
|
||||
gpt = T2SModel(gpt_path, vits)
|
||||
gpt_sovits = GptSoVits(vits, gpt)
|
||||
ssl = SSLModel()
|
||||
@ -297,7 +305,7 @@ def export(vits_path, gpt_path, project_name, vits_model="v2"):
|
||||
"y",
|
||||
"e4",
|
||||
],
|
||||
version=vits_model,
|
||||
version=voice_model_version,
|
||||
)
|
||||
]
|
||||
)
|
||||
@ -330,7 +338,7 @@ def export(vits_path, gpt_path, project_name, vits_model="v2"):
|
||||
"y",
|
||||
"e4",
|
||||
],
|
||||
version=vits_model,
|
||||
version=voice_model_version,
|
||||
)
|
||||
]
|
||||
)
|
||||
@ -349,9 +357,8 @@ def export(vits_path, gpt_path, project_name, vits_model="v2"):
|
||||
ssl_content = ssl(ref_audio_16k).float()
|
||||
|
||||
# debug = False
|
||||
debug = True
|
||||
|
||||
# gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, project_name)
|
||||
debug = False
|
||||
gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, project_name)
|
||||
|
||||
if debug:
|
||||
a, b = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, debug=debug)
|
||||
@ -361,7 +368,7 @@ def export(vits_path, gpt_path, project_name, vits_model="v2"):
|
||||
a = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content).detach().cpu().numpy()
|
||||
soundfile.write("out.wav", a, vits.hps.data.sampling_rate)
|
||||
|
||||
if vits_model == "v1":
|
||||
if voice_model_version == "v1":
|
||||
symbols = symbols_v1
|
||||
else:
|
||||
symbols = symbols_v2
|
||||
@ -390,9 +397,16 @@ if __name__ == "__main__":
|
||||
except:
|
||||
pass
|
||||
|
||||
gpt_path = "GPT_weights/nahida-e25.ckpt"
|
||||
vits_path = "SoVITS_weights/nahida_e30_s3930.pth"
|
||||
exp_path = "nahida"
|
||||
export(vits_path, gpt_path, exp_path)
|
||||
# gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"
|
||||
# vits_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"
|
||||
# exp_path = "v2_export"
|
||||
# version = "v2"
|
||||
# export(vits_path, gpt_path, exp_path, version)
|
||||
|
||||
gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"
|
||||
vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth"
|
||||
exp_path = "v2proplus_export"
|
||||
version = "v2ProPlus"
|
||||
export(vits_path, gpt_path, exp_path, version)
|
||||
|
||||
|
||||
# soundfile.write("out.wav", a, vits.hps.data.sampling_rate)
|
||||
|
@ -30,3 +30,15 @@ class SV:
|
||||
)
|
||||
sv_emb = self.embedding_model.forward3(feat)
|
||||
return sv_emb
|
||||
|
||||
def compute_embedding3_onnx(self, wav):
|
||||
# Disable gradients for all parameters
|
||||
for param in self.embedding_model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
with torch.no_grad():
|
||||
if self.is_half == True:
|
||||
wav = wav.half()
|
||||
feat = Kaldi.fbank_onnx(wav.detach()).unsqueeze(0)
|
||||
sv_emb = self.embedding_model.forward3(feat)
|
||||
return sv_emb
|
@ -7,6 +7,7 @@ numba
|
||||
pytorch-lightning>=2.4
|
||||
gradio<5
|
||||
ffmpeg-python
|
||||
onnx
|
||||
onnxruntime; platform_machine == "aarch64" or platform_machine == "arm64"
|
||||
onnxruntime-gpu; platform_machine == "x86_64" or platform_machine == "AMD64"
|
||||
tqdm
|
||||
|
Loading…
x
Reference in New Issue
Block a user