diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index 4fe8045..03440a3 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -50,6 +50,7 @@ is_share = eval(is_share) if "_CUDA_VISIBLE_DEVICES" in os.environ: os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"] is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available() +punctuation = set(['!', '?', '…', ',', '.', '-'," "]) import gradio as gr from transformers import AutoModelForMaskedLM, AutoTokenizer import numpy as np @@ -322,6 +323,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_language != "en" else "." print(i18n("实际输入的参考文本:"), prompt_text) text = text.strip("\n") + text = replace_consecutive_punctuation(text) if (text[0] not in splits and len(get_first(text)) < 4): text = "。" + text if text_language != "en" else "." + text print(i18n("实际输入的目标文本:"), text) @@ -366,6 +368,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, text = text.replace("\n\n", "\n") print(i18n("实际输入的目标文本(切句后):"), text) texts = text.split("\n") + texts = process_text(texts) texts = merge_short_text_in_array(texts, 5) audio_opt = [] if not ref_free: @@ -463,6 +466,7 @@ def cut1(inp): opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]])) else: opts = [inp] + opts = [item for item in opts if not set(item).issubset(punctuation)] return "\n".join(opts) @@ -487,17 +491,21 @@ def cut2(inp): if len(opts) > 1 and len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起 opts[-2] = opts[-2] + opts[-1] opts = opts[:-1] + opts = [item for item in opts if not set(item).issubset(punctuation)] return "\n".join(opts) def cut3(inp): inp = inp.strip("\n") - return "\n".join(["%s" % item for item in inp.strip("。").split("。")]) - + opts = ["%s" % item for item in inp.strip("。").split("。")] + opts = [item for item in opts if not set(item).issubset(punctuation)] + return "\n".join(opts) def cut4(inp): inp = inp.strip("\n") - return "\n".join(["%s" % item for item in inp.strip(".").split(".")]) + opts = ["%s" % item for item in inp.strip(".").split(".")] + opts = [item for item in opts if not set(item).issubset(punctuation)] + return "\n".join(opts) # contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py @@ -511,8 +519,8 @@ def cut5(inp): # 在句子不存在符号或句尾无符号的时候保证文本完整 if len(items)%2 == 1: mergeitems.append(items[-1]) - opt = "\n".join(mergeitems) - return opt + opt = [item for item in mergeitems if not set(item).issubset(punctuation)] + return "\n".join(opt) def custom_sort_key(s): @@ -522,6 +530,24 @@ def custom_sort_key(s): parts = [int(part) if part.isdigit() else part for part in parts] return parts +def process_text(texts): + _text=[] + if all(text in [None, " ", "\n",""] for text in texts): + raise ValueError(i18n("请输入有效文本")) + for text in texts: + if text in [None, " ", ""]: + pass + else: + _text.append(text) + return _text + + +def replace_consecutive_punctuation(text): + punctuations = ''.join(re.escape(p) for p in punctuation) + pattern = f'([{punctuations}])([{punctuations}])+' + result = re.sub(pattern, r'\1', text) + return result + def change_choices(): SoVITS_names, GPT_names = get_weights_names()