mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-07 15:19:59 +08:00
提高切分生成的鲁棒性
This commit is contained in:
parent
b75b5dcf6b
commit
6dfe56cfcf
@ -17,6 +17,7 @@ logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
|
|||||||
logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
|
logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
|
||||||
import pdb
|
import pdb
|
||||||
import torch
|
import torch
|
||||||
|
import jieba
|
||||||
|
|
||||||
if os.path.exists("./gweight.txt"):
|
if os.path.exists("./gweight.txt"):
|
||||||
with open("./gweight.txt", 'r', encoding="utf-8") as file:
|
with open("./gweight.txt", 'r', encoding="utf-8") as file:
|
||||||
@ -228,7 +229,7 @@ def get_bert_inf(phones, word2ph, norm_text, language):
|
|||||||
|
|
||||||
|
|
||||||
splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", }
|
splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", }
|
||||||
|
MAX_LENGTH = 500
|
||||||
|
|
||||||
def get_first(text):
|
def get_first(text):
|
||||||
pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
|
pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
|
||||||
@ -432,6 +433,21 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
|||||||
np.int16
|
np.int16
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def split_by_tokenizer(text):
|
||||||
|
tokens = jieba.lcut(text)
|
||||||
|
|
||||||
|
slices = []
|
||||||
|
temp_slice = []
|
||||||
|
for token in tokens:
|
||||||
|
if len("".join(temp_slice + [token])) > MAX_LENGTH:
|
||||||
|
slices.append("".join(temp_slice))
|
||||||
|
temp_slice = []
|
||||||
|
temp_slice.append(token)
|
||||||
|
|
||||||
|
if temp_slice:
|
||||||
|
slices.append("".join(temp_slice))
|
||||||
|
|
||||||
|
return slices
|
||||||
|
|
||||||
def split(todo_text):
|
def split(todo_text):
|
||||||
todo_text = todo_text.replace("……", "。").replace("——", ",")
|
todo_text = todo_text.replace("……", "。").replace("——", ",")
|
||||||
@ -440,55 +456,63 @@ def split(todo_text):
|
|||||||
i_split_head = i_split_tail = 0
|
i_split_head = i_split_tail = 0
|
||||||
len_text = len(todo_text)
|
len_text = len(todo_text)
|
||||||
todo_texts = []
|
todo_texts = []
|
||||||
while 1:
|
while True:
|
||||||
if i_split_head >= len_text:
|
if i_split_head >= len_text:
|
||||||
break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
|
break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
|
||||||
if todo_text[i_split_head] in splits:
|
if todo_text[i_split_head] in splits:
|
||||||
i_split_head += 1
|
i_split_head += 1
|
||||||
todo_texts.append(todo_text[i_split_tail:i_split_head])
|
segment = todo_text[i_split_tail:i_split_head]
|
||||||
|
if len(segment) > MAX_LENGTH:
|
||||||
|
# 如果段落长度超过最大长度,进行语义切分
|
||||||
|
todo_texts.extend(split_by_tokenizer(segment))
|
||||||
|
else:
|
||||||
|
todo_texts.append(segment)
|
||||||
i_split_tail = i_split_head
|
i_split_tail = i_split_head
|
||||||
else:
|
else:
|
||||||
i_split_head += 1
|
i_split_head += 1
|
||||||
return todo_texts
|
return todo_texts
|
||||||
|
|
||||||
|
def cut1(inp, max_concat=4):
|
||||||
def cut1(inp):
|
|
||||||
inp = inp.strip("\n")
|
inp = inp.strip("\n")
|
||||||
inps = split(inp)
|
inps = split(inp)
|
||||||
split_idx = list(range(0, len(inps), 4))
|
|
||||||
split_idx[-1] = None
|
|
||||||
if len(split_idx) > 1:
|
|
||||||
opts = []
|
opts = []
|
||||||
for idx in range(len(split_idx) - 1):
|
temp_segment = []
|
||||||
opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]]))
|
|
||||||
|
for segment in inps:
|
||||||
|
if len("".join(temp_segment + [segment])) > MAX_LENGTH or len(temp_segment) + 1 > max_concat:
|
||||||
|
opts.append("".join(temp_segment))
|
||||||
|
temp_segment = [segment]
|
||||||
else:
|
else:
|
||||||
opts = [inp]
|
temp_segment.append(segment)
|
||||||
|
|
||||||
|
if temp_segment:
|
||||||
|
opts.append("".join(temp_segment))
|
||||||
|
|
||||||
return "\n".join(opts)
|
return "\n".join(opts)
|
||||||
|
|
||||||
|
|
||||||
def cut2(inp):
|
def cut2(inp, desired_length=50):
|
||||||
inp = inp.strip("\n")
|
inp = inp.strip("\n")
|
||||||
inps = split(inp)
|
inps = split(inp)
|
||||||
if len(inps) < 2:
|
if len(inps) < 2:
|
||||||
return inp
|
return inp
|
||||||
opts = []
|
opts = []
|
||||||
summ = 0
|
|
||||||
tmp_str = ""
|
tmp_str = ""
|
||||||
for i in range(len(inps)):
|
|
||||||
summ += len(inps[i])
|
for segment in inps:
|
||||||
tmp_str += inps[i]
|
if len(tmp_str + segment) > desired_length:
|
||||||
if summ > 50:
|
if tmp_str:
|
||||||
summ = 0
|
|
||||||
opts.append(tmp_str)
|
opts.append(tmp_str)
|
||||||
tmp_str = ""
|
tmp_str = segment
|
||||||
if tmp_str != "":
|
else:
|
||||||
|
tmp_str += segment
|
||||||
|
if tmp_str:
|
||||||
opts.append(tmp_str)
|
opts.append(tmp_str)
|
||||||
# print(opts)
|
if len(opts) > 1 and len(opts[-1]) < desired_length and len(opts[-1]) + len(opts[-2]) < MAX_LENGTH: ##如果最后一个太短了,和前一个合一起
|
||||||
if len(opts) > 1 and len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起
|
|
||||||
opts[-2] = opts[-2] + opts[-1]
|
opts[-2] = opts[-2] + opts[-1]
|
||||||
opts = opts[:-1]
|
opts = opts[:-1]
|
||||||
return "\n".join(opts)
|
|
||||||
|
|
||||||
|
return "\n".join(opts)
|
||||||
|
|
||||||
def cut3(inp):
|
def cut3(inp):
|
||||||
inp = inp.strip("\n")
|
inp = inp.strip("\n")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user