Improve the consistency between ONNX and torch

This commit is contained in:
Kazuki Kyakuno 2024-03-21 15:16:02 +09:00
parent a495540f2a
commit 825588b526
4 changed files with 34 additions and 10 deletions

View File

@ -69,7 +69,8 @@ def logits_to_probs(
def multinomial_sample_one_no_sync( def multinomial_sample_one_no_sync(
probs_sort probs_sort
): # Does multinomial sampling without a cuda synchronization ): # Does multinomial sampling without a cuda synchronization
q = torch.randn_like(probs_sort) lambda_ = 1.0
q = -torch.log(torch.rand_like(probs_sort)) / lambda_
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
@ -152,6 +153,7 @@ class T2SFirstStageDecoder(nn.Module):
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache) xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
logits = self.ar_predict_layer(xy_dec[:, -1]) logits = self.ar_predict_layer(xy_dec[:, -1])
logits = logits[:, :-1] ###刨除1024终止符号的概率
samples = sample(logits[0], y, top_k=self.top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0) samples = sample(logits[0], y, top_k=self.top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
y = torch.concat([y, samples], dim=1) y = torch.concat([y, samples], dim=1)

View File

@ -50,10 +50,15 @@ class SinePositionalEmbedding(nn.Module):
self.div_term = torch.exp(torch.arange(0, self.embedding_dim, 2) * -(math.log(10000.0) / self.embedding_dim)) self.div_term = torch.exp(torch.arange(0, self.embedding_dim, 2) * -(math.log(10000.0) / self.embedding_dim))
def extend_pe(self, x): def extend_pe(self, x):
position = torch.cumsum(torch.ones_like(x[:,:,0]), dim=1).transpose(0, 1) position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
scpe = (position * self.div_term).unsqueeze(0) div_term = torch.exp(
pe = torch.cat([torch.sin(scpe), torch.cos(scpe)]).permute(1, 2, 0) torch.arange(0, self.embedding_dim, 2, dtype=torch.float32)
pe = pe.contiguous().view(1, -1, self.embedding_dim) * -(math.log(10000.0) / self.embedding_dim)
)
pe = torch.zeros(x.size(1), self.embedding_dim)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
return pe return pe
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:

View File

@ -892,7 +892,7 @@ class SynthesizerTrn(nn.Module):
# self.enc_p.encoder_text.requires_grad_(False) # self.enc_p.encoder_text.requires_grad_(False)
# self.enc_p.mrte.requires_grad_(False) # self.enc_p.mrte.requires_grad_(False)
def forward(self, codes, text, refer): def forward(self, codes, text, refer, noise_scale=0.5):
refer_mask = torch.ones_like(refer[:1,:1,:]) refer_mask = torch.ones_like(refer[:1,:1,:])
ge = self.ref_enc(refer * refer_mask, refer_mask) ge = self.ref_enc(refer * refer_mask, refer_mask)
@ -905,7 +905,7 @@ class SynthesizerTrn(nn.Module):
quantized, text, ge quantized, text, ge
) )
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=ge, reverse=True) z = self.flow(z_p, y_mask, g=ge, reverse=True)

View File

@ -4,7 +4,7 @@ import torch
import torchaudio import torchaudio
from torch import nn from torch import nn
from feature_extractor import cnhubert from feature_extractor import cnhubert
cnhubert_base_path = "pretrained_models/chinese-hubert-base" cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
cnhubert.cnhubert_base_path=cnhubert_base_path cnhubert.cnhubert_base_path=cnhubert_base_path
ssl_model = cnhubert.get_model() ssl_model = cnhubert.get_model()
from text import cleaned_text_to_sequence from text import cleaned_text_to_sequence
@ -266,6 +266,22 @@ class SSLModel(nn.Module):
def forward(self, ref_audio_16k): def forward(self, ref_audio_16k):
return self.ssl.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2) return self.ssl.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2)
def export(self, ref_audio_16k, project_name):
self.ssl.model.eval()
torch.onnx.export(
self,
(ref_audio_16k),
f"onnx/{project_name}/{project_name}_cnhubert.onnx",
input_names=["ref_audio_16k"],
output_names=["last_hidden_state"],
dynamic_axes={
"ref_audio_16k": {1 : "text_length"},
"last_hidden_state": {2 : "pred_length"}
},
opset_version=17,
verbose=False
)
def export(vits_path, gpt_path, project_name): def export(vits_path, gpt_path, project_name):
vits = VitsModel(vits_path) vits = VitsModel(vits_path)
@ -300,6 +316,7 @@ def export(vits_path, gpt_path, project_name):
soundfile.write("out.wav", a, vits.hps.data.sampling_rate) soundfile.write("out.wav", a, vits.hps.data.sampling_rate)
ssl.export(ref_audio_16k, project_name)
gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, project_name) gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, project_name)
MoeVSConf = { MoeVSConf = {
@ -326,8 +343,8 @@ if __name__ == "__main__":
except: except:
pass pass
gpt_path = "GPT_weights/nahida-e25.ckpt" gpt_path = "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
vits_path = "SoVITS_weights/nahida_e30_s3930.pth" vits_path = "GPT_SoVITS/pretrained_models/s2G488k.pth"
exp_path = "nahida" exp_path = "nahida"
export(vits_path, gpt_path, exp_path) export(vits_path, gpt_path, exp_path)