From f69c3c87ecd327dcb2f90ce8685b113cef1c9a69 Mon Sep 17 00:00:00 2001 From: Rodrigo Antonio de Araujo Date: Wed, 9 Oct 2024 10:32:16 -0300 Subject: [PATCH] Experiment: train lora lowvram --- ...config_machine_single_single_process.yaml} | 0 finetune/finetune_single_rank_weak.sh | 2 +- finetune/train_cogvideox_lora.py | 100 ++++++++++++++---- 3 files changed, 80 insertions(+), 22 deletions(-) rename finetune/{accelerate_config_machine_single_weak.yaml => accelerate_config_machine_single_single_process.yaml} (100%) diff --git a/finetune/accelerate_config_machine_single_weak.yaml b/finetune/accelerate_config_machine_single_single_process.yaml similarity index 100% rename from finetune/accelerate_config_machine_single_weak.yaml rename to finetune/accelerate_config_machine_single_single_process.yaml diff --git a/finetune/finetune_single_rank_weak.sh b/finetune/finetune_single_rank_weak.sh index 3541e8f..5fa0bde 100755 --- a/finetune/finetune_single_rank_weak.sh +++ b/finetune/finetune_single_rank_weak.sh @@ -10,7 +10,7 @@ export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True export CUDA_VISIBLE_DEVICES="0" # if you are not using wth 8 gus, change `accelerate_config_machine_single.yaml` num_processes as your gpu number -accelerate launch --config_file accelerate_config_machine_single_weak.yaml --multi_gpu \ +accelerate launch --config_file accelerate_config_machine_single_single_process.yaml --multi_gpu \ train_cogvideox_lora.py \ --gradient_checkpointing \ --pretrained_model_name_or_path $MODEL_PATH \ diff --git a/finetune/train_cogvideox_lora.py b/finetune/train_cogvideox_lora.py index 51feaf0..437b96e 100644 --- a/finetune/train_cogvideox_lora.py +++ b/finetune/train_cogvideox_lora.py @@ -416,7 +416,34 @@ def get_args(): return parsed_args +class Offloader: + def __init__(self): + pass + def enable_sequential_cpu_offload(self, model): + from accelerate import cpu_offload + gpu_id = None + device = "cuda" + + torch_device = torch.device(device) + device_index = torch_device.index + + if gpu_id is not None and device_index is not None: + raise ValueError( + f"You have passed both `gpu_id`={gpu_id} and an index as part of the passed device `device`={device}" + f"Cannot pass both. Please make sure to either not define `gpu_id` or not pass the index as part of the device: `device`={torch_device.type}" + ) + + # _offload_gpu_id should be set to passed gpu_id (or id in passed `device`) or default to previously set id or default to 0 + self._offload_gpu_id = gpu_id or torch_device.index or 0 + + device_type = torch_device.type + device = torch.device(f"{device_type}:{self._offload_gpu_id}") + self._offload_device = device + + offload_buffers = len(model._parameters) > 0 + cpu_offload(model, device, offload_buffers=offload_buffers) + class CachedVideoList: ACCELERATOR_DEVICE = 'cpu' CACHE_ENABLED = False @@ -458,7 +485,7 @@ class CachedVideoList: def __getitem__(self, index): if index >= len(self.video_names): raise IndexError("Index out of bounds") - pt_path = os.path.join(self.cache_dir, f'{video_names[index]}.pt') + pt_path = os.path.join(CachedVideoList.cache_dir(), f'{self.video_names[index]}.pt') if not os.path.exists(pt_path): raise FileNotFoundError(f"Video {pt_path} not found in cache.") return torch.load(pt_path) @@ -501,6 +528,9 @@ class VideoDataset(Dataset): self.instance_prompts, self.instance_video_paths = self._load_dataset_from_hub() else: self.instance_prompts, self.instance_video_paths = self._load_dataset_from_local_path() + + if self.id_token is not None: + self.instance_prompts = [self.id_token + prompt for prompt in self.instance_prompts] self.num_instance_videos = len(self.instance_video_paths) if self.num_instance_videos != len(self.instance_prompts): @@ -515,7 +545,7 @@ class VideoDataset(Dataset): def __getitem__(self, index): return { - "instance_prompt": self.id_token + self.instance_prompts[index], + "instance_prompt": self.instance_prompts[index], "instance_video": self.instance_videos[index], } @@ -593,6 +623,21 @@ class VideoDataset(Dataset): return instance_prompts, instance_videos + def encode_prompts(self, tokenizer, text_encoder, device, dtype): + encoded_prompts = [] + for index, prompt in enumerate(self.instance_prompts): + print(f"Encoding prompt {index + 1} of {len(self.instance_prompts)}") + prompt_embeds = compute_prompt_embeddings( + tokenizer, + text_encoder, + prompt, + max_sequence_length=226, + device=device, + dtype=dtype, + ) + encoded_prompts.append(prompt_embeds) + self.instance_prompts = encoded_prompts + def encode_videos(self, vae, device): if not CachedVideoList.CACHE_ENABLED: self.instance_videos = [self.encode_video(video, vae, device) for video in self.instance_videos] @@ -842,7 +887,7 @@ def _get_t5_prompt_embeds( if text_input_ids is None: raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.") - prompt_embeds = text_encoder(text_input_ids.to(device))[0] + prompt_embeds = text_encoder(text_input_ids.to(device))[0].to(device) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) # duplicate text embeddings for each generation per prompt, using mps friendly method @@ -1142,9 +1187,7 @@ def main(args): if args.offload_to_cpu: text_encoder.to('cpu', dtype=weight_dtype) transformer.to('cpu', dtype=weight_dtype) - vae.to('cpu', dtype=weight_dtype) - from accelerate import cpu_offload - cpu_offload(transformer, accelerator.device, offload_buffers=False) + vae.to('cpu', dtype=weight_dtype) else: text_encoder.to(accelerator.device, dtype=weight_dtype) transformer.to(accelerator.device, dtype=weight_dtype) @@ -1285,6 +1328,13 @@ def main(args): if args.offload_to_cpu: vae.to('cpu') + + if args.offload_to_cpu: + text_encoder.to(accelerator.device) + train_dataset.encode_prompts(tokenizer, text_encoder, accelerator.device, weight_dtype) + + if args.offload_to_cpu: + text_encoder.to('cpu') def collate_fn(examples): videos = [example["instance_video"].sample() * vae.config.scaling_factor for example in examples] @@ -1333,9 +1383,15 @@ def main(args): ) # Prepare everything with our `accelerator`. - transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - transformer, optimizer, train_dataloader, lr_scheduler - ) + if not args.offload_to_cpu: + transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer, optimizer, train_dataloader, lr_scheduler + ) + else: + offloader = Offloader() + print(f"Offloading to CPU on device {CachedVideoList.ACCELERATOR_DEVICE}") + offloader.enable_sequential_cpu_offload(transformer) + # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -1414,18 +1470,20 @@ def main(args): with accelerator.accumulate(models_to_accumulate): model_input = batch["videos"].permute(0, 2, 1, 3, 4).to(dtype=weight_dtype) # [B, F, C, H, W] prompts = batch["prompts"] - - # encode prompts - prompt_embeds = compute_prompt_embeddings( - tokenizer, - text_encoder, - prompts, - model_config.max_text_seq_length, - accelerator.device, - weight_dtype, - requires_grad=False, - ) - + if args.offload_to_cpu: + prompt_embeds = torch.cat(prompts) + else: + # encode prompts + prompt_embeds = compute_prompt_embeddings( + tokenizer, + text_encoder, + prompts, + model_config.max_text_seq_length, + accelerator.device if not args.offload_to_cpu else "cpu", + weight_dtype, + requires_grad=False, + ) + prompt_embeds.to(device=accelerator.device) # Sample noise that will be added to the latents noise = torch.randn_like(model_input) batch_size, num_frames, num_channels, height, width = model_input.shape