mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 19:41:59 +08:00
- Add caching for prompt embeddings - Store cached files using safetensors format - Add cache directory structure under data_root/cache - Optimize memory usage by moving tensors to CPU after caching - Add debug logging for cache hits - Add info logging for cache writes The caching system helps reduce redundant computation and memory usage during training by: 1. Caching prompt embeddings based on prompt text hash 2. Caching encoded video latents based on video filename 3. Moving tensors to CPU after caching to free GPU memory
313 lines
13 KiB
Python
313 lines
13 KiB
Python
import torch
|
|
import hashlib
|
|
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Tuple, TYPE_CHECKING
|
|
from typing_extensions import override
|
|
|
|
from torch.utils.data import Dataset
|
|
from torchvision import transforms
|
|
from accelerate.logging import get_logger
|
|
from safetensors.torch import save_file, load_file
|
|
|
|
from finetune.constants import LOG_NAME, LOG_LEVEL
|
|
|
|
from .utils import (
|
|
load_prompts, load_videos, load_images,
|
|
|
|
preprocess_image_with_resize,
|
|
preprocess_video_with_resize,
|
|
preprocess_video_with_buckets
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from finetune.trainer import Trainer
|
|
|
|
# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
|
|
# Very few bug reports but it happens. Look in decord Github issues for more relevant information.
|
|
import decord # isort:skip
|
|
|
|
decord.bridge.set_bridge("torch")
|
|
|
|
logger = get_logger(LOG_NAME, LOG_LEVEL)
|
|
|
|
|
|
class BaseI2VDataset(Dataset):
|
|
"""
|
|
Base dataset class for Image-to-Video (I2V) training.
|
|
|
|
This dataset loads prompts, videos and corresponding conditioning images for I2V training.
|
|
|
|
Args:
|
|
data_root (str): Root directory containing the dataset files
|
|
caption_column (str): Path to file containing text prompts/captions
|
|
video_column (str): Path to file containing video paths
|
|
image_column (str): Path to file containing image paths
|
|
device (torch.device): Device to load the data on
|
|
encode_video_fn (Callable[[torch.Tensor], torch.Tensor], optional): Function to encode videos
|
|
"""
|
|
def __init__(
|
|
self,
|
|
data_root: str,
|
|
caption_column: str,
|
|
video_column: str,
|
|
image_column: str,
|
|
device: torch.device,
|
|
trainer: "Trainer" = None,
|
|
*args,
|
|
**kwargs
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
data_root = Path(data_root)
|
|
self.prompts = load_prompts(data_root / caption_column)
|
|
self.videos = load_videos(data_root / video_column)
|
|
self.images = load_images(data_root / image_column)
|
|
self.trainer = trainer
|
|
|
|
self.device = device
|
|
self.encode_video = trainer.encode_video
|
|
self.encode_text = trainer.encode_text
|
|
|
|
# Check if number of prompts matches number of videos and images
|
|
if not (len(self.videos) == len(self.prompts) == len(self.images)):
|
|
raise ValueError(
|
|
f"Expected length of prompts, videos and images to be the same but found {len(self.prompts)=}, {len(self.videos)=} and {len(self.images)=}. Please ensure that the number of caption prompts, videos and images match in your dataset."
|
|
)
|
|
|
|
# Check if all video files exist
|
|
if any(not path.is_file() for path in self.videos):
|
|
raise ValueError(
|
|
f"Some video files were not found. Please ensure that all video files exist in the dataset directory. Missing file: {next(path for path in self.videos if not path.is_file())}"
|
|
)
|
|
|
|
# Check if all image files exist
|
|
if any(not path.is_file() for path in self.images):
|
|
raise ValueError(
|
|
f"Some image files were not found. Please ensure that all image files exist in the dataset directory. Missing file: {next(path for path in self.images if not path.is_file())}"
|
|
)
|
|
|
|
def __len__(self) -> int:
|
|
return len(self.videos)
|
|
|
|
def __getitem__(self, index: int) -> Dict[str, Any]:
|
|
if isinstance(index, list):
|
|
# Here, index is actually a list of data objects that we need to return.
|
|
# The BucketSampler should ideally return indices. But, in the sampler, we'd like
|
|
# to have information about num_frames, height and width. Since this is not stored
|
|
# as metadata, we need to read the video to get this information. You could read this
|
|
# information without loading the full video in memory, but we do it anyway. In order
|
|
# to not load the video twice (once to get the metadata, and once to return the loaded video
|
|
# based on sampled indices), we cache it in the BucketSampler. When the sampler is
|
|
# to yield, we yield the cache data instead of indices. So, this special check ensures
|
|
# that data is not loaded a second time. PRs are welcome for improvements.
|
|
return index
|
|
|
|
prompt = self.prompts[index]
|
|
video = self.videos[index]
|
|
image = self.images[index]
|
|
train_resolution_str = "x".join(str(x) for x in self.trainer.args.train_resolution)
|
|
|
|
cache_dir = self.trainer.args.data_root / "cache"
|
|
video_latent_dir = cache_dir / "video_latent" / self.trainer.args.model_name / train_resolution_str
|
|
prompt_embeddings_dir = cache_dir / "prompt_embeddings"
|
|
video_latent_dir.mkdir(parents=True, exist_ok=True)
|
|
prompt_embeddings_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
prompt_hash = str(hashlib.sha256(prompt.encode()).hexdigest())
|
|
prompt_embedding_path = prompt_embeddings_dir / (prompt_hash + ".safetensors")
|
|
encoded_video_path = video_latent_dir / (video.stem + ".safetensors")
|
|
|
|
if prompt_embedding_path.exists():
|
|
prompt_embedding = load_file(prompt_embedding_path)["prompt_embedding"]
|
|
logger.debug(f"process {self.trainer.accelerator.process_index}: Loaded prompt embedding from {prompt_embedding_path}", main_process_only=False)
|
|
else:
|
|
prompt_embedding = self.encode_text(prompt)
|
|
prompt_embedding = prompt_embedding.to("cpu")
|
|
# [1, seq_len, hidden_size] -> [seq_len, hidden_size]
|
|
prompt_embedding = prompt_embedding[0]
|
|
save_file({"prompt_embedding": prompt_embedding}, prompt_embedding_path)
|
|
logger.info(f"Saved prompt embedding to {prompt_embedding_path}", main_process_only=False)
|
|
|
|
if encoded_video_path.exists():
|
|
# encoded_video = torch.load(encoded_video_path, weights_only=True)
|
|
encoded_video = load_file(encoded_video_path)["encoded_video"]
|
|
logger.debug(f"Loaded encoded video from {encoded_video_path}", main_process_only=False)
|
|
# shape of image: [C, H, W]
|
|
_, image = self.preprocess(None, self.images[index])
|
|
else:
|
|
frames, image = self.preprocess(video, image)
|
|
frames = frames.to(self.device)
|
|
image = image.to(self.device)
|
|
# Current shape of frames: [F, C, H, W]
|
|
frames = self.video_transform(frames)
|
|
|
|
# Add image into the first frame.
|
|
# Note, **this operation maybe model-specific**, and maybe change in the future.
|
|
frames = torch.cat([image.unsqueeze(0), frames], dim=0)
|
|
|
|
# Convert to [B, C, F, H, W]
|
|
frames = frames.unsqueeze(0)
|
|
frames = frames.permute(0, 2, 1, 3, 4).contiguous()
|
|
encoded_video = self.encode_video(frames)
|
|
|
|
# [1, C, F, H, W] -> [C, F, H, W]
|
|
encoded_video = encoded_video[0]
|
|
encoded_video = encoded_video.to("cpu")
|
|
image = image.to("cpu")
|
|
save_file({"encoded_video": encoded_video}, encoded_video_path)
|
|
logger.info(f"Saved encoded video to {encoded_video_path}", main_process_only=False)
|
|
|
|
# shape of encoded_video: [C, F, H, W]
|
|
# shape of image: [C, H, W]
|
|
return {
|
|
"image": image,
|
|
"prompt_embedding": prompt_embedding,
|
|
"encoded_video": encoded_video,
|
|
"video_metadata": {
|
|
"num_frames": encoded_video.shape[1],
|
|
"height": encoded_video.shape[2],
|
|
"width": encoded_video.shape[3],
|
|
},
|
|
}
|
|
|
|
def preprocess(self, video_path: Path | None, image_path: Path | None) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Loads and preprocesses a video and an image.
|
|
If either path is None, no preprocessing will be done for that input.
|
|
|
|
Args:
|
|
video_path: Path to the video file to load
|
|
image_path: Path to the image file to load
|
|
|
|
Returns:
|
|
A tuple containing:
|
|
- video(torch.Tensor) of shape [F, C, H, W] where F is number of frames,
|
|
C is number of channels, H is height and W is width
|
|
- image(torch.Tensor) of shape [C, H, W]
|
|
"""
|
|
raise NotImplementedError("Subclass must implement this method")
|
|
|
|
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Applies transformations to a video.
|
|
|
|
Args:
|
|
frames (torch.Tensor): A 4D tensor representing a video
|
|
with shape [F, C, H, W] where:
|
|
- F is number of frames
|
|
- C is number of channels (3 for RGB)
|
|
- H is height
|
|
- W is width
|
|
|
|
Returns:
|
|
torch.Tensor: The transformed video tensor
|
|
"""
|
|
raise NotImplementedError("Subclass must implement this method")
|
|
|
|
def image_transform(self, image: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Applies transformations to an image.
|
|
|
|
Args:
|
|
image (torch.Tensor): A 3D tensor representing an image
|
|
with shape [C, H, W] where:
|
|
- C is number of channels (3 for RGB)
|
|
- H is height
|
|
- W is width
|
|
|
|
Returns:
|
|
torch.Tensor: The transformed image tensor
|
|
"""
|
|
raise NotImplementedError("Subclass must implement this method")
|
|
|
|
|
|
class I2VDatasetWithResize(BaseI2VDataset):
|
|
"""
|
|
A dataset class for image-to-video generation that resizes inputs to fixed dimensions.
|
|
|
|
This class preprocesses videos and images by resizing them to specified dimensions:
|
|
- Videos are resized to max_num_frames x height x width
|
|
- Images are resized to height x width
|
|
|
|
Args:
|
|
max_num_frames (int): Maximum number of frames to extract from videos
|
|
height (int): Target height for resizing videos and images
|
|
width (int): Target width for resizing videos and images
|
|
"""
|
|
def __init__(self, max_num_frames: int, height: int, width: int, *args, **kwargs) -> None:
|
|
super().__init__(*args, **kwargs)
|
|
|
|
self.max_num_frames = max_num_frames
|
|
self.height = height
|
|
self.width = width
|
|
|
|
self.__frame_transforms = transforms.Compose(
|
|
[
|
|
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)
|
|
]
|
|
)
|
|
self.__image_transforms = self.__frame_transforms
|
|
|
|
@override
|
|
def preprocess(self, video_path: Path | None, image_path: Path | None) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
if video_path is not None:
|
|
video = preprocess_video_with_resize(video_path, self.max_num_frames, self.height, self.width)
|
|
else:
|
|
video = None
|
|
if image_path is not None:
|
|
image = preprocess_image_with_resize(image_path, self.height, self.width)
|
|
else:
|
|
image = None
|
|
return video, image
|
|
|
|
@override
|
|
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
|
|
return torch.stack([self.__frame_transforms(f) for f in frames], dim=0)
|
|
|
|
@override
|
|
def image_transform(self, image: torch.Tensor) -> torch.Tensor:
|
|
return self.__image_transforms(image)
|
|
|
|
|
|
class I2VDatasetWithBuckets(BaseI2VDataset):
|
|
|
|
def __init__(
|
|
self,
|
|
video_resolution_buckets: List[Tuple[int, int, int]],
|
|
vae_temporal_compression_ratio: int,
|
|
vae_height_compression_ratio: int,
|
|
vae_width_compression_ratio: int,
|
|
*args, **kwargs
|
|
) -> None:
|
|
super().__init__(*args, **kwargs)
|
|
|
|
self.video_resolution_buckets = [
|
|
(
|
|
int(b[0] / vae_temporal_compression_ratio),
|
|
int(b[1] / vae_height_compression_ratio),
|
|
int(b[2] / vae_width_compression_ratio),
|
|
)
|
|
for b in video_resolution_buckets
|
|
]
|
|
self.__frame_transforms = transforms.Compose(
|
|
[
|
|
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)
|
|
]
|
|
)
|
|
self.__image_transforms = self.__frame_transforms
|
|
|
|
@override
|
|
def preprocess(self, video_path: Path, image_path: Path) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
video = preprocess_video_with_buckets(video_path, self.video_resolution_buckets)
|
|
image = preprocess_image_with_resize(image_path, video.shape[2], video.shape[3])
|
|
return video, image
|
|
|
|
@override
|
|
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
|
|
return torch.stack([self.__frame_transforms(f) for f in frames], dim=0)
|
|
|
|
@override
|
|
def image_transform(self, image: torch.Tensor) -> torch.Tensor:
|
|
return self.__image_transforms(image)
|