feat: Added entrys to save sv_emb and refers

This commit is contained in:
__kaning123__ 2026-02-25 07:53:03 +08:00 committed by GitHub
parent a6a53f7231
commit 1c54a945cb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 66 additions and 6 deletions

View File

@ -1,11 +1,11 @@
import numpy as np
import torch
import zipfile
from . import file_lib as fl
from . import time_lib as tl
from . import info_lib as il
import os
from typing import Union
import numpy as np
import torch
POOL:set = set()
def get_unique_name(name,MySet:set=set()):
@ -59,9 +59,9 @@ class ZIP_File:
file_path = fl.merge_dir_txt2(self.temp_write,file_name)
else:
file_path = fl.merge_dir_txt2(self.temp_write, location, file_name)
if not os.path.exists(location):
os.makedirs(location,exist_ok=True)
with open(file_path, 'x') as f:
if not os.path.exists(file_path):
os.makedirs(os.path.dirname(file_path),exist_ok=True)
with open(file_path, 'w') as f:
pass
def get_file_path(self, file_name:str,location:str=''):
@ -116,7 +116,7 @@ def save_tensor(path: str, tensors: Union[torch.Tensor, list],name:str,MySet:set
zf.create_file("voice.json")
info = {'name': name}
info.update(info_save)
il.save_info(str(zf.get_file_path("voice.json")), info)
il.save_info(info, str(zf.get_file_path("voice.json")))
for i in range(len(files)):
file_name = files[i]
np_array = np_arrays[i]

View File

@ -8,6 +8,36 @@
"""
import psutil
import os
import sys
from pathlib import Path
def get_my_dir():
return os.path.dirname(os.path.abspath(__file__))
def get_parent_dir(dir_path,depth=1):
parent_path = Path(dir_path)
for _ in range(depth):
parent_path = parent_path.parent
return parent_path
def merge_dir_txt2(*TXT):
return Path(os.path.join(*TXT))
ROOT_DIR = str(get_parent_dir(get_my_dir()))
sys.path.append(get_my_dir())
import VoiceSave
POOL:set = set()
def _get_unique_name(name,MySet:set=set()):
_id = 1
if name not in POOL and name not in MySet:
POOL.add(name)
return name
while name in POOL or name in MySet:
_id += 1
name = f'{name}_{_id}'
POOL.add(name)
return name
def set_high_priority():
"""把当前 Python 进程设为 HIGH_PRIORITY_CLASS"""
@ -765,6 +795,11 @@ def get_tts_wav(
sample_steps=8,
if_sr=False,
pause_second=0.3,
SaveSvEmb=False,
SaveRefers=False,
SaveSvEmbName=None,
SaveRefersName=None
):
global cache
if ref_wav_path:
@ -912,6 +947,31 @@ def get_tts_wav(
refers = [refers]
if is_v2pro:
sv_emb = [sv_cn_model.compute_embedding3(audio_tensor)]
try:
if SaveSvEmb and is_v2pro:
names = []
for i in sv_emb:
names.append(_get_unique_name(str(i.shape)))
sv_path = merge_dir_txt2(ROOT_DIR,"output","sv_emb_opt")
if not os.path.exists(sv_path):
os.makedirs(sv_path,exist_ok=True)
VoiceSave.save_tensor(str(merge_dir_txt2(ROOT_DIR,"output","sv_emb_opt",SaveSvEmbName)),sv_emb,SaveSvEmbName,file_names=names,access_list=names)
except:
traceback.print_exc()
try:
if SaveRefers:
names = []
for i in refers:
names.append(_get_unique_name(str(i.shape)))
refers_path = merge_dir_txt2(ROOT_DIR,"output","refers_opt")
if not os.path.exists(refers_path):
os.makedirs(refers_path,exist_ok=True)
VoiceSave.save_tensor(str(merge_dir_txt2(ROOT_DIR,"output","refers_opt",SaveRefersName)),refers,SaveRefersName,file_names=names,access_list=names)
except:
traceback.print_exc()
if is_v2pro:
audio = vq_model.decode(
pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed, sv_emb=sv_emb