mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-08-24 05:29:45 +08:00
Remove the monkey patch causing issues when using with other proejcts
This commit is contained in:
parent
fe531567f1
commit
651b0075ad
@ -9,11 +9,8 @@ from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
|
|||||||
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
|
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from torch.nn import functional as F
|
|
||||||
from GPT_SoVITS.AR.modules.patched_mha_with_cache import multi_head_attention_forward_patched
|
from GPT_SoVITS.AR.modules.patched_mha_with_cache import multi_head_attention_forward_patched
|
||||||
|
|
||||||
F.multi_head_attention_forward = multi_head_attention_forward_patched
|
|
||||||
|
|
||||||
|
|
||||||
class MultiheadAttention(Module):
|
class MultiheadAttention(Module):
|
||||||
r"""Allows the model to jointly attend to information
|
r"""Allows the model to jointly attend to information
|
||||||
@ -361,20 +358,20 @@ class MultiheadAttention(Module):
|
|||||||
query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
|
query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
|
||||||
|
|
||||||
if not self._qkv_same_embed_dim:
|
if not self._qkv_same_embed_dim:
|
||||||
attn_output, attn_output_weights = F.multi_head_attention_forward(
|
attn_output, attn_output_weights = multi_head_attention_forward_patched(
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
self.embed_dim,
|
embed_dim_to_check=self.embed_dim,
|
||||||
self.num_heads,
|
num_heads=self.num_heads,
|
||||||
self.in_proj_weight,
|
in_proj_weight=self.in_proj_weight,
|
||||||
self.in_proj_bias,
|
in_proj_bias=self.in_proj_bias,
|
||||||
self.bias_k,
|
bias_k=self.bias_k,
|
||||||
self.bias_v,
|
bias_v=self.bias_v,
|
||||||
self.add_zero_attn,
|
add_zero_attn=self.add_zero_attn,
|
||||||
self.dropout,
|
dropout_p=self.dropout,
|
||||||
self.out_proj.weight,
|
out_proj_weight=self.out_proj.weight,
|
||||||
self.out_proj.bias,
|
out_proj_bias=self.out_proj.bias,
|
||||||
training=self.training,
|
training=self.training,
|
||||||
key_padding_mask=key_padding_mask,
|
key_padding_mask=key_padding_mask,
|
||||||
need_weights=need_weights,
|
need_weights=need_weights,
|
||||||
@ -383,29 +380,39 @@ class MultiheadAttention(Module):
|
|||||||
q_proj_weight=self.q_proj_weight,
|
q_proj_weight=self.q_proj_weight,
|
||||||
k_proj_weight=self.k_proj_weight,
|
k_proj_weight=self.k_proj_weight,
|
||||||
v_proj_weight=self.v_proj_weight,
|
v_proj_weight=self.v_proj_weight,
|
||||||
|
static_k=None,
|
||||||
|
static_v=None,
|
||||||
average_attn_weights=average_attn_weights,
|
average_attn_weights=average_attn_weights,
|
||||||
|
is_causal=False,
|
||||||
cache=cache,
|
cache=cache,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
attn_output, attn_output_weights = F.multi_head_attention_forward(
|
attn_output, attn_output_weights = multi_head_attention_forward_patched(
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
self.embed_dim,
|
embed_dim_to_check=self.embed_dim,
|
||||||
self.num_heads,
|
num_heads=self.num_heads,
|
||||||
self.in_proj_weight,
|
in_proj_weight=self.in_proj_weight,
|
||||||
self.in_proj_bias,
|
in_proj_bias=self.in_proj_bias,
|
||||||
self.bias_k,
|
bias_k=self.bias_k,
|
||||||
self.bias_v,
|
bias_v=self.bias_v,
|
||||||
self.add_zero_attn,
|
add_zero_attn=self.add_zero_attn,
|
||||||
self.dropout,
|
dropout_p=self.dropout,
|
||||||
self.out_proj.weight,
|
out_proj_weight=self.out_proj.weight,
|
||||||
self.out_proj.bias,
|
out_proj_bias=self.out_proj.bias,
|
||||||
training=self.training,
|
training=self.training,
|
||||||
key_padding_mask=key_padding_mask,
|
key_padding_mask=key_padding_mask,
|
||||||
need_weights=need_weights,
|
need_weights=need_weights,
|
||||||
attn_mask=attn_mask,
|
attn_mask=attn_mask,
|
||||||
|
use_separate_proj_weight=False,
|
||||||
|
q_proj_weight=self.q_proj_weight,
|
||||||
|
k_proj_weight=self.k_proj_weight,
|
||||||
|
v_proj_weight=self.v_proj_weight,
|
||||||
|
static_k=None,
|
||||||
|
static_v=None,
|
||||||
average_attn_weights=average_attn_weights,
|
average_attn_weights=average_attn_weights,
|
||||||
|
is_causal=False,
|
||||||
cache=cache,
|
cache=cache,
|
||||||
)
|
)
|
||||||
if self.batch_first and is_batched:
|
if self.batch_first and is_batched:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user