From b10e444d945146f173b74bbc993d941f8afc93d3 Mon Sep 17 00:00:00 2001 From: Jia Zheng Date: Thu, 22 Aug 2024 19:50:23 +0800 Subject: [PATCH] :bug: fix frame padding bug --- sat/data_video.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/sat/data_video.py b/sat/data_video.py index 8f2ef23..d16667f 100644 --- a/sat/data_video.py +++ b/sat/data_video.py @@ -145,9 +145,10 @@ def resize_for_rectangle_crop(arr, image_size, reshape_mode="random"): def pad_last_frame(tensor, num_frames): # T, H, W, C - if tensor.shape[0] < num_frames: - last_frame = tensor[-int(num_frames - tensor.shape[1]) :] - padded_tensor = torch.cat([tensor, last_frame], dim=0) + if len(tensor) < num_frames: + pad_length = num_frames - len(tensor) + pad_tensor = torch.zeros([pad_length, *tensor.shape[1:]], dtype=tensor.dtype, device=tensor.device) + padded_tensor = torch.cat([tensor, pad_tensor], dim=0) return padded_tensor else: return tensor[:num_frames] @@ -418,7 +419,7 @@ class SFTDataset(Dataset): ) tensor_frms = pad_last_frame( - tensor_frms, num_frames + tensor_frms, max_num_frames ) # the len of indices may be less than num_frames, due to round error tensor_frms = tensor_frms.permute(0, 3, 1, 2) # [T, H, W, C] -> [T, C, H, W] tensor_frms = resize_for_rectangle_crop(tensor_frms, video_size, reshape_mode="center")