feat:Added .voice loader

This commit is contained in:
__kaning123__ 2026-02-25 10:20:48 +08:00 committed by GitHub
parent 1c54a945cb
commit f6e8ec8a78
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 54 additions and 8 deletions

View File

@ -73,7 +73,10 @@ class ZIP_File:
raise FileNotFoundError(f"File {file_path} does not exist.")
return file_path
def get_file_obj(self, file_name:str,location:str,mode:str='r'):
def get_file_obj(self, file_name:str,location:str='',mode:str='r'):
if location == '':
file_path = fl.merge_dir_txt2(self.temp_write,file_name)
else:
file_path = fl.merge_dir_txt2(self.temp_write, location, file_name)
if not os.path.exists(file_path):
raise FileNotFoundError(f"File {file_path} does not exist.")
@ -125,14 +128,14 @@ def save_tensor(path: str, tensors: Union[torch.Tensor, list],name:str,MySet:set
zf.close()
del zf
def load_tensor(path: str,name:str,find_func,MySet:set=set()) -> list:
def load_tensor(path: str,name:str,find_func,MySet:set=set()) -> list[torch.Tensor]:
zf = ZIP_File(path, name, MySet=MySet)
zf.release()
voice_path = find_func(zf,il)
tensors = []
for i in range(len(voice_path)):
v = voice_path[i]
np_array = np.load(v)
np_array = np.load(v,allow_pickle=True)
tensor = torch.from_numpy(np_array)
tensors.append(tensor)
zf.close()

View File

@ -39,6 +39,21 @@ def _get_unique_name(name,MySet:set=set()):
POOL.add(name)
return name
def find_func(zf,il):
f = zf.get_file_path("voice.json")
info = il.load_info(f)
if info is None:
return None
list_names = info["access_list"]
ret = []
for name in list_names:
try:
a = zf.get_file_path(name)
ret.append(a)
except FileNotFoundError:
continue
return ret
def set_high_priority():
"""把当前 Python 进程设为 HIGH_PRIORITY_CLASS"""
if os.name != "nt":
@ -798,8 +813,14 @@ def get_tts_wav(
SaveSvEmb=False,
SaveRefers=False,
SaveSvEmbName=None,
SaveRefersName=None
SaveSvEmbName="sv_emb.voice",
SaveRefersName="refers.voice",
InjectSvEmb=False,
InjectRefers=False,
InjectSvEmbName="sv_emb.voice",
InjectRefersName="refers.voice",
):
global cache
if ref_wav_path:
@ -952,7 +973,7 @@ def get_tts_wav(
if SaveSvEmb and is_v2pro:
names = []
for i in sv_emb:
names.append(_get_unique_name(str(i.shape)))
names.append(_get_unique_name(str(i.shape))+".npy")
sv_path = merge_dir_txt2(ROOT_DIR,"output","sv_emb_opt")
if not os.path.exists(sv_path):
os.makedirs(sv_path,exist_ok=True)
@ -964,7 +985,7 @@ def get_tts_wav(
if SaveRefers:
names = []
for i in refers:
names.append(_get_unique_name(str(i.shape)))
names.append(_get_unique_name(str(i.shape))+".npy")
refers_path = merge_dir_txt2(ROOT_DIR,"output","refers_opt")
if not os.path.exists(refers_path):
os.makedirs(refers_path,exist_ok=True)
@ -972,6 +993,28 @@ def get_tts_wav(
except:
traceback.print_exc()
#print("refers数量:", len(refers))
#print("sv_emb数量:", len(sv_emb) if is_v2pro else "无sv_emb")
try:
if InjectSvEmb and is_v2pro:
_sv_emb = VoiceSave.load_tensor(str(merge_dir_txt2(ROOT_DIR,"output","sv_emb_opt",InjectSvEmbName)),InjectSvEmbName,find_func)
for i in range(len(_sv_emb)):
sv_emb.append(_sv_emb[i].to(device))
except:
traceback.print_exc()
try:
if InjectRefers:
_refers = VoiceSave.load_tensor(str(merge_dir_txt2(ROOT_DIR,"output","refers_opt",InjectRefersName)),InjectRefersName,find_func)
for i in range(len(_refers)):
refers.append(_refers[i].to(device))
except:
traceback.print_exc()
#print("注入后refers数量:", len(refers))
#print("注入后sv_emb数量:", len(sv_emb) if is_v2pro else "无sv_emb")
if is_v2pro:
audio = vq_model.decode(
pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed, sv_emb=sv_emb