Merge 0cc6b8aaef4c637c674f8a6660ce8f59124a240e into 9da7e17efe05041e31d3c3f42c8730ae890397f2

This commit is contained in:
wzy3650 2025-04-02 04:19:19 +09:00 committed by GitHub
commit 99b6bf4e02
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,6 +1,9 @@
# -*- coding: utf-8 -*-
import os
import re
import LangSegment
from text import chinese
inp_text = os.environ.get("inp_text")
inp_wav_dir = os.environ.get("inp_wav_dir")
@ -83,24 +86,104 @@ if os.path.exists(txt_path) == False:
return phone_level_feature.T
def get_bert_inf(phones:list, word2ph:list, norm_text:str, language:str):
language=language.replace("all_","")
if language == "zh":
feature = get_bert_feature(norm_text, word2ph).to(device)
else:
feature = torch.zeros(
(1024, len(phones)),
dtype=torch.float32,
).to(device)
return feature
def get_phones_and_bert(text:str, language:str, version:str, final:bool=False):
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
language = language.replace("all_","")
if language == "en":
LangSegment.setfilters(["en"])
formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text))
else:
# 因无法区别中日韩文汉字,以用户输入为准
formattext = text
while " " in formattext:
formattext = formattext.replace(" ", " ")
if language == "zh":
if re.search(r'[A-Za-z]', formattext):
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
formattext = chinese.mix_text_normalize(formattext)
return get_phones_and_bert(formattext,"zh",version)
else:
phones, word2ph, norm_text = clean_text(formattext, language, version)
bert = get_bert_feature(norm_text, word2ph).to(device)
elif language == "yue" and re.search(r'[A-Za-z]', formattext):
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
formattext = chinese.mix_text_normalize(formattext)
return get_phones_and_bert(formattext,"yue",version)
else:
phones, word2ph, norm_text = clean_text(formattext, language, version)
bert = torch.zeros(
(1024, len(phones)),
dtype=torch.float32,
).to(device)
elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}:
textlist=[]
langlist=[]
LangSegment.setfilters(["zh","ja","en","ko"])
if language == "auto":
for tmp in LangSegment.getTexts(text):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "auto_yue":
for tmp in LangSegment.getTexts(text):
if tmp["lang"] == "zh":
tmp["lang"] = "yue"
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
else:
for tmp in LangSegment.getTexts(text):
if tmp["lang"] == "en":
langlist.append(tmp["lang"])
else:
# 因无法区别中日韩文汉字,以用户输入为准
langlist.append(language)
textlist.append(tmp["text"])
# print(textlist)
# print(langlist)
phones_list = []
bert_list = []
norm_text_list = []
for i in range(len(textlist)):
lang = langlist[i]
phones, word2ph, norm_text = clean_text(textlist[i], lang, version)
bert = get_bert_inf(phones, word2ph, norm_text, lang)
phones_list.append(phones)
norm_text_list.append(norm_text)
bert_list.append(bert)
bert = torch.cat(bert_list, dim=1)
phones = sum(phones_list, [])
norm_text = ''.join(norm_text_list)
return phones, bert, norm_text
def process(data, res):
for name, text, lan in data:
try:
name=clean_path(name)
name = os.path.basename(name)
print(name)
phones, word2ph, norm_text = clean_text(
text.replace("%", "-").replace("", ","), lan, version
phones, bert_feature, norm_text = get_phones_and_bert(
text.replace("%", "-").replace("", ","), lan, 'v2'
)
path_bert = "%s/%s.pt" % (bert_dir, name)
if os.path.exists(path_bert) == False and lan == "zh":
bert_feature = get_bert_feature(norm_text, word2ph)
assert bert_feature.shape[-1] == len(phones)
# torch.save(bert_feature, path_bert)
my_save(bert_feature, path_bert)
phones = " ".join(phones)
# res.append([name,phones])
res.append([name, phones, word2ph, norm_text])
res.append([name, phones, None, norm_text])
except:
print(name, text, traceback.format_exc())