From 26228402e3e94fe95bd9ad863deaed46989058cf Mon Sep 17 00:00:00 2001 From: zpeng11 Date: Mon, 25 Aug 2025 12:06:26 -0400 Subject: [PATCH] feat:solve unified kv cache shape handling, todo: clean up upper level to unify first and following step --- GPT_SoVITS/AR/models/t2s_model_onnx.py | 12 ++-- .../AR/modules/patched_mha_with_cache_onnx.py | 38 +++++------ GPT_SoVITS/onnx_export_v1v2.py | 63 ++++++++++--------- playground/freerun.py | 6 +- 4 files changed, 57 insertions(+), 62 deletions(-) diff --git a/GPT_SoVITS/AR/models/t2s_model_onnx.py b/GPT_SoVITS/AR/models/t2s_model_onnx.py index 438b843e..5050d78d 100644 --- a/GPT_SoVITS/AR/models/t2s_model_onnx.py +++ b/GPT_SoVITS/AR/models/t2s_model_onnx.py @@ -128,7 +128,7 @@ class T2SFirstStageDecoder(nn.Module): self.early_stop_num = early_stop_num self.num_layers = num_layers - def forward(self, x, prompt, top_k = None, top_p = None, repetition_penalty = None, temperature = None): + def forward(self, x, prompt, top_k = None, top_p = None, repetition_penalty = None, temperature = None, first_infer = None): if top_k is None: top_k = torch.LongTensor([15]).to(device=x.device) if top_p is None: @@ -146,7 +146,7 @@ class T2SFirstStageDecoder(nn.Module): "k": None, "v": None, "y_emb": None, - "first_infer": 1, + "first_infer": first_infer, "stage": 0, } @@ -216,7 +216,7 @@ class T2SStageDecoder(nn.Module): self.early_stop_num = early_stop_num self.num_layers = num_layers - def forward(self, y, k, v, y_emb, x_example, top_k = None, top_p = None, repetition_penalty = None, temperature = None): + def forward(self, y, k, v, y_emb, x_example, top_k = None, top_p = None, repetition_penalty = None, temperature = None, first_infer = None): if top_k is None: top_k = torch.LongTensor([15]).to(device=y.device) if top_p is None: @@ -231,7 +231,7 @@ class T2SStageDecoder(nn.Module): "k": torch.nn.functional.pad(k, (0, 0, 0, 0, 0, 1)), "v": torch.nn.functional.pad(v, (0, 0, 0, 0, 0, 1)), "y_emb": y_emb, - "first_infer": 0, + "first_infer": first_infer, "stage": 0, } @@ -336,11 +336,11 @@ class Text2SemanticDecoder(nn.Module): prefix_len = prompts.shape[1] x = self.onnx_encoder(x, bert_feature) - y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts, top_k=top_k) + y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts, top_k=top_k, first_infer=torch.LongTensor([1])) stop = False for idx in tqdm(range(1, 1500)): - enco = self.stage_decoder(y, k, v, y_emb, x_example, top_k=top_k) + enco = self.stage_decoder(y, k, v, y_emb, x_example, top_k=top_k, first_infer=torch.LongTensor([0])) y, k, v, y_emb, logits, samples = enco if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: stop = True diff --git a/GPT_SoVITS/AR/modules/patched_mha_with_cache_onnx.py b/GPT_SoVITS/AR/modules/patched_mha_with_cache_onnx.py index 03fbeba9..04ad5d87 100644 --- a/GPT_SoVITS/AR/modules/patched_mha_with_cache_onnx.py +++ b/GPT_SoVITS/AR/modules/patched_mha_with_cache_onnx.py @@ -48,33 +48,23 @@ def multi_head_attention_forward_patched( proj_qkv = linear(query, in_proj_weight, in_proj_bias) proj_qkv = proj_qkv.unflatten(-1, (3, query.size(-1))).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous() q, k, v = proj_qkv[0], proj_qkv[1], proj_qkv[2] - # 首轮qkv会产生多个batch,后续每轮只会产生一个batch, - # onnx导出时处理batch变化导致的输出形状变化非常无力, - # 已尝试过where方法,索引方法,尽管可以动态运行正常导出, - # 但都无法在onnx运行时正确处理kv cache形状导致抛出错误 - # 此实现需要整体重写,将kvcache增长和prefill交给外部调用 - if cache["first_infer"] == 1: - cache["k"][cache["stage"]] = k - cache["v"][cache["stage"]] = v - else: - cache["k"][cache["stage"]] = torch.cat([cache["k"][cache["stage"]][:-1], k], 0) - cache["v"][cache["stage"]] = torch.cat([cache["v"][cache["stage"]][:-1], v], 0) - k = cache["k"][cache["stage"]] - v = cache["v"][cache["stage"]] + # 使用动态形状推断来统一处理kv cache首步和后续步骤形状差异 # # k,v : [N, 1, 512] at first time, [1, 1, 512] afterwards # # cache_k, cache_v : [1, N, 1, 512] size increasement is prepared outside - # first_infer_mask = cache["first_infer"] - # cache_k = cache["k"][cache["stage"]] - # cache_v = cache["v"][cache["stage"]] - # # Magic to get an index of either -1 or -N according to if first_infer_mask is set - # index_offset = torch.min(torch.tensor([-1]).to(k.device).to(torch.int64), -1 * first_infer_mask * k.shape[0]) - # cache_k[0, index_offset :, :, :] = k - # cache_v[0, index_offset :, :, :] = v - # cache["k"][cache["stage"]] = cache_k - # cache["v"][cache["stage"]] = cache_v - # k = cache_k - # v = cache_v + first_infer_mask = cache["first_infer"] + cache_k = cache["k"][cache["stage"]] + cache_v = cache["v"][cache["stage"]] + # Magic to get an index of either -1 or -N according to if first_infer_mask is set + minus_one = torch.tensor([-1]).to(k.device).to(torch.int64) + multipled = minus_one * first_infer_mask * torch.onnx.operators.shape_as_tensor(query)[0] + index_offset = torch.min(minus_one, multipled) + cache_k[index_offset :, :, :] = k + cache_v[index_offset :, :, :] = v + cache["k"][cache["stage"]] = cache_k + cache["v"][cache["stage"]] = cache_v + k = cache_k + v = cache_v cache["stage"] = (cache["stage"] + 1) % cache["all_stage"] diff --git a/GPT_SoVITS/onnx_export_v1v2.py b/GPT_SoVITS/onnx_export_v1v2.py index 556174ed..f8ffbc28 100644 --- a/GPT_SoVITS/onnx_export_v1v2.py +++ b/GPT_SoVITS/onnx_export_v1v2.py @@ -112,14 +112,15 @@ class T2SInitStep(nn.Module): self.fsdc = t2s.first_stage_decoder self.vits = vits - def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=None, top_p=None, repetition_penalty=None, temperature=None): + def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=None, top_p=None, repetition_penalty=None, temperature=None, first_infer=None): + first_infer = first_infer.to(torch.int64) codes = self.vits.extract_latent(ssl_content) prompt_semantic = codes[0, 0] bert = torch.cat([ref_bert.transpose(0, 1), text_bert.transpose(0, 1)], 1) all_phoneme_ids = torch.cat([ref_seq, text_seq], 1) bert = bert.unsqueeze(0) prompt = prompt_semantic.unsqueeze(0) - [y, k, v, y_emb, x_example] = self.fsdc(self.encoder(all_phoneme_ids, bert), prompt, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature) + [y, k, v, y_emb, x_example] = self.fsdc(self.encoder(all_phoneme_ids, bert), prompt, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature, first_infer=first_infer) fake_logits = torch.zeros((1, 1025), dtype=torch.float32) # Dummy logits for ONNX export fake_samples = torch.zeros((1, 1), dtype=torch.int32) # Dummy samples for ONNX export return y, k, v, y_emb, x_example, fake_logits, fake_samples @@ -129,8 +130,9 @@ class T2SStageStep(nn.Module): super().__init__() self.stage_decoder = stage_decoder - def forward(self, iy, ik, iv, iy_emb, ix_example, top_k=None, top_p=None, repetition_penalty=None, temperature=None): - [y, k, v, y_emb, logits, samples] = self.stage_decoder(iy, ik, iv, iy_emb, ix_example, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature) + def forward(self, iy, ik, iv, iy_emb, ix_example, top_k=None, top_p=None, repetition_penalty=None, temperature=None, first_infer=None): + first_infer = first_infer.to(torch.int64) + [y, k, v, y_emb, logits, samples] = self.stage_decoder(iy, ik, iv, iy_emb, ix_example, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature, first_infer=first_infer) fake_x_example = torch.zeros((1, 512), dtype=torch.float32) # Dummy x_example for ONNX export return y, k, v, y_emb, fake_x_example, logits, samples @@ -155,10 +157,10 @@ class T2SModel(nn.Module): def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=None, top_p=None, repetition_penalty=None, temperature=None): # [1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N] - y, k, v, y_emb, x_example, fake_logits, fake_samples = self.init_step(ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature) + y, k, v, y_emb, x_example, fake_logits, fake_samples = self.init_step(ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature, first_infer=torch.LongTensor([1])) for idx in range(5): # This is a fake one! DO NOT take this as reference - enco = self.stage_decoder(y, k, v, y_emb, x_example, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature) + enco = self.stage_decoder(y, k, v, y_emb, x_example, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature, first_infer=torch.LongTensor([0])) y, k, v, y_emb, logits, samples = enco # if torch.argmax(logits, dim=-1)[0] == self.t2s_model.EOS or samples[0, 0] == self.t2s_model.EOS: # break @@ -168,9 +170,9 @@ class T2SModel(nn.Module): def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name, top_k=None, top_p=None, repetition_penalty=None, temperature=None): torch.onnx.export( self.init_step, - (ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k, top_p, repetition_penalty, temperature), + (ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k, top_p, repetition_penalty, temperature, torch.Tensor([True]).to(torch.bool)), f"onnx/{project_name}/{project_name}_t2s_init_step.onnx", - input_names=["ref_text_phones", "input_text_phones", "ref_text_bert", "input_text_bert", "hubert_ssl_content", "top_k", "top_p", "repetition_penalty", "temperature"], + input_names=["ref_text_phones", "input_text_phones", "ref_text_bert", "input_text_bert", "hubert_ssl_content", "top_k", "top_p", "repetition_penalty", "temperature", "if_init_step"], output_names=["y", "k", "v", "y_emb", "x_example", 'logits', 'samples'], dynamic_axes={ "ref_text_phones": {1: "ref_length"}, @@ -180,16 +182,17 @@ class T2SModel(nn.Module): "hubert_ssl_content": {2: "ssl_length"}, }, opset_version=16, + do_constant_folding=False ) - simplify_onnx_model(f"onnx/{project_name}/{project_name}_t2s_init_step.onnx") - y, k, v, y_emb, x_example, fake_logits, fake_samples = self.init_step(ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature) + # simplify_onnx_model(f"onnx/{project_name}/{project_name}_t2s_init_step.onnx") + y, k, v, y_emb, x_example, fake_logits, fake_samples = self.init_step(ref_seq, text_seq, ref_bert, text_bert, ssl_content, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature, first_infer=torch.Tensor([True]).to(torch.bool)) stage_step = T2SStageStep(self.stage_decoder) torch.onnx.export( stage_step, - (y, k, v, y_emb, x_example, top_k, top_p, repetition_penalty, temperature), + (y, k, v, y_emb, x_example, top_k, top_p, repetition_penalty, temperature, torch.Tensor([False]).to(torch.bool)), f"onnx/{project_name}/{project_name}_t2s_stage_step.onnx", - input_names=["iy", "ik", "iv", "iy_emb", "ix_example", "top_k", "top_p", "repetition_penalty", "temperature"], + input_names=["iy", "ik", "iv", "iy_emb", "ix_example", "top_k", "top_p", "repetition_penalty", "temperature", "if_init_step"], output_names=["y", "k", "v", "y_emb","x_example", "logits", "samples"], dynamic_axes={ "iy": {1: "iy_length"}, @@ -329,7 +332,7 @@ def combineInitStepAndStageStep(init_step_onnx_path, stage_step_onnx_path, combi # Define the inputs for the main graph # 1. The boolean condition to select the branch - cond_input = helper.make_tensor_value_info('if_init_step', TensorProto.BOOL, []) + # cond_input = helper.make_tensor_value_info('if_init_step', TensorProto.BOOL, []) main_outputs = [output for output in init_step_model.graph.output] @@ -346,7 +349,7 @@ def combineInitStepAndStageStep(init_step_onnx_path, stage_step_onnx_path, combi main_graph = helper.make_graph( nodes=[if_node], name="t2s_combined_graph", - inputs=[cond_input] + data_inputs_init + data_inputs_stage, + inputs= data_inputs_init + data_inputs_stage, outputs=main_outputs ) @@ -483,29 +486,29 @@ if __name__ == "__main__": # 因为io太频繁,可能导致模型导出出错(wsl非常明显),请自行重试 - gpt_path = "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt" - vits_path = "GPT_SoVITS/pretrained_models/s2G488k.pth" - exp_path = "v1_export" - version = "v1" - export(vits_path, gpt_path, exp_path, version) + # gpt_path = "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt" + # vits_path = "GPT_SoVITS/pretrained_models/s2G488k.pth" + # exp_path = "v1_export" + # version = "v1" + # export(vits_path, gpt_path, exp_path, version) gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt" vits_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth" exp_path = "v2_export" version = "v2" - export(vits_path, gpt_path, exp_path, version) + export(vits_path, gpt_path, exp_path, version, t2s_model_combine = True) - gpt_path = "GPT_SoVITS/pretrained_models/s1v3.ckpt" - vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth" - exp_path = "v2pro_export" - version = "v2Pro" - export(vits_path, gpt_path, exp_path, version) + # gpt_path = "GPT_SoVITS/pretrained_models/s1v3.ckpt" + # vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth" + # exp_path = "v2pro_export" + # version = "v2Pro" + # export(vits_path, gpt_path, exp_path, version) - gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt" - vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth" - exp_path = "v2proplus_export" - version = "v2ProPlus" - export(vits_path, gpt_path, exp_path, version, t2s_model_combine = True, half_precision=True) + # gpt_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt" + # vits_path = "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth" + # exp_path = "v2proplus_export" + # version = "v2ProPlus" + # export(vits_path, gpt_path, exp_path, version, t2s_model_combine = True, half_precision=True) diff --git a/playground/freerun.py b/playground/freerun.py index 4f88f1ae..93045381 100644 --- a/playground/freerun.py +++ b/playground/freerun.py @@ -98,7 +98,8 @@ t2s_combined = ort.InferenceSession(MODEL_PATH+"_export_t2s_combined.onnx") "top_k": top_k, "top_p": top_p, "repetition_penalty": repetition_penalty, - "temperature": temperature + "temperature": temperature, + "if_init_step": np.array([True], dtype=bool) }) # t2s_stage_step = ort.InferenceSession(MODEL_PATH+"_export_t2s_sdec.onnx") @@ -120,7 +121,8 @@ for idx in tqdm(range(1, 1500)): "top_k": top_k, "top_p": top_p, "repetition_penalty": repetition_penalty, - "temperature": temperature + "temperature": temperature, + "if_init_step": np.array([False], dtype=bool) }) if np.argmax(logits, axis=-1)[0] == 1024 or samples[0, 0] == 1024: # 1024 is the EOS token break