Merge 8858492f56326dce521db2c4a4b3a7323e786596 into 11aa78bd9bda8b53047cfcae03abf7ca94d27391

This commit is contained in:
zpeng11 2025-09-12 03:46:13 +08:00 committed by GitHub
commit 7e082dba50
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 1202 additions and 565 deletions

7
.gitignore vendored
View File

@ -193,3 +193,10 @@ cython_debug/
# PyPI configuration file
.pypirc
#onnx
onnx/
*.onnx
tokenizer.json
output.wav
config.json

View File

@ -7,6 +7,7 @@ from torchmetrics.classification import MulticlassAccuracy
from AR.modules.embedding_onnx import SinePositionalEmbedding, TokenEmbedding
from AR.modules.transformer_onnx import LayerNorm, TransformerEncoder, TransformerEncoderLayer
from tqdm import tqdm
default_config = {
"embedding_dim": 512,
@ -27,45 +28,45 @@ def logits_to_probs(
logits,
previous_tokens=None,
temperature: float = 1.0,
top_k=None,
top_p=None,
top_k=15,
top_p=1.0,
repetition_penalty: float = 1.0,
):
previous_tokens = previous_tokens.squeeze()
if previous_tokens is not None and repetition_penalty != 1.0:
previous_tokens = previous_tokens.long()
score = torch.gather(logits, dim=0, index=previous_tokens)
score = torch.where(
score < 0,
score * repetition_penalty,
score / repetition_penalty,
)
logits.scatter_(dim=0, index=previous_tokens, src=score)
# if previous_tokens is not None and repetition_penalty != 1.0: # Always captured by onnx
previous_tokens = previous_tokens.long()
score = torch.gather(logits, dim=0, index=previous_tokens)
score = torch.where(
score < 0,
score * repetition_penalty,
score / repetition_penalty,
)
logits.scatter_(dim=0, index=previous_tokens, src=score)
if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(
torch.nn.functional.softmax(
sorted_logits,
dim=-1,
),
# if top_p is not None and top_p < 1.0: #To be captured by onnx
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(
torch.nn.functional.softmax(
sorted_logits,
dim=-1,
)
sorted_indices_to_remove = cum_probs > top_p
sorted_indices_to_remove[0] = False # keep at least one option
indices_to_remove = sorted_indices_to_remove.scatter(
dim=0,
index=sorted_indices,
src=sorted_indices_to_remove,
)
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
),
dim=-1,
)
sorted_indices_to_remove = cum_probs > top_p
sorted_indices_to_remove[0] = False # keep at least one option
indices_to_remove = sorted_indices_to_remove.scatter(
dim=0,
index=sorted_indices,
src=sorted_indices_to_remove,
)
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
logits = logits / max(temperature, 1e-5)
logits = logits / torch.max(temperature, torch.tensor(1e-5, device=logits.device, dtype=torch.float))
if top_k is not None:
v, _ = torch.topk(logits, top_k)
pivot = v.select(-1, -1).unsqueeze(-1)
logits = torch.where(logits < pivot, inf_tensor_value, logits)
# if top_k is not None: # To be captured by onnx
v, _ = torch.topk(logits, top_k)
pivot = v.select(-1, -1).unsqueeze(-1)
logits = torch.where(logits < pivot, inf_tensor_value, logits)
probs = torch.nn.functional.softmax(logits, dim=-1)
return probs
@ -104,88 +105,6 @@ class OnnxEncoder(nn.Module):
x = x + self.bert_proj(bert_feature.transpose(1, 2))
return self.ar_text_position(x)
class T2SFirstStageDecoder(nn.Module):
def __init__(
self,
ar_audio_embedding,
ar_audio_position,
h,
ar_predict_layer,
loss_fct,
ar_accuracy_metric,
top_k,
early_stop_num,
num_layers,
):
super().__init__()
self.ar_audio_embedding = ar_audio_embedding
self.ar_audio_position = ar_audio_position
self.h = h
self.ar_predict_layer = ar_predict_layer
self.loss_fct = loss_fct
self.ar_accuracy_metric = ar_accuracy_metric
self.top_k = top_k
self.early_stop_num = early_stop_num
self.num_layers = num_layers
def forward(self, x, prompt):
y = prompt
x_example = x[:, :, 0] * 0.0
# N, 1, 512
cache = {
"all_stage": self.num_layers,
"k": None,
"v": None,
"y_emb": None,
"first_infer": 1,
"stage": 0,
}
y_emb = self.ar_audio_embedding(y)
cache["y_emb"] = y_emb
y_pos = self.ar_audio_position(y_emb)
xy_pos = torch.concat([x, y_pos], dim=1)
y_example = y_pos[:, :, 0] * 0.0
x_attn_mask = torch.matmul(x_example.transpose(0, 1), x_example).bool()
y_attn_mask = torch.ones_like(torch.matmul(y_example.transpose(0, 1), y_example), dtype=torch.int64)
y_attn_mask = torch.cumsum(y_attn_mask, dim=1) - torch.cumsum(
torch.ones_like(
y_example.transpose(0, 1),
dtype=torch.int64,
),
dim=0,
)
y_attn_mask = y_attn_mask > 0
x_y_pad = torch.matmul(x_example.transpose(0, 1), y_example).bool()
y_x_pad = torch.matmul(y_example.transpose(0, 1), x_example).bool()
x_attn_mask_pad = torch.cat([x_attn_mask, torch.ones_like(x_y_pad)], dim=1)
y_attn_mask = torch.cat([y_x_pad, y_attn_mask], dim=1)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
cache["k"] = (
torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))
.unsqueeze(1)
.repeat(self.num_layers, 1, 1, 1)
)
cache["v"] = (
torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))
.unsqueeze(1)
.repeat(self.num_layers, 1, 1, 1)
)
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
logits = self.ar_predict_layer(xy_dec[:, -1])
samples = sample(logits[0], y, top_k=self.top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
y = torch.concat([y, samples], dim=1)
return y, cache["k"], cache["v"], cache["y_emb"], x_example
class T2SStageDecoder(nn.Module):
def __init__(
self,
@ -195,7 +114,6 @@ class T2SStageDecoder(nn.Module):
ar_predict_layer,
loss_fct,
ar_accuracy_metric,
top_k,
early_stop_num,
num_layers,
):
@ -206,40 +124,80 @@ class T2SStageDecoder(nn.Module):
self.ar_predict_layer = ar_predict_layer
self.loss_fct = loss_fct
self.ar_accuracy_metric = ar_accuracy_metric
self.top_k = top_k
self.early_stop_num = early_stop_num
self.num_layers = num_layers
def forward(self, y, k, v, y_emb, x_example):
def forward(self, x, y, k, v, y_emb, top_k = None, top_p = None, repetition_penalty = None, temperature = None, first_infer = None, x_seq_len = None, y_seq_len = None):
if top_k is None:
top_k = torch.LongTensor([15]).to(device=y.device)
if top_p is None:
top_p = torch.FloatTensor([1.0]).to(device=y.device)
if repetition_penalty is None:
repetition_penalty = torch.FloatTensor([1.0]).to(device=y.device)
if temperature is None:
temperature = torch.FloatTensor([1.0]).to(device=y.device)
minus_one = torch.tensor([-1]).to(y.device).to(torch.int64)
cache = {
"all_stage": self.num_layers,
"k": torch.nn.functional.pad(k, (0, 0, 0, 0, 0, 1)),
"v": torch.nn.functional.pad(v, (0, 0, 0, 0, 0, 1)),
"k": k,
"v": v,
"y_emb": y_emb,
"first_infer": 0,
"first_infer": first_infer,
"stage": 0,
"x_seq_len": x_seq_len,
"y_seq_len": y_seq_len,
}
# 运行时判断对最后一个y还是整个y做embedding以正确应对首次和后续
multipled = minus_one * first_infer * y_seq_len
index_offset = torch.min(minus_one, multipled)
y_to_emb = y[:, index_offset:]
# 对y输入进行embedding
y_emb = torch.cat(
[
cache["y_emb"],
self.ar_audio_embedding(y[:, -1:]),
self.ar_audio_embedding(y_to_emb),
],
1,
)
cache["y_emb"] = y_emb
y_pos = self.ar_audio_position(y_emb)
# 与x输入拼接做attention准备
xy_pos = torch.concat([x, y_pos], dim=1)
xy_pos = y_pos[:, -1:]
# 运行时判断对最后一个xy_pos还是整个xy_pos做self attention
multipled = minus_one * first_infer * (x_seq_len + y_seq_len) # xy_pos = 1 or x_seq_len + y_seq_len
index_offset = torch.min(minus_one, multipled)
xy_pos = xy_pos[:, index_offset:]
y_example = y_pos[:, :, 0] * 0.0
# 构造xy的attention mask
x_attn_mask = torch.zeros((x_seq_len, x_seq_len)).bool()
y_attn_mask = torch.ones((y_seq_len, y_seq_len)).to(torch.int64)
y_attn_mask = torch.cumsum(y_attn_mask, dim=1) - torch.cumsum(
torch.ones(
(y_seq_len, 1),
dtype=torch.int64,
),
dim=0,
)
y_attn_mask = y_attn_mask > 0
xy_attn_mask = torch.cat([x_example, y_example], dim=1)
xy_attn_mask = torch.zeros_like(xy_attn_mask, dtype=torch.bool)
x_y_pad = torch.ones((x_seq_len, y_seq_len)).to(torch.bool)
y_x_pad = torch.zeros((y_seq_len, x_seq_len)).to(torch.bool)
x_attn_mask_pad = torch.cat([x_attn_mask, x_y_pad], dim=1)
y_attn_mask = torch.cat([y_x_pad, y_attn_mask], dim=1)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
# 运行时判断attension mask使用最后一个还是整个
multipled = minus_one * first_infer * (x_seq_len + y_seq_len)
index_offset = torch.min(minus_one, multipled)
xy_attn_mask = xy_attn_mask[index_offset:, :]
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
logits = self.ar_predict_layer(xy_dec[:, -1])
samples = sample(logits[0], y, top_k=self.top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
samples = sample(logits[0], y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)[0].unsqueeze(0)
y = torch.concat([y, samples], dim=1)
@ -291,17 +249,6 @@ class Text2SemanticDecoder(nn.Module):
def init_onnx(self):
self.onnx_encoder = OnnxEncoder(self.ar_text_embedding, self.bert_proj, self.ar_text_position)
self.first_stage_decoder = T2SFirstStageDecoder(
self.ar_audio_embedding,
self.ar_audio_position,
self.h,
self.ar_predict_layer,
self.loss_fct,
self.ar_accuracy_metric,
self.top_k,
self.early_stop_num,
self.num_layers,
)
self.stage_decoder = T2SStageDecoder(
self.ar_audio_embedding,
self.ar_audio_position,
@ -309,33 +256,56 @@ class Text2SemanticDecoder(nn.Module):
self.ar_predict_layer,
self.loss_fct,
self.ar_accuracy_metric,
self.top_k,
self.early_stop_num,
self.num_layers,
)
def forward(self, x, prompts, bert_feature):
def forward(self, x, prompts, bert_feature, top_k = None):
# torch.manual_seed(42)
# torch.use_deterministic_algorithms(True)
if top_k is None:
top_k = self.top_k
early_stop_num = self.early_stop_num
prefix_len = prompts.shape[1]
x = self.onnx_encoder(x, bert_feature)
y, k, v, y_emb, stage, x_example = self.first_stage_decoder(x, prompts)
x_seq_len = x.shape[1]
y_seq_len = prompts.shape[1]
init_k = torch.zeros(((x_seq_len + y_seq_len), self.num_layers, 512), dtype=torch.float)
init_v = torch.zeros(((x_seq_len + y_seq_len), self.num_layers, 512), dtype=torch.float)
empty_tensor = torch.empty((1,0,512)).to(torch.float)
y, k, v, y_emb, logits, samples = self.stage_decoder(x, prompts, init_k, init_v,
empty_tensor, top_k=top_k,
first_infer=torch.LongTensor([1]),
x_seq_len=x_seq_len, y_seq_len=y_seq_len)
stop = False
for idx in range(1, 1500):
enco = self.stage_decoder(y, k, v, y_emb, stage, x_example)
y, k, v, y_emb, stage, logits, samples = enco
for idx in tqdm(range(1, 1500)):
k = torch.nn.functional.pad(k, (0, 0, 0, 0, 0, 1))
v = torch.nn.functional.pad(v, (0, 0, 0, 0, 0, 1))
y_seq_len = y.shape[1]
enco = self.stage_decoder(empty_tensor, y, k, v, y_emb, top_k=top_k,
first_infer=torch.LongTensor([0]), x_seq_len=x_seq_len, y_seq_len=y_seq_len)
y, k, v, y_emb, logits, samples = enco
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
stop = True
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
stop = True
if stop:
y = y[:,:-1]
break
y[0, -1] = 0
# torch.use_deterministic_algorithms(False)
return y, idx
def infer(self, x, prompts, bert_feature):
top_k = self.top_k
def infer(self, x, prompts, bert_feature, top_k=None):
# torch.manual_seed(42)
# torch.use_deterministic_algorithms(True)
if top_k is None:
top_k = self.top_k
early_stop_num = self.early_stop_num
x = self.onnx_encoder(x, bert_feature)
@ -356,11 +326,14 @@ class Text2SemanticDecoder(nn.Module):
"first_infer": 1,
"stage": 0,
}
for idx in range(1500):
for idx in tqdm(range(1500)):
if cache["first_infer"] == 1:
y_emb = self.ar_audio_embedding(y)
else:
y_emb = torch.cat([cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1)
for i in range(len(cache["k"])):
cache["k"][i] = torch.nn.functional.pad(cache["k"][i], (0, 0, 0, 0, 0, 1))
cache["v"][i] = torch.nn.functional.pad(cache["v"][i], (0, 0, 0, 0, 0, 1))
cache["y_emb"] = y_emb
y_pos = self.ar_audio_position(y_emb)
if cache["first_infer"] == 1:
@ -380,15 +353,14 @@ class Text2SemanticDecoder(nn.Module):
xy_attn_mask = torch.zeros((1, x_len + y_len), dtype=torch.bool)
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
logits = self.ar_predict_layer(xy_dec[:, -1])
samples = sample(logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
samples = sample(logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35, temperature=torch.Tensor([1.0]))[0].unsqueeze(0)
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
stop = True
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
stop = True
if stop:
if prompts.shape[1] == y.shape[1]:
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
break
y = torch.concat([y, samples], dim=1)
cache["first_infer"] = 0
# torch.use_deterministic_algorithms(False)
return y, idx

View File

@ -2,6 +2,7 @@ from torch.nn.functional import *
from torch.nn.functional import (
_canonical_mask,
)
from typing import Tuple, Optional
def multi_head_attention_forward_patched(
@ -48,14 +49,21 @@ def multi_head_attention_forward_patched(
proj_qkv = proj_qkv.unflatten(-1, (3, query.size(-1))).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
q, k, v = proj_qkv[0], proj_qkv[1], proj_qkv[2]
if cache["first_infer"] == 1:
cache["k"][cache["stage"]] = k
cache["v"][cache["stage"]] = v
else:
cache["k"][cache["stage"]] = torch.cat([cache["k"][cache["stage"]][:-1], k], 0)
cache["v"][cache["stage"]] = torch.cat([cache["v"][cache["stage"]][:-1], v], 0)
k = cache["k"][cache["stage"]]
v = cache["v"][cache["stage"]]
# 使用动态形状推断来统一处理kv cache首步和后续步骤形状差异
# # k,v : [N, 1, 512] at first time, [1, 1, 512] afterwards
# # cache_k, cache_v : [N, 1, 512] for one head, N size increasement is prepared outside
# cache["k"][:, cache["stage"]:cache["stage"]+1, :]
# cache["v"][:, cache["stage"]:cache["stage"]+1, :]
# Magic to get an index of either -1 or -N according to if first_infer_mask is set
minus_one = torch.tensor([-1]).to(k.device).to(torch.int64)
multipled = minus_one * cache["first_infer"] * (cache['x_seq_len'] + cache['y_seq_len'])
index_offset = torch.min(minus_one, multipled)
# 首次时 index 为 -N后续index 为 -1
cache["k"][index_offset:, cache["stage"]:cache["stage"]+1, :] = k
cache["v"][index_offset:, cache["stage"]:cache["stage"]+1, :] = v
k = cache["k"][:, cache["stage"]:cache["stage"]+1, :]
v = cache["v"][:, cache["stage"]:cache["stage"]+1, :]
cache["stage"] = (cache["stage"] + 1) % cache["all_stage"]
attn_mask = _canonical_mask(

View File

@ -13,6 +13,7 @@ __all__ = [
"mel_scale_scalar",
"spectrogram",
"fbank",
"fbank_onnx"
"mfcc",
"vtln_warp_freq",
"vtln_warp_mel_freq",
@ -842,3 +843,391 @@ def mfcc(
feature = _subtract_column_mean(feature, subtract_mean)
return feature
def _get_log_energy_onnx(strided_input: Tensor, epsilon: Tensor, energy_floor: float = 1.0) -> Tensor:
r"""Returns the log energy of size (m) for a strided_input (m,*)"""
device, dtype = strided_input.device, strided_input.dtype
log_energy = torch.max(strided_input.pow(2).sum(1), epsilon).log() # size (m)
return torch.max(log_energy, torch.tensor(0.0, device=device, dtype=dtype))
def _get_waveform_and_window_properties_onnx(
waveform: Tensor,
) -> Tuple[Tensor, int, int, int]:
r"""ONNX-compatible version with hardcoded parameters from traced fbank call:
channel=-1, sample_frequency=16000, frame_shift=10.0, frame_length=25.0,
round_to_power_of_two=True, preemphasis_coefficient=0.97"""
# Hardcoded values from traced parameters
# channel=-1 -> 0 after max(channel, 0)
channel = 0
# Extract channel 0 from waveform
if waveform.dim() == 1:
# Mono waveform, use as-is
waveform_selected = waveform
else:
# Multi-channel, select first channel
waveform_selected = waveform[channel, :]
# Hardcoded calculations:
# window_shift = int(16000 * 10.0 * 0.001) = 160
# window_size = int(16000 * 25.0 * 0.001) = 400
# padded_window_size = _next_power_of_2(400) = 512
window_shift = 160
window_size = 400
padded_window_size = 512
return waveform_selected, window_shift, window_size, padded_window_size
def _get_window_onnx(
waveform: Tensor,
) -> Tuple[Tensor, Tensor]:
r"""ONNX-compatible version with hardcoded parameters from traced fbank call:
padded_window_size=512, window_size=400, window_shift=160, window_type='povey',
blackman_coeff=0.42, snip_edges=True, raw_energy=True, energy_floor=1.0,
dither=0, remove_dc_offset=True, preemphasis_coefficient=0.97
Returns:
(Tensor, Tensor): strided_input of size (m, 512) and signal_log_energy of size (m)
"""
device, dtype = waveform.device, waveform.dtype
epsilon = _get_epsilon(device, dtype)
# Hardcoded values from traced parameters
window_size = 400
window_shift = 160
padded_window_size = 512
snip_edges = True
# size (m, window_size)
strided_input = _get_strided_onnx(waveform, window_size, window_shift, snip_edges)
# dither=0, so skip dithering (lines 209-211 from original)
# remove_dc_offset=True, so execute this branch
row_means = torch.mean(strided_input, dim=1).unsqueeze(1) # size (m, 1)
strided_input = strided_input - row_means
# raw_energy=True, so execute this branch
signal_log_energy = _get_log_energy_onnx(strided_input, epsilon) # energy_floor=1.0
# preemphasis_coefficient=0.97 != 0.0, so execute this branch
offset_strided_input = torch.nn.functional.pad(strided_input.unsqueeze(0), (1, 0), mode="replicate").squeeze(0)
strided_input = strided_input - 0.97 * offset_strided_input[:, :-1]
# Apply povey window function to each row/frame
# povey window: torch.hann_window(window_size, periodic=False).pow(0.85)
window_function = torch.hann_window(window_size, periodic=False, device=device, dtype=dtype).pow(0.85).unsqueeze(0)
strided_input = strided_input * window_function
# Pad columns from window_size=400 to padded_window_size=512
padding_right = padded_window_size - window_size # 112
strided_input = torch.nn.functional.pad(strided_input.unsqueeze(0), (0, padding_right), mode="constant", value=0).squeeze(0)
# raw_energy=True, so skip the "not raw_energy" branch (lines 244-245)
return strided_input, signal_log_energy
def _get_strided_onnx(waveform: Tensor, window_size = 400, window_shift = 160, snip_edges = 512) -> Tensor:
seq_len = waveform.size(0)
# Calculate number of windows
num_windows = 1 + (seq_len - window_size) // window_shift
# Create indices for all windows at once
window_starts = torch.arange(0, num_windows * window_shift, window_shift, device=waveform.device)
window_indices = window_starts.unsqueeze(1) + torch.arange(window_size, device=waveform.device).unsqueeze(0)
# Extract windows using advanced indexing
windows = waveform[window_indices] # [num_windows, window_size]
return windows
def _subtract_column_mean_onnx(tensor: Tensor) -> Tensor:
"""ONNX-compatible version with hardcoded parameters from traced fbank call:
subtract_mean=False, so this function returns the input tensor unchanged.
Args:
tensor: Input tensor of size (m, n)
Returns:
Tensor: Same as input tensor (m, n) since subtract_mean=False
"""
# subtract_mean=False from traced parameters, so return tensor as-is
return tensor
def get_mel_banks_onnx(
device=None,
dtype=None,
) -> Tensor:
"""ONNX-compatible version with hardcoded parameters from traced fbank call:
num_bins=80, window_length_padded=512, sample_freq=16000, low_freq=20.0,
high_freq=0.0, vtln_low=100.0, vtln_high=-500.0, vtln_warp_factor=1.0
Returns:
Tensor: melbank of size (80, 256) (num_bins, num_fft_bins)
"""
# Hardcoded values from traced parameters
num_bins = 80
window_length_padded = 512
sample_freq = 16000.0
low_freq = 20.0
high_freq = 0.0 # Will be adjusted to nyquist
vtln_warp_factor = 1.0
# Calculate dynamic values to ensure accuracy
num_fft_bins = window_length_padded // 2 # 256 (integer division)
nyquist = 0.5 * sample_freq # 8000.0
# high_freq <= 0.0, so high_freq += nyquist
if high_freq <= 0.0:
high_freq += nyquist # 8000.0
# fft-bin width = sample_freq / window_length_padded = 16000 / 512 = 31.25
fft_bin_width = sample_freq / window_length_padded
# Calculate mel scale values dynamically
mel_low_freq = mel_scale_scalar(low_freq)
mel_high_freq = mel_scale_scalar(high_freq)
# mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1)
mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1)
# vtln_warp_factor == 1.0, so no VTLN warping needed
# Create mel bin centers
bin_indices = torch.arange(num_bins, device=device, dtype=dtype).unsqueeze(1)
left_mel = mel_low_freq + bin_indices * mel_freq_delta
center_mel = mel_low_freq + (bin_indices + 1.0) * mel_freq_delta
right_mel = mel_low_freq + (bin_indices + 2.0) * mel_freq_delta
# No VTLN warping since vtln_warp_factor == 1.0
# Create frequency bins for FFT
fft_freqs = fft_bin_width * torch.arange(num_fft_bins, device=device, dtype=dtype)
mel = mel_scale(fft_freqs).unsqueeze(0) # size(1, num_fft_bins)
# Calculate triangular filter banks
up_slope = (mel - left_mel) / (center_mel - left_mel)
down_slope = (right_mel - mel) / (right_mel - center_mel)
# Since vtln_warp_factor == 1.0, use the simpler branch
bins = torch.max(torch.zeros(1, device=device, dtype=dtype), torch.min(up_slope, down_slope))
return bins
def fbank_onnx(
waveform: Tensor, num_mel_bins=80, sample_frequency=16000, dither=0
) -> Tensor:
r"""ONNX-compatible fbank function with hardcoded parameters from traced call:
num_mel_bins=80, sample_frequency=16000, dither=0
blackman_coeff: float = 0.42,
channel: int = -1,
energy_floor: float = 1.0,
frame_length: float = 25.0,
frame_shift: float = 10.0,
high_freq: float = 0.0,
htk_compat: bool = False,
low_freq: float = 20.0,
min_duration: float = 0.0,
preemphasis_coefficient: float = 0.97,
raw_energy: bool = True,
remove_dc_offset: bool = True,
round_to_power_of_two: bool = True,
snip_edges: bool = True,
subtract_mean: bool = False,
use_energy: bool = False,
use_log_fbank: bool = True,
use_power: bool = True,
vtln_high: float = -500.0,
vtln_low: float = 100.0,
vtln_warp: float = 1.0,
window_type: str = POVEY
Args:
waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
Returns:
Tensor: A fbank identical to what Kaldi would output. The shape is (m, 80)
where m is calculated in _get_strided
"""
device, dtype = waveform.device, waveform.dtype
# Use ONNX-compatible version of _get_waveform_and_window_properties
waveform_selected, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties_onnx(waveform)
# min_duration=0.0, so skip the duration check (signal will never be too short)
# Use ONNX-compatible version of _get_window
strided_input, signal_log_energy = _get_window_onnx(waveform_selected)
# spectrum = torch.fft.rfft(strided_input).abs()
m, frame_size = strided_input.shape
# Process all frames at once using batch processing
# Reshape to (m, 1, frame_size) to treat each frame as a separate batch item
batched_frames = strided_input.unsqueeze(1) # Shape: (m, 1, 512)
# Create rectangular window for all frames at once
rectangular_window = torch.ones(512, device=strided_input.device, dtype=strided_input.dtype)
# Apply STFT to all frames simultaneously
# The batch dimension allows us to process all m frames in parallel
stft_result = torch.stft(
batched_frames.flatten(0, 1), # Shape: (m, 512) - flatten batch and channel dims
n_fft=512,
hop_length=512, # Process entire frame at once
window=rectangular_window,
center=False, # Don't add padding
return_complex=False
)
# stft_result shape: (m, 257, 1, 2) where last dim is [real, imag]
# Calculate magnitude: sqrt(real^2 + imag^2)
real_part = stft_result[..., 0] # Shape: (m, 257, 1)
imag_part = stft_result[..., 1] # Shape: (m, 257, 1)
spectrum = torch.sqrt(real_part.pow(2) + imag_part.pow(2)).squeeze(-1) # Shape: (m, 257)
# use_power=True, so execute this branch
spectrum = spectrum.pow(2.0)
# Get mel filterbanks using ONNX-compatible version
mel_energies = get_mel_banks_onnx(device, dtype)
# pad right column with zeros to match FFT output size (80, 256) -> (80, 257)
mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode="constant", value=0)
# sum with mel filterbanks over the power spectrum, size (m, 80)
mel_energies = torch.mm(spectrum, mel_energies.T)
# use_log_fbank=True, so execute this branch
mel_energies = torch.max(mel_energies, _get_epsilon(device, dtype)).log()
# use_energy=False, so skip the energy addition (lines 828-834)
# Use ONNX-compatible version of _subtract_column_mean
mel_energies = _subtract_column_mean_onnx(mel_energies)
return mel_energies
# Test to compare original fbank vs fbank_onnx
if __name__ == "__main__":
import torch
print("Testing fbank vs fbank_onnx with traced parameters...")
# Create test waveform
torch.manual_seed(42)
sample_rate = 16000
duration = 1.0 # 1 second
num_samples = int(sample_rate * duration)
# Create a test waveform (sine wave + noise)
t = torch.linspace(0, duration, num_samples)
frequency = 440.0 # A4 note
waveform = torch.sin(2 * torch.pi * frequency * t) + 0.1 * torch.randn(num_samples)
# Test with both mono and stereo inputs
mono_waveform = waveform.unsqueeze(0) # Shape: (1, num_samples)
print(f"Test waveform shape: {mono_waveform.shape}")
# Test parameters from trace: num_mel_bins=80, sample_frequency=16000, dither=0
try:
print("\n=== DEBUGGING: Step-by-step comparison ===")
# Step 1: Check waveform processing
orig_waveform, orig_window_shift, orig_window_size, orig_padded_window_size = _get_waveform_and_window_properties(
mono_waveform, -1, 16000.0, 10.0, 25.0, True, 0.97
)
onnx_waveform, onnx_window_shift, onnx_window_size, onnx_padded_window_size = _get_waveform_and_window_properties_onnx(mono_waveform)
print(f"Original waveform shape: {orig_waveform.shape}")
print(f"ONNX waveform shape: {onnx_waveform.shape}")
print(f"Waveform difference: {torch.max(torch.abs(orig_waveform - onnx_waveform)).item():.2e}")
print(f"Window params - orig: shift={orig_window_shift}, size={orig_window_size}, padded={orig_padded_window_size}")
print(f"Window params - onnx: shift={onnx_window_shift}, size={onnx_window_size}, padded={onnx_padded_window_size}")
# Step 2: Check windowing
orig_strided, orig_energy = _get_window(
orig_waveform, orig_padded_window_size, orig_window_size, orig_window_shift,
'povey', 0.42, True, True, 1.0, 0, True, 0.97
)
onnx_strided, onnx_energy = _get_window_onnx(onnx_waveform)
print(f"\nOriginal strided shape: {orig_strided.shape}")
print(f"ONNX strided shape: {onnx_strided.shape}")
print(f"Strided difference: {torch.max(torch.abs(orig_strided - onnx_strided)).item():.2e}")
print(f"Energy difference: {torch.max(torch.abs(orig_energy - onnx_energy)).item():.2e}")
# Step 3: Check mel banks
orig_mel_banks = get_mel_banks(80, 512, 16000.0, 20.0, 0.0, 100.0, -500.0, 1.0, mono_waveform.device, mono_waveform.dtype)
onnx_mel_banks = get_mel_banks_onnx(mono_waveform.device, mono_waveform.dtype)
print(f"\nOriginal mel banks shape: {orig_mel_banks.shape}")
print(f"ONNX mel banks shape: {onnx_mel_banks.shape}")
print(f"Mel banks difference: {torch.max(torch.abs(orig_mel_banks - onnx_mel_banks)).item():.2e}")
# Step 4: Full comparison
print("\n=== FULL COMPARISON ===")
# Original fbank
original_result = fbank(
mono_waveform,
num_mel_bins=80,
sample_frequency=16000,
dither=0
)
# ONNX-compatible fbank
onnx_result = fbank_onnx(mono_waveform)
print(f"Original fbank output shape: {original_result.shape}")
print(f"ONNX fbank output shape: {onnx_result.shape}")
# Check if shapes match
if original_result.shape == onnx_result.shape:
print("✅ Output shapes match")
else:
print("❌ Output shapes don't match")
print(f" Original: {original_result.shape}")
print(f" ONNX: {onnx_result.shape}")
# Check numerical differences
if original_result.shape == onnx_result.shape:
diff = torch.abs(original_result - onnx_result)
max_diff = torch.max(diff).item()
mean_diff = torch.mean(diff).item()
relative_diff = torch.mean(diff / (torch.abs(original_result) + 1e-8)).item()
print(f"Max absolute difference: {max_diff:.2e}")
print(f"Mean absolute difference: {mean_diff:.2e}")
print(f"Mean relative difference: {relative_diff:.2e}")
# Find where the max difference occurs
max_idx = torch.argmax(diff)
max_coords = torch.unravel_index(max_idx, diff.shape)
print(f"Max difference at coordinates: {max_coords}")
print(f" Original value: {original_result[max_coords].item():.6f}")
print(f" ONNX value: {onnx_result[max_coords].item():.6f}")
# Check if results are numerically close
tolerance = 1e-5
if max_diff < tolerance:
print(f"✅ Results are numerically identical (within {tolerance})")
else:
print(f"❌ Results {max_diff} differ by more than {tolerance}")
# Additional statistics
print(f"Original result range: [{torch.min(original_result).item():.3f}, {torch.max(original_result).item():.3f}]")
print(f"ONNX result range: [{torch.min(onnx_result).item():.3f}, {torch.max(onnx_result).item():.3f}]")
except Exception as e:
print(f"❌ Error during testing: {e}")
import traceback
traceback.print_exc()

View File

@ -0,0 +1,187 @@
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForMaskedLM
import onnx
import onnxruntime as ort
from typing import Dict, Any
import argparse
import os
import shutil
import numpy as np
import onnxsim
import onnx
class CombinedBERTModel(nn.Module):
"""Wrapper class that combines BERT tokenizer preprocessing and model inference."""
def __init__(self, model_name: str):
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForMaskedLM.from_pretrained(model_name)
def forward(self, text_input: torch.Tensor):
"""Forward pass that includes tokenization and model inference."""
# Note: For ONNX export, we'll work with pre-tokenized input_ids
# In practice, text tokenization needs to happen outside ONNX
input_ids = text_input.long()
outputs = self.model(input_ids=input_ids, output_hidden_states=True)
return torch.cat(outputs["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
def export_bert_to_onnx(
model_name: str = "bert-base-uncased",
output_dir: str = "bert_exported",
max_seq_length: int = 512
):
"""Export BERT model to ONNX format and copy tokenizer files."""
# Create output directory
os.makedirs(output_dir, exist_ok=True)
print(f"Loading model: {model_name}")
combined_model = CombinedBERTModel(model_name)
combined_model.eval()
# Create dummy inputs for ONNX export (pre-tokenized input_ids)
batch_size = 1
dummy_input_ids = torch.randint(0, combined_model.tokenizer.vocab_size, (batch_size, max_seq_length))
# Export to ONNX
onnx_path = os.path.join(output_dir, "chinese-roberta-wwm-ext-large.onnx")
print(f"Exporting to ONNX: {onnx_path}")
torch.onnx.export(
combined_model,
dummy_input_ids,
onnx_path,
export_params=True,
opset_version=14,
do_constant_folding=True,
input_names=['input_ids'],
output_names=['logits'],
dynamic_axes={
'input_ids': {0: 'batch_size', 1: 'sequence_length'},
'logits': {0: 'logits_length'}
}
)
# Load the ONNX model
model = onnx.load(onnx_path)
# Simplify the model
model_simplified, _ = onnxsim.simplify(model)
# Save the simplified model
onnx.save(model_simplified, onnx_path)
# Copy tokenizer.json if it exists
tokenizer_cache_dir = combined_model.tokenizer.name_or_path
if os.path.isdir(tokenizer_cache_dir):
tokenizer_json_path = os.path.join(tokenizer_cache_dir, "tokenizer.json")
else:
# For models from HuggingFace cache
from transformers import cached_path
try:
tokenizer_json_path = combined_model.tokenizer._tokenizer.model_path
except:
# Alternative approach to find tokenizer.json in cache
cache_dir = os.path.expanduser("~/.cache/huggingface/transformers")
tokenizer_json_path = None
for root, dirs, files in os.walk(cache_dir):
if "tokenizer.json" in files and model_name.replace("/", "--") in root:
tokenizer_json_path = os.path.join(root, "tokenizer.json")
break
if tokenizer_json_path and os.path.exists(tokenizer_json_path):
dest_tokenizer_path = os.path.join(output_dir, "tokenizer.json")
shutil.copy2(tokenizer_json_path, dest_tokenizer_path)
print(f"Copied tokenizer.json to: {dest_tokenizer_path}")
else:
print("Warning: tokenizer.json not found")
# Copy config.json if it exists
if tokenizer_cache_dir and os.path.isdir(tokenizer_cache_dir):
config_json_path = os.path.join(tokenizer_cache_dir, "config.json")
else:
# For models from HuggingFace cache
cache_dir = os.path.expanduser("~/.cache/huggingface/transformers")
config_json_path = None
for root, dirs, files in os.walk(cache_dir):
if "config.json" in files and model_name.replace("/", "--") in root:
config_json_path = os.path.join(root, "config.json")
break
if config_json_path and os.path.exists(config_json_path):
dest_config_path = os.path.join(output_dir, "config.json")
shutil.copy2(config_json_path, dest_config_path)
print(f"Copied config.json to: {dest_config_path}")
else:
print("Warning: config.json not found")
print(f"Model exported successfully to: {output_dir}")
return combined_model, onnx_path
def test_model_equivalence(original_model, onnx_path: str, max_seq_length: int = 512, tolerance: float = 1e-5):
"""Test if the original PyTorch model and ONNX model produce the same outputs."""
print("Testing model equivalence...")
# Create test input
batch_size = 1
test_input_ids = torch.randint(0, original_model.tokenizer.vocab_size, (batch_size, max_seq_length))
input_ids = original_model.tokenizer.encode("原神,启动!", return_tensors="pt")
# Get PyTorch output
original_model.eval()
with torch.no_grad():
pytorch_output = original_model(input_ids).numpy()
# Get ONNX output
ort_session = ort.InferenceSession(onnx_path)
onnx_output = ort_session.run(None, {"input_ids": input_ids.numpy()})[0]
print(f"PyTorch output shape: {pytorch_output.shape}")
print(f"ONNX output shape: {onnx_output.shape}")
# Compare outputs
max_diff = np.max(np.abs(pytorch_output - onnx_output))
mean_diff = np.mean(np.abs(pytorch_output - onnx_output))
print(f"Maximum absolute difference: {max_diff}")
print(f"Mean absolute difference: {mean_diff}")
if max_diff < tolerance:
print("✅ Models are numerically equivalent!")
return True
else:
print("❌ Models have significant differences!")
return False
def main():
parser = argparse.ArgumentParser(description="Export BERT model to ONNX")
parser.add_argument("--model_name", type=str, default="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
help="Pretrained BERT model name")
parser.add_argument("--output_dir", type=str, default="playground/chinese-roberta-wwm-ext-large",
help="Output directory path")
parser.add_argument("--max_seq_length", type=int, default=512,
help="Maximum sequence length")
parser.add_argument("--tolerance", type=float, default=1e-3,
help="Tolerance for numerical comparison")
args = parser.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
# Export model
original_model, onnx_path = export_bert_to_onnx(
model_name=args.model_name,
output_dir=args.output_dir,
max_seq_length=args.max_seq_length
)
# Test equivalence
test_model_equivalence(
original_model=original_model,
onnx_path=onnx_path,
max_seq_length=args.max_seq_length,
tolerance=args.tolerance
)
if __name__ == "__main__":
main()

View File

@ -357,9 +357,17 @@ class ResidualVectorQuantization(nn.Module):
return out_indices
def decode(self, q_indices: torch.Tensor, st: int = 0) -> torch.Tensor:
quantized_out = torch.tensor(0.0, device=q_indices.device)
for i, indices in enumerate(q_indices):
layer = self.layers[st + i]
quantized = layer.decode(indices)
quantized_out = quantized_out + quantized
# ONNX-friendly approach: use unbind instead of enumerate loop
indices_list = torch.unbind(q_indices, dim=0)
quantized_list = []
for i, indices in enumerate(indices_list):
if st + i < len(self.layers):
layer = self.layers[st + i]
quantized = layer.decode(indices)
quantized_list.append(quantized)
# Stack and sum instead of iterative addition
quantized_out = torch.stack(quantized_list, dim=0).sum(dim=0)
return quantized_out

View File

@ -205,6 +205,8 @@ class TextEncoder(nn.Module):
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, y, text, ge, speed=1):
if type(speed) == float:
speed = torch.FloatTensor([speed])
y_mask = torch.ones_like(y[:1, :1, :])
y = self.ssl_proj(y * y_mask) * y_mask
@ -217,9 +219,8 @@ class TextEncoder(nn.Module):
y = self.mrte(y, y_mask, text, text_mask, ge)
y = self.encoder2(y * y_mask, y_mask)
if speed != 1:
y = F.interpolate(y, size=int(y.shape[-1] / speed) + 1, mode="linear")
y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest")
y = F.interpolate(y, size=(y.shape[-1] / speed).to(torch.int) + 1, mode="linear")
y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest")
stats = self.proj(y) * y_mask
m, logs = torch.split(stats, self.out_channels, dim=1)

View File

@ -1,398 +0,0 @@
import torch
import torchaudio
from AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule
from feature_extractor import cnhubert
from module.models_onnx import SynthesizerTrn, symbols_v1, symbols_v2
from torch import nn
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
cnhubert.cnhubert_base_path = cnhubert_base_path
ssl_model = cnhubert.get_model()
import json
import os
import soundfile
from text import cleaned_text_to_sequence
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
hann_window = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
y = torch.nn.functional.pad(
y.unsqueeze(1),
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
mode="reflect",
)
y = y.squeeze(1)
spec = torch.stft(
y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window,
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=False,
)
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
return spec
class DictToAttrRecursive(dict):
def __init__(self, input_dict):
super().__init__(input_dict)
for key, value in input_dict.items():
if isinstance(value, dict):
value = DictToAttrRecursive(value)
self[key] = value
setattr(self, key, value)
def __getattr__(self, item):
try:
return self[item]
except KeyError:
raise AttributeError(f"Attribute {item} not found")
def __setattr__(self, key, value):
if isinstance(value, dict):
value = DictToAttrRecursive(value)
super(DictToAttrRecursive, self).__setitem__(key, value)
super().__setattr__(key, value)
def __delattr__(self, item):
try:
del self[item]
except KeyError:
raise AttributeError(f"Attribute {item} not found")
class T2SEncoder(nn.Module):
def __init__(self, t2s, vits):
super().__init__()
self.encoder = t2s.onnx_encoder
self.vits = vits
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content):
codes = self.vits.extract_latent(ssl_content)
prompt_semantic = codes[0, 0]
bert = torch.cat([ref_bert.transpose(0, 1), text_bert.transpose(0, 1)], 1)
all_phoneme_ids = torch.cat([ref_seq, text_seq], 1)
bert = bert.unsqueeze(0)
prompt = prompt_semantic.unsqueeze(0)
return self.encoder(all_phoneme_ids, bert), prompt
class T2SModel(nn.Module):
def __init__(self, t2s_path, vits_model):
super().__init__()
dict_s1 = torch.load(t2s_path, map_location="cpu")
self.config = dict_s1["config"]
self.t2s_model = Text2SemanticLightningModule(self.config, "ojbk", is_train=False)
self.t2s_model.load_state_dict(dict_s1["weight"])
self.t2s_model.eval()
self.vits_model = vits_model.vq_model
self.hz = 50
self.max_sec = self.config["data"]["max_sec"]
self.t2s_model.model.top_k = torch.LongTensor([self.config["inference"]["top_k"]])
self.t2s_model.model.early_stop_num = torch.LongTensor([self.hz * self.max_sec])
self.t2s_model = self.t2s_model.model
self.t2s_model.init_onnx()
self.onnx_encoder = T2SEncoder(self.t2s_model, self.vits_model)
self.first_stage_decoder = self.t2s_model.first_stage_decoder
self.stage_decoder = self.t2s_model.stage_decoder
# self.t2s_model = torch.jit.script(self.t2s_model)
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content):
early_stop_num = self.t2s_model.early_stop_num
# [1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N]
x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
prefix_len = prompts.shape[1]
# [1,N,512] [1,N]
y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
stop = False
for idx in range(1, 1500):
# [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
enco = self.stage_decoder(y, k, v, y_emb, x_example)
y, k, v, y_emb, logits, samples = enco
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
stop = True
if torch.argmax(logits, dim=-1)[0] == self.t2s_model.EOS or samples[0, 0] == self.t2s_model.EOS:
stop = True
if stop:
break
y[0, -1] = 0
return y[:, -idx:].unsqueeze(0)
def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name, dynamo=False):
# self.onnx_encoder = torch.jit.script(self.onnx_encoder)
if dynamo:
export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
onnx_encoder_export_output = torch.onnx.dynamo_export(
self.onnx_encoder, (ref_seq, text_seq, ref_bert, text_bert, ssl_content), export_options=export_options
)
onnx_encoder_export_output.save(f"onnx/{project_name}/{project_name}_t2s_encoder.onnx")
return
torch.onnx.export(
self.onnx_encoder,
(ref_seq, text_seq, ref_bert, text_bert, ssl_content),
f"onnx/{project_name}/{project_name}_t2s_encoder.onnx",
input_names=["ref_seq", "text_seq", "ref_bert", "text_bert", "ssl_content"],
output_names=["x", "prompts"],
dynamic_axes={
"ref_seq": {1: "ref_length"},
"text_seq": {1: "text_length"},
"ref_bert": {0: "ref_length"},
"text_bert": {0: "text_length"},
"ssl_content": {2: "ssl_length"},
},
opset_version=16,
)
x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
torch.onnx.export(
self.first_stage_decoder,
(x, prompts),
f"onnx/{project_name}/{project_name}_t2s_fsdec.onnx",
input_names=["x", "prompts"],
output_names=["y", "k", "v", "y_emb", "x_example"],
dynamic_axes={
"x": {1: "x_length"},
"prompts": {1: "prompts_length"},
},
verbose=False,
opset_version=16,
)
y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
torch.onnx.export(
self.stage_decoder,
(y, k, v, y_emb, x_example),
f"onnx/{project_name}/{project_name}_t2s_sdec.onnx",
input_names=["iy", "ik", "iv", "iy_emb", "ix_example"],
output_names=["y", "k", "v", "y_emb", "logits", "samples"],
dynamic_axes={
"iy": {1: "iy_length"},
"ik": {1: "ik_length"},
"iv": {1: "iv_length"},
"iy_emb": {1: "iy_emb_length"},
"ix_example": {1: "ix_example_length"},
},
verbose=False,
opset_version=16,
)
class VitsModel(nn.Module):
def __init__(self, vits_path):
super().__init__()
dict_s2 = torch.load(vits_path, map_location="cpu")
self.hps = dict_s2["config"]
if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
self.hps["model"]["version"] = "v1"
else:
self.hps["model"]["version"] = "v2"
self.hps = DictToAttrRecursive(self.hps)
self.hps.model.semantic_frame_rate = "25hz"
self.vq_model = SynthesizerTrn(
self.hps.data.filter_length // 2 + 1,
self.hps.train.segment_size // self.hps.data.hop_length,
n_speakers=self.hps.data.n_speakers,
**self.hps.model,
)
self.vq_model.eval()
self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
def forward(self, text_seq, pred_semantic, ref_audio):
refer = spectrogram_torch(
ref_audio,
self.hps.data.filter_length,
self.hps.data.sampling_rate,
self.hps.data.hop_length,
self.hps.data.win_length,
center=False,
)
return self.vq_model(pred_semantic, text_seq, refer)[0, 0]
class GptSoVits(nn.Module):
def __init__(self, vits, t2s):
super().__init__()
self.vits = vits
self.t2s = t2s
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, debug=False):
pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
audio = self.vits(text_seq, pred_semantic, ref_audio)
if debug:
import onnxruntime
sess = onnxruntime.InferenceSession("onnx/koharu/koharu_vits.onnx", providers=["CPU"])
audio1 = sess.run(
None,
{
"text_seq": text_seq.detach().cpu().numpy(),
"pred_semantic": pred_semantic.detach().cpu().numpy(),
"ref_audio": ref_audio.detach().cpu().numpy(),
},
)
return audio, audio1
return audio
def export(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, project_name):
self.t2s.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name)
pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
torch.onnx.export(
self.vits,
(text_seq, pred_semantic, ref_audio),
f"onnx/{project_name}/{project_name}_vits.onnx",
input_names=["text_seq", "pred_semantic", "ref_audio"],
output_names=["audio"],
dynamic_axes={
"text_seq": {1: "text_length"},
"pred_semantic": {2: "pred_length"},
"ref_audio": {1: "audio_length"},
},
opset_version=17,
verbose=False,
)
class SSLModel(nn.Module):
def __init__(self):
super().__init__()
self.ssl = ssl_model
def forward(self, ref_audio_16k):
return self.ssl.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2)
def export(vits_path, gpt_path, project_name, vits_model="v2"):
vits = VitsModel(vits_path)
gpt = T2SModel(gpt_path, vits)
gpt_sovits = GptSoVits(vits, gpt)
ssl = SSLModel()
ref_seq = torch.LongTensor(
[
cleaned_text_to_sequence(
[
"n",
"i2",
"h",
"ao3",
",",
"w",
"o3",
"sh",
"i4",
"b",
"ai2",
"y",
"e4",
],
version=vits_model,
)
]
)
text_seq = torch.LongTensor(
[
cleaned_text_to_sequence(
[
"w",
"o3",
"sh",
"i4",
"b",
"ai2",
"y",
"e4",
"w",
"o3",
"sh",
"i4",
"b",
"ai2",
"y",
"e4",
"w",
"o3",
"sh",
"i4",
"b",
"ai2",
"y",
"e4",
],
version=vits_model,
)
]
)
ref_bert = torch.randn((ref_seq.shape[1], 1024)).float()
text_bert = torch.randn((text_seq.shape[1], 1024)).float()
ref_audio = torch.randn((1, 48000 * 5)).float()
# ref_audio = torch.tensor([load_audio("rec.wav", 48000)]).float()
ref_audio_16k = torchaudio.functional.resample(ref_audio, 48000, 16000).float()
ref_audio_sr = torchaudio.functional.resample(ref_audio, 48000, vits.hps.data.sampling_rate).float()
try:
os.mkdir(f"onnx/{project_name}")
except:
pass
ssl_content = ssl(ref_audio_16k).float()
# debug = False
debug = True
# gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, project_name)
if debug:
a, b = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, debug=debug)
soundfile.write("out1.wav", a.cpu().detach().numpy(), vits.hps.data.sampling_rate)
soundfile.write("out2.wav", b[0], vits.hps.data.sampling_rate)
else:
a = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content).detach().cpu().numpy()
soundfile.write("out.wav", a, vits.hps.data.sampling_rate)
if vits_model == "v1":
symbols = symbols_v1
else:
symbols = symbols_v2
MoeVSConf = {
"Folder": f"{project_name}",
"Name": f"{project_name}",
"Type": "GPT-SoVits",
"Rate": vits.hps.data.sampling_rate,
"NumLayers": gpt.t2s_model.num_layers,
"EmbeddingDim": gpt.t2s_model.embedding_dim,
"Dict": "BasicDict",
"BertPath": "chinese-roberta-wwm-ext-large",
# "Symbol": symbols,
"AddBlank": False,
}
MoeVSConfJson = json.dumps(MoeVSConf)
with open(f"onnx/{project_name}.json", "w") as MoeVsConfFile:
json.dump(MoeVSConf, MoeVsConfFile, indent=4)
if __name__ == "__main__":
try:
os.mkdir("onnx")
except:
pass
gpt_path = "GPT_weights/nahida-e25.ckpt"
vits_path = "SoVITS_weights/nahida_e30_s3930.pth"
exp_path = "nahida"
export(vits_path, gpt_path, exp_path)
# soundfile.write("out.wav", a, vits.hps.data.sampling_rate)

View File

@ -0,0 +1,448 @@
import torch
import torch.nn.functional as F
import torchaudio
from AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule
from feature_extractor import cnhubert
from module.models_onnx import SynthesizerTrn, symbols_v1, symbols_v2
from torch import nn
from sv import SV
import onnx
from onnx import helper, TensorProto
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
from transformers import HubertModel, HubertConfig
import os
import json
from text import cleaned_text_to_sequence
import onnxsim
from onnxconverter_common import float16
def simplify_onnx_model(onnx_model_path: str):
# Load the ONNX model
model = onnx.load(onnx_model_path)
# Simplify the model
model_simplified, _ = onnxsim.simplify(model)
# Save the simplified model
onnx.save(model_simplified, onnx_model_path)
def convert_onnx_to_half(onnx_model_path:str):
try:
model = onnx.load(onnx_model_path)
model_fp16 = float16.convert_float_to_float16(model)
onnx.save(model_fp16, onnx_model_path)
except Exception as e:
print(f"Error converting {onnx_model_path} to half precision: {e}")
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
hann_window = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
y = torch.nn.functional.pad(
y.unsqueeze(1),
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
mode="reflect",
)
y = y.squeeze(1)
spec = torch.stft(
y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window,
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=False,
)
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
return spec
def resample_audio(audio: torch.Tensor, orig_sr: int, target_sr: int) -> torch.Tensor:
"""
Resample audio from orig_sr to target_sr using linear interpolation.
audio: (batch, channels, samples) or (channels, samples) or (samples,)
"""
if audio.dim() == 1:
audio = audio.unsqueeze(0).unsqueeze(0)
elif audio.dim() == 2:
audio = audio.unsqueeze(0)
# audio shape: (batch, channels, samples)
batch, channels, samples = audio.shape
# Reshape to combine batch and channels for interpolation
audio = audio.reshape(batch * channels, 1, samples)
# Use scale_factor instead of a computed size for ONNX export compatibility
resampled = F.interpolate(audio, scale_factor=target_sr / orig_sr, mode='linear', align_corners=False)
new_samples = resampled.shape[-1]
resampled = resampled.reshape(batch, channels, new_samples)
resampled = resampled.squeeze(0).squeeze(0)
return resampled
class DictToAttrRecursive(dict):
def __init__(self, input_dict):
super().__init__(input_dict)
for key, value in input_dict.items():
if isinstance(value, dict):
value = DictToAttrRecursive(value)
self[key] = value
setattr(self, key, value)
def __getattr__(self, item):
try:
return self[item]
except KeyError:
raise AttributeError(f"Attribute {item} not found")
def __setattr__(self, key, value):
if isinstance(value, dict):
value = DictToAttrRecursive(value)
super(DictToAttrRecursive, self).__setitem__(key, value)
super().__setattr__(key, value)
def __delattr__(self, item):
try:
del self[item]
except KeyError:
raise AttributeError(f"Attribute {item} not found")
class T2SInitStage(nn.Module):
def __init__(self, t2s, vits):
super().__init__()
self.encoder = t2s.onnx_encoder
self.vits = vits
self.num_layers = t2s.num_layers
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content):
codes = self.vits.extract_latent(ssl_content)
prompt_semantic = codes[0, 0]
bert = torch.cat([ref_bert.transpose(0, 1), text_bert.transpose(0, 1)], 1)
all_phoneme_ids = torch.cat([ref_seq, text_seq], 1)
bert = bert.unsqueeze(0)
prompt = prompt_semantic.unsqueeze(0)
x = self.encoder(all_phoneme_ids, bert)
x_seq_len = torch.onnx.operators.shape_as_tensor(x)[1]
y_seq_len = torch.onnx.operators.shape_as_tensor(prompt)[1]
init_k = torch.zeros(((x_seq_len + y_seq_len), self.num_layers, 512), dtype=torch.float)
init_v = torch.zeros(((x_seq_len + y_seq_len), self.num_layers, 512), dtype=torch.float)
return x, prompt, init_k, init_v, x_seq_len, y_seq_len
class T2SModel(nn.Module):
def __init__(self, t2s_path, vits_model):
super().__init__()
dict_s1 = torch.load(t2s_path, map_location="cpu")
self.config = dict_s1["config"]
self.t2s_model = Text2SemanticLightningModule(self.config, "ojbk", is_train=False)
self.t2s_model.load_state_dict(dict_s1["weight"])
self.t2s_model.eval()
self.vits_model = vits_model.vq_model
self.hz = 50
self.max_sec = self.config["data"]["max_sec"]
self.t2s_model.model.top_k = torch.LongTensor([self.config["inference"]["top_k"]])
self.t2s_model.model.early_stop_num = torch.LongTensor([self.hz * self.max_sec])
self.t2s_model = self.t2s_model.model
self.t2s_model.init_onnx()
self.init_stage = T2SInitStage(self.t2s_model, self.vits_model)
self.stage_decoder = self.t2s_model.stage_decoder
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=None, top_p=None, repetition_penalty=None, temperature=None):
x, prompt, init_k, init_v, x_seq_len, y_seq_len = self.init_stage(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
empty_tensor = torch.empty((1,0,512)).to(torch.float)
# first step
y, k, v, y_emb, logits, samples = self.stage_decoder(x, prompt, init_k, init_v,
empty_tensor,
top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature,
first_infer=torch.LongTensor([1]), x_seq_len=x_seq_len, y_seq_len=y_seq_len)
for idx in range(5): # This is a fake one! DO NOT take this as reference
k = torch.nn.functional.pad(k, (0, 0, 0, 0, 0, 1))
v = torch.nn.functional.pad(v, (0, 0, 0, 0, 0, 1))
y_seq_len = y.shape[1]
y, k, v, y_emb, logits, samples = self.stage_decoder(empty_tensor, y, k, v,
y_emb,
top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature,
first_infer=torch.LongTensor([0]), x_seq_len=x_seq_len, y_seq_len=y_seq_len)
# if torch.argmax(logits, dim=-1)[0] == self.t2s_model.EOS or samples[0, 0] == self.t2s_model.EOS:
# break
return y[:, -5:].unsqueeze(0)
def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name, top_k=None, top_p=None, repetition_penalty=None, temperature=None):
torch.onnx.export(
self.init_stage,
(ref_seq, text_seq, ref_bert, text_bert, ssl_content),
f"onnx/{project_name}/{project_name}_t2s_init_stage.onnx",
input_names=["ref_text_phones", "input_text_phones", "ref_text_bert", "input_text_bert", "hubert_ssl_content"],
output_names=["x", "prompt", "init_k", "init_v", 'x_seq_len', 'y_seq_len'],
dynamic_axes={
"ref_text_phones": {1: "ref_length"},
"input_text_phones": {1: "text_length"},
"ref_text_bert": {0: "ref_length"},
"input_text_bert": {0: "text_length"},
"hubert_ssl_content": {2: "ssl_length"},
},
opset_version=16,
do_constant_folding=False
)
simplify_onnx_model(f"onnx/{project_name}/{project_name}_t2s_init_stage.onnx")
x, prompt, init_k, init_v, x_seq_len, y_seq_len = self.init_stage(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
empty_tensor = torch.empty((1,0,512)).to(torch.float)
x_seq_len = torch.Tensor([x_seq_len]).to(torch.int64)
y_seq_len = torch.Tensor([y_seq_len]).to(torch.int64)
y, k, v, y_emb, logits, samples = self.stage_decoder(x, prompt, init_k, init_v,
empty_tensor,
top_k, top_p, repetition_penalty, temperature,
torch.LongTensor([1]), x_seq_len, y_seq_len)
k = torch.nn.functional.pad(k, (0, 0, 0, 0, 0, 1))
v = torch.nn.functional.pad(v, (0, 0, 0, 0, 0, 1))
y_seq_len = torch.Tensor([y.shape[1]]).to(torch.int64)
torch.onnx.export(
self.stage_decoder,
(x, y, k, v, y_emb, top_k, top_p, repetition_penalty, temperature, torch.LongTensor([0]), x_seq_len, y_seq_len),
f"onnx/{project_name}/{project_name}_t2s_stage_decoder.onnx",
input_names=["ix", "iy", "ik", "iv", "iy_emb", "top_k", "top_p", "repetition_penalty", "temperature", "if_init_step", "x_seq_len", "y_seq_len"],
output_names=["y", "k", "v", "y_emb", "logits", "samples"],
dynamic_axes={
"ix": {1: "ix_length"},
"iy": {1: "iy_length"},
"ik": {0: "ik_length"},
"iv": {0: "iv_length"},
"iy_emb": {1: "iy_emb_length"},
},
verbose=False,
opset_version=16,
)
simplify_onnx_model(f"onnx/{project_name}/{project_name}_t2s_stage_decoder.onnx")
class VitsModel(nn.Module):
def __init__(self, vits_path, version:str = 'v2'):
super().__init__()
dict_s2 = torch.load(vits_path, map_location="cpu", weights_only=False)
self.hps = dict_s2["config"]
if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
self.hps["model"]["version"] = "v1"
else:
self.hps["model"]["version"] = version
self.is_v2p = version.lower() in ['v2pro', 'v2proplus']
self.hps = DictToAttrRecursive(self.hps)
self.hps.model.semantic_frame_rate = "25hz"
self.vq_model:SynthesizerTrn = SynthesizerTrn(
self.hps.data.filter_length // 2 + 1,
self.hps.train.segment_size // self.hps.data.hop_length,
n_speakers=self.hps.data.n_speakers,
**self.hps.model,
)
self.vq_model.eval()
self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
# print(f"filter_length:{self.hps.data.filter_length} sampling_rate:{self.hps.data.sampling_rate} hop_length:{self.hps.data.hop_length} win_length:{self.hps.data.win_length}")
#v2 filter_length: 2048 sampling_rate: 32000 hop_length: 640 win_length: 2048
def forward(self, text_seq, pred_semantic, spectrum, sv_emb, speed):
if self.is_v2p:
return self.vq_model(pred_semantic, text_seq, spectrum, sv_emb=sv_emb, speed=speed)[0, 0]
else:
return self.vq_model(pred_semantic, text_seq, spectrum, speed=speed)[0, 0]
class GptSoVits():
def __init__(self, vits, t2s):
super().__init__()
self.vits = vits
self.t2s = t2s
def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, spectrum, sv_emb, speed, project_name, top_k=None, top_p=None, repetition_penalty=None, temperature=None):
self.t2s.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)
pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)
torch.onnx.export(
self.vits,
(text_seq, pred_semantic, spectrum, sv_emb, speed),
f"onnx/{project_name}/{project_name}_vits.onnx",
input_names=["input_text_phones", "pred_semantic", "spectrum", "sv_emb", "speed"],
output_names=["audio"],
dynamic_axes={
"input_text_phones": {1: "text_length"},
"pred_semantic": {2: "pred_length"},
"spectrum": {2: "spectrum_length"},
},
opset_version=17,
verbose=False,
)
simplify_onnx_model(f"onnx/{project_name}/{project_name}_vits.onnx")
class AudioPreprocess(nn.Module):
def __init__(self):
super().__init__()
# Load the model
self.model = HubertModel.from_pretrained(cnhubert_base_path, local_files_only=True)
self.model.eval()
self.sv_model = SV("cpu", False)
def forward(self, ref_audio_32k):
spectrum = spectrogram_torch(
ref_audio_32k,
2048,
32000,
640,
2048,
center=False,
)
ref_audio_16k = resample_audio(ref_audio_32k, 32000, 16000)
sv_emb = self.sv_model.compute_embedding3_onnx(ref_audio_16k)
zero_tensor = torch.zeros((1, 9600), dtype=torch.float32)
ref_audio_16k = ref_audio_16k.unsqueeze(0)
# concate zero_tensor with waveform
ref_audio_16k = torch.cat([ref_audio_16k, zero_tensor], dim=1)
ssl_content = self.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2)
return ssl_content, spectrum, sv_emb
def export(vits_path, gpt_path, project_name, voice_model_version, export_audio_preprocessor=True, half_precision=False):
vits = VitsModel(vits_path, version=voice_model_version)
gpt = T2SModel(gpt_path, vits)
gpt_sovits = GptSoVits(vits, gpt)
preprocessor = AudioPreprocess()
ref_seq = torch.LongTensor(
[
cleaned_text_to_sequence(
[
"n",
"i2",
"h",
"ao3",
",",
"w",
"o3",
"sh",
"i4",
"b",
"ai2",
"y",
"e4",
],
version='v2',
)
]
)
text_seq = torch.LongTensor(
[
cleaned_text_to_sequence(
[
"w",
"o3",
"sh",
"i4",
"b",
"ai2",
"y",
"e4",
"w",
"o3",
"sh",
"i4",
"b",
"ai2",
"y",
"e4",
"w",
"o3",
"sh",
"i4",
"b",
"ai2",
"y",
"e4",
],
version='v2',
)
]
)
ref_bert = torch.randn((ref_seq.shape[1], 1024)).float()
text_bert = torch.randn((text_seq.shape[1], 1024)).float()
ref_audio32k = torch.randn((1, 32000 * 5)).float() - 0.5 # 5 seconds of dummy audio
top_k = torch.LongTensor([15])
top_p = torch.FloatTensor([1.0])
repetition_penalty = torch.FloatTensor([1.0])
temperature = torch.FloatTensor([1.0])
speed = torch.FloatTensor([1.0])
os.makedirs(f"onnx/{project_name}", exist_ok=True)
[ssl_content, spectrum, sv_emb] = preprocessor(ref_audio32k)
gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content.float(), spectrum.float(), sv_emb.float(), speed, project_name, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)
if export_audio_preprocessor:
torch.onnx.export(preprocessor, (ref_audio32k,), f"onnx/{project_name}/{project_name}_audio_preprocess.onnx",
input_names=["audio32k"],
output_names=["hubert_ssl_output", "spectrum", "sv_emb"],
dynamic_axes={
"audio32k": {1: "sequence_length"},
"hubert_ssl_output": {2: "hubert_length"},
"spectrum": {2: "spectrum_length"}
})
simplify_onnx_model(f"onnx/{project_name}/{project_name}_audio_preprocess.onnx")
if half_precision:
if export_audio_preprocessor:
convert_onnx_to_half(f"onnx/{project_name}/{project_name}_audio_preprocess.onnx")
convert_onnx_to_half(f"onnx/{project_name}/{project_name}_vits.onnx")
convert_onnx_to_half(f"onnx/{project_name}/{project_name}_t2s_init_step.onnx")
convert_onnx_to_half(f"onnx/{project_name}/{project_name}_t2s_stage_step.onnx")
configJson = {
"project_name": project_name,
"type": "GPTSoVITS",
"version" : voice_model_version,
"bert_base_path": 'GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large',
"cnhuhbert_base_path": 'GPT_SoVITS/pretrained_models/chinese-hubert-base',
"t2s_weights_path": gpt_path,
"vits_weights_path": vits_path,
"half_precision": half_precision
}
with open(f"onnx/{project_name}/config.json", "w", encoding="utf-8") as f:
json.dump(configJson, f, ensure_ascii=False, indent=4)
if __name__ == "__main__":
try:
os.mkdir("onnx")
except:
pass
# 因为io太频繁可能导致模型导出出错(wsl非常明显),请自行重试
gpt_path = "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
vits_path = "GPT_SoVITS/pretrained_models/s2G488k.pth"
exp_path = "v1_export"
version = "v1"
export(vits_path, gpt_path, exp_path, version)
gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"
vits_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"
exp_path = "v2_export"
version = "v2"
export(vits_path, gpt_path, exp_path, version)
gpt_path = "GPT_SoVITS/pretrained_models/s1v3.ckpt"
vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth"
exp_path = "v2pro_export"
version = "v2Pro"
export(vits_path, gpt_path, exp_path, version)
gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"
vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth"
exp_path = "v2proplus_export"
version = "v2ProPlus"
export(vits_path, gpt_path, exp_path, version)

View File

@ -30,3 +30,15 @@ class SV:
)
sv_emb = self.embedding_model.forward3(feat)
return sv_emb
def compute_embedding3_onnx(self, wav):
# Disable gradients for all parameters
for param in self.embedding_model.parameters():
param.requires_grad = False
with torch.no_grad():
if self.is_half == True:
wav = wav.half()
feat = Kaldi.fbank_onnx(wav.detach()).unsqueeze(0)
sv_emb = self.embedding_model.forward3(feat)
return sv_emb

View File

@ -7,8 +7,11 @@ numba
pytorch-lightning>=2.4
gradio<5
ffmpeg-python
onnx
onnxruntime; platform_machine == "aarch64" or platform_machine == "arm64"
onnxruntime-gpu; platform_machine == "x86_64" or platform_machine == "AMD64"
onnxsim
onnxconverter-common
tqdm
funasr==1.0.27
cn2an
@ -32,7 +35,7 @@ rotary_embedding_torch
ToJyutping
g2pk2
ko_pron
opencc
opencc==1.1.6
python_mecab_ko; sys_platform != 'win32'
fastapi[standard]>=0.115.2
x_transformers