mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-06 03:57:56 +08:00
Merge pull request #317 from THUDM/CogVideoX_dev
Finetune and Readme update
This commit is contained in:
commit
1e5dd975e2
@ -277,6 +277,7 @@ pipe.vae.enable_tiling()
|
|||||||
We highly welcome contributions from the community and actively contribute to the open-source community. The following
|
We highly welcome contributions from the community and actively contribute to the open-source community. The following
|
||||||
works have already been adapted for CogVideoX, and we invite everyone to use them:
|
works have already been adapted for CogVideoX, and we invite everyone to use them:
|
||||||
|
|
||||||
|
+ [CogVideoX-Fun](https://github.com/aigc-apps/CogVideoX-Fun): CogVideoX-Fun is a modified pipeline based on the CogVideoX architecture, supporting flexible resolutions and multiple launch methods.
|
||||||
+ [Xorbits Inference](https://github.com/xorbitsai/inference): A powerful and comprehensive distributed inference
|
+ [Xorbits Inference](https://github.com/xorbitsai/inference): A powerful and comprehensive distributed inference
|
||||||
framework, allowing you to easily deploy your own models or the latest cutting-edge open-source models with just one
|
framework, allowing you to easily deploy your own models or the latest cutting-edge open-source models with just one
|
||||||
click.
|
click.
|
||||||
|
@ -261,6 +261,7 @@ pipe.vae.enable_tiling()
|
|||||||
|
|
||||||
コミュニティからの貢献を大歓迎し、私たちもオープンソースコミュニティに積極的に貢献しています。以下の作品はすでにCogVideoXに対応しており、ぜひご利用ください:
|
コミュニティからの貢献を大歓迎し、私たちもオープンソースコミュニティに積極的に貢献しています。以下の作品はすでにCogVideoXに対応しており、ぜひご利用ください:
|
||||||
|
|
||||||
|
+ [CogVideoX-Fun](https://github.com/aigc-apps/CogVideoX-Fun): CogVideoX-Funは、CogVideoXアーキテクチャを基にした改良パイプラインで、自由な解像度と複数の起動方法をサポートしています。
|
||||||
+ [Xorbits Inference](https://github.com/xorbitsai/inference):
|
+ [Xorbits Inference](https://github.com/xorbitsai/inference):
|
||||||
強力で包括的な分散推論フレームワークであり、ワンクリックで独自のモデルや最新のオープンソースモデルを簡単にデプロイできます。
|
強力で包括的な分散推論フレームワークであり、ワンクリックで独自のモデルや最新のオープンソースモデルを簡単にデプロイできます。
|
||||||
+ [ComfyUI-CogVideoXWrapper](https://github.com/kijai/ComfyUI-CogVideoXWrapper)
|
+ [ComfyUI-CogVideoXWrapper](https://github.com/kijai/ComfyUI-CogVideoXWrapper)
|
||||||
|
@ -248,7 +248,7 @@ pipe.vae.enable_tiling()
|
|||||||
## 友情链接
|
## 友情链接
|
||||||
|
|
||||||
我们非常欢迎来自社区的贡献,并积极的贡献开源社区。以下作品已经对CogVideoX进行了适配,欢迎大家使用:
|
我们非常欢迎来自社区的贡献,并积极的贡献开源社区。以下作品已经对CogVideoX进行了适配,欢迎大家使用:
|
||||||
|
+ [CogVideoX-Fun](https://github.com/aigc-apps/CogVideoX-Fun): CogVideoX-Fun是一个基于CogVideoX结构修改后的的pipeline,支持自由的分辨率,多种启动方式。
|
||||||
+ [Xorbits Inference](https://github.com/xorbitsai/inference): 性能强大且功能全面的分布式推理框架,轻松一键部署你自己的模型或内置的前沿开源模型。
|
+ [Xorbits Inference](https://github.com/xorbitsai/inference): 性能强大且功能全面的分布式推理框架,轻松一键部署你自己的模型或内置的前沿开源模型。
|
||||||
+ [ComfyUI-CogVideoXWrapper](https://github.com/kijai/ComfyUI-CogVideoXWrapper) 使用ComfyUI框架,将CogVideoX加入到你的工作流中。
|
+ [ComfyUI-CogVideoXWrapper](https://github.com/kijai/ComfyUI-CogVideoXWrapper) 使用ComfyUI框架,将CogVideoX加入到你的工作流中。
|
||||||
+ [VideoSys](https://github.com/NUS-HPC-AI-Lab/VideoSys): VideoSys 提供了易用且高性能的视频生成基础设施,支持完整的管道,并持续集成最新的模型和技术。
|
+ [VideoSys](https://github.com/NUS-HPC-AI-Lab/VideoSys): VideoSys 提供了易用且高性能的视频生成基础设施,支持完整的管道,并持续集成最新的模型和技术。
|
||||||
|
@ -9,9 +9,9 @@ see [here](../sat/README_zh.md). The dataset format is different from this versi
|
|||||||
|
|
||||||
## Hardware Requirements
|
## Hardware Requirements
|
||||||
|
|
||||||
+ CogVideoX-2B LoRA: 1 * A100
|
+ CogVideoX-2B / 5B LoRA: 1 * A100 (5B need to use `--use_8bit_adam`)
|
||||||
+ CogVideoX-2B SFT: 8 * A100
|
+ CogVideoX-2B SFT: 8 * A100 (Working)
|
||||||
+ CogVideoX-5B/5B-I2V is not supported yet.
|
+ CogVideoX-5B-I2V is not supported yet.
|
||||||
|
|
||||||
## Install Dependencies
|
## Install Dependencies
|
||||||
|
|
||||||
@ -20,8 +20,7 @@ diffusers branch. Please follow the steps below to install dependencies:
|
|||||||
|
|
||||||
```shell
|
```shell
|
||||||
git clone https://github.com/huggingface/diffusers.git
|
git clone https://github.com/huggingface/diffusers.git
|
||||||
cd diffusers
|
cd diffusers # Now in Main branch
|
||||||
git checkout cogvideox-lora-and-training
|
|
||||||
pip install -e .
|
pip install -e .
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -124,13 +123,13 @@ accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gp
|
|||||||
Single GPU fine-tuning:
|
Single GPU fine-tuning:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
bash finetune_single_gpu.sh
|
bash finetune_single_rank.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
Multi-GPU fine-tuning:
|
Multi-GPU fine-tuning:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
bash finetune_multi_gpus_1.sh # Needs to be run on each node
|
bash finetune_multi_rank.sh # Needs to be run on each node
|
||||||
```
|
```
|
||||||
|
|
||||||
## Loading the Fine-tuned Model
|
## Loading the Fine-tuned Model
|
||||||
|
@ -9,9 +9,9 @@
|
|||||||
|
|
||||||
## ハードウェア要件
|
## ハードウェア要件
|
||||||
|
|
||||||
+ CogVideoX-2B LORA: 1 * A100
|
+ CogVideoX-2B / 5B T2V LORA: 1 * A100 (5B need to use `--use_8bit_adam`)
|
||||||
+ CogVideoX-2B SFT: 8 * A100
|
+ CogVideoX-2B SFT: 8 * A100 (動作確認済み)
|
||||||
+ CogVideoX-5B/5B-I2V まだサポートしていません
|
+ CogVideoX-5B-I2V まだサポートしていません
|
||||||
|
|
||||||
## 依存関係のインストール
|
## 依存関係のインストール
|
||||||
|
|
||||||
@ -19,8 +19,7 @@
|
|||||||
|
|
||||||
```shell
|
```shell
|
||||||
git clone https://github.com/huggingface/diffusers.git
|
git clone https://github.com/huggingface/diffusers.git
|
||||||
cd diffusers
|
cd diffusers # Now in Main branch
|
||||||
git checkout cogvideox-lora-and-training
|
|
||||||
pip install -e .
|
pip install -e .
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -120,13 +119,13 @@ accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gp
|
|||||||
単一GPU微調整:
|
単一GPU微調整:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
bash finetune_single_gpu.sh
|
bash finetune_single_rank.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
複数GPU微調整:
|
複数GPU微調整:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
bash finetune_multi_gpus_1.sh # 各ノードで実行する必要があります。
|
bash finetune_multi_rank.sh # 各ノードで実行する必要があります。
|
||||||
```
|
```
|
||||||
|
|
||||||
## 微調整済みモデルのロード
|
## 微調整済みモデルのロード
|
||||||
|
@ -8,9 +8,9 @@
|
|||||||
|
|
||||||
## 硬件要求
|
## 硬件要求
|
||||||
|
|
||||||
+ CogVideoX-2B LORA: 1 * A100
|
+ CogVideoX-2B / 5B T2V LORA: 1 * A100 (5B need to use `--use_8bit_adam`)
|
||||||
+ CogVideoX-2B SFT: 8 * A100
|
+ CogVideoX-2B SFT: 8 * A100 (制作中)
|
||||||
+ CogVideoX-5B/5B-I2V 暂未支持
|
+ CogVideoX-5B-I2V 暂未支持
|
||||||
|
|
||||||
## 安装依赖
|
## 安装依赖
|
||||||
|
|
||||||
@ -18,8 +18,7 @@
|
|||||||
|
|
||||||
```shell
|
```shell
|
||||||
git clone https://github.com/huggingface/diffusers.git
|
git clone https://github.com/huggingface/diffusers.git
|
||||||
cd diffusers
|
cd diffusers # Now in Main branch
|
||||||
git checkout cogvideox-lora-and-training
|
|
||||||
pip install -e .
|
pip install -e .
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -150,13 +149,13 @@ accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gp
|
|||||||
单卡微调:
|
单卡微调:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
bash finetune_single_gpu.sh
|
bash finetune_single_rank.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
多卡微调:
|
多卡微调:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
bash finetune_multi_gpus_1.sh #需要在每个节点运行
|
bash finetune_multi_rank.sh #需要在每个节点运行
|
||||||
```
|
```
|
||||||
|
|
||||||
## 载入微调的模型
|
## 载入微调的模型
|
||||||
|
@ -9,6 +9,7 @@ export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES
|
|||||||
|
|
||||||
accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gpu --machine_rank 0 \
|
accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gpu --machine_rank 0 \
|
||||||
train_cogvideox_lora.py \
|
train_cogvideox_lora.py \
|
||||||
|
--gradient_checkpointing \
|
||||||
--pretrained_model_name_or_path $MODEL_PATH \
|
--pretrained_model_name_or_path $MODEL_PATH \
|
||||||
--cache_dir $CACHE_PATH \
|
--cache_dir $CACHE_PATH \
|
||||||
--enable_tiling \
|
--enable_tiling \
|
@ -3,12 +3,17 @@
|
|||||||
export MODEL_PATH="THUDM/CogVideoX-2b"
|
export MODEL_PATH="THUDM/CogVideoX-2b"
|
||||||
export CACHE_PATH="~/.cache"
|
export CACHE_PATH="~/.cache"
|
||||||
export DATASET_PATH="Disney-VideoGeneration-Dataset"
|
export DATASET_PATH="Disney-VideoGeneration-Dataset"
|
||||||
export OUTPUT_PATH="cogvideox-lora-single-gpu"
|
export OUTPUT_PATH="cogvideox-lora-multi-gpu"
|
||||||
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
||||||
export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES
|
export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES
|
||||||
|
|
||||||
|
|
||||||
|
# --use_8bit_adam is necessary for CogVideoX-5B-I2V
|
||||||
|
# if you are not using wth 8 gus, change `accelerate_config_machine_single.yaml` num_processes as your gpu number
|
||||||
accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gpu \
|
accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gpu \
|
||||||
train_cogvideox_lora.py \
|
train_cogvideox_lora.py \
|
||||||
|
--gradient_checkpointing \
|
||||||
|
--use_8bit_adam \
|
||||||
--pretrained_model_name_or_path $MODEL_PATH \
|
--pretrained_model_name_or_path $MODEL_PATH \
|
||||||
--cache_dir $CACHE_PATH \
|
--cache_dir $CACHE_PATH \
|
||||||
--enable_tiling \
|
--enable_tiling \
|
@ -14,7 +14,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import itertools
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
@ -32,7 +31,7 @@ from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dic
|
|||||||
from torch.utils.data import DataLoader, Dataset
|
from torch.utils.data import DataLoader, Dataset
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from transformers import T5EncoderModel, T5Tokenizer
|
from transformers import AutoTokenizer, T5EncoderModel, T5Tokenizer
|
||||||
|
|
||||||
import diffusers
|
import diffusers
|
||||||
from diffusers import AutoencoderKLCogVideoX, CogVideoXDPMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel
|
from diffusers import AutoencoderKLCogVideoX, CogVideoXDPMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel
|
||||||
@ -40,7 +39,6 @@ from diffusers.models.embeddings import get_3d_rotary_pos_embed
|
|||||||
from diffusers.optimization import get_scheduler
|
from diffusers.optimization import get_scheduler
|
||||||
from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
|
from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
|
||||||
from diffusers.training_utils import (
|
from diffusers.training_utils import (
|
||||||
_set_state_dict_into_text_encoder,
|
|
||||||
cast_training_params,
|
cast_training_params,
|
||||||
clear_objs_and_retain_memory,
|
clear_objs_and_retain_memory,
|
||||||
)
|
)
|
||||||
@ -48,6 +46,7 @@ from diffusers.utils import check_min_version, convert_unet_state_dict_to_peft,
|
|||||||
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||||
from diffusers.utils.torch_utils import is_compiled_module
|
from diffusers.utils.torch_utils import is_compiled_module
|
||||||
|
|
||||||
|
|
||||||
if is_wandb_available():
|
if is_wandb_available():
|
||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
@ -239,11 +238,6 @@ def get_args():
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="whether to randomly flip videos horizontally",
|
help="whether to randomly flip videos horizontally",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--train_text_encoder",
|
|
||||||
action="store_true",
|
|
||||||
help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
|
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
|
||||||
)
|
)
|
||||||
@ -296,12 +290,6 @@ def get_args():
|
|||||||
default=1e-4,
|
default=1e-4,
|
||||||
help="Initial learning rate (after the potential warmup period) to use.",
|
help="Initial learning rate (after the potential warmup period) to use.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--text_encoder_lr",
|
|
||||||
type=float,
|
|
||||||
default=5e-6,
|
|
||||||
help="Text encoder learning rate to use.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--scale_lr",
|
"--scale_lr",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
@ -367,9 +355,6 @@ def get_args():
|
|||||||
)
|
)
|
||||||
parser.add_argument("--prodigy_decouple", action="store_true", help="Use AdamW style decoupled weight decay")
|
parser.add_argument("--prodigy_decouple", action="store_true", help="Use AdamW style decoupled weight decay")
|
||||||
parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
|
parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
|
||||||
parser.add_argument(
|
|
||||||
"--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--adam_epsilon",
|
"--adam_epsilon",
|
||||||
type=float,
|
type=float,
|
||||||
@ -423,20 +408,20 @@ def get_args():
|
|||||||
|
|
||||||
class VideoDataset(Dataset):
|
class VideoDataset(Dataset):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
instance_data_root: Optional[str] = None,
|
instance_data_root: Optional[str] = None,
|
||||||
dataset_name: Optional[str] = None,
|
dataset_name: Optional[str] = None,
|
||||||
dataset_config_name: Optional[str] = None,
|
dataset_config_name: Optional[str] = None,
|
||||||
caption_column: str = "text",
|
caption_column: str = "text",
|
||||||
video_column: str = "video",
|
video_column: str = "video",
|
||||||
height: int = 480,
|
height: int = 480,
|
||||||
width: int = 720,
|
width: int = 720,
|
||||||
fps: int = 8,
|
fps: int = 8,
|
||||||
max_num_frames: int = 49,
|
max_num_frames: int = 49,
|
||||||
skip_frames_start: int = 0,
|
skip_frames_start: int = 0,
|
||||||
skip_frames_end: int = 0,
|
skip_frames_end: int = 0,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
id_token: Optional[str] = None,
|
id_token: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -555,7 +540,7 @@ class VideoDataset(Dataset):
|
|||||||
import decord
|
import decord
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"The `decord` package is required for loading the video dataset. Install with `pip install dataset`"
|
"The `decord` package is required for loading the video dataset. Install with `pip install decord`"
|
||||||
)
|
)
|
||||||
|
|
||||||
decord.bridge.set_bridge("torch")
|
decord.bridge.set_bridge("torch")
|
||||||
@ -602,13 +587,12 @@ class VideoDataset(Dataset):
|
|||||||
|
|
||||||
|
|
||||||
def save_model_card(
|
def save_model_card(
|
||||||
repo_id: str,
|
repo_id: str,
|
||||||
videos=None,
|
videos=None,
|
||||||
base_model: str = None,
|
base_model: str = None,
|
||||||
train_text_encoder=False,
|
validation_prompt=None,
|
||||||
validation_prompt=None,
|
repo_folder=None,
|
||||||
repo_folder=None,
|
fps=8,
|
||||||
fps=8,
|
|
||||||
):
|
):
|
||||||
widget_dict = []
|
widget_dict = []
|
||||||
if videos is not None:
|
if videos is not None:
|
||||||
@ -627,9 +611,9 @@ def save_model_card(
|
|||||||
|
|
||||||
These are {repo_id} LoRA weights for {base_model}.
|
These are {repo_id} LoRA weights for {base_model}.
|
||||||
|
|
||||||
The weights were trained using the [CogVideoX Diffusers trainer](TODO).
|
The weights were trained using the [CogVideoX Diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/cogvideo/train_cogvideox_lora.py).
|
||||||
|
|
||||||
Was LoRA for the text encoder enabled? {train_text_encoder}.
|
Was LoRA for the text encoder enabled? No.
|
||||||
|
|
||||||
## Download model
|
## Download model
|
||||||
|
|
||||||
@ -642,8 +626,15 @@ from diffusers import CogVideoXPipeline
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to("cuda")
|
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to("cuda")
|
||||||
pipe.load_lora_weights("{repo_id}", weight_name="pytorch_lora_weights.safetensors")
|
pipe.load_lora_weights("{repo_id}", weight_name="pytorch_lora_weights.safetensors", adapter_name=["cogvideox-lora"])
|
||||||
video = pipe("{validation_prompt}").frames[0]
|
|
||||||
|
# The LoRA adapter weights are determined by what was used for training.
|
||||||
|
# In this case, we assume `--lora_alpha` is 32 and `--rank` is 64.
|
||||||
|
# It can be made lower or higher from what was used in training to decrease or amplify the effect
|
||||||
|
# of the LoRA upto a tolerance, beyond which one might notice no effect at all or overflows.
|
||||||
|
pipe.set_adapters(["cogvideox-lora"], [32 / 64])
|
||||||
|
|
||||||
|
video = pipe("{validation_prompt}", guidance_scale=6, use_dynamic_cfg=True).frames[0]
|
||||||
```
|
```
|
||||||
|
|
||||||
For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
|
For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
|
||||||
@ -676,12 +667,12 @@ Please adhere to the licensing terms as described [here](https://huggingface.co/
|
|||||||
|
|
||||||
|
|
||||||
def log_validation(
|
def log_validation(
|
||||||
pipe,
|
pipe,
|
||||||
args,
|
args,
|
||||||
accelerator,
|
accelerator,
|
||||||
pipeline_args,
|
pipeline_args,
|
||||||
epoch,
|
epoch,
|
||||||
is_final_validation: bool = False,
|
is_final_validation: bool = False,
|
||||||
):
|
):
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}."
|
f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}."
|
||||||
@ -741,14 +732,14 @@ def log_validation(
|
|||||||
|
|
||||||
|
|
||||||
def _get_t5_prompt_embeds(
|
def _get_t5_prompt_embeds(
|
||||||
tokenizer: T5Tokenizer,
|
tokenizer: T5Tokenizer,
|
||||||
text_encoder: T5EncoderModel,
|
text_encoder: T5EncoderModel,
|
||||||
prompt: Union[str, List[str]],
|
prompt: Union[str, List[str]],
|
||||||
num_videos_per_prompt: int = 1,
|
num_videos_per_prompt: int = 1,
|
||||||
max_sequence_length: int = 226,
|
max_sequence_length: int = 226,
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
text_input_ids=None,
|
text_input_ids=None,
|
||||||
):
|
):
|
||||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||||
batch_size = len(prompt)
|
batch_size = len(prompt)
|
||||||
@ -779,14 +770,14 @@ def _get_t5_prompt_embeds(
|
|||||||
|
|
||||||
|
|
||||||
def encode_prompt(
|
def encode_prompt(
|
||||||
tokenizer: T5Tokenizer,
|
tokenizer: T5Tokenizer,
|
||||||
text_encoder: T5EncoderModel,
|
text_encoder: T5EncoderModel,
|
||||||
prompt: Union[str, List[str]],
|
prompt: Union[str, List[str]],
|
||||||
num_videos_per_prompt: int = 1,
|
num_videos_per_prompt: int = 1,
|
||||||
max_sequence_length: int = 226,
|
max_sequence_length: int = 226,
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
text_input_ids=None,
|
text_input_ids=None,
|
||||||
):
|
):
|
||||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||||
prompt_embeds = _get_t5_prompt_embeds(
|
prompt_embeds = _get_t5_prompt_embeds(
|
||||||
@ -802,13 +793,16 @@ def encode_prompt(
|
|||||||
return prompt_embeds
|
return prompt_embeds
|
||||||
|
|
||||||
|
|
||||||
def compute_prompt_embeddings(tokenizer, text_encoder, prompt, device, dtype, requires_grad: bool = False):
|
def compute_prompt_embeddings(
|
||||||
|
tokenizer, text_encoder, prompt, max_sequence_length, device, dtype, requires_grad: bool = False
|
||||||
|
):
|
||||||
if requires_grad:
|
if requires_grad:
|
||||||
prompt_embeds = encode_prompt(
|
prompt_embeds = encode_prompt(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
text_encoder,
|
text_encoder,
|
||||||
prompt,
|
prompt,
|
||||||
num_videos_per_prompt=1,
|
num_videos_per_prompt=1,
|
||||||
|
max_sequence_length=max_sequence_length,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
@ -819,6 +813,7 @@ def compute_prompt_embeddings(tokenizer, text_encoder, prompt, device, dtype, re
|
|||||||
text_encoder,
|
text_encoder,
|
||||||
prompt,
|
prompt,
|
||||||
num_videos_per_prompt=1,
|
num_videos_per_prompt=1,
|
||||||
|
max_sequence_length=max_sequence_length,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
@ -826,15 +821,15 @@ def compute_prompt_embeddings(tokenizer, text_encoder, prompt, device, dtype, re
|
|||||||
|
|
||||||
|
|
||||||
def prepare_rotary_positional_embeddings(
|
def prepare_rotary_positional_embeddings(
|
||||||
height: int,
|
height: int,
|
||||||
width: int,
|
width: int,
|
||||||
num_frames: int,
|
num_frames: int,
|
||||||
vae_scale_factor_spatial: int = 8,
|
vae_scale_factor_spatial: int = 8,
|
||||||
patch_size: int = 2,
|
patch_size: int = 2,
|
||||||
attention_head_dim: int = 64,
|
attention_head_dim: int = 64,
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
base_height: int = 480,
|
base_height: int = 480,
|
||||||
base_width: int = 720,
|
base_width: int = 720,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
grid_height = height // (vae_scale_factor_spatial * patch_size)
|
grid_height = height // (vae_scale_factor_spatial * patch_size)
|
||||||
grid_width = width // (vae_scale_factor_spatial * patch_size)
|
grid_width = width // (vae_scale_factor_spatial * patch_size)
|
||||||
@ -854,10 +849,22 @@ def prepare_rotary_positional_embeddings(
|
|||||||
return freqs_cos, freqs_sin
|
return freqs_cos, freqs_sin
|
||||||
|
|
||||||
|
|
||||||
def get_optimizer(args, params_to_optimize):
|
def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False):
|
||||||
|
# Use DeepSpeed optimzer
|
||||||
|
if use_deepspeed:
|
||||||
|
from accelerate.utils import DummyOptim
|
||||||
|
|
||||||
|
return DummyOptim(
|
||||||
|
params_to_optimize,
|
||||||
|
lr=args.learning_rate,
|
||||||
|
betas=(args.adam_beta1, args.adam_beta2),
|
||||||
|
eps=args.adam_epsilon,
|
||||||
|
weight_decay=args.adam_weight_decay,
|
||||||
|
)
|
||||||
|
|
||||||
# Optimizer creation
|
# Optimizer creation
|
||||||
supported_optimizers = ["adam", "adamw", "prodigy"]
|
supported_optimizers = ["adam", "adamw", "prodigy"]
|
||||||
if args.optimizer not in ["adam", "adamw", "prodigy"]:
|
if args.optimizer not in supported_optimizers:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Unsupported choice of optimizer: {args.optimizer}. Supported optimizers include {supported_optimizers}. Defaulting to AdamW"
|
f"Unsupported choice of optimizer: {args.optimizer}. Supported optimizers include {supported_optimizers}. Defaulting to AdamW"
|
||||||
)
|
)
|
||||||
@ -907,14 +914,6 @@ def get_optimizer(args, params_to_optimize):
|
|||||||
logger.warning(
|
logger.warning(
|
||||||
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
|
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
|
||||||
)
|
)
|
||||||
if args.train_text_encoder and args.text_encoder_lr:
|
|
||||||
logger.warning(
|
|
||||||
f"Learning rates were provided both for the transformer and the text encoder - e.g. text_encoder_lr:"
|
|
||||||
f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
|
|
||||||
f"When using prodigy only learning_rate is used as the initial learning rate."
|
|
||||||
)
|
|
||||||
# Changes the learning rate of text_encoder_parameters to be --learning_rate
|
|
||||||
params_to_optimize[1]["lr"] = args.learning_rate
|
|
||||||
|
|
||||||
optimizer = optimizer_class(
|
optimizer = optimizer_class(
|
||||||
params_to_optimize,
|
params_to_optimize,
|
||||||
@ -994,7 +993,7 @@ def main(args):
|
|||||||
).repo_id
|
).repo_id
|
||||||
|
|
||||||
# Prepare models and scheduler
|
# Prepare models and scheduler
|
||||||
tokenizer = T5Tokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
|
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1035,13 +1034,13 @@ def main(args):
|
|||||||
if accelerator.state.deepspeed_plugin:
|
if accelerator.state.deepspeed_plugin:
|
||||||
# DeepSpeed is handling precision, use what's in the DeepSpeed config
|
# DeepSpeed is handling precision, use what's in the DeepSpeed config
|
||||||
if (
|
if (
|
||||||
"fp16" in accelerator.state.deepspeed_plugin.deepspeed_config
|
"fp16" in accelerator.state.deepspeed_plugin.deepspeed_config
|
||||||
and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"]
|
and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"]
|
||||||
):
|
):
|
||||||
weight_dtype = torch.float16
|
weight_dtype = torch.float16
|
||||||
if (
|
if (
|
||||||
"bf16" in accelerator.state.deepspeed_plugin.deepspeed_config
|
"bf16" in accelerator.state.deepspeed_plugin.deepspeed_config
|
||||||
and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"]
|
and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"]
|
||||||
):
|
):
|
||||||
weight_dtype = torch.float16
|
weight_dtype = torch.float16
|
||||||
else:
|
else:
|
||||||
@ -1062,8 +1061,6 @@ def main(args):
|
|||||||
|
|
||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
transformer.enable_gradient_checkpointing()
|
transformer.enable_gradient_checkpointing()
|
||||||
if args.train_text_encoder:
|
|
||||||
text_encoder.gradient_checkpointing_enable()
|
|
||||||
|
|
||||||
# now we will add new LoRA weights to the attention layers
|
# now we will add new LoRA weights to the attention layers
|
||||||
transformer_lora_config = LoraConfig(
|
transformer_lora_config = LoraConfig(
|
||||||
@ -1073,14 +1070,6 @@ def main(args):
|
|||||||
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
||||||
)
|
)
|
||||||
transformer.add_adapter(transformer_lora_config)
|
transformer.add_adapter(transformer_lora_config)
|
||||||
if args.train_text_encoder:
|
|
||||||
text_lora_config = LoraConfig(
|
|
||||||
r=args.rank,
|
|
||||||
lora_alpha=args.lora_alpha,
|
|
||||||
init_lora_weights=True,
|
|
||||||
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
|
|
||||||
)
|
|
||||||
text_encoder.add_adapter(text_lora_config)
|
|
||||||
|
|
||||||
def unwrap_model(model):
|
def unwrap_model(model):
|
||||||
model = accelerator.unwrap_model(model)
|
model = accelerator.unwrap_model(model)
|
||||||
@ -1091,13 +1080,10 @@ def main(args):
|
|||||||
def save_model_hook(models, weights, output_dir):
|
def save_model_hook(models, weights, output_dir):
|
||||||
if accelerator.is_main_process:
|
if accelerator.is_main_process:
|
||||||
transformer_lora_layers_to_save = None
|
transformer_lora_layers_to_save = None
|
||||||
text_encoder_lora_layers_to_save = None
|
|
||||||
|
|
||||||
for model in models:
|
for model in models:
|
||||||
if isinstance(model, type(unwrap_model(transformer))):
|
if isinstance(model, type(unwrap_model(transformer))):
|
||||||
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
|
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||||
elif isinstance(model, type(unwrap_model(text_encoder))):
|
|
||||||
text_encoder_lora_layers_to_save = get_peft_model_state_dict(model)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||||
|
|
||||||
@ -1107,22 +1093,18 @@ def main(args):
|
|||||||
CogVideoXPipeline.save_lora_weights(
|
CogVideoXPipeline.save_lora_weights(
|
||||||
output_dir,
|
output_dir,
|
||||||
transformer_lora_layers=transformer_lora_layers_to_save,
|
transformer_lora_layers=transformer_lora_layers_to_save,
|
||||||
text_encoder_lora_layers=text_encoder_lora_layers_to_save,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_model_hook(models, input_dir):
|
def load_model_hook(models, input_dir):
|
||||||
transformer_ = None
|
transformer_ = None
|
||||||
text_encoder_ = None
|
|
||||||
|
|
||||||
while len(models) > 0:
|
while len(models) > 0:
|
||||||
model = models.pop()
|
model = models.pop()
|
||||||
|
|
||||||
if isinstance(model, type(unwrap_model(transformer))):
|
if isinstance(model, type(unwrap_model(transformer))):
|
||||||
transformer_ = model
|
transformer_ = model
|
||||||
elif isinstance(model, type(unwrap_model(text_encoder))):
|
|
||||||
text_encoder_ = model
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
raise ValueError(f"Unexpected save model: {model.__class__}")
|
||||||
|
|
||||||
lora_state_dict = CogVideoXPipeline.lora_state_dict(input_dir)
|
lora_state_dict = CogVideoXPipeline.lora_state_dict(input_dir)
|
||||||
|
|
||||||
@ -1139,19 +1121,13 @@ def main(args):
|
|||||||
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
|
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
|
||||||
f" {unexpected_keys}. "
|
f" {unexpected_keys}. "
|
||||||
)
|
)
|
||||||
if args.train_text_encoder:
|
|
||||||
# Do we need to call `scale_lora_layers()` here?
|
|
||||||
_set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_)
|
|
||||||
|
|
||||||
# Make sure the trainable params are in float32. This is again needed since the base models
|
# Make sure the trainable params are in float32. This is again needed since the base models
|
||||||
# are in `weight_dtype`. More details:
|
# are in `weight_dtype`. More details:
|
||||||
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
|
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
|
||||||
if args.mixed_precision == "fp16":
|
if args.mixed_precision == "fp16":
|
||||||
models = [transformer_]
|
|
||||||
if args.train_text_encoder:
|
|
||||||
models.extend([text_encoder_])
|
|
||||||
# only upcast trainable parameters (LoRA) into fp32
|
# only upcast trainable parameters (LoRA) into fp32
|
||||||
cast_training_params(models)
|
cast_training_params([transformer_])
|
||||||
|
|
||||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||||
accelerator.register_load_state_pre_hook(load_model_hook)
|
accelerator.register_load_state_pre_hook(load_model_hook)
|
||||||
@ -1163,38 +1139,30 @@ def main(args):
|
|||||||
|
|
||||||
if args.scale_lr:
|
if args.scale_lr:
|
||||||
args.learning_rate = (
|
args.learning_rate = (
|
||||||
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
||||||
)
|
)
|
||||||
|
|
||||||
# Make sure the trainable params are in float32.
|
# Make sure the trainable params are in float32.
|
||||||
if args.mixed_precision == "fp16":
|
if args.mixed_precision == "fp16":
|
||||||
models = [transformer]
|
|
||||||
if args.train_text_encoder:
|
|
||||||
models.extend([text_encoder])
|
|
||||||
# only upcast trainable parameters (LoRA) into fp32
|
# only upcast trainable parameters (LoRA) into fp32
|
||||||
cast_training_params(models, dtype=torch.float32)
|
cast_training_params([transformer], dtype=torch.float32)
|
||||||
|
|
||||||
transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
|
transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
|
||||||
if args.train_text_encoder:
|
|
||||||
text_encoder_lora_parameters = list(filter(lambda p: p.requires_grad, text_encoder.parameters()))
|
|
||||||
|
|
||||||
# Optimization parameters
|
# Optimization parameters
|
||||||
transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate}
|
transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate}
|
||||||
if args.train_text_encoder:
|
params_to_optimize = [transformer_parameters_with_lr]
|
||||||
# different learning rate for text encoder and unet
|
|
||||||
text_encoder_parameters_with_lr = {
|
|
||||||
"params": text_encoder_lora_parameters,
|
|
||||||
"weight_decay": args.adam_weight_decay_text_encoder,
|
|
||||||
"lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
|
|
||||||
}
|
|
||||||
params_to_optimize = [
|
|
||||||
transformer_parameters_with_lr,
|
|
||||||
text_encoder_parameters_with_lr,
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
params_to_optimize = [transformer_parameters_with_lr]
|
|
||||||
|
|
||||||
optimizer = get_optimizer(args, params_to_optimize)
|
use_deepspeed_optimizer = (
|
||||||
|
accelerator.state.deepspeed_plugin is not None
|
||||||
|
and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config
|
||||||
|
)
|
||||||
|
use_deepspeed_scheduler = (
|
||||||
|
accelerator.state.deepspeed_plugin is not None
|
||||||
|
and "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer)
|
||||||
|
|
||||||
# Dataset and DataLoader
|
# Dataset and DataLoader
|
||||||
train_dataset = VideoDataset(
|
train_dataset = VideoDataset(
|
||||||
@ -1248,35 +1216,30 @@ def main(args):
|
|||||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||||
overrode_max_train_steps = True
|
overrode_max_train_steps = True
|
||||||
|
|
||||||
lr_scheduler = get_scheduler(
|
if use_deepspeed_scheduler:
|
||||||
args.lr_scheduler,
|
from accelerate.utils import DummyScheduler
|
||||||
optimizer=optimizer,
|
|
||||||
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
|
||||||
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
|
||||||
num_cycles=args.lr_num_cycles,
|
|
||||||
power=args.lr_power,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prepare everything with our `accelerator`.
|
lr_scheduler = DummyScheduler(
|
||||||
if args.train_text_encoder:
|
name=args.lr_scheduler,
|
||||||
(
|
optimizer=optimizer,
|
||||||
transformer,
|
total_num_steps=args.max_train_steps * accelerator.num_processes,
|
||||||
text_encoder,
|
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
||||||
optimizer,
|
|
||||||
train_dataloader,
|
|
||||||
lr_scheduler,
|
|
||||||
) = accelerator.prepare(
|
|
||||||
transformer,
|
|
||||||
text_encoder,
|
|
||||||
optimizer,
|
|
||||||
train_dataloader,
|
|
||||||
lr_scheduler,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
lr_scheduler = get_scheduler(
|
||||||
transformer, optimizer, train_dataloader, lr_scheduler
|
args.lr_scheduler,
|
||||||
|
optimizer=optimizer,
|
||||||
|
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
||||||
|
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
||||||
|
num_cycles=args.lr_num_cycles,
|
||||||
|
power=args.lr_power,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Prepare everything with our `accelerator`.
|
||||||
|
transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||||
|
transformer, optimizer, train_dataloader, lr_scheduler
|
||||||
|
)
|
||||||
|
|
||||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||||
if overrode_max_train_steps:
|
if overrode_max_train_steps:
|
||||||
@ -1347,15 +1310,9 @@ def main(args):
|
|||||||
|
|
||||||
for epoch in range(first_epoch, args.num_train_epochs):
|
for epoch in range(first_epoch, args.num_train_epochs):
|
||||||
transformer.train()
|
transformer.train()
|
||||||
if args.train_text_encoder:
|
|
||||||
text_encoder.train()
|
|
||||||
# set top parameter requires_grad = True for gradient checkpointing works
|
|
||||||
accelerator.unwrap_model(text_encoder).text_model.embeddings.requires_grad_(True)
|
|
||||||
|
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
models_to_accumulate = [transformer]
|
models_to_accumulate = [transformer]
|
||||||
if args.train_text_encoder:
|
|
||||||
models_to_accumulate.extend([text_encoder])
|
|
||||||
|
|
||||||
with accelerator.accumulate(models_to_accumulate):
|
with accelerator.accumulate(models_to_accumulate):
|
||||||
model_input = batch["videos"].permute(0, 2, 1, 3, 4).to(dtype=weight_dtype) # [B, F, C, H, W]
|
model_input = batch["videos"].permute(0, 2, 1, 3, 4).to(dtype=weight_dtype) # [B, F, C, H, W]
|
||||||
@ -1366,9 +1323,10 @@ def main(args):
|
|||||||
tokenizer,
|
tokenizer,
|
||||||
text_encoder,
|
text_encoder,
|
||||||
prompts,
|
prompts,
|
||||||
|
model_config.max_text_seq_length,
|
||||||
accelerator.device,
|
accelerator.device,
|
||||||
weight_dtype,
|
weight_dtype,
|
||||||
requires_grad=args.train_text_encoder,
|
requires_grad=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Sample noise that will be added to the latents
|
# Sample noise that will be added to the latents
|
||||||
@ -1422,16 +1380,14 @@ def main(args):
|
|||||||
accelerator.backward(loss)
|
accelerator.backward(loss)
|
||||||
|
|
||||||
if accelerator.sync_gradients:
|
if accelerator.sync_gradients:
|
||||||
params_to_clip = (
|
params_to_clip = transformer.parameters()
|
||||||
itertools.chain(transformer.parameters(), text_encoder.parameters())
|
|
||||||
if args.train_text_encoder
|
|
||||||
else transformer.parameters()
|
|
||||||
)
|
|
||||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||||
|
|
||||||
optimizer.step()
|
if accelerator.state.deepspeed_plugin is None:
|
||||||
|
optimizer.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
lr_scheduler.step()
|
lr_scheduler.step()
|
||||||
optimizer.zero_grad()
|
|
||||||
|
|
||||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||||
if accelerator.sync_gradients:
|
if accelerator.sync_gradients:
|
||||||
@ -1507,7 +1463,6 @@ def main(args):
|
|||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
if accelerator.is_main_process:
|
if accelerator.is_main_process:
|
||||||
transformer = unwrap_model(transformer)
|
transformer = unwrap_model(transformer)
|
||||||
# transformer = transformer.to(torch.float32)
|
|
||||||
dtype = (
|
dtype = (
|
||||||
torch.float16
|
torch.float16
|
||||||
if args.mixed_precision == "fp16"
|
if args.mixed_precision == "fp16"
|
||||||
@ -1518,16 +1473,9 @@ def main(args):
|
|||||||
transformer = transformer.to(dtype)
|
transformer = transformer.to(dtype)
|
||||||
transformer_lora_layers = get_peft_model_state_dict(transformer)
|
transformer_lora_layers = get_peft_model_state_dict(transformer)
|
||||||
|
|
||||||
if args.train_text_encoder:
|
|
||||||
text_encoder = unwrap_model(text_encoder)
|
|
||||||
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder.to(dtype))
|
|
||||||
else:
|
|
||||||
text_encoder_lora_layers = None
|
|
||||||
|
|
||||||
CogVideoXPipeline.save_lora_weights(
|
CogVideoXPipeline.save_lora_weights(
|
||||||
save_directory=args.output_dir,
|
save_directory=args.output_dir,
|
||||||
transformer_lora_layers=transformer_lora_layers,
|
transformer_lora_layers=transformer_lora_layers,
|
||||||
text_encoder_lora_layers=text_encoder_lora_layers,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Final test inference
|
# Final test inference
|
||||||
@ -1539,6 +1487,11 @@ def main(args):
|
|||||||
)
|
)
|
||||||
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config)
|
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config)
|
||||||
|
|
||||||
|
if args.enable_slicing:
|
||||||
|
pipe.vae.enable_slicing()
|
||||||
|
if args.enable_tiling:
|
||||||
|
pipe.vae.enable_tiling()
|
||||||
|
|
||||||
# Load LoRA weights
|
# Load LoRA weights
|
||||||
lora_scaling = args.lora_alpha / args.rank
|
lora_scaling = args.lora_alpha / args.rank
|
||||||
pipe.load_lora_weights(args.output_dir, adapter_name="cogvideox-lora")
|
pipe.load_lora_weights(args.output_dir, adapter_name="cogvideox-lora")
|
||||||
@ -1572,7 +1525,6 @@ def main(args):
|
|||||||
repo_id,
|
repo_id,
|
||||||
videos=validation_outputs,
|
videos=validation_outputs,
|
||||||
base_model=args.pretrained_model_name_or_path,
|
base_model=args.pretrained_model_name_or_path,
|
||||||
train_text_encoder=args.train_text_encoder,
|
|
||||||
validation_prompt=args.validation_prompt,
|
validation_prompt=args.validation_prompt,
|
||||||
repo_folder=args.output_dir,
|
repo_folder=args.output_dir,
|
||||||
fps=args.fps,
|
fps=args.fps,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user