mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-07 15:19:59 +08:00
punctuation
This commit is contained in:
parent
e106a5ee88
commit
0182d581f7
@ -17,7 +17,7 @@ logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
|
||||
logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
|
||||
import pdb
|
||||
import torch
|
||||
|
||||
import warnings
|
||||
if os.path.exists("./gweight.txt"):
|
||||
with open("./gweight.txt", 'r', encoding="utf-8") as file:
|
||||
gweight_data = file.read()
|
||||
@ -50,6 +50,7 @@ is_share = eval(is_share)
|
||||
if "_CUDA_VISIBLE_DEVICES" in os.environ:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
|
||||
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
|
||||
punctuation = set(['!', '?', '…', ',', '.', '-'," "]);print(punctuation)
|
||||
import gradio as gr
|
||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
import numpy as np
|
||||
@ -366,6 +367,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
||||
text = text.replace("\n\n", "\n")
|
||||
print(i18n("实际输入的目标文本(切句后):"), text)
|
||||
texts = text.split("\n")
|
||||
text = process_text(text)
|
||||
texts = merge_short_text_in_array(texts, 5)
|
||||
audio_opt = []
|
||||
if not ref_free:
|
||||
@ -463,6 +465,7 @@ def cut1(inp):
|
||||
opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]]))
|
||||
else:
|
||||
opts = [inp]
|
||||
opts = [item for item in opts if not set(item).issubset(punctuation)]
|
||||
return "\n".join(opts)
|
||||
|
||||
|
||||
@ -487,17 +490,21 @@ def cut2(inp):
|
||||
if len(opts) > 1 and len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起
|
||||
opts[-2] = opts[-2] + opts[-1]
|
||||
opts = opts[:-1]
|
||||
opts = [item for item in opts if not set(item).issubset(punctuation)]
|
||||
return "\n".join(opts)
|
||||
|
||||
|
||||
def cut3(inp):
|
||||
inp = inp.strip("\n")
|
||||
return "\n".join(["%s" % item for item in inp.strip("。").split("。")])
|
||||
|
||||
opts = ["%s" % item for item in inp.strip("。").split("。")]
|
||||
opts = [item for item in opts if not set(item).issubset(punctuation)]
|
||||
return "\n".join(opts)
|
||||
|
||||
def cut4(inp):
|
||||
inp = inp.strip("\n")
|
||||
return "\n".join(["%s" % item for item in inp.strip(".").split(".")])
|
||||
opts = ["%s" % item for item in inp.strip(".").split(".")]
|
||||
opts = [item for item in opts if not set(item).issubset(punctuation)]
|
||||
return "\n".join(opts)
|
||||
|
||||
|
||||
# contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py
|
||||
@ -511,8 +518,8 @@ def cut5(inp):
|
||||
# 在句子不存在符号或句尾无符号的时候保证文本完整
|
||||
if len(items)%2 == 1:
|
||||
mergeitems.append(items[-1])
|
||||
opt = "\n".join(mergeitems)
|
||||
return opt
|
||||
opt = [item for item in mergeitems if not set(item).issubset(punctuation)]
|
||||
return "\n".join(opt)
|
||||
|
||||
|
||||
def custom_sort_key(s):
|
||||
@ -522,6 +529,18 @@ def custom_sort_key(s):
|
||||
parts = [int(part) if part.isdigit() else part for part in parts]
|
||||
return parts
|
||||
|
||||
def process_text(texts):
|
||||
_text=[]
|
||||
if all(text in [None, " ", "\n",""] for text in texts):
|
||||
raise ValueError(i18n("请输入有效文本"))
|
||||
|
||||
for text in texts:
|
||||
if text in [None, " ", "\n"]:
|
||||
warnings.warn(i18n("文本中包含连续标点"))
|
||||
else:
|
||||
_text.append(text)
|
||||
return _text
|
||||
|
||||
|
||||
def change_choices():
|
||||
SoVITS_names, GPT_names = get_weights_names()
|
||||
|
Loading…
x
Reference in New Issue
Block a user