From ee4a466f79b4f643251aa2f873f541f85df11d91 Mon Sep 17 00:00:00 2001 From: RVC-Boss <129054828+RVC-Boss@users.noreply.github.com> Date: Wed, 26 Mar 2025 17:39:19 +0800 Subject: [PATCH] Update patched_mha_with_cache.py --- .../AR/modules/patched_mha_with_cache.py | 50 +++++++++---------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/GPT_SoVITS/AR/modules/patched_mha_with_cache.py b/GPT_SoVITS/AR/modules/patched_mha_with_cache.py index 7be241d..cab6afe 100644 --- a/GPT_SoVITS/AR/modules/patched_mha_with_cache.py +++ b/GPT_SoVITS/AR/modules/patched_mha_with_cache.py @@ -12,33 +12,33 @@ import torch def multi_head_attention_forward_patched( - query: Tensor, - key: Tensor, - value: Tensor, - embed_dim_to_check: int, - num_heads: int, - in_proj_weight: Optional[Tensor], - in_proj_bias: Optional[Tensor], - bias_k: Optional[Tensor], - bias_v: Optional[Tensor], - add_zero_attn: bool, + query, + key, + value, + embed_dim_to_check, + num_heads, + in_proj_weight, + in_proj_bias, + bias_k, + bias_v, + add_zero_attn, dropout_p: float, - out_proj_weight: Tensor, - out_proj_bias: Optional[Tensor], - training: bool = True, - key_padding_mask: Optional[Tensor] = None, - need_weights: bool = True, - attn_mask: Optional[Tensor] = None, - use_separate_proj_weight: bool = False, - q_proj_weight: Optional[Tensor] = None, - k_proj_weight: Optional[Tensor] = None, - v_proj_weight: Optional[Tensor] = None, - static_k: Optional[Tensor] = None, - static_v: Optional[Tensor] = None, - average_attn_weights: bool = True, - is_causal: bool = False, + out_proj_weight, + out_proj_bias, + training = True, + key_padding_mask = None, + need_weights = True, + attn_mask = None, + use_separate_proj_weight = False, + q_proj_weight = None, + k_proj_weight = None, + v_proj_weight = None, + static_k = None, + static_v = None, + average_attn_weights = True, + is_causal = False, cache=None, -) -> Tuple[Tensor, Optional[Tensor]]: +): r""" Args: query, key, value: map a query and a set of key-value pairs to an output.