Update patched_mha_with_cache.py

This commit is contained in:
RVC-Boss 2025-03-26 17:39:19 +08:00 committed by GitHub
parent b65ea9181e
commit ee4a466f79
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -12,33 +12,33 @@ import torch
def multi_head_attention_forward_patched( def multi_head_attention_forward_patched(
query: Tensor, query,
key: Tensor, key,
value: Tensor, value,
embed_dim_to_check: int, embed_dim_to_check,
num_heads: int, num_heads,
in_proj_weight: Optional[Tensor], in_proj_weight,
in_proj_bias: Optional[Tensor], in_proj_bias,
bias_k: Optional[Tensor], bias_k,
bias_v: Optional[Tensor], bias_v,
add_zero_attn: bool, add_zero_attn,
dropout_p: float, dropout_p: float,
out_proj_weight: Tensor, out_proj_weight,
out_proj_bias: Optional[Tensor], out_proj_bias,
training: bool = True, training = True,
key_padding_mask: Optional[Tensor] = None, key_padding_mask = None,
need_weights: bool = True, need_weights = True,
attn_mask: Optional[Tensor] = None, attn_mask = None,
use_separate_proj_weight: bool = False, use_separate_proj_weight = False,
q_proj_weight: Optional[Tensor] = None, q_proj_weight = None,
k_proj_weight: Optional[Tensor] = None, k_proj_weight = None,
v_proj_weight: Optional[Tensor] = None, v_proj_weight = None,
static_k: Optional[Tensor] = None, static_k = None,
static_v: Optional[Tensor] = None, static_v = None,
average_attn_weights: bool = True, average_attn_weights = True,
is_causal: bool = False, is_causal = False,
cache=None, cache=None,
) -> Tuple[Tensor, Optional[Tensor]]: ):
r""" r"""
Args: Args:
query, key, value: map a query and a set of key-value pairs to an output. query, key, value: map a query and a set of key-value pairs to an output.