Simplify BSR (#1356)

This commit is contained in:
KamioRinn 2024-07-30 10:32:37 +08:00 committed by GitHub
parent 8abc0342d7
commit 7670bc77c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 13 additions and 13 deletions

View File

@ -25,4 +25,5 @@ jieba_fast
jieba
LangSegment>=0.2.0
Faster_Whisper
wordsegment
wordsegment
rotary_embedding_torch

View File

@ -7,8 +7,9 @@ import torch.nn.functional as F
from bs_roformer.attend import Attend
from beartype.typing import Tuple, Optional, List, Callable
from beartype import beartype
from typing import Tuple, Optional, List, Callable
# from beartype.typing import Tuple, Optional, List, Callable
# from beartype import beartype
from rotary_embedding_torch import RotaryEmbedding
@ -125,7 +126,7 @@ class LinearAttention(Module):
this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al.
"""
@beartype
# @beartype
def __init__(
self,
*,
@ -219,7 +220,7 @@ class Transformer(Module):
# bandsplit module
class BandSplit(Module):
@beartype
# @beartype
def __init__(
self,
dim,
@ -274,7 +275,7 @@ def MLP(
class MaskEstimator(Module):
@beartype
# @beartype
def __init__(
self,
dim,
@ -325,7 +326,7 @@ DEFAULT_FREQS_PER_BANDS = (
class BSRoformer(Module):
@beartype
# @beartype
def __init__(
self,
dim,

View File

@ -1,10 +1,8 @@
# This code is modified from https://github.com/ZFTurbo/
import time
import librosa
from tqdm import tqdm
import os
import glob
import torch
import numpy as np
import soundfile as sf
@ -52,7 +50,8 @@ class BsRoformer_Loader:
def demix_track(self, model, mix, device):
C = 352800
N = 2
# num_overlap
N = 1
fade_size = C // 10
step = int(C // N)
border = C - step
@ -60,7 +59,7 @@ class BsRoformer_Loader:
length_init = mix.shape[-1]
progress_bar = tqdm(total=(length_init//step)+3)
progress_bar = tqdm(total=length_init // step + 1)
progress_bar.set_description("Processing")
# Do pad from the beginning and end to account floating window results better
@ -79,7 +78,7 @@ class BsRoformer_Loader:
window_middle[-fade_size:] *= fadeout
window_middle[:fade_size] *= fadein
with torch.cuda.amp.autocast():
with torch.amp.autocast('cuda'):
with torch.inference_mode():
req_shape = (1, ) + tuple(mix.shape)
@ -160,7 +159,6 @@ class BsRoformer_Loader:
res = self.demix_track(self.model, mixture, self.device)
estimates = res['vocals'].T
print("{}/{}_{}.{}".format(vocal_root, os.path.basename(path)[:-4], 'vocals', format))
if format in ["wav", "flac"]:
sf.write("{}/{}_{}.{}".format(vocal_root, os.path.basename(path)[:-4], 'vocals', format), estimates, sr)