diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index df32d365..1fd033f0 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -17,6 +17,7 @@ logging.getLogger("charset_normalizer").setLevel(logging.ERROR) logging.getLogger("torchaudio._extension").setLevel(logging.ERROR) import pdb import torch +import jieba if os.path.exists("./gweight.txt"): with open("./gweight.txt", 'r', encoding="utf-8") as file: @@ -228,7 +229,7 @@ def get_bert_inf(phones, word2ph, norm_text, language): splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", } - +MAX_LENGTH = 500 def get_first(text): pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]" @@ -432,6 +433,21 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, np.int16 ) +def split_by_tokenizer(text): + tokens = jieba.lcut(text) + + slices = [] + temp_slice = [] + for token in tokens: + if len("".join(temp_slice + [token])) > MAX_LENGTH: + slices.append("".join(temp_slice)) + temp_slice = [] + temp_slice.append(token) + + if temp_slice: + slices.append("".join(temp_slice)) + + return slices def split(todo_text): todo_text = todo_text.replace("……", "。").replace("——", ",") @@ -440,55 +456,63 @@ def split(todo_text): i_split_head = i_split_tail = 0 len_text = len(todo_text) todo_texts = [] - while 1: + while True: if i_split_head >= len_text: break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入 if todo_text[i_split_head] in splits: i_split_head += 1 - todo_texts.append(todo_text[i_split_tail:i_split_head]) + segment = todo_text[i_split_tail:i_split_head] + if len(segment) > MAX_LENGTH: + # 如果段落长度超过最大长度,进行语义切分 + todo_texts.extend(split_by_tokenizer(segment)) + else: + todo_texts.append(segment) i_split_tail = i_split_head else: i_split_head += 1 return todo_texts - -def cut1(inp): +def cut1(inp, max_concat=4): inp = inp.strip("\n") inps = split(inp) - split_idx = list(range(0, len(inps), 4)) - split_idx[-1] = None - if len(split_idx) > 1: - opts = [] - for idx in range(len(split_idx) - 1): - opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]])) - else: - opts = [inp] + opts = [] + temp_segment = [] + + for segment in inps: + if len("".join(temp_segment + [segment])) > MAX_LENGTH or len(temp_segment) + 1 > max_concat: + opts.append("".join(temp_segment)) + temp_segment = [segment] + else: + temp_segment.append(segment) + + if temp_segment: + opts.append("".join(temp_segment)) + return "\n".join(opts) -def cut2(inp): +def cut2(inp, desired_length=50): inp = inp.strip("\n") inps = split(inp) if len(inps) < 2: return inp opts = [] - summ = 0 tmp_str = "" - for i in range(len(inps)): - summ += len(inps[i]) - tmp_str += inps[i] - if summ > 50: - summ = 0 - opts.append(tmp_str) - tmp_str = "" - if tmp_str != "": + + for segment in inps: + if len(tmp_str + segment) > desired_length: + if tmp_str: + opts.append(tmp_str) + tmp_str = segment + else: + tmp_str += segment + if tmp_str: opts.append(tmp_str) - # print(opts) - if len(opts) > 1 and len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起 + if len(opts) > 1 and len(opts[-1]) < desired_length and len(opts[-1]) + len(opts[-2]) < MAX_LENGTH: ##如果最后一个太短了,和前一个合一起 opts[-2] = opts[-2] + opts[-1] opts = opts[:-1] - return "\n".join(opts) + return "\n".join(opts) def cut3(inp): inp = inp.strip("\n")