GPT-SoVITS/GPT_SoVITS/text/japanese.py

322 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# modified from https://github.com/CjangCjengh/vits/blob/main/text/japanese.py
import re
import os
import hashlib
try:
import pyopenjtalk
current_file_path = os.path.dirname(__file__)
# 防止win下无法读取模型
if os.name == "nt":
python_dir = os.getcwd()
OPEN_JTALK_DICT_DIR = pyopenjtalk.OPEN_JTALK_DICT_DIR.decode("utf-8")
if not (re.match(r"^[A-Za-z0-9_/\\:.\-]*$", OPEN_JTALK_DICT_DIR)):
if OPEN_JTALK_DICT_DIR[: len(python_dir)].upper() == python_dir.upper():
OPEN_JTALK_DICT_DIR = os.path.join(os.path.relpath(OPEN_JTALK_DICT_DIR, python_dir))
else:
import shutil
if not os.path.exists("TEMP"):
os.mkdir("TEMP")
if not os.path.exists(os.path.join("TEMP", "ja")):
os.mkdir(os.path.join("TEMP", "ja"))
if os.path.exists(os.path.join("TEMP", "ja", "open_jtalk_dic")):
shutil.rmtree(os.path.join("TEMP", "ja", "open_jtalk_dic"))
shutil.copytree(
pyopenjtalk.OPEN_JTALK_DICT_DIR.decode("utf-8"),
os.path.join("TEMP", "ja", "open_jtalk_dic"),
)
OPEN_JTALK_DICT_DIR = os.path.join("TEMP", "ja", "open_jtalk_dic")
pyopenjtalk.OPEN_JTALK_DICT_DIR = OPEN_JTALK_DICT_DIR.encode("utf-8")
if not (re.match(r"^[A-Za-z0-9_/\\:.\-]*$", current_file_path)):
if current_file_path[: len(python_dir)].upper() == python_dir.upper():
current_file_path = os.path.join(os.path.relpath(current_file_path, python_dir))
else:
if not os.path.exists("TEMP"):
os.mkdir("TEMP")
if not os.path.exists(os.path.join("TEMP", "ja")):
os.mkdir(os.path.join("TEMP", "ja"))
if not os.path.exists(os.path.join("TEMP", "ja", "ja_userdic")):
os.mkdir(os.path.join("TEMP", "ja", "ja_userdic"))
shutil.copyfile(
os.path.join(current_file_path, "ja_userdic", "userdict.csv"),
os.path.join("TEMP", "ja", "ja_userdic", "userdict.csv"),
)
current_file_path = os.path.join("TEMP", "ja")
def get_hash(fp: str) -> str:
hash_md5 = hashlib.md5()
with open(fp, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()
USERDIC_CSV_PATH = os.path.join(current_file_path, "ja_userdic", "userdict.csv")
USERDIC_BIN_PATH = os.path.join(current_file_path, "ja_userdic", "user.dict")
USERDIC_HASH_PATH = os.path.join(current_file_path, "ja_userdic", "userdict.md5")
# 如果没有用户词典就生成一个如果有就检查md5如果不一样就重新生成
if os.path.exists(USERDIC_CSV_PATH):
if (
not os.path.exists(USERDIC_BIN_PATH)
or get_hash(USERDIC_CSV_PATH) != open(USERDIC_HASH_PATH, "r", encoding="utf-8").read()
):
pyopenjtalk.mecab_dict_index(USERDIC_CSV_PATH, USERDIC_BIN_PATH)
with open(USERDIC_HASH_PATH, "w", encoding="utf-8") as f:
f.write(get_hash(USERDIC_CSV_PATH))
if os.path.exists(USERDIC_BIN_PATH):
pyopenjtalk.update_global_jtalk_with_user_dict(USERDIC_BIN_PATH)
except Exception:
# print(e)
import pyopenjtalk
# failed to load user dictionary, ignore.
pass
from text.symbols import punctuation
# Regular expression matching Japanese without punctuation marks:
_japanese_characters = re.compile(
r"[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]"
)
# Regular expression matching non-Japanese characters or punctuation marks:
_japanese_marks = re.compile(
r"[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]"
)
# List of (symbol, Japanese) pairs for marks:
_symbols_to_japanese = [(re.compile("%s" % x[0]), x[1]) for x in [("", "パーセント")]]
# List of (consonant, sokuon) pairs:
_real_sokuon = [
(re.compile("%s" % x[0]), x[1])
for x in [
(r"Q([↑↓]*[kg])", r"k#\1"),
(r"Q([↑↓]*[tdjʧ])", r"t#\1"),
(r"Q([↑↓]*[sʃ])", r"s\1"),
(r"Q([↑↓]*[pb])", r"p#\1"),
]
]
# List of (consonant, hatsuon) pairs:
_real_hatsuon = [
(re.compile("%s" % x[0]), x[1])
for x in [
(r"N([↑↓]*[pbm])", r"m\1"),
(r"N([↑↓]*[ʧʥj])", r"n^\1"),
(r"N([↑↓]*[tdn])", r"n\1"),
(r"N([↑↓]*[kg])", r"ŋ\1"),
]
]
def post_replace_ph(ph):
rep_map = {
"": ",",
"": ",",
"": ",",
"": ".",
"": "!",
"": "?",
"\n": ".",
"·": ",",
"": ",",
"...": "",
}
if ph in rep_map.keys():
ph = rep_map[ph]
return ph
def replace_consecutive_punctuation(text):
punctuations = "".join(re.escape(p) for p in punctuation)
pattern = f"([{punctuations}])([{punctuations}])+"
result = re.sub(pattern, r"\1", text)
return result
def symbols_to_japanese(text):
for regex, replacement in _symbols_to_japanese:
text = re.sub(regex, replacement, text)
return text
def preprocess_jap(text, with_prosody=False):
"""Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html"""
text = symbols_to_japanese(text)
# English words to lower case, should have no influence on japanese words.
text = text.lower()
sentences = re.split(_japanese_marks, text)
marks = re.findall(_japanese_marks, text)
text = []
for i, sentence in enumerate(sentences):
if re.match(_japanese_characters, sentence):
if with_prosody:
text += pyopenjtalk_g2p_prosody(sentence)[1:-1]
else:
p = pyopenjtalk.g2p(sentence)
text += p.split(" ")
if i < len(marks):
if marks[i] == " ": # 防止意外的UNK
continue
text += [marks[i].replace(" ", "")]
return text
def text_normalize(text):
# todo: jap text normalize
# 避免重复标点引起的参考泄露
text = replace_consecutive_punctuation(text)
return text
def text_normalize_with_map(text):
"""
带字符映射的日文标准化函数
Returns:
normalized_text: 标准化后的文本
char_mappings: 字典,包含:
- "orig_to_norm": list[int], 原始文本每个字符对应标准化文本的位置
- "norm_to_orig": list[int], 标准化文本每个字符对应原始文本的位置
"""
from .char_mapping_utils import build_char_mapping
# 先进行标准化
normalized_text = text_normalize(text)
# 构建字符映射
mappings = build_char_mapping(text, normalized_text)
return normalized_text, mappings
# Copied from espnet https://github.com/espnet/espnet/blob/master/espnet2/text/phoneme_tokenizer.py
def pyopenjtalk_g2p_prosody(text, drop_unvoiced_vowels=True):
"""Extract phoneme + prosoody symbol sequence from input full-context labels.
The algorithm is based on `Prosodic features control by symbols as input of
sequence-to-sequence acoustic modeling for neural TTS`_ with some r9y9's tweaks.
Args:
text (str): Input text.
drop_unvoiced_vowels (bool): whether to drop unvoiced vowels.
Returns:
List[str]: List of phoneme + prosody symbols.
Examples:
>>> from espnet2.text.phoneme_tokenizer import pyopenjtalk_g2p_prosody
>>> pyopenjtalk_g2p_prosody("こんにちは。")
['^', 'k', 'o', '[', 'N', 'n', 'i', 'ch', 'i', 'w', 'a', '$']
.. _`Prosodic features control by symbols as input of sequence-to-sequence acoustic
modeling for neural TTS`: https://doi.org/10.1587/transinf.2020EDP7104
"""
labels = pyopenjtalk.make_label(pyopenjtalk.run_frontend(text))
N = len(labels)
phones = []
for n in range(N):
lab_curr = labels[n]
# current phoneme
p3 = re.search(r"\-(.*?)\+", lab_curr).group(1)
# deal unvoiced vowels as normal vowels
if drop_unvoiced_vowels and p3 in "AEIOU":
p3 = p3.lower()
# deal with sil at the beginning and the end of text
if p3 == "sil":
assert n == 0 or n == N - 1
if n == 0:
phones.append("^")
elif n == N - 1:
# check question form or not
e3 = _numeric_feature_by_regex(r"!(\d+)_", lab_curr)
if e3 == 0:
phones.append("$")
elif e3 == 1:
phones.append("?")
continue
elif p3 == "pau":
phones.append("_")
continue
else:
phones.append(p3)
# accent type and position info (forward or backward)
a1 = _numeric_feature_by_regex(r"/A:([0-9\-]+)\+", lab_curr)
a2 = _numeric_feature_by_regex(r"\+(\d+)\+", lab_curr)
a3 = _numeric_feature_by_regex(r"\+(\d+)/", lab_curr)
# number of mora in accent phrase
f1 = _numeric_feature_by_regex(r"/F:(\d+)_", lab_curr)
a2_next = _numeric_feature_by_regex(r"\+(\d+)\+", labels[n + 1])
# accent phrase border
if a3 == 1 and a2_next == 1 and p3 in "aeiouAEIOUNcl":
phones.append("#")
# pitch falling
elif a1 == 0 and a2_next == a2 + 1 and a2 != f1:
phones.append("]")
# pitch rising
elif a2 == 1 and a2_next == 2:
phones.append("[")
return phones
# Copied from espnet https://github.com/espnet/espnet/blob/master/espnet2/text/phoneme_tokenizer.py
def _numeric_feature_by_regex(regex, s):
match = re.search(regex, s)
if match is None:
return -50
return int(match.group(1))
def g2p(norm_text, with_prosody=True):
phones = preprocess_jap(norm_text, with_prosody)
phones = [post_replace_ph(i) for i in phones]
# todo: implement tones and word2ph
return phones
# Helper for alignment: build phones and word2ph by per-character g2p (ignoring prosody markers)
def g2p_with_word2ph(text, keep_punc=False):
"""
Returns (phones, word2ph)
- Per-character g2p; ignore prosody markers like '[', ']','^', '$', '#', '_'
- Punctuation counted as 1 if keep_punc else skipped
"""
norm_text = text_normalize(text)
phones_all = []
word2ph = []
prosody_markers = {'[', ']', '^', '$', '#', '_'}
punc_set = set(punctuation)
for ch in norm_text:
if ch.isspace() or ch in punc_set:
if keep_punc:
word2ph.append(1)
continue
phs = preprocess_jap(ch, with_prosody=True)
phs = [post_replace_ph(p) for p in phs if p not in prosody_markers and p not in punc_set]
phones_all.extend(phs)
word2ph.append(max(1, len(phs)))
return phones_all, word2ph
if __name__ == "__main__":
phones = g2p("Hello.こんにちは今日もNiCe天気ですねtokyotowerに行きましょう")
print(phones)