BERT缓存记录,减少对于同样token的转化

This commit is contained in:
Karyl01 2025-05-11 22:52:33 +08:00
parent f1cfc63851
commit 321d01f6d1
3 changed files with 58 additions and 151 deletions

View File

@ -1,4 +1,3 @@
from imghdr import tests
import os import os
import sys import sys
import threading import threading
@ -18,9 +17,13 @@ from text import cleaned_text_to_sequence
from transformers import AutoModelForMaskedLM, AutoTokenizer from transformers import AutoModelForMaskedLM, AutoTokenizer
from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method
from export_torch_script_v3 import extract_bert_features
from tools.i18n.i18n import I18nAuto, scan_language_list from tools.i18n.i18n import I18nAuto, scan_language_list
from functools import lru_cache
import torch
from cached import get_cached_bert
language = os.environ.get("language", "Auto") language = os.environ.get("language", "Auto")
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
@ -59,21 +62,14 @@ class TextPreprocessor:
self.device = device self.device = device
self.bert_lock = threading.RLock() self.bert_lock = threading.RLock()
def preprocess1(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]: def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]:
print(f"############ {i18n('切分文本')} ############") print(f"############ {i18n('切分文本')} ############")
text = self.replace_consecutive_punctuation(text) text = self.replace_consecutive_punctuation(text)
texts = self.pre_seg_text(text, lang, text_split_method) texts = self.pre_seg_text(text, lang, text_split_method)
result = [] result = []
# text_batch = []
print(f"############ {i18n('提取文本Bert特征')} ############") print(f"############ {i18n('提取文本Bert特征')} ############")
text_batch = [] for text in tqdm(texts):
for text in texts: phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version)
if text.strip(): # 忽略空句子
text_batch.append(text)
if not text_batch:
return []
phones_list, bert_list, norm_texts = self.batch_get_phones_and_bert(text_batch, lang, version)
for phones, bert_features, norm_text in zip(phones_list, bert_list, norm_texts):
if phones is None or norm_text == "": if phones is None or norm_text == "":
continue continue
res = { res = {
@ -83,52 +79,6 @@ class TextPreprocessor:
} }
result.append(res) result.append(res)
return result return result
# for text in texts:
# if text.strip(): # 忽略空句子
# text_batch.append(text)
# phones_list, bert_list, norm_texts = self.batch_get_phones_and_bert(text_batch, lang, version)
# for phones, bert_features, norm_text in zip(phones_list, bert_list, norm_texts):
# if phones is None or norm_text == "":
# continue
# res = {
# "phones": phones,
# "bert_features": bert_features,
# "norm_text": norm_text,
# }
# result.append(res)
# return result
# for text in tqdm(texts):
# phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version)
# if phones is None or norm_text == "":
# continue
# res = {
# "phones": phones,
# "bert_features": bert_features,
# "norm_text": norm_text,
# }
# result.append(res)
# return result
@torch.jit.script
def build_phone_level_feature(res: torch.Tensor, word2ph: torch.IntTensor) -> torch.Tensor:
"""
将词级别的 BERT 特征转换为音素级别的特征通过 word2ph 指定每个词对应的音素数
Args:
res: [N_words, hidden_dim]
word2ph: [N_words], 每个元素表示当前词需要复制多少次即包含多少个音素
Returns:
[sum(word2ph), hidden_dim] phone 级别特征
"""
phone_level_feature = []
for i in range(word2ph.shape[0]):
repeat_feature = res[i].repeat(word2ph[i].item(), 1)
phone_level_feature.append(repeat_feature)
return torch.cat(phone_level_feature, dim=0)
def pre_seg_text(self, text: str, lang: str, text_split_method: str): def pre_seg_text(self, text: str, lang: str, text_split_method: str):
text = text.strip("\n") text = text.strip("\n")
@ -242,22 +192,25 @@ class TextPreprocessor:
return phones, bert, norm_text return phones, bert, norm_text
def get_bert_feature(self, text: str, word2ph: list) -> torch.Tensor: # def get_bert_feature(self, text: str, word2ph: list) -> torch.Tensor:
with torch.no_grad(): # with torch.no_grad():
inputs = self.tokenizer(text, return_tensors="pt") # inputs = self.tokenizer(text, return_tensors="pt")
for i in inputs: # for i in inputs:
inputs[i] = inputs[i].to(self.device) # inputs[i] = inputs[i].to(self.device)
res = self.bert_model(**inputs, output_hidden_states=True) # res = self.bert_model(**inputs, output_hidden_states=True)
# # 优化保留在GPU处理直到需要时再转CPU # res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
res = torch.cat(res["hidden_states"][-3:-2], -1)[0][1:-1] # 移除不必要的cpu()调用 # assert len(word2ph) == len(text)
assert len(word2ph) == len(text) # phone_level_feature = []
# 向量化优化使用repeat_interleave替代循环 # for i in range(len(word2ph)):
word2ph_tensor = torch.tensor(word2ph, device=res.device) # repeat_feature = res[i].repeat(word2ph[i], 1)
indices = torch.repeat_interleave(torch.arange(len(word2ph), device=res.device), word2ph_tensor) # phone_level_feature.append(repeat_feature)
phone_level_feature = res[indices] # phone_level_feature = torch.cat(phone_level_feature, dim=0)
# 仅在需要时转CPU # return phone_level_feature.T
phone_level_feature = phone_level_feature.cpu()
return phone_level_feature.T def get_bert_feature(self, norm_text: str, word2ph: list) -> torch.Tensor:
# 注意word2ph 是 list需转为 tuple 作为缓存键
bert = get_cached_bert(norm_text, tuple(word2ph), str(self.device))
return bert.to(self.device)
def clean_text_inf(self, text: str, language: str, version: str = "v2"): def clean_text_inf(self, text: str, language: str, version: str = "v2"):
language = language.replace("all_", "") language = language.replace("all_", "")
@ -295,79 +248,3 @@ class TextPreprocessor:
return result return result
# def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]:
# print(f"############ {i18n('切分文本')} ############")
# text = self.replace_consecutive_punctuation(text)
# texts = self.pre_seg_text(text, lang, text_split_method)
# result = []
# print(f"############ {i18n('提取文本Bert特征')} ############")
# for text in tqdm(texts):
# phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version)
# if phones is None or norm_text == "":
# continue
# res = {
# "phones": phones,
# "bert_features": bert_features,
# "norm_text": norm_text,
# }
# result.append(res)
# return result
def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]:
print(f"############ {i18n('切分文本')} ############")
text = self.replace_consecutive_punctuation(text)
texts = self.pre_seg_text(text, lang, text_split_method)
result = []
print(f"############ {i18n('提取文本Bert特征')} ############")
extract_bert_features(texts)
for text in texts: #
phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version)
if phones is None or norm_text == "":
continue
res = {
"phones": phones,
"bert_features": bert_features,
"norm_text": norm_text,
}
result.append(res)
return result
def batch_get_phones_and_bert(self, texts: List[str], language: str, version: str):
phones_list = []
bert_list = []
norm_text_list = []
# 预处理文本,获取每句的 phones, word2ph, norm_text
format_texts = [self.clean_text_inf(t, language, version) for t in texts]
norm_texts = [x[2] for x in format_texts]
word2phs = [x[1] for x in format_texts]
# 批量送入 tokenizer
inputs = self.tokenizer(norm_texts, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.bert_model(**inputs, output_hidden_states=True)
# 使用 last_hidden_state 是正确且高效的方式
hidden_states = outputs.last_hidden_state # [batch_size, seq_len, hidden_dim]
for i in range(len(texts)):
res = hidden_states[i][1:-1].cpu() # 去掉 [CLS] 和 [SEP]
word2ph = word2phs[i]
phone_level_feature = []
for j in range(len(word2ph)):
if j >= res.shape[0]:
print(f"警告BERT输出不足跳过第 {i} 句中第 {j} 个 token")
continue
phone_level_feature.append(res[j].repeat(word2ph[j], 1))
phone_level_feature = torch.cat(phone_level_feature, dim=0)
bert_list.append(phone_level_feature.T)
phones_list.append(cleaned_text_to_sequence(format_texts[i][0], version))
norm_text_list.append(norm_texts[i])
return phones_list, bert_list, norm_text_list

View File

@ -0,0 +1,30 @@
from functools import lru_cache
import torch
@lru_cache(maxsize=1000)
def get_cached_bert(norm_text: str, word2ph_tuple: tuple, device_str: str = "cuda"):
"""
缓存 BERT 提取函数用于相同 norm_text 时复用特征
Args:
norm_text (str): 清洗后的文本可复用
word2ph_tuple (tuple): word2ph 列表转换成 tuple因为 lru_cache 不支持 list
device_str (str): 设备信息用于转移到正确设备上
Returns:
Tensor: 形状 [hidden_dim, total_phonemes]
"""
from transformers import AutoTokenizer, AutoModelForMaskedLM
# 如果你在类里,可以改成 self.tokenizer 和 self.model
tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
model = AutoModelForMaskedLM.from_pretrained("bert-base-chinese", output_hidden_states=True).eval().to(device_str)
inputs = tokenizer(norm_text, return_tensors="pt").to(device_str)
with torch.no_grad():
outputs = model(**inputs)
hidden = torch.cat(outputs.hidden_states[-3:-2], dim=-1)[0][1:-1] # 去掉 CLS/SEP
word2ph = torch.tensor(list(word2ph_tuple), device=hidden.device)
indices = torch.repeat_interleave(torch.arange(len(word2ph), device=hidden.device), word2ph)
phone_level_feature = hidden[indices]
return phone_level_feature.T.cpu()

Binary file not shown.