From e851ae34c9ad97012babac9455561c73ddc6e6cd Mon Sep 17 00:00:00 2001 From: KamioRinn <63162909+KamioRinn@users.noreply.github.com> Date: Sat, 27 Jul 2024 16:03:43 +0800 Subject: [PATCH] Better normlization (#1351) --- GPT_SoVITS/inference_webui.py | 19 +++++---- GPT_SoVITS/text/chinese.py | 12 +++++- GPT_SoVITS/text/english.py | 12 +++++- GPT_SoVITS/text/japanese.py | 11 ++++++ GPT_SoVITS/text/zh_normalization/num.py | 39 ++++++++++++++++++- .../zh_normalization/text_normlization.py | 18 +++++++-- 6 files changed, 93 insertions(+), 18 deletions(-) diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index ecbbd2f..9189674 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -236,7 +236,7 @@ def get_first(text): text = re.split(pattern, text)[0].strip() return text - +from text import chinese def get_phones_and_bert(text,language): if language in {"en","all_zh","all_ja"}: language = language.replace("all_","") @@ -248,10 +248,17 @@ def get_phones_and_bert(text,language): formattext = text while " " in formattext: formattext = formattext.replace(" ", " ") - phones, word2ph, norm_text = clean_text_inf(formattext, language) if language == "zh": + if re.search(r'[A-Za-z]', formattext): + formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext) + formattext = chinese.text_normalize(formattext) + return get_phones_and_bert(formattext,"zh") + else: + phones, word2ph, norm_text = clean_text_inf(formattext, language) + bert = get_bert_feature(norm_text, word2ph).to(device) else: + phones, word2ph, norm_text = clean_text_inf(formattext, language) bert = torch.zeros( (1024, len(phones)), dtype=torch.float16 if is_half == True else torch.float32, @@ -327,7 +334,6 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_language != "en" else "." print(i18n("实际输入的参考文本:"), prompt_text) text = text.strip("\n") - text = replace_consecutive_punctuation(text) if (text[0] not in splits and len(get_first(text)) < 4): text = "。" + text if text_language != "en" else "." + text print(i18n("实际输入的目标文本:"), text) @@ -551,13 +557,6 @@ def process_text(texts): return _text -def replace_consecutive_punctuation(text): - punctuations = ''.join(re.escape(p) for p in punctuation) - pattern = f'([{punctuations}])([{punctuations}])+' - result = re.sub(pattern, r'\1', text) - return result - - def change_choices(): SoVITS_names, GPT_names = get_weights_names() return {"choices": sorted(SoVITS_names, key=custom_sort_key), "__type__": "update"}, {"choices": sorted(GPT_names, key=custom_sort_key), "__type__": "update"} diff --git a/GPT_SoVITS/text/chinese.py b/GPT_SoVITS/text/chinese.py index f9a4b36..bebf3f0 100644 --- a/GPT_SoVITS/text/chinese.py +++ b/GPT_SoVITS/text/chinese.py @@ -48,12 +48,19 @@ def replace_punctuation(text): replaced_text = pattern.sub(lambda x: rep_map[x.group()], text) replaced_text = re.sub( - r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text + r"[^\u4e00-\u9fa5A-Za-z" + "".join(punctuation) + r"]+", "", replaced_text ) return replaced_text +def replace_consecutive_punctuation(text): + punctuations = ''.join(re.escape(p) for p in punctuation) + pattern = f'([{punctuations}])([{punctuations}])+' + result = re.sub(pattern, r'\1', text) + return result + + def g2p(text): pattern = r"(?<=[{0}])\s*".format("".join(punctuation)) sentences = [i for i in re.split(pattern, text) if i.strip() != ""] @@ -158,6 +165,9 @@ def text_normalize(text): dest_text = "" for sentence in sentences: dest_text += replace_punctuation(sentence) + + # 避免重复标点引起的参考泄露 + dest_text = replace_consecutive_punctuation(dest_text) return dest_text diff --git a/GPT_SoVITS/text/english.py b/GPT_SoVITS/text/english.py index 30fafb5..6c80aea 100644 --- a/GPT_SoVITS/text/english.py +++ b/GPT_SoVITS/text/english.py @@ -4,7 +4,7 @@ import re import wordsegment from g2p_en import G2p -from string import punctuation +from text.symbols import punctuation from text import symbols @@ -110,6 +110,13 @@ def replace_phs(phs): return phs_new +def replace_consecutive_punctuation(text): + punctuations = ''.join(re.escape(p) for p in punctuation) + pattern = f'([{punctuations}])([{punctuations}])+' + result = re.sub(pattern, r'\1', text) + return result + + def read_dict(): g2p_dict = {} start_line = 49 @@ -234,6 +241,9 @@ def text_normalize(text): text = re.sub(r"(?i)i\.e\.", "that is", text) text = re.sub(r"(?i)e\.g\.", "for example", text) + # 避免重复标点引起的参考泄露 + text = replace_consecutive_punctuation(text) + return text diff --git a/GPT_SoVITS/text/japanese.py b/GPT_SoVITS/text/japanese.py index a571467..5aa6a8f 100644 --- a/GPT_SoVITS/text/japanese.py +++ b/GPT_SoVITS/text/japanese.py @@ -6,6 +6,7 @@ import pyopenjtalk from text import symbols +from text.symbols import punctuation # Regular expression matching Japanese without punctuation marks: _japanese_characters = re.compile( r"[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]" @@ -65,6 +66,13 @@ def post_replace_ph(ph): return ph +def replace_consecutive_punctuation(text): + punctuations = ''.join(re.escape(p) for p in punctuation) + pattern = f'([{punctuations}])([{punctuations}])+' + result = re.sub(pattern, r'\1', text) + return result + + def symbols_to_japanese(text): for regex, replacement in _symbols_to_japanese: text = re.sub(regex, replacement, text) @@ -94,6 +102,9 @@ def preprocess_jap(text, with_prosody=False): def text_normalize(text): # todo: jap text normalize + + # 避免重复标点引起的参考泄露 + text = replace_consecutive_punctuation(text) return text # Copied from espnet https://github.com/espnet/espnet/blob/master/espnet2/text/phoneme_tokenizer.py diff --git a/GPT_SoVITS/text/zh_normalization/num.py b/GPT_SoVITS/text/zh_normalization/num.py index d38d5a6..43718e7 100644 --- a/GPT_SoVITS/text/zh_normalization/num.py +++ b/GPT_SoVITS/text/zh_normalization/num.py @@ -107,8 +107,11 @@ def replace_default_num(match): # 加减乘除 +# RE_ASMD = re.compile( +# r'((-?)((\d+)(\.\d+)?)|(\.(\d+)))([\+\-\×÷=])((-?)((\d+)(\.\d+)?)|(\.(\d+)))') RE_ASMD = re.compile( - r'((-?)((\d+)(\.\d+)?)|(\.(\d+)))([\+\-\×÷=])((-?)((\d+)(\.\d+)?)|(\.(\d+)))') + r'((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))([\+\-\×÷=])((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))') + asmd_map = { '+': '加', '-': '减', @@ -117,7 +120,6 @@ asmd_map = { '=': '等于' } - def replace_asmd(match) -> str: """ Args: @@ -129,6 +131,39 @@ def replace_asmd(match) -> str: return result +# 次方专项 +RE_POWER = re.compile(r'[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]+') + +power_map = { + '⁰': '0', + '¹': '1', + '²': '2', + '³': '3', + '⁴': '4', + '⁵': '5', + '⁶': '6', + '⁷': '7', + '⁸': '8', + '⁹': '9', + 'ˣ': 'x', + 'ʸ': 'y', + 'ⁿ': 'n' +} + +def replace_power(match) -> str: + """ + Args: + match (re.Match) + Returns: + str + """ + power_num = "" + for m in match.group(0): + power_num += power_map[m] + result = "的" + power_num + "次方" + return result + + # 数字表达式 # 纯小数 RE_DECIMAL_NUM = re.compile(r'(-?)((\d+)(\.\d+))' r'|(\.(\d+))') diff --git a/GPT_SoVITS/text/zh_normalization/text_normlization.py b/GPT_SoVITS/text/zh_normalization/text_normlization.py index e852fe9..400b30f 100644 --- a/GPT_SoVITS/text/zh_normalization/text_normlization.py +++ b/GPT_SoVITS/text/zh_normalization/text_normlization.py @@ -35,6 +35,7 @@ from .num import RE_POSITIVE_QUANTIFIERS from .num import RE_RANGE from .num import RE_TO_RANGE from .num import RE_ASMD +from .num import RE_POWER from .num import replace_default_num from .num import replace_frac from .num import replace_negative_num @@ -44,6 +45,7 @@ from .num import replace_positive_quantifier from .num import replace_range from .num import replace_to_range from .num import replace_asmd +from .num import replace_power from .phonecode import RE_MOBILE_PHONE from .phonecode import RE_NATIONAL_UNIFORM_NUMBER from .phonecode import RE_TELEPHONE @@ -114,6 +116,12 @@ class TextNormalizer(): sentence = sentence.replace('χ', '器') sentence = sentence.replace('ψ', '普赛').replace('Ψ', '普赛') sentence = sentence.replace('ω', '欧米伽').replace('Ω', '欧米伽') + # 兜底数学运算,顺便兼容懒人用语 + sentence = sentence.replace('+', '加') + sentence = sentence.replace('-', '减') + sentence = sentence.replace('×', '乘') + sentence = sentence.replace('÷', '除') + sentence = sentence.replace('=', '等') # re filter special characters, have one more character "-" than line 68 sentence = re.sub(r'[-——《》【】<=>{}()()#&@“”^_|\\]', '', sentence) return sentence @@ -136,6 +144,12 @@ class TextNormalizer(): sentence = RE_TO_RANGE.sub(replace_to_range, sentence) sentence = RE_TEMPERATURE.sub(replace_temperature, sentence) sentence = replace_measure(sentence) + + # 处理数学运算 + while RE_ASMD.search(sentence): + sentence = RE_ASMD.sub(replace_asmd, sentence) + sentence = RE_POWER.sub(replace_power, sentence) + sentence = RE_FRAC.sub(replace_frac, sentence) sentence = RE_PERCENTAGE.sub(replace_percentage, sentence) sentence = RE_MOBILE_PHONE.sub(replace_mobile, sentence) @@ -145,10 +159,6 @@ class TextNormalizer(): sentence = RE_RANGE.sub(replace_range, sentence) - # 处理加减乘除 - while RE_ASMD.search(sentence): - sentence = RE_ASMD.sub(replace_asmd, sentence) - sentence = RE_INTEGER.sub(replace_negative_num, sentence) sentence = RE_DECIMAL_NUM.sub(replace_number, sentence) sentence = RE_POSITIVE_QUANTIFIERS.sub(replace_positive_quantifier,