mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 03:04:56 +08:00
feat(dataset): pad short videos by repeating last frame
When loading videos with fewer frames than max_num_frames, repeat the last frame to reach the required length instead of failing. This ensures consistent tensor dimensions across the dataset while preserving as much original video content as possible.
This commit is contained in:
parent
249fadfb76
commit
07766001f6
@ -126,13 +126,20 @@ def preprocess_video_with_resize(
|
||||
video_reader = decord.VideoReader(uri=video_path.as_posix(), width=width, height=height)
|
||||
video_num_frames = len(video_reader)
|
||||
if video_num_frames < max_num_frames:
|
||||
raise ValueError(f"video frame count in {video_path} is less than {max_num_frames}.")
|
||||
|
||||
indices = list(range(0, video_num_frames, video_num_frames // max_num_frames))
|
||||
frames = video_reader.get_batch(indices)
|
||||
frames = frames[:max_num_frames].float()
|
||||
frames = frames.permute(0, 3, 1, 2).contiguous()
|
||||
return frames
|
||||
# Get all frames first
|
||||
frames = video_reader.get_batch(list(range(video_num_frames)))
|
||||
# Repeat the last frame until we reach max_num_frames
|
||||
last_frame = frames[-1:]
|
||||
num_repeats = max_num_frames - video_num_frames
|
||||
repeated_frames = last_frame.repeat(num_repeats, 1, 1, 1)
|
||||
frames = torch.cat([frames, repeated_frames], dim=0)
|
||||
return frames.float().permute(0, 3, 1, 2).contiguous()
|
||||
else:
|
||||
indices = list(range(0, video_num_frames, video_num_frames // max_num_frames))
|
||||
frames = video_reader.get_batch(indices)
|
||||
frames = frames[:max_num_frames].float()
|
||||
frames = frames.permute(0, 3, 1, 2).contiguous()
|
||||
return frames
|
||||
|
||||
|
||||
def preprocess_video_with_buckets(
|
||||
|
Loading…
x
Reference in New Issue
Block a user