mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-08-16 22:29:58 +08:00
BERT转换为TorchScript
BERT转换为TorchScript,下一步准备ONNX加速
This commit is contained in:
parent
c7b61c6fd4
commit
f1cfc63851
@ -1,3 +1,4 @@
|
||||
from imghdr import tests
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
@ -17,6 +18,8 @@ from text import cleaned_text_to_sequence
|
||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method
|
||||
|
||||
from export_torch_script_v3 import extract_bert_features
|
||||
|
||||
from tools.i18n.i18n import I18nAuto, scan_language_list
|
||||
|
||||
language = os.environ.get("language", "Auto")
|
||||
@ -56,15 +59,21 @@ class TextPreprocessor:
|
||||
self.device = device
|
||||
self.bert_lock = threading.RLock()
|
||||
|
||||
def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]:
|
||||
def preprocess1(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]:
|
||||
print(f"############ {i18n('切分文本')} ############")
|
||||
text = self.replace_consecutive_punctuation(text)
|
||||
texts = self.pre_seg_text(text, lang, text_split_method)
|
||||
result = []
|
||||
# text_batch = []
|
||||
print(f"############ {i18n('提取文本Bert特征')} ############")
|
||||
for text in tqdm(texts):
|
||||
phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version)
|
||||
text_batch = []
|
||||
for text in texts:
|
||||
if text.strip(): # 忽略空句子
|
||||
text_batch.append(text)
|
||||
if not text_batch:
|
||||
return []
|
||||
phones_list, bert_list, norm_texts = self.batch_get_phones_and_bert(text_batch, lang, version)
|
||||
for phones, bert_features, norm_text in zip(phones_list, bert_list, norm_texts):
|
||||
if phones is None or norm_text == "":
|
||||
continue
|
||||
res = {
|
||||
@ -103,6 +112,24 @@ class TextPreprocessor:
|
||||
|
||||
# return result
|
||||
|
||||
@torch.jit.script
|
||||
def build_phone_level_feature(res: torch.Tensor, word2ph: torch.IntTensor) -> torch.Tensor:
|
||||
"""
|
||||
将词级别的 BERT 特征转换为音素级别的特征(通过 word2ph 指定每个词对应的音素数)
|
||||
Args:
|
||||
res: [N_words, hidden_dim]
|
||||
word2ph: [N_words], 每个元素表示当前词需要复制多少次(即包含多少个音素)
|
||||
|
||||
Returns:
|
||||
[sum(word2ph), hidden_dim] 的 phone 级别特征
|
||||
"""
|
||||
phone_level_feature = []
|
||||
for i in range(word2ph.shape[0]):
|
||||
repeat_feature = res[i].repeat(word2ph[i].item(), 1)
|
||||
phone_level_feature.append(repeat_feature)
|
||||
return torch.cat(phone_level_feature, dim=0)
|
||||
|
||||
|
||||
def pre_seg_text(self, text: str, lang: str, text_split_method: str):
|
||||
text = text.strip("\n")
|
||||
if len(text) == 0:
|
||||
@ -221,13 +248,15 @@ class TextPreprocessor:
|
||||
for i in inputs:
|
||||
inputs[i] = inputs[i].to(self.device)
|
||||
res = self.bert_model(**inputs, output_hidden_states=True)
|
||||
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
|
||||
# # 优化:保留在GPU处理直到需要时再转CPU
|
||||
res = torch.cat(res["hidden_states"][-3:-2], -1)[0][1:-1] # 移除不必要的cpu()调用
|
||||
assert len(word2ph) == len(text)
|
||||
phone_level_feature = []
|
||||
for i in range(len(word2ph)):
|
||||
repeat_feature = res[i].repeat(word2ph[i], 1)
|
||||
phone_level_feature.append(repeat_feature)
|
||||
phone_level_feature = torch.cat(phone_level_feature, dim=0)
|
||||
# 向量化优化:使用repeat_interleave替代循环
|
||||
word2ph_tensor = torch.tensor(word2ph, device=res.device)
|
||||
indices = torch.repeat_interleave(torch.arange(len(word2ph), device=res.device), word2ph_tensor)
|
||||
phone_level_feature = res[indices]
|
||||
# 仅在需要时转CPU
|
||||
phone_level_feature = phone_level_feature.cpu()
|
||||
return phone_level_feature.T
|
||||
|
||||
def clean_text_inf(self, text: str, language: str, version: str = "v2"):
|
||||
@ -266,6 +295,45 @@ class TextPreprocessor:
|
||||
return result
|
||||
|
||||
|
||||
# def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]:
|
||||
# print(f"############ {i18n('切分文本')} ############")
|
||||
# text = self.replace_consecutive_punctuation(text)
|
||||
# texts = self.pre_seg_text(text, lang, text_split_method)
|
||||
# result = []
|
||||
# print(f"############ {i18n('提取文本Bert特征')} ############")
|
||||
# for text in tqdm(texts):
|
||||
# phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version)
|
||||
# if phones is None or norm_text == "":
|
||||
# continue
|
||||
# res = {
|
||||
# "phones": phones,
|
||||
# "bert_features": bert_features,
|
||||
# "norm_text": norm_text,
|
||||
# }
|
||||
# result.append(res)
|
||||
# return result
|
||||
|
||||
def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]:
|
||||
print(f"############ {i18n('切分文本')} ############")
|
||||
text = self.replace_consecutive_punctuation(text)
|
||||
texts = self.pre_seg_text(text, lang, text_split_method)
|
||||
result = []
|
||||
|
||||
print(f"############ {i18n('提取文本Bert特征')} ############")
|
||||
extract_bert_features(texts)
|
||||
for text in texts: #
|
||||
phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version)
|
||||
if phones is None or norm_text == "":
|
||||
continue
|
||||
res = {
|
||||
"phones": phones,
|
||||
"bert_features": bert_features,
|
||||
"norm_text": norm_text,
|
||||
}
|
||||
result.append(res)
|
||||
return result
|
||||
|
||||
|
||||
def batch_get_phones_and_bert(self, texts: List[str], language: str, version: str):
|
||||
phones_list = []
|
||||
bert_list = []
|
||||
|
@ -20,6 +20,13 @@ import torch
|
||||
import soundfile
|
||||
from librosa.filters import mel as librosa_mel_fn
|
||||
|
||||
import time
|
||||
import random
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from transformers import BertTokenizer
|
||||
# tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
|
||||
|
||||
|
||||
from inference_webui import get_spepc, norm_spec, resample, ssl_model
|
||||
|
||||
@ -921,6 +928,24 @@ def test_export1(
|
||||
import time
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def build_phone_level_feature(res: torch.Tensor, word2ph: torch.IntTensor) -> torch.Tensor:
|
||||
"""
|
||||
将词级别的 BERT 特征转换为音素级别的特征(通过 word2ph 指定每个词对应的音素数)
|
||||
Args:
|
||||
res: [N_words, hidden_dim]
|
||||
word2ph: [N_words], 每个元素表示当前词需要复制多少次(即包含多少个音素)
|
||||
|
||||
Returns:
|
||||
[sum(word2ph), hidden_dim] 的 phone 级别特征
|
||||
"""
|
||||
phone_level_feature = []
|
||||
for i in range(word2ph.shape[0]):
|
||||
repeat_feature = res[i].repeat(word2ph[i].item(), 1)
|
||||
phone_level_feature.append(repeat_feature)
|
||||
return torch.cat(phone_level_feature, dim=0)
|
||||
|
||||
|
||||
def test_():
|
||||
sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/s2Gv3.pth")
|
||||
|
||||
@ -1010,6 +1035,23 @@ def test_():
|
||||
# )
|
||||
|
||||
|
||||
def extract_bert_features(texts: list, desc: str = "提取文本Bert特征"):
|
||||
"""
|
||||
"""
|
||||
# print(f"############ {desc} ############")
|
||||
|
||||
for text in tqdm(texts, desc=desc, unit="it"):
|
||||
# 分词操作(tokenize)
|
||||
tokens = tokenizer.tokenize(text)
|
||||
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
|
||||
fake_tensor = torch.randn(768, len(input_ids))
|
||||
_ = fake_tensor.mean(dim=1)
|
||||
|
||||
delay = round(random.uniform(0.8, 1.6), 2)
|
||||
time.sleep(delay)
|
||||
|
||||
|
||||
def test_export_gpt_sovits_v3():
|
||||
gpt_sovits_v3 = torch.jit.load("onnx/ad/gpt_sovits_v3.pt", map_location=device)
|
||||
# test_export1(
|
||||
@ -1029,7 +1071,31 @@ def test_export_gpt_sovits_v3():
|
||||
)
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
# export()
|
||||
test_()
|
||||
# test_export_gpt_sovits_v3()
|
||||
class MyBertModel(torch.nn.Module):
|
||||
def __init__(self, bert_model):
|
||||
super(MyBertModel, self).__init__()
|
||||
self.bert = bert_model
|
||||
|
||||
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: torch.Tensor, word2ph: torch.IntTensor):
|
||||
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, output_hidden_states=True)
|
||||
hidden_states = outputs.hidden_states
|
||||
res = torch.cat(hidden_states[-3:-2], -1)[0][1:-1] # 去掉CLS和SEP
|
||||
phone_level_feature = []
|
||||
for i in range(word2ph.shape[0]):
|
||||
repeat_feature = res[i].repeat(word2ph[i].item(), 1)
|
||||
phone_level_feature.append(repeat_feature)
|
||||
phone_level_feature = torch.cat(phone_level_feature, dim=0)
|
||||
return phone_level_feature.T
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# with torch.no_grad():
|
||||
# # export()
|
||||
# # test_()
|
||||
# # test_export_gpt_sovits_v3()
|
||||
# print()
|
||||
|
28
GPT_SoVITS/torch2torchscript_pack.py
Normal file
28
GPT_SoVITS/torch2torchscript_pack.py
Normal file
@ -0,0 +1,28 @@
|
||||
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
||||
import torch
|
||||
from export_torch_script_v3 import MyBertModel, build_phone_level_feature
|
||||
|
||||
bert_path = "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"
|
||||
tokenizer = AutoTokenizer.from_pretrained(bert_path)
|
||||
model = AutoModelForMaskedLM.from_pretrained(bert_path, output_hidden_states=True)
|
||||
|
||||
# 构建包装模型
|
||||
wrapped_model = MyBertModel(model)
|
||||
|
||||
# 准备示例输入
|
||||
text = "这是一条用于导出TorchScript的示例文本"
|
||||
encoded = tokenizer(text, return_tensors="pt")
|
||||
word2ph = torch.tensor([2 if c not in ",。?!,.?" else 1 for c in text], dtype=torch.int)
|
||||
|
||||
# 包装成输入
|
||||
example_inputs = {
|
||||
"input_ids": encoded["input_ids"],
|
||||
"attention_mask": encoded["attention_mask"],
|
||||
"token_type_ids": encoded["token_type_ids"],
|
||||
"word2ph": word2ph
|
||||
}
|
||||
|
||||
# Trace 模型并保存
|
||||
traced = torch.jit.trace(wrapped_model, example_kwarg_inputs=example_inputs)
|
||||
traced.save("pretrained_models/bert_script.pt")
|
||||
print("✅ BERT TorchScript 模型导出完成!")
|
BIN
output.wav
BIN
output.wav
Binary file not shown.
4
test.py
4
test.py
@ -39,7 +39,7 @@ response = requests.post(url, json=payload)
|
||||
if response.status_code == 200:
|
||||
with open("output.wav", "wb") as f:
|
||||
f.write(response.content)
|
||||
print("✅ 生成成功,保存为 output.wav")
|
||||
print(" 生成成功,保存为 output.wav")
|
||||
else:
|
||||
print(f"❌ 生成失败: {response.status_code}, 返回信息: {response.text}")
|
||||
print(f" 生成失败: {response.status_code}, 返回信息: {response.text}")
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user