From 6eae5c201e24150b59a87f2cb39102698aa2484f Mon Sep 17 00:00:00 2001 From: OleehyO Date: Mon, 30 Dec 2024 16:10:06 +0000 Subject: [PATCH] feat: add latent caching for video encodings - Add caching mechanism to store VAE-encoded video latents to disk - Cache latents in a "latent" subdirectory alongside video files - Skip re-encoding when cached latent file exists - Add logging for successful cache saves - Minor code cleanup and formatting improvements This change improves training efficiency by avoiding redundant video encoding operations. --- finetune/datasets/i2v_dataset.py | 115 ++++++++++++++++++++++++------- finetune/datasets/t2v_dataset.py | 90 +++++++++++++++++++----- 2 files changed, 164 insertions(+), 41 deletions(-) diff --git a/finetune/datasets/i2v_dataset.py b/finetune/datasets/i2v_dataset.py index 2137dce..e993c96 100644 --- a/finetune/datasets/i2v_dataset.py +++ b/finetune/datasets/i2v_dataset.py @@ -1,11 +1,13 @@ +import torch + from pathlib import Path -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Tuple, Callable from typing_extensions import override -import torch from accelerate.logging import get_logger from torch.utils.data import Dataset from torchvision import transforms +from finetune.constants import LOG_NAME, LOG_LEVEL from .utils import ( load_prompts, load_videos, load_images, @@ -21,12 +23,22 @@ import decord # isort:skip decord.bridge.set_bridge("torch") -logger = get_logger(__name__) +logger = get_logger(LOG_NAME, LOG_LEVEL) class BaseI2VDataset(Dataset): """ + Base dataset class for Image-to-Video (I2V) training. + This dataset loads prompts, videos and corresponding conditioning images for I2V training. + + Args: + data_root (str): Root directory containing the dataset files + caption_column (str): Path to file containing text prompts/captions + video_column (str): Path to file containing video paths + image_column (str): Path to file containing image paths + device (torch.device): Device to load the data on + encode_video_fn (Callable[[torch.Tensor], torch.Tensor], optional): Function to encode videos """ def __init__( self, @@ -34,6 +46,8 @@ class BaseI2VDataset(Dataset): caption_column: str, video_column: str, image_column: str, + device: torch.device, + encode_video_fn: Callable[[torch.Tensor], torch.Tensor] = None, *args, **kwargs ) -> None: @@ -44,6 +58,9 @@ class BaseI2VDataset(Dataset): self.videos = load_videos(data_root / video_column) self.images = load_images(data_root / image_column) + self.device = device + self.encode_video_fn = encode_video_fn + # Check if number of prompts matches number of videos and images if not (len(self.videos) == len(self.prompts) == len(self.images)): raise ValueError( @@ -79,28 +96,48 @@ class BaseI2VDataset(Dataset): return index prompt = self.prompts[index] + video = self.videos[index] + image = self.images[index] - # shape of frames: [F, C, H, W] + video_latent_dir = video.parent / "latent" + video_latent_dir.mkdir(parents=True, exist_ok=True) + encoded_video_path = video_latent_dir / (video.stem + ".pt") + + if encoded_video_path.exists(): + encoded_video = torch.load(encoded_video_path, weights_only=True) + # shape of image: [C, H, W] + _, image = self.preprocess(None, self.images[index]) + else: + frames, image = self.preprocess(video, image) + frames = frames.to(self.device) + # current shape of frames: [F, C, H, W] + frames = self.video_transform(frames) + # Convert to [B, C, F, H, W] + frames = frames.unsqueeze(0) + frames = frames.permute(0, 2, 1, 3, 4).contiguous() + encoded_video = self.encode_video_fn(frames) + # [B, C, F, H, W] -> [C, F, H, W] + encoded_video = encoded_video[0].cpu() + torch.save(encoded_video, encoded_video_path) + logger.info(f"Saved encoded video to {encoded_video_path}", main_process_only=False) + + # shape of encoded_video: [C, F, H, W] # shape of image: [C, H, W] - frames, image = self.preprocess(self.videos[index], self.images[index]) - - frames = self.video_transform(frames) - image = self.image_transform(image) - return { "prompt": prompt, - "video": frames, + "image": image, + "encoded_video": encoded_video, "video_metadata": { - "num_frames": frames.shape[0], - "height": frames.shape[2], - "width": frames.shape[3], + "num_frames": encoded_video.shape[1], + "height": encoded_video.shape[2], + "width": encoded_video.shape[3], }, - "image": image } - def preprocess(self, video_path: Path, image_path: Path) -> Tuple[torch.Tensor, torch.Tensor]: + def preprocess(self, video_path: Path | None, image_path: Path | None) -> Tuple[torch.Tensor, torch.Tensor]: """ Loads and preprocesses a video and an image. + If either path is None, no preprocessing will be done for that input. Args: video_path: Path to the video file to load @@ -150,7 +187,16 @@ class BaseI2VDataset(Dataset): class I2VDatasetWithResize(BaseI2VDataset): """ + A dataset class for image-to-video generation that resizes inputs to fixed dimensions. + This class preprocesses videos and images by resizing them to specified dimensions: + - Videos are resized to max_num_frames x height x width + - Images are resized to height x width + + Args: + max_num_frames (int): Maximum number of frames to extract from videos + height (int): Target height for resizing videos and images + width (int): Target width for resizing videos and images """ def __init__(self, max_num_frames: int, height: int, width: int, *args, **kwargs) -> None: super().__init__(*args, **kwargs) @@ -164,26 +210,49 @@ class I2VDatasetWithResize(BaseI2VDataset): transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0) ] ) - + self.__image_transforms = self.__frame_transforms + @override - def preprocess(self, video_path: Path, image_path: Path) -> Tuple[torch.Tensor, torch.Tensor]: - video = preprocess_video_with_resize(video_path, self.max_num_frames, self.height, self.width) - image = preprocess_image_with_resize(image_path, self.height, self.width) + def preprocess(self, video_path: Path | None, image_path: Path | None) -> Tuple[torch.Tensor, torch.Tensor]: + if video_path is not None: + video = preprocess_video_with_resize(video_path, self.max_num_frames, self.height, self.width) + else: + video = None + if image_path is not None: + image = preprocess_image_with_resize(image_path, self.height, self.width) + else: + image = None return video, image @override def video_transform(self, frames: torch.Tensor) -> torch.Tensor: return torch.stack([self.__frame_transforms(f) for f in frames], dim=0) + + @override + def image_transform(self, image: torch.Tensor) -> torch.Tensor: + return self.__image_transforms(image) class I2VDatasetWithBuckets(BaseI2VDataset): - """ - """ - def __init__(self, video_resolution_buckets: List[Tuple[int, int, int]], *args, **kwargs) -> None: + def __init__( + self, + video_resolution_buckets: List[Tuple[int, int, int]], + vae_temporal_compression_ratio: int, + vae_height_compression_ratio: int, + vae_width_compression_ratio: int, + *args, **kwargs + ) -> None: super().__init__(*args, **kwargs) - self.video_resolution_buckets = video_resolution_buckets + self.video_resolution_buckets = [ + ( + int(b[0] / vae_temporal_compression_ratio), + int(b[1] / vae_height_compression_ratio), + int(b[2] / vae_width_compression_ratio), + ) + for b in video_resolution_buckets + ] self.__frame_transforms = transforms.Compose( [ transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0) diff --git a/finetune/datasets/t2v_dataset.py b/finetune/datasets/t2v_dataset.py index aa4ed72..9afe53a 100644 --- a/finetune/datasets/t2v_dataset.py +++ b/finetune/datasets/t2v_dataset.py @@ -1,12 +1,15 @@ +import torch + from pathlib import Path -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Tuple, Callable from typing_extensions import override -import torch from accelerate.logging import get_logger from torch.utils.data import Dataset from torchvision import transforms +from finetune.constants import LOG_NAME, LOG_LEVEL + from .utils import ( load_prompts, load_videos, preprocess_video_with_resize, @@ -19,18 +22,30 @@ import decord # isort:skip decord.bridge.set_bridge("torch") -logger = get_logger(__name__) +logger = get_logger(LOG_NAME, LOG_LEVEL) class BaseT2VDataset(Dataset): """ + Base dataset class for Text-to-Video (T2V) training. + This dataset loads prompts and videos for T2V training. + + Args: + data_root (str): Root directory containing the dataset files + caption_column (str): Path to file containing text prompts/captions + video_column (str): Path to file containing video paths + device (torch.device): Device to load the data on + encode_video_fn (Callable[[torch.Tensor], torch.Tensor], optional): Function to encode videos """ + def __init__( self, data_root: str, caption_column: str, video_column: str, + device: torch.device = None, + encode_video_fn: Callable[[torch.Tensor], torch.Tensor] = None, *args, **kwargs ) -> None: @@ -39,6 +54,8 @@ class BaseT2VDataset(Dataset): data_root = Path(data_root) self.prompts = load_prompts(data_root / caption_column) self.videos = load_videos(data_root / video_column) + self.device = device + self.encode_video_fn = encode_video_fn # Check if all video files exist if any(not path.is_file() for path in self.videos): @@ -69,18 +86,36 @@ class BaseT2VDataset(Dataset): return index prompt = self.prompts[index] + video = self.videos[index] - # shape of frames: [F, C, H, W] - frames = self.preprocess(self.videos[index]) - frames = self.video_transform(frames) + latent_dir = video.parent / "latent" + latent_dir.mkdir(parents=True, exist_ok=True) + encoded_video_path = latent_dir / (video.stem + ".pt") + + if encoded_video_path.exists(): + # shape of encoded_video: [C, F, H, W] + encoded_video = torch.load(encoded_video_path, weights_only=True) + else: + frames = self.preprocess(video) + frames = frames.to(self.device) + # current shape of frames: [F, C, H, W] + frames = self.video_transform(frames) + # Convert to [B, C, F, H, W] + frames = frames.unsqueeze(0) + frames = frames.permute(0, 2, 1, 3, 4).contiguous() + encoded_video = self.encode_video_fn(frames) + # [B, C, F, H, W] -> [C, F, H, W] + encoded_video = encoded_video[0].cpu() + torch.save(encoded_video, encoded_video_path) + logger.info(f"Saved encoded video to {encoded_video_path}", main_process_only=False) return { "prompt": prompt, - "video": frames, + "encoded_video": encoded_video, "video_metadata": { - "num_frames": frames.shape[0], - "height": frames.shape[2], - "width": frames.shape[3], + "num_frames": encoded_video.shape[1], + "height": encoded_video.shape[2], + "width": encoded_video.shape[3], }, } @@ -113,15 +148,24 @@ class BaseT2VDataset(Dataset): - W is width Returns: - torch.Tensor: The transformed video tensor + torch.Tensor: The transformed video tensor with the same shape as the input """ raise NotImplementedError("Subclass must implement this method") class T2VDatasetWithResize(BaseT2VDataset): """ + A dataset class for text-to-video generation that resizes inputs to fixed dimensions. + This class preprocesses videos by resizing them to specified dimensions: + - Videos are resized to max_num_frames x height x width + + Args: + max_num_frames (int): Maximum number of frames to extract from videos + height (int): Target height for resizing videos + width (int): Target width for resizing videos """ + def __init__(self, max_num_frames: int, height: int, width: int, *args, **kwargs) -> None: super().__init__(*args, **kwargs) @@ -147,18 +191,28 @@ class T2VDatasetWithResize(BaseT2VDataset): class T2VDatasetWithBuckets(BaseT2VDataset): - """ - """ - def __init__(self, video_resolution_buckets: List[Tuple[int, int, int]], *args, **kwargs) -> None: + def __init__( + self, + video_resolution_buckets: List[Tuple[int, int, int]], + vae_temporal_compression_ratio: int, + vae_height_compression_ratio: int, + vae_width_compression_ratio: int, + *args, **kwargs + ) -> None: """ - Args: - resolution_buckets: List of tuples representing the resolution buckets. - Each tuple contains three integers: (max_num_frames, height, width). + """ super().__init__(*args, **kwargs) - self.video_resolution_buckets = video_resolution_buckets + self.video_resolution_buckets = [ + ( + int(b[0] / vae_temporal_compression_ratio), + int(b[1] / vae_height_compression_ratio), + int(b[2] / vae_width_compression_ratio), + ) + for b in video_resolution_buckets + ] self.__frame_transform = transforms.Compose( [