GPT-SoVITS/GPT_SoVITS/AR/modules/patched_mha_with_cache_onnx.py

96 lines
3.4 KiB
Python

from torch.nn.functional import *
from torch.nn.functional import (
_canonical_mask,
)
from typing import Tuple, Optional
def multi_head_attention_forward_patched(
query,
key,
value,
embed_dim_to_check: int,
num_heads: int,
in_proj_weight,
in_proj_bias: Optional[Tensor],
bias_k: Optional[Tensor],
bias_v: Optional[Tensor],
add_zero_attn: bool,
dropout_p: float,
out_proj_weight: Tensor,
out_proj_bias: Optional[Tensor],
training: bool = True,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
use_separate_proj_weight: bool = False,
q_proj_weight: Optional[Tensor] = None,
k_proj_weight: Optional[Tensor] = None,
v_proj_weight: Optional[Tensor] = None,
static_k: Optional[Tensor] = None,
static_v: Optional[Tensor] = None,
average_attn_weights: bool = True,
is_causal: bool = False,
cache=None,
) -> Tuple[Tensor, Optional[Tensor]]:
# set up shape vars
_, _, embed_dim = query.shape
attn_mask = _canonical_mask(
mask=attn_mask,
mask_name="attn_mask",
other_type=None,
other_name="",
target_type=query.dtype,
check_other=False,
)
head_dim = embed_dim // num_heads
proj_qkv = linear(query, in_proj_weight, in_proj_bias)
proj_qkv = proj_qkv.unflatten(-1, (3, query.size(-1))).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
q, k, v = proj_qkv[0], proj_qkv[1], proj_qkv[2]
# 使用动态形状推断来统一处理kv cache首步和后续步骤形状差异
# # k,v : [N, 1, 512] at first time, [1, 1, 512] afterwards
# # cache_k, cache_v : [1, N, 1, 512] size increasement is prepared outside
first_infer_mask = cache["first_infer"]
cache_k = cache["k"][cache["stage"]]
cache_v = cache["v"][cache["stage"]]
# Magic to get an index of either -1 or -N according to if first_infer_mask is set
minus_one = torch.tensor([-1]).to(k.device).to(torch.int64)
multipled = minus_one * first_infer_mask * torch.onnx.operators.shape_as_tensor(query)[0]
index_offset = torch.min(minus_one, multipled)
cache_k[index_offset :, :, :] = k
cache_v[index_offset :, :, :] = v
cache["k"][cache["stage"]] = cache_k
cache["v"][cache["stage"]] = cache_v
k = cache_k
v = cache_v
cache["stage"] = (cache["stage"] + 1) % cache["all_stage"]
attn_mask = _canonical_mask(
mask=attn_mask,
mask_name="attn_mask",
other_type=None,
other_name="",
target_type=q.dtype,
check_other=False,
)
attn_mask = attn_mask.unsqueeze(0)
q = q.view(-1, num_heads, head_dim).transpose(0, 1)
k = k.view(-1, num_heads, head_dim).transpose(0, 1)
v = v.view(-1, num_heads, head_dim).transpose(0, 1)
dropout_p = 0.0
attn_mask = attn_mask.unsqueeze(0)
q = q.view(num_heads, -1, head_dim).unsqueeze(0)
k = k.view(num_heads, -1, head_dim).unsqueeze(0)
v = v.view(num_heads, -1, head_dim).unsqueeze(0)
attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(-1, embed_dim)
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
attn_output = attn_output.view(-1, 1, attn_output.size(1))
return attn_output