diff --git a/GPT_SoVITS/text/LangSegmenter/langsegmenter.py b/GPT_SoVITS/text/LangSegmenter/langsegmenter.py index 6859ddb..1c754b8 100644 --- a/GPT_SoVITS/text/LangSegmenter/langsegmenter.py +++ b/GPT_SoVITS/text/LangSegmenter/langsegmenter.py @@ -46,6 +46,99 @@ def merge_lang(lang_list, item): lang_list.append(item) return lang_list +from typing import List +from split_lang import SubString +def _special_merge_for_zh_ja( + self, + substrings: List[SubString], +) -> List[SubString]: + new_substrings: List[SubString] = [] + + if len(substrings) == 1: + return substrings + # NOTE: 统计每个语言的字符串长度 + substring_text_len_by_lang = { + "zh": 0, + "ja": 0, + "x": 0, + "digit": 0, + "punctuation": 0, + "newline": 0, + } + index = 0 + while index < len(substrings): + current_block = substrings[index] + substring_text_len_by_lang[current_block.lang] += current_block.length + if index == 0: + if ( + substrings[index + 1].lang in ["zh", "ja"] + and substrings[index].lang in ["zh", "ja", "x"] + and substrings[index].length * 10 < substrings[index + 1].length + ): + right_block = substrings[index + 1] + new_substrings.append( + SubString( + is_digit=False, + is_punctuation=False, + lang=right_block.lang, + text=current_block.text + right_block.text, + length=current_block.length + right_block.length, + index=current_block.index, + ) + ) + index += 1 + else: + new_substrings.append(substrings[index]) + elif index == len(substrings) - 1: + left_block = new_substrings[-1] + if ( + left_block.lang in ["zh", "ja"] + and current_block.lang in ["zh", "ja", "x"] + and current_block.length * 10 < left_block.length + ): + new_substrings[-1].text += current_block.text + new_substrings[-1].length += current_block.length + + index += 1 + else: + new_substrings.append(substrings[index]) + else: + if ( + new_substrings[-1].lang == substrings[index + 1].lang + and new_substrings[-1].lang in ["zh", "ja"] + # and substrings[index].lang in ["zh", "ja", "x"] + and substrings[index].lang != "en" + and substrings[index].length * 10 + < new_substrings[-1].length + substrings[index + 1].length + ): + left_block = new_substrings[-1] + right_block = substrings[index + 1] + current_block = substrings[index] + new_substrings[-1].text += current_block.text + right_block.text + new_substrings[-1].length += ( + current_block.length + right_block.length + ) + index += 1 + else: + new_substrings.append(substrings[index]) + index += 1 + # NOTE: 如果 substring_count 中 存在 x,则将 x 设置为最多的 lang + if substring_text_len_by_lang["x"] > 0: + max_lang = max( + substring_text_len_by_lang, key=substring_text_len_by_lang.get + ) + for index, substr in enumerate(new_substrings): + if substr.lang == "x": + new_substrings[index].lang = max_lang + # NOTE: 如果 ja 数量是 zh 数量的 10 倍以上,则该 zh 设置为 ja + if substring_text_len_by_lang["ja"] >= substring_text_len_by_lang["zh"] * 10: + for index, substr in enumerate(new_substrings): + if substr.lang == "zh": + new_substrings[index].lang = "ja" + new_substrings = self._merge_substrings(substrings=new_substrings) + return new_substrings + + class LangSegmenter(): # 默认过滤器, 基于gsv目前四种语言 @@ -62,6 +155,7 @@ class LangSegmenter(): def getTexts(text): + LangSplitter._special_merge_for_zh_ja = _special_merge_for_zh_ja lang_splitter = LangSplitter(lang_map=LangSegmenter.DEFAULT_LANG_MAP) substr = lang_splitter.split_by_lang(text=text) @@ -120,5 +214,6 @@ if __name__ == "__main__": text = "MyGO?,你也喜欢まいご吗?" print(LangSegmenter.getTexts(text)) - + text = "ねえ、知ってる?最近、僕は天文学を勉強してるんだ。君の瞳が星空みたいにキラキラしてるからさ。" + print(LangSegmenter.getTexts(text))