mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-06-05 05:48:14 +08:00
feat:Added .voice loader
This commit is contained in:
parent
1c54a945cb
commit
f6e8ec8a78
@ -73,8 +73,11 @@ class ZIP_File:
|
||||
raise FileNotFoundError(f"File {file_path} does not exist.")
|
||||
return file_path
|
||||
|
||||
def get_file_obj(self, file_name:str,location:str,mode:str='r'):
|
||||
file_path = fl.merge_dir_txt2(self.temp_write, location, file_name)
|
||||
def get_file_obj(self, file_name:str,location:str='',mode:str='r'):
|
||||
if location == '':
|
||||
file_path = fl.merge_dir_txt2(self.temp_write,file_name)
|
||||
else:
|
||||
file_path = fl.merge_dir_txt2(self.temp_write, location, file_name)
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File {file_path} does not exist.")
|
||||
return open(file_path, mode)
|
||||
@ -125,14 +128,14 @@ def save_tensor(path: str, tensors: Union[torch.Tensor, list],name:str,MySet:set
|
||||
zf.close()
|
||||
del zf
|
||||
|
||||
def load_tensor(path: str,name:str,find_func,MySet:set=set()) -> list:
|
||||
def load_tensor(path: str,name:str,find_func,MySet:set=set()) -> list[torch.Tensor]:
|
||||
zf = ZIP_File(path, name, MySet=MySet)
|
||||
zf.release()
|
||||
voice_path = find_func(zf,il)
|
||||
tensors = []
|
||||
for i in range(len(voice_path)):
|
||||
v = voice_path[i]
|
||||
np_array = np.load(v)
|
||||
np_array = np.load(v,allow_pickle=True)
|
||||
tensor = torch.from_numpy(np_array)
|
||||
tensors.append(tensor)
|
||||
zf.close()
|
||||
|
||||
@ -39,6 +39,21 @@ def _get_unique_name(name,MySet:set=set()):
|
||||
POOL.add(name)
|
||||
return name
|
||||
|
||||
def find_func(zf,il):
|
||||
f = zf.get_file_path("voice.json")
|
||||
info = il.load_info(f)
|
||||
if info is None:
|
||||
return None
|
||||
list_names = info["access_list"]
|
||||
ret = []
|
||||
for name in list_names:
|
||||
try:
|
||||
a = zf.get_file_path(name)
|
||||
ret.append(a)
|
||||
except FileNotFoundError:
|
||||
continue
|
||||
return ret
|
||||
|
||||
def set_high_priority():
|
||||
"""把当前 Python 进程设为 HIGH_PRIORITY_CLASS"""
|
||||
if os.name != "nt":
|
||||
@ -798,8 +813,14 @@ def get_tts_wav(
|
||||
|
||||
SaveSvEmb=False,
|
||||
SaveRefers=False,
|
||||
SaveSvEmbName=None,
|
||||
SaveRefersName=None
|
||||
SaveSvEmbName="sv_emb.voice",
|
||||
SaveRefersName="refers.voice",
|
||||
|
||||
InjectSvEmb=False,
|
||||
InjectRefers=False,
|
||||
InjectSvEmbName="sv_emb.voice",
|
||||
InjectRefersName="refers.voice",
|
||||
|
||||
):
|
||||
global cache
|
||||
if ref_wav_path:
|
||||
@ -952,7 +973,7 @@ def get_tts_wav(
|
||||
if SaveSvEmb and is_v2pro:
|
||||
names = []
|
||||
for i in sv_emb:
|
||||
names.append(_get_unique_name(str(i.shape)))
|
||||
names.append(_get_unique_name(str(i.shape))+".npy")
|
||||
sv_path = merge_dir_txt2(ROOT_DIR,"output","sv_emb_opt")
|
||||
if not os.path.exists(sv_path):
|
||||
os.makedirs(sv_path,exist_ok=True)
|
||||
@ -964,13 +985,35 @@ def get_tts_wav(
|
||||
if SaveRefers:
|
||||
names = []
|
||||
for i in refers:
|
||||
names.append(_get_unique_name(str(i.shape)))
|
||||
names.append(_get_unique_name(str(i.shape))+".npy")
|
||||
refers_path = merge_dir_txt2(ROOT_DIR,"output","refers_opt")
|
||||
if not os.path.exists(refers_path):
|
||||
os.makedirs(refers_path,exist_ok=True)
|
||||
VoiceSave.save_tensor(str(merge_dir_txt2(ROOT_DIR,"output","refers_opt",SaveRefersName)),refers,SaveRefersName,file_names=names,access_list=names)
|
||||
except:
|
||||
traceback.print_exc()
|
||||
|
||||
#print("refers数量:", len(refers))
|
||||
#print("sv_emb数量:", len(sv_emb) if is_v2pro else "无sv_emb")
|
||||
try:
|
||||
if InjectSvEmb and is_v2pro:
|
||||
_sv_emb = VoiceSave.load_tensor(str(merge_dir_txt2(ROOT_DIR,"output","sv_emb_opt",InjectSvEmbName)),InjectSvEmbName,find_func)
|
||||
for i in range(len(_sv_emb)):
|
||||
sv_emb.append(_sv_emb[i].to(device))
|
||||
except:
|
||||
traceback.print_exc()
|
||||
|
||||
try:
|
||||
if InjectRefers:
|
||||
_refers = VoiceSave.load_tensor(str(merge_dir_txt2(ROOT_DIR,"output","refers_opt",InjectRefersName)),InjectRefersName,find_func)
|
||||
for i in range(len(_refers)):
|
||||
refers.append(_refers[i].to(device))
|
||||
except:
|
||||
traceback.print_exc()
|
||||
|
||||
#print("注入后refers数量:", len(refers))
|
||||
#print("注入后sv_emb数量:", len(sv_emb) if is_v2pro else "无sv_emb")
|
||||
|
||||
|
||||
if is_v2pro:
|
||||
audio = vq_model.decode(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user