feat:experiments with for onnx with attention, but does not work well todo:clean code and try v3v4

This commit is contained in:
zpeng11 2025-08-24 00:46:29 -04:00
parent 5982080939
commit e4d1894a8f

View File

@ -48,7 +48,11 @@ def multi_head_attention_forward_patched(
proj_qkv = linear(query, in_proj_weight, in_proj_bias)
proj_qkv = proj_qkv.unflatten(-1, (3, query.size(-1))).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
q, k, v = proj_qkv[0], proj_qkv[1], proj_qkv[2]
# 首轮qkv会产生多个batch后续每轮只会产生一个batch
# onnx导出时处理batch变化导致的输出形状变化非常无力
# 已尝试过where方法索引方法尽管可以动态运行正常导出
# 但都无法在onnx运行时正确处理kv cache形状导致抛出错误
# 此实现需要整体重写将kvcache增长和prefill交给外部调用
if cache["first_infer"] == 1:
cache["k"][cache["stage"]] = k
cache["v"][cache["stage"]] = v
@ -57,6 +61,21 @@ def multi_head_attention_forward_patched(
cache["v"][cache["stage"]] = torch.cat([cache["v"][cache["stage"]][:-1], v], 0)
k = cache["k"][cache["stage"]]
v = cache["v"][cache["stage"]]
# # k,v : [N, 1, 512] at first time, [1, 1, 512] afterwards
# # cache_k, cache_v : [1, N, 1, 512] size increasement is prepared outside
# first_infer_mask = cache["first_infer"]
# cache_k = cache["k"][cache["stage"]]
# cache_v = cache["v"][cache["stage"]]
# # Magic to get an index of either -1 or -N according to if first_infer_mask is set
# index_offset = torch.min(torch.tensor([-1]).to(k.device).to(torch.int64), -1 * first_infer_mask * k.shape[0])
# cache_k[0, index_offset :, :, :] = k
# cache_v[0, index_offset :, :, :] = v
# cache["k"][cache["stage"]] = cache_k
# cache["v"][cache["stage"]] = cache_v
# k = cache_k
# v = cache_v
cache["stage"] = (cache["stage"] + 1) % cache["all_stage"]
attn_mask = _canonical_mask(