From 71cbe28e68cb6f58aaacb7975da438550e3a3517 Mon Sep 17 00:00:00 2001 From: zpeng11 Date: Tue, 19 Aug 2025 21:31:42 -0400 Subject: [PATCH] feat:optimize looping --- playground/freerun.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/playground/freerun.py b/playground/freerun.py index 684c664a..4377e44c 100644 --- a/playground/freerun.py +++ b/playground/freerun.py @@ -105,7 +105,7 @@ sdec = ort.InferenceSession(MODEL_PATH+"_export_t2s_sdec.onnx") "prompts": prompts }) -early_stop_num = -1 + prefix_len = prompts.shape[1] stop = False @@ -118,11 +118,7 @@ for idx in tqdm(range(1, 1500)): "iy_emb": y_emb, "ix_example": x_example }) - if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: - stop = True - if np.argmax(logits, axis=-1)[0] == 1024 or samples[0, 0] == 1024: - stop = True - if stop: + if np.argmax(logits, axis=-1)[0] == 1024 or samples[0, 0] == 1024: # 1024 is the EOS token break y[0, -1] = 0