mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-29 00:30:15 +08:00
Merge 8858492f56326dce521db2c4a4b3a7323e786596 into 11aa78bd9bda8b53047cfcae03abf7ca94d27391
This commit is contained in:
commit
7e082dba50
7
.gitignore
vendored
7
.gitignore
vendored
@ -193,3 +193,10 @@ cython_debug/
|
||||
|
||||
# PyPI configuration file
|
||||
.pypirc
|
||||
|
||||
#onnx
|
||||
onnx/
|
||||
*.onnx
|
||||
tokenizer.json
|
||||
output.wav
|
||||
config.json
|
@ -7,6 +7,7 @@ from torchmetrics.classification import MulticlassAccuracy
|
||||
|
||||
from AR.modules.embedding_onnx import SinePositionalEmbedding, TokenEmbedding
|
||||
from AR.modules.transformer_onnx import LayerNorm, TransformerEncoder, TransformerEncoderLayer
|
||||
from tqdm import tqdm
|
||||
|
||||
default_config = {
|
||||
"embedding_dim": 512,
|
||||
@ -27,45 +28,45 @@ def logits_to_probs(
|
||||
logits,
|
||||
previous_tokens=None,
|
||||
temperature: float = 1.0,
|
||||
top_k=None,
|
||||
top_p=None,
|
||||
top_k=15,
|
||||
top_p=1.0,
|
||||
repetition_penalty: float = 1.0,
|
||||
):
|
||||
previous_tokens = previous_tokens.squeeze()
|
||||
if previous_tokens is not None and repetition_penalty != 1.0:
|
||||
previous_tokens = previous_tokens.long()
|
||||
score = torch.gather(logits, dim=0, index=previous_tokens)
|
||||
score = torch.where(
|
||||
score < 0,
|
||||
score * repetition_penalty,
|
||||
score / repetition_penalty,
|
||||
)
|
||||
logits.scatter_(dim=0, index=previous_tokens, src=score)
|
||||
# if previous_tokens is not None and repetition_penalty != 1.0: # Always captured by onnx
|
||||
previous_tokens = previous_tokens.long()
|
||||
score = torch.gather(logits, dim=0, index=previous_tokens)
|
||||
score = torch.where(
|
||||
score < 0,
|
||||
score * repetition_penalty,
|
||||
score / repetition_penalty,
|
||||
)
|
||||
logits.scatter_(dim=0, index=previous_tokens, src=score)
|
||||
|
||||
if top_p is not None and top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
cum_probs = torch.cumsum(
|
||||
torch.nn.functional.softmax(
|
||||
sorted_logits,
|
||||
dim=-1,
|
||||
),
|
||||
# if top_p is not None and top_p < 1.0: #To be captured by onnx
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
cum_probs = torch.cumsum(
|
||||
torch.nn.functional.softmax(
|
||||
sorted_logits,
|
||||
dim=-1,
|
||||
)
|
||||
sorted_indices_to_remove = cum_probs > top_p
|
||||
sorted_indices_to_remove[0] = False # keep at least one option
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||
dim=0,
|
||||
index=sorted_indices,
|
||||
src=sorted_indices_to_remove,
|
||||
)
|
||||
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
sorted_indices_to_remove = cum_probs > top_p
|
||||
sorted_indices_to_remove[0] = False # keep at least one option
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||
dim=0,
|
||||
index=sorted_indices,
|
||||
src=sorted_indices_to_remove,
|
||||
)
|
||||
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
||||
|
||||
logits = logits / max(temperature, 1e-5)
|
||||
logits = logits / torch.max(temperature, torch.tensor(1e-5, device=logits.device, dtype=torch.float))
|
||||
|
||||
if top_k is not None:
|
||||
v, _ = torch.topk(logits, top_k)
|
||||
pivot = v.select(-1, -1).unsqueeze(-1)
|
||||
logits = torch.where(logits < pivot, inf_tensor_value, logits)
|
||||
# if top_k is not None: # To be captured by onnx
|
||||
v, _ = torch.topk(logits, top_k)
|
||||
pivot = v.select(-1, -1).unsqueeze(-1)
|
||||
logits = torch.where(logits < pivot, inf_tensor_value, logits)
|
||||
|
||||
probs = torch.nn.functional.softmax(logits, dim=-1)
|
||||
return probs
|
||||
@ -104,88 +105,6 @@ class OnnxEncoder(nn.Module):
|
||||
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
||||
return self.ar_text_position(x)
|
||||
|
||||
|
||||
class T2SFirstStageDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
ar_audio_embedding,
|
||||
ar_audio_position,
|
||||
h,
|
||||
ar_predict_layer,
|
||||
loss_fct,
|
||||
ar_accuracy_metric,
|
||||
top_k,
|
||||
early_stop_num,
|
||||
num_layers,
|
||||
):
|
||||
super().__init__()
|
||||
self.ar_audio_embedding = ar_audio_embedding
|
||||
self.ar_audio_position = ar_audio_position
|
||||
self.h = h
|
||||
self.ar_predict_layer = ar_predict_layer
|
||||
self.loss_fct = loss_fct
|
||||
self.ar_accuracy_metric = ar_accuracy_metric
|
||||
self.top_k = top_k
|
||||
self.early_stop_num = early_stop_num
|
||||
self.num_layers = num_layers
|
||||
|
||||
def forward(self, x, prompt):
|
||||
y = prompt
|
||||
x_example = x[:, :, 0] * 0.0
|
||||
# N, 1, 512
|
||||
cache = {
|
||||
"all_stage": self.num_layers,
|
||||
"k": None,
|
||||
"v": None,
|
||||
"y_emb": None,
|
||||
"first_infer": 1,
|
||||
"stage": 0,
|
||||
}
|
||||
|
||||
y_emb = self.ar_audio_embedding(y)
|
||||
|
||||
cache["y_emb"] = y_emb
|
||||
y_pos = self.ar_audio_position(y_emb)
|
||||
|
||||
xy_pos = torch.concat([x, y_pos], dim=1)
|
||||
|
||||
y_example = y_pos[:, :, 0] * 0.0
|
||||
x_attn_mask = torch.matmul(x_example.transpose(0, 1), x_example).bool()
|
||||
y_attn_mask = torch.ones_like(torch.matmul(y_example.transpose(0, 1), y_example), dtype=torch.int64)
|
||||
y_attn_mask = torch.cumsum(y_attn_mask, dim=1) - torch.cumsum(
|
||||
torch.ones_like(
|
||||
y_example.transpose(0, 1),
|
||||
dtype=torch.int64,
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
y_attn_mask = y_attn_mask > 0
|
||||
|
||||
x_y_pad = torch.matmul(x_example.transpose(0, 1), y_example).bool()
|
||||
y_x_pad = torch.matmul(y_example.transpose(0, 1), x_example).bool()
|
||||
x_attn_mask_pad = torch.cat([x_attn_mask, torch.ones_like(x_y_pad)], dim=1)
|
||||
y_attn_mask = torch.cat([y_x_pad, y_attn_mask], dim=1)
|
||||
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
|
||||
cache["k"] = (
|
||||
torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))
|
||||
.unsqueeze(1)
|
||||
.repeat(self.num_layers, 1, 1, 1)
|
||||
)
|
||||
cache["v"] = (
|
||||
torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))
|
||||
.unsqueeze(1)
|
||||
.repeat(self.num_layers, 1, 1, 1)
|
||||
)
|
||||
|
||||
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
|
||||
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||
samples = sample(logits[0], y, top_k=self.top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
|
||||
|
||||
y = torch.concat([y, samples], dim=1)
|
||||
|
||||
return y, cache["k"], cache["v"], cache["y_emb"], x_example
|
||||
|
||||
|
||||
class T2SStageDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -195,7 +114,6 @@ class T2SStageDecoder(nn.Module):
|
||||
ar_predict_layer,
|
||||
loss_fct,
|
||||
ar_accuracy_metric,
|
||||
top_k,
|
||||
early_stop_num,
|
||||
num_layers,
|
||||
):
|
||||
@ -206,40 +124,80 @@ class T2SStageDecoder(nn.Module):
|
||||
self.ar_predict_layer = ar_predict_layer
|
||||
self.loss_fct = loss_fct
|
||||
self.ar_accuracy_metric = ar_accuracy_metric
|
||||
self.top_k = top_k
|
||||
self.early_stop_num = early_stop_num
|
||||
self.num_layers = num_layers
|
||||
|
||||
def forward(self, y, k, v, y_emb, x_example):
|
||||
def forward(self, x, y, k, v, y_emb, top_k = None, top_p = None, repetition_penalty = None, temperature = None, first_infer = None, x_seq_len = None, y_seq_len = None):
|
||||
if top_k is None:
|
||||
top_k = torch.LongTensor([15]).to(device=y.device)
|
||||
if top_p is None:
|
||||
top_p = torch.FloatTensor([1.0]).to(device=y.device)
|
||||
if repetition_penalty is None:
|
||||
repetition_penalty = torch.FloatTensor([1.0]).to(device=y.device)
|
||||
if temperature is None:
|
||||
temperature = torch.FloatTensor([1.0]).to(device=y.device)
|
||||
minus_one = torch.tensor([-1]).to(y.device).to(torch.int64)
|
||||
|
||||
cache = {
|
||||
"all_stage": self.num_layers,
|
||||
"k": torch.nn.functional.pad(k, (0, 0, 0, 0, 0, 1)),
|
||||
"v": torch.nn.functional.pad(v, (0, 0, 0, 0, 0, 1)),
|
||||
"k": k,
|
||||
"v": v,
|
||||
"y_emb": y_emb,
|
||||
"first_infer": 0,
|
||||
"first_infer": first_infer,
|
||||
"stage": 0,
|
||||
"x_seq_len": x_seq_len,
|
||||
"y_seq_len": y_seq_len,
|
||||
}
|
||||
|
||||
# 运行时判断对最后一个y还是整个y做embedding,以正确应对首次和后续
|
||||
multipled = minus_one * first_infer * y_seq_len
|
||||
index_offset = torch.min(minus_one, multipled)
|
||||
y_to_emb = y[:, index_offset:]
|
||||
# 对y输入进行embedding
|
||||
y_emb = torch.cat(
|
||||
[
|
||||
cache["y_emb"],
|
||||
self.ar_audio_embedding(y[:, -1:]),
|
||||
self.ar_audio_embedding(y_to_emb),
|
||||
],
|
||||
1,
|
||||
)
|
||||
cache["y_emb"] = y_emb
|
||||
y_pos = self.ar_audio_position(y_emb)
|
||||
# 与x输入拼接做attention准备
|
||||
xy_pos = torch.concat([x, y_pos], dim=1)
|
||||
|
||||
xy_pos = y_pos[:, -1:]
|
||||
# 运行时判断对最后一个xy_pos还是整个xy_pos做self attention
|
||||
multipled = minus_one * first_infer * (x_seq_len + y_seq_len) # xy_pos = 1 or x_seq_len + y_seq_len
|
||||
index_offset = torch.min(minus_one, multipled)
|
||||
xy_pos = xy_pos[:, index_offset:]
|
||||
|
||||
y_example = y_pos[:, :, 0] * 0.0
|
||||
# 构造xy的attention mask
|
||||
x_attn_mask = torch.zeros((x_seq_len, x_seq_len)).bool()
|
||||
y_attn_mask = torch.ones((y_seq_len, y_seq_len)).to(torch.int64)
|
||||
y_attn_mask = torch.cumsum(y_attn_mask, dim=1) - torch.cumsum(
|
||||
torch.ones(
|
||||
(y_seq_len, 1),
|
||||
dtype=torch.int64,
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
y_attn_mask = y_attn_mask > 0
|
||||
|
||||
xy_attn_mask = torch.cat([x_example, y_example], dim=1)
|
||||
xy_attn_mask = torch.zeros_like(xy_attn_mask, dtype=torch.bool)
|
||||
x_y_pad = torch.ones((x_seq_len, y_seq_len)).to(torch.bool)
|
||||
y_x_pad = torch.zeros((y_seq_len, x_seq_len)).to(torch.bool)
|
||||
|
||||
x_attn_mask_pad = torch.cat([x_attn_mask, x_y_pad], dim=1)
|
||||
y_attn_mask = torch.cat([y_x_pad, y_attn_mask], dim=1)
|
||||
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
|
||||
|
||||
# 运行时判断attension mask使用最后一个还是整个
|
||||
multipled = minus_one * first_infer * (x_seq_len + y_seq_len)
|
||||
index_offset = torch.min(minus_one, multipled)
|
||||
xy_attn_mask = xy_attn_mask[index_offset:, :]
|
||||
|
||||
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
|
||||
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||
samples = sample(logits[0], y, top_k=self.top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
|
||||
samples = sample(logits[0], y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)[0].unsqueeze(0)
|
||||
|
||||
y = torch.concat([y, samples], dim=1)
|
||||
|
||||
@ -291,17 +249,6 @@ class Text2SemanticDecoder(nn.Module):
|
||||
|
||||
def init_onnx(self):
|
||||
self.onnx_encoder = OnnxEncoder(self.ar_text_embedding, self.bert_proj, self.ar_text_position)
|
||||
self.first_stage_decoder = T2SFirstStageDecoder(
|
||||
self.ar_audio_embedding,
|
||||
self.ar_audio_position,
|
||||
self.h,
|
||||
self.ar_predict_layer,
|
||||
self.loss_fct,
|
||||
self.ar_accuracy_metric,
|
||||
self.top_k,
|
||||
self.early_stop_num,
|
||||
self.num_layers,
|
||||
)
|
||||
self.stage_decoder = T2SStageDecoder(
|
||||
self.ar_audio_embedding,
|
||||
self.ar_audio_position,
|
||||
@ -309,33 +256,56 @@ class Text2SemanticDecoder(nn.Module):
|
||||
self.ar_predict_layer,
|
||||
self.loss_fct,
|
||||
self.ar_accuracy_metric,
|
||||
self.top_k,
|
||||
self.early_stop_num,
|
||||
self.num_layers,
|
||||
)
|
||||
|
||||
def forward(self, x, prompts, bert_feature):
|
||||
def forward(self, x, prompts, bert_feature, top_k = None):
|
||||
# torch.manual_seed(42)
|
||||
# torch.use_deterministic_algorithms(True)
|
||||
if top_k is None:
|
||||
top_k = self.top_k
|
||||
early_stop_num = self.early_stop_num
|
||||
prefix_len = prompts.shape[1]
|
||||
|
||||
x = self.onnx_encoder(x, bert_feature)
|
||||
y, k, v, y_emb, stage, x_example = self.first_stage_decoder(x, prompts)
|
||||
|
||||
x_seq_len = x.shape[1]
|
||||
y_seq_len = prompts.shape[1]
|
||||
|
||||
init_k = torch.zeros(((x_seq_len + y_seq_len), self.num_layers, 512), dtype=torch.float)
|
||||
init_v = torch.zeros(((x_seq_len + y_seq_len), self.num_layers, 512), dtype=torch.float)
|
||||
|
||||
empty_tensor = torch.empty((1,0,512)).to(torch.float)
|
||||
|
||||
y, k, v, y_emb, logits, samples = self.stage_decoder(x, prompts, init_k, init_v,
|
||||
empty_tensor, top_k=top_k,
|
||||
first_infer=torch.LongTensor([1]),
|
||||
x_seq_len=x_seq_len, y_seq_len=y_seq_len)
|
||||
|
||||
stop = False
|
||||
for idx in range(1, 1500):
|
||||
enco = self.stage_decoder(y, k, v, y_emb, stage, x_example)
|
||||
y, k, v, y_emb, stage, logits, samples = enco
|
||||
for idx in tqdm(range(1, 1500)):
|
||||
k = torch.nn.functional.pad(k, (0, 0, 0, 0, 0, 1))
|
||||
v = torch.nn.functional.pad(v, (0, 0, 0, 0, 0, 1))
|
||||
y_seq_len = y.shape[1]
|
||||
enco = self.stage_decoder(empty_tensor, y, k, v, y_emb, top_k=top_k,
|
||||
first_infer=torch.LongTensor([0]), x_seq_len=x_seq_len, y_seq_len=y_seq_len)
|
||||
y, k, v, y_emb, logits, samples = enco
|
||||
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
||||
stop = True
|
||||
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
|
||||
stop = True
|
||||
if stop:
|
||||
y = y[:,:-1]
|
||||
break
|
||||
y[0, -1] = 0
|
||||
# torch.use_deterministic_algorithms(False)
|
||||
return y, idx
|
||||
|
||||
def infer(self, x, prompts, bert_feature):
|
||||
top_k = self.top_k
|
||||
def infer(self, x, prompts, bert_feature, top_k=None):
|
||||
# torch.manual_seed(42)
|
||||
# torch.use_deterministic_algorithms(True)
|
||||
if top_k is None:
|
||||
top_k = self.top_k
|
||||
early_stop_num = self.early_stop_num
|
||||
|
||||
x = self.onnx_encoder(x, bert_feature)
|
||||
@ -356,11 +326,14 @@ class Text2SemanticDecoder(nn.Module):
|
||||
"first_infer": 1,
|
||||
"stage": 0,
|
||||
}
|
||||
for idx in range(1500):
|
||||
for idx in tqdm(range(1500)):
|
||||
if cache["first_infer"] == 1:
|
||||
y_emb = self.ar_audio_embedding(y)
|
||||
else:
|
||||
y_emb = torch.cat([cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1)
|
||||
for i in range(len(cache["k"])):
|
||||
cache["k"][i] = torch.nn.functional.pad(cache["k"][i], (0, 0, 0, 0, 0, 1))
|
||||
cache["v"][i] = torch.nn.functional.pad(cache["v"][i], (0, 0, 0, 0, 0, 1))
|
||||
cache["y_emb"] = y_emb
|
||||
y_pos = self.ar_audio_position(y_emb)
|
||||
if cache["first_infer"] == 1:
|
||||
@ -380,15 +353,14 @@ class Text2SemanticDecoder(nn.Module):
|
||||
xy_attn_mask = torch.zeros((1, x_len + y_len), dtype=torch.bool)
|
||||
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
|
||||
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||
samples = sample(logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
|
||||
samples = sample(logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35, temperature=torch.Tensor([1.0]))[0].unsqueeze(0)
|
||||
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
||||
stop = True
|
||||
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
|
||||
stop = True
|
||||
if stop:
|
||||
if prompts.shape[1] == y.shape[1]:
|
||||
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
|
||||
break
|
||||
y = torch.concat([y, samples], dim=1)
|
||||
cache["first_infer"] = 0
|
||||
# torch.use_deterministic_algorithms(False)
|
||||
return y, idx
|
||||
|
@ -2,6 +2,7 @@ from torch.nn.functional import *
|
||||
from torch.nn.functional import (
|
||||
_canonical_mask,
|
||||
)
|
||||
from typing import Tuple, Optional
|
||||
|
||||
|
||||
def multi_head_attention_forward_patched(
|
||||
@ -48,14 +49,21 @@ def multi_head_attention_forward_patched(
|
||||
proj_qkv = proj_qkv.unflatten(-1, (3, query.size(-1))).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
|
||||
q, k, v = proj_qkv[0], proj_qkv[1], proj_qkv[2]
|
||||
|
||||
if cache["first_infer"] == 1:
|
||||
cache["k"][cache["stage"]] = k
|
||||
cache["v"][cache["stage"]] = v
|
||||
else:
|
||||
cache["k"][cache["stage"]] = torch.cat([cache["k"][cache["stage"]][:-1], k], 0)
|
||||
cache["v"][cache["stage"]] = torch.cat([cache["v"][cache["stage"]][:-1], v], 0)
|
||||
k = cache["k"][cache["stage"]]
|
||||
v = cache["v"][cache["stage"]]
|
||||
# 使用动态形状推断来统一处理kv cache首步和后续步骤形状差异
|
||||
# # k,v : [N, 1, 512] at first time, [1, 1, 512] afterwards
|
||||
# # cache_k, cache_v : [N, 1, 512] for one head, N size increasement is prepared outside
|
||||
# cache["k"][:, cache["stage"]:cache["stage"]+1, :]
|
||||
# cache["v"][:, cache["stage"]:cache["stage"]+1, :]
|
||||
# Magic to get an index of either -1 or -N according to if first_infer_mask is set
|
||||
minus_one = torch.tensor([-1]).to(k.device).to(torch.int64)
|
||||
multipled = minus_one * cache["first_infer"] * (cache['x_seq_len'] + cache['y_seq_len'])
|
||||
index_offset = torch.min(minus_one, multipled)
|
||||
# 首次时 index 为 -N,后续index 为 -1
|
||||
cache["k"][index_offset:, cache["stage"]:cache["stage"]+1, :] = k
|
||||
cache["v"][index_offset:, cache["stage"]:cache["stage"]+1, :] = v
|
||||
k = cache["k"][:, cache["stage"]:cache["stage"]+1, :]
|
||||
v = cache["v"][:, cache["stage"]:cache["stage"]+1, :]
|
||||
|
||||
cache["stage"] = (cache["stage"] + 1) % cache["all_stage"]
|
||||
|
||||
attn_mask = _canonical_mask(
|
||||
|
@ -13,6 +13,7 @@ __all__ = [
|
||||
"mel_scale_scalar",
|
||||
"spectrogram",
|
||||
"fbank",
|
||||
"fbank_onnx"
|
||||
"mfcc",
|
||||
"vtln_warp_freq",
|
||||
"vtln_warp_mel_freq",
|
||||
@ -842,3 +843,391 @@ def mfcc(
|
||||
|
||||
feature = _subtract_column_mean(feature, subtract_mean)
|
||||
return feature
|
||||
|
||||
def _get_log_energy_onnx(strided_input: Tensor, epsilon: Tensor, energy_floor: float = 1.0) -> Tensor:
|
||||
r"""Returns the log energy of size (m) for a strided_input (m,*)"""
|
||||
device, dtype = strided_input.device, strided_input.dtype
|
||||
log_energy = torch.max(strided_input.pow(2).sum(1), epsilon).log() # size (m)
|
||||
return torch.max(log_energy, torch.tensor(0.0, device=device, dtype=dtype))
|
||||
|
||||
|
||||
def _get_waveform_and_window_properties_onnx(
|
||||
waveform: Tensor,
|
||||
) -> Tuple[Tensor, int, int, int]:
|
||||
r"""ONNX-compatible version with hardcoded parameters from traced fbank call:
|
||||
channel=-1, sample_frequency=16000, frame_shift=10.0, frame_length=25.0,
|
||||
round_to_power_of_two=True, preemphasis_coefficient=0.97"""
|
||||
|
||||
# Hardcoded values from traced parameters
|
||||
# channel=-1 -> 0 after max(channel, 0)
|
||||
channel = 0
|
||||
|
||||
# Extract channel 0 from waveform
|
||||
if waveform.dim() == 1:
|
||||
# Mono waveform, use as-is
|
||||
waveform_selected = waveform
|
||||
else:
|
||||
# Multi-channel, select first channel
|
||||
waveform_selected = waveform[channel, :]
|
||||
|
||||
# Hardcoded calculations:
|
||||
# window_shift = int(16000 * 10.0 * 0.001) = 160
|
||||
# window_size = int(16000 * 25.0 * 0.001) = 400
|
||||
# padded_window_size = _next_power_of_2(400) = 512
|
||||
window_shift = 160
|
||||
window_size = 400
|
||||
padded_window_size = 512
|
||||
|
||||
return waveform_selected, window_shift, window_size, padded_window_size
|
||||
|
||||
def _get_window_onnx(
|
||||
waveform: Tensor,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
r"""ONNX-compatible version with hardcoded parameters from traced fbank call:
|
||||
padded_window_size=512, window_size=400, window_shift=160, window_type='povey',
|
||||
blackman_coeff=0.42, snip_edges=True, raw_energy=True, energy_floor=1.0,
|
||||
dither=0, remove_dc_offset=True, preemphasis_coefficient=0.97
|
||||
|
||||
Returns:
|
||||
(Tensor, Tensor): strided_input of size (m, 512) and signal_log_energy of size (m)
|
||||
"""
|
||||
device, dtype = waveform.device, waveform.dtype
|
||||
epsilon = _get_epsilon(device, dtype)
|
||||
|
||||
# Hardcoded values from traced parameters
|
||||
window_size = 400
|
||||
window_shift = 160
|
||||
padded_window_size = 512
|
||||
snip_edges = True
|
||||
|
||||
# size (m, window_size)
|
||||
strided_input = _get_strided_onnx(waveform, window_size, window_shift, snip_edges)
|
||||
|
||||
# dither=0, so skip dithering (lines 209-211 from original)
|
||||
|
||||
# remove_dc_offset=True, so execute this branch
|
||||
row_means = torch.mean(strided_input, dim=1).unsqueeze(1) # size (m, 1)
|
||||
strided_input = strided_input - row_means
|
||||
|
||||
# raw_energy=True, so execute this branch
|
||||
signal_log_energy = _get_log_energy_onnx(strided_input, epsilon) # energy_floor=1.0
|
||||
|
||||
# preemphasis_coefficient=0.97 != 0.0, so execute this branch
|
||||
offset_strided_input = torch.nn.functional.pad(strided_input.unsqueeze(0), (1, 0), mode="replicate").squeeze(0)
|
||||
strided_input = strided_input - 0.97 * offset_strided_input[:, :-1]
|
||||
|
||||
# Apply povey window function to each row/frame
|
||||
# povey window: torch.hann_window(window_size, periodic=False).pow(0.85)
|
||||
window_function = torch.hann_window(window_size, periodic=False, device=device, dtype=dtype).pow(0.85).unsqueeze(0)
|
||||
strided_input = strided_input * window_function
|
||||
|
||||
# Pad columns from window_size=400 to padded_window_size=512
|
||||
padding_right = padded_window_size - window_size # 112
|
||||
strided_input = torch.nn.functional.pad(strided_input.unsqueeze(0), (0, padding_right), mode="constant", value=0).squeeze(0)
|
||||
|
||||
# raw_energy=True, so skip the "not raw_energy" branch (lines 244-245)
|
||||
return strided_input, signal_log_energy
|
||||
|
||||
|
||||
def _get_strided_onnx(waveform: Tensor, window_size = 400, window_shift = 160, snip_edges = 512) -> Tensor:
|
||||
seq_len = waveform.size(0)
|
||||
|
||||
# Calculate number of windows
|
||||
num_windows = 1 + (seq_len - window_size) // window_shift
|
||||
|
||||
# Create indices for all windows at once
|
||||
window_starts = torch.arange(0, num_windows * window_shift, window_shift, device=waveform.device)
|
||||
window_indices = window_starts.unsqueeze(1) + torch.arange(window_size, device=waveform.device).unsqueeze(0)
|
||||
|
||||
# Extract windows using advanced indexing
|
||||
windows = waveform[window_indices] # [num_windows, window_size]
|
||||
|
||||
return windows
|
||||
|
||||
|
||||
def _subtract_column_mean_onnx(tensor: Tensor) -> Tensor:
|
||||
"""ONNX-compatible version with hardcoded parameters from traced fbank call:
|
||||
subtract_mean=False, so this function returns the input tensor unchanged.
|
||||
|
||||
Args:
|
||||
tensor: Input tensor of size (m, n)
|
||||
|
||||
Returns:
|
||||
Tensor: Same as input tensor (m, n) since subtract_mean=False
|
||||
"""
|
||||
# subtract_mean=False from traced parameters, so return tensor as-is
|
||||
return tensor
|
||||
|
||||
|
||||
def get_mel_banks_onnx(
|
||||
device=None,
|
||||
dtype=None,
|
||||
) -> Tensor:
|
||||
"""ONNX-compatible version with hardcoded parameters from traced fbank call:
|
||||
num_bins=80, window_length_padded=512, sample_freq=16000, low_freq=20.0,
|
||||
high_freq=0.0, vtln_low=100.0, vtln_high=-500.0, vtln_warp_factor=1.0
|
||||
|
||||
Returns:
|
||||
Tensor: melbank of size (80, 256) (num_bins, num_fft_bins)
|
||||
"""
|
||||
# Hardcoded values from traced parameters
|
||||
num_bins = 80
|
||||
window_length_padded = 512
|
||||
sample_freq = 16000.0
|
||||
low_freq = 20.0
|
||||
high_freq = 0.0 # Will be adjusted to nyquist
|
||||
vtln_warp_factor = 1.0
|
||||
|
||||
# Calculate dynamic values to ensure accuracy
|
||||
num_fft_bins = window_length_padded // 2 # 256 (integer division)
|
||||
nyquist = 0.5 * sample_freq # 8000.0
|
||||
|
||||
# high_freq <= 0.0, so high_freq += nyquist
|
||||
if high_freq <= 0.0:
|
||||
high_freq += nyquist # 8000.0
|
||||
|
||||
# fft-bin width = sample_freq / window_length_padded = 16000 / 512 = 31.25
|
||||
fft_bin_width = sample_freq / window_length_padded
|
||||
|
||||
# Calculate mel scale values dynamically
|
||||
mel_low_freq = mel_scale_scalar(low_freq)
|
||||
mel_high_freq = mel_scale_scalar(high_freq)
|
||||
|
||||
# mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1)
|
||||
mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1)
|
||||
|
||||
# vtln_warp_factor == 1.0, so no VTLN warping needed
|
||||
|
||||
# Create mel bin centers
|
||||
bin_indices = torch.arange(num_bins, device=device, dtype=dtype).unsqueeze(1)
|
||||
left_mel = mel_low_freq + bin_indices * mel_freq_delta
|
||||
center_mel = mel_low_freq + (bin_indices + 1.0) * mel_freq_delta
|
||||
right_mel = mel_low_freq + (bin_indices + 2.0) * mel_freq_delta
|
||||
|
||||
# No VTLN warping since vtln_warp_factor == 1.0
|
||||
|
||||
# Create frequency bins for FFT
|
||||
fft_freqs = fft_bin_width * torch.arange(num_fft_bins, device=device, dtype=dtype)
|
||||
mel = mel_scale(fft_freqs).unsqueeze(0) # size(1, num_fft_bins)
|
||||
|
||||
# Calculate triangular filter banks
|
||||
up_slope = (mel - left_mel) / (center_mel - left_mel)
|
||||
down_slope = (right_mel - mel) / (right_mel - center_mel)
|
||||
|
||||
# Since vtln_warp_factor == 1.0, use the simpler branch
|
||||
bins = torch.max(torch.zeros(1, device=device, dtype=dtype), torch.min(up_slope, down_slope))
|
||||
|
||||
return bins
|
||||
|
||||
|
||||
def fbank_onnx(
|
||||
waveform: Tensor, num_mel_bins=80, sample_frequency=16000, dither=0
|
||||
) -> Tensor:
|
||||
r"""ONNX-compatible fbank function with hardcoded parameters from traced call:
|
||||
num_mel_bins=80, sample_frequency=16000, dither=0
|
||||
blackman_coeff: float = 0.42,
|
||||
channel: int = -1,
|
||||
energy_floor: float = 1.0,
|
||||
frame_length: float = 25.0,
|
||||
frame_shift: float = 10.0,
|
||||
high_freq: float = 0.0,
|
||||
htk_compat: bool = False,
|
||||
low_freq: float = 20.0,
|
||||
min_duration: float = 0.0,
|
||||
preemphasis_coefficient: float = 0.97,
|
||||
raw_energy: bool = True,
|
||||
remove_dc_offset: bool = True,
|
||||
round_to_power_of_two: bool = True,
|
||||
snip_edges: bool = True,
|
||||
subtract_mean: bool = False,
|
||||
use_energy: bool = False,
|
||||
use_log_fbank: bool = True,
|
||||
use_power: bool = True,
|
||||
vtln_high: float = -500.0,
|
||||
vtln_low: float = 100.0,
|
||||
vtln_warp: float = 1.0,
|
||||
window_type: str = POVEY
|
||||
|
||||
Args:
|
||||
waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
|
||||
|
||||
Returns:
|
||||
Tensor: A fbank identical to what Kaldi would output. The shape is (m, 80)
|
||||
where m is calculated in _get_strided
|
||||
"""
|
||||
device, dtype = waveform.device, waveform.dtype
|
||||
|
||||
# Use ONNX-compatible version of _get_waveform_and_window_properties
|
||||
waveform_selected, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties_onnx(waveform)
|
||||
|
||||
# min_duration=0.0, so skip the duration check (signal will never be too short)
|
||||
|
||||
# Use ONNX-compatible version of _get_window
|
||||
strided_input, signal_log_energy = _get_window_onnx(waveform_selected)
|
||||
|
||||
# spectrum = torch.fft.rfft(strided_input).abs()
|
||||
|
||||
m, frame_size = strided_input.shape
|
||||
|
||||
# Process all frames at once using batch processing
|
||||
# Reshape to (m, 1, frame_size) to treat each frame as a separate batch item
|
||||
batched_frames = strided_input.unsqueeze(1) # Shape: (m, 1, 512)
|
||||
|
||||
# Create rectangular window for all frames at once
|
||||
rectangular_window = torch.ones(512, device=strided_input.device, dtype=strided_input.dtype)
|
||||
|
||||
# Apply STFT to all frames simultaneously
|
||||
# The batch dimension allows us to process all m frames in parallel
|
||||
stft_result = torch.stft(
|
||||
batched_frames.flatten(0, 1), # Shape: (m, 512) - flatten batch and channel dims
|
||||
n_fft=512,
|
||||
hop_length=512, # Process entire frame at once
|
||||
window=rectangular_window,
|
||||
center=False, # Don't add padding
|
||||
return_complex=False
|
||||
)
|
||||
|
||||
# stft_result shape: (m, 257, 1, 2) where last dim is [real, imag]
|
||||
# Calculate magnitude: sqrt(real^2 + imag^2)
|
||||
real_part = stft_result[..., 0] # Shape: (m, 257, 1)
|
||||
imag_part = stft_result[..., 1] # Shape: (m, 257, 1)
|
||||
spectrum = torch.sqrt(real_part.pow(2) + imag_part.pow(2)).squeeze(-1) # Shape: (m, 257)
|
||||
|
||||
# use_power=True, so execute this branch
|
||||
spectrum = spectrum.pow(2.0)
|
||||
|
||||
# Get mel filterbanks using ONNX-compatible version
|
||||
mel_energies = get_mel_banks_onnx(device, dtype)
|
||||
|
||||
# pad right column with zeros to match FFT output size (80, 256) -> (80, 257)
|
||||
mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode="constant", value=0)
|
||||
|
||||
# sum with mel filterbanks over the power spectrum, size (m, 80)
|
||||
mel_energies = torch.mm(spectrum, mel_energies.T)
|
||||
|
||||
# use_log_fbank=True, so execute this branch
|
||||
mel_energies = torch.max(mel_energies, _get_epsilon(device, dtype)).log()
|
||||
|
||||
# use_energy=False, so skip the energy addition (lines 828-834)
|
||||
|
||||
# Use ONNX-compatible version of _subtract_column_mean
|
||||
mel_energies = _subtract_column_mean_onnx(mel_energies)
|
||||
|
||||
return mel_energies
|
||||
|
||||
# Test to compare original fbank vs fbank_onnx
|
||||
if __name__ == "__main__":
|
||||
import torch
|
||||
|
||||
print("Testing fbank vs fbank_onnx with traced parameters...")
|
||||
|
||||
# Create test waveform
|
||||
torch.manual_seed(42)
|
||||
sample_rate = 16000
|
||||
duration = 1.0 # 1 second
|
||||
num_samples = int(sample_rate * duration)
|
||||
|
||||
# Create a test waveform (sine wave + noise)
|
||||
t = torch.linspace(0, duration, num_samples)
|
||||
frequency = 440.0 # A4 note
|
||||
waveform = torch.sin(2 * torch.pi * frequency * t) + 0.1 * torch.randn(num_samples)
|
||||
|
||||
# Test with both mono and stereo inputs
|
||||
mono_waveform = waveform.unsqueeze(0) # Shape: (1, num_samples)
|
||||
|
||||
print(f"Test waveform shape: {mono_waveform.shape}")
|
||||
|
||||
# Test parameters from trace: num_mel_bins=80, sample_frequency=16000, dither=0
|
||||
try:
|
||||
print("\n=== DEBUGGING: Step-by-step comparison ===")
|
||||
|
||||
# Step 1: Check waveform processing
|
||||
orig_waveform, orig_window_shift, orig_window_size, orig_padded_window_size = _get_waveform_and_window_properties(
|
||||
mono_waveform, -1, 16000.0, 10.0, 25.0, True, 0.97
|
||||
)
|
||||
onnx_waveform, onnx_window_shift, onnx_window_size, onnx_padded_window_size = _get_waveform_and_window_properties_onnx(mono_waveform)
|
||||
|
||||
print(f"Original waveform shape: {orig_waveform.shape}")
|
||||
print(f"ONNX waveform shape: {onnx_waveform.shape}")
|
||||
print(f"Waveform difference: {torch.max(torch.abs(orig_waveform - onnx_waveform)).item():.2e}")
|
||||
print(f"Window params - orig: shift={orig_window_shift}, size={orig_window_size}, padded={orig_padded_window_size}")
|
||||
print(f"Window params - onnx: shift={onnx_window_shift}, size={onnx_window_size}, padded={onnx_padded_window_size}")
|
||||
|
||||
# Step 2: Check windowing
|
||||
orig_strided, orig_energy = _get_window(
|
||||
orig_waveform, orig_padded_window_size, orig_window_size, orig_window_shift,
|
||||
'povey', 0.42, True, True, 1.0, 0, True, 0.97
|
||||
)
|
||||
onnx_strided, onnx_energy = _get_window_onnx(onnx_waveform)
|
||||
|
||||
print(f"\nOriginal strided shape: {orig_strided.shape}")
|
||||
print(f"ONNX strided shape: {onnx_strided.shape}")
|
||||
print(f"Strided difference: {torch.max(torch.abs(orig_strided - onnx_strided)).item():.2e}")
|
||||
print(f"Energy difference: {torch.max(torch.abs(orig_energy - onnx_energy)).item():.2e}")
|
||||
|
||||
# Step 3: Check mel banks
|
||||
orig_mel_banks = get_mel_banks(80, 512, 16000.0, 20.0, 0.0, 100.0, -500.0, 1.0, mono_waveform.device, mono_waveform.dtype)
|
||||
onnx_mel_banks = get_mel_banks_onnx(mono_waveform.device, mono_waveform.dtype)
|
||||
|
||||
print(f"\nOriginal mel banks shape: {orig_mel_banks.shape}")
|
||||
print(f"ONNX mel banks shape: {onnx_mel_banks.shape}")
|
||||
print(f"Mel banks difference: {torch.max(torch.abs(orig_mel_banks - onnx_mel_banks)).item():.2e}")
|
||||
|
||||
# Step 4: Full comparison
|
||||
print("\n=== FULL COMPARISON ===")
|
||||
|
||||
# Original fbank
|
||||
original_result = fbank(
|
||||
mono_waveform,
|
||||
num_mel_bins=80,
|
||||
sample_frequency=16000,
|
||||
dither=0
|
||||
)
|
||||
|
||||
# ONNX-compatible fbank
|
||||
onnx_result = fbank_onnx(mono_waveform)
|
||||
|
||||
print(f"Original fbank output shape: {original_result.shape}")
|
||||
print(f"ONNX fbank output shape: {onnx_result.shape}")
|
||||
|
||||
# Check if shapes match
|
||||
if original_result.shape == onnx_result.shape:
|
||||
print("✅ Output shapes match")
|
||||
else:
|
||||
print("❌ Output shapes don't match")
|
||||
print(f" Original: {original_result.shape}")
|
||||
print(f" ONNX: {onnx_result.shape}")
|
||||
|
||||
# Check numerical differences
|
||||
if original_result.shape == onnx_result.shape:
|
||||
diff = torch.abs(original_result - onnx_result)
|
||||
max_diff = torch.max(diff).item()
|
||||
mean_diff = torch.mean(diff).item()
|
||||
relative_diff = torch.mean(diff / (torch.abs(original_result) + 1e-8)).item()
|
||||
|
||||
print(f"Max absolute difference: {max_diff:.2e}")
|
||||
print(f"Mean absolute difference: {mean_diff:.2e}")
|
||||
print(f"Mean relative difference: {relative_diff:.2e}")
|
||||
|
||||
# Find where the max difference occurs
|
||||
max_idx = torch.argmax(diff)
|
||||
max_coords = torch.unravel_index(max_idx, diff.shape)
|
||||
print(f"Max difference at coordinates: {max_coords}")
|
||||
print(f" Original value: {original_result[max_coords].item():.6f}")
|
||||
print(f" ONNX value: {onnx_result[max_coords].item():.6f}")
|
||||
|
||||
# Check if results are numerically close
|
||||
tolerance = 1e-5
|
||||
if max_diff < tolerance:
|
||||
print(f"✅ Results are numerically identical (within {tolerance})")
|
||||
else:
|
||||
print(f"❌ Results {max_diff} differ by more than {tolerance}")
|
||||
|
||||
# Additional statistics
|
||||
print(f"Original result range: [{torch.min(original_result).item():.3f}, {torch.max(original_result).item():.3f}]")
|
||||
print(f"ONNX result range: [{torch.min(onnx_result).item():.3f}, {torch.max(onnx_result).item():.3f}]")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error during testing: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
187
GPT_SoVITS/export_roberta_onnx.py
Normal file
187
GPT_SoVITS/export_roberta_onnx.py
Normal file
@ -0,0 +1,187 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
||||
import onnx
|
||||
import onnxruntime as ort
|
||||
from typing import Dict, Any
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
import numpy as np
|
||||
import onnxsim
|
||||
import onnx
|
||||
|
||||
class CombinedBERTModel(nn.Module):
|
||||
"""Wrapper class that combines BERT tokenizer preprocessing and model inference."""
|
||||
|
||||
def __init__(self, model_name: str):
|
||||
super().__init__()
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
self.model = AutoModelForMaskedLM.from_pretrained(model_name)
|
||||
|
||||
def forward(self, text_input: torch.Tensor):
|
||||
"""Forward pass that includes tokenization and model inference."""
|
||||
# Note: For ONNX export, we'll work with pre-tokenized input_ids
|
||||
# In practice, text tokenization needs to happen outside ONNX
|
||||
input_ids = text_input.long()
|
||||
|
||||
outputs = self.model(input_ids=input_ids, output_hidden_states=True)
|
||||
return torch.cat(outputs["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
|
||||
|
||||
def export_bert_to_onnx(
|
||||
model_name: str = "bert-base-uncased",
|
||||
output_dir: str = "bert_exported",
|
||||
max_seq_length: int = 512
|
||||
):
|
||||
"""Export BERT model to ONNX format and copy tokenizer files."""
|
||||
|
||||
# Create output directory
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
print(f"Loading model: {model_name}")
|
||||
combined_model = CombinedBERTModel(model_name)
|
||||
combined_model.eval()
|
||||
|
||||
# Create dummy inputs for ONNX export (pre-tokenized input_ids)
|
||||
batch_size = 1
|
||||
dummy_input_ids = torch.randint(0, combined_model.tokenizer.vocab_size, (batch_size, max_seq_length))
|
||||
|
||||
# Export to ONNX
|
||||
onnx_path = os.path.join(output_dir, "chinese-roberta-wwm-ext-large.onnx")
|
||||
print(f"Exporting to ONNX: {onnx_path}")
|
||||
torch.onnx.export(
|
||||
combined_model,
|
||||
dummy_input_ids,
|
||||
onnx_path,
|
||||
export_params=True,
|
||||
opset_version=14,
|
||||
do_constant_folding=True,
|
||||
input_names=['input_ids'],
|
||||
output_names=['logits'],
|
||||
dynamic_axes={
|
||||
'input_ids': {0: 'batch_size', 1: 'sequence_length'},
|
||||
'logits': {0: 'logits_length'}
|
||||
}
|
||||
)
|
||||
# Load the ONNX model
|
||||
model = onnx.load(onnx_path)
|
||||
# Simplify the model
|
||||
model_simplified, _ = onnxsim.simplify(model)
|
||||
# Save the simplified model
|
||||
onnx.save(model_simplified, onnx_path)
|
||||
|
||||
# Copy tokenizer.json if it exists
|
||||
tokenizer_cache_dir = combined_model.tokenizer.name_or_path
|
||||
if os.path.isdir(tokenizer_cache_dir):
|
||||
tokenizer_json_path = os.path.join(tokenizer_cache_dir, "tokenizer.json")
|
||||
else:
|
||||
# For models from HuggingFace cache
|
||||
from transformers import cached_path
|
||||
try:
|
||||
tokenizer_json_path = combined_model.tokenizer._tokenizer.model_path
|
||||
except:
|
||||
# Alternative approach to find tokenizer.json in cache
|
||||
cache_dir = os.path.expanduser("~/.cache/huggingface/transformers")
|
||||
tokenizer_json_path = None
|
||||
for root, dirs, files in os.walk(cache_dir):
|
||||
if "tokenizer.json" in files and model_name.replace("/", "--") in root:
|
||||
tokenizer_json_path = os.path.join(root, "tokenizer.json")
|
||||
break
|
||||
|
||||
if tokenizer_json_path and os.path.exists(tokenizer_json_path):
|
||||
dest_tokenizer_path = os.path.join(output_dir, "tokenizer.json")
|
||||
shutil.copy2(tokenizer_json_path, dest_tokenizer_path)
|
||||
print(f"Copied tokenizer.json to: {dest_tokenizer_path}")
|
||||
else:
|
||||
print("Warning: tokenizer.json not found")
|
||||
|
||||
# Copy config.json if it exists
|
||||
if tokenizer_cache_dir and os.path.isdir(tokenizer_cache_dir):
|
||||
config_json_path = os.path.join(tokenizer_cache_dir, "config.json")
|
||||
else:
|
||||
# For models from HuggingFace cache
|
||||
cache_dir = os.path.expanduser("~/.cache/huggingface/transformers")
|
||||
config_json_path = None
|
||||
for root, dirs, files in os.walk(cache_dir):
|
||||
if "config.json" in files and model_name.replace("/", "--") in root:
|
||||
config_json_path = os.path.join(root, "config.json")
|
||||
break
|
||||
|
||||
if config_json_path and os.path.exists(config_json_path):
|
||||
dest_config_path = os.path.join(output_dir, "config.json")
|
||||
shutil.copy2(config_json_path, dest_config_path)
|
||||
print(f"Copied config.json to: {dest_config_path}")
|
||||
else:
|
||||
print("Warning: config.json not found")
|
||||
|
||||
print(f"Model exported successfully to: {output_dir}")
|
||||
return combined_model, onnx_path
|
||||
|
||||
def test_model_equivalence(original_model, onnx_path: str, max_seq_length: int = 512, tolerance: float = 1e-5):
|
||||
"""Test if the original PyTorch model and ONNX model produce the same outputs."""
|
||||
|
||||
print("Testing model equivalence...")
|
||||
|
||||
# Create test input
|
||||
batch_size = 1
|
||||
test_input_ids = torch.randint(0, original_model.tokenizer.vocab_size, (batch_size, max_seq_length))
|
||||
input_ids = original_model.tokenizer.encode("原神,启动!", return_tensors="pt")
|
||||
|
||||
|
||||
# Get PyTorch output
|
||||
original_model.eval()
|
||||
with torch.no_grad():
|
||||
pytorch_output = original_model(input_ids).numpy()
|
||||
|
||||
# Get ONNX output
|
||||
ort_session = ort.InferenceSession(onnx_path)
|
||||
onnx_output = ort_session.run(None, {"input_ids": input_ids.numpy()})[0]
|
||||
|
||||
print(f"PyTorch output shape: {pytorch_output.shape}")
|
||||
print(f"ONNX output shape: {onnx_output.shape}")
|
||||
# Compare outputs
|
||||
max_diff = np.max(np.abs(pytorch_output - onnx_output))
|
||||
mean_diff = np.mean(np.abs(pytorch_output - onnx_output))
|
||||
|
||||
print(f"Maximum absolute difference: {max_diff}")
|
||||
print(f"Mean absolute difference: {mean_diff}")
|
||||
|
||||
if max_diff < tolerance:
|
||||
print("✅ Models are numerically equivalent!")
|
||||
return True
|
||||
else:
|
||||
print("❌ Models have significant differences!")
|
||||
return False
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Export BERT model to ONNX")
|
||||
parser.add_argument("--model_name", type=str, default="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
|
||||
help="Pretrained BERT model name")
|
||||
parser.add_argument("--output_dir", type=str, default="playground/chinese-roberta-wwm-ext-large",
|
||||
help="Output directory path")
|
||||
parser.add_argument("--max_seq_length", type=int, default=512,
|
||||
help="Maximum sequence length")
|
||||
parser.add_argument("--tolerance", type=float, default=1e-3,
|
||||
help="Tolerance for numerical comparison")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# Export model
|
||||
original_model, onnx_path = export_bert_to_onnx(
|
||||
model_name=args.model_name,
|
||||
output_dir=args.output_dir,
|
||||
max_seq_length=args.max_seq_length
|
||||
)
|
||||
|
||||
# Test equivalence
|
||||
test_model_equivalence(
|
||||
original_model=original_model,
|
||||
onnx_path=onnx_path,
|
||||
max_seq_length=args.max_seq_length,
|
||||
tolerance=args.tolerance
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -357,9 +357,17 @@ class ResidualVectorQuantization(nn.Module):
|
||||
return out_indices
|
||||
|
||||
def decode(self, q_indices: torch.Tensor, st: int = 0) -> torch.Tensor:
|
||||
quantized_out = torch.tensor(0.0, device=q_indices.device)
|
||||
for i, indices in enumerate(q_indices):
|
||||
layer = self.layers[st + i]
|
||||
quantized = layer.decode(indices)
|
||||
quantized_out = quantized_out + quantized
|
||||
# ONNX-friendly approach: use unbind instead of enumerate loop
|
||||
indices_list = torch.unbind(q_indices, dim=0)
|
||||
quantized_list = []
|
||||
|
||||
for i, indices in enumerate(indices_list):
|
||||
if st + i < len(self.layers):
|
||||
layer = self.layers[st + i]
|
||||
quantized = layer.decode(indices)
|
||||
quantized_list.append(quantized)
|
||||
|
||||
# Stack and sum instead of iterative addition
|
||||
quantized_out = torch.stack(quantized_list, dim=0).sum(dim=0)
|
||||
|
||||
return quantized_out
|
||||
|
@ -205,6 +205,8 @@ class TextEncoder(nn.Module):
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||
|
||||
def forward(self, y, text, ge, speed=1):
|
||||
if type(speed) == float:
|
||||
speed = torch.FloatTensor([speed])
|
||||
y_mask = torch.ones_like(y[:1, :1, :])
|
||||
|
||||
y = self.ssl_proj(y * y_mask) * y_mask
|
||||
@ -217,9 +219,8 @@ class TextEncoder(nn.Module):
|
||||
y = self.mrte(y, y_mask, text, text_mask, ge)
|
||||
|
||||
y = self.encoder2(y * y_mask, y_mask)
|
||||
if speed != 1:
|
||||
y = F.interpolate(y, size=int(y.shape[-1] / speed) + 1, mode="linear")
|
||||
y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest")
|
||||
y = F.interpolate(y, size=(y.shape[-1] / speed).to(torch.int) + 1, mode="linear")
|
||||
y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest")
|
||||
|
||||
stats = self.proj(y) * y_mask
|
||||
m, logs = torch.split(stats, self.out_channels, dim=1)
|
||||
|
@ -1,398 +0,0 @@
|
||||
import torch
|
||||
import torchaudio
|
||||
from AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule
|
||||
from feature_extractor import cnhubert
|
||||
from module.models_onnx import SynthesizerTrn, symbols_v1, symbols_v2
|
||||
from torch import nn
|
||||
|
||||
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
|
||||
cnhubert.cnhubert_base_path = cnhubert_base_path
|
||||
ssl_model = cnhubert.get_model()
|
||||
import json
|
||||
import os
|
||||
|
||||
import soundfile
|
||||
from text import cleaned_text_to_sequence
|
||||
|
||||
|
||||
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
|
||||
hann_window = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1),
|
||||
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
||||
mode="reflect",
|
||||
)
|
||||
y = y.squeeze(1)
|
||||
spec = torch.stft(
|
||||
y,
|
||||
n_fft,
|
||||
hop_length=hop_size,
|
||||
win_length=win_size,
|
||||
window=hann_window,
|
||||
center=center,
|
||||
pad_mode="reflect",
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
return_complex=False,
|
||||
)
|
||||
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
||||
return spec
|
||||
|
||||
|
||||
class DictToAttrRecursive(dict):
|
||||
def __init__(self, input_dict):
|
||||
super().__init__(input_dict)
|
||||
for key, value in input_dict.items():
|
||||
if isinstance(value, dict):
|
||||
value = DictToAttrRecursive(value)
|
||||
self[key] = value
|
||||
setattr(self, key, value)
|
||||
|
||||
def __getattr__(self, item):
|
||||
try:
|
||||
return self[item]
|
||||
except KeyError:
|
||||
raise AttributeError(f"Attribute {item} not found")
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if isinstance(value, dict):
|
||||
value = DictToAttrRecursive(value)
|
||||
super(DictToAttrRecursive, self).__setitem__(key, value)
|
||||
super().__setattr__(key, value)
|
||||
|
||||
def __delattr__(self, item):
|
||||
try:
|
||||
del self[item]
|
||||
except KeyError:
|
||||
raise AttributeError(f"Attribute {item} not found")
|
||||
|
||||
|
||||
class T2SEncoder(nn.Module):
|
||||
def __init__(self, t2s, vits):
|
||||
super().__init__()
|
||||
self.encoder = t2s.onnx_encoder
|
||||
self.vits = vits
|
||||
|
||||
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content):
|
||||
codes = self.vits.extract_latent(ssl_content)
|
||||
prompt_semantic = codes[0, 0]
|
||||
bert = torch.cat([ref_bert.transpose(0, 1), text_bert.transpose(0, 1)], 1)
|
||||
all_phoneme_ids = torch.cat([ref_seq, text_seq], 1)
|
||||
bert = bert.unsqueeze(0)
|
||||
prompt = prompt_semantic.unsqueeze(0)
|
||||
return self.encoder(all_phoneme_ids, bert), prompt
|
||||
|
||||
|
||||
class T2SModel(nn.Module):
|
||||
def __init__(self, t2s_path, vits_model):
|
||||
super().__init__()
|
||||
dict_s1 = torch.load(t2s_path, map_location="cpu")
|
||||
self.config = dict_s1["config"]
|
||||
self.t2s_model = Text2SemanticLightningModule(self.config, "ojbk", is_train=False)
|
||||
self.t2s_model.load_state_dict(dict_s1["weight"])
|
||||
self.t2s_model.eval()
|
||||
self.vits_model = vits_model.vq_model
|
||||
self.hz = 50
|
||||
self.max_sec = self.config["data"]["max_sec"]
|
||||
self.t2s_model.model.top_k = torch.LongTensor([self.config["inference"]["top_k"]])
|
||||
self.t2s_model.model.early_stop_num = torch.LongTensor([self.hz * self.max_sec])
|
||||
self.t2s_model = self.t2s_model.model
|
||||
self.t2s_model.init_onnx()
|
||||
self.onnx_encoder = T2SEncoder(self.t2s_model, self.vits_model)
|
||||
self.first_stage_decoder = self.t2s_model.first_stage_decoder
|
||||
self.stage_decoder = self.t2s_model.stage_decoder
|
||||
# self.t2s_model = torch.jit.script(self.t2s_model)
|
||||
|
||||
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content):
|
||||
early_stop_num = self.t2s_model.early_stop_num
|
||||
|
||||
# [1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N]
|
||||
x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
||||
|
||||
prefix_len = prompts.shape[1]
|
||||
|
||||
# [1,N,512] [1,N]
|
||||
y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
|
||||
|
||||
stop = False
|
||||
for idx in range(1, 1500):
|
||||
# [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
|
||||
enco = self.stage_decoder(y, k, v, y_emb, x_example)
|
||||
y, k, v, y_emb, logits, samples = enco
|
||||
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
||||
stop = True
|
||||
if torch.argmax(logits, dim=-1)[0] == self.t2s_model.EOS or samples[0, 0] == self.t2s_model.EOS:
|
||||
stop = True
|
||||
if stop:
|
||||
break
|
||||
y[0, -1] = 0
|
||||
|
||||
return y[:, -idx:].unsqueeze(0)
|
||||
|
||||
def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name, dynamo=False):
|
||||
# self.onnx_encoder = torch.jit.script(self.onnx_encoder)
|
||||
if dynamo:
|
||||
export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
|
||||
onnx_encoder_export_output = torch.onnx.dynamo_export(
|
||||
self.onnx_encoder, (ref_seq, text_seq, ref_bert, text_bert, ssl_content), export_options=export_options
|
||||
)
|
||||
onnx_encoder_export_output.save(f"onnx/{project_name}/{project_name}_t2s_encoder.onnx")
|
||||
return
|
||||
|
||||
torch.onnx.export(
|
||||
self.onnx_encoder,
|
||||
(ref_seq, text_seq, ref_bert, text_bert, ssl_content),
|
||||
f"onnx/{project_name}/{project_name}_t2s_encoder.onnx",
|
||||
input_names=["ref_seq", "text_seq", "ref_bert", "text_bert", "ssl_content"],
|
||||
output_names=["x", "prompts"],
|
||||
dynamic_axes={
|
||||
"ref_seq": {1: "ref_length"},
|
||||
"text_seq": {1: "text_length"},
|
||||
"ref_bert": {0: "ref_length"},
|
||||
"text_bert": {0: "text_length"},
|
||||
"ssl_content": {2: "ssl_length"},
|
||||
},
|
||||
opset_version=16,
|
||||
)
|
||||
x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
||||
|
||||
torch.onnx.export(
|
||||
self.first_stage_decoder,
|
||||
(x, prompts),
|
||||
f"onnx/{project_name}/{project_name}_t2s_fsdec.onnx",
|
||||
input_names=["x", "prompts"],
|
||||
output_names=["y", "k", "v", "y_emb", "x_example"],
|
||||
dynamic_axes={
|
||||
"x": {1: "x_length"},
|
||||
"prompts": {1: "prompts_length"},
|
||||
},
|
||||
verbose=False,
|
||||
opset_version=16,
|
||||
)
|
||||
y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
|
||||
|
||||
torch.onnx.export(
|
||||
self.stage_decoder,
|
||||
(y, k, v, y_emb, x_example),
|
||||
f"onnx/{project_name}/{project_name}_t2s_sdec.onnx",
|
||||
input_names=["iy", "ik", "iv", "iy_emb", "ix_example"],
|
||||
output_names=["y", "k", "v", "y_emb", "logits", "samples"],
|
||||
dynamic_axes={
|
||||
"iy": {1: "iy_length"},
|
||||
"ik": {1: "ik_length"},
|
||||
"iv": {1: "iv_length"},
|
||||
"iy_emb": {1: "iy_emb_length"},
|
||||
"ix_example": {1: "ix_example_length"},
|
||||
},
|
||||
verbose=False,
|
||||
opset_version=16,
|
||||
)
|
||||
|
||||
|
||||
class VitsModel(nn.Module):
|
||||
def __init__(self, vits_path):
|
||||
super().__init__()
|
||||
dict_s2 = torch.load(vits_path, map_location="cpu")
|
||||
self.hps = dict_s2["config"]
|
||||
if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
|
||||
self.hps["model"]["version"] = "v1"
|
||||
else:
|
||||
self.hps["model"]["version"] = "v2"
|
||||
|
||||
self.hps = DictToAttrRecursive(self.hps)
|
||||
self.hps.model.semantic_frame_rate = "25hz"
|
||||
self.vq_model = SynthesizerTrn(
|
||||
self.hps.data.filter_length // 2 + 1,
|
||||
self.hps.train.segment_size // self.hps.data.hop_length,
|
||||
n_speakers=self.hps.data.n_speakers,
|
||||
**self.hps.model,
|
||||
)
|
||||
self.vq_model.eval()
|
||||
self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
||||
|
||||
def forward(self, text_seq, pred_semantic, ref_audio):
|
||||
refer = spectrogram_torch(
|
||||
ref_audio,
|
||||
self.hps.data.filter_length,
|
||||
self.hps.data.sampling_rate,
|
||||
self.hps.data.hop_length,
|
||||
self.hps.data.win_length,
|
||||
center=False,
|
||||
)
|
||||
return self.vq_model(pred_semantic, text_seq, refer)[0, 0]
|
||||
|
||||
|
||||
class GptSoVits(nn.Module):
|
||||
def __init__(self, vits, t2s):
|
||||
super().__init__()
|
||||
self.vits = vits
|
||||
self.t2s = t2s
|
||||
|
||||
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, debug=False):
|
||||
pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
||||
audio = self.vits(text_seq, pred_semantic, ref_audio)
|
||||
if debug:
|
||||
import onnxruntime
|
||||
|
||||
sess = onnxruntime.InferenceSession("onnx/koharu/koharu_vits.onnx", providers=["CPU"])
|
||||
audio1 = sess.run(
|
||||
None,
|
||||
{
|
||||
"text_seq": text_seq.detach().cpu().numpy(),
|
||||
"pred_semantic": pred_semantic.detach().cpu().numpy(),
|
||||
"ref_audio": ref_audio.detach().cpu().numpy(),
|
||||
},
|
||||
)
|
||||
return audio, audio1
|
||||
return audio
|
||||
|
||||
def export(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, project_name):
|
||||
self.t2s.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name)
|
||||
pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
||||
torch.onnx.export(
|
||||
self.vits,
|
||||
(text_seq, pred_semantic, ref_audio),
|
||||
f"onnx/{project_name}/{project_name}_vits.onnx",
|
||||
input_names=["text_seq", "pred_semantic", "ref_audio"],
|
||||
output_names=["audio"],
|
||||
dynamic_axes={
|
||||
"text_seq": {1: "text_length"},
|
||||
"pred_semantic": {2: "pred_length"},
|
||||
"ref_audio": {1: "audio_length"},
|
||||
},
|
||||
opset_version=17,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
|
||||
class SSLModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.ssl = ssl_model
|
||||
|
||||
def forward(self, ref_audio_16k):
|
||||
return self.ssl.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2)
|
||||
|
||||
|
||||
def export(vits_path, gpt_path, project_name, vits_model="v2"):
|
||||
vits = VitsModel(vits_path)
|
||||
gpt = T2SModel(gpt_path, vits)
|
||||
gpt_sovits = GptSoVits(vits, gpt)
|
||||
ssl = SSLModel()
|
||||
ref_seq = torch.LongTensor(
|
||||
[
|
||||
cleaned_text_to_sequence(
|
||||
[
|
||||
"n",
|
||||
"i2",
|
||||
"h",
|
||||
"ao3",
|
||||
",",
|
||||
"w",
|
||||
"o3",
|
||||
"sh",
|
||||
"i4",
|
||||
"b",
|
||||
"ai2",
|
||||
"y",
|
||||
"e4",
|
||||
],
|
||||
version=vits_model,
|
||||
)
|
||||
]
|
||||
)
|
||||
text_seq = torch.LongTensor(
|
||||
[
|
||||
cleaned_text_to_sequence(
|
||||
[
|
||||
"w",
|
||||
"o3",
|
||||
"sh",
|
||||
"i4",
|
||||
"b",
|
||||
"ai2",
|
||||
"y",
|
||||
"e4",
|
||||
"w",
|
||||
"o3",
|
||||
"sh",
|
||||
"i4",
|
||||
"b",
|
||||
"ai2",
|
||||
"y",
|
||||
"e4",
|
||||
"w",
|
||||
"o3",
|
||||
"sh",
|
||||
"i4",
|
||||
"b",
|
||||
"ai2",
|
||||
"y",
|
||||
"e4",
|
||||
],
|
||||
version=vits_model,
|
||||
)
|
||||
]
|
||||
)
|
||||
ref_bert = torch.randn((ref_seq.shape[1], 1024)).float()
|
||||
text_bert = torch.randn((text_seq.shape[1], 1024)).float()
|
||||
ref_audio = torch.randn((1, 48000 * 5)).float()
|
||||
# ref_audio = torch.tensor([load_audio("rec.wav", 48000)]).float()
|
||||
ref_audio_16k = torchaudio.functional.resample(ref_audio, 48000, 16000).float()
|
||||
ref_audio_sr = torchaudio.functional.resample(ref_audio, 48000, vits.hps.data.sampling_rate).float()
|
||||
|
||||
try:
|
||||
os.mkdir(f"onnx/{project_name}")
|
||||
except:
|
||||
pass
|
||||
|
||||
ssl_content = ssl(ref_audio_16k).float()
|
||||
|
||||
# debug = False
|
||||
debug = True
|
||||
|
||||
# gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, project_name)
|
||||
|
||||
if debug:
|
||||
a, b = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, debug=debug)
|
||||
soundfile.write("out1.wav", a.cpu().detach().numpy(), vits.hps.data.sampling_rate)
|
||||
soundfile.write("out2.wav", b[0], vits.hps.data.sampling_rate)
|
||||
else:
|
||||
a = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content).detach().cpu().numpy()
|
||||
soundfile.write("out.wav", a, vits.hps.data.sampling_rate)
|
||||
|
||||
if vits_model == "v1":
|
||||
symbols = symbols_v1
|
||||
else:
|
||||
symbols = symbols_v2
|
||||
|
||||
MoeVSConf = {
|
||||
"Folder": f"{project_name}",
|
||||
"Name": f"{project_name}",
|
||||
"Type": "GPT-SoVits",
|
||||
"Rate": vits.hps.data.sampling_rate,
|
||||
"NumLayers": gpt.t2s_model.num_layers,
|
||||
"EmbeddingDim": gpt.t2s_model.embedding_dim,
|
||||
"Dict": "BasicDict",
|
||||
"BertPath": "chinese-roberta-wwm-ext-large",
|
||||
# "Symbol": symbols,
|
||||
"AddBlank": False,
|
||||
}
|
||||
|
||||
MoeVSConfJson = json.dumps(MoeVSConf)
|
||||
with open(f"onnx/{project_name}.json", "w") as MoeVsConfFile:
|
||||
json.dump(MoeVSConf, MoeVsConfFile, indent=4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
os.mkdir("onnx")
|
||||
except:
|
||||
pass
|
||||
|
||||
gpt_path = "GPT_weights/nahida-e25.ckpt"
|
||||
vits_path = "SoVITS_weights/nahida_e30_s3930.pth"
|
||||
exp_path = "nahida"
|
||||
export(vits_path, gpt_path, exp_path)
|
||||
|
||||
# soundfile.write("out.wav", a, vits.hps.data.sampling_rate)
|
448
GPT_SoVITS/onnx_export_v1v2.py
Normal file
448
GPT_SoVITS/onnx_export_v1v2.py
Normal file
@ -0,0 +1,448 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
from AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule
|
||||
from feature_extractor import cnhubert
|
||||
from module.models_onnx import SynthesizerTrn, symbols_v1, symbols_v2
|
||||
from torch import nn
|
||||
from sv import SV
|
||||
import onnx
|
||||
from onnx import helper, TensorProto
|
||||
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
|
||||
from transformers import HubertModel, HubertConfig
|
||||
import os
|
||||
import json
|
||||
from text import cleaned_text_to_sequence
|
||||
import onnxsim
|
||||
from onnxconverter_common import float16
|
||||
|
||||
def simplify_onnx_model(onnx_model_path: str):
|
||||
# Load the ONNX model
|
||||
model = onnx.load(onnx_model_path)
|
||||
# Simplify the model
|
||||
model_simplified, _ = onnxsim.simplify(model)
|
||||
# Save the simplified model
|
||||
onnx.save(model_simplified, onnx_model_path)
|
||||
|
||||
def convert_onnx_to_half(onnx_model_path:str):
|
||||
try:
|
||||
model = onnx.load(onnx_model_path)
|
||||
model_fp16 = float16.convert_float_to_float16(model)
|
||||
onnx.save(model_fp16, onnx_model_path)
|
||||
except Exception as e:
|
||||
print(f"Error converting {onnx_model_path} to half precision: {e}")
|
||||
|
||||
|
||||
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
|
||||
hann_window = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1),
|
||||
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
||||
mode="reflect",
|
||||
)
|
||||
y = y.squeeze(1)
|
||||
spec = torch.stft(
|
||||
y,
|
||||
n_fft,
|
||||
hop_length=hop_size,
|
||||
win_length=win_size,
|
||||
window=hann_window,
|
||||
center=center,
|
||||
pad_mode="reflect",
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
return_complex=False,
|
||||
)
|
||||
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
||||
return spec
|
||||
|
||||
def resample_audio(audio: torch.Tensor, orig_sr: int, target_sr: int) -> torch.Tensor:
|
||||
"""
|
||||
Resample audio from orig_sr to target_sr using linear interpolation.
|
||||
audio: (batch, channels, samples) or (channels, samples) or (samples,)
|
||||
"""
|
||||
if audio.dim() == 1:
|
||||
audio = audio.unsqueeze(0).unsqueeze(0)
|
||||
elif audio.dim() == 2:
|
||||
audio = audio.unsqueeze(0)
|
||||
# audio shape: (batch, channels, samples)
|
||||
batch, channels, samples = audio.shape
|
||||
# Reshape to combine batch and channels for interpolation
|
||||
audio = audio.reshape(batch * channels, 1, samples)
|
||||
# Use scale_factor instead of a computed size for ONNX export compatibility
|
||||
resampled = F.interpolate(audio, scale_factor=target_sr / orig_sr, mode='linear', align_corners=False)
|
||||
new_samples = resampled.shape[-1]
|
||||
resampled = resampled.reshape(batch, channels, new_samples)
|
||||
resampled = resampled.squeeze(0).squeeze(0)
|
||||
return resampled
|
||||
|
||||
|
||||
class DictToAttrRecursive(dict):
|
||||
def __init__(self, input_dict):
|
||||
super().__init__(input_dict)
|
||||
for key, value in input_dict.items():
|
||||
if isinstance(value, dict):
|
||||
value = DictToAttrRecursive(value)
|
||||
self[key] = value
|
||||
setattr(self, key, value)
|
||||
|
||||
def __getattr__(self, item):
|
||||
try:
|
||||
return self[item]
|
||||
except KeyError:
|
||||
raise AttributeError(f"Attribute {item} not found")
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if isinstance(value, dict):
|
||||
value = DictToAttrRecursive(value)
|
||||
super(DictToAttrRecursive, self).__setitem__(key, value)
|
||||
super().__setattr__(key, value)
|
||||
|
||||
def __delattr__(self, item):
|
||||
try:
|
||||
del self[item]
|
||||
except KeyError:
|
||||
raise AttributeError(f"Attribute {item} not found")
|
||||
|
||||
|
||||
class T2SInitStage(nn.Module):
|
||||
def __init__(self, t2s, vits):
|
||||
super().__init__()
|
||||
self.encoder = t2s.onnx_encoder
|
||||
self.vits = vits
|
||||
self.num_layers = t2s.num_layers
|
||||
|
||||
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content):
|
||||
codes = self.vits.extract_latent(ssl_content)
|
||||
prompt_semantic = codes[0, 0]
|
||||
bert = torch.cat([ref_bert.transpose(0, 1), text_bert.transpose(0, 1)], 1)
|
||||
all_phoneme_ids = torch.cat([ref_seq, text_seq], 1)
|
||||
bert = bert.unsqueeze(0)
|
||||
prompt = prompt_semantic.unsqueeze(0)
|
||||
x = self.encoder(all_phoneme_ids, bert)
|
||||
|
||||
x_seq_len = torch.onnx.operators.shape_as_tensor(x)[1]
|
||||
y_seq_len = torch.onnx.operators.shape_as_tensor(prompt)[1]
|
||||
|
||||
init_k = torch.zeros(((x_seq_len + y_seq_len), self.num_layers, 512), dtype=torch.float)
|
||||
init_v = torch.zeros(((x_seq_len + y_seq_len), self.num_layers, 512), dtype=torch.float)
|
||||
|
||||
return x, prompt, init_k, init_v, x_seq_len, y_seq_len
|
||||
|
||||
class T2SModel(nn.Module):
|
||||
def __init__(self, t2s_path, vits_model):
|
||||
super().__init__()
|
||||
dict_s1 = torch.load(t2s_path, map_location="cpu")
|
||||
self.config = dict_s1["config"]
|
||||
self.t2s_model = Text2SemanticLightningModule(self.config, "ojbk", is_train=False)
|
||||
self.t2s_model.load_state_dict(dict_s1["weight"])
|
||||
self.t2s_model.eval()
|
||||
self.vits_model = vits_model.vq_model
|
||||
self.hz = 50
|
||||
self.max_sec = self.config["data"]["max_sec"]
|
||||
self.t2s_model.model.top_k = torch.LongTensor([self.config["inference"]["top_k"]])
|
||||
self.t2s_model.model.early_stop_num = torch.LongTensor([self.hz * self.max_sec])
|
||||
self.t2s_model = self.t2s_model.model
|
||||
self.t2s_model.init_onnx()
|
||||
self.init_stage = T2SInitStage(self.t2s_model, self.vits_model)
|
||||
self.stage_decoder = self.t2s_model.stage_decoder
|
||||
|
||||
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=None, top_p=None, repetition_penalty=None, temperature=None):
|
||||
x, prompt, init_k, init_v, x_seq_len, y_seq_len = self.init_stage(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
||||
empty_tensor = torch.empty((1,0,512)).to(torch.float)
|
||||
# first step
|
||||
y, k, v, y_emb, logits, samples = self.stage_decoder(x, prompt, init_k, init_v,
|
||||
empty_tensor,
|
||||
top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature,
|
||||
first_infer=torch.LongTensor([1]), x_seq_len=x_seq_len, y_seq_len=y_seq_len)
|
||||
|
||||
for idx in range(5): # This is a fake one! DO NOT take this as reference
|
||||
k = torch.nn.functional.pad(k, (0, 0, 0, 0, 0, 1))
|
||||
v = torch.nn.functional.pad(v, (0, 0, 0, 0, 0, 1))
|
||||
y_seq_len = y.shape[1]
|
||||
y, k, v, y_emb, logits, samples = self.stage_decoder(empty_tensor, y, k, v,
|
||||
y_emb,
|
||||
top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature,
|
||||
first_infer=torch.LongTensor([0]), x_seq_len=x_seq_len, y_seq_len=y_seq_len)
|
||||
# if torch.argmax(logits, dim=-1)[0] == self.t2s_model.EOS or samples[0, 0] == self.t2s_model.EOS:
|
||||
# break
|
||||
|
||||
return y[:, -5:].unsqueeze(0)
|
||||
|
||||
def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name, top_k=None, top_p=None, repetition_penalty=None, temperature=None):
|
||||
torch.onnx.export(
|
||||
self.init_stage,
|
||||
(ref_seq, text_seq, ref_bert, text_bert, ssl_content),
|
||||
f"onnx/{project_name}/{project_name}_t2s_init_stage.onnx",
|
||||
input_names=["ref_text_phones", "input_text_phones", "ref_text_bert", "input_text_bert", "hubert_ssl_content"],
|
||||
output_names=["x", "prompt", "init_k", "init_v", 'x_seq_len', 'y_seq_len'],
|
||||
dynamic_axes={
|
||||
"ref_text_phones": {1: "ref_length"},
|
||||
"input_text_phones": {1: "text_length"},
|
||||
"ref_text_bert": {0: "ref_length"},
|
||||
"input_text_bert": {0: "text_length"},
|
||||
"hubert_ssl_content": {2: "ssl_length"},
|
||||
},
|
||||
opset_version=16,
|
||||
do_constant_folding=False
|
||||
)
|
||||
simplify_onnx_model(f"onnx/{project_name}/{project_name}_t2s_init_stage.onnx")
|
||||
x, prompt, init_k, init_v, x_seq_len, y_seq_len = self.init_stage(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
||||
empty_tensor = torch.empty((1,0,512)).to(torch.float)
|
||||
x_seq_len = torch.Tensor([x_seq_len]).to(torch.int64)
|
||||
y_seq_len = torch.Tensor([y_seq_len]).to(torch.int64)
|
||||
|
||||
y, k, v, y_emb, logits, samples = self.stage_decoder(x, prompt, init_k, init_v,
|
||||
empty_tensor,
|
||||
top_k, top_p, repetition_penalty, temperature,
|
||||
torch.LongTensor([1]), x_seq_len, y_seq_len)
|
||||
k = torch.nn.functional.pad(k, (0, 0, 0, 0, 0, 1))
|
||||
v = torch.nn.functional.pad(v, (0, 0, 0, 0, 0, 1))
|
||||
y_seq_len = torch.Tensor([y.shape[1]]).to(torch.int64)
|
||||
|
||||
torch.onnx.export(
|
||||
self.stage_decoder,
|
||||
(x, y, k, v, y_emb, top_k, top_p, repetition_penalty, temperature, torch.LongTensor([0]), x_seq_len, y_seq_len),
|
||||
f"onnx/{project_name}/{project_name}_t2s_stage_decoder.onnx",
|
||||
input_names=["ix", "iy", "ik", "iv", "iy_emb", "top_k", "top_p", "repetition_penalty", "temperature", "if_init_step", "x_seq_len", "y_seq_len"],
|
||||
output_names=["y", "k", "v", "y_emb", "logits", "samples"],
|
||||
dynamic_axes={
|
||||
"ix": {1: "ix_length"},
|
||||
"iy": {1: "iy_length"},
|
||||
"ik": {0: "ik_length"},
|
||||
"iv": {0: "iv_length"},
|
||||
"iy_emb": {1: "iy_emb_length"},
|
||||
},
|
||||
verbose=False,
|
||||
opset_version=16,
|
||||
)
|
||||
simplify_onnx_model(f"onnx/{project_name}/{project_name}_t2s_stage_decoder.onnx")
|
||||
|
||||
|
||||
class VitsModel(nn.Module):
|
||||
def __init__(self, vits_path, version:str = 'v2'):
|
||||
super().__init__()
|
||||
dict_s2 = torch.load(vits_path, map_location="cpu", weights_only=False)
|
||||
self.hps = dict_s2["config"]
|
||||
if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
|
||||
self.hps["model"]["version"] = "v1"
|
||||
else:
|
||||
self.hps["model"]["version"] = version
|
||||
|
||||
self.is_v2p = version.lower() in ['v2pro', 'v2proplus']
|
||||
|
||||
self.hps = DictToAttrRecursive(self.hps)
|
||||
self.hps.model.semantic_frame_rate = "25hz"
|
||||
self.vq_model:SynthesizerTrn = SynthesizerTrn(
|
||||
self.hps.data.filter_length // 2 + 1,
|
||||
self.hps.train.segment_size // self.hps.data.hop_length,
|
||||
n_speakers=self.hps.data.n_speakers,
|
||||
**self.hps.model,
|
||||
)
|
||||
self.vq_model.eval()
|
||||
self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
||||
# print(f"filter_length:{self.hps.data.filter_length} sampling_rate:{self.hps.data.sampling_rate} hop_length:{self.hps.data.hop_length} win_length:{self.hps.data.win_length}")
|
||||
#v2 filter_length: 2048 sampling_rate: 32000 hop_length: 640 win_length: 2048
|
||||
def forward(self, text_seq, pred_semantic, spectrum, sv_emb, speed):
|
||||
if self.is_v2p:
|
||||
return self.vq_model(pred_semantic, text_seq, spectrum, sv_emb=sv_emb, speed=speed)[0, 0]
|
||||
else:
|
||||
return self.vq_model(pred_semantic, text_seq, spectrum, speed=speed)[0, 0]
|
||||
|
||||
|
||||
class GptSoVits():
|
||||
def __init__(self, vits, t2s):
|
||||
super().__init__()
|
||||
self.vits = vits
|
||||
self.t2s = t2s
|
||||
|
||||
def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, spectrum, sv_emb, speed, project_name, top_k=None, top_p=None, repetition_penalty=None, temperature=None):
|
||||
self.t2s.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)
|
||||
pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)
|
||||
torch.onnx.export(
|
||||
self.vits,
|
||||
(text_seq, pred_semantic, spectrum, sv_emb, speed),
|
||||
f"onnx/{project_name}/{project_name}_vits.onnx",
|
||||
input_names=["input_text_phones", "pred_semantic", "spectrum", "sv_emb", "speed"],
|
||||
output_names=["audio"],
|
||||
dynamic_axes={
|
||||
"input_text_phones": {1: "text_length"},
|
||||
"pred_semantic": {2: "pred_length"},
|
||||
"spectrum": {2: "spectrum_length"},
|
||||
},
|
||||
opset_version=17,
|
||||
verbose=False,
|
||||
)
|
||||
simplify_onnx_model(f"onnx/{project_name}/{project_name}_vits.onnx")
|
||||
|
||||
|
||||
class AudioPreprocess(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
# Load the model
|
||||
self.model = HubertModel.from_pretrained(cnhubert_base_path, local_files_only=True)
|
||||
self.model.eval()
|
||||
|
||||
self.sv_model = SV("cpu", False)
|
||||
|
||||
def forward(self, ref_audio_32k):
|
||||
spectrum = spectrogram_torch(
|
||||
ref_audio_32k,
|
||||
2048,
|
||||
32000,
|
||||
640,
|
||||
2048,
|
||||
center=False,
|
||||
)
|
||||
ref_audio_16k = resample_audio(ref_audio_32k, 32000, 16000)
|
||||
|
||||
sv_emb = self.sv_model.compute_embedding3_onnx(ref_audio_16k)
|
||||
|
||||
zero_tensor = torch.zeros((1, 9600), dtype=torch.float32)
|
||||
ref_audio_16k = ref_audio_16k.unsqueeze(0)
|
||||
# concate zero_tensor with waveform
|
||||
ref_audio_16k = torch.cat([ref_audio_16k, zero_tensor], dim=1)
|
||||
ssl_content = self.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2)
|
||||
|
||||
return ssl_content, spectrum, sv_emb
|
||||
|
||||
def export(vits_path, gpt_path, project_name, voice_model_version, export_audio_preprocessor=True, half_precision=False):
|
||||
vits = VitsModel(vits_path, version=voice_model_version)
|
||||
gpt = T2SModel(gpt_path, vits)
|
||||
gpt_sovits = GptSoVits(vits, gpt)
|
||||
preprocessor = AudioPreprocess()
|
||||
ref_seq = torch.LongTensor(
|
||||
[
|
||||
cleaned_text_to_sequence(
|
||||
[
|
||||
"n",
|
||||
"i2",
|
||||
"h",
|
||||
"ao3",
|
||||
",",
|
||||
"w",
|
||||
"o3",
|
||||
"sh",
|
||||
"i4",
|
||||
"b",
|
||||
"ai2",
|
||||
"y",
|
||||
"e4",
|
||||
],
|
||||
version='v2',
|
||||
)
|
||||
]
|
||||
)
|
||||
text_seq = torch.LongTensor(
|
||||
[
|
||||
cleaned_text_to_sequence(
|
||||
[
|
||||
"w",
|
||||
"o3",
|
||||
"sh",
|
||||
"i4",
|
||||
"b",
|
||||
"ai2",
|
||||
"y",
|
||||
"e4",
|
||||
"w",
|
||||
"o3",
|
||||
"sh",
|
||||
"i4",
|
||||
"b",
|
||||
"ai2",
|
||||
"y",
|
||||
"e4",
|
||||
"w",
|
||||
"o3",
|
||||
"sh",
|
||||
"i4",
|
||||
"b",
|
||||
"ai2",
|
||||
"y",
|
||||
"e4",
|
||||
],
|
||||
version='v2',
|
||||
)
|
||||
]
|
||||
)
|
||||
ref_bert = torch.randn((ref_seq.shape[1], 1024)).float()
|
||||
text_bert = torch.randn((text_seq.shape[1], 1024)).float()
|
||||
ref_audio32k = torch.randn((1, 32000 * 5)).float() - 0.5 # 5 seconds of dummy audio
|
||||
top_k = torch.LongTensor([15])
|
||||
top_p = torch.FloatTensor([1.0])
|
||||
repetition_penalty = torch.FloatTensor([1.0])
|
||||
temperature = torch.FloatTensor([1.0])
|
||||
speed = torch.FloatTensor([1.0])
|
||||
|
||||
os.makedirs(f"onnx/{project_name}", exist_ok=True)
|
||||
|
||||
[ssl_content, spectrum, sv_emb] = preprocessor(ref_audio32k)
|
||||
gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content.float(), spectrum.float(), sv_emb.float(), speed, project_name, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)
|
||||
|
||||
if export_audio_preprocessor:
|
||||
torch.onnx.export(preprocessor, (ref_audio32k,), f"onnx/{project_name}/{project_name}_audio_preprocess.onnx",
|
||||
input_names=["audio32k"],
|
||||
output_names=["hubert_ssl_output", "spectrum", "sv_emb"],
|
||||
dynamic_axes={
|
||||
"audio32k": {1: "sequence_length"},
|
||||
"hubert_ssl_output": {2: "hubert_length"},
|
||||
"spectrum": {2: "spectrum_length"}
|
||||
})
|
||||
simplify_onnx_model(f"onnx/{project_name}/{project_name}_audio_preprocess.onnx")
|
||||
|
||||
if half_precision:
|
||||
if export_audio_preprocessor:
|
||||
convert_onnx_to_half(f"onnx/{project_name}/{project_name}_audio_preprocess.onnx")
|
||||
convert_onnx_to_half(f"onnx/{project_name}/{project_name}_vits.onnx")
|
||||
convert_onnx_to_half(f"onnx/{project_name}/{project_name}_t2s_init_step.onnx")
|
||||
convert_onnx_to_half(f"onnx/{project_name}/{project_name}_t2s_stage_step.onnx")
|
||||
|
||||
configJson = {
|
||||
"project_name": project_name,
|
||||
"type": "GPTSoVITS",
|
||||
"version" : voice_model_version,
|
||||
"bert_base_path": 'GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large',
|
||||
"cnhuhbert_base_path": 'GPT_SoVITS/pretrained_models/chinese-hubert-base',
|
||||
"t2s_weights_path": gpt_path,
|
||||
"vits_weights_path": vits_path,
|
||||
"half_precision": half_precision
|
||||
}
|
||||
with open(f"onnx/{project_name}/config.json", "w", encoding="utf-8") as f:
|
||||
json.dump(configJson, f, ensure_ascii=False, indent=4)
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
os.mkdir("onnx")
|
||||
except:
|
||||
pass
|
||||
|
||||
# 因为io太频繁,可能导致模型导出出错(wsl非常明显),请自行重试
|
||||
|
||||
gpt_path = "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
|
||||
vits_path = "GPT_SoVITS/pretrained_models/s2G488k.pth"
|
||||
exp_path = "v1_export"
|
||||
version = "v1"
|
||||
export(vits_path, gpt_path, exp_path, version)
|
||||
|
||||
gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"
|
||||
vits_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"
|
||||
exp_path = "v2_export"
|
||||
version = "v2"
|
||||
export(vits_path, gpt_path, exp_path, version)
|
||||
|
||||
|
||||
gpt_path = "GPT_SoVITS/pretrained_models/s1v3.ckpt"
|
||||
vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth"
|
||||
exp_path = "v2pro_export"
|
||||
version = "v2Pro"
|
||||
export(vits_path, gpt_path, exp_path, version)
|
||||
|
||||
gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"
|
||||
vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth"
|
||||
exp_path = "v2proplus_export"
|
||||
version = "v2ProPlus"
|
||||
export(vits_path, gpt_path, exp_path, version)
|
||||
|
||||
|
@ -30,3 +30,15 @@ class SV:
|
||||
)
|
||||
sv_emb = self.embedding_model.forward3(feat)
|
||||
return sv_emb
|
||||
|
||||
def compute_embedding3_onnx(self, wav):
|
||||
# Disable gradients for all parameters
|
||||
for param in self.embedding_model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
with torch.no_grad():
|
||||
if self.is_half == True:
|
||||
wav = wav.half()
|
||||
feat = Kaldi.fbank_onnx(wav.detach()).unsqueeze(0)
|
||||
sv_emb = self.embedding_model.forward3(feat)
|
||||
return sv_emb
|
@ -7,8 +7,11 @@ numba
|
||||
pytorch-lightning>=2.4
|
||||
gradio<5
|
||||
ffmpeg-python
|
||||
onnx
|
||||
onnxruntime; platform_machine == "aarch64" or platform_machine == "arm64"
|
||||
onnxruntime-gpu; platform_machine == "x86_64" or platform_machine == "AMD64"
|
||||
onnxsim
|
||||
onnxconverter-common
|
||||
tqdm
|
||||
funasr==1.0.27
|
||||
cn2an
|
||||
@ -32,7 +35,7 @@ rotary_embedding_torch
|
||||
ToJyutping
|
||||
g2pk2
|
||||
ko_pron
|
||||
opencc
|
||||
opencc==1.1.6
|
||||
python_mecab_ko; sys_platform != 'win32'
|
||||
fastapi[standard]>=0.115.2
|
||||
x_transformers
|
||||
|
Loading…
x
Reference in New Issue
Block a user