mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-05 19:41:56 +08:00
307 lines
11 KiB
Python
307 lines
11 KiB
Python
import os
|
|
import json
|
|
import onnx
|
|
import torch
|
|
import onnxsim
|
|
|
|
from torch.nn import Module
|
|
from feature_extractor import cnhubert
|
|
from onnxruntime import InferenceSession
|
|
from pytorch_lightning import LightningModule
|
|
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
|
import AR.models.t2s_model_onnx as t2s
|
|
|
|
from module.models_onnx import SynthesizerTrn
|
|
|
|
root_path = os.path.dirname(os.path.abspath(__file__))
|
|
onnx_path = os.path.join(root_path, "onnx")
|
|
if not os.path.exists(onnx_path):
|
|
os.makedirs(onnx_path)
|
|
|
|
class BertWrapper(Module):
|
|
def __init__(self):
|
|
bert_path = os.environ.get(
|
|
"bert_path", "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"
|
|
)
|
|
super(BertWrapper, self).__init__()
|
|
self.model = AutoModelForMaskedLM.from_pretrained(bert_path)
|
|
self.tokenizer = AutoTokenizer.from_pretrained(bert_path)
|
|
|
|
def forward(self, input_ids):
|
|
attention_mask = torch.ones_like(input_ids)
|
|
token_type_ids = torch.zeros_like(input_ids)
|
|
res = self.model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, output_hidden_states=True)
|
|
return torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
|
|
|
|
def export_onnx(self):
|
|
vocab_dict = { k: v for k, v in self.tokenizer.get_vocab().items() }
|
|
vocab_path = os.path.join(onnx_path, "Vocab.json")
|
|
with open(vocab_path, "w") as f:
|
|
json.dump(vocab_dict, f, indent=4)
|
|
dummy_input = torch.randint(0, 100, (1, 20)).long()
|
|
torch.onnx.export(
|
|
self,
|
|
dummy_input,
|
|
os.path.join(onnx_path, "Bert.onnx"),
|
|
input_names=["input_ids"],
|
|
output_names=["output"],
|
|
dynamic_axes={"input_ids": {0: "batch_size", 1: "sequence_length"}},
|
|
opset_version=18,
|
|
)
|
|
sim, _ = onnxsim.simplify(os.path.join(onnx_path, "Bert.onnx"))
|
|
onnx.save_model(sim, os.path.join(onnx_path, "Bert.onnx"))
|
|
print("Exported BERT to ONNX format.")
|
|
|
|
|
|
class CnHubertWrapper(Module):
|
|
def __init__(self):
|
|
super(CnHubertWrapper, self).__init__()
|
|
cnhubert_base_path = os.environ.get(
|
|
"cnhubert_base_path", "GPT_SoVITS/pretrained_models/chinese-hubert-base"
|
|
)
|
|
cnhubert.cnhubert_base_path = cnhubert_base_path
|
|
self.model = cnhubert.get_model().model
|
|
|
|
def forward(self, signal):
|
|
return self.model(signal)["last_hidden_state"]
|
|
|
|
def export_onnx(self):
|
|
dummy_input = torch.randn(1, 16000 * 10)
|
|
torch.onnx.export(
|
|
self,
|
|
dummy_input,
|
|
os.path.join(onnx_path, "CnHubert.onnx"),
|
|
input_names=["signal"],
|
|
output_names=["output"],
|
|
dynamic_axes={"signal": {0: "batch_size", 1: "sequence_length"}},
|
|
opset_version=18,
|
|
)
|
|
sim, _ = onnxsim.simplify(os.path.join(onnx_path, "CnHubert.onnx"))
|
|
onnx.save_model(sim, os.path.join(onnx_path, "CnHubert.onnx"))
|
|
print("Exported CN-Hubert to ONNX format.")
|
|
|
|
|
|
class Text2SemanticLightningModule(LightningModule):
|
|
def __init__(self, path, top_k=20, cache_size=2000):
|
|
super().__init__()
|
|
dict_s1 = torch.load(path, map_location="cpu")
|
|
config = dict_s1["config"]
|
|
self.model = t2s.Text2SemanticDecoder(config=config)
|
|
self.load_state_dict(dict_s1["weight"])
|
|
self.cache_size = cache_size
|
|
self.top_k = top_k
|
|
|
|
def export_ar(path, top_k=20, cache_size=2000):
|
|
model_l = Text2SemanticLightningModule(path, top_k=top_k, cache_size=cache_size)
|
|
model = model_l.model
|
|
|
|
x = torch.randint(0, 100, (1, 20)).long()
|
|
x_len = torch.tensor([20]).long()
|
|
y = torch.randint(0, 100, (1, 20)).long()
|
|
y_len = torch.tensor([20]).long()
|
|
bert_feature = torch.randn(1, 20, 1024)
|
|
top_p = torch.tensor([0.8])
|
|
repetition_penalty = torch.tensor([1.35])
|
|
temperature = torch.tensor([0.6])
|
|
|
|
prompt_processor = t2s.PromptProcessor(cache_len=cache_size, model=model, top_k=top_k)
|
|
decode_next_token = t2s.DecodeNextToken(cache_len=cache_size, model=model, top_k=top_k)
|
|
|
|
torch.onnx.export(
|
|
prompt_processor,
|
|
(x, x_len, y, y_len, bert_feature, top_p, repetition_penalty, temperature),
|
|
os.path.join(onnx_path, "PromptProcessor.onnx"),
|
|
input_names=["x", "x_len", "y", "y_len", "bert_feature", "top_p", "repetition_penalty", "temperature"],
|
|
output_names=["y", "k_cache", "v_cache", "xy_pos", "y_idx", "samples"],
|
|
dynamic_axes={
|
|
"x": {0: "batch_size", 1: "sequence_length"},
|
|
"y": {0: "batch_size", 1: "sequence_length"},
|
|
"bert_feature": {0: "batch_size", 1: "sequence_length"},
|
|
},
|
|
opset_version=18,
|
|
)
|
|
|
|
sim, _ = onnxsim.simplify(os.path.join(onnx_path, "PromptProcessor.onnx"))
|
|
onnx.save_model(sim, os.path.join(onnx_path, "PromptProcessor.onnx"))
|
|
|
|
y, k_cache, v_cache, xy_pos, y_idx, samples = prompt_processor(
|
|
x, x_len, y, y_len, bert_feature, top_p, repetition_penalty, temperature
|
|
)
|
|
|
|
torch.onnx.export(
|
|
decode_next_token,
|
|
(y, k_cache, v_cache, xy_pos, y_idx, top_p, repetition_penalty, temperature),
|
|
os.path.join(onnx_path, "DecodeNextToken.onnx"),
|
|
input_names=["y", "k_cache", "v_cache", "xy_pos", "y_idx", "top_p", "repetition_penalty", "temperature"],
|
|
output_names=["y", "k_cache", "v_cache", "xy_pos", "y_idx", "samples"],
|
|
dynamic_axes={
|
|
"y": {0: "batch_size", 1: "sequence_length"},
|
|
"k_cache": {1: "batch_size", 2: "sequence_length"},
|
|
"v_cache": {1: "batch_size", 2: "sequence_length"},
|
|
},
|
|
opset_version=18
|
|
)
|
|
|
|
sim, _ = onnxsim.simplify(os.path.join(onnx_path, "DecodeNextToken.onnx"))
|
|
onnx.save_model(sim, os.path.join(onnx_path, "DecodeNextToken.onnx"))
|
|
|
|
|
|
from io import BytesIO
|
|
def load_sovits_new(sovits_path):
|
|
f=open(sovits_path,"rb")
|
|
meta=f.read(2)
|
|
if meta!="PK":
|
|
data = b'PK' + f.read()
|
|
bio = BytesIO()
|
|
bio.write(data)
|
|
bio.seek(0)
|
|
return torch.load(bio, map_location="cpu", weights_only=False)
|
|
return torch.load(sovits_path,map_location="cpu", weights_only=False)
|
|
|
|
|
|
class DictToAttrRecursive(dict):
|
|
def __init__(self, input_dict):
|
|
super().__init__(input_dict)
|
|
for key, value in input_dict.items():
|
|
if isinstance(value, dict):
|
|
value = DictToAttrRecursive(value)
|
|
self[key] = value
|
|
setattr(self, key, value)
|
|
|
|
def __getattr__(self, item):
|
|
try:
|
|
return self[item]
|
|
except KeyError:
|
|
raise AttributeError(f"Attribute {item} not found")
|
|
|
|
def __setattr__(self, key, value):
|
|
if isinstance(value, dict):
|
|
value = DictToAttrRecursive(value)
|
|
super(DictToAttrRecursive, self).__setitem__(key, value)
|
|
super().__setattr__(key, value)
|
|
|
|
def __delattr__(self, item):
|
|
try:
|
|
del self[item]
|
|
except KeyError:
|
|
raise AttributeError(f"Attribute {item} not found")
|
|
|
|
|
|
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
|
|
hann_window = torch.hann_window(win_size).to(
|
|
dtype=y.dtype, device=y.device
|
|
)
|
|
y = torch.nn.functional.pad(
|
|
y.unsqueeze(1),
|
|
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
|
mode="reflect",
|
|
)
|
|
y = y.squeeze(1)
|
|
spec = torch.stft(
|
|
y,
|
|
n_fft,
|
|
hop_length=hop_size,
|
|
win_length=win_size,
|
|
window=hann_window,
|
|
center=center,
|
|
pad_mode="reflect",
|
|
normalized=False,
|
|
onesided=True,
|
|
return_complex=False,
|
|
)
|
|
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
|
return spec
|
|
|
|
|
|
class Extractor(Module):
|
|
def __init__(self, model):
|
|
super(Extractor, self).__init__()
|
|
self.model = model
|
|
|
|
def forward(self, x):
|
|
return self.model.extract_latent(x.transpose(1, 2))
|
|
|
|
|
|
class V1V2(Module):
|
|
def __init__(self, path):
|
|
super(V1V2, self).__init__()
|
|
dict_s2 = load_sovits_new(path)
|
|
hps = dict_s2["config"]
|
|
hps = DictToAttrRecursive(hps)
|
|
hps.model.semantic_frame_rate = "25hz"
|
|
if 'enc_p.text_embedding.weight'not in dict_s2['weight']:
|
|
hps.model.version = "v2"#v3model,v2sybomls
|
|
elif dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
|
|
hps.model.version = "v1"
|
|
else:
|
|
hps.model.version = "v2"
|
|
version=hps.model.version
|
|
# print("sovits版本:",hps.model.version)
|
|
vq_model = SynthesizerTrn(
|
|
hps.data.filter_length // 2 + 1,
|
|
hps.train.segment_size // hps.data.hop_length,
|
|
n_speakers=hps.data.n_speakers,
|
|
**hps.model
|
|
)
|
|
vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
|
vq_model.eval()
|
|
self.vq_model = vq_model
|
|
self.hps = hps
|
|
self.ext = Extractor(self.vq_model)
|
|
|
|
def forward(self, text_seq, pred_semantic, ref_audio):
|
|
refer = spectrogram_torch(
|
|
ref_audio,
|
|
self.hps.data.filter_length,
|
|
self.hps.data.sampling_rate,
|
|
self.hps.data.hop_length,
|
|
self.hps.data.win_length,
|
|
center=False
|
|
)
|
|
return self.vq_model(pred_semantic.unsqueeze(0), text_seq, refer)[0, 0]
|
|
|
|
def export(self):
|
|
test_seq = torch.randint(0, 100, (1, 20)).long()
|
|
pred_semantic = torch.randint(0, 100, (1, 20)).long()
|
|
ref_audio = torch.randn(1, 16000 * 10)
|
|
torch.onnx.export(
|
|
self,
|
|
(test_seq, pred_semantic, ref_audio),
|
|
os.path.join(onnx_path, "GptSoVitsV1V2.onnx"),
|
|
input_names=["text_seq", "pred_semantic", "ref_audio"],
|
|
output_names=["output"],
|
|
dynamic_axes={
|
|
"text_seq": {0: "batch_size", 1: "sequence_length"},
|
|
"pred_semantic": {0: "batch_size", 1: "sequence_length"},
|
|
"ref_audio": {0: "batch_size", 1: "sequence_length"},
|
|
},
|
|
opset_version=18,
|
|
)
|
|
|
|
sim, _ = onnxsim.simplify(os.path.join(onnx_path, "GptSoVitsV1V2.onnx"))
|
|
onnx.save_model(sim, os.path.join(onnx_path, "GptSoVitsV1V2.onnx"))
|
|
ref_units = torch.randn(1, 20, 768)
|
|
torch.onnx.export(
|
|
self.ext,
|
|
ref_units,
|
|
os.path.join(onnx_path, "Extractor.onnx"),
|
|
input_names=["ref_units"],
|
|
output_names=["output"],
|
|
dynamic_axes={
|
|
"ref_units": {0: "batch_size", 1: "sequence_length"},
|
|
},
|
|
opset_version=18,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
#CnHubertWrapper().export_onnx()
|
|
#BertWrapper().export_onnx()
|
|
V1V2("D:\\VSGIT\GPT-SoVITS-main\\GPT_SoVITS\\GPT-SoVITS-v3lora-20250228\\GPT_SoVITS\\t\\SoVITS_weights\\小特.pth").export()
|
|
'''export_ar(
|
|
"D:\\VSGIT\GPT-SoVITS-main\\GPT_SoVITS\\GPT-SoVITS-v3lora-20250228\\GPT_SoVITS\\t\\GPT_weights\\小特.ckpt",
|
|
top_k=10,
|
|
cache_size=1500,
|
|
)'''
|
|
|