mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 19:41:59 +08:00
feat: add latent caching for video encodings
- Add caching mechanism to store VAE-encoded video latents to disk - Cache latents in a "latent" subdirectory alongside video files - Skip re-encoding when cached latent file exists - Add logging for successful cache saves - Minor code cleanup and formatting improvements This change improves training efficiency by avoiding redundant video encoding operations.
This commit is contained in:
parent
2a6cca0656
commit
6eae5c201e
@ -1,11 +1,13 @@
|
||||
import torch
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from typing import Any, Dict, List, Tuple, Callable
|
||||
from typing_extensions import override
|
||||
|
||||
import torch
|
||||
from accelerate.logging import get_logger
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
from finetune.constants import LOG_NAME, LOG_LEVEL
|
||||
|
||||
from .utils import (
|
||||
load_prompts, load_videos, load_images,
|
||||
@ -21,12 +23,22 @@ import decord # isort:skip
|
||||
|
||||
decord.bridge.set_bridge("torch")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = get_logger(LOG_NAME, LOG_LEVEL)
|
||||
|
||||
|
||||
class BaseI2VDataset(Dataset):
|
||||
"""
|
||||
Base dataset class for Image-to-Video (I2V) training.
|
||||
|
||||
This dataset loads prompts, videos and corresponding conditioning images for I2V training.
|
||||
|
||||
Args:
|
||||
data_root (str): Root directory containing the dataset files
|
||||
caption_column (str): Path to file containing text prompts/captions
|
||||
video_column (str): Path to file containing video paths
|
||||
image_column (str): Path to file containing image paths
|
||||
device (torch.device): Device to load the data on
|
||||
encode_video_fn (Callable[[torch.Tensor], torch.Tensor], optional): Function to encode videos
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
@ -34,6 +46,8 @@ class BaseI2VDataset(Dataset):
|
||||
caption_column: str,
|
||||
video_column: str,
|
||||
image_column: str,
|
||||
device: torch.device,
|
||||
encode_video_fn: Callable[[torch.Tensor], torch.Tensor] = None,
|
||||
*args,
|
||||
**kwargs
|
||||
) -> None:
|
||||
@ -44,6 +58,9 @@ class BaseI2VDataset(Dataset):
|
||||
self.videos = load_videos(data_root / video_column)
|
||||
self.images = load_images(data_root / image_column)
|
||||
|
||||
self.device = device
|
||||
self.encode_video_fn = encode_video_fn
|
||||
|
||||
# Check if number of prompts matches number of videos and images
|
||||
if not (len(self.videos) == len(self.prompts) == len(self.images)):
|
||||
raise ValueError(
|
||||
@ -79,28 +96,48 @@ class BaseI2VDataset(Dataset):
|
||||
return index
|
||||
|
||||
prompt = self.prompts[index]
|
||||
video = self.videos[index]
|
||||
image = self.images[index]
|
||||
|
||||
# shape of frames: [F, C, H, W]
|
||||
video_latent_dir = video.parent / "latent"
|
||||
video_latent_dir.mkdir(parents=True, exist_ok=True)
|
||||
encoded_video_path = video_latent_dir / (video.stem + ".pt")
|
||||
|
||||
if encoded_video_path.exists():
|
||||
encoded_video = torch.load(encoded_video_path, weights_only=True)
|
||||
# shape of image: [C, H, W]
|
||||
_, image = self.preprocess(None, self.images[index])
|
||||
else:
|
||||
frames, image = self.preprocess(video, image)
|
||||
frames = frames.to(self.device)
|
||||
# current shape of frames: [F, C, H, W]
|
||||
frames = self.video_transform(frames)
|
||||
# Convert to [B, C, F, H, W]
|
||||
frames = frames.unsqueeze(0)
|
||||
frames = frames.permute(0, 2, 1, 3, 4).contiguous()
|
||||
encoded_video = self.encode_video_fn(frames)
|
||||
# [B, C, F, H, W] -> [C, F, H, W]
|
||||
encoded_video = encoded_video[0].cpu()
|
||||
torch.save(encoded_video, encoded_video_path)
|
||||
logger.info(f"Saved encoded video to {encoded_video_path}", main_process_only=False)
|
||||
|
||||
# shape of encoded_video: [C, F, H, W]
|
||||
# shape of image: [C, H, W]
|
||||
frames, image = self.preprocess(self.videos[index], self.images[index])
|
||||
|
||||
frames = self.video_transform(frames)
|
||||
image = self.image_transform(image)
|
||||
|
||||
return {
|
||||
"prompt": prompt,
|
||||
"video": frames,
|
||||
"image": image,
|
||||
"encoded_video": encoded_video,
|
||||
"video_metadata": {
|
||||
"num_frames": frames.shape[0],
|
||||
"height": frames.shape[2],
|
||||
"width": frames.shape[3],
|
||||
"num_frames": encoded_video.shape[1],
|
||||
"height": encoded_video.shape[2],
|
||||
"width": encoded_video.shape[3],
|
||||
},
|
||||
"image": image
|
||||
}
|
||||
|
||||
def preprocess(self, video_path: Path, image_path: Path) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def preprocess(self, video_path: Path | None, image_path: Path | None) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Loads and preprocesses a video and an image.
|
||||
If either path is None, no preprocessing will be done for that input.
|
||||
|
||||
Args:
|
||||
video_path: Path to the video file to load
|
||||
@ -150,7 +187,16 @@ class BaseI2VDataset(Dataset):
|
||||
|
||||
class I2VDatasetWithResize(BaseI2VDataset):
|
||||
"""
|
||||
A dataset class for image-to-video generation that resizes inputs to fixed dimensions.
|
||||
|
||||
This class preprocesses videos and images by resizing them to specified dimensions:
|
||||
- Videos are resized to max_num_frames x height x width
|
||||
- Images are resized to height x width
|
||||
|
||||
Args:
|
||||
max_num_frames (int): Maximum number of frames to extract from videos
|
||||
height (int): Target height for resizing videos and images
|
||||
width (int): Target width for resizing videos and images
|
||||
"""
|
||||
def __init__(self, max_num_frames: int, height: int, width: int, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
@ -164,26 +210,49 @@ class I2VDatasetWithResize(BaseI2VDataset):
|
||||
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)
|
||||
]
|
||||
)
|
||||
|
||||
self.__image_transforms = self.__frame_transforms
|
||||
|
||||
@override
|
||||
def preprocess(self, video_path: Path, image_path: Path) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
video = preprocess_video_with_resize(video_path, self.max_num_frames, self.height, self.width)
|
||||
image = preprocess_image_with_resize(image_path, self.height, self.width)
|
||||
def preprocess(self, video_path: Path | None, image_path: Path | None) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if video_path is not None:
|
||||
video = preprocess_video_with_resize(video_path, self.max_num_frames, self.height, self.width)
|
||||
else:
|
||||
video = None
|
||||
if image_path is not None:
|
||||
image = preprocess_image_with_resize(image_path, self.height, self.width)
|
||||
else:
|
||||
image = None
|
||||
return video, image
|
||||
|
||||
@override
|
||||
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
|
||||
return torch.stack([self.__frame_transforms(f) for f in frames], dim=0)
|
||||
|
||||
@override
|
||||
def image_transform(self, image: torch.Tensor) -> torch.Tensor:
|
||||
return self.__image_transforms(image)
|
||||
|
||||
|
||||
class I2VDatasetWithBuckets(BaseI2VDataset):
|
||||
"""
|
||||
|
||||
"""
|
||||
def __init__(self, video_resolution_buckets: List[Tuple[int, int, int]], *args, **kwargs) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
video_resolution_buckets: List[Tuple[int, int, int]],
|
||||
vae_temporal_compression_ratio: int,
|
||||
vae_height_compression_ratio: int,
|
||||
vae_width_compression_ratio: int,
|
||||
*args, **kwargs
|
||||
) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.video_resolution_buckets = video_resolution_buckets
|
||||
self.video_resolution_buckets = [
|
||||
(
|
||||
int(b[0] / vae_temporal_compression_ratio),
|
||||
int(b[1] / vae_height_compression_ratio),
|
||||
int(b[2] / vae_width_compression_ratio),
|
||||
)
|
||||
for b in video_resolution_buckets
|
||||
]
|
||||
self.__frame_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)
|
||||
|
@ -1,12 +1,15 @@
|
||||
import torch
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from typing import Any, Dict, List, Tuple, Callable
|
||||
from typing_extensions import override
|
||||
|
||||
import torch
|
||||
from accelerate.logging import get_logger
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
|
||||
from finetune.constants import LOG_NAME, LOG_LEVEL
|
||||
|
||||
from .utils import (
|
||||
load_prompts, load_videos,
|
||||
preprocess_video_with_resize,
|
||||
@ -19,18 +22,30 @@ import decord # isort:skip
|
||||
|
||||
decord.bridge.set_bridge("torch")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = get_logger(LOG_NAME, LOG_LEVEL)
|
||||
|
||||
|
||||
class BaseT2VDataset(Dataset):
|
||||
"""
|
||||
Base dataset class for Text-to-Video (T2V) training.
|
||||
|
||||
This dataset loads prompts and videos for T2V training.
|
||||
|
||||
Args:
|
||||
data_root (str): Root directory containing the dataset files
|
||||
caption_column (str): Path to file containing text prompts/captions
|
||||
video_column (str): Path to file containing video paths
|
||||
device (torch.device): Device to load the data on
|
||||
encode_video_fn (Callable[[torch.Tensor], torch.Tensor], optional): Function to encode videos
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_root: str,
|
||||
caption_column: str,
|
||||
video_column: str,
|
||||
device: torch.device = None,
|
||||
encode_video_fn: Callable[[torch.Tensor], torch.Tensor] = None,
|
||||
*args,
|
||||
**kwargs
|
||||
) -> None:
|
||||
@ -39,6 +54,8 @@ class BaseT2VDataset(Dataset):
|
||||
data_root = Path(data_root)
|
||||
self.prompts = load_prompts(data_root / caption_column)
|
||||
self.videos = load_videos(data_root / video_column)
|
||||
self.device = device
|
||||
self.encode_video_fn = encode_video_fn
|
||||
|
||||
# Check if all video files exist
|
||||
if any(not path.is_file() for path in self.videos):
|
||||
@ -69,18 +86,36 @@ class BaseT2VDataset(Dataset):
|
||||
return index
|
||||
|
||||
prompt = self.prompts[index]
|
||||
video = self.videos[index]
|
||||
|
||||
# shape of frames: [F, C, H, W]
|
||||
frames = self.preprocess(self.videos[index])
|
||||
frames = self.video_transform(frames)
|
||||
latent_dir = video.parent / "latent"
|
||||
latent_dir.mkdir(parents=True, exist_ok=True)
|
||||
encoded_video_path = latent_dir / (video.stem + ".pt")
|
||||
|
||||
if encoded_video_path.exists():
|
||||
# shape of encoded_video: [C, F, H, W]
|
||||
encoded_video = torch.load(encoded_video_path, weights_only=True)
|
||||
else:
|
||||
frames = self.preprocess(video)
|
||||
frames = frames.to(self.device)
|
||||
# current shape of frames: [F, C, H, W]
|
||||
frames = self.video_transform(frames)
|
||||
# Convert to [B, C, F, H, W]
|
||||
frames = frames.unsqueeze(0)
|
||||
frames = frames.permute(0, 2, 1, 3, 4).contiguous()
|
||||
encoded_video = self.encode_video_fn(frames)
|
||||
# [B, C, F, H, W] -> [C, F, H, W]
|
||||
encoded_video = encoded_video[0].cpu()
|
||||
torch.save(encoded_video, encoded_video_path)
|
||||
logger.info(f"Saved encoded video to {encoded_video_path}", main_process_only=False)
|
||||
|
||||
return {
|
||||
"prompt": prompt,
|
||||
"video": frames,
|
||||
"encoded_video": encoded_video,
|
||||
"video_metadata": {
|
||||
"num_frames": frames.shape[0],
|
||||
"height": frames.shape[2],
|
||||
"width": frames.shape[3],
|
||||
"num_frames": encoded_video.shape[1],
|
||||
"height": encoded_video.shape[2],
|
||||
"width": encoded_video.shape[3],
|
||||
},
|
||||
}
|
||||
|
||||
@ -113,15 +148,24 @@ class BaseT2VDataset(Dataset):
|
||||
- W is width
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The transformed video tensor
|
||||
torch.Tensor: The transformed video tensor with the same shape as the input
|
||||
"""
|
||||
raise NotImplementedError("Subclass must implement this method")
|
||||
|
||||
|
||||
class T2VDatasetWithResize(BaseT2VDataset):
|
||||
"""
|
||||
A dataset class for text-to-video generation that resizes inputs to fixed dimensions.
|
||||
|
||||
This class preprocesses videos by resizing them to specified dimensions:
|
||||
- Videos are resized to max_num_frames x height x width
|
||||
|
||||
Args:
|
||||
max_num_frames (int): Maximum number of frames to extract from videos
|
||||
height (int): Target height for resizing videos
|
||||
width (int): Target width for resizing videos
|
||||
"""
|
||||
|
||||
def __init__(self, max_num_frames: int, height: int, width: int, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@ -147,18 +191,28 @@ class T2VDatasetWithResize(BaseT2VDataset):
|
||||
|
||||
|
||||
class T2VDatasetWithBuckets(BaseT2VDataset):
|
||||
"""
|
||||
|
||||
"""
|
||||
def __init__(self, video_resolution_buckets: List[Tuple[int, int, int]], *args, **kwargs) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
video_resolution_buckets: List[Tuple[int, int, int]],
|
||||
vae_temporal_compression_ratio: int,
|
||||
vae_height_compression_ratio: int,
|
||||
vae_width_compression_ratio: int,
|
||||
*args, **kwargs
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
resolution_buckets: List of tuples representing the resolution buckets.
|
||||
Each tuple contains three integers: (max_num_frames, height, width).
|
||||
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.video_resolution_buckets = video_resolution_buckets
|
||||
self.video_resolution_buckets = [
|
||||
(
|
||||
int(b[0] / vae_temporal_compression_ratio),
|
||||
int(b[1] / vae_height_compression_ratio),
|
||||
int(b[2] / vae_width_compression_ratio),
|
||||
)
|
||||
for b in video_resolution_buckets
|
||||
]
|
||||
|
||||
self.__frame_transform = transforms.Compose(
|
||||
[
|
||||
|
Loading…
x
Reference in New Issue
Block a user