From 651b0075ad17290c30ec6cf812a4e7cd967577ac Mon Sep 17 00:00:00 2001 From: Jarod Mica Date: Sun, 25 May 2025 00:34:35 -0700 Subject: [PATCH] Remove the monkey patch causing issues when using with other proejcts --- GPT_SoVITS/AR/modules/activation.py | 57 ++++++++++++++++------------- 1 file changed, 32 insertions(+), 25 deletions(-) diff --git a/GPT_SoVITS/AR/modules/activation.py b/GPT_SoVITS/AR/modules/activation.py index f05d5e5f..afc70091 100644 --- a/GPT_SoVITS/AR/modules/activation.py +++ b/GPT_SoVITS/AR/modules/activation.py @@ -9,11 +9,8 @@ from torch.nn.init import constant_, xavier_normal_, xavier_uniform_ from torch.nn.modules.linear import NonDynamicallyQuantizableLinear from torch.nn.parameter import Parameter -from torch.nn import functional as F from GPT_SoVITS.AR.modules.patched_mha_with_cache import multi_head_attention_forward_patched -F.multi_head_attention_forward = multi_head_attention_forward_patched - class MultiheadAttention(Module): r"""Allows the model to jointly attend to information @@ -361,20 +358,20 @@ class MultiheadAttention(Module): query, key, value = [x.transpose(1, 0) for x in (query, key, value)] if not self._qkv_same_embed_dim: - attn_output, attn_output_weights = F.multi_head_attention_forward( + attn_output, attn_output_weights = multi_head_attention_forward_patched( query, key, value, - self.embed_dim, - self.num_heads, - self.in_proj_weight, - self.in_proj_bias, - self.bias_k, - self.bias_v, - self.add_zero_attn, - self.dropout, - self.out_proj.weight, - self.out_proj.bias, + embed_dim_to_check=self.embed_dim, + num_heads=self.num_heads, + in_proj_weight=self.in_proj_weight, + in_proj_bias=self.in_proj_bias, + bias_k=self.bias_k, + bias_v=self.bias_v, + add_zero_attn=self.add_zero_attn, + dropout_p=self.dropout, + out_proj_weight=self.out_proj.weight, + out_proj_bias=self.out_proj.bias, training=self.training, key_padding_mask=key_padding_mask, need_weights=need_weights, @@ -383,29 +380,39 @@ class MultiheadAttention(Module): q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, v_proj_weight=self.v_proj_weight, + static_k=None, + static_v=None, average_attn_weights=average_attn_weights, + is_causal=False, cache=cache, ) else: - attn_output, attn_output_weights = F.multi_head_attention_forward( + attn_output, attn_output_weights = multi_head_attention_forward_patched( query, key, value, - self.embed_dim, - self.num_heads, - self.in_proj_weight, - self.in_proj_bias, - self.bias_k, - self.bias_v, - self.add_zero_attn, - self.dropout, - self.out_proj.weight, - self.out_proj.bias, + embed_dim_to_check=self.embed_dim, + num_heads=self.num_heads, + in_proj_weight=self.in_proj_weight, + in_proj_bias=self.in_proj_bias, + bias_k=self.bias_k, + bias_v=self.bias_v, + add_zero_attn=self.add_zero_attn, + dropout_p=self.dropout, + out_proj_weight=self.out_proj.weight, + out_proj_bias=self.out_proj.bias, training=self.training, key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask, + use_separate_proj_weight=False, + q_proj_weight=self.q_proj_weight, + k_proj_weight=self.k_proj_weight, + v_proj_weight=self.v_proj_weight, + static_k=None, + static_v=None, average_attn_weights=average_attn_weights, + is_causal=False, cache=cache, ) if self.batch_first and is_batched: