mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-05 19:41:56 +08:00
255 lines
8.0 KiB
Python
255 lines
8.0 KiB
Python
# Copyright (c) 2024 NVIDIA CORPORATION.
|
|
# Licensed under the MIT license.
|
|
|
|
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
|
# LICENSE is in incl_licenses directory.
|
|
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import torch.nn as nn
|
|
from librosa.filters import mel as librosa_mel_fn
|
|
from scipy import signal
|
|
|
|
import typing
|
|
from typing import Optional, List, Union, Dict, Tuple
|
|
from collections import namedtuple
|
|
import math
|
|
import functools
|
|
|
|
|
|
# Adapted from https://github.com/descriptinc/descript-audio-codec/blob/main/dac/nn/loss.py under the MIT license.
|
|
# LICENSE is in incl_licenses directory.
|
|
class MultiScaleMelSpectrogramLoss(nn.Module):
|
|
"""Compute distance between mel spectrograms. Can be used
|
|
in a multi-scale way.
|
|
|
|
Parameters
|
|
----------
|
|
n_mels : List[int]
|
|
Number of mels per STFT, by default [5, 10, 20, 40, 80, 160, 320],
|
|
window_lengths : List[int], optional
|
|
Length of each window of each STFT, by default [32, 64, 128, 256, 512, 1024, 2048]
|
|
loss_fn : typing.Callable, optional
|
|
How to compare each loss, by default nn.L1Loss()
|
|
clamp_eps : float, optional
|
|
Clamp on the log magnitude, below, by default 1e-5
|
|
mag_weight : float, optional
|
|
Weight of raw magnitude portion of loss, by default 0.0 (no ampliciation on mag part)
|
|
log_weight : float, optional
|
|
Weight of log magnitude portion of loss, by default 1.0
|
|
pow : float, optional
|
|
Power to raise magnitude to before taking log, by default 1.0
|
|
weight : float, optional
|
|
Weight of this loss, by default 1.0
|
|
match_stride : bool, optional
|
|
Whether to match the stride of convolutional layers, by default False
|
|
|
|
Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
|
|
Additional code copied and modified from https://github.com/descriptinc/audiotools/blob/master/audiotools/core/audio_signal.py
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
sampling_rate: int,
|
|
n_mels: List[int] = [5, 10, 20, 40, 80, 160, 320],
|
|
window_lengths: List[int] = [32, 64, 128, 256, 512, 1024, 2048],
|
|
loss_fn: typing.Callable = nn.L1Loss(),
|
|
clamp_eps: float = 1e-5,
|
|
mag_weight: float = 0.0,
|
|
log_weight: float = 1.0,
|
|
pow: float = 1.0,
|
|
weight: float = 1.0,
|
|
match_stride: bool = False,
|
|
mel_fmin: List[float] = [0, 0, 0, 0, 0, 0, 0],
|
|
mel_fmax: List[float] = [None, None, None, None, None, None, None],
|
|
window_type: str = "hann",
|
|
):
|
|
super().__init__()
|
|
self.sampling_rate = sampling_rate
|
|
|
|
STFTParams = namedtuple(
|
|
"STFTParams",
|
|
["window_length", "hop_length", "window_type", "match_stride"],
|
|
)
|
|
|
|
self.stft_params = [
|
|
STFTParams(
|
|
window_length=w,
|
|
hop_length=w // 4,
|
|
match_stride=match_stride,
|
|
window_type=window_type,
|
|
)
|
|
for w in window_lengths
|
|
]
|
|
self.n_mels = n_mels
|
|
self.loss_fn = loss_fn
|
|
self.clamp_eps = clamp_eps
|
|
self.log_weight = log_weight
|
|
self.mag_weight = mag_weight
|
|
self.weight = weight
|
|
self.mel_fmin = mel_fmin
|
|
self.mel_fmax = mel_fmax
|
|
self.pow = pow
|
|
|
|
@staticmethod
|
|
@functools.lru_cache(None)
|
|
def get_window(
|
|
window_type,
|
|
window_length,
|
|
):
|
|
return signal.get_window(window_type, window_length)
|
|
|
|
@staticmethod
|
|
@functools.lru_cache(None)
|
|
def get_mel_filters(sr, n_fft, n_mels, fmin, fmax):
|
|
return librosa_mel_fn(sr=sr, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)
|
|
|
|
def mel_spectrogram(
|
|
self,
|
|
wav,
|
|
n_mels,
|
|
fmin,
|
|
fmax,
|
|
window_length,
|
|
hop_length,
|
|
match_stride,
|
|
window_type,
|
|
):
|
|
"""
|
|
Mirrors AudioSignal.mel_spectrogram used by BigVGAN-v2 training from:
|
|
https://github.com/descriptinc/audiotools/blob/master/audiotools/core/audio_signal.py
|
|
"""
|
|
B, C, T = wav.shape
|
|
|
|
if match_stride:
|
|
assert (
|
|
hop_length == window_length // 4
|
|
), "For match_stride, hop must equal n_fft // 4"
|
|
right_pad = math.ceil(T / hop_length) * hop_length - T
|
|
pad = (window_length - hop_length) // 2
|
|
else:
|
|
right_pad = 0
|
|
pad = 0
|
|
|
|
wav = torch.nn.functional.pad(wav, (pad, pad + right_pad), mode="reflect")
|
|
|
|
window = self.get_window(window_type, window_length)
|
|
window = torch.from_numpy(window).to(wav.device).float()
|
|
|
|
stft = torch.stft(
|
|
wav.reshape(-1, T),
|
|
n_fft=window_length,
|
|
hop_length=hop_length,
|
|
window=window,
|
|
return_complex=True,
|
|
center=True,
|
|
)
|
|
_, nf, nt = stft.shape
|
|
stft = stft.reshape(B, C, nf, nt)
|
|
if match_stride:
|
|
"""
|
|
Drop first two and last two frames, which are added, because of padding. Now num_frames * hop_length = num_samples.
|
|
"""
|
|
stft = stft[..., 2:-2]
|
|
magnitude = torch.abs(stft)
|
|
|
|
nf = magnitude.shape[2]
|
|
mel_basis = self.get_mel_filters(
|
|
self.sampling_rate, 2 * (nf - 1), n_mels, fmin, fmax
|
|
)
|
|
mel_basis = torch.from_numpy(mel_basis).to(wav.device)
|
|
mel_spectrogram = magnitude.transpose(2, -1) @ mel_basis.T
|
|
mel_spectrogram = mel_spectrogram.transpose(-1, 2)
|
|
|
|
return mel_spectrogram
|
|
|
|
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
"""Computes mel loss between an estimate and a reference
|
|
signal.
|
|
|
|
Parameters
|
|
----------
|
|
x : torch.Tensor
|
|
Estimate signal
|
|
y : torch.Tensor
|
|
Reference signal
|
|
|
|
Returns
|
|
-------
|
|
torch.Tensor
|
|
Mel loss.
|
|
"""
|
|
|
|
loss = 0.0
|
|
for n_mels, fmin, fmax, s in zip(
|
|
self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
|
|
):
|
|
kwargs = {
|
|
"n_mels": n_mels,
|
|
"fmin": fmin,
|
|
"fmax": fmax,
|
|
"window_length": s.window_length,
|
|
"hop_length": s.hop_length,
|
|
"match_stride": s.match_stride,
|
|
"window_type": s.window_type,
|
|
}
|
|
|
|
x_mels = self.mel_spectrogram(x, **kwargs)
|
|
y_mels = self.mel_spectrogram(y, **kwargs)
|
|
x_logmels = torch.log(
|
|
x_mels.clamp(min=self.clamp_eps).pow(self.pow)
|
|
) / torch.log(torch.tensor(10.0))
|
|
y_logmels = torch.log(
|
|
y_mels.clamp(min=self.clamp_eps).pow(self.pow)
|
|
) / torch.log(torch.tensor(10.0))
|
|
|
|
loss += self.log_weight * self.loss_fn(x_logmels, y_logmels)
|
|
loss += self.mag_weight * self.loss_fn(x_logmels, y_logmels)
|
|
|
|
return loss
|
|
|
|
|
|
# Loss functions
|
|
def feature_loss(
|
|
fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]
|
|
) -> torch.Tensor:
|
|
|
|
loss = 0
|
|
for dr, dg in zip(fmap_r, fmap_g):
|
|
for rl, gl in zip(dr, dg):
|
|
loss += torch.mean(torch.abs(rl - gl))
|
|
|
|
return loss * 2 # This equates to lambda=2.0 for the feature matching loss
|
|
|
|
|
|
def discriminator_loss(
|
|
disc_real_outputs: List[torch.Tensor], disc_generated_outputs: List[torch.Tensor]
|
|
) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
|
|
|
|
loss = 0
|
|
r_losses = []
|
|
g_losses = []
|
|
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
|
r_loss = torch.mean((1 - dr) ** 2)
|
|
g_loss = torch.mean(dg**2)
|
|
loss += r_loss + g_loss
|
|
r_losses.append(r_loss.item())
|
|
g_losses.append(g_loss.item())
|
|
|
|
return loss, r_losses, g_losses
|
|
|
|
|
|
def generator_loss(
|
|
disc_outputs: List[torch.Tensor],
|
|
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
|
|
|
loss = 0
|
|
gen_losses = []
|
|
for dg in disc_outputs:
|
|
l = torch.mean((1 - dg) ** 2)
|
|
gen_losses.append(l)
|
|
loss += l
|
|
|
|
return loss, gen_losses
|