GPT-SoVITS/tools/uvr5/bsroformer.py
XXXXRT666 53cac93589
Refactor: Format Code with Ruff and Update Deprecated G2PW Link (#2255)
* ruff check --fix

* ruff format --line-length 120 --target-version py39

* Change the link for G2PW Model

* update pytorch version and colab
2025-04-07 16:42:47 +08:00

305 lines
14 KiB
Python

# This code is modified from https://github.com/ZFTurbo/
import os
import warnings
import librosa
import numpy as np
import soundfile as sf
import torch
import torch.nn as nn
import yaml
from tqdm import tqdm
warnings.filterwarnings("ignore")
class Roformer_Loader:
def get_config(self, config_path):
with open(config_path, "r", encoding="utf-8") as f:
# use fullloader to load tag !!python/tuple, code can be improved
config = yaml.load(f, Loader=yaml.FullLoader)
return config
def get_default_config(self):
default_config = None
if self.model_type == "bs_roformer":
# Use model_bs_roformer_ep_368_sdr_12.9628.yaml and model_bs_roformer_ep_317_sdr_12.9755.yaml as default configuration files
# Other BS_Roformer models may not be compatible
# fmt: off
default_config = {
"audio": {"chunk_size": 352800, "sample_rate": 44100},
"model": {
"dim": 512,
"depth": 12,
"stereo": True,
"num_stems": 1,
"time_transformer_depth": 1,
"freq_transformer_depth": 1,
"linear_transformer_depth": 0,
"freqs_per_bands": (2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 12, 12, 12, 12, 12, 12, 12, 12, 24, 24, 24, 24, 24, 24, 24, 24, 48, 48, 48, 48, 48, 48, 48, 48, 128, 129),
"dim_head": 64,
"heads": 8,
"attn_dropout": 0.1,
"ff_dropout": 0.1,
"flash_attn": True,
"dim_freqs_in": 1025,
"stft_n_fft": 2048,
"stft_hop_length": 441,
"stft_win_length": 2048,
"stft_normalized": False,
"mask_estimator_depth": 2,
"multi_stft_resolution_loss_weight": 1.0,
"multi_stft_resolutions_window_sizes": (4096, 2048, 1024, 512, 256),
"multi_stft_hop_size": 147,
"multi_stft_normalized": False,
},
"training": {"instruments": ["vocals", "other"], "target_instrument": "vocals"},
"inference": {"batch_size": 2, "num_overlap": 2},
}
# fmt: on
elif self.model_type == "mel_band_roformer":
# Use model_mel_band_roformer_ep_3005_sdr_11.4360.yaml as default configuration files
# Other Mel_Band_Roformer models may not be compatible
default_config = {
"audio": {"chunk_size": 352800, "sample_rate": 44100},
"model": {
"dim": 384,
"depth": 12,
"stereo": True,
"num_stems": 1,
"time_transformer_depth": 1,
"freq_transformer_depth": 1,
"linear_transformer_depth": 0,
"num_bands": 60,
"dim_head": 64,
"heads": 8,
"attn_dropout": 0.1,
"ff_dropout": 0.1,
"flash_attn": True,
"dim_freqs_in": 1025,
"sample_rate": 44100,
"stft_n_fft": 2048,
"stft_hop_length": 441,
"stft_win_length": 2048,
"stft_normalized": False,
"mask_estimator_depth": 2,
"multi_stft_resolution_loss_weight": 1.0,
"multi_stft_resolutions_window_sizes": (4096, 2048, 1024, 512, 256),
"multi_stft_hop_size": 147,
"multi_stft_normalized": False,
},
"training": {"instruments": ["vocals", "other"], "target_instrument": "vocals"},
"inference": {"batch_size": 2, "num_overlap": 2},
}
return default_config
def get_model_from_config(self):
if self.model_type == "bs_roformer":
from bs_roformer.bs_roformer import BSRoformer
model = BSRoformer(**dict(self.config["model"]))
elif self.model_type == "mel_band_roformer":
from bs_roformer.mel_band_roformer import MelBandRoformer
model = MelBandRoformer(**dict(self.config["model"]))
else:
print("Error: Unknown model: {}".format(self.model_type))
model = None
return model
def demix_track(self, model, mix, device):
C = self.config["audio"]["chunk_size"] # chunk_size
N = self.config["inference"]["num_overlap"]
fade_size = C // 10
step = int(C // N)
border = C - step
batch_size = self.config["inference"]["batch_size"]
length_init = mix.shape[-1]
progress_bar = tqdm(total=length_init // step + 1, desc="Processing", leave=False)
# Do pad from the beginning and end to account floating window results better
if length_init > 2 * border and (border > 0):
mix = nn.functional.pad(mix, (border, border), mode="reflect")
# Prepare windows arrays (do 1 time for speed up). This trick repairs click problems on the edges of segment
window_size = C
fadein = torch.linspace(0, 1, fade_size)
fadeout = torch.linspace(1, 0, fade_size)
window_start = torch.ones(window_size)
window_middle = torch.ones(window_size)
window_finish = torch.ones(window_size)
window_start[-fade_size:] *= fadeout # First audio chunk, no fadein
window_finish[:fade_size] *= fadein # Last audio chunk, no fadeout
window_middle[-fade_size:] *= fadeout
window_middle[:fade_size] *= fadein
with torch.amp.autocast("cuda"):
with torch.inference_mode():
if self.config["training"]["target_instrument"] is None:
req_shape = (len(self.config["training"]["instruments"]),) + tuple(mix.shape)
else:
req_shape = (1,) + tuple(mix.shape)
result = torch.zeros(req_shape, dtype=torch.float32)
counter = torch.zeros(req_shape, dtype=torch.float32)
i = 0
batch_data = []
batch_locations = []
while i < mix.shape[1]:
part = mix[:, i : i + C].to(device)
length = part.shape[-1]
if length < C:
if length > C // 2 + 1:
part = nn.functional.pad(input=part, pad=(0, C - length), mode="reflect")
else:
part = nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode="constant", value=0)
if self.is_half:
part = part.half()
batch_data.append(part)
batch_locations.append((i, length))
i += step
progress_bar.update(1)
if len(batch_data) >= batch_size or (i >= mix.shape[1]):
arr = torch.stack(batch_data, dim=0)
# print(23333333,arr.dtype)
x = model(arr)
window = window_middle
if i - step == 0: # First audio chunk, no fadein
window = window_start
elif i >= mix.shape[1]: # Last audio chunk, no fadeout
window = window_finish
for j in range(len(batch_locations)):
start, l = batch_locations[j]
result[..., start : start + l] += x[j][..., :l].cpu() * window[..., :l]
counter[..., start : start + l] += window[..., :l]
batch_data = []
batch_locations = []
estimated_sources = result / counter
estimated_sources = estimated_sources.cpu().numpy()
np.nan_to_num(estimated_sources, copy=False, nan=0.0)
if length_init > 2 * border and (border > 0):
# Remove pad
estimated_sources = estimated_sources[..., border:-border]
progress_bar.close()
if self.config["training"]["target_instrument"] is None:
return {k: v for k, v in zip(self.config["training"]["instruments"], estimated_sources)}
else:
return {k: v for k, v in zip([self.config["training"]["target_instrument"]], estimated_sources)}
def run_folder(self, input, vocal_root, others_root, format):
self.model.eval()
path = input
os.makedirs(vocal_root, exist_ok=True)
os.makedirs(others_root, exist_ok=True)
file_base_name = os.path.splitext(os.path.basename(path))[0]
sample_rate = 44100
if "sample_rate" in self.config["audio"]:
sample_rate = self.config["audio"]["sample_rate"]
try:
mix, sr = librosa.load(path, sr=sample_rate, mono=False)
except Exception as e:
print("Can read track: {}".format(path))
print("Error message: {}".format(str(e)))
return
# in case if model only supports mono tracks
isstereo = self.config["model"].get("stereo", True)
if not isstereo and len(mix.shape) != 1:
mix = np.mean(mix, axis=0) # if more than 2 channels, take mean
print("Warning: Track has more than 1 channels, but model is mono, taking mean of all channels.")
mix_orig = mix.copy()
mixture = torch.tensor(mix, dtype=torch.float32)
res = self.demix_track(self.model, mixture, self.device)
if self.config["training"]["target_instrument"] is not None:
# if target instrument is specified, save target instrument as vocal and other instruments as others
# other instruments are caculated by subtracting target instrument from mixture
target_instrument = self.config["training"]["target_instrument"]
other_instruments = [i for i in self.config["training"]["instruments"] if i != target_instrument]
other = mix_orig - res[target_instrument] # caculate other instruments
path_vocal = "{}/{}_{}.wav".format(vocal_root, file_base_name, target_instrument)
path_other = "{}/{}_{}.wav".format(others_root, file_base_name, other_instruments[0])
self.save_audio(path_vocal, res[target_instrument].T, sr, format)
self.save_audio(path_other, other.T, sr, format)
else:
# if target instrument is not specified, save the first instrument as vocal and the rest as others
vocal_inst = self.config["training"]["instruments"][0]
path_vocal = "{}/{}_{}.wav".format(vocal_root, file_base_name, vocal_inst)
self.save_audio(path_vocal, res[vocal_inst].T, sr, format)
for other in self.config["training"]["instruments"][1:]: # save other instruments
path_other = "{}/{}_{}.wav".format(others_root, file_base_name, other)
self.save_audio(path_other, res[other].T, sr, format)
def save_audio(self, path, data, sr, format):
# input path should be endwith '.wav'
if format in ["wav", "flac"]:
if format == "flac":
path = path[:-3] + "flac"
sf.write(path, data, sr)
else:
sf.write(path, data, sr)
os.system('ffmpeg -i "{}" -vn "{}" -q:a 2 -y'.format(path, path[:-3] + format))
try:
os.remove(path)
except:
pass
def __init__(self, model_path, config_path, device, is_half):
self.device = device
self.is_half = is_half
self.model_type = None
self.config = None
# get model_type, first try:
if "bs_roformer" in model_path.lower() or "bsroformer" in model_path.lower():
self.model_type = "bs_roformer"
elif "mel_band_roformer" in model_path.lower() or "melbandroformer" in model_path.lower():
self.model_type = "mel_band_roformer"
if not os.path.exists(config_path):
if self.model_type is None:
# if model_type is still None, raise an error
raise ValueError(
"Error: Unknown model type. If you are using a model without a configuration file, Ensure that your model name includes 'bs_roformer', 'bsroformer', 'mel_band_roformer', or 'melbandroformer'. Otherwise, you can manually place the model configuration file into 'tools/uvr5/uvr5w_weights' and ensure that the configuration file is named as '<model_name>.yaml' then try it again."
)
self.config = self.get_default_config()
else:
# if there is a configuration file
self.config = self.get_config(config_path)
if self.model_type is None:
# if model_type is still None, second try, get model_type from the configuration file
if "freqs_per_bands" in self.config["model"]:
# if freqs_per_bands in config, it's a bs_roformer model
self.model_type = "bs_roformer"
else:
# else it's a mel_band_roformer model
self.model_type = "mel_band_roformer"
print("Detected model type: {}".format(self.model_type))
model = self.get_model_from_config()
state_dict = torch.load(model_path, map_location="cpu")
model.load_state_dict(state_dict)
if is_half == False:
self.model = model.to(device)
else:
self.model = model.half().to(device)
def _path_audio_(self, input, others_root, vocal_root, format, is_hp3=False):
self.run_folder(input, vocal_root, others_root, format)