From 21693ca7703770be077184c0e92134f161d1161f Mon Sep 17 00:00:00 2001 From: Zheng Guang Cong Date: Fri, 6 Dec 2024 20:14:43 +0800 Subject: [PATCH 01/25] fix bugs of image-to-video without image-condition --- finetune/train_cogvideox_image_to_video_lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/finetune/train_cogvideox_image_to_video_lora.py b/finetune/train_cogvideox_image_to_video_lora.py index 188d955..79e6223 100644 --- a/finetune/train_cogvideox_image_to_video_lora.py +++ b/finetune/train_cogvideox_image_to_video_lora.py @@ -1283,7 +1283,7 @@ def main(args): image_noise_sigma = torch.normal(mean=-3.0, std=0.5, size=(1,), device=image.device) image_noise_sigma = torch.exp(image_noise_sigma).to(dtype=image.dtype) - noisy_image = torch.randn_like(image) * image_noise_sigma[:, None, None, None, None] + noisy_image = image + torch.randn_like(image) * image_noise_sigma[:, None, None, None, None] image_latent_dist = vae.encode(noisy_image).latent_dist return latent_dist, image_latent_dist From 48ac9c10660b636e5e434de0aee14833a49426f8 Mon Sep 17 00:00:00 2001 From: Gforky Date: Sat, 14 Dec 2024 16:12:57 +0800 Subject: [PATCH 02/25] [fix]fix typo in train_cogvideox_image_to_video_lora.py --- finetune/train_cogvideox_image_to_video_lora.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/finetune/train_cogvideox_image_to_video_lora.py b/finetune/train_cogvideox_image_to_video_lora.py index 79e6223..abf245f 100644 --- a/finetune/train_cogvideox_image_to_video_lora.py +++ b/finetune/train_cogvideox_image_to_video_lora.py @@ -1246,11 +1246,11 @@ def main(args): use_deepspeed_optimizer = ( accelerator.state.deepspeed_plugin is not None - and accelerator.state.deepspeed_plugin.deepspeed_config.get("optimizer", "none").lower() == "none" + and accelerator.state.deepspeed_plugin.deepspeed_config.get("optimizer", "none").lower() != "none" ) use_deepspeed_scheduler = ( accelerator.state.deepspeed_plugin is not None - and accelerator.state.deepspeed_plugin.deepspeed_config.get("scheduler", "none").lower() == "none" + and accelerator.state.deepspeed_plugin.deepspeed_config.get("scheduler", "none").lower() != "none" ) optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer) From 2508c8353bbcc6f3c75ca5f4da807227e5e80021 Mon Sep 17 00:00:00 2001 From: OleehyO Date: Wed, 18 Dec 2024 07:38:10 +0000 Subject: [PATCH 03/25] [bugfix] fix specific resolution setting MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Different models use different resolutions, for example, for the CogVideoX1.5 series models, the optimal generation resolution is 1360x768, But for CogVideoX, the best resolution is 720x480. --- inference/cli_demo.py | 49 +++++++++++++++++++++++++++++++++---------- 1 file changed, 38 insertions(+), 11 deletions(-) diff --git a/inference/cli_demo.py b/inference/cli_demo.py index ea8b4fc..f4dbc28 100644 --- a/inference/cli_demo.py +++ b/inference/cli_demo.py @@ -17,9 +17,9 @@ $ python cli_demo.py --prompt "A girl riding a bike." --model_path THUDM/CogVide Additional options are available to specify the model path, guidance scale, number of inference steps, video generation type, and output paths. """ -import warnings +import logging import argparse -from typing import Literal +from typing import Literal, Optional import torch from diffusers import ( @@ -31,6 +31,20 @@ from diffusers import ( from diffusers.utils import export_to_video, load_image, load_video +logging.basicConfig(level=logging.INFO) + +# Recommended resolution for each model (width, height) +RESOLUTION_MAP = { + # cogvideox1.5-* + "cogvideox1.5-5b-i2v": (1360, 768), + "cogvideox1.5-5b": (1360, 768), + + # cogvideox-* + "cogvideox-5b-i2v": (720, 480), + "cogvideox-5b": (720, 480), + "cogvideox-2b": (720, 480), +} + def generate_video( prompt: str, @@ -38,8 +52,8 @@ def generate_video( lora_path: str = None, lora_rank: int = 128, num_frames: int = 81, - width: int = 1360, - height: int = 768, + width: Optional[int] = None, + height: Optional[int] = None, output_path: str = "./output.mp4", image_or_video_path: str = "", num_inference_steps: int = 50, @@ -48,7 +62,7 @@ def generate_video( dtype: torch.dtype = torch.bfloat16, generate_type: str = Literal["t2v", "i2v", "v2v"], # i2v: image to video, v2v: video to video seed: int = 42, - fps: int = 8, + fps: int = 16, ): """ Generates a video based on the given prompt and saves it to the specified path. @@ -78,10 +92,19 @@ def generate_video( image = None video = None - if (width != 1360 or height != 768) and "cogvideox1.5-5b-i2v" in model_path.lower(): - warnings.warn(f"The width({width}) and height({height}) are not recommended for CogVideoX1.5-5B-I2V. The best resolution for CogVideoX1.5-5B-I2V is 1360x768.") - elif (width != 720 or height != 480) and "cogvideox-5b-i2v" in model_path.lower(): - warnings.warn(f"The width({width}) and height({height}) are not recommended for CogVideo-5B-I2V. The best resolution for CogVideo-5B-I2V is 720x480.") + model_name = model_path.split("/")[-1].lower() + desired_resolution = RESOLUTION_MAP[model_name] + if width is None or height is None: + width, height = desired_resolution + logging.info(f"\033[1mUsing default resolution {desired_resolution} for {model_name}\033[0m") + elif (width, height) != desired_resolution: + if generate_type == "i2v": + # For i2v models, use user-defined width and height + logging.warning(f"\033[1;31mThe width({width}) and height({height}) are not recommended for {model_name}. The best resolution is {desired_resolution}.\033[0m") + else: + # Otherwise, use the recommended width and height + logging.warning(f"\033[1;31m{model_name} is not supported for custom resolution. Setting back to default resolution {desired_resolution}.\033[0m") + width, height = desired_resolution if generate_type == "i2v": pipe = CogVideoXImageToVideoPipeline.from_pretrained(model_path, torch_dtype=dtype) @@ -132,6 +155,8 @@ def generate_video( ).frames[0] elif generate_type == "t2v": video_generate = pipe( + height=height, + width=width, prompt=prompt, num_videos_per_prompt=num_videos_per_prompt, num_inference_steps=num_inference_steps, @@ -142,6 +167,8 @@ def generate_video( ).frames[0] else: video_generate = pipe( + height=height, + width=width, prompt=prompt, video=video, # The path of the video to be used as the background of the video num_videos_per_prompt=num_videos_per_prompt, @@ -172,8 +199,8 @@ if __name__ == "__main__": parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance") parser.add_argument("--num_inference_steps", type=int, default=50, help="Inference steps") parser.add_argument("--num_frames", type=int, default=81, help="Number of steps for the inference process") - parser.add_argument("--width", type=int, default=1360, help="Number of steps for the inference process") - parser.add_argument("--height", type=int, default=768, help="Number of steps for the inference process") + parser.add_argument("--width", type=int, default=None, help="Number of steps for the inference process") + parser.add_argument("--height", type=int, default=None, help="Number of steps for the inference process") parser.add_argument("--fps", type=int, default=16, help="Number of steps for the inference process") parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt") parser.add_argument("--generate_type", type=str, default="t2v", help="The type of video generation") From ba85627577407ca3bbbcf16f1e42e54ed92d94a3 Mon Sep 17 00:00:00 2001 From: OleehyO Date: Wed, 18 Dec 2024 12:30:13 +0000 Subject: [PATCH 04/25] [docs] improve help messages in argument parser Fix and clarify help documentation in parser.add_argument() to better describe command-line arguments. --- inference/cli_demo.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/inference/cli_demo.py b/inference/cli_demo.py index f4dbc28..b9820c9 100644 --- a/inference/cli_demo.py +++ b/inference/cli_demo.py @@ -199,9 +199,9 @@ if __name__ == "__main__": parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance") parser.add_argument("--num_inference_steps", type=int, default=50, help="Inference steps") parser.add_argument("--num_frames", type=int, default=81, help="Number of steps for the inference process") - parser.add_argument("--width", type=int, default=None, help="Number of steps for the inference process") - parser.add_argument("--height", type=int, default=None, help="Number of steps for the inference process") - parser.add_argument("--fps", type=int, default=16, help="Number of steps for the inference process") + parser.add_argument("--width", type=int, default=None, help="The width of the generated video") + parser.add_argument("--height", type=int, default=None, help="The height of the generated video") + parser.add_argument("--fps", type=int, default=16, help="The frames per second for the generated video") parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt") parser.add_argument("--generate_type", type=str, default="t2v", help="The type of video generation") parser.add_argument("--dtype", type=str, default="bfloat16", help="The data type for computation") From 5cb93032865003bf343b41a1d16d5235149aae4f Mon Sep 17 00:00:00 2001 From: OleehyO Date: Fri, 27 Dec 2024 09:42:42 +0000 Subject: [PATCH 05/25] chore: update .gitignore - Add new ignore patterns for dataset and model directories - Update rules for development files --- .gitignore | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 6be6f4b..2ee7d3d 100644 --- a/.gitignore +++ b/.gitignore @@ -8,4 +8,17 @@ logs/ .idea output* test* -venv \ No newline at end of file +venv +**/.swp +**/*.log +**/*.debug +**/.vscode + +**/*debug* +**/.gitignore +**/finetune/*-lora-* +**/finetune/Disney-* +**/wandb +**/results +**/*.mp4 +**/validation_set From 7b282246dd204629ddacede4eea004bf21a71ee1 Mon Sep 17 00:00:00 2001 From: OleehyO Date: Fri, 27 Dec 2024 09:47:45 +0000 Subject: [PATCH 06/25] chore: remove unused configuration files after refactoring Delete accelerate configs, deepspeed config and host file that are no longer needed --- finetune/accelerate_config_machine_multi.yaml | 26 ------------------- .../accelerate_config_machine_single.yaml | 20 -------------- finetune/ds_config.json | 20 -------------- finetune/hostfile.txt | 2 -- 4 files changed, 68 deletions(-) delete mode 100644 finetune/accelerate_config_machine_multi.yaml delete mode 100644 finetune/accelerate_config_machine_single.yaml delete mode 100644 finetune/ds_config.json delete mode 100644 finetune/hostfile.txt diff --git a/finetune/accelerate_config_machine_multi.yaml b/finetune/accelerate_config_machine_multi.yaml deleted file mode 100644 index 856db57..0000000 --- a/finetune/accelerate_config_machine_multi.yaml +++ /dev/null @@ -1,26 +0,0 @@ -compute_environment: LOCAL_MACHINE -debug: true -deepspeed_config: - deepspeed_hostfile: hostfile.txt - deepspeed_multinode_launcher: pdsh - gradient_accumulation_steps: 1 - gradient_clipping: 1.0 - offload_optimizer_device: none - offload_param_device: none - zero3_init_flag: true - zero_stage: 3 -distributed_type: DEEPSPEED -downcast_bf16: 'yes' -enable_cpu_affinity: true -main_process_ip: 10.250.128.19 -main_process_port: 12355 -main_training_function: main -mixed_precision: bf16 -num_machines: 4 -num_processes: 32 -rdzv_backend: static -same_network: true -tpu_env: [] -tpu_use_cluster: false -tpu_use_sudo: false -use_cpu: false diff --git a/finetune/accelerate_config_machine_single.yaml b/finetune/accelerate_config_machine_single.yaml deleted file mode 100644 index 90d993a..0000000 --- a/finetune/accelerate_config_machine_single.yaml +++ /dev/null @@ -1,20 +0,0 @@ -compute_environment: LOCAL_MACHINE -gpu_ids: "0" -debug: false -deepspeed_config: - deepspeed_config_file: ds_config.json - zero3_init_flag: false -distributed_type: DEEPSPEED -downcast_bf16: 'no' -enable_cpu_affinity: false -machine_rank: 0 -main_training_function: main -dynamo_backend: 'no' -num_machines: 1 -num_processes: 1 -rdzv_backend: static -same_network: true -tpu_env: [] -tpu_use_cluster: false -tpu_use_sudo: false -use_cpu: false \ No newline at end of file diff --git a/finetune/ds_config.json b/finetune/ds_config.json deleted file mode 100644 index 9a708a7..0000000 --- a/finetune/ds_config.json +++ /dev/null @@ -1,20 +0,0 @@ -{ - "scheduler": { - "type": "WarmupDecayLR", - "params": { - "warmup_min_lr": "auto", - "warmup_max_lr": "auto", - "warmup_num_steps": "auto", - "total_num_steps": "auto" - } - }, - "zero_optimization": { - "stage": 2, - "allgather_partitions": true, - "allgather_bucket_size": 2e8, - "overlap_comm": true, - "reduce_scatter": true, - "reduce_bucket_size": 1e8, - "contiguous_gradients": true - } -} \ No newline at end of file diff --git a/finetune/hostfile.txt b/finetune/hostfile.txt deleted file mode 100644 index d0b8045..0000000 --- a/finetune/hostfile.txt +++ /dev/null @@ -1,2 +0,0 @@ -node1 slots=8 -node2 slots=8 \ No newline at end of file From e3f6def234d4f01a27881f9af28cf46a4bab787c Mon Sep 17 00:00:00 2001 From: OleehyO Date: Fri, 27 Dec 2024 09:50:47 +0000 Subject: [PATCH 07/25] feat: add video frame extraction tool Add utility script to extract first frames from videos, helping users convert T2V datasets to I2V format --- finetune/scripts/extract_images.py | 52 ++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 finetune/scripts/extract_images.py diff --git a/finetune/scripts/extract_images.py b/finetune/scripts/extract_images.py new file mode 100644 index 0000000..a5cabae --- /dev/null +++ b/finetune/scripts/extract_images.py @@ -0,0 +1,52 @@ +import argparse +import os +from pathlib import Path +import cv2 + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--datadir", type=str, required=True, help="Root directory containing videos.txt and video subdirectory") + return parser.parse_args() + +args = parse_args() + +# Create data/images directory if it doesn't exist +data_dir = Path(args.datadir) +image_dir = data_dir / "images" +image_dir.mkdir(exist_ok=True) + +# Read videos.txt +videos_file = data_dir / "videos.txt" +with open(videos_file, "r") as f: + video_paths = [line.strip() for line in f.readlines() if line.strip()] + +# Process each video file and collect image paths +image_paths = [] +for video_rel_path in video_paths: + video_path = data_dir / video_rel_path + + # Open video + cap = cv2.VideoCapture(str(video_path)) + + # Read first frame + ret, frame = cap.read() + if not ret: + print(f"Failed to read video: {video_path}") + continue + + # Save frame as PNG with same name as video + image_name = f"images/{video_path.stem}.png" + image_path = data_dir / image_name + cv2.imwrite(str(image_path), frame) + + # Release video capture + cap.release() + + print(f"Extracted first frame from {video_path} to {image_path}") + image_paths.append(image_name) + +# Write images.txt +images_file = data_dir / "images.txt" +with open(images_file, "w") as f: + for path in image_paths: + f.write(f"{path}\n") \ No newline at end of file From 918ebb5a5471db0435dd6fb8ca7557384cc90253 Mon Sep 17 00:00:00 2001 From: OleehyO Date: Fri, 27 Dec 2024 09:57:37 +0000 Subject: [PATCH 08/25] feat(datasets): implement video dataset modules - Add dataset implementations for text-to-video and image-to-video - Include bucket sampler for efficient batch processing - Add utility functions for data processing - Create dataset package structure with proper initialization --- finetune/datasets/__init__.py | 12 ++ finetune/datasets/bucket_sampler.py | 73 ++++++++++ finetune/datasets/i2v_dataset.py | 206 ++++++++++++++++++++++++++++ finetune/datasets/t2v_dataset.py | 177 ++++++++++++++++++++++++ finetune/datasets/utils.py | 141 +++++++++++++++++++ 5 files changed, 609 insertions(+) create mode 100644 finetune/datasets/__init__.py create mode 100644 finetune/datasets/bucket_sampler.py create mode 100644 finetune/datasets/i2v_dataset.py create mode 100644 finetune/datasets/t2v_dataset.py create mode 100644 finetune/datasets/utils.py diff --git a/finetune/datasets/__init__.py b/finetune/datasets/__init__.py new file mode 100644 index 0000000..8c9b61a --- /dev/null +++ b/finetune/datasets/__init__.py @@ -0,0 +1,12 @@ +from .i2v_dataset import I2VDatasetWithResize, I2VDatasetWithBuckets +from .t2v_dataset import T2VDatasetWithResize, T2VDatasetWithBuckets +from .bucket_sampler import BucketSampler + + +__all__ = [ + "I2VDatasetWithResize", + "I2VDatasetWithBuckets", + "T2VDatasetWithResize", + "T2VDatasetWithBuckets", + "BucketSampler" +] diff --git a/finetune/datasets/bucket_sampler.py b/finetune/datasets/bucket_sampler.py new file mode 100644 index 0000000..bf1beb1 --- /dev/null +++ b/finetune/datasets/bucket_sampler.py @@ -0,0 +1,73 @@ +import random +import logging + +from torch.utils.data import Sampler +from torch.utils.data import Dataset + +logger = logging.getLogger(__name__) + + +class BucketSampler(Sampler): + r""" + PyTorch Sampler that groups 3D data by height, width and frames. + + Args: + data_source (`VideoDataset`): + A PyTorch dataset object that is an instance of `VideoDataset`. + batch_size (`int`, defaults to `8`): + The batch size to use for training. + shuffle (`bool`, defaults to `True`): + Whether or not to shuffle the data in each batch before dispatching to dataloader. + drop_last (`bool`, defaults to `False`): + Whether or not to drop incomplete buckets of data after completely iterating over all data + in the dataset. If set to True, only batches that have `batch_size` number of entries will + be yielded. If set to False, it is guaranteed that all data in the dataset will be processed + and batches that do not have `batch_size` number of entries will also be yielded. + """ + + def __init__( + self, data_source: Dataset, batch_size: int = 8, shuffle: bool = True, drop_last: bool = False + ) -> None: + self.data_source = data_source + self.batch_size = batch_size + self.shuffle = shuffle + self.drop_last = drop_last + + self.buckets = {resolution: [] for resolution in data_source.video_resolution_buckets} + + self._raised_warning_for_drop_last = False + + + def __len__(self): + if self.drop_last and not self._raised_warning_for_drop_last: + self._raised_warning_for_drop_last = True + logger.warning( + "Calculating the length for bucket sampler is not possible when `drop_last` is set to True. This may cause problems when setting the number of epochs used for training." + ) + return (len(self.data_source) + self.batch_size - 1) // self.batch_size + + + def __iter__(self): + for index, data in enumerate(self.data_source): + video_metadata = data["video_metadata"] + f, h, w = video_metadata["num_frames"], video_metadata["height"], video_metadata["width"] + + self.buckets[(f, h, w)].append(data) + if len(self.buckets[(f, h, w)]) == self.batch_size: + if self.shuffle: + random.shuffle(self.buckets[(f, h, w)]) + yield self.buckets[(f, h, w)] + del self.buckets[(f, h, w)] + self.buckets[(f, h, w)] = [] + + if self.drop_last: + return + + for fhw, bucket in list(self.buckets.items()): + if len(bucket) == 0: + continue + if self.shuffle: + random.shuffle(bucket) + yield bucket + del self.buckets[fhw] + self.buckets[fhw] = [] diff --git a/finetune/datasets/i2v_dataset.py b/finetune/datasets/i2v_dataset.py new file mode 100644 index 0000000..2137dce --- /dev/null +++ b/finetune/datasets/i2v_dataset.py @@ -0,0 +1,206 @@ +from pathlib import Path +from typing import Any, Dict, List, Tuple +from typing_extensions import override + +import torch +from accelerate.logging import get_logger +from torch.utils.data import Dataset +from torchvision import transforms + +from .utils import ( + load_prompts, load_videos, load_images, + + preprocess_image_with_resize, + preprocess_video_with_resize, + preprocess_video_with_buckets +) + +# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error +# Very few bug reports but it happens. Look in decord Github issues for more relevant information. +import decord # isort:skip + +decord.bridge.set_bridge("torch") + +logger = get_logger(__name__) + + +class BaseI2VDataset(Dataset): + """ + + """ + def __init__( + self, + data_root: str, + caption_column: str, + video_column: str, + image_column: str, + *args, + **kwargs + ) -> None: + super().__init__() + + data_root = Path(data_root) + self.prompts = load_prompts(data_root / caption_column) + self.videos = load_videos(data_root / video_column) + self.images = load_images(data_root / image_column) + + # Check if number of prompts matches number of videos and images + if not (len(self.videos) == len(self.prompts) == len(self.images)): + raise ValueError( + f"Expected length of prompts, videos and images to be the same but found {len(self.prompts)=}, {len(self.videos)=} and {len(self.images)=}. Please ensure that the number of caption prompts, videos and images match in your dataset." + ) + + # Check if all video files exist + if any(not path.is_file() for path in self.videos): + raise ValueError( + f"Some video files were not found. Please ensure that all video files exist in the dataset directory. Missing file: {next(path for path in self.videos if not path.is_file())}" + ) + + # Check if all image files exist + if any(not path.is_file() for path in self.images): + raise ValueError( + f"Some image files were not found. Please ensure that all image files exist in the dataset directory. Missing file: {next(path for path in self.images if not path.is_file())}" + ) + + def __len__(self) -> int: + return len(self.videos) + + def __getitem__(self, index: int) -> Dict[str, Any]: + if isinstance(index, list): + # Here, index is actually a list of data objects that we need to return. + # The BucketSampler should ideally return indices. But, in the sampler, we'd like + # to have information about num_frames, height and width. Since this is not stored + # as metadata, we need to read the video to get this information. You could read this + # information without loading the full video in memory, but we do it anyway. In order + # to not load the video twice (once to get the metadata, and once to return the loaded video + # based on sampled indices), we cache it in the BucketSampler. When the sampler is + # to yield, we yield the cache data instead of indices. So, this special check ensures + # that data is not loaded a second time. PRs are welcome for improvements. + return index + + prompt = self.prompts[index] + + # shape of frames: [F, C, H, W] + # shape of image: [C, H, W] + frames, image = self.preprocess(self.videos[index], self.images[index]) + + frames = self.video_transform(frames) + image = self.image_transform(image) + + return { + "prompt": prompt, + "video": frames, + "video_metadata": { + "num_frames": frames.shape[0], + "height": frames.shape[2], + "width": frames.shape[3], + }, + "image": image + } + + def preprocess(self, video_path: Path, image_path: Path) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Loads and preprocesses a video and an image. + + Args: + video_path: Path to the video file to load + image_path: Path to the image file to load + + Returns: + A tuple containing: + - video(torch.Tensor) of shape [F, C, H, W] where F is number of frames, + C is number of channels, H is height and W is width + - image(torch.Tensor) of shape [C, H, W] + """ + raise NotImplementedError("Subclass must implement this method") + + def video_transform(self, frames: torch.Tensor) -> torch.Tensor: + """ + Applies transformations to a video. + + Args: + frames (torch.Tensor): A 4D tensor representing a video + with shape [F, C, H, W] where: + - F is number of frames + - C is number of channels (3 for RGB) + - H is height + - W is width + + Returns: + torch.Tensor: The transformed video tensor + """ + raise NotImplementedError("Subclass must implement this method") + + def image_transform(self, image: torch.Tensor) -> torch.Tensor: + """ + Applies transformations to an image. + + Args: + image (torch.Tensor): A 3D tensor representing an image + with shape [C, H, W] where: + - C is number of channels (3 for RGB) + - H is height + - W is width + + Returns: + torch.Tensor: The transformed image tensor + """ + raise NotImplementedError("Subclass must implement this method") + + +class I2VDatasetWithResize(BaseI2VDataset): + """ + + """ + def __init__(self, max_num_frames: int, height: int, width: int, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + self.max_num_frames = max_num_frames + self.height = height + self.width = width + + self.__frame_transforms = transforms.Compose( + [ + transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0) + ] + ) + + @override + def preprocess(self, video_path: Path, image_path: Path) -> Tuple[torch.Tensor, torch.Tensor]: + video = preprocess_video_with_resize(video_path, self.max_num_frames, self.height, self.width) + image = preprocess_image_with_resize(image_path, self.height, self.width) + return video, image + + @override + def video_transform(self, frames: torch.Tensor) -> torch.Tensor: + return torch.stack([self.__frame_transforms(f) for f in frames], dim=0) + + +class I2VDatasetWithBuckets(BaseI2VDataset): + """ + + """ + def __init__(self, video_resolution_buckets: List[Tuple[int, int, int]], *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + self.video_resolution_buckets = video_resolution_buckets + self.__frame_transforms = transforms.Compose( + [ + transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0) + ] + ) + self.__image_transforms = self.__frame_transforms + + @override + def preprocess(self, video_path: Path, image_path: Path) -> Tuple[torch.Tensor, torch.Tensor]: + video = preprocess_video_with_buckets(video_path, self.video_resolution_buckets) + image = preprocess_image_with_resize(image_path, video.shape[2], video.shape[3]) + return video, image + + @override + def video_transform(self, frames: torch.Tensor) -> torch.Tensor: + return torch.stack([self.__frame_transforms(f) for f in frames], dim=0) + + @override + def image_transform(self, image: torch.Tensor) -> torch.Tensor: + return self.__image_transforms(image) diff --git a/finetune/datasets/t2v_dataset.py b/finetune/datasets/t2v_dataset.py new file mode 100644 index 0000000..aa4ed72 --- /dev/null +++ b/finetune/datasets/t2v_dataset.py @@ -0,0 +1,177 @@ +from pathlib import Path +from typing import Any, Dict, List, Tuple +from typing_extensions import override + +import torch +from accelerate.logging import get_logger +from torch.utils.data import Dataset +from torchvision import transforms + +from .utils import ( + load_prompts, load_videos, + preprocess_video_with_resize, + preprocess_video_with_buckets +) + +# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error +# Very few bug reports but it happens. Look in decord Github issues for more relevant information. +import decord # isort:skip + +decord.bridge.set_bridge("torch") + +logger = get_logger(__name__) + + +class BaseT2VDataset(Dataset): + """ + + """ + def __init__( + self, + data_root: str, + caption_column: str, + video_column: str, + *args, + **kwargs + ) -> None: + super().__init__() + + data_root = Path(data_root) + self.prompts = load_prompts(data_root / caption_column) + self.videos = load_videos(data_root / video_column) + + # Check if all video files exist + if any(not path.is_file() for path in self.videos): + raise ValueError( + f"Some video files were not found. Please ensure that all video files exist in the dataset directory. Missing file: {next(path for path in self.videos if not path.is_file())}" + ) + + # Check if number of prompts matches number of videos + if len(self.videos) != len(self.prompts): + raise ValueError( + f"Expected length of prompts and videos to be the same but found {len(self.prompts)=} and {len(self.videos)=}. Please ensure that the number of caption prompts and videos match in your dataset." + ) + + def __len__(self) -> int: + return len(self.videos) + + def __getitem__(self, index: int) -> Dict[str, Any]: + if isinstance(index, list): + # Here, index is actually a list of data objects that we need to return. + # The BucketSampler should ideally return indices. But, in the sampler, we'd like + # to have information about num_frames, height and width. Since this is not stored + # as metadata, we need to read the video to get this information. You could read this + # information without loading the full video in memory, but we do it anyway. In order + # to not load the video twice (once to get the metadata, and once to return the loaded video + # based on sampled indices), we cache it in the BucketSampler. When the sampler is + # to yield, we yield the cache data instead of indices. So, this special check ensures + # that data is not loaded a second time. PRs are welcome for improvements. + return index + + prompt = self.prompts[index] + + # shape of frames: [F, C, H, W] + frames = self.preprocess(self.videos[index]) + frames = self.video_transform(frames) + + return { + "prompt": prompt, + "video": frames, + "video_metadata": { + "num_frames": frames.shape[0], + "height": frames.shape[2], + "width": frames.shape[3], + }, + } + + def preprocess(self, video_path: Path) -> torch.Tensor: + """ + Loads and preprocesses a video. + + Args: + video_path: Path to the video file to load. + + Returns: + torch.Tensor: Video tensor of shape [F, C, H, W] where: + - F is number of frames + - C is number of channels (3 for RGB) + - H is height + - W is width + """ + raise NotImplementedError("Subclass must implement this method") + + def video_transform(self, frames: torch.Tensor) -> torch.Tensor: + """ + Applies transformations to a video. + + Args: + frames (torch.Tensor): A 4D tensor representing a video + with shape [F, C, H, W] where: + - F is number of frames + - C is number of channels (3 for RGB) + - H is height + - W is width + + Returns: + torch.Tensor: The transformed video tensor + """ + raise NotImplementedError("Subclass must implement this method") + + +class T2VDatasetWithResize(BaseT2VDataset): + """ + + """ + def __init__(self, max_num_frames: int, height: int, width: int, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + self.max_num_frames = max_num_frames + self.height = height + self.width = width + + self.__frame_transform = transforms.Compose( + [ + transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0) + ] + ) + + @override + def preprocess(self, video_path: Path) -> torch.Tensor: + return preprocess_video_with_resize( + video_path, self.max_num_frames, self.height, self.width, + ) + + @override + def video_transform(self, frames: torch.Tensor) -> torch.Tensor: + return torch.stack([self.__frame_transform(f) for f in frames], dim=0) + + +class T2VDatasetWithBuckets(BaseT2VDataset): + """ + + """ + def __init__(self, video_resolution_buckets: List[Tuple[int, int, int]], *args, **kwargs) -> None: + """ + Args: + resolution_buckets: List of tuples representing the resolution buckets. + Each tuple contains three integers: (max_num_frames, height, width). + """ + super().__init__(*args, **kwargs) + + self.video_resolution_buckets = video_resolution_buckets + + self.__frame_transform = transforms.Compose( + [ + transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0) + ] + ) + + @override + def preprocess(self, video_path: Path) -> torch.Tensor: + return preprocess_video_with_buckets( + video_path, self.video_resolution_buckets + ) + + @override + def video_transform(self, frames: torch.Tensor) -> torch.Tensor: + return torch.stack([self.__frame_transform(f) for f in frames], dim=0) diff --git a/finetune/datasets/utils.py b/finetune/datasets/utils.py new file mode 100644 index 0000000..cf0525e --- /dev/null +++ b/finetune/datasets/utils.py @@ -0,0 +1,141 @@ +import torch +import cv2 + +from typing import List, Tuple +from pathlib import Path +from torchvision import transforms +from torchvision.transforms.functional import resize + +# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error +# Very few bug reports but it happens. Look in decord Github issues for more relevant information. +import decord # isort:skip + +decord.bridge.set_bridge("torch") + + +########## loaders ########## + +def load_prompts(prompt_path: Path) -> List[str]: + with open(prompt_path, "r", encoding="utf-8") as file: + return [line.strip() for line in file.readlines() if len(line.strip()) > 0] + + +def load_videos(video_path: Path) -> List[Path]: + with open(video_path, "r", encoding="utf-8") as file: + return [video_path.parent / line.strip() for line in file.readlines() if len(line.strip()) > 0] + + +def load_images(image_path: Path) -> List[Path]: + with open(image_path, "r", encoding="utf-8") as file: + return [image_path.parent / line.strip() for line in file.readlines() if len(line.strip()) > 0] + + +########## preprocessors ########## + +def preprocess_image_with_resize( + image_path: Path, + height: int, + width: int, +) -> torch.Tensor: + """ + Loads and resizes a single image. + + Args: + image_path: Path to the image file. + height: Target height for resizing. + width: Target width for resizing. + + Returns: + torch.Tensor: Image tensor with shape [C, H, W] where: + C = number of channels (3 for RGB) + H = height + W = width + """ + image = cv2.imread(image_path.as_posix()) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + image = cv2.resize(image, (width, height)) + image = torch.from_numpy(image).float() + image = image.permute(2, 0, 1).contiguous() + return image + + +def preprocess_video_with_resize( + video_path: Path, + max_num_frames: int, + height: int, + width: int, +) -> torch.Tensor: + """ + Loads and resizes a single video. + + The function processes the video through these steps: + 1. If video frame count > max_num_frames, downsample frames evenly + 2. If video dimensions don't match (height, width), resize frames + + Args: + video_path: Path to the video file. + max_num_frames: Maximum number of frames to keep. + height: Target height for resizing. + width: Target width for resizing. + + Returns: + A torch.Tensor with shape [F, C, H, W] where: + F = number of frames + C = number of channels (3 for RGB) + H = height + W = width + """ + video_reader = decord.VideoReader(uri=video_path.as_posix(), width=width, height=height) + video_num_frames = len(video_reader) + + indices = list(range(0, video_num_frames, video_num_frames // max_num_frames)) + frames = video_reader.get_batch(indices) + frames = frames[: max_num_frames].float() + frames = frames.permute(0, 3, 1, 2).contiguous() + return frames + + +def preprocess_video_with_buckets( + video_path: Path, + resolution_buckets: List[Tuple[int, int, int]], +) -> torch.Tensor: + """ + Args: + video_path: Path to the video file. + resolution_buckets: List of tuples (num_frames, height, width) representing + available resolution buckets. + + Returns: + torch.Tensor: Video tensor with shape [F, C, H, W] where: + F = number of frames + C = number of channels (3 for RGB) + H = height + W = width + + The function processes the video through these steps: + 1. Finds nearest frame bucket <= video frame count + 2. Downsamples frames evenly to match bucket size + 3. Finds nearest resolution bucket based on dimensions + 4. Resizes frames to match bucket resolution + """ + video_reader = decord.VideoReader(uri=video_path.as_posix()) + video_num_frames = len(video_reader) + resolution_buckets = [bucket for bucket in resolution_buckets if bucket[0] <= video_num_frames] + if len(resolution_buckets) == 0: + raise ValueError(f"video frame count in {video_path} is less than all frame buckets {resolution_buckets}") + + nearest_frame_bucket = min( + resolution_buckets, + key=lambda bucket: video_num_frames - bucket[0], + default=1, + )[0] + frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket)) + frames = video_reader.get_batch(frame_indices) + frames = frames[:nearest_frame_bucket].float() + frames = frames.permute(0, 3, 1, 2).contiguous() + + nearest_res = min(resolution_buckets, key=lambda x: abs(x[1] - frames.shape[2]) + abs(x[2] - frames.shape[3])) + nearest_res = (nearest_res[1], nearest_res[2]) + frames = torch.stack([resize(f, nearest_res) for f in frames], dim=0) + + return frames \ No newline at end of file From 85e00a1082c041904691551cc33800b6ab4502c7 Mon Sep 17 00:00:00 2001 From: OleehyO Date: Fri, 27 Dec 2024 09:59:49 +0000 Subject: [PATCH 09/25] feat(models): add scaffolding --- finetune/models/__init__.py | 12 ++++ .../models/cogvideox1dot5_i2v/lora_trainer.py | 29 +++++++++ .../models/cogvideox1dot5_t2v/lora_trainer.py | 29 +++++++++ finetune/models/cogvideox_i2v/lora_trainer.py | 29 +++++++++ finetune/models/utils.py | 62 +++++++++++++++++++ 5 files changed, 161 insertions(+) create mode 100644 finetune/models/__init__.py create mode 100644 finetune/models/cogvideox1dot5_i2v/lora_trainer.py create mode 100644 finetune/models/cogvideox1dot5_t2v/lora_trainer.py create mode 100644 finetune/models/cogvideox_i2v/lora_trainer.py create mode 100644 finetune/models/utils.py diff --git a/finetune/models/__init__.py b/finetune/models/__init__.py new file mode 100644 index 0000000..b315ff5 --- /dev/null +++ b/finetune/models/__init__.py @@ -0,0 +1,12 @@ +import importlib +from pathlib import Path + + +package_dir = Path(__file__).parent + +for subdir in package_dir.iterdir(): + if subdir.is_dir() and not subdir.name.startswith('_'): + for module_path in subdir.glob('*.py'): + module_name = module_path.stem + full_module_name = f".{subdir.name}.{module_name}" + importlib.import_module(full_module_name, package=__name__) diff --git a/finetune/models/cogvideox1dot5_i2v/lora_trainer.py b/finetune/models/cogvideox1dot5_i2v/lora_trainer.py new file mode 100644 index 0000000..6ef9dd4 --- /dev/null +++ b/finetune/models/cogvideox1dot5_i2v/lora_trainer.py @@ -0,0 +1,29 @@ +import torch + +from typing_extensions import override +from typing import Any, Dict, List + +from finetune.trainer import Trainer +from ..utils import register + + +class CogVideoX1dot5I2VLoraTrainer(Trainer): + + @override + def collate_fn(self, samples: List[List[Dict[str, Any]]]) -> Dict[str, Any]: + raise NotImplementedError + + @override + def load_components(self) -> Dict[str, Any]: + raise NotImplementedError + + @override + def compute_loss(self, batch) -> torch.Tensor: + raise NotImplementedError + + @override + def validate(self) -> None: + raise NotImplementedError + + +register("cogvideox1.5-i2v", "lora", CogVideoX1dot5I2VLoraTrainer) diff --git a/finetune/models/cogvideox1dot5_t2v/lora_trainer.py b/finetune/models/cogvideox1dot5_t2v/lora_trainer.py new file mode 100644 index 0000000..dfc2a78 --- /dev/null +++ b/finetune/models/cogvideox1dot5_t2v/lora_trainer.py @@ -0,0 +1,29 @@ +import torch + +from typing_extensions import override +from typing import Any, Dict, List + +from finetune.trainer import Trainer +from ..utils import register + + +class CogVideoX1dot5T2VLoraTrainer(Trainer): + + @override + def collate_fn(self, samples: List[List[Dict[str, Any]]]) -> Dict[str, Any]: + raise NotImplementedError + + @override + def load_components(self) -> Dict[str, Any]: + raise NotImplementedError + + @override + def compute_loss(self, batch) -> torch.Tensor: + raise NotImplementedError + + @override + def validate(self) -> None: + raise NotImplementedError + + +register("cogvideox1.5-t2v", "lora", CogVideoX1dot5T2VLoraTrainer) diff --git a/finetune/models/cogvideox_i2v/lora_trainer.py b/finetune/models/cogvideox_i2v/lora_trainer.py new file mode 100644 index 0000000..d625f18 --- /dev/null +++ b/finetune/models/cogvideox_i2v/lora_trainer.py @@ -0,0 +1,29 @@ +import torch + +from typing_extensions import override +from typing import Any, Dict, List + +from finetune.trainer import Trainer +from ..utils import register + + +class CogVideoXI2VLoraTrainer(Trainer): + + @override + def collate_fn(self, samples: List[List[Dict[str, Any]]]) -> Dict[str, Any]: + raise NotImplementedError + + @override + def load_components(self) -> Dict[str, Any]: + raise NotImplementedError + + @override + def compute_loss(self, batch) -> torch.Tensor: + raise NotImplementedError + + @override + def validate(self) -> None: + raise NotImplementedError + + +register("cogvideox-i2v", "lora", CogVideoXI2VLoraTrainer) \ No newline at end of file diff --git a/finetune/models/utils.py b/finetune/models/utils.py new file mode 100644 index 0000000..fd5a455 --- /dev/null +++ b/finetune/models/utils.py @@ -0,0 +1,62 @@ +from typing import Literal, Dict + +from finetune.trainer import Trainer + + +SUPPORTED_MODELS: Dict[str, Dict[str, Trainer]] = {} + + +def register(model_name: str, training_type: Literal["lora", "sft"], trainer_cls: Trainer): + """Register a model and its associated functions for a specific training type. + + Args: + model_name (str): Name of the model to register (e.g. "cogvideox-5b") + training_type (Literal["lora", "sft"]): Type of training - either "lora" or "sft" + trainer_cls (Trainer): Trainer class to register. + """ + + # Check if model_name exists in SUPPORTED_MODELS + if model_name not in SUPPORTED_MODELS: + SUPPORTED_MODELS[model_name] = {} + else: + raise ValueError(f"Model {model_name} already exists") + + # Check if training_type exists for this model + if training_type not in SUPPORTED_MODELS[model_name]: + SUPPORTED_MODELS[model_name][training_type] = {} + else: + raise ValueError(f"Training type {training_type} already exists for model {model_name}") + + SUPPORTED_MODELS[model_name][training_type] = trainer_cls + + +def show_supported_models(): + """Print all currently supported models and their training types.""" + + print("\nSupported Models:") + print("================") + + for model_name, training_types in SUPPORTED_MODELS.items(): + print(f"\n{model_name}") + print("-" * len(model_name)) + for training_type in training_types: + print(f" • {training_type}") + + +def get_model_cls(model_type: str, training_type: Literal["lora", "sft"]) -> Trainer: + """Get the trainer class for a specific model and training type.""" + if model_type not in SUPPORTED_MODELS: + print(f"\nModel '{model_type}' is not supported.") + print("\nSupported models are:") + for supported_model in SUPPORTED_MODELS: + print(f" • {supported_model}") + raise ValueError(f"Model '{model_type}' is not supported") + + if training_type not in SUPPORTED_MODELS[model_type]: + print(f"\nTraining type '{training_type}' is not supported for model '{model_type}'.") + print(f"\nSupported training types for '{model_type}' are:") + for supported_type in SUPPORTED_MODELS[model_type]: + print(f" • {supported_type}") + raise ValueError(f"Training type '{training_type}' is not supported for model '{model_type}'") + + return SUPPORTED_MODELS[model_type][training_type] From 78f655a9a4e7261fa8149b2c81f9155e8d063570 Mon Sep 17 00:00:00 2001 From: OleehyO Date: Sun, 29 Dec 2024 15:06:33 +0000 Subject: [PATCH 10/25] Add utils --- finetune/utils/__init__.py | 4 + finetune/utils/checkpointing.py | 53 +++++++++ finetune/utils/file_utils.py | 47 ++++++++ finetune/utils/memory_utils.py | 60 ++++++++++ finetune/utils/optimizer_utils.py | 180 ++++++++++++++++++++++++++++++ finetune/utils/torch_utils.py | 52 +++++++++ 6 files changed, 396 insertions(+) create mode 100644 finetune/utils/__init__.py create mode 100644 finetune/utils/checkpointing.py create mode 100644 finetune/utils/file_utils.py create mode 100644 finetune/utils/memory_utils.py create mode 100644 finetune/utils/optimizer_utils.py create mode 100644 finetune/utils/torch_utils.py diff --git a/finetune/utils/__init__.py b/finetune/utils/__init__.py new file mode 100644 index 0000000..5c6fe5d --- /dev/null +++ b/finetune/utils/__init__.py @@ -0,0 +1,4 @@ +from .torch_utils import * +from .optimizer_utils import * +from .memory_utils import * +from .checkpointing import * diff --git a/finetune/utils/checkpointing.py b/finetune/utils/checkpointing.py new file mode 100644 index 0000000..1797153 --- /dev/null +++ b/finetune/utils/checkpointing.py @@ -0,0 +1,53 @@ +import os +from pathlib import Path +from typing import Tuple +from accelerate.logging import get_logger + +from finetune.constants import LOG_NAME, LOG_LEVEL +from ..utils.file_utils import find_files, delete_files + + +logger = get_logger(LOG_NAME, LOG_LEVEL) + + +def get_latest_ckpt_path_to_resume_from( + resume_from_checkpoint: str | None, num_update_steps_per_epoch: int +) -> Tuple[str | None, int, int, int]: + if resume_from_checkpoint is None: + initial_global_step = 0 + global_step = 0 + first_epoch = 0 + resume_from_checkpoint_path = None + else: + resume_from_checkpoint_path = Path(resume_from_checkpoint) + if not resume_from_checkpoint_path.exists(): + logger.info(f"Checkpoint '{resume_from_checkpoint}' does not exist. Starting a new training run.") + initial_global_step = 0 + global_step = 0 + first_epoch = 0 + resume_from_checkpoint_path = None + else: + logger.info(f"Resuming from checkpoint {resume_from_checkpoint}") + global_step = int(resume_from_checkpoint_path.name.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + return resume_from_checkpoint_path, initial_global_step, global_step, first_epoch + + +def get_intermediate_ckpt_path(checkpointing_limit: int, step: int, output_dir: str) -> str: + # before saving state, check if this save would set us over the `checkpointing_limit` + if checkpointing_limit is not None: + checkpoints = find_files(output_dir, prefix="checkpoint") + + # before we save the new checkpoint, we need to have at_most `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= checkpointing_limit: + num_to_remove = len(checkpoints) - checkpointing_limit + 1 + checkpoints_to_remove = checkpoints[0:num_to_remove] + delete_files(checkpoints_to_remove) + + logger.info(f"Checkpointing at step {step}") + save_path = os.path.join(output_dir, f"checkpoint-{step}") + logger.info(f"Saving state to {save_path}") + return save_path diff --git a/finetune/utils/file_utils.py b/finetune/utils/file_utils.py new file mode 100644 index 0000000..f04dd85 --- /dev/null +++ b/finetune/utils/file_utils.py @@ -0,0 +1,47 @@ +import logging +import os +import shutil + +from pathlib import Path +from typing import Any, Dict, List, Union +from accelerate.logging import get_logger +from finetune.constants import LOG_NAME, LOG_LEVEL + + +logger = get_logger(LOG_NAME, LOG_LEVEL) + + +def find_files(dir: Union[str, Path], prefix: str = "checkpoint") -> List[str]: + if not isinstance(dir, Path): + dir = Path(dir) + if not dir.exists(): + return [] + checkpoints = os.listdir(dir.as_posix()) + checkpoints = [c for c in checkpoints if c.startswith(prefix)] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + checkpoints = [dir / c for c in checkpoints] + return checkpoints + + +def delete_files(dirs: Union[str, List[str], Path, List[Path]]) -> None: + if not isinstance(dirs, list): + dirs = [dirs] + dirs = [Path(d) if isinstance(d, str) else d for d in dirs] + logger.info(f"Deleting files: {dirs}") + for dir in dirs: + if not dir.exists(): + continue + shutil.rmtree(dir, ignore_errors=True) + + +def string_to_filename(s: str) -> str: + return ( + s.replace(" ", "-") + .replace("/", "-") + .replace(":", "-") + .replace(".", "-") + .replace(",", "-") + .replace(";", "-") + .replace("!", "-") + .replace("?", "-") + ) diff --git a/finetune/utils/memory_utils.py b/finetune/utils/memory_utils.py new file mode 100644 index 0000000..a7a136b --- /dev/null +++ b/finetune/utils/memory_utils.py @@ -0,0 +1,60 @@ +import gc +import torch + +from typing import Any, Dict, Union +from accelerate.logging import get_logger + +from finetune.constants import LOG_NAME, LOG_LEVEL + + +logger = get_logger(LOG_NAME, LOG_LEVEL) + + +def get_memory_statistics(precision: int = 3) -> Dict[str, Any]: + memory_allocated = None + memory_reserved = None + max_memory_allocated = None + max_memory_reserved = None + + if torch.cuda.is_available(): + device = torch.cuda.current_device() + memory_allocated = torch.cuda.memory_allocated(device) + memory_reserved = torch.cuda.memory_reserved(device) + max_memory_allocated = torch.cuda.max_memory_allocated(device) + max_memory_reserved = torch.cuda.max_memory_reserved(device) + + elif torch.mps.is_available(): + memory_allocated = torch.mps.current_allocated_memory() + + else: + logger.warning("No CUDA, MPS, or ROCm device found. Memory statistics are not available.") + + return { + "memory_allocated": round(bytes_to_gigabytes(memory_allocated), ndigits=precision), + "memory_reserved": round(bytes_to_gigabytes(memory_reserved), ndigits=precision), + "max_memory_allocated": round(bytes_to_gigabytes(max_memory_allocated), ndigits=precision), + "max_memory_reserved": round(bytes_to_gigabytes(max_memory_reserved), ndigits=precision), + } + + +def bytes_to_gigabytes(x: int) -> float: + if x is not None: + return x / 1024**3 + + +def free_memory() -> None: + if torch.cuda.is_available(): + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + # TODO(aryan): handle non-cuda devices + + +def make_contiguous(x: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + if isinstance(x, torch.Tensor): + return x.contiguous() + elif isinstance(x, dict): + return {k: make_contiguous(v) for k, v in x.items()} + else: + return x diff --git a/finetune/utils/optimizer_utils.py b/finetune/utils/optimizer_utils.py new file mode 100644 index 0000000..bd93f9c --- /dev/null +++ b/finetune/utils/optimizer_utils.py @@ -0,0 +1,180 @@ +import inspect +import torch + +from accelerate.logging import get_logger + +from finetune.constants import LOG_NAME, LOG_LEVEL + + +logger = get_logger(LOG_NAME, LOG_LEVEL) + + +def get_optimizer( + params_to_optimize, + optimizer_name: str = "adam", + learning_rate: float = 1e-3, + beta1: float = 0.9, + beta2: float = 0.95, + beta3: float = 0.98, + epsilon: float = 1e-8, + weight_decay: float = 1e-4, + prodigy_decouple: bool = False, + prodigy_use_bias_correction: bool = False, + prodigy_safeguard_warmup: bool = False, + use_8bit: bool = False, + use_4bit: bool = False, + use_torchao: bool = False, + use_deepspeed: bool = False, + use_cpu_offload_optimizer: bool = False, + offload_gradients: bool = False, +) -> torch.optim.Optimizer: + optimizer_name = optimizer_name.lower() + + # Use DeepSpeed optimzer + if use_deepspeed: + from accelerate.utils import DummyOptim + + return DummyOptim( + params_to_optimize, + lr=learning_rate, + betas=(beta1, beta2), + eps=epsilon, + weight_decay=weight_decay, + ) + + if use_8bit and use_4bit: + raise ValueError("Cannot set both `use_8bit` and `use_4bit` to True.") + + if (use_torchao and (use_8bit or use_4bit)) or use_cpu_offload_optimizer: + try: + import torchao + + torchao.__version__ + except ImportError: + raise ImportError( + "To use optimizers from torchao, please install the torchao library: `USE_CPP=0 pip install torchao`." + ) + + if not use_torchao and use_4bit: + raise ValueError("4-bit Optimizers are only supported with torchao.") + + # Optimizer creation + supported_optimizers = ["adam", "adamw", "prodigy", "came"] + if optimizer_name not in supported_optimizers: + logger.warning( + f"Unsupported choice of optimizer: {optimizer_name}. Supported optimizers include {supported_optimizers}. Defaulting to `AdamW`." + ) + optimizer_name = "adamw" + + if (use_8bit or use_4bit) and optimizer_name not in ["adam", "adamw"]: + raise ValueError("`use_8bit` and `use_4bit` can only be used with the Adam and AdamW optimizers.") + + if use_8bit: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + if optimizer_name == "adamw": + if use_torchao: + from torchao.prototype.low_bit_optim import AdamW4bit, AdamW8bit + + optimizer_class = AdamW8bit if use_8bit else AdamW4bit if use_4bit else torch.optim.AdamW + else: + optimizer_class = bnb.optim.AdamW8bit if use_8bit else torch.optim.AdamW + + init_kwargs = { + "betas": (beta1, beta2), + "eps": epsilon, + "weight_decay": weight_decay, + } + + elif optimizer_name == "adam": + if use_torchao: + from torchao.prototype.low_bit_optim import Adam4bit, Adam8bit + + optimizer_class = Adam8bit if use_8bit else Adam4bit if use_4bit else torch.optim.Adam + else: + optimizer_class = bnb.optim.Adam8bit if use_8bit else torch.optim.Adam + + init_kwargs = { + "betas": (beta1, beta2), + "eps": epsilon, + "weight_decay": weight_decay, + } + + elif optimizer_name == "prodigy": + try: + import prodigyopt + except ImportError: + raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") + + optimizer_class = prodigyopt.Prodigy + + if learning_rate <= 0.1: + logger.warning( + "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" + ) + + init_kwargs = { + "lr": learning_rate, + "betas": (beta1, beta2), + "beta3": beta3, + "eps": epsilon, + "weight_decay": weight_decay, + "decouple": prodigy_decouple, + "use_bias_correction": prodigy_use_bias_correction, + "safeguard_warmup": prodigy_safeguard_warmup, + } + + elif optimizer_name == "came": + try: + import came_pytorch + except ImportError: + raise ImportError("To use CAME, please install the came-pytorch library: `pip install came-pytorch`") + + optimizer_class = came_pytorch.CAME + + init_kwargs = { + "lr": learning_rate, + "eps": (1e-30, 1e-16), + "betas": (beta1, beta2, beta3), + "weight_decay": weight_decay, + } + + if use_cpu_offload_optimizer: + from torchao.prototype.low_bit_optim import CPUOffloadOptimizer + + if "fused" in inspect.signature(optimizer_class.__init__).parameters: + init_kwargs.update({"fused": True}) + + optimizer = CPUOffloadOptimizer( + params_to_optimize, optimizer_class=optimizer_class, offload_gradients=offload_gradients, **init_kwargs + ) + else: + optimizer = optimizer_class(params_to_optimize, **init_kwargs) + + return optimizer + + +def gradient_norm(parameters): + norm = 0 + for param in parameters: + if param.grad is None: + continue + local_norm = param.grad.detach().data.norm(2) + norm += local_norm.item() ** 2 + norm = norm**0.5 + return norm + + +def max_gradient(parameters): + max_grad_value = float("-inf") + for param in parameters: + if param.grad is None: + continue + local_max_grad = param.grad.detach().data.abs().max() + max_grad_value = max(max_grad_value, local_max_grad.item()) + return max_grad_value diff --git a/finetune/utils/torch_utils.py b/finetune/utils/torch_utils.py new file mode 100644 index 0000000..8a74271 --- /dev/null +++ b/finetune/utils/torch_utils.py @@ -0,0 +1,52 @@ +from typing import Dict, Optional, Union, List + +import torch +from accelerate import Accelerator +from diffusers.utils.torch_utils import is_compiled_module + + +def unwrap_model(accelerator: Accelerator, model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + +def align_device_and_dtype( + x: Union[torch.Tensor, Dict[str, torch.Tensor]], + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, +): + if isinstance(x, torch.Tensor): + if device is not None: + x = x.to(device) + if dtype is not None: + x = x.to(dtype) + elif isinstance(x, dict): + if device is not None: + x = {k: align_device_and_dtype(v, device, dtype) for k, v in x.items()} + if dtype is not None: + x = {k: align_device_and_dtype(v, device, dtype) for k, v in x.items()} + return x + + +def expand_tensor_to_dims(tensor, ndim): + while len(tensor.shape) < ndim: + tensor = tensor.unsqueeze(-1) + return tensor + + +def cast_training_params(model: Union[torch.nn.Module, List[torch.nn.Module]], dtype=torch.float32): + """ + Casts the training parameters of the model to the specified data type. + + Args: + model: The PyTorch model whose parameters will be cast. + dtype: The data type to which the model parameters will be cast. + """ + if not isinstance(model, list): + model = [model] + for m in model: + for param in m.parameters(): + # only upcast trainable parameters into fp32 + if param.requires_grad: + param.data = param.to(dtype) \ No newline at end of file From a505f2e312b07174dea456d565c5230265f2f4a2 Mon Sep 17 00:00:00 2001 From: OleehyO Date: Sun, 29 Dec 2024 15:06:52 +0000 Subject: [PATCH 11/25] Add constants.py --- finetune/constants.py | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 finetune/constants.py diff --git a/finetune/constants.py b/finetune/constants.py new file mode 100644 index 0000000..0842efa --- /dev/null +++ b/finetune/constants.py @@ -0,0 +1,2 @@ +LOG_NAME = "finetrainer" +LOG_LEVEL = "INFO" \ No newline at end of file From 60f6a3d7eed6c8d29abce8795d3a85f19941bf72 Mon Sep 17 00:00:00 2001 From: OleehyO Date: Sun, 29 Dec 2024 15:27:43 +0000 Subject: [PATCH 12/25] feat: add base trainer implementation and training script - Add Trainer base class with core training loop functionality - Implement distributed training setup with Accelerate - Add training script with model/trainer initialization - Support LoRA fine-tuning with checkpointing and validation --- finetune/train.py | 18 ++ finetune/trainer.py | 512 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 530 insertions(+) create mode 100644 finetune/train.py create mode 100644 finetune/trainer.py diff --git a/finetune/train.py b/finetune/train.py new file mode 100644 index 0000000..5f49f4b --- /dev/null +++ b/finetune/train.py @@ -0,0 +1,18 @@ +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent)) + +from finetune.schemas import Args +from finetune.models.utils import get_model_cls + + +def main(): + args = Args.parse_args() + trainer_cls = get_model_cls(args.model_name, args.training_type) + trainer = trainer_cls(args) + trainer.fit() + + +if __name__ == "__main__": + main() diff --git a/finetune/trainer.py b/finetune/trainer.py new file mode 100644 index 0000000..cc9882c --- /dev/null +++ b/finetune/trainer.py @@ -0,0 +1,512 @@ +import torch +import logging +import transformers +import diffusers +import math +import json +import multiprocessing + +from datetime import timedelta +from pathlib import Path +from tqdm import tqdm +from typing import Dict, Any, List + +from torch.utils.data import Dataset, DataLoader +from accelerate.logging import get_logger +from accelerate.accelerator import Accelerator, DistributedType +from accelerate.utils import ( + DistributedDataParallelKwargs, + InitProcessGroupKwargs, + ProjectConfiguration, + set_seed, + gather_object, +) + +from diffusers.optimization import get_scheduler +from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict + +from finetune.schemas import Args, State, Components +from finetune.utils import ( + unwrap_model, cast_training_params, + get_optimizer, + + get_memory_statistics, + free_memory, + + get_latest_ckpt_path_to_resume_from, + get_intermediate_ckpt_path, + get_latest_ckpt_path_to_resume_from, + get_intermediate_ckpt_path, +) +from finetune.datasets import I2VDatasetWithBuckets, T2VDatasetWithBuckets, BucketSampler + +from finetune.constants import LOG_NAME, LOG_LEVEL + + +logger = get_logger(LOG_NAME, LOG_LEVEL) + +_DTYPE_MAP = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, +} + + +class Trainer: + + def __init__(self, args: Args) -> None: + self.args = args + self.state = State(weight_dtype=self.__get_training_dtype()) + + self.components = Components() + self.accelerator: Accelerator = None + self.dataset: Dataset = None + self.data_loader: DataLoader = None + + self.optimizer = None + self.lr_scheduler = None + + self._init_distributed() + self._init_logging() + self._init_directories() + + + def _init_distributed(self): + logging_dir = Path(self.args.output_dir, "logs") + project_config = ProjectConfiguration(project_dir=self.args.output_dir, logging_dir=logging_dir) + ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + init_process_group_kwargs = InitProcessGroupKwargs( + backend="nccl", timeout=timedelta(seconds=self.args.nccl_timeout) + ) + mixed_precision = "no" if torch.backends.mps.is_available() else self.args.mixed_precision + report_to = None if self.args.report_to.lower() == "none" else self.args.report_to + + accelerator = Accelerator( + project_config=project_config, + gradient_accumulation_steps=self.args.gradient_accumulation_steps, + mixed_precision=mixed_precision, + log_with=report_to, + kwargs_handlers=[ddp_kwargs, init_process_group_kwargs], + ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + + self.accelerator = accelerator + + if self.args.seed is not None: + set_seed(self.args.seed) + + + def _init_logging(self) -> None: + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=LOG_LEVEL, + ) + if self.accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + logger.info("Initialized Trainer") + logger.info(f"Accelerator state: \n{self.accelerator.state}", main_process_only=False) + + + def _init_directories(self) -> None: + if self.accelerator.is_main_process: + self.args.output_dir = Path(self.args.output_dir) + self.args.output_dir.mkdir(parents=True, exist_ok=True) + + + def prepare_dataset(self) -> None: + logger.info("Initializing dataset and dataloader") + + if self.args.model_type == "i2v": + self.dataset = I2VDatasetWithBuckets(**(self.args.model_dump())) + elif self.args.model_type == "t2v": + self.dataset = T2VDatasetWithBuckets(**(self.args.model_dump())) + else: + raise ValueError(f"Invalid model type: {self.args.model_type}") + + self.data_loader = torch.utils.data.DataLoader( + self.dataset, + batch_size=1, + sampler=BucketSampler(self.dataset, batch_size=self.args.batch_size, shuffle=True), + collate_fn=self.collate_fn, + num_workers=self.args.num_workers, + pin_memory=self.args.pin_memory, + ) + + def prepare_models(self) -> None: + logger.info("Initializing models") + + # Initialize model components + self.components = self.load_components() + + if self.components.vae is not None: + if self.args.enable_slicing: + self.components.vae.enable_slicing() + if self.args.enable_tiling: + self.components.vae.enable_tiling() + + def prepare_trainable_parameters(self): + logger.info("Initializing trainable parameters") + + # For now only lora is supported + for attr_name, component in vars(self.components).items(): + if hasattr(component, 'requires_grad_'): + component.requires_grad_(False) + + # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = self.state.weight_dtype + + if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + self.__move_components_to_device() + + if self.args.gradient_checkpointing: + self.components.transformer.enable_gradient_checkpointing() + + transformer_lora_config = LoraConfig( + r=self.args.rank, + lora_alpha=self.args.lora_alpha, + init_lora_weights=True, + target_modules=self.args.target_modules, + ) + self.components.transformer.add_adapter(transformer_lora_config) + self.__prepare_saving_loading_hooks(transformer_lora_config) + + def prepare_optimizer(self) -> None: + logger.info("Initializing optimizer and lr scheduler") + + # Make sure the trainable params are in float32 + if self.args.mixed_precision == "fp16": + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params([self.components.transformer], dtype=torch.float32) + + transformer_lora_parameters = list(filter(lambda p: p.requires_grad, self.components.transformer.parameters())) + transformer_parameters_with_lr = { + "params": transformer_lora_parameters, + "lr": self.args.learning_rate, + } + params_to_optimize = [transformer_parameters_with_lr] + self.state.num_trainable_parameters = sum(p.numel() for p in transformer_lora_parameters) + + use_deepspeed_opt = ( + self.accelerator.state.deepspeed_plugin is not None + and "optimizer" in self.accelerator.state.deepspeed_plugin.deepspeed_config + ) + optimizer = get_optimizer( + params_to_optimize=params_to_optimize, + optimizer_name=self.args.optimizer, + learning_rate=self.args.learning_rate, + beta1=self.args.beta1, + beta2=self.args.beta2, + beta3=self.args.beta3, + epsilon=self.args.epsilon, + weight_decay=self.args.weight_decay, + use_deepspeed=use_deepspeed_opt, + ) + + num_update_steps_per_epoch = math.ceil(len(self.data_loader) / self.args.gradient_accumulation_steps) + if self.args.train_steps is None: + self.args.train_steps = self.args.train_epochs * num_update_steps_per_epoch + self.state.overwrote_max_train_steps = True + + use_deepspeed_lr_scheduler = ( + self.accelerator.state.deepspeed_plugin is not None + and "scheduler" in self.accelerator.state.deepspeed_plugin.deepspeed_config + ) + total_training_steps = self.args.train_steps * self.accelerator.num_processes + num_warmup_steps = self.args.lr_warmup_steps * self.accelerator.num_processes + + if use_deepspeed_lr_scheduler: + from accelerate.utils import DummyScheduler + + lr_scheduler = DummyScheduler( + name=self.args.lr_scheduler, + optimizer=optimizer, + total_num_steps=total_training_steps, + num_warmup_steps=num_warmup_steps, + ) + else: + lr_scheduler = get_scheduler( + name=self.args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=total_training_steps, + num_cycles=self.args.lr_num_cycles, + power=self.args.lr_power, + ) + + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + + def prepare_for_training(self) -> None: + self.components.transformer, self.optimizer, self.data_loader, self.lr_scheduler = self.accelerator.prepare( + self.components.transformer, self.optimizer, self.data_loader, self.lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(self.data_loader) / self.args.gradient_accumulation_steps) + if self.state.overwrote_max_train_steps: + self.args.train_steps = self.args.train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + self.args.train_epochs = math.ceil(self.args.train_steps / num_update_steps_per_epoch) + self.state.num_update_steps_per_epoch = num_update_steps_per_epoch + + def prepare_trackers(self) -> None: + logger.info("Initializing trackers") + + tracker_name = self.args.tracker_name or "finetrainers-experiment" + self.accelerator.init_trackers(tracker_name, config=self.args.model_dump()) + + def train(self) -> None: + logger.info("Starting training") + + memory_statistics = get_memory_statistics() + logger.info(f"Memory before training start: {json.dumps(memory_statistics, indent=4)}") + + self.state.total_batch_size_count = ( + self.args.batch_size * self.accelerator.num_processes * self.args.gradient_accumulation_steps + ) + info = { + "trainable parameters": self.state.num_trainable_parameters, + "total samples": len(self.dataset), + "train epochs": self.args.train_epochs, + "train steps": self.args.train_steps, + "batches per device": self.args.batch_size, + "total batches observed per epoch": len(self.data_loader), + "train batch size total count": self.state.total_batch_size_count, + "gradient accumulation steps": self.args.gradient_accumulation_steps, + } + logger.info(f"Training configuration: {json.dumps(info, indent=4)}") + + global_step = 0 + first_epoch = 0 + initial_global_step = 0 + + # Potentially load in the weights and states from a previous save + ( + resume_from_checkpoint_path, + initial_global_step, + global_step, + first_epoch, + ) = get_latest_ckpt_path_to_resume_from( + resume_from_checkpoint=self.args.resume_from_checkpoint, + num_update_steps_per_epoch=self.state.num_update_steps_per_epoch, + ) + if resume_from_checkpoint_path is not None: + self.accelerator.load_state(resume_from_checkpoint_path) + + progress_bar = tqdm( + range(0, self.args.train_steps), + initial=initial_global_step, + desc="Training steps", + disable=not self.accelerator.is_local_main_process, + ) + + accelerator = self.accelerator + generator = torch.Generator(device=accelerator.device) + if self.args.seed is not None: + generator = generator.manual_seed(self.args.seed) + + for epoch in range(first_epoch, self.args.train_epochs): + logger.debug(f"Starting epoch ({epoch + 1}/{self.args.train_epochs})") + + self.components.transformer.train() + models_to_accumulate = [self.components.transformer] + + for step, batch in enumerate(self.data_loader): + logger.debug(f"Starting step {step + 1}") + logs = {} + + with accelerator.accumulate(models_to_accumulate): + # These weighting schemes use a uniform timestep sampling and instead post-weight the loss + loss = self.compute_loss(batch) + accelerator.backward(loss) + + if accelerator.sync_gradients: + if accelerator.distributed_type == DistributedType.DEEPSPEED: + grad_norm = self.components.transformer.get_global_grad_norm() + # In some cases the grad norm may not return a float + if torch.is_tensor(grad_norm): + grad_norm = grad_norm.item() + else: + grad_norm = accelerator.clip_grad_norm_( + self.components.transformer.parameters(), self.args.max_grad_norm + ) + if torch.is_tensor(grad_norm): + grad_norm = grad_norm.item() + + logs["grad_norm"] = grad_norm + + self.optimizer.step() + self.lr_scheduler.step() + self.optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + self.__maybe_save_checkpoint(global_step) + + # Maybe run validation + should_run_validation = ( + self.args.validation_steps is not None + and global_step % self.args.validation_steps == 0 + ) + if should_run_validation: + self.validate(global_step) + + logs["loss"] = loss.detach().item() + logs["lr"] = self.lr_scheduler.get_last_lr()[0] + progress_bar.set_postfix(logs) + accelerator.log(logs, step=global_step) + + if global_step >= self.args.train_steps: + break + + memory_statistics = get_memory_statistics() + logger.info(f"Memory after epoch {epoch + 1}: {json.dumps(memory_statistics, indent=4)}") + + accelerator.wait_for_everyone() + self.__maybe_save_checkpoint(global_step, must_save=True) + self.validate(global_step) + + del self.components + free_memory() + memory_statistics = get_memory_statistics() + logger.info(f"Memory after training end: {json.dumps(memory_statistics, indent=4)}") + + accelerator.end_training() + + def fit(self): + self.prepare_dataset() + self.prepare_models() + self.prepare_trainable_parameters() + self.prepare_optimizer() + self.prepare_for_training() + self.prepare_trackers() + self.train() + + def collate_fn(self, examples: List[List[Dict[str, Any]]]): + """ + Since we use BucketSampler, the examples parameter is a nested list where the outer list contains only one element, + which is the batch data we need. Therefore, when processing the data, we need to access the batch through examples[0]. + """ + raise NotImplementedError + + def load_components(self) -> Components: + raise NotImplementedError + + def compute_loss(self, batch) -> torch.Tensor: + raise NotImplementedError + + def validate(self) -> None: + raise NotImplementedError + + def __get_training_dtype(self) -> torch.dtype: + if self.args.mixed_precision == "no": + return _DTYPE_MAP["fp32"] + elif self.args.mixed_precision == "fp16": + return _DTYPE_MAP["fp16"] + elif self.args.mixed_precision == "bf16": + return _DTYPE_MAP["bf16"] + else: + raise ValueError(f"Invalid mixed precision: {self.args.mixed_precision}") + + def __move_components_to_device(self): + components = self.components.model_dump() + for name, component in components.items(): + if not isinstance(component, type) and hasattr(component, 'to'): + setattr(self.components, name, component.to(self.accelerator.device)) + + def __prepare_saving_loading_hooks(self, transformer_lora_config): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if self.accelerator.is_main_process: + transformer_lora_layers_to_save = None + + for model in models: + if isinstance( + unwrap_model(self.accelerator, model), + type(unwrap_model(self.accelerator, self.components.transformer)), + ): + model = unwrap_model(self.accelerator, model) + transformer_lora_layers_to_save = get_peft_model_state_dict(model) + else: + raise ValueError(f"Unexpected save model: {model.__class__}") + + # make sure to pop weight so that corresponding model is not saved again + if weights: + weights.pop() + + self.components.pipeline_cls.save_lora_weights( + output_dir, + transformer_lora_layers=transformer_lora_layers_to_save, + ) + + def load_model_hook(models, input_dir): + if not self.accelerator.distributed_type == DistributedType.DEEPSPEED: + while len(models) > 0: + model = models.pop() + if isinstance( + unwrap_model(self.accelerator, model), + type(unwrap_model(self.accelerator, self.components.transformer)), + ): + transformer_ = unwrap_model(self.accelerator, model) + else: + raise ValueError( + f"Unexpected save model: {unwrap_model(self.accelerator, model).__class__}" + ) + else: + transformer_ = unwrap_model(self.accelerator, self.components.transformer).__class__.from_pretrained( + self.args.model_path, subfolder="transformer" + ) + transformer_.add_adapter(transformer_lora_config) + + lora_state_dict = self.components.piepeline_cls.lora_state_dict(input_dir) + transformer_state_dict = { + f'{k.replace("transformer.", "")}': v + for k, v in lora_state_dict.items() + if k.startswith("transformer.") + } + incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + + # Make sure the trainable params are in float32. This is again needed since the base models + # are in `weight_dtype`. More details: + # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 + if self.args.mixed_precision == "fp16": + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params([transformer_]) + + self.accelerator.register_save_state_pre_hook(save_model_hook) + self.accelerator.register_load_state_pre_hook(load_model_hook) + + def __maybe_save_checkpoint(self, global_step: int, must_save: bool = False): + if self.accelerator.distributed_type == DistributedType.DEEPSPEED or self.accelerator.is_main_process: + if must_save or global_step % self.args.checkpointing_steps == 0: + save_path = get_intermediate_ckpt_path( + checkpointing_limit=self.args.checkpointing_limit, + step=global_step, + output_dir=self.args.output_dir, + ) + self.accelerator.save_state(save_path) From 6971364591aad3b3f9262267db8ff48fbf89c75e Mon Sep 17 00:00:00 2001 From: OleehyO Date: Mon, 30 Dec 2024 06:49:45 +0000 Subject: [PATCH 13/25] Export file_utils.py --- finetune/utils/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/finetune/utils/__init__.py b/finetune/utils/__init__.py index 5c6fe5d..7becfd5 100644 --- a/finetune/utils/__init__.py +++ b/finetune/utils/__init__.py @@ -2,3 +2,4 @@ from .torch_utils import * from .optimizer_utils import * from .memory_utils import * from .checkpointing import * +from .file_utils import * From fa4659fb2cc1c31f6be1228b39da80141e0d68a1 Mon Sep 17 00:00:00 2001 From: OleehyO Date: Mon, 30 Dec 2024 06:51:03 +0000 Subject: [PATCH 14/25] feat(trainer): add validation functionality to Trainer class Add validation capabilities to the Trainer class including: - Support for validating images and videos during training - Periodic validation based on validation_steps parameter - Artifact logging to wandb for validation results - Memory tracking during validation process --- finetune/trainer.py | 158 +++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 150 insertions(+), 8 deletions(-) diff --git a/finetune/trainer.py b/finetune/trainer.py index cc9882c..501a7e5 100644 --- a/finetune/trainer.py +++ b/finetune/trainer.py @@ -1,15 +1,18 @@ -import torch +import os import logging -import transformers -import diffusers import math import json -import multiprocessing + +import torch +import transformers +import diffusers +import wandb from datetime import timedelta from pathlib import Path from tqdm import tqdm -from typing import Dict, Any, List +from typing import Dict, Any, List, Tuple +from PIL import Image from torch.utils.data import Dataset, DataLoader from accelerate.logging import get_logger @@ -23,6 +26,7 @@ from accelerate.utils import ( ) from diffusers.optimization import get_scheduler +from diffusers.utils.export_utils import export_to_video from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict from finetune.schemas import Args, State, Components @@ -37,8 +41,14 @@ from finetune.utils import ( get_intermediate_ckpt_path, get_latest_ckpt_path_to_resume_from, get_intermediate_ckpt_path, + + string_to_filename ) from finetune.datasets import I2VDatasetWithBuckets, T2VDatasetWithBuckets, BucketSampler +from finetune.datasets.utils import ( + load_prompts, load_images, load_videos, + preprocess_image_with_resize, preprocess_video_with_resize +) from finetune.constants import LOG_NAME, LOG_LEVEL @@ -263,6 +273,24 @@ class Trainer: # Afterwards we recalculate our number of training epochs self.args.train_epochs = math.ceil(self.args.train_steps / num_update_steps_per_epoch) self.state.num_update_steps_per_epoch = num_update_steps_per_epoch + + def prepare_for_validation(self): + self.state.gen_frames, self.state.gen_height, self.state.gen_width = [int(elem) for elem in self.args.gen_video_resolution.split('x')] + validation_prompts = load_prompts(self.args.validation_dir / self.args.validation_prompts) + + if self.args.validation_images is not None: + validation_images = load_images(self.args.validation_dir / self.args.validation_images) + else: + validation_images = [None] * len(validation_prompts) + + if self.args.validation_videos is not None: + validation_videos = load_videos(self.args.validation_dir / self.args.validation_videos) + else: + validation_videos = [None] * len(validation_prompts) + + self.state.validation_prompts = validation_prompts + self.state.validation_images = validation_images + self.state.validation_videos = validation_videos def prepare_trackers(self) -> None: logger.info("Initializing trackers") @@ -319,6 +347,7 @@ class Trainer: generator = torch.Generator(device=accelerator.device) if self.args.seed is not None: generator = generator.manual_seed(self.args.seed) + self.state.generator = generator for epoch in range(first_epoch, self.args.train_epochs): logger.debug(f"Starting epoch ({epoch + 1}/{self.args.train_epochs})") @@ -362,7 +391,7 @@ class Trainer: # Maybe run validation should_run_validation = ( - self.args.validation_steps is not None + self.args.do_validation and global_step % self.args.validation_steps == 0 ) if should_run_validation: @@ -381,7 +410,8 @@ class Trainer: accelerator.wait_for_everyone() self.__maybe_save_checkpoint(global_step, must_save=True) - self.validate(global_step) + if self.args.do_validation: + self.validate(global_step) del self.components free_memory() @@ -390,12 +420,124 @@ class Trainer: accelerator.end_training() + def validate(self, step: int) -> None: + logger.info("Starting validation") + + accelerator = self.accelerator + num_validation_samples = len(self.state.validation_prompts) + + if num_validation_samples == 0: + logger.warning("No validation samples found. Skipping validation.") + return + + self.components.transformer.eval() + + memory_statistics = get_memory_statistics() + logger.info(f"Memory before validation start: {json.dumps(memory_statistics, indent=4)}") + + all_processes_artifacts = [] + for i in range(num_validation_samples): + # Skip current validation on all processes but one + if i % accelerator.num_processes != accelerator.process_index: + continue + + prompt = self.state.validation_prompts[i] + image = self.state.validation_images[i] + video = self.state.validation_videos[i] + + if image is not None: + image = preprocess_image_with_resize( + image, self.state.gen_height, self.state.gen_width + ) + # Convert image tensor (C, H, W) to PIL images + image = image.to(torch.uint8) + image = image.permute(1, 2, 0).cpu().numpy() + image = Image.fromarray(image) + + if video is not None: + video = preprocess_video_with_resize( + video, self.state.gen_frames, self.state.gen_height, self.state.gen_width + ) + # Convert video tensor (F, C, H, W) to list of PIL images + video = (video * 255).round().clamp(0, 255).to(torch.uint8) + video = [Image.fromarray(frame.permute(1,2,0).cpu().numpy()) for frame in video] + + logger.debug( + f"Validating sample {i + 1}/{num_validation_samples} on process {accelerator.process_index}. Prompt: {prompt}", + main_process_only=False, + ) + validation_artifacts = self.validation_step({ + "prompt": prompt, + "image": image, + "video": video + }) + prompt_filename = string_to_filename(prompt)[:25] + artifacts = { + "image": {"type": "image", "value": image}, + "video": {"type": "video", "value": video}, + } + for i, (artifact_type, artifact_value) in enumerate(validation_artifacts): + artifacts.update({f"artifact_{i}": {"type": artifact_type, "value": artifact_value}}) + logger.debug( + f"Validation artifacts on process {accelerator.process_index}: {list(artifacts.keys())}", + main_process_only=False, + ) + + for key, value in list(artifacts.items()): + artifact_type = value["type"] + artifact_value = value["value"] + if artifact_type not in ["image", "video"] or artifact_value is None: + continue + + extension = "png" if artifact_type == "image" else "mp4" + filename = f"validation-{step}-{accelerator.process_index}-{prompt_filename}.{extension}" + validation_path = self.args.output_dir / "validation_res" + validation_path.mkdir(parents=True, exist_ok=True) + filename = str(validation_path / filename) + + if artifact_type == "image": + logger.debug(f"Saving image to {filename}") + artifact_value.save(filename) + artifact_value = wandb.Image(filename) + elif artifact_type == "video": + logger.debug(f"Saving video to {filename}") + export_to_video(artifact_value, filename, fps=self.args.gen_fps) + artifact_value = wandb.Video(filename, caption=prompt) + + all_processes_artifacts.append(artifact_value) + + all_artifacts = gather_object(all_processes_artifacts) + + if accelerator.is_main_process: + tracker_key = "validation" + for tracker in accelerator.trackers: + if tracker.name == "wandb": + image_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Image)] + video_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Video)] + tracker.log( + { + tracker_key: {"images": image_artifacts, "videos": video_artifacts}, + }, + step=step, + ) + + accelerator.wait_for_everyone() + + free_memory() + memory_statistics = get_memory_statistics() + logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}") + torch.cuda.reset_peak_memory_stats(accelerator.device) + + self.components.transformer.train() + def fit(self): self.prepare_dataset() self.prepare_models() self.prepare_trainable_parameters() self.prepare_optimizer() self.prepare_for_training() + if self.args.do_validation: + self.prepare_for_validation() self.prepare_trackers() self.train() @@ -412,7 +554,7 @@ class Trainer: def compute_loss(self, batch) -> torch.Tensor: raise NotImplementedError - def validate(self) -> None: + def validation_step(self) -> List[Tuple[str, Image.Image | List[Image.Image]]]: raise NotImplementedError def __get_training_dtype(self) -> torch.dtype: From 2a6cca0656e1c91db466d77ec58cd77af1ffc510 Mon Sep 17 00:00:00 2001 From: OleehyO Date: Mon, 30 Dec 2024 06:53:23 +0000 Subject: [PATCH 15/25] Add type conversion and validation checks --- finetune/datasets/utils.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/finetune/datasets/utils.py b/finetune/datasets/utils.py index cf0525e..ba7bddf 100644 --- a/finetune/datasets/utils.py +++ b/finetune/datasets/utils.py @@ -33,7 +33,7 @@ def load_images(image_path: Path) -> List[Path]: ########## preprocessors ########## def preprocess_image_with_resize( - image_path: Path, + image_path: Path | str, height: int, width: int, ) -> torch.Tensor: @@ -51,6 +51,8 @@ def preprocess_image_with_resize( H = height W = width """ + if isinstance(image_path, str): + image_path = Path(image_path) image = cv2.imread(image_path.as_posix()) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) image = cv2.resize(image, (width, height)) @@ -60,7 +62,7 @@ def preprocess_image_with_resize( def preprocess_video_with_resize( - video_path: Path, + video_path: Path | str, max_num_frames: int, height: int, width: int, @@ -85,8 +87,12 @@ def preprocess_video_with_resize( H = height W = width """ + if isinstance(video_path, str): + video_path = Path(video_path) video_reader = decord.VideoReader(uri=video_path.as_posix(), width=width, height=height) video_num_frames = len(video_reader) + if video_num_frames < max_num_frames: + raise ValueError(f"video's frames is less than {max_num_frames}.") indices = list(range(0, video_num_frames, video_num_frames // max_num_frames)) frames = video_reader.get_batch(indices) From 6eae5c201e24150b59a87f2cb39102698aa2484f Mon Sep 17 00:00:00 2001 From: OleehyO Date: Mon, 30 Dec 2024 16:10:06 +0000 Subject: [PATCH 16/25] feat: add latent caching for video encodings - Add caching mechanism to store VAE-encoded video latents to disk - Cache latents in a "latent" subdirectory alongside video files - Skip re-encoding when cached latent file exists - Add logging for successful cache saves - Minor code cleanup and formatting improvements This change improves training efficiency by avoiding redundant video encoding operations. --- finetune/datasets/i2v_dataset.py | 115 ++++++++++++++++++++++++------- finetune/datasets/t2v_dataset.py | 90 +++++++++++++++++++----- 2 files changed, 164 insertions(+), 41 deletions(-) diff --git a/finetune/datasets/i2v_dataset.py b/finetune/datasets/i2v_dataset.py index 2137dce..e993c96 100644 --- a/finetune/datasets/i2v_dataset.py +++ b/finetune/datasets/i2v_dataset.py @@ -1,11 +1,13 @@ +import torch + from pathlib import Path -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Tuple, Callable from typing_extensions import override -import torch from accelerate.logging import get_logger from torch.utils.data import Dataset from torchvision import transforms +from finetune.constants import LOG_NAME, LOG_LEVEL from .utils import ( load_prompts, load_videos, load_images, @@ -21,12 +23,22 @@ import decord # isort:skip decord.bridge.set_bridge("torch") -logger = get_logger(__name__) +logger = get_logger(LOG_NAME, LOG_LEVEL) class BaseI2VDataset(Dataset): """ + Base dataset class for Image-to-Video (I2V) training. + This dataset loads prompts, videos and corresponding conditioning images for I2V training. + + Args: + data_root (str): Root directory containing the dataset files + caption_column (str): Path to file containing text prompts/captions + video_column (str): Path to file containing video paths + image_column (str): Path to file containing image paths + device (torch.device): Device to load the data on + encode_video_fn (Callable[[torch.Tensor], torch.Tensor], optional): Function to encode videos """ def __init__( self, @@ -34,6 +46,8 @@ class BaseI2VDataset(Dataset): caption_column: str, video_column: str, image_column: str, + device: torch.device, + encode_video_fn: Callable[[torch.Tensor], torch.Tensor] = None, *args, **kwargs ) -> None: @@ -44,6 +58,9 @@ class BaseI2VDataset(Dataset): self.videos = load_videos(data_root / video_column) self.images = load_images(data_root / image_column) + self.device = device + self.encode_video_fn = encode_video_fn + # Check if number of prompts matches number of videos and images if not (len(self.videos) == len(self.prompts) == len(self.images)): raise ValueError( @@ -79,28 +96,48 @@ class BaseI2VDataset(Dataset): return index prompt = self.prompts[index] + video = self.videos[index] + image = self.images[index] - # shape of frames: [F, C, H, W] + video_latent_dir = video.parent / "latent" + video_latent_dir.mkdir(parents=True, exist_ok=True) + encoded_video_path = video_latent_dir / (video.stem + ".pt") + + if encoded_video_path.exists(): + encoded_video = torch.load(encoded_video_path, weights_only=True) + # shape of image: [C, H, W] + _, image = self.preprocess(None, self.images[index]) + else: + frames, image = self.preprocess(video, image) + frames = frames.to(self.device) + # current shape of frames: [F, C, H, W] + frames = self.video_transform(frames) + # Convert to [B, C, F, H, W] + frames = frames.unsqueeze(0) + frames = frames.permute(0, 2, 1, 3, 4).contiguous() + encoded_video = self.encode_video_fn(frames) + # [B, C, F, H, W] -> [C, F, H, W] + encoded_video = encoded_video[0].cpu() + torch.save(encoded_video, encoded_video_path) + logger.info(f"Saved encoded video to {encoded_video_path}", main_process_only=False) + + # shape of encoded_video: [C, F, H, W] # shape of image: [C, H, W] - frames, image = self.preprocess(self.videos[index], self.images[index]) - - frames = self.video_transform(frames) - image = self.image_transform(image) - return { "prompt": prompt, - "video": frames, + "image": image, + "encoded_video": encoded_video, "video_metadata": { - "num_frames": frames.shape[0], - "height": frames.shape[2], - "width": frames.shape[3], + "num_frames": encoded_video.shape[1], + "height": encoded_video.shape[2], + "width": encoded_video.shape[3], }, - "image": image } - def preprocess(self, video_path: Path, image_path: Path) -> Tuple[torch.Tensor, torch.Tensor]: + def preprocess(self, video_path: Path | None, image_path: Path | None) -> Tuple[torch.Tensor, torch.Tensor]: """ Loads and preprocesses a video and an image. + If either path is None, no preprocessing will be done for that input. Args: video_path: Path to the video file to load @@ -150,7 +187,16 @@ class BaseI2VDataset(Dataset): class I2VDatasetWithResize(BaseI2VDataset): """ + A dataset class for image-to-video generation that resizes inputs to fixed dimensions. + This class preprocesses videos and images by resizing them to specified dimensions: + - Videos are resized to max_num_frames x height x width + - Images are resized to height x width + + Args: + max_num_frames (int): Maximum number of frames to extract from videos + height (int): Target height for resizing videos and images + width (int): Target width for resizing videos and images """ def __init__(self, max_num_frames: int, height: int, width: int, *args, **kwargs) -> None: super().__init__(*args, **kwargs) @@ -164,26 +210,49 @@ class I2VDatasetWithResize(BaseI2VDataset): transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0) ] ) - + self.__image_transforms = self.__frame_transforms + @override - def preprocess(self, video_path: Path, image_path: Path) -> Tuple[torch.Tensor, torch.Tensor]: - video = preprocess_video_with_resize(video_path, self.max_num_frames, self.height, self.width) - image = preprocess_image_with_resize(image_path, self.height, self.width) + def preprocess(self, video_path: Path | None, image_path: Path | None) -> Tuple[torch.Tensor, torch.Tensor]: + if video_path is not None: + video = preprocess_video_with_resize(video_path, self.max_num_frames, self.height, self.width) + else: + video = None + if image_path is not None: + image = preprocess_image_with_resize(image_path, self.height, self.width) + else: + image = None return video, image @override def video_transform(self, frames: torch.Tensor) -> torch.Tensor: return torch.stack([self.__frame_transforms(f) for f in frames], dim=0) + + @override + def image_transform(self, image: torch.Tensor) -> torch.Tensor: + return self.__image_transforms(image) class I2VDatasetWithBuckets(BaseI2VDataset): - """ - """ - def __init__(self, video_resolution_buckets: List[Tuple[int, int, int]], *args, **kwargs) -> None: + def __init__( + self, + video_resolution_buckets: List[Tuple[int, int, int]], + vae_temporal_compression_ratio: int, + vae_height_compression_ratio: int, + vae_width_compression_ratio: int, + *args, **kwargs + ) -> None: super().__init__(*args, **kwargs) - self.video_resolution_buckets = video_resolution_buckets + self.video_resolution_buckets = [ + ( + int(b[0] / vae_temporal_compression_ratio), + int(b[1] / vae_height_compression_ratio), + int(b[2] / vae_width_compression_ratio), + ) + for b in video_resolution_buckets + ] self.__frame_transforms = transforms.Compose( [ transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0) diff --git a/finetune/datasets/t2v_dataset.py b/finetune/datasets/t2v_dataset.py index aa4ed72..9afe53a 100644 --- a/finetune/datasets/t2v_dataset.py +++ b/finetune/datasets/t2v_dataset.py @@ -1,12 +1,15 @@ +import torch + from pathlib import Path -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Tuple, Callable from typing_extensions import override -import torch from accelerate.logging import get_logger from torch.utils.data import Dataset from torchvision import transforms +from finetune.constants import LOG_NAME, LOG_LEVEL + from .utils import ( load_prompts, load_videos, preprocess_video_with_resize, @@ -19,18 +22,30 @@ import decord # isort:skip decord.bridge.set_bridge("torch") -logger = get_logger(__name__) +logger = get_logger(LOG_NAME, LOG_LEVEL) class BaseT2VDataset(Dataset): """ + Base dataset class for Text-to-Video (T2V) training. + This dataset loads prompts and videos for T2V training. + + Args: + data_root (str): Root directory containing the dataset files + caption_column (str): Path to file containing text prompts/captions + video_column (str): Path to file containing video paths + device (torch.device): Device to load the data on + encode_video_fn (Callable[[torch.Tensor], torch.Tensor], optional): Function to encode videos """ + def __init__( self, data_root: str, caption_column: str, video_column: str, + device: torch.device = None, + encode_video_fn: Callable[[torch.Tensor], torch.Tensor] = None, *args, **kwargs ) -> None: @@ -39,6 +54,8 @@ class BaseT2VDataset(Dataset): data_root = Path(data_root) self.prompts = load_prompts(data_root / caption_column) self.videos = load_videos(data_root / video_column) + self.device = device + self.encode_video_fn = encode_video_fn # Check if all video files exist if any(not path.is_file() for path in self.videos): @@ -69,18 +86,36 @@ class BaseT2VDataset(Dataset): return index prompt = self.prompts[index] + video = self.videos[index] - # shape of frames: [F, C, H, W] - frames = self.preprocess(self.videos[index]) - frames = self.video_transform(frames) + latent_dir = video.parent / "latent" + latent_dir.mkdir(parents=True, exist_ok=True) + encoded_video_path = latent_dir / (video.stem + ".pt") + + if encoded_video_path.exists(): + # shape of encoded_video: [C, F, H, W] + encoded_video = torch.load(encoded_video_path, weights_only=True) + else: + frames = self.preprocess(video) + frames = frames.to(self.device) + # current shape of frames: [F, C, H, W] + frames = self.video_transform(frames) + # Convert to [B, C, F, H, W] + frames = frames.unsqueeze(0) + frames = frames.permute(0, 2, 1, 3, 4).contiguous() + encoded_video = self.encode_video_fn(frames) + # [B, C, F, H, W] -> [C, F, H, W] + encoded_video = encoded_video[0].cpu() + torch.save(encoded_video, encoded_video_path) + logger.info(f"Saved encoded video to {encoded_video_path}", main_process_only=False) return { "prompt": prompt, - "video": frames, + "encoded_video": encoded_video, "video_metadata": { - "num_frames": frames.shape[0], - "height": frames.shape[2], - "width": frames.shape[3], + "num_frames": encoded_video.shape[1], + "height": encoded_video.shape[2], + "width": encoded_video.shape[3], }, } @@ -113,15 +148,24 @@ class BaseT2VDataset(Dataset): - W is width Returns: - torch.Tensor: The transformed video tensor + torch.Tensor: The transformed video tensor with the same shape as the input """ raise NotImplementedError("Subclass must implement this method") class T2VDatasetWithResize(BaseT2VDataset): """ + A dataset class for text-to-video generation that resizes inputs to fixed dimensions. + This class preprocesses videos by resizing them to specified dimensions: + - Videos are resized to max_num_frames x height x width + + Args: + max_num_frames (int): Maximum number of frames to extract from videos + height (int): Target height for resizing videos + width (int): Target width for resizing videos """ + def __init__(self, max_num_frames: int, height: int, width: int, *args, **kwargs) -> None: super().__init__(*args, **kwargs) @@ -147,18 +191,28 @@ class T2VDatasetWithResize(BaseT2VDataset): class T2VDatasetWithBuckets(BaseT2VDataset): - """ - """ - def __init__(self, video_resolution_buckets: List[Tuple[int, int, int]], *args, **kwargs) -> None: + def __init__( + self, + video_resolution_buckets: List[Tuple[int, int, int]], + vae_temporal_compression_ratio: int, + vae_height_compression_ratio: int, + vae_width_compression_ratio: int, + *args, **kwargs + ) -> None: """ - Args: - resolution_buckets: List of tuples representing the resolution buckets. - Each tuple contains three integers: (max_num_frames, height, width). + """ super().__init__(*args, **kwargs) - self.video_resolution_buckets = video_resolution_buckets + self.video_resolution_buckets = [ + ( + int(b[0] / vae_temporal_compression_ratio), + int(b[1] / vae_height_compression_ratio), + int(b[2] / vae_width_compression_ratio), + ) + for b in video_resolution_buckets + ] self.__frame_transform = transforms.Compose( [ From 45d40450a1ba50ceaf1d57b6fec7917cd13f8d80 Mon Sep 17 00:00:00 2001 From: OleehyO Date: Mon, 30 Dec 2024 16:14:46 +0000 Subject: [PATCH 17/25] refactor: simplify dataset implementation and add latent precomputation - Replace bucket-based dataset with simpler resize-based implementation - Add video latent precomputation during dataset initialization - Improve code readability and user experience - Remove complexity of bucket sampling for better maintainability This change makes the codebase more straightforward and easier to use while maintaining functionality through resize-based video processing. --- finetune/trainer.py | 95 +++++++++++++++++++++++++++++---------------- 1 file changed, 61 insertions(+), 34 deletions(-) diff --git a/finetune/trainer.py b/finetune/trainer.py index 501a7e5..5b02f0e 100644 --- a/finetune/trainer.py +++ b/finetune/trainer.py @@ -44,7 +44,7 @@ from finetune.utils import ( string_to_filename ) -from finetune.datasets import I2VDatasetWithBuckets, T2VDatasetWithBuckets, BucketSampler +from finetune.datasets import I2VDatasetWithResize, T2VDatasetWithResize, BucketSampler from finetune.datasets.utils import ( load_prompts, load_images, load_videos, preprocess_image_with_resize, preprocess_video_with_resize @@ -80,7 +80,6 @@ class Trainer: self._init_logging() self._init_directories() - def _init_distributed(self): logging_dir = Path(self.args.output_dir, "logs") project_config = ProjectConfiguration(project_dir=self.args.output_dir, logging_dir=logging_dir) @@ -108,7 +107,6 @@ class Trainer: if self.args.seed is not None: set_seed(self.args.seed) - def _init_logging(self) -> None: logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", @@ -124,33 +122,12 @@ class Trainer: logger.info("Initialized Trainer") logger.info(f"Accelerator state: \n{self.accelerator.state}", main_process_only=False) - def _init_directories(self) -> None: if self.accelerator.is_main_process: self.args.output_dir = Path(self.args.output_dir) self.args.output_dir.mkdir(parents=True, exist_ok=True) - - def prepare_dataset(self) -> None: - logger.info("Initializing dataset and dataloader") - - if self.args.model_type == "i2v": - self.dataset = I2VDatasetWithBuckets(**(self.args.model_dump())) - elif self.args.model_type == "t2v": - self.dataset = T2VDatasetWithBuckets(**(self.args.model_dump())) - else: - raise ValueError(f"Invalid model type: {self.args.model_type}") - - self.data_loader = torch.utils.data.DataLoader( - self.dataset, - batch_size=1, - sampler=BucketSampler(self.dataset, batch_size=self.args.batch_size, shuffle=True), - collate_fn=self.collate_fn, - num_workers=self.args.num_workers, - pin_memory=self.args.pin_memory, - ) - def prepare_models(self) -> None: logger.info("Initializing models") @@ -163,6 +140,57 @@ class Trainer: if self.args.enable_tiling: self.components.vae.enable_tiling() + def prepare_dataset(self) -> None: + logger.info("Initializing dataset and dataloader") + + if self.args.model_type == "i2v": + self.dataset = I2VDatasetWithResize( + **(self.args.model_dump()), + device=self.accelerator.device, + encode_fn=self.encode_video, + max_num_frames=self.args.train_resolution[0], + height=self.args.train_resolution[1], + width=self.args.train_resolution[2] + ) + elif self.args.model_type == "t2v": + self.dataset = T2VDatasetWithResize( + **(self.args.model_dump()), + device=self.accelerator.device, + encode_fn=self.encode_video, + max_num_frames=self.args.train_resolution[0], + height=self.args.train_resolution[1], + width=self.args.train_resolution[2] + ) + else: + raise ValueError(f"Invalid model type: {self.args.model_type}") + + # Prepare VAE for encoding + self.components.vae = self.components.vae.to(self.accelerator.device) + self.components.vae.requires_grad_(False) + + # Precompute latent for video + logger.info("Precomputing latent for video ...") + tmp_data_loader = torch.utils.data.DataLoader( + self.dataset, + collate_fn=self.collate_fn, + batch_size=1, + num_workers=0, + pin_memory=self.args.pin_memory, + ) + tmp_data_loader = self.accelerator.prepare_data_loader(tmp_data_loader) + for _ in tmp_data_loader: ... + logger.info("Precomputing latent for video ... Done") + + self.data_loader = torch.utils.data.DataLoader( + self.dataset, + collate_fn=self.collate_fn, + batch_size=self.args.batch_size, + num_workers=self.args.num_workers, + pin_memory=self.args.pin_memory, + shuffle=True + ) + + def prepare_trainable_parameters(self): logger.info("Initializing trainable parameters") @@ -275,7 +303,6 @@ class Trainer: self.state.num_update_steps_per_epoch = num_update_steps_per_epoch def prepare_for_validation(self): - self.state.gen_frames, self.state.gen_height, self.state.gen_width = [int(elem) for elem in self.args.gen_video_resolution.split('x')] validation_prompts = load_prompts(self.args.validation_dir / self.args.validation_prompts) if self.args.validation_images is not None: @@ -447,7 +474,7 @@ class Trainer: if image is not None: image = preprocess_image_with_resize( - image, self.state.gen_height, self.state.gen_width + image, self.args.train_resolution[1], self.args.train_resolution[2] ) # Convert image tensor (C, H, W) to PIL images image = image.to(torch.uint8) @@ -456,7 +483,7 @@ class Trainer: if video is not None: video = preprocess_video_with_resize( - video, self.state.gen_frames, self.state.gen_height, self.state.gen_width + video, self.args.train_resolution[0], self.args.train_resolution[1], self.args.train_resolution[2] ) # Convert video tensor (F, C, H, W) to list of PIL images video = (video * 255).round().clamp(0, 255).to(torch.uint8) @@ -531,8 +558,8 @@ class Trainer: self.components.transformer.train() def fit(self): - self.prepare_dataset() self.prepare_models() + self.prepare_dataset() self.prepare_trainable_parameters() self.prepare_optimizer() self.prepare_for_training() @@ -541,15 +568,15 @@ class Trainer: self.prepare_trackers() self.train() - def collate_fn(self, examples: List[List[Dict[str, Any]]]): - """ - Since we use BucketSampler, the examples parameter is a nested list where the outer list contains only one element, - which is the batch data we need. Therefore, when processing the data, we need to access the batch through examples[0]. - """ + def collate_fn(self, examples: List[Dict[str, Any]]): raise NotImplementedError def load_components(self) -> Components: raise NotImplementedError + + def encode_video(self, video: torch.Tensor) -> torch.Tensor: + # shape of input video: [B, C, F, H, W], where B = 1 + raise NotImplementedError def compute_loss(self, batch) -> torch.Tensor: raise NotImplementedError @@ -617,7 +644,7 @@ class Trainer: ) transformer_.add_adapter(transformer_lora_config) - lora_state_dict = self.components.piepeline_cls.lora_state_dict(input_dir) + lora_state_dict = self.components.pipeline_cls.lora_state_dict(input_dir) transformer_state_dict = { f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() From 91d79fd9a45a3b03436c9edca64e071d3538331f Mon Sep 17 00:00:00 2001 From: OleehyO Date: Tue, 31 Dec 2024 17:25:02 +0000 Subject: [PATCH 18/25] feat: add schemas module for configuration and state management Add Pydantic models to handle: - CLI arguments and configuration (Args) - Model components and pipeline (Components) - Training state and parameters (State) --- finetune/schemas/__init__.py | 5 ++ finetune/schemas/args.py | 147 +++++++++++++++++++++++++++++++++ finetune/schemas/components.py | 27 ++++++ finetune/schemas/state.py | 26 ++++++ 4 files changed, 205 insertions(+) create mode 100644 finetune/schemas/__init__.py create mode 100644 finetune/schemas/args.py create mode 100644 finetune/schemas/components.py create mode 100644 finetune/schemas/state.py diff --git a/finetune/schemas/__init__.py b/finetune/schemas/__init__.py new file mode 100644 index 0000000..76a6bf8 --- /dev/null +++ b/finetune/schemas/__init__.py @@ -0,0 +1,5 @@ +from .args import Args +from .state import State +from .components import Components + +__all__ = ["Args", "State", "Components"] \ No newline at end of file diff --git a/finetune/schemas/args.py b/finetune/schemas/args.py new file mode 100644 index 0000000..4745f7c --- /dev/null +++ b/finetune/schemas/args.py @@ -0,0 +1,147 @@ +import datetime +import argparse +from typing import Dict, Any, Literal, List, Tuple +from pydantic import BaseModel, field_validator + +from pathlib import Path + + +class Args(BaseModel): + ########## Model ########## + model_path: Path + model_name: str + model_type: Literal["i2v", "t2v"] + training_type: Literal["lora", "sft"] = "lora" + + ########## Output ########## + output_dir: Path = Path("train_results/{:%Y-%m-%d-%H-%M-%S}".format(datetime.datetime.now())) + report_to: Literal["tensorboard", "wandb", "all"] | None = None + tracker_name: str = "finetrainer-cogvideo" + + ########## Data ########### + data_root: Path + caption_column: Path + image_column: Path | None = None + video_column: Path + + ########## Training ######### + resume_from_checkpoint: Path | None = None + + seed: int | None = None + train_epochs: int + train_steps: int | None = None + checkpointing_steps: int = 500 + checkpointing_limit: int = 10 + + batch_size: int + gradient_accumulation_steps: int = 1 + + train_resolution: Tuple[int, int, int] # shape: (frames, height, width) + + #### deprecated args: video_resolution_buckets + # if use bucket for training, should not be None + # Note1: At least one frame rate in the bucket must be less than or equal to the frame rate of any video in the dataset + # Note2: For cogvideox, cogvideox1.5 + # The frame rate set in the bucket must be an integer multiple of 8 (spatial_compression_rate[4] * path_t[2] = 8) + # The height and width set in the bucket must be an integer multiple of 8 (temporal_compression_rate[8]) + # video_resolution_buckets: List[Tuple[int, int, int]] | None = None + + mixed_precision: Literal["no", "fp16", "bf16"] + + learning_rate: float = 2e-5 + optimizer: str = "adamw" + beta1: float = 0.9 + beta2: float = 0.95 + beta3: float = 0.98 + epsilon: float = 1e-8 + weight_decay: float = 1e-4 + max_grad_norm: float = 1.0 + + lr_scheduler: str = "constant_with_warmup" + lr_warmup_steps: int = 100 + lr_num_cycles: int = 1 + lr_power: float = 1.0 + + num_workers: int = 8 + pin_memory: bool = True + + gradient_checkpointing: bool = True + enable_slicing: bool = True + enable_tiling: bool = True + nccl_timeout: int = 1800 + + ########## Lora ########## + rank: int = 128 + lora_alpha: int = 64 + target_modules: List[str] = ["to_q", "to_k", "to_v", "to_out.0"] + + ########## Validation ########## + do_validation: bool = False + validation_steps: int | None = None # if set, should be a multiple of checkpointing_steps + validation_dir: Path | None # if set do_validation, should not be None + validation_prompts: str | None # if set do_validation, should not be None + validation_images: str | None # if set do_validation and model_type == i2v, should not be None + validation_videos: str | None # if set do_validation and model_type == v2v, should not be None + gen_fps: int = 15 + + #### deprecated args: gen_video_resolution + # 1. If set do_validation, should not be None + # 2. Suggest selecting the bucket from `video_resolution_buckets` that is closest to the resolution you have chosen for fine-tuning + # or the resolution recommended by the model + # 3. Note: For cogvideox, cogvideox1.5 + # The frame rate set in the bucket must be an integer multiple of 8 (spatial_compression_rate[4] * path_t[2] = 8) + # The height and width set in the bucket must be an integer multiple of 8 (temporal_compression_rate[8]) + # gen_video_resolution: Tuple[int, int, int] | None # shape: (frames, height, width) + + + @classmethod + def parse_args(cls): + """Parse command line arguments and return Args instance""" + parser = argparse.ArgumentParser() + parser.add_argument("--model_path", type=str, required=True) + parser.add_argument("--model_name", type=str, required=True) + parser.add_argument("--model_type", type=str, required=True) + parser.add_argument("--training_type", type=str, required=True) + parser.add_argument("--output_dir", type=str, required=True) + parser.add_argument("--seed", type=int, required=True) + parser.add_argument("--nccl_timeout", type=int, required=True) + parser.add_argument("--mixed_precision", type=str, required=True) + parser.add_argument("--gradient_accumulation_steps", type=int, required=True) + parser.add_argument("--data_root", type=str, required=True) + parser.add_argument("--caption_column", type=str, required=True) + parser.add_argument("--video_column", type=str, required=True) + parser.add_argument("--image_column", type=str) + parser.add_argument("--train_resolution", type=str, required=True) + parser.add_argument("--batch_size", type=int, required=True) + parser.add_argument("--num_workers", type=int, required=True) + parser.add_argument("--pin_memory", type=str, required=True) + parser.add_argument("--report_to", type=str, required=True) + parser.add_argument("--train_epochs", type=int, required=True) + parser.add_argument("--checkpointing_steps", type=int, required=True) + parser.add_argument("--checkpointing_limit", type=int, required=True) + + parser.add_argument("--do_validation", type=bool) + parser.add_argument("--validation_steps", type=int) + parser.add_argument("--validation_dir", type=str) + parser.add_argument("--validation_prompts", type=str) + parser.add_argument("--validation_images", type=str) + parser.add_argument("--validation_videos", type=str) + parser.add_argument("--gen_fps", type=int) + + parser.add_argument("--resume_from_checkpoint", type=str) + + args = parser.parse_args() + + # Convert video_resolution_buckets string to list of tuples + frames, height, width = args.train_resolution.split("x") + args.train_resolution = (int(frames), int(height), int(width)) + + return cls(**vars(args)) + + # @field_validator("...", mode="after") + # def foo(cls, foobar): + # ... + + # @field_validator("...", mode="before") + # def bar(cls, barbar): + # ... diff --git a/finetune/schemas/components.py b/finetune/schemas/components.py new file mode 100644 index 0000000..2d3fef5 --- /dev/null +++ b/finetune/schemas/components.py @@ -0,0 +1,27 @@ +from typing import Any +from pydantic import BaseModel + + +class Components(BaseModel): + # pipeline cls + pipeline_cls: Any = None + + # Tokenizers + tokenizer: Any = None + tokenizer_2: Any = None + tokenizer_3: Any = None + + # Text encoders + text_encoder: Any = None + text_encoder_2: Any = None + text_encoder_3: Any = None + + # Autoencoder + vae: Any = None + + # Denoiser + transformer: Any = None + unet: Any = None + + # Scheduler + scheduler: Any = None diff --git a/finetune/schemas/state.py b/finetune/schemas/state.py new file mode 100644 index 0000000..9d5363c --- /dev/null +++ b/finetune/schemas/state.py @@ -0,0 +1,26 @@ +import torch + +from pathlib import Path +from typing import List, Dict, Any +from pydantic import BaseModel, field_validator + +class State(BaseModel): + model_config = {"arbitrary_types_allowed": True} + + train_frames: int + train_height: int + train_width: int + + transformer_config: Dict[str, Any] = None + + weight_dtype: torch.dtype = torch.float32 + num_trainable_parameters: int = 0 + overwrote_max_train_steps: bool = False + num_update_steps_per_epoch: int = 0 + total_batch_size_count: int = 0 + + generator: torch.Generator | None = None + + validation_prompts: List[str] = [] + validation_images: List[Path | None] = [] + validation_videos: List[Path | None] = [] From a0018428346d00970cabb03a74369a2cf531e514 Mon Sep 17 00:00:00 2001 From: OleehyO Date: Tue, 31 Dec 2024 17:27:47 +0000 Subject: [PATCH 19/25] feat: implement CogVideoX trainers for I2V and T2V tasks Add and refactor trainers for CogVideoX model variants: - Implement CogVideoXT2VLoraTrainer for text-to-video generation - Refactor CogVideoXI2VLoraTrainer for image-to-video generation Both trainers support LoRA fine-tuning with proper handling of: - Model components loading and initialization - Video encoding and batch collation - Loss computation with noise prediction - Validation step for generation --- .../models/cogvideox1dot5_i2v/lora_trainer.py | 26 +- .../models/cogvideox1dot5_t2v/lora_trainer.py | 26 +- finetune/models/cogvideox_i2v/lora_trainer.py | 229 +++++++++++++++++- finetune/models/cogvideox_t2v/lora_trainer.py | 214 ++++++++++++++++ finetune/trainer.py | 31 ++- 5 files changed, 459 insertions(+), 67 deletions(-) create mode 100644 finetune/models/cogvideox_t2v/lora_trainer.py diff --git a/finetune/models/cogvideox1dot5_i2v/lora_trainer.py b/finetune/models/cogvideox1dot5_i2v/lora_trainer.py index 6ef9dd4..09d4b70 100644 --- a/finetune/models/cogvideox1dot5_i2v/lora_trainer.py +++ b/finetune/models/cogvideox1dot5_i2v/lora_trainer.py @@ -1,29 +1,9 @@ -import torch - -from typing_extensions import override -from typing import Any, Dict, List - -from finetune.trainer import Trainer from ..utils import register +from ..cogvideox_i2v.lora_trainer import CogVideoXI2VLoraTrainer -class CogVideoX1dot5I2VLoraTrainer(Trainer): - - @override - def collate_fn(self, samples: List[List[Dict[str, Any]]]) -> Dict[str, Any]: - raise NotImplementedError - - @override - def load_components(self) -> Dict[str, Any]: - raise NotImplementedError - - @override - def compute_loss(self, batch) -> torch.Tensor: - raise NotImplementedError - - @override - def validate(self) -> None: - raise NotImplementedError +class CogVideoX1dot5I2VLoraTrainer(CogVideoXI2VLoraTrainer): + pass register("cogvideox1.5-i2v", "lora", CogVideoX1dot5I2VLoraTrainer) diff --git a/finetune/models/cogvideox1dot5_t2v/lora_trainer.py b/finetune/models/cogvideox1dot5_t2v/lora_trainer.py index dfc2a78..79504bc 100644 --- a/finetune/models/cogvideox1dot5_t2v/lora_trainer.py +++ b/finetune/models/cogvideox1dot5_t2v/lora_trainer.py @@ -1,29 +1,9 @@ -import torch - -from typing_extensions import override -from typing import Any, Dict, List - -from finetune.trainer import Trainer +from ..cogvideox_t2v.lora_trainer import CogVideoXT2VLoraTrainer from ..utils import register -class CogVideoX1dot5T2VLoraTrainer(Trainer): - - @override - def collate_fn(self, samples: List[List[Dict[str, Any]]]) -> Dict[str, Any]: - raise NotImplementedError - - @override - def load_components(self) -> Dict[str, Any]: - raise NotImplementedError - - @override - def compute_loss(self, batch) -> torch.Tensor: - raise NotImplementedError - - @override - def validate(self) -> None: - raise NotImplementedError +class CogVideoX1dot5T2VLoraTrainer(CogVideoXT2VLoraTrainer): + pass register("cogvideox1.5-t2v", "lora", CogVideoX1dot5T2VLoraTrainer) diff --git a/finetune/models/cogvideox_i2v/lora_trainer.py b/finetune/models/cogvideox_i2v/lora_trainer.py index d625f18..442f769 100644 --- a/finetune/models/cogvideox_i2v/lora_trainer.py +++ b/finetune/models/cogvideox_i2v/lora_trainer.py @@ -1,29 +1,240 @@ import torch from typing_extensions import override -from typing import Any, Dict, List +from typing import Any, Dict, List, Tuple +from PIL import Image + +from transformers import AutoTokenizer, T5EncoderModel + +from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid +from diffusers.models.embeddings import get_3d_rotary_pos_embed +from diffusers import ( + CogVideoXImageToVideoPipeline, + CogVideoXTransformer3DModel, + AutoencoderKLCogVideoX, + CogVideoXDPMScheduler, +) from finetune.trainer import Trainer +from finetune.schemas import Components +from finetune.utils import unwrap_model from ..utils import register class CogVideoXI2VLoraTrainer(Trainer): - @override - def collate_fn(self, samples: List[List[Dict[str, Any]]]) -> Dict[str, Any]: - raise NotImplementedError - @override def load_components(self) -> Dict[str, Any]: - raise NotImplementedError + components = Components() + model_path = str(self.args.model_path) + + components.pipeline_cls = CogVideoXImageToVideoPipeline + + components.tokenizer = AutoTokenizer.from_pretrained( + model_path, subfolder="tokenizer" + ) + + components.text_encoder = T5EncoderModel.from_pretrained( + model_path, subfolder="text_encoder" + ) + + components.transformer = CogVideoXTransformer3DModel.from_pretrained( + model_path, subfolder="transformer" + ) + + components.vae = AutoencoderKLCogVideoX.from_pretrained( + model_path, subfolder="vae" + ) + + components.scheduler = CogVideoXDPMScheduler.from_pretrained( + model_path, subfolder="scheduler" + ) + + return components + + @override + def encode_video(self, video: torch.Tensor) -> torch.Tensor: + # shape of input video: [B, C, F, H, W] + vae = self.components.vae + video = video.to(vae.device, dtype=vae.dtype) + latent_dist = vae.encode(video).latent_dist + latent = latent_dist.sample() * vae.config.scaling_factor + return latent + + @override + def collate_fn(self, samples: List[Dict[str, Any]]) -> Dict[str, Any]: + ret = { + "encoded_videos": [], + "prompt_token_ids": [], + "images": [] + } + + for sample in samples: + encoded_video = sample["encoded_video"] + prompt = sample["prompt"] + image = sample["image"] + + # tokenize prompt + text_inputs = self.components.tokenizer( + prompt, + padding="max_length", + max_length=self.state.transformer_config.max_text_seq_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + ret["encoded_videos"].append(encoded_video) + ret["prompt_token_ids"].append(text_input_ids[0]) + ret["images"].append(image) + + ret["encoded_videos"] = torch.stack(ret["encoded_videos"]) + ret["prompt_token_ids"] = torch.stack(ret["prompt_token_ids"]) + ret["images"] = torch.stack(ret["images"]) + + return ret @override def compute_loss(self, batch) -> torch.Tensor: - raise NotImplementedError + prompt_token_ids = batch["prompt_token_ids"] + latent = batch["encoded_videos"] + images = batch["images"] + + batch_size, num_channels, num_frames, height, width = latent.shape + + # Get prompt embeddings + prompt_embeds = self.components.text_encoder(prompt_token_ids.to(self.accelerator.device))[0] + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.view(batch_size, seq_len, -1) + + # Add frame dimension to images [B,C,H,W] -> [B,C,F,H,W] + images = images.unsqueeze(2) + # Add noise to images + image_noise_sigma = torch.normal(mean=-3.0, std=0.5, size=(1,), device=self.accelerator.device) + image_noise_sigma = torch.exp(image_noise_sigma).to(dtype=images.dtype) + noisy_images = images + torch.randn_like(images) * image_noise_sigma[:, None, None, None, None] + image_latent_dist = self.components.vae.encode(noisy_images).latent_dist + image_latents = image_latent_dist.sample() * self.components.vae.config.scaling_factor + + # Sample a random timestep for each sample + timesteps = torch.randint( + 0, self.components.scheduler.config.num_train_timesteps, + (batch_size,), device=self.accelerator.device + ) + timesteps = timesteps.long() + + # from [B, C, F, H, W] to [B, F, C, H, W] + latent = latent.permute(0, 2, 1, 3, 4) + image_latents = image_latents.permute(0, 2, 1, 3, 4) + assert (latent.shape[0], *latent.shape[2:]) == (image_latents.shape[0], *image_latents.shape[2:]) + + # Padding image_latents to the same frame number as latent + padding_shape = (latent.shape[0], latent.shape[1] - 1, *latent.shape[2:]) + latent_padding = image_latents.new_zeros(padding_shape) + image_latents = torch.cat([image_latents, latent_padding], dim=1) + + # Add noise to latent + noise = torch.randn_like(latent) + latent_noisy = self.components.scheduler.add_noise(latent, noise, timesteps) + + # Concatenate latent and image_latents in the channel dimension + latent_img_noisy = torch.cat([latent_noisy, image_latents], dim=2) + + # Prepare rotary embeds + vae_scale_factor_spatial = 2 ** (len(self.components.vae.config.block_out_channels) - 1) + transformer_config = self.state.transformer_config + rotary_emb = ( + self.prepare_rotary_positional_embeddings( + height=height * vae_scale_factor_spatial, + width=width * vae_scale_factor_spatial, + num_frames=num_frames, + transformer_config=transformer_config, + vae_scale_factor_spatial=vae_scale_factor_spatial, + device=self.accelerator.device, + ) + if transformer_config.use_rotary_positional_embeddings + else None + ) + + # Predict noise + ofs_emb = None if self.state.transformer_config.ofs_embed_dim is None else latent.new_full((1,), fill_value=2.0) + predicted_noise = self.components.transformer( + hidden_states=latent_img_noisy, + encoder_hidden_states=prompt_embeds, + timestep=timesteps, + ofs=ofs_emb, + image_rotary_emb=rotary_emb, + return_dict=False, + )[0] + + # Denoise + latent_pred = self.components.scheduler.get_velocity(predicted_noise, latent_noisy, timesteps) + + alphas_cumprod = self.components.scheduler.alphas_cumprod[timesteps] + weights = 1 / (1 - alphas_cumprod) + while len(weights.shape) < len(latent_pred.shape): + weights = weights.unsqueeze(-1) + + loss = torch.mean((weights * (latent_pred - latent) ** 2).reshape(batch_size, -1), dim=1) + loss = loss.mean() + + return loss @override - def validate(self) -> None: - raise NotImplementedError + def validation_step( + self, eval_data: Dict[str, Any] + ) -> List[Tuple[str, Image.Image | List[Image.Image]]]: + """ + Return the data that needs to be saved. For videos, the data format is List[PIL], + and for images, the data format is PIL + """ + prompt, image, video = eval_data["prompt"], eval_data["image"], eval_data["video"] + + pipe = self.components.pipeline_cls( + tokenizer=self.components.tokenizer, + text_encoder=self.components.text_encoder, + vae=self.components.vae, + transformer=unwrap_model(self.accelerator, self.components.transformer), + scheduler=self.components.scheduler + ) + video_generate = pipe( + num_frames=self.state.train_frames, + height=self.state.train_height, + width=self.state.train_width, + prompt=prompt, + image=image, + generator=self.state.generator + ).frames[0] + return [("video", video_generate)] + + def prepare_rotary_positional_embeddings( + self, + height: int, + width: int, + num_frames: int, + transformer_config: Dict, + vae_scale_factor_spatial: int, + device: torch.device + ) -> Tuple[torch.Tensor, torch.Tensor]: + grid_height = height // (vae_scale_factor_spatial * transformer_config.patch_size) + grid_width = width // (vae_scale_factor_spatial * transformer_config.patch_size) + + if transformer_config.patch_size_t is None: + base_num_frames = num_frames + else: + base_num_frames = (num_frames + transformer_config.patch_size_t - 1) // transformer_config.patch_size_t + + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=transformer_config.attention_head_dim, + crops_coords=None, + grid_size=(grid_height, grid_width), + temporal_size=base_num_frames, + grid_type="slice", + max_size=(grid_height, grid_width), + device=device, + ) + + return freqs_cos, freqs_sin register("cogvideox-i2v", "lora", CogVideoXI2VLoraTrainer) \ No newline at end of file diff --git a/finetune/models/cogvideox_t2v/lora_trainer.py b/finetune/models/cogvideox_t2v/lora_trainer.py new file mode 100644 index 0000000..2e92486 --- /dev/null +++ b/finetune/models/cogvideox_t2v/lora_trainer.py @@ -0,0 +1,214 @@ +import torch + +from typing_extensions import override +from typing import Any, Dict, List, Tuple + +from PIL import Image + +from transformers import AutoTokenizer, T5EncoderModel + +from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid +from diffusers.models.embeddings import get_3d_rotary_pos_embed +from diffusers import ( + CogVideoXPipeline, + CogVideoXTransformer3DModel, + AutoencoderKLCogVideoX, + CogVideoXDPMScheduler, +) + +from finetune.trainer import Trainer +from finetune.schemas import Components +from finetune.utils import unwrap_model +from ..utils import register + + +class CogVideoXT2VLoraTrainer(Trainer): + + @override + def load_components(self) -> Components: + components = Components() + model_path = str(self.args.model_path) + + components.pipeline_cls = CogVideoXPipeline + + components.tokenizer = AutoTokenizer.from_pretrained( + model_path, subfolder="tokenizer" + ) + + components.text_encoder = T5EncoderModel.from_pretrained( + model_path, subfolder="text_encoder" + ) + + components.transformer = CogVideoXTransformer3DModel.from_pretrained( + model_path, subfolder="transformer" + ) + + components.vae = AutoencoderKLCogVideoX.from_pretrained( + model_path, subfolder="vae" + ) + + components.scheduler = CogVideoXDPMScheduler.from_pretrained( + model_path, subfolder="scheduler" + ) + + return components + + @override + def encode_video(self, video: torch.Tensor) -> torch.Tensor: + # shape of input video: [B, C, F, H, W] + vae = self.components.vae + video = video.to(vae.device, dtype=vae.dtype) + latent_dist = vae.encode(video).latent_dist + latent = latent_dist.sample() * vae.config.scaling_factor + return latent + + @override + def collate_fn(self, samples: List[Dict[str, Any]]) -> Dict[str, Any]: + ret = { + "encoded_videos": [], + "prompt_token_ids": [] + } + + for sample in samples: + encoded_video = sample["encoded_video"] + prompt = sample["prompt"] + + # tokenize prompt + text_inputs = self.components.tokenizer( + prompt, + padding="max_length", + max_length=226, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + ret["encoded_videos"].append(encoded_video) + ret["prompt_token_ids"].append(text_input_ids[0]) + + ret["encoded_videos"] = torch.stack(ret["encoded_videos"]) + ret["prompt_token_ids"] = torch.stack(ret["prompt_token_ids"]) + + return ret + + @override + def compute_loss(self, batch) -> torch.Tensor: + prompt_token_ids = batch["prompt_token_ids"] + latent = batch["encoded_videos"] + + batch_size, num_channels, num_frames, height, width = latent.shape + + # Get prompt embeddings + prompt_embeds = self.components.text_encoder(prompt_token_ids.to(self.accelerator.device))[0] + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.view(batch_size, seq_len, -1) + assert prompt_embeds.requires_grad is False + + # Sample a random timestep for each sample + timesteps = torch.randint( + 0, self.components.scheduler.config.num_train_timesteps, + (batch_size,), device=self.accelerator.device + ) + timesteps = timesteps.long() + + # Add noise to latent + latent = latent.permute(0, 2, 1, 3, 4) # [B, F, C, H, W] + noise = torch.randn_like(latent) + latent_added_noise = self.components.scheduler.add_noise(latent, noise, timesteps) + + # Prepare rotary embeds + vae_scale_factor_spatial = 2 ** (len(self.components.vae.config.block_out_channels) - 1) + transformer_config = self.state.transformer_config + rotary_emb = ( + self.prepare_rotary_positional_embeddings( + height=height * vae_scale_factor_spatial, + width=width * vae_scale_factor_spatial, + num_frames=num_frames, + transformer_config=transformer_config, + vae_scale_factor_spatial=vae_scale_factor_spatial, + device=self.accelerator.device, + ) + if transformer_config.use_rotary_positional_embeddings + else None + ) + + # Predict noise + predicted_noise = self.components.transformer( + hidden_states=latent_added_noise, + encoder_hidden_states=prompt_embeds, + timestep=timesteps, + image_rotary_emb=rotary_emb, + return_dict=False, + )[0] + + # Denoise + latent_pred = self.components.scheduler.get_velocity(predicted_noise, latent_added_noise, timesteps) + + alphas_cumprod = self.components.scheduler.alphas_cumprod[timesteps] + weights = 1 / (1 - alphas_cumprod) + while len(weights.shape) < len(latent_pred.shape): + weights = weights.unsqueeze(-1) + + loss = torch.mean((weights * (latent_pred - latent) ** 2).reshape(batch_size, -1), dim=1) + loss = loss.mean() + + return loss + + @override + def validation_step( + self, eval_data: Dict[str, Any] + ) -> List[Tuple[str, Image.Image | List[Image.Image]]]: + """ + Return the data that needs to be saved. For videos, the data format is List[PIL], + and for images, the data format is PIL + """ + prompt, image, video = eval_data["prompt"], eval_data["image"], eval_data["video"] + + pipe = self.components.pipeline_cls( + tokenizer=self.components.tokenizer, + text_encoder=self.components.text_encoder, + vae=self.components.vae, + transformer=unwrap_model(self.accelerator, self.components.transformer), + scheduler=self.components.scheduler + ) + video_generate = pipe( + num_frames=self.state.train_frames, + height=self.state.train_height, + width=self.state.train_width, + prompt=prompt, + generator=self.state.generator + ).frames[0] + return [("video", video_generate)] + + def prepare_rotary_positional_embeddings( + self, + height: int, + width: int, + num_frames: int, + transformer_config: Dict, + vae_scale_factor_spatial: int, + device: torch.device + ) -> Tuple[torch.Tensor, torch.Tensor]: + grid_height = height // (vae_scale_factor_spatial * transformer_config.patch_size) + grid_width = width // (vae_scale_factor_spatial * transformer_config.patch_size) + + if transformer_config.patch_size_t is None: + base_num_frames = num_frames + else: + base_num_frames = (num_frames + transformer_config.patch_size_t - 1) // transformer_config.patch_size_t + + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=transformer_config.attention_head_dim, + crops_coords=None, + grid_size=(grid_height, grid_width), + temporal_size=base_num_frames, + grid_type="slice", + max_size=(grid_height, grid_width), + device=device, + ) + + return freqs_cos, freqs_sin + + +register("cogvideox-t2v", "lora", CogVideoXT2VLoraTrainer) diff --git a/finetune/trainer.py b/finetune/trainer.py index 5b02f0e..6c9ec82 100644 --- a/finetune/trainer.py +++ b/finetune/trainer.py @@ -44,7 +44,7 @@ from finetune.utils import ( string_to_filename ) -from finetune.datasets import I2VDatasetWithResize, T2VDatasetWithResize, BucketSampler +from finetune.datasets import I2VDatasetWithResize, T2VDatasetWithResize from finetune.datasets.utils import ( load_prompts, load_images, load_videos, preprocess_image_with_resize, preprocess_video_with_resize @@ -66,7 +66,12 @@ class Trainer: def __init__(self, args: Args) -> None: self.args = args - self.state = State(weight_dtype=self.__get_training_dtype()) + self.state = State( + weight_dtype=self.__get_training_dtype(), + train_frames=self.args.train_resolution[0], + train_height=self.args.train_resolution[1], + train_width=self.args.train_resolution[2] + ) self.components = Components() self.accelerator: Accelerator = None @@ -140,6 +145,8 @@ class Trainer: if self.args.enable_tiling: self.components.vae.enable_tiling() + self.state.transformer_config = self.components.transformer.config + def prepare_dataset(self) -> None: logger.info("Initializing dataset and dataloader") @@ -147,19 +154,19 @@ class Trainer: self.dataset = I2VDatasetWithResize( **(self.args.model_dump()), device=self.accelerator.device, - encode_fn=self.encode_video, - max_num_frames=self.args.train_resolution[0], - height=self.args.train_resolution[1], - width=self.args.train_resolution[2] + encode_video_fn=self.encode_video, + max_num_frames=self.state.train_frames, + height=self.state.train_height, + width=self.state.train_width ) elif self.args.model_type == "t2v": self.dataset = T2VDatasetWithResize( **(self.args.model_dump()), device=self.accelerator.device, - encode_fn=self.encode_video, - max_num_frames=self.args.train_resolution[0], - height=self.args.train_resolution[1], - width=self.args.train_resolution[2] + encode_video_fn=self.encode_video, + max_num_frames=self.state.train_frames, + height=self.state.train_height, + width=self.state.train_width ) else: raise ValueError(f"Invalid model type: {self.args.model_type}") @@ -474,7 +481,7 @@ class Trainer: if image is not None: image = preprocess_image_with_resize( - image, self.args.train_resolution[1], self.args.train_resolution[2] + image, self.state.train_height, self.state.train_width ) # Convert image tensor (C, H, W) to PIL images image = image.to(torch.uint8) @@ -483,7 +490,7 @@ class Trainer: if video is not None: video = preprocess_video_with_resize( - video, self.args.train_resolution[0], self.args.train_resolution[1], self.args.train_resolution[2] + video, self.state.train_frames, self.state.train_height, self.state.train_width ) # Convert video tensor (F, C, H, W) to list of PIL images video = (video * 255).round().clamp(0, 255).to(torch.uint8) From 04a60e743584c17f2fc798261eddd601a8937c55 Mon Sep 17 00:00:00 2001 From: OleehyO Date: Wed, 1 Jan 2025 14:22:59 +0000 Subject: [PATCH 20/25] Change logger name to trainer --- finetune/constants.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/finetune/constants.py b/finetune/constants.py index 0842efa..30ef0eb 100644 --- a/finetune/constants.py +++ b/finetune/constants.py @@ -1,2 +1,2 @@ -LOG_NAME = "finetrainer" +LOG_NAME = "trainer" LOG_LEVEL = "INFO" \ No newline at end of file From 26b87cd4ff2d696303ab8f4a65a27692fe0e440d Mon Sep 17 00:00:00 2001 From: OleehyO Date: Wed, 1 Jan 2025 14:40:09 +0000 Subject: [PATCH 21/25] feat(args): add validation and arg interface for training parameters - Add field validators for model type and validation settings - Implement command line argument parsing with argparse - Add type hints and documentation for training parameters - Support configuration of model, training, and validation parameters --- finetune/schemas/args.py | 122 +++++++++++++++++++++++++++++---------- 1 file changed, 93 insertions(+), 29 deletions(-) diff --git a/finetune/schemas/args.py b/finetune/schemas/args.py index 4745f7c..dca76e3 100644 --- a/finetune/schemas/args.py +++ b/finetune/schemas/args.py @@ -1,7 +1,7 @@ import datetime import argparse from typing import Dict, Any, Literal, List, Tuple -from pydantic import BaseModel, field_validator +from pydantic import BaseModel, field_validator, ValidationInfo from pathlib import Path @@ -30,7 +30,7 @@ class Args(BaseModel): seed: int | None = None train_epochs: int train_steps: int | None = None - checkpointing_steps: int = 500 + checkpointing_steps: int = 200 checkpointing_limit: int = 10 batch_size: int @@ -93,42 +93,114 @@ class Args(BaseModel): # The height and width set in the bucket must be an integer multiple of 8 (temporal_compression_rate[8]) # gen_video_resolution: Tuple[int, int, int] | None # shape: (frames, height, width) + @field_validator("image_column") + def validate_image_column(cls, v: str | None, info: ValidationInfo) -> str | None: + values = info.data + if values.get("model_type") == "i2v" and not v: + raise ValueError("image_column must be specified when using i2v model") + return v + + @field_validator("validation_dir", "validation_prompts") + def validate_validation_required_fields(cls, v: Any, info: ValidationInfo) -> Any: + values = info.data + if values.get("do_validation") and not v: + field_name = info.field_name + raise ValueError(f"{field_name} must be specified when do_validation is True") + return v + + @field_validator("validation_images") + def validate_validation_images(cls, v: str | None, info: ValidationInfo) -> str | None: + values = info.data + if values.get("do_validation") and values.get("model_type") == "i2v" and not v: + raise ValueError("validation_images must be specified when do_validation is True and model_type is i2v") + return v + + @field_validator("validation_videos") + def validate_validation_videos(cls, v: str | None, info: ValidationInfo) -> str | None: + values = info.data + if values.get("do_validation") and values.get("model_type") == "v2v" and not v: + raise ValueError("validation_videos must be specified when do_validation is True and model_type is v2v") + return v + + @field_validator("validation_steps") + def validate_validation_steps(cls, v: int | None, info: ValidationInfo) -> int | None: + values = info.data + if values.get("do_validation"): + if v is None: + raise ValueError("validation_steps must be specified when do_validation is True") + if values.get("checkpointing_steps") and v % values["checkpointing_steps"] != 0: + raise ValueError("validation_steps must be a multiple of checkpointing_steps") + return v + @classmethod def parse_args(cls): """Parse command line arguments and return Args instance""" parser = argparse.ArgumentParser() + # Required arguments parser.add_argument("--model_path", type=str, required=True) parser.add_argument("--model_name", type=str, required=True) parser.add_argument("--model_type", type=str, required=True) parser.add_argument("--training_type", type=str, required=True) parser.add_argument("--output_dir", type=str, required=True) - parser.add_argument("--seed", type=int, required=True) - parser.add_argument("--nccl_timeout", type=int, required=True) - parser.add_argument("--mixed_precision", type=str, required=True) - parser.add_argument("--gradient_accumulation_steps", type=int, required=True) parser.add_argument("--data_root", type=str, required=True) parser.add_argument("--caption_column", type=str, required=True) parser.add_argument("--video_column", type=str, required=True) - parser.add_argument("--image_column", type=str) parser.add_argument("--train_resolution", type=str, required=True) - parser.add_argument("--batch_size", type=int, required=True) - parser.add_argument("--num_workers", type=int, required=True) - parser.add_argument("--pin_memory", type=str, required=True) parser.add_argument("--report_to", type=str, required=True) - parser.add_argument("--train_epochs", type=int, required=True) - parser.add_argument("--checkpointing_steps", type=int, required=True) - parser.add_argument("--checkpointing_limit", type=int, required=True) - parser.add_argument("--do_validation", type=bool) - parser.add_argument("--validation_steps", type=int) - parser.add_argument("--validation_dir", type=str) - parser.add_argument("--validation_prompts", type=str) - parser.add_argument("--validation_images", type=str) - parser.add_argument("--validation_videos", type=str) - parser.add_argument("--gen_fps", type=int) + # Training hyperparameters + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--train_epochs", type=int, default=10) + parser.add_argument("--train_steps", type=int, default=None) + parser.add_argument("--gradient_accumulation_steps", type=int, default=1) + parser.add_argument("--batch_size", type=int, default=1) + parser.add_argument("--learning_rate", type=float, default=2e-5) + parser.add_argument("--optimizer", type=str, default="adamw") + parser.add_argument("--beta1", type=float, default=0.9) + parser.add_argument("--beta2", type=float, default=0.95) + parser.add_argument("--beta3", type=float, default=0.98) + parser.add_argument("--epsilon", type=float, default=1e-8) + parser.add_argument("--weight_decay", type=float, default=1e-4) + parser.add_argument("--max_grad_norm", type=float, default=1.0) - parser.add_argument("--resume_from_checkpoint", type=str) + # Learning rate scheduler + parser.add_argument("--lr_scheduler", type=str, default="constant_with_warmup") + parser.add_argument("--lr_warmup_steps", type=int, default=100) + parser.add_argument("--lr_num_cycles", type=int, default=1) + parser.add_argument("--lr_power", type=float, default=1.0) + + # Data loading + parser.add_argument("--num_workers", type=int, default=8) + parser.add_argument("--pin_memory", type=bool, default=True) + parser.add_argument("--image_column", type=str, default=None) + + # Model configuration + parser.add_argument("--mixed_precision", type=str, default="no") + parser.add_argument("--gradient_checkpointing", type=bool, default=True) + parser.add_argument("--enable_slicing", type=bool, default=True) + parser.add_argument("--enable_tiling", type=bool, default=True) + parser.add_argument("--nccl_timeout", type=int, default=1800) + + # LoRA parameters + parser.add_argument("--rank", type=int, default=128) + parser.add_argument("--lora_alpha", type=int, default=64) + parser.add_argument("--target_modules", type=str, nargs="+", + default=["to_q", "to_k", "to_v", "to_out.0"]) + + # Checkpointing + parser.add_argument("--checkpointing_steps", type=int, default=200) + parser.add_argument("--checkpointing_limit", type=int, default=10) + parser.add_argument("--resume_from_checkpoint", type=str, default=None) + + # Validation + parser.add_argument("--do_validation", type=bool, default=False) + parser.add_argument("--validation_steps", type=int, default=None) + parser.add_argument("--validation_dir", type=str, default=None) + parser.add_argument("--validation_prompts", type=str, default=None) + parser.add_argument("--validation_images", type=str, default=None) + parser.add_argument("--validation_videos", type=str, default=None) + parser.add_argument("--gen_fps", type=int, default=15) args = parser.parse_args() @@ -137,11 +209,3 @@ class Args(BaseModel): args.train_resolution = (int(frames), int(height), int(width)) return cls(**vars(args)) - - # @field_validator("...", mode="after") - # def foo(cls, foobar): - # ... - - # @field_validator("...", mode="before") - # def bar(cls, barbar): - # ... From 6e794724178423de53e0e0ad697f0d66780cefc5 Mon Sep 17 00:00:00 2001 From: OleehyO Date: Wed, 1 Jan 2025 14:53:45 +0000 Subject: [PATCH 22/25] feat: add training launch scripts for I2V and T2V models Add two shell scripts to simplify model training: - accelerate_train_i2v.sh: Launch script for Image-to-Video training - accelerate_train_t2v.sh: Launch script for Text-to-Video training Both scripts provide comprehensive configurations for: - Model settings - Data pipeline - Training parameters - System resources - Checkpointing - Validation --- finetune/accelerate_train_i2v.sh | 46 ++++++++++++++++++++++++++++++++ finetune/accelerate_train_t2v.sh | 45 +++++++++++++++++++++++++++++++ 2 files changed, 91 insertions(+) create mode 100644 finetune/accelerate_train_i2v.sh create mode 100644 finetune/accelerate_train_t2v.sh diff --git a/finetune/accelerate_train_i2v.sh b/finetune/accelerate_train_i2v.sh new file mode 100644 index 0000000..372d2c4 --- /dev/null +++ b/finetune/accelerate_train_i2v.sh @@ -0,0 +1,46 @@ +#!/usr/bin/env bash + +# Prevent tokenizer parallelism issues +export TOKENIZERS_PARALLELISM=false + +# Launch training with accelerate +accelerate launch train.py \ + ########## Model Configuration ########## + --model_path "THUDM/CogVideoX1.5-5B-I2V" \ + --model_name "cogvideox1.5-i2v" \ + --model_type "i2v" \ + --training_type "lora" \ + + ########## Output Configuration ########## + --output_dir "/path/to/output/dir" \ + --report_to "tensorboard" \ + + ########## Data Configuration ########## + --data_root "/path/to/data/dir" \ + --caption_column "prompt.txt" \ + --video_column "videos.txt" \ + --image_column "images.txt" \ + --train_resolution "48x768x1360" \ + + ########## Training Configuration ########## + --train_epochs 10 \ + --batch_size 1 \ + --gradient_accumulation_steps 1 \ + --mixed_precision "bf16" \ + --seed 42 \ + + ########## System Configuration ########## + --num_workers 8 \ + --pin_memory True \ + --nccl_timeout 1800 \ + + ########## Checkpointing Configuration ########## + --checkpointing_steps 200 \ + --checkpointing_limit 10 \ + + ########## Validation Configuration ########## + --do_validation False \ + --validation_dir "path/to/validation/dir" \ + --validation_steps 400 \ + --validation_prompts "prompts.txt" \ + --gen_fps 15 diff --git a/finetune/accelerate_train_t2v.sh b/finetune/accelerate_train_t2v.sh new file mode 100644 index 0000000..bdb0140 --- /dev/null +++ b/finetune/accelerate_train_t2v.sh @@ -0,0 +1,45 @@ +#!/usr/bin/env bash + +# Prevent tokenizer parallelism issues +export TOKENIZERS_PARALLELISM=false + +# Launch training with accelerate +accelerate launch train.py \ + ########## Model Configuration ########## + --model_path "THUDM/CogVideoX1.5-5B" \ + --model_name "cogvideox1.5-t2v" \ + --model_type "t2v" \ + --training_type "lora" \ + + ########## Output Configuration ########## + --output_dir "/path/to/output/dir" \ + --report_to "tensorboard" \ + + ########## Data Configuration ########## + --data_root "/path/to/data/dir" \ + --caption_column "prompt.txt" \ + --video_column "videos.txt" \ + --train_resolution "48x768x1360" \ + + ########## Training Configuration ########## + --train_epochs 10 \ + --batch_size 1 \ + --gradient_accumulation_steps 1 \ + --mixed_precision "bf16" \ + --seed 42 \ + + ########## System Configuration ########## + --num_workers 8 \ + --pin_memory True \ + --nccl_timeout 1800 \ + + ########## Checkpointing Configuration ########## + --checkpointing_steps 200 \ + --checkpointing_limit 10 \ + + ########## Validation Configuration ########## + --do_validation False \ + --validation_dir "path/to/validation/dir" \ + --validation_steps 400 \ + --validation_prompts "prompts.txt" \ + --gen_fps 15 From 6ef15dd2a5898d210a5b0b5ed6e3a34e13aa8edc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=89=E6=B4=8B=E4=B8=89=E6=B4=8B?= <1258009915@qq.com> Date: Fri, 27 Dec 2024 19:37:08 +0800 Subject: [PATCH 23/25] docs: update TOC and add friendly link in README files - Update table of contents in README.md, README_ja.md and README_zh.md - Add friendly link section to all README files --- README.md | 31 +++++++++++++++++++------------ README_ja.md | 21 ++++++++++++++------- README_zh.md | 23 +++++++++++++++-------- 3 files changed, 48 insertions(+), 27 deletions(-) diff --git a/README.md b/README.md index a22b290..e58d8d8 100644 --- a/README.md +++ b/README.md @@ -61,18 +61,24 @@ The SAT code has already been updated, while the diffusers version is still unde Jump to a specific section: -- [Quick Start](#Quick-Start) - - [SAT](#sat) - - [Diffusers](#Diffusers) -- [CogVideoX-2B Video Works](#cogvideox-2b-gallery) -- [Introduction to the CogVideoX Model](#Model-Introduction) -- [Full Project Structure](#project-structure) - - [Inference](#inference) - - [SAT](#sat) - - [Tools](#tools) -- [Introduction to CogVideo(ICLR'23) Model](#cogvideoiclr23) -- [Citations](#Citation) -- [Model License](#Model-License) +- [Quick Start](#quick-start) + - [Prompt Optimization](#prompt-optimization) + - [SAT](#sat) + - [Diffusers](#diffusers) +- [Gallery](#gallery) + - [CogVideoX-5B](#cogvideox-5b) + - [CogVideoX-2B](#cogvideox-2b) +- [Model Introduction](#model-introduction) +- [Friendly Links](#friendly-links) +- [Project Structure](#project-structure) + - [Quick Start with Colab](#quick-start-with-colab) + - [Inference](#inference) + - [finetune](#finetune) + - [sat](#sat-1) + - [Tools](#tools) +- [CogVideo(ICLR'23)](#cogvideoiclr23) +- [Citation](#citation) +- [Model-License](#model-license) ## Quick Start @@ -321,6 +327,7 @@ works have already been adapted for CogVideoX, and we invite everyone to use the + [CogVideoX-Controlnet](https://github.com/TheDenk/cogvideox-controlnet): A simple ControlNet module code that includes the CogVideoX model. + [VideoTuna](https://github.com/VideoVerses/VideoTuna): VideoTuna is the first repo that integrates multiple AI video generation models for text-to-video, image-to-video, text-to-image generation. + [ConsisID](https://github.com/PKU-YuanGroup/ConsisID): An identity-preserving text-to-video generation model, bases on CogVideoX-5B, which keep the face consistent in the generated video by frequency decomposition. ++ [A Step by Step Tutorial](https://www.youtube.com/watch?v=5UCkMzP2VLE&ab_channel=SECourses): A step-by-step guide on installing and optimizing the CogVideoX1.5-5B-I2V model in Windows and cloud environments. Special thanks to the [FurkanGozukara](https://github.com/FurkanGozukara) for his effort and support! ## Project Structure diff --git a/README_ja.md b/README_ja.md index c8c2d24..6c2303c 100644 --- a/README_ja.md +++ b/README_ja.md @@ -62,15 +62,21 @@ SAT バージョンのコードは [こちら](https://huggingface.co/THUDM/CogV 特定のセクションにジャンプ: - [クイックスタート](#クイックスタート) - - [SAT](#sat) - - [Diffusers](#Diffusers) -- [CogVideoX-2B ギャラリー](#CogVideoX-2B-ギャラリー) + - [プロンプトの最適化](#プロンプトの最適化) + - [SAT](#sat) + - [Diffusers](#diffusers) +- [Gallery](#gallery) + - [CogVideoX-5B](#cogvideox-5b) + - [CogVideoX-2B](#cogvideox-2b) - [モデル紹介](#モデル紹介) +- [友好的リンク](#友好的リンク) - [プロジェクト構造](#プロジェクト構造) - - [推論](#推論) - - [sat](#sat) - - [ツール](#ツール)= -- [CogVideo(ICLR'23)モデル紹介](#CogVideoICLR23) + - [Colabでのクイックスタート](#colabでのクイックスタート) + - [Inference](#inference) + - [finetune](#finetune) + - [sat](#sat-1) + - [ツール](#ツール) +- [CogVideo(ICLR'23)](#cogvideoiclr23) - [引用](#引用) - [ライセンス契約](#ライセンス契約) @@ -302,6 +308,7 @@ pipe.vae.enable_tiling() + [CogVideoX-Controlnet](https://github.com/TheDenk/cogvideox-controlnet): CogVideoXモデルを含むシンプルなControlNetモジュールのコード。 + [VideoTuna](https://github.com/VideoVerses/VideoTuna): VideoTuna は、テキストからビデオ、画像からビデオ、テキストから画像生成のための複数のAIビデオ生成モデルを統合した最初のリポジトリです。 + [ConsisID](https://github.com/PKU-YuanGroup/ConsisID): 一貫性のある顔を保持するために、周波数分解を使用するCogVideoX-5Bに基づいたアイデンティティ保持型テキストから動画生成モデル。 ++ [ステップバイステップチュートリアル](https://www.youtube.com/watch?v=5UCkMzP2VLE&ab_channel=SECourses): WindowsおよびクラウドでのCogVideoX1.5-5B-I2Vモデルのインストールと最適化に関するステップバイステップガイド。[FurkanGozukara](https://github.com/FurkanGozukara)氏の尽力とサポートに感謝いたします! ## プロジェクト構造 diff --git a/README_zh.md b/README_zh.md index b280bd2..393770a 100644 --- a/README_zh.md +++ b/README_zh.md @@ -51,15 +51,21 @@ CogVideoX1.5-5B 系列模型支持 **10秒** 长度的视频和更高的分辨 跳转到指定部分: - [快速开始](#快速开始) - - [SAT](#sat) - - [Diffusers](#Diffusers) -- [CogVideoX-2B 视频作品](#cogvideox-2b-视频作品) -- [CogVideoX模型介绍](#模型介绍) + - [提示词优化](#提示词优化) + - [SAT](#sat) + - [Diffusers](#diffusers) +- [视频作品](#视频作品) + - [CogVideoX-5B](#cogvideox-5b) + - [CogVideoX-2B](#cogvideox-2b) +- [模型介绍](#模型介绍) +- [友情链接](#友情链接) - [完整项目代码结构](#完整项目代码结构) - - [Inference](#inference) - - [SAT](#sat) - - [Tools](#tools) -- [CogVideo(ICLR'23)模型介绍](#cogvideoiclr23) + - [Colab 快速使用](#colab-快速使用) + - [inference](#inference) + - [finetune](#finetune) + - [sat](#sat-1) + - [tools](#tools) +- [CogVideo(ICLR'23)](#cogvideoiclr23) - [引用](#引用) - [模型协议](#模型协议) @@ -282,6 +288,7 @@ pipe.vae.enable_tiling() + [CogVideoX-Controlnet](https://github.com/TheDenk/cogvideox-controlnet): 一个包含 CogvideoX 模型的简单 Controlnet 模块的代码。 + [VideoTuna](https://github.com/VideoVerses/VideoTuna):VideoTuna 是首个集成多种 AI 视频生成模型的仓库,支持文本转视频、图像转视频、文本转图像生成。 + [ConsisID](https://github.com/PKU-YuanGroup/ConsisID): 一种身份保持的文本到视频生成模型,基于 CogVideoX-5B,通过频率分解在生成的视频中保持面部一致性。 ++ [教程](https://www.youtube.com/watch?v=5UCkMzP2VLE&ab_channel=SECourses): 一个关于在Windows和云环境中安装和优化CogVideoX1.5-5B-I2V模型的分步指南。特别感谢[FurkanGozukara](https://github.com/FurkanGozukara)的努力和支持! ## 完整项目代码结构 From 48ad17881878dcf235b24964f05a337713834af2 Mon Sep 17 00:00:00 2001 From: OleehyO Date: Wed, 1 Jan 2025 15:52:39 +0000 Subject: [PATCH 24/25] Reorganize training script arguments --- finetune/accelerate_train_i2v.sh | 103 +++++++++++++++++++------------ finetune/accelerate_train_t2v.sh | 100 ++++++++++++++++++------------ 2 files changed, 124 insertions(+), 79 deletions(-) diff --git a/finetune/accelerate_train_i2v.sh b/finetune/accelerate_train_i2v.sh index 372d2c4..ec3922e 100644 --- a/finetune/accelerate_train_i2v.sh +++ b/finetune/accelerate_train_i2v.sh @@ -3,44 +3,67 @@ # Prevent tokenizer parallelism issues export TOKENIZERS_PARALLELISM=false -# Launch training with accelerate -accelerate launch train.py \ - ########## Model Configuration ########## - --model_path "THUDM/CogVideoX1.5-5B-I2V" \ - --model_name "cogvideox1.5-i2v" \ - --model_type "i2v" \ - --training_type "lora" \ - - ########## Output Configuration ########## - --output_dir "/path/to/output/dir" \ - --report_to "tensorboard" \ - - ########## Data Configuration ########## - --data_root "/path/to/data/dir" \ - --caption_column "prompt.txt" \ - --video_column "videos.txt" \ - --image_column "images.txt" \ - --train_resolution "48x768x1360" \ - - ########## Training Configuration ########## - --train_epochs 10 \ - --batch_size 1 \ - --gradient_accumulation_steps 1 \ - --mixed_precision "bf16" \ - --seed 42 \ - - ########## System Configuration ########## - --num_workers 8 \ - --pin_memory True \ - --nccl_timeout 1800 \ - - ########## Checkpointing Configuration ########## - --checkpointing_steps 200 \ - --checkpointing_limit 10 \ - - ########## Validation Configuration ########## - --do_validation False \ - --validation_dir "path/to/validation/dir" \ - --validation_steps 400 \ - --validation_prompts "prompts.txt" \ +# Model Configuration +MODEL_ARGS=( + --model_path "THUDM/CogVideoX1.5-5B-I2V" + --model_name "cogvideox1.5-i2v" + --model_type "i2v" + --training_type "lora" +) + +# Output Configuration +OUTPUT_ARGS=( + --output_dir "/path/to/output/dir" + --report_to "tensorboard" +) + +# Data Configuration +DATA_ARGS=( + --data_root "/path/to/data/dir" + --caption_column "prompt.txt" + --video_column "videos.txt" + --image_column "images.txt" + --train_resolution "80x768x1360" +) + +# Training Configuration +TRAIN_ARGS=( + --train_epochs 10 + --batch_size 1 + --gradient_accumulation_steps 1 + --mixed_precision "bf16" + --seed 42 +) + +# System Configuration +SYSTEM_ARGS=( + --num_workers 8 + --pin_memory True + --nccl_timeout 1800 +) + +# Checkpointing Configuration +CHECKPOINT_ARGS=( + --checkpointing_steps 200 + --checkpointing_limit 10 +) + +# Validation Configuration +VALIDATION_ARGS=( + --do_validation False + --validation_dir "/path/to/validation/dir" + --validation_steps 400 + --validation_prompts "prompts.txt" + --validation_images "images.txt" --gen_fps 15 +) + +# Combine all arguments and launch training +accelerate launch train.py \ + "${MODEL_ARGS[@]}" \ + "${OUTPUT_ARGS[@]}" \ + "${DATA_ARGS[@]}" \ + "${TRAIN_ARGS[@]}" \ + "${SYSTEM_ARGS[@]}" \ + "${CHECKPOINT_ARGS[@]}" \ + "${VALIDATION_ARGS[@]}" \ No newline at end of file diff --git a/finetune/accelerate_train_t2v.sh b/finetune/accelerate_train_t2v.sh index bdb0140..0d2b7f6 100644 --- a/finetune/accelerate_train_t2v.sh +++ b/finetune/accelerate_train_t2v.sh @@ -3,43 +3,65 @@ # Prevent tokenizer parallelism issues export TOKENIZERS_PARALLELISM=false -# Launch training with accelerate -accelerate launch train.py \ - ########## Model Configuration ########## - --model_path "THUDM/CogVideoX1.5-5B" \ - --model_name "cogvideox1.5-t2v" \ - --model_type "t2v" \ - --training_type "lora" \ - - ########## Output Configuration ########## - --output_dir "/path/to/output/dir" \ - --report_to "tensorboard" \ - - ########## Data Configuration ########## - --data_root "/path/to/data/dir" \ - --caption_column "prompt.txt" \ - --video_column "videos.txt" \ - --train_resolution "48x768x1360" \ - - ########## Training Configuration ########## - --train_epochs 10 \ - --batch_size 1 \ - --gradient_accumulation_steps 1 \ - --mixed_precision "bf16" \ - --seed 42 \ - - ########## System Configuration ########## - --num_workers 8 \ - --pin_memory True \ - --nccl_timeout 1800 \ - - ########## Checkpointing Configuration ########## - --checkpointing_steps 200 \ - --checkpointing_limit 10 \ - - ########## Validation Configuration ########## - --do_validation False \ - --validation_dir "path/to/validation/dir" \ - --validation_steps 400 \ - --validation_prompts "prompts.txt" \ +# Model Configuration +MODEL_ARGS=( + --model_path "THUDM/CogVideoX1.5-5B" + --model_name "cogvideox1.5-t2v" + --model_type "t2v" + --training_type "lora" +) + +# Output Configuration +OUTPUT_ARGS=( + --output_dir "/path/to/output/dir" + --report_to "tensorboard" +) + +# Data Configuration +DATA_ARGS=( + --data_root "/path/to/data/dir" + --caption_column "prompt.txt" + --video_column "videos.txt" + --train_resolution "80x768x1360" +) + +# Training Configuration +TRAIN_ARGS=( + --train_epochs 10 + --batch_size 1 + --gradient_accumulation_steps 1 + --mixed_precision "bf16" + --seed 42 +) + +# System Configuration +SYSTEM_ARGS=( + --num_workers 8 + --pin_memory True + --nccl_timeout 1800 +) + +# Checkpointing Configuration +CHECKPOINT_ARGS=( + --checkpointing_steps 200 + --checkpointing_limit 10 +) + +# Validation Configuration +VALIDATION_ARGS=( + --do_validation False + --validation_dir "/path/to/validation/dir" + --validation_steps 400 + --validation_prompts "prompts.txt" --gen_fps 15 +) + +# Combine all arguments and launch training +accelerate launch train.py \ + "${MODEL_ARGS[@]}" \ + "${OUTPUT_ARGS[@]}" \ + "${DATA_ARGS[@]}" \ + "${TRAIN_ARGS[@]}" \ + "${SYSTEM_ARGS[@]}" \ + "${CHECKPOINT_ARGS[@]}" \ + "${VALIDATION_ARGS[@]}" \ No newline at end of file From 7fa1bb48be3272841c1195411e0dffda35ca6739 Mon Sep 17 00:00:00 2001 From: OleehyO Date: Wed, 1 Jan 2025 15:56:14 +0000 Subject: [PATCH 25/25] refactor: remove deprecated training scripts --- finetune/finetune_multi_rank.sh | 52 - finetune/finetune_single_rank.sh | 52 - .../train_cogvideox_image_to_video_lora.py | 1689 ----------------- finetune/train_cogvideox_lora.py | 1573 --------------- 4 files changed, 3366 deletions(-) delete mode 100644 finetune/finetune_multi_rank.sh delete mode 100644 finetune/finetune_single_rank.sh delete mode 100644 finetune/train_cogvideox_image_to_video_lora.py delete mode 100644 finetune/train_cogvideox_lora.py diff --git a/finetune/finetune_multi_rank.sh b/finetune/finetune_multi_rank.sh deleted file mode 100644 index f6c34de..0000000 --- a/finetune/finetune_multi_rank.sh +++ /dev/null @@ -1,52 +0,0 @@ -#!/bin/bash - -export MODEL_PATH="THUDM/CogVideoX-2b" -export CACHE_PATH="~/.cache" -export DATASET_PATH="Disney-VideoGeneration-Dataset" -export OUTPUT_PATH="cogvideox-lora-multi-node" -export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True -export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES - -# max batch-size is 2. -accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gpu --machine_rank 0 \ - train_cogvideox_lora.py \ - --gradient_checkpointing \ - --pretrained_model_name_or_path $MODEL_PATH \ - --cache_dir $CACHE_PATH \ - --enable_tiling \ - --enable_slicing \ - --instance_data_root $DATASET_PATH \ - --caption_column prompt.txt \ - --video_column videos.txt \ - --validation_prompt "DISNEY A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions:::A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance" \ - --validation_prompt_separator ::: \ - --num_validation_videos 1 \ - --validation_epochs 100 \ - --seed 42 \ - --rank 128 \ - --lora_alpha 64 \ - --mixed_precision bf16 \ - --output_dir $OUTPUT_PATH \ - --height 480 \ - --width 720 \ - --fps 8 \ - --max_num_frames 49 \ - --skip_frames_start 0 \ - --skip_frames_end 0 \ - --train_batch_size 1 \ - --num_train_epochs 30 \ - --checkpointing_steps 1000 \ - --gradient_accumulation_steps 1 \ - --learning_rate 1e-3 \ - --lr_scheduler cosine_with_restarts \ - --lr_warmup_steps 200 \ - --lr_num_cycles 1 \ - --enable_slicing \ - --enable_tiling \ - --gradient_checkpointing \ - --optimizer AdamW \ - --adam_beta1 0.9 \ - --adam_beta2 0.95 \ - --max_grad_norm 1.0 \ - --allow_tf32 \ - --report_to wandb \ No newline at end of file diff --git a/finetune/finetune_single_rank.sh b/finetune/finetune_single_rank.sh deleted file mode 100644 index 8b45876..0000000 --- a/finetune/finetune_single_rank.sh +++ /dev/null @@ -1,52 +0,0 @@ -#!/bin/bash - -export MODEL_PATH="THUDM/CogVideoX-5b" -export CACHE_PATH="~/.cache" -export DATASET_PATH="Disney-VideoGeneration-Dataset" -export OUTPUT_PATH="cogvideox-lora-single-node" -export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True -export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES - -# if you are not using wth 8 gus, change `accelerate_config_machine_single.yaml` num_processes as your gpu number -accelerate launch --config_file accelerate_config_machine_single.yaml \ - train_cogvideox_lora.py \ - --gradient_checkpointing \ - --pretrained_model_name_or_path $MODEL_PATH \ - --cache_dir $CACHE_PATH \ - --enable_tiling \ - --enable_slicing \ - --instance_data_root $DATASET_PATH \ - --caption_column prompt.txt \ - --video_column videos.txt \ - --validation_prompt "DISNEY A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions:::A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance" \ - --validation_prompt_separator ::: \ - --num_validation_videos 1 \ - --validation_epochs 100 \ - --seed 42 \ - --rank 128 \ - --lora_alpha 64 \ - --mixed_precision bf16 \ - --output_dir $OUTPUT_PATH \ - --height 480 \ - --width 720 \ - --fps 8 \ - --max_num_frames 49 \ - --skip_frames_start 0 \ - --skip_frames_end 0 \ - --train_batch_size 1 \ - --num_train_epochs 30 \ - --checkpointing_steps 1000 \ - --gradient_accumulation_steps 1 \ - --learning_rate 1e-3 \ - --lr_scheduler cosine_with_restarts \ - --lr_warmup_steps 200 \ - --lr_num_cycles 1 \ - --enable_slicing \ - --enable_tiling \ - --gradient_checkpointing \ - --optimizer AdamW \ - --adam_beta1 0.9 \ - --adam_beta2 0.95 \ - --max_grad_norm 1.0 \ - --allow_tf32 \ - --report_to wandb \ No newline at end of file diff --git a/finetune/train_cogvideox_image_to_video_lora.py b/finetune/train_cogvideox_image_to_video_lora.py deleted file mode 100644 index abf245f..0000000 --- a/finetune/train_cogvideox_image_to_video_lora.py +++ /dev/null @@ -1,1689 +0,0 @@ -# Copyright 2024 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import logging -import math -import os -import random -import shutil -from datetime import timedelta -from pathlib import Path -from typing import List, Optional, Tuple, Union - -import torch -import transformers -from accelerate import Accelerator -from accelerate.logging import get_logger -from accelerate.utils import DistributedDataParallelKwargs, InitProcessGroupKwargs, ProjectConfiguration, set_seed -from huggingface_hub import create_repo, upload_folder -from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict -from torch.utils.data import DataLoader, Dataset -from torchvision import transforms -from tqdm.auto import tqdm -from transformers import AutoTokenizer, T5EncoderModel, T5Tokenizer - -import diffusers -from diffusers import ( - AutoencoderKLCogVideoX, - CogVideoXDPMScheduler, - CogVideoXImageToVideoPipeline, - CogVideoXTransformer3DModel, -) -from diffusers.models.embeddings import get_3d_rotary_pos_embed -from diffusers.optimization import get_scheduler -from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid -from diffusers.training_utils import cast_training_params, free_memory -from diffusers.utils import ( - check_min_version, - convert_unet_state_dict_to_peft, - export_to_video, - is_wandb_available, - load_image, -) -from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card -from diffusers.utils.torch_utils import is_compiled_module -from torchvision.transforms.functional import center_crop, resize -from torchvision.transforms import InterpolationMode -import torchvision.transforms as TT -import numpy as np -from diffusers.image_processor import VaeImageProcessor - - -if is_wandb_available(): - import wandb - -# Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.31.0.dev0") - -logger = get_logger(__name__) - - -def get_args(): - parser = argparse.ArgumentParser(description="Simple example of a training script for CogVideoX.") - - # Model information - parser.add_argument( - "--pretrained_model_name_or_path", - type=str, - default=None, - required=True, - help="Path to pretrained model or model identifier from huggingface.co/models.", - ) - parser.add_argument( - "--revision", - type=str, - default=None, - required=False, - help="Revision of pretrained model identifier from huggingface.co/models.", - ) - parser.add_argument( - "--variant", - type=str, - default=None, - help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", - ) - parser.add_argument( - "--cache_dir", - type=str, - default=None, - help="The directory where the downloaded models and datasets will be stored.", - ) - - # Dataset information - parser.add_argument( - "--dataset_name", - type=str, - default=None, - help=( - "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private," - " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," - " or to a folder containing files that 🤗 Datasets can understand." - ), - ) - parser.add_argument( - "--dataset_config_name", - type=str, - default=None, - help="The config of the Dataset, leave as None if there's only one config.", - ) - parser.add_argument( - "--instance_data_root", - type=str, - default=None, - help=("A folder containing the training data."), - ) - parser.add_argument( - "--video_column", - type=str, - default="video", - help="The column of the dataset containing videos. Or, the name of the file in `--instance_data_root` folder containing the line-separated path to video data.", - ) - parser.add_argument( - "--caption_column", - type=str, - default="text", - help="The column of the dataset containing the instance prompt for each video. Or, the name of the file in `--instance_data_root` folder containing the line-separated instance prompts.", - ) - parser.add_argument( - "--id_token", type=str, default=None, help="Identifier token appended to the start of each prompt if provided." - ) - parser.add_argument( - "--dataloader_num_workers", - type=int, - default=0, - help=( - "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." - ), - ) - - # Validation - parser.add_argument( - "--validation_prompt", - type=str, - default=None, - help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.", - ) - parser.add_argument( - "--validation_images", - type=str, - default=None, - help="One or more image path(s) that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_seperator' string. These should correspond to the order of the validation prompts.", - ) - parser.add_argument( - "--validation_prompt_separator", - type=str, - default=":::", - help="String that separates multiple validation prompts", - ) - parser.add_argument( - "--num_validation_videos", - type=int, - default=1, - help="Number of videos that should be generated during validation per `validation_prompt`.", - ) - parser.add_argument( - "--validation_epochs", - type=int, - default=50, - help=( - "Run validation every X epochs. Validation consists of running the prompt `args.validation_prompt` multiple times: `args.num_validation_videos`." - ), - ) - parser.add_argument( - "--guidance_scale", - type=float, - default=6, - help="The guidance scale to use while sampling validation videos.", - ) - parser.add_argument( - "--use_dynamic_cfg", - action="store_true", - default=False, - help="Whether or not to use the default cosine dynamic guidance schedule when sampling validation videos.", - ) - - # Training information - parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") - parser.add_argument( - "--rank", - type=int, - default=128, - help=("The dimension of the LoRA update matrices."), - ) - parser.add_argument( - "--lora_alpha", - type=float, - default=128, - help=("The scaling factor to scale LoRA weight update. The actual scaling factor is `lora_alpha / rank`"), - ) - parser.add_argument( - "--mixed_precision", - type=str, - default=None, - choices=["no", "fp16", "bf16"], - help=( - "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" - " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" - " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." - ), - ) - parser.add_argument( - "--output_dir", - type=str, - default="cogvideox-i2v-lora", - help="The output directory where the model predictions and checkpoints will be written.", - ) - parser.add_argument( - "--height", - type=int, - default=480, - help="All input videos are resized to this height.", - ) - parser.add_argument( - "--width", - type=int, - default=720, - help="All input videos are resized to this width.", - ) - parser.add_argument( - "--video_reshape_mode", - type=str, - default="center", - help="All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']", - ) - parser.add_argument("--fps", type=int, default=8, help="All input videos will be used at this FPS.") - parser.add_argument( - "--max_num_frames", type=int, default=49, help="All input videos will be truncated to these many frames." - ) - parser.add_argument( - "--skip_frames_start", - type=int, - default=0, - help="Number of frames to skip from the beginning of each input video. Useful if training data contains intro sequences.", - ) - parser.add_argument( - "--skip_frames_end", - type=int, - default=0, - help="Number of frames to skip from the end of each input video. Useful if training data contains outro sequences.", - ) - parser.add_argument( - "--random_flip", - action="store_true", - help="whether to randomly flip videos horizontally", - ) - parser.add_argument( - "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." - ) - parser.add_argument("--num_train_epochs", type=int, default=1) - parser.add_argument( - "--max_train_steps", - type=int, - default=None, - help="Total number of training steps to perform. If provided, overrides `--num_train_epochs`.", - ) - parser.add_argument( - "--checkpointing_steps", - type=int, - default=500, - help=( - "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" - " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" - " training using `--resume_from_checkpoint`." - ), - ) - parser.add_argument( - "--checkpoints_total_limit", - type=int, - default=None, - help=("Max number of checkpoints to store."), - ) - parser.add_argument( - "--resume_from_checkpoint", - type=str, - default=None, - help=( - "Whether training should be resumed from a previous checkpoint. Use a path saved by" - ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' - ), - ) - parser.add_argument( - "--gradient_accumulation_steps", - type=int, - default=1, - help="Number of updates steps to accumulate before performing a backward/update pass.", - ) - parser.add_argument( - "--gradient_checkpointing", - action="store_true", - help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", - ) - parser.add_argument( - "--learning_rate", - type=float, - default=1e-4, - help="Initial learning rate (after the potential warmup period) to use.", - ) - parser.add_argument( - "--scale_lr", - action="store_true", - default=False, - help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", - ) - parser.add_argument( - "--lr_scheduler", - type=str, - default="constant", - help=( - 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' - ' "constant", "constant_with_warmup"]' - ), - ) - parser.add_argument( - "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." - ) - parser.add_argument( - "--lr_num_cycles", - type=int, - default=1, - help="Number of hard resets of the lr in cosine_with_restarts scheduler.", - ) - parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") - parser.add_argument( - "--enable_slicing", - action="store_true", - default=False, - help="Whether or not to use VAE slicing for saving memory.", - ) - parser.add_argument( - "--enable_tiling", - action="store_true", - default=False, - help="Whether or not to use VAE tiling for saving memory.", - ) - parser.add_argument( - "--noised_image_dropout", - type=float, - default=0.05, - help="Image condition dropout probability.", - ) - - # Optimizer - parser.add_argument( - "--optimizer", - type=lambda s: s.lower(), - default="adam", - choices=["adam", "adamw", "prodigy"], - help=("The optimizer type to use."), - ) - parser.add_argument( - "--use_8bit_adam", - action="store_true", - help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW", - ) - parser.add_argument( - "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers." - ) - parser.add_argument( - "--adam_beta2", type=float, default=0.95, help="The beta2 parameter for the Adam and Prodigy optimizers." - ) - parser.add_argument( - "--prodigy_beta3", - type=float, - default=None, - help="Coefficients for computing the Prodigy optimizer's stepsize using running averages. If set to None, uses the value of square root of beta2.", - ) - parser.add_argument("--prodigy_decouple", action="store_true", help="Use AdamW style decoupled weight decay") - parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") - parser.add_argument( - "--adam_epsilon", - type=float, - default=1e-08, - help="Epsilon value for the Adam optimizer and Prodigy optimizers.", - ) - parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") - parser.add_argument("--prodigy_use_bias_correction", action="store_true", help="Turn on Adam's bias correction.") - parser.add_argument( - "--prodigy_safeguard_warmup", - action="store_true", - help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage.", - ) - - # Other information - parser.add_argument("--tracker_name", type=str, default=None, help="Project tracker name") - parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") - parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") - parser.add_argument( - "--hub_model_id", - type=str, - default=None, - help="The name of the repository to keep in sync with the local `output_dir`.", - ) - parser.add_argument( - "--logging_dir", - type=str, - default="logs", - help="Directory where logs are stored.", - ) - parser.add_argument( - "--allow_tf32", - action="store_true", - help=( - "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" - " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" - ), - ) - parser.add_argument( - "--report_to", - type=str, - default=None, - help=( - 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' - ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' - ), - ) - parser.add_argument("--nccl_timeout", type=int, default=600, help="NCCL backend timeout in seconds.") - - return parser.parse_args() - - -class VideoDataset(Dataset): - def __init__( - self, - instance_data_root: Optional[str] = None, - dataset_name: Optional[str] = None, - dataset_config_name: Optional[str] = None, - caption_column: str = "text", - video_column: str = "video", - height: int = 480, - width: int = 720, - video_reshape_mode: str = "center", - fps: int = 8, - max_num_frames: int = 49, - skip_frames_start: int = 0, - skip_frames_end: int = 0, - cache_dir: Optional[str] = None, - id_token: Optional[str] = None, - ) -> None: - super().__init__() - - self.instance_data_root = Path(instance_data_root) if instance_data_root is not None else None - self.dataset_name = dataset_name - self.dataset_config_name = dataset_config_name - self.caption_column = caption_column - self.video_column = video_column - self.height = height - self.width = width - self.video_reshape_mode = video_reshape_mode - self.fps = fps - self.max_num_frames = max_num_frames - self.skip_frames_start = skip_frames_start - self.skip_frames_end = skip_frames_end - self.cache_dir = cache_dir - self.id_token = id_token or "" - - if dataset_name is not None: - self.instance_prompts, self.instance_video_paths = self._load_dataset_from_hub() - else: - self.instance_prompts, self.instance_video_paths = self._load_dataset_from_local_path() - - self.instance_prompts = [self.id_token + prompt for prompt in self.instance_prompts] - - self.num_instance_videos = len(self.instance_video_paths) - if self.num_instance_videos != len(self.instance_prompts): - raise ValueError( - f"Expected length of instance prompts and videos to be the same but found {len(self.instance_prompts)=} and {len(self.instance_video_paths)=}. Please ensure that the number of caption prompts and videos match in your dataset." - ) - - self.instance_videos = self._preprocess_data() - - def __len__(self): - return self.num_instance_videos - - def __getitem__(self, index): - return { - "instance_prompt": self.instance_prompts[index], - "instance_video": self.instance_videos[index], - } - - def _load_dataset_from_hub(self): - try: - from datasets import load_dataset - except ImportError: - raise ImportError( - "You are trying to load your data using the datasets library. If you wish to train using custom " - "captions please install the datasets library: `pip install datasets`. If you wish to load a " - "local folder containing images only, specify --instance_data_root instead." - ) - - # Downloading and loading a dataset from the hub. See more about loading custom images at - # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script - dataset = load_dataset( - self.dataset_name, - self.dataset_config_name, - cache_dir=self.cache_dir, - ) - column_names = dataset["train"].column_names - - if self.video_column is None: - video_column = column_names[0] - logger.info(f"`video_column` defaulting to {video_column}") - else: - video_column = self.video_column - if video_column not in column_names: - raise ValueError( - f"`--video_column` value '{video_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" - ) - - if self.caption_column is None: - caption_column = column_names[1] - logger.info(f"`caption_column` defaulting to {caption_column}") - else: - caption_column = self.caption_column - if self.caption_column not in column_names: - raise ValueError( - f"`--caption_column` value '{self.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" - ) - - instance_prompts = dataset["train"][caption_column] - instance_videos = [Path(self.instance_data_root, filepath) for filepath in dataset["train"][video_column]] - - return instance_prompts, instance_videos - - def _load_dataset_from_local_path(self): - if not self.instance_data_root.exists(): - raise ValueError("Instance videos root folder does not exist") - - prompt_path = self.instance_data_root.joinpath(self.caption_column) - video_path = self.instance_data_root.joinpath(self.video_column) - - if not prompt_path.exists() or not prompt_path.is_file(): - raise ValueError( - "Expected `--caption_column` to be path to a file in `--instance_data_root` containing line-separated text prompts." - ) - if not video_path.exists() or not video_path.is_file(): - raise ValueError( - "Expected `--video_column` to be path to a file in `--instance_data_root` containing line-separated paths to video data in the same directory." - ) - - with open(prompt_path, "r", encoding="utf-8") as file: - instance_prompts = [line.strip() for line in file.readlines() if len(line.strip()) > 0] - with open(video_path, "r", encoding="utf-8") as file: - instance_videos = [ - self.instance_data_root.joinpath(line.strip()) for line in file.readlines() if len(line.strip()) > 0 - ] - - if any(not path.is_file() for path in instance_videos): - raise ValueError( - "Expected '--video_column' to be a path to a file in `--instance_data_root` containing line-separated paths to video data but found atleast one path that is not a valid file." - ) - - return instance_prompts, instance_videos - - def _resize_for_rectangle_crop(self, arr): - image_size = self.height, self.width - reshape_mode = self.video_reshape_mode - if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]: - arr = resize( - arr, - size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])], - interpolation=InterpolationMode.BICUBIC, - ) - else: - arr = resize( - arr, - size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]], - interpolation=InterpolationMode.BICUBIC, - ) - - h, w = arr.shape[2], arr.shape[3] - arr = arr.squeeze(0) - - delta_h = h - image_size[0] - delta_w = w - image_size[1] - - if reshape_mode == "random" or reshape_mode == "none": - top = np.random.randint(0, delta_h + 1) - left = np.random.randint(0, delta_w + 1) - elif reshape_mode == "center": - top, left = delta_h // 2, delta_w // 2 - else: - raise NotImplementedError - arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1]) - return arr - - def _preprocess_data(self): - try: - import decord - except ImportError: - raise ImportError( - "The `decord` package is required for loading the video dataset. Install with `pip install decord`" - ) - - decord.bridge.set_bridge("torch") - - progress_dataset_bar = tqdm( - range(0, len(self.instance_video_paths)), - desc="Loading progress resize and crop videos", - ) - - videos = [] - - for filename in self.instance_video_paths: - video_reader = decord.VideoReader(uri=filename.as_posix()) - video_num_frames = len(video_reader) - - start_frame = min(self.skip_frames_start, video_num_frames) - end_frame = max(0, video_num_frames - self.skip_frames_end) - if end_frame <= start_frame: - frames = video_reader.get_batch([start_frame]) - elif end_frame - start_frame <= self.max_num_frames: - frames = video_reader.get_batch(list(range(start_frame, end_frame))) - else: - indices = list(range(start_frame, end_frame, (end_frame - start_frame) // self.max_num_frames)) - frames = video_reader.get_batch(indices) - - # Ensure that we don't go over the limit - frames = frames[: self.max_num_frames] - selected_num_frames = frames.shape[0] - - # Choose first (4k + 1) frames as this is how many is required by the VAE - remainder = (3 + (selected_num_frames % 4)) % 4 - if remainder != 0: - frames = frames[:-remainder] - selected_num_frames = frames.shape[0] - - assert (selected_num_frames - 1) % 4 == 0 - - # Training transforms - frames = (frames - 127.5) / 127.5 - frames = frames.permute(0, 3, 1, 2) # [F, C, H, W] - progress_dataset_bar.set_description( - f"Loading progress Resizing video from {frames.shape[2]}x{frames.shape[3]} to {self.height}x{self.width}" - ) - frames = self._resize_for_rectangle_crop(frames) - videos.append(frames.contiguous()) # [F, C, H, W] - progress_dataset_bar.update(1) - - progress_dataset_bar.close() - - return videos - - -def save_model_card( - repo_id: str, - videos=None, - base_model: str = None, - validation_prompt=None, - repo_folder=None, - fps=8, -): - widget_dict = [] - if videos is not None: - for i, video in enumerate(videos): - video_path = f"final_video_{i}.mp4" - export_to_video(video, os.path.join(repo_folder, video_path, fps=fps)) - widget_dict.append( - {"text": validation_prompt if validation_prompt else " ", "output": {"url": video_path}}, - ) - - model_description = f""" -# CogVideoX LoRA - {repo_id} - - - -## Model description - -These are {repo_id} LoRA weights for {base_model}. - -The weights were trained using the [CogVideoX Diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/cogvideo/train_cogvideox_image_to_video_lora.py). - -Was LoRA for the text encoder enabled? No. - -## Download model - -[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab. - -## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers) - -```py -import torch -from diffusers import CogVideoXImageToVideoPipeline -from diffusers.utils import load_image, export_to_video - -pipe = CogVideoXImageToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to("cuda") -pipe.load_lora_weights("{repo_id}", weight_name="pytorch_lora_weights.safetensors", adapter_name=["cogvideox-i2v-lora"]) - -# The LoRA adapter weights are determined by what was used for training. -# In this case, we assume `--lora_alpha` is 32 and `--rank` is 64. -# It can be made lower or higher from what was used in training to decrease or amplify the effect -# of the LoRA upto a tolerance, beyond which one might notice no effect at all or overflows. -pipe.set_adapters(["cogvideox-i2v-lora"], [32 / 64]) - -image = load_image("/path/to/image") -video = pipe(image=image, "{validation_prompt}", guidance_scale=6, use_dynamic_cfg=True).frames[0] -export_to_video(video, "output.mp4", fps=8) -``` - -For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) - -## License - -Please adhere to the licensing terms as described [here](https://huggingface.co/THUDM/CogVideoX-5b-I2V/blob/main/LICENSE). -""" - model_card = load_or_create_model_card( - repo_id_or_path=repo_id, - from_training=True, - license="other", - base_model=base_model, - prompt=validation_prompt, - model_description=model_description, - widget=widget_dict, - ) - tags = [ - "image-to-video", - "diffusers-training", - "diffusers", - "lora", - "cogvideox", - "cogvideox-diffusers", - "template:sd-lora", - ] - - model_card = populate_model_card(model_card, tags=tags) - model_card.save(os.path.join(repo_folder, "README.md")) - - -def log_validation( - pipe, - args, - accelerator, - pipeline_args, - epoch, - is_final_validation: bool = False, -): - logger.info( - f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}." - ) - # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it - scheduler_args = {} - - if "variance_type" in pipe.scheduler.config: - variance_type = pipe.scheduler.config.variance_type - - if variance_type in ["learned", "learned_range"]: - variance_type = "fixed_small" - - scheduler_args["variance_type"] = variance_type - - pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, **scheduler_args) - pipe = pipe.to(accelerator.device) - # pipe.set_progress_bar_config(disable=True) - - # run inference - generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None - - videos = [] - for _ in range(args.num_validation_videos): - pt_images = pipe(**pipeline_args, generator=generator, output_type="pt").frames[0] - pt_images = torch.stack([pt_images[i] for i in range(pt_images.shape[0])]) - - image_np = VaeImageProcessor.pt_to_numpy(pt_images) - image_pil = VaeImageProcessor.numpy_to_pil(image_np) - - videos.append(image_pil) - - for tracker in accelerator.trackers: - phase_name = "test" if is_final_validation else "validation" - if tracker.name == "wandb": - video_filenames = [] - for i, video in enumerate(videos): - prompt = ( - pipeline_args["prompt"][:25] - .replace(" ", "_") - .replace(" ", "_") - .replace("'", "_") - .replace('"', "_") - .replace("/", "_") - ) - filename = os.path.join(args.output_dir, f"{phase_name}_video_{i}_{prompt}.mp4") - export_to_video(video, filename, fps=8) - video_filenames.append(filename) - - tracker.log( - { - phase_name: [ - wandb.Video(filename, caption=f"{i}: {pipeline_args['prompt']}") - for i, filename in enumerate(video_filenames) - ] - } - ) - - del pipe - free_memory() - - return videos - - -def _get_t5_prompt_embeds( - tokenizer: T5Tokenizer, - text_encoder: T5EncoderModel, - prompt: Union[str, List[str]], - num_videos_per_prompt: int = 1, - max_sequence_length: int = 226, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - text_input_ids=None, -): - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - if tokenizer is not None: - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - add_special_tokens=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - else: - if text_input_ids is None: - raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.") - - prompt_embeds = text_encoder(text_input_ids.to(device))[0] - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - - # duplicate text embeddings for each generation per prompt, using mps friendly method - _, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) - - return prompt_embeds - - -def encode_prompt( - tokenizer: T5Tokenizer, - text_encoder: T5EncoderModel, - prompt: Union[str, List[str]], - num_videos_per_prompt: int = 1, - max_sequence_length: int = 226, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - text_input_ids=None, -): - prompt = [prompt] if isinstance(prompt, str) else prompt - prompt_embeds = _get_t5_prompt_embeds( - tokenizer, - text_encoder, - prompt=prompt, - num_videos_per_prompt=num_videos_per_prompt, - max_sequence_length=max_sequence_length, - device=device, - dtype=dtype, - text_input_ids=text_input_ids, - ) - return prompt_embeds - - -def compute_prompt_embeddings( - tokenizer, text_encoder, prompt, max_sequence_length, device, dtype, requires_grad: bool = False -): - if requires_grad: - prompt_embeds = encode_prompt( - tokenizer, - text_encoder, - prompt, - num_videos_per_prompt=1, - max_sequence_length=max_sequence_length, - device=device, - dtype=dtype, - ) - else: - with torch.no_grad(): - prompt_embeds = encode_prompt( - tokenizer, - text_encoder, - prompt, - num_videos_per_prompt=1, - max_sequence_length=max_sequence_length, - device=device, - dtype=dtype, - ) - return prompt_embeds - - -def prepare_rotary_positional_embeddings( - height: int, - width: int, - num_frames: int, - vae_scale_factor_spatial: int = 8, - patch_size: int = 2, - patch_size_t: int = 1, - attention_head_dim: int = 64, - device: Optional[torch.device] = None, - base_height: int = 480, - base_width: int = 720, -) -> Tuple[torch.Tensor, torch.Tensor]: - grid_height = height // (vae_scale_factor_spatial * patch_size) - grid_width = width // (vae_scale_factor_spatial * patch_size) - base_size_width = base_width // (vae_scale_factor_spatial * patch_size) - base_size_height = base_height // (vae_scale_factor_spatial * patch_size) - - p_t = patch_size_t - base_num_frames = (num_frames + p_t - 1) // p_t - - grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, base_size_height) - freqs_cos, freqs_sin = get_3d_rotary_pos_embed( - embed_dim=attention_head_dim, - crops_coords=grid_crops_coords, - grid_size=(grid_height, grid_width), - temporal_size=base_num_frames, - ) - - freqs_cos = freqs_cos.to(device=device) - freqs_sin = freqs_sin.to(device=device) - return freqs_cos, freqs_sin - - -def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False): - # Use DeepSpeed optimzer - if use_deepspeed: - from accelerate.utils import DummyOptim - - return DummyOptim( - params_to_optimize, - lr=args.learning_rate, - betas=(args.adam_beta1, args.adam_beta2), - eps=args.adam_epsilon, - weight_decay=args.adam_weight_decay, - ) - - # Optimizer creation - supported_optimizers = ["adam", "adamw", "prodigy"] - if args.optimizer not in supported_optimizers: - logger.warning( - f"Unsupported choice of optimizer: {args.optimizer}. Supported optimizers include {supported_optimizers}. Defaulting to AdamW" - ) - args.optimizer = "adamw" - - if args.use_8bit_adam and args.optimizer.lower() not in ["adam", "adamw"]: - logger.warning( - f"use_8bit_adam is ignored when optimizer is not set to 'Adam' or 'AdamW'. Optimizer was " - f"set to {args.optimizer.lower()}" - ) - - if args.use_8bit_adam: - try: - import bitsandbytes as bnb - except ImportError: - raise ImportError( - "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." - ) - - if args.optimizer.lower() == "adamw": - optimizer_class = bnb.optim.AdamW8bit if args.use_8bit_adam else torch.optim.AdamW - - optimizer = optimizer_class( - params_to_optimize, - betas=(args.adam_beta1, args.adam_beta2), - eps=args.adam_epsilon, - weight_decay=args.adam_weight_decay, - ) - elif args.optimizer.lower() == "adam": - optimizer_class = bnb.optim.Adam8bit if args.use_8bit_adam else torch.optim.Adam - - optimizer = optimizer_class( - params_to_optimize, - betas=(args.adam_beta1, args.adam_beta2), - eps=args.adam_epsilon, - weight_decay=args.adam_weight_decay, - ) - elif args.optimizer.lower() == "prodigy": - try: - import prodigyopt - except ImportError: - raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") - - optimizer_class = prodigyopt.Prodigy - - if args.learning_rate <= 0.1: - logger.warning( - "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" - ) - - optimizer = optimizer_class( - params_to_optimize, - lr=args.learning_rate, - betas=(args.adam_beta1, args.adam_beta2), - beta3=args.prodigy_beta3, - weight_decay=args.adam_weight_decay, - eps=args.adam_epsilon, - decouple=args.prodigy_decouple, - use_bias_correction=args.prodigy_use_bias_correction, - safeguard_warmup=args.prodigy_safeguard_warmup, - ) - - return optimizer - - -def main(args): - if args.report_to == "wandb" and args.hub_token is not None: - raise ValueError( - "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." - " Please use `huggingface-cli login` to authenticate with the Hub." - ) - - if torch.backends.mps.is_available() and args.mixed_precision == "bf16": - # due to pytorch#99272, MPS does not yet support bfloat16. - raise ValueError( - "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." - ) - - logging_dir = Path(args.output_dir, args.logging_dir) - - accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) - ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) - init_kwargs = InitProcessGroupKwargs(backend="nccl", timeout=timedelta(seconds=args.nccl_timeout)) - accelerator = Accelerator( - gradient_accumulation_steps=args.gradient_accumulation_steps, - mixed_precision=args.mixed_precision, - log_with=args.report_to, - project_config=accelerator_project_config, - kwargs_handlers=[ddp_kwargs, init_kwargs], - ) - - # Disable AMP for MPS. - if torch.backends.mps.is_available(): - accelerator.native_amp = False - - if args.report_to == "wandb": - if not is_wandb_available(): - raise ImportError("Make sure to install wandb if you want to use it for logging during training.") - - # Make one log on every process with the configuration for debugging. - logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - level=logging.INFO, - ) - logger.info(accelerator.state, main_process_only=False) - if accelerator.is_local_main_process: - transformers.utils.logging.set_verbosity_warning() - diffusers.utils.logging.set_verbosity_info() - else: - transformers.utils.logging.set_verbosity_error() - diffusers.utils.logging.set_verbosity_error() - - # If passed along, set the training seed now. - if args.seed is not None: - set_seed(args.seed) - - # Handle the repository creation - if accelerator.is_main_process: - if args.output_dir is not None: - os.makedirs(args.output_dir, exist_ok=True) - - if args.push_to_hub: - repo_id = create_repo( - repo_id=args.hub_model_id or Path(args.output_dir).name, - exist_ok=True, - ).repo_id - - # Prepare models and scheduler - tokenizer = AutoTokenizer.from_pretrained( - args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision - ) - - text_encoder = T5EncoderModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision - ) - - # CogVideoX-2b weights are stored in float16 - # CogVideoX-5b and CogVideoX-5b-I2V weights are stored in bfloat16 - load_dtype = torch.bfloat16 if "5b" in args.pretrained_model_name_or_path.lower() else torch.float16 - transformer = CogVideoXTransformer3DModel.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="transformer", - torch_dtype=load_dtype, - revision=args.revision, - variant=args.variant, - ) - - vae = AutoencoderKLCogVideoX.from_pretrained( - args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant - ) - - scheduler = CogVideoXDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") - - if args.enable_slicing: - vae.enable_slicing() - if args.enable_tiling: - vae.enable_tiling() - - # We only train the additional adapter LoRA layers - text_encoder.requires_grad_(False) - transformer.requires_grad_(False) - vae.requires_grad_(False) - - # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision - # as these weights are only used for inference, keeping weights in full precision is not required. - weight_dtype = torch.float32 - if accelerator.state.deepspeed_plugin: - # DeepSpeed is handling precision, use what's in the DeepSpeed config - if ( - "fp16" in accelerator.state.deepspeed_plugin.deepspeed_config - and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"] - ): - weight_dtype = torch.float16 - if ( - "bf16" in accelerator.state.deepspeed_plugin.deepspeed_config - and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"] - ): - weight_dtype = torch.float16 - else: - if accelerator.mixed_precision == "fp16": - weight_dtype = torch.float16 - elif accelerator.mixed_precision == "bf16": - weight_dtype = torch.bfloat16 - - if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: - # due to pytorch#99272, MPS does not yet support bfloat16. - raise ValueError( - "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." - ) - - text_encoder.to(accelerator.device, dtype=weight_dtype) - transformer.to(accelerator.device, dtype=weight_dtype) - vae.to(accelerator.device, dtype=weight_dtype) - - if args.gradient_checkpointing: - transformer.enable_gradient_checkpointing() - - # now we will add new LoRA weights to the attention layers - transformer_lora_config = LoraConfig( - r=args.rank, - lora_alpha=args.lora_alpha, - init_lora_weights=True, - target_modules=["to_k", "to_q", "to_v", "to_out.0"], - ) - transformer.add_adapter(transformer_lora_config) - - def unwrap_model(model): - model = accelerator.unwrap_model(model) - model = model._orig_mod if is_compiled_module(model) else model - return model - - # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format - def save_model_hook(models, weights, output_dir): - if accelerator.is_main_process: - transformer_lora_layers_to_save = None - - for model in models: - if isinstance(model, type(unwrap_model(transformer))): - transformer_lora_layers_to_save = get_peft_model_state_dict(model) - else: - raise ValueError(f"unexpected save model: {model.__class__}") - - # make sure to pop weight so that corresponding model is not saved again - weights.pop() - - CogVideoXImageToVideoPipeline.save_lora_weights( - output_dir, - transformer_lora_layers=transformer_lora_layers_to_save, - ) - - def load_model_hook(models, input_dir): - transformer_ = None - - while len(models) > 0: - model = models.pop() - - if isinstance(model, type(unwrap_model(transformer))): - transformer_ = model - else: - raise ValueError(f"Unexpected save model: {model.__class__}") - - lora_state_dict = CogVideoXImageToVideoPipeline.lora_state_dict(input_dir) - - transformer_state_dict = { - f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") - } - transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) - incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") - if incompatible_keys is not None: - # check only for unexpected keys - unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) - if unexpected_keys: - logger.warning( - f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " - f" {unexpected_keys}. " - ) - - # Make sure the trainable params are in float32. This is again needed since the base models - # are in `weight_dtype`. More details: - # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 - if args.mixed_precision == "fp16": - # only upcast trainable parameters (LoRA) into fp32 - cast_training_params([transformer_]) - - accelerator.register_save_state_pre_hook(save_model_hook) - accelerator.register_load_state_pre_hook(load_model_hook) - - # Enable TF32 for faster training on Ampere GPUs, - # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices - if args.allow_tf32 and torch.cuda.is_available(): - torch.backends.cuda.matmul.allow_tf32 = True - - if args.scale_lr: - args.learning_rate = ( - args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes - ) - - # Make sure the trainable params are in float32. - if args.mixed_precision == "fp16": - # only upcast trainable parameters (LoRA) into fp32 - cast_training_params([transformer], dtype=torch.float32) - - transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) - - # Optimization parameters - transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate} - params_to_optimize = [transformer_parameters_with_lr] - - use_deepspeed_optimizer = ( - accelerator.state.deepspeed_plugin is not None - and accelerator.state.deepspeed_plugin.deepspeed_config.get("optimizer", "none").lower() != "none" - ) - use_deepspeed_scheduler = ( - accelerator.state.deepspeed_plugin is not None - and accelerator.state.deepspeed_plugin.deepspeed_config.get("scheduler", "none").lower() != "none" - ) - - optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer) - - # Dataset and DataLoader - train_dataset = VideoDataset( - instance_data_root=args.instance_data_root, - dataset_name=args.dataset_name, - dataset_config_name=args.dataset_config_name, - caption_column=args.caption_column, - video_column=args.video_column, - height=args.height, - width=args.width, - video_reshape_mode=args.video_reshape_mode, - fps=args.fps, - max_num_frames=args.max_num_frames, - skip_frames_start=args.skip_frames_start, - skip_frames_end=args.skip_frames_end, - cache_dir=args.cache_dir, - id_token=args.id_token, - ) - - def encode_video(video, bar): - bar.update(1) - video = video.to(accelerator.device, dtype=vae.dtype).unsqueeze(0) - video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] - image = video[:, :, :1].clone() - - latent_dist = vae.encode(video).latent_dist - - image_noise_sigma = torch.normal(mean=-3.0, std=0.5, size=(1,), device=image.device) - image_noise_sigma = torch.exp(image_noise_sigma).to(dtype=image.dtype) - noisy_image = image + torch.randn_like(image) * image_noise_sigma[:, None, None, None, None] - image_latent_dist = vae.encode(noisy_image).latent_dist - - return latent_dist, image_latent_dist - - train_dataset.instance_prompts = [ - compute_prompt_embeddings( - tokenizer, - text_encoder, - [prompt], - transformer.config.max_text_seq_length, - accelerator.device, - weight_dtype, - requires_grad=False, - ) - for prompt in train_dataset.instance_prompts - ] - - progress_encode_bar = tqdm( - range(0, len(train_dataset.instance_videos)), - desc="Loading Encode videos", - ) - train_dataset.instance_videos = [encode_video(video, progress_encode_bar) for video in train_dataset.instance_videos] - progress_encode_bar.close() - - def collate_fn(examples): - videos = [] - images = [] - for example in examples: - latent_dist, image_latent_dist = example["instance_video"] - - video_latents = latent_dist.sample() * vae.config.scaling_factor - image_latents = image_latent_dist.sample() * vae.config.scaling_factor - video_latents = video_latents.permute(0, 2, 1, 3, 4) - image_latents = image_latents.permute(0, 2, 1, 3, 4) - - padding_shape = (video_latents.shape[0], video_latents.shape[1] - 1, *video_latents.shape[2:]) - latent_padding = image_latents.new_zeros(padding_shape) - image_latents = torch.cat([image_latents, latent_padding], dim=1) - - if random.random() < args.noised_image_dropout: - image_latents = torch.zeros_like(image_latents) - - videos.append(video_latents) - images.append(image_latents) - - videos = torch.cat(videos) - images = torch.cat(images) - videos = videos.to(memory_format=torch.contiguous_format).float() - images = images.to(memory_format=torch.contiguous_format).float() - - prompts = [example["instance_prompt"] for example in examples] - prompts = torch.cat(prompts) - - return { - "videos": (videos, images), - "prompts": prompts, - } - - train_dataloader = DataLoader( - train_dataset, - batch_size=args.train_batch_size, - shuffle=True, - collate_fn=collate_fn, - num_workers=args.dataloader_num_workers, - ) - - # Scheduler and math around the number of training steps. - overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if args.max_train_steps is None: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - overrode_max_train_steps = True - - if use_deepspeed_scheduler: - from accelerate.utils import DummyScheduler - - lr_scheduler = DummyScheduler( - optimizer=optimizer, - total_num_steps=args.max_train_steps * accelerator.num_processes, - warmup_num_steps=args.lr_warmup_steps * accelerator.num_processes, - ) - else: - lr_scheduler = get_scheduler( - args.lr_scheduler, - optimizer=optimizer, - num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, - num_training_steps=args.max_train_steps * accelerator.num_processes, - num_cycles=args.lr_num_cycles, - power=args.lr_power, - ) - - # Prepare everything with our `accelerator`. - transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - transformer, optimizer, train_dataloader, lr_scheduler - ) - - # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if overrode_max_train_steps: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - # Afterwards we recalculate our number of training epochs - args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - - # We need to initialize the trackers we use, and also store our configuration. - # The trackers initializes automatically on the main process. - if accelerator.is_main_process: - tracker_name = args.tracker_name or "cogvideox-i2v-lora" - accelerator.init_trackers(tracker_name, config=vars(args)) - - # Train! - total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - num_trainable_parameters = sum(param.numel() for model in params_to_optimize for param in model["params"]) - - logger.info("***** Running training *****") - logger.info(f" Num trainable parameters = {num_trainable_parameters}") - logger.info(f" Num examples = {len(train_dataset)}") - logger.info(f" Num batches each epoch = {len(train_dataloader)}") - logger.info(f" Num epochs = {args.num_train_epochs}") - logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") - logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") - logger.info(f" Gradient accumulation steps = {args.gradient_accumulation_steps}") - logger.info(f" Total optimization steps = {args.max_train_steps}") - global_step = 0 - first_epoch = 0 - - # Potentially load in the weights and states from a previous save - if not args.resume_from_checkpoint: - initial_global_step = 0 - else: - if args.resume_from_checkpoint != "latest": - path = os.path.basename(args.resume_from_checkpoint) - else: - # Get the mos recent checkpoint - dirs = os.listdir(args.output_dir) - dirs = [d for d in dirs if d.startswith("checkpoint")] - dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) - path = dirs[-1] if len(dirs) > 0 else None - - if path is None: - accelerator.print( - f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." - ) - args.resume_from_checkpoint = None - initial_global_step = 0 - else: - accelerator.print(f"Resuming from checkpoint {path}") - accelerator.load_state(os.path.join(args.output_dir, path)) - global_step = int(path.split("-")[1]) - - initial_global_step = global_step - first_epoch = global_step // num_update_steps_per_epoch - - progress_bar = tqdm( - range(0, args.max_train_steps), - initial=initial_global_step, - desc="Steps", - # Only show the progress bar once on each machine. - disable=not accelerator.is_local_main_process, - ) - vae_scale_factor_spatial = 2 ** (len(vae.config.block_out_channels) - 1) - - # For DeepSpeed training - model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config - - for epoch in range(first_epoch, args.num_train_epochs): - transformer.train() - - for step, batch in enumerate(train_dataloader): - models_to_accumulate = [transformer] - - with accelerator.accumulate(models_to_accumulate): - video_latents, image_latents = batch["videos"] - prompt_embeds = batch["prompts"] - - video_latents = video_latents.to(dtype=weight_dtype) # [B, F, C, H, W] - image_latents = image_latents.to(dtype=weight_dtype) # [B, F, C, H, W] - - batch_size, num_frames, num_channels, height, width = video_latents.shape - - # Sample a random timestep for each image - timesteps = torch.randint( - 0, scheduler.config.num_train_timesteps, (batch_size,), device=video_latents.device - ) - timesteps = timesteps.long() - - # Sample noise that will be added to the latents - noise = torch.randn_like(video_latents) - - # Add noise to the model input according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_video_latents = scheduler.add_noise(video_latents, noise, timesteps) - noisy_model_input = torch.cat([noisy_video_latents, image_latents], dim=2) - - # Prepare rotary embeds - image_rotary_emb = ( - prepare_rotary_positional_embeddings( - height=args.height, - width=args.width, - num_frames=num_frames, - vae_scale_factor_spatial=vae_scale_factor_spatial, - patch_size=model_config.patch_size, - patch_size_t=model_config.patch_size_t, - attention_head_dim=model_config.attention_head_dim, - device=accelerator.device, - ) - if model_config.use_rotary_positional_embeddings - else None - ) - - # Predict the noise residual - model_output = transformer( - hidden_states=noisy_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timesteps, - image_rotary_emb=image_rotary_emb, - return_dict=False, - )[0] - model_pred = scheduler.get_velocity(model_output, noisy_video_latents, timesteps) - - alphas_cumprod = scheduler.alphas_cumprod[timesteps] - weights = 1 / (1 - alphas_cumprod) - while len(weights.shape) < len(model_pred.shape): - weights = weights.unsqueeze(-1) - - target = video_latents - - loss = torch.mean((weights * (model_pred - target) ** 2).reshape(batch_size, -1), dim=1) - loss = loss.mean() - accelerator.backward(loss) - - if accelerator.sync_gradients: - params_to_clip = transformer.parameters() - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) - - if accelerator.state.deepspeed_plugin is None: - optimizer.step() - optimizer.zero_grad() - - lr_scheduler.step() - - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - progress_bar.update(1) - global_step += 1 - - if accelerator.is_main_process: - if global_step % args.checkpointing_steps == 0: - # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` - if args.checkpoints_total_limit is not None: - checkpoints = os.listdir(args.output_dir) - checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] - checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) - - # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints - if len(checkpoints) >= args.checkpoints_total_limit: - num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 - removing_checkpoints = checkpoints[0:num_to_remove] - - logger.info( - f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" - ) - logger.info(f"Removing checkpoints: {', '.join(removing_checkpoints)}") - - for removing_checkpoint in removing_checkpoints: - removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) - shutil.rmtree(removing_checkpoint) - - save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") - accelerator.save_state(save_path) - logger.info(f"Saved state to {save_path}") - - logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} - progress_bar.set_postfix(**logs) - accelerator.log(logs, step=global_step) - - if global_step >= args.max_train_steps: - break - - if accelerator.is_main_process: - if args.validation_prompt is not None and (epoch + 1) % args.validation_epochs == 0: - # Create pipeline - pipe = CogVideoXImageToVideoPipeline.from_pretrained( - args.pretrained_model_name_or_path, - transformer=unwrap_model(transformer), - scheduler=scheduler, - revision=args.revision, - variant=args.variant, - torch_dtype=weight_dtype, - ) - - validation_prompts = args.validation_prompt.split(args.validation_prompt_separator) - validation_images = args.validation_images.split(args.validation_prompt_separator) - - for validation_image, validation_prompt in zip(validation_images, validation_prompts): - pipeline_args = { - "image": load_image(validation_image), - "prompt": validation_prompt, - "guidance_scale": args.guidance_scale, - "use_dynamic_cfg": args.use_dynamic_cfg, - "height": args.height, - "width": args.width, - } - - validation_outputs = log_validation( - pipe=pipe, - args=args, - accelerator=accelerator, - pipeline_args=pipeline_args, - epoch=epoch, - ) - - # Save the lora layers - accelerator.wait_for_everyone() - if accelerator.is_main_process: - transformer = unwrap_model(transformer) - dtype = ( - torch.float16 - if args.mixed_precision == "fp16" - else torch.bfloat16 - if args.mixed_precision == "bf16" - else torch.float32 - ) - transformer = transformer.to(dtype) - transformer_lora_layers = get_peft_model_state_dict(transformer) - - CogVideoXImageToVideoPipeline.save_lora_weights( - save_directory=args.output_dir, - transformer_lora_layers=transformer_lora_layers, - ) - - # Cleanup trained models to save memory - del transformer - free_memory() - - # Final test inference - pipe = CogVideoXImageToVideoPipeline.from_pretrained( - args.pretrained_model_name_or_path, - revision=args.revision, - variant=args.variant, - torch_dtype=weight_dtype, - ) - pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config) - - if args.enable_slicing: - pipe.vae.enable_slicing() - if args.enable_tiling: - pipe.vae.enable_tiling() - - # Load LoRA weights - lora_scaling = args.lora_alpha / args.rank - pipe.load_lora_weights(args.output_dir, adapter_name="cogvideox-i2v-lora") - pipe.set_adapters(["cogvideox-i2v-lora"], [lora_scaling]) - - # Run inference - validation_outputs = [] - if args.validation_prompt and args.num_validation_videos > 0: - validation_prompts = args.validation_prompt.split(args.validation_prompt_separator) - validation_images = args.validation_images.split(args.validation_prompt_separator) - - for validation_image, validation_prompt in zip(validation_images, validation_prompts): - pipeline_args = { - "image": load_image(validation_image), - "prompt": validation_prompt, - "guidance_scale": args.guidance_scale, - "use_dynamic_cfg": args.use_dynamic_cfg, - "height": args.height, - "width": args.width, - } - - video = log_validation( - pipe=pipe, - args=args, - accelerator=accelerator, - pipeline_args=pipeline_args, - epoch=epoch, - is_final_validation=True, - ) - validation_outputs.extend(video) - - if args.push_to_hub: - validation_prompt = args.validation_prompt or "" - validation_prompt = validation_prompt.split(args.validation_prompt_separator)[0] - save_model_card( - repo_id, - videos=validation_outputs, - base_model=args.pretrained_model_name_or_path, - validation_prompt=validation_prompt, - repo_folder=args.output_dir, - fps=args.fps, - ) - upload_folder( - repo_id=repo_id, - folder_path=args.output_dir, - commit_message="End of training", - ignore_patterns=["step_*", "epoch_*"], - ) - - accelerator.end_training() - - -if __name__ == "__main__": - args = get_args() - main(args) diff --git a/finetune/train_cogvideox_lora.py b/finetune/train_cogvideox_lora.py deleted file mode 100644 index e12b3d5..0000000 --- a/finetune/train_cogvideox_lora.py +++ /dev/null @@ -1,1573 +0,0 @@ -# Copyright 2024 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import logging -import math -import os -import shutil -from pathlib import Path -from typing import List, Optional, Tuple, Union - -import torch -import transformers -from accelerate import Accelerator -from accelerate.logging import get_logger -from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed -from huggingface_hub import create_repo, upload_folder -from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict -from torch.utils.data import DataLoader, Dataset -from torchvision import transforms -from tqdm.auto import tqdm -from transformers import AutoTokenizer, T5EncoderModel, T5Tokenizer - -import diffusers -from diffusers import AutoencoderKLCogVideoX, CogVideoXDPMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel -from diffusers.models.embeddings import get_3d_rotary_pos_embed -from diffusers.optimization import get_scheduler -from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid -from diffusers.training_utils import ( - cast_training_params, - free_memory, -) -from diffusers.utils import check_min_version, convert_unet_state_dict_to_peft, export_to_video, is_wandb_available -from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card -from diffusers.utils.torch_utils import is_compiled_module - - -if is_wandb_available(): - import wandb - -# Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.31.0.dev0") - -logger = get_logger(__name__) - - -def get_args(): - parser = argparse.ArgumentParser(description="Simple example of a training script for CogVideoX.") - - # Model information - parser.add_argument( - "--pretrained_model_name_or_path", - type=str, - default=None, - required=True, - help="Path to pretrained model or model identifier from huggingface.co/models.", - ) - parser.add_argument( - "--revision", - type=str, - default=None, - required=False, - help="Revision of pretrained model identifier from huggingface.co/models.", - ) - parser.add_argument( - "--variant", - type=str, - default=None, - help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", - ) - parser.add_argument( - "--cache_dir", - type=str, - default=None, - help="The directory where the downloaded models and datasets will be stored.", - ) - - # Dataset information - parser.add_argument( - "--dataset_name", - type=str, - default=None, - help=( - "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private," - " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," - " or to a folder containing files that 🤗 Datasets can understand." - ), - ) - parser.add_argument( - "--dataset_config_name", - type=str, - default=None, - help="The config of the Dataset, leave as None if there's only one config.", - ) - parser.add_argument( - "--instance_data_root", - type=str, - default=None, - help=("A folder containing the training data."), - ) - parser.add_argument( - "--video_column", - type=str, - default="video", - help="The column of the dataset containing videos. Or, the name of the file in `--instance_data_root` folder containing the line-separated path to video data.", - ) - parser.add_argument( - "--caption_column", - type=str, - default="text", - help="The column of the dataset containing the instance prompt for each video. Or, the name of the file in `--instance_data_root` folder containing the line-separated instance prompts.", - ) - parser.add_argument( - "--id_token", type=str, default=None, help="Identifier token appended to the start of each prompt if provided." - ) - parser.add_argument( - "--dataloader_num_workers", - type=int, - default=0, - help=( - "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." - ), - ) - - # Validation - parser.add_argument( - "--validation_prompt", - type=str, - default=None, - help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.", - ) - parser.add_argument( - "--validation_prompt_separator", - type=str, - default=":::", - help="String that separates multiple validation prompts", - ) - parser.add_argument( - "--num_validation_videos", - type=int, - default=1, - help="Number of videos that should be generated during validation per `validation_prompt`.", - ) - parser.add_argument( - "--validation_epochs", - type=int, - default=50, - help=( - "Run validation every X epochs. Validation consists of running the prompt `args.validation_prompt` multiple times: `args.num_validation_videos`." - ), - ) - parser.add_argument( - "--guidance_scale", - type=float, - default=6, - help="The guidance scale to use while sampling validation videos.", - ) - parser.add_argument( - "--use_dynamic_cfg", - action="store_true", - default=False, - help="Whether or not to use the default cosine dynamic guidance schedule when sampling validation videos.", - ) - - # Training information - parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") - parser.add_argument( - "--rank", - type=int, - default=128, - help=("The dimension of the LoRA update matrices."), - ) - parser.add_argument( - "--lora_alpha", - type=float, - default=128, - help=("The scaling factor to scale LoRA weight update. The actual scaling factor is `lora_alpha / rank`"), - ) - parser.add_argument( - "--mixed_precision", - type=str, - default=None, - choices=["no", "fp16", "bf16"], - help=( - "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" - " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" - " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." - ), - ) - parser.add_argument( - "--output_dir", - type=str, - default="cogvideox-lora", - help="The output directory where the model predictions and checkpoints will be written.", - ) - parser.add_argument( - "--height", - type=int, - default=480, - help="All input videos are resized to this height.", - ) - parser.add_argument( - "--width", - type=int, - default=720, - help="All input videos are resized to this width.", - ) - parser.add_argument("--fps", type=int, default=8, help="All input videos will be used at this FPS.") - parser.add_argument( - "--max_num_frames", type=int, default=49, help="All input videos will be truncated to these many frames." - ) - parser.add_argument( - "--skip_frames_start", - type=int, - default=0, - help="Number of frames to skip from the beginning of each input video. Useful if training data contains intro sequences.", - ) - parser.add_argument( - "--skip_frames_end", - type=int, - default=0, - help="Number of frames to skip from the end of each input video. Useful if training data contains outro sequences.", - ) - parser.add_argument( - "--random_flip", - action="store_true", - help="whether to randomly flip videos horizontally", - ) - parser.add_argument( - "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." - ) - parser.add_argument("--num_train_epochs", type=int, default=1) - parser.add_argument( - "--max_train_steps", - type=int, - default=None, - help="Total number of training steps to perform. If provided, overrides `--num_train_epochs`.", - ) - parser.add_argument( - "--checkpointing_steps", - type=int, - default=500, - help=( - "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" - " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" - " training using `--resume_from_checkpoint`." - ), - ) - parser.add_argument( - "--checkpoints_total_limit", - type=int, - default=None, - help=("Max number of checkpoints to store."), - ) - parser.add_argument( - "--resume_from_checkpoint", - type=str, - default=None, - help=( - "Whether training should be resumed from a previous checkpoint. Use a path saved by" - ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' - ), - ) - parser.add_argument( - "--gradient_accumulation_steps", - type=int, - default=1, - help="Number of updates steps to accumulate before performing a backward/update pass.", - ) - parser.add_argument( - "--gradient_checkpointing", - action="store_true", - help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", - ) - parser.add_argument( - "--learning_rate", - type=float, - default=1e-4, - help="Initial learning rate (after the potential warmup period) to use.", - ) - parser.add_argument( - "--scale_lr", - action="store_true", - default=False, - help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", - ) - parser.add_argument( - "--lr_scheduler", - type=str, - default="constant", - help=( - 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' - ' "constant", "constant_with_warmup"]' - ), - ) - parser.add_argument( - "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." - ) - parser.add_argument( - "--lr_num_cycles", - type=int, - default=1, - help="Number of hard resets of the lr in cosine_with_restarts scheduler.", - ) - parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") - parser.add_argument( - "--enable_slicing", - action="store_true", - default=False, - help="Whether or not to use VAE slicing for saving memory.", - ) - parser.add_argument( - "--enable_tiling", - action="store_true", - default=False, - help="Whether or not to use VAE tiling for saving memory.", - ) - - # Optimizer - parser.add_argument( - "--optimizer", - type=lambda s: s.lower(), - default="adam", - choices=["adam", "adamw", "prodigy"], - help=("The optimizer type to use."), - ) - parser.add_argument( - "--use_8bit_adam", - action="store_true", - help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW", - ) - parser.add_argument( - "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers." - ) - parser.add_argument( - "--adam_beta2", type=float, default=0.95, help="The beta2 parameter for the Adam and Prodigy optimizers." - ) - parser.add_argument( - "--prodigy_beta3", - type=float, - default=None, - help="Coefficients for computing the Prodigy optimizer's stepsize using running averages. If set to None, uses the value of square root of beta2.", - ) - parser.add_argument("--prodigy_decouple", action="store_true", help="Use AdamW style decoupled weight decay") - parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") - parser.add_argument( - "--adam_epsilon", - type=float, - default=1e-08, - help="Epsilon value for the Adam optimizer and Prodigy optimizers.", - ) - parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") - parser.add_argument("--prodigy_use_bias_correction", action="store_true", help="Turn on Adam's bias correction.") - parser.add_argument( - "--prodigy_safeguard_warmup", - action="store_true", - help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage.", - ) - - # Other information - parser.add_argument("--tracker_name", type=str, default=None, help="Project tracker name") - parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") - parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") - parser.add_argument( - "--hub_model_id", - type=str, - default=None, - help="The name of the repository to keep in sync with the local `output_dir`.", - ) - parser.add_argument( - "--logging_dir", - type=str, - default="logs", - help="Directory where logs are stored.", - ) - parser.add_argument( - "--allow_tf32", - action="store_true", - help=( - "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" - " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" - ), - ) - parser.add_argument( - "--report_to", - type=str, - default=None, - help=( - 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' - ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' - ), - ) - - return parser.parse_args() - - -class VideoDataset(Dataset): - def __init__( - self, - instance_data_root: Optional[str] = None, - dataset_name: Optional[str] = None, - dataset_config_name: Optional[str] = None, - caption_column: str = "text", - video_column: str = "video", - height: int = 480, - width: int = 720, - fps: int = 8, - max_num_frames: int = 49, - skip_frames_start: int = 0, - skip_frames_end: int = 0, - cache_dir: Optional[str] = None, - id_token: Optional[str] = None, - ) -> None: - super().__init__() - - self.instance_data_root = Path(instance_data_root) if instance_data_root is not None else None - self.dataset_name = dataset_name - self.dataset_config_name = dataset_config_name - self.caption_column = caption_column - self.video_column = video_column - self.height = height - self.width = width - self.fps = fps - self.max_num_frames = max_num_frames - self.skip_frames_start = skip_frames_start - self.skip_frames_end = skip_frames_end - self.cache_dir = cache_dir - self.id_token = id_token or "" - - if dataset_name is not None: - self.instance_prompts, self.instance_video_paths = self._load_dataset_from_hub() - else: - self.instance_prompts, self.instance_video_paths = self._load_dataset_from_local_path() - - self.num_instance_videos = len(self.instance_video_paths) - if self.num_instance_videos != len(self.instance_prompts): - raise ValueError( - f"Expected length of instance prompts and videos to be the same but found {len(self.instance_prompts)=} and {len(self.instance_video_paths)=}. Please ensure that the number of caption prompts and videos match in your dataset." - ) - - self.instance_videos = self._preprocess_data() - - def __len__(self): - return self.num_instance_videos - - def __getitem__(self, index): - return { - "instance_prompt": self.id_token + self.instance_prompts[index], - "instance_video": self.instance_videos[index], - } - - def _load_dataset_from_hub(self): - try: - from datasets import load_dataset - except ImportError: - raise ImportError( - "You are trying to load your data using the datasets library. If you wish to train using custom " - "captions please install the datasets library: `pip install datasets`. If you wish to load a " - "local folder containing images only, specify --instance_data_root instead." - ) - - # Downloading and loading a dataset from the hub. See more about loading custom images at - # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script - dataset = load_dataset( - self.dataset_name, - self.dataset_config_name, - cache_dir=self.cache_dir, - ) - column_names = dataset["train"].column_names - - if self.video_column is None: - video_column = column_names[0] - logger.info(f"`video_column` defaulting to {video_column}") - else: - video_column = self.video_column - if video_column not in column_names: - raise ValueError( - f"`--video_column` value '{video_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" - ) - - if self.caption_column is None: - caption_column = column_names[1] - logger.info(f"`caption_column` defaulting to {caption_column}") - else: - caption_column = self.caption_column - if self.caption_column not in column_names: - raise ValueError( - f"`--caption_column` value '{self.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" - ) - - instance_prompts = dataset["train"][caption_column] - instance_videos = [Path(self.instance_data_root, filepath) for filepath in dataset["train"][video_column]] - - return instance_prompts, instance_videos - - def _load_dataset_from_local_path(self): - if not self.instance_data_root.exists(): - raise ValueError("Instance videos root folder does not exist") - - prompt_path = self.instance_data_root.joinpath(self.caption_column) - video_path = self.instance_data_root.joinpath(self.video_column) - - if not prompt_path.exists() or not prompt_path.is_file(): - raise ValueError( - "Expected `--caption_column` to be path to a file in `--instance_data_root` containing line-separated text prompts." - ) - if not video_path.exists() or not video_path.is_file(): - raise ValueError( - "Expected `--video_column` to be path to a file in `--instance_data_root` containing line-separated paths to video data in the same directory." - ) - - with open(prompt_path, "r", encoding="utf-8") as file: - instance_prompts = [line.strip() for line in file.readlines() if len(line.strip()) > 0] - with open(video_path, "r", encoding="utf-8") as file: - instance_videos = [ - self.instance_data_root.joinpath(line.strip()) for line in file.readlines() if len(line.strip()) > 0 - ] - - if any(not path.is_file() for path in instance_videos): - raise ValueError( - "Expected '--video_column' to be a path to a file in `--instance_data_root` containing line-separated paths to video data but found atleast one path that is not a valid file." - ) - - return instance_prompts, instance_videos - - def _preprocess_data(self): - try: - import decord - except ImportError: - raise ImportError( - "The `decord` package is required for loading the video dataset. Install with `pip install decord`" - ) - - decord.bridge.set_bridge("torch") - - videos = [] - train_transforms = transforms.Compose( - [ - transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0), - ] - ) - - for filename in self.instance_video_paths: - video_reader = decord.VideoReader(uri=filename.as_posix(), width=self.width, height=self.height) - video_num_frames = len(video_reader) - - start_frame = min(self.skip_frames_start, video_num_frames) - end_frame = max(0, video_num_frames - self.skip_frames_end) - if end_frame <= start_frame: - frames = video_reader.get_batch([start_frame]) - elif end_frame - start_frame <= self.max_num_frames: - frames = video_reader.get_batch(list(range(start_frame, end_frame))) - else: - indices = list(range(start_frame, end_frame, (end_frame - start_frame) // self.max_num_frames)) - frames = video_reader.get_batch(indices) - - # Ensure that we don't go over the limit - frames = frames[: self.max_num_frames] - selected_num_frames = frames.shape[0] - - # Choose first (4k + 1) frames as this is how many is required by the VAE - remainder = (3 + (selected_num_frames % 4)) % 4 - if remainder != 0: - frames = frames[:-remainder] - selected_num_frames = frames.shape[0] - - assert (selected_num_frames - 1) % 4 == 0 - - # Training transforms - frames = frames.float() - frames = torch.stack([train_transforms(frame) for frame in frames], dim=0) - videos.append(frames.permute(0, 3, 1, 2).contiguous()) # [F, C, H, W] - - return videos - - -def save_model_card( - repo_id: str, - videos=None, - base_model: str = None, - validation_prompt=None, - repo_folder=None, - fps=8, -): - widget_dict = [] - if videos is not None: - for i, video in enumerate(videos): - export_to_video(video, os.path.join(repo_folder, f"final_video_{i}.mp4", fps=fps)) - widget_dict.append( - {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"video_{i}.mp4"}} - ) - - model_description = f""" -# CogVideoX LoRA - {repo_id} - - - -## Model description - -These are {repo_id} LoRA weights for {base_model}. - -The weights were trained using the [CogVideoX Diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/cogvideo/train_cogvideox_lora.py). - -Was LoRA for the text encoder enabled? No. - -## Download model - -[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab. - -## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers) - -```py -from diffusers import CogVideoXPipeline -import torch - -pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to("cuda") -pipe.load_lora_weights("{repo_id}", weight_name="pytorch_lora_weights.safetensors", adapter_name=["cogvideox-lora"]) - -# The LoRA adapter weights are determined by what was used for training. -# In this case, we assume `--lora_alpha` is 32 and `--rank` is 64. -# It can be made lower or higher from what was used in training to decrease or amplify the effect -# of the LoRA upto a tolerance, beyond which one might notice no effect at all or overflows. -pipe.set_adapters(["cogvideox-lora"], [32 / 64]) - -video = pipe("{validation_prompt}", guidance_scale=6, use_dynamic_cfg=True).frames[0] -``` - -For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) - -## License - -Please adhere to the licensing terms as described [here](https://huggingface.co/THUDM/CogVideoX-5b/blob/main/LICENSE) and [here](https://huggingface.co/THUDM/CogVideoX-2b/blob/main/LICENSE). -""" - model_card = load_or_create_model_card( - repo_id_or_path=repo_id, - from_training=True, - license="other", - base_model=base_model, - prompt=validation_prompt, - model_description=model_description, - widget=widget_dict, - ) - tags = [ - "text-to-video", - "diffusers-training", - "diffusers", - "lora", - "cogvideox", - "cogvideox-diffusers", - "template:sd-lora", - ] - - model_card = populate_model_card(model_card, tags=tags) - model_card.save(os.path.join(repo_folder, "README.md")) - - -def log_validation( - pipe, - args, - accelerator, - pipeline_args, - epoch, - is_final_validation: bool = False, -): - logger.info( - f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}." - ) - # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it - scheduler_args = {} - - if "variance_type" in pipe.scheduler.config: - variance_type = pipe.scheduler.config.variance_type - - if variance_type in ["learned", "learned_range"]: - variance_type = "fixed_small" - - scheduler_args["variance_type"] = variance_type - - pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, **scheduler_args) - pipe = pipe.to(accelerator.device) - # pipe.set_progress_bar_config(disable=True) - - # run inference - generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None - - videos = [] - for _ in range(args.num_validation_videos): - video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0] - videos.append(video) - - for tracker in accelerator.trackers: - phase_name = "test" if is_final_validation else "validation" - if tracker.name == "wandb": - video_filenames = [] - for i, video in enumerate(videos): - prompt = ( - pipeline_args["prompt"][:25] - .replace(" ", "_") - .replace(" ", "_") - .replace("'", "_") - .replace('"', "_") - .replace("/", "_") - ) - filename = os.path.join(args.output_dir, f"{phase_name}_video_{i}_{prompt}.mp4") - export_to_video(video, filename, fps=8) - video_filenames.append(filename) - - tracker.log( - { - phase_name: [ - wandb.Video(filename, caption=f"{i}: {pipeline_args['prompt']}") - for i, filename in enumerate(video_filenames) - ] - } - ) - - free_memory() - - return videos - - -def _get_t5_prompt_embeds( - tokenizer: T5Tokenizer, - text_encoder: T5EncoderModel, - prompt: Union[str, List[str]], - num_videos_per_prompt: int = 1, - max_sequence_length: int = 226, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - text_input_ids=None, -): - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - if tokenizer is not None: - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - add_special_tokens=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - else: - if text_input_ids is None: - raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.") - - prompt_embeds = text_encoder(text_input_ids.to(device))[0] - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - - # duplicate text embeddings for each generation per prompt, using mps friendly method - _, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) - - return prompt_embeds - - -def encode_prompt( - tokenizer: T5Tokenizer, - text_encoder: T5EncoderModel, - prompt: Union[str, List[str]], - num_videos_per_prompt: int = 1, - max_sequence_length: int = 226, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - text_input_ids=None, -): - prompt = [prompt] if isinstance(prompt, str) else prompt - prompt_embeds = _get_t5_prompt_embeds( - tokenizer, - text_encoder, - prompt=prompt, - num_videos_per_prompt=num_videos_per_prompt, - max_sequence_length=max_sequence_length, - device=device, - dtype=dtype, - text_input_ids=text_input_ids, - ) - return prompt_embeds - - -def compute_prompt_embeddings( - tokenizer, text_encoder, prompt, max_sequence_length, device, dtype, requires_grad: bool = False -): - if requires_grad: - prompt_embeds = encode_prompt( - tokenizer, - text_encoder, - prompt, - num_videos_per_prompt=1, - max_sequence_length=max_sequence_length, - device=device, - dtype=dtype, - ) - else: - with torch.no_grad(): - prompt_embeds = encode_prompt( - tokenizer, - text_encoder, - prompt, - num_videos_per_prompt=1, - max_sequence_length=max_sequence_length, - device=device, - dtype=dtype, - ) - return prompt_embeds - - -def prepare_rotary_positional_embeddings( - height: int, - width: int, - num_frames: int, - vae_scale_factor_spatial: int = 8, - patch_size: int = 2, - patch_size_t: int = 1, - attention_head_dim: int = 64, - device: Optional[torch.device] = None, - base_height: int = 480, - base_width: int = 720, -) -> Tuple[torch.Tensor, torch.Tensor]: - grid_height = height // (vae_scale_factor_spatial * patch_size) - grid_width = width // (vae_scale_factor_spatial * patch_size) - base_size_width = base_width // (vae_scale_factor_spatial * patch_size) - base_size_height = base_height // (vae_scale_factor_spatial * patch_size) - - p_t = patch_size_t - base_num_frames = (num_frames + p_t - 1) // p_t - - grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, base_size_height) - freqs_cos, freqs_sin = get_3d_rotary_pos_embed( - embed_dim=attention_head_dim, - crops_coords=grid_crops_coords, - grid_size=(grid_height, grid_width), - temporal_size=base_num_frames, - ) - - freqs_cos = freqs_cos.to(device=device) - freqs_sin = freqs_sin.to(device=device) - return freqs_cos, freqs_sin - - -def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False): - # Use DeepSpeed optimzer - if use_deepspeed: - from accelerate.utils import DummyOptim - - return DummyOptim( - params_to_optimize, - lr=args.learning_rate, - betas=(args.adam_beta1, args.adam_beta2), - eps=args.adam_epsilon, - weight_decay=args.adam_weight_decay, - ) - - # Optimizer creation - supported_optimizers = ["adam", "adamw", "prodigy"] - if args.optimizer not in supported_optimizers: - logger.warning( - f"Unsupported choice of optimizer: {args.optimizer}. Supported optimizers include {supported_optimizers}. Defaulting to AdamW" - ) - args.optimizer = "adamw" - - if args.use_8bit_adam and not (args.optimizer.lower() not in ["adam", "adamw"]): - logger.warning( - f"use_8bit_adam is ignored when optimizer is not set to 'Adam' or 'AdamW'. Optimizer was " - f"set to {args.optimizer.lower()}" - ) - - if args.use_8bit_adam: - try: - import bitsandbytes as bnb - except ImportError: - raise ImportError( - "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." - ) - - if args.optimizer.lower() == "adamw": - optimizer_class = bnb.optim.AdamW8bit if args.use_8bit_adam else torch.optim.AdamW - - optimizer = optimizer_class( - params_to_optimize, - betas=(args.adam_beta1, args.adam_beta2), - eps=args.adam_epsilon, - weight_decay=args.adam_weight_decay, - ) - elif args.optimizer.lower() == "adam": - optimizer_class = bnb.optim.Adam8bit if args.use_8bit_adam else torch.optim.Adam - - optimizer = optimizer_class( - params_to_optimize, - betas=(args.adam_beta1, args.adam_beta2), - eps=args.adam_epsilon, - weight_decay=args.adam_weight_decay, - ) - elif args.optimizer.lower() == "prodigy": - try: - import prodigyopt - except ImportError: - raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") - - optimizer_class = prodigyopt.Prodigy - - if args.learning_rate <= 0.1: - logger.warning( - "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" - ) - - optimizer = optimizer_class( - params_to_optimize, - lr=args.learning_rate, - betas=(args.adam_beta1, args.adam_beta2), - beta3=args.prodigy_beta3, - weight_decay=args.adam_weight_decay, - eps=args.adam_epsilon, - decouple=args.prodigy_decouple, - use_bias_correction=args.prodigy_use_bias_correction, - safeguard_warmup=args.prodigy_safeguard_warmup, - ) - - return optimizer - - -def main(args): - if args.report_to == "wandb" and args.hub_token is not None: - raise ValueError( - "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." - " Please use `huggingface-cli login` to authenticate with the Hub." - ) - - if torch.backends.mps.is_available() and args.mixed_precision == "bf16": - # due to pytorch#99272, MPS does not yet support bfloat16. - raise ValueError( - "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." - ) - - logging_dir = Path(args.output_dir, args.logging_dir) - - expected_midxed_precision = "bf16" if "5b" in args.pretrained_model_name_or_path.lower() else "fp16" - if args.mixed_precision != expected_midxed_precision: - raise ValueError(f"Mixed precision {args.mixed_precision} does not match the model precision, should be {expected_midxed_precision}") - - accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) - kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) - accelerator = Accelerator( - gradient_accumulation_steps=args.gradient_accumulation_steps, - mixed_precision=args.mixed_precision, - log_with=args.report_to, - project_config=accelerator_project_config, - kwargs_handlers=[kwargs], - ) - - if accelerator.state.deepspeed_plugin: - # Set deepspeed config according to args - config = { - 'optimizer': { - 'type': args.optimizer, - 'params': { - 'lr': args.learning_rate, - 'betas': [args.adam_beta1, args.adam_beta2] - }, - 'torch_adam': True - }, - 'bf16': { - 'enabled': True if args.mixed_precision == "bf16" else False - }, - 'fp16': { - 'enabled': True if args.mixed_precision == "fp16" else False - }, - 'gradient_accumulation_steps': args.gradient_accumulation_steps, - 'train_batch_size': args.train_batch_size - } - accelerator.state.deepspeed_plugin.deepspeed_config.update(config) - - # Disable AMP for MPS. - if torch.backends.mps.is_available(): - accelerator.native_amp = False - - if args.report_to == "wandb": - if not is_wandb_available(): - raise ImportError("Make sure to install wandb if you want to use it for logging during training.") - - # Make one log on every process with the configuration for debugging. - logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - level=logging.INFO, - ) - logger.info(accelerator.state, main_process_only=False) - if accelerator.is_local_main_process: - transformers.utils.logging.set_verbosity_warning() - diffusers.utils.logging.set_verbosity_info() - else: - transformers.utils.logging.set_verbosity_error() - diffusers.utils.logging.set_verbosity_error() - - # If passed along, set the training seed now. - if args.seed is not None: - set_seed(args.seed) - - # Handle the repository creation - if accelerator.is_main_process: - if args.output_dir is not None: - os.makedirs(args.output_dir, exist_ok=True) - - if args.push_to_hub: - repo_id = create_repo( - repo_id=args.hub_model_id or Path(args.output_dir).name, - exist_ok=True, - ).repo_id - - # Prepare models and scheduler - tokenizer = AutoTokenizer.from_pretrained( - args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision - ) - - text_encoder = T5EncoderModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision - ) - - # CogVideoX-2b weights are stored in float16 - # CogVideoX-5b and CogVideoX-5b-I2V weights are stored in bfloat16 - load_dtype = torch.bfloat16 if "5b" in args.pretrained_model_name_or_path.lower() else torch.float16 - transformer = CogVideoXTransformer3DModel.from_pretrained( - args.pretrained_model_name_or_path, - subfolder="transformer", - torch_dtype=load_dtype, - revision=args.revision, - variant=args.variant, - ) - - vae = AutoencoderKLCogVideoX.from_pretrained( - args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant - ) - - scheduler = CogVideoXDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") - - if args.enable_slicing: - vae.enable_slicing() - if args.enable_tiling: - vae.enable_tiling() - - # We only train the additional adapter LoRA layers - text_encoder.requires_grad_(False) - transformer.requires_grad_(False) - vae.requires_grad_(False) - - # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision - # as these weights are only used for inference, keeping weights in full precision is not required. - weight_dtype = torch.float32 - if accelerator.state.deepspeed_plugin: - # DeepSpeed is handling precision, use what's in the DeepSpeed config - if ( - "fp16" in accelerator.state.deepspeed_plugin.deepspeed_config - and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"] - ): - weight_dtype = torch.float16 - if ( - "bf16" in accelerator.state.deepspeed_plugin.deepspeed_config - and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"] - ): - weight_dtype = torch.bfloat16 - else: - if accelerator.mixed_precision == "fp16": - weight_dtype = torch.float16 - elif accelerator.mixed_precision == "bf16": - weight_dtype = torch.bfloat16 - - if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: - # due to pytorch#99272, MPS does not yet support bfloat16. - raise ValueError( - "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." - ) - - text_encoder.to(accelerator.device, dtype=weight_dtype) - transformer.to(accelerator.device, dtype=weight_dtype) - vae.to(accelerator.device, dtype=weight_dtype) - - if args.gradient_checkpointing: - transformer.enable_gradient_checkpointing() - - # now we will add new LoRA weights to the attention layers - transformer_lora_config = LoraConfig( - r=args.rank, - lora_alpha=args.lora_alpha, - init_lora_weights=True, - target_modules=["to_k", "to_q", "to_v", "to_out.0"], - ) - transformer.add_adapter(transformer_lora_config) - - def unwrap_model(model): - model = accelerator.unwrap_model(model) - model = model._orig_mod if is_compiled_module(model) else model - return model - - # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format - def save_model_hook(models, weights, output_dir): - if accelerator.is_main_process: - transformer_lora_layers_to_save = None - - for model in models: - if isinstance(model, type(unwrap_model(transformer))): - transformer_lora_layers_to_save = get_peft_model_state_dict(model) - else: - raise ValueError(f"unexpected save model: {model.__class__}") - - # make sure to pop weight so that corresponding model is not saved again - weights.pop() - - CogVideoXPipeline.save_lora_weights( - output_dir, - transformer_lora_layers=transformer_lora_layers_to_save, - ) - - def load_model_hook(models, input_dir): - transformer_ = None - - while len(models) > 0: - model = models.pop() - - if isinstance(model, type(unwrap_model(transformer))): - transformer_ = model - else: - raise ValueError(f"Unexpected save model: {model.__class__}") - - lora_state_dict = CogVideoXPipeline.lora_state_dict(input_dir) - - transformer_state_dict = { - f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") - } - transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) - incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") - if incompatible_keys is not None: - # check only for unexpected keys - unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) - if unexpected_keys: - logger.warning( - f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " - f" {unexpected_keys}. " - ) - - # Make sure the trainable params are in float32. This is again needed since the base models - # are in `weight_dtype`. More details: - # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 - if args.mixed_precision == "fp16": - # only upcast trainable parameters (LoRA) into fp32 - cast_training_params([transformer_]) - - accelerator.register_save_state_pre_hook(save_model_hook) - accelerator.register_load_state_pre_hook(load_model_hook) - - # Enable TF32 for faster training on Ampere GPUs, - # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices - if args.allow_tf32 and torch.cuda.is_available(): - torch.backends.cuda.matmul.allow_tf32 = True - - if args.scale_lr: - args.learning_rate = ( - args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes - ) - - # Make sure the trainable params are in float32. - if args.mixed_precision == "fp16": - # only upcast trainable parameters (LoRA) into fp32 - cast_training_params([transformer], dtype=torch.float32) - - transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) - - # Optimization parameters - transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate} - params_to_optimize = [transformer_parameters_with_lr] - - use_deepspeed_optimizer = ( - accelerator.state.deepspeed_plugin is not None - and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config - ) - use_deepspeed_scheduler = ( - accelerator.state.deepspeed_plugin is not None - and "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config - ) - - optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer) - - # Dataset and DataLoader - train_dataset = VideoDataset( - instance_data_root=args.instance_data_root, - dataset_name=args.dataset_name, - dataset_config_name=args.dataset_config_name, - caption_column=args.caption_column, - video_column=args.video_column, - height=args.height, - width=args.width, - fps=args.fps, - max_num_frames=args.max_num_frames, - skip_frames_start=args.skip_frames_start, - skip_frames_end=args.skip_frames_end, - cache_dir=args.cache_dir, - id_token=args.id_token, - ) - - def encode_video(video): - video = video.to(accelerator.device, dtype=vae.dtype).unsqueeze(0) - video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] - latent_dist = vae.encode(video).latent_dist - return latent_dist - - train_dataset.instance_videos = [encode_video(video) for video in train_dataset.instance_videos] - - def collate_fn(examples): - videos = [example["instance_video"].sample() * vae.config.scaling_factor for example in examples] - prompts = [example["instance_prompt"] for example in examples] - - videos = torch.cat(videos) - videos = videos.to(memory_format=torch.contiguous_format).float() - - return { - "videos": videos, - "prompts": prompts, - } - - train_dataloader = DataLoader( - train_dataset, - batch_size=args.train_batch_size, - shuffle=True, - collate_fn=collate_fn, - num_workers=args.dataloader_num_workers, - ) - - # Scheduler and math around the number of training steps. - overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if args.max_train_steps is None: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - overrode_max_train_steps = True - - if use_deepspeed_scheduler: - from accelerate.utils import DummyScheduler - - lr_scheduler = DummyScheduler( - optimizer=optimizer, - total_num_steps=args.max_train_steps * accelerator.num_processes, - warmup_num_steps=args.lr_warmup_steps * accelerator.num_processes, - ) - else: - lr_scheduler = get_scheduler( - args.lr_scheduler, - optimizer=optimizer, - num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, - num_training_steps=args.max_train_steps * accelerator.num_processes, - num_cycles=args.lr_num_cycles, - power=args.lr_power, - ) - - # Prepare everything with our `accelerator`. - transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - transformer, optimizer, train_dataloader, lr_scheduler - ) - - # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if overrode_max_train_steps: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - # Afterwards we recalculate our number of training epochs - args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - - # We need to initialize the trackers we use, and also store our configuration. - # The trackers initializes automatically on the main process. - if accelerator.is_main_process: - tracker_name = args.tracker_name or "cogvideox-lora" - accelerator.init_trackers(tracker_name, config=vars(args)) - - # Train! - total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - num_trainable_parameters = sum(param.numel() for model in params_to_optimize for param in model["params"]) - - logger.info("***** Running training *****") - logger.info(f" Num trainable parameters = {num_trainable_parameters}") - logger.info(f" Num examples = {len(train_dataset)}") - logger.info(f" Num batches each epoch = {len(train_dataloader)}") - logger.info(f" Num epochs = {args.num_train_epochs}") - logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") - logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") - logger.info(f" Gradient accumulation steps = {args.gradient_accumulation_steps}") - logger.info(f" Total optimization steps = {args.max_train_steps}") - global_step = 0 - first_epoch = 0 - - # Potentially load in the weights and states from a previous save - if not args.resume_from_checkpoint: - initial_global_step = 0 - else: - if args.resume_from_checkpoint != "latest": - path = os.path.basename(args.resume_from_checkpoint) - else: - # Get the mos recent checkpoint - dirs = os.listdir(args.output_dir) - dirs = [d for d in dirs if d.startswith("checkpoint")] - dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) - path = dirs[-1] if len(dirs) > 0 else None - - if path is None: - accelerator.print( - f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." - ) - args.resume_from_checkpoint = None - initial_global_step = 0 - else: - accelerator.print(f"Resuming from checkpoint {path}") - accelerator.load_state(os.path.join(args.output_dir, path)) - global_step = int(path.split("-")[1]) - - initial_global_step = global_step - first_epoch = global_step // num_update_steps_per_epoch - - progress_bar = tqdm( - range(0, args.max_train_steps), - initial=initial_global_step, - desc="Steps", - # Only show the progress bar once on each machine. - disable=not accelerator.is_local_main_process, - ) - vae_scale_factor_spatial = 2 ** (len(vae.config.block_out_channels) - 1) - - # For DeepSpeed training - model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config - - for epoch in range(first_epoch, args.num_train_epochs): - transformer.train() - - for step, batch in enumerate(train_dataloader): - models_to_accumulate = [transformer] - - with accelerator.accumulate(models_to_accumulate): - model_input = batch["videos"].permute(0, 2, 1, 3, 4).to(dtype=weight_dtype) # [B, F, C, H, W] - prompts = batch["prompts"] - - # encode prompts - prompt_embeds = compute_prompt_embeddings( - tokenizer, - text_encoder, - prompts, - model_config.max_text_seq_length, - accelerator.device, - weight_dtype, - requires_grad=False, - ) - - # Sample noise that will be added to the latents - noise = torch.randn_like(model_input) - batch_size, num_frames, num_channels, height, width = model_input.shape - - # Sample a random timestep for each image - timesteps = torch.randint( - 0, scheduler.config.num_train_timesteps, (batch_size,), device=model_input.device - ) - timesteps = timesteps.long() - - # Prepare rotary embeds - image_rotary_emb = ( - prepare_rotary_positional_embeddings( - height=args.height, - width=args.width, - num_frames=num_frames, - vae_scale_factor_spatial=vae_scale_factor_spatial, - patch_size=model_config.patch_size, - patch_size_t=model_config.patch_size_t if model_config.patch_size_t is not None else 1, - attention_head_dim=model_config.attention_head_dim, - device=accelerator.device, - ) - if model_config.use_rotary_positional_embeddings - else None - ) - - # Add noise to the model input according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_model_input = scheduler.add_noise(model_input, noise, timesteps) - - # Predict the noise residual - model_output = transformer( - hidden_states=noisy_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timesteps, - image_rotary_emb=image_rotary_emb, - return_dict=False, - )[0] - model_pred = scheduler.get_velocity(model_output, noisy_model_input, timesteps) - - alphas_cumprod = scheduler.alphas_cumprod[timesteps] - weights = 1 / (1 - alphas_cumprod) - while len(weights.shape) < len(model_pred.shape): - weights = weights.unsqueeze(-1) - - target = model_input - - loss = torch.mean((weights * (model_pred - target) ** 2).reshape(batch_size, -1), dim=1) - loss = loss.mean() - accelerator.backward(loss) - - if accelerator.sync_gradients: - params_to_clip = transformer.parameters() - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) - - if accelerator.state.deepspeed_plugin is None: - optimizer.step() - optimizer.zero_grad() - - lr_scheduler.step() - - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - progress_bar.update(1) - global_step += 1 - - if accelerator.is_main_process: - if global_step % args.checkpointing_steps == 0: - # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` - if args.checkpoints_total_limit is not None: - checkpoints = os.listdir(args.output_dir) - checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] - checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) - - # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints - if len(checkpoints) >= args.checkpoints_total_limit: - num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 - removing_checkpoints = checkpoints[0:num_to_remove] - - logger.info( - f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" - ) - logger.info(f"Removing checkpoints: {', '.join(removing_checkpoints)}") - - for removing_checkpoint in removing_checkpoints: - removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) - shutil.rmtree(removing_checkpoint) - - save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") - accelerator.save_state(save_path) - logger.info(f"Saved state to {save_path}") - - logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} - progress_bar.set_postfix(**logs) - accelerator.log(logs, step=global_step) - - if global_step >= args.max_train_steps: - break - - if accelerator.is_main_process: - if args.validation_prompt is not None and (epoch + 1) % args.validation_epochs == 0: - # Create pipeline - pipe = CogVideoXPipeline.from_pretrained( - args.pretrained_model_name_or_path, - transformer=unwrap_model(transformer), - text_encoder=unwrap_model(text_encoder), - vae=unwrap_model(vae), - scheduler=scheduler, - revision=args.revision, - variant=args.variant, - torch_dtype=weight_dtype, - ) - - validation_prompts = args.validation_prompt.split(args.validation_prompt_separator) - for validation_prompt in validation_prompts: - pipeline_args = { - "prompt": validation_prompt, - "guidance_scale": args.guidance_scale, - "use_dynamic_cfg": args.use_dynamic_cfg, - "height": args.height, - "width": args.width, - } - - validation_outputs = log_validation( - pipe=pipe, - args=args, - accelerator=accelerator, - pipeline_args=pipeline_args, - epoch=epoch, - ) - - # Save the lora layers - accelerator.wait_for_everyone() - if accelerator.is_main_process: - transformer = unwrap_model(transformer) - dtype = ( - torch.float16 - if args.mixed_precision == "fp16" - else torch.bfloat16 - if args.mixed_precision == "bf16" - else torch.float32 - ) - transformer = transformer.to(dtype) - transformer_lora_layers = get_peft_model_state_dict(transformer) - - CogVideoXPipeline.save_lora_weights( - save_directory=args.output_dir, - transformer_lora_layers=transformer_lora_layers, - ) - - # Final test inference - pipe = CogVideoXPipeline.from_pretrained( - args.pretrained_model_name_or_path, - revision=args.revision, - variant=args.variant, - torch_dtype=weight_dtype, - ) - pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config) - - if args.enable_slicing: - pipe.vae.enable_slicing() - if args.enable_tiling: - pipe.vae.enable_tiling() - - # Load LoRA weights - lora_scaling = args.lora_alpha / args.rank - pipe.load_lora_weights(args.output_dir, adapter_name="cogvideox-lora") - pipe.set_adapters(["cogvideox-lora"], [lora_scaling]) - - # Run inference - validation_outputs = [] - if args.validation_prompt and args.num_validation_videos > 0: - validation_prompts = args.validation_prompt.split(args.validation_prompt_separator) - for validation_prompt in validation_prompts: - pipeline_args = { - "prompt": validation_prompt, - "guidance_scale": args.guidance_scale, - "use_dynamic_cfg": args.use_dynamic_cfg, - "height": args.height, - "width": args.width, - } - - video = log_validation( - pipe=pipe, - args=args, - accelerator=accelerator, - pipeline_args=pipeline_args, - epoch=epoch, - is_final_validation=True, - ) - validation_outputs.extend(video) - - if args.push_to_hub: - save_model_card( - repo_id, - videos=validation_outputs, - base_model=args.pretrained_model_name_or_path, - validation_prompt=args.validation_prompt, - repo_folder=args.output_dir, - fps=args.fps, - ) - upload_folder( - repo_id=repo_id, - folder_path=args.output_dir, - commit_message="End of training", - ignore_patterns=["step_*", "epoch_*"], - ) - - accelerator.end_training() - - -if __name__ == "__main__": - args = get_args() - main(args)