Support for mel_band_roformer (#2078)

* support for mel_band_roformer

* Remove unnecessary audio channel judgments

* remove context manager and fix path

* Update webui.py

* Update README.md
This commit is contained in:
Sucial 2025-02-23 20:28:53 +08:00 committed by GitHub
parent fbb9f21e53
commit e061e9d38e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 941 additions and 176 deletions

View File

@ -152,6 +152,10 @@ docker run --rm -it --gpus=all --env=is_half=False --volume=G:\GPT-SoVITS-Docker
3. For UVR5 (Vocals/Accompaniment Separation & Reverberation Removal, additionally), download models from [UVR5 Weights](https://huggingface.co/lj1995/VoiceConversionWebUI/tree/main/uvr5_weights) and place them in `tools/uvr5/uvr5_weights`.
- If you want to use `bs_roformer` or `mel_band_roformer` models for UVR5, you can manually download the model and corresponding configuration file, and put them in `tools/uvr5/uvr5_weights`. **Rename the model file and configuration file, ensure that the model and configuration files have the same and corresponding names except for the suffix**. In addition, the model and configuration file names **must include `roformer`** in order to be recognized as models of the roformer class.
- The suggestion is to **directly specify the model type** in the model name and configuration file name, such as `mel_mand_roformer`, `bs_roformer`. If not specified, the features will be compared from the configuration file to determine which type of model it is. For example, the model `bs_roformer_ep_368_sdr_12.9628.ckpt` and its corresponding configuration file `bs_roformer_ep_368_sdr_12.9628.yaml` are a pair, `kim_mel_band_roformer.ckpt` and `kim_mel_band_roformer.yaml` are also a pair.
4. For Chinese ASR (additionally), download models from [Damo ASR Model](https://modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/files), [Damo VAD Model](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/files), and [Damo Punc Model](https://modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/files) and place them in `tools/asr/models`.
5. For English or Japanese ASR (additionally), download models from [Faster Whisper Large V3](https://huggingface.co/Systran/faster-whisper-large-v3) and place them in `tools/asr/models`. Also, [other models](https://huggingface.co/Systran) may have the similar effect with smaller disk footprint.

View File

@ -149,6 +149,11 @@ docker run --rm -it --gpus=all --env=is_half=False --volume=G:\GPT-SoVITS-Docker
3. 对于 UVR5人声/伴奏分离和混响移除,额外功能),从 [UVR5 Weights](https://huggingface.co/lj1995/VoiceConversionWebUI/tree/main/uvr5_weights) 下载模型,并将其放置在 `tools/uvr5/uvr5_weights` 目录中。
- 如果你在 UVR5 中使用 `bs_roformer``mel_band_roformer`模型,你可以手动下载模型和相应的配置文件,并将它们放在 `tools/UVR5/UVR5_weights` 中。**重命名模型文件和配置文件,确保除后缀外**,模型和配置文件具有相同且对应的名称。此外,模型和配置文件名**必须包含“roformer”**,才能被识别为 roformer 类的模型。
- 建议在模型名称和配置文件名中**直接指定模型类型**,例如`mel_mand_roformer``bs_roformer`。如果未指定,将从配置文中比对特征,以确定它是哪种类型的模型。例如,模型`bs_roformer_ep_368_sdr_12.9628.ckpt` 和对应的配置文件`bs_roformer_ep_368_sdr_12.9628.yaml` 是一对。`kim_mel_band_roformer.ckpt``kim_mel_band_roformer.yaml` 也是一对。
4. 对于中文 ASR额外功能从 [Damo ASR Model](https://modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/files)、[Damo VAD Model](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/files) 和 [Damo Punc Model](https://modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/files) 下载模型,并将它们放置在 `tools/asr/models` 目录中。
5. 对于英语或日语 ASR额外功能从 [Faster Whisper Large V3](https://huggingface.co/Systran/faster-whisper-large-v3) 下载模型,并将其放置在 `tools/asr/models` 目录中。此外,[其他模型](https://huggingface.co/Systran) 可能具有类似效果且占用更少的磁盘空间。

View File

@ -142,6 +142,10 @@ docker run --rm -it --gpus=all --env=is_half=False --volume=G:\GPT-SoVITS-Docker
3. UVR5ボーカル/伴奏BGM等分離 & リバーブ除去の追加機能)の場合は、[UVR5 Weights](https://huggingface.co/lj1995/VoiceConversionWebUI/tree/main/uvr5_weights) からモデルをダウンロードし、`tools/uvr5/uvr5_weights` ディレクトリに配置してください。
- UVR5でbs_roformerまたはmel_band_roformerモデルを使用する場合、モデルと対応する設定ファイルを手動でダウンロードし、`tools/UVR5/UVR5_weights`フォルダに配置することができます。**モデルファイルと設定ファイルの名前は、拡張子を除いて同じであることを確認してください**。さらに、モデルと設定ファイルの名前には**「roformer」が含まれている必要があります**。これにより、roformerクラスのモデルとして認識されます。
- モデル名と設定ファイル名には、**直接モデルタイプを指定することをお勧めします**。例mel_mand_roformer、bs_roformer。指定しない場合、設定文から特徴を照合して、モデルの種類を特定します。例えば、モデル`bs_roformer_ep_368_sdr_12.9628.ckpt`と対応する設定ファイル`bs_roformer_ep_368_sdr_12.9628.yaml`はペアです。同様に、`kim_mel_band_roformer.ckpt``kim_mel_band_roformer.yaml`もペアです。
4. 中国語ASR追加機能の場合は、[Damo ASR Model](https://modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/files)、[Damo VAD Model](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/files)、および [Damo Punc Model](https://modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/files) からモデルをダウンロードし、`tools/asr/models` ディレクトリに配置してください。
5. 英語または日本語のASR追加機能を使用する場合は、[Faster Whisper Large V3](https://huggingface.co/Systran/faster-whisper-large-v3) からモデルをダウンロードし、`tools/asr/models` ディレクトリに配置してください。また、[他のモデル](https://huggingface.co/Systran) は、より小さいサイズで高クオリティな可能性があります。

View File

@ -145,6 +145,10 @@ docker run --rm -it --gpus=all --env=is_half=False --volume=G:\GPT-SoVITS-Docker
3. UVR5 (보컬/반주 분리 & 잔향 제거 추가 기능)의 경우, [UVR5 Weights](https://huggingface.co/lj1995/VoiceConversionWebUI/tree/main/uvr5_weights) 에서 모델을 다운로드하고 `tools/uvr5/uvr5_weights` 디렉토리에 배치하세요.
- UVR5에서 bs_roformer 또는 mel_band_roformer 모델을 사용할 경우, 모델과 해당 설정 파일을 수동으로 다운로드하여 `tools/UVR5/UVR5_weights` 폴더에 저장할 수 있습니다. **모델 파일과 설정 파일의 이름은 확장자를 제외하고 동일한 이름을 가지도록 해야 합니다**. 또한, 모델과 설정 파일 이름에는 **“roformer”**가 포함되어야 roformer 클래스의 모델로 인식됩니다.
- 모델 이름과 설정 파일 이름에 **모델 유형을 직접 지정하는 것이 좋습니다**. 예: mel_mand_roformer, bs_roformer. 지정하지 않으면 설정 파일을 기준으로 특성을 비교하여 어떤 유형의 모델인지를 판단합니다. 예를 들어, 모델 `bs_roformer_ep_368_sdr_12.9628.ckpt`와 해당 설정 파일 `bs_roformer_ep_368_sdr_12.9628.yaml`은 한 쌍입니다. `kim_mel_band_roformer.ckpt``kim_mel_band_roformer.yaml`도 한 쌍입니다.
4. 중국어 ASR (추가 기능)의 경우, [Damo ASR Model](https://modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/files), [Damo VAD Model](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/files) 및 [Damo Punc Model](https://modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/files) 에서 모델을 다운로드하고, `tools/asr/models` 디렉토리에 배치하세요.
5. 영어 또는 일본어 ASR (추가 기능)의 경우, [Faster Whisper Large V3](https://huggingface.co/Systran/faster-whisper-large-v3) 에서 모델을 다운로드하고, `tools/asr/models` 디렉토리에 배치하세요. 또한, [다른 모델](https://huggingface.co/Systran) 은 더 적은 디스크 용량으로 비슷한 효과를 가질 수 있습니다.

View File

@ -142,6 +142,10 @@ docker run --rm -it --gpus=all --env=is_half=False --volume=G:\GPT-SoVITS-Docker
3. UVR5 (Vokal/Enstrümantal Ayrımı & Yankı Giderme) için, [UVR5 Weights](https://huggingface.co/lj1995/VoiceConversionWebUI/tree/main/uvr5_weights) üzerinden modelleri indirip `tools/uvr5/uvr5_weights` dizinine yerleştirin.
- UVR5'te bs_roformer veya mel_band_roformer modellerini kullanıyorsanız, modeli ve ilgili yapılandırma dosyasını manuel olarak indirip `tools/UVR5/UVR5_weights` klasörüne yerleştirebilirsiniz. **Model dosyası ve yapılandırma dosyasının adı, uzantı dışında aynı olmalıdır**. Ayrıca, model ve yapılandırma dosyasının adlarında **“roformer”** kelimesi yer almalıdır, böylece roformer sınıfındaki bir model olarak tanınır.
- Model adı ve yapılandırma dosyası adı içinde **doğrudan model tipini belirtmek önerilir**. Örneğin: mel_mand_roformer, bs_roformer. Belirtilmezse, yapılandırma dosyasından özellikler karşılaştırılarak model tipi belirlenir. Örneğin, `bs_roformer_ep_368_sdr_12.9628.ckpt` modeli ve karşılık gelen yapılandırma dosyası `bs_roformer_ep_368_sdr_12.9628.yaml` bir çifttir. Aynı şekilde, `kim_mel_band_roformer.ckpt` ve `kim_mel_band_roformer.yaml` da bir çifttir.
4. Çince ASR için, [Damo ASR Model](https://modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/files), [Damo VAD Model](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/files) ve [Damo Punc Model](https://modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/files) üzerinden modelleri indirip `tools/asr/models` dizinine yerleştirin.
5. İngilizce veya Japonca ASR için, [Faster Whisper Large V3](https://huggingface.co/Systran/faster-whisper-large-v3) üzerinden modeli indirip `tools/asr/models` dizinine yerleştirin. Ayrıca, [diğer modeller](https://huggingface.co/Systran) benzer bir etki yaratabilir ve daha az disk alanı kaplayabilir.

View File

@ -1,18 +1,8 @@
from functools import wraps
from packaging import version
from collections import namedtuple
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, reduce
# constants
FlashAttentionConfig = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
# helpers
def exists(val):
return val is not None
@ -20,21 +10,6 @@ def exists(val):
def default(v, d):
return v if exists(v) else d
def once(fn):
called = False
@wraps(fn)
def inner(x):
nonlocal called
if called:
return
called = True
return fn(x)
return inner
print_once = once(print)
# main class
class Attend(nn.Module):
def __init__(
self,
@ -50,43 +25,16 @@ class Attend(nn.Module):
self.flash = flash
assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
# determine efficient attention configs for cuda and cpu
self.cpu_config = FlashAttentionConfig(True, True, True)
self.cuda_config = None
if not torch.cuda.is_available() or not flash:
return
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
if device_properties.major == 8 and device_properties.minor == 0:
print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
self.cuda_config = FlashAttentionConfig(True, False, False)
else:
print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
self.cuda_config = FlashAttentionConfig(False, True, True)
def flash_attn(self, q, k, v):
_, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
# _, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
if exists(self.scale):
default_scale = q.shape[-1] ** -0.5
q = q * (self.scale / default_scale)
# Check if there is a compatible device for flash attention
config = self.cuda_config if is_cuda else self.cpu_config
# pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale
with torch.backends.cuda.sdp_kernel(**config._asdict()):
out = F.scaled_dot_product_attention(
q, k, v,
dropout_p = self.dropout if self.training else 0.
)
return out
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
return F.scaled_dot_product_attention(q, k, v,dropout_p = self.dropout if self.training else 0.)
def forward(self, q, k, v):
"""
@ -97,7 +45,7 @@ class Attend(nn.Module):
d - feature dimension
"""
q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
# q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
scale = default(self.scale, q.shape[-1] ** -0.5)

View File

@ -6,6 +6,7 @@ from torch.nn import Module, ModuleList
import torch.nn.functional as F
from bs_roformer.attend import Attend
from torch.utils.checkpoint import checkpoint
from typing import Tuple, Optional, List, Callable
# from beartype.typing import Tuple, Optional, List, Callable
@ -356,13 +357,18 @@ class BSRoformer(Module):
multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
multi_stft_hop_size=147,
multi_stft_normalized=False,
multi_stft_window_fn: Callable = torch.hann_window
multi_stft_window_fn: Callable = torch.hann_window,
mlp_expansion_factor=4,
use_torch_checkpoint=False,
skip_connection=False,
):
super().__init__()
self.stereo = stereo
self.audio_channels = 2 if stereo else 1
self.num_stems = num_stems
self.use_torch_checkpoint = use_torch_checkpoint
self.skip_connection = skip_connection
self.layers = ModuleList([])
@ -402,7 +408,7 @@ class BSRoformer(Module):
self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, return_complex=True).shape[1]
freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, window=torch.ones(stft_win_length), return_complex=True).shape[1]
assert len(freqs_per_bands) > 1
assert sum(
@ -421,7 +427,8 @@ class BSRoformer(Module):
mask_estimator = MaskEstimator(
dim=dim,
dim_inputs=freqs_per_bands_with_complex,
depth=mask_estimator_depth
depth=mask_estimator_depth,
mlp_expansion_factor=mlp_expansion_factor,
)
self.mask_estimators.append(mask_estimator)
@ -458,12 +465,14 @@ class BSRoformer(Module):
device = raw_audio.device
# defining whether model is loaded on MPS (MacOS GPU accelerator)
x_is_mps = True if device.type == "mps" else False
if raw_audio.ndim == 2:
raw_audio = rearrange(raw_audio, 'b t -> b 1 t')
channels = raw_audio.shape[1]
assert (not self.stereo and channels == 1) or (
self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)'
assert (not self.stereo and channels == 1) or (self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)'
# to stft
@ -471,53 +480,79 @@ class BSRoformer(Module):
stft_window = self.stft_window_fn(device=device)
stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
# RuntimeError: FFT operations are only supported on MacOS 14+
# Since it's tedious to define whether we're on correct MacOS version - simple try-catch is used
try:
stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
except:
stft_repr = torch.stft(raw_audio.cpu() if x_is_mps else raw_audio, **self.stft_kwargs, window=stft_window.cpu() if x_is_mps else stft_window, return_complex=True).to(device)
stft_repr = torch.view_as_real(stft_repr)
stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c')
stft_repr = rearrange(stft_repr,
'b s f t c -> b (f s) t c') # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
# merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
stft_repr = rearrange(stft_repr,'b s f t c -> b (f s) t c')
x = rearrange(stft_repr, 'b f t c -> b t (f c)')
# print("460:", x.dtype)#fp32
x = self.band_split(x)
if self.use_torch_checkpoint:
x = checkpoint(self.band_split, x, use_reentrant=False)
else:
x = self.band_split(x)
# axial / hierarchical attention
# print("487:",x.dtype)#fp16
for transformer_block in self.layers:
store = [None] * len(self.layers)
for i, transformer_block in enumerate(self.layers):
if len(transformer_block) == 3:
linear_transformer, time_transformer, freq_transformer = transformer_block
x, ft_ps = pack([x], 'b * d')
# print("494:", x.dtype)#fp16
x = linear_transformer(x)
# print("496:", x.dtype)#fp16
if self.use_torch_checkpoint:
x = checkpoint(linear_transformer, x, use_reentrant=False)
else:
x = linear_transformer(x)
x, = unpack(x, ft_ps, 'b * d')
else:
time_transformer, freq_transformer = transformer_block
# print("501:", x.dtype)#fp16
if self.skip_connection:
# Sum all previous
for j in range(i):
x = x + store[j]
x = rearrange(x, 'b t f d -> b f t d')
x, ps = pack([x], '* t d')
x = time_transformer(x)
# print("505:", x.dtype)#fp16
if self.use_torch_checkpoint:
x = checkpoint(time_transformer, x, use_reentrant=False)
else:
x = time_transformer(x)
x, = unpack(x, ps, '* t d')
x = rearrange(x, 'b f t d -> b t f d')
x, ps = pack([x], '* f d')
x = freq_transformer(x)
if self.use_torch_checkpoint:
x = checkpoint(freq_transformer, x, use_reentrant=False)
else:
x = freq_transformer(x)
x, = unpack(x, ps, '* f d')
# print("515:", x.dtype)######fp16
if self.skip_connection:
store[i] = x
x = self.final_norm(x)
num_stems = len(self.mask_estimators)
# print("519:", x.dtype)#fp32
mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
if self.use_torch_checkpoint:
mask = torch.stack([checkpoint(fn, x, use_reentrant=False) for fn in self.mask_estimators], dim=1)
else:
mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
mask = rearrange(mask, 'b n t (f c) -> b n f t c', c=2)
# modulate frequency representation
@ -535,7 +570,11 @@ class BSRoformer(Module):
stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels)
recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False)
# same as torch.stft() fix for MacOS MPS above
try:
recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False, length=raw_audio.shape[-1])
except:
recon_audio = torch.istft(stft_repr.cpu() if x_is_mps else stft_repr, **self.stft_kwargs, window=stft_window.cpu() if x_is_mps else stft_window, return_complex=False, length=raw_audio.shape[-1]).to(device)
recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', s=self.audio_channels, n=num_stems)

View File

@ -0,0 +1,669 @@
from functools import partial
import torch
from torch import nn, einsum, Tensor
from torch.nn import Module, ModuleList
import torch.nn.functional as F
from bs_roformer.attend import Attend
from torch.utils.checkpoint import checkpoint
from typing import Tuple, Optional, List, Callable
# from beartype.typing import Tuple, Optional, List, Callable
# from beartype import beartype
from rotary_embedding_torch import RotaryEmbedding
from einops import rearrange, pack, unpack, reduce, repeat
from einops.layers.torch import Rearrange
from librosa import filters
# helper functions
def exists(val):
return val is not None
def default(v, d):
return v if exists(v) else d
def pack_one(t, pattern):
return pack([t], pattern)
def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]
def pad_at_dim(t, pad, dim=-1, value=0.):
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
zeros = ((0, 0) * dims_from_right)
return F.pad(t, (*zeros, *pad), value=value)
def l2norm(t):
return F.normalize(t, dim=-1, p=2)
# norm
class RMSNorm(Module):
def __init__(self, dim):
super().__init__()
self.scale = dim ** 0.5
self.gamma = nn.Parameter(torch.ones(dim))
def forward(self, x):
return F.normalize(x, dim=-1) * self.scale * self.gamma
# attention
class FeedForward(Module):
def __init__(
self,
dim,
mult=4,
dropout=0.
):
super().__init__()
dim_inner = int(dim * mult)
self.net = nn.Sequential(
RMSNorm(dim),
nn.Linear(dim, dim_inner),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim_inner, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(Module):
def __init__(
self,
dim,
heads=8,
dim_head=64,
dropout=0.,
rotary_embed=None,
flash=True
):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
dim_inner = heads * dim_head
self.rotary_embed = rotary_embed
self.attend = Attend(flash=flash, dropout=dropout)
self.norm = RMSNorm(dim)
self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
self.to_gates = nn.Linear(dim, heads)
self.to_out = nn.Sequential(
nn.Linear(dim_inner, dim, bias=False),
nn.Dropout(dropout)
)
def forward(self, x):
x = self.norm(x)
q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads)
if exists(self.rotary_embed):
q = self.rotary_embed.rotate_queries_or_keys(q)
k = self.rotary_embed.rotate_queries_or_keys(k)
out = self.attend(q, k, v)
gates = self.to_gates(x)
out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid()
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class LinearAttention(Module):
"""
this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al.
"""
# @beartype
def __init__(
self,
*,
dim,
dim_head=32,
heads=8,
scale=8,
flash=False,
dropout=0.
):
super().__init__()
dim_inner = dim_head * heads
self.norm = RMSNorm(dim)
self.to_qkv = nn.Sequential(
nn.Linear(dim, dim_inner * 3, bias=False),
Rearrange('b n (qkv h d) -> qkv b h d n', qkv=3, h=heads)
)
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
self.attend = Attend(
scale=scale,
dropout=dropout,
flash=flash
)
self.to_out = nn.Sequential(
Rearrange('b h d n -> b n (h d)'),
nn.Linear(dim_inner, dim, bias=False)
)
def forward(
self,
x
):
x = self.norm(x)
q, k, v = self.to_qkv(x)
q, k = map(l2norm, (q, k))
q = q * self.temperature.exp()
out = self.attend(q, k, v)
return self.to_out(out)
class Transformer(Module):
def __init__(
self,
*,
dim,
depth,
dim_head=64,
heads=8,
attn_dropout=0.,
ff_dropout=0.,
ff_mult=4,
norm_output=True,
rotary_embed=None,
flash_attn=True,
linear_attn=False
):
super().__init__()
self.layers = ModuleList([])
for _ in range(depth):
if linear_attn:
attn = LinearAttention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, flash=flash_attn)
else:
attn = Attention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout,
rotary_embed=rotary_embed, flash=flash_attn)
self.layers.append(ModuleList([
attn,
FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)
]))
self.norm = RMSNorm(dim) if norm_output else nn.Identity()
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return self.norm(x)
# bandsplit module
class BandSplit(Module):
# @beartype
def __init__(
self,
dim,
dim_inputs: Tuple[int, ...]
):
super().__init__()
self.dim_inputs = dim_inputs
self.to_features = ModuleList([])
for dim_in in dim_inputs:
net = nn.Sequential(
RMSNorm(dim_in),
nn.Linear(dim_in, dim)
)
self.to_features.append(net)
def forward(self, x):
x = x.split(self.dim_inputs, dim=-1)
outs = []
for split_input, to_feature in zip(x, self.to_features):
split_output = to_feature(split_input)
outs.append(split_output)
return torch.stack(outs, dim=-2)
def MLP(
dim_in,
dim_out,
dim_hidden=None,
depth=1,
activation=nn.Tanh
):
dim_hidden = default(dim_hidden, dim_in)
net = []
dims = (dim_in, *((dim_hidden,) * depth), dim_out)
for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
is_last = ind == (len(dims) - 2)
net.append(nn.Linear(layer_dim_in, layer_dim_out))
if is_last:
continue
net.append(activation())
return nn.Sequential(*net)
class MaskEstimator(Module):
# @beartype
def __init__(
self,
dim,
dim_inputs: Tuple[int, ...],
depth,
mlp_expansion_factor=4
):
super().__init__()
self.dim_inputs = dim_inputs
self.to_freqs = ModuleList([])
dim_hidden = dim * mlp_expansion_factor
for dim_in in dim_inputs:
net = []
mlp = nn.Sequential(
MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth),
nn.GLU(dim=-1)
)
self.to_freqs.append(mlp)
def forward(self, x):
x = x.unbind(dim=-2)
outs = []
for band_features, mlp in zip(x, self.to_freqs):
freq_out = mlp(band_features)
outs.append(freq_out)
return torch.cat(outs, dim=-1)
# main class
class MelBandRoformer(Module):
# @beartype
def __init__(
self,
dim,
*,
depth,
stereo=False,
num_stems=1,
time_transformer_depth=2,
freq_transformer_depth=2,
linear_transformer_depth=0,
num_bands=60,
dim_head=64,
heads=8,
attn_dropout=0.1,
ff_dropout=0.1,
flash_attn=True,
dim_freqs_in=1025,
sample_rate=44100, # needed for mel filter bank from librosa
stft_n_fft=2048,
stft_hop_length=512,
# 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
stft_win_length=2048,
stft_normalized=False,
stft_window_fn: Optional[Callable] = None,
mask_estimator_depth=1,
multi_stft_resolution_loss_weight=1.,
multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
multi_stft_hop_size=147,
multi_stft_normalized=False,
multi_stft_window_fn: Callable = torch.hann_window,
match_input_audio_length=False, # if True, pad output tensor to match length of input tensor
mlp_expansion_factor=4,
use_torch_checkpoint=False,
skip_connection=False,
):
super().__init__()
self.stereo = stereo
self.audio_channels = 2 if stereo else 1
self.num_stems = num_stems
self.use_torch_checkpoint = use_torch_checkpoint
self.skip_connection = skip_connection
self.layers = ModuleList([])
transformer_kwargs = dict(
dim=dim,
heads=heads,
dim_head=dim_head,
attn_dropout=attn_dropout,
ff_dropout=ff_dropout,
flash_attn=flash_attn
)
time_rotary_embed = RotaryEmbedding(dim=dim_head)
freq_rotary_embed = RotaryEmbedding(dim=dim_head)
for _ in range(depth):
tran_modules = []
if linear_transformer_depth > 0:
tran_modules.append(Transformer(depth=linear_transformer_depth, linear_attn=True, **transformer_kwargs))
tran_modules.append(
Transformer(depth=time_transformer_depth, rotary_embed=time_rotary_embed, **transformer_kwargs)
)
tran_modules.append(
Transformer(depth=freq_transformer_depth, rotary_embed=freq_rotary_embed, **transformer_kwargs)
)
self.layers.append(nn.ModuleList(tran_modules))
self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
self.stft_kwargs = dict(
n_fft=stft_n_fft,
hop_length=stft_hop_length,
win_length=stft_win_length,
normalized=stft_normalized
)
freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, window=torch.ones(stft_n_fft), return_complex=True).shape[1]
# create mel filter bank
# with librosa.filters.mel as in section 2 of paper
mel_filter_bank_numpy = filters.mel(sr=sample_rate, n_fft=stft_n_fft, n_mels=num_bands)
mel_filter_bank = torch.from_numpy(mel_filter_bank_numpy)
# for some reason, it doesn't include the first freq? just force a value for now
mel_filter_bank[0][0] = 1.
# In some systems/envs we get 0.0 instead of ~1.9e-18 in the last position,
# so let's force a positive value
mel_filter_bank[-1, -1] = 1.
# binary as in paper (then estimated masks are averaged for overlapping regions)
freqs_per_band = mel_filter_bank > 0
assert freqs_per_band.any(dim=0).all(), 'all frequencies need to be covered by all bands for now'
repeated_freq_indices = repeat(torch.arange(freqs), 'f -> b f', b=num_bands)
freq_indices = repeated_freq_indices[freqs_per_band]
if stereo:
freq_indices = repeat(freq_indices, 'f -> f s', s=2)
freq_indices = freq_indices * 2 + torch.arange(2)
freq_indices = rearrange(freq_indices, 'f s -> (f s)')
self.register_buffer('freq_indices', freq_indices, persistent=False)
self.register_buffer('freqs_per_band', freqs_per_band, persistent=False)
num_freqs_per_band = reduce(freqs_per_band, 'b f -> b', 'sum')
num_bands_per_freq = reduce(freqs_per_band, 'b f -> f', 'sum')
self.register_buffer('num_freqs_per_band', num_freqs_per_band, persistent=False)
self.register_buffer('num_bands_per_freq', num_bands_per_freq, persistent=False)
# band split and mask estimator
freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in num_freqs_per_band.tolist())
self.band_split = BandSplit(
dim=dim,
dim_inputs=freqs_per_bands_with_complex
)
self.mask_estimators = nn.ModuleList([])
for _ in range(num_stems):
mask_estimator = MaskEstimator(
dim=dim,
dim_inputs=freqs_per_bands_with_complex,
depth=mask_estimator_depth,
mlp_expansion_factor=mlp_expansion_factor,
)
self.mask_estimators.append(mask_estimator)
# for the multi-resolution stft loss
self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
self.multi_stft_n_fft = stft_n_fft
self.multi_stft_window_fn = multi_stft_window_fn
self.multi_stft_kwargs = dict(
hop_length=multi_stft_hop_size,
normalized=multi_stft_normalized
)
self.match_input_audio_length = match_input_audio_length
def forward(
self,
raw_audio,
target=None,
return_loss_breakdown=False
):
"""
einops
b - batch
f - freq
t - time
s - audio channel (1 for mono, 2 for stereo)
n - number of 'stems'
c - complex (2)
d - feature dimension
"""
device = raw_audio.device
if raw_audio.ndim == 2:
raw_audio = rearrange(raw_audio, 'b t -> b 1 t')
batch, channels, raw_audio_length = raw_audio.shape
istft_length = raw_audio_length if self.match_input_audio_length else None
assert (not self.stereo and channels == 1) or (
self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)'
# to stft
raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t')
stft_window = self.stft_window_fn(device=device)
stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
stft_repr = torch.view_as_real(stft_repr)
stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c')
# merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
stft_repr = rearrange(stft_repr,'b s f t c -> b (f s) t c')
# index out all frequencies for all frequency ranges across bands ascending in one go
batch_arange = torch.arange(batch, device=device)[..., None]
# account for stereo
x = stft_repr[batch_arange, self.freq_indices]
# fold the complex (real and imag) into the frequencies dimension
x = rearrange(x, 'b f t c -> b t (f c)')
if self.use_torch_checkpoint:
x = checkpoint(self.band_split, x, use_reentrant=False)
else:
x = self.band_split(x)
# axial / hierarchical attention
store = [None] * len(self.layers)
for i, transformer_block in enumerate(self.layers):
if len(transformer_block) == 3:
linear_transformer, time_transformer, freq_transformer = transformer_block
x, ft_ps = pack([x], 'b * d')
if self.use_torch_checkpoint:
x = checkpoint(linear_transformer, x, use_reentrant=False)
else:
x = linear_transformer(x)
x, = unpack(x, ft_ps, 'b * d')
else:
time_transformer, freq_transformer = transformer_block
if self.skip_connection:
# Sum all previous
for j in range(i):
x = x + store[j]
x = rearrange(x, 'b t f d -> b f t d')
x, ps = pack([x], '* t d')
if self.use_torch_checkpoint:
x = checkpoint(time_transformer, x, use_reentrant=False)
else:
x = time_transformer(x)
x, = unpack(x, ps, '* t d')
x = rearrange(x, 'b f t d -> b t f d')
x, ps = pack([x], '* f d')
if self.use_torch_checkpoint:
x = checkpoint(freq_transformer, x, use_reentrant=False)
else:
x = freq_transformer(x)
x, = unpack(x, ps, '* f d')
if self.skip_connection:
store[i] = x
num_stems = len(self.mask_estimators)
if self.use_torch_checkpoint:
masks = torch.stack([checkpoint(fn, x, use_reentrant=False) for fn in self.mask_estimators], dim=1)
else:
masks = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
masks = rearrange(masks, 'b n t (f c) -> b n f t c', c=2)
# modulate frequency representation
stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c')
# complex number multiplication
stft_repr = torch.view_as_complex(stft_repr)
masks = torch.view_as_complex(masks)
masks = masks.type(stft_repr.dtype)
# need to average the estimated mask for the overlapped frequencies
scatter_indices = repeat(self.freq_indices, 'f -> b n f t', b=batch, n=num_stems, t=stft_repr.shape[-1])
stft_repr_expanded_stems = repeat(stft_repr, 'b 1 ... -> b n ...', n=num_stems)
masks_summed = torch.zeros_like(stft_repr_expanded_stems).scatter_add_(2, scatter_indices, masks)
denom = repeat(self.num_bands_per_freq, 'f -> (f r) 1', r=channels)
masks_averaged = masks_summed / denom.clamp(min=1e-8)
# modulate stft repr with estimated mask
stft_repr = stft_repr * masks_averaged
# istft
stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels)
recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False,
length=istft_length)
recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', b=batch, s=self.audio_channels, n=num_stems)
if num_stems == 1:
recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t')
# if a target is passed in, calculate loss for learning
if not exists(target):
return recon_audio
if self.num_stems > 1:
assert target.ndim == 4 and target.shape[1] == self.num_stems
if target.ndim == 2:
target = rearrange(target, '... t -> ... 1 t')
target = target[..., :recon_audio.shape[-1]] # protect against lost length on istft
loss = F.l1_loss(recon_audio, target)
multi_stft_resolution_loss = 0.
for window_size in self.multi_stft_resolutions_window_sizes:
res_stft_kwargs = dict(
n_fft=max(window_size, self.multi_stft_n_fft), # not sure what n_fft is across multi resolution stft
win_length=window_size,
return_complex=True,
window=self.multi_stft_window_fn(window_size, device=device),
**self.multi_stft_kwargs,
)
recon_Y = torch.stft(rearrange(recon_audio, '... s t -> (... s) t'), **res_stft_kwargs)
target_Y = torch.stft(rearrange(target, '... s t -> (... s) t'), **res_stft_kwargs)
multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y)
weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
total_loss = loss + weighted_multi_resolution_loss
if not return_loss_breakdown:
return total_loss
return total_loss, (loss, multi_stft_resolution_loss)

View File

@ -1,6 +1,4 @@
# This code is modified from https://github.com/ZFTurbo/
import pdb
import librosa
from tqdm import tqdm
import os
@ -8,61 +6,113 @@ import torch
import numpy as np
import soundfile as sf
import torch.nn as nn
import yaml
import warnings
warnings.filterwarnings("ignore")
from bs_roformer.bs_roformer import BSRoformer
class BsRoformer_Loader:
class Roformer_Loader:
def get_config(self, config_path):
with open(config_path, 'r', encoding='utf-8') as f:
# use fullloader to load tag !!python/tuple, code can be improved
config = yaml.load(f, Loader=yaml.FullLoader)
return config
def get_default_config(self):
default_config = None
if self.model_type == 'bs_roformer':
# Use model_bs_roformer_ep_368_sdr_12.9628.yaml and model_bs_roformer_ep_317_sdr_12.9755.yaml as default configuration files
# Other BS_Roformer models may not be compatible
default_config = {
"audio": {"chunk_size": 352800, "sample_rate": 44100},
"model": {
"dim": 512,
"depth": 12,
"stereo": True,
"num_stems": 1,
"time_transformer_depth": 1,
"freq_transformer_depth": 1,
"linear_transformer_depth": 0,
"freqs_per_bands": (2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 12, 12, 12, 12, 12, 12, 12, 12, 24, 24, 24, 24, 24, 24, 24, 24, 48, 48, 48, 48, 48, 48, 48, 48, 128, 129),
"dim_head": 64,
"heads": 8,
"attn_dropout": 0.1,
"ff_dropout": 0.1,
"flash_attn": True,
"dim_freqs_in": 1025,
"stft_n_fft": 2048,
"stft_hop_length": 441,
"stft_win_length": 2048,
"stft_normalized": False,
"mask_estimator_depth": 2,
"multi_stft_resolution_loss_weight": 1.0,
"multi_stft_resolutions_window_sizes": (4096, 2048, 1024, 512, 256),
"multi_stft_hop_size": 147,
"multi_stft_normalized": False,
},
"training": {"instruments": ["vocals", "other"], "target_instrument": "vocals"},
"inference": {"batch_size": 2, "num_overlap": 2}
}
elif self.model_type == 'mel_band_roformer':
# Use model_mel_band_roformer_ep_3005_sdr_11.4360.yaml as default configuration files
# Other Mel_Band_Roformer models may not be compatible
default_config = {
"audio": {"chunk_size": 352800, "sample_rate": 44100},
"model": {
"dim": 384,
"depth": 12,
"stereo": True,
"num_stems": 1,
"time_transformer_depth": 1,
"freq_transformer_depth": 1,
"linear_transformer_depth": 0,
"num_bands": 60,
"dim_head": 64,
"heads": 8,
"attn_dropout": 0.1,
"ff_dropout": 0.1,
"flash_attn": True,
"dim_freqs_in": 1025,
"sample_rate": 44100,
"stft_n_fft": 2048,
"stft_hop_length": 441,
"stft_win_length": 2048,
"stft_normalized": False,
"mask_estimator_depth": 2,
"multi_stft_resolution_loss_weight": 1.0,
"multi_stft_resolutions_window_sizes": (4096, 2048, 1024, 512, 256),
"multi_stft_hop_size": 147,
"multi_stft_normalized": False
},
"training": {"instruments": ["vocals", "other"], "target_instrument": "vocals"},
"inference": {"batch_size": 2, "num_overlap": 2}
}
return default_config
def get_model_from_config(self):
config = {
"attn_dropout": 0.1,
"depth": 12,
"dim": 512,
"dim_freqs_in": 1025,
"dim_head": 64,
"ff_dropout": 0.1,
"flash_attn": True,
"freq_transformer_depth": 1,
"freqs_per_bands":(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 12, 12, 12, 12, 12, 12, 12, 12, 24, 24, 24, 24, 24, 24, 24, 24, 48, 48, 48, 48, 48, 48, 48, 48, 128, 129),
"heads": 8,
"linear_transformer_depth": 0,
"mask_estimator_depth": 2,
"multi_stft_hop_size": 147,
"multi_stft_normalized": False,
"multi_stft_resolution_loss_weight": 1.0,
"multi_stft_resolutions_window_sizes":(4096, 2048, 1024, 512, 256),
"num_stems": 1,
"stereo": True,
"stft_hop_length": 441,
"stft_n_fft": 2048,
"stft_normalized": False,
"stft_win_length": 2048,
"time_transformer_depth": 1,
}
model = BSRoformer(
**dict(config)
)
if self.model_type == 'bs_roformer':
from bs_roformer.bs_roformer import BSRoformer
model = BSRoformer(**dict(self.config["model"]))
elif self.model_type == 'mel_band_roformer':
from bs_roformer.mel_band_roformer import MelBandRoformer
model = MelBandRoformer(**dict(self.config["model"]))
else:
print('Error: Unknown model: {}'.format(self.model_type))
model = None
return model
def demix_track(self, model, mix, device):
C = 352800
# num_overlap
N = 1
C = self.config["audio"]["chunk_size"] # chunk_size
N = self.config["inference"]["num_overlap"]
fade_size = C // 10
step = int(C // N)
border = C - step
batch_size = 4
batch_size = self.config["inference"]["batch_size"]
length_init = mix.shape[-1]
progress_bar = tqdm(total=length_init // step + 1)
progress_bar.set_description("Processing")
progress_bar = tqdm(total=length_init // step + 1, desc="Processing", leave=False)
# Do pad from the beginning and end to account floating window results better
if length_init > 2 * border and (border > 0):
@ -82,7 +132,10 @@ class BsRoformer_Loader:
with torch.amp.autocast('cuda'):
with torch.inference_mode():
req_shape = (1, ) + tuple(mix.shape)
if self.config["training"]["target_instrument"] is None:
req_shape = (len(self.config["training"]["instruments"]),) + tuple(mix.shape)
else:
req_shape = (1, ) + tuple(mix.shape)
result = torch.zeros(req_shape, dtype=torch.float32)
counter = torch.zeros(req_shape, dtype=torch.float32)
@ -97,7 +150,7 @@ class BsRoformer_Loader:
part = nn.functional.pad(input=part, pad=(0, C - length), mode='reflect')
else:
part = nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0)
if(self.is_half==True):
if self.is_half:
part=part.half()
batch_data.append(part)
batch_locations.append((i, length))
@ -133,78 +186,109 @@ class BsRoformer_Loader:
progress_bar.close()
return {k: v for k, v in zip(['vocals', 'other'], estimated_sources)}
if self.config["training"]["target_instrument"] is None:
return {k: v for k, v in zip(self.config["training"]["instruments"], estimated_sources)}
else:
return {k: v for k, v in zip([self.config["training"]["target_instrument"]], estimated_sources)}
def run_folder(self,input, vocal_root, others_root, format):
# start_time = time.time()
def run_folder(self, input, vocal_root, others_root, format):
self.model.eval()
path = input
os.makedirs(vocal_root, exist_ok=True)
os.makedirs(others_root, exist_ok=True)
file_base_name = os.path.splitext(os.path.basename(path))[0]
if not os.path.isdir(vocal_root):
os.mkdir(vocal_root)
if not os.path.isdir(others_root):
os.mkdir(others_root)
sample_rate = 44100
if 'sample_rate' in self.config["audio"]:
sample_rate = self.config["audio"]['sample_rate']
try:
mix, sr = librosa.load(path, sr=44100, mono=False)
mix, sr = librosa.load(path, sr=sample_rate, mono=False)
except Exception as e:
print('Can read track: {}'.format(path))
print('Error message: {}'.format(str(e)))
return
# Convert mono to stereo if needed
if len(mix.shape) == 1:
mix = np.stack([mix, mix], axis=0)
# in case if model only supports mono tracks
isstereo = self.config["model"].get("stereo", True)
if not isstereo and len(mix.shape) != 1:
mix = np.mean(mix, axis=0) # if more than 2 channels, take mean
print("Warning: Track has more than 1 channels, but model is mono, taking mean of all channels.")
mix_orig = mix.copy()
mixture = torch.tensor(mix, dtype=torch.float32)
res = self.demix_track(self.model, mixture, self.device)
estimates = res['vocals'].T
if format in ["wav", "flac"]:
sf.write("{}/{}_{}.{}".format(vocal_root, os.path.basename(path)[:-4], 'vocals', format), estimates, sr)
sf.write("{}/{}_{}.{}".format(others_root, os.path.basename(path)[:-4], 'instrumental', format), mix_orig.T - estimates, sr)
if self.config["training"]["target_instrument"] is not None:
# if target instrument is specified, save target instrument as vocal and other instruments as others
# other instruments are caculated by subtracting target instrument from mixture
target_instrument = self.config["training"]["target_instrument"]
other_instruments = [i for i in self.config["training"]["instruments"] if i != target_instrument]
other = mix_orig - res[target_instrument] # caculate other instruments
path_vocal = "{}/{}_{}.wav".format(vocal_root, file_base_name, target_instrument)
path_other = "{}/{}_{}.wav".format(others_root, file_base_name, other_instruments[0])
self.save_audio(path_vocal, res[target_instrument].T, sr, format)
self.save_audio(path_other, other.T, sr, format)
else:
path_vocal = "%s/%s_vocals.wav" % (vocal_root, os.path.basename(path)[:-4])
path_other = "%s/%s_instrumental.wav" % (others_root, os.path.basename(path)[:-4])
sf.write(path_vocal, estimates, sr)
sf.write(path_other, mix_orig.T - estimates, sr)
opt_path_vocal = path_vocal[:-4] + ".%s" % format
opt_path_other = path_other[:-4] + ".%s" % format
if os.path.exists(path_vocal):
os.system(
"ffmpeg -i '%s' -vn '%s' -q:a 2 -y" % (path_vocal, opt_path_vocal)
)
if os.path.exists(opt_path_vocal):
try:
os.remove(path_vocal)
except:
pass
if os.path.exists(path_other):
os.system(
"ffmpeg -i '%s' -vn '%s' -q:a 2 -y" % (path_other, opt_path_other)
)
if os.path.exists(opt_path_other):
try:
os.remove(path_other)
except:
pass
# print("Elapsed time: {:.2f} sec".format(time.time() - start_time))
# if target instrument is not specified, save the first instrument as vocal and the rest as others
vocal_inst = self.config["training"]["instruments"][0]
path_vocal = "{}/{}_{}.wav".format(vocal_root, file_base_name, vocal_inst)
self.save_audio(path_vocal, res[vocal_inst].T, sr, format)
for other in self.config["training"]["instruments"][1:]: # save other instruments
path_other = "{}/{}_{}.wav".format(others_root, file_base_name, other)
self.save_audio(path_other, res[other].T, sr, format)
def __init__(self, model_path, device,is_half):
def save_audio(self, path, data, sr, format):
# input path should be endwith '.wav'
if format in ["wav", "flac"]:
if format == "flac":
path = path[:-3] + "flac"
sf.write(path, data, sr)
else:
sf.write(path, data, sr)
os.system("ffmpeg -i \"{}\" -vn \"{}\" -q:a 2 -y".format(path, path[:-3] + format))
try: os.remove(path)
except: pass
def __init__(self, model_path, config_path, device, is_half):
self.device = device
self.extract_instrumental=True
self.is_half = is_half
self.model_type = None
self.config = None
# get model_type, first try:
if "bs_roformer" in model_path.lower() or "bsroformer" in model_path.lower():
self.model_type = "bs_roformer"
elif "mel_band_roformer" in model_path.lower() or "melbandroformer" in model_path.lower():
self.model_type = "mel_band_roformer"
if not os.path.exists(config_path):
if self.model_type is None:
# if model_type is still None, raise an error
raise ValueError("Error: Unknown model type. If you are using a model without a configuration file, Ensure that your model name includes 'bs_roformer', 'bsroformer', 'mel_band_roformer', or 'melbandroformer'. Otherwise, you can manually place the model configuration file into 'tools/uvr5/uvr5w_weights' and ensure that the configuration file is named as '<model_name>.yaml' then try it again.")
self.config = self.get_default_config()
else:
# if there is a configuration file
self.config = self.get_config(config_path)
if self.model_type is None:
# if model_type is still None, second try, get model_type from the configuration file
if "freqs_per_bands" in self.config["model"]:
# if freqs_per_bands in config, it's a bs_roformer model
self.model_type = "bs_roformer"
else:
# else it's a mel_band_roformer model
self.model_type = "mel_band_roformer"
print("Detected model type: {}".format(self.model_type))
model = self.get_model_from_config()
state_dict = torch.load(model_path,map_location="cpu")
state_dict = torch.load(model_path, map_location="cpu")
model.load_state_dict(state_dict)
self.is_half=is_half
if(is_half==False):
self.model = model.to(device)
else:

View File

@ -12,7 +12,7 @@ import torch
import sys
from mdxnet import MDXNetDereverb
from vr import AudioPre, AudioPreDeEcho
from bsroformer import BsRoformer_Loader
from bsroformer import Roformer_Loader
try:
import gradio.analytics as analytics
@ -49,13 +49,17 @@ def uvr(model_name, inp_root, save_root_vocal, paths, save_root_ins, agg, format
is_hp3 = "HP3" in model_name
if model_name == "onnx_dereverb_By_FoxJoy":
pre_fun = MDXNetDereverb(15)
elif model_name == "Bs_Roformer" or "bs_roformer" in model_name.lower():
func = BsRoformer_Loader
elif "roformer" in model_name.lower():
func = Roformer_Loader
pre_fun = func(
model_path = os.path.join(weight_uvr5_root, model_name + ".ckpt"),
config_path = os.path.join(weight_uvr5_root, model_name + ".yaml"),
device = device,
is_half=is_half
)
if not os.path.exists(os.path.join(weight_uvr5_root, model_name + ".yaml")):
infos.append("Warning: You are using a model without a configuration file. The program will automatically use the default configuration file. However, the default configuration file cannot guarantee that all models will run successfully. You can manually place the model configuration file into 'tools/uvr5/uvr5w_weights' and ensure that the configuration file is named as '<model_name>.yaml' then try it again. (For example, the configuration file corresponding to the model 'bs_roformer_ep_368_sdr_12.9628.ckpt' should be 'bs_roformer_ep_368_sdr_12.9628.yaml'.) Or you can just ignore this warning.")
yield "\n".join(infos)
else:
func = AudioPre if "DeEcho" not in model_name else AudioPreDeEcho
pre_fun = func(