feat(dataset): pad short videos by repeating last frame

When loading videos with fewer frames than max_num_frames, repeat the last
frame to reach the required length instead of failing. This ensures consistent
tensor dimensions across the dataset while preserving as much original video
content as possible.
This commit is contained in:
OleehyO 2025-01-08 02:14:56 +00:00
parent 249fadfb76
commit 07766001f6

View File

@ -126,13 +126,20 @@ def preprocess_video_with_resize(
video_reader = decord.VideoReader(uri=video_path.as_posix(), width=width, height=height)
video_num_frames = len(video_reader)
if video_num_frames < max_num_frames:
raise ValueError(f"video frame count in {video_path} is less than {max_num_frames}.")
indices = list(range(0, video_num_frames, video_num_frames // max_num_frames))
frames = video_reader.get_batch(indices)
frames = frames[:max_num_frames].float()
frames = frames.permute(0, 3, 1, 2).contiguous()
return frames
# Get all frames first
frames = video_reader.get_batch(list(range(video_num_frames)))
# Repeat the last frame until we reach max_num_frames
last_frame = frames[-1:]
num_repeats = max_num_frames - video_num_frames
repeated_frames = last_frame.repeat(num_repeats, 1, 1, 1)
frames = torch.cat([frames, repeated_frames], dim=0)
return frames.float().permute(0, 3, 1, 2).contiguous()
else:
indices = list(range(0, video_num_frames, video_num_frames // max_num_frames))
frames = video_reader.get_batch(indices)
frames = frames[:max_num_frames].float()
frames = frames.permute(0, 3, 1, 2).contiguous()
return frames
def preprocess_video_with_buckets(