delete some useless code & add some args type for export torch-script

This commit is contained in:
csh 2024-09-23 21:30:48 +08:00
parent 0464160c1e
commit 76d2cbaa1f
4 changed files with 93 additions and 61 deletions

View File

@ -83,7 +83,7 @@ class T2SMLP:
class T2SBlock:
def __init__(
self,
num_heads,
num_heads: int,
hidden_dim: int,
mlp: T2SMLP,
qkv_w,
@ -92,12 +92,12 @@ class T2SBlock:
out_b,
norm_w1,
norm_b1,
norm_eps1,
norm_eps1: float,
norm_w2,
norm_b2,
norm_eps2,
norm_eps2: float,
):
self.num_heads = num_heads
self.num_heads:int = num_heads
self.mlp = mlp
self.hidden_dim: int = hidden_dim
self.qkv_w = qkv_w
@ -266,7 +266,7 @@ class Text2SemanticDecoder(nn.Module):
self.norm_first = norm_first
self.vocab_size = config["model"]["vocab_size"]
self.phoneme_vocab_size = config["model"]["phoneme_vocab_size"]
self.p_dropout = config["model"]["dropout"]
self.p_dropout = float(config["model"]["dropout"])
self.EOS = config["model"]["EOS"]
self.norm_first = norm_first
assert self.EOS == self.vocab_size - 1

View File

@ -4,8 +4,8 @@ from torch import nn
from torch.nn import functional as F
from module import commons
from module.modules import LayerNorm
from typing import Optional
class LayerNorm(nn.Module):
def __init__(self, channels, eps=1e-5):
@ -59,6 +59,7 @@ class Encoder(nn.Module):
# self.cond_layer = weight_norm(cond_layer, name='weight')
# self.gin_channels = 256
self.cond_layer_idx = self.n_layers
self.spk_emb_linear = nn.Linear(256, self.hidden_channels)
if "gin_channels" in kwargs:
self.gin_channels = kwargs["gin_channels"]
if self.gin_channels != 0:
@ -98,22 +99,36 @@ class Encoder(nn.Module):
)
self.norm_layers_2.append(LayerNorm(hidden_channels))
def forward(self, x, x_mask, g=None):
# def forward(self, x, x_mask, g=None):
# attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
# x = x * x_mask
# for i in range(self.n_layers):
# if i == self.cond_layer_idx and g is not None:
# g = self.spk_emb_linear(g.transpose(1, 2))
# g = g.transpose(1, 2)
# x = x + g
# x = x * x_mask
# y = self.attn_layers[i](x, x, attn_mask)
# y = self.drop(y)
# x = self.norm_layers_1[i](x + y)
# y = self.ffn_layers[i](x, x_mask)
# y = self.drop(y)
# x = self.norm_layers_2[i](x + y)
# x = x * x_mask
# return x
def forward(self, x, x_mask):
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
x = x * x_mask
for i in range(self.n_layers):
if i == self.cond_layer_idx and g is not None:
g = self.spk_emb_linear(g.transpose(1, 2))
g = g.transpose(1, 2)
x = x + g
x = x * x_mask
y = self.attn_layers[i](x, x, attn_mask)
for attn_layers,norm_layers_1,ffn_layers,norm_layers_2 in zip(self.attn_layers,self.norm_layers_1,self.ffn_layers,self.norm_layers_2):
y = attn_layers(x, x, attn_mask)
y = self.drop(y)
x = self.norm_layers_1[i](x + y)
x = norm_layers_1(x + y)
y = self.ffn_layers[i](x, x_mask)
y = ffn_layers(x, x_mask)
y = self.drop(y)
x = self.norm_layers_2[i](x + y)
x = norm_layers_2(x + y)
x = x * x_mask
return x
@ -172,17 +187,18 @@ class MultiHeadAttention(nn.Module):
self.conv_k.weight.copy_(self.conv_q.weight)
self.conv_k.bias.copy_(self.conv_q.bias)
def forward(self, x, c, attn_mask=None):
def forward(self, x, c, attn_mask:Optional[torch.Tensor]=None):
q = self.conv_q(x)
k = self.conv_k(c)
v = self.conv_v(c)
x, self.attn = self.attention(q, k, v, mask=attn_mask)
# x, self.attn = self.attention(q, k, v, mask=attn_mask)
x, _ = self.attention(q, k, v, mask=attn_mask)
x = self.conv_o(x)
return x
def attention(self, query, key, value, mask=None):
def attention(self, query, key, value, mask:Optional[torch.Tensor]=None):
# reshape [b, d, t] -> [b, n_h, t, d_k]
b, d, t_s, _ = (*key.size(), query.size(2))
query = query.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3)
@ -304,7 +320,7 @@ class FFN(nn.Module):
filter_channels,
kernel_size,
p_dropout=0.0,
activation=None,
activation="",
causal=False,
):
super().__init__()
@ -316,10 +332,11 @@ class FFN(nn.Module):
self.activation = activation
self.causal = causal
if causal:
self.padding = self._causal_padding
else:
self.padding = self._same_padding
# 从上下文看这里一定是 False
# if causal:
# self.padding = self._causal_padding
# else:
# self.padding = self._same_padding
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
@ -335,6 +352,9 @@ class FFN(nn.Module):
x = self.conv_2(self.padding(x * x_mask))
return x * x_mask
def padding(self, x):
return self._same_padding(x)
def _causal_padding(self, x):
if self.kernel_size == 1:
return x
@ -352,3 +372,35 @@ class FFN(nn.Module):
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
x = F.pad(x, commons.convert_pad_shape(padding))
return x
class MRTE(nn.Module):
def __init__(
self,
content_enc_channels=192,
hidden_size=512,
out_channels=192,
kernel_size=5,
n_heads=4,
ge_layer=2,
):
super(MRTE, self).__init__()
self.cross_attention = MultiHeadAttention(hidden_size, hidden_size, n_heads)
self.c_pre = nn.Conv1d(content_enc_channels, hidden_size, 1)
self.text_pre = nn.Conv1d(content_enc_channels, hidden_size, 1)
self.c_post = nn.Conv1d(hidden_size, out_channels, 1)
def forward(self, ssl_enc, ssl_mask, text, text_mask, ge):
attn_mask = text_mask.unsqueeze(2) * ssl_mask.unsqueeze(-1)
ssl_enc = self.c_pre(ssl_enc * ssl_mask)
text_enc = self.text_pre(text * text_mask)
x = (
self.cross_attention(
ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
)
+ ssl_enc
+ ge
)
x = self.c_post(x * ssl_mask)
return x

View File

@ -13,10 +13,10 @@ def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
def convert_pad_shape(pad_shape):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
# def convert_pad_shape(pad_shape):
# l = pad_shape[::-1]
# pad_shape = [item for sublist in l for item in sublist]
# return pad_shape
def intersperse(lst, item):

View File

@ -1,5 +1,6 @@
import copy
import math
from typing import Optional
import torch
from torch import nn
from torch.nn import functional as F
@ -11,12 +12,10 @@ from module import attentions_onnx as attentions
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from module.commons import init_weights, get_padding
from module.mrte_model import MRTE
from module.quantize import ResidualVectorQuantizer
# from text import symbols
from text import symbols as symbols_v1
from text import symbols2 as symbols_v2
from torch.cuda.amp import autocast
class StochasticDurationPredictor(nn.Module):
@ -218,7 +217,7 @@ class TextEncoder(nn.Module):
symbols = symbols_v2.symbols
self.text_embedding = nn.Embedding(len(symbols), hidden_channels)
self.mrte = MRTE()
self.mrte = attentions.MRTE()
self.encoder2 = attentions.Encoder(
hidden_channels,
@ -249,25 +248,6 @@ class TextEncoder(nn.Module):
m, logs = torch.split(stats, self.out_channels, dim=1)
return y, m, logs, y_mask
def extract_latent(self, x):
x = self.ssl_proj(x)
quantized, codes, commit_loss, quantized_list = self.quantizer(x)
return codes.transpose(0, 1)
def decode_latent(self, codes, y_mask, refer, refer_mask, ge):
quantized = self.quantizer.decode(codes)
y = self.vq_proj(quantized) * y_mask
y = self.encoder_ssl(y * y_mask, y_mask)
y = self.mrte(y, y_mask, refer, refer_mask, ge)
y = self.encoder2(y * y_mask, y_mask)
stats = self.proj(y) * y_mask
m, logs = torch.split(stats, self.out_channels, dim=1)
return y, m, logs, y_mask, quantized
class ResidualCouplingBlock(nn.Module):
def __init__(
@ -448,7 +428,7 @@ class Generator(torch.nn.Module):
if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
def forward(self, x, g=None):
def forward(self, x, g:Optional[torch.Tensor]=None):
x = self.conv_pre(x)
if g is not None:
x = x + self.cond(g)
@ -870,15 +850,15 @@ class SynthesizerTrn(nn.Module):
upsample_kernel_sizes,
gin_channels=gin_channels,
)
self.enc_q = PosteriorEncoder(
spec_channels,
inter_channels,
hidden_channels,
5,
1,
16,
gin_channels=gin_channels,
)
# self.enc_q = PosteriorEncoder(
# spec_channels,
# inter_channels,
# hidden_channels,
# 5,
# 1,
# 16,
# gin_channels=gin_channels,
# )
self.flow = ResidualCouplingBlock(
inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
)