diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index ee099627..263ae842 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -293,6 +293,22 @@ def get_phones_and_bert(text,language): return phones,bert.to(dtype),norm_text +def get_target_phones_and_bert(text): + phones_list = [] + bert_list = [] + norm_text_list = [] + for t in text: + phones, word2ph, norm_text = clean_text_inf(t['text'], t['lang']) + bert = get_bert_inf(phones, word2ph, norm_text, t['lang']) + phones_list.append(phones) + norm_text_list.append(norm_text) + bert_list.append(bert) + bert = torch.cat(bert_list, dim=1) + phones = sum(phones_list, []) + norm_text = ''.join(norm_text_list) + + return phones,bert.to(dtype),norm_text + def merge_short_text_in_array(texts, threshold): if (len(texts)) < 2: @@ -311,6 +327,209 @@ def merge_short_text_in_array(texts, threshold): result[len(result) - 1] += text return result +# cut1 +# 凑四句一切。切 +def cut1(language_splits): + sentences_list = [] + sentences = [] + sentences_count = 0 + + for lang_block in language_splits: + lang_block['text'] = lang_block['text'].replace('。','.') + text_parts = lang_block['text'].split('.') + + if len(text_parts) == 1: + sentences.append({'lang': lang_block['lang'], 'text': lang_block['text']}) + else : + for i, part in enumerate(text_parts): + if i < len(text_parts) - 1: + sentences.append({'lang': lang_block['lang'], 'text': part + '.'}) + if part and not part.isspace(): + sentences_count += 1 + if sentences_count >=4: + sentences_list.append(sentences) + sentences = [] + sentences_count = 0 + elif part and not part.isspace(): + sentences.append({'lang': lang_block['lang'], 'text': part}) + + if sentences: + sentences_list.append(sentences) + + return sentences_list + + +# cut2 +# 凑50字一切。切 +def cut2(language_splits): + sentences_list = [] + sentences = [] + senteces_count = 0 + last_senteces_count = 0 + + for lang_block in language_splits: + senteces_count += len(lang_block['text']) + + if senteces_count <= 50: + sentences.append({'lang': lang_block['lang'], 'text': lang_block['text']}) + last_senteces_count = senteces_count + else : + t_text = lang_block['text'] + while senteces_count > 50: + sentences.append({'lang': lang_block['lang'], 'text': t_text[:50-last_senteces_count]}) + sentences_list.append(sentences) + sentences = [] + t_text = t_text[50-last_senteces_count:] + last_senteces_count = 0 + senteces_count -= 50 + + sentences.append({'lang': lang_block['lang'], 'text': t_text}) + last_senteces_count = senteces_count + + if sentences: + sentences_list.append(sentences) + + return sentences_list + + +# cut3 +# 按中文句号。切 +def cut3(language_splits): + sentences_list = [] + + for lang_block in language_splits: + text_parts = lang_block['text'].split('。') + + if len(text_parts) <= 1: + sentences_list.append([{'lang': lang_block['lang'], 'text': lang_block['text']}]) + else: + for i, part in enumerate(text_parts[:-1]): + sentences_list.append([{'lang': lang_block['lang'], 'text': part.strip() + "。"}]) + + last_part = text_parts[-1].strip() + if last_part: + sentences_list.append([{'lang': lang_block['lang'], 'text': last_part}]) + + return sentences_list + +# cut4 +# 按英文句号.切 +def cut4(language_splits): + sentences_list = [] + + for lang_block in language_splits: + text_parts = lang_block['text'].split('.') + + if len(text_parts) <= 1: + sentences_list.append([{'lang': lang_block['lang'], 'text': lang_block['text']}]) + else: + for i, part in enumerate(text_parts[:-1]): + sentences_list.append([{'lang': lang_block['lang'], 'text': part.strip() + "."}]) + + last_part = text_parts[-1].strip() + if last_part: + sentences_list.append([{'lang': lang_block['lang'], 'text': last_part}]) + + return sentences_list + + +# cut5 +# 按标点符号切 +def cut5(language_splits): + sentences_list = [] + sentences = [] + + for lang_block in language_splits: + punds = r'[,.;?!、,。?!;:…]' + text_parts = re.split(f'({punds})', lang_block['text']) + if len(text_parts) == 1: + if text_parts[-1] and not text_parts[-1].isspace(): + sentences.append({'lang': lang_block['lang'], 'text': text_parts[-1]}) + else: + for group in zip(text_parts[::2], text_parts[1::2]): + sentences.append({'lang': lang_block['lang'], 'text': "".join(group)}) + sentences_list.append(sentences) + sentences = [] + if len(text_parts)%2 == 1: + if text_parts[-1] and not text_parts[-1].isspace(): + sentences.append({'lang': lang_block['lang'], 'text': text_parts[-1]}) + + if sentences: + sentences_list.append(sentences) + + return sentences_list + + +# 预先分割语种 +def split_language(text,language): + if language in {"en","all_zh","all_ja"}: + language = language.replace("all_","") + if language == "en": + LangSegment.setfilters(["en"]) + formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text)) + else: + # 因无法区别中日文汉字,以用户输入为准 + formattext = text + while " " in formattext: + formattext = formattext.replace(" ", " ") + language_splits = [{'lang': language, 'text': formattext}] + elif language in {"zh", "ja","auto"}: + LangSegment.setfilters(["zh","ja","en"]) + if language == "auto": + language_splits = LangSegment.getTexts(text) + else: + language_splits = [] + for tmp in LangSegment.getTexts(text): + if tmp["lang"] == "en": + language_splits.append(tmp) + else: + # 因无法区别中日文汉字,以用户输入为准 + language_splits.append({'lang': language, 'text': tmp["text"]}) + + return language_splits + + +# 合并太碎的 +def merge_fragments(sentences_list): + new_sentences_list = [] + temp_list = [] + + for sentences in sentences_list: + if sentences[0]['text'].strip() not in {".","。"}: + if temp_list: + temp_list.extend(sentences) + new_sentences_list.append(temp_list) + temp_list = [] + else: + new_sentences_list.append(sentences) + else: + temp_list.extend(sentences) + + if temp_list: + if len(new_sentences_list) >1: + new_sentences_list[-1].extend(temp_list) + else: + new_sentences_list.append(temp_list) + + sentences_list = new_sentences_list + new_sentences_list = [] + + for sentences in sentences_list: + merged_sentences = [] + prev_entry = None + + for entry in sentences: + if prev_entry and entry['lang'] == prev_entry['lang']: + prev_entry['text'] += entry['text'] + else: + merged_sentences.append(entry) + prev_entry = entry + + new_sentences_list.append(merged_sentences) + + return new_sentences_list + + def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_p=0.6, temperature=0.6, ref_free = False): if prompt_text is None or len(prompt_text) == 0: ref_free = True @@ -321,7 +540,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, prompt_text = prompt_text.strip("\n") if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_language != "en" else "." print(i18n("实际输入的参考文本:"), prompt_text) - text = text.strip("\n") + text = text.replace("\n","") if (text[0] not in splits and len(get_first(text)) < 4): text = "。" + text if text_language != "en" else "." + text print(i18n("实际输入的目标文本:"), text) @@ -352,32 +571,36 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, prompt_semantic = codes[0, 0] t1 = ttime() + # 预先分割tts目标 + texts = split_language(text,text_language) + if (how_to_cut == i18n("凑四句一切")): - text = cut1(text) + texts = cut1(texts) elif (how_to_cut == i18n("凑50字一切")): - text = cut2(text) + texts = cut2(texts) elif (how_to_cut == i18n("按中文句号。切")): - text = cut3(text) + texts = cut3(texts) elif (how_to_cut == i18n("按英文句号.切")): - text = cut4(text) + texts = cut4(texts) elif (how_to_cut == i18n("按标点符号切")): - text = cut5(text) - while "\n\n" in text: - text = text.replace("\n\n", "\n") - print(i18n("实际输入的目标文本(切句后):"), text) - texts = text.split("\n") - texts = merge_short_text_in_array(texts, 5) + texts = cut5(texts) + else: + texts = [texts] + + texts = merge_fragments(texts) + + print(i18n("实际输入的目标文本(预分割后):"), texts) audio_opt = [] if not ref_free: phones1,bert1,norm_text1=get_phones_and_bert(prompt_text, prompt_language) for text in texts: # 解决输入目标文本的空行导致报错的问题 - if (len(text.strip()) == 0): - continue - if (text[-1] not in splits): text += "。" if text_language != "en" else "." - print(i18n("实际输入的目标文本(每句):"), text) - phones2,bert2,norm_text2=get_phones_and_bert(text, text_language) + # if (len(text.strip()) == 0): + # continue + # if (text[-1] not in splits): text += "。" if text_language != "en" else "." + print(i18n("实际输入的目标文本(预分割后每句):"), text) + phones2,bert2,norm_text2=get_target_phones_and_bert(text) print(i18n("前端处理后的文本(每句):"), norm_text2) if not ref_free: bert = torch.cat([bert1, bert2], 1) @@ -452,69 +675,6 @@ def split(todo_text): return todo_texts -def cut1(inp): - inp = inp.strip("\n") - inps = split(inp) - split_idx = list(range(0, len(inps), 4)) - split_idx[-1] = None - if len(split_idx) > 1: - opts = [] - for idx in range(len(split_idx) - 1): - opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]])) - else: - opts = [inp] - return "\n".join(opts) - - -def cut2(inp): - inp = inp.strip("\n") - inps = split(inp) - if len(inps) < 2: - return inp - opts = [] - summ = 0 - tmp_str = "" - for i in range(len(inps)): - summ += len(inps[i]) - tmp_str += inps[i] - if summ > 50: - summ = 0 - opts.append(tmp_str) - tmp_str = "" - if tmp_str != "": - opts.append(tmp_str) - # print(opts) - if len(opts) > 1 and len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起 - opts[-2] = opts[-2] + opts[-1] - opts = opts[:-1] - return "\n".join(opts) - - -def cut3(inp): - inp = inp.strip("\n") - return "\n".join(["%s" % item for item in inp.strip("。").split("。")]) - - -def cut4(inp): - inp = inp.strip("\n") - return "\n".join(["%s" % item for item in inp.strip(".").split(".")]) - - -# contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py -def cut5(inp): - # if not re.search(r'[^\w\s]', inp[-1]): - # inp += '。' - inp = inp.strip("\n") - punds = r'[,.;?!、,。?!;:…]' - items = re.split(f'({punds})', inp) - mergeitems = ["".join(group) for group in zip(items[::2], items[1::2])] - # 在句子不存在符号或句尾无符号的时候保证文本完整 - if len(items)%2 == 1: - mergeitems.append(items[-1]) - opt = "\n".join(mergeitems) - return opt - - def custom_sort_key(s): # 使用正则表达式提取字符串中的数字部分和非数字部分 parts = re.split('(\d+)', s)