From 8c01e275ecede4ce37712fa7351dbcd055cc8909 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=CE=9D=CE=B1=CF=81=CE=BF=CF=85=CF=83=CE=AD=C2=B7=CE=BC?= =?UTF-8?q?=C2=B7=CE=B3=CE=B9=CE=BF=CF=85=CE=BC=CE=B5=CE=BC=CE=AF=C2=B7?= =?UTF-8?q?=CE=A7=CE=B9=CE=BD=CE=B1=CE=BA=CE=AC=CE=BD=CE=BD=CE=B1?= <40709280+NaruseMioShirakana@users.noreply.github.com> Date: Thu, 8 Feb 2024 21:38:38 +0800 Subject: [PATCH 1/6] Add files via upload --- GPT_SoVITS/onnx_export.py | 72 +++++++++++++++++++++++++-------------- 1 file changed, 46 insertions(+), 26 deletions(-) diff --git a/GPT_SoVITS/onnx_export.py b/GPT_SoVITS/onnx_export.py index f08679f..b82e987 100644 --- a/GPT_SoVITS/onnx_export.py +++ b/GPT_SoVITS/onnx_export.py @@ -140,6 +140,7 @@ class T2SModel(nn.Module): ) onnx_encoder_export_output.save(f"onnx/{project_name}/{project_name}_t2s_encoder.onnx") return + torch.onnx.export( self.onnx_encoder, (ref_seq, text_seq, ref_bert, text_bert, ssl_content), @@ -147,16 +148,16 @@ class T2SModel(nn.Module): input_names=["ref_seq", "text_seq", "ref_bert", "text_bert", "ssl_content"], output_names=["x", "prompts"], dynamic_axes={ - "ref_seq": [1], - "text_seq": [1], - "ref_bert": [0], - "text_bert": [0], - "ssl_content": [2], + "ref_seq": {1 : "ref_length"}, + "text_seq": {1 : "text_length"}, + "ref_bert": {0 : "ref_length"}, + "text_bert": {0 : "text_length"}, + "ssl_content": {2 : "ssl_length"}, }, opset_version=16 ) x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content) - torch.exp + torch.onnx.export( self.first_stage_decoder, (x, prompts), @@ -164,10 +165,10 @@ class T2SModel(nn.Module): input_names=["x", "prompts"], output_names=["y", "k", "v", "y_emb", "x_example"], dynamic_axes={ - "x": [1], - "prompts": [1], + "x": {1 : "x_length"}, + "prompts": {1 : "prompts_length"}, }, - verbose=True, + verbose=False, opset_version=16 ) y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts) @@ -179,13 +180,13 @@ class T2SModel(nn.Module): input_names=["iy", "ik", "iv", "iy_emb", "ix_example"], output_names=["y", "k", "v", "y_emb", "logits", "samples"], dynamic_axes={ - "iy": [1], - "ik": [1], - "iv": [1], - "iy_emb": [1], - "ix_example": [1], + "iy": {1 : "iy_length"}, + "ik": {1 : "ik_length"}, + "iv": {1 : "iv_length"}, + "iy_emb": {1 : "iy_emb_length"}, + "ix_example": {1 : "ix_example_length"}, }, - verbose=True, + verbose=False, opset_version=16 ) @@ -224,9 +225,19 @@ class GptSoVits(nn.Module): self.vits = vits self.t2s = t2s - def forward(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content): + def forward(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, debug=False): pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content) - return self.vits(text_seq, pred_semantic, ref_audio) + audio = self.vits(text_seq, pred_semantic, ref_audio) + if debug: + import onnxruntime + sess = onnxruntime.InferenceSession("onnx/koharu/koharu_vits.onnx", providers=["CPU"]) + audio1 = sess.run(None, { + "text_seq" : text_seq.detach().cpu().numpy(), + "pred_semantic" : pred_semantic.detach().cpu().numpy(), + "ref_audio" : ref_audio.detach().cpu().numpy() + }) + return audio, audio1 + return audio def export(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, project_name): self.t2s.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name) @@ -238,11 +249,12 @@ class GptSoVits(nn.Module): input_names=["text_seq", "pred_semantic", "ref_audio"], output_names=["audio"], dynamic_axes={ - "text_seq": [1], - "pred_semantic": [2], - "ref_audio": [1], + "text_seq": {1 : "text_length"}, + "pred_semantic": {2 : "pred_length"}, + "ref_audio": {1 : "audio_length"}, }, - opset_version=17 + opset_version=17, + verbose=False ) @@ -261,7 +273,7 @@ def export(vits_path, gpt_path, project_name): gpt_sovits = GptSoVits(vits, gpt) ssl = SSLModel() ref_seq = torch.LongTensor([cleaned_text_to_sequence(["n", "i2", "h", "ao3", ",", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"])]) - text_seq = torch.LongTensor([cleaned_text_to_sequence(["w", "o3", "sh", "i4", "b", "ai2", "y", "e4"])]) + text_seq = torch.LongTensor([cleaned_text_to_sequence(["w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"])]) ref_bert = torch.randn((ref_seq.shape[1], 1024)).float() text_bert = torch.randn((text_seq.shape[1], 1024)).float() ref_audio = torch.randn((1, 48000 * 5)).float() @@ -275,10 +287,18 @@ def export(vits_path, gpt_path, project_name): pass ssl_content = ssl(ref_audio_16k).float() + + debug = False + if debug: + a, b = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, debug=debug) + soundfile.write("out1.wav", a.cpu().detach().numpy(), vits.hps.data.sampling_rate) + soundfile.write("out2.wav", b[0], vits.hps.data.sampling_rate) + return + a = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content).detach().cpu().numpy() - # soundfile.write("out.wav", a, vits.hps.data.sampling_rate) + soundfile.write("out.wav", a, vits.hps.data.sampling_rate) gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, project_name) @@ -306,9 +326,9 @@ if __name__ == "__main__": except: pass - gpt_path = "pt_model/koharu-e20.ckpt" - vits_path = "pt_model/koharu_e20_s4960.pth" - exp_path = "koharu" + gpt_path = "GPT_weights/nahida-e25.ckpt" + vits_path = "SoVITS_weights/nahida_e30_s3930.pth" + exp_path = "nahida" export(vits_path, gpt_path, exp_path) # soundfile.write("out.wav", a, vits.hps.data.sampling_rate) \ No newline at end of file From 08aed05796f28244bf70e26ffa9f8ba88031e938 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=CE=9D=CE=B1=CF=81=CE=BF=CF=85=CF=83=CE=AD=C2=B7=CE=BC?= =?UTF-8?q?=C2=B7=CE=B3=CE=B9=CE=BF=CF=85=CE=BC=CE=B5=CE=BC=CE=AF=C2=B7?= =?UTF-8?q?=CE=A7=CE=B9=CE=BD=CE=B1=CE=BA=CE=AC=CE=BD=CE=BD=CE=B1?= <40709280+NaruseMioShirakana@users.noreply.github.com> Date: Thu, 8 Feb 2024 21:39:29 +0800 Subject: [PATCH 2/6] Add files via upload --- GPT_SoVITS/module/models_onnx.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/GPT_SoVITS/module/models_onnx.py b/GPT_SoVITS/module/models_onnx.py index 35fd291..232fd74 100644 --- a/GPT_SoVITS/module/models_onnx.py +++ b/GPT_SoVITS/module/models_onnx.py @@ -896,9 +896,6 @@ class SynthesizerTrn(nn.Module): refer_mask = torch.ones_like(refer[:1,:1,:]) ge = self.ref_enc(refer * refer_mask, refer_mask) - y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device) - text_lengths = torch.LongTensor([text.size(-1)]).to(text.device) - quantized = self.quantizer.decode(codes) if self.semantic_frame_rate == "25hz": dquantized = torch.cat([quantized, quantized]).permute(1, 2, 0) @@ -907,6 +904,7 @@ class SynthesizerTrn(nn.Module): x, m_p, logs_p, y_mask = self.enc_p( quantized, text, ge ) + z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) z = self.flow(z_p, y_mask, g=ge, reverse=True) From 0b2e3760c268b0cbafb0f3e00207f776b80ac7aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=CE=9D=CE=B1=CF=81=CE=BF=CF=85=CF=83=CE=AD=C2=B7=CE=BC?= =?UTF-8?q?=C2=B7=CE=B3=CE=B9=CE=BF=CF=85=CE=BC=CE=B5=CE=BC=CE=AF=C2=B7?= =?UTF-8?q?=CE=A7=CE=B9=CE=BD=CE=B1=CE=BA=CE=AC=CE=BD=CE=BD=CE=B1?= <40709280+NaruseMioShirakana@users.noreply.github.com> Date: Thu, 8 Feb 2024 21:39:56 +0800 Subject: [PATCH 3/6] Add files via upload --- GPT_SoVITS/module/attentions_onnx.py | 47 +++++++++++----------------- 1 file changed, 18 insertions(+), 29 deletions(-) diff --git a/GPT_SoVITS/module/attentions_onnx.py b/GPT_SoVITS/module/attentions_onnx.py index df0ae82..bc63a06 100644 --- a/GPT_SoVITS/module/attentions_onnx.py +++ b/GPT_SoVITS/module/attentions_onnx.py @@ -188,38 +188,27 @@ class MultiHeadAttention(nn.Module): query = query.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3) key = key.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3) value = value.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3) - scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) + if self.window_size is not None: key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) - rel_logits = self._matmul_with_relative_keys( - query / math.sqrt(self.k_channels), key_relative_embeddings - ) + rel_logits = self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), key_relative_embeddings) scores_local = self._relative_position_to_absolute_position(rel_logits) scores = scores + scores_local + if mask is not None: scores = scores.masked_fill(mask == 0, -1e4) - if self.block_length is not None: - block_mask = ( - torch.ones_like(scores) - .triu(-self.block_length) - .tril(self.block_length) - ) - scores = scores.masked_fill(block_mask == 0, -1e4) + p_attn = F.softmax(scores, dim=-1) p_attn = self.drop(p_attn) output = torch.matmul(p_attn, value) + if self.window_size is not None: relative_weights = self._absolute_position_to_relative_position(p_attn) - value_relative_embeddings = self._get_relative_embeddings( - self.emb_rel_v, t_s - ) - output = output + self._matmul_with_relative_values( - relative_weights, value_relative_embeddings - ) - output = ( - output.transpose(2, 3).contiguous().view(b, d, -1) - ) + value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) + output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings) + + output = (output.transpose(2, 3).contiguous().view(b, d, -1)) return output, p_attn def _matmul_with_relative_values(self, x, y): @@ -243,16 +232,16 @@ class MultiHeadAttention(nn.Module): def _get_relative_embeddings(self, relative_embeddings, length): max_relative_position = 2 * self.window_size + 1 # Pad first before slice to avoid using cond ops. - pad_length = max(length - (self.window_size + 1), 0) - slice_start_position = max((self.window_size + 1) - length, 0) + pad_l = torch.zeros((1), dtype = torch.int64) + length - (self.window_size + 1) + pad_s = torch.zeros((1), dtype = torch.int64) + (self.window_size + 1) - length + pad_length = torch.max(pad_l, other=torch.zeros((1), dtype = torch.int64)) + slice_start_position = torch.max(pad_s, other=torch.zeros((1), dtype = torch.int64)) + slice_end_position = slice_start_position + 2 * length - 1 - if pad_length > 0: - padded_relative_embeddings = F.pad( - relative_embeddings, - commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]), - ) - else: - padded_relative_embeddings = relative_embeddings + padded_relative_embeddings = F.pad( + relative_embeddings, + commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]), + ) used_relative_embeddings = padded_relative_embeddings[ :, slice_start_position:slice_end_position ] From d1c3cf70e54cb7411367bde0608ff759e9402793 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=CE=9D=CE=B1=CF=81=CE=BF=CF=85=CF=83=CE=AD=C2=B7=CE=BC?= =?UTF-8?q?=C2=B7=CE=B3=CE=B9=CE=BF=CF=85=CE=BC=CE=B5=CE=BC=CE=AF=C2=B7?= =?UTF-8?q?=CE=A7=CE=B9=CE=BD=CE=B1=CE=BA=CE=AC=CE=BD=CE=BD=CE=B1?= <40709280+NaruseMioShirakana@users.noreply.github.com> Date: Thu, 8 Feb 2024 21:40:07 +0800 Subject: [PATCH 4/6] Add files via upload From ffa49a7a1c4f51b224c8b2cf06b8aa8c3184efb9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=CE=9D=CE=B1=CF=81=CE=BF=CF=85=CF=83=CE=AD=C2=B7=CE=BC?= =?UTF-8?q?=C2=B7=CE=B3=CE=B9=CE=BF=CF=85=CE=BC=CE=B5=CE=BC=CE=AF=C2=B7?= =?UTF-8?q?=CE=A7=CE=B9=CE=BD=CE=B1=CE=BA=CE=AC=CE=BD=CE=BD=CE=B1?= <40709280+NaruseMioShirakana@users.noreply.github.com> Date: Thu, 8 Feb 2024 21:41:06 +0800 Subject: [PATCH 5/6] Add files via upload From dd1186c029929e05cbf490dd7e404a59f9686cd3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=CE=9D=CE=B1=CF=81=CE=BF=CF=85=CF=83=CE=AD=C2=B7=CE=BC?= =?UTF-8?q?=C2=B7=CE=B3=CE=B9=CE=BF=CF=85=CE=BC=CE=B5=CE=BC=CE=AF=C2=B7?= =?UTF-8?q?=CE=A7=CE=B9=CE=BD=CE=B1=CE=BA=CE=AC=CE=BD=CE=BD=CE=B1?= <40709280+NaruseMioShirakana@users.noreply.github.com> Date: Thu, 8 Feb 2024 21:41:37 +0800 Subject: [PATCH 6/6] Add files via upload --- GPT_SoVITS/AR/models/t2s_model_onnx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/GPT_SoVITS/AR/models/t2s_model_onnx.py b/GPT_SoVITS/AR/models/t2s_model_onnx.py index 263b933..92f2d74 100644 --- a/GPT_SoVITS/AR/models/t2s_model_onnx.py +++ b/GPT_SoVITS/AR/models/t2s_model_onnx.py @@ -57,7 +57,7 @@ def logits_to_probs( logits = logits / max(temperature, 1e-5) if top_k is not None: - v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + v, _ = torch.topk(logits, top_k) pivot = v.select(-1, -1).unsqueeze(-1) logits = torch.where(logits < pivot, inf_tensor_value, logits)