mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-07 23:48:48 +08:00
211 lines
6.0 KiB
Python
211 lines
6.0 KiB
Python
import logging
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch import Tensor, nn
|
|
from torch.nn.utils.parametrizations import weight_norm
|
|
|
|
from ..hparams import HParams
|
|
from .mrstft import get_stft_cfgs
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class PeriodNetwork(nn.Module):
|
|
def __init__(self, period):
|
|
super().__init__()
|
|
self.period = period
|
|
wn = weight_norm
|
|
self.convs = nn.ModuleList(
|
|
[
|
|
wn(nn.Conv2d(1, 64, (5, 1), (3, 1), padding=(2, 0))),
|
|
wn(nn.Conv2d(64, 128, (5, 1), (3, 1), padding=(2, 0))),
|
|
wn(nn.Conv2d(128, 256, (5, 1), (3, 1), padding=(2, 0))),
|
|
wn(nn.Conv2d(256, 512, (5, 1), (3, 1), padding=(2, 0))),
|
|
wn(nn.Conv2d(512, 1024, (5, 1), 1, padding=(2, 0))),
|
|
]
|
|
)
|
|
self.conv_post = wn(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
|
|
|
def forward(self, x):
|
|
"""
|
|
Args:
|
|
x: [B, 1, T]
|
|
"""
|
|
assert x.dim() == 3, f"(B, 1, T) is expected, but got {x.shape}."
|
|
|
|
# 1d to 2d
|
|
b, c, t = x.shape
|
|
if t % self.period != 0: # pad first
|
|
n_pad = self.period - (t % self.period)
|
|
x = F.pad(x, (0, n_pad), "reflect")
|
|
t = t + n_pad
|
|
x = x.view(b, c, t // self.period, self.period)
|
|
|
|
for l in self.convs:
|
|
x = l(x)
|
|
x = F.leaky_relu(x, 0.2)
|
|
x = self.conv_post(x)
|
|
x = torch.flatten(x, 1, -1)
|
|
|
|
return x
|
|
|
|
|
|
class SpecNetwork(nn.Module):
|
|
def __init__(self, stft_cfg: dict):
|
|
super().__init__()
|
|
wn = weight_norm
|
|
self.stft_cfg = stft_cfg
|
|
self.convs = nn.ModuleList(
|
|
[
|
|
wn(nn.Conv2d(1, 32, (3, 9), padding=(1, 4))),
|
|
wn(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
|
|
wn(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
|
|
wn(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
|
|
wn(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))),
|
|
]
|
|
)
|
|
self.conv_post = wn(nn.Conv2d(32, 1, (3, 3), padding=(1, 1)))
|
|
|
|
def forward(self, x):
|
|
"""
|
|
Args:
|
|
x: [B, 1, T]
|
|
"""
|
|
x = self.spectrogram(x)
|
|
x = x.unsqueeze(1)
|
|
for l in self.convs:
|
|
x = l(x)
|
|
x = F.leaky_relu(x, 0.2)
|
|
x = self.conv_post(x)
|
|
x = x.flatten(1, -1)
|
|
return x
|
|
|
|
def spectrogram(self, x):
|
|
"""
|
|
Args:
|
|
x: [B, 1, T]
|
|
"""
|
|
x = x.squeeze(1)
|
|
dtype = x.dtype
|
|
stft_cfg = dict(self.stft_cfg)
|
|
x = torch.stft(x.float(), center=False, return_complex=False, **stft_cfg)
|
|
mag = x.norm(p=2, dim=-1) # [B, F, TT]
|
|
mag = mag.to(dtype) # [B, F, TT]
|
|
return mag
|
|
|
|
|
|
class MD(nn.ModuleList):
|
|
def __init__(self, l: list):
|
|
super().__init__([self._create_network(x) for x in l])
|
|
self._loss_type = None
|
|
|
|
def loss_type_(self, loss_type):
|
|
self._loss_type = loss_type
|
|
|
|
def _create_network(self, _):
|
|
raise NotImplementedError
|
|
|
|
def _forward_each(self, d, x, y):
|
|
assert self._loss_type is not None, "loss_type is not set."
|
|
loss_type = self._loss_type
|
|
|
|
if loss_type == "hinge":
|
|
if y == 0:
|
|
# d(x) should be small -> -1
|
|
loss = F.relu(1 + d(x)).mean()
|
|
elif y == 1:
|
|
# d(x) should be large -> 1
|
|
loss = F.relu(1 - d(x)).mean()
|
|
else:
|
|
raise ValueError(f"Invalid y: {y}")
|
|
elif loss_type == "wgan":
|
|
if y == 0:
|
|
loss = d(x).mean()
|
|
elif y == 1:
|
|
loss = -d(x).mean()
|
|
else:
|
|
raise ValueError(f"Invalid y: {y}")
|
|
else:
|
|
raise ValueError(f"Invalid loss_type: {loss_type}")
|
|
|
|
return loss
|
|
|
|
def forward(self, x, y) -> Tensor:
|
|
losses = [self._forward_each(d, x, y) for d in self]
|
|
return torch.stack(losses).mean()
|
|
|
|
|
|
class MPD(MD):
|
|
def __init__(self):
|
|
super().__init__([2, 3, 7, 13, 17])
|
|
|
|
def _create_network(self, period):
|
|
return PeriodNetwork(period)
|
|
|
|
|
|
class MRD(MD):
|
|
def __init__(self, stft_cfgs):
|
|
super().__init__(stft_cfgs)
|
|
|
|
def _create_network(self, stft_cfg):
|
|
return SpecNetwork(stft_cfg)
|
|
|
|
|
|
class Discriminator(nn.Module):
|
|
@property
|
|
def wav_rate(self):
|
|
return self.hp.wav_rate
|
|
|
|
def __init__(self, hp: HParams):
|
|
super().__init__()
|
|
self.hp = hp
|
|
self.stft_cfgs = get_stft_cfgs(hp)
|
|
self.mpd = MPD()
|
|
self.mrd = MRD(self.stft_cfgs)
|
|
self.dummy_float: Tensor
|
|
self.register_buffer("dummy_float", torch.zeros(0), persistent=False)
|
|
|
|
def loss_type_(self, loss_type):
|
|
self.mpd.loss_type_(loss_type)
|
|
self.mrd.loss_type_(loss_type)
|
|
|
|
def forward(self, fake, real=None):
|
|
"""
|
|
Args:
|
|
fake: [B T]
|
|
real: [B T]
|
|
"""
|
|
fake = fake.to(self.dummy_float)
|
|
|
|
if real is None:
|
|
self.loss_type_("wgan")
|
|
else:
|
|
length_difference = (fake.shape[-1] - real.shape[-1]) / real.shape[-1]
|
|
assert length_difference < 0.05, f"length_difference should be smaller than 5%"
|
|
|
|
self.loss_type_("hinge")
|
|
real = real.to(self.dummy_float)
|
|
|
|
fake = fake[..., : real.shape[-1]]
|
|
real = real[..., : fake.shape[-1]]
|
|
|
|
losses = {}
|
|
|
|
assert fake.dim() == 2, f"(B, T) is expected, but got {fake.shape}."
|
|
assert real is None or real.dim() == 2, f"(B, T) is expected, but got {real.shape}."
|
|
|
|
fake = fake.unsqueeze(1)
|
|
|
|
if real is None:
|
|
losses["mpd"] = self.mpd(fake, 1)
|
|
losses["mrd"] = self.mrd(fake, 1)
|
|
else:
|
|
real = real.unsqueeze(1)
|
|
losses["mpd_fake"] = self.mpd(fake, 0)
|
|
losses["mpd_real"] = self.mpd(real, 1)
|
|
losses["mrd_fake"] = self.mrd(fake, 0)
|
|
losses["mrd_real"] = self.mrd(real, 1)
|
|
|
|
return losses
|