From 4339f65660bce2ab09e40853fb312012bee6aeb4 Mon Sep 17 00:00:00 2001 From: zR <2448370773@qq.com> Date: Sat, 5 Oct 2024 01:05:27 +0800 Subject: [PATCH] update --- .../train_cogvideox_image_to_video_lora.py | 97 ++++++++++++++----- finetune/train_cogvideox_lora.py | 3 +- 2 files changed, 76 insertions(+), 24 deletions(-) diff --git a/finetune/train_cogvideox_image_to_video_lora.py b/finetune/train_cogvideox_image_to_video_lora.py index 81a6b2c..ba9b0b7 100644 --- a/finetune/train_cogvideox_image_to_video_lora.py +++ b/finetune/train_cogvideox_image_to_video_lora.py @@ -1,5 +1,4 @@ -# Copyright 2024 The HuggingFace Team. -# All rights reserved. +# Copyright 2024 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -45,10 +44,7 @@ from diffusers import ( from diffusers.models.embeddings import get_3d_rotary_pos_embed from diffusers.optimization import get_scheduler from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid -from diffusers.training_utils import ( - cast_training_params, - clear_objs_and_retain_memory, -) +from diffusers.training_utils import cast_training_params, free_memory from diffusers.utils import ( check_min_version, convert_unet_state_dict_to_peft, @@ -58,6 +54,10 @@ from diffusers.utils import ( ) from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.torch_utils import is_compiled_module +from torchvision.transforms.functional import center_crop, resize +from torchvision.transforms import InterpolationMode +import torchvision.transforms as TT +import numpy as np if is_wandb_available(): @@ -236,6 +236,12 @@ def get_args(): default=720, help="All input videos are resized to this width.", ) + parser.add_argument( + "--video_reshape_mode", + type=str, + default="center", + help="All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']", + ) parser.add_argument("--fps", type=int, default=8, help="All input videos will be used at this FPS.") parser.add_argument( "--max_num_frames", type=int, default=49, help="All input videos will be truncated to these many frames." @@ -442,6 +448,7 @@ class VideoDataset(Dataset): video_column: str = "video", height: int = 480, width: int = 720, + video_reshape_mode: str = "center", fps: int = 8, max_num_frames: int = 49, skip_frames_start: int = 0, @@ -450,6 +457,7 @@ class VideoDataset(Dataset): id_token: Optional[str] = None, ) -> None: super().__init__() + self.instance_data_root = Path(instance_data_root) if instance_data_root is not None else None self.dataset_name = dataset_name self.dataset_config_name = dataset_config_name @@ -457,12 +465,14 @@ class VideoDataset(Dataset): self.video_column = video_column self.height = height self.width = width + self.video_reshape_mode = video_reshape_mode self.fps = fps self.max_num_frames = max_num_frames self.skip_frames_start = skip_frames_start self.skip_frames_end = skip_frames_end self.cache_dir = cache_dir self.id_token = id_token or "" + if dataset_name is not None: self.instance_prompts, self.instance_video_paths = self._load_dataset_from_hub() else: @@ -561,6 +571,38 @@ class VideoDataset(Dataset): return instance_prompts, instance_videos + def _resize_for_rectangle_crop(self, arr): + image_size = self.height, self.width + reshape_mode = self.video_reshape_mode + if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]: + arr = resize( + arr, + size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])], + interpolation=InterpolationMode.BICUBIC, + ) + else: + arr = resize( + arr, + size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]], + interpolation=InterpolationMode.BICUBIC, + ) + + h, w = arr.shape[2], arr.shape[3] + arr = arr.squeeze(0) + + delta_h = h - image_size[0] + delta_w = w - image_size[1] + + if reshape_mode == "random" or reshape_mode == "none": + top = np.random.randint(0, delta_h + 1) + left = np.random.randint(0, delta_w + 1) + elif reshape_mode == "center": + top, left = delta_h // 2, delta_w // 2 + else: + raise NotImplementedError + arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1]) + return arr + def _preprocess_data(self): try: import decord @@ -571,14 +613,15 @@ class VideoDataset(Dataset): decord.bridge.set_bridge("torch") - videos = [] - train_transforms = transforms.Compose( - [ - transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0), - ] + progress_dataset_bar = tqdm( + range(0, len(self.instance_video_paths)), + desc="Loading progress resize and crop videos", ) + videos = [] + for filename in self.instance_video_paths: + progress_dataset_bar.update(1) video_reader = decord.VideoReader(uri=filename.as_posix(), width=self.width, height=self.height) video_num_frames = len(video_reader) @@ -605,9 +648,12 @@ class VideoDataset(Dataset): assert (selected_num_frames - 1) % 4 == 0 # Training transforms - frames = frames.float() - frames = torch.stack([train_transforms(frame) for frame in frames], dim=0) - videos.append(frames.permute(0, 3, 1, 2).contiguous()) # [F, C, H, W] + frames = (frames - 127.5) / 127.5 + frames = frames.permute(0, 3, 1, 2) # [F, C, H, W] + frames = self._resize_for_rectangle_crop(frames) + videos.append(frames.contiguous()) # [F, C, H, W] + + progress_dataset_bar.close() return videos @@ -727,7 +773,7 @@ def log_validation( videos = [] for _ in range(args.num_validation_videos): - video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0] + video = pipe(**pipeline_args, generator=generator, output_type="pil").frames[0] videos.append(video) for tracker in accelerator.trackers: @@ -756,7 +802,8 @@ def log_validation( } ) - clear_objs_and_retain_memory([pipe]) + del pipe + free_memory() return videos @@ -1204,6 +1251,7 @@ def main(args): video_column=args.video_column, height=args.height, width=args.width, + video_reshape_mode=args.video_reshape_mode, fps=args.fps, max_num_frames=args.max_num_frames, skip_frames_start=args.skip_frames_start, @@ -1212,7 +1260,8 @@ def main(args): id_token=args.id_token, ) - def encode_video(video): + def encode_video(video, bar): + bar.update(1) video = video.to(accelerator.device, dtype=vae.dtype).unsqueeze(0) video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] image = video[:, :, :1].clone() @@ -1238,7 +1287,13 @@ def main(args): ) for prompt in train_dataset.instance_prompts ] - train_dataset.instance_videos = [encode_video(video) for video in train_dataset.instance_videos] + + progress_encode_bar = tqdm( + range(0, len(train_dataset.instance_videos)), + desc="Loading Encode videos", + ) + train_dataset.instance_videos = [encode_video(video, progress_encode_bar) for video in train_dataset.instance_videos] + progress_encode_bar.close() def collate_fn(examples): videos = [] @@ -1378,9 +1433,6 @@ def main(args): ) vae_scale_factor_spatial = 2 ** (len(vae.config.block_out_channels) - 1) - # Delete VAE and Text Encoder to save memory - clear_objs_and_retain_memory([vae, text_encoder]) - # For DeepSpeed training model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config @@ -1550,7 +1602,8 @@ def main(args): ) # Cleanup trained models to save memory - clear_objs_and_retain_memory([transformer]) + del transformer + free_memory() # Final test inference pipe = CogVideoXImageToVideoPipeline.from_pretrained( diff --git a/finetune/train_cogvideox_lora.py b/finetune/train_cogvideox_lora.py index 137f322..e39827b 100644 --- a/finetune/train_cogvideox_lora.py +++ b/finetune/train_cogvideox_lora.py @@ -1,5 +1,4 @@ -# Copyright 2024 The HuggingFace Team. -# All rights reserved. +# Copyright 2024 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License.