update webui

This commit is contained in:
Watchtower-Liu 2024-02-16 17:47:54 +08:00
parent 013db82d7d
commit 9cac8ed160

View File

@ -365,15 +365,19 @@ def merge_short_text_in_array(texts, threshold):
result[len(result) - 1] += text
return result
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_p=0.6, temperature=0.6):
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_p=0.6, temperature=0.6, ref_free = False):
if prompt_text is None or len(prompt_text) == 0:
ref_free = True
t0 = ttime()
prompt_language = dict_language[prompt_language]
text_language = dict_language[text_language]
prompt_text = prompt_text.strip("\n")
if (prompt_text[-1] not in splits): prompt_text += "" if prompt_language != "en" else "."
if not ref_free:
prompt_text = prompt_text.strip("\n")
if (prompt_text[-1] not in splits): prompt_text += "" if prompt_language != "en" else "."
print(i18n("实际输入的参考文本:"), prompt_text)
text = text.strip("\n")
if (text[0] not in splits and len(get_first(text)) < 4): text = "" + text if text_language != "en" else "." + text
print(i18n("实际输入的参考文本:"), prompt_text)
print(i18n("实际输入的目标文本:"), text)
zero_wav = np.zeros(
int(hps.data.sampling_rate * 0.3),
@ -402,8 +406,6 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
prompt_semantic = codes[0, 0]
t1 = ttime()
phones1, word2ph1, norm_text1=get_cleaned_text_final(prompt_text, prompt_language)
if (how_to_cut == i18n("凑四句一切")):
text = cut1(text)
elif (how_to_cut == i18n("凑50字一切")):
@ -420,7 +422,9 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
texts = text.split("\n")
texts = merge_short_text_in_array(texts, 5)
audio_opt = []
bert1=get_bert_final(phones1, word2ph1, norm_text1,prompt_language,device).to(dtype)
if not ref_free:
phones1, word2ph1, norm_text1=get_cleaned_text_final(prompt_text, prompt_language)
bert1=get_bert_final(phones1, word2ph1, norm_text1,prompt_language,device).to(dtype)
for text in texts:
# 解决输入目标文本的空行导致报错的问题
@ -430,9 +434,13 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
print(i18n("实际输入的目标文本(每句):"), text)
phones2, word2ph2, norm_text2 = get_cleaned_text_final(text, text_language)
bert2 = get_bert_final(phones2, word2ph2, norm_text2, text_language, device).to(dtype)
bert = torch.cat([bert2], 1)
if not ref_free:
bert = torch.cat([bert1, bert2], 1)
all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0)
else:
bert = bert2
all_phoneme_ids = torch.LongTensor(phones2).to(device).unsqueeze(0)
all_phoneme_ids = torch.LongTensor(phones2).to(device).unsqueeze(0)
bert = bert.to(device).unsqueeze(0)
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
prompt = prompt_semantic.unsqueeze(0).to(device)
@ -442,7 +450,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
pred_semantic, idx = t2s_model.model.infer_panel(
all_phoneme_ids,
all_phoneme_len,
None,
None if ref_free else prompt,
bert,
# prompt_phone_len=ph_offset,
top_k=top_k,
@ -451,7 +459,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
early_stop_num=hz * max_sec,
)
t3 = ttime()
print(pred_semantic,idx)
# print(pred_semantic.shape,idx)
pred_semantic = pred_semantic[:, -idx:].unsqueeze(
0
) # .unsqueeze(0)#mq要多unsqueeze一次
@ -608,7 +616,10 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
gr.Markdown(value=i18n("*请上传并填写参考信息"))
with gr.Row():
inp_ref = gr.Audio(label=i18n("请上传3~10秒内参考音频超过会报错"), type="filepath")
prompt_text = gr.Textbox(label=i18n("参考音频的文本"), value="")
with gr.Column():
ref_text_free = gr.Checkbox(label=i18n("开启无参考文本模式 无参考文本时该选项无效"), value=False, interactive=True, show_label=True)
gr.Markdown("使用无参考文本模式时建议使用微调GPT")
prompt_text = gr.Textbox(label=i18n("参考音频的文本"), value="")
prompt_language = gr.Dropdown(
label=i18n("参考音频的语种"), choices=[i18n("中文"), i18n("英文"), i18n("日文"), i18n("中英混合"), i18n("日英混合"), i18n("多语种混合")], value=i18n("中文")
)
@ -625,6 +636,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
interactive=True,
)
with gr.Row():
gr.Markdown("gpt采样参数(无参考文本时不要太低)")
top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True)
top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True)
temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=1,interactive=True)
@ -633,7 +645,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
inference_button.click(
get_tts_wav,
[inp_ref, prompt_text, prompt_language, text, text_language, how_to_cut, top_k, top_p, temperature],
[inp_ref, prompt_text, prompt_language, text, text_language, how_to_cut, top_k, top_p, temperature, ref_text_free],
[output],
)