Patch-LangSegmenter (#2073)

This commit is contained in:
KamioRinn 2025-02-18 15:23:06 +08:00 committed by GitHub
parent c17dd642c7
commit a2bb1dab91
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -46,6 +46,99 @@ def merge_lang(lang_list, item):
lang_list.append(item) lang_list.append(item)
return lang_list return lang_list
from typing import List
from split_lang import SubString
def _special_merge_for_zh_ja(
self,
substrings: List[SubString],
) -> List[SubString]:
new_substrings: List[SubString] = []
if len(substrings) == 1:
return substrings
# NOTE: 统计每个语言的字符串长度
substring_text_len_by_lang = {
"zh": 0,
"ja": 0,
"x": 0,
"digit": 0,
"punctuation": 0,
"newline": 0,
}
index = 0
while index < len(substrings):
current_block = substrings[index]
substring_text_len_by_lang[current_block.lang] += current_block.length
if index == 0:
if (
substrings[index + 1].lang in ["zh", "ja"]
and substrings[index].lang in ["zh", "ja", "x"]
and substrings[index].length * 10 < substrings[index + 1].length
):
right_block = substrings[index + 1]
new_substrings.append(
SubString(
is_digit=False,
is_punctuation=False,
lang=right_block.lang,
text=current_block.text + right_block.text,
length=current_block.length + right_block.length,
index=current_block.index,
)
)
index += 1
else:
new_substrings.append(substrings[index])
elif index == len(substrings) - 1:
left_block = new_substrings[-1]
if (
left_block.lang in ["zh", "ja"]
and current_block.lang in ["zh", "ja", "x"]
and current_block.length * 10 < left_block.length
):
new_substrings[-1].text += current_block.text
new_substrings[-1].length += current_block.length
index += 1
else:
new_substrings.append(substrings[index])
else:
if (
new_substrings[-1].lang == substrings[index + 1].lang
and new_substrings[-1].lang in ["zh", "ja"]
# and substrings[index].lang in ["zh", "ja", "x"]
and substrings[index].lang != "en"
and substrings[index].length * 10
< new_substrings[-1].length + substrings[index + 1].length
):
left_block = new_substrings[-1]
right_block = substrings[index + 1]
current_block = substrings[index]
new_substrings[-1].text += current_block.text + right_block.text
new_substrings[-1].length += (
current_block.length + right_block.length
)
index += 1
else:
new_substrings.append(substrings[index])
index += 1
# NOTE: 如果 substring_count 中 存在 x则将 x 设置为最多的 lang
if substring_text_len_by_lang["x"] > 0:
max_lang = max(
substring_text_len_by_lang, key=substring_text_len_by_lang.get
)
for index, substr in enumerate(new_substrings):
if substr.lang == "x":
new_substrings[index].lang = max_lang
# NOTE: 如果 ja 数量是 zh 数量的 10 倍以上,则该 zh 设置为 ja
if substring_text_len_by_lang["ja"] >= substring_text_len_by_lang["zh"] * 10:
for index, substr in enumerate(new_substrings):
if substr.lang == "zh":
new_substrings[index].lang = "ja"
new_substrings = self._merge_substrings(substrings=new_substrings)
return new_substrings
class LangSegmenter(): class LangSegmenter():
# 默认过滤器, 基于gsv目前四种语言 # 默认过滤器, 基于gsv目前四种语言
@ -62,6 +155,7 @@ class LangSegmenter():
def getTexts(text): def getTexts(text):
LangSplitter._special_merge_for_zh_ja = _special_merge_for_zh_ja
lang_splitter = LangSplitter(lang_map=LangSegmenter.DEFAULT_LANG_MAP) lang_splitter = LangSplitter(lang_map=LangSegmenter.DEFAULT_LANG_MAP)
substr = lang_splitter.split_by_lang(text=text) substr = lang_splitter.split_by_lang(text=text)
@ -120,5 +214,6 @@ if __name__ == "__main__":
text = "MyGO?,你也喜欢まいご吗?" text = "MyGO?,你也喜欢まいご吗?"
print(LangSegmenter.getTexts(text)) print(LangSegmenter.getTexts(text))
text = "ねえ、知ってる?最近、僕は天文学を勉強してるんだ。君の瞳が星空みたいにキラキラしてるからさ。"
print(LangSegmenter.getTexts(text))