mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-29 08:49:59 +08:00
feat:experiments with for onnx with attention, but does not work well todo:clean code and try v3v4
This commit is contained in:
parent
5982080939
commit
e4d1894a8f
@ -48,7 +48,11 @@ def multi_head_attention_forward_patched(
|
|||||||
proj_qkv = linear(query, in_proj_weight, in_proj_bias)
|
proj_qkv = linear(query, in_proj_weight, in_proj_bias)
|
||||||
proj_qkv = proj_qkv.unflatten(-1, (3, query.size(-1))).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
|
proj_qkv = proj_qkv.unflatten(-1, (3, query.size(-1))).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
|
||||||
q, k, v = proj_qkv[0], proj_qkv[1], proj_qkv[2]
|
q, k, v = proj_qkv[0], proj_qkv[1], proj_qkv[2]
|
||||||
|
# 首轮qkv会产生多个batch,后续每轮只会产生一个batch,
|
||||||
|
# onnx导出时处理batch变化导致的输出形状变化非常无力,
|
||||||
|
# 已尝试过where方法,索引方法,尽管可以动态运行正常导出,
|
||||||
|
# 但都无法在onnx运行时正确处理kv cache形状导致抛出错误
|
||||||
|
# 此实现需要整体重写,将kvcache增长和prefill交给外部调用
|
||||||
if cache["first_infer"] == 1:
|
if cache["first_infer"] == 1:
|
||||||
cache["k"][cache["stage"]] = k
|
cache["k"][cache["stage"]] = k
|
||||||
cache["v"][cache["stage"]] = v
|
cache["v"][cache["stage"]] = v
|
||||||
@ -57,6 +61,21 @@ def multi_head_attention_forward_patched(
|
|||||||
cache["v"][cache["stage"]] = torch.cat([cache["v"][cache["stage"]][:-1], v], 0)
|
cache["v"][cache["stage"]] = torch.cat([cache["v"][cache["stage"]][:-1], v], 0)
|
||||||
k = cache["k"][cache["stage"]]
|
k = cache["k"][cache["stage"]]
|
||||||
v = cache["v"][cache["stage"]]
|
v = cache["v"][cache["stage"]]
|
||||||
|
|
||||||
|
# # k,v : [N, 1, 512] at first time, [1, 1, 512] afterwards
|
||||||
|
# # cache_k, cache_v : [1, N, 1, 512] size increasement is prepared outside
|
||||||
|
# first_infer_mask = cache["first_infer"]
|
||||||
|
# cache_k = cache["k"][cache["stage"]]
|
||||||
|
# cache_v = cache["v"][cache["stage"]]
|
||||||
|
# # Magic to get an index of either -1 or -N according to if first_infer_mask is set
|
||||||
|
# index_offset = torch.min(torch.tensor([-1]).to(k.device).to(torch.int64), -1 * first_infer_mask * k.shape[0])
|
||||||
|
# cache_k[0, index_offset :, :, :] = k
|
||||||
|
# cache_v[0, index_offset :, :, :] = v
|
||||||
|
# cache["k"][cache["stage"]] = cache_k
|
||||||
|
# cache["v"][cache["stage"]] = cache_v
|
||||||
|
# k = cache_k
|
||||||
|
# v = cache_v
|
||||||
|
|
||||||
cache["stage"] = (cache["stage"] + 1) % cache["all_stage"]
|
cache["stage"] = (cache["stage"] + 1) % cache["all_stage"]
|
||||||
|
|
||||||
attn_mask = _canonical_mask(
|
attn_mask = _canonical_mask(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user