mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-05-01 07:02:44 +08:00
关于标点符号导致参考泄漏的问题 (#1168)
* punctuation * Update inference_webui.py * Update * update * update
This commit is contained in:
parent
501a74ae96
commit
bedb421adb
@ -50,6 +50,7 @@ is_share = eval(is_share)
|
|||||||
if "_CUDA_VISIBLE_DEVICES" in os.environ:
|
if "_CUDA_VISIBLE_DEVICES" in os.environ:
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
|
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
|
||||||
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
|
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
|
||||||
|
punctuation = set(['!', '?', '…', ',', '.', '-'," "])
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -322,6 +323,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
|||||||
if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_language != "en" else "."
|
if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_language != "en" else "."
|
||||||
print(i18n("实际输入的参考文本:"), prompt_text)
|
print(i18n("实际输入的参考文本:"), prompt_text)
|
||||||
text = text.strip("\n")
|
text = text.strip("\n")
|
||||||
|
text = replace_consecutive_punctuation(text)
|
||||||
if (text[0] not in splits and len(get_first(text)) < 4): text = "。" + text if text_language != "en" else "." + text
|
if (text[0] not in splits and len(get_first(text)) < 4): text = "。" + text if text_language != "en" else "." + text
|
||||||
|
|
||||||
print(i18n("实际输入的目标文本:"), text)
|
print(i18n("实际输入的目标文本:"), text)
|
||||||
@ -366,6 +368,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
|||||||
text = text.replace("\n\n", "\n")
|
text = text.replace("\n\n", "\n")
|
||||||
print(i18n("实际输入的目标文本(切句后):"), text)
|
print(i18n("实际输入的目标文本(切句后):"), text)
|
||||||
texts = text.split("\n")
|
texts = text.split("\n")
|
||||||
|
texts = process_text(texts)
|
||||||
texts = merge_short_text_in_array(texts, 5)
|
texts = merge_short_text_in_array(texts, 5)
|
||||||
audio_opt = []
|
audio_opt = []
|
||||||
if not ref_free:
|
if not ref_free:
|
||||||
@ -463,6 +466,7 @@ def cut1(inp):
|
|||||||
opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]]))
|
opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]]))
|
||||||
else:
|
else:
|
||||||
opts = [inp]
|
opts = [inp]
|
||||||
|
opts = [item for item in opts if not set(item).issubset(punctuation)]
|
||||||
return "\n".join(opts)
|
return "\n".join(opts)
|
||||||
|
|
||||||
|
|
||||||
@ -487,17 +491,21 @@ def cut2(inp):
|
|||||||
if len(opts) > 1 and len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起
|
if len(opts) > 1 and len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起
|
||||||
opts[-2] = opts[-2] + opts[-1]
|
opts[-2] = opts[-2] + opts[-1]
|
||||||
opts = opts[:-1]
|
opts = opts[:-1]
|
||||||
|
opts = [item for item in opts if not set(item).issubset(punctuation)]
|
||||||
return "\n".join(opts)
|
return "\n".join(opts)
|
||||||
|
|
||||||
|
|
||||||
def cut3(inp):
|
def cut3(inp):
|
||||||
inp = inp.strip("\n")
|
inp = inp.strip("\n")
|
||||||
return "\n".join(["%s" % item for item in inp.strip("。").split("。")])
|
opts = ["%s" % item for item in inp.strip("。").split("。")]
|
||||||
|
opts = [item for item in opts if not set(item).issubset(punctuation)]
|
||||||
|
return "\n".join(opts)
|
||||||
|
|
||||||
def cut4(inp):
|
def cut4(inp):
|
||||||
inp = inp.strip("\n")
|
inp = inp.strip("\n")
|
||||||
return "\n".join(["%s" % item for item in inp.strip(".").split(".")])
|
opts = ["%s" % item for item in inp.strip(".").split(".")]
|
||||||
|
opts = [item for item in opts if not set(item).issubset(punctuation)]
|
||||||
|
return "\n".join(opts)
|
||||||
|
|
||||||
|
|
||||||
# contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py
|
# contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py
|
||||||
@ -511,8 +519,8 @@ def cut5(inp):
|
|||||||
# 在句子不存在符号或句尾无符号的时候保证文本完整
|
# 在句子不存在符号或句尾无符号的时候保证文本完整
|
||||||
if len(items)%2 == 1:
|
if len(items)%2 == 1:
|
||||||
mergeitems.append(items[-1])
|
mergeitems.append(items[-1])
|
||||||
opt = "\n".join(mergeitems)
|
opt = [item for item in mergeitems if not set(item).issubset(punctuation)]
|
||||||
return opt
|
return "\n".join(opt)
|
||||||
|
|
||||||
|
|
||||||
def custom_sort_key(s):
|
def custom_sort_key(s):
|
||||||
@ -522,6 +530,24 @@ def custom_sort_key(s):
|
|||||||
parts = [int(part) if part.isdigit() else part for part in parts]
|
parts = [int(part) if part.isdigit() else part for part in parts]
|
||||||
return parts
|
return parts
|
||||||
|
|
||||||
|
def process_text(texts):
|
||||||
|
_text=[]
|
||||||
|
if all(text in [None, " ", "\n",""] for text in texts):
|
||||||
|
raise ValueError(i18n("请输入有效文本"))
|
||||||
|
for text in texts:
|
||||||
|
if text in [None, " ", ""]:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
_text.append(text)
|
||||||
|
return _text
|
||||||
|
|
||||||
|
|
||||||
|
def replace_consecutive_punctuation(text):
|
||||||
|
punctuations = ''.join(re.escape(p) for p in punctuation)
|
||||||
|
pattern = f'([{punctuations}])([{punctuations}])+'
|
||||||
|
result = re.sub(pattern, r'\1', text)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def change_choices():
|
def change_choices():
|
||||||
SoVITS_names, GPT_names = get_weights_names()
|
SoVITS_names, GPT_names = get_weights_names()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user