fix: SynthesizerTrn

This commit is contained in:
CyberWon 2024-08-09 10:01:58 +08:00
parent bd53aa8200
commit 902e0121ac

View File

@ -1,6 +1,11 @@
import warnings
warnings.filterwarnings("ignore")
import copy import copy
import math import math
from typing import List import os
import pdb
import torch import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
@ -14,20 +19,22 @@ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from module.commons import init_weights, get_padding from module.commons import init_weights, get_padding
from module.mrte_model import MRTE from module.mrte_model import MRTE
from module.quantize import ResidualVectorQuantizer from module.quantize import ResidualVectorQuantizer
from text import symbols # from text import symbols
from text import symbols as symbols_v1
from text import symbols2 as symbols_v2
from torch.cuda.amp import autocast from torch.cuda.amp import autocast
import contextlib import contextlib
class StochasticDurationPredictor(nn.Module): class StochasticDurationPredictor(nn.Module):
def __init__( def __init__(
self, self,
in_channels, in_channels,
filter_channels, filter_channels,
kernel_size, kernel_size,
p_dropout, p_dropout,
n_flows=4, n_flows=4,
gin_channels=0, gin_channels=0,
): ):
super().__init__() super().__init__()
filter_channels = in_channels # it needs to be removed from future version. filter_channels = in_channels # it needs to be removed from future version.
@ -86,8 +93,8 @@ class StochasticDurationPredictor(nn.Module):
h_w = self.post_convs(h_w, x_mask) h_w = self.post_convs(h_w, x_mask)
h_w = self.post_proj(h_w) * x_mask h_w = self.post_proj(h_w) * x_mask
e_q = ( e_q = (
torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
* x_mask * x_mask
) )
z_q = e_q z_q = e_q
for flow in self.post_flows: for flow in self.post_flows:
@ -100,8 +107,8 @@ class StochasticDurationPredictor(nn.Module):
(F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2] (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
) )
logq = ( logq = (
torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2]) torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q ** 2)) * x_mask, [1, 2])
- logdet_tot_q - logdet_tot_q
) )
logdet_tot = 0 logdet_tot = 0
@ -112,16 +119,16 @@ class StochasticDurationPredictor(nn.Module):
z, logdet = flow(z, x_mask, g=x, reverse=reverse) z, logdet = flow(z, x_mask, g=x, reverse=reverse)
logdet_tot = logdet_tot + logdet logdet_tot = logdet_tot + logdet
nll = ( nll = (
torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) torch.sum(0.5 * (math.log(2 * math.pi) + (z ** 2)) * x_mask, [1, 2])
- logdet_tot - logdet_tot
) )
return nll + logq # [b] return nll + logq # [b]
else: else:
flows = list(reversed(self.flows)) flows = list(reversed(self.flows))
flows = flows[:-2] + [flows[-1]] # remove a useless vflow flows = flows[:-2] + [flows[-1]] # remove a useless vflow
z = ( z = (
torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
* noise_scale * noise_scale
) )
for flow in flows: for flow in flows:
z = flow(z, x_mask, g=x, reverse=reverse) z = flow(z, x_mask, g=x, reverse=reverse)
@ -132,7 +139,7 @@ class StochasticDurationPredictor(nn.Module):
class DurationPredictor(nn.Module): class DurationPredictor(nn.Module):
def __init__( def __init__(
self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0 self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
): ):
super().__init__() super().__init__()
@ -175,15 +182,16 @@ class DurationPredictor(nn.Module):
class TextEncoder(nn.Module): class TextEncoder(nn.Module):
def __init__( def __init__(
self, self,
out_channels, out_channels,
hidden_channels, hidden_channels,
filter_channels, filter_channels,
n_heads, n_heads,
n_layers, n_layers,
kernel_size, kernel_size,
p_dropout, p_dropout,
latent_channels=192, latent_channels=192,
version="v1",
): ):
super().__init__() super().__init__()
self.out_channels = out_channels self.out_channels = out_channels
@ -194,6 +202,7 @@ class TextEncoder(nn.Module):
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.p_dropout = p_dropout self.p_dropout = p_dropout
self.latent_channels = latent_channels self.latent_channels = latent_channels
self.version = version
self.ssl_proj = nn.Conv1d(768, hidden_channels, 1) self.ssl_proj = nn.Conv1d(768, hidden_channels, 1)
@ -209,6 +218,11 @@ class TextEncoder(nn.Module):
self.encoder_text = attentions.Encoder( self.encoder_text = attentions.Encoder(
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
) )
if self.version == "v1":
symbols = symbols_v1.symbols
else:
symbols = symbols_v2.symbols
self.text_embedding = nn.Embedding(len(symbols), hidden_channels) self.text_embedding = nn.Embedding(len(symbols), hidden_channels)
self.mrte = MRTE() self.mrte = MRTE()
@ -224,7 +238,7 @@ class TextEncoder(nn.Module):
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, y, y_lengths, text, text_lengths, ge, test=None): def forward(self, y, y_lengths, text, text_lengths, ge, speed=1, test=None):
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to( y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
y.dtype y.dtype
) )
@ -241,9 +255,10 @@ class TextEncoder(nn.Module):
text = self.text_embedding(text).transpose(1, 2) text = self.text_embedding(text).transpose(1, 2)
text = self.encoder_text(text * text_mask, text_mask) text = self.encoder_text(text * text_mask, text_mask)
y = self.mrte(y, y_mask, text, text_mask, ge) y = self.mrte(y, y_mask, text, text_mask, ge)
y = self.encoder2(y * y_mask, y_mask) y = self.encoder2(y * y_mask, y_mask)
if (speed != 1):
y = F.interpolate(y, size=int(y.shape[-1] / speed) + 1, mode="linear")
y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest")
stats = self.proj(y) * y_mask stats = self.proj(y) * y_mask
m, logs = torch.split(stats, self.out_channels, dim=1) m, logs = torch.split(stats, self.out_channels, dim=1)
return y, m, logs, y_mask return y, m, logs, y_mask
@ -270,14 +285,14 @@ class TextEncoder(nn.Module):
class ResidualCouplingBlock(nn.Module): class ResidualCouplingBlock(nn.Module):
def __init__( def __init__(
self, self,
channels, channels,
hidden_channels, hidden_channels,
kernel_size, kernel_size,
dilation_rate, dilation_rate,
n_layers, n_layers,
n_flows=4, n_flows=4,
gin_channels=0, gin_channels=0,
): ):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
@ -315,14 +330,14 @@ class ResidualCouplingBlock(nn.Module):
class PosteriorEncoder(nn.Module): class PosteriorEncoder(nn.Module):
def __init__( def __init__(
self, self,
in_channels, in_channels,
out_channels, out_channels,
hidden_channels, hidden_channels,
kernel_size, kernel_size,
dilation_rate, dilation_rate,
n_layers, n_layers,
gin_channels=0, gin_channels=0,
): ):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
@ -359,14 +374,14 @@ class PosteriorEncoder(nn.Module):
class WNEncoder(nn.Module): class WNEncoder(nn.Module):
def __init__( def __init__(
self, self,
in_channels, in_channels,
out_channels, out_channels,
hidden_channels, hidden_channels,
kernel_size, kernel_size,
dilation_rate, dilation_rate,
n_layers, n_layers,
gin_channels=0, gin_channels=0,
): ):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
@ -401,15 +416,15 @@ class WNEncoder(nn.Module):
class Generator(torch.nn.Module): class Generator(torch.nn.Module):
def __init__( def __init__(
self, self,
initial_channel, initial_channel,
resblock, resblock,
resblock_kernel_sizes, resblock_kernel_sizes,
resblock_dilation_sizes, resblock_dilation_sizes,
upsample_rates, upsample_rates,
upsample_initial_channel, upsample_initial_channel,
upsample_kernel_sizes, upsample_kernel_sizes,
gin_channels=0, gin_channels=0,
): ):
super(Generator, self).__init__() super(Generator, self).__init__()
self.num_kernels = len(resblock_kernel_sizes) self.num_kernels = len(resblock_kernel_sizes)
@ -424,7 +439,7 @@ class Generator(torch.nn.Module):
self.ups.append( self.ups.append(
weight_norm( weight_norm(
ConvTranspose1d( ConvTranspose1d(
upsample_initial_channel // (2**i), upsample_initial_channel // (2 ** i),
upsample_initial_channel // (2 ** (i + 1)), upsample_initial_channel // (2 ** (i + 1)),
k, k,
u, u,
@ -437,7 +452,7 @@ class Generator(torch.nn.Module):
for i in range(len(self.ups)): for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1)) ch = upsample_initial_channel // (2 ** (i + 1))
for j, (k, d) in enumerate( for j, (k, d) in enumerate(
zip(resblock_kernel_sizes, resblock_dilation_sizes) zip(resblock_kernel_sizes, resblock_dilation_sizes)
): ):
self.resblocks.append(resblock(ch, k, d)) self.resblocks.append(resblock(ch, k, d))
@ -679,9 +694,9 @@ class Quantizer_module(torch.nn.Module):
def forward(self, x): def forward(self, x):
d = ( d = (
torch.sum(x**2, 1, keepdim=True) torch.sum(x ** 2, 1, keepdim=True)
+ torch.sum(self.embedding.weight**2, 1) + torch.sum(self.embedding.weight ** 2, 1)
- 2 * torch.matmul(x, self.embedding.weight.T) - 2 * torch.matmul(x, self.embedding.weight.T)
) )
min_indicies = torch.argmin(d, 1) min_indicies = torch.argmin(d, 1)
z_q = self.embedding(min_indicies) z_q = self.embedding(min_indicies)
@ -736,16 +751,16 @@ class Quantizer(torch.nn.Module):
class CodePredictor(nn.Module): class CodePredictor(nn.Module):
def __init__( def __init__(
self, self,
hidden_channels, hidden_channels,
filter_channels, filter_channels,
n_heads, n_heads,
n_layers, n_layers,
kernel_size, kernel_size,
p_dropout, p_dropout,
n_q=8, n_q=8,
dims=1024, dims=1024,
ssl_dim=768, ssl_dim=768,
): ):
super().__init__() super().__init__()
self.hidden_channels = hidden_channels self.hidden_channels = hidden_channels
@ -804,28 +819,29 @@ class SynthesizerTrn(nn.Module):
""" """
def __init__( def __init__(
self, self,
spec_channels, spec_channels,
segment_size, segment_size,
inter_channels, inter_channels,
hidden_channels, hidden_channels,
filter_channels, filter_channels,
n_heads, n_heads,
n_layers, n_layers,
kernel_size, kernel_size,
p_dropout, p_dropout,
resblock, resblock,
resblock_kernel_sizes, resblock_kernel_sizes,
resblock_dilation_sizes, resblock_dilation_sizes,
upsample_rates, upsample_rates,
upsample_initial_channel, upsample_initial_channel,
upsample_kernel_sizes, upsample_kernel_sizes,
n_speakers=0, n_speakers=0,
gin_channels=0, gin_channels=0,
use_sdp=True, use_sdp=True,
semantic_frame_rate=None, semantic_frame_rate=None,
freeze_quantizer=None, freeze_quantizer=None,
**kwargs version="v1",
**kwargs
): ):
super().__init__() super().__init__()
self.spec_channels = spec_channels self.spec_channels = spec_channels
@ -845,6 +861,7 @@ class SynthesizerTrn(nn.Module):
self.segment_size = segment_size self.segment_size = segment_size
self.n_speakers = n_speakers self.n_speakers = n_speakers
self.gin_channels = gin_channels self.gin_channels = gin_channels
self.version = version
self.use_sdp = use_sdp self.use_sdp = use_sdp
self.enc_p = TextEncoder( self.enc_p = TextEncoder(
@ -855,6 +872,7 @@ class SynthesizerTrn(nn.Module):
n_layers, n_layers,
kernel_size, kernel_size,
p_dropout, p_dropout,
version=version,
) )
self.dec = Generator( self.dec = Generator(
inter_channels, inter_channels,
@ -879,9 +897,11 @@ class SynthesizerTrn(nn.Module):
inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
) )
self.ref_enc = modules.MelStyleEncoder( # self.version=os.environ.get("version","v1")
spec_channels, style_vector_dim=gin_channels if (self.version == "v1"):
) self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)
else:
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels)
ssl_dim = 768 ssl_dim = 768
assert semantic_frame_rate in ["25hz", "50hz"] assert semantic_frame_rate in ["25hz", "50hz"]
@ -893,20 +913,15 @@ class SynthesizerTrn(nn.Module):
self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024) self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
self.freeze_quantizer = freeze_quantizer self.freeze_quantizer = freeze_quantizer
# if freeze_quantizer:
# self.ssl_proj.requires_grad_(False)
# self.quantizer.requires_grad_(False)
#self.quantizer.eval()
# self.enc_p.text_embedding.requires_grad_(False)
# self.enc_p.encoder_text.requires_grad_(False)
# self.enc_p.mrte.requires_grad_(False)
def forward(self, ssl, y, y_lengths, text, text_lengths): def forward(self, ssl, y, y_lengths, text, text_lengths):
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to( y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
y.dtype y.dtype
) )
ge = self.ref_enc(y * y_mask, y_mask) if (self.version == "v1"):
ge = self.ref_enc(y * y_mask, y_mask)
else:
ge = self.ref_enc(y[:, :704] * y_mask, y_mask)
with autocast(enabled=False): with autocast(enabled=False):
maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext() maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext()
with maybe_no_grad: with maybe_no_grad:
@ -947,7 +962,10 @@ class SynthesizerTrn(nn.Module):
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to( y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
y.dtype y.dtype
) )
ge = self.ref_enc(y * y_mask, y_mask) if (self.version == "v1"):
ge = self.ref_enc(y * y_mask, y_mask)
else:
ge = self.ref_enc(y[:, :704] * y_mask, y_mask)
ssl = self.ssl_proj(ssl) ssl = self.ssl_proj(ssl)
quantized, codes, commit_loss, _ = self.quantizer(ssl, layers=[0]) quantized, codes, commit_loss, _ = self.quantizer(ssl, layers=[0])
@ -967,14 +985,28 @@ class SynthesizerTrn(nn.Module):
return o, y_mask, (z, z_p, m_p, logs_p) return o, y_mask, (z, z_p, m_p, logs_p)
@torch.no_grad() @torch.no_grad()
def decode(self, codes, text, refer, noise_scale=0.5): def decode(self, codes, text, refer, noise_scale=0.5, speed=1):
ge = None def get_ge(refer):
if refer is not None: ge = None
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device) if refer is not None:
refer_mask = torch.unsqueeze( refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
commons.sequence_mask(refer_lengths, refer.size(2)), 1 refer_mask = torch.unsqueeze(
).to(refer.dtype) commons.sequence_mask(refer_lengths, refer.size(2)), 1
ge = self.ref_enc(refer * refer_mask, refer_mask) ).to(refer.dtype)
if (self.version == "v1"):
ge = self.ref_enc(refer * refer_mask, refer_mask)
else:
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
return ge
if (type(refer) == list):
ges = []
for _refer in refer:
ge = get_ge(_refer)
ges.append(ge)
ge = torch.stack(ges, 0).mean(0)
else:
ge = get_ge(refer)
y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device) y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device) text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
@ -984,9 +1016,8 @@ class SynthesizerTrn(nn.Module):
quantized = F.interpolate( quantized = F.interpolate(
quantized, size=int(quantized.shape[-1] * 2), mode="nearest" quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
) )
x, m_p, logs_p, y_mask = self.enc_p( x, m_p, logs_p, y_mask = self.enc_p(
quantized, y_lengths, text, text_lengths, ge quantized, y_lengths, text, text_lengths, ge, speed
) )
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
@ -995,55 +1026,6 @@ class SynthesizerTrn(nn.Module):
o = self.dec((z * y_mask)[:, :, :], g=ge) o = self.dec((z * y_mask)[:, :, :], g=ge)
return o return o
@torch.no_grad()
def batched_decode(self, codes, y_lengths, text, text_lengths, refer, noise_scale=0.5):
ge = None
if refer is not None:
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
refer_mask = torch.unsqueeze(
commons.sequence_mask(refer_lengths, refer.size(2)), 1
).to(refer.dtype)
ge = self.ref_enc(refer * refer_mask, refer_mask)
# y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, codes.size(2)), 1).to(
# codes.dtype
# )
y_lengths = (y_lengths * 2).long().to(codes.device)
text_lengths = text_lengths.long().to(text.device)
# y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
# text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
# 假设padding之后再decode没有问题, 影响未知,但听起来好像没问题?
quantized = self.quantizer.decode(codes)
if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(
quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
)
x, m_p, logs_p, y_mask = self.enc_p(
quantized, y_lengths, text, text_lengths, ge
)
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=ge, reverse=True)
z_masked = (z * y_mask)[:, :, :]
# 串行。把padding部分去掉再decode
o_list:List[torch.Tensor] = []
for i in range(z_masked.shape[0]):
z_slice = z_masked[i, :, :y_lengths[i]].unsqueeze(0)
o = self.dec(z_slice, g=ge)[0, 0, :].detach()
o_list.append(o)
# 并行会有问题。先decode再把padding的部分去掉
# o = self.dec(z_masked, g=ge)
# upsample_rate = int(math.prod(self.upsample_rates))
# o_lengths = y_lengths*upsample_rate
# o_list = [o[i, 0, :idx].detach() for i, idx in enumerate(o_lengths)]
return o_list
def extract_latent(self, x): def extract_latent(self, x):
ssl = self.ssl_proj(x) ssl = self.ssl_proj(x)
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl) quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)