Merge 7632175d7f17bb551fa3a44220b9cd34363b14e2 into 40cd22e69d439954a74914855f05e9c4c3ab26da

This commit is contained in:
Kazuki Kyakuno 2024-09-10 21:48:34 +07:00 committed by GitHub
commit 466859e87e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 174 additions and 10 deletions

View File

@ -69,7 +69,8 @@ def logits_to_probs(
def multinomial_sample_one_no_sync(
probs_sort
): # Does multinomial sampling without a cuda synchronization
q = torch.randn_like(probs_sort)
lambda_ = 1.0
q = -torch.log(torch.rand_like(probs_sort)) / lambda_
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
@ -152,6 +153,7 @@ class T2SFirstStageDecoder(nn.Module):
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
logits = self.ar_predict_layer(xy_dec[:, -1])
logits = logits[:, :-1] ###刨除1024终止符号的概率
samples = sample(logits[0], y, top_k=self.top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
y = torch.concat([y, samples], dim=1)

View File

@ -50,10 +50,15 @@ class SinePositionalEmbedding(nn.Module):
self.div_term = torch.exp(torch.arange(0, self.embedding_dim, 2) * -(math.log(10000.0) / self.embedding_dim))
def extend_pe(self, x):
position = torch.cumsum(torch.ones_like(x[:,:,0]), dim=1).transpose(0, 1)
scpe = (position * self.div_term).unsqueeze(0)
pe = torch.cat([torch.sin(scpe), torch.cos(scpe)]).permute(1, 2, 0)
pe = pe.contiguous().view(1, -1, self.embedding_dim)
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.embedding_dim, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.embedding_dim)
)
pe = torch.zeros(x.size(1), self.embedding_dim)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
return pe
def forward(self, x: torch.Tensor) -> torch.Tensor:

View File

@ -892,7 +892,7 @@ class SynthesizerTrn(nn.Module):
# self.enc_p.encoder_text.requires_grad_(False)
# self.enc_p.mrte.requires_grad_(False)
def forward(self, codes, text, refer):
def forward(self, codes, text, refer, noise_scale=0.5):
refer_mask = torch.ones_like(refer[:1,:1,:])
ge = self.ref_enc(refer * refer_mask, refer_mask)
@ -905,7 +905,7 @@ class SynthesizerTrn(nn.Module):
quantized, text, ge
)
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p)
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=ge, reverse=True)

View File

@ -4,7 +4,7 @@ import torch
import torchaudio
from torch import nn
from feature_extractor import cnhubert
cnhubert_base_path = "pretrained_models/chinese-hubert-base"
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
cnhubert.cnhubert_base_path=cnhubert_base_path
ssl_model = cnhubert.get_model()
from text import cleaned_text_to_sequence
@ -266,6 +266,22 @@ class SSLModel(nn.Module):
def forward(self, ref_audio_16k):
return self.ssl.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2)
def export(self, ref_audio_16k, project_name):
self.ssl.model.eval()
torch.onnx.export(
self,
(ref_audio_16k),
f"onnx/{project_name}/{project_name}_cnhubert.onnx",
input_names=["ref_audio_16k"],
output_names=["last_hidden_state"],
dynamic_axes={
"ref_audio_16k": {1 : "text_length"},
"last_hidden_state": {2 : "pred_length"}
},
opset_version=17,
verbose=False
)
def export(vits_path, gpt_path, project_name):
vits = VitsModel(vits_path)
@ -300,6 +316,7 @@ def export(vits_path, gpt_path, project_name):
soundfile.write("out.wav", a, vits.hps.data.sampling_rate)
ssl.export(ref_audio_16k, project_name)
gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, project_name)
MoeVSConf = {
@ -326,8 +343,8 @@ if __name__ == "__main__":
except:
pass
gpt_path = "GPT_weights/nahida-e25.ckpt"
vits_path = "SoVITS_weights/nahida_e30_s3930.pth"
gpt_path = "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
vits_path = "GPT_SoVITS/pretrained_models/s2G488k.pth"
exp_path = "nahida"
export(vits_path, gpt_path, exp_path)

View File

@ -0,0 +1,140 @@
import torch
import torchaudio
from torch import nn
import onnxruntime
import os
from text import cleaned_text_to_sequence
from text.japanese import g2p
import soundfile
import ffmpeg
import numpy as np
import librosa
from my_utils import load_audio
class T2SModel(nn.Module):
def __init__(self):
super().__init__()
self.hz = 50
self.max_sec = 54
self.top_k = 5
self.early_stop_num = torch.LongTensor([self.hz * self.max_sec])
self.sess_encoder = onnxruntime.InferenceSession(f"./onnx/nahida/nahida_t2s_encoder.onnx", providers=["CPUExecutionProvider"])
self.sess_fsdec = onnxruntime.InferenceSession(f"./onnx/nahida/nahida_t2s_fsdec.onnx", providers=["CPUExecutionProvider"])
self.sess_sdec = onnxruntime.InferenceSession(f"./onnx/nahida/nahida_t2s_sdec.onnx", providers=["CPUExecutionProvider"])
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content):
early_stop_num = self.early_stop_num
EOS = 1024
#[1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N]
x, prompts = self.sess_encoder.run(None, {"ref_seq":ref_seq.detach().numpy(), "text_seq":text_seq.detach().numpy(), "ref_bert":ref_bert.detach().numpy(), "text_bert":text_bert.detach().numpy(), "ssl_content":ssl_content.detach().numpy()})
x = torch.from_numpy(x)
prompts = torch.from_numpy(prompts)
prefix_len = prompts.shape[1]
#[1,N,512] [1,N]
y, k, v, y_emb, x_example = self.sess_fsdec.run(None, {"x":x.detach().numpy(), "prompts":prompts.detach().numpy()})
y = torch.from_numpy(y)
k = torch.from_numpy(k)
v = torch.from_numpy(v)
y_emb = torch.from_numpy(y_emb)
x_example = torch.from_numpy(x_example)
stop = False
for idx in range(1, 1500):
#[1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
y, k, v, y_emb, logits, samples = self.sess_sdec.run(None, {"iy":y.detach().numpy(), "ik":k.detach().numpy(), "iv":v.detach().numpy(), "iy_emb":y_emb.detach().numpy(), "ix_example":x_example.detach().numpy()})
y = torch.from_numpy(y)
k = torch.from_numpy(k)
v = torch.from_numpy(v)
y_emb = torch.from_numpy(y_emb)
logits = torch.from_numpy(logits)
samples = torch.from_numpy(samples)
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
stop = True
if torch.argmax(logits, dim=-1)[0] == EOS or samples[0, 0] == EOS:
stop = True
if stop:
break
y[0, -1] = 0
return y[:, -idx:-1].unsqueeze(0)
class GptSoVits(nn.Module):
def __init__(self, t2s):
super().__init__()
self.t2s = t2s
self.sess = onnxruntime.InferenceSession("./onnx/nahida/nahida_vits.onnx", providers=["CPUExecutionProvider"])
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content):
pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
audio1 = self.sess.run(None, {
"text_seq" : text_seq.detach().cpu().numpy(),
"pred_semantic" : pred_semantic.detach().cpu().numpy(),
"ref_audio" : ref_audio.detach().cpu().numpy()
})
return torch.from_numpy(audio1[0])
class SSLModel(nn.Module):
def __init__(self):
super().__init__()
self.sess = onnxruntime.InferenceSession("./onnx/nahida/nahida_cnhubert.onnx", providers=["CPUExecutionProvider"])
def forward(self, ref_audio_16k):
last_hidden_state = self.sess.run(None, {
"ref_audio_16k" : ref_audio_16k.detach().cpu().numpy()
})
return torch.from_numpy(last_hidden_state[0])
def inference():
gpt = T2SModel()
gpt_sovits = GptSoVits(gpt)
ssl = SSLModel()
ref_audio = torch.randn((1, 48000 * 5)).float()
input_audio = "JSUT.wav"
ref_phones = g2p("水をマレーシアから買わなくてはならない。")
ref_audio = torch.tensor([load_audio(input_audio, 48000)]).float()
ref_seq = torch.LongTensor([cleaned_text_to_sequence(ref_phones)])
text_phones = g2p("音声合成のテストを行なっています。")
text_seq = torch.LongTensor([cleaned_text_to_sequence(text_phones)])
# empty for ja or en
ref_bert = torch.zeros((ref_seq.shape[1], 1024)).float()
text_bert = torch.zeros((text_seq.shape[1], 1024)).float()
ref_audio_16k = torchaudio.functional.resample(ref_audio,48000,16000).float()
vits_hps_data_sampling_rate = 32000
ref_audio_sr = torchaudio.functional.resample(ref_audio,48000,vits_hps_data_sampling_rate).float()
zero_wav = np.zeros(
int(vits_hps_data_sampling_rate * 0.3),
dtype=np.float32,
)
wav16k, sr = librosa.load(input_audio, sr=16000)
wav16k = torch.from_numpy(wav16k)
zero_wav_torch = torch.from_numpy(zero_wav)
wav16k = torch.cat([wav16k, zero_wav_torch]).unsqueeze(0)
ref_audio_16k = wav16k
ssl_content = ssl(ref_audio_16k).float()
a = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content)
soundfile.write("out.wav", a.cpu().detach().numpy(), vits_hps_data_sampling_rate)
if __name__ == "__main__":
inference()