mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-06 22:50:00 +08:00
Add files via upload
This commit is contained in:
parent
8582131bd8
commit
4cd1d83de1
178
GPT_SoVITS/module/activation_onnx.py
Normal file
178
GPT_SoVITS/module/activation_onnx.py
Normal file
@ -0,0 +1,178 @@
|
|||||||
|
# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Tuple
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.nn import Linear
|
||||||
|
from torch.nn import Module
|
||||||
|
from torch.nn.init import constant_
|
||||||
|
from torch.nn.init import xavier_normal_
|
||||||
|
from torch.nn.init import xavier_uniform_
|
||||||
|
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from AR.modules.patched_mha_with_cache_onnx import multi_head_attention_forward_patched
|
||||||
|
|
||||||
|
|
||||||
|
class MultiheadAttention(Module):
|
||||||
|
__constants__ = ["batch_first"]
|
||||||
|
bias_k: Optional[torch.Tensor]
|
||||||
|
bias_v: Optional[torch.Tensor]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dim,
|
||||||
|
num_heads,
|
||||||
|
dropout=0.0,
|
||||||
|
bias=True,
|
||||||
|
add_bias_kv=False,
|
||||||
|
add_zero_attn=False,
|
||||||
|
kdim=None,
|
||||||
|
vdim=None,
|
||||||
|
batch_first=False,
|
||||||
|
linear1_cls=Linear,
|
||||||
|
linear2_cls=Linear,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
) -> None:
|
||||||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
super(MultiheadAttention, self).__init__()
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.kdim = kdim if kdim is not None else embed_dim
|
||||||
|
self.vdim = vdim if vdim is not None else embed_dim
|
||||||
|
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
||||||
|
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.dropout = dropout
|
||||||
|
self.batch_first = batch_first
|
||||||
|
self.head_dim = embed_dim // num_heads
|
||||||
|
assert (
|
||||||
|
self.head_dim * num_heads == self.embed_dim
|
||||||
|
), "embed_dim must be divisible by num_heads"
|
||||||
|
|
||||||
|
if add_bias_kv:
|
||||||
|
self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
|
||||||
|
self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
|
||||||
|
else:
|
||||||
|
self.bias_k = self.bias_v = None
|
||||||
|
|
||||||
|
if linear1_cls == Linear:
|
||||||
|
if not self._qkv_same_embed_dim:
|
||||||
|
self.q_proj_weight = Parameter(
|
||||||
|
torch.empty((embed_dim, embed_dim), **factory_kwargs)
|
||||||
|
)
|
||||||
|
self.k_proj_weight = Parameter(
|
||||||
|
torch.empty((embed_dim, self.kdim), **factory_kwargs)
|
||||||
|
)
|
||||||
|
self.v_proj_weight = Parameter(
|
||||||
|
torch.empty((embed_dim, self.vdim), **factory_kwargs)
|
||||||
|
)
|
||||||
|
self.register_parameter("in_proj_weight", None)
|
||||||
|
else:
|
||||||
|
self.in_proj_weight = Parameter(
|
||||||
|
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
|
||||||
|
)
|
||||||
|
self.register_parameter("q_proj_weight", None)
|
||||||
|
self.register_parameter("k_proj_weight", None)
|
||||||
|
self.register_parameter("v_proj_weight", None)
|
||||||
|
|
||||||
|
if bias:
|
||||||
|
self.in_proj_bias = Parameter(
|
||||||
|
torch.empty(3 * embed_dim, **factory_kwargs)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.register_parameter("in_proj_bias", None)
|
||||||
|
self.out_proj = NonDynamicallyQuantizableLinear(
|
||||||
|
embed_dim, embed_dim, bias=bias, **factory_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
self._reset_parameters()
|
||||||
|
else:
|
||||||
|
if not self._qkv_same_embed_dim:
|
||||||
|
raise NotImplementedError
|
||||||
|
else:
|
||||||
|
self.in_proj_linear = linear1_cls(
|
||||||
|
embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
|
||||||
|
)
|
||||||
|
self.in_proj_weight = self.in_proj_linear.weight
|
||||||
|
|
||||||
|
self.register_parameter("q_proj_weight", None)
|
||||||
|
self.register_parameter("k_proj_weight", None)
|
||||||
|
self.register_parameter("v_proj_weight", None)
|
||||||
|
|
||||||
|
if bias:
|
||||||
|
self.in_proj_bias = self.in_proj_linear.bias
|
||||||
|
else:
|
||||||
|
self.register_parameter("in_proj_bias", None)
|
||||||
|
|
||||||
|
self.out_proj = linear2_cls(
|
||||||
|
embed_dim, embed_dim, bias=bias, **factory_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.bias_k is not None:
|
||||||
|
xavier_normal_(self.bias_k)
|
||||||
|
if self.bias_v is not None:
|
||||||
|
xavier_normal_(self.bias_v)
|
||||||
|
|
||||||
|
self.add_zero_attn = add_zero_attn
|
||||||
|
|
||||||
|
def _reset_parameters(self):
|
||||||
|
if self._qkv_same_embed_dim:
|
||||||
|
xavier_uniform_(self.in_proj_weight)
|
||||||
|
else:
|
||||||
|
xavier_uniform_(self.q_proj_weight)
|
||||||
|
xavier_uniform_(self.k_proj_weight)
|
||||||
|
xavier_uniform_(self.v_proj_weight)
|
||||||
|
|
||||||
|
if self.in_proj_bias is not None:
|
||||||
|
constant_(self.in_proj_bias, 0.0)
|
||||||
|
constant_(self.out_proj.bias, 0.0)
|
||||||
|
|
||||||
|
if self.bias_k is not None:
|
||||||
|
xavier_normal_(self.bias_k)
|
||||||
|
if self.bias_v is not None:
|
||||||
|
xavier_normal_(self.bias_v)
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
# Support loading old MultiheadAttention checkpoints generated by v1.1.0
|
||||||
|
if "_qkv_same_embed_dim" not in state:
|
||||||
|
state["_qkv_same_embed_dim"] = True
|
||||||
|
|
||||||
|
super(MultiheadAttention, self).__setstate__(state)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
query: Tensor,
|
||||||
|
key: Tensor,
|
||||||
|
value: Tensor,
|
||||||
|
key_padding_mask: Optional[Tensor] = None,
|
||||||
|
need_weights: bool = True,
|
||||||
|
attn_mask: Optional[Tensor] = None,
|
||||||
|
average_attn_weights: bool = True,
|
||||||
|
cache=None,
|
||||||
|
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||||
|
any_nested = query.is_nested or key.is_nested or value.is_nested
|
||||||
|
query = key = value = query.transpose(1, 0)
|
||||||
|
attn_output = multi_head_attention_forward_patched(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
self.embed_dim,
|
||||||
|
self.num_heads,
|
||||||
|
self.in_proj_weight,
|
||||||
|
self.in_proj_bias,
|
||||||
|
self.bias_k,
|
||||||
|
self.bias_v,
|
||||||
|
self.add_zero_attn,
|
||||||
|
self.dropout,
|
||||||
|
self.out_proj.weight,
|
||||||
|
self.out_proj.bias,
|
||||||
|
training=self.training,
|
||||||
|
key_padding_mask=key_padding_mask,
|
||||||
|
need_weights=need_weights,
|
||||||
|
attn_mask=attn_mask,
|
||||||
|
average_attn_weights=average_attn_weights,
|
||||||
|
cache=cache,
|
||||||
|
)
|
||||||
|
return attn_output.transpose(1, 0)
|
64
GPT_SoVITS/module/embedding_onnx.py
Normal file
64
GPT_SoVITS/module/embedding_onnx.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/embedding.py
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
class TokenEmbedding(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int,
|
||||||
|
vocab_size: int,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
|
||||||
|
self.dropout = torch.nn.Dropout(p=dropout)
|
||||||
|
self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_dim)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def weight(self) -> torch.Tensor:
|
||||||
|
return self.word_embeddings.weight
|
||||||
|
|
||||||
|
def embedding(self, index: int) -> torch.Tensor:
|
||||||
|
return self.word_embeddings.weight[index : index + 1]
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
x = self.word_embeddings(x)
|
||||||
|
x = self.dropout(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SinePositionalEmbedding(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: int,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
scale: bool = False,
|
||||||
|
alpha: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.x_scale = math.sqrt(embedding_dim) if scale else 1.0
|
||||||
|
self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
|
||||||
|
self.dropout = torch.nn.Dropout(p=dropout)
|
||||||
|
self.reverse = False
|
||||||
|
self.div_term = torch.exp(torch.arange(0, self.embedding_dim, 2) * -(math.log(10000.0) / self.embedding_dim))
|
||||||
|
self.pe = self.extend_pe(2000)
|
||||||
|
|
||||||
|
def extend_pe(self, x):
|
||||||
|
position = torch.cumsum(torch.ones((x,1)), dim=0)
|
||||||
|
scpe = (position * self.div_term).unsqueeze(0)
|
||||||
|
pe = torch.cat([torch.sin(scpe), torch.cos(scpe)]).permute(1, 2, 0)
|
||||||
|
pe = pe.contiguous().view(1, -1, self.embedding_dim)
|
||||||
|
return pe
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
pe = self.pe[:,:x.size(1),:]
|
||||||
|
output = x.unsqueeze(-1) if x.ndim == 2 else x
|
||||||
|
output = output * self.x_scale + self.alpha * pe
|
||||||
|
return self.dropout(output)
|
92
GPT_SoVITS/module/patched_mha_with_cache_onnx.py
Normal file
92
GPT_SoVITS/module/patched_mha_with_cache_onnx.py
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
from torch.nn.functional import *
|
||||||
|
from torch.nn.functional import (
|
||||||
|
_mha_shape_check,
|
||||||
|
_canonical_mask,
|
||||||
|
_none_or_dtype,
|
||||||
|
_in_projection_packed,
|
||||||
|
)
|
||||||
|
|
||||||
|
def multi_head_attention_forward_patched(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
embed_dim_to_check: int,
|
||||||
|
num_heads: int,
|
||||||
|
in_proj_weight,
|
||||||
|
in_proj_bias: Optional[Tensor],
|
||||||
|
bias_k: Optional[Tensor],
|
||||||
|
bias_v: Optional[Tensor],
|
||||||
|
add_zero_attn: bool,
|
||||||
|
dropout_p: float,
|
||||||
|
out_proj_weight: Tensor,
|
||||||
|
out_proj_bias: Optional[Tensor],
|
||||||
|
training: bool = True,
|
||||||
|
key_padding_mask: Optional[Tensor] = None,
|
||||||
|
need_weights: bool = True,
|
||||||
|
attn_mask: Optional[Tensor] = None,
|
||||||
|
use_separate_proj_weight: bool = False,
|
||||||
|
q_proj_weight: Optional[Tensor] = None,
|
||||||
|
k_proj_weight: Optional[Tensor] = None,
|
||||||
|
v_proj_weight: Optional[Tensor] = None,
|
||||||
|
static_k: Optional[Tensor] = None,
|
||||||
|
static_v: Optional[Tensor] = None,
|
||||||
|
average_attn_weights: bool = True,
|
||||||
|
is_causal: bool = False,
|
||||||
|
cache=None,
|
||||||
|
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||||
|
|
||||||
|
# set up shape vars
|
||||||
|
_, _, embed_dim = query.shape
|
||||||
|
attn_mask = _canonical_mask(
|
||||||
|
mask=attn_mask,
|
||||||
|
mask_name="attn_mask",
|
||||||
|
other_type=None,
|
||||||
|
other_name="",
|
||||||
|
target_type=query.dtype,
|
||||||
|
check_other=False,
|
||||||
|
)
|
||||||
|
head_dim = embed_dim // num_heads
|
||||||
|
|
||||||
|
proj_qkv = linear(query, in_proj_weight, in_proj_bias)
|
||||||
|
proj_qkv = proj_qkv.unflatten(-1, (3, query.size(-1))).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
|
||||||
|
q, k, v = proj_qkv[0], proj_qkv[1], proj_qkv[2]
|
||||||
|
|
||||||
|
if cache["first_infer"] == 1:
|
||||||
|
cache["k"][cache["stage"]] = k
|
||||||
|
cache["v"][cache["stage"]] = v
|
||||||
|
else:
|
||||||
|
cache["k"][cache["stage"]] = torch.cat([cache["k"][cache["stage"]][:-1], k], 0)
|
||||||
|
cache["v"][cache["stage"]] = torch.cat([cache["v"][cache["stage"]][:-1], v], 0)
|
||||||
|
k = cache["k"][cache["stage"]]
|
||||||
|
v = cache["v"][cache["stage"]]
|
||||||
|
cache["stage"] = (cache["stage"] + 1) % cache["all_stage"]
|
||||||
|
|
||||||
|
attn_mask = _canonical_mask(
|
||||||
|
mask=attn_mask,
|
||||||
|
mask_name="attn_mask",
|
||||||
|
other_type=None,
|
||||||
|
other_name="",
|
||||||
|
target_type=q.dtype,
|
||||||
|
check_other=False,
|
||||||
|
)
|
||||||
|
attn_mask = attn_mask.unsqueeze(0)
|
||||||
|
|
||||||
|
q = q.view(-1, num_heads, head_dim).transpose(0, 1)
|
||||||
|
k = k.view(-1, num_heads, head_dim).transpose(0, 1)
|
||||||
|
v = v.view(-1, num_heads, head_dim).transpose(0, 1)
|
||||||
|
|
||||||
|
dropout_p = 0.0
|
||||||
|
attn_mask = attn_mask.unsqueeze(0)
|
||||||
|
q = q.view(num_heads, -1, head_dim).unsqueeze(0)
|
||||||
|
k = k.view(num_heads, -1, head_dim).unsqueeze(0)
|
||||||
|
v = v.view(num_heads, -1, head_dim).unsqueeze(0)
|
||||||
|
attn_output = scaled_dot_product_attention(
|
||||||
|
q, k, v, attn_mask, dropout_p, is_causal
|
||||||
|
)
|
||||||
|
attn_output = (
|
||||||
|
attn_output.permute(2, 0, 1, 3).contiguous().view(-1, embed_dim)
|
||||||
|
)
|
||||||
|
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
||||||
|
attn_output = attn_output.view(-1, 1, attn_output.size(1))
|
||||||
|
|
||||||
|
return attn_output
|
292
GPT_SoVITS/module/transformer_onnx.py
Normal file
292
GPT_SoVITS/module/transformer_onnx.py
Normal file
@ -0,0 +1,292 @@
|
|||||||
|
# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/transformer.py
|
||||||
|
import copy
|
||||||
|
import numbers
|
||||||
|
from functools import partial
|
||||||
|
from typing import Any
|
||||||
|
from typing import Callable
|
||||||
|
from typing import List
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Tuple
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from AR.modules.activation_onnx import MultiheadAttention
|
||||||
|
from AR.modules.scaling import BalancedDoubleSwish
|
||||||
|
from torch import nn
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
_shape_t = Union[int, List[int], torch.Size]
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNorm(nn.Module):
|
||||||
|
__constants__ = ["normalized_shape", "eps", "elementwise_affine"]
|
||||||
|
normalized_shape: Tuple[int, ...]
|
||||||
|
eps: float
|
||||||
|
elementwise_affine: bool
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
normalized_shape: _shape_t,
|
||||||
|
eps: float = 1e-5,
|
||||||
|
elementwise_affine: bool = True,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
) -> None:
|
||||||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
super(LayerNorm, self).__init__()
|
||||||
|
if isinstance(normalized_shape, numbers.Integral):
|
||||||
|
# mypy error: incompatible types in assignment
|
||||||
|
normalized_shape = (normalized_shape,) # type: ignore[assignment]
|
||||||
|
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
|
||||||
|
self.eps = eps
|
||||||
|
self.elementwise_affine = elementwise_affine
|
||||||
|
if self.elementwise_affine:
|
||||||
|
self.weight = nn.Parameter(
|
||||||
|
torch.empty(self.normalized_shape, **factory_kwargs)
|
||||||
|
)
|
||||||
|
self.bias = nn.Parameter(
|
||||||
|
torch.empty(self.normalized_shape, **factory_kwargs)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.register_parameter("weight", None)
|
||||||
|
self.register_parameter("bias", None)
|
||||||
|
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self) -> None:
|
||||||
|
if self.elementwise_affine:
|
||||||
|
nn.init.ones_(self.weight)
|
||||||
|
nn.init.zeros_(self.bias)
|
||||||
|
|
||||||
|
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
|
||||||
|
if isinstance(input, tuple):
|
||||||
|
input, embedding = input
|
||||||
|
return (
|
||||||
|
F.layer_norm(
|
||||||
|
input,
|
||||||
|
self.normalized_shape,
|
||||||
|
self.weight,
|
||||||
|
self.bias,
|
||||||
|
self.eps,
|
||||||
|
),
|
||||||
|
embedding,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert embedding is None
|
||||||
|
return F.layer_norm(
|
||||||
|
input, self.normalized_shape, self.weight, self.bias, self.eps
|
||||||
|
)
|
||||||
|
|
||||||
|
def extra_repr(self) -> str:
|
||||||
|
return (
|
||||||
|
"{normalized_shape}, eps={eps}, "
|
||||||
|
"elementwise_affine={elementwise_affine}".format(**self.__dict__)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class IdentityNorm(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model: int,
|
||||||
|
eps: float = 1e-5,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
) -> None:
|
||||||
|
super(IdentityNorm, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
|
||||||
|
if isinstance(input, tuple):
|
||||||
|
return input
|
||||||
|
|
||||||
|
assert embedding is None
|
||||||
|
return input
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerEncoder(nn.Module):
|
||||||
|
r"""TransformerEncoder is a stack of N encoder layers. Users can build the
|
||||||
|
BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encoder_layer: an instance of the TransformerEncoderLayer() class (required).
|
||||||
|
num_layers: the number of sub-encoder-layers in the encoder (required).
|
||||||
|
norm: the layer normalization component (optional).
|
||||||
|
enable_nested_tensor: if True, input will automatically convert to nested tensor
|
||||||
|
(and convert back on output). This will improve the overall performance of
|
||||||
|
TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
>>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
|
||||||
|
>>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6)
|
||||||
|
>>> src = torch.rand(10, 32, 512)
|
||||||
|
>>> out = transformer_encoder(src)
|
||||||
|
"""
|
||||||
|
__constants__ = ["norm"]
|
||||||
|
|
||||||
|
def __init__(self, encoder_layer, num_layers, norm=None):
|
||||||
|
super(TransformerEncoder, self).__init__()
|
||||||
|
self.layers = _get_clones(encoder_layer, num_layers)
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.norm = norm
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
src: Tensor,
|
||||||
|
mask: Optional[Tensor] = None,
|
||||||
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
|
return_layer_states: bool = False,
|
||||||
|
cache=None,
|
||||||
|
) -> Tensor:
|
||||||
|
output = src
|
||||||
|
for mod in self.layers:
|
||||||
|
output = mod(
|
||||||
|
output,
|
||||||
|
src_mask=mask,
|
||||||
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
|
cache=cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.norm is not None:
|
||||||
|
output = self.norm(output)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerEncoderLayer(nn.Module):
|
||||||
|
__constants__ = ["batch_first", "norm_first"]
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model: int,
|
||||||
|
nhead: int,
|
||||||
|
dim_feedforward: int = 2048,
|
||||||
|
dropout: float = 0.1,
|
||||||
|
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
|
||||||
|
batch_first: bool = False,
|
||||||
|
norm_first: bool = False,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
linear1_self_attention_cls: nn.Module = nn.Linear,
|
||||||
|
linear2_self_attention_cls: nn.Module = nn.Linear,
|
||||||
|
linear1_feedforward_cls: nn.Module = nn.Linear,
|
||||||
|
linear2_feedforward_cls: nn.Module = nn.Linear,
|
||||||
|
layer_norm_cls: nn.Module = LayerNorm,
|
||||||
|
layer_norm_eps: float = 1e-5,
|
||||||
|
adaptive_layer_norm=False,
|
||||||
|
) -> None:
|
||||||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
super(TransformerEncoderLayer, self).__init__()
|
||||||
|
self.self_attn = MultiheadAttention(
|
||||||
|
d_model, # 512 16
|
||||||
|
nhead,
|
||||||
|
dropout=dropout,
|
||||||
|
batch_first=batch_first,
|
||||||
|
linear1_cls=linear1_self_attention_cls,
|
||||||
|
linear2_cls=linear2_self_attention_cls,
|
||||||
|
**factory_kwargs,
|
||||||
|
)
|
||||||
|
self.linear1 = linear1_feedforward_cls(
|
||||||
|
d_model, dim_feedforward, **factory_kwargs
|
||||||
|
)
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
self.linear2 = linear2_feedforward_cls(
|
||||||
|
dim_feedforward, d_model, **factory_kwargs
|
||||||
|
)
|
||||||
|
self.norm_first = norm_first
|
||||||
|
self.dropout1 = nn.Dropout(dropout)
|
||||||
|
self.dropout2 = nn.Dropout(dropout)
|
||||||
|
if isinstance(activation, str):
|
||||||
|
activation = _get_activation_fn(activation)
|
||||||
|
elif isinstance(activation, partial):
|
||||||
|
activation = activation(d_model)
|
||||||
|
elif activation == BalancedDoubleSwish:
|
||||||
|
activation = BalancedDoubleSwish(d_model)
|
||||||
|
self.activation = activation
|
||||||
|
|
||||||
|
norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
|
||||||
|
if layer_norm_cls == IdentityNorm:
|
||||||
|
norm2 = BalancedBasicNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
||||||
|
else:
|
||||||
|
norm2 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
|
||||||
|
|
||||||
|
if adaptive_layer_norm:
|
||||||
|
self.norm1 = AdaptiveLayerNorm(d_model, norm1)
|
||||||
|
self.norm2 = AdaptiveLayerNorm(d_model, norm2)
|
||||||
|
else:
|
||||||
|
self.norm1 = norm1
|
||||||
|
self.norm2 = norm2
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
super(TransformerEncoderLayer, self).__setstate__(state)
|
||||||
|
if not hasattr(self, "activation"):
|
||||||
|
self.activation = F.relu
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
src: Tensor,
|
||||||
|
src_mask: Optional[Tensor] = None,
|
||||||
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
|
cache=None,
|
||||||
|
) -> Tensor:
|
||||||
|
x = src
|
||||||
|
stage_embedding = None
|
||||||
|
x = self.norm1(
|
||||||
|
x + self._sa_block(x, src_mask, src_key_padding_mask, cache=cache),
|
||||||
|
stage_embedding,
|
||||||
|
)
|
||||||
|
x = self.norm2(x + self._ff_block(x), stage_embedding)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _sa_block(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
attn_mask: Optional[Tensor],
|
||||||
|
key_padding_mask: Optional[Tensor],
|
||||||
|
cache=None,
|
||||||
|
) -> Tensor:
|
||||||
|
x = self.self_attn(
|
||||||
|
x,
|
||||||
|
x,
|
||||||
|
x,
|
||||||
|
attn_mask=attn_mask,
|
||||||
|
key_padding_mask=key_padding_mask,
|
||||||
|
need_weights=False,
|
||||||
|
cache=cache,
|
||||||
|
)
|
||||||
|
return self.dropout1(x)
|
||||||
|
|
||||||
|
def _ff_block(self, x: Tensor) -> Tensor:
|
||||||
|
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
||||||
|
return self.dropout2(x)
|
||||||
|
|
||||||
|
|
||||||
|
class AdaptiveLayerNorm(nn.Module):
|
||||||
|
r"""Adaptive Layer Normalization"""
|
||||||
|
|
||||||
|
def __init__(self, d_model, norm) -> None:
|
||||||
|
super(AdaptiveLayerNorm, self).__init__()
|
||||||
|
self.project_layer = nn.Linear(d_model, 2 * d_model)
|
||||||
|
self.norm = norm
|
||||||
|
self.d_model = d_model
|
||||||
|
self.eps = self.norm.eps
|
||||||
|
|
||||||
|
def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
|
||||||
|
if isinstance(input, tuple):
|
||||||
|
input, embedding = input
|
||||||
|
weight, bias = torch.split(
|
||||||
|
self.project_layer(embedding),
|
||||||
|
split_size_or_sections=self.d_model,
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
return (weight * self.norm(input) + bias, embedding)
|
||||||
|
|
||||||
|
weight, bias = torch.split(
|
||||||
|
self.project_layer(embedding),
|
||||||
|
split_size_or_sections=self.d_model,
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
return weight * self.norm(input) + bias
|
||||||
|
|
||||||
|
|
||||||
|
def _get_clones(module, N):
|
||||||
|
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
Loading…
x
Reference in New Issue
Block a user