From 90a2f0471f101353ee64e6948f9261da99d00008 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=99=BD=E8=8F=9C=E5=B7=A5=E5=8E=821145=E5=8F=B7=E5=91=98?= =?UTF-8?q?=E5=B7=A5?= <114749500+baicai-1145@users.noreply.github.com> Date: Sun, 12 Oct 2025 02:27:13 +0800 Subject: [PATCH] Implement character mapping functionality for text normalization across multiple languages, including Cantonese, Chinese, English, Japanese, and Korean. Introduce `text_normalize_with_map` methods to return normalized text along with character mappings. Add a new utility module for building character mappings. --- GPT_SoVITS/text/cantonese.py | 465 +++++++++++---------- GPT_SoVITS/text/char_mapping_utils.py | 144 +++++++ GPT_SoVITS/text/chinese.py | 21 + GPT_SoVITS/text/chinese2.py | 21 + GPT_SoVITS/text/cleaner.py | 81 ++++ GPT_SoVITS/text/en_normalization/expend.py | 208 +++++++++ GPT_SoVITS/text/english.py | 26 ++ GPT_SoVITS/text/japanese.py | 21 + GPT_SoVITS/text/korean.py | 16 + 9 files changed, 781 insertions(+), 222 deletions(-) create mode 100644 GPT_SoVITS/text/char_mapping_utils.py diff --git a/GPT_SoVITS/text/cantonese.py b/GPT_SoVITS/text/cantonese.py index 1f07c414..667dfdd9 100644 --- a/GPT_SoVITS/text/cantonese.py +++ b/GPT_SoVITS/text/cantonese.py @@ -1,222 +1,243 @@ -# reference: https://huggingface.co/spaces/Naozumi0512/Bert-VITS2-Cantonese-Yue/blob/main/text/chinese.py - -import re -import cn2an -import ToJyutping - -from text.symbols import punctuation -from text.zh_normalization.text_normlization import TextNormalizer - -normalizer = lambda x: cn2an.transform(x, "an2cn") - -INITIALS = [ - "aa", - "aai", - "aak", - "aap", - "aat", - "aau", - "ai", - "au", - "ap", - "at", - "ak", - "a", - "p", - "b", - "e", - "ts", - "t", - "dz", - "d", - "kw", - "k", - "gw", - "g", - "f", - "h", - "l", - "m", - "ng", - "n", - "s", - "y", - "w", - "c", - "z", - "j", - "ong", - "on", - "ou", - "oi", - "ok", - "o", - "uk", - "ung", -] -INITIALS += ["sp", "spl", "spn", "sil"] - - -rep_map = { - ":": ",", - ";": ",", - ",": ",", - "。": ".", - "!": "!", - "?": "?", - "\n": ".", - "·": ",", - "、": ",", - "...": "…", - "$": ".", - "“": "'", - "”": "'", - '"': "'", - "‘": "'", - "’": "'", - "(": "'", - ")": "'", - "(": "'", - ")": "'", - "《": "'", - "》": "'", - "【": "'", - "】": "'", - "[": "'", - "]": "'", - "—": "-", - "~": "-", - "~": "-", - "「": "'", - "」": "'", -} - - -def replace_punctuation(text): - # text = text.replace("嗯", "恩").replace("呣", "母") - pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys())) - - replaced_text = pattern.sub(lambda x: rep_map[x.group()], text) - - replaced_text = re.sub(r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text) - - return replaced_text - - -def text_normalize(text): - tx = TextNormalizer() - sentences = tx.normalize(text) - dest_text = "" - for sentence in sentences: - dest_text += replace_punctuation(sentence) - return dest_text - - -punctuation_set = set(punctuation) - - -def jyuping_to_initials_finals_tones(jyuping_syllables): - initials_finals = [] - tones = [] - word2ph = [] - - for syllable in jyuping_syllables: - if syllable in punctuation: - initials_finals.append(syllable) - tones.append(0) - word2ph.append(1) # Add 1 for punctuation - elif syllable == "_": - initials_finals.append(syllable) - tones.append(0) - word2ph.append(1) # Add 1 for underscore - else: - try: - tone = int(syllable[-1]) - syllable_without_tone = syllable[:-1] - except ValueError: - tone = 0 - syllable_without_tone = syllable - - for initial in INITIALS: - if syllable_without_tone.startswith(initial): - if syllable_without_tone.startswith("nga"): - initials_finals.extend( - [ - syllable_without_tone[:2], - syllable_without_tone[2:] or syllable_without_tone[-1], - ] - ) - # tones.extend([tone, tone]) - tones.extend([-1, tone]) - word2ph.append(2) - else: - final = syllable_without_tone[len(initial) :] or initial[-1] - initials_finals.extend([initial, final]) - # tones.extend([tone, tone]) - tones.extend([-1, tone]) - word2ph.append(2) - break - assert len(initials_finals) == len(tones) - - ###魔改为辅音+带音调的元音 - phones = [] - for a, b in zip(initials_finals, tones): - if b not in [-1, 0]: ###防止粤语和普通话重合开头加Y,如果是标点,不加。 - todo = "%s%s" % (a, b) - else: - todo = a - if todo not in punctuation_set: - todo = "Y%s" % todo - phones.append(todo) - - # return initials_finals, tones, word2ph - return phones, word2ph - - -def get_jyutping(text): - jyutping_array = [] - punct_pattern = re.compile(r"^[{}]+$".format(re.escape("".join(punctuation)))) - - syllables = ToJyutping.get_jyutping_list(text) - - for word, syllable in syllables: - if punct_pattern.match(word): - puncts = re.split(r"([{}])".format(re.escape("".join(punctuation))), word) - for punct in puncts: - if len(punct) > 0: - jyutping_array.append(punct) - else: - # match multple jyutping eg: liu4 ge3, or single jyutping eg: liu4 - if not re.search(r"^([a-z]+[1-6]+[ ]?)+$", syllable): - raise ValueError(f"Failed to convert {word} to jyutping: {syllable}") - jyutping_array.append(syllable) - - return jyutping_array - - -def get_bert_feature(text, word2ph): - from text import chinese_bert - - return chinese_bert.get_bert_feature(text, word2ph) - - -def g2p(text): - # word2ph = [] - jyuping = get_jyutping(text) - # print(jyuping) - # phones, tones, word2ph = jyuping_to_initials_finals_tones(jyuping) - phones, word2ph = jyuping_to_initials_finals_tones(jyuping) - # phones = ["_"] + phones + ["_"] - # tones = [0] + tones + [0] - # word2ph = [1] + word2ph + [1] - return phones, word2ph - - -if __name__ == "__main__": - # text = "啊!但是《原神》是由,米哈\游自主, [研发]的一款全.新开放世界.冒险游戏" - text = "佢個鋤頭太短啦。" - text = text_normalize(text) - # phones, tones, word2ph = g2p(text) - phones, word2ph = g2p(text) - # print(phones, tones, word2ph) - print(phones, word2ph) +# reference: https://huggingface.co/spaces/Naozumi0512/Bert-VITS2-Cantonese-Yue/blob/main/text/chinese.py + +import re +import cn2an +import ToJyutping + +from text.symbols import punctuation +from text.zh_normalization.text_normlization import TextNormalizer + +normalizer = lambda x: cn2an.transform(x, "an2cn") + +INITIALS = [ + "aa", + "aai", + "aak", + "aap", + "aat", + "aau", + "ai", + "au", + "ap", + "at", + "ak", + "a", + "p", + "b", + "e", + "ts", + "t", + "dz", + "d", + "kw", + "k", + "gw", + "g", + "f", + "h", + "l", + "m", + "ng", + "n", + "s", + "y", + "w", + "c", + "z", + "j", + "ong", + "on", + "ou", + "oi", + "ok", + "o", + "uk", + "ung", +] +INITIALS += ["sp", "spl", "spn", "sil"] + + +rep_map = { + ":": ",", + ";": ",", + ",": ",", + "。": ".", + "!": "!", + "?": "?", + "\n": ".", + "·": ",", + "、": ",", + "...": "…", + "$": ".", + "“": "'", + "”": "'", + '"': "'", + "‘": "'", + "’": "'", + "(": "'", + ")": "'", + "(": "'", + ")": "'", + "《": "'", + "》": "'", + "【": "'", + "】": "'", + "[": "'", + "]": "'", + "—": "-", + "~": "-", + "~": "-", + "「": "'", + "」": "'", +} + + +def replace_punctuation(text): + # text = text.replace("嗯", "恩").replace("呣", "母") + pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys())) + + replaced_text = pattern.sub(lambda x: rep_map[x.group()], text) + + replaced_text = re.sub(r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text) + + return replaced_text + + +def text_normalize(text): + tx = TextNormalizer() + sentences = tx.normalize(text) + dest_text = "" + for sentence in sentences: + dest_text += replace_punctuation(sentence) + return dest_text + + +def text_normalize_with_map(text): + """ + 带字符映射的粤语标准化函数 + + Returns: + normalized_text: 标准化后的文本 + char_mappings: 字典,包含: + - "orig_to_norm": list[int], 原始文本每个字符对应标准化文本的位置 + - "norm_to_orig": list[int], 标准化文本每个字符对应原始文本的位置 + """ + from .char_mapping_utils import build_char_mapping + + # 先进行标准化 + normalized_text = text_normalize(text) + + # 构建字符映射 + mappings = build_char_mapping(text, normalized_text) + + return normalized_text, mappings + + +punctuation_set = set(punctuation) + + +def jyuping_to_initials_finals_tones(jyuping_syllables): + initials_finals = [] + tones = [] + word2ph = [] + + for syllable in jyuping_syllables: + if syllable in punctuation: + initials_finals.append(syllable) + tones.append(0) + word2ph.append(1) # Add 1 for punctuation + elif syllable == "_": + initials_finals.append(syllable) + tones.append(0) + word2ph.append(1) # Add 1 for underscore + else: + try: + tone = int(syllable[-1]) + syllable_without_tone = syllable[:-1] + except ValueError: + tone = 0 + syllable_without_tone = syllable + + for initial in INITIALS: + if syllable_without_tone.startswith(initial): + if syllable_without_tone.startswith("nga"): + initials_finals.extend( + [ + syllable_without_tone[:2], + syllable_without_tone[2:] or syllable_without_tone[-1], + ] + ) + # tones.extend([tone, tone]) + tones.extend([-1, tone]) + word2ph.append(2) + else: + final = syllable_without_tone[len(initial) :] or initial[-1] + initials_finals.extend([initial, final]) + # tones.extend([tone, tone]) + tones.extend([-1, tone]) + word2ph.append(2) + break + assert len(initials_finals) == len(tones) + + ###魔改为辅音+带音调的元音 + phones = [] + for a, b in zip(initials_finals, tones): + if b not in [-1, 0]: ###防止粤语和普通话重合开头加Y,如果是标点,不加。 + todo = "%s%s" % (a, b) + else: + todo = a + if todo not in punctuation_set: + todo = "Y%s" % todo + phones.append(todo) + + # return initials_finals, tones, word2ph + return phones, word2ph + + +def get_jyutping(text): + jyutping_array = [] + punct_pattern = re.compile(r"^[{}]+$".format(re.escape("".join(punctuation)))) + + syllables = ToJyutping.get_jyutping_list(text) + + for word, syllable in syllables: + if punct_pattern.match(word): + puncts = re.split(r"([{}])".format(re.escape("".join(punctuation))), word) + for punct in puncts: + if len(punct) > 0: + jyutping_array.append(punct) + else: + # match multple jyutping eg: liu4 ge3, or single jyutping eg: liu4 + if not re.search(r"^([a-z]+[1-6]+[ ]?)+$", syllable): + raise ValueError(f"Failed to convert {word} to jyutping: {syllable}") + jyutping_array.append(syllable) + + return jyutping_array + + +def get_bert_feature(text, word2ph): + from text import chinese_bert + + return chinese_bert.get_bert_feature(text, word2ph) + + +def g2p(text): + # word2ph = [] + jyuping = get_jyutping(text) + # print(jyuping) + # phones, tones, word2ph = jyuping_to_initials_finals_tones(jyuping) + phones, word2ph = jyuping_to_initials_finals_tones(jyuping) + # phones = ["_"] + phones + ["_"] + # tones = [0] + tones + [0] + # word2ph = [1] + word2ph + [1] + return phones, word2ph + + +if __name__ == "__main__": + # text = "啊!但是《原神》是由,米哈\游自主, [研发]的一款全.新开放世界.冒险游戏" + text = "佢個鋤頭太短啦。" + text = text_normalize(text) + # phones, tones, word2ph = g2p(text) + phones, word2ph = g2p(text) + # print(phones, tones, word2ph) + print(phones, word2ph) diff --git a/GPT_SoVITS/text/char_mapping_utils.py b/GPT_SoVITS/text/char_mapping_utils.py new file mode 100644 index 00000000..fab656a0 --- /dev/null +++ b/GPT_SoVITS/text/char_mapping_utils.py @@ -0,0 +1,144 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +通用字符映射工具 + +用于在文本标准化后,建立原始文本到标准化文本的字符级映射 +""" + + +def build_char_mapping(original_text, normalized_text): + """ + 通过字符对齐算法,构建原始文本到标准化文本的映射 + + Args: + original_text: 原始文本 + normalized_text: 标准化后的文本 + + Returns: + dict: { + "orig_to_norm": list[int], 原始文本每个字符对应标准化文本的位置 + "norm_to_orig": list[int], 标准化文本每个字符对应原始文本的位置 + } + """ + # 使用动态规划找到最优对齐 + m, n = len(original_text), len(normalized_text) + + # dp[i][j] 表示 original_text[:i] 和 normalized_text[:j] 的对齐代价 + # 0 = 匹配, 1 = 替换, 插入, 删除 + dp = [[float('inf')] * (n + 1) for _ in range(m + 1)] + + # 记录路径 + path = [[None] * (n + 1) for _ in range(m + 1)] + + # 初始化 + for i in range(m + 1): + dp[i][0] = i + if i > 0: + path[i][0] = ('del', i-1, -1) + + for j in range(n + 1): + dp[0][j] = j + if j > 0: + path[0][j] = ('ins', -1, j-1) + + # 动态规划 + for i in range(1, m + 1): + for j in range(1, n + 1): + orig_char = original_text[i-1] + norm_char = normalized_text[j-1] + + # 匹配代价(相同字符代价为0,不同字符代价为1) + match_cost = 0 if orig_char == norm_char else 1 + + # 三种操作的代价 + costs = [ + (dp[i-1][j-1] + match_cost, 'match' if match_cost == 0 else 'replace', i-1, j-1), + (dp[i-1][j] + 1, 'del', i-1, j), + (dp[i][j-1] + 1, 'ins', i, j-1), + ] + + min_cost, op, pi, pj = min(costs, key=lambda x: x[0]) + dp[i][j] = min_cost + path[i][j] = (op, pi, pj) + + # 回溯路径,构建映射 + orig_to_norm = [-1] * len(original_text) + norm_to_orig = [-1] * len(normalized_text) + + i, j = m, n + alignments = [] + + while i > 0 or j > 0: + if path[i][j] is None: + break + + op, pi, pj = path[i][j] + + if op in ['match', 'replace']: + # 原始字符i-1 对应 标准化字符j-1 + alignments.append((i-1, j-1)) + i, j = pi, pj + elif op == 'del': + # 原始字符i-1 被删除,映射到当前标准化位置(如果存在) + if j > 0: + alignments.append((i-1, j-1)) + i = pi + elif op == 'ins': + # 标准化字符j-1 是插入的,没有对应的原始字符 + j = pj + + # 根据对齐结果建立映射 + for orig_idx, norm_idx in alignments: + if orig_idx >= 0 and orig_idx < len(original_text): + if orig_to_norm[orig_idx] == -1: + orig_to_norm[orig_idx] = norm_idx + + if norm_idx >= 0 and norm_idx < len(normalized_text): + if norm_to_orig[norm_idx] == -1: + norm_to_orig[norm_idx] = orig_idx + + return { + "orig_to_norm": orig_to_norm, + "norm_to_orig": norm_to_orig + } + + +def test_char_mapping(): + """测试字符映射功能""" + test_cases = [ + ("50元", "五十元"), + ("3.5度", "三点五度"), + ("GPT-4", "GPT minus four"), # 也可以测试英文 + ] + + for orig, norm in test_cases: + print(f"原始: '{orig}'") + print(f"标准化: '{norm}'") + + mappings = build_char_mapping(orig, norm) + orig_to_norm = mappings["orig_to_norm"] + norm_to_orig = mappings["norm_to_orig"] + + print(f"原始→标准化映射:") + for i, c in enumerate(orig): + norm_idx = orig_to_norm[i] + if norm_idx >= 0 and norm_idx < len(norm): + print(f" [{i}]'{c}' → [{norm_idx}]'{norm[norm_idx]}'") + else: + print(f" [{i}]'{c}' → 无映射") + + print(f"标准化→原始映射:") + for i, c in enumerate(norm): + orig_idx = norm_to_orig[i] + if orig_idx >= 0 and orig_idx < len(orig): + print(f" [{i}]'{c}' ← [{orig_idx}]'{orig[orig_idx]}'") + else: + print(f" [{i}]'{c}' ← 无对应") + + print() + + +if __name__ == "__main__": + test_char_mapping() + diff --git a/GPT_SoVITS/text/chinese.py b/GPT_SoVITS/text/chinese.py index 944c9cb7..dff68757 100644 --- a/GPT_SoVITS/text/chinese.py +++ b/GPT_SoVITS/text/chinese.py @@ -181,6 +181,27 @@ def text_normalize(text): return dest_text +def text_normalize_with_map(text): + """ + 带字符映射的中文标准化函数(旧版) + + Returns: + normalized_text: 标准化后的文本 + char_mappings: 字典,包含: + - "orig_to_norm": list[int], 原始文本每个字符对应标准化文本的位置 + - "norm_to_orig": list[int], 标准化文本每个字符对应原始文本的位置 + """ + from .char_mapping_utils import build_char_mapping + + # 先进行标准化 + normalized_text = text_normalize(text) + + # 构建字符映射 + mappings = build_char_mapping(text, normalized_text) + + return normalized_text, mappings + + if __name__ == "__main__": text = "啊——但是《原神》是由,米哈\游自主,研发的一款全.新开放世界.冒险游戏" text = "呣呣呣~就是…大人的鼹鼠党吧?" diff --git a/GPT_SoVITS/text/chinese2.py b/GPT_SoVITS/text/chinese2.py index dcce0d96..a497223f 100644 --- a/GPT_SoVITS/text/chinese2.py +++ b/GPT_SoVITS/text/chinese2.py @@ -326,6 +326,27 @@ def text_normalize(text): return dest_text +def text_normalize_with_map(text): + """ + 带字符映射的中文标准化函数 + + Returns: + normalized_text: 标准化后的文本 + char_mappings: 字典,包含: + - "orig_to_norm": list[int], 原始文本每个字符对应标准化文本的位置 + - "norm_to_orig": list[int], 标准化文本每个字符对应原始文本的位置 + """ + from .char_mapping_utils import build_char_mapping + + # 先进行标准化 + normalized_text = text_normalize(text) + + # 构建字符映射 + mappings = build_char_mapping(text, normalized_text) + + return normalized_text, mappings + + if __name__ == "__main__": text = "啊——但是《原神》是由,米哈\游自主,研发的一款全.新开放世界.冒险游戏" text = "呣呣呣~就是…大人的鼹鼠党吧?" diff --git a/GPT_SoVITS/text/cleaner.py b/GPT_SoVITS/text/cleaner.py index 4176db43..7f753fd8 100644 --- a/GPT_SoVITS/text/cleaner.py +++ b/GPT_SoVITS/text/cleaner.py @@ -61,6 +61,87 @@ def clean_text(text, language, version=None): return phones, word2ph, norm_text +def clean_text_with_mapping(text, language, version=None): + """ + 带字符映射的文本清洗函数 + + Args: + text: 原始文本 + language: 语言代码 (zh, ja, en, ko, yue) + version: 模型版本 (v1, v2) + + Returns: + tuple: (phones, word2ph, norm_text, char_mappings) + - phones: 音素序列 + - word2ph: 词到音素的映射 + - norm_text: 标准化后的文本 + - char_mappings: 字符映射字典 {"orig_to_norm": [...], "norm_to_orig": [...]} + """ + if version is None: + version = os.environ.get("version", "v2") + if version == "v1": + symbols = symbols_v1.symbols + language_module_map = {"zh": "chinese", "ja": "japanese", "en": "english"} + else: + symbols = symbols_v2.symbols + language_module_map = {"zh": "chinese2", "ja": "japanese", "en": "english", "ko": "korean", "yue": "cantonese"} + + if language not in language_module_map: + language = "en" + text = " " + + # 特殊符号处理暂不支持映射 + for special_s, special_l, target_symbol in special: + if special_s in text and language == special_l: + phones, word2ph, norm_text = clean_special(text, language, special_s, target_symbol, version) + # 返回空映射 + char_mappings = { + "orig_to_norm": list(range(len(text))), + "norm_to_orig": list(range(len(norm_text))) + } + return phones, word2ph, norm_text, char_mappings + + language_module = __import__("text." + language_module_map[language], fromlist=[language_module_map[language]]) + + # 使用带映射的标准化函数 + if hasattr(language_module, "text_normalize_with_map"): + norm_text, char_mappings = language_module.text_normalize_with_map(text) + elif hasattr(language_module, "text_normalize"): + # 如果没有带映射的版本,使用普通版本并构建映射 + norm_text = language_module.text_normalize(text) + from .char_mapping_utils import build_char_mapping + char_mappings = build_char_mapping(text, norm_text) + else: + # 不进行标准化 + norm_text = text + char_mappings = { + "orig_to_norm": list(range(len(text))), + "norm_to_orig": list(range(len(text))) + } + + # 处理音素 + if language == "zh" or language == "yue": + phones, word2ph = language_module.g2p(norm_text) + assert len(phones) == sum(word2ph) + assert len(norm_text) == len(word2ph) + else: + # Try per-language word2ph helpers + if hasattr(language_module, "g2p_with_word2ph"): + try: + phones, word2ph = language_module.g2p_with_word2ph(norm_text, keep_punc=False) + except Exception: + phones = language_module.g2p(norm_text) + word2ph = None + else: + phones = language_module.g2p(norm_text) + word2ph = None + if language == "en" and len(phones) < 4: + phones = [","] + phones + + phones = ["UNK" if ph not in symbols else ph for ph in phones] + return phones, word2ph, norm_text, char_mappings + + def clean_special(text, language, special_s, target_symbol, version=None): if version is None: version = os.environ.get("version", "v2") diff --git a/GPT_SoVITS/text/en_normalization/expend.py b/GPT_SoVITS/text/en_normalization/expend.py index bbd607cd..484f1f91 100644 --- a/GPT_SoVITS/text/en_normalization/expend.py +++ b/GPT_SoVITS/text/en_normalization/expend.py @@ -238,6 +238,214 @@ def _expand_number(m): return _inflect.number_to_words(num, andword="") +class CharMapper: + """ + 字符映射追踪器,用于记录文本标准化过程中字符位置的变化 + + 核心思想:维护从原始文本到当前文本的映射 + - orig_to_curr[i] 表示原始文本第i个字符对应当前文本的位置 + """ + def __init__(self, text): + self.original_text = text + self.text = text + # 初始化:每个原始字符映射到自己 + self.orig_to_curr = list(range(len(text))) + + def apply_sub(self, pattern, replacement_func): + """ + 应用正则替换并更新映射 + + 关键:需要通过旧映射来更新新映射 + 支持捕获组的特殊处理(如大写字母拆分) + """ + new_text = "" + # curr_to_new[i] 表示当前文本第i个字符在新文本中的位置 + curr_to_new = [-1] * len(self.text) + + pos = 0 + for match in pattern.finditer(self.text): + # 处理匹配前的未变化文本 + for i in range(pos, match.start()): + curr_to_new[i] = len(new_text) + new_text += self.text[i] + + # 处理匹配的部分 + replacement = replacement_func(match) + replacement_start_pos = len(new_text) + + # 特殊处理:如果替换文本包含原始文本的字符(例如 "A" -> " A") + # 尝试找到对应关系 + match_text = match.group(0) + if len(match.groups()) > 0 and match.group(1) in replacement: + # 有捕获组,尝试精确映射 + # 例如: "(?= 0: + # 捕获的字符在替换文本中 + for i in range(match.start(), match.end()): + char = self.text[i] + if char in captured: + # 这个字符在捕获组中,映射到替换文本中的对应位置 + char_idx_in_replacement = replacement.find(char, replacement_idx) + if char_idx_in_replacement >= 0: + curr_to_new[i] = replacement_start_pos + char_idx_in_replacement + else: + curr_to_new[i] = replacement_start_pos + else: + curr_to_new[i] = replacement_start_pos + else: + # 捕获的字符不在替换文本中,都映射到起始位置 + for i in range(match.start(), match.end()): + curr_to_new[i] = replacement_start_pos + else: + # 没有捕获组或不包含原始字符,匹配部分的所有字符都映射到替换文本的起始位置 + for i in range(match.start(), match.end()): + curr_to_new[i] = replacement_start_pos + + new_text += replacement + pos = match.end() + + # 处理剩余文本 + for i in range(pos, len(self.text)): + curr_to_new[i] = len(new_text) + new_text += self.text[i] + + # 更新原始到当前的映射:orig -> old_curr -> new_curr + new_orig_to_curr = [] + for orig_idx in range(len(self.original_text)): + old_curr_idx = self.orig_to_curr[orig_idx] + if old_curr_idx >= 0 and old_curr_idx < len(curr_to_new): + new_orig_to_curr.append(curr_to_new[old_curr_idx]) + else: + new_orig_to_curr.append(-1) + + self.text = new_text + self.orig_to_curr = new_orig_to_curr + + def apply_char_filter(self, keep_pattern): + """ + 应用字符过滤(只保留符合模式的字符)并更新映射 + + keep_pattern: 正则表达式字符串,如 "[ A-Za-z'.,?!-]" + """ + new_text = "" + curr_to_new = [] + + for i, char in enumerate(self.text): + if re.match(keep_pattern, char): + curr_to_new.append(len(new_text)) + new_text += char + else: + # 字符被删除 + if new_text: + curr_to_new.append(len(new_text) - 1) + else: + curr_to_new.append(-1) + + # 更新原始映射 + new_orig_to_curr = [] + for orig_idx in range(len(self.original_text)): + old_curr_idx = self.orig_to_curr[orig_idx] + if old_curr_idx >= 0 and old_curr_idx < len(curr_to_new): + new_orig_to_curr.append(curr_to_new[old_curr_idx]) + else: + new_orig_to_curr.append(-1) + + self.text = new_text + self.orig_to_curr = new_orig_to_curr + + def get_norm_to_orig(self): + """ + 构建标准化文本到原始文本的反向映射 + """ + if not self.text: + return [] + + norm_to_orig = [-1] * len(self.text) + for orig_idx, norm_idx in enumerate(self.orig_to_curr): + if 0 <= norm_idx < len(self.text): + # 如果多个原始字符映射到同一个标准化位置,取第一个 + if norm_to_orig[norm_idx] == -1: + norm_to_orig[norm_idx] = orig_idx + + return norm_to_orig + + +def normalize_with_map(text): + """ + 带字符映射的标准化函数 + + 返回: + normalized_text: 标准化后的文本 + char_mappings: 字典,包含: + - "orig_to_norm": list[int], 原始文本每个字符对应标准化文本的位置 + - "norm_to_orig": list[int], 标准化文本每个字符对应原始文本的位置 + """ + mapper = CharMapper(text) + + # 按照 normalize() 的顺序应用所有转换 + mapper.apply_sub(_ordinal_number_re, _convert_ordinal) + mapper.apply_sub(re.compile(r"(?= len(mapper.text): + # NFD 展开了一些字符 + new_orig_to_curr = [] + for orig_idx in range(len(mapper.original_text)): + old_curr_idx = mapper.orig_to_curr[orig_idx] + if old_curr_idx >= 0 and old_curr_idx < len(curr_to_new): + new_orig_to_curr.append(curr_to_new[old_curr_idx]) + else: + new_orig_to_curr.append(-1) + mapper.orig_to_curr = new_orig_to_curr + mapper.text = new_text + + # 继续其他替换 + mapper.apply_sub(re.compile("%"), lambda m: " percent") + + # 删除非法字符 - 使用 apply_char_filter + mapper.apply_char_filter(r"[ A-Za-z'.,?!\-]") + + mapper.apply_sub(re.compile(r"(?i)i\.e\."), lambda m: "that is") + mapper.apply_sub(re.compile(r"(?i)e\.g\."), lambda m: "for example") + mapper.apply_sub(re.compile(r"(?