diff --git a/GPT_SoVITS/AR/modules/patched_mha_with_cache_onnx.py b/GPT_SoVITS/AR/modules/patched_mha_with_cache_onnx.py index 8144c9c6..79077120 100644 --- a/GPT_SoVITS/AR/modules/patched_mha_with_cache_onnx.py +++ b/GPT_SoVITS/AR/modules/patched_mha_with_cache_onnx.py @@ -8,30 +8,30 @@ def multi_head_attention_forward_patched( query, key, value, - embed_dim_to_check: int, - num_heads: int, + embed_dim_to_check, + num_heads, in_proj_weight, - in_proj_bias: Optional[Tensor], - bias_k: Optional[Tensor], - bias_v: Optional[Tensor], - add_zero_attn: bool, - dropout_p: float, - out_proj_weight: Tensor, - out_proj_bias: Optional[Tensor], - training: bool = True, - key_padding_mask: Optional[Tensor] = None, - need_weights: bool = True, - attn_mask: Optional[Tensor] = None, - use_separate_proj_weight: bool = False, - q_proj_weight: Optional[Tensor] = None, - k_proj_weight: Optional[Tensor] = None, - v_proj_weight: Optional[Tensor] = None, - static_k: Optional[Tensor] = None, - static_v: Optional[Tensor] = None, - average_attn_weights: bool = True, - is_causal: bool = False, + in_proj_bias, + bias_k, + bias_v, + add_zero_attn, + dropout_p, + out_proj_weight, + out_proj_bias, + training=True, + key_padding_mask=None, + need_weights=True, + attn_mask=None, + use_separate_proj_weight=False, + q_proj_weight=None, + k_proj_weight=None, + v_proj_weight=None, + static_k=None, + static_v=None, + average_attn_weights=True, + is_causal=False, cache=None, -) -> Tuple[Tensor, Optional[Tensor]]: +): # set up shape vars _, _, embed_dim = query.shape attn_mask = _canonical_mask(