mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-08-19 16:29:47 +08:00
更新BERT模块缓存为一个独立模块进行调用
This commit is contained in:
parent
321d01f6d1
commit
770a1143ad
@ -22,6 +22,7 @@ from functools import lru_cache
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from cached import get_cached_bert
|
from cached import get_cached_bert
|
||||||
|
from cached import CachedBertExtractor
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -62,6 +63,8 @@ class TextPreprocessor:
|
|||||||
self.device = device
|
self.device = device
|
||||||
self.bert_lock = threading.RLock()
|
self.bert_lock = threading.RLock()
|
||||||
|
|
||||||
|
self.bert_extractor = CachedBertExtractor("bert-base-chinese", device=device)
|
||||||
|
|
||||||
def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]:
|
def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]:
|
||||||
print(f"############ {i18n('切分文本')} ############")
|
print(f"############ {i18n('切分文本')} ############")
|
||||||
text = self.replace_consecutive_punctuation(text)
|
text = self.replace_consecutive_punctuation(text)
|
||||||
|
@ -1,5 +1,9 @@
|
|||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
import torch
|
import torch
|
||||||
|
import torch
|
||||||
|
from functools import lru_cache
|
||||||
|
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
@lru_cache(maxsize=1000)
|
@lru_cache(maxsize=1000)
|
||||||
def get_cached_bert(norm_text: str, word2ph_tuple: tuple, device_str: str = "cuda"):
|
def get_cached_bert(norm_text: str, word2ph_tuple: tuple, device_str: str = "cuda"):
|
||||||
@ -28,3 +32,44 @@ def get_cached_bert(norm_text: str, word2ph_tuple: tuple, device_str: str = "cud
|
|||||||
indices = torch.repeat_interleave(torch.arange(len(word2ph), device=hidden.device), word2ph)
|
indices = torch.repeat_interleave(torch.arange(len(word2ph), device=hidden.device), word2ph)
|
||||||
phone_level_feature = hidden[indices]
|
phone_level_feature = hidden[indices]
|
||||||
return phone_level_feature.T.cpu()
|
return phone_level_feature.T.cpu()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class CachedBertExtractor:
|
||||||
|
def __init__(self, model_name_or_path: str = "bert-base-chinese", device: str = "cuda"):
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
||||||
|
self.device = device
|
||||||
|
self.bert_model = AutoModelForMaskedLM.from_pretrained(
|
||||||
|
model_name_or_path, output_hidden_states=True
|
||||||
|
).eval().to(device)
|
||||||
|
|
||||||
|
def get_bert_feature(self, norm_text: str, word2ph: List[int]) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Public method: gets cached BERT feature tensor
|
||||||
|
"""
|
||||||
|
word2ph_tuple = tuple(word2ph)
|
||||||
|
return self._cached_bert(norm_text, word2ph_tuple).to(self.device)
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1024)
|
||||||
|
def _cached_bert(self, norm_text: str, word2ph_tuple: Tuple[int, ...]) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Cached private method: returns CPU tensor (for lru_cache compatibility)
|
||||||
|
"""
|
||||||
|
inputs = self.tokenizer(norm_text, return_tensors="pt").to(self.device)
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = self.bert_model(**inputs)
|
||||||
|
hidden = torch.cat(outputs.hidden_states[-3:-2], dim=-1)[0][1:-1] # shape: [seq_len-2, hidden_dim]
|
||||||
|
|
||||||
|
word2ph_tensor = torch.tensor(list(word2ph_tuple), device=self.device)
|
||||||
|
indices = torch.repeat_interleave(torch.arange(len(word2ph_tuple), device=self.device), word2ph_tensor)
|
||||||
|
phone_level_feature = hidden[indices] # [sum(word2ph), hidden_size]
|
||||||
|
|
||||||
|
return phone_level_feature.T.cpu() # cache-safe
|
||||||
|
|
||||||
|
def clear_cache(self):
|
||||||
|
"""
|
||||||
|
Clear the internal BERT feature cache
|
||||||
|
"""
|
||||||
|
self._cached_bert.cache_clear()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user