feat:optimize looping

This commit is contained in:
zpeng11 2025-08-19 21:31:42 -04:00
parent 5c08328cf3
commit 71cbe28e68

View File

@ -105,7 +105,7 @@ sdec = ort.InferenceSession(MODEL_PATH+"_export_t2s_sdec.onnx")
"prompts": prompts
})
early_stop_num = -1
prefix_len = prompts.shape[1]
stop = False
@ -118,11 +118,7 @@ for idx in tqdm(range(1, 1500)):
"iy_emb": y_emb,
"ix_example": x_example
})
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
stop = True
if np.argmax(logits, axis=-1)[0] == 1024 or samples[0, 0] == 1024:
stop = True
if stop:
if np.argmax(logits, axis=-1)[0] == 1024 or samples[0, 0] == 1024: # 1024 is the EOS token
break
y[0, -1] = 0