diff --git a/finetune/datasets/i2v_dataset.py b/finetune/datasets/i2v_dataset.py index 6b06da4..b26bb7f 100644 --- a/finetune/datasets/i2v_dataset.py +++ b/finetune/datasets/i2v_dataset.py @@ -13,6 +13,7 @@ from finetune.constants import LOG_LEVEL, LOG_NAME from .utils import ( load_images, + load_images_from_videos, load_prompts, load_videos, preprocess_image_with_resize, @@ -53,7 +54,7 @@ class BaseI2VDataset(Dataset): data_root: str, caption_column: str, video_column: str, - image_column: str, + image_column: str | None, device: torch.device, trainer: "Trainer" = None, *args, @@ -64,7 +65,10 @@ class BaseI2VDataset(Dataset): data_root = Path(data_root) self.prompts = load_prompts(data_root / caption_column) self.videos = load_videos(data_root / video_column) - self.images = load_images(data_root / image_column) + if image_column is not None: + self.images = load_images(data_root / image_column) + else: + self.images = load_images_from_videos(self.videos) self.trainer = trainer self.device = device diff --git a/finetune/datasets/utils.py b/finetune/datasets/utils.py index d28975e..1048d79 100644 --- a/finetune/datasets/utils.py +++ b/finetune/datasets/utils.py @@ -1,3 +1,4 @@ +import logging from pathlib import Path from typing import List, Tuple @@ -31,6 +32,37 @@ def load_images(image_path: Path) -> List[Path]: return [image_path.parent / line.strip() for line in file.readlines() if len(line.strip()) > 0] +def load_images_from_videos(videos_path: List[Path]) -> List[Path]: + first_frames_dir = videos_path[0].parent.parent / "first_frames" + first_frames_dir.mkdir(exist_ok=True) + + first_frame_paths = [] + for video_path in videos_path: + frame_path = first_frames_dir / f"{video_path.stem}.png" + if frame_path.exists(): + first_frame_paths.append(frame_path) + continue + + # Open video + cap = cv2.VideoCapture(str(video_path)) + + # Read first frame + ret, frame = cap.read() + if not ret: + raise RuntimeError(f"Failed to read video: {video_path}") + + # Save frame as PNG with same name as video + cv2.imwrite(str(frame_path), frame) + logging.info(f"Saved first frame to {frame_path}") + + # Release video capture + cap.release() + + first_frame_paths.append(frame_path) + + return first_frame_paths + + ########## preprocessors ########## diff --git a/finetune/schemas/args.py b/finetune/schemas/args.py index 3074872..b5fc781 100644 --- a/finetune/schemas/args.py +++ b/finetune/schemas/args.py @@ -98,7 +98,9 @@ class Args(BaseModel): def validate_image_column(cls, v: str | None, info: ValidationInfo) -> str | None: values = info.data if values.get("model_type") == "i2v" and not v: - raise ValueError("image_column must be specified when using i2v model") + logging.warning( + "No `image_column` specified for i2v model. Will automatically extract first frames from videos as conditioning images." + ) return v @field_validator("validation_dir", "validation_prompts")