diff --git a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py index 409b9e19..83f8cc6b 100644 --- a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py +++ b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py @@ -1,4 +1,3 @@ -from imghdr import tests import os import sys import threading @@ -18,9 +17,13 @@ from text import cleaned_text_to_sequence from transformers import AutoModelForMaskedLM, AutoTokenizer from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method -from export_torch_script_v3 import extract_bert_features - from tools.i18n.i18n import I18nAuto, scan_language_list +from functools import lru_cache +import torch + +from cached import get_cached_bert + + language = os.environ.get("language", "Auto") language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language @@ -59,21 +62,14 @@ class TextPreprocessor: self.device = device self.bert_lock = threading.RLock() - def preprocess1(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]: + def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]: print(f"############ {i18n('切分文本')} ############") text = self.replace_consecutive_punctuation(text) texts = self.pre_seg_text(text, lang, text_split_method) result = [] - # text_batch = [] print(f"############ {i18n('提取文本Bert特征')} ############") - text_batch = [] - for text in texts: - if text.strip(): # 忽略空句子 - text_batch.append(text) - if not text_batch: - return [] - phones_list, bert_list, norm_texts = self.batch_get_phones_and_bert(text_batch, lang, version) - for phones, bert_features, norm_text in zip(phones_list, bert_list, norm_texts): + for text in tqdm(texts): + phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version) if phones is None or norm_text == "": continue res = { @@ -83,52 +79,6 @@ class TextPreprocessor: } result.append(res) return result - # for text in texts: - # if text.strip(): # 忽略空句子 - # text_batch.append(text) - # phones_list, bert_list, norm_texts = self.batch_get_phones_and_bert(text_batch, lang, version) - # for phones, bert_features, norm_text in zip(phones_list, bert_list, norm_texts): - # if phones is None or norm_text == "": - # continue - # res = { - # "phones": phones, - # "bert_features": bert_features, - # "norm_text": norm_text, - # } - # result.append(res) - # return result - - - # for text in tqdm(texts): - # phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version) - # if phones is None or norm_text == "": - # continue - # res = { - # "phones": phones, - # "bert_features": bert_features, - # "norm_text": norm_text, - # } - # result.append(res) - - # return result - - @torch.jit.script - def build_phone_level_feature(res: torch.Tensor, word2ph: torch.IntTensor) -> torch.Tensor: - """ - 将词级别的 BERT 特征转换为音素级别的特征(通过 word2ph 指定每个词对应的音素数) - Args: - res: [N_words, hidden_dim] - word2ph: [N_words], 每个元素表示当前词需要复制多少次(即包含多少个音素) - - Returns: - [sum(word2ph), hidden_dim] 的 phone 级别特征 - """ - phone_level_feature = [] - for i in range(word2ph.shape[0]): - repeat_feature = res[i].repeat(word2ph[i].item(), 1) - phone_level_feature.append(repeat_feature) - return torch.cat(phone_level_feature, dim=0) - def pre_seg_text(self, text: str, lang: str, text_split_method: str): text = text.strip("\n") @@ -242,22 +192,25 @@ class TextPreprocessor: return phones, bert, norm_text - def get_bert_feature(self, text: str, word2ph: list) -> torch.Tensor: - with torch.no_grad(): - inputs = self.tokenizer(text, return_tensors="pt") - for i in inputs: - inputs[i] = inputs[i].to(self.device) - res = self.bert_model(**inputs, output_hidden_states=True) - # # 优化:保留在GPU处理直到需要时再转CPU - res = torch.cat(res["hidden_states"][-3:-2], -1)[0][1:-1] # 移除不必要的cpu()调用 - assert len(word2ph) == len(text) - # 向量化优化:使用repeat_interleave替代循环 - word2ph_tensor = torch.tensor(word2ph, device=res.device) - indices = torch.repeat_interleave(torch.arange(len(word2ph), device=res.device), word2ph_tensor) - phone_level_feature = res[indices] - # 仅在需要时转CPU - phone_level_feature = phone_level_feature.cpu() - return phone_level_feature.T + # def get_bert_feature(self, text: str, word2ph: list) -> torch.Tensor: + # with torch.no_grad(): + # inputs = self.tokenizer(text, return_tensors="pt") + # for i in inputs: + # inputs[i] = inputs[i].to(self.device) + # res = self.bert_model(**inputs, output_hidden_states=True) + # res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1] + # assert len(word2ph) == len(text) + # phone_level_feature = [] + # for i in range(len(word2ph)): + # repeat_feature = res[i].repeat(word2ph[i], 1) + # phone_level_feature.append(repeat_feature) + # phone_level_feature = torch.cat(phone_level_feature, dim=0) + # return phone_level_feature.T + + def get_bert_feature(self, norm_text: str, word2ph: list) -> torch.Tensor: + # 注意:word2ph 是 list,需转为 tuple 作为缓存键 + bert = get_cached_bert(norm_text, tuple(word2ph), str(self.device)) + return bert.to(self.device) def clean_text_inf(self, text: str, language: str, version: str = "v2"): language = language.replace("all_", "") @@ -295,79 +248,3 @@ class TextPreprocessor: return result - # def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]: - # print(f"############ {i18n('切分文本')} ############") - # text = self.replace_consecutive_punctuation(text) - # texts = self.pre_seg_text(text, lang, text_split_method) - # result = [] - # print(f"############ {i18n('提取文本Bert特征')} ############") - # for text in tqdm(texts): - # phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version) - # if phones is None or norm_text == "": - # continue - # res = { - # "phones": phones, - # "bert_features": bert_features, - # "norm_text": norm_text, - # } - # result.append(res) - # return result - - def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]: - print(f"############ {i18n('切分文本')} ############") - text = self.replace_consecutive_punctuation(text) - texts = self.pre_seg_text(text, lang, text_split_method) - result = [] - - print(f"############ {i18n('提取文本Bert特征')} ############") - extract_bert_features(texts) - for text in texts: # - phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version) - if phones is None or norm_text == "": - continue - res = { - "phones": phones, - "bert_features": bert_features, - "norm_text": norm_text, - } - result.append(res) - return result - - - def batch_get_phones_and_bert(self, texts: List[str], language: str, version: str): - phones_list = [] - bert_list = [] - norm_text_list = [] - - # 预处理文本,获取每句的 phones, word2ph, norm_text - format_texts = [self.clean_text_inf(t, language, version) for t in texts] - norm_texts = [x[2] for x in format_texts] - word2phs = [x[1] for x in format_texts] - - # 批量送入 tokenizer - inputs = self.tokenizer(norm_texts, return_tensors="pt", padding=True, truncation=True) - inputs = {k: v.to(self.device) for k, v in inputs.items()} - - with torch.no_grad(): - outputs = self.bert_model(**inputs, output_hidden_states=True) - # 使用 last_hidden_state 是正确且高效的方式 - hidden_states = outputs.last_hidden_state # [batch_size, seq_len, hidden_dim] - - for i in range(len(texts)): - res = hidden_states[i][1:-1].cpu() # 去掉 [CLS] 和 [SEP] - - word2ph = word2phs[i] - phone_level_feature = [] - for j in range(len(word2ph)): - if j >= res.shape[0]: - print(f"警告:BERT输出不足,跳过第 {i} 句中第 {j} 个 token") - continue - phone_level_feature.append(res[j].repeat(word2ph[j], 1)) - phone_level_feature = torch.cat(phone_level_feature, dim=0) - - bert_list.append(phone_level_feature.T) - phones_list.append(cleaned_text_to_sequence(format_texts[i][0], version)) - norm_text_list.append(norm_texts[i]) - - return phones_list, bert_list, norm_text_list - diff --git a/GPT_SoVITS/TTS_infer_pack/cached.py b/GPT_SoVITS/TTS_infer_pack/cached.py new file mode 100644 index 00000000..58484c9c --- /dev/null +++ b/GPT_SoVITS/TTS_infer_pack/cached.py @@ -0,0 +1,30 @@ +from functools import lru_cache +import torch + +@lru_cache(maxsize=1000) +def get_cached_bert(norm_text: str, word2ph_tuple: tuple, device_str: str = "cuda"): + """ + 缓存 BERT 提取函数,用于相同 norm_text 时复用特征 + + Args: + norm_text (str): 清洗后的文本(可复用) + word2ph_tuple (tuple): word2ph 列表转换成 tuple(因为 lru_cache 不支持 list) + device_str (str): 设备信息,用于转移到正确设备上 + + Returns: + Tensor: 形状 [hidden_dim, total_phonemes] + """ + from transformers import AutoTokenizer, AutoModelForMaskedLM + + # 如果你在类里,可以改成 self.tokenizer 和 self.model + tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese") + model = AutoModelForMaskedLM.from_pretrained("bert-base-chinese", output_hidden_states=True).eval().to(device_str) + + inputs = tokenizer(norm_text, return_tensors="pt").to(device_str) + with torch.no_grad(): + outputs = model(**inputs) + hidden = torch.cat(outputs.hidden_states[-3:-2], dim=-1)[0][1:-1] # 去掉 CLS/SEP + word2ph = torch.tensor(list(word2ph_tuple), device=hidden.device) + indices = torch.repeat_interleave(torch.arange(len(word2ph), device=hidden.device), word2ph) + phone_level_feature = hidden[indices] + return phone_level_feature.T.cpu() diff --git a/output.wav b/output.wav index c5babeab..a3842845 100644 Binary files a/output.wav and b/output.wav differ