Better normlization (#1351)

This commit is contained in:
KamioRinn 2024-07-27 16:03:43 +08:00 committed by GitHub
parent f042030cca
commit e851ae34c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 93 additions and 18 deletions

View File

@ -236,7 +236,7 @@ def get_first(text):
text = re.split(pattern, text)[0].strip()
return text
from text import chinese
def get_phones_and_bert(text,language):
if language in {"en","all_zh","all_ja"}:
language = language.replace("all_","")
@ -248,10 +248,17 @@ def get_phones_and_bert(text,language):
formattext = text
while " " in formattext:
formattext = formattext.replace(" ", " ")
phones, word2ph, norm_text = clean_text_inf(formattext, language)
if language == "zh":
if re.search(r'[A-Za-z]', formattext):
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
formattext = chinese.text_normalize(formattext)
return get_phones_and_bert(formattext,"zh")
else:
phones, word2ph, norm_text = clean_text_inf(formattext, language)
bert = get_bert_feature(norm_text, word2ph).to(device)
else:
phones, word2ph, norm_text = clean_text_inf(formattext, language)
bert = torch.zeros(
(1024, len(phones)),
dtype=torch.float16 if is_half == True else torch.float32,
@ -327,7 +334,6 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
if (prompt_text[-1] not in splits): prompt_text += "" if prompt_language != "en" else "."
print(i18n("实际输入的参考文本:"), prompt_text)
text = text.strip("\n")
text = replace_consecutive_punctuation(text)
if (text[0] not in splits and len(get_first(text)) < 4): text = "" + text if text_language != "en" else "." + text
print(i18n("实际输入的目标文本:"), text)
@ -551,13 +557,6 @@ def process_text(texts):
return _text
def replace_consecutive_punctuation(text):
punctuations = ''.join(re.escape(p) for p in punctuation)
pattern = f'([{punctuations}])([{punctuations}])+'
result = re.sub(pattern, r'\1', text)
return result
def change_choices():
SoVITS_names, GPT_names = get_weights_names()
return {"choices": sorted(SoVITS_names, key=custom_sort_key), "__type__": "update"}, {"choices": sorted(GPT_names, key=custom_sort_key), "__type__": "update"}

View File

@ -48,12 +48,19 @@ def replace_punctuation(text):
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
replaced_text = re.sub(
r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text
r"[^\u4e00-\u9fa5A-Za-z" + "".join(punctuation) + r"]+", "", replaced_text
)
return replaced_text
def replace_consecutive_punctuation(text):
punctuations = ''.join(re.escape(p) for p in punctuation)
pattern = f'([{punctuations}])([{punctuations}])+'
result = re.sub(pattern, r'\1', text)
return result
def g2p(text):
pattern = r"(?<=[{0}])\s*".format("".join(punctuation))
sentences = [i for i in re.split(pattern, text) if i.strip() != ""]
@ -158,6 +165,9 @@ def text_normalize(text):
dest_text = ""
for sentence in sentences:
dest_text += replace_punctuation(sentence)
# 避免重复标点引起的参考泄露
dest_text = replace_consecutive_punctuation(dest_text)
return dest_text

View File

@ -4,7 +4,7 @@ import re
import wordsegment
from g2p_en import G2p
from string import punctuation
from text.symbols import punctuation
from text import symbols
@ -110,6 +110,13 @@ def replace_phs(phs):
return phs_new
def replace_consecutive_punctuation(text):
punctuations = ''.join(re.escape(p) for p in punctuation)
pattern = f'([{punctuations}])([{punctuations}])+'
result = re.sub(pattern, r'\1', text)
return result
def read_dict():
g2p_dict = {}
start_line = 49
@ -234,6 +241,9 @@ def text_normalize(text):
text = re.sub(r"(?i)i\.e\.", "that is", text)
text = re.sub(r"(?i)e\.g\.", "for example", text)
# 避免重复标点引起的参考泄露
text = replace_consecutive_punctuation(text)
return text

View File

@ -6,6 +6,7 @@ import pyopenjtalk
from text import symbols
from text.symbols import punctuation
# Regular expression matching Japanese without punctuation marks:
_japanese_characters = re.compile(
r"[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]"
@ -65,6 +66,13 @@ def post_replace_ph(ph):
return ph
def replace_consecutive_punctuation(text):
punctuations = ''.join(re.escape(p) for p in punctuation)
pattern = f'([{punctuations}])([{punctuations}])+'
result = re.sub(pattern, r'\1', text)
return result
def symbols_to_japanese(text):
for regex, replacement in _symbols_to_japanese:
text = re.sub(regex, replacement, text)
@ -94,6 +102,9 @@ def preprocess_jap(text, with_prosody=False):
def text_normalize(text):
# todo: jap text normalize
# 避免重复标点引起的参考泄露
text = replace_consecutive_punctuation(text)
return text
# Copied from espnet https://github.com/espnet/espnet/blob/master/espnet2/text/phoneme_tokenizer.py

View File

@ -107,8 +107,11 @@ def replace_default_num(match):
# 加减乘除
# RE_ASMD = re.compile(
# r'((-?)((\d+)(\.\d+)?)|(\.(\d+)))([\+\-\×÷=])((-?)((\d+)(\.\d+)?)|(\.(\d+)))')
RE_ASMD = re.compile(
r'((-?)((\d+)(\.\d+)?)|(\.(\d+)))([\+\-\×÷=])((-?)((\d+)(\.\d+)?)|(\.(\d+)))')
r'((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))([\+\-\×÷=])((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))')
asmd_map = {
'+': '',
'-': '',
@ -117,7 +120,6 @@ asmd_map = {
'=': '等于'
}
def replace_asmd(match) -> str:
"""
Args:
@ -129,6 +131,39 @@ def replace_asmd(match) -> str:
return result
# 次方专项
RE_POWER = re.compile(r'[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]+')
power_map = {
'': '0',
'¹': '1',
'²': '2',
'³': '3',
'': '4',
'': '5',
'': '6',
'': '7',
'': '8',
'': '9',
'ˣ': 'x',
'ʸ': 'y',
'': 'n'
}
def replace_power(match) -> str:
"""
Args:
match (re.Match)
Returns:
str
"""
power_num = ""
for m in match.group(0):
power_num += power_map[m]
result = "" + power_num + "次方"
return result
# 数字表达式
# 纯小数
RE_DECIMAL_NUM = re.compile(r'(-?)((\d+)(\.\d+))' r'|(\.(\d+))')

View File

@ -35,6 +35,7 @@ from .num import RE_POSITIVE_QUANTIFIERS
from .num import RE_RANGE
from .num import RE_TO_RANGE
from .num import RE_ASMD
from .num import RE_POWER
from .num import replace_default_num
from .num import replace_frac
from .num import replace_negative_num
@ -44,6 +45,7 @@ from .num import replace_positive_quantifier
from .num import replace_range
from .num import replace_to_range
from .num import replace_asmd
from .num import replace_power
from .phonecode import RE_MOBILE_PHONE
from .phonecode import RE_NATIONAL_UNIFORM_NUMBER
from .phonecode import RE_TELEPHONE
@ -114,6 +116,12 @@ class TextNormalizer():
sentence = sentence.replace('χ', '')
sentence = sentence.replace('ψ', '普赛').replace('Ψ', '普赛')
sentence = sentence.replace('ω', '欧米伽').replace('Ω', '欧米伽')
# 兜底数学运算,顺便兼容懒人用语
sentence = sentence.replace('+', '')
sentence = sentence.replace('-', '')
sentence = sentence.replace('×', '')
sentence = sentence.replace('÷', '')
sentence = sentence.replace('=', '')
# re filter special characters, have one more character "-" than line 68
sentence = re.sub(r'[-——《》【】<=>{}()#&@“”^_|\\]', '', sentence)
return sentence
@ -136,6 +144,12 @@ class TextNormalizer():
sentence = RE_TO_RANGE.sub(replace_to_range, sentence)
sentence = RE_TEMPERATURE.sub(replace_temperature, sentence)
sentence = replace_measure(sentence)
# 处理数学运算
while RE_ASMD.search(sentence):
sentence = RE_ASMD.sub(replace_asmd, sentence)
sentence = RE_POWER.sub(replace_power, sentence)
sentence = RE_FRAC.sub(replace_frac, sentence)
sentence = RE_PERCENTAGE.sub(replace_percentage, sentence)
sentence = RE_MOBILE_PHONE.sub(replace_mobile, sentence)
@ -145,10 +159,6 @@ class TextNormalizer():
sentence = RE_RANGE.sub(replace_range, sentence)
# 处理加减乘除
while RE_ASMD.search(sentence):
sentence = RE_ASMD.sub(replace_asmd, sentence)
sentence = RE_INTEGER.sub(replace_negative_num, sentence)
sentence = RE_DECIMAL_NUM.sub(replace_number, sentence)
sentence = RE_POSITIVE_QUANTIFIERS.sub(replace_positive_quantifier,