mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-07 15:19:59 +08:00
Merge e0c0410f5873401748d08e87cd9e9e2e5bea29d6 into 82a5672361baea2351e1c370f49cd0d7f58ccd8b
This commit is contained in:
commit
77e9bce4dc
@ -17,6 +17,11 @@ logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
|
|||||||
logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
|
logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
|
||||||
import pdb
|
import pdb
|
||||||
import torch
|
import torch
|
||||||
|
from resemble_enhance.enhancer.inference import denoise, enhance
|
||||||
|
import torchaudio
|
||||||
|
import gc
|
||||||
|
import librosa
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
if os.path.exists("./gweight.txt"):
|
if os.path.exists("./gweight.txt"):
|
||||||
with open("./gweight.txt", 'r', encoding="utf-8") as file:
|
with open("./gweight.txt", 'r', encoding="utf-8") as file:
|
||||||
@ -83,6 +88,31 @@ if is_half == True:
|
|||||||
else:
|
else:
|
||||||
bert_model = bert_model.to(device)
|
bert_model = bert_model.to(device)
|
||||||
|
|
||||||
|
def clear_gpu_cash():
|
||||||
|
# del model
|
||||||
|
gc.collect()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
def _fn(path, solver="Midpoint", nfe=64, tau=0.5,chunk_seconds=10,chunks_overlap=0.5, denoising=True):
|
||||||
|
if path is None:
|
||||||
|
return None, None
|
||||||
|
print(path)
|
||||||
|
sf.write('./output.wav', path[1], path[0], 'PCM_24')
|
||||||
|
|
||||||
|
solver = solver.lower()
|
||||||
|
nfe = int(nfe)
|
||||||
|
lambd = 0.9 if denoising else 0.1
|
||||||
|
|
||||||
|
dwav, sr = torchaudio.load('./output.wav')
|
||||||
|
dwav = dwav.mean(dim=0)
|
||||||
|
|
||||||
|
wav2, new_sr = enhance(dwav = dwav, sr = sr, device = device, nfe=nfe,chunk_seconds=chunk_seconds,chunks_overlap=chunks_overlap, solver=solver, lambd=lambd, tau=tau)
|
||||||
|
|
||||||
|
wav2 = wav2.cpu().numpy()
|
||||||
|
|
||||||
|
clear_gpu_cash()
|
||||||
|
return (new_sr, wav2)
|
||||||
|
|
||||||
def get_bert_feature(text, word2ph):
|
def get_bert_feature(text, word2ph):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@ -590,6 +620,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
|||||||
temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=1,interactive=True)
|
temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=1,interactive=True)
|
||||||
inference_button = gr.Button(i18n("合成语音"), variant="primary")
|
inference_button = gr.Button(i18n("合成语音"), variant="primary")
|
||||||
output = gr.Audio(label=i18n("输出的语音"))
|
output = gr.Audio(label=i18n("输出的语音"))
|
||||||
|
up_button = gr.Button(i18n("音频降噪增强"), variant="primary")
|
||||||
|
|
||||||
inference_button.click(
|
inference_button.click(
|
||||||
get_tts_wav,
|
get_tts_wav,
|
||||||
@ -597,6 +628,8 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
|||||||
[output],
|
[output],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
up_button.click(_fn, [output], [output])
|
||||||
|
|
||||||
gr.Markdown(value=i18n("文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。"))
|
gr.Markdown(value=i18n("文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。"))
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
text_inp = gr.Textbox(label=i18n("需要合成的切分前文本"), value="")
|
text_inp = gr.Textbox(label=i18n("需要合成的切分前文本"), value="")
|
||||||
|
55
GPT_SoVITS/resemble_enhance/common.py
Normal file
55
GPT_SoVITS/resemble_enhance/common.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Normalizer(nn.Module):
|
||||||
|
def __init__(self, momentum=0.01, eps=1e-9):
|
||||||
|
super().__init__()
|
||||||
|
self.momentum = momentum
|
||||||
|
self.eps = eps
|
||||||
|
self.running_mean_unsafe: Tensor
|
||||||
|
self.running_var_unsafe: Tensor
|
||||||
|
self.register_buffer("running_mean_unsafe", torch.full([], torch.nan))
|
||||||
|
self.register_buffer("running_var_unsafe", torch.full([], torch.nan))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def started(self):
|
||||||
|
return not torch.isnan(self.running_mean_unsafe)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def running_mean(self):
|
||||||
|
if not self.started:
|
||||||
|
return torch.zeros_like(self.running_mean_unsafe)
|
||||||
|
return self.running_mean_unsafe
|
||||||
|
|
||||||
|
@property
|
||||||
|
def running_std(self):
|
||||||
|
if not self.started:
|
||||||
|
return torch.ones_like(self.running_var_unsafe)
|
||||||
|
return (self.running_var_unsafe + self.eps).sqrt()
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def _ema(self, a: Tensor, x: Tensor):
|
||||||
|
return (1 - self.momentum) * a + self.momentum * x
|
||||||
|
|
||||||
|
def update_(self, x):
|
||||||
|
if not self.started:
|
||||||
|
self.running_mean_unsafe = x.mean()
|
||||||
|
self.running_var_unsafe = x.var()
|
||||||
|
else:
|
||||||
|
self.running_mean_unsafe = self._ema(self.running_mean_unsafe, x.mean())
|
||||||
|
self.running_var_unsafe = self._ema(self.running_var_unsafe, (x - self.running_mean).pow(2).mean())
|
||||||
|
|
||||||
|
def forward(self, x: Tensor, update=True):
|
||||||
|
if self.training and update:
|
||||||
|
self.update_(x)
|
||||||
|
self.stats = dict(mean=self.running_mean.item(), std=self.running_std.item())
|
||||||
|
x = (x - self.running_mean) / self.running_std
|
||||||
|
return x
|
||||||
|
|
||||||
|
def inverse(self, x: Tensor):
|
||||||
|
return x * self.running_std + self.running_mean
|
48
GPT_SoVITS/resemble_enhance/data/__init__.py
Normal file
48
GPT_SoVITS/resemble_enhance/data/__init__.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
import logging
|
||||||
|
import random
|
||||||
|
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from ..hparams import HParams
|
||||||
|
from .dataset import Dataset
|
||||||
|
from .utils import mix_fg_bg, rglob_audio_files
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_datasets(hp: HParams, mode, val_size=10, seed=123):
|
||||||
|
paths = rglob_audio_files(hp.fg_dir)
|
||||||
|
logger.info(f"Found {len(paths)} audio files in {hp.fg_dir}")
|
||||||
|
|
||||||
|
random.Random(seed).shuffle(paths)
|
||||||
|
train_paths = paths[:-val_size]
|
||||||
|
val_paths = paths[-val_size:]
|
||||||
|
|
||||||
|
train_ds = Dataset(train_paths, hp, training=True, mode=mode)
|
||||||
|
val_ds = Dataset(val_paths, hp, training=False, mode=mode)
|
||||||
|
|
||||||
|
logger.info(f"Train set: {len(train_ds)} samples - Val set: {len(val_ds)} samples")
|
||||||
|
|
||||||
|
return train_ds, val_ds
|
||||||
|
|
||||||
|
|
||||||
|
def create_dataloaders(hp: HParams, mode):
|
||||||
|
train_ds, val_ds = _create_datasets(hp=hp, mode=mode)
|
||||||
|
|
||||||
|
train_dl = DataLoader(
|
||||||
|
train_ds,
|
||||||
|
batch_size=hp.batch_size_per_gpu,
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=hp.nj,
|
||||||
|
drop_last=True,
|
||||||
|
collate_fn=train_ds.collate_fn,
|
||||||
|
)
|
||||||
|
val_dl = DataLoader(
|
||||||
|
val_ds,
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=hp.nj,
|
||||||
|
drop_last=False,
|
||||||
|
collate_fn=val_ds.collate_fn,
|
||||||
|
)
|
||||||
|
return train_dl, val_dl
|
171
GPT_SoVITS/resemble_enhance/data/dataset.py
Normal file
171
GPT_SoVITS/resemble_enhance/data/dataset.py
Normal file
@ -0,0 +1,171 @@
|
|||||||
|
import logging
|
||||||
|
import random
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
import torchaudio.functional as AF
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
from torch.utils.data import Dataset as DatasetBase
|
||||||
|
|
||||||
|
from ..hparams import HParams
|
||||||
|
from .distorter import Distorter
|
||||||
|
from .utils import rglob_audio_files
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize(x):
|
||||||
|
return x / (np.abs(x).max() + 1e-7)
|
||||||
|
|
||||||
|
|
||||||
|
def _collate(batch, key, tensor=True, pad=True):
|
||||||
|
l = [d[key] for d in batch]
|
||||||
|
if l[0] is None:
|
||||||
|
return None
|
||||||
|
if tensor:
|
||||||
|
l = [torch.from_numpy(x) for x in l]
|
||||||
|
if pad:
|
||||||
|
assert tensor, "Can't pad non-tensor"
|
||||||
|
l = pad_sequence(l, batch_first=True)
|
||||||
|
return l
|
||||||
|
|
||||||
|
|
||||||
|
def praat_augment(wav, sr):
|
||||||
|
try:
|
||||||
|
import parselmouth
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Please install parselmouth>=0.5.0 to use Praat augmentation")
|
||||||
|
# "praat-parselmouth @ git+https://github.com/YannickJadoul/Parselmouth@0bbcca69705ed73322f3712b19d71bb3694b2540",
|
||||||
|
# https://github.com/YannickJadoul/Parselmouth/issues/68
|
||||||
|
# note that this function may hang if the praat version is 0.4.3
|
||||||
|
assert wav.ndim == 1, f"wav.ndim must be 1 but got {wav.ndim}"
|
||||||
|
sound = parselmouth.Sound(wav, sr)
|
||||||
|
formant_shift_ratio = random.uniform(1.1, 1.5)
|
||||||
|
pitch_range_factor = random.uniform(0.5, 2.0)
|
||||||
|
sound = parselmouth.praat.call(sound, "Change gender", 75, 600, formant_shift_ratio, 0, pitch_range_factor, 1.0)
|
||||||
|
wav = np.array(sound.values)[0].astype(np.float32)
|
||||||
|
return wav
|
||||||
|
|
||||||
|
|
||||||
|
class Dataset(DatasetBase):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
fg_paths: list[Path],
|
||||||
|
hp: HParams,
|
||||||
|
training=True,
|
||||||
|
max_retries=100,
|
||||||
|
silent_fg_prob=0.01,
|
||||||
|
mode=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
assert mode in ("enhancer", "denoiser"), f"Invalid mode: {mode}"
|
||||||
|
|
||||||
|
self.hp = hp
|
||||||
|
self.fg_paths = fg_paths
|
||||||
|
self.bg_paths = rglob_audio_files(hp.bg_dir)
|
||||||
|
|
||||||
|
if len(self.fg_paths) == 0:
|
||||||
|
raise ValueError(f"No foreground audio files found in {hp.fg_dir}")
|
||||||
|
|
||||||
|
if len(self.bg_paths) == 0:
|
||||||
|
raise ValueError(f"No background audio files found in {hp.bg_dir}")
|
||||||
|
|
||||||
|
logger.info(f"Found {len(self.fg_paths)} foreground files and {len(self.bg_paths)} background files")
|
||||||
|
|
||||||
|
self.training = training
|
||||||
|
self.max_retries = max_retries
|
||||||
|
self.silent_fg_prob = silent_fg_prob
|
||||||
|
|
||||||
|
self.mode = mode
|
||||||
|
self.distorter = Distorter(hp, training=training, mode=mode)
|
||||||
|
|
||||||
|
def _load_wav(self, path, length=None, random_crop=True):
|
||||||
|
wav, sr = torchaudio.load(path)
|
||||||
|
|
||||||
|
wav = AF.resample(
|
||||||
|
waveform=wav,
|
||||||
|
orig_freq=sr,
|
||||||
|
new_freq=self.hp.wav_rate,
|
||||||
|
lowpass_filter_width=64,
|
||||||
|
rolloff=0.9475937167399596,
|
||||||
|
resampling_method="sinc_interp_kaiser",
|
||||||
|
beta=14.769656459379492,
|
||||||
|
)
|
||||||
|
|
||||||
|
wav = wav.float().numpy()
|
||||||
|
|
||||||
|
if wav.ndim == 2:
|
||||||
|
wav = np.mean(wav, axis=0)
|
||||||
|
|
||||||
|
if length is None and self.training:
|
||||||
|
length = int(self.hp.training_seconds * self.hp.wav_rate)
|
||||||
|
|
||||||
|
if length is not None:
|
||||||
|
if random_crop:
|
||||||
|
start = random.randint(0, max(0, len(wav) - length))
|
||||||
|
wav = wav[start : start + length]
|
||||||
|
else:
|
||||||
|
wav = wav[:length]
|
||||||
|
|
||||||
|
if length is not None and len(wav) < length:
|
||||||
|
wav = np.pad(wav, (0, length - len(wav)))
|
||||||
|
|
||||||
|
wav = _normalize(wav)
|
||||||
|
|
||||||
|
return wav
|
||||||
|
|
||||||
|
def _getitem_unsafe(self, index: int):
|
||||||
|
fg_path = self.fg_paths[index]
|
||||||
|
|
||||||
|
if self.training and random.random() < self.silent_fg_prob:
|
||||||
|
fg_wav = np.zeros(int(self.hp.training_seconds * self.hp.wav_rate), dtype=np.float32)
|
||||||
|
else:
|
||||||
|
fg_wav = self._load_wav(fg_path)
|
||||||
|
if random.random() < self.hp.praat_augment_prob and self.training:
|
||||||
|
fg_wav = praat_augment(fg_wav, self.hp.wav_rate)
|
||||||
|
|
||||||
|
if self.hp.load_fg_only:
|
||||||
|
bg_wav = None
|
||||||
|
fg_dwav = None
|
||||||
|
bg_dwav = None
|
||||||
|
else:
|
||||||
|
fg_dwav = _normalize(self.distorter(fg_wav, self.hp.wav_rate)).astype(np.float32)
|
||||||
|
if self.training:
|
||||||
|
bg_path = random.choice(self.bg_paths)
|
||||||
|
else:
|
||||||
|
# Deterministic for validation
|
||||||
|
bg_path = self.bg_paths[index % len(self.bg_paths)]
|
||||||
|
bg_wav = self._load_wav(bg_path, length=len(fg_wav), random_crop=self.training)
|
||||||
|
bg_dwav = _normalize(self.distorter(bg_wav, self.hp.wav_rate)).astype(np.float32)
|
||||||
|
|
||||||
|
return dict(
|
||||||
|
fg_wav=fg_wav,
|
||||||
|
bg_wav=bg_wav,
|
||||||
|
fg_dwav=fg_dwav,
|
||||||
|
bg_dwav=bg_dwav,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __getitem__(self, index: int):
|
||||||
|
for i in range(self.max_retries):
|
||||||
|
try:
|
||||||
|
return self._getitem_unsafe(index)
|
||||||
|
except Exception as e:
|
||||||
|
if i == self.max_retries - 1:
|
||||||
|
raise RuntimeError(f"Failed to load {self.fg_paths[index]} after {self.max_retries} retries") from e
|
||||||
|
logger.debug(f"Error loading {self.fg_paths[index]}: {e}, skipping")
|
||||||
|
index = np.random.randint(0, len(self))
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.fg_paths)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def collate_fn(batch):
|
||||||
|
return dict(
|
||||||
|
fg_wavs=_collate(batch, "fg_wav"),
|
||||||
|
bg_wavs=_collate(batch, "bg_wav"),
|
||||||
|
fg_dwavs=_collate(batch, "fg_dwav"),
|
||||||
|
bg_dwavs=_collate(batch, "bg_dwav"),
|
||||||
|
)
|
1
GPT_SoVITS/resemble_enhance/data/distorter/__init__.py
Normal file
1
GPT_SoVITS/resemble_enhance/data/distorter/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from .distorter import Distorter
|
104
GPT_SoVITS/resemble_enhance/data/distorter/base.py
Normal file
104
GPT_SoVITS/resemble_enhance/data/distorter/base.py
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
import itertools
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
_DEBUG = bool(os.environ.get("DEBUG", False))
|
||||||
|
|
||||||
|
|
||||||
|
class Effect:
|
||||||
|
def apply(self, wav: np.ndarray, sr: int):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
wav: (T)
|
||||||
|
sr: sample rate
|
||||||
|
Returns:
|
||||||
|
wav: (T) with the same sample rate of `sr`
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __call__(self, wav: np.ndarray, sr: int):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
wav: (T)
|
||||||
|
sr: sample rate
|
||||||
|
Returns:
|
||||||
|
wav: (T) with the same sample rate of `sr`
|
||||||
|
"""
|
||||||
|
assert len(wav.shape) == 1, wav.shape
|
||||||
|
|
||||||
|
if _DEBUG:
|
||||||
|
start = time.time()
|
||||||
|
else:
|
||||||
|
start = None
|
||||||
|
|
||||||
|
shape = wav.shape
|
||||||
|
assert wav.ndim == 1, f"{self}: Expected wav.ndim == 1, got {wav.ndim}."
|
||||||
|
wav = self.apply(wav, sr)
|
||||||
|
assert shape == wav.shape, f"{self}: {shape} != {wav.shape}."
|
||||||
|
|
||||||
|
if start is not None:
|
||||||
|
end = time.time()
|
||||||
|
print(f"{self.__class__.__name__}: {end - start:.3f} sec")
|
||||||
|
|
||||||
|
return wav
|
||||||
|
|
||||||
|
|
||||||
|
class Chain(Effect):
|
||||||
|
def __init__(self, *effects):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.effects = effects
|
||||||
|
|
||||||
|
def apply(self, wav, sr):
|
||||||
|
for effect in self.effects:
|
||||||
|
wav = effect(wav, sr)
|
||||||
|
return wav
|
||||||
|
|
||||||
|
|
||||||
|
class Maybe(Effect):
|
||||||
|
def __init__(self, prob, effect):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.prob = prob
|
||||||
|
self.effect = effect
|
||||||
|
|
||||||
|
if _DEBUG:
|
||||||
|
warnings.warn("DEBUG mode is on. Maybe -> Must.")
|
||||||
|
self.prob = 1
|
||||||
|
|
||||||
|
def apply(self, wav, sr):
|
||||||
|
if random.random() > self.prob:
|
||||||
|
return wav
|
||||||
|
return self.effect(wav, sr)
|
||||||
|
|
||||||
|
|
||||||
|
class Choice(Effect):
|
||||||
|
def __init__(self, *effects, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.effects = effects
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
def apply(self, wav, sr):
|
||||||
|
return np.random.choice(self.effects, **self.kwargs)(wav, sr)
|
||||||
|
|
||||||
|
|
||||||
|
class Permutation(Effect):
|
||||||
|
def __init__(self, *effects, n: int | None = None):
|
||||||
|
super().__init__()
|
||||||
|
self.effects = effects
|
||||||
|
self.n = n
|
||||||
|
|
||||||
|
def apply(self, wav, sr):
|
||||||
|
if self.n is None:
|
||||||
|
n = np.random.binomial(len(self.effects), 0.5)
|
||||||
|
else:
|
||||||
|
n = self.n
|
||||||
|
if n == 0:
|
||||||
|
return wav
|
||||||
|
perms = itertools.permutations(self.effects, n)
|
||||||
|
effects = random.choice(list(perms))
|
||||||
|
return Chain(*effects)(wav, sr)
|
85
GPT_SoVITS/resemble_enhance/data/distorter/custom.py
Normal file
85
GPT_SoVITS/resemble_enhance/data/distorter/custom.py
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
import logging
|
||||||
|
import random
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from functools import cached_property
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import librosa
|
||||||
|
import numpy as np
|
||||||
|
from scipy import signal
|
||||||
|
|
||||||
|
from ..utils import walk_paths
|
||||||
|
from .base import Effect
|
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RandomRIR(Effect):
|
||||||
|
rir_dir: Path | None
|
||||||
|
rir_rate: int = 44_000
|
||||||
|
rir_suffix: str = ".npy"
|
||||||
|
deterministic: bool = False
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def rir_paths(self):
|
||||||
|
if self.rir_dir is None:
|
||||||
|
return []
|
||||||
|
return list(walk_paths(self.rir_dir, self.rir_suffix))
|
||||||
|
|
||||||
|
def _sample_rir(self):
|
||||||
|
if len(self.rir_paths) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if self.deterministic:
|
||||||
|
rir_path = self.rir_paths[0]
|
||||||
|
else:
|
||||||
|
rir_path = random.choice(self.rir_paths)
|
||||||
|
|
||||||
|
rir = np.squeeze(np.load(rir_path))
|
||||||
|
assert isinstance(rir, np.ndarray)
|
||||||
|
|
||||||
|
return rir
|
||||||
|
|
||||||
|
def apply(self, wav, sr):
|
||||||
|
# ref: https://github.com/haoheliu/voicefixer_main/blob/b06e07c945ac1d309b8a57ddcd599ca376b98cd9/dataloaders/augmentation/magical_effects.py#L158
|
||||||
|
|
||||||
|
if len(self.rir_paths) == 0:
|
||||||
|
return wav
|
||||||
|
|
||||||
|
length = len(wav)
|
||||||
|
|
||||||
|
wav = librosa.resample(wav, orig_sr=sr, target_sr=self.rir_rate, res_type="kaiser_fast")
|
||||||
|
rir = self._sample_rir()
|
||||||
|
|
||||||
|
wav = signal.convolve(wav, rir, mode="same")
|
||||||
|
|
||||||
|
actlev = np.max(np.abs(wav))
|
||||||
|
if actlev > 0.99:
|
||||||
|
wav = (wav / actlev) * 0.98
|
||||||
|
|
||||||
|
wav = librosa.resample(wav, orig_sr=self.rir_rate, target_sr=sr, res_type="kaiser_fast")
|
||||||
|
|
||||||
|
if abs(length - len(wav)) > 10:
|
||||||
|
_logger.warning(f"length mismatch: {length} vs {len(wav)}")
|
||||||
|
|
||||||
|
if length > len(wav):
|
||||||
|
wav = np.pad(wav, (0, length - len(wav)))
|
||||||
|
elif length < len(wav):
|
||||||
|
wav = wav[:length]
|
||||||
|
|
||||||
|
return wav
|
||||||
|
|
||||||
|
|
||||||
|
class RandomGaussianNoise(Effect):
|
||||||
|
def __init__(self, alpha_range=(0.8, 1)):
|
||||||
|
super().__init__()
|
||||||
|
self.alpha_range = alpha_range
|
||||||
|
|
||||||
|
def apply(self, wav, sr):
|
||||||
|
noise = np.random.randn(*wav.shape)
|
||||||
|
noise_energy = np.sum(noise**2)
|
||||||
|
wav_energy = np.sum(wav**2)
|
||||||
|
noise = noise * np.sqrt(wav_energy / noise_energy)
|
||||||
|
alpha = random.uniform(*self.alpha_range)
|
||||||
|
return wav * alpha + noise * (1 - alpha)
|
32
GPT_SoVITS/resemble_enhance/data/distorter/distorter.py
Normal file
32
GPT_SoVITS/resemble_enhance/data/distorter/distorter.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
from ...hparams import HParams
|
||||||
|
from .base import Chain, Choice, Permutation
|
||||||
|
from .custom import RandomGaussianNoise, RandomRIR
|
||||||
|
|
||||||
|
|
||||||
|
class Distorter(Chain):
|
||||||
|
def __init__(self, hp: HParams, training: bool = False, mode: str = "enhancer"):
|
||||||
|
# Lazy import
|
||||||
|
from .sox import RandomBandpassDistorter, RandomEqualizer, RandomLowpassDistorter, RandomOverdrive, RandomReverb
|
||||||
|
|
||||||
|
if training:
|
||||||
|
permutation = Permutation(
|
||||||
|
RandomRIR(hp.rir_dir),
|
||||||
|
RandomReverb(),
|
||||||
|
RandomGaussianNoise(),
|
||||||
|
RandomOverdrive(),
|
||||||
|
RandomEqualizer(),
|
||||||
|
Choice(
|
||||||
|
RandomLowpassDistorter(),
|
||||||
|
RandomBandpassDistorter(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if mode == "denoiser":
|
||||||
|
super().__init__(permutation)
|
||||||
|
else:
|
||||||
|
# 80%: distortion, 20%: clean
|
||||||
|
super().__init__(Choice(permutation, Chain(), p=[0.8, 0.2]))
|
||||||
|
else:
|
||||||
|
super().__init__(
|
||||||
|
RandomRIR(hp.rir_dir, deterministic=True),
|
||||||
|
RandomReverb(deterministic=True),
|
||||||
|
)
|
176
GPT_SoVITS/resemble_enhance/data/distorter/sox.py
Normal file
176
GPT_SoVITS/resemble_enhance/data/distorter/sox.py
Normal file
@ -0,0 +1,176 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import warnings
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
try:
|
||||||
|
import augment
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"augment is not installed, please install it first using:"
|
||||||
|
"\npip install git+https://github.com/facebookresearch/WavAugment@54afcdb00ccc852c2f030f239f8532c9562b550e"
|
||||||
|
)
|
||||||
|
|
||||||
|
from .base import Effect
|
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
_DEBUG = bool(os.environ.get("DEBUG", False))
|
||||||
|
|
||||||
|
|
||||||
|
class AttachableEffect(Effect):
|
||||||
|
def attach(self, chain: augment.EffectChain) -> augment.EffectChain:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def apply(self, wav: np.ndarray, sr: int):
|
||||||
|
chain = augment.EffectChain()
|
||||||
|
chain = self.attach(chain)
|
||||||
|
tensor = torch.from_numpy(wav)[None].float() # (1, T)
|
||||||
|
tensor = chain.apply(tensor, src_info={"rate": sr}, target_info={"channels": 1, "rate": sr})
|
||||||
|
wav = tensor.numpy()[0] # (T,)
|
||||||
|
return wav
|
||||||
|
|
||||||
|
|
||||||
|
class SoxEffect(AttachableEffect):
|
||||||
|
def __init__(self, effect_name: str, *args, **kwargs):
|
||||||
|
self.effect_name = effect_name
|
||||||
|
self.args = args
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
def attach(self, chain: augment.EffectChain) -> augment.EffectChain:
|
||||||
|
_logger.debug(f"Attaching {self.effect_name} with {self.args} and {self.kwargs}")
|
||||||
|
if not hasattr(chain, self.effect_name):
|
||||||
|
raise ValueError(f"EffectChain has no attribute {self.effect_name}")
|
||||||
|
return getattr(chain, self.effect_name)(*self.args, **self.kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class Maybe(AttachableEffect):
|
||||||
|
"""
|
||||||
|
Attach an effect with a probability.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, prob: float, effect: AttachableEffect):
|
||||||
|
self.prob = prob
|
||||||
|
self.effect = effect
|
||||||
|
if _DEBUG:
|
||||||
|
warnings.warn("DEBUG mode is on. Maybe -> Must.")
|
||||||
|
self.prob = 1
|
||||||
|
|
||||||
|
def attach(self, chain: augment.EffectChain) -> augment.EffectChain:
|
||||||
|
if random.random() > self.prob:
|
||||||
|
return chain
|
||||||
|
return self.effect.attach(chain)
|
||||||
|
|
||||||
|
|
||||||
|
class Chain(AttachableEffect):
|
||||||
|
"""
|
||||||
|
Attach a chain of effects.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *effects: AttachableEffect):
|
||||||
|
self.effects = effects
|
||||||
|
|
||||||
|
def attach(self, chain: augment.EffectChain) -> augment.EffectChain:
|
||||||
|
for effect in self.effects:
|
||||||
|
chain = effect.attach(chain)
|
||||||
|
return chain
|
||||||
|
|
||||||
|
|
||||||
|
class Choice(AttachableEffect):
|
||||||
|
"""
|
||||||
|
Attach one of the effects randomly.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *effects: AttachableEffect):
|
||||||
|
self.effects = effects
|
||||||
|
|
||||||
|
def attach(self, chain: augment.EffectChain) -> augment.EffectChain:
|
||||||
|
return random.choice(self.effects).attach(chain)
|
||||||
|
|
||||||
|
|
||||||
|
class Generator:
|
||||||
|
def __call__(self) -> str:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class Uniform(Generator):
|
||||||
|
def __init__(self, low, high):
|
||||||
|
self.low = low
|
||||||
|
self.high = high
|
||||||
|
|
||||||
|
def __call__(self) -> str:
|
||||||
|
return str(random.uniform(self.low, self.high))
|
||||||
|
|
||||||
|
|
||||||
|
class Randint(Generator):
|
||||||
|
def __init__(self, low, high):
|
||||||
|
self.low = low
|
||||||
|
self.high = high
|
||||||
|
|
||||||
|
def __call__(self) -> str:
|
||||||
|
return str(random.randint(self.low, self.high))
|
||||||
|
|
||||||
|
|
||||||
|
class Concat(Generator):
|
||||||
|
def __init__(self, *parts: Generator | str):
|
||||||
|
self.parts = parts
|
||||||
|
|
||||||
|
def __call__(self):
|
||||||
|
return "".join([part if isinstance(part, str) else part() for part in self.parts])
|
||||||
|
|
||||||
|
|
||||||
|
class RandomLowpassDistorter(SoxEffect):
|
||||||
|
def __init__(self, low=2000, high=16000):
|
||||||
|
super().__init__("sinc", "-n", Randint(50, 200), Concat("-", Uniform(low, high)))
|
||||||
|
|
||||||
|
|
||||||
|
class RandomBandpassDistorter(SoxEffect):
|
||||||
|
def __init__(self, low=100, high=1000, min_width=2000, max_width=4000):
|
||||||
|
super().__init__("sinc", "-n", Randint(50, 200), partial(self._fn, low, high, min_width, max_width))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _fn(low, high, min_width, max_width):
|
||||||
|
start = random.randint(low, high)
|
||||||
|
stop = start + random.randint(min_width, max_width)
|
||||||
|
return f"{start}-{stop}"
|
||||||
|
|
||||||
|
|
||||||
|
class RandomEqualizer(SoxEffect):
|
||||||
|
def __init__(self, low=100, high=4000, q_low=1, q_high=5, db_low: int = -30, db_high: int = 30):
|
||||||
|
super().__init__(
|
||||||
|
"equalizer",
|
||||||
|
Uniform(low, high),
|
||||||
|
lambda: f"{random.randint(q_low, q_high)}q",
|
||||||
|
lambda: random.randint(db_low, db_high),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RandomOverdrive(SoxEffect):
|
||||||
|
def __init__(self, gain_low=5, gain_high=40, colour_low=20, colour_high=80):
|
||||||
|
super().__init__("overdrive", Uniform(gain_low, gain_high), Uniform(colour_low, colour_high))
|
||||||
|
|
||||||
|
|
||||||
|
class RandomReverb(Chain):
|
||||||
|
def __init__(self, deterministic=False):
|
||||||
|
super().__init__(
|
||||||
|
SoxEffect(
|
||||||
|
"reverb",
|
||||||
|
Uniform(50, 50) if deterministic else Uniform(0, 100),
|
||||||
|
Uniform(50, 50) if deterministic else Uniform(0, 100),
|
||||||
|
Uniform(50, 50) if deterministic else Uniform(0, 100),
|
||||||
|
),
|
||||||
|
SoxEffect("channels", 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Flanger(SoxEffect):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__("flanger")
|
||||||
|
|
||||||
|
|
||||||
|
class Phaser(SoxEffect):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__("phaser")
|
43
GPT_SoVITS/resemble_enhance/data/utils.py
Normal file
43
GPT_SoVITS/resemble_enhance/data/utils.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
def walk_paths(root, suffix):
|
||||||
|
for path in Path(root).iterdir():
|
||||||
|
if path.is_dir():
|
||||||
|
yield from walk_paths(path, suffix)
|
||||||
|
elif path.suffix == suffix:
|
||||||
|
yield path
|
||||||
|
|
||||||
|
|
||||||
|
def rglob_audio_files(path: Path):
|
||||||
|
return list(walk_paths(path, ".wav")) + list(walk_paths(path, ".flac"))
|
||||||
|
|
||||||
|
|
||||||
|
def mix_fg_bg(fg: Tensor, bg: Tensor, alpha: float | Callable[..., float] = 0.5, eps=1e-7):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
fg: (b, t)
|
||||||
|
bg: (b, t)
|
||||||
|
"""
|
||||||
|
assert bg.shape == fg.shape, f"bg.shape != fg.shape: {bg.shape} != {fg.shape}"
|
||||||
|
fg = fg / (fg.abs().max(dim=-1, keepdim=True).values + eps)
|
||||||
|
bg = bg / (bg.abs().max(dim=-1, keepdim=True).values + eps)
|
||||||
|
|
||||||
|
fg_energy = fg.pow(2).sum(dim=-1, keepdim=True)
|
||||||
|
bg_energy = bg.pow(2).sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
fg = fg / (fg_energy + eps).sqrt()
|
||||||
|
bg = bg / (bg_energy + eps).sqrt()
|
||||||
|
|
||||||
|
if callable(alpha):
|
||||||
|
alpha = alpha()
|
||||||
|
|
||||||
|
assert 0 <= alpha <= 1, f"alpha must be between 0 and 1: {alpha}"
|
||||||
|
|
||||||
|
mx = alpha * fg + (1 - alpha) * bg
|
||||||
|
mx = mx / (mx.abs().max(dim=-1, keepdim=True).values + eps)
|
||||||
|
|
||||||
|
return mx
|
0
GPT_SoVITS/resemble_enhance/denoiser/__init__.py
Normal file
0
GPT_SoVITS/resemble_enhance/denoiser/__init__.py
Normal file
30
GPT_SoVITS/resemble_enhance/denoiser/__main__.py
Normal file
30
GPT_SoVITS/resemble_enhance/denoiser/__main__.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
|
||||||
|
from .inference import denoise
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("in_dir", type=Path, help="Path to input audio folder")
|
||||||
|
parser.add_argument("out_dir", type=Path, help="Output folder")
|
||||||
|
parser.add_argument("--run_dir", type=Path, default="runs/denoiser", help="Path to run folder")
|
||||||
|
parser.add_argument("--suffix", type=str, default=".wav", help="File suffix")
|
||||||
|
parser.add_argument("--device", type=str, default="cuda", help="Device")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
for path in args.in_dir.glob(f"**/*{args.suffix}"):
|
||||||
|
print(f"Processing {path} ..")
|
||||||
|
dwav, sr = torchaudio.load(path)
|
||||||
|
hwav, sr = denoise(dwav[0], sr, args.run_dir, args.device)
|
||||||
|
out_path = args.out_dir / path.relative_to(args.in_dir)
|
||||||
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
torchaudio.save(out_path, hwav[None], sr)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
181
GPT_SoVITS/resemble_enhance/denoiser/denoiser.py
Normal file
181
GPT_SoVITS/resemble_enhance/denoiser/denoiser.py
Normal file
@ -0,0 +1,181 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
from ..melspec import MelSpectrogram
|
||||||
|
from .hparams import HParams
|
||||||
|
from .unet import UNet
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize(x: Tensor) -> Tensor:
|
||||||
|
return x / (x.abs().max(dim=-1, keepdim=True).values + 1e-7)
|
||||||
|
|
||||||
|
|
||||||
|
class Denoiser(nn.Module):
|
||||||
|
@property
|
||||||
|
def stft_cfg(self) -> dict:
|
||||||
|
hop_size = self.hp.hop_size
|
||||||
|
return dict(hop_length=hop_size, n_fft=hop_size * 4, win_length=hop_size * 4)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_fft(self):
|
||||||
|
return self.stft_cfg["n_fft"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def eps(self):
|
||||||
|
return 1e-7
|
||||||
|
|
||||||
|
def __init__(self, hp: HParams):
|
||||||
|
super().__init__()
|
||||||
|
self.hp = hp
|
||||||
|
self.net = UNet(input_dim=3, output_dim=3)
|
||||||
|
self.mel_fn = MelSpectrogram(hp)
|
||||||
|
|
||||||
|
self.dummy: Tensor
|
||||||
|
self.register_buffer("dummy", torch.zeros(1), persistent=False)
|
||||||
|
|
||||||
|
def to_mel(self, x: Tensor, drop_last=True):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: (b t), wavs
|
||||||
|
Returns:
|
||||||
|
o: (b c t), mels
|
||||||
|
"""
|
||||||
|
if drop_last:
|
||||||
|
return self.mel_fn(x)[..., :-1] # (b d t)
|
||||||
|
return self.mel_fn(x)
|
||||||
|
|
||||||
|
def _stft(self, x):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: (b t)
|
||||||
|
Returns:
|
||||||
|
mag: (b f t) in [0, inf)
|
||||||
|
cos: (b f t) in [-1, 1]
|
||||||
|
sin: (b f t) in [-1, 1]
|
||||||
|
"""
|
||||||
|
dtype = x.dtype
|
||||||
|
device = x.device
|
||||||
|
|
||||||
|
if x.is_mps:
|
||||||
|
x = x.cpu()
|
||||||
|
|
||||||
|
window = torch.hann_window(self.stft_cfg["win_length"], device=x.device)
|
||||||
|
s = torch.stft(x.float(), **self.stft_cfg, window=window, return_complex=True) # (b f t+1)
|
||||||
|
|
||||||
|
s = s[..., :-1] # (b f t)
|
||||||
|
|
||||||
|
mag = s.abs() # (b f t)
|
||||||
|
|
||||||
|
phi = s.angle() # (b f t)
|
||||||
|
cos = phi.cos() # (b f t)
|
||||||
|
sin = phi.sin() # (b f t)
|
||||||
|
|
||||||
|
mag = mag.to(dtype=dtype, device=device)
|
||||||
|
cos = cos.to(dtype=dtype, device=device)
|
||||||
|
sin = sin.to(dtype=dtype, device=device)
|
||||||
|
|
||||||
|
return mag, cos, sin
|
||||||
|
|
||||||
|
def _istft(self, mag: Tensor, cos: Tensor, sin: Tensor):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
mag: (b f t) in [0, inf)
|
||||||
|
cos: (b f t) in [-1, 1]
|
||||||
|
sin: (b f t) in [-1, 1]
|
||||||
|
Returns:
|
||||||
|
x: (b t)
|
||||||
|
"""
|
||||||
|
device = mag.device
|
||||||
|
dtype = mag.dtype
|
||||||
|
|
||||||
|
if mag.is_mps:
|
||||||
|
mag = mag.cpu()
|
||||||
|
cos = cos.cpu()
|
||||||
|
sin = sin.cpu()
|
||||||
|
|
||||||
|
real = mag * cos # (b f t)
|
||||||
|
imag = mag * sin # (b f t)
|
||||||
|
|
||||||
|
s = torch.complex(real, imag) # (b f t)
|
||||||
|
|
||||||
|
if s.isnan().any():
|
||||||
|
logger.warning("NaN detected in ISTFT input.")
|
||||||
|
|
||||||
|
s = F.pad(s, (0, 1), "replicate") # (b f t+1)
|
||||||
|
|
||||||
|
window = torch.hann_window(self.stft_cfg["win_length"], device=s.device)
|
||||||
|
x = torch.istft(s, **self.stft_cfg, window=window, return_complex=False)
|
||||||
|
|
||||||
|
if x.isnan().any():
|
||||||
|
logger.warning("NaN detected in ISTFT output, set to zero.")
|
||||||
|
x = torch.where(x.isnan(), torch.zeros_like(x), x)
|
||||||
|
|
||||||
|
x = x.to(dtype=dtype, device=device)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _magphase(self, real, imag):
|
||||||
|
mag = (real.pow(2) + imag.pow(2) + self.eps).sqrt()
|
||||||
|
cos = real / mag
|
||||||
|
sin = imag / mag
|
||||||
|
return mag, cos, sin
|
||||||
|
|
||||||
|
def _predict(self, mag: Tensor, cos: Tensor, sin: Tensor):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
mag: (b f t)
|
||||||
|
cos: (b f t)
|
||||||
|
sin: (b f t)
|
||||||
|
Returns:
|
||||||
|
mag_mask: (b f t) in [0, 1], magnitude mask
|
||||||
|
cos_res: (b f t) in [-1, 1], phase residual
|
||||||
|
sin_res: (b f t) in [-1, 1], phase residual
|
||||||
|
"""
|
||||||
|
x = torch.stack([mag, cos, sin], dim=1) # (b 3 f t)
|
||||||
|
mag_mask, real, imag = self.net(x).unbind(1) # (b 3 f t)
|
||||||
|
mag_mask = mag_mask.sigmoid() # (b f t)
|
||||||
|
real = real.tanh() # (b f t)
|
||||||
|
imag = imag.tanh() # (b f t)
|
||||||
|
_, cos_res, sin_res = self._magphase(real, imag) # (b f t)
|
||||||
|
return mag_mask, sin_res, cos_res
|
||||||
|
|
||||||
|
def _separate(self, mag, cos, sin, mag_mask, cos_res, sin_res):
|
||||||
|
"""Ref: https://audio-agi.github.io/Separate-Anything-You-Describe/AudioSep_arXiv.pdf"""
|
||||||
|
sep_mag = F.relu(mag * mag_mask)
|
||||||
|
sep_cos = cos * cos_res - sin * sin_res
|
||||||
|
sep_sin = sin * cos_res + cos * sin_res
|
||||||
|
return sep_mag, sep_cos, sep_sin
|
||||||
|
|
||||||
|
def forward(self, x: Tensor, y: Tensor | None = None):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: (b t), a mixed audio
|
||||||
|
y: (b t), a fg audio
|
||||||
|
"""
|
||||||
|
assert x.dim() == 2, f"Expected (b t), got {x.size()}"
|
||||||
|
x = x.to(self.dummy)
|
||||||
|
x = _normalize(x)
|
||||||
|
|
||||||
|
if y is not None:
|
||||||
|
assert y.dim() == 2, f"Expected (b t), got {y.size()}"
|
||||||
|
y = y.to(self.dummy)
|
||||||
|
y = _normalize(y)
|
||||||
|
|
||||||
|
mag, cos, sin = self._stft(x) # (b 2f t)
|
||||||
|
mag_mask, sin_res, cos_res = self._predict(mag, cos, sin)
|
||||||
|
sep_mag, sep_cos, sep_sin = self._separate(mag, cos, sin, mag_mask, cos_res, sin_res)
|
||||||
|
|
||||||
|
o = self._istft(sep_mag, sep_cos, sep_sin)
|
||||||
|
|
||||||
|
npad = x.shape[-1] - o.shape[-1]
|
||||||
|
o = F.pad(o, (0, npad))
|
||||||
|
|
||||||
|
if y is not None:
|
||||||
|
self.losses = dict(l1=F.l1_loss(o, y))
|
||||||
|
|
||||||
|
return o
|
9
GPT_SoVITS/resemble_enhance/denoiser/hparams.py
Normal file
9
GPT_SoVITS/resemble_enhance/denoiser/hparams.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from ..hparams import HParams as HParamsBase
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class HParams(HParamsBase):
|
||||||
|
batch_size_per_gpu: int = 128
|
||||||
|
distort_prob: float = 0.5
|
29
GPT_SoVITS/resemble_enhance/denoiser/inference.py
Normal file
29
GPT_SoVITS/resemble_enhance/denoiser/inference.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
import logging
|
||||||
|
from functools import cache
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ..inference import inference
|
||||||
|
from .train import Denoiser, HParams
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def load_denoiser(run_dir, device):
|
||||||
|
if run_dir is None:
|
||||||
|
return Denoiser(HParams())
|
||||||
|
hp = HParams.load(run_dir)
|
||||||
|
denoiser = Denoiser(hp)
|
||||||
|
path = run_dir / "ds" / "G" / "default" / "mp_rank_00_model_states.pt"
|
||||||
|
state_dict = torch.load(path, map_location="cpu")["module"]
|
||||||
|
denoiser.load_state_dict(state_dict)
|
||||||
|
denoiser.eval()
|
||||||
|
denoiser.to(device)
|
||||||
|
return denoiser
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def denoise(dwav, sr, run_dir, device):
|
||||||
|
denoiser = load_denoiser(run_dir, device)
|
||||||
|
return inference(model=denoiser, dwav=dwav, sr=sr, device=device)
|
112
GPT_SoVITS/resemble_enhance/denoiser/train.py
Normal file
112
GPT_SoVITS/resemble_enhance/denoiser/train.py
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
import argparse
|
||||||
|
import random
|
||||||
|
from functools import partial
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import soundfile
|
||||||
|
import torch
|
||||||
|
from deepspeed import DeepSpeedConfig
|
||||||
|
from torch import Tensor
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from ..data import create_dataloaders, mix_fg_bg
|
||||||
|
from ..utils import Engine, TrainLoop, save_mels, setup_logging, tree_map
|
||||||
|
from ..utils.distributed import is_local_leader
|
||||||
|
from .denoiser import Denoiser
|
||||||
|
from .hparams import HParams
|
||||||
|
|
||||||
|
|
||||||
|
def load_G(run_dir: Path, hp: HParams | None = None, training=True):
|
||||||
|
if hp is None:
|
||||||
|
hp = HParams.load(run_dir)
|
||||||
|
assert isinstance(hp, HParams)
|
||||||
|
model = Denoiser(hp)
|
||||||
|
engine = Engine(model=model, config_class=DeepSpeedConfig(hp.deepspeed_config), ckpt_dir=run_dir / "ds" / "G")
|
||||||
|
if training:
|
||||||
|
engine.load_checkpoint()
|
||||||
|
else:
|
||||||
|
engine.load_checkpoint(load_optimizer_states=False, load_lr_scheduler_states=False)
|
||||||
|
return engine
|
||||||
|
|
||||||
|
|
||||||
|
def save_wav(path: Path, wav: Tensor, rate: int):
|
||||||
|
wav = wav.detach().cpu().numpy()
|
||||||
|
soundfile.write(path, wav, samplerate=rate)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("run_dir", type=Path)
|
||||||
|
parser.add_argument("--yaml", type=Path, default=None)
|
||||||
|
parser.add_argument("--device", type=str, default="cuda")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
setup_logging(args.run_dir)
|
||||||
|
hp = HParams.load(args.run_dir, yaml=args.yaml)
|
||||||
|
|
||||||
|
if is_local_leader():
|
||||||
|
hp.save_if_not_exists(args.run_dir)
|
||||||
|
hp.print()
|
||||||
|
|
||||||
|
train_dl, val_dl = create_dataloaders(hp, mode="denoiser")
|
||||||
|
|
||||||
|
def feed_G(engine: Engine, batch: dict[str, Tensor]):
|
||||||
|
alpha_fn = lambda: random.uniform(*hp.mix_alpha_range)
|
||||||
|
if random.random() < hp.distort_prob:
|
||||||
|
fg_wavs = batch["fg_dwavs"]
|
||||||
|
else:
|
||||||
|
fg_wavs = batch["fg_wavs"]
|
||||||
|
mx_dwavs = mix_fg_bg(fg_wavs, batch["bg_dwavs"], alpha=alpha_fn)
|
||||||
|
pred = engine(mx_dwavs, fg_wavs)
|
||||||
|
losses = engine.gather_attribute("losses", prefix="losses")
|
||||||
|
return pred, losses
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def eval_fn(engine: Engine, eval_dir, n_saved=10):
|
||||||
|
model = engine.module
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
step = engine.global_step
|
||||||
|
|
||||||
|
for i, batch in enumerate(tqdm(val_dl), 1):
|
||||||
|
batch = tree_map(lambda x: x.to(args.device) if isinstance(x, Tensor) else x, batch)
|
||||||
|
|
||||||
|
fg_dwavs = batch["fg_dwavs"] # 1 t
|
||||||
|
mx_dwavs = mix_fg_bg(fg_dwavs, batch["bg_dwavs"])
|
||||||
|
pred_fg_dwavs = model(mx_dwavs) # 1 t
|
||||||
|
|
||||||
|
mx_mels = model.to_mel(mx_dwavs) # 1 c t
|
||||||
|
fg_mels = model.to_mel(fg_dwavs) # 1 c t
|
||||||
|
pred_fg_mels = model.to_mel(pred_fg_dwavs) # 1 c t
|
||||||
|
|
||||||
|
rate = model.hp.wav_rate
|
||||||
|
get_path = lambda suffix: eval_dir / f"step_{step:08}_{i:03}{suffix}"
|
||||||
|
|
||||||
|
save_wav(get_path("_input.wav"), mx_dwavs[0], rate=rate)
|
||||||
|
save_wav(get_path("_predict.wav"), pred_fg_dwavs[0], rate=rate)
|
||||||
|
save_wav(get_path("_target.wav"), fg_dwavs[0], rate=rate)
|
||||||
|
|
||||||
|
save_mels(
|
||||||
|
get_path(".png"),
|
||||||
|
cond_mel=mx_mels[0].cpu().numpy(),
|
||||||
|
pred_mel=pred_fg_mels[0].cpu().numpy(),
|
||||||
|
targ_mel=fg_mels[0].cpu().numpy(),
|
||||||
|
)
|
||||||
|
|
||||||
|
if i >= n_saved:
|
||||||
|
break
|
||||||
|
|
||||||
|
train_loop = TrainLoop(
|
||||||
|
run_dir=args.run_dir,
|
||||||
|
train_dl=train_dl,
|
||||||
|
load_G=partial(load_G, hp=hp),
|
||||||
|
device=args.device,
|
||||||
|
feed_G=feed_G,
|
||||||
|
eval_fn=eval_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
train_loop.run(max_steps=hp.max_steps)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
144
GPT_SoVITS/resemble_enhance/denoiser/unet.py
Normal file
144
GPT_SoVITS/resemble_enhance/denoiser/unet.py
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
class PreactResBlock(nn.Sequential):
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__(
|
||||||
|
nn.GroupNorm(dim // 16, dim),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Conv2d(dim, dim, 3, padding=1),
|
||||||
|
nn.GroupNorm(dim // 16, dim),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Conv2d(dim, dim, 3, padding=1),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x + super().forward(x)
|
||||||
|
|
||||||
|
|
||||||
|
class UNetBlock(nn.Module):
|
||||||
|
def __init__(self, input_dim, output_dim=None, scale_factor=1.0):
|
||||||
|
super().__init__()
|
||||||
|
if output_dim is None:
|
||||||
|
output_dim = input_dim
|
||||||
|
self.pre_conv = nn.Conv2d(input_dim, output_dim, 3, padding=1)
|
||||||
|
self.res_block1 = PreactResBlock(output_dim)
|
||||||
|
self.res_block2 = PreactResBlock(output_dim)
|
||||||
|
self.downsample = self.upsample = nn.Identity()
|
||||||
|
if scale_factor > 1:
|
||||||
|
self.upsample = nn.Upsample(scale_factor=scale_factor)
|
||||||
|
elif scale_factor < 1:
|
||||||
|
self.downsample = nn.Upsample(scale_factor=scale_factor)
|
||||||
|
|
||||||
|
def forward(self, x, h=None):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: (b c h w), last output
|
||||||
|
h: (b c h w), skip output
|
||||||
|
Returns:
|
||||||
|
o: (b c h w), output
|
||||||
|
s: (b c h w), skip output
|
||||||
|
"""
|
||||||
|
x = self.upsample(x)
|
||||||
|
if h is not None:
|
||||||
|
assert x.shape == h.shape, f"{x.shape} != {h.shape}"
|
||||||
|
x = x + h
|
||||||
|
x = self.pre_conv(x)
|
||||||
|
x = self.res_block1(x)
|
||||||
|
x = self.res_block2(x)
|
||||||
|
return self.downsample(x), x
|
||||||
|
|
||||||
|
|
||||||
|
class UNet(nn.Module):
|
||||||
|
def __init__(self, input_dim, output_dim, hidden_dim=16, num_blocks=4, num_middle_blocks=2):
|
||||||
|
super().__init__()
|
||||||
|
self.input_dim = input_dim
|
||||||
|
self.output_dim = output_dim
|
||||||
|
self.input_proj = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
|
||||||
|
self.encoder_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
UNetBlock(input_dim=hidden_dim * 2**i, output_dim=hidden_dim * 2 ** (i + 1), scale_factor=0.5)
|
||||||
|
for i in range(num_blocks)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.middle_blocks = nn.ModuleList(
|
||||||
|
[UNetBlock(input_dim=hidden_dim * 2**num_blocks) for _ in range(num_middle_blocks)]
|
||||||
|
)
|
||||||
|
self.decoder_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
UNetBlock(input_dim=hidden_dim * 2 ** (i + 1), output_dim=hidden_dim * 2**i, scale_factor=2)
|
||||||
|
for i in reversed(range(num_blocks))
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.head = nn.Sequential(
|
||||||
|
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Conv2d(hidden_dim, output_dim, 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def scale_factor(self):
|
||||||
|
return 2 ** len(self.encoder_blocks)
|
||||||
|
|
||||||
|
def pad_to_fit(self, x):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: (b c h w), input
|
||||||
|
Returns:
|
||||||
|
x: (b c h' w'), padded input
|
||||||
|
"""
|
||||||
|
hpad = (self.scale_factor - x.shape[2] % self.scale_factor) % self.scale_factor
|
||||||
|
wpad = (self.scale_factor - x.shape[3] % self.scale_factor) % self.scale_factor
|
||||||
|
return F.pad(x, (0, wpad, 0, hpad))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: (b c h w), input
|
||||||
|
Returns:
|
||||||
|
o: (b c h w), output
|
||||||
|
"""
|
||||||
|
shape = x.shape
|
||||||
|
|
||||||
|
x = self.pad_to_fit(x)
|
||||||
|
x = self.input_proj(x)
|
||||||
|
|
||||||
|
s_list = []
|
||||||
|
for block in self.encoder_blocks:
|
||||||
|
x, s = block(x)
|
||||||
|
s_list.append(s)
|
||||||
|
|
||||||
|
for block in self.middle_blocks:
|
||||||
|
x, _ = block(x)
|
||||||
|
|
||||||
|
for block, s in zip(self.decoder_blocks, reversed(s_list)):
|
||||||
|
x, _ = block(x, s)
|
||||||
|
|
||||||
|
x = self.head(x)
|
||||||
|
x = x[..., : shape[2], : shape[3]]
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def test(self, shape=(3, 512, 256)):
|
||||||
|
import ptflops
|
||||||
|
|
||||||
|
macs, params = ptflops.get_model_complexity_info(
|
||||||
|
self,
|
||||||
|
shape,
|
||||||
|
as_strings=True,
|
||||||
|
print_per_layer_stat=True,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"macs: {macs}")
|
||||||
|
print(f"params: {params}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
model = UNet(3, 3)
|
||||||
|
model.test()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
0
GPT_SoVITS/resemble_enhance/enhancer/__init__.py
Normal file
0
GPT_SoVITS/resemble_enhance/enhancer/__init__.py
Normal file
129
GPT_SoVITS/resemble_enhance/enhancer/__main__.py
Normal file
129
GPT_SoVITS/resemble_enhance/enhancer/__main__.py
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
import argparse
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from .inference import denoise, enhance
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
|
parser.add_argument("in_dir", type=Path, help="Path to input audio folder")
|
||||||
|
parser.add_argument("out_dir", type=Path, help="Output folder")
|
||||||
|
parser.add_argument(
|
||||||
|
"--run_dir",
|
||||||
|
type=Path,
|
||||||
|
default=None,
|
||||||
|
help="Path to the enhancer run folder, if None, use the default model",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--suffix",
|
||||||
|
type=str,
|
||||||
|
default=".wav",
|
||||||
|
help="Audio file suffix",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--device",
|
||||||
|
type=str,
|
||||||
|
default="cuda",
|
||||||
|
help="Device to use for computation, recommended to use CUDA",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--denoise_only",
|
||||||
|
action="store_true",
|
||||||
|
help="Only apply denoising without enhancement",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lambd",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
help="Denoise strength for enhancement (0.0 to 1.0)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tau",
|
||||||
|
type=float,
|
||||||
|
default=0.5,
|
||||||
|
help="CFM prior temperature (0.0 to 1.0)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--solver",
|
||||||
|
type=str,
|
||||||
|
default="midpoint",
|
||||||
|
choices=["midpoint", "rk4", "euler"],
|
||||||
|
help="Numerical solver to use",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--nfe",
|
||||||
|
type=int,
|
||||||
|
default=64,
|
||||||
|
help="Number of function evaluations",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--parallel_mode",
|
||||||
|
action="store_true",
|
||||||
|
help="Shuffle the audio paths and skip the existing ones, enabling multiple jobs to run in parallel",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
device = args.device
|
||||||
|
|
||||||
|
if device == "cuda" and not torch.cuda.is_available():
|
||||||
|
print("CUDA is not available but --device is set to cuda, using CPU instead")
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
run_dir = args.run_dir
|
||||||
|
|
||||||
|
paths = sorted(args.in_dir.glob(f"**/*{args.suffix}"))
|
||||||
|
|
||||||
|
if args.parallel_mode:
|
||||||
|
random.shuffle(paths)
|
||||||
|
|
||||||
|
if len(paths) == 0:
|
||||||
|
print(f"No {args.suffix} files found in the following path: {args.in_dir}")
|
||||||
|
return
|
||||||
|
|
||||||
|
pbar = tqdm(paths)
|
||||||
|
|
||||||
|
for path in pbar:
|
||||||
|
out_path = args.out_dir / path.relative_to(args.in_dir)
|
||||||
|
if args.parallel_mode and out_path.exists():
|
||||||
|
continue
|
||||||
|
pbar.set_description(f"Processing {out_path}")
|
||||||
|
dwav, sr = torchaudio.load(path)
|
||||||
|
dwav = dwav.mean(0)
|
||||||
|
if args.denoise_only:
|
||||||
|
hwav, sr = denoise(
|
||||||
|
dwav=dwav,
|
||||||
|
sr=sr,
|
||||||
|
device=device,
|
||||||
|
run_dir=args.run_dir,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hwav, sr = enhance(
|
||||||
|
dwav=dwav,
|
||||||
|
sr=sr,
|
||||||
|
device=device,
|
||||||
|
nfe=args.nfe,
|
||||||
|
solver=args.solver,
|
||||||
|
lambd=args.lambd,
|
||||||
|
tau=args.tau,
|
||||||
|
run_dir=run_dir,
|
||||||
|
)
|
||||||
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
torchaudio.save(out_path, hwav[None], sr)
|
||||||
|
|
||||||
|
# Cool emoji effect saying the job is done
|
||||||
|
elapsed_time = time.perf_counter() - start_time
|
||||||
|
print(f"🌟 Enhancement done! {len(paths)} files processed in {elapsed_time:.2f}s")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
28
GPT_SoVITS/resemble_enhance/enhancer/download.py
Normal file
28
GPT_SoVITS/resemble_enhance/enhancer/download.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
RUN_NAME = "enhancer_stage2"
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_url(relpath):
|
||||||
|
return f"https://huggingface.co/ResembleAI/resemble-enhance/resolve/main/{RUN_NAME}/{relpath}?download=true"
|
||||||
|
|
||||||
|
|
||||||
|
def get_path(relpath):
|
||||||
|
return Path(__file__).parent.parent / "model_repo" / RUN_NAME / relpath
|
||||||
|
|
||||||
|
|
||||||
|
def download():
|
||||||
|
relpaths = ["hparams.yaml", "ds/G/latest", "ds/G/default/mp_rank_00_model_states.pt"]
|
||||||
|
for relpath in relpaths:
|
||||||
|
path = get_path(relpath)
|
||||||
|
if path.exists():
|
||||||
|
continue
|
||||||
|
url = get_url(relpath)
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
torch.hub.download_url_to_file(url, str(path))
|
||||||
|
return get_path("")
|
195
GPT_SoVITS/resemble_enhance/enhancer/enhancer.py
Normal file
195
GPT_SoVITS/resemble_enhance/enhancer/enhancer.py
Normal file
@ -0,0 +1,195 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
from torch import Tensor, nn
|
||||||
|
from torch.distributions import Beta
|
||||||
|
|
||||||
|
from ..common import Normalizer
|
||||||
|
from ..denoiser.inference import load_denoiser
|
||||||
|
from ..melspec import MelSpectrogram
|
||||||
|
from ..utils.distributed import global_leader_only
|
||||||
|
from ..utils.train_loop import TrainLoop
|
||||||
|
from .hparams import HParams
|
||||||
|
from .lcfm import CFM, IRMAE, LCFM
|
||||||
|
from .univnet import UnivNet
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _maybe(fn):
|
||||||
|
def _fn(*args):
|
||||||
|
if args[0] is None:
|
||||||
|
return None
|
||||||
|
return fn(*args)
|
||||||
|
|
||||||
|
return _fn
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_wav(x: Tensor):
|
||||||
|
return x / (x.abs().max(dim=-1, keepdim=True).values + 1e-7)
|
||||||
|
|
||||||
|
|
||||||
|
class Enhancer(nn.Module):
|
||||||
|
def __init__(self, hp: HParams):
|
||||||
|
super().__init__()
|
||||||
|
self.hp = hp
|
||||||
|
|
||||||
|
n_mels = self.hp.num_mels
|
||||||
|
vocoder_input_dim = n_mels + self.hp.vocoder_extra_dim
|
||||||
|
latent_dim = self.hp.lcfm_latent_dim
|
||||||
|
|
||||||
|
self.lcfm = LCFM(
|
||||||
|
IRMAE(
|
||||||
|
input_dim=n_mels,
|
||||||
|
output_dim=vocoder_input_dim,
|
||||||
|
latent_dim=latent_dim,
|
||||||
|
),
|
||||||
|
CFM(
|
||||||
|
cond_dim=n_mels,
|
||||||
|
output_dim=self.hp.lcfm_latent_dim,
|
||||||
|
solver_nfe=self.hp.cfm_solver_nfe,
|
||||||
|
solver_method=self.hp.cfm_solver_method,
|
||||||
|
time_mapping_divisor=self.hp.cfm_time_mapping_divisor,
|
||||||
|
),
|
||||||
|
z_scale=self.hp.lcfm_z_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.lcfm.set_mode_(self.hp.lcfm_training_mode)
|
||||||
|
|
||||||
|
self.mel_fn = MelSpectrogram(hp)
|
||||||
|
self.vocoder = UnivNet(self.hp, vocoder_input_dim)
|
||||||
|
self.denoiser = load_denoiser(self.hp.denoiser_run_dir, "cpu")
|
||||||
|
self.normalizer = Normalizer()
|
||||||
|
|
||||||
|
self._eval_lambd = 0.0
|
||||||
|
|
||||||
|
self.dummy: Tensor
|
||||||
|
self.register_buffer("dummy", torch.zeros(1))
|
||||||
|
|
||||||
|
if self.hp.enhancer_stage1_run_dir is not None:
|
||||||
|
pretrained_path = self.hp.enhancer_stage1_run_dir / "ds/G/default/mp_rank_00_model_states.pt"
|
||||||
|
self._load_pretrained(pretrained_path)
|
||||||
|
|
||||||
|
logger.info(f"{self.__class__.__name__} summary")
|
||||||
|
logger.info(f"{self.summarize()}")
|
||||||
|
|
||||||
|
def _load_pretrained(self, path):
|
||||||
|
# Clone is necessary as otherwise it holds a reference to the original model
|
||||||
|
cfm_state_dict = {k: v.clone() for k, v in self.lcfm.cfm.state_dict().items()}
|
||||||
|
denoiser_state_dict = {k: v.clone() for k, v in self.denoiser.state_dict().items()}
|
||||||
|
state_dict = torch.load(path, map_location="cpu")["module"]
|
||||||
|
self.load_state_dict(state_dict, strict=False)
|
||||||
|
self.lcfm.cfm.load_state_dict(cfm_state_dict) # Reset cfm
|
||||||
|
self.denoiser.load_state_dict(denoiser_state_dict) # Reset denoiser
|
||||||
|
logger.info(f"Loaded pretrained model from {path}")
|
||||||
|
|
||||||
|
def summarize(self):
|
||||||
|
npa_train = lambda m: sum(p.numel() for p in m.parameters() if p.requires_grad)
|
||||||
|
npa = lambda m: sum(p.numel() for p in m.parameters())
|
||||||
|
rows = []
|
||||||
|
for name, module in self.named_children():
|
||||||
|
rows.append(dict(name=name, trainable=npa_train(module), total=npa(module)))
|
||||||
|
rows.append(dict(name="total", trainable=npa_train(self), total=npa(self)))
|
||||||
|
df = pd.DataFrame(rows)
|
||||||
|
return df.to_markdown(index=False)
|
||||||
|
|
||||||
|
def to_mel(self, x: Tensor, drop_last=True):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: (b t), wavs
|
||||||
|
Returns:
|
||||||
|
o: (b c t), mels
|
||||||
|
"""
|
||||||
|
if drop_last:
|
||||||
|
return self.mel_fn(x)[..., :-1] # (b d t)
|
||||||
|
return self.mel_fn(x)
|
||||||
|
|
||||||
|
@global_leader_only
|
||||||
|
@torch.no_grad()
|
||||||
|
def _visualize(self, original_mel, denoised_mel):
|
||||||
|
loop = TrainLoop.get_running_loop()
|
||||||
|
if loop is None or loop.global_step % 100 != 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
plt.figure(figsize=(6, 6))
|
||||||
|
plt.subplot(211)
|
||||||
|
plt.title("Original")
|
||||||
|
plt.imshow(original_mel[0].cpu().numpy(), origin="lower", interpolation="none")
|
||||||
|
plt.subplot(212)
|
||||||
|
plt.title("Denoised")
|
||||||
|
plt.imshow(denoised_mel[0].cpu().numpy(), origin="lower", interpolation="none")
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
path = loop.get_running_loop_viz_path("input", ".png")
|
||||||
|
plt.savefig(path, dpi=300)
|
||||||
|
|
||||||
|
def _may_denoise(self, x: Tensor, y: Tensor | None = None):
|
||||||
|
if self.hp.lcfm_training_mode == "cfm":
|
||||||
|
return self.denoiser(x, y)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def configurate_(self, nfe, solver, lambd, tau):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
nfe: number of function evaluations
|
||||||
|
solver: solver method
|
||||||
|
lambd: denoiser strength [0, 1]
|
||||||
|
tau: prior temperature [0, 1]
|
||||||
|
"""
|
||||||
|
self.lcfm.cfm.solver.configurate_(nfe, solver)
|
||||||
|
self.lcfm.eval_tau_(tau)
|
||||||
|
self._eval_lambd = lambd
|
||||||
|
|
||||||
|
def forward(self, x: Tensor, y: Tensor | None = None, z: Tensor | None = None):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: (b t), mix wavs (fg + bg)
|
||||||
|
y: (b t), fg clean wavs
|
||||||
|
z: (b t), fg distorted wavs
|
||||||
|
Returns:
|
||||||
|
o: (b t), reconstructed wavs
|
||||||
|
"""
|
||||||
|
assert x.dim() == 2, f"Expected (b t), got {x.size()}"
|
||||||
|
assert y is None or y.dim() == 2, f"Expected (b t), got {y.size()}"
|
||||||
|
|
||||||
|
if self.hp.lcfm_training_mode == "cfm":
|
||||||
|
self.normalizer.eval()
|
||||||
|
|
||||||
|
x = _normalize_wav(x)
|
||||||
|
y = _maybe(_normalize_wav)(y)
|
||||||
|
z = _maybe(_normalize_wav)(z)
|
||||||
|
|
||||||
|
x_mel_original = self.normalizer(self.to_mel(x), update=False) # (b d t)
|
||||||
|
|
||||||
|
if self.hp.lcfm_training_mode == "cfm":
|
||||||
|
if self.training:
|
||||||
|
lambd = Beta(0.2, 0.2).sample(x.shape[:1]).to(x.device)
|
||||||
|
lambd = lambd[:, None, None]
|
||||||
|
x_mel_denoised = self.normalizer(self.to_mel(self._may_denoise(x, z)), update=False)
|
||||||
|
x_mel_denoised = x_mel_denoised.detach()
|
||||||
|
x_mel_denoised = lambd * x_mel_denoised + (1 - lambd) * x_mel_original
|
||||||
|
self._visualize(x_mel_original, x_mel_denoised)
|
||||||
|
else:
|
||||||
|
lambd = self._eval_lambd
|
||||||
|
if lambd == 0:
|
||||||
|
x_mel_denoised = x_mel_original
|
||||||
|
else:
|
||||||
|
x_mel_denoised = self.normalizer(self.to_mel(self._may_denoise(x, z)), update=False)
|
||||||
|
x_mel_denoised = x_mel_denoised.detach()
|
||||||
|
x_mel_denoised = lambd * x_mel_denoised + (1 - lambd) * x_mel_original
|
||||||
|
else:
|
||||||
|
x_mel_denoised = x_mel_original
|
||||||
|
|
||||||
|
y_mel = _maybe(self.to_mel)(y) # (b d t)
|
||||||
|
y_mel = _maybe(self.normalizer)(y_mel)
|
||||||
|
|
||||||
|
lcfm_decoded = self.lcfm(x_mel_denoised, y_mel, ψ0=x_mel_original) # (b d t)
|
||||||
|
|
||||||
|
if lcfm_decoded is None:
|
||||||
|
o = None
|
||||||
|
else:
|
||||||
|
o = self.vocoder(lcfm_decoded, y)
|
||||||
|
|
||||||
|
return o
|
23
GPT_SoVITS/resemble_enhance/enhancer/hparams.py
Normal file
23
GPT_SoVITS/resemble_enhance/enhancer/hparams.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from ..hparams import HParams as HParamsBase
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class HParams(HParamsBase):
|
||||||
|
cfm_solver_method: str = "midpoint"
|
||||||
|
cfm_solver_nfe: int = 64
|
||||||
|
cfm_time_mapping_divisor: int = 4
|
||||||
|
univnet_nc: int = 96
|
||||||
|
|
||||||
|
lcfm_latent_dim: int = 64
|
||||||
|
lcfm_training_mode: str = "ae"
|
||||||
|
lcfm_z_scale: float = 5
|
||||||
|
|
||||||
|
vocoder_extra_dim: int = 32
|
||||||
|
|
||||||
|
gan_training_start_step: int | None = 5_000
|
||||||
|
enhancer_stage1_run_dir: Path | None = None
|
||||||
|
|
||||||
|
denoiser_run_dir: Path | None = None
|
49
GPT_SoVITS/resemble_enhance/enhancer/inference.py
Normal file
49
GPT_SoVITS/resemble_enhance/enhancer/inference.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
import logging
|
||||||
|
from functools import cache
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ..inference import inference
|
||||||
|
from .download import download
|
||||||
|
from .train import Enhancer, HParams
|
||||||
|
|
||||||
|
import platform
|
||||||
|
import pathlib
|
||||||
|
|
||||||
|
# Check if the current system is Windows
|
||||||
|
if platform.system() == 'Windows':
|
||||||
|
# Make changes specific to Windows
|
||||||
|
temp = pathlib.PosixPath
|
||||||
|
pathlib.PosixPath = pathlib.WindowsPath
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def load_enhancer(run_dir, device):
|
||||||
|
if run_dir is None:
|
||||||
|
run_dir = download()
|
||||||
|
hp = HParams.load(run_dir)
|
||||||
|
enhancer = Enhancer(hp)
|
||||||
|
path = run_dir / "ds" / "G" / "default" / "mp_rank_00_model_states.pt"
|
||||||
|
state_dict = torch.load(path, map_location="cpu")["module"]
|
||||||
|
enhancer.load_state_dict(state_dict)
|
||||||
|
enhancer.eval()
|
||||||
|
enhancer.to(device)
|
||||||
|
return enhancer
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def denoise(dwav, sr, device, run_dir=None):
|
||||||
|
enhancer = load_enhancer(run_dir, device)
|
||||||
|
return inference(model=enhancer.denoiser, dwav=dwav, sr=sr, device=device)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def enhance(chunk_seconds, chunks_overlap,dwav, sr, device, nfe=32, solver="midpoint", lambd=0.5, tau=0.5, run_dir=None):
|
||||||
|
assert 0 < nfe <= 128, f"nfe must be in (0, 128], got {nfe}"
|
||||||
|
assert solver in ("midpoint", "rk4", "euler"), f"solver must be in ('midpoint', 'rk4', 'euler'), got {solver}"
|
||||||
|
assert 0 <= lambd <= 1, f"lambd must be in [0, 1], got {lambd}"
|
||||||
|
assert 0 <= tau <= 1, f"tau must be in [0, 1], got {tau}"
|
||||||
|
enhancer = load_enhancer(run_dir, device)
|
||||||
|
enhancer.configurate_(nfe=nfe, solver=solver, lambd=lambd, tau=tau)
|
||||||
|
return inference(model=enhancer, chunk_seconds=chunk_seconds, overlap_seconds=chunks_overlap, dwav=dwav, sr=sr, device=device)
|
2
GPT_SoVITS/resemble_enhance/enhancer/lcfm/__init__.py
Normal file
2
GPT_SoVITS/resemble_enhance/enhancer/lcfm/__init__.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
from .irmae import IRMAE
|
||||||
|
from .lcfm import CFM, LCFM
|
372
GPT_SoVITS/resemble_enhance/enhancer/lcfm/cfm.py
Normal file
372
GPT_SoVITS/resemble_enhance/enhancer/lcfm/cfm.py
Normal file
@ -0,0 +1,372 @@
|
|||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from functools import partial
|
||||||
|
from typing import Protocol
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import scipy
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import Tensor, nn
|
||||||
|
from tqdm import trange
|
||||||
|
|
||||||
|
from .wn import WN
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class VelocityField(Protocol):
|
||||||
|
def __call__(self, *, t: Tensor, ψt: Tensor, dt: Tensor) -> Tensor:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class Solver:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
method="midpoint",
|
||||||
|
nfe=32,
|
||||||
|
viz_name="solver",
|
||||||
|
viz_every=100,
|
||||||
|
mel_fn=None,
|
||||||
|
time_mapping_divisor=4,
|
||||||
|
verbose=False,
|
||||||
|
):
|
||||||
|
self.configurate_(nfe=nfe, method=method)
|
||||||
|
|
||||||
|
self.verbose = verbose
|
||||||
|
self.viz_every = viz_every
|
||||||
|
self.viz_name = viz_name
|
||||||
|
|
||||||
|
self._camera = None
|
||||||
|
self._mel_fn = mel_fn
|
||||||
|
self._time_mapping = partial(self.exponential_decay_mapping, n=time_mapping_divisor)
|
||||||
|
|
||||||
|
def configurate_(self, nfe=None, method=None):
|
||||||
|
if nfe is None:
|
||||||
|
nfe = self.nfe
|
||||||
|
|
||||||
|
if method is None:
|
||||||
|
method = self.method
|
||||||
|
|
||||||
|
if nfe == 1 and method in ("midpoint", "rk4"):
|
||||||
|
logger.warning(f"1 NFE is not supported for {method}, using euler method instead.")
|
||||||
|
method = "euler"
|
||||||
|
|
||||||
|
self.nfe = nfe
|
||||||
|
self.method = method
|
||||||
|
|
||||||
|
@property
|
||||||
|
def time_mapping(self):
|
||||||
|
return self._time_mapping
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def exponential_decay_mapping(t, n=4):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
n: target step
|
||||||
|
"""
|
||||||
|
|
||||||
|
def h(t, a):
|
||||||
|
return (a**t - 1) / (a - 1)
|
||||||
|
|
||||||
|
# Solve h(1/n) = 0.5
|
||||||
|
a = float(scipy.optimize.fsolve(lambda a: h(1 / n, a) - 0.5, x0=0))
|
||||||
|
|
||||||
|
t = h(t, a=a)
|
||||||
|
|
||||||
|
return t
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def _maybe_camera_snap(self, *, ψt, t):
|
||||||
|
camera = self._camera
|
||||||
|
if camera is not None:
|
||||||
|
if ψt.shape[1] == 1:
|
||||||
|
# Waveform, b 1 t, plot every 100 samples
|
||||||
|
plt.subplot(211)
|
||||||
|
plt.plot(ψt.detach().cpu().numpy()[0, 0, ::100], color="blue")
|
||||||
|
if self._mel_fn is not None:
|
||||||
|
plt.subplot(212)
|
||||||
|
mel = self._mel_fn(ψt.detach().cpu().numpy()[0, 0])
|
||||||
|
plt.imshow(mel, origin="lower", interpolation="none")
|
||||||
|
elif ψt.shape[1] == 2:
|
||||||
|
# Complex
|
||||||
|
plt.subplot(121)
|
||||||
|
plt.imshow(
|
||||||
|
ψt.detach().cpu().numpy()[0, 0],
|
||||||
|
origin="lower",
|
||||||
|
interpolation="none",
|
||||||
|
)
|
||||||
|
plt.subplot(122)
|
||||||
|
plt.imshow(
|
||||||
|
ψt.detach().cpu().numpy()[0, 1],
|
||||||
|
origin="lower",
|
||||||
|
interpolation="none",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Spectrogram, b c t
|
||||||
|
plt.imshow(ψt.detach().cpu().numpy()[0], origin="lower", interpolation="none")
|
||||||
|
ax = plt.gca()
|
||||||
|
ax.text(0.5, 1.01, f"t={t:.2f}", transform=ax.transAxes, ha="center")
|
||||||
|
camera.snap()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _euler_step(t, ψt, dt, f: VelocityField):
|
||||||
|
return ψt + dt * f(t=t, ψt=ψt, dt=dt)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _midpoint_step(t, ψt, dt, f: VelocityField):
|
||||||
|
return ψt + dt * f(t=t + dt / 2, ψt=ψt + dt * f(t=t, ψt=ψt, dt=dt) / 2, dt=dt)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _rk4_step(t, ψt, dt, f: VelocityField):
|
||||||
|
k1 = f(t=t, ψt=ψt, dt=dt)
|
||||||
|
k2 = f(t=t + dt / 2, ψt=ψt + dt * k1 / 2, dt=dt)
|
||||||
|
k3 = f(t=t + dt / 2, ψt=ψt + dt * k2 / 2, dt=dt)
|
||||||
|
k4 = f(t=t + dt, ψt=ψt + dt * k3, dt=dt)
|
||||||
|
return ψt + dt * (k1 + 2 * k2 + 2 * k3 + k4) / 6
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _step(self):
|
||||||
|
if self.method == "euler":
|
||||||
|
return self._euler_step
|
||||||
|
elif self.method == "midpoint":
|
||||||
|
return self._midpoint_step
|
||||||
|
elif self.method == "rk4":
|
||||||
|
return self._rk4_step
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown method: {self.method}")
|
||||||
|
|
||||||
|
def get_running_train_loop(self):
|
||||||
|
try:
|
||||||
|
# Lazy import
|
||||||
|
from ...utils.train_loop import TrainLoop
|
||||||
|
|
||||||
|
return TrainLoop.get_running_loop()
|
||||||
|
except ImportError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def visualizing(self):
|
||||||
|
loop = self.get_running_train_loop()
|
||||||
|
if loop is None:
|
||||||
|
return
|
||||||
|
out_path = loop.make_current_step_viz_path(self.viz_name, ".gif")
|
||||||
|
return loop.global_step % self.viz_every == 0 and not out_path.exists()
|
||||||
|
|
||||||
|
def _reset_camera(self):
|
||||||
|
try:
|
||||||
|
from celluloid import Camera
|
||||||
|
|
||||||
|
self._camera = Camera(plt.figure())
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _maybe_dump_camera(self):
|
||||||
|
camera = self._camera
|
||||||
|
loop = self.get_running_train_loop()
|
||||||
|
if camera is not None and loop is not None:
|
||||||
|
animation = camera.animate()
|
||||||
|
out_path = loop.make_current_step_viz_path(self.viz_name, ".gif")
|
||||||
|
out_path.parent.mkdir(exist_ok=True, parents=True)
|
||||||
|
animation.save(out_path, writer="pillow", fps=4)
|
||||||
|
plt.close()
|
||||||
|
self._camera = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_steps(self):
|
||||||
|
n = self.nfe
|
||||||
|
if self.method == "euler":
|
||||||
|
pass
|
||||||
|
elif self.method == "midpoint":
|
||||||
|
n //= 2
|
||||||
|
elif self.method == "rk4":
|
||||||
|
n //= 4
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown method: {self.method}")
|
||||||
|
return n
|
||||||
|
|
||||||
|
def solve(self, f: VelocityField, ψ0: Tensor, t0=0.0, t1=1.0):
|
||||||
|
ts = self._time_mapping(np.linspace(t0, t1, self.n_steps + 1))
|
||||||
|
|
||||||
|
if self.visualizing:
|
||||||
|
self._reset_camera()
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
steps = trange(self.n_steps, desc="CFM inference")
|
||||||
|
else:
|
||||||
|
steps = range(self.n_steps)
|
||||||
|
|
||||||
|
ψt = ψ0
|
||||||
|
|
||||||
|
for i in steps:
|
||||||
|
dt = ts[i + 1] - ts[i]
|
||||||
|
t = ts[i]
|
||||||
|
self._maybe_camera_snap(ψt=ψt, t=t)
|
||||||
|
ψt = self._step(t=t, ψt=ψt, dt=dt, f=f)
|
||||||
|
|
||||||
|
self._maybe_camera_snap(ψt=ψt, t=ts[-1])
|
||||||
|
|
||||||
|
ψ1 = ψt
|
||||||
|
del ψt
|
||||||
|
|
||||||
|
self._maybe_dump_camera()
|
||||||
|
|
||||||
|
return ψ1
|
||||||
|
|
||||||
|
def __call__(self, f: VelocityField, ψ0: Tensor, t0=0.0, t1=1.0):
|
||||||
|
return self.solve(f=f, ψ0=ψ0, t0=t0, t1=t1)
|
||||||
|
|
||||||
|
|
||||||
|
class SinusodialTimeEmbedding(nn.Module):
|
||||||
|
def __init__(self, d_embed):
|
||||||
|
super().__init__()
|
||||||
|
self.d_embed = d_embed
|
||||||
|
assert d_embed % 2 == 0
|
||||||
|
|
||||||
|
def forward(self, t):
|
||||||
|
t = t.unsqueeze(-1) # ... 1
|
||||||
|
p = torch.linspace(0, 4, self.d_embed // 2).to(t)
|
||||||
|
while p.dim() < t.dim():
|
||||||
|
p = p.unsqueeze(0) # ... d/2
|
||||||
|
sin = torch.sin(t * 10**p)
|
||||||
|
cos = torch.cos(t * 10**p)
|
||||||
|
return torch.cat([sin, cos], dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(eq=False)
|
||||||
|
class CFM(nn.Module):
|
||||||
|
"""
|
||||||
|
This mixin is for general diffusion models.
|
||||||
|
|
||||||
|
ψ0 stands for the gaussian noise, and ψ1 is the data point.
|
||||||
|
|
||||||
|
Here we follow the CFM style:
|
||||||
|
The generation process (reverse process) is from t=0 to t=1.
|
||||||
|
The forward process is from t=1 to t=0.
|
||||||
|
"""
|
||||||
|
|
||||||
|
cond_dim: int
|
||||||
|
output_dim: int
|
||||||
|
time_emb_dim: int = 128
|
||||||
|
viz_name: str = "cfm"
|
||||||
|
solver_nfe: int = 32
|
||||||
|
solver_method: str = "midpoint"
|
||||||
|
time_mapping_divisor: int = 4
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.solver = Solver(
|
||||||
|
viz_name=self.viz_name,
|
||||||
|
viz_every=1,
|
||||||
|
nfe=self.solver_nfe,
|
||||||
|
method=self.solver_method,
|
||||||
|
time_mapping_divisor=self.time_mapping_divisor,
|
||||||
|
)
|
||||||
|
self.emb = SinusodialTimeEmbedding(self.time_emb_dim)
|
||||||
|
self.net = WN(
|
||||||
|
input_dim=self.output_dim,
|
||||||
|
output_dim=self.output_dim,
|
||||||
|
local_dim=self.cond_dim,
|
||||||
|
global_dim=self.time_emb_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _perturb(self, ψ1: Tensor, t: Tensor | None = None):
|
||||||
|
"""
|
||||||
|
Perturb ψ1 to ψt.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _sample_ψ0(self, x: Tensor):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: (b c t), which implies the shape of ψ0
|
||||||
|
"""
|
||||||
|
shape = list(x.shape)
|
||||||
|
shape[1] = self.output_dim
|
||||||
|
if self.training:
|
||||||
|
g = None
|
||||||
|
else:
|
||||||
|
g = torch.Generator(device=x.device)
|
||||||
|
g.manual_seed(0) # deterministic sampling during eval
|
||||||
|
ψ0 = torch.randn(shape, device=x.device, dtype=x.dtype, generator=g)
|
||||||
|
return ψ0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sigma(self):
|
||||||
|
return 1e-4
|
||||||
|
|
||||||
|
def _to_ψt(self, *, ψ1: Tensor, ψ0: Tensor, t: Tensor):
|
||||||
|
"""
|
||||||
|
Eq (22)
|
||||||
|
"""
|
||||||
|
while t.dim() < ψ1.dim():
|
||||||
|
t = t.unsqueeze(-1)
|
||||||
|
μ = t * ψ1 + (1 - t) * ψ0
|
||||||
|
return μ + torch.randn_like(μ) * self.sigma
|
||||||
|
|
||||||
|
def _to_u(self, *, ψ1, ψ0: Tensor):
|
||||||
|
"""
|
||||||
|
Eq (21)
|
||||||
|
"""
|
||||||
|
return ψ1 - ψ0
|
||||||
|
|
||||||
|
def _to_v(self, *, ψt, x, t: float | Tensor):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
ψt: (b c t)
|
||||||
|
x: (b c t)
|
||||||
|
t: (b)
|
||||||
|
Returns:
|
||||||
|
v: (b c t)
|
||||||
|
"""
|
||||||
|
if isinstance(t, (float, int)):
|
||||||
|
t = torch.full(ψt.shape[:1], t).to(ψt)
|
||||||
|
t = t.clamp(0, 1) # [0, 1)
|
||||||
|
g = self.emb(t) # (b d)
|
||||||
|
v = self.net(ψt, l=x, g=g)
|
||||||
|
return v
|
||||||
|
|
||||||
|
def compute_losses(self, x, y, ψ0) -> dict:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: (b c t)
|
||||||
|
y: (b c t)
|
||||||
|
Returns:
|
||||||
|
losses: dict
|
||||||
|
"""
|
||||||
|
t = torch.rand(len(x), device=x.device, dtype=x.dtype)
|
||||||
|
t = self.solver.time_mapping(t)
|
||||||
|
|
||||||
|
if ψ0 is None:
|
||||||
|
ψ0 = self._sample_ψ0(x)
|
||||||
|
|
||||||
|
ψt = self._to_ψt(ψ1=y, t=t, ψ0=ψ0)
|
||||||
|
|
||||||
|
v = self._to_v(ψt=ψt, t=t, x=x)
|
||||||
|
u = self._to_u(ψ1=y, ψ0=ψ0)
|
||||||
|
|
||||||
|
losses = dict(l1=F.l1_loss(v, u))
|
||||||
|
|
||||||
|
return losses
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def sample(self, x, ψ0=None, t0=0.0):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: (b c t)
|
||||||
|
Returns:
|
||||||
|
y: (b ... t)
|
||||||
|
"""
|
||||||
|
if ψ0 is None:
|
||||||
|
ψ0 = self._sample_ψ0(x)
|
||||||
|
f = lambda t, ψt, dt: self._to_v(ψt=ψt, t=t, x=x)
|
||||||
|
ψ1 = self.solver(f=f, ψ0=ψ0, t0=t0)
|
||||||
|
return ψ1
|
||||||
|
|
||||||
|
def forward(self, x: Tensor, y: Tensor | None = None, ψ0: Tensor | None = None, t0=0.0):
|
||||||
|
if y is None:
|
||||||
|
y = self.sample(x, ψ0=ψ0, t0=t0)
|
||||||
|
else:
|
||||||
|
self.losses = self.compute_losses(x, y, ψ0=ψ0)
|
||||||
|
return y
|
123
GPT_SoVITS/resemble_enhance/enhancer/lcfm/irmae.py
Normal file
123
GPT_SoVITS/resemble_enhance/enhancer/lcfm/irmae.py
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import Tensor, nn
|
||||||
|
from torch.nn.utils.parametrizations import weight_norm
|
||||||
|
|
||||||
|
from ...common import Normalizer
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class IRMAEOutput:
|
||||||
|
latent: Tensor # latent vector
|
||||||
|
decoded: Tensor | None # decoder output, include extra dim
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock(nn.Sequential):
|
||||||
|
def __init__(self, channels, dilations=[1, 2, 4, 8]):
|
||||||
|
wn = weight_norm
|
||||||
|
super().__init__(
|
||||||
|
nn.GroupNorm(32, channels),
|
||||||
|
nn.GELU(),
|
||||||
|
wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[0])),
|
||||||
|
nn.GroupNorm(32, channels),
|
||||||
|
nn.GELU(),
|
||||||
|
wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[1])),
|
||||||
|
nn.GroupNorm(32, channels),
|
||||||
|
nn.GELU(),
|
||||||
|
wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[2])),
|
||||||
|
nn.GroupNorm(32, channels),
|
||||||
|
nn.GELU(),
|
||||||
|
wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[3])),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor):
|
||||||
|
return x + super().forward(x)
|
||||||
|
|
||||||
|
|
||||||
|
class IRMAE(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_dim,
|
||||||
|
output_dim,
|
||||||
|
latent_dim,
|
||||||
|
hidden_dim=1024,
|
||||||
|
num_irms=4,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
input_dim: input dimension
|
||||||
|
output_dim: output dimension
|
||||||
|
latent_dim: latent dimension
|
||||||
|
hidden_dim: hidden layer dimension
|
||||||
|
num_irm_matrics: number of implicit rank minimization matrices
|
||||||
|
norm: normalization layer
|
||||||
|
"""
|
||||||
|
self.input_dim = input_dim
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.encoder = nn.Sequential(
|
||||||
|
nn.Conv1d(input_dim, hidden_dim, 3, padding="same"),
|
||||||
|
*[ResBlock(hidden_dim) for _ in range(4)],
|
||||||
|
# Try to obtain compact representation (https://proceedings.neurips.cc/paper/2020/file/a9078e8653368c9c291ae2f8b74012e7-Paper.pdf)
|
||||||
|
*[nn.Conv1d(hidden_dim if i == 0 else latent_dim, latent_dim, 1, bias=False) for i in range(num_irms)],
|
||||||
|
nn.Tanh(),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.decoder = nn.Sequential(
|
||||||
|
nn.Conv1d(latent_dim, hidden_dim, 3, padding="same"),
|
||||||
|
*[ResBlock(hidden_dim) for _ in range(4)],
|
||||||
|
nn.Conv1d(hidden_dim, output_dim, 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.head = nn.Sequential(
|
||||||
|
nn.Conv1d(output_dim, hidden_dim, 3, padding="same"),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Conv1d(hidden_dim, input_dim, 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.estimator = Normalizer()
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: (b c t) tensor
|
||||||
|
"""
|
||||||
|
z = self.encoder(x) # (b c t)
|
||||||
|
_ = self.estimator(z) # Estimate the glboal mean and std of z
|
||||||
|
self.stats = {}
|
||||||
|
self.stats["z_mean"] = z.mean().item()
|
||||||
|
self.stats["z_std"] = z.std().item()
|
||||||
|
self.stats["z_abs_68"] = z.abs().quantile(0.6827).item()
|
||||||
|
self.stats["z_abs_95"] = z.abs().quantile(0.9545).item()
|
||||||
|
self.stats["z_abs_99"] = z.abs().quantile(0.9973).item()
|
||||||
|
return z
|
||||||
|
|
||||||
|
def decode(self, z):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
z: (b c t) tensor
|
||||||
|
"""
|
||||||
|
return self.decoder(z)
|
||||||
|
|
||||||
|
def forward(self, x, skip_decoding=False):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: (b c t) tensor
|
||||||
|
skip_decoding: if True, skip the decoding step
|
||||||
|
"""
|
||||||
|
z = self.encode(x) # q(z|x)
|
||||||
|
|
||||||
|
if skip_decoding:
|
||||||
|
# This speeds up the training in cfm only mode
|
||||||
|
decoded = None
|
||||||
|
else:
|
||||||
|
decoded = self.decode(z) # p(x|z)
|
||||||
|
predicted = self.head(decoded)
|
||||||
|
self.losses = dict(mse=F.mse_loss(predicted, x))
|
||||||
|
|
||||||
|
return IRMAEOutput(latent=z, decoded=decoded)
|
152
GPT_SoVITS/resemble_enhance/enhancer/lcfm/lcfm.py
Normal file
152
GPT_SoVITS/resemble_enhance/enhancer/lcfm/lcfm.py
Normal file
@ -0,0 +1,152 @@
|
|||||||
|
import logging
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
from .cfm import CFM
|
||||||
|
from .irmae import IRMAE, IRMAEOutput
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def freeze_(module):
|
||||||
|
for p in module.parameters():
|
||||||
|
p.requires_grad_(False)
|
||||||
|
|
||||||
|
|
||||||
|
class LCFM(nn.Module):
|
||||||
|
class Mode(Enum):
|
||||||
|
AE = "ae"
|
||||||
|
CFM = "cfm"
|
||||||
|
|
||||||
|
def __init__(self, ae: IRMAE, cfm: CFM, z_scale: float = 1.0):
|
||||||
|
super().__init__()
|
||||||
|
self.ae = ae
|
||||||
|
self.cfm = cfm
|
||||||
|
self.z_scale = z_scale
|
||||||
|
self._mode = None
|
||||||
|
self._eval_tau = 0.5
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mode(self):
|
||||||
|
return self._mode
|
||||||
|
|
||||||
|
def set_mode_(self, mode):
|
||||||
|
mode = self.Mode(mode)
|
||||||
|
self._mode = mode
|
||||||
|
|
||||||
|
if mode == mode.AE:
|
||||||
|
freeze_(self.cfm)
|
||||||
|
logger.info("Freeze cfm")
|
||||||
|
elif mode == mode.CFM:
|
||||||
|
freeze_(self.ae)
|
||||||
|
logger.info("Freeze ae (encoder and decoder)")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown training mode: {mode}")
|
||||||
|
|
||||||
|
def get_running_train_loop(self):
|
||||||
|
try:
|
||||||
|
# Lazy import
|
||||||
|
from ...utils.train_loop import TrainLoop
|
||||||
|
|
||||||
|
return TrainLoop.get_running_loop()
|
||||||
|
except ImportError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def global_step(self):
|
||||||
|
loop = self.get_running_train_loop()
|
||||||
|
if loop is None:
|
||||||
|
return None
|
||||||
|
return loop.global_step
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def _visualize(self, x, y, y_):
|
||||||
|
loop = self.get_running_train_loop()
|
||||||
|
if loop is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
plt.subplot(221)
|
||||||
|
plt.imshow(y[0].detach().cpu().numpy(), aspect="auto", origin="lower", interpolation="none")
|
||||||
|
plt.title("GT")
|
||||||
|
|
||||||
|
plt.subplot(222)
|
||||||
|
y_ = y_[:, : y.shape[1]]
|
||||||
|
plt.imshow(y_[0].detach().cpu().numpy(), aspect="auto", origin="lower", interpolation="none")
|
||||||
|
plt.title("Posterior")
|
||||||
|
|
||||||
|
plt.subplot(223)
|
||||||
|
z_ = self.cfm(x)
|
||||||
|
y__ = self.ae.decode(z_)
|
||||||
|
y__ = y__[:, : y.shape[1]]
|
||||||
|
plt.imshow(y__[0].detach().cpu().numpy(), aspect="auto", origin="lower", interpolation="none")
|
||||||
|
plt.title("C-Prior")
|
||||||
|
del y__
|
||||||
|
|
||||||
|
plt.subplot(224)
|
||||||
|
z_ = torch.randn_like(z_)
|
||||||
|
y__ = self.ae.decode(z_)
|
||||||
|
y__ = y__[:, : y.shape[1]]
|
||||||
|
plt.imshow(y__[0].detach().cpu().numpy(), aspect="auto", origin="lower", interpolation="none")
|
||||||
|
plt.title("Prior")
|
||||||
|
del z_, y__
|
||||||
|
|
||||||
|
path = loop.make_current_step_viz_path("recon", ".png")
|
||||||
|
path.parent.mkdir(exist_ok=True, parents=True)
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(path, dpi=500)
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
def _scale(self, z: Tensor):
|
||||||
|
return z * self.z_scale
|
||||||
|
|
||||||
|
def _unscale(self, z: Tensor):
|
||||||
|
return z / self.z_scale
|
||||||
|
|
||||||
|
def eval_tau_(self, tau):
|
||||||
|
self._eval_tau = tau
|
||||||
|
|
||||||
|
def forward(self, x, y: Tensor | None = None, ψ0: Tensor | None = None):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: (b d t), condition mel
|
||||||
|
y: (b d t), target mel
|
||||||
|
ψ0: (b d t), starting mel
|
||||||
|
"""
|
||||||
|
if self.mode == self.Mode.CFM:
|
||||||
|
self.ae.eval() # Always set to eval when training cfm
|
||||||
|
|
||||||
|
if ψ0 is not None:
|
||||||
|
ψ0 = self._scale(self.ae.encode(ψ0))
|
||||||
|
if self.training:
|
||||||
|
tau = torch.rand_like(ψ0[:, :1, :1])
|
||||||
|
else:
|
||||||
|
tau = self._eval_tau
|
||||||
|
ψ0 = tau * torch.randn_like(ψ0) + (1 - tau) * ψ0
|
||||||
|
|
||||||
|
if y is None:
|
||||||
|
if self.mode == self.Mode.AE:
|
||||||
|
with torch.no_grad():
|
||||||
|
training = self.ae.training
|
||||||
|
self.ae.eval()
|
||||||
|
z = self.ae.encode(x)
|
||||||
|
self.ae.train(training)
|
||||||
|
else:
|
||||||
|
z = self._unscale(self.cfm(x, ψ0=ψ0))
|
||||||
|
|
||||||
|
h = self.ae.decode(z)
|
||||||
|
else:
|
||||||
|
ae_output: IRMAEOutput = self.ae(y, skip_decoding=self.mode == self.Mode.CFM)
|
||||||
|
|
||||||
|
if self.mode == self.Mode.CFM:
|
||||||
|
_ = self.cfm(x, self._scale(ae_output.latent.detach()), ψ0=ψ0)
|
||||||
|
|
||||||
|
h = ae_output.decoded
|
||||||
|
|
||||||
|
if h is not None and self.global_step is not None and self.global_step % 100 == 0:
|
||||||
|
self._visualize(x[:1], y[:1], h[:1])
|
||||||
|
|
||||||
|
return h
|
147
GPT_SoVITS/resemble_enhance/enhancer/lcfm/wn.py
Normal file
147
GPT_SoVITS/resemble_enhance/enhancer/lcfm/wn.py
Normal file
@ -0,0 +1,147 @@
|
|||||||
|
import logging
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def _fused_tanh_sigmoid(h):
|
||||||
|
a, b = h.chunk(2, dim=1)
|
||||||
|
h = a.tanh() * b.sigmoid()
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
class WNLayer(nn.Module):
|
||||||
|
"""
|
||||||
|
A DiffWave-like WN
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, hidden_dim, local_dim, global_dim, kernel_size, dilation):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
local_output_dim = hidden_dim * 2
|
||||||
|
|
||||||
|
if global_dim is not None:
|
||||||
|
self.gconv = nn.Conv1d(global_dim, hidden_dim, 1)
|
||||||
|
|
||||||
|
if local_dim is not None:
|
||||||
|
self.lconv = nn.Conv1d(local_dim, local_output_dim, 1)
|
||||||
|
|
||||||
|
self.dconv = nn.Conv1d(hidden_dim, local_output_dim, kernel_size, dilation=dilation, padding="same")
|
||||||
|
|
||||||
|
self.out = nn.Conv1d(hidden_dim, 2 * hidden_dim, kernel_size=1)
|
||||||
|
|
||||||
|
def forward(self, z, l, g):
|
||||||
|
identity = z
|
||||||
|
|
||||||
|
if g is not None:
|
||||||
|
if g.dim() == 2:
|
||||||
|
g = g.unsqueeze(-1)
|
||||||
|
z = z + self.gconv(g)
|
||||||
|
|
||||||
|
z = self.dconv(z)
|
||||||
|
|
||||||
|
if l is not None:
|
||||||
|
z = z + self.lconv(l)
|
||||||
|
|
||||||
|
z = _fused_tanh_sigmoid(z)
|
||||||
|
|
||||||
|
h = self.out(z)
|
||||||
|
|
||||||
|
z, s = h.chunk(2, dim=1)
|
||||||
|
|
||||||
|
o = (z + identity) / math.sqrt(2)
|
||||||
|
|
||||||
|
return o, s
|
||||||
|
|
||||||
|
|
||||||
|
class WN(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_dim,
|
||||||
|
output_dim,
|
||||||
|
local_dim=None,
|
||||||
|
global_dim=None,
|
||||||
|
n_layers=30,
|
||||||
|
kernel_size=3,
|
||||||
|
dilation_cycle=5,
|
||||||
|
hidden_dim=512,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
assert kernel_size % 2 == 1
|
||||||
|
assert hidden_dim % 2 == 0
|
||||||
|
|
||||||
|
self.input_dim = input_dim
|
||||||
|
self.hidden_dim = hidden_dim
|
||||||
|
self.local_dim = local_dim
|
||||||
|
self.global_dim = global_dim
|
||||||
|
|
||||||
|
self.start = nn.Conv1d(input_dim, hidden_dim, 1)
|
||||||
|
if local_dim is not None:
|
||||||
|
self.local_norm = nn.InstanceNorm1d(local_dim)
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
WNLayer(
|
||||||
|
hidden_dim=hidden_dim,
|
||||||
|
local_dim=local_dim,
|
||||||
|
global_dim=global_dim,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
dilation=2 ** (i % dilation_cycle),
|
||||||
|
)
|
||||||
|
for i in range(n_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.end = nn.Conv1d(hidden_dim, output_dim, 1)
|
||||||
|
|
||||||
|
def forward(self, z, l=None, g=None):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
z: input (b c t)
|
||||||
|
l: local condition (b c t)
|
||||||
|
g: global condition (b d)
|
||||||
|
"""
|
||||||
|
z = self.start(z)
|
||||||
|
|
||||||
|
if l is not None:
|
||||||
|
l = self.local_norm(l)
|
||||||
|
|
||||||
|
# Skips
|
||||||
|
s_list = []
|
||||||
|
|
||||||
|
for layer in self.layers:
|
||||||
|
z, s = layer(z, l, g)
|
||||||
|
s_list.append(s)
|
||||||
|
|
||||||
|
s_list = torch.stack(s_list, dim=0).sum(dim=0)
|
||||||
|
s_list = s_list / math.sqrt(len(self.layers))
|
||||||
|
|
||||||
|
o = self.end(s_list)
|
||||||
|
|
||||||
|
return o
|
||||||
|
|
||||||
|
def summarize(self, length=100):
|
||||||
|
from ptflops import get_model_complexity_info
|
||||||
|
|
||||||
|
x = torch.randn(1, self.input_dim, length)
|
||||||
|
|
||||||
|
macs, params = get_model_complexity_info(
|
||||||
|
self,
|
||||||
|
(self.input_dim, length),
|
||||||
|
as_strings=True,
|
||||||
|
print_per_layer_stat=True,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Input shape: {x.shape}")
|
||||||
|
print(f"Computational complexity: {macs}")
|
||||||
|
print(f"Number of parameters: {params}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
model = WN(input_dim=64, output_dim=64)
|
||||||
|
model.summarize()
|
143
GPT_SoVITS/resemble_enhance/enhancer/train.py
Normal file
143
GPT_SoVITS/resemble_enhance/enhancer/train.py
Normal file
@ -0,0 +1,143 @@
|
|||||||
|
import argparse
|
||||||
|
import random
|
||||||
|
from functools import partial
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import soundfile
|
||||||
|
import torch
|
||||||
|
from deepspeed import DeepSpeedConfig
|
||||||
|
from torch import Tensor
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from ..data import create_dataloaders, mix_fg_bg
|
||||||
|
from ..utils import Engine, TrainLoop, save_mels, setup_logging, tree_map
|
||||||
|
from ..utils.distributed import is_local_leader
|
||||||
|
from .enhancer import Enhancer
|
||||||
|
from .hparams import HParams
|
||||||
|
from .univnet.discriminator import Discriminator
|
||||||
|
|
||||||
|
|
||||||
|
def load_G(run_dir: Path, hp: HParams | None = None, training=True):
|
||||||
|
if hp is None:
|
||||||
|
hp = HParams.load(run_dir)
|
||||||
|
assert isinstance(hp, HParams)
|
||||||
|
model = Enhancer(hp)
|
||||||
|
engine = Engine(model=model, config_class=DeepSpeedConfig(hp.deepspeed_config), ckpt_dir=run_dir / "ds" / "G")
|
||||||
|
if training:
|
||||||
|
engine.load_checkpoint()
|
||||||
|
else:
|
||||||
|
engine.load_checkpoint(load_optimizer_states=False, load_lr_scheduler_states=False)
|
||||||
|
return engine
|
||||||
|
|
||||||
|
|
||||||
|
def load_D(run_dir: Path, hp: HParams):
|
||||||
|
if hp is None:
|
||||||
|
hp = HParams.load(run_dir)
|
||||||
|
assert isinstance(hp, HParams)
|
||||||
|
model = Discriminator(hp)
|
||||||
|
engine = Engine(model=model, config_class=DeepSpeedConfig(hp.deepspeed_config), ckpt_dir=run_dir / "ds" / "D")
|
||||||
|
engine.load_checkpoint()
|
||||||
|
return engine
|
||||||
|
|
||||||
|
|
||||||
|
def save_wav(path: Path, wav: Tensor, rate: int):
|
||||||
|
wav = wav.detach().cpu().numpy()
|
||||||
|
soundfile.write(path, wav, samplerate=rate)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("run_dir", type=Path)
|
||||||
|
parser.add_argument("--yaml", type=Path, default=None)
|
||||||
|
parser.add_argument("--device", type=str, default="cuda")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
setup_logging(args.run_dir)
|
||||||
|
hp = HParams.load(args.run_dir, yaml=args.yaml)
|
||||||
|
|
||||||
|
if is_local_leader():
|
||||||
|
hp.save_if_not_exists(args.run_dir)
|
||||||
|
hp.print()
|
||||||
|
|
||||||
|
train_dl, val_dl = create_dataloaders(hp, mode="enhancer")
|
||||||
|
|
||||||
|
def feed_G(engine: Engine, batch: dict[str, Tensor]):
|
||||||
|
if hp.lcfm_training_mode == "ae":
|
||||||
|
pred = engine(batch["fg_wavs"], batch["fg_wavs"])
|
||||||
|
elif hp.lcfm_training_mode == "cfm":
|
||||||
|
alpha_fn = lambda: random.uniform(*hp.mix_alpha_range)
|
||||||
|
mx_dwavs = mix_fg_bg(batch["fg_dwavs"], batch["bg_dwavs"], alpha=alpha_fn)
|
||||||
|
pred = engine(mx_dwavs, batch["fg_wavs"], batch["fg_dwavs"])
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown training mode: {hp.lcfm_training_mode}")
|
||||||
|
losses = engine.gather_attribute("losses")
|
||||||
|
return pred, losses
|
||||||
|
|
||||||
|
def feed_D(engine: Engine, batch: dict | None, fake: Tensor):
|
||||||
|
if batch is None:
|
||||||
|
losses = engine(fake=fake)
|
||||||
|
else:
|
||||||
|
losses = engine(fake=fake, real=batch["fg_wavs"])
|
||||||
|
return losses
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def eval_fn(engine: Engine, eval_dir, n_saved=10):
|
||||||
|
assert isinstance(hp, HParams)
|
||||||
|
|
||||||
|
model = engine.module
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
step = engine.global_step
|
||||||
|
|
||||||
|
for i, batch in enumerate(tqdm(val_dl), 1):
|
||||||
|
batch = tree_map(lambda x: x.to(args.device) if isinstance(x, Tensor) else x, batch)
|
||||||
|
|
||||||
|
fg_wavs = batch["fg_wavs"] # 1 t
|
||||||
|
|
||||||
|
if hp.lcfm_training_mode == "ae":
|
||||||
|
in_dwavs = fg_wavs
|
||||||
|
elif hp.lcfm_training_mode == "cfm":
|
||||||
|
in_dwavs = mix_fg_bg(fg_wavs, batch["bg_dwavs"])
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown training mode: {hp.lcfm_training_mode}")
|
||||||
|
|
||||||
|
pred_fg_wavs = model(in_dwavs) # 1 t
|
||||||
|
|
||||||
|
in_mels = model.to_mel(in_dwavs) # 1 c t
|
||||||
|
fg_mels = model.to_mel(fg_wavs) # 1 c t
|
||||||
|
pred_fg_mels = model.to_mel(pred_fg_wavs) # 1 c t
|
||||||
|
|
||||||
|
rate = model.hp.wav_rate
|
||||||
|
get_path = lambda suffix: eval_dir / f"step_{step:08}_{i:03}{suffix}"
|
||||||
|
|
||||||
|
save_wav(get_path("_input.wav"), in_dwavs[0], rate=rate)
|
||||||
|
save_wav(get_path("_predict.wav"), pred_fg_wavs[0], rate=rate)
|
||||||
|
save_wav(get_path("_target.wav"), fg_wavs[0], rate=rate)
|
||||||
|
|
||||||
|
save_mels(
|
||||||
|
get_path(".png"),
|
||||||
|
cond_mel=in_mels[0].cpu().numpy(),
|
||||||
|
pred_mel=pred_fg_mels[0].cpu().numpy(),
|
||||||
|
targ_mel=fg_mels[0].cpu().numpy(),
|
||||||
|
)
|
||||||
|
|
||||||
|
if i >= n_saved:
|
||||||
|
break
|
||||||
|
|
||||||
|
train_loop = TrainLoop(
|
||||||
|
run_dir=args.run_dir,
|
||||||
|
train_dl=train_dl,
|
||||||
|
load_G=partial(load_G, hp=hp),
|
||||||
|
load_D=partial(load_D, hp=hp),
|
||||||
|
device=args.device,
|
||||||
|
feed_G=feed_G,
|
||||||
|
feed_D=feed_D,
|
||||||
|
eval_fn=eval_fn,
|
||||||
|
gan_training_start_step=hp.gan_training_start_step,
|
||||||
|
)
|
||||||
|
|
||||||
|
train_loop.run(max_steps=hp.max_steps)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
1
GPT_SoVITS/resemble_enhance/enhancer/univnet/__init__.py
Normal file
1
GPT_SoVITS/resemble_enhance/enhancer/univnet/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from .univnet import UnivNet
|
@ -0,0 +1,5 @@
|
|||||||
|
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
||||||
|
# LICENSE is in incl_licenses directory.
|
||||||
|
|
||||||
|
from .filter import *
|
||||||
|
from .resample import *
|
@ -0,0 +1,95 @@
|
|||||||
|
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
||||||
|
# LICENSE is in incl_licenses directory.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import math
|
||||||
|
|
||||||
|
if 'sinc' in dir(torch):
|
||||||
|
sinc = torch.sinc
|
||||||
|
else:
|
||||||
|
# This code is adopted from adefossez's julius.core.sinc under the MIT License
|
||||||
|
# https://adefossez.github.io/julius/julius/core.html
|
||||||
|
# LICENSE is in incl_licenses directory.
|
||||||
|
def sinc(x: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Implementation of sinc, i.e. sin(pi * x) / (pi * x)
|
||||||
|
__Warning__: Different to julius.sinc, the input is multiplied by `pi`!
|
||||||
|
"""
|
||||||
|
return torch.where(x == 0,
|
||||||
|
torch.tensor(1., device=x.device, dtype=x.dtype),
|
||||||
|
torch.sin(math.pi * x) / math.pi / x)
|
||||||
|
|
||||||
|
|
||||||
|
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
|
||||||
|
# https://adefossez.github.io/julius/julius/lowpass.html
|
||||||
|
# LICENSE is in incl_licenses directory.
|
||||||
|
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
|
||||||
|
even = (kernel_size % 2 == 0)
|
||||||
|
half_size = kernel_size // 2
|
||||||
|
|
||||||
|
#For kaiser window
|
||||||
|
delta_f = 4 * half_width
|
||||||
|
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
||||||
|
if A > 50.:
|
||||||
|
beta = 0.1102 * (A - 8.7)
|
||||||
|
elif A >= 21.:
|
||||||
|
beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
|
||||||
|
else:
|
||||||
|
beta = 0.
|
||||||
|
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
||||||
|
|
||||||
|
# ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
|
||||||
|
if even:
|
||||||
|
time = (torch.arange(-half_size, half_size) + 0.5)
|
||||||
|
else:
|
||||||
|
time = torch.arange(kernel_size) - half_size
|
||||||
|
if cutoff == 0:
|
||||||
|
filter_ = torch.zeros_like(time)
|
||||||
|
else:
|
||||||
|
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
|
||||||
|
# Normalize filter to have sum = 1, otherwise we will have a small leakage
|
||||||
|
# of the constant component in the input signal.
|
||||||
|
filter_ /= filter_.sum()
|
||||||
|
filter = filter_.view(1, 1, kernel_size)
|
||||||
|
|
||||||
|
return filter
|
||||||
|
|
||||||
|
|
||||||
|
class LowPassFilter1d(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
cutoff=0.5,
|
||||||
|
half_width=0.6,
|
||||||
|
stride: int = 1,
|
||||||
|
padding: bool = True,
|
||||||
|
padding_mode: str = 'replicate',
|
||||||
|
kernel_size: int = 12):
|
||||||
|
# kernel_size should be even number for stylegan3 setup,
|
||||||
|
# in this implementation, odd number is also possible.
|
||||||
|
super().__init__()
|
||||||
|
if cutoff < -0.:
|
||||||
|
raise ValueError("Minimum cutoff must be larger than zero.")
|
||||||
|
if cutoff > 0.5:
|
||||||
|
raise ValueError("A cutoff above 0.5 does not make sense.")
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.even = (kernel_size % 2 == 0)
|
||||||
|
self.pad_left = kernel_size // 2 - int(self.even)
|
||||||
|
self.pad_right = kernel_size // 2
|
||||||
|
self.stride = stride
|
||||||
|
self.padding = padding
|
||||||
|
self.padding_mode = padding_mode
|
||||||
|
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
|
||||||
|
self.register_buffer("filter", filter)
|
||||||
|
|
||||||
|
#input [B, C, T]
|
||||||
|
def forward(self, x):
|
||||||
|
_, C, _ = x.shape
|
||||||
|
|
||||||
|
if self.padding:
|
||||||
|
x = F.pad(x, (self.pad_left, self.pad_right),
|
||||||
|
mode=self.padding_mode)
|
||||||
|
out = F.conv1d(x, self.filter.expand(C, -1, -1),
|
||||||
|
stride=self.stride, groups=C)
|
||||||
|
|
||||||
|
return out
|
@ -0,0 +1,49 @@
|
|||||||
|
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
||||||
|
# LICENSE is in incl_licenses directory.
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from .filter import LowPassFilter1d
|
||||||
|
from .filter import kaiser_sinc_filter1d
|
||||||
|
|
||||||
|
|
||||||
|
class UpSample1d(nn.Module):
|
||||||
|
def __init__(self, ratio=2, kernel_size=None):
|
||||||
|
super().__init__()
|
||||||
|
self.ratio = ratio
|
||||||
|
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||||
|
self.stride = ratio
|
||||||
|
self.pad = self.kernel_size // ratio - 1
|
||||||
|
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
||||||
|
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
||||||
|
filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio,
|
||||||
|
half_width=0.6 / ratio,
|
||||||
|
kernel_size=self.kernel_size)
|
||||||
|
self.register_buffer("filter", filter)
|
||||||
|
|
||||||
|
# x: [B, C, T]
|
||||||
|
def forward(self, x):
|
||||||
|
_, C, _ = x.shape
|
||||||
|
|
||||||
|
x = F.pad(x, (self.pad, self.pad), mode='replicate')
|
||||||
|
x = self.ratio * F.conv_transpose1d(
|
||||||
|
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
|
||||||
|
x = x[..., self.pad_left:-self.pad_right]
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class DownSample1d(nn.Module):
|
||||||
|
def __init__(self, ratio=2, kernel_size=None):
|
||||||
|
super().__init__()
|
||||||
|
self.ratio = ratio
|
||||||
|
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||||
|
self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio,
|
||||||
|
half_width=0.6 / ratio,
|
||||||
|
stride=ratio,
|
||||||
|
kernel_size=self.kernel_size)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
xx = self.lowpass(x)
|
||||||
|
|
||||||
|
return xx
|
101
GPT_SoVITS/resemble_enhance/enhancer/univnet/amp.py
Normal file
101
GPT_SoVITS/resemble_enhance/enhancer/univnet/amp.py
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
# Refer from https://github.com/NVIDIA/BigVGAN
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn.utils.parametrizations import weight_norm
|
||||||
|
|
||||||
|
from .alias_free_torch import DownSample1d, UpSample1d
|
||||||
|
|
||||||
|
|
||||||
|
class SnakeBeta(nn.Module):
|
||||||
|
"""
|
||||||
|
A modified Snake function which uses separate parameters for the magnitude of the periodic components
|
||||||
|
Shape:
|
||||||
|
- Input: (B, C, T)
|
||||||
|
- Output: (B, C, T), same shape as the input
|
||||||
|
Parameters:
|
||||||
|
- alpha - trainable parameter that controls frequency
|
||||||
|
- beta - trainable parameter that controls magnitude
|
||||||
|
References:
|
||||||
|
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
||||||
|
https://arxiv.org/abs/2006.08195
|
||||||
|
Examples:
|
||||||
|
>>> a1 = snakebeta(256)
|
||||||
|
>>> x = torch.randn(256)
|
||||||
|
>>> x = a1(x)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_features, alpha=1.0, clamp=(1e-2, 50)):
|
||||||
|
"""
|
||||||
|
Initialization.
|
||||||
|
INPUT:
|
||||||
|
- in_features: shape of the input
|
||||||
|
- alpha - trainable parameter that controls frequency
|
||||||
|
- beta - trainable parameter that controls magnitude
|
||||||
|
alpha is initialized to 1 by default, higher values = higher-frequency.
|
||||||
|
beta is initialized to 1 by default, higher values = higher-magnitude.
|
||||||
|
alpha will be trained along with the rest of your model.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.in_features = in_features
|
||||||
|
self.log_alpha = nn.Parameter(torch.zeros(in_features) + math.log(alpha))
|
||||||
|
self.log_beta = nn.Parameter(torch.zeros(in_features) + math.log(alpha))
|
||||||
|
self.clamp = clamp
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Forward pass of the function.
|
||||||
|
Applies the function to the input elementwise.
|
||||||
|
SnakeBeta ∶= x + 1/b * sin^2 (xa)
|
||||||
|
"""
|
||||||
|
alpha = self.log_alpha.exp().clamp(*self.clamp)
|
||||||
|
alpha = alpha[None, :, None]
|
||||||
|
|
||||||
|
beta = self.log_beta.exp().clamp(*self.clamp)
|
||||||
|
beta = beta[None, :, None]
|
||||||
|
|
||||||
|
x = x + (1.0 / beta) * (x * alpha).sin().pow(2)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class UpActDown(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
act,
|
||||||
|
up_ratio: int = 2,
|
||||||
|
down_ratio: int = 2,
|
||||||
|
up_kernel_size: int = 12,
|
||||||
|
down_kernel_size: int = 12,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.up_ratio = up_ratio
|
||||||
|
self.down_ratio = down_ratio
|
||||||
|
self.act = act
|
||||||
|
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
||||||
|
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# x: [B,C,T]
|
||||||
|
x = self.upsample(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.downsample(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class AMPBlock(nn.Sequential):
|
||||||
|
def __init__(self, channels, *, kernel_size=3, dilations=(1, 3, 5)):
|
||||||
|
super().__init__(*(self._make_layer(channels, kernel_size, d) for d in dilations))
|
||||||
|
|
||||||
|
def _make_layer(self, channels, kernel_size, dilation):
|
||||||
|
return nn.Sequential(
|
||||||
|
weight_norm(nn.Conv1d(channels, channels, kernel_size, dilation=dilation, padding="same")),
|
||||||
|
UpActDown(act=SnakeBeta(channels)),
|
||||||
|
weight_norm(nn.Conv1d(channels, channels, kernel_size, padding="same")),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x + super().forward(x)
|
210
GPT_SoVITS/resemble_enhance/enhancer/univnet/discriminator.py
Normal file
210
GPT_SoVITS/resemble_enhance/enhancer/univnet/discriminator.py
Normal file
@ -0,0 +1,210 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import Tensor, nn
|
||||||
|
from torch.nn.utils.parametrizations import weight_norm
|
||||||
|
|
||||||
|
from ..hparams import HParams
|
||||||
|
from .mrstft import get_stft_cfgs
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PeriodNetwork(nn.Module):
|
||||||
|
def __init__(self, period):
|
||||||
|
super().__init__()
|
||||||
|
self.period = period
|
||||||
|
wn = weight_norm
|
||||||
|
self.convs = nn.ModuleList(
|
||||||
|
[
|
||||||
|
wn(nn.Conv2d(1, 64, (5, 1), (3, 1), padding=(2, 0))),
|
||||||
|
wn(nn.Conv2d(64, 128, (5, 1), (3, 1), padding=(2, 0))),
|
||||||
|
wn(nn.Conv2d(128, 256, (5, 1), (3, 1), padding=(2, 0))),
|
||||||
|
wn(nn.Conv2d(256, 512, (5, 1), (3, 1), padding=(2, 0))),
|
||||||
|
wn(nn.Conv2d(512, 1024, (5, 1), 1, padding=(2, 0))),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.conv_post = wn(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: [B, 1, T]
|
||||||
|
"""
|
||||||
|
assert x.dim() == 3, f"(B, 1, T) is expected, but got {x.shape}."
|
||||||
|
|
||||||
|
# 1d to 2d
|
||||||
|
b, c, t = x.shape
|
||||||
|
if t % self.period != 0: # pad first
|
||||||
|
n_pad = self.period - (t % self.period)
|
||||||
|
x = F.pad(x, (0, n_pad), "reflect")
|
||||||
|
t = t + n_pad
|
||||||
|
x = x.view(b, c, t // self.period, self.period)
|
||||||
|
|
||||||
|
for l in self.convs:
|
||||||
|
x = l(x)
|
||||||
|
x = F.leaky_relu(x, 0.2)
|
||||||
|
x = self.conv_post(x)
|
||||||
|
x = torch.flatten(x, 1, -1)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SpecNetwork(nn.Module):
|
||||||
|
def __init__(self, stft_cfg: dict):
|
||||||
|
super().__init__()
|
||||||
|
wn = weight_norm
|
||||||
|
self.stft_cfg = stft_cfg
|
||||||
|
self.convs = nn.ModuleList(
|
||||||
|
[
|
||||||
|
wn(nn.Conv2d(1, 32, (3, 9), padding=(1, 4))),
|
||||||
|
wn(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
|
||||||
|
wn(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
|
||||||
|
wn(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
|
||||||
|
wn(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.conv_post = wn(nn.Conv2d(32, 1, (3, 3), padding=(1, 1)))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: [B, 1, T]
|
||||||
|
"""
|
||||||
|
x = self.spectrogram(x)
|
||||||
|
x = x.unsqueeze(1)
|
||||||
|
for l in self.convs:
|
||||||
|
x = l(x)
|
||||||
|
x = F.leaky_relu(x, 0.2)
|
||||||
|
x = self.conv_post(x)
|
||||||
|
x = x.flatten(1, -1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def spectrogram(self, x):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: [B, 1, T]
|
||||||
|
"""
|
||||||
|
x = x.squeeze(1)
|
||||||
|
dtype = x.dtype
|
||||||
|
stft_cfg = dict(self.stft_cfg)
|
||||||
|
x = torch.stft(x.float(), center=False, return_complex=False, **stft_cfg)
|
||||||
|
mag = x.norm(p=2, dim=-1) # [B, F, TT]
|
||||||
|
mag = mag.to(dtype) # [B, F, TT]
|
||||||
|
return mag
|
||||||
|
|
||||||
|
|
||||||
|
class MD(nn.ModuleList):
|
||||||
|
def __init__(self, l: list):
|
||||||
|
super().__init__([self._create_network(x) for x in l])
|
||||||
|
self._loss_type = None
|
||||||
|
|
||||||
|
def loss_type_(self, loss_type):
|
||||||
|
self._loss_type = loss_type
|
||||||
|
|
||||||
|
def _create_network(self, _):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _forward_each(self, d, x, y):
|
||||||
|
assert self._loss_type is not None, "loss_type is not set."
|
||||||
|
loss_type = self._loss_type
|
||||||
|
|
||||||
|
if loss_type == "hinge":
|
||||||
|
if y == 0:
|
||||||
|
# d(x) should be small -> -1
|
||||||
|
loss = F.relu(1 + d(x)).mean()
|
||||||
|
elif y == 1:
|
||||||
|
# d(x) should be large -> 1
|
||||||
|
loss = F.relu(1 - d(x)).mean()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid y: {y}")
|
||||||
|
elif loss_type == "wgan":
|
||||||
|
if y == 0:
|
||||||
|
loss = d(x).mean()
|
||||||
|
elif y == 1:
|
||||||
|
loss = -d(x).mean()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid y: {y}")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid loss_type: {loss_type}")
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def forward(self, x, y) -> Tensor:
|
||||||
|
losses = [self._forward_each(d, x, y) for d in self]
|
||||||
|
return torch.stack(losses).mean()
|
||||||
|
|
||||||
|
|
||||||
|
class MPD(MD):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__([2, 3, 7, 13, 17])
|
||||||
|
|
||||||
|
def _create_network(self, period):
|
||||||
|
return PeriodNetwork(period)
|
||||||
|
|
||||||
|
|
||||||
|
class MRD(MD):
|
||||||
|
def __init__(self, stft_cfgs):
|
||||||
|
super().__init__(stft_cfgs)
|
||||||
|
|
||||||
|
def _create_network(self, stft_cfg):
|
||||||
|
return SpecNetwork(stft_cfg)
|
||||||
|
|
||||||
|
|
||||||
|
class Discriminator(nn.Module):
|
||||||
|
@property
|
||||||
|
def wav_rate(self):
|
||||||
|
return self.hp.wav_rate
|
||||||
|
|
||||||
|
def __init__(self, hp: HParams):
|
||||||
|
super().__init__()
|
||||||
|
self.hp = hp
|
||||||
|
self.stft_cfgs = get_stft_cfgs(hp)
|
||||||
|
self.mpd = MPD()
|
||||||
|
self.mrd = MRD(self.stft_cfgs)
|
||||||
|
self.dummy_float: Tensor
|
||||||
|
self.register_buffer("dummy_float", torch.zeros(0), persistent=False)
|
||||||
|
|
||||||
|
def loss_type_(self, loss_type):
|
||||||
|
self.mpd.loss_type_(loss_type)
|
||||||
|
self.mrd.loss_type_(loss_type)
|
||||||
|
|
||||||
|
def forward(self, fake, real=None):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
fake: [B T]
|
||||||
|
real: [B T]
|
||||||
|
"""
|
||||||
|
fake = fake.to(self.dummy_float)
|
||||||
|
|
||||||
|
if real is None:
|
||||||
|
self.loss_type_("wgan")
|
||||||
|
else:
|
||||||
|
length_difference = (fake.shape[-1] - real.shape[-1]) / real.shape[-1]
|
||||||
|
assert length_difference < 0.05, f"length_difference should be smaller than 5%"
|
||||||
|
|
||||||
|
self.loss_type_("hinge")
|
||||||
|
real = real.to(self.dummy_float)
|
||||||
|
|
||||||
|
fake = fake[..., : real.shape[-1]]
|
||||||
|
real = real[..., : fake.shape[-1]]
|
||||||
|
|
||||||
|
losses = {}
|
||||||
|
|
||||||
|
assert fake.dim() == 2, f"(B, T) is expected, but got {fake.shape}."
|
||||||
|
assert real is None or real.dim() == 2, f"(B, T) is expected, but got {real.shape}."
|
||||||
|
|
||||||
|
fake = fake.unsqueeze(1)
|
||||||
|
|
||||||
|
if real is None:
|
||||||
|
losses["mpd"] = self.mpd(fake, 1)
|
||||||
|
losses["mrd"] = self.mrd(fake, 1)
|
||||||
|
else:
|
||||||
|
real = real.unsqueeze(1)
|
||||||
|
losses["mpd_fake"] = self.mpd(fake, 0)
|
||||||
|
losses["mpd_real"] = self.mpd(real, 1)
|
||||||
|
losses["mrd_fake"] = self.mrd(fake, 0)
|
||||||
|
losses["mrd_real"] = self.mrd(real, 1)
|
||||||
|
|
||||||
|
return losses
|
281
GPT_SoVITS/resemble_enhance/enhancer/univnet/lvcnet.py
Normal file
281
GPT_SoVITS/resemble_enhance/enhancer/univnet/lvcnet.py
Normal file
@ -0,0 +1,281 @@
|
|||||||
|
""" refer from https://github.com/zceng/LVCNet """
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn.utils.parametrizations import weight_norm
|
||||||
|
|
||||||
|
from .amp import AMPBlock
|
||||||
|
|
||||||
|
|
||||||
|
class KernelPredictor(torch.nn.Module):
|
||||||
|
"""Kernel predictor for the location-variable convolutions"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cond_channels,
|
||||||
|
conv_in_channels,
|
||||||
|
conv_out_channels,
|
||||||
|
conv_layers,
|
||||||
|
conv_kernel_size=3,
|
||||||
|
kpnet_hidden_channels=64,
|
||||||
|
kpnet_conv_size=3,
|
||||||
|
kpnet_dropout=0.0,
|
||||||
|
kpnet_nonlinear_activation="LeakyReLU",
|
||||||
|
kpnet_nonlinear_activation_params={"negative_slope": 0.1},
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
cond_channels (int): number of channel for the conditioning sequence,
|
||||||
|
conv_in_channels (int): number of channel for the input sequence,
|
||||||
|
conv_out_channels (int): number of channel for the output sequence,
|
||||||
|
conv_layers (int): number of layers
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.conv_in_channels = conv_in_channels
|
||||||
|
self.conv_out_channels = conv_out_channels
|
||||||
|
self.conv_kernel_size = conv_kernel_size
|
||||||
|
self.conv_layers = conv_layers
|
||||||
|
|
||||||
|
kpnet_kernel_channels = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers # l_w
|
||||||
|
kpnet_bias_channels = conv_out_channels * conv_layers # l_b
|
||||||
|
|
||||||
|
self.input_conv = nn.Sequential(
|
||||||
|
weight_norm(nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)),
|
||||||
|
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.residual_convs = nn.ModuleList()
|
||||||
|
padding = (kpnet_conv_size - 1) // 2
|
||||||
|
for _ in range(3):
|
||||||
|
self.residual_convs.append(
|
||||||
|
nn.Sequential(
|
||||||
|
nn.Dropout(kpnet_dropout),
|
||||||
|
weight_norm(
|
||||||
|
nn.Conv1d(
|
||||||
|
kpnet_hidden_channels,
|
||||||
|
kpnet_hidden_channels,
|
||||||
|
kpnet_conv_size,
|
||||||
|
padding=padding,
|
||||||
|
bias=True,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||||
|
weight_norm(
|
||||||
|
nn.Conv1d(
|
||||||
|
kpnet_hidden_channels,
|
||||||
|
kpnet_hidden_channels,
|
||||||
|
kpnet_conv_size,
|
||||||
|
padding=padding,
|
||||||
|
bias=True,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.kernel_conv = weight_norm(
|
||||||
|
nn.Conv1d(
|
||||||
|
kpnet_hidden_channels,
|
||||||
|
kpnet_kernel_channels,
|
||||||
|
kpnet_conv_size,
|
||||||
|
padding=padding,
|
||||||
|
bias=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.bias_conv = weight_norm(
|
||||||
|
nn.Conv1d(
|
||||||
|
kpnet_hidden_channels,
|
||||||
|
kpnet_bias_channels,
|
||||||
|
kpnet_conv_size,
|
||||||
|
padding=padding,
|
||||||
|
bias=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, c):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
|
||||||
|
"""
|
||||||
|
batch, _, cond_length = c.shape
|
||||||
|
c = self.input_conv(c)
|
||||||
|
for residual_conv in self.residual_convs:
|
||||||
|
residual_conv.to(c.device)
|
||||||
|
c = c + residual_conv(c)
|
||||||
|
k = self.kernel_conv(c)
|
||||||
|
b = self.bias_conv(c)
|
||||||
|
kernels = k.contiguous().view(
|
||||||
|
batch,
|
||||||
|
self.conv_layers,
|
||||||
|
self.conv_in_channels,
|
||||||
|
self.conv_out_channels,
|
||||||
|
self.conv_kernel_size,
|
||||||
|
cond_length,
|
||||||
|
)
|
||||||
|
bias = b.contiguous().view(
|
||||||
|
batch,
|
||||||
|
self.conv_layers,
|
||||||
|
self.conv_out_channels,
|
||||||
|
cond_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
return kernels, bias
|
||||||
|
|
||||||
|
|
||||||
|
class LVCBlock(torch.nn.Module):
|
||||||
|
"""the location-variable convolutions"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
cond_channels,
|
||||||
|
stride,
|
||||||
|
dilations=[1, 3, 9, 27],
|
||||||
|
lReLU_slope=0.2,
|
||||||
|
conv_kernel_size=3,
|
||||||
|
cond_hop_length=256,
|
||||||
|
kpnet_hidden_channels=64,
|
||||||
|
kpnet_conv_size=3,
|
||||||
|
kpnet_dropout=0.0,
|
||||||
|
add_extra_noise=False,
|
||||||
|
downsampling=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.add_extra_noise = add_extra_noise
|
||||||
|
|
||||||
|
self.cond_hop_length = cond_hop_length
|
||||||
|
self.conv_layers = len(dilations)
|
||||||
|
self.conv_kernel_size = conv_kernel_size
|
||||||
|
|
||||||
|
self.kernel_predictor = KernelPredictor(
|
||||||
|
cond_channels=cond_channels,
|
||||||
|
conv_in_channels=in_channels,
|
||||||
|
conv_out_channels=2 * in_channels,
|
||||||
|
conv_layers=len(dilations),
|
||||||
|
conv_kernel_size=conv_kernel_size,
|
||||||
|
kpnet_hidden_channels=kpnet_hidden_channels,
|
||||||
|
kpnet_conv_size=kpnet_conv_size,
|
||||||
|
kpnet_dropout=kpnet_dropout,
|
||||||
|
kpnet_nonlinear_activation_params={"negative_slope": lReLU_slope},
|
||||||
|
)
|
||||||
|
|
||||||
|
if downsampling:
|
||||||
|
self.convt_pre = nn.Sequential(
|
||||||
|
nn.LeakyReLU(lReLU_slope),
|
||||||
|
weight_norm(nn.Conv1d(in_channels, in_channels, 2 * stride + 1, padding="same")),
|
||||||
|
nn.AvgPool1d(stride, stride),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if stride == 1:
|
||||||
|
self.convt_pre = nn.Sequential(
|
||||||
|
nn.LeakyReLU(lReLU_slope),
|
||||||
|
weight_norm(nn.Conv1d(in_channels, in_channels, 1)),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.convt_pre = nn.Sequential(
|
||||||
|
nn.LeakyReLU(lReLU_slope),
|
||||||
|
weight_norm(
|
||||||
|
nn.ConvTranspose1d(
|
||||||
|
in_channels,
|
||||||
|
in_channels,
|
||||||
|
2 * stride,
|
||||||
|
stride=stride,
|
||||||
|
padding=stride // 2 + stride % 2,
|
||||||
|
output_padding=stride % 2,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.amp_block = AMPBlock(in_channels)
|
||||||
|
|
||||||
|
self.conv_blocks = nn.ModuleList()
|
||||||
|
for d in dilations:
|
||||||
|
self.conv_blocks.append(
|
||||||
|
nn.Sequential(
|
||||||
|
nn.LeakyReLU(lReLU_slope),
|
||||||
|
weight_norm(nn.Conv1d(in_channels, in_channels, conv_kernel_size, dilation=d, padding="same")),
|
||||||
|
nn.LeakyReLU(lReLU_slope),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, c):
|
||||||
|
"""forward propagation of the location-variable convolutions.
|
||||||
|
Args:
|
||||||
|
x (Tensor): the input sequence (batch, in_channels, in_length)
|
||||||
|
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: the output sequence (batch, in_channels, in_length)
|
||||||
|
"""
|
||||||
|
_, in_channels, _ = x.shape # (B, c_g, L')
|
||||||
|
|
||||||
|
x = self.convt_pre(x) # (B, c_g, stride * L')
|
||||||
|
|
||||||
|
# Add one amp block just after the upsampling
|
||||||
|
x = self.amp_block(x) # (B, c_g, stride * L')
|
||||||
|
|
||||||
|
kernels, bias = self.kernel_predictor(c)
|
||||||
|
|
||||||
|
if self.add_extra_noise:
|
||||||
|
# Add extra noise to part of the feature
|
||||||
|
a, b = x.chunk(2, dim=1)
|
||||||
|
b = b + torch.randn_like(b) * 0.1
|
||||||
|
x = torch.cat([a, b], dim=1)
|
||||||
|
|
||||||
|
for i, conv in enumerate(self.conv_blocks):
|
||||||
|
output = conv(x) # (B, c_g, stride * L')
|
||||||
|
|
||||||
|
k = kernels[:, i, :, :, :, :] # (B, 2 * c_g, c_g, kernel_size, cond_length)
|
||||||
|
b = bias[:, i, :, :] # (B, 2 * c_g, cond_length)
|
||||||
|
|
||||||
|
output = self.location_variable_convolution(
|
||||||
|
output, k, b, hop_size=self.cond_hop_length
|
||||||
|
) # (B, 2 * c_g, stride * L'): LVC
|
||||||
|
x = x + torch.sigmoid(output[:, :in_channels, :]) * torch.tanh(
|
||||||
|
output[:, in_channels:, :]
|
||||||
|
) # (B, c_g, stride * L'): GAU
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def location_variable_convolution(self, x, kernel, bias, dilation=1, hop_size=256):
|
||||||
|
"""perform location-variable convolution operation on the input sequence (x) using the local convolution kernl.
|
||||||
|
Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100.
|
||||||
|
Args:
|
||||||
|
x (Tensor): the input sequence (batch, in_channels, in_length).
|
||||||
|
kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length)
|
||||||
|
bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length)
|
||||||
|
dilation (int): the dilation of convolution.
|
||||||
|
hop_size (int): the hop_size of the conditioning sequence.
|
||||||
|
Returns:
|
||||||
|
(Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length).
|
||||||
|
"""
|
||||||
|
batch, _, in_length = x.shape
|
||||||
|
batch, _, out_channels, kernel_size, kernel_length = kernel.shape
|
||||||
|
|
||||||
|
assert in_length == (
|
||||||
|
kernel_length * hop_size
|
||||||
|
), f"length of (x, kernel) is not matched, {in_length} != {kernel_length} * {hop_size}"
|
||||||
|
|
||||||
|
padding = dilation * int((kernel_size - 1) / 2)
|
||||||
|
x = F.pad(x, (padding, padding), "constant", 0) # (batch, in_channels, in_length + 2*padding)
|
||||||
|
x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding)
|
||||||
|
|
||||||
|
if hop_size < dilation:
|
||||||
|
x = F.pad(x, (0, dilation), "constant", 0)
|
||||||
|
x = x.unfold(
|
||||||
|
3, dilation, dilation
|
||||||
|
) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation)
|
||||||
|
x = x[:, :, :, :, :hop_size]
|
||||||
|
x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
|
||||||
|
x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size)
|
||||||
|
|
||||||
|
o = torch.einsum("bildsk,biokl->bolsd", x, kernel)
|
||||||
|
o = o.to(memory_format=torch.channels_last_3d)
|
||||||
|
bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d)
|
||||||
|
o = o + bias
|
||||||
|
o = o.contiguous().view(batch, out_channels, -1)
|
||||||
|
|
||||||
|
return o
|
128
GPT_SoVITS/resemble_enhance/enhancer/univnet/mrstft.py
Normal file
128
GPT_SoVITS/resemble_enhance/enhancer/univnet/mrstft.py
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
# Copyright 2019 Tomoki Hayashi
|
||||||
|
# MIT License (https://opensource.org/licenses/MIT)
|
||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from ..hparams import HParams
|
||||||
|
|
||||||
|
|
||||||
|
def _make_stft_cfg(hop_length, win_length=None):
|
||||||
|
if win_length is None:
|
||||||
|
win_length = 4 * hop_length
|
||||||
|
n_fft = 2 ** (win_length - 1).bit_length()
|
||||||
|
return dict(n_fft=n_fft, hop_length=hop_length, win_length=win_length)
|
||||||
|
|
||||||
|
|
||||||
|
def get_stft_cfgs(hp: HParams):
|
||||||
|
assert hp.wav_rate == 44100, f"wav_rate must be 44100, got {hp.wav_rate}"
|
||||||
|
return [_make_stft_cfg(h) for h in (100, 256, 512)]
|
||||||
|
|
||||||
|
|
||||||
|
def stft(x, n_fft, hop_length, win_length, window):
|
||||||
|
dtype = x.dtype
|
||||||
|
x = torch.stft(x.float(), n_fft, hop_length, win_length, window, return_complex=True)
|
||||||
|
x = x.abs().to(dtype)
|
||||||
|
x = x.transpose(2, 1) # (b f t) -> (b t f)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SpectralConvergengeLoss(nn.Module):
|
||||||
|
def forward(self, x_mag, y_mag):
|
||||||
|
"""Calculate forward propagation.
|
||||||
|
Args:
|
||||||
|
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
|
||||||
|
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
|
||||||
|
Returns:
|
||||||
|
Tensor: Spectral convergence loss value.
|
||||||
|
"""
|
||||||
|
return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")
|
||||||
|
|
||||||
|
|
||||||
|
class LogSTFTMagnitudeLoss(nn.Module):
|
||||||
|
def forward(self, x_mag, y_mag):
|
||||||
|
"""Calculate forward propagation.
|
||||||
|
Args:
|
||||||
|
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
|
||||||
|
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
|
||||||
|
Returns:
|
||||||
|
Tensor: Log STFT magnitude loss value.
|
||||||
|
"""
|
||||||
|
return F.l1_loss(torch.log1p(x_mag), torch.log1p(y_mag))
|
||||||
|
|
||||||
|
|
||||||
|
class STFTLoss(nn.Module):
|
||||||
|
def __init__(self, hp, stft_cfg: dict, window="hann_window"):
|
||||||
|
super().__init__()
|
||||||
|
self.hp = hp
|
||||||
|
self.stft_cfg = stft_cfg
|
||||||
|
self.spectral_convergenge_loss = SpectralConvergengeLoss()
|
||||||
|
self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
|
||||||
|
self.register_buffer("window", getattr(torch, window)(stft_cfg["win_length"]), persistent=False)
|
||||||
|
|
||||||
|
def forward(self, x, y):
|
||||||
|
"""Calculate forward propagation.
|
||||||
|
Args:
|
||||||
|
x (Tensor): Predicted signal (B, T).
|
||||||
|
y (Tensor): Groundtruth signal (B, T).
|
||||||
|
Returns:
|
||||||
|
Tensor: Spectral convergence loss value.
|
||||||
|
Tensor: Log STFT magnitude loss value.
|
||||||
|
"""
|
||||||
|
stft_cfg = dict(self.stft_cfg)
|
||||||
|
x_mag = stft(x, **stft_cfg, window=self.window) # (b t) -> (b t f)
|
||||||
|
y_mag = stft(y, **stft_cfg, window=self.window)
|
||||||
|
sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
|
||||||
|
mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
|
||||||
|
return dict(sc=sc_loss, mag=mag_loss)
|
||||||
|
|
||||||
|
|
||||||
|
class MRSTFTLoss(nn.Module):
|
||||||
|
def __init__(self, hp: HParams, window="hann_window"):
|
||||||
|
"""Initialize Multi resolution STFT loss module.
|
||||||
|
Args:
|
||||||
|
resolutions (list): List of (FFT size, hop size, window length).
|
||||||
|
window (str): Window function type.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
stft_cfgs = get_stft_cfgs(hp)
|
||||||
|
self.stft_losses = nn.ModuleList()
|
||||||
|
self.hp = hp
|
||||||
|
for c in stft_cfgs:
|
||||||
|
self.stft_losses += [STFTLoss(hp, c, window=window)]
|
||||||
|
|
||||||
|
def forward(self, x, y):
|
||||||
|
"""Calculate forward propagation.
|
||||||
|
Args:
|
||||||
|
x (Tensor): Predicted signal (b t).
|
||||||
|
y (Tensor): Groundtruth signal (b t).
|
||||||
|
Returns:
|
||||||
|
Tensor: Multi resolution spectral convergence loss value.
|
||||||
|
Tensor: Multi resolution log STFT magnitude loss value.
|
||||||
|
"""
|
||||||
|
assert x.dim() == 2 and y.dim() == 2, f"(b t) is expected, but got {x.shape} and {y.shape}."
|
||||||
|
|
||||||
|
dtype = x.dtype
|
||||||
|
|
||||||
|
x = x.float()
|
||||||
|
y = y.float()
|
||||||
|
|
||||||
|
# Align length
|
||||||
|
x = x[..., : y.shape[-1]]
|
||||||
|
y = y[..., : x.shape[-1]]
|
||||||
|
|
||||||
|
losses = {}
|
||||||
|
|
||||||
|
for f in self.stft_losses:
|
||||||
|
d = f(x, y)
|
||||||
|
for k, v in d.items():
|
||||||
|
losses.setdefault(k, []).append(v)
|
||||||
|
|
||||||
|
for k, v in losses.items():
|
||||||
|
losses[k] = torch.stack(v, dim=0).mean().to(dtype)
|
||||||
|
|
||||||
|
return losses
|
94
GPT_SoVITS/resemble_enhance/enhancer/univnet/univnet.py
Normal file
94
GPT_SoVITS/resemble_enhance/enhancer/univnet/univnet.py
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import Tensor, nn
|
||||||
|
from torch.nn.utils.parametrizations import weight_norm
|
||||||
|
|
||||||
|
from ..hparams import HParams
|
||||||
|
from .lvcnet import LVCBlock
|
||||||
|
from .mrstft import MRSTFTLoss
|
||||||
|
|
||||||
|
|
||||||
|
class UnivNet(nn.Module):
|
||||||
|
@property
|
||||||
|
def d_noise(self):
|
||||||
|
return 128
|
||||||
|
|
||||||
|
@property
|
||||||
|
def strides(self):
|
||||||
|
return [7, 5, 4, 3]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dilations(self):
|
||||||
|
return [1, 3, 9, 27]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def nc(self):
|
||||||
|
return self.hp.univnet_nc
|
||||||
|
|
||||||
|
@property
|
||||||
|
def scale_factor(self) -> int:
|
||||||
|
return self.hp.hop_size
|
||||||
|
|
||||||
|
def __init__(self, hp: HParams, d_input):
|
||||||
|
super().__init__()
|
||||||
|
self.d_input = d_input
|
||||||
|
|
||||||
|
self.hp = hp
|
||||||
|
|
||||||
|
self.blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
LVCBlock(
|
||||||
|
self.nc,
|
||||||
|
d_input,
|
||||||
|
stride=stride,
|
||||||
|
dilations=self.dilations,
|
||||||
|
cond_hop_length=hop_length,
|
||||||
|
kpnet_conv_size=3,
|
||||||
|
)
|
||||||
|
for stride, hop_length in zip(self.strides, np.cumprod(self.strides))
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.conv_pre = weight_norm(nn.Conv1d(self.d_noise, self.nc, 7, padding=3, padding_mode="reflect"))
|
||||||
|
|
||||||
|
self.conv_post = nn.Sequential(
|
||||||
|
nn.LeakyReLU(0.2),
|
||||||
|
weight_norm(nn.Conv1d(self.nc, 1, 7, padding=3, padding_mode="reflect")),
|
||||||
|
nn.Tanh(),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.mrstft = MRSTFTLoss(hp)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def eps(self):
|
||||||
|
return 1e-5
|
||||||
|
|
||||||
|
def forward(self, x: Tensor, y: Tensor | None = None, npad=10):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: (b c t), acoustic features
|
||||||
|
y: (b t), waveform
|
||||||
|
Returns:
|
||||||
|
z: (b t), waveform
|
||||||
|
"""
|
||||||
|
assert x.ndim == 3, "x must be 3D tensor"
|
||||||
|
assert y is None or y.ndim == 2, "y must be 2D tensor"
|
||||||
|
assert x.shape[1] == self.d_input, f"x.shape[1] must be {self.d_input}, but got {x.shape}"
|
||||||
|
assert npad >= 0, "npad must be positive or zero"
|
||||||
|
|
||||||
|
x = F.pad(x, (0, npad), "constant", 0)
|
||||||
|
z = torch.randn(x.shape[0], self.d_noise, x.shape[2]).to(x)
|
||||||
|
z = self.conv_pre(z) # (b c t)
|
||||||
|
|
||||||
|
for block in self.blocks:
|
||||||
|
z = block(z, x) # (b c t)
|
||||||
|
|
||||||
|
z = self.conv_post(z) # (b 1 t)
|
||||||
|
z = z[..., : -self.scale_factor * npad]
|
||||||
|
z = z.squeeze(1) # (b t)
|
||||||
|
|
||||||
|
if y is not None:
|
||||||
|
self.losses = self.mrstft(z, y)
|
||||||
|
|
||||||
|
return z
|
128
GPT_SoVITS/resemble_enhance/hparams.py
Normal file
128
GPT_SoVITS/resemble_enhance/hparams.py
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
import logging
|
||||||
|
from dataclasses import asdict, dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.panel import Panel
|
||||||
|
from rich.table import Table
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
console = Console()
|
||||||
|
|
||||||
|
|
||||||
|
def _make_stft_cfg(hop_length, win_length=None):
|
||||||
|
if win_length is None:
|
||||||
|
win_length = 4 * hop_length
|
||||||
|
n_fft = 2 ** (win_length - 1).bit_length()
|
||||||
|
return dict(n_fft=n_fft, hop_length=hop_length, win_length=win_length)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_rich_table(rows, columns, title=None):
|
||||||
|
table = Table(title=title, header_style=None)
|
||||||
|
for column in columns:
|
||||||
|
table.add_column(column.capitalize(), justify="left")
|
||||||
|
for row in rows:
|
||||||
|
table.add_row(*map(str, row))
|
||||||
|
return Panel(table, expand=False)
|
||||||
|
|
||||||
|
|
||||||
|
def _rich_print_dict(d, title="Config", key="Key", value="Value"):
|
||||||
|
console.print(_build_rich_table(d.items(), [key, value], title))
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class HParams:
|
||||||
|
# Dataset
|
||||||
|
fg_dir: Path = Path("data/fg")
|
||||||
|
bg_dir: Path = Path("data/bg")
|
||||||
|
rir_dir: Path = Path("data/rir")
|
||||||
|
load_fg_only: bool = False
|
||||||
|
praat_augment_prob: float = 0
|
||||||
|
|
||||||
|
# Audio settings
|
||||||
|
wav_rate: int = 44_100
|
||||||
|
n_fft: int = 2048
|
||||||
|
win_size: int = 2048
|
||||||
|
hop_size: int = 420 # 9.5ms
|
||||||
|
num_mels: int = 128
|
||||||
|
stft_magnitude_min: float = 1e-4
|
||||||
|
preemphasis: float = 0.97
|
||||||
|
mix_alpha_range: tuple[float, float] = (0.2, 0.8)
|
||||||
|
|
||||||
|
# Training
|
||||||
|
nj: int = 64
|
||||||
|
training_seconds: float = 1.0
|
||||||
|
batch_size_per_gpu: int = 16
|
||||||
|
min_lr: float = 1e-5
|
||||||
|
max_lr: float = 1e-4
|
||||||
|
warmup_steps: int = 1000
|
||||||
|
max_steps: int = 1_000_000
|
||||||
|
gradient_clipping: float = 1.0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def deepspeed_config(self):
|
||||||
|
return {
|
||||||
|
"train_micro_batch_size_per_gpu": self.batch_size_per_gpu,
|
||||||
|
"optimizer": {
|
||||||
|
"type": "Adam",
|
||||||
|
"params": {"lr": float(self.min_lr)},
|
||||||
|
},
|
||||||
|
"scheduler": {
|
||||||
|
"type": "WarmupDecayLR",
|
||||||
|
"params": {
|
||||||
|
"warmup_min_lr": float(self.min_lr),
|
||||||
|
"warmup_max_lr": float(self.max_lr),
|
||||||
|
"warmup_num_steps": self.warmup_steps,
|
||||||
|
"total_num_steps": self.max_steps,
|
||||||
|
"warmup_type": "linear",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"gradient_clipping": self.gradient_clipping,
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stft_cfgs(self):
|
||||||
|
assert self.wav_rate == 44_100, f"wav_rate must be 44_100, got {self.wav_rate}"
|
||||||
|
return [_make_stft_cfg(h) for h in (100, 256, 512)]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_yaml(cls, path: Path) -> "HParams":
|
||||||
|
logger.info(f"Reading hparams from {path}")
|
||||||
|
# First merge to fix types (e.g., str -> Path)
|
||||||
|
return cls(**dict(OmegaConf.merge(cls(), OmegaConf.load(path))))
|
||||||
|
|
||||||
|
def save_if_not_exists(self, run_dir: Path):
|
||||||
|
path = run_dir / "hparams.yaml"
|
||||||
|
if path.exists():
|
||||||
|
logger.info(f"{path} already exists, not saving")
|
||||||
|
return
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
OmegaConf.save(asdict(self), str(path))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, run_dir, yaml: Path | None = None):
|
||||||
|
hps = []
|
||||||
|
|
||||||
|
if (run_dir / "hparams.yaml").exists():
|
||||||
|
hps.append(cls.from_yaml(run_dir / "hparams.yaml"))
|
||||||
|
|
||||||
|
if yaml is not None:
|
||||||
|
hps.append(cls.from_yaml(yaml))
|
||||||
|
|
||||||
|
if len(hps) == 0:
|
||||||
|
hps.append(cls())
|
||||||
|
|
||||||
|
for hp in hps[1:]:
|
||||||
|
if hp != hps[0]:
|
||||||
|
errors = {}
|
||||||
|
for k, v in asdict(hp).items():
|
||||||
|
if getattr(hps[0], k) != v:
|
||||||
|
errors[k] = f"{getattr(hps[0], k)} != {v}"
|
||||||
|
raise ValueError(f"Found inconsistent hparams: {errors}, consider deleting {run_dir}")
|
||||||
|
|
||||||
|
return hps[0]
|
||||||
|
|
||||||
|
def print(self):
|
||||||
|
_rich_print_dict(asdict(self), title="HParams")
|
177
GPT_SoVITS/resemble_enhance/inference.py
Normal file
177
GPT_SoVITS/resemble_enhance/inference.py
Normal file
@ -0,0 +1,177 @@
|
|||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import gc
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.nn.utils.parametrize import remove_parametrizations
|
||||||
|
from torchaudio.functional import resample
|
||||||
|
from torchaudio.transforms import MelSpectrogram
|
||||||
|
from tqdm import trange
|
||||||
|
|
||||||
|
from .hparams import HParams
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def inference_chunk(model, dwav, sr, device, npad=441):
|
||||||
|
assert model.hp.wav_rate == sr, f"Expected {model.hp.wav_rate} Hz, got {sr} Hz"
|
||||||
|
del sr
|
||||||
|
|
||||||
|
length = dwav.shape[-1]
|
||||||
|
abs_max = dwav.abs().max().clamp(min=1e-7)
|
||||||
|
|
||||||
|
assert dwav.dim() == 1, f"Expected 1D waveform, got {dwav.dim()}D"
|
||||||
|
dwav = dwav.to(device)
|
||||||
|
dwav = dwav / abs_max # Normalize
|
||||||
|
dwav = F.pad(dwav, (0, npad))
|
||||||
|
hwav = model(dwav[None])[0].cpu() # (T,)
|
||||||
|
hwav = hwav[:length] # Trim padding
|
||||||
|
hwav = hwav * abs_max # Unnormalize
|
||||||
|
|
||||||
|
return hwav
|
||||||
|
|
||||||
|
|
||||||
|
def compute_corr(x, y):
|
||||||
|
return torch.fft.ifft(torch.fft.fft(x) * torch.fft.fft(y).conj()).abs()
|
||||||
|
|
||||||
|
|
||||||
|
def compute_offset(chunk1, chunk2, sr=44100):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
chunk1: (T,)
|
||||||
|
chunk2: (T,)
|
||||||
|
Returns:
|
||||||
|
offset: int, offset in samples such that chunk1 ~= chunk2.roll(-offset)
|
||||||
|
"""
|
||||||
|
hop_length = sr // 200 # 5 ms resolution
|
||||||
|
win_length = hop_length * 4
|
||||||
|
n_fft = 2 ** (win_length - 1).bit_length()
|
||||||
|
|
||||||
|
mel_fn = MelSpectrogram(
|
||||||
|
sample_rate=sr,
|
||||||
|
n_fft=n_fft,
|
||||||
|
win_length=win_length,
|
||||||
|
hop_length=hop_length,
|
||||||
|
n_mels=80,
|
||||||
|
f_min=0.0,
|
||||||
|
f_max=sr // 2,
|
||||||
|
)
|
||||||
|
|
||||||
|
spec1 = mel_fn(chunk1).log1p()
|
||||||
|
spec2 = mel_fn(chunk2).log1p()
|
||||||
|
|
||||||
|
corr = compute_corr(spec1, spec2) # (F, T)
|
||||||
|
corr = corr.mean(dim=0) # (T,)
|
||||||
|
|
||||||
|
argmax = corr.argmax().item()
|
||||||
|
|
||||||
|
if argmax > len(corr) // 2:
|
||||||
|
argmax -= len(corr)
|
||||||
|
|
||||||
|
offset = -argmax * hop_length
|
||||||
|
|
||||||
|
return offset
|
||||||
|
|
||||||
|
|
||||||
|
def merge_chunks(chunks, chunk_length, hop_length, sr=44100, length=None):
|
||||||
|
signal_length = (len(chunks) - 1) * hop_length + chunk_length
|
||||||
|
overlap_length = chunk_length - hop_length
|
||||||
|
signal = torch.zeros(signal_length, device=chunks[0].device)
|
||||||
|
|
||||||
|
fadein = torch.linspace(0, 1, overlap_length, device=chunks[0].device)
|
||||||
|
fadein = torch.cat([fadein, torch.ones(hop_length, device=chunks[0].device)])
|
||||||
|
fadeout = torch.linspace(1, 0, overlap_length, device=chunks[0].device)
|
||||||
|
fadeout = torch.cat([torch.ones(hop_length, device=chunks[0].device), fadeout])
|
||||||
|
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
start = i * hop_length
|
||||||
|
end = start + chunk_length
|
||||||
|
|
||||||
|
if len(chunk) < chunk_length:
|
||||||
|
chunk = F.pad(chunk, (0, chunk_length - len(chunk)))
|
||||||
|
|
||||||
|
if i > 0:
|
||||||
|
pre_region = chunks[i - 1][-overlap_length:]
|
||||||
|
cur_region = chunk[:overlap_length]
|
||||||
|
offset = compute_offset(pre_region, cur_region, sr=sr)
|
||||||
|
start -= offset
|
||||||
|
end -= offset
|
||||||
|
|
||||||
|
if i == 0:
|
||||||
|
chunk = chunk * fadeout
|
||||||
|
elif i == len(chunks) - 1:
|
||||||
|
chunk = chunk * fadein
|
||||||
|
else:
|
||||||
|
chunk = chunk * fadein * fadeout
|
||||||
|
|
||||||
|
signal[start:end] += chunk[: len(signal[start:end])]
|
||||||
|
|
||||||
|
signal = signal[:length]
|
||||||
|
|
||||||
|
return signal
|
||||||
|
|
||||||
|
|
||||||
|
def remove_weight_norm_recursively(module):
|
||||||
|
for _, module in module.named_modules():
|
||||||
|
try:
|
||||||
|
remove_parametrizations(module, "weight")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def inference(model, dwav, sr, device, chunk_seconds: float = 30.0, overlap_seconds: float = 1.0):
|
||||||
|
remove_weight_norm_recursively(model)
|
||||||
|
|
||||||
|
hp: HParams = model.hp
|
||||||
|
|
||||||
|
dwav = resample(
|
||||||
|
dwav,
|
||||||
|
orig_freq=sr,
|
||||||
|
new_freq=hp.wav_rate,
|
||||||
|
lowpass_filter_width=64,
|
||||||
|
rolloff=0.9475937167399596,
|
||||||
|
resampling_method="sinc_interp_kaiser",
|
||||||
|
beta=14.769656459379492,
|
||||||
|
)
|
||||||
|
|
||||||
|
del sr # We are now using hp.wav_rate as the sampling rate
|
||||||
|
sr = hp.wav_rate
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
chunk_length = int(sr * chunk_seconds)
|
||||||
|
overlap_length = int(sr * overlap_seconds)
|
||||||
|
hop_length = chunk_length - overlap_length
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
|
||||||
|
for start in trange(0, dwav.shape[-1], hop_length):
|
||||||
|
new_chunk = inference_chunk(model, dwav[start : start + chunk_length], sr, device)
|
||||||
|
chunks.append(new_chunk)
|
||||||
|
|
||||||
|
# Delete the processed segment to free up memory
|
||||||
|
# del new_chunk
|
||||||
|
# if torch.cuda.is_available():
|
||||||
|
# torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# Force garbage collection at this point (optional and may slow down processing)
|
||||||
|
# gc.collect()
|
||||||
|
|
||||||
|
hwav = merge_chunks(chunks, chunk_length, hop_length, sr=sr,length=dwav.shape[-1])
|
||||||
|
# Clean up chunks to free memory after merging
|
||||||
|
|
||||||
|
del chunks[:]
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
gc.collect() # Explicitly call garbage collector again
|
||||||
|
|
||||||
|
elapsed_time = time.perf_counter() - start_time
|
||||||
|
logger.info(f"Elapsed time: {elapsed_time:.3f} s, {hwav.shape[-1] / elapsed_time / 1000:.3f} kHz")
|
||||||
|
|
||||||
|
return hwav, sr
|
61
GPT_SoVITS/resemble_enhance/melspec.py
Normal file
61
GPT_SoVITS/resemble_enhance/melspec.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torchaudio.transforms import MelSpectrogram as TorchMelSpectrogram
|
||||||
|
|
||||||
|
from .hparams import HParams
|
||||||
|
|
||||||
|
|
||||||
|
class MelSpectrogram(nn.Module):
|
||||||
|
def __init__(self, hp: HParams):
|
||||||
|
"""
|
||||||
|
Torch implementation of Resemble's mel extraction.
|
||||||
|
Note that the values are NOT identical to librosa's implementation
|
||||||
|
due to floating point precisions.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.hp = hp
|
||||||
|
self.melspec = TorchMelSpectrogram(
|
||||||
|
hp.wav_rate,
|
||||||
|
n_fft=hp.n_fft,
|
||||||
|
win_length=hp.win_size,
|
||||||
|
hop_length=hp.hop_size,
|
||||||
|
f_min=0,
|
||||||
|
f_max=hp.wav_rate // 2,
|
||||||
|
n_mels=hp.num_mels,
|
||||||
|
power=1,
|
||||||
|
normalized=False,
|
||||||
|
# NOTE: Folowing librosa's default.
|
||||||
|
pad_mode="constant",
|
||||||
|
norm="slaney",
|
||||||
|
mel_scale="slaney",
|
||||||
|
)
|
||||||
|
self.register_buffer("stft_magnitude_min", torch.FloatTensor([hp.stft_magnitude_min]))
|
||||||
|
self.min_level_db = 20 * np.log10(hp.stft_magnitude_min)
|
||||||
|
self.preemphasis = hp.preemphasis
|
||||||
|
self.hop_size = hp.hop_size
|
||||||
|
|
||||||
|
def forward(self, wav, pad=True):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
wav: [B, T]
|
||||||
|
"""
|
||||||
|
device = wav.device
|
||||||
|
if wav.is_mps:
|
||||||
|
wav = wav.cpu()
|
||||||
|
self.to(wav.device)
|
||||||
|
if self.preemphasis > 0:
|
||||||
|
wav = torch.nn.functional.pad(wav, [1, 0], value=0)
|
||||||
|
wav = wav[..., 1:] - self.preemphasis * wav[..., :-1]
|
||||||
|
mel = self.melspec(wav)
|
||||||
|
mel = self._amp_to_db(mel)
|
||||||
|
mel_normed = self._normalize(mel)
|
||||||
|
assert not pad or mel_normed.shape[-1] == 1 + wav.shape[-1] // self.hop_size # Sanity check
|
||||||
|
mel_normed = mel_normed.to(device)
|
||||||
|
return mel_normed # (M, T)
|
||||||
|
|
||||||
|
def _normalize(self, s, headroom_db=15):
|
||||||
|
return (s - self.min_level_db) / (-self.min_level_db + headroom_db)
|
||||||
|
|
||||||
|
def _amp_to_db(self, x):
|
||||||
|
return x.clamp_min(self.hp.stft_magnitude_min).log10() * 20
|
Loading…
x
Reference in New Issue
Block a user