Merge pull request #559 from KamioRinn/Adjust-text-normlization

中文前端处理优化
This commit is contained in:
RVC-Boss 2024-02-21 16:24:43 +08:00 committed by GitHub
commit b19c1a2b0a
4 changed files with 27 additions and 5 deletions

View File

@ -245,7 +245,7 @@ def get_phones_and_bert(text,language):
formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text)) formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text))
else: else:
# 因无法区别中日文汉字,以用户输入为准 # 因无法区别中日文汉字,以用户输入为准
formattext = re.sub('[a-zA-Z]', '', text) formattext = text
while " " in formattext: while " " in formattext:
formattext = formattext.replace(" ", " ") formattext = formattext.replace(" ", " ")
phones, word2ph, norm_text = clean_text_inf(formattext, language) phones, word2ph, norm_text = clean_text_inf(formattext, language)
@ -375,6 +375,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
if (text[-1] not in splits): text += "" if text_language != "en" else "." if (text[-1] not in splits): text += "" if text_language != "en" else "."
print(i18n("实际输入的目标文本(每句):"), text) print(i18n("实际输入的目标文本(每句):"), text)
phones2,bert2,norm_text2=get_phones_and_bert(text, text_language) phones2,bert2,norm_text2=get_phones_and_bert(text, text_language)
print(i18n("前端处理后的文本(每句):"), norm_text2)
if not ref_free: if not ref_free:
bert = torch.cat([bert1, bert2], 1) bert = torch.cat([bert1, bert2], 1)
all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0) all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0)

View File

@ -34,6 +34,8 @@ rep_map = {
"$": ".", "$": ".",
"/": ",", "/": ",",
"": "-", "": "-",
"~": "",
"":"",
} }
tone_modifier = ToneSandhi() tone_modifier = ToneSandhi()

View File

@ -172,6 +172,21 @@ def replace_range(match) -> str:
return result return result
# ~至表达式
RE_TO_RANGE = re.compile(
r'((-?)((\d+)(\.\d+)?)|(\.(\d+)))(%|°C|℃|度|摄氏度|cm2|cm²|cm3|cm³|cm|db|ds|kg|km|m2|m²|m³|m3|ml|m|mm|s)[~]((-?)((\d+)(\.\d+)?)|(\.(\d+)))(%|°C|℃|度|摄氏度|cm2|cm²|cm3|cm³|cm|db|ds|kg|km|m2|m²|m³|m3|ml|m|mm|s)')
def replace_to_range(match) -> str:
"""
Args:
match (re.Match)
Returns:
str
"""
result = match.group(0).replace('~', '')
return result
def _get_value(value_string: str, use_zero: bool=True) -> List[str]: def _get_value(value_string: str, use_zero: bool=True) -> List[str]:
stripped = value_string.lstrip('0') stripped = value_string.lstrip('0')
if len(stripped) == 0: if len(stripped) == 0:

View File

@ -33,6 +33,7 @@ from .num import RE_NUMBER
from .num import RE_PERCENTAGE from .num import RE_PERCENTAGE
from .num import RE_POSITIVE_QUANTIFIERS from .num import RE_POSITIVE_QUANTIFIERS
from .num import RE_RANGE from .num import RE_RANGE
from .num import RE_TO_RANGE
from .num import replace_default_num from .num import replace_default_num
from .num import replace_frac from .num import replace_frac
from .num import replace_negative_num from .num import replace_negative_num
@ -40,6 +41,7 @@ from .num import replace_number
from .num import replace_percentage from .num import replace_percentage
from .num import replace_positive_quantifier from .num import replace_positive_quantifier
from .num import replace_range from .num import replace_range
from .num import replace_to_range
from .phonecode import RE_MOBILE_PHONE from .phonecode import RE_MOBILE_PHONE
from .phonecode import RE_NATIONAL_UNIFORM_NUMBER from .phonecode import RE_NATIONAL_UNIFORM_NUMBER
from .phonecode import RE_TELEPHONE from .phonecode import RE_TELEPHONE
@ -65,7 +67,7 @@ class TextNormalizer():
if lang == "zh": if lang == "zh":
text = text.replace(" ", "") text = text.replace(" ", "")
# 过滤掉特殊字符 # 过滤掉特殊字符
text = re.sub(r'[——《》【】<=>{}()#&@“”^_|\\]', '', text) text = re.sub(r'[——《》【】<=>{}()#&@“”^_|\\]', '', text)
text = self.SENTENCE_SPLITOR.sub(r'\1\n', text) text = self.SENTENCE_SPLITOR.sub(r'\1\n', text)
text = text.strip() text = text.strip()
sentences = [sentence.strip() for sentence in re.split(r'\n+', text)] sentences = [sentence.strip() for sentence in re.split(r'\n+', text)]
@ -73,8 +75,8 @@ class TextNormalizer():
def _post_replace(self, sentence: str) -> str: def _post_replace(self, sentence: str) -> str:
sentence = sentence.replace('/', '') sentence = sentence.replace('/', '')
sentence = sentence.replace('~', '') # sentence = sentence.replace('~', '至')
sentence = sentence.replace('', '') # sentence = sentence.replace('', '至')
sentence = sentence.replace('', '') sentence = sentence.replace('', '')
sentence = sentence.replace('', '') sentence = sentence.replace('', '')
sentence = sentence.replace('', '') sentence = sentence.replace('', '')
@ -128,6 +130,8 @@ class TextNormalizer():
sentence = RE_TIME_RANGE.sub(replace_time, sentence) sentence = RE_TIME_RANGE.sub(replace_time, sentence)
sentence = RE_TIME.sub(replace_time, sentence) sentence = RE_TIME.sub(replace_time, sentence)
# 处理~波浪号作为至的替换
sentence = RE_TO_RANGE.sub(replace_to_range, sentence)
sentence = RE_TEMPERATURE.sub(replace_temperature, sentence) sentence = RE_TEMPERATURE.sub(replace_temperature, sentence)
sentence = replace_measure(sentence) sentence = replace_measure(sentence)
sentence = RE_FRAC.sub(replace_frac, sentence) sentence = RE_FRAC.sub(replace_frac, sentence)