mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-01-11 04:46:57 +08:00
Merge 7632175d7f17bb551fa3a44220b9cd34363b14e2 into 40cd22e69d439954a74914855f05e9c4c3ab26da
This commit is contained in:
commit
466859e87e
@ -69,7 +69,8 @@ def logits_to_probs(
|
|||||||
def multinomial_sample_one_no_sync(
|
def multinomial_sample_one_no_sync(
|
||||||
probs_sort
|
probs_sort
|
||||||
): # Does multinomial sampling without a cuda synchronization
|
): # Does multinomial sampling without a cuda synchronization
|
||||||
q = torch.randn_like(probs_sort)
|
lambda_ = 1.0
|
||||||
|
q = -torch.log(torch.rand_like(probs_sort)) / lambda_
|
||||||
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
||||||
|
|
||||||
|
|
||||||
@ -152,6 +153,7 @@ class T2SFirstStageDecoder(nn.Module):
|
|||||||
|
|
||||||
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
|
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
|
||||||
logits = self.ar_predict_layer(xy_dec[:, -1])
|
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||||
|
logits = logits[:, :-1] ###刨除1024终止符号的概率
|
||||||
samples = sample(logits[0], y, top_k=self.top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
|
samples = sample(logits[0], y, top_k=self.top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
|
||||||
|
|
||||||
y = torch.concat([y, samples], dim=1)
|
y = torch.concat([y, samples], dim=1)
|
||||||
|
|||||||
@ -50,10 +50,15 @@ class SinePositionalEmbedding(nn.Module):
|
|||||||
self.div_term = torch.exp(torch.arange(0, self.embedding_dim, 2) * -(math.log(10000.0) / self.embedding_dim))
|
self.div_term = torch.exp(torch.arange(0, self.embedding_dim, 2) * -(math.log(10000.0) / self.embedding_dim))
|
||||||
|
|
||||||
def extend_pe(self, x):
|
def extend_pe(self, x):
|
||||||
position = torch.cumsum(torch.ones_like(x[:,:,0]), dim=1).transpose(0, 1)
|
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
||||||
scpe = (position * self.div_term).unsqueeze(0)
|
div_term = torch.exp(
|
||||||
pe = torch.cat([torch.sin(scpe), torch.cos(scpe)]).permute(1, 2, 0)
|
torch.arange(0, self.embedding_dim, 2, dtype=torch.float32)
|
||||||
pe = pe.contiguous().view(1, -1, self.embedding_dim)
|
* -(math.log(10000.0) / self.embedding_dim)
|
||||||
|
)
|
||||||
|
pe = torch.zeros(x.size(1), self.embedding_dim)
|
||||||
|
pe[:, 0::2] = torch.sin(position * div_term)
|
||||||
|
pe[:, 1::2] = torch.cos(position * div_term)
|
||||||
|
pe = pe.unsqueeze(0)
|
||||||
return pe
|
return pe
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
|||||||
@ -892,7 +892,7 @@ class SynthesizerTrn(nn.Module):
|
|||||||
# self.enc_p.encoder_text.requires_grad_(False)
|
# self.enc_p.encoder_text.requires_grad_(False)
|
||||||
# self.enc_p.mrte.requires_grad_(False)
|
# self.enc_p.mrte.requires_grad_(False)
|
||||||
|
|
||||||
def forward(self, codes, text, refer):
|
def forward(self, codes, text, refer, noise_scale=0.5):
|
||||||
refer_mask = torch.ones_like(refer[:1,:1,:])
|
refer_mask = torch.ones_like(refer[:1,:1,:])
|
||||||
ge = self.ref_enc(refer * refer_mask, refer_mask)
|
ge = self.ref_enc(refer * refer_mask, refer_mask)
|
||||||
|
|
||||||
@ -905,7 +905,7 @@ class SynthesizerTrn(nn.Module):
|
|||||||
quantized, text, ge
|
quantized, text, ge
|
||||||
)
|
)
|
||||||
|
|
||||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p)
|
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
||||||
|
|
||||||
z = self.flow(z_p, y_mask, g=ge, reverse=True)
|
z = self.flow(z_p, y_mask, g=ge, reverse=True)
|
||||||
|
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import torch
|
|||||||
import torchaudio
|
import torchaudio
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from feature_extractor import cnhubert
|
from feature_extractor import cnhubert
|
||||||
cnhubert_base_path = "pretrained_models/chinese-hubert-base"
|
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
|
||||||
cnhubert.cnhubert_base_path=cnhubert_base_path
|
cnhubert.cnhubert_base_path=cnhubert_base_path
|
||||||
ssl_model = cnhubert.get_model()
|
ssl_model = cnhubert.get_model()
|
||||||
from text import cleaned_text_to_sequence
|
from text import cleaned_text_to_sequence
|
||||||
@ -266,6 +266,22 @@ class SSLModel(nn.Module):
|
|||||||
def forward(self, ref_audio_16k):
|
def forward(self, ref_audio_16k):
|
||||||
return self.ssl.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2)
|
return self.ssl.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2)
|
||||||
|
|
||||||
|
def export(self, ref_audio_16k, project_name):
|
||||||
|
self.ssl.model.eval()
|
||||||
|
torch.onnx.export(
|
||||||
|
self,
|
||||||
|
(ref_audio_16k),
|
||||||
|
f"onnx/{project_name}/{project_name}_cnhubert.onnx",
|
||||||
|
input_names=["ref_audio_16k"],
|
||||||
|
output_names=["last_hidden_state"],
|
||||||
|
dynamic_axes={
|
||||||
|
"ref_audio_16k": {1 : "text_length"},
|
||||||
|
"last_hidden_state": {2 : "pred_length"}
|
||||||
|
},
|
||||||
|
opset_version=17,
|
||||||
|
verbose=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def export(vits_path, gpt_path, project_name):
|
def export(vits_path, gpt_path, project_name):
|
||||||
vits = VitsModel(vits_path)
|
vits = VitsModel(vits_path)
|
||||||
@ -300,6 +316,7 @@ def export(vits_path, gpt_path, project_name):
|
|||||||
|
|
||||||
soundfile.write("out.wav", a, vits.hps.data.sampling_rate)
|
soundfile.write("out.wav", a, vits.hps.data.sampling_rate)
|
||||||
|
|
||||||
|
ssl.export(ref_audio_16k, project_name)
|
||||||
gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, project_name)
|
gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, project_name)
|
||||||
|
|
||||||
MoeVSConf = {
|
MoeVSConf = {
|
||||||
@ -326,8 +343,8 @@ if __name__ == "__main__":
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
gpt_path = "GPT_weights/nahida-e25.ckpt"
|
gpt_path = "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
|
||||||
vits_path = "SoVITS_weights/nahida_e30_s3930.pth"
|
vits_path = "GPT_SoVITS/pretrained_models/s2G488k.pth"
|
||||||
exp_path = "nahida"
|
exp_path = "nahida"
|
||||||
export(vits_path, gpt_path, exp_path)
|
export(vits_path, gpt_path, exp_path)
|
||||||
|
|
||||||
|
|||||||
140
GPT_SoVITS/onnx_inference.py
Normal file
140
GPT_SoVITS/onnx_inference.py
Normal file
@ -0,0 +1,140 @@
|
|||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
import onnxruntime
|
||||||
|
|
||||||
|
import os
|
||||||
|
from text import cleaned_text_to_sequence
|
||||||
|
from text.japanese import g2p
|
||||||
|
import soundfile
|
||||||
|
|
||||||
|
import ffmpeg
|
||||||
|
import numpy as np
|
||||||
|
import librosa
|
||||||
|
from my_utils import load_audio
|
||||||
|
|
||||||
|
|
||||||
|
class T2SModel(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.hz = 50
|
||||||
|
self.max_sec = 54
|
||||||
|
self.top_k = 5
|
||||||
|
self.early_stop_num = torch.LongTensor([self.hz * self.max_sec])
|
||||||
|
self.sess_encoder = onnxruntime.InferenceSession(f"./onnx/nahida/nahida_t2s_encoder.onnx", providers=["CPUExecutionProvider"])
|
||||||
|
self.sess_fsdec = onnxruntime.InferenceSession(f"./onnx/nahida/nahida_t2s_fsdec.onnx", providers=["CPUExecutionProvider"])
|
||||||
|
self.sess_sdec = onnxruntime.InferenceSession(f"./onnx/nahida/nahida_t2s_sdec.onnx", providers=["CPUExecutionProvider"])
|
||||||
|
|
||||||
|
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content):
|
||||||
|
early_stop_num = self.early_stop_num
|
||||||
|
|
||||||
|
EOS = 1024
|
||||||
|
|
||||||
|
#[1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N]
|
||||||
|
x, prompts = self.sess_encoder.run(None, {"ref_seq":ref_seq.detach().numpy(), "text_seq":text_seq.detach().numpy(), "ref_bert":ref_bert.detach().numpy(), "text_bert":text_bert.detach().numpy(), "ssl_content":ssl_content.detach().numpy()})
|
||||||
|
x = torch.from_numpy(x)
|
||||||
|
prompts = torch.from_numpy(prompts)
|
||||||
|
|
||||||
|
prefix_len = prompts.shape[1]
|
||||||
|
|
||||||
|
#[1,N,512] [1,N]
|
||||||
|
y, k, v, y_emb, x_example = self.sess_fsdec.run(None, {"x":x.detach().numpy(), "prompts":prompts.detach().numpy()})
|
||||||
|
y = torch.from_numpy(y)
|
||||||
|
k = torch.from_numpy(k)
|
||||||
|
v = torch.from_numpy(v)
|
||||||
|
y_emb = torch.from_numpy(y_emb)
|
||||||
|
x_example = torch.from_numpy(x_example)
|
||||||
|
|
||||||
|
stop = False
|
||||||
|
for idx in range(1, 1500):
|
||||||
|
#[1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
|
||||||
|
y, k, v, y_emb, logits, samples = self.sess_sdec.run(None, {"iy":y.detach().numpy(), "ik":k.detach().numpy(), "iv":v.detach().numpy(), "iy_emb":y_emb.detach().numpy(), "ix_example":x_example.detach().numpy()})
|
||||||
|
y = torch.from_numpy(y)
|
||||||
|
k = torch.from_numpy(k)
|
||||||
|
v = torch.from_numpy(v)
|
||||||
|
y_emb = torch.from_numpy(y_emb)
|
||||||
|
logits = torch.from_numpy(logits)
|
||||||
|
samples = torch.from_numpy(samples)
|
||||||
|
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
||||||
|
stop = True
|
||||||
|
if torch.argmax(logits, dim=-1)[0] == EOS or samples[0, 0] == EOS:
|
||||||
|
stop = True
|
||||||
|
if stop:
|
||||||
|
break
|
||||||
|
y[0, -1] = 0
|
||||||
|
|
||||||
|
return y[:, -idx:-1].unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
class GptSoVits(nn.Module):
|
||||||
|
def __init__(self, t2s):
|
||||||
|
super().__init__()
|
||||||
|
self.t2s = t2s
|
||||||
|
self.sess = onnxruntime.InferenceSession("./onnx/nahida/nahida_vits.onnx", providers=["CPUExecutionProvider"])
|
||||||
|
|
||||||
|
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content):
|
||||||
|
pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
||||||
|
audio1 = self.sess.run(None, {
|
||||||
|
"text_seq" : text_seq.detach().cpu().numpy(),
|
||||||
|
"pred_semantic" : pred_semantic.detach().cpu().numpy(),
|
||||||
|
"ref_audio" : ref_audio.detach().cpu().numpy()
|
||||||
|
})
|
||||||
|
return torch.from_numpy(audio1[0])
|
||||||
|
|
||||||
|
|
||||||
|
class SSLModel(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.sess = onnxruntime.InferenceSession("./onnx/nahida/nahida_cnhubert.onnx", providers=["CPUExecutionProvider"])
|
||||||
|
|
||||||
|
def forward(self, ref_audio_16k):
|
||||||
|
last_hidden_state = self.sess.run(None, {
|
||||||
|
"ref_audio_16k" : ref_audio_16k.detach().cpu().numpy()
|
||||||
|
})
|
||||||
|
return torch.from_numpy(last_hidden_state[0])
|
||||||
|
|
||||||
|
|
||||||
|
def inference():
|
||||||
|
gpt = T2SModel()
|
||||||
|
gpt_sovits = GptSoVits(gpt)
|
||||||
|
ssl = SSLModel()
|
||||||
|
|
||||||
|
ref_audio = torch.randn((1, 48000 * 5)).float()
|
||||||
|
|
||||||
|
input_audio = "JSUT.wav"
|
||||||
|
ref_phones = g2p("水をマレーシアから買わなくてはならない。")
|
||||||
|
|
||||||
|
ref_audio = torch.tensor([load_audio(input_audio, 48000)]).float()
|
||||||
|
|
||||||
|
ref_seq = torch.LongTensor([cleaned_text_to_sequence(ref_phones)])
|
||||||
|
|
||||||
|
text_phones = g2p("音声合成のテストを行なっています。")
|
||||||
|
text_seq = torch.LongTensor([cleaned_text_to_sequence(text_phones)])
|
||||||
|
|
||||||
|
# empty for ja or en
|
||||||
|
ref_bert = torch.zeros((ref_seq.shape[1], 1024)).float()
|
||||||
|
text_bert = torch.zeros((text_seq.shape[1], 1024)).float()
|
||||||
|
|
||||||
|
ref_audio_16k = torchaudio.functional.resample(ref_audio,48000,16000).float()
|
||||||
|
vits_hps_data_sampling_rate = 32000
|
||||||
|
ref_audio_sr = torchaudio.functional.resample(ref_audio,48000,vits_hps_data_sampling_rate).float()
|
||||||
|
|
||||||
|
zero_wav = np.zeros(
|
||||||
|
int(vits_hps_data_sampling_rate * 0.3),
|
||||||
|
dtype=np.float32,
|
||||||
|
)
|
||||||
|
wav16k, sr = librosa.load(input_audio, sr=16000)
|
||||||
|
wav16k = torch.from_numpy(wav16k)
|
||||||
|
zero_wav_torch = torch.from_numpy(zero_wav)
|
||||||
|
wav16k = torch.cat([wav16k, zero_wav_torch]).unsqueeze(0)
|
||||||
|
ref_audio_16k = wav16k
|
||||||
|
|
||||||
|
ssl_content = ssl(ref_audio_16k).float()
|
||||||
|
|
||||||
|
a = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content)
|
||||||
|
soundfile.write("out.wav", a.cpu().detach().numpy(), vits_hps_data_sampling_rate)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
inference()
|
||||||
|
|
||||||
Loading…
x
Reference in New Issue
Block a user