diff --git a/.gitignore b/.gitignore index d280e459..b98ec8f4 100644 --- a/.gitignore +++ b/.gitignore @@ -193,3 +193,10 @@ cython_debug/ # PyPI configuration file .pypirc + +#onnx +onnx/ +*.onnx +tokenizer.json +output.wav +config.json \ No newline at end of file diff --git a/GPT_SoVITS/AR/models/t2s_model_onnx.py b/GPT_SoVITS/AR/models/t2s_model_onnx.py index 4f7b50a3..c8a3d876 100644 --- a/GPT_SoVITS/AR/models/t2s_model_onnx.py +++ b/GPT_SoVITS/AR/models/t2s_model_onnx.py @@ -7,6 +7,7 @@ from torchmetrics.classification import MulticlassAccuracy from AR.modules.embedding_onnx import SinePositionalEmbedding, TokenEmbedding from AR.modules.transformer_onnx import LayerNorm, TransformerEncoder, TransformerEncoderLayer +from tqdm import tqdm default_config = { "embedding_dim": 512, @@ -27,45 +28,45 @@ def logits_to_probs( logits, previous_tokens=None, temperature: float = 1.0, - top_k=None, - top_p=None, + top_k=15, + top_p=1.0, repetition_penalty: float = 1.0, ): previous_tokens = previous_tokens.squeeze() - if previous_tokens is not None and repetition_penalty != 1.0: - previous_tokens = previous_tokens.long() - score = torch.gather(logits, dim=0, index=previous_tokens) - score = torch.where( - score < 0, - score * repetition_penalty, - score / repetition_penalty, - ) - logits.scatter_(dim=0, index=previous_tokens, src=score) + # if previous_tokens is not None and repetition_penalty != 1.0: # Always captured by onnx + previous_tokens = previous_tokens.long() + score = torch.gather(logits, dim=0, index=previous_tokens) + score = torch.where( + score < 0, + score * repetition_penalty, + score / repetition_penalty, + ) + logits.scatter_(dim=0, index=previous_tokens, src=score) - if top_p is not None and top_p < 1.0: - sorted_logits, sorted_indices = torch.sort(logits, descending=True) - cum_probs = torch.cumsum( - torch.nn.functional.softmax( - sorted_logits, - dim=-1, - ), + # if top_p is not None and top_p < 1.0: #To be captured by onnx + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cum_probs = torch.cumsum( + torch.nn.functional.softmax( + sorted_logits, dim=-1, - ) - sorted_indices_to_remove = cum_probs > top_p - sorted_indices_to_remove[0] = False # keep at least one option - indices_to_remove = sorted_indices_to_remove.scatter( - dim=0, - index=sorted_indices, - src=sorted_indices_to_remove, - ) - logits = logits.masked_fill(indices_to_remove, -float("Inf")) + ), + dim=-1, + ) + sorted_indices_to_remove = cum_probs > top_p + sorted_indices_to_remove[0] = False # keep at least one option + indices_to_remove = sorted_indices_to_remove.scatter( + dim=0, + index=sorted_indices, + src=sorted_indices_to_remove, + ) + logits = logits.masked_fill(indices_to_remove, -float("Inf")) - logits = logits / max(temperature, 1e-5) + logits = logits / torch.max(temperature, torch.tensor(1e-5, device=logits.device, dtype=torch.float)) - if top_k is not None: - v, _ = torch.topk(logits, top_k) - pivot = v.select(-1, -1).unsqueeze(-1) - logits = torch.where(logits < pivot, inf_tensor_value, logits) + # if top_k is not None: # To be captured by onnx + v, _ = torch.topk(logits, top_k) + pivot = v.select(-1, -1).unsqueeze(-1) + logits = torch.where(logits < pivot, inf_tensor_value, logits) probs = torch.nn.functional.softmax(logits, dim=-1) return probs @@ -104,88 +105,6 @@ class OnnxEncoder(nn.Module): x = x + self.bert_proj(bert_feature.transpose(1, 2)) return self.ar_text_position(x) - -class T2SFirstStageDecoder(nn.Module): - def __init__( - self, - ar_audio_embedding, - ar_audio_position, - h, - ar_predict_layer, - loss_fct, - ar_accuracy_metric, - top_k, - early_stop_num, - num_layers, - ): - super().__init__() - self.ar_audio_embedding = ar_audio_embedding - self.ar_audio_position = ar_audio_position - self.h = h - self.ar_predict_layer = ar_predict_layer - self.loss_fct = loss_fct - self.ar_accuracy_metric = ar_accuracy_metric - self.top_k = top_k - self.early_stop_num = early_stop_num - self.num_layers = num_layers - - def forward(self, x, prompt): - y = prompt - x_example = x[:, :, 0] * 0.0 - # N, 1, 512 - cache = { - "all_stage": self.num_layers, - "k": None, - "v": None, - "y_emb": None, - "first_infer": 1, - "stage": 0, - } - - y_emb = self.ar_audio_embedding(y) - - cache["y_emb"] = y_emb - y_pos = self.ar_audio_position(y_emb) - - xy_pos = torch.concat([x, y_pos], dim=1) - - y_example = y_pos[:, :, 0] * 0.0 - x_attn_mask = torch.matmul(x_example.transpose(0, 1), x_example).bool() - y_attn_mask = torch.ones_like(torch.matmul(y_example.transpose(0, 1), y_example), dtype=torch.int64) - y_attn_mask = torch.cumsum(y_attn_mask, dim=1) - torch.cumsum( - torch.ones_like( - y_example.transpose(0, 1), - dtype=torch.int64, - ), - dim=0, - ) - y_attn_mask = y_attn_mask > 0 - - x_y_pad = torch.matmul(x_example.transpose(0, 1), y_example).bool() - y_x_pad = torch.matmul(y_example.transpose(0, 1), x_example).bool() - x_attn_mask_pad = torch.cat([x_attn_mask, torch.ones_like(x_y_pad)], dim=1) - y_attn_mask = torch.cat([y_x_pad, y_attn_mask], dim=1) - xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0) - cache["k"] = ( - torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512))) - .unsqueeze(1) - .repeat(self.num_layers, 1, 1, 1) - ) - cache["v"] = ( - torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512))) - .unsqueeze(1) - .repeat(self.num_layers, 1, 1, 1) - ) - - xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache) - logits = self.ar_predict_layer(xy_dec[:, -1]) - samples = sample(logits[0], y, top_k=self.top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0) - - y = torch.concat([y, samples], dim=1) - - return y, cache["k"], cache["v"], cache["y_emb"], x_example - - class T2SStageDecoder(nn.Module): def __init__( self, @@ -195,7 +114,6 @@ class T2SStageDecoder(nn.Module): ar_predict_layer, loss_fct, ar_accuracy_metric, - top_k, early_stop_num, num_layers, ): @@ -206,40 +124,80 @@ class T2SStageDecoder(nn.Module): self.ar_predict_layer = ar_predict_layer self.loss_fct = loss_fct self.ar_accuracy_metric = ar_accuracy_metric - self.top_k = top_k self.early_stop_num = early_stop_num self.num_layers = num_layers - def forward(self, y, k, v, y_emb, x_example): + def forward(self, x, y, k, v, y_emb, top_k = None, top_p = None, repetition_penalty = None, temperature = None, first_infer = None, x_seq_len = None, y_seq_len = None): + if top_k is None: + top_k = torch.LongTensor([15]).to(device=y.device) + if top_p is None: + top_p = torch.FloatTensor([1.0]).to(device=y.device) + if repetition_penalty is None: + repetition_penalty = torch.FloatTensor([1.0]).to(device=y.device) + if temperature is None: + temperature = torch.FloatTensor([1.0]).to(device=y.device) + minus_one = torch.tensor([-1]).to(y.device).to(torch.int64) + cache = { "all_stage": self.num_layers, - "k": torch.nn.functional.pad(k, (0, 0, 0, 0, 0, 1)), - "v": torch.nn.functional.pad(v, (0, 0, 0, 0, 0, 1)), + "k": k, + "v": v, "y_emb": y_emb, - "first_infer": 0, + "first_infer": first_infer, "stage": 0, + "x_seq_len": x_seq_len, + "y_seq_len": y_seq_len, } + # 运行时判断对最后一个y还是整个y做embedding,以正确应对首次和后续 + multipled = minus_one * first_infer * y_seq_len + index_offset = torch.min(minus_one, multipled) + y_to_emb = y[:, index_offset:] + # 对y输入进行embedding y_emb = torch.cat( [ cache["y_emb"], - self.ar_audio_embedding(y[:, -1:]), + self.ar_audio_embedding(y_to_emb), ], 1, ) cache["y_emb"] = y_emb y_pos = self.ar_audio_position(y_emb) + # 与x输入拼接做attention准备 + xy_pos = torch.concat([x, y_pos], dim=1) - xy_pos = y_pos[:, -1:] + # 运行时判断对最后一个xy_pos还是整个xy_pos做self attention + multipled = minus_one * first_infer * (x_seq_len + y_seq_len) # xy_pos = 1 or x_seq_len + y_seq_len + index_offset = torch.min(minus_one, multipled) + xy_pos = xy_pos[:, index_offset:] - y_example = y_pos[:, :, 0] * 0.0 + # 构造xy的attention mask + x_attn_mask = torch.zeros((x_seq_len, x_seq_len)).bool() + y_attn_mask = torch.ones((y_seq_len, y_seq_len)).to(torch.int64) + y_attn_mask = torch.cumsum(y_attn_mask, dim=1) - torch.cumsum( + torch.ones( + (y_seq_len, 1), + dtype=torch.int64, + ), + dim=0, + ) + y_attn_mask = y_attn_mask > 0 - xy_attn_mask = torch.cat([x_example, y_example], dim=1) - xy_attn_mask = torch.zeros_like(xy_attn_mask, dtype=torch.bool) + x_y_pad = torch.ones((x_seq_len, y_seq_len)).to(torch.bool) + y_x_pad = torch.zeros((y_seq_len, x_seq_len)).to(torch.bool) + + x_attn_mask_pad = torch.cat([x_attn_mask, x_y_pad], dim=1) + y_attn_mask = torch.cat([y_x_pad, y_attn_mask], dim=1) + xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0) + + # 运行时判断attension mask使用最后一个还是整个 + multipled = minus_one * first_infer * (x_seq_len + y_seq_len) + index_offset = torch.min(minus_one, multipled) + xy_attn_mask = xy_attn_mask[index_offset:, :] xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache) logits = self.ar_predict_layer(xy_dec[:, -1]) - samples = sample(logits[0], y, top_k=self.top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0) + samples = sample(logits[0], y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature)[0].unsqueeze(0) y = torch.concat([y, samples], dim=1) @@ -291,17 +249,6 @@ class Text2SemanticDecoder(nn.Module): def init_onnx(self): self.onnx_encoder = OnnxEncoder(self.ar_text_embedding, self.bert_proj, self.ar_text_position) - self.first_stage_decoder = T2SFirstStageDecoder( - self.ar_audio_embedding, - self.ar_audio_position, - self.h, - self.ar_predict_layer, - self.loss_fct, - self.ar_accuracy_metric, - self.top_k, - self.early_stop_num, - self.num_layers, - ) self.stage_decoder = T2SStageDecoder( self.ar_audio_embedding, self.ar_audio_position, @@ -309,33 +256,56 @@ class Text2SemanticDecoder(nn.Module): self.ar_predict_layer, self.loss_fct, self.ar_accuracy_metric, - self.top_k, self.early_stop_num, self.num_layers, ) - def forward(self, x, prompts, bert_feature): + def forward(self, x, prompts, bert_feature, top_k = None): + # torch.manual_seed(42) + # torch.use_deterministic_algorithms(True) + if top_k is None: + top_k = self.top_k early_stop_num = self.early_stop_num prefix_len = prompts.shape[1] x = self.onnx_encoder(x, bert_feature) - y, k, v, y_emb, stage, x_example = self.first_stage_decoder(x, prompts) + + x_seq_len = x.shape[1] + y_seq_len = prompts.shape[1] + + init_k = torch.zeros(((x_seq_len + y_seq_len), self.num_layers, 512), dtype=torch.float) + init_v = torch.zeros(((x_seq_len + y_seq_len), self.num_layers, 512), dtype=torch.float) + + empty_tensor = torch.empty((1,0,512)).to(torch.float) + + y, k, v, y_emb, logits, samples = self.stage_decoder(x, prompts, init_k, init_v, + empty_tensor, top_k=top_k, + first_infer=torch.LongTensor([1]), + x_seq_len=x_seq_len, y_seq_len=y_seq_len) stop = False - for idx in range(1, 1500): - enco = self.stage_decoder(y, k, v, y_emb, stage, x_example) - y, k, v, y_emb, stage, logits, samples = enco + for idx in tqdm(range(1, 1500)): + k = torch.nn.functional.pad(k, (0, 0, 0, 0, 0, 1)) + v = torch.nn.functional.pad(v, (0, 0, 0, 0, 0, 1)) + y_seq_len = y.shape[1] + enco = self.stage_decoder(empty_tensor, y, k, v, y_emb, top_k=top_k, + first_infer=torch.LongTensor([0]), x_seq_len=x_seq_len, y_seq_len=y_seq_len) + y, k, v, y_emb, logits, samples = enco if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: stop = True if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS: stop = True if stop: + y = y[:,:-1] break - y[0, -1] = 0 + # torch.use_deterministic_algorithms(False) return y, idx - def infer(self, x, prompts, bert_feature): - top_k = self.top_k + def infer(self, x, prompts, bert_feature, top_k=None): + # torch.manual_seed(42) + # torch.use_deterministic_algorithms(True) + if top_k is None: + top_k = self.top_k early_stop_num = self.early_stop_num x = self.onnx_encoder(x, bert_feature) @@ -356,11 +326,14 @@ class Text2SemanticDecoder(nn.Module): "first_infer": 1, "stage": 0, } - for idx in range(1500): + for idx in tqdm(range(1500)): if cache["first_infer"] == 1: y_emb = self.ar_audio_embedding(y) else: y_emb = torch.cat([cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1) + for i in range(len(cache["k"])): + cache["k"][i] = torch.nn.functional.pad(cache["k"][i], (0, 0, 0, 0, 0, 1)) + cache["v"][i] = torch.nn.functional.pad(cache["v"][i], (0, 0, 0, 0, 0, 1)) cache["y_emb"] = y_emb y_pos = self.ar_audio_position(y_emb) if cache["first_infer"] == 1: @@ -380,15 +353,14 @@ class Text2SemanticDecoder(nn.Module): xy_attn_mask = torch.zeros((1, x_len + y_len), dtype=torch.bool) xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache) logits = self.ar_predict_layer(xy_dec[:, -1]) - samples = sample(logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0) + samples = sample(logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35, temperature=torch.Tensor([1.0]))[0].unsqueeze(0) if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: stop = True if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS: stop = True if stop: - if prompts.shape[1] == y.shape[1]: - y = torch.concat([y, torch.zeros_like(samples)], dim=1) break y = torch.concat([y, samples], dim=1) cache["first_infer"] = 0 + # torch.use_deterministic_algorithms(False) return y, idx diff --git a/GPT_SoVITS/AR/modules/patched_mha_with_cache_onnx.py b/GPT_SoVITS/AR/modules/patched_mha_with_cache_onnx.py index 8144c9c6..f8461f7d 100644 --- a/GPT_SoVITS/AR/modules/patched_mha_with_cache_onnx.py +++ b/GPT_SoVITS/AR/modules/patched_mha_with_cache_onnx.py @@ -2,6 +2,7 @@ from torch.nn.functional import * from torch.nn.functional import ( _canonical_mask, ) +from typing import Tuple, Optional def multi_head_attention_forward_patched( @@ -48,14 +49,21 @@ def multi_head_attention_forward_patched( proj_qkv = proj_qkv.unflatten(-1, (3, query.size(-1))).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous() q, k, v = proj_qkv[0], proj_qkv[1], proj_qkv[2] - if cache["first_infer"] == 1: - cache["k"][cache["stage"]] = k - cache["v"][cache["stage"]] = v - else: - cache["k"][cache["stage"]] = torch.cat([cache["k"][cache["stage"]][:-1], k], 0) - cache["v"][cache["stage"]] = torch.cat([cache["v"][cache["stage"]][:-1], v], 0) - k = cache["k"][cache["stage"]] - v = cache["v"][cache["stage"]] + # 使用动态形状推断来统一处理kv cache首步和后续步骤形状差异 + # # k,v : [N, 1, 512] at first time, [1, 1, 512] afterwards + # # cache_k, cache_v : [N, 1, 512] for one head, N size increasement is prepared outside + # cache["k"][:, cache["stage"]:cache["stage"]+1, :] + # cache["v"][:, cache["stage"]:cache["stage"]+1, :] + # Magic to get an index of either -1 or -N according to if first_infer_mask is set + minus_one = torch.tensor([-1]).to(k.device).to(torch.int64) + multipled = minus_one * cache["first_infer"] * (cache['x_seq_len'] + cache['y_seq_len']) + index_offset = torch.min(minus_one, multipled) + # 首次时 index 为 -N,后续index 为 -1 + cache["k"][index_offset:, cache["stage"]:cache["stage"]+1, :] = k + cache["v"][index_offset:, cache["stage"]:cache["stage"]+1, :] = v + k = cache["k"][:, cache["stage"]:cache["stage"]+1, :] + v = cache["v"][:, cache["stage"]:cache["stage"]+1, :] + cache["stage"] = (cache["stage"] + 1) % cache["all_stage"] attn_mask = _canonical_mask( diff --git a/GPT_SoVITS/eres2net/kaldi.py b/GPT_SoVITS/eres2net/kaldi.py index a80e5e6b..cf32b04a 100644 --- a/GPT_SoVITS/eres2net/kaldi.py +++ b/GPT_SoVITS/eres2net/kaldi.py @@ -13,6 +13,7 @@ __all__ = [ "mel_scale_scalar", "spectrogram", "fbank", + "fbank_onnx" "mfcc", "vtln_warp_freq", "vtln_warp_mel_freq", @@ -842,3 +843,391 @@ def mfcc( feature = _subtract_column_mean(feature, subtract_mean) return feature + +def _get_log_energy_onnx(strided_input: Tensor, epsilon: Tensor, energy_floor: float = 1.0) -> Tensor: + r"""Returns the log energy of size (m) for a strided_input (m,*)""" + device, dtype = strided_input.device, strided_input.dtype + log_energy = torch.max(strided_input.pow(2).sum(1), epsilon).log() # size (m) + return torch.max(log_energy, torch.tensor(0.0, device=device, dtype=dtype)) + + +def _get_waveform_and_window_properties_onnx( + waveform: Tensor, +) -> Tuple[Tensor, int, int, int]: + r"""ONNX-compatible version with hardcoded parameters from traced fbank call: + channel=-1, sample_frequency=16000, frame_shift=10.0, frame_length=25.0, + round_to_power_of_two=True, preemphasis_coefficient=0.97""" + + # Hardcoded values from traced parameters + # channel=-1 -> 0 after max(channel, 0) + channel = 0 + + # Extract channel 0 from waveform + if waveform.dim() == 1: + # Mono waveform, use as-is + waveform_selected = waveform + else: + # Multi-channel, select first channel + waveform_selected = waveform[channel, :] + + # Hardcoded calculations: + # window_shift = int(16000 * 10.0 * 0.001) = 160 + # window_size = int(16000 * 25.0 * 0.001) = 400 + # padded_window_size = _next_power_of_2(400) = 512 + window_shift = 160 + window_size = 400 + padded_window_size = 512 + + return waveform_selected, window_shift, window_size, padded_window_size + +def _get_window_onnx( + waveform: Tensor, +) -> Tuple[Tensor, Tensor]: + r"""ONNX-compatible version with hardcoded parameters from traced fbank call: + padded_window_size=512, window_size=400, window_shift=160, window_type='povey', + blackman_coeff=0.42, snip_edges=True, raw_energy=True, energy_floor=1.0, + dither=0, remove_dc_offset=True, preemphasis_coefficient=0.97 + + Returns: + (Tensor, Tensor): strided_input of size (m, 512) and signal_log_energy of size (m) + """ + device, dtype = waveform.device, waveform.dtype + epsilon = _get_epsilon(device, dtype) + + # Hardcoded values from traced parameters + window_size = 400 + window_shift = 160 + padded_window_size = 512 + snip_edges = True + + # size (m, window_size) + strided_input = _get_strided_onnx(waveform, window_size, window_shift, snip_edges) + + # dither=0, so skip dithering (lines 209-211 from original) + + # remove_dc_offset=True, so execute this branch + row_means = torch.mean(strided_input, dim=1).unsqueeze(1) # size (m, 1) + strided_input = strided_input - row_means + + # raw_energy=True, so execute this branch + signal_log_energy = _get_log_energy_onnx(strided_input, epsilon) # energy_floor=1.0 + + # preemphasis_coefficient=0.97 != 0.0, so execute this branch + offset_strided_input = torch.nn.functional.pad(strided_input.unsqueeze(0), (1, 0), mode="replicate").squeeze(0) + strided_input = strided_input - 0.97 * offset_strided_input[:, :-1] + + # Apply povey window function to each row/frame + # povey window: torch.hann_window(window_size, periodic=False).pow(0.85) + window_function = torch.hann_window(window_size, periodic=False, device=device, dtype=dtype).pow(0.85).unsqueeze(0) + strided_input = strided_input * window_function + + # Pad columns from window_size=400 to padded_window_size=512 + padding_right = padded_window_size - window_size # 112 + strided_input = torch.nn.functional.pad(strided_input.unsqueeze(0), (0, padding_right), mode="constant", value=0).squeeze(0) + + # raw_energy=True, so skip the "not raw_energy" branch (lines 244-245) + return strided_input, signal_log_energy + + +def _get_strided_onnx(waveform: Tensor, window_size = 400, window_shift = 160, snip_edges = 512) -> Tensor: + seq_len = waveform.size(0) + + # Calculate number of windows + num_windows = 1 + (seq_len - window_size) // window_shift + + # Create indices for all windows at once + window_starts = torch.arange(0, num_windows * window_shift, window_shift, device=waveform.device) + window_indices = window_starts.unsqueeze(1) + torch.arange(window_size, device=waveform.device).unsqueeze(0) + + # Extract windows using advanced indexing + windows = waveform[window_indices] # [num_windows, window_size] + + return windows + + +def _subtract_column_mean_onnx(tensor: Tensor) -> Tensor: + """ONNX-compatible version with hardcoded parameters from traced fbank call: + subtract_mean=False, so this function returns the input tensor unchanged. + + Args: + tensor: Input tensor of size (m, n) + + Returns: + Tensor: Same as input tensor (m, n) since subtract_mean=False + """ + # subtract_mean=False from traced parameters, so return tensor as-is + return tensor + + +def get_mel_banks_onnx( + device=None, + dtype=None, +) -> Tensor: + """ONNX-compatible version with hardcoded parameters from traced fbank call: + num_bins=80, window_length_padded=512, sample_freq=16000, low_freq=20.0, + high_freq=0.0, vtln_low=100.0, vtln_high=-500.0, vtln_warp_factor=1.0 + + Returns: + Tensor: melbank of size (80, 256) (num_bins, num_fft_bins) + """ + # Hardcoded values from traced parameters + num_bins = 80 + window_length_padded = 512 + sample_freq = 16000.0 + low_freq = 20.0 + high_freq = 0.0 # Will be adjusted to nyquist + vtln_warp_factor = 1.0 + + # Calculate dynamic values to ensure accuracy + num_fft_bins = window_length_padded // 2 # 256 (integer division) + nyquist = 0.5 * sample_freq # 8000.0 + + # high_freq <= 0.0, so high_freq += nyquist + if high_freq <= 0.0: + high_freq += nyquist # 8000.0 + + # fft-bin width = sample_freq / window_length_padded = 16000 / 512 = 31.25 + fft_bin_width = sample_freq / window_length_padded + + # Calculate mel scale values dynamically + mel_low_freq = mel_scale_scalar(low_freq) + mel_high_freq = mel_scale_scalar(high_freq) + + # mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1) + mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1) + + # vtln_warp_factor == 1.0, so no VTLN warping needed + + # Create mel bin centers + bin_indices = torch.arange(num_bins, device=device, dtype=dtype).unsqueeze(1) + left_mel = mel_low_freq + bin_indices * mel_freq_delta + center_mel = mel_low_freq + (bin_indices + 1.0) * mel_freq_delta + right_mel = mel_low_freq + (bin_indices + 2.0) * mel_freq_delta + + # No VTLN warping since vtln_warp_factor == 1.0 + + # Create frequency bins for FFT + fft_freqs = fft_bin_width * torch.arange(num_fft_bins, device=device, dtype=dtype) + mel = mel_scale(fft_freqs).unsqueeze(0) # size(1, num_fft_bins) + + # Calculate triangular filter banks + up_slope = (mel - left_mel) / (center_mel - left_mel) + down_slope = (right_mel - mel) / (right_mel - center_mel) + + # Since vtln_warp_factor == 1.0, use the simpler branch + bins = torch.max(torch.zeros(1, device=device, dtype=dtype), torch.min(up_slope, down_slope)) + + return bins + + +def fbank_onnx( + waveform: Tensor, num_mel_bins=80, sample_frequency=16000, dither=0 +) -> Tensor: + r"""ONNX-compatible fbank function with hardcoded parameters from traced call: + num_mel_bins=80, sample_frequency=16000, dither=0 + blackman_coeff: float = 0.42, + channel: int = -1, + energy_floor: float = 1.0, + frame_length: float = 25.0, + frame_shift: float = 10.0, + high_freq: float = 0.0, + htk_compat: bool = False, + low_freq: float = 20.0, + min_duration: float = 0.0, + preemphasis_coefficient: float = 0.97, + raw_energy: bool = True, + remove_dc_offset: bool = True, + round_to_power_of_two: bool = True, + snip_edges: bool = True, + subtract_mean: bool = False, + use_energy: bool = False, + use_log_fbank: bool = True, + use_power: bool = True, + vtln_high: float = -500.0, + vtln_low: float = 100.0, + vtln_warp: float = 1.0, + window_type: str = POVEY + + Args: + waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2) + + Returns: + Tensor: A fbank identical to what Kaldi would output. The shape is (m, 80) + where m is calculated in _get_strided + """ + device, dtype = waveform.device, waveform.dtype + + # Use ONNX-compatible version of _get_waveform_and_window_properties + waveform_selected, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties_onnx(waveform) + + # min_duration=0.0, so skip the duration check (signal will never be too short) + + # Use ONNX-compatible version of _get_window + strided_input, signal_log_energy = _get_window_onnx(waveform_selected) + + # spectrum = torch.fft.rfft(strided_input).abs() + + m, frame_size = strided_input.shape + + # Process all frames at once using batch processing + # Reshape to (m, 1, frame_size) to treat each frame as a separate batch item + batched_frames = strided_input.unsqueeze(1) # Shape: (m, 1, 512) + + # Create rectangular window for all frames at once + rectangular_window = torch.ones(512, device=strided_input.device, dtype=strided_input.dtype) + + # Apply STFT to all frames simultaneously + # The batch dimension allows us to process all m frames in parallel + stft_result = torch.stft( + batched_frames.flatten(0, 1), # Shape: (m, 512) - flatten batch and channel dims + n_fft=512, + hop_length=512, # Process entire frame at once + window=rectangular_window, + center=False, # Don't add padding + return_complex=False + ) + + # stft_result shape: (m, 257, 1, 2) where last dim is [real, imag] + # Calculate magnitude: sqrt(real^2 + imag^2) + real_part = stft_result[..., 0] # Shape: (m, 257, 1) + imag_part = stft_result[..., 1] # Shape: (m, 257, 1) + spectrum = torch.sqrt(real_part.pow(2) + imag_part.pow(2)).squeeze(-1) # Shape: (m, 257) + + # use_power=True, so execute this branch + spectrum = spectrum.pow(2.0) + + # Get mel filterbanks using ONNX-compatible version + mel_energies = get_mel_banks_onnx(device, dtype) + + # pad right column with zeros to match FFT output size (80, 256) -> (80, 257) + mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode="constant", value=0) + + # sum with mel filterbanks over the power spectrum, size (m, 80) + mel_energies = torch.mm(spectrum, mel_energies.T) + + # use_log_fbank=True, so execute this branch + mel_energies = torch.max(mel_energies, _get_epsilon(device, dtype)).log() + + # use_energy=False, so skip the energy addition (lines 828-834) + + # Use ONNX-compatible version of _subtract_column_mean + mel_energies = _subtract_column_mean_onnx(mel_energies) + + return mel_energies + +# Test to compare original fbank vs fbank_onnx +if __name__ == "__main__": + import torch + + print("Testing fbank vs fbank_onnx with traced parameters...") + + # Create test waveform + torch.manual_seed(42) + sample_rate = 16000 + duration = 1.0 # 1 second + num_samples = int(sample_rate * duration) + + # Create a test waveform (sine wave + noise) + t = torch.linspace(0, duration, num_samples) + frequency = 440.0 # A4 note + waveform = torch.sin(2 * torch.pi * frequency * t) + 0.1 * torch.randn(num_samples) + + # Test with both mono and stereo inputs + mono_waveform = waveform.unsqueeze(0) # Shape: (1, num_samples) + + print(f"Test waveform shape: {mono_waveform.shape}") + + # Test parameters from trace: num_mel_bins=80, sample_frequency=16000, dither=0 + try: + print("\n=== DEBUGGING: Step-by-step comparison ===") + + # Step 1: Check waveform processing + orig_waveform, orig_window_shift, orig_window_size, orig_padded_window_size = _get_waveform_and_window_properties( + mono_waveform, -1, 16000.0, 10.0, 25.0, True, 0.97 + ) + onnx_waveform, onnx_window_shift, onnx_window_size, onnx_padded_window_size = _get_waveform_and_window_properties_onnx(mono_waveform) + + print(f"Original waveform shape: {orig_waveform.shape}") + print(f"ONNX waveform shape: {onnx_waveform.shape}") + print(f"Waveform difference: {torch.max(torch.abs(orig_waveform - onnx_waveform)).item():.2e}") + print(f"Window params - orig: shift={orig_window_shift}, size={orig_window_size}, padded={orig_padded_window_size}") + print(f"Window params - onnx: shift={onnx_window_shift}, size={onnx_window_size}, padded={onnx_padded_window_size}") + + # Step 2: Check windowing + orig_strided, orig_energy = _get_window( + orig_waveform, orig_padded_window_size, orig_window_size, orig_window_shift, + 'povey', 0.42, True, True, 1.0, 0, True, 0.97 + ) + onnx_strided, onnx_energy = _get_window_onnx(onnx_waveform) + + print(f"\nOriginal strided shape: {orig_strided.shape}") + print(f"ONNX strided shape: {onnx_strided.shape}") + print(f"Strided difference: {torch.max(torch.abs(orig_strided - onnx_strided)).item():.2e}") + print(f"Energy difference: {torch.max(torch.abs(orig_energy - onnx_energy)).item():.2e}") + + # Step 3: Check mel banks + orig_mel_banks = get_mel_banks(80, 512, 16000.0, 20.0, 0.0, 100.0, -500.0, 1.0, mono_waveform.device, mono_waveform.dtype) + onnx_mel_banks = get_mel_banks_onnx(mono_waveform.device, mono_waveform.dtype) + + print(f"\nOriginal mel banks shape: {orig_mel_banks.shape}") + print(f"ONNX mel banks shape: {onnx_mel_banks.shape}") + print(f"Mel banks difference: {torch.max(torch.abs(orig_mel_banks - onnx_mel_banks)).item():.2e}") + + # Step 4: Full comparison + print("\n=== FULL COMPARISON ===") + + # Original fbank + original_result = fbank( + mono_waveform, + num_mel_bins=80, + sample_frequency=16000, + dither=0 + ) + + # ONNX-compatible fbank + onnx_result = fbank_onnx(mono_waveform) + + print(f"Original fbank output shape: {original_result.shape}") + print(f"ONNX fbank output shape: {onnx_result.shape}") + + # Check if shapes match + if original_result.shape == onnx_result.shape: + print("✅ Output shapes match") + else: + print("❌ Output shapes don't match") + print(f" Original: {original_result.shape}") + print(f" ONNX: {onnx_result.shape}") + + # Check numerical differences + if original_result.shape == onnx_result.shape: + diff = torch.abs(original_result - onnx_result) + max_diff = torch.max(diff).item() + mean_diff = torch.mean(diff).item() + relative_diff = torch.mean(diff / (torch.abs(original_result) + 1e-8)).item() + + print(f"Max absolute difference: {max_diff:.2e}") + print(f"Mean absolute difference: {mean_diff:.2e}") + print(f"Mean relative difference: {relative_diff:.2e}") + + # Find where the max difference occurs + max_idx = torch.argmax(diff) + max_coords = torch.unravel_index(max_idx, diff.shape) + print(f"Max difference at coordinates: {max_coords}") + print(f" Original value: {original_result[max_coords].item():.6f}") + print(f" ONNX value: {onnx_result[max_coords].item():.6f}") + + # Check if results are numerically close + tolerance = 1e-5 + if max_diff < tolerance: + print(f"✅ Results are numerically identical (within {tolerance})") + else: + print(f"❌ Results {max_diff} differ by more than {tolerance}") + + # Additional statistics + print(f"Original result range: [{torch.min(original_result).item():.3f}, {torch.max(original_result).item():.3f}]") + print(f"ONNX result range: [{torch.min(onnx_result).item():.3f}, {torch.max(onnx_result).item():.3f}]") + + except Exception as e: + print(f"❌ Error during testing: {e}") + import traceback + traceback.print_exc() diff --git a/GPT_SoVITS/export_roberta_onnx.py b/GPT_SoVITS/export_roberta_onnx.py new file mode 100644 index 00000000..cea4c4ee --- /dev/null +++ b/GPT_SoVITS/export_roberta_onnx.py @@ -0,0 +1,187 @@ +import torch +import torch.nn as nn +from transformers import AutoTokenizer, AutoModelForMaskedLM +import onnx +import onnxruntime as ort +from typing import Dict, Any +import argparse +import os +import shutil +import numpy as np +import onnxsim +import onnx + +class CombinedBERTModel(nn.Module): + """Wrapper class that combines BERT tokenizer preprocessing and model inference.""" + + def __init__(self, model_name: str): + super().__init__() + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.model = AutoModelForMaskedLM.from_pretrained(model_name) + + def forward(self, text_input: torch.Tensor): + """Forward pass that includes tokenization and model inference.""" + # Note: For ONNX export, we'll work with pre-tokenized input_ids + # In practice, text tokenization needs to happen outside ONNX + input_ids = text_input.long() + + outputs = self.model(input_ids=input_ids, output_hidden_states=True) + return torch.cat(outputs["hidden_states"][-3:-2], -1)[0].cpu()[1:-1] + +def export_bert_to_onnx( + model_name: str = "bert-base-uncased", + output_dir: str = "bert_exported", + max_seq_length: int = 512 +): + """Export BERT model to ONNX format and copy tokenizer files.""" + + # Create output directory + os.makedirs(output_dir, exist_ok=True) + + print(f"Loading model: {model_name}") + combined_model = CombinedBERTModel(model_name) + combined_model.eval() + + # Create dummy inputs for ONNX export (pre-tokenized input_ids) + batch_size = 1 + dummy_input_ids = torch.randint(0, combined_model.tokenizer.vocab_size, (batch_size, max_seq_length)) + + # Export to ONNX + onnx_path = os.path.join(output_dir, "chinese-roberta-wwm-ext-large.onnx") + print(f"Exporting to ONNX: {onnx_path}") + torch.onnx.export( + combined_model, + dummy_input_ids, + onnx_path, + export_params=True, + opset_version=14, + do_constant_folding=True, + input_names=['input_ids'], + output_names=['logits'], + dynamic_axes={ + 'input_ids': {0: 'batch_size', 1: 'sequence_length'}, + 'logits': {0: 'logits_length'} + } + ) + # Load the ONNX model + model = onnx.load(onnx_path) + # Simplify the model + model_simplified, _ = onnxsim.simplify(model) + # Save the simplified model + onnx.save(model_simplified, onnx_path) + + # Copy tokenizer.json if it exists + tokenizer_cache_dir = combined_model.tokenizer.name_or_path + if os.path.isdir(tokenizer_cache_dir): + tokenizer_json_path = os.path.join(tokenizer_cache_dir, "tokenizer.json") + else: + # For models from HuggingFace cache + from transformers import cached_path + try: + tokenizer_json_path = combined_model.tokenizer._tokenizer.model_path + except: + # Alternative approach to find tokenizer.json in cache + cache_dir = os.path.expanduser("~/.cache/huggingface/transformers") + tokenizer_json_path = None + for root, dirs, files in os.walk(cache_dir): + if "tokenizer.json" in files and model_name.replace("/", "--") in root: + tokenizer_json_path = os.path.join(root, "tokenizer.json") + break + + if tokenizer_json_path and os.path.exists(tokenizer_json_path): + dest_tokenizer_path = os.path.join(output_dir, "tokenizer.json") + shutil.copy2(tokenizer_json_path, dest_tokenizer_path) + print(f"Copied tokenizer.json to: {dest_tokenizer_path}") + else: + print("Warning: tokenizer.json not found") + + # Copy config.json if it exists + if tokenizer_cache_dir and os.path.isdir(tokenizer_cache_dir): + config_json_path = os.path.join(tokenizer_cache_dir, "config.json") + else: + # For models from HuggingFace cache + cache_dir = os.path.expanduser("~/.cache/huggingface/transformers") + config_json_path = None + for root, dirs, files in os.walk(cache_dir): + if "config.json" in files and model_name.replace("/", "--") in root: + config_json_path = os.path.join(root, "config.json") + break + + if config_json_path and os.path.exists(config_json_path): + dest_config_path = os.path.join(output_dir, "config.json") + shutil.copy2(config_json_path, dest_config_path) + print(f"Copied config.json to: {dest_config_path}") + else: + print("Warning: config.json not found") + + print(f"Model exported successfully to: {output_dir}") + return combined_model, onnx_path + +def test_model_equivalence(original_model, onnx_path: str, max_seq_length: int = 512, tolerance: float = 1e-5): + """Test if the original PyTorch model and ONNX model produce the same outputs.""" + + print("Testing model equivalence...") + + # Create test input + batch_size = 1 + test_input_ids = torch.randint(0, original_model.tokenizer.vocab_size, (batch_size, max_seq_length)) + input_ids = original_model.tokenizer.encode("原神,启动!", return_tensors="pt") + + + # Get PyTorch output + original_model.eval() + with torch.no_grad(): + pytorch_output = original_model(input_ids).numpy() + + # Get ONNX output + ort_session = ort.InferenceSession(onnx_path) + onnx_output = ort_session.run(None, {"input_ids": input_ids.numpy()})[0] + + print(f"PyTorch output shape: {pytorch_output.shape}") + print(f"ONNX output shape: {onnx_output.shape}") + # Compare outputs + max_diff = np.max(np.abs(pytorch_output - onnx_output)) + mean_diff = np.mean(np.abs(pytorch_output - onnx_output)) + + print(f"Maximum absolute difference: {max_diff}") + print(f"Mean absolute difference: {mean_diff}") + + if max_diff < tolerance: + print("✅ Models are numerically equivalent!") + return True + else: + print("❌ Models have significant differences!") + return False + +def main(): + parser = argparse.ArgumentParser(description="Export BERT model to ONNX") + parser.add_argument("--model_name", type=str, default="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large", + help="Pretrained BERT model name") + parser.add_argument("--output_dir", type=str, default="playground/chinese-roberta-wwm-ext-large", + help="Output directory path") + parser.add_argument("--max_seq_length", type=int, default=512, + help="Maximum sequence length") + parser.add_argument("--tolerance", type=float, default=1e-3, + help="Tolerance for numerical comparison") + + args = parser.parse_args() + + os.makedirs(args.output_dir, exist_ok=True) + + # Export model + original_model, onnx_path = export_bert_to_onnx( + model_name=args.model_name, + output_dir=args.output_dir, + max_seq_length=args.max_seq_length + ) + + # Test equivalence + test_model_equivalence( + original_model=original_model, + onnx_path=onnx_path, + max_seq_length=args.max_seq_length, + tolerance=args.tolerance + ) + +if __name__ == "__main__": + main() diff --git a/GPT_SoVITS/module/core_vq.py b/GPT_SoVITS/module/core_vq.py index b7dab317..876a6984 100644 --- a/GPT_SoVITS/module/core_vq.py +++ b/GPT_SoVITS/module/core_vq.py @@ -357,9 +357,17 @@ class ResidualVectorQuantization(nn.Module): return out_indices def decode(self, q_indices: torch.Tensor, st: int = 0) -> torch.Tensor: - quantized_out = torch.tensor(0.0, device=q_indices.device) - for i, indices in enumerate(q_indices): - layer = self.layers[st + i] - quantized = layer.decode(indices) - quantized_out = quantized_out + quantized + # ONNX-friendly approach: use unbind instead of enumerate loop + indices_list = torch.unbind(q_indices, dim=0) + quantized_list = [] + + for i, indices in enumerate(indices_list): + if st + i < len(self.layers): + layer = self.layers[st + i] + quantized = layer.decode(indices) + quantized_list.append(quantized) + + # Stack and sum instead of iterative addition + quantized_out = torch.stack(quantized_list, dim=0).sum(dim=0) + return quantized_out diff --git a/GPT_SoVITS/module/models_onnx.py b/GPT_SoVITS/module/models_onnx.py index b62b8b71..9ee78d60 100644 --- a/GPT_SoVITS/module/models_onnx.py +++ b/GPT_SoVITS/module/models_onnx.py @@ -205,6 +205,8 @@ class TextEncoder(nn.Module): self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) def forward(self, y, text, ge, speed=1): + if type(speed) == float: + speed = torch.FloatTensor([speed]) y_mask = torch.ones_like(y[:1, :1, :]) y = self.ssl_proj(y * y_mask) * y_mask @@ -217,9 +219,8 @@ class TextEncoder(nn.Module): y = self.mrte(y, y_mask, text, text_mask, ge) y = self.encoder2(y * y_mask, y_mask) - if speed != 1: - y = F.interpolate(y, size=int(y.shape[-1] / speed) + 1, mode="linear") - y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest") + y = F.interpolate(y, size=(y.shape[-1] / speed).to(torch.int) + 1, mode="linear") + y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest") stats = self.proj(y) * y_mask m, logs = torch.split(stats, self.out_channels, dim=1) diff --git a/GPT_SoVITS/onnx_export.py b/GPT_SoVITS/onnx_export.py deleted file mode 100644 index fd680135..00000000 --- a/GPT_SoVITS/onnx_export.py +++ /dev/null @@ -1,398 +0,0 @@ -import torch -import torchaudio -from AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule -from feature_extractor import cnhubert -from module.models_onnx import SynthesizerTrn, symbols_v1, symbols_v2 -from torch import nn - -cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base" -cnhubert.cnhubert_base_path = cnhubert_base_path -ssl_model = cnhubert.get_model() -import json -import os - -import soundfile -from text import cleaned_text_to_sequence - - -def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): - hann_window = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) - y = torch.nn.functional.pad( - y.unsqueeze(1), - (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), - mode="reflect", - ) - y = y.squeeze(1) - spec = torch.stft( - y, - n_fft, - hop_length=hop_size, - win_length=win_size, - window=hann_window, - center=center, - pad_mode="reflect", - normalized=False, - onesided=True, - return_complex=False, - ) - spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) - return spec - - -class DictToAttrRecursive(dict): - def __init__(self, input_dict): - super().__init__(input_dict) - for key, value in input_dict.items(): - if isinstance(value, dict): - value = DictToAttrRecursive(value) - self[key] = value - setattr(self, key, value) - - def __getattr__(self, item): - try: - return self[item] - except KeyError: - raise AttributeError(f"Attribute {item} not found") - - def __setattr__(self, key, value): - if isinstance(value, dict): - value = DictToAttrRecursive(value) - super(DictToAttrRecursive, self).__setitem__(key, value) - super().__setattr__(key, value) - - def __delattr__(self, item): - try: - del self[item] - except KeyError: - raise AttributeError(f"Attribute {item} not found") - - -class T2SEncoder(nn.Module): - def __init__(self, t2s, vits): - super().__init__() - self.encoder = t2s.onnx_encoder - self.vits = vits - - def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content): - codes = self.vits.extract_latent(ssl_content) - prompt_semantic = codes[0, 0] - bert = torch.cat([ref_bert.transpose(0, 1), text_bert.transpose(0, 1)], 1) - all_phoneme_ids = torch.cat([ref_seq, text_seq], 1) - bert = bert.unsqueeze(0) - prompt = prompt_semantic.unsqueeze(0) - return self.encoder(all_phoneme_ids, bert), prompt - - -class T2SModel(nn.Module): - def __init__(self, t2s_path, vits_model): - super().__init__() - dict_s1 = torch.load(t2s_path, map_location="cpu") - self.config = dict_s1["config"] - self.t2s_model = Text2SemanticLightningModule(self.config, "ojbk", is_train=False) - self.t2s_model.load_state_dict(dict_s1["weight"]) - self.t2s_model.eval() - self.vits_model = vits_model.vq_model - self.hz = 50 - self.max_sec = self.config["data"]["max_sec"] - self.t2s_model.model.top_k = torch.LongTensor([self.config["inference"]["top_k"]]) - self.t2s_model.model.early_stop_num = torch.LongTensor([self.hz * self.max_sec]) - self.t2s_model = self.t2s_model.model - self.t2s_model.init_onnx() - self.onnx_encoder = T2SEncoder(self.t2s_model, self.vits_model) - self.first_stage_decoder = self.t2s_model.first_stage_decoder - self.stage_decoder = self.t2s_model.stage_decoder - # self.t2s_model = torch.jit.script(self.t2s_model) - - def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content): - early_stop_num = self.t2s_model.early_stop_num - - # [1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N] - x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content) - - prefix_len = prompts.shape[1] - - # [1,N,512] [1,N] - y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts) - - stop = False - for idx in range(1, 1500): - # [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N] - enco = self.stage_decoder(y, k, v, y_emb, x_example) - y, k, v, y_emb, logits, samples = enco - if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: - stop = True - if torch.argmax(logits, dim=-1)[0] == self.t2s_model.EOS or samples[0, 0] == self.t2s_model.EOS: - stop = True - if stop: - break - y[0, -1] = 0 - - return y[:, -idx:].unsqueeze(0) - - def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name, dynamo=False): - # self.onnx_encoder = torch.jit.script(self.onnx_encoder) - if dynamo: - export_options = torch.onnx.ExportOptions(dynamic_shapes=True) - onnx_encoder_export_output = torch.onnx.dynamo_export( - self.onnx_encoder, (ref_seq, text_seq, ref_bert, text_bert, ssl_content), export_options=export_options - ) - onnx_encoder_export_output.save(f"onnx/{project_name}/{project_name}_t2s_encoder.onnx") - return - - torch.onnx.export( - self.onnx_encoder, - (ref_seq, text_seq, ref_bert, text_bert, ssl_content), - f"onnx/{project_name}/{project_name}_t2s_encoder.onnx", - input_names=["ref_seq", "text_seq", "ref_bert", "text_bert", "ssl_content"], - output_names=["x", "prompts"], - dynamic_axes={ - "ref_seq": {1: "ref_length"}, - "text_seq": {1: "text_length"}, - "ref_bert": {0: "ref_length"}, - "text_bert": {0: "text_length"}, - "ssl_content": {2: "ssl_length"}, - }, - opset_version=16, - ) - x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content) - - torch.onnx.export( - self.first_stage_decoder, - (x, prompts), - f"onnx/{project_name}/{project_name}_t2s_fsdec.onnx", - input_names=["x", "prompts"], - output_names=["y", "k", "v", "y_emb", "x_example"], - dynamic_axes={ - "x": {1: "x_length"}, - "prompts": {1: "prompts_length"}, - }, - verbose=False, - opset_version=16, - ) - y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts) - - torch.onnx.export( - self.stage_decoder, - (y, k, v, y_emb, x_example), - f"onnx/{project_name}/{project_name}_t2s_sdec.onnx", - input_names=["iy", "ik", "iv", "iy_emb", "ix_example"], - output_names=["y", "k", "v", "y_emb", "logits", "samples"], - dynamic_axes={ - "iy": {1: "iy_length"}, - "ik": {1: "ik_length"}, - "iv": {1: "iv_length"}, - "iy_emb": {1: "iy_emb_length"}, - "ix_example": {1: "ix_example_length"}, - }, - verbose=False, - opset_version=16, - ) - - -class VitsModel(nn.Module): - def __init__(self, vits_path): - super().__init__() - dict_s2 = torch.load(vits_path, map_location="cpu") - self.hps = dict_s2["config"] - if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322: - self.hps["model"]["version"] = "v1" - else: - self.hps["model"]["version"] = "v2" - - self.hps = DictToAttrRecursive(self.hps) - self.hps.model.semantic_frame_rate = "25hz" - self.vq_model = SynthesizerTrn( - self.hps.data.filter_length // 2 + 1, - self.hps.train.segment_size // self.hps.data.hop_length, - n_speakers=self.hps.data.n_speakers, - **self.hps.model, - ) - self.vq_model.eval() - self.vq_model.load_state_dict(dict_s2["weight"], strict=False) - - def forward(self, text_seq, pred_semantic, ref_audio): - refer = spectrogram_torch( - ref_audio, - self.hps.data.filter_length, - self.hps.data.sampling_rate, - self.hps.data.hop_length, - self.hps.data.win_length, - center=False, - ) - return self.vq_model(pred_semantic, text_seq, refer)[0, 0] - - -class GptSoVits(nn.Module): - def __init__(self, vits, t2s): - super().__init__() - self.vits = vits - self.t2s = t2s - - def forward(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, debug=False): - pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content) - audio = self.vits(text_seq, pred_semantic, ref_audio) - if debug: - import onnxruntime - - sess = onnxruntime.InferenceSession("onnx/koharu/koharu_vits.onnx", providers=["CPU"]) - audio1 = sess.run( - None, - { - "text_seq": text_seq.detach().cpu().numpy(), - "pred_semantic": pred_semantic.detach().cpu().numpy(), - "ref_audio": ref_audio.detach().cpu().numpy(), - }, - ) - return audio, audio1 - return audio - - def export(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, project_name): - self.t2s.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name) - pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content) - torch.onnx.export( - self.vits, - (text_seq, pred_semantic, ref_audio), - f"onnx/{project_name}/{project_name}_vits.onnx", - input_names=["text_seq", "pred_semantic", "ref_audio"], - output_names=["audio"], - dynamic_axes={ - "text_seq": {1: "text_length"}, - "pred_semantic": {2: "pred_length"}, - "ref_audio": {1: "audio_length"}, - }, - opset_version=17, - verbose=False, - ) - - -class SSLModel(nn.Module): - def __init__(self): - super().__init__() - self.ssl = ssl_model - - def forward(self, ref_audio_16k): - return self.ssl.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2) - - -def export(vits_path, gpt_path, project_name, vits_model="v2"): - vits = VitsModel(vits_path) - gpt = T2SModel(gpt_path, vits) - gpt_sovits = GptSoVits(vits, gpt) - ssl = SSLModel() - ref_seq = torch.LongTensor( - [ - cleaned_text_to_sequence( - [ - "n", - "i2", - "h", - "ao3", - ",", - "w", - "o3", - "sh", - "i4", - "b", - "ai2", - "y", - "e4", - ], - version=vits_model, - ) - ] - ) - text_seq = torch.LongTensor( - [ - cleaned_text_to_sequence( - [ - "w", - "o3", - "sh", - "i4", - "b", - "ai2", - "y", - "e4", - "w", - "o3", - "sh", - "i4", - "b", - "ai2", - "y", - "e4", - "w", - "o3", - "sh", - "i4", - "b", - "ai2", - "y", - "e4", - ], - version=vits_model, - ) - ] - ) - ref_bert = torch.randn((ref_seq.shape[1], 1024)).float() - text_bert = torch.randn((text_seq.shape[1], 1024)).float() - ref_audio = torch.randn((1, 48000 * 5)).float() - # ref_audio = torch.tensor([load_audio("rec.wav", 48000)]).float() - ref_audio_16k = torchaudio.functional.resample(ref_audio, 48000, 16000).float() - ref_audio_sr = torchaudio.functional.resample(ref_audio, 48000, vits.hps.data.sampling_rate).float() - - try: - os.mkdir(f"onnx/{project_name}") - except: - pass - - ssl_content = ssl(ref_audio_16k).float() - - # debug = False - debug = True - - # gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, project_name) - - if debug: - a, b = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, debug=debug) - soundfile.write("out1.wav", a.cpu().detach().numpy(), vits.hps.data.sampling_rate) - soundfile.write("out2.wav", b[0], vits.hps.data.sampling_rate) - else: - a = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content).detach().cpu().numpy() - soundfile.write("out.wav", a, vits.hps.data.sampling_rate) - - if vits_model == "v1": - symbols = symbols_v1 - else: - symbols = symbols_v2 - - MoeVSConf = { - "Folder": f"{project_name}", - "Name": f"{project_name}", - "Type": "GPT-SoVits", - "Rate": vits.hps.data.sampling_rate, - "NumLayers": gpt.t2s_model.num_layers, - "EmbeddingDim": gpt.t2s_model.embedding_dim, - "Dict": "BasicDict", - "BertPath": "chinese-roberta-wwm-ext-large", - # "Symbol": symbols, - "AddBlank": False, - } - - MoeVSConfJson = json.dumps(MoeVSConf) - with open(f"onnx/{project_name}.json", "w") as MoeVsConfFile: - json.dump(MoeVSConf, MoeVsConfFile, indent=4) - - -if __name__ == "__main__": - try: - os.mkdir("onnx") - except: - pass - - gpt_path = "GPT_weights/nahida-e25.ckpt" - vits_path = "SoVITS_weights/nahida_e30_s3930.pth" - exp_path = "nahida" - export(vits_path, gpt_path, exp_path) - - # soundfile.write("out.wav", a, vits.hps.data.sampling_rate) diff --git a/GPT_SoVITS/onnx_export_v1v2.py b/GPT_SoVITS/onnx_export_v1v2.py new file mode 100644 index 00000000..45f4679b --- /dev/null +++ b/GPT_SoVITS/onnx_export_v1v2.py @@ -0,0 +1,448 @@ +import torch +import torch.nn.functional as F +import torchaudio +from AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule +from feature_extractor import cnhubert +from module.models_onnx import SynthesizerTrn, symbols_v1, symbols_v2 +from torch import nn +from sv import SV +import onnx +from onnx import helper, TensorProto +cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base" +from transformers import HubertModel, HubertConfig +import os +import json +from text import cleaned_text_to_sequence +import onnxsim +from onnxconverter_common import float16 + +def simplify_onnx_model(onnx_model_path: str): + # Load the ONNX model + model = onnx.load(onnx_model_path) + # Simplify the model + model_simplified, _ = onnxsim.simplify(model) + # Save the simplified model + onnx.save(model_simplified, onnx_model_path) + +def convert_onnx_to_half(onnx_model_path:str): + try: + model = onnx.load(onnx_model_path) + model_fp16 = float16.convert_float_to_float16(model) + onnx.save(model_fp16, onnx_model_path) + except Exception as e: + print(f"Error converting {onnx_model_path} to half precision: {e}") + + +def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): + hann_window = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + spec = torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window, + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=False, + ) + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + return spec + +def resample_audio(audio: torch.Tensor, orig_sr: int, target_sr: int) -> torch.Tensor: + """ + Resample audio from orig_sr to target_sr using linear interpolation. + audio: (batch, channels, samples) or (channels, samples) or (samples,) + """ + if audio.dim() == 1: + audio = audio.unsqueeze(0).unsqueeze(0) + elif audio.dim() == 2: + audio = audio.unsqueeze(0) + # audio shape: (batch, channels, samples) + batch, channels, samples = audio.shape + # Reshape to combine batch and channels for interpolation + audio = audio.reshape(batch * channels, 1, samples) + # Use scale_factor instead of a computed size for ONNX export compatibility + resampled = F.interpolate(audio, scale_factor=target_sr / orig_sr, mode='linear', align_corners=False) + new_samples = resampled.shape[-1] + resampled = resampled.reshape(batch, channels, new_samples) + resampled = resampled.squeeze(0).squeeze(0) + return resampled + + +class DictToAttrRecursive(dict): + def __init__(self, input_dict): + super().__init__(input_dict) + for key, value in input_dict.items(): + if isinstance(value, dict): + value = DictToAttrRecursive(value) + self[key] = value + setattr(self, key, value) + + def __getattr__(self, item): + try: + return self[item] + except KeyError: + raise AttributeError(f"Attribute {item} not found") + + def __setattr__(self, key, value): + if isinstance(value, dict): + value = DictToAttrRecursive(value) + super(DictToAttrRecursive, self).__setitem__(key, value) + super().__setattr__(key, value) + + def __delattr__(self, item): + try: + del self[item] + except KeyError: + raise AttributeError(f"Attribute {item} not found") + + +class T2SInitStage(nn.Module): + def __init__(self, t2s, vits): + super().__init__() + self.encoder = t2s.onnx_encoder + self.vits = vits + self.num_layers = t2s.num_layers + + def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content): + codes = self.vits.extract_latent(ssl_content) + prompt_semantic = codes[0, 0] + bert = torch.cat([ref_bert.transpose(0, 1), text_bert.transpose(0, 1)], 1) + all_phoneme_ids = torch.cat([ref_seq, text_seq], 1) + bert = bert.unsqueeze(0) + prompt = prompt_semantic.unsqueeze(0) + x = self.encoder(all_phoneme_ids, bert) + + x_seq_len = torch.onnx.operators.shape_as_tensor(x)[1] + y_seq_len = torch.onnx.operators.shape_as_tensor(prompt)[1] + + init_k = torch.zeros(((x_seq_len + y_seq_len), self.num_layers, 512), dtype=torch.float) + init_v = torch.zeros(((x_seq_len + y_seq_len), self.num_layers, 512), dtype=torch.float) + + return x, prompt, init_k, init_v, x_seq_len, y_seq_len + +class T2SModel(nn.Module): + def __init__(self, t2s_path, vits_model): + super().__init__() + dict_s1 = torch.load(t2s_path, map_location="cpu") + self.config = dict_s1["config"] + self.t2s_model = Text2SemanticLightningModule(self.config, "ojbk", is_train=False) + self.t2s_model.load_state_dict(dict_s1["weight"]) + self.t2s_model.eval() + self.vits_model = vits_model.vq_model + self.hz = 50 + self.max_sec = self.config["data"]["max_sec"] + self.t2s_model.model.top_k = torch.LongTensor([self.config["inference"]["top_k"]]) + self.t2s_model.model.early_stop_num = torch.LongTensor([self.hz * self.max_sec]) + self.t2s_model = self.t2s_model.model + self.t2s_model.init_onnx() + self.init_stage = T2SInitStage(self.t2s_model, self.vits_model) + self.stage_decoder = self.t2s_model.stage_decoder + + def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=None, top_p=None, repetition_penalty=None, temperature=None): + x, prompt, init_k, init_v, x_seq_len, y_seq_len = self.init_stage(ref_seq, text_seq, ref_bert, text_bert, ssl_content) + empty_tensor = torch.empty((1,0,512)).to(torch.float) + # first step + y, k, v, y_emb, logits, samples = self.stage_decoder(x, prompt, init_k, init_v, + empty_tensor, + top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature, + first_infer=torch.LongTensor([1]), x_seq_len=x_seq_len, y_seq_len=y_seq_len) + + for idx in range(5): # This is a fake one! DO NOT take this as reference + k = torch.nn.functional.pad(k, (0, 0, 0, 0, 0, 1)) + v = torch.nn.functional.pad(v, (0, 0, 0, 0, 0, 1)) + y_seq_len = y.shape[1] + y, k, v, y_emb, logits, samples = self.stage_decoder(empty_tensor, y, k, v, + y_emb, + top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature, + first_infer=torch.LongTensor([0]), x_seq_len=x_seq_len, y_seq_len=y_seq_len) + # if torch.argmax(logits, dim=-1)[0] == self.t2s_model.EOS or samples[0, 0] == self.t2s_model.EOS: + # break + + return y[:, -5:].unsqueeze(0) + + def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name, top_k=None, top_p=None, repetition_penalty=None, temperature=None): + torch.onnx.export( + self.init_stage, + (ref_seq, text_seq, ref_bert, text_bert, ssl_content), + f"onnx/{project_name}/{project_name}_t2s_init_stage.onnx", + input_names=["ref_text_phones", "input_text_phones", "ref_text_bert", "input_text_bert", "hubert_ssl_content"], + output_names=["x", "prompt", "init_k", "init_v", 'x_seq_len', 'y_seq_len'], + dynamic_axes={ + "ref_text_phones": {1: "ref_length"}, + "input_text_phones": {1: "text_length"}, + "ref_text_bert": {0: "ref_length"}, + "input_text_bert": {0: "text_length"}, + "hubert_ssl_content": {2: "ssl_length"}, + }, + opset_version=16, + do_constant_folding=False + ) + simplify_onnx_model(f"onnx/{project_name}/{project_name}_t2s_init_stage.onnx") + x, prompt, init_k, init_v, x_seq_len, y_seq_len = self.init_stage(ref_seq, text_seq, ref_bert, text_bert, ssl_content) + empty_tensor = torch.empty((1,0,512)).to(torch.float) + x_seq_len = torch.Tensor([x_seq_len]).to(torch.int64) + y_seq_len = torch.Tensor([y_seq_len]).to(torch.int64) + + y, k, v, y_emb, logits, samples = self.stage_decoder(x, prompt, init_k, init_v, + empty_tensor, + top_k, top_p, repetition_penalty, temperature, + torch.LongTensor([1]), x_seq_len, y_seq_len) + k = torch.nn.functional.pad(k, (0, 0, 0, 0, 0, 1)) + v = torch.nn.functional.pad(v, (0, 0, 0, 0, 0, 1)) + y_seq_len = torch.Tensor([y.shape[1]]).to(torch.int64) + + torch.onnx.export( + self.stage_decoder, + (x, y, k, v, y_emb, top_k, top_p, repetition_penalty, temperature, torch.LongTensor([0]), x_seq_len, y_seq_len), + f"onnx/{project_name}/{project_name}_t2s_stage_decoder.onnx", + input_names=["ix", "iy", "ik", "iv", "iy_emb", "top_k", "top_p", "repetition_penalty", "temperature", "if_init_step", "x_seq_len", "y_seq_len"], + output_names=["y", "k", "v", "y_emb", "logits", "samples"], + dynamic_axes={ + "ix": {1: "ix_length"}, + "iy": {1: "iy_length"}, + "ik": {0: "ik_length"}, + "iv": {0: "iv_length"}, + "iy_emb": {1: "iy_emb_length"}, + }, + verbose=False, + opset_version=16, + ) + simplify_onnx_model(f"onnx/{project_name}/{project_name}_t2s_stage_decoder.onnx") + + +class VitsModel(nn.Module): + def __init__(self, vits_path, version:str = 'v2'): + super().__init__() + dict_s2 = torch.load(vits_path, map_location="cpu", weights_only=False) + self.hps = dict_s2["config"] + if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322: + self.hps["model"]["version"] = "v1" + else: + self.hps["model"]["version"] = version + + self.is_v2p = version.lower() in ['v2pro', 'v2proplus'] + + self.hps = DictToAttrRecursive(self.hps) + self.hps.model.semantic_frame_rate = "25hz" + self.vq_model:SynthesizerTrn = SynthesizerTrn( + self.hps.data.filter_length // 2 + 1, + self.hps.train.segment_size // self.hps.data.hop_length, + n_speakers=self.hps.data.n_speakers, + **self.hps.model, + ) + self.vq_model.eval() + self.vq_model.load_state_dict(dict_s2["weight"], strict=False) + # print(f"filter_length:{self.hps.data.filter_length} sampling_rate:{self.hps.data.sampling_rate} hop_length:{self.hps.data.hop_length} win_length:{self.hps.data.win_length}") + #v2 filter_length: 2048 sampling_rate: 32000 hop_length: 640 win_length: 2048 + def forward(self, text_seq, pred_semantic, spectrum, sv_emb, speed): + if self.is_v2p: + return self.vq_model(pred_semantic, text_seq, spectrum, sv_emb=sv_emb, speed=speed)[0, 0] + else: + return self.vq_model(pred_semantic, text_seq, spectrum, speed=speed)[0, 0] + + +class GptSoVits(): + def __init__(self, vits, t2s): + super().__init__() + self.vits = vits + self.t2s = t2s + + def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, spectrum, sv_emb, speed, project_name, top_k=None, top_p=None, repetition_penalty=None, temperature=None): + self.t2s.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature) + pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature) + torch.onnx.export( + self.vits, + (text_seq, pred_semantic, spectrum, sv_emb, speed), + f"onnx/{project_name}/{project_name}_vits.onnx", + input_names=["input_text_phones", "pred_semantic", "spectrum", "sv_emb", "speed"], + output_names=["audio"], + dynamic_axes={ + "input_text_phones": {1: "text_length"}, + "pred_semantic": {2: "pred_length"}, + "spectrum": {2: "spectrum_length"}, + }, + opset_version=17, + verbose=False, + ) + simplify_onnx_model(f"onnx/{project_name}/{project_name}_vits.onnx") + + +class AudioPreprocess(nn.Module): + def __init__(self): + super().__init__() + + # Load the model + self.model = HubertModel.from_pretrained(cnhubert_base_path, local_files_only=True) + self.model.eval() + + self.sv_model = SV("cpu", False) + + def forward(self, ref_audio_32k): + spectrum = spectrogram_torch( + ref_audio_32k, + 2048, + 32000, + 640, + 2048, + center=False, + ) + ref_audio_16k = resample_audio(ref_audio_32k, 32000, 16000) + + sv_emb = self.sv_model.compute_embedding3_onnx(ref_audio_16k) + + zero_tensor = torch.zeros((1, 9600), dtype=torch.float32) + ref_audio_16k = ref_audio_16k.unsqueeze(0) + # concate zero_tensor with waveform + ref_audio_16k = torch.cat([ref_audio_16k, zero_tensor], dim=1) + ssl_content = self.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2) + + return ssl_content, spectrum, sv_emb + +def export(vits_path, gpt_path, project_name, voice_model_version, export_audio_preprocessor=True, half_precision=False): + vits = VitsModel(vits_path, version=voice_model_version) + gpt = T2SModel(gpt_path, vits) + gpt_sovits = GptSoVits(vits, gpt) + preprocessor = AudioPreprocess() + ref_seq = torch.LongTensor( + [ + cleaned_text_to_sequence( + [ + "n", + "i2", + "h", + "ao3", + ",", + "w", + "o3", + "sh", + "i4", + "b", + "ai2", + "y", + "e4", + ], + version='v2', + ) + ] + ) + text_seq = torch.LongTensor( + [ + cleaned_text_to_sequence( + [ + "w", + "o3", + "sh", + "i4", + "b", + "ai2", + "y", + "e4", + "w", + "o3", + "sh", + "i4", + "b", + "ai2", + "y", + "e4", + "w", + "o3", + "sh", + "i4", + "b", + "ai2", + "y", + "e4", + ], + version='v2', + ) + ] + ) + ref_bert = torch.randn((ref_seq.shape[1], 1024)).float() + text_bert = torch.randn((text_seq.shape[1], 1024)).float() + ref_audio32k = torch.randn((1, 32000 * 5)).float() - 0.5 # 5 seconds of dummy audio + top_k = torch.LongTensor([15]) + top_p = torch.FloatTensor([1.0]) + repetition_penalty = torch.FloatTensor([1.0]) + temperature = torch.FloatTensor([1.0]) + speed = torch.FloatTensor([1.0]) + + os.makedirs(f"onnx/{project_name}", exist_ok=True) + + [ssl_content, spectrum, sv_emb] = preprocessor(ref_audio32k) + gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content.float(), spectrum.float(), sv_emb.float(), speed, project_name, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature) + + if export_audio_preprocessor: + torch.onnx.export(preprocessor, (ref_audio32k,), f"onnx/{project_name}/{project_name}_audio_preprocess.onnx", + input_names=["audio32k"], + output_names=["hubert_ssl_output", "spectrum", "sv_emb"], + dynamic_axes={ + "audio32k": {1: "sequence_length"}, + "hubert_ssl_output": {2: "hubert_length"}, + "spectrum": {2: "spectrum_length"} + }) + simplify_onnx_model(f"onnx/{project_name}/{project_name}_audio_preprocess.onnx") + + if half_precision: + if export_audio_preprocessor: + convert_onnx_to_half(f"onnx/{project_name}/{project_name}_audio_preprocess.onnx") + convert_onnx_to_half(f"onnx/{project_name}/{project_name}_vits.onnx") + convert_onnx_to_half(f"onnx/{project_name}/{project_name}_t2s_init_step.onnx") + convert_onnx_to_half(f"onnx/{project_name}/{project_name}_t2s_stage_step.onnx") + + configJson = { + "project_name": project_name, + "type": "GPTSoVITS", + "version" : voice_model_version, + "bert_base_path": 'GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large', + "cnhuhbert_base_path": 'GPT_SoVITS/pretrained_models/chinese-hubert-base', + "t2s_weights_path": gpt_path, + "vits_weights_path": vits_path, + "half_precision": half_precision + } + with open(f"onnx/{project_name}/config.json", "w", encoding="utf-8") as f: + json.dump(configJson, f, ensure_ascii=False, indent=4) + +if __name__ == "__main__": + try: + os.mkdir("onnx") + except: + pass + + # 因为io太频繁,可能导致模型导出出错(wsl非常明显),请自行重试 + + gpt_path = "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt" + vits_path = "GPT_SoVITS/pretrained_models/s2G488k.pth" + exp_path = "v1_export" + version = "v1" + export(vits_path, gpt_path, exp_path, version) + + gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt" + vits_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth" + exp_path = "v2_export" + version = "v2" + export(vits_path, gpt_path, exp_path, version) + + + gpt_path = "GPT_SoVITS/pretrained_models/s1v3.ckpt" + vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth" + exp_path = "v2pro_export" + version = "v2Pro" + export(vits_path, gpt_path, exp_path, version) + + gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt" + vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth" + exp_path = "v2proplus_export" + version = "v2ProPlus" + export(vits_path, gpt_path, exp_path, version) + + diff --git a/GPT_SoVITS/sv.py b/GPT_SoVITS/sv.py index 22e70369..50069c7e 100644 --- a/GPT_SoVITS/sv.py +++ b/GPT_SoVITS/sv.py @@ -30,3 +30,15 @@ class SV: ) sv_emb = self.embedding_model.forward3(feat) return sv_emb + + def compute_embedding3_onnx(self, wav): + # Disable gradients for all parameters + for param in self.embedding_model.parameters(): + param.requires_grad = False + + with torch.no_grad(): + if self.is_half == True: + wav = wav.half() + feat = Kaldi.fbank_onnx(wav.detach()).unsqueeze(0) + sv_emb = self.embedding_model.forward3(feat) + return sv_emb \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 90e4957d..40f0976e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,8 +7,11 @@ numba pytorch-lightning>=2.4 gradio<5 ffmpeg-python +onnx onnxruntime; platform_machine == "aarch64" or platform_machine == "arm64" onnxruntime-gpu; platform_machine == "x86_64" or platform_machine == "AMD64" +onnxsim +onnxconverter-common tqdm funasr==1.0.27 cn2an @@ -32,7 +35,7 @@ rotary_embedding_torch ToJyutping g2pk2 ko_pron -opencc +opencc==1.1.6 python_mecab_ko; sys_platform != 'win32' fastapi[standard]>=0.115.2 x_transformers