mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-08-15 13:29:51 +08:00
76 lines
3.0 KiB
Python
76 lines
3.0 KiB
Python
from functools import lru_cache
|
||
import torch
|
||
import torch
|
||
from functools import lru_cache
|
||
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
||
from typing import List, Tuple
|
||
|
||
@lru_cache(maxsize=1000)
|
||
def get_cached_bert(norm_text: str, word2ph_tuple: tuple, device_str: str = "cuda"):
|
||
"""
|
||
缓存 BERT 提取函数,用于相同 norm_text 时复用特征
|
||
|
||
Args:
|
||
norm_text (str): 清洗后的文本(可复用)
|
||
word2ph_tuple (tuple): word2ph 列表转换成 tuple(因为 lru_cache 不支持 list)
|
||
device_str (str): 设备信息,用于转移到正确设备上
|
||
|
||
Returns:
|
||
Tensor: 形状 [hidden_dim, total_phonemes]
|
||
"""
|
||
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
||
|
||
# 如果你在类里,可以改成 self.tokenizer 和 self.model
|
||
tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
|
||
model = AutoModelForMaskedLM.from_pretrained("bert-base-chinese", output_hidden_states=True).eval().to(device_str)
|
||
|
||
inputs = tokenizer(norm_text, return_tensors="pt").to(device_str)
|
||
with torch.no_grad():
|
||
outputs = model(**inputs)
|
||
hidden = torch.cat(outputs.hidden_states[-3:-2], dim=-1)[0][1:-1] # 去掉 CLS/SEP
|
||
word2ph = torch.tensor(list(word2ph_tuple), device=hidden.device)
|
||
indices = torch.repeat_interleave(torch.arange(len(word2ph), device=hidden.device), word2ph)
|
||
phone_level_feature = hidden[indices]
|
||
return phone_level_feature.T.cpu()
|
||
|
||
|
||
|
||
|
||
|
||
class CachedBertExtractor:
|
||
def __init__(self, model_name_or_path: str = "bert-base-chinese", device: str = "cuda"):
|
||
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
||
self.device = device
|
||
self.bert_model = AutoModelForMaskedLM.from_pretrained(
|
||
model_name_or_path, output_hidden_states=True
|
||
).eval().to(device)
|
||
|
||
def get_bert_feature(self, norm_text: str, word2ph: List[int]) -> torch.Tensor:
|
||
"""
|
||
Public method: gets cached BERT feature tensor
|
||
"""
|
||
word2ph_tuple = tuple(word2ph)
|
||
return self._cached_bert(norm_text, word2ph_tuple).to(self.device)
|
||
|
||
@lru_cache(maxsize=1024)
|
||
def _cached_bert(self, norm_text: str, word2ph_tuple: Tuple[int, ...]) -> torch.Tensor:
|
||
"""
|
||
Cached private method: returns CPU tensor (for lru_cache compatibility)
|
||
"""
|
||
inputs = self.tokenizer(norm_text, return_tensors="pt").to(self.device)
|
||
with torch.no_grad():
|
||
outputs = self.bert_model(**inputs)
|
||
hidden = torch.cat(outputs.hidden_states[-3:-2], dim=-1)[0][1:-1] # shape: [seq_len-2, hidden_dim]
|
||
|
||
word2ph_tensor = torch.tensor(list(word2ph_tuple), device=self.device)
|
||
indices = torch.repeat_interleave(torch.arange(len(word2ph_tuple), device=self.device), word2ph_tensor)
|
||
phone_level_feature = hidden[indices] # [sum(word2ph), hidden_size]
|
||
|
||
return phone_level_feature.T.cpu() # cache-safe
|
||
|
||
def clear_cache(self):
|
||
"""
|
||
Clear the internal BERT feature cache
|
||
"""
|
||
self._cached_bert.cache_clear()
|