From 770a1143add20ae78cccb7704ee0bbd773615e84 Mon Sep 17 00:00:00 2001 From: Karyl01 <148410288+Karyl01@users.noreply.github.com> Date: Sun, 11 May 2025 22:55:34 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0BERT=E6=A8=A1=E5=9D=97?= =?UTF-8?q?=E7=BC=93=E5=AD=98=E4=B8=BA=E4=B8=80=E4=B8=AA=E7=8B=AC=E7=AB=8B?= =?UTF-8?q?=E6=A8=A1=E5=9D=97=E8=BF=9B=E8=A1=8C=E8=B0=83=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py | 3 ++ GPT_SoVITS/TTS_infer_pack/cached.py | 45 +++++++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py index 83f8cc6b..1c28bd52 100644 --- a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py +++ b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py @@ -22,6 +22,7 @@ from functools import lru_cache import torch from cached import get_cached_bert +from cached import CachedBertExtractor @@ -62,6 +63,8 @@ class TextPreprocessor: self.device = device self.bert_lock = threading.RLock() + self.bert_extractor = CachedBertExtractor("bert-base-chinese", device=device) + def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]: print(f"############ {i18n('切分文本')} ############") text = self.replace_consecutive_punctuation(text) diff --git a/GPT_SoVITS/TTS_infer_pack/cached.py b/GPT_SoVITS/TTS_infer_pack/cached.py index 58484c9c..56405328 100644 --- a/GPT_SoVITS/TTS_infer_pack/cached.py +++ b/GPT_SoVITS/TTS_infer_pack/cached.py @@ -1,5 +1,9 @@ from functools import lru_cache import torch +import torch +from functools import lru_cache +from transformers import AutoTokenizer, AutoModelForMaskedLM +from typing import List, Tuple @lru_cache(maxsize=1000) def get_cached_bert(norm_text: str, word2ph_tuple: tuple, device_str: str = "cuda"): @@ -28,3 +32,44 @@ def get_cached_bert(norm_text: str, word2ph_tuple: tuple, device_str: str = "cud indices = torch.repeat_interleave(torch.arange(len(word2ph), device=hidden.device), word2ph) phone_level_feature = hidden[indices] return phone_level_feature.T.cpu() + + + + + +class CachedBertExtractor: + def __init__(self, model_name_or_path: str = "bert-base-chinese", device: str = "cuda"): + self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + self.device = device + self.bert_model = AutoModelForMaskedLM.from_pretrained( + model_name_or_path, output_hidden_states=True + ).eval().to(device) + + def get_bert_feature(self, norm_text: str, word2ph: List[int]) -> torch.Tensor: + """ + Public method: gets cached BERT feature tensor + """ + word2ph_tuple = tuple(word2ph) + return self._cached_bert(norm_text, word2ph_tuple).to(self.device) + + @lru_cache(maxsize=1024) + def _cached_bert(self, norm_text: str, word2ph_tuple: Tuple[int, ...]) -> torch.Tensor: + """ + Cached private method: returns CPU tensor (for lru_cache compatibility) + """ + inputs = self.tokenizer(norm_text, return_tensors="pt").to(self.device) + with torch.no_grad(): + outputs = self.bert_model(**inputs) + hidden = torch.cat(outputs.hidden_states[-3:-2], dim=-1)[0][1:-1] # shape: [seq_len-2, hidden_dim] + + word2ph_tensor = torch.tensor(list(word2ph_tuple), device=self.device) + indices = torch.repeat_interleave(torch.arange(len(word2ph_tuple), device=self.device), word2ph_tensor) + phone_level_feature = hidden[indices] # [sum(word2ph), hidden_size] + + return phone_level_feature.T.cpu() # cache-safe + + def clear_cache(self): + """ + Clear the internal BERT feature cache + """ + self._cached_bert.cache_clear()