diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index 4fe8045d..574b06f6 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -17,7 +17,7 @@ logging.getLogger("charset_normalizer").setLevel(logging.ERROR) logging.getLogger("torchaudio._extension").setLevel(logging.ERROR) import pdb import torch - +import warnings if os.path.exists("./gweight.txt"): with open("./gweight.txt", 'r', encoding="utf-8") as file: gweight_data = file.read() @@ -50,6 +50,7 @@ is_share = eval(is_share) if "_CUDA_VISIBLE_DEVICES" in os.environ: os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"] is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available() +punctuation = set(['!', '?', '…', ',', '.', '-'," "]);print(punctuation) import gradio as gr from transformers import AutoModelForMaskedLM, AutoTokenizer import numpy as np @@ -366,6 +367,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, text = text.replace("\n\n", "\n") print(i18n("实际输入的目标文本(切句后):"), text) texts = text.split("\n") + text = process_text(text) texts = merge_short_text_in_array(texts, 5) audio_opt = [] if not ref_free: @@ -463,6 +465,7 @@ def cut1(inp): opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]])) else: opts = [inp] + opts = [item for item in opts if not set(item).issubset(punctuation)] return "\n".join(opts) @@ -487,17 +490,21 @@ def cut2(inp): if len(opts) > 1 and len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起 opts[-2] = opts[-2] + opts[-1] opts = opts[:-1] + opts = [item for item in opts if not set(item).issubset(punctuation)] return "\n".join(opts) def cut3(inp): inp = inp.strip("\n") - return "\n".join(["%s" % item for item in inp.strip("。").split("。")]) - + opts = ["%s" % item for item in inp.strip("。").split("。")] + opts = [item for item in opts if not set(item).issubset(punctuation)] + return "\n".join(opts) def cut4(inp): inp = inp.strip("\n") - return "\n".join(["%s" % item for item in inp.strip(".").split(".")]) + opts = ["%s" % item for item in inp.strip(".").split(".")] + opts = [item for item in opts if not set(item).issubset(punctuation)] + return "\n".join(opts) # contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py @@ -511,8 +518,8 @@ def cut5(inp): # 在句子不存在符号或句尾无符号的时候保证文本完整 if len(items)%2 == 1: mergeitems.append(items[-1]) - opt = "\n".join(mergeitems) - return opt + opt = [item for item in mergeitems if not set(item).issubset(punctuation)] + return "\n".join(opt) def custom_sort_key(s): @@ -522,6 +529,18 @@ def custom_sort_key(s): parts = [int(part) if part.isdigit() else part for part in parts] return parts +def process_text(texts): + _text=[] + if all(text in [None, " ", "\n",""] for text in texts): + raise ValueError(i18n("请输入有效文本")) + + for text in texts: + if text in [None, " ", "\n"]: + warnings.warn(i18n("文本中包含连续标点")) + else: + _text.append(text) + return _text + def change_choices(): SoVITS_names, GPT_names = get_weights_names()