feat: add caching for prompt embeddings

- Add caching for prompt embeddings
- Store cached files using safetensors format
- Add cache directory structure under data_root/cache
- Optimize memory usage by moving tensors to CPU after caching
- Add debug logging for cache hits
- Add info logging for cache writes

The caching system helps reduce redundant computation and memory usage during training by:
1. Caching prompt embeddings based on prompt text hash
2. Caching encoded video latents based on video filename
3. Moving tensors to CPU after caching to free GPU memory
This commit is contained in:
OleehyO 2025-01-03 07:52:10 +00:00
parent f731c35f70
commit e5b8f9a2ee
3 changed files with 180 additions and 46 deletions

View File

@ -1,12 +1,15 @@
import torch import torch
import hashlib
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Tuple, Callable from typing import Any, Dict, List, Tuple, TYPE_CHECKING
from typing_extensions import override from typing_extensions import override
from accelerate.logging import get_logger
from torch.utils.data import Dataset from torch.utils.data import Dataset
from torchvision import transforms from torchvision import transforms
from accelerate.logging import get_logger
from safetensors.torch import save_file, load_file
from finetune.constants import LOG_NAME, LOG_LEVEL from finetune.constants import LOG_NAME, LOG_LEVEL
from .utils import ( from .utils import (
@ -17,6 +20,9 @@ from .utils import (
preprocess_video_with_buckets preprocess_video_with_buckets
) )
if TYPE_CHECKING:
from finetune.trainer import Trainer
# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error # Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
# Very few bug reports but it happens. Look in decord Github issues for more relevant information. # Very few bug reports but it happens. Look in decord Github issues for more relevant information.
import decord # isort:skip import decord # isort:skip
@ -47,7 +53,7 @@ class BaseI2VDataset(Dataset):
video_column: str, video_column: str,
image_column: str, image_column: str,
device: torch.device, device: torch.device,
encode_video_fn: Callable[[torch.Tensor], torch.Tensor] = None, trainer: "Trainer" = None,
*args, *args,
**kwargs **kwargs
) -> None: ) -> None:
@ -57,9 +63,11 @@ class BaseI2VDataset(Dataset):
self.prompts = load_prompts(data_root / caption_column) self.prompts = load_prompts(data_root / caption_column)
self.videos = load_videos(data_root / video_column) self.videos = load_videos(data_root / video_column)
self.images = load_images(data_root / image_column) self.images = load_images(data_root / image_column)
self.trainer = trainer
self.device = device self.device = device
self.encode_video_fn = encode_video_fn self.encode_video = trainer.encode_video
self.encode_text = trainer.encode_text
# Check if number of prompts matches number of videos and images # Check if number of prompts matches number of videos and images
if not (len(self.videos) == len(self.prompts) == len(self.images)): if not (len(self.videos) == len(self.prompts) == len(self.images)):
@ -98,34 +106,63 @@ class BaseI2VDataset(Dataset):
prompt = self.prompts[index] prompt = self.prompts[index]
video = self.videos[index] video = self.videos[index]
image = self.images[index] image = self.images[index]
train_resolution_str = "x".join(str(x) for x in self.trainer.args.train_resolution)
video_latent_dir = video.parent / "latent" cache_dir = self.trainer.args.data_root / "cache"
video_latent_dir = cache_dir / "video_latent" / self.trainer.args.model_name / train_resolution_str
prompt_embeddings_dir = cache_dir / "prompt_embeddings"
video_latent_dir.mkdir(parents=True, exist_ok=True) video_latent_dir.mkdir(parents=True, exist_ok=True)
encoded_video_path = video_latent_dir / (video.stem + ".pt") prompt_embeddings_dir.mkdir(parents=True, exist_ok=True)
prompt_hash = str(hashlib.sha256(prompt.encode()).hexdigest())
prompt_embedding_path = prompt_embeddings_dir / (prompt_hash + ".safetensors")
encoded_video_path = video_latent_dir / (video.stem + ".safetensors")
if prompt_embedding_path.exists():
prompt_embedding = load_file(prompt_embedding_path)["prompt_embedding"]
logger.debug(f"process {self.trainer.accelerator.process_index}: Loaded prompt embedding from {prompt_embedding_path}", main_process_only=False)
else:
prompt_embedding = self.encode_text(prompt)
prompt_embedding = prompt_embedding.to("cpu")
# [1, seq_len, hidden_size] -> [seq_len, hidden_size]
prompt_embedding = prompt_embedding[0]
save_file({"prompt_embedding": prompt_embedding}, prompt_embedding_path)
logger.info(f"Saved prompt embedding to {prompt_embedding_path}", main_process_only=False)
if encoded_video_path.exists(): if encoded_video_path.exists():
encoded_video = torch.load(encoded_video_path, weights_only=True) # encoded_video = torch.load(encoded_video_path, weights_only=True)
encoded_video = load_file(encoded_video_path)["encoded_video"]
logger.debug(f"Loaded encoded video from {encoded_video_path}", main_process_only=False)
# shape of image: [C, H, W] # shape of image: [C, H, W]
_, image = self.preprocess(None, self.images[index]) _, image = self.preprocess(None, self.images[index])
else: else:
frames, image = self.preprocess(video, image) frames, image = self.preprocess(video, image)
frames = frames.to(self.device) frames = frames.to(self.device)
# current shape of frames: [F, C, H, W] image = image.to(self.device)
# Current shape of frames: [F, C, H, W]
frames = self.video_transform(frames) frames = self.video_transform(frames)
# Add image into the first frame.
# Note, **this operation maybe model-specific**, and maybe change in the future.
frames = torch.cat([image.unsqueeze(0), frames], dim=0)
# Convert to [B, C, F, H, W] # Convert to [B, C, F, H, W]
frames = frames.unsqueeze(0) frames = frames.unsqueeze(0)
frames = frames.permute(0, 2, 1, 3, 4).contiguous() frames = frames.permute(0, 2, 1, 3, 4).contiguous()
encoded_video = self.encode_video_fn(frames) encoded_video = self.encode_video(frames)
# [B, C, F, H, W] -> [C, F, H, W]
encoded_video = encoded_video[0].cpu() # [1, C, F, H, W] -> [C, F, H, W]
torch.save(encoded_video, encoded_video_path) encoded_video = encoded_video[0]
encoded_video = encoded_video.to("cpu")
image = image.to("cpu")
save_file({"encoded_video": encoded_video}, encoded_video_path)
logger.info(f"Saved encoded video to {encoded_video_path}", main_process_only=False) logger.info(f"Saved encoded video to {encoded_video_path}", main_process_only=False)
# shape of encoded_video: [C, F, H, W] # shape of encoded_video: [C, F, H, W]
# shape of image: [C, H, W] # shape of image: [C, H, W]
return { return {
"prompt": prompt,
"image": image, "image": image,
"prompt_embedding": prompt_embedding,
"encoded_video": encoded_video, "encoded_video": encoded_video,
"video_metadata": { "video_metadata": {
"num_frames": encoded_video.shape[1], "num_frames": encoded_video.shape[1],

View File

@ -1,12 +1,14 @@
import hashlib
import torch import torch
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Tuple, Callable from typing import Any, Dict, List, Tuple, TYPE_CHECKING
from typing_extensions import override from typing_extensions import override
from accelerate.logging import get_logger
from torch.utils.data import Dataset from torch.utils.data import Dataset
from torchvision import transforms from torchvision import transforms
from accelerate.logging import get_logger
from safetensors.torch import save_file, load_file
from finetune.constants import LOG_NAME, LOG_LEVEL from finetune.constants import LOG_NAME, LOG_LEVEL
@ -16,6 +18,9 @@ from .utils import (
preprocess_video_with_buckets preprocess_video_with_buckets
) )
if TYPE_CHECKING:
from finetune.trainer import Trainer
# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error # Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
# Very few bug reports but it happens. Look in decord Github issues for more relevant information. # Very few bug reports but it happens. Look in decord Github issues for more relevant information.
import decord # isort:skip import decord # isort:skip
@ -45,7 +50,7 @@ class BaseT2VDataset(Dataset):
caption_column: str, caption_column: str,
video_column: str, video_column: str,
device: torch.device = None, device: torch.device = None,
encode_video_fn: Callable[[torch.Tensor], torch.Tensor] = None, trainer: "Trainer" = None,
*args, *args,
**kwargs **kwargs
) -> None: ) -> None:
@ -55,7 +60,9 @@ class BaseT2VDataset(Dataset):
self.prompts = load_prompts(data_root / caption_column) self.prompts = load_prompts(data_root / caption_column)
self.videos = load_videos(data_root / video_column) self.videos = load_videos(data_root / video_column)
self.device = device self.device = device
self.encode_video_fn = encode_video_fn self.encode_video = trainer.encode_video
self.encode_text = trainer.encode_text
self.trainer = trainer
# Check if all video files exist # Check if all video files exist
if any(not path.is_file() for path in self.videos): if any(not path.is_file() for path in self.videos):
@ -87,30 +94,53 @@ class BaseT2VDataset(Dataset):
prompt = self.prompts[index] prompt = self.prompts[index]
video = self.videos[index] video = self.videos[index]
train_resolution_str = "x".join(str(x) for x in self.trainer.args.train_resolution)
latent_dir = video.parent / "latent" cache_dir = self.trainer.args.data_root / "cache"
latent_dir.mkdir(parents=True, exist_ok=True) video_latent_dir = cache_dir / "video_latent" / self.trainer.args.model_name / train_resolution_str
encoded_video_path = latent_dir / (video.stem + ".pt") prompt_embeddings_dir = cache_dir / "prompt_embeddings"
video_latent_dir.mkdir(parents=True, exist_ok=True)
prompt_embeddings_dir.mkdir(parents=True, exist_ok=True)
prompt_hash = str(hashlib.sha256(prompt.encode()).hexdigest())
prompt_embedding_path = prompt_embeddings_dir / (prompt_hash + ".safetensors")
encoded_video_path = video_latent_dir / (video.stem + ".safetensors")
if prompt_embedding_path.exists():
prompt_embedding = load_file(prompt_embedding_path)["prompt_embedding"]
logger.debug(f"process {self.trainer.accelerator.process_index}: Loaded prompt embedding from {prompt_embedding_path}", main_process_only=False)
else:
prompt_embedding = self.encode_text(prompt)
prompt_embedding = prompt_embedding.to("cpu")
# [1, seq_len, hidden_size] -> [seq_len, hidden_size]
prompt_embedding = prompt_embedding[0]
save_file({"prompt_embedding": prompt_embedding}, prompt_embedding_path)
logger.info(f"Saved prompt embedding to {prompt_embedding_path}", main_process_only=False)
if encoded_video_path.exists(): if encoded_video_path.exists():
# shape of encoded_video: [C, F, H, W] # encoded_video = torch.load(encoded_video_path, weights_only=True)
encoded_video = torch.load(encoded_video_path, weights_only=True) encoded_video = load_file(encoded_video_path)["encoded_video"]
logger.debug(f"Loaded encoded video from {encoded_video_path}", main_process_only=False)
# shape of image: [C, H, W]
else: else:
frames = self.preprocess(video) frames = self.preprocess(video)
frames = frames.to(self.device) frames = frames.to(self.device)
# current shape of frames: [F, C, H, W] # Current shape of frames: [F, C, H, W]
frames = self.video_transform(frames) frames = self.video_transform(frames)
# Convert to [B, C, F, H, W] # Convert to [B, C, F, H, W]
frames = frames.unsqueeze(0) frames = frames.unsqueeze(0)
frames = frames.permute(0, 2, 1, 3, 4).contiguous() frames = frames.permute(0, 2, 1, 3, 4).contiguous()
encoded_video = self.encode_video_fn(frames) encoded_video = self.encode_video(frames)
# [B, C, F, H, W] -> [C, F, H, W]
encoded_video = encoded_video[0].cpu() # [1, C, F, H, W] -> [C, F, H, W]
torch.save(encoded_video, encoded_video_path) encoded_video = encoded_video[0]
encoded_video = encoded_video.to("cpu")
save_file({"encoded_video": encoded_video}, encoded_video_path)
logger.info(f"Saved encoded video to {encoded_video_path}", main_process_only=False) logger.info(f"Saved encoded video to {encoded_video_path}", main_process_only=False)
# shape of encoded_video: [C, F, H, W]
return { return {
"prompt": prompt, "prompt_embedding": prompt_embedding,
"encoded_video": encoded_video, "encoded_video": encoded_video,
"video_metadata": { "video_metadata": {
"num_frames": encoded_video.shape[1], "num_frames": encoded_video.shape[1],

View File

@ -25,6 +25,7 @@ from accelerate.utils import (
gather_object, gather_object,
) )
from diffusers.pipelines import DiffusionPipeline
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.utils.export_utils import export_to_video from diffusers.utils.export_utils import export_to_video
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
@ -36,6 +37,7 @@ from finetune.utils import (
get_memory_statistics, get_memory_statistics,
free_memory, free_memory,
unload_model,
get_latest_ckpt_path_to_resume_from, get_latest_ckpt_path_to_resume_from,
get_intermediate_ckpt_path, get_intermediate_ckpt_path,
@ -63,6 +65,8 @@ _DTYPE_MAP = {
class Trainer: class Trainer:
# If set, should be a list of components to unload (refer to `Components``)
UNLOAD_LIST: List[str] = None
def __init__(self, args: Args) -> None: def __init__(self, args: Args) -> None:
self.args = args self.args = args
@ -132,6 +136,15 @@ class Trainer:
if self.accelerator.is_main_process: if self.accelerator.is_main_process:
self.args.output_dir = Path(self.args.output_dir) self.args.output_dir = Path(self.args.output_dir)
self.args.output_dir.mkdir(parents=True, exist_ok=True) self.args.output_dir.mkdir(parents=True, exist_ok=True)
def check_setting(self) -> None:
# Check for unload_list
if self.UNLOAD_LIST is None:
logger.warning("\033[91mNo unload_list specified for this Trainer. All components will be loaded to GPU during training.\033[0m")
else:
for name in self.UNLOAD_LIST:
if name not in self.components.model_fields:
raise ValueError(f"Invalid component name in unload_list: {name}")
def prepare_models(self) -> None: def prepare_models(self) -> None:
logger.info("Initializing models") logger.info("Initializing models")
@ -150,33 +163,39 @@ class Trainer:
def prepare_dataset(self) -> None: def prepare_dataset(self) -> None:
logger.info("Initializing dataset and dataloader") logger.info("Initializing dataset and dataloader")
# self.state.train_frames includes one padding frame for image conditioning
# so we only sample train_frames - 1 frames from the actual video
max_num_frames = self.state.train_frames - 1
if self.args.model_type == "i2v": if self.args.model_type == "i2v":
self.dataset = I2VDatasetWithResize( self.dataset = I2VDatasetWithResize(
**(self.args.model_dump()), **(self.args.model_dump()),
device=self.accelerator.device, device=self.accelerator.device,
encode_video_fn=self.encode_video, max_num_frames=max_num_frames,
max_num_frames=self.state.train_frames,
height=self.state.train_height, height=self.state.train_height,
width=self.state.train_width width=self.state.train_width,
trainer=self
) )
elif self.args.model_type == "t2v": elif self.args.model_type == "t2v":
self.dataset = T2VDatasetWithResize( self.dataset = T2VDatasetWithResize(
**(self.args.model_dump()), **(self.args.model_dump()),
device=self.accelerator.device, device=self.accelerator.device,
encode_video_fn=self.encode_video, max_num_frames=max_num_frames,
max_num_frames=self.state.train_frames,
height=self.state.train_height, height=self.state.train_height,
width=self.state.train_width width=self.state.train_width,
trainer=self
) )
else: else:
raise ValueError(f"Invalid model type: {self.args.model_type}") raise ValueError(f"Invalid model type: {self.args.model_type}")
# Prepare VAE for encoding # Prepare VAE and text encoder for encoding
self.components.vae = self.components.vae.to(self.accelerator.device) self.components.vae = self.components.vae.to(self.accelerator.device)
self.components.vae.requires_grad_(False) self.components.vae.requires_grad_(False)
self.components.text_encoder = self.components.text_encoder.to(self.accelerator.device)
self.components.text_encoder.requires_grad_(False)
# Precompute latent for video # Precompute latent for video and prompt embedding
logger.info("Precomputing latent for video ...") logger.info("Precomputing latent for video and prompt embedding ...")
tmp_data_loader = torch.utils.data.DataLoader( tmp_data_loader = torch.utils.data.DataLoader(
self.dataset, self.dataset,
collate_fn=self.collate_fn, collate_fn=self.collate_fn,
@ -186,7 +205,12 @@ class Trainer:
) )
tmp_data_loader = self.accelerator.prepare_data_loader(tmp_data_loader) tmp_data_loader = self.accelerator.prepare_data_loader(tmp_data_loader)
for _ in tmp_data_loader: ... for _ in tmp_data_loader: ...
logger.info("Precomputing latent for video ... Done") self.accelerator.wait_for_everyone()
logger.info("Precomputing latent for video and prompt embedding ... Done")
unload_model(self.components.vae)
unload_model(self.components.text_encoder)
free_memory()
self.data_loader = torch.utils.data.DataLoader( self.data_loader = torch.utils.data.DataLoader(
self.dataset, self.dataset,
@ -216,7 +240,7 @@ class Trainer:
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
) )
self.__move_components_to_device() self.__load_components()
if self.args.gradient_checkpointing: if self.args.gradient_checkpointing:
self.components.transformer.enable_gradient_checkpointing() self.components.transformer.enable_gradient_checkpointing()
@ -234,7 +258,7 @@ class Trainer:
logger.info("Initializing optimizer and lr scheduler") logger.info("Initializing optimizer and lr scheduler")
# Make sure the trainable params are in float32 # Make sure the trainable params are in float32
if self.args.mixed_precision == "fp16": if self.args.mixed_precision != "no":
# only upcast trainable parameters (LoRA) into fp32 # only upcast trainable parameters (LoRA) into fp32
cast_training_params([self.components.transformer], dtype=torch.float32) cast_training_params([self.components.transformer], dtype=torch.float32)
@ -423,17 +447,20 @@ class Trainer:
global_step += 1 global_step += 1
self.__maybe_save_checkpoint(global_step) self.__maybe_save_checkpoint(global_step)
logs["loss"] = loss.detach().item()
logs["lr"] = self.lr_scheduler.get_last_lr()[0]
progress_bar.set_postfix(logs)
# Maybe run validation # Maybe run validation
should_run_validation = ( should_run_validation = (
self.args.do_validation self.args.do_validation
and global_step % self.args.validation_steps == 0 and global_step % self.args.validation_steps == 0
) )
if should_run_validation: if should_run_validation:
del loss
free_memory()
self.validate(global_step) self.validate(global_step)
logs["loss"] = loss.detach().item()
logs["lr"] = self.lr_scheduler.get_last_lr()[0]
progress_bar.set_postfix(logs)
accelerator.log(logs, step=global_step) accelerator.log(logs, step=global_step)
if global_step >= self.args.train_steps: if global_step >= self.args.train_steps:
@ -445,6 +472,7 @@ class Trainer:
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
self.__maybe_save_checkpoint(global_step, must_save=True) self.__maybe_save_checkpoint(global_step, must_save=True)
if self.args.do_validation: if self.args.do_validation:
free_memory()
self.validate(global_step) self.validate(global_step)
del self.components del self.components
@ -465,10 +493,22 @@ class Trainer:
return return
self.components.transformer.eval() self.components.transformer.eval()
torch.set_grad_enabled(False)
memory_statistics = get_memory_statistics() memory_statistics = get_memory_statistics()
logger.info(f"Memory before validation start: {json.dumps(memory_statistics, indent=4)}") logger.info(f"Memory before validation start: {json.dumps(memory_statistics, indent=4)}")
##### Initialize pipeline #####
pipe = self.initialize_pipeline()
# Or use pipe.enable_sequential_cpu_offload() to further reduce memory usage
pipe.enable_model_cpu_offload(device=self.accelerator.device)
# Convert all model weights to training dtype
# Note, this will change LoRA weights in self.components.transformer to training dtype, rather than keep them in fp32
pipe = pipe.to(dtype=self.state.weight_dtype)
#################################
all_processes_artifacts = [] all_processes_artifacts = []
for i in range(num_validation_samples): for i in range(num_validation_samples):
# Skip current validation on all processes but one # Skip current validation on all processes but one
@ -504,7 +544,7 @@ class Trainer:
"prompt": prompt, "prompt": prompt,
"image": image, "image": image,
"video": video "video": video
}) }, pipe)
prompt_filename = string_to_filename(prompt)[:25] prompt_filename = string_to_filename(prompt)[:25]
artifacts = { artifacts = {
"image": {"type": "image", "value": image}, "image": {"type": "image", "value": image},
@ -555,6 +595,12 @@ class Trainer:
step=step, step=step,
) )
del pipe
# Unload loaded models except those needed for training
self.__unload_components()
# Change LoRA weights back to fp32
cast_training_params([self.components.transformer], dtype=torch.float32)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
free_memory() free_memory()
@ -562,9 +608,12 @@ class Trainer:
logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}") logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}")
torch.cuda.reset_peak_memory_stats(accelerator.device) torch.cuda.reset_peak_memory_stats(accelerator.device)
torch.set_grad_enabled(True)
self.components.transformer.train() self.components.transformer.train()
def fit(self): def fit(self):
self.check_setting()
self.prepare_models() self.prepare_models()
self.prepare_dataset() self.prepare_dataset()
self.prepare_trainable_parameters() self.prepare_trainable_parameters()
@ -580,9 +629,17 @@ class Trainer:
def load_components(self) -> Components: def load_components(self) -> Components:
raise NotImplementedError raise NotImplementedError
def initialize_pipeline(self) -> DiffusionPipeline:
raise NotImplementedError
def encode_video(self, video: torch.Tensor) -> torch.Tensor: def encode_video(self, video: torch.Tensor) -> torch.Tensor:
# shape of input video: [B, C, F, H, W], where B = 1 # shape of input video: [B, C, F, H, W], where B = 1
# shape of output video: [B, C', F', H', W'], where B = 1
raise NotImplementedError
def encode_text(self, text: str) -> torch.Tensor:
# shape of output text: [batch size, sequence length, embedding dimension]
raise NotImplementedError raise NotImplementedError
def compute_loss(self, batch) -> torch.Tensor: def compute_loss(self, batch) -> torch.Tensor:
@ -601,11 +658,21 @@ class Trainer:
else: else:
raise ValueError(f"Invalid mixed precision: {self.args.mixed_precision}") raise ValueError(f"Invalid mixed precision: {self.args.mixed_precision}")
def __move_components_to_device(self): def __load_components(self):
components = self.components.model_dump() components = self.components.model_dump()
for name, component in components.items(): for name, component in components.items():
if not isinstance(component, type) and hasattr(component, 'to'): if not isinstance(component, type) and hasattr(component, 'to'):
setattr(self.components, name, component.to(self.accelerator.device)) if name in self.UNLOAD_LIST:
continue
# setattr(self.components, name, component.to(self.accelerator.device))
setattr(self.components, name, component.to(self.accelerator.device, dtype=self.state.weight_dtype))
def __unload_components(self):
components = self.components.model_dump()
for name, component in components.items():
if not isinstance(component, type) and hasattr(component, 'to'):
if name in self.UNLOAD_LIST:
setattr(self.components, name, component.to('cpu'))
def __prepare_saving_loading_hooks(self, transformer_lora_config): def __prepare_saving_loading_hooks(self, transformer_lora_config):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format