diff --git a/GPT_SoVITS/module/transformer_onnx.py b/GPT_SoVITS/module/transformer_onnx.py deleted file mode 100644 index a3f68b43..00000000 --- a/GPT_SoVITS/module/transformer_onnx.py +++ /dev/null @@ -1,292 +0,0 @@ -# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/transformer.py -import copy -import numbers -from functools import partial -from typing import Any -from typing import Callable -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union - -import torch -from AR.modules.activation_onnx import MultiheadAttention -from AR.modules.scaling import BalancedDoubleSwish -from torch import nn -from torch import Tensor -from torch.nn import functional as F - -_shape_t = Union[int, List[int], torch.Size] - - -class LayerNorm(nn.Module): - __constants__ = ["normalized_shape", "eps", "elementwise_affine"] - normalized_shape: Tuple[int, ...] - eps: float - elementwise_affine: bool - - def __init__( - self, - normalized_shape: _shape_t, - eps: float = 1e-5, - elementwise_affine: bool = True, - device=None, - dtype=None, - ) -> None: - factory_kwargs = {"device": device, "dtype": dtype} - super(LayerNorm, self).__init__() - if isinstance(normalized_shape, numbers.Integral): - # mypy error: incompatible types in assignment - normalized_shape = (normalized_shape,) # type: ignore[assignment] - self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] - self.eps = eps - self.elementwise_affine = elementwise_affine - if self.elementwise_affine: - self.weight = nn.Parameter( - torch.empty(self.normalized_shape, **factory_kwargs) - ) - self.bias = nn.Parameter( - torch.empty(self.normalized_shape, **factory_kwargs) - ) - else: - self.register_parameter("weight", None) - self.register_parameter("bias", None) - - self.reset_parameters() - - def reset_parameters(self) -> None: - if self.elementwise_affine: - nn.init.ones_(self.weight) - nn.init.zeros_(self.bias) - - def forward(self, input: Tensor, embedding: Any = None) -> Tensor: - if isinstance(input, tuple): - input, embedding = input - return ( - F.layer_norm( - input, - self.normalized_shape, - self.weight, - self.bias, - self.eps, - ), - embedding, - ) - - assert embedding is None - return F.layer_norm( - input, self.normalized_shape, self.weight, self.bias, self.eps - ) - - def extra_repr(self) -> str: - return ( - "{normalized_shape}, eps={eps}, " - "elementwise_affine={elementwise_affine}".format(**self.__dict__) - ) - - -class IdentityNorm(nn.Module): - def __init__( - self, - d_model: int, - eps: float = 1e-5, - device=None, - dtype=None, - ) -> None: - super(IdentityNorm, self).__init__() - - def forward(self, input: Tensor, embedding: Any = None) -> Tensor: - if isinstance(input, tuple): - return input - - assert embedding is None - return input - - -class TransformerEncoder(nn.Module): - r"""TransformerEncoder is a stack of N encoder layers. Users can build the - BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters. - - Args: - encoder_layer: an instance of the TransformerEncoderLayer() class (required). - num_layers: the number of sub-encoder-layers in the encoder (required). - norm: the layer normalization component (optional). - enable_nested_tensor: if True, input will automatically convert to nested tensor - (and convert back on output). This will improve the overall performance of - TransformerEncoder when padding rate is high. Default: ``True`` (enabled). - - Examples:: - >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) - >>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6) - >>> src = torch.rand(10, 32, 512) - >>> out = transformer_encoder(src) - """ - __constants__ = ["norm"] - - def __init__(self, encoder_layer, num_layers, norm=None): - super(TransformerEncoder, self).__init__() - self.layers = _get_clones(encoder_layer, num_layers) - self.num_layers = num_layers - self.norm = norm - - def forward( - self, - src: Tensor, - mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - return_layer_states: bool = False, - cache=None, - ) -> Tensor: - output = src - for mod in self.layers: - output = mod( - output, - src_mask=mask, - src_key_padding_mask=src_key_padding_mask, - cache=cache, - ) - - if self.norm is not None: - output = self.norm(output) - - return output - - -class TransformerEncoderLayer(nn.Module): - __constants__ = ["batch_first", "norm_first"] - def __init__( - self, - d_model: int, - nhead: int, - dim_feedforward: int = 2048, - dropout: float = 0.1, - activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, - batch_first: bool = False, - norm_first: bool = False, - device=None, - dtype=None, - linear1_self_attention_cls: nn.Module = nn.Linear, - linear2_self_attention_cls: nn.Module = nn.Linear, - linear1_feedforward_cls: nn.Module = nn.Linear, - linear2_feedforward_cls: nn.Module = nn.Linear, - layer_norm_cls: nn.Module = LayerNorm, - layer_norm_eps: float = 1e-5, - adaptive_layer_norm=False, - ) -> None: - factory_kwargs = {"device": device, "dtype": dtype} - super(TransformerEncoderLayer, self).__init__() - self.self_attn = MultiheadAttention( - d_model, # 512 16 - nhead, - dropout=dropout, - batch_first=batch_first, - linear1_cls=linear1_self_attention_cls, - linear2_cls=linear2_self_attention_cls, - **factory_kwargs, - ) - self.linear1 = linear1_feedforward_cls( - d_model, dim_feedforward, **factory_kwargs - ) - self.dropout = nn.Dropout(dropout) - self.linear2 = linear2_feedforward_cls( - dim_feedforward, d_model, **factory_kwargs - ) - self.norm_first = norm_first - self.dropout1 = nn.Dropout(dropout) - self.dropout2 = nn.Dropout(dropout) - if isinstance(activation, str): - activation = _get_activation_fn(activation) - elif isinstance(activation, partial): - activation = activation(d_model) - elif activation == BalancedDoubleSwish: - activation = BalancedDoubleSwish(d_model) - self.activation = activation - - norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs) - if layer_norm_cls == IdentityNorm: - norm2 = BalancedBasicNorm(d_model, eps=layer_norm_eps, **factory_kwargs) - else: - norm2 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs) - - if adaptive_layer_norm: - self.norm1 = AdaptiveLayerNorm(d_model, norm1) - self.norm2 = AdaptiveLayerNorm(d_model, norm2) - else: - self.norm1 = norm1 - self.norm2 = norm2 - - def __setstate__(self, state): - super(TransformerEncoderLayer, self).__setstate__(state) - if not hasattr(self, "activation"): - self.activation = F.relu - - def forward( - self, - src: Tensor, - src_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - cache=None, - ) -> Tensor: - x = src - stage_embedding = None - x = self.norm1( - x + self._sa_block(x, src_mask, src_key_padding_mask, cache=cache), - stage_embedding, - ) - x = self.norm2(x + self._ff_block(x), stage_embedding) - - return x - - def _sa_block( - self, - x: Tensor, - attn_mask: Optional[Tensor], - key_padding_mask: Optional[Tensor], - cache=None, - ) -> Tensor: - x = self.self_attn( - x, - x, - x, - attn_mask=attn_mask, - key_padding_mask=key_padding_mask, - need_weights=False, - cache=cache, - ) - return self.dropout1(x) - - def _ff_block(self, x: Tensor) -> Tensor: - x = self.linear2(self.dropout(self.activation(self.linear1(x)))) - return self.dropout2(x) - - -class AdaptiveLayerNorm(nn.Module): - r"""Adaptive Layer Normalization""" - - def __init__(self, d_model, norm) -> None: - super(AdaptiveLayerNorm, self).__init__() - self.project_layer = nn.Linear(d_model, 2 * d_model) - self.norm = norm - self.d_model = d_model - self.eps = self.norm.eps - - def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor: - if isinstance(input, tuple): - input, embedding = input - weight, bias = torch.split( - self.project_layer(embedding), - split_size_or_sections=self.d_model, - dim=-1, - ) - return (weight * self.norm(input) + bias, embedding) - - weight, bias = torch.split( - self.project_layer(embedding), - split_size_or_sections=self.d_model, - dim=-1, - ) - return weight * self.norm(input) + bias - - -def _get_clones(module, N): - return nn.ModuleList([copy.deepcopy(module) for i in range(N)])