# This code is modified from https://github.com/ZFTurbo/ import librosa from tqdm import tqdm import os import torch import numpy as np import soundfile as sf import torch.nn as nn import yaml import warnings warnings.filterwarnings("ignore") class Roformer_Loader: def get_config(self, config_path): with open(config_path, 'r', encoding='utf-8') as f: # use fullloader to load tag !!python/tuple, code can be improved config = yaml.load(f, Loader=yaml.FullLoader) return config def get_default_config(self): default_config = None if self.model_type == 'bs_roformer': # Use model_bs_roformer_ep_368_sdr_12.9628.yaml and model_bs_roformer_ep_317_sdr_12.9755.yaml as default configuration files # Other BS_Roformer models may not be compatible default_config = { "audio": {"chunk_size": 352800, "sample_rate": 44100}, "model": { "dim": 512, "depth": 12, "stereo": True, "num_stems": 1, "time_transformer_depth": 1, "freq_transformer_depth": 1, "linear_transformer_depth": 0, "freqs_per_bands": (2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 12, 12, 12, 12, 12, 12, 12, 12, 24, 24, 24, 24, 24, 24, 24, 24, 48, 48, 48, 48, 48, 48, 48, 48, 128, 129), "dim_head": 64, "heads": 8, "attn_dropout": 0.1, "ff_dropout": 0.1, "flash_attn": True, "dim_freqs_in": 1025, "stft_n_fft": 2048, "stft_hop_length": 441, "stft_win_length": 2048, "stft_normalized": False, "mask_estimator_depth": 2, "multi_stft_resolution_loss_weight": 1.0, "multi_stft_resolutions_window_sizes": (4096, 2048, 1024, 512, 256), "multi_stft_hop_size": 147, "multi_stft_normalized": False, }, "training": {"instruments": ["vocals", "other"], "target_instrument": "vocals"}, "inference": {"batch_size": 2, "num_overlap": 2} } elif self.model_type == 'mel_band_roformer': # Use model_mel_band_roformer_ep_3005_sdr_11.4360.yaml as default configuration files # Other Mel_Band_Roformer models may not be compatible default_config = { "audio": {"chunk_size": 352800, "sample_rate": 44100}, "model": { "dim": 384, "depth": 12, "stereo": True, "num_stems": 1, "time_transformer_depth": 1, "freq_transformer_depth": 1, "linear_transformer_depth": 0, "num_bands": 60, "dim_head": 64, "heads": 8, "attn_dropout": 0.1, "ff_dropout": 0.1, "flash_attn": True, "dim_freqs_in": 1025, "sample_rate": 44100, "stft_n_fft": 2048, "stft_hop_length": 441, "stft_win_length": 2048, "stft_normalized": False, "mask_estimator_depth": 2, "multi_stft_resolution_loss_weight": 1.0, "multi_stft_resolutions_window_sizes": (4096, 2048, 1024, 512, 256), "multi_stft_hop_size": 147, "multi_stft_normalized": False }, "training": {"instruments": ["vocals", "other"], "target_instrument": "vocals"}, "inference": {"batch_size": 2, "num_overlap": 2} } return default_config def get_model_from_config(self): if self.model_type == 'bs_roformer': from bs_roformer.bs_roformer import BSRoformer model = BSRoformer(**dict(self.config["model"])) elif self.model_type == 'mel_band_roformer': from bs_roformer.mel_band_roformer import MelBandRoformer model = MelBandRoformer(**dict(self.config["model"])) else: print('Error: Unknown model: {}'.format(self.model_type)) model = None return model def demix_track(self, model, mix, device): C = self.config["audio"]["chunk_size"] # chunk_size N = self.config["inference"]["num_overlap"] fade_size = C // 10 step = int(C // N) border = C - step batch_size = self.config["inference"]["batch_size"] length_init = mix.shape[-1] progress_bar = tqdm(total=length_init // step + 1, desc="Processing", leave=False) # Do pad from the beginning and end to account floating window results better if length_init > 2 * border and (border > 0): mix = nn.functional.pad(mix, (border, border), mode='reflect') # Prepare windows arrays (do 1 time for speed up). This trick repairs click problems on the edges of segment window_size = C fadein = torch.linspace(0, 1, fade_size) fadeout = torch.linspace(1, 0, fade_size) window_start = torch.ones(window_size) window_middle = torch.ones(window_size) window_finish = torch.ones(window_size) window_start[-fade_size:] *= fadeout # First audio chunk, no fadein window_finish[:fade_size] *= fadein # Last audio chunk, no fadeout window_middle[-fade_size:] *= fadeout window_middle[:fade_size] *= fadein with torch.amp.autocast('cuda'): with torch.inference_mode(): if self.config["training"]["target_instrument"] is None: req_shape = (len(self.config["training"]["instruments"]),) + tuple(mix.shape) else: req_shape = (1, ) + tuple(mix.shape) result = torch.zeros(req_shape, dtype=torch.float32) counter = torch.zeros(req_shape, dtype=torch.float32) i = 0 batch_data = [] batch_locations = [] while i < mix.shape[1]: part = mix[:, i:i + C].to(device) length = part.shape[-1] if length < C: if length > C // 2 + 1: part = nn.functional.pad(input=part, pad=(0, C - length), mode='reflect') else: part = nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0) if self.is_half: part=part.half() batch_data.append(part) batch_locations.append((i, length)) i += step progress_bar.update(1) if len(batch_data) >= batch_size or (i >= mix.shape[1]): arr = torch.stack(batch_data, dim=0) # print(23333333,arr.dtype) x = model(arr) window = window_middle if i - step == 0: # First audio chunk, no fadein window = window_start elif i >= mix.shape[1]: # Last audio chunk, no fadeout window = window_finish for j in range(len(batch_locations)): start, l = batch_locations[j] result[..., start:start+l] += x[j][..., :l].cpu() * window[..., :l] counter[..., start:start+l] += window[..., :l] batch_data = [] batch_locations = [] estimated_sources = result / counter estimated_sources = estimated_sources.cpu().numpy() np.nan_to_num(estimated_sources, copy=False, nan=0.0) if length_init > 2 * border and (border > 0): # Remove pad estimated_sources = estimated_sources[..., border:-border] progress_bar.close() if self.config["training"]["target_instrument"] is None: return {k: v for k, v in zip(self.config["training"]["instruments"], estimated_sources)} else: return {k: v for k, v in zip([self.config["training"]["target_instrument"]], estimated_sources)} def run_folder(self, input, vocal_root, others_root, format): self.model.eval() path = input os.makedirs(vocal_root, exist_ok=True) os.makedirs(others_root, exist_ok=True) file_base_name = os.path.splitext(os.path.basename(path))[0] sample_rate = 44100 if 'sample_rate' in self.config["audio"]: sample_rate = self.config["audio"]['sample_rate'] try: mix, sr = librosa.load(path, sr=sample_rate, mono=False) except Exception as e: print('Can read track: {}'.format(path)) print('Error message: {}'.format(str(e))) return # in case if model only supports mono tracks isstereo = self.config["model"].get("stereo", True) if not isstereo and len(mix.shape) != 1: mix = np.mean(mix, axis=0) # if more than 2 channels, take mean print("Warning: Track has more than 1 channels, but model is mono, taking mean of all channels.") mix_orig = mix.copy() mixture = torch.tensor(mix, dtype=torch.float32) res = self.demix_track(self.model, mixture, self.device) if self.config["training"]["target_instrument"] is not None: # if target instrument is specified, save target instrument as vocal and other instruments as others # other instruments are caculated by subtracting target instrument from mixture target_instrument = self.config["training"]["target_instrument"] other_instruments = [i for i in self.config["training"]["instruments"] if i != target_instrument] other = mix_orig - res[target_instrument] # caculate other instruments path_vocal = "{}/{}_{}.wav".format(vocal_root, file_base_name, target_instrument) path_other = "{}/{}_{}.wav".format(others_root, file_base_name, other_instruments[0]) self.save_audio(path_vocal, res[target_instrument].T, sr, format) self.save_audio(path_other, other.T, sr, format) else: # if target instrument is not specified, save the first instrument as vocal and the rest as others vocal_inst = self.config["training"]["instruments"][0] path_vocal = "{}/{}_{}.wav".format(vocal_root, file_base_name, vocal_inst) self.save_audio(path_vocal, res[vocal_inst].T, sr, format) for other in self.config["training"]["instruments"][1:]: # save other instruments path_other = "{}/{}_{}.wav".format(others_root, file_base_name, other) self.save_audio(path_other, res[other].T, sr, format) def save_audio(self, path, data, sr, format): # input path should be endwith '.wav' if format in ["wav", "flac"]: if format == "flac": path = path[:-3] + "flac" sf.write(path, data, sr) else: sf.write(path, data, sr) os.system("ffmpeg -i \"{}\" -vn \"{}\" -q:a 2 -y".format(path, path[:-3] + format)) try: os.remove(path) except: pass def __init__(self, model_path, config_path, device, is_half): self.device = device self.is_half = is_half self.model_type = None self.config = None # get model_type, first try: if "bs_roformer" in model_path.lower() or "bsroformer" in model_path.lower(): self.model_type = "bs_roformer" elif "mel_band_roformer" in model_path.lower() or "melbandroformer" in model_path.lower(): self.model_type = "mel_band_roformer" if not os.path.exists(config_path): if self.model_type is None: # if model_type is still None, raise an error raise ValueError("Error: Unknown model type. If you are using a model without a configuration file, Ensure that your model name includes 'bs_roformer', 'bsroformer', 'mel_band_roformer', or 'melbandroformer'. Otherwise, you can manually place the model configuration file into 'tools/uvr5/uvr5w_weights' and ensure that the configuration file is named as '.yaml' then try it again.") self.config = self.get_default_config() else: # if there is a configuration file self.config = self.get_config(config_path) if self.model_type is None: # if model_type is still None, second try, get model_type from the configuration file if "freqs_per_bands" in self.config["model"]: # if freqs_per_bands in config, it's a bs_roformer model self.model_type = "bs_roformer" else: # else it's a mel_band_roformer model self.model_type = "mel_band_roformer" print("Detected model type: {}".format(self.model_type)) model = self.get_model_from_config() state_dict = torch.load(model_path, map_location="cpu") model.load_state_dict(state_dict) if(is_half==False): self.model = model.to(device) else: self.model = model.half().to(device) def _path_audio_(self, input, others_root, vocal_root, format, is_hp3=False): self.run_folder(input, vocal_root, others_root, format)