feat: add latent caching for video encodings

- Add caching mechanism to store VAE-encoded video latents to disk
- Cache latents in a "latent" subdirectory alongside video files
- Skip re-encoding when cached latent file exists
- Add logging for successful cache saves
- Minor code cleanup and formatting improvements

This change improves training efficiency by avoiding redundant video encoding operations.
This commit is contained in:
OleehyO 2024-12-30 16:10:06 +00:00
parent 2a6cca0656
commit 6eae5c201e
2 changed files with 164 additions and 41 deletions

View File

@ -1,11 +1,13 @@
import torch
from pathlib import Path
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Tuple, Callable
from typing_extensions import override
import torch
from accelerate.logging import get_logger
from torch.utils.data import Dataset
from torchvision import transforms
from finetune.constants import LOG_NAME, LOG_LEVEL
from .utils import (
load_prompts, load_videos, load_images,
@ -21,12 +23,22 @@ import decord # isort:skip
decord.bridge.set_bridge("torch")
logger = get_logger(__name__)
logger = get_logger(LOG_NAME, LOG_LEVEL)
class BaseI2VDataset(Dataset):
"""
Base dataset class for Image-to-Video (I2V) training.
This dataset loads prompts, videos and corresponding conditioning images for I2V training.
Args:
data_root (str): Root directory containing the dataset files
caption_column (str): Path to file containing text prompts/captions
video_column (str): Path to file containing video paths
image_column (str): Path to file containing image paths
device (torch.device): Device to load the data on
encode_video_fn (Callable[[torch.Tensor], torch.Tensor], optional): Function to encode videos
"""
def __init__(
self,
@ -34,6 +46,8 @@ class BaseI2VDataset(Dataset):
caption_column: str,
video_column: str,
image_column: str,
device: torch.device,
encode_video_fn: Callable[[torch.Tensor], torch.Tensor] = None,
*args,
**kwargs
) -> None:
@ -44,6 +58,9 @@ class BaseI2VDataset(Dataset):
self.videos = load_videos(data_root / video_column)
self.images = load_images(data_root / image_column)
self.device = device
self.encode_video_fn = encode_video_fn
# Check if number of prompts matches number of videos and images
if not (len(self.videos) == len(self.prompts) == len(self.images)):
raise ValueError(
@ -79,28 +96,48 @@ class BaseI2VDataset(Dataset):
return index
prompt = self.prompts[index]
video = self.videos[index]
image = self.images[index]
# shape of frames: [F, C, H, W]
video_latent_dir = video.parent / "latent"
video_latent_dir.mkdir(parents=True, exist_ok=True)
encoded_video_path = video_latent_dir / (video.stem + ".pt")
if encoded_video_path.exists():
encoded_video = torch.load(encoded_video_path, weights_only=True)
# shape of image: [C, H, W]
_, image = self.preprocess(None, self.images[index])
else:
frames, image = self.preprocess(video, image)
frames = frames.to(self.device)
# current shape of frames: [F, C, H, W]
frames = self.video_transform(frames)
# Convert to [B, C, F, H, W]
frames = frames.unsqueeze(0)
frames = frames.permute(0, 2, 1, 3, 4).contiguous()
encoded_video = self.encode_video_fn(frames)
# [B, C, F, H, W] -> [C, F, H, W]
encoded_video = encoded_video[0].cpu()
torch.save(encoded_video, encoded_video_path)
logger.info(f"Saved encoded video to {encoded_video_path}", main_process_only=False)
# shape of encoded_video: [C, F, H, W]
# shape of image: [C, H, W]
frames, image = self.preprocess(self.videos[index], self.images[index])
frames = self.video_transform(frames)
image = self.image_transform(image)
return {
"prompt": prompt,
"video": frames,
"image": image,
"encoded_video": encoded_video,
"video_metadata": {
"num_frames": frames.shape[0],
"height": frames.shape[2],
"width": frames.shape[3],
"num_frames": encoded_video.shape[1],
"height": encoded_video.shape[2],
"width": encoded_video.shape[3],
},
"image": image
}
def preprocess(self, video_path: Path, image_path: Path) -> Tuple[torch.Tensor, torch.Tensor]:
def preprocess(self, video_path: Path | None, image_path: Path | None) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Loads and preprocesses a video and an image.
If either path is None, no preprocessing will be done for that input.
Args:
video_path: Path to the video file to load
@ -150,7 +187,16 @@ class BaseI2VDataset(Dataset):
class I2VDatasetWithResize(BaseI2VDataset):
"""
A dataset class for image-to-video generation that resizes inputs to fixed dimensions.
This class preprocesses videos and images by resizing them to specified dimensions:
- Videos are resized to max_num_frames x height x width
- Images are resized to height x width
Args:
max_num_frames (int): Maximum number of frames to extract from videos
height (int): Target height for resizing videos and images
width (int): Target width for resizing videos and images
"""
def __init__(self, max_num_frames: int, height: int, width: int, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
@ -164,26 +210,49 @@ class I2VDatasetWithResize(BaseI2VDataset):
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)
]
)
self.__image_transforms = self.__frame_transforms
@override
def preprocess(self, video_path: Path, image_path: Path) -> Tuple[torch.Tensor, torch.Tensor]:
video = preprocess_video_with_resize(video_path, self.max_num_frames, self.height, self.width)
image = preprocess_image_with_resize(image_path, self.height, self.width)
def preprocess(self, video_path: Path | None, image_path: Path | None) -> Tuple[torch.Tensor, torch.Tensor]:
if video_path is not None:
video = preprocess_video_with_resize(video_path, self.max_num_frames, self.height, self.width)
else:
video = None
if image_path is not None:
image = preprocess_image_with_resize(image_path, self.height, self.width)
else:
image = None
return video, image
@override
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
return torch.stack([self.__frame_transforms(f) for f in frames], dim=0)
@override
def image_transform(self, image: torch.Tensor) -> torch.Tensor:
return self.__image_transforms(image)
class I2VDatasetWithBuckets(BaseI2VDataset):
"""
"""
def __init__(self, video_resolution_buckets: List[Tuple[int, int, int]], *args, **kwargs) -> None:
def __init__(
self,
video_resolution_buckets: List[Tuple[int, int, int]],
vae_temporal_compression_ratio: int,
vae_height_compression_ratio: int,
vae_width_compression_ratio: int,
*args, **kwargs
) -> None:
super().__init__(*args, **kwargs)
self.video_resolution_buckets = video_resolution_buckets
self.video_resolution_buckets = [
(
int(b[0] / vae_temporal_compression_ratio),
int(b[1] / vae_height_compression_ratio),
int(b[2] / vae_width_compression_ratio),
)
for b in video_resolution_buckets
]
self.__frame_transforms = transforms.Compose(
[
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)

View File

@ -1,12 +1,15 @@
import torch
from pathlib import Path
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Tuple, Callable
from typing_extensions import override
import torch
from accelerate.logging import get_logger
from torch.utils.data import Dataset
from torchvision import transforms
from finetune.constants import LOG_NAME, LOG_LEVEL
from .utils import (
load_prompts, load_videos,
preprocess_video_with_resize,
@ -19,18 +22,30 @@ import decord # isort:skip
decord.bridge.set_bridge("torch")
logger = get_logger(__name__)
logger = get_logger(LOG_NAME, LOG_LEVEL)
class BaseT2VDataset(Dataset):
"""
Base dataset class for Text-to-Video (T2V) training.
This dataset loads prompts and videos for T2V training.
Args:
data_root (str): Root directory containing the dataset files
caption_column (str): Path to file containing text prompts/captions
video_column (str): Path to file containing video paths
device (torch.device): Device to load the data on
encode_video_fn (Callable[[torch.Tensor], torch.Tensor], optional): Function to encode videos
"""
def __init__(
self,
data_root: str,
caption_column: str,
video_column: str,
device: torch.device = None,
encode_video_fn: Callable[[torch.Tensor], torch.Tensor] = None,
*args,
**kwargs
) -> None:
@ -39,6 +54,8 @@ class BaseT2VDataset(Dataset):
data_root = Path(data_root)
self.prompts = load_prompts(data_root / caption_column)
self.videos = load_videos(data_root / video_column)
self.device = device
self.encode_video_fn = encode_video_fn
# Check if all video files exist
if any(not path.is_file() for path in self.videos):
@ -69,18 +86,36 @@ class BaseT2VDataset(Dataset):
return index
prompt = self.prompts[index]
video = self.videos[index]
# shape of frames: [F, C, H, W]
frames = self.preprocess(self.videos[index])
frames = self.video_transform(frames)
latent_dir = video.parent / "latent"
latent_dir.mkdir(parents=True, exist_ok=True)
encoded_video_path = latent_dir / (video.stem + ".pt")
if encoded_video_path.exists():
# shape of encoded_video: [C, F, H, W]
encoded_video = torch.load(encoded_video_path, weights_only=True)
else:
frames = self.preprocess(video)
frames = frames.to(self.device)
# current shape of frames: [F, C, H, W]
frames = self.video_transform(frames)
# Convert to [B, C, F, H, W]
frames = frames.unsqueeze(0)
frames = frames.permute(0, 2, 1, 3, 4).contiguous()
encoded_video = self.encode_video_fn(frames)
# [B, C, F, H, W] -> [C, F, H, W]
encoded_video = encoded_video[0].cpu()
torch.save(encoded_video, encoded_video_path)
logger.info(f"Saved encoded video to {encoded_video_path}", main_process_only=False)
return {
"prompt": prompt,
"video": frames,
"encoded_video": encoded_video,
"video_metadata": {
"num_frames": frames.shape[0],
"height": frames.shape[2],
"width": frames.shape[3],
"num_frames": encoded_video.shape[1],
"height": encoded_video.shape[2],
"width": encoded_video.shape[3],
},
}
@ -113,15 +148,24 @@ class BaseT2VDataset(Dataset):
- W is width
Returns:
torch.Tensor: The transformed video tensor
torch.Tensor: The transformed video tensor with the same shape as the input
"""
raise NotImplementedError("Subclass must implement this method")
class T2VDatasetWithResize(BaseT2VDataset):
"""
A dataset class for text-to-video generation that resizes inputs to fixed dimensions.
This class preprocesses videos by resizing them to specified dimensions:
- Videos are resized to max_num_frames x height x width
Args:
max_num_frames (int): Maximum number of frames to extract from videos
height (int): Target height for resizing videos
width (int): Target width for resizing videos
"""
def __init__(self, max_num_frames: int, height: int, width: int, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
@ -147,18 +191,28 @@ class T2VDatasetWithResize(BaseT2VDataset):
class T2VDatasetWithBuckets(BaseT2VDataset):
"""
"""
def __init__(self, video_resolution_buckets: List[Tuple[int, int, int]], *args, **kwargs) -> None:
def __init__(
self,
video_resolution_buckets: List[Tuple[int, int, int]],
vae_temporal_compression_ratio: int,
vae_height_compression_ratio: int,
vae_width_compression_ratio: int,
*args, **kwargs
) -> None:
"""
Args:
resolution_buckets: List of tuples representing the resolution buckets.
Each tuple contains three integers: (max_num_frames, height, width).
"""
super().__init__(*args, **kwargs)
self.video_resolution_buckets = video_resolution_buckets
self.video_resolution_buckets = [
(
int(b[0] / vae_temporal_compression_ratio),
int(b[1] / vae_height_compression_ratio),
int(b[2] / vae_width_compression_ratio),
)
for b in video_resolution_buckets
]
self.__frame_transform = transforms.Compose(
[