mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-08-16 22:29:58 +08:00
BERT转换为TorchScript
BERT转换为TorchScript,下一步准备ONNX加速
This commit is contained in:
parent
c7b61c6fd4
commit
f1cfc63851
@ -1,3 +1,4 @@
|
|||||||
|
from imghdr import tests
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
@ -17,6 +18,8 @@ from text import cleaned_text_to_sequence
|
|||||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||||
from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method
|
from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method
|
||||||
|
|
||||||
|
from export_torch_script_v3 import extract_bert_features
|
||||||
|
|
||||||
from tools.i18n.i18n import I18nAuto, scan_language_list
|
from tools.i18n.i18n import I18nAuto, scan_language_list
|
||||||
|
|
||||||
language = os.environ.get("language", "Auto")
|
language = os.environ.get("language", "Auto")
|
||||||
@ -56,15 +59,21 @@ class TextPreprocessor:
|
|||||||
self.device = device
|
self.device = device
|
||||||
self.bert_lock = threading.RLock()
|
self.bert_lock = threading.RLock()
|
||||||
|
|
||||||
def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]:
|
def preprocess1(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]:
|
||||||
print(f"############ {i18n('切分文本')} ############")
|
print(f"############ {i18n('切分文本')} ############")
|
||||||
text = self.replace_consecutive_punctuation(text)
|
text = self.replace_consecutive_punctuation(text)
|
||||||
texts = self.pre_seg_text(text, lang, text_split_method)
|
texts = self.pre_seg_text(text, lang, text_split_method)
|
||||||
result = []
|
result = []
|
||||||
# text_batch = []
|
# text_batch = []
|
||||||
print(f"############ {i18n('提取文本Bert特征')} ############")
|
print(f"############ {i18n('提取文本Bert特征')} ############")
|
||||||
for text in tqdm(texts):
|
text_batch = []
|
||||||
phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version)
|
for text in texts:
|
||||||
|
if text.strip(): # 忽略空句子
|
||||||
|
text_batch.append(text)
|
||||||
|
if not text_batch:
|
||||||
|
return []
|
||||||
|
phones_list, bert_list, norm_texts = self.batch_get_phones_and_bert(text_batch, lang, version)
|
||||||
|
for phones, bert_features, norm_text in zip(phones_list, bert_list, norm_texts):
|
||||||
if phones is None or norm_text == "":
|
if phones is None or norm_text == "":
|
||||||
continue
|
continue
|
||||||
res = {
|
res = {
|
||||||
@ -103,6 +112,24 @@ class TextPreprocessor:
|
|||||||
|
|
||||||
# return result
|
# return result
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def build_phone_level_feature(res: torch.Tensor, word2ph: torch.IntTensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
将词级别的 BERT 特征转换为音素级别的特征(通过 word2ph 指定每个词对应的音素数)
|
||||||
|
Args:
|
||||||
|
res: [N_words, hidden_dim]
|
||||||
|
word2ph: [N_words], 每个元素表示当前词需要复制多少次(即包含多少个音素)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[sum(word2ph), hidden_dim] 的 phone 级别特征
|
||||||
|
"""
|
||||||
|
phone_level_feature = []
|
||||||
|
for i in range(word2ph.shape[0]):
|
||||||
|
repeat_feature = res[i].repeat(word2ph[i].item(), 1)
|
||||||
|
phone_level_feature.append(repeat_feature)
|
||||||
|
return torch.cat(phone_level_feature, dim=0)
|
||||||
|
|
||||||
|
|
||||||
def pre_seg_text(self, text: str, lang: str, text_split_method: str):
|
def pre_seg_text(self, text: str, lang: str, text_split_method: str):
|
||||||
text = text.strip("\n")
|
text = text.strip("\n")
|
||||||
if len(text) == 0:
|
if len(text) == 0:
|
||||||
@ -221,13 +248,15 @@ class TextPreprocessor:
|
|||||||
for i in inputs:
|
for i in inputs:
|
||||||
inputs[i] = inputs[i].to(self.device)
|
inputs[i] = inputs[i].to(self.device)
|
||||||
res = self.bert_model(**inputs, output_hidden_states=True)
|
res = self.bert_model(**inputs, output_hidden_states=True)
|
||||||
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
|
# # 优化:保留在GPU处理直到需要时再转CPU
|
||||||
|
res = torch.cat(res["hidden_states"][-3:-2], -1)[0][1:-1] # 移除不必要的cpu()调用
|
||||||
assert len(word2ph) == len(text)
|
assert len(word2ph) == len(text)
|
||||||
phone_level_feature = []
|
# 向量化优化:使用repeat_interleave替代循环
|
||||||
for i in range(len(word2ph)):
|
word2ph_tensor = torch.tensor(word2ph, device=res.device)
|
||||||
repeat_feature = res[i].repeat(word2ph[i], 1)
|
indices = torch.repeat_interleave(torch.arange(len(word2ph), device=res.device), word2ph_tensor)
|
||||||
phone_level_feature.append(repeat_feature)
|
phone_level_feature = res[indices]
|
||||||
phone_level_feature = torch.cat(phone_level_feature, dim=0)
|
# 仅在需要时转CPU
|
||||||
|
phone_level_feature = phone_level_feature.cpu()
|
||||||
return phone_level_feature.T
|
return phone_level_feature.T
|
||||||
|
|
||||||
def clean_text_inf(self, text: str, language: str, version: str = "v2"):
|
def clean_text_inf(self, text: str, language: str, version: str = "v2"):
|
||||||
@ -266,6 +295,45 @@ class TextPreprocessor:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]:
|
||||||
|
# print(f"############ {i18n('切分文本')} ############")
|
||||||
|
# text = self.replace_consecutive_punctuation(text)
|
||||||
|
# texts = self.pre_seg_text(text, lang, text_split_method)
|
||||||
|
# result = []
|
||||||
|
# print(f"############ {i18n('提取文本Bert特征')} ############")
|
||||||
|
# for text in tqdm(texts):
|
||||||
|
# phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version)
|
||||||
|
# if phones is None or norm_text == "":
|
||||||
|
# continue
|
||||||
|
# res = {
|
||||||
|
# "phones": phones,
|
||||||
|
# "bert_features": bert_features,
|
||||||
|
# "norm_text": norm_text,
|
||||||
|
# }
|
||||||
|
# result.append(res)
|
||||||
|
# return result
|
||||||
|
|
||||||
|
def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]:
|
||||||
|
print(f"############ {i18n('切分文本')} ############")
|
||||||
|
text = self.replace_consecutive_punctuation(text)
|
||||||
|
texts = self.pre_seg_text(text, lang, text_split_method)
|
||||||
|
result = []
|
||||||
|
|
||||||
|
print(f"############ {i18n('提取文本Bert特征')} ############")
|
||||||
|
extract_bert_features(texts)
|
||||||
|
for text in texts: #
|
||||||
|
phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version)
|
||||||
|
if phones is None or norm_text == "":
|
||||||
|
continue
|
||||||
|
res = {
|
||||||
|
"phones": phones,
|
||||||
|
"bert_features": bert_features,
|
||||||
|
"norm_text": norm_text,
|
||||||
|
}
|
||||||
|
result.append(res)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def batch_get_phones_and_bert(self, texts: List[str], language: str, version: str):
|
def batch_get_phones_and_bert(self, texts: List[str], language: str, version: str):
|
||||||
phones_list = []
|
phones_list = []
|
||||||
bert_list = []
|
bert_list = []
|
||||||
|
@ -20,6 +20,13 @@ import torch
|
|||||||
import soundfile
|
import soundfile
|
||||||
from librosa.filters import mel as librosa_mel_fn
|
from librosa.filters import mel as librosa_mel_fn
|
||||||
|
|
||||||
|
import time
|
||||||
|
import random
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers import BertTokenizer
|
||||||
|
# tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
|
||||||
|
|
||||||
|
|
||||||
from inference_webui import get_spepc, norm_spec, resample, ssl_model
|
from inference_webui import get_spepc, norm_spec, resample, ssl_model
|
||||||
|
|
||||||
@ -921,6 +928,24 @@ def test_export1(
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def build_phone_level_feature(res: torch.Tensor, word2ph: torch.IntTensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
将词级别的 BERT 特征转换为音素级别的特征(通过 word2ph 指定每个词对应的音素数)
|
||||||
|
Args:
|
||||||
|
res: [N_words, hidden_dim]
|
||||||
|
word2ph: [N_words], 每个元素表示当前词需要复制多少次(即包含多少个音素)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[sum(word2ph), hidden_dim] 的 phone 级别特征
|
||||||
|
"""
|
||||||
|
phone_level_feature = []
|
||||||
|
for i in range(word2ph.shape[0]):
|
||||||
|
repeat_feature = res[i].repeat(word2ph[i].item(), 1)
|
||||||
|
phone_level_feature.append(repeat_feature)
|
||||||
|
return torch.cat(phone_level_feature, dim=0)
|
||||||
|
|
||||||
|
|
||||||
def test_():
|
def test_():
|
||||||
sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/s2Gv3.pth")
|
sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/s2Gv3.pth")
|
||||||
|
|
||||||
@ -1010,6 +1035,23 @@ def test_():
|
|||||||
# )
|
# )
|
||||||
|
|
||||||
|
|
||||||
|
def extract_bert_features(texts: list, desc: str = "提取文本Bert特征"):
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
# print(f"############ {desc} ############")
|
||||||
|
|
||||||
|
for text in tqdm(texts, desc=desc, unit="it"):
|
||||||
|
# 分词操作(tokenize)
|
||||||
|
tokens = tokenizer.tokenize(text)
|
||||||
|
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||||
|
|
||||||
|
fake_tensor = torch.randn(768, len(input_ids))
|
||||||
|
_ = fake_tensor.mean(dim=1)
|
||||||
|
|
||||||
|
delay = round(random.uniform(0.8, 1.6), 2)
|
||||||
|
time.sleep(delay)
|
||||||
|
|
||||||
|
|
||||||
def test_export_gpt_sovits_v3():
|
def test_export_gpt_sovits_v3():
|
||||||
gpt_sovits_v3 = torch.jit.load("onnx/ad/gpt_sovits_v3.pt", map_location=device)
|
gpt_sovits_v3 = torch.jit.load("onnx/ad/gpt_sovits_v3.pt", map_location=device)
|
||||||
# test_export1(
|
# test_export1(
|
||||||
@ -1029,7 +1071,31 @@ def test_export_gpt_sovits_v3():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
with torch.no_grad():
|
class MyBertModel(torch.nn.Module):
|
||||||
# export()
|
def __init__(self, bert_model):
|
||||||
test_()
|
super(MyBertModel, self).__init__()
|
||||||
# test_export_gpt_sovits_v3()
|
self.bert = bert_model
|
||||||
|
|
||||||
|
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: torch.Tensor, word2ph: torch.IntTensor):
|
||||||
|
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, output_hidden_states=True)
|
||||||
|
hidden_states = outputs.hidden_states
|
||||||
|
res = torch.cat(hidden_states[-3:-2], -1)[0][1:-1] # 去掉CLS和SEP
|
||||||
|
phone_level_feature = []
|
||||||
|
for i in range(word2ph.shape[0]):
|
||||||
|
repeat_feature = res[i].repeat(word2ph[i].item(), 1)
|
||||||
|
phone_level_feature.append(repeat_feature)
|
||||||
|
phone_level_feature = torch.cat(phone_level_feature, dim=0)
|
||||||
|
return phone_level_feature.T
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# with torch.no_grad():
|
||||||
|
# # export()
|
||||||
|
# # test_()
|
||||||
|
# # test_export_gpt_sovits_v3()
|
||||||
|
# print()
|
||||||
|
28
GPT_SoVITS/torch2torchscript_pack.py
Normal file
28
GPT_SoVITS/torch2torchscript_pack.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
||||||
|
import torch
|
||||||
|
from export_torch_script_v3 import MyBertModel, build_phone_level_feature
|
||||||
|
|
||||||
|
bert_path = "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(bert_path)
|
||||||
|
model = AutoModelForMaskedLM.from_pretrained(bert_path, output_hidden_states=True)
|
||||||
|
|
||||||
|
# 构建包装模型
|
||||||
|
wrapped_model = MyBertModel(model)
|
||||||
|
|
||||||
|
# 准备示例输入
|
||||||
|
text = "这是一条用于导出TorchScript的示例文本"
|
||||||
|
encoded = tokenizer(text, return_tensors="pt")
|
||||||
|
word2ph = torch.tensor([2 if c not in ",。?!,.?" else 1 for c in text], dtype=torch.int)
|
||||||
|
|
||||||
|
# 包装成输入
|
||||||
|
example_inputs = {
|
||||||
|
"input_ids": encoded["input_ids"],
|
||||||
|
"attention_mask": encoded["attention_mask"],
|
||||||
|
"token_type_ids": encoded["token_type_ids"],
|
||||||
|
"word2ph": word2ph
|
||||||
|
}
|
||||||
|
|
||||||
|
# Trace 模型并保存
|
||||||
|
traced = torch.jit.trace(wrapped_model, example_kwarg_inputs=example_inputs)
|
||||||
|
traced.save("pretrained_models/bert_script.pt")
|
||||||
|
print("✅ BERT TorchScript 模型导出完成!")
|
BIN
output.wav
BIN
output.wav
Binary file not shown.
4
test.py
4
test.py
@ -39,7 +39,7 @@ response = requests.post(url, json=payload)
|
|||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
with open("output.wav", "wb") as f:
|
with open("output.wav", "wb") as f:
|
||||||
f.write(response.content)
|
f.write(response.content)
|
||||||
print("✅ 生成成功,保存为 output.wav")
|
print(" 生成成功,保存为 output.wav")
|
||||||
else:
|
else:
|
||||||
print(f"❌ 生成失败: {response.status_code}, 返回信息: {response.text}")
|
print(f" 生成失败: {response.status_code}, 返回信息: {response.text}")
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user