Merge pull request #317 from THUDM/CogVideoX_dev

Finetune and Readme update
This commit is contained in:
Yuxuan.Zhang 2024-09-20 13:39:51 +08:00 committed by GitHub
commit 1e5dd975e2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 163 additions and 206 deletions

View File

@ -277,6 +277,7 @@ pipe.vae.enable_tiling()
We highly welcome contributions from the community and actively contribute to the open-source community. The following
works have already been adapted for CogVideoX, and we invite everyone to use them:
+ [CogVideoX-Fun](https://github.com/aigc-apps/CogVideoX-Fun): CogVideoX-Fun is a modified pipeline based on the CogVideoX architecture, supporting flexible resolutions and multiple launch methods.
+ [Xorbits Inference](https://github.com/xorbitsai/inference): A powerful and comprehensive distributed inference
framework, allowing you to easily deploy your own models or the latest cutting-edge open-source models with just one
click.

View File

@ -261,6 +261,7 @@ pipe.vae.enable_tiling()
コミュニティからの貢献を大歓迎し、私たちもオープンソースコミュニティに積極的に貢献しています。以下の作品はすでにCogVideoXに対応しており、ぜひご利用ください
+ [CogVideoX-Fun](https://github.com/aigc-apps/CogVideoX-Fun): CogVideoX-Funは、CogVideoXアーキテクチャを基にした改良パイプラインで、自由な解像度と複数の起動方法をサポートしています。
+ [Xorbits Inference](https://github.com/xorbitsai/inference):
強力で包括的な分散推論フレームワークであり、ワンクリックで独自のモデルや最新のオープンソースモデルを簡単にデプロイできます。
+ [ComfyUI-CogVideoXWrapper](https://github.com/kijai/ComfyUI-CogVideoXWrapper)

View File

@ -248,7 +248,7 @@ pipe.vae.enable_tiling()
## 友情链接
我们非常欢迎来自社区的贡献并积极的贡献开源社区。以下作品已经对CogVideoX进行了适配欢迎大家使用:
+ [CogVideoX-Fun](https://github.com/aigc-apps/CogVideoX-Fun): CogVideoX-Fun是一个基于CogVideoX结构修改后的的pipeline支持自由的分辨率多种启动方式。
+ [Xorbits Inference](https://github.com/xorbitsai/inference): 性能强大且功能全面的分布式推理框架,轻松一键部署你自己的模型或内置的前沿开源模型。
+ [ComfyUI-CogVideoXWrapper](https://github.com/kijai/ComfyUI-CogVideoXWrapper) 使用ComfyUI框架将CogVideoX加入到你的工作流中。
+ [VideoSys](https://github.com/NUS-HPC-AI-Lab/VideoSys): VideoSys 提供了易用且高性能的视频生成基础设施,支持完整的管道,并持续集成最新的模型和技术。

View File

@ -9,9 +9,9 @@ see [here](../sat/README_zh.md). The dataset format is different from this versi
## Hardware Requirements
+ CogVideoX-2B LoRA: 1 * A100
+ CogVideoX-2B SFT: 8 * A100
+ CogVideoX-5B/5B-I2V is not supported yet.
+ CogVideoX-2B / 5B LoRA: 1 * A100 (5B need to use `--use_8bit_adam`)
+ CogVideoX-2B SFT: 8 * A100 (Working)
+ CogVideoX-5B-I2V is not supported yet.
## Install Dependencies
@ -20,8 +20,7 @@ diffusers branch. Please follow the steps below to install dependencies:
```shell
git clone https://github.com/huggingface/diffusers.git
cd diffusers
git checkout cogvideox-lora-and-training
cd diffusers # Now in Main branch
pip install -e .
```
@ -124,13 +123,13 @@ accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gp
Single GPU fine-tuning:
```shell
bash finetune_single_gpu.sh
bash finetune_single_rank.sh
```
Multi-GPU fine-tuning:
```shell
bash finetune_multi_gpus_1.sh # Needs to be run on each node
bash finetune_multi_rank.sh # Needs to be run on each node
```
## Loading the Fine-tuned Model

View File

@ -9,9 +9,9 @@
## ハードウェア要件
+ CogVideoX-2B LORA: 1 * A100
+ CogVideoX-2B SFT: 8 * A100
+ CogVideoX-5B/5B-I2V まだサポートしていません
+ CogVideoX-2B / 5B T2V LORA: 1 * A100 (5B need to use `--use_8bit_adam`)
+ CogVideoX-2B SFT: 8 * A100 (動作確認済み)
+ CogVideoX-5B-I2V まだサポートしていません
## 依存関係のインストール
@ -19,8 +19,7 @@
```shell
git clone https://github.com/huggingface/diffusers.git
cd diffusers
git checkout cogvideox-lora-and-training
cd diffusers # Now in Main branch
pip install -e .
```
@ -120,13 +119,13 @@ accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gp
単一GPU微調整
```shell
bash finetune_single_gpu.sh
bash finetune_single_rank.sh
```
複数GPU微調整
```shell
bash finetune_multi_gpus_1.sh # 各ノードで実行する必要があります。
bash finetune_multi_rank.sh # 各ノードで実行する必要があります。
```
## 微調整済みモデルのロード

View File

@ -8,9 +8,9 @@
## 硬件要求
+ CogVideoX-2B LORA: 1 * A100
+ CogVideoX-2B SFT: 8 * A100
+ CogVideoX-5B/5B-I2V 暂未支持
+ CogVideoX-2B / 5B T2V LORA: 1 * A100 (5B need to use `--use_8bit_adam`)
+ CogVideoX-2B SFT: 8 * A100 (制作中)
+ CogVideoX-5B-I2V 暂未支持
## 安装依赖
@ -18,8 +18,7 @@
```shell
git clone https://github.com/huggingface/diffusers.git
cd diffusers
git checkout cogvideox-lora-and-training
cd diffusers # Now in Main branch
pip install -e .
```
@ -150,13 +149,13 @@ accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gp
单卡微调:
```shell
bash finetune_single_gpu.sh
bash finetune_single_rank.sh
```
多卡微调:
```shell
bash finetune_multi_gpus_1.sh #需要在每个节点运行
bash finetune_multi_rank.sh #需要在每个节点运行
```
## 载入微调的模型

View File

@ -9,6 +9,7 @@ export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES
accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gpu --machine_rank 0 \
train_cogvideox_lora.py \
--gradient_checkpointing \
--pretrained_model_name_or_path $MODEL_PATH \
--cache_dir $CACHE_PATH \
--enable_tiling \

View File

@ -3,12 +3,17 @@
export MODEL_PATH="THUDM/CogVideoX-2b"
export CACHE_PATH="~/.cache"
export DATASET_PATH="Disney-VideoGeneration-Dataset"
export OUTPUT_PATH="cogvideox-lora-single-gpu"
export OUTPUT_PATH="cogvideox-lora-multi-gpu"
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES
# --use_8bit_adam is necessary for CogVideoX-5B-I2V
# if you are not using wth 8 gus, change `accelerate_config_machine_single.yaml` num_processes as your gpu number
accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gpu \
train_cogvideox_lora.py \
--gradient_checkpointing \
--use_8bit_adam \
--pretrained_model_name_or_path $MODEL_PATH \
--cache_dir $CACHE_PATH \
--enable_tiling \

View File

@ -14,7 +14,6 @@
# limitations under the License.
import argparse
import itertools
import logging
import math
import os
@ -32,7 +31,7 @@ from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dic
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import T5EncoderModel, T5Tokenizer
from transformers import AutoTokenizer, T5EncoderModel, T5Tokenizer
import diffusers
from diffusers import AutoencoderKLCogVideoX, CogVideoXDPMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel
@ -40,7 +39,6 @@ from diffusers.models.embeddings import get_3d_rotary_pos_embed
from diffusers.optimization import get_scheduler
from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
from diffusers.training_utils import (
_set_state_dict_into_text_encoder,
cast_training_params,
clear_objs_and_retain_memory,
)
@ -48,6 +46,7 @@ from diffusers.utils import check_min_version, convert_unet_state_dict_to_peft,
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.torch_utils import is_compiled_module
if is_wandb_available():
import wandb
@ -239,11 +238,6 @@ def get_args():
action="store_true",
help="whether to randomly flip videos horizontally",
)
parser.add_argument(
"--train_text_encoder",
action="store_true",
help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
)
parser.add_argument(
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
)
@ -296,12 +290,6 @@ def get_args():
default=1e-4,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--text_encoder_lr",
type=float,
default=5e-6,
help="Text encoder learning rate to use.",
)
parser.add_argument(
"--scale_lr",
action="store_true",
@ -367,9 +355,6 @@ def get_args():
)
parser.add_argument("--prodigy_decouple", action="store_true", help="Use AdamW style decoupled weight decay")
parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
parser.add_argument(
"--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
)
parser.add_argument(
"--adam_epsilon",
type=float,
@ -423,20 +408,20 @@ def get_args():
class VideoDataset(Dataset):
def __init__(
self,
instance_data_root: Optional[str] = None,
dataset_name: Optional[str] = None,
dataset_config_name: Optional[str] = None,
caption_column: str = "text",
video_column: str = "video",
height: int = 480,
width: int = 720,
fps: int = 8,
max_num_frames: int = 49,
skip_frames_start: int = 0,
skip_frames_end: int = 0,
cache_dir: Optional[str] = None,
id_token: Optional[str] = None,
self,
instance_data_root: Optional[str] = None,
dataset_name: Optional[str] = None,
dataset_config_name: Optional[str] = None,
caption_column: str = "text",
video_column: str = "video",
height: int = 480,
width: int = 720,
fps: int = 8,
max_num_frames: int = 49,
skip_frames_start: int = 0,
skip_frames_end: int = 0,
cache_dir: Optional[str] = None,
id_token: Optional[str] = None,
) -> None:
super().__init__()
@ -555,7 +540,7 @@ class VideoDataset(Dataset):
import decord
except ImportError:
raise ImportError(
"The `decord` package is required for loading the video dataset. Install with `pip install dataset`"
"The `decord` package is required for loading the video dataset. Install with `pip install decord`"
)
decord.bridge.set_bridge("torch")
@ -602,13 +587,12 @@ class VideoDataset(Dataset):
def save_model_card(
repo_id: str,
videos=None,
base_model: str = None,
train_text_encoder=False,
validation_prompt=None,
repo_folder=None,
fps=8,
repo_id: str,
videos=None,
base_model: str = None,
validation_prompt=None,
repo_folder=None,
fps=8,
):
widget_dict = []
if videos is not None:
@ -627,9 +611,9 @@ def save_model_card(
These are {repo_id} LoRA weights for {base_model}.
The weights were trained using the [CogVideoX Diffusers trainer](TODO).
The weights were trained using the [CogVideoX Diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/cogvideo/train_cogvideox_lora.py).
Was LoRA for the text encoder enabled? {train_text_encoder}.
Was LoRA for the text encoder enabled? No.
## Download model
@ -642,8 +626,15 @@ from diffusers import CogVideoXPipeline
import torch
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to("cuda")
pipe.load_lora_weights("{repo_id}", weight_name="pytorch_lora_weights.safetensors")
video = pipe("{validation_prompt}").frames[0]
pipe.load_lora_weights("{repo_id}", weight_name="pytorch_lora_weights.safetensors", adapter_name=["cogvideox-lora"])
# The LoRA adapter weights are determined by what was used for training.
# In this case, we assume `--lora_alpha` is 32 and `--rank` is 64.
# It can be made lower or higher from what was used in training to decrease or amplify the effect
# of the LoRA upto a tolerance, beyond which one might notice no effect at all or overflows.
pipe.set_adapters(["cogvideox-lora"], [32 / 64])
video = pipe("{validation_prompt}", guidance_scale=6, use_dynamic_cfg=True).frames[0]
```
For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
@ -676,12 +667,12 @@ Please adhere to the licensing terms as described [here](https://huggingface.co/
def log_validation(
pipe,
args,
accelerator,
pipeline_args,
epoch,
is_final_validation: bool = False,
pipe,
args,
accelerator,
pipeline_args,
epoch,
is_final_validation: bool = False,
):
logger.info(
f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}."
@ -741,14 +732,14 @@ def log_validation(
def _get_t5_prompt_embeds(
tokenizer: T5Tokenizer,
text_encoder: T5EncoderModel,
prompt: Union[str, List[str]],
num_videos_per_prompt: int = 1,
max_sequence_length: int = 226,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
text_input_ids=None,
tokenizer: T5Tokenizer,
text_encoder: T5EncoderModel,
prompt: Union[str, List[str]],
num_videos_per_prompt: int = 1,
max_sequence_length: int = 226,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
text_input_ids=None,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
@ -779,14 +770,14 @@ def _get_t5_prompt_embeds(
def encode_prompt(
tokenizer: T5Tokenizer,
text_encoder: T5EncoderModel,
prompt: Union[str, List[str]],
num_videos_per_prompt: int = 1,
max_sequence_length: int = 226,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
text_input_ids=None,
tokenizer: T5Tokenizer,
text_encoder: T5EncoderModel,
prompt: Union[str, List[str]],
num_videos_per_prompt: int = 1,
max_sequence_length: int = 226,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
text_input_ids=None,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
prompt_embeds = _get_t5_prompt_embeds(
@ -802,13 +793,16 @@ def encode_prompt(
return prompt_embeds
def compute_prompt_embeddings(tokenizer, text_encoder, prompt, device, dtype, requires_grad: bool = False):
def compute_prompt_embeddings(
tokenizer, text_encoder, prompt, max_sequence_length, device, dtype, requires_grad: bool = False
):
if requires_grad:
prompt_embeds = encode_prompt(
tokenizer,
text_encoder,
prompt,
num_videos_per_prompt=1,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
@ -819,6 +813,7 @@ def compute_prompt_embeddings(tokenizer, text_encoder, prompt, device, dtype, re
text_encoder,
prompt,
num_videos_per_prompt=1,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
@ -826,15 +821,15 @@ def compute_prompt_embeddings(tokenizer, text_encoder, prompt, device, dtype, re
def prepare_rotary_positional_embeddings(
height: int,
width: int,
num_frames: int,
vae_scale_factor_spatial: int = 8,
patch_size: int = 2,
attention_head_dim: int = 64,
device: Optional[torch.device] = None,
base_height: int = 480,
base_width: int = 720,
height: int,
width: int,
num_frames: int,
vae_scale_factor_spatial: int = 8,
patch_size: int = 2,
attention_head_dim: int = 64,
device: Optional[torch.device] = None,
base_height: int = 480,
base_width: int = 720,
) -> Tuple[torch.Tensor, torch.Tensor]:
grid_height = height // (vae_scale_factor_spatial * patch_size)
grid_width = width // (vae_scale_factor_spatial * patch_size)
@ -854,10 +849,22 @@ def prepare_rotary_positional_embeddings(
return freqs_cos, freqs_sin
def get_optimizer(args, params_to_optimize):
def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False):
# Use DeepSpeed optimzer
if use_deepspeed:
from accelerate.utils import DummyOptim
return DummyOptim(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_epsilon,
weight_decay=args.adam_weight_decay,
)
# Optimizer creation
supported_optimizers = ["adam", "adamw", "prodigy"]
if args.optimizer not in ["adam", "adamw", "prodigy"]:
if args.optimizer not in supported_optimizers:
logger.warning(
f"Unsupported choice of optimizer: {args.optimizer}. Supported optimizers include {supported_optimizers}. Defaulting to AdamW"
)
@ -907,14 +914,6 @@ def get_optimizer(args, params_to_optimize):
logger.warning(
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
)
if args.train_text_encoder and args.text_encoder_lr:
logger.warning(
f"Learning rates were provided both for the transformer and the text encoder - e.g. text_encoder_lr:"
f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
f"When using prodigy only learning_rate is used as the initial learning rate."
)
# Changes the learning rate of text_encoder_parameters to be --learning_rate
params_to_optimize[1]["lr"] = args.learning_rate
optimizer = optimizer_class(
params_to_optimize,
@ -994,7 +993,7 @@ def main(args):
).repo_id
# Prepare models and scheduler
tokenizer = T5Tokenizer.from_pretrained(
tokenizer = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
)
@ -1035,13 +1034,13 @@ def main(args):
if accelerator.state.deepspeed_plugin:
# DeepSpeed is handling precision, use what's in the DeepSpeed config
if (
"fp16" in accelerator.state.deepspeed_plugin.deepspeed_config
and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"]
"fp16" in accelerator.state.deepspeed_plugin.deepspeed_config
and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"]
):
weight_dtype = torch.float16
if (
"bf16" in accelerator.state.deepspeed_plugin.deepspeed_config
and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"]
"bf16" in accelerator.state.deepspeed_plugin.deepspeed_config
and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"]
):
weight_dtype = torch.float16
else:
@ -1062,8 +1061,6 @@ def main(args):
if args.gradient_checkpointing:
transformer.enable_gradient_checkpointing()
if args.train_text_encoder:
text_encoder.gradient_checkpointing_enable()
# now we will add new LoRA weights to the attention layers
transformer_lora_config = LoraConfig(
@ -1073,14 +1070,6 @@ def main(args):
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
)
transformer.add_adapter(transformer_lora_config)
if args.train_text_encoder:
text_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.lora_alpha,
init_lora_weights=True,
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
)
text_encoder.add_adapter(text_lora_config)
def unwrap_model(model):
model = accelerator.unwrap_model(model)
@ -1091,13 +1080,10 @@ def main(args):
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
transformer_lora_layers_to_save = None
text_encoder_lora_layers_to_save = None
for model in models:
if isinstance(model, type(unwrap_model(transformer))):
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
elif isinstance(model, type(unwrap_model(text_encoder))):
text_encoder_lora_layers_to_save = get_peft_model_state_dict(model)
else:
raise ValueError(f"unexpected save model: {model.__class__}")
@ -1107,22 +1093,18 @@ def main(args):
CogVideoXPipeline.save_lora_weights(
output_dir,
transformer_lora_layers=transformer_lora_layers_to_save,
text_encoder_lora_layers=text_encoder_lora_layers_to_save,
)
def load_model_hook(models, input_dir):
transformer_ = None
text_encoder_ = None
while len(models) > 0:
model = models.pop()
if isinstance(model, type(unwrap_model(transformer))):
transformer_ = model
elif isinstance(model, type(unwrap_model(text_encoder))):
text_encoder_ = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
raise ValueError(f"Unexpected save model: {model.__class__}")
lora_state_dict = CogVideoXPipeline.lora_state_dict(input_dir)
@ -1139,19 +1121,13 @@ def main(args):
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. "
)
if args.train_text_encoder:
# Do we need to call `scale_lora_layers()` here?
_set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_)
# Make sure the trainable params are in float32. This is again needed since the base models
# are in `weight_dtype`. More details:
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
if args.mixed_precision == "fp16":
models = [transformer_]
if args.train_text_encoder:
models.extend([text_encoder_])
# only upcast trainable parameters (LoRA) into fp32
cast_training_params(models)
cast_training_params([transformer_])
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
@ -1163,38 +1139,30 @@ def main(args):
if args.scale_lr:
args.learning_rate = (
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
)
# Make sure the trainable params are in float32.
if args.mixed_precision == "fp16":
models = [transformer]
if args.train_text_encoder:
models.extend([text_encoder])
# only upcast trainable parameters (LoRA) into fp32
cast_training_params(models, dtype=torch.float32)
cast_training_params([transformer], dtype=torch.float32)
transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
if args.train_text_encoder:
text_encoder_lora_parameters = list(filter(lambda p: p.requires_grad, text_encoder.parameters()))
# Optimization parameters
transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate}
if args.train_text_encoder:
# different learning rate for text encoder and unet
text_encoder_parameters_with_lr = {
"params": text_encoder_lora_parameters,
"weight_decay": args.adam_weight_decay_text_encoder,
"lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
}
params_to_optimize = [
transformer_parameters_with_lr,
text_encoder_parameters_with_lr,
]
else:
params_to_optimize = [transformer_parameters_with_lr]
params_to_optimize = [transformer_parameters_with_lr]
optimizer = get_optimizer(args, params_to_optimize)
use_deepspeed_optimizer = (
accelerator.state.deepspeed_plugin is not None
and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config
)
use_deepspeed_scheduler = (
accelerator.state.deepspeed_plugin is not None
and "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
)
optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer)
# Dataset and DataLoader
train_dataset = VideoDataset(
@ -1248,35 +1216,30 @@ def main(args):
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
num_cycles=args.lr_num_cycles,
power=args.lr_power,
)
if use_deepspeed_scheduler:
from accelerate.utils import DummyScheduler
# Prepare everything with our `accelerator`.
if args.train_text_encoder:
(
transformer,
text_encoder,
optimizer,
train_dataloader,
lr_scheduler,
) = accelerator.prepare(
transformer,
text_encoder,
optimizer,
train_dataloader,
lr_scheduler,
lr_scheduler = DummyScheduler(
name=args.lr_scheduler,
optimizer=optimizer,
total_num_steps=args.max_train_steps * accelerator.num_processes,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
)
else:
transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
transformer, optimizer, train_dataloader, lr_scheduler
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
num_cycles=args.lr_num_cycles,
power=args.lr_power,
)
# Prepare everything with our `accelerator`.
transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
transformer, optimizer, train_dataloader, lr_scheduler
)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
@ -1347,15 +1310,9 @@ def main(args):
for epoch in range(first_epoch, args.num_train_epochs):
transformer.train()
if args.train_text_encoder:
text_encoder.train()
# set top parameter requires_grad = True for gradient checkpointing works
accelerator.unwrap_model(text_encoder).text_model.embeddings.requires_grad_(True)
for step, batch in enumerate(train_dataloader):
models_to_accumulate = [transformer]
if args.train_text_encoder:
models_to_accumulate.extend([text_encoder])
with accelerator.accumulate(models_to_accumulate):
model_input = batch["videos"].permute(0, 2, 1, 3, 4).to(dtype=weight_dtype) # [B, F, C, H, W]
@ -1366,9 +1323,10 @@ def main(args):
tokenizer,
text_encoder,
prompts,
model_config.max_text_seq_length,
accelerator.device,
weight_dtype,
requires_grad=args.train_text_encoder,
requires_grad=False,
)
# Sample noise that will be added to the latents
@ -1422,16 +1380,14 @@ def main(args):
accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = (
itertools.chain(transformer.parameters(), text_encoder.parameters())
if args.train_text_encoder
else transformer.parameters()
)
params_to_clip = transformer.parameters()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
if accelerator.state.deepspeed_plugin is None:
optimizer.step()
optimizer.zero_grad()
lr_scheduler.step()
optimizer.zero_grad()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
@ -1507,7 +1463,6 @@ def main(args):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
transformer = unwrap_model(transformer)
# transformer = transformer.to(torch.float32)
dtype = (
torch.float16
if args.mixed_precision == "fp16"
@ -1518,16 +1473,9 @@ def main(args):
transformer = transformer.to(dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer)
if args.train_text_encoder:
text_encoder = unwrap_model(text_encoder)
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder.to(dtype))
else:
text_encoder_lora_layers = None
CogVideoXPipeline.save_lora_weights(
save_directory=args.output_dir,
transformer_lora_layers=transformer_lora_layers,
text_encoder_lora_layers=text_encoder_lora_layers,
)
# Final test inference
@ -1539,6 +1487,11 @@ def main(args):
)
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config)
if args.enable_slicing:
pipe.vae.enable_slicing()
if args.enable_tiling:
pipe.vae.enable_tiling()
# Load LoRA weights
lora_scaling = args.lora_alpha / args.rank
pipe.load_lora_weights(args.output_dir, adapter_name="cogvideox-lora")
@ -1572,7 +1525,6 @@ def main(args):
repo_id,
videos=validation_outputs,
base_model=args.pretrained_model_name_or_path,
train_text_encoder=args.train_text_encoder,
validation_prompt=args.validation_prompt,
repo_folder=args.output_dir,
fps=args.fps,