关于标点符号导致参考泄漏的问题 (#1168)

* punctuation

* Update inference_webui.py

* Update

* update

* update
This commit is contained in:
XXXXRT666 2024-06-10 09:18:30 +01:00 committed by GitHub
parent 501a74ae96
commit bedb421adb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -50,6 +50,7 @@ is_share = eval(is_share)
if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
punctuation = set(['!', '?', '', ',', '.', '-'," "])
import gradio as gr
from transformers import AutoModelForMaskedLM, AutoTokenizer
import numpy as np
@ -322,6 +323,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
if (prompt_text[-1] not in splits): prompt_text += "" if prompt_language != "en" else "."
print(i18n("实际输入的参考文本:"), prompt_text)
text = text.strip("\n")
text = replace_consecutive_punctuation(text)
if (text[0] not in splits and len(get_first(text)) < 4): text = "" + text if text_language != "en" else "." + text
print(i18n("实际输入的目标文本:"), text)
@ -366,6 +368,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
text = text.replace("\n\n", "\n")
print(i18n("实际输入的目标文本(切句后):"), text)
texts = text.split("\n")
texts = process_text(texts)
texts = merge_short_text_in_array(texts, 5)
audio_opt = []
if not ref_free:
@ -463,6 +466,7 @@ def cut1(inp):
opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]]))
else:
opts = [inp]
opts = [item for item in opts if not set(item).issubset(punctuation)]
return "\n".join(opts)
@ -487,17 +491,21 @@ def cut2(inp):
if len(opts) > 1 and len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起
opts[-2] = opts[-2] + opts[-1]
opts = opts[:-1]
opts = [item for item in opts if not set(item).issubset(punctuation)]
return "\n".join(opts)
def cut3(inp):
inp = inp.strip("\n")
return "\n".join(["%s" % item for item in inp.strip("").split("")])
opts = ["%s" % item for item in inp.strip("").split("")]
opts = [item for item in opts if not set(item).issubset(punctuation)]
return "\n".join(opts)
def cut4(inp):
inp = inp.strip("\n")
return "\n".join(["%s" % item for item in inp.strip(".").split(".")])
opts = ["%s" % item for item in inp.strip(".").split(".")]
opts = [item for item in opts if not set(item).issubset(punctuation)]
return "\n".join(opts)
# contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py
@ -511,8 +519,8 @@ def cut5(inp):
# 在句子不存在符号或句尾无符号的时候保证文本完整
if len(items)%2 == 1:
mergeitems.append(items[-1])
opt = "\n".join(mergeitems)
return opt
opt = [item for item in mergeitems if not set(item).issubset(punctuation)]
return "\n".join(opt)
def custom_sort_key(s):
@ -522,6 +530,24 @@ def custom_sort_key(s):
parts = [int(part) if part.isdigit() else part for part in parts]
return parts
def process_text(texts):
_text=[]
if all(text in [None, " ", "\n",""] for text in texts):
raise ValueError(i18n("请输入有效文本"))
for text in texts:
if text in [None, " ", ""]:
pass
else:
_text.append(text)
return _text
def replace_consecutive_punctuation(text):
punctuations = ''.join(re.escape(p) for p in punctuation)
pattern = f'([{punctuations}])([{punctuations}])+'
result = re.sub(pattern, r'\1', text)
return result
def change_choices():
SoVITS_names, GPT_names = get_weights_names()