feat: add caching for prompt embeddings

- Add caching for prompt embeddings
- Store cached files using safetensors format
- Add cache directory structure under data_root/cache
- Optimize memory usage by moving tensors to CPU after caching
- Add debug logging for cache hits
- Add info logging for cache writes

The caching system helps reduce redundant computation and memory usage during training by:
1. Caching prompt embeddings based on prompt text hash
2. Caching encoded video latents based on video filename
3. Moving tensors to CPU after caching to free GPU memory
This commit is contained in:
OleehyO 2025-01-03 07:52:10 +00:00
parent f731c35f70
commit e5b8f9a2ee
3 changed files with 180 additions and 46 deletions

View File

@ -1,12 +1,15 @@
import torch
import hashlib
from pathlib import Path
from typing import Any, Dict, List, Tuple, Callable
from typing import Any, Dict, List, Tuple, TYPE_CHECKING
from typing_extensions import override
from accelerate.logging import get_logger
from torch.utils.data import Dataset
from torchvision import transforms
from accelerate.logging import get_logger
from safetensors.torch import save_file, load_file
from finetune.constants import LOG_NAME, LOG_LEVEL
from .utils import (
@ -17,6 +20,9 @@ from .utils import (
preprocess_video_with_buckets
)
if TYPE_CHECKING:
from finetune.trainer import Trainer
# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
# Very few bug reports but it happens. Look in decord Github issues for more relevant information.
import decord # isort:skip
@ -47,7 +53,7 @@ class BaseI2VDataset(Dataset):
video_column: str,
image_column: str,
device: torch.device,
encode_video_fn: Callable[[torch.Tensor], torch.Tensor] = None,
trainer: "Trainer" = None,
*args,
**kwargs
) -> None:
@ -57,9 +63,11 @@ class BaseI2VDataset(Dataset):
self.prompts = load_prompts(data_root / caption_column)
self.videos = load_videos(data_root / video_column)
self.images = load_images(data_root / image_column)
self.trainer = trainer
self.device = device
self.encode_video_fn = encode_video_fn
self.encode_video = trainer.encode_video
self.encode_text = trainer.encode_text
# Check if number of prompts matches number of videos and images
if not (len(self.videos) == len(self.prompts) == len(self.images)):
@ -98,34 +106,63 @@ class BaseI2VDataset(Dataset):
prompt = self.prompts[index]
video = self.videos[index]
image = self.images[index]
train_resolution_str = "x".join(str(x) for x in self.trainer.args.train_resolution)
video_latent_dir = video.parent / "latent"
cache_dir = self.trainer.args.data_root / "cache"
video_latent_dir = cache_dir / "video_latent" / self.trainer.args.model_name / train_resolution_str
prompt_embeddings_dir = cache_dir / "prompt_embeddings"
video_latent_dir.mkdir(parents=True, exist_ok=True)
encoded_video_path = video_latent_dir / (video.stem + ".pt")
prompt_embeddings_dir.mkdir(parents=True, exist_ok=True)
prompt_hash = str(hashlib.sha256(prompt.encode()).hexdigest())
prompt_embedding_path = prompt_embeddings_dir / (prompt_hash + ".safetensors")
encoded_video_path = video_latent_dir / (video.stem + ".safetensors")
if prompt_embedding_path.exists():
prompt_embedding = load_file(prompt_embedding_path)["prompt_embedding"]
logger.debug(f"process {self.trainer.accelerator.process_index}: Loaded prompt embedding from {prompt_embedding_path}", main_process_only=False)
else:
prompt_embedding = self.encode_text(prompt)
prompt_embedding = prompt_embedding.to("cpu")
# [1, seq_len, hidden_size] -> [seq_len, hidden_size]
prompt_embedding = prompt_embedding[0]
save_file({"prompt_embedding": prompt_embedding}, prompt_embedding_path)
logger.info(f"Saved prompt embedding to {prompt_embedding_path}", main_process_only=False)
if encoded_video_path.exists():
encoded_video = torch.load(encoded_video_path, weights_only=True)
# encoded_video = torch.load(encoded_video_path, weights_only=True)
encoded_video = load_file(encoded_video_path)["encoded_video"]
logger.debug(f"Loaded encoded video from {encoded_video_path}", main_process_only=False)
# shape of image: [C, H, W]
_, image = self.preprocess(None, self.images[index])
else:
frames, image = self.preprocess(video, image)
frames = frames.to(self.device)
# current shape of frames: [F, C, H, W]
image = image.to(self.device)
# Current shape of frames: [F, C, H, W]
frames = self.video_transform(frames)
# Add image into the first frame.
# Note, **this operation maybe model-specific**, and maybe change in the future.
frames = torch.cat([image.unsqueeze(0), frames], dim=0)
# Convert to [B, C, F, H, W]
frames = frames.unsqueeze(0)
frames = frames.permute(0, 2, 1, 3, 4).contiguous()
encoded_video = self.encode_video_fn(frames)
# [B, C, F, H, W] -> [C, F, H, W]
encoded_video = encoded_video[0].cpu()
torch.save(encoded_video, encoded_video_path)
encoded_video = self.encode_video(frames)
# [1, C, F, H, W] -> [C, F, H, W]
encoded_video = encoded_video[0]
encoded_video = encoded_video.to("cpu")
image = image.to("cpu")
save_file({"encoded_video": encoded_video}, encoded_video_path)
logger.info(f"Saved encoded video to {encoded_video_path}", main_process_only=False)
# shape of encoded_video: [C, F, H, W]
# shape of image: [C, H, W]
return {
"prompt": prompt,
"image": image,
"prompt_embedding": prompt_embedding,
"encoded_video": encoded_video,
"video_metadata": {
"num_frames": encoded_video.shape[1],

View File

@ -1,12 +1,14 @@
import hashlib
import torch
from pathlib import Path
from typing import Any, Dict, List, Tuple, Callable
from typing import Any, Dict, List, Tuple, TYPE_CHECKING
from typing_extensions import override
from accelerate.logging import get_logger
from torch.utils.data import Dataset
from torchvision import transforms
from accelerate.logging import get_logger
from safetensors.torch import save_file, load_file
from finetune.constants import LOG_NAME, LOG_LEVEL
@ -16,6 +18,9 @@ from .utils import (
preprocess_video_with_buckets
)
if TYPE_CHECKING:
from finetune.trainer import Trainer
# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
# Very few bug reports but it happens. Look in decord Github issues for more relevant information.
import decord # isort:skip
@ -45,7 +50,7 @@ class BaseT2VDataset(Dataset):
caption_column: str,
video_column: str,
device: torch.device = None,
encode_video_fn: Callable[[torch.Tensor], torch.Tensor] = None,
trainer: "Trainer" = None,
*args,
**kwargs
) -> None:
@ -55,7 +60,9 @@ class BaseT2VDataset(Dataset):
self.prompts = load_prompts(data_root / caption_column)
self.videos = load_videos(data_root / video_column)
self.device = device
self.encode_video_fn = encode_video_fn
self.encode_video = trainer.encode_video
self.encode_text = trainer.encode_text
self.trainer = trainer
# Check if all video files exist
if any(not path.is_file() for path in self.videos):
@ -87,30 +94,53 @@ class BaseT2VDataset(Dataset):
prompt = self.prompts[index]
video = self.videos[index]
train_resolution_str = "x".join(str(x) for x in self.trainer.args.train_resolution)
latent_dir = video.parent / "latent"
latent_dir.mkdir(parents=True, exist_ok=True)
encoded_video_path = latent_dir / (video.stem + ".pt")
cache_dir = self.trainer.args.data_root / "cache"
video_latent_dir = cache_dir / "video_latent" / self.trainer.args.model_name / train_resolution_str
prompt_embeddings_dir = cache_dir / "prompt_embeddings"
video_latent_dir.mkdir(parents=True, exist_ok=True)
prompt_embeddings_dir.mkdir(parents=True, exist_ok=True)
prompt_hash = str(hashlib.sha256(prompt.encode()).hexdigest())
prompt_embedding_path = prompt_embeddings_dir / (prompt_hash + ".safetensors")
encoded_video_path = video_latent_dir / (video.stem + ".safetensors")
if prompt_embedding_path.exists():
prompt_embedding = load_file(prompt_embedding_path)["prompt_embedding"]
logger.debug(f"process {self.trainer.accelerator.process_index}: Loaded prompt embedding from {prompt_embedding_path}", main_process_only=False)
else:
prompt_embedding = self.encode_text(prompt)
prompt_embedding = prompt_embedding.to("cpu")
# [1, seq_len, hidden_size] -> [seq_len, hidden_size]
prompt_embedding = prompt_embedding[0]
save_file({"prompt_embedding": prompt_embedding}, prompt_embedding_path)
logger.info(f"Saved prompt embedding to {prompt_embedding_path}", main_process_only=False)
if encoded_video_path.exists():
# shape of encoded_video: [C, F, H, W]
encoded_video = torch.load(encoded_video_path, weights_only=True)
# encoded_video = torch.load(encoded_video_path, weights_only=True)
encoded_video = load_file(encoded_video_path)["encoded_video"]
logger.debug(f"Loaded encoded video from {encoded_video_path}", main_process_only=False)
# shape of image: [C, H, W]
else:
frames = self.preprocess(video)
frames = frames.to(self.device)
# current shape of frames: [F, C, H, W]
# Current shape of frames: [F, C, H, W]
frames = self.video_transform(frames)
# Convert to [B, C, F, H, W]
frames = frames.unsqueeze(0)
frames = frames.permute(0, 2, 1, 3, 4).contiguous()
encoded_video = self.encode_video_fn(frames)
# [B, C, F, H, W] -> [C, F, H, W]
encoded_video = encoded_video[0].cpu()
torch.save(encoded_video, encoded_video_path)
encoded_video = self.encode_video(frames)
# [1, C, F, H, W] -> [C, F, H, W]
encoded_video = encoded_video[0]
encoded_video = encoded_video.to("cpu")
save_file({"encoded_video": encoded_video}, encoded_video_path)
logger.info(f"Saved encoded video to {encoded_video_path}", main_process_only=False)
# shape of encoded_video: [C, F, H, W]
return {
"prompt": prompt,
"prompt_embedding": prompt_embedding,
"encoded_video": encoded_video,
"video_metadata": {
"num_frames": encoded_video.shape[1],

View File

@ -25,6 +25,7 @@ from accelerate.utils import (
gather_object,
)
from diffusers.pipelines import DiffusionPipeline
from diffusers.optimization import get_scheduler
from diffusers.utils.export_utils import export_to_video
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
@ -36,6 +37,7 @@ from finetune.utils import (
get_memory_statistics,
free_memory,
unload_model,
get_latest_ckpt_path_to_resume_from,
get_intermediate_ckpt_path,
@ -63,6 +65,8 @@ _DTYPE_MAP = {
class Trainer:
# If set, should be a list of components to unload (refer to `Components``)
UNLOAD_LIST: List[str] = None
def __init__(self, args: Args) -> None:
self.args = args
@ -132,6 +136,15 @@ class Trainer:
if self.accelerator.is_main_process:
self.args.output_dir = Path(self.args.output_dir)
self.args.output_dir.mkdir(parents=True, exist_ok=True)
def check_setting(self) -> None:
# Check for unload_list
if self.UNLOAD_LIST is None:
logger.warning("\033[91mNo unload_list specified for this Trainer. All components will be loaded to GPU during training.\033[0m")
else:
for name in self.UNLOAD_LIST:
if name not in self.components.model_fields:
raise ValueError(f"Invalid component name in unload_list: {name}")
def prepare_models(self) -> None:
logger.info("Initializing models")
@ -150,33 +163,39 @@ class Trainer:
def prepare_dataset(self) -> None:
logger.info("Initializing dataset and dataloader")
# self.state.train_frames includes one padding frame for image conditioning
# so we only sample train_frames - 1 frames from the actual video
max_num_frames = self.state.train_frames - 1
if self.args.model_type == "i2v":
self.dataset = I2VDatasetWithResize(
**(self.args.model_dump()),
device=self.accelerator.device,
encode_video_fn=self.encode_video,
max_num_frames=self.state.train_frames,
max_num_frames=max_num_frames,
height=self.state.train_height,
width=self.state.train_width
width=self.state.train_width,
trainer=self
)
elif self.args.model_type == "t2v":
self.dataset = T2VDatasetWithResize(
**(self.args.model_dump()),
device=self.accelerator.device,
encode_video_fn=self.encode_video,
max_num_frames=self.state.train_frames,
max_num_frames=max_num_frames,
height=self.state.train_height,
width=self.state.train_width
width=self.state.train_width,
trainer=self
)
else:
raise ValueError(f"Invalid model type: {self.args.model_type}")
# Prepare VAE for encoding
# Prepare VAE and text encoder for encoding
self.components.vae = self.components.vae.to(self.accelerator.device)
self.components.vae.requires_grad_(False)
self.components.text_encoder = self.components.text_encoder.to(self.accelerator.device)
self.components.text_encoder.requires_grad_(False)
# Precompute latent for video
logger.info("Precomputing latent for video ...")
# Precompute latent for video and prompt embedding
logger.info("Precomputing latent for video and prompt embedding ...")
tmp_data_loader = torch.utils.data.DataLoader(
self.dataset,
collate_fn=self.collate_fn,
@ -186,7 +205,12 @@ class Trainer:
)
tmp_data_loader = self.accelerator.prepare_data_loader(tmp_data_loader)
for _ in tmp_data_loader: ...
logger.info("Precomputing latent for video ... Done")
self.accelerator.wait_for_everyone()
logger.info("Precomputing latent for video and prompt embedding ... Done")
unload_model(self.components.vae)
unload_model(self.components.text_encoder)
free_memory()
self.data_loader = torch.utils.data.DataLoader(
self.dataset,
@ -216,7 +240,7 @@ class Trainer:
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
)
self.__move_components_to_device()
self.__load_components()
if self.args.gradient_checkpointing:
self.components.transformer.enable_gradient_checkpointing()
@ -234,7 +258,7 @@ class Trainer:
logger.info("Initializing optimizer and lr scheduler")
# Make sure the trainable params are in float32
if self.args.mixed_precision == "fp16":
if self.args.mixed_precision != "no":
# only upcast trainable parameters (LoRA) into fp32
cast_training_params([self.components.transformer], dtype=torch.float32)
@ -423,17 +447,20 @@ class Trainer:
global_step += 1
self.__maybe_save_checkpoint(global_step)
logs["loss"] = loss.detach().item()
logs["lr"] = self.lr_scheduler.get_last_lr()[0]
progress_bar.set_postfix(logs)
# Maybe run validation
should_run_validation = (
self.args.do_validation
and global_step % self.args.validation_steps == 0
)
if should_run_validation:
del loss
free_memory()
self.validate(global_step)
logs["loss"] = loss.detach().item()
logs["lr"] = self.lr_scheduler.get_last_lr()[0]
progress_bar.set_postfix(logs)
accelerator.log(logs, step=global_step)
if global_step >= self.args.train_steps:
@ -445,6 +472,7 @@ class Trainer:
accelerator.wait_for_everyone()
self.__maybe_save_checkpoint(global_step, must_save=True)
if self.args.do_validation:
free_memory()
self.validate(global_step)
del self.components
@ -465,10 +493,22 @@ class Trainer:
return
self.components.transformer.eval()
torch.set_grad_enabled(False)
memory_statistics = get_memory_statistics()
logger.info(f"Memory before validation start: {json.dumps(memory_statistics, indent=4)}")
##### Initialize pipeline #####
pipe = self.initialize_pipeline()
# Or use pipe.enable_sequential_cpu_offload() to further reduce memory usage
pipe.enable_model_cpu_offload(device=self.accelerator.device)
# Convert all model weights to training dtype
# Note, this will change LoRA weights in self.components.transformer to training dtype, rather than keep them in fp32
pipe = pipe.to(dtype=self.state.weight_dtype)
#################################
all_processes_artifacts = []
for i in range(num_validation_samples):
# Skip current validation on all processes but one
@ -504,7 +544,7 @@ class Trainer:
"prompt": prompt,
"image": image,
"video": video
})
}, pipe)
prompt_filename = string_to_filename(prompt)[:25]
artifacts = {
"image": {"type": "image", "value": image},
@ -555,6 +595,12 @@ class Trainer:
step=step,
)
del pipe
# Unload loaded models except those needed for training
self.__unload_components()
# Change LoRA weights back to fp32
cast_training_params([self.components.transformer], dtype=torch.float32)
accelerator.wait_for_everyone()
free_memory()
@ -562,9 +608,12 @@ class Trainer:
logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}")
torch.cuda.reset_peak_memory_stats(accelerator.device)
torch.set_grad_enabled(True)
self.components.transformer.train()
def fit(self):
self.check_setting()
self.prepare_models()
self.prepare_dataset()
self.prepare_trainable_parameters()
@ -580,9 +629,17 @@ class Trainer:
def load_components(self) -> Components:
raise NotImplementedError
def initialize_pipeline(self) -> DiffusionPipeline:
raise NotImplementedError
def encode_video(self, video: torch.Tensor) -> torch.Tensor:
# shape of input video: [B, C, F, H, W], where B = 1
# shape of output video: [B, C', F', H', W'], where B = 1
raise NotImplementedError
def encode_text(self, text: str) -> torch.Tensor:
# shape of output text: [batch size, sequence length, embedding dimension]
raise NotImplementedError
def compute_loss(self, batch) -> torch.Tensor:
@ -601,11 +658,21 @@ class Trainer:
else:
raise ValueError(f"Invalid mixed precision: {self.args.mixed_precision}")
def __move_components_to_device(self):
def __load_components(self):
components = self.components.model_dump()
for name, component in components.items():
if not isinstance(component, type) and hasattr(component, 'to'):
setattr(self.components, name, component.to(self.accelerator.device))
if name in self.UNLOAD_LIST:
continue
# setattr(self.components, name, component.to(self.accelerator.device))
setattr(self.components, name, component.to(self.accelerator.device, dtype=self.state.weight_dtype))
def __unload_components(self):
components = self.components.model_dump()
for name, component in components.items():
if not isinstance(component, type) and hasattr(component, 'to'):
if name in self.UNLOAD_LIST:
setattr(self.components, name, component.to('cpu'))
def __prepare_saving_loading_hooks(self, transformer_lora_config):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format