diff --git a/GPT_SoVITS/onnx_export.py b/GPT_SoVITS/onnx_export.py index f08679f..b82e987 100644 --- a/GPT_SoVITS/onnx_export.py +++ b/GPT_SoVITS/onnx_export.py @@ -140,6 +140,7 @@ class T2SModel(nn.Module): ) onnx_encoder_export_output.save(f"onnx/{project_name}/{project_name}_t2s_encoder.onnx") return + torch.onnx.export( self.onnx_encoder, (ref_seq, text_seq, ref_bert, text_bert, ssl_content), @@ -147,16 +148,16 @@ class T2SModel(nn.Module): input_names=["ref_seq", "text_seq", "ref_bert", "text_bert", "ssl_content"], output_names=["x", "prompts"], dynamic_axes={ - "ref_seq": [1], - "text_seq": [1], - "ref_bert": [0], - "text_bert": [0], - "ssl_content": [2], + "ref_seq": {1 : "ref_length"}, + "text_seq": {1 : "text_length"}, + "ref_bert": {0 : "ref_length"}, + "text_bert": {0 : "text_length"}, + "ssl_content": {2 : "ssl_length"}, }, opset_version=16 ) x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content) - torch.exp + torch.onnx.export( self.first_stage_decoder, (x, prompts), @@ -164,10 +165,10 @@ class T2SModel(nn.Module): input_names=["x", "prompts"], output_names=["y", "k", "v", "y_emb", "x_example"], dynamic_axes={ - "x": [1], - "prompts": [1], + "x": {1 : "x_length"}, + "prompts": {1 : "prompts_length"}, }, - verbose=True, + verbose=False, opset_version=16 ) y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts) @@ -179,13 +180,13 @@ class T2SModel(nn.Module): input_names=["iy", "ik", "iv", "iy_emb", "ix_example"], output_names=["y", "k", "v", "y_emb", "logits", "samples"], dynamic_axes={ - "iy": [1], - "ik": [1], - "iv": [1], - "iy_emb": [1], - "ix_example": [1], + "iy": {1 : "iy_length"}, + "ik": {1 : "ik_length"}, + "iv": {1 : "iv_length"}, + "iy_emb": {1 : "iy_emb_length"}, + "ix_example": {1 : "ix_example_length"}, }, - verbose=True, + verbose=False, opset_version=16 ) @@ -224,9 +225,19 @@ class GptSoVits(nn.Module): self.vits = vits self.t2s = t2s - def forward(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content): + def forward(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, debug=False): pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content) - return self.vits(text_seq, pred_semantic, ref_audio) + audio = self.vits(text_seq, pred_semantic, ref_audio) + if debug: + import onnxruntime + sess = onnxruntime.InferenceSession("onnx/koharu/koharu_vits.onnx", providers=["CPU"]) + audio1 = sess.run(None, { + "text_seq" : text_seq.detach().cpu().numpy(), + "pred_semantic" : pred_semantic.detach().cpu().numpy(), + "ref_audio" : ref_audio.detach().cpu().numpy() + }) + return audio, audio1 + return audio def export(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, project_name): self.t2s.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name) @@ -238,11 +249,12 @@ class GptSoVits(nn.Module): input_names=["text_seq", "pred_semantic", "ref_audio"], output_names=["audio"], dynamic_axes={ - "text_seq": [1], - "pred_semantic": [2], - "ref_audio": [1], + "text_seq": {1 : "text_length"}, + "pred_semantic": {2 : "pred_length"}, + "ref_audio": {1 : "audio_length"}, }, - opset_version=17 + opset_version=17, + verbose=False ) @@ -261,7 +273,7 @@ def export(vits_path, gpt_path, project_name): gpt_sovits = GptSoVits(vits, gpt) ssl = SSLModel() ref_seq = torch.LongTensor([cleaned_text_to_sequence(["n", "i2", "h", "ao3", ",", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"])]) - text_seq = torch.LongTensor([cleaned_text_to_sequence(["w", "o3", "sh", "i4", "b", "ai2", "y", "e4"])]) + text_seq = torch.LongTensor([cleaned_text_to_sequence(["w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"])]) ref_bert = torch.randn((ref_seq.shape[1], 1024)).float() text_bert = torch.randn((text_seq.shape[1], 1024)).float() ref_audio = torch.randn((1, 48000 * 5)).float() @@ -275,10 +287,18 @@ def export(vits_path, gpt_path, project_name): pass ssl_content = ssl(ref_audio_16k).float() + + debug = False + if debug: + a, b = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, debug=debug) + soundfile.write("out1.wav", a.cpu().detach().numpy(), vits.hps.data.sampling_rate) + soundfile.write("out2.wav", b[0], vits.hps.data.sampling_rate) + return + a = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content).detach().cpu().numpy() - # soundfile.write("out.wav", a, vits.hps.data.sampling_rate) + soundfile.write("out.wav", a, vits.hps.data.sampling_rate) gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, project_name) @@ -306,9 +326,9 @@ if __name__ == "__main__": except: pass - gpt_path = "pt_model/koharu-e20.ckpt" - vits_path = "pt_model/koharu_e20_s4960.pth" - exp_path = "koharu" + gpt_path = "GPT_weights/nahida-e25.ckpt" + vits_path = "SoVITS_weights/nahida_e30_s3930.pth" + exp_path = "nahida" export(vits_path, gpt_path, exp_path) # soundfile.write("out.wav", a, vits.hps.data.sampling_rate) \ No newline at end of file