mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 19:41:59 +08:00
finetune requirement change
This commit is contained in:
parent
61cc99720d
commit
2db0453b96
@ -9,9 +9,9 @@ see [here](../sat/README_zh.md). The dataset format is different from this versi
|
||||
|
||||
## Hardware Requirements
|
||||
|
||||
+ CogVideoX-2B LoRA: 1 * A100
|
||||
+ CogVideoX-2B SFT: 8 * A100
|
||||
+ CogVideoX-5B/5B-I2V is not supported yet.
|
||||
+ CogVideoX-2B / 5B LoRA: 1 * A100 (5B need to use `--use_8bit_adam`)
|
||||
+ CogVideoX-2B SFT: 8 * A100 (Working)
|
||||
+ CogVideoX-5B-I2V is not supported yet.
|
||||
|
||||
## Install Dependencies
|
||||
|
||||
@ -20,8 +20,7 @@ diffusers branch. Please follow the steps below to install dependencies:
|
||||
|
||||
```shell
|
||||
git clone https://github.com/huggingface/diffusers.git
|
||||
cd diffusers
|
||||
git checkout cogvideox-lora-and-training
|
||||
cd diffusers # Now in Main branch
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
@ -124,13 +123,13 @@ accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gp
|
||||
Single GPU fine-tuning:
|
||||
|
||||
```shell
|
||||
bash finetune_single_gpu.sh
|
||||
bash finetune_single_rank.sh
|
||||
```
|
||||
|
||||
Multi-GPU fine-tuning:
|
||||
|
||||
```shell
|
||||
bash finetune_multi_gpus_1.sh # Needs to be run on each node
|
||||
bash finetune_multi_rank.sh # Needs to be run on each node
|
||||
```
|
||||
|
||||
## Loading the Fine-tuned Model
|
||||
|
@ -9,9 +9,9 @@
|
||||
|
||||
## ハードウェア要件
|
||||
|
||||
+ CogVideoX-2B LORA: 1 * A100
|
||||
+ CogVideoX-2B SFT: 8 * A100
|
||||
+ CogVideoX-5B/5B-I2V まだサポートしていません
|
||||
+ CogVideoX-2B / 5B T2V LORA: 1 * A100 (5B need to use `--use_8bit_adam`)
|
||||
+ CogVideoX-2B SFT: 8 * A100 (動作確認済み)
|
||||
+ CogVideoX-5B-I2V まだサポートしていません
|
||||
|
||||
## 依存関係のインストール
|
||||
|
||||
@ -19,8 +19,7 @@
|
||||
|
||||
```shell
|
||||
git clone https://github.com/huggingface/diffusers.git
|
||||
cd diffusers
|
||||
git checkout cogvideox-lora-and-training
|
||||
cd diffusers # Now in Main branch
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
@ -120,13 +119,13 @@ accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gp
|
||||
単一GPU微調整:
|
||||
|
||||
```shell
|
||||
bash finetune_single_gpu.sh
|
||||
bash finetune_single_rank.sh
|
||||
```
|
||||
|
||||
複数GPU微調整:
|
||||
|
||||
```shell
|
||||
bash finetune_multi_gpus_1.sh # 各ノードで実行する必要があります。
|
||||
bash finetune_multi_rank.sh # 各ノードで実行する必要があります。
|
||||
```
|
||||
|
||||
## 微調整済みモデルのロード
|
||||
|
@ -8,9 +8,9 @@
|
||||
|
||||
## 硬件要求
|
||||
|
||||
+ CogVideoX-2B LORA: 1 * A100
|
||||
+ CogVideoX-2B SFT: 8 * A100
|
||||
+ CogVideoX-5B/5B-I2V 暂未支持
|
||||
+ CogVideoX-2B / 5B T2V LORA: 1 * A100 (5B need to use `--use_8bit_adam`)
|
||||
+ CogVideoX-2B SFT: 8 * A100 (制作中)
|
||||
+ CogVideoX-5B-I2V 暂未支持
|
||||
|
||||
## 安装依赖
|
||||
|
||||
@ -18,8 +18,7 @@
|
||||
|
||||
```shell
|
||||
git clone https://github.com/huggingface/diffusers.git
|
||||
cd diffusers
|
||||
git checkout cogvideox-lora-and-training
|
||||
cd diffusers # Now in Main branch
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
@ -150,13 +149,13 @@ accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gp
|
||||
单卡微调:
|
||||
|
||||
```shell
|
||||
bash finetune_single_gpu.sh
|
||||
bash finetune_single_rank.sh
|
||||
```
|
||||
|
||||
多卡微调:
|
||||
|
||||
```shell
|
||||
bash finetune_multi_gpus_1.sh #需要在每个节点运行
|
||||
bash finetune_multi_rank.sh #需要在每个节点运行
|
||||
```
|
||||
|
||||
## 载入微调的模型
|
||||
|
@ -9,6 +9,7 @@ export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES
|
||||
|
||||
accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gpu --machine_rank 0 \
|
||||
train_cogvideox_lora.py \
|
||||
--gradient_checkpointing \
|
||||
--pretrained_model_name_or_path $MODEL_PATH \
|
||||
--cache_dir $CACHE_PATH \
|
||||
--enable_tiling \
|
@ -3,12 +3,17 @@
|
||||
export MODEL_PATH="THUDM/CogVideoX-2b"
|
||||
export CACHE_PATH="~/.cache"
|
||||
export DATASET_PATH="Disney-VideoGeneration-Dataset"
|
||||
export OUTPUT_PATH="cogvideox-lora-single-gpu"
|
||||
export OUTPUT_PATH="cogvideox-lora-multi-gpu"
|
||||
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
||||
export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES
|
||||
|
||||
|
||||
# --use_8bit_adam is necessary for CogVideoX-5B-I2V
|
||||
# if you are not using wth 8 gus, change `accelerate_config_machine_single.yaml` num_processes as your gpu number
|
||||
accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gpu \
|
||||
train_cogvideox_lora.py \
|
||||
--gradient_checkpointing \
|
||||
--use_8bit_adam \
|
||||
--pretrained_model_name_or_path $MODEL_PATH \
|
||||
--cache_dir $CACHE_PATH \
|
||||
--enable_tiling \
|
@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
@ -32,7 +31,7 @@ from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dic
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import T5EncoderModel, T5Tokenizer
|
||||
from transformers import AutoTokenizer, T5EncoderModel, T5Tokenizer
|
||||
|
||||
import diffusers
|
||||
from diffusers import AutoencoderKLCogVideoX, CogVideoXDPMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel
|
||||
@ -40,7 +39,6 @@ from diffusers.models.embeddings import get_3d_rotary_pos_embed
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
|
||||
from diffusers.training_utils import (
|
||||
_set_state_dict_into_text_encoder,
|
||||
cast_training_params,
|
||||
clear_objs_and_retain_memory,
|
||||
)
|
||||
@ -48,6 +46,7 @@ from diffusers.utils import check_min_version, convert_unet_state_dict_to_peft,
|
||||
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
@ -239,11 +238,6 @@ def get_args():
|
||||
action="store_true",
|
||||
help="whether to randomly flip videos horizontally",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_text_encoder",
|
||||
action="store_true",
|
||||
help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
|
||||
)
|
||||
@ -296,12 +290,6 @@ def get_args():
|
||||
default=1e-4,
|
||||
help="Initial learning rate (after the potential warmup period) to use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text_encoder_lr",
|
||||
type=float,
|
||||
default=5e-6,
|
||||
help="Text encoder learning rate to use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scale_lr",
|
||||
action="store_true",
|
||||
@ -367,9 +355,6 @@ def get_args():
|
||||
)
|
||||
parser.add_argument("--prodigy_decouple", action="store_true", help="Use AdamW style decoupled weight decay")
|
||||
parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
|
||||
parser.add_argument(
|
||||
"--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adam_epsilon",
|
||||
type=float,
|
||||
@ -423,20 +408,20 @@ def get_args():
|
||||
|
||||
class VideoDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
instance_data_root: Optional[str] = None,
|
||||
dataset_name: Optional[str] = None,
|
||||
dataset_config_name: Optional[str] = None,
|
||||
caption_column: str = "text",
|
||||
video_column: str = "video",
|
||||
height: int = 480,
|
||||
width: int = 720,
|
||||
fps: int = 8,
|
||||
max_num_frames: int = 49,
|
||||
skip_frames_start: int = 0,
|
||||
skip_frames_end: int = 0,
|
||||
cache_dir: Optional[str] = None,
|
||||
id_token: Optional[str] = None,
|
||||
self,
|
||||
instance_data_root: Optional[str] = None,
|
||||
dataset_name: Optional[str] = None,
|
||||
dataset_config_name: Optional[str] = None,
|
||||
caption_column: str = "text",
|
||||
video_column: str = "video",
|
||||
height: int = 480,
|
||||
width: int = 720,
|
||||
fps: int = 8,
|
||||
max_num_frames: int = 49,
|
||||
skip_frames_start: int = 0,
|
||||
skip_frames_end: int = 0,
|
||||
cache_dir: Optional[str] = None,
|
||||
id_token: Optional[str] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@ -555,7 +540,7 @@ class VideoDataset(Dataset):
|
||||
import decord
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"The `decord` package is required for loading the video dataset. Install with `pip install dataset`"
|
||||
"The `decord` package is required for loading the video dataset. Install with `pip install decord`"
|
||||
)
|
||||
|
||||
decord.bridge.set_bridge("torch")
|
||||
@ -602,13 +587,12 @@ class VideoDataset(Dataset):
|
||||
|
||||
|
||||
def save_model_card(
|
||||
repo_id: str,
|
||||
videos=None,
|
||||
base_model: str = None,
|
||||
train_text_encoder=False,
|
||||
validation_prompt=None,
|
||||
repo_folder=None,
|
||||
fps=8,
|
||||
repo_id: str,
|
||||
videos=None,
|
||||
base_model: str = None,
|
||||
validation_prompt=None,
|
||||
repo_folder=None,
|
||||
fps=8,
|
||||
):
|
||||
widget_dict = []
|
||||
if videos is not None:
|
||||
@ -627,9 +611,9 @@ def save_model_card(
|
||||
|
||||
These are {repo_id} LoRA weights for {base_model}.
|
||||
|
||||
The weights were trained using the [CogVideoX Diffusers trainer](TODO).
|
||||
The weights were trained using the [CogVideoX Diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/cogvideo/train_cogvideox_lora.py).
|
||||
|
||||
Was LoRA for the text encoder enabled? {train_text_encoder}.
|
||||
Was LoRA for the text encoder enabled? No.
|
||||
|
||||
## Download model
|
||||
|
||||
@ -642,8 +626,15 @@ from diffusers import CogVideoXPipeline
|
||||
import torch
|
||||
|
||||
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to("cuda")
|
||||
pipe.load_lora_weights("{repo_id}", weight_name="pytorch_lora_weights.safetensors")
|
||||
video = pipe("{validation_prompt}").frames[0]
|
||||
pipe.load_lora_weights("{repo_id}", weight_name="pytorch_lora_weights.safetensors", adapter_name=["cogvideox-lora"])
|
||||
|
||||
# The LoRA adapter weights are determined by what was used for training.
|
||||
# In this case, we assume `--lora_alpha` is 32 and `--rank` is 64.
|
||||
# It can be made lower or higher from what was used in training to decrease or amplify the effect
|
||||
# of the LoRA upto a tolerance, beyond which one might notice no effect at all or overflows.
|
||||
pipe.set_adapters(["cogvideox-lora"], [32 / 64])
|
||||
|
||||
video = pipe("{validation_prompt}", guidance_scale=6, use_dynamic_cfg=True).frames[0]
|
||||
```
|
||||
|
||||
For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
|
||||
@ -676,12 +667,12 @@ Please adhere to the licensing terms as described [here](https://huggingface.co/
|
||||
|
||||
|
||||
def log_validation(
|
||||
pipe,
|
||||
args,
|
||||
accelerator,
|
||||
pipeline_args,
|
||||
epoch,
|
||||
is_final_validation: bool = False,
|
||||
pipe,
|
||||
args,
|
||||
accelerator,
|
||||
pipeline_args,
|
||||
epoch,
|
||||
is_final_validation: bool = False,
|
||||
):
|
||||
logger.info(
|
||||
f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}."
|
||||
@ -741,14 +732,14 @@ def log_validation(
|
||||
|
||||
|
||||
def _get_t5_prompt_embeds(
|
||||
tokenizer: T5Tokenizer,
|
||||
text_encoder: T5EncoderModel,
|
||||
prompt: Union[str, List[str]],
|
||||
num_videos_per_prompt: int = 1,
|
||||
max_sequence_length: int = 226,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
text_input_ids=None,
|
||||
tokenizer: T5Tokenizer,
|
||||
text_encoder: T5EncoderModel,
|
||||
prompt: Union[str, List[str]],
|
||||
num_videos_per_prompt: int = 1,
|
||||
max_sequence_length: int = 226,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
text_input_ids=None,
|
||||
):
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
@ -779,14 +770,14 @@ def _get_t5_prompt_embeds(
|
||||
|
||||
|
||||
def encode_prompt(
|
||||
tokenizer: T5Tokenizer,
|
||||
text_encoder: T5EncoderModel,
|
||||
prompt: Union[str, List[str]],
|
||||
num_videos_per_prompt: int = 1,
|
||||
max_sequence_length: int = 226,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
text_input_ids=None,
|
||||
tokenizer: T5Tokenizer,
|
||||
text_encoder: T5EncoderModel,
|
||||
prompt: Union[str, List[str]],
|
||||
num_videos_per_prompt: int = 1,
|
||||
max_sequence_length: int = 226,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
text_input_ids=None,
|
||||
):
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
prompt_embeds = _get_t5_prompt_embeds(
|
||||
@ -802,13 +793,16 @@ def encode_prompt(
|
||||
return prompt_embeds
|
||||
|
||||
|
||||
def compute_prompt_embeddings(tokenizer, text_encoder, prompt, device, dtype, requires_grad: bool = False):
|
||||
def compute_prompt_embeddings(
|
||||
tokenizer, text_encoder, prompt, max_sequence_length, device, dtype, requires_grad: bool = False
|
||||
):
|
||||
if requires_grad:
|
||||
prompt_embeds = encode_prompt(
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
prompt,
|
||||
num_videos_per_prompt=1,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
@ -819,6 +813,7 @@ def compute_prompt_embeddings(tokenizer, text_encoder, prompt, device, dtype, re
|
||||
text_encoder,
|
||||
prompt,
|
||||
num_videos_per_prompt=1,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
@ -826,15 +821,15 @@ def compute_prompt_embeddings(tokenizer, text_encoder, prompt, device, dtype, re
|
||||
|
||||
|
||||
def prepare_rotary_positional_embeddings(
|
||||
height: int,
|
||||
width: int,
|
||||
num_frames: int,
|
||||
vae_scale_factor_spatial: int = 8,
|
||||
patch_size: int = 2,
|
||||
attention_head_dim: int = 64,
|
||||
device: Optional[torch.device] = None,
|
||||
base_height: int = 480,
|
||||
base_width: int = 720,
|
||||
height: int,
|
||||
width: int,
|
||||
num_frames: int,
|
||||
vae_scale_factor_spatial: int = 8,
|
||||
patch_size: int = 2,
|
||||
attention_head_dim: int = 64,
|
||||
device: Optional[torch.device] = None,
|
||||
base_height: int = 480,
|
||||
base_width: int = 720,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
grid_height = height // (vae_scale_factor_spatial * patch_size)
|
||||
grid_width = width // (vae_scale_factor_spatial * patch_size)
|
||||
@ -854,10 +849,22 @@ def prepare_rotary_positional_embeddings(
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
|
||||
def get_optimizer(args, params_to_optimize):
|
||||
def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False):
|
||||
# Use DeepSpeed optimzer
|
||||
if use_deepspeed:
|
||||
from accelerate.utils import DummyOptim
|
||||
|
||||
return DummyOptim(
|
||||
params_to_optimize,
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
eps=args.adam_epsilon,
|
||||
weight_decay=args.adam_weight_decay,
|
||||
)
|
||||
|
||||
# Optimizer creation
|
||||
supported_optimizers = ["adam", "adamw", "prodigy"]
|
||||
if args.optimizer not in ["adam", "adamw", "prodigy"]:
|
||||
if args.optimizer not in supported_optimizers:
|
||||
logger.warning(
|
||||
f"Unsupported choice of optimizer: {args.optimizer}. Supported optimizers include {supported_optimizers}. Defaulting to AdamW"
|
||||
)
|
||||
@ -907,14 +914,6 @@ def get_optimizer(args, params_to_optimize):
|
||||
logger.warning(
|
||||
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
|
||||
)
|
||||
if args.train_text_encoder and args.text_encoder_lr:
|
||||
logger.warning(
|
||||
f"Learning rates were provided both for the transformer and the text encoder - e.g. text_encoder_lr:"
|
||||
f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
|
||||
f"When using prodigy only learning_rate is used as the initial learning rate."
|
||||
)
|
||||
# Changes the learning rate of text_encoder_parameters to be --learning_rate
|
||||
params_to_optimize[1]["lr"] = args.learning_rate
|
||||
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
@ -994,7 +993,7 @@ def main(args):
|
||||
).repo_id
|
||||
|
||||
# Prepare models and scheduler
|
||||
tokenizer = T5Tokenizer.from_pretrained(
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
|
||||
)
|
||||
|
||||
@ -1035,13 +1034,13 @@ def main(args):
|
||||
if accelerator.state.deepspeed_plugin:
|
||||
# DeepSpeed is handling precision, use what's in the DeepSpeed config
|
||||
if (
|
||||
"fp16" in accelerator.state.deepspeed_plugin.deepspeed_config
|
||||
and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"]
|
||||
"fp16" in accelerator.state.deepspeed_plugin.deepspeed_config
|
||||
and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"]
|
||||
):
|
||||
weight_dtype = torch.float16
|
||||
if (
|
||||
"bf16" in accelerator.state.deepspeed_plugin.deepspeed_config
|
||||
and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"]
|
||||
"bf16" in accelerator.state.deepspeed_plugin.deepspeed_config
|
||||
and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"]
|
||||
):
|
||||
weight_dtype = torch.float16
|
||||
else:
|
||||
@ -1062,8 +1061,6 @@ def main(args):
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
transformer.enable_gradient_checkpointing()
|
||||
if args.train_text_encoder:
|
||||
text_encoder.gradient_checkpointing_enable()
|
||||
|
||||
# now we will add new LoRA weights to the attention layers
|
||||
transformer_lora_config = LoraConfig(
|
||||
@ -1073,14 +1070,6 @@ def main(args):
|
||||
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
||||
)
|
||||
transformer.add_adapter(transformer_lora_config)
|
||||
if args.train_text_encoder:
|
||||
text_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.lora_alpha,
|
||||
init_lora_weights=True,
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
|
||||
)
|
||||
text_encoder.add_adapter(text_lora_config)
|
||||
|
||||
def unwrap_model(model):
|
||||
model = accelerator.unwrap_model(model)
|
||||
@ -1091,13 +1080,10 @@ def main(args):
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if accelerator.is_main_process:
|
||||
transformer_lora_layers_to_save = None
|
||||
text_encoder_lora_layers_to_save = None
|
||||
|
||||
for model in models:
|
||||
if isinstance(model, type(unwrap_model(transformer))):
|
||||
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
elif isinstance(model, type(unwrap_model(text_encoder))):
|
||||
text_encoder_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
@ -1107,22 +1093,18 @@ def main(args):
|
||||
CogVideoXPipeline.save_lora_weights(
|
||||
output_dir,
|
||||
transformer_lora_layers=transformer_lora_layers_to_save,
|
||||
text_encoder_lora_layers=text_encoder_lora_layers_to_save,
|
||||
)
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
transformer_ = None
|
||||
text_encoder_ = None
|
||||
|
||||
while len(models) > 0:
|
||||
model = models.pop()
|
||||
|
||||
if isinstance(model, type(unwrap_model(transformer))):
|
||||
transformer_ = model
|
||||
elif isinstance(model, type(unwrap_model(text_encoder))):
|
||||
text_encoder_ = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
raise ValueError(f"Unexpected save model: {model.__class__}")
|
||||
|
||||
lora_state_dict = CogVideoXPipeline.lora_state_dict(input_dir)
|
||||
|
||||
@ -1139,19 +1121,13 @@ def main(args):
|
||||
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
|
||||
f" {unexpected_keys}. "
|
||||
)
|
||||
if args.train_text_encoder:
|
||||
# Do we need to call `scale_lora_layers()` here?
|
||||
_set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_)
|
||||
|
||||
# Make sure the trainable params are in float32. This is again needed since the base models
|
||||
# are in `weight_dtype`. More details:
|
||||
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
|
||||
if args.mixed_precision == "fp16":
|
||||
models = [transformer_]
|
||||
if args.train_text_encoder:
|
||||
models.extend([text_encoder_])
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
cast_training_params(models)
|
||||
cast_training_params([transformer_])
|
||||
|
||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
accelerator.register_load_state_pre_hook(load_model_hook)
|
||||
@ -1163,38 +1139,30 @@ def main(args):
|
||||
|
||||
if args.scale_lr:
|
||||
args.learning_rate = (
|
||||
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
||||
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
||||
)
|
||||
|
||||
# Make sure the trainable params are in float32.
|
||||
if args.mixed_precision == "fp16":
|
||||
models = [transformer]
|
||||
if args.train_text_encoder:
|
||||
models.extend([text_encoder])
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
cast_training_params(models, dtype=torch.float32)
|
||||
cast_training_params([transformer], dtype=torch.float32)
|
||||
|
||||
transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
|
||||
if args.train_text_encoder:
|
||||
text_encoder_lora_parameters = list(filter(lambda p: p.requires_grad, text_encoder.parameters()))
|
||||
|
||||
# Optimization parameters
|
||||
transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate}
|
||||
if args.train_text_encoder:
|
||||
# different learning rate for text encoder and unet
|
||||
text_encoder_parameters_with_lr = {
|
||||
"params": text_encoder_lora_parameters,
|
||||
"weight_decay": args.adam_weight_decay_text_encoder,
|
||||
"lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
|
||||
}
|
||||
params_to_optimize = [
|
||||
transformer_parameters_with_lr,
|
||||
text_encoder_parameters_with_lr,
|
||||
]
|
||||
else:
|
||||
params_to_optimize = [transformer_parameters_with_lr]
|
||||
params_to_optimize = [transformer_parameters_with_lr]
|
||||
|
||||
optimizer = get_optimizer(args, params_to_optimize)
|
||||
use_deepspeed_optimizer = (
|
||||
accelerator.state.deepspeed_plugin is not None
|
||||
and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config
|
||||
)
|
||||
use_deepspeed_scheduler = (
|
||||
accelerator.state.deepspeed_plugin is not None
|
||||
and "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
|
||||
)
|
||||
|
||||
optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer)
|
||||
|
||||
# Dataset and DataLoader
|
||||
train_dataset = VideoDataset(
|
||||
@ -1248,35 +1216,30 @@ def main(args):
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
overrode_max_train_steps = True
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
||||
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
||||
num_cycles=args.lr_num_cycles,
|
||||
power=args.lr_power,
|
||||
)
|
||||
if use_deepspeed_scheduler:
|
||||
from accelerate.utils import DummyScheduler
|
||||
|
||||
# Prepare everything with our `accelerator`.
|
||||
if args.train_text_encoder:
|
||||
(
|
||||
transformer,
|
||||
text_encoder,
|
||||
optimizer,
|
||||
train_dataloader,
|
||||
lr_scheduler,
|
||||
) = accelerator.prepare(
|
||||
transformer,
|
||||
text_encoder,
|
||||
optimizer,
|
||||
train_dataloader,
|
||||
lr_scheduler,
|
||||
lr_scheduler = DummyScheduler(
|
||||
name=args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
total_num_steps=args.max_train_steps * accelerator.num_processes,
|
||||
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
||||
)
|
||||
else:
|
||||
transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
transformer, optimizer, train_dataloader, lr_scheduler
|
||||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
||||
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
||||
num_cycles=args.lr_num_cycles,
|
||||
power=args.lr_power,
|
||||
)
|
||||
|
||||
# Prepare everything with our `accelerator`.
|
||||
transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
transformer, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if overrode_max_train_steps:
|
||||
@ -1347,15 +1310,9 @@ def main(args):
|
||||
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
transformer.train()
|
||||
if args.train_text_encoder:
|
||||
text_encoder.train()
|
||||
# set top parameter requires_grad = True for gradient checkpointing works
|
||||
accelerator.unwrap_model(text_encoder).text_model.embeddings.requires_grad_(True)
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
models_to_accumulate = [transformer]
|
||||
if args.train_text_encoder:
|
||||
models_to_accumulate.extend([text_encoder])
|
||||
|
||||
with accelerator.accumulate(models_to_accumulate):
|
||||
model_input = batch["videos"].permute(0, 2, 1, 3, 4).to(dtype=weight_dtype) # [B, F, C, H, W]
|
||||
@ -1366,9 +1323,10 @@ def main(args):
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
prompts,
|
||||
model_config.max_text_seq_length,
|
||||
accelerator.device,
|
||||
weight_dtype,
|
||||
requires_grad=args.train_text_encoder,
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
# Sample noise that will be added to the latents
|
||||
@ -1422,16 +1380,14 @@ def main(args):
|
||||
accelerator.backward(loss)
|
||||
|
||||
if accelerator.sync_gradients:
|
||||
params_to_clip = (
|
||||
itertools.chain(transformer.parameters(), text_encoder.parameters())
|
||||
if args.train_text_encoder
|
||||
else transformer.parameters()
|
||||
)
|
||||
params_to_clip = transformer.parameters()
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
if accelerator.state.deepspeed_plugin is None:
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
@ -1507,7 +1463,6 @@ def main(args):
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
transformer = unwrap_model(transformer)
|
||||
# transformer = transformer.to(torch.float32)
|
||||
dtype = (
|
||||
torch.float16
|
||||
if args.mixed_precision == "fp16"
|
||||
@ -1518,16 +1473,9 @@ def main(args):
|
||||
transformer = transformer.to(dtype)
|
||||
transformer_lora_layers = get_peft_model_state_dict(transformer)
|
||||
|
||||
if args.train_text_encoder:
|
||||
text_encoder = unwrap_model(text_encoder)
|
||||
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder.to(dtype))
|
||||
else:
|
||||
text_encoder_lora_layers = None
|
||||
|
||||
CogVideoXPipeline.save_lora_weights(
|
||||
save_directory=args.output_dir,
|
||||
transformer_lora_layers=transformer_lora_layers,
|
||||
text_encoder_lora_layers=text_encoder_lora_layers,
|
||||
)
|
||||
|
||||
# Final test inference
|
||||
@ -1539,6 +1487,11 @@ def main(args):
|
||||
)
|
||||
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config)
|
||||
|
||||
if args.enable_slicing:
|
||||
pipe.vae.enable_slicing()
|
||||
if args.enable_tiling:
|
||||
pipe.vae.enable_tiling()
|
||||
|
||||
# Load LoRA weights
|
||||
lora_scaling = args.lora_alpha / args.rank
|
||||
pipe.load_lora_weights(args.output_dir, adapter_name="cogvideox-lora")
|
||||
@ -1572,7 +1525,6 @@ def main(args):
|
||||
repo_id,
|
||||
videos=validation_outputs,
|
||||
base_model=args.pretrained_model_name_or_path,
|
||||
train_text_encoder=args.train_text_encoder,
|
||||
validation_prompt=args.validation_prompt,
|
||||
repo_folder=args.output_dir,
|
||||
fps=args.fps,
|
||||
|
Loading…
x
Reference in New Issue
Block a user