diff --git a/GPT_SoVITS/onnx_export.py b/GPT_SoVITS/onnx_export.py index 41408059..3a940f65 100644 --- a/GPT_SoVITS/onnx_export.py +++ b/GPT_SoVITS/onnx_export.py @@ -89,7 +89,7 @@ class DictToAttrRecursive(dict): raise AttributeError(f"Attribute {item} not found") -class T2SEncoder(nn.Module): +class T2SInitStep(nn.Module): def __init__(self, t2s, vits): super().__init__() self.encoder = t2s.onnx_encoder @@ -122,7 +122,7 @@ class T2SModel(nn.Module): self.t2s_model.model.early_stop_num = torch.LongTensor([self.hz * self.max_sec]) self.t2s_model = self.t2s_model.model self.t2s_model.init_onnx() - self.onnx_encoder = T2SEncoder(self.t2s_model, self.vits_model) + self.init_step = T2SInitStep(self.t2s_model, self.vits_model) self.first_stage_decoder = self.t2s_model.first_stage_decoder self.stage_decoder = self.t2s_model.stage_decoder # self.t2s_model = torch.jit.script(self.t2s_model) @@ -131,7 +131,7 @@ class T2SModel(nn.Module): early_stop_num = self.t2s_model.early_stop_num # [1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N] - y, k, v, y_emb, x_example = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content) + y, k, v, y_emb, x_example = self.init_step(ref_seq, text_seq, ref_bert, text_bert, ssl_content) for idx in tqdm(range(1, 20)): # This is a fake one! do take this as reference # [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N] @@ -144,19 +144,19 @@ class T2SModel(nn.Module): return y[:, -idx:].unsqueeze(0) def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name, dynamo=False): - # self.onnx_encoder = torch.jit.script(self.onnx_encoder) + # self.init_step = torch.jit.script(self.init_step) if dynamo: export_options = torch.onnx.ExportOptions(dynamic_shapes=True) - onnx_encoder_export_output = torch.onnx.dynamo_export( - self.onnx_encoder, (ref_seq, text_seq, ref_bert, text_bert, ssl_content), export_options=export_options + init_step_export_output = torch.onnx.dynamo_export( + self.init_step, (ref_seq, text_seq, ref_bert, text_bert, ssl_content), export_options=export_options ) - onnx_encoder_export_output.save(f"onnx/{project_name}/{project_name}_t2s_encoder.onnx") + init_step_export_output.save(f"onnx/{project_name}/{project_name}_t2s_init_step.onnx") return torch.onnx.export( - self.onnx_encoder, + self.init_step, (ref_seq, text_seq, ref_bert, text_bert, ssl_content), - f"onnx/{project_name}/{project_name}_t2s_encoder.onnx", + f"onnx/{project_name}/{project_name}_t2s_init_step.onnx", input_names=["ref_seq", "text_seq", "ref_bert", "text_bert", "ssl_content"], output_names=["y", "k", "v", "y_emb", "x_example"], dynamic_axes={ @@ -168,7 +168,7 @@ class T2SModel(nn.Module): }, opset_version=16, ) - y, k, v, y_emb, x_example = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content) + y, k, v, y_emb, x_example = self.init_step(ref_seq, text_seq, ref_bert, text_bert, ssl_content) # torch.onnx.export( # self.first_stage_decoder, diff --git a/playground/freerun.py b/playground/freerun.py index ab21410f..f169fd99 100644 --- a/playground/freerun.py +++ b/playground/freerun.py @@ -63,7 +63,7 @@ def preprocess_text(text:str): # input_phones_saved = np.load("playground/ref/input_phones.npy") # input_bert_saved = np.load("playground/ref/input_bert.npy").T.astype(np.float32) -[input_phones, input_bert] = preprocess_text("像大雨匆匆打击过的屋檐") +[input_phones, input_bert] = preprocess_text("天上的风筝在天上飞,地上的人儿在地上追") # ref_phones = np.load("playground/ref/ref_phones.npy") @@ -74,9 +74,9 @@ def preprocess_text(text:str): [audio_prompt_hubert, spectrum, sv_emb] = audio_preprocess("playground/ref/audio.wav") -encoder = ort.InferenceSession(MODEL_PATH+"_export_t2s_encoder.onnx") +init_step = ort.InferenceSession(MODEL_PATH+"_export_t2s_init_step.onnx") -[y, k, v, y_emb, x_example] = encoder.run(None, { +[y, k, v, y_emb, x_example] = init_step.run(None, { "text_seq": input_phones, "text_bert": input_bert, "ref_seq": ref_phones,