Merge pull request #642 from THUDM/CogVideoX_dev

New Lora 20250108
This commit is contained in:
Yuxuan Zhang 2025-01-08 09:51:39 +08:00 committed by GitHub
commit 8f1829f1cd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
36 changed files with 812 additions and 635 deletions

1
.gitignore vendored
View File

@ -22,3 +22,4 @@ venv
**/results
**/*.mp4
**/validation_set
CogVideo-1.0

View File

@ -22,8 +22,9 @@ Experience the CogVideoX-5B model online at <a href="https://huggingface.co/spac
## Project Updates
- 🔥🔥 **News**: ```2024/11/15```: We released the `CogVideoX1.5` model in the diffusers version. Only minor parameter adjustments are needed to continue using previous code.
- 🔥 News: ```2024/11/08```: We have released the CogVideoX1.5 model. CogVideoX1.5 is an upgraded version of the open-source model CogVideoX.
- 🔥🔥 **News**: ```2025/01/08```: We have updated the code for `Lora` fine-tuning based on the `diffusers` version model, which uses less GPU memory. For more details, please see [here](finetune/README.md).
- 🔥 **News**: ```2024/11/15```: We released the `CogVideoX1.5` model in the diffusers version. Only minor parameter adjustments are needed to continue using previous code.
- 🔥 **News**: ```2024/11/08```: We have released the CogVideoX1.5 model. CogVideoX1.5 is an upgraded version of the open-source model CogVideoX.
The CogVideoX1.5-5B series supports 10-second videos with higher resolution, and CogVideoX1.5-5B-I2V supports video generation at any resolution.
The SAT code has already been updated, while the diffusers version is still under adaptation. Download the SAT version code [here](https://huggingface.co/THUDM/CogVideoX1.5-5B-SAT).
- 🔥 **News**: ```2024/10/13```: A more cost-effective fine-tuning framework for `CogVideoX-5B` that works with a single

View File

@ -22,7 +22,8 @@
## 更新とニュース
- 🔥🔥 **ニュース**: ```2024/11/15```: `CogVideoX1.5`モデルのdiffusersバージョンをリリースしました。わずかなパラメータ調整で以前のコードをそのまま利用可能です。
- 🔥🔥 **ニュース**: ```2025/01/08```: 私たちは`diffusers`バージョンのモデルをベースにした`Lora`微調整用のコードを更新しました。より少ないVRAMビデオメモリで動作します。詳細については[こちら](finetune/README_ja.md)をご覧ください。
- 🔥 **ニュース**: ```2024/11/15```: `CogVideoX1.5`モデルのdiffusersバージョンをリリースしました。わずかなパラメータ調整で以前のコードをそのまま利用可能です。
- 🔥 **ニュース**: ```2024/11/08```: `CogVideoX1.5` モデルをリリースしました。CogVideoX1.5 は CogVideoX オープンソースモデルのアップグレードバージョンです。
CogVideoX1.5-5B シリーズモデルは、10秒 長の動画とより高い解像度をサポートしており、`CogVideoX1.5-5B-I2V` は任意の解像度での動画生成に対応しています。
SAT コードはすでに更新されており、`diffusers` バージョンは現在適応中です。

View File

@ -22,7 +22,8 @@
## 项目更新
- 🔥🔥 **News**: ```2024/11/15```: 我们发布 `CogVideoX1.5` 模型的diffusers版本仅需调整部分参数仅可沿用之前的代码。
- 🔥🔥 **News**: ```2025/01/08```: 我们更新了基于`diffusers`版本模型的`Lora`微调代码,占用显存更低,详情请见[这里](finetune/README_zh.md)。
- 🔥 **News**: ```2024/11/15```: 我们发布 `CogVideoX1.5` 模型的diffusers版本仅需调整部分参数仅可沿用之前的代码。
- 🔥 **News**: ```2024/11/08```: 我们发布 `CogVideoX1.5` 模型。CogVideoX1.5 是 CogVideoX 开源模型的升级版本。
CogVideoX1.5-5B 系列模型支持 **10秒** 长度的视频和更高的分辨率,其中 `CogVideoX1.5-5B-I2V` 支持 **任意分辨率** 的视频生成SAT代码已经更新。`diffusers`版本还在适配中。SAT版本代码前往 [这里](https://huggingface.co/THUDM/CogVideoX1.5-5B-SAT) 下载。
- 🔥**News**: ```2024/10/13```: 成本更低单卡4090可微调 `CogVideoX-5B`

View File

@ -1,126 +1,102 @@
# CogVideoX diffusers Fine-tuning Guide
# CogVideoX Diffusers Fine-tuning Guide
[中文阅读](./README_zh.md)
[日本語で読む](./README_ja.md)
This feature is not fully complete yet. If you want to check the fine-tuning for the SAT version, please
see [here](../sat/README_zh.md). The dataset format is different from this version.
If you're looking for the fine-tuning instructions for the SAT version, please check [here](../sat/README_zh.md). The dataset format for this version differs from the one used here.
## Hardware Requirements
+ CogVideoX-2B / 5B LoRA: 1 * A100 (5B need to use `--use_8bit_adam`)
+ CogVideoX-2B SFT: 8 * A100 (Working)
+ CogVideoX-5B-I2V is not supported yet.
| Model | Training Type | Mixed Precision | Training Resolution (frames x height x width) | Hardware Requirements |
|---------------------|-----------------|----------------|---------------------------------------------|------------------------|
| cogvideox-t2v-2b | lora (rank128) | fp16 | 49x480x720 | 16GB VRAM (NVIDIA 4080) |
| cogvideox-t2v-5b | lora (rank128) | bf16 | 49x480x720 | 24GB VRAM (NVIDIA 4090) |
| cogvideox-i2v-5b | lora (rank128) | bf16 | 49x480x720 | 24GB VRAM (NVIDIA 4090) |
| cogvideox1.5-t2v-5b | lora (rank128) | bf16 | 81x768x1360 | 35GB VRAM (NVIDIA A100) |
| cogvideox1.5-i2v-5b | lora (rank128) | bf16 | 81x768x1360 | 35GB VRAM (NVIDIA A100) |
## Install Dependencies
Since the related code has not been merged into the diffusers release, you need to base your fine-tuning on the
diffusers branch. Please follow the steps below to install dependencies:
Since the relevant code has not yet been merged into the official `diffusers` release, you need to fine-tune based on the diffusers branch. Follow the steps below to install the dependencies:
```shell
git clone https://github.com/huggingface/diffusers.git
cd diffusers # Now in Main branch
cd diffusers # Now on the Main branch
pip install -e .
```
## Prepare the Dataset
First, you need to prepare the dataset. The dataset format should be as follows, with `videos.txt` containing the list
of videos in the `videos` directory:
First, you need to prepare your dataset. Depending on your task type (T2V or I2V), the dataset format will vary slightly:
```
.
├── prompts.txt
├── videos
└── videos.txt
├── videos.txt
├── images # (Optional) For I2V, if not provided, first frame will be extracted from video as reference
└── images.txt # (Optional) For I2V, if not provided, first frame will be extracted from video as reference
```
You can download
the [Disney Steamboat Willie](https://huggingface.co/datasets/Wild-Heart/Disney-VideoGeneration-Dataset) dataset from
here.
Where:
- `prompts.txt`: Contains the prompts
- `videos/`: Contains the .mp4 video files
- `videos.txt`: Contains the list of video files in the `videos/` directory
- `images/`: (Optional) Contains the .png reference image files
- `images.txt`: (Optional) Contains the list of reference image files
This video fine-tuning dataset is used as a test for fine-tuning.
You can download a sample dataset (T2V) [Disney Steamboat Willie](https://huggingface.co/datasets/Wild-Heart/Disney-VideoGeneration-Dataset).
## Configuration Files and Execution
If you need to use a validation dataset during training, make sure to provide a validation dataset with the same format as the training dataset.
The `accelerate` configuration files are as follows:
## Run the Script to Start Fine-tuning
+ `accelerate_config_machine_multi.yaml`: Suitable for multi-GPU use
+ `accelerate_config_machine_single.yaml`: Suitable for single-GPU use
Before starting the training, please note the following resolution requirements:
The configuration for the `finetune` script is as follows:
1. The number of frames must be a multiple of 8 **plus 1** (i.e., 8N+1), such as 49, 81, etc.
2. The recommended resolution for videos is:
- CogVideoX: 480x720 (Height x Width)
- CogVideoX1.5: 768x1360 (Height x Width)
3. For samples that do not meet the required resolution (videos or images), the code will automatically resize them. This may distort the aspect ratio and impact training results. We recommend preprocessing the samples (e.g., using crop + resize to maintain aspect ratio) before training.
```
accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gpu \ # Use accelerate to launch multi-GPU training with the config file accelerate_config_machine_single.yaml
train_cogvideox_lora.py \ # Training script train_cogvideox_lora.py for LoRA fine-tuning on CogVideoX model
--gradient_checkpointing \ # Enable gradient checkpointing to reduce memory usage
--pretrained_model_name_or_path $MODEL_PATH \ # Path to the pretrained model, specified by $MODEL_PATH
--cache_dir $CACHE_PATH \ # Cache directory for model files, specified by $CACHE_PATH
--enable_tiling \ # Enable tiling technique to process videos in chunks, saving memory
--enable_slicing \ # Enable slicing to further optimize memory by slicing inputs
--instance_data_root $DATASET_PATH \ # Dataset path specified by $DATASET_PATH
--caption_column prompts.txt \ # Specify the file prompts.txt for video descriptions used in training
--video_column videos.txt \ # Specify the file videos.txt for video paths used in training
--validation_prompt "" \ # Prompt used for generating validation videos during training
--validation_prompt_separator ::: \ # Set ::: as the separator for validation prompts
--num_validation_videos 1 \ # Generate 1 validation video per validation round
--validation_epochs 100 \ # Perform validation every 100 training epochs
--seed 42 \ # Set random seed to 42 for reproducibility
--rank 128 \ # Set the rank for LoRA parameters to 128
--lora_alpha 64 \ # Set the alpha parameter for LoRA to 64, adjusting LoRA learning rate
--mixed_precision bf16 \ # Use bf16 mixed precision for training to save memory
--output_dir $OUTPUT_PATH \ # Specify the output directory for the model, defined by $OUTPUT_PATH
--height 480 \ # Set video height to 480 pixels
--width 720 \ # Set video width to 720 pixels
--fps 8 \ # Set video frame rate to 8 frames per second
--max_num_frames 49 \ # Set the maximum number of frames per video to 49
--skip_frames_start 0 \ # Skip 0 frames at the start of the video
--skip_frames_end 0 \ # Skip 0 frames at the end of the video
--train_batch_size 4 \ # Set training batch size to 4
--num_train_epochs 30 \ # Total number of training epochs set to 30
--checkpointing_steps 1000 \ # Save model checkpoint every 1000 steps
--gradient_accumulation_steps 1 \ # Accumulate gradients for 1 step, updating after each batch
--learning_rate 1e-3 \ # Set learning rate to 0.001
--lr_scheduler cosine_with_restarts \ # Use cosine learning rate scheduler with restarts
--lr_warmup_steps 200 \ # Warm up the learning rate for the first 200 steps
--lr_num_cycles 1 \ # Set the number of learning rate cycles to 1
--optimizer AdamW \ # Use the AdamW optimizer
--adam_beta1 0.9 \ # Set Adam optimizer beta1 parameter to 0.9
--adam_beta2 0.95 \ # Set Adam optimizer beta2 parameter to 0.95
--max_grad_norm 1.0 \ # Set maximum gradient clipping value to 1.0
--allow_tf32 \ # Enable TF32 to speed up training
--report_to wandb # Use Weights and Biases (wandb) for logging and monitoring the training
> **Important Note**: To improve training efficiency, we will automatically encode videos and cache the results on disk. If you modify the data after training has begun, please delete the `latent` directory under the `videos/` folder to ensure that the latest data is used.
### Text-to-Video (T2V) Fine-tuning
```bash
# Modify the configuration parameters in accelerate_train_t2v.sh
# The main parameters to modify are:
# --output_dir: Output directory
# --data_root: Root directory of the dataset
# --caption_column: Path to the prompt file
# --video_column: Path to the video list file
# --train_resolution: Training resolution (frames x height x width)
# Refer to the start script for other important parameters
bash accelerate_train_t2v.sh
```
## Running the Script to Start Fine-tuning
### Image-to-Video (I2V) Fine-tuning
Single Node (One GPU or Multi GPU) fine-tuning:
```bash
# Modify the configuration parameters in accelerate_train_i2v.sh
# In addition to modifying the same parameters as for T2V, you also need to set:
# --image_column: Path to the reference image list file(if not provided, remove use this parameter)
# Refer to the start script for other important parameters
```shell
bash finetune_single_rank.sh
bash accelerate_train_i2v.sh
```
Multi-Node fine-tuning:
## Load the Fine-tuned Model
```shell
bash finetune_multi_rank.sh # Needs to be run on each node
```
## Loading the Fine-tuned Model
+ Please refer to [cli_demo.py](../inference/cli_demo.py) for how to load the fine-tuned model.
+ Please refer to [cli_demo.py](../inference/cli_demo.py) for instructions on how to load the fine-tuned model.
## Best Practices
+ Includes 70 training videos with a resolution of `200 x 480 x 720` (frames x height x width). By skipping frames in
the data preprocessing, we created two smaller datasets with 49 and 16 frames to speed up experimentation, as the
maximum frame limit recommended by the CogVideoX team is 49 frames. We split the 70 videos into three groups of 10,
25, and 50 videos, with similar conceptual nature.
+ Using 25 or more videos works best when training new concepts and styles.
+ It works better to train using identifier tokens specified with `--id_token`. This is similar to Dreambooth training,
but regular fine-tuning without such tokens also works.
+ The original repository used `lora_alpha` set to 1. We found this value ineffective across multiple runs, likely due
to differences in the backend and training setup. Our recommendation is to set `lora_alpha` equal to rank or rank //
2.
+ We recommend using a rank of 64 or higher.
+ We included 70 training videos with a resolution of `200 x 480 x 720` (frames x height x width). Through frame skipping in the data preprocessing, we created two smaller datasets with 49 and 16 frames to speed up experiments. The maximum frame count recommended by the CogVideoX team is 49 frames. These 70 videos were divided into three groups: 10, 25, and 50 videos, with similar conceptual nature.
+ Videos with 25 or more frames work best for training new concepts and styles.
+ It's recommended to use an identifier token, which can be specified using `--id_token`, for better training results. This is similar to Dreambooth training, though regular fine-tuning without using this token will still work.
+ The original repository uses `lora_alpha` set to 1. We found that this value performed poorly in several runs, possibly due to differences in the model backend and training settings. Our recommendation is to set `lora_alpha` to be equal to the rank or `rank // 2`.
+ It's advised to use a rank of 64 or higher.

View File

@ -1,116 +1,100 @@
# CogVideoX diffusers 微調整方法
[Read this in English.](./README_zh)
# CogVideoX Diffusers ファインチューニングガイド
[中文阅读](./README_zh.md)
[Read in English](./README.md)
この機能はまだ完全に完成していません。SATバージョンの微調整を確認したい場合は、[こちら](../sat/README_ja.md)を参照してください。本バージョンとは異なるデータセット形式を使用しています。
SATバージョンのファインチューニング手順については、[こちら](../sat/README_zh.md)をご確認ください。このバージョンのデータセットフォーマットは、こちらのバージョンとは異なります。
## ハードウェア要件
+ CogVideoX-2B / 5B T2V LORA: 1 * A100 (5B need to use `--use_8bit_adam`)
+ CogVideoX-2B SFT: 8 * A100 (動作確認済み)
+ CogVideoX-5B-I2V まだサポートしていません
| モデル | トレーニングタイプ | 混合精度学習 | トレーニング解像度(フレーム数x高さx幅) | ハードウェア要件 |
|----------------------|-----------------|------------|----------------------------------|----------------|
| cogvideox-t2v-2b | lora (rank128) | fp16 | 49x480x720 | 16GB VRAM (NVIDIA 4080) |
| cogvideox-t2v-5b | lora (rank128) | bf16 | 49x480x720 | 24GB VRAM (NVIDIA 4090) |
| cogvideox-i2v-5b | lora (rank128) | bf16 | 49x480x720 | 24GB VRAM (NVIDIA 4090) |
| cogvideox1.5-t2v-5b | lora (rank128) | bf16 | 81x768x1360 | 35GB VRAM (NVIDIA A100) |
| cogvideox1.5-i2v-5b | lora (rank128) | bf16 | 81x768x1360 | 35GB VRAM (NVIDIA A100) |
## 依存関係のインストール
関連コードはまだdiffusersのリリース版に統合されていないため、diffusersブランチを使用して微調整を行う必要があります。以下の手順に従って依存関係をインストールしてください
関連するコードがまだ `diffusers` の公式リリースに統合されていないため、`diffusers` ブランチを基にファインチューニングを行う必要があります。以下の手順に従って依存関係をインストールしてください:
```shell
git clone https://github.com/huggingface/diffusers.git
cd diffusers # Now in Main branch
cd diffusers # 現在は Main ブランチ
pip install -e .
```
## データセットの準備
まず、データセットを準備する必要があります。データセットの形式は以下のようになります。
まず、データセットを準備する必要があります。タスクの種類T2V または I2Vによって、データセットのフォーマットが少し異なります
```
.
├── prompts.txt
├── videos
└── videos.txt
├── videos.txt
├── images # (オプション) I2Vの場合。提供されない場合、動画の最初のフレームが参照画像として使用されます
└── images.txt # (オプション) I2Vの場合。提供されない場合、動画の最初のフレームが参照画像として使用されます
```
[ディズニースチームボートウィリー](https://huggingface.co/datasets/Wild-Heart/Disney-VideoGeneration-Dataset)をここからダウンロードできます。
各ファイルの役割は以下の通りです:
- `prompts.txt`: プロンプトを格納
- `videos/`: .mp4 動画ファイルを格納
- `videos.txt`: `videos/` フォルダ内の動画ファイルリストを格納
- `images/`: (オプション) .png 形式の参照画像ファイル
- `images.txt`: (オプション) 参照画像ファイルリスト
ビデオ微調整データセットはテスト用として使用されます。
トレーニング中に検証データセットを使用する場合は、トレーニングデータセットと同じフォーマットで検証データセットを提供する必要があります。
## 設定ファイルと実行
## スクリプトを実行してファインチューニングを開始
`accelerate` 設定ファイルは以下の通りです:
トレーニングを開始する前に、以下の解像度設定に関する要件に注意してください:
+ accelerate_config_machine_multi.yaml 複数GPU向け
+ accelerate_config_machine_single.yaml 単一GPU向け
1. フレーム数は8の倍数 **+1** (つまり8N+1)でなければなりません。例えば49、81など。
2. 推奨される動画の解像度は次の通りです:
- CogVideoX: 480x720高さ x 幅)
- CogVideoX1.5: 768x1360高さ x 幅)
3. 解像度が要求される基準に合わないサンプル(動画や画像)については、コード内で自動的にリサイズされます。この処理により、アスペクト比が歪む可能性があり、トレーニング結果に影響を与える可能性があります。解像度については、トレーニング前にサンプルを前処理(例えば、アスペクト比を維持するためにクロップとリサイズを使用)しておくことをお勧めします。
`finetune` スクリプト設定ファイルの例:
> **重要な注意**:トレーニング効率を向上させるために、動画はトレーニング前に自動的にエンコードされ、結果がディスクにキャッシュされます。トレーニング後にデータを変更した場合は、`videos/` フォルダ内の `latent` フォルダを削除して、最新のデータが使用されるようにしてください。
```
accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gpu \ # accelerateを使用してmulti-GPUトレーニングを起動、設定ファイルはaccelerate_config_machine_single.yaml
train_cogvideox_lora.py \ # LoRAの微調整用のトレーニングスクリプトtrain_cogvideox_lora.pyを実行
--gradient_checkpointing \ # メモリ使用量を減らすためにgradient checkpointingを有効化
--pretrained_model_name_or_path $MODEL_PATH \ # 事前学習済みモデルのパスを$MODEL_PATHで指定
--cache_dir $CACHE_PATH \ # モデルファイルのキャッシュディレクトリを$CACHE_PATHで指定
--enable_tiling \ # メモリ節約のためにタイル処理を有効化し、動画をチャンク分けして処理
--enable_slicing \ # 入力をスライスしてさらにメモリ最適化
--instance_data_root $DATASET_PATH \ # データセットのパスを$DATASET_PATHで指定
--caption_column prompts.txt \ # トレーニングで使用する動画の説明ファイルをprompts.txtで指定
--video_column videos.txt \ # トレーニングで使用する動画のパスファイルをvideos.txtで指定
--validation_prompt "" \ # トレーニング中に検証用の動画を生成する際のプロンプト
--validation_prompt_separator ::: \ # 検証プロンプトの区切り文字を:::に設定
--num_validation_videos 1 \ # 各検証ラウンドで1本の動画を生成
--validation_epochs 100 \ # 100エポックごとに検証を実施
--seed 42 \ # 再現性を保証するためにランダムシードを42に設定
--rank 128 \ # LoRAのパラメータのランクを128に設定
--lora_alpha 64 \ # LoRAのalphaパラメータを64に設定し、LoRAの学習率を調整
--mixed_precision bf16 \ # bf16混合精度でトレーニングし、メモリを節約
--output_dir $OUTPUT_PATH \ # モデルの出力ディレクトリを$OUTPUT_PATHで指定
--height 480 \ # 動画の高さを480ピクセルに設定
--width 720 \ # 動画の幅を720ピクセルに設定
--fps 8 \ # 動画のフレームレートを1秒あたり8フレームに設定
--max_num_frames 49 \ # 各動画の最大フレーム数を49に設定
--skip_frames_start 0 \ # 動画の最初のフレームを0スキップ
--skip_frames_end 0 \ # 動画の最後のフレームを0スキップ
--train_batch_size 4 \ # トレーニングのバッチサイズを4に設定
--num_train_epochs 30 \ # 総トレーニングエポック数を30に設定
--checkpointing_steps 1000 \ # 1000ステップごとにモデルのチェックポイントを保存
--gradient_accumulation_steps 1 \ # 1ステップの勾配累積を行い、各バッチ後に更新
--learning_rate 1e-3 \ # 学習率を0.001に設定
--lr_scheduler cosine_with_restarts \ # リスタート付きのコサイン学習率スケジューラを使用
--lr_warmup_steps 200 \ # トレーニングの最初の200ステップで学習率をウォームアップ
--lr_num_cycles 1 \ # 学習率のサイクル数を1に設定
--optimizer AdamW \ # AdamWオプティマイザーを使用
--adam_beta1 0.9 \ # Adamオプティマイザーのbeta1パラメータを0.9に設定
--adam_beta2 0.95 \ # Adamオプティマイザーのbeta2パラメータを0.95に設定
--max_grad_norm 1.0 \ # 勾配クリッピングの最大値を1.0に設定
--allow_tf32 \ # トレーニングを高速化するためにTF32を有効化
--report_to wandb # Weights and Biasesを使用してトレーニングの記録とモニタリングを行う
### テキストから動画生成T2Vのファインチューニング
```bash
# accelerate_train_t2v.sh の設定パラメータを変更します
# 主に変更が必要なパラメータ:
# --output_dir: 出力先ディレクトリ
# --data_root: データセットのルートディレクトリ
# --caption_column: プロンプトファイルのパス
# --video_column: 動画リストファイルのパス
# --train_resolution: トレーニング解像度(フレーム数 x 高さ x 幅)
# その他の重要なパラメータについては、起動スクリプトを参照してください
bash accelerate_train_t2v.sh
```
## 微調整を開始
### 画像から動画生成I2Vのファインチューニング
単一マシン (シングルGPU、マルチGPU) での微調整:
```bash
# accelerate_train_i2v.sh の設定パラメータを変更します
# T2Vと同様に変更が必要なパラメータに加えて、以下のパラメータも設定する必要があります
# --image_column: 参照画像リストファイルのパス(オプション)
# その他の重要なパラメータについては、起動スクリプトを参照してください
```shell
bash finetune_single_rank.sh
bash accelerate_train_i2v.sh
```
複数マシン・マルチGPUでの微調整
## ファインチューニングしたモデルの読み込み
```shell
bash finetune_multi_rank.sh # 各ノードで実行する必要があります。
```
## 微調整済みモデルのロード
+ 微調整済みのモデルをロードする方法については、[cli_demo.py](../inference/cli_demo.py) を参照してください。
+ ファインチューニングしたモデルを読み込む方法については、[cli_demo.py](../inference/cli_demo.py)を参照してください。
## ベストプラクティス
+ 解像度が `200 x 480 x 720`(フレーム数 x 高さ x 幅)のトレーニングビデオが70本含まれています。データ前処理でフレームをスキップすることで、49フレームと16フレームの小さなデータセットを作成しました。これは実験を加速するためのもので、CogVideoXチームが推奨する最大フレーム数制限は49フレームです。
+ 25本以上のビデオが新しい概念やスタイルのトレーニングに最適です。
+ 現在、`--id_token` を指定して識別トークンを使用してトレーニングする方が効果的です。これはDreamboothトレーニングに似ていますが、通常の微調整でも機能します。
+ 元のリポジトリでは `lora_alpha` を1に設定していましたが、複数の実行でこの値が効果的でないことがわかりました。モデルのバックエンドやトレーニング設定によるかもしれません。私たちの提案は、lora_alphaをrankと同じか、rank // 2に設定することです。
+ Rank 64以上の設定を推奨します。
+ 解像度が `200 x 480 x 720`(フレーム数 x 高さ x 幅の70本のトレーニング動画を使用しました。データ前処理でフレームスキップを行い、49フレームおよび16フレームの2つの小さなデータセットを作成して実験速度を向上させました。CogVideoXチームの推奨最大フレーム数制限は49フレームです。これらの70本の動画は、10、25、50本の3つのグループに分け、概念的に類似した性質のものです。
+ 25本以上の動画を使用することで、新しい概念やスタイルのトレーニングが最適です。
+ `--id_token` で指定できる識別子トークンを使用すると、トレーニング効果がより良くなります。これはDreamboothトレーニングに似ていますが、このトークンを使用しない通常のファインチューニングでも問題なく動作します。
+ 元のリポジトリでは `lora_alpha` が1に設定されていますが、この値は多くの実行で効果が悪かったため、モデルのバックエンドやトレーニング設定の違いが影響している可能性があります。私たちの推奨は、`lora_alpha` を rank と同じか、`rank // 2` に設定することです。
+ rank は64以上に設定することをお勧めします。

View File

@ -1,16 +1,24 @@
# CogVideoX diffusers 微调方案
[Read this in English](./README_zh.md)
[Read this in English](./README.md)
[日本語で読む](./README_ja.md)
本功能尚未完全完善,如果您想查看SAT版本微调请查看[这里](../sat/README_zh.md)。其数据集格式与本版本不同。
如果您想查看SAT版本微调请查看[这里](../sat/README_zh.md)。其数据集格式与本版本不同。
## 硬件要求
+ CogVideoX-2B / 5B T2V LORA: 1 * A100 (5B need to use `--use_8bit_adam`)
+ CogVideoX-2B SFT: 8 * A100 (制作中)
+ CogVideoX-5B-I2V 暂未支持
| 模型 | 训练类型 | 混合训练精度 | 训练分辨率(帧数x高x宽) | 硬件要求 |
|----------------------|----------------|------------|----------------------|-----------------------|
| cogvideox-t2v-2b | lora (rank128) | fp16 | 49x480x720 | 16G显存 (NVIDIA 4080) |
| cogvideox-t2v-5b | lora (rank128) | bf16 | 49x480x720 | 24G显存 (NVIDIA 4090) |
| cogvideox-i2v-5b | lora (rank128) | bf16 | 49x480x720 | 24G显存 (NVIDIA 4090) |
| cogvideox1.5-t2v-5b | lora (rank128) | bf16 | 81x768x1360 | 35G显存 (NVIDIA A100) |
| cogvideox1.5-i2v-5b | lora (rank128) | bf16 | 81x768x1360 | 35G显存 (NVIDIA A100) |
<!-- | cogvideox-t2v-5b | sft | bf16 | 49x480x720 | |
| cogvideox-i2v-5b | sft | bf16 | 49x480x720 | |
| cogvideox1.5-t2v-5b | sft | bf16 | 81x768x1360 | |
| cogvideox1.5-i2v-5b | sft | bf16 | 81x768x1360 | | -->
## 安装依赖
@ -24,83 +32,64 @@ pip install -e .
## 准备数据集
首先,你需要准备数据集数据集格式如下其中videos.txt 存放 videos 中的视频。
首先,你需要准备数据集。根据你的任务类型T2V 或 I2V数据集格式略有不同
```
.
├── prompts.txt
├── videos
└── videos.txt
├── videos.txt
├── images # (可选) 对于I2V若不提供则从视频中提取第一帧作为参考图像
└── images.txt # (可选) 对于I2V若不提供则从视频中提取第一帧作为参考图像
```
你可以从这里下载 [迪士尼汽船威利号](https://huggingface.co/datasets/Wild-Heart/Disney-VideoGeneration-Dataset)
其中:
- `prompts.txt`: 存放提示词
- `videos/`: 存放.mp4视频文件
- `videos.txt`: 存放 videos 目录中的视频文件列表
- `images/`: (可选) 存放.png参考图像文件
- `images.txt`: (可选) 存放参考图像文件列表
视频微调数据集作为测试微调。
你可以从这里下载示例数据集(T2V) [迪士尼汽船威利号](https://huggingface.co/datasets/Wild-Heart/Disney-VideoGeneration-Dataset)
## 配置文件和运行
`accelerate` 配置文件如下:
+ accelerate_config_machine_multi.yaml 适合多GPU使用
+ accelerate_config_machine_single.yaml 适合单GPU使用
`finetune` 脚本配置文件如下:
```shell
accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gpu \ # 使用 accelerate 启动多GPU训练配置文件为 accelerate_config_machine_single.yaml
train_cogvideox_lora.py \ # 运行的训练脚本为 train_cogvideox_lora.py用于在 CogVideoX 模型上进行 LoRA 微调
--gradient_checkpointing \ # 启用梯度检查点功能,以减少显存使用
--pretrained_model_name_or_path $MODEL_PATH \ # 预训练模型路径,通过 $MODEL_PATH 指定
--cache_dir $CACHE_PATH \ # 模型缓存路径,由 $CACHE_PATH 指定
--enable_tiling \ # 启用tiling技术以分片处理视频节省显存
--enable_slicing \ # 启用slicing技术将输入切片以进一步优化内存
--instance_data_root $DATASET_PATH \ # 数据集路径,由 $DATASET_PATH 指定
--caption_column prompts.txt \ # 指定用于训练的视频描述文件,文件名为 prompts.txt
--video_column videos.txt \ # 指定用于训练的视频路径文件,文件名为 videos.txt
--validation_prompt "" \ # 验证集的提示语 (prompt),用于在训练期间生成验证视频
--validation_prompt_separator ::: \ # 设置验证提示语的分隔符为 :::
--num_validation_videos 1 \ # 每个验证回合生成 1 个视频
--validation_epochs 100 \ # 每 100 个训练epoch进行一次验证
--seed 42 \ # 设置随机种子为 42以保证结果的可复现性
--rank 128 \ # 设置 LoRA 参数的秩 (rank) 为 128
--lora_alpha 64 \ # 设置 LoRA 的 alpha 参数为 64用于调整LoRA的学习率
--mixed_precision bf16 \ # 使用 bf16 混合精度进行训练,减少显存使用
--output_dir $OUTPUT_PATH \ # 指定模型输出目录,由 $OUTPUT_PATH 定义
--height 480 \ # 视频高度为 480 像素
--width 720 \ # 视频宽度为 720 像素
--fps 8 \ # 视频帧率设置为 8 帧每秒
--max_num_frames 49 \ # 每个视频的最大帧数为 49 帧
--skip_frames_start 0 \ # 跳过视频开头的帧数为 0
--skip_frames_end 0 \ # 跳过视频结尾的帧数为 0
--train_batch_size 4 \ # 训练时的 batch size 设置为 4
--num_train_epochs 30 \ # 总训练epoch数为 30
--checkpointing_steps 1000 \ # 每 1000 步保存一次模型检查点
--gradient_accumulation_steps 1 \ # 梯度累计步数为 1即每个 batch 后都会更新梯度
--learning_rate 1e-3 \ # 学习率设置为 0.001
--lr_scheduler cosine_with_restarts \ # 使用带重启的余弦学习率调度器
--lr_warmup_steps 200 \ # 在训练的前 200 步进行学习率预热
--lr_num_cycles 1 \ # 学习率周期设置为 1
--optimizer AdamW \ # 使用 AdamW 优化器
--adam_beta1 0.9 \ # 设置 Adam 优化器的 beta1 参数为 0.9
--adam_beta2 0.95 \ # 设置 Adam 优化器的 beta2 参数为 0.95
--max_grad_norm 1.0 \ # 最大梯度裁剪值设置为 1.0
--allow_tf32 \ # 启用 TF32 以加速训练
--report_to wandb # 使用 Weights and Biases 进行训练记录与监控
```
如果需要在训练过程中进行validation则需要额外提供验证数据集其中数据格式与训练集相同。
## 运行脚本,开始微调
单机(单卡,多卡)微调
在开始训练之前,请注意以下分辨率设置要求:
```shell
bash finetune_single_rank.sh
1. 帧数必须是8的倍数 **+1** (即8N+1), 例如49, 81 ...
2. 视频分辨率建议使用模型的默认大小:
- CogVideoX: 480x720 (高x宽)
- CogVideoX1.5: 768x1360 (高x宽)
3. 对于不满足训练分辨率的样本视频或图片在代码中会直接进行resize。这可能会导致样本的宽高比发生形变从而影响训练效果。建议用户提前对样本在分辨率上进行处理例如使用crop + resize来维持宽高比再进行训练。
> **重要提示**为了提高训练效率我们会在训练前自动对video进行encode并将结果缓存在磁盘。如果在训练后修改了数据请删除video目录下的latent目录以确保使用最新的数据。
### 文本生成视频 (T2V) 微调
```bash
# 修改 accelerate_train_t2v.sh 中的配置参数
# 主要需要修改以下参数:
# --output_dir: 输出目录
# --data_root: 数据集根目录
# --caption_column: 提示词文件路径
# --video_column: 视频文件列表路径
# --train_resolution: 训练分辨率 (帧数x高x宽)
# 其他重要参数请参考启动脚本
bash accelerate_train_t2v.sh
```
多机多卡微调:
### 图像生成视频 (I2V) 微调
```shell
bash finetune_multi_rank.sh #需要在每个节点运行
```bash
# 修改 accelerate_train_i2v.sh 中的配置参数
# 除了需要修改与T2V相同的参数外还需要额外设置:
# --image_column: 参考图像文件列表路径(如果没有自己的图片,默认使用视频第一帧,移除这个参数)
# 其他重要参数请参考启动脚本
bash accelerate_train_i2v.sh
```
## 载入微调的模型

View File

@ -6,7 +6,7 @@ export TOKENIZERS_PARALLELISM=false
# Model Configuration
MODEL_ARGS=(
--model_path "THUDM/CogVideoX1.5-5B-I2V"
--model_name "cogvideox1.5-i2v"
--model_name "cogvideox1.5-i2v" # ["cogvideox-i2v"]
--model_type "i2v"
--training_type "lora"
)
@ -23,7 +23,7 @@ DATA_ARGS=(
--caption_column "prompt.txt"
--video_column "videos.txt"
--image_column "images.txt"
--train_resolution "80x768x1360"
--train_resolution "81x768x1360"
)
# Training Configuration
@ -31,7 +31,7 @@ TRAIN_ARGS=(
--train_epochs 10
--batch_size 1
--gradient_accumulation_steps 1
--mixed_precision "bf16"
--mixed_precision "bf16" # ["no", "fp16"]
--seed 42
)
@ -55,7 +55,7 @@ VALIDATION_ARGS=(
--validation_steps 400
--validation_prompts "prompts.txt"
--validation_images "images.txt"
--gen_fps 15
--gen_fps 16
)
# Combine all arguments and launch training

View File

@ -6,7 +6,7 @@ export TOKENIZERS_PARALLELISM=false
# Model Configuration
MODEL_ARGS=(
--model_path "THUDM/CogVideoX1.5-5B"
--model_name "cogvideox1.5-t2v"
--model_name "cogvideox1.5-t2v" # ["cogvideox-t2v"]
--model_type "t2v"
--training_type "lora"
)
@ -22,7 +22,7 @@ DATA_ARGS=(
--data_root "/path/to/data/dir"
--caption_column "prompt.txt"
--video_column "videos.txt"
--train_resolution "80x768x1360"
--train_resolution "81x768x1360"
)
# Training Configuration
@ -30,7 +30,7 @@ TRAIN_ARGS=(
--train_epochs 10
--batch_size 1
--gradient_accumulation_steps 1
--mixed_precision "bf16"
--mixed_precision "bf16" # ["no", "fp16"]
--seed 42
)
@ -53,7 +53,7 @@ VALIDATION_ARGS=(
--validation_dir "/path/to/validation/dir"
--validation_steps 400
--validation_prompts "prompts.txt"
--gen_fps 15
--gen_fps 16
)
# Combine all arguments and launch training

View File

@ -1,2 +1,2 @@
LOG_NAME = "trainer"
LOG_LEVEL = "INFO"
LOG_LEVEL = "INFO"

View File

@ -1,6 +1,6 @@
from .i2v_dataset import I2VDatasetWithResize, I2VDatasetWithBuckets
from .t2v_dataset import T2VDatasetWithResize, T2VDatasetWithBuckets
from .bucket_sampler import BucketSampler
from .i2v_dataset import I2VDatasetWithBuckets, I2VDatasetWithResize
from .t2v_dataset import T2VDatasetWithBuckets, T2VDatasetWithResize
__all__ = [
@ -8,5 +8,5 @@ __all__ = [
"I2VDatasetWithBuckets",
"T2VDatasetWithResize",
"T2VDatasetWithBuckets",
"BucketSampler"
"BucketSampler",
]

View File

@ -1,8 +1,8 @@
import random
import logging
import random
from torch.utils.data import Dataset, Sampler
from torch.utils.data import Sampler
from torch.utils.data import Dataset
logger = logging.getLogger(__name__)
@ -37,7 +37,6 @@ class BucketSampler(Sampler):
self._raised_warning_for_drop_last = False
def __len__(self):
if self.drop_last and not self._raised_warning_for_drop_last:
self._raised_warning_for_drop_last = True
@ -46,7 +45,6 @@ class BucketSampler(Sampler):
)
return (len(self.data_source) + self.batch_size - 1) // self.batch_size
def __iter__(self):
for index, data in enumerate(self.data_source):
video_metadata = data["video_metadata"]

View File

@ -1,22 +1,30 @@
import torch
import hashlib
from pathlib import Path
from typing import Any, Dict, List, Tuple, Callable
from typing_extensions import override
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
import torch
from accelerate.logging import get_logger
from safetensors.torch import load_file, save_file
from torch.utils.data import Dataset
from torchvision import transforms
from finetune.constants import LOG_NAME, LOG_LEVEL
from typing_extensions import override
from finetune.constants import LOG_LEVEL, LOG_NAME
from .utils import (
load_prompts, load_videos, load_images,
load_images,
load_images_from_videos,
load_prompts,
load_videos,
preprocess_image_with_resize,
preprocess_video_with_buckets,
preprocess_video_with_resize,
preprocess_video_with_buckets
)
if TYPE_CHECKING:
from finetune.trainer import Trainer
# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
# Very few bug reports but it happens. Look in decord Github issues for more relevant information.
import decord # isort:skip
@ -40,26 +48,32 @@ class BaseI2VDataset(Dataset):
device (torch.device): Device to load the data on
encode_video_fn (Callable[[torch.Tensor], torch.Tensor], optional): Function to encode videos
"""
def __init__(
self,
data_root: str,
caption_column: str,
video_column: str,
image_column: str,
image_column: str | None,
device: torch.device,
encode_video_fn: Callable[[torch.Tensor], torch.Tensor] = None,
trainer: "Trainer" = None,
*args,
**kwargs
**kwargs,
) -> None:
super().__init__()
data_root = Path(data_root)
self.prompts = load_prompts(data_root / caption_column)
self.videos = load_videos(data_root / video_column)
self.images = load_images(data_root / image_column)
if image_column is not None:
self.images = load_images(data_root / image_column)
else:
self.images = load_images_from_videos(self.videos)
self.trainer = trainer
self.device = device
self.encode_video_fn = encode_video_fn
self.encode_video = trainer.encode_video
self.encode_text = trainer.encode_text
# Check if number of prompts matches number of videos and images
if not (len(self.videos) == len(self.prompts) == len(self.images)):
@ -98,34 +112,66 @@ class BaseI2VDataset(Dataset):
prompt = self.prompts[index]
video = self.videos[index]
image = self.images[index]
train_resolution_str = "x".join(str(x) for x in self.trainer.args.train_resolution)
video_latent_dir = video.parent / "latent"
cache_dir = self.trainer.args.data_root / "cache"
video_latent_dir = cache_dir / "video_latent" / self.trainer.args.model_name / train_resolution_str
prompt_embeddings_dir = cache_dir / "prompt_embeddings"
video_latent_dir.mkdir(parents=True, exist_ok=True)
encoded_video_path = video_latent_dir / (video.stem + ".pt")
prompt_embeddings_dir.mkdir(parents=True, exist_ok=True)
prompt_hash = str(hashlib.sha256(prompt.encode()).hexdigest())
prompt_embedding_path = prompt_embeddings_dir / (prompt_hash + ".safetensors")
encoded_video_path = video_latent_dir / (video.stem + ".safetensors")
if prompt_embedding_path.exists():
prompt_embedding = load_file(prompt_embedding_path)["prompt_embedding"]
logger.debug(
f"process {self.trainer.accelerator.process_index}: Loaded prompt embedding from {prompt_embedding_path}",
main_process_only=False,
)
else:
prompt_embedding = self.encode_text(prompt)
prompt_embedding = prompt_embedding.to("cpu")
# [1, seq_len, hidden_size] -> [seq_len, hidden_size]
prompt_embedding = prompt_embedding[0]
save_file({"prompt_embedding": prompt_embedding}, prompt_embedding_path)
logger.info(f"Saved prompt embedding to {prompt_embedding_path}", main_process_only=False)
if encoded_video_path.exists():
encoded_video = torch.load(encoded_video_path, weights_only=True)
# encoded_video = torch.load(encoded_video_path, weights_only=True)
encoded_video = load_file(encoded_video_path)["encoded_video"]
logger.debug(f"Loaded encoded video from {encoded_video_path}", main_process_only=False)
# shape of image: [C, H, W]
_, image = self.preprocess(None, self.images[index])
else:
frames, image = self.preprocess(video, image)
frames = frames.to(self.device)
# current shape of frames: [F, C, H, W]
image = image.to(self.device)
# Current shape of frames: [F, C, H, W]
frames = self.video_transform(frames)
# Add image into the first frame.
# Note, **this operation maybe model-specific**, and maybe change in the future.
frames = torch.cat([image.unsqueeze(0), frames], dim=0)
# Convert to [B, C, F, H, W]
frames = frames.unsqueeze(0)
frames = frames.permute(0, 2, 1, 3, 4).contiguous()
encoded_video = self.encode_video_fn(frames)
# [B, C, F, H, W] -> [C, F, H, W]
encoded_video = encoded_video[0].cpu()
torch.save(encoded_video, encoded_video_path)
encoded_video = self.encode_video(frames)
# [1, C, F, H, W] -> [C, F, H, W]
encoded_video = encoded_video[0]
encoded_video = encoded_video.to("cpu")
image = image.to("cpu")
save_file({"encoded_video": encoded_video}, encoded_video_path)
logger.info(f"Saved encoded video to {encoded_video_path}", main_process_only=False)
# shape of encoded_video: [C, F, H, W]
# shape of image: [C, H, W]
return {
"prompt": prompt,
"image": image,
"prompt_embedding": prompt_embedding,
"encoded_video": encoded_video,
"video_metadata": {
"num_frames": encoded_video.shape[1],
@ -150,7 +196,7 @@ class BaseI2VDataset(Dataset):
- image(torch.Tensor) of shape [C, H, W]
"""
raise NotImplementedError("Subclass must implement this method")
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
"""
Applies transformations to a video.
@ -160,14 +206,14 @@ class BaseI2VDataset(Dataset):
with shape [F, C, H, W] where:
- F is number of frames
- C is number of channels (3 for RGB)
- H is height
- H is height
- W is width
Returns:
torch.Tensor: The transformed video tensor
"""
raise NotImplementedError("Subclass must implement this method")
def image_transform(self, image: torch.Tensor) -> torch.Tensor:
"""
Applies transformations to an image.
@ -176,7 +222,7 @@ class BaseI2VDataset(Dataset):
image (torch.Tensor): A 3D tensor representing an image
with shape [C, H, W] where:
- C is number of channels (3 for RGB)
- H is height
- H is height
- W is width
Returns:
@ -198,6 +244,7 @@ class I2VDatasetWithResize(BaseI2VDataset):
height (int): Target height for resizing videos and images
width (int): Target width for resizing videos and images
"""
def __init__(self, max_num_frames: int, height: int, width: int, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
@ -205,11 +252,7 @@ class I2VDatasetWithResize(BaseI2VDataset):
self.height = height
self.width = width
self.__frame_transforms = transforms.Compose(
[
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)
]
)
self.__frame_transforms = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)])
self.__image_transforms = self.__frame_transforms
@override
@ -223,25 +266,25 @@ class I2VDatasetWithResize(BaseI2VDataset):
else:
image = None
return video, image
@override
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
return torch.stack([self.__frame_transforms(f) for f in frames], dim=0)
@override
def image_transform(self, image: torch.Tensor) -> torch.Tensor:
return self.__image_transforms(image)
class I2VDatasetWithBuckets(BaseI2VDataset):
def __init__(
self,
video_resolution_buckets: List[Tuple[int, int, int]],
vae_temporal_compression_ratio: int,
vae_height_compression_ratio: int,
vae_width_compression_ratio: int,
*args, **kwargs
*args,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
@ -253,23 +296,19 @@ class I2VDatasetWithBuckets(BaseI2VDataset):
)
for b in video_resolution_buckets
]
self.__frame_transforms = transforms.Compose(
[
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)
]
)
self.__frame_transforms = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)])
self.__image_transforms = self.__frame_transforms
@override
def preprocess(self, video_path: Path, image_path: Path) -> Tuple[torch.Tensor, torch.Tensor]:
video = preprocess_video_with_buckets(video_path, self.video_resolution_buckets)
image = preprocess_image_with_resize(image_path, video.shape[2], video.shape[3])
return video, image
@override
@override
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
return torch.stack([self.__frame_transforms(f) for f in frames], dim=0)
@override
def image_transform(self, image: torch.Tensor) -> torch.Tensor:
return self.__image_transforms(image)

View File

@ -1,20 +1,21 @@
import torch
import hashlib
from pathlib import Path
from typing import Any, Dict, List, Tuple, Callable
from typing_extensions import override
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
import torch
from accelerate.logging import get_logger
from safetensors.torch import load_file, save_file
from torch.utils.data import Dataset
from torchvision import transforms
from typing_extensions import override
from finetune.constants import LOG_NAME, LOG_LEVEL
from finetune.constants import LOG_LEVEL, LOG_NAME
from .utils import (
load_prompts, load_videos,
preprocess_video_with_resize,
preprocess_video_with_buckets
)
from .utils import load_prompts, load_videos, preprocess_video_with_buckets, preprocess_video_with_resize
if TYPE_CHECKING:
from finetune.trainer import Trainer
# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
# Very few bug reports but it happens. Look in decord Github issues for more relevant information.
@ -45,9 +46,9 @@ class BaseT2VDataset(Dataset):
caption_column: str,
video_column: str,
device: torch.device = None,
encode_video_fn: Callable[[torch.Tensor], torch.Tensor] = None,
trainer: "Trainer" = None,
*args,
**kwargs
**kwargs,
) -> None:
super().__init__()
@ -55,7 +56,9 @@ class BaseT2VDataset(Dataset):
self.prompts = load_prompts(data_root / caption_column)
self.videos = load_videos(data_root / video_column)
self.device = device
self.encode_video_fn = encode_video_fn
self.encode_video = trainer.encode_video
self.encode_text = trainer.encode_text
self.trainer = trainer
# Check if all video files exist
if any(not path.is_file() for path in self.videos):
@ -87,30 +90,56 @@ class BaseT2VDataset(Dataset):
prompt = self.prompts[index]
video = self.videos[index]
train_resolution_str = "x".join(str(x) for x in self.trainer.args.train_resolution)
latent_dir = video.parent / "latent"
latent_dir.mkdir(parents=True, exist_ok=True)
encoded_video_path = latent_dir / (video.stem + ".pt")
cache_dir = self.trainer.args.data_root / "cache"
video_latent_dir = cache_dir / "video_latent" / self.trainer.args.model_name / train_resolution_str
prompt_embeddings_dir = cache_dir / "prompt_embeddings"
video_latent_dir.mkdir(parents=True, exist_ok=True)
prompt_embeddings_dir.mkdir(parents=True, exist_ok=True)
prompt_hash = str(hashlib.sha256(prompt.encode()).hexdigest())
prompt_embedding_path = prompt_embeddings_dir / (prompt_hash + ".safetensors")
encoded_video_path = video_latent_dir / (video.stem + ".safetensors")
if prompt_embedding_path.exists():
prompt_embedding = load_file(prompt_embedding_path)["prompt_embedding"]
logger.debug(
f"process {self.trainer.accelerator.process_index}: Loaded prompt embedding from {prompt_embedding_path}",
main_process_only=False,
)
else:
prompt_embedding = self.encode_text(prompt)
prompt_embedding = prompt_embedding.to("cpu")
# [1, seq_len, hidden_size] -> [seq_len, hidden_size]
prompt_embedding = prompt_embedding[0]
save_file({"prompt_embedding": prompt_embedding}, prompt_embedding_path)
logger.info(f"Saved prompt embedding to {prompt_embedding_path}", main_process_only=False)
if encoded_video_path.exists():
# shape of encoded_video: [C, F, H, W]
encoded_video = torch.load(encoded_video_path, weights_only=True)
# encoded_video = torch.load(encoded_video_path, weights_only=True)
encoded_video = load_file(encoded_video_path)["encoded_video"]
logger.debug(f"Loaded encoded video from {encoded_video_path}", main_process_only=False)
# shape of image: [C, H, W]
else:
frames = self.preprocess(video)
frames = frames.to(self.device)
# current shape of frames: [F, C, H, W]
# Current shape of frames: [F, C, H, W]
frames = self.video_transform(frames)
# Convert to [B, C, F, H, W]
frames = frames.unsqueeze(0)
frames = frames.permute(0, 2, 1, 3, 4).contiguous()
encoded_video = self.encode_video_fn(frames)
# [B, C, F, H, W] -> [C, F, H, W]
encoded_video = encoded_video[0].cpu()
torch.save(encoded_video, encoded_video_path)
encoded_video = self.encode_video(frames)
# [1, C, F, H, W] -> [C, F, H, W]
encoded_video = encoded_video[0]
encoded_video = encoded_video.to("cpu")
save_file({"encoded_video": encoded_video}, encoded_video_path)
logger.info(f"Saved encoded video to {encoded_video_path}", main_process_only=False)
# shape of encoded_video: [C, F, H, W]
return {
"prompt": prompt,
"prompt_embedding": prompt_embedding,
"encoded_video": encoded_video,
"video_metadata": {
"num_frames": encoded_video.shape[1],
@ -134,7 +163,7 @@ class BaseT2VDataset(Dataset):
- W is width
"""
raise NotImplementedError("Subclass must implement this method")
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
"""
Applies transformations to a video.
@ -144,7 +173,7 @@ class BaseT2VDataset(Dataset):
with shape [F, C, H, W] where:
- F is number of frames
- C is number of channels (3 for RGB)
- H is height
- H is height
- W is width
Returns:
@ -173,36 +202,33 @@ class T2VDatasetWithResize(BaseT2VDataset):
self.height = height
self.width = width
self.__frame_transform = transforms.Compose(
[
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)
]
)
self.__frame_transform = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)])
@override
def preprocess(self, video_path: Path) -> torch.Tensor:
return preprocess_video_with_resize(
video_path, self.max_num_frames, self.height, self.width,
video_path,
self.max_num_frames,
self.height,
self.width,
)
@override
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
return torch.stack([self.__frame_transform(f) for f in frames], dim=0)
class T2VDatasetWithBuckets(BaseT2VDataset):
def __init__(
self,
video_resolution_buckets: List[Tuple[int, int, int]],
vae_temporal_compression_ratio: int,
vae_height_compression_ratio: int,
vae_width_compression_ratio: int,
*args, **kwargs
*args,
**kwargs,
) -> None:
"""
"""
""" """
super().__init__(*args, **kwargs)
self.video_resolution_buckets = [
@ -214,18 +240,12 @@ class T2VDatasetWithBuckets(BaseT2VDataset):
for b in video_resolution_buckets
]
self.__frame_transform = transforms.Compose(
[
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)
]
)
self.__frame_transform = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)])
@override
def preprocess(self, video_path: Path) -> torch.Tensor:
return preprocess_video_with_buckets(
video_path, self.video_resolution_buckets
)
return preprocess_video_with_buckets(video_path, self.video_resolution_buckets)
@override
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
return torch.stack([self.__frame_transform(f) for f in frames], dim=0)

View File

@ -1,11 +1,12 @@
import torch
import cv2
from typing import List, Tuple
import logging
from pathlib import Path
from torchvision import transforms
from typing import List, Tuple
import cv2
import torch
from torchvision.transforms.functional import resize
# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
# Very few bug reports but it happens. Look in decord Github issues for more relevant information.
import decord # isort:skip
@ -15,6 +16,7 @@ decord.bridge.set_bridge("torch")
########## loaders ##########
def load_prompts(prompt_path: Path) -> List[str]:
with open(prompt_path, "r", encoding="utf-8") as file:
return [line.strip() for line in file.readlines() if len(line.strip()) > 0]
@ -30,8 +32,40 @@ def load_images(image_path: Path) -> List[Path]:
return [image_path.parent / line.strip() for line in file.readlines() if len(line.strip()) > 0]
def load_images_from_videos(videos_path: List[Path]) -> List[Path]:
first_frames_dir = videos_path[0].parent.parent / "first_frames"
first_frames_dir.mkdir(exist_ok=True)
first_frame_paths = []
for video_path in videos_path:
frame_path = first_frames_dir / f"{video_path.stem}.png"
if frame_path.exists():
first_frame_paths.append(frame_path)
continue
# Open video
cap = cv2.VideoCapture(str(video_path))
# Read first frame
ret, frame = cap.read()
if not ret:
raise RuntimeError(f"Failed to read video: {video_path}")
# Save frame as PNG with same name as video
cv2.imwrite(str(frame_path), frame)
logging.info(f"Saved first frame to {frame_path}")
# Release video capture
cap.release()
first_frame_paths.append(frame_path)
return first_frame_paths
########## preprocessors ##########
def preprocess_image_with_resize(
image_path: Path | str,
height: int,
@ -92,11 +126,11 @@ def preprocess_video_with_resize(
video_reader = decord.VideoReader(uri=video_path.as_posix(), width=width, height=height)
video_num_frames = len(video_reader)
if video_num_frames < max_num_frames:
raise ValueError(f"video's frames is less than {max_num_frames}.")
raise ValueError(f"video frame count in {video_path} is less than {max_num_frames}.")
indices = list(range(0, video_num_frames, video_num_frames // max_num_frames))
frames = video_reader.get_batch(indices)
frames = frames[: max_num_frames].float()
frames = frames[:max_num_frames].float()
frames = frames.permute(0, 3, 1, 2).contiguous()
return frames
@ -144,4 +178,4 @@ def preprocess_video_with_buckets(
nearest_res = (nearest_res[1], nearest_res[2])
frames = torch.stack([resize(f, nearest_res) for f in frames], dim=0)
return frames
return frames

View File

@ -5,8 +5,8 @@ from pathlib import Path
package_dir = Path(__file__).parent
for subdir in package_dir.iterdir():
if subdir.is_dir() and not subdir.name.startswith('_'):
for module_path in subdir.glob('*.py'):
if subdir.is_dir() and not subdir.name.startswith("_"):
for module_path in subdir.glob("*.py"):
module_name = module_path.stem
full_module_name = f".{subdir.name}.{module_name}"
importlib.import_module(full_module_name, package=__name__)

View File

@ -1,5 +1,5 @@
from ..utils import register
from ..cogvideox_i2v.lora_trainer import CogVideoXI2VLoraTrainer
from ..utils import register
class CogVideoX1dot5I2VLoraTrainer(CogVideoXI2VLoraTrainer):

View File

@ -1,27 +1,26 @@
import torch
from typing_extensions import override
from typing import Any, Dict, List, Tuple
from PIL import Image
from transformers import AutoTokenizer, T5EncoderModel
from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
from diffusers.models.embeddings import get_3d_rotary_pos_embed
import torch
from diffusers import (
CogVideoXImageToVideoPipeline,
CogVideoXTransformer3DModel,
AutoencoderKLCogVideoX,
CogVideoXDPMScheduler,
CogVideoXImageToVideoPipeline,
CogVideoXTransformer3DModel,
)
from diffusers.models.embeddings import get_3d_rotary_pos_embed
from PIL import Image
from transformers import AutoTokenizer, T5EncoderModel
from typing_extensions import override
from finetune.trainer import Trainer
from finetune.schemas import Components
from finetune.trainer import Trainer
from finetune.utils import unwrap_model
from ..utils import register
class CogVideoXI2VLoraTrainer(Trainer):
UNLOAD_LIST = ["text_encoder"]
@override
def load_components(self) -> Dict[str, Any]:
@ -30,28 +29,29 @@ class CogVideoXI2VLoraTrainer(Trainer):
components.pipeline_cls = CogVideoXImageToVideoPipeline
components.tokenizer = AutoTokenizer.from_pretrained(
model_path, subfolder="tokenizer"
)
components.tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder="tokenizer")
components.text_encoder = T5EncoderModel.from_pretrained(
model_path, subfolder="text_encoder"
)
components.text_encoder = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder")
components.transformer = CogVideoXTransformer3DModel.from_pretrained(
model_path, subfolder="transformer"
)
components.transformer = CogVideoXTransformer3DModel.from_pretrained(model_path, subfolder="transformer")
components.vae = AutoencoderKLCogVideoX.from_pretrained(
model_path, subfolder="vae"
)
components.vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae")
components.scheduler = CogVideoXDPMScheduler.from_pretrained(
model_path, subfolder="scheduler"
)
components.scheduler = CogVideoXDPMScheduler.from_pretrained(model_path, subfolder="scheduler")
return components
@override
def initialize_pipeline(self) -> CogVideoXImageToVideoPipeline:
pipe = CogVideoXImageToVideoPipeline(
tokenizer=self.components.tokenizer,
text_encoder=self.components.text_encoder,
vae=self.components.vae,
transformer=unwrap_model(self.accelerator, self.components.transformer),
scheduler=self.components.scheduler,
)
return pipe
@override
def encode_video(self, video: torch.Tensor) -> torch.Tensor:
# shape of input video: [B, C, F, H, W]
@ -61,51 +61,62 @@ class CogVideoXI2VLoraTrainer(Trainer):
latent = latent_dist.sample() * vae.config.scaling_factor
return latent
@override
def encode_text(self, prompt: str) -> torch.Tensor:
prompt_token_ids = self.components.tokenizer(
prompt,
padding="max_length",
max_length=self.state.transformer_config.max_text_seq_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
prompt_token_ids = prompt_token_ids.input_ids
prompt_embedding = self.components.text_encoder(prompt_token_ids.to(self.accelerator.device))[0]
return prompt_embedding
@override
def collate_fn(self, samples: List[Dict[str, Any]]) -> Dict[str, Any]:
ret = {
"encoded_videos": [],
"prompt_token_ids": [],
"images": []
}
ret = {"encoded_videos": [], "prompt_embedding": [], "images": []}
for sample in samples:
encoded_video = sample["encoded_video"]
prompt = sample["prompt"]
prompt_embedding = sample["prompt_embedding"]
image = sample["image"]
# tokenize prompt
text_inputs = self.components.tokenizer(
prompt,
padding="max_length",
max_length=self.state.transformer_config.max_text_seq_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
ret["encoded_videos"].append(encoded_video)
ret["prompt_token_ids"].append(text_input_ids[0])
ret["prompt_embedding"].append(prompt_embedding)
ret["images"].append(image)
ret["encoded_videos"] = torch.stack(ret["encoded_videos"])
ret["prompt_token_ids"] = torch.stack(ret["prompt_token_ids"])
ret["prompt_embedding"] = torch.stack(ret["prompt_embedding"])
ret["images"] = torch.stack(ret["images"])
return ret
@override
def compute_loss(self, batch) -> torch.Tensor:
prompt_token_ids = batch["prompt_token_ids"]
prompt_embedding = batch["prompt_embedding"]
latent = batch["encoded_videos"]
images = batch["images"]
# Shape of prompt_embedding: [B, seq_len, hidden_size]
# Shape of latent: [B, C, F, H, W]
# Shape of images: [B, C, H, W]
patch_size_t = self.state.transformer_config.patch_size_t
if patch_size_t is not None:
ncopy = latent.shape[2] % patch_size_t
# Copy the first frame ncopy times to match patch_size_t
first_frame = latent[:, :, :1, :, :] # Get first frame [B, C, 1, H, W]
latent = torch.cat([first_frame.repeat(1, 1, ncopy, 1, 1), latent], dim=2)
assert latent.shape[2] % patch_size_t == 0
batch_size, num_channels, num_frames, height, width = latent.shape
# Get prompt embeddings
prompt_embeds = self.components.text_encoder(prompt_token_ids.to(self.accelerator.device))[0]
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.view(batch_size, seq_len, -1)
_, seq_len, _ = prompt_embedding.shape
prompt_embedding = prompt_embedding.view(batch_size, seq_len, -1)
# Add frame dimension to images [B,C,H,W] -> [B,C,F,H,W]
images = images.unsqueeze(2)
@ -113,13 +124,12 @@ class CogVideoXI2VLoraTrainer(Trainer):
image_noise_sigma = torch.normal(mean=-3.0, std=0.5, size=(1,), device=self.accelerator.device)
image_noise_sigma = torch.exp(image_noise_sigma).to(dtype=images.dtype)
noisy_images = images + torch.randn_like(images) * image_noise_sigma[:, None, None, None, None]
image_latent_dist = self.components.vae.encode(noisy_images).latent_dist
image_latent_dist = self.components.vae.encode(noisy_images.to(dtype=self.components.vae.dtype)).latent_dist
image_latents = image_latent_dist.sample() * self.components.vae.config.scaling_factor
# Sample a random timestep for each sample
timesteps = torch.randint(
0, self.components.scheduler.config.num_train_timesteps,
(batch_size,), device=self.accelerator.device
0, self.components.scheduler.config.num_train_timesteps, (batch_size,), device=self.accelerator.device
)
timesteps = timesteps.long()
@ -157,10 +167,12 @@ class CogVideoXI2VLoraTrainer(Trainer):
)
# Predict noise
ofs_emb = None if self.state.transformer_config.ofs_embed_dim is None else latent.new_full((1,), fill_value=2.0)
ofs_emb = (
None if self.state.transformer_config.ofs_embed_dim is None else latent.new_full((1,), fill_value=2.0)
)
predicted_noise = self.components.transformer(
hidden_states=latent_img_noisy,
encoder_hidden_states=prompt_embeds,
encoder_hidden_states=prompt_embedding,
timestep=timesteps,
ofs=ofs_emb,
image_rotary_emb=rotary_emb,
@ -182,7 +194,7 @@ class CogVideoXI2VLoraTrainer(Trainer):
@override
def validation_step(
self, eval_data: Dict[str, Any]
self, eval_data: Dict[str, Any], pipe: CogVideoXImageToVideoPipeline
) -> List[Tuple[str, Image.Image | List[Image.Image]]]:
"""
Return the data that needs to be saved. For videos, the data format is List[PIL],
@ -190,20 +202,13 @@ class CogVideoXI2VLoraTrainer(Trainer):
"""
prompt, image, video = eval_data["prompt"], eval_data["image"], eval_data["video"]
pipe = self.components.pipeline_cls(
tokenizer=self.components.tokenizer,
text_encoder=self.components.text_encoder,
vae=self.components.vae,
transformer=unwrap_model(self.accelerator, self.components.transformer),
scheduler=self.components.scheduler
)
video_generate = pipe(
num_frames=self.state.train_frames,
height=self.state.train_height,
width=self.state.train_width,
prompt=prompt,
image=image,
generator=self.state.generator
generator=self.state.generator,
).frames[0]
return [("video", video_generate)]
@ -214,7 +219,7 @@ class CogVideoXI2VLoraTrainer(Trainer):
num_frames: int,
transformer_config: Dict,
vae_scale_factor_spatial: int,
device: torch.device
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]:
grid_height = height // (vae_scale_factor_spatial * transformer_config.patch_size)
grid_width = width // (vae_scale_factor_spatial * transformer_config.patch_size)
@ -237,4 +242,4 @@ class CogVideoXI2VLoraTrainer(Trainer):
return freqs_cos, freqs_sin
register("cogvideox-i2v", "lora", CogVideoXI2VLoraTrainer)
register("cogvideox-i2v", "lora", CogVideoXI2VLoraTrainer)

View File

@ -1,28 +1,26 @@
import torch
from typing_extensions import override
from typing import Any, Dict, List, Tuple
from PIL import Image
from transformers import AutoTokenizer, T5EncoderModel
from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
from diffusers.models.embeddings import get_3d_rotary_pos_embed
import torch
from diffusers import (
CogVideoXPipeline,
CogVideoXTransformer3DModel,
AutoencoderKLCogVideoX,
CogVideoXDPMScheduler,
CogVideoXPipeline,
CogVideoXTransformer3DModel,
)
from diffusers.models.embeddings import get_3d_rotary_pos_embed
from PIL import Image
from transformers import AutoTokenizer, T5EncoderModel
from typing_extensions import override
from finetune.trainer import Trainer
from finetune.schemas import Components
from finetune.trainer import Trainer
from finetune.utils import unwrap_model
from ..utils import register
class CogVideoXT2VLoraTrainer(Trainer):
UNLOAD_LIST = ["text_encoder", "vae"]
@override
def load_components(self) -> Components:
@ -31,28 +29,29 @@ class CogVideoXT2VLoraTrainer(Trainer):
components.pipeline_cls = CogVideoXPipeline
components.tokenizer = AutoTokenizer.from_pretrained(
model_path, subfolder="tokenizer"
)
components.tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder="tokenizer")
components.text_encoder = T5EncoderModel.from_pretrained(
model_path, subfolder="text_encoder"
)
components.text_encoder = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder")
components.transformer = CogVideoXTransformer3DModel.from_pretrained(
model_path, subfolder="transformer"
)
components.transformer = CogVideoXTransformer3DModel.from_pretrained(model_path, subfolder="transformer")
components.vae = AutoencoderKLCogVideoX.from_pretrained(
model_path, subfolder="vae"
)
components.vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae")
components.scheduler = CogVideoXDPMScheduler.from_pretrained(
model_path, subfolder="scheduler"
)
components.scheduler = CogVideoXDPMScheduler.from_pretrained(model_path, subfolder="scheduler")
return components
@override
def initialize_pipeline(self) -> CogVideoXPipeline:
pipe = CogVideoXPipeline(
tokenizer=self.components.tokenizer,
text_encoder=self.components.text_encoder,
vae=self.components.vae,
transformer=unwrap_model(self.accelerator, self.components.transformer),
scheduler=self.components.scheduler,
)
return pipe
@override
def encode_video(self, video: torch.Tensor) -> torch.Tensor:
# shape of input video: [B, C, F, H, W]
@ -61,59 +60,77 @@ class CogVideoXT2VLoraTrainer(Trainer):
latent_dist = vae.encode(video).latent_dist
latent = latent_dist.sample() * vae.config.scaling_factor
return latent
@override
def encode_text(self, prompt: str) -> torch.Tensor:
prompt_token_ids = self.components.tokenizer(
prompt,
padding="max_length",
max_length=self.state.transformer_config.max_text_seq_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
prompt_token_ids = prompt_token_ids.input_ids
prompt_embedding = self.components.text_encoder(prompt_token_ids.to(self.accelerator.device))[0]
return prompt_embedding
@override
def collate_fn(self, samples: List[Dict[str, Any]]) -> Dict[str, Any]:
ret = {
"encoded_videos": [],
"prompt_token_ids": []
}
ret = {"encoded_videos": [], "prompt_embedding": []}
for sample in samples:
encoded_video = sample["encoded_video"]
prompt = sample["prompt"]
# tokenize prompt
text_inputs = self.components.tokenizer(
prompt,
padding="max_length",
max_length=226,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_embedding = sample["prompt_embedding"]
ret["encoded_videos"].append(encoded_video)
ret["prompt_token_ids"].append(text_input_ids[0])
ret["prompt_embedding"].append(prompt_embedding)
ret["encoded_videos"] = torch.stack(ret["encoded_videos"])
ret["prompt_token_ids"] = torch.stack(ret["prompt_token_ids"])
ret["prompt_embedding"] = torch.stack(ret["prompt_embedding"])
return ret
@override
def compute_loss(self, batch) -> torch.Tensor:
prompt_token_ids = batch["prompt_token_ids"]
prompt_embedding = batch["prompt_embedding"]
latent = batch["encoded_videos"]
# Shape of prompt_embedding: [B, seq_len, hidden_size]
# Shape of latent: [B, C, F, H, W]
patch_size_t = self.state.transformer_config.patch_size_t
if patch_size_t is not None and latent.shape[2] % patch_size_t != 0:
raise ValueError(
"Number of frames in latent must be divisible by patch size, please check your args for training."
)
# Add 2 random noise frames at the beginning of frame dimension
noise_frames = torch.randn(
latent.shape[0],
latent.shape[1],
2,
latent.shape[3],
latent.shape[4],
device=latent.device,
dtype=latent.dtype,
)
latent = torch.cat([noise_frames, latent], dim=2)
batch_size, num_channels, num_frames, height, width = latent.shape
# Get prompt embeddings
prompt_embeds = self.components.text_encoder(prompt_token_ids.to(self.accelerator.device))[0]
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.view(batch_size, seq_len, -1)
assert prompt_embeds.requires_grad is False
_, seq_len, _ = prompt_embedding.shape
prompt_embedding = prompt_embedding.view(batch_size, seq_len, -1)
# Sample a random timestep for each sample
timesteps = torch.randint(
0, self.components.scheduler.config.num_train_timesteps,
(batch_size,), device=self.accelerator.device
0, self.components.scheduler.config.num_train_timesteps, (batch_size,), device=self.accelerator.device
)
timesteps = timesteps.long()
# Add noise to latent
latent = latent.permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
latent = latent.permute(0, 2, 1, 3, 4) # from [B, C, F, H, W] to [B, F, C, H, W]
noise = torch.randn_like(latent)
latent_added_noise = self.components.scheduler.add_noise(latent, noise, timesteps)
@ -136,7 +153,7 @@ class CogVideoXT2VLoraTrainer(Trainer):
# Predict noise
predicted_noise = self.components.transformer(
hidden_states=latent_added_noise,
encoder_hidden_states=prompt_embeds,
encoder_hidden_states=prompt_embedding,
timestep=timesteps,
image_rotary_emb=rotary_emb,
return_dict=False,
@ -157,7 +174,7 @@ class CogVideoXT2VLoraTrainer(Trainer):
@override
def validation_step(
self, eval_data: Dict[str, Any]
self, eval_data: Dict[str, Any], pipe: CogVideoXPipeline
) -> List[Tuple[str, Image.Image | List[Image.Image]]]:
"""
Return the data that needs to be saved. For videos, the data format is List[PIL],
@ -165,19 +182,12 @@ class CogVideoXT2VLoraTrainer(Trainer):
"""
prompt, image, video = eval_data["prompt"], eval_data["image"], eval_data["video"]
pipe = self.components.pipeline_cls(
tokenizer=self.components.tokenizer,
text_encoder=self.components.text_encoder,
vae=self.components.vae,
transformer=unwrap_model(self.accelerator, self.components.transformer),
scheduler=self.components.scheduler
)
video_generate = pipe(
num_frames=self.state.train_frames,
num_frames=self.state.train_frames, # since we pad 2 frames in latent, we still use train_frames
height=self.state.train_height,
width=self.state.train_width,
prompt=prompt,
generator=self.state.generator
generator=self.state.generator,
).frames[0]
return [("video", video_generate)]
@ -188,7 +198,7 @@ class CogVideoXT2VLoraTrainer(Trainer):
num_frames: int,
transformer_config: Dict,
vae_scale_factor_spatial: int,
device: torch.device
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]:
grid_height = height // (vae_scale_factor_spatial * transformer_config.patch_size)
grid_width = width // (vae_scale_factor_spatial * transformer_config.patch_size)

View File

@ -1,4 +1,4 @@
from typing import Literal, Dict
from typing import Dict, Literal
from finetune.trainer import Trainer

View File

@ -1,5 +1,6 @@
from .args import Args
from .state import State
from .components import Components
from .state import State
__all__ = ["Args", "State", "Components"]
__all__ = ["Args", "State", "Components"]

View File

@ -1,9 +1,10 @@
import datetime
import argparse
from typing import Dict, Any, Literal, List, Tuple
from pydantic import BaseModel, field_validator, ValidationInfo
import datetime
import logging
from pathlib import Path
from typing import Any, List, Literal, Tuple
from pydantic import BaseModel, ValidationInfo, field_validator
class Args(BaseModel):
@ -78,10 +79,10 @@ class Args(BaseModel):
########## Validation ##########
do_validation: bool = False
validation_steps: int | None = None # if set, should be a multiple of checkpointing_steps
validation_dir: Path | None # if set do_validation, should not be None
validation_dir: Path | None # if set do_validation, should not be None
validation_prompts: str | None # if set do_validation, should not be None
validation_images: str | None # if set do_validation and model_type == i2v, should not be None
validation_videos: str | None # if set do_validation and model_type == v2v, should not be None
validation_images: str | None # if set do_validation and model_type == i2v, should not be None
validation_videos: str | None # if set do_validation and model_type == v2v, should not be None
gen_fps: int = 15
#### deprecated args: gen_video_resolution
@ -97,7 +98,9 @@ class Args(BaseModel):
def validate_image_column(cls, v: str | None, info: ValidationInfo) -> str | None:
values = info.data
if values.get("model_type") == "i2v" and not v:
raise ValueError("image_column must be specified when using i2v model")
logging.warning(
"No `image_column` specified for i2v model. Will automatically extract first frames from videos as conditioning images."
)
return v
@field_validator("validation_dir", "validation_prompts")
@ -115,7 +118,7 @@ class Args(BaseModel):
raise ValueError("validation_images must be specified when do_validation is True and model_type is i2v")
return v
@field_validator("validation_videos")
@field_validator("validation_videos")
def validate_validation_videos(cls, v: str | None, info: ValidationInfo) -> str | None:
values = info.data
if values.get("do_validation") and values.get("model_type") == "v2v" and not v:
@ -132,6 +135,39 @@ class Args(BaseModel):
raise ValueError("validation_steps must be a multiple of checkpointing_steps")
return v
@field_validator("train_resolution")
def validate_train_resolution(cls, v: Tuple[int, int, int], info: ValidationInfo) -> str:
try:
frames, height, width = v
# Check if (frames - 1) is multiple of 8
if (frames - 1) % 8 != 0:
raise ValueError("Number of frames - 1 must be a multiple of 8")
# Check resolution for cogvideox-5b models
model_name = info.data.get("model_name", "")
if model_name in ["cogvideox-5b-i2v", "cogvideox-5b-t2v"]:
if (height, width) != (480, 720):
raise ValueError("For cogvideox-5b models, height must be 480 and width must be 720")
return v
except ValueError as e:
if (
str(e) == "not enough values to unpack (expected 3, got 0)"
or str(e) == "invalid literal for int() with base 10"
):
raise ValueError("train_resolution must be in format 'frames x height x width'")
raise e
@field_validator("mixed_precision")
def validate_mixed_precision(cls, v: str, info: ValidationInfo) -> str:
if v == "fp16" and "cogvideox-2b" not in str(info.data.get("model_path", "")).lower():
logging.warning(
"All CogVideoX models except cogvideox-2b were trained with bfloat16. "
"Using fp16 precision may lead to training instability."
)
return v
@classmethod
def parse_args(cls):
@ -185,8 +221,7 @@ class Args(BaseModel):
# LoRA parameters
parser.add_argument("--rank", type=int, default=128)
parser.add_argument("--lora_alpha", type=int, default=64)
parser.add_argument("--target_modules", type=str, nargs="+",
default=["to_q", "to_k", "to_v", "to_out.0"])
parser.add_argument("--target_modules", type=str, nargs="+", default=["to_q", "to_k", "to_v", "to_out.0"])
# Checkpointing
parser.add_argument("--checkpointing_steps", type=int, default=200)
@ -203,7 +238,7 @@ class Args(BaseModel):
parser.add_argument("--gen_fps", type=int, default=15)
args = parser.parse_args()
# Convert video_resolution_buckets string to list of tuples
frames, height, width = args.train_resolution.split("x")
args.train_resolution = (int(frames), int(height), int(width))

View File

@ -1,4 +1,5 @@
from typing import Any
from pydantic import BaseModel

View File

@ -1,13 +1,14 @@
import torch
from pathlib import Path
from typing import List, Dict, Any
from pydantic import BaseModel, field_validator
from typing import Any, Dict, List
import torch
from pydantic import BaseModel
class State(BaseModel):
model_config = {"arbitrary_types_allowed": True}
train_frames: int
train_frames: int # user-defined training frames, **containing one image padding frame**
train_height: int
train_width: int

View File

@ -1,13 +1,18 @@
import argparse
import os
from pathlib import Path
import cv2
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--datadir", type=str, required=True, help="Root directory containing videos.txt and video subdirectory")
parser.add_argument(
"--datadir", type=str, required=True, help="Root directory containing videos.txt and video subdirectory"
)
return parser.parse_args()
args = parse_args()
# Create data/images directory if it doesn't exist
@ -24,24 +29,24 @@ with open(videos_file, "r") as f:
image_paths = []
for video_rel_path in video_paths:
video_path = data_dir / video_rel_path
# Open video
cap = cv2.VideoCapture(str(video_path))
# Read first frame
ret, frame = cap.read()
if not ret:
print(f"Failed to read video: {video_path}")
continue
# Save frame as PNG with same name as video
image_name = f"images/{video_path.stem}.png"
image_path = data_dir / image_name
cv2.imwrite(str(image_path), frame)
# Release video capture
cap.release()
print(f"Extracted first frame from {video_path} to {image_path}")
image_paths.append(image_name)
@ -49,4 +54,4 @@ for video_rel_path in video_paths:
images_file = data_dir / "images.txt"
with open(images_file, "w") as f:
for path in image_paths:
f.write(f"{path}\n")
f.write(f"{path}\n")

View File

@ -1,10 +1,11 @@
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent))
from finetune.schemas import Args
from finetune.models.utils import get_model_cls
from finetune.schemas import Args
def main():

View File

@ -1,76 +1,74 @@
import os
import json
import logging
import math
import json
import torch
import transformers
import diffusers
import wandb
from datetime import timedelta
from pathlib import Path
from tqdm import tqdm
from typing import Dict, Any, List, Tuple
from PIL import Image
from typing import Any, Dict, List, Tuple
from torch.utils.data import Dataset, DataLoader
from accelerate.logging import get_logger
import diffusers
import torch
import transformers
import wandb
from accelerate.accelerator import Accelerator, DistributedType
from accelerate.logging import get_logger
from accelerate.utils import (
DistributedDataParallelKwargs,
InitProcessGroupKwargs,
ProjectConfiguration,
set_seed,
gather_object,
set_seed,
)
from diffusers.optimization import get_scheduler
from diffusers.pipelines import DiffusionPipeline
from diffusers.utils.export_utils import export_to_video
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from finetune.schemas import Args, State, Components
from finetune.utils import (
unwrap_model, cast_training_params,
get_optimizer,
get_memory_statistics,
free_memory,
get_latest_ckpt_path_to_resume_from,
get_intermediate_ckpt_path,
get_latest_ckpt_path_to_resume_from,
get_intermediate_ckpt_path,
string_to_filename
)
from finetune.constants import LOG_LEVEL, LOG_NAME
from finetune.datasets import I2VDatasetWithResize, T2VDatasetWithResize
from finetune.datasets.utils import (
load_prompts, load_images, load_videos,
preprocess_image_with_resize, preprocess_video_with_resize
load_images,
load_prompts,
load_videos,
preprocess_image_with_resize,
preprocess_video_with_resize,
)
from finetune.schemas import Args, Components, State
from finetune.utils import (
cast_training_params,
free_memory,
get_intermediate_ckpt_path,
get_latest_ckpt_path_to_resume_from,
get_memory_statistics,
get_optimizer,
string_to_filename,
unload_model,
unwrap_model,
)
from finetune.constants import LOG_NAME, LOG_LEVEL
logger = get_logger(LOG_NAME, LOG_LEVEL)
_DTYPE_MAP = {
"fp32": torch.float32,
"fp16": torch.float16,
"fp16": torch.float16, # FP16 is Only Support for CogVideoX-2B
"bf16": torch.bfloat16,
}
class Trainer:
# If set, should be a list of components to unload (refer to `Components``)
UNLOAD_LIST: List[str] = None
def __init__(self, args: Args) -> None:
self.args = args
self.args = args
self.state = State(
weight_dtype=self.__get_training_dtype(),
train_frames=self.args.train_resolution[0],
train_height=self.args.train_resolution[1],
train_width=self.args.train_resolution[2]
train_width=self.args.train_resolution[2],
)
self.components = Components()
@ -133,6 +131,17 @@ class Trainer:
self.args.output_dir = Path(self.args.output_dir)
self.args.output_dir.mkdir(parents=True, exist_ok=True)
def check_setting(self) -> None:
# Check for unload_list
if self.UNLOAD_LIST is None:
logger.warning(
"\033[91mNo unload_list specified for this Trainer. All components will be loaded to GPU during training.\033[0m"
)
else:
for name in self.UNLOAD_LIST:
if name not in self.components.model_fields:
raise ValueError(f"Invalid component name in unload_list: {name}")
def prepare_models(self) -> None:
logger.info("Initializing models")
@ -150,33 +159,41 @@ class Trainer:
def prepare_dataset(self) -> None:
logger.info("Initializing dataset and dataloader")
# self.state.train_frames includes one padding frame for image conditioning
# so we only sample train_frames - 1 frames from the actual video
sample_frames = self.state.train_frames - 1
if self.args.model_type == "i2v":
self.dataset = I2VDatasetWithResize(
**(self.args.model_dump()),
device=self.accelerator.device,
encode_video_fn=self.encode_video,
max_num_frames=self.state.train_frames,
max_num_frames=sample_frames,
height=self.state.train_height,
width=self.state.train_width
width=self.state.train_width,
trainer=self,
)
elif self.args.model_type == "t2v":
self.dataset = T2VDatasetWithResize(
**(self.args.model_dump()),
device=self.accelerator.device,
encode_video_fn=self.encode_video,
max_num_frames=self.state.train_frames,
max_num_frames=sample_frames,
height=self.state.train_height,
width=self.state.train_width
width=self.state.train_width,
trainer=self,
)
else:
raise ValueError(f"Invalid model type: {self.args.model_type}")
# Prepare VAE for encoding
self.components.vae = self.components.vae.to(self.accelerator.device)
# Prepare VAE and text encoder for encoding
self.components.vae.requires_grad_(False)
self.components.text_encoder.requires_grad_(False)
self.components.vae = self.components.vae.to(self.accelerator.device, dtype=self.state.weight_dtype)
self.components.text_encoder = self.components.text_encoder.to(
self.accelerator.device, dtype=self.state.weight_dtype
)
# Precompute latent for video
logger.info("Precomputing latent for video ...")
# Precompute latent for video and prompt embedding
logger.info("Precomputing latent for video and prompt embedding ...")
tmp_data_loader = torch.utils.data.DataLoader(
self.dataset,
collate_fn=self.collate_fn,
@ -185,8 +202,14 @@ class Trainer:
pin_memory=self.args.pin_memory,
)
tmp_data_loader = self.accelerator.prepare_data_loader(tmp_data_loader)
for _ in tmp_data_loader: ...
logger.info("Precomputing latent for video ... Done")
for _ in tmp_data_loader:
...
self.accelerator.wait_for_everyone()
logger.info("Precomputing latent for video and prompt embedding ... Done")
unload_model(self.components.vae)
unload_model(self.components.text_encoder)
free_memory()
self.data_loader = torch.utils.data.DataLoader(
self.dataset,
@ -194,16 +217,15 @@ class Trainer:
batch_size=self.args.batch_size,
num_workers=self.args.num_workers,
pin_memory=self.args.pin_memory,
shuffle=True
shuffle=True,
)
def prepare_trainable_parameters(self):
logger.info("Initializing trainable parameters")
# For now only lora is supported
for attr_name, component in vars(self.components).items():
if hasattr(component, 'requires_grad_'):
if hasattr(component, "requires_grad_"):
component.requires_grad_(False)
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
@ -216,7 +238,7 @@ class Trainer:
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
)
self.__move_components_to_device()
self.__load_components()
if self.args.gradient_checkpointing:
self.components.transformer.enable_gradient_checkpointing()
@ -234,7 +256,7 @@ class Trainer:
logger.info("Initializing optimizer and lr scheduler")
# Make sure the trainable params are in float32
if self.args.mixed_precision == "fp16":
if self.args.mixed_precision != "no":
# only upcast trainable parameters (LoRA) into fp32
cast_training_params([self.components.transformer], dtype=torch.float32)
@ -308,7 +330,7 @@ class Trainer:
# Afterwards we recalculate our number of training epochs
self.args.train_epochs = math.ceil(self.args.train_steps / num_update_steps_per_epoch)
self.state.num_update_steps_per_epoch = num_update_steps_per_epoch
def prepare_for_validation(self):
validation_prompts = load_prompts(self.args.validation_dir / self.args.validation_prompts)
@ -423,17 +445,17 @@ class Trainer:
global_step += 1
self.__maybe_save_checkpoint(global_step)
# Maybe run validation
should_run_validation = (
self.args.do_validation
and global_step % self.args.validation_steps == 0
)
if should_run_validation:
self.validate(global_step)
logs["loss"] = loss.detach().item()
logs["lr"] = self.lr_scheduler.get_last_lr()[0]
progress_bar.set_postfix(logs)
# Maybe run validation
should_run_validation = self.args.do_validation and global_step % self.args.validation_steps == 0
if should_run_validation:
del loss
free_memory()
self.validate(global_step)
accelerator.log(logs, step=global_step)
if global_step >= self.args.train_steps:
@ -445,6 +467,7 @@ class Trainer:
accelerator.wait_for_everyone()
self.__maybe_save_checkpoint(global_step, must_save=True)
if self.args.do_validation:
free_memory()
self.validate(global_step)
del self.components
@ -465,10 +488,22 @@ class Trainer:
return
self.components.transformer.eval()
torch.set_grad_enabled(False)
memory_statistics = get_memory_statistics()
logger.info(f"Memory before validation start: {json.dumps(memory_statistics, indent=4)}")
##### Initialize pipeline #####
pipe = self.initialize_pipeline()
# Or use pipe.enable_sequential_cpu_offload() to further reduce memory usage
pipe.enable_model_cpu_offload(device=self.accelerator.device)
# Convert all model weights to training dtype
# Note, this will change LoRA weights in self.components.transformer to training dtype, rather than keep them in fp32
pipe = pipe.to(dtype=self.state.weight_dtype)
#################################
all_processes_artifacts = []
for i in range(num_validation_samples):
# Skip current validation on all processes but one
@ -480,9 +515,7 @@ class Trainer:
video = self.state.validation_videos[i]
if image is not None:
image = preprocess_image_with_resize(
image, self.state.train_height, self.state.train_width
)
image = preprocess_image_with_resize(image, self.state.train_height, self.state.train_width)
# Convert image tensor (C, H, W) to PIL images
image = image.to(torch.uint8)
image = image.permute(1, 2, 0).cpu().numpy()
@ -494,17 +527,13 @@ class Trainer:
)
# Convert video tensor (F, C, H, W) to list of PIL images
video = (video * 255).round().clamp(0, 255).to(torch.uint8)
video = [Image.fromarray(frame.permute(1,2,0).cpu().numpy()) for frame in video]
video = [Image.fromarray(frame.permute(1, 2, 0).cpu().numpy()) for frame in video]
logger.debug(
f"Validating sample {i + 1}/{num_validation_samples} on process {accelerator.process_index}. Prompt: {prompt}",
main_process_only=False,
)
validation_artifacts = self.validation_step({
"prompt": prompt,
"image": image,
"video": video
})
validation_artifacts = self.validation_step({"prompt": prompt, "image": image, "video": video}, pipe)
prompt_filename = string_to_filename(prompt)[:25]
artifacts = {
"image": {"type": "image", "value": image},
@ -555,6 +584,15 @@ class Trainer:
step=step,
)
pipe.remove_all_hooks()
del pipe
# Unload models except those needed for training
self.__unload_components()
# Load models except those not needed for training
self.__load_components()
# Change LoRA weights back to fp32
cast_training_params([self.components.transformer], dtype=torch.float32)
accelerator.wait_for_everyone()
free_memory()
@ -562,9 +600,11 @@ class Trainer:
logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}")
torch.cuda.reset_peak_memory_stats(accelerator.device)
torch.set_grad_enabled(True)
self.components.transformer.train()
def fit(self):
self.check_setting()
self.prepare_models()
self.prepare_dataset()
self.prepare_trainable_parameters()
@ -577,14 +617,22 @@ class Trainer:
def collate_fn(self, examples: List[Dict[str, Any]]):
raise NotImplementedError
def load_components(self) -> Components:
raise NotImplementedError
def initialize_pipeline(self) -> DiffusionPipeline:
raise NotImplementedError
def encode_video(self, video: torch.Tensor) -> torch.Tensor:
# shape of input video: [B, C, F, H, W], where B = 1
# shape of output video: [B, C', F', H', W'], where B = 1
raise NotImplementedError
def encode_text(self, text: str) -> torch.Tensor:
# shape of output text: [batch size, sequence length, embedding dimension]
raise NotImplementedError
def compute_loss(self, batch) -> torch.Tensor:
raise NotImplementedError
@ -601,11 +649,21 @@ class Trainer:
else:
raise ValueError(f"Invalid mixed precision: {self.args.mixed_precision}")
def __move_components_to_device(self):
def __load_components(self):
components = self.components.model_dump()
for name, component in components.items():
if not isinstance(component, type) and hasattr(component, 'to'):
setattr(self.components, name, component.to(self.accelerator.device))
if not isinstance(component, type) and hasattr(component, "to"):
if name in self.UNLOAD_LIST:
continue
# setattr(self.components, name, component.to(self.accelerator.device))
setattr(self.components, name, component.to(self.accelerator.device, dtype=self.state.weight_dtype))
def __unload_components(self):
components = self.components.model_dump()
for name, component in components.items():
if not isinstance(component, type) and hasattr(component, "to"):
if name in self.UNLOAD_LIST:
setattr(self.components, name, component.to("cpu"))
def __prepare_saving_loading_hooks(self, transformer_lora_config):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
@ -642,9 +700,7 @@ class Trainer:
):
transformer_ = unwrap_model(self.accelerator, model)
else:
raise ValueError(
f"Unexpected save model: {unwrap_model(self.accelerator, model).__class__}"
)
raise ValueError(f"Unexpected save model: {unwrap_model(self.accelerator, model).__class__}")
else:
transformer_ = unwrap_model(self.accelerator, self.components.transformer).__class__.from_pretrained(
self.args.model_path, subfolder="transformer"

View File

@ -1,5 +1,5 @@
from .torch_utils import *
from .optimizer_utils import *
from .memory_utils import *
from .checkpointing import *
from .file_utils import *
from .memory_utils import *
from .optimizer_utils import *
from .torch_utils import *

View File

@ -1,10 +1,12 @@
import os
from pathlib import Path
from typing import Tuple
from accelerate.logging import get_logger
from finetune.constants import LOG_NAME, LOG_LEVEL
from ..utils.file_utils import find_files, delete_files
from finetune.constants import LOG_LEVEL, LOG_NAME
from ..utils.file_utils import delete_files, find_files
logger = get_logger(LOG_NAME, LOG_LEVEL)

View File

@ -1,11 +1,12 @@
import logging
import os
import shutil
from pathlib import Path
from typing import Any, Dict, List, Union
from accelerate.logging import get_logger
from finetune.constants import LOG_NAME, LOG_LEVEL
from finetune.constants import LOG_LEVEL, LOG_NAME
logger = get_logger(LOG_NAME, LOG_LEVEL)

View File

@ -1,10 +1,10 @@
import gc
import torch
from typing import Any, Dict, Union
import torch
from accelerate.logging import get_logger
from finetune.constants import LOG_NAME, LOG_LEVEL
from finetune.constants import LOG_LEVEL, LOG_NAME
logger = get_logger(LOG_NAME, LOG_LEVEL)
@ -51,6 +51,10 @@ def free_memory() -> None:
# TODO(aryan): handle non-cuda devices
def unload_model(model):
model.to("cpu")
def make_contiguous(x: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
if isinstance(x, torch.Tensor):
return x.contiguous()

View File

@ -1,9 +1,9 @@
import inspect
import torch
import torch
from accelerate.logging import get_logger
from finetune.constants import LOG_NAME, LOG_LEVEL
from finetune.constants import LOG_LEVEL, LOG_NAME
logger = get_logger(LOG_NAME, LOG_LEVEL)

View File

@ -1,4 +1,4 @@
from typing import Dict, Optional, Union, List
from typing import Dict, List, Optional, Union
import torch
from accelerate import Accelerator
@ -49,4 +49,4 @@ def cast_training_params(model: Union[torch.nn.Module, List[torch.nn.Module]], d
for param in m.parameters():
# only upcast trainable parameters into fp32
if param.requires_grad:
param.data = param.to(dtype)
param.data = param.to(dtype)

View File

@ -17,20 +17,20 @@ $ python cli_demo.py --prompt "A girl riding a bike." --model_path THUDM/CogVide
Additional options are available to specify the model path, guidance scale, number of inference steps, video generation type, and output paths.
"""
import logging
import argparse
import logging
from typing import Literal, Optional
import torch
from diffusers import (
CogVideoXPipeline,
CogVideoXDPMScheduler,
CogVideoXImageToVideoPipeline,
CogVideoXPipeline,
CogVideoXVideoToVideoPipeline,
)
from diffusers.utils import export_to_video, load_image, load_video
logging.basicConfig(level=logging.INFO)
# Recommended resolution for each model (width, height)
@ -38,7 +38,6 @@ RESOLUTION_MAP = {
# cogvideox1.5-*
"cogvideox1.5-5b-i2v": (1360, 768),
"cogvideox1.5-5b": (1360, 768),
# cogvideox-*
"cogvideox-5b-i2v": (720, 480),
"cogvideox-5b": (720, 480),
@ -100,10 +99,14 @@ def generate_video(
elif (width, height) != desired_resolution:
if generate_type == "i2v":
# For i2v models, use user-defined width and height
logging.warning(f"\033[1;31mThe width({width}) and height({height}) are not recommended for {model_name}. The best resolution is {desired_resolution}.\033[0m")
logging.warning(
f"\033[1;31mThe width({width}) and height({height}) are not recommended for {model_name}. The best resolution is {desired_resolution}.\033[0m"
)
else:
# Otherwise, use the recommended width and height
logging.warning(f"\033[1;31m{model_name} is not supported for custom resolution. Setting back to default resolution {desired_resolution}.\033[0m")
logging.warning(
f"\033[1;31m{model_name} is not supported for custom resolution. Setting back to default resolution {desired_resolution}.\033[0m"
)
width, height = desired_resolution
if generate_type == "i2v":

View File

@ -10,7 +10,6 @@ import torch
from torch import nn
from sgm.modules import UNCONDITIONAL_CONFIG
from sgm.modules.autoencoding.temporal_ae import VideoDecoder
from sgm.modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
from sgm.util import (
default,
@ -90,27 +89,37 @@ class SATVideoDiffusionEngine(nn.Module):
self.no_cond_log = no_cond_log
self.device = args.device
# put lora add here
def disable_untrainable_params(self):
total_trainable = 0
for n, p in self.named_parameters():
if p.requires_grad == False:
continue
flag = False
for prefix in self.not_trainable_prefixes:
if n.startswith(prefix) or prefix == "all":
flag = True
break
if self.lora_train:
for n, p in self.named_parameters():
if p.requires_grad == False:
continue
if 'lora_layer' not in n:
p.lr_scale = 0
else:
total_trainable += p.numel()
else:
for n, p in self.named_parameters():
if p.requires_grad == False:
continue
flag = False
for prefix in self.not_trainable_prefixes:
if n.startswith(prefix) or prefix == "all":
flag = True
break
lora_prefix = ["matrix_A", "matrix_B"]
for prefix in lora_prefix:
if prefix in n:
flag = False
break
lora_prefix = ['matrix_A', 'matrix_B']
for prefix in lora_prefix:
if prefix in n:
flag = False
break
if flag:
p.requires_grad_(False)
else:
total_trainable += p.numel()
if flag:
p.requires_grad_(False)
else:
total_trainable += p.numel()
print_rank0("***** Total trainable parameters: " + str(total_trainable) + " *****")
@ -182,11 +191,7 @@ class SATVideoDiffusionEngine(nn.Module):
for n in range(n_rounds):
z_now = z[n * n_samples : (n + 1) * n_samples, :, 1:]
latent_time = z_now.shape[2] # check the time latent
temporal_compress_times = 4
fake_cp_size = min(10, latent_time // 2)
start_frame = 0
recons = []
start_frame = 0
for i in range(fake_cp_size):

View File

@ -31,6 +31,7 @@ class ImagePatchEmbeddingMixin(BaseMixin):
def word_embedding_forward(self, input_ids, **kwargs):
images = kwargs["images"] # (b,t,c,h,w)
emb = rearrange(images, "b t c h w -> b (t h w) c")
# emb = rearrange(images, "b c t h w -> b (t h w) c")
emb = rearrange(
emb,
"b (t o h p w q) c -> b (t h w) (c o p q)",
@ -810,7 +811,9 @@ class DiffusionTransformer(BaseModel):
),
reinit=True,
)
if "lora_config" in module_configs:
lora_config = module_configs["lora_config"]
self.add_mixin("lora", instantiate_from_config(lora_config, layer_num=self.num_layers), reinit=True)
return
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):