mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-05-22 06:35:37 +08:00
Update patched_mha_with_cache.py
This commit is contained in:
parent
b65ea9181e
commit
ee4a466f79
@ -12,33 +12,33 @@ import torch
|
|||||||
|
|
||||||
|
|
||||||
def multi_head_attention_forward_patched(
|
def multi_head_attention_forward_patched(
|
||||||
query: Tensor,
|
query,
|
||||||
key: Tensor,
|
key,
|
||||||
value: Tensor,
|
value,
|
||||||
embed_dim_to_check: int,
|
embed_dim_to_check,
|
||||||
num_heads: int,
|
num_heads,
|
||||||
in_proj_weight: Optional[Tensor],
|
in_proj_weight,
|
||||||
in_proj_bias: Optional[Tensor],
|
in_proj_bias,
|
||||||
bias_k: Optional[Tensor],
|
bias_k,
|
||||||
bias_v: Optional[Tensor],
|
bias_v,
|
||||||
add_zero_attn: bool,
|
add_zero_attn,
|
||||||
dropout_p: float,
|
dropout_p: float,
|
||||||
out_proj_weight: Tensor,
|
out_proj_weight,
|
||||||
out_proj_bias: Optional[Tensor],
|
out_proj_bias,
|
||||||
training: bool = True,
|
training = True,
|
||||||
key_padding_mask: Optional[Tensor] = None,
|
key_padding_mask = None,
|
||||||
need_weights: bool = True,
|
need_weights = True,
|
||||||
attn_mask: Optional[Tensor] = None,
|
attn_mask = None,
|
||||||
use_separate_proj_weight: bool = False,
|
use_separate_proj_weight = False,
|
||||||
q_proj_weight: Optional[Tensor] = None,
|
q_proj_weight = None,
|
||||||
k_proj_weight: Optional[Tensor] = None,
|
k_proj_weight = None,
|
||||||
v_proj_weight: Optional[Tensor] = None,
|
v_proj_weight = None,
|
||||||
static_k: Optional[Tensor] = None,
|
static_k = None,
|
||||||
static_v: Optional[Tensor] = None,
|
static_v = None,
|
||||||
average_attn_weights: bool = True,
|
average_attn_weights = True,
|
||||||
is_causal: bool = False,
|
is_causal = False,
|
||||||
cache=None,
|
cache=None,
|
||||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
query, key, value: map a query and a set of key-value pairs to an output.
|
query, key, value: map a query and a set of key-value pairs to an output.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user