feat: Added path check

This commit is contained in:
__kaning123__ 2026-02-25 13:56:47 +08:00 committed by GitHub
parent 012eb93ef8
commit 69f1c9c2dd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -987,7 +987,11 @@ def get_tts_wav(
sv_path = merge_dir_txt2(ROOT_DIR,"output","sv_emb_opt")
if not os.path.exists(sv_path):
os.makedirs(sv_path,exist_ok=True)
VoiceSave.save_tensor(str(merge_dir_txt2(ROOT_DIR,"output","sv_emb_opt",SaveSvEmbName)),sv_emb,SaveSvEmbName,file_names=names,access_list=names)
if not os.path.exists(SaveSvEmbName):
_pth_ = str(merge_dir_txt2(ROOT_DIR,"output","sv_emb_opt",SaveSvEmbName))
else:
_pth_ = SaveSvEmbName
VoiceSave.save_tensor(_pth_,sv_emb,SaveSvEmbName,file_names=names,access_list=names)
except:
traceback.print_exc()
@ -999,15 +1003,24 @@ def get_tts_wav(
refers_path = merge_dir_txt2(ROOT_DIR,"output","refers_opt")
if not os.path.exists(refers_path):
os.makedirs(refers_path,exist_ok=True)
VoiceSave.save_tensor(str(merge_dir_txt2(ROOT_DIR,"output","refers_opt",SaveRefersName)),refers,SaveRefersName,file_names=names,access_list=names)
if not os.path.exists(SaveRefersName):
_pth_ = str(merge_dir_txt2(ROOT_DIR,"output","refers_opt",SaveRefersName))
else:
_pth_ = SaveRefersName
VoiceSave.save_tensor(_pth_,refers,SaveRefersName,file_names=names,access_list=names)
except:
traceback.print_exc()
#print("refers数量:", len(refers))
#print("sv_emb数量:", len(sv_emb) if is_v2pro else "无sv_emb")
try:
if InjectSvEmb and is_v2pro:
_sv_emb = VoiceSave.load_tensor(str(merge_dir_txt2(ROOT_DIR,"output","sv_emb_opt",InjectSvEmbName)),InjectSvEmbName,find_func)
if not os.path.exists(InjectSvEmbName):
_pth_ = str(merge_dir_txt2(ROOT_DIR,"output","sv_emb_opt",InjectSvEmbName))
else:
_pth_ = InjectSvEmbName
_sv_emb = VoiceSave.load_tensor(_pth_,InjectSvEmbName,find_func)
for i in range(len(_sv_emb)):
sv_emb.append(_sv_emb[i].to(device))
except:
@ -1015,7 +1028,11 @@ def get_tts_wav(
try:
if InjectRefers:
_refers = VoiceSave.load_tensor(str(merge_dir_txt2(ROOT_DIR,"output","refers_opt",InjectRefersName)),InjectRefersName,find_func)
if not os.path.exists(InjectRefersName):
_pth_ = str(merge_dir_txt2(ROOT_DIR,"output","refers_opt",InjectRefersName))
else:
_pth_ = InjectRefersName
_refers = VoiceSave.load_tensor(_pth_,InjectRefersName,find_func)
for i in range(len(_refers)):
refers.append(_refers[i].to(device))
except:
@ -1024,7 +1041,6 @@ def get_tts_wav(
#print("注入后refers数量:", len(refers))
#print("注入后sv_emb数量:", len(sv_emb) if is_v2pro else "无sv_emb")
if is_v2pro:
audio = vq_model.decode(
pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed, sv_emb=sv_emb