mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-08-23 21:19:47 +08:00
Remove the monkey patch causing issues when using with other proejcts
This commit is contained in:
parent
fe531567f1
commit
651b0075ad
@ -9,11 +9,8 @@ from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
|
||||
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from torch.nn import functional as F
|
||||
from GPT_SoVITS.AR.modules.patched_mha_with_cache import multi_head_attention_forward_patched
|
||||
|
||||
F.multi_head_attention_forward = multi_head_attention_forward_patched
|
||||
|
||||
|
||||
class MultiheadAttention(Module):
|
||||
r"""Allows the model to jointly attend to information
|
||||
@ -361,20 +358,20 @@ class MultiheadAttention(Module):
|
||||
query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
|
||||
|
||||
if not self._qkv_same_embed_dim:
|
||||
attn_output, attn_output_weights = F.multi_head_attention_forward(
|
||||
attn_output, attn_output_weights = multi_head_attention_forward_patched(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self.embed_dim,
|
||||
self.num_heads,
|
||||
self.in_proj_weight,
|
||||
self.in_proj_bias,
|
||||
self.bias_k,
|
||||
self.bias_v,
|
||||
self.add_zero_attn,
|
||||
self.dropout,
|
||||
self.out_proj.weight,
|
||||
self.out_proj.bias,
|
||||
embed_dim_to_check=self.embed_dim,
|
||||
num_heads=self.num_heads,
|
||||
in_proj_weight=self.in_proj_weight,
|
||||
in_proj_bias=self.in_proj_bias,
|
||||
bias_k=self.bias_k,
|
||||
bias_v=self.bias_v,
|
||||
add_zero_attn=self.add_zero_attn,
|
||||
dropout_p=self.dropout,
|
||||
out_proj_weight=self.out_proj.weight,
|
||||
out_proj_bias=self.out_proj.bias,
|
||||
training=self.training,
|
||||
key_padding_mask=key_padding_mask,
|
||||
need_weights=need_weights,
|
||||
@ -383,29 +380,39 @@ class MultiheadAttention(Module):
|
||||
q_proj_weight=self.q_proj_weight,
|
||||
k_proj_weight=self.k_proj_weight,
|
||||
v_proj_weight=self.v_proj_weight,
|
||||
static_k=None,
|
||||
static_v=None,
|
||||
average_attn_weights=average_attn_weights,
|
||||
is_causal=False,
|
||||
cache=cache,
|
||||
)
|
||||
else:
|
||||
attn_output, attn_output_weights = F.multi_head_attention_forward(
|
||||
attn_output, attn_output_weights = multi_head_attention_forward_patched(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self.embed_dim,
|
||||
self.num_heads,
|
||||
self.in_proj_weight,
|
||||
self.in_proj_bias,
|
||||
self.bias_k,
|
||||
self.bias_v,
|
||||
self.add_zero_attn,
|
||||
self.dropout,
|
||||
self.out_proj.weight,
|
||||
self.out_proj.bias,
|
||||
embed_dim_to_check=self.embed_dim,
|
||||
num_heads=self.num_heads,
|
||||
in_proj_weight=self.in_proj_weight,
|
||||
in_proj_bias=self.in_proj_bias,
|
||||
bias_k=self.bias_k,
|
||||
bias_v=self.bias_v,
|
||||
add_zero_attn=self.add_zero_attn,
|
||||
dropout_p=self.dropout,
|
||||
out_proj_weight=self.out_proj.weight,
|
||||
out_proj_bias=self.out_proj.bias,
|
||||
training=self.training,
|
||||
key_padding_mask=key_padding_mask,
|
||||
need_weights=need_weights,
|
||||
attn_mask=attn_mask,
|
||||
use_separate_proj_weight=False,
|
||||
q_proj_weight=self.q_proj_weight,
|
||||
k_proj_weight=self.k_proj_weight,
|
||||
v_proj_weight=self.v_proj_weight,
|
||||
static_k=None,
|
||||
static_v=None,
|
||||
average_attn_weights=average_attn_weights,
|
||||
is_causal=False,
|
||||
cache=cache,
|
||||
)
|
||||
if self.batch_first and is_batched:
|
||||
|
Loading…
x
Reference in New Issue
Block a user