From 07766001f666ab3dead32a71f2d1e4c7d8b13e06 Mon Sep 17 00:00:00 2001 From: OleehyO Date: Wed, 8 Jan 2025 02:14:56 +0000 Subject: [PATCH] feat(dataset): pad short videos by repeating last frame When loading videos with fewer frames than max_num_frames, repeat the last frame to reach the required length instead of failing. This ensures consistent tensor dimensions across the dataset while preserving as much original video content as possible. --- finetune/datasets/utils.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/finetune/datasets/utils.py b/finetune/datasets/utils.py index a4c747c..9f29d4a 100644 --- a/finetune/datasets/utils.py +++ b/finetune/datasets/utils.py @@ -126,13 +126,20 @@ def preprocess_video_with_resize( video_reader = decord.VideoReader(uri=video_path.as_posix(), width=width, height=height) video_num_frames = len(video_reader) if video_num_frames < max_num_frames: - raise ValueError(f"video frame count in {video_path} is less than {max_num_frames}.") - - indices = list(range(0, video_num_frames, video_num_frames // max_num_frames)) - frames = video_reader.get_batch(indices) - frames = frames[:max_num_frames].float() - frames = frames.permute(0, 3, 1, 2).contiguous() - return frames + # Get all frames first + frames = video_reader.get_batch(list(range(video_num_frames))) + # Repeat the last frame until we reach max_num_frames + last_frame = frames[-1:] + num_repeats = max_num_frames - video_num_frames + repeated_frames = last_frame.repeat(num_repeats, 1, 1, 1) + frames = torch.cat([frames, repeated_frames], dim=0) + return frames.float().permute(0, 3, 1, 2).contiguous() + else: + indices = list(range(0, video_num_frames, video_num_frames // max_num_frames)) + frames = video_reader.get_batch(indices) + frames = frames[:max_num_frames].float() + frames = frames.permute(0, 3, 1, 2).contiguous() + return frames def preprocess_video_with_buckets(