From 76d2cbaa1f209da44162aef9112ea466be1c593c Mon Sep 17 00:00:00 2001 From: csh <458761603@qq.com> Date: Mon, 23 Sep 2024 21:30:48 +0800 Subject: [PATCH] delete some useless code & add some args type for export torch-script --- GPT_SoVITS/AR/models/t2s_model.py | 10 +-- GPT_SoVITS/module/attentions_onnx.py | 92 ++++++++++++++++++++++------ GPT_SoVITS/module/commons.py | 8 +-- GPT_SoVITS/module/models_onnx.py | 44 ++++--------- 4 files changed, 93 insertions(+), 61 deletions(-) diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index 31acadcc..3cae7b37 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -83,7 +83,7 @@ class T2SMLP: class T2SBlock: def __init__( self, - num_heads, + num_heads: int, hidden_dim: int, mlp: T2SMLP, qkv_w, @@ -92,12 +92,12 @@ class T2SBlock: out_b, norm_w1, norm_b1, - norm_eps1, + norm_eps1: float, norm_w2, norm_b2, - norm_eps2, + norm_eps2: float, ): - self.num_heads = num_heads + self.num_heads:int = num_heads self.mlp = mlp self.hidden_dim: int = hidden_dim self.qkv_w = qkv_w @@ -266,7 +266,7 @@ class Text2SemanticDecoder(nn.Module): self.norm_first = norm_first self.vocab_size = config["model"]["vocab_size"] self.phoneme_vocab_size = config["model"]["phoneme_vocab_size"] - self.p_dropout = config["model"]["dropout"] + self.p_dropout = float(config["model"]["dropout"]) self.EOS = config["model"]["EOS"] self.norm_first = norm_first assert self.EOS == self.vocab_size - 1 diff --git a/GPT_SoVITS/module/attentions_onnx.py b/GPT_SoVITS/module/attentions_onnx.py index bc63a06f..097b1b9c 100644 --- a/GPT_SoVITS/module/attentions_onnx.py +++ b/GPT_SoVITS/module/attentions_onnx.py @@ -4,8 +4,8 @@ from torch import nn from torch.nn import functional as F from module import commons -from module.modules import LayerNorm +from typing import Optional class LayerNorm(nn.Module): def __init__(self, channels, eps=1e-5): @@ -59,6 +59,7 @@ class Encoder(nn.Module): # self.cond_layer = weight_norm(cond_layer, name='weight') # self.gin_channels = 256 self.cond_layer_idx = self.n_layers + self.spk_emb_linear = nn.Linear(256, self.hidden_channels) if "gin_channels" in kwargs: self.gin_channels = kwargs["gin_channels"] if self.gin_channels != 0: @@ -98,22 +99,36 @@ class Encoder(nn.Module): ) self.norm_layers_2.append(LayerNorm(hidden_channels)) - def forward(self, x, x_mask, g=None): + # def forward(self, x, x_mask, g=None): + # attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + # x = x * x_mask + # for i in range(self.n_layers): + # if i == self.cond_layer_idx and g is not None: + # g = self.spk_emb_linear(g.transpose(1, 2)) + # g = g.transpose(1, 2) + # x = x + g + # x = x * x_mask + # y = self.attn_layers[i](x, x, attn_mask) + # y = self.drop(y) + # x = self.norm_layers_1[i](x + y) + + # y = self.ffn_layers[i](x, x_mask) + # y = self.drop(y) + # x = self.norm_layers_2[i](x + y) + # x = x * x_mask + # return x + + def forward(self, x, x_mask): attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) x = x * x_mask - for i in range(self.n_layers): - if i == self.cond_layer_idx and g is not None: - g = self.spk_emb_linear(g.transpose(1, 2)) - g = g.transpose(1, 2) - x = x + g - x = x * x_mask - y = self.attn_layers[i](x, x, attn_mask) + for attn_layers,norm_layers_1,ffn_layers,norm_layers_2 in zip(self.attn_layers,self.norm_layers_1,self.ffn_layers,self.norm_layers_2): + y = attn_layers(x, x, attn_mask) y = self.drop(y) - x = self.norm_layers_1[i](x + y) + x = norm_layers_1(x + y) - y = self.ffn_layers[i](x, x_mask) + y = ffn_layers(x, x_mask) y = self.drop(y) - x = self.norm_layers_2[i](x + y) + x = norm_layers_2(x + y) x = x * x_mask return x @@ -172,17 +187,18 @@ class MultiHeadAttention(nn.Module): self.conv_k.weight.copy_(self.conv_q.weight) self.conv_k.bias.copy_(self.conv_q.bias) - def forward(self, x, c, attn_mask=None): + def forward(self, x, c, attn_mask:Optional[torch.Tensor]=None): q = self.conv_q(x) k = self.conv_k(c) v = self.conv_v(c) - x, self.attn = self.attention(q, k, v, mask=attn_mask) + # x, self.attn = self.attention(q, k, v, mask=attn_mask) + x, _ = self.attention(q, k, v, mask=attn_mask) x = self.conv_o(x) return x - def attention(self, query, key, value, mask=None): + def attention(self, query, key, value, mask:Optional[torch.Tensor]=None): # reshape [b, d, t] -> [b, n_h, t, d_k] b, d, t_s, _ = (*key.size(), query.size(2)) query = query.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3) @@ -304,7 +320,7 @@ class FFN(nn.Module): filter_channels, kernel_size, p_dropout=0.0, - activation=None, + activation="", causal=False, ): super().__init__() @@ -316,10 +332,11 @@ class FFN(nn.Module): self.activation = activation self.causal = causal - if causal: - self.padding = self._causal_padding - else: - self.padding = self._same_padding + # 从上下文看这里一定是 False + # if causal: + # self.padding = self._causal_padding + # else: + # self.padding = self._same_padding self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size) self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) @@ -334,6 +351,9 @@ class FFN(nn.Module): x = self.drop(x) x = self.conv_2(self.padding(x * x_mask)) return x * x_mask + + def padding(self, x): + return self._same_padding(x) def _causal_padding(self, x): if self.kernel_size == 1: @@ -352,3 +372,35 @@ class FFN(nn.Module): padding = [[0, 0], [0, 0], [pad_l, pad_r]] x = F.pad(x, commons.convert_pad_shape(padding)) return x + + +class MRTE(nn.Module): + def __init__( + self, + content_enc_channels=192, + hidden_size=512, + out_channels=192, + kernel_size=5, + n_heads=4, + ge_layer=2, + ): + super(MRTE, self).__init__() + self.cross_attention = MultiHeadAttention(hidden_size, hidden_size, n_heads) + self.c_pre = nn.Conv1d(content_enc_channels, hidden_size, 1) + self.text_pre = nn.Conv1d(content_enc_channels, hidden_size, 1) + self.c_post = nn.Conv1d(hidden_size, out_channels, 1) + + def forward(self, ssl_enc, ssl_mask, text, text_mask, ge): + attn_mask = text_mask.unsqueeze(2) * ssl_mask.unsqueeze(-1) + + ssl_enc = self.c_pre(ssl_enc * ssl_mask) + text_enc = self.text_pre(text * text_mask) + x = ( + self.cross_attention( + ssl_enc * ssl_mask, text_enc * text_mask, attn_mask + ) + + ssl_enc + + ge + ) + x = self.c_post(x * ssl_mask) + return x diff --git a/GPT_SoVITS/module/commons.py b/GPT_SoVITS/module/commons.py index e96cf923..6083535f 100644 --- a/GPT_SoVITS/module/commons.py +++ b/GPT_SoVITS/module/commons.py @@ -13,10 +13,10 @@ def get_padding(kernel_size, dilation=1): return int((kernel_size * dilation - dilation) / 2) -def convert_pad_shape(pad_shape): - l = pad_shape[::-1] - pad_shape = [item for sublist in l for item in sublist] - return pad_shape +# def convert_pad_shape(pad_shape): +# l = pad_shape[::-1] +# pad_shape = [item for sublist in l for item in sublist] +# return pad_shape def intersperse(lst, item): diff --git a/GPT_SoVITS/module/models_onnx.py b/GPT_SoVITS/module/models_onnx.py index 77ae3074..b39f4b85 100644 --- a/GPT_SoVITS/module/models_onnx.py +++ b/GPT_SoVITS/module/models_onnx.py @@ -1,5 +1,6 @@ import copy import math +from typing import Optional import torch from torch import nn from torch.nn import functional as F @@ -11,12 +12,10 @@ from module import attentions_onnx as attentions from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm from module.commons import init_weights, get_padding -from module.mrte_model import MRTE from module.quantize import ResidualVectorQuantizer # from text import symbols from text import symbols as symbols_v1 from text import symbols2 as symbols_v2 -from torch.cuda.amp import autocast class StochasticDurationPredictor(nn.Module): @@ -218,7 +217,7 @@ class TextEncoder(nn.Module): symbols = symbols_v2.symbols self.text_embedding = nn.Embedding(len(symbols), hidden_channels) - self.mrte = MRTE() + self.mrte = attentions.MRTE() self.encoder2 = attentions.Encoder( hidden_channels, @@ -249,25 +248,6 @@ class TextEncoder(nn.Module): m, logs = torch.split(stats, self.out_channels, dim=1) return y, m, logs, y_mask - def extract_latent(self, x): - x = self.ssl_proj(x) - quantized, codes, commit_loss, quantized_list = self.quantizer(x) - return codes.transpose(0, 1) - - def decode_latent(self, codes, y_mask, refer, refer_mask, ge): - quantized = self.quantizer.decode(codes) - - y = self.vq_proj(quantized) * y_mask - y = self.encoder_ssl(y * y_mask, y_mask) - - y = self.mrte(y, y_mask, refer, refer_mask, ge) - - y = self.encoder2(y * y_mask, y_mask) - - stats = self.proj(y) * y_mask - m, logs = torch.split(stats, self.out_channels, dim=1) - return y, m, logs, y_mask, quantized - class ResidualCouplingBlock(nn.Module): def __init__( @@ -448,7 +428,7 @@ class Generator(torch.nn.Module): if gin_channels != 0: self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) - def forward(self, x, g=None): + def forward(self, x, g:Optional[torch.Tensor]=None): x = self.conv_pre(x) if g is not None: x = x + self.cond(g) @@ -870,15 +850,15 @@ class SynthesizerTrn(nn.Module): upsample_kernel_sizes, gin_channels=gin_channels, ) - self.enc_q = PosteriorEncoder( - spec_channels, - inter_channels, - hidden_channels, - 5, - 1, - 16, - gin_channels=gin_channels, - ) + # self.enc_q = PosteriorEncoder( + # spec_channels, + # inter_channels, + # hidden_channels, + # 5, + # 1, + # 16, + # gin_channels=gin_channels, + # ) self.flow = ResidualCouplingBlock( inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels )