diff --git a/finetune/README.md b/finetune/README.md new file mode 100644 index 0000000..1f6e50b --- /dev/null +++ b/finetune/README.md @@ -0,0 +1,192 @@ +# CogVideoX diffusers Fine-tuning Guide + +If you want to see the SAT version fine-tuning, please check [here](../sat/README.md). The dataset format is different +from this version. + +This tutorial aims to quickly fine-tune the diffusers version of the CogVideoX model. + +### Hardware Requirements + ++ CogVideoX-2B LORA: 1 * A100 ++ CogVideoX-2B SFT: 8 * A100 ++ CogVideoX-5B/5B-I2V not yet supported + +### Prepare the Dataset + +First, you need to prepare the dataset. The format of the dataset is as follows, where `videos.txt` contains paths to +the videos in the `videos` directory. + +``` +. +├── prompts.txt +├── videos +└── videos.txt +``` + +### Configuration Files and Execution + +`accelerate` configuration files are as follows: + ++ accelerate_config_machine_multi.yaml for multi-GPU use ++ accelerate_config_machine_single.yaml for single-GPU use + +The `finetune` script configuration is as follows: + +```shell +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + +# This command sets PyTorch's CUDA memory allocation strategy to segment-based memory management to prevent OOM (Out of Memory) errors. + +accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gpu \ + +# Use Accelerate to start training, specifying the `accelerate_config_machine_single.yaml` configuration file, and using multiple GPUs. + +train_cogvideox_lora.py \ + +# This is the training script you will execute for LoRA fine-tuning of the CogVideoX model. + +--pretrained_model_name_or_path THUDM/CogVideoX-2b \ + +# The path to the pretrained model, pointing to the CogVideoX-5b model you want to fine-tune. + +--cache_dir ~/.cache \ + +# The directory where downloaded models and datasets will be stored. + +--enable_tiling \ + +# Enable VAE tiling functionality, which reduces memory usage by processing smaller blocks of the image. + +--enable_slicing \ + +# Enable VAE slicing functionality, which slices the image across channels to save memory. + +--instance_data_root ~/disney/ \ + +# The root directory of the instance data, the folder of the dataset used during training. + +--caption_column prompts.txt \ + +# Specifies the column or file containing instance prompts (text descriptions), in this case, the `prompts.txt` file. + +--video_column videos.txt \ + +# Specifies the column or file containing paths to videos, in this case, the `videos.txt` file. + +--validation_prompt "Mickey with the captain and friends:::Mickey and the bear" \ + +# The prompt(s) used for validation, multiple prompts should be separated by the specified delimiter (`:::`). + +--validation_prompt_separator ::: \ + +# The delimiter for validation prompts, set here as `:::`. + +--num_validation_videos 1 \ + +# The number of videos to be generated during validation, set to 1. + +--validation_epochs 2 \ + +# How many epochs to run validation, set to validate every 2 epochs. + +--seed 3407 \ + +# Sets the random seed for reproducible training, set to 3407. + +--rank 128 \ + +# The dimension of the LoRA update matrices, controlling the size of the LoRA layer parameters, set to 128. + +--mixed_precision bf16 \ + +# Use mixed precision training, set to `bf16` (bfloat16), which can reduce memory usage and speed up training. + +--output_dir cogvideox-lora-single-gpu \ + +# Output directory, where model predictions and checkpoints will be stored. + +--height 480 \ + +# The height of input videos, all videos will be resized to 480 pixels. + +--width 720 \ + +# The width of input videos, all videos will be resized to 720 pixels. + +--fps 8 \ + +# The frame rate of input videos, all videos will be processed at 8 frames per second. + +--max_num_frames 49 \ + +# The maximum number of frames for input videos, videos will be truncated to a maximum of 49 frames. + +--skip_frames_start 0 \ + +# The number of frames to skip at the beginning of each video, set to 0, indicating no frames are skipped. + +--skip_frames_end 0 \ + +# The number of frames to skip at the end of each video, set to 0, indicating no frames are skipped. + +--train_batch_size 1 \ + +# The batch size for training, set to 1 per device. + +--num_train_epochs 10 \ + +# The total number of epochs for training, set to 10. + +--checkpointing_steps 500 \ + +# Save a checkpoint every 500 steps. + +--gradient_accumulation_steps 1 \ + +# The number of gradient accumulation steps, indicating that a gradient update is performed every 1 step. + +--learning_rate 1e-4 \ + +# The initial learning rate, set to 1e-4. + +--optimizer AdamW \ + +# The type of optimizer, choosing AdamW. + +--adam_beta1 0.9 \ + +# The beta1 parameter for the Adam optimizer, set to 0.9. + +--adam_beta2 0.95 \ + +# The beta2 parameter for the Adam optimizer, set to 0.95. + +``` + +### Run the script to start fine-tuning + +Single GPU fine-tuning: + +```shell +bash finetune_single_gpu.sh +``` + +Multi-GPU fine-tuning: + +```shell +bash finetune_multi_gpus_1.sh # needs to be run on each node +``` + +### Best Practices + ++ Include 70 videos with a resolution of `200 x 480 x 720` (frames x height x width). Through data preprocessing's frame + skipping, we created two smaller datasets of 49 and 16 frames to speed up experiments, as the CogVideoX team suggests + a maximum frame count of 49. We divided the 70 videos into three groups of 10, 25, and 50 videos. These videos are + conceptually similar. ++ 25 or more videos work best when training new concepts and styles. ++ Now using an identifier token specified through `--id_token` enhances training results. This is similar to Dreambooth + training, but regular fine-tuning without this token also works. ++ The original repository uses `lora_alpha` set to 1. We found this value to be ineffective in multiple runs, likely due + to differences in model backend and training setups. Our recommendation is to set lora_alpha to the same as rank or + rank // 2. ++ Using settings with a rank of 64 or above is recommended. \ No newline at end of file diff --git a/finetune/README_zh.md b/finetune/README_zh.md new file mode 100644 index 0000000..514ce72 --- /dev/null +++ b/finetune/README_zh.md @@ -0,0 +1,157 @@ +# CogVideoX diffusers 微调方案 + +如果您想查看SAT版本微调,请查看[这里](../sat/README_zh.md)。其数据集格式与本版本不同。 + +本教程旨在快速微调 diffusers 版本 CogVideoX 模型。 + +### 硬件要求 + ++ CogVideoX-2B LORA: 1 * A100 ++ CogVideoX-2B SFT: 8 * A100 ++ CogVideoX-5B/5B-I2V 暂未支持 + +### 准备数据集 + +首先,你需要准备数据集,数据集格式如下,其中,videos.txt 存放 videos 中的视频。 + +``` +. +├── prompts.txt +├── videos +└── videos.txt +``` + +### 配置文件和运行 + +`accelerate` 配置文件如下: + ++ accelerate_config_machine_multi.yaml 适合多GPU使用 ++ accelerate_config_machine_single.yaml 适合单GPU使用 + +`finetune` 脚本配置文件如下: + +```shell +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True +# 这条命令设置了 PyTorch 的 CUDA 内存分配策略,将显存扩展为段式内存管理,以防止 OOM(Out of Memory)错误。 + +accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gpu \ +# 使用 Accelerate 启动训练,指定配置文件 `accelerate_config_machine_single.yaml`,并使用多 GPU。 + + train_cogvideox_lora.py \ + # 这是你要执行的训练脚本,用于 LoRA 微调 CogVideoX 模型。 + + --pretrained_model_name_or_path THUDM/CogVideoX-2b \ + # 预训练模型的路径,指向你要微调的 CogVideoX-5b 模型。 + + --cache_dir ~/.cache \ + # 模型缓存的目录,用于存储从 Hugging Face 下载的模型和数据集。 + + --enable_tiling \ + # 启用 VAE tiling 功能,通过将图像划分成更小的区块处理,减少显存占用。 + + --enable_slicing \ + # 启用 VAE slicing 功能,将图像在通道上切片处理,以节省显存。 + + --instance_data_root ~/disney/ \ + # 实例数据的根目录,训练时使用的数据集文件夹。 + + --caption_column prompts.txt \ + # 用于指定包含实例提示(文本描述)的列或文件,在本例中为 `prompts.txt` 文件。 + + --video_column videos.txt \ + # 用于指定包含视频路径的列或文件,在本例中为 `videos.txt` 文件。 + + --validation_prompt "Mickey with the captain and friends:::Mickey and the bear" \ + # 用于验证的提示语,多个提示语用指定分隔符(例如 `:::`)分开。 + + --validation_prompt_separator ::: \ + # 验证提示语的分隔符,在此设置为 `:::`。 + + --num_validation_videos 1 \ + # 验证期间生成的视频数量,设置为 1。 + + --validation_epochs 2 \ + # 每隔多少个 epoch 运行一次验证,设置为每 2 个 epoch 验证一次。 + + --seed 3407 \ + # 设置随机数种子,确保训练的可重复性,设置为 3407。 + + --rank 128 \ + # LoRA 更新矩阵的维度,控制 LoRA 层的参数大小,设置为 128。 + + --mixed_precision bf16 \ + # 使用混合精度训练,设置为 `bf16`(bfloat16),可以减少显存占用并加速训练。 + + --output_dir cogvideox-lora-single-gpu \ + # 输出目录,存放模型预测结果和检查点。 + + --height 480 \ + # 输入视频的高度,所有视频将被调整到 480 像素。 + + --width 720 \ + # 输入视频的宽度,所有视频将被调整到 720 像素。 + + --fps 8 \ + # 输入视频的帧率,所有视频将以每秒 8 帧处理。 + + --max_num_frames 49 \ + # 输入视频的最大帧数,视频将被截取到最多 49 帧。 + + --skip_frames_start 0 \ + # 每个视频从头部开始跳过的帧数,设置为 0,表示不跳过帧。 + + --skip_frames_end 0 \ + # 每个视频从尾部跳过的帧数,设置为 0,表示不跳过尾帧。 + + --train_batch_size 1 \ + # 训练的批次大小,每个设备的训练批次设置为 1。 + + --num_train_epochs 10 \ + # 训练的总 epoch 数,设置为 10。 + + --checkpointing_steps 500 \ + # 每经过 500 步保存一次检查点。 + + --gradient_accumulation_steps 1 \ + # 梯度累积步数,表示每进行 1 步才进行一次梯度更新。 + + --learning_rate 1e-4 \ + # 初始学习率,设置为 1e-4。 + + --optimizer AdamW \ + # 优化器类型,选择 AdamW 优化器。 + + --adam_beta1 0.9 \ + # Adam 优化器的 beta1 参数,设置为 0.9。 + + --adam_beta2 0.95 \ + # Adam 优化器的 beta2 参数,设置为 0.95。 +``` + +### 运行脚本,开始微调 + +单卡微调: + +```shell +bash finetune_single_gpu.sh +``` + +多卡微调: + +```shell +bash finetune_multi_gpus_1.sh #需要在每个节点运行 +``` + +### 最佳实践 + ++ 包含70个分辨率为 `200 x 480 x 720`(帧数 x 高 x + 宽)的训练视频。通过数据预处理中的帧跳过,我们创建了两个较小的49帧和16帧数据集,以加快实验速度,因为CogVideoX团队建议的最大帧数限制是49帧。我们将70个视频分成三组,分别为10、25和50个视频。这些视频的概念性质相似。 ++ 25个及以上的视频在训练新概念和风格时效果最佳。 ++ 现使用可以通过 `--id_token` 指定的标识符token进行训练效果更好。这类似于 Dreambooth 训练,但不使用这种token的常规微调也可以工作。 ++ 原始仓库使用 `lora_alpha` 设置为 1。我们发现这个值在多次运行中效果不佳,可能是因为模型后端和训练设置的不同。我们的建议是将 + lora_alpha 设置为与 rank 相同或 rank // 2。 ++ 建议使用 rank 为 64 及以上的设置。 + + + + diff --git a/finetune/accelerate_config_machine_multi.yaml b/finetune/accelerate_config_machine_multi.yaml new file mode 100644 index 0000000..856db57 --- /dev/null +++ b/finetune/accelerate_config_machine_multi.yaml @@ -0,0 +1,26 @@ +compute_environment: LOCAL_MACHINE +debug: true +deepspeed_config: + deepspeed_hostfile: hostfile.txt + deepspeed_multinode_launcher: pdsh + gradient_accumulation_steps: 1 + gradient_clipping: 1.0 + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: true + zero_stage: 3 +distributed_type: DEEPSPEED +downcast_bf16: 'yes' +enable_cpu_affinity: true +main_process_ip: 10.250.128.19 +main_process_port: 12355 +main_training_function: main +mixed_precision: bf16 +num_machines: 4 +num_processes: 32 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/finetune/accelerate_config_machine_single.yaml b/finetune/accelerate_config_machine_single.yaml new file mode 100644 index 0000000..84eb13e --- /dev/null +++ b/finetune/accelerate_config_machine_single.yaml @@ -0,0 +1,24 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + gradient_accumulation_steps: 1 + gradient_clipping: 1.0 + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: false + zero_stage: 2 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +enable_cpu_affinity: false +machine_rank: 0 +main_training_function: main +dynamo_backend: 'no' +mixed_precision: 'no' +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false \ No newline at end of file diff --git a/finetune/finetune_multi_gpus_1.sh b/finetune/finetune_multi_gpus_1.sh new file mode 100644 index 0000000..37e1fce --- /dev/null +++ b/finetune/finetune_multi_gpus_1.sh @@ -0,0 +1,39 @@ +#!/bin/bash + +export MODEL_PATH="THUDM/CogVideoX-2b" +export CACHE_PATH="~/.cache" +export DATASET_PATH="disney" +export OUTPUT_PATH="cogvideox-lora-multi-gpu" +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + +accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gpu --machine_rank 0 \ + train_cogvideox_lora.py \ + --pretrained_model_name_or_path $MODEL_PATH \ + --cache_dir $CACHE_PATH \ + --enable_tiling \ + --enable_slicing \ + --instance_data_root $DATASET_PATH \ + --caption_column prompts.txt \ + --video_column videos.txt \ + --validation_prompt "Mickey with the captain and friends:::Mickey and the bear" \ + --validation_prompt_separator ::: \ + --num_validation_videos 1 \ + --validation_epochs 2 \ + --seed 3407 \ + --rank 128 \ + --mixed_precision bf16 \ + --output_dir $OUTPUT_PATH \ + --height 480 \ + --width 720 \ + --fps 8 \ + --max_num_frames 49 \ + --skip_frames_start 0 \ + --skip_frames_end 0 \ + --train_batch_size 1 \ + --num_train_epochs 10 \ + --checkpointing_steps 500 \ + --gradient_accumulation_steps 1 \ + --learning_rate 1e-4 \ + --optimizer AdamW \ + --adam_beta1 0.9 \ + --adam_beta2 0.95 \ No newline at end of file diff --git a/finetune/finetune_single_gpu.sh b/finetune/finetune_single_gpu.sh new file mode 100644 index 0000000..3f866fe --- /dev/null +++ b/finetune/finetune_single_gpu.sh @@ -0,0 +1,39 @@ +#!/bin/bash + +export MODEL_PATH="THUDM/CogVideoX-2b" +export CACHE_PATH="~/.cache" +export DATASET_PATH="disney" +export OUTPUT_PATH="cogvideox-lora-single-gpu" +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + +accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gpu \ + train_cogvideox_lora.py \ + --pretrained_model_name_or_path $MODEL_PATH \ + --cache_dir $CACHE_PATH \ + --enable_tiling \ + --enable_slicing \ + --instance_data_root $DATASET_PATH \ + --caption_column prompts.txt \ + --video_column videos.txt \ + --validation_prompt "Mickey with the captain and friends:::Mickey and the bear" \ + --validation_prompt_separator ::: \ + --num_validation_videos 1 \ + --validation_epochs 2 \ + --seed 3407 \ + --rank 128 \ + --mixed_precision bf16 \ + --output_dir $OUTPUT_PATH \ + --height 480 \ + --width 720 \ + --fps 8 \ + --max_num_frames 49 \ + --skip_frames_start 0 \ + --skip_frames_end 0 \ + --train_batch_size 1 \ + --num_train_epochs 10 \ + --checkpointing_steps 500 \ + --gradient_accumulation_steps 1 \ + --learning_rate 1e-4 \ + --optimizer AdamW \ + --adam_beta1 0.9 \ + --adam_beta2 0.95 \ No newline at end of file diff --git a/finetune/hostfile.txt b/finetune/hostfile.txt new file mode 100644 index 0000000..0dc339a --- /dev/null +++ b/finetune/hostfile.txt @@ -0,0 +1,4 @@ +node1 slots=8 +node2 slots=8 +node3 slots=8 +node4 slots=8 \ No newline at end of file diff --git a/finetune/train_cogvideox_lora.py b/finetune/train_cogvideox_lora.py new file mode 100644 index 0000000..12f3e9c --- /dev/null +++ b/finetune/train_cogvideox_lora.py @@ -0,0 +1,1593 @@ +# Copyright 2024 The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import itertools +import logging +import math +import os +import shutil +from pathlib import Path +from typing import List, Optional, Tuple, Union + +import torch +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed +from huggingface_hub import create_repo, upload_folder +from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import T5EncoderModel, T5Tokenizer + +import diffusers +from diffusers import AutoencoderKLCogVideoX, CogVideoXDPMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel +from diffusers.models.embeddings import get_3d_rotary_pos_embed +from diffusers.optimization import get_scheduler +from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid +from diffusers.training_utils import ( + _set_state_dict_into_text_encoder, + cast_training_params, + clear_objs_and_retain_memory, +) +from diffusers.utils import check_min_version, convert_unet_state_dict_to_peft, export_to_video, is_wandb_available +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.torch_utils import is_compiled_module + +if is_wandb_available(): + import wandb + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.31.0.dev0") + +logger = get_logger(__name__) + + +def get_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script for CogVideoX.") + + # Model information + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + + # Dataset information + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--instance_data_root", + type=str, + default=None, + help=("A folder containing the training data."), + ) + parser.add_argument( + "--video_column", + type=str, + default="video", + help="The column of the dataset containing videos. Or, the name of the file in `--instance_data_root` folder containing the line-separated path to video data.", + ) + parser.add_argument( + "--caption_column", + type=str, + default="text", + help="The column of the dataset containing the instance prompt for each video. Or, the name of the file in `--instance_data_root` folder containing the line-separated instance prompts.", + ) + parser.add_argument( + "--id_token", type=str, default=None, help="Identifier token appended to the start of each prompt if provided." + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + + # Validation + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.", + ) + parser.add_argument( + "--validation_prompt_separator", + type=str, + default=":::", + help="String that separates multiple validation prompts", + ) + parser.add_argument( + "--num_validation_videos", + type=int, + default=1, + help="Number of videos that should be generated during validation per `validation_prompt`.", + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=50, + help=( + "Run validation every X epochs. Validation consists of running the prompt `args.validation_prompt` multiple times: `args.num_validation_videos`." + ), + ) + parser.add_argument( + "--guidance_scale", + type=float, + default=6, + help="The guidance scale to use while sampling validation videos.", + ) + parser.add_argument( + "--use_dynamic_cfg", + action="store_true", + default=False, + help="Whether or not to use the default cosine dynamic guidance schedule when sampling validation videos.", + ) + + # Training information + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--rank", + type=int, + default=128, + help=("The dimension of the LoRA update matrices."), + ) + parser.add_argument( + "--lora_alpha", + type=float, + default=128, + help=("The scaling factor to scale LoRA weight update. The actual scaling factor is `lora_alpha / rank`"), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="cogvideox-lora", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--height", + type=int, + default=480, + help="All input videos are resized to this height.", + ) + parser.add_argument( + "--width", + type=int, + default=720, + help="All input videos are resized to this width.", + ) + parser.add_argument("--fps", type=int, default=8, help="All input videos will be used at this FPS.") + parser.add_argument( + "--max_num_frames", type=int, default=49, help="All input videos will be truncated to these many frames." + ) + parser.add_argument( + "--skip_frames_start", + type=int, + default=0, + help="Number of frames to skip from the beginning of each input video. Useful if training data contains intro sequences.", + ) + parser.add_argument( + "--skip_frames_end", + type=int, + default=0, + help="Number of frames to skip from the end of each input video. Useful if training data contains outro sequences.", + ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip videos horizontally", + ) + parser.add_argument( + "--train_text_encoder", + action="store_true", + help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides `--num_train_epochs`.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--text_encoder_lr", + type=float, + default=5e-6, + help="Text encoder learning rate to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--enable_slicing", + action="store_true", + default=False, + help="Whether or not to use VAE slicing for saving memory.", + ) + parser.add_argument( + "--enable_tiling", + action="store_true", + default=False, + help="Whether or not to use VAE tiling for saving memory.", + ) + + # Optimizer + parser.add_argument( + "--optimizer", + type=lambda s: s.lower(), + default="adam", + choices=["adam", "adamw", "prodigy"], + help=("The optimizer type to use."), + ) + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW", + ) + parser.add_argument( + "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--adam_beta2", type=float, default=0.95, help="The beta2 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--prodigy_beta3", + type=float, + default=None, + help="Coefficients for computing the Prodigy optimizer's stepsize using running averages. If set to None, uses the value of square root of beta2.", + ) + parser.add_argument("--prodigy_decouple", action="store_true", help="Use AdamW style decoupled weight decay") + parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") + parser.add_argument( + "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" + ) + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-08, + help="Epsilon value for the Adam optimizer and Prodigy optimizers.", + ) + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--prodigy_use_bias_correction", action="store_true", help="Turn on Adam's bias correction.") + parser.add_argument( + "--prodigy_safeguard_warmup", + action="store_true", + help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage.", + ) + + # Other information + parser.add_argument("--tracker_name", type=str, default=None, help="Project tracker name") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help="Directory where logs are stored.", + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--report_to", + type=str, + default=None, + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + + return parser.parse_args() + + +class VideoDataset(Dataset): + def __init__( + self, + instance_data_root: Optional[str] = None, + dataset_name: Optional[str] = None, + dataset_config_name: Optional[str] = None, + caption_column: str = "text", + video_column: str = "video", + height: int = 480, + width: int = 720, + fps: int = 8, + max_num_frames: int = 49, + skip_frames_start: int = 0, + skip_frames_end: int = 0, + cache_dir: Optional[str] = None, + id_token: Optional[str] = None, + ) -> None: + super().__init__() + + self.instance_data_root = Path(instance_data_root) if instance_data_root is not None else None + self.dataset_name = dataset_name + self.dataset_config_name = dataset_config_name + self.caption_column = caption_column + self.video_column = video_column + self.height = height + self.width = width + self.fps = fps + self.max_num_frames = max_num_frames + self.skip_frames_start = skip_frames_start + self.skip_frames_end = skip_frames_end + self.cache_dir = cache_dir + self.id_token = id_token or "" + + if dataset_name is not None: + self.instance_prompts, self.instance_video_paths = self._load_dataset_from_hub() + else: + self.instance_prompts, self.instance_video_paths = self._load_dataset_from_local_path() + + self.num_instance_videos = len(self.instance_video_paths) + if self.num_instance_videos != len(self.instance_prompts): + raise ValueError( + f"Expected length of instance prompts and videos to be the same but found {len(self.instance_prompts)=} and {len(self.instance_video_paths)=}. Please ensure that the number of caption prompts and videos match in your dataset." + ) + + self.instance_videos = self._preprocess_data() + + def __len__(self): + return self.num_instance_videos + + def __getitem__(self, index): + return { + "instance_prompt": self.id_token + self.instance_prompts[index], + "instance_video": self.instance_videos[index], + } + + def _load_dataset_from_hub(self): + try: + from datasets import load_dataset + except ImportError: + raise ImportError( + "You are trying to load your data using the datasets library. If you wish to train using custom " + "captions please install the datasets library: `pip install datasets`. If you wish to load a " + "local folder containing images only, specify --instance_data_root instead." + ) + + # Downloading and loading a dataset from the hub. See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script + dataset = load_dataset( + self.dataset_name, + self.dataset_config_name, + cache_dir=self.cache_dir, + ) + column_names = dataset["train"].column_names + + if self.video_column is None: + video_column = column_names[0] + logger.info(f"`video_column` defaulting to {video_column}") + else: + video_column = self.video_column + if video_column not in column_names: + raise ValueError( + f"`--video_column` value '{video_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + + if self.caption_column is None: + caption_column = column_names[1] + logger.info(f"`caption_column` defaulting to {caption_column}") + else: + caption_column = self.caption_column + if self.caption_column not in column_names: + raise ValueError( + f"`--caption_column` value '{self.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + + instance_prompts = dataset["train"][caption_column] + instance_videos = [Path(self.instance_data_root, filepath) for filepath in dataset["train"][video_column]] + + return instance_prompts, instance_videos + + def _load_dataset_from_local_path(self): + if not self.instance_data_root.exists(): + raise ValueError("Instance videos root folder does not exist") + + prompt_path = self.instance_data_root.joinpath(self.caption_column) + video_path = self.instance_data_root.joinpath(self.video_column) + + if not prompt_path.exists() or not prompt_path.is_file(): + raise ValueError( + "Expected `--caption_column` to be path to a file in `--instance_data_root` containing line-separated text prompts." + ) + if not video_path.exists() or not video_path.is_file(): + raise ValueError( + "Expected `--video_column` to be path to a file in `--instance_data_root` containing line-separated paths to video data in the same directory." + ) + + with open(prompt_path, "r", encoding="utf-8") as file: + instance_prompts = [line.strip() for line in file.readlines() if len(line.strip()) > 0] + with open(video_path, "r", encoding="utf-8") as file: + instance_videos = [ + self.instance_data_root.joinpath(line.strip()) for line in file.readlines() if len(line.strip()) > 0 + ] + + if any(not path.is_file() for path in instance_videos): + raise ValueError( + "Expected '--video_column' to be a path to a file in `--instance_data_root` containing line-separated paths to video data but found atleast one path that is not a valid file." + ) + + return instance_prompts, instance_videos + + def _preprocess_data(self): + try: + import decord + except ImportError: + raise ImportError( + "The `decord` package is required for loading the video dataset. Install with `pip install dataset`" + ) + + decord.bridge.set_bridge("torch") + + videos = [] + train_transforms = transforms.Compose( + [ + transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0), + ] + ) + + for filename in self.instance_video_paths: + video_reader = decord.VideoReader(uri=filename.as_posix(), width=self.width, height=self.height) + video_num_frames = len(video_reader) + + start_frame = min(self.skip_frames_start, video_num_frames) + end_frame = max(0, video_num_frames - self.skip_frames_end) + if end_frame <= start_frame: + frames = video_reader.get_batch([start_frame]) + elif end_frame - start_frame <= self.max_num_frames: + frames = video_reader.get_batch(list(range(start_frame, end_frame))) + else: + indices = list(range(start_frame, end_frame, (end_frame - start_frame) // self.max_num_frames)) + frames = video_reader.get_batch(indices) + + # Ensure that we don't go over the limit + frames = frames[: self.max_num_frames] + selected_num_frames = frames.shape[0] + + # Choose first (4k + 1) frames as this is how many is required by the VAE + remainder = (3 + (selected_num_frames % 4)) % 4 + if remainder != 0: + frames = frames[:-remainder] + selected_num_frames = frames.shape[0] + + assert (selected_num_frames - 1) % 4 == 0 + + # Training transforms + frames = frames.float() + frames = torch.stack([train_transforms(frame) for frame in frames], dim=0) + videos.append(frames.permute(0, 3, 1, 2).contiguous()) # [F, C, H, W] + + return videos + + +def save_model_card( + repo_id: str, + videos=None, + base_model: str = None, + train_text_encoder=False, + validation_prompt=None, + repo_folder=None, + fps=8, +): + widget_dict = [] + if videos is not None: + for i, video in enumerate(videos): + export_to_video(video, os.path.join(repo_folder, f"final_video_{i}.mp4", fps=fps)) + widget_dict.append( + {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"video_{i}.mp4"}} + ) + + model_description = f""" +# CogVideoX LoRA - {repo_id} + + + +## Model description + +These are {repo_id} LoRA weights for {base_model}. + +The weights were trained using the [CogVideoX Diffusers trainer](TODO). + +Was LoRA for the text encoder enabled? {train_text_encoder}. + +## Download model + +[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab. + +## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers) + +```py +from diffusers import CogVideoXPipeline +import torch + +pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to("cuda") +pipe.load_lora_weights("{repo_id}", weight_name="pytorch_lora_weights.safetensors") +video = pipe("{validation_prompt}").frames[0] +``` + +For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) + +## License + +Please adhere to the licensing terms as described [here](https://huggingface.co/THUDM/CogVideoX-5b/blob/main/LICENSE) and [here](https://huggingface.co/THUDM/CogVideoX-2b/blob/main/LICENSE). +""" + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="other", + base_model=base_model, + prompt=validation_prompt, + model_description=model_description, + widget=widget_dict, + ) + tags = [ + "text-to-video", + "diffusers-training", + "diffusers", + "lora", + "cogvideox", + "cogvideox-diffusers", + "template:sd-lora", + ] + + model_card = populate_model_card(model_card, tags=tags) + model_card.save(os.path.join(repo_folder, "README.md")) + + +def log_validation( + pipe, + args, + accelerator, + pipeline_args, + epoch, + is_final_validation: bool = False, +): + logger.info( + f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}." + ) + # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it + scheduler_args = {} + + if "variance_type" in pipe.scheduler.config: + variance_type = pipe.scheduler.config.variance_type + + if variance_type in ["learned", "learned_range"]: + variance_type = "fixed_small" + + scheduler_args["variance_type"] = variance_type + + pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, **scheduler_args) + pipe = pipe.to(accelerator.device) + # pipe.set_progress_bar_config(disable=True) + + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + + videos = [] + for _ in range(args.num_validation_videos): + video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0] + videos.append(video) + + for tracker in accelerator.trackers: + phase_name = "test" if is_final_validation else "validation" + if tracker.name == "wandb": + video_filenames = [] + for i, video in enumerate(videos): + prompt = ( + pipeline_args["prompt"][:25] + .replace(" ", "_") + .replace(" ", "_") + .replace("'", "_") + .replace('"', "_") + .replace("/", "_") + ) + filename = os.path.join(args.output_dir, f"{phase_name}_video_{i}_{prompt}.mp4") + export_to_video(video, filename, fps=8) + video_filenames.append(filename) + + tracker.log( + { + phase_name: [ + wandb.Video(filename, caption=f"{i}: {pipeline_args['prompt']}") + for i, filename in enumerate(video_filenames) + ] + } + ) + + clear_objs_and_retain_memory([pipe]) + + return videos + + +def _get_t5_prompt_embeds( + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + prompt: Union[str, List[str]], + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + text_input_ids=None, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if tokenizer is not None: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + else: + if text_input_ids is None: + raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.") + + prompt_embeds = text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + +def encode_prompt( + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + prompt: Union[str, List[str]], + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + text_input_ids=None, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt_embeds = _get_t5_prompt_embeds( + tokenizer, + text_encoder, + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + text_input_ids=text_input_ids, + ) + return prompt_embeds + + +def compute_prompt_embeddings(tokenizer, text_encoder, prompt, device, dtype, requires_grad: bool = False): + if requires_grad: + prompt_embeds = encode_prompt( + tokenizer, + text_encoder, + prompt, + num_videos_per_prompt=1, + device=device, + dtype=dtype, + ) + else: + with torch.no_grad(): + prompt_embeds = encode_prompt( + tokenizer, + text_encoder, + prompt, + num_videos_per_prompt=1, + device=device, + dtype=dtype, + ) + return prompt_embeds + + +def prepare_rotary_positional_embeddings( + height: int, + width: int, + num_frames: int, + vae_scale_factor_spatial: int = 8, + patch_size: int = 2, + attention_head_dim: int = 64, + device: Optional[torch.device] = None, + base_height: int = 480, + base_width: int = 720, +) -> Tuple[torch.Tensor, torch.Tensor]: + grid_height = height // (vae_scale_factor_spatial * patch_size) + grid_width = width // (vae_scale_factor_spatial * patch_size) + base_size_width = base_width // (vae_scale_factor_spatial * patch_size) + base_size_height = base_height // (vae_scale_factor_spatial * patch_size) + + grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, base_size_height) + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + ) + + freqs_cos = freqs_cos.to(device=device) + freqs_sin = freqs_sin.to(device=device) + return freqs_cos, freqs_sin + + +def get_optimizer(args, params_to_optimize): + # Optimizer creation + supported_optimizers = ["adam", "adamw", "prodigy"] + if args.optimizer not in ["adam", "adamw", "prodigy"]: + logger.warning( + f"Unsupported choice of optimizer: {args.optimizer}. Supported optimizers include {supported_optimizers}. Defaulting to AdamW" + ) + args.optimizer = "adamw" + + if args.use_8bit_adam and not (args.optimizer.lower() not in ["adam", "adamw"]): + logger.warning( + f"use_8bit_adam is ignored when optimizer is not set to 'Adam' or 'AdamW'. Optimizer was " + f"set to {args.optimizer.lower()}" + ) + + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + if args.optimizer.lower() == "adamw": + optimizer_class = bnb.optim.AdamW8bit if args.use_8bit_adam else torch.optim.AdamW + + optimizer = optimizer_class( + params_to_optimize, + betas=(args.adam_beta1, args.adam_beta2), + eps=args.adam_epsilon, + weight_decay=args.adam_weight_decay, + ) + elif args.optimizer.lower() == "adam": + optimizer_class = bnb.optim.Adam8bit if args.use_8bit_adam else torch.optim.Adam + + optimizer = optimizer_class( + params_to_optimize, + betas=(args.adam_beta1, args.adam_beta2), + eps=args.adam_epsilon, + weight_decay=args.adam_weight_decay, + ) + elif args.optimizer.lower() == "prodigy": + try: + import prodigyopt + except ImportError: + raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") + + optimizer_class = prodigyopt.Prodigy + + if args.learning_rate <= 0.1: + logger.warning( + "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" + ) + if args.train_text_encoder and args.text_encoder_lr: + logger.warning( + f"Learning rates were provided both for the transformer and the text encoder - e.g. text_encoder_lr:" + f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. " + f"When using prodigy only learning_rate is used as the initial learning rate." + ) + # Changes the learning rate of text_encoder_parameters to be --learning_rate + params_to_optimize[1]["lr"] = args.learning_rate + + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + beta3=args.prodigy_beta3, + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + decouple=args.prodigy_decouple, + use_bias_correction=args.prodigy_use_bias_correction, + safeguard_warmup=args.prodigy_safeguard_warmup, + ) + + return optimizer + + +def main(args): + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + if torch.backends.mps.is_available() and args.mixed_precision == "bf16": + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + kwargs_handlers=[kwargs], + ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, + exist_ok=True, + ).repo_id + + # Prepare models and scheduler + tokenizer = T5Tokenizer.from_pretrained( + args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision + ) + + text_encoder = T5EncoderModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + ) + + # CogVideoX-2b weights are stored in float16 + # CogVideoX-5b and CogVideoX-5b-I2V weights are stored in bfloat16 + load_dtype = torch.bfloat16 if "5b" in args.pretrained_model_name_or_path.lower() else torch.float16 + transformer = CogVideoXTransformer3DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + torch_dtype=load_dtype, + revision=args.revision, + variant=args.variant, + ) + + vae = AutoencoderKLCogVideoX.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant + ) + + scheduler = CogVideoXDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + + if args.enable_slicing: + vae.enable_slicing() + if args.enable_tiling: + vae.enable_tiling() + + # We only train the additional adapter LoRA layers + text_encoder.requires_grad_(False) + transformer.requires_grad_(False) + vae.requires_grad_(False) + + # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.state.deepspeed_plugin: + # DeepSpeed is handling precision, use what's in the DeepSpeed config + if ( + "fp16" in accelerator.state.deepspeed_plugin.deepspeed_config + and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"] + ): + weight_dtype = torch.float16 + if ( + "bf16" in accelerator.state.deepspeed_plugin.deepspeed_config + and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"] + ): + weight_dtype = torch.float16 + else: + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + text_encoder.to(accelerator.device, dtype=weight_dtype) + transformer.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=weight_dtype) + + if args.gradient_checkpointing: + transformer.enable_gradient_checkpointing() + if args.train_text_encoder: + text_encoder.gradient_checkpointing_enable() + + # now we will add new LoRA weights to the attention layers + transformer_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.lora_alpha, + init_lora_weights=True, + target_modules=["to_k", "to_q", "to_v", "to_out.0"], + ) + transformer.add_adapter(transformer_lora_config) + + if args.train_text_encoder: + text_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.lora_alpha, + init_lora_weights=True, + target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], + ) + text_encoder.add_adapter(text_lora_config) + + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + transformer_lora_layers_to_save = None + text_encoder_lora_layers_to_save = None + + for model in models: + if isinstance(model, type(unwrap_model(transformer))): + transformer_lora_layers_to_save = get_peft_model_state_dict(model) + elif isinstance(model, type(unwrap_model(text_encoder))): + text_encoder_lora_layers_to_save = get_peft_model_state_dict(model) + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + CogVideoXPipeline.save_lora_weights( + output_dir, + transformer_lora_layers=transformer_lora_layers_to_save, + text_encoder_lora_layers=text_encoder_lora_layers_to_save, + ) + + def load_model_hook(models, input_dir): + transformer_ = None + text_encoder_ = None + + while len(models) > 0: + model = models.pop() + + if isinstance(model, type(unwrap_model(transformer))): + transformer_ = model + elif isinstance(model, type(unwrap_model(text_encoder))): + text_encoder_ = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + lora_state_dict = CogVideoXPipeline.lora_state_dict(input_dir) + + transformer_state_dict = { + f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") + } + transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) + incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + if args.train_text_encoder: + # Do we need to call `scale_lora_layers()` here? + _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_) + + # Make sure the trainable params are in float32. This is again needed since the base models + # are in `weight_dtype`. More details: + # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 + if args.mixed_precision == "fp16": + models = [transformer_] + if args.train_text_encoder: + models.extend([text_encoder_]) + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models) + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32 and torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Make sure the trainable params are in float32. + if args.mixed_precision == "fp16": + models = [transformer] + if args.train_text_encoder: + models.extend([text_encoder]) + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models, dtype=torch.float32) + + transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) + if args.train_text_encoder: + text_encoder_lora_parameters = list(filter(lambda p: p.requires_grad, text_encoder.parameters())) + + # Optimization parameters + transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate} + if args.train_text_encoder: + # different learning rate for text encoder and unet + text_encoder_parameters_with_lr = { + "params": text_encoder_lora_parameters, + "weight_decay": args.adam_weight_decay_text_encoder, + "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, + } + params_to_optimize = [ + transformer_parameters_with_lr, + text_encoder_parameters_with_lr, + ] + else: + params_to_optimize = [transformer_parameters_with_lr] + + optimizer = get_optimizer(args, params_to_optimize) + + # Dataset and DataLoader + train_dataset = VideoDataset( + instance_data_root=args.instance_data_root, + dataset_name=args.dataset_name, + dataset_config_name=args.dataset_config_name, + caption_column=args.caption_column, + video_column=args.video_column, + height=args.height, + width=args.width, + fps=args.fps, + max_num_frames=args.max_num_frames, + skip_frames_start=args.skip_frames_start, + skip_frames_end=args.skip_frames_end, + cache_dir=args.cache_dir, + id_token=args.id_token, + ) + + def encode_video(video): + video = video.to(accelerator.device, dtype=vae.dtype).unsqueeze(0) + video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] + latent_dist = vae.encode(video).latent_dist + return latent_dist + + train_dataset.instance_videos = [encode_video(video) for video in train_dataset.instance_videos] + + def collate_fn(examples): + videos = [example["instance_video"].sample() * vae.config.scaling_factor for example in examples] + prompts = [example["instance_prompt"] for example in examples] + + videos = torch.cat(videos) + videos = videos.to(memory_format=torch.contiguous_format).float() + + return { + "videos": videos, + "prompts": prompts, + } + + train_dataloader = DataLoader( + train_dataset, + batch_size=args.train_batch_size, + shuffle=True, + collate_fn=collate_fn, + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + if args.train_text_encoder: + ( + transformer, + text_encoder, + optimizer, + train_dataloader, + lr_scheduler, + ) = accelerator.prepare( + transformer, + text_encoder, + optimizer, + train_dataloader, + lr_scheduler, + ) + else: + transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_name = args.tracker_name or "cogvideox-lora" + accelerator.init_trackers(tracker_name, config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + num_trainable_parameters = sum(param.numel() for model in params_to_optimize for param in model["params"]) + + logger.info("***** Running training *****") + logger.info(f" Num trainable parameters = {num_trainable_parameters}") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if not args.resume_from_checkpoint: + initial_global_step = 0 + else: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the mos recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + vae_scale_factor_spatial = 2 ** (len(vae.config.block_out_channels) - 1) + + # For DeepSpeed training + model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config + + for epoch in range(first_epoch, args.num_train_epochs): + transformer.train() + if args.train_text_encoder: + text_encoder.train() + # set top parameter requires_grad = True for gradient checkpointing works + accelerator.unwrap_model(text_encoder).text_model.embeddings.requires_grad_(True) + + for step, batch in enumerate(train_dataloader): + models_to_accumulate = [transformer] + if args.train_text_encoder: + models_to_accumulate.extend([text_encoder]) + + with accelerator.accumulate(models_to_accumulate): + model_input = batch["videos"].permute(0, 2, 1, 3, 4).to(dtype=weight_dtype) # [B, F, C, H, W] + prompts = batch["prompts"] + + # encode prompts + prompt_embeds = compute_prompt_embeddings( + tokenizer, + text_encoder, + prompts, + accelerator.device, + weight_dtype, + requires_grad=args.train_text_encoder, + ) + + # Sample noise that will be added to the latents + noise = torch.randn_like(model_input) + batch_size, num_frames, num_channels, height, width = model_input.shape + + # Sample a random timestep for each image + timesteps = torch.randint( + 0, scheduler.config.num_train_timesteps, (batch_size,), device=model_input.device + ) + timesteps = timesteps.long() + + # Prepare rotary embeds + image_rotary_emb = ( + prepare_rotary_positional_embeddings( + height=args.height, + width=args.width, + num_frames=num_frames, + vae_scale_factor_spatial=vae_scale_factor_spatial, + patch_size=model_config.patch_size, + attention_head_dim=model_config.attention_head_dim, + device=accelerator.device, + ) + if model_config.use_rotary_positional_embeddings + else None + ) + + # Add noise to the model input according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_model_input = scheduler.add_noise(model_input, noise, timesteps) + + # Predict the noise residual + model_output = transformer( + hidden_states=noisy_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timesteps, + image_rotary_emb=image_rotary_emb, + return_dict=False, + )[0] + model_pred = scheduler.get_velocity(model_output, noisy_model_input, timesteps) + + alphas_cumprod = scheduler.alphas_cumprod[timesteps] + weights = 1 / (1 - alphas_cumprod) + while len(weights.shape) < len(model_pred.shape): + weights = weights.unsqueeze(-1) + + target = model_input + + loss = torch.mean((weights * (model_pred - target) ** 2).reshape(batch_size, -1), dim=1) + loss = loss.mean() + accelerator.backward(loss) + + if accelerator.sync_gradients: + params_to_clip = ( + itertools.chain(transformer.parameters(), text_encoder.parameters()) + if args.train_text_encoder + else transformer.parameters() + ) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"Removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if args.validation_prompt is not None and (epoch + 1) % args.validation_epochs == 0: + # Create pipeline + pipe = CogVideoXPipeline.from_pretrained( + args.pretrained_model_name_or_path, + transformer=unwrap_model(transformer), + text_encoder=unwrap_model(text_encoder), + vae=unwrap_model(vae), + scheduler=scheduler, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + + validation_prompts = args.validation_prompt.split(args.validation_prompt_separator) + for validation_prompt in validation_prompts: + pipeline_args = { + "prompt": validation_prompt, + "guidance_scale": args.guidance_scale, + "use_dynamic_cfg": args.use_dynamic_cfg, + "height": args.height, + "width": args.width, + } + + validation_outputs = log_validation( + pipe=pipe, + args=args, + accelerator=accelerator, + pipeline_args=pipeline_args, + epoch=epoch, + ) + + # Save the lora layers + accelerator.wait_for_everyone() + if accelerator.is_main_process: + transformer = unwrap_model(transformer) + # transformer = transformer.to(torch.float32) + dtype = ( + torch.float16 + if args.mixed_precision == "fp16" + else torch.bfloat16 + if args.mixed_precision == "bf16" + else torch.float32 + ) + transformer = transformer.to(dtype) + transformer_lora_layers = get_peft_model_state_dict(transformer) + + if args.train_text_encoder: + text_encoder = unwrap_model(text_encoder) + text_encoder_lora_layers = get_peft_model_state_dict(text_encoder.to(dtype)) + else: + text_encoder_lora_layers = None + + CogVideoXPipeline.save_lora_weights( + save_directory=args.output_dir, + transformer_lora_layers=transformer_lora_layers, + text_encoder_lora_layers=text_encoder_lora_layers, + ) + + # Final test inference + pipe = CogVideoXPipeline.from_pretrained( + args.pretrained_model_name_or_path, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config) + + # Load LoRA weights + lora_scaling = args.lora_alpha / args.rank + pipe.load_lora_weights(args.output_dir, adapter_name="cogvideox-lora") + pipe.set_adapters(["cogvideox-lora"], [lora_scaling]) + + # Run inference + validation_outputs = [] + if args.validation_prompt and args.num_validation_videos > 0: + validation_prompts = args.validation_prompt.split(args.validation_prompt_separator) + for validation_prompt in validation_prompts: + pipeline_args = { + "prompt": validation_prompt, + "guidance_scale": args.guidance_scale, + "use_dynamic_cfg": args.use_dynamic_cfg, + "height": args.height, + "width": args.width, + } + + video = log_validation( + pipe=pipe, + args=args, + accelerator=accelerator, + pipeline_args=pipeline_args, + epoch=epoch, + is_final_validation=True, + ) + validation_outputs.extend(video) + + if args.push_to_hub: + save_model_card( + repo_id, + videos=validation_outputs, + base_model=args.pretrained_model_name_or_path, + train_text_encoder=args.train_text_encoder, + validation_prompt=args.validation_prompt, + repo_folder=args.output_dir, + fps=args.fps, + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = get_args() + main(args) diff --git a/inference/cli_demo.py b/inference/cli_demo.py index f89e63a..4b1e1d1 100644 --- a/inference/cli_demo.py +++ b/inference/cli_demo.py @@ -77,7 +77,9 @@ def generate_video( # 2. Set Scheduler. # Can be changed to `CogVideoXDPMScheduler` or `CogVideoXDDIMScheduler`. - # We recommend using `CogVideoXDDIMScheduler` for CogVideoX-2B and `CogVideoXDPMScheduler` for CogVideoX-5B. + # We recommend using `CogVideoXDDIMScheduler` for CogVideoX-2B. + # using `CogVideoXDPMScheduler` for CogVideoX-5B / CogVideoX-5B-I2V. + # pipe.scheduler = CogVideoXDDIMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") @@ -85,14 +87,16 @@ def generate_video( # turn off if you have multiple GPUs or enough GPU memory(such as H100) and it will cost less time in inference # and enable to("cuda") - # pipe.enable_sequential_cpu_offload() - pipe.to("cuda") + # pipe.to("cuda") + + pipe.enable_sequential_cpu_offload() + pipe.vae.enable_slicing() pipe.vae.enable_tiling() # 4. Generate the video frames based on the prompt. # `num_frames` is the Number of frames to generate. - # This is the default value for 6 seconds video and 8 fps,so 48 frames and will plus 1 frame for the first frame and 49 frames. + # This is the default value for 6 seconds video and 8 fps and will plus 1 frame for the first frame and 49 frames. if generate_type == "i2v": video_generate = pipe( prompt=prompt, @@ -100,7 +104,7 @@ def generate_video( num_videos_per_prompt=num_videos_per_prompt, # Number of videos to generate per prompt num_inference_steps=num_inference_steps, # Number of inference steps num_frames=49, # Number of frames to generate,changed to 49 for diffusers version `0.30.3` and after. - use_dynamic_cfg=True, ## This id used for DPM Sechduler, for DDIM scheduler, it should be False + use_dynamic_cfg=True, # This id used for DPM Sechduler, for DDIM scheduler, it should be False guidance_scale=guidance_scale, generator=torch.Generator().manual_seed(seed), # Set the seed for reproducibility ).frames[0] diff --git a/inference/cli_demo_quantization.py b/inference/cli_demo_quantization.py index 7a702f2..ecaa840 100644 --- a/inference/cli_demo_quantization.py +++ b/inference/cli_demo_quantization.py @@ -6,7 +6,7 @@ Note: Must install the `torchao`,`torch`,`diffusers`,`accelerate` library FROM SOURCE to use the quantization feature. Only NVIDIA GPUs like H100 or higher are supported om FP-8 quantization. -ALL quantization schemes must using with NVIDIA GPUs. +ALL quantization schemes must use with NVIDIA GPUs. # Run the script: @@ -83,8 +83,9 @@ def generate_video( # Using with compile will run faster. First time infer will cost ~30min to compile. # pipe.transformer.to(memory_format=torch.channels_last) - # for FP8 should remove pipe.enable_model_cpu_offload() - pipe.enable_model_cpu_offload() + + # for FP8 should remove pipe.enable_sequential_cpu_offload() + pipe.enable_sequential_cpu_offload() # This is not for FP8 and INT8 and should remove this line # pipe.enable_sequential_cpu_offload() @@ -95,7 +96,7 @@ def generate_video( num_videos_per_prompt=num_videos_per_prompt, num_inference_steps=num_inference_steps, num_frames=49, - use_dynamic_cfg=True, ## This id used for DPM Sechduler, for DDIM scheduler, it should be False + use_dynamic_cfg=True, guidance_scale=guidance_scale, generator=torch.Generator(device="cuda").manual_seed(42), ).frames[0] diff --git a/inference/cli_vae_demo.py b/inference/cli_vae_demo.py index 0943597..07e949e 100644 --- a/inference/cli_vae_demo.py +++ b/inference/cli_vae_demo.py @@ -1,18 +1,22 @@ """ This script is designed to demonstrate how to use the CogVideoX-2b VAE model for video encoding and decoding. It allows you to encode a video into a latent representation, decode it back into a video, or perform both operations sequentially. -Before running the script, make sure to clone the CogVideoX Hugging Face model repository and set the `{your local diffusers path}` argument to the path of the cloned repository. +Before running the script, make sure to clone the CogVideoX Hugging Face model repository and set the +`{your local diffusers path}` argument to the path of the cloned repository. Command 1: Encoding Video -Encodes the video located at ../resources/videos/1.mp4 using the CogVideoX-2b VAE model. -Memory Usage: ~34GB of GPU memory for encoding. -If you do not have enough GPU memory, we provide a pre-encoded tensor file (encoded.pt) in the resources folder and you can still run the decoding command. +Encodes the video located at ../resources/videos/1.mp4 using the CogVideoX-5b VAE model. +Memory Usage: ~18GB of GPU memory for encoding. + +If you do not have enough GPU memory, we provide a pre-encoded tensor file (encoded.pt) in the resources folder, +and you can still run the decoding command. + $ python cli_vae_demo.py --model_path {your local diffusers path}/CogVideoX-2b/vae/ --video_path ../resources/videos/1.mp4 --mode encode Command 2: Decoding Video Decodes the latent representation stored in encoded.pt back into a video. -Memory Usage: ~19GB of GPU memory for decoding. +Memory Usage: ~4GB of GPU memory for decoding. $ python cli_vae_demo.py --model_path {your local diffusers path}/CogVideoX-2b/vae/ --encoded_path ./encoded.pt --mode decode Command 3: Encoding and Decoding Video @@ -24,9 +28,9 @@ $ python cli_vae_demo.py --model_path {your local diffusers path}/CogVideoX-2b/v import argparse import torch import imageio -import numpy as np from diffusers import AutoencoderKLCogVideoX from torchvision import transforms +import numpy as np def encode_video(model_path, video_path, dtype, device): @@ -42,7 +46,12 @@ def encode_video(model_path, video_path, dtype, device): Returns: - torch.Tensor: The encoded video frames. """ + model = AutoencoderKLCogVideoX.from_pretrained(model_path, torch_dtype=dtype).to(device) + + model.enable_slicing() + model.enable_tiling() + video_reader = imageio.get_reader(video_path, "ffmpeg") frames = [transforms.ToTensor()(frame) for frame in video_reader] @@ -80,13 +89,13 @@ def save_video(tensor, output_path): Saves the video frames to a video file. Parameters: - - tensor (torch.Tensor): The video frames tensor. + - tensor (torch.Tensor): The video frames' tensor. - output_path (str): The path to save the output video. """ + tensor = tensor.to(dtype=torch.float32) frames = tensor[0].squeeze(0).permute(1, 2, 3, 0).cpu().numpy() frames = np.clip(frames, 0, 1) * 255 frames = frames.astype(np.uint8) - writer = imageio.get_writer(output_path + "/output.mp4", fps=8) for frame in frames: writer.append_data(frame) @@ -103,7 +112,7 @@ if __name__ == "__main__": "--mode", type=str, choices=["encode", "decode", "both"], required=True, help="Mode: encode, decode, or both" ) parser.add_argument( - "--dtype", type=str, default="float16", help="The data type for computation (e.g., 'float16' or 'float32')" + "--dtype", type=str, default="bfloat16", help="The data type for computation (e.g., 'float16' or 'bfloat16')" ) parser.add_argument( "--device", type=str, default="cuda", help="The device to use for computation (e.g., 'cuda' or 'cpu')" @@ -111,7 +120,7 @@ if __name__ == "__main__": args = parser.parse_args() device = torch.device(args.device) - dtype = torch.float16 if args.dtype == "float16" else torch.float32 + dtype = torch.float16 if args.dtype == "float16" else torch.bfloat16 if args.mode == "encode": assert args.video_path, "Video path must be provided for encoding."