feat:添加关闭tts webui 的入口 与 ge 等中间量的保存入口用于分发及使用

This commit is contained in:
Kaning123 2026-04-02 17:24:19 +08:00
parent 47170fd555
commit 46ae12bf17
2 changed files with 129 additions and 0 deletions

3
GPT_SoVITS/config.json Normal file
View File

@ -0,0 +1,3 @@
{
"running_on" : "local"
}

View File

@ -9,7 +9,11 @@
import psutil
import os
import sys
import json
from pathlib import Path
import uuid
def get_my_dir():
return os.path.dirname(os.path.abspath(__file__))
@ -23,6 +27,11 @@ def get_parent_dir(dir_path,depth=1):
def merge_dir_txt2(*TXT):
return Path(os.path.join(*TXT))
with open(merge_dir_txt2(get_my_dir(), "config.json"), "r", encoding="utf-8") as f:
config_json = f.read()
config_json = json.loads(config_json)
running_on = config_json["running_on"]
ROOT_DIR = str(get_parent_dir(get_my_dir()))
sys.path.append(get_my_dir())
import VoiceSave
@ -816,12 +825,19 @@ def get_tts_wav(
SaveSvEmbName="sv_emb.voice",
SaveRefersName="refers.voice",
SaveGE=False,
SaveGEName="ge.voice",
InjectSvEmb=False,
InjectRefers=False,
InjectSvEmbName="sv_emb.voice",
InjectRefersName="refers.voice",
EnableAudioLoad=True,
SaveOutputAsUndecoded=False,
SaveOutputAsUndecodedName="output.voice",
AddRandomSaltToSaveOutputAsUndecodedName=False,
):
global cache
if ref_wav_path:
@ -1041,6 +1057,60 @@ def get_tts_wav(
#print("注入后refers数量:", len(refers))
#print("注入后sv_emb数量:", len(sv_emb) if is_v2pro else "无sv_emb")
try:
ges = []
for i in range(len(refers)):
if is_v2pro:
ge_ = vq_model.ge_(refers[i],sv_emb[i])
else:
ge_ = vq_model.ge_(refers[i])
ges.append(ge_)
if SaveGE:
names = []
for i in ges:
names.append(_get_unique_name(str(i.shape))+".npy")
ge_path = merge_dir_txt2(ROOT_DIR,"output","ge_opt")
if not os.path.exists(ge_path):
os.makedirs(ge_path,exist_ok=True)
if not os.path.exists(SaveGEName):
_pth_ = str(merge_dir_txt2(ROOT_DIR,"output","ge_opt",SaveGEName))
else:
_pth_ = SaveGEName
VoiceSave.save_tensor(_pth_,ges,SaveGEName,file_names=names,access_list=names)
except:
traceback.print_exc()
if AddRandomSaltToSaveOutputAsUndecodedName:
ranA = uuid.uuid4()
ranB = uuid.uuid4()
SaveOutputAsUndecodedName = f"{SaveOutputAsUndecodedName}_{ranA}_{ranB}.voice"
try:
if SaveOutputAsUndecoded:
if is_v2pro:
z_p,mask,ge = vq_model.decode2(
pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0),
refers, speed=speed, sv_emb=sv_emb)
else:
z_p,mask,ge = vq_model.decode2(
pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0),
refers, speed=speed)
ret = [z_p.cpu().detach(),
mask.cpu().detach(),
ge.cpu().detach()]
names = [f"z_p_{str(ret[0].shape)}",
f"mask_{str(ret[1].shape)}",
f"ge_{str(ret[2].shape)}"]
undecoded_path = merge_dir_txt2(ROOT_DIR,"output","undecoded_opt")
if not os.path.exists(undecoded_path):
os.makedirs(undecoded_path,exist_ok=True)
if not os.path.exists(SaveOutputAsUndecodedName):
_pth_ = str(merge_dir_txt2(ROOT_DIR,"output","undecoded_opt",SaveOutputAsUndecodedName))
else:
_pth_ = SaveOutputAsUndecodedName
VoiceSave.save_tensor(_pth_,ret,SaveOutputAsUndecodedName,file_names=names,access_list=names)
except:
traceback.print_exc()
if is_v2pro:
audio = vq_model.decode(
pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed, sv_emb=sv_emb
@ -1129,6 +1199,11 @@ def get_tts_wav(
audio_opt = audio_opt.cpu().detach().numpy()
yield opt_sr, (audio_opt * 32767).astype(np.int16)
def close_serv():
if running_on == "local"
sys.exit(0)
else:
gr.Warning(i18n("服务器环境下该功能不可用"))
def split(todo_text):
todo_text = todo_text.replace("……", "").replace("——", "")
@ -1372,7 +1447,47 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
show_label=True,
visible=True,
)
SaveGE = gr.Checkbox(
label = i18n("保存GE"),
value = True,
interactive = True,
show_label = True,
visible = True,
)
SaveGEName = gr.Textbox(
label = i18n("保存的GE文件名默认保存在output/ge_opt目录下"),
value = "ge.voice",
interactive = True,
show_label = True,
visible = True,
)
SaveOutputAsUndecoded = gr.Checkbox(
label = i18n("保存未解码的输出"),
value = False,
interactive = True,
show_label = True,
visible = True,
)
SaveOutputAsUndecodedName = gr.Textbox(
label = i18n("保存的未解码输出文件名默认保存在output/undecoded_opt目录下"),
value = "output.voice",
interactive = True,
show_label = True,
visible = True,
)
AddRandomSaltToSaveOutputAsUndecodedName = gr.Checkbox(
label = i18n("给未解码输出文件名添加随机盐,防止覆盖"),
value = False,
interactive = True,
show_label = True,
visible = True,
)
with gr.Column(scale=14):
prompt_language = gr.Dropdown(
label=i18n("参考音频的语种"),
@ -1482,6 +1597,11 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
inference_button = gr.Button(value=i18n("合成语音"), variant="primary", size="lg", scale=25)
output = gr.Audio(label=i18n("输出的语音"), scale=14)
with gr.Row():
close_button = gr.Button(value=i18n("关闭服务器"), variant="danger", size="lg", scale=25)
close_button.click(close_serv)
inference_button.click(
get_tts_wav,
[
@ -1506,12 +1626,18 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
SaveRefers,
SaveSvEmbName,
SaveRefersName,
SaveGE,
SaveGEName,
InjectSvEmb,
InjectRefers,
InjectSvEmbName,
InjectRefersName,
EnableAudioLoad,
SaveOutputAsUndecoded,
SaveOutputAsUndecodedName,
AddRandomSaltToSaveOutputAsUndecodedName,
],
[output],