mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 03:04:56 +08:00
Merge branch 'main' into moviepy-v2
This commit is contained in:
commit
aa12ed37f5
@ -4,22 +4,34 @@
|
||||
|
||||
[日本語で読む](./README_ja.md)
|
||||
|
||||
If you're looking for the fine-tuning instructions for the SAT version, please check [here](../sat/README_zh.md). The dataset format for this version differs from the one used here.
|
||||
If you're looking for the fine-tuning instructions for the SAT version, please check [here](../sat/README_zh.md). The
|
||||
dataset format for this version differs from the one used here.
|
||||
|
||||
## Hardware Requirements
|
||||
|
||||
| Model | Training Type | Mixed Precision | Training Resolution (frames x height x width) | Hardware Requirements |
|
||||
|---------------------|-----------------|----------------|---------------------------------------------|------------------------|
|
||||
| cogvideox-t2v-2b | lora (rank128) | fp16 | 49x480x720 | 16GB VRAM (NVIDIA 4080) |
|
||||
| cogvideox-t2v-5b | lora (rank128) | bf16 | 49x480x720 | 24GB VRAM (NVIDIA 4090) |
|
||||
| cogvideox-i2v-5b | lora (rank128) | bf16 | 49x480x720 | 24GB VRAM (NVIDIA 4090) |
|
||||
| cogvideox1.5-t2v-5b | lora (rank128) | bf16 | 81x768x1360 | 35GB VRAM (NVIDIA A100) |
|
||||
| cogvideox1.5-i2v-5b | lora (rank128) | bf16 | 81x768x1360 | 35GB VRAM (NVIDIA A100) |
|
||||
|
||||
| Model | Training Type | Distribution Strategy | Mixed Precision | Training Resolution (FxHxW) | Hardware Requirements |
|
||||
|----------------------------|----------------|--------------------------------------|-----------------|-----------------------------|-------------------------|
|
||||
| cogvideox-t2v-2b | lora (rank128) | DDP | fp16 | 49x480x720 | 16GB VRAM (NVIDIA 4080) |
|
||||
| cogvideox-{t2v, i2v}-5b | lora (rank128) | DDP | bf16 | 49x480x720 | 24GB VRAM (NVIDIA 4090) |
|
||||
| cogvideox1.5-{t2v, i2v}-5b | lora (rank128) | DDP | bf16 | 81x768x1360 | 35GB VRAM (NVIDIA A100) |
|
||||
| cogvideox-t2v-2b | sft | DDP | fp16 | 49x480x720 | 36GB VRAM (NVIDIA A100) |
|
||||
| cogvideox-t2v-2b | sft | 1-GPU zero-2 + opt offload | fp16 | 49x480x720 | 17GB VRAM (NVIDIA 4090) |
|
||||
| cogvideox-t2v-2b | sft | 8-GPU zero-2 | fp16 | 49x480x720 | 17GB VRAM (NVIDIA 4090) |
|
||||
| cogvideox-t2v-2b | sft | 8-GPU zero-3 | fp16 | 49x480x720 | 19GB VRAM (NVIDIA 4090) |
|
||||
| cogvideox-t2v-2b | sft | 8-GPU zero-3 + opt and param offload | bf16 | 49x480x720 | 14GB VRAM (NVIDIA 4080) |
|
||||
| cogvideox-{t2v, i2v}-5b | sft | 1-GPU zero-2 + opt offload | bf16 | 49x480x720 | 42GB VRAM (NVIDIA A100) |
|
||||
| cogvideox-{t2v, i2v}-5b | sft | 8-GPU zero-2 | bf16 | 49x480x720 | 42GB VRAM (NVIDIA 4090) |
|
||||
| cogvideox-{t2v, i2v}-5b | sft | 8-GPU zero-3 | bf16 | 49x480x720 | 43GB VRAM (NVIDIA 4090) |
|
||||
| cogvideox-{t2v, i2v}-5b | sft | 8-GPU zero-3 + opt and param offload | bf16 | 49x480x720 | 28GB VRAM (NVIDIA 5090) |
|
||||
| cogvideox1.5-{t2v, i2v}-5b | sft | 1-GPU zero-2 + opt offload | bf16 | 81x768x1360 | 56GB VRAM (NVIDIA A100) |
|
||||
| cogvideox1.5-{t2v, i2v}-5b | sft | 8-GPU zero-2 | bf16 | 81x768x1360 | 55GB VRAM (NVIDIA A100) |
|
||||
| cogvideox1.5-{t2v, i2v}-5b | sft | 8-GPU zero-3 | bf16 | 81x768x1360 | 55GB VRAM (NVIDIA A100) |
|
||||
| cogvideox1.5-{t2v, i2v}-5b | sft | 8-GPU zero-3 + opt and param offload | bf16 | 81x768x1360 | 40GB VRAM (NVIDIA A100) |
|
||||
|
||||
## Install Dependencies
|
||||
|
||||
Since the relevant code has not yet been merged into the official `diffusers` release, you need to fine-tune based on the diffusers branch. Follow the steps below to install the dependencies:
|
||||
Since the relevant code has not yet been merged into the official `diffusers` release, you need to fine-tune based on
|
||||
the diffusers branch. Follow the steps below to install the dependencies:
|
||||
|
||||
```shell
|
||||
git clone https://github.com/huggingface/diffusers.git
|
||||
@ -29,7 +41,8 @@ pip install -e .
|
||||
|
||||
## Prepare the Dataset
|
||||
|
||||
First, you need to prepare your dataset. Depending on your task type (T2V or I2V), the dataset format will vary slightly:
|
||||
First, you need to prepare your dataset. Depending on your task type (T2V or I2V), the dataset format will vary
|
||||
slightly:
|
||||
|
||||
```
|
||||
.
|
||||
@ -41,62 +54,93 @@ First, you need to prepare your dataset. Depending on your task type (T2V or I2V
|
||||
```
|
||||
|
||||
Where:
|
||||
|
||||
- `prompts.txt`: Contains the prompts
|
||||
- `videos/`: Contains the .mp4 video files
|
||||
- `videos.txt`: Contains the list of video files in the `videos/` directory
|
||||
- `images/`: (Optional) Contains the .png reference image files
|
||||
- `images.txt`: (Optional) Contains the list of reference image files
|
||||
|
||||
You can download a sample dataset (T2V) [Disney Steamboat Willie](https://huggingface.co/datasets/Wild-Heart/Disney-VideoGeneration-Dataset).
|
||||
You can download a sample dataset (
|
||||
T2V) [Disney Steamboat Willie](https://huggingface.co/datasets/Wild-Heart/Disney-VideoGeneration-Dataset).
|
||||
|
||||
If you need to use a validation dataset during training, make sure to provide a validation dataset with the same format as the training dataset.
|
||||
If you need to use a validation dataset during training, make sure to provide a validation dataset with the same format
|
||||
as the training dataset.
|
||||
|
||||
## Run the Script to Start Fine-tuning
|
||||
## Running Scripts to Start Fine-tuning
|
||||
|
||||
Before starting the training, please note the following resolution requirements:
|
||||
Before starting training, please note the following resolution requirements:
|
||||
|
||||
1. The number of frames must be a multiple of 8 **plus 1** (i.e., 8N+1), such as 49, 81, etc.
|
||||
2. The recommended resolution for videos is:
|
||||
- CogVideoX: 480x720 (Height x Width)
|
||||
- CogVideoX1.5: 768x1360 (Height x Width)
|
||||
3. For samples that do not meet the required resolution (videos or images), the code will automatically resize them. This may distort the aspect ratio and impact training results. We recommend preprocessing the samples (e.g., using crop + resize to maintain aspect ratio) before training.
|
||||
1. The number of frames must be a multiple of 8 **plus 1** (i.e., 8N+1), such as 49, 81 ...
|
||||
2. Recommended video resolutions for each model:
|
||||
- CogVideoX: 480x720 (height x width)
|
||||
- CogVideoX1.5: 768x1360 (height x width)
|
||||
3. For samples (videos or images) that don't match the training resolution, the code will directly resize them. This may
|
||||
cause aspect ratio distortion and affect training results. It's recommended to preprocess your samples (e.g., using
|
||||
crop + resize to maintain aspect ratio) before training.
|
||||
|
||||
> **Important Note**: To improve training efficiency, we will automatically encode videos and cache the results on disk. If you modify the data after training has begun, please delete the `latent` directory under the `videos/` folder to ensure that the latest data is used.
|
||||
> **Important Note**: To improve training efficiency, we automatically encode videos and cache the results on disk
|
||||
> before training. If you modify the data after training, please delete the latent directory under the video directory to
|
||||
> ensure the latest data is used.
|
||||
|
||||
### Text-to-Video (T2V) Fine-tuning
|
||||
### LoRA
|
||||
|
||||
```bash
|
||||
# Modify the configuration parameters in accelerate_train_t2v.sh
|
||||
# The main parameters to modify are:
|
||||
# Modify configuration parameters in train_ddp_t2v.sh
|
||||
# Main parameters to modify:
|
||||
# --output_dir: Output directory
|
||||
# --data_root: Root directory of the dataset
|
||||
# --caption_column: Path to the prompt file
|
||||
# --video_column: Path to the video list file
|
||||
# --data_root: Dataset root directory
|
||||
# --caption_column: Path to prompt file
|
||||
# --image_column: Optional for I2V, path to reference image file list (remove this parameter to use the first frame of video as image condition)
|
||||
# --video_column: Path to video file list
|
||||
# --train_resolution: Training resolution (frames x height x width)
|
||||
# Refer to the start script for other important parameters
|
||||
# For other important parameters, please refer to the launch script
|
||||
|
||||
bash accelerate_train_t2v.sh
|
||||
bash train_ddp_t2v.sh # Text-to-Video (T2V) fine-tuning
|
||||
bash train_ddp_i2v.sh # Image-to-Video (I2V) fine-tuning
|
||||
```
|
||||
|
||||
### Image-to-Video (I2V) Fine-tuning
|
||||
### SFT
|
||||
|
||||
We provide several zero configuration templates in the `configs/` directory. Please choose the appropriate training
|
||||
configuration based on your needs (configure the `deepspeed_config_file` option in `accelerate_config.yaml`).
|
||||
|
||||
```bash
|
||||
# Modify the configuration parameters in accelerate_train_i2v.sh
|
||||
# In addition to modifying the same parameters as for T2V, you also need to set:
|
||||
# --image_column: Path to the reference image list file(if not provided, remove use this parameter)
|
||||
# Refer to the start script for other important parameters
|
||||
# Parameters to configure are the same as LoRA training
|
||||
|
||||
bash accelerate_train_i2v.sh
|
||||
bash train_zero_t2v.sh # Text-to-Video (T2V) fine-tuning
|
||||
bash train_zero_i2v.sh # Image-to-Video (I2V) fine-tuning
|
||||
```
|
||||
|
||||
In addition to setting the bash script parameters, you need to set the relevant training options in the zero
|
||||
configuration file and ensure the zero training configuration matches the parameters in the bash script, such as
|
||||
batch_size, gradient_accumulation_steps, mixed_precision. For details, please refer to
|
||||
the [DeepSpeed official documentation](https://www.deepspeed.ai/docs/config-json/)
|
||||
|
||||
When using SFT training, please note:
|
||||
|
||||
1. For SFT training, model offload is not used during validation, so the peak VRAM usage may exceed 24GB. For GPUs with
|
||||
less than 24GB VRAM, it's recommended to disable validation.
|
||||
|
||||
2. Validation is slow when zero-3 is enabled, so it's recommended to disable validation when using zero-3.
|
||||
|
||||
## Load the Fine-tuned Model
|
||||
|
||||
+ Please refer to [cli_demo.py](../inference/cli_demo.py) for instructions on how to load the fine-tuned model.
|
||||
|
||||
+ For SFT trained models, please first use the `zero_to_fp32.py` script in the `checkpoint-*/` directory to merge the
|
||||
model weights
|
||||
|
||||
## Best Practices
|
||||
|
||||
+ We included 70 training videos with a resolution of `200 x 480 x 720` (frames x height x width). Through frame skipping in the data preprocessing, we created two smaller datasets with 49 and 16 frames to speed up experiments. The maximum frame count recommended by the CogVideoX team is 49 frames. These 70 videos were divided into three groups: 10, 25, and 50 videos, with similar conceptual nature.
|
||||
+ We included 70 training videos with a resolution of `200 x 480 x 720` (frames x height x width). Through frame
|
||||
skipping in the data preprocessing, we created two smaller datasets with 49 and 16 frames to speed up experiments. The
|
||||
maximum frame count recommended by the CogVideoX team is 49 frames. These 70 videos were divided into three groups:
|
||||
10, 25, and 50 videos, with similar conceptual nature.
|
||||
+ Videos with 25 or more frames work best for training new concepts and styles.
|
||||
+ It's recommended to use an identifier token, which can be specified using `--id_token`, for better training results. This is similar to Dreambooth training, though regular fine-tuning without using this token will still work.
|
||||
+ The original repository uses `lora_alpha` set to 1. We found that this value performed poorly in several runs, possibly due to differences in the model backend and training settings. Our recommendation is to set `lora_alpha` to be equal to the rank or `rank // 2`.
|
||||
+ It's recommended to use an identifier token, which can be specified using `--id_token`, for better training results.
|
||||
This is similar to Dreambooth training, though regular fine-tuning without using this token will still work.
|
||||
+ The original repository uses `lora_alpha` set to 1. We found that this value performed poorly in several runs,
|
||||
possibly due to differences in the model backend and training settings. Our recommendation is to set `lora_alpha` to
|
||||
be equal to the rank or `rank // 2`.
|
||||
+ It's advised to use a rank of 64 or higher.
|
@ -8,13 +8,25 @@ SATバージョンのファインチューニング手順については、[こ
|
||||
|
||||
## ハードウェア要件
|
||||
|
||||
| モデル | トレーニングタイプ | 混合精度学習 | トレーニング解像度(フレーム数x高さx幅) | ハードウェア要件 |
|
||||
|----------------------|-----------------|------------|----------------------------------|----------------|
|
||||
| cogvideox-t2v-2b | lora (rank128) | fp16 | 49x480x720 | 16GB VRAM (NVIDIA 4080) |
|
||||
| cogvideox-t2v-5b | lora (rank128) | bf16 | 49x480x720 | 24GB VRAM (NVIDIA 4090) |
|
||||
| cogvideox-i2v-5b | lora (rank128) | bf16 | 49x480x720 | 24GB VRAM (NVIDIA 4090) |
|
||||
| cogvideox1.5-t2v-5b | lora (rank128) | bf16 | 81x768x1360 | 35GB VRAM (NVIDIA A100) |
|
||||
| cogvideox1.5-i2v-5b | lora (rank128) | bf16 | 81x768x1360 | 35GB VRAM (NVIDIA A100) |
|
||||
| モデル | トレーニングタイプ | 分散戦略 | 混合トレーニング精度 | トレーニング解像度(フレーム数×高さ×幅) | ハードウェア要件 |
|
||||
|----------------------------|-------------------|----------------------------------|-----------------|-----------------------------|----------------------------|
|
||||
| cogvideox-t2v-2b | lora (rank128) | DDP | fp16 | 49x480x720 | 16GB VRAM (NVIDIA 4080) |
|
||||
| cogvideox-{t2v, i2v}-5b | lora (rank128) | DDP | bf16 | 49x480x720 | 24GB VRAM (NVIDIA 4090) |
|
||||
| cogvideox1.5-{t2v, i2v}-5b | lora (rank128) | DDP | bf16 | 81x768x1360 | 35GB VRAM (NVIDIA A100) |
|
||||
| cogvideox-t2v-2b | sft | DDP | fp16 | 49x480x720 | 36GB VRAM (NVIDIA A100) |
|
||||
| cogvideox-t2v-2b | sft | 1カード zero-2 + オプティマイゼーションオフロード | fp16 | 49x480x720 | 17GB VRAM (NVIDIA 4090) |
|
||||
| cogvideox-t2v-2b | sft | 8カード zero-2 | fp16 | 49x480x720 | 17GB VRAM (NVIDIA 4090) |
|
||||
| cogvideox-t2v-2b | sft | 8カード zero-3 | fp16 | 49x480x720 | 19GB VRAM (NVIDIA 4090) |
|
||||
| cogvideox-t2v-2b | sft | 8カード zero-3 + オプティマイゼーションとパラメータオフロード | bf16 | 49x480x720 | 14GB VRAM (NVIDIA 4080) |
|
||||
| cogvideox-{t2v, i2v}-5b | sft | 1カード zero-2 + オプティマイゼーションオフロード | bf16 | 49x480x720 | 42GB VRAM (NVIDIA A100) |
|
||||
| cogvideox-{t2v, i2v}-5b | sft | 8カード zero-2 | bf16 | 49x480x720 | 42GB VRAM (NVIDIA 4090) |
|
||||
| cogvideox-{t2v, i2v}-5b | sft | 8カード zero-3 | bf16 | 49x480x720 | 43GB VRAM (NVIDIA 4090) |
|
||||
| cogvideox-{t2v, i2v}-5b | sft | 8カード zero-3 + オプティマイゼーションとパラメータオフロード | bf16 | 49x480x720 | 28GB VRAM (NVIDIA 5090) |
|
||||
| cogvideox1.5-{t2v, i2v}-5b | sft | 1カード zero-2 + オプティマイゼーションオフロード | bf16 | 81x768x1360 | 56GB VRAM (NVIDIA A100) |
|
||||
| cogvideox1.5-{t2v, i2v}-5b | sft | 8カード zero-2 | bf16 | 81x768x1360 | 55GB VRAM (NVIDIA A100) |
|
||||
| cogvideox1.5-{t2v, i2v}-5b | sft | 8カード zero-3 | bf16 | 81x768x1360 | 55GB VRAM (NVIDIA A100) |
|
||||
| cogvideox1.5-{t2v, i2v}-5b | sft | 8カード zero-3 + オプティマイゼーションとパラメータオフロード | bf16 | 81x768x1360 | 40GB VRAM (NVIDIA A100) |
|
||||
|
||||
|
||||
|
||||
## 依存関係のインストール
|
||||
@ -51,46 +63,58 @@ pip install -e .
|
||||
|
||||
## スクリプトを実行してファインチューニングを開始
|
||||
|
||||
トレーニングを開始する前に、以下の解像度設定に関する要件に注意してください:
|
||||
トレーニングを開始する前に、以下の解像度設定要件に注意してください:
|
||||
|
||||
1. フレーム数は8の倍数 **+1** (つまり8N+1)でなければなりません。例えば49、81など。
|
||||
2. 推奨される動画の解像度は次の通りです:
|
||||
- CogVideoX: 480x720(高さ x 幅)
|
||||
- CogVideoX1.5: 768x1360(高さ x 幅)
|
||||
3. 解像度が要求される基準に合わないサンプル(動画や画像)については、コード内で自動的にリサイズされます。この処理により、アスペクト比が歪む可能性があり、トレーニング結果に影響を与える可能性があります。解像度については、トレーニング前にサンプルを前処理(例えば、アスペクト比を維持するためにクロップとリサイズを使用)しておくことをお勧めします。
|
||||
1. フレーム数は8の倍数 **+1** (つまり8N+1) でなければなりません。例:49, 81 ...
|
||||
2. ビデオ解像度はモデルのデフォルトサイズを使用することをお勧めします:
|
||||
- CogVideoX: 480x720 (高さ×幅)
|
||||
- CogVideoX1.5: 768x1360 (高さ×幅)
|
||||
3. トレーニング解像度に合わないサンプル(ビデオや画像)はコード内で自動的にリサイズされます。このため、サンプルのアスペクト比が変形し、トレーニング効果に影響を与える可能性があります。解像度に関しては、事前にサンプルを処理(例えば、アスペクト比を維持するためにクロップ+リサイズを使用)してからトレーニングを行うことをお勧めします。
|
||||
|
||||
> **重要な注意**:トレーニング効率を向上させるために、動画はトレーニング前に自動的にエンコードされ、結果がディスクにキャッシュされます。トレーニング後にデータを変更した場合は、`videos/` フォルダ内の `latent` フォルダを削除して、最新のデータが使用されるようにしてください。
|
||||
> **重要な注意**:トレーニング効率を高めるため、トレーニング前にビデオをエンコードし、その結果をディスクにキャッシュします。トレーニング後にデータを変更した場合は、`video`ディレクトリ内の`latent`ディレクトリを削除して、最新のデータを使用するようにしてください。
|
||||
|
||||
### テキストから動画生成(T2V)のファインチューニング
|
||||
### LoRA
|
||||
|
||||
```bash
|
||||
# accelerate_train_t2v.sh の設定パラメータを変更します
|
||||
# 主に変更が必要なパラメータ:
|
||||
# --output_dir: 出力先ディレクトリ
|
||||
# train_ddp_t2v.sh の設定パラメータを変更
|
||||
# 主に以下のパラメータを変更する必要があります:
|
||||
# --output_dir: 出力ディレクトリ
|
||||
# --data_root: データセットのルートディレクトリ
|
||||
# --caption_column: プロンプトファイルのパス
|
||||
# --video_column: 動画リストファイルのパス
|
||||
# --train_resolution: トレーニング解像度(フレーム数 x 高さ x 幅)
|
||||
# --caption_column: テキストプロンプトのファイルパス
|
||||
# --image_column: I2Vの場合、参照画像のファイルリストのパス(このパラメータを削除すると、デフォルトで動画の最初のフレームが画像条件として使用されます)
|
||||
# --video_column: 動画ファイルのリストのパス
|
||||
# --train_resolution: トレーニング解像度(フレーム数×高さ×幅)
|
||||
# その他の重要なパラメータについては、起動スクリプトを参照してください
|
||||
|
||||
bash accelerate_train_t2v.sh
|
||||
bash train_ddp_t2v.sh # テキストから動画(T2V)微調整
|
||||
bash train_ddp_i2v.sh # 画像から動画(I2V)微調整
|
||||
```
|
||||
|
||||
### 画像から動画生成(I2V)のファインチューニング
|
||||
### SFT
|
||||
|
||||
`configs/`ディレクトリにはいくつかのZero構成テンプレートが提供されています。必要に応じて適切なトレーニング設定を選択してください(`accelerate_config.yaml`で`deepspeed_config_file`オプションを設定します)。
|
||||
|
||||
```bash
|
||||
# accelerate_train_i2v.sh の設定パラメータを変更します
|
||||
# T2Vと同様に変更が必要なパラメータに加えて、以下のパラメータも設定する必要があります:
|
||||
# --image_column: 参照画像リストファイルのパス(オプション)
|
||||
# その他の重要なパラメータについては、起動スクリプトを参照してください
|
||||
# 設定するパラメータはLoRAトレーニングと同様です
|
||||
|
||||
bash accelerate_train_i2v.sh
|
||||
bash train_zero_t2v.sh # テキストから動画(T2V)微調整
|
||||
bash train_zero_i2v.sh # 画像から動画(I2V)微調整
|
||||
```
|
||||
|
||||
Bashスクリプトの関連パラメータを設定するだけでなく、Zeroの設定ファイルでトレーニングオプションを設定し、Zeroのトレーニング設定がBashスクリプト内のパラメータと一致していることを確認する必要があります。例えば、`batch_size`、`gradient_accumulation_steps`、`mixed_precision`など、具体的な詳細は[DeepSpeed公式ドキュメント](https://www.deepspeed.ai/docs/config-json/)を参照してください。
|
||||
|
||||
SFTトレーニングを使用する際に注意すべき点:
|
||||
|
||||
1. SFTトレーニングでは、検証時にモデルオフロードは使用されません。そのため、24GB以下のGPUでは検証時にVRAMのピークが24GBを超える可能性があります。24GB以下のGPUでは、検証を無効にすることをお勧めします。
|
||||
|
||||
2. Zero-3を有効にすると検証が遅くなるため、Zero-3では検証を無効にすることをお勧めします。
|
||||
|
||||
## ファインチューニングしたモデルの読み込み
|
||||
|
||||
+ ファインチューニングしたモデルを読み込む方法については、[cli_demo.py](../inference/cli_demo.py)を参照してください。
|
||||
|
||||
+ SFTトレーニングのモデルについては、まず`checkpoint-*`/ディレクトリ内の`zero_to_fp32.py`スクリプトを使用して、モデルの重みを統合してください。
|
||||
|
||||
## ベストプラクティス
|
||||
|
||||
+ 解像度が `200 x 480 x 720`(フレーム数 x 高さ x 幅)の70本のトレーニング動画を使用しました。データ前処理でフレームスキップを行い、49フレームおよび16フレームの2つの小さなデータセットを作成して実験速度を向上させました。CogVideoXチームの推奨最大フレーム数制限は49フレームです。これらの70本の動画は、10、25、50本の3つのグループに分け、概念的に類似した性質のものです。
|
||||
|
@ -8,17 +8,25 @@
|
||||
|
||||
## 硬件要求
|
||||
|
||||
| 模型 | 训练类型 | 混合训练精度 | 训练分辨率(帧数x高x宽) | 硬件要求 |
|
||||
|----------------------|----------------|------------|----------------------|-----------------------|
|
||||
| cogvideox-t2v-2b | lora (rank128) | fp16 | 49x480x720 | 16G显存 (NVIDIA 4080) |
|
||||
| cogvideox-t2v-5b | lora (rank128) | bf16 | 49x480x720 | 24G显存 (NVIDIA 4090) |
|
||||
| cogvideox-i2v-5b | lora (rank128) | bf16 | 49x480x720 | 24G显存 (NVIDIA 4090) |
|
||||
| cogvideox1.5-t2v-5b | lora (rank128) | bf16 | 81x768x1360 | 35G显存 (NVIDIA A100) |
|
||||
| cogvideox1.5-i2v-5b | lora (rank128) | bf16 | 81x768x1360 | 35G显存 (NVIDIA A100) |
|
||||
<!-- | cogvideox-t2v-5b | sft | bf16 | 49x480x720 | |
|
||||
| cogvideox-i2v-5b | sft | bf16 | 49x480x720 | |
|
||||
| cogvideox1.5-t2v-5b | sft | bf16 | 81x768x1360 | |
|
||||
| cogvideox1.5-i2v-5b | sft | bf16 | 81x768x1360 | | -->
|
||||
| 模型 | 训练类型 | 分布式策略 | 混合训练精度 | 训练分辨率(帧数x高x宽) | 硬件要求 |
|
||||
|----------------------------|----------------|-----------------------------------|------------|----------------------|-----------------------|
|
||||
| cogvideox-t2v-2b | lora (rank128) | DDP | fp16 | 49x480x720 | 16G显存 (NVIDIA 4080) |
|
||||
| cogvideox-{t2v, i2v}-5b | lora (rank128) | DDP | bf16 | 49x480x720 | 24G显存 (NVIDIA 4090) |
|
||||
| cogvideox1.5-{t2v, i2v}-5b | lora (rank128) | DDP | bf16 | 81x768x1360 | 35G显存 (NVIDIA A100) |
|
||||
| cogvideox-t2v-2b | sft | DDP | fp16 | 49x480x720 | 36G显存 (NVIDIA A100) |
|
||||
| cogvideox-t2v-2b | sft | 1卡zero-2 + opt offload | fp16 | 49x480x720 | 17G显存 (NVIDIA 4090) |
|
||||
| cogvideox-t2v-2b | sft | 8卡zero-2 | fp16 | 49x480x720 | 17G显存 (NVIDIA 4090) |
|
||||
| cogvideox-t2v-2b | sft | 8卡zero-3 | fp16 | 49x480x720 | 19G显存 (NVIDIA 4090) |
|
||||
| cogvideox-t2v-2b | sft | 8卡zero-3 + opt and param offload | bf16 | 49x480x720 | 14G显存 (NVIDIA 4080) |
|
||||
| cogvideox-{t2v, i2v}-5b | sft | 1卡zero-2 + opt offload | bf16 | 49x480x720 | 42G显存 (NVIDIA A100) |
|
||||
| cogvideox-{t2v, i2v}-5b | sft | 8卡zero-2 | bf16 | 49x480x720 | 42G显存 (NVIDIA 4090) |
|
||||
| cogvideox-{t2v, i2v}-5b | sft | 8卡zero-3 | bf16 | 49x480x720 | 43G显存 (NVIDIA 4090) |
|
||||
| cogvideox-{t2v, i2v}-5b | sft | 8卡zero-3 + opt and param offload | bf16 | 49x480x720 | 28G显存 (NVIDIA 5090) |
|
||||
| cogvideox1.5-{t2v, i2v}-5b | sft | 1卡zero-2 + opt offload | bf16 | 81x768x1360 | 56G显存 (NVIDIA A100) |
|
||||
| cogvideox1.5-{t2v, i2v}-5b | sft | 8卡zero-2 | bf16 | 81x768x1360 | 55G显存 (NVIDIA A100) |
|
||||
| cogvideox1.5-{t2v, i2v}-5b | sft | 8卡zero-3 | bf16 | 81x768x1360 | 55G显存 (NVIDIA A100) |
|
||||
| cogvideox1.5-{t2v, i2v}-5b | sft | 8卡zero-3 + opt and param offload | bf16 | 81x768x1360 | 40G显存 (NVIDIA A100) |
|
||||
|
||||
|
||||
## 安装依赖
|
||||
|
||||
@ -66,36 +74,49 @@ pip install -e .
|
||||
|
||||
> **重要提示**:为了提高训练效率,我们会在训练前自动对video进行encode并将结果缓存在磁盘。如果在训练后修改了数据,请删除video目录下的latent目录,以确保使用最新的数据。
|
||||
|
||||
### 文本生成视频 (T2V) 微调
|
||||
### LoRA
|
||||
|
||||
```bash
|
||||
# 修改 accelerate_train_t2v.sh 中的配置参数
|
||||
# 修改 train_ddp_t2v.sh 中的配置参数
|
||||
# 主要需要修改以下参数:
|
||||
# --output_dir: 输出目录
|
||||
# --data_root: 数据集根目录
|
||||
# --caption_column: 提示词文件路径
|
||||
# --image_column: I2V可选,参考图像文件列表路径 (移除这个参数将默认使用视频第一帧作为image condition)
|
||||
# --video_column: 视频文件列表路径
|
||||
# --train_resolution: 训练分辨率 (帧数x高x宽)
|
||||
# 其他重要参数请参考启动脚本
|
||||
|
||||
bash accelerate_train_t2v.sh
|
||||
bash train_ddp_t2v.sh # 文本生成视频 (T2V) 微调
|
||||
bash train_ddp_i2v.sh # 图像生成视频 (I2V) 微调
|
||||
```
|
||||
|
||||
### 图像生成视频 (I2V) 微调
|
||||
### SFT
|
||||
|
||||
我们在`configs/`目录中提供了几个zero配置的模版,请根据你的需求选择合适的训练配置(在`accelerate_config.yaml`中配置`deepspeed_config_file`选项即可)。
|
||||
|
||||
```bash
|
||||
# 修改 accelerate_train_i2v.sh 中的配置参数
|
||||
# 除了需要修改与T2V相同的参数外,还需要额外设置:
|
||||
# --image_column: 参考图像文件列表路径(如果没有自己的图片,默认使用视频第一帧,移除这个参数)
|
||||
# 其他重要参数请参考启动脚本
|
||||
# 需要配置的参数与LoRA训练同理
|
||||
|
||||
bash accelerate_train_i2v.sh
|
||||
bash train_zero_t2v.sh # 文本生成视频 (T2V) 微调
|
||||
bash train_zero_i2v.sh # 图像生成视频 (I2V) 微调
|
||||
```
|
||||
|
||||
除了设置bash脚本的相关参数,你还需要在zero的配置文件中设定相关的训练选项,并确保zero的训练配置与bash脚本中的参数一致,例如batch_size,gradient_accumulation_steps,mixed_precision,具体细节请参考[deepspeed官方文档](https://www.deepspeed.ai/docs/config-json/)
|
||||
|
||||
在使用sft训练时,有以下几点需要注意:
|
||||
|
||||
1. 对于sft训练,validation时不会使用model offload,因此显存峰值可能会超出24GB,所以对于24GB以下的显卡,建议关闭validation。
|
||||
|
||||
2. 开启zero-3时validation会比较慢,建议在zero-3下关闭validation。
|
||||
|
||||
|
||||
## 载入微调的模型
|
||||
|
||||
+ 请关注[cli_demo.py](../inference/cli_demo.py) 以了解如何加载微调的模型。
|
||||
|
||||
+ 对于sft训练的模型,请先使用`checkpoint-*/`目录下的`zero_to_fp32.py`脚本合并模型权重
|
||||
|
||||
## 最佳实践
|
||||
|
||||
+ 包含70个分辨率为 `200 x 480 x 720`(帧数 x 高 x
|
||||
@ -105,4 +126,3 @@ bash accelerate_train_i2v.sh
|
||||
+ 原始仓库使用 `lora_alpha` 设置为 1。我们发现这个值在多次运行中效果不佳,可能是因为模型后端和训练设置的不同。我们的建议是将
|
||||
lora_alpha 设置为与 rank 相同或 rank // 2。
|
||||
+ 建议使用 rank 为 64 及以上的设置。
|
||||
|
||||
|
21
finetune/accelerate_config.yaml
Normal file
21
finetune/accelerate_config.yaml
Normal file
@ -0,0 +1,21 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
|
||||
gpu_ids: "0,1,2,3,4,5,6,7"
|
||||
num_processes: 8 # should be the same as the number of GPUs
|
||||
|
||||
debug: false
|
||||
deepspeed_config:
|
||||
deepspeed_config_file: configs/zero2.yaml # e.g. configs/zero2.yaml, need use absolute path
|
||||
zero3_init_flag: false
|
||||
distributed_type: DEEPSPEED
|
||||
downcast_bf16: 'no'
|
||||
enable_cpu_affinity: false
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
num_machines: 1
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
38
finetune/configs/zero2.yaml
Normal file
38
finetune/configs/zero2.yaml
Normal file
@ -0,0 +1,38 @@
|
||||
{
|
||||
"bf16": {
|
||||
"enabled": true
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": "auto",
|
||||
"weight_decay": "auto",
|
||||
"torch_adam": true,
|
||||
"adam_w_mode": true
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupDecayLR",
|
||||
"params": {
|
||||
"warmup_min_lr": "auto",
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto",
|
||||
"total_num_steps": "auto"
|
||||
}
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"allgather_partitions": true,
|
||||
"allgather_bucket_size": 2e8,
|
||||
"overlap_comm": true,
|
||||
"reduce_scatter": true,
|
||||
"reduce_bucket_size": 5e8,
|
||||
"contiguous_gradients": true
|
||||
},
|
||||
"gradient_accumulation_steps": 1,
|
||||
"train_micro_batch_size_per_gpu": 1,
|
||||
"train_batch_size": "auto",
|
||||
"gradient_clipping": "auto",
|
||||
"steps_per_print": 2000,
|
||||
"wall_clock_breakdown": false
|
||||
}
|
42
finetune/configs/zero2_offload.yaml
Normal file
42
finetune/configs/zero2_offload.yaml
Normal file
@ -0,0 +1,42 @@
|
||||
{
|
||||
"bf16": {
|
||||
"enabled": true
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": "auto",
|
||||
"weight_decay": "auto",
|
||||
"torch_adam": true,
|
||||
"adam_w_mode": true
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupDecayLR",
|
||||
"params": {
|
||||
"warmup_min_lr": "auto",
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto",
|
||||
"total_num_steps": "auto"
|
||||
}
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"allgather_partitions": true,
|
||||
"allgather_bucket_size": 2e8,
|
||||
"overlap_comm": true,
|
||||
"reduce_scatter": true,
|
||||
"reduce_bucket_size": 5e8,
|
||||
"contiguous_gradients": true,
|
||||
"offload_optimizer": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
}
|
||||
},
|
||||
"gradient_accumulation_steps": 1,
|
||||
"train_micro_batch_size_per_gpu": 1,
|
||||
"train_batch_size": "auto",
|
||||
"gradient_clipping": "auto",
|
||||
"steps_per_print": 2000,
|
||||
"wall_clock_breakdown": false
|
||||
}
|
43
finetune/configs/zero3.yaml
Normal file
43
finetune/configs/zero3.yaml
Normal file
@ -0,0 +1,43 @@
|
||||
{
|
||||
"bf16": {
|
||||
"enabled": true
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": "auto",
|
||||
"weight_decay": "auto",
|
||||
"torch_adam": true,
|
||||
"adam_w_mode": true
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupDecayLR",
|
||||
"params": {
|
||||
"warmup_min_lr": "auto",
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto",
|
||||
"total_num_steps": "auto"
|
||||
}
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
"overlap_comm": true,
|
||||
"contiguous_gradients": true,
|
||||
"reduce_bucket_size": 5e8,
|
||||
"stage3_prefetch_bucket_size": "auto",
|
||||
"stage3_param_persistence_threshold": "auto",
|
||||
"sub_group_size": 1e9,
|
||||
"stage3_max_live_parameters": 1e9,
|
||||
"stage3_max_reuse_distance": 1e9,
|
||||
"stage3_gather_16bit_weights_on_model_save": "auto",
|
||||
"stage3_prefetch_bucket_size": 5e8,
|
||||
"stage3_param_persistence_threshold": 1e5
|
||||
},
|
||||
"gradient_accumulation_steps": 1,
|
||||
"train_micro_batch_size_per_gpu": 1,
|
||||
"train_batch_size": "auto",
|
||||
"gradient_clipping": "auto",
|
||||
"steps_per_print": 2000,
|
||||
"wall_clock_breakdown": false
|
||||
}
|
51
finetune/configs/zero3_offload.yaml
Normal file
51
finetune/configs/zero3_offload.yaml
Normal file
@ -0,0 +1,51 @@
|
||||
{
|
||||
"bf16": {
|
||||
"enabled": true
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": "auto",
|
||||
"weight_decay": "auto",
|
||||
"torch_adam": true,
|
||||
"adam_w_mode": true
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupDecayLR",
|
||||
"params": {
|
||||
"warmup_min_lr": "auto",
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto",
|
||||
"total_num_steps": "auto"
|
||||
}
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
"offload_optimizer": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"offload_param": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"overlap_comm": true,
|
||||
"contiguous_gradients": true,
|
||||
"reduce_bucket_size": 5e8,
|
||||
"stage3_prefetch_bucket_size": "auto",
|
||||
"stage3_param_persistence_threshold": "auto",
|
||||
"sub_group_size": 1e9,
|
||||
"stage3_max_live_parameters": 1e9,
|
||||
"stage3_max_reuse_distance": 1e9,
|
||||
"stage3_gather_16bit_weights_on_model_save": "auto",
|
||||
"stage3_prefetch_bucket_size": 5e8,
|
||||
"stage3_param_persistence_threshold": 1e6
|
||||
},
|
||||
"gradient_accumulation_steps": 1,
|
||||
"train_micro_batch_size_per_gpu": 1,
|
||||
"train_batch_size": "auto",
|
||||
"gradient_clipping": "auto",
|
||||
"steps_per_print": 2000,
|
||||
"wall_clock_breakdown": false
|
||||
}
|
@ -139,22 +139,19 @@ class BaseI2VDataset(Dataset):
|
||||
logger.info(f"Saved prompt embedding to {prompt_embedding_path}", main_process_only=False)
|
||||
|
||||
if encoded_video_path.exists():
|
||||
# encoded_video = torch.load(encoded_video_path, weights_only=True)
|
||||
encoded_video = load_file(encoded_video_path)["encoded_video"]
|
||||
logger.debug(f"Loaded encoded video from {encoded_video_path}", main_process_only=False)
|
||||
# shape of image: [C, H, W]
|
||||
_, image = self.preprocess(None, self.images[index])
|
||||
image = self.image_transform(image)
|
||||
else:
|
||||
frames, image = self.preprocess(video, image)
|
||||
frames = frames.to(self.device)
|
||||
image = image.to(self.device)
|
||||
image = self.image_transform(image)
|
||||
# Current shape of frames: [F, C, H, W]
|
||||
frames = self.video_transform(frames)
|
||||
|
||||
# Add image into the first frame.
|
||||
# Note, **this operation maybe model-specific**, and maybe change in the future.
|
||||
frames = torch.cat([image.unsqueeze(0), frames], dim=0)
|
||||
|
||||
# Convert to [B, C, F, H, W]
|
||||
frames = frames.unsqueeze(0)
|
||||
frames = frames.permute(0, 2, 1, 3, 4).contiguous()
|
||||
|
@ -126,8 +126,15 @@ def preprocess_video_with_resize(
|
||||
video_reader = decord.VideoReader(uri=video_path.as_posix(), width=width, height=height)
|
||||
video_num_frames = len(video_reader)
|
||||
if video_num_frames < max_num_frames:
|
||||
raise ValueError(f"video frame count in {video_path} is less than {max_num_frames}.")
|
||||
|
||||
# Get all frames first
|
||||
frames = video_reader.get_batch(list(range(video_num_frames)))
|
||||
# Repeat the last frame until we reach max_num_frames
|
||||
last_frame = frames[-1:]
|
||||
num_repeats = max_num_frames - video_num_frames
|
||||
repeated_frames = last_frame.repeat(num_repeats, 1, 1, 1)
|
||||
frames = torch.cat([frames, repeated_frames], dim=0)
|
||||
return frames.float().permute(0, 3, 1, 2).contiguous()
|
||||
else:
|
||||
indices = list(range(0, video_num_frames, video_num_frames // max_num_frames))
|
||||
frames = video_reader.get_batch(indices)
|
||||
frames = frames[:max_num_frames].float()
|
||||
|
9
finetune/models/cogvideox1_5_i2v/lora_trainer.py
Normal file
9
finetune/models/cogvideox1_5_i2v/lora_trainer.py
Normal file
@ -0,0 +1,9 @@
|
||||
from ..cogvideox_i2v.lora_trainer import CogVideoXI2VLoraTrainer
|
||||
from ..utils import register
|
||||
|
||||
|
||||
class CogVideoX1_5I2VLoraTrainer(CogVideoXI2VLoraTrainer):
|
||||
pass
|
||||
|
||||
|
||||
register("cogvideox1.5-i2v", "lora", CogVideoX1_5I2VLoraTrainer)
|
9
finetune/models/cogvideox1_5_i2v/sft_trainer.py
Normal file
9
finetune/models/cogvideox1_5_i2v/sft_trainer.py
Normal file
@ -0,0 +1,9 @@
|
||||
from ..cogvideox_i2v.sft_trainer import CogVideoXI2VSftTrainer
|
||||
from ..utils import register
|
||||
|
||||
|
||||
class CogVideoX1_5I2VSftTrainer(CogVideoXI2VSftTrainer):
|
||||
pass
|
||||
|
||||
|
||||
register("cogvideox1.5-i2v", "sft", CogVideoX1_5I2VSftTrainer)
|
9
finetune/models/cogvideox1_5_t2v/lora_trainer.py
Normal file
9
finetune/models/cogvideox1_5_t2v/lora_trainer.py
Normal file
@ -0,0 +1,9 @@
|
||||
from ..cogvideox_t2v.lora_trainer import CogVideoXT2VLoraTrainer
|
||||
from ..utils import register
|
||||
|
||||
|
||||
class CogVideoX1_5T2VLoraTrainer(CogVideoXT2VLoraTrainer):
|
||||
pass
|
||||
|
||||
|
||||
register("cogvideox1.5-t2v", "lora", CogVideoX1_5T2VLoraTrainer)
|
9
finetune/models/cogvideox1_5_t2v/sft_trainer.py
Normal file
9
finetune/models/cogvideox1_5_t2v/sft_trainer.py
Normal file
@ -0,0 +1,9 @@
|
||||
from ..cogvideox_t2v.sft_trainer import CogVideoXT2VSftTrainer
|
||||
from ..utils import register
|
||||
|
||||
|
||||
class CogVideoX1_5T2VSftTrainer(CogVideoXT2VSftTrainer):
|
||||
pass
|
||||
|
||||
|
||||
register("cogvideox1.5-t2v", "sft", CogVideoX1_5T2VSftTrainer)
|
@ -1,9 +0,0 @@
|
||||
from ..cogvideox_i2v.lora_trainer import CogVideoXI2VLoraTrainer
|
||||
from ..utils import register
|
||||
|
||||
|
||||
class CogVideoX1dot5I2VLoraTrainer(CogVideoXI2VLoraTrainer):
|
||||
pass
|
||||
|
||||
|
||||
register("cogvideox1.5-i2v", "lora", CogVideoX1dot5I2VLoraTrainer)
|
@ -1,9 +0,0 @@
|
||||
from ..cogvideox_t2v.lora_trainer import CogVideoXT2VLoraTrainer
|
||||
from ..utils import register
|
||||
|
||||
|
||||
class CogVideoX1dot5T2VLoraTrainer(CogVideoXT2VLoraTrainer):
|
||||
pass
|
||||
|
||||
|
||||
register("cogvideox1.5-t2v", "lora", CogVideoX1dot5T2VLoraTrainer)
|
@ -9,6 +9,7 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.models.embeddings import get_3d_rotary_pos_embed
|
||||
from PIL import Image
|
||||
from numpy import dtype
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
from typing_extensions import override
|
||||
|
||||
@ -116,7 +117,7 @@ class CogVideoXI2VLoraTrainer(Trainer):
|
||||
|
||||
# Get prompt embeddings
|
||||
_, seq_len, _ = prompt_embedding.shape
|
||||
prompt_embedding = prompt_embedding.view(batch_size, seq_len, -1)
|
||||
prompt_embedding = prompt_embedding.view(batch_size, seq_len, -1).to(dtype=latent.dtype)
|
||||
|
||||
# Add frame dimension to images [B,C,H,W] -> [B,C,F,H,W]
|
||||
images = images.unsqueeze(2)
|
||||
@ -166,7 +167,7 @@ class CogVideoXI2VLoraTrainer(Trainer):
|
||||
else None
|
||||
)
|
||||
|
||||
# Predict noise
|
||||
# Predict noise, For CogVideoX1.5 Only.
|
||||
ofs_emb = (
|
||||
None if self.state.transformer_config.ofs_embed_dim is None else latent.new_full((1,), fill_value=2.0)
|
||||
)
|
||||
|
9
finetune/models/cogvideox_i2v/sft_trainer.py
Normal file
9
finetune/models/cogvideox_i2v/sft_trainer.py
Normal file
@ -0,0 +1,9 @@
|
||||
from ..cogvideox_i2v.lora_trainer import CogVideoXI2VLoraTrainer
|
||||
from ..utils import register
|
||||
|
||||
|
||||
class CogVideoXI2VSftTrainer(CogVideoXI2VLoraTrainer):
|
||||
pass
|
||||
|
||||
|
||||
register("cogvideox-i2v", "sft", CogVideoXI2VSftTrainer)
|
@ -100,28 +100,18 @@ class CogVideoXT2VLoraTrainer(Trainer):
|
||||
# Shape of latent: [B, C, F, H, W]
|
||||
|
||||
patch_size_t = self.state.transformer_config.patch_size_t
|
||||
if patch_size_t is not None and latent.shape[2] % patch_size_t != 0:
|
||||
raise ValueError(
|
||||
"Number of frames in latent must be divisible by patch size, please check your args for training."
|
||||
)
|
||||
|
||||
# Add 2 random noise frames at the beginning of frame dimension
|
||||
noise_frames = torch.randn(
|
||||
latent.shape[0],
|
||||
latent.shape[1],
|
||||
2,
|
||||
latent.shape[3],
|
||||
latent.shape[4],
|
||||
device=latent.device,
|
||||
dtype=latent.dtype,
|
||||
)
|
||||
latent = torch.cat([noise_frames, latent], dim=2)
|
||||
if patch_size_t is not None:
|
||||
ncopy = latent.shape[2] % patch_size_t
|
||||
# Copy the first frame ncopy times to match patch_size_t
|
||||
first_frame = latent[:, :, :1, :, :] # Get first frame [B, C, 1, H, W]
|
||||
latent = torch.cat([first_frame.repeat(1, 1, ncopy, 1, 1), latent], dim=2)
|
||||
assert latent.shape[2] % patch_size_t == 0
|
||||
|
||||
batch_size, num_channels, num_frames, height, width = latent.shape
|
||||
|
||||
# Get prompt embeddings
|
||||
_, seq_len, _ = prompt_embedding.shape
|
||||
prompt_embedding = prompt_embedding.view(batch_size, seq_len, -1)
|
||||
prompt_embedding = prompt_embedding.view(batch_size, seq_len, -1).to(dtype=latent.dtype)
|
||||
|
||||
# Sample a random timestep for each sample
|
||||
timesteps = torch.randint(
|
||||
@ -183,7 +173,7 @@ class CogVideoXT2VLoraTrainer(Trainer):
|
||||
prompt, image, video = eval_data["prompt"], eval_data["image"], eval_data["video"]
|
||||
|
||||
video_generate = pipe(
|
||||
num_frames=self.state.train_frames, # since we pad 2 frames in latent, we still use train_frames
|
||||
num_frames=self.state.train_frames,
|
||||
height=self.state.train_height,
|
||||
width=self.state.train_width,
|
||||
prompt=prompt,
|
||||
@ -207,7 +197,6 @@ class CogVideoXT2VLoraTrainer(Trainer):
|
||||
base_num_frames = num_frames
|
||||
else:
|
||||
base_num_frames = (num_frames + transformer_config.patch_size_t - 1) // transformer_config.patch_size_t
|
||||
|
||||
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
||||
embed_dim=transformer_config.attention_head_dim,
|
||||
crops_coords=None,
|
||||
|
9
finetune/models/cogvideox_t2v/sft_trainer.py
Normal file
9
finetune/models/cogvideox_t2v/sft_trainer.py
Normal file
@ -0,0 +1,9 @@
|
||||
from ..cogvideox_t2v.lora_trainer import CogVideoXT2VLoraTrainer
|
||||
from ..utils import register
|
||||
|
||||
|
||||
class CogVideoXT2VSftTrainer(CogVideoXT2VLoraTrainer):
|
||||
pass
|
||||
|
||||
|
||||
register("cogvideox-t2v", "sft", CogVideoXT2VSftTrainer)
|
@ -15,16 +15,11 @@ def register(model_name: str, training_type: Literal["lora", "sft"], trainer_cls
|
||||
trainer_cls (Trainer): Trainer class to register.
|
||||
"""
|
||||
|
||||
# Check if model_name exists in SUPPORTED_MODELS
|
||||
# Check if model_name and training_type exists in SUPPORTED_MODELS
|
||||
if model_name not in SUPPORTED_MODELS:
|
||||
SUPPORTED_MODELS[model_name] = {}
|
||||
else:
|
||||
raise ValueError(f"Model {model_name} already exists")
|
||||
|
||||
# Check if training_type exists for this model
|
||||
if training_type not in SUPPORTED_MODELS[model_name]:
|
||||
SUPPORTED_MODELS[model_name][training_type] = {}
|
||||
else:
|
||||
if training_type in SUPPORTED_MODELS[model_name]:
|
||||
raise ValueError(f"Training type {training_type} already exists for model {model_name}")
|
||||
|
||||
SUPPORTED_MODELS[model_name][training_type] = trainer_cls
|
||||
|
@ -78,7 +78,7 @@ class Args(BaseModel):
|
||||
|
||||
########## Validation ##########
|
||||
do_validation: bool = False
|
||||
validation_steps: int | None = None # if set, should be a multiple of checkpointing_steps
|
||||
validation_steps: int | None # if set, should be a multiple of checkpointing_steps
|
||||
validation_dir: Path | None # if set do_validation, should not be None
|
||||
validation_prompts: str | None # if set do_validation, should not be None
|
||||
validation_images: str | None # if set do_validation and model_type == i2v, should not be None
|
||||
@ -229,7 +229,7 @@ class Args(BaseModel):
|
||||
parser.add_argument("--resume_from_checkpoint", type=str, default=None)
|
||||
|
||||
# Validation
|
||||
parser.add_argument("--do_validation", type=bool, default=False)
|
||||
parser.add_argument("--do_validation", type=lambda x: x.lower() == 'true', default=False)
|
||||
parser.add_argument("--validation_steps", type=int, default=None)
|
||||
parser.add_argument("--validation_dir", type=str, default=None)
|
||||
parser.add_argument("--validation_prompts", type=str, default=None)
|
||||
|
@ -8,13 +8,13 @@ from pydantic import BaseModel
|
||||
class State(BaseModel):
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
train_frames: int # user-defined training frames, **containing one image padding frame**
|
||||
train_frames: int
|
||||
train_height: int
|
||||
train_width: int
|
||||
|
||||
transformer_config: Dict[str, Any] = None
|
||||
|
||||
weight_dtype: torch.dtype = torch.float32
|
||||
weight_dtype: torch.dtype = torch.float32 # dtype for mixed precision training
|
||||
num_trainable_parameters: int = 0
|
||||
overwrote_max_train_steps: bool = False
|
||||
num_update_steps_per_epoch: int = 0
|
||||
@ -25,3 +25,5 @@ class State(BaseModel):
|
||||
validation_prompts: List[str] = []
|
||||
validation_images: List[Path | None] = []
|
||||
validation_videos: List[Path | None] = []
|
||||
|
||||
using_deepspeed: bool = False
|
||||
|
@ -13,26 +13,26 @@ MODEL_ARGS=(
|
||||
|
||||
# Output Configuration
|
||||
OUTPUT_ARGS=(
|
||||
--output_dir "/path/to/output/dir"
|
||||
--output_dir "/path/to/your/output_dir"
|
||||
--report_to "tensorboard"
|
||||
)
|
||||
|
||||
# Data Configuration
|
||||
DATA_ARGS=(
|
||||
--data_root "/path/to/data/dir"
|
||||
--data_root "/absolute/path/to/your/data_root"
|
||||
--caption_column "prompt.txt"
|
||||
--video_column "videos.txt"
|
||||
--image_column "images.txt"
|
||||
--train_resolution "81x768x1360"
|
||||
# --image_column "images.txt" # comment this line will use first frame of video as image conditioning
|
||||
--train_resolution "81x768x1360" # (frames x height x width), frames should be 8N+1
|
||||
)
|
||||
|
||||
# Training Configuration
|
||||
TRAIN_ARGS=(
|
||||
--train_epochs 10
|
||||
--train_epochs 10 # number of training epochs
|
||||
--seed 42 # random seed
|
||||
--batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--mixed_precision "bf16" # ["no", "fp16"]
|
||||
--seed 42
|
||||
--mixed_precision "bf16" # ["no", "fp16"] # Only CogVideoX-2B supports fp16 training
|
||||
)
|
||||
|
||||
# System Configuration
|
||||
@ -44,15 +44,16 @@ SYSTEM_ARGS=(
|
||||
|
||||
# Checkpointing Configuration
|
||||
CHECKPOINT_ARGS=(
|
||||
--checkpointing_steps 200
|
||||
--checkpointing_limit 10
|
||||
--checkpointing_steps 10 # save checkpoint every x steps
|
||||
--checkpointing_limit 2 # maximum number of checkpoints to keep, after which the oldest one is deleted
|
||||
--resume_from_checkpoint "/absolute/path/to/checkpoint_dir" # if you want to resume from a checkpoint, otherwise, comment this line
|
||||
)
|
||||
|
||||
# Validation Configuration
|
||||
VALIDATION_ARGS=(
|
||||
--do_validation False
|
||||
--validation_dir "/path/to/validation/dir"
|
||||
--validation_steps 400
|
||||
--do_validation false # ["true", "false"]
|
||||
--validation_dir "/absolute/path/to/your/validation_set"
|
||||
--validation_steps 20 # should be multiple of checkpointing_steps
|
||||
--validation_prompts "prompts.txt"
|
||||
--validation_images "images.txt"
|
||||
--gen_fps 16
|
@ -13,25 +13,25 @@ MODEL_ARGS=(
|
||||
|
||||
# Output Configuration
|
||||
OUTPUT_ARGS=(
|
||||
--output_dir "/path/to/output/dir"
|
||||
--output_dir "/absolute/path/to/your/output_dir"
|
||||
--report_to "tensorboard"
|
||||
)
|
||||
|
||||
# Data Configuration
|
||||
DATA_ARGS=(
|
||||
--data_root "/path/to/data/dir"
|
||||
--data_root "/absolute/path/to/your/data_root"
|
||||
--caption_column "prompt.txt"
|
||||
--video_column "videos.txt"
|
||||
--train_resolution "81x768x1360"
|
||||
--train_resolution "81x768x1360" # (frames x height x width), frames should be 8N+1
|
||||
)
|
||||
|
||||
# Training Configuration
|
||||
TRAIN_ARGS=(
|
||||
--train_epochs 10
|
||||
--train_epochs 10 # number of training epochs
|
||||
--seed 42 # random seed
|
||||
--batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--mixed_precision "bf16" # ["no", "fp16"]
|
||||
--seed 42
|
||||
--mixed_precision "bf16" # ["no", "fp16"] # Only CogVideoX-2B supports fp16 training
|
||||
)
|
||||
|
||||
# System Configuration
|
||||
@ -43,15 +43,16 @@ SYSTEM_ARGS=(
|
||||
|
||||
# Checkpointing Configuration
|
||||
CHECKPOINT_ARGS=(
|
||||
--checkpointing_steps 200
|
||||
--checkpointing_limit 10
|
||||
--checkpointing_steps 10 # save checkpoint every x steps
|
||||
--checkpointing_limit 2 # maximum number of checkpoints to keep, after which the oldest one is deleted
|
||||
--resume_from_checkpoint "/absolute/path/to/checkpoint_dir" # if you want to resume from a checkpoint, otherwise, comment this line
|
||||
)
|
||||
|
||||
# Validation Configuration
|
||||
VALIDATION_ARGS=(
|
||||
--do_validation False
|
||||
--validation_dir "/path/to/validation/dir"
|
||||
--validation_steps 400
|
||||
--do_validation false # ["true", "false"]
|
||||
--validation_dir "/absolute/path/to/your/validation_set"
|
||||
--validation_steps 20 # should be multiple of checkpointing_steps
|
||||
--validation_prompts "prompts.txt"
|
||||
--gen_fps 16
|
||||
)
|
73
finetune/train_zero_i2v.sh
Normal file
73
finetune/train_zero_i2v.sh
Normal file
@ -0,0 +1,73 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# Prevent tokenizer parallelism issues
|
||||
export TOKENIZERS_PARALLELISM=false
|
||||
|
||||
# Model Configuration
|
||||
MODEL_ARGS=(
|
||||
--model_path "THUDM/CogVideoX1.5-5B-I2V"
|
||||
--model_name "cogvideox1.5-i2v" # ["cogvideox-i2v"]
|
||||
--model_type "i2v"
|
||||
--training_type "sft"
|
||||
)
|
||||
|
||||
# Output Configuration
|
||||
OUTPUT_ARGS=(
|
||||
--output_dir "/absolute/path/to/your/output_dir"
|
||||
--report_to "tensorboard"
|
||||
)
|
||||
|
||||
# Data Configuration
|
||||
DATA_ARGS=(
|
||||
--data_root "/absolute/path/to/your/data_root"
|
||||
--caption_column "prompt.txt"
|
||||
--video_column "videos.txt"
|
||||
# --image_column "images.txt" # comment this line will use first frame of video as image conditioning
|
||||
--train_resolution "81x768x1360" # (frames x height x width), frames should be 8N+1 and height, width should be multiples of 16
|
||||
)
|
||||
|
||||
# Training Configuration
|
||||
TRAIN_ARGS=(
|
||||
--train_epochs 10 # number of training epochs
|
||||
--seed 42 # random seed
|
||||
|
||||
######### Please keep consistent with deepspeed config file ##########
|
||||
--batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--mixed_precision "bf16" # ["no", "fp16"] Only CogVideoX-2B supports fp16 training
|
||||
########################################################################
|
||||
)
|
||||
|
||||
# System Configuration
|
||||
SYSTEM_ARGS=(
|
||||
--num_workers 8
|
||||
--pin_memory True
|
||||
--nccl_timeout 1800
|
||||
)
|
||||
|
||||
# Checkpointing Configuration
|
||||
CHECKPOINT_ARGS=(
|
||||
--checkpointing_steps 10 # save checkpoint every x steps
|
||||
--checkpointing_limit 2 # maximum number of checkpoints to keep, after which the oldest one is deleted
|
||||
# --resume_from_checkpoint "/absolute/path/to/checkpoint_dir" # if you want to resume from a checkpoint, otherwise, comment this line
|
||||
)
|
||||
|
||||
# Validation Configuration
|
||||
VALIDATION_ARGS=(
|
||||
--do_validation false # ["true", "false"]
|
||||
--validation_dir "/absolute/path/to/validation_set"
|
||||
--validation_steps 20 # should be multiple of checkpointing_steps
|
||||
--validation_prompts "prompts.txt"
|
||||
--validation_images "images.txt"
|
||||
--gen_fps 16
|
||||
)
|
||||
|
||||
# Combine all arguments and launch training
|
||||
accelerate launch --config_file accelerate_config.yaml train.py \
|
||||
"${MODEL_ARGS[@]}" \
|
||||
"${OUTPUT_ARGS[@]}" \
|
||||
"${DATA_ARGS[@]}" \
|
||||
"${TRAIN_ARGS[@]}" \
|
||||
"${SYSTEM_ARGS[@]}" \
|
||||
"${CHECKPOINT_ARGS[@]}" \
|
||||
"${VALIDATION_ARGS[@]}"
|
71
finetune/train_zero_t2v.sh
Normal file
71
finetune/train_zero_t2v.sh
Normal file
@ -0,0 +1,71 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# Prevent tokenizer parallelism issues
|
||||
export TOKENIZERS_PARALLELISM=false
|
||||
|
||||
# Model Configuration
|
||||
MODEL_ARGS=(
|
||||
--model_path "THUDM/CogVideoX1.5-5B"
|
||||
--model_name "cogvideox1.5-t2v" # ["cogvideox-t2v"]
|
||||
--model_type "t2v"
|
||||
--training_type "sft"
|
||||
)
|
||||
|
||||
# Output Configuration
|
||||
OUTPUT_ARGS=(
|
||||
--output_dir "/absolute/path/to/your/output_dir"
|
||||
--report_to "tensorboard"
|
||||
)
|
||||
|
||||
# Data Configuration
|
||||
DATA_ARGS=(
|
||||
--data_root "/absolute/path/to/your/data_root"
|
||||
--caption_column "prompt.txt"
|
||||
--video_column "videos.txt"
|
||||
--train_resolution "81x768x1360" # (frames x height x width), frames should be 8N+1 and height, width should be multiples of 16
|
||||
)
|
||||
|
||||
# Training Configuration
|
||||
TRAIN_ARGS=(
|
||||
--train_epochs 10 # number of training epochs
|
||||
--seed 42 # random seed
|
||||
|
||||
######### Please keep consistent with deepspeed config file ##########
|
||||
--batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--mixed_precision "bf16" # ["no", "fp16"] Only CogVideoX-2B supports fp16 training
|
||||
########################################################################
|
||||
)
|
||||
|
||||
# System Configuration
|
||||
SYSTEM_ARGS=(
|
||||
--num_workers 8
|
||||
--pin_memory True
|
||||
--nccl_timeout 1800
|
||||
)
|
||||
|
||||
# Checkpointing Configuration
|
||||
CHECKPOINT_ARGS=(
|
||||
--checkpointing_steps 10 # save checkpoint every x steps
|
||||
--checkpointing_limit 2 # maximum number of checkpoints to keep, after which the oldest one is deleted
|
||||
# --resume_from_checkpoint "/absolute/path/to/checkpoint_dir" # if you want to resume from a checkpoint, otherwise, comment this line
|
||||
)
|
||||
|
||||
# Validation Configuration
|
||||
VALIDATION_ARGS=(
|
||||
--do_validation false # ["true", "false"]
|
||||
--validation_dir "/absolute/path/to/validation_set"
|
||||
--validation_steps 20 # should be multiple of checkpointing_steps
|
||||
--validation_prompts "prompts.txt"
|
||||
--gen_fps 16
|
||||
)
|
||||
|
||||
# Combine all arguments and launch training
|
||||
accelerate launch --config_file accelerate_config.yaml train.py \
|
||||
"${MODEL_ARGS[@]}" \
|
||||
"${OUTPUT_ARGS[@]}" \
|
||||
"${DATA_ARGS[@]}" \
|
||||
"${TRAIN_ARGS[@]}" \
|
||||
"${SYSTEM_ARGS[@]}" \
|
||||
"${CHECKPOINT_ARGS[@]}" \
|
||||
"${VALIDATION_ARGS[@]}"
|
@ -1,3 +1,4 @@
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
@ -71,7 +72,7 @@ class Trainer:
|
||||
train_width=self.args.train_resolution[2],
|
||||
)
|
||||
|
||||
self.components = Components()
|
||||
self.components: Components = self.load_components()
|
||||
self.accelerator: Accelerator = None
|
||||
self.dataset: Dataset = None
|
||||
self.data_loader: DataLoader = None
|
||||
@ -83,6 +84,8 @@ class Trainer:
|
||||
self._init_logging()
|
||||
self._init_directories()
|
||||
|
||||
self.state.using_deepspeed = self.accelerator.state.deepspeed_plugin is not None
|
||||
|
||||
def _init_distributed(self):
|
||||
logging_dir = Path(self.args.output_dir, "logs")
|
||||
project_config = ProjectConfiguration(project_dir=self.args.output_dir, logging_dir=logging_dir)
|
||||
@ -145,9 +148,6 @@ class Trainer:
|
||||
def prepare_models(self) -> None:
|
||||
logger.info("Initializing models")
|
||||
|
||||
# Initialize model components
|
||||
self.components = self.load_components()
|
||||
|
||||
if self.components.vae is not None:
|
||||
if self.args.enable_slicing:
|
||||
self.components.vae.enable_slicing()
|
||||
@ -159,15 +159,11 @@ class Trainer:
|
||||
def prepare_dataset(self) -> None:
|
||||
logger.info("Initializing dataset and dataloader")
|
||||
|
||||
# self.state.train_frames includes one padding frame for image conditioning
|
||||
# so we only sample train_frames - 1 frames from the actual video
|
||||
sample_frames = self.state.train_frames - 1
|
||||
|
||||
if self.args.model_type == "i2v":
|
||||
self.dataset = I2VDatasetWithResize(
|
||||
**(self.args.model_dump()),
|
||||
device=self.accelerator.device,
|
||||
max_num_frames=sample_frames,
|
||||
max_num_frames=self.state.train_frames,
|
||||
height=self.state.train_height,
|
||||
width=self.state.train_width,
|
||||
trainer=self,
|
||||
@ -176,7 +172,7 @@ class Trainer:
|
||||
self.dataset = T2VDatasetWithResize(
|
||||
**(self.args.model_dump()),
|
||||
device=self.accelerator.device,
|
||||
max_num_frames=sample_frames,
|
||||
max_num_frames=self.state.train_frames,
|
||||
height=self.state.train_height,
|
||||
width=self.state.train_width,
|
||||
trainer=self,
|
||||
@ -223,12 +219,7 @@ class Trainer:
|
||||
def prepare_trainable_parameters(self):
|
||||
logger.info("Initializing trainable parameters")
|
||||
|
||||
# For now only lora is supported
|
||||
for attr_name, component in vars(self.components).items():
|
||||
if hasattr(component, "requires_grad_"):
|
||||
component.requires_grad_(False)
|
||||
|
||||
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
|
||||
# For mixed precision training we cast all non-trainable weights to half-precision
|
||||
# as these weights are only used for inference, keeping weights in full precision is not required.
|
||||
weight_dtype = self.state.weight_dtype
|
||||
|
||||
@ -238,11 +229,16 @@ class Trainer:
|
||||
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
|
||||
)
|
||||
|
||||
self.__load_components()
|
||||
|
||||
if self.args.gradient_checkpointing:
|
||||
self.components.transformer.enable_gradient_checkpointing()
|
||||
# For LoRA, we freeze all the parameters
|
||||
# For SFT, we train all the parameters in transformer model
|
||||
for attr_name, component in vars(self.components).items():
|
||||
if hasattr(component, "requires_grad_"):
|
||||
if self.args.training_type == "sft" and attr_name == "transformer":
|
||||
component.requires_grad_(True)
|
||||
else:
|
||||
component.requires_grad_(False)
|
||||
|
||||
if self.args.training_type == "lora":
|
||||
transformer_lora_config = LoraConfig(
|
||||
r=self.args.rank,
|
||||
lora_alpha=self.args.lora_alpha,
|
||||
@ -252,21 +248,28 @@ class Trainer:
|
||||
self.components.transformer.add_adapter(transformer_lora_config)
|
||||
self.__prepare_saving_loading_hooks(transformer_lora_config)
|
||||
|
||||
# Load components needed for training to GPU (except transformer), and cast them to the specified data type
|
||||
ignore_list = ["transformer"] + self.UNLOAD_LIST
|
||||
self.__move_components_to_device(dtype=weight_dtype, ignore_list=ignore_list)
|
||||
|
||||
if self.args.gradient_checkpointing:
|
||||
self.components.transformer.enable_gradient_checkpointing()
|
||||
|
||||
def prepare_optimizer(self) -> None:
|
||||
logger.info("Initializing optimizer and lr scheduler")
|
||||
|
||||
# Make sure the trainable params are in float32
|
||||
if self.args.mixed_precision != "no":
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
cast_training_params([self.components.transformer], dtype=torch.float32)
|
||||
|
||||
transformer_lora_parameters = list(filter(lambda p: p.requires_grad, self.components.transformer.parameters()))
|
||||
# For LoRA, we only want to train the LoRA weights
|
||||
# For SFT, we want to train all the parameters
|
||||
trainable_parameters = list(filter(lambda p: p.requires_grad, self.components.transformer.parameters()))
|
||||
transformer_parameters_with_lr = {
|
||||
"params": transformer_lora_parameters,
|
||||
"params": trainable_parameters,
|
||||
"lr": self.args.learning_rate,
|
||||
}
|
||||
params_to_optimize = [transformer_parameters_with_lr]
|
||||
self.state.num_trainable_parameters = sum(p.numel() for p in transformer_lora_parameters)
|
||||
self.state.num_trainable_parameters = sum(p.numel() for p in trainable_parameters)
|
||||
|
||||
use_deepspeed_opt = (
|
||||
self.accelerator.state.deepspeed_plugin is not None
|
||||
@ -405,6 +408,7 @@ class Trainer:
|
||||
generator = generator.manual_seed(self.args.seed)
|
||||
self.state.generator = generator
|
||||
|
||||
free_memory()
|
||||
for epoch in range(first_epoch, self.args.train_epochs):
|
||||
logger.debug(f"Starting epoch ({epoch + 1}/{self.args.train_epochs})")
|
||||
|
||||
@ -496,16 +500,25 @@ class Trainer:
|
||||
##### Initialize pipeline #####
|
||||
pipe = self.initialize_pipeline()
|
||||
|
||||
if self.state.using_deepspeed:
|
||||
# Can't using model_cpu_offload in deepspeed,
|
||||
# so we need to move all components in pipe to device
|
||||
# pipe.to(self.accelerator.device, dtype=self.state.weight_dtype)
|
||||
self.__move_components_to_device(dtype=self.state.weight_dtype, ignore_list=["transformer"])
|
||||
else:
|
||||
# if not using deepspeed, use model_cpu_offload to further reduce memory usage
|
||||
# Or use pipe.enable_sequential_cpu_offload() to further reduce memory usage
|
||||
pipe.enable_model_cpu_offload(device=self.accelerator.device)
|
||||
|
||||
# Convert all model weights to training dtype
|
||||
# Note, this will change LoRA weights in self.components.transformer to training dtype, rather than keep them in fp32
|
||||
pipe = pipe.to(dtype=self.state.weight_dtype)
|
||||
|
||||
#################################
|
||||
|
||||
all_processes_artifacts = []
|
||||
for i in range(num_validation_samples):
|
||||
if self.state.using_deepspeed and self.accelerator.deepspeed_plugin.zero_stage != 3:
|
||||
# Skip current validation on all processes but one
|
||||
if i % accelerator.num_processes != accelerator.process_index:
|
||||
continue
|
||||
@ -526,7 +539,7 @@ class Trainer:
|
||||
video, self.state.train_frames, self.state.train_height, self.state.train_width
|
||||
)
|
||||
# Convert video tensor (F, C, H, W) to list of PIL images
|
||||
video = (video * 255).round().clamp(0, 255).to(torch.uint8)
|
||||
video = video.round().clamp(0, 255).to(torch.uint8)
|
||||
video = [Image.fromarray(frame.permute(1, 2, 0).cpu().numpy()) for frame in video]
|
||||
|
||||
logger.debug(
|
||||
@ -534,7 +547,19 @@ class Trainer:
|
||||
main_process_only=False,
|
||||
)
|
||||
validation_artifacts = self.validation_step({"prompt": prompt, "image": image, "video": video}, pipe)
|
||||
|
||||
if (
|
||||
self.state.using_deepspeed
|
||||
and self.accelerator.deepspeed_plugin.zero_stage == 3
|
||||
and not accelerator.is_main_process
|
||||
):
|
||||
continue
|
||||
|
||||
prompt_filename = string_to_filename(prompt)[:25]
|
||||
# Calculate hash of reversed prompt as a unique identifier
|
||||
reversed_prompt = prompt[::-1]
|
||||
hash_suffix = hashlib.md5(reversed_prompt.encode()).hexdigest()[:5]
|
||||
|
||||
artifacts = {
|
||||
"image": {"type": "image", "value": image},
|
||||
"video": {"type": "video", "value": video},
|
||||
@ -553,7 +578,7 @@ class Trainer:
|
||||
continue
|
||||
|
||||
extension = "png" if artifact_type == "image" else "mp4"
|
||||
filename = f"validation-{step}-{accelerator.process_index}-{prompt_filename}.{extension}"
|
||||
filename = f"validation-{step}-{accelerator.process_index}-{prompt_filename}-{hash_suffix}.{extension}"
|
||||
validation_path = self.args.output_dir / "validation_res"
|
||||
validation_path.mkdir(parents=True, exist_ok=True)
|
||||
filename = str(validation_path / filename)
|
||||
@ -584,18 +609,25 @@ class Trainer:
|
||||
step=step,
|
||||
)
|
||||
|
||||
pipe.remove_all_hooks()
|
||||
########## Clean up ##########
|
||||
if self.state.using_deepspeed:
|
||||
del pipe
|
||||
# Unload models except those needed for training
|
||||
self.__unload_components()
|
||||
self.__move_components_to_cpu(unload_list=self.UNLOAD_LIST)
|
||||
else:
|
||||
pipe.remove_all_hooks()
|
||||
del pipe
|
||||
# Load models except those not needed for training
|
||||
self.__load_components()
|
||||
# Change LoRA weights back to fp32
|
||||
self.__move_components_to_device(dtype=self.state.weight_dtype, ignore_list=self.UNLOAD_LIST)
|
||||
self.components.transformer.to(self.accelerator.device, dtype=self.state.weight_dtype)
|
||||
|
||||
# Change trainable weights back to fp32 to keep with dtype after prepare the model
|
||||
cast_training_params([self.components.transformer], dtype=torch.float32)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
free_memory()
|
||||
accelerator.wait_for_everyone()
|
||||
################################
|
||||
|
||||
memory_statistics = get_memory_statistics()
|
||||
logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}")
|
||||
torch.cuda.reset_peak_memory_stats(accelerator.device)
|
||||
@ -649,20 +681,20 @@ class Trainer:
|
||||
else:
|
||||
raise ValueError(f"Invalid mixed precision: {self.args.mixed_precision}")
|
||||
|
||||
def __load_components(self):
|
||||
def __move_components_to_device(self, dtype, ignore_list: List[str] = []):
|
||||
ignore_list = set(ignore_list)
|
||||
components = self.components.model_dump()
|
||||
for name, component in components.items():
|
||||
if not isinstance(component, type) and hasattr(component, "to"):
|
||||
if name in self.UNLOAD_LIST:
|
||||
continue
|
||||
# setattr(self.components, name, component.to(self.accelerator.device))
|
||||
setattr(self.components, name, component.to(self.accelerator.device, dtype=self.state.weight_dtype))
|
||||
if name not in ignore_list:
|
||||
setattr(self.components, name, component.to(self.accelerator.device, dtype=dtype))
|
||||
|
||||
def __unload_components(self):
|
||||
def __move_components_to_cpu(self, unload_list: List[str] = []):
|
||||
unload_list = set(unload_list)
|
||||
components = self.components.model_dump()
|
||||
for name, component in components.items():
|
||||
if not isinstance(component, type) and hasattr(component, "to"):
|
||||
if name in self.UNLOAD_LIST:
|
||||
if name in unload_list:
|
||||
setattr(self.components, name, component.to("cpu"))
|
||||
|
||||
def __prepare_saving_loading_hooks(self, transformer_lora_config):
|
||||
@ -723,22 +755,16 @@ class Trainer:
|
||||
f" {unexpected_keys}. "
|
||||
)
|
||||
|
||||
# Make sure the trainable params are in float32. This is again needed since the base models
|
||||
# are in `weight_dtype`. More details:
|
||||
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
|
||||
if self.args.mixed_precision == "fp16":
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
cast_training_params([transformer_])
|
||||
|
||||
self.accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
self.accelerator.register_load_state_pre_hook(load_model_hook)
|
||||
|
||||
def __maybe_save_checkpoint(self, global_step: int, must_save: bool = False):
|
||||
if self.accelerator.distributed_type == DistributedType.DEEPSPEED or self.accelerator.is_main_process:
|
||||
if must_save or global_step % self.args.checkpointing_steps == 0:
|
||||
# for training
|
||||
save_path = get_intermediate_ckpt_path(
|
||||
checkpointing_limit=self.args.checkpointing_limit,
|
||||
step=global_step,
|
||||
output_dir=self.args.output_dir,
|
||||
)
|
||||
self.accelerator.save_state(save_path)
|
||||
self.accelerator.save_state(save_path, safe_serialization=True)
|
||||
|
@ -14,7 +14,10 @@ To run the script, use the following command with appropriate arguments:
|
||||
$ python cli_demo.py --prompt "A girl riding a bike." --model_path THUDM/CogVideoX1.5-5b --generate_type "t2v"
|
||||
```
|
||||
|
||||
You can change `pipe.enable_sequential_cpu_offload()` to `pipe.enable_model_cpu_offload()` to speed up inference, but this will use more GPU memory
|
||||
|
||||
Additional options are available to specify the model path, guidance scale, number of inference steps, video generation type, and output paths.
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@ -22,6 +25,7 @@ import logging
|
||||
from typing import Literal, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import (
|
||||
CogVideoXDPMScheduler,
|
||||
CogVideoXImageToVideoPipeline,
|
||||
@ -36,12 +40,12 @@ logging.basicConfig(level=logging.INFO)
|
||||
# Recommended resolution for each model (width, height)
|
||||
RESOLUTION_MAP = {
|
||||
# cogvideox1.5-*
|
||||
"cogvideox1.5-5b-i2v": (1360, 768),
|
||||
"cogvideox1.5-5b": (1360, 768),
|
||||
"cogvideox1.5-5b-i2v": (768, 1360),
|
||||
"cogvideox1.5-5b": (768, 1360),
|
||||
# cogvideox-*
|
||||
"cogvideox-5b-i2v": (720, 480),
|
||||
"cogvideox-5b": (720, 480),
|
||||
"cogvideox-2b": (720, 480),
|
||||
"cogvideox-5b-i2v": (480, 720),
|
||||
"cogvideox-5b": (480, 720),
|
||||
"cogvideox-2b": (480, 720),
|
||||
}
|
||||
|
||||
|
||||
@ -94,7 +98,7 @@ def generate_video(
|
||||
model_name = model_path.split("/")[-1].lower()
|
||||
desired_resolution = RESOLUTION_MAP[model_name]
|
||||
if width is None or height is None:
|
||||
width, height = desired_resolution
|
||||
height, width = desired_resolution
|
||||
logging.info(f"\033[1mUsing default resolution {desired_resolution} for {model_name}\033[0m")
|
||||
elif (width, height) != desired_resolution:
|
||||
if generate_type == "i2v":
|
||||
@ -121,7 +125,7 @@ def generate_video(
|
||||
# If you're using with lora, add this code
|
||||
if lora_path:
|
||||
pipe.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name="test_1")
|
||||
pipe.fuse_lora(lora_scale=1 / lora_rank)
|
||||
pipe.fuse_lora(components=["transformer"], lora_scale=1 / lora_rank)
|
||||
|
||||
# 2. Set Scheduler.
|
||||
# Can be changed to `CogVideoXDPMScheduler` or `CogVideoXDDIMScheduler`.
|
||||
@ -134,8 +138,9 @@ def generate_video(
|
||||
# 3. Enable CPU offload for the model.
|
||||
# turn off if you have multiple GPUs or enough GPU memory(such as H100) and it will cost less time in inference
|
||||
# and enable to("cuda")
|
||||
|
||||
# pipe.to("cuda")
|
||||
|
||||
# pipe.enable_model_cpu_offload()
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
pipe.vae.enable_slicing()
|
||||
pipe.vae.enable_tiling()
|
||||
|
@ -1,4 +1,4 @@
|
||||
diffusers>=0.31.0
|
||||
diffusers>=0.32.1
|
||||
accelerate>=1.1.1
|
||||
transformers>=4.46.2
|
||||
numpy==1.26.0
|
||||
@ -12,3 +12,4 @@ imageio-ffmpeg>=0.5.1
|
||||
openai>=1.54.0
|
||||
moviepy>=2.0.0
|
||||
scikit-video>=1.1.11
|
||||
pydantic>=2.10.3
|
||||
|
@ -4,9 +4,12 @@
|
||||
|
||||
[日本語で読む](./README_ja.md)
|
||||
|
||||
This folder contains inference code using [SAT](https://github.com/THUDM/SwissArmyTransformer) weights, along with fine-tuning code for SAT weights.
|
||||
This folder contains inference code using [SAT](https://github.com/THUDM/SwissArmyTransformer) weights, along with
|
||||
fine-tuning code for SAT weights.
|
||||
|
||||
This code framework was used by our team during model training. There are few comments, so careful study is required.
|
||||
If you are interested in the `CogVideoX1.0` version of the model, please check the SAT
|
||||
folder [here](https://github.com/THUDM/CogVideo/releases/tag/v1.0). This branch only supports the `CogVideoX1.5` series
|
||||
models.
|
||||
|
||||
## Inference Model
|
||||
|
||||
@ -272,7 +275,8 @@ args:
|
||||
force_inference: True
|
||||
```
|
||||
|
||||
+ If using a text file to save multiple prompts, modify `configs/test.txt` as needed. One prompt per line. If you are unsure how to write prompts, use [this code](../inference/convert_demo.py) to call an LLM for refinement.
|
||||
+ If using a text file to save multiple prompts, modify `configs/test.txt` as needed. One prompt per line. If you are
|
||||
unsure how to write prompts, use [this code](../inference/convert_demo.py) to call an LLM for refinement.
|
||||
+ To use command-line input, modify:
|
||||
|
||||
```
|
||||
@ -313,13 +317,15 @@ The dataset should be structured as follows:
|
||||
├── ...
|
||||
```
|
||||
|
||||
Each txt file should have the same name as the corresponding video file and contain the label for that video. The videos and labels should correspond one-to-one. Generally, avoid using one video with multiple labels.
|
||||
Each txt file should have the same name as the corresponding video file and contain the label for that video. The videos
|
||||
and labels should correspond one-to-one. Generally, avoid using one video with multiple labels.
|
||||
|
||||
For style fine-tuning, prepare at least 50 videos and labels with a similar style to facilitate fitting.
|
||||
|
||||
### Modifying the Configuration File
|
||||
|
||||
We support two fine-tuning methods: `Lora` and full-parameter fine-tuning. Note that both methods only fine-tune the `transformer` part. The `VAE` part is not modified, and `T5` is only used as an encoder.
|
||||
We support two fine-tuning methods: `Lora` and full-parameter fine-tuning. Note that both methods only fine-tune the
|
||||
`transformer` part. The `VAE` part is not modified, and `T5` is only used as an encoder.
|
||||
Modify the files in `configs/sft.yaml` (full fine-tuning) as follows:
|
||||
|
||||
```yaml
|
||||
@ -371,13 +377,15 @@ model:
|
||||
|
||||
Edit `finetune_single_gpu.sh` or `finetune_multi_gpus.sh` and select the config file. Below are two examples:
|
||||
|
||||
1. If you want to use the `CogVideoX-2B` model with `Lora`, modify `finetune_single_gpu.sh` or `finetune_multi_gpus.sh` as follows:
|
||||
1. If you want to use the `CogVideoX-2B` model with `Lora`, modify `finetune_single_gpu.sh` or `finetune_multi_gpus.sh`
|
||||
as follows:
|
||||
|
||||
```
|
||||
run_cmd="torchrun --standalone --nproc_per_node=8 train_video.py --base configs/cogvideox_2b_lora.yaml configs/sft.yaml --seed $RANDOM"
|
||||
```
|
||||
|
||||
2. If you want to use the `CogVideoX-2B` model with full fine-tuning, modify `finetune_single_gpu.sh` or `finetune_multi_gpus.sh` as follows:
|
||||
2. If you want to use the `CogVideoX-2B` model with full fine-tuning, modify `finetune_single_gpu.sh` or
|
||||
`finetune_multi_gpus.sh` as follows:
|
||||
|
||||
```
|
||||
run_cmd="torchrun --standalone --nproc_per_node=8 train_video.py --base configs/cogvideox_2b.yaml configs/sft.yaml --seed $RANDOM"
|
||||
@ -417,9 +425,11 @@ python ../tools/convert_weight_sat2hf.py
|
||||
### Exporting Lora Weights from SAT to Huggingface Diffusers
|
||||
|
||||
Support is provided for exporting Lora weights from SAT to Huggingface Diffusers format.
|
||||
After training with the above steps, you’ll find the SAT model with Lora weights in {args.save}/1000/1000/mp_rank_00_model_states.pt
|
||||
After training with the above steps, you’ll find the SAT model with Lora weights in
|
||||
{args.save}/1000/1000/mp_rank_00_model_states.pt
|
||||
|
||||
The export script `export_sat_lora_weight.py` is located in the CogVideoX repository under `tools/`. After exporting, use `load_cogvideox_lora.py` for inference.
|
||||
The export script `export_sat_lora_weight.py` is located in the CogVideoX repository under `tools/`. After exporting,
|
||||
use `load_cogvideox_lora.py` for inference.
|
||||
|
||||
Export command:
|
||||
|
||||
@ -427,7 +437,8 @@ Export command:
|
||||
python tools/export_sat_lora_weight.py --sat_pt_path {args.save}/{experiment_name}-09-09-21-10/1000/mp_rank_00_model_states.pt --lora_save_directory {args.save}/export_hf_lora_weights_1/
|
||||
```
|
||||
|
||||
The following model structures were modified during training. Here is the mapping between SAT and HF Lora structures. Lora adds a low-rank weight to the attention structure of the model.
|
||||
The following model structures were modified during training. Here is the mapping between SAT and HF Lora structures.
|
||||
Lora adds a low-rank weight to the attention structure of the model.
|
||||
|
||||
```
|
||||
'attention.query_key_value.matrix_A.0': 'attn1.to_q.lora_A.weight',
|
||||
|
@ -5,7 +5,8 @@
|
||||
[中文阅读](./README_zh.md)
|
||||
|
||||
このフォルダには、[SAT](https://github.com/THUDM/SwissArmyTransformer)の重みを使用した推論コードと、SAT重みのファインチューニングコードが含まれています。
|
||||
このコードは、チームがモデルを訓練する際に使用したフレームワークです。コメントが少ないため、注意深く確認する必要があります。
|
||||
`CogVideoX1.0`バージョンのモデルに関心がある場合は、[こちら](https://github.com/THUDM/CogVideo/releases/tag/v1.0)
|
||||
のSATフォルダを参照してください。このブランチは`CogVideoX1.5`シリーズのモデルのみをサポートしています。
|
||||
|
||||
## 推論モデル
|
||||
|
||||
@ -16,7 +17,8 @@ pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### 2. モデルの重みをダウンロード
|
||||
まず、SATミラーからモデルの重みをダウンロードしてください。
|
||||
|
||||
まず、SATミラーからモデルの重みをダウンロードしてください。
|
||||
|
||||
#### CogVideoX1.5 モデル
|
||||
|
||||
@ -270,7 +272,9 @@ args:
|
||||
force_inference: True
|
||||
```
|
||||
|
||||
+ 複数のプロンプトを含むテキストファイルを使用する場合、`configs/test.txt`を適宜編集してください。1行につき1プロンプトです。プロンプトの書き方が分からない場合は、[こちらのコード](../inference/convert_demo.py)を使用してLLMで補正できます。
|
||||
+ 複数のプロンプトを含むテキストファイルを使用する場合、`configs/test.txt`
|
||||
を適宜編集してください。1行につき1プロンプトです。プロンプトの書き方が分からない場合は、[こちらのコード](../inference/convert_demo.py)
|
||||
を使用してLLMで補正できます。
|
||||
+ コマンドライン入力を使用する場合、以下のように変更します:
|
||||
|
||||
```
|
||||
@ -346,6 +350,7 @@ bash inference.sh
|
||||
fp16:
|
||||
enabled: True # CogVideoX-2B 用は True、CogVideoX-5B 用は False に設定
|
||||
```
|
||||
|
||||
```yaml
|
||||
args:
|
||||
latent_channels: 16
|
||||
@ -364,7 +369,8 @@ args:
|
||||
force_inference: True
|
||||
```
|
||||
|
||||
+ If using a text file to save multiple prompts, modify `configs/test.txt` as needed. One prompt per line. If you are unsure how to write prompts, use [this code](../inference/convert_demo.py) to call an LLM for refinement.
|
||||
+ If using a text file to save multiple prompts, modify `configs/test.txt` as needed. One prompt per line. If you are
|
||||
unsure how to write prompts, use [this code](../inference/convert_demo.py) to call an LLM for refinement.
|
||||
+ To use command-line input, modify:
|
||||
|
||||
```
|
||||
@ -405,13 +411,15 @@ The dataset should be structured as follows:
|
||||
├── ...
|
||||
```
|
||||
|
||||
Each txt file should have the same name as the corresponding video file and contain the label for that video. The videos and labels should correspond one-to-one. Generally, avoid using one video with multiple labels.
|
||||
Each txt file should have the same name as the corresponding video file and contain the label for that video. The videos
|
||||
and labels should correspond one-to-one. Generally, avoid using one video with multiple labels.
|
||||
|
||||
For style fine-tuning, prepare at least 50 videos and labels with a similar style to facilitate fitting.
|
||||
|
||||
### Modifying the Configuration File
|
||||
|
||||
We support two fine-tuning methods: `Lora` and full-parameter fine-tuning. Note that both methods only fine-tune the `transformer` part. The `VAE` part is not modified, and `T5` is only used as an encoder.
|
||||
We support two fine-tuning methods: `Lora` and full-parameter fine-tuning. Note that both methods only fine-tune the
|
||||
`transformer` part. The `VAE` part is not modified, and `T5` is only used as an encoder.
|
||||
Modify the files in `configs/sft.yaml` (full fine-tuning) as follows:
|
||||
|
||||
```yaml
|
||||
@ -463,13 +471,15 @@ model:
|
||||
|
||||
Edit `finetune_single_gpu.sh` or `finetune_multi_gpus.sh` and select the config file. Below are two examples:
|
||||
|
||||
1. If you want to use the `CogVideoX-2B` model with `Lora`, modify `finetune_single_gpu.sh` or `finetune_multi_gpus.sh` as follows:
|
||||
1. If you want to use the `CogVideoX-2B` model with `Lora`, modify `finetune_single_gpu.sh` or `finetune_multi_gpus.sh`
|
||||
as follows:
|
||||
|
||||
```
|
||||
run_cmd="torchrun --standalone --nproc_per_node=8 train_video.py --base configs/cogvideox_2b_lora.yaml configs/sft.yaml --seed $RANDOM"
|
||||
```
|
||||
|
||||
2. If you want to use the `CogVideoX-2B` model with full fine-tuning, modify `finetune_single_gpu.sh` or `finetune_multi_gpus.sh` as follows:
|
||||
2. If you want to use the `CogVideoX-2B` model with full fine-tuning, modify `finetune_single_gpu.sh` or
|
||||
`finetune_multi_gpus.sh` as follows:
|
||||
|
||||
```
|
||||
run_cmd="torchrun --standalone --nproc_per_node=8 train_video.py --base configs/cogvideox_2b.yaml configs/sft.yaml --seed $RANDOM"
|
||||
@ -509,9 +519,11 @@ python ../tools/convert_weight_sat2hf.py
|
||||
### Exporting Lora Weights from SAT to Huggingface Diffusers
|
||||
|
||||
Support is provided for exporting Lora weights from SAT to Huggingface Diffusers format.
|
||||
After training with the above steps, you’ll find the SAT model with Lora weights in {args.save}/1000/1000/mp_rank_00_model_states.pt
|
||||
After training with the above steps, you’ll find the SAT model with Lora weights in
|
||||
{args.save}/1000/1000/mp_rank_00_model_states.pt
|
||||
|
||||
The export script `export_sat_lora_weight.py` is located in the CogVideoX repository under `tools/`. After exporting, use `load_cogvideox_lora.py` for inference.
|
||||
The export script `export_sat_lora_weight.py` is located in the CogVideoX repository under `tools/`. After exporting,
|
||||
use `load_cogvideox_lora.py` for inference.
|
||||
|
||||
Export command:
|
||||
|
||||
@ -519,7 +531,8 @@ Export command:
|
||||
python tools/export_sat_lora_weight.py --sat_pt_path {args.save}/{experiment_name}-09-09-21-10/1000/mp_rank_00_model_states.pt --lora_save_directory {args.save}/export_hf_lora_weights_1/
|
||||
```
|
||||
|
||||
The following model structures were modified during training. Here is the mapping between SAT and HF Lora structures. Lora adds a low-rank weight to the attention structure of the model.
|
||||
The following model structures were modified during training. Here is the mapping between SAT and HF Lora structures.
|
||||
Lora adds a low-rank weight to the attention structure of the model.
|
||||
|
||||
```
|
||||
'attention.query_key_value.matrix_A.0': 'attn1.to_q.lora_A.weight',
|
||||
|
@ -5,8 +5,7 @@
|
||||
[日本語で読む](./README_ja.md)
|
||||
|
||||
本文件夹包含了使用 [SAT](https://github.com/THUDM/SwissArmyTransformer) 权重的推理代码,以及 SAT 权重的微调代码。
|
||||
|
||||
该代码是团队训练模型时使用的框架。注释较少,需要认真研究。
|
||||
如果你关注 `CogVideoX1.0`版本的模型,请查看[这里](https://github.com/THUDM/CogVideo/releases/tag/v1.0)的SAT文件夹,该分支仅支持`CogVideoX1.5`系列模型。
|
||||
|
||||
## 推理模型
|
||||
|
||||
|
848
tools/convert_weight_deepspeed2hf.py
Normal file
848
tools/convert_weight_deepspeed2hf.py
Normal file
@ -0,0 +1,848 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
# This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets
|
||||
# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
|
||||
# the future. Once extracted, the weights don't require DeepSpeed and can be used in any
|
||||
# application.
|
||||
#
|
||||
# example:
|
||||
# python zero_to_fp32.py . output_dir/
|
||||
# or
|
||||
# python zero_to_fp32.py . output_dir/ --safe_serialization
|
||||
|
||||
import argparse
|
||||
import torch
|
||||
import glob
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import gc
|
||||
import json
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
|
||||
# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
|
||||
# DeepSpeed data structures it has to be available in the current python environment.
|
||||
from deepspeed.utils import logger
|
||||
from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS,
|
||||
FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES,
|
||||
FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS)
|
||||
|
||||
|
||||
@dataclass
|
||||
class zero_model_state:
|
||||
buffers: dict()
|
||||
param_shapes: dict()
|
||||
shared_params: list
|
||||
ds_version: int
|
||||
frozen_param_shapes: dict()
|
||||
frozen_param_fragments: dict()
|
||||
|
||||
|
||||
debug = 0
|
||||
|
||||
# load to cpu
|
||||
device = torch.device('cpu')
|
||||
|
||||
|
||||
def atoi(text):
|
||||
return int(text) if text.isdigit() else text
|
||||
|
||||
|
||||
def natural_keys(text):
|
||||
'''
|
||||
alist.sort(key=natural_keys) sorts in human order
|
||||
http://nedbatchelder.com/blog/200712/human_sorting.html
|
||||
(See Toothy's implementation in the comments)
|
||||
'''
|
||||
return [atoi(c) for c in re.split(r'(\d+)', text)]
|
||||
|
||||
|
||||
def get_model_state_file(checkpoint_dir, zero_stage):
|
||||
if not os.path.isdir(checkpoint_dir):
|
||||
raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
|
||||
|
||||
# there should be only one file
|
||||
if zero_stage <= 2:
|
||||
file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
|
||||
elif zero_stage == 3:
|
||||
file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
|
||||
|
||||
if not os.path.exists(file):
|
||||
raise FileNotFoundError(f"can't find model states file at '{file}'")
|
||||
|
||||
return file
|
||||
|
||||
|
||||
def get_checkpoint_files(checkpoint_dir, glob_pattern):
|
||||
# XXX: need to test that this simple glob rule works for multi-node setup too
|
||||
ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys)
|
||||
|
||||
if len(ckpt_files) == 0:
|
||||
raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'")
|
||||
|
||||
return ckpt_files
|
||||
|
||||
|
||||
def get_optim_files(checkpoint_dir):
|
||||
return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt")
|
||||
|
||||
|
||||
def get_model_state_files(checkpoint_dir):
|
||||
return get_checkpoint_files(checkpoint_dir, "*_model_states.pt")
|
||||
|
||||
|
||||
def parse_model_states(files):
|
||||
zero_model_states = []
|
||||
for file in files:
|
||||
state_dict = torch.load(file, map_location=device, weights_only=False)
|
||||
|
||||
if BUFFER_NAMES not in state_dict:
|
||||
raise ValueError(f"{file} is not a model state checkpoint")
|
||||
buffer_names = state_dict[BUFFER_NAMES]
|
||||
if debug:
|
||||
print("Found buffers:", buffer_names)
|
||||
|
||||
# recover just the buffers while restoring them to fp32 if they were saved in fp16
|
||||
buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names}
|
||||
param_shapes = state_dict[PARAM_SHAPES]
|
||||
|
||||
# collect parameters that are included in param_shapes
|
||||
param_names = []
|
||||
for s in param_shapes:
|
||||
for name in s.keys():
|
||||
param_names.append(name)
|
||||
|
||||
# update with frozen parameters
|
||||
frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None)
|
||||
if frozen_param_shapes is not None:
|
||||
if debug:
|
||||
print(f"Found frozen_param_shapes: {frozen_param_shapes}")
|
||||
param_names += list(frozen_param_shapes.keys())
|
||||
|
||||
# handle shared params
|
||||
shared_params = [[k, v] for k, v in state_dict["shared_params"].items()]
|
||||
|
||||
ds_version = state_dict.get(DS_VERSION, None)
|
||||
|
||||
frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
|
||||
|
||||
z_model_state = zero_model_state(buffers=buffers,
|
||||
param_shapes=param_shapes,
|
||||
shared_params=shared_params,
|
||||
ds_version=ds_version,
|
||||
frozen_param_shapes=frozen_param_shapes,
|
||||
frozen_param_fragments=frozen_param_fragments)
|
||||
zero_model_states.append(z_model_state)
|
||||
|
||||
return zero_model_states
|
||||
|
||||
|
||||
def parse_optim_states(files, ds_checkpoint_dir):
|
||||
total_files = len(files)
|
||||
state_dicts = []
|
||||
for f in tqdm(files, desc='Loading checkpoint shards'):
|
||||
state_dict = torch.load(f, map_location=device, mmap=True, weights_only=False)
|
||||
# immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
|
||||
# and also handle the case where it was already removed by another helper script
|
||||
state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
|
||||
state_dicts.append(state_dict)
|
||||
|
||||
if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]:
|
||||
raise ValueError(f"{files[0]} is not a zero checkpoint")
|
||||
zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
|
||||
world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
|
||||
|
||||
# For ZeRO-2 each param group can have different partition_count as data parallelism for expert
|
||||
# parameters can be different from data parallelism for non-expert parameters. So we can just
|
||||
# use the max of the partition_count to get the dp world_size.
|
||||
|
||||
if type(world_size) is list:
|
||||
world_size = max(world_size)
|
||||
|
||||
if world_size != total_files:
|
||||
raise ValueError(
|
||||
f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
|
||||
"Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
|
||||
)
|
||||
|
||||
# the groups are named differently in each stage
|
||||
if zero_stage <= 2:
|
||||
fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
|
||||
elif zero_stage == 3:
|
||||
fp32_groups_key = FP32_FLAT_GROUPS
|
||||
else:
|
||||
raise ValueError(f"unknown zero stage {zero_stage}")
|
||||
|
||||
fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
|
||||
return zero_stage, world_size, fp32_flat_groups
|
||||
|
||||
|
||||
def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters):
|
||||
"""
|
||||
Returns fp32 state_dict reconstructed from ds checkpoint
|
||||
|
||||
Args:
|
||||
- ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
|
||||
|
||||
"""
|
||||
print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
|
||||
|
||||
optim_files = get_optim_files(ds_checkpoint_dir)
|
||||
zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
|
||||
print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
|
||||
|
||||
model_files = get_model_state_files(ds_checkpoint_dir)
|
||||
|
||||
zero_model_states = parse_model_states(model_files)
|
||||
print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
|
||||
|
||||
if zero_stage <= 2:
|
||||
return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
|
||||
exclude_frozen_parameters)
|
||||
elif zero_stage == 3:
|
||||
return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
|
||||
exclude_frozen_parameters)
|
||||
|
||||
|
||||
def _zero2_merge_frozen_params(state_dict, zero_model_states):
|
||||
if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
|
||||
return
|
||||
|
||||
frozen_param_shapes = zero_model_states[0].frozen_param_shapes
|
||||
frozen_param_fragments = zero_model_states[0].frozen_param_fragments
|
||||
|
||||
if debug:
|
||||
num_elem = sum(s.numel() for s in frozen_param_shapes.values())
|
||||
print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
|
||||
|
||||
wanted_params = len(frozen_param_shapes)
|
||||
wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
|
||||
avail_numel = sum([p.numel() for p in frozen_param_fragments.values()])
|
||||
print(f'Frozen params: Have {avail_numel} numels to process.')
|
||||
print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
|
||||
|
||||
total_params = 0
|
||||
total_numel = 0
|
||||
for name, shape in frozen_param_shapes.items():
|
||||
total_params += 1
|
||||
unpartitioned_numel = shape.numel()
|
||||
total_numel += unpartitioned_numel
|
||||
|
||||
state_dict[name] = frozen_param_fragments[name]
|
||||
|
||||
if debug:
|
||||
print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
|
||||
|
||||
print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
|
||||
|
||||
|
||||
def _has_callable(obj, fn):
|
||||
attr = getattr(obj, fn, None)
|
||||
return callable(attr)
|
||||
|
||||
|
||||
def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
|
||||
param_shapes = zero_model_states[0].param_shapes
|
||||
|
||||
# Reconstruction protocol:
|
||||
#
|
||||
# XXX: document this
|
||||
|
||||
if debug:
|
||||
for i in range(world_size):
|
||||
for j in range(len(fp32_flat_groups[0])):
|
||||
print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
|
||||
|
||||
# XXX: memory usage doubles here (zero2)
|
||||
num_param_groups = len(fp32_flat_groups[0])
|
||||
merged_single_partition_of_fp32_groups = []
|
||||
for i in range(num_param_groups):
|
||||
merged_partitions = [sd[i] for sd in fp32_flat_groups]
|
||||
full_single_fp32_vector = torch.cat(merged_partitions, 0)
|
||||
merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
|
||||
avail_numel = sum(
|
||||
[full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups])
|
||||
|
||||
if debug:
|
||||
wanted_params = sum([len(shapes) for shapes in param_shapes])
|
||||
wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
|
||||
# not asserting if there is a mismatch due to possible padding
|
||||
print(f"Have {avail_numel} numels to process.")
|
||||
print(f"Need {wanted_numel} numels in {wanted_params} params.")
|
||||
|
||||
# params
|
||||
# XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
|
||||
# out-of-core computing solution
|
||||
total_numel = 0
|
||||
total_params = 0
|
||||
for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
|
||||
offset = 0
|
||||
avail_numel = full_single_fp32_vector.numel()
|
||||
for name, shape in shapes.items():
|
||||
|
||||
unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape)
|
||||
total_numel += unpartitioned_numel
|
||||
total_params += 1
|
||||
|
||||
if debug:
|
||||
print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
|
||||
state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape)
|
||||
offset += unpartitioned_numel
|
||||
|
||||
# Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
|
||||
# avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex
|
||||
# paddings performed in the code it's almost impossible to predict the exact numbers w/o the
|
||||
# live optimizer object, so we are checking that the numbers are within the right range
|
||||
align_to = 2 * world_size
|
||||
|
||||
def zero2_align(x):
|
||||
return align_to * math.ceil(x / align_to)
|
||||
|
||||
if debug:
|
||||
print(f"original offset={offset}, avail_numel={avail_numel}")
|
||||
|
||||
offset = zero2_align(offset)
|
||||
avail_numel = zero2_align(avail_numel)
|
||||
|
||||
if debug:
|
||||
print(f"aligned offset={offset}, avail_numel={avail_numel}")
|
||||
|
||||
# Sanity check
|
||||
if offset != avail_numel:
|
||||
raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
|
||||
|
||||
print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
|
||||
|
||||
|
||||
def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
|
||||
exclude_frozen_parameters):
|
||||
state_dict = OrderedDict()
|
||||
|
||||
# buffers
|
||||
buffers = zero_model_states[0].buffers
|
||||
state_dict.update(buffers)
|
||||
if debug:
|
||||
print(f"added {len(buffers)} buffers")
|
||||
|
||||
if not exclude_frozen_parameters:
|
||||
_zero2_merge_frozen_params(state_dict, zero_model_states)
|
||||
|
||||
_zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
|
||||
|
||||
# recover shared parameters
|
||||
for pair in zero_model_states[0].shared_params:
|
||||
if pair[1] in state_dict:
|
||||
state_dict[pair[0]] = state_dict[pair[1]]
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def zero3_partitioned_param_info(unpartitioned_numel, world_size):
|
||||
remainder = unpartitioned_numel % world_size
|
||||
padding_numel = (world_size - remainder) if remainder else 0
|
||||
partitioned_numel = math.ceil(unpartitioned_numel / world_size)
|
||||
return partitioned_numel, padding_numel
|
||||
|
||||
|
||||
def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
|
||||
if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
|
||||
return
|
||||
|
||||
if debug:
|
||||
for i in range(world_size):
|
||||
num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values())
|
||||
print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
|
||||
|
||||
frozen_param_shapes = zero_model_states[0].frozen_param_shapes
|
||||
wanted_params = len(frozen_param_shapes)
|
||||
wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
|
||||
avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size
|
||||
print(f'Frozen params: Have {avail_numel} numels to process.')
|
||||
print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
|
||||
|
||||
total_params = 0
|
||||
total_numel = 0
|
||||
for name, shape in zero_model_states[0].frozen_param_shapes.items():
|
||||
total_params += 1
|
||||
unpartitioned_numel = shape.numel()
|
||||
total_numel += unpartitioned_numel
|
||||
|
||||
param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states)
|
||||
state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape)
|
||||
|
||||
partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
|
||||
|
||||
if debug:
|
||||
print(
|
||||
f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
|
||||
)
|
||||
|
||||
print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
|
||||
|
||||
|
||||
class GatheredTensor:
|
||||
"""
|
||||
A pseudo tensor that collects partitioned weights.
|
||||
It is more memory efficient when there are multiple groups.
|
||||
"""
|
||||
|
||||
def __init__(self, flat_groups, flat_groups_offset, offset, partitioned_numel, shape):
|
||||
self.flat_groups = flat_groups
|
||||
self.flat_groups_offset = flat_groups_offset
|
||||
self.offset = offset
|
||||
self.partitioned_numel = partitioned_numel
|
||||
self.shape = shape
|
||||
self.dtype = self.flat_groups[0][0].dtype
|
||||
|
||||
def contiguous(self):
|
||||
"""
|
||||
Merge partitioned weights from flat_groups into a single tensor.
|
||||
"""
|
||||
end_idx = self.offset + self.partitioned_numel
|
||||
world_size = len(self.flat_groups)
|
||||
pad_flat_param_chunks = []
|
||||
|
||||
for rank_i in range(world_size):
|
||||
# for each rank, we need to collect weights from related group/groups
|
||||
flat_groups_at_rank_i = self.flat_groups[rank_i]
|
||||
start_group_id = None
|
||||
end_group_id = None
|
||||
for group_id in range(len(self.flat_groups_offset)):
|
||||
if self.flat_groups_offset[group_id] <= self.offset < self.flat_groups_offset[group_id + 1]:
|
||||
start_group_id = group_id
|
||||
if self.flat_groups_offset[group_id] < end_idx <= self.flat_groups_offset[group_id + 1]:
|
||||
end_group_id = group_id
|
||||
break
|
||||
# collect weights from related group/groups
|
||||
for group_id in range(start_group_id, end_group_id + 1):
|
||||
flat_tensor = flat_groups_at_rank_i[group_id]
|
||||
start_offset = self.offset - self.flat_groups_offset[group_id]
|
||||
end_offset = min(end_idx, self.flat_groups_offset[group_id + 1]) - self.flat_groups_offset[group_id]
|
||||
pad_flat_param_chunks.append(flat_tensor[start_offset:end_offset])
|
||||
|
||||
# collect weights from all ranks
|
||||
pad_flat_param = torch.cat(pad_flat_param_chunks, dim=0)
|
||||
param = pad_flat_param[:self.shape.numel()].view(self.shape).contiguous()
|
||||
return param
|
||||
|
||||
|
||||
def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
|
||||
param_shapes = zero_model_states[0].param_shapes
|
||||
avail_numel = sum([flat_group.numel() for flat_group in fp32_flat_groups[0]]) * world_size
|
||||
|
||||
# Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
|
||||
# param, re-consolidating each param, while dealing with padding if any
|
||||
|
||||
# merge list of dicts, preserving order
|
||||
param_shapes = {k: v for d in param_shapes for k, v in d.items()}
|
||||
|
||||
if debug:
|
||||
for i in range(world_size):
|
||||
print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}")
|
||||
|
||||
wanted_params = len(param_shapes)
|
||||
wanted_numel = sum(shape.numel() for shape in param_shapes.values())
|
||||
# not asserting if there is a mismatch due to possible padding
|
||||
avail_numel = fp32_flat_groups[0].numel() * world_size
|
||||
print(f"Trainable params: Have {avail_numel} numels to process.")
|
||||
print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.")
|
||||
|
||||
# params
|
||||
# XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
|
||||
# out-of-core computing solution
|
||||
offset = 0
|
||||
total_numel = 0
|
||||
total_params = 0
|
||||
flat_groups_offset = [0] + list(np.cumsum([flat_tensor.numel() for flat_tensor in fp32_flat_groups[0]]))
|
||||
for name, shape in tqdm(param_shapes.items(), desc='Gathering sharded weights'):
|
||||
unpartitioned_numel = shape.numel()
|
||||
total_numel += unpartitioned_numel
|
||||
total_params += 1
|
||||
partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
|
||||
|
||||
if debug:
|
||||
print(
|
||||
f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
|
||||
)
|
||||
|
||||
# memory efficient tensor
|
||||
tensor = GatheredTensor(fp32_flat_groups, flat_groups_offset, offset, partitioned_numel, shape)
|
||||
state_dict[name] = tensor
|
||||
offset += partitioned_numel
|
||||
|
||||
offset *= world_size
|
||||
|
||||
# Sanity check
|
||||
if offset != avail_numel:
|
||||
raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
|
||||
|
||||
print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")
|
||||
|
||||
|
||||
def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
|
||||
exclude_frozen_parameters):
|
||||
state_dict = OrderedDict()
|
||||
|
||||
# buffers
|
||||
buffers = zero_model_states[0].buffers
|
||||
state_dict.update(buffers)
|
||||
if debug:
|
||||
print(f"added {len(buffers)} buffers")
|
||||
|
||||
if not exclude_frozen_parameters:
|
||||
_zero3_merge_frozen_params(state_dict, world_size, zero_model_states)
|
||||
|
||||
_zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
|
||||
|
||||
# recover shared parameters
|
||||
for pair in zero_model_states[0].shared_params:
|
||||
if pair[1] in state_dict:
|
||||
state_dict[pair[0]] = state_dict[pair[1]]
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def to_torch_tensor(state_dict, return_empty_tensor=False):
|
||||
"""
|
||||
Convert state_dict of GatheredTensor to torch tensor
|
||||
"""
|
||||
torch_state_dict = {}
|
||||
converted_tensors = {}
|
||||
for name, tensor in state_dict.items():
|
||||
tensor_id = id(tensor)
|
||||
if tensor_id in converted_tensors: # shared tensors
|
||||
shared_tensor = torch_state_dict[converted_tensors[tensor_id]]
|
||||
torch_state_dict[name] = shared_tensor
|
||||
else:
|
||||
converted_tensors[tensor_id] = name
|
||||
if return_empty_tensor:
|
||||
torch_state_dict[name] = torch.empty(tensor.shape, dtype=tensor.dtype)
|
||||
else:
|
||||
torch_state_dict[name] = tensor.contiguous()
|
||||
return torch_state_dict
|
||||
|
||||
|
||||
def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
|
||||
tag=None,
|
||||
exclude_frozen_parameters=False,
|
||||
lazy_mode=False):
|
||||
"""
|
||||
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
|
||||
``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
|
||||
via a model hub.
|
||||
|
||||
Args:
|
||||
- ``checkpoint_dir``: path to the desired checkpoint folder
|
||||
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
|
||||
- ``exclude_frozen_parameters``: exclude frozen parameters
|
||||
- ``lazy_mode``: get state_dict in lazy mode. It returns a dict of pesduo tensor instead of torch tensor, which is more memory efficient.
|
||||
Convert the pesduo tensor to torch tensor by ``.contiguous()``
|
||||
|
||||
Returns:
|
||||
- pytorch ``state_dict``
|
||||
|
||||
A typical usage might be ::
|
||||
|
||||
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
|
||||
# do the training and checkpoint saving
|
||||
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
|
||||
model = model.cpu() # move to cpu
|
||||
model.load_state_dict(state_dict)
|
||||
# submit to model hub or save the model to share with others
|
||||
|
||||
In this example the ``model`` will no longer be usable in the deepspeed context of the same
|
||||
application. i.e. you will need to re-initialize the deepspeed engine, since
|
||||
``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
|
||||
|
||||
If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
|
||||
|
||||
Note: the above usage may not work if your application doesn't have sufficient free CPU memory.
|
||||
You may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
|
||||
the checkpoint. Or you can load state_dict in lazy mode ::
|
||||
|
||||
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
|
||||
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, lazy_mode=True) # not on cpu
|
||||
for name, lazy_tensor in state_dict.item():
|
||||
tensor = lazy_tensor.contiguous() # to cpu
|
||||
print(name, tensor)
|
||||
# del tensor to release memory if it no longer in use
|
||||
"""
|
||||
if tag is None:
|
||||
latest_path = os.path.join(checkpoint_dir, 'latest')
|
||||
if os.path.isfile(latest_path):
|
||||
with open(latest_path, 'r') as fd:
|
||||
tag = fd.read().strip()
|
||||
else:
|
||||
raise ValueError(f"Unable to find 'latest' file at {latest_path}")
|
||||
|
||||
ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
|
||||
|
||||
if not os.path.isdir(ds_checkpoint_dir):
|
||||
raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
|
||||
|
||||
state_dict = _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters)
|
||||
if lazy_mode:
|
||||
return state_dict
|
||||
else:
|
||||
return to_torch_tensor(state_dict)
|
||||
|
||||
|
||||
def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir,
|
||||
output_dir,
|
||||
max_shard_size="5GB",
|
||||
safe_serialization=False,
|
||||
tag=None,
|
||||
exclude_frozen_parameters=False):
|
||||
"""
|
||||
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
|
||||
loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
|
||||
|
||||
Args:
|
||||
- ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
|
||||
- ``output_dir``: directory to the pytorch fp32 state_dict output files
|
||||
- ``max_shard_size``: the maximum size for a checkpoint before being sharded, default value is 5GB
|
||||
- ``safe_serialization``: whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
|
||||
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
|
||||
- ``exclude_frozen_parameters``: exclude frozen parameters
|
||||
"""
|
||||
|
||||
# Dependency pre-check
|
||||
if safe_serialization:
|
||||
try:
|
||||
from safetensors.torch import save_file
|
||||
except ImportError:
|
||||
print('If you want to use `safe_serialization`, please `pip install safetensors`')
|
||||
raise
|
||||
if max_shard_size is not None:
|
||||
try:
|
||||
from huggingface_hub import split_torch_state_dict_into_shards
|
||||
except ImportError:
|
||||
print('If you want to use `max_shard_size`, please `pip install huggingface_hub`')
|
||||
raise
|
||||
|
||||
# Convert zero checkpoint to state_dict
|
||||
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
|
||||
tag,
|
||||
exclude_frozen_parameters,
|
||||
lazy_mode=True)
|
||||
|
||||
# Shard the model if it is too big.
|
||||
weights_name = "model.safetensors" if safe_serialization else "pytorch_model.bin"
|
||||
if max_shard_size is not None:
|
||||
filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
|
||||
# an memory-efficient approach for sharding
|
||||
empty_state_dict = to_torch_tensor(state_dict, return_empty_tensor=True)
|
||||
state_dict_split = split_torch_state_dict_into_shards(empty_state_dict,
|
||||
filename_pattern=filename_pattern,
|
||||
max_shard_size=max_shard_size)
|
||||
else:
|
||||
from collections import namedtuple
|
||||
StateDictSplit = namedtuple("StateDictSplit", ["is_sharded", "filename_to_tensors"])
|
||||
state_dict_split = StateDictSplit(is_sharded=False,
|
||||
filename_to_tensors={weights_name: list(state_dict.keys())})
|
||||
|
||||
# Save the model by shard
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
filename_to_tensors = state_dict_split.filename_to_tensors.items()
|
||||
for shard_file, tensors in tqdm(filename_to_tensors, desc="Saving checkpoint shards"):
|
||||
shard_state_dict = {tensor_name: state_dict[tensor_name] for tensor_name in tensors}
|
||||
shard_state_dict = to_torch_tensor(shard_state_dict)
|
||||
output_path = os.path.join(output_dir, shard_file)
|
||||
if safe_serialization:
|
||||
save_file(shard_state_dict, output_path, metadata={"format": "pt"})
|
||||
else:
|
||||
torch.save(shard_state_dict, output_path)
|
||||
# release the memory of current shard
|
||||
for tensor_name in list(shard_state_dict.keys()):
|
||||
del state_dict[tensor_name]
|
||||
del shard_state_dict[tensor_name]
|
||||
del shard_state_dict
|
||||
gc.collect()
|
||||
|
||||
# Save index if sharded
|
||||
if state_dict_split.is_sharded:
|
||||
index = {
|
||||
"metadata": state_dict_split.metadata,
|
||||
"weight_map": state_dict_split.tensor_to_filename,
|
||||
}
|
||||
save_index_file = "model.safetensors.index.json" if safe_serialization else "pytorch_model.bin.index.json"
|
||||
save_index_file = os.path.join(output_dir, save_index_file)
|
||||
with open(save_index_file, "w", encoding="utf-8") as f:
|
||||
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||
f.write(content)
|
||||
|
||||
|
||||
def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
|
||||
"""
|
||||
1. Put the provided model to cpu
|
||||
2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
|
||||
3. Load it into the provided model
|
||||
|
||||
Args:
|
||||
- ``model``: the model object to update
|
||||
- ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
|
||||
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
|
||||
|
||||
Returns:
|
||||
- ``model`: modified model
|
||||
|
||||
Make sure you have plenty of CPU memory available before you call this function. If you don't
|
||||
have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
|
||||
conveniently placed for you in the checkpoint folder.
|
||||
|
||||
A typical usage might be ::
|
||||
|
||||
from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
|
||||
model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
|
||||
# submit to model hub or save the model to share with others
|
||||
|
||||
Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context
|
||||
of the same application. i.e. you will need to re-initialize the deepspeed engine, since
|
||||
``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
|
||||
|
||||
"""
|
||||
logger.info(f"Extracting fp32 weights")
|
||||
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
|
||||
|
||||
logger.info(f"Overwriting model with fp32 weights")
|
||||
model = model.cpu()
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def convert_zero_checkpoint_to_bf16_state_dict(checkpoint_dir,
|
||||
output_dir,
|
||||
max_shard_size="5GB",
|
||||
safe_serialization=True,
|
||||
tag=None,
|
||||
exclude_frozen_parameters=False):
|
||||
"""
|
||||
将 ZeRO 2 或 ZeRO 3 格式的 DeepSpeed 检查点转换为 BF16,并输出到指定目录下,命名规则为:
|
||||
- 如果只有一个分片:
|
||||
diffusion_pytorch_model.safetensors
|
||||
- 如果分片多于一个:
|
||||
diffusion_pytorch_model-00001-of-0000X.safetensors
|
||||
diffusion_pytorch_model-00002-of-0000X.safetensors
|
||||
...
|
||||
diffusion_pytorch_model.safetensors.index.json
|
||||
"""
|
||||
|
||||
if safe_serialization:
|
||||
try:
|
||||
from safetensors.torch import save_file
|
||||
except ImportError:
|
||||
raise ImportError("You need `pip install safetensors` to use safetensors.")
|
||||
if max_shard_size is not None:
|
||||
try:
|
||||
from huggingface_hub import split_torch_state_dict_into_shards
|
||||
except ImportError:
|
||||
raise ImportError("You need `pip install huggingface_hub` to use the sharding feature.")
|
||||
|
||||
state_dict = get_fp32_state_dict_from_zero_checkpoint(
|
||||
checkpoint_dir,
|
||||
tag=tag,
|
||||
exclude_frozen_parameters=exclude_frozen_parameters,
|
||||
lazy_mode=True
|
||||
)
|
||||
|
||||
state_dict = to_torch_tensor(state_dict, return_empty_tensor=False)
|
||||
|
||||
for key, tensor in state_dict.items():
|
||||
state_dict[key] = tensor.to(torch.bfloat16)
|
||||
|
||||
if safe_serialization:
|
||||
filename_pattern = "diffusion_pytorch_model{suffix}.safetensors"
|
||||
else:
|
||||
filename_pattern = "diffusion_pytorch_model{suffix}.bin"
|
||||
|
||||
empty_state_dict = to_torch_tensor(state_dict, return_empty_tensor=True)
|
||||
state_dict_split = split_torch_state_dict_into_shards(
|
||||
empty_state_dict,
|
||||
filename_pattern=filename_pattern,
|
||||
max_shard_size=max_shard_size
|
||||
)
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
filename_to_tensors = list(state_dict_split.filename_to_tensors.items())
|
||||
for shard_file, tensors in tqdm(filename_to_tensors, desc="Saving checkpoint shards"):
|
||||
shard_state_dict = {t_name: state_dict[t_name] for t_name in tensors}
|
||||
shard_state_dict = to_torch_tensor(shard_state_dict)
|
||||
|
||||
# Save
|
||||
output_path = os.path.join(output_dir, shard_file)
|
||||
if safe_serialization:
|
||||
save_file(shard_state_dict, output_path, metadata={"format": "pt"})
|
||||
else:
|
||||
torch.save(shard_state_dict, output_path)
|
||||
for t_name in shard_state_dict.keys():
|
||||
del state_dict[t_name]
|
||||
del shard_state_dict
|
||||
gc.collect()
|
||||
|
||||
|
||||
if state_dict_split.is_sharded:
|
||||
index = {
|
||||
"metadata": state_dict_split.metadata,
|
||||
"weight_map": state_dict_split.tensor_to_filename,
|
||||
}
|
||||
index_path = os.path.join(output_dir, "diffusion_pytorch_model.safetensors.index.json")
|
||||
with open(index_path, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(index, indent=2, sort_keys=True) + "\n")
|
||||
else:
|
||||
only_filename = list(state_dict_split.filename_to_tensors.keys())[0]
|
||||
old_path = os.path.join(output_dir, only_filename)
|
||||
new_path = os.path.join(output_dir, "diffusion_pytorch_model.safetensors" if safe_serialization
|
||||
else "diffusion_pytorch_model.bin")
|
||||
if old_path != new_path:
|
||||
os.rename(old_path, new_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("checkpoint_dir",
|
||||
type=str,
|
||||
help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
|
||||
parser.add_argument("output_dir",
|
||||
type=str,
|
||||
help="directory to the pytorch fp32 state_dict output files"
|
||||
"(e.g. path/checkpoint-12-output/)")
|
||||
parser.add_argument(
|
||||
"--max_shard_size",
|
||||
type=str,
|
||||
default="5GB",
|
||||
help="The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size"
|
||||
"lower than this size. If expressed as a string, needs to be digits followed by a unit (like `5MB`"
|
||||
"We default it to 5GB in order for models to be able to run easily on free-tier google colab instances"
|
||||
"without CPU OOM issues.")
|
||||
parser.add_argument(
|
||||
"--safe_serialization",
|
||||
default=False,
|
||||
action='store_true',
|
||||
help="Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).")
|
||||
parser.add_argument("-t",
|
||||
"--tag",
|
||||
type=str,
|
||||
default=None,
|
||||
help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1")
|
||||
parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters")
|
||||
parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
|
||||
args = parser.parse_args()
|
||||
|
||||
debug = args.debug
|
||||
|
||||
convert_zero_checkpoint_to_bf16_state_dict(args.checkpoint_dir,
|
||||
args.output_dir,
|
||||
max_shard_size=args.max_shard_size,
|
||||
safe_serialization=args.safe_serialization,
|
||||
tag=args.tag,
|
||||
exclude_frozen_parameters=args.exclude_frozen_parameters)
|
Loading…
x
Reference in New Issue
Block a user