为UVR5模块添加Bs_Roformer模型 (#1306)

* Add Bs_Roformer

* Compatible with default model name

* Add progress bar
This commit is contained in:
KamioRinn 2024-07-27 16:37:28 +08:00 committed by GitHub
parent e851ae34c9
commit 8abc0342d7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 913 additions and 0 deletions

View File

View File

@ -0,0 +1,120 @@
from functools import wraps
from packaging import version
from collections import namedtuple
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, reduce
# constants
FlashAttentionConfig = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
# helpers
def exists(val):
return val is not None
def default(v, d):
return v if exists(v) else d
def once(fn):
called = False
@wraps(fn)
def inner(x):
nonlocal called
if called:
return
called = True
return fn(x)
return inner
print_once = once(print)
# main class
class Attend(nn.Module):
def __init__(
self,
dropout = 0.,
flash = False,
scale = None
):
super().__init__()
self.scale = scale
self.dropout = dropout
self.attn_dropout = nn.Dropout(dropout)
self.flash = flash
assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
# determine efficient attention configs for cuda and cpu
self.cpu_config = FlashAttentionConfig(True, True, True)
self.cuda_config = None
if not torch.cuda.is_available() or not flash:
return
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
if device_properties.major == 8 and device_properties.minor == 0:
print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
self.cuda_config = FlashAttentionConfig(True, False, False)
else:
print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
self.cuda_config = FlashAttentionConfig(False, True, True)
def flash_attn(self, q, k, v):
_, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
if exists(self.scale):
default_scale = q.shape[-1] ** -0.5
q = q * (self.scale / default_scale)
# Check if there is a compatible device for flash attention
config = self.cuda_config if is_cuda else self.cpu_config
# pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale
with torch.backends.cuda.sdp_kernel(**config._asdict()):
out = F.scaled_dot_product_attention(
q, k, v,
dropout_p = self.dropout if self.training else 0.
)
return out
def forward(self, q, k, v):
"""
einstein notation
b - batch
h - heads
n, i, j - sequence length (base sequence length, source, target)
d - feature dimension
"""
q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
scale = default(self.scale, q.shape[-1] ** -0.5)
if self.flash:
return self.flash_attn(q, k, v)
# similarity
sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale
# attention
attn = sim.softmax(dim=-1)
attn = self.attn_dropout(attn)
# aggregate values
out = einsum(f"b h i j, b h j d -> b h i d", attn, v)
return out

View File

@ -0,0 +1,577 @@
from functools import partial
import torch
from torch import nn, einsum, Tensor
from torch.nn import Module, ModuleList
import torch.nn.functional as F
from bs_roformer.attend import Attend
from beartype.typing import Tuple, Optional, List, Callable
from beartype import beartype
from rotary_embedding_torch import RotaryEmbedding
from einops import rearrange, pack, unpack
from einops.layers.torch import Rearrange
# helper functions
def exists(val):
return val is not None
def default(v, d):
return v if exists(v) else d
def pack_one(t, pattern):
return pack([t], pattern)
def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]
# norm
def l2norm(t):
return F.normalize(t, dim = -1, p = 2)
class RMSNorm(Module):
def __init__(self, dim):
super().__init__()
self.scale = dim ** 0.5
self.gamma = nn.Parameter(torch.ones(dim))
def forward(self, x):
return F.normalize(x, dim=-1) * self.scale * self.gamma
# attention
class FeedForward(Module):
def __init__(
self,
dim,
mult=4,
dropout=0.
):
super().__init__()
dim_inner = int(dim * mult)
self.net = nn.Sequential(
RMSNorm(dim),
nn.Linear(dim, dim_inner),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim_inner, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(Module):
def __init__(
self,
dim,
heads=8,
dim_head=64,
dropout=0.,
rotary_embed=None,
flash=True
):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
dim_inner = heads * dim_head
self.rotary_embed = rotary_embed
self.attend = Attend(flash=flash, dropout=dropout)
self.norm = RMSNorm(dim)
self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
self.to_gates = nn.Linear(dim, heads)
self.to_out = nn.Sequential(
nn.Linear(dim_inner, dim, bias=False),
nn.Dropout(dropout)
)
def forward(self, x):
x = self.norm(x)
q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads)
if exists(self.rotary_embed):
q = self.rotary_embed.rotate_queries_or_keys(q)
k = self.rotary_embed.rotate_queries_or_keys(k)
out = self.attend(q, k, v)
gates = self.to_gates(x)
out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid()
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class LinearAttention(Module):
"""
this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al.
"""
@beartype
def __init__(
self,
*,
dim,
dim_head=32,
heads=8,
scale=8,
flash=False,
dropout=0.
):
super().__init__()
dim_inner = dim_head * heads
self.norm = RMSNorm(dim)
self.to_qkv = nn.Sequential(
nn.Linear(dim, dim_inner * 3, bias=False),
Rearrange('b n (qkv h d) -> qkv b h d n', qkv=3, h=heads)
)
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
self.attend = Attend(
scale=scale,
dropout=dropout,
flash=flash
)
self.to_out = nn.Sequential(
Rearrange('b h d n -> b n (h d)'),
nn.Linear(dim_inner, dim, bias=False)
)
def forward(
self,
x
):
x = self.norm(x)
q, k, v = self.to_qkv(x)
q, k = map(l2norm, (q, k))
q = q * self.temperature.exp()
out = self.attend(q, k, v)
return self.to_out(out)
class Transformer(Module):
def __init__(
self,
*,
dim,
depth,
dim_head=64,
heads=8,
attn_dropout=0.,
ff_dropout=0.,
ff_mult=4,
norm_output=True,
rotary_embed=None,
flash_attn=True,
linear_attn=False
):
super().__init__()
self.layers = ModuleList([])
for _ in range(depth):
if linear_attn:
attn = LinearAttention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, flash=flash_attn)
else:
attn = Attention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout,
rotary_embed=rotary_embed, flash=flash_attn)
self.layers.append(ModuleList([
attn,
FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)
]))
self.norm = RMSNorm(dim) if norm_output else nn.Identity()
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return self.norm(x)
# bandsplit module
class BandSplit(Module):
@beartype
def __init__(
self,
dim,
dim_inputs: Tuple[int, ...]
):
super().__init__()
self.dim_inputs = dim_inputs
self.to_features = ModuleList([])
for dim_in in dim_inputs:
net = nn.Sequential(
RMSNorm(dim_in),
nn.Linear(dim_in, dim)
)
self.to_features.append(net)
def forward(self, x):
x = x.split(self.dim_inputs, dim=-1)
outs = []
for split_input, to_feature in zip(x, self.to_features):
split_output = to_feature(split_input)
outs.append(split_output)
return torch.stack(outs, dim=-2)
def MLP(
dim_in,
dim_out,
dim_hidden=None,
depth=1,
activation=nn.Tanh
):
dim_hidden = default(dim_hidden, dim_in)
net = []
dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out)
for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
is_last = ind == (len(dims) - 2)
net.append(nn.Linear(layer_dim_in, layer_dim_out))
if is_last:
continue
net.append(activation())
return nn.Sequential(*net)
class MaskEstimator(Module):
@beartype
def __init__(
self,
dim,
dim_inputs: Tuple[int, ...],
depth,
mlp_expansion_factor=4
):
super().__init__()
self.dim_inputs = dim_inputs
self.to_freqs = ModuleList([])
dim_hidden = dim * mlp_expansion_factor
for dim_in in dim_inputs:
net = []
mlp = nn.Sequential(
MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth),
nn.GLU(dim=-1)
)
self.to_freqs.append(mlp)
def forward(self, x):
x = x.unbind(dim=-2)
outs = []
for band_features, mlp in zip(x, self.to_freqs):
freq_out = mlp(band_features)
outs.append(freq_out)
return torch.cat(outs, dim=-1)
# main class
DEFAULT_FREQS_PER_BANDS = (
2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2,
4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
12, 12, 12, 12, 12, 12, 12, 12,
24, 24, 24, 24, 24, 24, 24, 24,
48, 48, 48, 48, 48, 48, 48, 48,
128, 129,
)
class BSRoformer(Module):
@beartype
def __init__(
self,
dim,
*,
depth,
stereo=False,
num_stems=1,
time_transformer_depth=2,
freq_transformer_depth=2,
linear_transformer_depth=0,
freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS,
# in the paper, they divide into ~60 bands, test with 1 for starters
dim_head=64,
heads=8,
attn_dropout=0.,
ff_dropout=0.,
flash_attn=True,
dim_freqs_in=1025,
stft_n_fft=2048,
stft_hop_length=512,
# 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
stft_win_length=2048,
stft_normalized=False,
stft_window_fn: Optional[Callable] = None,
mask_estimator_depth=2,
multi_stft_resolution_loss_weight=1.,
multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
multi_stft_hop_size=147,
multi_stft_normalized=False,
multi_stft_window_fn: Callable = torch.hann_window
):
super().__init__()
self.stereo = stereo
self.audio_channels = 2 if stereo else 1
self.num_stems = num_stems
self.layers = ModuleList([])
transformer_kwargs = dict(
dim=dim,
heads=heads,
dim_head=dim_head,
attn_dropout=attn_dropout,
ff_dropout=ff_dropout,
flash_attn=flash_attn,
norm_output=False
)
time_rotary_embed = RotaryEmbedding(dim=dim_head)
freq_rotary_embed = RotaryEmbedding(dim=dim_head)
for _ in range(depth):
tran_modules = []
if linear_transformer_depth > 0:
tran_modules.append(Transformer(depth=linear_transformer_depth, linear_attn=True, **transformer_kwargs))
tran_modules.append(
Transformer(depth=time_transformer_depth, rotary_embed=time_rotary_embed, **transformer_kwargs)
)
tran_modules.append(
Transformer(depth=freq_transformer_depth, rotary_embed=freq_rotary_embed, **transformer_kwargs)
)
self.layers.append(nn.ModuleList(tran_modules))
self.final_norm = RMSNorm(dim)
self.stft_kwargs = dict(
n_fft=stft_n_fft,
hop_length=stft_hop_length,
win_length=stft_win_length,
normalized=stft_normalized
)
self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, return_complex=True).shape[1]
assert len(freqs_per_bands) > 1
assert sum(
freqs_per_bands) == freqs, f'the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}'
freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in freqs_per_bands)
self.band_split = BandSplit(
dim=dim,
dim_inputs=freqs_per_bands_with_complex
)
self.mask_estimators = nn.ModuleList([])
for _ in range(num_stems):
mask_estimator = MaskEstimator(
dim=dim,
dim_inputs=freqs_per_bands_with_complex,
depth=mask_estimator_depth
)
self.mask_estimators.append(mask_estimator)
# for the multi-resolution stft loss
self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
self.multi_stft_n_fft = stft_n_fft
self.multi_stft_window_fn = multi_stft_window_fn
self.multi_stft_kwargs = dict(
hop_length=multi_stft_hop_size,
normalized=multi_stft_normalized
)
def forward(
self,
raw_audio,
target=None,
return_loss_breakdown=False
):
"""
einops
b - batch
f - freq
t - time
s - audio channel (1 for mono, 2 for stereo)
n - number of 'stems'
c - complex (2)
d - feature dimension
"""
device = raw_audio.device
if raw_audio.ndim == 2:
raw_audio = rearrange(raw_audio, 'b t -> b 1 t')
channels = raw_audio.shape[1]
assert (not self.stereo and channels == 1) or (
self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)'
# to stft
raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t')
stft_window = self.stft_window_fn(device=device)
stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
stft_repr = torch.view_as_real(stft_repr)
stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c')
stft_repr = rearrange(stft_repr,
'b s f t c -> b (f s) t c') # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
x = rearrange(stft_repr, 'b f t c -> b t (f c)')
x = self.band_split(x)
# axial / hierarchical attention
for transformer_block in self.layers:
if len(transformer_block) == 3:
linear_transformer, time_transformer, freq_transformer = transformer_block
x, ft_ps = pack([x], 'b * d')
x = linear_transformer(x)
x, = unpack(x, ft_ps, 'b * d')
else:
time_transformer, freq_transformer = transformer_block
x = rearrange(x, 'b t f d -> b f t d')
x, ps = pack([x], '* t d')
x = time_transformer(x)
x, = unpack(x, ps, '* t d')
x = rearrange(x, 'b f t d -> b t f d')
x, ps = pack([x], '* f d')
x = freq_transformer(x)
x, = unpack(x, ps, '* f d')
x = self.final_norm(x)
num_stems = len(self.mask_estimators)
mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
mask = rearrange(mask, 'b n t (f c) -> b n f t c', c=2)
# modulate frequency representation
stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c')
# complex number multiplication
stft_repr = torch.view_as_complex(stft_repr)
mask = torch.view_as_complex(mask)
stft_repr = stft_repr * mask
# istft
stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels)
recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False)
recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', s=self.audio_channels, n=num_stems)
if num_stems == 1:
recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t')
# if a target is passed in, calculate loss for learning
if not exists(target):
return recon_audio
if self.num_stems > 1:
assert target.ndim == 4 and target.shape[1] == self.num_stems
if target.ndim == 2:
target = rearrange(target, '... t -> ... 1 t')
target = target[..., :recon_audio.shape[-1]] # protect against lost length on istft
loss = F.l1_loss(recon_audio, target)
multi_stft_resolution_loss = 0.
for window_size in self.multi_stft_resolutions_window_sizes:
res_stft_kwargs = dict(
n_fft=max(window_size, self.multi_stft_n_fft), # not sure what n_fft is across multi resolution stft
win_length=window_size,
return_complex=True,
window=self.multi_stft_window_fn(window_size, device=device),
**self.multi_stft_kwargs,
)
recon_Y = torch.stft(rearrange(recon_audio, '... s t -> (... s) t'), **res_stft_kwargs)
target_Y = torch.stft(rearrange(target, '... s t -> (... s) t'), **res_stft_kwargs)
multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y)
weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
total_loss = loss + weighted_multi_resolution_loss
if not return_loss_breakdown:
return total_loss
return total_loss, (loss, multi_stft_resolution_loss)

209
tools/uvr5/bsroformer.py Normal file
View File

@ -0,0 +1,209 @@
# This code is modified from https://github.com/ZFTurbo/
import time
import librosa
from tqdm import tqdm
import os
import glob
import torch
import numpy as np
import soundfile as sf
import torch.nn as nn
import warnings
warnings.filterwarnings("ignore")
class BsRoformer_Loader:
def get_model_from_config(self):
config = {
"attn_dropout": 0.1,
"depth": 12,
"dim": 512,
"dim_freqs_in": 1025,
"dim_head": 64,
"ff_dropout": 0.1,
"flash_attn": True,
"freq_transformer_depth": 1,
"freqs_per_bands":(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 12, 12, 12, 12, 12, 12, 12, 12, 24, 24, 24, 24, 24, 24, 24, 24, 48, 48, 48, 48, 48, 48, 48, 48, 128, 129),
"heads": 8,
"linear_transformer_depth": 0,
"mask_estimator_depth": 2,
"multi_stft_hop_size": 147,
"multi_stft_normalized": False,
"multi_stft_resolution_loss_weight": 1.0,
"multi_stft_resolutions_window_sizes":(4096, 2048, 1024, 512, 256),
"num_stems": 1,
"stereo": True,
"stft_hop_length": 441,
"stft_n_fft": 2048,
"stft_normalized": False,
"stft_win_length": 2048,
"time_transformer_depth": 1,
}
from bs_roformer.bs_roformer import BSRoformer
model = BSRoformer(
**dict(config)
)
return model
def demix_track(self, model, mix, device):
C = 352800
N = 2
fade_size = C // 10
step = int(C // N)
border = C - step
batch_size = 4
length_init = mix.shape[-1]
progress_bar = tqdm(total=(length_init//step)+3)
progress_bar.set_description("Processing")
# Do pad from the beginning and end to account floating window results better
if length_init > 2 * border and (border > 0):
mix = nn.functional.pad(mix, (border, border), mode='reflect')
# Prepare windows arrays (do 1 time for speed up). This trick repairs click problems on the edges of segment
window_size = C
fadein = torch.linspace(0, 1, fade_size)
fadeout = torch.linspace(1, 0, fade_size)
window_start = torch.ones(window_size)
window_middle = torch.ones(window_size)
window_finish = torch.ones(window_size)
window_start[-fade_size:] *= fadeout # First audio chunk, no fadein
window_finish[:fade_size] *= fadein # Last audio chunk, no fadeout
window_middle[-fade_size:] *= fadeout
window_middle[:fade_size] *= fadein
with torch.cuda.amp.autocast():
with torch.inference_mode():
req_shape = (1, ) + tuple(mix.shape)
result = torch.zeros(req_shape, dtype=torch.float32)
counter = torch.zeros(req_shape, dtype=torch.float32)
i = 0
batch_data = []
batch_locations = []
while i < mix.shape[1]:
part = mix[:, i:i + C].to(device)
length = part.shape[-1]
if length < C:
if length > C // 2 + 1:
part = nn.functional.pad(input=part, pad=(0, C - length), mode='reflect')
else:
part = nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0)
batch_data.append(part)
batch_locations.append((i, length))
i += step
progress_bar.update(1)
if len(batch_data) >= batch_size or (i >= mix.shape[1]):
arr = torch.stack(batch_data, dim=0)
x = model(arr)
window = window_middle
if i - step == 0: # First audio chunk, no fadein
window = window_start
elif i >= mix.shape[1]: # Last audio chunk, no fadeout
window = window_finish
for j in range(len(batch_locations)):
start, l = batch_locations[j]
result[..., start:start+l] += x[j][..., :l].cpu() * window[..., :l]
counter[..., start:start+l] += window[..., :l]
batch_data = []
batch_locations = []
estimated_sources = result / counter
estimated_sources = estimated_sources.cpu().numpy()
np.nan_to_num(estimated_sources, copy=False, nan=0.0)
if length_init > 2 * border and (border > 0):
# Remove pad
estimated_sources = estimated_sources[..., border:-border]
progress_bar.close()
return {k: v for k, v in zip(['vocals', 'other'], estimated_sources)}
def run_folder(self,input, vocal_root, others_root, format):
# start_time = time.time()
self.model.eval()
path = input
if not os.path.isdir(vocal_root):
os.mkdir(vocal_root)
if not os.path.isdir(others_root):
os.mkdir(others_root)
try:
mix, sr = librosa.load(path, sr=44100, mono=False)
except Exception as e:
print('Can read track: {}'.format(path))
print('Error message: {}'.format(str(e)))
return
# Convert mono to stereo if needed
if len(mix.shape) == 1:
mix = np.stack([mix, mix], axis=0)
mix_orig = mix.copy()
mixture = torch.tensor(mix, dtype=torch.float32)
res = self.demix_track(self.model, mixture, self.device)
estimates = res['vocals'].T
print("{}/{}_{}.{}".format(vocal_root, os.path.basename(path)[:-4], 'vocals', format))
if format in ["wav", "flac"]:
sf.write("{}/{}_{}.{}".format(vocal_root, os.path.basename(path)[:-4], 'vocals', format), estimates, sr)
sf.write("{}/{}_{}.{}".format(others_root, os.path.basename(path)[:-4], 'instrumental', format), mix_orig.T - estimates, sr)
else:
path_vocal = "%s/%s_vocals.wav" % (vocal_root, os.path.basename(path)[:-4])
path_other = "%s/%s_instrumental.wav" % (others_root, os.path.basename(path)[:-4])
sf.write(path_vocal, estimates, sr)
sf.write(path_other, mix_orig.T - estimates, sr)
opt_path_vocal = path_vocal[:-4] + ".%s" % format
opt_path_other = path_other[:-4] + ".%s" % format
if os.path.exists(path_vocal):
os.system(
"ffmpeg -i '%s' -vn '%s' -q:a 2 -y" % (path_vocal, opt_path_vocal)
)
if os.path.exists(opt_path_vocal):
try:
os.remove(path_vocal)
except:
pass
if os.path.exists(path_other):
os.system(
"ffmpeg -i '%s' -vn '%s' -q:a 2 -y" % (path_other, opt_path_other)
)
if os.path.exists(opt_path_other):
try:
os.remove(path_other)
except:
pass
# print("Elapsed time: {:.2f} sec".format(time.time() - start_time))
def __init__(self, model_path, device):
self.device = device
self.extract_instrumental=True
model = self.get_model_from_config()
state_dict = torch.load(model_path)
model.load_state_dict(state_dict)
self.model = model.to(device)
def _path_audio_(self, input, others_root, vocal_root, format, is_hp3=False):
self.run_folder(input, vocal_root, others_root, format)

View File

@ -12,6 +12,7 @@ import torch
import sys
from mdxnet import MDXNetDereverb
from vr import AudioPre, AudioPreDeEcho
from bsroformer import BsRoformer_Loader
weight_uvr5_root = "tools/uvr5/uvr5_weights"
uvr5_names = []
@ -33,6 +34,12 @@ def uvr(model_name, inp_root, save_root_vocal, paths, save_root_ins, agg, format
is_hp3 = "HP3" in model_name
if model_name == "onnx_dereverb_By_FoxJoy":
pre_fun = MDXNetDereverb(15)
elif model_name == "Bs_Roformer" or "bs_roformer" in model_name.lower():
func = BsRoformer_Loader
pre_fun = func(
model_path = os.path.join(weight_uvr5_root, model_name + ".pth"),
device = device,
)
else:
func = AudioPre if "DeEcho" not in model_name else AudioPreDeEcho
pre_fun = func(