mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-05 04:22:46 +08:00
Support for mel_band_roformer (#2078)
* support for mel_band_roformer * Remove unnecessary audio channel judgments * remove context manager and fix path * Update webui.py * Update README.md
This commit is contained in:
parent
fbb9f21e53
commit
e061e9d38e
@ -152,6 +152,10 @@ docker run --rm -it --gpus=all --env=is_half=False --volume=G:\GPT-SoVITS-Docker
|
||||
|
||||
3. For UVR5 (Vocals/Accompaniment Separation & Reverberation Removal, additionally), download models from [UVR5 Weights](https://huggingface.co/lj1995/VoiceConversionWebUI/tree/main/uvr5_weights) and place them in `tools/uvr5/uvr5_weights`.
|
||||
|
||||
- If you want to use `bs_roformer` or `mel_band_roformer` models for UVR5, you can manually download the model and corresponding configuration file, and put them in `tools/uvr5/uvr5_weights`. **Rename the model file and configuration file, ensure that the model and configuration files have the same and corresponding names except for the suffix**. In addition, the model and configuration file names **must include `roformer`** in order to be recognized as models of the roformer class.
|
||||
|
||||
- The suggestion is to **directly specify the model type** in the model name and configuration file name, such as `mel_mand_roformer`, `bs_roformer`. If not specified, the features will be compared from the configuration file to determine which type of model it is. For example, the model `bs_roformer_ep_368_sdr_12.9628.ckpt` and its corresponding configuration file `bs_roformer_ep_368_sdr_12.9628.yaml` are a pair, `kim_mel_band_roformer.ckpt` and `kim_mel_band_roformer.yaml` are also a pair.
|
||||
|
||||
4. For Chinese ASR (additionally), download models from [Damo ASR Model](https://modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/files), [Damo VAD Model](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/files), and [Damo Punc Model](https://modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/files) and place them in `tools/asr/models`.
|
||||
|
||||
5. For English or Japanese ASR (additionally), download models from [Faster Whisper Large V3](https://huggingface.co/Systran/faster-whisper-large-v3) and place them in `tools/asr/models`. Also, [other models](https://huggingface.co/Systran) may have the similar effect with smaller disk footprint.
|
||||
|
@ -149,6 +149,11 @@ docker run --rm -it --gpus=all --env=is_half=False --volume=G:\GPT-SoVITS-Docker
|
||||
|
||||
3. 对于 UVR5(人声/伴奏分离和混响移除,额外功能),从 [UVR5 Weights](https://huggingface.co/lj1995/VoiceConversionWebUI/tree/main/uvr5_weights) 下载模型,并将其放置在 `tools/uvr5/uvr5_weights` 目录中。
|
||||
|
||||
- 如果你在 UVR5 中使用 `bs_roformer` 或 `mel_band_roformer`模型,你可以手动下载模型和相应的配置文件,并将它们放在 `tools/UVR5/UVR5_weights` 中。**重命名模型文件和配置文件,确保除后缀外**,模型和配置文件具有相同且对应的名称。此外,模型和配置文件名**必须包含“roformer”**,才能被识别为 roformer 类的模型。
|
||||
|
||||
- 建议在模型名称和配置文件名中**直接指定模型类型**,例如`mel_mand_roformer`、`bs_roformer`。如果未指定,将从配置文中比对特征,以确定它是哪种类型的模型。例如,模型`bs_roformer_ep_368_sdr_12.9628.ckpt` 和对应的配置文件`bs_roformer_ep_368_sdr_12.9628.yaml` 是一对。`kim_mel_band_roformer.ckpt` 和 `kim_mel_band_roformer.yaml` 也是一对。
|
||||
|
||||
|
||||
4. 对于中文 ASR(额外功能),从 [Damo ASR Model](https://modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/files)、[Damo VAD Model](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/files) 和 [Damo Punc Model](https://modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/files) 下载模型,并将它们放置在 `tools/asr/models` 目录中。
|
||||
|
||||
5. 对于英语或日语 ASR(额外功能),从 [Faster Whisper Large V3](https://huggingface.co/Systran/faster-whisper-large-v3) 下载模型,并将其放置在 `tools/asr/models` 目录中。此外,[其他模型](https://huggingface.co/Systran) 可能具有类似效果且占用更少的磁盘空间。
|
||||
|
@ -142,6 +142,10 @@ docker run --rm -it --gpus=all --env=is_half=False --volume=G:\GPT-SoVITS-Docker
|
||||
|
||||
3. UVR5(ボーカル/伴奏(BGM等)分離 & リバーブ除去の追加機能)の場合は、[UVR5 Weights](https://huggingface.co/lj1995/VoiceConversionWebUI/tree/main/uvr5_weights) からモデルをダウンロードし、`tools/uvr5/uvr5_weights` ディレクトリに配置してください。
|
||||
|
||||
- UVR5でbs_roformerまたはmel_band_roformerモデルを使用する場合、モデルと対応する設定ファイルを手動でダウンロードし、`tools/UVR5/UVR5_weights`フォルダに配置することができます。**モデルファイルと設定ファイルの名前は、拡張子を除いて同じであることを確認してください**。さらに、モデルと設定ファイルの名前には**「roformer」が含まれている必要があります**。これにより、roformerクラスのモデルとして認識されます。
|
||||
|
||||
- モデル名と設定ファイル名には、**直接モデルタイプを指定することをお勧めします**。例:mel_mand_roformer、bs_roformer。指定しない場合、設定文から特徴を照合して、モデルの種類を特定します。例えば、モデル`bs_roformer_ep_368_sdr_12.9628.ckpt`と対応する設定ファイル`bs_roformer_ep_368_sdr_12.9628.yaml`はペアです。同様に、`kim_mel_band_roformer.ckpt`と`kim_mel_band_roformer.yaml`もペアです。
|
||||
|
||||
4. 中国語ASR(追加機能)の場合は、[Damo ASR Model](https://modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/files)、[Damo VAD Model](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/files)、および [Damo Punc Model](https://modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/files) からモデルをダウンロードし、`tools/asr/models` ディレクトリに配置してください。
|
||||
|
||||
5. 英語または日本語のASR(追加機能)を使用する場合は、[Faster Whisper Large V3](https://huggingface.co/Systran/faster-whisper-large-v3) からモデルをダウンロードし、`tools/asr/models` ディレクトリに配置してください。また、[他のモデル](https://huggingface.co/Systran) は、より小さいサイズで高クオリティな可能性があります。
|
||||
|
@ -145,6 +145,10 @@ docker run --rm -it --gpus=all --env=is_half=False --volume=G:\GPT-SoVITS-Docker
|
||||
|
||||
3. UVR5 (보컬/반주 분리 & 잔향 제거 추가 기능)의 경우, [UVR5 Weights](https://huggingface.co/lj1995/VoiceConversionWebUI/tree/main/uvr5_weights) 에서 모델을 다운로드하고 `tools/uvr5/uvr5_weights` 디렉토리에 배치하세요.
|
||||
|
||||
- UVR5에서 bs_roformer 또는 mel_band_roformer 모델을 사용할 경우, 모델과 해당 설정 파일을 수동으로 다운로드하여 `tools/UVR5/UVR5_weights` 폴더에 저장할 수 있습니다. **모델 파일과 설정 파일의 이름은 확장자를 제외하고 동일한 이름을 가지도록 해야 합니다**. 또한, 모델과 설정 파일 이름에는 **“roformer”**가 포함되어야 roformer 클래스의 모델로 인식됩니다.
|
||||
|
||||
- 모델 이름과 설정 파일 이름에 **모델 유형을 직접 지정하는 것이 좋습니다**. 예: mel_mand_roformer, bs_roformer. 지정하지 않으면 설정 파일을 기준으로 특성을 비교하여 어떤 유형의 모델인지를 판단합니다. 예를 들어, 모델 `bs_roformer_ep_368_sdr_12.9628.ckpt`와 해당 설정 파일 `bs_roformer_ep_368_sdr_12.9628.yaml`은 한 쌍입니다. `kim_mel_band_roformer.ckpt`와 `kim_mel_band_roformer.yaml`도 한 쌍입니다.
|
||||
|
||||
4. 중국어 ASR (추가 기능)의 경우, [Damo ASR Model](https://modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/files), [Damo VAD Model](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/files) 및 [Damo Punc Model](https://modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/files) 에서 모델을 다운로드하고, `tools/asr/models` 디렉토리에 배치하세요.
|
||||
|
||||
5. 영어 또는 일본어 ASR (추가 기능)의 경우, [Faster Whisper Large V3](https://huggingface.co/Systran/faster-whisper-large-v3) 에서 모델을 다운로드하고, `tools/asr/models` 디렉토리에 배치하세요. 또한, [다른 모델](https://huggingface.co/Systran) 은 더 적은 디스크 용량으로 비슷한 효과를 가질 수 있습니다.
|
||||
|
@ -142,6 +142,10 @@ docker run --rm -it --gpus=all --env=is_half=False --volume=G:\GPT-SoVITS-Docker
|
||||
|
||||
3. UVR5 (Vokal/Enstrümantal Ayrımı & Yankı Giderme) için, [UVR5 Weights](https://huggingface.co/lj1995/VoiceConversionWebUI/tree/main/uvr5_weights) üzerinden modelleri indirip `tools/uvr5/uvr5_weights` dizinine yerleştirin.
|
||||
|
||||
- UVR5'te bs_roformer veya mel_band_roformer modellerini kullanıyorsanız, modeli ve ilgili yapılandırma dosyasını manuel olarak indirip `tools/UVR5/UVR5_weights` klasörüne yerleştirebilirsiniz. **Model dosyası ve yapılandırma dosyasının adı, uzantı dışında aynı olmalıdır**. Ayrıca, model ve yapılandırma dosyasının adlarında **“roformer”** kelimesi yer almalıdır, böylece roformer sınıfındaki bir model olarak tanınır.
|
||||
|
||||
- Model adı ve yapılandırma dosyası adı içinde **doğrudan model tipini belirtmek önerilir**. Örneğin: mel_mand_roformer, bs_roformer. Belirtilmezse, yapılandırma dosyasından özellikler karşılaştırılarak model tipi belirlenir. Örneğin, `bs_roformer_ep_368_sdr_12.9628.ckpt` modeli ve karşılık gelen yapılandırma dosyası `bs_roformer_ep_368_sdr_12.9628.yaml` bir çifttir. Aynı şekilde, `kim_mel_band_roformer.ckpt` ve `kim_mel_band_roformer.yaml` da bir çifttir.
|
||||
|
||||
4. Çince ASR için, [Damo ASR Model](https://modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/files), [Damo VAD Model](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/files) ve [Damo Punc Model](https://modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/files) üzerinden modelleri indirip `tools/asr/models` dizinine yerleştirin.
|
||||
|
||||
5. İngilizce veya Japonca ASR için, [Faster Whisper Large V3](https://huggingface.co/Systran/faster-whisper-large-v3) üzerinden modeli indirip `tools/asr/models` dizinine yerleştirin. Ayrıca, [diğer modeller](https://huggingface.co/Systran) benzer bir etki yaratabilir ve daha az disk alanı kaplayabilir.
|
||||
|
@ -1,18 +1,8 @@
|
||||
from functools import wraps
|
||||
from packaging import version
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange, reduce
|
||||
|
||||
# constants
|
||||
|
||||
FlashAttentionConfig = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
@ -20,21 +10,6 @@ def exists(val):
|
||||
def default(v, d):
|
||||
return v if exists(v) else d
|
||||
|
||||
def once(fn):
|
||||
called = False
|
||||
@wraps(fn)
|
||||
def inner(x):
|
||||
nonlocal called
|
||||
if called:
|
||||
return
|
||||
called = True
|
||||
return fn(x)
|
||||
return inner
|
||||
|
||||
print_once = once(print)
|
||||
|
||||
# main class
|
||||
|
||||
class Attend(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -50,43 +25,16 @@ class Attend(nn.Module):
|
||||
self.flash = flash
|
||||
assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
|
||||
|
||||
# determine efficient attention configs for cuda and cpu
|
||||
|
||||
self.cpu_config = FlashAttentionConfig(True, True, True)
|
||||
self.cuda_config = None
|
||||
|
||||
if not torch.cuda.is_available() or not flash:
|
||||
return
|
||||
|
||||
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
|
||||
|
||||
if device_properties.major == 8 and device_properties.minor == 0:
|
||||
print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
|
||||
self.cuda_config = FlashAttentionConfig(True, False, False)
|
||||
else:
|
||||
print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
|
||||
self.cuda_config = FlashAttentionConfig(False, True, True)
|
||||
|
||||
def flash_attn(self, q, k, v):
|
||||
_, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
|
||||
# _, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
|
||||
|
||||
if exists(self.scale):
|
||||
default_scale = q.shape[-1] ** -0.5
|
||||
q = q * (self.scale / default_scale)
|
||||
|
||||
# Check if there is a compatible device for flash attention
|
||||
|
||||
config = self.cuda_config if is_cuda else self.cpu_config
|
||||
|
||||
# pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale
|
||||
|
||||
with torch.backends.cuda.sdp_kernel(**config._asdict()):
|
||||
out = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
dropout_p = self.dropout if self.training else 0.
|
||||
)
|
||||
|
||||
return out
|
||||
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
|
||||
return F.scaled_dot_product_attention(q, k, v,dropout_p = self.dropout if self.training else 0.)
|
||||
|
||||
def forward(self, q, k, v):
|
||||
"""
|
||||
@ -97,7 +45,7 @@ class Attend(nn.Module):
|
||||
d - feature dimension
|
||||
"""
|
||||
|
||||
q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
|
||||
# q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
|
||||
|
||||
scale = default(self.scale, q.shape[-1] ** -0.5)
|
||||
|
||||
|
@ -6,6 +6,7 @@ from torch.nn import Module, ModuleList
|
||||
import torch.nn.functional as F
|
||||
|
||||
from bs_roformer.attend import Attend
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from typing import Tuple, Optional, List, Callable
|
||||
# from beartype.typing import Tuple, Optional, List, Callable
|
||||
@ -356,13 +357,18 @@ class BSRoformer(Module):
|
||||
multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
|
||||
multi_stft_hop_size=147,
|
||||
multi_stft_normalized=False,
|
||||
multi_stft_window_fn: Callable = torch.hann_window
|
||||
multi_stft_window_fn: Callable = torch.hann_window,
|
||||
mlp_expansion_factor=4,
|
||||
use_torch_checkpoint=False,
|
||||
skip_connection=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.stereo = stereo
|
||||
self.audio_channels = 2 if stereo else 1
|
||||
self.num_stems = num_stems
|
||||
self.use_torch_checkpoint = use_torch_checkpoint
|
||||
self.skip_connection = skip_connection
|
||||
|
||||
self.layers = ModuleList([])
|
||||
|
||||
@ -402,7 +408,7 @@ class BSRoformer(Module):
|
||||
|
||||
self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
|
||||
|
||||
freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, return_complex=True).shape[1]
|
||||
freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, window=torch.ones(stft_win_length), return_complex=True).shape[1]
|
||||
|
||||
assert len(freqs_per_bands) > 1
|
||||
assert sum(
|
||||
@ -421,7 +427,8 @@ class BSRoformer(Module):
|
||||
mask_estimator = MaskEstimator(
|
||||
dim=dim,
|
||||
dim_inputs=freqs_per_bands_with_complex,
|
||||
depth=mask_estimator_depth
|
||||
depth=mask_estimator_depth,
|
||||
mlp_expansion_factor=mlp_expansion_factor,
|
||||
)
|
||||
|
||||
self.mask_estimators.append(mask_estimator)
|
||||
@ -458,12 +465,14 @@ class BSRoformer(Module):
|
||||
|
||||
device = raw_audio.device
|
||||
|
||||
# defining whether model is loaded on MPS (MacOS GPU accelerator)
|
||||
x_is_mps = True if device.type == "mps" else False
|
||||
|
||||
if raw_audio.ndim == 2:
|
||||
raw_audio = rearrange(raw_audio, 'b t -> b 1 t')
|
||||
|
||||
channels = raw_audio.shape[1]
|
||||
assert (not self.stereo and channels == 1) or (
|
||||
self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)'
|
||||
assert (not self.stereo and channels == 1) or (self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)'
|
||||
|
||||
# to stft
|
||||
|
||||
@ -471,53 +480,79 @@ class BSRoformer(Module):
|
||||
|
||||
stft_window = self.stft_window_fn(device=device)
|
||||
|
||||
stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
|
||||
# RuntimeError: FFT operations are only supported on MacOS 14+
|
||||
# Since it's tedious to define whether we're on correct MacOS version - simple try-catch is used
|
||||
try:
|
||||
stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
|
||||
except:
|
||||
stft_repr = torch.stft(raw_audio.cpu() if x_is_mps else raw_audio, **self.stft_kwargs, window=stft_window.cpu() if x_is_mps else stft_window, return_complex=True).to(device)
|
||||
|
||||
stft_repr = torch.view_as_real(stft_repr)
|
||||
|
||||
stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c')
|
||||
stft_repr = rearrange(stft_repr,
|
||||
'b s f t c -> b (f s) t c') # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
|
||||
|
||||
# merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
|
||||
stft_repr = rearrange(stft_repr,'b s f t c -> b (f s) t c')
|
||||
|
||||
x = rearrange(stft_repr, 'b f t c -> b t (f c)')
|
||||
# print("460:", x.dtype)#fp32
|
||||
x = self.band_split(x)
|
||||
|
||||
if self.use_torch_checkpoint:
|
||||
x = checkpoint(self.band_split, x, use_reentrant=False)
|
||||
else:
|
||||
x = self.band_split(x)
|
||||
|
||||
# axial / hierarchical attention
|
||||
|
||||
# print("487:",x.dtype)#fp16
|
||||
for transformer_block in self.layers:
|
||||
store = [None] * len(self.layers)
|
||||
for i, transformer_block in enumerate(self.layers):
|
||||
|
||||
if len(transformer_block) == 3:
|
||||
linear_transformer, time_transformer, freq_transformer = transformer_block
|
||||
|
||||
x, ft_ps = pack([x], 'b * d')
|
||||
# print("494:", x.dtype)#fp16
|
||||
x = linear_transformer(x)
|
||||
# print("496:", x.dtype)#fp16
|
||||
if self.use_torch_checkpoint:
|
||||
x = checkpoint(linear_transformer, x, use_reentrant=False)
|
||||
else:
|
||||
x = linear_transformer(x)
|
||||
x, = unpack(x, ft_ps, 'b * d')
|
||||
else:
|
||||
time_transformer, freq_transformer = transformer_block
|
||||
|
||||
# print("501:", x.dtype)#fp16
|
||||
if self.skip_connection:
|
||||
# Sum all previous
|
||||
for j in range(i):
|
||||
x = x + store[j]
|
||||
|
||||
x = rearrange(x, 'b t f d -> b f t d')
|
||||
x, ps = pack([x], '* t d')
|
||||
|
||||
x = time_transformer(x)
|
||||
# print("505:", x.dtype)#fp16
|
||||
if self.use_torch_checkpoint:
|
||||
x = checkpoint(time_transformer, x, use_reentrant=False)
|
||||
else:
|
||||
x = time_transformer(x)
|
||||
|
||||
x, = unpack(x, ps, '* t d')
|
||||
x = rearrange(x, 'b f t d -> b t f d')
|
||||
x, ps = pack([x], '* f d')
|
||||
|
||||
x = freq_transformer(x)
|
||||
if self.use_torch_checkpoint:
|
||||
x = checkpoint(freq_transformer, x, use_reentrant=False)
|
||||
else:
|
||||
x = freq_transformer(x)
|
||||
|
||||
x, = unpack(x, ps, '* f d')
|
||||
|
||||
# print("515:", x.dtype)######fp16
|
||||
if self.skip_connection:
|
||||
store[i] = x
|
||||
|
||||
x = self.final_norm(x)
|
||||
|
||||
num_stems = len(self.mask_estimators)
|
||||
# print("519:", x.dtype)#fp32
|
||||
mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
|
||||
|
||||
if self.use_torch_checkpoint:
|
||||
mask = torch.stack([checkpoint(fn, x, use_reentrant=False) for fn in self.mask_estimators], dim=1)
|
||||
else:
|
||||
mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
|
||||
mask = rearrange(mask, 'b n t (f c) -> b n f t c', c=2)
|
||||
|
||||
# modulate frequency representation
|
||||
@ -535,7 +570,11 @@ class BSRoformer(Module):
|
||||
|
||||
stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels)
|
||||
|
||||
recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False)
|
||||
# same as torch.stft() fix for MacOS MPS above
|
||||
try:
|
||||
recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False, length=raw_audio.shape[-1])
|
||||
except:
|
||||
recon_audio = torch.istft(stft_repr.cpu() if x_is_mps else stft_repr, **self.stft_kwargs, window=stft_window.cpu() if x_is_mps else stft_window, return_complex=False, length=raw_audio.shape[-1]).to(device)
|
||||
|
||||
recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', s=self.audio_channels, n=num_stems)
|
||||
|
||||
|
669
tools/uvr5/bs_roformer/mel_band_roformer.py
Normal file
669
tools/uvr5/bs_roformer/mel_band_roformer.py
Normal file
@ -0,0 +1,669 @@
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from torch import nn, einsum, Tensor
|
||||
from torch.nn import Module, ModuleList
|
||||
import torch.nn.functional as F
|
||||
|
||||
from bs_roformer.attend import Attend
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from typing import Tuple, Optional, List, Callable
|
||||
# from beartype.typing import Tuple, Optional, List, Callable
|
||||
# from beartype import beartype
|
||||
|
||||
from rotary_embedding_torch import RotaryEmbedding
|
||||
|
||||
from einops import rearrange, pack, unpack, reduce, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
from librosa import filters
|
||||
|
||||
|
||||
# helper functions
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def default(v, d):
|
||||
return v if exists(v) else d
|
||||
|
||||
|
||||
def pack_one(t, pattern):
|
||||
return pack([t], pattern)
|
||||
|
||||
|
||||
def unpack_one(t, ps, pattern):
|
||||
return unpack(t, ps, pattern)[0]
|
||||
|
||||
|
||||
def pad_at_dim(t, pad, dim=-1, value=0.):
|
||||
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
|
||||
zeros = ((0, 0) * dims_from_right)
|
||||
return F.pad(t, (*zeros, *pad), value=value)
|
||||
|
||||
|
||||
def l2norm(t):
|
||||
return F.normalize(t, dim=-1, p=2)
|
||||
|
||||
|
||||
# norm
|
||||
|
||||
class RMSNorm(Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.scale = dim ** 0.5
|
||||
self.gamma = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x):
|
||||
return F.normalize(x, dim=-1) * self.scale * self.gamma
|
||||
|
||||
|
||||
# attention
|
||||
|
||||
class FeedForward(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
mult=4,
|
||||
dropout=0.
|
||||
):
|
||||
super().__init__()
|
||||
dim_inner = int(dim * mult)
|
||||
self.net = nn.Sequential(
|
||||
RMSNorm(dim),
|
||||
nn.Linear(dim, dim_inner),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(dim_inner, dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class Attention(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
heads=8,
|
||||
dim_head=64,
|
||||
dropout=0.,
|
||||
rotary_embed=None,
|
||||
flash=True
|
||||
):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
dim_inner = heads * dim_head
|
||||
|
||||
self.rotary_embed = rotary_embed
|
||||
|
||||
self.attend = Attend(flash=flash, dropout=dropout)
|
||||
|
||||
self.norm = RMSNorm(dim)
|
||||
self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
|
||||
|
||||
self.to_gates = nn.Linear(dim, heads)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(dim_inner, dim, bias=False),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
|
||||
q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads)
|
||||
|
||||
if exists(self.rotary_embed):
|
||||
q = self.rotary_embed.rotate_queries_or_keys(q)
|
||||
k = self.rotary_embed.rotate_queries_or_keys(k)
|
||||
|
||||
out = self.attend(q, k, v)
|
||||
|
||||
gates = self.to_gates(x)
|
||||
out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid()
|
||||
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class LinearAttention(Module):
|
||||
"""
|
||||
this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al.
|
||||
"""
|
||||
|
||||
# @beartype
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
dim_head=32,
|
||||
heads=8,
|
||||
scale=8,
|
||||
flash=False,
|
||||
dropout=0.
|
||||
):
|
||||
super().__init__()
|
||||
dim_inner = dim_head * heads
|
||||
self.norm = RMSNorm(dim)
|
||||
|
||||
self.to_qkv = nn.Sequential(
|
||||
nn.Linear(dim, dim_inner * 3, bias=False),
|
||||
Rearrange('b n (qkv h d) -> qkv b h d n', qkv=3, h=heads)
|
||||
)
|
||||
|
||||
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
|
||||
|
||||
self.attend = Attend(
|
||||
scale=scale,
|
||||
dropout=dropout,
|
||||
flash=flash
|
||||
)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
Rearrange('b h d n -> b n (h d)'),
|
||||
nn.Linear(dim_inner, dim, bias=False)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x
|
||||
):
|
||||
x = self.norm(x)
|
||||
|
||||
q, k, v = self.to_qkv(x)
|
||||
|
||||
q, k = map(l2norm, (q, k))
|
||||
q = q * self.temperature.exp()
|
||||
|
||||
out = self.attend(q, k, v)
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class Transformer(Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
depth,
|
||||
dim_head=64,
|
||||
heads=8,
|
||||
attn_dropout=0.,
|
||||
ff_dropout=0.,
|
||||
ff_mult=4,
|
||||
norm_output=True,
|
||||
rotary_embed=None,
|
||||
flash_attn=True,
|
||||
linear_attn=False
|
||||
):
|
||||
super().__init__()
|
||||
self.layers = ModuleList([])
|
||||
|
||||
for _ in range(depth):
|
||||
if linear_attn:
|
||||
attn = LinearAttention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, flash=flash_attn)
|
||||
else:
|
||||
attn = Attention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout,
|
||||
rotary_embed=rotary_embed, flash=flash_attn)
|
||||
|
||||
self.layers.append(ModuleList([
|
||||
attn,
|
||||
FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)
|
||||
]))
|
||||
|
||||
self.norm = RMSNorm(dim) if norm_output else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
|
||||
# bandsplit module
|
||||
|
||||
class BandSplit(Module):
|
||||
# @beartype
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_inputs: Tuple[int, ...]
|
||||
):
|
||||
super().__init__()
|
||||
self.dim_inputs = dim_inputs
|
||||
self.to_features = ModuleList([])
|
||||
|
||||
for dim_in in dim_inputs:
|
||||
net = nn.Sequential(
|
||||
RMSNorm(dim_in),
|
||||
nn.Linear(dim_in, dim)
|
||||
)
|
||||
|
||||
self.to_features.append(net)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.split(self.dim_inputs, dim=-1)
|
||||
|
||||
outs = []
|
||||
for split_input, to_feature in zip(x, self.to_features):
|
||||
split_output = to_feature(split_input)
|
||||
outs.append(split_output)
|
||||
|
||||
return torch.stack(outs, dim=-2)
|
||||
|
||||
|
||||
def MLP(
|
||||
dim_in,
|
||||
dim_out,
|
||||
dim_hidden=None,
|
||||
depth=1,
|
||||
activation=nn.Tanh
|
||||
):
|
||||
dim_hidden = default(dim_hidden, dim_in)
|
||||
|
||||
net = []
|
||||
dims = (dim_in, *((dim_hidden,) * depth), dim_out)
|
||||
|
||||
for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
is_last = ind == (len(dims) - 2)
|
||||
|
||||
net.append(nn.Linear(layer_dim_in, layer_dim_out))
|
||||
|
||||
if is_last:
|
||||
continue
|
||||
|
||||
net.append(activation())
|
||||
|
||||
return nn.Sequential(*net)
|
||||
|
||||
|
||||
class MaskEstimator(Module):
|
||||
# @beartype
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_inputs: Tuple[int, ...],
|
||||
depth,
|
||||
mlp_expansion_factor=4
|
||||
):
|
||||
super().__init__()
|
||||
self.dim_inputs = dim_inputs
|
||||
self.to_freqs = ModuleList([])
|
||||
dim_hidden = dim * mlp_expansion_factor
|
||||
|
||||
for dim_in in dim_inputs:
|
||||
net = []
|
||||
|
||||
mlp = nn.Sequential(
|
||||
MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth),
|
||||
nn.GLU(dim=-1)
|
||||
)
|
||||
|
||||
self.to_freqs.append(mlp)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.unbind(dim=-2)
|
||||
|
||||
outs = []
|
||||
|
||||
for band_features, mlp in zip(x, self.to_freqs):
|
||||
freq_out = mlp(band_features)
|
||||
outs.append(freq_out)
|
||||
|
||||
return torch.cat(outs, dim=-1)
|
||||
|
||||
|
||||
# main class
|
||||
|
||||
class MelBandRoformer(Module):
|
||||
|
||||
# @beartype
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
*,
|
||||
depth,
|
||||
stereo=False,
|
||||
num_stems=1,
|
||||
time_transformer_depth=2,
|
||||
freq_transformer_depth=2,
|
||||
linear_transformer_depth=0,
|
||||
num_bands=60,
|
||||
dim_head=64,
|
||||
heads=8,
|
||||
attn_dropout=0.1,
|
||||
ff_dropout=0.1,
|
||||
flash_attn=True,
|
||||
dim_freqs_in=1025,
|
||||
sample_rate=44100, # needed for mel filter bank from librosa
|
||||
stft_n_fft=2048,
|
||||
stft_hop_length=512,
|
||||
# 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
|
||||
stft_win_length=2048,
|
||||
stft_normalized=False,
|
||||
stft_window_fn: Optional[Callable] = None,
|
||||
mask_estimator_depth=1,
|
||||
multi_stft_resolution_loss_weight=1.,
|
||||
multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
|
||||
multi_stft_hop_size=147,
|
||||
multi_stft_normalized=False,
|
||||
multi_stft_window_fn: Callable = torch.hann_window,
|
||||
match_input_audio_length=False, # if True, pad output tensor to match length of input tensor
|
||||
mlp_expansion_factor=4,
|
||||
use_torch_checkpoint=False,
|
||||
skip_connection=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.stereo = stereo
|
||||
self.audio_channels = 2 if stereo else 1
|
||||
self.num_stems = num_stems
|
||||
self.use_torch_checkpoint = use_torch_checkpoint
|
||||
self.skip_connection = skip_connection
|
||||
|
||||
self.layers = ModuleList([])
|
||||
|
||||
transformer_kwargs = dict(
|
||||
dim=dim,
|
||||
heads=heads,
|
||||
dim_head=dim_head,
|
||||
attn_dropout=attn_dropout,
|
||||
ff_dropout=ff_dropout,
|
||||
flash_attn=flash_attn
|
||||
)
|
||||
|
||||
time_rotary_embed = RotaryEmbedding(dim=dim_head)
|
||||
freq_rotary_embed = RotaryEmbedding(dim=dim_head)
|
||||
|
||||
for _ in range(depth):
|
||||
tran_modules = []
|
||||
if linear_transformer_depth > 0:
|
||||
tran_modules.append(Transformer(depth=linear_transformer_depth, linear_attn=True, **transformer_kwargs))
|
||||
tran_modules.append(
|
||||
Transformer(depth=time_transformer_depth, rotary_embed=time_rotary_embed, **transformer_kwargs)
|
||||
)
|
||||
tran_modules.append(
|
||||
Transformer(depth=freq_transformer_depth, rotary_embed=freq_rotary_embed, **transformer_kwargs)
|
||||
)
|
||||
self.layers.append(nn.ModuleList(tran_modules))
|
||||
|
||||
self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
|
||||
|
||||
self.stft_kwargs = dict(
|
||||
n_fft=stft_n_fft,
|
||||
hop_length=stft_hop_length,
|
||||
win_length=stft_win_length,
|
||||
normalized=stft_normalized
|
||||
)
|
||||
|
||||
freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, window=torch.ones(stft_n_fft), return_complex=True).shape[1]
|
||||
|
||||
# create mel filter bank
|
||||
# with librosa.filters.mel as in section 2 of paper
|
||||
|
||||
mel_filter_bank_numpy = filters.mel(sr=sample_rate, n_fft=stft_n_fft, n_mels=num_bands)
|
||||
|
||||
mel_filter_bank = torch.from_numpy(mel_filter_bank_numpy)
|
||||
|
||||
# for some reason, it doesn't include the first freq? just force a value for now
|
||||
|
||||
mel_filter_bank[0][0] = 1.
|
||||
|
||||
# In some systems/envs we get 0.0 instead of ~1.9e-18 in the last position,
|
||||
# so let's force a positive value
|
||||
|
||||
mel_filter_bank[-1, -1] = 1.
|
||||
|
||||
# binary as in paper (then estimated masks are averaged for overlapping regions)
|
||||
|
||||
freqs_per_band = mel_filter_bank > 0
|
||||
assert freqs_per_band.any(dim=0).all(), 'all frequencies need to be covered by all bands for now'
|
||||
|
||||
repeated_freq_indices = repeat(torch.arange(freqs), 'f -> b f', b=num_bands)
|
||||
freq_indices = repeated_freq_indices[freqs_per_band]
|
||||
|
||||
if stereo:
|
||||
freq_indices = repeat(freq_indices, 'f -> f s', s=2)
|
||||
freq_indices = freq_indices * 2 + torch.arange(2)
|
||||
freq_indices = rearrange(freq_indices, 'f s -> (f s)')
|
||||
|
||||
self.register_buffer('freq_indices', freq_indices, persistent=False)
|
||||
self.register_buffer('freqs_per_band', freqs_per_band, persistent=False)
|
||||
|
||||
num_freqs_per_band = reduce(freqs_per_band, 'b f -> b', 'sum')
|
||||
num_bands_per_freq = reduce(freqs_per_band, 'b f -> f', 'sum')
|
||||
|
||||
self.register_buffer('num_freqs_per_band', num_freqs_per_band, persistent=False)
|
||||
self.register_buffer('num_bands_per_freq', num_bands_per_freq, persistent=False)
|
||||
|
||||
# band split and mask estimator
|
||||
|
||||
freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in num_freqs_per_band.tolist())
|
||||
|
||||
self.band_split = BandSplit(
|
||||
dim=dim,
|
||||
dim_inputs=freqs_per_bands_with_complex
|
||||
)
|
||||
|
||||
self.mask_estimators = nn.ModuleList([])
|
||||
|
||||
for _ in range(num_stems):
|
||||
mask_estimator = MaskEstimator(
|
||||
dim=dim,
|
||||
dim_inputs=freqs_per_bands_with_complex,
|
||||
depth=mask_estimator_depth,
|
||||
mlp_expansion_factor=mlp_expansion_factor,
|
||||
)
|
||||
|
||||
self.mask_estimators.append(mask_estimator)
|
||||
|
||||
# for the multi-resolution stft loss
|
||||
|
||||
self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
|
||||
self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
|
||||
self.multi_stft_n_fft = stft_n_fft
|
||||
self.multi_stft_window_fn = multi_stft_window_fn
|
||||
|
||||
self.multi_stft_kwargs = dict(
|
||||
hop_length=multi_stft_hop_size,
|
||||
normalized=multi_stft_normalized
|
||||
)
|
||||
|
||||
self.match_input_audio_length = match_input_audio_length
|
||||
|
||||
def forward(
|
||||
self,
|
||||
raw_audio,
|
||||
target=None,
|
||||
return_loss_breakdown=False
|
||||
):
|
||||
"""
|
||||
einops
|
||||
|
||||
b - batch
|
||||
f - freq
|
||||
t - time
|
||||
s - audio channel (1 for mono, 2 for stereo)
|
||||
n - number of 'stems'
|
||||
c - complex (2)
|
||||
d - feature dimension
|
||||
"""
|
||||
|
||||
device = raw_audio.device
|
||||
|
||||
if raw_audio.ndim == 2:
|
||||
raw_audio = rearrange(raw_audio, 'b t -> b 1 t')
|
||||
|
||||
batch, channels, raw_audio_length = raw_audio.shape
|
||||
|
||||
istft_length = raw_audio_length if self.match_input_audio_length else None
|
||||
|
||||
assert (not self.stereo and channels == 1) or (
|
||||
self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)'
|
||||
|
||||
# to stft
|
||||
|
||||
raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t')
|
||||
|
||||
stft_window = self.stft_window_fn(device=device)
|
||||
|
||||
stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
|
||||
stft_repr = torch.view_as_real(stft_repr)
|
||||
|
||||
stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c')
|
||||
|
||||
# merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
|
||||
stft_repr = rearrange(stft_repr,'b s f t c -> b (f s) t c')
|
||||
|
||||
# index out all frequencies for all frequency ranges across bands ascending in one go
|
||||
|
||||
batch_arange = torch.arange(batch, device=device)[..., None]
|
||||
|
||||
# account for stereo
|
||||
|
||||
x = stft_repr[batch_arange, self.freq_indices]
|
||||
|
||||
# fold the complex (real and imag) into the frequencies dimension
|
||||
|
||||
x = rearrange(x, 'b f t c -> b t (f c)')
|
||||
|
||||
if self.use_torch_checkpoint:
|
||||
x = checkpoint(self.band_split, x, use_reentrant=False)
|
||||
else:
|
||||
x = self.band_split(x)
|
||||
|
||||
# axial / hierarchical attention
|
||||
|
||||
store = [None] * len(self.layers)
|
||||
for i, transformer_block in enumerate(self.layers):
|
||||
|
||||
if len(transformer_block) == 3:
|
||||
linear_transformer, time_transformer, freq_transformer = transformer_block
|
||||
|
||||
x, ft_ps = pack([x], 'b * d')
|
||||
if self.use_torch_checkpoint:
|
||||
x = checkpoint(linear_transformer, x, use_reentrant=False)
|
||||
else:
|
||||
x = linear_transformer(x)
|
||||
x, = unpack(x, ft_ps, 'b * d')
|
||||
else:
|
||||
time_transformer, freq_transformer = transformer_block
|
||||
|
||||
if self.skip_connection:
|
||||
# Sum all previous
|
||||
for j in range(i):
|
||||
x = x + store[j]
|
||||
|
||||
x = rearrange(x, 'b t f d -> b f t d')
|
||||
x, ps = pack([x], '* t d')
|
||||
|
||||
if self.use_torch_checkpoint:
|
||||
x = checkpoint(time_transformer, x, use_reentrant=False)
|
||||
else:
|
||||
x = time_transformer(x)
|
||||
|
||||
x, = unpack(x, ps, '* t d')
|
||||
x = rearrange(x, 'b f t d -> b t f d')
|
||||
x, ps = pack([x], '* f d')
|
||||
|
||||
if self.use_torch_checkpoint:
|
||||
x = checkpoint(freq_transformer, x, use_reentrant=False)
|
||||
else:
|
||||
x = freq_transformer(x)
|
||||
|
||||
x, = unpack(x, ps, '* f d')
|
||||
|
||||
if self.skip_connection:
|
||||
store[i] = x
|
||||
|
||||
num_stems = len(self.mask_estimators)
|
||||
if self.use_torch_checkpoint:
|
||||
masks = torch.stack([checkpoint(fn, x, use_reentrant=False) for fn in self.mask_estimators], dim=1)
|
||||
else:
|
||||
masks = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
|
||||
masks = rearrange(masks, 'b n t (f c) -> b n f t c', c=2)
|
||||
|
||||
# modulate frequency representation
|
||||
|
||||
stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c')
|
||||
|
||||
# complex number multiplication
|
||||
|
||||
stft_repr = torch.view_as_complex(stft_repr)
|
||||
masks = torch.view_as_complex(masks)
|
||||
|
||||
masks = masks.type(stft_repr.dtype)
|
||||
|
||||
# need to average the estimated mask for the overlapped frequencies
|
||||
|
||||
scatter_indices = repeat(self.freq_indices, 'f -> b n f t', b=batch, n=num_stems, t=stft_repr.shape[-1])
|
||||
|
||||
stft_repr_expanded_stems = repeat(stft_repr, 'b 1 ... -> b n ...', n=num_stems)
|
||||
masks_summed = torch.zeros_like(stft_repr_expanded_stems).scatter_add_(2, scatter_indices, masks)
|
||||
|
||||
denom = repeat(self.num_bands_per_freq, 'f -> (f r) 1', r=channels)
|
||||
|
||||
masks_averaged = masks_summed / denom.clamp(min=1e-8)
|
||||
|
||||
# modulate stft repr with estimated mask
|
||||
|
||||
stft_repr = stft_repr * masks_averaged
|
||||
|
||||
# istft
|
||||
|
||||
stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels)
|
||||
|
||||
recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False,
|
||||
length=istft_length)
|
||||
|
||||
recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', b=batch, s=self.audio_channels, n=num_stems)
|
||||
|
||||
if num_stems == 1:
|
||||
recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t')
|
||||
|
||||
# if a target is passed in, calculate loss for learning
|
||||
|
||||
if not exists(target):
|
||||
return recon_audio
|
||||
|
||||
if self.num_stems > 1:
|
||||
assert target.ndim == 4 and target.shape[1] == self.num_stems
|
||||
|
||||
if target.ndim == 2:
|
||||
target = rearrange(target, '... t -> ... 1 t')
|
||||
|
||||
target = target[..., :recon_audio.shape[-1]] # protect against lost length on istft
|
||||
|
||||
loss = F.l1_loss(recon_audio, target)
|
||||
|
||||
multi_stft_resolution_loss = 0.
|
||||
|
||||
for window_size in self.multi_stft_resolutions_window_sizes:
|
||||
res_stft_kwargs = dict(
|
||||
n_fft=max(window_size, self.multi_stft_n_fft), # not sure what n_fft is across multi resolution stft
|
||||
win_length=window_size,
|
||||
return_complex=True,
|
||||
window=self.multi_stft_window_fn(window_size, device=device),
|
||||
**self.multi_stft_kwargs,
|
||||
)
|
||||
|
||||
recon_Y = torch.stft(rearrange(recon_audio, '... s t -> (... s) t'), **res_stft_kwargs)
|
||||
target_Y = torch.stft(rearrange(target, '... s t -> (... s) t'), **res_stft_kwargs)
|
||||
|
||||
multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y)
|
||||
|
||||
weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
|
||||
|
||||
total_loss = loss + weighted_multi_resolution_loss
|
||||
|
||||
if not return_loss_breakdown:
|
||||
return total_loss
|
||||
|
||||
return total_loss, (loss, multi_stft_resolution_loss)
|
@ -1,6 +1,4 @@
|
||||
# This code is modified from https://github.com/ZFTurbo/
|
||||
import pdb
|
||||
|
||||
import librosa
|
||||
from tqdm import tqdm
|
||||
import os
|
||||
@ -8,61 +6,113 @@ import torch
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import torch.nn as nn
|
||||
|
||||
import yaml
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
from bs_roformer.bs_roformer import BSRoformer
|
||||
|
||||
class BsRoformer_Loader:
|
||||
|
||||
class Roformer_Loader:
|
||||
def get_config(self, config_path):
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
# use fullloader to load tag !!python/tuple, code can be improved
|
||||
config = yaml.load(f, Loader=yaml.FullLoader)
|
||||
return config
|
||||
|
||||
def get_default_config(self):
|
||||
default_config = None
|
||||
if self.model_type == 'bs_roformer':
|
||||
# Use model_bs_roformer_ep_368_sdr_12.9628.yaml and model_bs_roformer_ep_317_sdr_12.9755.yaml as default configuration files
|
||||
# Other BS_Roformer models may not be compatible
|
||||
default_config = {
|
||||
"audio": {"chunk_size": 352800, "sample_rate": 44100},
|
||||
"model": {
|
||||
"dim": 512,
|
||||
"depth": 12,
|
||||
"stereo": True,
|
||||
"num_stems": 1,
|
||||
"time_transformer_depth": 1,
|
||||
"freq_transformer_depth": 1,
|
||||
"linear_transformer_depth": 0,
|
||||
"freqs_per_bands": (2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 12, 12, 12, 12, 12, 12, 12, 12, 24, 24, 24, 24, 24, 24, 24, 24, 48, 48, 48, 48, 48, 48, 48, 48, 128, 129),
|
||||
"dim_head": 64,
|
||||
"heads": 8,
|
||||
"attn_dropout": 0.1,
|
||||
"ff_dropout": 0.1,
|
||||
"flash_attn": True,
|
||||
"dim_freqs_in": 1025,
|
||||
"stft_n_fft": 2048,
|
||||
"stft_hop_length": 441,
|
||||
"stft_win_length": 2048,
|
||||
"stft_normalized": False,
|
||||
"mask_estimator_depth": 2,
|
||||
"multi_stft_resolution_loss_weight": 1.0,
|
||||
"multi_stft_resolutions_window_sizes": (4096, 2048, 1024, 512, 256),
|
||||
"multi_stft_hop_size": 147,
|
||||
"multi_stft_normalized": False,
|
||||
},
|
||||
"training": {"instruments": ["vocals", "other"], "target_instrument": "vocals"},
|
||||
"inference": {"batch_size": 2, "num_overlap": 2}
|
||||
}
|
||||
elif self.model_type == 'mel_band_roformer':
|
||||
# Use model_mel_band_roformer_ep_3005_sdr_11.4360.yaml as default configuration files
|
||||
# Other Mel_Band_Roformer models may not be compatible
|
||||
default_config = {
|
||||
"audio": {"chunk_size": 352800, "sample_rate": 44100},
|
||||
"model": {
|
||||
"dim": 384,
|
||||
"depth": 12,
|
||||
"stereo": True,
|
||||
"num_stems": 1,
|
||||
"time_transformer_depth": 1,
|
||||
"freq_transformer_depth": 1,
|
||||
"linear_transformer_depth": 0,
|
||||
"num_bands": 60,
|
||||
"dim_head": 64,
|
||||
"heads": 8,
|
||||
"attn_dropout": 0.1,
|
||||
"ff_dropout": 0.1,
|
||||
"flash_attn": True,
|
||||
"dim_freqs_in": 1025,
|
||||
"sample_rate": 44100,
|
||||
"stft_n_fft": 2048,
|
||||
"stft_hop_length": 441,
|
||||
"stft_win_length": 2048,
|
||||
"stft_normalized": False,
|
||||
"mask_estimator_depth": 2,
|
||||
"multi_stft_resolution_loss_weight": 1.0,
|
||||
"multi_stft_resolutions_window_sizes": (4096, 2048, 1024, 512, 256),
|
||||
"multi_stft_hop_size": 147,
|
||||
"multi_stft_normalized": False
|
||||
},
|
||||
"training": {"instruments": ["vocals", "other"], "target_instrument": "vocals"},
|
||||
"inference": {"batch_size": 2, "num_overlap": 2}
|
||||
}
|
||||
return default_config
|
||||
|
||||
|
||||
def get_model_from_config(self):
|
||||
config = {
|
||||
"attn_dropout": 0.1,
|
||||
"depth": 12,
|
||||
"dim": 512,
|
||||
"dim_freqs_in": 1025,
|
||||
"dim_head": 64,
|
||||
"ff_dropout": 0.1,
|
||||
"flash_attn": True,
|
||||
"freq_transformer_depth": 1,
|
||||
"freqs_per_bands":(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 12, 12, 12, 12, 12, 12, 12, 12, 24, 24, 24, 24, 24, 24, 24, 24, 48, 48, 48, 48, 48, 48, 48, 48, 128, 129),
|
||||
"heads": 8,
|
||||
"linear_transformer_depth": 0,
|
||||
"mask_estimator_depth": 2,
|
||||
"multi_stft_hop_size": 147,
|
||||
"multi_stft_normalized": False,
|
||||
"multi_stft_resolution_loss_weight": 1.0,
|
||||
"multi_stft_resolutions_window_sizes":(4096, 2048, 1024, 512, 256),
|
||||
"num_stems": 1,
|
||||
"stereo": True,
|
||||
"stft_hop_length": 441,
|
||||
"stft_n_fft": 2048,
|
||||
"stft_normalized": False,
|
||||
"stft_win_length": 2048,
|
||||
"time_transformer_depth": 1,
|
||||
|
||||
}
|
||||
|
||||
|
||||
model = BSRoformer(
|
||||
**dict(config)
|
||||
)
|
||||
|
||||
if self.model_type == 'bs_roformer':
|
||||
from bs_roformer.bs_roformer import BSRoformer
|
||||
model = BSRoformer(**dict(self.config["model"]))
|
||||
elif self.model_type == 'mel_band_roformer':
|
||||
from bs_roformer.mel_band_roformer import MelBandRoformer
|
||||
model = MelBandRoformer(**dict(self.config["model"]))
|
||||
else:
|
||||
print('Error: Unknown model: {}'.format(self.model_type))
|
||||
model = None
|
||||
return model
|
||||
|
||||
|
||||
|
||||
def demix_track(self, model, mix, device):
|
||||
C = 352800
|
||||
# num_overlap
|
||||
N = 1
|
||||
C = self.config["audio"]["chunk_size"] # chunk_size
|
||||
N = self.config["inference"]["num_overlap"]
|
||||
fade_size = C // 10
|
||||
step = int(C // N)
|
||||
border = C - step
|
||||
batch_size = 4
|
||||
batch_size = self.config["inference"]["batch_size"]
|
||||
|
||||
length_init = mix.shape[-1]
|
||||
|
||||
progress_bar = tqdm(total=length_init // step + 1)
|
||||
progress_bar.set_description("Processing")
|
||||
progress_bar = tqdm(total=length_init // step + 1, desc="Processing", leave=False)
|
||||
|
||||
# Do pad from the beginning and end to account floating window results better
|
||||
if length_init > 2 * border and (border > 0):
|
||||
@ -82,7 +132,10 @@ class BsRoformer_Loader:
|
||||
|
||||
with torch.amp.autocast('cuda'):
|
||||
with torch.inference_mode():
|
||||
req_shape = (1, ) + tuple(mix.shape)
|
||||
if self.config["training"]["target_instrument"] is None:
|
||||
req_shape = (len(self.config["training"]["instruments"]),) + tuple(mix.shape)
|
||||
else:
|
||||
req_shape = (1, ) + tuple(mix.shape)
|
||||
|
||||
result = torch.zeros(req_shape, dtype=torch.float32)
|
||||
counter = torch.zeros(req_shape, dtype=torch.float32)
|
||||
@ -97,7 +150,7 @@ class BsRoformer_Loader:
|
||||
part = nn.functional.pad(input=part, pad=(0, C - length), mode='reflect')
|
||||
else:
|
||||
part = nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0)
|
||||
if(self.is_half==True):
|
||||
if self.is_half:
|
||||
part=part.half()
|
||||
batch_data.append(part)
|
||||
batch_locations.append((i, length))
|
||||
@ -133,78 +186,109 @@ class BsRoformer_Loader:
|
||||
|
||||
progress_bar.close()
|
||||
|
||||
return {k: v for k, v in zip(['vocals', 'other'], estimated_sources)}
|
||||
if self.config["training"]["target_instrument"] is None:
|
||||
return {k: v for k, v in zip(self.config["training"]["instruments"], estimated_sources)}
|
||||
else:
|
||||
return {k: v for k, v in zip([self.config["training"]["target_instrument"]], estimated_sources)}
|
||||
|
||||
|
||||
def run_folder(self,input, vocal_root, others_root, format):
|
||||
# start_time = time.time()
|
||||
def run_folder(self, input, vocal_root, others_root, format):
|
||||
self.model.eval()
|
||||
path = input
|
||||
os.makedirs(vocal_root, exist_ok=True)
|
||||
os.makedirs(others_root, exist_ok=True)
|
||||
file_base_name = os.path.splitext(os.path.basename(path))[0]
|
||||
|
||||
if not os.path.isdir(vocal_root):
|
||||
os.mkdir(vocal_root)
|
||||
|
||||
if not os.path.isdir(others_root):
|
||||
os.mkdir(others_root)
|
||||
sample_rate = 44100
|
||||
if 'sample_rate' in self.config["audio"]:
|
||||
sample_rate = self.config["audio"]['sample_rate']
|
||||
|
||||
try:
|
||||
mix, sr = librosa.load(path, sr=44100, mono=False)
|
||||
mix, sr = librosa.load(path, sr=sample_rate, mono=False)
|
||||
except Exception as e:
|
||||
print('Can read track: {}'.format(path))
|
||||
print('Error message: {}'.format(str(e)))
|
||||
return
|
||||
|
||||
# Convert mono to stereo if needed
|
||||
if len(mix.shape) == 1:
|
||||
mix = np.stack([mix, mix], axis=0)
|
||||
# in case if model only supports mono tracks
|
||||
isstereo = self.config["model"].get("stereo", True)
|
||||
if not isstereo and len(mix.shape) != 1:
|
||||
mix = np.mean(mix, axis=0) # if more than 2 channels, take mean
|
||||
print("Warning: Track has more than 1 channels, but model is mono, taking mean of all channels.")
|
||||
|
||||
mix_orig = mix.copy()
|
||||
|
||||
mixture = torch.tensor(mix, dtype=torch.float32)
|
||||
res = self.demix_track(self.model, mixture, self.device)
|
||||
|
||||
estimates = res['vocals'].T
|
||||
|
||||
if format in ["wav", "flac"]:
|
||||
sf.write("{}/{}_{}.{}".format(vocal_root, os.path.basename(path)[:-4], 'vocals', format), estimates, sr)
|
||||
sf.write("{}/{}_{}.{}".format(others_root, os.path.basename(path)[:-4], 'instrumental', format), mix_orig.T - estimates, sr)
|
||||
if self.config["training"]["target_instrument"] is not None:
|
||||
# if target instrument is specified, save target instrument as vocal and other instruments as others
|
||||
# other instruments are caculated by subtracting target instrument from mixture
|
||||
target_instrument = self.config["training"]["target_instrument"]
|
||||
other_instruments = [i for i in self.config["training"]["instruments"] if i != target_instrument]
|
||||
other = mix_orig - res[target_instrument] # caculate other instruments
|
||||
|
||||
path_vocal = "{}/{}_{}.wav".format(vocal_root, file_base_name, target_instrument)
|
||||
path_other = "{}/{}_{}.wav".format(others_root, file_base_name, other_instruments[0])
|
||||
self.save_audio(path_vocal, res[target_instrument].T, sr, format)
|
||||
self.save_audio(path_other, other.T, sr, format)
|
||||
else:
|
||||
path_vocal = "%s/%s_vocals.wav" % (vocal_root, os.path.basename(path)[:-4])
|
||||
path_other = "%s/%s_instrumental.wav" % (others_root, os.path.basename(path)[:-4])
|
||||
sf.write(path_vocal, estimates, sr)
|
||||
sf.write(path_other, mix_orig.T - estimates, sr)
|
||||
opt_path_vocal = path_vocal[:-4] + ".%s" % format
|
||||
opt_path_other = path_other[:-4] + ".%s" % format
|
||||
if os.path.exists(path_vocal):
|
||||
os.system(
|
||||
"ffmpeg -i '%s' -vn '%s' -q:a 2 -y" % (path_vocal, opt_path_vocal)
|
||||
)
|
||||
if os.path.exists(opt_path_vocal):
|
||||
try:
|
||||
os.remove(path_vocal)
|
||||
except:
|
||||
pass
|
||||
if os.path.exists(path_other):
|
||||
os.system(
|
||||
"ffmpeg -i '%s' -vn '%s' -q:a 2 -y" % (path_other, opt_path_other)
|
||||
)
|
||||
if os.path.exists(opt_path_other):
|
||||
try:
|
||||
os.remove(path_other)
|
||||
except:
|
||||
pass
|
||||
|
||||
# print("Elapsed time: {:.2f} sec".format(time.time() - start_time))
|
||||
# if target instrument is not specified, save the first instrument as vocal and the rest as others
|
||||
vocal_inst = self.config["training"]["instruments"][0]
|
||||
path_vocal = "{}/{}_{}.wav".format(vocal_root, file_base_name, vocal_inst)
|
||||
self.save_audio(path_vocal, res[vocal_inst].T, sr, format)
|
||||
for other in self.config["training"]["instruments"][1:]: # save other instruments
|
||||
path_other = "{}/{}_{}.wav".format(others_root, file_base_name, other)
|
||||
self.save_audio(path_other, res[other].T, sr, format)
|
||||
|
||||
|
||||
def __init__(self, model_path, device,is_half):
|
||||
def save_audio(self, path, data, sr, format):
|
||||
# input path should be endwith '.wav'
|
||||
if format in ["wav", "flac"]:
|
||||
if format == "flac":
|
||||
path = path[:-3] + "flac"
|
||||
sf.write(path, data, sr)
|
||||
else:
|
||||
sf.write(path, data, sr)
|
||||
os.system("ffmpeg -i \"{}\" -vn \"{}\" -q:a 2 -y".format(path, path[:-3] + format))
|
||||
try: os.remove(path)
|
||||
except: pass
|
||||
|
||||
|
||||
def __init__(self, model_path, config_path, device, is_half):
|
||||
self.device = device
|
||||
self.extract_instrumental=True
|
||||
self.is_half = is_half
|
||||
self.model_type = None
|
||||
self.config = None
|
||||
|
||||
# get model_type, first try:
|
||||
if "bs_roformer" in model_path.lower() or "bsroformer" in model_path.lower():
|
||||
self.model_type = "bs_roformer"
|
||||
elif "mel_band_roformer" in model_path.lower() or "melbandroformer" in model_path.lower():
|
||||
self.model_type = "mel_band_roformer"
|
||||
|
||||
if not os.path.exists(config_path):
|
||||
if self.model_type is None:
|
||||
# if model_type is still None, raise an error
|
||||
raise ValueError("Error: Unknown model type. If you are using a model without a configuration file, Ensure that your model name includes 'bs_roformer', 'bsroformer', 'mel_band_roformer', or 'melbandroformer'. Otherwise, you can manually place the model configuration file into 'tools/uvr5/uvr5w_weights' and ensure that the configuration file is named as '<model_name>.yaml' then try it again.")
|
||||
self.config = self.get_default_config()
|
||||
else:
|
||||
# if there is a configuration file
|
||||
self.config = self.get_config(config_path)
|
||||
if self.model_type is None:
|
||||
# if model_type is still None, second try, get model_type from the configuration file
|
||||
if "freqs_per_bands" in self.config["model"]:
|
||||
# if freqs_per_bands in config, it's a bs_roformer model
|
||||
self.model_type = "bs_roformer"
|
||||
else:
|
||||
# else it's a mel_band_roformer model
|
||||
self.model_type = "mel_band_roformer"
|
||||
|
||||
print("Detected model type: {}".format(self.model_type))
|
||||
model = self.get_model_from_config()
|
||||
state_dict = torch.load(model_path,map_location="cpu")
|
||||
state_dict = torch.load(model_path, map_location="cpu")
|
||||
model.load_state_dict(state_dict)
|
||||
self.is_half=is_half
|
||||
|
||||
if(is_half==False):
|
||||
self.model = model.to(device)
|
||||
else:
|
||||
|
@ -12,7 +12,7 @@ import torch
|
||||
import sys
|
||||
from mdxnet import MDXNetDereverb
|
||||
from vr import AudioPre, AudioPreDeEcho
|
||||
from bsroformer import BsRoformer_Loader
|
||||
from bsroformer import Roformer_Loader
|
||||
|
||||
try:
|
||||
import gradio.analytics as analytics
|
||||
@ -49,13 +49,17 @@ def uvr(model_name, inp_root, save_root_vocal, paths, save_root_ins, agg, format
|
||||
is_hp3 = "HP3" in model_name
|
||||
if model_name == "onnx_dereverb_By_FoxJoy":
|
||||
pre_fun = MDXNetDereverb(15)
|
||||
elif model_name == "Bs_Roformer" or "bs_roformer" in model_name.lower():
|
||||
func = BsRoformer_Loader
|
||||
elif "roformer" in model_name.lower():
|
||||
func = Roformer_Loader
|
||||
pre_fun = func(
|
||||
model_path = os.path.join(weight_uvr5_root, model_name + ".ckpt"),
|
||||
config_path = os.path.join(weight_uvr5_root, model_name + ".yaml"),
|
||||
device = device,
|
||||
is_half=is_half
|
||||
)
|
||||
if not os.path.exists(os.path.join(weight_uvr5_root, model_name + ".yaml")):
|
||||
infos.append("Warning: You are using a model without a configuration file. The program will automatically use the default configuration file. However, the default configuration file cannot guarantee that all models will run successfully. You can manually place the model configuration file into 'tools/uvr5/uvr5w_weights' and ensure that the configuration file is named as '<model_name>.yaml' then try it again. (For example, the configuration file corresponding to the model 'bs_roformer_ep_368_sdr_12.9628.ckpt' should be 'bs_roformer_ep_368_sdr_12.9628.yaml'.) Or you can just ignore this warning.")
|
||||
yield "\n".join(infos)
|
||||
else:
|
||||
func = AudioPre if "DeEcho" not in model_name else AudioPreDeEcho
|
||||
pre_fun = func(
|
||||
|
Loading…
x
Reference in New Issue
Block a user