Simplify BSR (#1356)

This commit is contained in:
KamioRinn 2024-07-30 10:32:37 +08:00 committed by GitHub
parent 8abc0342d7
commit 7670bc77c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 13 additions and 13 deletions

View File

@ -25,4 +25,5 @@ jieba_fast
jieba jieba
LangSegment>=0.2.0 LangSegment>=0.2.0
Faster_Whisper Faster_Whisper
wordsegment wordsegment
rotary_embedding_torch

View File

@ -7,8 +7,9 @@ import torch.nn.functional as F
from bs_roformer.attend import Attend from bs_roformer.attend import Attend
from beartype.typing import Tuple, Optional, List, Callable from typing import Tuple, Optional, List, Callable
from beartype import beartype # from beartype.typing import Tuple, Optional, List, Callable
# from beartype import beartype
from rotary_embedding_torch import RotaryEmbedding from rotary_embedding_torch import RotaryEmbedding
@ -125,7 +126,7 @@ class LinearAttention(Module):
this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al. this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al.
""" """
@beartype # @beartype
def __init__( def __init__(
self, self,
*, *,
@ -219,7 +220,7 @@ class Transformer(Module):
# bandsplit module # bandsplit module
class BandSplit(Module): class BandSplit(Module):
@beartype # @beartype
def __init__( def __init__(
self, self,
dim, dim,
@ -274,7 +275,7 @@ def MLP(
class MaskEstimator(Module): class MaskEstimator(Module):
@beartype # @beartype
def __init__( def __init__(
self, self,
dim, dim,
@ -325,7 +326,7 @@ DEFAULT_FREQS_PER_BANDS = (
class BSRoformer(Module): class BSRoformer(Module):
@beartype # @beartype
def __init__( def __init__(
self, self,
dim, dim,

View File

@ -1,10 +1,8 @@
# This code is modified from https://github.com/ZFTurbo/ # This code is modified from https://github.com/ZFTurbo/
import time
import librosa import librosa
from tqdm import tqdm from tqdm import tqdm
import os import os
import glob
import torch import torch
import numpy as np import numpy as np
import soundfile as sf import soundfile as sf
@ -52,7 +50,8 @@ class BsRoformer_Loader:
def demix_track(self, model, mix, device): def demix_track(self, model, mix, device):
C = 352800 C = 352800
N = 2 # num_overlap
N = 1
fade_size = C // 10 fade_size = C // 10
step = int(C // N) step = int(C // N)
border = C - step border = C - step
@ -60,7 +59,7 @@ class BsRoformer_Loader:
length_init = mix.shape[-1] length_init = mix.shape[-1]
progress_bar = tqdm(total=(length_init//step)+3) progress_bar = tqdm(total=length_init // step + 1)
progress_bar.set_description("Processing") progress_bar.set_description("Processing")
# Do pad from the beginning and end to account floating window results better # Do pad from the beginning and end to account floating window results better
@ -79,7 +78,7 @@ class BsRoformer_Loader:
window_middle[-fade_size:] *= fadeout window_middle[-fade_size:] *= fadeout
window_middle[:fade_size] *= fadein window_middle[:fade_size] *= fadein
with torch.cuda.amp.autocast(): with torch.amp.autocast('cuda'):
with torch.inference_mode(): with torch.inference_mode():
req_shape = (1, ) + tuple(mix.shape) req_shape = (1, ) + tuple(mix.shape)
@ -160,7 +159,6 @@ class BsRoformer_Loader:
res = self.demix_track(self.model, mixture, self.device) res = self.demix_track(self.model, mixture, self.device)
estimates = res['vocals'].T estimates = res['vocals'].T
print("{}/{}_{}.{}".format(vocal_root, os.path.basename(path)[:-4], 'vocals', format))
if format in ["wav", "flac"]: if format in ["wav", "flac"]:
sf.write("{}/{}_{}.{}".format(vocal_root, os.path.basename(path)[:-4], 'vocals', format), estimates, sr) sf.write("{}/{}_{}.{}".format(vocal_root, os.path.basename(path)[:-4], 'vocals', format), estimates, sr)