mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-06 03:57:56 +08:00
feat: auto-extract first frames as conditioning images for i2v model
When training i2v models without specifying image_column, automatically extract and use first frames from training videos as conditioning images. This includes: - Add load_images_from_videos() utility function to extract and cache first frames - Update BaseI2VDataset to support auto-extraction when image_column is None - Add validation and warning message in Args schema for i2v without image_column The first frames are extracted once and cached to avoid repeated video loading.
This commit is contained in:
parent
96e511b413
commit
e084a4a270
@ -13,6 +13,7 @@ from finetune.constants import LOG_LEVEL, LOG_NAME
|
|||||||
|
|
||||||
from .utils import (
|
from .utils import (
|
||||||
load_images,
|
load_images,
|
||||||
|
load_images_from_videos,
|
||||||
load_prompts,
|
load_prompts,
|
||||||
load_videos,
|
load_videos,
|
||||||
preprocess_image_with_resize,
|
preprocess_image_with_resize,
|
||||||
@ -53,7 +54,7 @@ class BaseI2VDataset(Dataset):
|
|||||||
data_root: str,
|
data_root: str,
|
||||||
caption_column: str,
|
caption_column: str,
|
||||||
video_column: str,
|
video_column: str,
|
||||||
image_column: str,
|
image_column: str | None,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
trainer: "Trainer" = None,
|
trainer: "Trainer" = None,
|
||||||
*args,
|
*args,
|
||||||
@ -64,7 +65,10 @@ class BaseI2VDataset(Dataset):
|
|||||||
data_root = Path(data_root)
|
data_root = Path(data_root)
|
||||||
self.prompts = load_prompts(data_root / caption_column)
|
self.prompts = load_prompts(data_root / caption_column)
|
||||||
self.videos = load_videos(data_root / video_column)
|
self.videos = load_videos(data_root / video_column)
|
||||||
self.images = load_images(data_root / image_column)
|
if image_column is not None:
|
||||||
|
self.images = load_images(data_root / image_column)
|
||||||
|
else:
|
||||||
|
self.images = load_images_from_videos(self.videos)
|
||||||
self.trainer = trainer
|
self.trainer = trainer
|
||||||
|
|
||||||
self.device = device
|
self.device = device
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
@ -31,6 +32,37 @@ def load_images(image_path: Path) -> List[Path]:
|
|||||||
return [image_path.parent / line.strip() for line in file.readlines() if len(line.strip()) > 0]
|
return [image_path.parent / line.strip() for line in file.readlines() if len(line.strip()) > 0]
|
||||||
|
|
||||||
|
|
||||||
|
def load_images_from_videos(videos_path: List[Path]) -> List[Path]:
|
||||||
|
first_frames_dir = videos_path[0].parent.parent / "first_frames"
|
||||||
|
first_frames_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
first_frame_paths = []
|
||||||
|
for video_path in videos_path:
|
||||||
|
frame_path = first_frames_dir / f"{video_path.stem}.png"
|
||||||
|
if frame_path.exists():
|
||||||
|
first_frame_paths.append(frame_path)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Open video
|
||||||
|
cap = cv2.VideoCapture(str(video_path))
|
||||||
|
|
||||||
|
# Read first frame
|
||||||
|
ret, frame = cap.read()
|
||||||
|
if not ret:
|
||||||
|
raise RuntimeError(f"Failed to read video: {video_path}")
|
||||||
|
|
||||||
|
# Save frame as PNG with same name as video
|
||||||
|
cv2.imwrite(str(frame_path), frame)
|
||||||
|
logging.info(f"Saved first frame to {frame_path}")
|
||||||
|
|
||||||
|
# Release video capture
|
||||||
|
cap.release()
|
||||||
|
|
||||||
|
first_frame_paths.append(frame_path)
|
||||||
|
|
||||||
|
return first_frame_paths
|
||||||
|
|
||||||
|
|
||||||
########## preprocessors ##########
|
########## preprocessors ##########
|
||||||
|
|
||||||
|
|
||||||
|
@ -98,7 +98,9 @@ class Args(BaseModel):
|
|||||||
def validate_image_column(cls, v: str | None, info: ValidationInfo) -> str | None:
|
def validate_image_column(cls, v: str | None, info: ValidationInfo) -> str | None:
|
||||||
values = info.data
|
values = info.data
|
||||||
if values.get("model_type") == "i2v" and not v:
|
if values.get("model_type") == "i2v" and not v:
|
||||||
raise ValueError("image_column must be specified when using i2v model")
|
logging.warning(
|
||||||
|
"No `image_column` specified for i2v model. Will automatically extract first frames from videos as conditioning images."
|
||||||
|
)
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@field_validator("validation_dir", "validation_prompts")
|
@field_validator("validation_dir", "validation_prompts")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user