From e4d1894a8fd811d24fb92067cd09421386207608 Mon Sep 17 00:00:00 2001 From: zpeng11 Date: Sun, 24 Aug 2025 00:46:29 -0400 Subject: [PATCH] feat:experiments with for onnx with attention, but does not work well todo:clean code and try v3v4 --- .../AR/modules/patched_mha_with_cache_onnx.py | 21 ++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/GPT_SoVITS/AR/modules/patched_mha_with_cache_onnx.py b/GPT_SoVITS/AR/modules/patched_mha_with_cache_onnx.py index bd39ff6e..03fbeba9 100644 --- a/GPT_SoVITS/AR/modules/patched_mha_with_cache_onnx.py +++ b/GPT_SoVITS/AR/modules/patched_mha_with_cache_onnx.py @@ -48,7 +48,11 @@ def multi_head_attention_forward_patched( proj_qkv = linear(query, in_proj_weight, in_proj_bias) proj_qkv = proj_qkv.unflatten(-1, (3, query.size(-1))).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous() q, k, v = proj_qkv[0], proj_qkv[1], proj_qkv[2] - + # 首轮qkv会产生多个batch,后续每轮只会产生一个batch, + # onnx导出时处理batch变化导致的输出形状变化非常无力, + # 已尝试过where方法,索引方法,尽管可以动态运行正常导出, + # 但都无法在onnx运行时正确处理kv cache形状导致抛出错误 + # 此实现需要整体重写,将kvcache增长和prefill交给外部调用 if cache["first_infer"] == 1: cache["k"][cache["stage"]] = k cache["v"][cache["stage"]] = v @@ -57,6 +61,21 @@ def multi_head_attention_forward_patched( cache["v"][cache["stage"]] = torch.cat([cache["v"][cache["stage"]][:-1], v], 0) k = cache["k"][cache["stage"]] v = cache["v"][cache["stage"]] + + # # k,v : [N, 1, 512] at first time, [1, 1, 512] afterwards + # # cache_k, cache_v : [1, N, 1, 512] size increasement is prepared outside + # first_infer_mask = cache["first_infer"] + # cache_k = cache["k"][cache["stage"]] + # cache_v = cache["v"][cache["stage"]] + # # Magic to get an index of either -1 or -N according to if first_infer_mask is set + # index_offset = torch.min(torch.tensor([-1]).to(k.device).to(torch.int64), -1 * first_infer_mask * k.shape[0]) + # cache_k[0, index_offset :, :, :] = k + # cache_v[0, index_offset :, :, :] = v + # cache["k"][cache["stage"]] = cache_k + # cache["v"][cache["stage"]] = cache_v + # k = cache_k + # v = cache_v + cache["stage"] = (cache["stage"] + 1) % cache["all_stage"] attn_mask = _canonical_mask(