punctuation

This commit is contained in:
XXXXRT666 2024-06-08 16:38:38 +01:00
parent e106a5ee88
commit 0182d581f7

View File

@ -17,7 +17,7 @@ logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
logging.getLogger("torchaudio._extension").setLevel(logging.ERROR) logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
import pdb import pdb
import torch import torch
import warnings
if os.path.exists("./gweight.txt"): if os.path.exists("./gweight.txt"):
with open("./gweight.txt", 'r', encoding="utf-8") as file: with open("./gweight.txt", 'r', encoding="utf-8") as file:
gweight_data = file.read() gweight_data = file.read()
@ -50,6 +50,7 @@ is_share = eval(is_share)
if "_CUDA_VISIBLE_DEVICES" in os.environ: if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"] os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available() is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
punctuation = set(['!', '?', '', ',', '.', '-'," "]);print(punctuation)
import gradio as gr import gradio as gr
from transformers import AutoModelForMaskedLM, AutoTokenizer from transformers import AutoModelForMaskedLM, AutoTokenizer
import numpy as np import numpy as np
@ -366,6 +367,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
text = text.replace("\n\n", "\n") text = text.replace("\n\n", "\n")
print(i18n("实际输入的目标文本(切句后):"), text) print(i18n("实际输入的目标文本(切句后):"), text)
texts = text.split("\n") texts = text.split("\n")
text = process_text(text)
texts = merge_short_text_in_array(texts, 5) texts = merge_short_text_in_array(texts, 5)
audio_opt = [] audio_opt = []
if not ref_free: if not ref_free:
@ -463,6 +465,7 @@ def cut1(inp):
opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]])) opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]]))
else: else:
opts = [inp] opts = [inp]
opts = [item for item in opts if not set(item).issubset(punctuation)]
return "\n".join(opts) return "\n".join(opts)
@ -487,17 +490,21 @@ def cut2(inp):
if len(opts) > 1 and len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起 if len(opts) > 1 and len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起
opts[-2] = opts[-2] + opts[-1] opts[-2] = opts[-2] + opts[-1]
opts = opts[:-1] opts = opts[:-1]
opts = [item for item in opts if not set(item).issubset(punctuation)]
return "\n".join(opts) return "\n".join(opts)
def cut3(inp): def cut3(inp):
inp = inp.strip("\n") inp = inp.strip("\n")
return "\n".join(["%s" % item for item in inp.strip("").split("")]) opts = ["%s" % item for item in inp.strip("").split("")]
opts = [item for item in opts if not set(item).issubset(punctuation)]
return "\n".join(opts)
def cut4(inp): def cut4(inp):
inp = inp.strip("\n") inp = inp.strip("\n")
return "\n".join(["%s" % item for item in inp.strip(".").split(".")]) opts = ["%s" % item for item in inp.strip(".").split(".")]
opts = [item for item in opts if not set(item).issubset(punctuation)]
return "\n".join(opts)
# contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py # contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py
@ -511,8 +518,8 @@ def cut5(inp):
# 在句子不存在符号或句尾无符号的时候保证文本完整 # 在句子不存在符号或句尾无符号的时候保证文本完整
if len(items)%2 == 1: if len(items)%2 == 1:
mergeitems.append(items[-1]) mergeitems.append(items[-1])
opt = "\n".join(mergeitems) opt = [item for item in mergeitems if not set(item).issubset(punctuation)]
return opt return "\n".join(opt)
def custom_sort_key(s): def custom_sort_key(s):
@ -522,6 +529,18 @@ def custom_sort_key(s):
parts = [int(part) if part.isdigit() else part for part in parts] parts = [int(part) if part.isdigit() else part for part in parts]
return parts return parts
def process_text(texts):
_text=[]
if all(text in [None, " ", "\n",""] for text in texts):
raise ValueError(i18n("请输入有效文本"))
for text in texts:
if text in [None, " ", "\n"]:
warnings.warn(i18n("文本中包含连续标点"))
else:
_text.append(text)
return _text
def change_choices(): def change_choices():
SoVITS_names, GPT_names = get_weights_names() SoVITS_names, GPT_names = get_weights_names()