GPT-SoVITS/GPT_SoVITS/export_onnx.py
Ναρουσέ·μ·γιουμεμί·Χινακάννα e537e129a5
onnx export
onnx export
2025-04-03 16:40:44 +08:00

307 lines
11 KiB
Python

import os
import json
import onnx
import torch
import onnxsim
from torch.nn import Module
from feature_extractor import cnhubert
from onnxruntime import InferenceSession
from pytorch_lightning import LightningModule
from transformers import AutoTokenizer, AutoModelForMaskedLM
import AR.models.t2s_model_onnx as t2s
from module.models_onnx import SynthesizerTrn
root_path = os.path.dirname(os.path.abspath(__file__))
onnx_path = os.path.join(root_path, "onnx")
if not os.path.exists(onnx_path):
os.makedirs(onnx_path)
class BertWrapper(Module):
def __init__(self):
bert_path = os.environ.get(
"bert_path", "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"
)
super(BertWrapper, self).__init__()
self.model = AutoModelForMaskedLM.from_pretrained(bert_path)
self.tokenizer = AutoTokenizer.from_pretrained(bert_path)
def forward(self, input_ids):
attention_mask = torch.ones_like(input_ids)
token_type_ids = torch.zeros_like(input_ids)
res = self.model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, output_hidden_states=True)
return torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
def export_onnx(self):
vocab_dict = { k: v for k, v in self.tokenizer.get_vocab().items() }
vocab_path = os.path.join(onnx_path, "Vocab.json")
with open(vocab_path, "w") as f:
json.dump(vocab_dict, f, indent=4)
dummy_input = torch.randint(0, 100, (1, 20)).long()
torch.onnx.export(
self,
dummy_input,
os.path.join(onnx_path, "Bert.onnx"),
input_names=["input_ids"],
output_names=["output"],
dynamic_axes={"input_ids": {0: "batch_size", 1: "sequence_length"}},
opset_version=18,
)
sim, _ = onnxsim.simplify(os.path.join(onnx_path, "Bert.onnx"))
onnx.save_model(sim, os.path.join(onnx_path, "Bert.onnx"))
print("Exported BERT to ONNX format.")
class CnHubertWrapper(Module):
def __init__(self):
super(CnHubertWrapper, self).__init__()
cnhubert_base_path = os.environ.get(
"cnhubert_base_path", "GPT_SoVITS/pretrained_models/chinese-hubert-base"
)
cnhubert.cnhubert_base_path = cnhubert_base_path
self.model = cnhubert.get_model().model
def forward(self, signal):
return self.model(signal)["last_hidden_state"]
def export_onnx(self):
dummy_input = torch.randn(1, 16000 * 10)
torch.onnx.export(
self,
dummy_input,
os.path.join(onnx_path, "CnHubert.onnx"),
input_names=["signal"],
output_names=["output"],
dynamic_axes={"signal": {0: "batch_size", 1: "sequence_length"}},
opset_version=18,
)
sim, _ = onnxsim.simplify(os.path.join(onnx_path, "CnHubert.onnx"))
onnx.save_model(sim, os.path.join(onnx_path, "CnHubert.onnx"))
print("Exported CN-Hubert to ONNX format.")
class Text2SemanticLightningModule(LightningModule):
def __init__(self, path, top_k=20, cache_size=2000):
super().__init__()
dict_s1 = torch.load(path, map_location="cpu")
config = dict_s1["config"]
self.model = t2s.Text2SemanticDecoder(config=config)
self.load_state_dict(dict_s1["weight"])
self.cache_size = cache_size
self.top_k = top_k
def export_ar(path, top_k=20, cache_size=2000):
model_l = Text2SemanticLightningModule(path, top_k=top_k, cache_size=cache_size)
model = model_l.model
x = torch.randint(0, 100, (1, 20)).long()
x_len = torch.tensor([20]).long()
y = torch.randint(0, 100, (1, 20)).long()
y_len = torch.tensor([20]).long()
bert_feature = torch.randn(1, 20, 1024)
top_p = torch.tensor([0.8])
repetition_penalty = torch.tensor([1.35])
temperature = torch.tensor([0.6])
prompt_processor = t2s.PromptProcessor(cache_len=cache_size, model=model, top_k=top_k)
decode_next_token = t2s.DecodeNextToken(cache_len=cache_size, model=model, top_k=top_k)
torch.onnx.export(
prompt_processor,
(x, x_len, y, y_len, bert_feature, top_p, repetition_penalty, temperature),
os.path.join(onnx_path, "PromptProcessor.onnx"),
input_names=["x", "x_len", "y", "y_len", "bert_feature", "top_p", "repetition_penalty", "temperature"],
output_names=["y", "k_cache", "v_cache", "xy_pos", "y_idx", "samples"],
dynamic_axes={
"x": {0: "batch_size", 1: "sequence_length"},
"y": {0: "batch_size", 1: "sequence_length"},
"bert_feature": {0: "batch_size", 1: "sequence_length"},
},
opset_version=18,
)
sim, _ = onnxsim.simplify(os.path.join(onnx_path, "PromptProcessor.onnx"))
onnx.save_model(sim, os.path.join(onnx_path, "PromptProcessor.onnx"))
y, k_cache, v_cache, xy_pos, y_idx, samples = prompt_processor(
x, x_len, y, y_len, bert_feature, top_p, repetition_penalty, temperature
)
torch.onnx.export(
decode_next_token,
(y, k_cache, v_cache, xy_pos, y_idx, top_p, repetition_penalty, temperature),
os.path.join(onnx_path, "DecodeNextToken.onnx"),
input_names=["y", "k_cache", "v_cache", "xy_pos", "y_idx", "top_p", "repetition_penalty", "temperature"],
output_names=["y", "k_cache", "v_cache", "xy_pos", "y_idx", "samples"],
dynamic_axes={
"y": {0: "batch_size", 1: "sequence_length"},
"k_cache": {1: "batch_size", 2: "sequence_length"},
"v_cache": {1: "batch_size", 2: "sequence_length"},
},
opset_version=18
)
sim, _ = onnxsim.simplify(os.path.join(onnx_path, "DecodeNextToken.onnx"))
onnx.save_model(sim, os.path.join(onnx_path, "DecodeNextToken.onnx"))
from io import BytesIO
def load_sovits_new(sovits_path):
f=open(sovits_path,"rb")
meta=f.read(2)
if meta!="PK":
data = b'PK' + f.read()
bio = BytesIO()
bio.write(data)
bio.seek(0)
return torch.load(bio, map_location="cpu", weights_only=False)
return torch.load(sovits_path,map_location="cpu", weights_only=False)
class DictToAttrRecursive(dict):
def __init__(self, input_dict):
super().__init__(input_dict)
for key, value in input_dict.items():
if isinstance(value, dict):
value = DictToAttrRecursive(value)
self[key] = value
setattr(self, key, value)
def __getattr__(self, item):
try:
return self[item]
except KeyError:
raise AttributeError(f"Attribute {item} not found")
def __setattr__(self, key, value):
if isinstance(value, dict):
value = DictToAttrRecursive(value)
super(DictToAttrRecursive, self).__setitem__(key, value)
super().__setattr__(key, value)
def __delattr__(self, item):
try:
del self[item]
except KeyError:
raise AttributeError(f"Attribute {item} not found")
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
hann_window = torch.hann_window(win_size).to(
dtype=y.dtype, device=y.device
)
y = torch.nn.functional.pad(
y.unsqueeze(1),
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
mode="reflect",
)
y = y.squeeze(1)
spec = torch.stft(
y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window,
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=False,
)
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
return spec
class Extractor(Module):
def __init__(self, model):
super(Extractor, self).__init__()
self.model = model
def forward(self, x):
return self.model.extract_latent(x.transpose(1, 2))
class V1V2(Module):
def __init__(self, path):
super(V1V2, self).__init__()
dict_s2 = load_sovits_new(path)
hps = dict_s2["config"]
hps = DictToAttrRecursive(hps)
hps.model.semantic_frame_rate = "25hz"
if 'enc_p.text_embedding.weight'not in dict_s2['weight']:
hps.model.version = "v2"#v3model,v2sybomls
elif dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
hps.model.version = "v1"
else:
hps.model.version = "v2"
version=hps.model.version
# print("sovits版本:",hps.model.version)
vq_model = SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model
)
vq_model.load_state_dict(dict_s2["weight"], strict=False)
vq_model.eval()
self.vq_model = vq_model
self.hps = hps
self.ext = Extractor(self.vq_model)
def forward(self, text_seq, pred_semantic, ref_audio):
refer = spectrogram_torch(
ref_audio,
self.hps.data.filter_length,
self.hps.data.sampling_rate,
self.hps.data.hop_length,
self.hps.data.win_length,
center=False
)
return self.vq_model(pred_semantic.unsqueeze(0), text_seq, refer)[0, 0]
def export(self):
test_seq = torch.randint(0, 100, (1, 20)).long()
pred_semantic = torch.randint(0, 100, (1, 20)).long()
ref_audio = torch.randn(1, 16000 * 10)
torch.onnx.export(
self,
(test_seq, pred_semantic, ref_audio),
os.path.join(onnx_path, "GptSoVitsV1V2.onnx"),
input_names=["text_seq", "pred_semantic", "ref_audio"],
output_names=["output"],
dynamic_axes={
"text_seq": {0: "batch_size", 1: "sequence_length"},
"pred_semantic": {0: "batch_size", 1: "sequence_length"},
"ref_audio": {0: "batch_size", 1: "sequence_length"},
},
opset_version=18,
)
sim, _ = onnxsim.simplify(os.path.join(onnx_path, "GptSoVitsV1V2.onnx"))
onnx.save_model(sim, os.path.join(onnx_path, "GptSoVitsV1V2.onnx"))
ref_units = torch.randn(1, 20, 768)
torch.onnx.export(
self.ext,
ref_units,
os.path.join(onnx_path, "Extractor.onnx"),
input_names=["ref_units"],
output_names=["output"],
dynamic_axes={
"ref_units": {0: "batch_size", 1: "sequence_length"},
},
opset_version=18,
)
if __name__ == "__main__":
#CnHubertWrapper().export_onnx()
#BertWrapper().export_onnx()
V1V2("D:\\VSGIT\GPT-SoVITS-main\\GPT_SoVITS\\GPT-SoVITS-v3lora-20250228\\GPT_SoVITS\\t\\SoVITS_weights\\小特.pth").export()
'''export_ar(
"D:\\VSGIT\GPT-SoVITS-main\\GPT_SoVITS\\GPT-SoVITS-v3lora-20250228\\GPT_SoVITS\\t\\GPT_weights\\小特.ckpt",
top_k=10,
cache_size=1500,
)'''