mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-29 08:49:59 +08:00
feat:export onnx with combined graph ready, todo:link weights in onnx graph
This commit is contained in:
parent
16d30ce1e4
commit
77794a5923
@ -7,12 +7,10 @@ from module.models_onnx import SynthesizerTrn, symbols_v1, symbols_v2
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from sv import SV
|
from sv import SV
|
||||||
import onnx
|
import onnx
|
||||||
|
from onnx import helper, TensorProto
|
||||||
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
|
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
|
||||||
from transformers import HubertModel, HubertConfig
|
from transformers import HubertModel, HubertConfig
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import soundfile
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from text import cleaned_text_to_sequence
|
from text import cleaned_text_to_sequence
|
||||||
|
|
||||||
@ -104,8 +102,8 @@ class T2SInitStep(nn.Module):
|
|||||||
bert = bert.unsqueeze(0)
|
bert = bert.unsqueeze(0)
|
||||||
prompt = prompt_semantic.unsqueeze(0)
|
prompt = prompt_semantic.unsqueeze(0)
|
||||||
[y, k, v, y_emb, x_example] = self.fsdc(self.encoder(all_phoneme_ids, bert), prompt)
|
[y, k, v, y_emb, x_example] = self.fsdc(self.encoder(all_phoneme_ids, bert), prompt)
|
||||||
fake_logits = torch.randn((1, 1025)) # Dummy logits for ONNX export
|
fake_logits = torch.randn((1, 1025), dtype=torch.float32) # Dummy logits for ONNX export
|
||||||
fake_samples = torch.randn((1, 1)) # Dummy samples for ONNX export
|
fake_samples = torch.zeros((1, 1), dtype=torch.int32) # Dummy samples for ONNX export
|
||||||
return y, k, v, y_emb, x_example, fake_logits, fake_samples
|
return y, k, v, y_emb, x_example, fake_logits, fake_samples
|
||||||
|
|
||||||
class T2SStageStep(nn.Module):
|
class T2SStageStep(nn.Module):
|
||||||
@ -115,7 +113,7 @@ class T2SStageStep(nn.Module):
|
|||||||
|
|
||||||
def forward(self, iy, ik, iv, iy_emb, ix_example):
|
def forward(self, iy, ik, iv, iy_emb, ix_example):
|
||||||
[y, k, v, y_emb, logits, samples] = self.stage_decoder(iy, ik, iv, iy_emb, ix_example)
|
[y, k, v, y_emb, logits, samples] = self.stage_decoder(iy, ik, iv, iy_emb, ix_example)
|
||||||
fake_x_example = torch.randn((1, 512)) # Dummy x_example for ONNX export
|
fake_x_example = torch.randn((1, 512), dtype=torch.float32) # Dummy x_example for ONNX export
|
||||||
return y, k, v, y_emb, fake_x_example, logits, samples
|
return y, k, v, y_emb, fake_x_example, logits, samples
|
||||||
|
|
||||||
class T2SModel(nn.Module):
|
class T2SModel(nn.Module):
|
||||||
@ -302,20 +300,63 @@ class AudioPreprocess(nn.Module):
|
|||||||
|
|
||||||
return ssl_content, spectrum, sv_emb
|
return ssl_content, spectrum, sv_emb
|
||||||
|
|
||||||
# def combineInitStepAndStageStep(init_step_onnx_path, stage_step_onnx_path):
|
def combineInitStepAndStageStep(init_step_onnx_path, stage_step_onnx_path, combined_onnx_path):
|
||||||
# init_step_model = onnx.load(init_step_onnx_path)
|
init_step_model = onnx.load(init_step_onnx_path)
|
||||||
# stage_step_model = onnx.load(stage_step_onnx_path)
|
stage_step_model = onnx.load(stage_step_onnx_path)
|
||||||
|
|
||||||
# # Combine the models (this is a simplified example; actual combination logic may vary)
|
then_graph = init_step_model.graph
|
||||||
# combined_graph = helper.make_graph(
|
then_graph.name = "init_step_graph"
|
||||||
# nodes=init_step_model.graph.node + stage_step_model.graph.node,
|
else_graph = stage_step_model.graph
|
||||||
# name="combined_graph",
|
else_graph.name = "stage_step_graph"
|
||||||
# inputs=init_step_model.graph.input,
|
|
||||||
# outputs=stage_step_model.graph.output
|
data_inputs_init = [input for input in init_step_model.graph.input]
|
||||||
# )
|
data_inputs_stage = [input for input in stage_step_model.graph.input]
|
||||||
|
|
||||||
|
del then_graph.input[:]
|
||||||
|
del else_graph.input[:]
|
||||||
|
|
||||||
|
# The output names of the subgraphs must be the same.
|
||||||
|
# The 'If' node will have an output with this same name.
|
||||||
|
subgraph_output_names = [output.name for output in then_graph.output]
|
||||||
|
for i, output in enumerate(else_graph.output):
|
||||||
|
assert subgraph_output_names[i] == output.name, "Subgraph output names must match"
|
||||||
|
|
||||||
|
# Define the inputs for the main graph
|
||||||
|
# 1. The boolean condition to select the branch
|
||||||
|
cond_input = helper.make_tensor_value_info('if_init_step', TensorProto.BOOL, [])
|
||||||
|
|
||||||
|
main_outputs = [output for output in init_step_model.graph.output]
|
||||||
|
|
||||||
|
# Create the 'If' node
|
||||||
|
if_node = helper.make_node(
|
||||||
|
'If',
|
||||||
|
inputs=['if_init_step'],
|
||||||
|
outputs=subgraph_output_names, # This name MUST match the subgraph's output name
|
||||||
|
then_branch=then_graph,
|
||||||
|
else_branch=else_graph
|
||||||
|
)
|
||||||
|
|
||||||
|
# Combine the models (this is a simplified example; actual combination logic may vary)
|
||||||
|
main_graph = helper.make_graph(
|
||||||
|
nodes=[if_node],
|
||||||
|
name="t2s_combined_graph",
|
||||||
|
inputs=[cond_input] + data_inputs_init + data_inputs_stage,
|
||||||
|
outputs=main_outputs
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the final combined model, specifying the opset and IR version
|
||||||
|
opset_version = 16
|
||||||
|
final_model = helper.make_model(main_graph,
|
||||||
|
producer_name='GSV-ONNX-Exporter',
|
||||||
|
ir_version=9, # For compatibility with older onnxruntime
|
||||||
|
opset_imports=[helper.make_opsetid("", opset_version)])
|
||||||
|
# Check the model for correctness
|
||||||
|
onnx.checker.check_model(final_model)
|
||||||
|
|
||||||
|
# Save the combined model
|
||||||
|
onnx.save(final_model, combined_onnx_path)
|
||||||
|
print(f"Combined model saved to {combined_onnx_path}")
|
||||||
|
|
||||||
# combined_model = helper.make_model(combined_graph, producer_name='onnx-combiner')
|
|
||||||
# onnx.save(combined_model, f"onnx/{project_name}/{project_name}_combined.onnx")
|
|
||||||
|
|
||||||
def export(vits_path, gpt_path, project_name, voice_model_version="v2"):
|
def export(vits_path, gpt_path, project_name, voice_model_version="v2"):
|
||||||
vits = VitsModel(vits_path, version=voice_model_version)
|
vits = VitsModel(vits_path, version=voice_model_version)
|
||||||
@ -428,5 +469,6 @@ if __name__ == "__main__":
|
|||||||
exp_path = "v2proplus_export"
|
exp_path = "v2proplus_export"
|
||||||
version = "v2ProPlus"
|
version = "v2ProPlus"
|
||||||
export(vits_path, gpt_path, exp_path, version)
|
export(vits_path, gpt_path, exp_path, version)
|
||||||
|
combineInitStepAndStageStep('onnx/v2proplus_export/v2proplus_export_t2s_init_step.onnx', 'onnx/v2proplus_export/v2proplus_export_t2s_sdec.onnx', 'onnx/v2proplus_export/v2proplus_export_t2s_combined.onnx')
|
||||||
|
|
||||||
|
|
||||||
|
@ -63,7 +63,7 @@ def preprocess_text(text:str):
|
|||||||
|
|
||||||
# input_phones_saved = np.load("playground/ref/input_phones.npy")
|
# input_phones_saved = np.load("playground/ref/input_phones.npy")
|
||||||
# input_bert_saved = np.load("playground/ref/input_bert.npy").T.astype(np.float32)
|
# input_bert_saved = np.load("playground/ref/input_bert.npy").T.astype(np.float32)
|
||||||
[input_phones, input_bert] = preprocess_text("天上的风筝在天上飞,地上的人儿在地上追")
|
[input_phones, input_bert] = preprocess_text("地上的人儿吵吵闹闹在地上追")
|
||||||
|
|
||||||
|
|
||||||
# ref_phones = np.load("playground/ref/ref_phones.npy")
|
# ref_phones = np.load("playground/ref/ref_phones.npy")
|
||||||
@ -74,29 +74,33 @@ def preprocess_text(text:str):
|
|||||||
[audio_prompt_hubert, spectrum, sv_emb] = audio_preprocess("playground/ref/audio.wav")
|
[audio_prompt_hubert, spectrum, sv_emb] = audio_preprocess("playground/ref/audio.wav")
|
||||||
|
|
||||||
|
|
||||||
init_step = ort.InferenceSession(MODEL_PATH+"_export_t2s_init_step.onnx")
|
t2s_combined = ort.InferenceSession(MODEL_PATH+"_export_t2s_combined.onnx")
|
||||||
|
|
||||||
[y, k, v, y_emb, x_example, fake_logits, fake_samples] = init_step.run(None, {
|
[y, k, v, y_emb, x_example, fake_logits, fake_samples] = t2s_combined.run(None, {
|
||||||
|
"if_init_step": np.array(True, dtype=bool),
|
||||||
"input_text_phones": input_phones,
|
"input_text_phones": input_phones,
|
||||||
"input_text_bert": input_bert,
|
"input_text_bert": input_bert,
|
||||||
"ref_text_phones": ref_phones,
|
"ref_text_phones": ref_phones,
|
||||||
"ref_text_bert": ref_bert,
|
"ref_text_bert": ref_bert,
|
||||||
"hubert_ssl_content": audio_prompt_hubert
|
"hubert_ssl_content": audio_prompt_hubert,
|
||||||
|
"iy":np.empty((1, 0), dtype=np.int64),
|
||||||
|
"ik":np.empty((24, 0, 1, 512), dtype=np.float32),
|
||||||
|
"iv":np.empty((24, 0, 1, 512), dtype=np.float32),
|
||||||
|
"iy_emb":np.empty((1, 0, 512), dtype=np.float32),
|
||||||
|
"ix_example":np.empty((1, 0), dtype=np.float32)
|
||||||
})
|
})
|
||||||
|
|
||||||
# fsdec = ort.InferenceSession(MODEL_PATH+"_export_t2s_fsdec.onnx")
|
|
||||||
sdec = ort.InferenceSession(MODEL_PATH+"_export_t2s_sdec.onnx")
|
|
||||||
|
|
||||||
# for i in tqdm(range(10000)):
|
|
||||||
# [y, k, v, y_emb, x_example] = fsdec.run(None, {
|
|
||||||
# "x": x,
|
|
||||||
# "prompts": prompts
|
|
||||||
# })
|
|
||||||
|
|
||||||
|
|
||||||
for idx in tqdm(range(1, 1500)):
|
for idx in tqdm(range(1, 1500)):
|
||||||
# [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
|
# [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
|
||||||
[y, k, v, y_emb, fake_x_example, logits, samples] = sdec.run(None, {
|
[y, k, v, y_emb, fake_x_example, logits, samples] = t2s_combined.run(None, {
|
||||||
|
"if_init_step": np.array(False, dtype=bool),
|
||||||
|
"input_text_phones": np.empty((1, 0), dtype=np.int64),
|
||||||
|
"input_text_bert": np.empty((0, 1024), dtype=np.float32),
|
||||||
|
"ref_text_phones": np.empty((1, 0), dtype=np.int64),
|
||||||
|
"ref_text_bert": np.empty((0, 1024), dtype=np.float32),
|
||||||
|
"hubert_ssl_content": np.empty((1, 768, 0), dtype=np.float32),
|
||||||
"iy": y,
|
"iy": y,
|
||||||
"ik": k,
|
"ik": k,
|
||||||
"iv": v,
|
"iv": v,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user