mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 19:41:59 +08:00
feat: add caching for prompt embeddings
- Add caching for prompt embeddings - Store cached files using safetensors format - Add cache directory structure under data_root/cache - Optimize memory usage by moving tensors to CPU after caching - Add debug logging for cache hits - Add info logging for cache writes The caching system helps reduce redundant computation and memory usage during training by: 1. Caching prompt embeddings based on prompt text hash 2. Caching encoded video latents based on video filename 3. Moving tensors to CPU after caching to free GPU memory
This commit is contained in:
parent
f731c35f70
commit
e5b8f9a2ee
@ -1,12 +1,15 @@
|
||||
import torch
|
||||
import hashlib
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple, Callable
|
||||
from typing import Any, Dict, List, Tuple, TYPE_CHECKING
|
||||
from typing_extensions import override
|
||||
|
||||
from accelerate.logging import get_logger
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
from accelerate.logging import get_logger
|
||||
from safetensors.torch import save_file, load_file
|
||||
|
||||
from finetune.constants import LOG_NAME, LOG_LEVEL
|
||||
|
||||
from .utils import (
|
||||
@ -17,6 +20,9 @@ from .utils import (
|
||||
preprocess_video_with_buckets
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from finetune.trainer import Trainer
|
||||
|
||||
# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
|
||||
# Very few bug reports but it happens. Look in decord Github issues for more relevant information.
|
||||
import decord # isort:skip
|
||||
@ -47,7 +53,7 @@ class BaseI2VDataset(Dataset):
|
||||
video_column: str,
|
||||
image_column: str,
|
||||
device: torch.device,
|
||||
encode_video_fn: Callable[[torch.Tensor], torch.Tensor] = None,
|
||||
trainer: "Trainer" = None,
|
||||
*args,
|
||||
**kwargs
|
||||
) -> None:
|
||||
@ -57,9 +63,11 @@ class BaseI2VDataset(Dataset):
|
||||
self.prompts = load_prompts(data_root / caption_column)
|
||||
self.videos = load_videos(data_root / video_column)
|
||||
self.images = load_images(data_root / image_column)
|
||||
self.trainer = trainer
|
||||
|
||||
self.device = device
|
||||
self.encode_video_fn = encode_video_fn
|
||||
self.encode_video = trainer.encode_video
|
||||
self.encode_text = trainer.encode_text
|
||||
|
||||
# Check if number of prompts matches number of videos and images
|
||||
if not (len(self.videos) == len(self.prompts) == len(self.images)):
|
||||
@ -98,34 +106,63 @@ class BaseI2VDataset(Dataset):
|
||||
prompt = self.prompts[index]
|
||||
video = self.videos[index]
|
||||
image = self.images[index]
|
||||
train_resolution_str = "x".join(str(x) for x in self.trainer.args.train_resolution)
|
||||
|
||||
video_latent_dir = video.parent / "latent"
|
||||
cache_dir = self.trainer.args.data_root / "cache"
|
||||
video_latent_dir = cache_dir / "video_latent" / self.trainer.args.model_name / train_resolution_str
|
||||
prompt_embeddings_dir = cache_dir / "prompt_embeddings"
|
||||
video_latent_dir.mkdir(parents=True, exist_ok=True)
|
||||
encoded_video_path = video_latent_dir / (video.stem + ".pt")
|
||||
prompt_embeddings_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
prompt_hash = str(hashlib.sha256(prompt.encode()).hexdigest())
|
||||
prompt_embedding_path = prompt_embeddings_dir / (prompt_hash + ".safetensors")
|
||||
encoded_video_path = video_latent_dir / (video.stem + ".safetensors")
|
||||
|
||||
if prompt_embedding_path.exists():
|
||||
prompt_embedding = load_file(prompt_embedding_path)["prompt_embedding"]
|
||||
logger.debug(f"process {self.trainer.accelerator.process_index}: Loaded prompt embedding from {prompt_embedding_path}", main_process_only=False)
|
||||
else:
|
||||
prompt_embedding = self.encode_text(prompt)
|
||||
prompt_embedding = prompt_embedding.to("cpu")
|
||||
# [1, seq_len, hidden_size] -> [seq_len, hidden_size]
|
||||
prompt_embedding = prompt_embedding[0]
|
||||
save_file({"prompt_embedding": prompt_embedding}, prompt_embedding_path)
|
||||
logger.info(f"Saved prompt embedding to {prompt_embedding_path}", main_process_only=False)
|
||||
|
||||
if encoded_video_path.exists():
|
||||
encoded_video = torch.load(encoded_video_path, weights_only=True)
|
||||
# encoded_video = torch.load(encoded_video_path, weights_only=True)
|
||||
encoded_video = load_file(encoded_video_path)["encoded_video"]
|
||||
logger.debug(f"Loaded encoded video from {encoded_video_path}", main_process_only=False)
|
||||
# shape of image: [C, H, W]
|
||||
_, image = self.preprocess(None, self.images[index])
|
||||
else:
|
||||
frames, image = self.preprocess(video, image)
|
||||
frames = frames.to(self.device)
|
||||
# current shape of frames: [F, C, H, W]
|
||||
image = image.to(self.device)
|
||||
# Current shape of frames: [F, C, H, W]
|
||||
frames = self.video_transform(frames)
|
||||
|
||||
# Add image into the first frame.
|
||||
# Note, **this operation maybe model-specific**, and maybe change in the future.
|
||||
frames = torch.cat([image.unsqueeze(0), frames], dim=0)
|
||||
|
||||
# Convert to [B, C, F, H, W]
|
||||
frames = frames.unsqueeze(0)
|
||||
frames = frames.permute(0, 2, 1, 3, 4).contiguous()
|
||||
encoded_video = self.encode_video_fn(frames)
|
||||
# [B, C, F, H, W] -> [C, F, H, W]
|
||||
encoded_video = encoded_video[0].cpu()
|
||||
torch.save(encoded_video, encoded_video_path)
|
||||
encoded_video = self.encode_video(frames)
|
||||
|
||||
# [1, C, F, H, W] -> [C, F, H, W]
|
||||
encoded_video = encoded_video[0]
|
||||
encoded_video = encoded_video.to("cpu")
|
||||
image = image.to("cpu")
|
||||
save_file({"encoded_video": encoded_video}, encoded_video_path)
|
||||
logger.info(f"Saved encoded video to {encoded_video_path}", main_process_only=False)
|
||||
|
||||
# shape of encoded_video: [C, F, H, W]
|
||||
# shape of image: [C, H, W]
|
||||
return {
|
||||
"prompt": prompt,
|
||||
"image": image,
|
||||
"prompt_embedding": prompt_embedding,
|
||||
"encoded_video": encoded_video,
|
||||
"video_metadata": {
|
||||
"num_frames": encoded_video.shape[1],
|
||||
|
@ -1,12 +1,14 @@
|
||||
import hashlib
|
||||
import torch
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple, Callable
|
||||
from typing import Any, Dict, List, Tuple, TYPE_CHECKING
|
||||
from typing_extensions import override
|
||||
|
||||
from accelerate.logging import get_logger
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
from accelerate.logging import get_logger
|
||||
from safetensors.torch import save_file, load_file
|
||||
|
||||
from finetune.constants import LOG_NAME, LOG_LEVEL
|
||||
|
||||
@ -16,6 +18,9 @@ from .utils import (
|
||||
preprocess_video_with_buckets
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from finetune.trainer import Trainer
|
||||
|
||||
# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
|
||||
# Very few bug reports but it happens. Look in decord Github issues for more relevant information.
|
||||
import decord # isort:skip
|
||||
@ -45,7 +50,7 @@ class BaseT2VDataset(Dataset):
|
||||
caption_column: str,
|
||||
video_column: str,
|
||||
device: torch.device = None,
|
||||
encode_video_fn: Callable[[torch.Tensor], torch.Tensor] = None,
|
||||
trainer: "Trainer" = None,
|
||||
*args,
|
||||
**kwargs
|
||||
) -> None:
|
||||
@ -55,7 +60,9 @@ class BaseT2VDataset(Dataset):
|
||||
self.prompts = load_prompts(data_root / caption_column)
|
||||
self.videos = load_videos(data_root / video_column)
|
||||
self.device = device
|
||||
self.encode_video_fn = encode_video_fn
|
||||
self.encode_video = trainer.encode_video
|
||||
self.encode_text = trainer.encode_text
|
||||
self.trainer = trainer
|
||||
|
||||
# Check if all video files exist
|
||||
if any(not path.is_file() for path in self.videos):
|
||||
@ -87,30 +94,53 @@ class BaseT2VDataset(Dataset):
|
||||
|
||||
prompt = self.prompts[index]
|
||||
video = self.videos[index]
|
||||
train_resolution_str = "x".join(str(x) for x in self.trainer.args.train_resolution)
|
||||
|
||||
latent_dir = video.parent / "latent"
|
||||
latent_dir.mkdir(parents=True, exist_ok=True)
|
||||
encoded_video_path = latent_dir / (video.stem + ".pt")
|
||||
cache_dir = self.trainer.args.data_root / "cache"
|
||||
video_latent_dir = cache_dir / "video_latent" / self.trainer.args.model_name / train_resolution_str
|
||||
prompt_embeddings_dir = cache_dir / "prompt_embeddings"
|
||||
video_latent_dir.mkdir(parents=True, exist_ok=True)
|
||||
prompt_embeddings_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
prompt_hash = str(hashlib.sha256(prompt.encode()).hexdigest())
|
||||
prompt_embedding_path = prompt_embeddings_dir / (prompt_hash + ".safetensors")
|
||||
encoded_video_path = video_latent_dir / (video.stem + ".safetensors")
|
||||
|
||||
if prompt_embedding_path.exists():
|
||||
prompt_embedding = load_file(prompt_embedding_path)["prompt_embedding"]
|
||||
logger.debug(f"process {self.trainer.accelerator.process_index}: Loaded prompt embedding from {prompt_embedding_path}", main_process_only=False)
|
||||
else:
|
||||
prompt_embedding = self.encode_text(prompt)
|
||||
prompt_embedding = prompt_embedding.to("cpu")
|
||||
# [1, seq_len, hidden_size] -> [seq_len, hidden_size]
|
||||
prompt_embedding = prompt_embedding[0]
|
||||
save_file({"prompt_embedding": prompt_embedding}, prompt_embedding_path)
|
||||
logger.info(f"Saved prompt embedding to {prompt_embedding_path}", main_process_only=False)
|
||||
|
||||
if encoded_video_path.exists():
|
||||
# shape of encoded_video: [C, F, H, W]
|
||||
encoded_video = torch.load(encoded_video_path, weights_only=True)
|
||||
# encoded_video = torch.load(encoded_video_path, weights_only=True)
|
||||
encoded_video = load_file(encoded_video_path)["encoded_video"]
|
||||
logger.debug(f"Loaded encoded video from {encoded_video_path}", main_process_only=False)
|
||||
# shape of image: [C, H, W]
|
||||
else:
|
||||
frames = self.preprocess(video)
|
||||
frames = frames.to(self.device)
|
||||
# current shape of frames: [F, C, H, W]
|
||||
# Current shape of frames: [F, C, H, W]
|
||||
frames = self.video_transform(frames)
|
||||
# Convert to [B, C, F, H, W]
|
||||
frames = frames.unsqueeze(0)
|
||||
frames = frames.permute(0, 2, 1, 3, 4).contiguous()
|
||||
encoded_video = self.encode_video_fn(frames)
|
||||
# [B, C, F, H, W] -> [C, F, H, W]
|
||||
encoded_video = encoded_video[0].cpu()
|
||||
torch.save(encoded_video, encoded_video_path)
|
||||
encoded_video = self.encode_video(frames)
|
||||
|
||||
# [1, C, F, H, W] -> [C, F, H, W]
|
||||
encoded_video = encoded_video[0]
|
||||
encoded_video = encoded_video.to("cpu")
|
||||
save_file({"encoded_video": encoded_video}, encoded_video_path)
|
||||
logger.info(f"Saved encoded video to {encoded_video_path}", main_process_only=False)
|
||||
|
||||
# shape of encoded_video: [C, F, H, W]
|
||||
return {
|
||||
"prompt": prompt,
|
||||
"prompt_embedding": prompt_embedding,
|
||||
"encoded_video": encoded_video,
|
||||
"video_metadata": {
|
||||
"num_frames": encoded_video.shape[1],
|
||||
|
@ -25,6 +25,7 @@ from accelerate.utils import (
|
||||
gather_object,
|
||||
)
|
||||
|
||||
from diffusers.pipelines import DiffusionPipeline
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils.export_utils import export_to_video
|
||||
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
|
||||
@ -36,6 +37,7 @@ from finetune.utils import (
|
||||
|
||||
get_memory_statistics,
|
||||
free_memory,
|
||||
unload_model,
|
||||
|
||||
get_latest_ckpt_path_to_resume_from,
|
||||
get_intermediate_ckpt_path,
|
||||
@ -63,6 +65,8 @@ _DTYPE_MAP = {
|
||||
|
||||
|
||||
class Trainer:
|
||||
# If set, should be a list of components to unload (refer to `Components``)
|
||||
UNLOAD_LIST: List[str] = None
|
||||
|
||||
def __init__(self, args: Args) -> None:
|
||||
self.args = args
|
||||
@ -132,6 +136,15 @@ class Trainer:
|
||||
if self.accelerator.is_main_process:
|
||||
self.args.output_dir = Path(self.args.output_dir)
|
||||
self.args.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def check_setting(self) -> None:
|
||||
# Check for unload_list
|
||||
if self.UNLOAD_LIST is None:
|
||||
logger.warning("\033[91mNo unload_list specified for this Trainer. All components will be loaded to GPU during training.\033[0m")
|
||||
else:
|
||||
for name in self.UNLOAD_LIST:
|
||||
if name not in self.components.model_fields:
|
||||
raise ValueError(f"Invalid component name in unload_list: {name}")
|
||||
|
||||
def prepare_models(self) -> None:
|
||||
logger.info("Initializing models")
|
||||
@ -150,33 +163,39 @@ class Trainer:
|
||||
def prepare_dataset(self) -> None:
|
||||
logger.info("Initializing dataset and dataloader")
|
||||
|
||||
# self.state.train_frames includes one padding frame for image conditioning
|
||||
# so we only sample train_frames - 1 frames from the actual video
|
||||
max_num_frames = self.state.train_frames - 1
|
||||
|
||||
if self.args.model_type == "i2v":
|
||||
self.dataset = I2VDatasetWithResize(
|
||||
**(self.args.model_dump()),
|
||||
device=self.accelerator.device,
|
||||
encode_video_fn=self.encode_video,
|
||||
max_num_frames=self.state.train_frames,
|
||||
max_num_frames=max_num_frames,
|
||||
height=self.state.train_height,
|
||||
width=self.state.train_width
|
||||
width=self.state.train_width,
|
||||
trainer=self
|
||||
)
|
||||
elif self.args.model_type == "t2v":
|
||||
self.dataset = T2VDatasetWithResize(
|
||||
**(self.args.model_dump()),
|
||||
device=self.accelerator.device,
|
||||
encode_video_fn=self.encode_video,
|
||||
max_num_frames=self.state.train_frames,
|
||||
max_num_frames=max_num_frames,
|
||||
height=self.state.train_height,
|
||||
width=self.state.train_width
|
||||
width=self.state.train_width,
|
||||
trainer=self
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid model type: {self.args.model_type}")
|
||||
|
||||
# Prepare VAE for encoding
|
||||
# Prepare VAE and text encoder for encoding
|
||||
self.components.vae = self.components.vae.to(self.accelerator.device)
|
||||
self.components.vae.requires_grad_(False)
|
||||
self.components.text_encoder = self.components.text_encoder.to(self.accelerator.device)
|
||||
self.components.text_encoder.requires_grad_(False)
|
||||
|
||||
# Precompute latent for video
|
||||
logger.info("Precomputing latent for video ...")
|
||||
# Precompute latent for video and prompt embedding
|
||||
logger.info("Precomputing latent for video and prompt embedding ...")
|
||||
tmp_data_loader = torch.utils.data.DataLoader(
|
||||
self.dataset,
|
||||
collate_fn=self.collate_fn,
|
||||
@ -186,7 +205,12 @@ class Trainer:
|
||||
)
|
||||
tmp_data_loader = self.accelerator.prepare_data_loader(tmp_data_loader)
|
||||
for _ in tmp_data_loader: ...
|
||||
logger.info("Precomputing latent for video ... Done")
|
||||
self.accelerator.wait_for_everyone()
|
||||
logger.info("Precomputing latent for video and prompt embedding ... Done")
|
||||
|
||||
unload_model(self.components.vae)
|
||||
unload_model(self.components.text_encoder)
|
||||
free_memory()
|
||||
|
||||
self.data_loader = torch.utils.data.DataLoader(
|
||||
self.dataset,
|
||||
@ -216,7 +240,7 @@ class Trainer:
|
||||
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
|
||||
)
|
||||
|
||||
self.__move_components_to_device()
|
||||
self.__load_components()
|
||||
|
||||
if self.args.gradient_checkpointing:
|
||||
self.components.transformer.enable_gradient_checkpointing()
|
||||
@ -234,7 +258,7 @@ class Trainer:
|
||||
logger.info("Initializing optimizer and lr scheduler")
|
||||
|
||||
# Make sure the trainable params are in float32
|
||||
if self.args.mixed_precision == "fp16":
|
||||
if self.args.mixed_precision != "no":
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
cast_training_params([self.components.transformer], dtype=torch.float32)
|
||||
|
||||
@ -423,17 +447,20 @@ class Trainer:
|
||||
global_step += 1
|
||||
self.__maybe_save_checkpoint(global_step)
|
||||
|
||||
logs["loss"] = loss.detach().item()
|
||||
logs["lr"] = self.lr_scheduler.get_last_lr()[0]
|
||||
progress_bar.set_postfix(logs)
|
||||
|
||||
# Maybe run validation
|
||||
should_run_validation = (
|
||||
self.args.do_validation
|
||||
and global_step % self.args.validation_steps == 0
|
||||
)
|
||||
if should_run_validation:
|
||||
del loss
|
||||
free_memory()
|
||||
self.validate(global_step)
|
||||
|
||||
logs["loss"] = loss.detach().item()
|
||||
logs["lr"] = self.lr_scheduler.get_last_lr()[0]
|
||||
progress_bar.set_postfix(logs)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if global_step >= self.args.train_steps:
|
||||
@ -445,6 +472,7 @@ class Trainer:
|
||||
accelerator.wait_for_everyone()
|
||||
self.__maybe_save_checkpoint(global_step, must_save=True)
|
||||
if self.args.do_validation:
|
||||
free_memory()
|
||||
self.validate(global_step)
|
||||
|
||||
del self.components
|
||||
@ -465,10 +493,22 @@ class Trainer:
|
||||
return
|
||||
|
||||
self.components.transformer.eval()
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
memory_statistics = get_memory_statistics()
|
||||
logger.info(f"Memory before validation start: {json.dumps(memory_statistics, indent=4)}")
|
||||
|
||||
##### Initialize pipeline #####
|
||||
pipe = self.initialize_pipeline()
|
||||
|
||||
# Or use pipe.enable_sequential_cpu_offload() to further reduce memory usage
|
||||
pipe.enable_model_cpu_offload(device=self.accelerator.device)
|
||||
|
||||
# Convert all model weights to training dtype
|
||||
# Note, this will change LoRA weights in self.components.transformer to training dtype, rather than keep them in fp32
|
||||
pipe = pipe.to(dtype=self.state.weight_dtype)
|
||||
#################################
|
||||
|
||||
all_processes_artifacts = []
|
||||
for i in range(num_validation_samples):
|
||||
# Skip current validation on all processes but one
|
||||
@ -504,7 +544,7 @@ class Trainer:
|
||||
"prompt": prompt,
|
||||
"image": image,
|
||||
"video": video
|
||||
})
|
||||
}, pipe)
|
||||
prompt_filename = string_to_filename(prompt)[:25]
|
||||
artifacts = {
|
||||
"image": {"type": "image", "value": image},
|
||||
@ -555,6 +595,12 @@ class Trainer:
|
||||
step=step,
|
||||
)
|
||||
|
||||
del pipe
|
||||
# Unload loaded models except those needed for training
|
||||
self.__unload_components()
|
||||
# Change LoRA weights back to fp32
|
||||
cast_training_params([self.components.transformer], dtype=torch.float32)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
free_memory()
|
||||
@ -562,9 +608,12 @@ class Trainer:
|
||||
logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}")
|
||||
torch.cuda.reset_peak_memory_stats(accelerator.device)
|
||||
|
||||
torch.set_grad_enabled(True)
|
||||
self.components.transformer.train()
|
||||
|
||||
|
||||
def fit(self):
|
||||
self.check_setting()
|
||||
self.prepare_models()
|
||||
self.prepare_dataset()
|
||||
self.prepare_trainable_parameters()
|
||||
@ -580,9 +629,17 @@ class Trainer:
|
||||
|
||||
def load_components(self) -> Components:
|
||||
raise NotImplementedError
|
||||
|
||||
def initialize_pipeline(self) -> DiffusionPipeline:
|
||||
raise NotImplementedError
|
||||
|
||||
def encode_video(self, video: torch.Tensor) -> torch.Tensor:
|
||||
# shape of input video: [B, C, F, H, W], where B = 1
|
||||
# shape of output video: [B, C', F', H', W'], where B = 1
|
||||
raise NotImplementedError
|
||||
|
||||
def encode_text(self, text: str) -> torch.Tensor:
|
||||
# shape of output text: [batch size, sequence length, embedding dimension]
|
||||
raise NotImplementedError
|
||||
|
||||
def compute_loss(self, batch) -> torch.Tensor:
|
||||
@ -601,11 +658,21 @@ class Trainer:
|
||||
else:
|
||||
raise ValueError(f"Invalid mixed precision: {self.args.mixed_precision}")
|
||||
|
||||
def __move_components_to_device(self):
|
||||
def __load_components(self):
|
||||
components = self.components.model_dump()
|
||||
for name, component in components.items():
|
||||
if not isinstance(component, type) and hasattr(component, 'to'):
|
||||
setattr(self.components, name, component.to(self.accelerator.device))
|
||||
if name in self.UNLOAD_LIST:
|
||||
continue
|
||||
# setattr(self.components, name, component.to(self.accelerator.device))
|
||||
setattr(self.components, name, component.to(self.accelerator.device, dtype=self.state.weight_dtype))
|
||||
|
||||
def __unload_components(self):
|
||||
components = self.components.model_dump()
|
||||
for name, component in components.items():
|
||||
if not isinstance(component, type) and hasattr(component, 'to'):
|
||||
if name in self.UNLOAD_LIST:
|
||||
setattr(self.components, name, component.to('cpu'))
|
||||
|
||||
def __prepare_saving_loading_hooks(self, transformer_lora_config):
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
|
Loading…
x
Reference in New Issue
Block a user