From e62e965323a60a76a025bcaa45268c1ddcbcf05c Mon Sep 17 00:00:00 2001 From: RVC-Boss <129054828+RVC-Boss@users.noreply.github.com> Date: Thu, 1 Aug 2024 21:26:59 +0800 Subject: [PATCH] bsroformer support fp16 inference bsroformer support fp16 inference --- tools/uvr5/bsroformer.py | 17 +++++++++++++---- tools/uvr5/webui.py | 7 ++++--- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/tools/uvr5/bsroformer.py b/tools/uvr5/bsroformer.py index 5c01011..d162032 100644 --- a/tools/uvr5/bsroformer.py +++ b/tools/uvr5/bsroformer.py @@ -1,4 +1,5 @@ # This code is modified from https://github.com/ZFTurbo/ +import pdb import librosa from tqdm import tqdm @@ -10,6 +11,7 @@ import torch.nn as nn import warnings warnings.filterwarnings("ignore") +from bs_roformer.bs_roformer import BSRoformer class BsRoformer_Loader: def get_model_from_config(self): @@ -40,7 +42,7 @@ class BsRoformer_Loader: } - from bs_roformer.bs_roformer import BSRoformer + model = BSRoformer( **dict(config) ) @@ -95,6 +97,8 @@ class BsRoformer_Loader: part = nn.functional.pad(input=part, pad=(0, C - length), mode='reflect') else: part = nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0) + if(self.is_half==True): + part=part.half() batch_data.append(part) batch_locations.append((i, length)) i += step @@ -102,6 +106,7 @@ class BsRoformer_Loader: if len(batch_data) >= batch_size or (i >= mix.shape[1]): arr = torch.stack(batch_data, dim=0) + # print(23333333,arr.dtype) x = model(arr) window = window_middle @@ -192,14 +197,18 @@ class BsRoformer_Loader: # print("Elapsed time: {:.2f} sec".format(time.time() - start_time)) - def __init__(self, model_path, device): + def __init__(self, model_path, device,is_half): self.device = device self.extract_instrumental=True model = self.get_model_from_config() - state_dict = torch.load(model_path) + state_dict = torch.load(model_path,map_location="cpu") model.load_state_dict(state_dict) - self.model = model.to(device) + self.is_half=is_half + if(is_half==False): + self.model = model.to(device) + else: + self.model = model.half().to(device) def _path_audio_(self, input, others_root, vocal_root, format, is_hp3=False): diff --git a/tools/uvr5/webui.py b/tools/uvr5/webui.py index 712c892..60dfdaa 100644 --- a/tools/uvr5/webui.py +++ b/tools/uvr5/webui.py @@ -17,8 +17,8 @@ from bsroformer import BsRoformer_Loader weight_uvr5_root = "tools/uvr5/uvr5_weights" uvr5_names = [] for name in os.listdir(weight_uvr5_root): - if name.endswith(".pth") or "onnx" in name: - uvr5_names.append(name.replace(".pth", "")) + if name.endswith(".pth") or name.endswith(".ckpt") or "onnx" in name: + uvr5_names.append(name.replace(".pth", "").replace(".ckpt", "")) device=sys.argv[1] is_half=eval(sys.argv[2]) @@ -37,8 +37,9 @@ def uvr(model_name, inp_root, save_root_vocal, paths, save_root_ins, agg, format elif model_name == "Bs_Roformer" or "bs_roformer" in model_name.lower(): func = BsRoformer_Loader pre_fun = func( - model_path = os.path.join(weight_uvr5_root, model_name + ".pth"), + model_path = os.path.join(weight_uvr5_root, model_name + ".ckpt"), device = device, + is_half=is_half ) else: func = AudioPre if "DeEcho" not in model_name else AudioPreDeEcho