Compatible with default model name

This commit is contained in:
KamioRinn 2024-07-15 05:05:20 +08:00
parent 104429d864
commit 1295bf22fb

View File

@ -34,10 +34,10 @@ def uvr(model_name, inp_root, save_root_vocal, paths, save_root_ins, agg, format
is_hp3 = "HP3" in model_name is_hp3 = "HP3" in model_name
if model_name == "onnx_dereverb_By_FoxJoy": if model_name == "onnx_dereverb_By_FoxJoy":
pre_fun = MDXNetDereverb(15) pre_fun = MDXNetDereverb(15)
elif model_name == "Bs_Roformer": elif model_name == "Bs_Roformer" or "bs_roformer" in model_name.lower():
func = BsRoformer_Loader func = BsRoformer_Loader
pre_fun = func( pre_fun = func(
model_path = os.path.join(weight_uvr5_root, "Bs_Roformer.pth"), model_path = os.path.join(weight_uvr5_root, model_name + ".pth"),
device = device, device = device,
) )
else: else: