mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-05 19:41:56 +08:00
189 lines
6.4 KiB
Python
189 lines
6.4 KiB
Python
# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py
|
|
from typing import Optional, Tuple
|
|
|
|
import torch
|
|
from torch import Tensor
|
|
from torch.nn import Linear, Module
|
|
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
|
|
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
|
|
from torch.nn.parameter import Parameter
|
|
|
|
from AR.modules.patched_mha_with_cache_onnx import multi_head_attention_forward_patched
|
|
|
|
|
|
class MultiheadAttention(Module):
|
|
__constants__ = ["batch_first"]
|
|
bias_k: Optional[torch.Tensor]
|
|
bias_v: Optional[torch.Tensor]
|
|
|
|
def __init__(
|
|
self,
|
|
embed_dim,
|
|
num_heads,
|
|
dropout=0.0,
|
|
bias=True,
|
|
add_bias_kv=False,
|
|
add_zero_attn=False,
|
|
kdim=None,
|
|
vdim=None,
|
|
batch_first=False,
|
|
linear1_cls=Linear,
|
|
linear2_cls=Linear,
|
|
device=None,
|
|
dtype=None,
|
|
) -> None:
|
|
factory_kwargs = {"device": device, "dtype": dtype}
|
|
super(MultiheadAttention, self).__init__()
|
|
self.embed_dim = embed_dim
|
|
self.kdim = kdim if kdim is not None else embed_dim
|
|
self.vdim = vdim if vdim is not None else embed_dim
|
|
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
|
|
|
self.num_heads = num_heads
|
|
self.dropout = dropout
|
|
self.batch_first = batch_first
|
|
self.head_dim = embed_dim // num_heads
|
|
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
|
|
|
if add_bias_kv:
|
|
self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
|
|
self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
|
|
else:
|
|
self.bias_k = self.bias_v = None
|
|
|
|
if linear1_cls == Linear:
|
|
if not self._qkv_same_embed_dim:
|
|
self.q_proj_weight = Parameter(
|
|
torch.empty(
|
|
(embed_dim, embed_dim),
|
|
**factory_kwargs,
|
|
)
|
|
)
|
|
self.k_proj_weight = Parameter(
|
|
torch.empty(
|
|
(embed_dim, self.kdim),
|
|
**factory_kwargs,
|
|
)
|
|
)
|
|
self.v_proj_weight = Parameter(
|
|
torch.empty(
|
|
(embed_dim, self.vdim),
|
|
**factory_kwargs,
|
|
)
|
|
)
|
|
self.register_parameter("in_proj_weight", None)
|
|
else:
|
|
self.in_proj_weight = Parameter(
|
|
torch.empty(
|
|
(3 * embed_dim, embed_dim),
|
|
**factory_kwargs,
|
|
)
|
|
)
|
|
self.register_parameter("q_proj_weight", None)
|
|
self.register_parameter("k_proj_weight", None)
|
|
self.register_parameter("v_proj_weight", None)
|
|
|
|
if bias:
|
|
self.in_proj_bias = Parameter(
|
|
torch.empty(3 * embed_dim, **factory_kwargs),
|
|
)
|
|
else:
|
|
self.register_parameter("in_proj_bias", None)
|
|
self.out_proj = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
|
|
|
|
self._reset_parameters()
|
|
else:
|
|
if not self._qkv_same_embed_dim:
|
|
raise NotImplementedError
|
|
else:
|
|
self.in_proj_linear = linear1_cls(
|
|
embed_dim,
|
|
3 * embed_dim,
|
|
bias=bias,
|
|
**factory_kwargs,
|
|
)
|
|
self.in_proj_weight = self.in_proj_linear.weight
|
|
|
|
self.register_parameter("q_proj_weight", None)
|
|
self.register_parameter("k_proj_weight", None)
|
|
self.register_parameter("v_proj_weight", None)
|
|
|
|
if bias:
|
|
self.in_proj_bias = self.in_proj_linear.bias
|
|
else:
|
|
self.register_parameter("in_proj_bias", None)
|
|
|
|
self.out_proj = linear2_cls(
|
|
embed_dim,
|
|
embed_dim,
|
|
bias=bias,
|
|
**factory_kwargs,
|
|
)
|
|
|
|
if self.bias_k is not None:
|
|
xavier_normal_(self.bias_k)
|
|
if self.bias_v is not None:
|
|
xavier_normal_(self.bias_v)
|
|
|
|
self.add_zero_attn = add_zero_attn
|
|
|
|
def _reset_parameters(self):
|
|
if self._qkv_same_embed_dim:
|
|
xavier_uniform_(self.in_proj_weight)
|
|
else:
|
|
xavier_uniform_(self.q_proj_weight)
|
|
xavier_uniform_(self.k_proj_weight)
|
|
xavier_uniform_(self.v_proj_weight)
|
|
|
|
if self.in_proj_bias is not None:
|
|
constant_(self.in_proj_bias, 0.0)
|
|
constant_(self.out_proj.bias, 0.0)
|
|
|
|
if self.bias_k is not None:
|
|
xavier_normal_(self.bias_k)
|
|
if self.bias_v is not None:
|
|
xavier_normal_(self.bias_v)
|
|
|
|
def __setstate__(self, state):
|
|
# Support loading old MultiheadAttention checkpoints generated by v1.1.0
|
|
if "_qkv_same_embed_dim" not in state:
|
|
state["_qkv_same_embed_dim"] = True
|
|
|
|
super(MultiheadAttention, self).__setstate__(state)
|
|
|
|
def forward(
|
|
self,
|
|
query: Tensor,
|
|
key: Tensor,
|
|
value: Tensor,
|
|
key_padding_mask: Optional[Tensor] = None,
|
|
need_weights: bool = True,
|
|
attn_mask: Optional[Tensor] = None,
|
|
average_attn_weights: bool = True,
|
|
cache=None,
|
|
) -> Tuple[Tensor, Optional[Tensor]]:
|
|
any_nested = query.is_nested or key.is_nested or value.is_nested
|
|
query = key = value = query.transpose(1, 0)
|
|
attn_output = multi_head_attention_forward_patched(
|
|
query,
|
|
key,
|
|
value,
|
|
self.embed_dim,
|
|
self.num_heads,
|
|
self.in_proj_weight,
|
|
self.in_proj_bias,
|
|
self.bias_k,
|
|
self.bias_v,
|
|
self.add_zero_attn,
|
|
self.dropout,
|
|
self.out_proj.weight,
|
|
self.out_proj.bias,
|
|
training=self.training,
|
|
key_padding_mask=key_padding_mask,
|
|
need_weights=need_weights,
|
|
attn_mask=attn_mask,
|
|
average_attn_weights=average_attn_weights,
|
|
cache=cache,
|
|
)
|
|
return attn_output.transpose(1, 0)
|