diff --git a/GPT_SoVITS/prepare_datasets/1-get-text.py b/GPT_SoVITS/prepare_datasets/1-get-text.py index bdeacc7..76a2d69 100644 --- a/GPT_SoVITS/prepare_datasets/1-get-text.py +++ b/GPT_SoVITS/prepare_datasets/1-get-text.py @@ -1,6 +1,9 @@ # -*- coding: utf-8 -*- import os +import re +import LangSegment +from text import chinese inp_text = os.environ.get("inp_text") inp_wav_dir = os.environ.get("inp_wav_dir") @@ -83,24 +86,104 @@ if os.path.exists(txt_path) == False: return phone_level_feature.T + def get_bert_inf(phones:list, word2ph:list, norm_text:str, language:str): + language=language.replace("all_","") + if language == "zh": + feature = get_bert_feature(norm_text, word2ph).to(device) + else: + feature = torch.zeros( + (1024, len(phones)), + dtype=torch.float32, + ).to(device) + + return feature + + def get_phones_and_bert(text:str, language:str, version:str, final:bool=False): + if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}: + language = language.replace("all_","") + if language == "en": + LangSegment.setfilters(["en"]) + formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text)) + else: + # 因无法区别中日韩文汉字,以用户输入为准 + formattext = text + while " " in formattext: + formattext = formattext.replace(" ", " ") + if language == "zh": + if re.search(r'[A-Za-z]', formattext): + formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext) + formattext = chinese.mix_text_normalize(formattext) + return get_phones_and_bert(formattext,"zh",version) + else: + phones, word2ph, norm_text = clean_text(formattext, language, version) + bert = get_bert_feature(norm_text, word2ph).to(device) + elif language == "yue" and re.search(r'[A-Za-z]', formattext): + formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext) + formattext = chinese.mix_text_normalize(formattext) + return get_phones_and_bert(formattext,"yue",version) + else: + phones, word2ph, norm_text = clean_text(formattext, language, version) + bert = torch.zeros( + (1024, len(phones)), + dtype=torch.float32, + ).to(device) + elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}: + textlist=[] + langlist=[] + LangSegment.setfilters(["zh","ja","en","ko"]) + if language == "auto": + for tmp in LangSegment.getTexts(text): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "auto_yue": + for tmp in LangSegment.getTexts(text): + if tmp["lang"] == "zh": + tmp["lang"] = "yue" + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + else: + for tmp in LangSegment.getTexts(text): + if tmp["lang"] == "en": + langlist.append(tmp["lang"]) + else: + # 因无法区别中日韩文汉字,以用户输入为准 + langlist.append(language) + textlist.append(tmp["text"]) + # print(textlist) + # print(langlist) + phones_list = [] + bert_list = [] + norm_text_list = [] + for i in range(len(textlist)): + lang = langlist[i] + phones, word2ph, norm_text = clean_text(textlist[i], lang, version) + bert = get_bert_inf(phones, word2ph, norm_text, lang) + phones_list.append(phones) + norm_text_list.append(norm_text) + bert_list.append(bert) + bert = torch.cat(bert_list, dim=1) + phones = sum(phones_list, []) + norm_text = ''.join(norm_text_list) + + return phones, bert, norm_text + def process(data, res): for name, text, lan in data: try: name=clean_path(name) name = os.path.basename(name) print(name) - phones, word2ph, norm_text = clean_text( - text.replace("%", "-").replace("¥", ","), lan, version + phones, bert_feature, norm_text = get_phones_and_bert( + text.replace("%", "-").replace("¥", ","), lan, 'v2' ) path_bert = "%s/%s.pt" % (bert_dir, name) if os.path.exists(path_bert) == False and lan == "zh": - bert_feature = get_bert_feature(norm_text, word2ph) assert bert_feature.shape[-1] == len(phones) # torch.save(bert_feature, path_bert) my_save(bert_feature, path_bert) phones = " ".join(phones) # res.append([name,phones]) - res.append([name, phones, word2ph, norm_text]) + res.append([name, phones, None, norm_text]) except: print(name, text, traceback.format_exc())