mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-08-19 08:19:47 +08:00
更新BERT模块缓存为一个独立模块进行调用
This commit is contained in:
parent
321d01f6d1
commit
770a1143ad
@ -22,6 +22,7 @@ from functools import lru_cache
|
||||
import torch
|
||||
|
||||
from cached import get_cached_bert
|
||||
from cached import CachedBertExtractor
|
||||
|
||||
|
||||
|
||||
@ -62,6 +63,8 @@ class TextPreprocessor:
|
||||
self.device = device
|
||||
self.bert_lock = threading.RLock()
|
||||
|
||||
self.bert_extractor = CachedBertExtractor("bert-base-chinese", device=device)
|
||||
|
||||
def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]:
|
||||
print(f"############ {i18n('切分文本')} ############")
|
||||
text = self.replace_consecutive_punctuation(text)
|
||||
|
@ -1,5 +1,9 @@
|
||||
from functools import lru_cache
|
||||
import torch
|
||||
import torch
|
||||
from functools import lru_cache
|
||||
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
||||
from typing import List, Tuple
|
||||
|
||||
@lru_cache(maxsize=1000)
|
||||
def get_cached_bert(norm_text: str, word2ph_tuple: tuple, device_str: str = "cuda"):
|
||||
@ -28,3 +32,44 @@ def get_cached_bert(norm_text: str, word2ph_tuple: tuple, device_str: str = "cud
|
||||
indices = torch.repeat_interleave(torch.arange(len(word2ph), device=hidden.device), word2ph)
|
||||
phone_level_feature = hidden[indices]
|
||||
return phone_level_feature.T.cpu()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class CachedBertExtractor:
|
||||
def __init__(self, model_name_or_path: str = "bert-base-chinese", device: str = "cuda"):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
||||
self.device = device
|
||||
self.bert_model = AutoModelForMaskedLM.from_pretrained(
|
||||
model_name_or_path, output_hidden_states=True
|
||||
).eval().to(device)
|
||||
|
||||
def get_bert_feature(self, norm_text: str, word2ph: List[int]) -> torch.Tensor:
|
||||
"""
|
||||
Public method: gets cached BERT feature tensor
|
||||
"""
|
||||
word2ph_tuple = tuple(word2ph)
|
||||
return self._cached_bert(norm_text, word2ph_tuple).to(self.device)
|
||||
|
||||
@lru_cache(maxsize=1024)
|
||||
def _cached_bert(self, norm_text: str, word2ph_tuple: Tuple[int, ...]) -> torch.Tensor:
|
||||
"""
|
||||
Cached private method: returns CPU tensor (for lru_cache compatibility)
|
||||
"""
|
||||
inputs = self.tokenizer(norm_text, return_tensors="pt").to(self.device)
|
||||
with torch.no_grad():
|
||||
outputs = self.bert_model(**inputs)
|
||||
hidden = torch.cat(outputs.hidden_states[-3:-2], dim=-1)[0][1:-1] # shape: [seq_len-2, hidden_dim]
|
||||
|
||||
word2ph_tensor = torch.tensor(list(word2ph_tuple), device=self.device)
|
||||
indices = torch.repeat_interleave(torch.arange(len(word2ph_tuple), device=self.device), word2ph_tensor)
|
||||
phone_level_feature = hidden[indices] # [sum(word2ph), hidden_size]
|
||||
|
||||
return phone_level_feature.T.cpu() # cache-safe
|
||||
|
||||
def clear_cache(self):
|
||||
"""
|
||||
Clear the internal BERT feature cache
|
||||
"""
|
||||
self._cached_bert.cache_clear()
|
||||
|
Loading…
x
Reference in New Issue
Block a user