From 703b0d8b2902a7ab560443f03e1e4724b48f6379 Mon Sep 17 00:00:00 2001 From: xf-4070 Date: Tue, 13 Aug 2024 11:11:29 +0800 Subject: [PATCH] fix out-of-index bugs --- sat/data_video.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sat/data_video.py b/sat/data_video.py index 3783340..8f2ef23 100644 --- a/sat/data_video.py +++ b/sat/data_video.py @@ -378,8 +378,9 @@ class SFTDataset(Dataset): num_frames = max_num_frames start = int(skip_frms_num) end = int(start + num_frames / fps * actual_fps) - indices = np.arange(start, end, (end - start) / num_frames).astype(int) - temp_frms = vr.get_batch(np.arange(start, end)) + end_safty = min(int(start + num_frames / fps * actual_fps), int(ori_vlen)) + indices = np.arange(start, end, (end - start) // num_frames).astype(int) + temp_frms = vr.get_batch(np.arange(start, end_safty)) assert temp_frms is not None tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())] @@ -388,7 +389,7 @@ class SFTDataset(Dataset): num_frames = max_num_frames start = int(skip_frms_num) end = int(ori_vlen - skip_frms_num) - indices = np.arange(start, end, (end - start) / num_frames).astype(int) + indices = np.arange(start, end, (end - start) // num_frames).astype(int) temp_frms = vr.get_batch(np.arange(start, end)) assert temp_frms is not None tensor_frms = (