From 8abc0342d72551fd7526872c3d243370ddc92a23 Mon Sep 17 00:00:00 2001 From: KamioRinn <63162909+KamioRinn@users.noreply.github.com> Date: Sat, 27 Jul 2024 16:37:28 +0800 Subject: [PATCH] =?UTF-8?q?=E4=B8=BAUVR5=E6=A8=A1=E5=9D=97=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0Bs=5FRoformer=E6=A8=A1=E5=9E=8B=20(#1306)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add Bs_Roformer * Compatible with default model name * Add progress bar --- tools/uvr5/bs_roformer/__init__.py | 0 tools/uvr5/bs_roformer/attend.py | 120 ++++++ tools/uvr5/bs_roformer/bs_roformer.py | 577 ++++++++++++++++++++++++++ tools/uvr5/bsroformer.py | 209 ++++++++++ tools/uvr5/webui.py | 7 + 5 files changed, 913 insertions(+) create mode 100644 tools/uvr5/bs_roformer/__init__.py create mode 100644 tools/uvr5/bs_roformer/attend.py create mode 100644 tools/uvr5/bs_roformer/bs_roformer.py create mode 100644 tools/uvr5/bsroformer.py diff --git a/tools/uvr5/bs_roformer/__init__.py b/tools/uvr5/bs_roformer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tools/uvr5/bs_roformer/attend.py b/tools/uvr5/bs_roformer/attend.py new file mode 100644 index 0000000..34476c1 --- /dev/null +++ b/tools/uvr5/bs_roformer/attend.py @@ -0,0 +1,120 @@ +from functools import wraps +from packaging import version +from collections import namedtuple + +import torch +from torch import nn, einsum +import torch.nn.functional as F + +from einops import rearrange, reduce + +# constants + +FlashAttentionConfig = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient']) + +# helpers + +def exists(val): + return val is not None + +def default(v, d): + return v if exists(v) else d + +def once(fn): + called = False + @wraps(fn) + def inner(x): + nonlocal called + if called: + return + called = True + return fn(x) + return inner + +print_once = once(print) + +# main class + +class Attend(nn.Module): + def __init__( + self, + dropout = 0., + flash = False, + scale = None + ): + super().__init__() + self.scale = scale + self.dropout = dropout + self.attn_dropout = nn.Dropout(dropout) + + self.flash = flash + assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above' + + # determine efficient attention configs for cuda and cpu + + self.cpu_config = FlashAttentionConfig(True, True, True) + self.cuda_config = None + + if not torch.cuda.is_available() or not flash: + return + + device_properties = torch.cuda.get_device_properties(torch.device('cuda')) + + if device_properties.major == 8 and device_properties.minor == 0: + print_once('A100 GPU detected, using flash attention if input tensor is on cuda') + self.cuda_config = FlashAttentionConfig(True, False, False) + else: + print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda') + self.cuda_config = FlashAttentionConfig(False, True, True) + + def flash_attn(self, q, k, v): + _, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device + + if exists(self.scale): + default_scale = q.shape[-1] ** -0.5 + q = q * (self.scale / default_scale) + + # Check if there is a compatible device for flash attention + + config = self.cuda_config if is_cuda else self.cpu_config + + # pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale + + with torch.backends.cuda.sdp_kernel(**config._asdict()): + out = F.scaled_dot_product_attention( + q, k, v, + dropout_p = self.dropout if self.training else 0. + ) + + return out + + def forward(self, q, k, v): + """ + einstein notation + b - batch + h - heads + n, i, j - sequence length (base sequence length, source, target) + d - feature dimension + """ + + q_len, k_len, device = q.shape[-2], k.shape[-2], q.device + + scale = default(self.scale, q.shape[-1] ** -0.5) + + if self.flash: + return self.flash_attn(q, k, v) + + # similarity + + sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale + + # attention + + attn = sim.softmax(dim=-1) + attn = self.attn_dropout(attn) + + # aggregate values + + out = einsum(f"b h i j, b h j d -> b h i d", attn, v) + + return out diff --git a/tools/uvr5/bs_roformer/bs_roformer.py b/tools/uvr5/bs_roformer/bs_roformer.py new file mode 100644 index 0000000..8670cec --- /dev/null +++ b/tools/uvr5/bs_roformer/bs_roformer.py @@ -0,0 +1,577 @@ +from functools import partial + +import torch +from torch import nn, einsum, Tensor +from torch.nn import Module, ModuleList +import torch.nn.functional as F + +from bs_roformer.attend import Attend + +from beartype.typing import Tuple, Optional, List, Callable +from beartype import beartype + +from rotary_embedding_torch import RotaryEmbedding + +from einops import rearrange, pack, unpack +from einops.layers.torch import Rearrange + +# helper functions + +def exists(val): + return val is not None + + +def default(v, d): + return v if exists(v) else d + + +def pack_one(t, pattern): + return pack([t], pattern) + + +def unpack_one(t, ps, pattern): + return unpack(t, ps, pattern)[0] + + +# norm + +def l2norm(t): + return F.normalize(t, dim = -1, p = 2) + + +class RMSNorm(Module): + def __init__(self, dim): + super().__init__() + self.scale = dim ** 0.5 + self.gamma = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + return F.normalize(x, dim=-1) * self.scale * self.gamma + + +# attention + +class FeedForward(Module): + def __init__( + self, + dim, + mult=4, + dropout=0. + ): + super().__init__() + dim_inner = int(dim * mult) + self.net = nn.Sequential( + RMSNorm(dim), + nn.Linear(dim, dim_inner), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(dim_inner, dim), + nn.Dropout(dropout) + ) + + def forward(self, x): + return self.net(x) + + +class Attention(Module): + def __init__( + self, + dim, + heads=8, + dim_head=64, + dropout=0., + rotary_embed=None, + flash=True + ): + super().__init__() + self.heads = heads + self.scale = dim_head ** -0.5 + dim_inner = heads * dim_head + + self.rotary_embed = rotary_embed + + self.attend = Attend(flash=flash, dropout=dropout) + + self.norm = RMSNorm(dim) + self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False) + + self.to_gates = nn.Linear(dim, heads) + + self.to_out = nn.Sequential( + nn.Linear(dim_inner, dim, bias=False), + nn.Dropout(dropout) + ) + + def forward(self, x): + x = self.norm(x) + + q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads) + + if exists(self.rotary_embed): + q = self.rotary_embed.rotate_queries_or_keys(q) + k = self.rotary_embed.rotate_queries_or_keys(k) + + out = self.attend(q, k, v) + + gates = self.to_gates(x) + out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid() + + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + + +class LinearAttention(Module): + """ + this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al. + """ + + @beartype + def __init__( + self, + *, + dim, + dim_head=32, + heads=8, + scale=8, + flash=False, + dropout=0. + ): + super().__init__() + dim_inner = dim_head * heads + self.norm = RMSNorm(dim) + + self.to_qkv = nn.Sequential( + nn.Linear(dim, dim_inner * 3, bias=False), + Rearrange('b n (qkv h d) -> qkv b h d n', qkv=3, h=heads) + ) + + self.temperature = nn.Parameter(torch.ones(heads, 1, 1)) + + self.attend = Attend( + scale=scale, + dropout=dropout, + flash=flash + ) + + self.to_out = nn.Sequential( + Rearrange('b h d n -> b n (h d)'), + nn.Linear(dim_inner, dim, bias=False) + ) + + def forward( + self, + x + ): + x = self.norm(x) + + q, k, v = self.to_qkv(x) + + q, k = map(l2norm, (q, k)) + q = q * self.temperature.exp() + + out = self.attend(q, k, v) + + return self.to_out(out) + + +class Transformer(Module): + def __init__( + self, + *, + dim, + depth, + dim_head=64, + heads=8, + attn_dropout=0., + ff_dropout=0., + ff_mult=4, + norm_output=True, + rotary_embed=None, + flash_attn=True, + linear_attn=False + ): + super().__init__() + self.layers = ModuleList([]) + + for _ in range(depth): + if linear_attn: + attn = LinearAttention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, flash=flash_attn) + else: + attn = Attention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, + rotary_embed=rotary_embed, flash=flash_attn) + + self.layers.append(ModuleList([ + attn, + FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout) + ])) + + self.norm = RMSNorm(dim) if norm_output else nn.Identity() + + def forward(self, x): + + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + + return self.norm(x) + + +# bandsplit module + +class BandSplit(Module): + @beartype + def __init__( + self, + dim, + dim_inputs: Tuple[int, ...] + ): + super().__init__() + self.dim_inputs = dim_inputs + self.to_features = ModuleList([]) + + for dim_in in dim_inputs: + net = nn.Sequential( + RMSNorm(dim_in), + nn.Linear(dim_in, dim) + ) + + self.to_features.append(net) + + def forward(self, x): + x = x.split(self.dim_inputs, dim=-1) + + outs = [] + for split_input, to_feature in zip(x, self.to_features): + split_output = to_feature(split_input) + outs.append(split_output) + + return torch.stack(outs, dim=-2) + + +def MLP( + dim_in, + dim_out, + dim_hidden=None, + depth=1, + activation=nn.Tanh +): + dim_hidden = default(dim_hidden, dim_in) + + net = [] + dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out) + + for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])): + is_last = ind == (len(dims) - 2) + + net.append(nn.Linear(layer_dim_in, layer_dim_out)) + + if is_last: + continue + + net.append(activation()) + + return nn.Sequential(*net) + + +class MaskEstimator(Module): + @beartype + def __init__( + self, + dim, + dim_inputs: Tuple[int, ...], + depth, + mlp_expansion_factor=4 + ): + super().__init__() + self.dim_inputs = dim_inputs + self.to_freqs = ModuleList([]) + dim_hidden = dim * mlp_expansion_factor + + for dim_in in dim_inputs: + net = [] + + mlp = nn.Sequential( + MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), + nn.GLU(dim=-1) + ) + + self.to_freqs.append(mlp) + + def forward(self, x): + x = x.unbind(dim=-2) + + outs = [] + + for band_features, mlp in zip(x, self.to_freqs): + freq_out = mlp(band_features) + outs.append(freq_out) + + return torch.cat(outs, dim=-1) + + +# main class + +DEFAULT_FREQS_PER_BANDS = ( + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, + 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, + 12, 12, 12, 12, 12, 12, 12, 12, + 24, 24, 24, 24, 24, 24, 24, 24, + 48, 48, 48, 48, 48, 48, 48, 48, + 128, 129, +) + + +class BSRoformer(Module): + + @beartype + def __init__( + self, + dim, + *, + depth, + stereo=False, + num_stems=1, + time_transformer_depth=2, + freq_transformer_depth=2, + linear_transformer_depth=0, + freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS, + # in the paper, they divide into ~60 bands, test with 1 for starters + dim_head=64, + heads=8, + attn_dropout=0., + ff_dropout=0., + flash_attn=True, + dim_freqs_in=1025, + stft_n_fft=2048, + stft_hop_length=512, + # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction + stft_win_length=2048, + stft_normalized=False, + stft_window_fn: Optional[Callable] = None, + mask_estimator_depth=2, + multi_stft_resolution_loss_weight=1., + multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256), + multi_stft_hop_size=147, + multi_stft_normalized=False, + multi_stft_window_fn: Callable = torch.hann_window + ): + super().__init__() + + self.stereo = stereo + self.audio_channels = 2 if stereo else 1 + self.num_stems = num_stems + + self.layers = ModuleList([]) + + transformer_kwargs = dict( + dim=dim, + heads=heads, + dim_head=dim_head, + attn_dropout=attn_dropout, + ff_dropout=ff_dropout, + flash_attn=flash_attn, + norm_output=False + ) + + time_rotary_embed = RotaryEmbedding(dim=dim_head) + freq_rotary_embed = RotaryEmbedding(dim=dim_head) + + for _ in range(depth): + tran_modules = [] + if linear_transformer_depth > 0: + tran_modules.append(Transformer(depth=linear_transformer_depth, linear_attn=True, **transformer_kwargs)) + tran_modules.append( + Transformer(depth=time_transformer_depth, rotary_embed=time_rotary_embed, **transformer_kwargs) + ) + tran_modules.append( + Transformer(depth=freq_transformer_depth, rotary_embed=freq_rotary_embed, **transformer_kwargs) + ) + self.layers.append(nn.ModuleList(tran_modules)) + + self.final_norm = RMSNorm(dim) + + self.stft_kwargs = dict( + n_fft=stft_n_fft, + hop_length=stft_hop_length, + win_length=stft_win_length, + normalized=stft_normalized + ) + + self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length) + + freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, return_complex=True).shape[1] + + assert len(freqs_per_bands) > 1 + assert sum( + freqs_per_bands) == freqs, f'the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}' + + freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in freqs_per_bands) + + self.band_split = BandSplit( + dim=dim, + dim_inputs=freqs_per_bands_with_complex + ) + + self.mask_estimators = nn.ModuleList([]) + + for _ in range(num_stems): + mask_estimator = MaskEstimator( + dim=dim, + dim_inputs=freqs_per_bands_with_complex, + depth=mask_estimator_depth + ) + + self.mask_estimators.append(mask_estimator) + + # for the multi-resolution stft loss + + self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight + self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes + self.multi_stft_n_fft = stft_n_fft + self.multi_stft_window_fn = multi_stft_window_fn + + self.multi_stft_kwargs = dict( + hop_length=multi_stft_hop_size, + normalized=multi_stft_normalized + ) + + def forward( + self, + raw_audio, + target=None, + return_loss_breakdown=False + ): + """ + einops + + b - batch + f - freq + t - time + s - audio channel (1 for mono, 2 for stereo) + n - number of 'stems' + c - complex (2) + d - feature dimension + """ + + device = raw_audio.device + + if raw_audio.ndim == 2: + raw_audio = rearrange(raw_audio, 'b t -> b 1 t') + + channels = raw_audio.shape[1] + assert (not self.stereo and channels == 1) or ( + self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)' + + # to stft + + raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t') + + stft_window = self.stft_window_fn(device=device) + + stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True) + stft_repr = torch.view_as_real(stft_repr) + + stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c') + stft_repr = rearrange(stft_repr, + 'b s f t c -> b (f s) t c') # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting + + x = rearrange(stft_repr, 'b f t c -> b t (f c)') + + x = self.band_split(x) + + # axial / hierarchical attention + + for transformer_block in self.layers: + + if len(transformer_block) == 3: + linear_transformer, time_transformer, freq_transformer = transformer_block + + x, ft_ps = pack([x], 'b * d') + x = linear_transformer(x) + x, = unpack(x, ft_ps, 'b * d') + else: + time_transformer, freq_transformer = transformer_block + + x = rearrange(x, 'b t f d -> b f t d') + x, ps = pack([x], '* t d') + + x = time_transformer(x) + + x, = unpack(x, ps, '* t d') + x = rearrange(x, 'b f t d -> b t f d') + x, ps = pack([x], '* f d') + + x = freq_transformer(x) + + x, = unpack(x, ps, '* f d') + + x = self.final_norm(x) + + num_stems = len(self.mask_estimators) + + mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1) + mask = rearrange(mask, 'b n t (f c) -> b n f t c', c=2) + + # modulate frequency representation + + stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c') + + # complex number multiplication + + stft_repr = torch.view_as_complex(stft_repr) + mask = torch.view_as_complex(mask) + + stft_repr = stft_repr * mask + + # istft + + stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels) + + recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False) + + recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', s=self.audio_channels, n=num_stems) + + if num_stems == 1: + recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t') + + # if a target is passed in, calculate loss for learning + + if not exists(target): + return recon_audio + + if self.num_stems > 1: + assert target.ndim == 4 and target.shape[1] == self.num_stems + + if target.ndim == 2: + target = rearrange(target, '... t -> ... 1 t') + + target = target[..., :recon_audio.shape[-1]] # protect against lost length on istft + + loss = F.l1_loss(recon_audio, target) + + multi_stft_resolution_loss = 0. + + for window_size in self.multi_stft_resolutions_window_sizes: + res_stft_kwargs = dict( + n_fft=max(window_size, self.multi_stft_n_fft), # not sure what n_fft is across multi resolution stft + win_length=window_size, + return_complex=True, + window=self.multi_stft_window_fn(window_size, device=device), + **self.multi_stft_kwargs, + ) + + recon_Y = torch.stft(rearrange(recon_audio, '... s t -> (... s) t'), **res_stft_kwargs) + target_Y = torch.stft(rearrange(target, '... s t -> (... s) t'), **res_stft_kwargs) + + multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y) + + weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight + + total_loss = loss + weighted_multi_resolution_loss + + if not return_loss_breakdown: + return total_loss + + return total_loss, (loss, multi_stft_resolution_loss) \ No newline at end of file diff --git a/tools/uvr5/bsroformer.py b/tools/uvr5/bsroformer.py new file mode 100644 index 0000000..c55a776 --- /dev/null +++ b/tools/uvr5/bsroformer.py @@ -0,0 +1,209 @@ +# This code is modified from https://github.com/ZFTurbo/ + +import time +import librosa +from tqdm import tqdm +import os +import glob +import torch +import numpy as np +import soundfile as sf +import torch.nn as nn + +import warnings +warnings.filterwarnings("ignore") + +class BsRoformer_Loader: + def get_model_from_config(self): + config = { + "attn_dropout": 0.1, + "depth": 12, + "dim": 512, + "dim_freqs_in": 1025, + "dim_head": 64, + "ff_dropout": 0.1, + "flash_attn": True, + "freq_transformer_depth": 1, + "freqs_per_bands":(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 12, 12, 12, 12, 12, 12, 12, 12, 24, 24, 24, 24, 24, 24, 24, 24, 48, 48, 48, 48, 48, 48, 48, 48, 128, 129), + "heads": 8, + "linear_transformer_depth": 0, + "mask_estimator_depth": 2, + "multi_stft_hop_size": 147, + "multi_stft_normalized": False, + "multi_stft_resolution_loss_weight": 1.0, + "multi_stft_resolutions_window_sizes":(4096, 2048, 1024, 512, 256), + "num_stems": 1, + "stereo": True, + "stft_hop_length": 441, + "stft_n_fft": 2048, + "stft_normalized": False, + "stft_win_length": 2048, + "time_transformer_depth": 1, + + } + + from bs_roformer.bs_roformer import BSRoformer + model = BSRoformer( + **dict(config) + ) + + return model + + + def demix_track(self, model, mix, device): + C = 352800 + N = 2 + fade_size = C // 10 + step = int(C // N) + border = C - step + batch_size = 4 + + length_init = mix.shape[-1] + + progress_bar = tqdm(total=(length_init//step)+3) + progress_bar.set_description("Processing") + + # Do pad from the beginning and end to account floating window results better + if length_init > 2 * border and (border > 0): + mix = nn.functional.pad(mix, (border, border), mode='reflect') + + # Prepare windows arrays (do 1 time for speed up). This trick repairs click problems on the edges of segment + window_size = C + fadein = torch.linspace(0, 1, fade_size) + fadeout = torch.linspace(1, 0, fade_size) + window_start = torch.ones(window_size) + window_middle = torch.ones(window_size) + window_finish = torch.ones(window_size) + window_start[-fade_size:] *= fadeout # First audio chunk, no fadein + window_finish[:fade_size] *= fadein # Last audio chunk, no fadeout + window_middle[-fade_size:] *= fadeout + window_middle[:fade_size] *= fadein + + with torch.cuda.amp.autocast(): + with torch.inference_mode(): + req_shape = (1, ) + tuple(mix.shape) + + result = torch.zeros(req_shape, dtype=torch.float32) + counter = torch.zeros(req_shape, dtype=torch.float32) + i = 0 + batch_data = [] + batch_locations = [] + while i < mix.shape[1]: + part = mix[:, i:i + C].to(device) + length = part.shape[-1] + if length < C: + if length > C // 2 + 1: + part = nn.functional.pad(input=part, pad=(0, C - length), mode='reflect') + else: + part = nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0) + batch_data.append(part) + batch_locations.append((i, length)) + i += step + progress_bar.update(1) + + if len(batch_data) >= batch_size or (i >= mix.shape[1]): + arr = torch.stack(batch_data, dim=0) + x = model(arr) + + window = window_middle + if i - step == 0: # First audio chunk, no fadein + window = window_start + elif i >= mix.shape[1]: # Last audio chunk, no fadeout + window = window_finish + + for j in range(len(batch_locations)): + start, l = batch_locations[j] + result[..., start:start+l] += x[j][..., :l].cpu() * window[..., :l] + counter[..., start:start+l] += window[..., :l] + + batch_data = [] + batch_locations = [] + + estimated_sources = result / counter + estimated_sources = estimated_sources.cpu().numpy() + np.nan_to_num(estimated_sources, copy=False, nan=0.0) + + if length_init > 2 * border and (border > 0): + # Remove pad + estimated_sources = estimated_sources[..., border:-border] + + progress_bar.close() + + return {k: v for k, v in zip(['vocals', 'other'], estimated_sources)} + + + def run_folder(self,input, vocal_root, others_root, format): + # start_time = time.time() + self.model.eval() + path = input + + if not os.path.isdir(vocal_root): + os.mkdir(vocal_root) + + if not os.path.isdir(others_root): + os.mkdir(others_root) + + try: + mix, sr = librosa.load(path, sr=44100, mono=False) + except Exception as e: + print('Can read track: {}'.format(path)) + print('Error message: {}'.format(str(e))) + return + + # Convert mono to stereo if needed + if len(mix.shape) == 1: + mix = np.stack([mix, mix], axis=0) + + mix_orig = mix.copy() + + mixture = torch.tensor(mix, dtype=torch.float32) + res = self.demix_track(self.model, mixture, self.device) + + estimates = res['vocals'].T + print("{}/{}_{}.{}".format(vocal_root, os.path.basename(path)[:-4], 'vocals', format)) + + if format in ["wav", "flac"]: + sf.write("{}/{}_{}.{}".format(vocal_root, os.path.basename(path)[:-4], 'vocals', format), estimates, sr) + sf.write("{}/{}_{}.{}".format(others_root, os.path.basename(path)[:-4], 'instrumental', format), mix_orig.T - estimates, sr) + else: + path_vocal = "%s/%s_vocals.wav" % (vocal_root, os.path.basename(path)[:-4]) + path_other = "%s/%s_instrumental.wav" % (others_root, os.path.basename(path)[:-4]) + sf.write(path_vocal, estimates, sr) + sf.write(path_other, mix_orig.T - estimates, sr) + opt_path_vocal = path_vocal[:-4] + ".%s" % format + opt_path_other = path_other[:-4] + ".%s" % format + if os.path.exists(path_vocal): + os.system( + "ffmpeg -i '%s' -vn '%s' -q:a 2 -y" % (path_vocal, opt_path_vocal) + ) + if os.path.exists(opt_path_vocal): + try: + os.remove(path_vocal) + except: + pass + if os.path.exists(path_other): + os.system( + "ffmpeg -i '%s' -vn '%s' -q:a 2 -y" % (path_other, opt_path_other) + ) + if os.path.exists(opt_path_other): + try: + os.remove(path_other) + except: + pass + + # print("Elapsed time: {:.2f} sec".format(time.time() - start_time)) + + + def __init__(self, model_path, device): + self.device = device + self.extract_instrumental=True + + model = self.get_model_from_config() + state_dict = torch.load(model_path) + model.load_state_dict(state_dict) + self.model = model.to(device) + + + def _path_audio_(self, input, others_root, vocal_root, format, is_hp3=False): + self.run_folder(input, vocal_root, others_root, format) + diff --git a/tools/uvr5/webui.py b/tools/uvr5/webui.py index 735cc2c..712c892 100644 --- a/tools/uvr5/webui.py +++ b/tools/uvr5/webui.py @@ -12,6 +12,7 @@ import torch import sys from mdxnet import MDXNetDereverb from vr import AudioPre, AudioPreDeEcho +from bsroformer import BsRoformer_Loader weight_uvr5_root = "tools/uvr5/uvr5_weights" uvr5_names = [] @@ -33,6 +34,12 @@ def uvr(model_name, inp_root, save_root_vocal, paths, save_root_ins, agg, format is_hp3 = "HP3" in model_name if model_name == "onnx_dereverb_By_FoxJoy": pre_fun = MDXNetDereverb(15) + elif model_name == "Bs_Roformer" or "bs_roformer" in model_name.lower(): + func = BsRoformer_Loader + pre_fun = func( + model_path = os.path.join(weight_uvr5_root, model_name + ".pth"), + device = device, + ) else: func = AudioPre if "DeEcho" not in model_name else AudioPreDeEcho pre_fun = func(