Add cache to videos and cpu offloading during trainning

This commit is contained in:
Rodrigo Antonio de Araujo 2024-10-09 09:17:06 -03:00
parent 532f246d7c
commit 8aae1e1e42
4 changed files with 194 additions and 15 deletions

2
.gitignore vendored
View File

@ -8,3 +8,5 @@ logs/
.idea
output*
test*
finetune/cogvideox-lora-single-node/
venv/

View File

@ -0,0 +1,24 @@
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
gradient_accumulation_steps: 1
gradient_clipping: 1.0
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: false
zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
dynamo_backend: 'no'
mixed_precision: 'no'
num_machines: 1
num_processes: 1
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@ -0,0 +1,55 @@
#!/bin/bash
export MODEL_PATH="THUDM/CogVideoX-2b"
export CACHE_PATH="~/.cache"
# "Disney-VideoGeneration-Dataset"
# /home/rodrigo/generator/dataset/Disney-VideoGeneration-Dataset
export DATASET_PATH="/home/rodrigo/generator/dataset/Disney-VideoGeneration-Dataset"
export OUTPUT_PATH="cogvideox-lora-single-node"
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
export CUDA_VISIBLE_DEVICES="0"
# if you are not using wth 8 gus, change `accelerate_config_machine_single.yaml` num_processes as your gpu number
accelerate launch --config_file accelerate_config_machine_single_weak.yaml --multi_gpu \
train_cogvideox_lora.py \
--gradient_checkpointing \
--pretrained_model_name_or_path $MODEL_PATH \
--cache_dir $CACHE_PATH \
--enable_tiling \
--enable_slicing \
--instance_data_root $DATASET_PATH \
--caption_column prompt.txt \
--video_column videos.txt \
--validation_prompt "DISNEY A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions:::A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance" \
--validation_prompt_separator ::: \
--num_validation_videos 1 \
--validation_epochs 100 \
--seed 42 \
--rank 128 \
--lora_alpha 64 \
--mixed_precision bf16 \
--output_dir $OUTPUT_PATH \
--height 480 \
--width 720 \
--fps 8 \
--max_num_frames 49 \
--skip_frames_start 0 \
--skip_frames_end 0 \
--train_batch_size 1 \
--num_train_epochs 30 \
--checkpointing_steps 1000 \
--gradient_accumulation_steps 1 \
--learning_rate 1e-3 \
--lr_scheduler cosine_with_restarts \
--lr_warmup_steps 200 \
--lr_num_cycles 1 \
--enable_slicing \
--enable_tiling \
--gradient_checkpointing \
--optimizer AdamW \
--adam_beta1 0.9 \
--adam_beta2 0.95 \
--max_grad_norm 1.0 \
--allow_tf32 \
--offload_to_cpu \
--cache_preprocessed_data

View File

@ -401,8 +401,67 @@ def get_args():
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
),
)
parser.add_argument(
"--offload_to_cpu",
action="store_true",
help="Whether or not to offload the model to the CPU.",
)
parser.add_argument(
"--cache_preprocessed_data",
action="store_true",
help="Whether or not to cache preprocessed data.",
)
return parser.parse_args()
parsed_args = parser.parse_args()
return parsed_args
class CachedVideoList:
ACCELERATOR_DEVICE = 'cpu'
CACHE_ENABLED = False
VAE = None
OUTPUT_DIR = None
@classmethod
def is_cached(cls, video_name):
if cls.CACHE_ENABLED:
return os.path.exists(os.path.join(cls.cache_dir(), f'{video_name}.pt'))
return False
@classmethod
def cache_dir(cls):
result = os.path.join(cls.OUTPUT_DIR, "cached_videos")
if not os.path.exists(result):
os.makedirs(result)
return result
def __init__(self):
if not self.CACHE_ENABLED:
raise ValueError("CachedVideoList is not enabled. Please enable the cache before using it.")
if self.OUTPUT_DIR is None:
raise ValueError("Output directory not set. Please set the output directory before using the CachedVideoList.")
if self.VAE is None:
raise ValueError("VAE model not set. Please set the VAE model before using the CachedVideoList.")
self.video_names = []
def __len__(self):
return len(self.video_names)
def append(self, video: Tuple[str, torch.Tensor]):
self.video_names.append(video[0])
if video[1] is not None:
torch.save(video[1], os.path.join(CachedVideoList.cache_dir(), f'{video[0]}.pt'))
def __getitem__(self, index):
if index >= len(self.video_names):
raise IndexError("Index out of bounds")
pt_path = os.path.join(self.cache_dir, f'{video_names[index]}.pt')
if not os.path.exists(pt_path):
raise FileNotFoundError(f"Video {pt_path} not found in cache.")
return torch.load(pt_path)
class VideoDataset(Dataset):
@ -534,6 +593,17 @@ class VideoDataset(Dataset):
return instance_prompts, instance_videos
def encode_videos(self, vae, device):
if not CachedVideoList.CACHE_ENABLED:
self.instance_videos = [self.encode_video(video, vae, device) for video in self.instance_videos]
def encode_video(self, video, vae, device):
video = video.to(device, dtype=vae.dtype).unsqueeze(0)
video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
vae.to(device)
latent_dist = vae.encode(video.to(device)).latent_dist
return latent_dist
def _preprocess_data(self):
try:
import decord
@ -544,7 +614,7 @@ class VideoDataset(Dataset):
decord.bridge.set_bridge("torch")
videos = []
videos = CachedVideoList() if CachedVideoList.CACHE_ENABLED else []
train_transforms = transforms.Compose(
[
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0),
@ -552,6 +622,11 @@ class VideoDataset(Dataset):
)
for filename in self.instance_video_paths:
if CachedVideoList.CACHE_ENABLED and videos.is_cached(filename.stem):
videos.append((filename.stem, None))
continue
if CachedVideoList.CACHE_ENABLED:
print(f"Pre-processing video {filename.as_posix()}")
video_reader = decord.VideoReader(uri=filename.as_posix(), width=self.width, height=self.height)
video_num_frames = len(video_reader)
@ -580,7 +655,17 @@ class VideoDataset(Dataset):
# Training transforms
frames = frames.float()
frames = torch.stack([train_transforms(frame) for frame in frames], dim=0)
videos.append(frames.permute(0, 3, 1, 2).contiguous()) # [F, C, H, W]
frames = frames.permute(0, 3, 1, 2).contiguous() # [F, C, H, W]
if CachedVideoList.CACHE_ENABLED:
print(f"Encoding video {filename.stem} device {CachedVideoList.ACCELERATOR_DEVICE}")
if CachedVideoList.CACHE_ENABLED:
videos.append((filename.stem, self.encode_video(frames, CachedVideoList.VAE, CachedVideoList.ACCELERATOR_DEVICE)))
else:
videos.append(frames) # [F, C, H, W]
if CachedVideoList.CACHE_ENABLED:
print(f"Video Encoded {filename.stem}")
return videos
@ -1054,11 +1139,18 @@ def main(args):
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
)
text_encoder.to(accelerator.device, dtype=weight_dtype)
transformer.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype)
if args.offload_to_cpu:
text_encoder.to('cpu', dtype=weight_dtype)
transformer.to('cpu', dtype=weight_dtype)
vae.to('cpu', dtype=weight_dtype)
from accelerate import cpu_offload
cpu_offload(transformer, accelerator.device, offload_buffers=False)
else:
text_encoder.to(accelerator.device, dtype=weight_dtype)
transformer.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype)
if args.gradient_checkpointing:
if args.gradient_checkpointing or args.offload_to_cpu:
transformer.enable_gradient_checkpointing()
# now we will add new LoRA weights to the attention layers
@ -1163,6 +1255,12 @@ def main(args):
optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer)
if args.cache_preprocessed_data:
CachedVideoList.OUTPUT_DIR = args.output_dir
CachedVideoList.CACHE_ENABLED = True
CachedVideoList.VAE = vae
CachedVideoList.ACCELERATOR_DEVICE = accelerator.device
# Dataset and DataLoader
train_dataset = VideoDataset(
instance_data_root=args.instance_data_root,
@ -1180,13 +1278,13 @@ def main(args):
id_token=args.id_token,
)
def encode_video(video):
video = video.to(accelerator.device, dtype=vae.dtype).unsqueeze(0)
video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
latent_dist = vae.encode(video).latent_dist
return latent_dist
if args.offload_to_cpu:
vae.to('cuda')
train_dataset.instance_videos = [encode_video(video) for video in train_dataset.instance_videos]
train_dataset.encode_videos(vae, accelerator.device)
if args.offload_to_cpu:
vae.to('cpu')
def collate_fn(examples):
videos = [example["instance_video"].sample() * vae.config.scaling_factor for example in examples]