mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-16 05:36:34 +08:00
Implement character mapping functionality for text normalization across multiple languages, including Cantonese, Chinese, English, Japanese, and Korean. Introduce text_normalize_with_map
methods to return normalized text along with character mappings. Add a new utility module for building character mappings.
This commit is contained in:
parent
0bde5d4a81
commit
90a2f0471f
@ -1,222 +1,243 @@
|
||||
# reference: https://huggingface.co/spaces/Naozumi0512/Bert-VITS2-Cantonese-Yue/blob/main/text/chinese.py
|
||||
|
||||
import re
|
||||
import cn2an
|
||||
import ToJyutping
|
||||
|
||||
from text.symbols import punctuation
|
||||
from text.zh_normalization.text_normlization import TextNormalizer
|
||||
|
||||
normalizer = lambda x: cn2an.transform(x, "an2cn")
|
||||
|
||||
INITIALS = [
|
||||
"aa",
|
||||
"aai",
|
||||
"aak",
|
||||
"aap",
|
||||
"aat",
|
||||
"aau",
|
||||
"ai",
|
||||
"au",
|
||||
"ap",
|
||||
"at",
|
||||
"ak",
|
||||
"a",
|
||||
"p",
|
||||
"b",
|
||||
"e",
|
||||
"ts",
|
||||
"t",
|
||||
"dz",
|
||||
"d",
|
||||
"kw",
|
||||
"k",
|
||||
"gw",
|
||||
"g",
|
||||
"f",
|
||||
"h",
|
||||
"l",
|
||||
"m",
|
||||
"ng",
|
||||
"n",
|
||||
"s",
|
||||
"y",
|
||||
"w",
|
||||
"c",
|
||||
"z",
|
||||
"j",
|
||||
"ong",
|
||||
"on",
|
||||
"ou",
|
||||
"oi",
|
||||
"ok",
|
||||
"o",
|
||||
"uk",
|
||||
"ung",
|
||||
]
|
||||
INITIALS += ["sp", "spl", "spn", "sil"]
|
||||
|
||||
|
||||
rep_map = {
|
||||
":": ",",
|
||||
";": ",",
|
||||
",": ",",
|
||||
"。": ".",
|
||||
"!": "!",
|
||||
"?": "?",
|
||||
"\n": ".",
|
||||
"·": ",",
|
||||
"、": ",",
|
||||
"...": "…",
|
||||
"$": ".",
|
||||
"“": "'",
|
||||
"”": "'",
|
||||
'"': "'",
|
||||
"‘": "'",
|
||||
"’": "'",
|
||||
"(": "'",
|
||||
")": "'",
|
||||
"(": "'",
|
||||
")": "'",
|
||||
"《": "'",
|
||||
"》": "'",
|
||||
"【": "'",
|
||||
"】": "'",
|
||||
"[": "'",
|
||||
"]": "'",
|
||||
"—": "-",
|
||||
"~": "-",
|
||||
"~": "-",
|
||||
"「": "'",
|
||||
"」": "'",
|
||||
}
|
||||
|
||||
|
||||
def replace_punctuation(text):
|
||||
# text = text.replace("嗯", "恩").replace("呣", "母")
|
||||
pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
|
||||
|
||||
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
|
||||
|
||||
replaced_text = re.sub(r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text)
|
||||
|
||||
return replaced_text
|
||||
|
||||
|
||||
def text_normalize(text):
|
||||
tx = TextNormalizer()
|
||||
sentences = tx.normalize(text)
|
||||
dest_text = ""
|
||||
for sentence in sentences:
|
||||
dest_text += replace_punctuation(sentence)
|
||||
return dest_text
|
||||
|
||||
|
||||
punctuation_set = set(punctuation)
|
||||
|
||||
|
||||
def jyuping_to_initials_finals_tones(jyuping_syllables):
|
||||
initials_finals = []
|
||||
tones = []
|
||||
word2ph = []
|
||||
|
||||
for syllable in jyuping_syllables:
|
||||
if syllable in punctuation:
|
||||
initials_finals.append(syllable)
|
||||
tones.append(0)
|
||||
word2ph.append(1) # Add 1 for punctuation
|
||||
elif syllable == "_":
|
||||
initials_finals.append(syllable)
|
||||
tones.append(0)
|
||||
word2ph.append(1) # Add 1 for underscore
|
||||
else:
|
||||
try:
|
||||
tone = int(syllable[-1])
|
||||
syllable_without_tone = syllable[:-1]
|
||||
except ValueError:
|
||||
tone = 0
|
||||
syllable_without_tone = syllable
|
||||
|
||||
for initial in INITIALS:
|
||||
if syllable_without_tone.startswith(initial):
|
||||
if syllable_without_tone.startswith("nga"):
|
||||
initials_finals.extend(
|
||||
[
|
||||
syllable_without_tone[:2],
|
||||
syllable_without_tone[2:] or syllable_without_tone[-1],
|
||||
]
|
||||
)
|
||||
# tones.extend([tone, tone])
|
||||
tones.extend([-1, tone])
|
||||
word2ph.append(2)
|
||||
else:
|
||||
final = syllable_without_tone[len(initial) :] or initial[-1]
|
||||
initials_finals.extend([initial, final])
|
||||
# tones.extend([tone, tone])
|
||||
tones.extend([-1, tone])
|
||||
word2ph.append(2)
|
||||
break
|
||||
assert len(initials_finals) == len(tones)
|
||||
|
||||
###魔改为辅音+带音调的元音
|
||||
phones = []
|
||||
for a, b in zip(initials_finals, tones):
|
||||
if b not in [-1, 0]: ###防止粤语和普通话重合开头加Y,如果是标点,不加。
|
||||
todo = "%s%s" % (a, b)
|
||||
else:
|
||||
todo = a
|
||||
if todo not in punctuation_set:
|
||||
todo = "Y%s" % todo
|
||||
phones.append(todo)
|
||||
|
||||
# return initials_finals, tones, word2ph
|
||||
return phones, word2ph
|
||||
|
||||
|
||||
def get_jyutping(text):
|
||||
jyutping_array = []
|
||||
punct_pattern = re.compile(r"^[{}]+$".format(re.escape("".join(punctuation))))
|
||||
|
||||
syllables = ToJyutping.get_jyutping_list(text)
|
||||
|
||||
for word, syllable in syllables:
|
||||
if punct_pattern.match(word):
|
||||
puncts = re.split(r"([{}])".format(re.escape("".join(punctuation))), word)
|
||||
for punct in puncts:
|
||||
if len(punct) > 0:
|
||||
jyutping_array.append(punct)
|
||||
else:
|
||||
# match multple jyutping eg: liu4 ge3, or single jyutping eg: liu4
|
||||
if not re.search(r"^([a-z]+[1-6]+[ ]?)+$", syllable):
|
||||
raise ValueError(f"Failed to convert {word} to jyutping: {syllable}")
|
||||
jyutping_array.append(syllable)
|
||||
|
||||
return jyutping_array
|
||||
|
||||
|
||||
def get_bert_feature(text, word2ph):
|
||||
from text import chinese_bert
|
||||
|
||||
return chinese_bert.get_bert_feature(text, word2ph)
|
||||
|
||||
|
||||
def g2p(text):
|
||||
# word2ph = []
|
||||
jyuping = get_jyutping(text)
|
||||
# print(jyuping)
|
||||
# phones, tones, word2ph = jyuping_to_initials_finals_tones(jyuping)
|
||||
phones, word2ph = jyuping_to_initials_finals_tones(jyuping)
|
||||
# phones = ["_"] + phones + ["_"]
|
||||
# tones = [0] + tones + [0]
|
||||
# word2ph = [1] + word2ph + [1]
|
||||
return phones, word2ph
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# text = "啊!但是《原神》是由,米哈\游自主, [研发]的一款全.新开放世界.冒险游戏"
|
||||
text = "佢個鋤頭太短啦。"
|
||||
text = text_normalize(text)
|
||||
# phones, tones, word2ph = g2p(text)
|
||||
phones, word2ph = g2p(text)
|
||||
# print(phones, tones, word2ph)
|
||||
print(phones, word2ph)
|
||||
# reference: https://huggingface.co/spaces/Naozumi0512/Bert-VITS2-Cantonese-Yue/blob/main/text/chinese.py
|
||||
|
||||
import re
|
||||
import cn2an
|
||||
import ToJyutping
|
||||
|
||||
from text.symbols import punctuation
|
||||
from text.zh_normalization.text_normlization import TextNormalizer
|
||||
|
||||
normalizer = lambda x: cn2an.transform(x, "an2cn")
|
||||
|
||||
INITIALS = [
|
||||
"aa",
|
||||
"aai",
|
||||
"aak",
|
||||
"aap",
|
||||
"aat",
|
||||
"aau",
|
||||
"ai",
|
||||
"au",
|
||||
"ap",
|
||||
"at",
|
||||
"ak",
|
||||
"a",
|
||||
"p",
|
||||
"b",
|
||||
"e",
|
||||
"ts",
|
||||
"t",
|
||||
"dz",
|
||||
"d",
|
||||
"kw",
|
||||
"k",
|
||||
"gw",
|
||||
"g",
|
||||
"f",
|
||||
"h",
|
||||
"l",
|
||||
"m",
|
||||
"ng",
|
||||
"n",
|
||||
"s",
|
||||
"y",
|
||||
"w",
|
||||
"c",
|
||||
"z",
|
||||
"j",
|
||||
"ong",
|
||||
"on",
|
||||
"ou",
|
||||
"oi",
|
||||
"ok",
|
||||
"o",
|
||||
"uk",
|
||||
"ung",
|
||||
]
|
||||
INITIALS += ["sp", "spl", "spn", "sil"]
|
||||
|
||||
|
||||
rep_map = {
|
||||
":": ",",
|
||||
";": ",",
|
||||
",": ",",
|
||||
"。": ".",
|
||||
"!": "!",
|
||||
"?": "?",
|
||||
"\n": ".",
|
||||
"·": ",",
|
||||
"、": ",",
|
||||
"...": "…",
|
||||
"$": ".",
|
||||
"“": "'",
|
||||
"”": "'",
|
||||
'"': "'",
|
||||
"‘": "'",
|
||||
"’": "'",
|
||||
"(": "'",
|
||||
")": "'",
|
||||
"(": "'",
|
||||
")": "'",
|
||||
"《": "'",
|
||||
"》": "'",
|
||||
"【": "'",
|
||||
"】": "'",
|
||||
"[": "'",
|
||||
"]": "'",
|
||||
"—": "-",
|
||||
"~": "-",
|
||||
"~": "-",
|
||||
"「": "'",
|
||||
"」": "'",
|
||||
}
|
||||
|
||||
|
||||
def replace_punctuation(text):
|
||||
# text = text.replace("嗯", "恩").replace("呣", "母")
|
||||
pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
|
||||
|
||||
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
|
||||
|
||||
replaced_text = re.sub(r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text)
|
||||
|
||||
return replaced_text
|
||||
|
||||
|
||||
def text_normalize(text):
|
||||
tx = TextNormalizer()
|
||||
sentences = tx.normalize(text)
|
||||
dest_text = ""
|
||||
for sentence in sentences:
|
||||
dest_text += replace_punctuation(sentence)
|
||||
return dest_text
|
||||
|
||||
|
||||
def text_normalize_with_map(text):
|
||||
"""
|
||||
带字符映射的粤语标准化函数
|
||||
|
||||
Returns:
|
||||
normalized_text: 标准化后的文本
|
||||
char_mappings: 字典,包含:
|
||||
- "orig_to_norm": list[int], 原始文本每个字符对应标准化文本的位置
|
||||
- "norm_to_orig": list[int], 标准化文本每个字符对应原始文本的位置
|
||||
"""
|
||||
from .char_mapping_utils import build_char_mapping
|
||||
|
||||
# 先进行标准化
|
||||
normalized_text = text_normalize(text)
|
||||
|
||||
# 构建字符映射
|
||||
mappings = build_char_mapping(text, normalized_text)
|
||||
|
||||
return normalized_text, mappings
|
||||
|
||||
|
||||
punctuation_set = set(punctuation)
|
||||
|
||||
|
||||
def jyuping_to_initials_finals_tones(jyuping_syllables):
|
||||
initials_finals = []
|
||||
tones = []
|
||||
word2ph = []
|
||||
|
||||
for syllable in jyuping_syllables:
|
||||
if syllable in punctuation:
|
||||
initials_finals.append(syllable)
|
||||
tones.append(0)
|
||||
word2ph.append(1) # Add 1 for punctuation
|
||||
elif syllable == "_":
|
||||
initials_finals.append(syllable)
|
||||
tones.append(0)
|
||||
word2ph.append(1) # Add 1 for underscore
|
||||
else:
|
||||
try:
|
||||
tone = int(syllable[-1])
|
||||
syllable_without_tone = syllable[:-1]
|
||||
except ValueError:
|
||||
tone = 0
|
||||
syllable_without_tone = syllable
|
||||
|
||||
for initial in INITIALS:
|
||||
if syllable_without_tone.startswith(initial):
|
||||
if syllable_without_tone.startswith("nga"):
|
||||
initials_finals.extend(
|
||||
[
|
||||
syllable_without_tone[:2],
|
||||
syllable_without_tone[2:] or syllable_without_tone[-1],
|
||||
]
|
||||
)
|
||||
# tones.extend([tone, tone])
|
||||
tones.extend([-1, tone])
|
||||
word2ph.append(2)
|
||||
else:
|
||||
final = syllable_without_tone[len(initial) :] or initial[-1]
|
||||
initials_finals.extend([initial, final])
|
||||
# tones.extend([tone, tone])
|
||||
tones.extend([-1, tone])
|
||||
word2ph.append(2)
|
||||
break
|
||||
assert len(initials_finals) == len(tones)
|
||||
|
||||
###魔改为辅音+带音调的元音
|
||||
phones = []
|
||||
for a, b in zip(initials_finals, tones):
|
||||
if b not in [-1, 0]: ###防止粤语和普通话重合开头加Y,如果是标点,不加。
|
||||
todo = "%s%s" % (a, b)
|
||||
else:
|
||||
todo = a
|
||||
if todo not in punctuation_set:
|
||||
todo = "Y%s" % todo
|
||||
phones.append(todo)
|
||||
|
||||
# return initials_finals, tones, word2ph
|
||||
return phones, word2ph
|
||||
|
||||
|
||||
def get_jyutping(text):
|
||||
jyutping_array = []
|
||||
punct_pattern = re.compile(r"^[{}]+$".format(re.escape("".join(punctuation))))
|
||||
|
||||
syllables = ToJyutping.get_jyutping_list(text)
|
||||
|
||||
for word, syllable in syllables:
|
||||
if punct_pattern.match(word):
|
||||
puncts = re.split(r"([{}])".format(re.escape("".join(punctuation))), word)
|
||||
for punct in puncts:
|
||||
if len(punct) > 0:
|
||||
jyutping_array.append(punct)
|
||||
else:
|
||||
# match multple jyutping eg: liu4 ge3, or single jyutping eg: liu4
|
||||
if not re.search(r"^([a-z]+[1-6]+[ ]?)+$", syllable):
|
||||
raise ValueError(f"Failed to convert {word} to jyutping: {syllable}")
|
||||
jyutping_array.append(syllable)
|
||||
|
||||
return jyutping_array
|
||||
|
||||
|
||||
def get_bert_feature(text, word2ph):
|
||||
from text import chinese_bert
|
||||
|
||||
return chinese_bert.get_bert_feature(text, word2ph)
|
||||
|
||||
|
||||
def g2p(text):
|
||||
# word2ph = []
|
||||
jyuping = get_jyutping(text)
|
||||
# print(jyuping)
|
||||
# phones, tones, word2ph = jyuping_to_initials_finals_tones(jyuping)
|
||||
phones, word2ph = jyuping_to_initials_finals_tones(jyuping)
|
||||
# phones = ["_"] + phones + ["_"]
|
||||
# tones = [0] + tones + [0]
|
||||
# word2ph = [1] + word2ph + [1]
|
||||
return phones, word2ph
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# text = "啊!但是《原神》是由,米哈\游自主, [研发]的一款全.新开放世界.冒险游戏"
|
||||
text = "佢個鋤頭太短啦。"
|
||||
text = text_normalize(text)
|
||||
# phones, tones, word2ph = g2p(text)
|
||||
phones, word2ph = g2p(text)
|
||||
# print(phones, tones, word2ph)
|
||||
print(phones, word2ph)
|
||||
|
144
GPT_SoVITS/text/char_mapping_utils.py
Normal file
144
GPT_SoVITS/text/char_mapping_utils.py
Normal file
@ -0,0 +1,144 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
通用字符映射工具
|
||||
|
||||
用于在文本标准化后,建立原始文本到标准化文本的字符级映射
|
||||
"""
|
||||
|
||||
|
||||
def build_char_mapping(original_text, normalized_text):
|
||||
"""
|
||||
通过字符对齐算法,构建原始文本到标准化文本的映射
|
||||
|
||||
Args:
|
||||
original_text: 原始文本
|
||||
normalized_text: 标准化后的文本
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
"orig_to_norm": list[int], 原始文本每个字符对应标准化文本的位置
|
||||
"norm_to_orig": list[int], 标准化文本每个字符对应原始文本的位置
|
||||
}
|
||||
"""
|
||||
# 使用动态规划找到最优对齐
|
||||
m, n = len(original_text), len(normalized_text)
|
||||
|
||||
# dp[i][j] 表示 original_text[:i] 和 normalized_text[:j] 的对齐代价
|
||||
# 0 = 匹配, 1 = 替换, 插入, 删除
|
||||
dp = [[float('inf')] * (n + 1) for _ in range(m + 1)]
|
||||
|
||||
# 记录路径
|
||||
path = [[None] * (n + 1) for _ in range(m + 1)]
|
||||
|
||||
# 初始化
|
||||
for i in range(m + 1):
|
||||
dp[i][0] = i
|
||||
if i > 0:
|
||||
path[i][0] = ('del', i-1, -1)
|
||||
|
||||
for j in range(n + 1):
|
||||
dp[0][j] = j
|
||||
if j > 0:
|
||||
path[0][j] = ('ins', -1, j-1)
|
||||
|
||||
# 动态规划
|
||||
for i in range(1, m + 1):
|
||||
for j in range(1, n + 1):
|
||||
orig_char = original_text[i-1]
|
||||
norm_char = normalized_text[j-1]
|
||||
|
||||
# 匹配代价(相同字符代价为0,不同字符代价为1)
|
||||
match_cost = 0 if orig_char == norm_char else 1
|
||||
|
||||
# 三种操作的代价
|
||||
costs = [
|
||||
(dp[i-1][j-1] + match_cost, 'match' if match_cost == 0 else 'replace', i-1, j-1),
|
||||
(dp[i-1][j] + 1, 'del', i-1, j),
|
||||
(dp[i][j-1] + 1, 'ins', i, j-1),
|
||||
]
|
||||
|
||||
min_cost, op, pi, pj = min(costs, key=lambda x: x[0])
|
||||
dp[i][j] = min_cost
|
||||
path[i][j] = (op, pi, pj)
|
||||
|
||||
# 回溯路径,构建映射
|
||||
orig_to_norm = [-1] * len(original_text)
|
||||
norm_to_orig = [-1] * len(normalized_text)
|
||||
|
||||
i, j = m, n
|
||||
alignments = []
|
||||
|
||||
while i > 0 or j > 0:
|
||||
if path[i][j] is None:
|
||||
break
|
||||
|
||||
op, pi, pj = path[i][j]
|
||||
|
||||
if op in ['match', 'replace']:
|
||||
# 原始字符i-1 对应 标准化字符j-1
|
||||
alignments.append((i-1, j-1))
|
||||
i, j = pi, pj
|
||||
elif op == 'del':
|
||||
# 原始字符i-1 被删除,映射到当前标准化位置(如果存在)
|
||||
if j > 0:
|
||||
alignments.append((i-1, j-1))
|
||||
i = pi
|
||||
elif op == 'ins':
|
||||
# 标准化字符j-1 是插入的,没有对应的原始字符
|
||||
j = pj
|
||||
|
||||
# 根据对齐结果建立映射
|
||||
for orig_idx, norm_idx in alignments:
|
||||
if orig_idx >= 0 and orig_idx < len(original_text):
|
||||
if orig_to_norm[orig_idx] == -1:
|
||||
orig_to_norm[orig_idx] = norm_idx
|
||||
|
||||
if norm_idx >= 0 and norm_idx < len(normalized_text):
|
||||
if norm_to_orig[norm_idx] == -1:
|
||||
norm_to_orig[norm_idx] = orig_idx
|
||||
|
||||
return {
|
||||
"orig_to_norm": orig_to_norm,
|
||||
"norm_to_orig": norm_to_orig
|
||||
}
|
||||
|
||||
|
||||
def test_char_mapping():
|
||||
"""测试字符映射功能"""
|
||||
test_cases = [
|
||||
("50元", "五十元"),
|
||||
("3.5度", "三点五度"),
|
||||
("GPT-4", "GPT minus four"), # 也可以测试英文
|
||||
]
|
||||
|
||||
for orig, norm in test_cases:
|
||||
print(f"原始: '{orig}'")
|
||||
print(f"标准化: '{norm}'")
|
||||
|
||||
mappings = build_char_mapping(orig, norm)
|
||||
orig_to_norm = mappings["orig_to_norm"]
|
||||
norm_to_orig = mappings["norm_to_orig"]
|
||||
|
||||
print(f"原始→标准化映射:")
|
||||
for i, c in enumerate(orig):
|
||||
norm_idx = orig_to_norm[i]
|
||||
if norm_idx >= 0 and norm_idx < len(norm):
|
||||
print(f" [{i}]'{c}' → [{norm_idx}]'{norm[norm_idx]}'")
|
||||
else:
|
||||
print(f" [{i}]'{c}' → 无映射")
|
||||
|
||||
print(f"标准化→原始映射:")
|
||||
for i, c in enumerate(norm):
|
||||
orig_idx = norm_to_orig[i]
|
||||
if orig_idx >= 0 and orig_idx < len(orig):
|
||||
print(f" [{i}]'{c}' ← [{orig_idx}]'{orig[orig_idx]}'")
|
||||
else:
|
||||
print(f" [{i}]'{c}' ← 无对应")
|
||||
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_char_mapping()
|
||||
|
@ -181,6 +181,27 @@ def text_normalize(text):
|
||||
return dest_text
|
||||
|
||||
|
||||
def text_normalize_with_map(text):
|
||||
"""
|
||||
带字符映射的中文标准化函数(旧版)
|
||||
|
||||
Returns:
|
||||
normalized_text: 标准化后的文本
|
||||
char_mappings: 字典,包含:
|
||||
- "orig_to_norm": list[int], 原始文本每个字符对应标准化文本的位置
|
||||
- "norm_to_orig": list[int], 标准化文本每个字符对应原始文本的位置
|
||||
"""
|
||||
from .char_mapping_utils import build_char_mapping
|
||||
|
||||
# 先进行标准化
|
||||
normalized_text = text_normalize(text)
|
||||
|
||||
# 构建字符映射
|
||||
mappings = build_char_mapping(text, normalized_text)
|
||||
|
||||
return normalized_text, mappings
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
text = "啊——但是《原神》是由,米哈\游自主,研发的一款全.新开放世界.冒险游戏"
|
||||
text = "呣呣呣~就是…大人的鼹鼠党吧?"
|
||||
|
@ -326,6 +326,27 @@ def text_normalize(text):
|
||||
return dest_text
|
||||
|
||||
|
||||
def text_normalize_with_map(text):
|
||||
"""
|
||||
带字符映射的中文标准化函数
|
||||
|
||||
Returns:
|
||||
normalized_text: 标准化后的文本
|
||||
char_mappings: 字典,包含:
|
||||
- "orig_to_norm": list[int], 原始文本每个字符对应标准化文本的位置
|
||||
- "norm_to_orig": list[int], 标准化文本每个字符对应原始文本的位置
|
||||
"""
|
||||
from .char_mapping_utils import build_char_mapping
|
||||
|
||||
# 先进行标准化
|
||||
normalized_text = text_normalize(text)
|
||||
|
||||
# 构建字符映射
|
||||
mappings = build_char_mapping(text, normalized_text)
|
||||
|
||||
return normalized_text, mappings
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
text = "啊——但是《原神》是由,米哈\游自主,研发的一款全.新开放世界.冒险游戏"
|
||||
text = "呣呣呣~就是…大人的鼹鼠党吧?"
|
||||
|
@ -61,6 +61,87 @@ def clean_text(text, language, version=None):
|
||||
return phones, word2ph, norm_text
|
||||
|
||||
|
||||
def clean_text_with_mapping(text, language, version=None):
|
||||
"""
|
||||
带字符映射的文本清洗函数
|
||||
|
||||
Args:
|
||||
text: 原始文本
|
||||
language: 语言代码 (zh, ja, en, ko, yue)
|
||||
version: 模型版本 (v1, v2)
|
||||
|
||||
Returns:
|
||||
tuple: (phones, word2ph, norm_text, char_mappings)
|
||||
- phones: 音素序列
|
||||
- word2ph: 词到音素的映射
|
||||
- norm_text: 标准化后的文本
|
||||
- char_mappings: 字符映射字典 {"orig_to_norm": [...], "norm_to_orig": [...]}
|
||||
"""
|
||||
if version is None:
|
||||
version = os.environ.get("version", "v2")
|
||||
if version == "v1":
|
||||
symbols = symbols_v1.symbols
|
||||
language_module_map = {"zh": "chinese", "ja": "japanese", "en": "english"}
|
||||
else:
|
||||
symbols = symbols_v2.symbols
|
||||
language_module_map = {"zh": "chinese2", "ja": "japanese", "en": "english", "ko": "korean", "yue": "cantonese"}
|
||||
|
||||
if language not in language_module_map:
|
||||
language = "en"
|
||||
text = " "
|
||||
|
||||
# 特殊符号处理暂不支持映射
|
||||
for special_s, special_l, target_symbol in special:
|
||||
if special_s in text and language == special_l:
|
||||
phones, word2ph, norm_text = clean_special(text, language, special_s, target_symbol, version)
|
||||
# 返回空映射
|
||||
char_mappings = {
|
||||
"orig_to_norm": list(range(len(text))),
|
||||
"norm_to_orig": list(range(len(norm_text)))
|
||||
}
|
||||
return phones, word2ph, norm_text, char_mappings
|
||||
|
||||
language_module = __import__("text." + language_module_map[language], fromlist=[language_module_map[language]])
|
||||
|
||||
# 使用带映射的标准化函数
|
||||
if hasattr(language_module, "text_normalize_with_map"):
|
||||
norm_text, char_mappings = language_module.text_normalize_with_map(text)
|
||||
elif hasattr(language_module, "text_normalize"):
|
||||
# 如果没有带映射的版本,使用普通版本并构建映射
|
||||
norm_text = language_module.text_normalize(text)
|
||||
from .char_mapping_utils import build_char_mapping
|
||||
char_mappings = build_char_mapping(text, norm_text)
|
||||
else:
|
||||
# 不进行标准化
|
||||
norm_text = text
|
||||
char_mappings = {
|
||||
"orig_to_norm": list(range(len(text))),
|
||||
"norm_to_orig": list(range(len(text)))
|
||||
}
|
||||
|
||||
# 处理音素
|
||||
if language == "zh" or language == "yue":
|
||||
phones, word2ph = language_module.g2p(norm_text)
|
||||
assert len(phones) == sum(word2ph)
|
||||
assert len(norm_text) == len(word2ph)
|
||||
else:
|
||||
# Try per-language word2ph helpers
|
||||
if hasattr(language_module, "g2p_with_word2ph"):
|
||||
try:
|
||||
phones, word2ph = language_module.g2p_with_word2ph(norm_text, keep_punc=False)
|
||||
except Exception:
|
||||
phones = language_module.g2p(norm_text)
|
||||
word2ph = None
|
||||
else:
|
||||
phones = language_module.g2p(norm_text)
|
||||
word2ph = None
|
||||
if language == "en" and len(phones) < 4:
|
||||
phones = [","] + phones
|
||||
|
||||
phones = ["UNK" if ph not in symbols else ph for ph in phones]
|
||||
return phones, word2ph, norm_text, char_mappings
|
||||
|
||||
|
||||
def clean_special(text, language, special_s, target_symbol, version=None):
|
||||
if version is None:
|
||||
version = os.environ.get("version", "v2")
|
||||
|
@ -238,6 +238,214 @@ def _expand_number(m):
|
||||
return _inflect.number_to_words(num, andword="")
|
||||
|
||||
|
||||
class CharMapper:
|
||||
"""
|
||||
字符映射追踪器,用于记录文本标准化过程中字符位置的变化
|
||||
|
||||
核心思想:维护从原始文本到当前文本的映射
|
||||
- orig_to_curr[i] 表示原始文本第i个字符对应当前文本的位置
|
||||
"""
|
||||
def __init__(self, text):
|
||||
self.original_text = text
|
||||
self.text = text
|
||||
# 初始化:每个原始字符映射到自己
|
||||
self.orig_to_curr = list(range(len(text)))
|
||||
|
||||
def apply_sub(self, pattern, replacement_func):
|
||||
"""
|
||||
应用正则替换并更新映射
|
||||
|
||||
关键:需要通过旧映射来更新新映射
|
||||
支持捕获组的特殊处理(如大写字母拆分)
|
||||
"""
|
||||
new_text = ""
|
||||
# curr_to_new[i] 表示当前文本第i个字符在新文本中的位置
|
||||
curr_to_new = [-1] * len(self.text)
|
||||
|
||||
pos = 0
|
||||
for match in pattern.finditer(self.text):
|
||||
# 处理匹配前的未变化文本
|
||||
for i in range(pos, match.start()):
|
||||
curr_to_new[i] = len(new_text)
|
||||
new_text += self.text[i]
|
||||
|
||||
# 处理匹配的部分
|
||||
replacement = replacement_func(match)
|
||||
replacement_start_pos = len(new_text)
|
||||
|
||||
# 特殊处理:如果替换文本包含原始文本的字符(例如 "A" -> " A")
|
||||
# 尝试找到对应关系
|
||||
match_text = match.group(0)
|
||||
if len(match.groups()) > 0 and match.group(1) in replacement:
|
||||
# 有捕获组,尝试精确映射
|
||||
# 例如: "(?<!^)(?<![\s])([A-Z])" 匹配 "P",替换为 " P"
|
||||
captured = match.group(1)
|
||||
replacement_idx = replacement.find(captured)
|
||||
if replacement_idx >= 0:
|
||||
# 捕获的字符在替换文本中
|
||||
for i in range(match.start(), match.end()):
|
||||
char = self.text[i]
|
||||
if char in captured:
|
||||
# 这个字符在捕获组中,映射到替换文本中的对应位置
|
||||
char_idx_in_replacement = replacement.find(char, replacement_idx)
|
||||
if char_idx_in_replacement >= 0:
|
||||
curr_to_new[i] = replacement_start_pos + char_idx_in_replacement
|
||||
else:
|
||||
curr_to_new[i] = replacement_start_pos
|
||||
else:
|
||||
curr_to_new[i] = replacement_start_pos
|
||||
else:
|
||||
# 捕获的字符不在替换文本中,都映射到起始位置
|
||||
for i in range(match.start(), match.end()):
|
||||
curr_to_new[i] = replacement_start_pos
|
||||
else:
|
||||
# 没有捕获组或不包含原始字符,匹配部分的所有字符都映射到替换文本的起始位置
|
||||
for i in range(match.start(), match.end()):
|
||||
curr_to_new[i] = replacement_start_pos
|
||||
|
||||
new_text += replacement
|
||||
pos = match.end()
|
||||
|
||||
# 处理剩余文本
|
||||
for i in range(pos, len(self.text)):
|
||||
curr_to_new[i] = len(new_text)
|
||||
new_text += self.text[i]
|
||||
|
||||
# 更新原始到当前的映射:orig -> old_curr -> new_curr
|
||||
new_orig_to_curr = []
|
||||
for orig_idx in range(len(self.original_text)):
|
||||
old_curr_idx = self.orig_to_curr[orig_idx]
|
||||
if old_curr_idx >= 0 and old_curr_idx < len(curr_to_new):
|
||||
new_orig_to_curr.append(curr_to_new[old_curr_idx])
|
||||
else:
|
||||
new_orig_to_curr.append(-1)
|
||||
|
||||
self.text = new_text
|
||||
self.orig_to_curr = new_orig_to_curr
|
||||
|
||||
def apply_char_filter(self, keep_pattern):
|
||||
"""
|
||||
应用字符过滤(只保留符合模式的字符)并更新映射
|
||||
|
||||
keep_pattern: 正则表达式字符串,如 "[ A-Za-z'.,?!-]"
|
||||
"""
|
||||
new_text = ""
|
||||
curr_to_new = []
|
||||
|
||||
for i, char in enumerate(self.text):
|
||||
if re.match(keep_pattern, char):
|
||||
curr_to_new.append(len(new_text))
|
||||
new_text += char
|
||||
else:
|
||||
# 字符被删除
|
||||
if new_text:
|
||||
curr_to_new.append(len(new_text) - 1)
|
||||
else:
|
||||
curr_to_new.append(-1)
|
||||
|
||||
# 更新原始映射
|
||||
new_orig_to_curr = []
|
||||
for orig_idx in range(len(self.original_text)):
|
||||
old_curr_idx = self.orig_to_curr[orig_idx]
|
||||
if old_curr_idx >= 0 and old_curr_idx < len(curr_to_new):
|
||||
new_orig_to_curr.append(curr_to_new[old_curr_idx])
|
||||
else:
|
||||
new_orig_to_curr.append(-1)
|
||||
|
||||
self.text = new_text
|
||||
self.orig_to_curr = new_orig_to_curr
|
||||
|
||||
def get_norm_to_orig(self):
|
||||
"""
|
||||
构建标准化文本到原始文本的反向映射
|
||||
"""
|
||||
if not self.text:
|
||||
return []
|
||||
|
||||
norm_to_orig = [-1] * len(self.text)
|
||||
for orig_idx, norm_idx in enumerate(self.orig_to_curr):
|
||||
if 0 <= norm_idx < len(self.text):
|
||||
# 如果多个原始字符映射到同一个标准化位置,取第一个
|
||||
if norm_to_orig[norm_idx] == -1:
|
||||
norm_to_orig[norm_idx] = orig_idx
|
||||
|
||||
return norm_to_orig
|
||||
|
||||
|
||||
def normalize_with_map(text):
|
||||
"""
|
||||
带字符映射的标准化函数
|
||||
|
||||
返回:
|
||||
normalized_text: 标准化后的文本
|
||||
char_mappings: 字典,包含:
|
||||
- "orig_to_norm": list[int], 原始文本每个字符对应标准化文本的位置
|
||||
- "norm_to_orig": list[int], 标准化文本每个字符对应原始文本的位置
|
||||
"""
|
||||
mapper = CharMapper(text)
|
||||
|
||||
# 按照 normalize() 的顺序应用所有转换
|
||||
mapper.apply_sub(_ordinal_number_re, _convert_ordinal)
|
||||
mapper.apply_sub(re.compile(r"(?<!\d)-|-(?!\d)"), lambda m: " minus ")
|
||||
mapper.apply_sub(_comma_number_re, _remove_commas)
|
||||
mapper.apply_sub(_time_re, _expand_time)
|
||||
mapper.apply_sub(_measurement_re, _expand_measurement)
|
||||
mapper.apply_sub(_pounds_re_start, _expand_pounds)
|
||||
mapper.apply_sub(_pounds_re_end, _expand_pounds)
|
||||
mapper.apply_sub(_dollars_re_start, _expand_dollars)
|
||||
mapper.apply_sub(_dollars_re_end, _expand_dollars)
|
||||
mapper.apply_sub(_decimal_number_re, _expand_decimal_number)
|
||||
mapper.apply_sub(_fraction_re, _expend_fraction)
|
||||
mapper.apply_sub(_ordinal_re, _expand_ordinal)
|
||||
mapper.apply_sub(_number_re, _expand_number)
|
||||
|
||||
# Strip accents - 需要手动处理映射
|
||||
normalized_nfd = unicodedata.normalize("NFD", mapper.text)
|
||||
new_text = ""
|
||||
curr_to_new = []
|
||||
for i, char in enumerate(normalized_nfd):
|
||||
if unicodedata.category(char) != "Mn":
|
||||
curr_to_new.append(len(new_text))
|
||||
new_text += char
|
||||
else:
|
||||
# 重音符号被删除,映射到前一个字符
|
||||
if new_text:
|
||||
curr_to_new.append(len(new_text) - 1)
|
||||
else:
|
||||
curr_to_new.append(-1)
|
||||
|
||||
# 更新原始映射 - 需要处理 NFD 可能改变字符数的情况
|
||||
# 简化处理:假设 NFD 不会显著改变字符数(对于英文通常是这样)
|
||||
if len(curr_to_new) >= len(mapper.text):
|
||||
# NFD 展开了一些字符
|
||||
new_orig_to_curr = []
|
||||
for orig_idx in range(len(mapper.original_text)):
|
||||
old_curr_idx = mapper.orig_to_curr[orig_idx]
|
||||
if old_curr_idx >= 0 and old_curr_idx < len(curr_to_new):
|
||||
new_orig_to_curr.append(curr_to_new[old_curr_idx])
|
||||
else:
|
||||
new_orig_to_curr.append(-1)
|
||||
mapper.orig_to_curr = new_orig_to_curr
|
||||
mapper.text = new_text
|
||||
|
||||
# 继续其他替换
|
||||
mapper.apply_sub(re.compile("%"), lambda m: " percent")
|
||||
|
||||
# 删除非法字符 - 使用 apply_char_filter
|
||||
mapper.apply_char_filter(r"[ A-Za-z'.,?!\-]")
|
||||
|
||||
mapper.apply_sub(re.compile(r"(?i)i\.e\."), lambda m: "that is")
|
||||
mapper.apply_sub(re.compile(r"(?i)e\.g\."), lambda m: "for example")
|
||||
mapper.apply_sub(re.compile(r"(?<!^)(?<![\s])([A-Z])"), lambda m: " " + m.group(1))
|
||||
|
||||
norm_to_orig = mapper.get_norm_to_orig()
|
||||
|
||||
return mapper.text, {
|
||||
"orig_to_norm": mapper.orig_to_curr,
|
||||
"norm_to_orig": norm_to_orig
|
||||
}
|
||||
|
||||
|
||||
def normalize(text):
|
||||
"""
|
||||
!!! 所有的处理都需要正确的输入 !!!
|
||||
|
@ -245,6 +245,32 @@ def text_normalize(text):
|
||||
return text
|
||||
|
||||
|
||||
def text_normalize_with_map(text):
|
||||
"""
|
||||
带字符映射的英文标准化函数
|
||||
|
||||
Returns:
|
||||
normalized_text: 标准化后的文本
|
||||
char_mappings: 字典,包含:
|
||||
- "orig_to_norm": list[int], 原始文本每个字符对应标准化文本的位置
|
||||
- "norm_to_orig": list[int], 标准化文本每个字符对应原始文本的位置
|
||||
"""
|
||||
from .en_normalization.expend import normalize_with_map
|
||||
|
||||
# 先进行标点替换
|
||||
pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
|
||||
text = pattern.sub(lambda x: rep_map[x.group()], text)
|
||||
text = unicode(text)
|
||||
|
||||
# 使用带映射的标准化函数
|
||||
normalized_text, mappings = normalize_with_map(text)
|
||||
|
||||
# 避免重复标点引起的参考泄露
|
||||
normalized_text = replace_consecutive_punctuation(normalized_text)
|
||||
|
||||
return normalized_text, mappings
|
||||
|
||||
|
||||
class en_G2p(G2p):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -179,6 +179,27 @@ def text_normalize(text):
|
||||
return text
|
||||
|
||||
|
||||
def text_normalize_with_map(text):
|
||||
"""
|
||||
带字符映射的日文标准化函数
|
||||
|
||||
Returns:
|
||||
normalized_text: 标准化后的文本
|
||||
char_mappings: 字典,包含:
|
||||
- "orig_to_norm": list[int], 原始文本每个字符对应标准化文本的位置
|
||||
- "norm_to_orig": list[int], 标准化文本每个字符对应原始文本的位置
|
||||
"""
|
||||
from .char_mapping_utils import build_char_mapping
|
||||
|
||||
# 先进行标准化
|
||||
normalized_text = text_normalize(text)
|
||||
|
||||
# 构建字符映射
|
||||
mappings = build_char_mapping(text, normalized_text)
|
||||
|
||||
return normalized_text, mappings
|
||||
|
||||
|
||||
# Copied from espnet https://github.com/espnet/espnet/blob/master/espnet2/text/phoneme_tokenizer.py
|
||||
def pyopenjtalk_g2p_prosody(text, drop_unvoiced_vowels=True):
|
||||
"""Extract phoneme + prosoody symbol sequence from input full-context labels.
|
||||
|
@ -357,6 +357,22 @@ def g2p_with_word2ph(text, keep_punc=False):
|
||||
word2ph.append(max(1, len(phs_valid)))
|
||||
return phones_all, word2ph
|
||||
|
||||
def text_normalize_with_map(text):
|
||||
"""
|
||||
韩文不需要标准化,直接返回原始文本和恒等映射
|
||||
|
||||
Returns:
|
||||
normalized_text: 原始文本(未改变)
|
||||
char_mappings: 字典,包含恒等映射
|
||||
"""
|
||||
# 韩文不进行标准化,直接返回恒等映射
|
||||
mappings = {
|
||||
"orig_to_norm": list(range(len(text))),
|
||||
"norm_to_orig": list(range(len(text)))
|
||||
}
|
||||
return text, mappings
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
text = "안녕하세요"
|
||||
print(g2p(text))
|
||||
|
Loading…
x
Reference in New Issue
Block a user