mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-06 03:57:44 +08:00
fix: prevent concurrent access to BERT model with thread lock (#2165)
Added thread lock to protect get_phones_and_bert method from potential race conditions during concurrent access. This addresses issue #1844 where multiple threads accessing the BERT model simultaneously could cause data inconsistency or crashes. Co-authored-by: RVC-Boss <129054828+RVC-Boss@users.noreply.github.com>
This commit is contained in:
parent
b0e465eb72
commit
fef65d40fe
@ -1,5 +1,6 @@
|
|||||||
|
|
||||||
import os, sys
|
import os, sys
|
||||||
|
import threading
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
now_dir = os.getcwd()
|
now_dir = os.getcwd()
|
||||||
@ -54,6 +55,7 @@ class TextPreprocessor:
|
|||||||
self.bert_model = bert_model
|
self.bert_model = bert_model
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.device = device
|
self.device = device
|
||||||
|
self.bert_lock = threading.RLock()
|
||||||
|
|
||||||
def preprocess(self, text:str, lang:str, text_split_method:str, version:str="v2")->List[Dict]:
|
def preprocess(self, text:str, lang:str, text_split_method:str, version:str="v2")->List[Dict]:
|
||||||
print(f'############ {i18n("切分文本")} ############')
|
print(f'############ {i18n("切分文本")} ############')
|
||||||
@ -117,70 +119,71 @@ class TextPreprocessor:
|
|||||||
return self.get_phones_and_bert(text, language, version)
|
return self.get_phones_and_bert(text, language, version)
|
||||||
|
|
||||||
def get_phones_and_bert(self, text:str, language:str, version:str, final:bool=False):
|
def get_phones_and_bert(self, text:str, language:str, version:str, final:bool=False):
|
||||||
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
|
with self.bert_lock:
|
||||||
# language = language.replace("all_","")
|
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
|
||||||
formattext = text
|
# language = language.replace("all_","")
|
||||||
while " " in formattext:
|
formattext = text
|
||||||
formattext = formattext.replace(" ", " ")
|
while " " in formattext:
|
||||||
if language == "all_zh":
|
formattext = formattext.replace(" ", " ")
|
||||||
if re.search(r'[A-Za-z]', formattext):
|
if language == "all_zh":
|
||||||
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
|
if re.search(r'[A-Za-z]', formattext):
|
||||||
formattext = chinese.mix_text_normalize(formattext)
|
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
|
||||||
return self.get_phones_and_bert(formattext,"zh",version)
|
formattext = chinese.mix_text_normalize(formattext)
|
||||||
else:
|
return self.get_phones_and_bert(formattext,"zh",version)
|
||||||
phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version)
|
else:
|
||||||
bert = self.get_bert_feature(norm_text, word2ph).to(self.device)
|
phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version)
|
||||||
elif language == "all_yue" and re.search(r'[A-Za-z]', formattext):
|
bert = self.get_bert_feature(norm_text, word2ph).to(self.device)
|
||||||
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
|
elif language == "all_yue" and re.search(r'[A-Za-z]', formattext):
|
||||||
formattext = chinese.mix_text_normalize(formattext)
|
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
|
||||||
return self.get_phones_and_bert(formattext,"yue",version)
|
formattext = chinese.mix_text_normalize(formattext)
|
||||||
else:
|
return self.get_phones_and_bert(formattext,"yue",version)
|
||||||
phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version)
|
else:
|
||||||
bert = torch.zeros(
|
phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version)
|
||||||
(1024, len(phones)),
|
bert = torch.zeros(
|
||||||
dtype=torch.float32,
|
(1024, len(phones)),
|
||||||
).to(self.device)
|
dtype=torch.float32,
|
||||||
elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}:
|
).to(self.device)
|
||||||
textlist=[]
|
elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}:
|
||||||
langlist=[]
|
textlist=[]
|
||||||
if language == "auto":
|
langlist=[]
|
||||||
for tmp in LangSegmenter.getTexts(text):
|
if language == "auto":
|
||||||
langlist.append(tmp["lang"])
|
for tmp in LangSegmenter.getTexts(text):
|
||||||
textlist.append(tmp["text"])
|
langlist.append(tmp["lang"])
|
||||||
elif language == "auto_yue":
|
textlist.append(tmp["text"])
|
||||||
for tmp in LangSegmenter.getTexts(text):
|
elif language == "auto_yue":
|
||||||
if tmp["lang"] == "zh":
|
for tmp in LangSegmenter.getTexts(text):
|
||||||
tmp["lang"] = "yue"
|
if tmp["lang"] == "zh":
|
||||||
langlist.append(tmp["lang"])
|
tmp["lang"] = "yue"
|
||||||
textlist.append(tmp["text"])
|
langlist.append(tmp["lang"])
|
||||||
else:
|
textlist.append(tmp["text"])
|
||||||
for tmp in LangSegmenter.getTexts(text):
|
else:
|
||||||
if tmp["lang"] == "en":
|
for tmp in LangSegmenter.getTexts(text):
|
||||||
langlist.append(tmp["lang"])
|
if tmp["lang"] == "en":
|
||||||
else:
|
langlist.append(tmp["lang"])
|
||||||
# 因无法区别中日韩文汉字,以用户输入为准
|
else:
|
||||||
langlist.append(language)
|
# 因无法区别中日韩文汉字,以用户输入为准
|
||||||
textlist.append(tmp["text"])
|
langlist.append(language)
|
||||||
# print(textlist)
|
textlist.append(tmp["text"])
|
||||||
# print(langlist)
|
# print(textlist)
|
||||||
phones_list = []
|
# print(langlist)
|
||||||
bert_list = []
|
phones_list = []
|
||||||
norm_text_list = []
|
bert_list = []
|
||||||
for i in range(len(textlist)):
|
norm_text_list = []
|
||||||
lang = langlist[i]
|
for i in range(len(textlist)):
|
||||||
phones, word2ph, norm_text = self.clean_text_inf(textlist[i], lang, version)
|
lang = langlist[i]
|
||||||
bert = self.get_bert_inf(phones, word2ph, norm_text, lang)
|
phones, word2ph, norm_text = self.clean_text_inf(textlist[i], lang, version)
|
||||||
phones_list.append(phones)
|
bert = self.get_bert_inf(phones, word2ph, norm_text, lang)
|
||||||
norm_text_list.append(norm_text)
|
phones_list.append(phones)
|
||||||
bert_list.append(bert)
|
norm_text_list.append(norm_text)
|
||||||
bert = torch.cat(bert_list, dim=1)
|
bert_list.append(bert)
|
||||||
phones = sum(phones_list, [])
|
bert = torch.cat(bert_list, dim=1)
|
||||||
norm_text = ''.join(norm_text_list)
|
phones = sum(phones_list, [])
|
||||||
|
norm_text = ''.join(norm_text_list)
|
||||||
|
|
||||||
if not final and len(phones) < 6:
|
if not final and len(phones) < 6:
|
||||||
return self.get_phones_and_bert("." + text,language,version,final=True)
|
return self.get_phones_and_bert("." + text,language,version,final=True)
|
||||||
|
|
||||||
return phones, bert, norm_text
|
return phones, bert, norm_text
|
||||||
|
|
||||||
|
|
||||||
def get_bert_feature(self, text:str, word2ph:list)->torch.Tensor:
|
def get_bert_feature(self, text:str, word2ph:list)->torch.Tensor:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user