feat: auto-extract first frames as conditioning images for i2v model

When training i2v models without specifying image_column, automatically extract
and use first frames from training videos as conditioning images. This includes:

- Add load_images_from_videos() utility function to extract and cache first frames
- Update BaseI2VDataset to support auto-extraction when image_column is None
- Add validation and warning message in Args schema for i2v without image_column

The first frames are extracted once and cached to avoid repeated video loading.
This commit is contained in:
OleehyO 2025-01-07 06:43:26 +00:00
parent 96e511b413
commit e084a4a270
3 changed files with 41 additions and 3 deletions

View File

@ -13,6 +13,7 @@ from finetune.constants import LOG_LEVEL, LOG_NAME
from .utils import (
load_images,
load_images_from_videos,
load_prompts,
load_videos,
preprocess_image_with_resize,
@ -53,7 +54,7 @@ class BaseI2VDataset(Dataset):
data_root: str,
caption_column: str,
video_column: str,
image_column: str,
image_column: str | None,
device: torch.device,
trainer: "Trainer" = None,
*args,
@ -64,7 +65,10 @@ class BaseI2VDataset(Dataset):
data_root = Path(data_root)
self.prompts = load_prompts(data_root / caption_column)
self.videos = load_videos(data_root / video_column)
self.images = load_images(data_root / image_column)
if image_column is not None:
self.images = load_images(data_root / image_column)
else:
self.images = load_images_from_videos(self.videos)
self.trainer = trainer
self.device = device

View File

@ -1,3 +1,4 @@
import logging
from pathlib import Path
from typing import List, Tuple
@ -31,6 +32,37 @@ def load_images(image_path: Path) -> List[Path]:
return [image_path.parent / line.strip() for line in file.readlines() if len(line.strip()) > 0]
def load_images_from_videos(videos_path: List[Path]) -> List[Path]:
first_frames_dir = videos_path[0].parent.parent / "first_frames"
first_frames_dir.mkdir(exist_ok=True)
first_frame_paths = []
for video_path in videos_path:
frame_path = first_frames_dir / f"{video_path.stem}.png"
if frame_path.exists():
first_frame_paths.append(frame_path)
continue
# Open video
cap = cv2.VideoCapture(str(video_path))
# Read first frame
ret, frame = cap.read()
if not ret:
raise RuntimeError(f"Failed to read video: {video_path}")
# Save frame as PNG with same name as video
cv2.imwrite(str(frame_path), frame)
logging.info(f"Saved first frame to {frame_path}")
# Release video capture
cap.release()
first_frame_paths.append(frame_path)
return first_frame_paths
########## preprocessors ##########

View File

@ -98,7 +98,9 @@ class Args(BaseModel):
def validate_image_column(cls, v: str | None, info: ValidationInfo) -> str | None:
values = info.data
if values.get("model_type") == "i2v" and not v:
raise ValueError("image_column must be specified when using i2v model")
logging.warning(
"No `image_column` specified for i2v model. Will automatically extract first frames from videos as conditioning images."
)
return v
@field_validator("validation_dir", "validation_prompts")