From e5b8f9a2ee6eb79dbf4d9c0350e06518a3628704 Mon Sep 17 00:00:00 2001 From: OleehyO Date: Fri, 3 Jan 2025 07:52:10 +0000 Subject: [PATCH] feat: add caching for prompt embeddings - Add caching for prompt embeddings - Store cached files using safetensors format - Add cache directory structure under data_root/cache - Optimize memory usage by moving tensors to CPU after caching - Add debug logging for cache hits - Add info logging for cache writes The caching system helps reduce redundant computation and memory usage during training by: 1. Caching prompt embeddings based on prompt text hash 2. Caching encoded video latents based on video filename 3. Moving tensors to CPU after caching to free GPU memory --- finetune/datasets/i2v_dataset.py | 63 +++++++++++++++---- finetune/datasets/t2v_dataset.py | 60 +++++++++++++----- finetune/trainer.py | 103 +++++++++++++++++++++++++------ 3 files changed, 180 insertions(+), 46 deletions(-) diff --git a/finetune/datasets/i2v_dataset.py b/finetune/datasets/i2v_dataset.py index e993c96..2258a0a 100644 --- a/finetune/datasets/i2v_dataset.py +++ b/finetune/datasets/i2v_dataset.py @@ -1,12 +1,15 @@ import torch +import hashlib from pathlib import Path -from typing import Any, Dict, List, Tuple, Callable +from typing import Any, Dict, List, Tuple, TYPE_CHECKING from typing_extensions import override -from accelerate.logging import get_logger from torch.utils.data import Dataset from torchvision import transforms +from accelerate.logging import get_logger +from safetensors.torch import save_file, load_file + from finetune.constants import LOG_NAME, LOG_LEVEL from .utils import ( @@ -17,6 +20,9 @@ from .utils import ( preprocess_video_with_buckets ) +if TYPE_CHECKING: + from finetune.trainer import Trainer + # Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error # Very few bug reports but it happens. Look in decord Github issues for more relevant information. import decord # isort:skip @@ -47,7 +53,7 @@ class BaseI2VDataset(Dataset): video_column: str, image_column: str, device: torch.device, - encode_video_fn: Callable[[torch.Tensor], torch.Tensor] = None, + trainer: "Trainer" = None, *args, **kwargs ) -> None: @@ -57,9 +63,11 @@ class BaseI2VDataset(Dataset): self.prompts = load_prompts(data_root / caption_column) self.videos = load_videos(data_root / video_column) self.images = load_images(data_root / image_column) + self.trainer = trainer self.device = device - self.encode_video_fn = encode_video_fn + self.encode_video = trainer.encode_video + self.encode_text = trainer.encode_text # Check if number of prompts matches number of videos and images if not (len(self.videos) == len(self.prompts) == len(self.images)): @@ -98,34 +106,63 @@ class BaseI2VDataset(Dataset): prompt = self.prompts[index] video = self.videos[index] image = self.images[index] + train_resolution_str = "x".join(str(x) for x in self.trainer.args.train_resolution) - video_latent_dir = video.parent / "latent" + cache_dir = self.trainer.args.data_root / "cache" + video_latent_dir = cache_dir / "video_latent" / self.trainer.args.model_name / train_resolution_str + prompt_embeddings_dir = cache_dir / "prompt_embeddings" video_latent_dir.mkdir(parents=True, exist_ok=True) - encoded_video_path = video_latent_dir / (video.stem + ".pt") + prompt_embeddings_dir.mkdir(parents=True, exist_ok=True) + + prompt_hash = str(hashlib.sha256(prompt.encode()).hexdigest()) + prompt_embedding_path = prompt_embeddings_dir / (prompt_hash + ".safetensors") + encoded_video_path = video_latent_dir / (video.stem + ".safetensors") + + if prompt_embedding_path.exists(): + prompt_embedding = load_file(prompt_embedding_path)["prompt_embedding"] + logger.debug(f"process {self.trainer.accelerator.process_index}: Loaded prompt embedding from {prompt_embedding_path}", main_process_only=False) + else: + prompt_embedding = self.encode_text(prompt) + prompt_embedding = prompt_embedding.to("cpu") + # [1, seq_len, hidden_size] -> [seq_len, hidden_size] + prompt_embedding = prompt_embedding[0] + save_file({"prompt_embedding": prompt_embedding}, prompt_embedding_path) + logger.info(f"Saved prompt embedding to {prompt_embedding_path}", main_process_only=False) if encoded_video_path.exists(): - encoded_video = torch.load(encoded_video_path, weights_only=True) + # encoded_video = torch.load(encoded_video_path, weights_only=True) + encoded_video = load_file(encoded_video_path)["encoded_video"] + logger.debug(f"Loaded encoded video from {encoded_video_path}", main_process_only=False) # shape of image: [C, H, W] _, image = self.preprocess(None, self.images[index]) else: frames, image = self.preprocess(video, image) frames = frames.to(self.device) - # current shape of frames: [F, C, H, W] + image = image.to(self.device) + # Current shape of frames: [F, C, H, W] frames = self.video_transform(frames) + + # Add image into the first frame. + # Note, **this operation maybe model-specific**, and maybe change in the future. + frames = torch.cat([image.unsqueeze(0), frames], dim=0) + # Convert to [B, C, F, H, W] frames = frames.unsqueeze(0) frames = frames.permute(0, 2, 1, 3, 4).contiguous() - encoded_video = self.encode_video_fn(frames) - # [B, C, F, H, W] -> [C, F, H, W] - encoded_video = encoded_video[0].cpu() - torch.save(encoded_video, encoded_video_path) + encoded_video = self.encode_video(frames) + + # [1, C, F, H, W] -> [C, F, H, W] + encoded_video = encoded_video[0] + encoded_video = encoded_video.to("cpu") + image = image.to("cpu") + save_file({"encoded_video": encoded_video}, encoded_video_path) logger.info(f"Saved encoded video to {encoded_video_path}", main_process_only=False) # shape of encoded_video: [C, F, H, W] # shape of image: [C, H, W] return { - "prompt": prompt, "image": image, + "prompt_embedding": prompt_embedding, "encoded_video": encoded_video, "video_metadata": { "num_frames": encoded_video.shape[1], diff --git a/finetune/datasets/t2v_dataset.py b/finetune/datasets/t2v_dataset.py index 9afe53a..7a82a5e 100644 --- a/finetune/datasets/t2v_dataset.py +++ b/finetune/datasets/t2v_dataset.py @@ -1,12 +1,14 @@ +import hashlib import torch from pathlib import Path -from typing import Any, Dict, List, Tuple, Callable +from typing import Any, Dict, List, Tuple, TYPE_CHECKING from typing_extensions import override -from accelerate.logging import get_logger from torch.utils.data import Dataset from torchvision import transforms +from accelerate.logging import get_logger +from safetensors.torch import save_file, load_file from finetune.constants import LOG_NAME, LOG_LEVEL @@ -16,6 +18,9 @@ from .utils import ( preprocess_video_with_buckets ) +if TYPE_CHECKING: + from finetune.trainer import Trainer + # Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error # Very few bug reports but it happens. Look in decord Github issues for more relevant information. import decord # isort:skip @@ -45,7 +50,7 @@ class BaseT2VDataset(Dataset): caption_column: str, video_column: str, device: torch.device = None, - encode_video_fn: Callable[[torch.Tensor], torch.Tensor] = None, + trainer: "Trainer" = None, *args, **kwargs ) -> None: @@ -55,7 +60,9 @@ class BaseT2VDataset(Dataset): self.prompts = load_prompts(data_root / caption_column) self.videos = load_videos(data_root / video_column) self.device = device - self.encode_video_fn = encode_video_fn + self.encode_video = trainer.encode_video + self.encode_text = trainer.encode_text + self.trainer = trainer # Check if all video files exist if any(not path.is_file() for path in self.videos): @@ -87,30 +94,53 @@ class BaseT2VDataset(Dataset): prompt = self.prompts[index] video = self.videos[index] + train_resolution_str = "x".join(str(x) for x in self.trainer.args.train_resolution) - latent_dir = video.parent / "latent" - latent_dir.mkdir(parents=True, exist_ok=True) - encoded_video_path = latent_dir / (video.stem + ".pt") + cache_dir = self.trainer.args.data_root / "cache" + video_latent_dir = cache_dir / "video_latent" / self.trainer.args.model_name / train_resolution_str + prompt_embeddings_dir = cache_dir / "prompt_embeddings" + video_latent_dir.mkdir(parents=True, exist_ok=True) + prompt_embeddings_dir.mkdir(parents=True, exist_ok=True) + + prompt_hash = str(hashlib.sha256(prompt.encode()).hexdigest()) + prompt_embedding_path = prompt_embeddings_dir / (prompt_hash + ".safetensors") + encoded_video_path = video_latent_dir / (video.stem + ".safetensors") + + if prompt_embedding_path.exists(): + prompt_embedding = load_file(prompt_embedding_path)["prompt_embedding"] + logger.debug(f"process {self.trainer.accelerator.process_index}: Loaded prompt embedding from {prompt_embedding_path}", main_process_only=False) + else: + prompt_embedding = self.encode_text(prompt) + prompt_embedding = prompt_embedding.to("cpu") + # [1, seq_len, hidden_size] -> [seq_len, hidden_size] + prompt_embedding = prompt_embedding[0] + save_file({"prompt_embedding": prompt_embedding}, prompt_embedding_path) + logger.info(f"Saved prompt embedding to {prompt_embedding_path}", main_process_only=False) if encoded_video_path.exists(): - # shape of encoded_video: [C, F, H, W] - encoded_video = torch.load(encoded_video_path, weights_only=True) + # encoded_video = torch.load(encoded_video_path, weights_only=True) + encoded_video = load_file(encoded_video_path)["encoded_video"] + logger.debug(f"Loaded encoded video from {encoded_video_path}", main_process_only=False) + # shape of image: [C, H, W] else: frames = self.preprocess(video) frames = frames.to(self.device) - # current shape of frames: [F, C, H, W] + # Current shape of frames: [F, C, H, W] frames = self.video_transform(frames) # Convert to [B, C, F, H, W] frames = frames.unsqueeze(0) frames = frames.permute(0, 2, 1, 3, 4).contiguous() - encoded_video = self.encode_video_fn(frames) - # [B, C, F, H, W] -> [C, F, H, W] - encoded_video = encoded_video[0].cpu() - torch.save(encoded_video, encoded_video_path) + encoded_video = self.encode_video(frames) + + # [1, C, F, H, W] -> [C, F, H, W] + encoded_video = encoded_video[0] + encoded_video = encoded_video.to("cpu") + save_file({"encoded_video": encoded_video}, encoded_video_path) logger.info(f"Saved encoded video to {encoded_video_path}", main_process_only=False) + # shape of encoded_video: [C, F, H, W] return { - "prompt": prompt, + "prompt_embedding": prompt_embedding, "encoded_video": encoded_video, "video_metadata": { "num_frames": encoded_video.shape[1], diff --git a/finetune/trainer.py b/finetune/trainer.py index 6c9ec82..a6be790 100644 --- a/finetune/trainer.py +++ b/finetune/trainer.py @@ -25,6 +25,7 @@ from accelerate.utils import ( gather_object, ) +from diffusers.pipelines import DiffusionPipeline from diffusers.optimization import get_scheduler from diffusers.utils.export_utils import export_to_video from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict @@ -36,6 +37,7 @@ from finetune.utils import ( get_memory_statistics, free_memory, + unload_model, get_latest_ckpt_path_to_resume_from, get_intermediate_ckpt_path, @@ -63,6 +65,8 @@ _DTYPE_MAP = { class Trainer: + # If set, should be a list of components to unload (refer to `Components``) + UNLOAD_LIST: List[str] = None def __init__(self, args: Args) -> None: self.args = args @@ -132,6 +136,15 @@ class Trainer: if self.accelerator.is_main_process: self.args.output_dir = Path(self.args.output_dir) self.args.output_dir.mkdir(parents=True, exist_ok=True) + + def check_setting(self) -> None: + # Check for unload_list + if self.UNLOAD_LIST is None: + logger.warning("\033[91mNo unload_list specified for this Trainer. All components will be loaded to GPU during training.\033[0m") + else: + for name in self.UNLOAD_LIST: + if name not in self.components.model_fields: + raise ValueError(f"Invalid component name in unload_list: {name}") def prepare_models(self) -> None: logger.info("Initializing models") @@ -150,33 +163,39 @@ class Trainer: def prepare_dataset(self) -> None: logger.info("Initializing dataset and dataloader") + # self.state.train_frames includes one padding frame for image conditioning + # so we only sample train_frames - 1 frames from the actual video + max_num_frames = self.state.train_frames - 1 + if self.args.model_type == "i2v": self.dataset = I2VDatasetWithResize( **(self.args.model_dump()), device=self.accelerator.device, - encode_video_fn=self.encode_video, - max_num_frames=self.state.train_frames, + max_num_frames=max_num_frames, height=self.state.train_height, - width=self.state.train_width + width=self.state.train_width, + trainer=self ) elif self.args.model_type == "t2v": self.dataset = T2VDatasetWithResize( **(self.args.model_dump()), device=self.accelerator.device, - encode_video_fn=self.encode_video, - max_num_frames=self.state.train_frames, + max_num_frames=max_num_frames, height=self.state.train_height, - width=self.state.train_width + width=self.state.train_width, + trainer=self ) else: raise ValueError(f"Invalid model type: {self.args.model_type}") - # Prepare VAE for encoding + # Prepare VAE and text encoder for encoding self.components.vae = self.components.vae.to(self.accelerator.device) self.components.vae.requires_grad_(False) + self.components.text_encoder = self.components.text_encoder.to(self.accelerator.device) + self.components.text_encoder.requires_grad_(False) - # Precompute latent for video - logger.info("Precomputing latent for video ...") + # Precompute latent for video and prompt embedding + logger.info("Precomputing latent for video and prompt embedding ...") tmp_data_loader = torch.utils.data.DataLoader( self.dataset, collate_fn=self.collate_fn, @@ -186,7 +205,12 @@ class Trainer: ) tmp_data_loader = self.accelerator.prepare_data_loader(tmp_data_loader) for _ in tmp_data_loader: ... - logger.info("Precomputing latent for video ... Done") + self.accelerator.wait_for_everyone() + logger.info("Precomputing latent for video and prompt embedding ... Done") + + unload_model(self.components.vae) + unload_model(self.components.text_encoder) + free_memory() self.data_loader = torch.utils.data.DataLoader( self.dataset, @@ -216,7 +240,7 @@ class Trainer: "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." ) - self.__move_components_to_device() + self.__load_components() if self.args.gradient_checkpointing: self.components.transformer.enable_gradient_checkpointing() @@ -234,7 +258,7 @@ class Trainer: logger.info("Initializing optimizer and lr scheduler") # Make sure the trainable params are in float32 - if self.args.mixed_precision == "fp16": + if self.args.mixed_precision != "no": # only upcast trainable parameters (LoRA) into fp32 cast_training_params([self.components.transformer], dtype=torch.float32) @@ -423,17 +447,20 @@ class Trainer: global_step += 1 self.__maybe_save_checkpoint(global_step) + logs["loss"] = loss.detach().item() + logs["lr"] = self.lr_scheduler.get_last_lr()[0] + progress_bar.set_postfix(logs) + # Maybe run validation should_run_validation = ( self.args.do_validation and global_step % self.args.validation_steps == 0 ) if should_run_validation: + del loss + free_memory() self.validate(global_step) - logs["loss"] = loss.detach().item() - logs["lr"] = self.lr_scheduler.get_last_lr()[0] - progress_bar.set_postfix(logs) accelerator.log(logs, step=global_step) if global_step >= self.args.train_steps: @@ -445,6 +472,7 @@ class Trainer: accelerator.wait_for_everyone() self.__maybe_save_checkpoint(global_step, must_save=True) if self.args.do_validation: + free_memory() self.validate(global_step) del self.components @@ -465,10 +493,22 @@ class Trainer: return self.components.transformer.eval() + torch.set_grad_enabled(False) memory_statistics = get_memory_statistics() logger.info(f"Memory before validation start: {json.dumps(memory_statistics, indent=4)}") + ##### Initialize pipeline ##### + pipe = self.initialize_pipeline() + + # Or use pipe.enable_sequential_cpu_offload() to further reduce memory usage + pipe.enable_model_cpu_offload(device=self.accelerator.device) + + # Convert all model weights to training dtype + # Note, this will change LoRA weights in self.components.transformer to training dtype, rather than keep them in fp32 + pipe = pipe.to(dtype=self.state.weight_dtype) + ################################# + all_processes_artifacts = [] for i in range(num_validation_samples): # Skip current validation on all processes but one @@ -504,7 +544,7 @@ class Trainer: "prompt": prompt, "image": image, "video": video - }) + }, pipe) prompt_filename = string_to_filename(prompt)[:25] artifacts = { "image": {"type": "image", "value": image}, @@ -555,6 +595,12 @@ class Trainer: step=step, ) + del pipe + # Unload loaded models except those needed for training + self.__unload_components() + # Change LoRA weights back to fp32 + cast_training_params([self.components.transformer], dtype=torch.float32) + accelerator.wait_for_everyone() free_memory() @@ -562,9 +608,12 @@ class Trainer: logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}") torch.cuda.reset_peak_memory_stats(accelerator.device) + torch.set_grad_enabled(True) self.components.transformer.train() + def fit(self): + self.check_setting() self.prepare_models() self.prepare_dataset() self.prepare_trainable_parameters() @@ -580,9 +629,17 @@ class Trainer: def load_components(self) -> Components: raise NotImplementedError + + def initialize_pipeline(self) -> DiffusionPipeline: + raise NotImplementedError def encode_video(self, video: torch.Tensor) -> torch.Tensor: # shape of input video: [B, C, F, H, W], where B = 1 + # shape of output video: [B, C', F', H', W'], where B = 1 + raise NotImplementedError + + def encode_text(self, text: str) -> torch.Tensor: + # shape of output text: [batch size, sequence length, embedding dimension] raise NotImplementedError def compute_loss(self, batch) -> torch.Tensor: @@ -601,11 +658,21 @@ class Trainer: else: raise ValueError(f"Invalid mixed precision: {self.args.mixed_precision}") - def __move_components_to_device(self): + def __load_components(self): components = self.components.model_dump() for name, component in components.items(): if not isinstance(component, type) and hasattr(component, 'to'): - setattr(self.components, name, component.to(self.accelerator.device)) + if name in self.UNLOAD_LIST: + continue + # setattr(self.components, name, component.to(self.accelerator.device)) + setattr(self.components, name, component.to(self.accelerator.device, dtype=self.state.weight_dtype)) + + def __unload_components(self): + components = self.components.model_dump() + for name, component in components.items(): + if not isinstance(component, type) and hasattr(component, 'to'): + if name in self.UNLOAD_LIST: + setattr(self.components, name, component.to('cpu')) def __prepare_saving_loading_hooks(self, transformer_lora_config): # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format