修改批处理bert特征

This commit is contained in:
Karyl01 2025-05-09 21:53:15 +08:00
parent 0d654cd238
commit c7b61c6fd4
5 changed files with 72 additions and 4 deletions

View File

@ -61,6 +61,7 @@ class TextPreprocessor:
text = self.replace_consecutive_punctuation(text) text = self.replace_consecutive_punctuation(text)
texts = self.pre_seg_text(text, lang, text_split_method) texts = self.pre_seg_text(text, lang, text_split_method)
result = [] result = []
# text_batch = []
print(f"############ {i18n('提取文本Bert特征')} ############") print(f"############ {i18n('提取文本Bert特征')} ############")
for text in tqdm(texts): for text in tqdm(texts):
phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version) phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version)
@ -73,6 +74,34 @@ class TextPreprocessor:
} }
result.append(res) result.append(res)
return result return result
# for text in texts:
# if text.strip(): # 忽略空句子
# text_batch.append(text)
# phones_list, bert_list, norm_texts = self.batch_get_phones_and_bert(text_batch, lang, version)
# for phones, bert_features, norm_text in zip(phones_list, bert_list, norm_texts):
# if phones is None or norm_text == "":
# continue
# res = {
# "phones": phones,
# "bert_features": bert_features,
# "norm_text": norm_text,
# }
# result.append(res)
# return result
# for text in tqdm(texts):
# phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version)
# if phones is None or norm_text == "":
# continue
# res = {
# "phones": phones,
# "bert_features": bert_features,
# "norm_text": norm_text,
# }
# result.append(res)
# return result
def pre_seg_text(self, text: str, lang: str, text_split_method: str): def pre_seg_text(self, text: str, lang: str, text_split_method: str):
text = text.strip("\n") text = text.strip("\n")
@ -235,3 +264,42 @@ class TextPreprocessor:
pattern = f"([{punctuations}])([{punctuations}])+" pattern = f"([{punctuations}])([{punctuations}])+"
result = re.sub(pattern, r"\1", text) result = re.sub(pattern, r"\1", text)
return result return result
def batch_get_phones_and_bert(self, texts: List[str], language: str, version: str):
phones_list = []
bert_list = []
norm_text_list = []
# 预处理文本,获取每句的 phones, word2ph, norm_text
format_texts = [self.clean_text_inf(t, language, version) for t in texts]
norm_texts = [x[2] for x in format_texts]
word2phs = [x[1] for x in format_texts]
# 批量送入 tokenizer
inputs = self.tokenizer(norm_texts, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.bert_model(**inputs, output_hidden_states=True)
# 使用 last_hidden_state 是正确且高效的方式
hidden_states = outputs.last_hidden_state # [batch_size, seq_len, hidden_dim]
for i in range(len(texts)):
res = hidden_states[i][1:-1].cpu() # 去掉 [CLS] 和 [SEP]
word2ph = word2phs[i]
phone_level_feature = []
for j in range(len(word2ph)):
if j >= res.shape[0]:
print(f"警告BERT输出不足跳过第 {i} 句中第 {j} 个 token")
continue
phone_level_feature.append(res[j].repeat(word2ph[j], 1))
phone_level_feature = torch.cat(phone_level_feature, dim=0)
bert_list.append(phone_level_feature.T)
phones_list.append(cleaned_text_to_sequence(format_texts[i][0], version))
norm_text_list.append(norm_texts[i])
return phones_list, bert_list, norm_text_list

View File

@ -1,8 +1,8 @@
custom: custom:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cpu device: cuda
is_half: false is_half: true
t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
version: v2 version: v2
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth

View File

@ -26,7 +26,7 @@ from inference_webui import get_spepc, norm_spec, resample, ssl_model
logging.config.dictConfig(uvicorn.config.LOGGING_CONFIG) logging.config.dictConfig(uvicorn.config.LOGGING_CONFIG)
logger = logging.getLogger("uvicorn") logger = logging.getLogger("uvicorn")
is_half = False is_half = True
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
now_dir = os.getcwd() now_dir = os.getcwd()

View File

@ -69,7 +69,7 @@ if torch.cuda.is_available():
else: else:
device = "cpu" device = "cpu"
is_half = False is_half = True
device = "cpu" device = "cpu"
dict_language_v1 = { dict_language_v1 = {

Binary file not shown.