diff --git a/GPT_SoVITS/module/activation_onnx.py b/GPT_SoVITS/module/activation_onnx.py deleted file mode 100644 index b54acd99..00000000 --- a/GPT_SoVITS/module/activation_onnx.py +++ /dev/null @@ -1,178 +0,0 @@ -# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py -from typing import Optional -from typing import Tuple -import torch -from torch import Tensor -from torch.nn import Linear -from torch.nn import Module -from torch.nn.init import constant_ -from torch.nn.init import xavier_normal_ -from torch.nn.init import xavier_uniform_ -from torch.nn.modules.linear import NonDynamicallyQuantizableLinear -from torch.nn.parameter import Parameter - -from torch.nn import functional as F -from AR.modules.patched_mha_with_cache_onnx import multi_head_attention_forward_patched - - -class MultiheadAttention(Module): - __constants__ = ["batch_first"] - bias_k: Optional[torch.Tensor] - bias_v: Optional[torch.Tensor] - - def __init__( - self, - embed_dim, - num_heads, - dropout=0.0, - bias=True, - add_bias_kv=False, - add_zero_attn=False, - kdim=None, - vdim=None, - batch_first=False, - linear1_cls=Linear, - linear2_cls=Linear, - device=None, - dtype=None, - ) -> None: - factory_kwargs = {"device": device, "dtype": dtype} - super(MultiheadAttention, self).__init__() - self.embed_dim = embed_dim - self.kdim = kdim if kdim is not None else embed_dim - self.vdim = vdim if vdim is not None else embed_dim - self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim - - self.num_heads = num_heads - self.dropout = dropout - self.batch_first = batch_first - self.head_dim = embed_dim // num_heads - assert ( - self.head_dim * num_heads == self.embed_dim - ), "embed_dim must be divisible by num_heads" - - if add_bias_kv: - self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) - self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) - else: - self.bias_k = self.bias_v = None - - if linear1_cls == Linear: - if not self._qkv_same_embed_dim: - self.q_proj_weight = Parameter( - torch.empty((embed_dim, embed_dim), **factory_kwargs) - ) - self.k_proj_weight = Parameter( - torch.empty((embed_dim, self.kdim), **factory_kwargs) - ) - self.v_proj_weight = Parameter( - torch.empty((embed_dim, self.vdim), **factory_kwargs) - ) - self.register_parameter("in_proj_weight", None) - else: - self.in_proj_weight = Parameter( - torch.empty((3 * embed_dim, embed_dim), **factory_kwargs) - ) - self.register_parameter("q_proj_weight", None) - self.register_parameter("k_proj_weight", None) - self.register_parameter("v_proj_weight", None) - - if bias: - self.in_proj_bias = Parameter( - torch.empty(3 * embed_dim, **factory_kwargs) - ) - else: - self.register_parameter("in_proj_bias", None) - self.out_proj = NonDynamicallyQuantizableLinear( - embed_dim, embed_dim, bias=bias, **factory_kwargs - ) - - self._reset_parameters() - else: - if not self._qkv_same_embed_dim: - raise NotImplementedError - else: - self.in_proj_linear = linear1_cls( - embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs - ) - self.in_proj_weight = self.in_proj_linear.weight - - self.register_parameter("q_proj_weight", None) - self.register_parameter("k_proj_weight", None) - self.register_parameter("v_proj_weight", None) - - if bias: - self.in_proj_bias = self.in_proj_linear.bias - else: - self.register_parameter("in_proj_bias", None) - - self.out_proj = linear2_cls( - embed_dim, embed_dim, bias=bias, **factory_kwargs - ) - - if self.bias_k is not None: - xavier_normal_(self.bias_k) - if self.bias_v is not None: - xavier_normal_(self.bias_v) - - self.add_zero_attn = add_zero_attn - - def _reset_parameters(self): - if self._qkv_same_embed_dim: - xavier_uniform_(self.in_proj_weight) - else: - xavier_uniform_(self.q_proj_weight) - xavier_uniform_(self.k_proj_weight) - xavier_uniform_(self.v_proj_weight) - - if self.in_proj_bias is not None: - constant_(self.in_proj_bias, 0.0) - constant_(self.out_proj.bias, 0.0) - - if self.bias_k is not None: - xavier_normal_(self.bias_k) - if self.bias_v is not None: - xavier_normal_(self.bias_v) - - def __setstate__(self, state): - # Support loading old MultiheadAttention checkpoints generated by v1.1.0 - if "_qkv_same_embed_dim" not in state: - state["_qkv_same_embed_dim"] = True - - super(MultiheadAttention, self).__setstate__(state) - - def forward( - self, - query: Tensor, - key: Tensor, - value: Tensor, - key_padding_mask: Optional[Tensor] = None, - need_weights: bool = True, - attn_mask: Optional[Tensor] = None, - average_attn_weights: bool = True, - cache=None, - ) -> Tuple[Tensor, Optional[Tensor]]: - any_nested = query.is_nested or key.is_nested or value.is_nested - query = key = value = query.transpose(1, 0) - attn_output = multi_head_attention_forward_patched( - query, - key, - value, - self.embed_dim, - self.num_heads, - self.in_proj_weight, - self.in_proj_bias, - self.bias_k, - self.bias_v, - self.add_zero_attn, - self.dropout, - self.out_proj.weight, - self.out_proj.bias, - training=self.training, - key_padding_mask=key_padding_mask, - need_weights=need_weights, - attn_mask=attn_mask, - average_attn_weights=average_attn_weights, - cache=cache, - ) - return attn_output.transpose(1, 0)