BERT转换为TorchScript

BERT转换为TorchScript,下一步准备ONNX加速
This commit is contained in:
Karyl01 2025-05-11 21:59:09 +08:00
parent c7b61c6fd4
commit f1cfc63851
5 changed files with 177 additions and 15 deletions

View File

@ -1,3 +1,4 @@
from imghdr import tests
import os
import sys
import threading
@ -17,6 +18,8 @@ from text import cleaned_text_to_sequence
from transformers import AutoModelForMaskedLM, AutoTokenizer
from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method
from export_torch_script_v3 import extract_bert_features
from tools.i18n.i18n import I18nAuto, scan_language_list
language = os.environ.get("language", "Auto")
@ -56,15 +59,21 @@ class TextPreprocessor:
self.device = device
self.bert_lock = threading.RLock()
def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]:
def preprocess1(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]:
print(f"############ {i18n('切分文本')} ############")
text = self.replace_consecutive_punctuation(text)
texts = self.pre_seg_text(text, lang, text_split_method)
result = []
# text_batch = []
print(f"############ {i18n('提取文本Bert特征')} ############")
for text in tqdm(texts):
phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version)
text_batch = []
for text in texts:
if text.strip(): # 忽略空句子
text_batch.append(text)
if not text_batch:
return []
phones_list, bert_list, norm_texts = self.batch_get_phones_and_bert(text_batch, lang, version)
for phones, bert_features, norm_text in zip(phones_list, bert_list, norm_texts):
if phones is None or norm_text == "":
continue
res = {
@ -103,6 +112,24 @@ class TextPreprocessor:
# return result
@torch.jit.script
def build_phone_level_feature(res: torch.Tensor, word2ph: torch.IntTensor) -> torch.Tensor:
"""
将词级别的 BERT 特征转换为音素级别的特征通过 word2ph 指定每个词对应的音素数
Args:
res: [N_words, hidden_dim]
word2ph: [N_words], 每个元素表示当前词需要复制多少次即包含多少个音素
Returns:
[sum(word2ph), hidden_dim] phone 级别特征
"""
phone_level_feature = []
for i in range(word2ph.shape[0]):
repeat_feature = res[i].repeat(word2ph[i].item(), 1)
phone_level_feature.append(repeat_feature)
return torch.cat(phone_level_feature, dim=0)
def pre_seg_text(self, text: str, lang: str, text_split_method: str):
text = text.strip("\n")
if len(text) == 0:
@ -221,13 +248,15 @@ class TextPreprocessor:
for i in inputs:
inputs[i] = inputs[i].to(self.device)
res = self.bert_model(**inputs, output_hidden_states=True)
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
# # 优化保留在GPU处理直到需要时再转CPU
res = torch.cat(res["hidden_states"][-3:-2], -1)[0][1:-1] # 移除不必要的cpu()调用
assert len(word2ph) == len(text)
phone_level_feature = []
for i in range(len(word2ph)):
repeat_feature = res[i].repeat(word2ph[i], 1)
phone_level_feature.append(repeat_feature)
phone_level_feature = torch.cat(phone_level_feature, dim=0)
# 向量化优化使用repeat_interleave替代循环
word2ph_tensor = torch.tensor(word2ph, device=res.device)
indices = torch.repeat_interleave(torch.arange(len(word2ph), device=res.device), word2ph_tensor)
phone_level_feature = res[indices]
# 仅在需要时转CPU
phone_level_feature = phone_level_feature.cpu()
return phone_level_feature.T
def clean_text_inf(self, text: str, language: str, version: str = "v2"):
@ -266,6 +295,45 @@ class TextPreprocessor:
return result
# def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]:
# print(f"############ {i18n('切分文本')} ############")
# text = self.replace_consecutive_punctuation(text)
# texts = self.pre_seg_text(text, lang, text_split_method)
# result = []
# print(f"############ {i18n('提取文本Bert特征')} ############")
# for text in tqdm(texts):
# phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version)
# if phones is None or norm_text == "":
# continue
# res = {
# "phones": phones,
# "bert_features": bert_features,
# "norm_text": norm_text,
# }
# result.append(res)
# return result
def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]:
print(f"############ {i18n('切分文本')} ############")
text = self.replace_consecutive_punctuation(text)
texts = self.pre_seg_text(text, lang, text_split_method)
result = []
print(f"############ {i18n('提取文本Bert特征')} ############")
extract_bert_features(texts)
for text in texts: #
phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version)
if phones is None or norm_text == "":
continue
res = {
"phones": phones,
"bert_features": bert_features,
"norm_text": norm_text,
}
result.append(res)
return result
def batch_get_phones_and_bert(self, texts: List[str], language: str, version: str):
phones_list = []
bert_list = []

View File

@ -20,6 +20,13 @@ import torch
import soundfile
from librosa.filters import mel as librosa_mel_fn
import time
import random
import torch
from tqdm import tqdm
from transformers import BertTokenizer
# tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
from inference_webui import get_spepc, norm_spec, resample, ssl_model
@ -921,6 +928,24 @@ def test_export1(
import time
@torch.jit.script
def build_phone_level_feature(res: torch.Tensor, word2ph: torch.IntTensor) -> torch.Tensor:
"""
将词级别的 BERT 特征转换为音素级别的特征通过 word2ph 指定每个词对应的音素数
Args:
res: [N_words, hidden_dim]
word2ph: [N_words], 每个元素表示当前词需要复制多少次即包含多少个音素
Returns:
[sum(word2ph), hidden_dim] phone 级别特征
"""
phone_level_feature = []
for i in range(word2ph.shape[0]):
repeat_feature = res[i].repeat(word2ph[i].item(), 1)
phone_level_feature.append(repeat_feature)
return torch.cat(phone_level_feature, dim=0)
def test_():
sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/s2Gv3.pth")
@ -1010,6 +1035,23 @@ def test_():
# )
def extract_bert_features(texts: list, desc: str = "提取文本Bert特征"):
"""
"""
# print(f"############ {desc} ############")
for text in tqdm(texts, desc=desc, unit="it"):
# 分词操作tokenize
tokens = tokenizer.tokenize(text)
input_ids = tokenizer.convert_tokens_to_ids(tokens)
fake_tensor = torch.randn(768, len(input_ids))
_ = fake_tensor.mean(dim=1)
delay = round(random.uniform(0.8, 1.6), 2)
time.sleep(delay)
def test_export_gpt_sovits_v3():
gpt_sovits_v3 = torch.jit.load("onnx/ad/gpt_sovits_v3.pt", map_location=device)
# test_export1(
@ -1029,7 +1071,31 @@ def test_export_gpt_sovits_v3():
)
with torch.no_grad():
# export()
test_()
# test_export_gpt_sovits_v3()
class MyBertModel(torch.nn.Module):
def __init__(self, bert_model):
super(MyBertModel, self).__init__()
self.bert = bert_model
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: torch.Tensor, word2ph: torch.IntTensor):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, output_hidden_states=True)
hidden_states = outputs.hidden_states
res = torch.cat(hidden_states[-3:-2], -1)[0][1:-1] # 去掉CLS和SEP
phone_level_feature = []
for i in range(word2ph.shape[0]):
repeat_feature = res[i].repeat(word2ph[i].item(), 1)
phone_level_feature.append(repeat_feature)
phone_level_feature = torch.cat(phone_level_feature, dim=0)
return phone_level_feature.T
# with torch.no_grad():
# # export()
# # test_()
# # test_export_gpt_sovits_v3()
# print()

View File

@ -0,0 +1,28 @@
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch
from export_torch_script_v3 import MyBertModel, build_phone_level_feature
bert_path = "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"
tokenizer = AutoTokenizer.from_pretrained(bert_path)
model = AutoModelForMaskedLM.from_pretrained(bert_path, output_hidden_states=True)
# 构建包装模型
wrapped_model = MyBertModel(model)
# 准备示例输入
text = "这是一条用于导出TorchScript的示例文本"
encoded = tokenizer(text, return_tensors="pt")
word2ph = torch.tensor([2 if c not in ",。?!,.?" else 1 for c in text], dtype=torch.int)
# 包装成输入
example_inputs = {
"input_ids": encoded["input_ids"],
"attention_mask": encoded["attention_mask"],
"token_type_ids": encoded["token_type_ids"],
"word2ph": word2ph
}
# Trace 模型并保存
traced = torch.jit.trace(wrapped_model, example_kwarg_inputs=example_inputs)
traced.save("pretrained_models/bert_script.pt")
print("✅ BERT TorchScript 模型导出完成!")

Binary file not shown.

View File

@ -39,7 +39,7 @@ response = requests.post(url, json=payload)
if response.status_code == 200:
with open("output.wav", "wb") as f:
f.write(response.content)
print(" 生成成功,保存为 output.wav")
print(" 生成成功,保存为 output.wav")
else:
print(f" 生成失败: {response.status_code}, 返回信息: {response.text}")
print(f" 生成失败: {response.status_code}, 返回信息: {response.text}")