mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-05 19:41:56 +08:00
371 lines
9.5 KiB
Python
371 lines
9.5 KiB
Python
import pickle
|
||
import os
|
||
import re
|
||
import wordsegment
|
||
from g2p_en import G2p
|
||
|
||
from text.symbols import punctuation
|
||
|
||
from text.symbols2 import symbols
|
||
|
||
from builtins import str as unicode
|
||
from text.en_normalization.expend import normalize
|
||
from nltk.tokenize import TweetTokenizer
|
||
word_tokenize = TweetTokenizer().tokenize
|
||
from nltk import pos_tag
|
||
|
||
current_file_path = os.path.dirname(__file__)
|
||
CMU_DICT_PATH = os.path.join(current_file_path, "cmudict.rep")
|
||
CMU_DICT_FAST_PATH = os.path.join(current_file_path, "cmudict-fast.rep")
|
||
CMU_DICT_HOT_PATH = os.path.join(current_file_path, "engdict-hot.rep")
|
||
CACHE_PATH = os.path.join(current_file_path, "engdict_cache.pickle")
|
||
NAMECACHE_PATH = os.path.join(current_file_path, "namedict_cache.pickle")
|
||
|
||
|
||
# 适配中文及 g2p_en 标点
|
||
rep_map = {
|
||
"[;::,;]": ",",
|
||
'["’]': "'",
|
||
"。": ".",
|
||
"!": "!",
|
||
"?": "?",
|
||
}
|
||
|
||
|
||
arpa = {
|
||
"AH0",
|
||
"S",
|
||
"AH1",
|
||
"EY2",
|
||
"AE2",
|
||
"EH0",
|
||
"OW2",
|
||
"UH0",
|
||
"NG",
|
||
"B",
|
||
"G",
|
||
"AY0",
|
||
"M",
|
||
"AA0",
|
||
"F",
|
||
"AO0",
|
||
"ER2",
|
||
"UH1",
|
||
"IY1",
|
||
"AH2",
|
||
"DH",
|
||
"IY0",
|
||
"EY1",
|
||
"IH0",
|
||
"K",
|
||
"N",
|
||
"W",
|
||
"IY2",
|
||
"T",
|
||
"AA1",
|
||
"ER1",
|
||
"EH2",
|
||
"OY0",
|
||
"UH2",
|
||
"UW1",
|
||
"Z",
|
||
"AW2",
|
||
"AW1",
|
||
"V",
|
||
"UW2",
|
||
"AA2",
|
||
"ER",
|
||
"AW0",
|
||
"UW0",
|
||
"R",
|
||
"OW1",
|
||
"EH1",
|
||
"ZH",
|
||
"AE0",
|
||
"IH2",
|
||
"IH",
|
||
"Y",
|
||
"JH",
|
||
"P",
|
||
"AY1",
|
||
"EY0",
|
||
"OY2",
|
||
"TH",
|
||
"HH",
|
||
"D",
|
||
"ER0",
|
||
"CH",
|
||
"AO1",
|
||
"AE1",
|
||
"AO2",
|
||
"OY1",
|
||
"AY2",
|
||
"IH1",
|
||
"OW0",
|
||
"L",
|
||
"SH",
|
||
}
|
||
|
||
|
||
def replace_phs(phs):
|
||
rep_map = {"'": "-"}
|
||
phs_new = []
|
||
for ph in phs:
|
||
if ph in symbols:
|
||
phs_new.append(ph)
|
||
elif ph in rep_map.keys():
|
||
phs_new.append(rep_map[ph])
|
||
else:
|
||
print("ph not in symbols: ", ph)
|
||
return phs_new
|
||
|
||
|
||
def replace_consecutive_punctuation(text):
|
||
punctuations = ''.join(re.escape(p) for p in punctuation)
|
||
pattern = f'([{punctuations}\s])([{punctuations}])+'
|
||
result = re.sub(pattern, r'\1', text)
|
||
return result
|
||
|
||
|
||
def read_dict():
|
||
g2p_dict = {}
|
||
start_line = 49
|
||
with open(CMU_DICT_PATH) as f:
|
||
line = f.readline()
|
||
line_index = 1
|
||
while line:
|
||
if line_index >= start_line:
|
||
line = line.strip()
|
||
word_split = line.split(" ")
|
||
word = word_split[0].lower()
|
||
|
||
syllable_split = word_split[1].split(" - ")
|
||
g2p_dict[word] = []
|
||
for syllable in syllable_split:
|
||
phone_split = syllable.split(" ")
|
||
g2p_dict[word].append(phone_split)
|
||
|
||
line_index = line_index + 1
|
||
line = f.readline()
|
||
|
||
return g2p_dict
|
||
|
||
|
||
def read_dict_new():
|
||
g2p_dict = {}
|
||
with open(CMU_DICT_PATH) as f:
|
||
line = f.readline()
|
||
line_index = 1
|
||
while line:
|
||
if line_index >= 57:
|
||
line = line.strip()
|
||
word_split = line.split(" ")
|
||
word = word_split[0].lower()
|
||
g2p_dict[word] = [word_split[1].split(" ")]
|
||
|
||
line_index = line_index + 1
|
||
line = f.readline()
|
||
|
||
with open(CMU_DICT_FAST_PATH) as f:
|
||
line = f.readline()
|
||
line_index = 1
|
||
while line:
|
||
if line_index >= 0:
|
||
line = line.strip()
|
||
word_split = line.split(" ")
|
||
word = word_split[0].lower()
|
||
if word not in g2p_dict:
|
||
g2p_dict[word] = [word_split[1:]]
|
||
|
||
line_index = line_index + 1
|
||
line = f.readline()
|
||
|
||
return g2p_dict
|
||
|
||
def hot_reload_hot(g2p_dict):
|
||
with open(CMU_DICT_HOT_PATH) as f:
|
||
line = f.readline()
|
||
line_index = 1
|
||
while line:
|
||
if line_index >= 0:
|
||
line = line.strip()
|
||
word_split = line.split(" ")
|
||
word = word_split[0].lower()
|
||
# 自定义发音词直接覆盖字典
|
||
g2p_dict[word] = [word_split[1:]]
|
||
|
||
line_index = line_index + 1
|
||
line = f.readline()
|
||
|
||
return g2p_dict
|
||
|
||
|
||
def cache_dict(g2p_dict, file_path):
|
||
with open(file_path, "wb") as pickle_file:
|
||
pickle.dump(g2p_dict, pickle_file)
|
||
|
||
|
||
def get_dict():
|
||
if os.path.exists(CACHE_PATH):
|
||
with open(CACHE_PATH, "rb") as pickle_file:
|
||
g2p_dict = pickle.load(pickle_file)
|
||
else:
|
||
g2p_dict = read_dict_new()
|
||
cache_dict(g2p_dict, CACHE_PATH)
|
||
|
||
g2p_dict = hot_reload_hot(g2p_dict)
|
||
|
||
return g2p_dict
|
||
|
||
|
||
def get_namedict():
|
||
if os.path.exists(NAMECACHE_PATH):
|
||
with open(NAMECACHE_PATH, "rb") as pickle_file:
|
||
name_dict = pickle.load(pickle_file)
|
||
else:
|
||
name_dict = {}
|
||
|
||
return name_dict
|
||
|
||
|
||
def text_normalize(text):
|
||
# todo: eng text normalize
|
||
|
||
# 效果相同,和 chinese.py 保持一致
|
||
pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
|
||
text = pattern.sub(lambda x: rep_map[x.group()], text)
|
||
|
||
text = unicode(text)
|
||
text = normalize(text)
|
||
|
||
# 避免重复标点引起的参考泄露
|
||
text = replace_consecutive_punctuation(text)
|
||
return text
|
||
|
||
|
||
class en_G2p(G2p):
|
||
def __init__(self):
|
||
super().__init__()
|
||
# 分词初始化
|
||
wordsegment.load()
|
||
|
||
# 扩展过时字典, 添加姓名字典
|
||
self.cmu = get_dict()
|
||
self.namedict = get_namedict()
|
||
|
||
# 剔除读音错误的几个缩写
|
||
for word in ["AE", "AI", "AR", "IOS", "HUD", "OS"]:
|
||
del self.cmu[word.lower()]
|
||
|
||
# 修正多音字
|
||
self.homograph2features["read"] = (['R', 'IY1', 'D'], ['R', 'EH1', 'D'], 'VBP')
|
||
self.homograph2features["complex"] = (['K', 'AH0', 'M', 'P', 'L', 'EH1', 'K', 'S'], ['K', 'AA1', 'M', 'P', 'L', 'EH0', 'K', 'S'], 'JJ')
|
||
|
||
|
||
def __call__(self, text):
|
||
# tokenization
|
||
words = word_tokenize(text)
|
||
tokens = pos_tag(words) # tuples of (word, tag)
|
||
|
||
# steps
|
||
prons = []
|
||
for o_word, pos in tokens:
|
||
# 还原 g2p_en 小写操作逻辑
|
||
word = o_word.lower()
|
||
|
||
if re.search("[a-z]", word) is None:
|
||
pron = [word]
|
||
# 先把单字母推出去
|
||
elif len(word) == 1:
|
||
# 单读 A 发音修正, 这里需要原格式 o_word 判断大写
|
||
if o_word == "A":
|
||
pron = ['EY1']
|
||
else:
|
||
pron = self.cmu[word][0]
|
||
# g2p_en 原版多音字处理
|
||
elif word in self.homograph2features: # Check homograph
|
||
pron1, pron2, pos1 = self.homograph2features[word]
|
||
if pos.startswith(pos1):
|
||
pron = pron1
|
||
# pos1比pos长仅出现在read
|
||
elif len(pos) < len(pos1) and pos == pos1[:len(pos)]:
|
||
pron = pron1
|
||
else:
|
||
pron = pron2
|
||
else:
|
||
# 递归查找预测
|
||
pron = self.qryword(o_word)
|
||
|
||
prons.extend(pron)
|
||
prons.extend([" "])
|
||
|
||
return prons[:-1]
|
||
|
||
|
||
def qryword(self, o_word):
|
||
word = o_word.lower()
|
||
|
||
# 查字典, 单字母除外
|
||
if len(word) > 1 and word in self.cmu: # lookup CMU dict
|
||
return self.cmu[word][0]
|
||
|
||
# 单词仅首字母大写时查找姓名字典
|
||
if o_word.istitle() and word in self.namedict:
|
||
return self.namedict[word][0]
|
||
|
||
# oov 长度小于等于 3 直接读字母
|
||
if len(word) <= 3:
|
||
phones = []
|
||
for w in word:
|
||
# 单读 A 发音修正, 此处不存在大写的情况
|
||
if w == "a":
|
||
phones.extend(['EY1'])
|
||
elif not w.isalpha():
|
||
phones.extend([w])
|
||
else:
|
||
phones.extend(self.cmu[w][0])
|
||
return phones
|
||
|
||
# 尝试分离所有格
|
||
if re.match(r"^([a-z]+)('s)$", word):
|
||
phones = self.qryword(word[:-2])[:]
|
||
# P T K F TH HH 无声辅音结尾 's 发 ['S']
|
||
if phones[-1] in ['P', 'T', 'K', 'F', 'TH', 'HH']:
|
||
phones.extend(['S'])
|
||
# S Z SH ZH CH JH 擦声结尾 's 发 ['IH1', 'Z'] 或 ['AH0', 'Z']
|
||
elif phones[-1] in ['S', 'Z', 'SH', 'ZH', 'CH', 'JH']:
|
||
phones.extend(['AH0', 'Z'])
|
||
# B D G DH V M N NG L R W Y 有声辅音结尾 's 发 ['Z']
|
||
# AH0 AH1 AH2 EY0 EY1 EY2 AE0 AE1 AE2 EH0 EH1 EH2 OW0 OW1 OW2 UH0 UH1 UH2 IY0 IY1 IY2 AA0 AA1 AA2 AO0 AO1 AO2
|
||
# ER ER0 ER1 ER2 UW0 UW1 UW2 AY0 AY1 AY2 AW0 AW1 AW2 OY0 OY1 OY2 IH IH0 IH1 IH2 元音结尾 's 发 ['Z']
|
||
else:
|
||
phones.extend(['Z'])
|
||
return phones
|
||
|
||
# 尝试进行分词,应对复合词
|
||
comps = wordsegment.segment(word.lower())
|
||
|
||
# 无法分词的送回去预测
|
||
if len(comps)==1:
|
||
return self.predict(word)
|
||
|
||
# 可以分词的递归处理
|
||
return [phone for comp in comps for phone in self.qryword(comp)]
|
||
|
||
|
||
_g2p = en_G2p()
|
||
|
||
|
||
def g2p(text):
|
||
# g2p_en 整段推理,剔除不存在的arpa返回
|
||
phone_list = _g2p(text)
|
||
phones = [ph if ph != "<unk>" else "UNK" for ph in phone_list if ph not in [" ", "<pad>", "UW", "</s>", "<s>"]]
|
||
|
||
return replace_phs(phones)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
print(g2p("hello"))
|
||
print(g2p(text_normalize("e.g. I used openai's AI tool to draw a picture.")))
|
||
print(g2p(text_normalize("In this; paper, we propose 1 DSPGAN, a GAN-based universal vocoder.")))
|