feat:Added .voice loader

This commit is contained in:
__kaning123__ 2026-02-25 10:20:48 +08:00 committed by GitHub
parent 1c54a945cb
commit f6e8ec8a78
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 54 additions and 8 deletions

View File

@ -73,8 +73,11 @@ class ZIP_File:
raise FileNotFoundError(f"File {file_path} does not exist.") raise FileNotFoundError(f"File {file_path} does not exist.")
return file_path return file_path
def get_file_obj(self, file_name:str,location:str,mode:str='r'): def get_file_obj(self, file_name:str,location:str='',mode:str='r'):
file_path = fl.merge_dir_txt2(self.temp_write, location, file_name) if location == '':
file_path = fl.merge_dir_txt2(self.temp_write,file_name)
else:
file_path = fl.merge_dir_txt2(self.temp_write, location, file_name)
if not os.path.exists(file_path): if not os.path.exists(file_path):
raise FileNotFoundError(f"File {file_path} does not exist.") raise FileNotFoundError(f"File {file_path} does not exist.")
return open(file_path, mode) return open(file_path, mode)
@ -125,14 +128,14 @@ def save_tensor(path: str, tensors: Union[torch.Tensor, list],name:str,MySet:set
zf.close() zf.close()
del zf del zf
def load_tensor(path: str,name:str,find_func,MySet:set=set()) -> list: def load_tensor(path: str,name:str,find_func,MySet:set=set()) -> list[torch.Tensor]:
zf = ZIP_File(path, name, MySet=MySet) zf = ZIP_File(path, name, MySet=MySet)
zf.release() zf.release()
voice_path = find_func(zf,il) voice_path = find_func(zf,il)
tensors = [] tensors = []
for i in range(len(voice_path)): for i in range(len(voice_path)):
v = voice_path[i] v = voice_path[i]
np_array = np.load(v) np_array = np.load(v,allow_pickle=True)
tensor = torch.from_numpy(np_array) tensor = torch.from_numpy(np_array)
tensors.append(tensor) tensors.append(tensor)
zf.close() zf.close()

View File

@ -39,6 +39,21 @@ def _get_unique_name(name,MySet:set=set()):
POOL.add(name) POOL.add(name)
return name return name
def find_func(zf,il):
f = zf.get_file_path("voice.json")
info = il.load_info(f)
if info is None:
return None
list_names = info["access_list"]
ret = []
for name in list_names:
try:
a = zf.get_file_path(name)
ret.append(a)
except FileNotFoundError:
continue
return ret
def set_high_priority(): def set_high_priority():
"""把当前 Python 进程设为 HIGH_PRIORITY_CLASS""" """把当前 Python 进程设为 HIGH_PRIORITY_CLASS"""
if os.name != "nt": if os.name != "nt":
@ -798,8 +813,14 @@ def get_tts_wav(
SaveSvEmb=False, SaveSvEmb=False,
SaveRefers=False, SaveRefers=False,
SaveSvEmbName=None, SaveSvEmbName="sv_emb.voice",
SaveRefersName=None SaveRefersName="refers.voice",
InjectSvEmb=False,
InjectRefers=False,
InjectSvEmbName="sv_emb.voice",
InjectRefersName="refers.voice",
): ):
global cache global cache
if ref_wav_path: if ref_wav_path:
@ -952,7 +973,7 @@ def get_tts_wav(
if SaveSvEmb and is_v2pro: if SaveSvEmb and is_v2pro:
names = [] names = []
for i in sv_emb: for i in sv_emb:
names.append(_get_unique_name(str(i.shape))) names.append(_get_unique_name(str(i.shape))+".npy")
sv_path = merge_dir_txt2(ROOT_DIR,"output","sv_emb_opt") sv_path = merge_dir_txt2(ROOT_DIR,"output","sv_emb_opt")
if not os.path.exists(sv_path): if not os.path.exists(sv_path):
os.makedirs(sv_path,exist_ok=True) os.makedirs(sv_path,exist_ok=True)
@ -964,13 +985,35 @@ def get_tts_wav(
if SaveRefers: if SaveRefers:
names = [] names = []
for i in refers: for i in refers:
names.append(_get_unique_name(str(i.shape))) names.append(_get_unique_name(str(i.shape))+".npy")
refers_path = merge_dir_txt2(ROOT_DIR,"output","refers_opt") refers_path = merge_dir_txt2(ROOT_DIR,"output","refers_opt")
if not os.path.exists(refers_path): if not os.path.exists(refers_path):
os.makedirs(refers_path,exist_ok=True) os.makedirs(refers_path,exist_ok=True)
VoiceSave.save_tensor(str(merge_dir_txt2(ROOT_DIR,"output","refers_opt",SaveRefersName)),refers,SaveRefersName,file_names=names,access_list=names) VoiceSave.save_tensor(str(merge_dir_txt2(ROOT_DIR,"output","refers_opt",SaveRefersName)),refers,SaveRefersName,file_names=names,access_list=names)
except: except:
traceback.print_exc() traceback.print_exc()
#print("refers数量:", len(refers))
#print("sv_emb数量:", len(sv_emb) if is_v2pro else "无sv_emb")
try:
if InjectSvEmb and is_v2pro:
_sv_emb = VoiceSave.load_tensor(str(merge_dir_txt2(ROOT_DIR,"output","sv_emb_opt",InjectSvEmbName)),InjectSvEmbName,find_func)
for i in range(len(_sv_emb)):
sv_emb.append(_sv_emb[i].to(device))
except:
traceback.print_exc()
try:
if InjectRefers:
_refers = VoiceSave.load_tensor(str(merge_dir_txt2(ROOT_DIR,"output","refers_opt",InjectRefersName)),InjectRefersName,find_func)
for i in range(len(_refers)):
refers.append(_refers[i].to(device))
except:
traceback.print_exc()
#print("注入后refers数量:", len(refers))
#print("注入后sv_emb数量:", len(sv_emb) if is_v2pro else "无sv_emb")
if is_v2pro: if is_v2pro:
audio = vq_model.decode( audio = vq_model.decode(