mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 19:41:59 +08:00
feat: auto-extract first frames as conditioning images for i2v model
When training i2v models without specifying image_column, automatically extract and use first frames from training videos as conditioning images. This includes: - Add load_images_from_videos() utility function to extract and cache first frames - Update BaseI2VDataset to support auto-extraction when image_column is None - Add validation and warning message in Args schema for i2v without image_column The first frames are extracted once and cached to avoid repeated video loading.
This commit is contained in:
parent
96e511b413
commit
e084a4a270
@ -13,6 +13,7 @@ from finetune.constants import LOG_LEVEL, LOG_NAME
|
||||
|
||||
from .utils import (
|
||||
load_images,
|
||||
load_images_from_videos,
|
||||
load_prompts,
|
||||
load_videos,
|
||||
preprocess_image_with_resize,
|
||||
@ -53,7 +54,7 @@ class BaseI2VDataset(Dataset):
|
||||
data_root: str,
|
||||
caption_column: str,
|
||||
video_column: str,
|
||||
image_column: str,
|
||||
image_column: str | None,
|
||||
device: torch.device,
|
||||
trainer: "Trainer" = None,
|
||||
*args,
|
||||
@ -64,7 +65,10 @@ class BaseI2VDataset(Dataset):
|
||||
data_root = Path(data_root)
|
||||
self.prompts = load_prompts(data_root / caption_column)
|
||||
self.videos = load_videos(data_root / video_column)
|
||||
self.images = load_images(data_root / image_column)
|
||||
if image_column is not None:
|
||||
self.images = load_images(data_root / image_column)
|
||||
else:
|
||||
self.images = load_images_from_videos(self.videos)
|
||||
self.trainer = trainer
|
||||
|
||||
self.device = device
|
||||
|
@ -1,3 +1,4 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
@ -31,6 +32,37 @@ def load_images(image_path: Path) -> List[Path]:
|
||||
return [image_path.parent / line.strip() for line in file.readlines() if len(line.strip()) > 0]
|
||||
|
||||
|
||||
def load_images_from_videos(videos_path: List[Path]) -> List[Path]:
|
||||
first_frames_dir = videos_path[0].parent.parent / "first_frames"
|
||||
first_frames_dir.mkdir(exist_ok=True)
|
||||
|
||||
first_frame_paths = []
|
||||
for video_path in videos_path:
|
||||
frame_path = first_frames_dir / f"{video_path.stem}.png"
|
||||
if frame_path.exists():
|
||||
first_frame_paths.append(frame_path)
|
||||
continue
|
||||
|
||||
# Open video
|
||||
cap = cv2.VideoCapture(str(video_path))
|
||||
|
||||
# Read first frame
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
raise RuntimeError(f"Failed to read video: {video_path}")
|
||||
|
||||
# Save frame as PNG with same name as video
|
||||
cv2.imwrite(str(frame_path), frame)
|
||||
logging.info(f"Saved first frame to {frame_path}")
|
||||
|
||||
# Release video capture
|
||||
cap.release()
|
||||
|
||||
first_frame_paths.append(frame_path)
|
||||
|
||||
return first_frame_paths
|
||||
|
||||
|
||||
########## preprocessors ##########
|
||||
|
||||
|
||||
|
@ -98,7 +98,9 @@ class Args(BaseModel):
|
||||
def validate_image_column(cls, v: str | None, info: ValidationInfo) -> str | None:
|
||||
values = info.data
|
||||
if values.get("model_type") == "i2v" and not v:
|
||||
raise ValueError("image_column must be specified when using i2v model")
|
||||
logging.warning(
|
||||
"No `image_column` specified for i2v model. Will automatically extract first frames from videos as conditioning images."
|
||||
)
|
||||
return v
|
||||
|
||||
@field_validator("validation_dir", "validation_prompts")
|
||||
|
Loading…
x
Reference in New Issue
Block a user