From b9f2400e826848474be8a50267d070337a9fe78f Mon Sep 17 00:00:00 2001 From: Kazuki Kyakuno Date: Mon, 18 Mar 2024 18:31:44 +0900 Subject: [PATCH] Update onnx export script --- GPT_SoVITS/onnx_export.py | 98 ++++++++++++++++++++++++++++++++------- 1 file changed, 80 insertions(+), 18 deletions(-) diff --git a/GPT_SoVITS/onnx_export.py b/GPT_SoVITS/onnx_export.py index b82e987f..b6e60684 100644 --- a/GPT_SoVITS/onnx_export.py +++ b/GPT_SoVITS/onnx_export.py @@ -4,7 +4,14 @@ import torch import torchaudio from torch import nn from feature_extractor import cnhubert -cnhubert_base_path = "pretrained_models/chinese-hubert-base" + +#cnhubert_base_path = "pretrained_models/chinese-hubert-base" + +import os +cnhubert_base_path = os.environ.get( + "cnhubert_base_path", "GPT_SoVITS/pretrained_models/chinese-hubert-base" +) + cnhubert.cnhubert_base_path=cnhubert_base_path ssl_model = cnhubert.get_model() from text import cleaned_text_to_sequence @@ -103,22 +110,50 @@ class T2SModel(nn.Module): self.stage_decoder = self.t2s_model.stage_decoder #self.t2s_model = torch.jit.script(self.t2s_model) - def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content): + def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, debug=False): early_stop_num = self.t2s_model.early_stop_num + if debug: + import onnxruntime + sess_encoder = onnxruntime.InferenceSession(f"onnx/nahida/nahida_t2s_encoder.onnx", providers=["CPU"]) + sess_fsdec = onnxruntime.InferenceSession(f"onnx/nahida/nahida_t2s_fsdec.onnx", providers=["CPU"]) + sess_sdec = onnxruntime.InferenceSession(f"onnx/nahida/nahida_t2s_sdec.onnx", providers=["CPU"]) + #[1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N] - x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content) + if debug: + x, prompts = sess_encoder.run(None, {"ref_seq":ref_seq.detach().numpy(), "text_seq":text_seq.detach().numpy(), "ref_bert":ref_bert.detach().numpy(), "text_bert":text_bert.detach().numpy(), "ssl_content":ssl_content.detach().numpy()}) + x = torch.from_numpy(x) + prompts = torch.from_numpy(prompts) + else: + x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content) prefix_len = prompts.shape[1] #[1,N,512] [1,N] - y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts) + if debug: + y, k, v, y_emb, x_example = sess_fsdec.run(None, {"x":x.detach().numpy(), "prompts":prompts.detach().numpy()}) + y = torch.from_numpy(y) + k = torch.from_numpy(k) + v = torch.from_numpy(v) + y_emb = torch.from_numpy(y_emb) + x_example = torch.from_numpy(x_example) + else: + y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts) stop = False for idx in range(1, 1500): #[1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N] - enco = self.stage_decoder(y, k, v, y_emb, x_example) - y, k, v, y_emb, logits, samples = enco + if debug: + y, k, v, y_emb, logits, samples = sess_sdec.run(None, {"iy":y.detach().numpy(), "ik":k.detach().numpy(), "iv":v.detach().numpy(), "iy_emb":y_emb.detach().numpy(), "ix_example":x_example.detach().numpy()}) + y = torch.from_numpy(y) + k = torch.from_numpy(k) + v = torch.from_numpy(v) + y_emb = torch.from_numpy(y_emb) + logits = torch.from_numpy(logits) + samples = torch.from_numpy(samples) + else: + enco = self.stage_decoder(y, k, v, y_emb, x_example) + y, k, v, y_emb, logits, samples = enco if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: stop = True if torch.argmax(logits, dim=-1)[0] == self.t2s_model.EOS or samples[0, 0] == self.t2s_model.EOS: @@ -226,11 +261,11 @@ class GptSoVits(nn.Module): self.t2s = t2s def forward(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, debug=False): - pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content) + pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content, debug) audio = self.vits(text_seq, pred_semantic, ref_audio) if debug: import onnxruntime - sess = onnxruntime.InferenceSession("onnx/koharu/koharu_vits.onnx", providers=["CPU"]) + sess = onnxruntime.InferenceSession("onnx/nahida/nahida_vits.onnx", providers=["CPU"]) audio1 = sess.run(None, { "text_seq" : text_seq.detach().cpu().numpy(), "pred_semantic" : pred_semantic.detach().cpu().numpy(), @@ -263,21 +298,47 @@ class SSLModel(nn.Module): super().__init__() self.ssl = ssl_model - def forward(self, ref_audio_16k): + def forward(self, ref_audio_16k, debug = False): + if debug: + import onnxruntime + sess = onnxruntime.InferenceSession("onnx/nahida/nahida_cnhubert.onnx", providers=["CPU"]) + last_hidden_state = sess.run(None, { + "ref_audio_16k" : ref_audio_16k.detach().cpu().numpy() + }) + return torch.from_numpy(last_hidden_state[0]) + return self.ssl.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2) + def export(self, ref_audio_16k, project_name): + torch.onnx.export( + self, + (ref_audio_16k), + f"onnx/{project_name}/{project_name}_cnhubert.onnx", + input_names=["ref_audio_16k"], + output_names=["last_hidden_state"], + dynamic_axes={ + "ref_audio_16k": {1 : "text_length"}, + "last_hidden_state": {2 : "pred_length"} + }, + opset_version=17, + verbose=False + ) + def export(vits_path, gpt_path, project_name): vits = VitsModel(vits_path) gpt = T2SModel(gpt_path, vits) gpt_sovits = GptSoVits(vits, gpt) ssl = SSLModel() - ref_seq = torch.LongTensor([cleaned_text_to_sequence(["n", "i2", "h", "ao3", ",", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"])]) - text_seq = torch.LongTensor([cleaned_text_to_sequence(["w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"])]) + #ref_seq = torch.LongTensor([cleaned_text_to_sequence(["n", "i2", "h", "ao3", ",", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"])]) + #text_seq = torch.LongTensor([cleaned_text_to_sequence(["w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"])]) + ref_seq = torch.LongTensor([cleaned_text_to_sequence(['m', 'i', 'z', 'u', 'o', 'm', 'a', 'r', 'e', 'e', 'sh', 'i', 'a', 'k', 'a', 'r', 'a', 'k', 'a', 'w', 'a', 'n', 'a', 'k', 'U', 't', 'e', 'w', 'a', 'n', 'a', 'r', 'a', 'n', 'a', 'i', '.'])]) + text_seq = torch.LongTensor([cleaned_text_to_sequence(['m', 'i', 'z', 'u', 'w', 'a', ',', 'i', 'r', 'i', 'm', 'a', 's', 'e', 'N', 'k', 'a', '?'])]) ref_bert = torch.randn((ref_seq.shape[1], 1024)).float() text_bert = torch.randn((text_seq.shape[1], 1024)).float() ref_audio = torch.randn((1, 48000 * 5)).float() - # ref_audio = torch.tensor([load_audio("rec.wav", 48000)]).float() + + ref_audio = torch.tensor([load_audio("/Users/kyakuno/Desktop/大阪万博/voices/JSUT.wav", 48000)]).float() ref_audio_16k = torchaudio.functional.resample(ref_audio,48000,16000).float() ref_audio_sr = torchaudio.functional.resample(ref_audio,48000,vits.hps.data.sampling_rate).float() @@ -286,16 +347,17 @@ def export(vits_path, gpt_path, project_name): except: pass - ssl_content = ssl(ref_audio_16k).float() - - debug = False + debug = True + ssl_content = ssl(ref_audio_16k, debug=debug).float() if debug: a, b = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, debug=debug) soundfile.write("out1.wav", a.cpu().detach().numpy(), vits.hps.data.sampling_rate) soundfile.write("out2.wav", b[0], vits.hps.data.sampling_rate) return - + + ssl.export(ref_audio_16k, project_name) + a = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content).detach().cpu().numpy() soundfile.write("out.wav", a, vits.hps.data.sampling_rate) @@ -326,8 +388,8 @@ if __name__ == "__main__": except: pass - gpt_path = "GPT_weights/nahida-e25.ckpt" - vits_path = "SoVITS_weights/nahida_e30_s3930.pth" + gpt_path = "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"#"GPT_weights/nahida-e25.ckpt" + vits_path = "GPT_SoVITS/pretrained_models/s2G488k.pth"#"SoVITS_weights/nahida_e30_s3930.pth" exp_path = "nahida" export(vits_path, gpt_path, exp_path)