diff --git a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py index 653656a..0ebe553 100644 --- a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py +++ b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py @@ -1,5 +1,6 @@ import os, sys +import threading from tqdm import tqdm now_dir = os.getcwd() @@ -54,6 +55,7 @@ class TextPreprocessor: self.bert_model = bert_model self.tokenizer = tokenizer self.device = device + self.bert_lock = threading.RLock() def preprocess(self, text:str, lang:str, text_split_method:str, version:str="v2")->List[Dict]: print(f'############ {i18n("切分文本")} ############') @@ -117,70 +119,71 @@ class TextPreprocessor: return self.get_phones_and_bert(text, language, version) def get_phones_and_bert(self, text:str, language:str, version:str, final:bool=False): - if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}: - # language = language.replace("all_","") - formattext = text - while " " in formattext: - formattext = formattext.replace(" ", " ") - if language == "all_zh": - if re.search(r'[A-Za-z]', formattext): - formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext) - formattext = chinese.mix_text_normalize(formattext) - return self.get_phones_and_bert(formattext,"zh",version) - else: - phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version) - bert = self.get_bert_feature(norm_text, word2ph).to(self.device) - elif language == "all_yue" and re.search(r'[A-Za-z]', formattext): - formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext) - formattext = chinese.mix_text_normalize(formattext) - return self.get_phones_and_bert(formattext,"yue",version) - else: - phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version) - bert = torch.zeros( - (1024, len(phones)), - dtype=torch.float32, - ).to(self.device) - elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}: - textlist=[] - langlist=[] - if language == "auto": - for tmp in LangSegmenter.getTexts(text): - langlist.append(tmp["lang"]) - textlist.append(tmp["text"]) - elif language == "auto_yue": - for tmp in LangSegmenter.getTexts(text): - if tmp["lang"] == "zh": - tmp["lang"] = "yue" - langlist.append(tmp["lang"]) - textlist.append(tmp["text"]) - else: - for tmp in LangSegmenter.getTexts(text): - if tmp["lang"] == "en": - langlist.append(tmp["lang"]) - else: - # 因无法区别中日韩文汉字,以用户输入为准 - langlist.append(language) - textlist.append(tmp["text"]) - # print(textlist) - # print(langlist) - phones_list = [] - bert_list = [] - norm_text_list = [] - for i in range(len(textlist)): - lang = langlist[i] - phones, word2ph, norm_text = self.clean_text_inf(textlist[i], lang, version) - bert = self.get_bert_inf(phones, word2ph, norm_text, lang) - phones_list.append(phones) - norm_text_list.append(norm_text) - bert_list.append(bert) - bert = torch.cat(bert_list, dim=1) - phones = sum(phones_list, []) - norm_text = ''.join(norm_text_list) + with self.bert_lock: + if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}: + # language = language.replace("all_","") + formattext = text + while " " in formattext: + formattext = formattext.replace(" ", " ") + if language == "all_zh": + if re.search(r'[A-Za-z]', formattext): + formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext) + formattext = chinese.mix_text_normalize(formattext) + return self.get_phones_and_bert(formattext,"zh",version) + else: + phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version) + bert = self.get_bert_feature(norm_text, word2ph).to(self.device) + elif language == "all_yue" and re.search(r'[A-Za-z]', formattext): + formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext) + formattext = chinese.mix_text_normalize(formattext) + return self.get_phones_and_bert(formattext,"yue",version) + else: + phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version) + bert = torch.zeros( + (1024, len(phones)), + dtype=torch.float32, + ).to(self.device) + elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}: + textlist=[] + langlist=[] + if language == "auto": + for tmp in LangSegmenter.getTexts(text): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "auto_yue": + for tmp in LangSegmenter.getTexts(text): + if tmp["lang"] == "zh": + tmp["lang"] = "yue" + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + else: + for tmp in LangSegmenter.getTexts(text): + if tmp["lang"] == "en": + langlist.append(tmp["lang"]) + else: + # 因无法区别中日韩文汉字,以用户输入为准 + langlist.append(language) + textlist.append(tmp["text"]) + # print(textlist) + # print(langlist) + phones_list = [] + bert_list = [] + norm_text_list = [] + for i in range(len(textlist)): + lang = langlist[i] + phones, word2ph, norm_text = self.clean_text_inf(textlist[i], lang, version) + bert = self.get_bert_inf(phones, word2ph, norm_text, lang) + phones_list.append(phones) + norm_text_list.append(norm_text) + bert_list.append(bert) + bert = torch.cat(bert_list, dim=1) + phones = sum(phones_list, []) + norm_text = ''.join(norm_text_list) - if not final and len(phones) < 6: - return self.get_phones_and_bert("." + text,language,version,final=True) + if not final and len(phones) < 6: + return self.get_phones_and_bert("." + text,language,version,final=True) - return phones, bert, norm_text + return phones, bert, norm_text def get_bert_feature(self, text:str, word2ph:list)->torch.Tensor: