mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-06 06:29:59 +08:00
support for mel_band_roformer
This commit is contained in:
parent
3737496389
commit
c5d92a22e9
@ -2,6 +2,7 @@ from functools import wraps
|
||||
from packaging import version
|
||||
from collections import namedtuple
|
||||
|
||||
import os
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
@ -59,12 +60,17 @@ class Attend(nn.Module):
|
||||
return
|
||||
|
||||
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
|
||||
device_version = version.parse(f'{device_properties.major}.{device_properties.minor}')
|
||||
|
||||
if device_properties.major == 8 and device_properties.minor == 0:
|
||||
print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
|
||||
self.cuda_config = FlashAttentionConfig(True, False, False)
|
||||
if device_version >= version.parse('8.0'):
|
||||
if os.name == 'nt':
|
||||
print_once('Windows OS detected, using math or mem efficient attention if input tensor is on cuda')
|
||||
self.cuda_config = FlashAttentionConfig(False, True, True)
|
||||
else:
|
||||
print_once('GPU Compute Capability equal or above 8.0, using flash attention if input tensor is on cuda')
|
||||
self.cuda_config = FlashAttentionConfig(True, False, False)
|
||||
else:
|
||||
print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
|
||||
print_once('GPU Compute Capability below 8.0, using math or mem efficient attention if input tensor is on cuda')
|
||||
self.cuda_config = FlashAttentionConfig(False, True, True)
|
||||
|
||||
def flash_attn(self, q, k, v):
|
||||
|
@ -6,6 +6,7 @@ from torch.nn import Module, ModuleList
|
||||
import torch.nn.functional as F
|
||||
|
||||
from bs_roformer.attend import Attend
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from typing import Tuple, Optional, List, Callable
|
||||
# from beartype.typing import Tuple, Optional, List, Callable
|
||||
@ -356,13 +357,18 @@ class BSRoformer(Module):
|
||||
multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
|
||||
multi_stft_hop_size=147,
|
||||
multi_stft_normalized=False,
|
||||
multi_stft_window_fn: Callable = torch.hann_window
|
||||
multi_stft_window_fn: Callable = torch.hann_window,
|
||||
mlp_expansion_factor=4,
|
||||
use_torch_checkpoint=False,
|
||||
skip_connection=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.stereo = stereo
|
||||
self.audio_channels = 2 if stereo else 1
|
||||
self.num_stems = num_stems
|
||||
self.use_torch_checkpoint = use_torch_checkpoint
|
||||
self.skip_connection = skip_connection
|
||||
|
||||
self.layers = ModuleList([])
|
||||
|
||||
@ -402,7 +408,7 @@ class BSRoformer(Module):
|
||||
|
||||
self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
|
||||
|
||||
freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, return_complex=True).shape[1]
|
||||
freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, window=torch.ones(stft_win_length), return_complex=True).shape[1]
|
||||
|
||||
assert len(freqs_per_bands) > 1
|
||||
assert sum(
|
||||
@ -421,7 +427,8 @@ class BSRoformer(Module):
|
||||
mask_estimator = MaskEstimator(
|
||||
dim=dim,
|
||||
dim_inputs=freqs_per_bands_with_complex,
|
||||
depth=mask_estimator_depth
|
||||
depth=mask_estimator_depth,
|
||||
mlp_expansion_factor=mlp_expansion_factor,
|
||||
)
|
||||
|
||||
self.mask_estimators.append(mask_estimator)
|
||||
@ -458,12 +465,14 @@ class BSRoformer(Module):
|
||||
|
||||
device = raw_audio.device
|
||||
|
||||
# defining whether model is loaded on MPS (MacOS GPU accelerator)
|
||||
x_is_mps = True if device.type == "mps" else False
|
||||
|
||||
if raw_audio.ndim == 2:
|
||||
raw_audio = rearrange(raw_audio, 'b t -> b 1 t')
|
||||
|
||||
channels = raw_audio.shape[1]
|
||||
assert (not self.stereo and channels == 1) or (
|
||||
self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)'
|
||||
assert (not self.stereo and channels == 1) or (self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)'
|
||||
|
||||
# to stft
|
||||
|
||||
@ -471,53 +480,79 @@ class BSRoformer(Module):
|
||||
|
||||
stft_window = self.stft_window_fn(device=device)
|
||||
|
||||
stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
|
||||
# RuntimeError: FFT operations are only supported on MacOS 14+
|
||||
# Since it's tedious to define whether we're on correct MacOS version - simple try-catch is used
|
||||
try:
|
||||
stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
|
||||
except:
|
||||
stft_repr = torch.stft(raw_audio.cpu() if x_is_mps else raw_audio, **self.stft_kwargs, window=stft_window.cpu() if x_is_mps else stft_window, return_complex=True).to(device)
|
||||
|
||||
stft_repr = torch.view_as_real(stft_repr)
|
||||
|
||||
stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c')
|
||||
stft_repr = rearrange(stft_repr,
|
||||
'b s f t c -> b (f s) t c') # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
|
||||
|
||||
# merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
|
||||
stft_repr = rearrange(stft_repr,'b s f t c -> b (f s) t c')
|
||||
|
||||
x = rearrange(stft_repr, 'b f t c -> b t (f c)')
|
||||
# print("460:", x.dtype)#fp32
|
||||
x = self.band_split(x)
|
||||
|
||||
if self.use_torch_checkpoint:
|
||||
x = checkpoint(self.band_split, x, use_reentrant=False)
|
||||
else:
|
||||
x = self.band_split(x)
|
||||
|
||||
# axial / hierarchical attention
|
||||
|
||||
# print("487:",x.dtype)#fp16
|
||||
for transformer_block in self.layers:
|
||||
store = [None] * len(self.layers)
|
||||
for i, transformer_block in enumerate(self.layers):
|
||||
|
||||
if len(transformer_block) == 3:
|
||||
linear_transformer, time_transformer, freq_transformer = transformer_block
|
||||
|
||||
x, ft_ps = pack([x], 'b * d')
|
||||
# print("494:", x.dtype)#fp16
|
||||
x = linear_transformer(x)
|
||||
# print("496:", x.dtype)#fp16
|
||||
if self.use_torch_checkpoint:
|
||||
x = checkpoint(linear_transformer, x, use_reentrant=False)
|
||||
else:
|
||||
x = linear_transformer(x)
|
||||
x, = unpack(x, ft_ps, 'b * d')
|
||||
else:
|
||||
time_transformer, freq_transformer = transformer_block
|
||||
|
||||
# print("501:", x.dtype)#fp16
|
||||
if self.skip_connection:
|
||||
# Sum all previous
|
||||
for j in range(i):
|
||||
x = x + store[j]
|
||||
|
||||
x = rearrange(x, 'b t f d -> b f t d')
|
||||
x, ps = pack([x], '* t d')
|
||||
|
||||
x = time_transformer(x)
|
||||
# print("505:", x.dtype)#fp16
|
||||
if self.use_torch_checkpoint:
|
||||
x = checkpoint(time_transformer, x, use_reentrant=False)
|
||||
else:
|
||||
x = time_transformer(x)
|
||||
|
||||
x, = unpack(x, ps, '* t d')
|
||||
x = rearrange(x, 'b f t d -> b t f d')
|
||||
x, ps = pack([x], '* f d')
|
||||
|
||||
x = freq_transformer(x)
|
||||
if self.use_torch_checkpoint:
|
||||
x = checkpoint(freq_transformer, x, use_reentrant=False)
|
||||
else:
|
||||
x = freq_transformer(x)
|
||||
|
||||
x, = unpack(x, ps, '* f d')
|
||||
|
||||
# print("515:", x.dtype)######fp16
|
||||
if self.skip_connection:
|
||||
store[i] = x
|
||||
|
||||
x = self.final_norm(x)
|
||||
|
||||
num_stems = len(self.mask_estimators)
|
||||
# print("519:", x.dtype)#fp32
|
||||
mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
|
||||
|
||||
if self.use_torch_checkpoint:
|
||||
mask = torch.stack([checkpoint(fn, x, use_reentrant=False) for fn in self.mask_estimators], dim=1)
|
||||
else:
|
||||
mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
|
||||
mask = rearrange(mask, 'b n t (f c) -> b n f t c', c=2)
|
||||
|
||||
# modulate frequency representation
|
||||
@ -535,7 +570,11 @@ class BSRoformer(Module):
|
||||
|
||||
stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels)
|
||||
|
||||
recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False)
|
||||
# same as torch.stft() fix for MacOS MPS above
|
||||
try:
|
||||
recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False, length=raw_audio.shape[-1])
|
||||
except:
|
||||
recon_audio = torch.istft(stft_repr.cpu() if x_is_mps else stft_repr, **self.stft_kwargs, window=stft_window.cpu() if x_is_mps else stft_window, return_complex=False, length=raw_audio.shape[-1]).to(device)
|
||||
|
||||
recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', s=self.audio_channels, n=num_stems)
|
||||
|
||||
|
669
tools/uvr5/bs_roformer/mel_band_roformer.py
Normal file
669
tools/uvr5/bs_roformer/mel_band_roformer.py
Normal file
@ -0,0 +1,669 @@
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from torch import nn, einsum, Tensor
|
||||
from torch.nn import Module, ModuleList
|
||||
import torch.nn.functional as F
|
||||
|
||||
from bs_roformer.attend import Attend
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from typing import Tuple, Optional, List, Callable
|
||||
# from beartype.typing import Tuple, Optional, List, Callable
|
||||
# from beartype import beartype
|
||||
|
||||
from rotary_embedding_torch import RotaryEmbedding
|
||||
|
||||
from einops import rearrange, pack, unpack, reduce, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
from librosa import filters
|
||||
|
||||
|
||||
# helper functions
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def default(v, d):
|
||||
return v if exists(v) else d
|
||||
|
||||
|
||||
def pack_one(t, pattern):
|
||||
return pack([t], pattern)
|
||||
|
||||
|
||||
def unpack_one(t, ps, pattern):
|
||||
return unpack(t, ps, pattern)[0]
|
||||
|
||||
|
||||
def pad_at_dim(t, pad, dim=-1, value=0.):
|
||||
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
|
||||
zeros = ((0, 0) * dims_from_right)
|
||||
return F.pad(t, (*zeros, *pad), value=value)
|
||||
|
||||
|
||||
def l2norm(t):
|
||||
return F.normalize(t, dim=-1, p=2)
|
||||
|
||||
|
||||
# norm
|
||||
|
||||
class RMSNorm(Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.scale = dim ** 0.5
|
||||
self.gamma = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x):
|
||||
return F.normalize(x, dim=-1) * self.scale * self.gamma
|
||||
|
||||
|
||||
# attention
|
||||
|
||||
class FeedForward(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
mult=4,
|
||||
dropout=0.
|
||||
):
|
||||
super().__init__()
|
||||
dim_inner = int(dim * mult)
|
||||
self.net = nn.Sequential(
|
||||
RMSNorm(dim),
|
||||
nn.Linear(dim, dim_inner),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(dim_inner, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
heads=8,
|
||||
dim_head=64,
|
||||
dropout=0.,
|
||||
rotary_embed=None,
|
||||
flash=True
|
||||
):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
dim_inner = heads * dim_head
|
||||
|
||||
self.rotary_embed = rotary_embed
|
||||
|
||||
self.attend = Attend(flash=flash, dropout=dropout)
|
||||
|
||||
self.norm = RMSNorm(dim)
|
||||
self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
|
||||
|
||||
self.to_gates = nn.Linear(dim, heads)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(dim_inner, dim, bias=False),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
|
||||
q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads)
|
||||
|
||||
if exists(self.rotary_embed):
|
||||
q = self.rotary_embed.rotate_queries_or_keys(q)
|
||||
k = self.rotary_embed.rotate_queries_or_keys(k)
|
||||
|
||||
out = self.attend(q, k, v)
|
||||
|
||||
gates = self.to_gates(x)
|
||||
out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid()
|
||||
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class LinearAttention(Module):
|
||||
"""
|
||||
this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al.
|
||||
"""
|
||||
|
||||
# @beartype
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
dim_head=32,
|
||||
heads=8,
|
||||
scale=8,
|
||||
flash=False,
|
||||
dropout=0.
|
||||
):
|
||||
super().__init__()
|
||||
dim_inner = dim_head * heads
|
||||
self.norm = RMSNorm(dim)
|
||||
|
||||
self.to_qkv = nn.Sequential(
|
||||
nn.Linear(dim, dim_inner * 3, bias=False),
|
||||
Rearrange('b n (qkv h d) -> qkv b h d n', qkv=3, h=heads)
|
||||
)
|
||||
|
||||
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
|
||||
|
||||
self.attend = Attend(
|
||||
scale=scale,
|
||||
dropout=dropout,
|
||||
flash=flash
|
||||
)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
Rearrange('b h d n -> b n (h d)'),
|
||||
nn.Linear(dim_inner, dim, bias=False)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x
|
||||
):
|
||||
x = self.norm(x)
|
||||
|
||||
q, k, v = self.to_qkv(x)
|
||||
|
||||
q, k = map(l2norm, (q, k))
|
||||
q = q * self.temperature.exp()
|
||||
|
||||
out = self.attend(q, k, v)
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class Transformer(Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
depth,
|
||||
dim_head=64,
|
||||
heads=8,
|
||||
attn_dropout=0.,
|
||||
ff_dropout=0.,
|
||||
ff_mult=4,
|
||||
norm_output=True,
|
||||
rotary_embed=None,
|
||||
flash_attn=True,
|
||||
linear_attn=False
|
||||
):
|
||||
super().__init__()
|
||||
self.layers = ModuleList([])
|
||||
|
||||
for _ in range(depth):
|
||||
if linear_attn:
|
||||
attn = LinearAttention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, flash=flash_attn)
|
||||
else:
|
||||
attn = Attention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout,
|
||||
rotary_embed=rotary_embed, flash=flash_attn)
|
||||
|
||||
self.layers.append(ModuleList([
|
||||
attn,
|
||||
FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)
|
||||
]))
|
||||
|
||||
self.norm = RMSNorm(dim) if norm_output else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
|
||||
# bandsplit module
|
||||
|
||||
class BandSplit(Module):
|
||||
# @beartype
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_inputs: Tuple[int, ...]
|
||||
):
|
||||
super().__init__()
|
||||
self.dim_inputs = dim_inputs
|
||||
self.to_features = ModuleList([])
|
||||
|
||||
for dim_in in dim_inputs:
|
||||
net = nn.Sequential(
|
||||
RMSNorm(dim_in),
|
||||
nn.Linear(dim_in, dim)
|
||||
)
|
||||
|
||||
self.to_features.append(net)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.split(self.dim_inputs, dim=-1)
|
||||
|
||||
outs = []
|
||||
for split_input, to_feature in zip(x, self.to_features):
|
||||
split_output = to_feature(split_input)
|
||||
outs.append(split_output)
|
||||
|
||||
return torch.stack(outs, dim=-2)
|
||||
|
||||
|
||||
def MLP(
|
||||
dim_in,
|
||||
dim_out,
|
||||
dim_hidden=None,
|
||||
depth=1,
|
||||
activation=nn.Tanh
|
||||
):
|
||||
dim_hidden = default(dim_hidden, dim_in)
|
||||
|
||||
net = []
|
||||
dims = (dim_in, *((dim_hidden,) * depth), dim_out)
|
||||
|
||||
for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
is_last = ind == (len(dims) - 2)
|
||||
|
||||
net.append(nn.Linear(layer_dim_in, layer_dim_out))
|
||||
|
||||
if is_last:
|
||||
continue
|
||||
|
||||
net.append(activation())
|
||||
|
||||
return nn.Sequential(*net)
|
||||
|
||||
|
||||
class MaskEstimator(Module):
|
||||
# @beartype
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_inputs: Tuple[int, ...],
|
||||
depth,
|
||||
mlp_expansion_factor=4
|
||||
):
|
||||
super().__init__()
|
||||
self.dim_inputs = dim_inputs
|
||||
self.to_freqs = ModuleList([])
|
||||
dim_hidden = dim * mlp_expansion_factor
|
||||
|
||||
for dim_in in dim_inputs:
|
||||
net = []
|
||||
|
||||
mlp = nn.Sequential(
|
||||
MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth),
|
||||
nn.GLU(dim=-1)
|
||||
)
|
||||
|
||||
self.to_freqs.append(mlp)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.unbind(dim=-2)
|
||||
|
||||
outs = []
|
||||
|
||||
for band_features, mlp in zip(x, self.to_freqs):
|
||||
freq_out = mlp(band_features)
|
||||
outs.append(freq_out)
|
||||
|
||||
return torch.cat(outs, dim=-1)
|
||||
|
||||
|
||||
# main class
|
||||
|
||||
class MelBandRoformer(Module):
|
||||
|
||||
# @beartype
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
*,
|
||||
depth,
|
||||
stereo=False,
|
||||
num_stems=1,
|
||||
time_transformer_depth=2,
|
||||
freq_transformer_depth=2,
|
||||
linear_transformer_depth=0,
|
||||
num_bands=60,
|
||||
dim_head=64,
|
||||
heads=8,
|
||||
attn_dropout=0.1,
|
||||
ff_dropout=0.1,
|
||||
flash_attn=True,
|
||||
dim_freqs_in=1025,
|
||||
sample_rate=44100, # needed for mel filter bank from librosa
|
||||
stft_n_fft=2048,
|
||||
stft_hop_length=512,
|
||||
# 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
|
||||
stft_win_length=2048,
|
||||
stft_normalized=False,
|
||||
stft_window_fn: Optional[Callable] = None,
|
||||
mask_estimator_depth=1,
|
||||
multi_stft_resolution_loss_weight=1.,
|
||||
multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
|
||||
multi_stft_hop_size=147,
|
||||
multi_stft_normalized=False,
|
||||
multi_stft_window_fn: Callable = torch.hann_window,
|
||||
match_input_audio_length=False, # if True, pad output tensor to match length of input tensor
|
||||
mlp_expansion_factor=4,
|
||||
use_torch_checkpoint=False,
|
||||
skip_connection=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.stereo = stereo
|
||||
self.audio_channels = 2 if stereo else 1
|
||||
self.num_stems = num_stems
|
||||
self.use_torch_checkpoint = use_torch_checkpoint
|
||||
self.skip_connection = skip_connection
|
||||
|
||||
self.layers = ModuleList([])
|
||||
|
||||
transformer_kwargs = dict(
|
||||
dim=dim,
|
||||
heads=heads,
|
||||
dim_head=dim_head,
|
||||
attn_dropout=attn_dropout,
|
||||
ff_dropout=ff_dropout,
|
||||
flash_attn=flash_attn
|
||||
)
|
||||
|
||||
time_rotary_embed = RotaryEmbedding(dim=dim_head)
|
||||
freq_rotary_embed = RotaryEmbedding(dim=dim_head)
|
||||
|
||||
for _ in range(depth):
|
||||
tran_modules = []
|
||||
if linear_transformer_depth > 0:
|
||||
tran_modules.append(Transformer(depth=linear_transformer_depth, linear_attn=True, **transformer_kwargs))
|
||||
tran_modules.append(
|
||||
Transformer(depth=time_transformer_depth, rotary_embed=time_rotary_embed, **transformer_kwargs)
|
||||
)
|
||||
tran_modules.append(
|
||||
Transformer(depth=freq_transformer_depth, rotary_embed=freq_rotary_embed, **transformer_kwargs)
|
||||
)
|
||||
self.layers.append(nn.ModuleList(tran_modules))
|
||||
|
||||
self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
|
||||
|
||||
self.stft_kwargs = dict(
|
||||
n_fft=stft_n_fft,
|
||||
hop_length=stft_hop_length,
|
||||
win_length=stft_win_length,
|
||||
normalized=stft_normalized
|
||||
)
|
||||
|
||||
freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, window=torch.ones(stft_n_fft), return_complex=True).shape[1]
|
||||
|
||||
# create mel filter bank
|
||||
# with librosa.filters.mel as in section 2 of paper
|
||||
|
||||
mel_filter_bank_numpy = filters.mel(sr=sample_rate, n_fft=stft_n_fft, n_mels=num_bands)
|
||||
|
||||
mel_filter_bank = torch.from_numpy(mel_filter_bank_numpy)
|
||||
|
||||
# for some reason, it doesn't include the first freq? just force a value for now
|
||||
|
||||
mel_filter_bank[0][0] = 1.
|
||||
|
||||
# In some systems/envs we get 0.0 instead of ~1.9e-18 in the last position,
|
||||
# so let's force a positive value
|
||||
|
||||
mel_filter_bank[-1, -1] = 1.
|
||||
|
||||
# binary as in paper (then estimated masks are averaged for overlapping regions)
|
||||
|
||||
freqs_per_band = mel_filter_bank > 0
|
||||
assert freqs_per_band.any(dim=0).all(), 'all frequencies need to be covered by all bands for now'
|
||||
|
||||
repeated_freq_indices = repeat(torch.arange(freqs), 'f -> b f', b=num_bands)
|
||||
freq_indices = repeated_freq_indices[freqs_per_band]
|
||||
|
||||
if stereo:
|
||||
freq_indices = repeat(freq_indices, 'f -> f s', s=2)
|
||||
freq_indices = freq_indices * 2 + torch.arange(2)
|
||||
freq_indices = rearrange(freq_indices, 'f s -> (f s)')
|
||||
|
||||
self.register_buffer('freq_indices', freq_indices, persistent=False)
|
||||
self.register_buffer('freqs_per_band', freqs_per_band, persistent=False)
|
||||
|
||||
num_freqs_per_band = reduce(freqs_per_band, 'b f -> b', 'sum')
|
||||
num_bands_per_freq = reduce(freqs_per_band, 'b f -> f', 'sum')
|
||||
|
||||
self.register_buffer('num_freqs_per_band', num_freqs_per_band, persistent=False)
|
||||
self.register_buffer('num_bands_per_freq', num_bands_per_freq, persistent=False)
|
||||
|
||||
# band split and mask estimator
|
||||
|
||||
freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in num_freqs_per_band.tolist())
|
||||
|
||||
self.band_split = BandSplit(
|
||||
dim=dim,
|
||||
dim_inputs=freqs_per_bands_with_complex
|
||||
)
|
||||
|
||||
self.mask_estimators = nn.ModuleList([])
|
||||
|
||||
for _ in range(num_stems):
|
||||
mask_estimator = MaskEstimator(
|
||||
dim=dim,
|
||||
dim_inputs=freqs_per_bands_with_complex,
|
||||
depth=mask_estimator_depth,
|
||||
mlp_expansion_factor=mlp_expansion_factor,
|
||||
)
|
||||
|
||||
self.mask_estimators.append(mask_estimator)
|
||||
|
||||
# for the multi-resolution stft loss
|
||||
|
||||
self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
|
||||
self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
|
||||
self.multi_stft_n_fft = stft_n_fft
|
||||
self.multi_stft_window_fn = multi_stft_window_fn
|
||||
|
||||
self.multi_stft_kwargs = dict(
|
||||
hop_length=multi_stft_hop_size,
|
||||
normalized=multi_stft_normalized
|
||||
)
|
||||
|
||||
self.match_input_audio_length = match_input_audio_length
|
||||
|
||||
def forward(
|
||||
self,
|
||||
raw_audio,
|
||||
target=None,
|
||||
return_loss_breakdown=False
|
||||
):
|
||||
"""
|
||||
einops
|
||||
|
||||
b - batch
|
||||
f - freq
|
||||
t - time
|
||||
s - audio channel (1 for mono, 2 for stereo)
|
||||
n - number of 'stems'
|
||||
c - complex (2)
|
||||
d - feature dimension
|
||||
"""
|
||||
|
||||
device = raw_audio.device
|
||||
|
||||
if raw_audio.ndim == 2:
|
||||
raw_audio = rearrange(raw_audio, 'b t -> b 1 t')
|
||||
|
||||
batch, channels, raw_audio_length = raw_audio.shape
|
||||
|
||||
istft_length = raw_audio_length if self.match_input_audio_length else None
|
||||
|
||||
assert (not self.stereo and channels == 1) or (
|
||||
self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)'
|
||||
|
||||
# to stft
|
||||
|
||||
raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t')
|
||||
|
||||
stft_window = self.stft_window_fn(device=device)
|
||||
|
||||
stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
|
||||
stft_repr = torch.view_as_real(stft_repr)
|
||||
|
||||
stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c')
|
||||
|
||||
# merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
|
||||
stft_repr = rearrange(stft_repr,'b s f t c -> b (f s) t c')
|
||||
|
||||
# index out all frequencies for all frequency ranges across bands ascending in one go
|
||||
|
||||
batch_arange = torch.arange(batch, device=device)[..., None]
|
||||
|
||||
# account for stereo
|
||||
|
||||
x = stft_repr[batch_arange, self.freq_indices]
|
||||
|
||||
# fold the complex (real and imag) into the frequencies dimension
|
||||
|
||||
x = rearrange(x, 'b f t c -> b t (f c)')
|
||||
|
||||
if self.use_torch_checkpoint:
|
||||
x = checkpoint(self.band_split, x, use_reentrant=False)
|
||||
else:
|
||||
x = self.band_split(x)
|
||||
|
||||
# axial / hierarchical attention
|
||||
|
||||
store = [None] * len(self.layers)
|
||||
for i, transformer_block in enumerate(self.layers):
|
||||
|
||||
if len(transformer_block) == 3:
|
||||
linear_transformer, time_transformer, freq_transformer = transformer_block
|
||||
|
||||
x, ft_ps = pack([x], 'b * d')
|
||||
if self.use_torch_checkpoint:
|
||||
x = checkpoint(linear_transformer, x, use_reentrant=False)
|
||||
else:
|
||||
x = linear_transformer(x)
|
||||
x, = unpack(x, ft_ps, 'b * d')
|
||||
else:
|
||||
time_transformer, freq_transformer = transformer_block
|
||||
|
||||
if self.skip_connection:
|
||||
# Sum all previous
|
||||
for j in range(i):
|
||||
x = x + store[j]
|
||||
|
||||
x = rearrange(x, 'b t f d -> b f t d')
|
||||
x, ps = pack([x], '* t d')
|
||||
|
||||
if self.use_torch_checkpoint:
|
||||
x = checkpoint(time_transformer, x, use_reentrant=False)
|
||||
else:
|
||||
x = time_transformer(x)
|
||||
|
||||
x, = unpack(x, ps, '* t d')
|
||||
x = rearrange(x, 'b f t d -> b t f d')
|
||||
x, ps = pack([x], '* f d')
|
||||
|
||||
if self.use_torch_checkpoint:
|
||||
x = checkpoint(freq_transformer, x, use_reentrant=False)
|
||||
else:
|
||||
x = freq_transformer(x)
|
||||
|
||||
x, = unpack(x, ps, '* f d')
|
||||
|
||||
if self.skip_connection:
|
||||
store[i] = x
|
||||
|
||||
num_stems = len(self.mask_estimators)
|
||||
if self.use_torch_checkpoint:
|
||||
masks = torch.stack([checkpoint(fn, x, use_reentrant=False) for fn in self.mask_estimators], dim=1)
|
||||
else:
|
||||
masks = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
|
||||
masks = rearrange(masks, 'b n t (f c) -> b n f t c', c=2)
|
||||
|
||||
# modulate frequency representation
|
||||
|
||||
stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c')
|
||||
|
||||
# complex number multiplication
|
||||
|
||||
stft_repr = torch.view_as_complex(stft_repr)
|
||||
masks = torch.view_as_complex(masks)
|
||||
|
||||
masks = masks.type(stft_repr.dtype)
|
||||
|
||||
# need to average the estimated mask for the overlapped frequencies
|
||||
|
||||
scatter_indices = repeat(self.freq_indices, 'f -> b n f t', b=batch, n=num_stems, t=stft_repr.shape[-1])
|
||||
|
||||
stft_repr_expanded_stems = repeat(stft_repr, 'b 1 ... -> b n ...', n=num_stems)
|
||||
masks_summed = torch.zeros_like(stft_repr_expanded_stems).scatter_add_(2, scatter_indices, masks)
|
||||
|
||||
denom = repeat(self.num_bands_per_freq, 'f -> (f r) 1', r=channels)
|
||||
|
||||
masks_averaged = masks_summed / denom.clamp(min=1e-8)
|
||||
|
||||
# modulate stft repr with estimated mask
|
||||
|
||||
stft_repr = stft_repr * masks_averaged
|
||||
|
||||
# istft
|
||||
|
||||
stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels)
|
||||
|
||||
recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False,
|
||||
length=istft_length)
|
||||
|
||||
recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', b=batch, s=self.audio_channels, n=num_stems)
|
||||
|
||||
if num_stems == 1:
|
||||
recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t')
|
||||
|
||||
# if a target is passed in, calculate loss for learning
|
||||
|
||||
if not exists(target):
|
||||
return recon_audio
|
||||
|
||||
if self.num_stems > 1:
|
||||
assert target.ndim == 4 and target.shape[1] == self.num_stems
|
||||
|
||||
if target.ndim == 2:
|
||||
target = rearrange(target, '... t -> ... 1 t')
|
||||
|
||||
target = target[..., :recon_audio.shape[-1]] # protect against lost length on istft
|
||||
|
||||
loss = F.l1_loss(recon_audio, target)
|
||||
|
||||
multi_stft_resolution_loss = 0.
|
||||
|
||||
for window_size in self.multi_stft_resolutions_window_sizes:
|
||||
res_stft_kwargs = dict(
|
||||
n_fft=max(window_size, self.multi_stft_n_fft), # not sure what n_fft is across multi resolution stft
|
||||
win_length=window_size,
|
||||
return_complex=True,
|
||||
window=self.multi_stft_window_fn(window_size, device=device),
|
||||
**self.multi_stft_kwargs,
|
||||
)
|
||||
|
||||
recon_Y = torch.stft(rearrange(recon_audio, '... s t -> (... s) t'), **res_stft_kwargs)
|
||||
target_Y = torch.stft(rearrange(target, '... s t -> (... s) t'), **res_stft_kwargs)
|
||||
|
||||
multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y)
|
||||
|
||||
weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
|
||||
|
||||
total_loss = loss + weighted_multi_resolution_loss
|
||||
|
||||
if not return_loss_breakdown:
|
||||
return total_loss
|
||||
|
||||
return total_loss, (loss, multi_stft_resolution_loss)
|
@ -1,6 +1,4 @@
|
||||
# This code is modified from https://github.com/ZFTurbo/
|
||||
import pdb
|
||||
|
||||
import librosa
|
||||
from tqdm import tqdm
|
||||
import os
|
||||
@ -8,61 +6,113 @@ import torch
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import torch.nn as nn
|
||||
|
||||
import yaml
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
from bs_roformer.bs_roformer import BSRoformer
|
||||
|
||||
class BsRoformer_Loader:
|
||||
|
||||
class Roformer_Loader:
|
||||
def get_config(self, config_path):
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
# use fullloader to load tag !!python/tuple, code can be improved
|
||||
config = yaml.load(f, Loader=yaml.FullLoader)
|
||||
return config
|
||||
|
||||
def get_default_config(self):
|
||||
default_config = None
|
||||
if self.model_type == 'bs_roformer':
|
||||
# Use model_bs_roformer_ep_368_sdr_12.9628.yaml and model_bs_roformer_ep_317_sdr_12.9755.yaml as default configuration files
|
||||
# Other BS_Roformer models may not be compatible
|
||||
default_config = {
|
||||
"audio": {"chunk_size": 352800, "sample_rate": 44100},
|
||||
"model": {
|
||||
"dim": 512,
|
||||
"depth": 12,
|
||||
"stereo": True,
|
||||
"num_stems": 1,
|
||||
"time_transformer_depth": 1,
|
||||
"freq_transformer_depth": 1,
|
||||
"linear_transformer_depth": 0,
|
||||
"freqs_per_bands": (2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 12, 12, 12, 12, 12, 12, 12, 12, 24, 24, 24, 24, 24, 24, 24, 24, 48, 48, 48, 48, 48, 48, 48, 48, 128, 129),
|
||||
"dim_head": 64,
|
||||
"heads": 8,
|
||||
"attn_dropout": 0.1,
|
||||
"ff_dropout": 0.1,
|
||||
"flash_attn": True,
|
||||
"dim_freqs_in": 1025,
|
||||
"stft_n_fft": 2048,
|
||||
"stft_hop_length": 441,
|
||||
"stft_win_length": 2048,
|
||||
"stft_normalized": False,
|
||||
"mask_estimator_depth": 2,
|
||||
"multi_stft_resolution_loss_weight": 1.0,
|
||||
"multi_stft_resolutions_window_sizes": (4096, 2048, 1024, 512, 256),
|
||||
"multi_stft_hop_size": 147,
|
||||
"multi_stft_normalized": False,
|
||||
},
|
||||
"training": {"instruments": ["vocals", "other"], "target_instrument": "vocals"},
|
||||
"inference": {"batch_size": 2, "num_overlap": 2}
|
||||
}
|
||||
elif self.model_type == 'mel_band_roformer':
|
||||
# Use model_mel_band_roformer_ep_3005_sdr_11.4360.yaml as default configuration files
|
||||
# Other Mel_Band_Roformer models may not be compatible
|
||||
default_config = {
|
||||
"audio": {"chunk_size": 352800, "sample_rate": 44100},
|
||||
"model": {
|
||||
"dim": 384,
|
||||
"depth": 12,
|
||||
"stereo": True,
|
||||
"num_stems": 1,
|
||||
"time_transformer_depth": 1,
|
||||
"freq_transformer_depth": 1,
|
||||
"linear_transformer_depth": 0,
|
||||
"num_bands": 60,
|
||||
"dim_head": 64,
|
||||
"heads": 8,
|
||||
"attn_dropout": 0.1,
|
||||
"ff_dropout": 0.1,
|
||||
"flash_attn": True,
|
||||
"dim_freqs_in": 1025,
|
||||
"sample_rate": 44100,
|
||||
"stft_n_fft": 2048,
|
||||
"stft_hop_length": 441,
|
||||
"stft_win_length": 2048,
|
||||
"stft_normalized": False,
|
||||
"mask_estimator_depth": 2,
|
||||
"multi_stft_resolution_loss_weight": 1.0,
|
||||
"multi_stft_resolutions_window_sizes": (4096, 2048, 1024, 512, 256),
|
||||
"multi_stft_hop_size": 147,
|
||||
"multi_stft_normalized": False
|
||||
},
|
||||
"training": {"instruments": ["vocals", "other"], "target_instrument": "vocals"},
|
||||
"inference": {"batch_size": 2, "num_overlap": 2}
|
||||
}
|
||||
return default_config
|
||||
|
||||
|
||||
def get_model_from_config(self):
|
||||
config = {
|
||||
"attn_dropout": 0.1,
|
||||
"depth": 12,
|
||||
"dim": 512,
|
||||
"dim_freqs_in": 1025,
|
||||
"dim_head": 64,
|
||||
"ff_dropout": 0.1,
|
||||
"flash_attn": True,
|
||||
"freq_transformer_depth": 1,
|
||||
"freqs_per_bands":(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 12, 12, 12, 12, 12, 12, 12, 12, 24, 24, 24, 24, 24, 24, 24, 24, 48, 48, 48, 48, 48, 48, 48, 48, 128, 129),
|
||||
"heads": 8,
|
||||
"linear_transformer_depth": 0,
|
||||
"mask_estimator_depth": 2,
|
||||
"multi_stft_hop_size": 147,
|
||||
"multi_stft_normalized": False,
|
||||
"multi_stft_resolution_loss_weight": 1.0,
|
||||
"multi_stft_resolutions_window_sizes":(4096, 2048, 1024, 512, 256),
|
||||
"num_stems": 1,
|
||||
"stereo": True,
|
||||
"stft_hop_length": 441,
|
||||
"stft_n_fft": 2048,
|
||||
"stft_normalized": False,
|
||||
"stft_win_length": 2048,
|
||||
"time_transformer_depth": 1,
|
||||
|
||||
}
|
||||
|
||||
|
||||
model = BSRoformer(
|
||||
**dict(config)
|
||||
)
|
||||
|
||||
if self.model_type == 'bs_roformer':
|
||||
from bs_roformer.bs_roformer import BSRoformer
|
||||
model = BSRoformer(**dict(self.config["model"]))
|
||||
elif self.model_type == 'mel_band_roformer':
|
||||
from bs_roformer.mel_band_roformer import MelBandRoformer
|
||||
model = MelBandRoformer(**dict(self.config["model"]))
|
||||
else:
|
||||
print('Error: Unknown model: {}'.format(self.model_type))
|
||||
model = None
|
||||
return model
|
||||
|
||||
|
||||
|
||||
def demix_track(self, model, mix, device):
|
||||
C = 352800
|
||||
# num_overlap
|
||||
N = 1
|
||||
C = self.config["audio"]["chunk_size"] # chunk_size
|
||||
N = self.config["inference"]["num_overlap"]
|
||||
fade_size = C // 10
|
||||
step = int(C // N)
|
||||
border = C - step
|
||||
batch_size = 4
|
||||
batch_size = self.config["inference"]["batch_size"]
|
||||
|
||||
length_init = mix.shape[-1]
|
||||
|
||||
progress_bar = tqdm(total=length_init // step + 1)
|
||||
progress_bar.set_description("Processing")
|
||||
progress_bar = tqdm(total=length_init // step + 1, desc="Processing", leave=False)
|
||||
|
||||
# Do pad from the beginning and end to account floating window results better
|
||||
if length_init > 2 * border and (border > 0):
|
||||
@ -82,7 +132,10 @@ class BsRoformer_Loader:
|
||||
|
||||
with torch.amp.autocast('cuda'):
|
||||
with torch.inference_mode():
|
||||
req_shape = (1, ) + tuple(mix.shape)
|
||||
if self.config["training"]["target_instrument"] is None:
|
||||
req_shape = (len(self.config["training"]["instruments"]),) + tuple(mix.shape)
|
||||
else:
|
||||
req_shape = (1, ) + tuple(mix.shape)
|
||||
|
||||
result = torch.zeros(req_shape, dtype=torch.float32)
|
||||
counter = torch.zeros(req_shape, dtype=torch.float32)
|
||||
@ -97,7 +150,7 @@ class BsRoformer_Loader:
|
||||
part = nn.functional.pad(input=part, pad=(0, C - length), mode='reflect')
|
||||
else:
|
||||
part = nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0)
|
||||
if(self.is_half==True):
|
||||
if self.is_half:
|
||||
part=part.half()
|
||||
batch_data.append(part)
|
||||
batch_locations.append((i, length))
|
||||
@ -133,78 +186,116 @@ class BsRoformer_Loader:
|
||||
|
||||
progress_bar.close()
|
||||
|
||||
return {k: v for k, v in zip(['vocals', 'other'], estimated_sources)}
|
||||
if self.config["training"]["target_instrument"] is None:
|
||||
return {k: v for k, v in zip(self.config["training"]["instruments"], estimated_sources)}
|
||||
else:
|
||||
return {k: v for k, v in zip([self.config["training"]["target_instrument"]], estimated_sources)}
|
||||
|
||||
|
||||
def run_folder(self,input, vocal_root, others_root, format):
|
||||
# start_time = time.time()
|
||||
def run_folder(self, input, vocal_root, others_root, format):
|
||||
self.model.eval()
|
||||
path = input
|
||||
os.makedirs(vocal_root, exist_ok=True)
|
||||
os.makedirs(others_root, exist_ok=True)
|
||||
file_base_name = os.path.splitext(os.path.basename(path))[0]
|
||||
|
||||
if not os.path.isdir(vocal_root):
|
||||
os.mkdir(vocal_root)
|
||||
|
||||
if not os.path.isdir(others_root):
|
||||
os.mkdir(others_root)
|
||||
sample_rate = 44100
|
||||
if 'sample_rate' in self.config["audio"]:
|
||||
sample_rate = self.config["audio"]['sample_rate']
|
||||
|
||||
try:
|
||||
mix, sr = librosa.load(path, sr=44100, mono=False)
|
||||
mix, sr = librosa.load(path, sr=sample_rate, mono=False)
|
||||
except Exception as e:
|
||||
print('Can read track: {}'.format(path))
|
||||
print('Error message: {}'.format(str(e)))
|
||||
return
|
||||
|
||||
# Convert mono to stereo if needed
|
||||
if len(mix.shape) == 1:
|
||||
# in case if model only supports mono or stereo
|
||||
isstereo = self.config["model"].get("stereo", True)
|
||||
if isstereo and len(mix.shape) == 1:
|
||||
mix = np.stack([mix, mix], axis=0)
|
||||
print("Warning: Track is mono, but model is stereo, adding a second channel.")
|
||||
elif isstereo and len(mix.shape) > 2:
|
||||
mix = np.mean(mix, axis=0) # if more than 2 channels, take mean
|
||||
mix = np.stack([mix, mix], axis=0)
|
||||
print("Warning: Track has more than 2 channels, taking mean of all channels and adding a second channel.")
|
||||
elif not isstereo and len(mix.shape) != 1:
|
||||
mix = np.mean(mix, axis=0) # if more than 2 channels, take mean
|
||||
print("Warning: Track has more than 1 channels, but model is mono, taking mean of all channels.")
|
||||
|
||||
mix_orig = mix.copy()
|
||||
|
||||
mixture = torch.tensor(mix, dtype=torch.float32)
|
||||
res = self.demix_track(self.model, mixture, self.device)
|
||||
|
||||
estimates = res['vocals'].T
|
||||
|
||||
if format in ["wav", "flac"]:
|
||||
sf.write("{}/{}_{}.{}".format(vocal_root, os.path.basename(path)[:-4], 'vocals', format), estimates, sr)
|
||||
sf.write("{}/{}_{}.{}".format(others_root, os.path.basename(path)[:-4], 'instrumental', format), mix_orig.T - estimates, sr)
|
||||
if self.config["training"]["target_instrument"] is not None:
|
||||
# if target instrument is specified, save target instrument as vocal and other instruments as others
|
||||
# other instruments are caculated by subtracting target instrument from mixture
|
||||
target_instrument = self.config["training"]["target_instrument"]
|
||||
other_instruments = [i for i in self.config["training"]["instruments"] if i != target_instrument]
|
||||
other = mix_orig - res[target_instrument] # caculate other instruments
|
||||
|
||||
path_vocal = "{}/{}_{}.wav".format(vocal_root, file_base_name, target_instrument)
|
||||
path_other = "{}/{}_{}.wav".format(others_root, file_base_name, other_instruments[0])
|
||||
self.save_audio(path_vocal, res[target_instrument].T, sr, format)
|
||||
self.save_audio(path_other, other.T, sr, format)
|
||||
else:
|
||||
path_vocal = "%s/%s_vocals.wav" % (vocal_root, os.path.basename(path)[:-4])
|
||||
path_other = "%s/%s_instrumental.wav" % (others_root, os.path.basename(path)[:-4])
|
||||
sf.write(path_vocal, estimates, sr)
|
||||
sf.write(path_other, mix_orig.T - estimates, sr)
|
||||
opt_path_vocal = path_vocal[:-4] + ".%s" % format
|
||||
opt_path_other = path_other[:-4] + ".%s" % format
|
||||
if os.path.exists(path_vocal):
|
||||
os.system(
|
||||
"ffmpeg -i '%s' -vn '%s' -q:a 2 -y" % (path_vocal, opt_path_vocal)
|
||||
)
|
||||
if os.path.exists(opt_path_vocal):
|
||||
try:
|
||||
os.remove(path_vocal)
|
||||
except:
|
||||
pass
|
||||
if os.path.exists(path_other):
|
||||
os.system(
|
||||
"ffmpeg -i '%s' -vn '%s' -q:a 2 -y" % (path_other, opt_path_other)
|
||||
)
|
||||
if os.path.exists(opt_path_other):
|
||||
try:
|
||||
os.remove(path_other)
|
||||
except:
|
||||
pass
|
||||
|
||||
# print("Elapsed time: {:.2f} sec".format(time.time() - start_time))
|
||||
# if target instrument is not specified, save the first instrument as vocal and the rest as others
|
||||
vocal_inst = self.config["training"]["instruments"][0]
|
||||
path_vocal = "{}/{}_{}.wav".format(vocal_root, file_base_name, vocal_inst)
|
||||
self.save_audio(path_vocal, res[vocal_inst].T, sr, format)
|
||||
for other in self.config["training"]["instruments"][1:]: # save other instruments
|
||||
path_other = "{}/{}_{}.wav".format(others_root, file_base_name, other)
|
||||
self.save_audio(path_other, res[other].T, sr, format)
|
||||
|
||||
|
||||
def __init__(self, model_path, device,is_half):
|
||||
def save_audio(self, path, data, sr, format):
|
||||
# input path should be endwith '.wav'
|
||||
if format in ["wav", "flac"]:
|
||||
if format == "flac":
|
||||
path = path[:-3] + "flac"
|
||||
sf.write(path, data, sr)
|
||||
else:
|
||||
sf.write(path, data, sr)
|
||||
os.system("ffmpeg -i '{}' -vn '{}' -q:a 2 -y".format(path, path[:-3] + format))
|
||||
try: os.remove(path)
|
||||
except: pass
|
||||
|
||||
|
||||
def __init__(self, model_path, config_path, device, is_half):
|
||||
self.device = device
|
||||
self.extract_instrumental=True
|
||||
self.is_half = is_half
|
||||
self.model_type = None
|
||||
self.config = None
|
||||
|
||||
# get model_type, first try:
|
||||
if "bs_roformer" in model_path.lower() or "bsroformer" in model_path.lower():
|
||||
self.model_type = "bs_roformer"
|
||||
elif "mel_band_roformer" in model_path.lower() or "melbandroformer" in model_path.lower():
|
||||
self.model_type = "mel_band_roformer"
|
||||
|
||||
if not os.path.exists(config_path):
|
||||
if self.model_type is None:
|
||||
# if model_type is still None, raise an error
|
||||
raise ValueError("Error: Unknown model type. If you are using a model without a configuration file, Ensure that your model name includes 'bs_roformer', 'bsroformer', 'mel_band_roformer', or 'melbandroformer'. Otherwise, you can manually place the model configuration file into 'tools/uvr5/uvr5w_weights' and ensure that the configuration file is named as '<model_name>.yaml' then try it again.")
|
||||
self.config = self.get_default_config()
|
||||
else:
|
||||
# if there is a configuration file
|
||||
self.config = self.get_config(config_path)
|
||||
if self.model_type is None:
|
||||
# if model_type is still None, second try, get model_type from the configuration file
|
||||
if "freqs_per_bands" in self.config["model"]:
|
||||
# if freqs_per_bands in config, it's a bs_roformer model
|
||||
self.model_type = "bs_roformer"
|
||||
else:
|
||||
# else it's a mel_band_roformer model
|
||||
self.model_type = "mel_band_roformer"
|
||||
|
||||
print("Detected model type: {}".format(self.model_type))
|
||||
model = self.get_model_from_config()
|
||||
state_dict = torch.load(model_path,map_location="cpu")
|
||||
state_dict = torch.load(model_path, map_location="cpu")
|
||||
model.load_state_dict(state_dict)
|
||||
self.is_half=is_half
|
||||
|
||||
if(is_half==False):
|
||||
self.model = model.to(device)
|
||||
else:
|
||||
|
@ -12,7 +12,7 @@ import torch
|
||||
import sys
|
||||
from mdxnet import MDXNetDereverb
|
||||
from vr import AudioPre, AudioPreDeEcho
|
||||
from bsroformer import BsRoformer_Loader
|
||||
from bsroformer import Roformer_Loader
|
||||
|
||||
try:
|
||||
import gradio.analytics as analytics
|
||||
@ -49,13 +49,17 @@ def uvr(model_name, inp_root, save_root_vocal, paths, save_root_ins, agg, format
|
||||
is_hp3 = "HP3" in model_name
|
||||
if model_name == "onnx_dereverb_By_FoxJoy":
|
||||
pre_fun = MDXNetDereverb(15)
|
||||
elif model_name == "Bs_Roformer" or "bs_roformer" in model_name.lower():
|
||||
func = BsRoformer_Loader
|
||||
elif "roformer" in model_name.lower():
|
||||
func = Roformer_Loader
|
||||
pre_fun = func(
|
||||
model_path = os.path.join(weight_uvr5_root, model_name + ".ckpt"),
|
||||
config_path = os.path.join(weight_uvr5_root, model_name + ".yaml"),
|
||||
device = device,
|
||||
is_half=is_half
|
||||
)
|
||||
if not os.path.exists(os.path.join(weight_uvr5_root, model_name + ".yaml")):
|
||||
infos.append("Warning: You are using a model without a configuration file. The program will automatically use the default configuration file. However, the default configuration file cannot guarantee that all models will run successfully. You can manually place the model configuration file into 'tools/uvr5/uvr5w_weights' and ensure that the configuration file is named as '<model_name>.yaml' then try it again. (For example, the configuration file corresponding to the model 'bs_roformer_ep_368_sdr_12.9628' should be 'bs_roformer_ep_368_sdr_12.9628.yaml'.) Or you can just ignore this warning.")
|
||||
yield "\n".join(infos)
|
||||
else:
|
||||
func = AudioPre if "DeEcho" not in model_name else AudioPreDeEcho
|
||||
pre_fun = func(
|
||||
|
Loading…
x
Reference in New Issue
Block a user