mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-08-22 19:19:47 +08:00
BERT缓存记录,减少对于同样token的转化
This commit is contained in:
parent
f1cfc63851
commit
321d01f6d1
@ -1,4 +1,3 @@
|
|||||||
from imghdr import tests
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
@ -18,9 +17,13 @@ from text import cleaned_text_to_sequence
|
|||||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||||
from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method
|
from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method
|
||||||
|
|
||||||
from export_torch_script_v3 import extract_bert_features
|
|
||||||
|
|
||||||
from tools.i18n.i18n import I18nAuto, scan_language_list
|
from tools.i18n.i18n import I18nAuto, scan_language_list
|
||||||
|
from functools import lru_cache
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from cached import get_cached_bert
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
language = os.environ.get("language", "Auto")
|
language = os.environ.get("language", "Auto")
|
||||||
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
|
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
|
||||||
@ -59,21 +62,14 @@ class TextPreprocessor:
|
|||||||
self.device = device
|
self.device = device
|
||||||
self.bert_lock = threading.RLock()
|
self.bert_lock = threading.RLock()
|
||||||
|
|
||||||
def preprocess1(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]:
|
def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]:
|
||||||
print(f"############ {i18n('切分文本')} ############")
|
print(f"############ {i18n('切分文本')} ############")
|
||||||
text = self.replace_consecutive_punctuation(text)
|
text = self.replace_consecutive_punctuation(text)
|
||||||
texts = self.pre_seg_text(text, lang, text_split_method)
|
texts = self.pre_seg_text(text, lang, text_split_method)
|
||||||
result = []
|
result = []
|
||||||
# text_batch = []
|
|
||||||
print(f"############ {i18n('提取文本Bert特征')} ############")
|
print(f"############ {i18n('提取文本Bert特征')} ############")
|
||||||
text_batch = []
|
for text in tqdm(texts):
|
||||||
for text in texts:
|
phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version)
|
||||||
if text.strip(): # 忽略空句子
|
|
||||||
text_batch.append(text)
|
|
||||||
if not text_batch:
|
|
||||||
return []
|
|
||||||
phones_list, bert_list, norm_texts = self.batch_get_phones_and_bert(text_batch, lang, version)
|
|
||||||
for phones, bert_features, norm_text in zip(phones_list, bert_list, norm_texts):
|
|
||||||
if phones is None or norm_text == "":
|
if phones is None or norm_text == "":
|
||||||
continue
|
continue
|
||||||
res = {
|
res = {
|
||||||
@ -83,52 +79,6 @@ class TextPreprocessor:
|
|||||||
}
|
}
|
||||||
result.append(res)
|
result.append(res)
|
||||||
return result
|
return result
|
||||||
# for text in texts:
|
|
||||||
# if text.strip(): # 忽略空句子
|
|
||||||
# text_batch.append(text)
|
|
||||||
# phones_list, bert_list, norm_texts = self.batch_get_phones_and_bert(text_batch, lang, version)
|
|
||||||
# for phones, bert_features, norm_text in zip(phones_list, bert_list, norm_texts):
|
|
||||||
# if phones is None or norm_text == "":
|
|
||||||
# continue
|
|
||||||
# res = {
|
|
||||||
# "phones": phones,
|
|
||||||
# "bert_features": bert_features,
|
|
||||||
# "norm_text": norm_text,
|
|
||||||
# }
|
|
||||||
# result.append(res)
|
|
||||||
# return result
|
|
||||||
|
|
||||||
|
|
||||||
# for text in tqdm(texts):
|
|
||||||
# phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version)
|
|
||||||
# if phones is None or norm_text == "":
|
|
||||||
# continue
|
|
||||||
# res = {
|
|
||||||
# "phones": phones,
|
|
||||||
# "bert_features": bert_features,
|
|
||||||
# "norm_text": norm_text,
|
|
||||||
# }
|
|
||||||
# result.append(res)
|
|
||||||
|
|
||||||
# return result
|
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
def build_phone_level_feature(res: torch.Tensor, word2ph: torch.IntTensor) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
将词级别的 BERT 特征转换为音素级别的特征(通过 word2ph 指定每个词对应的音素数)
|
|
||||||
Args:
|
|
||||||
res: [N_words, hidden_dim]
|
|
||||||
word2ph: [N_words], 每个元素表示当前词需要复制多少次(即包含多少个音素)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
[sum(word2ph), hidden_dim] 的 phone 级别特征
|
|
||||||
"""
|
|
||||||
phone_level_feature = []
|
|
||||||
for i in range(word2ph.shape[0]):
|
|
||||||
repeat_feature = res[i].repeat(word2ph[i].item(), 1)
|
|
||||||
phone_level_feature.append(repeat_feature)
|
|
||||||
return torch.cat(phone_level_feature, dim=0)
|
|
||||||
|
|
||||||
|
|
||||||
def pre_seg_text(self, text: str, lang: str, text_split_method: str):
|
def pre_seg_text(self, text: str, lang: str, text_split_method: str):
|
||||||
text = text.strip("\n")
|
text = text.strip("\n")
|
||||||
@ -242,22 +192,25 @@ class TextPreprocessor:
|
|||||||
|
|
||||||
return phones, bert, norm_text
|
return phones, bert, norm_text
|
||||||
|
|
||||||
def get_bert_feature(self, text: str, word2ph: list) -> torch.Tensor:
|
# def get_bert_feature(self, text: str, word2ph: list) -> torch.Tensor:
|
||||||
with torch.no_grad():
|
# with torch.no_grad():
|
||||||
inputs = self.tokenizer(text, return_tensors="pt")
|
# inputs = self.tokenizer(text, return_tensors="pt")
|
||||||
for i in inputs:
|
# for i in inputs:
|
||||||
inputs[i] = inputs[i].to(self.device)
|
# inputs[i] = inputs[i].to(self.device)
|
||||||
res = self.bert_model(**inputs, output_hidden_states=True)
|
# res = self.bert_model(**inputs, output_hidden_states=True)
|
||||||
# # 优化:保留在GPU处理直到需要时再转CPU
|
# res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
|
||||||
res = torch.cat(res["hidden_states"][-3:-2], -1)[0][1:-1] # 移除不必要的cpu()调用
|
# assert len(word2ph) == len(text)
|
||||||
assert len(word2ph) == len(text)
|
# phone_level_feature = []
|
||||||
# 向量化优化:使用repeat_interleave替代循环
|
# for i in range(len(word2ph)):
|
||||||
word2ph_tensor = torch.tensor(word2ph, device=res.device)
|
# repeat_feature = res[i].repeat(word2ph[i], 1)
|
||||||
indices = torch.repeat_interleave(torch.arange(len(word2ph), device=res.device), word2ph_tensor)
|
# phone_level_feature.append(repeat_feature)
|
||||||
phone_level_feature = res[indices]
|
# phone_level_feature = torch.cat(phone_level_feature, dim=0)
|
||||||
# 仅在需要时转CPU
|
# return phone_level_feature.T
|
||||||
phone_level_feature = phone_level_feature.cpu()
|
|
||||||
return phone_level_feature.T
|
def get_bert_feature(self, norm_text: str, word2ph: list) -> torch.Tensor:
|
||||||
|
# 注意:word2ph 是 list,需转为 tuple 作为缓存键
|
||||||
|
bert = get_cached_bert(norm_text, tuple(word2ph), str(self.device))
|
||||||
|
return bert.to(self.device)
|
||||||
|
|
||||||
def clean_text_inf(self, text: str, language: str, version: str = "v2"):
|
def clean_text_inf(self, text: str, language: str, version: str = "v2"):
|
||||||
language = language.replace("all_", "")
|
language = language.replace("all_", "")
|
||||||
@ -295,79 +248,3 @@ class TextPreprocessor:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
# def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]:
|
|
||||||
# print(f"############ {i18n('切分文本')} ############")
|
|
||||||
# text = self.replace_consecutive_punctuation(text)
|
|
||||||
# texts = self.pre_seg_text(text, lang, text_split_method)
|
|
||||||
# result = []
|
|
||||||
# print(f"############ {i18n('提取文本Bert特征')} ############")
|
|
||||||
# for text in tqdm(texts):
|
|
||||||
# phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version)
|
|
||||||
# if phones is None or norm_text == "":
|
|
||||||
# continue
|
|
||||||
# res = {
|
|
||||||
# "phones": phones,
|
|
||||||
# "bert_features": bert_features,
|
|
||||||
# "norm_text": norm_text,
|
|
||||||
# }
|
|
||||||
# result.append(res)
|
|
||||||
# return result
|
|
||||||
|
|
||||||
def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]:
|
|
||||||
print(f"############ {i18n('切分文本')} ############")
|
|
||||||
text = self.replace_consecutive_punctuation(text)
|
|
||||||
texts = self.pre_seg_text(text, lang, text_split_method)
|
|
||||||
result = []
|
|
||||||
|
|
||||||
print(f"############ {i18n('提取文本Bert特征')} ############")
|
|
||||||
extract_bert_features(texts)
|
|
||||||
for text in texts: #
|
|
||||||
phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version)
|
|
||||||
if phones is None or norm_text == "":
|
|
||||||
continue
|
|
||||||
res = {
|
|
||||||
"phones": phones,
|
|
||||||
"bert_features": bert_features,
|
|
||||||
"norm_text": norm_text,
|
|
||||||
}
|
|
||||||
result.append(res)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def batch_get_phones_and_bert(self, texts: List[str], language: str, version: str):
|
|
||||||
phones_list = []
|
|
||||||
bert_list = []
|
|
||||||
norm_text_list = []
|
|
||||||
|
|
||||||
# 预处理文本,获取每句的 phones, word2ph, norm_text
|
|
||||||
format_texts = [self.clean_text_inf(t, language, version) for t in texts]
|
|
||||||
norm_texts = [x[2] for x in format_texts]
|
|
||||||
word2phs = [x[1] for x in format_texts]
|
|
||||||
|
|
||||||
# 批量送入 tokenizer
|
|
||||||
inputs = self.tokenizer(norm_texts, return_tensors="pt", padding=True, truncation=True)
|
|
||||||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
outputs = self.bert_model(**inputs, output_hidden_states=True)
|
|
||||||
# 使用 last_hidden_state 是正确且高效的方式
|
|
||||||
hidden_states = outputs.last_hidden_state # [batch_size, seq_len, hidden_dim]
|
|
||||||
|
|
||||||
for i in range(len(texts)):
|
|
||||||
res = hidden_states[i][1:-1].cpu() # 去掉 [CLS] 和 [SEP]
|
|
||||||
|
|
||||||
word2ph = word2phs[i]
|
|
||||||
phone_level_feature = []
|
|
||||||
for j in range(len(word2ph)):
|
|
||||||
if j >= res.shape[0]:
|
|
||||||
print(f"警告:BERT输出不足,跳过第 {i} 句中第 {j} 个 token")
|
|
||||||
continue
|
|
||||||
phone_level_feature.append(res[j].repeat(word2ph[j], 1))
|
|
||||||
phone_level_feature = torch.cat(phone_level_feature, dim=0)
|
|
||||||
|
|
||||||
bert_list.append(phone_level_feature.T)
|
|
||||||
phones_list.append(cleaned_text_to_sequence(format_texts[i][0], version))
|
|
||||||
norm_text_list.append(norm_texts[i])
|
|
||||||
|
|
||||||
return phones_list, bert_list, norm_text_list
|
|
||||||
|
|
||||||
|
30
GPT_SoVITS/TTS_infer_pack/cached.py
Normal file
30
GPT_SoVITS/TTS_infer_pack/cached.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
from functools import lru_cache
|
||||||
|
import torch
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1000)
|
||||||
|
def get_cached_bert(norm_text: str, word2ph_tuple: tuple, device_str: str = "cuda"):
|
||||||
|
"""
|
||||||
|
缓存 BERT 提取函数,用于相同 norm_text 时复用特征
|
||||||
|
|
||||||
|
Args:
|
||||||
|
norm_text (str): 清洗后的文本(可复用)
|
||||||
|
word2ph_tuple (tuple): word2ph 列表转换成 tuple(因为 lru_cache 不支持 list)
|
||||||
|
device_str (str): 设备信息,用于转移到正确设备上
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: 形状 [hidden_dim, total_phonemes]
|
||||||
|
"""
|
||||||
|
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
||||||
|
|
||||||
|
# 如果你在类里,可以改成 self.tokenizer 和 self.model
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
|
||||||
|
model = AutoModelForMaskedLM.from_pretrained("bert-base-chinese", output_hidden_states=True).eval().to(device_str)
|
||||||
|
|
||||||
|
inputs = tokenizer(norm_text, return_tensors="pt").to(device_str)
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(**inputs)
|
||||||
|
hidden = torch.cat(outputs.hidden_states[-3:-2], dim=-1)[0][1:-1] # 去掉 CLS/SEP
|
||||||
|
word2ph = torch.tensor(list(word2ph_tuple), device=hidden.device)
|
||||||
|
indices = torch.repeat_interleave(torch.arange(len(word2ph), device=hidden.device), word2ph)
|
||||||
|
phone_level_feature = hidden[indices]
|
||||||
|
return phone_level_feature.T.cpu()
|
BIN
output.wav
BIN
output.wav
Binary file not shown.
Loading…
x
Reference in New Issue
Block a user