diff --git a/finetune/README.md b/finetune/README.md index d3f0204..740244d 100644 --- a/finetune/README.md +++ b/finetune/README.md @@ -9,9 +9,9 @@ see [here](../sat/README_zh.md). The dataset format is different from this versi ## Hardware Requirements -+ CogVideoX-2B LoRA: 1 * A100 -+ CogVideoX-2B SFT: 8 * A100 -+ CogVideoX-5B/5B-I2V is not supported yet. ++ CogVideoX-2B / 5B LoRA: 1 * A100 (5B need to use `--use_8bit_adam`) ++ CogVideoX-2B SFT: 8 * A100 (Working) ++ CogVideoX-5B-I2V is not supported yet. ## Install Dependencies @@ -20,8 +20,7 @@ diffusers branch. Please follow the steps below to install dependencies: ```shell git clone https://github.com/huggingface/diffusers.git -cd diffusers -git checkout cogvideox-lora-and-training +cd diffusers # Now in Main branch pip install -e . ``` @@ -124,13 +123,13 @@ accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gp Single GPU fine-tuning: ```shell -bash finetune_single_gpu.sh +bash finetune_single_rank.sh ``` Multi-GPU fine-tuning: ```shell -bash finetune_multi_gpus_1.sh # Needs to be run on each node +bash finetune_multi_rank.sh # Needs to be run on each node ``` ## Loading the Fine-tuned Model diff --git a/finetune/README_ja.md b/finetune/README_ja.md index 1c0a021..99491b1 100644 --- a/finetune/README_ja.md +++ b/finetune/README_ja.md @@ -9,9 +9,9 @@ ## ハードウェア要件 -+ CogVideoX-2B LORA: 1 * A100 -+ CogVideoX-2B SFT: 8 * A100 -+ CogVideoX-5B/5B-I2V まだサポートしていません ++ CogVideoX-2B / 5B T2V LORA: 1 * A100 (5B need to use `--use_8bit_adam`) ++ CogVideoX-2B SFT: 8 * A100 (動作確認済み) ++ CogVideoX-5B-I2V まだサポートしていません ## 依存関係のインストール @@ -19,8 +19,7 @@ ```shell git clone https://github.com/huggingface/diffusers.git -cd diffusers -git checkout cogvideox-lora-and-training +cd diffusers # Now in Main branch pip install -e . ``` @@ -120,13 +119,13 @@ accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gp 単一GPU微調整: ```shell -bash finetune_single_gpu.sh +bash finetune_single_rank.sh ``` 複数GPU微調整: ```shell -bash finetune_multi_gpus_1.sh # 各ノードで実行する必要があります。 +bash finetune_multi_rank.sh # 各ノードで実行する必要があります。 ``` ## 微調整済みモデルのロード diff --git a/finetune/README_zh.md b/finetune/README_zh.md index 3385e9e..73ec738 100644 --- a/finetune/README_zh.md +++ b/finetune/README_zh.md @@ -8,9 +8,9 @@ ## 硬件要求 -+ CogVideoX-2B LORA: 1 * A100 -+ CogVideoX-2B SFT: 8 * A100 -+ CogVideoX-5B/5B-I2V 暂未支持 ++ CogVideoX-2B / 5B T2V LORA: 1 * A100 (5B need to use `--use_8bit_adam`) ++ CogVideoX-2B SFT: 8 * A100 (制作中) ++ CogVideoX-5B-I2V 暂未支持 ## 安装依赖 @@ -18,8 +18,7 @@ ```shell git clone https://github.com/huggingface/diffusers.git -cd diffusers -git checkout cogvideox-lora-and-training +cd diffusers # Now in Main branch pip install -e . ``` @@ -150,13 +149,13 @@ accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gp 单卡微调: ```shell -bash finetune_single_gpu.sh +bash finetune_single_rank.sh ``` 多卡微调: ```shell -bash finetune_multi_gpus_1.sh #需要在每个节点运行 +bash finetune_multi_rank.sh #需要在每个节点运行 ``` ## 载入微调的模型 diff --git a/finetune/finetune_multi_gpus_1.sh b/finetune/finetune_multi_rank.sh similarity index 97% rename from finetune/finetune_multi_gpus_1.sh rename to finetune/finetune_multi_rank.sh index 6ae55c5..71f94e8 100644 --- a/finetune/finetune_multi_gpus_1.sh +++ b/finetune/finetune_multi_rank.sh @@ -9,6 +9,7 @@ export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gpu --machine_rank 0 \ train_cogvideox_lora.py \ + --gradient_checkpointing \ --pretrained_model_name_or_path $MODEL_PATH \ --cache_dir $CACHE_PATH \ --enable_tiling \ diff --git a/finetune/finetune_single_gpu.sh b/finetune/finetune_single_rank.sh similarity index 81% rename from finetune/finetune_single_gpu.sh rename to finetune/finetune_single_rank.sh index 4accadf..2794a13 100644 --- a/finetune/finetune_single_gpu.sh +++ b/finetune/finetune_single_rank.sh @@ -3,12 +3,17 @@ export MODEL_PATH="THUDM/CogVideoX-2b" export CACHE_PATH="~/.cache" export DATASET_PATH="Disney-VideoGeneration-Dataset" -export OUTPUT_PATH="cogvideox-lora-single-gpu" +export OUTPUT_PATH="cogvideox-lora-multi-gpu" export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES + +# --use_8bit_adam is necessary for CogVideoX-5B-I2V +# if you are not using wth 8 gus, change `accelerate_config_machine_single.yaml` num_processes as your gpu number accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gpu \ train_cogvideox_lora.py \ + --gradient_checkpointing \ + --use_8bit_adam \ --pretrained_model_name_or_path $MODEL_PATH \ --cache_dir $CACHE_PATH \ --enable_tiling \ diff --git a/finetune/train_cogvideox_lora.py b/finetune/train_cogvideox_lora.py index bba4d06..137f322 100644 --- a/finetune/train_cogvideox_lora.py +++ b/finetune/train_cogvideox_lora.py @@ -14,7 +14,6 @@ # limitations under the License. import argparse -import itertools import logging import math import os @@ -32,7 +31,7 @@ from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dic from torch.utils.data import DataLoader, Dataset from torchvision import transforms from tqdm.auto import tqdm -from transformers import T5EncoderModel, T5Tokenizer +from transformers import AutoTokenizer, T5EncoderModel, T5Tokenizer import diffusers from diffusers import AutoencoderKLCogVideoX, CogVideoXDPMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel @@ -40,7 +39,6 @@ from diffusers.models.embeddings import get_3d_rotary_pos_embed from diffusers.optimization import get_scheduler from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid from diffusers.training_utils import ( - _set_state_dict_into_text_encoder, cast_training_params, clear_objs_and_retain_memory, ) @@ -48,6 +46,7 @@ from diffusers.utils import check_min_version, convert_unet_state_dict_to_peft, from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.torch_utils import is_compiled_module + if is_wandb_available(): import wandb @@ -239,11 +238,6 @@ def get_args(): action="store_true", help="whether to randomly flip videos horizontally", ) - parser.add_argument( - "--train_text_encoder", - action="store_true", - help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", - ) parser.add_argument( "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." ) @@ -296,12 +290,6 @@ def get_args(): default=1e-4, help="Initial learning rate (after the potential warmup period) to use.", ) - parser.add_argument( - "--text_encoder_lr", - type=float, - default=5e-6, - help="Text encoder learning rate to use.", - ) parser.add_argument( "--scale_lr", action="store_true", @@ -367,9 +355,6 @@ def get_args(): ) parser.add_argument("--prodigy_decouple", action="store_true", help="Use AdamW style decoupled weight decay") parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") - parser.add_argument( - "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" - ) parser.add_argument( "--adam_epsilon", type=float, @@ -423,20 +408,20 @@ def get_args(): class VideoDataset(Dataset): def __init__( - self, - instance_data_root: Optional[str] = None, - dataset_name: Optional[str] = None, - dataset_config_name: Optional[str] = None, - caption_column: str = "text", - video_column: str = "video", - height: int = 480, - width: int = 720, - fps: int = 8, - max_num_frames: int = 49, - skip_frames_start: int = 0, - skip_frames_end: int = 0, - cache_dir: Optional[str] = None, - id_token: Optional[str] = None, + self, + instance_data_root: Optional[str] = None, + dataset_name: Optional[str] = None, + dataset_config_name: Optional[str] = None, + caption_column: str = "text", + video_column: str = "video", + height: int = 480, + width: int = 720, + fps: int = 8, + max_num_frames: int = 49, + skip_frames_start: int = 0, + skip_frames_end: int = 0, + cache_dir: Optional[str] = None, + id_token: Optional[str] = None, ) -> None: super().__init__() @@ -555,7 +540,7 @@ class VideoDataset(Dataset): import decord except ImportError: raise ImportError( - "The `decord` package is required for loading the video dataset. Install with `pip install dataset`" + "The `decord` package is required for loading the video dataset. Install with `pip install decord`" ) decord.bridge.set_bridge("torch") @@ -602,13 +587,12 @@ class VideoDataset(Dataset): def save_model_card( - repo_id: str, - videos=None, - base_model: str = None, - train_text_encoder=False, - validation_prompt=None, - repo_folder=None, - fps=8, + repo_id: str, + videos=None, + base_model: str = None, + validation_prompt=None, + repo_folder=None, + fps=8, ): widget_dict = [] if videos is not None: @@ -627,9 +611,9 @@ def save_model_card( These are {repo_id} LoRA weights for {base_model}. -The weights were trained using the [CogVideoX Diffusers trainer](TODO). +The weights were trained using the [CogVideoX Diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/cogvideo/train_cogvideox_lora.py). -Was LoRA for the text encoder enabled? {train_text_encoder}. +Was LoRA for the text encoder enabled? No. ## Download model @@ -642,8 +626,15 @@ from diffusers import CogVideoXPipeline import torch pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to("cuda") -pipe.load_lora_weights("{repo_id}", weight_name="pytorch_lora_weights.safetensors") -video = pipe("{validation_prompt}").frames[0] +pipe.load_lora_weights("{repo_id}", weight_name="pytorch_lora_weights.safetensors", adapter_name=["cogvideox-lora"]) + +# The LoRA adapter weights are determined by what was used for training. +# In this case, we assume `--lora_alpha` is 32 and `--rank` is 64. +# It can be made lower or higher from what was used in training to decrease or amplify the effect +# of the LoRA upto a tolerance, beyond which one might notice no effect at all or overflows. +pipe.set_adapters(["cogvideox-lora"], [32 / 64]) + +video = pipe("{validation_prompt}", guidance_scale=6, use_dynamic_cfg=True).frames[0] ``` For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) @@ -676,12 +667,12 @@ Please adhere to the licensing terms as described [here](https://huggingface.co/ def log_validation( - pipe, - args, - accelerator, - pipeline_args, - epoch, - is_final_validation: bool = False, + pipe, + args, + accelerator, + pipeline_args, + epoch, + is_final_validation: bool = False, ): logger.info( f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}." @@ -741,14 +732,14 @@ def log_validation( def _get_t5_prompt_embeds( - tokenizer: T5Tokenizer, - text_encoder: T5EncoderModel, - prompt: Union[str, List[str]], - num_videos_per_prompt: int = 1, - max_sequence_length: int = 226, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - text_input_ids=None, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + prompt: Union[str, List[str]], + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + text_input_ids=None, ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) @@ -779,14 +770,14 @@ def _get_t5_prompt_embeds( def encode_prompt( - tokenizer: T5Tokenizer, - text_encoder: T5EncoderModel, - prompt: Union[str, List[str]], - num_videos_per_prompt: int = 1, - max_sequence_length: int = 226, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - text_input_ids=None, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + prompt: Union[str, List[str]], + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + text_input_ids=None, ): prompt = [prompt] if isinstance(prompt, str) else prompt prompt_embeds = _get_t5_prompt_embeds( @@ -802,13 +793,16 @@ def encode_prompt( return prompt_embeds -def compute_prompt_embeddings(tokenizer, text_encoder, prompt, device, dtype, requires_grad: bool = False): +def compute_prompt_embeddings( + tokenizer, text_encoder, prompt, max_sequence_length, device, dtype, requires_grad: bool = False +): if requires_grad: prompt_embeds = encode_prompt( tokenizer, text_encoder, prompt, num_videos_per_prompt=1, + max_sequence_length=max_sequence_length, device=device, dtype=dtype, ) @@ -819,6 +813,7 @@ def compute_prompt_embeddings(tokenizer, text_encoder, prompt, device, dtype, re text_encoder, prompt, num_videos_per_prompt=1, + max_sequence_length=max_sequence_length, device=device, dtype=dtype, ) @@ -826,15 +821,15 @@ def compute_prompt_embeddings(tokenizer, text_encoder, prompt, device, dtype, re def prepare_rotary_positional_embeddings( - height: int, - width: int, - num_frames: int, - vae_scale_factor_spatial: int = 8, - patch_size: int = 2, - attention_head_dim: int = 64, - device: Optional[torch.device] = None, - base_height: int = 480, - base_width: int = 720, + height: int, + width: int, + num_frames: int, + vae_scale_factor_spatial: int = 8, + patch_size: int = 2, + attention_head_dim: int = 64, + device: Optional[torch.device] = None, + base_height: int = 480, + base_width: int = 720, ) -> Tuple[torch.Tensor, torch.Tensor]: grid_height = height // (vae_scale_factor_spatial * patch_size) grid_width = width // (vae_scale_factor_spatial * patch_size) @@ -854,10 +849,22 @@ def prepare_rotary_positional_embeddings( return freqs_cos, freqs_sin -def get_optimizer(args, params_to_optimize): +def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False): + # Use DeepSpeed optimzer + if use_deepspeed: + from accelerate.utils import DummyOptim + + return DummyOptim( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + eps=args.adam_epsilon, + weight_decay=args.adam_weight_decay, + ) + # Optimizer creation supported_optimizers = ["adam", "adamw", "prodigy"] - if args.optimizer not in ["adam", "adamw", "prodigy"]: + if args.optimizer not in supported_optimizers: logger.warning( f"Unsupported choice of optimizer: {args.optimizer}. Supported optimizers include {supported_optimizers}. Defaulting to AdamW" ) @@ -907,14 +914,6 @@ def get_optimizer(args, params_to_optimize): logger.warning( "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" ) - if args.train_text_encoder and args.text_encoder_lr: - logger.warning( - f"Learning rates were provided both for the transformer and the text encoder - e.g. text_encoder_lr:" - f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. " - f"When using prodigy only learning_rate is used as the initial learning rate." - ) - # Changes the learning rate of text_encoder_parameters to be --learning_rate - params_to_optimize[1]["lr"] = args.learning_rate optimizer = optimizer_class( params_to_optimize, @@ -994,7 +993,7 @@ def main(args): ).repo_id # Prepare models and scheduler - tokenizer = T5Tokenizer.from_pretrained( + tokenizer = AutoTokenizer.from_pretrained( args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision ) @@ -1035,13 +1034,13 @@ def main(args): if accelerator.state.deepspeed_plugin: # DeepSpeed is handling precision, use what's in the DeepSpeed config if ( - "fp16" in accelerator.state.deepspeed_plugin.deepspeed_config - and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"] + "fp16" in accelerator.state.deepspeed_plugin.deepspeed_config + and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"] ): weight_dtype = torch.float16 if ( - "bf16" in accelerator.state.deepspeed_plugin.deepspeed_config - and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"] + "bf16" in accelerator.state.deepspeed_plugin.deepspeed_config + and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"] ): weight_dtype = torch.float16 else: @@ -1062,8 +1061,6 @@ def main(args): if args.gradient_checkpointing: transformer.enable_gradient_checkpointing() - if args.train_text_encoder: - text_encoder.gradient_checkpointing_enable() # now we will add new LoRA weights to the attention layers transformer_lora_config = LoraConfig( @@ -1073,14 +1070,6 @@ def main(args): target_modules=["to_k", "to_q", "to_v", "to_out.0"], ) transformer.add_adapter(transformer_lora_config) - if args.train_text_encoder: - text_lora_config = LoraConfig( - r=args.rank, - lora_alpha=args.lora_alpha, - init_lora_weights=True, - target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], - ) - text_encoder.add_adapter(text_lora_config) def unwrap_model(model): model = accelerator.unwrap_model(model) @@ -1091,13 +1080,10 @@ def main(args): def save_model_hook(models, weights, output_dir): if accelerator.is_main_process: transformer_lora_layers_to_save = None - text_encoder_lora_layers_to_save = None for model in models: if isinstance(model, type(unwrap_model(transformer))): transformer_lora_layers_to_save = get_peft_model_state_dict(model) - elif isinstance(model, type(unwrap_model(text_encoder))): - text_encoder_lora_layers_to_save = get_peft_model_state_dict(model) else: raise ValueError(f"unexpected save model: {model.__class__}") @@ -1107,22 +1093,18 @@ def main(args): CogVideoXPipeline.save_lora_weights( output_dir, transformer_lora_layers=transformer_lora_layers_to_save, - text_encoder_lora_layers=text_encoder_lora_layers_to_save, ) def load_model_hook(models, input_dir): transformer_ = None - text_encoder_ = None while len(models) > 0: model = models.pop() if isinstance(model, type(unwrap_model(transformer))): transformer_ = model - elif isinstance(model, type(unwrap_model(text_encoder))): - text_encoder_ = model else: - raise ValueError(f"unexpected save model: {model.__class__}") + raise ValueError(f"Unexpected save model: {model.__class__}") lora_state_dict = CogVideoXPipeline.lora_state_dict(input_dir) @@ -1139,19 +1121,13 @@ def main(args): f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " f" {unexpected_keys}. " ) - if args.train_text_encoder: - # Do we need to call `scale_lora_layers()` here? - _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_) # Make sure the trainable params are in float32. This is again needed since the base models # are in `weight_dtype`. More details: # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 if args.mixed_precision == "fp16": - models = [transformer_] - if args.train_text_encoder: - models.extend([text_encoder_]) # only upcast trainable parameters (LoRA) into fp32 - cast_training_params(models) + cast_training_params([transformer_]) accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) @@ -1163,38 +1139,30 @@ def main(args): if args.scale_lr: args.learning_rate = ( - args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) # Make sure the trainable params are in float32. if args.mixed_precision == "fp16": - models = [transformer] - if args.train_text_encoder: - models.extend([text_encoder]) # only upcast trainable parameters (LoRA) into fp32 - cast_training_params(models, dtype=torch.float32) + cast_training_params([transformer], dtype=torch.float32) transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) - if args.train_text_encoder: - text_encoder_lora_parameters = list(filter(lambda p: p.requires_grad, text_encoder.parameters())) # Optimization parameters transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate} - if args.train_text_encoder: - # different learning rate for text encoder and unet - text_encoder_parameters_with_lr = { - "params": text_encoder_lora_parameters, - "weight_decay": args.adam_weight_decay_text_encoder, - "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, - } - params_to_optimize = [ - transformer_parameters_with_lr, - text_encoder_parameters_with_lr, - ] - else: - params_to_optimize = [transformer_parameters_with_lr] + params_to_optimize = [transformer_parameters_with_lr] - optimizer = get_optimizer(args, params_to_optimize) + use_deepspeed_optimizer = ( + accelerator.state.deepspeed_plugin is not None + and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config + ) + use_deepspeed_scheduler = ( + accelerator.state.deepspeed_plugin is not None + and "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config + ) + + optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer) # Dataset and DataLoader train_dataset = VideoDataset( @@ -1248,35 +1216,30 @@ def main(args): args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True - lr_scheduler = get_scheduler( - args.lr_scheduler, - optimizer=optimizer, - num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, - num_training_steps=args.max_train_steps * accelerator.num_processes, - num_cycles=args.lr_num_cycles, - power=args.lr_power, - ) + if use_deepspeed_scheduler: + from accelerate.utils import DummyScheduler - # Prepare everything with our `accelerator`. - if args.train_text_encoder: - ( - transformer, - text_encoder, - optimizer, - train_dataloader, - lr_scheduler, - ) = accelerator.prepare( - transformer, - text_encoder, - optimizer, - train_dataloader, - lr_scheduler, + lr_scheduler = DummyScheduler( + name=args.lr_scheduler, + optimizer=optimizer, + total_num_steps=args.max_train_steps * accelerator.num_processes, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, ) else: - transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - transformer, optimizer, train_dataloader, lr_scheduler + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + power=args.lr_power, ) + # Prepare everything with our `accelerator`. + transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer, optimizer, train_dataloader, lr_scheduler + ) + # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if overrode_max_train_steps: @@ -1347,15 +1310,9 @@ def main(args): for epoch in range(first_epoch, args.num_train_epochs): transformer.train() - if args.train_text_encoder: - text_encoder.train() - # set top parameter requires_grad = True for gradient checkpointing works - accelerator.unwrap_model(text_encoder).text_model.embeddings.requires_grad_(True) for step, batch in enumerate(train_dataloader): models_to_accumulate = [transformer] - if args.train_text_encoder: - models_to_accumulate.extend([text_encoder]) with accelerator.accumulate(models_to_accumulate): model_input = batch["videos"].permute(0, 2, 1, 3, 4).to(dtype=weight_dtype) # [B, F, C, H, W] @@ -1366,9 +1323,10 @@ def main(args): tokenizer, text_encoder, prompts, + model_config.max_text_seq_length, accelerator.device, weight_dtype, - requires_grad=args.train_text_encoder, + requires_grad=False, ) # Sample noise that will be added to the latents @@ -1422,16 +1380,14 @@ def main(args): accelerator.backward(loss) if accelerator.sync_gradients: - params_to_clip = ( - itertools.chain(transformer.parameters(), text_encoder.parameters()) - if args.train_text_encoder - else transformer.parameters() - ) + params_to_clip = transformer.parameters() accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) - optimizer.step() + if accelerator.state.deepspeed_plugin is None: + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() - optimizer.zero_grad() # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: @@ -1507,7 +1463,6 @@ def main(args): accelerator.wait_for_everyone() if accelerator.is_main_process: transformer = unwrap_model(transformer) - # transformer = transformer.to(torch.float32) dtype = ( torch.float16 if args.mixed_precision == "fp16" @@ -1518,16 +1473,9 @@ def main(args): transformer = transformer.to(dtype) transformer_lora_layers = get_peft_model_state_dict(transformer) - if args.train_text_encoder: - text_encoder = unwrap_model(text_encoder) - text_encoder_lora_layers = get_peft_model_state_dict(text_encoder.to(dtype)) - else: - text_encoder_lora_layers = None - CogVideoXPipeline.save_lora_weights( save_directory=args.output_dir, transformer_lora_layers=transformer_lora_layers, - text_encoder_lora_layers=text_encoder_lora_layers, ) # Final test inference @@ -1539,6 +1487,11 @@ def main(args): ) pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config) + if args.enable_slicing: + pipe.vae.enable_slicing() + if args.enable_tiling: + pipe.vae.enable_tiling() + # Load LoRA weights lora_scaling = args.lora_alpha / args.rank pipe.load_lora_weights(args.output_dir, adapter_name="cogvideox-lora") @@ -1572,7 +1525,6 @@ def main(args): repo_id, videos=validation_outputs, base_model=args.pretrained_model_name_or_path, - train_text_encoder=args.train_text_encoder, validation_prompt=args.validation_prompt, repo_folder=args.output_dir, fps=args.fps,