mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-08 16:00:01 +08:00
Update onnx export script
This commit is contained in:
parent
939971afe3
commit
b9f2400e82
@ -4,7 +4,14 @@ import torch
|
|||||||
import torchaudio
|
import torchaudio
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from feature_extractor import cnhubert
|
from feature_extractor import cnhubert
|
||||||
cnhubert_base_path = "pretrained_models/chinese-hubert-base"
|
|
||||||
|
#cnhubert_base_path = "pretrained_models/chinese-hubert-base"
|
||||||
|
|
||||||
|
import os
|
||||||
|
cnhubert_base_path = os.environ.get(
|
||||||
|
"cnhubert_base_path", "GPT_SoVITS/pretrained_models/chinese-hubert-base"
|
||||||
|
)
|
||||||
|
|
||||||
cnhubert.cnhubert_base_path=cnhubert_base_path
|
cnhubert.cnhubert_base_path=cnhubert_base_path
|
||||||
ssl_model = cnhubert.get_model()
|
ssl_model = cnhubert.get_model()
|
||||||
from text import cleaned_text_to_sequence
|
from text import cleaned_text_to_sequence
|
||||||
@ -103,22 +110,50 @@ class T2SModel(nn.Module):
|
|||||||
self.stage_decoder = self.t2s_model.stage_decoder
|
self.stage_decoder = self.t2s_model.stage_decoder
|
||||||
#self.t2s_model = torch.jit.script(self.t2s_model)
|
#self.t2s_model = torch.jit.script(self.t2s_model)
|
||||||
|
|
||||||
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content):
|
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, debug=False):
|
||||||
early_stop_num = self.t2s_model.early_stop_num
|
early_stop_num = self.t2s_model.early_stop_num
|
||||||
|
|
||||||
|
if debug:
|
||||||
|
import onnxruntime
|
||||||
|
sess_encoder = onnxruntime.InferenceSession(f"onnx/nahida/nahida_t2s_encoder.onnx", providers=["CPU"])
|
||||||
|
sess_fsdec = onnxruntime.InferenceSession(f"onnx/nahida/nahida_t2s_fsdec.onnx", providers=["CPU"])
|
||||||
|
sess_sdec = onnxruntime.InferenceSession(f"onnx/nahida/nahida_t2s_sdec.onnx", providers=["CPU"])
|
||||||
|
|
||||||
#[1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N]
|
#[1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N]
|
||||||
x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
if debug:
|
||||||
|
x, prompts = sess_encoder.run(None, {"ref_seq":ref_seq.detach().numpy(), "text_seq":text_seq.detach().numpy(), "ref_bert":ref_bert.detach().numpy(), "text_bert":text_bert.detach().numpy(), "ssl_content":ssl_content.detach().numpy()})
|
||||||
|
x = torch.from_numpy(x)
|
||||||
|
prompts = torch.from_numpy(prompts)
|
||||||
|
else:
|
||||||
|
x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
||||||
|
|
||||||
prefix_len = prompts.shape[1]
|
prefix_len = prompts.shape[1]
|
||||||
|
|
||||||
#[1,N,512] [1,N]
|
#[1,N,512] [1,N]
|
||||||
y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
|
if debug:
|
||||||
|
y, k, v, y_emb, x_example = sess_fsdec.run(None, {"x":x.detach().numpy(), "prompts":prompts.detach().numpy()})
|
||||||
|
y = torch.from_numpy(y)
|
||||||
|
k = torch.from_numpy(k)
|
||||||
|
v = torch.from_numpy(v)
|
||||||
|
y_emb = torch.from_numpy(y_emb)
|
||||||
|
x_example = torch.from_numpy(x_example)
|
||||||
|
else:
|
||||||
|
y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
|
||||||
|
|
||||||
stop = False
|
stop = False
|
||||||
for idx in range(1, 1500):
|
for idx in range(1, 1500):
|
||||||
#[1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
|
#[1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
|
||||||
enco = self.stage_decoder(y, k, v, y_emb, x_example)
|
if debug:
|
||||||
y, k, v, y_emb, logits, samples = enco
|
y, k, v, y_emb, logits, samples = sess_sdec.run(None, {"iy":y.detach().numpy(), "ik":k.detach().numpy(), "iv":v.detach().numpy(), "iy_emb":y_emb.detach().numpy(), "ix_example":x_example.detach().numpy()})
|
||||||
|
y = torch.from_numpy(y)
|
||||||
|
k = torch.from_numpy(k)
|
||||||
|
v = torch.from_numpy(v)
|
||||||
|
y_emb = torch.from_numpy(y_emb)
|
||||||
|
logits = torch.from_numpy(logits)
|
||||||
|
samples = torch.from_numpy(samples)
|
||||||
|
else:
|
||||||
|
enco = self.stage_decoder(y, k, v, y_emb, x_example)
|
||||||
|
y, k, v, y_emb, logits, samples = enco
|
||||||
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
||||||
stop = True
|
stop = True
|
||||||
if torch.argmax(logits, dim=-1)[0] == self.t2s_model.EOS or samples[0, 0] == self.t2s_model.EOS:
|
if torch.argmax(logits, dim=-1)[0] == self.t2s_model.EOS or samples[0, 0] == self.t2s_model.EOS:
|
||||||
@ -226,11 +261,11 @@ class GptSoVits(nn.Module):
|
|||||||
self.t2s = t2s
|
self.t2s = t2s
|
||||||
|
|
||||||
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, debug=False):
|
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, debug=False):
|
||||||
pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content, debug)
|
||||||
audio = self.vits(text_seq, pred_semantic, ref_audio)
|
audio = self.vits(text_seq, pred_semantic, ref_audio)
|
||||||
if debug:
|
if debug:
|
||||||
import onnxruntime
|
import onnxruntime
|
||||||
sess = onnxruntime.InferenceSession("onnx/koharu/koharu_vits.onnx", providers=["CPU"])
|
sess = onnxruntime.InferenceSession("onnx/nahida/nahida_vits.onnx", providers=["CPU"])
|
||||||
audio1 = sess.run(None, {
|
audio1 = sess.run(None, {
|
||||||
"text_seq" : text_seq.detach().cpu().numpy(),
|
"text_seq" : text_seq.detach().cpu().numpy(),
|
||||||
"pred_semantic" : pred_semantic.detach().cpu().numpy(),
|
"pred_semantic" : pred_semantic.detach().cpu().numpy(),
|
||||||
@ -263,21 +298,47 @@ class SSLModel(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.ssl = ssl_model
|
self.ssl = ssl_model
|
||||||
|
|
||||||
def forward(self, ref_audio_16k):
|
def forward(self, ref_audio_16k, debug = False):
|
||||||
|
if debug:
|
||||||
|
import onnxruntime
|
||||||
|
sess = onnxruntime.InferenceSession("onnx/nahida/nahida_cnhubert.onnx", providers=["CPU"])
|
||||||
|
last_hidden_state = sess.run(None, {
|
||||||
|
"ref_audio_16k" : ref_audio_16k.detach().cpu().numpy()
|
||||||
|
})
|
||||||
|
return torch.from_numpy(last_hidden_state[0])
|
||||||
|
|
||||||
return self.ssl.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2)
|
return self.ssl.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2)
|
||||||
|
|
||||||
|
def export(self, ref_audio_16k, project_name):
|
||||||
|
torch.onnx.export(
|
||||||
|
self,
|
||||||
|
(ref_audio_16k),
|
||||||
|
f"onnx/{project_name}/{project_name}_cnhubert.onnx",
|
||||||
|
input_names=["ref_audio_16k"],
|
||||||
|
output_names=["last_hidden_state"],
|
||||||
|
dynamic_axes={
|
||||||
|
"ref_audio_16k": {1 : "text_length"},
|
||||||
|
"last_hidden_state": {2 : "pred_length"}
|
||||||
|
},
|
||||||
|
opset_version=17,
|
||||||
|
verbose=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def export(vits_path, gpt_path, project_name):
|
def export(vits_path, gpt_path, project_name):
|
||||||
vits = VitsModel(vits_path)
|
vits = VitsModel(vits_path)
|
||||||
gpt = T2SModel(gpt_path, vits)
|
gpt = T2SModel(gpt_path, vits)
|
||||||
gpt_sovits = GptSoVits(vits, gpt)
|
gpt_sovits = GptSoVits(vits, gpt)
|
||||||
ssl = SSLModel()
|
ssl = SSLModel()
|
||||||
ref_seq = torch.LongTensor([cleaned_text_to_sequence(["n", "i2", "h", "ao3", ",", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"])])
|
#ref_seq = torch.LongTensor([cleaned_text_to_sequence(["n", "i2", "h", "ao3", ",", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"])])
|
||||||
text_seq = torch.LongTensor([cleaned_text_to_sequence(["w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"])])
|
#text_seq = torch.LongTensor([cleaned_text_to_sequence(["w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"])])
|
||||||
|
ref_seq = torch.LongTensor([cleaned_text_to_sequence(['m', 'i', 'z', 'u', 'o', 'm', 'a', 'r', 'e', 'e', 'sh', 'i', 'a', 'k', 'a', 'r', 'a', 'k', 'a', 'w', 'a', 'n', 'a', 'k', 'U', 't', 'e', 'w', 'a', 'n', 'a', 'r', 'a', 'n', 'a', 'i', '.'])])
|
||||||
|
text_seq = torch.LongTensor([cleaned_text_to_sequence(['m', 'i', 'z', 'u', 'w', 'a', ',', 'i', 'r', 'i', 'm', 'a', 's', 'e', 'N', 'k', 'a', '?'])])
|
||||||
ref_bert = torch.randn((ref_seq.shape[1], 1024)).float()
|
ref_bert = torch.randn((ref_seq.shape[1], 1024)).float()
|
||||||
text_bert = torch.randn((text_seq.shape[1], 1024)).float()
|
text_bert = torch.randn((text_seq.shape[1], 1024)).float()
|
||||||
ref_audio = torch.randn((1, 48000 * 5)).float()
|
ref_audio = torch.randn((1, 48000 * 5)).float()
|
||||||
# ref_audio = torch.tensor([load_audio("rec.wav", 48000)]).float()
|
|
||||||
|
ref_audio = torch.tensor([load_audio("/Users/kyakuno/Desktop/大阪万博/voices/JSUT.wav", 48000)]).float()
|
||||||
ref_audio_16k = torchaudio.functional.resample(ref_audio,48000,16000).float()
|
ref_audio_16k = torchaudio.functional.resample(ref_audio,48000,16000).float()
|
||||||
ref_audio_sr = torchaudio.functional.resample(ref_audio,48000,vits.hps.data.sampling_rate).float()
|
ref_audio_sr = torchaudio.functional.resample(ref_audio,48000,vits.hps.data.sampling_rate).float()
|
||||||
|
|
||||||
@ -286,16 +347,17 @@ def export(vits_path, gpt_path, project_name):
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
ssl_content = ssl(ref_audio_16k).float()
|
debug = True
|
||||||
|
ssl_content = ssl(ref_audio_16k, debug=debug).float()
|
||||||
debug = False
|
|
||||||
|
|
||||||
if debug:
|
if debug:
|
||||||
a, b = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, debug=debug)
|
a, b = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, debug=debug)
|
||||||
soundfile.write("out1.wav", a.cpu().detach().numpy(), vits.hps.data.sampling_rate)
|
soundfile.write("out1.wav", a.cpu().detach().numpy(), vits.hps.data.sampling_rate)
|
||||||
soundfile.write("out2.wav", b[0], vits.hps.data.sampling_rate)
|
soundfile.write("out2.wav", b[0], vits.hps.data.sampling_rate)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
ssl.export(ref_audio_16k, project_name)
|
||||||
|
|
||||||
a = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content).detach().cpu().numpy()
|
a = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content).detach().cpu().numpy()
|
||||||
|
|
||||||
soundfile.write("out.wav", a, vits.hps.data.sampling_rate)
|
soundfile.write("out.wav", a, vits.hps.data.sampling_rate)
|
||||||
@ -326,8 +388,8 @@ if __name__ == "__main__":
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
gpt_path = "GPT_weights/nahida-e25.ckpt"
|
gpt_path = "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"#"GPT_weights/nahida-e25.ckpt"
|
||||||
vits_path = "SoVITS_weights/nahida_e30_s3930.pth"
|
vits_path = "GPT_SoVITS/pretrained_models/s2G488k.pth"#"SoVITS_weights/nahida_e30_s3930.pth"
|
||||||
exp_path = "nahida"
|
exp_path = "nahida"
|
||||||
export(vits_path, gpt_path, exp_path)
|
export(vits_path, gpt_path, exp_path)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user