Remove the monkey patch causing issues when using with other proejcts

This commit is contained in:
Jarod Mica 2025-05-25 00:34:35 -07:00
parent fe531567f1
commit 651b0075ad

View File

@ -9,11 +9,8 @@ from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
from torch.nn.parameter import Parameter
from torch.nn import functional as F
from GPT_SoVITS.AR.modules.patched_mha_with_cache import multi_head_attention_forward_patched
F.multi_head_attention_forward = multi_head_attention_forward_patched
class MultiheadAttention(Module):
r"""Allows the model to jointly attend to information
@ -361,20 +358,20 @@ class MultiheadAttention(Module):
query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
if not self._qkv_same_embed_dim:
attn_output, attn_output_weights = F.multi_head_attention_forward(
attn_output, attn_output_weights = multi_head_attention_forward_patched(
query,
key,
value,
self.embed_dim,
self.num_heads,
self.in_proj_weight,
self.in_proj_bias,
self.bias_k,
self.bias_v,
self.add_zero_attn,
self.dropout,
self.out_proj.weight,
self.out_proj.bias,
embed_dim_to_check=self.embed_dim,
num_heads=self.num_heads,
in_proj_weight=self.in_proj_weight,
in_proj_bias=self.in_proj_bias,
bias_k=self.bias_k,
bias_v=self.bias_v,
add_zero_attn=self.add_zero_attn,
dropout_p=self.dropout,
out_proj_weight=self.out_proj.weight,
out_proj_bias=self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask,
need_weights=need_weights,
@ -383,29 +380,39 @@ class MultiheadAttention(Module):
q_proj_weight=self.q_proj_weight,
k_proj_weight=self.k_proj_weight,
v_proj_weight=self.v_proj_weight,
static_k=None,
static_v=None,
average_attn_weights=average_attn_weights,
is_causal=False,
cache=cache,
)
else:
attn_output, attn_output_weights = F.multi_head_attention_forward(
attn_output, attn_output_weights = multi_head_attention_forward_patched(
query,
key,
value,
self.embed_dim,
self.num_heads,
self.in_proj_weight,
self.in_proj_bias,
self.bias_k,
self.bias_v,
self.add_zero_attn,
self.dropout,
self.out_proj.weight,
self.out_proj.bias,
embed_dim_to_check=self.embed_dim,
num_heads=self.num_heads,
in_proj_weight=self.in_proj_weight,
in_proj_bias=self.in_proj_bias,
bias_k=self.bias_k,
bias_v=self.bias_v,
add_zero_attn=self.add_zero_attn,
dropout_p=self.dropout,
out_proj_weight=self.out_proj.weight,
out_proj_bias=self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask,
need_weights=need_weights,
attn_mask=attn_mask,
use_separate_proj_weight=False,
q_proj_weight=self.q_proj_weight,
k_proj_weight=self.k_proj_weight,
v_proj_weight=self.v_proj_weight,
static_k=None,
static_v=None,
average_attn_weights=average_attn_weights,
is_causal=False,
cache=cache,
)
if self.batch_first and is_batched: