From e537e129a58eb74da4374083edad2c6de41638b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=CE=9D=CE=B1=CF=81=CE=BF=CF=85=CF=83=CE=AD=C2=B7=CE=BC?= =?UTF-8?q?=C2=B7=CE=B3=CE=B9=CE=BF=CF=85=CE=BC=CE=B5=CE=BC=CE=AF=C2=B7?= =?UTF-8?q?=CE=A7=CE=B9=CE=BD=CE=B1=CE=BA=CE=AC=CE=BD=CE=BD=CE=B1?= <40709280+NaruseMioShirakana@users.noreply.github.com> Date: Thu, 3 Apr 2025 16:40:44 +0800 Subject: [PATCH] onnx export onnx export --- GPT_SoVITS/export_onnx.py | 306 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 306 insertions(+) create mode 100644 GPT_SoVITS/export_onnx.py diff --git a/GPT_SoVITS/export_onnx.py b/GPT_SoVITS/export_onnx.py new file mode 100644 index 0000000..677936c --- /dev/null +++ b/GPT_SoVITS/export_onnx.py @@ -0,0 +1,306 @@ +import os +import json +import onnx +import torch +import onnxsim + +from torch.nn import Module +from feature_extractor import cnhubert +from onnxruntime import InferenceSession +from pytorch_lightning import LightningModule +from transformers import AutoTokenizer, AutoModelForMaskedLM +import AR.models.t2s_model_onnx as t2s + +from module.models_onnx import SynthesizerTrn + +root_path = os.path.dirname(os.path.abspath(__file__)) +onnx_path = os.path.join(root_path, "onnx") +if not os.path.exists(onnx_path): + os.makedirs(onnx_path) + +class BertWrapper(Module): + def __init__(self): + bert_path = os.environ.get( + "bert_path", "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large" + ) + super(BertWrapper, self).__init__() + self.model = AutoModelForMaskedLM.from_pretrained(bert_path) + self.tokenizer = AutoTokenizer.from_pretrained(bert_path) + + def forward(self, input_ids): + attention_mask = torch.ones_like(input_ids) + token_type_ids = torch.zeros_like(input_ids) + res = self.model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, output_hidden_states=True) + return torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1] + + def export_onnx(self): + vocab_dict = { k: v for k, v in self.tokenizer.get_vocab().items() } + vocab_path = os.path.join(onnx_path, "Vocab.json") + with open(vocab_path, "w") as f: + json.dump(vocab_dict, f, indent=4) + dummy_input = torch.randint(0, 100, (1, 20)).long() + torch.onnx.export( + self, + dummy_input, + os.path.join(onnx_path, "Bert.onnx"), + input_names=["input_ids"], + output_names=["output"], + dynamic_axes={"input_ids": {0: "batch_size", 1: "sequence_length"}}, + opset_version=18, + ) + sim, _ = onnxsim.simplify(os.path.join(onnx_path, "Bert.onnx")) + onnx.save_model(sim, os.path.join(onnx_path, "Bert.onnx")) + print("Exported BERT to ONNX format.") + + +class CnHubertWrapper(Module): + def __init__(self): + super(CnHubertWrapper, self).__init__() + cnhubert_base_path = os.environ.get( + "cnhubert_base_path", "GPT_SoVITS/pretrained_models/chinese-hubert-base" + ) + cnhubert.cnhubert_base_path = cnhubert_base_path + self.model = cnhubert.get_model().model + + def forward(self, signal): + return self.model(signal)["last_hidden_state"] + + def export_onnx(self): + dummy_input = torch.randn(1, 16000 * 10) + torch.onnx.export( + self, + dummy_input, + os.path.join(onnx_path, "CnHubert.onnx"), + input_names=["signal"], + output_names=["output"], + dynamic_axes={"signal": {0: "batch_size", 1: "sequence_length"}}, + opset_version=18, + ) + sim, _ = onnxsim.simplify(os.path.join(onnx_path, "CnHubert.onnx")) + onnx.save_model(sim, os.path.join(onnx_path, "CnHubert.onnx")) + print("Exported CN-Hubert to ONNX format.") + + +class Text2SemanticLightningModule(LightningModule): + def __init__(self, path, top_k=20, cache_size=2000): + super().__init__() + dict_s1 = torch.load(path, map_location="cpu") + config = dict_s1["config"] + self.model = t2s.Text2SemanticDecoder(config=config) + self.load_state_dict(dict_s1["weight"]) + self.cache_size = cache_size + self.top_k = top_k + +def export_ar(path, top_k=20, cache_size=2000): + model_l = Text2SemanticLightningModule(path, top_k=top_k, cache_size=cache_size) + model = model_l.model + + x = torch.randint(0, 100, (1, 20)).long() + x_len = torch.tensor([20]).long() + y = torch.randint(0, 100, (1, 20)).long() + y_len = torch.tensor([20]).long() + bert_feature = torch.randn(1, 20, 1024) + top_p = torch.tensor([0.8]) + repetition_penalty = torch.tensor([1.35]) + temperature = torch.tensor([0.6]) + + prompt_processor = t2s.PromptProcessor(cache_len=cache_size, model=model, top_k=top_k) + decode_next_token = t2s.DecodeNextToken(cache_len=cache_size, model=model, top_k=top_k) + + torch.onnx.export( + prompt_processor, + (x, x_len, y, y_len, bert_feature, top_p, repetition_penalty, temperature), + os.path.join(onnx_path, "PromptProcessor.onnx"), + input_names=["x", "x_len", "y", "y_len", "bert_feature", "top_p", "repetition_penalty", "temperature"], + output_names=["y", "k_cache", "v_cache", "xy_pos", "y_idx", "samples"], + dynamic_axes={ + "x": {0: "batch_size", 1: "sequence_length"}, + "y": {0: "batch_size", 1: "sequence_length"}, + "bert_feature": {0: "batch_size", 1: "sequence_length"}, + }, + opset_version=18, + ) + + sim, _ = onnxsim.simplify(os.path.join(onnx_path, "PromptProcessor.onnx")) + onnx.save_model(sim, os.path.join(onnx_path, "PromptProcessor.onnx")) + + y, k_cache, v_cache, xy_pos, y_idx, samples = prompt_processor( + x, x_len, y, y_len, bert_feature, top_p, repetition_penalty, temperature + ) + + torch.onnx.export( + decode_next_token, + (y, k_cache, v_cache, xy_pos, y_idx, top_p, repetition_penalty, temperature), + os.path.join(onnx_path, "DecodeNextToken.onnx"), + input_names=["y", "k_cache", "v_cache", "xy_pos", "y_idx", "top_p", "repetition_penalty", "temperature"], + output_names=["y", "k_cache", "v_cache", "xy_pos", "y_idx", "samples"], + dynamic_axes={ + "y": {0: "batch_size", 1: "sequence_length"}, + "k_cache": {1: "batch_size", 2: "sequence_length"}, + "v_cache": {1: "batch_size", 2: "sequence_length"}, + }, + opset_version=18 + ) + + sim, _ = onnxsim.simplify(os.path.join(onnx_path, "DecodeNextToken.onnx")) + onnx.save_model(sim, os.path.join(onnx_path, "DecodeNextToken.onnx")) + + +from io import BytesIO +def load_sovits_new(sovits_path): + f=open(sovits_path,"rb") + meta=f.read(2) + if meta!="PK": + data = b'PK' + f.read() + bio = BytesIO() + bio.write(data) + bio.seek(0) + return torch.load(bio, map_location="cpu", weights_only=False) + return torch.load(sovits_path,map_location="cpu", weights_only=False) + + +class DictToAttrRecursive(dict): + def __init__(self, input_dict): + super().__init__(input_dict) + for key, value in input_dict.items(): + if isinstance(value, dict): + value = DictToAttrRecursive(value) + self[key] = value + setattr(self, key, value) + + def __getattr__(self, item): + try: + return self[item] + except KeyError: + raise AttributeError(f"Attribute {item} not found") + + def __setattr__(self, key, value): + if isinstance(value, dict): + value = DictToAttrRecursive(value) + super(DictToAttrRecursive, self).__setitem__(key, value) + super().__setattr__(key, value) + + def __delattr__(self, item): + try: + del self[item] + except KeyError: + raise AttributeError(f"Attribute {item} not found") + + +def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): + hann_window = torch.hann_window(win_size).to( + dtype=y.dtype, device=y.device + ) + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + spec = torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window, + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=False, + ) + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + return spec + + +class Extractor(Module): + def __init__(self, model): + super(Extractor, self).__init__() + self.model = model + + def forward(self, x): + return self.model.extract_latent(x.transpose(1, 2)) + + +class V1V2(Module): + def __init__(self, path): + super(V1V2, self).__init__() + dict_s2 = load_sovits_new(path) + hps = dict_s2["config"] + hps = DictToAttrRecursive(hps) + hps.model.semantic_frame_rate = "25hz" + if 'enc_p.text_embedding.weight'not in dict_s2['weight']: + hps.model.version = "v2"#v3model,v2sybomls + elif dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322: + hps.model.version = "v1" + else: + hps.model.version = "v2" + version=hps.model.version + # print("sovits版本:",hps.model.version) + vq_model = SynthesizerTrn( + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + n_speakers=hps.data.n_speakers, + **hps.model + ) + vq_model.load_state_dict(dict_s2["weight"], strict=False) + vq_model.eval() + self.vq_model = vq_model + self.hps = hps + self.ext = Extractor(self.vq_model) + + def forward(self, text_seq, pred_semantic, ref_audio): + refer = spectrogram_torch( + ref_audio, + self.hps.data.filter_length, + self.hps.data.sampling_rate, + self.hps.data.hop_length, + self.hps.data.win_length, + center=False + ) + return self.vq_model(pred_semantic.unsqueeze(0), text_seq, refer)[0, 0] + + def export(self): + test_seq = torch.randint(0, 100, (1, 20)).long() + pred_semantic = torch.randint(0, 100, (1, 20)).long() + ref_audio = torch.randn(1, 16000 * 10) + torch.onnx.export( + self, + (test_seq, pred_semantic, ref_audio), + os.path.join(onnx_path, "GptSoVitsV1V2.onnx"), + input_names=["text_seq", "pred_semantic", "ref_audio"], + output_names=["output"], + dynamic_axes={ + "text_seq": {0: "batch_size", 1: "sequence_length"}, + "pred_semantic": {0: "batch_size", 1: "sequence_length"}, + "ref_audio": {0: "batch_size", 1: "sequence_length"}, + }, + opset_version=18, + ) + + sim, _ = onnxsim.simplify(os.path.join(onnx_path, "GptSoVitsV1V2.onnx")) + onnx.save_model(sim, os.path.join(onnx_path, "GptSoVitsV1V2.onnx")) + ref_units = torch.randn(1, 20, 768) + torch.onnx.export( + self.ext, + ref_units, + os.path.join(onnx_path, "Extractor.onnx"), + input_names=["ref_units"], + output_names=["output"], + dynamic_axes={ + "ref_units": {0: "batch_size", 1: "sequence_length"}, + }, + opset_version=18, + ) + + +if __name__ == "__main__": + #CnHubertWrapper().export_onnx() + #BertWrapper().export_onnx() + V1V2("D:\\VSGIT\GPT-SoVITS-main\\GPT_SoVITS\\GPT-SoVITS-v3lora-20250228\\GPT_SoVITS\\t\\SoVITS_weights\\小特.pth").export() + '''export_ar( + "D:\\VSGIT\GPT-SoVITS-main\\GPT_SoVITS\\GPT-SoVITS-v3lora-20250228\\GPT_SoVITS\\t\\GPT_weights\\小特.ckpt", + top_k=10, + cache_size=1500, + )''' +