31 lines
1.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from functools import lru_cache
import torch
@lru_cache(maxsize=1000)
def get_cached_bert(norm_text: str, word2ph_tuple: tuple, device_str: str = "cuda"):
"""
缓存 BERT 提取函数,用于相同 norm_text 时复用特征
Args:
norm_text (str): 清洗后的文本(可复用)
word2ph_tuple (tuple): word2ph 列表转换成 tuple因为 lru_cache 不支持 list
device_str (str): 设备信息,用于转移到正确设备上
Returns:
Tensor: 形状 [hidden_dim, total_phonemes]
"""
from transformers import AutoTokenizer, AutoModelForMaskedLM
# 如果你在类里,可以改成 self.tokenizer 和 self.model
tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
model = AutoModelForMaskedLM.from_pretrained("bert-base-chinese", output_hidden_states=True).eval().to(device_str)
inputs = tokenizer(norm_text, return_tensors="pt").to(device_str)
with torch.no_grad():
outputs = model(**inputs)
hidden = torch.cat(outputs.hidden_states[-3:-2], dim=-1)[0][1:-1] # 去掉 CLS/SEP
word2ph = torch.tensor(list(word2ph_tuple), device=hidden.device)
indices = torch.repeat_interleave(torch.arange(len(word2ph), device=hidden.device), word2ph)
phone_level_feature = hidden[indices]
return phone_level_feature.T.cpu()