mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-12-03 03:02:09 +08:00
Add cache to videos and cpu offloading during trainning
This commit is contained in:
parent
532f246d7c
commit
8aae1e1e42
2
.gitignore
vendored
2
.gitignore
vendored
@ -8,3 +8,5 @@ logs/
|
||||
.idea
|
||||
output*
|
||||
test*
|
||||
finetune/cogvideox-lora-single-node/
|
||||
venv/
|
||||
24
finetune/accelerate_config_machine_single_weak.yaml
Normal file
24
finetune/accelerate_config_machine_single_weak.yaml
Normal file
@ -0,0 +1,24 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
deepspeed_config:
|
||||
gradient_accumulation_steps: 1
|
||||
gradient_clipping: 1.0
|
||||
offload_optimizer_device: none
|
||||
offload_param_device: none
|
||||
zero3_init_flag: false
|
||||
zero_stage: 2
|
||||
distributed_type: DEEPSPEED
|
||||
downcast_bf16: 'no'
|
||||
enable_cpu_affinity: false
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
dynamo_backend: 'no'
|
||||
mixed_precision: 'no'
|
||||
num_machines: 1
|
||||
num_processes: 1
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
55
finetune/finetune_single_rank_weak.sh
Executable file
55
finetune/finetune_single_rank_weak.sh
Executable file
@ -0,0 +1,55 @@
|
||||
#!/bin/bash
|
||||
|
||||
export MODEL_PATH="THUDM/CogVideoX-2b"
|
||||
export CACHE_PATH="~/.cache"
|
||||
# "Disney-VideoGeneration-Dataset"
|
||||
# /home/rodrigo/generator/dataset/Disney-VideoGeneration-Dataset
|
||||
export DATASET_PATH="/home/rodrigo/generator/dataset/Disney-VideoGeneration-Dataset"
|
||||
export OUTPUT_PATH="cogvideox-lora-single-node"
|
||||
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
||||
export CUDA_VISIBLE_DEVICES="0"
|
||||
|
||||
# if you are not using wth 8 gus, change `accelerate_config_machine_single.yaml` num_processes as your gpu number
|
||||
accelerate launch --config_file accelerate_config_machine_single_weak.yaml --multi_gpu \
|
||||
train_cogvideox_lora.py \
|
||||
--gradient_checkpointing \
|
||||
--pretrained_model_name_or_path $MODEL_PATH \
|
||||
--cache_dir $CACHE_PATH \
|
||||
--enable_tiling \
|
||||
--enable_slicing \
|
||||
--instance_data_root $DATASET_PATH \
|
||||
--caption_column prompt.txt \
|
||||
--video_column videos.txt \
|
||||
--validation_prompt "DISNEY A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions:::A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance" \
|
||||
--validation_prompt_separator ::: \
|
||||
--num_validation_videos 1 \
|
||||
--validation_epochs 100 \
|
||||
--seed 42 \
|
||||
--rank 128 \
|
||||
--lora_alpha 64 \
|
||||
--mixed_precision bf16 \
|
||||
--output_dir $OUTPUT_PATH \
|
||||
--height 480 \
|
||||
--width 720 \
|
||||
--fps 8 \
|
||||
--max_num_frames 49 \
|
||||
--skip_frames_start 0 \
|
||||
--skip_frames_end 0 \
|
||||
--train_batch_size 1 \
|
||||
--num_train_epochs 30 \
|
||||
--checkpointing_steps 1000 \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--learning_rate 1e-3 \
|
||||
--lr_scheduler cosine_with_restarts \
|
||||
--lr_warmup_steps 200 \
|
||||
--lr_num_cycles 1 \
|
||||
--enable_slicing \
|
||||
--enable_tiling \
|
||||
--gradient_checkpointing \
|
||||
--optimizer AdamW \
|
||||
--adam_beta1 0.9 \
|
||||
--adam_beta2 0.95 \
|
||||
--max_grad_norm 1.0 \
|
||||
--allow_tf32 \
|
||||
--offload_to_cpu \
|
||||
--cache_preprocessed_data
|
||||
@ -401,8 +401,67 @@ def get_args():
|
||||
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--offload_to_cpu",
|
||||
action="store_true",
|
||||
help="Whether or not to offload the model to the CPU.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_preprocessed_data",
|
||||
action="store_true",
|
||||
help="Whether or not to cache preprocessed data.",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
parsed_args = parser.parse_args()
|
||||
|
||||
return parsed_args
|
||||
|
||||
|
||||
class CachedVideoList:
|
||||
ACCELERATOR_DEVICE = 'cpu'
|
||||
CACHE_ENABLED = False
|
||||
VAE = None
|
||||
OUTPUT_DIR = None
|
||||
|
||||
@classmethod
|
||||
def is_cached(cls, video_name):
|
||||
if cls.CACHE_ENABLED:
|
||||
return os.path.exists(os.path.join(cls.cache_dir(), f'{video_name}.pt'))
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def cache_dir(cls):
|
||||
result = os.path.join(cls.OUTPUT_DIR, "cached_videos")
|
||||
if not os.path.exists(result):
|
||||
os.makedirs(result)
|
||||
return result
|
||||
|
||||
def __init__(self):
|
||||
if not self.CACHE_ENABLED:
|
||||
raise ValueError("CachedVideoList is not enabled. Please enable the cache before using it.")
|
||||
if self.OUTPUT_DIR is None:
|
||||
raise ValueError("Output directory not set. Please set the output directory before using the CachedVideoList.")
|
||||
if self.VAE is None:
|
||||
raise ValueError("VAE model not set. Please set the VAE model before using the CachedVideoList.")
|
||||
|
||||
self.video_names = []
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.video_names)
|
||||
|
||||
def append(self, video: Tuple[str, torch.Tensor]):
|
||||
self.video_names.append(video[0])
|
||||
if video[1] is not None:
|
||||
torch.save(video[1], os.path.join(CachedVideoList.cache_dir(), f'{video[0]}.pt'))
|
||||
|
||||
def __getitem__(self, index):
|
||||
if index >= len(self.video_names):
|
||||
raise IndexError("Index out of bounds")
|
||||
pt_path = os.path.join(self.cache_dir, f'{video_names[index]}.pt')
|
||||
if not os.path.exists(pt_path):
|
||||
raise FileNotFoundError(f"Video {pt_path} not found in cache.")
|
||||
return torch.load(pt_path)
|
||||
|
||||
|
||||
class VideoDataset(Dataset):
|
||||
@ -534,6 +593,17 @@ class VideoDataset(Dataset):
|
||||
|
||||
return instance_prompts, instance_videos
|
||||
|
||||
def encode_videos(self, vae, device):
|
||||
if not CachedVideoList.CACHE_ENABLED:
|
||||
self.instance_videos = [self.encode_video(video, vae, device) for video in self.instance_videos]
|
||||
|
||||
def encode_video(self, video, vae, device):
|
||||
video = video.to(device, dtype=vae.dtype).unsqueeze(0)
|
||||
video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
|
||||
vae.to(device)
|
||||
latent_dist = vae.encode(video.to(device)).latent_dist
|
||||
return latent_dist
|
||||
|
||||
def _preprocess_data(self):
|
||||
try:
|
||||
import decord
|
||||
@ -544,7 +614,7 @@ class VideoDataset(Dataset):
|
||||
|
||||
decord.bridge.set_bridge("torch")
|
||||
|
||||
videos = []
|
||||
videos = CachedVideoList() if CachedVideoList.CACHE_ENABLED else []
|
||||
train_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0),
|
||||
@ -552,6 +622,11 @@ class VideoDataset(Dataset):
|
||||
)
|
||||
|
||||
for filename in self.instance_video_paths:
|
||||
if CachedVideoList.CACHE_ENABLED and videos.is_cached(filename.stem):
|
||||
videos.append((filename.stem, None))
|
||||
continue
|
||||
if CachedVideoList.CACHE_ENABLED:
|
||||
print(f"Pre-processing video {filename.as_posix()}")
|
||||
video_reader = decord.VideoReader(uri=filename.as_posix(), width=self.width, height=self.height)
|
||||
video_num_frames = len(video_reader)
|
||||
|
||||
@ -580,7 +655,17 @@ class VideoDataset(Dataset):
|
||||
# Training transforms
|
||||
frames = frames.float()
|
||||
frames = torch.stack([train_transforms(frame) for frame in frames], dim=0)
|
||||
videos.append(frames.permute(0, 3, 1, 2).contiguous()) # [F, C, H, W]
|
||||
frames = frames.permute(0, 3, 1, 2).contiguous() # [F, C, H, W]
|
||||
|
||||
if CachedVideoList.CACHE_ENABLED:
|
||||
print(f"Encoding video {filename.stem} device {CachedVideoList.ACCELERATOR_DEVICE}")
|
||||
|
||||
if CachedVideoList.CACHE_ENABLED:
|
||||
videos.append((filename.stem, self.encode_video(frames, CachedVideoList.VAE, CachedVideoList.ACCELERATOR_DEVICE)))
|
||||
else:
|
||||
videos.append(frames) # [F, C, H, W]
|
||||
if CachedVideoList.CACHE_ENABLED:
|
||||
print(f"Video Encoded {filename.stem}")
|
||||
|
||||
return videos
|
||||
|
||||
@ -1054,11 +1139,18 @@ def main(args):
|
||||
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
|
||||
)
|
||||
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
transformer.to(accelerator.device, dtype=weight_dtype)
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
if args.offload_to_cpu:
|
||||
text_encoder.to('cpu', dtype=weight_dtype)
|
||||
transformer.to('cpu', dtype=weight_dtype)
|
||||
vae.to('cpu', dtype=weight_dtype)
|
||||
from accelerate import cpu_offload
|
||||
cpu_offload(transformer, accelerator.device, offload_buffers=False)
|
||||
else:
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
transformer.to(accelerator.device, dtype=weight_dtype)
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
if args.gradient_checkpointing or args.offload_to_cpu:
|
||||
transformer.enable_gradient_checkpointing()
|
||||
|
||||
# now we will add new LoRA weights to the attention layers
|
||||
@ -1163,6 +1255,12 @@ def main(args):
|
||||
|
||||
optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer)
|
||||
|
||||
if args.cache_preprocessed_data:
|
||||
CachedVideoList.OUTPUT_DIR = args.output_dir
|
||||
CachedVideoList.CACHE_ENABLED = True
|
||||
CachedVideoList.VAE = vae
|
||||
CachedVideoList.ACCELERATOR_DEVICE = accelerator.device
|
||||
|
||||
# Dataset and DataLoader
|
||||
train_dataset = VideoDataset(
|
||||
instance_data_root=args.instance_data_root,
|
||||
@ -1180,13 +1278,13 @@ def main(args):
|
||||
id_token=args.id_token,
|
||||
)
|
||||
|
||||
def encode_video(video):
|
||||
video = video.to(accelerator.device, dtype=vae.dtype).unsqueeze(0)
|
||||
video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
|
||||
latent_dist = vae.encode(video).latent_dist
|
||||
return latent_dist
|
||||
if args.offload_to_cpu:
|
||||
vae.to('cuda')
|
||||
|
||||
train_dataset.instance_videos = [encode_video(video) for video in train_dataset.instance_videos]
|
||||
train_dataset.encode_videos(vae, accelerator.device)
|
||||
|
||||
if args.offload_to_cpu:
|
||||
vae.to('cpu')
|
||||
|
||||
def collate_fn(examples):
|
||||
videos = [example["instance_video"].sample() * vae.config.scaling_factor for example in examples]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user