mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-06 03:57:44 +08:00
Simplify BSR (#1356)
This commit is contained in:
parent
8abc0342d7
commit
7670bc77c3
@ -25,4 +25,5 @@ jieba_fast
|
|||||||
jieba
|
jieba
|
||||||
LangSegment>=0.2.0
|
LangSegment>=0.2.0
|
||||||
Faster_Whisper
|
Faster_Whisper
|
||||||
wordsegment
|
wordsegment
|
||||||
|
rotary_embedding_torch
|
@ -7,8 +7,9 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
from bs_roformer.attend import Attend
|
from bs_roformer.attend import Attend
|
||||||
|
|
||||||
from beartype.typing import Tuple, Optional, List, Callable
|
from typing import Tuple, Optional, List, Callable
|
||||||
from beartype import beartype
|
# from beartype.typing import Tuple, Optional, List, Callable
|
||||||
|
# from beartype import beartype
|
||||||
|
|
||||||
from rotary_embedding_torch import RotaryEmbedding
|
from rotary_embedding_torch import RotaryEmbedding
|
||||||
|
|
||||||
@ -125,7 +126,7 @@ class LinearAttention(Module):
|
|||||||
this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al.
|
this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@beartype
|
# @beartype
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
@ -219,7 +220,7 @@ class Transformer(Module):
|
|||||||
# bandsplit module
|
# bandsplit module
|
||||||
|
|
||||||
class BandSplit(Module):
|
class BandSplit(Module):
|
||||||
@beartype
|
# @beartype
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dim,
|
dim,
|
||||||
@ -274,7 +275,7 @@ def MLP(
|
|||||||
|
|
||||||
|
|
||||||
class MaskEstimator(Module):
|
class MaskEstimator(Module):
|
||||||
@beartype
|
# @beartype
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dim,
|
dim,
|
||||||
@ -325,7 +326,7 @@ DEFAULT_FREQS_PER_BANDS = (
|
|||||||
|
|
||||||
class BSRoformer(Module):
|
class BSRoformer(Module):
|
||||||
|
|
||||||
@beartype
|
# @beartype
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dim,
|
dim,
|
||||||
|
@ -1,10 +1,8 @@
|
|||||||
# This code is modified from https://github.com/ZFTurbo/
|
# This code is modified from https://github.com/ZFTurbo/
|
||||||
|
|
||||||
import time
|
|
||||||
import librosa
|
import librosa
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import os
|
import os
|
||||||
import glob
|
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
@ -52,7 +50,8 @@ class BsRoformer_Loader:
|
|||||||
|
|
||||||
def demix_track(self, model, mix, device):
|
def demix_track(self, model, mix, device):
|
||||||
C = 352800
|
C = 352800
|
||||||
N = 2
|
# num_overlap
|
||||||
|
N = 1
|
||||||
fade_size = C // 10
|
fade_size = C // 10
|
||||||
step = int(C // N)
|
step = int(C // N)
|
||||||
border = C - step
|
border = C - step
|
||||||
@ -60,7 +59,7 @@ class BsRoformer_Loader:
|
|||||||
|
|
||||||
length_init = mix.shape[-1]
|
length_init = mix.shape[-1]
|
||||||
|
|
||||||
progress_bar = tqdm(total=(length_init//step)+3)
|
progress_bar = tqdm(total=length_init // step + 1)
|
||||||
progress_bar.set_description("Processing")
|
progress_bar.set_description("Processing")
|
||||||
|
|
||||||
# Do pad from the beginning and end to account floating window results better
|
# Do pad from the beginning and end to account floating window results better
|
||||||
@ -79,7 +78,7 @@ class BsRoformer_Loader:
|
|||||||
window_middle[-fade_size:] *= fadeout
|
window_middle[-fade_size:] *= fadeout
|
||||||
window_middle[:fade_size] *= fadein
|
window_middle[:fade_size] *= fadein
|
||||||
|
|
||||||
with torch.cuda.amp.autocast():
|
with torch.amp.autocast('cuda'):
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
req_shape = (1, ) + tuple(mix.shape)
|
req_shape = (1, ) + tuple(mix.shape)
|
||||||
|
|
||||||
@ -160,7 +159,6 @@ class BsRoformer_Loader:
|
|||||||
res = self.demix_track(self.model, mixture, self.device)
|
res = self.demix_track(self.model, mixture, self.device)
|
||||||
|
|
||||||
estimates = res['vocals'].T
|
estimates = res['vocals'].T
|
||||||
print("{}/{}_{}.{}".format(vocal_root, os.path.basename(path)[:-4], 'vocals', format))
|
|
||||||
|
|
||||||
if format in ["wav", "flac"]:
|
if format in ["wav", "flac"]:
|
||||||
sf.write("{}/{}_{}.{}".format(vocal_root, os.path.basename(path)[:-4], 'vocals', format), estimates, sr)
|
sf.write("{}/{}_{}.{}".format(vocal_root, os.path.basename(path)[:-4], 'vocals', format), estimates, sr)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user