语言分割及格式化优化 (#2488)

* better LangSegmenter

* add version num2str

* better version num2str

* sync fast infer

* sync api

* remove duplicate spaces

* remove unnecessary code

---------

Co-authored-by: RVC-Boss <129054828+RVC-Boss@users.noreply.github.com>
This commit is contained in:
KamioRinn 2025-06-27 11:58:41 +08:00 committed by GitHub
parent 90ebefa78f
commit 6df61f58e4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 332 additions and 296 deletions

View File

@ -121,71 +121,67 @@ class TextPreprocessor:
def get_phones_and_bert(self, text: str, language: str, version: str, final: bool = False): def get_phones_and_bert(self, text: str, language: str, version: str, final: bool = False):
with self.bert_lock: with self.bert_lock:
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}: text = re.sub(r' {2,}', ' ', text)
# language = language.replace("all_","") textlist = []
formattext = text langlist = []
while " " in formattext: if language == "all_zh":
formattext = formattext.replace(" ", " ") for tmp in LangSegmenter.getTexts(text,"zh"):
if language == "all_zh": langlist.append(tmp["lang"])
if re.search(r"[A-Za-z]", formattext): textlist.append(tmp["text"])
formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext) elif language == "all_yue":
formattext = chinese.mix_text_normalize(formattext) for tmp in LangSegmenter.getTexts(text,"zh"):
return self.get_phones_and_bert(formattext, "zh", version) if tmp["lang"] == "zh":
tmp["lang"] = "yue"
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "all_ja":
for tmp in LangSegmenter.getTexts(text,"ja"):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "all_ko":
for tmp in LangSegmenter.getTexts(text,"ko"):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "en":
langlist.append("en")
textlist.append(text)
elif language == "auto":
for tmp in LangSegmenter.getTexts(text):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "auto_yue":
for tmp in LangSegmenter.getTexts(text):
if tmp["lang"] == "zh":
tmp["lang"] = "yue"
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
else:
for tmp in LangSegmenter.getTexts(text):
if langlist:
if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"):
textlist[-1] += tmp["text"]
continue
if tmp["lang"] == "en":
langlist.append(tmp["lang"])
else: else:
phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version) # 因无法区别中日韩文汉字,以用户输入为准
bert = self.get_bert_feature(norm_text, word2ph).to(self.device) langlist.append(language)
elif language == "all_yue" and re.search(r"[A-Za-z]", formattext): textlist.append(tmp["text"])
formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext) # print(textlist)
formattext = chinese.mix_text_normalize(formattext) # print(langlist)
return self.get_phones_and_bert(formattext, "yue", version) phones_list = []
else: bert_list = []
phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version) norm_text_list = []
bert = torch.zeros( for i in range(len(textlist)):
(1024, len(phones)), lang = langlist[i]
dtype=torch.float32, phones, word2ph, norm_text = self.clean_text_inf(textlist[i], lang, version)
).to(self.device) bert = self.get_bert_inf(phones, word2ph, norm_text, lang)
elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}: phones_list.append(phones)
textlist = [] norm_text_list.append(norm_text)
langlist = [] bert_list.append(bert)
if language == "auto": bert = torch.cat(bert_list, dim=1)
for tmp in LangSegmenter.getTexts(text): phones = sum(phones_list, [])
langlist.append(tmp["lang"]) norm_text = "".join(norm_text_list)
textlist.append(tmp["text"])
elif language == "auto_yue":
for tmp in LangSegmenter.getTexts(text):
if tmp["lang"] == "zh":
tmp["lang"] = "yue"
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
else:
for tmp in LangSegmenter.getTexts(text):
if langlist:
if (tmp["lang"] == "en" and langlist[-1] == "en") or (
tmp["lang"] != "en" and langlist[-1] != "en"
):
textlist[-1] += tmp["text"]
continue
if tmp["lang"] == "en":
langlist.append(tmp["lang"])
else:
# 因无法区别中日韩文汉字,以用户输入为准
langlist.append(language)
textlist.append(tmp["text"])
# print(textlist)
# print(langlist)
phones_list = []
bert_list = []
norm_text_list = []
for i in range(len(textlist)):
lang = langlist[i]
phones, word2ph, norm_text = self.clean_text_inf(textlist[i], lang, version)
bert = self.get_bert_inf(phones, word2ph, norm_text, lang)
phones_list.append(phones)
norm_text_list.append(norm_text)
bert_list.append(bert)
bert = torch.cat(bert_list, dim=1)
phones = sum(phones_list, [])
norm_text = "".join(norm_text_list)
if not final and len(phones) < 6: if not final and len(phones) < 6:
return self.get_phones_and_bert("." + text, language, version, final=True) return self.get_phones_and_bert("." + text, language, version, final=True)
@ -240,4 +236,4 @@ class TextPreprocessor:
punctuations = "".join(re.escape(p) for p in punctuation) punctuations = "".join(re.escape(p) for p in punctuation)
pattern = f"([{punctuations}])([{punctuations}])+" pattern = f"([{punctuations}])([{punctuations}])+"
result = re.sub(pattern, r"\1", text) result = re.sub(pattern, r"\1", text)
return result return result

View File

@ -586,68 +586,67 @@ from text import chinese
def get_phones_and_bert(text, language, version, final=False): def get_phones_and_bert(text, language, version, final=False):
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}: text = re.sub(r' {2,}', ' ', text)
formattext = text textlist = []
while " " in formattext: langlist = []
formattext = formattext.replace(" ", " ") if language == "all_zh":
if language == "all_zh": for tmp in LangSegmenter.getTexts(text,"zh"):
if re.search(r"[A-Za-z]", formattext): langlist.append(tmp["lang"])
formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext) textlist.append(tmp["text"])
formattext = chinese.mix_text_normalize(formattext) elif language == "all_yue":
return get_phones_and_bert(formattext, "zh", version) for tmp in LangSegmenter.getTexts(text,"zh"):
if tmp["lang"] == "zh":
tmp["lang"] = "yue"
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "all_ja":
for tmp in LangSegmenter.getTexts(text,"ja"):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "all_ko":
for tmp in LangSegmenter.getTexts(text,"ko"):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "en":
langlist.append("en")
textlist.append(text)
elif language == "auto":
for tmp in LangSegmenter.getTexts(text):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "auto_yue":
for tmp in LangSegmenter.getTexts(text):
if tmp["lang"] == "zh":
tmp["lang"] = "yue"
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
else:
for tmp in LangSegmenter.getTexts(text):
if langlist:
if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"):
textlist[-1] += tmp["text"]
continue
if tmp["lang"] == "en":
langlist.append(tmp["lang"])
else: else:
phones, word2ph, norm_text = clean_text_inf(formattext, language, version) # 因无法区别中日韩文汉字,以用户输入为准
bert = get_bert_feature(norm_text, word2ph).to(device) langlist.append(language)
elif language == "all_yue" and re.search(r"[A-Za-z]", formattext): textlist.append(tmp["text"])
formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext) print(textlist)
formattext = chinese.mix_text_normalize(formattext) print(langlist)
return get_phones_and_bert(formattext, "yue", version) phones_list = []
else: bert_list = []
phones, word2ph, norm_text = clean_text_inf(formattext, language, version) norm_text_list = []
bert = torch.zeros( for i in range(len(textlist)):
(1024, len(phones)), lang = langlist[i]
dtype=torch.float16 if is_half == True else torch.float32, phones, word2ph, norm_text = clean_text_inf(textlist[i], lang, version)
).to(device) bert = get_bert_inf(phones, word2ph, norm_text, lang)
elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}: phones_list.append(phones)
textlist = [] norm_text_list.append(norm_text)
langlist = [] bert_list.append(bert)
if language == "auto": bert = torch.cat(bert_list, dim=1)
for tmp in LangSegmenter.getTexts(text): phones = sum(phones_list, [])
langlist.append(tmp["lang"]) norm_text = "".join(norm_text_list)
textlist.append(tmp["text"])
elif language == "auto_yue":
for tmp in LangSegmenter.getTexts(text):
if tmp["lang"] == "zh":
tmp["lang"] = "yue"
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
else:
for tmp in LangSegmenter.getTexts(text):
if langlist:
if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"):
textlist[-1] += tmp["text"]
continue
if tmp["lang"] == "en":
langlist.append(tmp["lang"])
else:
# 因无法区别中日韩文汉字,以用户输入为准
langlist.append(language)
textlist.append(tmp["text"])
print(textlist)
print(langlist)
phones_list = []
bert_list = []
norm_text_list = []
for i in range(len(textlist)):
lang = langlist[i]
phones, word2ph, norm_text = clean_text_inf(textlist[i], lang, version)
bert = get_bert_inf(phones, word2ph, norm_text, lang)
phones_list.append(phones)
norm_text_list.append(norm_text)
bert_list.append(bert)
bert = torch.cat(bert_list, dim=1)
phones = sum(phones_list, [])
norm_text = "".join(norm_text_list)
if not final and len(phones) < 6: if not final and len(phones) < 6:
return get_phones_and_bert("." + text, language, version, final=True) return get_phones_and_bert("." + text, language, version, final=True)

View File

@ -3,44 +3,38 @@ import re
# jieba静音 # jieba静音
import jieba import jieba
jieba.setLogLevel(logging.CRITICAL) jieba.setLogLevel(logging.CRITICAL)
# 更改fast_langdetect大模型位置 # 更改fast_langdetect大模型位置
from pathlib import Path from pathlib import Path
import fast_langdetect import fast_langdetect
fast_langdetect.infer._default_detector = fast_langdetect.infer.LangDetector(fast_langdetect.infer.LangDetectConfig(cache_dir=Path(__file__).parent.parent.parent / "pretrained_models" / "fast_langdetect"))
fast_langdetect.infer._default_detector = fast_langdetect.infer.LangDetector(
fast_langdetect.infer.LangDetectConfig(
cache_dir=Path(__file__).parent.parent.parent / "pretrained_models" / "fast_langdetect"
)
)
from split_lang import LangSplitter from split_lang import LangSplitter
def full_en(text): def full_en(text):
pattern = r"^(?=.*[A-Za-z])[A-Za-z0-9\s\u0020-\u007E\u2000-\u206F\u3000-\u303F\uFF00-\uFFEF]+$" pattern = r'^(?=.*[A-Za-z])[A-Za-z0-9\s\u0020-\u007E\u2000-\u206F\u3000-\u303F\uFF00-\uFFEF]+$'
return bool(re.match(pattern, text)) return bool(re.match(pattern, text))
def full_cjk(text): def full_cjk(text):
# 来自wiki # 来自wiki
cjk_ranges = [ cjk_ranges = [
(0x4E00, 0x9FFF), # CJK Unified Ideographs (0x4E00, 0x9FFF), # CJK Unified Ideographs
(0x3400, 0x4DB5), # CJK Extension A (0x3400, 0x4DB5), # CJK Extension A
(0x20000, 0x2A6DD), # CJK Extension B (0x20000, 0x2A6DD), # CJK Extension B
(0x2A700, 0x2B73F), # CJK Extension C (0x2A700, 0x2B73F), # CJK Extension C
(0x2B740, 0x2B81F), # CJK Extension D (0x2B740, 0x2B81F), # CJK Extension D
(0x2B820, 0x2CEAF), # CJK Extension E (0x2B820, 0x2CEAF), # CJK Extension E
(0x2CEB0, 0x2EBEF), # CJK Extension F (0x2CEB0, 0x2EBEF), # CJK Extension F
(0x30000, 0x3134A), # CJK Extension G (0x30000, 0x3134A), # CJK Extension G
(0x31350, 0x323AF), # CJK Extension H (0x31350, 0x323AF), # CJK Extension H
(0x2EBF0, 0x2EE5D), # CJK Extension H (0x2EBF0, 0x2EE5D), # CJK Extension H
] ]
pattern = r"[0-9、-〜。!?.!?… /]+$" pattern = r'[0-9、-〜。!?.!?… /]+$'
cjk_text = "" cjk_text = ""
for char in text: for char in text:
@ -51,7 +45,7 @@ def full_cjk(text):
return cjk_text return cjk_text
def split_jako(tag_lang, item): def split_jako(tag_lang,item):
if tag_lang == "ja": if tag_lang == "ja":
pattern = r"([\u3041-\u3096\u3099\u309A\u30A1-\u30FA\u30FC]+(?:[0-9、-〜。!?.!?… ]+[\u3041-\u3096\u3099\u309A\u30A1-\u30FA\u30FC]*)*)" pattern = r"([\u3041-\u3096\u3099\u309A\u30A1-\u30FA\u30FC]+(?:[0-9、-〜。!?.!?… ]+[\u3041-\u3096\u3099\u309A\u30A1-\u30FA\u30FC]*)*)"
else: else:
@ -59,118 +53,165 @@ def split_jako(tag_lang, item):
lang_list: list[dict] = [] lang_list: list[dict] = []
tag = 0 tag = 0
for match in re.finditer(pattern, item["text"]): for match in re.finditer(pattern, item['text']):
if match.start() > tag: if match.start() > tag:
lang_list.append({"lang": item["lang"], "text": item["text"][tag : match.start()]}) lang_list.append({'lang':item['lang'],'text':item['text'][tag:match.start()]})
tag = match.end() tag = match.end()
lang_list.append({"lang": tag_lang, "text": item["text"][match.start() : match.end()]}) lang_list.append({'lang':tag_lang,'text':item['text'][match.start():match.end()]})
if tag < len(item["text"]): if tag < len(item['text']):
lang_list.append({"lang": item["lang"], "text": item["text"][tag : len(item["text"])]}) lang_list.append({'lang':item['lang'],'text':item['text'][tag:len(item['text'])]})
return lang_list return lang_list
def merge_lang(lang_list, item): def merge_lang(lang_list, item):
if lang_list and item["lang"] == lang_list[-1]["lang"]: if lang_list and item['lang'] == lang_list[-1]['lang']:
lang_list[-1]["text"] += item["text"] lang_list[-1]['text'] += item['text']
else: else:
lang_list.append(item) lang_list.append(item)
return lang_list return lang_list
class LangSegmenter: class LangSegmenter():
# 默认过滤器, 基于gsv目前四种语言 # 默认过滤器, 基于gsv目前四种语言
DEFAULT_LANG_MAP = { DEFAULT_LANG_MAP = {
"zh": "zh", "zh": "zh",
"yue": "zh", # 粤语 "yue": "zh", # 粤语
"wuu": "zh", # 吴语 "wuu": "zh", # 吴语
"zh-cn": "zh", "zh-cn": "zh",
"zh-tw": "x", # 繁体设置为x "zh-tw": "x", # 繁体设置为x
"ko": "ko", "ko": "ko",
"ja": "ja", "ja": "ja",
"en": "en", "en": "en",
} }
def getTexts(text): def getTexts(text,default_lang = ""):
lang_splitter = LangSplitter(lang_map=LangSegmenter.DEFAULT_LANG_MAP) lang_splitter = LangSplitter(lang_map=LangSegmenter.DEFAULT_LANG_MAP)
lang_splitter.merge_across_digit = False
substr = lang_splitter.split_by_lang(text=text) substr = lang_splitter.split_by_lang(text=text)
lang_list: list[dict] = [] lang_list: list[dict] = []
for _, item in enumerate(substr): have_num = False
dict_item = {"lang": item.lang, "text": item.text}
# 处理短英文被识别为其他语言的问题 for _, item in enumerate(substr):
if full_en(dict_item["text"]): dict_item = {'lang':item.lang,'text':item.text}
dict_item["lang"] = "en"
lang_list = merge_lang(lang_list, dict_item) if dict_item['lang'] == 'digit':
if default_lang != "":
dict_item['lang'] = default_lang
else:
have_num = True
lang_list = merge_lang(lang_list,dict_item)
continue continue
# 处理非日语夹日文的问题(不包含CJK) # 处理短英文被识别为其他语言的问题
ja_list: list[dict] = [] if full_en(dict_item['text']):
if dict_item["lang"] != "ja": dict_item['lang'] = 'en'
ja_list = split_jako("ja", dict_item) lang_list = merge_lang(lang_list,dict_item)
continue
if not ja_list: if default_lang != "":
ja_list.append(dict_item) dict_item['lang'] = default_lang
lang_list = merge_lang(lang_list,dict_item)
continue
else:
# 处理非日语夹日文的问题(不包含CJK)
ja_list: list[dict] = []
if dict_item['lang'] != 'ja':
ja_list = split_jako('ja',dict_item)
# 处理非韩语夹韩语的问题(不包含CJK) if not ja_list:
ko_list: list[dict] = [] ja_list.append(dict_item)
temp_list: list[dict] = []
for _, ko_item in enumerate(ja_list):
if ko_item["lang"] != "ko":
ko_list = split_jako("ko", ko_item)
if ko_list: # 处理非韩语夹韩语的问题(不包含CJK)
temp_list.extend(ko_list) ko_list: list[dict] = []
else: temp_list: list[dict] = []
temp_list.append(ko_item) for _, ko_item in enumerate(ja_list):
if ko_item["lang"] != 'ko':
ko_list = split_jako('ko',ko_item)
# 未存在非日韩文夹日韩文 if ko_list:
if len(temp_list) == 1: temp_list.extend(ko_list)
# 未知语言检查是否为CJK
if dict_item["lang"] == "x":
cjk_text = full_cjk(dict_item["text"])
if cjk_text:
dict_item = {"lang": "zh", "text": cjk_text}
lang_list = merge_lang(lang_list, dict_item)
else: else:
lang_list = merge_lang(lang_list, dict_item) temp_list.append(ko_item)
continue
else:
lang_list = merge_lang(lang_list, dict_item)
continue
# 存在非日韩文夹日韩文 # 未存在非日韩文夹日韩文
for _, temp_item in enumerate(temp_list): if len(temp_list) == 1:
# 未知语言检查是否为CJK # 未知语言检查是否为CJK
if temp_item["lang"] == "x": if dict_item['lang'] == 'x':
cjk_text = full_cjk(dict_item["text"]) cjk_text = full_cjk(dict_item['text'])
if cjk_text: if cjk_text:
dict_item = {"lang": "zh", "text": cjk_text} dict_item = {'lang':'zh','text':cjk_text}
lang_list = merge_lang(lang_list, dict_item) lang_list = merge_lang(lang_list,dict_item)
else:
lang_list = merge_lang(lang_list,dict_item)
continue
else: else:
lang_list = merge_lang(lang_list, dict_item) lang_list = merge_lang(lang_list,dict_item)
else: continue
lang_list = merge_lang(lang_list, temp_item)
# 存在非日韩文夹日韩文
for _, temp_item in enumerate(temp_list):
# 未知语言检查是否为CJK
if temp_item['lang'] == 'x':
cjk_text = full_cjk(temp_item['text'])
if cjk_text:
lang_list = merge_lang(lang_list,{'lang':'zh','text':cjk_text})
else:
lang_list = merge_lang(lang_list,temp_item)
else:
lang_list = merge_lang(lang_list,temp_item)
# 有数字
if have_num:
temp_list = lang_list
lang_list = []
for i, temp_item in enumerate(temp_list):
if temp_item['lang'] == 'digit':
if default_lang:
temp_item['lang'] = default_lang
elif lang_list and i == len(temp_list) - 1:
temp_item['lang'] = lang_list[-1]['lang']
elif not lang_list and i < len(temp_list) - 1:
temp_item['lang'] = temp_list[1]['lang']
elif lang_list and i < len(temp_list) - 1:
if lang_list[-1]['lang'] == temp_list[i + 1]['lang']:
temp_item['lang'] = lang_list[-1]['lang']
elif lang_list[-1]['text'][-1] in [",",".","!","?","","","",""]:
temp_item['lang'] = temp_list[i + 1]['lang']
elif temp_list[i + 1]['text'][0] in [",",".","!","?","","","",""]:
temp_item['lang'] = lang_list[-1]['lang']
elif temp_item['text'][-1] in ["","."]:
temp_item['lang'] = lang_list[-1]['lang']
elif len(lang_list[-1]['text']) >= len(temp_list[i + 1]['text']):
temp_item['lang'] = lang_list[-1]['lang']
else:
temp_item['lang'] = temp_list[i + 1]['lang']
else:
temp_item['lang'] = 'zh'
lang_list = merge_lang(lang_list,temp_item)
# 筛X
temp_list = lang_list temp_list = lang_list
lang_list = [] lang_list = []
for _, temp_item in enumerate(temp_list): for _, temp_item in enumerate(temp_list):
if temp_item["lang"] == "x": if temp_item['lang'] == 'x':
if lang_list: if lang_list:
temp_item["lang"] = lang_list[-1]["lang"] temp_item['lang'] = lang_list[-1]['lang']
elif len(temp_list) > 1: elif len(temp_list) > 1:
temp_item["lang"] = temp_list[1]["lang"] temp_item['lang'] = temp_list[1]['lang']
else: else:
temp_item["lang"] = "zh" temp_item['lang'] = 'zh'
lang_list = merge_lang(lang_list, temp_item) lang_list = merge_lang(lang_list,temp_item)
return lang_list return lang_list
if __name__ == "__main__": if __name__ == "__main__":
text = "MyGO?,你也喜欢まいご吗?" text = "MyGO?,你也喜欢まいご吗?"
@ -178,3 +219,7 @@ if __name__ == "__main__":
text = "ねえ、知ってる?最近、僕は天文学を勉強してるんだ。君の瞳が星空みたいにキラキラしてるからさ。" text = "ねえ、知ってる?最近、僕は天文学を勉強してるんだ。君の瞳が星空みたいにキラキラしてるからさ。"
print(LangSegmenter.getTexts(text)) print(LangSegmenter.getTexts(text))
text = "当时ThinkPad T60刚刚发布一同推出的还有一款名为Advanced Dock的扩展坞配件。这款扩展坞通过连接T60底部的插槽扩展出包括PCIe在内的一大堆接口并且自带电源让T60可以安装桌面显卡来提升性能。"
print(LangSegmenter.getTexts(text,"zh"))
print(LangSegmenter.getTexts(text))

View File

@ -181,20 +181,6 @@ def text_normalize(text):
return dest_text return dest_text
# 不排除英文的文本格式化
def mix_text_normalize(text):
# https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/paddlespeech/t2s/frontend/zh_normalization
tx = TextNormalizer()
sentences = tx.normalize(text)
dest_text = ""
for sentence in sentences:
dest_text += replace_punctuation_with_en(sentence)
# 避免重复标点引起的参考泄露
dest_text = replace_consecutive_punctuation(dest_text)
return dest_text
if __name__ == "__main__": if __name__ == "__main__":
text = "啊——但是《原神》是由,米哈\游自主,研发的一款全.新开放世界.冒险游戏" text = "啊——但是《原神》是由,米哈\游自主,研发的一款全.新开放世界.冒险游戏"
text = "呣呣呣~就是…大人的鼹鼠党吧?" text = "呣呣呣~就是…大人的鼹鼠党吧?"

View File

@ -326,20 +326,6 @@ def text_normalize(text):
return dest_text return dest_text
# 不排除英文的文本格式化
def mix_text_normalize(text):
# https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/paddlespeech/t2s/frontend/zh_normalization
tx = TextNormalizer()
sentences = tx.normalize(text)
dest_text = ""
for sentence in sentences:
dest_text += replace_punctuation_with_en(sentence)
# 避免重复标点引起的参考泄露
dest_text = replace_consecutive_punctuation(dest_text)
return dest_text
if __name__ == "__main__": if __name__ == "__main__":
text = "啊——但是《原神》是由,米哈\游自主,研发的一款全.新开放世界.冒险游戏" text = "啊——但是《原神》是由,米哈\游自主,研发的一款全.新开放世界.冒险游戏"
text = "呣呣呣~就是…大人的鼹鼠党吧?" text = "呣呣呣~就是…大人的鼹鼠党吧?"

View File

@ -256,6 +256,24 @@ def replace_to_range(match) -> str:
return result return result
RE_VERSION_NUM = re.compile(r"((\d+)(\.\d+)(\.\d+)?(\.\d+)+)")
def replace_vrsion_num(match) -> str:
"""
Args:
match (re.Match)
Returns:
str
"""
result = ""
for c in match.group(1):
if c == ".":
result += ""
else:
result += num2str(c)
return result
def _get_value(value_string: str, use_zero: bool = True) -> List[str]: def _get_value(value_string: str, use_zero: bool = True) -> List[str]:
stripped = value_string.lstrip("0") stripped = value_string.lstrip("0")
if len(stripped) == 0: if len(stripped) == 0:
@ -308,7 +326,11 @@ def num2str(value_string: str) -> str:
result = verbalize_cardinal(integer) result = verbalize_cardinal(integer)
decimal = decimal.rstrip("0") if decimal.endswith("0"):
decimal = decimal.rstrip("0") + "0"
else:
decimal = decimal.rstrip("0")
if decimal: if decimal:
# '.22' is verbalized as '零点二二' # '.22' is verbalized as '零点二二'
# '3.20' is verbalized as '三点二 # '3.20' is verbalized as '三点二

View File

@ -25,6 +25,7 @@ from .chronology import replace_time
from .constants import F2H_ASCII_LETTERS from .constants import F2H_ASCII_LETTERS
from .constants import F2H_DIGITS from .constants import F2H_DIGITS
from .constants import F2H_SPACE from .constants import F2H_SPACE
from .num import RE_VERSION_NUM
from .num import RE_DECIMAL_NUM from .num import RE_DECIMAL_NUM
from .num import RE_DEFAULT_NUM from .num import RE_DEFAULT_NUM
from .num import RE_FRAC from .num import RE_FRAC
@ -36,6 +37,7 @@ from .num import RE_RANGE
from .num import RE_TO_RANGE from .num import RE_TO_RANGE
from .num import RE_ASMD from .num import RE_ASMD
from .num import RE_POWER from .num import RE_POWER
from .num import replace_vrsion_num
from .num import replace_default_num from .num import replace_default_num
from .num import replace_frac from .num import replace_frac
from .num import replace_negative_num from .num import replace_negative_num
@ -158,6 +160,7 @@ class TextNormalizer:
sentence = RE_RANGE.sub(replace_range, sentence) sentence = RE_RANGE.sub(replace_range, sentence)
sentence = RE_INTEGER.sub(replace_negative_num, sentence) sentence = RE_INTEGER.sub(replace_negative_num, sentence)
sentence = RE_VERSION_NUM.sub(replace_vrsion_num, sentence)
sentence = RE_DECIMAL_NUM.sub(replace_number, sentence) sentence = RE_DECIMAL_NUM.sub(replace_number, sentence)
sentence = RE_POSITIVE_QUANTIFIERS.sub(replace_positive_quantifier, sentence) sentence = RE_POSITIVE_QUANTIFIERS.sub(replace_positive_quantifier, sentence)
sentence = RE_DEFAULT_NUM.sub(replace_default_num, sentence) sentence = RE_DEFAULT_NUM.sub(replace_default_num, sentence)

117
api.py
View File

@ -543,66 +543,65 @@ from text import chinese
def get_phones_and_bert(text, language, version, final=False): def get_phones_and_bert(text, language, version, final=False):
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}: text = re.sub(r' {2,}', ' ', text)
formattext = text textlist = []
while " " in formattext: langlist = []
formattext = formattext.replace(" ", " ") if language == "all_zh":
if language == "all_zh": for tmp in LangSegmenter.getTexts(text,"zh"):
if re.search(r"[A-Za-z]", formattext): langlist.append(tmp["lang"])
formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext) textlist.append(tmp["text"])
formattext = chinese.mix_text_normalize(formattext) elif language == "all_yue":
return get_phones_and_bert(formattext, "zh", version) for tmp in LangSegmenter.getTexts(text,"zh"):
if tmp["lang"] == "zh":
tmp["lang"] = "yue"
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "all_ja":
for tmp in LangSegmenter.getTexts(text,"ja"):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "all_ko":
for tmp in LangSegmenter.getTexts(text,"ko"):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "en":
langlist.append("en")
textlist.append(text)
elif language == "auto":
for tmp in LangSegmenter.getTexts(text):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "auto_yue":
for tmp in LangSegmenter.getTexts(text):
if tmp["lang"] == "zh":
tmp["lang"] = "yue"
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
else:
for tmp in LangSegmenter.getTexts(text):
if langlist:
if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"):
textlist[-1] += tmp["text"]
continue
if tmp["lang"] == "en":
langlist.append(tmp["lang"])
else: else:
phones, word2ph, norm_text = clean_text_inf(formattext, language, version) # 因无法区别中日韩文汉字,以用户输入为准
bert = get_bert_feature(norm_text, word2ph).to(device) langlist.append(language)
elif language == "all_yue" and re.search(r"[A-Za-z]", formattext): textlist.append(tmp["text"])
formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext) phones_list = []
formattext = chinese.mix_text_normalize(formattext) bert_list = []
return get_phones_and_bert(formattext, "yue", version) norm_text_list = []
else: for i in range(len(textlist)):
phones, word2ph, norm_text = clean_text_inf(formattext, language, version) lang = langlist[i]
bert = torch.zeros( phones, word2ph, norm_text = clean_text_inf(textlist[i], lang, version)
(1024, len(phones)), bert = get_bert_inf(phones, word2ph, norm_text, lang)
dtype=torch.float16 if is_half == True else torch.float32, phones_list.append(phones)
).to(device) norm_text_list.append(norm_text)
elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}: bert_list.append(bert)
textlist = [] bert = torch.cat(bert_list, dim=1)
langlist = [] phones = sum(phones_list, [])
if language == "auto": norm_text = "".join(norm_text_list)
for tmp in LangSegmenter.getTexts(text):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "auto_yue":
for tmp in LangSegmenter.getTexts(text):
if tmp["lang"] == "zh":
tmp["lang"] = "yue"
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
else:
for tmp in LangSegmenter.getTexts(text):
if langlist:
if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"):
textlist[-1] += tmp["text"]
continue
if tmp["lang"] == "en":
langlist.append(tmp["lang"])
else:
# 因无法区别中日韩文汉字,以用户输入为准
langlist.append(language)
textlist.append(tmp["text"])
phones_list = []
bert_list = []
norm_text_list = []
for i in range(len(textlist)):
lang = langlist[i]
phones, word2ph, norm_text = clean_text_inf(textlist[i], lang, version)
bert = get_bert_inf(phones, word2ph, norm_text, lang)
phones_list.append(phones)
norm_text_list.append(norm_text)
bert_list.append(bert)
bert = torch.cat(bert_list, dim=1)
phones = sum(phones_list, [])
norm_text = "".join(norm_text_list)
if not final and len(phones) < 6: if not final and len(phones) < 6:
return get_phones_and_bert("." + text, language, version, final=True) return get_phones_and_bert("." + text, language, version, final=True)