🐛 fix frame padding bug

This commit is contained in:
Jia Zheng 2024-08-22 19:50:23 +08:00 committed by bertjiazheng
parent da1af26d57
commit b10e444d94

View File

@ -145,9 +145,10 @@ def resize_for_rectangle_crop(arr, image_size, reshape_mode="random"):
def pad_last_frame(tensor, num_frames):
# T, H, W, C
if tensor.shape[0] < num_frames:
last_frame = tensor[-int(num_frames - tensor.shape[1]) :]
padded_tensor = torch.cat([tensor, last_frame], dim=0)
if len(tensor) < num_frames:
pad_length = num_frames - len(tensor)
pad_tensor = torch.zeros([pad_length, *tensor.shape[1:]], dtype=tensor.dtype, device=tensor.device)
padded_tensor = torch.cat([tensor, pad_tensor], dim=0)
return padded_tensor
else:
return tensor[:num_frames]
@ -418,7 +419,7 @@ class SFTDataset(Dataset):
)
tensor_frms = pad_last_frame(
tensor_frms, num_frames
tensor_frms, max_num_frames
) # the len of indices may be less than num_frames, due to round error
tensor_frms = tensor_frms.permute(0, 3, 1, 2) # [T, H, W, C] -> [T, C, H, W]
tensor_frms = resize_for_rectangle_crop(tensor_frms, video_size, reshape_mode="center")