diff --git a/finetune/train_cogvideox_lora.py b/finetune/train_cogvideox_lora.py index 437b96e..e6cb262 100644 --- a/finetune/train_cogvideox_lora.py +++ b/finetune/train_cogvideox_lora.py @@ -401,95 +401,9 @@ def get_args(): ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' ), ) - parser.add_argument( - "--offload_to_cpu", - action="store_true", - help="Whether or not to offload the model to the CPU.", - ) - parser.add_argument( - "--cache_preprocessed_data", - action="store_true", - help="Whether or not to cache preprocessed data.", - ) - parsed_args = parser.parse_args() - - return parsed_args + return parser.parse_args() -class Offloader: - def __init__(self): - pass - - def enable_sequential_cpu_offload(self, model): - from accelerate import cpu_offload - gpu_id = None - device = "cuda" - - torch_device = torch.device(device) - device_index = torch_device.index - - if gpu_id is not None and device_index is not None: - raise ValueError( - f"You have passed both `gpu_id`={gpu_id} and an index as part of the passed device `device`={device}" - f"Cannot pass both. Please make sure to either not define `gpu_id` or not pass the index as part of the device: `device`={torch_device.type}" - ) - - # _offload_gpu_id should be set to passed gpu_id (or id in passed `device`) or default to previously set id or default to 0 - self._offload_gpu_id = gpu_id or torch_device.index or 0 - - device_type = torch_device.type - device = torch.device(f"{device_type}:{self._offload_gpu_id}") - self._offload_device = device - - offload_buffers = len(model._parameters) > 0 - cpu_offload(model, device, offload_buffers=offload_buffers) - -class CachedVideoList: - ACCELERATOR_DEVICE = 'cpu' - CACHE_ENABLED = False - VAE = None - OUTPUT_DIR = None - - @classmethod - def is_cached(cls, video_name): - if cls.CACHE_ENABLED: - return os.path.exists(os.path.join(cls.cache_dir(), f'{video_name}.pt')) - return False - - @classmethod - def cache_dir(cls): - result = os.path.join(cls.OUTPUT_DIR, "cached_videos") - if not os.path.exists(result): - os.makedirs(result) - return result - - def __init__(self): - if not self.CACHE_ENABLED: - raise ValueError("CachedVideoList is not enabled. Please enable the cache before using it.") - if self.OUTPUT_DIR is None: - raise ValueError("Output directory not set. Please set the output directory before using the CachedVideoList.") - if self.VAE is None: - raise ValueError("VAE model not set. Please set the VAE model before using the CachedVideoList.") - - self.video_names = [] - - - def __len__(self): - return len(self.video_names) - - def append(self, video: Tuple[str, torch.Tensor]): - self.video_names.append(video[0]) - if video[1] is not None: - torch.save(video[1], os.path.join(CachedVideoList.cache_dir(), f'{video[0]}.pt')) - - def __getitem__(self, index): - if index >= len(self.video_names): - raise IndexError("Index out of bounds") - pt_path = os.path.join(CachedVideoList.cache_dir(), f'{self.video_names[index]}.pt') - if not os.path.exists(pt_path): - raise FileNotFoundError(f"Video {pt_path} not found in cache.") - return torch.load(pt_path) - class VideoDataset(Dataset): def __init__( @@ -528,9 +442,6 @@ class VideoDataset(Dataset): self.instance_prompts, self.instance_video_paths = self._load_dataset_from_hub() else: self.instance_prompts, self.instance_video_paths = self._load_dataset_from_local_path() - - if self.id_token is not None: - self.instance_prompts = [self.id_token + prompt for prompt in self.instance_prompts] self.num_instance_videos = len(self.instance_video_paths) if self.num_instance_videos != len(self.instance_prompts): @@ -545,7 +456,7 @@ class VideoDataset(Dataset): def __getitem__(self, index): return { - "instance_prompt": self.instance_prompts[index], + "instance_prompt": self.id_token + self.instance_prompts[index], "instance_video": self.instance_videos[index], } @@ -622,32 +533,6 @@ class VideoDataset(Dataset): ) return instance_prompts, instance_videos - - def encode_prompts(self, tokenizer, text_encoder, device, dtype): - encoded_prompts = [] - for index, prompt in enumerate(self.instance_prompts): - print(f"Encoding prompt {index + 1} of {len(self.instance_prompts)}") - prompt_embeds = compute_prompt_embeddings( - tokenizer, - text_encoder, - prompt, - max_sequence_length=226, - device=device, - dtype=dtype, - ) - encoded_prompts.append(prompt_embeds) - self.instance_prompts = encoded_prompts - - def encode_videos(self, vae, device): - if not CachedVideoList.CACHE_ENABLED: - self.instance_videos = [self.encode_video(video, vae, device) for video in self.instance_videos] - - def encode_video(self, video, vae, device): - video = video.to(device, dtype=vae.dtype).unsqueeze(0) - video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] - vae.to(device) - latent_dist = vae.encode(video.to(device)).latent_dist - return latent_dist def _preprocess_data(self): try: @@ -659,7 +544,7 @@ class VideoDataset(Dataset): decord.bridge.set_bridge("torch") - videos = CachedVideoList() if CachedVideoList.CACHE_ENABLED else [] + videos = [] train_transforms = transforms.Compose( [ transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0), @@ -667,11 +552,6 @@ class VideoDataset(Dataset): ) for filename in self.instance_video_paths: - if CachedVideoList.CACHE_ENABLED and videos.is_cached(filename.stem): - videos.append((filename.stem, None)) - continue - if CachedVideoList.CACHE_ENABLED: - print(f"Pre-processing video {filename.as_posix()}") video_reader = decord.VideoReader(uri=filename.as_posix(), width=self.width, height=self.height) video_num_frames = len(video_reader) @@ -700,17 +580,7 @@ class VideoDataset(Dataset): # Training transforms frames = frames.float() frames = torch.stack([train_transforms(frame) for frame in frames], dim=0) - frames = frames.permute(0, 3, 1, 2).contiguous() # [F, C, H, W] - - if CachedVideoList.CACHE_ENABLED: - print(f"Encoding video {filename.stem} device {CachedVideoList.ACCELERATOR_DEVICE}") - - if CachedVideoList.CACHE_ENABLED: - videos.append((filename.stem, self.encode_video(frames, CachedVideoList.VAE, CachedVideoList.ACCELERATOR_DEVICE))) - else: - videos.append(frames) # [F, C, H, W] - if CachedVideoList.CACHE_ENABLED: - print(f"Video Encoded {filename.stem}") + videos.append(frames.permute(0, 3, 1, 2).contiguous()) # [F, C, H, W] return videos @@ -887,7 +757,7 @@ def _get_t5_prompt_embeds( if text_input_ids is None: raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.") - prompt_embeds = text_encoder(text_input_ids.to(device))[0].to(device) + prompt_embeds = text_encoder(text_input_ids.to(device))[0] prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) # duplicate text embeddings for each generation per prompt, using mps friendly method @@ -1184,16 +1054,11 @@ def main(args): "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." ) - if args.offload_to_cpu: - text_encoder.to('cpu', dtype=weight_dtype) - transformer.to('cpu', dtype=weight_dtype) - vae.to('cpu', dtype=weight_dtype) - else: - text_encoder.to(accelerator.device, dtype=weight_dtype) - transformer.to(accelerator.device, dtype=weight_dtype) - vae.to(accelerator.device, dtype=weight_dtype) + text_encoder.to(accelerator.device, dtype=weight_dtype) + transformer.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=weight_dtype) - if args.gradient_checkpointing or args.offload_to_cpu: + if args.gradient_checkpointing: transformer.enable_gradient_checkpointing() # now we will add new LoRA weights to the attention layers @@ -1297,12 +1162,6 @@ def main(args): ) optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer) - - if args.cache_preprocessed_data: - CachedVideoList.OUTPUT_DIR = args.output_dir - CachedVideoList.CACHE_ENABLED = True - CachedVideoList.VAE = vae - CachedVideoList.ACCELERATOR_DEVICE = accelerator.device # Dataset and DataLoader train_dataset = VideoDataset( @@ -1321,20 +1180,13 @@ def main(args): id_token=args.id_token, ) - if args.offload_to_cpu: - vae.to('cuda') - - train_dataset.encode_videos(vae, accelerator.device) - - if args.offload_to_cpu: - vae.to('cpu') - - if args.offload_to_cpu: - text_encoder.to(accelerator.device) - train_dataset.encode_prompts(tokenizer, text_encoder, accelerator.device, weight_dtype) + def encode_video(video): + video = video.to(accelerator.device, dtype=vae.dtype).unsqueeze(0) + video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] + latent_dist = vae.encode(video).latent_dist + return latent_dist - if args.offload_to_cpu: - text_encoder.to('cpu') + train_dataset.instance_videos = [encode_video(video) for video in train_dataset.instance_videos] def collate_fn(examples): videos = [example["instance_video"].sample() * vae.config.scaling_factor for example in examples] @@ -1383,15 +1235,9 @@ def main(args): ) # Prepare everything with our `accelerator`. - if not args.offload_to_cpu: - transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - transformer, optimizer, train_dataloader, lr_scheduler - ) - else: - offloader = Offloader() - print(f"Offloading to CPU on device {CachedVideoList.ACCELERATOR_DEVICE}") - offloader.enable_sequential_cpu_offload(transformer) - + transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer, optimizer, train_dataloader, lr_scheduler + ) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -1470,20 +1316,18 @@ def main(args): with accelerator.accumulate(models_to_accumulate): model_input = batch["videos"].permute(0, 2, 1, 3, 4).to(dtype=weight_dtype) # [B, F, C, H, W] prompts = batch["prompts"] - if args.offload_to_cpu: - prompt_embeds = torch.cat(prompts) - else: - # encode prompts - prompt_embeds = compute_prompt_embeddings( - tokenizer, - text_encoder, - prompts, - model_config.max_text_seq_length, - accelerator.device if not args.offload_to_cpu else "cpu", - weight_dtype, - requires_grad=False, - ) - prompt_embeds.to(device=accelerator.device) + + # encode prompts + prompt_embeds = compute_prompt_embeddings( + tokenizer, + text_encoder, + prompts, + model_config.max_text_seq_length, + accelerator.device, + weight_dtype, + requires_grad=False, + ) + # Sample noise that will be added to the latents noise = torch.randn_like(model_input) batch_size, num_frames, num_channels, height, width = model_input.shape @@ -1696,4 +1540,4 @@ def main(args): if __name__ == "__main__": args = get_args() - main(args) + main(args) \ No newline at end of file