From 825588b526a537258abc4162dee082a6297bb836 Mon Sep 17 00:00:00 2001 From: Kazuki Kyakuno Date: Thu, 21 Mar 2024 15:16:02 +0900 Subject: [PATCH] Improve the consistency between ONNX and torch --- GPT_SoVITS/AR/models/t2s_model_onnx.py | 4 +++- GPT_SoVITS/AR/modules/embedding_onnx.py | 13 +++++++++---- GPT_SoVITS/module/models_onnx.py | 4 ++-- GPT_SoVITS/onnx_export.py | 23 ++++++++++++++++++++--- 4 files changed, 34 insertions(+), 10 deletions(-) diff --git a/GPT_SoVITS/AR/models/t2s_model_onnx.py b/GPT_SoVITS/AR/models/t2s_model_onnx.py index 7834297d..558ee9af 100644 --- a/GPT_SoVITS/AR/models/t2s_model_onnx.py +++ b/GPT_SoVITS/AR/models/t2s_model_onnx.py @@ -69,7 +69,8 @@ def logits_to_probs( def multinomial_sample_one_no_sync( probs_sort ): # Does multinomial sampling without a cuda synchronization - q = torch.randn_like(probs_sort) + lambda_ = 1.0 + q = -torch.log(torch.rand_like(probs_sort)) / lambda_ return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) @@ -152,6 +153,7 @@ class T2SFirstStageDecoder(nn.Module): xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache) logits = self.ar_predict_layer(xy_dec[:, -1]) + logits = logits[:, :-1] ###刨除1024终止符号的概率 samples = sample(logits[0], y, top_k=self.top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0) y = torch.concat([y, samples], dim=1) diff --git a/GPT_SoVITS/AR/modules/embedding_onnx.py b/GPT_SoVITS/AR/modules/embedding_onnx.py index b93405b4..f8511e6e 100644 --- a/GPT_SoVITS/AR/modules/embedding_onnx.py +++ b/GPT_SoVITS/AR/modules/embedding_onnx.py @@ -50,10 +50,15 @@ class SinePositionalEmbedding(nn.Module): self.div_term = torch.exp(torch.arange(0, self.embedding_dim, 2) * -(math.log(10000.0) / self.embedding_dim)) def extend_pe(self, x): - position = torch.cumsum(torch.ones_like(x[:,:,0]), dim=1).transpose(0, 1) - scpe = (position * self.div_term).unsqueeze(0) - pe = torch.cat([torch.sin(scpe), torch.cos(scpe)]).permute(1, 2, 0) - pe = pe.contiguous().view(1, -1, self.embedding_dim) + position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) + * -(math.log(10000.0) / self.embedding_dim) + ) + pe = torch.zeros(x.size(1), self.embedding_dim) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) return pe def forward(self, x: torch.Tensor) -> torch.Tensor: diff --git a/GPT_SoVITS/module/models_onnx.py b/GPT_SoVITS/module/models_onnx.py index 232fd74d..9784fe6f 100644 --- a/GPT_SoVITS/module/models_onnx.py +++ b/GPT_SoVITS/module/models_onnx.py @@ -892,7 +892,7 @@ class SynthesizerTrn(nn.Module): # self.enc_p.encoder_text.requires_grad_(False) # self.enc_p.mrte.requires_grad_(False) - def forward(self, codes, text, refer): + def forward(self, codes, text, refer, noise_scale=0.5): refer_mask = torch.ones_like(refer[:1,:1,:]) ge = self.ref_enc(refer * refer_mask, refer_mask) @@ -905,7 +905,7 @@ class SynthesizerTrn(nn.Module): quantized, text, ge ) - z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) + z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale z = self.flow(z_p, y_mask, g=ge, reverse=True) diff --git a/GPT_SoVITS/onnx_export.py b/GPT_SoVITS/onnx_export.py index b82e987f..9738a293 100644 --- a/GPT_SoVITS/onnx_export.py +++ b/GPT_SoVITS/onnx_export.py @@ -4,7 +4,7 @@ import torch import torchaudio from torch import nn from feature_extractor import cnhubert -cnhubert_base_path = "pretrained_models/chinese-hubert-base" +cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base" cnhubert.cnhubert_base_path=cnhubert_base_path ssl_model = cnhubert.get_model() from text import cleaned_text_to_sequence @@ -266,6 +266,22 @@ class SSLModel(nn.Module): def forward(self, ref_audio_16k): return self.ssl.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2) + def export(self, ref_audio_16k, project_name): + self.ssl.model.eval() + torch.onnx.export( + self, + (ref_audio_16k), + f"onnx/{project_name}/{project_name}_cnhubert.onnx", + input_names=["ref_audio_16k"], + output_names=["last_hidden_state"], + dynamic_axes={ + "ref_audio_16k": {1 : "text_length"}, + "last_hidden_state": {2 : "pred_length"} + }, + opset_version=17, + verbose=False + ) + def export(vits_path, gpt_path, project_name): vits = VitsModel(vits_path) @@ -300,6 +316,7 @@ def export(vits_path, gpt_path, project_name): soundfile.write("out.wav", a, vits.hps.data.sampling_rate) + ssl.export(ref_audio_16k, project_name) gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, project_name) MoeVSConf = { @@ -326,8 +343,8 @@ if __name__ == "__main__": except: pass - gpt_path = "GPT_weights/nahida-e25.ckpt" - vits_path = "SoVITS_weights/nahida_e30_s3930.pth" + gpt_path = "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt" + vits_path = "GPT_SoVITS/pretrained_models/s2G488k.pth" exp_path = "nahida" export(vits_path, gpt_path, exp_path)