diff --git a/sat/data_video.py b/sat/data_video.py index 3783340..8f2ef23 100644 --- a/sat/data_video.py +++ b/sat/data_video.py @@ -378,8 +378,9 @@ class SFTDataset(Dataset): num_frames = max_num_frames start = int(skip_frms_num) end = int(start + num_frames / fps * actual_fps) - indices = np.arange(start, end, (end - start) / num_frames).astype(int) - temp_frms = vr.get_batch(np.arange(start, end)) + end_safty = min(int(start + num_frames / fps * actual_fps), int(ori_vlen)) + indices = np.arange(start, end, (end - start) // num_frames).astype(int) + temp_frms = vr.get_batch(np.arange(start, end_safty)) assert temp_frms is not None tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())] @@ -388,7 +389,7 @@ class SFTDataset(Dataset): num_frames = max_num_frames start = int(skip_frms_num) end = int(ori_vlen - skip_frms_num) - indices = np.arange(start, end, (end - start) / num_frames).astype(int) + indices = np.arange(start, end, (end - start) // num_frames).astype(int) temp_frms = vr.get_batch(np.arange(start, end)) assert temp_frms is not None tensor_frms = (