mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-07 23:48:48 +08:00
129 lines
4.2 KiB
Python
129 lines
4.2 KiB
Python
# -*- coding: utf-8 -*-
|
|
|
|
# Copyright 2019 Tomoki Hayashi
|
|
# MIT License (https://opensource.org/licenses/MIT)
|
|
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch import nn
|
|
|
|
from ..hparams import HParams
|
|
|
|
|
|
def _make_stft_cfg(hop_length, win_length=None):
|
|
if win_length is None:
|
|
win_length = 4 * hop_length
|
|
n_fft = 2 ** (win_length - 1).bit_length()
|
|
return dict(n_fft=n_fft, hop_length=hop_length, win_length=win_length)
|
|
|
|
|
|
def get_stft_cfgs(hp: HParams):
|
|
assert hp.wav_rate == 44100, f"wav_rate must be 44100, got {hp.wav_rate}"
|
|
return [_make_stft_cfg(h) for h in (100, 256, 512)]
|
|
|
|
|
|
def stft(x, n_fft, hop_length, win_length, window):
|
|
dtype = x.dtype
|
|
x = torch.stft(x.float(), n_fft, hop_length, win_length, window, return_complex=True)
|
|
x = x.abs().to(dtype)
|
|
x = x.transpose(2, 1) # (b f t) -> (b t f)
|
|
return x
|
|
|
|
|
|
class SpectralConvergengeLoss(nn.Module):
|
|
def forward(self, x_mag, y_mag):
|
|
"""Calculate forward propagation.
|
|
Args:
|
|
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
|
|
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
|
|
Returns:
|
|
Tensor: Spectral convergence loss value.
|
|
"""
|
|
return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")
|
|
|
|
|
|
class LogSTFTMagnitudeLoss(nn.Module):
|
|
def forward(self, x_mag, y_mag):
|
|
"""Calculate forward propagation.
|
|
Args:
|
|
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
|
|
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
|
|
Returns:
|
|
Tensor: Log STFT magnitude loss value.
|
|
"""
|
|
return F.l1_loss(torch.log1p(x_mag), torch.log1p(y_mag))
|
|
|
|
|
|
class STFTLoss(nn.Module):
|
|
def __init__(self, hp, stft_cfg: dict, window="hann_window"):
|
|
super().__init__()
|
|
self.hp = hp
|
|
self.stft_cfg = stft_cfg
|
|
self.spectral_convergenge_loss = SpectralConvergengeLoss()
|
|
self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
|
|
self.register_buffer("window", getattr(torch, window)(stft_cfg["win_length"]), persistent=False)
|
|
|
|
def forward(self, x, y):
|
|
"""Calculate forward propagation.
|
|
Args:
|
|
x (Tensor): Predicted signal (B, T).
|
|
y (Tensor): Groundtruth signal (B, T).
|
|
Returns:
|
|
Tensor: Spectral convergence loss value.
|
|
Tensor: Log STFT magnitude loss value.
|
|
"""
|
|
stft_cfg = dict(self.stft_cfg)
|
|
x_mag = stft(x, **stft_cfg, window=self.window) # (b t) -> (b t f)
|
|
y_mag = stft(y, **stft_cfg, window=self.window)
|
|
sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
|
|
mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
|
|
return dict(sc=sc_loss, mag=mag_loss)
|
|
|
|
|
|
class MRSTFTLoss(nn.Module):
|
|
def __init__(self, hp: HParams, window="hann_window"):
|
|
"""Initialize Multi resolution STFT loss module.
|
|
Args:
|
|
resolutions (list): List of (FFT size, hop size, window length).
|
|
window (str): Window function type.
|
|
"""
|
|
super().__init__()
|
|
stft_cfgs = get_stft_cfgs(hp)
|
|
self.stft_losses = nn.ModuleList()
|
|
self.hp = hp
|
|
for c in stft_cfgs:
|
|
self.stft_losses += [STFTLoss(hp, c, window=window)]
|
|
|
|
def forward(self, x, y):
|
|
"""Calculate forward propagation.
|
|
Args:
|
|
x (Tensor): Predicted signal (b t).
|
|
y (Tensor): Groundtruth signal (b t).
|
|
Returns:
|
|
Tensor: Multi resolution spectral convergence loss value.
|
|
Tensor: Multi resolution log STFT magnitude loss value.
|
|
"""
|
|
assert x.dim() == 2 and y.dim() == 2, f"(b t) is expected, but got {x.shape} and {y.shape}."
|
|
|
|
dtype = x.dtype
|
|
|
|
x = x.float()
|
|
y = y.float()
|
|
|
|
# Align length
|
|
x = x[..., : y.shape[-1]]
|
|
y = y[..., : x.shape[-1]]
|
|
|
|
losses = {}
|
|
|
|
for f in self.stft_losses:
|
|
d = f(x, y)
|
|
for k, v in d.items():
|
|
losses.setdefault(k, []).append(v)
|
|
|
|
for k, v in losses.items():
|
|
losses[k] = torch.stack(v, dim=0).mean().to(dtype)
|
|
|
|
return losses
|