better LangSegmenter

This commit is contained in:
KamioRinn 2025-06-26 11:43:38 +08:00
parent ed89a02337
commit d90ee93c23
2 changed files with 117 additions and 65 deletions

View File

@ -586,32 +586,34 @@ from text import chinese
def get_phones_and_bert(text, language, version, final=False): def get_phones_and_bert(text, language, version, final=False):
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}: if language in {"all_zh", "all_yue", "all_ja", "all_ko", "zh", "ja", "ko", "yue", "en", "auto", "auto_yue"}:
formattext = text
while " " in formattext:
formattext = formattext.replace(" ", " ")
if language == "all_zh":
if re.search(r"[A-Za-z]", formattext):
formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext)
formattext = chinese.mix_text_normalize(formattext)
return get_phones_and_bert(formattext, "zh", version)
else:
phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
bert = get_bert_feature(norm_text, word2ph).to(device)
elif language == "all_yue" and re.search(r"[A-Za-z]", formattext):
formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext)
formattext = chinese.mix_text_normalize(formattext)
return get_phones_and_bert(formattext, "yue", version)
else:
phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
bert = torch.zeros(
(1024, len(phones)),
dtype=torch.float16 if is_half == True else torch.float32,
).to(device)
elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}:
textlist = [] textlist = []
langlist = [] langlist = []
if language == "auto": if language == "all_zh":
for tmp in LangSegmenter.getTexts(text,"zh"):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "all_yue":
for tmp in LangSegmenter.getTexts(text,"zh"):
if tmp["lang"] == "zh":
tmp["lang"] = "yue"
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "all_ja":
for tmp in LangSegmenter.getTexts(text,"ja"):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "all_ko":
for tmp in LangSegmenter.getTexts(text,"ko"):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "en":
formattext = text
while " " in formattext:
formattext = formattext.replace(" ", " ")
langlist.append("en")
textlist.append(formattext)
elif language == "auto":
for tmp in LangSegmenter.getTexts(text): for tmp in LangSegmenter.getTexts(text):
langlist.append(tmp["lang"]) langlist.append(tmp["lang"])
textlist.append(tmp["text"]) textlist.append(tmp["text"])

View File

@ -87,70 +87,116 @@ class LangSegmenter():
"en": "en", "en": "en",
} }
def getTexts(text,default_lang = ""):
def getTexts(text):
lang_splitter = LangSplitter(lang_map=LangSegmenter.DEFAULT_LANG_MAP) lang_splitter = LangSplitter(lang_map=LangSegmenter.DEFAULT_LANG_MAP)
lang_splitter.merge_across_digit = False
substr = lang_splitter.split_by_lang(text=text) substr = lang_splitter.split_by_lang(text=text)
lang_list: list[dict] = [] lang_list: list[dict] = []
have_num = False
for _, item in enumerate(substr): for _, item in enumerate(substr):
dict_item = {'lang':item.lang,'text':item.text} dict_item = {'lang':item.lang,'text':item.text}
if dict_item['lang'] == 'digit':
if default_lang != "":
dict_item['lang'] = default_lang
else:
have_num = True
lang_list = merge_lang(lang_list,dict_item)
continue
# 处理短英文被识别为其他语言的问题 # 处理短英文被识别为其他语言的问题
if full_en(dict_item['text']): if full_en(dict_item['text']):
dict_item['lang'] = 'en' dict_item['lang'] = 'en'
lang_list = merge_lang(lang_list,dict_item) lang_list = merge_lang(lang_list,dict_item)
continue continue
# 处理非日语夹日文的问题(不包含CJK) if default_lang != "":
ja_list: list[dict] = [] dict_item['lang'] = default_lang
if dict_item['lang'] != 'ja': lang_list = merge_lang(lang_list,dict_item)
ja_list = split_jako('ja',dict_item) continue
else:
# 处理非日语夹日文的问题(不包含CJK)
ja_list: list[dict] = []
if dict_item['lang'] != 'ja':
ja_list = split_jako('ja',dict_item)
if not ja_list: if not ja_list:
ja_list.append(dict_item) ja_list.append(dict_item)
# 处理非韩语夹韩语的问题(不包含CJK) # 处理非韩语夹韩语的问题(不包含CJK)
ko_list: list[dict] = [] ko_list: list[dict] = []
temp_list: list[dict] = [] temp_list: list[dict] = []
for _, ko_item in enumerate(ja_list): for _, ko_item in enumerate(ja_list):
if ko_item["lang"] != 'ko': if ko_item["lang"] != 'ko':
ko_list = split_jako('ko',ko_item) ko_list = split_jako('ko',ko_item)
if ko_list: if ko_list:
temp_list.extend(ko_list) temp_list.extend(ko_list)
else: else:
temp_list.append(ko_item) temp_list.append(ko_item)
# 未存在非日韩文夹日韩文 # 未存在非日韩文夹日韩文
if len(temp_list) == 1: if len(temp_list) == 1:
# 未知语言检查是否为CJK # 未知语言检查是否为CJK
if dict_item['lang'] == 'x': if dict_item['lang'] == 'x':
cjk_text = full_cjk(dict_item['text']) cjk_text = full_cjk(dict_item['text'])
if cjk_text: if cjk_text:
dict_item = {'lang':'zh','text':cjk_text} dict_item = {'lang':'zh','text':cjk_text}
lang_list = merge_lang(lang_list,dict_item) lang_list = merge_lang(lang_list,dict_item)
else:
lang_list = merge_lang(lang_list,dict_item)
continue
else: else:
lang_list = merge_lang(lang_list,dict_item) lang_list = merge_lang(lang_list,dict_item)
continue continue
else:
lang_list = merge_lang(lang_list,dict_item)
continue
# 存在非日韩文夹日韩文 # 存在非日韩文夹日韩文
for _, temp_item in enumerate(temp_list): for _, temp_item in enumerate(temp_list):
# 未知语言检查是否为CJK # 未知语言检查是否为CJK
if temp_item['lang'] == 'x': if temp_item['lang'] == 'x':
cjk_text = full_cjk(dict_item['text']) cjk_text = full_cjk(temp_item['text'])
if cjk_text: if cjk_text:
dict_item = {'lang':'zh','text':cjk_text} lang_list = merge_lang(lang_list,{'lang':'zh','text':cjk_text})
lang_list = merge_lang(lang_list,dict_item) else:
lang_list = merge_lang(lang_list,temp_item)
else: else:
lang_list = merge_lang(lang_list,dict_item) lang_list = merge_lang(lang_list,temp_item)
else:
lang_list = merge_lang(lang_list,temp_item)
# 有数字
if have_num:
temp_list = lang_list
lang_list = []
for i, temp_item in enumerate(temp_list):
if temp_item['lang'] == 'digit':
if default_lang:
temp_item['lang'] = default_lang
elif lang_list and i == len(temp_list) - 1:
temp_item['lang'] = lang_list[-1]['lang']
elif not lang_list and i < len(temp_list) - 1:
temp_item['lang'] = temp_list[1]['lang']
elif lang_list and i < len(temp_list) - 1:
if lang_list[-1]['lang'] == temp_list[i + 1]['lang']:
temp_item['lang'] = lang_list[-1]['lang']
elif lang_list[-1]['text'][-1] in [",",".","!","?","","","",""]:
temp_item['lang'] = temp_list[i + 1]['lang']
elif temp_list[i + 1]['text'][0] in [",",".","!","?","","","",""]:
temp_item['lang'] = lang_list[-1]['lang']
elif temp_item['text'][-1] in ["","."]:
temp_item['lang'] = lang_list[-1]['lang']
elif len(lang_list[-1]['text']) >= len(temp_list[i + 1]['text']):
temp_item['lang'] = lang_list[-1]['lang']
else:
temp_item['lang'] = temp_list[i + 1]['lang']
else:
temp_item['lang'] = 'zh'
lang_list = merge_lang(lang_list,temp_item)
# 筛X
temp_list = lang_list temp_list = lang_list
lang_list = [] lang_list = []
for _, temp_item in enumerate(temp_list): for _, temp_item in enumerate(temp_list):
@ -173,3 +219,7 @@ if __name__ == "__main__":
text = "ねえ、知ってる?最近、僕は天文学を勉強してるんだ。君の瞳が星空みたいにキラキラしてるからさ。" text = "ねえ、知ってる?最近、僕は天文学を勉強してるんだ。君の瞳が星空みたいにキラキラしてるからさ。"
print(LangSegmenter.getTexts(text)) print(LangSegmenter.getTexts(text))
text = "当时ThinkPad T60刚刚发布一同推出的还有一款名为Advanced Dock的扩展坞配件。这款扩展坞通过连接T60底部的插槽扩展出包括PCIe在内的一大堆接口并且自带电源让T60可以安装桌面显卡来提升性能。"
print(LangSegmenter.getTexts(text,"zh"))
print(LangSegmenter.getTexts(text))