mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 19:41:59 +08:00
- Add dataset implementations for text-to-video and image-to-video - Include bucket sampler for efficient batch processing - Add utility functions for data processing - Create dataset package structure with proper initialization
141 lines
4.7 KiB
Python
141 lines
4.7 KiB
Python
import torch
|
|
import cv2
|
|
|
|
from typing import List, Tuple
|
|
from pathlib import Path
|
|
from torchvision import transforms
|
|
from torchvision.transforms.functional import resize
|
|
|
|
# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
|
|
# Very few bug reports but it happens. Look in decord Github issues for more relevant information.
|
|
import decord # isort:skip
|
|
|
|
decord.bridge.set_bridge("torch")
|
|
|
|
|
|
########## loaders ##########
|
|
|
|
def load_prompts(prompt_path: Path) -> List[str]:
|
|
with open(prompt_path, "r", encoding="utf-8") as file:
|
|
return [line.strip() for line in file.readlines() if len(line.strip()) > 0]
|
|
|
|
|
|
def load_videos(video_path: Path) -> List[Path]:
|
|
with open(video_path, "r", encoding="utf-8") as file:
|
|
return [video_path.parent / line.strip() for line in file.readlines() if len(line.strip()) > 0]
|
|
|
|
|
|
def load_images(image_path: Path) -> List[Path]:
|
|
with open(image_path, "r", encoding="utf-8") as file:
|
|
return [image_path.parent / line.strip() for line in file.readlines() if len(line.strip()) > 0]
|
|
|
|
|
|
########## preprocessors ##########
|
|
|
|
def preprocess_image_with_resize(
|
|
image_path: Path,
|
|
height: int,
|
|
width: int,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Loads and resizes a single image.
|
|
|
|
Args:
|
|
image_path: Path to the image file.
|
|
height: Target height for resizing.
|
|
width: Target width for resizing.
|
|
|
|
Returns:
|
|
torch.Tensor: Image tensor with shape [C, H, W] where:
|
|
C = number of channels (3 for RGB)
|
|
H = height
|
|
W = width
|
|
"""
|
|
image = cv2.imread(image_path.as_posix())
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
image = cv2.resize(image, (width, height))
|
|
image = torch.from_numpy(image).float()
|
|
image = image.permute(2, 0, 1).contiguous()
|
|
return image
|
|
|
|
|
|
def preprocess_video_with_resize(
|
|
video_path: Path,
|
|
max_num_frames: int,
|
|
height: int,
|
|
width: int,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Loads and resizes a single video.
|
|
|
|
The function processes the video through these steps:
|
|
1. If video frame count > max_num_frames, downsample frames evenly
|
|
2. If video dimensions don't match (height, width), resize frames
|
|
|
|
Args:
|
|
video_path: Path to the video file.
|
|
max_num_frames: Maximum number of frames to keep.
|
|
height: Target height for resizing.
|
|
width: Target width for resizing.
|
|
|
|
Returns:
|
|
A torch.Tensor with shape [F, C, H, W] where:
|
|
F = number of frames
|
|
C = number of channels (3 for RGB)
|
|
H = height
|
|
W = width
|
|
"""
|
|
video_reader = decord.VideoReader(uri=video_path.as_posix(), width=width, height=height)
|
|
video_num_frames = len(video_reader)
|
|
|
|
indices = list(range(0, video_num_frames, video_num_frames // max_num_frames))
|
|
frames = video_reader.get_batch(indices)
|
|
frames = frames[: max_num_frames].float()
|
|
frames = frames.permute(0, 3, 1, 2).contiguous()
|
|
return frames
|
|
|
|
|
|
def preprocess_video_with_buckets(
|
|
video_path: Path,
|
|
resolution_buckets: List[Tuple[int, int, int]],
|
|
) -> torch.Tensor:
|
|
"""
|
|
Args:
|
|
video_path: Path to the video file.
|
|
resolution_buckets: List of tuples (num_frames, height, width) representing
|
|
available resolution buckets.
|
|
|
|
Returns:
|
|
torch.Tensor: Video tensor with shape [F, C, H, W] where:
|
|
F = number of frames
|
|
C = number of channels (3 for RGB)
|
|
H = height
|
|
W = width
|
|
|
|
The function processes the video through these steps:
|
|
1. Finds nearest frame bucket <= video frame count
|
|
2. Downsamples frames evenly to match bucket size
|
|
3. Finds nearest resolution bucket based on dimensions
|
|
4. Resizes frames to match bucket resolution
|
|
"""
|
|
video_reader = decord.VideoReader(uri=video_path.as_posix())
|
|
video_num_frames = len(video_reader)
|
|
resolution_buckets = [bucket for bucket in resolution_buckets if bucket[0] <= video_num_frames]
|
|
if len(resolution_buckets) == 0:
|
|
raise ValueError(f"video frame count in {video_path} is less than all frame buckets {resolution_buckets}")
|
|
|
|
nearest_frame_bucket = min(
|
|
resolution_buckets,
|
|
key=lambda bucket: video_num_frames - bucket[0],
|
|
default=1,
|
|
)[0]
|
|
frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket))
|
|
frames = video_reader.get_batch(frame_indices)
|
|
frames = frames[:nearest_frame_bucket].float()
|
|
frames = frames.permute(0, 3, 1, 2).contiguous()
|
|
|
|
nearest_res = min(resolution_buckets, key=lambda x: abs(x[1] - frames.shape[2]) + abs(x[2] - frames.shape[3]))
|
|
nearest_res = (nearest_res[1], nearest_res[2])
|
|
frames = torch.stack([resize(f, nearest_res) for f in frames], dim=0)
|
|
|
|
return frames |