Merge pull request #469 from THUDM/CogVideoX_dev

CogVideoX1.5-SAT
This commit is contained in:
Yuxuan.Zhang 2024-11-08 13:50:19 +08:00 committed by GitHub
commit ddd3dcd7eb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 1387 additions and 725 deletions

View File

@ -22,7 +22,10 @@ Experience the CogVideoX-5B model online at <a href="https://huggingface.co/spac
## Project Updates ## Project Updates
- 🔥🔥 **News**: ```2024/10/13```: A more cost-effective fine-tuning framework for `CogVideoX-5B` that works with a single - 🔥🔥 News: ```2024/11/08```: We have released the CogVideoX1.5 model. CogVideoX1.5 is an upgraded version of the open-source model CogVideoX.
The CogVideoX1.5-5B series supports 10-second videos with higher resolution, and CogVideoX1.5-5B-I2V supports video generation at any resolution.
The SAT code has already been updated, while the diffusers version is still under adaptation. Download the SAT version code [here](https://huggingface.co/THUDM/CogVideoX1.5-5B-SAT).
- 🔥 **News**: ```2024/10/13```: A more cost-effective fine-tuning framework for `CogVideoX-5B` that works with a single
4090 GPU, [cogvideox-factory](https://github.com/a-r-r-o-w/cogvideox-factory), has been released. It supports 4090 GPU, [cogvideox-factory](https://github.com/a-r-r-o-w/cogvideox-factory), has been released. It supports
fine-tuning with multiple resolutions. Feel free to use it! fine-tuning with multiple resolutions. Feel free to use it!
- 🔥 **News**: ```2024/10/10```: We have updated our technical report. Please - 🔥 **News**: ```2024/10/10```: We have updated our technical report. Please
@ -68,7 +71,6 @@ Jump to a specific section:
- [Tools](#tools) - [Tools](#tools)
- [Introduction to CogVideo(ICLR'23) Model](#cogvideoiclr23) - [Introduction to CogVideo(ICLR'23) Model](#cogvideoiclr23)
- [Citations](#Citation) - [Citations](#Citation)
- [Open Source Project Plan](#Open-Source-Project-Plan)
- [Model License](#Model-License) - [Model License](#Model-License)
## Quick Start ## Quick Start
@ -172,79 +174,85 @@ models we currently offer, along with their foundational information.
<th style="text-align: center;">CogVideoX-2B</th> <th style="text-align: center;">CogVideoX-2B</th>
<th style="text-align: center;">CogVideoX-5B</th> <th style="text-align: center;">CogVideoX-5B</th>
<th style="text-align: center;">CogVideoX-5B-I2V</th> <th style="text-align: center;">CogVideoX-5B-I2V</th>
<th style="text-align: center;">CogVideoX1.5-5B</th>
<th style="text-align: center;">CogVideoX1.5-5B-I2V</th>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">Model Description</td> <td style="text-align: center;">Release Date</td>
<td style="text-align: center;">Entry-level model, balancing compatibility. Low cost for running and secondary development.</td> <th style="text-align: center;">August 6, 2024</th>
<td style="text-align: center;">Larger model with higher video generation quality and better visual effects.</td> <th style="text-align: center;">August 27, 2024</th>
<td style="text-align: center;">CogVideoX-5B image-to-video version.</td> <th style="text-align: center;">September 19, 2024</th>
<th style="text-align: center;">November 8, 2024</th>
<th style="text-align: center;">November 8, 2024</th>
</tr>
<tr>
<td style="text-align: center;">Video Resolution</td>
<td colspan="3" style="text-align: center;">720 * 480</td>
<td colspan="1" style="text-align: center;">1360 * 768</td>
<td colspan="1" style="text-align: center;">256 <= W <=1360<br>256 <= H <=768<br> W,H % 16 == 0</td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">Inference Precision</td> <td style="text-align: center;">Inference Precision</td>
<td style="text-align: center;"><b>FP16*(recommended)</b>, BF16, FP32, FP8*, INT8, not supported: INT4</td> <td style="text-align: center;"><b>FP16*(recommended)</b>, BF16, FP32, FP8*, INT8, not supported: INT4</td>
<td colspan="2" style="text-align: center;"><b>BF16 (recommended)</b>, FP16, FP32, FP8*, INT8, not supported: INT4</td> <td colspan="2" style="text-align: center;"><b>BF16(recommended)</b>, FP16, FP32, FP8*, INT8, not supported: INT4</td>
<td colspan="2" style="text-align: center;"><b>BF16</b></td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">Single GPU Memory Usage<br></td> <td style="text-align: center;">Single GPU Memory Usage</td>
<td style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> FP16: 18GB <br><b>diffusers FP16: from 4GB* </b><br><b>diffusers INT8 (torchao): from 3.6GB*</b></td> <td style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> FP16: 18GB<br><b>diffusers FP16: from 4GB*</b><br><b>diffusers INT8(torchao): from 3.6GB*</b></td>
<td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 26GB <br><b>diffusers BF16: from 5GB* </b><br><b>diffusers INT8 (torchao): from 4.4GB*</b></td> <td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 26GB<br><b>diffusers BF16 : from 5GB*</b><br><b>diffusers INT8(torchao): from 4.4GB*</b></td>
<td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 66GB<br></td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">Multi-GPU Inference Memory Usage</td> <td style="text-align: center;">Multi-GPU Memory Usage</td>
<td style="text-align: center;"><b>FP16: 10GB* using diffusers</b><br></td> <td style="text-align: center;"><b>FP16: 10GB* using diffusers</b><br></td>
<td colspan="2" style="text-align: center;"><b>BF16: 15GB* using diffusers</b><br></td> <td colspan="2" style="text-align: center;"><b>BF16: 15GB* using diffusers</b><br></td>
<td colspan="2" style="text-align: center;"><b>Not supported</b><br></td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">Inference Speed<br>(Step = 50, FP/BF16)</td> <td style="text-align: center;">Inference Speed<br>(Step = 50, FP/BF16)</td>
<td style="text-align: center;">Single A100: ~90 seconds<br>Single H100: ~45 seconds</td> <td style="text-align: center;">Single A100: ~90 seconds<br>Single H100: ~45 seconds</td>
<td colspan="2" style="text-align: center;">Single A100: ~180 seconds<br>Single H100: ~90 seconds</td> <td colspan="2" style="text-align: center;">Single A100: ~180 seconds<br>Single H100: ~90 seconds</td>
</tr> <td colspan="2" style="text-align: center;">Single A100: ~1000 seconds (5-second video)<br>Single H100: ~550 seconds (5-second video)</td>
<tr>
<td style="text-align: center;">Fine-tuning Precision</td>
<td style="text-align: center;"><b>FP16</b></td>
<td colspan="2" style="text-align: center;"><b>BF16</b></td>
</tr>
<tr>
<td style="text-align: center;">Fine-tuning Memory Usage</td>
<td style="text-align: center;">47 GB (bs=1, LORA)<br> 61 GB (bs=2, LORA)<br> 62GB (bs=1, SFT)</td>
<td style="text-align: center;">63 GB (bs=1, LORA)<br> 80 GB (bs=2, LORA)<br> 75GB (bs=1, SFT)<br></td>
<td style="text-align: center;">78 GB (bs=1, LORA)<br> 75GB (bs=1, SFT, 16GPU)<br></td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">Prompt Language</td> <td style="text-align: center;">Prompt Language</td>
<td colspan="3" style="text-align: center;">English*</td> <td colspan="5" style="text-align: center;">English*</td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">Maximum Prompt Length</td> <td style="text-align: center;">Prompt Token Limit</td>
<td colspan="3" style="text-align: center;">226 Tokens</td> <td colspan="3" style="text-align: center;">226 Tokens</td>
<td colspan="2" style="text-align: center;">224 Tokens</td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">Video Length</td> <td style="text-align: center;">Video Length</td>
<td colspan="3" style="text-align: center;">6 Seconds</td> <td colspan="3" style="text-align: center;">6 seconds</td>
<td colspan="2" style="text-align: center;">5 or 10 seconds</td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">Frame Rate</td> <td style="text-align: center;">Frame Rate</td>
<td colspan="3" style="text-align: center;">8 Frames / Second</td> <td colspan="3" style="text-align: center;">8 frames / second</td>
<td colspan="2" style="text-align: center;">16 frames / second</td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">Video Resolution</td> <td style="text-align: center;">Positional Encoding</td>
<td colspan="3" style="text-align: center;">720 x 480, no support for other resolutions (including fine-tuning)</td>
</tr>
<tr>
<td style="text-align: center;">Position Encoding</td>
<td style="text-align: center;">3d_sincos_pos_embed</td> <td style="text-align: center;">3d_sincos_pos_embed</td>
<td style="text-align: center;">3d_sincos_pos_embed</td> <td style="text-align: center;">3d_sincos_pos_embed</td>
<td style="text-align: center;">3d_rope_pos_embed + learnable_pos_embed</td> <td style="text-align: center;">3d_rope_pos_embed + learnable_pos_embed</td>
<td style="text-align: center;">3d_sincos_pos_embed</td>
<td style="text-align: center;">3d_rope_pos_embed + learnable_pos_embed</td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">Download Link (Diffusers)</td> <td style="text-align: center;">Download Link (Diffusers)</td>
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-2b">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-2b">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-2b">🟣 WiseModel</a></td> <td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-2b">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-2b">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-2b">🟣 WiseModel</a></td>
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-5b">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-5b">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-5b">🟣 WiseModel</a></td> <td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-5b">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-5b">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-5b">🟣 WiseModel</a></td>
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-5b-I2V">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-5b-I2V">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-5b-I2V">🟣 WiseModel</a></td> <td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-5b-I2V">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-5b-I2V">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-5b-I2V">🟣 WiseModel</a></td>
<td colspan="2" style="text-align: center;"> Coming Soon </td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">Download Link (SAT)</td> <td style="text-align: center;">Download Link (SAT)</td>
<td colspan="3" style="text-align: center;"><a href="./sat/README.md">SAT</a></td> <td colspan="3" style="text-align: center;"><a href="./sat/README_zh.md">SAT</a></td>
<td colspan="2" style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX1.5-5b-SAT">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX1.5-5b-SAT">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX1.5-5b-SAT">🟣 WiseModel</a></td>
</tr> </tr>
</table> </table>
@ -422,7 +430,7 @@ hands-on practice on text-to-video generation. *The original input is in Chinese
We welcome your contributions! You can click [here](resources/contribute.md) for more information. We welcome your contributions! You can click [here](resources/contribute.md) for more information.
## License Agreement ## Model-License
The code in this repository is released under the [Apache 2.0 License](LICENSE). The code in this repository is released under the [Apache 2.0 License](LICENSE).

View File

@ -1,6 +1,6 @@
# CogVideo & CogVideoX # CogVideo & CogVideoX
[Read this in English](./README_zh.md) [Read this in English](./README.md)
[中文阅读](./README_zh.md) [中文阅读](./README_zh.md)
@ -22,9 +22,14 @@
## 更新とニュース ## 更新とニュース
- 🔥🔥 **ニュース**: ```2024/10/13```: コスト削減のため、単一の4090 GPUで`CogVideoX-5B` - 🔥🔥 ニュース: ```2024/11/08```: `CogVideoX1.5` モデルをリリースしました。CogVideoX1.5 は CogVideoX オープンソースモデルのアップグレードバージョンです。
CogVideoX1.5-5B シリーズモデルは、10秒 長の動画とより高い解像度をサポートしており、`CogVideoX1.5-5B-I2V` は任意の解像度での動画生成に対応しています。
SAT コードはすでに更新されており、`diffusers` バージョンは現在適応中です。
SAT バージョンのコードは [こちら](https://huggingface.co/THUDM/CogVideoX1.5-5B-SAT) からダウンロードできます。
- 🔥 **ニュース**: ```2024/10/13```: コスト削減のため、単一の4090 GPUで`CogVideoX-5B`
を微調整できるフレームワーク [cogvideox-factory](https://github.com/a-r-r-o-w/cogvideox-factory) を微調整できるフレームワーク [cogvideox-factory](https://github.com/a-r-r-o-w/cogvideox-factory)
がリリースされました。複数の解像度での微調整に対応しています。ぜひご利用ください!- 🔥**ニュース**: ```2024/10/10```: がリリースされました。複数の解像度での微調整に対応しています。ぜひご利用ください!
- 🔥**ニュース**: ```2024/10/10```:
技術報告書を更新し、より詳細なトレーニング情報とデモを追加しました。 技術報告書を更新し、より詳細なトレーニング情報とデモを追加しました。
- 🔥 **ニュース**: ```2024/10/10```: 技術報告書を更新しました。[こちら](https://arxiv.org/pdf/2408.06072) - 🔥 **ニュース**: ```2024/10/10```: 技術報告書を更新しました。[こちら](https://arxiv.org/pdf/2408.06072)
をクリックしてご覧ください。さらにトレーニングの詳細とデモを追加しました。デモを見るには[こちら](https://yzy-thu.github.io/CogVideoX-demo/) をクリックしてご覧ください。さらにトレーニングの詳細とデモを追加しました。デモを見るには[こちら](https://yzy-thu.github.io/CogVideoX-demo/)
@ -34,7 +39,7 @@
- 🔥**ニュース**: ```2024/9/19```: CogVideoXシリーズの画像生成ビデオモデル **CogVideoX-5B-I2V** - 🔥**ニュース**: ```2024/9/19```: CogVideoXシリーズの画像生成ビデオモデル **CogVideoX-5B-I2V**
をオープンソース化しました。このモデルは、画像を背景入力として使用し、プロンプトワードと組み合わせてビデオを生成することができ、より高い制御性を提供します。これにより、CogVideoXシリーズのモデルは、テキストからビデオ生成、ビデオの継続、画像からビデオ生成の3つのタスクをサポートするようになりました。オンラインでの[体験](https://huggingface.co/spaces/THUDM/CogVideoX-5B-Space) をオープンソース化しました。このモデルは、画像を背景入力として使用し、プロンプトワードと組み合わせてビデオを生成することができ、より高い制御性を提供します。これにより、CogVideoXシリーズのモデルは、テキストからビデオ生成、ビデオの継続、画像からビデオ生成の3つのタスクをサポートするようになりました。オンラインでの[体験](https://huggingface.co/spaces/THUDM/CogVideoX-5B-Space)
をお楽しみください。 をお楽しみください。
- 🔥🔥 **ニュース**: ```2024/9/19```: - 🔥 **ニュース**: ```2024/9/19```:
CogVideoXのトレーニングプロセスでビデオデータをテキスト記述に変換するために使用されるキャプションモデル [CogVLM2-Caption](https://huggingface.co/THUDM/cogvlm2-llama3-caption) CogVideoXのトレーニングプロセスでビデオデータをテキスト記述に変換するために使用されるキャプションモデル [CogVLM2-Caption](https://huggingface.co/THUDM/cogvlm2-llama3-caption)
をオープンソース化しました。ダウンロードしてご利用ください。 をオープンソース化しました。ダウンロードしてご利用ください。
- 🔥 ```2024/8/27```: CogVideoXシリーズのより大きなモデル **CogVideoX-5B** - 🔥 ```2024/8/27```: CogVideoXシリーズのより大きなモデル **CogVideoX-5B**
@ -63,11 +68,10 @@
- [プロジェクト構造](#プロジェクト構造) - [プロジェクト構造](#プロジェクト構造)
- [推論](#推論) - [推論](#推論)
- [sat](#sat) - [sat](#sat)
- [ツール](#ツール) - [ツール](#ツール)=
- [プロジェクト計画](#プロジェクト計画)
- [モデルライセンス](#モデルライセンス)
- [CogVideo(ICLR'23)モデル紹介](#CogVideoICLR23) - [CogVideo(ICLR'23)モデル紹介](#CogVideoICLR23)
- [引用](#引用) - [引用](#引用)
- [ライセンス契約](#ライセンス契約)
## クイックスタート ## クイックスタート
@ -156,79 +160,91 @@ pip install -r requirements.txt
CogVideoXは、[清影](https://chatglm.cn/video?fr=osm_cogvideox) と同源のオープンソース版ビデオ生成モデルです。 CogVideoXは、[清影](https://chatglm.cn/video?fr=osm_cogvideox) と同源のオープンソース版ビデオ生成モデルです。
以下の表に、提供しているビデオ生成モデルの基本情報を示します: 以下の表に、提供しているビデオ生成モデルの基本情報を示します:
<table style="border-collapse: collapse; width: 100%;"> <table style="border-collapse: collapse; width: 100%;">
<tr> <tr>
<th style="text-align: center;">モデル名</th> <th style="text-align: center;">モデル名</th>
<th style="text-align: center;">CogVideoX-2B</th> <th style="text-align: center;">CogVideoX-2B</th>
<th style="text-align: center;">CogVideoX-5B</th> <th style="text-align: center;">CogVideoX-5B</th>
<th style="text-align: center;">CogVideoX-5B-I2V </th> <th style="text-align: center;">CogVideoX-5B-I2V</th>
<th style="text-align: center;">CogVideoX1.5-5B</th>
<th style="text-align: center;">CogVideoX1.5-5B-I2V</th>
</tr>
<tr>
<td style="text-align: center;">リリース日</td>
<th style="text-align: center;">2024年8月6日</th>
<th style="text-align: center;">2024年8月27日</th>
<th style="text-align: center;">2024年9月19日</th>
<th style="text-align: center;">2024年11月8日</th>
<th style="text-align: center;">2024年11月8日</th>
</tr>
<tr>
<td style="text-align: center;">ビデオ解像度</td>
<td colspan="3" style="text-align: center;">720 * 480</td>
<td colspan="1" style="text-align: center;">1360 * 768</td>
<td colspan="1" style="text-align: center;">256 <= W <=1360<br>256 <= H <=768<br> W,H % 16 == 0</td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">推論精度</td> <td style="text-align: center;">推論精度</td>
<td style="text-align: center;"><b>FP16*(推奨)</b>, BF16, FP32, FP8*, INT8, INT4は非対応</td> <td style="text-align: center;"><b>FP16*(推奨)</b>, BF16, FP32, FP8*, INT8, INT4は非対応</td>
<td colspan="2" style="text-align: center;"><b>BF16(推奨)</b>, FP16, FP32, FP8*, INT8, INT4は非対応</td> <td colspan="2" style="text-align: center;"><b>BF16(推奨)</b>, FP16, FP32, FP8*, INT8, INT4は非対応</td>
</tr>
<tr>
<td style="text-align: center;">単一GPUのメモリ消費<br></td>
<td style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> FP16: 18GB <br><b>diffusers FP16: 4GBから* </b><br><b>diffusers INT8(torchao): 3.6GBから*</b></td>
<td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 26GB <br><b>diffusers BF16 : 5GBから* </b><br><b>diffusers INT8(torchao): 4.4GBから* </b></td>
</tr>
<tr>
<td style="text-align: center;">マルチGPUのメモリ消費</td>
<td style="text-align: center;"><b>FP16: 10GB* using diffusers</b><br></td>
<td colspan="2" style="text-align: center;"><b>BF16: 15GB* using diffusers</b><br></td>
</tr>
<tr>
<td style="text-align: center;">推論速度<br>(ステップ = 50, FP/BF16)</td>
<td style="text-align: center;">単一A100: 約90秒<br>単一H100: 約45秒</td>
<td colspan="2" style="text-align: center;">単一A100: 約180秒<br>単一H100: 約90秒</td>
</tr>
<tr>
<td style="text-align: center;">ファインチューニング精度</td>
<td style="text-align: center;"><b>FP16</b></td>
<td colspan="2" style="text-align: center;"><b>BF16</b></td> <td colspan="2" style="text-align: center;"><b>BF16</b></td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">ファインチューニング時のメモリ消費</td> <td style="text-align: center;">シングルGPUメモリ消費</td>
<td style="text-align: center;">47 GB (bs=1, LORA)<br> 61 GB (bs=2, LORA)<br> 62GB (bs=1, SFT)</td> <td style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> FP16: 18GB<br><b>diffusers FP16: 4GBから*</b><br><b>diffusers INT8(torchao): 3.6GBから*</b></td>
<td style="text-align: center;">63 GB (bs=1, LORA)<br> 80 GB (bs=2, LORA)<br> 75GB (bs=1, SFT)<br></td> <td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 26GB<br><b>diffusers BF16: 5GBから*</b><br><b>diffusers INT8(torchao): 4.4GBから*</b></td>
<td style="text-align: center;">78 GB (bs=1, LORA)<br> 75GB (bs=1, SFT, 16GPU)<br></td> <td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 66GB<br></td>
</tr>
<tr>
<td style="text-align: center;">マルチGPUメモリ消費</td>
<td style="text-align: center;"><b>FP16: 10GB* using diffusers</b><br></td>
<td colspan="2" style="text-align: center;"><b>BF16: 15GB* using diffusers</b><br></td>
<td colspan="2" style="text-align: center;"><b>サポートなし</b><br></td>
</tr>
<tr>
<td style="text-align: center;">推論速度<br>(ステップ数 = 50, FP/BF16)</td>
<td style="text-align: center;">単一A100: 約90秒<br>単一H100: 約45秒</td>
<td colspan="2" style="text-align: center;">単一A100: 約180秒<br>単一H100: 約90秒</td>
<td colspan="2" style="text-align: center;">単一A100: 約1000秒(5秒動画)<br>単一H100: 約550秒(5秒動画)</td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">プロンプト言語</td> <td style="text-align: center;">プロンプト言語</td>
<td colspan="3" style="text-align: center;">英語*</td> <td colspan="5" style="text-align: center;">英語*</td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">プロンプトの最大トークン数</td> <td style="text-align: center;">プロンプトトークン制限</td>
<td colspan="3" style="text-align: center;">226トークン</td> <td colspan="3" style="text-align: center;">226トークン</td>
<td colspan="2" style="text-align: center;">224トークン</td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">ビデオの長さ</td> <td style="text-align: center;">ビデオの長さ</td>
<td colspan="3" style="text-align: center;">6秒</td> <td colspan="3" style="text-align: center;">6秒</td>
<td colspan="2" style="text-align: center;">5秒または10秒</td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">フレームレート</td> <td style="text-align: center;">フレームレート</td>
<td colspan="3" style="text-align: center;">8フレーム/秒</td> <td colspan="3" style="text-align: center;">8 フレーム / 秒</td>
</tr> <td colspan="2" style="text-align: center;">16 フレーム / 秒</td>
<tr>
<td style="text-align: center;">ビデオ解像度</td>
<td colspan="3" style="text-align: center;">720 * 480、他の解像度は非対応(ファインチューニング含む)</td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">位置エンコーディング</td> <td style="text-align: center;">位置エンコーディング</td>
<td style="text-align: center;">3d_sincos_pos_embed</td> <td style="text-align: center;">3d_sincos_pos_embed</td>
<td style="text-align: center;">3d_sincos_pos_embed</td> <td style="text-align: center;">3d_sincos_pos_embed</td>
<td style="text-align: center;">3d_rope_pos_embed + learnable_pos_embed</td> <td style="text-align: center;">3d_rope_pos_embed + learnable_pos_embed</td>
<td style="text-align: center;">3d_sincos_pos_embed</td>
<td style="text-align: center;">3d_rope_pos_embed + learnable_pos_embed</td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">ダウンロードリンク (Diffusers)</td> <td style="text-align: center;">ダウンロードリンク (Diffusers)</td>
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-2b">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-2b">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-2b">🟣 WiseModel</a></td> <td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-2b">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-2b">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-2b">🟣 WiseModel</a></td>
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-5b">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-5b">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-5b">🟣 WiseModel</a></td> <td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-5b">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-5b">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-5b">🟣 WiseModel</a></td>
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-5b-I2V">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-5b-I2V">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-5b-I2V">🟣 WiseModel</a></td> <td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-5b-I2V">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-5b-I2V">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-5b-I2V">🟣 WiseModel</a></td>
<td colspan="2" style="text-align: center;">近日公開</td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">ダウンロードリンク (SAT)</td> <td style="text-align: center;">ダウンロードリンク (SAT)</td>
<td colspan="3" style="text-align: center;"><a href="./sat/README_ja.md">SAT</a></td> <td colspan="3" style="text-align: center;"><a href="./sat/README_zh.md">SAT</a></td>
<td colspan="2" style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX1.5-5b-SAT">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX1.5-5b-SAT">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX1.5-5b-SAT">🟣 WiseModel</a></td>
</tr> </tr>
</table> </table>

View File

@ -1,10 +1,9 @@
# CogVideo & CogVideoX # CogVideo & CogVideoX
[Read this in English](./README_zh.md) [Read this in English](./README.md)
[日本語で読む](./README_ja.md) [日本語で読む](./README_ja.md)
<div align="center"> <div align="center">
<img src=resources/logo.svg width="50%"/> <img src=resources/logo.svg width="50%"/>
</div> </div>
@ -23,7 +22,9 @@
## 项目更新 ## 项目更新
- 🔥🔥 **News**: ```2024/10/13```: 成本更低单卡4090可微调`CogVideoX-5B` - 🔥🔥 **News**: ```2024/11/08```: 我们发布 `CogVideoX1.5` 模型。CogVideoX1.5 是 CogVideoX 开源模型的升级版本。
CogVideoX1.5-5B 系列模型支持 **10秒** 长度的视频和更高的分辨率,其中 `CogVideoX1.5-5B-I2V` 支持 **任意分辨率** 的视频生成SAT代码已经更新。`diffusers`版本还在适配中。SAT版本代码前往 [这里](https://huggingface.co/THUDM/CogVideoX1.5-5B-SAT) 下载。
- 🔥**News**: ```2024/10/13```: 成本更低单卡4090可微调 `CogVideoX-5B`
的微调框架[cogvideox-factory](https://github.com/a-r-r-o-w/cogvideox-factory)已经推出,多种分辨率微调,欢迎使用。 的微调框架[cogvideox-factory](https://github.com/a-r-r-o-w/cogvideox-factory)已经推出,多种分辨率微调,欢迎使用。
- 🔥 **News**: ```2024/10/10```: 我们更新了我们的技术报告,请点击 [这里](https://arxiv.org/pdf/2408.06072) - 🔥 **News**: ```2024/10/10```: 我们更新了我们的技术报告,请点击 [这里](https://arxiv.org/pdf/2408.06072)
查看附上了更多的训练细节和demo关于demo点击[这里](https://yzy-thu.github.io/CogVideoX-demo/) 查看。 查看附上了更多的训练细节和demo关于demo点击[这里](https://yzy-thu.github.io/CogVideoX-demo/) 查看。
@ -58,10 +59,9 @@
- [Inference](#inference) - [Inference](#inference)
- [SAT](#sat) - [SAT](#sat)
- [Tools](#tools) - [Tools](#tools)
- [开源项目规划](#开源项目规划)
- [模型协议](#模型协议)
- [CogVideo(ICLR'23)模型介绍](#cogvideoiclr23) - [CogVideo(ICLR'23)模型介绍](#cogvideoiclr23)
- [引用](#引用) - [引用](#引用)
- [模型协议](#模型协议)
## 快速开始 ## 快速开始
@ -157,62 +157,72 @@ CogVideoX是 [清影](https://chatglm.cn/video?fr=osm_cogvideox) 同源的开源
<th style="text-align: center;">CogVideoX-2B</th> <th style="text-align: center;">CogVideoX-2B</th>
<th style="text-align: center;">CogVideoX-5B</th> <th style="text-align: center;">CogVideoX-5B</th>
<th style="text-align: center;">CogVideoX-5B-I2V </th> <th style="text-align: center;">CogVideoX-5B-I2V </th>
<th style="text-align: center;">CogVideoX1.5-5B</th>
<th style="text-align: center;">CogVideoX1.5-5B-I2V</th>
</tr>
<tr>
<td style="text-align: center;">发布时间</td>
<th style="text-align: center;">2024年8月6日</th>
<th style="text-align: center;">2024年8月27日</th>
<th style="text-align: center;">2024年9月19日</th>
<th style="text-align: center;">2024年11月8日</th>
<th style="text-align: center;">2024年11月8日</th>
</tr>
<tr>
<td style="text-align: center;">视频分辨率</td>
<td colspan="3" style="text-align: center;">720 * 480</td>
<td colspan="1" style="text-align: center;">1360 * 768</td>
<td colspan="1" style="text-align: center;">256 <= W <=1360<br> 256 <= H <=768<br> W,H % 16 == 0</td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">推理精度</td> <td style="text-align: center;">推理精度</td>
<td style="text-align: center;"><b>FP16*(推荐)</b>, BF16, FP32FP8*INT8不支持INT4</td> <td style="text-align: center;"><b>FP16*(推荐)</b>, BF16, FP32FP8*INT8不支持INT4</td>
<td colspan="2" style="text-align: center;"><b>BF16(推荐)</b>, FP16, FP32FP8*INT8不支持INT4</td> <td colspan="2" style="text-align: center;"><b>BF16(推荐)</b>, FP16, FP32FP8*INT8不支持INT4</td>
<td colspan="2" style="text-align: center;"><b>BF16</b></td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">单GPU显存消耗<br></td> <td style="text-align: center;">单GPU显存消耗<br></td>
<td style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> FP16: 18GB <br><b>diffusers FP16: 4GB起* </b><br><b>diffusers INT8(torchao): 3.6G起*</b></td> <td style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> FP16: 18GB <br><b>diffusers FP16: 4GB起* </b><br><b>diffusers INT8(torchao): 3.6G起*</b></td>
<td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 26GB <br><b>diffusers BF16 : 5GB起* </b><br><b>diffusers INT8(torchao): 4.4G起* </b></td> <td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 26GB <br><b>diffusers BF16 : 5GB起* </b><br><b>diffusers INT8(torchao): 4.4G起* </b></td>
<td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 66GB <br></td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">多GPU推理显存消耗</td> <td style="text-align: center;">多GPU推理显存消耗</td>
<td style="text-align: center;"><b>FP16: 10GB* using diffusers</b><br></td> <td style="text-align: center;"><b>FP16: 10GB* using diffusers</b><br></td>
<td colspan="2" style="text-align: center;"><b>BF16: 15GB* using diffusers</b><br></td> <td colspan="2" style="text-align: center;"><b>BF16: 15GB* using diffusers</b><br></td>
<td colspan="2" style="text-align: center;"><b>Not support</b><br></td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">推理速度<br>(Step = 50, FP/BF16)</td> <td style="text-align: center;">推理速度<br>(Step = 50, FP/BF16)</td>
<td style="text-align: center;">单卡A100: ~90秒<br>单卡H100: ~45秒</td> <td style="text-align: center;">单卡A100: ~90秒<br>单卡H100: ~45秒</td>
<td colspan="2" style="text-align: center;">单卡A100: ~180秒<br>单卡H100: ~90秒</td> <td colspan="2" style="text-align: center;">单卡A100: ~180秒<br>单卡H100: ~90秒</td>
</tr> <td colspan="2" style="text-align: center;">单卡A100: ~1000秒(5秒视频)<br>单卡H100: ~550秒(5秒视频)</td>
<tr>
<td style="text-align: center;">微调精度</td>
<td style="text-align: center;"><b>FP16</b></td>
<td colspan="2" style="text-align: center;"><b>BF16</b></td>
</tr>
<tr>
<td style="text-align: center;">微调显存消耗</td>
<td style="text-align: center;">47 GB (bs=1, LORA)<br> 61 GB (bs=2, LORA)<br> 62GB (bs=1, SFT)</td>
<td style="text-align: center;">63 GB (bs=1, LORA)<br> 80 GB (bs=2, LORA)<br> 75GB (bs=1, SFT)<br></td>
<td style="text-align: center;">78 GB (bs=1, LORA)<br> 75GB (bs=1, SFT, 16GPU)<br></td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">提示词语言</td> <td style="text-align: center;">提示词语言</td>
<td colspan="3" style="text-align: center;">English*</td> <td colspan="5" style="text-align: center;">English*</td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">提示词长度上限</td> <td style="text-align: center;">提示词长度上限</td>
<td colspan="3" style="text-align: center;">226 Tokens</td> <td colspan="3" style="text-align: center;">226 Tokens</td>
<td colspan="2" style="text-align: center;">224 Tokens</td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">视频长度</td> <td style="text-align: center;">视频长度</td>
<td colspan="3" style="text-align: center;">6 秒</td> <td colspan="3" style="text-align: center;">6 秒</td>
<td colspan="2" style="text-align: center;">5 秒 或 10 秒</td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">帧率</td> <td style="text-align: center;">帧率</td>
<td colspan="3" style="text-align: center;">8 帧 / 秒 </td> <td colspan="3" style="text-align: center;">8 帧 / 秒 </td>
<td colspan="2" style="text-align: center;">16 帧 / 秒 </td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">视频分辨率</td>
<td colspan="3" style="text-align: center;">720 * 480不支持其他分辨率(含微调)</td>
</tr>
<tr>
<td style="text-align: center;">位置编码</td> <td style="text-align: center;">位置编码</td>
<td style="text-align: center;">3d_sincos_pos_embed</td> <td style="text-align: center;">3d_sincos_pos_embed</td>
<td style="text-align: center;">3d_sincos_pos_embed</td> <td style="text-align: center;">3d_sincos_pos_embed</td>
<td style="text-align: center;">3d_rope_pos_embed + learnable_pos_embed</td>
<td style="text-align: center;">3d_sincos_pos_embed</td>
<td style="text-align: center;">3d_rope_pos_embed + learnable_pos_embed</td> <td style="text-align: center;">3d_rope_pos_embed + learnable_pos_embed</td>
</tr> </tr>
<tr> <tr>
@ -220,10 +230,13 @@ CogVideoX是 [清影](https://chatglm.cn/video?fr=osm_cogvideox) 同源的开源
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-2b">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-2b">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-2b">🟣 WiseModel</a></td> <td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-2b">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-2b">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-2b">🟣 WiseModel</a></td>
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-5b">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-5b">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-5b">🟣 WiseModel</a></td> <td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-5b">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-5b">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-5b">🟣 WiseModel</a></td>
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-5b-I2V">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-5b-I2V">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-5b-I2V">🟣 WiseModel</a></td> <td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-5b-I2V">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-5b-I2V">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-5b-I2V">🟣 WiseModel</a></td>
<td colspan="2" style="text-align: center;"> 即将推出 </td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">下载链接 (SAT)</td> <td style="text-align: center;">下载链接 (SAT)</td>
<td colspan="3" style="text-align: center;"><a href="./sat/README_zh.md">SAT</a></td> <td colspan="3" style="text-align: center;"><a href="./sat/README_zh.md">SAT</a></td>
<td colspan="2" style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX1.5-5b-SAT">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX1.5-5b-SAT">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX1.5-5b-SAT">🟣 WiseModel</a></td>
</tr> </tr>
</table> </table>

View File

@ -1,29 +1,39 @@
# SAT CogVideoX-2B # SAT CogVideoX
[中文阅读](./README_zh.md) [Read this in English.](./README_zh.md)
[日本語で読む](./README_ja.md) [日本語で読む](./README_ja.md)
This folder contains the inference code using [SAT](https://github.com/THUDM/SwissArmyTransformer) weights and the This folder contains inference code using [SAT](https://github.com/THUDM/SwissArmyTransformer) weights, along with fine-tuning code for SAT weights.
fine-tuning code for SAT weights.
This code is the framework used by the team to train the model. It has few comments and requires careful study. This code framework was used by our team during model training. There are few comments, so careful study is required.
## Inference Model ## Inference Model
### 1. Ensure that you have correctly installed the dependencies required by this folder. ### 1. Make sure you have installed all dependencies in this folder
```shell ```
pip install -r requirements.txt pip install -r requirements.txt
``` ```
### 2. Download the model weights ### 2. Download the Model Weights
### 2. Download model weights First, download the model weights from the SAT mirror.
First, go to the SAT mirror to download the model weights. For the CogVideoX-2B model, please download as follows: #### CogVideoX1.5 Model
```shell ```
git lfs install
git clone https://huggingface.co/THUDM/CogVideoX1.5-5B-SAT
```
This command downloads three models: Transformers, VAE, and T5 Encoder.
#### CogVideoX Model
For the CogVideoX-2B model, download as follows:
```
mkdir CogVideoX-2b-sat mkdir CogVideoX-2b-sat
cd CogVideoX-2b-sat cd CogVideoX-2b-sat
wget https://cloud.tsinghua.edu.cn/f/fdba7608a49c463ba754/?dl=1 wget https://cloud.tsinghua.edu.cn/f/fdba7608a49c463ba754/?dl=1
@ -34,13 +44,12 @@ mv 'index.html?dl=1' transformer.zip
unzip transformer.zip unzip transformer.zip
``` ```
For the CogVideoX-5B model, please download the `transformers` file as follows link: Download the `transformers` file for the CogVideoX-5B model (the VAE file is the same as for 2B):
(VAE files are the same as 2B)
+ [CogVideoX-5B](https://cloud.tsinghua.edu.cn/d/fcef5b3904294a6885e5/?p=%2F&mode=list) + [CogVideoX-5B](https://cloud.tsinghua.edu.cn/d/fcef5b3904294a6885e5/?p=%2F&mode=list)
+ [CogVideoX-5B-I2V](https://cloud.tsinghua.edu.cn/d/5cc62a2d6e7d45c0a2f6/?p=%2F1&mode=list) + [CogVideoX-5B-I2V](https://cloud.tsinghua.edu.cn/d/5cc62a2d6e7d45c0a2f6/?p=%2F1&mode=list)
Next, you need to format the model files as follows: Arrange the model files in the following structure:
``` ```
. .
@ -52,20 +61,24 @@ Next, you need to format the model files as follows:
└── 3d-vae.pt └── 3d-vae.pt
``` ```
Due to large size of model weight file, using `git lfs` is recommended. Installation of `git lfs` can be Since model weight files are large, its recommended to use `git lfs`.
found [here](https://github.com/git-lfs/git-lfs?tab=readme-ov-file#installing) See [here](https://github.com/git-lfs/git-lfs?tab=readme-ov-file#installing) for `git lfs` installation.
Next, clone the T5 model, which is not used for training and fine-tuning, but must be used. ```
> T5 model is available on [Modelscope](https://modelscope.cn/models/ZhipuAI/CogVideoX-2b) as well. git lfs install
```
```shell Next, clone the T5 model, which is used as an encoder and doesnt require training or fine-tuning.
git clone https://huggingface.co/THUDM/CogVideoX-2b.git > You may also use the model file location on [Modelscope](https://modelscope.cn/models/ZhipuAI/CogVideoX-2b).
```
git clone https://huggingface.co/THUDM/CogVideoX-2b.git # Download model from Huggingface
# git clone https://www.modelscope.cn/ZhipuAI/CogVideoX-2b.git # Download from Modelscope
mkdir t5-v1_1-xxl mkdir t5-v1_1-xxl
mv CogVideoX-2b/text_encoder/* CogVideoX-2b/tokenizer/* t5-v1_1-xxl mv CogVideoX-2b/text_encoder/* CogVideoX-2b/tokenizer/* t5-v1_1-xxl
``` ```
By following the above approach, you will obtain a safetensor format T5 file. Ensure that there are no errors when This will yield a safetensor format T5 file that can be loaded without error during Deepspeed fine-tuning.
loading it into Deepspeed in Finetune.
``` ```
├── added_tokens.json ├── added_tokens.json
@ -80,11 +93,11 @@ loading it into Deepspeed in Finetune.
0 directories, 8 files 0 directories, 8 files
``` ```
### 3. Modify the file in `configs/cogvideox_2b.yaml`. ### 3. Modify `configs/cogvideox_*.yaml` file.
```yaml ```yaml
model: model:
scale_factor: 1.15258426 scale_factor: 1.55258426
disable_first_stage_autocast: true disable_first_stage_autocast: true
log_keys: log_keys:
- txt - txt
@ -160,14 +173,14 @@ model:
ucg_rate: 0.1 ucg_rate: 0.1
target: sgm.modules.encoders.modules.FrozenT5Embedder target: sgm.modules.encoders.modules.FrozenT5Embedder
params: params:
model_dir: "t5-v1_1-xxl" # Absolute path to the CogVideoX-2b/t5-v1_1-xxl weights folder model_dir: "t5-v1_1-xxl" # absolute path to CogVideoX-2b/t5-v1_1-xxl weight folder
max_length: 226 max_length: 226
first_stage_config: first_stage_config:
target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper
params: params:
cp_size: 1 cp_size: 1
ckpt_path: "CogVideoX-2b-sat/vae/3d-vae.pt" # Absolute path to the CogVideoX-2b-sat/vae/3d-vae.pt folder ckpt_path: "CogVideoX-2b-sat/vae/3d-vae.pt" # absolute path to CogVideoX-2b-sat/vae/3d-vae.pt file
ignore_keys: [ 'loss' ] ignore_keys: [ 'loss' ]
loss_config: loss_config:
@ -239,48 +252,46 @@ model:
num_steps: 50 num_steps: 50
``` ```
### 4. Modify the file in `configs/inference.yaml`. ### 4. Modify `configs/inference.yaml` file.
```yaml ```yaml
args: args:
latent_channels: 16 latent_channels: 16
mode: inference mode: inference
load: "{absolute_path/to/your}/transformer" # Absolute path to the CogVideoX-2b-sat/transformer folder load: "{absolute_path/to/your}/transformer" # Absolute path to CogVideoX-2b-sat/transformer folder
# load: "{your lora folder} such as zRzRzRzRzRzRzR/lora-disney-08-20-13-28" # This is for Full model without lora adapter # load: "{your lora folder} such as zRzRzRzRzRzRzR/lora-disney-08-20-13-28" # This is for Full model without lora adapter
batch_size: 1 batch_size: 1
input_type: txt # You can choose txt for pure text input, or change to cli for command line input input_type: txt # You can choose "txt" for plain text input or change to "cli" for command-line input
input_file: configs/test.txt # Pure text file, which can be edited input_file: configs/test.txt # Plain text file, can be edited
sampling_num_frames: 13 # Must be 13, 11 or 9 sampling_num_frames: 13 # For CogVideoX1.5-5B it must be 42 or 22. For CogVideoX-5B / 2B, it must be 13, 11, or 9.
sampling_fps: 8 sampling_fps: 8
fp16: True # For CogVideoX-2B fp16: True # For CogVideoX-2B
# bf16: True # For CogVideoX-5B # bf16: True # For CogVideoX-5B
output_dir: outputs/ output_dir: outputs/
force_inference: True force_inference: True
``` ```
+ Modify `configs/test.txt` if multiple prompts is required, in which each line makes a prompt. + If using a text file to save multiple prompts, modify `configs/test.txt` as needed. One prompt per line. If you are unsure how to write prompts, use [this code](../inference/convert_demo.py) to call an LLM for refinement.
+ For better prompt formatting, refer to [convert_demo.py](../inference/convert_demo.py), for which you should set the + To use command-line input, modify:
OPENAI_API_KEY as your environmental variable.
+ Modify `input_type` in `configs/inference.yaml` if use command line as prompt iuput.
```yaml ```
input_type: cli input_type: cli
``` ```
This allows input from the command line as prompts. This allows you to enter prompts from the command line.
Change `output_dir` if you wish to modify the address of the output video To modify the output video location, change:
```yaml ```
output_dir: outputs/ output_dir: outputs/
``` ```
It is saved by default in the `.outputs/` folder. The default location is the `.outputs/` folder.
### 5. Run the inference code to perform inference. ### 5. Run the Inference Code to Perform Inference
```shell ```
bash inference.sh bash inference.sh
``` ```
@ -288,95 +299,91 @@ bash inference.sh
### Preparing the Dataset ### Preparing the Dataset
The dataset format should be as follows: The dataset should be structured as follows:
``` ```
. .
├── labels ├── labels
   ├── 1.txt ├── 1.txt
   ├── 2.txt ├── 2.txt
   ├── ... ├── ...
└── videos └── videos
├── 1.mp4 ├── 1.mp4
├── 2.mp4 ├── 2.mp4
├── ... ├── ...
``` ```
Each text file shares the same name as its corresponding video, serving as the label for that video. Videos and labels Each txt file should have the same name as the corresponding video file and contain the label for that video. The videos and labels should correspond one-to-one. Generally, avoid using one video with multiple labels.
should be matched one-to-one. Generally, a single video should not be associated with multiple labels.
For style fine-tuning, please prepare at least 50 videos and labels with similar styles to ensure proper fitting. For style fine-tuning, prepare at least 50 videos and labels with a similar style to facilitate fitting.
### Modifying Configuration Files ### Modifying the Configuration File
We support two fine-tuning methods: `Lora` and full-parameter fine-tuning. Please note that both methods only fine-tune We support two fine-tuning methods: `Lora` and full-parameter fine-tuning. Note that both methods only fine-tune the `transformer` part. The `VAE` part is not modified, and `T5` is only used as an encoder.
the `transformer` part and do not modify the `VAE` section. `T5` is used solely as an Encoder. Please modify Modify the files in `configs/sft.yaml` (full fine-tuning) as follows:
the `configs/sft.yaml` (for full-parameter fine-tuning) file as follows:
``` ```yaml
# checkpoint_activations: True ## Using gradient checkpointing (Both checkpoint_activations in the config file need to be set to True) # checkpoint_activations: True ## using gradient checkpointing (both `checkpoint_activations` in the config file need to be set to True)
model_parallel_size: 1 # Model parallel size model_parallel_size: 1 # Model parallel size
experiment_name: lora-disney # Experiment name (do not modify) experiment_name: lora-disney # Experiment name (do not change)
mode: finetune # Mode (do not modify) mode: finetune # Mode (do not change)
load: "{your_CogVideoX-2b-sat_path}/transformer" ## Transformer model path load: "{your_CogVideoX-2b-sat_path}/transformer" ## Path to Transformer model
no_load_rng: True # Whether to load random seed no_load_rng: True # Whether to load random number seed
train_iters: 1000 # Training iterations train_iters: 1000 # Training iterations
eval_iters: 1 # Evaluation iterations eval_iters: 1 # Evaluation iterations
eval_interval: 100 # Evaluation interval eval_interval: 100 # Evaluation interval
eval_batch_size: 1 # Evaluation batch size eval_batch_size: 1 # Evaluation batch size
save: ckpts # Model save path save: ckpts # Model save path
save_interval: 100 # Model save interval save_interval: 100 # Save interval
log_interval: 20 # Log output interval log_interval: 20 # Log output interval
train_data: [ "your train data path" ] train_data: [ "your train data path" ]
valid_data: [ "your val data path" ] # Training and validation datasets can be the same valid_data: [ "your val data path" ] # Training and validation sets can be the same
split: 1,0,0 # Training, validation, and test set ratio split: 1,0,0 # Proportion for training, validation, and test sets
num_workers: 8 # Number of worker threads for data loader num_workers: 8 # Number of data loader workers
force_train: True # Allow missing keys when loading checkpoint (T5 and VAE are loaded separately) force_train: True # Allow missing keys when loading checkpoint (T5 and VAE loaded separately)
only_log_video_latents: True # Avoid memory overhead caused by VAE decode only_log_video_latents: True # Avoid memory usage from VAE decoding
deepspeed: deepspeed:
bf16: bf16:
enabled: False # For CogVideoX-2B set to False and for CogVideoX-5B set to True enabled: False # For CogVideoX-2B Turn to False and For CogVideoX-5B Turn to True
fp16: fp16:
enabled: True # For CogVideoX-2B set to True and for CogVideoX-5B set to False enabled: True # For CogVideoX-2B Turn to True and For CogVideoX-5B Turn to False
``` ```
If you wish to use Lora fine-tuning, you also need to modify the `cogvideox_<model_parameters>_lora` file: ``` To use Lora fine-tuning, you also need to modify `cogvideox_<model parameters>_lora` file:
Here, take `CogVideoX-2B` as a reference: Here's an example using `CogVideoX-2B`:
``` ```
model: model:
scale_factor: 1.15258426 scale_factor: 1.55258426
disable_first_stage_autocast: true disable_first_stage_autocast: true
not_trainable_prefixes: [ 'all' ] ## Uncomment not_trainable_prefixes: [ 'all' ] ## Uncomment to unlock
log_keys: log_keys:
- txt' - txt
lora_config: ## Uncomment lora_config: ## Uncomment to unlock
target: sat.model.finetune.lora2.LoraMixin target: sat.model.finetune.lora2.LoraMixin
params: params:
r: 256 r: 256
``` ```
### Modifying Run Scripts ### Modify the Run Script
Edit `finetune_single_gpu.sh` or `finetune_multi_gpus.sh` to select the configuration file. Below are two examples: Edit `finetune_single_gpu.sh` or `finetune_multi_gpus.sh` and select the config file. Below are two examples:
1. If you want to use the `CogVideoX-2B` model and the `Lora` method, you need to modify `finetune_single_gpu.sh` 1. If you want to use the `CogVideoX-2B` model with `Lora`, modify `finetune_single_gpu.sh` or `finetune_multi_gpus.sh` as follows:
or `finetune_multi_gpus.sh`:
``` ```
run_cmd="torchrun --standalone --nproc_per_node=8 train_video.py --base configs/cogvideox_2b_lora.yaml configs/sft.yaml --seed $RANDOM" run_cmd="torchrun --standalone --nproc_per_node=8 train_video.py --base configs/cogvideox_2b_lora.yaml configs/sft.yaml --seed $RANDOM"
``` ```
2. If you want to use the `CogVideoX-2B` model and the `full-parameter fine-tuning` method, you need to 2. If you want to use the `CogVideoX-2B` model with full fine-tuning, modify `finetune_single_gpu.sh` or `finetune_multi_gpus.sh` as follows:
modify `finetune_single_gpu.sh` or `finetune_multi_gpus.sh`:
``` ```
run_cmd="torchrun --standalone --nproc_per_node=8 train_video.py --base configs/cogvideox_2b.yaml configs/sft.yaml --seed $RANDOM" run_cmd="torchrun --standalone --nproc_per_node=8 train_video.py --base configs/cogvideox_2b.yaml configs/sft.yaml --seed $RANDOM"
``` ```
### Fine-Tuning and Evaluation ### Fine-tuning and Validation
Run the inference code to start fine-tuning. Run the inference code to start fine-tuning.
@ -385,45 +392,42 @@ bash finetune_single_gpu.sh # Single GPU
bash finetune_multi_gpus.sh # Multi GPUs bash finetune_multi_gpus.sh # Multi GPUs
``` ```
### Using the Fine-Tuned Model ### Using the Fine-tuned Model
The fine-tuned model cannot be merged; here is how to modify the inference configuration file `inference.sh`: The fine-tuned model cannot be merged. Heres how to modify the inference configuration file `inference.sh`
``` ```
run_cmd="$environs python sample_video.py --base configs/cogvideox_<model_parameters>_lora.yaml configs/inference.yaml --seed 42" run_cmd="$environs python sample_video.py --base configs/cogvideox_<model parameters>_lora.yaml configs/inference.yaml --seed 42"
``` ```
Then, execute the code: Then, run the code:
``` ```
bash inference.sh bash inference.sh
``` ```
### Converting to Huggingface Diffusers Supported Weights ### Converting to Huggingface Diffusers-compatible Weights
The SAT weight format is different from Huggingface's weight format and needs to be converted. Please run: The SAT weight format is different from Huggingfaces format and requires conversion. Run
```shell ```
python ../tools/convert_weight_sat2hf.py python ../tools/convert_weight_sat2hf.py
``` ```
### Exporting Huggingface Diffusers lora LoRA Weights from SAT Checkpoints ### Exporting Lora Weights from SAT to Huggingface Diffusers
After completing the training using the above steps, we get a SAT checkpoint with LoRA weights. You can find the file Support is provided for exporting Lora weights from SAT to Huggingface Diffusers format.
at `{args.save}/1000/1000/mp_rank_00_model_states.pt`. After training with the above steps, youll find the SAT model with Lora weights in {args.save}/1000/1000/mp_rank_00_model_states.pt
The script for exporting LoRA weights can be found in the CogVideoX repository at `tools/export_sat_lora_weight.py`. The export script `export_sat_lora_weight.py` is located in the CogVideoX repository under `tools/`. After exporting, use `load_cogvideox_lora.py` for inference.
After exporting, you can use `load_cogvideox_lora.py` for inference.
Export command: Export command:
```bash ```
python tools/export_sat_lora_weight.py --sat_pt_path {args.save}/{experiment_name}-09-09-21-10/1000/mp_rank_00_model_states.pt --lora_save_directory {args.save}/export_hf_lora_weights_1/ python tools/export_sat_lora_weight.py --sat_pt_path {args.save}/{experiment_name}-09-09-21-10/1000/mp_rank_00_model_states.pt --lora_save_directory {args.save}/export_hf_lora_weights_1/
``` ```
This training mainly modified the following model structures. The table below lists the corresponding structure mappings The following model structures were modified during training. Here is the mapping between SAT and HF Lora structures. Lora adds a low-rank weight to the attention structure of the model.
for converting to the HF (Hugging Face) format LoRA structure. As you can see, LoRA adds a low-rank weight to the
model's attention structure.
``` ```
'attention.query_key_value.matrix_A.0': 'attn1.to_q.lora_A.weight', 'attention.query_key_value.matrix_A.0': 'attn1.to_q.lora_A.weight',
@ -436,5 +440,5 @@ model's attention structure.
'attention.dense.matrix_B.0': 'attn1.to_out.0.lora_B.weight' 'attention.dense.matrix_B.0': 'attn1.to_out.0.lora_B.weight'
``` ```
Using export_sat_lora_weight.py, you can convert the SAT checkpoint into the HF LoRA format. Using `export_sat_lora_weight.py` will convert these to the HF format Lora structure.
![alt text](../resources/hf_lora_weights.png) ![alt text](../resources/hf_lora_weights.png)

View File

@ -1,27 +1,37 @@
# SAT CogVideoX-2B # SAT CogVideoX
[Read this in English.](./README_zh) [Read this in English.](./README.md)
[中文阅读](./README_zh.md) [中文阅读](./README_zh.md)
このフォルダには、[SAT](https://github.com/THUDM/SwissArmyTransformer) ウェイトを使用した推論コードと、SAT このフォルダには、[SAT](https://github.com/THUDM/SwissArmyTransformer)の重みを使用した推論コードと、SAT重みのファインチューニングコードが含まれています。
ウェイトのファインチューニングコードが含まれています。 このコードは、チームがモデルを訓練する際に使用したフレームワークです。コメントが少ないため、注意深く確認する必要があります。
このコードは、チームがモデルをトレーニングするために使用したフレームワークです。コメントが少なく、注意深く研究する必要があります。
## 推論モデル ## 推論モデル
### 1. このフォルダに必要な依存関係が正しくインストールされていることを確認してください ### 1. このフォルダ内の必要な依存関係がすべてインストールされていることを確認してください
```shell ```
pip install -r requirements.txt pip install -r requirements.txt
``` ```
### 2. モデルウェイトをダウンロードします ### 2. モデルの重みをダウンロード
まず、SATミラーからモデルの重みをダウンロードしてください。
まず、SAT ミラーに移動してモデルの重みをダウンロードします。 CogVideoX-2B モデルの場合は、次のようにダウンロードしてください。 #### CogVideoX1.5 モデル
```shell ```
git lfs install
git clone https://huggingface.co/THUDM/CogVideoX1.5-5B-SAT
```
これにより、Transformers、VAE、T5 Encoderの3つのモデルがダウンロードされます。
#### CogVideoX モデル
CogVideoX-2B モデルについては、以下のようにダウンロードしてください:
```
mkdir CogVideoX-2b-sat mkdir CogVideoX-2b-sat
cd CogVideoX-2b-sat cd CogVideoX-2b-sat
wget https://cloud.tsinghua.edu.cn/f/fdba7608a49c463ba754/?dl=1 wget https://cloud.tsinghua.edu.cn/f/fdba7608a49c463ba754/?dl=1
@ -32,12 +42,12 @@ mv 'index.html?dl=1' transformer.zip
unzip transformer.zip unzip transformer.zip
``` ```
CogVideoX-5B モデルの `transformers` ファイルを以下のリンクからダウンロードしてください VAE ファイルは 2B と同じです): CogVideoX-5B モデルの `transformers` ファイルをダウンロードしてくださいVAEファイルは2Bと同じです
+ [CogVideoX-5B](https://cloud.tsinghua.edu.cn/d/fcef5b3904294a6885e5/?p=%2F&mode=list) + [CogVideoX-5B](https://cloud.tsinghua.edu.cn/d/fcef5b3904294a6885e5/?p=%2F&mode=list)
+ [CogVideoX-5B-I2V](https://cloud.tsinghua.edu.cn/d/5cc62a2d6e7d45c0a2f6/?p=%2F1&mode=list) + [CogVideoX-5B-I2V](https://cloud.tsinghua.edu.cn/d/5cc62a2d6e7d45c0a2f6/?p=%2F1&mode=list)
次に、モデルファイルを以下の形式にフォーマットする必要があります モデルファイルを以下のように配置してください
``` ```
. .
@ -49,24 +59,24 @@ CogVideoX-5B モデルの `transformers` ファイルを以下のリンクから
└── 3d-vae.pt └── 3d-vae.pt
``` ```
モデルの重みファイルが大きいため、`git lfs`を使用することをお勧めいたします。`git lfs` モデルの重みファイルが大きいため、`git lfs`の使用をお勧めします。
のインストールについては、[こちら](https://github.com/git-lfs/git-lfs?tab=readme-ov-file#installing)を参照ください。 `git lfs`のインストール方法は[こちら](https://github.com/git-lfs/git-lfs?tab=readme-ov-file#installing)を参照してください。
```shell ```
git lfs install git lfs install
``` ```
次に、T5 モデルをクローンします。これはトレーニングやファインチューニングには使用されませんが、使用する必要があります 次に、T5モデルをクローンします。このモデルはEncoderとしてのみ使用され、訓練やファインチューニングは必要ありません
> モデルを複製する際には、[Modelscope](https://modelscope.cn/models/ZhipuAI/CogVideoX-2b)のモデルファイルの場所もご使用いただけます。 > [Modelscope](https://modelscope.cn/models/ZhipuAI/CogVideoX-2b)上のモデルファイルも使用可能です。
```shell ```
git clone https://huggingface.co/THUDM/CogVideoX-2b.git #ハギングフェイス(huggingface.org)からモデルをダウンロードいただきます git clone https://huggingface.co/THUDM/CogVideoX-2b.git # Huggingfaceからモデルをダウンロード
# git clone https://www.modelscope.cn/ZhipuAI/CogVideoX-2b.git #Modelscopeからモデルをダウンロードいただきます # git clone https://www.modelscope.cn/ZhipuAI/CogVideoX-2b.git # Modelscopeからダウンロード
mkdir t5-v1_1-xxl mkdir t5-v1_1-xxl
mv CogVideoX-2b/text_encoder/* CogVideoX-2b/tokenizer/* t5-v1_1-xxl mv CogVideoX-2b/text_encoder/* CogVideoX-2b/tokenizer/* t5-v1_1-xxl
``` ```
上記の方法に従うことで、safetensor 形式の T5 ファイルを取得できます。これにより、Deepspeed でのファインチューニング中にエラーが発生しないようにします。 これにより、Deepspeedファインチューニング中にエラーなくロードできるsafetensor形式のT5ファイルが作成されます。
``` ```
├── added_tokens.json ├── added_tokens.json
@ -81,11 +91,11 @@ mv CogVideoX-2b/text_encoder/* CogVideoX-2b/tokenizer/* t5-v1_1-xxl
0 directories, 8 files 0 directories, 8 files
``` ```
### 3. `configs/cogvideox_2b.yaml` ファイルを変更します。 ### 3. `configs/cogvideox_*.yaml`ファイルを編集
```yaml ```yaml
model: model:
scale_factor: 1.15258426 scale_factor: 1.55258426
disable_first_stage_autocast: true disable_first_stage_autocast: true
log_keys: log_keys:
- txt - txt
@ -123,7 +133,7 @@ model:
num_attention_heads: 30 num_attention_heads: 30
transformer_args: transformer_args:
checkpoint_activations: True ## グラデーション チェックポイントを使用する checkpoint_activations: True ## using gradient checkpointing
vocab_size: 1 vocab_size: 1
max_sequence_length: 64 max_sequence_length: 64
layernorm_order: pre layernorm_order: pre
@ -161,14 +171,14 @@ model:
ucg_rate: 0.1 ucg_rate: 0.1
target: sgm.modules.encoders.modules.FrozenT5Embedder target: sgm.modules.encoders.modules.FrozenT5Embedder
params: params:
model_dir: "t5-v1_1-xxl" # CogVideoX-2b/t5-v1_1-xxlフォルダの絶対パス model_dir: "t5-v1_1-xxl" # CogVideoX-2b/t5-v1_1-xxl 重みフォルダの絶対パス
max_length: 226 max_length: 226
first_stage_config: first_stage_config:
target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper
params: params:
cp_size: 1 cp_size: 1
ckpt_path: "CogVideoX-2b-sat/vae/3d-vae.pt" # CogVideoX-2b-sat/vae/3d-vae.ptフォルダの絶対パス ckpt_path: "CogVideoX-2b-sat/vae/3d-vae.pt" # CogVideoX-2b-sat/vae/3d-vae.ptファイルの絶対パス
ignore_keys: [ 'loss' ] ignore_keys: [ 'loss' ]
loss_config: loss_config:
@ -240,7 +250,7 @@ model:
num_steps: 50 num_steps: 50
``` ```
### 4. `configs/inference.yaml` ファイルを変更します。 ### 4. `configs/inference.yaml`ファイルを編集
```yaml ```yaml
args: args:
@ -250,38 +260,36 @@ args:
# load: "{your lora folder} such as zRzRzRzRzRzRzR/lora-disney-08-20-13-28" # This is for Full model without lora adapter # load: "{your lora folder} such as zRzRzRzRzRzRzR/lora-disney-08-20-13-28" # This is for Full model without lora adapter
batch_size: 1 batch_size: 1
input_type: txt #TXTのテキストファイルを入力として選択されたり、CLIコマンドラインを入力として変更されたりいただけます input_type: txt # "txt"でプレーンテキスト入力、"cli"でコマンドライン入力を選択可能
input_file: configs/test.txt #テキストファイルのパスで、これに対して編集がさせていただけます input_file: configs/test.txt # プレーンテキストファイル、編集可能
sampling_num_frames: 13 # Must be 13, 11 or 9 sampling_num_frames: 13 # CogVideoX1.5-5Bでは42または22、CogVideoX-5B / 2Bでは13, 11, または9
sampling_fps: 8 sampling_fps: 8
fp16: True # For CogVideoX-2B fp16: True # CogVideoX-2B
# bf16: True # For CogVideoX-5B # bf16: True # CogVideoX-5B
output_dir: outputs/ output_dir: outputs/
force_inference: True force_inference: True
``` ```
+ 複数のプロンプトを保存するために txt を使用する場合は、`configs/test.txt` + 複数のプロンプトを含むテキストファイルを使用する場合、`configs/test.txt`を適宜編集してください。1行につき1プロンプトです。プロンプトの書き方が分からない場合は、[こちらのコード](../inference/convert_demo.py)を使用してLLMで補正できます。
を参照して変更してください。1行に1つのプロンプトを記述します。プロンプトの書き方がわからない場合は、最初に [このコード](../inference/convert_demo.py) + コマンドライン入力を使用する場合、以下のように変更します:
を使用して LLM によるリファインメントを呼び出すことができます。
+ コマンドラインを入力として使用する場合は、次のように変更します。
```yaml ```
input_type: cli input_type: cli
``` ```
これにより、コマンドラインからプロンプトを入力できます。 これにより、コマンドラインからプロンプトを入力できます。
出力ビデオのディレクトリを変更したい場合は、次のように変更できます 出力ビデオの保存場所を変更する場合は、以下を編集してください
```yaml ```
output_dir: outputs/ output_dir: outputs/
``` ```
デフォルトでは `.outputs/` フォルダに保存されます。 デフォルトでは`.outputs/`フォルダに保存されます。
### 5. 推論コードを実行して推論を開始します。 ### 5. 推論コードを実行して推論を開始
```shell ```
bash inference.sh bash inference.sh
``` ```
@ -289,7 +297,7 @@ bash inference.sh
### データセットの準備 ### データセットの準備
データセットの形式は次のようになります: データセットは以下の構造である必要があります:
``` ```
. .
@ -303,123 +311,215 @@ bash inference.sh
├── ... ├── ...
``` ```
txt ファイルは対応するビデオファイルと同じ名前であり、そのビデオのラベルを含んでいます。各ビデオはラベルと一対一で対応する必要があります。通常、1つのビデオに複数のラベルを持たせることはありません 各txtファイルは対応するビデオファイルと同じ名前で、ビデオのラベルを含んでいます。ビデオとラベルは一対一で対応させる必要があります。通常、1つのビデオに複数のラベルを使用することは避けてください
スタイルファインチューニングの場合、少なくとも50本のスタイルが似たビデオとラベルを準備し、フィッティングを容易にします。 スタイルのファインチューニングの場合、スタイルが似たビデオとラベルを少なくとも50本準備し、フィッティングを促進します。
### 設定ファイルの変更 ### 設定ファイルの編集
`Lora` とフルパラメータ微調整の2つの方法をサポートしています。両方の微調整方法は、`transformer` 部分のみを微調整し、`VAE` ``` `Lora`と全パラメータのファインチューニングの2種類をサポートしています。どちらも`transformer`部分のみをファインチューニングし、`VAE`部分は変更されず、`T5`はエンコーダーとしてのみ使用されます。
部分には変更を加えないことに注意してください。`T5` はエンコーダーとしてのみ使用されます。以下のように `configs/sft.yaml` ( ``` 以下のようにして`configs/sft.yaml`(全量ファインチューニング)ファイルを編集してください:
フルパラメータ微調整用) ファイルを変更してください。
``` ```
# checkpoint_activations: True ## 勾配チェックポイントを使用する場合 (設定ファイル内の2つの checkpoint_activations を True に設定する必要があります) # checkpoint_activations: True ## using gradient checkpointing (configファイル内の2つの`checkpoint_activations`を両方Trueに設定)
model_parallel_size: 1 # モデル並列サイズ model_parallel_size: 1 # モデル並列サイズ
experiment_name: lora-disney # 実験名 (変更しないでください) experiment_name: lora-disney # 実験名(変更不要)
mode: finetune # モード (変更しないでください) mode: finetune # モード(変更不要)
load: "{your_CogVideoX-2b-sat_path}/transformer" ## Transformer モデルのパス load: "{your_CogVideoX-2b-sat_path}/transformer" ## Transformerモデルのパス
no_load_rng: True # 乱数シードを読み込むかどうか no_load_rng: True # 乱数シードをロードするかどうか
train_iters: 1000 # トレーニングイテレーション数 train_iters: 1000 # トレーニングイテレーション数
eval_iters: 1 # 評価イテレーション数 eval_iters: 1 # 検証イテレーション数
eval_interval: 100 # 評価間隔 eval_interval: 100 # 検証間隔
eval_batch_size: 1 # 評価バッチサイズ eval_batch_size: 1 # 検証バッチサイズ
save: ckpts # モデル保存パス save: ckpts # モデル保存パス
save_interval: 100 # モデル保存間隔 save_interval: 100 # 保存間隔
log_interval: 20 # ログ出力間隔 log_interval: 20 # ログ出力間隔
train_data: [ "your train data path" ] train_data: [ "your train data path" ]
valid_data: [ "your val data path" ] # トレーニングデータと評価データは同じでも構いません valid_data: [ "your val data path" ] # トレーニングセットと検証セットは同じでも構いません
split: 1,0,0 # トレーニングセット、評価セット、テストセットの割合 split: 1,0,0 # トレーニングセット、検証セット、テストセットの割合
num_workers: 8 # データローダーのワーカースレッド num_workers: 8 # データローダーのワーカー数
force_train: True # チェックポイントをロードするときに欠落したキーを許可 (T5 と VAE は別々にロードされます) force_train: True # チェックポイントをロードする際に`missing keys`を許可T5とVAEは別途ロード
only_log_video_latents: True # VAE のデコードによるメモリオーバーヘッドを回避 only_log_video_latents: True # VAEのデコードによるメモリ使用量を抑える
deepspeed: deepspeed:
bf16: bf16:
enabled: False # CogVideoX-2B の場合は False に設定し、CogVideoX-5B の場合は True に設定 enabled: False # CogVideoX-2B 用は False、CogVideoX-5B 用は True に設定
fp16: fp16:
enabled: True # CogVideoX-2B の場合は True に設定し、CogVideoX-5B の場合は False に設定 enabled: True # CogVideoX-2B 用は True、CogVideoX-5B 用は False に設定
```
```yaml
args:
latent_channels: 16
mode: inference
load: "{absolute_path/to/your}/transformer" # Absolute path to CogVideoX-2b-sat/transformer folder
# load: "{your lora folder} such as zRzRzRzRzRzRzR/lora-disney-08-20-13-28" # This is for Full model without lora adapter
batch_size: 1
input_type: txt # You can choose "txt" for plain text input or change to "cli" for command-line input
input_file: configs/test.txt # Plain text file, can be edited
sampling_num_frames: 13 # For CogVideoX1.5-5B it must be 42 or 22. For CogVideoX-5B / 2B, it must be 13, 11, or 9.
sampling_fps: 8
fp16: True # For CogVideoX-2B
# bf16: True # For CogVideoX-5B
output_dir: outputs/
force_inference: True
``` ```
Lora 微調整を使用したい場合は、`cogvideox_<model_parameters>_lora` ファイルも変更する必要があります。 + If using a text file to save multiple prompts, modify `configs/test.txt` as needed. One prompt per line. If you are unsure how to write prompts, use [this code](../inference/convert_demo.py) to call an LLM for refinement.
+ To use command-line input, modify:
ここでは、`CogVideoX-2B` を参考にします。
``` ```
input_type: cli
```
This allows you to enter prompts from the command line.
To modify the output video location, change:
```
output_dir: outputs/
```
The default location is the `.outputs/` folder.
### 5. Run the Inference Code to Perform Inference
```
bash inference.sh
```
## Fine-tuning the Model
### Preparing the Dataset
The dataset should be structured as follows:
```
.
├── labels
│ ├── 1.txt
│ ├── 2.txt
│ ├── ...
└── videos
├── 1.mp4
├── 2.mp4
├── ...
```
Each txt file should have the same name as the corresponding video file and contain the label for that video. The videos and labels should correspond one-to-one. Generally, avoid using one video with multiple labels.
For style fine-tuning, prepare at least 50 videos and labels with a similar style to facilitate fitting.
### Modifying the Configuration File
We support two fine-tuning methods: `Lora` and full-parameter fine-tuning. Note that both methods only fine-tune the `transformer` part. The `VAE` part is not modified, and `T5` is only used as an encoder.
Modify the files in `configs/sft.yaml` (full fine-tuning) as follows:
```yaml
# checkpoint_activations: True ## using gradient checkpointing (both `checkpoint_activations` in the config file need to be set to True)
model_parallel_size: 1 # Model parallel size
experiment_name: lora-disney # Experiment name (do not change)
mode: finetune # Mode (do not change)
load: "{your_CogVideoX-2b-sat_path}/transformer" ## Path to Transformer model
no_load_rng: True # Whether to load random number seed
train_iters: 1000 # Training iterations
eval_iters: 1 # Evaluation iterations
eval_interval: 100 # Evaluation interval
eval_batch_size: 1 # Evaluation batch size
save: ckpts # Model save path
save_interval: 100 # Save interval
log_interval: 20 # Log output interval
train_data: [ "your train data path" ]
valid_data: [ "your val data path" ] # Training and validation sets can be the same
split: 1,0,0 # Proportion for training, validation, and test sets
num_workers: 8 # Number of data loader workers
force_train: True # Allow missing keys when loading checkpoint (T5 and VAE loaded separately)
only_log_video_latents: True # Avoid memory usage from VAE decoding
deepspeed:
bf16:
enabled: False # For CogVideoX-2B Turn to False and For CogVideoX-5B Turn to True
fp16:
enabled: True # For CogVideoX-2B Turn to True and For CogVideoX-5B Turn to False
```
``` To use Lora fine-tuning, you also need to modify `cogvideox_<model parameters>_lora` file:
Here's an example using `CogVideoX-2B`:
```yaml
model: model:
scale_factor: 1.15258426 scale_factor: 1.55258426
disable_first_stage_autocast: true disable_first_stage_autocast: true
not_trainable_prefixes: [ 'all' ] ## コメントを解除 not_trainable_prefixes: [ 'all' ] ## Uncomment to unlock
log_keys: log_keys:
- txt' - txt
lora_config: ## コメントを解除 lora_config: ## Uncomment to unlock
target: sat.model.finetune.lora2.LoraMixin target: sat.model.finetune.lora2.LoraMixin
params: params:
r: 256 r: 256
``` ```
### 実行スクリプトの変更 ### Modify the Run Script
設定ファイルを選択するために `finetune_single_gpu.sh` または `finetune_multi_gpus.sh` を編集します。以下に2つの例を示します。 Edit `finetune_single_gpu.sh` or `finetune_multi_gpus.sh` and select the config file. Below are two examples:
1. `CogVideoX-2B` モデルを使用し、`Lora` 手法を利用する場合は、`finetune_single_gpu.sh` または `finetune_multi_gpus.sh` 1. If you want to use the `CogVideoX-2B` model with `Lora`, modify `finetune_single_gpu.sh` or `finetune_multi_gpus.sh` as follows:
を変更する必要があります。
``` ```
run_cmd="torchrun --standalone --nproc_per_node=8 train_video.py --base configs/cogvideox_2b_lora.yaml configs/sft.yaml --seed $RANDOM" run_cmd="torchrun --standalone --nproc_per_node=8 train_video.py --base configs/cogvideox_2b_lora.yaml configs/sft.yaml --seed $RANDOM"
``` ```
2. `CogVideoX-2B` モデルを使用し、`フルパラメータ微調整` 手法を利用する場合は、`finetune_single_gpu.sh` 2. If you want to use the `CogVideoX-2B` model with full fine-tuning, modify `finetune_single_gpu.sh` or `finetune_multi_gpus.sh` as follows:
または `finetune_multi_gpus.sh` を変更する必要があります。
``` ```
run_cmd="torchrun --standalone --nproc_per_node=8 train_video.py --base configs/cogvideox_2b.yaml configs/sft.yaml --seed $RANDOM" run_cmd="torchrun --standalone --nproc_per_node=8 train_video.py --base configs/cogvideox_2b.yaml configs/sft.yaml --seed $RANDOM"
``` ```
### 微調整と評価 ### Fine-tuning and Validation
推論コードを実行して微調整を開始します。 Run the inference code to start fine-tuning.
``` ```
bash finetune_single_gpu.sh # シングルGPU bash finetune_single_gpu.sh # Single GPU
bash finetune_multi_gpus.sh # マルチGPU bash finetune_multi_gpus.sh # Multi GPUs
``` ```
### 微調整後のモデルの使用 ### Using the Fine-tuned Model
微調整されたモデルは統合できません。ここでは、推論設定ファイル `inference.sh` を変更する方法を示します。 The fine-tuned model cannot be merged. Heres how to modify the inference configuration file `inference.sh`
``` ```
run_cmd="$environs python sample_video.py --base configs/cogvideox_<model_parameters>_lora.yaml configs/inference.yaml --seed 42" run_cmd="$environs python sample_video.py --base configs/cogvideox_<model parameters>_lora.yaml configs/inference.yaml --seed 42"
``` ```
その後、次のコードを実行します。 Then, run the code:
``` ```
bash inference.sh bash inference.sh
``` ```
### Huggingface Diffusers サポートのウェイトに変換 ### Converting to Huggingface Diffusers-compatible Weights
SAT ウェイト形式は Huggingface のウェイト形式と異なり、変換が必要です。次のコマンドを実行してください: The SAT weight format is different from Huggingfaces format and requires conversion. Run
```shell ```
python ../tools/convert_weight_sat2hf.py python ../tools/convert_weight_sat2hf.py
``` ```
### SATチェックポイントからHuggingface Diffusers lora LoRAウェイトをエクスポート ### Exporting Lora Weights from SAT to Huggingface Diffusers
上記のステップを完了すると、LoRAウェイト付きのSATチェックポイントが得られます。ファイルは `{args.save}/1000/1000/mp_rank_00_model_states.pt` にあります。 Support is provided for exporting Lora weights from SAT to Huggingface Diffusers format.
After training with the above steps, youll find the SAT model with Lora weights in {args.save}/1000/1000/mp_rank_00_model_states.pt
LoRAウェイトをエクスポートするためのスクリプトは、CogVideoXリポジトリの `tools/export_sat_lora_weight.py` にあります。エクスポート後、`load_cogvideox_lora.py` を使用して推論を行うことができます。 The export script `export_sat_lora_weight.py` is located in the CogVideoX repository under `tools/`. After exporting, use `load_cogvideox_lora.py` for inference.
エクスポートコマンド: Export command:
```bash ```
python tools/export_sat_lora_weight.py --sat_pt_path {args.save}/{experiment_name}-09-09-21-10/1000/mp_rank_00_model_states.pt --lora_save_directory {args.save}/export_hf_lora_weights_1/ python tools/export_sat_lora_weight.py --sat_pt_path {args.save}/{experiment_name}-09-09-21-10/1000/mp_rank_00_model_states.pt --lora_save_directory {args.save}/export_hf_lora_weights_1/
``` ```
このトレーニングでは主に以下のモデル構造が変更されました。以下の表は、HF (Hugging Face) 形式のLoRA構造に変換する際の対応関係を示しています。ご覧の通り、LoRAはモデルの注意メカニズムに低ランクの重みを追加しています。 The following model structures were modified during training. Here is the mapping between SAT and HF Lora structures. Lora adds a low-rank weight to the attention structure of the model.
``` ```
'attention.query_key_value.matrix_A.0': 'attn1.to_q.lora_A.weight', 'attention.query_key_value.matrix_A.0': 'attn1.to_q.lora_A.weight',
@ -431,8 +531,6 @@ python tools/export_sat_lora_weight.py --sat_pt_path {args.save}/{experiment_nam
'attention.dense.matrix_A.0': 'attn1.to_out.0.lora_A.weight', 'attention.dense.matrix_A.0': 'attn1.to_out.0.lora_A.weight',
'attention.dense.matrix_B.0': 'attn1.to_out.0.lora_B.weight' 'attention.dense.matrix_B.0': 'attn1.to_out.0.lora_B.weight'
``` ```
export_sat_lora_weight.py を使用して、SATチェックポイントをHF LoRA形式に変換できます。
Using `export_sat_lora_weight.py` will convert these to the HF format Lora structure.
![alt text](../resources/hf_lora_weights.png) ![alt text](../resources/hf_lora_weights.png)

View File

@ -1,6 +1,6 @@
# SAT CogVideoX-2B # SAT CogVideoX
[Read this in English.](./README_zh) [Read this in English.](./README.md)
[日本語で読む](./README_ja.md) [日本語で読む](./README_ja.md)
@ -20,6 +20,15 @@ pip install -r requirements.txt
首先,前往 SAT 镜像下载模型权重。 首先,前往 SAT 镜像下载模型权重。
#### CogVideoX1.5 模型
```shell
git lfs install
git clone https://huggingface.co/THUDM/CogVideoX1.5-5B-SAT
```
此操作会下载 Transformers, VAE, T5 Encoder 这三个模型。
#### CogVideoX 模型
对于 CogVideoX-2B 模型,请按照如下方式下载: 对于 CogVideoX-2B 模型,请按照如下方式下载:
```shell ```shell
@ -82,11 +91,11 @@ mv CogVideoX-2b/text_encoder/* CogVideoX-2b/tokenizer/* t5-v1_1-xxl
0 directories, 8 files 0 directories, 8 files
``` ```
### 3. 修改`configs/cogvideox_2b.yaml`中的文件。 ### 3. 修改`configs/cogvideox_*.yaml`中的文件。
```yaml ```yaml
model: model:
scale_factor: 1.15258426 scale_factor: 1.55258426
disable_first_stage_autocast: true disable_first_stage_autocast: true
log_keys: log_keys:
- txt - txt
@ -253,7 +262,7 @@ args:
batch_size: 1 batch_size: 1
input_type: txt #可以选择txt纯文字档作为输入或者改成cli命令行作为输入 input_type: txt #可以选择txt纯文字档作为输入或者改成cli命令行作为输入
input_file: configs/test.txt #纯文字档,可以对此做编辑 input_file: configs/test.txt #纯文字档,可以对此做编辑
sampling_num_frames: 13 # Must be 13, 11 or 9 sampling_num_frames: 13 #CogVideoX1.5-5B 必须是 42 或 22。 CogVideoX-5B / 2B 必须是 13 11 或 9。
sampling_fps: 8 sampling_fps: 8
fp16: True # For CogVideoX-2B fp16: True # For CogVideoX-2B
# bf16: True # For CogVideoX-5B # bf16: True # For CogVideoX-5B
@ -346,7 +355,7 @@ Encoder 使用。
```yaml ```yaml
model: model:
scale_factor: 1.15258426 scale_factor: 1.55258426
disable_first_stage_autocast: true disable_first_stage_autocast: true
not_trainable_prefixes: [ 'all' ] ## 解除注释 not_trainable_prefixes: [ 'all' ] ## 解除注释
log_keys: log_keys:

View File

@ -36,6 +36,7 @@ def add_sampling_config_args(parser):
group.add_argument("--input-dir", type=str, default=None) group.add_argument("--input-dir", type=str, default=None)
group.add_argument("--input-type", type=str, default="cli") group.add_argument("--input-type", type=str, default="cli")
group.add_argument("--input-file", type=str, default="input.txt") group.add_argument("--input-file", type=str, default="input.txt")
group.add_argument("--sampling-image-size", type=list, default=[768, 1360])
group.add_argument("--final-size", type=int, default=2048) group.add_argument("--final-size", type=int, default=2048)
group.add_argument("--sdedit", action="store_true") group.add_argument("--sdedit", action="store_true")
group.add_argument("--grid-num-rows", type=int, default=1) group.add_argument("--grid-num-rows", type=int, default=1)

View File

@ -0,0 +1,149 @@
model:
scale_factor: 0.7
disable_first_stage_autocast: true
latent_input: true
log_keys:
- txt
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
params:
num_idx: 1000
quantize_c_noise: False
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.VideoScaling
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
network_config:
target: dit_video_concat.DiffusionTransformer
params:
time_embed_dim: 512
elementwise_affine: True
num_frames: 81
time_compressed_rate: 4
latent_width: 300
latent_height: 300
num_layers: 42
patch_size: [2, 2, 2]
in_channels: 16
out_channels: 16
hidden_size: 3072
adm_in_channels: 256
num_attention_heads: 48
transformer_args:
checkpoint_activations: True
vocab_size: 1
max_sequence_length: 64
layernorm_order: pre
skip_init: false
model_parallel_size: 1
is_decoder: false
modules:
pos_embed_config:
target: dit_video_concat.Rotary3DPositionEmbeddingMixin
params:
hidden_size_head: 64
text_length: 224
patch_embed_config:
target: dit_video_concat.ImagePatchEmbeddingMixin
params:
text_hidden_size: 4096
adaln_layer_config:
target: dit_video_concat.AdaLNMixin
params:
qk_ln: True
final_layer_config:
target: dit_video_concat.FinalLayerMixin
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
- is_trainable: false
input_key: txt
ucg_rate: 0.1
target: sgm.modules.encoders.modules.FrozenT5Embedder
params:
model_dir: "google/t5-v1_1-xxl"
max_length: 224
first_stage_config:
target : vae_modules.autoencoder.VideoAutoencoderInferenceWrapper
params:
cp_size: 1
ckpt_path: "cogvideox-5b-sat/vae/3d-vae.pt"
ignore_keys: ['loss']
loss_config:
target: torch.nn.Identity
regularizer_config:
target: vae_modules.regularizers.DiagonalGaussianRegularizer
encoder_config:
target: vae_modules.cp_enc_dec.ContextParallelEncoder3D
params:
double_z: true
z_channels: 16
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 2, 4]
attn_resolutions: []
num_res_blocks: 3
dropout: 0.0
gather_norm: True
decoder_config:
target: vae_modules.cp_enc_dec.ContextParallelDecoder3D
params:
double_z: True
z_channels: 16
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 2, 4]
attn_resolutions: []
num_res_blocks: 3
dropout: 0.0
gather_norm: True
loss_fn_config:
target: sgm.modules.diffusionmodules.loss.VideoDiffusionLoss
params:
offset_noise_level: 0
sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
params:
uniform_sampling: True
group_num: 40
num_idx: 1000
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
sampler_config:
target: sgm.modules.diffusionmodules.sampling.VPSDEDPMPP2MSampler
params:
num_steps: 50
verbose: True
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
guider_config:
target: sgm.modules.diffusionmodules.guiders.DynamicCFG
params:
scale: 6
exp: 5
num_steps: 50

View File

@ -0,0 +1,160 @@
model:
scale_factor: 0.7
disable_first_stage_autocast: true
latent_input: false
noised_image_input: true
noised_image_all_concat: false
noised_image_dropout: 0.05
augmentation_dropout: 0.15
log_keys:
- txt
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
params:
num_idx: 1000
quantize_c_noise: False
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.VideoScaling
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
network_config:
target: dit_video_concat.DiffusionTransformer
params:
# space_interpolation: 1.875
ofs_embed_dim: 512
time_embed_dim: 512
elementwise_affine: True
num_frames: 81
time_compressed_rate: 4
latent_width: 300
latent_height: 300
num_layers: 42
patch_size: [2, 2, 2]
in_channels: 32
out_channels: 16
hidden_size: 3072
adm_in_channels: 256
num_attention_heads: 48
transformer_args:
checkpoint_activations: True
vocab_size: 1
max_sequence_length: 64
layernorm_order: pre
skip_init: false
model_parallel_size: 1
is_decoder: false
modules:
pos_embed_config:
target: dit_video_concat.Rotary3DPositionEmbeddingMixin
params:
hidden_size_head: 64
text_length: 224
patch_embed_config:
target: dit_video_concat.ImagePatchEmbeddingMixin
params:
text_hidden_size: 4096
adaln_layer_config:
target: dit_video_concat.AdaLNMixin
params:
qk_ln: True
final_layer_config:
target: dit_video_concat.FinalLayerMixin
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
- is_trainable: false
input_key: txt
ucg_rate: 0.1
target: sgm.modules.encoders.modules.FrozenT5Embedder
params:
model_dir: "google/t5-v1_1-xxl"
max_length: 224
first_stage_config:
target : vae_modules.autoencoder.VideoAutoencoderInferenceWrapper
params:
cp_size: 1
ckpt_path: "cogvideox-5b-i2v-sat/vae/3d-vae.pt"
ignore_keys: ['loss']
loss_config:
target: torch.nn.Identity
regularizer_config:
target: vae_modules.regularizers.DiagonalGaussianRegularizer
encoder_config:
target: vae_modules.cp_enc_dec.ContextParallelEncoder3D
params:
double_z: true
z_channels: 16
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 2, 4]
attn_resolutions: []
num_res_blocks: 3
dropout: 0.0
gather_norm: True
decoder_config:
target: vae_modules.cp_enc_dec.ContextParallelDecoder3D
params:
double_z: True
z_channels: 16
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 2, 4]
attn_resolutions: []
num_res_blocks: 3
dropout: 0.0
gather_norm: True
loss_fn_config:
target: sgm.modules.diffusionmodules.loss.VideoDiffusionLoss
params:
fixed_frames: 0
offset_noise_level: 0.0
sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
params:
uniform_sampling: True
group_num: 40
num_idx: 1000
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
sampler_config:
target: sgm.modules.diffusionmodules.sampling.VPSDEDPMPP2MSampler
params:
fixed_frames: 0
num_steps: 50
verbose: True
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
guider_config:
target: sgm.modules.diffusionmodules.guiders.DynamicCFG
params:
scale: 6
exp: 5
num_steps: 50

View File

@ -179,14 +179,31 @@ class SATVideoDiffusionEngine(nn.Module):
n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0]) n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])
n_rounds = math.ceil(z.shape[0] / n_samples) n_rounds = math.ceil(z.shape[0] / n_samples)
all_out = [] all_out = []
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast): for n in range(n_rounds):
for n in range(n_rounds): z_now = z[n * n_samples : (n + 1) * n_samples, :, 1:]
if isinstance(self.first_stage_model.decoder, VideoDecoder): latent_time = z_now.shape[2] # check the time latent
kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])} temporal_compress_times = 4
else:
kwargs = {} fake_cp_size = min(10, latent_time // 2)
out = self.first_stage_model.decode(z[n * n_samples : (n + 1) * n_samples], **kwargs) start_frame = 0
all_out.append(out)
recons = []
start_frame = 0
for i in range(fake_cp_size):
end_frame = start_frame + latent_time // fake_cp_size + (1 if i < latent_time % fake_cp_size else 0)
fake_cp_rank0 = True if i == 0 else False
clear_fake_cp_cache = True if i == fake_cp_size - 1 else False
with torch.no_grad():
recon = self.first_stage_model.decode(
z_now[:, :, start_frame:end_frame].contiguous(),
clear_fake_cp_cache=clear_fake_cp_cache,
fake_cp_rank0=fake_cp_rank0,
)
recons.append(recon)
start_frame = end_frame
recons = torch.cat(recons, dim=2)
all_out.append(recons)
out = torch.cat(all_out, dim=0) out = torch.cat(all_out, dim=0)
return out return out
@ -218,6 +235,7 @@ class SATVideoDiffusionEngine(nn.Module):
shape: Union[None, Tuple, List] = None, shape: Union[None, Tuple, List] = None,
prefix=None, prefix=None,
concat_images=None, concat_images=None,
ofs=None,
**kwargs, **kwargs,
): ):
randn = torch.randn(batch_size, *shape).to(torch.float32).to(self.device) randn = torch.randn(batch_size, *shape).to(torch.float32).to(self.device)
@ -241,7 +259,7 @@ class SATVideoDiffusionEngine(nn.Module):
self.model, input, sigma, c, concat_images=concat_images, **addtional_model_inputs self.model, input, sigma, c, concat_images=concat_images, **addtional_model_inputs
) )
samples = self.sampler(denoiser, randn, cond, uc=uc, scale=scale, scale_emb=scale_emb) samples = self.sampler(denoiser, randn, cond, uc=uc, scale=scale, scale_emb=scale_emb, ofs=ofs)
samples = samples.to(self.dtype) samples = samples.to(self.dtype)
return samples return samples

View File

@ -1,5 +1,7 @@
from functools import partial from functools import partial
from einops import rearrange, repeat from einops import rearrange, repeat
from functools import reduce
from operator import mul
import numpy as np import numpy as np
import torch import torch
@ -13,38 +15,34 @@ from sat.mpu.layers import ColumnParallelLinear
from sgm.util import instantiate_from_config from sgm.util import instantiate_from_config
from sgm.modules.diffusionmodules.openaimodel import Timestep from sgm.modules.diffusionmodules.openaimodel import Timestep
from sgm.modules.diffusionmodules.util import ( from sgm.modules.diffusionmodules.util import linear, timestep_embedding
linear,
timestep_embedding,
)
from sat.ops.layernorm import LayerNorm, RMSNorm from sat.ops.layernorm import LayerNorm, RMSNorm
class ImagePatchEmbeddingMixin(BaseMixin): class ImagePatchEmbeddingMixin(BaseMixin):
def __init__( def __init__(self, in_channels, hidden_size, patch_size, text_hidden_size=None):
self,
in_channels,
hidden_size,
patch_size,
bias=True,
text_hidden_size=None,
):
super().__init__() super().__init__()
self.proj = nn.Conv2d(in_channels, hidden_size, kernel_size=patch_size, stride=patch_size, bias=bias) self.patch_size = patch_size
self.proj = nn.Linear(in_channels * reduce(mul, patch_size), hidden_size)
if text_hidden_size is not None: if text_hidden_size is not None:
self.text_proj = nn.Linear(text_hidden_size, hidden_size) self.text_proj = nn.Linear(text_hidden_size, hidden_size)
else: else:
self.text_proj = None self.text_proj = None
def word_embedding_forward(self, input_ids, **kwargs): def word_embedding_forward(self, input_ids, **kwargs):
# now is 3d patch
images = kwargs["images"] # (b,t,c,h,w) images = kwargs["images"] # (b,t,c,h,w)
B, T = images.shape[:2] emb = rearrange(images, "b t c h w -> b (t h w) c")
emb = images.view(-1, *images.shape[2:]) emb = rearrange(
emb = self.proj(emb) # ((b t),d,h/2,w/2) emb,
emb = emb.view(B, T, *emb.shape[1:]) "b (t o h p w q) c -> b (t h w) (c o p q)",
emb = emb.flatten(3).transpose(2, 3) # (b,t,n,d) t=kwargs["rope_T"],
emb = rearrange(emb, "b t n d -> b (t n) d") h=kwargs["rope_H"],
w=kwargs["rope_W"],
o=self.patch_size[0],
p=self.patch_size[1],
q=self.patch_size[2],
)
emb = self.proj(emb)
if self.text_proj is not None: if self.text_proj is not None:
text_emb = self.text_proj(kwargs["encoder_outputs"]) text_emb = self.text_proj(kwargs["encoder_outputs"])
@ -74,7 +72,8 @@ def get_3d_sincos_pos_embed(
grid_size: int of the grid height and width grid_size: int of the grid height and width
t_size: int of the temporal size t_size: int of the temporal size
return: return:
pos_embed: [t_size*grid_size*grid_size, embed_dim] or [1+t_size*grid_size*grid_size, embed_dim] (w/ or w/o cls_token) pos_embed: [t_size*grid_size * grid_size, embed_dim] or [1+t_size*grid_size * grid_size, embed_dim]
(w/ or w/o cls_token)
""" """
assert embed_dim % 4 == 0 assert embed_dim % 4 == 0
embed_dim_spatial = embed_dim // 4 * 3 embed_dim_spatial = embed_dim // 4 * 3
@ -100,7 +99,6 @@ def get_3d_sincos_pos_embed(
pos_embed_spatial = np.repeat(pos_embed_spatial, t_size, axis=0) # [T, H*W, D // 4 * 3] pos_embed_spatial = np.repeat(pos_embed_spatial, t_size, axis=0) # [T, H*W, D // 4 * 3]
pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1) pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1)
# pos_embed = pos_embed.reshape([-1, embed_dim]) # [T*H*W, D]
return pos_embed # [T, H*W, D] return pos_embed # [T, H*W, D]
@ -259,6 +257,9 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin):
text_length, text_length,
theta=10000, theta=10000,
rot_v=False, rot_v=False,
height_interpolation=1.0,
width_interpolation=1.0,
time_interpolation=1.0,
learnable_pos_embed=False, learnable_pos_embed=False,
): ):
super().__init__() super().__init__()
@ -285,14 +286,10 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin):
freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2) freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1) freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
freqs = rearrange(freqs, "t h w d -> (t h w) d")
freqs = freqs.contiguous() freqs = freqs.contiguous()
freqs_sin = freqs.sin() self.freqs_sin = freqs.sin().cuda()
freqs_cos = freqs.cos() self.freqs_cos = freqs.cos().cuda()
self.register_buffer("freqs_sin", freqs_sin)
self.register_buffer("freqs_cos", freqs_cos)
self.text_length = text_length self.text_length = text_length
if learnable_pos_embed: if learnable_pos_embed:
num_patches = height * width * compressed_num_frames + text_length num_patches = height * width * compressed_num_frames + text_length
@ -301,15 +298,20 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin):
self.pos_embedding = None self.pos_embedding = None
def rotary(self, t, **kwargs): def rotary(self, t, **kwargs):
seq_len = t.shape[2] def reshape_freq(freqs):
freqs_cos = self.freqs_cos[:seq_len].unsqueeze(0).unsqueeze(0) freqs = freqs[: kwargs["rope_T"], : kwargs["rope_H"], : kwargs["rope_W"]].contiguous()
freqs_sin = self.freqs_sin[:seq_len].unsqueeze(0).unsqueeze(0) freqs = rearrange(freqs, "t h w d -> (t h w) d")
freqs = freqs.unsqueeze(0).unsqueeze(0)
return freqs
freqs_cos = reshape_freq(self.freqs_cos).to(t.dtype)
freqs_sin = reshape_freq(self.freqs_sin).to(t.dtype)
return t * freqs_cos + rotate_half(t) * freqs_sin return t * freqs_cos + rotate_half(t) * freqs_sin
def position_embedding_forward(self, position_ids, **kwargs): def position_embedding_forward(self, position_ids, **kwargs):
if self.pos_embedding is not None: if self.pos_embedding is not None:
return self.pos_embedding[:, :self.text_length + kwargs["seq_length"]] return self.pos_embedding[:, : self.text_length + kwargs["seq_length"]]
else: else:
return None return None
@ -326,10 +328,61 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin):
): ):
attention_fn_default = HOOKS_DEFAULT["attention_fn"] attention_fn_default = HOOKS_DEFAULT["attention_fn"]
query_layer[:, :, self.text_length :] = self.rotary(query_layer[:, :, self.text_length :]) query_layer = torch.cat(
key_layer[:, :, self.text_length :] = self.rotary(key_layer[:, :, self.text_length :]) (
query_layer[
:,
:,
: kwargs["text_length"],
],
self.rotary(
query_layer[
:,
:,
kwargs["text_length"] :,
],
**kwargs,
),
),
dim=2,
)
key_layer = torch.cat(
(
key_layer[
:,
:,
: kwargs["text_length"],
],
self.rotary(
key_layer[
:,
:,
kwargs["text_length"] :,
],
**kwargs,
),
),
dim=2,
)
if self.rot_v: if self.rot_v:
value_layer[:, :, self.text_length :] = self.rotary(value_layer[:, :, self.text_length :]) value_layer = torch.cat(
(
value_layer[
:,
:,
: kwargs["text_length"],
],
self.rotary(
value_layer[
:,
:,
kwargs["text_length"] :,
],
**kwargs,
),
),
dim=2,
)
return attention_fn_default( return attention_fn_default(
query_layer, query_layer,
@ -347,21 +400,25 @@ def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def unpatchify(x, c, p, w, h, rope_position_ids=None, **kwargs): def unpatchify(x, c, patch_size, w, h, **kwargs):
""" """
x: (N, T/2 * S, patch_size**3 * C) x: (N, T/2 * S, patch_size**3 * C)
imgs: (N, T, H, W, C) imgs: (N, T, H, W, C)
patch_size 被拆解为三个不同的维度 (o, p, q)分别对应了深度o高度p和宽度q这使得 patch 大小在不同维度上可以不相等增加了灵活性
""" """
if rope_position_ids is not None:
assert NotImplementedError imgs = rearrange(
# do pix2struct unpatchify x,
L = x.shape[1] "b (t h w) (c o p q) -> b (t o) c (h p) (w q)",
x = x.reshape(shape=(x.shape[0], L, p, p, c)) c=c,
x = torch.einsum("nlpqc->ncplq", x) o=patch_size[0],
imgs = x.reshape(shape=(x.shape[0], c, p, L * p)) p=patch_size[1],
else: q=patch_size[2],
b = x.shape[0] t=kwargs["rope_T"],
imgs = rearrange(x, "b (t h w) (c p q) -> b t c (h p) (w q)", b=b, h=h, w=w, c=c, p=p, q=p) h=kwargs["rope_H"],
w=kwargs["rope_W"],
)
return imgs return imgs
@ -382,27 +439,17 @@ class FinalLayerMixin(BaseMixin):
self.patch_size = patch_size self.patch_size = patch_size
self.out_channels = out_channels self.out_channels = out_channels
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=elementwise_affine, eps=1e-6) self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=elementwise_affine, eps=1e-6)
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) self.linear = nn.Linear(hidden_size, reduce(mul, patch_size) * out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim, 2 * hidden_size, bias=True)) self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim, 2 * hidden_size, bias=True))
self.spatial_length = latent_width * latent_height // patch_size**2
self.latent_width = latent_width
self.latent_height = latent_height
def final_forward(self, logits, **kwargs): def final_forward(self, logits, **kwargs):
x, emb = logits[:, kwargs["text_length"] :, :], kwargs["emb"] # x:(b,(t n),d) x, emb = logits[:, kwargs["text_length"] :, :], kwargs["emb"] # x:(b,(t n),d),只取了x中后面images的部分
shift, scale = self.adaLN_modulation(emb).chunk(2, dim=1) shift, scale = self.adaLN_modulation(emb).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale) x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x) x = self.linear(x)
return unpatchify( return unpatchify(
x, x, c=self.out_channels, patch_size=self.patch_size, w=kwargs["rope_W"], h=kwargs["rope_H"], **kwargs
c=self.out_channels,
p=self.patch_size,
w=self.latent_width // self.patch_size,
h=self.latent_height // self.patch_size,
rope_position_ids=kwargs.get("rope_position_ids", None),
**kwargs,
) )
def reinit(self, parent_model=None): def reinit(self, parent_model=None):
@ -440,8 +487,6 @@ class SwiGLUMixin(BaseMixin):
class AdaLNMixin(BaseMixin): class AdaLNMixin(BaseMixin):
def __init__( def __init__(
self, self,
width,
height,
hidden_size, hidden_size,
num_layers, num_layers,
time_embed_dim, time_embed_dim,
@ -452,8 +497,6 @@ class AdaLNMixin(BaseMixin):
): ):
super().__init__() super().__init__()
self.num_layers = num_layers self.num_layers = num_layers
self.width = width
self.height = height
self.compressed_num_frames = compressed_num_frames self.compressed_num_frames = compressed_num_frames
self.adaLN_modulations = nn.ModuleList( self.adaLN_modulations = nn.ModuleList(
@ -611,7 +654,7 @@ class DiffusionTransformer(BaseModel):
time_interpolation=1.0, time_interpolation=1.0,
use_SwiGLU=False, use_SwiGLU=False,
use_RMSNorm=False, use_RMSNorm=False,
zero_init_y_embed=False, ofs_embed_dim=None,
**kwargs, **kwargs,
): ):
self.latent_width = latent_width self.latent_width = latent_width
@ -619,12 +662,13 @@ class DiffusionTransformer(BaseModel):
self.patch_size = patch_size self.patch_size = patch_size
self.num_frames = num_frames self.num_frames = num_frames
self.time_compressed_rate = time_compressed_rate self.time_compressed_rate = time_compressed_rate
self.spatial_length = latent_width * latent_height // patch_size**2 self.spatial_length = latent_width * latent_height // reduce(mul, patch_size[1:])
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.model_channels = hidden_size self.model_channels = hidden_size
self.time_embed_dim = time_embed_dim if time_embed_dim is not None else hidden_size self.time_embed_dim = time_embed_dim if time_embed_dim is not None else hidden_size
self.ofs_embed_dim = ofs_embed_dim
self.num_classes = num_classes self.num_classes = num_classes
self.adm_in_channels = adm_in_channels self.adm_in_channels = adm_in_channels
self.input_time = input_time self.input_time = input_time
@ -636,7 +680,6 @@ class DiffusionTransformer(BaseModel):
self.width_interpolation = width_interpolation self.width_interpolation = width_interpolation
self.time_interpolation = time_interpolation self.time_interpolation = time_interpolation
self.inner_hidden_size = hidden_size * 4 self.inner_hidden_size = hidden_size * 4
self.zero_init_y_embed = zero_init_y_embed
try: try:
self.dtype = str_to_dtype[kwargs.pop("dtype")] self.dtype = str_to_dtype[kwargs.pop("dtype")]
except: except:
@ -669,7 +712,6 @@ class DiffusionTransformer(BaseModel):
def _build_modules(self, module_configs): def _build_modules(self, module_configs):
model_channels = self.hidden_size model_channels = self.hidden_size
# time_embed_dim = model_channels * 4
time_embed_dim = self.time_embed_dim time_embed_dim = self.time_embed_dim
self.time_embed = nn.Sequential( self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim), linear(model_channels, time_embed_dim),
@ -677,6 +719,13 @@ class DiffusionTransformer(BaseModel):
linear(time_embed_dim, time_embed_dim), linear(time_embed_dim, time_embed_dim),
) )
if self.ofs_embed_dim is not None:
self.ofs_embed = nn.Sequential(
linear(self.ofs_embed_dim, self.ofs_embed_dim),
nn.SiLU(),
linear(self.ofs_embed_dim, self.ofs_embed_dim),
)
if self.num_classes is not None: if self.num_classes is not None:
if isinstance(self.num_classes, int): if isinstance(self.num_classes, int):
self.label_emb = nn.Embedding(self.num_classes, time_embed_dim) self.label_emb = nn.Embedding(self.num_classes, time_embed_dim)
@ -701,9 +750,6 @@ class DiffusionTransformer(BaseModel):
linear(time_embed_dim, time_embed_dim), linear(time_embed_dim, time_embed_dim),
) )
) )
if self.zero_init_y_embed:
nn.init.constant_(self.label_emb[0][2].weight, 0)
nn.init.constant_(self.label_emb[0][2].bias, 0)
else: else:
raise ValueError() raise ValueError()
@ -712,10 +758,13 @@ class DiffusionTransformer(BaseModel):
"pos_embed", "pos_embed",
instantiate_from_config( instantiate_from_config(
pos_embed_config, pos_embed_config,
height=self.latent_height // self.patch_size, height=self.latent_height // self.patch_size[1],
width=self.latent_width // self.patch_size, width=self.latent_width // self.patch_size[2],
compressed_num_frames=(self.num_frames - 1) // self.time_compressed_rate + 1, compressed_num_frames=(self.num_frames - 1) // self.time_compressed_rate + 1,
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
height_interpolation=self.height_interpolation,
width_interpolation=self.width_interpolation,
time_interpolation=self.time_interpolation,
), ),
reinit=True, reinit=True,
) )
@ -737,8 +786,6 @@ class DiffusionTransformer(BaseModel):
"adaln_layer", "adaln_layer",
instantiate_from_config( instantiate_from_config(
adaln_layer_config, adaln_layer_config,
height=self.latent_height // self.patch_size,
width=self.latent_width // self.patch_size,
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_layers=self.num_layers, num_layers=self.num_layers,
compressed_num_frames=(self.num_frames - 1) // self.time_compressed_rate + 1, compressed_num_frames=(self.num_frames - 1) // self.time_compressed_rate + 1,
@ -749,7 +796,6 @@ class DiffusionTransformer(BaseModel):
) )
else: else:
raise NotImplementedError raise NotImplementedError
final_layer_config = module_configs["final_layer_config"] final_layer_config = module_configs["final_layer_config"]
self.add_mixin( self.add_mixin(
"final_layer", "final_layer",
@ -766,25 +812,18 @@ class DiffusionTransformer(BaseModel):
reinit=True, reinit=True,
) )
if "lora_config" in module_configs:
lora_config = module_configs["lora_config"]
self.add_mixin("lora", instantiate_from_config(lora_config, layer_num=self.num_layers), reinit=True)
return return
def forward(self, x, timesteps=None, context=None, y=None, **kwargs): def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
b, t, d, h, w = x.shape b, t, d, h, w = x.shape
if x.dtype != self.dtype: if x.dtype != self.dtype:
x = x.to(self.dtype) x = x.to(self.dtype)
# This is not use in inference
if "concat_images" in kwargs and kwargs["concat_images"] is not None: if "concat_images" in kwargs and kwargs["concat_images"] is not None:
if kwargs["concat_images"].shape[0] != x.shape[0]: if kwargs["concat_images"].shape[0] != x.shape[0]:
concat_images = kwargs["concat_images"].repeat(2, 1, 1, 1, 1) concat_images = kwargs["concat_images"].repeat(2, 1, 1, 1, 1)
else: else:
concat_images = kwargs["concat_images"] concat_images = kwargs["concat_images"]
x = torch.cat([x, concat_images], dim=2) x = torch.cat([x, concat_images], dim=2)
assert (y is not None) == ( assert (y is not None) == (
self.num_classes is not None self.num_classes is not None
), "must specify y if and only if the model is class-conditional" ), "must specify y if and only if the model is class-conditional"
@ -792,17 +831,25 @@ class DiffusionTransformer(BaseModel):
emb = self.time_embed(t_emb) emb = self.time_embed(t_emb)
if self.num_classes is not None: if self.num_classes is not None:
# assert y.shape[0] == x.shape[0]
assert x.shape[0] % y.shape[0] == 0 assert x.shape[0] % y.shape[0] == 0
y = y.repeat_interleave(x.shape[0] // y.shape[0], dim=0) y = y.repeat_interleave(x.shape[0] // y.shape[0], dim=0)
emb = emb + self.label_emb(y) emb = emb + self.label_emb(y)
kwargs["seq_length"] = t * h * w // (self.patch_size**2) if self.ofs_embed_dim is not None:
ofs_emb = timestep_embedding(kwargs["ofs"], self.ofs_embed_dim, repeat_only=False, dtype=self.dtype)
ofs_emb = self.ofs_embed(ofs_emb)
emb = emb + ofs_emb
kwargs["seq_length"] = t * h * w // reduce(mul, self.patch_size)
kwargs["images"] = x kwargs["images"] = x
kwargs["emb"] = emb kwargs["emb"] = emb
kwargs["encoder_outputs"] = context kwargs["encoder_outputs"] = context
kwargs["text_length"] = context.shape[1] kwargs["text_length"] = context.shape[1]
kwargs["rope_T"] = t // self.patch_size[0]
kwargs["rope_H"] = h // self.patch_size[1]
kwargs["rope_W"] = w // self.patch_size[2]
kwargs["input_ids"] = kwargs["position_ids"] = kwargs["attention_mask"] = torch.ones((1, 1)).to(x.dtype) kwargs["input_ids"] = kwargs["position_ids"] = kwargs["attention_mask"] = torch.ones((1, 1)).to(x.dtype)
output = super().forward(**kwargs)[0] output = super().forward(**kwargs)[0]
return output return output

View File

@ -4,7 +4,7 @@ echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
environs="WORLD_SIZE=1 RANK=0 LOCAL_RANK=0 LOCAL_WORLD_SIZE=1" environs="WORLD_SIZE=1 RANK=0 LOCAL_RANK=0 LOCAL_WORLD_SIZE=1"
run_cmd="$environs python sample_video.py --base configs/cogvideox_5b.yaml configs/inference.yaml --seed $RANDOM" run_cmd="$environs python sample_video.py --base configs/cogvideox1.5_5b.yaml configs/test_inference.yaml --seed $RANDOM"
echo ${run_cmd} echo ${run_cmd}
eval ${run_cmd} eval ${run_cmd}

View File

@ -1,16 +1,11 @@
SwissArmyTransformer==0.4.12 SwissArmyTransformer>=0.4.12
omegaconf==2.3.0 omegaconf>=2.3.0
torch==2.4.0 pytorch_lightning>=2.4.0
torchvision==0.19.0 kornia>=0.7.3
pytorch_lightning==2.3.3 beartype>=0.19.0
kornia==0.7.3 fsspec>=2024.2.0
beartype==0.18.5 safetensors>=0.4.5
numpy==2.0.1 scipy>=1.14.1
fsspec==2024.5.0 decord>=0.6.0
safetensors==0.4.3 wandb>=0.18.5
imageio-ffmpeg==0.5.1 deepspeed>=0.15.3
imageio==2.34.2
scipy==1.14.0
decord==0.6.0
wandb==0.17.5
deepspeed==0.14.4

View File

@ -4,24 +4,20 @@ import argparse
from typing import List, Union from typing import List, Union
from tqdm import tqdm from tqdm import tqdm
from omegaconf import ListConfig from omegaconf import ListConfig
from PIL import Image
import imageio import imageio
import torch import torch
import numpy as np import numpy as np
from einops import rearrange from einops import rearrange, repeat
import torchvision.transforms as TT import torchvision.transforms as TT
from sat.model.base_model import get_model from sat.model.base_model import get_model
from sat.training.model_io import load_checkpoint from sat.training.model_io import load_checkpoint
from sat import mpu from sat import mpu
from diffusion_video import SATVideoDiffusionEngine from diffusion_video import SATVideoDiffusionEngine
from arguments import get_args from arguments import get_args
from torchvision.transforms.functional import center_crop, resize
from torchvision.transforms import InterpolationMode
from PIL import Image
def read_from_cli(): def read_from_cli():
cnt = 0 cnt = 0
@ -56,6 +52,42 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], T=None, device="cuda
if key == "txt": if key == "txt":
batch["txt"] = np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist() batch["txt"] = np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist()
batch_uc["txt"] = np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist() batch_uc["txt"] = np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist()
elif key == "original_size_as_tuple":
batch["original_size_as_tuple"] = (
torch.tensor([value_dict["orig_height"], value_dict["orig_width"]]).to(device).repeat(*N, 1)
)
elif key == "crop_coords_top_left":
batch["crop_coords_top_left"] = (
torch.tensor([value_dict["crop_coords_top"], value_dict["crop_coords_left"]]).to(device).repeat(*N, 1)
)
elif key == "aesthetic_score":
batch["aesthetic_score"] = torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
batch_uc["aesthetic_score"] = (
torch.tensor([value_dict["negative_aesthetic_score"]]).to(device).repeat(*N, 1)
)
elif key == "target_size_as_tuple":
batch["target_size_as_tuple"] = (
torch.tensor([value_dict["target_height"], value_dict["target_width"]]).to(device).repeat(*N, 1)
)
elif key == "fps":
batch[key] = torch.tensor([value_dict["fps"]]).to(device).repeat(math.prod(N))
elif key == "fps_id":
batch[key] = torch.tensor([value_dict["fps_id"]]).to(device).repeat(math.prod(N))
elif key == "motion_bucket_id":
batch[key] = torch.tensor([value_dict["motion_bucket_id"]]).to(device).repeat(math.prod(N))
elif key == "pool_image":
batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=math.prod(N)).to(device, dtype=torch.half)
elif key == "cond_aug":
batch[key] = repeat(
torch.tensor([value_dict["cond_aug"]]).to("cuda"),
"1 -> b",
b=math.prod(N),
)
elif key == "cond_frames":
batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0])
elif key == "cond_frames_without_noise":
batch[key] = repeat(value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0])
else: else:
batch[key] = value_dict[key] batch[key] = value_dict[key]
@ -83,37 +115,6 @@ def save_video_as_grid_and_mp4(video_batch: torch.Tensor, save_path: str, fps: i
writer.append_data(frame) writer.append_data(frame)
def resize_for_rectangle_crop(arr, image_size, reshape_mode="random"):
if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
arr = resize(
arr,
size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
interpolation=InterpolationMode.BICUBIC,
)
else:
arr = resize(
arr,
size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
interpolation=InterpolationMode.BICUBIC,
)
h, w = arr.shape[2], arr.shape[3]
arr = arr.squeeze(0)
delta_h = h - image_size[0]
delta_w = w - image_size[1]
if reshape_mode == "random" or reshape_mode == "none":
top = np.random.randint(0, delta_h + 1)
left = np.random.randint(0, delta_w + 1)
elif reshape_mode == "center":
top, left = delta_h // 2, delta_w // 2
else:
raise NotImplementedError
arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
return arr
def sampling_main(args, model_cls): def sampling_main(args, model_cls):
if isinstance(model_cls, type): if isinstance(model_cls, type):
model = get_model(args, model_cls) model = get_model(args, model_cls)
@ -127,45 +128,62 @@ def sampling_main(args, model_cls):
data_iter = read_from_cli() data_iter = read_from_cli()
elif args.input_type == "txt": elif args.input_type == "txt":
rank, world_size = mpu.get_data_parallel_rank(), mpu.get_data_parallel_world_size() rank, world_size = mpu.get_data_parallel_rank(), mpu.get_data_parallel_world_size()
print("rank and world_size", rank, world_size)
data_iter = read_from_file(args.input_file, rank=rank, world_size=world_size) data_iter = read_from_file(args.input_file, rank=rank, world_size=world_size)
else: else:
raise NotImplementedError raise NotImplementedError
image_size = [480, 720]
if args.image2video:
chained_trainsforms = []
chained_trainsforms.append(TT.ToTensor())
transform = TT.Compose(chained_trainsforms)
sample_func = model.sample sample_func = model.sample
T, H, W, C, F = args.sampling_num_frames, image_size[0], image_size[1], args.latent_channels, 8
num_samples = [1] num_samples = [1]
force_uc_zero_embeddings = ["txt"] force_uc_zero_embeddings = ["txt"]
device = model.device T, C = args.sampling_num_frames, args.latent_channels
with torch.no_grad(): with torch.no_grad():
for text, cnt in tqdm(data_iter): for text, cnt in tqdm(data_iter):
if args.image2video: if args.image2video:
# use with input image shape
text, image_path = text.split("@@") text, image_path = text.split("@@")
assert os.path.exists(image_path), image_path assert os.path.exists(image_path), image_path
image = Image.open(image_path).convert("RGB") image = Image.open(image_path).convert("RGB")
(img_W, img_H) = image.size
def nearest_multiple_of_16(n):
lower_multiple = (n // 16) * 16
upper_multiple = (n // 16 + 1) * 16
if abs(n - lower_multiple) < abs(n - upper_multiple):
return lower_multiple
else:
return upper_multiple
if img_H < img_W:
H = 96
W = int(nearest_multiple_of_16(img_W / img_H * H * 8)) // 8
else:
W = 96
H = int(nearest_multiple_of_16(img_H / img_W * W * 8)) // 8
chained_trainsforms = []
chained_trainsforms.append(TT.Resize(size=[int(H * 8), int(W * 8)], interpolation=1))
chained_trainsforms.append(TT.ToTensor())
transform = TT.Compose(chained_trainsforms)
image = transform(image).unsqueeze(0).to("cuda") image = transform(image).unsqueeze(0).to("cuda")
image = resize_for_rectangle_crop(image, image_size, reshape_mode="center").unsqueeze(0)
image = image * 2.0 - 1.0 image = image * 2.0 - 1.0
image = image.unsqueeze(2).to(torch.bfloat16) image = image.unsqueeze(2).to(torch.bfloat16)
image = model.encode_first_stage(image, None) image = model.encode_first_stage(image, None)
image = image / model.scale_factor
image = image.permute(0, 2, 1, 3, 4).contiguous() image = image.permute(0, 2, 1, 3, 4).contiguous()
pad_shape = (image.shape[0], T - 1, C, H // F, W // F) pad_shape = (image.shape[0], T - 1, C, H, W)
image = torch.concat([image, torch.zeros(pad_shape).to(image.device).to(image.dtype)], dim=1) image = torch.concat([image, torch.zeros(pad_shape).to(image.device).to(image.dtype)], dim=1)
else: else:
image_size = args.sampling_image_size
H, W = image_size[0], image_size[1]
F = 8 # 8x downsampled
image = None image = None
value_dict = { text_cast = [text]
"prompt": text, mp_size = mpu.get_model_parallel_world_size()
"negative_prompt": "", global_rank = torch.distributed.get_rank() // mp_size
"num_frames": torch.tensor(T).unsqueeze(0), src = global_rank * mp_size
} torch.distributed.broadcast_object_list(text_cast, src=src, group=mpu.get_model_parallel_group())
text = text_cast[0]
value_dict = {"prompt": text, "negative_prompt": "", "num_frames": torch.tensor(T).unsqueeze(0)}
batch, batch_uc = get_batch( batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples
@ -187,57 +205,42 @@ def sampling_main(args, model_cls):
if not k == "crossattn": if not k == "crossattn":
c[k], uc[k] = map(lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc)) c[k], uc[k] = map(lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc))
if args.image2video and image is not None: if args.image2video:
c["concat"] = image c["concat"] = image
uc["concat"] = image uc["concat"] = image
for index in range(args.batch_size): for index in range(args.batch_size):
# reload model on GPU if args.image2video:
model.to(device) samples_z = sample_func(
samples_z = sample_func( c, uc=uc, batch_size=1, shape=(T, C, H, W), ofs=torch.tensor([2.0]).to("cuda")
c, )
uc=uc, else:
batch_size=1, samples_z = sample_func(
shape=(T, C, H // F, W // F), c,
) uc=uc,
batch_size=1,
shape=(T, C, H // F, W // F),
).to("cuda")
samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous() samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous()
if args.only_save_latents:
# Unload the model from GPU to save GPU memory samples_z = 1.0 / model.scale_factor * samples_z
model.to("cpu") save_path = os.path.join(
torch.cuda.empty_cache() args.output_dir, str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120], str(index)
first_stage_model = model.first_stage_model )
first_stage_model = first_stage_model.to(device) os.makedirs(save_path, exist_ok=True)
torch.save(samples_z, os.path.join(save_path, "latent.pt"))
latent = 1.0 / model.scale_factor * samples_z with open(os.path.join(save_path, "text.txt"), "w") as f:
f.write(text)
# Decode latent serial to save GPU memory else:
recons = [] samples_x = model.decode_first_stage(samples_z).to(torch.float32)
loop_num = (T - 1) // 2 samples_x = samples_x.permute(0, 2, 1, 3, 4).contiguous()
for i in range(loop_num): samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu()
if i == 0: save_path = os.path.join(
start_frame, end_frame = 0, 3 args.output_dir, str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120], str(index)
else: )
start_frame, end_frame = i * 2 + 1, i * 2 + 3 if mpu.get_model_parallel_rank() == 0:
if i == loop_num - 1: save_video_as_grid_and_mp4(samples, save_path, fps=args.sampling_fps)
clear_fake_cp_cache = True
else:
clear_fake_cp_cache = False
with torch.no_grad():
recon = first_stage_model.decode(
latent[:, :, start_frame:end_frame].contiguous(), clear_fake_cp_cache=clear_fake_cp_cache
)
recons.append(recon)
recon = torch.cat(recons, dim=2).to(torch.float32)
samples_x = recon.permute(0, 2, 1, 3, 4).contiguous()
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu()
save_path = os.path.join(
args.output_dir, str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120], str(index)
)
if mpu.get_model_parallel_rank() == 0:
save_video_as_grid_and_mp4(samples, save_path, fps=args.sampling_fps)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,7 +1,8 @@
""" """
Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py
""" """
from typing import Dict, Union from typing import Dict, Union
import torch import torch
@ -16,7 +17,6 @@ from ...modules.diffusionmodules.sampling_utils import (
to_sigma, to_sigma,
) )
from ...util import append_dims, default, instantiate_from_config from ...util import append_dims, default, instantiate_from_config
from ...util import SeededNoise
from .guiders import DynamicCFG from .guiders import DynamicCFG
@ -44,7 +44,9 @@ class BaseDiffusionSampler:
self.device = device self.device = device
def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None): def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
sigmas = self.discretization(self.num_steps if num_steps is None else num_steps, device=self.device) sigmas = self.discretization(
self.num_steps if num_steps is None else num_steps, device=self.device
)
uc = default(uc, cond) uc = default(uc, cond)
x *= torch.sqrt(1.0 + sigmas[0] ** 2.0) x *= torch.sqrt(1.0 + sigmas[0] ** 2.0)
@ -83,7 +85,9 @@ class SingleStepDiffusionSampler(BaseDiffusionSampler):
class EDMSampler(SingleStepDiffusionSampler): class EDMSampler(SingleStepDiffusionSampler):
def __init__(self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs): def __init__(
self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs
):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.s_churn = s_churn self.s_churn = s_churn
@ -102,15 +106,21 @@ class EDMSampler(SingleStepDiffusionSampler):
dt = append_dims(next_sigma - sigma_hat, x.ndim) dt = append_dims(next_sigma - sigma_hat, x.ndim)
euler_step = self.euler_step(x, d, dt) euler_step = self.euler_step(x, d, dt)
x = self.possible_correction_step(euler_step, x, d, dt, next_sigma, denoiser, cond, uc) x = self.possible_correction_step(
euler_step, x, d, dt, next_sigma, denoiser, cond, uc
)
return x return x
def __call__(self, denoiser, x, cond, uc=None, num_steps=None): def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
x, cond, uc, num_steps
)
for i in self.get_sigma_gen(num_sigmas): for i in self.get_sigma_gen(num_sigmas):
gamma = ( gamma = (
min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) if self.s_tmin <= sigmas[i] <= self.s_tmax else 0.0 min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)
if self.s_tmin <= sigmas[i] <= self.s_tmax
else 0.0
) )
x = self.sampler_step( x = self.sampler_step(
s_in * sigmas[i], s_in * sigmas[i],
@ -126,23 +136,30 @@ class EDMSampler(SingleStepDiffusionSampler):
class DDIMSampler(SingleStepDiffusionSampler): class DDIMSampler(SingleStepDiffusionSampler):
def __init__(self, s_noise=0.1, *args, **kwargs): def __init__(
self, s_noise=0.1, *args, **kwargs
):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.s_noise = s_noise self.s_noise = s_noise
def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, s_noise=0.0): def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, s_noise=0.0):
denoised = self.denoise(x, denoiser, sigma, cond, uc) denoised = self.denoise(x, denoiser, sigma, cond, uc)
d = to_d(x, sigma, denoised) d = to_d(x, sigma, denoised)
dt = append_dims(next_sigma * (1 - s_noise**2) ** 0.5 - sigma, x.ndim) dt = append_dims(next_sigma * (1 - s_noise**2)**0.5 - sigma, x.ndim)
euler_step = x + dt * d + s_noise * append_dims(next_sigma, x.ndim) * torch.randn_like(x) euler_step = x + dt * d + s_noise * append_dims(next_sigma, x.ndim) * torch.randn_like(x)
x = self.possible_correction_step(euler_step, x, d, dt, next_sigma, denoiser, cond, uc) x = self.possible_correction_step(
euler_step, x, d, dt, next_sigma, denoiser, cond, uc
)
return x return x
def __call__(self, denoiser, x, cond, uc=None, num_steps=None): def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
x, cond, uc, num_steps
)
for i in self.get_sigma_gen(num_sigmas): for i in self.get_sigma_gen(num_sigmas):
x = self.sampler_step( x = self.sampler_step(
@ -181,7 +198,9 @@ class AncestralSampler(SingleStepDiffusionSampler):
return x return x
def __call__(self, denoiser, x, cond, uc=None, num_steps=None): def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
x, cond, uc, num_steps
)
for i in self.get_sigma_gen(num_sigmas): for i in self.get_sigma_gen(num_sigmas):
x = self.sampler_step( x = self.sampler_step(
@ -208,32 +227,43 @@ class LinearMultistepSampler(BaseDiffusionSampler):
self.order = order self.order = order
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs): def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
x, cond, uc, num_steps
)
ds = [] ds = []
sigmas_cpu = sigmas.detach().cpu().numpy() sigmas_cpu = sigmas.detach().cpu().numpy()
for i in self.get_sigma_gen(num_sigmas): for i in self.get_sigma_gen(num_sigmas):
sigma = s_in * sigmas[i] sigma = s_in * sigmas[i]
denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs) denoised = denoiser(
*self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs
)
denoised = self.guider(denoised, sigma) denoised = self.guider(denoised, sigma)
d = to_d(x, sigma, denoised) d = to_d(x, sigma, denoised)
ds.append(d) ds.append(d)
if len(ds) > self.order: if len(ds) > self.order:
ds.pop(0) ds.pop(0)
cur_order = min(i + 1, self.order) cur_order = min(i + 1, self.order)
coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)] coeffs = [
linear_multistep_coeff(cur_order, sigmas_cpu, i, j)
for j in range(cur_order)
]
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds))) x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
return x return x
class EulerEDMSampler(EDMSampler): class EulerEDMSampler(EDMSampler):
def possible_correction_step(self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc): def possible_correction_step(
self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc
):
return euler_step return euler_step
class HeunEDMSampler(EDMSampler): class HeunEDMSampler(EDMSampler):
def possible_correction_step(self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc): def possible_correction_step(
self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc
):
if torch.sum(next_sigma) < 1e-14: if torch.sum(next_sigma) < 1e-14:
# Save a network evaluation if all noise levels are 0 # Save a network evaluation if all noise levels are 0
return euler_step return euler_step
@ -243,7 +273,9 @@ class HeunEDMSampler(EDMSampler):
d_prime = (d + d_new) / 2.0 d_prime = (d + d_new) / 2.0
# apply correction if noise level is not 0 # apply correction if noise level is not 0
x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step) x = torch.where(
append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step
)
return x return x
@ -282,7 +314,9 @@ class DPMPP2SAncestralSampler(AncestralSampler):
x = x_euler x = x_euler
else: else:
h, s, t, t_next = self.get_variables(sigma, sigma_down) h, s, t, t_next = self.get_variables(sigma, sigma_down)
mult = [append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next)] mult = [
append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next)
]
x2 = mult[0] * x - mult[1] * denoised x2 = mult[0] * x - mult[1] * denoised
denoised2 = self.denoise(x2, denoiser, to_sigma(s), cond, uc) denoised2 = self.denoise(x2, denoiser, to_sigma(s), cond, uc)
@ -332,7 +366,10 @@ class DPMPP2MSampler(BaseDiffusionSampler):
denoised = self.denoise(x, denoiser, sigma, cond, uc) denoised = self.denoise(x, denoiser, sigma, cond, uc)
h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma) h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma)
mult = [append_dims(mult, x.ndim) for mult in self.get_mult(h, r, t, t_next, previous_sigma)] mult = [
append_dims(mult, x.ndim)
for mult in self.get_mult(h, r, t, t_next, previous_sigma)
]
x_standard = mult[0] * x - mult[1] * denoised x_standard = mult[0] * x - mult[1] * denoised
if old_denoised is None or torch.sum(next_sigma) < 1e-14: if old_denoised is None or torch.sum(next_sigma) < 1e-14:
@ -343,12 +380,16 @@ class DPMPP2MSampler(BaseDiffusionSampler):
x_advanced = mult[0] * x - mult[1] * denoised_d x_advanced = mult[0] * x - mult[1] * denoised_d
# apply correction if noise level is not 0 and not first step # apply correction if noise level is not 0 and not first step
x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard) x = torch.where(
append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard
)
return x, denoised return x, denoised
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs): def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
x, cond, uc, num_steps
)
old_denoised = None old_denoised = None
for i in self.get_sigma_gen(num_sigmas): for i in self.get_sigma_gen(num_sigmas):
@ -365,7 +406,6 @@ class DPMPP2MSampler(BaseDiffusionSampler):
return x return x
class SDEDPMPP2MSampler(BaseDiffusionSampler): class SDEDPMPP2MSampler(BaseDiffusionSampler):
def get_variables(self, sigma, next_sigma, previous_sigma=None): def get_variables(self, sigma, next_sigma, previous_sigma=None):
t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)] t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)]
@ -380,7 +420,7 @@ class SDEDPMPP2MSampler(BaseDiffusionSampler):
def get_mult(self, h, r, t, t_next, previous_sigma): def get_mult(self, h, r, t, t_next, previous_sigma):
mult1 = to_sigma(t_next) / to_sigma(t) * (-h).exp() mult1 = to_sigma(t_next) / to_sigma(t) * (-h).exp()
mult2 = (-2 * h).expm1() mult2 = (-2*h).expm1()
if previous_sigma is not None: if previous_sigma is not None:
mult3 = 1 + 1 / (2 * r) mult3 = 1 + 1 / (2 * r)
@ -403,8 +443,11 @@ class SDEDPMPP2MSampler(BaseDiffusionSampler):
denoised = self.denoise(x, denoiser, sigma, cond, uc) denoised = self.denoise(x, denoiser, sigma, cond, uc)
h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma) h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma)
mult = [append_dims(mult, x.ndim) for mult in self.get_mult(h, r, t, t_next, previous_sigma)] mult = [
mult_noise = append_dims(next_sigma * (1 - (-2 * h).exp()) ** 0.5, x.ndim) append_dims(mult, x.ndim)
for mult in self.get_mult(h, r, t, t_next, previous_sigma)
]
mult_noise = append_dims(next_sigma * (1 - (-2*h).exp())**0.5, x.ndim)
x_standard = mult[0] * x - mult[1] * denoised + mult_noise * torch.randn_like(x) x_standard = mult[0] * x - mult[1] * denoised + mult_noise * torch.randn_like(x)
if old_denoised is None or torch.sum(next_sigma) < 1e-14: if old_denoised is None or torch.sum(next_sigma) < 1e-14:
@ -415,12 +458,16 @@ class SDEDPMPP2MSampler(BaseDiffusionSampler):
x_advanced = mult[0] * x - mult[1] * denoised_d + mult_noise * torch.randn_like(x) x_advanced = mult[0] * x - mult[1] * denoised_d + mult_noise * torch.randn_like(x)
# apply correction if noise level is not 0 and not first step # apply correction if noise level is not 0 and not first step
x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard) x = torch.where(
append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard
)
return x, denoised return x, denoised
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, **kwargs): def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, **kwargs):
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
x, cond, uc, num_steps
)
old_denoised = None old_denoised = None
for i in self.get_sigma_gen(num_sigmas): for i in self.get_sigma_gen(num_sigmas):
@ -437,7 +484,6 @@ class SDEDPMPP2MSampler(BaseDiffusionSampler):
return x return x
class SdeditEDMSampler(EulerEDMSampler): class SdeditEDMSampler(EulerEDMSampler):
def __init__(self, edit_ratio=0.5, *args, **kwargs): def __init__(self, edit_ratio=0.5, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -446,7 +492,9 @@ class SdeditEDMSampler(EulerEDMSampler):
def __call__(self, denoiser, image, randn, cond, uc=None, num_steps=None, edit_ratio=None): def __call__(self, denoiser, image, randn, cond, uc=None, num_steps=None, edit_ratio=None):
randn_unit = randn.clone() randn_unit = randn.clone()
randn, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(randn, cond, uc, num_steps) randn, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
randn, cond, uc, num_steps
)
if num_steps is None: if num_steps is None:
num_steps = self.num_steps num_steps = self.num_steps
@ -461,7 +509,9 @@ class SdeditEDMSampler(EulerEDMSampler):
x = image + randn_unit * append_dims(s_in * sigmas[i], len(randn_unit.shape)) x = image + randn_unit * append_dims(s_in * sigmas[i], len(randn_unit.shape))
gamma = ( gamma = (
min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) if self.s_tmin <= sigmas[i] <= self.s_tmax else 0.0 min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)
if self.s_tmin <= sigmas[i] <= self.s_tmax
else 0.0
) )
x = self.sampler_step( x = self.sampler_step(
s_in * sigmas[i], s_in * sigmas[i],
@ -475,8 +525,8 @@ class SdeditEDMSampler(EulerEDMSampler):
return x return x
class VideoDDIMSampler(BaseDiffusionSampler): class VideoDDIMSampler(BaseDiffusionSampler):
def __init__(self, fixed_frames=0, sdedit=False, **kwargs): def __init__(self, fixed_frames=0, sdedit=False, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.fixed_frames = fixed_frames self.fixed_frames = fixed_frames
@ -484,13 +534,10 @@ class VideoDDIMSampler(BaseDiffusionSampler):
def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None): def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
alpha_cumprod_sqrt, timesteps = self.discretization( alpha_cumprod_sqrt, timesteps = self.discretization(
self.num_steps if num_steps is None else num_steps, self.num_steps if num_steps is None else num_steps, device=self.device, return_idx=True, do_append_zero=False
device=self.device,
return_idx=True,
do_append_zero=False,
) )
alpha_cumprod_sqrt = torch.cat([alpha_cumprod_sqrt, alpha_cumprod_sqrt.new_ones([1])]) alpha_cumprod_sqrt = torch.cat([alpha_cumprod_sqrt, alpha_cumprod_sqrt.new_ones([1])])
timesteps = torch.cat([torch.tensor(list(timesteps)).new_zeros([1]) - 1, torch.tensor(list(timesteps))]) timesteps = torch.cat([torch.tensor(list(timesteps)).new_zeros([1])-1, torch.tensor(list(timesteps))])
uc = default(uc, cond) uc = default(uc, cond)
@ -500,51 +547,36 @@ class VideoDDIMSampler(BaseDiffusionSampler):
return x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps return x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps
def denoise(self, x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep=None, idx=None, scale=None, scale_emb=None): def denoise(self, x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep=None, idx=None, scale=None, scale_emb=None, ofs=None):
additional_model_inputs = {} additional_model_inputs = {}
if ofs is not None:
additional_model_inputs['ofs'] = ofs
if isinstance(scale, torch.Tensor) == False and scale == 1: if isinstance(scale, torch.Tensor) == False and scale == 1:
additional_model_inputs["idx"] = x.new_ones([x.shape[0]]) * timestep additional_model_inputs['idx'] = x.new_ones([x.shape[0]]) * timestep
if scale_emb is not None: if scale_emb is not None:
additional_model_inputs["scale_emb"] = scale_emb additional_model_inputs['scale_emb'] = scale_emb
denoised = denoiser(x, alpha_cumprod_sqrt, cond, **additional_model_inputs).to(torch.float32) denoised = denoiser(x, alpha_cumprod_sqrt, cond, **additional_model_inputs).to(torch.float32)
else: else:
additional_model_inputs["idx"] = torch.cat([x.new_ones([x.shape[0]]) * timestep] * 2) additional_model_inputs['idx'] = torch.cat([x.new_ones([x.shape[0]]) * timestep] * 2)
denoised = denoiser( denoised = denoiser(*self.guider.prepare_inputs(x, alpha_cumprod_sqrt, cond, uc), **additional_model_inputs).to(torch.float32)
*self.guider.prepare_inputs(x, alpha_cumprod_sqrt, cond, uc), **additional_model_inputs
).to(torch.float32)
if isinstance(self.guider, DynamicCFG): if isinstance(self.guider, DynamicCFG):
denoised = self.guider( denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt**2)**0.5, step_index=self.num_steps - timestep, scale=scale)
denoised, (1 - alpha_cumprod_sqrt**2) ** 0.5, step_index=self.num_steps - timestep, scale=scale
)
else: else:
denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt**2) ** 0.5, scale=scale) denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt**2)**0.5, scale=scale)
return denoised return denoised
def sampler_step( def sampler_step(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, denoiser, x, cond, uc=None, idx=None, timestep=None, scale=None, scale_emb=None, ofs=None):
self, denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx, scale=scale, scale_emb=scale_emb, ofs=ofs).to(torch.float32) # 1020
alpha_cumprod_sqrt,
next_alpha_cumprod_sqrt,
denoiser,
x,
cond,
uc=None,
idx=None,
timestep=None,
scale=None,
scale_emb=None,
):
denoised = self.denoise(
x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx, scale=scale, scale_emb=scale_emb
).to(torch.float32)
a_t = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5 a_t = ((1-next_alpha_cumprod_sqrt**2)/(1-alpha_cumprod_sqrt**2))**0.5
b_t = next_alpha_cumprod_sqrt - alpha_cumprod_sqrt * a_t b_t = next_alpha_cumprod_sqrt - alpha_cumprod_sqrt * a_t
x = append_dims(a_t, x.ndim) * x + append_dims(b_t, x.ndim) * denoised x = append_dims(a_t, x.ndim) * x + append_dims(b_t, x.ndim) * denoised
return x return x
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None): def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None, ofs=None): # 1020
x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop( x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop(
x, cond, uc, num_steps x, cond, uc, num_steps
) )
@ -558,25 +590,83 @@ class VideoDDIMSampler(BaseDiffusionSampler):
cond, cond,
uc, uc,
idx=self.num_steps - i, idx=self.num_steps - i,
timestep=timesteps[-(i + 1)], timestep=timesteps[-(i+1)],
scale=scale, scale=scale,
scale_emb=scale_emb, scale_emb=scale_emb,
ofs=ofs # 1020
) )
return x return x
class Image2VideoDDIMSampler(BaseDiffusionSampler):
def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
alpha_cumprod_sqrt, timesteps = self.discretization(
self.num_steps if num_steps is None else num_steps, device=self.device, return_idx=True
)
uc = default(uc, cond)
num_sigmas = len(alpha_cumprod_sqrt)
s_in = x.new_ones([x.shape[0]])
return x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps
def denoise(self, x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep=None):
additional_model_inputs = {}
additional_model_inputs['idx'] = torch.cat([x.new_ones([x.shape[0]]) * timestep] * 2)
denoised = denoiser(*self.guider.prepare_inputs(x, alpha_cumprod_sqrt, cond, uc), **additional_model_inputs).to(
torch.float32)
if isinstance(self.guider, DynamicCFG):
denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt ** 2) ** 0.5, step_index=self.num_steps - timestep)
else:
denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt ** 2) ** 0.5)
return denoised
def sampler_step(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, denoiser, x, cond, uc=None, idx=None,
timestep=None):
# 此处的sigma实际上是alpha_cumprod_sqrt
denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep).to(torch.float32)
if idx == 1:
return denoised
a_t = ((1 - next_alpha_cumprod_sqrt ** 2) / (1 - alpha_cumprod_sqrt ** 2)) ** 0.5
b_t = next_alpha_cumprod_sqrt - alpha_cumprod_sqrt * a_t
x = append_dims(a_t, x.ndim) * x + append_dims(b_t, x.ndim) * denoised
return x
def __call__(self, image, denoiser, x, cond, uc=None, num_steps=None):
x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop(
x, cond, uc, num_steps
)
for i in self.get_sigma_gen(num_sigmas):
x = self.sampler_step(
s_in * alpha_cumprod_sqrt[i],
s_in * alpha_cumprod_sqrt[i + 1],
denoiser,
x,
cond,
uc,
idx=self.num_steps - i,
timestep=timesteps[-(i + 1)]
)
return x
class VPSDEDPMPP2MSampler(VideoDDIMSampler): class VPSDEDPMPP2MSampler(VideoDDIMSampler):
def get_variables(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt=None): def get_variables(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt=None):
alpha_cumprod = alpha_cumprod_sqrt**2 alpha_cumprod = alpha_cumprod_sqrt ** 2
lamb = ((alpha_cumprod / (1 - alpha_cumprod)) ** 0.5).log() lamb = ((alpha_cumprod / (1-alpha_cumprod))**0.5).log()
next_alpha_cumprod = next_alpha_cumprod_sqrt**2 next_alpha_cumprod = next_alpha_cumprod_sqrt ** 2
lamb_next = ((next_alpha_cumprod / (1 - next_alpha_cumprod)) ** 0.5).log() lamb_next = ((next_alpha_cumprod / (1-next_alpha_cumprod))**0.5).log()
h = lamb_next - lamb h = lamb_next - lamb
if previous_alpha_cumprod_sqrt is not None: if previous_alpha_cumprod_sqrt is not None:
previous_alpha_cumprod = previous_alpha_cumprod_sqrt**2 previous_alpha_cumprod = previous_alpha_cumprod_sqrt ** 2
lamb_previous = ((previous_alpha_cumprod / (1 - previous_alpha_cumprod)) ** 0.5).log() lamb_previous = ((previous_alpha_cumprod / (1-previous_alpha_cumprod))**0.5).log()
h_last = lamb - lamb_previous h_last = lamb - lamb_previous
r = h_last / h r = h_last / h
return h, r, lamb, lamb_next return h, r, lamb, lamb_next
@ -584,8 +674,8 @@ class VPSDEDPMPP2MSampler(VideoDDIMSampler):
return h, None, lamb, lamb_next return h, None, lamb, lamb_next
def get_mult(self, h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt): def get_mult(self, h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt):
mult1 = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5 * (-h).exp() mult1 = ((1-next_alpha_cumprod_sqrt**2) / (1-alpha_cumprod_sqrt**2))**0.5 * (-h).exp()
mult2 = (-2 * h).expm1() * next_alpha_cumprod_sqrt mult2 = (-2*h).expm1() * next_alpha_cumprod_sqrt
if previous_alpha_cumprod_sqrt is not None: if previous_alpha_cumprod_sqrt is not None:
mult3 = 1 + 1 / (2 * r) mult3 = 1 + 1 / (2 * r)
@ -608,21 +698,18 @@ class VPSDEDPMPP2MSampler(VideoDDIMSampler):
timestep=None, timestep=None,
scale=None, scale=None,
scale_emb=None, scale_emb=None,
ofs=None # 1020
): ):
denoised = self.denoise( denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx, scale=scale, scale_emb=scale_emb, ofs=ofs).to(torch.float32) # 1020
x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx, scale=scale, scale_emb=scale_emb
).to(torch.float32)
if idx == 1: if idx == 1:
return denoised, denoised return denoised, denoised
h, r, lamb, lamb_next = self.get_variables( h, r, lamb, lamb_next = self.get_variables(alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt)
alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt
)
mult = [ mult = [
append_dims(mult, x.ndim) append_dims(mult, x.ndim)
for mult in self.get_mult(h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt) for mult in self.get_mult(h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt)
] ]
mult_noise = append_dims((1 - next_alpha_cumprod_sqrt**2) ** 0.5 * (1 - (-2 * h).exp()) ** 0.5, x.ndim) mult_noise = append_dims((1-next_alpha_cumprod_sqrt**2)**0.5 * (1 - (-2*h).exp())**0.5, x.ndim)
x_standard = mult[0] * x - mult[1] * denoised + mult_noise * torch.randn_like(x) x_standard = mult[0] * x - mult[1] * denoised + mult_noise * torch.randn_like(x)
if old_denoised is None or torch.sum(next_alpha_cumprod_sqrt) < 1e-14: if old_denoised is None or torch.sum(next_alpha_cumprod_sqrt) < 1e-14:
@ -636,24 +723,23 @@ class VPSDEDPMPP2MSampler(VideoDDIMSampler):
return x, denoised return x, denoised
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None): def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None, ofs=None): # 1020
x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop( x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop(
x, cond, uc, num_steps x, cond, uc, num_steps
) )
if self.fixed_frames > 0: if self.fixed_frames > 0:
prefix_frames = x[:, : self.fixed_frames] prefix_frames = x[:, :self.fixed_frames]
old_denoised = None old_denoised = None
for i in self.get_sigma_gen(num_sigmas): for i in self.get_sigma_gen(num_sigmas):
if self.fixed_frames > 0: if self.fixed_frames > 0:
if self.sdedit: if self.sdedit:
rd = torch.randn_like(prefix_frames) rd = torch.randn_like(prefix_frames)
noised_prefix_frames = alpha_cumprod_sqrt[i] * prefix_frames + rd * append_dims( noised_prefix_frames = alpha_cumprod_sqrt[i] * prefix_frames + rd * append_dims(s_in * (1 - alpha_cumprod_sqrt[i] ** 2)**0.5, len(prefix_frames.shape))
s_in * (1 - alpha_cumprod_sqrt[i] ** 2) ** 0.5, len(prefix_frames.shape) x = torch.cat([noised_prefix_frames, x[:, self.fixed_frames:]], dim=1)
)
x = torch.cat([noised_prefix_frames, x[:, self.fixed_frames :]], dim=1)
else: else:
x = torch.cat([prefix_frames, x[:, self.fixed_frames :]], dim=1) x = torch.cat([prefix_frames, x[:, self.fixed_frames:]], dim=1)
x, old_denoised = self.sampler_step( x, old_denoised = self.sampler_step(
old_denoised, old_denoised,
None if i == 0 else s_in * alpha_cumprod_sqrt[i - 1], None if i == 0 else s_in * alpha_cumprod_sqrt[i - 1],
@ -664,28 +750,29 @@ class VPSDEDPMPP2MSampler(VideoDDIMSampler):
cond, cond,
uc=uc, uc=uc,
idx=self.num_steps - i, idx=self.num_steps - i,
timestep=timesteps[-(i + 1)], timestep=timesteps[-(i+1)],
scale=scale, scale=scale,
scale_emb=scale_emb, scale_emb=scale_emb,
ofs=ofs # 1020
) )
if self.fixed_frames > 0: if self.fixed_frames > 0:
x = torch.cat([prefix_frames, x[:, self.fixed_frames :]], dim=1) x = torch.cat([prefix_frames, x[:, self.fixed_frames:]], dim=1)
return x return x
class VPODEDPMPP2MSampler(VideoDDIMSampler): class VPODEDPMPP2MSampler(VideoDDIMSampler):
def get_variables(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt=None): def get_variables(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt=None):
alpha_cumprod = alpha_cumprod_sqrt**2 alpha_cumprod = alpha_cumprod_sqrt ** 2
lamb = ((alpha_cumprod / (1 - alpha_cumprod)) ** 0.5).log() lamb = ((alpha_cumprod / (1-alpha_cumprod))**0.5).log()
next_alpha_cumprod = next_alpha_cumprod_sqrt**2 next_alpha_cumprod = next_alpha_cumprod_sqrt ** 2
lamb_next = ((next_alpha_cumprod / (1 - next_alpha_cumprod)) ** 0.5).log() lamb_next = ((next_alpha_cumprod / (1-next_alpha_cumprod))**0.5).log()
h = lamb_next - lamb h = lamb_next - lamb
if previous_alpha_cumprod_sqrt is not None: if previous_alpha_cumprod_sqrt is not None:
previous_alpha_cumprod = previous_alpha_cumprod_sqrt**2 previous_alpha_cumprod = previous_alpha_cumprod_sqrt ** 2
lamb_previous = ((previous_alpha_cumprod / (1 - previous_alpha_cumprod)) ** 0.5).log() lamb_previous = ((previous_alpha_cumprod / (1-previous_alpha_cumprod))**0.5).log()
h_last = lamb - lamb_previous h_last = lamb - lamb_previous
r = h_last / h r = h_last / h
return h, r, lamb, lamb_next return h, r, lamb, lamb_next
@ -693,7 +780,7 @@ class VPODEDPMPP2MSampler(VideoDDIMSampler):
return h, None, lamb, lamb_next return h, None, lamb, lamb_next
def get_mult(self, h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt): def get_mult(self, h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt):
mult1 = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5 mult1 = ((1-next_alpha_cumprod_sqrt**2) / (1-alpha_cumprod_sqrt**2))**0.5
mult2 = (-h).expm1() * next_alpha_cumprod_sqrt mult2 = (-h).expm1() * next_alpha_cumprod_sqrt
if previous_alpha_cumprod_sqrt is not None: if previous_alpha_cumprod_sqrt is not None:
@ -714,15 +801,13 @@ class VPODEDPMPP2MSampler(VideoDDIMSampler):
cond, cond,
uc=None, uc=None,
idx=None, idx=None,
timestep=None, timestep=None
): ):
denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx).to(torch.float32) denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx).to(torch.float32)
if idx == 1: if idx == 1:
return denoised, denoised return denoised, denoised
h, r, lamb, lamb_next = self.get_variables( h, r, lamb, lamb_next = self.get_variables(alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt)
alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt
)
mult = [ mult = [
append_dims(mult, x.ndim) append_dims(mult, x.ndim)
for mult in self.get_mult(h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt) for mult in self.get_mult(h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt)
@ -757,7 +842,39 @@ class VPODEDPMPP2MSampler(VideoDDIMSampler):
cond, cond,
uc=uc, uc=uc,
idx=self.num_steps - i, idx=self.num_steps - i,
timestep=timesteps[-(i + 1)], timestep=timesteps[-(i+1)]
) )
return x return x
class VideoDDPMSampler(VideoDDIMSampler):
def sampler_step(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, denoiser, x, cond, uc=None, idx=None):
# 此处的sigma实际上是alpha_cumprod_sqrt
denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, idx*1000//self.num_steps).to(torch.float32)
if idx == 1:
return denoised
alpha_sqrt = alpha_cumprod_sqrt / next_alpha_cumprod_sqrt
x = append_dims(alpha_sqrt * (1-next_alpha_cumprod_sqrt**2) / (1-alpha_cumprod_sqrt**2), x.ndim) * x \
+ append_dims(next_alpha_cumprod_sqrt * (1-alpha_sqrt**2) / (1-alpha_cumprod_sqrt**2), x.ndim) * denoised \
+ append_dims(((1-next_alpha_cumprod_sqrt**2) * (1-alpha_sqrt**2) / (1-alpha_cumprod_sqrt**2))**0.5, x.ndim) * torch.randn_like(x)
return x
def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc = self.prepare_sampling_loop(
x, cond, uc, num_steps
)
for i in self.get_sigma_gen(num_sigmas):
x = self.sampler_step(
s_in * alpha_cumprod_sqrt[i],
s_in * alpha_cumprod_sqrt[i + 1],
denoiser,
x,
cond,
uc,
idx=self.num_steps - i
)
return x

View File

@ -17,23 +17,20 @@ class EDMSampling:
class DiscreteSampling: class DiscreteSampling:
def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True, uniform_sampling=False): def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True, uniform_sampling=False, group_num=0):
self.num_idx = num_idx self.num_idx = num_idx
self.sigmas = instantiate_from_config(discretization_config)(num_idx, do_append_zero=do_append_zero, flip=flip) self.sigmas = instantiate_from_config(discretization_config)(
num_idx, do_append_zero=do_append_zero, flip=flip
)
world_size = mpu.get_data_parallel_world_size() world_size = mpu.get_data_parallel_world_size()
if world_size <= 8:
uniform_sampling = False
self.uniform_sampling = uniform_sampling self.uniform_sampling = uniform_sampling
self.group_num = group_num
if self.uniform_sampling: if self.uniform_sampling:
i = 1
while True:
if world_size % i != 0 or num_idx % (world_size // i) != 0:
i += 1
else:
self.group_num = world_size // i
break
assert self.group_num > 0 assert self.group_num > 0
assert world_size % self.group_num == 0 assert world_size % group_num == 0
self.group_width = world_size // self.group_num # the number of rank in one group self.group_width = world_size // group_num # the number of rank in one group
self.sigma_interval = self.num_idx // self.group_num self.sigma_interval = self.num_idx // self.group_num
def idx_to_sigma(self, idx): def idx_to_sigma(self, idx):
@ -45,9 +42,7 @@ class DiscreteSampling:
group_index = rank // self.group_width group_index = rank // self.group_width
idx = default( idx = default(
rand, rand,
torch.randint( torch.randint(group_index * self.sigma_interval, (group_index + 1) * self.sigma_interval, (n_samples,)),
group_index * self.sigma_interval, (group_index + 1) * self.sigma_interval, (n_samples,)
),
) )
else: else:
idx = default( idx = default(
@ -59,7 +54,6 @@ class DiscreteSampling:
else: else:
return self.idx_to_sigma(idx) return self.idx_to_sigma(idx)
class PartialDiscreteSampling: class PartialDiscreteSampling:
def __init__(self, discretization_config, total_num_idx, partial_num_idx, do_append_zero=False, flip=True): def __init__(self, discretization_config, total_num_idx, partial_num_idx, do_append_zero=False, flip=True):
self.total_num_idx = total_num_idx self.total_num_idx = total_num_idx

View File

@ -592,8 +592,11 @@ class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine):
unregularized: bool = False, unregularized: bool = False,
input_cp: bool = False, input_cp: bool = False,
output_cp: bool = False, output_cp: bool = False,
use_cp: bool = True,
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
if self.cp_size > 0 and not input_cp: if self.cp_size <= 1:
use_cp = False
if self.cp_size > 0 and use_cp and not input_cp:
if not is_context_parallel_initialized: if not is_context_parallel_initialized:
initialize_context_parallel(self.cp_size) initialize_context_parallel(self.cp_size)
@ -603,11 +606,11 @@ class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine):
x = _conv_split(x, dim=2, kernel_size=1) x = _conv_split(x, dim=2, kernel_size=1)
if return_reg_log: if return_reg_log:
z, reg_log = super().encode(x, return_reg_log, unregularized) z, reg_log = super().encode(x, return_reg_log, unregularized, use_cp=use_cp)
else: else:
z = super().encode(x, return_reg_log, unregularized) z = super().encode(x, return_reg_log, unregularized, use_cp=use_cp)
if self.cp_size > 0 and not output_cp: if self.cp_size > 0 and use_cp and not output_cp:
z = _conv_gather(z, dim=2, kernel_size=1) z = _conv_gather(z, dim=2, kernel_size=1)
if return_reg_log: if return_reg_log:
@ -619,23 +622,24 @@ class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine):
z: torch.Tensor, z: torch.Tensor,
input_cp: bool = False, input_cp: bool = False,
output_cp: bool = False, output_cp: bool = False,
split_kernel_size: int = 1, use_cp: bool = True,
**kwargs, **kwargs,
): ):
if self.cp_size > 0 and not input_cp: if self.cp_size <= 1:
use_cp = False
if self.cp_size > 0 and use_cp and not input_cp:
if not is_context_parallel_initialized: if not is_context_parallel_initialized:
initialize_context_parallel(self.cp_size) initialize_context_parallel(self.cp_size)
global_src_rank = get_context_parallel_group_rank() * self.cp_size global_src_rank = get_context_parallel_group_rank() * self.cp_size
torch.distributed.broadcast(z, src=global_src_rank, group=get_context_parallel_group()) torch.distributed.broadcast(z, src=global_src_rank, group=get_context_parallel_group())
z = _conv_split(z, dim=2, kernel_size=split_kernel_size) z = _conv_split(z, dim=2, kernel_size=1)
x = super().decode(z, **kwargs) x = super().decode(z, use_cp=use_cp, **kwargs)
if self.cp_size > 0 and not output_cp:
x = _conv_gather(x, dim=2, kernel_size=split_kernel_size)
if self.cp_size > 0 and use_cp and not output_cp:
x = _conv_gather(x, dim=2, kernel_size=1)
return x return x
def forward( def forward(

View File

@ -16,11 +16,7 @@ from sgm.util import (
get_context_parallel_group_rank, get_context_parallel_group_rank,
) )
# try:
from vae_modules.utils import SafeConv3d as Conv3d from vae_modules.utils import SafeConv3d as Conv3d
# except:
# # Degrade to normal Conv3d if SafeConv3d is not available
# from torch.nn import Conv3d
def cast_tuple(t, length=1): def cast_tuple(t, length=1):
@ -81,8 +77,6 @@ def _split(input_, dim):
cp_rank = get_context_parallel_rank() cp_rank = get_context_parallel_rank()
# print('in _split, cp_rank:', cp_rank, 'input_size:', input_.shape)
inpu_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous() inpu_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous()
input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous() input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous()
dim_size = input_.size()[dim] // cp_world_size dim_size = input_.size()[dim] // cp_world_size
@ -94,8 +88,6 @@ def _split(input_, dim):
output = torch.cat([inpu_first_frame_, output], dim=dim) output = torch.cat([inpu_first_frame_, output], dim=dim)
output = output.contiguous() output = output.contiguous()
# print('out _split, cp_rank:', cp_rank, 'output_size:', output.shape)
return output return output
@ -382,19 +374,6 @@ class ContextParallelCausalConv3d(nn.Module):
self.cache_padding = None self.cache_padding = None
def forward(self, input_, clear_cache=True): def forward(self, input_, clear_cache=True):
# if input_.shape[2] == 1: # handle image
# # first frame padding
# input_parallel = torch.cat([input_] * self.time_kernel_size, dim=2)
# else:
# input_parallel = conv_pass_from_last_rank(input_, self.temporal_dim, self.time_kernel_size)
# padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
# input_parallel = F.pad(input_parallel, padding_2d, mode = 'constant', value = 0)
# output_parallel = self.conv(input_parallel)
# output = output_parallel
# return output
input_parallel = fake_cp_pass_from_previous_rank( input_parallel = fake_cp_pass_from_previous_rank(
input_, self.temporal_dim, self.time_kernel_size, self.cache_padding input_, self.temporal_dim, self.time_kernel_size, self.cache_padding
) )
@ -441,7 +420,8 @@ class ContextParallelGroupNorm(torch.nn.GroupNorm):
return output return output
def Normalize(in_channels, gather=False, **kwargs): # same for 3D and 2D def Normalize(in_channels, gather=False, **kwargs):
# same for 3D and 2D
if gather: if gather:
return ContextParallelGroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) return ContextParallelGroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
else: else:
@ -488,24 +468,34 @@ class SpatialNorm3D(nn.Module):
kernel_size=1, kernel_size=1,
) )
def forward(self, f, zq, clear_fake_cp_cache=True): def forward(self, f, zq, clear_fake_cp_cache=True, fake_cp_rank0=True):
if f.shape[2] > 1 and f.shape[2] % 2 == 1: if f.shape[2] > 1 and get_context_parallel_rank() == 0 and fake_cp_rank0:
f_first, f_rest = f[:, :, :1], f[:, :, 1:] f_first, f_rest = f[:, :, :1], f[:, :, 1:]
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:] f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
zq_first, zq_rest = zq[:, :, :1], zq[:, :, 1:] zq_first, zq_rest = zq[:, :, :1], zq[:, :, 1:]
zq_first = torch.nn.functional.interpolate(zq_first, size=f_first_size, mode="nearest") zq_first = torch.nn.functional.interpolate(zq_first, size=f_first_size, mode="nearest")
zq_rest = torch.nn.functional.interpolate(zq_rest, size=f_rest_size, mode="nearest")
zq_rest_splits = torch.split(zq_rest, 32, dim=1)
interpolated_splits = [
torch.nn.functional.interpolate(split, size=f_rest_size, mode="nearest") for split in zq_rest_splits
]
zq_rest = torch.cat(interpolated_splits, dim=1)
# zq_rest = torch.nn.functional.interpolate(zq_rest, size=f_rest_size, mode="nearest")
zq = torch.cat([zq_first, zq_rest], dim=2) zq = torch.cat([zq_first, zq_rest], dim=2)
else: else:
zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:], mode="nearest") f_size = f.shape[-3:]
zq_splits = torch.split(zq, 32, dim=1)
interpolated_splits = [
torch.nn.functional.interpolate(split, size=f_size, mode="nearest") for split in zq_splits
]
zq = torch.cat(interpolated_splits, dim=1)
if self.add_conv: if self.add_conv:
zq = self.conv(zq, clear_cache=clear_fake_cp_cache) zq = self.conv(zq, clear_cache=clear_fake_cp_cache)
# f = conv_gather_from_context_parallel_region(f, dim=2, kernel_size=1)
norm_f = self.norm_layer(f) norm_f = self.norm_layer(f)
# norm_f = conv_scatter_to_context_parallel_region(norm_f, dim=2, kernel_size=1)
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
return new_f return new_f
@ -541,23 +531,44 @@ class Upsample3D(nn.Module):
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
self.compress_time = compress_time self.compress_time = compress_time
def forward(self, x): def forward(self, x, fake_cp_rank0=True):
if self.compress_time and x.shape[2] > 1: if self.compress_time and x.shape[2] > 1:
if x.shape[2] % 2 == 1: if get_context_parallel_rank() == 0 and fake_cp_rank0:
# print(x.shape)
# split first frame # split first frame
x_first, x_rest = x[:, :, 0], x[:, :, 1:] x_first, x_rest = x[:, :, 0], x[:, :, 1:]
x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0, mode="nearest") x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0, mode="nearest")
x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0, mode="nearest")
splits = torch.split(x_rest, 32, dim=1)
interpolated_splits = [
torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits
]
x_rest = torch.cat(interpolated_splits, dim=1)
# x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0, mode="nearest")
x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2) x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2)
else: else:
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") splits = torch.split(x, 32, dim=1)
interpolated_splits = [
torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits
]
x = torch.cat(interpolated_splits, dim=1)
# x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
else: else:
# only interpolate 2D # only interpolate 2D
t = x.shape[2] t = x.shape[2]
x = rearrange(x, "b c t h w -> (b t) c h w") x = rearrange(x, "b c t h w -> (b t) c h w")
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") # x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
splits = torch.split(x, 32, dim=1)
interpolated_splits = [
torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits
]
x = torch.cat(interpolated_splits, dim=1)
x = rearrange(x, "(b t) c h w -> b c t h w", t=t) x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
if self.with_conv: if self.with_conv:
@ -579,21 +590,30 @@ class DownSample3D(nn.Module):
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0) self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0)
self.compress_time = compress_time self.compress_time = compress_time
def forward(self, x): def forward(self, x, fake_cp_rank0=True):
if self.compress_time and x.shape[2] > 1: if self.compress_time and x.shape[2] > 1:
h, w = x.shape[-2:] h, w = x.shape[-2:]
x = rearrange(x, "b c t h w -> (b h w) c t") x = rearrange(x, "b c t h w -> (b h w) c t")
if x.shape[-1] % 2 == 1: if get_context_parallel_rank() == 0 and fake_cp_rank0:
# split first frame # split first frame
x_first, x_rest = x[..., 0], x[..., 1:] x_first, x_rest = x[..., 0], x[..., 1:]
if x_rest.shape[-1] > 0: if x_rest.shape[-1] > 0:
x_rest = torch.nn.functional.avg_pool1d(x_rest, kernel_size=2, stride=2) splits = torch.split(x_rest, 32, dim=1)
interpolated_splits = [
torch.nn.functional.avg_pool1d(split, kernel_size=2, stride=2) for split in splits
]
x_rest = torch.cat(interpolated_splits, dim=1)
x = torch.cat([x_first[..., None], x_rest], dim=-1) x = torch.cat([x_first[..., None], x_rest], dim=-1)
x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w) x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w)
else: else:
x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2) # x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2)
splits = torch.split(x, 32, dim=1)
interpolated_splits = [
torch.nn.functional.avg_pool1d(split, kernel_size=2, stride=2) for split in splits
]
x = torch.cat(interpolated_splits, dim=1)
x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w) x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w)
if self.with_conv: if self.with_conv:
@ -673,13 +693,13 @@ class ContextParallelResnetBlock3D(nn.Module):
padding=0, padding=0,
) )
def forward(self, x, temb, zq=None, clear_fake_cp_cache=True): def forward(self, x, temb, zq=None, clear_fake_cp_cache=True, fake_cp_rank0=True):
h = x h = x
# if isinstance(self.norm1, torch.nn.GroupNorm): # if isinstance(self.norm1, torch.nn.GroupNorm):
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1) # h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
if zq is not None: if zq is not None:
h = self.norm1(h, zq, clear_fake_cp_cache=clear_fake_cp_cache) h = self.norm1(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0)
else: else:
h = self.norm1(h) h = self.norm1(h)
# if isinstance(self.norm1, torch.nn.GroupNorm): # if isinstance(self.norm1, torch.nn.GroupNorm):
@ -694,7 +714,7 @@ class ContextParallelResnetBlock3D(nn.Module):
# if isinstance(self.norm2, torch.nn.GroupNorm): # if isinstance(self.norm2, torch.nn.GroupNorm):
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1) # h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
if zq is not None: if zq is not None:
h = self.norm2(h, zq, clear_fake_cp_cache=clear_fake_cp_cache) h = self.norm2(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0)
else: else:
h = self.norm2(h) h = self.norm2(h)
# if isinstance(self.norm2, torch.nn.GroupNorm): # if isinstance(self.norm2, torch.nn.GroupNorm):
@ -807,23 +827,24 @@ class ContextParallelEncoder3D(nn.Module):
kernel_size=3, kernel_size=3,
) )
def forward(self, x, **kwargs): def forward(self, x, clear_fake_cp_cache=True, fake_cp_rank0=True):
# timestep embedding # timestep embedding
temb = None temb = None
# downsampling # downsampling
h = self.conv_in(x) h = self.conv_in(x, clear_cache=clear_fake_cp_cache)
for i_level in range(self.num_resolutions): for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks): for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](h, temb) h = self.down[i_level].block[i_block](h, temb, clear_fake_cp_cache=clear_fake_cp_cache)
if len(self.down[i_level].attn) > 0: if len(self.down[i_level].attn) > 0:
print("Attention not implemented")
h = self.down[i_level].attn[i_block](h) h = self.down[i_level].attn[i_block](h)
if i_level != self.num_resolutions - 1: if i_level != self.num_resolutions - 1:
h = self.down[i_level].downsample(h) h = self.down[i_level].downsample(h, fake_cp_rank0=fake_cp_rank0)
# middle # middle
h = self.mid.block_1(h, temb) h = self.mid.block_1(h, temb, clear_fake_cp_cache=clear_fake_cp_cache)
h = self.mid.block_2(h, temb) h = self.mid.block_2(h, temb, clear_fake_cp_cache=clear_fake_cp_cache)
# end # end
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1) # h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
@ -831,7 +852,7 @@ class ContextParallelEncoder3D(nn.Module):
# h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1) # h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
h = nonlinearity(h) h = nonlinearity(h)
h = self.conv_out(h) h = self.conv_out(h, clear_cache=clear_fake_cp_cache)
return h return h
@ -934,6 +955,11 @@ class ContextParallelDecoder3D(nn.Module):
up.block = block up.block = block
up.attn = attn up.attn = attn
if i_level != 0: if i_level != 0:
# # Symmetrical enc-dec
if i_level <= self.temporal_compress_level:
up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=True)
else:
up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False)
if i_level < self.num_resolutions - self.temporal_compress_level: if i_level < self.num_resolutions - self.temporal_compress_level:
up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False) up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False)
else: else:
@ -948,7 +974,7 @@ class ContextParallelDecoder3D(nn.Module):
kernel_size=3, kernel_size=3,
) )
def forward(self, z, clear_fake_cp_cache=True, **kwargs): def forward(self, z, clear_fake_cp_cache=True, fake_cp_rank0=True):
self.last_z_shape = z.shape self.last_z_shape = z.shape
# timestep embedding # timestep embedding
@ -961,23 +987,25 @@ class ContextParallelDecoder3D(nn.Module):
h = self.conv_in(z, clear_cache=clear_fake_cp_cache) h = self.conv_in(z, clear_cache=clear_fake_cp_cache)
# middle # middle
h = self.mid.block_1(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache) h = self.mid.block_1(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0)
h = self.mid.block_2(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache) h = self.mid.block_2(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0)
# upsampling # upsampling
for i_level in reversed(range(self.num_resolutions)): for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1): for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache) h = self.up[i_level].block[i_block](
h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0
)
if len(self.up[i_level].attn) > 0: if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h, zq) h = self.up[i_level].attn[i_block](h, zq)
if i_level != 0: if i_level != 0:
h = self.up[i_level].upsample(h) h = self.up[i_level].upsample(h, fake_cp_rank0=fake_cp_rank0)
# end # end
if self.give_pre_end: if self.give_pre_end:
return h return h
h = self.norm_out(h, zq, clear_fake_cp_cache=clear_fake_cp_cache) h = self.norm_out(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0)
h = nonlinearity(h) h = nonlinearity(h)
h = self.conv_out(h, clear_cache=clear_fake_cp_cache) h = self.conv_out(h, clear_cache=clear_fake_cp_cache)

View File

@ -1,22 +1,15 @@
""" """
This script demonstrates how to convert and generate video from a text prompt
using CogVideoX with 🤗Huggingface Diffusers Pipeline.
This script requires the `diffusers>=0.30.2` library to be installed.
Functions:
- reassign_query_key_value_inplace: Reassigns the query, key, and value weights in-place.
- reassign_query_key_layernorm_inplace: Reassigns layer normalization for query and key in-place.
- reassign_adaln_norm_inplace: Reassigns adaptive layer normalization in-place.
- remove_keys_inplace: Removes specified keys from the state_dict in-place.
- replace_up_keys_inplace: Replaces keys in the "up" block in-place.
- get_state_dict: Extracts the state_dict from a saved checkpoint.
- update_state_dict_inplace: Updates the state_dict with new key assignments in-place.
- convert_transformer: Converts a transformer checkpoint to the CogVideoX format.
- convert_vae: Converts a VAE checkpoint to the CogVideoX format.
- get_args: Parses command-line arguments for the script.
- generate_video: Generates a video from a text prompt using the CogVideoX pipeline.
"""
The script demonstrates how to convert the weights of the CogVideoX model from SAT to Hugging Face format.
This script supports the conversion of the following models:
- CogVideoX-2B
- CogVideoX-5B, CogVideoX-5B-I2V
- CogVideoX1.1-5B, CogVideoX1.1-5B-I2V
Original Script:
https://github.com/huggingface/diffusers/blob/main/scripts/convert_cogvideox_to_diffusers.py
"""
import argparse import argparse
from typing import Any, Dict from typing import Any, Dict
@ -153,12 +146,12 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key:
def convert_transformer( def convert_transformer(
ckpt_path: str, ckpt_path: str,
num_layers: int, num_layers: int,
num_attention_heads: int, num_attention_heads: int,
use_rotary_positional_embeddings: bool, use_rotary_positional_embeddings: bool,
i2v: bool, i2v: bool,
dtype: torch.dtype, dtype: torch.dtype,
): ):
PREFIX_KEY = "model.diffusion_model." PREFIX_KEY = "model.diffusion_model."
@ -172,7 +165,7 @@ def convert_transformer(
).to(dtype=dtype) ).to(dtype=dtype)
for key in list(original_state_dict.keys()): for key in list(original_state_dict.keys()):
new_key = key[len(PREFIX_KEY) :] new_key = key[len(PREFIX_KEY):]
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key) new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key) update_state_dict_inplace(original_state_dict, key, new_key)
@ -209,7 +202,8 @@ def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype):
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint") "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
)
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint") parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint")
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
parser.add_argument("--fp16", action="store_true", default=False, help="Whether to save the model weights in fp16") parser.add_argument("--fp16", action="store_true", default=False, help="Whether to save the model weights in fp16")
@ -259,9 +253,10 @@ if __name__ == "__main__":
if args.vae_ckpt_path is not None: if args.vae_ckpt_path is not None:
vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, dtype) vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, dtype)
text_encoder_id = "google/t5-v1_1-xxl" text_encoder_id = "/share/official_pretrains/hf_home/t5-v1_1-xxl"
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH) tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir) text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
# Apparently, the conversion does not work anymore without this :shrug: # Apparently, the conversion does not work anymore without this :shrug:
for param in text_encoder.parameters(): for param in text_encoder.parameters():
param.data = param.data.contiguous() param.data = param.data.contiguous()
@ -301,4 +296,7 @@ if __name__ == "__main__":
# We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird # We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird
# for users to specify variant when the default is not fp32 and they want to run with the correct default (which # for users to specify variant when the default is not fp32 and they want to run with the correct default (which
# is either fp16/bf16 here). # is either fp16/bf16 here).
pipe.save_pretrained(args.output_path, safe_serialization=True, push_to_hub=args.push_to_hub)
# This is necessary This is necessary for users with insufficient memory,
# such as those using Colab and notebooks, as it can save some memory used for model loading.
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=args.push_to_hub)