diff --git a/finetune/datasets/utils.py b/finetune/datasets/utils.py index a4c747c..9f29d4a 100644 --- a/finetune/datasets/utils.py +++ b/finetune/datasets/utils.py @@ -126,13 +126,20 @@ def preprocess_video_with_resize( video_reader = decord.VideoReader(uri=video_path.as_posix(), width=width, height=height) video_num_frames = len(video_reader) if video_num_frames < max_num_frames: - raise ValueError(f"video frame count in {video_path} is less than {max_num_frames}.") - - indices = list(range(0, video_num_frames, video_num_frames // max_num_frames)) - frames = video_reader.get_batch(indices) - frames = frames[:max_num_frames].float() - frames = frames.permute(0, 3, 1, 2).contiguous() - return frames + # Get all frames first + frames = video_reader.get_batch(list(range(video_num_frames))) + # Repeat the last frame until we reach max_num_frames + last_frame = frames[-1:] + num_repeats = max_num_frames - video_num_frames + repeated_frames = last_frame.repeat(num_repeats, 1, 1, 1) + frames = torch.cat([frames, repeated_frames], dim=0) + return frames.float().permute(0, 3, 1, 2).contiguous() + else: + indices = list(range(0, video_num_frames, video_num_frames // max_num_frames)) + frames = video_reader.get_batch(indices) + frames = frames[:max_num_frames].float() + frames = frames.permute(0, 3, 1, 2).contiguous() + return frames def preprocess_video_with_buckets(