mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-06 03:57:44 +08:00
bsroformer support fp16 inference
bsroformer support fp16 inference
This commit is contained in:
parent
10e885d9ac
commit
e62e965323
@ -1,4 +1,5 @@
|
|||||||
# This code is modified from https://github.com/ZFTurbo/
|
# This code is modified from https://github.com/ZFTurbo/
|
||||||
|
import pdb
|
||||||
|
|
||||||
import librosa
|
import librosa
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@ -10,6 +11,7 @@ import torch.nn as nn
|
|||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
|
from bs_roformer.bs_roformer import BSRoformer
|
||||||
|
|
||||||
class BsRoformer_Loader:
|
class BsRoformer_Loader:
|
||||||
def get_model_from_config(self):
|
def get_model_from_config(self):
|
||||||
@ -40,7 +42,7 @@ class BsRoformer_Loader:
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
from bs_roformer.bs_roformer import BSRoformer
|
|
||||||
model = BSRoformer(
|
model = BSRoformer(
|
||||||
**dict(config)
|
**dict(config)
|
||||||
)
|
)
|
||||||
@ -95,6 +97,8 @@ class BsRoformer_Loader:
|
|||||||
part = nn.functional.pad(input=part, pad=(0, C - length), mode='reflect')
|
part = nn.functional.pad(input=part, pad=(0, C - length), mode='reflect')
|
||||||
else:
|
else:
|
||||||
part = nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0)
|
part = nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0)
|
||||||
|
if(self.is_half==True):
|
||||||
|
part=part.half()
|
||||||
batch_data.append(part)
|
batch_data.append(part)
|
||||||
batch_locations.append((i, length))
|
batch_locations.append((i, length))
|
||||||
i += step
|
i += step
|
||||||
@ -102,6 +106,7 @@ class BsRoformer_Loader:
|
|||||||
|
|
||||||
if len(batch_data) >= batch_size or (i >= mix.shape[1]):
|
if len(batch_data) >= batch_size or (i >= mix.shape[1]):
|
||||||
arr = torch.stack(batch_data, dim=0)
|
arr = torch.stack(batch_data, dim=0)
|
||||||
|
# print(23333333,arr.dtype)
|
||||||
x = model(arr)
|
x = model(arr)
|
||||||
|
|
||||||
window = window_middle
|
window = window_middle
|
||||||
@ -192,14 +197,18 @@ class BsRoformer_Loader:
|
|||||||
# print("Elapsed time: {:.2f} sec".format(time.time() - start_time))
|
# print("Elapsed time: {:.2f} sec".format(time.time() - start_time))
|
||||||
|
|
||||||
|
|
||||||
def __init__(self, model_path, device):
|
def __init__(self, model_path, device,is_half):
|
||||||
self.device = device
|
self.device = device
|
||||||
self.extract_instrumental=True
|
self.extract_instrumental=True
|
||||||
|
|
||||||
model = self.get_model_from_config()
|
model = self.get_model_from_config()
|
||||||
state_dict = torch.load(model_path)
|
state_dict = torch.load(model_path,map_location="cpu")
|
||||||
model.load_state_dict(state_dict)
|
model.load_state_dict(state_dict)
|
||||||
self.model = model.to(device)
|
self.is_half=is_half
|
||||||
|
if(is_half==False):
|
||||||
|
self.model = model.to(device)
|
||||||
|
else:
|
||||||
|
self.model = model.half().to(device)
|
||||||
|
|
||||||
|
|
||||||
def _path_audio_(self, input, others_root, vocal_root, format, is_hp3=False):
|
def _path_audio_(self, input, others_root, vocal_root, format, is_hp3=False):
|
||||||
|
@ -17,8 +17,8 @@ from bsroformer import BsRoformer_Loader
|
|||||||
weight_uvr5_root = "tools/uvr5/uvr5_weights"
|
weight_uvr5_root = "tools/uvr5/uvr5_weights"
|
||||||
uvr5_names = []
|
uvr5_names = []
|
||||||
for name in os.listdir(weight_uvr5_root):
|
for name in os.listdir(weight_uvr5_root):
|
||||||
if name.endswith(".pth") or "onnx" in name:
|
if name.endswith(".pth") or name.endswith(".ckpt") or "onnx" in name:
|
||||||
uvr5_names.append(name.replace(".pth", ""))
|
uvr5_names.append(name.replace(".pth", "").replace(".ckpt", ""))
|
||||||
|
|
||||||
device=sys.argv[1]
|
device=sys.argv[1]
|
||||||
is_half=eval(sys.argv[2])
|
is_half=eval(sys.argv[2])
|
||||||
@ -37,8 +37,9 @@ def uvr(model_name, inp_root, save_root_vocal, paths, save_root_ins, agg, format
|
|||||||
elif model_name == "Bs_Roformer" or "bs_roformer" in model_name.lower():
|
elif model_name == "Bs_Roformer" or "bs_roformer" in model_name.lower():
|
||||||
func = BsRoformer_Loader
|
func = BsRoformer_Loader
|
||||||
pre_fun = func(
|
pre_fun = func(
|
||||||
model_path = os.path.join(weight_uvr5_root, model_name + ".pth"),
|
model_path = os.path.join(weight_uvr5_root, model_name + ".ckpt"),
|
||||||
device = device,
|
device = device,
|
||||||
|
is_half=is_half
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
func = AudioPre if "DeEcho" not in model_name else AudioPreDeEcho
|
func = AudioPre if "DeEcho" not in model_name else AudioPreDeEcho
|
||||||
|
Loading…
x
Reference in New Issue
Block a user