|
|
|
@ -14,7 +14,6 @@
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
|
import itertools
|
|
|
|
|
import logging
|
|
|
|
|
import math
|
|
|
|
|
import os
|
|
|
|
@ -32,7 +31,7 @@ from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dic
|
|
|
|
|
from torch.utils.data import DataLoader, Dataset
|
|
|
|
|
from torchvision import transforms
|
|
|
|
|
from tqdm.auto import tqdm
|
|
|
|
|
from transformers import T5EncoderModel, T5Tokenizer
|
|
|
|
|
from transformers import AutoTokenizer, T5EncoderModel, T5Tokenizer
|
|
|
|
|
|
|
|
|
|
import diffusers
|
|
|
|
|
from diffusers import AutoencoderKLCogVideoX, CogVideoXDPMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel
|
|
|
|
@ -40,7 +39,6 @@ from diffusers.models.embeddings import get_3d_rotary_pos_embed
|
|
|
|
|
from diffusers.optimization import get_scheduler
|
|
|
|
|
from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
|
|
|
|
|
from diffusers.training_utils import (
|
|
|
|
|
_set_state_dict_into_text_encoder,
|
|
|
|
|
cast_training_params,
|
|
|
|
|
clear_objs_and_retain_memory,
|
|
|
|
|
)
|
|
|
|
@ -48,6 +46,7 @@ from diffusers.utils import check_min_version, convert_unet_state_dict_to_peft,
|
|
|
|
|
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
|
|
|
|
from diffusers.utils.torch_utils import is_compiled_module
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if is_wandb_available():
|
|
|
|
|
import wandb
|
|
|
|
|
|
|
|
|
@ -239,11 +238,6 @@ def get_args():
|
|
|
|
|
action="store_true",
|
|
|
|
|
help="whether to randomly flip videos horizontally",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--train_text_encoder",
|
|
|
|
|
action="store_true",
|
|
|
|
|
help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
|
|
|
|
|
)
|
|
|
|
@ -296,12 +290,6 @@ def get_args():
|
|
|
|
|
default=1e-4,
|
|
|
|
|
help="Initial learning rate (after the potential warmup period) to use.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--text_encoder_lr",
|
|
|
|
|
type=float,
|
|
|
|
|
default=5e-6,
|
|
|
|
|
help="Text encoder learning rate to use.",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--scale_lr",
|
|
|
|
|
action="store_true",
|
|
|
|
@ -367,9 +355,6 @@ def get_args():
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument("--prodigy_decouple", action="store_true", help="Use AdamW style decoupled weight decay")
|
|
|
|
|
parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--adam_epsilon",
|
|
|
|
|
type=float,
|
|
|
|
@ -423,20 +408,20 @@ def get_args():
|
|
|
|
|
|
|
|
|
|
class VideoDataset(Dataset):
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
instance_data_root: Optional[str] = None,
|
|
|
|
|
dataset_name: Optional[str] = None,
|
|
|
|
|
dataset_config_name: Optional[str] = None,
|
|
|
|
|
caption_column: str = "text",
|
|
|
|
|
video_column: str = "video",
|
|
|
|
|
height: int = 480,
|
|
|
|
|
width: int = 720,
|
|
|
|
|
fps: int = 8,
|
|
|
|
|
max_num_frames: int = 49,
|
|
|
|
|
skip_frames_start: int = 0,
|
|
|
|
|
skip_frames_end: int = 0,
|
|
|
|
|
cache_dir: Optional[str] = None,
|
|
|
|
|
id_token: Optional[str] = None,
|
|
|
|
|
self,
|
|
|
|
|
instance_data_root: Optional[str] = None,
|
|
|
|
|
dataset_name: Optional[str] = None,
|
|
|
|
|
dataset_config_name: Optional[str] = None,
|
|
|
|
|
caption_column: str = "text",
|
|
|
|
|
video_column: str = "video",
|
|
|
|
|
height: int = 480,
|
|
|
|
|
width: int = 720,
|
|
|
|
|
fps: int = 8,
|
|
|
|
|
max_num_frames: int = 49,
|
|
|
|
|
skip_frames_start: int = 0,
|
|
|
|
|
skip_frames_end: int = 0,
|
|
|
|
|
cache_dir: Optional[str] = None,
|
|
|
|
|
id_token: Optional[str] = None,
|
|
|
|
|
) -> None:
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
@ -555,7 +540,7 @@ class VideoDataset(Dataset):
|
|
|
|
|
import decord
|
|
|
|
|
except ImportError:
|
|
|
|
|
raise ImportError(
|
|
|
|
|
"The `decord` package is required for loading the video dataset. Install with `pip install dataset`"
|
|
|
|
|
"The `decord` package is required for loading the video dataset. Install with `pip install decord`"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
decord.bridge.set_bridge("torch")
|
|
|
|
@ -602,13 +587,12 @@ class VideoDataset(Dataset):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_model_card(
|
|
|
|
|
repo_id: str,
|
|
|
|
|
videos=None,
|
|
|
|
|
base_model: str = None,
|
|
|
|
|
train_text_encoder=False,
|
|
|
|
|
validation_prompt=None,
|
|
|
|
|
repo_folder=None,
|
|
|
|
|
fps=8,
|
|
|
|
|
repo_id: str,
|
|
|
|
|
videos=None,
|
|
|
|
|
base_model: str = None,
|
|
|
|
|
validation_prompt=None,
|
|
|
|
|
repo_folder=None,
|
|
|
|
|
fps=8,
|
|
|
|
|
):
|
|
|
|
|
widget_dict = []
|
|
|
|
|
if videos is not None:
|
|
|
|
@ -627,9 +611,9 @@ def save_model_card(
|
|
|
|
|
|
|
|
|
|
These are {repo_id} LoRA weights for {base_model}.
|
|
|
|
|
|
|
|
|
|
The weights were trained using the [CogVideoX Diffusers trainer](TODO).
|
|
|
|
|
The weights were trained using the [CogVideoX Diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/cogvideo/train_cogvideox_lora.py).
|
|
|
|
|
|
|
|
|
|
Was LoRA for the text encoder enabled? {train_text_encoder}.
|
|
|
|
|
Was LoRA for the text encoder enabled? No.
|
|
|
|
|
|
|
|
|
|
## Download model
|
|
|
|
|
|
|
|
|
@ -642,8 +626,15 @@ from diffusers import CogVideoXPipeline
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to("cuda")
|
|
|
|
|
pipe.load_lora_weights("{repo_id}", weight_name="pytorch_lora_weights.safetensors")
|
|
|
|
|
video = pipe("{validation_prompt}").frames[0]
|
|
|
|
|
pipe.load_lora_weights("{repo_id}", weight_name="pytorch_lora_weights.safetensors", adapter_name=["cogvideox-lora"])
|
|
|
|
|
|
|
|
|
|
# The LoRA adapter weights are determined by what was used for training.
|
|
|
|
|
# In this case, we assume `--lora_alpha` is 32 and `--rank` is 64.
|
|
|
|
|
# It can be made lower or higher from what was used in training to decrease or amplify the effect
|
|
|
|
|
# of the LoRA upto a tolerance, beyond which one might notice no effect at all or overflows.
|
|
|
|
|
pipe.set_adapters(["cogvideox-lora"], [32 / 64])
|
|
|
|
|
|
|
|
|
|
video = pipe("{validation_prompt}", guidance_scale=6, use_dynamic_cfg=True).frames[0]
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
|
|
|
|
@ -676,12 +667,12 @@ Please adhere to the licensing terms as described [here](https://huggingface.co/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def log_validation(
|
|
|
|
|
pipe,
|
|
|
|
|
args,
|
|
|
|
|
accelerator,
|
|
|
|
|
pipeline_args,
|
|
|
|
|
epoch,
|
|
|
|
|
is_final_validation: bool = False,
|
|
|
|
|
pipe,
|
|
|
|
|
args,
|
|
|
|
|
accelerator,
|
|
|
|
|
pipeline_args,
|
|
|
|
|
epoch,
|
|
|
|
|
is_final_validation: bool = False,
|
|
|
|
|
):
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}."
|
|
|
|
@ -741,14 +732,14 @@ def log_validation(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_t5_prompt_embeds(
|
|
|
|
|
tokenizer: T5Tokenizer,
|
|
|
|
|
text_encoder: T5EncoderModel,
|
|
|
|
|
prompt: Union[str, List[str]],
|
|
|
|
|
num_videos_per_prompt: int = 1,
|
|
|
|
|
max_sequence_length: int = 226,
|
|
|
|
|
device: Optional[torch.device] = None,
|
|
|
|
|
dtype: Optional[torch.dtype] = None,
|
|
|
|
|
text_input_ids=None,
|
|
|
|
|
tokenizer: T5Tokenizer,
|
|
|
|
|
text_encoder: T5EncoderModel,
|
|
|
|
|
prompt: Union[str, List[str]],
|
|
|
|
|
num_videos_per_prompt: int = 1,
|
|
|
|
|
max_sequence_length: int = 226,
|
|
|
|
|
device: Optional[torch.device] = None,
|
|
|
|
|
dtype: Optional[torch.dtype] = None,
|
|
|
|
|
text_input_ids=None,
|
|
|
|
|
):
|
|
|
|
|
prompt = [prompt] if isinstance(prompt, str) else prompt
|
|
|
|
|
batch_size = len(prompt)
|
|
|
|
@ -779,14 +770,14 @@ def _get_t5_prompt_embeds(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def encode_prompt(
|
|
|
|
|
tokenizer: T5Tokenizer,
|
|
|
|
|
text_encoder: T5EncoderModel,
|
|
|
|
|
prompt: Union[str, List[str]],
|
|
|
|
|
num_videos_per_prompt: int = 1,
|
|
|
|
|
max_sequence_length: int = 226,
|
|
|
|
|
device: Optional[torch.device] = None,
|
|
|
|
|
dtype: Optional[torch.dtype] = None,
|
|
|
|
|
text_input_ids=None,
|
|
|
|
|
tokenizer: T5Tokenizer,
|
|
|
|
|
text_encoder: T5EncoderModel,
|
|
|
|
|
prompt: Union[str, List[str]],
|
|
|
|
|
num_videos_per_prompt: int = 1,
|
|
|
|
|
max_sequence_length: int = 226,
|
|
|
|
|
device: Optional[torch.device] = None,
|
|
|
|
|
dtype: Optional[torch.dtype] = None,
|
|
|
|
|
text_input_ids=None,
|
|
|
|
|
):
|
|
|
|
|
prompt = [prompt] if isinstance(prompt, str) else prompt
|
|
|
|
|
prompt_embeds = _get_t5_prompt_embeds(
|
|
|
|
@ -802,13 +793,16 @@ def encode_prompt(
|
|
|
|
|
return prompt_embeds
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compute_prompt_embeddings(tokenizer, text_encoder, prompt, device, dtype, requires_grad: bool = False):
|
|
|
|
|
def compute_prompt_embeddings(
|
|
|
|
|
tokenizer, text_encoder, prompt, max_sequence_length, device, dtype, requires_grad: bool = False
|
|
|
|
|
):
|
|
|
|
|
if requires_grad:
|
|
|
|
|
prompt_embeds = encode_prompt(
|
|
|
|
|
tokenizer,
|
|
|
|
|
text_encoder,
|
|
|
|
|
prompt,
|
|
|
|
|
num_videos_per_prompt=1,
|
|
|
|
|
max_sequence_length=max_sequence_length,
|
|
|
|
|
device=device,
|
|
|
|
|
dtype=dtype,
|
|
|
|
|
)
|
|
|
|
@ -819,6 +813,7 @@ def compute_prompt_embeddings(tokenizer, text_encoder, prompt, device, dtype, re
|
|
|
|
|
text_encoder,
|
|
|
|
|
prompt,
|
|
|
|
|
num_videos_per_prompt=1,
|
|
|
|
|
max_sequence_length=max_sequence_length,
|
|
|
|
|
device=device,
|
|
|
|
|
dtype=dtype,
|
|
|
|
|
)
|
|
|
|
@ -826,15 +821,15 @@ def compute_prompt_embeddings(tokenizer, text_encoder, prompt, device, dtype, re
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def prepare_rotary_positional_embeddings(
|
|
|
|
|
height: int,
|
|
|
|
|
width: int,
|
|
|
|
|
num_frames: int,
|
|
|
|
|
vae_scale_factor_spatial: int = 8,
|
|
|
|
|
patch_size: int = 2,
|
|
|
|
|
attention_head_dim: int = 64,
|
|
|
|
|
device: Optional[torch.device] = None,
|
|
|
|
|
base_height: int = 480,
|
|
|
|
|
base_width: int = 720,
|
|
|
|
|
height: int,
|
|
|
|
|
width: int,
|
|
|
|
|
num_frames: int,
|
|
|
|
|
vae_scale_factor_spatial: int = 8,
|
|
|
|
|
patch_size: int = 2,
|
|
|
|
|
attention_head_dim: int = 64,
|
|
|
|
|
device: Optional[torch.device] = None,
|
|
|
|
|
base_height: int = 480,
|
|
|
|
|
base_width: int = 720,
|
|
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
grid_height = height // (vae_scale_factor_spatial * patch_size)
|
|
|
|
|
grid_width = width // (vae_scale_factor_spatial * patch_size)
|
|
|
|
@ -854,10 +849,22 @@ def prepare_rotary_positional_embeddings(
|
|
|
|
|
return freqs_cos, freqs_sin
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_optimizer(args, params_to_optimize):
|
|
|
|
|
def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False):
|
|
|
|
|
# Use DeepSpeed optimzer
|
|
|
|
|
if use_deepspeed:
|
|
|
|
|
from accelerate.utils import DummyOptim
|
|
|
|
|
|
|
|
|
|
return DummyOptim(
|
|
|
|
|
params_to_optimize,
|
|
|
|
|
lr=args.learning_rate,
|
|
|
|
|
betas=(args.adam_beta1, args.adam_beta2),
|
|
|
|
|
eps=args.adam_epsilon,
|
|
|
|
|
weight_decay=args.adam_weight_decay,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Optimizer creation
|
|
|
|
|
supported_optimizers = ["adam", "adamw", "prodigy"]
|
|
|
|
|
if args.optimizer not in ["adam", "adamw", "prodigy"]:
|
|
|
|
|
if args.optimizer not in supported_optimizers:
|
|
|
|
|
logger.warning(
|
|
|
|
|
f"Unsupported choice of optimizer: {args.optimizer}. Supported optimizers include {supported_optimizers}. Defaulting to AdamW"
|
|
|
|
|
)
|
|
|
|
@ -907,14 +914,6 @@ def get_optimizer(args, params_to_optimize):
|
|
|
|
|
logger.warning(
|
|
|
|
|
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
|
|
|
|
|
)
|
|
|
|
|
if args.train_text_encoder and args.text_encoder_lr:
|
|
|
|
|
logger.warning(
|
|
|
|
|
f"Learning rates were provided both for the transformer and the text encoder - e.g. text_encoder_lr:"
|
|
|
|
|
f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
|
|
|
|
|
f"When using prodigy only learning_rate is used as the initial learning rate."
|
|
|
|
|
)
|
|
|
|
|
# Changes the learning rate of text_encoder_parameters to be --learning_rate
|
|
|
|
|
params_to_optimize[1]["lr"] = args.learning_rate
|
|
|
|
|
|
|
|
|
|
optimizer = optimizer_class(
|
|
|
|
|
params_to_optimize,
|
|
|
|
@ -994,7 +993,7 @@ def main(args):
|
|
|
|
|
).repo_id
|
|
|
|
|
|
|
|
|
|
# Prepare models and scheduler
|
|
|
|
|
tokenizer = T5Tokenizer.from_pretrained(
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
|
|
|
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
@ -1035,13 +1034,13 @@ def main(args):
|
|
|
|
|
if accelerator.state.deepspeed_plugin:
|
|
|
|
|
# DeepSpeed is handling precision, use what's in the DeepSpeed config
|
|
|
|
|
if (
|
|
|
|
|
"fp16" in accelerator.state.deepspeed_plugin.deepspeed_config
|
|
|
|
|
and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"]
|
|
|
|
|
"fp16" in accelerator.state.deepspeed_plugin.deepspeed_config
|
|
|
|
|
and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"]
|
|
|
|
|
):
|
|
|
|
|
weight_dtype = torch.float16
|
|
|
|
|
if (
|
|
|
|
|
"bf16" in accelerator.state.deepspeed_plugin.deepspeed_config
|
|
|
|
|
and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"]
|
|
|
|
|
"bf16" in accelerator.state.deepspeed_plugin.deepspeed_config
|
|
|
|
|
and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"]
|
|
|
|
|
):
|
|
|
|
|
weight_dtype = torch.float16
|
|
|
|
|
else:
|
|
|
|
@ -1062,8 +1061,6 @@ def main(args):
|
|
|
|
|
|
|
|
|
|
if args.gradient_checkpointing:
|
|
|
|
|
transformer.enable_gradient_checkpointing()
|
|
|
|
|
if args.train_text_encoder:
|
|
|
|
|
text_encoder.gradient_checkpointing_enable()
|
|
|
|
|
|
|
|
|
|
# now we will add new LoRA weights to the attention layers
|
|
|
|
|
transformer_lora_config = LoraConfig(
|
|
|
|
@ -1073,14 +1070,6 @@ def main(args):
|
|
|
|
|
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
|
|
|
|
)
|
|
|
|
|
transformer.add_adapter(transformer_lora_config)
|
|
|
|
|
if args.train_text_encoder:
|
|
|
|
|
text_lora_config = LoraConfig(
|
|
|
|
|
r=args.rank,
|
|
|
|
|
lora_alpha=args.lora_alpha,
|
|
|
|
|
init_lora_weights=True,
|
|
|
|
|
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
|
|
|
|
|
)
|
|
|
|
|
text_encoder.add_adapter(text_lora_config)
|
|
|
|
|
|
|
|
|
|
def unwrap_model(model):
|
|
|
|
|
model = accelerator.unwrap_model(model)
|
|
|
|
@ -1091,13 +1080,10 @@ def main(args):
|
|
|
|
|
def save_model_hook(models, weights, output_dir):
|
|
|
|
|
if accelerator.is_main_process:
|
|
|
|
|
transformer_lora_layers_to_save = None
|
|
|
|
|
text_encoder_lora_layers_to_save = None
|
|
|
|
|
|
|
|
|
|
for model in models:
|
|
|
|
|
if isinstance(model, type(unwrap_model(transformer))):
|
|
|
|
|
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
|
|
|
|
|
elif isinstance(model, type(unwrap_model(text_encoder))):
|
|
|
|
|
text_encoder_lora_layers_to_save = get_peft_model_state_dict(model)
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"unexpected save model: {model.__class__}")
|
|
|
|
|
|
|
|
|
@ -1107,22 +1093,18 @@ def main(args):
|
|
|
|
|
CogVideoXPipeline.save_lora_weights(
|
|
|
|
|
output_dir,
|
|
|
|
|
transformer_lora_layers=transformer_lora_layers_to_save,
|
|
|
|
|
text_encoder_lora_layers=text_encoder_lora_layers_to_save,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def load_model_hook(models, input_dir):
|
|
|
|
|
transformer_ = None
|
|
|
|
|
text_encoder_ = None
|
|
|
|
|
|
|
|
|
|
while len(models) > 0:
|
|
|
|
|
model = models.pop()
|
|
|
|
|
|
|
|
|
|
if isinstance(model, type(unwrap_model(transformer))):
|
|
|
|
|
transformer_ = model
|
|
|
|
|
elif isinstance(model, type(unwrap_model(text_encoder))):
|
|
|
|
|
text_encoder_ = model
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"unexpected save model: {model.__class__}")
|
|
|
|
|
raise ValueError(f"Unexpected save model: {model.__class__}")
|
|
|
|
|
|
|
|
|
|
lora_state_dict = CogVideoXPipeline.lora_state_dict(input_dir)
|
|
|
|
|
|
|
|
|
@ -1139,19 +1121,13 @@ def main(args):
|
|
|
|
|
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
|
|
|
|
|
f" {unexpected_keys}. "
|
|
|
|
|
)
|
|
|
|
|
if args.train_text_encoder:
|
|
|
|
|
# Do we need to call `scale_lora_layers()` here?
|
|
|
|
|
_set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_)
|
|
|
|
|
|
|
|
|
|
# Make sure the trainable params are in float32. This is again needed since the base models
|
|
|
|
|
# are in `weight_dtype`. More details:
|
|
|
|
|
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
|
|
|
|
|
if args.mixed_precision == "fp16":
|
|
|
|
|
models = [transformer_]
|
|
|
|
|
if args.train_text_encoder:
|
|
|
|
|
models.extend([text_encoder_])
|
|
|
|
|
# only upcast trainable parameters (LoRA) into fp32
|
|
|
|
|
cast_training_params(models)
|
|
|
|
|
cast_training_params([transformer_])
|
|
|
|
|
|
|
|
|
|
accelerator.register_save_state_pre_hook(save_model_hook)
|
|
|
|
|
accelerator.register_load_state_pre_hook(load_model_hook)
|
|
|
|
@ -1163,38 +1139,30 @@ def main(args):
|
|
|
|
|
|
|
|
|
|
if args.scale_lr:
|
|
|
|
|
args.learning_rate = (
|
|
|
|
|
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
|
|
|
|
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Make sure the trainable params are in float32.
|
|
|
|
|
if args.mixed_precision == "fp16":
|
|
|
|
|
models = [transformer]
|
|
|
|
|
if args.train_text_encoder:
|
|
|
|
|
models.extend([text_encoder])
|
|
|
|
|
# only upcast trainable parameters (LoRA) into fp32
|
|
|
|
|
cast_training_params(models, dtype=torch.float32)
|
|
|
|
|
cast_training_params([transformer], dtype=torch.float32)
|
|
|
|
|
|
|
|
|
|
transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
|
|
|
|
|
if args.train_text_encoder:
|
|
|
|
|
text_encoder_lora_parameters = list(filter(lambda p: p.requires_grad, text_encoder.parameters()))
|
|
|
|
|
|
|
|
|
|
# Optimization parameters
|
|
|
|
|
transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate}
|
|
|
|
|
if args.train_text_encoder:
|
|
|
|
|
# different learning rate for text encoder and unet
|
|
|
|
|
text_encoder_parameters_with_lr = {
|
|
|
|
|
"params": text_encoder_lora_parameters,
|
|
|
|
|
"weight_decay": args.adam_weight_decay_text_encoder,
|
|
|
|
|
"lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
|
|
|
|
|
}
|
|
|
|
|
params_to_optimize = [
|
|
|
|
|
transformer_parameters_with_lr,
|
|
|
|
|
text_encoder_parameters_with_lr,
|
|
|
|
|
]
|
|
|
|
|
else:
|
|
|
|
|
params_to_optimize = [transformer_parameters_with_lr]
|
|
|
|
|
params_to_optimize = [transformer_parameters_with_lr]
|
|
|
|
|
|
|
|
|
|
optimizer = get_optimizer(args, params_to_optimize)
|
|
|
|
|
use_deepspeed_optimizer = (
|
|
|
|
|
accelerator.state.deepspeed_plugin is not None
|
|
|
|
|
and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config
|
|
|
|
|
)
|
|
|
|
|
use_deepspeed_scheduler = (
|
|
|
|
|
accelerator.state.deepspeed_plugin is not None
|
|
|
|
|
and "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer)
|
|
|
|
|
|
|
|
|
|
# Dataset and DataLoader
|
|
|
|
|
train_dataset = VideoDataset(
|
|
|
|
@ -1248,35 +1216,30 @@ def main(args):
|
|
|
|
|
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
|
|
|
|
overrode_max_train_steps = True
|
|
|
|
|
|
|
|
|
|
lr_scheduler = get_scheduler(
|
|
|
|
|
args.lr_scheduler,
|
|
|
|
|
optimizer=optimizer,
|
|
|
|
|
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
|
|
|
|
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
|
|
|
|
num_cycles=args.lr_num_cycles,
|
|
|
|
|
power=args.lr_power,
|
|
|
|
|
)
|
|
|
|
|
if use_deepspeed_scheduler:
|
|
|
|
|
from accelerate.utils import DummyScheduler
|
|
|
|
|
|
|
|
|
|
# Prepare everything with our `accelerator`.
|
|
|
|
|
if args.train_text_encoder:
|
|
|
|
|
(
|
|
|
|
|
transformer,
|
|
|
|
|
text_encoder,
|
|
|
|
|
optimizer,
|
|
|
|
|
train_dataloader,
|
|
|
|
|
lr_scheduler,
|
|
|
|
|
) = accelerator.prepare(
|
|
|
|
|
transformer,
|
|
|
|
|
text_encoder,
|
|
|
|
|
optimizer,
|
|
|
|
|
train_dataloader,
|
|
|
|
|
lr_scheduler,
|
|
|
|
|
lr_scheduler = DummyScheduler(
|
|
|
|
|
name=args.lr_scheduler,
|
|
|
|
|
optimizer=optimizer,
|
|
|
|
|
total_num_steps=args.max_train_steps * accelerator.num_processes,
|
|
|
|
|
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
|
|
|
|
transformer, optimizer, train_dataloader, lr_scheduler
|
|
|
|
|
lr_scheduler = get_scheduler(
|
|
|
|
|
args.lr_scheduler,
|
|
|
|
|
optimizer=optimizer,
|
|
|
|
|
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
|
|
|
|
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
|
|
|
|
num_cycles=args.lr_num_cycles,
|
|
|
|
|
power=args.lr_power,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Prepare everything with our `accelerator`.
|
|
|
|
|
transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
|
|
|
|
transformer, optimizer, train_dataloader, lr_scheduler
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
|
|
|
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
|
|
|
|
if overrode_max_train_steps:
|
|
|
|
@ -1347,15 +1310,9 @@ def main(args):
|
|
|
|
|
|
|
|
|
|
for epoch in range(first_epoch, args.num_train_epochs):
|
|
|
|
|
transformer.train()
|
|
|
|
|
if args.train_text_encoder:
|
|
|
|
|
text_encoder.train()
|
|
|
|
|
# set top parameter requires_grad = True for gradient checkpointing works
|
|
|
|
|
accelerator.unwrap_model(text_encoder).text_model.embeddings.requires_grad_(True)
|
|
|
|
|
|
|
|
|
|
for step, batch in enumerate(train_dataloader):
|
|
|
|
|
models_to_accumulate = [transformer]
|
|
|
|
|
if args.train_text_encoder:
|
|
|
|
|
models_to_accumulate.extend([text_encoder])
|
|
|
|
|
|
|
|
|
|
with accelerator.accumulate(models_to_accumulate):
|
|
|
|
|
model_input = batch["videos"].permute(0, 2, 1, 3, 4).to(dtype=weight_dtype) # [B, F, C, H, W]
|
|
|
|
@ -1366,9 +1323,10 @@ def main(args):
|
|
|
|
|
tokenizer,
|
|
|
|
|
text_encoder,
|
|
|
|
|
prompts,
|
|
|
|
|
model_config.max_text_seq_length,
|
|
|
|
|
accelerator.device,
|
|
|
|
|
weight_dtype,
|
|
|
|
|
requires_grad=args.train_text_encoder,
|
|
|
|
|
requires_grad=False,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Sample noise that will be added to the latents
|
|
|
|
@ -1422,16 +1380,14 @@ def main(args):
|
|
|
|
|
accelerator.backward(loss)
|
|
|
|
|
|
|
|
|
|
if accelerator.sync_gradients:
|
|
|
|
|
params_to_clip = (
|
|
|
|
|
itertools.chain(transformer.parameters(), text_encoder.parameters())
|
|
|
|
|
if args.train_text_encoder
|
|
|
|
|
else transformer.parameters()
|
|
|
|
|
)
|
|
|
|
|
params_to_clip = transformer.parameters()
|
|
|
|
|
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
|
|
|
|
|
|
|
|
|
optimizer.step()
|
|
|
|
|
if accelerator.state.deepspeed_plugin is None:
|
|
|
|
|
optimizer.step()
|
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
|
|
|
|
lr_scheduler.step()
|
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
|
|
|
|
# Checks if the accelerator has performed an optimization step behind the scenes
|
|
|
|
|
if accelerator.sync_gradients:
|
|
|
|
@ -1507,7 +1463,6 @@ def main(args):
|
|
|
|
|
accelerator.wait_for_everyone()
|
|
|
|
|
if accelerator.is_main_process:
|
|
|
|
|
transformer = unwrap_model(transformer)
|
|
|
|
|
# transformer = transformer.to(torch.float32)
|
|
|
|
|
dtype = (
|
|
|
|
|
torch.float16
|
|
|
|
|
if args.mixed_precision == "fp16"
|
|
|
|
@ -1518,16 +1473,9 @@ def main(args):
|
|
|
|
|
transformer = transformer.to(dtype)
|
|
|
|
|
transformer_lora_layers = get_peft_model_state_dict(transformer)
|
|
|
|
|
|
|
|
|
|
if args.train_text_encoder:
|
|
|
|
|
text_encoder = unwrap_model(text_encoder)
|
|
|
|
|
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder.to(dtype))
|
|
|
|
|
else:
|
|
|
|
|
text_encoder_lora_layers = None
|
|
|
|
|
|
|
|
|
|
CogVideoXPipeline.save_lora_weights(
|
|
|
|
|
save_directory=args.output_dir,
|
|
|
|
|
transformer_lora_layers=transformer_lora_layers,
|
|
|
|
|
text_encoder_lora_layers=text_encoder_lora_layers,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Final test inference
|
|
|
|
@ -1539,6 +1487,11 @@ def main(args):
|
|
|
|
|
)
|
|
|
|
|
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config)
|
|
|
|
|
|
|
|
|
|
if args.enable_slicing:
|
|
|
|
|
pipe.vae.enable_slicing()
|
|
|
|
|
if args.enable_tiling:
|
|
|
|
|
pipe.vae.enable_tiling()
|
|
|
|
|
|
|
|
|
|
# Load LoRA weights
|
|
|
|
|
lora_scaling = args.lora_alpha / args.rank
|
|
|
|
|
pipe.load_lora_weights(args.output_dir, adapter_name="cogvideox-lora")
|
|
|
|
@ -1572,7 +1525,6 @@ def main(args):
|
|
|
|
|
repo_id,
|
|
|
|
|
videos=validation_outputs,
|
|
|
|
|
base_model=args.pretrained_model_name_or_path,
|
|
|
|
|
train_text_encoder=args.train_text_encoder,
|
|
|
|
|
validation_prompt=args.validation_prompt,
|
|
|
|
|
repo_folder=args.output_dir,
|
|
|
|
|
fps=args.fps,
|
|
|
|
|