bsroformer support fp16 inference

bsroformer support fp16 inference
This commit is contained in:
RVC-Boss 2024-08-01 21:26:59 +08:00 committed by GitHub
parent 10e885d9ac
commit e62e965323
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 17 additions and 7 deletions

View File

@ -1,4 +1,5 @@
# This code is modified from https://github.com/ZFTurbo/ # This code is modified from https://github.com/ZFTurbo/
import pdb
import librosa import librosa
from tqdm import tqdm from tqdm import tqdm
@ -10,6 +11,7 @@ import torch.nn as nn
import warnings import warnings
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
from bs_roformer.bs_roformer import BSRoformer
class BsRoformer_Loader: class BsRoformer_Loader:
def get_model_from_config(self): def get_model_from_config(self):
@ -40,7 +42,7 @@ class BsRoformer_Loader:
} }
from bs_roformer.bs_roformer import BSRoformer
model = BSRoformer( model = BSRoformer(
**dict(config) **dict(config)
) )
@ -95,6 +97,8 @@ class BsRoformer_Loader:
part = nn.functional.pad(input=part, pad=(0, C - length), mode='reflect') part = nn.functional.pad(input=part, pad=(0, C - length), mode='reflect')
else: else:
part = nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0) part = nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0)
if(self.is_half==True):
part=part.half()
batch_data.append(part) batch_data.append(part)
batch_locations.append((i, length)) batch_locations.append((i, length))
i += step i += step
@ -102,6 +106,7 @@ class BsRoformer_Loader:
if len(batch_data) >= batch_size or (i >= mix.shape[1]): if len(batch_data) >= batch_size or (i >= mix.shape[1]):
arr = torch.stack(batch_data, dim=0) arr = torch.stack(batch_data, dim=0)
# print(23333333,arr.dtype)
x = model(arr) x = model(arr)
window = window_middle window = window_middle
@ -192,14 +197,18 @@ class BsRoformer_Loader:
# print("Elapsed time: {:.2f} sec".format(time.time() - start_time)) # print("Elapsed time: {:.2f} sec".format(time.time() - start_time))
def __init__(self, model_path, device): def __init__(self, model_path, device,is_half):
self.device = device self.device = device
self.extract_instrumental=True self.extract_instrumental=True
model = self.get_model_from_config() model = self.get_model_from_config()
state_dict = torch.load(model_path) state_dict = torch.load(model_path,map_location="cpu")
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
self.model = model.to(device) self.is_half=is_half
if(is_half==False):
self.model = model.to(device)
else:
self.model = model.half().to(device)
def _path_audio_(self, input, others_root, vocal_root, format, is_hp3=False): def _path_audio_(self, input, others_root, vocal_root, format, is_hp3=False):

View File

@ -17,8 +17,8 @@ from bsroformer import BsRoformer_Loader
weight_uvr5_root = "tools/uvr5/uvr5_weights" weight_uvr5_root = "tools/uvr5/uvr5_weights"
uvr5_names = [] uvr5_names = []
for name in os.listdir(weight_uvr5_root): for name in os.listdir(weight_uvr5_root):
if name.endswith(".pth") or "onnx" in name: if name.endswith(".pth") or name.endswith(".ckpt") or "onnx" in name:
uvr5_names.append(name.replace(".pth", "")) uvr5_names.append(name.replace(".pth", "").replace(".ckpt", ""))
device=sys.argv[1] device=sys.argv[1]
is_half=eval(sys.argv[2]) is_half=eval(sys.argv[2])
@ -37,8 +37,9 @@ def uvr(model_name, inp_root, save_root_vocal, paths, save_root_ins, agg, format
elif model_name == "Bs_Roformer" or "bs_roformer" in model_name.lower(): elif model_name == "Bs_Roformer" or "bs_roformer" in model_name.lower():
func = BsRoformer_Loader func = BsRoformer_Loader
pre_fun = func( pre_fun = func(
model_path = os.path.join(weight_uvr5_root, model_name + ".pth"), model_path = os.path.join(weight_uvr5_root, model_name + ".ckpt"),
device = device, device = device,
is_half=is_half
) )
else: else:
func = AudioPre if "DeEcho" not in model_name else AudioPreDeEcho func = AudioPre if "DeEcho" not in model_name else AudioPreDeEcho