diff --git a/README.md b/README.md index a27782c..4e83ec2 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,10 @@ Experience the CogVideoX-5B model online at CogVideoX-2B CogVideoX-5B CogVideoX-5B-I2V + CogVideoX1.5-5B + CogVideoX1.5-5B-I2V - Model Description - Entry-level model, balancing compatibility. Low cost for running and secondary development. - Larger model with higher video generation quality and better visual effects. - CogVideoX-5B image-to-video version. + Release Date + August 6, 2024 + August 27, 2024 + September 19, 2024 + November 8, 2024 + November 8, 2024 + + + Video Resolution + 720 * 480 + 1360 * 768 + 256 <= W <=1360
256 <= H <=768
W,H % 16 == 0 Inference Precision FP16*(recommended), BF16, FP32, FP8*, INT8, not supported: INT4 - BF16 (recommended), FP16, FP32, FP8*, INT8, not supported: INT4 + BF16(recommended), FP16, FP32, FP8*, INT8, not supported: INT4 + BF16 - Single GPU Memory Usage
-
SAT FP16: 18GB
diffusers FP16: from 4GB*
diffusers INT8 (torchao): from 3.6GB* - SAT BF16: 26GB
diffusers BF16: from 5GB*
diffusers INT8 (torchao): from 4.4GB* + Single GPU Memory Usage + SAT FP16: 18GB
diffusers FP16: from 4GB*
diffusers INT8(torchao): from 3.6GB* + SAT BF16: 26GB
diffusers BF16 : from 5GB*
diffusers INT8(torchao): from 4.4GB* + SAT BF16: 66GB
- Multi-GPU Inference Memory Usage + Multi-GPU Memory Usage FP16: 10GB* using diffusers
BF16: 15GB* using diffusers
+ Not supported
Inference Speed
(Step = 50, FP/BF16) Single A100: ~90 seconds
Single H100: ~45 seconds Single A100: ~180 seconds
Single H100: ~90 seconds - - - Fine-tuning Precision - FP16 - BF16 - - - Fine-tuning Memory Usage - 47 GB (bs=1, LORA)
61 GB (bs=2, LORA)
62GB (bs=1, SFT) - 63 GB (bs=1, LORA)
80 GB (bs=2, LORA)
75GB (bs=1, SFT)
- 78 GB (bs=1, LORA)
75GB (bs=1, SFT, 16GPU)
+ Single A100: ~1000 seconds (5-second video)
Single H100: ~550 seconds (5-second video) Prompt Language - English* + English* - Maximum Prompt Length + Prompt Token Limit 226 Tokens + 224 Tokens Video Length - 6 Seconds + 6 seconds + 5 or 10 seconds Frame Rate - 8 Frames / Second + 8 frames / second + 16 frames / second - Video Resolution - 720 x 480, no support for other resolutions (including fine-tuning) - - - Position Encoding + Positional Encoding 3d_sincos_pos_embed 3d_sincos_pos_embed 3d_rope_pos_embed + learnable_pos_embed + 3d_sincos_pos_embed + 3d_rope_pos_embed + learnable_pos_embed Download Link (Diffusers) 🀗 HuggingFace
🀖 ModelScope
🟣 WiseModel 🀗 HuggingFace
🀖 ModelScope
🟣 WiseModel 🀗 HuggingFace
🀖 ModelScope
🟣 WiseModel + Coming Soon Download Link (SAT) - SAT + SAT + 🀗 HuggingFace
🀖 ModelScope
🟣 WiseModel @@ -422,7 +430,7 @@ hands-on practice on text-to-video generation. *The original input is in Chinese We welcome your contributions! You can click [here](resources/contribute.md) for more information. -## License Agreement +## Model-License The code in this repository is released under the [Apache 2.0 License](LICENSE). diff --git a/README_ja.md b/README_ja.md index 69b46b6..aa7ae37 100644 --- a/README_ja.md +++ b/README_ja.md @@ -1,6 +1,6 @@ # CogVideo & CogVideoX -[Read this in English](./README_zh.md) +[Read this in English](./README.md) [䞭文阅读](./README_zh.md) @@ -22,9 +22,14 @@ ## 曎新ずニュヌス -- 🔥🔥 **ニュヌス**: ```2024/10/13```: コスト削枛のため、単䞀の4090 GPUで`CogVideoX-5B` +- 🔥🔥 ニュヌス: ```2024/11/08```: `CogVideoX1.5` モデルをリリヌスしたした。CogVideoX1.5 は CogVideoX オヌプン゜ヌスモデルのアップグレヌドバヌゞョンです。 +CogVideoX1.5-5B シリヌズモデルは、10秒 長の動画ずより高い解像床をサポヌトしおおり、`CogVideoX1.5-5B-I2V` は任意の解像床での動画生成に察応しおいたす。 +SAT コヌドはすでに曎新されおおり、`diffusers` バヌゞョンは珟圚適応䞭です。 +SAT バヌゞョンのコヌドは [こちら](https://huggingface.co/THUDM/CogVideoX1.5-5B-SAT) からダりンロヌドできたす。 +- 🔥 **ニュヌス**: ```2024/10/13```: コスト削枛のため、単䞀の4090 GPUで`CogVideoX-5B` を埮調敎できるフレヌムワヌク [cogvideox-factory](https://github.com/a-r-r-o-w/cogvideox-factory) - がリリヌスされたした。耇数の解像床での埮調敎に察応しおいたす。ぜひご利甚ください- 🔥**ニュヌス**: ```2024/10/10```: + がリリヌスされたした。耇数の解像床での埮調敎に察応しおいたす。ぜひご利甚ください +- 🔥**ニュヌス**: ```2024/10/10```: 技術報告曞を曎新し、より詳现なトレヌニング情報ずデモを远加したした。 - 🔥 **ニュヌス**: ```2024/10/10```: 技術報告曞を曎新したした。[こちら](https://arxiv.org/pdf/2408.06072) をクリックしおご芧ください。さらにトレヌニングの詳现ずデモを远加したした。デモを芋るには[こちら](https://yzy-thu.github.io/CogVideoX-demo/) @@ -34,7 +39,7 @@ - 🔥**ニュヌス**: ```2024/9/19```: CogVideoXシリヌズの画像生成ビデオモデル **CogVideoX-5B-I2V** をオヌプン゜ヌス化したした。このモデルは、画像を背景入力ずしお䜿甚し、プロンプトワヌドず組み合わせおビデオを生成するこずができ、より高い制埡性を提䟛したす。これにより、CogVideoXシリヌズのモデルは、テキストからビデオ生成、ビデオの継続、画像からビデオ生成の3぀のタスクをサポヌトするようになりたした。オンラむンでの[䜓隓](https://huggingface.co/spaces/THUDM/CogVideoX-5B-Space) をお楜しみください。 -- 🔥🔥 **ニュヌス**: ```2024/9/19```: +- 🔥 **ニュヌス**: ```2024/9/19```: CogVideoXのトレヌニングプロセスでビデオデヌタをテキスト蚘述に倉換するために䜿甚されるキャプションモデル [CogVLM2-Caption](https://huggingface.co/THUDM/cogvlm2-llama3-caption) をオヌプン゜ヌス化したした。ダりンロヌドしおご利甚ください。 - 🔥 ```2024/8/27```: CogVideoXシリヌズのより倧きなモデル **CogVideoX-5B** @@ -63,11 +68,10 @@ - [プロゞェクト構造](#プロゞェクト構造) - [掚論](#掚論) - [sat](#sat) - - [ツヌル](#ツヌル) -- [プロゞェクト蚈画](#プロゞェクト蚈画) -- [モデルラむセンス](#モデルラむセンス) + - [ツヌル](#ツヌル)= - [CogVideo(ICLR'23)モデル玹介](#CogVideoICLR23) - [匕甚](#匕甚) +- [ラむセンス契玄](#ラむセンス契玄) ## クむックスタヌト @@ -156,79 +160,91 @@ pip install -r requirements.txt CogVideoXは、[枅圱](https://chatglm.cn/video?fr=osm_cogvideox) ず同源のオヌプン゜ヌス版ビデオ生成モデルです。 以䞋の衚に、提䟛しおいるビデオ生成モデルの基本情報を瀺したす: - +
- + + + + + + + + + + + + + + + + + - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + - + - + + + - - - - - + + + + + - + +
モデル名 CogVideoX-2B CogVideoX-5BCogVideoX-5B-I2V CogVideoX-5B-I2VCogVideoX1.5-5BCogVideoX1.5-5B-I2V
リリヌス日2024幎8月6日2024幎8月27日2024幎9月19日2024幎11月8日2024幎11月8日
ビデオ解像床720 * 4801360 * 768256 <= W <=1360
256 <= H <=768
W,H % 16 == 0
掚論粟床 FP16*(掚奚), BF16, FP32, FP8*, INT8, INT4は非察応 BF16(掚奚), FP16, FP32, FP8*, INT8, INT4は非察応
単䞀GPUのメモリ消費
SAT FP16: 18GB
diffusers FP16: 4GBから*
diffusers INT8(torchao): 3.6GBから*
SAT BF16: 26GB
diffusers BF16 : 5GBから*
diffusers INT8(torchao): 4.4GBから*
マルチGPUのメモリ消費FP16: 10GB* using diffusers
BF16: 15GB* using diffusers
掚論速床
(ステップ = 50, FP/BF16)
単䞀A100: 箄90秒
単䞀H100: 箄45秒
単䞀A100: 箄180秒
単䞀H100: 箄90秒
ファむンチュヌニング粟床FP16 BF16
ファむンチュヌニング時のメモリ消費47 GB (bs=1, LORA)
61 GB (bs=2, LORA)
62GB (bs=1, SFT)
63 GB (bs=1, LORA)
80 GB (bs=2, LORA)
75GB (bs=1, SFT)
78 GB (bs=1, LORA)
75GB (bs=1, SFT, 16GPU)
シングルGPUメモリ消費SAT FP16: 18GB
diffusers FP16: 4GBから*
diffusers INT8(torchao): 3.6GBから*
SAT BF16: 26GB
diffusers BF16: 5GBから*
diffusers INT8(torchao): 4.4GBから*
SAT BF16: 66GB
マルチGPUメモリ消費FP16: 10GB* using diffusers
BF16: 15GB* using diffusers
サポヌトなし
掚論速床
(ステップ数 = 50, FP/BF16)
単䞀A100: 箄90秒
単䞀H100: 箄45秒
単䞀A100: 箄180秒
単䞀H100: 箄90秒
単䞀A100: 箄1000秒(5秒動画)
単䞀H100: 箄550秒(5秒動画)
プロンプト蚀語英語*英語*
プロンプトの最倧トヌクン数プロンプトトヌクン制限 226トヌクン224トヌクン
ビデオの長さ 6秒5秒たたは10秒
フレヌムレヌト8フレヌム/秒
ビデオ解像床720 * 480、他の解像床は非察応(ファむンチュヌニング含む)8 フレヌム / 秒16 フレヌム / 秒
䜍眮゚ンコヌディング 3d_sincos_pos_embed 3d_sincos_pos_embed 3d_rope_pos_embed + learnable_pos_embed3d_sincos_pos_embed3d_rope_pos_embed + learnable_pos_embed
ダりンロヌドリンク (Diffusers) 🀗 HuggingFace
🀖 ModelScope
🟣 WiseModel
🀗 HuggingFace
🀖 ModelScope
🟣 WiseModel
🀗 HuggingFace
🀖 ModelScope
🟣 WiseModel
近日公開
ダりンロヌドリンク (SAT)SATSAT🀗 HuggingFace
🀖 ModelScope
🟣 WiseModel
diff --git a/README_zh.md b/README_zh.md index 9f84f84..3574e7d 100644 --- a/README_zh.md +++ b/README_zh.md @@ -1,10 +1,9 @@ # CogVideo & CogVideoX -[Read this in English](./README_zh.md) +[Read this in English](./README.md) [日本語で読む](./README_ja.md) -
@@ -23,7 +22,9 @@ ## 项目曎新 -- 🔥🔥 **News**: ```2024/10/13```: 成本曎䜎单卡4090可埮调`CogVideoX-5B` +- 🔥🔥 **News**: ```2024/11/08```: 我们发垃 `CogVideoX1.5` 暡型。CogVideoX1.5 是 CogVideoX 匀源暡型的升级版本。 +CogVideoX1.5-5B 系列暡型支持 **10秒** 长床的视频和曎高的分蟚率其䞭 `CogVideoX1.5-5B-I2V` 支持 **任意分蟚率** 的视频生成SAT代码已经曎新。`diffusers`版本还圚适配䞭。SAT版本代码前埀 [这里](https://huggingface.co/THUDM/CogVideoX1.5-5B-SAT) 䞋蜜。 +- 🔥**News**: ```2024/10/13```: 成本曎䜎单卡4090可埮调 `CogVideoX-5B` 的埮调框架[cogvideox-factory](https://github.com/a-r-r-o-w/cogvideox-factory)已经掚出倚种分蟚率埮调欢迎䜿甚。 - 🔥 **News**: ```2024/10/10```: 我们曎新了我们的技术报告,请点击 [这里](https://arxiv.org/pdf/2408.06072) 查看附䞊了曎倚的训练细节和demo关于demo点击[这里](https://yzy-thu.github.io/CogVideoX-demo/) 查看。 @@ -58,10 +59,9 @@ - [Inference](#inference) - [SAT](#sat) - [Tools](#tools) -- [匀源项目规划](#匀源项目规划) -- [暡型协议](#暡型协议) - [CogVideo(ICLR'23)暡型介绍](#cogvideoiclr23) - [匕甚](#匕甚) +- [暡型协议](#暡型协议) ## 快速匀始 @@ -157,62 +157,72 @@ CogVideoX是 [枅圱](https://chatglm.cn/video?fr=osm_cogvideox) 同源的匀源 CogVideoX-2B CogVideoX-5B CogVideoX-5B-I2V + CogVideoX1.5-5B + CogVideoX1.5-5B-I2V + + + 发垃时闎 + 2024幎8月6日 + 2024幎8月27日 + 2024幎9月19日 + 2024幎11月8日 + 2024幎11月8日 + + + 视频分蟚率 + 720 * 480 + 1360 * 768 + 256 <= W <=1360
256 <= H <=768
W,H % 16 == 0 掚理粟床 FP16*(掚荐), BF16, FP32FP8*INT8䞍支持INT4 BF16(掚荐), FP16, FP32FP8*INT8䞍支持INT4 + BF16 单GPU星存消耗
SAT FP16: 18GB
diffusers FP16: 4GBèµ·*
diffusers INT8(torchao): 3.6Gèµ·* SAT BF16: 26GB
diffusers BF16 : 5GBèµ·*
diffusers INT8(torchao): 4.4Gèµ·* + SAT BF16: 66GB
倚GPU掚理星存消耗 FP16: 10GB* using diffusers
BF16: 15GB* using diffusers
+ Not support
掚理速床
(Step = 50, FP/BF16) 单卡A100: ~90秒
单卡H100: ~45秒 单卡A100: ~180秒
单卡H100: ~90秒 - - - 埮调粟床 - FP16 - BF16 - - - 埮调星存消耗 - 47 GB (bs=1, LORA)
61 GB (bs=2, LORA)
62GB (bs=1, SFT) - 63 GB (bs=1, LORA)
80 GB (bs=2, LORA)
75GB (bs=1, SFT)
- 78 GB (bs=1, LORA)
75GB (bs=1, SFT, 16GPU)
+ 单卡A100: ~1000秒(5秒视频)
单卡H100: ~550秒(5秒视频) 提瀺词语蚀 - English* + English* 提瀺词长床䞊限 226 Tokens + 224 Tokens 视频长床 6 秒 + 5 秒 或 10 秒 垧率 8 垧 / 秒 + 16 垧 / 秒 - 视频分蟚率 - 720 * 480䞍支持其他分蟚率(含埮调) - - 䜍眮猖码 3d_sincos_pos_embed - 3d_sincos_pos_embed + 3d_sincos_pos_embed + 3d_rope_pos_embed + learnable_pos_embed + 3d_sincos_pos_embed 3d_rope_pos_embed + learnable_pos_embed @@ -220,10 +230,13 @@ CogVideoX是 [枅圱](https://chatglm.cn/video?fr=osm_cogvideox) 同源的匀源 🀗 HuggingFace
🀖 ModelScope
🟣 WiseModel 🀗 HuggingFace
🀖 ModelScope
🟣 WiseModel 🀗 HuggingFace
🀖 ModelScope
🟣 WiseModel + 即将掚出 䞋蜜铟接 (SAT) SAT + 🀗 HuggingFace
🀖 ModelScope
🟣 WiseModel + diff --git a/sat/README.md b/sat/README.md index 48c4552..c67e15c 100644 --- a/sat/README.md +++ b/sat/README.md @@ -1,29 +1,39 @@ -# SAT CogVideoX-2B +# SAT CogVideoX -[䞭文阅读](./README_zh.md) +[Read this in English.](./README_zh.md) [日本語で読む](./README_ja.md) -This folder contains the inference code using [SAT](https://github.com/THUDM/SwissArmyTransformer) weights and the -fine-tuning code for SAT weights. +This folder contains inference code using [SAT](https://github.com/THUDM/SwissArmyTransformer) weights, along with fine-tuning code for SAT weights. -This code is the framework used by the team to train the model. It has few comments and requires careful study. +This code framework was used by our team during model training. There are few comments, so careful study is required. ## Inference Model -### 1. Ensure that you have correctly installed the dependencies required by this folder. +### 1. Make sure you have installed all dependencies in this folder -```shell +``` pip install -r requirements.txt ``` -### 2. Download the model weights +### 2. Download the Model Weights -### 2. Download model weights +First, download the model weights from the SAT mirror. -First, go to the SAT mirror to download the model weights. For the CogVideoX-2B model, please download as follows: +#### CogVideoX1.5 Model -```shell +``` +git lfs install +git clone https://huggingface.co/THUDM/CogVideoX1.5-5B-SAT +``` + +This command downloads three models: Transformers, VAE, and T5 Encoder. + +#### CogVideoX Model + +For the CogVideoX-2B model, download as follows: + +``` mkdir CogVideoX-2b-sat cd CogVideoX-2b-sat wget https://cloud.tsinghua.edu.cn/f/fdba7608a49c463ba754/?dl=1 @@ -34,13 +44,12 @@ mv 'index.html?dl=1' transformer.zip unzip transformer.zip ``` -For the CogVideoX-5B model, please download the `transformers` file as follows link: -(VAE files are the same as 2B) +Download the `transformers` file for the CogVideoX-5B model (the VAE file is the same as for 2B): + [CogVideoX-5B](https://cloud.tsinghua.edu.cn/d/fcef5b3904294a6885e5/?p=%2F&mode=list) + [CogVideoX-5B-I2V](https://cloud.tsinghua.edu.cn/d/5cc62a2d6e7d45c0a2f6/?p=%2F1&mode=list) -Next, you need to format the model files as follows: +Arrange the model files in the following structure: ``` . @@ -52,20 +61,24 @@ Next, you need to format the model files as follows: └── 3d-vae.pt ``` -Due to large size of model weight file, using `git lfs` is recommended. Installation of `git lfs` can be -found [here](https://github.com/git-lfs/git-lfs?tab=readme-ov-file#installing) +Since model weight files are large, it’s recommended to use `git lfs`. +See [here](https://github.com/git-lfs/git-lfs?tab=readme-ov-file#installing) for `git lfs` installation. -Next, clone the T5 model, which is not used for training and fine-tuning, but must be used. -> T5 model is available on [Modelscope](https://modelscope.cn/models/ZhipuAI/CogVideoX-2b) as well. +``` +git lfs install +``` -```shell -git clone https://huggingface.co/THUDM/CogVideoX-2b.git +Next, clone the T5 model, which is used as an encoder and doesn’t require training or fine-tuning. +> You may also use the model file location on [Modelscope](https://modelscope.cn/models/ZhipuAI/CogVideoX-2b). + +``` +git clone https://huggingface.co/THUDM/CogVideoX-2b.git # Download model from Huggingface +# git clone https://www.modelscope.cn/ZhipuAI/CogVideoX-2b.git # Download from Modelscope mkdir t5-v1_1-xxl mv CogVideoX-2b/text_encoder/* CogVideoX-2b/tokenizer/* t5-v1_1-xxl ``` -By following the above approach, you will obtain a safetensor format T5 file. Ensure that there are no errors when -loading it into Deepspeed in Finetune. +This will yield a safetensor format T5 file that can be loaded without error during Deepspeed fine-tuning. ``` ├── added_tokens.json @@ -80,11 +93,11 @@ loading it into Deepspeed in Finetune. 0 directories, 8 files ``` -### 3. Modify the file in `configs/cogvideox_2b.yaml`. +### 3. Modify `configs/cogvideox_*.yaml` file. ```yaml model: - scale_factor: 1.15258426 + scale_factor: 1.55258426 disable_first_stage_autocast: true log_keys: - txt @@ -160,14 +173,14 @@ model: ucg_rate: 0.1 target: sgm.modules.encoders.modules.FrozenT5Embedder params: - model_dir: "t5-v1_1-xxl" # Absolute path to the CogVideoX-2b/t5-v1_1-xxl weights folder + model_dir: "t5-v1_1-xxl" # absolute path to CogVideoX-2b/t5-v1_1-xxl weight folder max_length: 226 first_stage_config: target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper params: cp_size: 1 - ckpt_path: "CogVideoX-2b-sat/vae/3d-vae.pt" # Absolute path to the CogVideoX-2b-sat/vae/3d-vae.pt folder + ckpt_path: "CogVideoX-2b-sat/vae/3d-vae.pt" # absolute path to CogVideoX-2b-sat/vae/3d-vae.pt file ignore_keys: [ 'loss' ] loss_config: @@ -239,48 +252,46 @@ model: num_steps: 50 ``` -### 4. Modify the file in `configs/inference.yaml`. +### 4. Modify `configs/inference.yaml` file. ```yaml args: latent_channels: 16 mode: inference - load: "{absolute_path/to/your}/transformer" # Absolute path to the CogVideoX-2b-sat/transformer folder + load: "{absolute_path/to/your}/transformer" # Absolute path to CogVideoX-2b-sat/transformer folder # load: "{your lora folder} such as zRzRzRzRzRzRzR/lora-disney-08-20-13-28" # This is for Full model without lora adapter batch_size: 1 - input_type: txt # You can choose txt for pure text input, or change to cli for command line input - input_file: configs/test.txt # Pure text file, which can be edited - sampling_num_frames: 13 # Must be 13, 11 or 9 + input_type: txt # You can choose "txt" for plain text input or change to "cli" for command-line input + input_file: configs/test.txt # Plain text file, can be edited + sampling_num_frames: 13 # For CogVideoX1.5-5B it must be 42 or 22. For CogVideoX-5B / 2B, it must be 13, 11, or 9. sampling_fps: 8 fp16: True # For CogVideoX-2B - # bf16: True # For CogVideoX-5B + # bf16: True # For CogVideoX-5B output_dir: outputs/ force_inference: True ``` -+ Modify `configs/test.txt` if multiple prompts is required, in which each line makes a prompt. -+ For better prompt formatting, refer to [convert_demo.py](../inference/convert_demo.py), for which you should set the - OPENAI_API_KEY as your environmental variable. -+ Modify `input_type` in `configs/inference.yaml` if use command line as prompt iuput. ++ If using a text file to save multiple prompts, modify `configs/test.txt` as needed. One prompt per line. If you are unsure how to write prompts, use [this code](../inference/convert_demo.py) to call an LLM for refinement. ++ To use command-line input, modify: -```yaml +``` input_type: cli ``` -This allows input from the command line as prompts. +This allows you to enter prompts from the command line. -Change `output_dir` if you wish to modify the address of the output video +To modify the output video location, change: -```yaml +``` output_dir: outputs/ ``` -It is saved by default in the `.outputs/` folder. +The default location is the `.outputs/` folder. -### 5. Run the inference code to perform inference. +### 5. Run the Inference Code to Perform Inference -```shell +``` bash inference.sh ``` @@ -288,95 +299,91 @@ bash inference.sh ### Preparing the Dataset -The dataset format should be as follows: +The dataset should be structured as follows: ``` . ├── labels -│   ├── 1.txt -│   ├── 2.txt -│   ├── ... +│ ├── 1.txt +│ ├── 2.txt +│ ├── ... └── videos ├── 1.mp4 ├── 2.mp4 ├── ... ``` -Each text file shares the same name as its corresponding video, serving as the label for that video. Videos and labels -should be matched one-to-one. Generally, a single video should not be associated with multiple labels. +Each txt file should have the same name as the corresponding video file and contain the label for that video. The videos and labels should correspond one-to-one. Generally, avoid using one video with multiple labels. -For style fine-tuning, please prepare at least 50 videos and labels with similar styles to ensure proper fitting. +For style fine-tuning, prepare at least 50 videos and labels with a similar style to facilitate fitting. -### Modifying Configuration Files +### Modifying the Configuration File -We support two fine-tuning methods: `Lora` and full-parameter fine-tuning. Please note that both methods only fine-tune -the `transformer` part and do not modify the `VAE` section. `T5` is used solely as an Encoder. Please modify -the `configs/sft.yaml` (for full-parameter fine-tuning) file as follows: +We support two fine-tuning methods: `Lora` and full-parameter fine-tuning. Note that both methods only fine-tune the `transformer` part. The `VAE` part is not modified, and `T5` is only used as an encoder. +Modify the files in `configs/sft.yaml` (full fine-tuning) as follows: -``` - # checkpoint_activations: True ## Using gradient checkpointing (Both checkpoint_activations in the config file need to be set to True) +```yaml + # checkpoint_activations: True ## using gradient checkpointing (both `checkpoint_activations` in the config file need to be set to True) model_parallel_size: 1 # Model parallel size - experiment_name: lora-disney # Experiment name (do not modify) - mode: finetune # Mode (do not modify) - load: "{your_CogVideoX-2b-sat_path}/transformer" ## Transformer model path - no_load_rng: True # Whether to load random seed + experiment_name: lora-disney # Experiment name (do not change) + mode: finetune # Mode (do not change) + load: "{your_CogVideoX-2b-sat_path}/transformer" ## Path to Transformer model + no_load_rng: True # Whether to load random number seed train_iters: 1000 # Training iterations eval_iters: 1 # Evaluation iterations eval_interval: 100 # Evaluation interval eval_batch_size: 1 # Evaluation batch size - save: ckpts # Model save path - save_interval: 100 # Model save interval + save: ckpts # Model save path + save_interval: 100 # Save interval log_interval: 20 # Log output interval train_data: [ "your train data path" ] - valid_data: [ "your val data path" ] # Training and validation datasets can be the same - split: 1,0,0 # Training, validation, and test set ratio - num_workers: 8 # Number of worker threads for data loader - force_train: True # Allow missing keys when loading checkpoint (T5 and VAE are loaded separately) - only_log_video_latents: True # Avoid memory overhead caused by VAE decode + valid_data: [ "your val data path" ] # Training and validation sets can be the same + split: 1,0,0 # Proportion for training, validation, and test sets + num_workers: 8 # Number of data loader workers + force_train: True # Allow missing keys when loading checkpoint (T5 and VAE loaded separately) + only_log_video_latents: True # Avoid memory usage from VAE decoding deepspeed: bf16: - enabled: False # For CogVideoX-2B set to False and for CogVideoX-5B set to True + enabled: False # For CogVideoX-2B Turn to False and For CogVideoX-5B Turn to True fp16: - enabled: True # For CogVideoX-2B set to True and for CogVideoX-5B set to False + enabled: True # For CogVideoX-2B Turn to True and For CogVideoX-5B Turn to False ``` -If you wish to use Lora fine-tuning, you also need to modify the `cogvideox__lora` file: +``` To use Lora fine-tuning, you also need to modify `cogvideox__lora` file: -Here, take `CogVideoX-2B` as a reference: +Here's an example using `CogVideoX-2B`: ``` model: - scale_factor: 1.15258426 + scale_factor: 1.55258426 disable_first_stage_autocast: true - not_trainable_prefixes: [ 'all' ] ## Uncomment + not_trainable_prefixes: [ 'all' ] ## Uncomment to unlock log_keys: - - txt' + - txt - lora_config: ## Uncomment + lora_config: ## Uncomment to unlock target: sat.model.finetune.lora2.LoraMixin params: r: 256 ``` -### Modifying Run Scripts +### Modify the Run Script -Edit `finetune_single_gpu.sh` or `finetune_multi_gpus.sh` to select the configuration file. Below are two examples: +Edit `finetune_single_gpu.sh` or `finetune_multi_gpus.sh` and select the config file. Below are two examples: -1. If you want to use the `CogVideoX-2B` model and the `Lora` method, you need to modify `finetune_single_gpu.sh` - or `finetune_multi_gpus.sh`: +1. If you want to use the `CogVideoX-2B` model with `Lora`, modify `finetune_single_gpu.sh` or `finetune_multi_gpus.sh` as follows: ``` run_cmd="torchrun --standalone --nproc_per_node=8 train_video.py --base configs/cogvideox_2b_lora.yaml configs/sft.yaml --seed $RANDOM" ``` -2. If you want to use the `CogVideoX-2B` model and the `full-parameter fine-tuning` method, you need to - modify `finetune_single_gpu.sh` or `finetune_multi_gpus.sh`: +2. If you want to use the `CogVideoX-2B` model with full fine-tuning, modify `finetune_single_gpu.sh` or `finetune_multi_gpus.sh` as follows: ``` run_cmd="torchrun --standalone --nproc_per_node=8 train_video.py --base configs/cogvideox_2b.yaml configs/sft.yaml --seed $RANDOM" ``` -### Fine-Tuning and Evaluation +### Fine-tuning and Validation Run the inference code to start fine-tuning. @@ -385,45 +392,42 @@ bash finetune_single_gpu.sh # Single GPU bash finetune_multi_gpus.sh # Multi GPUs ``` -### Using the Fine-Tuned Model +### Using the Fine-tuned Model -The fine-tuned model cannot be merged; here is how to modify the inference configuration file `inference.sh`: +The fine-tuned model cannot be merged. Here’s how to modify the inference configuration file `inference.sh` ``` -run_cmd="$environs python sample_video.py --base configs/cogvideox__lora.yaml configs/inference.yaml --seed 42" +run_cmd="$environs python sample_video.py --base configs/cogvideox__lora.yaml configs/inference.yaml --seed 42" ``` -Then, execute the code: +Then, run the code: ``` bash inference.sh ``` -### Converting to Huggingface Diffusers Supported Weights +### Converting to Huggingface Diffusers-compatible Weights -The SAT weight format is different from Huggingface's weight format and needs to be converted. Please run: +The SAT weight format is different from Huggingface’s format and requires conversion. Run -```shell +``` python ../tools/convert_weight_sat2hf.py ``` -### Exporting Huggingface Diffusers lora LoRA Weights from SAT Checkpoints +### Exporting Lora Weights from SAT to Huggingface Diffusers -After completing the training using the above steps, we get a SAT checkpoint with LoRA weights. You can find the file -at `{args.save}/1000/1000/mp_rank_00_model_states.pt`. +Support is provided for exporting Lora weights from SAT to Huggingface Diffusers format. + After training with the above steps, you’ll find the SAT model with Lora weights in {args.save}/1000/1000/mp_rank_00_model_states.pt -The script for exporting LoRA weights can be found in the CogVideoX repository at `tools/export_sat_lora_weight.py`. -After exporting, you can use `load_cogvideox_lora.py` for inference. +The export script `export_sat_lora_weight.py` is located in the CogVideoX repository under `tools/`. After exporting, use `load_cogvideox_lora.py` for inference. Export command: -```bash -python tools/export_sat_lora_weight.py --sat_pt_path {args.save}/{experiment_name}-09-09-21-10/1000/mp_rank_00_model_states.pt --lora_save_directory {args.save}/export_hf_lora_weights_1/ +``` +python tools/export_sat_lora_weight.py --sat_pt_path {args.save}/{experiment_name}-09-09-21-10/1000/mp_rank_00_model_states.pt --lora_save_directory {args.save}/export_hf_lora_weights_1/ ``` -This training mainly modified the following model structures. The table below lists the corresponding structure mappings -for converting to the HF (Hugging Face) format LoRA structure. As you can see, LoRA adds a low-rank weight to the -model's attention structure. +The following model structures were modified during training. Here is the mapping between SAT and HF Lora structures. Lora adds a low-rank weight to the attention structure of the model. ``` 'attention.query_key_value.matrix_A.0': 'attn1.to_q.lora_A.weight', @@ -436,5 +440,5 @@ model's attention structure. 'attention.dense.matrix_B.0': 'attn1.to_out.0.lora_B.weight' ``` -Using export_sat_lora_weight.py, you can convert the SAT checkpoint into the HF LoRA format. -![alt text](../resources/hf_lora_weights.png) +Using `export_sat_lora_weight.py` will convert these to the HF format Lora structure. +![alt text](../resources/hf_lora_weights.png) \ No newline at end of file diff --git a/sat/README_ja.md b/sat/README_ja.md index ee1abcd..3685ba3 100644 --- a/sat/README_ja.md +++ b/sat/README_ja.md @@ -1,27 +1,37 @@ -# SAT CogVideoX-2B +# SAT CogVideoX -[Read this in English.](./README_zh) +[Read this in English.](./README.md) [䞭文阅读](./README_zh.md) -このフォルダには、[SAT](https://github.com/THUDM/SwissArmyTransformer) りェむトを䜿甚した掚論コヌドず、SAT -りェむトのファむンチュヌニングコヌドが含たれおいたす。 - -このコヌドは、チヌムがモデルをトレヌニングするために䜿甚したフレヌムワヌクです。コメントが少なく、泚意深く研究する必芁がありたす。 +このフォルダには、[SAT](https://github.com/THUDM/SwissArmyTransformer)の重みを䜿甚した掚論コヌドず、SAT重みのファむンチュヌニングコヌドが含たれおいたす。 +このコヌドは、チヌムがモデルを蚓緎する際に䜿甚したフレヌムワヌクです。コメントが少ないため、泚意深く確認する必芁がありたす。 ## 掚論モデル -### 1. このフォルダに必芁な䟝存関係が正しくむンストヌルされおいるこずを確認しおください。 +### 1. このフォルダ内の必芁な䟝存関係がすべおむンストヌルされおいるこずを確認しおください -```shell +``` pip install -r requirements.txt ``` -### 2. モデルりェむトをダりンロヌドしたす +### 2. モデルの重みをダりンロヌド + たず、SATミラヌからモデルの重みをダりンロヌドしおください。 -たず、SAT ミラヌに移動しおモデルの重みをダりンロヌドしたす。 CogVideoX-2B モデルの堎合は、次のようにダりンロヌドしおください。 +#### CogVideoX1.5 モデル -```shell +``` +git lfs install +git clone https://huggingface.co/THUDM/CogVideoX1.5-5B-SAT +``` + +これにより、Transformers、VAE、T5 Encoderの3぀のモデルがダりンロヌドされたす。 + +#### CogVideoX モデル + +CogVideoX-2B モデルに぀いおは、以䞋のようにダりンロヌドしおください + +``` mkdir CogVideoX-2b-sat cd CogVideoX-2b-sat wget https://cloud.tsinghua.edu.cn/f/fdba7608a49c463ba754/?dl=1 @@ -32,12 +42,12 @@ mv 'index.html?dl=1' transformer.zip unzip transformer.zip ``` -CogVideoX-5B モデルの `transformers` ファむルを以䞋のリンクからダりンロヌドしおください VAE ファむルは 2B ず同じです +CogVideoX-5B モデルの `transformers` ファむルをダりンロヌドしおくださいVAEファむルは2Bず同じです + [CogVideoX-5B](https://cloud.tsinghua.edu.cn/d/fcef5b3904294a6885e5/?p=%2F&mode=list) + [CogVideoX-5B-I2V](https://cloud.tsinghua.edu.cn/d/5cc62a2d6e7d45c0a2f6/?p=%2F1&mode=list) -次に、モデルファむルを以䞋の圢匏にフォヌマットする必芁がありたす +モデルファむルを以䞋のように配眮しおください ``` . @@ -49,24 +59,24 @@ CogVideoX-5B モデルの `transformers` ファむルを以䞋のリンクから └── 3d-vae.pt ``` -モデルの重みファむルが倧きいため、`git lfs`を䜿甚するこずをお勧めいたしたす。`git lfs` -のむンストヌルに぀いおは、[こちら](https://github.com/git-lfs/git-lfs?tab=readme-ov-file#installing)をご参照ください。 +モデルの重みファむルが倧きいため、`git lfs`の䜿甚をお勧めしたす。 +`git lfs`のむンストヌル方法は[こちら](https://github.com/git-lfs/git-lfs?tab=readme-ov-file#installing)を参照しおください。 -```shell +``` git lfs install ``` -次に、T5 モデルをクロヌンしたす。これはトレヌニングやファむンチュヌニングには䜿甚されたせんが、䜿甚する必芁がありたす。 -> モデルを耇補する際には、[Modelscope](https://modelscope.cn/models/ZhipuAI/CogVideoX-2b)のモデルファむルの堎所もご䜿甚いただけたす。 +次に、T5モデルをクロヌンしたす。このモデルはEncoderずしおのみ䜿甚され、蚓緎やファむンチュヌニングは必芁ありたせん。 +> [Modelscope](https://modelscope.cn/models/ZhipuAI/CogVideoX-2b)䞊のモデルファむルも䜿甚可胜です。 -```shell -git clone https://huggingface.co/THUDM/CogVideoX-2b.git #ハギングフェむス(huggingface.org)からモデルをダりンロヌドいただきたす -# git clone https://www.modelscope.cn/ZhipuAI/CogVideoX-2b.git #Modelscopeからモデルをダりンロヌドいただきたす +``` +git clone https://huggingface.co/THUDM/CogVideoX-2b.git # Huggingfaceからモデルをダりンロヌド +# git clone https://www.modelscope.cn/ZhipuAI/CogVideoX-2b.git # Modelscopeからダりンロヌド mkdir t5-v1_1-xxl mv CogVideoX-2b/text_encoder/* CogVideoX-2b/tokenizer/* t5-v1_1-xxl ``` -䞊蚘の方法に埓うこずで、safetensor 圢匏の T5 ファむルを取埗できたす。これにより、Deepspeed でのファむンチュヌニング䞭に゚ラヌが発生しないようにしたす。 +これにより、Deepspeedファむンチュヌニング䞭に゚ラヌなくロヌドできるsafetensor圢匏のT5ファむルが䜜成されたす。 ``` ├── added_tokens.json @@ -81,11 +91,11 @@ mv CogVideoX-2b/text_encoder/* CogVideoX-2b/tokenizer/* t5-v1_1-xxl 0 directories, 8 files ``` -### 3. `configs/cogvideox_2b.yaml` ファむルを倉曎したす。 +### 3. `configs/cogvideox_*.yaml`ファむルを線集 ```yaml model: - scale_factor: 1.15258426 + scale_factor: 1.55258426 disable_first_stage_autocast: true log_keys: - txt @@ -123,7 +133,7 @@ model: num_attention_heads: 30 transformer_args: - checkpoint_activations: True ## グラデヌション チェックポむントを䜿甚する + checkpoint_activations: True ## using gradient checkpointing vocab_size: 1 max_sequence_length: 64 layernorm_order: pre @@ -161,14 +171,14 @@ model: ucg_rate: 0.1 target: sgm.modules.encoders.modules.FrozenT5Embedder params: - model_dir: "t5-v1_1-xxl" # CogVideoX-2b/t5-v1_1-xxlフォルダの絶察パス + model_dir: "t5-v1_1-xxl" # CogVideoX-2b/t5-v1_1-xxl 重みフォルダの絶察パス max_length: 226 first_stage_config: target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper params: cp_size: 1 - ckpt_path: "CogVideoX-2b-sat/vae/3d-vae.pt" # CogVideoX-2b-sat/vae/3d-vae.ptフォルダの絶察パス + ckpt_path: "CogVideoX-2b-sat/vae/3d-vae.pt" # CogVideoX-2b-sat/vae/3d-vae.ptファむルの絶察パス ignore_keys: [ 'loss' ] loss_config: @@ -240,7 +250,7 @@ model: num_steps: 50 ``` -### 4. `configs/inference.yaml` ファむルを倉曎したす。 +### 4. `configs/inference.yaml`ファむルを線集 ```yaml args: @@ -250,38 +260,36 @@ args: # load: "{your lora folder} such as zRzRzRzRzRzRzR/lora-disney-08-20-13-28" # This is for Full model without lora adapter batch_size: 1 - input_type: txt #TXTのテキストファむルを入力ずしお遞択されたり、CLIコマンドラむンを入力ずしお倉曎されたりいただけたす - input_file: configs/test.txt #テキストファむルのパスで、これに察しお線集がさせおいただけたす - sampling_num_frames: 13 # Must be 13, 11 or 9 + input_type: txt # "txt"でプレヌンテキスト入力、"cli"でコマンドラむン入力を遞択可胜 + input_file: configs/test.txt # プレヌンテキストファむル、線集可胜 + sampling_num_frames: 13 # CogVideoX1.5-5Bでは42たたは22、CogVideoX-5B / 2Bでは13, 11, たたは9 sampling_fps: 8 - fp16: True # For CogVideoX-2B - # bf16: True # For CogVideoX-5B + fp16: True # CogVideoX-2B甹 + # bf16: True # CogVideoX-5B甹 output_dir: outputs/ force_inference: True ``` -+ 耇数のプロンプトを保存するために txt を䜿甚する堎合は、`configs/test.txt` - を参照しお倉曎しおください。1行に1぀のプロンプトを蚘述したす。プロンプトの曞き方がわからない堎合は、最初に [このコヌド](../inference/convert_demo.py) - を䜿甚しお LLM によるリファむンメントを呌び出すこずができたす。 -+ コマンドラむンを入力ずしお䜿甚する堎合は、次のように倉曎したす。 ++ 耇数のプロンプトを含むテキストファむルを䜿甚する堎合、`configs/test.txt`を適宜線集しおください。1行に぀き1プロンプトです。プロンプトの曞き方が分からない堎合は、[こちらのコヌド](../inference/convert_demo.py)を䜿甚しおLLMで補正できたす。 ++ コマンドラむン入力を䜿甚する堎合、以䞋のように倉曎したす -```yaml +``` input_type: cli ``` これにより、コマンドラむンからプロンプトを入力できたす。 -出力ビデオのディレクトリを倉曎したい堎合は、次のように倉曎できたす +出力ビデオの保存堎所を倉曎する堎合は、以䞋を線集しおください -```yaml +``` output_dir: outputs/ ``` -デフォルトでは `.outputs/` フォルダに保存されたす。 +デフォルトでは`.outputs/`フォルダに保存されたす。 -### 5. 掚論コヌドを実行しお掚論を開始したす。 +### 5. 掚論コヌドを実行しお掚論を開始 -```shell +``` bash inference.sh ``` @@ -289,7 +297,7 @@ bash inference.sh ### デヌタセットの準備 -デヌタセットの圢匏は次のようになりたす +デヌタセットは以䞋の構造である必芁がありたす ``` . @@ -303,123 +311,215 @@ bash inference.sh ├── ... ``` -各 txt ファむルは察応するビデオファむルず同じ名前であり、そのビデオのラベルを含んでいたす。各ビデオはラベルず䞀察䞀で察応する必芁がありたす。通垞、1぀のビデオに耇数のラベルを持たせるこずはありたせん。 +各txtファむルは察応するビデオファむルず同じ名前で、ビデオのラベルを含んでいたす。ビデオずラベルは䞀察䞀で察応させる必芁がありたす。通垞、1぀のビデオに耇数のラベルを䜿甚するこずは避けおください。 -スタむルファむンチュヌニングの堎合、少なくずも50本のスタむルが䌌たビデオずラベルを準備し、フィッティングを容易にしたす。 +スタむルのファむンチュヌニングの堎合、スタむルが䌌たビデオずラベルを少なくずも50本準備し、フィッティングを促進したす。 -### 蚭定ファむルの倉曎 +### 蚭定ファむルの線集 -`Lora` ずフルパラメヌタ埮調敎の2぀の方法をサポヌトしおいたす。䞡方の埮調敎方法は、`transformer` 郚分のみを埮調敎し、`VAE` -郚分には倉曎を加えないこずに泚意しおください。`T5` ぱンコヌダヌずしおのみ䜿甚されたす。以䞋のように `configs/sft.yaml` ( -フルパラメヌタ埮調敎甚) ファむルを倉曎しおください。 +``` `Lora`ず党パラメヌタのファむンチュヌニングの2皮類をサポヌトしおいたす。どちらも`transformer`郚分のみをファむンチュヌニングし、`VAE`郚分は倉曎されず、`T5`ぱンコヌダヌずしおのみ䜿甚されたす。 +``` 以䞋のようにしお`configs/sft.yaml`党量ファむンチュヌニングファむルを線集しおください ``` - # checkpoint_activations: True ## 募配チェックポむントを䜿甚する堎合 (蚭定ファむル内の2぀の checkpoint_activations を True に蚭定する必芁がありたす) + # checkpoint_activations: True ## using gradient checkpointing (configファむル内の2぀の`checkpoint_activations`ã‚’äž¡æ–¹Trueに蚭定) model_parallel_size: 1 # モデル䞊列サむズ - experiment_name: lora-disney # 実隓名 (倉曎しないでください) - mode: finetune # モヌド (倉曎しないでください) - load: "{your_CogVideoX-2b-sat_path}/transformer" ## Transformer モデルのパス - no_load_rng: True # 乱数シヌドを読み蟌むかどうか + experiment_name: lora-disney # 実隓名倉曎䞍芁 + mode: finetune # モヌド倉曎䞍芁 + load: "{your_CogVideoX-2b-sat_path}/transformer" ## Transformerモデルのパス + no_load_rng: True # 乱数シヌドをロヌドするかどうか train_iters: 1000 # トレヌニングむテレヌション数 - eval_iters: 1 # 評䟡むテレヌション数 - eval_interval: 100 # 評䟡間隔 - eval_batch_size: 1 # 評䟡バッチサむズ - save: ckpts # モデル保存パス - save_interval: 100 # モデル保存間隔 + eval_iters: 1 # 怜蚌むテレヌション数 + eval_interval: 100 # 怜蚌間隔 + eval_batch_size: 1 # 怜蚌バッチサむズ + save: ckpts # モデル保存パス + save_interval: 100 # 保存間隔 log_interval: 20 # ログ出力間隔 train_data: [ "your train data path" ] - valid_data: [ "your val data path" ] # トレヌニングデヌタず評䟡デヌタは同じでも構いたせん - split: 1,0,0 # トレヌニングセット、評䟡セット、テストセットの割合 - num_workers: 8 # デヌタロヌダヌのワヌカヌスレッド数 - force_train: True # チェックポむントをロヌドするずきに欠萜したキヌを蚱可 (T5 ず VAE は別々にロヌドされたす) - only_log_video_latents: True # VAE のデコヌドによるメモリオヌバヌヘッドを回避 + valid_data: [ "your val data path" ] # トレヌニングセットず怜蚌セットは同じでも構いたせん + split: 1,0,0 # トレヌニングセット、怜蚌セット、テストセットの割合 + num_workers: 8 # デヌタロヌダヌのワヌカヌ数 + force_train: True # チェックポむントをロヌドする際に`missing keys`を蚱可T5ずVAEは別途ロヌド + only_log_video_latents: True # VAEのデコヌドによるメモリ䜿甚量を抑える deepspeed: bf16: - enabled: False # CogVideoX-2B の堎合は False に蚭定し、CogVideoX-5B の堎合は True に蚭定 + enabled: False # CogVideoX-2B 甚は False、CogVideoX-5B 甚は True に蚭定 fp16: - enabled: True # CogVideoX-2B の堎合は True に蚭定し、CogVideoX-5B の堎合は False に蚭定 + enabled: True # CogVideoX-2B 甚は True、CogVideoX-5B 甚は False に蚭定 +``` +```yaml +args: + latent_channels: 16 + mode: inference + load: "{absolute_path/to/your}/transformer" # Absolute path to CogVideoX-2b-sat/transformer folder + # load: "{your lora folder} such as zRzRzRzRzRzRzR/lora-disney-08-20-13-28" # This is for Full model without lora adapter + + batch_size: 1 + input_type: txt # You can choose "txt" for plain text input or change to "cli" for command-line input + input_file: configs/test.txt # Plain text file, can be edited + sampling_num_frames: 13 # For CogVideoX1.5-5B it must be 42 or 22. For CogVideoX-5B / 2B, it must be 13, 11, or 9. + sampling_fps: 8 + fp16: True # For CogVideoX-2B + # bf16: True # For CogVideoX-5B + output_dir: outputs/ + force_inference: True ``` -Lora 埮調敎を䜿甚したい堎合は、`cogvideox__lora` ファむルも倉曎する必芁がありたす。 - -ここでは、`CogVideoX-2B` を参考にしたす。 ++ If using a text file to save multiple prompts, modify `configs/test.txt` as needed. One prompt per line. If you are unsure how to write prompts, use [this code](../inference/convert_demo.py) to call an LLM for refinement. ++ To use command-line input, modify: ``` +input_type: cli +``` + +This allows you to enter prompts from the command line. + +To modify the output video location, change: + +``` +output_dir: outputs/ +``` + +The default location is the `.outputs/` folder. + +### 5. Run the Inference Code to Perform Inference + +``` +bash inference.sh +``` + +## Fine-tuning the Model + +### Preparing the Dataset + +The dataset should be structured as follows: + +``` +. +├── labels +│ ├── 1.txt +│ ├── 2.txt +│ ├── ... +└── videos + ├── 1.mp4 + ├── 2.mp4 + ├── ... +``` + +Each txt file should have the same name as the corresponding video file and contain the label for that video. The videos and labels should correspond one-to-one. Generally, avoid using one video with multiple labels. + +For style fine-tuning, prepare at least 50 videos and labels with a similar style to facilitate fitting. + +### Modifying the Configuration File + +We support two fine-tuning methods: `Lora` and full-parameter fine-tuning. Note that both methods only fine-tune the `transformer` part. The `VAE` part is not modified, and `T5` is only used as an encoder. +Modify the files in `configs/sft.yaml` (full fine-tuning) as follows: + +```yaml + # checkpoint_activations: True ## using gradient checkpointing (both `checkpoint_activations` in the config file need to be set to True) + model_parallel_size: 1 # Model parallel size + experiment_name: lora-disney # Experiment name (do not change) + mode: finetune # Mode (do not change) + load: "{your_CogVideoX-2b-sat_path}/transformer" ## Path to Transformer model + no_load_rng: True # Whether to load random number seed + train_iters: 1000 # Training iterations + eval_iters: 1 # Evaluation iterations + eval_interval: 100 # Evaluation interval + eval_batch_size: 1 # Evaluation batch size + save: ckpts # Model save path + save_interval: 100 # Save interval + log_interval: 20 # Log output interval + train_data: [ "your train data path" ] + valid_data: [ "your val data path" ] # Training and validation sets can be the same + split: 1,0,0 # Proportion for training, validation, and test sets + num_workers: 8 # Number of data loader workers + force_train: True # Allow missing keys when loading checkpoint (T5 and VAE loaded separately) + only_log_video_latents: True # Avoid memory usage from VAE decoding + deepspeed: + bf16: + enabled: False # For CogVideoX-2B Turn to False and For CogVideoX-5B Turn to True + fp16: + enabled: True # For CogVideoX-2B Turn to True and For CogVideoX-5B Turn to False +``` + +``` To use Lora fine-tuning, you also need to modify `cogvideox__lora` file: + +Here's an example using `CogVideoX-2B`: + +```yaml model: - scale_factor: 1.15258426 + scale_factor: 1.55258426 disable_first_stage_autocast: true - not_trainable_prefixes: [ 'all' ] ## コメントを解陀 + not_trainable_prefixes: [ 'all' ] ## Uncomment to unlock log_keys: - - txt' + - txt - lora_config: ## コメントを解陀 + lora_config: ## Uncomment to unlock target: sat.model.finetune.lora2.LoraMixin params: r: 256 ``` -### 実行スクリプトの倉曎 +### Modify the Run Script -蚭定ファむルを遞択するために `finetune_single_gpu.sh` たたは `finetune_multi_gpus.sh` を線集したす。以䞋に2぀の䟋を瀺したす。 +Edit `finetune_single_gpu.sh` or `finetune_multi_gpus.sh` and select the config file. Below are two examples: -1. `CogVideoX-2B` モデルを䜿甚し、`Lora` 手法を利甚する堎合は、`finetune_single_gpu.sh` たたは `finetune_multi_gpus.sh` - を倉曎する必芁がありたす。 +1. If you want to use the `CogVideoX-2B` model with `Lora`, modify `finetune_single_gpu.sh` or `finetune_multi_gpus.sh` as follows: ``` run_cmd="torchrun --standalone --nproc_per_node=8 train_video.py --base configs/cogvideox_2b_lora.yaml configs/sft.yaml --seed $RANDOM" ``` -2. `CogVideoX-2B` モデルを䜿甚し、`フルパラメヌタ埮調敎` 手法を利甚する堎合は、`finetune_single_gpu.sh` - たたは `finetune_multi_gpus.sh` を倉曎する必芁がありたす。 +2. If you want to use the `CogVideoX-2B` model with full fine-tuning, modify `finetune_single_gpu.sh` or `finetune_multi_gpus.sh` as follows: ``` run_cmd="torchrun --standalone --nproc_per_node=8 train_video.py --base configs/cogvideox_2b.yaml configs/sft.yaml --seed $RANDOM" ``` -### 埮調敎ず評䟡 +### Fine-tuning and Validation -掚論コヌドを実行しお埮調敎を開始したす。 +Run the inference code to start fine-tuning. ``` -bash finetune_single_gpu.sh # シングルGPU -bash finetune_multi_gpus.sh # マルチGPU +bash finetune_single_gpu.sh # Single GPU +bash finetune_multi_gpus.sh # Multi GPUs ``` -### 埮調敎埌のモデルの䜿甚 +### Using the Fine-tuned Model -埮調敎されたモデルは統合できたせん。ここでは、掚論蚭定ファむル `inference.sh` を倉曎する方法を瀺したす。 +The fine-tuned model cannot be merged. Here’s how to modify the inference configuration file `inference.sh` ``` -run_cmd="$environs python sample_video.py --base configs/cogvideox__lora.yaml configs/inference.yaml --seed 42" +run_cmd="$environs python sample_video.py --base configs/cogvideox__lora.yaml configs/inference.yaml --seed 42" ``` -その埌、次のコヌドを実行したす。 +Then, run the code: ``` bash inference.sh ``` -### Huggingface Diffusers サポヌトのりェむトに倉換 +### Converting to Huggingface Diffusers-compatible Weights -SAT りェむト圢匏は Huggingface のりェむト圢匏ず異なり、倉換が必芁です。次のコマンドを実行しおください +The SAT weight format is different from Huggingface’s format and requires conversion. Run -```shell +``` python ../tools/convert_weight_sat2hf.py ``` -### SATチェックポむントからHuggingface Diffusers lora LoRAりェむトを゚クスポヌト +### Exporting Lora Weights from SAT to Huggingface Diffusers -䞊蚘のステップを完了するず、LoRAりェむト付きのSATチェックポむントが埗られたす。ファむルは `{args.save}/1000/1000/mp_rank_00_model_states.pt` にありたす。 +Support is provided for exporting Lora weights from SAT to Huggingface Diffusers format. +After training with the above steps, you’ll find the SAT model with Lora weights in {args.save}/1000/1000/mp_rank_00_model_states.pt -LoRAりェむトを゚クスポヌトするためのスクリプトは、CogVideoXリポゞトリの `tools/export_sat_lora_weight.py` にありたす。゚クスポヌト埌、`load_cogvideox_lora.py` を䜿甚しお掚論を行うこずができたす。 +The export script `export_sat_lora_weight.py` is located in the CogVideoX repository under `tools/`. After exporting, use `load_cogvideox_lora.py` for inference. -゚クスポヌトコマンド: +Export command: -```bash -python tools/export_sat_lora_weight.py --sat_pt_path {args.save}/{experiment_name}-09-09-21-10/1000/mp_rank_00_model_states.pt --lora_save_directory {args.save}/export_hf_lora_weights_1/ +``` +python tools/export_sat_lora_weight.py --sat_pt_path {args.save}/{experiment_name}-09-09-21-10/1000/mp_rank_00_model_states.pt --lora_save_directory {args.save}/export_hf_lora_weights_1/ ``` -このトレヌニングでは䞻に以䞋のモデル構造が倉曎されたした。以䞋の衚は、HF (Hugging Face) 圢匏のLoRA構造に倉換する際の察応関係を瀺しおいたす。ご芧の通り、LoRAはモデルの泚意メカニズムに䜎ランクの重みを远加しおいたす。 +The following model structures were modified during training. Here is the mapping between SAT and HF Lora structures. Lora adds a low-rank weight to the attention structure of the model. ``` 'attention.query_key_value.matrix_A.0': 'attn1.to_q.lora_A.weight', @@ -431,8 +531,6 @@ python tools/export_sat_lora_weight.py --sat_pt_path {args.save}/{experiment_nam 'attention.dense.matrix_A.0': 'attn1.to_out.0.lora_A.weight', 'attention.dense.matrix_B.0': 'attn1.to_out.0.lora_B.weight' ``` - -export_sat_lora_weight.py を䜿甚しお、SATチェックポむントをHF LoRA圢匏に倉換できたす。 - -![alt text](../resources/hf_lora_weights.png) +Using `export_sat_lora_weight.py` will convert these to the HF format Lora structure. +![alt text](../resources/hf_lora_weights.png) \ No newline at end of file diff --git a/sat/README_zh.md b/sat/README_zh.md index c605da8..c25c6b7 100644 --- a/sat/README_zh.md +++ b/sat/README_zh.md @@ -1,6 +1,6 @@ -# SAT CogVideoX-2B +# SAT CogVideoX -[Read this in English.](./README_zh) +[Read this in English.](./README.md) [日本語で読む](./README_ja.md) @@ -20,6 +20,15 @@ pip install -r requirements.txt 銖先前埀 SAT 镜像䞋蜜暡型权重。 +#### CogVideoX1.5 æš¡åž‹ + +```shell +git lfs install +git clone https://huggingface.co/THUDM/CogVideoX1.5-5B-SAT +``` +歀操䜜䌚䞋蜜 Transformers, VAE, T5 Encoder 这䞉䞪暡型。 + +#### CogVideoX æš¡åž‹ 对于 CogVideoX-2B 暡型请按照劂䞋方匏䞋蜜: ```shell @@ -82,11 +91,11 @@ mv CogVideoX-2b/text_encoder/* CogVideoX-2b/tokenizer/* t5-v1_1-xxl 0 directories, 8 files ``` -### 3. 修改`configs/cogvideox_2b.yaml`䞭的文件。 +### 3. 修改`configs/cogvideox_*.yaml`䞭的文件。 ```yaml model: - scale_factor: 1.15258426 + scale_factor: 1.55258426 disable_first_stage_autocast: true log_keys: - txt @@ -253,7 +262,7 @@ args: batch_size: 1 input_type: txt #可以选择txt纯文字档䜜䞺蟓入或者改成cli呜什行䜜䞺蟓入 input_file: configs/test.txt #纯文字档可以对歀做猖蟑 - sampling_num_frames: 13 # Must be 13, 11 or 9 + sampling_num_frames: 13 #CogVideoX1.5-5B 必须是 42 或 22。 CogVideoX-5B / 2B 必须是 13 11 或 9。 sampling_fps: 8 fp16: True # For CogVideoX-2B # bf16: True # For CogVideoX-5B @@ -346,7 +355,7 @@ Encoder 䜿甚。 ```yaml model: - scale_factor: 1.15258426 + scale_factor: 1.55258426 disable_first_stage_autocast: true not_trainable_prefixes: [ 'all' ] ## 解陀泚释 log_keys: diff --git a/sat/arguments.py b/sat/arguments.py index 44767d3..9b0a1bb 100644 --- a/sat/arguments.py +++ b/sat/arguments.py @@ -36,6 +36,7 @@ def add_sampling_config_args(parser): group.add_argument("--input-dir", type=str, default=None) group.add_argument("--input-type", type=str, default="cli") group.add_argument("--input-file", type=str, default="input.txt") + group.add_argument("--sampling-image-size", type=list, default=[768, 1360]) group.add_argument("--final-size", type=int, default=2048) group.add_argument("--sdedit", action="store_true") group.add_argument("--grid-num-rows", type=int, default=1) diff --git a/sat/configs/cogvideox1.5_5b.yaml b/sat/configs/cogvideox1.5_5b.yaml new file mode 100644 index 0000000..0000ec2 --- /dev/null +++ b/sat/configs/cogvideox1.5_5b.yaml @@ -0,0 +1,149 @@ +model: + scale_factor: 0.7 + disable_first_stage_autocast: true + latent_input: true + log_keys: + - txt + + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser + params: + num_idx: 1000 + quantize_c_noise: False + + weighting_config: + target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.VideoScaling + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization + + network_config: + target: dit_video_concat.DiffusionTransformer + params: + time_embed_dim: 512 + elementwise_affine: True + num_frames: 81 + time_compressed_rate: 4 + latent_width: 300 + latent_height: 300 + num_layers: 42 + patch_size: [2, 2, 2] + in_channels: 16 + out_channels: 16 + hidden_size: 3072 + adm_in_channels: 256 + num_attention_heads: 48 + + transformer_args: + checkpoint_activations: True + vocab_size: 1 + max_sequence_length: 64 + layernorm_order: pre + skip_init: false + model_parallel_size: 1 + is_decoder: false + + modules: + pos_embed_config: + target: dit_video_concat.Rotary3DPositionEmbeddingMixin + params: + hidden_size_head: 64 + text_length: 224 + + patch_embed_config: + target: dit_video_concat.ImagePatchEmbeddingMixin + params: + text_hidden_size: 4096 + + adaln_layer_config: + target: dit_video_concat.AdaLNMixin + params: + qk_ln: True + + final_layer_config: + target: dit_video_concat.FinalLayerMixin + + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + - is_trainable: false + input_key: txt + ucg_rate: 0.1 + target: sgm.modules.encoders.modules.FrozenT5Embedder + params: + model_dir: "google/t5-v1_1-xxl" + max_length: 224 + + + first_stage_config: + target : vae_modules.autoencoder.VideoAutoencoderInferenceWrapper + params: + cp_size: 1 + ckpt_path: "cogvideox-5b-sat/vae/3d-vae.pt" + ignore_keys: ['loss'] + + loss_config: + target: torch.nn.Identity + + regularizer_config: + target: vae_modules.regularizers.DiagonalGaussianRegularizer + + encoder_config: + target: vae_modules.cp_enc_dec.ContextParallelEncoder3D + params: + double_z: true + z_channels: 16 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [1, 2, 2, 4] + attn_resolutions: [] + num_res_blocks: 3 + dropout: 0.0 + gather_norm: True + + decoder_config: + target: vae_modules.cp_enc_dec.ContextParallelDecoder3D + params: + double_z: True + z_channels: 16 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [1, 2, 2, 4] + attn_resolutions: [] + num_res_blocks: 3 + dropout: 0.0 + gather_norm: True + + loss_fn_config: + target: sgm.modules.diffusionmodules.loss.VideoDiffusionLoss + params: + offset_noise_level: 0 + sigma_sampler_config: + target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling + params: + uniform_sampling: True + group_num: 40 + num_idx: 1000 + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization + + sampler_config: + target: sgm.modules.diffusionmodules.sampling.VPSDEDPMPP2MSampler + params: + num_steps: 50 + verbose: True + + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization + guider_config: + target: sgm.modules.diffusionmodules.guiders.DynamicCFG + params: + scale: 6 + exp: 5 + num_steps: 50 diff --git a/sat/configs/cogvideox1.5_5b_i2v.yaml b/sat/configs/cogvideox1.5_5b_i2v.yaml new file mode 100644 index 0000000..c65f0b7 --- /dev/null +++ b/sat/configs/cogvideox1.5_5b_i2v.yaml @@ -0,0 +1,160 @@ +model: + scale_factor: 0.7 + disable_first_stage_autocast: true + latent_input: false + noised_image_input: true + noised_image_all_concat: false + noised_image_dropout: 0.05 + augmentation_dropout: 0.15 + log_keys: + - txt + + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser + params: + num_idx: 1000 + quantize_c_noise: False + + weighting_config: + target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.VideoScaling + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization + + network_config: + target: dit_video_concat.DiffusionTransformer + params: +# space_interpolation: 1.875 + ofs_embed_dim: 512 + time_embed_dim: 512 + elementwise_affine: True + num_frames: 81 + time_compressed_rate: 4 + latent_width: 300 + latent_height: 300 + num_layers: 42 + patch_size: [2, 2, 2] + in_channels: 32 + out_channels: 16 + hidden_size: 3072 + adm_in_channels: 256 + num_attention_heads: 48 + + transformer_args: + checkpoint_activations: True + vocab_size: 1 + max_sequence_length: 64 + layernorm_order: pre + skip_init: false + model_parallel_size: 1 + is_decoder: false + + modules: + pos_embed_config: + target: dit_video_concat.Rotary3DPositionEmbeddingMixin + params: + hidden_size_head: 64 + text_length: 224 + + patch_embed_config: + target: dit_video_concat.ImagePatchEmbeddingMixin + params: + text_hidden_size: 4096 + + + adaln_layer_config: + target: dit_video_concat.AdaLNMixin + params: + qk_ln: True + + final_layer_config: + target: dit_video_concat.FinalLayerMixin + + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + + - is_trainable: false + input_key: txt + ucg_rate: 0.1 + target: sgm.modules.encoders.modules.FrozenT5Embedder + params: + model_dir: "google/t5-v1_1-xxl" + max_length: 224 + + + first_stage_config: + target : vae_modules.autoencoder.VideoAutoencoderInferenceWrapper + params: + cp_size: 1 + ckpt_path: "cogvideox-5b-i2v-sat/vae/3d-vae.pt" + ignore_keys: ['loss'] + + loss_config: + target: torch.nn.Identity + + regularizer_config: + target: vae_modules.regularizers.DiagonalGaussianRegularizer + + encoder_config: + target: vae_modules.cp_enc_dec.ContextParallelEncoder3D + params: + double_z: true + z_channels: 16 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [1, 2, 2, 4] + attn_resolutions: [] + num_res_blocks: 3 + dropout: 0.0 + gather_norm: True + + decoder_config: + target: vae_modules.cp_enc_dec.ContextParallelDecoder3D + params: + double_z: True + z_channels: 16 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [1, 2, 2, 4] + attn_resolutions: [] + num_res_blocks: 3 + dropout: 0.0 + gather_norm: True + + loss_fn_config: + target: sgm.modules.diffusionmodules.loss.VideoDiffusionLoss + params: + fixed_frames: 0 + offset_noise_level: 0.0 + sigma_sampler_config: + target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling + params: + uniform_sampling: True + group_num: 40 + num_idx: 1000 + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization + + sampler_config: + target: sgm.modules.diffusionmodules.sampling.VPSDEDPMPP2MSampler + params: + fixed_frames: 0 + num_steps: 50 + verbose: True + + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization + + guider_config: + target: sgm.modules.diffusionmodules.guiders.DynamicCFG + params: + scale: 6 + exp: 5 + num_steps: 50 \ No newline at end of file diff --git a/sat/diffusion_video.py b/sat/diffusion_video.py index 963038b..10635b4 100644 --- a/sat/diffusion_video.py +++ b/sat/diffusion_video.py @@ -179,14 +179,31 @@ class SATVideoDiffusionEngine(nn.Module): n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0]) n_rounds = math.ceil(z.shape[0] / n_samples) all_out = [] - with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast): - for n in range(n_rounds): - if isinstance(self.first_stage_model.decoder, VideoDecoder): - kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])} - else: - kwargs = {} - out = self.first_stage_model.decode(z[n * n_samples : (n + 1) * n_samples], **kwargs) - all_out.append(out) + for n in range(n_rounds): + z_now = z[n * n_samples : (n + 1) * n_samples, :, 1:] + latent_time = z_now.shape[2] # check the time latent + temporal_compress_times = 4 + + fake_cp_size = min(10, latent_time // 2) + start_frame = 0 + + recons = [] + start_frame = 0 + for i in range(fake_cp_size): + end_frame = start_frame + latent_time // fake_cp_size + (1 if i < latent_time % fake_cp_size else 0) + + fake_cp_rank0 = True if i == 0 else False + clear_fake_cp_cache = True if i == fake_cp_size - 1 else False + with torch.no_grad(): + recon = self.first_stage_model.decode( + z_now[:, :, start_frame:end_frame].contiguous(), + clear_fake_cp_cache=clear_fake_cp_cache, + fake_cp_rank0=fake_cp_rank0, + ) + recons.append(recon) + start_frame = end_frame + recons = torch.cat(recons, dim=2) + all_out.append(recons) out = torch.cat(all_out, dim=0) return out @@ -218,6 +235,7 @@ class SATVideoDiffusionEngine(nn.Module): shape: Union[None, Tuple, List] = None, prefix=None, concat_images=None, + ofs=None, **kwargs, ): randn = torch.randn(batch_size, *shape).to(torch.float32).to(self.device) @@ -241,7 +259,7 @@ class SATVideoDiffusionEngine(nn.Module): self.model, input, sigma, c, concat_images=concat_images, **addtional_model_inputs ) - samples = self.sampler(denoiser, randn, cond, uc=uc, scale=scale, scale_emb=scale_emb) + samples = self.sampler(denoiser, randn, cond, uc=uc, scale=scale, scale_emb=scale_emb, ofs=ofs) samples = samples.to(self.dtype) return samples diff --git a/sat/dit_video_concat.py b/sat/dit_video_concat.py index 7692116..b55a3f1 100644 --- a/sat/dit_video_concat.py +++ b/sat/dit_video_concat.py @@ -1,5 +1,7 @@ from functools import partial from einops import rearrange, repeat +from functools import reduce +from operator import mul import numpy as np import torch @@ -13,38 +15,34 @@ from sat.mpu.layers import ColumnParallelLinear from sgm.util import instantiate_from_config from sgm.modules.diffusionmodules.openaimodel import Timestep -from sgm.modules.diffusionmodules.util import ( - linear, - timestep_embedding, -) +from sgm.modules.diffusionmodules.util import linear, timestep_embedding from sat.ops.layernorm import LayerNorm, RMSNorm class ImagePatchEmbeddingMixin(BaseMixin): - def __init__( - self, - in_channels, - hidden_size, - patch_size, - bias=True, - text_hidden_size=None, - ): + def __init__(self, in_channels, hidden_size, patch_size, text_hidden_size=None): super().__init__() - self.proj = nn.Conv2d(in_channels, hidden_size, kernel_size=patch_size, stride=patch_size, bias=bias) + self.patch_size = patch_size + self.proj = nn.Linear(in_channels * reduce(mul, patch_size), hidden_size) if text_hidden_size is not None: self.text_proj = nn.Linear(text_hidden_size, hidden_size) else: self.text_proj = None def word_embedding_forward(self, input_ids, **kwargs): - # now is 3d patch images = kwargs["images"] # (b,t,c,h,w) - B, T = images.shape[:2] - emb = images.view(-1, *images.shape[2:]) - emb = self.proj(emb) # ((b t),d,h/2,w/2) - emb = emb.view(B, T, *emb.shape[1:]) - emb = emb.flatten(3).transpose(2, 3) # (b,t,n,d) - emb = rearrange(emb, "b t n d -> b (t n) d") + emb = rearrange(images, "b t c h w -> b (t h w) c") + emb = rearrange( + emb, + "b (t o h p w q) c -> b (t h w) (c o p q)", + t=kwargs["rope_T"], + h=kwargs["rope_H"], + w=kwargs["rope_W"], + o=self.patch_size[0], + p=self.patch_size[1], + q=self.patch_size[2], + ) + emb = self.proj(emb) if self.text_proj is not None: text_emb = self.text_proj(kwargs["encoder_outputs"]) @@ -74,7 +72,8 @@ def get_3d_sincos_pos_embed( grid_size: int of the grid height and width t_size: int of the temporal size return: - pos_embed: [t_size*grid_size*grid_size, embed_dim] or [1+t_size*grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + pos_embed: [t_size*grid_size * grid_size, embed_dim] or [1+t_size*grid_size * grid_size, embed_dim] + (w/ or w/o cls_token) """ assert embed_dim % 4 == 0 embed_dim_spatial = embed_dim // 4 * 3 @@ -100,7 +99,6 @@ def get_3d_sincos_pos_embed( pos_embed_spatial = np.repeat(pos_embed_spatial, t_size, axis=0) # [T, H*W, D // 4 * 3] pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1) - # pos_embed = pos_embed.reshape([-1, embed_dim]) # [T*H*W, D] return pos_embed # [T, H*W, D] @@ -259,6 +257,9 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin): text_length, theta=10000, rot_v=False, + height_interpolation=1.0, + width_interpolation=1.0, + time_interpolation=1.0, learnable_pos_embed=False, ): super().__init__() @@ -285,14 +286,10 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin): freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2) freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1) - freqs = rearrange(freqs, "t h w d -> (t h w) d") freqs = freqs.contiguous() - freqs_sin = freqs.sin() - freqs_cos = freqs.cos() - self.register_buffer("freqs_sin", freqs_sin) - self.register_buffer("freqs_cos", freqs_cos) - + self.freqs_sin = freqs.sin().cuda() + self.freqs_cos = freqs.cos().cuda() self.text_length = text_length if learnable_pos_embed: num_patches = height * width * compressed_num_frames + text_length @@ -301,15 +298,20 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin): self.pos_embedding = None def rotary(self, t, **kwargs): - seq_len = t.shape[2] - freqs_cos = self.freqs_cos[:seq_len].unsqueeze(0).unsqueeze(0) - freqs_sin = self.freqs_sin[:seq_len].unsqueeze(0).unsqueeze(0) + def reshape_freq(freqs): + freqs = freqs[: kwargs["rope_T"], : kwargs["rope_H"], : kwargs["rope_W"]].contiguous() + freqs = rearrange(freqs, "t h w d -> (t h w) d") + freqs = freqs.unsqueeze(0).unsqueeze(0) + return freqs + + freqs_cos = reshape_freq(self.freqs_cos).to(t.dtype) + freqs_sin = reshape_freq(self.freqs_sin).to(t.dtype) return t * freqs_cos + rotate_half(t) * freqs_sin def position_embedding_forward(self, position_ids, **kwargs): if self.pos_embedding is not None: - return self.pos_embedding[:, :self.text_length + kwargs["seq_length"]] + return self.pos_embedding[:, : self.text_length + kwargs["seq_length"]] else: return None @@ -326,10 +328,61 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin): ): attention_fn_default = HOOKS_DEFAULT["attention_fn"] - query_layer[:, :, self.text_length :] = self.rotary(query_layer[:, :, self.text_length :]) - key_layer[:, :, self.text_length :] = self.rotary(key_layer[:, :, self.text_length :]) + query_layer = torch.cat( + ( + query_layer[ + :, + :, + : kwargs["text_length"], + ], + self.rotary( + query_layer[ + :, + :, + kwargs["text_length"] :, + ], + **kwargs, + ), + ), + dim=2, + ) + key_layer = torch.cat( + ( + key_layer[ + :, + :, + : kwargs["text_length"], + ], + self.rotary( + key_layer[ + :, + :, + kwargs["text_length"] :, + ], + **kwargs, + ), + ), + dim=2, + ) if self.rot_v: - value_layer[:, :, self.text_length :] = self.rotary(value_layer[:, :, self.text_length :]) + value_layer = torch.cat( + ( + value_layer[ + :, + :, + : kwargs["text_length"], + ], + self.rotary( + value_layer[ + :, + :, + kwargs["text_length"] :, + ], + **kwargs, + ), + ), + dim=2, + ) return attention_fn_default( query_layer, @@ -347,21 +400,25 @@ def modulate(x, shift, scale): return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) -def unpatchify(x, c, p, w, h, rope_position_ids=None, **kwargs): +def unpatchify(x, c, patch_size, w, h, **kwargs): """ x: (N, T/2 * S, patch_size**3 * C) imgs: (N, T, H, W, C) + + patch_size 被拆解䞺䞉䞪䞍同的绎床 (o, p, q)分别对应了深床o、高床p和宜床q。这䜿埗 patch 倧小圚䞍同绎床䞊可以䞍盞等增加了灵掻性。 """ - if rope_position_ids is not None: - assert NotImplementedError - # do pix2struct unpatchify - L = x.shape[1] - x = x.reshape(shape=(x.shape[0], L, p, p, c)) - x = torch.einsum("nlpqc->ncplq", x) - imgs = x.reshape(shape=(x.shape[0], c, p, L * p)) - else: - b = x.shape[0] - imgs = rearrange(x, "b (t h w) (c p q) -> b t c (h p) (w q)", b=b, h=h, w=w, c=c, p=p, q=p) + + imgs = rearrange( + x, + "b (t h w) (c o p q) -> b (t o) c (h p) (w q)", + c=c, + o=patch_size[0], + p=patch_size[1], + q=patch_size[2], + t=kwargs["rope_T"], + h=kwargs["rope_H"], + w=kwargs["rope_W"], + ) return imgs @@ -382,27 +439,17 @@ class FinalLayerMixin(BaseMixin): self.patch_size = patch_size self.out_channels = out_channels self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=elementwise_affine, eps=1e-6) - self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.linear = nn.Linear(hidden_size, reduce(mul, patch_size) * out_channels, bias=True) self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim, 2 * hidden_size, bias=True)) - self.spatial_length = latent_width * latent_height // patch_size**2 - self.latent_width = latent_width - self.latent_height = latent_height - def final_forward(self, logits, **kwargs): - x, emb = logits[:, kwargs["text_length"] :, :], kwargs["emb"] # x:(b,(t n),d) + x, emb = logits[:, kwargs["text_length"] :, :], kwargs["emb"] # x:(b,(t n),d),只取了x䞭后面images的郚分 shift, scale = self.adaLN_modulation(emb).chunk(2, dim=1) x = modulate(self.norm_final(x), shift, scale) x = self.linear(x) return unpatchify( - x, - c=self.out_channels, - p=self.patch_size, - w=self.latent_width // self.patch_size, - h=self.latent_height // self.patch_size, - rope_position_ids=kwargs.get("rope_position_ids", None), - **kwargs, + x, c=self.out_channels, patch_size=self.patch_size, w=kwargs["rope_W"], h=kwargs["rope_H"], **kwargs ) def reinit(self, parent_model=None): @@ -440,8 +487,6 @@ class SwiGLUMixin(BaseMixin): class AdaLNMixin(BaseMixin): def __init__( self, - width, - height, hidden_size, num_layers, time_embed_dim, @@ -452,8 +497,6 @@ class AdaLNMixin(BaseMixin): ): super().__init__() self.num_layers = num_layers - self.width = width - self.height = height self.compressed_num_frames = compressed_num_frames self.adaLN_modulations = nn.ModuleList( @@ -611,7 +654,7 @@ class DiffusionTransformer(BaseModel): time_interpolation=1.0, use_SwiGLU=False, use_RMSNorm=False, - zero_init_y_embed=False, + ofs_embed_dim=None, **kwargs, ): self.latent_width = latent_width @@ -619,12 +662,13 @@ class DiffusionTransformer(BaseModel): self.patch_size = patch_size self.num_frames = num_frames self.time_compressed_rate = time_compressed_rate - self.spatial_length = latent_width * latent_height // patch_size**2 + self.spatial_length = latent_width * latent_height // reduce(mul, patch_size[1:]) self.in_channels = in_channels self.out_channels = out_channels self.hidden_size = hidden_size self.model_channels = hidden_size self.time_embed_dim = time_embed_dim if time_embed_dim is not None else hidden_size + self.ofs_embed_dim = ofs_embed_dim self.num_classes = num_classes self.adm_in_channels = adm_in_channels self.input_time = input_time @@ -636,7 +680,6 @@ class DiffusionTransformer(BaseModel): self.width_interpolation = width_interpolation self.time_interpolation = time_interpolation self.inner_hidden_size = hidden_size * 4 - self.zero_init_y_embed = zero_init_y_embed try: self.dtype = str_to_dtype[kwargs.pop("dtype")] except: @@ -669,7 +712,6 @@ class DiffusionTransformer(BaseModel): def _build_modules(self, module_configs): model_channels = self.hidden_size - # time_embed_dim = model_channels * 4 time_embed_dim = self.time_embed_dim self.time_embed = nn.Sequential( linear(model_channels, time_embed_dim), @@ -677,6 +719,13 @@ class DiffusionTransformer(BaseModel): linear(time_embed_dim, time_embed_dim), ) + if self.ofs_embed_dim is not None: + self.ofs_embed = nn.Sequential( + linear(self.ofs_embed_dim, self.ofs_embed_dim), + nn.SiLU(), + linear(self.ofs_embed_dim, self.ofs_embed_dim), + ) + if self.num_classes is not None: if isinstance(self.num_classes, int): self.label_emb = nn.Embedding(self.num_classes, time_embed_dim) @@ -701,9 +750,6 @@ class DiffusionTransformer(BaseModel): linear(time_embed_dim, time_embed_dim), ) ) - if self.zero_init_y_embed: - nn.init.constant_(self.label_emb[0][2].weight, 0) - nn.init.constant_(self.label_emb[0][2].bias, 0) else: raise ValueError() @@ -712,10 +758,13 @@ class DiffusionTransformer(BaseModel): "pos_embed", instantiate_from_config( pos_embed_config, - height=self.latent_height // self.patch_size, - width=self.latent_width // self.patch_size, + height=self.latent_height // self.patch_size[1], + width=self.latent_width // self.patch_size[2], compressed_num_frames=(self.num_frames - 1) // self.time_compressed_rate + 1, hidden_size=self.hidden_size, + height_interpolation=self.height_interpolation, + width_interpolation=self.width_interpolation, + time_interpolation=self.time_interpolation, ), reinit=True, ) @@ -737,8 +786,6 @@ class DiffusionTransformer(BaseModel): "adaln_layer", instantiate_from_config( adaln_layer_config, - height=self.latent_height // self.patch_size, - width=self.latent_width // self.patch_size, hidden_size=self.hidden_size, num_layers=self.num_layers, compressed_num_frames=(self.num_frames - 1) // self.time_compressed_rate + 1, @@ -749,7 +796,6 @@ class DiffusionTransformer(BaseModel): ) else: raise NotImplementedError - final_layer_config = module_configs["final_layer_config"] self.add_mixin( "final_layer", @@ -766,25 +812,18 @@ class DiffusionTransformer(BaseModel): reinit=True, ) - if "lora_config" in module_configs: - lora_config = module_configs["lora_config"] - self.add_mixin("lora", instantiate_from_config(lora_config, layer_num=self.num_layers), reinit=True) - return def forward(self, x, timesteps=None, context=None, y=None, **kwargs): b, t, d, h, w = x.shape if x.dtype != self.dtype: x = x.to(self.dtype) - - # This is not use in inference if "concat_images" in kwargs and kwargs["concat_images"] is not None: if kwargs["concat_images"].shape[0] != x.shape[0]: concat_images = kwargs["concat_images"].repeat(2, 1, 1, 1, 1) else: concat_images = kwargs["concat_images"] x = torch.cat([x, concat_images], dim=2) - assert (y is not None) == ( self.num_classes is not None ), "must specify y if and only if the model is class-conditional" @@ -792,17 +831,25 @@ class DiffusionTransformer(BaseModel): emb = self.time_embed(t_emb) if self.num_classes is not None: - # assert y.shape[0] == x.shape[0] assert x.shape[0] % y.shape[0] == 0 y = y.repeat_interleave(x.shape[0] // y.shape[0], dim=0) emb = emb + self.label_emb(y) - kwargs["seq_length"] = t * h * w // (self.patch_size**2) + if self.ofs_embed_dim is not None: + ofs_emb = timestep_embedding(kwargs["ofs"], self.ofs_embed_dim, repeat_only=False, dtype=self.dtype) + ofs_emb = self.ofs_embed(ofs_emb) + emb = emb + ofs_emb + + kwargs["seq_length"] = t * h * w // reduce(mul, self.patch_size) kwargs["images"] = x kwargs["emb"] = emb kwargs["encoder_outputs"] = context kwargs["text_length"] = context.shape[1] + kwargs["rope_T"] = t // self.patch_size[0] + kwargs["rope_H"] = h // self.patch_size[1] + kwargs["rope_W"] = w // self.patch_size[2] + kwargs["input_ids"] = kwargs["position_ids"] = kwargs["attention_mask"] = torch.ones((1, 1)).to(x.dtype) output = super().forward(**kwargs)[0] return output diff --git a/sat/inference.sh b/sat/inference.sh index c798fa5..a22ef87 100755 --- a/sat/inference.sh +++ b/sat/inference.sh @@ -4,7 +4,7 @@ echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" environs="WORLD_SIZE=1 RANK=0 LOCAL_RANK=0 LOCAL_WORLD_SIZE=1" -run_cmd="$environs python sample_video.py --base configs/cogvideox_5b.yaml configs/inference.yaml --seed $RANDOM" +run_cmd="$environs python sample_video.py --base configs/cogvideox1.5_5b.yaml configs/test_inference.yaml --seed $RANDOM" echo ${run_cmd} eval ${run_cmd} diff --git a/sat/requirements.txt b/sat/requirements.txt index 75b4649..3c1c501 100644 --- a/sat/requirements.txt +++ b/sat/requirements.txt @@ -1,16 +1,11 @@ -SwissArmyTransformer==0.4.12 -omegaconf==2.3.0 -torch==2.4.0 -torchvision==0.19.0 -pytorch_lightning==2.3.3 -kornia==0.7.3 -beartype==0.18.5 -numpy==2.0.1 -fsspec==2024.5.0 -safetensors==0.4.3 -imageio-ffmpeg==0.5.1 -imageio==2.34.2 -scipy==1.14.0 -decord==0.6.0 -wandb==0.17.5 -deepspeed==0.14.4 \ No newline at end of file +SwissArmyTransformer>=0.4.12 +omegaconf>=2.3.0 +pytorch_lightning>=2.4.0 +kornia>=0.7.3 +beartype>=0.19.0 +fsspec>=2024.2.0 +safetensors>=0.4.5 +scipy>=1.14.1 +decord>=0.6.0 +wandb>=0.18.5 +deepspeed>=0.15.3 \ No newline at end of file diff --git a/sat/sample_video.py b/sat/sample_video.py index 49cfcac..c34e6a7 100644 --- a/sat/sample_video.py +++ b/sat/sample_video.py @@ -4,24 +4,20 @@ import argparse from typing import List, Union from tqdm import tqdm from omegaconf import ListConfig +from PIL import Image import imageio import torch import numpy as np -from einops import rearrange +from einops import rearrange, repeat import torchvision.transforms as TT - from sat.model.base_model import get_model from sat.training.model_io import load_checkpoint from sat import mpu from diffusion_video import SATVideoDiffusionEngine from arguments import get_args -from torchvision.transforms.functional import center_crop, resize -from torchvision.transforms import InterpolationMode -from PIL import Image - def read_from_cli(): cnt = 0 @@ -56,6 +52,42 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], T=None, device="cuda if key == "txt": batch["txt"] = np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist() batch_uc["txt"] = np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist() + elif key == "original_size_as_tuple": + batch["original_size_as_tuple"] = ( + torch.tensor([value_dict["orig_height"], value_dict["orig_width"]]).to(device).repeat(*N, 1) + ) + elif key == "crop_coords_top_left": + batch["crop_coords_top_left"] = ( + torch.tensor([value_dict["crop_coords_top"], value_dict["crop_coords_left"]]).to(device).repeat(*N, 1) + ) + elif key == "aesthetic_score": + batch["aesthetic_score"] = torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1) + batch_uc["aesthetic_score"] = ( + torch.tensor([value_dict["negative_aesthetic_score"]]).to(device).repeat(*N, 1) + ) + + elif key == "target_size_as_tuple": + batch["target_size_as_tuple"] = ( + torch.tensor([value_dict["target_height"], value_dict["target_width"]]).to(device).repeat(*N, 1) + ) + elif key == "fps": + batch[key] = torch.tensor([value_dict["fps"]]).to(device).repeat(math.prod(N)) + elif key == "fps_id": + batch[key] = torch.tensor([value_dict["fps_id"]]).to(device).repeat(math.prod(N)) + elif key == "motion_bucket_id": + batch[key] = torch.tensor([value_dict["motion_bucket_id"]]).to(device).repeat(math.prod(N)) + elif key == "pool_image": + batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=math.prod(N)).to(device, dtype=torch.half) + elif key == "cond_aug": + batch[key] = repeat( + torch.tensor([value_dict["cond_aug"]]).to("cuda"), + "1 -> b", + b=math.prod(N), + ) + elif key == "cond_frames": + batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0]) + elif key == "cond_frames_without_noise": + batch[key] = repeat(value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0]) else: batch[key] = value_dict[key] @@ -83,37 +115,6 @@ def save_video_as_grid_and_mp4(video_batch: torch.Tensor, save_path: str, fps: i writer.append_data(frame) -def resize_for_rectangle_crop(arr, image_size, reshape_mode="random"): - if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]: - arr = resize( - arr, - size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])], - interpolation=InterpolationMode.BICUBIC, - ) - else: - arr = resize( - arr, - size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]], - interpolation=InterpolationMode.BICUBIC, - ) - - h, w = arr.shape[2], arr.shape[3] - arr = arr.squeeze(0) - - delta_h = h - image_size[0] - delta_w = w - image_size[1] - - if reshape_mode == "random" or reshape_mode == "none": - top = np.random.randint(0, delta_h + 1) - left = np.random.randint(0, delta_w + 1) - elif reshape_mode == "center": - top, left = delta_h // 2, delta_w // 2 - else: - raise NotImplementedError - arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1]) - return arr - - def sampling_main(args, model_cls): if isinstance(model_cls, type): model = get_model(args, model_cls) @@ -127,45 +128,62 @@ def sampling_main(args, model_cls): data_iter = read_from_cli() elif args.input_type == "txt": rank, world_size = mpu.get_data_parallel_rank(), mpu.get_data_parallel_world_size() - print("rank and world_size", rank, world_size) data_iter = read_from_file(args.input_file, rank=rank, world_size=world_size) else: raise NotImplementedError - image_size = [480, 720] - - if args.image2video: - chained_trainsforms = [] - chained_trainsforms.append(TT.ToTensor()) - transform = TT.Compose(chained_trainsforms) - sample_func = model.sample - T, H, W, C, F = args.sampling_num_frames, image_size[0], image_size[1], args.latent_channels, 8 num_samples = [1] force_uc_zero_embeddings = ["txt"] - device = model.device + T, C = args.sampling_num_frames, args.latent_channels with torch.no_grad(): for text, cnt in tqdm(data_iter): if args.image2video: + # use with input image shape text, image_path = text.split("@@") assert os.path.exists(image_path), image_path image = Image.open(image_path).convert("RGB") + (img_W, img_H) = image.size + + def nearest_multiple_of_16(n): + lower_multiple = (n // 16) * 16 + upper_multiple = (n // 16 + 1) * 16 + if abs(n - lower_multiple) < abs(n - upper_multiple): + return lower_multiple + else: + return upper_multiple + + if img_H < img_W: + H = 96 + W = int(nearest_multiple_of_16(img_W / img_H * H * 8)) // 8 + else: + W = 96 + H = int(nearest_multiple_of_16(img_H / img_W * W * 8)) // 8 + chained_trainsforms = [] + chained_trainsforms.append(TT.Resize(size=[int(H * 8), int(W * 8)], interpolation=1)) + chained_trainsforms.append(TT.ToTensor()) + transform = TT.Compose(chained_trainsforms) image = transform(image).unsqueeze(0).to("cuda") - image = resize_for_rectangle_crop(image, image_size, reshape_mode="center").unsqueeze(0) image = image * 2.0 - 1.0 image = image.unsqueeze(2).to(torch.bfloat16) image = model.encode_first_stage(image, None) + image = image / model.scale_factor image = image.permute(0, 2, 1, 3, 4).contiguous() - pad_shape = (image.shape[0], T - 1, C, H // F, W // F) + pad_shape = (image.shape[0], T - 1, C, H, W) image = torch.concat([image, torch.zeros(pad_shape).to(image.device).to(image.dtype)], dim=1) else: + image_size = args.sampling_image_size + H, W = image_size[0], image_size[1] + F = 8 # 8x downsampled image = None - value_dict = { - "prompt": text, - "negative_prompt": "", - "num_frames": torch.tensor(T).unsqueeze(0), - } + text_cast = [text] + mp_size = mpu.get_model_parallel_world_size() + global_rank = torch.distributed.get_rank() // mp_size + src = global_rank * mp_size + torch.distributed.broadcast_object_list(text_cast, src=src, group=mpu.get_model_parallel_group()) + text = text_cast[0] + value_dict = {"prompt": text, "negative_prompt": "", "num_frames": torch.tensor(T).unsqueeze(0)} batch, batch_uc = get_batch( get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples @@ -187,57 +205,42 @@ def sampling_main(args, model_cls): if not k == "crossattn": c[k], uc[k] = map(lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc)) - if args.image2video and image is not None: + if args.image2video: c["concat"] = image uc["concat"] = image for index in range(args.batch_size): - # reload model on GPU - model.to(device) - samples_z = sample_func( - c, - uc=uc, - batch_size=1, - shape=(T, C, H // F, W // F), - ) + if args.image2video: + samples_z = sample_func( + c, uc=uc, batch_size=1, shape=(T, C, H, W), ofs=torch.tensor([2.0]).to("cuda") + ) + else: + samples_z = sample_func( + c, + uc=uc, + batch_size=1, + shape=(T, C, H // F, W // F), + ).to("cuda") + samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous() - - # Unload the model from GPU to save GPU memory - model.to("cpu") - torch.cuda.empty_cache() - first_stage_model = model.first_stage_model - first_stage_model = first_stage_model.to(device) - - latent = 1.0 / model.scale_factor * samples_z - - # Decode latent serial to save GPU memory - recons = [] - loop_num = (T - 1) // 2 - for i in range(loop_num): - if i == 0: - start_frame, end_frame = 0, 3 - else: - start_frame, end_frame = i * 2 + 1, i * 2 + 3 - if i == loop_num - 1: - clear_fake_cp_cache = True - else: - clear_fake_cp_cache = False - with torch.no_grad(): - recon = first_stage_model.decode( - latent[:, :, start_frame:end_frame].contiguous(), clear_fake_cp_cache=clear_fake_cp_cache - ) - - recons.append(recon) - - recon = torch.cat(recons, dim=2).to(torch.float32) - samples_x = recon.permute(0, 2, 1, 3, 4).contiguous() - samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu() - - save_path = os.path.join( - args.output_dir, str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120], str(index) - ) - if mpu.get_model_parallel_rank() == 0: - save_video_as_grid_and_mp4(samples, save_path, fps=args.sampling_fps) + if args.only_save_latents: + samples_z = 1.0 / model.scale_factor * samples_z + save_path = os.path.join( + args.output_dir, str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120], str(index) + ) + os.makedirs(save_path, exist_ok=True) + torch.save(samples_z, os.path.join(save_path, "latent.pt")) + with open(os.path.join(save_path, "text.txt"), "w") as f: + f.write(text) + else: + samples_x = model.decode_first_stage(samples_z).to(torch.float32) + samples_x = samples_x.permute(0, 2, 1, 3, 4).contiguous() + samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu() + save_path = os.path.join( + args.output_dir, str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120], str(index) + ) + if mpu.get_model_parallel_rank() == 0: + save_video_as_grid_and_mp4(samples, save_path, fps=args.sampling_fps) if __name__ == "__main__": diff --git a/sat/sgm/modules/diffusionmodules/sampling.py b/sat/sgm/modules/diffusionmodules/sampling.py index f0f1830..6efd154 100644 --- a/sat/sgm/modules/diffusionmodules/sampling.py +++ b/sat/sgm/modules/diffusionmodules/sampling.py @@ -1,7 +1,8 @@ """ -Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py + Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py """ + from typing import Dict, Union import torch @@ -16,7 +17,6 @@ from ...modules.diffusionmodules.sampling_utils import ( to_sigma, ) from ...util import append_dims, default, instantiate_from_config -from ...util import SeededNoise from .guiders import DynamicCFG @@ -44,7 +44,9 @@ class BaseDiffusionSampler: self.device = device def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None): - sigmas = self.discretization(self.num_steps if num_steps is None else num_steps, device=self.device) + sigmas = self.discretization( + self.num_steps if num_steps is None else num_steps, device=self.device + ) uc = default(uc, cond) x *= torch.sqrt(1.0 + sigmas[0] ** 2.0) @@ -83,7 +85,9 @@ class SingleStepDiffusionSampler(BaseDiffusionSampler): class EDMSampler(SingleStepDiffusionSampler): - def __init__(self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs): + def __init__( + self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs + ): super().__init__(*args, **kwargs) self.s_churn = s_churn @@ -102,15 +106,21 @@ class EDMSampler(SingleStepDiffusionSampler): dt = append_dims(next_sigma - sigma_hat, x.ndim) euler_step = self.euler_step(x, d, dt) - x = self.possible_correction_step(euler_step, x, d, dt, next_sigma, denoiser, cond, uc) + x = self.possible_correction_step( + euler_step, x, d, dt, next_sigma, denoiser, cond, uc + ) return x def __call__(self, denoiser, x, cond, uc=None, num_steps=None): - x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) + x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( + x, cond, uc, num_steps + ) for i in self.get_sigma_gen(num_sigmas): gamma = ( - min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) if self.s_tmin <= sigmas[i] <= self.s_tmax else 0.0 + min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) + if self.s_tmin <= sigmas[i] <= self.s_tmax + else 0.0 ) x = self.sampler_step( s_in * sigmas[i], @@ -126,23 +136,30 @@ class EDMSampler(SingleStepDiffusionSampler): class DDIMSampler(SingleStepDiffusionSampler): - def __init__(self, s_noise=0.1, *args, **kwargs): + def __init__( + self, s_noise=0.1, *args, **kwargs + ): super().__init__(*args, **kwargs) self.s_noise = s_noise def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, s_noise=0.0): + denoised = self.denoise(x, denoiser, sigma, cond, uc) d = to_d(x, sigma, denoised) - dt = append_dims(next_sigma * (1 - s_noise**2) ** 0.5 - sigma, x.ndim) + dt = append_dims(next_sigma * (1 - s_noise**2)**0.5 - sigma, x.ndim) euler_step = x + dt * d + s_noise * append_dims(next_sigma, x.ndim) * torch.randn_like(x) - x = self.possible_correction_step(euler_step, x, d, dt, next_sigma, denoiser, cond, uc) + x = self.possible_correction_step( + euler_step, x, d, dt, next_sigma, denoiser, cond, uc + ) return x def __call__(self, denoiser, x, cond, uc=None, num_steps=None): - x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) + x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( + x, cond, uc, num_steps + ) for i in self.get_sigma_gen(num_sigmas): x = self.sampler_step( @@ -181,7 +198,9 @@ class AncestralSampler(SingleStepDiffusionSampler): return x def __call__(self, denoiser, x, cond, uc=None, num_steps=None): - x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) + x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( + x, cond, uc, num_steps + ) for i in self.get_sigma_gen(num_sigmas): x = self.sampler_step( @@ -208,32 +227,43 @@ class LinearMultistepSampler(BaseDiffusionSampler): self.order = order def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs): - x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) + x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( + x, cond, uc, num_steps + ) ds = [] sigmas_cpu = sigmas.detach().cpu().numpy() for i in self.get_sigma_gen(num_sigmas): sigma = s_in * sigmas[i] - denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs) + denoised = denoiser( + *self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs + ) denoised = self.guider(denoised, sigma) d = to_d(x, sigma, denoised) ds.append(d) if len(ds) > self.order: ds.pop(0) cur_order = min(i + 1, self.order) - coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)] + coeffs = [ + linear_multistep_coeff(cur_order, sigmas_cpu, i, j) + for j in range(cur_order) + ] x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds))) return x class EulerEDMSampler(EDMSampler): - def possible_correction_step(self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc): + def possible_correction_step( + self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc + ): return euler_step class HeunEDMSampler(EDMSampler): - def possible_correction_step(self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc): + def possible_correction_step( + self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc + ): if torch.sum(next_sigma) < 1e-14: # Save a network evaluation if all noise levels are 0 return euler_step @@ -243,7 +273,9 @@ class HeunEDMSampler(EDMSampler): d_prime = (d + d_new) / 2.0 # apply correction if noise level is not 0 - x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step) + x = torch.where( + append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step + ) return x @@ -282,7 +314,9 @@ class DPMPP2SAncestralSampler(AncestralSampler): x = x_euler else: h, s, t, t_next = self.get_variables(sigma, sigma_down) - mult = [append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next)] + mult = [ + append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next) + ] x2 = mult[0] * x - mult[1] * denoised denoised2 = self.denoise(x2, denoiser, to_sigma(s), cond, uc) @@ -332,7 +366,10 @@ class DPMPP2MSampler(BaseDiffusionSampler): denoised = self.denoise(x, denoiser, sigma, cond, uc) h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma) - mult = [append_dims(mult, x.ndim) for mult in self.get_mult(h, r, t, t_next, previous_sigma)] + mult = [ + append_dims(mult, x.ndim) + for mult in self.get_mult(h, r, t, t_next, previous_sigma) + ] x_standard = mult[0] * x - mult[1] * denoised if old_denoised is None or torch.sum(next_sigma) < 1e-14: @@ -343,12 +380,16 @@ class DPMPP2MSampler(BaseDiffusionSampler): x_advanced = mult[0] * x - mult[1] * denoised_d # apply correction if noise level is not 0 and not first step - x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard) + x = torch.where( + append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard + ) return x, denoised def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs): - x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) + x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( + x, cond, uc, num_steps + ) old_denoised = None for i in self.get_sigma_gen(num_sigmas): @@ -365,7 +406,6 @@ class DPMPP2MSampler(BaseDiffusionSampler): return x - class SDEDPMPP2MSampler(BaseDiffusionSampler): def get_variables(self, sigma, next_sigma, previous_sigma=None): t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)] @@ -380,7 +420,7 @@ class SDEDPMPP2MSampler(BaseDiffusionSampler): def get_mult(self, h, r, t, t_next, previous_sigma): mult1 = to_sigma(t_next) / to_sigma(t) * (-h).exp() - mult2 = (-2 * h).expm1() + mult2 = (-2*h).expm1() if previous_sigma is not None: mult3 = 1 + 1 / (2 * r) @@ -403,8 +443,11 @@ class SDEDPMPP2MSampler(BaseDiffusionSampler): denoised = self.denoise(x, denoiser, sigma, cond, uc) h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma) - mult = [append_dims(mult, x.ndim) for mult in self.get_mult(h, r, t, t_next, previous_sigma)] - mult_noise = append_dims(next_sigma * (1 - (-2 * h).exp()) ** 0.5, x.ndim) + mult = [ + append_dims(mult, x.ndim) + for mult in self.get_mult(h, r, t, t_next, previous_sigma) + ] + mult_noise = append_dims(next_sigma * (1 - (-2*h).exp())**0.5, x.ndim) x_standard = mult[0] * x - mult[1] * denoised + mult_noise * torch.randn_like(x) if old_denoised is None or torch.sum(next_sigma) < 1e-14: @@ -415,12 +458,16 @@ class SDEDPMPP2MSampler(BaseDiffusionSampler): x_advanced = mult[0] * x - mult[1] * denoised_d + mult_noise * torch.randn_like(x) # apply correction if noise level is not 0 and not first step - x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard) + x = torch.where( + append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard + ) return x, denoised def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, **kwargs): - x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) + x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( + x, cond, uc, num_steps + ) old_denoised = None for i in self.get_sigma_gen(num_sigmas): @@ -437,7 +484,6 @@ class SDEDPMPP2MSampler(BaseDiffusionSampler): return x - class SdeditEDMSampler(EulerEDMSampler): def __init__(self, edit_ratio=0.5, *args, **kwargs): super().__init__(*args, **kwargs) @@ -446,7 +492,9 @@ class SdeditEDMSampler(EulerEDMSampler): def __call__(self, denoiser, image, randn, cond, uc=None, num_steps=None, edit_ratio=None): randn_unit = randn.clone() - randn, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(randn, cond, uc, num_steps) + randn, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( + randn, cond, uc, num_steps + ) if num_steps is None: num_steps = self.num_steps @@ -461,7 +509,9 @@ class SdeditEDMSampler(EulerEDMSampler): x = image + randn_unit * append_dims(s_in * sigmas[i], len(randn_unit.shape)) gamma = ( - min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) if self.s_tmin <= sigmas[i] <= self.s_tmax else 0.0 + min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) + if self.s_tmin <= sigmas[i] <= self.s_tmax + else 0.0 ) x = self.sampler_step( s_in * sigmas[i], @@ -475,8 +525,8 @@ class SdeditEDMSampler(EulerEDMSampler): return x - class VideoDDIMSampler(BaseDiffusionSampler): + def __init__(self, fixed_frames=0, sdedit=False, **kwargs): super().__init__(**kwargs) self.fixed_frames = fixed_frames @@ -484,13 +534,10 @@ class VideoDDIMSampler(BaseDiffusionSampler): def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None): alpha_cumprod_sqrt, timesteps = self.discretization( - self.num_steps if num_steps is None else num_steps, - device=self.device, - return_idx=True, - do_append_zero=False, + self.num_steps if num_steps is None else num_steps, device=self.device, return_idx=True, do_append_zero=False ) alpha_cumprod_sqrt = torch.cat([alpha_cumprod_sqrt, alpha_cumprod_sqrt.new_ones([1])]) - timesteps = torch.cat([torch.tensor(list(timesteps)).new_zeros([1]) - 1, torch.tensor(list(timesteps))]) + timesteps = torch.cat([torch.tensor(list(timesteps)).new_zeros([1])-1, torch.tensor(list(timesteps))]) uc = default(uc, cond) @@ -500,51 +547,36 @@ class VideoDDIMSampler(BaseDiffusionSampler): return x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps - def denoise(self, x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep=None, idx=None, scale=None, scale_emb=None): + def denoise(self, x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep=None, idx=None, scale=None, scale_emb=None, ofs=None): additional_model_inputs = {} + if ofs is not None: + additional_model_inputs['ofs'] = ofs + if isinstance(scale, torch.Tensor) == False and scale == 1: - additional_model_inputs["idx"] = x.new_ones([x.shape[0]]) * timestep + additional_model_inputs['idx'] = x.new_ones([x.shape[0]]) * timestep if scale_emb is not None: - additional_model_inputs["scale_emb"] = scale_emb + additional_model_inputs['scale_emb'] = scale_emb denoised = denoiser(x, alpha_cumprod_sqrt, cond, **additional_model_inputs).to(torch.float32) else: - additional_model_inputs["idx"] = torch.cat([x.new_ones([x.shape[0]]) * timestep] * 2) - denoised = denoiser( - *self.guider.prepare_inputs(x, alpha_cumprod_sqrt, cond, uc), **additional_model_inputs - ).to(torch.float32) + additional_model_inputs['idx'] = torch.cat([x.new_ones([x.shape[0]]) * timestep] * 2) + denoised = denoiser(*self.guider.prepare_inputs(x, alpha_cumprod_sqrt, cond, uc), **additional_model_inputs).to(torch.float32) if isinstance(self.guider, DynamicCFG): - denoised = self.guider( - denoised, (1 - alpha_cumprod_sqrt**2) ** 0.5, step_index=self.num_steps - timestep, scale=scale - ) + denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt**2)**0.5, step_index=self.num_steps - timestep, scale=scale) else: - denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt**2) ** 0.5, scale=scale) + denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt**2)**0.5, scale=scale) return denoised - def sampler_step( - self, - alpha_cumprod_sqrt, - next_alpha_cumprod_sqrt, - denoiser, - x, - cond, - uc=None, - idx=None, - timestep=None, - scale=None, - scale_emb=None, - ): - denoised = self.denoise( - x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx, scale=scale, scale_emb=scale_emb - ).to(torch.float32) + def sampler_step(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, denoiser, x, cond, uc=None, idx=None, timestep=None, scale=None, scale_emb=None, ofs=None): + denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx, scale=scale, scale_emb=scale_emb, ofs=ofs).to(torch.float32) # 1020 - a_t = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5 + a_t = ((1-next_alpha_cumprod_sqrt**2)/(1-alpha_cumprod_sqrt**2))**0.5 b_t = next_alpha_cumprod_sqrt - alpha_cumprod_sqrt * a_t x = append_dims(a_t, x.ndim) * x + append_dims(b_t, x.ndim) * denoised return x - def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None): + def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None, ofs=None): # 1020 x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop( x, cond, uc, num_steps ) @@ -558,25 +590,83 @@ class VideoDDIMSampler(BaseDiffusionSampler): cond, uc, idx=self.num_steps - i, - timestep=timesteps[-(i + 1)], + timestep=timesteps[-(i+1)], scale=scale, scale_emb=scale_emb, + ofs=ofs # 1020 ) return x +class Image2VideoDDIMSampler(BaseDiffusionSampler): + + def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None): + alpha_cumprod_sqrt, timesteps = self.discretization( + self.num_steps if num_steps is None else num_steps, device=self.device, return_idx=True + ) + uc = default(uc, cond) + + num_sigmas = len(alpha_cumprod_sqrt) + + s_in = x.new_ones([x.shape[0]]) + + return x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps + + def denoise(self, x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep=None): + additional_model_inputs = {} + additional_model_inputs['idx'] = torch.cat([x.new_ones([x.shape[0]]) * timestep] * 2) + denoised = denoiser(*self.guider.prepare_inputs(x, alpha_cumprod_sqrt, cond, uc), **additional_model_inputs).to( + torch.float32) + if isinstance(self.guider, DynamicCFG): + denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt ** 2) ** 0.5, step_index=self.num_steps - timestep) + else: + denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt ** 2) ** 0.5) + return denoised + + def sampler_step(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, denoiser, x, cond, uc=None, idx=None, + timestep=None): + # 歀倄的sigma实际䞊是alpha_cumprod_sqrt + denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep).to(torch.float32) + if idx == 1: + return denoised + + a_t = ((1 - next_alpha_cumprod_sqrt ** 2) / (1 - alpha_cumprod_sqrt ** 2)) ** 0.5 + b_t = next_alpha_cumprod_sqrt - alpha_cumprod_sqrt * a_t + + x = append_dims(a_t, x.ndim) * x + append_dims(b_t, x.ndim) * denoised + return x + + def __call__(self, image, denoiser, x, cond, uc=None, num_steps=None): + x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop( + x, cond, uc, num_steps + ) + + for i in self.get_sigma_gen(num_sigmas): + x = self.sampler_step( + s_in * alpha_cumprod_sqrt[i], + s_in * alpha_cumprod_sqrt[i + 1], + denoiser, + x, + cond, + uc, + idx=self.num_steps - i, + timestep=timesteps[-(i + 1)] + ) + + return x + class VPSDEDPMPP2MSampler(VideoDDIMSampler): def get_variables(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt=None): - alpha_cumprod = alpha_cumprod_sqrt**2 - lamb = ((alpha_cumprod / (1 - alpha_cumprod)) ** 0.5).log() - next_alpha_cumprod = next_alpha_cumprod_sqrt**2 - lamb_next = ((next_alpha_cumprod / (1 - next_alpha_cumprod)) ** 0.5).log() + alpha_cumprod = alpha_cumprod_sqrt ** 2 + lamb = ((alpha_cumprod / (1-alpha_cumprod))**0.5).log() + next_alpha_cumprod = next_alpha_cumprod_sqrt ** 2 + lamb_next = ((next_alpha_cumprod / (1-next_alpha_cumprod))**0.5).log() h = lamb_next - lamb if previous_alpha_cumprod_sqrt is not None: - previous_alpha_cumprod = previous_alpha_cumprod_sqrt**2 - lamb_previous = ((previous_alpha_cumprod / (1 - previous_alpha_cumprod)) ** 0.5).log() + previous_alpha_cumprod = previous_alpha_cumprod_sqrt ** 2 + lamb_previous = ((previous_alpha_cumprod / (1-previous_alpha_cumprod))**0.5).log() h_last = lamb - lamb_previous r = h_last / h return h, r, lamb, lamb_next @@ -584,8 +674,8 @@ class VPSDEDPMPP2MSampler(VideoDDIMSampler): return h, None, lamb, lamb_next def get_mult(self, h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt): - mult1 = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5 * (-h).exp() - mult2 = (-2 * h).expm1() * next_alpha_cumprod_sqrt + mult1 = ((1-next_alpha_cumprod_sqrt**2) / (1-alpha_cumprod_sqrt**2))**0.5 * (-h).exp() + mult2 = (-2*h).expm1() * next_alpha_cumprod_sqrt if previous_alpha_cumprod_sqrt is not None: mult3 = 1 + 1 / (2 * r) @@ -608,21 +698,18 @@ class VPSDEDPMPP2MSampler(VideoDDIMSampler): timestep=None, scale=None, scale_emb=None, + ofs=None # 1020 ): - denoised = self.denoise( - x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx, scale=scale, scale_emb=scale_emb - ).to(torch.float32) + denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx, scale=scale, scale_emb=scale_emb, ofs=ofs).to(torch.float32) # 1020 if idx == 1: return denoised, denoised - h, r, lamb, lamb_next = self.get_variables( - alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt - ) + h, r, lamb, lamb_next = self.get_variables(alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt) mult = [ append_dims(mult, x.ndim) for mult in self.get_mult(h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt) ] - mult_noise = append_dims((1 - next_alpha_cumprod_sqrt**2) ** 0.5 * (1 - (-2 * h).exp()) ** 0.5, x.ndim) + mult_noise = append_dims((1-next_alpha_cumprod_sqrt**2)**0.5 * (1 - (-2*h).exp())**0.5, x.ndim) x_standard = mult[0] * x - mult[1] * denoised + mult_noise * torch.randn_like(x) if old_denoised is None or torch.sum(next_alpha_cumprod_sqrt) < 1e-14: @@ -636,24 +723,23 @@ class VPSDEDPMPP2MSampler(VideoDDIMSampler): return x, denoised - def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None): + def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None, ofs=None): # 1020 x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop( x, cond, uc, num_steps ) if self.fixed_frames > 0: - prefix_frames = x[:, : self.fixed_frames] + prefix_frames = x[:, :self.fixed_frames] old_denoised = None for i in self.get_sigma_gen(num_sigmas): + if self.fixed_frames > 0: if self.sdedit: rd = torch.randn_like(prefix_frames) - noised_prefix_frames = alpha_cumprod_sqrt[i] * prefix_frames + rd * append_dims( - s_in * (1 - alpha_cumprod_sqrt[i] ** 2) ** 0.5, len(prefix_frames.shape) - ) - x = torch.cat([noised_prefix_frames, x[:, self.fixed_frames :]], dim=1) + noised_prefix_frames = alpha_cumprod_sqrt[i] * prefix_frames + rd * append_dims(s_in * (1 - alpha_cumprod_sqrt[i] ** 2)**0.5, len(prefix_frames.shape)) + x = torch.cat([noised_prefix_frames, x[:, self.fixed_frames:]], dim=1) else: - x = torch.cat([prefix_frames, x[:, self.fixed_frames :]], dim=1) + x = torch.cat([prefix_frames, x[:, self.fixed_frames:]], dim=1) x, old_denoised = self.sampler_step( old_denoised, None if i == 0 else s_in * alpha_cumprod_sqrt[i - 1], @@ -664,28 +750,29 @@ class VPSDEDPMPP2MSampler(VideoDDIMSampler): cond, uc=uc, idx=self.num_steps - i, - timestep=timesteps[-(i + 1)], + timestep=timesteps[-(i+1)], scale=scale, scale_emb=scale_emb, + ofs=ofs # 1020 ) if self.fixed_frames > 0: - x = torch.cat([prefix_frames, x[:, self.fixed_frames :]], dim=1) + x = torch.cat([prefix_frames, x[:, self.fixed_frames:]], dim=1) return x class VPODEDPMPP2MSampler(VideoDDIMSampler): def get_variables(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt=None): - alpha_cumprod = alpha_cumprod_sqrt**2 - lamb = ((alpha_cumprod / (1 - alpha_cumprod)) ** 0.5).log() - next_alpha_cumprod = next_alpha_cumprod_sqrt**2 - lamb_next = ((next_alpha_cumprod / (1 - next_alpha_cumprod)) ** 0.5).log() + alpha_cumprod = alpha_cumprod_sqrt ** 2 + lamb = ((alpha_cumprod / (1-alpha_cumprod))**0.5).log() + next_alpha_cumprod = next_alpha_cumprod_sqrt ** 2 + lamb_next = ((next_alpha_cumprod / (1-next_alpha_cumprod))**0.5).log() h = lamb_next - lamb if previous_alpha_cumprod_sqrt is not None: - previous_alpha_cumprod = previous_alpha_cumprod_sqrt**2 - lamb_previous = ((previous_alpha_cumprod / (1 - previous_alpha_cumprod)) ** 0.5).log() + previous_alpha_cumprod = previous_alpha_cumprod_sqrt ** 2 + lamb_previous = ((previous_alpha_cumprod / (1-previous_alpha_cumprod))**0.5).log() h_last = lamb - lamb_previous r = h_last / h return h, r, lamb, lamb_next @@ -693,7 +780,7 @@ class VPODEDPMPP2MSampler(VideoDDIMSampler): return h, None, lamb, lamb_next def get_mult(self, h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt): - mult1 = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5 + mult1 = ((1-next_alpha_cumprod_sqrt**2) / (1-alpha_cumprod_sqrt**2))**0.5 mult2 = (-h).expm1() * next_alpha_cumprod_sqrt if previous_alpha_cumprod_sqrt is not None: @@ -714,15 +801,13 @@ class VPODEDPMPP2MSampler(VideoDDIMSampler): cond, uc=None, idx=None, - timestep=None, + timestep=None ): denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx).to(torch.float32) if idx == 1: return denoised, denoised - h, r, lamb, lamb_next = self.get_variables( - alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt - ) + h, r, lamb, lamb_next = self.get_variables(alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt) mult = [ append_dims(mult, x.ndim) for mult in self.get_mult(h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt) @@ -757,7 +842,39 @@ class VPODEDPMPP2MSampler(VideoDDIMSampler): cond, uc=uc, idx=self.num_steps - i, - timestep=timesteps[-(i + 1)], + timestep=timesteps[-(i+1)] ) return x + +class VideoDDPMSampler(VideoDDIMSampler): + def sampler_step(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, denoiser, x, cond, uc=None, idx=None): + # 歀倄的sigma实际䞊是alpha_cumprod_sqrt + denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, idx*1000//self.num_steps).to(torch.float32) + if idx == 1: + return denoised + + alpha_sqrt = alpha_cumprod_sqrt / next_alpha_cumprod_sqrt + x = append_dims(alpha_sqrt * (1-next_alpha_cumprod_sqrt**2) / (1-alpha_cumprod_sqrt**2), x.ndim) * x \ + + append_dims(next_alpha_cumprod_sqrt * (1-alpha_sqrt**2) / (1-alpha_cumprod_sqrt**2), x.ndim) * denoised \ + + append_dims(((1-next_alpha_cumprod_sqrt**2) * (1-alpha_sqrt**2) / (1-alpha_cumprod_sqrt**2))**0.5, x.ndim) * torch.randn_like(x) + + return x + + def __call__(self, denoiser, x, cond, uc=None, num_steps=None): + x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc = self.prepare_sampling_loop( + x, cond, uc, num_steps + ) + + for i in self.get_sigma_gen(num_sigmas): + x = self.sampler_step( + s_in * alpha_cumprod_sqrt[i], + s_in * alpha_cumprod_sqrt[i + 1], + denoiser, + x, + cond, + uc, + idx=self.num_steps - i + ) + + return x \ No newline at end of file diff --git a/sat/sgm/modules/diffusionmodules/sigma_sampling.py b/sat/sgm/modules/diffusionmodules/sigma_sampling.py index 770de42..8bb623e 100644 --- a/sat/sgm/modules/diffusionmodules/sigma_sampling.py +++ b/sat/sgm/modules/diffusionmodules/sigma_sampling.py @@ -17,23 +17,20 @@ class EDMSampling: class DiscreteSampling: - def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True, uniform_sampling=False): + def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True, uniform_sampling=False, group_num=0): self.num_idx = num_idx - self.sigmas = instantiate_from_config(discretization_config)(num_idx, do_append_zero=do_append_zero, flip=flip) + self.sigmas = instantiate_from_config(discretization_config)( + num_idx, do_append_zero=do_append_zero, flip=flip + ) world_size = mpu.get_data_parallel_world_size() + if world_size <= 8: + uniform_sampling = False self.uniform_sampling = uniform_sampling + self.group_num = group_num if self.uniform_sampling: - i = 1 - while True: - if world_size % i != 0 or num_idx % (world_size // i) != 0: - i += 1 - else: - self.group_num = world_size // i - break - assert self.group_num > 0 - assert world_size % self.group_num == 0 - self.group_width = world_size // self.group_num # the number of rank in one group + assert world_size % group_num == 0 + self.group_width = world_size // group_num # the number of rank in one group self.sigma_interval = self.num_idx // self.group_num def idx_to_sigma(self, idx): @@ -45,9 +42,7 @@ class DiscreteSampling: group_index = rank // self.group_width idx = default( rand, - torch.randint( - group_index * self.sigma_interval, (group_index + 1) * self.sigma_interval, (n_samples,) - ), + torch.randint(group_index * self.sigma_interval, (group_index + 1) * self.sigma_interval, (n_samples,)), ) else: idx = default( @@ -59,7 +54,6 @@ class DiscreteSampling: else: return self.idx_to_sigma(idx) - class PartialDiscreteSampling: def __init__(self, discretization_config, total_num_idx, partial_num_idx, do_append_zero=False, flip=True): self.total_num_idx = total_num_idx diff --git a/sat/vae_modules/autoencoder.py b/sat/vae_modules/autoencoder.py index 7c0cc80..9642fb4 100644 --- a/sat/vae_modules/autoencoder.py +++ b/sat/vae_modules/autoencoder.py @@ -592,8 +592,11 @@ class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine): unregularized: bool = False, input_cp: bool = False, output_cp: bool = False, + use_cp: bool = True, ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: - if self.cp_size > 0 and not input_cp: + if self.cp_size <= 1: + use_cp = False + if self.cp_size > 0 and use_cp and not input_cp: if not is_context_parallel_initialized: initialize_context_parallel(self.cp_size) @@ -603,11 +606,11 @@ class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine): x = _conv_split(x, dim=2, kernel_size=1) if return_reg_log: - z, reg_log = super().encode(x, return_reg_log, unregularized) + z, reg_log = super().encode(x, return_reg_log, unregularized, use_cp=use_cp) else: - z = super().encode(x, return_reg_log, unregularized) + z = super().encode(x, return_reg_log, unregularized, use_cp=use_cp) - if self.cp_size > 0 and not output_cp: + if self.cp_size > 0 and use_cp and not output_cp: z = _conv_gather(z, dim=2, kernel_size=1) if return_reg_log: @@ -619,23 +622,24 @@ class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine): z: torch.Tensor, input_cp: bool = False, output_cp: bool = False, - split_kernel_size: int = 1, + use_cp: bool = True, **kwargs, ): - if self.cp_size > 0 and not input_cp: + if self.cp_size <= 1: + use_cp = False + if self.cp_size > 0 and use_cp and not input_cp: if not is_context_parallel_initialized: initialize_context_parallel(self.cp_size) global_src_rank = get_context_parallel_group_rank() * self.cp_size torch.distributed.broadcast(z, src=global_src_rank, group=get_context_parallel_group()) - z = _conv_split(z, dim=2, kernel_size=split_kernel_size) + z = _conv_split(z, dim=2, kernel_size=1) - x = super().decode(z, **kwargs) - - if self.cp_size > 0 and not output_cp: - x = _conv_gather(x, dim=2, kernel_size=split_kernel_size) + x = super().decode(z, use_cp=use_cp, **kwargs) + if self.cp_size > 0 and use_cp and not output_cp: + x = _conv_gather(x, dim=2, kernel_size=1) return x def forward( diff --git a/sat/vae_modules/cp_enc_dec.py b/sat/vae_modules/cp_enc_dec.py index d50720d..1d9c34f 100644 --- a/sat/vae_modules/cp_enc_dec.py +++ b/sat/vae_modules/cp_enc_dec.py @@ -16,11 +16,7 @@ from sgm.util import ( get_context_parallel_group_rank, ) -# try: from vae_modules.utils import SafeConv3d as Conv3d -# except: -# # Degrade to normal Conv3d if SafeConv3d is not available -# from torch.nn import Conv3d def cast_tuple(t, length=1): @@ -81,8 +77,6 @@ def _split(input_, dim): cp_rank = get_context_parallel_rank() - # print('in _split, cp_rank:', cp_rank, 'input_size:', input_.shape) - inpu_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous() input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous() dim_size = input_.size()[dim] // cp_world_size @@ -94,8 +88,6 @@ def _split(input_, dim): output = torch.cat([inpu_first_frame_, output], dim=dim) output = output.contiguous() - # print('out _split, cp_rank:', cp_rank, 'output_size:', output.shape) - return output @@ -382,19 +374,6 @@ class ContextParallelCausalConv3d(nn.Module): self.cache_padding = None def forward(self, input_, clear_cache=True): - # if input_.shape[2] == 1: # handle image - # # first frame padding - # input_parallel = torch.cat([input_] * self.time_kernel_size, dim=2) - # else: - # input_parallel = conv_pass_from_last_rank(input_, self.temporal_dim, self.time_kernel_size) - - # padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) - # input_parallel = F.pad(input_parallel, padding_2d, mode = 'constant', value = 0) - - # output_parallel = self.conv(input_parallel) - # output = output_parallel - # return output - input_parallel = fake_cp_pass_from_previous_rank( input_, self.temporal_dim, self.time_kernel_size, self.cache_padding ) @@ -441,7 +420,8 @@ class ContextParallelGroupNorm(torch.nn.GroupNorm): return output -def Normalize(in_channels, gather=False, **kwargs): # same for 3D and 2D +def Normalize(in_channels, gather=False, **kwargs): + # same for 3D and 2D if gather: return ContextParallelGroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) else: @@ -488,24 +468,34 @@ class SpatialNorm3D(nn.Module): kernel_size=1, ) - def forward(self, f, zq, clear_fake_cp_cache=True): - if f.shape[2] > 1 and f.shape[2] % 2 == 1: + def forward(self, f, zq, clear_fake_cp_cache=True, fake_cp_rank0=True): + if f.shape[2] > 1 and get_context_parallel_rank() == 0 and fake_cp_rank0: f_first, f_rest = f[:, :, :1], f[:, :, 1:] f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:] zq_first, zq_rest = zq[:, :, :1], zq[:, :, 1:] zq_first = torch.nn.functional.interpolate(zq_first, size=f_first_size, mode="nearest") - zq_rest = torch.nn.functional.interpolate(zq_rest, size=f_rest_size, mode="nearest") + + zq_rest_splits = torch.split(zq_rest, 32, dim=1) + interpolated_splits = [ + torch.nn.functional.interpolate(split, size=f_rest_size, mode="nearest") for split in zq_rest_splits + ] + + zq_rest = torch.cat(interpolated_splits, dim=1) + # zq_rest = torch.nn.functional.interpolate(zq_rest, size=f_rest_size, mode="nearest") zq = torch.cat([zq_first, zq_rest], dim=2) else: - zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:], mode="nearest") + f_size = f.shape[-3:] + + zq_splits = torch.split(zq, 32, dim=1) + interpolated_splits = [ + torch.nn.functional.interpolate(split, size=f_size, mode="nearest") for split in zq_splits + ] + zq = torch.cat(interpolated_splits, dim=1) if self.add_conv: zq = self.conv(zq, clear_cache=clear_fake_cp_cache) - # f = conv_gather_from_context_parallel_region(f, dim=2, kernel_size=1) norm_f = self.norm_layer(f) - # norm_f = conv_scatter_to_context_parallel_region(norm_f, dim=2, kernel_size=1) - new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) return new_f @@ -541,23 +531,44 @@ class Upsample3D(nn.Module): self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) self.compress_time = compress_time - def forward(self, x): + def forward(self, x, fake_cp_rank0=True): if self.compress_time and x.shape[2] > 1: - if x.shape[2] % 2 == 1: + if get_context_parallel_rank() == 0 and fake_cp_rank0: + # print(x.shape) # split first frame x_first, x_rest = x[:, :, 0], x[:, :, 1:] x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0, mode="nearest") - x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0, mode="nearest") + + splits = torch.split(x_rest, 32, dim=1) + interpolated_splits = [ + torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits + ] + x_rest = torch.cat(interpolated_splits, dim=1) + + # x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0, mode="nearest") x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2) else: - x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + splits = torch.split(x, 32, dim=1) + interpolated_splits = [ + torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits + ] + x = torch.cat(interpolated_splits, dim=1) + + # x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") else: # only interpolate 2D t = x.shape[2] x = rearrange(x, "b c t h w -> (b t) c h w") - x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + # x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + + splits = torch.split(x, 32, dim=1) + interpolated_splits = [ + torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits + ] + x = torch.cat(interpolated_splits, dim=1) + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) if self.with_conv: @@ -579,21 +590,30 @@ class DownSample3D(nn.Module): self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0) self.compress_time = compress_time - def forward(self, x): + def forward(self, x, fake_cp_rank0=True): if self.compress_time and x.shape[2] > 1: h, w = x.shape[-2:] x = rearrange(x, "b c t h w -> (b h w) c t") - if x.shape[-1] % 2 == 1: + if get_context_parallel_rank() == 0 and fake_cp_rank0: # split first frame x_first, x_rest = x[..., 0], x[..., 1:] if x_rest.shape[-1] > 0: - x_rest = torch.nn.functional.avg_pool1d(x_rest, kernel_size=2, stride=2) + splits = torch.split(x_rest, 32, dim=1) + interpolated_splits = [ + torch.nn.functional.avg_pool1d(split, kernel_size=2, stride=2) for split in splits + ] + x_rest = torch.cat(interpolated_splits, dim=1) x = torch.cat([x_first[..., None], x_rest], dim=-1) x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w) else: - x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2) + # x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2) + splits = torch.split(x, 32, dim=1) + interpolated_splits = [ + torch.nn.functional.avg_pool1d(split, kernel_size=2, stride=2) for split in splits + ] + x = torch.cat(interpolated_splits, dim=1) x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w) if self.with_conv: @@ -673,13 +693,13 @@ class ContextParallelResnetBlock3D(nn.Module): padding=0, ) - def forward(self, x, temb, zq=None, clear_fake_cp_cache=True): + def forward(self, x, temb, zq=None, clear_fake_cp_cache=True, fake_cp_rank0=True): h = x # if isinstance(self.norm1, torch.nn.GroupNorm): # h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1) if zq is not None: - h = self.norm1(h, zq, clear_fake_cp_cache=clear_fake_cp_cache) + h = self.norm1(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0) else: h = self.norm1(h) # if isinstance(self.norm1, torch.nn.GroupNorm): @@ -694,7 +714,7 @@ class ContextParallelResnetBlock3D(nn.Module): # if isinstance(self.norm2, torch.nn.GroupNorm): # h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1) if zq is not None: - h = self.norm2(h, zq, clear_fake_cp_cache=clear_fake_cp_cache) + h = self.norm2(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0) else: h = self.norm2(h) # if isinstance(self.norm2, torch.nn.GroupNorm): @@ -807,23 +827,24 @@ class ContextParallelEncoder3D(nn.Module): kernel_size=3, ) - def forward(self, x, **kwargs): + def forward(self, x, clear_fake_cp_cache=True, fake_cp_rank0=True): # timestep embedding temb = None # downsampling - h = self.conv_in(x) + h = self.conv_in(x, clear_cache=clear_fake_cp_cache) for i_level in range(self.num_resolutions): for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](h, temb) + h = self.down[i_level].block[i_block](h, temb, clear_fake_cp_cache=clear_fake_cp_cache) if len(self.down[i_level].attn) > 0: + print("Attention not implemented") h = self.down[i_level].attn[i_block](h) if i_level != self.num_resolutions - 1: - h = self.down[i_level].downsample(h) + h = self.down[i_level].downsample(h, fake_cp_rank0=fake_cp_rank0) # middle - h = self.mid.block_1(h, temb) - h = self.mid.block_2(h, temb) + h = self.mid.block_1(h, temb, clear_fake_cp_cache=clear_fake_cp_cache) + h = self.mid.block_2(h, temb, clear_fake_cp_cache=clear_fake_cp_cache) # end # h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1) @@ -831,7 +852,7 @@ class ContextParallelEncoder3D(nn.Module): # h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1) h = nonlinearity(h) - h = self.conv_out(h) + h = self.conv_out(h, clear_cache=clear_fake_cp_cache) return h @@ -934,6 +955,11 @@ class ContextParallelDecoder3D(nn.Module): up.block = block up.attn = attn if i_level != 0: + # # Symmetrical enc-dec + if i_level <= self.temporal_compress_level: + up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=True) + else: + up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False) if i_level < self.num_resolutions - self.temporal_compress_level: up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False) else: @@ -948,7 +974,7 @@ class ContextParallelDecoder3D(nn.Module): kernel_size=3, ) - def forward(self, z, clear_fake_cp_cache=True, **kwargs): + def forward(self, z, clear_fake_cp_cache=True, fake_cp_rank0=True): self.last_z_shape = z.shape # timestep embedding @@ -961,23 +987,25 @@ class ContextParallelDecoder3D(nn.Module): h = self.conv_in(z, clear_cache=clear_fake_cp_cache) # middle - h = self.mid.block_1(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache) - h = self.mid.block_2(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache) + h = self.mid.block_1(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0) + h = self.mid.block_2(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0) # upsampling for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks + 1): - h = self.up[i_level].block[i_block](h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache) + h = self.up[i_level].block[i_block]( + h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0 + ) if len(self.up[i_level].attn) > 0: h = self.up[i_level].attn[i_block](h, zq) if i_level != 0: - h = self.up[i_level].upsample(h) + h = self.up[i_level].upsample(h, fake_cp_rank0=fake_cp_rank0) # end if self.give_pre_end: return h - h = self.norm_out(h, zq, clear_fake_cp_cache=clear_fake_cp_cache) + h = self.norm_out(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0) h = nonlinearity(h) h = self.conv_out(h, clear_cache=clear_fake_cp_cache) diff --git a/tools/convert_weight_sat2hf.py b/tools/convert_weight_sat2hf.py index 183be62..f325018 100644 --- a/tools/convert_weight_sat2hf.py +++ b/tools/convert_weight_sat2hf.py @@ -1,22 +1,15 @@ """ -This script demonstrates how to convert and generate video from a text prompt -using CogVideoX with 🀗Huggingface Diffusers Pipeline. -This script requires the `diffusers>=0.30.2` library to be installed. - -Functions: - - reassign_query_key_value_inplace: Reassigns the query, key, and value weights in-place. - - reassign_query_key_layernorm_inplace: Reassigns layer normalization for query and key in-place. - - reassign_adaln_norm_inplace: Reassigns adaptive layer normalization in-place. - - remove_keys_inplace: Removes specified keys from the state_dict in-place. - - replace_up_keys_inplace: Replaces keys in the "up" block in-place. - - get_state_dict: Extracts the state_dict from a saved checkpoint. - - update_state_dict_inplace: Updates the state_dict with new key assignments in-place. - - convert_transformer: Converts a transformer checkpoint to the CogVideoX format. - - convert_vae: Converts a VAE checkpoint to the CogVideoX format. - - get_args: Parses command-line arguments for the script. - - generate_video: Generates a video from a text prompt using the CogVideoX pipeline. -""" +The script demonstrates how to convert the weights of the CogVideoX model from SAT to Hugging Face format. +This script supports the conversion of the following models: +- CogVideoX-2B +- CogVideoX-5B, CogVideoX-5B-I2V +- CogVideoX1.1-5B, CogVideoX1.1-5B-I2V + +Original Script: +https://github.com/huggingface/diffusers/blob/main/scripts/convert_cogvideox_to_diffusers.py + +""" import argparse from typing import Any, Dict @@ -153,12 +146,12 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: def convert_transformer( - ckpt_path: str, - num_layers: int, - num_attention_heads: int, - use_rotary_positional_embeddings: bool, - i2v: bool, - dtype: torch.dtype, + ckpt_path: str, + num_layers: int, + num_attention_heads: int, + use_rotary_positional_embeddings: bool, + i2v: bool, + dtype: torch.dtype, ): PREFIX_KEY = "model.diffusion_model." @@ -172,7 +165,7 @@ def convert_transformer( ).to(dtype=dtype) for key in list(original_state_dict.keys()): - new_key = key[len(PREFIX_KEY) :] + new_key = key[len(PREFIX_KEY):] for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): new_key = new_key.replace(replace_key, rename_key) update_state_dict_inplace(original_state_dict, key, new_key) @@ -209,7 +202,8 @@ def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype): def get_args(): parser = argparse.ArgumentParser() parser.add_argument( - "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint") + "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint" + ) parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint") parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") parser.add_argument("--fp16", action="store_true", default=False, help="Whether to save the model weights in fp16") @@ -259,9 +253,10 @@ if __name__ == "__main__": if args.vae_ckpt_path is not None: vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, dtype) - text_encoder_id = "google/t5-v1_1-xxl" + text_encoder_id = "/share/official_pretrains/hf_home/t5-v1_1-xxl" tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH) text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir) + # Apparently, the conversion does not work anymore without this :shrug: for param in text_encoder.parameters(): param.data = param.data.contiguous() @@ -301,4 +296,7 @@ if __name__ == "__main__": # We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird # for users to specify variant when the default is not fp32 and they want to run with the correct default (which # is either fp16/bf16 here). - pipe.save_pretrained(args.output_path, safe_serialization=True, push_to_hub=args.push_to_hub) + + # This is necessary This is necessary for users with insufficient memory, + # such as those using Colab and notebooks, as it can save some memory used for model loading. + pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=args.push_to_hub)