update namings

This commit is contained in:
zpeng11 2025-08-20 20:21:42 -04:00
parent 4e0cc57052
commit bb529e7e4a
2 changed files with 13 additions and 14 deletions

View File

@ -157,14 +157,14 @@ class T2SModel(nn.Module):
self.init_step, self.init_step,
(ref_seq, text_seq, ref_bert, text_bert, ssl_content), (ref_seq, text_seq, ref_bert, text_bert, ssl_content),
f"onnx/{project_name}/{project_name}_t2s_init_step.onnx", f"onnx/{project_name}/{project_name}_t2s_init_step.onnx",
input_names=["ref_seq", "text_seq", "ref_bert", "text_bert", "ssl_content"], input_names=["ref_text_phones", "input_text_phones", "ref_text_bert", "input_text_bert", "hubert_ssl_content"],
output_names=["y", "k", "v", "y_emb", "x_example"], output_names=["y", "k", "v", "y_emb", "x_example"],
dynamic_axes={ dynamic_axes={
"ref_seq": {1: "ref_length"}, "ref_text_phones": {1: "ref_length"},
"text_seq": {1: "text_length"}, "input_text_phones": {1: "text_length"},
"ref_bert": {0: "ref_length"}, "ref_text_bert": {0: "ref_length"},
"text_bert": {0: "text_length"}, "input_text_bert": {0: "text_length"},
"ssl_content": {2: "ssl_length"}, "hubert_ssl_content": {2: "ssl_length"},
}, },
opset_version=16, opset_version=16,
) )
@ -254,9 +254,8 @@ class GptSoVits(nn.Module):
input_names=["text_seq", "pred_semantic", "spectrum", "sv_emb"], input_names=["text_seq", "pred_semantic", "spectrum", "sv_emb"],
output_names=["audio"], output_names=["audio"],
dynamic_axes={ dynamic_axes={
"text_seq": {1: "text_length"}, "input_text_phones": {1: "text_length"},
"pred_semantic": {2: "pred_length"}, "pred_semantic": {2: "pred_length"},
"ref_audio": {1: "audio_length"},
"spectrum": {2: "spectrum_length"}, "spectrum": {2: "spectrum_length"},
}, },
opset_version=17, opset_version=17,

View File

@ -77,11 +77,11 @@ def preprocess_text(text:str):
init_step = ort.InferenceSession(MODEL_PATH+"_export_t2s_init_step.onnx") init_step = ort.InferenceSession(MODEL_PATH+"_export_t2s_init_step.onnx")
[y, k, v, y_emb, x_example] = init_step.run(None, { [y, k, v, y_emb, x_example] = init_step.run(None, {
"text_seq": input_phones, "input_text_phones": input_phones,
"text_bert": input_bert, "input_text_bert": input_bert,
"ref_seq": ref_phones, "ref_text_phones": ref_phones,
"ref_bert": ref_bert, "ref_text_bert": ref_bert,
"ssl_content": audio_prompt_hubert "hubert_ssl_content": audio_prompt_hubert
}) })
# fsdec = ort.InferenceSession(MODEL_PATH+"_export_t2s_fsdec.onnx") # fsdec = ort.InferenceSession(MODEL_PATH+"_export_t2s_fsdec.onnx")
@ -120,7 +120,7 @@ ref_audio = waveform.numpy().astype(np.float32)
vtis = ort.InferenceSession(MODEL_PATH+"_export_vits.onnx") vtis = ort.InferenceSession(MODEL_PATH+"_export_vits.onnx")
[audio] = vtis.run(None, { [audio] = vtis.run(None, {
"text_seq": input_phones, "input_text_phones": input_phones,
"pred_semantic": pred_semantic, "pred_semantic": pred_semantic,
"spectrum": spectrum.astype(np.float32), "spectrum": spectrum.astype(np.float32),
"sv_emb": sv_emb.astype(np.float32) "sv_emb": sv_emb.astype(np.float32)