提高切分生成的鲁棒性

This commit is contained in:
ALEXuH 2024-03-04 19:28:22 +08:00
parent b75b5dcf6b
commit 6dfe56cfcf

View File

@ -17,6 +17,7 @@ logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
import pdb
import torch
import jieba
if os.path.exists("./gweight.txt"):
with open("./gweight.txt", 'r', encoding="utf-8") as file:
@ -228,7 +229,7 @@ def get_bert_inf(phones, word2ph, norm_text, language):
splits = {"", "", "", "", ",", ".", "?", "!", "~", ":", "", "", "", }
MAX_LENGTH = 500
def get_first(text):
pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
@ -432,6 +433,21 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
np.int16
)
def split_by_tokenizer(text):
tokens = jieba.lcut(text)
slices = []
temp_slice = []
for token in tokens:
if len("".join(temp_slice + [token])) > MAX_LENGTH:
slices.append("".join(temp_slice))
temp_slice = []
temp_slice.append(token)
if temp_slice:
slices.append("".join(temp_slice))
return slices
def split(todo_text):
todo_text = todo_text.replace("……", "").replace("——", "")
@ -440,55 +456,63 @@ def split(todo_text):
i_split_head = i_split_tail = 0
len_text = len(todo_text)
todo_texts = []
while 1:
while True:
if i_split_head >= len_text:
break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
if todo_text[i_split_head] in splits:
i_split_head += 1
todo_texts.append(todo_text[i_split_tail:i_split_head])
segment = todo_text[i_split_tail:i_split_head]
if len(segment) > MAX_LENGTH:
# 如果段落长度超过最大长度,进行语义切分
todo_texts.extend(split_by_tokenizer(segment))
else:
todo_texts.append(segment)
i_split_tail = i_split_head
else:
i_split_head += 1
return todo_texts
def cut1(inp):
def cut1(inp, max_concat=4):
inp = inp.strip("\n")
inps = split(inp)
split_idx = list(range(0, len(inps), 4))
split_idx[-1] = None
if len(split_idx) > 1:
opts = []
for idx in range(len(split_idx) - 1):
opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]]))
else:
opts = [inp]
opts = []
temp_segment = []
for segment in inps:
if len("".join(temp_segment + [segment])) > MAX_LENGTH or len(temp_segment) + 1 > max_concat:
opts.append("".join(temp_segment))
temp_segment = [segment]
else:
temp_segment.append(segment)
if temp_segment:
opts.append("".join(temp_segment))
return "\n".join(opts)
def cut2(inp):
def cut2(inp, desired_length=50):
inp = inp.strip("\n")
inps = split(inp)
if len(inps) < 2:
return inp
opts = []
summ = 0
tmp_str = ""
for i in range(len(inps)):
summ += len(inps[i])
tmp_str += inps[i]
if summ > 50:
summ = 0
opts.append(tmp_str)
tmp_str = ""
if tmp_str != "":
for segment in inps:
if len(tmp_str + segment) > desired_length:
if tmp_str:
opts.append(tmp_str)
tmp_str = segment
else:
tmp_str += segment
if tmp_str:
opts.append(tmp_str)
# print(opts)
if len(opts) > 1 and len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起
if len(opts) > 1 and len(opts[-1]) < desired_length and len(opts[-1]) + len(opts[-2]) < MAX_LENGTH: ##如果最后一个太短了,和前一个合一起
opts[-2] = opts[-2] + opts[-1]
opts = opts[:-1]
return "\n".join(opts)
return "\n".join(opts)
def cut3(inp):
inp = inp.strip("\n")