更新BERT模块缓存为一个独立模块进行调用

This commit is contained in:
Karyl01 2025-05-11 22:55:34 +08:00
parent 321d01f6d1
commit 770a1143ad
2 changed files with 48 additions and 0 deletions

View File

@ -22,6 +22,7 @@ from functools import lru_cache
import torch import torch
from cached import get_cached_bert from cached import get_cached_bert
from cached import CachedBertExtractor
@ -62,6 +63,8 @@ class TextPreprocessor:
self.device = device self.device = device
self.bert_lock = threading.RLock() self.bert_lock = threading.RLock()
self.bert_extractor = CachedBertExtractor("bert-base-chinese", device=device)
def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]: def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]:
print(f"############ {i18n('切分文本')} ############") print(f"############ {i18n('切分文本')} ############")
text = self.replace_consecutive_punctuation(text) text = self.replace_consecutive_punctuation(text)

View File

@ -1,5 +1,9 @@
from functools import lru_cache from functools import lru_cache
import torch import torch
import torch
from functools import lru_cache
from transformers import AutoTokenizer, AutoModelForMaskedLM
from typing import List, Tuple
@lru_cache(maxsize=1000) @lru_cache(maxsize=1000)
def get_cached_bert(norm_text: str, word2ph_tuple: tuple, device_str: str = "cuda"): def get_cached_bert(norm_text: str, word2ph_tuple: tuple, device_str: str = "cuda"):
@ -28,3 +32,44 @@ def get_cached_bert(norm_text: str, word2ph_tuple: tuple, device_str: str = "cud
indices = torch.repeat_interleave(torch.arange(len(word2ph), device=hidden.device), word2ph) indices = torch.repeat_interleave(torch.arange(len(word2ph), device=hidden.device), word2ph)
phone_level_feature = hidden[indices] phone_level_feature = hidden[indices]
return phone_level_feature.T.cpu() return phone_level_feature.T.cpu()
class CachedBertExtractor:
def __init__(self, model_name_or_path: str = "bert-base-chinese", device: str = "cuda"):
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.device = device
self.bert_model = AutoModelForMaskedLM.from_pretrained(
model_name_or_path, output_hidden_states=True
).eval().to(device)
def get_bert_feature(self, norm_text: str, word2ph: List[int]) -> torch.Tensor:
"""
Public method: gets cached BERT feature tensor
"""
word2ph_tuple = tuple(word2ph)
return self._cached_bert(norm_text, word2ph_tuple).to(self.device)
@lru_cache(maxsize=1024)
def _cached_bert(self, norm_text: str, word2ph_tuple: Tuple[int, ...]) -> torch.Tensor:
"""
Cached private method: returns CPU tensor (for lru_cache compatibility)
"""
inputs = self.tokenizer(norm_text, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.bert_model(**inputs)
hidden = torch.cat(outputs.hidden_states[-3:-2], dim=-1)[0][1:-1] # shape: [seq_len-2, hidden_dim]
word2ph_tensor = torch.tensor(list(word2ph_tuple), device=self.device)
indices = torch.repeat_interleave(torch.arange(len(word2ph_tuple), device=self.device), word2ph_tensor)
phone_level_feature = hidden[indices] # [sum(word2ph), hidden_size]
return phone_level_feature.T.cpu() # cache-safe
def clear_cache(self):
"""
Clear the internal BERT feature cache
"""
self._cached_bert.cache_clear()