Implement character mapping functionality for text normalization across multiple languages, including Cantonese, Chinese, English, Japanese, and Korean. Introduce text_normalize_with_map methods to return normalized text along with character mappings. Add a new utility module for building character mappings.

This commit is contained in:
白菜工厂1145号员工 2025-10-12 02:27:13 +08:00
parent 0bde5d4a81
commit 90a2f0471f
9 changed files with 781 additions and 222 deletions

View File

@ -112,6 +112,27 @@ def text_normalize(text):
return dest_text
def text_normalize_with_map(text):
"""
带字符映射的粤语标准化函数
Returns:
normalized_text: 标准化后的文本
char_mappings: 字典包含:
- "orig_to_norm": list[int], 原始文本每个字符对应标准化文本的位置
- "norm_to_orig": list[int], 标准化文本每个字符对应原始文本的位置
"""
from .char_mapping_utils import build_char_mapping
# 先进行标准化
normalized_text = text_normalize(text)
# 构建字符映射
mappings = build_char_mapping(text, normalized_text)
return normalized_text, mappings
punctuation_set = set(punctuation)

View File

@ -0,0 +1,144 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
通用字符映射工具
用于在文本标准化后建立原始文本到标准化文本的字符级映射
"""
def build_char_mapping(original_text, normalized_text):
"""
通过字符对齐算法构建原始文本到标准化文本的映射
Args:
original_text: 原始文本
normalized_text: 标准化后的文本
Returns:
dict: {
"orig_to_norm": list[int], 原始文本每个字符对应标准化文本的位置
"norm_to_orig": list[int], 标准化文本每个字符对应原始文本的位置
}
"""
# 使用动态规划找到最优对齐
m, n = len(original_text), len(normalized_text)
# dp[i][j] 表示 original_text[:i] 和 normalized_text[:j] 的对齐代价
# 0 = 匹配, 1 = 替换, 插入, 删除
dp = [[float('inf')] * (n + 1) for _ in range(m + 1)]
# 记录路径
path = [[None] * (n + 1) for _ in range(m + 1)]
# 初始化
for i in range(m + 1):
dp[i][0] = i
if i > 0:
path[i][0] = ('del', i-1, -1)
for j in range(n + 1):
dp[0][j] = j
if j > 0:
path[0][j] = ('ins', -1, j-1)
# 动态规划
for i in range(1, m + 1):
for j in range(1, n + 1):
orig_char = original_text[i-1]
norm_char = normalized_text[j-1]
# 匹配代价相同字符代价为0不同字符代价为1
match_cost = 0 if orig_char == norm_char else 1
# 三种操作的代价
costs = [
(dp[i-1][j-1] + match_cost, 'match' if match_cost == 0 else 'replace', i-1, j-1),
(dp[i-1][j] + 1, 'del', i-1, j),
(dp[i][j-1] + 1, 'ins', i, j-1),
]
min_cost, op, pi, pj = min(costs, key=lambda x: x[0])
dp[i][j] = min_cost
path[i][j] = (op, pi, pj)
# 回溯路径,构建映射
orig_to_norm = [-1] * len(original_text)
norm_to_orig = [-1] * len(normalized_text)
i, j = m, n
alignments = []
while i > 0 or j > 0:
if path[i][j] is None:
break
op, pi, pj = path[i][j]
if op in ['match', 'replace']:
# 原始字符i-1 对应 标准化字符j-1
alignments.append((i-1, j-1))
i, j = pi, pj
elif op == 'del':
# 原始字符i-1 被删除,映射到当前标准化位置(如果存在)
if j > 0:
alignments.append((i-1, j-1))
i = pi
elif op == 'ins':
# 标准化字符j-1 是插入的,没有对应的原始字符
j = pj
# 根据对齐结果建立映射
for orig_idx, norm_idx in alignments:
if orig_idx >= 0 and orig_idx < len(original_text):
if orig_to_norm[orig_idx] == -1:
orig_to_norm[orig_idx] = norm_idx
if norm_idx >= 0 and norm_idx < len(normalized_text):
if norm_to_orig[norm_idx] == -1:
norm_to_orig[norm_idx] = orig_idx
return {
"orig_to_norm": orig_to_norm,
"norm_to_orig": norm_to_orig
}
def test_char_mapping():
"""测试字符映射功能"""
test_cases = [
("50元", "五十元"),
("3.5度", "三点五度"),
("GPT-4", "GPT minus four"), # 也可以测试英文
]
for orig, norm in test_cases:
print(f"原始: '{orig}'")
print(f"标准化: '{norm}'")
mappings = build_char_mapping(orig, norm)
orig_to_norm = mappings["orig_to_norm"]
norm_to_orig = mappings["norm_to_orig"]
print(f"原始→标准化映射:")
for i, c in enumerate(orig):
norm_idx = orig_to_norm[i]
if norm_idx >= 0 and norm_idx < len(norm):
print(f" [{i}]'{c}' → [{norm_idx}]'{norm[norm_idx]}'")
else:
print(f" [{i}]'{c}' → 无映射")
print(f"标准化→原始映射:")
for i, c in enumerate(norm):
orig_idx = norm_to_orig[i]
if orig_idx >= 0 and orig_idx < len(orig):
print(f" [{i}]'{c}' ← [{orig_idx}]'{orig[orig_idx]}'")
else:
print(f" [{i}]'{c}' ← 无对应")
print()
if __name__ == "__main__":
test_char_mapping()

View File

@ -181,6 +181,27 @@ def text_normalize(text):
return dest_text
def text_normalize_with_map(text):
"""
带字符映射的中文标准化函数旧版
Returns:
normalized_text: 标准化后的文本
char_mappings: 字典包含:
- "orig_to_norm": list[int], 原始文本每个字符对应标准化文本的位置
- "norm_to_orig": list[int], 标准化文本每个字符对应原始文本的位置
"""
from .char_mapping_utils import build_char_mapping
# 先进行标准化
normalized_text = text_normalize(text)
# 构建字符映射
mappings = build_char_mapping(text, normalized_text)
return normalized_text, mappings
if __name__ == "__main__":
text = "啊——但是《原神》是由,米哈\游自主,研发的一款全.新开放世界.冒险游戏"
text = "呣呣呣~就是…大人的鼹鼠党吧?"

View File

@ -326,6 +326,27 @@ def text_normalize(text):
return dest_text
def text_normalize_with_map(text):
"""
带字符映射的中文标准化函数
Returns:
normalized_text: 标准化后的文本
char_mappings: 字典包含:
- "orig_to_norm": list[int], 原始文本每个字符对应标准化文本的位置
- "norm_to_orig": list[int], 标准化文本每个字符对应原始文本的位置
"""
from .char_mapping_utils import build_char_mapping
# 先进行标准化
normalized_text = text_normalize(text)
# 构建字符映射
mappings = build_char_mapping(text, normalized_text)
return normalized_text, mappings
if __name__ == "__main__":
text = "啊——但是《原神》是由,米哈\游自主,研发的一款全.新开放世界.冒险游戏"
text = "呣呣呣~就是…大人的鼹鼠党吧?"

View File

@ -61,6 +61,87 @@ def clean_text(text, language, version=None):
return phones, word2ph, norm_text
def clean_text_with_mapping(text, language, version=None):
"""
带字符映射的文本清洗函数
Args:
text: 原始文本
language: 语言代码 (zh, ja, en, ko, yue)
version: 模型版本 (v1, v2)
Returns:
tuple: (phones, word2ph, norm_text, char_mappings)
- phones: 音素序列
- word2ph: 词到音素的映射
- norm_text: 标准化后的文本
- char_mappings: 字符映射字典 {"orig_to_norm": [...], "norm_to_orig": [...]}
"""
if version is None:
version = os.environ.get("version", "v2")
if version == "v1":
symbols = symbols_v1.symbols
language_module_map = {"zh": "chinese", "ja": "japanese", "en": "english"}
else:
symbols = symbols_v2.symbols
language_module_map = {"zh": "chinese2", "ja": "japanese", "en": "english", "ko": "korean", "yue": "cantonese"}
if language not in language_module_map:
language = "en"
text = " "
# 特殊符号处理暂不支持映射
for special_s, special_l, target_symbol in special:
if special_s in text and language == special_l:
phones, word2ph, norm_text = clean_special(text, language, special_s, target_symbol, version)
# 返回空映射
char_mappings = {
"orig_to_norm": list(range(len(text))),
"norm_to_orig": list(range(len(norm_text)))
}
return phones, word2ph, norm_text, char_mappings
language_module = __import__("text." + language_module_map[language], fromlist=[language_module_map[language]])
# 使用带映射的标准化函数
if hasattr(language_module, "text_normalize_with_map"):
norm_text, char_mappings = language_module.text_normalize_with_map(text)
elif hasattr(language_module, "text_normalize"):
# 如果没有带映射的版本,使用普通版本并构建映射
norm_text = language_module.text_normalize(text)
from .char_mapping_utils import build_char_mapping
char_mappings = build_char_mapping(text, norm_text)
else:
# 不进行标准化
norm_text = text
char_mappings = {
"orig_to_norm": list(range(len(text))),
"norm_to_orig": list(range(len(text)))
}
# 处理音素
if language == "zh" or language == "yue":
phones, word2ph = language_module.g2p(norm_text)
assert len(phones) == sum(word2ph)
assert len(norm_text) == len(word2ph)
else:
# Try per-language word2ph helpers
if hasattr(language_module, "g2p_with_word2ph"):
try:
phones, word2ph = language_module.g2p_with_word2ph(norm_text, keep_punc=False)
except Exception:
phones = language_module.g2p(norm_text)
word2ph = None
else:
phones = language_module.g2p(norm_text)
word2ph = None
if language == "en" and len(phones) < 4:
phones = [","] + phones
phones = ["UNK" if ph not in symbols else ph for ph in phones]
return phones, word2ph, norm_text, char_mappings
def clean_special(text, language, special_s, target_symbol, version=None):
if version is None:
version = os.environ.get("version", "v2")

View File

@ -238,6 +238,214 @@ def _expand_number(m):
return _inflect.number_to_words(num, andword="")
class CharMapper:
"""
字符映射追踪器用于记录文本标准化过程中字符位置的变化
核心思想维护从原始文本到当前文本的映射
- orig_to_curr[i] 表示原始文本第i个字符对应当前文本的位置
"""
def __init__(self, text):
self.original_text = text
self.text = text
# 初始化:每个原始字符映射到自己
self.orig_to_curr = list(range(len(text)))
def apply_sub(self, pattern, replacement_func):
"""
应用正则替换并更新映射
关键需要通过旧映射来更新新映射
支持捕获组的特殊处理如大写字母拆分
"""
new_text = ""
# curr_to_new[i] 表示当前文本第i个字符在新文本中的位置
curr_to_new = [-1] * len(self.text)
pos = 0
for match in pattern.finditer(self.text):
# 处理匹配前的未变化文本
for i in range(pos, match.start()):
curr_to_new[i] = len(new_text)
new_text += self.text[i]
# 处理匹配的部分
replacement = replacement_func(match)
replacement_start_pos = len(new_text)
# 特殊处理:如果替换文本包含原始文本的字符(例如 "A" -> " A"
# 尝试找到对应关系
match_text = match.group(0)
if len(match.groups()) > 0 and match.group(1) in replacement:
# 有捕获组,尝试精确映射
# 例如: "(?<!^)(?<![\s])([A-Z])" 匹配 "P",替换为 " P"
captured = match.group(1)
replacement_idx = replacement.find(captured)
if replacement_idx >= 0:
# 捕获的字符在替换文本中
for i in range(match.start(), match.end()):
char = self.text[i]
if char in captured:
# 这个字符在捕获组中,映射到替换文本中的对应位置
char_idx_in_replacement = replacement.find(char, replacement_idx)
if char_idx_in_replacement >= 0:
curr_to_new[i] = replacement_start_pos + char_idx_in_replacement
else:
curr_to_new[i] = replacement_start_pos
else:
curr_to_new[i] = replacement_start_pos
else:
# 捕获的字符不在替换文本中,都映射到起始位置
for i in range(match.start(), match.end()):
curr_to_new[i] = replacement_start_pos
else:
# 没有捕获组或不包含原始字符,匹配部分的所有字符都映射到替换文本的起始位置
for i in range(match.start(), match.end()):
curr_to_new[i] = replacement_start_pos
new_text += replacement
pos = match.end()
# 处理剩余文本
for i in range(pos, len(self.text)):
curr_to_new[i] = len(new_text)
new_text += self.text[i]
# 更新原始到当前的映射orig -> old_curr -> new_curr
new_orig_to_curr = []
for orig_idx in range(len(self.original_text)):
old_curr_idx = self.orig_to_curr[orig_idx]
if old_curr_idx >= 0 and old_curr_idx < len(curr_to_new):
new_orig_to_curr.append(curr_to_new[old_curr_idx])
else:
new_orig_to_curr.append(-1)
self.text = new_text
self.orig_to_curr = new_orig_to_curr
def apply_char_filter(self, keep_pattern):
"""
应用字符过滤只保留符合模式的字符并更新映射
keep_pattern: 正则表达式字符串 "[ A-Za-z'.,?!-]"
"""
new_text = ""
curr_to_new = []
for i, char in enumerate(self.text):
if re.match(keep_pattern, char):
curr_to_new.append(len(new_text))
new_text += char
else:
# 字符被删除
if new_text:
curr_to_new.append(len(new_text) - 1)
else:
curr_to_new.append(-1)
# 更新原始映射
new_orig_to_curr = []
for orig_idx in range(len(self.original_text)):
old_curr_idx = self.orig_to_curr[orig_idx]
if old_curr_idx >= 0 and old_curr_idx < len(curr_to_new):
new_orig_to_curr.append(curr_to_new[old_curr_idx])
else:
new_orig_to_curr.append(-1)
self.text = new_text
self.orig_to_curr = new_orig_to_curr
def get_norm_to_orig(self):
"""
构建标准化文本到原始文本的反向映射
"""
if not self.text:
return []
norm_to_orig = [-1] * len(self.text)
for orig_idx, norm_idx in enumerate(self.orig_to_curr):
if 0 <= norm_idx < len(self.text):
# 如果多个原始字符映射到同一个标准化位置,取第一个
if norm_to_orig[norm_idx] == -1:
norm_to_orig[norm_idx] = orig_idx
return norm_to_orig
def normalize_with_map(text):
"""
带字符映射的标准化函数
返回:
normalized_text: 标准化后的文本
char_mappings: 字典包含:
- "orig_to_norm": list[int], 原始文本每个字符对应标准化文本的位置
- "norm_to_orig": list[int], 标准化文本每个字符对应原始文本的位置
"""
mapper = CharMapper(text)
# 按照 normalize() 的顺序应用所有转换
mapper.apply_sub(_ordinal_number_re, _convert_ordinal)
mapper.apply_sub(re.compile(r"(?<!\d)-|-(?!\d)"), lambda m: " minus ")
mapper.apply_sub(_comma_number_re, _remove_commas)
mapper.apply_sub(_time_re, _expand_time)
mapper.apply_sub(_measurement_re, _expand_measurement)
mapper.apply_sub(_pounds_re_start, _expand_pounds)
mapper.apply_sub(_pounds_re_end, _expand_pounds)
mapper.apply_sub(_dollars_re_start, _expand_dollars)
mapper.apply_sub(_dollars_re_end, _expand_dollars)
mapper.apply_sub(_decimal_number_re, _expand_decimal_number)
mapper.apply_sub(_fraction_re, _expend_fraction)
mapper.apply_sub(_ordinal_re, _expand_ordinal)
mapper.apply_sub(_number_re, _expand_number)
# Strip accents - 需要手动处理映射
normalized_nfd = unicodedata.normalize("NFD", mapper.text)
new_text = ""
curr_to_new = []
for i, char in enumerate(normalized_nfd):
if unicodedata.category(char) != "Mn":
curr_to_new.append(len(new_text))
new_text += char
else:
# 重音符号被删除,映射到前一个字符
if new_text:
curr_to_new.append(len(new_text) - 1)
else:
curr_to_new.append(-1)
# 更新原始映射 - 需要处理 NFD 可能改变字符数的情况
# 简化处理:假设 NFD 不会显著改变字符数(对于英文通常是这样)
if len(curr_to_new) >= len(mapper.text):
# NFD 展开了一些字符
new_orig_to_curr = []
for orig_idx in range(len(mapper.original_text)):
old_curr_idx = mapper.orig_to_curr[orig_idx]
if old_curr_idx >= 0 and old_curr_idx < len(curr_to_new):
new_orig_to_curr.append(curr_to_new[old_curr_idx])
else:
new_orig_to_curr.append(-1)
mapper.orig_to_curr = new_orig_to_curr
mapper.text = new_text
# 继续其他替换
mapper.apply_sub(re.compile("%"), lambda m: " percent")
# 删除非法字符 - 使用 apply_char_filter
mapper.apply_char_filter(r"[ A-Za-z'.,?!\-]")
mapper.apply_sub(re.compile(r"(?i)i\.e\."), lambda m: "that is")
mapper.apply_sub(re.compile(r"(?i)e\.g\."), lambda m: "for example")
mapper.apply_sub(re.compile(r"(?<!^)(?<![\s])([A-Z])"), lambda m: " " + m.group(1))
norm_to_orig = mapper.get_norm_to_orig()
return mapper.text, {
"orig_to_norm": mapper.orig_to_curr,
"norm_to_orig": norm_to_orig
}
def normalize(text):
"""
!!! 所有的处理都需要正确的输入 !!!

View File

@ -245,6 +245,32 @@ def text_normalize(text):
return text
def text_normalize_with_map(text):
"""
带字符映射的英文标准化函数
Returns:
normalized_text: 标准化后的文本
char_mappings: 字典包含:
- "orig_to_norm": list[int], 原始文本每个字符对应标准化文本的位置
- "norm_to_orig": list[int], 标准化文本每个字符对应原始文本的位置
"""
from .en_normalization.expend import normalize_with_map
# 先进行标点替换
pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
text = pattern.sub(lambda x: rep_map[x.group()], text)
text = unicode(text)
# 使用带映射的标准化函数
normalized_text, mappings = normalize_with_map(text)
# 避免重复标点引起的参考泄露
normalized_text = replace_consecutive_punctuation(normalized_text)
return normalized_text, mappings
class en_G2p(G2p):
def __init__(self):
super().__init__()

View File

@ -179,6 +179,27 @@ def text_normalize(text):
return text
def text_normalize_with_map(text):
"""
带字符映射的日文标准化函数
Returns:
normalized_text: 标准化后的文本
char_mappings: 字典包含:
- "orig_to_norm": list[int], 原始文本每个字符对应标准化文本的位置
- "norm_to_orig": list[int], 标准化文本每个字符对应原始文本的位置
"""
from .char_mapping_utils import build_char_mapping
# 先进行标准化
normalized_text = text_normalize(text)
# 构建字符映射
mappings = build_char_mapping(text, normalized_text)
return normalized_text, mappings
# Copied from espnet https://github.com/espnet/espnet/blob/master/espnet2/text/phoneme_tokenizer.py
def pyopenjtalk_g2p_prosody(text, drop_unvoiced_vowels=True):
"""Extract phoneme + prosoody symbol sequence from input full-context labels.

View File

@ -357,6 +357,22 @@ def g2p_with_word2ph(text, keep_punc=False):
word2ph.append(max(1, len(phs_valid)))
return phones_all, word2ph
def text_normalize_with_map(text):
"""
韩文不需要标准化直接返回原始文本和恒等映射
Returns:
normalized_text: 原始文本未改变
char_mappings: 字典包含恒等映射
"""
# 韩文不进行标准化,直接返回恒等映射
mappings = {
"orig_to_norm": list(range(len(text))),
"norm_to_orig": list(range(len(text)))
}
return text, mappings
if __name__ == "__main__":
text = "안녕하세요"
print(g2p(text))