Fixed issues such as missing imports for types like Optional.

Fixed issues such as missing imports for types like `Optional`.
This commit is contained in:
RVC-Boss 2026-04-18 17:33:53 +08:00 committed by GitHub
parent 938f05fce8
commit 02425ea256
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -8,30 +8,30 @@ def multi_head_attention_forward_patched(
query,
key,
value,
embed_dim_to_check: int,
num_heads: int,
embed_dim_to_check,
num_heads,
in_proj_weight,
in_proj_bias: Optional[Tensor],
bias_k: Optional[Tensor],
bias_v: Optional[Tensor],
add_zero_attn: bool,
dropout_p: float,
out_proj_weight: Tensor,
out_proj_bias: Optional[Tensor],
training: bool = True,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
use_separate_proj_weight: bool = False,
q_proj_weight: Optional[Tensor] = None,
k_proj_weight: Optional[Tensor] = None,
v_proj_weight: Optional[Tensor] = None,
static_k: Optional[Tensor] = None,
static_v: Optional[Tensor] = None,
average_attn_weights: bool = True,
is_causal: bool = False,
in_proj_bias,
bias_k,
bias_v,
add_zero_attn,
dropout_p,
out_proj_weight,
out_proj_bias,
training=True,
key_padding_mask=None,
need_weights=True,
attn_mask=None,
use_separate_proj_weight=False,
q_proj_weight=None,
k_proj_weight=None,
v_proj_weight=None,
static_k=None,
static_v=None,
average_attn_weights=True,
is_causal=False,
cache=None,
) -> Tuple[Tensor, Optional[Tensor]]:
):
# set up shape vars
_, _, embed_dim = query.shape
attn_mask = _canonical_mask(