fix: prevent concurrent access to BERT model with thread lock (#2165)

Added thread lock to protect get_phones_and_bert method from potential race conditions during concurrent access. This addresses issue #1844 where multiple threads accessing the BERT model simultaneously could cause data inconsistency or crashes.

Co-authored-by: RVC-Boss <129054828+RVC-Boss@users.noreply.github.com>
This commit is contained in:
lishq 2025-03-26 15:03:36 +08:00 committed by GitHub
parent b0e465eb72
commit fef65d40fe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,5 +1,6 @@
import os, sys import os, sys
import threading
from tqdm import tqdm from tqdm import tqdm
now_dir = os.getcwd() now_dir = os.getcwd()
@ -54,6 +55,7 @@ class TextPreprocessor:
self.bert_model = bert_model self.bert_model = bert_model
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.device = device self.device = device
self.bert_lock = threading.RLock()
def preprocess(self, text:str, lang:str, text_split_method:str, version:str="v2")->List[Dict]: def preprocess(self, text:str, lang:str, text_split_method:str, version:str="v2")->List[Dict]:
print(f'############ {i18n("切分文本")} ############') print(f'############ {i18n("切分文本")} ############')
@ -117,70 +119,71 @@ class TextPreprocessor:
return self.get_phones_and_bert(text, language, version) return self.get_phones_and_bert(text, language, version)
def get_phones_and_bert(self, text:str, language:str, version:str, final:bool=False): def get_phones_and_bert(self, text:str, language:str, version:str, final:bool=False):
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}: with self.bert_lock:
# language = language.replace("all_","") if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
formattext = text # language = language.replace("all_","")
while " " in formattext: formattext = text
formattext = formattext.replace(" ", " ") while " " in formattext:
if language == "all_zh": formattext = formattext.replace(" ", " ")
if re.search(r'[A-Za-z]', formattext): if language == "all_zh":
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext) if re.search(r'[A-Za-z]', formattext):
formattext = chinese.mix_text_normalize(formattext) formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
return self.get_phones_and_bert(formattext,"zh",version) formattext = chinese.mix_text_normalize(formattext)
else: return self.get_phones_and_bert(formattext,"zh",version)
phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version) else:
bert = self.get_bert_feature(norm_text, word2ph).to(self.device) phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version)
elif language == "all_yue" and re.search(r'[A-Za-z]', formattext): bert = self.get_bert_feature(norm_text, word2ph).to(self.device)
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext) elif language == "all_yue" and re.search(r'[A-Za-z]', formattext):
formattext = chinese.mix_text_normalize(formattext) formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
return self.get_phones_and_bert(formattext,"yue",version) formattext = chinese.mix_text_normalize(formattext)
else: return self.get_phones_and_bert(formattext,"yue",version)
phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version) else:
bert = torch.zeros( phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version)
(1024, len(phones)), bert = torch.zeros(
dtype=torch.float32, (1024, len(phones)),
).to(self.device) dtype=torch.float32,
elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}: ).to(self.device)
textlist=[] elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}:
langlist=[] textlist=[]
if language == "auto": langlist=[]
for tmp in LangSegmenter.getTexts(text): if language == "auto":
langlist.append(tmp["lang"]) for tmp in LangSegmenter.getTexts(text):
textlist.append(tmp["text"]) langlist.append(tmp["lang"])
elif language == "auto_yue": textlist.append(tmp["text"])
for tmp in LangSegmenter.getTexts(text): elif language == "auto_yue":
if tmp["lang"] == "zh": for tmp in LangSegmenter.getTexts(text):
tmp["lang"] = "yue" if tmp["lang"] == "zh":
langlist.append(tmp["lang"]) tmp["lang"] = "yue"
textlist.append(tmp["text"]) langlist.append(tmp["lang"])
else: textlist.append(tmp["text"])
for tmp in LangSegmenter.getTexts(text): else:
if tmp["lang"] == "en": for tmp in LangSegmenter.getTexts(text):
langlist.append(tmp["lang"]) if tmp["lang"] == "en":
else: langlist.append(tmp["lang"])
# 因无法区别中日韩文汉字,以用户输入为准 else:
langlist.append(language) # 因无法区别中日韩文汉字,以用户输入为准
textlist.append(tmp["text"]) langlist.append(language)
# print(textlist) textlist.append(tmp["text"])
# print(langlist) # print(textlist)
phones_list = [] # print(langlist)
bert_list = [] phones_list = []
norm_text_list = [] bert_list = []
for i in range(len(textlist)): norm_text_list = []
lang = langlist[i] for i in range(len(textlist)):
phones, word2ph, norm_text = self.clean_text_inf(textlist[i], lang, version) lang = langlist[i]
bert = self.get_bert_inf(phones, word2ph, norm_text, lang) phones, word2ph, norm_text = self.clean_text_inf(textlist[i], lang, version)
phones_list.append(phones) bert = self.get_bert_inf(phones, word2ph, norm_text, lang)
norm_text_list.append(norm_text) phones_list.append(phones)
bert_list.append(bert) norm_text_list.append(norm_text)
bert = torch.cat(bert_list, dim=1) bert_list.append(bert)
phones = sum(phones_list, []) bert = torch.cat(bert_list, dim=1)
norm_text = ''.join(norm_text_list) phones = sum(phones_list, [])
norm_text = ''.join(norm_text_list)
if not final and len(phones) < 6: if not final and len(phones) < 6:
return self.get_phones_and_bert("." + text,language,version,final=True) return self.get_phones_and_bert("." + text,language,version,final=True)
return phones, bert, norm_text return phones, bert, norm_text
def get_bert_feature(self, text:str, word2ph:list)->torch.Tensor: def get_bert_feature(self, text:str, word2ph:list)->torch.Tensor: