diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index 633ef79..7ae9259 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -365,15 +365,19 @@ def merge_short_text_in_array(texts, threshold): result[len(result) - 1] += text return result -def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_p=0.6, temperature=0.6): +def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_p=0.6, temperature=0.6, ref_free = False): + if prompt_text is None or len(prompt_text) == 0: + ref_free = True t0 = ttime() prompt_language = dict_language[prompt_language] text_language = dict_language[text_language] - prompt_text = prompt_text.strip("\n") - if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_language != "en" else "." + if not ref_free: + prompt_text = prompt_text.strip("\n") + if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_language != "en" else "." + print(i18n("实际输入的参考文本:"), prompt_text) text = text.strip("\n") if (text[0] not in splits and len(get_first(text)) < 4): text = "。" + text if text_language != "en" else "." + text - print(i18n("实际输入的参考文本:"), prompt_text) + print(i18n("实际输入的目标文本:"), text) zero_wav = np.zeros( int(hps.data.sampling_rate * 0.3), @@ -402,8 +406,6 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, prompt_semantic = codes[0, 0] t1 = ttime() - phones1, word2ph1, norm_text1=get_cleaned_text_final(prompt_text, prompt_language) - if (how_to_cut == i18n("凑四句一切")): text = cut1(text) elif (how_to_cut == i18n("凑50字一切")): @@ -420,7 +422,9 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, texts = text.split("\n") texts = merge_short_text_in_array(texts, 5) audio_opt = [] - bert1=get_bert_final(phones1, word2ph1, norm_text1,prompt_language,device).to(dtype) + if not ref_free: + phones1, word2ph1, norm_text1=get_cleaned_text_final(prompt_text, prompt_language) + bert1=get_bert_final(phones1, word2ph1, norm_text1,prompt_language,device).to(dtype) for text in texts: # 解决输入目标文本的空行导致报错的问题 @@ -430,9 +434,13 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, print(i18n("实际输入的目标文本(每句):"), text) phones2, word2ph2, norm_text2 = get_cleaned_text_final(text, text_language) bert2 = get_bert_final(phones2, word2ph2, norm_text2, text_language, device).to(dtype) - bert = torch.cat([bert2], 1) + if not ref_free: + bert = torch.cat([bert1, bert2], 1) + all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0) + else: + bert = bert2 + all_phoneme_ids = torch.LongTensor(phones2).to(device).unsqueeze(0) - all_phoneme_ids = torch.LongTensor(phones2).to(device).unsqueeze(0) bert = bert.to(device).unsqueeze(0) all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device) prompt = prompt_semantic.unsqueeze(0).to(device) @@ -442,7 +450,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, pred_semantic, idx = t2s_model.model.infer_panel( all_phoneme_ids, all_phoneme_len, - None, + None if ref_free else prompt, bert, # prompt_phone_len=ph_offset, top_k=top_k, @@ -451,7 +459,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, early_stop_num=hz * max_sec, ) t3 = ttime() - print(pred_semantic,idx) + # print(pred_semantic.shape,idx) pred_semantic = pred_semantic[:, -idx:].unsqueeze( 0 ) # .unsqueeze(0)#mq要多unsqueeze一次 @@ -608,7 +616,10 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app: gr.Markdown(value=i18n("*请上传并填写参考信息")) with gr.Row(): inp_ref = gr.Audio(label=i18n("请上传3~10秒内参考音频,超过会报错!"), type="filepath") - prompt_text = gr.Textbox(label=i18n("参考音频的文本"), value="") + with gr.Column(): + ref_text_free = gr.Checkbox(label=i18n("开启无参考文本模式 无参考文本时该选项无效"), value=False, interactive=True, show_label=True) + gr.Markdown("使用无参考文本模式时建议使用微调GPT") + prompt_text = gr.Textbox(label=i18n("参考音频的文本"), value="") prompt_language = gr.Dropdown( label=i18n("参考音频的语种"), choices=[i18n("中文"), i18n("英文"), i18n("日文"), i18n("中英混合"), i18n("日英混合"), i18n("多语种混合")], value=i18n("中文") ) @@ -625,6 +636,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app: interactive=True, ) with gr.Row(): + gr.Markdown("gpt采样参数(无参考文本时不要太低):") top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True) top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True) temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=1,interactive=True) @@ -633,7 +645,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app: inference_button.click( get_tts_wav, - [inp_ref, prompt_text, prompt_language, text, text_language, how_to_cut, top_k, top_p, temperature], + [inp_ref, prompt_text, prompt_language, text, text_language, how_to_cut, top_k, top_p, temperature, ref_text_free], [output], )