Add files via upload

添加resemble-enhance源文件
This commit is contained in:
刘悦 2024-03-11 14:18:49 +08:00 committed by GitHub
parent 931781774d
commit a3b108bfe6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
40 changed files with 3913 additions and 0 deletions

View File

@ -0,0 +1,55 @@
import logging
import torch
from torch import Tensor, nn
logger = logging.getLogger(__name__)
class Normalizer(nn.Module):
def __init__(self, momentum=0.01, eps=1e-9):
super().__init__()
self.momentum = momentum
self.eps = eps
self.running_mean_unsafe: Tensor
self.running_var_unsafe: Tensor
self.register_buffer("running_mean_unsafe", torch.full([], torch.nan))
self.register_buffer("running_var_unsafe", torch.full([], torch.nan))
@property
def started(self):
return not torch.isnan(self.running_mean_unsafe)
@property
def running_mean(self):
if not self.started:
return torch.zeros_like(self.running_mean_unsafe)
return self.running_mean_unsafe
@property
def running_std(self):
if not self.started:
return torch.ones_like(self.running_var_unsafe)
return (self.running_var_unsafe + self.eps).sqrt()
@torch.no_grad()
def _ema(self, a: Tensor, x: Tensor):
return (1 - self.momentum) * a + self.momentum * x
def update_(self, x):
if not self.started:
self.running_mean_unsafe = x.mean()
self.running_var_unsafe = x.var()
else:
self.running_mean_unsafe = self._ema(self.running_mean_unsafe, x.mean())
self.running_var_unsafe = self._ema(self.running_var_unsafe, (x - self.running_mean).pow(2).mean())
def forward(self, x: Tensor, update=True):
if self.training and update:
self.update_(x)
self.stats = dict(mean=self.running_mean.item(), std=self.running_std.item())
x = (x - self.running_mean) / self.running_std
return x
def inverse(self, x: Tensor):
return x * self.running_std + self.running_mean

View File

@ -0,0 +1,48 @@
import logging
import random
from torch.utils.data import DataLoader
from ..hparams import HParams
from .dataset import Dataset
from .utils import mix_fg_bg, rglob_audio_files
logger = logging.getLogger(__name__)
def _create_datasets(hp: HParams, mode, val_size=10, seed=123):
paths = rglob_audio_files(hp.fg_dir)
logger.info(f"Found {len(paths)} audio files in {hp.fg_dir}")
random.Random(seed).shuffle(paths)
train_paths = paths[:-val_size]
val_paths = paths[-val_size:]
train_ds = Dataset(train_paths, hp, training=True, mode=mode)
val_ds = Dataset(val_paths, hp, training=False, mode=mode)
logger.info(f"Train set: {len(train_ds)} samples - Val set: {len(val_ds)} samples")
return train_ds, val_ds
def create_dataloaders(hp: HParams, mode):
train_ds, val_ds = _create_datasets(hp=hp, mode=mode)
train_dl = DataLoader(
train_ds,
batch_size=hp.batch_size_per_gpu,
shuffle=True,
num_workers=hp.nj,
drop_last=True,
collate_fn=train_ds.collate_fn,
)
val_dl = DataLoader(
val_ds,
batch_size=1,
shuffle=False,
num_workers=hp.nj,
drop_last=False,
collate_fn=val_ds.collate_fn,
)
return train_dl, val_dl

View File

@ -0,0 +1,171 @@
import logging
import random
from pathlib import Path
import numpy as np
import torch
import torchaudio
import torchaudio.functional as AF
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset as DatasetBase
from ..hparams import HParams
from .distorter import Distorter
from .utils import rglob_audio_files
logger = logging.getLogger(__name__)
def _normalize(x):
return x / (np.abs(x).max() + 1e-7)
def _collate(batch, key, tensor=True, pad=True):
l = [d[key] for d in batch]
if l[0] is None:
return None
if tensor:
l = [torch.from_numpy(x) for x in l]
if pad:
assert tensor, "Can't pad non-tensor"
l = pad_sequence(l, batch_first=True)
return l
def praat_augment(wav, sr):
try:
import parselmouth
except ImportError:
raise ImportError("Please install parselmouth>=0.5.0 to use Praat augmentation")
# "praat-parselmouth @ git+https://github.com/YannickJadoul/Parselmouth@0bbcca69705ed73322f3712b19d71bb3694b2540",
# https://github.com/YannickJadoul/Parselmouth/issues/68
# note that this function may hang if the praat version is 0.4.3
assert wav.ndim == 1, f"wav.ndim must be 1 but got {wav.ndim}"
sound = parselmouth.Sound(wav, sr)
formant_shift_ratio = random.uniform(1.1, 1.5)
pitch_range_factor = random.uniform(0.5, 2.0)
sound = parselmouth.praat.call(sound, "Change gender", 75, 600, formant_shift_ratio, 0, pitch_range_factor, 1.0)
wav = np.array(sound.values)[0].astype(np.float32)
return wav
class Dataset(DatasetBase):
def __init__(
self,
fg_paths: list[Path],
hp: HParams,
training=True,
max_retries=100,
silent_fg_prob=0.01,
mode=False,
):
super().__init__()
assert mode in ("enhancer", "denoiser"), f"Invalid mode: {mode}"
self.hp = hp
self.fg_paths = fg_paths
self.bg_paths = rglob_audio_files(hp.bg_dir)
if len(self.fg_paths) == 0:
raise ValueError(f"No foreground audio files found in {hp.fg_dir}")
if len(self.bg_paths) == 0:
raise ValueError(f"No background audio files found in {hp.bg_dir}")
logger.info(f"Found {len(self.fg_paths)} foreground files and {len(self.bg_paths)} background files")
self.training = training
self.max_retries = max_retries
self.silent_fg_prob = silent_fg_prob
self.mode = mode
self.distorter = Distorter(hp, training=training, mode=mode)
def _load_wav(self, path, length=None, random_crop=True):
wav, sr = torchaudio.load(path)
wav = AF.resample(
waveform=wav,
orig_freq=sr,
new_freq=self.hp.wav_rate,
lowpass_filter_width=64,
rolloff=0.9475937167399596,
resampling_method="sinc_interp_kaiser",
beta=14.769656459379492,
)
wav = wav.float().numpy()
if wav.ndim == 2:
wav = np.mean(wav, axis=0)
if length is None and self.training:
length = int(self.hp.training_seconds * self.hp.wav_rate)
if length is not None:
if random_crop:
start = random.randint(0, max(0, len(wav) - length))
wav = wav[start : start + length]
else:
wav = wav[:length]
if length is not None and len(wav) < length:
wav = np.pad(wav, (0, length - len(wav)))
wav = _normalize(wav)
return wav
def _getitem_unsafe(self, index: int):
fg_path = self.fg_paths[index]
if self.training and random.random() < self.silent_fg_prob:
fg_wav = np.zeros(int(self.hp.training_seconds * self.hp.wav_rate), dtype=np.float32)
else:
fg_wav = self._load_wav(fg_path)
if random.random() < self.hp.praat_augment_prob and self.training:
fg_wav = praat_augment(fg_wav, self.hp.wav_rate)
if self.hp.load_fg_only:
bg_wav = None
fg_dwav = None
bg_dwav = None
else:
fg_dwav = _normalize(self.distorter(fg_wav, self.hp.wav_rate)).astype(np.float32)
if self.training:
bg_path = random.choice(self.bg_paths)
else:
# Deterministic for validation
bg_path = self.bg_paths[index % len(self.bg_paths)]
bg_wav = self._load_wav(bg_path, length=len(fg_wav), random_crop=self.training)
bg_dwav = _normalize(self.distorter(bg_wav, self.hp.wav_rate)).astype(np.float32)
return dict(
fg_wav=fg_wav,
bg_wav=bg_wav,
fg_dwav=fg_dwav,
bg_dwav=bg_dwav,
)
def __getitem__(self, index: int):
for i in range(self.max_retries):
try:
return self._getitem_unsafe(index)
except Exception as e:
if i == self.max_retries - 1:
raise RuntimeError(f"Failed to load {self.fg_paths[index]} after {self.max_retries} retries") from e
logger.debug(f"Error loading {self.fg_paths[index]}: {e}, skipping")
index = np.random.randint(0, len(self))
def __len__(self):
return len(self.fg_paths)
@staticmethod
def collate_fn(batch):
return dict(
fg_wavs=_collate(batch, "fg_wav"),
bg_wavs=_collate(batch, "bg_wav"),
fg_dwavs=_collate(batch, "fg_dwav"),
bg_dwavs=_collate(batch, "bg_dwav"),
)

View File

@ -0,0 +1 @@
from .distorter import Distorter

View File

@ -0,0 +1,104 @@
import itertools
import os
import random
import time
import warnings
import numpy as np
_DEBUG = bool(os.environ.get("DEBUG", False))
class Effect:
def apply(self, wav: np.ndarray, sr: int):
"""
Args:
wav: (T)
sr: sample rate
Returns:
wav: (T) with the same sample rate of `sr`
"""
raise NotImplementedError
def __call__(self, wav: np.ndarray, sr: int):
"""
Args:
wav: (T)
sr: sample rate
Returns:
wav: (T) with the same sample rate of `sr`
"""
assert len(wav.shape) == 1, wav.shape
if _DEBUG:
start = time.time()
else:
start = None
shape = wav.shape
assert wav.ndim == 1, f"{self}: Expected wav.ndim == 1, got {wav.ndim}."
wav = self.apply(wav, sr)
assert shape == wav.shape, f"{self}: {shape} != {wav.shape}."
if start is not None:
end = time.time()
print(f"{self.__class__.__name__}: {end - start:.3f} sec")
return wav
class Chain(Effect):
def __init__(self, *effects):
super().__init__()
self.effects = effects
def apply(self, wav, sr):
for effect in self.effects:
wav = effect(wav, sr)
return wav
class Maybe(Effect):
def __init__(self, prob, effect):
super().__init__()
self.prob = prob
self.effect = effect
if _DEBUG:
warnings.warn("DEBUG mode is on. Maybe -> Must.")
self.prob = 1
def apply(self, wav, sr):
if random.random() > self.prob:
return wav
return self.effect(wav, sr)
class Choice(Effect):
def __init__(self, *effects, **kwargs):
super().__init__()
self.effects = effects
self.kwargs = kwargs
def apply(self, wav, sr):
return np.random.choice(self.effects, **self.kwargs)(wav, sr)
class Permutation(Effect):
def __init__(self, *effects, n: int | None = None):
super().__init__()
self.effects = effects
self.n = n
def apply(self, wav, sr):
if self.n is None:
n = np.random.binomial(len(self.effects), 0.5)
else:
n = self.n
if n == 0:
return wav
perms = itertools.permutations(self.effects, n)
effects = random.choice(list(perms))
return Chain(*effects)(wav, sr)

View File

@ -0,0 +1,85 @@
import logging
import random
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
import librosa
import numpy as np
from scipy import signal
from ..utils import walk_paths
from .base import Effect
_logger = logging.getLogger(__name__)
@dataclass
class RandomRIR(Effect):
rir_dir: Path | None
rir_rate: int = 44_000
rir_suffix: str = ".npy"
deterministic: bool = False
@cached_property
def rir_paths(self):
if self.rir_dir is None:
return []
return list(walk_paths(self.rir_dir, self.rir_suffix))
def _sample_rir(self):
if len(self.rir_paths) == 0:
return None
if self.deterministic:
rir_path = self.rir_paths[0]
else:
rir_path = random.choice(self.rir_paths)
rir = np.squeeze(np.load(rir_path))
assert isinstance(rir, np.ndarray)
return rir
def apply(self, wav, sr):
# ref: https://github.com/haoheliu/voicefixer_main/blob/b06e07c945ac1d309b8a57ddcd599ca376b98cd9/dataloaders/augmentation/magical_effects.py#L158
if len(self.rir_paths) == 0:
return wav
length = len(wav)
wav = librosa.resample(wav, orig_sr=sr, target_sr=self.rir_rate, res_type="kaiser_fast")
rir = self._sample_rir()
wav = signal.convolve(wav, rir, mode="same")
actlev = np.max(np.abs(wav))
if actlev > 0.99:
wav = (wav / actlev) * 0.98
wav = librosa.resample(wav, orig_sr=self.rir_rate, target_sr=sr, res_type="kaiser_fast")
if abs(length - len(wav)) > 10:
_logger.warning(f"length mismatch: {length} vs {len(wav)}")
if length > len(wav):
wav = np.pad(wav, (0, length - len(wav)))
elif length < len(wav):
wav = wav[:length]
return wav
class RandomGaussianNoise(Effect):
def __init__(self, alpha_range=(0.8, 1)):
super().__init__()
self.alpha_range = alpha_range
def apply(self, wav, sr):
noise = np.random.randn(*wav.shape)
noise_energy = np.sum(noise**2)
wav_energy = np.sum(wav**2)
noise = noise * np.sqrt(wav_energy / noise_energy)
alpha = random.uniform(*self.alpha_range)
return wav * alpha + noise * (1 - alpha)

View File

@ -0,0 +1,32 @@
from ...hparams import HParams
from .base import Chain, Choice, Permutation
from .custom import RandomGaussianNoise, RandomRIR
class Distorter(Chain):
def __init__(self, hp: HParams, training: bool = False, mode: str = "enhancer"):
# Lazy import
from .sox import RandomBandpassDistorter, RandomEqualizer, RandomLowpassDistorter, RandomOverdrive, RandomReverb
if training:
permutation = Permutation(
RandomRIR(hp.rir_dir),
RandomReverb(),
RandomGaussianNoise(),
RandomOverdrive(),
RandomEqualizer(),
Choice(
RandomLowpassDistorter(),
RandomBandpassDistorter(),
),
)
if mode == "denoiser":
super().__init__(permutation)
else:
# 80%: distortion, 20%: clean
super().__init__(Choice(permutation, Chain(), p=[0.8, 0.2]))
else:
super().__init__(
RandomRIR(hp.rir_dir, deterministic=True),
RandomReverb(deterministic=True),
)

View File

@ -0,0 +1,176 @@
import logging
import os
import random
import warnings
from functools import partial
import numpy as np
import torch
try:
import augment
except ImportError:
raise ImportError(
"augment is not installed, please install it first using:"
"\npip install git+https://github.com/facebookresearch/WavAugment@54afcdb00ccc852c2f030f239f8532c9562b550e"
)
from .base import Effect
_logger = logging.getLogger(__name__)
_DEBUG = bool(os.environ.get("DEBUG", False))
class AttachableEffect(Effect):
def attach(self, chain: augment.EffectChain) -> augment.EffectChain:
raise NotImplementedError
def apply(self, wav: np.ndarray, sr: int):
chain = augment.EffectChain()
chain = self.attach(chain)
tensor = torch.from_numpy(wav)[None].float() # (1, T)
tensor = chain.apply(tensor, src_info={"rate": sr}, target_info={"channels": 1, "rate": sr})
wav = tensor.numpy()[0] # (T,)
return wav
class SoxEffect(AttachableEffect):
def __init__(self, effect_name: str, *args, **kwargs):
self.effect_name = effect_name
self.args = args
self.kwargs = kwargs
def attach(self, chain: augment.EffectChain) -> augment.EffectChain:
_logger.debug(f"Attaching {self.effect_name} with {self.args} and {self.kwargs}")
if not hasattr(chain, self.effect_name):
raise ValueError(f"EffectChain has no attribute {self.effect_name}")
return getattr(chain, self.effect_name)(*self.args, **self.kwargs)
class Maybe(AttachableEffect):
"""
Attach an effect with a probability.
"""
def __init__(self, prob: float, effect: AttachableEffect):
self.prob = prob
self.effect = effect
if _DEBUG:
warnings.warn("DEBUG mode is on. Maybe -> Must.")
self.prob = 1
def attach(self, chain: augment.EffectChain) -> augment.EffectChain:
if random.random() > self.prob:
return chain
return self.effect.attach(chain)
class Chain(AttachableEffect):
"""
Attach a chain of effects.
"""
def __init__(self, *effects: AttachableEffect):
self.effects = effects
def attach(self, chain: augment.EffectChain) -> augment.EffectChain:
for effect in self.effects:
chain = effect.attach(chain)
return chain
class Choice(AttachableEffect):
"""
Attach one of the effects randomly.
"""
def __init__(self, *effects: AttachableEffect):
self.effects = effects
def attach(self, chain: augment.EffectChain) -> augment.EffectChain:
return random.choice(self.effects).attach(chain)
class Generator:
def __call__(self) -> str:
raise NotImplementedError
class Uniform(Generator):
def __init__(self, low, high):
self.low = low
self.high = high
def __call__(self) -> str:
return str(random.uniform(self.low, self.high))
class Randint(Generator):
def __init__(self, low, high):
self.low = low
self.high = high
def __call__(self) -> str:
return str(random.randint(self.low, self.high))
class Concat(Generator):
def __init__(self, *parts: Generator | str):
self.parts = parts
def __call__(self):
return "".join([part if isinstance(part, str) else part() for part in self.parts])
class RandomLowpassDistorter(SoxEffect):
def __init__(self, low=2000, high=16000):
super().__init__("sinc", "-n", Randint(50, 200), Concat("-", Uniform(low, high)))
class RandomBandpassDistorter(SoxEffect):
def __init__(self, low=100, high=1000, min_width=2000, max_width=4000):
super().__init__("sinc", "-n", Randint(50, 200), partial(self._fn, low, high, min_width, max_width))
@staticmethod
def _fn(low, high, min_width, max_width):
start = random.randint(low, high)
stop = start + random.randint(min_width, max_width)
return f"{start}-{stop}"
class RandomEqualizer(SoxEffect):
def __init__(self, low=100, high=4000, q_low=1, q_high=5, db_low: int = -30, db_high: int = 30):
super().__init__(
"equalizer",
Uniform(low, high),
lambda: f"{random.randint(q_low, q_high)}q",
lambda: random.randint(db_low, db_high),
)
class RandomOverdrive(SoxEffect):
def __init__(self, gain_low=5, gain_high=40, colour_low=20, colour_high=80):
super().__init__("overdrive", Uniform(gain_low, gain_high), Uniform(colour_low, colour_high))
class RandomReverb(Chain):
def __init__(self, deterministic=False):
super().__init__(
SoxEffect(
"reverb",
Uniform(50, 50) if deterministic else Uniform(0, 100),
Uniform(50, 50) if deterministic else Uniform(0, 100),
Uniform(50, 50) if deterministic else Uniform(0, 100),
),
SoxEffect("channels", 1),
)
class Flanger(SoxEffect):
def __init__(self):
super().__init__("flanger")
class Phaser(SoxEffect):
def __init__(self):
super().__init__("phaser")

View File

@ -0,0 +1,43 @@
from pathlib import Path
from typing import Callable
from torch import Tensor
def walk_paths(root, suffix):
for path in Path(root).iterdir():
if path.is_dir():
yield from walk_paths(path, suffix)
elif path.suffix == suffix:
yield path
def rglob_audio_files(path: Path):
return list(walk_paths(path, ".wav")) + list(walk_paths(path, ".flac"))
def mix_fg_bg(fg: Tensor, bg: Tensor, alpha: float | Callable[..., float] = 0.5, eps=1e-7):
"""
Args:
fg: (b, t)
bg: (b, t)
"""
assert bg.shape == fg.shape, f"bg.shape != fg.shape: {bg.shape} != {fg.shape}"
fg = fg / (fg.abs().max(dim=-1, keepdim=True).values + eps)
bg = bg / (bg.abs().max(dim=-1, keepdim=True).values + eps)
fg_energy = fg.pow(2).sum(dim=-1, keepdim=True)
bg_energy = bg.pow(2).sum(dim=-1, keepdim=True)
fg = fg / (fg_energy + eps).sqrt()
bg = bg / (bg_energy + eps).sqrt()
if callable(alpha):
alpha = alpha()
assert 0 <= alpha <= 1, f"alpha must be between 0 and 1: {alpha}"
mx = alpha * fg + (1 - alpha) * bg
mx = mx / (mx.abs().max(dim=-1, keepdim=True).values + eps)
return mx

View File

@ -0,0 +1,30 @@
import argparse
from pathlib import Path
import torch
import torchaudio
from .inference import denoise
@torch.inference_mode()
def main():
parser = argparse.ArgumentParser()
parser.add_argument("in_dir", type=Path, help="Path to input audio folder")
parser.add_argument("out_dir", type=Path, help="Output folder")
parser.add_argument("--run_dir", type=Path, default="runs/denoiser", help="Path to run folder")
parser.add_argument("--suffix", type=str, default=".wav", help="File suffix")
parser.add_argument("--device", type=str, default="cuda", help="Device")
args = parser.parse_args()
for path in args.in_dir.glob(f"**/*{args.suffix}"):
print(f"Processing {path} ..")
dwav, sr = torchaudio.load(path)
hwav, sr = denoise(dwav[0], sr, args.run_dir, args.device)
out_path = args.out_dir / path.relative_to(args.in_dir)
out_path.parent.mkdir(parents=True, exist_ok=True)
torchaudio.save(out_path, hwav[None], sr)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,181 @@
import logging
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from ..melspec import MelSpectrogram
from .hparams import HParams
from .unet import UNet
logger = logging.getLogger(__name__)
def _normalize(x: Tensor) -> Tensor:
return x / (x.abs().max(dim=-1, keepdim=True).values + 1e-7)
class Denoiser(nn.Module):
@property
def stft_cfg(self) -> dict:
hop_size = self.hp.hop_size
return dict(hop_length=hop_size, n_fft=hop_size * 4, win_length=hop_size * 4)
@property
def n_fft(self):
return self.stft_cfg["n_fft"]
@property
def eps(self):
return 1e-7
def __init__(self, hp: HParams):
super().__init__()
self.hp = hp
self.net = UNet(input_dim=3, output_dim=3)
self.mel_fn = MelSpectrogram(hp)
self.dummy: Tensor
self.register_buffer("dummy", torch.zeros(1), persistent=False)
def to_mel(self, x: Tensor, drop_last=True):
"""
Args:
x: (b t), wavs
Returns:
o: (b c t), mels
"""
if drop_last:
return self.mel_fn(x)[..., :-1] # (b d t)
return self.mel_fn(x)
def _stft(self, x):
"""
Args:
x: (b t)
Returns:
mag: (b f t) in [0, inf)
cos: (b f t) in [-1, 1]
sin: (b f t) in [-1, 1]
"""
dtype = x.dtype
device = x.device
if x.is_mps:
x = x.cpu()
window = torch.hann_window(self.stft_cfg["win_length"], device=x.device)
s = torch.stft(x.float(), **self.stft_cfg, window=window, return_complex=True) # (b f t+1)
s = s[..., :-1] # (b f t)
mag = s.abs() # (b f t)
phi = s.angle() # (b f t)
cos = phi.cos() # (b f t)
sin = phi.sin() # (b f t)
mag = mag.to(dtype=dtype, device=device)
cos = cos.to(dtype=dtype, device=device)
sin = sin.to(dtype=dtype, device=device)
return mag, cos, sin
def _istft(self, mag: Tensor, cos: Tensor, sin: Tensor):
"""
Args:
mag: (b f t) in [0, inf)
cos: (b f t) in [-1, 1]
sin: (b f t) in [-1, 1]
Returns:
x: (b t)
"""
device = mag.device
dtype = mag.dtype
if mag.is_mps:
mag = mag.cpu()
cos = cos.cpu()
sin = sin.cpu()
real = mag * cos # (b f t)
imag = mag * sin # (b f t)
s = torch.complex(real, imag) # (b f t)
if s.isnan().any():
logger.warning("NaN detected in ISTFT input.")
s = F.pad(s, (0, 1), "replicate") # (b f t+1)
window = torch.hann_window(self.stft_cfg["win_length"], device=s.device)
x = torch.istft(s, **self.stft_cfg, window=window, return_complex=False)
if x.isnan().any():
logger.warning("NaN detected in ISTFT output, set to zero.")
x = torch.where(x.isnan(), torch.zeros_like(x), x)
x = x.to(dtype=dtype, device=device)
return x
def _magphase(self, real, imag):
mag = (real.pow(2) + imag.pow(2) + self.eps).sqrt()
cos = real / mag
sin = imag / mag
return mag, cos, sin
def _predict(self, mag: Tensor, cos: Tensor, sin: Tensor):
"""
Args:
mag: (b f t)
cos: (b f t)
sin: (b f t)
Returns:
mag_mask: (b f t) in [0, 1], magnitude mask
cos_res: (b f t) in [-1, 1], phase residual
sin_res: (b f t) in [-1, 1], phase residual
"""
x = torch.stack([mag, cos, sin], dim=1) # (b 3 f t)
mag_mask, real, imag = self.net(x).unbind(1) # (b 3 f t)
mag_mask = mag_mask.sigmoid() # (b f t)
real = real.tanh() # (b f t)
imag = imag.tanh() # (b f t)
_, cos_res, sin_res = self._magphase(real, imag) # (b f t)
return mag_mask, sin_res, cos_res
def _separate(self, mag, cos, sin, mag_mask, cos_res, sin_res):
"""Ref: https://audio-agi.github.io/Separate-Anything-You-Describe/AudioSep_arXiv.pdf"""
sep_mag = F.relu(mag * mag_mask)
sep_cos = cos * cos_res - sin * sin_res
sep_sin = sin * cos_res + cos * sin_res
return sep_mag, sep_cos, sep_sin
def forward(self, x: Tensor, y: Tensor | None = None):
"""
Args:
x: (b t), a mixed audio
y: (b t), a fg audio
"""
assert x.dim() == 2, f"Expected (b t), got {x.size()}"
x = x.to(self.dummy)
x = _normalize(x)
if y is not None:
assert y.dim() == 2, f"Expected (b t), got {y.size()}"
y = y.to(self.dummy)
y = _normalize(y)
mag, cos, sin = self._stft(x) # (b 2f t)
mag_mask, sin_res, cos_res = self._predict(mag, cos, sin)
sep_mag, sep_cos, sep_sin = self._separate(mag, cos, sin, mag_mask, cos_res, sin_res)
o = self._istft(sep_mag, sep_cos, sep_sin)
npad = x.shape[-1] - o.shape[-1]
o = F.pad(o, (0, npad))
if y is not None:
self.losses = dict(l1=F.l1_loss(o, y))
return o

View File

@ -0,0 +1,9 @@
from dataclasses import dataclass
from ..hparams import HParams as HParamsBase
@dataclass(frozen=True)
class HParams(HParamsBase):
batch_size_per_gpu: int = 128
distort_prob: float = 0.5

View File

@ -0,0 +1,29 @@
import logging
from functools import cache
import torch
from ..inference import inference
from .train import Denoiser, HParams
logger = logging.getLogger(__name__)
@cache
def load_denoiser(run_dir, device):
if run_dir is None:
return Denoiser(HParams())
hp = HParams.load(run_dir)
denoiser = Denoiser(hp)
path = run_dir / "ds" / "G" / "default" / "mp_rank_00_model_states.pt"
state_dict = torch.load(path, map_location="cpu")["module"]
denoiser.load_state_dict(state_dict)
denoiser.eval()
denoiser.to(device)
return denoiser
@torch.inference_mode()
def denoise(dwav, sr, run_dir, device):
denoiser = load_denoiser(run_dir, device)
return inference(model=denoiser, dwav=dwav, sr=sr, device=device)

View File

@ -0,0 +1,112 @@
import argparse
import random
from functools import partial
from pathlib import Path
import soundfile
import torch
from deepspeed import DeepSpeedConfig
from torch import Tensor
from tqdm import tqdm
from ..data import create_dataloaders, mix_fg_bg
from ..utils import Engine, TrainLoop, save_mels, setup_logging, tree_map
from ..utils.distributed import is_local_leader
from .denoiser import Denoiser
from .hparams import HParams
def load_G(run_dir: Path, hp: HParams | None = None, training=True):
if hp is None:
hp = HParams.load(run_dir)
assert isinstance(hp, HParams)
model = Denoiser(hp)
engine = Engine(model=model, config_class=DeepSpeedConfig(hp.deepspeed_config), ckpt_dir=run_dir / "ds" / "G")
if training:
engine.load_checkpoint()
else:
engine.load_checkpoint(load_optimizer_states=False, load_lr_scheduler_states=False)
return engine
def save_wav(path: Path, wav: Tensor, rate: int):
wav = wav.detach().cpu().numpy()
soundfile.write(path, wav, samplerate=rate)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("run_dir", type=Path)
parser.add_argument("--yaml", type=Path, default=None)
parser.add_argument("--device", type=str, default="cuda")
args = parser.parse_args()
setup_logging(args.run_dir)
hp = HParams.load(args.run_dir, yaml=args.yaml)
if is_local_leader():
hp.save_if_not_exists(args.run_dir)
hp.print()
train_dl, val_dl = create_dataloaders(hp, mode="denoiser")
def feed_G(engine: Engine, batch: dict[str, Tensor]):
alpha_fn = lambda: random.uniform(*hp.mix_alpha_range)
if random.random() < hp.distort_prob:
fg_wavs = batch["fg_dwavs"]
else:
fg_wavs = batch["fg_wavs"]
mx_dwavs = mix_fg_bg(fg_wavs, batch["bg_dwavs"], alpha=alpha_fn)
pred = engine(mx_dwavs, fg_wavs)
losses = engine.gather_attribute("losses", prefix="losses")
return pred, losses
@torch.no_grad()
def eval_fn(engine: Engine, eval_dir, n_saved=10):
model = engine.module
model.eval()
step = engine.global_step
for i, batch in enumerate(tqdm(val_dl), 1):
batch = tree_map(lambda x: x.to(args.device) if isinstance(x, Tensor) else x, batch)
fg_dwavs = batch["fg_dwavs"] # 1 t
mx_dwavs = mix_fg_bg(fg_dwavs, batch["bg_dwavs"])
pred_fg_dwavs = model(mx_dwavs) # 1 t
mx_mels = model.to_mel(mx_dwavs) # 1 c t
fg_mels = model.to_mel(fg_dwavs) # 1 c t
pred_fg_mels = model.to_mel(pred_fg_dwavs) # 1 c t
rate = model.hp.wav_rate
get_path = lambda suffix: eval_dir / f"step_{step:08}_{i:03}{suffix}"
save_wav(get_path("_input.wav"), mx_dwavs[0], rate=rate)
save_wav(get_path("_predict.wav"), pred_fg_dwavs[0], rate=rate)
save_wav(get_path("_target.wav"), fg_dwavs[0], rate=rate)
save_mels(
get_path(".png"),
cond_mel=mx_mels[0].cpu().numpy(),
pred_mel=pred_fg_mels[0].cpu().numpy(),
targ_mel=fg_mels[0].cpu().numpy(),
)
if i >= n_saved:
break
train_loop = TrainLoop(
run_dir=args.run_dir,
train_dl=train_dl,
load_G=partial(load_G, hp=hp),
device=args.device,
feed_G=feed_G,
eval_fn=eval_fn,
)
train_loop.run(max_steps=hp.max_steps)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,144 @@
import torch.nn.functional as F
from torch import nn
class PreactResBlock(nn.Sequential):
def __init__(self, dim):
super().__init__(
nn.GroupNorm(dim // 16, dim),
nn.GELU(),
nn.Conv2d(dim, dim, 3, padding=1),
nn.GroupNorm(dim // 16, dim),
nn.GELU(),
nn.Conv2d(dim, dim, 3, padding=1),
)
def forward(self, x):
return x + super().forward(x)
class UNetBlock(nn.Module):
def __init__(self, input_dim, output_dim=None, scale_factor=1.0):
super().__init__()
if output_dim is None:
output_dim = input_dim
self.pre_conv = nn.Conv2d(input_dim, output_dim, 3, padding=1)
self.res_block1 = PreactResBlock(output_dim)
self.res_block2 = PreactResBlock(output_dim)
self.downsample = self.upsample = nn.Identity()
if scale_factor > 1:
self.upsample = nn.Upsample(scale_factor=scale_factor)
elif scale_factor < 1:
self.downsample = nn.Upsample(scale_factor=scale_factor)
def forward(self, x, h=None):
"""
Args:
x: (b c h w), last output
h: (b c h w), skip output
Returns:
o: (b c h w), output
s: (b c h w), skip output
"""
x = self.upsample(x)
if h is not None:
assert x.shape == h.shape, f"{x.shape} != {h.shape}"
x = x + h
x = self.pre_conv(x)
x = self.res_block1(x)
x = self.res_block2(x)
return self.downsample(x), x
class UNet(nn.Module):
def __init__(self, input_dim, output_dim, hidden_dim=16, num_blocks=4, num_middle_blocks=2):
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.input_proj = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
self.encoder_blocks = nn.ModuleList(
[
UNetBlock(input_dim=hidden_dim * 2**i, output_dim=hidden_dim * 2 ** (i + 1), scale_factor=0.5)
for i in range(num_blocks)
]
)
self.middle_blocks = nn.ModuleList(
[UNetBlock(input_dim=hidden_dim * 2**num_blocks) for _ in range(num_middle_blocks)]
)
self.decoder_blocks = nn.ModuleList(
[
UNetBlock(input_dim=hidden_dim * 2 ** (i + 1), output_dim=hidden_dim * 2**i, scale_factor=2)
for i in reversed(range(num_blocks))
]
)
self.head = nn.Sequential(
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
nn.GELU(),
nn.Conv2d(hidden_dim, output_dim, 1),
)
@property
def scale_factor(self):
return 2 ** len(self.encoder_blocks)
def pad_to_fit(self, x):
"""
Args:
x: (b c h w), input
Returns:
x: (b c h' w'), padded input
"""
hpad = (self.scale_factor - x.shape[2] % self.scale_factor) % self.scale_factor
wpad = (self.scale_factor - x.shape[3] % self.scale_factor) % self.scale_factor
return F.pad(x, (0, wpad, 0, hpad))
def forward(self, x):
"""
Args:
x: (b c h w), input
Returns:
o: (b c h w), output
"""
shape = x.shape
x = self.pad_to_fit(x)
x = self.input_proj(x)
s_list = []
for block in self.encoder_blocks:
x, s = block(x)
s_list.append(s)
for block in self.middle_blocks:
x, _ = block(x)
for block, s in zip(self.decoder_blocks, reversed(s_list)):
x, _ = block(x, s)
x = self.head(x)
x = x[..., : shape[2], : shape[3]]
return x
def test(self, shape=(3, 512, 256)):
import ptflops
macs, params = ptflops.get_model_complexity_info(
self,
shape,
as_strings=True,
print_per_layer_stat=True,
verbose=True,
)
print(f"macs: {macs}")
print(f"params: {params}")
def main():
model = UNet(3, 3)
model.test()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,129 @@
import argparse
import random
import time
from pathlib import Path
import torch
import torchaudio
from tqdm import tqdm
from .inference import denoise, enhance
@torch.inference_mode()
def main():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("in_dir", type=Path, help="Path to input audio folder")
parser.add_argument("out_dir", type=Path, help="Output folder")
parser.add_argument(
"--run_dir",
type=Path,
default=None,
help="Path to the enhancer run folder, if None, use the default model",
)
parser.add_argument(
"--suffix",
type=str,
default=".wav",
help="Audio file suffix",
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="Device to use for computation, recommended to use CUDA",
)
parser.add_argument(
"--denoise_only",
action="store_true",
help="Only apply denoising without enhancement",
)
parser.add_argument(
"--lambd",
type=float,
default=1.0,
help="Denoise strength for enhancement (0.0 to 1.0)",
)
parser.add_argument(
"--tau",
type=float,
default=0.5,
help="CFM prior temperature (0.0 to 1.0)",
)
parser.add_argument(
"--solver",
type=str,
default="midpoint",
choices=["midpoint", "rk4", "euler"],
help="Numerical solver to use",
)
parser.add_argument(
"--nfe",
type=int,
default=64,
help="Number of function evaluations",
)
parser.add_argument(
"--parallel_mode",
action="store_true",
help="Shuffle the audio paths and skip the existing ones, enabling multiple jobs to run in parallel",
)
args = parser.parse_args()
device = args.device
if device == "cuda" and not torch.cuda.is_available():
print("CUDA is not available but --device is set to cuda, using CPU instead")
device = "cpu"
start_time = time.perf_counter()
run_dir = args.run_dir
paths = sorted(args.in_dir.glob(f"**/*{args.suffix}"))
if args.parallel_mode:
random.shuffle(paths)
if len(paths) == 0:
print(f"No {args.suffix} files found in the following path: {args.in_dir}")
return
pbar = tqdm(paths)
for path in pbar:
out_path = args.out_dir / path.relative_to(args.in_dir)
if args.parallel_mode and out_path.exists():
continue
pbar.set_description(f"Processing {out_path}")
dwav, sr = torchaudio.load(path)
dwav = dwav.mean(0)
if args.denoise_only:
hwav, sr = denoise(
dwav=dwav,
sr=sr,
device=device,
run_dir=args.run_dir,
)
else:
hwav, sr = enhance(
dwav=dwav,
sr=sr,
device=device,
nfe=args.nfe,
solver=args.solver,
lambd=args.lambd,
tau=args.tau,
run_dir=run_dir,
)
out_path.parent.mkdir(parents=True, exist_ok=True)
torchaudio.save(out_path, hwav[None], sr)
# Cool emoji effect saying the job is done
elapsed_time = time.perf_counter() - start_time
print(f"🌟 Enhancement done! {len(paths)} files processed in {elapsed_time:.2f}s")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,28 @@
import logging
from pathlib import Path
import torch
RUN_NAME = "enhancer_stage2"
logger = logging.getLogger(__name__)
def get_url(relpath):
return f"https://huggingface.co/ResembleAI/resemble-enhance/resolve/main/{RUN_NAME}/{relpath}?download=true"
def get_path(relpath):
return Path(__file__).parent.parent / "model_repo" / RUN_NAME / relpath
def download():
relpaths = ["hparams.yaml", "ds/G/latest", "ds/G/default/mp_rank_00_model_states.pt"]
for relpath in relpaths:
path = get_path(relpath)
if path.exists():
continue
url = get_url(relpath)
path.parent.mkdir(parents=True, exist_ok=True)
torch.hub.download_url_to_file(url, str(path))
return get_path("")

View File

@ -0,0 +1,195 @@
import logging
import matplotlib.pyplot as plt
import pandas as pd
import torch
from torch import Tensor, nn
from torch.distributions import Beta
from ..common import Normalizer
from ..denoiser.inference import load_denoiser
from ..melspec import MelSpectrogram
from ..utils.distributed import global_leader_only
from ..utils.train_loop import TrainLoop
from .hparams import HParams
from .lcfm import CFM, IRMAE, LCFM
from .univnet import UnivNet
logger = logging.getLogger(__name__)
def _maybe(fn):
def _fn(*args):
if args[0] is None:
return None
return fn(*args)
return _fn
def _normalize_wav(x: Tensor):
return x / (x.abs().max(dim=-1, keepdim=True).values + 1e-7)
class Enhancer(nn.Module):
def __init__(self, hp: HParams):
super().__init__()
self.hp = hp
n_mels = self.hp.num_mels
vocoder_input_dim = n_mels + self.hp.vocoder_extra_dim
latent_dim = self.hp.lcfm_latent_dim
self.lcfm = LCFM(
IRMAE(
input_dim=n_mels,
output_dim=vocoder_input_dim,
latent_dim=latent_dim,
),
CFM(
cond_dim=n_mels,
output_dim=self.hp.lcfm_latent_dim,
solver_nfe=self.hp.cfm_solver_nfe,
solver_method=self.hp.cfm_solver_method,
time_mapping_divisor=self.hp.cfm_time_mapping_divisor,
),
z_scale=self.hp.lcfm_z_scale,
)
self.lcfm.set_mode_(self.hp.lcfm_training_mode)
self.mel_fn = MelSpectrogram(hp)
self.vocoder = UnivNet(self.hp, vocoder_input_dim)
self.denoiser = load_denoiser(self.hp.denoiser_run_dir, "cpu")
self.normalizer = Normalizer()
self._eval_lambd = 0.0
self.dummy: Tensor
self.register_buffer("dummy", torch.zeros(1))
if self.hp.enhancer_stage1_run_dir is not None:
pretrained_path = self.hp.enhancer_stage1_run_dir / "ds/G/default/mp_rank_00_model_states.pt"
self._load_pretrained(pretrained_path)
logger.info(f"{self.__class__.__name__} summary")
logger.info(f"{self.summarize()}")
def _load_pretrained(self, path):
# Clone is necessary as otherwise it holds a reference to the original model
cfm_state_dict = {k: v.clone() for k, v in self.lcfm.cfm.state_dict().items()}
denoiser_state_dict = {k: v.clone() for k, v in self.denoiser.state_dict().items()}
state_dict = torch.load(path, map_location="cpu")["module"]
self.load_state_dict(state_dict, strict=False)
self.lcfm.cfm.load_state_dict(cfm_state_dict) # Reset cfm
self.denoiser.load_state_dict(denoiser_state_dict) # Reset denoiser
logger.info(f"Loaded pretrained model from {path}")
def summarize(self):
npa_train = lambda m: sum(p.numel() for p in m.parameters() if p.requires_grad)
npa = lambda m: sum(p.numel() for p in m.parameters())
rows = []
for name, module in self.named_children():
rows.append(dict(name=name, trainable=npa_train(module), total=npa(module)))
rows.append(dict(name="total", trainable=npa_train(self), total=npa(self)))
df = pd.DataFrame(rows)
return df.to_markdown(index=False)
def to_mel(self, x: Tensor, drop_last=True):
"""
Args:
x: (b t), wavs
Returns:
o: (b c t), mels
"""
if drop_last:
return self.mel_fn(x)[..., :-1] # (b d t)
return self.mel_fn(x)
@global_leader_only
@torch.no_grad()
def _visualize(self, original_mel, denoised_mel):
loop = TrainLoop.get_running_loop()
if loop is None or loop.global_step % 100 != 0:
return
plt.figure(figsize=(6, 6))
plt.subplot(211)
plt.title("Original")
plt.imshow(original_mel[0].cpu().numpy(), origin="lower", interpolation="none")
plt.subplot(212)
plt.title("Denoised")
plt.imshow(denoised_mel[0].cpu().numpy(), origin="lower", interpolation="none")
plt.tight_layout()
path = loop.get_running_loop_viz_path("input", ".png")
plt.savefig(path, dpi=300)
def _may_denoise(self, x: Tensor, y: Tensor | None = None):
if self.hp.lcfm_training_mode == "cfm":
return self.denoiser(x, y)
return x
def configurate_(self, nfe, solver, lambd, tau):
"""
Args:
nfe: number of function evaluations
solver: solver method
lambd: denoiser strength [0, 1]
tau: prior temperature [0, 1]
"""
self.lcfm.cfm.solver.configurate_(nfe, solver)
self.lcfm.eval_tau_(tau)
self._eval_lambd = lambd
def forward(self, x: Tensor, y: Tensor | None = None, z: Tensor | None = None):
"""
Args:
x: (b t), mix wavs (fg + bg)
y: (b t), fg clean wavs
z: (b t), fg distorted wavs
Returns:
o: (b t), reconstructed wavs
"""
assert x.dim() == 2, f"Expected (b t), got {x.size()}"
assert y is None or y.dim() == 2, f"Expected (b t), got {y.size()}"
if self.hp.lcfm_training_mode == "cfm":
self.normalizer.eval()
x = _normalize_wav(x)
y = _maybe(_normalize_wav)(y)
z = _maybe(_normalize_wav)(z)
x_mel_original = self.normalizer(self.to_mel(x), update=False) # (b d t)
if self.hp.lcfm_training_mode == "cfm":
if self.training:
lambd = Beta(0.2, 0.2).sample(x.shape[:1]).to(x.device)
lambd = lambd[:, None, None]
x_mel_denoised = self.normalizer(self.to_mel(self._may_denoise(x, z)), update=False)
x_mel_denoised = x_mel_denoised.detach()
x_mel_denoised = lambd * x_mel_denoised + (1 - lambd) * x_mel_original
self._visualize(x_mel_original, x_mel_denoised)
else:
lambd = self._eval_lambd
if lambd == 0:
x_mel_denoised = x_mel_original
else:
x_mel_denoised = self.normalizer(self.to_mel(self._may_denoise(x, z)), update=False)
x_mel_denoised = x_mel_denoised.detach()
x_mel_denoised = lambd * x_mel_denoised + (1 - lambd) * x_mel_original
else:
x_mel_denoised = x_mel_original
y_mel = _maybe(self.to_mel)(y) # (b d t)
y_mel = _maybe(self.normalizer)(y_mel)
lcfm_decoded = self.lcfm(x_mel_denoised, y_mel, ψ0=x_mel_original) # (b d t)
if lcfm_decoded is None:
o = None
else:
o = self.vocoder(lcfm_decoded, y)
return o

View File

@ -0,0 +1,23 @@
from dataclasses import dataclass
from pathlib import Path
from ..hparams import HParams as HParamsBase
@dataclass(frozen=True)
class HParams(HParamsBase):
cfm_solver_method: str = "midpoint"
cfm_solver_nfe: int = 64
cfm_time_mapping_divisor: int = 4
univnet_nc: int = 96
lcfm_latent_dim: int = 64
lcfm_training_mode: str = "ae"
lcfm_z_scale: float = 5
vocoder_extra_dim: int = 32
gan_training_start_step: int | None = 5_000
enhancer_stage1_run_dir: Path | None = None
denoiser_run_dir: Path | None = None

View File

@ -0,0 +1,49 @@
import logging
from functools import cache
import torch
from ..inference import inference
from .download import download
from .train import Enhancer, HParams
import platform
import pathlib
# Check if the current system is Windows
if platform.system() == 'Windows':
# Make changes specific to Windows
temp = pathlib.PosixPath
pathlib.PosixPath = pathlib.WindowsPath
logger = logging.getLogger(__name__)
def load_enhancer(run_dir, device):
if run_dir is None:
run_dir = download()
hp = HParams.load(run_dir)
enhancer = Enhancer(hp)
path = run_dir / "ds" / "G" / "default" / "mp_rank_00_model_states.pt"
state_dict = torch.load(path, map_location="cpu")["module"]
enhancer.load_state_dict(state_dict)
enhancer.eval()
enhancer.to(device)
return enhancer
@torch.inference_mode()
def denoise(dwav, sr, device, run_dir=None):
enhancer = load_enhancer(run_dir, device)
return inference(model=enhancer.denoiser, dwav=dwav, sr=sr, device=device)
@torch.inference_mode()
def enhance(chunk_seconds, chunks_overlap,dwav, sr, device, nfe=32, solver="midpoint", lambd=0.5, tau=0.5, run_dir=None):
assert 0 < nfe <= 128, f"nfe must be in (0, 128], got {nfe}"
assert solver in ("midpoint", "rk4", "euler"), f"solver must be in ('midpoint', 'rk4', 'euler'), got {solver}"
assert 0 <= lambd <= 1, f"lambd must be in [0, 1], got {lambd}"
assert 0 <= tau <= 1, f"tau must be in [0, 1], got {tau}"
enhancer = load_enhancer(run_dir, device)
enhancer.configurate_(nfe=nfe, solver=solver, lambd=lambd, tau=tau)
return inference(model=enhancer, chunk_seconds=chunk_seconds, overlap_seconds=chunks_overlap, dwav=dwav, sr=sr, device=device)

View File

@ -0,0 +1,2 @@
from .irmae import IRMAE
from .lcfm import CFM, LCFM

View File

@ -0,0 +1,372 @@
import logging
from dataclasses import dataclass
from functools import partial
from typing import Protocol
import matplotlib.pyplot as plt
import numpy as np
import scipy
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from tqdm import trange
from .wn import WN
logger = logging.getLogger(__name__)
class VelocityField(Protocol):
def __call__(self, *, t: Tensor, ψt: Tensor, dt: Tensor) -> Tensor:
...
class Solver:
def __init__(
self,
method="midpoint",
nfe=32,
viz_name="solver",
viz_every=100,
mel_fn=None,
time_mapping_divisor=4,
verbose=False,
):
self.configurate_(nfe=nfe, method=method)
self.verbose = verbose
self.viz_every = viz_every
self.viz_name = viz_name
self._camera = None
self._mel_fn = mel_fn
self._time_mapping = partial(self.exponential_decay_mapping, n=time_mapping_divisor)
def configurate_(self, nfe=None, method=None):
if nfe is None:
nfe = self.nfe
if method is None:
method = self.method
if nfe == 1 and method in ("midpoint", "rk4"):
logger.warning(f"1 NFE is not supported for {method}, using euler method instead.")
method = "euler"
self.nfe = nfe
self.method = method
@property
def time_mapping(self):
return self._time_mapping
@staticmethod
def exponential_decay_mapping(t, n=4):
"""
Args:
n: target step
"""
def h(t, a):
return (a**t - 1) / (a - 1)
# Solve h(1/n) = 0.5
a = float(scipy.optimize.fsolve(lambda a: h(1 / n, a) - 0.5, x0=0))
t = h(t, a=a)
return t
@torch.no_grad()
def _maybe_camera_snap(self, *, ψt, t):
camera = self._camera
if camera is not None:
if ψt.shape[1] == 1:
# Waveform, b 1 t, plot every 100 samples
plt.subplot(211)
plt.plot(ψt.detach().cpu().numpy()[0, 0, ::100], color="blue")
if self._mel_fn is not None:
plt.subplot(212)
mel = self._mel_fn(ψt.detach().cpu().numpy()[0, 0])
plt.imshow(mel, origin="lower", interpolation="none")
elif ψt.shape[1] == 2:
# Complex
plt.subplot(121)
plt.imshow(
ψt.detach().cpu().numpy()[0, 0],
origin="lower",
interpolation="none",
)
plt.subplot(122)
plt.imshow(
ψt.detach().cpu().numpy()[0, 1],
origin="lower",
interpolation="none",
)
else:
# Spectrogram, b c t
plt.imshow(ψt.detach().cpu().numpy()[0], origin="lower", interpolation="none")
ax = plt.gca()
ax.text(0.5, 1.01, f"t={t:.2f}", transform=ax.transAxes, ha="center")
camera.snap()
@staticmethod
def _euler_step(t, ψt, dt, f: VelocityField):
return ψt + dt * f(t=t, ψt=ψt, dt=dt)
@staticmethod
def _midpoint_step(t, ψt, dt, f: VelocityField):
return ψt + dt * f(t=t + dt / 2, ψt=ψt + dt * f(t=t, ψt=ψt, dt=dt) / 2, dt=dt)
@staticmethod
def _rk4_step(t, ψt, dt, f: VelocityField):
k1 = f(t=t, ψt=ψt, dt=dt)
k2 = f(t=t + dt / 2, ψt=ψt + dt * k1 / 2, dt=dt)
k3 = f(t=t + dt / 2, ψt=ψt + dt * k2 / 2, dt=dt)
k4 = f(t=t + dt, ψt=ψt + dt * k3, dt=dt)
return ψt + dt * (k1 + 2 * k2 + 2 * k3 + k4) / 6
@property
def _step(self):
if self.method == "euler":
return self._euler_step
elif self.method == "midpoint":
return self._midpoint_step
elif self.method == "rk4":
return self._rk4_step
else:
raise ValueError(f"Unknown method: {self.method}")
def get_running_train_loop(self):
try:
# Lazy import
from ...utils.train_loop import TrainLoop
return TrainLoop.get_running_loop()
except ImportError:
return None
@property
def visualizing(self):
loop = self.get_running_train_loop()
if loop is None:
return
out_path = loop.make_current_step_viz_path(self.viz_name, ".gif")
return loop.global_step % self.viz_every == 0 and not out_path.exists()
def _reset_camera(self):
try:
from celluloid import Camera
self._camera = Camera(plt.figure())
except:
pass
def _maybe_dump_camera(self):
camera = self._camera
loop = self.get_running_train_loop()
if camera is not None and loop is not None:
animation = camera.animate()
out_path = loop.make_current_step_viz_path(self.viz_name, ".gif")
out_path.parent.mkdir(exist_ok=True, parents=True)
animation.save(out_path, writer="pillow", fps=4)
plt.close()
self._camera = None
@property
def n_steps(self):
n = self.nfe
if self.method == "euler":
pass
elif self.method == "midpoint":
n //= 2
elif self.method == "rk4":
n //= 4
else:
raise ValueError(f"Unknown method: {self.method}")
return n
def solve(self, f: VelocityField, ψ0: Tensor, t0=0.0, t1=1.0):
ts = self._time_mapping(np.linspace(t0, t1, self.n_steps + 1))
if self.visualizing:
self._reset_camera()
if self.verbose:
steps = trange(self.n_steps, desc="CFM inference")
else:
steps = range(self.n_steps)
ψt = ψ0
for i in steps:
dt = ts[i + 1] - ts[i]
t = ts[i]
self._maybe_camera_snap(ψt=ψt, t=t)
ψt = self._step(t=t, ψt=ψt, dt=dt, f=f)
self._maybe_camera_snap(ψt=ψt, t=ts[-1])
ψ1 = ψt
del ψt
self._maybe_dump_camera()
return ψ1
def __call__(self, f: VelocityField, ψ0: Tensor, t0=0.0, t1=1.0):
return self.solve(f=f, ψ0=ψ0, t0=t0, t1=t1)
class SinusodialTimeEmbedding(nn.Module):
def __init__(self, d_embed):
super().__init__()
self.d_embed = d_embed
assert d_embed % 2 == 0
def forward(self, t):
t = t.unsqueeze(-1) # ... 1
p = torch.linspace(0, 4, self.d_embed // 2).to(t)
while p.dim() < t.dim():
p = p.unsqueeze(0) # ... d/2
sin = torch.sin(t * 10**p)
cos = torch.cos(t * 10**p)
return torch.cat([sin, cos], dim=-1)
@dataclass(eq=False)
class CFM(nn.Module):
"""
This mixin is for general diffusion models.
ψ0 stands for the gaussian noise, and ψ1 is the data point.
Here we follow the CFM style:
The generation process (reverse process) is from t=0 to t=1.
The forward process is from t=1 to t=0.
"""
cond_dim: int
output_dim: int
time_emb_dim: int = 128
viz_name: str = "cfm"
solver_nfe: int = 32
solver_method: str = "midpoint"
time_mapping_divisor: int = 4
def __post_init__(self):
super().__init__()
self.solver = Solver(
viz_name=self.viz_name,
viz_every=1,
nfe=self.solver_nfe,
method=self.solver_method,
time_mapping_divisor=self.time_mapping_divisor,
)
self.emb = SinusodialTimeEmbedding(self.time_emb_dim)
self.net = WN(
input_dim=self.output_dim,
output_dim=self.output_dim,
local_dim=self.cond_dim,
global_dim=self.time_emb_dim,
)
def _perturb(self, ψ1: Tensor, t: Tensor | None = None):
"""
Perturb ψ1 to ψt.
"""
raise NotImplementedError
def _sample_ψ0(self, x: Tensor):
"""
Args:
x: (b c t), which implies the shape of ψ0
"""
shape = list(x.shape)
shape[1] = self.output_dim
if self.training:
g = None
else:
g = torch.Generator(device=x.device)
g.manual_seed(0) # deterministic sampling during eval
ψ0 = torch.randn(shape, device=x.device, dtype=x.dtype, generator=g)
return ψ0
@property
def sigma(self):
return 1e-4
def _to_ψt(self, *, ψ1: Tensor, ψ0: Tensor, t: Tensor):
"""
Eq (22)
"""
while t.dim() < ψ1.dim():
t = t.unsqueeze(-1)
μ = t * ψ1 + (1 - t) * ψ0
return μ + torch.randn_like(μ) * self.sigma
def _to_u(self, *, ψ1, ψ0: Tensor):
"""
Eq (21)
"""
return ψ1 - ψ0
def _to_v(self, *, ψt, x, t: float | Tensor):
"""
Args:
ψt: (b c t)
x: (b c t)
t: (b)
Returns:
v: (b c t)
"""
if isinstance(t, (float, int)):
t = torch.full(ψt.shape[:1], t).to(ψt)
t = t.clamp(0, 1) # [0, 1)
g = self.emb(t) # (b d)
v = self.net(ψt, l=x, g=g)
return v
def compute_losses(self, x, y, ψ0) -> dict:
"""
Args:
x: (b c t)
y: (b c t)
Returns:
losses: dict
"""
t = torch.rand(len(x), device=x.device, dtype=x.dtype)
t = self.solver.time_mapping(t)
if ψ0 is None:
ψ0 = self._sample_ψ0(x)
ψt = self._to_ψt(ψ1=y, t=t, ψ0=ψ0)
v = self._to_v(ψt=ψt, t=t, x=x)
u = self._to_u(ψ1=y, ψ0=ψ0)
losses = dict(l1=F.l1_loss(v, u))
return losses
@torch.inference_mode()
def sample(self, x, ψ0=None, t0=0.0):
"""
Args:
x: (b c t)
Returns:
y: (b ... t)
"""
if ψ0 is None:
ψ0 = self._sample_ψ0(x)
f = lambda t, ψt, dt: self._to_v(ψt=ψt, t=t, x=x)
ψ1 = self.solver(f=f, ψ0=ψ0, t0=t0)
return ψ1
def forward(self, x: Tensor, y: Tensor | None = None, ψ0: Tensor | None = None, t0=0.0):
if y is None:
y = self.sample(x, ψ0=ψ0, t0=t0)
else:
self.losses = self.compute_losses(x, y, ψ0=ψ0)
return y

View File

@ -0,0 +1,123 @@
import logging
from dataclasses import dataclass
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor, nn
from torch.nn.utils.parametrizations import weight_norm
from ...common import Normalizer
logger = logging.getLogger(__name__)
@dataclass
class IRMAEOutput:
latent: Tensor # latent vector
decoded: Tensor | None # decoder output, include extra dim
class ResBlock(nn.Sequential):
def __init__(self, channels, dilations=[1, 2, 4, 8]):
wn = weight_norm
super().__init__(
nn.GroupNorm(32, channels),
nn.GELU(),
wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[0])),
nn.GroupNorm(32, channels),
nn.GELU(),
wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[1])),
nn.GroupNorm(32, channels),
nn.GELU(),
wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[2])),
nn.GroupNorm(32, channels),
nn.GELU(),
wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[3])),
)
def forward(self, x: Tensor):
return x + super().forward(x)
class IRMAE(nn.Module):
def __init__(
self,
input_dim,
output_dim,
latent_dim,
hidden_dim=1024,
num_irms=4,
):
"""
Args:
input_dim: input dimension
output_dim: output dimension
latent_dim: latent dimension
hidden_dim: hidden layer dimension
num_irm_matrics: number of implicit rank minimization matrices
norm: normalization layer
"""
self.input_dim = input_dim
super().__init__()
self.encoder = nn.Sequential(
nn.Conv1d(input_dim, hidden_dim, 3, padding="same"),
*[ResBlock(hidden_dim) for _ in range(4)],
# Try to obtain compact representation (https://proceedings.neurips.cc/paper/2020/file/a9078e8653368c9c291ae2f8b74012e7-Paper.pdf)
*[nn.Conv1d(hidden_dim if i == 0 else latent_dim, latent_dim, 1, bias=False) for i in range(num_irms)],
nn.Tanh(),
)
self.decoder = nn.Sequential(
nn.Conv1d(latent_dim, hidden_dim, 3, padding="same"),
*[ResBlock(hidden_dim) for _ in range(4)],
nn.Conv1d(hidden_dim, output_dim, 1),
)
self.head = nn.Sequential(
nn.Conv1d(output_dim, hidden_dim, 3, padding="same"),
nn.GELU(),
nn.Conv1d(hidden_dim, input_dim, 1),
)
self.estimator = Normalizer()
def encode(self, x):
"""
Args:
x: (b c t) tensor
"""
z = self.encoder(x) # (b c t)
_ = self.estimator(z) # Estimate the glboal mean and std of z
self.stats = {}
self.stats["z_mean"] = z.mean().item()
self.stats["z_std"] = z.std().item()
self.stats["z_abs_68"] = z.abs().quantile(0.6827).item()
self.stats["z_abs_95"] = z.abs().quantile(0.9545).item()
self.stats["z_abs_99"] = z.abs().quantile(0.9973).item()
return z
def decode(self, z):
"""
Args:
z: (b c t) tensor
"""
return self.decoder(z)
def forward(self, x, skip_decoding=False):
"""
Args:
x: (b c t) tensor
skip_decoding: if True, skip the decoding step
"""
z = self.encode(x) # q(z|x)
if skip_decoding:
# This speeds up the training in cfm only mode
decoded = None
else:
decoded = self.decode(z) # p(x|z)
predicted = self.head(decoded)
self.losses = dict(mse=F.mse_loss(predicted, x))
return IRMAEOutput(latent=z, decoded=decoded)

View File

@ -0,0 +1,152 @@
import logging
from enum import Enum
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch import Tensor, nn
from .cfm import CFM
from .irmae import IRMAE, IRMAEOutput
logger = logging.getLogger(__name__)
def freeze_(module):
for p in module.parameters():
p.requires_grad_(False)
class LCFM(nn.Module):
class Mode(Enum):
AE = "ae"
CFM = "cfm"
def __init__(self, ae: IRMAE, cfm: CFM, z_scale: float = 1.0):
super().__init__()
self.ae = ae
self.cfm = cfm
self.z_scale = z_scale
self._mode = None
self._eval_tau = 0.5
@property
def mode(self):
return self._mode
def set_mode_(self, mode):
mode = self.Mode(mode)
self._mode = mode
if mode == mode.AE:
freeze_(self.cfm)
logger.info("Freeze cfm")
elif mode == mode.CFM:
freeze_(self.ae)
logger.info("Freeze ae (encoder and decoder)")
else:
raise ValueError(f"Unknown training mode: {mode}")
def get_running_train_loop(self):
try:
# Lazy import
from ...utils.train_loop import TrainLoop
return TrainLoop.get_running_loop()
except ImportError:
return None
@property
def global_step(self):
loop = self.get_running_train_loop()
if loop is None:
return None
return loop.global_step
@torch.no_grad()
def _visualize(self, x, y, y_):
loop = self.get_running_train_loop()
if loop is None:
return
plt.subplot(221)
plt.imshow(y[0].detach().cpu().numpy(), aspect="auto", origin="lower", interpolation="none")
plt.title("GT")
plt.subplot(222)
y_ = y_[:, : y.shape[1]]
plt.imshow(y_[0].detach().cpu().numpy(), aspect="auto", origin="lower", interpolation="none")
plt.title("Posterior")
plt.subplot(223)
z_ = self.cfm(x)
y__ = self.ae.decode(z_)
y__ = y__[:, : y.shape[1]]
plt.imshow(y__[0].detach().cpu().numpy(), aspect="auto", origin="lower", interpolation="none")
plt.title("C-Prior")
del y__
plt.subplot(224)
z_ = torch.randn_like(z_)
y__ = self.ae.decode(z_)
y__ = y__[:, : y.shape[1]]
plt.imshow(y__[0].detach().cpu().numpy(), aspect="auto", origin="lower", interpolation="none")
plt.title("Prior")
del z_, y__
path = loop.make_current_step_viz_path("recon", ".png")
path.parent.mkdir(exist_ok=True, parents=True)
plt.tight_layout()
plt.savefig(path, dpi=500)
plt.close()
def _scale(self, z: Tensor):
return z * self.z_scale
def _unscale(self, z: Tensor):
return z / self.z_scale
def eval_tau_(self, tau):
self._eval_tau = tau
def forward(self, x, y: Tensor | None = None, ψ0: Tensor | None = None):
"""
Args:
x: (b d t), condition mel
y: (b d t), target mel
ψ0: (b d t), starting mel
"""
if self.mode == self.Mode.CFM:
self.ae.eval() # Always set to eval when training cfm
if ψ0 is not None:
ψ0 = self._scale(self.ae.encode(ψ0))
if self.training:
tau = torch.rand_like(ψ0[:, :1, :1])
else:
tau = self._eval_tau
ψ0 = tau * torch.randn_like(ψ0) + (1 - tau) * ψ0
if y is None:
if self.mode == self.Mode.AE:
with torch.no_grad():
training = self.ae.training
self.ae.eval()
z = self.ae.encode(x)
self.ae.train(training)
else:
z = self._unscale(self.cfm(x, ψ0=ψ0))
h = self.ae.decode(z)
else:
ae_output: IRMAEOutput = self.ae(y, skip_decoding=self.mode == self.Mode.CFM)
if self.mode == self.Mode.CFM:
_ = self.cfm(x, self._scale(ae_output.latent.detach()), ψ0=ψ0)
h = ae_output.decoded
if h is not None and self.global_step is not None and self.global_step % 100 == 0:
self._visualize(x[:1], y[:1], h[:1])
return h

View File

@ -0,0 +1,147 @@
import logging
import math
import torch
import torch.nn as nn
logger = logging.getLogger(__name__)
@torch.jit.script
def _fused_tanh_sigmoid(h):
a, b = h.chunk(2, dim=1)
h = a.tanh() * b.sigmoid()
return h
class WNLayer(nn.Module):
"""
A DiffWave-like WN
"""
def __init__(self, hidden_dim, local_dim, global_dim, kernel_size, dilation):
super().__init__()
local_output_dim = hidden_dim * 2
if global_dim is not None:
self.gconv = nn.Conv1d(global_dim, hidden_dim, 1)
if local_dim is not None:
self.lconv = nn.Conv1d(local_dim, local_output_dim, 1)
self.dconv = nn.Conv1d(hidden_dim, local_output_dim, kernel_size, dilation=dilation, padding="same")
self.out = nn.Conv1d(hidden_dim, 2 * hidden_dim, kernel_size=1)
def forward(self, z, l, g):
identity = z
if g is not None:
if g.dim() == 2:
g = g.unsqueeze(-1)
z = z + self.gconv(g)
z = self.dconv(z)
if l is not None:
z = z + self.lconv(l)
z = _fused_tanh_sigmoid(z)
h = self.out(z)
z, s = h.chunk(2, dim=1)
o = (z + identity) / math.sqrt(2)
return o, s
class WN(nn.Module):
def __init__(
self,
input_dim,
output_dim,
local_dim=None,
global_dim=None,
n_layers=30,
kernel_size=3,
dilation_cycle=5,
hidden_dim=512,
):
super().__init__()
assert kernel_size % 2 == 1
assert hidden_dim % 2 == 0
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.local_dim = local_dim
self.global_dim = global_dim
self.start = nn.Conv1d(input_dim, hidden_dim, 1)
if local_dim is not None:
self.local_norm = nn.InstanceNorm1d(local_dim)
self.layers = nn.ModuleList(
[
WNLayer(
hidden_dim=hidden_dim,
local_dim=local_dim,
global_dim=global_dim,
kernel_size=kernel_size,
dilation=2 ** (i % dilation_cycle),
)
for i in range(n_layers)
]
)
self.end = nn.Conv1d(hidden_dim, output_dim, 1)
def forward(self, z, l=None, g=None):
"""
Args:
z: input (b c t)
l: local condition (b c t)
g: global condition (b d)
"""
z = self.start(z)
if l is not None:
l = self.local_norm(l)
# Skips
s_list = []
for layer in self.layers:
z, s = layer(z, l, g)
s_list.append(s)
s_list = torch.stack(s_list, dim=0).sum(dim=0)
s_list = s_list / math.sqrt(len(self.layers))
o = self.end(s_list)
return o
def summarize(self, length=100):
from ptflops import get_model_complexity_info
x = torch.randn(1, self.input_dim, length)
macs, params = get_model_complexity_info(
self,
(self.input_dim, length),
as_strings=True,
print_per_layer_stat=True,
verbose=True,
)
print(f"Input shape: {x.shape}")
print(f"Computational complexity: {macs}")
print(f"Number of parameters: {params}")
if __name__ == "__main__":
model = WN(input_dim=64, output_dim=64)
model.summarize()

View File

@ -0,0 +1,143 @@
import argparse
import random
from functools import partial
from pathlib import Path
import soundfile
import torch
from deepspeed import DeepSpeedConfig
from torch import Tensor
from tqdm import tqdm
from ..data import create_dataloaders, mix_fg_bg
from ..utils import Engine, TrainLoop, save_mels, setup_logging, tree_map
from ..utils.distributed import is_local_leader
from .enhancer import Enhancer
from .hparams import HParams
from .univnet.discriminator import Discriminator
def load_G(run_dir: Path, hp: HParams | None = None, training=True):
if hp is None:
hp = HParams.load(run_dir)
assert isinstance(hp, HParams)
model = Enhancer(hp)
engine = Engine(model=model, config_class=DeepSpeedConfig(hp.deepspeed_config), ckpt_dir=run_dir / "ds" / "G")
if training:
engine.load_checkpoint()
else:
engine.load_checkpoint(load_optimizer_states=False, load_lr_scheduler_states=False)
return engine
def load_D(run_dir: Path, hp: HParams):
if hp is None:
hp = HParams.load(run_dir)
assert isinstance(hp, HParams)
model = Discriminator(hp)
engine = Engine(model=model, config_class=DeepSpeedConfig(hp.deepspeed_config), ckpt_dir=run_dir / "ds" / "D")
engine.load_checkpoint()
return engine
def save_wav(path: Path, wav: Tensor, rate: int):
wav = wav.detach().cpu().numpy()
soundfile.write(path, wav, samplerate=rate)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("run_dir", type=Path)
parser.add_argument("--yaml", type=Path, default=None)
parser.add_argument("--device", type=str, default="cuda")
args = parser.parse_args()
setup_logging(args.run_dir)
hp = HParams.load(args.run_dir, yaml=args.yaml)
if is_local_leader():
hp.save_if_not_exists(args.run_dir)
hp.print()
train_dl, val_dl = create_dataloaders(hp, mode="enhancer")
def feed_G(engine: Engine, batch: dict[str, Tensor]):
if hp.lcfm_training_mode == "ae":
pred = engine(batch["fg_wavs"], batch["fg_wavs"])
elif hp.lcfm_training_mode == "cfm":
alpha_fn = lambda: random.uniform(*hp.mix_alpha_range)
mx_dwavs = mix_fg_bg(batch["fg_dwavs"], batch["bg_dwavs"], alpha=alpha_fn)
pred = engine(mx_dwavs, batch["fg_wavs"], batch["fg_dwavs"])
else:
raise ValueError(f"Unknown training mode: {hp.lcfm_training_mode}")
losses = engine.gather_attribute("losses")
return pred, losses
def feed_D(engine: Engine, batch: dict | None, fake: Tensor):
if batch is None:
losses = engine(fake=fake)
else:
losses = engine(fake=fake, real=batch["fg_wavs"])
return losses
@torch.no_grad()
def eval_fn(engine: Engine, eval_dir, n_saved=10):
assert isinstance(hp, HParams)
model = engine.module
model.eval()
step = engine.global_step
for i, batch in enumerate(tqdm(val_dl), 1):
batch = tree_map(lambda x: x.to(args.device) if isinstance(x, Tensor) else x, batch)
fg_wavs = batch["fg_wavs"] # 1 t
if hp.lcfm_training_mode == "ae":
in_dwavs = fg_wavs
elif hp.lcfm_training_mode == "cfm":
in_dwavs = mix_fg_bg(fg_wavs, batch["bg_dwavs"])
else:
raise ValueError(f"Unknown training mode: {hp.lcfm_training_mode}")
pred_fg_wavs = model(in_dwavs) # 1 t
in_mels = model.to_mel(in_dwavs) # 1 c t
fg_mels = model.to_mel(fg_wavs) # 1 c t
pred_fg_mels = model.to_mel(pred_fg_wavs) # 1 c t
rate = model.hp.wav_rate
get_path = lambda suffix: eval_dir / f"step_{step:08}_{i:03}{suffix}"
save_wav(get_path("_input.wav"), in_dwavs[0], rate=rate)
save_wav(get_path("_predict.wav"), pred_fg_wavs[0], rate=rate)
save_wav(get_path("_target.wav"), fg_wavs[0], rate=rate)
save_mels(
get_path(".png"),
cond_mel=in_mels[0].cpu().numpy(),
pred_mel=pred_fg_mels[0].cpu().numpy(),
targ_mel=fg_mels[0].cpu().numpy(),
)
if i >= n_saved:
break
train_loop = TrainLoop(
run_dir=args.run_dir,
train_dl=train_dl,
load_G=partial(load_G, hp=hp),
load_D=partial(load_D, hp=hp),
device=args.device,
feed_G=feed_G,
feed_D=feed_D,
eval_fn=eval_fn,
gan_training_start_step=hp.gan_training_start_step,
)
train_loop.run(max_steps=hp.max_steps)
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
from .univnet import UnivNet

View File

@ -0,0 +1,5 @@
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
# LICENSE is in incl_licenses directory.
from .filter import *
from .resample import *

View File

@ -0,0 +1,95 @@
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
# LICENSE is in incl_licenses directory.
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
if 'sinc' in dir(torch):
sinc = torch.sinc
else:
# This code is adopted from adefossez's julius.core.sinc under the MIT License
# https://adefossez.github.io/julius/julius/core.html
# LICENSE is in incl_licenses directory.
def sinc(x: torch.Tensor):
"""
Implementation of sinc, i.e. sin(pi * x) / (pi * x)
__Warning__: Different to julius.sinc, the input is multiplied by `pi`!
"""
return torch.where(x == 0,
torch.tensor(1., device=x.device, dtype=x.dtype),
torch.sin(math.pi * x) / math.pi / x)
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
# https://adefossez.github.io/julius/julius/lowpass.html
# LICENSE is in incl_licenses directory.
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
even = (kernel_size % 2 == 0)
half_size = kernel_size // 2
#For kaiser window
delta_f = 4 * half_width
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
if A > 50.:
beta = 0.1102 * (A - 8.7)
elif A >= 21.:
beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
else:
beta = 0.
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
# ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
if even:
time = (torch.arange(-half_size, half_size) + 0.5)
else:
time = torch.arange(kernel_size) - half_size
if cutoff == 0:
filter_ = torch.zeros_like(time)
else:
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
# Normalize filter to have sum = 1, otherwise we will have a small leakage
# of the constant component in the input signal.
filter_ /= filter_.sum()
filter = filter_.view(1, 1, kernel_size)
return filter
class LowPassFilter1d(nn.Module):
def __init__(self,
cutoff=0.5,
half_width=0.6,
stride: int = 1,
padding: bool = True,
padding_mode: str = 'replicate',
kernel_size: int = 12):
# kernel_size should be even number for stylegan3 setup,
# in this implementation, odd number is also possible.
super().__init__()
if cutoff < -0.:
raise ValueError("Minimum cutoff must be larger than zero.")
if cutoff > 0.5:
raise ValueError("A cutoff above 0.5 does not make sense.")
self.kernel_size = kernel_size
self.even = (kernel_size % 2 == 0)
self.pad_left = kernel_size // 2 - int(self.even)
self.pad_right = kernel_size // 2
self.stride = stride
self.padding = padding
self.padding_mode = padding_mode
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
self.register_buffer("filter", filter)
#input [B, C, T]
def forward(self, x):
_, C, _ = x.shape
if self.padding:
x = F.pad(x, (self.pad_left, self.pad_right),
mode=self.padding_mode)
out = F.conv1d(x, self.filter.expand(C, -1, -1),
stride=self.stride, groups=C)
return out

View File

@ -0,0 +1,49 @@
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
# LICENSE is in incl_licenses directory.
import torch.nn as nn
from torch.nn import functional as F
from .filter import LowPassFilter1d
from .filter import kaiser_sinc_filter1d
class UpSample1d(nn.Module):
def __init__(self, ratio=2, kernel_size=None):
super().__init__()
self.ratio = ratio
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
self.stride = ratio
self.pad = self.kernel_size // ratio - 1
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio,
half_width=0.6 / ratio,
kernel_size=self.kernel_size)
self.register_buffer("filter", filter)
# x: [B, C, T]
def forward(self, x):
_, C, _ = x.shape
x = F.pad(x, (self.pad, self.pad), mode='replicate')
x = self.ratio * F.conv_transpose1d(
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
x = x[..., self.pad_left:-self.pad_right]
return x
class DownSample1d(nn.Module):
def __init__(self, ratio=2, kernel_size=None):
super().__init__()
self.ratio = ratio
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio,
half_width=0.6 / ratio,
stride=ratio,
kernel_size=self.kernel_size)
def forward(self, x):
xx = self.lowpass(x)
return xx

View File

@ -0,0 +1,101 @@
# Refer from https://github.com/NVIDIA/BigVGAN
import math
import torch
import torch.nn as nn
from torch import nn
from torch.nn.utils.parametrizations import weight_norm
from .alias_free_torch import DownSample1d, UpSample1d
class SnakeBeta(nn.Module):
"""
A modified Snake function which uses separate parameters for the magnitude of the periodic components
Shape:
- Input: (B, C, T)
- Output: (B, C, T), same shape as the input
Parameters:
- alpha - trainable parameter that controls frequency
- beta - trainable parameter that controls magnitude
References:
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
https://arxiv.org/abs/2006.08195
Examples:
>>> a1 = snakebeta(256)
>>> x = torch.randn(256)
>>> x = a1(x)
"""
def __init__(self, in_features, alpha=1.0, clamp=(1e-2, 50)):
"""
Initialization.
INPUT:
- in_features: shape of the input
- alpha - trainable parameter that controls frequency
- beta - trainable parameter that controls magnitude
alpha is initialized to 1 by default, higher values = higher-frequency.
beta is initialized to 1 by default, higher values = higher-magnitude.
alpha will be trained along with the rest of your model.
"""
super().__init__()
self.in_features = in_features
self.log_alpha = nn.Parameter(torch.zeros(in_features) + math.log(alpha))
self.log_beta = nn.Parameter(torch.zeros(in_features) + math.log(alpha))
self.clamp = clamp
def forward(self, x):
"""
Forward pass of the function.
Applies the function to the input elementwise.
SnakeBeta = x + 1/b * sin^2 (xa)
"""
alpha = self.log_alpha.exp().clamp(*self.clamp)
alpha = alpha[None, :, None]
beta = self.log_beta.exp().clamp(*self.clamp)
beta = beta[None, :, None]
x = x + (1.0 / beta) * (x * alpha).sin().pow(2)
return x
class UpActDown(nn.Module):
def __init__(
self,
act,
up_ratio: int = 2,
down_ratio: int = 2,
up_kernel_size: int = 12,
down_kernel_size: int = 12,
):
super().__init__()
self.up_ratio = up_ratio
self.down_ratio = down_ratio
self.act = act
self.upsample = UpSample1d(up_ratio, up_kernel_size)
self.downsample = DownSample1d(down_ratio, down_kernel_size)
def forward(self, x):
# x: [B,C,T]
x = self.upsample(x)
x = self.act(x)
x = self.downsample(x)
return x
class AMPBlock(nn.Sequential):
def __init__(self, channels, *, kernel_size=3, dilations=(1, 3, 5)):
super().__init__(*(self._make_layer(channels, kernel_size, d) for d in dilations))
def _make_layer(self, channels, kernel_size, dilation):
return nn.Sequential(
weight_norm(nn.Conv1d(channels, channels, kernel_size, dilation=dilation, padding="same")),
UpActDown(act=SnakeBeta(channels)),
weight_norm(nn.Conv1d(channels, channels, kernel_size, padding="same")),
)
def forward(self, x):
return x + super().forward(x)

View File

@ -0,0 +1,210 @@
import logging
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.nn.utils.parametrizations import weight_norm
from ..hparams import HParams
from .mrstft import get_stft_cfgs
logger = logging.getLogger(__name__)
class PeriodNetwork(nn.Module):
def __init__(self, period):
super().__init__()
self.period = period
wn = weight_norm
self.convs = nn.ModuleList(
[
wn(nn.Conv2d(1, 64, (5, 1), (3, 1), padding=(2, 0))),
wn(nn.Conv2d(64, 128, (5, 1), (3, 1), padding=(2, 0))),
wn(nn.Conv2d(128, 256, (5, 1), (3, 1), padding=(2, 0))),
wn(nn.Conv2d(256, 512, (5, 1), (3, 1), padding=(2, 0))),
wn(nn.Conv2d(512, 1024, (5, 1), 1, padding=(2, 0))),
]
)
self.conv_post = wn(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
def forward(self, x):
"""
Args:
x: [B, 1, T]
"""
assert x.dim() == 3, f"(B, 1, T) is expected, but got {x.shape}."
# 1d to 2d
b, c, t = x.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, 0.2)
x = self.conv_post(x)
x = torch.flatten(x, 1, -1)
return x
class SpecNetwork(nn.Module):
def __init__(self, stft_cfg: dict):
super().__init__()
wn = weight_norm
self.stft_cfg = stft_cfg
self.convs = nn.ModuleList(
[
wn(nn.Conv2d(1, 32, (3, 9), padding=(1, 4))),
wn(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
wn(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
wn(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
wn(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))),
]
)
self.conv_post = wn(nn.Conv2d(32, 1, (3, 3), padding=(1, 1)))
def forward(self, x):
"""
Args:
x: [B, 1, T]
"""
x = self.spectrogram(x)
x = x.unsqueeze(1)
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, 0.2)
x = self.conv_post(x)
x = x.flatten(1, -1)
return x
def spectrogram(self, x):
"""
Args:
x: [B, 1, T]
"""
x = x.squeeze(1)
dtype = x.dtype
stft_cfg = dict(self.stft_cfg)
x = torch.stft(x.float(), center=False, return_complex=False, **stft_cfg)
mag = x.norm(p=2, dim=-1) # [B, F, TT]
mag = mag.to(dtype) # [B, F, TT]
return mag
class MD(nn.ModuleList):
def __init__(self, l: list):
super().__init__([self._create_network(x) for x in l])
self._loss_type = None
def loss_type_(self, loss_type):
self._loss_type = loss_type
def _create_network(self, _):
raise NotImplementedError
def _forward_each(self, d, x, y):
assert self._loss_type is not None, "loss_type is not set."
loss_type = self._loss_type
if loss_type == "hinge":
if y == 0:
# d(x) should be small -> -1
loss = F.relu(1 + d(x)).mean()
elif y == 1:
# d(x) should be large -> 1
loss = F.relu(1 - d(x)).mean()
else:
raise ValueError(f"Invalid y: {y}")
elif loss_type == "wgan":
if y == 0:
loss = d(x).mean()
elif y == 1:
loss = -d(x).mean()
else:
raise ValueError(f"Invalid y: {y}")
else:
raise ValueError(f"Invalid loss_type: {loss_type}")
return loss
def forward(self, x, y) -> Tensor:
losses = [self._forward_each(d, x, y) for d in self]
return torch.stack(losses).mean()
class MPD(MD):
def __init__(self):
super().__init__([2, 3, 7, 13, 17])
def _create_network(self, period):
return PeriodNetwork(period)
class MRD(MD):
def __init__(self, stft_cfgs):
super().__init__(stft_cfgs)
def _create_network(self, stft_cfg):
return SpecNetwork(stft_cfg)
class Discriminator(nn.Module):
@property
def wav_rate(self):
return self.hp.wav_rate
def __init__(self, hp: HParams):
super().__init__()
self.hp = hp
self.stft_cfgs = get_stft_cfgs(hp)
self.mpd = MPD()
self.mrd = MRD(self.stft_cfgs)
self.dummy_float: Tensor
self.register_buffer("dummy_float", torch.zeros(0), persistent=False)
def loss_type_(self, loss_type):
self.mpd.loss_type_(loss_type)
self.mrd.loss_type_(loss_type)
def forward(self, fake, real=None):
"""
Args:
fake: [B T]
real: [B T]
"""
fake = fake.to(self.dummy_float)
if real is None:
self.loss_type_("wgan")
else:
length_difference = (fake.shape[-1] - real.shape[-1]) / real.shape[-1]
assert length_difference < 0.05, f"length_difference should be smaller than 5%"
self.loss_type_("hinge")
real = real.to(self.dummy_float)
fake = fake[..., : real.shape[-1]]
real = real[..., : fake.shape[-1]]
losses = {}
assert fake.dim() == 2, f"(B, T) is expected, but got {fake.shape}."
assert real is None or real.dim() == 2, f"(B, T) is expected, but got {real.shape}."
fake = fake.unsqueeze(1)
if real is None:
losses["mpd"] = self.mpd(fake, 1)
losses["mrd"] = self.mrd(fake, 1)
else:
real = real.unsqueeze(1)
losses["mpd_fake"] = self.mpd(fake, 0)
losses["mpd_real"] = self.mpd(real, 1)
losses["mrd_fake"] = self.mrd(fake, 0)
losses["mrd_real"] = self.mrd(real, 1)
return losses

View File

@ -0,0 +1,281 @@
""" refer from https://github.com/zceng/LVCNet """
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import nn
from torch.nn.utils.parametrizations import weight_norm
from .amp import AMPBlock
class KernelPredictor(torch.nn.Module):
"""Kernel predictor for the location-variable convolutions"""
def __init__(
self,
cond_channels,
conv_in_channels,
conv_out_channels,
conv_layers,
conv_kernel_size=3,
kpnet_hidden_channels=64,
kpnet_conv_size=3,
kpnet_dropout=0.0,
kpnet_nonlinear_activation="LeakyReLU",
kpnet_nonlinear_activation_params={"negative_slope": 0.1},
):
"""
Args:
cond_channels (int): number of channel for the conditioning sequence,
conv_in_channels (int): number of channel for the input sequence,
conv_out_channels (int): number of channel for the output sequence,
conv_layers (int): number of layers
"""
super().__init__()
self.conv_in_channels = conv_in_channels
self.conv_out_channels = conv_out_channels
self.conv_kernel_size = conv_kernel_size
self.conv_layers = conv_layers
kpnet_kernel_channels = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers # l_w
kpnet_bias_channels = conv_out_channels * conv_layers # l_b
self.input_conv = nn.Sequential(
weight_norm(nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)),
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
)
self.residual_convs = nn.ModuleList()
padding = (kpnet_conv_size - 1) // 2
for _ in range(3):
self.residual_convs.append(
nn.Sequential(
nn.Dropout(kpnet_dropout),
weight_norm(
nn.Conv1d(
kpnet_hidden_channels,
kpnet_hidden_channels,
kpnet_conv_size,
padding=padding,
bias=True,
)
),
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
weight_norm(
nn.Conv1d(
kpnet_hidden_channels,
kpnet_hidden_channels,
kpnet_conv_size,
padding=padding,
bias=True,
)
),
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
)
)
self.kernel_conv = weight_norm(
nn.Conv1d(
kpnet_hidden_channels,
kpnet_kernel_channels,
kpnet_conv_size,
padding=padding,
bias=True,
)
)
self.bias_conv = weight_norm(
nn.Conv1d(
kpnet_hidden_channels,
kpnet_bias_channels,
kpnet_conv_size,
padding=padding,
bias=True,
)
)
def forward(self, c):
"""
Args:
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
"""
batch, _, cond_length = c.shape
c = self.input_conv(c)
for residual_conv in self.residual_convs:
residual_conv.to(c.device)
c = c + residual_conv(c)
k = self.kernel_conv(c)
b = self.bias_conv(c)
kernels = k.contiguous().view(
batch,
self.conv_layers,
self.conv_in_channels,
self.conv_out_channels,
self.conv_kernel_size,
cond_length,
)
bias = b.contiguous().view(
batch,
self.conv_layers,
self.conv_out_channels,
cond_length,
)
return kernels, bias
class LVCBlock(torch.nn.Module):
"""the location-variable convolutions"""
def __init__(
self,
in_channels,
cond_channels,
stride,
dilations=[1, 3, 9, 27],
lReLU_slope=0.2,
conv_kernel_size=3,
cond_hop_length=256,
kpnet_hidden_channels=64,
kpnet_conv_size=3,
kpnet_dropout=0.0,
add_extra_noise=False,
downsampling=False,
):
super().__init__()
self.add_extra_noise = add_extra_noise
self.cond_hop_length = cond_hop_length
self.conv_layers = len(dilations)
self.conv_kernel_size = conv_kernel_size
self.kernel_predictor = KernelPredictor(
cond_channels=cond_channels,
conv_in_channels=in_channels,
conv_out_channels=2 * in_channels,
conv_layers=len(dilations),
conv_kernel_size=conv_kernel_size,
kpnet_hidden_channels=kpnet_hidden_channels,
kpnet_conv_size=kpnet_conv_size,
kpnet_dropout=kpnet_dropout,
kpnet_nonlinear_activation_params={"negative_slope": lReLU_slope},
)
if downsampling:
self.convt_pre = nn.Sequential(
nn.LeakyReLU(lReLU_slope),
weight_norm(nn.Conv1d(in_channels, in_channels, 2 * stride + 1, padding="same")),
nn.AvgPool1d(stride, stride),
)
else:
if stride == 1:
self.convt_pre = nn.Sequential(
nn.LeakyReLU(lReLU_slope),
weight_norm(nn.Conv1d(in_channels, in_channels, 1)),
)
else:
self.convt_pre = nn.Sequential(
nn.LeakyReLU(lReLU_slope),
weight_norm(
nn.ConvTranspose1d(
in_channels,
in_channels,
2 * stride,
stride=stride,
padding=stride // 2 + stride % 2,
output_padding=stride % 2,
)
),
)
self.amp_block = AMPBlock(in_channels)
self.conv_blocks = nn.ModuleList()
for d in dilations:
self.conv_blocks.append(
nn.Sequential(
nn.LeakyReLU(lReLU_slope),
weight_norm(nn.Conv1d(in_channels, in_channels, conv_kernel_size, dilation=d, padding="same")),
nn.LeakyReLU(lReLU_slope),
)
)
def forward(self, x, c):
"""forward propagation of the location-variable convolutions.
Args:
x (Tensor): the input sequence (batch, in_channels, in_length)
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
Returns:
Tensor: the output sequence (batch, in_channels, in_length)
"""
_, in_channels, _ = x.shape # (B, c_g, L')
x = self.convt_pre(x) # (B, c_g, stride * L')
# Add one amp block just after the upsampling
x = self.amp_block(x) # (B, c_g, stride * L')
kernels, bias = self.kernel_predictor(c)
if self.add_extra_noise:
# Add extra noise to part of the feature
a, b = x.chunk(2, dim=1)
b = b + torch.randn_like(b) * 0.1
x = torch.cat([a, b], dim=1)
for i, conv in enumerate(self.conv_blocks):
output = conv(x) # (B, c_g, stride * L')
k = kernels[:, i, :, :, :, :] # (B, 2 * c_g, c_g, kernel_size, cond_length)
b = bias[:, i, :, :] # (B, 2 * c_g, cond_length)
output = self.location_variable_convolution(
output, k, b, hop_size=self.cond_hop_length
) # (B, 2 * c_g, stride * L'): LVC
x = x + torch.sigmoid(output[:, :in_channels, :]) * torch.tanh(
output[:, in_channels:, :]
) # (B, c_g, stride * L'): GAU
return x
def location_variable_convolution(self, x, kernel, bias, dilation=1, hop_size=256):
"""perform location-variable convolution operation on the input sequence (x) using the local convolution kernl.
Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100.
Args:
x (Tensor): the input sequence (batch, in_channels, in_length).
kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length)
bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length)
dilation (int): the dilation of convolution.
hop_size (int): the hop_size of the conditioning sequence.
Returns:
(Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length).
"""
batch, _, in_length = x.shape
batch, _, out_channels, kernel_size, kernel_length = kernel.shape
assert in_length == (
kernel_length * hop_size
), f"length of (x, kernel) is not matched, {in_length} != {kernel_length} * {hop_size}"
padding = dilation * int((kernel_size - 1) / 2)
x = F.pad(x, (padding, padding), "constant", 0) # (batch, in_channels, in_length + 2*padding)
x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding)
if hop_size < dilation:
x = F.pad(x, (0, dilation), "constant", 0)
x = x.unfold(
3, dilation, dilation
) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation)
x = x[:, :, :, :, :hop_size]
x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size)
o = torch.einsum("bildsk,biokl->bolsd", x, kernel)
o = o.to(memory_format=torch.channels_last_3d)
bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d)
o = o + bias
o = o.contiguous().view(batch, out_channels, -1)
return o

View File

@ -0,0 +1,128 @@
# -*- coding: utf-8 -*-
# Copyright 2019 Tomoki Hayashi
# MIT License (https://opensource.org/licenses/MIT)
import torch
import torch.nn.functional as F
from torch import nn
from ..hparams import HParams
def _make_stft_cfg(hop_length, win_length=None):
if win_length is None:
win_length = 4 * hop_length
n_fft = 2 ** (win_length - 1).bit_length()
return dict(n_fft=n_fft, hop_length=hop_length, win_length=win_length)
def get_stft_cfgs(hp: HParams):
assert hp.wav_rate == 44100, f"wav_rate must be 44100, got {hp.wav_rate}"
return [_make_stft_cfg(h) for h in (100, 256, 512)]
def stft(x, n_fft, hop_length, win_length, window):
dtype = x.dtype
x = torch.stft(x.float(), n_fft, hop_length, win_length, window, return_complex=True)
x = x.abs().to(dtype)
x = x.transpose(2, 1) # (b f t) -> (b t f)
return x
class SpectralConvergengeLoss(nn.Module):
def forward(self, x_mag, y_mag):
"""Calculate forward propagation.
Args:
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
Returns:
Tensor: Spectral convergence loss value.
"""
return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")
class LogSTFTMagnitudeLoss(nn.Module):
def forward(self, x_mag, y_mag):
"""Calculate forward propagation.
Args:
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
Returns:
Tensor: Log STFT magnitude loss value.
"""
return F.l1_loss(torch.log1p(x_mag), torch.log1p(y_mag))
class STFTLoss(nn.Module):
def __init__(self, hp, stft_cfg: dict, window="hann_window"):
super().__init__()
self.hp = hp
self.stft_cfg = stft_cfg
self.spectral_convergenge_loss = SpectralConvergengeLoss()
self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
self.register_buffer("window", getattr(torch, window)(stft_cfg["win_length"]), persistent=False)
def forward(self, x, y):
"""Calculate forward propagation.
Args:
x (Tensor): Predicted signal (B, T).
y (Tensor): Groundtruth signal (B, T).
Returns:
Tensor: Spectral convergence loss value.
Tensor: Log STFT magnitude loss value.
"""
stft_cfg = dict(self.stft_cfg)
x_mag = stft(x, **stft_cfg, window=self.window) # (b t) -> (b t f)
y_mag = stft(y, **stft_cfg, window=self.window)
sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
return dict(sc=sc_loss, mag=mag_loss)
class MRSTFTLoss(nn.Module):
def __init__(self, hp: HParams, window="hann_window"):
"""Initialize Multi resolution STFT loss module.
Args:
resolutions (list): List of (FFT size, hop size, window length).
window (str): Window function type.
"""
super().__init__()
stft_cfgs = get_stft_cfgs(hp)
self.stft_losses = nn.ModuleList()
self.hp = hp
for c in stft_cfgs:
self.stft_losses += [STFTLoss(hp, c, window=window)]
def forward(self, x, y):
"""Calculate forward propagation.
Args:
x (Tensor): Predicted signal (b t).
y (Tensor): Groundtruth signal (b t).
Returns:
Tensor: Multi resolution spectral convergence loss value.
Tensor: Multi resolution log STFT magnitude loss value.
"""
assert x.dim() == 2 and y.dim() == 2, f"(b t) is expected, but got {x.shape} and {y.shape}."
dtype = x.dtype
x = x.float()
y = y.float()
# Align length
x = x[..., : y.shape[-1]]
y = y[..., : x.shape[-1]]
losses = {}
for f in self.stft_losses:
d = f(x, y)
for k, v in d.items():
losses.setdefault(k, []).append(v)
for k, v in losses.items():
losses[k] = torch.stack(v, dim=0).mean().to(dtype)
return losses

View File

@ -0,0 +1,94 @@
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.nn.utils.parametrizations import weight_norm
from ..hparams import HParams
from .lvcnet import LVCBlock
from .mrstft import MRSTFTLoss
class UnivNet(nn.Module):
@property
def d_noise(self):
return 128
@property
def strides(self):
return [7, 5, 4, 3]
@property
def dilations(self):
return [1, 3, 9, 27]
@property
def nc(self):
return self.hp.univnet_nc
@property
def scale_factor(self) -> int:
return self.hp.hop_size
def __init__(self, hp: HParams, d_input):
super().__init__()
self.d_input = d_input
self.hp = hp
self.blocks = nn.ModuleList(
[
LVCBlock(
self.nc,
d_input,
stride=stride,
dilations=self.dilations,
cond_hop_length=hop_length,
kpnet_conv_size=3,
)
for stride, hop_length in zip(self.strides, np.cumprod(self.strides))
]
)
self.conv_pre = weight_norm(nn.Conv1d(self.d_noise, self.nc, 7, padding=3, padding_mode="reflect"))
self.conv_post = nn.Sequential(
nn.LeakyReLU(0.2),
weight_norm(nn.Conv1d(self.nc, 1, 7, padding=3, padding_mode="reflect")),
nn.Tanh(),
)
self.mrstft = MRSTFTLoss(hp)
@property
def eps(self):
return 1e-5
def forward(self, x: Tensor, y: Tensor | None = None, npad=10):
"""
Args:
x: (b c t), acoustic features
y: (b t), waveform
Returns:
z: (b t), waveform
"""
assert x.ndim == 3, "x must be 3D tensor"
assert y is None or y.ndim == 2, "y must be 2D tensor"
assert x.shape[1] == self.d_input, f"x.shape[1] must be {self.d_input}, but got {x.shape}"
assert npad >= 0, "npad must be positive or zero"
x = F.pad(x, (0, npad), "constant", 0)
z = torch.randn(x.shape[0], self.d_noise, x.shape[2]).to(x)
z = self.conv_pre(z) # (b c t)
for block in self.blocks:
z = block(z, x) # (b c t)
z = self.conv_post(z) # (b 1 t)
z = z[..., : -self.scale_factor * npad]
z = z.squeeze(1) # (b t)
if y is not None:
self.losses = self.mrstft(z, y)
return z

View File

@ -0,0 +1,128 @@
import logging
from dataclasses import asdict, dataclass
from pathlib import Path
from omegaconf import OmegaConf
from rich.console import Console
from rich.panel import Panel
from rich.table import Table
logger = logging.getLogger(__name__)
console = Console()
def _make_stft_cfg(hop_length, win_length=None):
if win_length is None:
win_length = 4 * hop_length
n_fft = 2 ** (win_length - 1).bit_length()
return dict(n_fft=n_fft, hop_length=hop_length, win_length=win_length)
def _build_rich_table(rows, columns, title=None):
table = Table(title=title, header_style=None)
for column in columns:
table.add_column(column.capitalize(), justify="left")
for row in rows:
table.add_row(*map(str, row))
return Panel(table, expand=False)
def _rich_print_dict(d, title="Config", key="Key", value="Value"):
console.print(_build_rich_table(d.items(), [key, value], title))
@dataclass(frozen=True)
class HParams:
# Dataset
fg_dir: Path = Path("data/fg")
bg_dir: Path = Path("data/bg")
rir_dir: Path = Path("data/rir")
load_fg_only: bool = False
praat_augment_prob: float = 0
# Audio settings
wav_rate: int = 44_100
n_fft: int = 2048
win_size: int = 2048
hop_size: int = 420 # 9.5ms
num_mels: int = 128
stft_magnitude_min: float = 1e-4
preemphasis: float = 0.97
mix_alpha_range: tuple[float, float] = (0.2, 0.8)
# Training
nj: int = 64
training_seconds: float = 1.0
batch_size_per_gpu: int = 16
min_lr: float = 1e-5
max_lr: float = 1e-4
warmup_steps: int = 1000
max_steps: int = 1_000_000
gradient_clipping: float = 1.0
@property
def deepspeed_config(self):
return {
"train_micro_batch_size_per_gpu": self.batch_size_per_gpu,
"optimizer": {
"type": "Adam",
"params": {"lr": float(self.min_lr)},
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": float(self.min_lr),
"warmup_max_lr": float(self.max_lr),
"warmup_num_steps": self.warmup_steps,
"total_num_steps": self.max_steps,
"warmup_type": "linear",
},
},
"gradient_clipping": self.gradient_clipping,
}
@property
def stft_cfgs(self):
assert self.wav_rate == 44_100, f"wav_rate must be 44_100, got {self.wav_rate}"
return [_make_stft_cfg(h) for h in (100, 256, 512)]
@classmethod
def from_yaml(cls, path: Path) -> "HParams":
logger.info(f"Reading hparams from {path}")
# First merge to fix types (e.g., str -> Path)
return cls(**dict(OmegaConf.merge(cls(), OmegaConf.load(path))))
def save_if_not_exists(self, run_dir: Path):
path = run_dir / "hparams.yaml"
if path.exists():
logger.info(f"{path} already exists, not saving")
return
path.parent.mkdir(parents=True, exist_ok=True)
OmegaConf.save(asdict(self), str(path))
@classmethod
def load(cls, run_dir, yaml: Path | None = None):
hps = []
if (run_dir / "hparams.yaml").exists():
hps.append(cls.from_yaml(run_dir / "hparams.yaml"))
if yaml is not None:
hps.append(cls.from_yaml(yaml))
if len(hps) == 0:
hps.append(cls())
for hp in hps[1:]:
if hp != hps[0]:
errors = {}
for k, v in asdict(hp).items():
if getattr(hps[0], k) != v:
errors[k] = f"{getattr(hps[0], k)} != {v}"
raise ValueError(f"Found inconsistent hparams: {errors}, consider deleting {run_dir}")
return hps[0]
def print(self):
_rich_print_dict(asdict(self), title="HParams")

View File

@ -0,0 +1,177 @@
import logging
import time
import gc
import torch
import torch.nn.functional as F
from torch.nn.utils.parametrize import remove_parametrizations
from torchaudio.functional import resample
from torchaudio.transforms import MelSpectrogram
from tqdm import trange
from .hparams import HParams
logger = logging.getLogger(__name__)
@torch.inference_mode()
def inference_chunk(model, dwav, sr, device, npad=441):
assert model.hp.wav_rate == sr, f"Expected {model.hp.wav_rate} Hz, got {sr} Hz"
del sr
length = dwav.shape[-1]
abs_max = dwav.abs().max().clamp(min=1e-7)
assert dwav.dim() == 1, f"Expected 1D waveform, got {dwav.dim()}D"
dwav = dwav.to(device)
dwav = dwav / abs_max # Normalize
dwav = F.pad(dwav, (0, npad))
hwav = model(dwav[None])[0].cpu() # (T,)
hwav = hwav[:length] # Trim padding
hwav = hwav * abs_max # Unnormalize
return hwav
def compute_corr(x, y):
return torch.fft.ifft(torch.fft.fft(x) * torch.fft.fft(y).conj()).abs()
def compute_offset(chunk1, chunk2, sr=44100):
"""
Args:
chunk1: (T,)
chunk2: (T,)
Returns:
offset: int, offset in samples such that chunk1 ~= chunk2.roll(-offset)
"""
hop_length = sr // 200 # 5 ms resolution
win_length = hop_length * 4
n_fft = 2 ** (win_length - 1).bit_length()
mel_fn = MelSpectrogram(
sample_rate=sr,
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
n_mels=80,
f_min=0.0,
f_max=sr // 2,
)
spec1 = mel_fn(chunk1).log1p()
spec2 = mel_fn(chunk2).log1p()
corr = compute_corr(spec1, spec2) # (F, T)
corr = corr.mean(dim=0) # (T,)
argmax = corr.argmax().item()
if argmax > len(corr) // 2:
argmax -= len(corr)
offset = -argmax * hop_length
return offset
def merge_chunks(chunks, chunk_length, hop_length, sr=44100, length=None):
signal_length = (len(chunks) - 1) * hop_length + chunk_length
overlap_length = chunk_length - hop_length
signal = torch.zeros(signal_length, device=chunks[0].device)
fadein = torch.linspace(0, 1, overlap_length, device=chunks[0].device)
fadein = torch.cat([fadein, torch.ones(hop_length, device=chunks[0].device)])
fadeout = torch.linspace(1, 0, overlap_length, device=chunks[0].device)
fadeout = torch.cat([torch.ones(hop_length, device=chunks[0].device), fadeout])
for i, chunk in enumerate(chunks):
start = i * hop_length
end = start + chunk_length
if len(chunk) < chunk_length:
chunk = F.pad(chunk, (0, chunk_length - len(chunk)))
if i > 0:
pre_region = chunks[i - 1][-overlap_length:]
cur_region = chunk[:overlap_length]
offset = compute_offset(pre_region, cur_region, sr=sr)
start -= offset
end -= offset
if i == 0:
chunk = chunk * fadeout
elif i == len(chunks) - 1:
chunk = chunk * fadein
else:
chunk = chunk * fadein * fadeout
signal[start:end] += chunk[: len(signal[start:end])]
signal = signal[:length]
return signal
def remove_weight_norm_recursively(module):
for _, module in module.named_modules():
try:
remove_parametrizations(module, "weight")
except Exception:
pass
def inference(model, dwav, sr, device, chunk_seconds: float = 30.0, overlap_seconds: float = 1.0):
remove_weight_norm_recursively(model)
hp: HParams = model.hp
dwav = resample(
dwav,
orig_freq=sr,
new_freq=hp.wav_rate,
lowpass_filter_width=64,
rolloff=0.9475937167399596,
resampling_method="sinc_interp_kaiser",
beta=14.769656459379492,
)
del sr # We are now using hp.wav_rate as the sampling rate
sr = hp.wav_rate
if torch.cuda.is_available():
torch.cuda.synchronize()
start_time = time.perf_counter()
chunk_length = int(sr * chunk_seconds)
overlap_length = int(sr * overlap_seconds)
hop_length = chunk_length - overlap_length
chunks = []
for start in trange(0, dwav.shape[-1], hop_length):
new_chunk = inference_chunk(model, dwav[start : start + chunk_length], sr, device)
chunks.append(new_chunk)
# Delete the processed segment to free up memory
# del new_chunk
# if torch.cuda.is_available():
# torch.cuda.empty_cache()
# Force garbage collection at this point (optional and may slow down processing)
# gc.collect()
hwav = merge_chunks(chunks, chunk_length, hop_length, sr=sr,length=dwav.shape[-1])
# Clean up chunks to free memory after merging
del chunks[:]
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect() # Explicitly call garbage collector again
elapsed_time = time.perf_counter() - start_time
logger.info(f"Elapsed time: {elapsed_time:.3f} s, {hwav.shape[-1] / elapsed_time / 1000:.3f} kHz")
return hwav, sr

View File

@ -0,0 +1,61 @@
import numpy as np
import torch
from torch import nn
from torchaudio.transforms import MelSpectrogram as TorchMelSpectrogram
from .hparams import HParams
class MelSpectrogram(nn.Module):
def __init__(self, hp: HParams):
"""
Torch implementation of Resemble's mel extraction.
Note that the values are NOT identical to librosa's implementation
due to floating point precisions.
"""
super().__init__()
self.hp = hp
self.melspec = TorchMelSpectrogram(
hp.wav_rate,
n_fft=hp.n_fft,
win_length=hp.win_size,
hop_length=hp.hop_size,
f_min=0,
f_max=hp.wav_rate // 2,
n_mels=hp.num_mels,
power=1,
normalized=False,
# NOTE: Folowing librosa's default.
pad_mode="constant",
norm="slaney",
mel_scale="slaney",
)
self.register_buffer("stft_magnitude_min", torch.FloatTensor([hp.stft_magnitude_min]))
self.min_level_db = 20 * np.log10(hp.stft_magnitude_min)
self.preemphasis = hp.preemphasis
self.hop_size = hp.hop_size
def forward(self, wav, pad=True):
"""
Args:
wav: [B, T]
"""
device = wav.device
if wav.is_mps:
wav = wav.cpu()
self.to(wav.device)
if self.preemphasis > 0:
wav = torch.nn.functional.pad(wav, [1, 0], value=0)
wav = wav[..., 1:] - self.preemphasis * wav[..., :-1]
mel = self.melspec(wav)
mel = self._amp_to_db(mel)
mel_normed = self._normalize(mel)
assert not pad or mel_normed.shape[-1] == 1 + wav.shape[-1] // self.hop_size # Sanity check
mel_normed = mel_normed.to(device)
return mel_normed # (M, T)
def _normalize(self, s, headroom_db=15):
return (s - self.min_level_db) / (-self.min_level_db + headroom_db)
def _amp_to_db(self, x):
return x.clamp_min(self.hp.stft_magnitude_min).log10() * 20