mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 19:41:59 +08:00
commit
ddd3dcd7eb
74
README.md
74
README.md
@ -22,7 +22,10 @@ Experience the CogVideoX-5B model online at <a href="https://huggingface.co/spac
|
||||
|
||||
## Project Updates
|
||||
|
||||
- 🔥🔥 **News**: ```2024/10/13```: A more cost-effective fine-tuning framework for `CogVideoX-5B` that works with a single
|
||||
- 🔥🔥 News: ```2024/11/08```: We have released the CogVideoX1.5 model. CogVideoX1.5 is an upgraded version of the open-source model CogVideoX.
|
||||
The CogVideoX1.5-5B series supports 10-second videos with higher resolution, and CogVideoX1.5-5B-I2V supports video generation at any resolution.
|
||||
The SAT code has already been updated, while the diffusers version is still under adaptation. Download the SAT version code [here](https://huggingface.co/THUDM/CogVideoX1.5-5B-SAT).
|
||||
- 🔥 **News**: ```2024/10/13```: A more cost-effective fine-tuning framework for `CogVideoX-5B` that works with a single
|
||||
4090 GPU, [cogvideox-factory](https://github.com/a-r-r-o-w/cogvideox-factory), has been released. It supports
|
||||
fine-tuning with multiple resolutions. Feel free to use it!
|
||||
- 🔥 **News**: ```2024/10/10```: We have updated our technical report. Please
|
||||
@ -68,7 +71,6 @@ Jump to a specific section:
|
||||
- [Tools](#tools)
|
||||
- [Introduction to CogVideo(ICLR'23) Model](#cogvideoiclr23)
|
||||
- [Citations](#Citation)
|
||||
- [Open Source Project Plan](#Open-Source-Project-Plan)
|
||||
- [Model License](#Model-License)
|
||||
|
||||
## Quick Start
|
||||
@ -172,79 +174,85 @@ models we currently offer, along with their foundational information.
|
||||
<th style="text-align: center;">CogVideoX-2B</th>
|
||||
<th style="text-align: center;">CogVideoX-5B</th>
|
||||
<th style="text-align: center;">CogVideoX-5B-I2V</th>
|
||||
<th style="text-align: center;">CogVideoX1.5-5B</th>
|
||||
<th style="text-align: center;">CogVideoX1.5-5B-I2V</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">Model Description</td>
|
||||
<td style="text-align: center;">Entry-level model, balancing compatibility. Low cost for running and secondary development.</td>
|
||||
<td style="text-align: center;">Larger model with higher video generation quality and better visual effects.</td>
|
||||
<td style="text-align: center;">CogVideoX-5B image-to-video version.</td>
|
||||
<td style="text-align: center;">Release Date</td>
|
||||
<th style="text-align: center;">August 6, 2024</th>
|
||||
<th style="text-align: center;">August 27, 2024</th>
|
||||
<th style="text-align: center;">September 19, 2024</th>
|
||||
<th style="text-align: center;">November 8, 2024</th>
|
||||
<th style="text-align: center;">November 8, 2024</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">Video Resolution</td>
|
||||
<td colspan="3" style="text-align: center;">720 * 480</td>
|
||||
<td colspan="1" style="text-align: center;">1360 * 768</td>
|
||||
<td colspan="1" style="text-align: center;">256 <= W <=1360<br>256 <= H <=768<br> W,H % 16 == 0</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">Inference Precision</td>
|
||||
<td style="text-align: center;"><b>FP16*(recommended)</b>, BF16, FP32, FP8*, INT8, not supported: INT4</td>
|
||||
<td colspan="2" style="text-align: center;"><b>BF16 (recommended)</b>, FP16, FP32, FP8*, INT8, not supported: INT4</td>
|
||||
<td colspan="2" style="text-align: center;"><b>BF16(recommended)</b>, FP16, FP32, FP8*, INT8, not supported: INT4</td>
|
||||
<td colspan="2" style="text-align: center;"><b>BF16</b></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">Single GPU Memory Usage<br></td>
|
||||
<td style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> FP16: 18GB <br><b>diffusers FP16: from 4GB* </b><br><b>diffusers INT8 (torchao): from 3.6GB*</b></td>
|
||||
<td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 26GB <br><b>diffusers BF16: from 5GB* </b><br><b>diffusers INT8 (torchao): from 4.4GB*</b></td>
|
||||
<td style="text-align: center;">Single GPU Memory Usage</td>
|
||||
<td style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> FP16: 18GB<br><b>diffusers FP16: from 4GB*</b><br><b>diffusers INT8(torchao): from 3.6GB*</b></td>
|
||||
<td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 26GB<br><b>diffusers BF16 : from 5GB*</b><br><b>diffusers INT8(torchao): from 4.4GB*</b></td>
|
||||
<td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 66GB<br></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">Multi-GPU Inference Memory Usage</td>
|
||||
<td style="text-align: center;">Multi-GPU Memory Usage</td>
|
||||
<td style="text-align: center;"><b>FP16: 10GB* using diffusers</b><br></td>
|
||||
<td colspan="2" style="text-align: center;"><b>BF16: 15GB* using diffusers</b><br></td>
|
||||
<td colspan="2" style="text-align: center;"><b>Not supported</b><br></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">Inference Speed<br>(Step = 50, FP/BF16)</td>
|
||||
<td style="text-align: center;">Single A100: ~90 seconds<br>Single H100: ~45 seconds</td>
|
||||
<td colspan="2" style="text-align: center;">Single A100: ~180 seconds<br>Single H100: ~90 seconds</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">Fine-tuning Precision</td>
|
||||
<td style="text-align: center;"><b>FP16</b></td>
|
||||
<td colspan="2" style="text-align: center;"><b>BF16</b></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">Fine-tuning Memory Usage</td>
|
||||
<td style="text-align: center;">47 GB (bs=1, LORA)<br> 61 GB (bs=2, LORA)<br> 62GB (bs=1, SFT)</td>
|
||||
<td style="text-align: center;">63 GB (bs=1, LORA)<br> 80 GB (bs=2, LORA)<br> 75GB (bs=1, SFT)<br></td>
|
||||
<td style="text-align: center;">78 GB (bs=1, LORA)<br> 75GB (bs=1, SFT, 16GPU)<br></td>
|
||||
<td colspan="2" style="text-align: center;">Single A100: ~1000 seconds (5-second video)<br>Single H100: ~550 seconds (5-second video)</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">Prompt Language</td>
|
||||
<td colspan="3" style="text-align: center;">English*</td>
|
||||
<td colspan="5" style="text-align: center;">English*</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">Maximum Prompt Length</td>
|
||||
<td style="text-align: center;">Prompt Token Limit</td>
|
||||
<td colspan="3" style="text-align: center;">226 Tokens</td>
|
||||
<td colspan="2" style="text-align: center;">224 Tokens</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">Video Length</td>
|
||||
<td colspan="3" style="text-align: center;">6 Seconds</td>
|
||||
<td colspan="3" style="text-align: center;">6 seconds</td>
|
||||
<td colspan="2" style="text-align: center;">5 or 10 seconds</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">Frame Rate</td>
|
||||
<td colspan="3" style="text-align: center;">8 Frames / Second</td>
|
||||
<td colspan="3" style="text-align: center;">8 frames / second</td>
|
||||
<td colspan="2" style="text-align: center;">16 frames / second</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">Video Resolution</td>
|
||||
<td colspan="3" style="text-align: center;">720 x 480, no support for other resolutions (including fine-tuning)</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">Position Encoding</td>
|
||||
<td style="text-align: center;">Positional Encoding</td>
|
||||
<td style="text-align: center;">3d_sincos_pos_embed</td>
|
||||
<td style="text-align: center;">3d_sincos_pos_embed</td>
|
||||
<td style="text-align: center;">3d_rope_pos_embed + learnable_pos_embed</td>
|
||||
<td style="text-align: center;">3d_sincos_pos_embed</td>
|
||||
<td style="text-align: center;">3d_rope_pos_embed + learnable_pos_embed</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">Download Link (Diffusers)</td>
|
||||
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-2b">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-2b">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-2b">🟣 WiseModel</a></td>
|
||||
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-5b">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-5b">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-5b">🟣 WiseModel</a></td>
|
||||
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-5b-I2V">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-5b-I2V">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-5b-I2V">🟣 WiseModel</a></td>
|
||||
<td colspan="2" style="text-align: center;"> Coming Soon </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">Download Link (SAT)</td>
|
||||
<td colspan="3" style="text-align: center;"><a href="./sat/README.md">SAT</a></td>
|
||||
<td colspan="3" style="text-align: center;"><a href="./sat/README_zh.md">SAT</a></td>
|
||||
<td colspan="2" style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX1.5-5b-SAT">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX1.5-5b-SAT">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX1.5-5b-SAT">🟣 WiseModel</a></td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
@ -422,7 +430,7 @@ hands-on practice on text-to-video generation. *The original input is in Chinese
|
||||
|
||||
We welcome your contributions! You can click [here](resources/contribute.md) for more information.
|
||||
|
||||
## License Agreement
|
||||
## Model-License
|
||||
|
||||
The code in this repository is released under the [Apache 2.0 License](LICENSE).
|
||||
|
||||
|
96
README_ja.md
96
README_ja.md
@ -1,6 +1,6 @@
|
||||
# CogVideo & CogVideoX
|
||||
|
||||
[Read this in English](./README_zh.md)
|
||||
[Read this in English](./README.md)
|
||||
|
||||
[中文阅读](./README_zh.md)
|
||||
|
||||
@ -22,9 +22,14 @@
|
||||
|
||||
## 更新とニュース
|
||||
|
||||
- 🔥🔥 **ニュース**: ```2024/10/13```: コスト削減のため、単一の4090 GPUで`CogVideoX-5B`
|
||||
- 🔥🔥 ニュース: ```2024/11/08```: `CogVideoX1.5` モデルをリリースしました。CogVideoX1.5 は CogVideoX オープンソースモデルのアップグレードバージョンです。
|
||||
CogVideoX1.5-5B シリーズモデルは、10秒 長の動画とより高い解像度をサポートしており、`CogVideoX1.5-5B-I2V` は任意の解像度での動画生成に対応しています。
|
||||
SAT コードはすでに更新されており、`diffusers` バージョンは現在適応中です。
|
||||
SAT バージョンのコードは [こちら](https://huggingface.co/THUDM/CogVideoX1.5-5B-SAT) からダウンロードできます。
|
||||
- 🔥 **ニュース**: ```2024/10/13```: コスト削減のため、単一の4090 GPUで`CogVideoX-5B`
|
||||
を微調整できるフレームワーク [cogvideox-factory](https://github.com/a-r-r-o-w/cogvideox-factory)
|
||||
がリリースされました。複数の解像度での微調整に対応しています。ぜひご利用ください!- 🔥**ニュース**: ```2024/10/10```:
|
||||
がリリースされました。複数の解像度での微調整に対応しています。ぜひご利用ください!
|
||||
- 🔥**ニュース**: ```2024/10/10```:
|
||||
技術報告書を更新し、より詳細なトレーニング情報とデモを追加しました。
|
||||
- 🔥 **ニュース**: ```2024/10/10```: 技術報告書を更新しました。[こちら](https://arxiv.org/pdf/2408.06072)
|
||||
をクリックしてご覧ください。さらにトレーニングの詳細とデモを追加しました。デモを見るには[こちら](https://yzy-thu.github.io/CogVideoX-demo/)
|
||||
@ -34,7 +39,7 @@
|
||||
- 🔥**ニュース**: ```2024/9/19```: CogVideoXシリーズの画像生成ビデオモデル **CogVideoX-5B-I2V**
|
||||
をオープンソース化しました。このモデルは、画像を背景入力として使用し、プロンプトワードと組み合わせてビデオを生成することができ、より高い制御性を提供します。これにより、CogVideoXシリーズのモデルは、テキストからビデオ生成、ビデオの継続、画像からビデオ生成の3つのタスクをサポートするようになりました。オンラインでの[体験](https://huggingface.co/spaces/THUDM/CogVideoX-5B-Space)
|
||||
をお楽しみください。
|
||||
- 🔥🔥 **ニュース**: ```2024/9/19```:
|
||||
- 🔥 **ニュース**: ```2024/9/19```:
|
||||
CogVideoXのトレーニングプロセスでビデオデータをテキスト記述に変換するために使用されるキャプションモデル [CogVLM2-Caption](https://huggingface.co/THUDM/cogvlm2-llama3-caption)
|
||||
をオープンソース化しました。ダウンロードしてご利用ください。
|
||||
- 🔥 ```2024/8/27```: CogVideoXシリーズのより大きなモデル **CogVideoX-5B**
|
||||
@ -63,11 +68,10 @@
|
||||
- [プロジェクト構造](#プロジェクト構造)
|
||||
- [推論](#推論)
|
||||
- [sat](#sat)
|
||||
- [ツール](#ツール)
|
||||
- [プロジェクト計画](#プロジェクト計画)
|
||||
- [モデルライセンス](#モデルライセンス)
|
||||
- [ツール](#ツール)=
|
||||
- [CogVideo(ICLR'23)モデル紹介](#CogVideoICLR23)
|
||||
- [引用](#引用)
|
||||
- [ライセンス契約](#ライセンス契約)
|
||||
|
||||
## クイックスタート
|
||||
|
||||
@ -156,79 +160,91 @@ pip install -r requirements.txt
|
||||
CogVideoXは、[清影](https://chatglm.cn/video?fr=osm_cogvideox) と同源のオープンソース版ビデオ生成モデルです。
|
||||
以下の表に、提供しているビデオ生成モデルの基本情報を示します:
|
||||
|
||||
<table style="border-collapse: collapse; width: 100%;">
|
||||
<table style="border-collapse: collapse; width: 100%;">
|
||||
<tr>
|
||||
<th style="text-align: center;">モデル名</th>
|
||||
<th style="text-align: center;">CogVideoX-2B</th>
|
||||
<th style="text-align: center;">CogVideoX-5B</th>
|
||||
<th style="text-align: center;">CogVideoX-5B-I2V </th>
|
||||
<th style="text-align: center;">CogVideoX-5B-I2V</th>
|
||||
<th style="text-align: center;">CogVideoX1.5-5B</th>
|
||||
<th style="text-align: center;">CogVideoX1.5-5B-I2V</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">リリース日</td>
|
||||
<th style="text-align: center;">2024年8月6日</th>
|
||||
<th style="text-align: center;">2024年8月27日</th>
|
||||
<th style="text-align: center;">2024年9月19日</th>
|
||||
<th style="text-align: center;">2024年11月8日</th>
|
||||
<th style="text-align: center;">2024年11月8日</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">ビデオ解像度</td>
|
||||
<td colspan="3" style="text-align: center;">720 * 480</td>
|
||||
<td colspan="1" style="text-align: center;">1360 * 768</td>
|
||||
<td colspan="1" style="text-align: center;">256 <= W <=1360<br>256 <= H <=768<br> W,H % 16 == 0</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">推論精度</td>
|
||||
<td style="text-align: center;"><b>FP16*(推奨)</b>, BF16, FP32, FP8*, INT8, INT4は非対応</td>
|
||||
<td colspan="2" style="text-align: center;"><b>BF16(推奨)</b>, FP16, FP32, FP8*, INT8, INT4は非対応</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">単一GPUのメモリ消費<br></td>
|
||||
<td style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> FP16: 18GB <br><b>diffusers FP16: 4GBから* </b><br><b>diffusers INT8(torchao): 3.6GBから*</b></td>
|
||||
<td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 26GB <br><b>diffusers BF16 : 5GBから* </b><br><b>diffusers INT8(torchao): 4.4GBから* </b></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">マルチGPUのメモリ消費</td>
|
||||
<td style="text-align: center;"><b>FP16: 10GB* using diffusers</b><br></td>
|
||||
<td colspan="2" style="text-align: center;"><b>BF16: 15GB* using diffusers</b><br></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">推論速度<br>(ステップ = 50, FP/BF16)</td>
|
||||
<td style="text-align: center;">単一A100: 約90秒<br>単一H100: 約45秒</td>
|
||||
<td colspan="2" style="text-align: center;">単一A100: 約180秒<br>単一H100: 約90秒</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">ファインチューニング精度</td>
|
||||
<td style="text-align: center;"><b>FP16</b></td>
|
||||
<td colspan="2" style="text-align: center;"><b>BF16</b></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">ファインチューニング時のメモリ消費</td>
|
||||
<td style="text-align: center;">47 GB (bs=1, LORA)<br> 61 GB (bs=2, LORA)<br> 62GB (bs=1, SFT)</td>
|
||||
<td style="text-align: center;">63 GB (bs=1, LORA)<br> 80 GB (bs=2, LORA)<br> 75GB (bs=1, SFT)<br></td>
|
||||
<td style="text-align: center;">78 GB (bs=1, LORA)<br> 75GB (bs=1, SFT, 16GPU)<br></td>
|
||||
<td style="text-align: center;">シングルGPUメモリ消費</td>
|
||||
<td style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> FP16: 18GB<br><b>diffusers FP16: 4GBから*</b><br><b>diffusers INT8(torchao): 3.6GBから*</b></td>
|
||||
<td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 26GB<br><b>diffusers BF16: 5GBから*</b><br><b>diffusers INT8(torchao): 4.4GBから*</b></td>
|
||||
<td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 66GB<br></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">マルチGPUメモリ消費</td>
|
||||
<td style="text-align: center;"><b>FP16: 10GB* using diffusers</b><br></td>
|
||||
<td colspan="2" style="text-align: center;"><b>BF16: 15GB* using diffusers</b><br></td>
|
||||
<td colspan="2" style="text-align: center;"><b>サポートなし</b><br></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">推論速度<br>(ステップ数 = 50, FP/BF16)</td>
|
||||
<td style="text-align: center;">単一A100: 約90秒<br>単一H100: 約45秒</td>
|
||||
<td colspan="2" style="text-align: center;">単一A100: 約180秒<br>単一H100: 約90秒</td>
|
||||
<td colspan="2" style="text-align: center;">単一A100: 約1000秒(5秒動画)<br>単一H100: 約550秒(5秒動画)</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">プロンプト言語</td>
|
||||
<td colspan="3" style="text-align: center;">英語*</td>
|
||||
<td colspan="5" style="text-align: center;">英語*</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">プロンプトの最大トークン数</td>
|
||||
<td style="text-align: center;">プロンプトトークン制限</td>
|
||||
<td colspan="3" style="text-align: center;">226トークン</td>
|
||||
<td colspan="2" style="text-align: center;">224トークン</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">ビデオの長さ</td>
|
||||
<td colspan="3" style="text-align: center;">6秒</td>
|
||||
<td colspan="2" style="text-align: center;">5秒または10秒</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">フレームレート</td>
|
||||
<td colspan="3" style="text-align: center;">8フレーム/秒</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">ビデオ解像度</td>
|
||||
<td colspan="3" style="text-align: center;">720 * 480、他の解像度は非対応(ファインチューニング含む)</td>
|
||||
<td colspan="3" style="text-align: center;">8 フレーム / 秒</td>
|
||||
<td colspan="2" style="text-align: center;">16 フレーム / 秒</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">位置エンコーディング</td>
|
||||
<td style="text-align: center;">3d_sincos_pos_embed</td>
|
||||
<td style="text-align: center;">3d_sincos_pos_embed</td>
|
||||
<td style="text-align: center;">3d_rope_pos_embed + learnable_pos_embed</td>
|
||||
<td style="text-align: center;">3d_sincos_pos_embed</td>
|
||||
<td style="text-align: center;">3d_rope_pos_embed + learnable_pos_embed</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">ダウンロードリンク (Diffusers)</td>
|
||||
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-2b">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-2b">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-2b">🟣 WiseModel</a></td>
|
||||
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-5b">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-5b">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-5b">🟣 WiseModel</a></td>
|
||||
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-5b-I2V">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-5b-I2V">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-5b-I2V">🟣 WiseModel</a></td>
|
||||
<td colspan="2" style="text-align: center;">近日公開</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">ダウンロードリンク (SAT)</td>
|
||||
<td colspan="3" style="text-align: center;"><a href="./sat/README_ja.md">SAT</a></td>
|
||||
<td colspan="3" style="text-align: center;"><a href="./sat/README_zh.md">SAT</a></td>
|
||||
<td colspan="2" style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX1.5-5b-SAT">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX1.5-5b-SAT">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX1.5-5b-SAT">🟣 WiseModel</a></td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
|
57
README_zh.md
57
README_zh.md
@ -1,10 +1,9 @@
|
||||
# CogVideo & CogVideoX
|
||||
|
||||
[Read this in English](./README_zh.md)
|
||||
[Read this in English](./README.md)
|
||||
|
||||
[日本語で読む](./README_ja.md)
|
||||
|
||||
|
||||
<div align="center">
|
||||
<img src=resources/logo.svg width="50%"/>
|
||||
</div>
|
||||
@ -23,7 +22,9 @@
|
||||
|
||||
## 项目更新
|
||||
|
||||
- 🔥🔥 **News**: ```2024/10/13```: 成本更低,单卡4090可微调`CogVideoX-5B`
|
||||
- 🔥🔥 **News**: ```2024/11/08```: 我们发布 `CogVideoX1.5` 模型。CogVideoX1.5 是 CogVideoX 开源模型的升级版本。
|
||||
CogVideoX1.5-5B 系列模型支持 **10秒** 长度的视频和更高的分辨率,其中 `CogVideoX1.5-5B-I2V` 支持 **任意分辨率** 的视频生成,SAT代码已经更新。`diffusers`版本还在适配中。SAT版本代码前往 [这里](https://huggingface.co/THUDM/CogVideoX1.5-5B-SAT) 下载。
|
||||
- 🔥**News**: ```2024/10/13```: 成本更低,单卡4090可微调 `CogVideoX-5B`
|
||||
的微调框架[cogvideox-factory](https://github.com/a-r-r-o-w/cogvideox-factory)已经推出,多种分辨率微调,欢迎使用。
|
||||
- 🔥 **News**: ```2024/10/10```: 我们更新了我们的技术报告,请点击 [这里](https://arxiv.org/pdf/2408.06072)
|
||||
查看,附上了更多的训练细节和demo,关于demo,点击[这里](https://yzy-thu.github.io/CogVideoX-demo/) 查看。
|
||||
@ -58,10 +59,9 @@
|
||||
- [Inference](#inference)
|
||||
- [SAT](#sat)
|
||||
- [Tools](#tools)
|
||||
- [开源项目规划](#开源项目规划)
|
||||
- [模型协议](#模型协议)
|
||||
- [CogVideo(ICLR'23)模型介绍](#cogvideoiclr23)
|
||||
- [引用](#引用)
|
||||
- [模型协议](#模型协议)
|
||||
|
||||
## 快速开始
|
||||
|
||||
@ -157,62 +157,72 @@ CogVideoX是 [清影](https://chatglm.cn/video?fr=osm_cogvideox) 同源的开源
|
||||
<th style="text-align: center;">CogVideoX-2B</th>
|
||||
<th style="text-align: center;">CogVideoX-5B</th>
|
||||
<th style="text-align: center;">CogVideoX-5B-I2V </th>
|
||||
<th style="text-align: center;">CogVideoX1.5-5B</th>
|
||||
<th style="text-align: center;">CogVideoX1.5-5B-I2V</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">发布时间</td>
|
||||
<th style="text-align: center;">2024年8月6日</th>
|
||||
<th style="text-align: center;">2024年8月27日</th>
|
||||
<th style="text-align: center;">2024年9月19日</th>
|
||||
<th style="text-align: center;">2024年11月8日</th>
|
||||
<th style="text-align: center;">2024年11月8日</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">视频分辨率</td>
|
||||
<td colspan="3" style="text-align: center;">720 * 480</td>
|
||||
<td colspan="1" style="text-align: center;">1360 * 768</td>
|
||||
<td colspan="1" style="text-align: center;">256 <= W <=1360<br> 256 <= H <=768<br> W,H % 16 == 0</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">推理精度</td>
|
||||
<td style="text-align: center;"><b>FP16*(推荐)</b>, BF16, FP32,FP8*,INT8,不支持INT4</td>
|
||||
<td colspan="2" style="text-align: center;"><b>BF16(推荐)</b>, FP16, FP32,FP8*,INT8,不支持INT4</td>
|
||||
<td colspan="2" style="text-align: center;"><b>BF16</b></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">单GPU显存消耗<br></td>
|
||||
<td style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> FP16: 18GB <br><b>diffusers FP16: 4GB起* </b><br><b>diffusers INT8(torchao): 3.6G起*</b></td>
|
||||
<td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 26GB <br><b>diffusers BF16 : 5GB起* </b><br><b>diffusers INT8(torchao): 4.4G起* </b></td>
|
||||
<td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 66GB <br></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">多GPU推理显存消耗</td>
|
||||
<td style="text-align: center;"><b>FP16: 10GB* using diffusers</b><br></td>
|
||||
<td colspan="2" style="text-align: center;"><b>BF16: 15GB* using diffusers</b><br></td>
|
||||
<td colspan="2" style="text-align: center;"><b>Not support</b><br></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">推理速度<br>(Step = 50, FP/BF16)</td>
|
||||
<td style="text-align: center;">单卡A100: ~90秒<br>单卡H100: ~45秒</td>
|
||||
<td colspan="2" style="text-align: center;">单卡A100: ~180秒<br>单卡H100: ~90秒</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">微调精度</td>
|
||||
<td style="text-align: center;"><b>FP16</b></td>
|
||||
<td colspan="2" style="text-align: center;"><b>BF16</b></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">微调显存消耗</td>
|
||||
<td style="text-align: center;">47 GB (bs=1, LORA)<br> 61 GB (bs=2, LORA)<br> 62GB (bs=1, SFT)</td>
|
||||
<td style="text-align: center;">63 GB (bs=1, LORA)<br> 80 GB (bs=2, LORA)<br> 75GB (bs=1, SFT)<br></td>
|
||||
<td style="text-align: center;">78 GB (bs=1, LORA)<br> 75GB (bs=1, SFT, 16GPU)<br></td>
|
||||
<td colspan="2" style="text-align: center;">单卡A100: ~1000秒(5秒视频)<br>单卡H100: ~550秒(5秒视频)</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">提示词语言</td>
|
||||
<td colspan="3" style="text-align: center;">English*</td>
|
||||
<td colspan="5" style="text-align: center;">English*</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">提示词长度上限</td>
|
||||
<td colspan="3" style="text-align: center;">226 Tokens</td>
|
||||
<td colspan="2" style="text-align: center;">224 Tokens</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">视频长度</td>
|
||||
<td colspan="3" style="text-align: center;">6 秒</td>
|
||||
<td colspan="2" style="text-align: center;">5 秒 或 10 秒</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">帧率</td>
|
||||
<td colspan="3" style="text-align: center;">8 帧 / 秒 </td>
|
||||
<td colspan="2" style="text-align: center;">16 帧 / 秒 </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">视频分辨率</td>
|
||||
<td colspan="3" style="text-align: center;">720 * 480,不支持其他分辨率(含微调)</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">位置编码</td>
|
||||
<td style="text-align: center;">3d_sincos_pos_embed</td>
|
||||
<td style="text-align: center;">3d_sincos_pos_embed</td>
|
||||
<td style="text-align: center;">3d_sincos_pos_embed</td>
|
||||
<td style="text-align: center;">3d_rope_pos_embed + learnable_pos_embed</td>
|
||||
<td style="text-align: center;">3d_sincos_pos_embed</td>
|
||||
<td style="text-align: center;">3d_rope_pos_embed + learnable_pos_embed</td>
|
||||
</tr>
|
||||
<tr>
|
||||
@ -220,10 +230,13 @@ CogVideoX是 [清影](https://chatglm.cn/video?fr=osm_cogvideox) 同源的开源
|
||||
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-2b">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-2b">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-2b">🟣 WiseModel</a></td>
|
||||
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-5b">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-5b">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-5b">🟣 WiseModel</a></td>
|
||||
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-5b-I2V">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-5b-I2V">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-5b-I2V">🟣 WiseModel</a></td>
|
||||
<td colspan="2" style="text-align: center;"> 即将推出 </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">下载链接 (SAT)</td>
|
||||
<td colspan="3" style="text-align: center;"><a href="./sat/README_zh.md">SAT</a></td>
|
||||
<td colspan="2" style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX1.5-5b-SAT">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX1.5-5b-SAT">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX1.5-5b-SAT">🟣 WiseModel</a></td>
|
||||
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
|
206
sat/README.md
206
sat/README.md
@ -1,29 +1,39 @@
|
||||
# SAT CogVideoX-2B
|
||||
# SAT CogVideoX
|
||||
|
||||
[中文阅读](./README_zh.md)
|
||||
[Read this in English.](./README_zh.md)
|
||||
|
||||
[日本語で読む](./README_ja.md)
|
||||
|
||||
This folder contains the inference code using [SAT](https://github.com/THUDM/SwissArmyTransformer) weights and the
|
||||
fine-tuning code for SAT weights.
|
||||
This folder contains inference code using [SAT](https://github.com/THUDM/SwissArmyTransformer) weights, along with fine-tuning code for SAT weights.
|
||||
|
||||
This code is the framework used by the team to train the model. It has few comments and requires careful study.
|
||||
This code framework was used by our team during model training. There are few comments, so careful study is required.
|
||||
|
||||
## Inference Model
|
||||
|
||||
### 1. Ensure that you have correctly installed the dependencies required by this folder.
|
||||
### 1. Make sure you have installed all dependencies in this folder
|
||||
|
||||
```shell
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### 2. Download the model weights
|
||||
### 2. Download the Model Weights
|
||||
|
||||
### 2. Download model weights
|
||||
First, download the model weights from the SAT mirror.
|
||||
|
||||
First, go to the SAT mirror to download the model weights. For the CogVideoX-2B model, please download as follows:
|
||||
#### CogVideoX1.5 Model
|
||||
|
||||
```shell
|
||||
```
|
||||
git lfs install
|
||||
git clone https://huggingface.co/THUDM/CogVideoX1.5-5B-SAT
|
||||
```
|
||||
|
||||
This command downloads three models: Transformers, VAE, and T5 Encoder.
|
||||
|
||||
#### CogVideoX Model
|
||||
|
||||
For the CogVideoX-2B model, download as follows:
|
||||
|
||||
```
|
||||
mkdir CogVideoX-2b-sat
|
||||
cd CogVideoX-2b-sat
|
||||
wget https://cloud.tsinghua.edu.cn/f/fdba7608a49c463ba754/?dl=1
|
||||
@ -34,13 +44,12 @@ mv 'index.html?dl=1' transformer.zip
|
||||
unzip transformer.zip
|
||||
```
|
||||
|
||||
For the CogVideoX-5B model, please download the `transformers` file as follows link:
|
||||
(VAE files are the same as 2B)
|
||||
Download the `transformers` file for the CogVideoX-5B model (the VAE file is the same as for 2B):
|
||||
|
||||
+ [CogVideoX-5B](https://cloud.tsinghua.edu.cn/d/fcef5b3904294a6885e5/?p=%2F&mode=list)
|
||||
+ [CogVideoX-5B-I2V](https://cloud.tsinghua.edu.cn/d/5cc62a2d6e7d45c0a2f6/?p=%2F1&mode=list)
|
||||
|
||||
Next, you need to format the model files as follows:
|
||||
Arrange the model files in the following structure:
|
||||
|
||||
```
|
||||
.
|
||||
@ -52,20 +61,24 @@ Next, you need to format the model files as follows:
|
||||
└── 3d-vae.pt
|
||||
```
|
||||
|
||||
Due to large size of model weight file, using `git lfs` is recommended. Installation of `git lfs` can be
|
||||
found [here](https://github.com/git-lfs/git-lfs?tab=readme-ov-file#installing)
|
||||
Since model weight files are large, it’s recommended to use `git lfs`.
|
||||
See [here](https://github.com/git-lfs/git-lfs?tab=readme-ov-file#installing) for `git lfs` installation.
|
||||
|
||||
Next, clone the T5 model, which is not used for training and fine-tuning, but must be used.
|
||||
> T5 model is available on [Modelscope](https://modelscope.cn/models/ZhipuAI/CogVideoX-2b) as well.
|
||||
```
|
||||
git lfs install
|
||||
```
|
||||
|
||||
```shell
|
||||
git clone https://huggingface.co/THUDM/CogVideoX-2b.git
|
||||
Next, clone the T5 model, which is used as an encoder and doesn’t require training or fine-tuning.
|
||||
> You may also use the model file location on [Modelscope](https://modelscope.cn/models/ZhipuAI/CogVideoX-2b).
|
||||
|
||||
```
|
||||
git clone https://huggingface.co/THUDM/CogVideoX-2b.git # Download model from Huggingface
|
||||
# git clone https://www.modelscope.cn/ZhipuAI/CogVideoX-2b.git # Download from Modelscope
|
||||
mkdir t5-v1_1-xxl
|
||||
mv CogVideoX-2b/text_encoder/* CogVideoX-2b/tokenizer/* t5-v1_1-xxl
|
||||
```
|
||||
|
||||
By following the above approach, you will obtain a safetensor format T5 file. Ensure that there are no errors when
|
||||
loading it into Deepspeed in Finetune.
|
||||
This will yield a safetensor format T5 file that can be loaded without error during Deepspeed fine-tuning.
|
||||
|
||||
```
|
||||
├── added_tokens.json
|
||||
@ -80,11 +93,11 @@ loading it into Deepspeed in Finetune.
|
||||
0 directories, 8 files
|
||||
```
|
||||
|
||||
### 3. Modify the file in `configs/cogvideox_2b.yaml`.
|
||||
### 3. Modify `configs/cogvideox_*.yaml` file.
|
||||
|
||||
```yaml
|
||||
model:
|
||||
scale_factor: 1.15258426
|
||||
scale_factor: 1.55258426
|
||||
disable_first_stage_autocast: true
|
||||
log_keys:
|
||||
- txt
|
||||
@ -160,14 +173,14 @@ model:
|
||||
ucg_rate: 0.1
|
||||
target: sgm.modules.encoders.modules.FrozenT5Embedder
|
||||
params:
|
||||
model_dir: "t5-v1_1-xxl" # Absolute path to the CogVideoX-2b/t5-v1_1-xxl weights folder
|
||||
model_dir: "t5-v1_1-xxl" # absolute path to CogVideoX-2b/t5-v1_1-xxl weight folder
|
||||
max_length: 226
|
||||
|
||||
first_stage_config:
|
||||
target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper
|
||||
params:
|
||||
cp_size: 1
|
||||
ckpt_path: "CogVideoX-2b-sat/vae/3d-vae.pt" # Absolute path to the CogVideoX-2b-sat/vae/3d-vae.pt folder
|
||||
ckpt_path: "CogVideoX-2b-sat/vae/3d-vae.pt" # absolute path to CogVideoX-2b-sat/vae/3d-vae.pt file
|
||||
ignore_keys: [ 'loss' ]
|
||||
|
||||
loss_config:
|
||||
@ -239,48 +252,46 @@ model:
|
||||
num_steps: 50
|
||||
```
|
||||
|
||||
### 4. Modify the file in `configs/inference.yaml`.
|
||||
### 4. Modify `configs/inference.yaml` file.
|
||||
|
||||
```yaml
|
||||
args:
|
||||
latent_channels: 16
|
||||
mode: inference
|
||||
load: "{absolute_path/to/your}/transformer" # Absolute path to the CogVideoX-2b-sat/transformer folder
|
||||
load: "{absolute_path/to/your}/transformer" # Absolute path to CogVideoX-2b-sat/transformer folder
|
||||
# load: "{your lora folder} such as zRzRzRzRzRzRzR/lora-disney-08-20-13-28" # This is for Full model without lora adapter
|
||||
|
||||
batch_size: 1
|
||||
input_type: txt # You can choose txt for pure text input, or change to cli for command line input
|
||||
input_file: configs/test.txt # Pure text file, which can be edited
|
||||
sampling_num_frames: 13 # Must be 13, 11 or 9
|
||||
input_type: txt # You can choose "txt" for plain text input or change to "cli" for command-line input
|
||||
input_file: configs/test.txt # Plain text file, can be edited
|
||||
sampling_num_frames: 13 # For CogVideoX1.5-5B it must be 42 or 22. For CogVideoX-5B / 2B, it must be 13, 11, or 9.
|
||||
sampling_fps: 8
|
||||
fp16: True # For CogVideoX-2B
|
||||
# bf16: True # For CogVideoX-5B
|
||||
# bf16: True # For CogVideoX-5B
|
||||
output_dir: outputs/
|
||||
force_inference: True
|
||||
```
|
||||
|
||||
+ Modify `configs/test.txt` if multiple prompts is required, in which each line makes a prompt.
|
||||
+ For better prompt formatting, refer to [convert_demo.py](../inference/convert_demo.py), for which you should set the
|
||||
OPENAI_API_KEY as your environmental variable.
|
||||
+ Modify `input_type` in `configs/inference.yaml` if use command line as prompt iuput.
|
||||
+ If using a text file to save multiple prompts, modify `configs/test.txt` as needed. One prompt per line. If you are unsure how to write prompts, use [this code](../inference/convert_demo.py) to call an LLM for refinement.
|
||||
+ To use command-line input, modify:
|
||||
|
||||
```yaml
|
||||
```
|
||||
input_type: cli
|
||||
```
|
||||
|
||||
This allows input from the command line as prompts.
|
||||
This allows you to enter prompts from the command line.
|
||||
|
||||
Change `output_dir` if you wish to modify the address of the output video
|
||||
To modify the output video location, change:
|
||||
|
||||
```yaml
|
||||
```
|
||||
output_dir: outputs/
|
||||
```
|
||||
|
||||
It is saved by default in the `.outputs/` folder.
|
||||
The default location is the `.outputs/` folder.
|
||||
|
||||
### 5. Run the inference code to perform inference.
|
||||
### 5. Run the Inference Code to Perform Inference
|
||||
|
||||
```shell
|
||||
```
|
||||
bash inference.sh
|
||||
```
|
||||
|
||||
@ -288,95 +299,91 @@ bash inference.sh
|
||||
|
||||
### Preparing the Dataset
|
||||
|
||||
The dataset format should be as follows:
|
||||
The dataset should be structured as follows:
|
||||
|
||||
```
|
||||
.
|
||||
├── labels
|
||||
│ ├── 1.txt
|
||||
│ ├── 2.txt
|
||||
│ ├── ...
|
||||
│ ├── 1.txt
|
||||
│ ├── 2.txt
|
||||
│ ├── ...
|
||||
└── videos
|
||||
├── 1.mp4
|
||||
├── 2.mp4
|
||||
├── ...
|
||||
```
|
||||
|
||||
Each text file shares the same name as its corresponding video, serving as the label for that video. Videos and labels
|
||||
should be matched one-to-one. Generally, a single video should not be associated with multiple labels.
|
||||
Each txt file should have the same name as the corresponding video file and contain the label for that video. The videos and labels should correspond one-to-one. Generally, avoid using one video with multiple labels.
|
||||
|
||||
For style fine-tuning, please prepare at least 50 videos and labels with similar styles to ensure proper fitting.
|
||||
For style fine-tuning, prepare at least 50 videos and labels with a similar style to facilitate fitting.
|
||||
|
||||
### Modifying Configuration Files
|
||||
### Modifying the Configuration File
|
||||
|
||||
We support two fine-tuning methods: `Lora` and full-parameter fine-tuning. Please note that both methods only fine-tune
|
||||
the `transformer` part and do not modify the `VAE` section. `T5` is used solely as an Encoder. Please modify
|
||||
the `configs/sft.yaml` (for full-parameter fine-tuning) file as follows:
|
||||
We support two fine-tuning methods: `Lora` and full-parameter fine-tuning. Note that both methods only fine-tune the `transformer` part. The `VAE` part is not modified, and `T5` is only used as an encoder.
|
||||
Modify the files in `configs/sft.yaml` (full fine-tuning) as follows:
|
||||
|
||||
```
|
||||
# checkpoint_activations: True ## Using gradient checkpointing (Both checkpoint_activations in the config file need to be set to True)
|
||||
```yaml
|
||||
# checkpoint_activations: True ## using gradient checkpointing (both `checkpoint_activations` in the config file need to be set to True)
|
||||
model_parallel_size: 1 # Model parallel size
|
||||
experiment_name: lora-disney # Experiment name (do not modify)
|
||||
mode: finetune # Mode (do not modify)
|
||||
load: "{your_CogVideoX-2b-sat_path}/transformer" ## Transformer model path
|
||||
no_load_rng: True # Whether to load random seed
|
||||
experiment_name: lora-disney # Experiment name (do not change)
|
||||
mode: finetune # Mode (do not change)
|
||||
load: "{your_CogVideoX-2b-sat_path}/transformer" ## Path to Transformer model
|
||||
no_load_rng: True # Whether to load random number seed
|
||||
train_iters: 1000 # Training iterations
|
||||
eval_iters: 1 # Evaluation iterations
|
||||
eval_interval: 100 # Evaluation interval
|
||||
eval_batch_size: 1 # Evaluation batch size
|
||||
save: ckpts # Model save path
|
||||
save_interval: 100 # Model save interval
|
||||
save: ckpts # Model save path
|
||||
save_interval: 100 # Save interval
|
||||
log_interval: 20 # Log output interval
|
||||
train_data: [ "your train data path" ]
|
||||
valid_data: [ "your val data path" ] # Training and validation datasets can be the same
|
||||
split: 1,0,0 # Training, validation, and test set ratio
|
||||
num_workers: 8 # Number of worker threads for data loader
|
||||
force_train: True # Allow missing keys when loading checkpoint (T5 and VAE are loaded separately)
|
||||
only_log_video_latents: True # Avoid memory overhead caused by VAE decode
|
||||
valid_data: [ "your val data path" ] # Training and validation sets can be the same
|
||||
split: 1,0,0 # Proportion for training, validation, and test sets
|
||||
num_workers: 8 # Number of data loader workers
|
||||
force_train: True # Allow missing keys when loading checkpoint (T5 and VAE loaded separately)
|
||||
only_log_video_latents: True # Avoid memory usage from VAE decoding
|
||||
deepspeed:
|
||||
bf16:
|
||||
enabled: False # For CogVideoX-2B set to False and for CogVideoX-5B set to True
|
||||
enabled: False # For CogVideoX-2B Turn to False and For CogVideoX-5B Turn to True
|
||||
fp16:
|
||||
enabled: True # For CogVideoX-2B set to True and for CogVideoX-5B set to False
|
||||
enabled: True # For CogVideoX-2B Turn to True and For CogVideoX-5B Turn to False
|
||||
```
|
||||
|
||||
If you wish to use Lora fine-tuning, you also need to modify the `cogvideox_<model_parameters>_lora` file:
|
||||
``` To use Lora fine-tuning, you also need to modify `cogvideox_<model parameters>_lora` file:
|
||||
|
||||
Here, take `CogVideoX-2B` as a reference:
|
||||
Here's an example using `CogVideoX-2B`:
|
||||
|
||||
```
|
||||
model:
|
||||
scale_factor: 1.15258426
|
||||
scale_factor: 1.55258426
|
||||
disable_first_stage_autocast: true
|
||||
not_trainable_prefixes: [ 'all' ] ## Uncomment
|
||||
not_trainable_prefixes: [ 'all' ] ## Uncomment to unlock
|
||||
log_keys:
|
||||
- txt'
|
||||
- txt
|
||||
|
||||
lora_config: ## Uncomment
|
||||
lora_config: ## Uncomment to unlock
|
||||
target: sat.model.finetune.lora2.LoraMixin
|
||||
params:
|
||||
r: 256
|
||||
```
|
||||
|
||||
### Modifying Run Scripts
|
||||
### Modify the Run Script
|
||||
|
||||
Edit `finetune_single_gpu.sh` or `finetune_multi_gpus.sh` to select the configuration file. Below are two examples:
|
||||
Edit `finetune_single_gpu.sh` or `finetune_multi_gpus.sh` and select the config file. Below are two examples:
|
||||
|
||||
1. If you want to use the `CogVideoX-2B` model and the `Lora` method, you need to modify `finetune_single_gpu.sh`
|
||||
or `finetune_multi_gpus.sh`:
|
||||
1. If you want to use the `CogVideoX-2B` model with `Lora`, modify `finetune_single_gpu.sh` or `finetune_multi_gpus.sh` as follows:
|
||||
|
||||
```
|
||||
run_cmd="torchrun --standalone --nproc_per_node=8 train_video.py --base configs/cogvideox_2b_lora.yaml configs/sft.yaml --seed $RANDOM"
|
||||
```
|
||||
|
||||
2. If you want to use the `CogVideoX-2B` model and the `full-parameter fine-tuning` method, you need to
|
||||
modify `finetune_single_gpu.sh` or `finetune_multi_gpus.sh`:
|
||||
2. If you want to use the `CogVideoX-2B` model with full fine-tuning, modify `finetune_single_gpu.sh` or `finetune_multi_gpus.sh` as follows:
|
||||
|
||||
```
|
||||
run_cmd="torchrun --standalone --nproc_per_node=8 train_video.py --base configs/cogvideox_2b.yaml configs/sft.yaml --seed $RANDOM"
|
||||
```
|
||||
|
||||
### Fine-Tuning and Evaluation
|
||||
### Fine-tuning and Validation
|
||||
|
||||
Run the inference code to start fine-tuning.
|
||||
|
||||
@ -385,45 +392,42 @@ bash finetune_single_gpu.sh # Single GPU
|
||||
bash finetune_multi_gpus.sh # Multi GPUs
|
||||
```
|
||||
|
||||
### Using the Fine-Tuned Model
|
||||
### Using the Fine-tuned Model
|
||||
|
||||
The fine-tuned model cannot be merged; here is how to modify the inference configuration file `inference.sh`:
|
||||
The fine-tuned model cannot be merged. Here’s how to modify the inference configuration file `inference.sh`
|
||||
|
||||
```
|
||||
run_cmd="$environs python sample_video.py --base configs/cogvideox_<model_parameters>_lora.yaml configs/inference.yaml --seed 42"
|
||||
run_cmd="$environs python sample_video.py --base configs/cogvideox_<model parameters>_lora.yaml configs/inference.yaml --seed 42"
|
||||
```
|
||||
|
||||
Then, execute the code:
|
||||
Then, run the code:
|
||||
|
||||
```
|
||||
bash inference.sh
|
||||
```
|
||||
|
||||
### Converting to Huggingface Diffusers Supported Weights
|
||||
### Converting to Huggingface Diffusers-compatible Weights
|
||||
|
||||
The SAT weight format is different from Huggingface's weight format and needs to be converted. Please run:
|
||||
The SAT weight format is different from Huggingface’s format and requires conversion. Run
|
||||
|
||||
```shell
|
||||
```
|
||||
python ../tools/convert_weight_sat2hf.py
|
||||
```
|
||||
|
||||
### Exporting Huggingface Diffusers lora LoRA Weights from SAT Checkpoints
|
||||
### Exporting Lora Weights from SAT to Huggingface Diffusers
|
||||
|
||||
After completing the training using the above steps, we get a SAT checkpoint with LoRA weights. You can find the file
|
||||
at `{args.save}/1000/1000/mp_rank_00_model_states.pt`.
|
||||
Support is provided for exporting Lora weights from SAT to Huggingface Diffusers format.
|
||||
After training with the above steps, you’ll find the SAT model with Lora weights in {args.save}/1000/1000/mp_rank_00_model_states.pt
|
||||
|
||||
The script for exporting LoRA weights can be found in the CogVideoX repository at `tools/export_sat_lora_weight.py`.
|
||||
After exporting, you can use `load_cogvideox_lora.py` for inference.
|
||||
The export script `export_sat_lora_weight.py` is located in the CogVideoX repository under `tools/`. After exporting, use `load_cogvideox_lora.py` for inference.
|
||||
|
||||
Export command:
|
||||
|
||||
```bash
|
||||
python tools/export_sat_lora_weight.py --sat_pt_path {args.save}/{experiment_name}-09-09-21-10/1000/mp_rank_00_model_states.pt --lora_save_directory {args.save}/export_hf_lora_weights_1/
|
||||
```
|
||||
python tools/export_sat_lora_weight.py --sat_pt_path {args.save}/{experiment_name}-09-09-21-10/1000/mp_rank_00_model_states.pt --lora_save_directory {args.save}/export_hf_lora_weights_1/
|
||||
```
|
||||
|
||||
This training mainly modified the following model structures. The table below lists the corresponding structure mappings
|
||||
for converting to the HF (Hugging Face) format LoRA structure. As you can see, LoRA adds a low-rank weight to the
|
||||
model's attention structure.
|
||||
The following model structures were modified during training. Here is the mapping between SAT and HF Lora structures. Lora adds a low-rank weight to the attention structure of the model.
|
||||
|
||||
```
|
||||
'attention.query_key_value.matrix_A.0': 'attn1.to_q.lora_A.weight',
|
||||
@ -436,5 +440,5 @@ model's attention structure.
|
||||
'attention.dense.matrix_B.0': 'attn1.to_out.0.lora_B.weight'
|
||||
```
|
||||
|
||||
Using export_sat_lora_weight.py, you can convert the SAT checkpoint into the HF LoRA format.
|
||||

|
||||
Using `export_sat_lora_weight.py` will convert these to the HF format Lora structure.
|
||||

|
302
sat/README_ja.md
302
sat/README_ja.md
@ -1,27 +1,37 @@
|
||||
# SAT CogVideoX-2B
|
||||
# SAT CogVideoX
|
||||
|
||||
[Read this in English.](./README_zh)
|
||||
[Read this in English.](./README.md)
|
||||
|
||||
[中文阅读](./README_zh.md)
|
||||
|
||||
このフォルダには、[SAT](https://github.com/THUDM/SwissArmyTransformer) ウェイトを使用した推論コードと、SAT
|
||||
ウェイトのファインチューニングコードが含まれています。
|
||||
|
||||
このコードは、チームがモデルをトレーニングするために使用したフレームワークです。コメントが少なく、注意深く研究する必要があります。
|
||||
このフォルダには、[SAT](https://github.com/THUDM/SwissArmyTransformer)の重みを使用した推論コードと、SAT重みのファインチューニングコードが含まれています。
|
||||
このコードは、チームがモデルを訓練する際に使用したフレームワークです。コメントが少ないため、注意深く確認する必要があります。
|
||||
|
||||
## 推論モデル
|
||||
|
||||
### 1. このフォルダに必要な依存関係が正しくインストールされていることを確認してください。
|
||||
### 1. このフォルダ内の必要な依存関係がすべてインストールされていることを確認してください
|
||||
|
||||
```shell
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### 2. モデルウェイトをダウンロードします
|
||||
### 2. モデルの重みをダウンロード
|
||||
まず、SATミラーからモデルの重みをダウンロードしてください。
|
||||
|
||||
まず、SAT ミラーに移動してモデルの重みをダウンロードします。 CogVideoX-2B モデルの場合は、次のようにダウンロードしてください。
|
||||
#### CogVideoX1.5 モデル
|
||||
|
||||
```shell
|
||||
```
|
||||
git lfs install
|
||||
git clone https://huggingface.co/THUDM/CogVideoX1.5-5B-SAT
|
||||
```
|
||||
|
||||
これにより、Transformers、VAE、T5 Encoderの3つのモデルがダウンロードされます。
|
||||
|
||||
#### CogVideoX モデル
|
||||
|
||||
CogVideoX-2B モデルについては、以下のようにダウンロードしてください:
|
||||
|
||||
```
|
||||
mkdir CogVideoX-2b-sat
|
||||
cd CogVideoX-2b-sat
|
||||
wget https://cloud.tsinghua.edu.cn/f/fdba7608a49c463ba754/?dl=1
|
||||
@ -32,12 +42,12 @@ mv 'index.html?dl=1' transformer.zip
|
||||
unzip transformer.zip
|
||||
```
|
||||
|
||||
CogVideoX-5B モデルの `transformers` ファイルを以下のリンクからダウンロードしてください (VAE ファイルは 2B と同じです):
|
||||
CogVideoX-5B モデルの `transformers` ファイルをダウンロードしてください(VAEファイルは2Bと同じです):
|
||||
|
||||
+ [CogVideoX-5B](https://cloud.tsinghua.edu.cn/d/fcef5b3904294a6885e5/?p=%2F&mode=list)
|
||||
+ [CogVideoX-5B-I2V](https://cloud.tsinghua.edu.cn/d/5cc62a2d6e7d45c0a2f6/?p=%2F1&mode=list)
|
||||
|
||||
次に、モデルファイルを以下の形式にフォーマットする必要があります:
|
||||
モデルファイルを以下のように配置してください:
|
||||
|
||||
```
|
||||
.
|
||||
@ -49,24 +59,24 @@ CogVideoX-5B モデルの `transformers` ファイルを以下のリンクから
|
||||
└── 3d-vae.pt
|
||||
```
|
||||
|
||||
モデルの重みファイルが大きいため、`git lfs`を使用することをお勧めいたします。`git lfs`
|
||||
のインストールについては、[こちら](https://github.com/git-lfs/git-lfs?tab=readme-ov-file#installing)をご参照ください。
|
||||
モデルの重みファイルが大きいため、`git lfs`の使用をお勧めします。
|
||||
`git lfs`のインストール方法は[こちら](https://github.com/git-lfs/git-lfs?tab=readme-ov-file#installing)を参照してください。
|
||||
|
||||
```shell
|
||||
```
|
||||
git lfs install
|
||||
```
|
||||
|
||||
次に、T5 モデルをクローンします。これはトレーニングやファインチューニングには使用されませんが、使用する必要があります。
|
||||
> モデルを複製する際には、[Modelscope](https://modelscope.cn/models/ZhipuAI/CogVideoX-2b)のモデルファイルの場所もご使用いただけます。
|
||||
次に、T5モデルをクローンします。このモデルはEncoderとしてのみ使用され、訓練やファインチューニングは必要ありません。
|
||||
> [Modelscope](https://modelscope.cn/models/ZhipuAI/CogVideoX-2b)上のモデルファイルも使用可能です。
|
||||
|
||||
```shell
|
||||
git clone https://huggingface.co/THUDM/CogVideoX-2b.git #ハギングフェイス(huggingface.org)からモデルをダウンロードいただきます
|
||||
# git clone https://www.modelscope.cn/ZhipuAI/CogVideoX-2b.git #Modelscopeからモデルをダウンロードいただきます
|
||||
```
|
||||
git clone https://huggingface.co/THUDM/CogVideoX-2b.git # Huggingfaceからモデルをダウンロード
|
||||
# git clone https://www.modelscope.cn/ZhipuAI/CogVideoX-2b.git # Modelscopeからダウンロード
|
||||
mkdir t5-v1_1-xxl
|
||||
mv CogVideoX-2b/text_encoder/* CogVideoX-2b/tokenizer/* t5-v1_1-xxl
|
||||
```
|
||||
|
||||
上記の方法に従うことで、safetensor 形式の T5 ファイルを取得できます。これにより、Deepspeed でのファインチューニング中にエラーが発生しないようにします。
|
||||
これにより、Deepspeedファインチューニング中にエラーなくロードできるsafetensor形式のT5ファイルが作成されます。
|
||||
|
||||
```
|
||||
├── added_tokens.json
|
||||
@ -81,11 +91,11 @@ mv CogVideoX-2b/text_encoder/* CogVideoX-2b/tokenizer/* t5-v1_1-xxl
|
||||
0 directories, 8 files
|
||||
```
|
||||
|
||||
### 3. `configs/cogvideox_2b.yaml` ファイルを変更します。
|
||||
### 3. `configs/cogvideox_*.yaml`ファイルを編集
|
||||
|
||||
```yaml
|
||||
model:
|
||||
scale_factor: 1.15258426
|
||||
scale_factor: 1.55258426
|
||||
disable_first_stage_autocast: true
|
||||
log_keys:
|
||||
- txt
|
||||
@ -123,7 +133,7 @@ model:
|
||||
num_attention_heads: 30
|
||||
|
||||
transformer_args:
|
||||
checkpoint_activations: True ## グラデーション チェックポイントを使用する
|
||||
checkpoint_activations: True ## using gradient checkpointing
|
||||
vocab_size: 1
|
||||
max_sequence_length: 64
|
||||
layernorm_order: pre
|
||||
@ -161,14 +171,14 @@ model:
|
||||
ucg_rate: 0.1
|
||||
target: sgm.modules.encoders.modules.FrozenT5Embedder
|
||||
params:
|
||||
model_dir: "t5-v1_1-xxl" # CogVideoX-2b/t5-v1_1-xxlフォルダの絶対パス
|
||||
model_dir: "t5-v1_1-xxl" # CogVideoX-2b/t5-v1_1-xxl 重みフォルダの絶対パス
|
||||
max_length: 226
|
||||
|
||||
first_stage_config:
|
||||
target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper
|
||||
params:
|
||||
cp_size: 1
|
||||
ckpt_path: "CogVideoX-2b-sat/vae/3d-vae.pt" # CogVideoX-2b-sat/vae/3d-vae.ptフォルダの絶対パス
|
||||
ckpt_path: "CogVideoX-2b-sat/vae/3d-vae.pt" # CogVideoX-2b-sat/vae/3d-vae.ptファイルの絶対パス
|
||||
ignore_keys: [ 'loss' ]
|
||||
|
||||
loss_config:
|
||||
@ -240,7 +250,7 @@ model:
|
||||
num_steps: 50
|
||||
```
|
||||
|
||||
### 4. `configs/inference.yaml` ファイルを変更します。
|
||||
### 4. `configs/inference.yaml`ファイルを編集
|
||||
|
||||
```yaml
|
||||
args:
|
||||
@ -250,38 +260,36 @@ args:
|
||||
# load: "{your lora folder} such as zRzRzRzRzRzRzR/lora-disney-08-20-13-28" # This is for Full model without lora adapter
|
||||
|
||||
batch_size: 1
|
||||
input_type: txt #TXTのテキストファイルを入力として選択されたり、CLIコマンドラインを入力として変更されたりいただけます
|
||||
input_file: configs/test.txt #テキストファイルのパスで、これに対して編集がさせていただけます
|
||||
sampling_num_frames: 13 # Must be 13, 11 or 9
|
||||
input_type: txt # "txt"でプレーンテキスト入力、"cli"でコマンドライン入力を選択可能
|
||||
input_file: configs/test.txt # プレーンテキストファイル、編集可能
|
||||
sampling_num_frames: 13 # CogVideoX1.5-5Bでは42または22、CogVideoX-5B / 2Bでは13, 11, または9
|
||||
sampling_fps: 8
|
||||
fp16: True # For CogVideoX-2B
|
||||
# bf16: True # For CogVideoX-5B
|
||||
fp16: True # CogVideoX-2B用
|
||||
# bf16: True # CogVideoX-5B用
|
||||
output_dir: outputs/
|
||||
force_inference: True
|
||||
```
|
||||
|
||||
+ 複数のプロンプトを保存するために txt を使用する場合は、`configs/test.txt`
|
||||
を参照して変更してください。1行に1つのプロンプトを記述します。プロンプトの書き方がわからない場合は、最初に [このコード](../inference/convert_demo.py)
|
||||
を使用して LLM によるリファインメントを呼び出すことができます。
|
||||
+ コマンドラインを入力として使用する場合は、次のように変更します。
|
||||
+ 複数のプロンプトを含むテキストファイルを使用する場合、`configs/test.txt`を適宜編集してください。1行につき1プロンプトです。プロンプトの書き方が分からない場合は、[こちらのコード](../inference/convert_demo.py)を使用してLLMで補正できます。
|
||||
+ コマンドライン入力を使用する場合、以下のように変更します:
|
||||
|
||||
```yaml
|
||||
```
|
||||
input_type: cli
|
||||
```
|
||||
|
||||
これにより、コマンドラインからプロンプトを入力できます。
|
||||
|
||||
出力ビデオのディレクトリを変更したい場合は、次のように変更できます:
|
||||
出力ビデオの保存場所を変更する場合は、以下を編集してください:
|
||||
|
||||
```yaml
|
||||
```
|
||||
output_dir: outputs/
|
||||
```
|
||||
|
||||
デフォルトでは `.outputs/` フォルダに保存されます。
|
||||
デフォルトでは`.outputs/`フォルダに保存されます。
|
||||
|
||||
### 5. 推論コードを実行して推論を開始します。
|
||||
### 5. 推論コードを実行して推論を開始
|
||||
|
||||
```shell
|
||||
```
|
||||
bash inference.sh
|
||||
```
|
||||
|
||||
@ -289,7 +297,7 @@ bash inference.sh
|
||||
|
||||
### データセットの準備
|
||||
|
||||
データセットの形式は次のようになります:
|
||||
データセットは以下の構造である必要があります:
|
||||
|
||||
```
|
||||
.
|
||||
@ -303,123 +311,215 @@ bash inference.sh
|
||||
├── ...
|
||||
```
|
||||
|
||||
各 txt ファイルは対応するビデオファイルと同じ名前であり、そのビデオのラベルを含んでいます。各ビデオはラベルと一対一で対応する必要があります。通常、1つのビデオに複数のラベルを持たせることはありません。
|
||||
各txtファイルは対応するビデオファイルと同じ名前で、ビデオのラベルを含んでいます。ビデオとラベルは一対一で対応させる必要があります。通常、1つのビデオに複数のラベルを使用することは避けてください。
|
||||
|
||||
スタイルファインチューニングの場合、少なくとも50本のスタイルが似たビデオとラベルを準備し、フィッティングを容易にします。
|
||||
スタイルのファインチューニングの場合、スタイルが似たビデオとラベルを少なくとも50本準備し、フィッティングを促進します。
|
||||
|
||||
### 設定ファイルの変更
|
||||
### 設定ファイルの編集
|
||||
|
||||
`Lora` とフルパラメータ微調整の2つの方法をサポートしています。両方の微調整方法は、`transformer` 部分のみを微調整し、`VAE`
|
||||
部分には変更を加えないことに注意してください。`T5` はエンコーダーとしてのみ使用されます。以下のように `configs/sft.yaml` (
|
||||
フルパラメータ微調整用) ファイルを変更してください。
|
||||
``` `Lora`と全パラメータのファインチューニングの2種類をサポートしています。どちらも`transformer`部分のみをファインチューニングし、`VAE`部分は変更されず、`T5`はエンコーダーとしてのみ使用されます。
|
||||
``` 以下のようにして`configs/sft.yaml`(全量ファインチューニング)ファイルを編集してください:
|
||||
|
||||
```
|
||||
# checkpoint_activations: True ## 勾配チェックポイントを使用する場合 (設定ファイル内の2つの checkpoint_activations を True に設定する必要があります)
|
||||
# checkpoint_activations: True ## using gradient checkpointing (configファイル内の2つの`checkpoint_activations`を両方Trueに設定)
|
||||
model_parallel_size: 1 # モデル並列サイズ
|
||||
experiment_name: lora-disney # 実験名 (変更しないでください)
|
||||
mode: finetune # モード (変更しないでください)
|
||||
load: "{your_CogVideoX-2b-sat_path}/transformer" ## Transformer モデルのパス
|
||||
no_load_rng: True # 乱数シードを読み込むかどうか
|
||||
experiment_name: lora-disney # 実験名(変更不要)
|
||||
mode: finetune # モード(変更不要)
|
||||
load: "{your_CogVideoX-2b-sat_path}/transformer" ## Transformerモデルのパス
|
||||
no_load_rng: True # 乱数シードをロードするかどうか
|
||||
train_iters: 1000 # トレーニングイテレーション数
|
||||
eval_iters: 1 # 評価イテレーション数
|
||||
eval_interval: 100 # 評価間隔
|
||||
eval_batch_size: 1 # 評価バッチサイズ
|
||||
save: ckpts # モデル保存パス
|
||||
save_interval: 100 # モデル保存間隔
|
||||
eval_iters: 1 # 検証イテレーション数
|
||||
eval_interval: 100 # 検証間隔
|
||||
eval_batch_size: 1 # 検証バッチサイズ
|
||||
save: ckpts # モデル保存パス
|
||||
save_interval: 100 # 保存間隔
|
||||
log_interval: 20 # ログ出力間隔
|
||||
train_data: [ "your train data path" ]
|
||||
valid_data: [ "your val data path" ] # トレーニングデータと評価データは同じでも構いません
|
||||
split: 1,0,0 # トレーニングセット、評価セット、テストセットの割合
|
||||
num_workers: 8 # データローダーのワーカースレッド数
|
||||
force_train: True # チェックポイントをロードするときに欠落したキーを許可 (T5 と VAE は別々にロードされます)
|
||||
only_log_video_latents: True # VAE のデコードによるメモリオーバーヘッドを回避
|
||||
valid_data: [ "your val data path" ] # トレーニングセットと検証セットは同じでも構いません
|
||||
split: 1,0,0 # トレーニングセット、検証セット、テストセットの割合
|
||||
num_workers: 8 # データローダーのワーカー数
|
||||
force_train: True # チェックポイントをロードする際に`missing keys`を許可(T5とVAEは別途ロード)
|
||||
only_log_video_latents: True # VAEのデコードによるメモリ使用量を抑える
|
||||
deepspeed:
|
||||
bf16:
|
||||
enabled: False # CogVideoX-2B の場合は False に設定し、CogVideoX-5B の場合は True に設定
|
||||
enabled: False # CogVideoX-2B 用は False、CogVideoX-5B 用は True に設定
|
||||
fp16:
|
||||
enabled: True # CogVideoX-2B の場合は True に設定し、CogVideoX-5B の場合は False に設定
|
||||
enabled: True # CogVideoX-2B 用は True、CogVideoX-5B 用は False に設定
|
||||
```
|
||||
```yaml
|
||||
args:
|
||||
latent_channels: 16
|
||||
mode: inference
|
||||
load: "{absolute_path/to/your}/transformer" # Absolute path to CogVideoX-2b-sat/transformer folder
|
||||
# load: "{your lora folder} such as zRzRzRzRzRzRzR/lora-disney-08-20-13-28" # This is for Full model without lora adapter
|
||||
|
||||
batch_size: 1
|
||||
input_type: txt # You can choose "txt" for plain text input or change to "cli" for command-line input
|
||||
input_file: configs/test.txt # Plain text file, can be edited
|
||||
sampling_num_frames: 13 # For CogVideoX1.5-5B it must be 42 or 22. For CogVideoX-5B / 2B, it must be 13, 11, or 9.
|
||||
sampling_fps: 8
|
||||
fp16: True # For CogVideoX-2B
|
||||
# bf16: True # For CogVideoX-5B
|
||||
output_dir: outputs/
|
||||
force_inference: True
|
||||
```
|
||||
|
||||
Lora 微調整を使用したい場合は、`cogvideox_<model_parameters>_lora` ファイルも変更する必要があります。
|
||||
|
||||
ここでは、`CogVideoX-2B` を参考にします。
|
||||
+ If using a text file to save multiple prompts, modify `configs/test.txt` as needed. One prompt per line. If you are unsure how to write prompts, use [this code](../inference/convert_demo.py) to call an LLM for refinement.
|
||||
+ To use command-line input, modify:
|
||||
|
||||
```
|
||||
input_type: cli
|
||||
```
|
||||
|
||||
This allows you to enter prompts from the command line.
|
||||
|
||||
To modify the output video location, change:
|
||||
|
||||
```
|
||||
output_dir: outputs/
|
||||
```
|
||||
|
||||
The default location is the `.outputs/` folder.
|
||||
|
||||
### 5. Run the Inference Code to Perform Inference
|
||||
|
||||
```
|
||||
bash inference.sh
|
||||
```
|
||||
|
||||
## Fine-tuning the Model
|
||||
|
||||
### Preparing the Dataset
|
||||
|
||||
The dataset should be structured as follows:
|
||||
|
||||
```
|
||||
.
|
||||
├── labels
|
||||
│ ├── 1.txt
|
||||
│ ├── 2.txt
|
||||
│ ├── ...
|
||||
└── videos
|
||||
├── 1.mp4
|
||||
├── 2.mp4
|
||||
├── ...
|
||||
```
|
||||
|
||||
Each txt file should have the same name as the corresponding video file and contain the label for that video. The videos and labels should correspond one-to-one. Generally, avoid using one video with multiple labels.
|
||||
|
||||
For style fine-tuning, prepare at least 50 videos and labels with a similar style to facilitate fitting.
|
||||
|
||||
### Modifying the Configuration File
|
||||
|
||||
We support two fine-tuning methods: `Lora` and full-parameter fine-tuning. Note that both methods only fine-tune the `transformer` part. The `VAE` part is not modified, and `T5` is only used as an encoder.
|
||||
Modify the files in `configs/sft.yaml` (full fine-tuning) as follows:
|
||||
|
||||
```yaml
|
||||
# checkpoint_activations: True ## using gradient checkpointing (both `checkpoint_activations` in the config file need to be set to True)
|
||||
model_parallel_size: 1 # Model parallel size
|
||||
experiment_name: lora-disney # Experiment name (do not change)
|
||||
mode: finetune # Mode (do not change)
|
||||
load: "{your_CogVideoX-2b-sat_path}/transformer" ## Path to Transformer model
|
||||
no_load_rng: True # Whether to load random number seed
|
||||
train_iters: 1000 # Training iterations
|
||||
eval_iters: 1 # Evaluation iterations
|
||||
eval_interval: 100 # Evaluation interval
|
||||
eval_batch_size: 1 # Evaluation batch size
|
||||
save: ckpts # Model save path
|
||||
save_interval: 100 # Save interval
|
||||
log_interval: 20 # Log output interval
|
||||
train_data: [ "your train data path" ]
|
||||
valid_data: [ "your val data path" ] # Training and validation sets can be the same
|
||||
split: 1,0,0 # Proportion for training, validation, and test sets
|
||||
num_workers: 8 # Number of data loader workers
|
||||
force_train: True # Allow missing keys when loading checkpoint (T5 and VAE loaded separately)
|
||||
only_log_video_latents: True # Avoid memory usage from VAE decoding
|
||||
deepspeed:
|
||||
bf16:
|
||||
enabled: False # For CogVideoX-2B Turn to False and For CogVideoX-5B Turn to True
|
||||
fp16:
|
||||
enabled: True # For CogVideoX-2B Turn to True and For CogVideoX-5B Turn to False
|
||||
```
|
||||
|
||||
``` To use Lora fine-tuning, you also need to modify `cogvideox_<model parameters>_lora` file:
|
||||
|
||||
Here's an example using `CogVideoX-2B`:
|
||||
|
||||
```yaml
|
||||
model:
|
||||
scale_factor: 1.15258426
|
||||
scale_factor: 1.55258426
|
||||
disable_first_stage_autocast: true
|
||||
not_trainable_prefixes: [ 'all' ] ## コメントを解除
|
||||
not_trainable_prefixes: [ 'all' ] ## Uncomment to unlock
|
||||
log_keys:
|
||||
- txt'
|
||||
- txt
|
||||
|
||||
lora_config: ## コメントを解除
|
||||
lora_config: ## Uncomment to unlock
|
||||
target: sat.model.finetune.lora2.LoraMixin
|
||||
params:
|
||||
r: 256
|
||||
```
|
||||
|
||||
### 実行スクリプトの変更
|
||||
### Modify the Run Script
|
||||
|
||||
設定ファイルを選択するために `finetune_single_gpu.sh` または `finetune_multi_gpus.sh` を編集します。以下に2つの例を示します。
|
||||
Edit `finetune_single_gpu.sh` or `finetune_multi_gpus.sh` and select the config file. Below are two examples:
|
||||
|
||||
1. `CogVideoX-2B` モデルを使用し、`Lora` 手法を利用する場合は、`finetune_single_gpu.sh` または `finetune_multi_gpus.sh`
|
||||
を変更する必要があります。
|
||||
1. If you want to use the `CogVideoX-2B` model with `Lora`, modify `finetune_single_gpu.sh` or `finetune_multi_gpus.sh` as follows:
|
||||
|
||||
```
|
||||
run_cmd="torchrun --standalone --nproc_per_node=8 train_video.py --base configs/cogvideox_2b_lora.yaml configs/sft.yaml --seed $RANDOM"
|
||||
```
|
||||
|
||||
2. `CogVideoX-2B` モデルを使用し、`フルパラメータ微調整` 手法を利用する場合は、`finetune_single_gpu.sh`
|
||||
または `finetune_multi_gpus.sh` を変更する必要があります。
|
||||
2. If you want to use the `CogVideoX-2B` model with full fine-tuning, modify `finetune_single_gpu.sh` or `finetune_multi_gpus.sh` as follows:
|
||||
|
||||
```
|
||||
run_cmd="torchrun --standalone --nproc_per_node=8 train_video.py --base configs/cogvideox_2b.yaml configs/sft.yaml --seed $RANDOM"
|
||||
```
|
||||
|
||||
### 微調整と評価
|
||||
### Fine-tuning and Validation
|
||||
|
||||
推論コードを実行して微調整を開始します。
|
||||
Run the inference code to start fine-tuning.
|
||||
|
||||
```
|
||||
bash finetune_single_gpu.sh # シングルGPU
|
||||
bash finetune_multi_gpus.sh # マルチGPU
|
||||
bash finetune_single_gpu.sh # Single GPU
|
||||
bash finetune_multi_gpus.sh # Multi GPUs
|
||||
```
|
||||
|
||||
### 微調整後のモデルの使用
|
||||
### Using the Fine-tuned Model
|
||||
|
||||
微調整されたモデルは統合できません。ここでは、推論設定ファイル `inference.sh` を変更する方法を示します。
|
||||
The fine-tuned model cannot be merged. Here’s how to modify the inference configuration file `inference.sh`
|
||||
|
||||
```
|
||||
run_cmd="$environs python sample_video.py --base configs/cogvideox_<model_parameters>_lora.yaml configs/inference.yaml --seed 42"
|
||||
run_cmd="$environs python sample_video.py --base configs/cogvideox_<model parameters>_lora.yaml configs/inference.yaml --seed 42"
|
||||
```
|
||||
|
||||
その後、次のコードを実行します。
|
||||
Then, run the code:
|
||||
|
||||
```
|
||||
bash inference.sh
|
||||
```
|
||||
|
||||
### Huggingface Diffusers サポートのウェイトに変換
|
||||
### Converting to Huggingface Diffusers-compatible Weights
|
||||
|
||||
SAT ウェイト形式は Huggingface のウェイト形式と異なり、変換が必要です。次のコマンドを実行してください:
|
||||
The SAT weight format is different from Huggingface’s format and requires conversion. Run
|
||||
|
||||
```shell
|
||||
```
|
||||
python ../tools/convert_weight_sat2hf.py
|
||||
```
|
||||
|
||||
### SATチェックポイントからHuggingface Diffusers lora LoRAウェイトをエクスポート
|
||||
### Exporting Lora Weights from SAT to Huggingface Diffusers
|
||||
|
||||
上記のステップを完了すると、LoRAウェイト付きのSATチェックポイントが得られます。ファイルは `{args.save}/1000/1000/mp_rank_00_model_states.pt` にあります。
|
||||
Support is provided for exporting Lora weights from SAT to Huggingface Diffusers format.
|
||||
After training with the above steps, you’ll find the SAT model with Lora weights in {args.save}/1000/1000/mp_rank_00_model_states.pt
|
||||
|
||||
LoRAウェイトをエクスポートするためのスクリプトは、CogVideoXリポジトリの `tools/export_sat_lora_weight.py` にあります。エクスポート後、`load_cogvideox_lora.py` を使用して推論を行うことができます。
|
||||
The export script `export_sat_lora_weight.py` is located in the CogVideoX repository under `tools/`. After exporting, use `load_cogvideox_lora.py` for inference.
|
||||
|
||||
エクスポートコマンド:
|
||||
Export command:
|
||||
|
||||
```bash
|
||||
python tools/export_sat_lora_weight.py --sat_pt_path {args.save}/{experiment_name}-09-09-21-10/1000/mp_rank_00_model_states.pt --lora_save_directory {args.save}/export_hf_lora_weights_1/
|
||||
```
|
||||
python tools/export_sat_lora_weight.py --sat_pt_path {args.save}/{experiment_name}-09-09-21-10/1000/mp_rank_00_model_states.pt --lora_save_directory {args.save}/export_hf_lora_weights_1/
|
||||
```
|
||||
|
||||
このトレーニングでは主に以下のモデル構造が変更されました。以下の表は、HF (Hugging Face) 形式のLoRA構造に変換する際の対応関係を示しています。ご覧の通り、LoRAはモデルの注意メカニズムに低ランクの重みを追加しています。
|
||||
The following model structures were modified during training. Here is the mapping between SAT and HF Lora structures. Lora adds a low-rank weight to the attention structure of the model.
|
||||
|
||||
```
|
||||
'attention.query_key_value.matrix_A.0': 'attn1.to_q.lora_A.weight',
|
||||
@ -431,8 +531,6 @@ python tools/export_sat_lora_weight.py --sat_pt_path {args.save}/{experiment_nam
|
||||
'attention.dense.matrix_A.0': 'attn1.to_out.0.lora_A.weight',
|
||||
'attention.dense.matrix_B.0': 'attn1.to_out.0.lora_B.weight'
|
||||
```
|
||||
|
||||
export_sat_lora_weight.py を使用して、SATチェックポイントをHF LoRA形式に変換できます。
|
||||
|
||||
|
||||

|
||||
Using `export_sat_lora_weight.py` will convert these to the HF format Lora structure.
|
||||

|
@ -1,6 +1,6 @@
|
||||
# SAT CogVideoX-2B
|
||||
# SAT CogVideoX
|
||||
|
||||
[Read this in English.](./README_zh)
|
||||
[Read this in English.](./README.md)
|
||||
|
||||
[日本語で読む](./README_ja.md)
|
||||
|
||||
@ -20,6 +20,15 @@ pip install -r requirements.txt
|
||||
|
||||
首先,前往 SAT 镜像下载模型权重。
|
||||
|
||||
#### CogVideoX1.5 模型
|
||||
|
||||
```shell
|
||||
git lfs install
|
||||
git clone https://huggingface.co/THUDM/CogVideoX1.5-5B-SAT
|
||||
```
|
||||
此操作会下载 Transformers, VAE, T5 Encoder 这三个模型。
|
||||
|
||||
#### CogVideoX 模型
|
||||
对于 CogVideoX-2B 模型,请按照如下方式下载:
|
||||
|
||||
```shell
|
||||
@ -82,11 +91,11 @@ mv CogVideoX-2b/text_encoder/* CogVideoX-2b/tokenizer/* t5-v1_1-xxl
|
||||
0 directories, 8 files
|
||||
```
|
||||
|
||||
### 3. 修改`configs/cogvideox_2b.yaml`中的文件。
|
||||
### 3. 修改`configs/cogvideox_*.yaml`中的文件。
|
||||
|
||||
```yaml
|
||||
model:
|
||||
scale_factor: 1.15258426
|
||||
scale_factor: 1.55258426
|
||||
disable_first_stage_autocast: true
|
||||
log_keys:
|
||||
- txt
|
||||
@ -253,7 +262,7 @@ args:
|
||||
batch_size: 1
|
||||
input_type: txt #可以选择txt纯文字档作为输入,或者改成cli命令行作为输入
|
||||
input_file: configs/test.txt #纯文字档,可以对此做编辑
|
||||
sampling_num_frames: 13 # Must be 13, 11 or 9
|
||||
sampling_num_frames: 13 #CogVideoX1.5-5B 必须是 42 或 22。 CogVideoX-5B / 2B 必须是 13 11 或 9。
|
||||
sampling_fps: 8
|
||||
fp16: True # For CogVideoX-2B
|
||||
# bf16: True # For CogVideoX-5B
|
||||
@ -346,7 +355,7 @@ Encoder 使用。
|
||||
|
||||
```yaml
|
||||
model:
|
||||
scale_factor: 1.15258426
|
||||
scale_factor: 1.55258426
|
||||
disable_first_stage_autocast: true
|
||||
not_trainable_prefixes: [ 'all' ] ## 解除注释
|
||||
log_keys:
|
||||
|
@ -36,6 +36,7 @@ def add_sampling_config_args(parser):
|
||||
group.add_argument("--input-dir", type=str, default=None)
|
||||
group.add_argument("--input-type", type=str, default="cli")
|
||||
group.add_argument("--input-file", type=str, default="input.txt")
|
||||
group.add_argument("--sampling-image-size", type=list, default=[768, 1360])
|
||||
group.add_argument("--final-size", type=int, default=2048)
|
||||
group.add_argument("--sdedit", action="store_true")
|
||||
group.add_argument("--grid-num-rows", type=int, default=1)
|
||||
|
149
sat/configs/cogvideox1.5_5b.yaml
Normal file
149
sat/configs/cogvideox1.5_5b.yaml
Normal file
@ -0,0 +1,149 @@
|
||||
model:
|
||||
scale_factor: 0.7
|
||||
disable_first_stage_autocast: true
|
||||
latent_input: true
|
||||
log_keys:
|
||||
- txt
|
||||
|
||||
denoiser_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
||||
params:
|
||||
num_idx: 1000
|
||||
quantize_c_noise: False
|
||||
|
||||
weighting_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
|
||||
scaling_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_scaling.VideoScaling
|
||||
discretization_config:
|
||||
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
|
||||
|
||||
network_config:
|
||||
target: dit_video_concat.DiffusionTransformer
|
||||
params:
|
||||
time_embed_dim: 512
|
||||
elementwise_affine: True
|
||||
num_frames: 81
|
||||
time_compressed_rate: 4
|
||||
latent_width: 300
|
||||
latent_height: 300
|
||||
num_layers: 42
|
||||
patch_size: [2, 2, 2]
|
||||
in_channels: 16
|
||||
out_channels: 16
|
||||
hidden_size: 3072
|
||||
adm_in_channels: 256
|
||||
num_attention_heads: 48
|
||||
|
||||
transformer_args:
|
||||
checkpoint_activations: True
|
||||
vocab_size: 1
|
||||
max_sequence_length: 64
|
||||
layernorm_order: pre
|
||||
skip_init: false
|
||||
model_parallel_size: 1
|
||||
is_decoder: false
|
||||
|
||||
modules:
|
||||
pos_embed_config:
|
||||
target: dit_video_concat.Rotary3DPositionEmbeddingMixin
|
||||
params:
|
||||
hidden_size_head: 64
|
||||
text_length: 224
|
||||
|
||||
patch_embed_config:
|
||||
target: dit_video_concat.ImagePatchEmbeddingMixin
|
||||
params:
|
||||
text_hidden_size: 4096
|
||||
|
||||
adaln_layer_config:
|
||||
target: dit_video_concat.AdaLNMixin
|
||||
params:
|
||||
qk_ln: True
|
||||
|
||||
final_layer_config:
|
||||
target: dit_video_concat.FinalLayerMixin
|
||||
|
||||
conditioner_config:
|
||||
target: sgm.modules.GeneralConditioner
|
||||
params:
|
||||
emb_models:
|
||||
- is_trainable: false
|
||||
input_key: txt
|
||||
ucg_rate: 0.1
|
||||
target: sgm.modules.encoders.modules.FrozenT5Embedder
|
||||
params:
|
||||
model_dir: "google/t5-v1_1-xxl"
|
||||
max_length: 224
|
||||
|
||||
|
||||
first_stage_config:
|
||||
target : vae_modules.autoencoder.VideoAutoencoderInferenceWrapper
|
||||
params:
|
||||
cp_size: 1
|
||||
ckpt_path: "cogvideox-5b-sat/vae/3d-vae.pt"
|
||||
ignore_keys: ['loss']
|
||||
|
||||
loss_config:
|
||||
target: torch.nn.Identity
|
||||
|
||||
regularizer_config:
|
||||
target: vae_modules.regularizers.DiagonalGaussianRegularizer
|
||||
|
||||
encoder_config:
|
||||
target: vae_modules.cp_enc_dec.ContextParallelEncoder3D
|
||||
params:
|
||||
double_z: true
|
||||
z_channels: 16
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult: [1, 2, 2, 4]
|
||||
attn_resolutions: []
|
||||
num_res_blocks: 3
|
||||
dropout: 0.0
|
||||
gather_norm: True
|
||||
|
||||
decoder_config:
|
||||
target: vae_modules.cp_enc_dec.ContextParallelDecoder3D
|
||||
params:
|
||||
double_z: True
|
||||
z_channels: 16
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult: [1, 2, 2, 4]
|
||||
attn_resolutions: []
|
||||
num_res_blocks: 3
|
||||
dropout: 0.0
|
||||
gather_norm: True
|
||||
|
||||
loss_fn_config:
|
||||
target: sgm.modules.diffusionmodules.loss.VideoDiffusionLoss
|
||||
params:
|
||||
offset_noise_level: 0
|
||||
sigma_sampler_config:
|
||||
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
|
||||
params:
|
||||
uniform_sampling: True
|
||||
group_num: 40
|
||||
num_idx: 1000
|
||||
discretization_config:
|
||||
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
|
||||
|
||||
sampler_config:
|
||||
target: sgm.modules.diffusionmodules.sampling.VPSDEDPMPP2MSampler
|
||||
params:
|
||||
num_steps: 50
|
||||
verbose: True
|
||||
|
||||
discretization_config:
|
||||
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
|
||||
guider_config:
|
||||
target: sgm.modules.diffusionmodules.guiders.DynamicCFG
|
||||
params:
|
||||
scale: 6
|
||||
exp: 5
|
||||
num_steps: 50
|
160
sat/configs/cogvideox1.5_5b_i2v.yaml
Normal file
160
sat/configs/cogvideox1.5_5b_i2v.yaml
Normal file
@ -0,0 +1,160 @@
|
||||
model:
|
||||
scale_factor: 0.7
|
||||
disable_first_stage_autocast: true
|
||||
latent_input: false
|
||||
noised_image_input: true
|
||||
noised_image_all_concat: false
|
||||
noised_image_dropout: 0.05
|
||||
augmentation_dropout: 0.15
|
||||
log_keys:
|
||||
- txt
|
||||
|
||||
denoiser_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
||||
params:
|
||||
num_idx: 1000
|
||||
quantize_c_noise: False
|
||||
|
||||
weighting_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
|
||||
scaling_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_scaling.VideoScaling
|
||||
discretization_config:
|
||||
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
|
||||
|
||||
network_config:
|
||||
target: dit_video_concat.DiffusionTransformer
|
||||
params:
|
||||
# space_interpolation: 1.875
|
||||
ofs_embed_dim: 512
|
||||
time_embed_dim: 512
|
||||
elementwise_affine: True
|
||||
num_frames: 81
|
||||
time_compressed_rate: 4
|
||||
latent_width: 300
|
||||
latent_height: 300
|
||||
num_layers: 42
|
||||
patch_size: [2, 2, 2]
|
||||
in_channels: 32
|
||||
out_channels: 16
|
||||
hidden_size: 3072
|
||||
adm_in_channels: 256
|
||||
num_attention_heads: 48
|
||||
|
||||
transformer_args:
|
||||
checkpoint_activations: True
|
||||
vocab_size: 1
|
||||
max_sequence_length: 64
|
||||
layernorm_order: pre
|
||||
skip_init: false
|
||||
model_parallel_size: 1
|
||||
is_decoder: false
|
||||
|
||||
modules:
|
||||
pos_embed_config:
|
||||
target: dit_video_concat.Rotary3DPositionEmbeddingMixin
|
||||
params:
|
||||
hidden_size_head: 64
|
||||
text_length: 224
|
||||
|
||||
patch_embed_config:
|
||||
target: dit_video_concat.ImagePatchEmbeddingMixin
|
||||
params:
|
||||
text_hidden_size: 4096
|
||||
|
||||
|
||||
adaln_layer_config:
|
||||
target: dit_video_concat.AdaLNMixin
|
||||
params:
|
||||
qk_ln: True
|
||||
|
||||
final_layer_config:
|
||||
target: dit_video_concat.FinalLayerMixin
|
||||
|
||||
conditioner_config:
|
||||
target: sgm.modules.GeneralConditioner
|
||||
params:
|
||||
emb_models:
|
||||
|
||||
- is_trainable: false
|
||||
input_key: txt
|
||||
ucg_rate: 0.1
|
||||
target: sgm.modules.encoders.modules.FrozenT5Embedder
|
||||
params:
|
||||
model_dir: "google/t5-v1_1-xxl"
|
||||
max_length: 224
|
||||
|
||||
|
||||
first_stage_config:
|
||||
target : vae_modules.autoencoder.VideoAutoencoderInferenceWrapper
|
||||
params:
|
||||
cp_size: 1
|
||||
ckpt_path: "cogvideox-5b-i2v-sat/vae/3d-vae.pt"
|
||||
ignore_keys: ['loss']
|
||||
|
||||
loss_config:
|
||||
target: torch.nn.Identity
|
||||
|
||||
regularizer_config:
|
||||
target: vae_modules.regularizers.DiagonalGaussianRegularizer
|
||||
|
||||
encoder_config:
|
||||
target: vae_modules.cp_enc_dec.ContextParallelEncoder3D
|
||||
params:
|
||||
double_z: true
|
||||
z_channels: 16
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult: [1, 2, 2, 4]
|
||||
attn_resolutions: []
|
||||
num_res_blocks: 3
|
||||
dropout: 0.0
|
||||
gather_norm: True
|
||||
|
||||
decoder_config:
|
||||
target: vae_modules.cp_enc_dec.ContextParallelDecoder3D
|
||||
params:
|
||||
double_z: True
|
||||
z_channels: 16
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult: [1, 2, 2, 4]
|
||||
attn_resolutions: []
|
||||
num_res_blocks: 3
|
||||
dropout: 0.0
|
||||
gather_norm: True
|
||||
|
||||
loss_fn_config:
|
||||
target: sgm.modules.diffusionmodules.loss.VideoDiffusionLoss
|
||||
params:
|
||||
fixed_frames: 0
|
||||
offset_noise_level: 0.0
|
||||
sigma_sampler_config:
|
||||
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
|
||||
params:
|
||||
uniform_sampling: True
|
||||
group_num: 40
|
||||
num_idx: 1000
|
||||
discretization_config:
|
||||
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
|
||||
|
||||
sampler_config:
|
||||
target: sgm.modules.diffusionmodules.sampling.VPSDEDPMPP2MSampler
|
||||
params:
|
||||
fixed_frames: 0
|
||||
num_steps: 50
|
||||
verbose: True
|
||||
|
||||
discretization_config:
|
||||
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
|
||||
|
||||
guider_config:
|
||||
target: sgm.modules.diffusionmodules.guiders.DynamicCFG
|
||||
params:
|
||||
scale: 6
|
||||
exp: 5
|
||||
num_steps: 50
|
@ -179,14 +179,31 @@ class SATVideoDiffusionEngine(nn.Module):
|
||||
n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])
|
||||
n_rounds = math.ceil(z.shape[0] / n_samples)
|
||||
all_out = []
|
||||
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
|
||||
for n in range(n_rounds):
|
||||
if isinstance(self.first_stage_model.decoder, VideoDecoder):
|
||||
kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])}
|
||||
else:
|
||||
kwargs = {}
|
||||
out = self.first_stage_model.decode(z[n * n_samples : (n + 1) * n_samples], **kwargs)
|
||||
all_out.append(out)
|
||||
for n in range(n_rounds):
|
||||
z_now = z[n * n_samples : (n + 1) * n_samples, :, 1:]
|
||||
latent_time = z_now.shape[2] # check the time latent
|
||||
temporal_compress_times = 4
|
||||
|
||||
fake_cp_size = min(10, latent_time // 2)
|
||||
start_frame = 0
|
||||
|
||||
recons = []
|
||||
start_frame = 0
|
||||
for i in range(fake_cp_size):
|
||||
end_frame = start_frame + latent_time // fake_cp_size + (1 if i < latent_time % fake_cp_size else 0)
|
||||
|
||||
fake_cp_rank0 = True if i == 0 else False
|
||||
clear_fake_cp_cache = True if i == fake_cp_size - 1 else False
|
||||
with torch.no_grad():
|
||||
recon = self.first_stage_model.decode(
|
||||
z_now[:, :, start_frame:end_frame].contiguous(),
|
||||
clear_fake_cp_cache=clear_fake_cp_cache,
|
||||
fake_cp_rank0=fake_cp_rank0,
|
||||
)
|
||||
recons.append(recon)
|
||||
start_frame = end_frame
|
||||
recons = torch.cat(recons, dim=2)
|
||||
all_out.append(recons)
|
||||
out = torch.cat(all_out, dim=0)
|
||||
return out
|
||||
|
||||
@ -218,6 +235,7 @@ class SATVideoDiffusionEngine(nn.Module):
|
||||
shape: Union[None, Tuple, List] = None,
|
||||
prefix=None,
|
||||
concat_images=None,
|
||||
ofs=None,
|
||||
**kwargs,
|
||||
):
|
||||
randn = torch.randn(batch_size, *shape).to(torch.float32).to(self.device)
|
||||
@ -241,7 +259,7 @@ class SATVideoDiffusionEngine(nn.Module):
|
||||
self.model, input, sigma, c, concat_images=concat_images, **addtional_model_inputs
|
||||
)
|
||||
|
||||
samples = self.sampler(denoiser, randn, cond, uc=uc, scale=scale, scale_emb=scale_emb)
|
||||
samples = self.sampler(denoiser, randn, cond, uc=uc, scale=scale, scale_emb=scale_emb, ofs=ofs)
|
||||
samples = samples.to(self.dtype)
|
||||
return samples
|
||||
|
||||
|
@ -1,5 +1,7 @@
|
||||
from functools import partial
|
||||
from einops import rearrange, repeat
|
||||
from functools import reduce
|
||||
from operator import mul
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
@ -13,38 +15,34 @@ from sat.mpu.layers import ColumnParallelLinear
|
||||
from sgm.util import instantiate_from_config
|
||||
|
||||
from sgm.modules.diffusionmodules.openaimodel import Timestep
|
||||
from sgm.modules.diffusionmodules.util import (
|
||||
linear,
|
||||
timestep_embedding,
|
||||
)
|
||||
from sgm.modules.diffusionmodules.util import linear, timestep_embedding
|
||||
from sat.ops.layernorm import LayerNorm, RMSNorm
|
||||
|
||||
|
||||
class ImagePatchEmbeddingMixin(BaseMixin):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
hidden_size,
|
||||
patch_size,
|
||||
bias=True,
|
||||
text_hidden_size=None,
|
||||
):
|
||||
def __init__(self, in_channels, hidden_size, patch_size, text_hidden_size=None):
|
||||
super().__init__()
|
||||
self.proj = nn.Conv2d(in_channels, hidden_size, kernel_size=patch_size, stride=patch_size, bias=bias)
|
||||
self.patch_size = patch_size
|
||||
self.proj = nn.Linear(in_channels * reduce(mul, patch_size), hidden_size)
|
||||
if text_hidden_size is not None:
|
||||
self.text_proj = nn.Linear(text_hidden_size, hidden_size)
|
||||
else:
|
||||
self.text_proj = None
|
||||
|
||||
def word_embedding_forward(self, input_ids, **kwargs):
|
||||
# now is 3d patch
|
||||
images = kwargs["images"] # (b,t,c,h,w)
|
||||
B, T = images.shape[:2]
|
||||
emb = images.view(-1, *images.shape[2:])
|
||||
emb = self.proj(emb) # ((b t),d,h/2,w/2)
|
||||
emb = emb.view(B, T, *emb.shape[1:])
|
||||
emb = emb.flatten(3).transpose(2, 3) # (b,t,n,d)
|
||||
emb = rearrange(emb, "b t n d -> b (t n) d")
|
||||
emb = rearrange(images, "b t c h w -> b (t h w) c")
|
||||
emb = rearrange(
|
||||
emb,
|
||||
"b (t o h p w q) c -> b (t h w) (c o p q)",
|
||||
t=kwargs["rope_T"],
|
||||
h=kwargs["rope_H"],
|
||||
w=kwargs["rope_W"],
|
||||
o=self.patch_size[0],
|
||||
p=self.patch_size[1],
|
||||
q=self.patch_size[2],
|
||||
)
|
||||
emb = self.proj(emb)
|
||||
|
||||
if self.text_proj is not None:
|
||||
text_emb = self.text_proj(kwargs["encoder_outputs"])
|
||||
@ -74,7 +72,8 @@ def get_3d_sincos_pos_embed(
|
||||
grid_size: int of the grid height and width
|
||||
t_size: int of the temporal size
|
||||
return:
|
||||
pos_embed: [t_size*grid_size*grid_size, embed_dim] or [1+t_size*grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
||||
pos_embed: [t_size*grid_size * grid_size, embed_dim] or [1+t_size*grid_size * grid_size, embed_dim]
|
||||
(w/ or w/o cls_token)
|
||||
"""
|
||||
assert embed_dim % 4 == 0
|
||||
embed_dim_spatial = embed_dim // 4 * 3
|
||||
@ -100,7 +99,6 @@ def get_3d_sincos_pos_embed(
|
||||
pos_embed_spatial = np.repeat(pos_embed_spatial, t_size, axis=0) # [T, H*W, D // 4 * 3]
|
||||
|
||||
pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1)
|
||||
# pos_embed = pos_embed.reshape([-1, embed_dim]) # [T*H*W, D]
|
||||
|
||||
return pos_embed # [T, H*W, D]
|
||||
|
||||
@ -259,6 +257,9 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin):
|
||||
text_length,
|
||||
theta=10000,
|
||||
rot_v=False,
|
||||
height_interpolation=1.0,
|
||||
width_interpolation=1.0,
|
||||
time_interpolation=1.0,
|
||||
learnable_pos_embed=False,
|
||||
):
|
||||
super().__init__()
|
||||
@ -285,14 +286,10 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin):
|
||||
freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
|
||||
|
||||
freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
|
||||
freqs = rearrange(freqs, "t h w d -> (t h w) d")
|
||||
|
||||
freqs = freqs.contiguous()
|
||||
freqs_sin = freqs.sin()
|
||||
freqs_cos = freqs.cos()
|
||||
self.register_buffer("freqs_sin", freqs_sin)
|
||||
self.register_buffer("freqs_cos", freqs_cos)
|
||||
|
||||
self.freqs_sin = freqs.sin().cuda()
|
||||
self.freqs_cos = freqs.cos().cuda()
|
||||
self.text_length = text_length
|
||||
if learnable_pos_embed:
|
||||
num_patches = height * width * compressed_num_frames + text_length
|
||||
@ -301,15 +298,20 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin):
|
||||
self.pos_embedding = None
|
||||
|
||||
def rotary(self, t, **kwargs):
|
||||
seq_len = t.shape[2]
|
||||
freqs_cos = self.freqs_cos[:seq_len].unsqueeze(0).unsqueeze(0)
|
||||
freqs_sin = self.freqs_sin[:seq_len].unsqueeze(0).unsqueeze(0)
|
||||
def reshape_freq(freqs):
|
||||
freqs = freqs[: kwargs["rope_T"], : kwargs["rope_H"], : kwargs["rope_W"]].contiguous()
|
||||
freqs = rearrange(freqs, "t h w d -> (t h w) d")
|
||||
freqs = freqs.unsqueeze(0).unsqueeze(0)
|
||||
return freqs
|
||||
|
||||
freqs_cos = reshape_freq(self.freqs_cos).to(t.dtype)
|
||||
freqs_sin = reshape_freq(self.freqs_sin).to(t.dtype)
|
||||
|
||||
return t * freqs_cos + rotate_half(t) * freqs_sin
|
||||
|
||||
def position_embedding_forward(self, position_ids, **kwargs):
|
||||
if self.pos_embedding is not None:
|
||||
return self.pos_embedding[:, :self.text_length + kwargs["seq_length"]]
|
||||
return self.pos_embedding[:, : self.text_length + kwargs["seq_length"]]
|
||||
else:
|
||||
return None
|
||||
|
||||
@ -326,10 +328,61 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin):
|
||||
):
|
||||
attention_fn_default = HOOKS_DEFAULT["attention_fn"]
|
||||
|
||||
query_layer[:, :, self.text_length :] = self.rotary(query_layer[:, :, self.text_length :])
|
||||
key_layer[:, :, self.text_length :] = self.rotary(key_layer[:, :, self.text_length :])
|
||||
query_layer = torch.cat(
|
||||
(
|
||||
query_layer[
|
||||
:,
|
||||
:,
|
||||
: kwargs["text_length"],
|
||||
],
|
||||
self.rotary(
|
||||
query_layer[
|
||||
:,
|
||||
:,
|
||||
kwargs["text_length"] :,
|
||||
],
|
||||
**kwargs,
|
||||
),
|
||||
),
|
||||
dim=2,
|
||||
)
|
||||
key_layer = torch.cat(
|
||||
(
|
||||
key_layer[
|
||||
:,
|
||||
:,
|
||||
: kwargs["text_length"],
|
||||
],
|
||||
self.rotary(
|
||||
key_layer[
|
||||
:,
|
||||
:,
|
||||
kwargs["text_length"] :,
|
||||
],
|
||||
**kwargs,
|
||||
),
|
||||
),
|
||||
dim=2,
|
||||
)
|
||||
if self.rot_v:
|
||||
value_layer[:, :, self.text_length :] = self.rotary(value_layer[:, :, self.text_length :])
|
||||
value_layer = torch.cat(
|
||||
(
|
||||
value_layer[
|
||||
:,
|
||||
:,
|
||||
: kwargs["text_length"],
|
||||
],
|
||||
self.rotary(
|
||||
value_layer[
|
||||
:,
|
||||
:,
|
||||
kwargs["text_length"] :,
|
||||
],
|
||||
**kwargs,
|
||||
),
|
||||
),
|
||||
dim=2,
|
||||
)
|
||||
|
||||
return attention_fn_default(
|
||||
query_layer,
|
||||
@ -347,21 +400,25 @@ def modulate(x, shift, scale):
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
|
||||
def unpatchify(x, c, p, w, h, rope_position_ids=None, **kwargs):
|
||||
def unpatchify(x, c, patch_size, w, h, **kwargs):
|
||||
"""
|
||||
x: (N, T/2 * S, patch_size**3 * C)
|
||||
imgs: (N, T, H, W, C)
|
||||
|
||||
patch_size 被拆解为三个不同的维度 (o, p, q),分别对应了深度(o)、高度(p)和宽度(q)。这使得 patch 大小在不同维度上可以不相等,增加了灵活性。
|
||||
"""
|
||||
if rope_position_ids is not None:
|
||||
assert NotImplementedError
|
||||
# do pix2struct unpatchify
|
||||
L = x.shape[1]
|
||||
x = x.reshape(shape=(x.shape[0], L, p, p, c))
|
||||
x = torch.einsum("nlpqc->ncplq", x)
|
||||
imgs = x.reshape(shape=(x.shape[0], c, p, L * p))
|
||||
else:
|
||||
b = x.shape[0]
|
||||
imgs = rearrange(x, "b (t h w) (c p q) -> b t c (h p) (w q)", b=b, h=h, w=w, c=c, p=p, q=p)
|
||||
|
||||
imgs = rearrange(
|
||||
x,
|
||||
"b (t h w) (c o p q) -> b (t o) c (h p) (w q)",
|
||||
c=c,
|
||||
o=patch_size[0],
|
||||
p=patch_size[1],
|
||||
q=patch_size[2],
|
||||
t=kwargs["rope_T"],
|
||||
h=kwargs["rope_H"],
|
||||
w=kwargs["rope_W"],
|
||||
)
|
||||
|
||||
return imgs
|
||||
|
||||
@ -382,27 +439,17 @@ class FinalLayerMixin(BaseMixin):
|
||||
self.patch_size = patch_size
|
||||
self.out_channels = out_channels
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=elementwise_affine, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
||||
self.linear = nn.Linear(hidden_size, reduce(mul, patch_size) * out_channels, bias=True)
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim, 2 * hidden_size, bias=True))
|
||||
|
||||
self.spatial_length = latent_width * latent_height // patch_size**2
|
||||
self.latent_width = latent_width
|
||||
self.latent_height = latent_height
|
||||
|
||||
def final_forward(self, logits, **kwargs):
|
||||
x, emb = logits[:, kwargs["text_length"] :, :], kwargs["emb"] # x:(b,(t n),d)
|
||||
x, emb = logits[:, kwargs["text_length"] :, :], kwargs["emb"] # x:(b,(t n),d),只取了x中后面images的部分
|
||||
shift, scale = self.adaLN_modulation(emb).chunk(2, dim=1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
|
||||
return unpatchify(
|
||||
x,
|
||||
c=self.out_channels,
|
||||
p=self.patch_size,
|
||||
w=self.latent_width // self.patch_size,
|
||||
h=self.latent_height // self.patch_size,
|
||||
rope_position_ids=kwargs.get("rope_position_ids", None),
|
||||
**kwargs,
|
||||
x, c=self.out_channels, patch_size=self.patch_size, w=kwargs["rope_W"], h=kwargs["rope_H"], **kwargs
|
||||
)
|
||||
|
||||
def reinit(self, parent_model=None):
|
||||
@ -440,8 +487,6 @@ class SwiGLUMixin(BaseMixin):
|
||||
class AdaLNMixin(BaseMixin):
|
||||
def __init__(
|
||||
self,
|
||||
width,
|
||||
height,
|
||||
hidden_size,
|
||||
num_layers,
|
||||
time_embed_dim,
|
||||
@ -452,8 +497,6 @@ class AdaLNMixin(BaseMixin):
|
||||
):
|
||||
super().__init__()
|
||||
self.num_layers = num_layers
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.compressed_num_frames = compressed_num_frames
|
||||
|
||||
self.adaLN_modulations = nn.ModuleList(
|
||||
@ -611,7 +654,7 @@ class DiffusionTransformer(BaseModel):
|
||||
time_interpolation=1.0,
|
||||
use_SwiGLU=False,
|
||||
use_RMSNorm=False,
|
||||
zero_init_y_embed=False,
|
||||
ofs_embed_dim=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.latent_width = latent_width
|
||||
@ -619,12 +662,13 @@ class DiffusionTransformer(BaseModel):
|
||||
self.patch_size = patch_size
|
||||
self.num_frames = num_frames
|
||||
self.time_compressed_rate = time_compressed_rate
|
||||
self.spatial_length = latent_width * latent_height // patch_size**2
|
||||
self.spatial_length = latent_width * latent_height // reduce(mul, patch_size[1:])
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.hidden_size = hidden_size
|
||||
self.model_channels = hidden_size
|
||||
self.time_embed_dim = time_embed_dim if time_embed_dim is not None else hidden_size
|
||||
self.ofs_embed_dim = ofs_embed_dim
|
||||
self.num_classes = num_classes
|
||||
self.adm_in_channels = adm_in_channels
|
||||
self.input_time = input_time
|
||||
@ -636,7 +680,6 @@ class DiffusionTransformer(BaseModel):
|
||||
self.width_interpolation = width_interpolation
|
||||
self.time_interpolation = time_interpolation
|
||||
self.inner_hidden_size = hidden_size * 4
|
||||
self.zero_init_y_embed = zero_init_y_embed
|
||||
try:
|
||||
self.dtype = str_to_dtype[kwargs.pop("dtype")]
|
||||
except:
|
||||
@ -669,7 +712,6 @@ class DiffusionTransformer(BaseModel):
|
||||
|
||||
def _build_modules(self, module_configs):
|
||||
model_channels = self.hidden_size
|
||||
# time_embed_dim = model_channels * 4
|
||||
time_embed_dim = self.time_embed_dim
|
||||
self.time_embed = nn.Sequential(
|
||||
linear(model_channels, time_embed_dim),
|
||||
@ -677,6 +719,13 @@ class DiffusionTransformer(BaseModel):
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
|
||||
if self.ofs_embed_dim is not None:
|
||||
self.ofs_embed = nn.Sequential(
|
||||
linear(self.ofs_embed_dim, self.ofs_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(self.ofs_embed_dim, self.ofs_embed_dim),
|
||||
)
|
||||
|
||||
if self.num_classes is not None:
|
||||
if isinstance(self.num_classes, int):
|
||||
self.label_emb = nn.Embedding(self.num_classes, time_embed_dim)
|
||||
@ -701,9 +750,6 @@ class DiffusionTransformer(BaseModel):
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
)
|
||||
if self.zero_init_y_embed:
|
||||
nn.init.constant_(self.label_emb[0][2].weight, 0)
|
||||
nn.init.constant_(self.label_emb[0][2].bias, 0)
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
@ -712,10 +758,13 @@ class DiffusionTransformer(BaseModel):
|
||||
"pos_embed",
|
||||
instantiate_from_config(
|
||||
pos_embed_config,
|
||||
height=self.latent_height // self.patch_size,
|
||||
width=self.latent_width // self.patch_size,
|
||||
height=self.latent_height // self.patch_size[1],
|
||||
width=self.latent_width // self.patch_size[2],
|
||||
compressed_num_frames=(self.num_frames - 1) // self.time_compressed_rate + 1,
|
||||
hidden_size=self.hidden_size,
|
||||
height_interpolation=self.height_interpolation,
|
||||
width_interpolation=self.width_interpolation,
|
||||
time_interpolation=self.time_interpolation,
|
||||
),
|
||||
reinit=True,
|
||||
)
|
||||
@ -737,8 +786,6 @@ class DiffusionTransformer(BaseModel):
|
||||
"adaln_layer",
|
||||
instantiate_from_config(
|
||||
adaln_layer_config,
|
||||
height=self.latent_height // self.patch_size,
|
||||
width=self.latent_width // self.patch_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_layers=self.num_layers,
|
||||
compressed_num_frames=(self.num_frames - 1) // self.time_compressed_rate + 1,
|
||||
@ -749,7 +796,6 @@ class DiffusionTransformer(BaseModel):
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
final_layer_config = module_configs["final_layer_config"]
|
||||
self.add_mixin(
|
||||
"final_layer",
|
||||
@ -766,25 +812,18 @@ class DiffusionTransformer(BaseModel):
|
||||
reinit=True,
|
||||
)
|
||||
|
||||
if "lora_config" in module_configs:
|
||||
lora_config = module_configs["lora_config"]
|
||||
self.add_mixin("lora", instantiate_from_config(lora_config, layer_num=self.num_layers), reinit=True)
|
||||
|
||||
return
|
||||
|
||||
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
|
||||
b, t, d, h, w = x.shape
|
||||
if x.dtype != self.dtype:
|
||||
x = x.to(self.dtype)
|
||||
|
||||
# This is not use in inference
|
||||
if "concat_images" in kwargs and kwargs["concat_images"] is not None:
|
||||
if kwargs["concat_images"].shape[0] != x.shape[0]:
|
||||
concat_images = kwargs["concat_images"].repeat(2, 1, 1, 1, 1)
|
||||
else:
|
||||
concat_images = kwargs["concat_images"]
|
||||
x = torch.cat([x, concat_images], dim=2)
|
||||
|
||||
assert (y is not None) == (
|
||||
self.num_classes is not None
|
||||
), "must specify y if and only if the model is class-conditional"
|
||||
@ -792,17 +831,25 @@ class DiffusionTransformer(BaseModel):
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
if self.num_classes is not None:
|
||||
# assert y.shape[0] == x.shape[0]
|
||||
assert x.shape[0] % y.shape[0] == 0
|
||||
y = y.repeat_interleave(x.shape[0] // y.shape[0], dim=0)
|
||||
emb = emb + self.label_emb(y)
|
||||
|
||||
kwargs["seq_length"] = t * h * w // (self.patch_size**2)
|
||||
if self.ofs_embed_dim is not None:
|
||||
ofs_emb = timestep_embedding(kwargs["ofs"], self.ofs_embed_dim, repeat_only=False, dtype=self.dtype)
|
||||
ofs_emb = self.ofs_embed(ofs_emb)
|
||||
emb = emb + ofs_emb
|
||||
|
||||
kwargs["seq_length"] = t * h * w // reduce(mul, self.patch_size)
|
||||
kwargs["images"] = x
|
||||
kwargs["emb"] = emb
|
||||
kwargs["encoder_outputs"] = context
|
||||
kwargs["text_length"] = context.shape[1]
|
||||
|
||||
kwargs["rope_T"] = t // self.patch_size[0]
|
||||
kwargs["rope_H"] = h // self.patch_size[1]
|
||||
kwargs["rope_W"] = w // self.patch_size[2]
|
||||
|
||||
kwargs["input_ids"] = kwargs["position_ids"] = kwargs["attention_mask"] = torch.ones((1, 1)).to(x.dtype)
|
||||
output = super().forward(**kwargs)[0]
|
||||
return output
|
||||
|
@ -4,7 +4,7 @@ echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
|
||||
environs="WORLD_SIZE=1 RANK=0 LOCAL_RANK=0 LOCAL_WORLD_SIZE=1"
|
||||
|
||||
run_cmd="$environs python sample_video.py --base configs/cogvideox_5b.yaml configs/inference.yaml --seed $RANDOM"
|
||||
run_cmd="$environs python sample_video.py --base configs/cogvideox1.5_5b.yaml configs/test_inference.yaml --seed $RANDOM"
|
||||
|
||||
echo ${run_cmd}
|
||||
eval ${run_cmd}
|
||||
|
@ -1,16 +1,11 @@
|
||||
SwissArmyTransformer==0.4.12
|
||||
omegaconf==2.3.0
|
||||
torch==2.4.0
|
||||
torchvision==0.19.0
|
||||
pytorch_lightning==2.3.3
|
||||
kornia==0.7.3
|
||||
beartype==0.18.5
|
||||
numpy==2.0.1
|
||||
fsspec==2024.5.0
|
||||
safetensors==0.4.3
|
||||
imageio-ffmpeg==0.5.1
|
||||
imageio==2.34.2
|
||||
scipy==1.14.0
|
||||
decord==0.6.0
|
||||
wandb==0.17.5
|
||||
deepspeed==0.14.4
|
||||
SwissArmyTransformer>=0.4.12
|
||||
omegaconf>=2.3.0
|
||||
pytorch_lightning>=2.4.0
|
||||
kornia>=0.7.3
|
||||
beartype>=0.19.0
|
||||
fsspec>=2024.2.0
|
||||
safetensors>=0.4.5
|
||||
scipy>=1.14.1
|
||||
decord>=0.6.0
|
||||
wandb>=0.18.5
|
||||
deepspeed>=0.15.3
|
@ -4,24 +4,20 @@ import argparse
|
||||
from typing import List, Union
|
||||
from tqdm import tqdm
|
||||
from omegaconf import ListConfig
|
||||
from PIL import Image
|
||||
import imageio
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from einops import rearrange
|
||||
from einops import rearrange, repeat
|
||||
import torchvision.transforms as TT
|
||||
|
||||
|
||||
from sat.model.base_model import get_model
|
||||
from sat.training.model_io import load_checkpoint
|
||||
from sat import mpu
|
||||
|
||||
from diffusion_video import SATVideoDiffusionEngine
|
||||
from arguments import get_args
|
||||
from torchvision.transforms.functional import center_crop, resize
|
||||
from torchvision.transforms import InterpolationMode
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def read_from_cli():
|
||||
cnt = 0
|
||||
@ -56,6 +52,42 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], T=None, device="cuda
|
||||
if key == "txt":
|
||||
batch["txt"] = np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist()
|
||||
batch_uc["txt"] = np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist()
|
||||
elif key == "original_size_as_tuple":
|
||||
batch["original_size_as_tuple"] = (
|
||||
torch.tensor([value_dict["orig_height"], value_dict["orig_width"]]).to(device).repeat(*N, 1)
|
||||
)
|
||||
elif key == "crop_coords_top_left":
|
||||
batch["crop_coords_top_left"] = (
|
||||
torch.tensor([value_dict["crop_coords_top"], value_dict["crop_coords_left"]]).to(device).repeat(*N, 1)
|
||||
)
|
||||
elif key == "aesthetic_score":
|
||||
batch["aesthetic_score"] = torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
|
||||
batch_uc["aesthetic_score"] = (
|
||||
torch.tensor([value_dict["negative_aesthetic_score"]]).to(device).repeat(*N, 1)
|
||||
)
|
||||
|
||||
elif key == "target_size_as_tuple":
|
||||
batch["target_size_as_tuple"] = (
|
||||
torch.tensor([value_dict["target_height"], value_dict["target_width"]]).to(device).repeat(*N, 1)
|
||||
)
|
||||
elif key == "fps":
|
||||
batch[key] = torch.tensor([value_dict["fps"]]).to(device).repeat(math.prod(N))
|
||||
elif key == "fps_id":
|
||||
batch[key] = torch.tensor([value_dict["fps_id"]]).to(device).repeat(math.prod(N))
|
||||
elif key == "motion_bucket_id":
|
||||
batch[key] = torch.tensor([value_dict["motion_bucket_id"]]).to(device).repeat(math.prod(N))
|
||||
elif key == "pool_image":
|
||||
batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=math.prod(N)).to(device, dtype=torch.half)
|
||||
elif key == "cond_aug":
|
||||
batch[key] = repeat(
|
||||
torch.tensor([value_dict["cond_aug"]]).to("cuda"),
|
||||
"1 -> b",
|
||||
b=math.prod(N),
|
||||
)
|
||||
elif key == "cond_frames":
|
||||
batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0])
|
||||
elif key == "cond_frames_without_noise":
|
||||
batch[key] = repeat(value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0])
|
||||
else:
|
||||
batch[key] = value_dict[key]
|
||||
|
||||
@ -83,37 +115,6 @@ def save_video_as_grid_and_mp4(video_batch: torch.Tensor, save_path: str, fps: i
|
||||
writer.append_data(frame)
|
||||
|
||||
|
||||
def resize_for_rectangle_crop(arr, image_size, reshape_mode="random"):
|
||||
if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
|
||||
arr = resize(
|
||||
arr,
|
||||
size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
|
||||
interpolation=InterpolationMode.BICUBIC,
|
||||
)
|
||||
else:
|
||||
arr = resize(
|
||||
arr,
|
||||
size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
|
||||
interpolation=InterpolationMode.BICUBIC,
|
||||
)
|
||||
|
||||
h, w = arr.shape[2], arr.shape[3]
|
||||
arr = arr.squeeze(0)
|
||||
|
||||
delta_h = h - image_size[0]
|
||||
delta_w = w - image_size[1]
|
||||
|
||||
if reshape_mode == "random" or reshape_mode == "none":
|
||||
top = np.random.randint(0, delta_h + 1)
|
||||
left = np.random.randint(0, delta_w + 1)
|
||||
elif reshape_mode == "center":
|
||||
top, left = delta_h // 2, delta_w // 2
|
||||
else:
|
||||
raise NotImplementedError
|
||||
arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
|
||||
return arr
|
||||
|
||||
|
||||
def sampling_main(args, model_cls):
|
||||
if isinstance(model_cls, type):
|
||||
model = get_model(args, model_cls)
|
||||
@ -127,45 +128,62 @@ def sampling_main(args, model_cls):
|
||||
data_iter = read_from_cli()
|
||||
elif args.input_type == "txt":
|
||||
rank, world_size = mpu.get_data_parallel_rank(), mpu.get_data_parallel_world_size()
|
||||
print("rank and world_size", rank, world_size)
|
||||
data_iter = read_from_file(args.input_file, rank=rank, world_size=world_size)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
image_size = [480, 720]
|
||||
|
||||
if args.image2video:
|
||||
chained_trainsforms = []
|
||||
chained_trainsforms.append(TT.ToTensor())
|
||||
transform = TT.Compose(chained_trainsforms)
|
||||
|
||||
sample_func = model.sample
|
||||
T, H, W, C, F = args.sampling_num_frames, image_size[0], image_size[1], args.latent_channels, 8
|
||||
num_samples = [1]
|
||||
force_uc_zero_embeddings = ["txt"]
|
||||
device = model.device
|
||||
T, C = args.sampling_num_frames, args.latent_channels
|
||||
with torch.no_grad():
|
||||
for text, cnt in tqdm(data_iter):
|
||||
if args.image2video:
|
||||
# use with input image shape
|
||||
text, image_path = text.split("@@")
|
||||
assert os.path.exists(image_path), image_path
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
(img_W, img_H) = image.size
|
||||
|
||||
def nearest_multiple_of_16(n):
|
||||
lower_multiple = (n // 16) * 16
|
||||
upper_multiple = (n // 16 + 1) * 16
|
||||
if abs(n - lower_multiple) < abs(n - upper_multiple):
|
||||
return lower_multiple
|
||||
else:
|
||||
return upper_multiple
|
||||
|
||||
if img_H < img_W:
|
||||
H = 96
|
||||
W = int(nearest_multiple_of_16(img_W / img_H * H * 8)) // 8
|
||||
else:
|
||||
W = 96
|
||||
H = int(nearest_multiple_of_16(img_H / img_W * W * 8)) // 8
|
||||
chained_trainsforms = []
|
||||
chained_trainsforms.append(TT.Resize(size=[int(H * 8), int(W * 8)], interpolation=1))
|
||||
chained_trainsforms.append(TT.ToTensor())
|
||||
transform = TT.Compose(chained_trainsforms)
|
||||
image = transform(image).unsqueeze(0).to("cuda")
|
||||
image = resize_for_rectangle_crop(image, image_size, reshape_mode="center").unsqueeze(0)
|
||||
image = image * 2.0 - 1.0
|
||||
image = image.unsqueeze(2).to(torch.bfloat16)
|
||||
image = model.encode_first_stage(image, None)
|
||||
image = image / model.scale_factor
|
||||
image = image.permute(0, 2, 1, 3, 4).contiguous()
|
||||
pad_shape = (image.shape[0], T - 1, C, H // F, W // F)
|
||||
pad_shape = (image.shape[0], T - 1, C, H, W)
|
||||
image = torch.concat([image, torch.zeros(pad_shape).to(image.device).to(image.dtype)], dim=1)
|
||||
else:
|
||||
image_size = args.sampling_image_size
|
||||
H, W = image_size[0], image_size[1]
|
||||
F = 8 # 8x downsampled
|
||||
image = None
|
||||
|
||||
value_dict = {
|
||||
"prompt": text,
|
||||
"negative_prompt": "",
|
||||
"num_frames": torch.tensor(T).unsqueeze(0),
|
||||
}
|
||||
text_cast = [text]
|
||||
mp_size = mpu.get_model_parallel_world_size()
|
||||
global_rank = torch.distributed.get_rank() // mp_size
|
||||
src = global_rank * mp_size
|
||||
torch.distributed.broadcast_object_list(text_cast, src=src, group=mpu.get_model_parallel_group())
|
||||
text = text_cast[0]
|
||||
value_dict = {"prompt": text, "negative_prompt": "", "num_frames": torch.tensor(T).unsqueeze(0)}
|
||||
|
||||
batch, batch_uc = get_batch(
|
||||
get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples
|
||||
@ -187,57 +205,42 @@ def sampling_main(args, model_cls):
|
||||
if not k == "crossattn":
|
||||
c[k], uc[k] = map(lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc))
|
||||
|
||||
if args.image2video and image is not None:
|
||||
if args.image2video:
|
||||
c["concat"] = image
|
||||
uc["concat"] = image
|
||||
|
||||
for index in range(args.batch_size):
|
||||
# reload model on GPU
|
||||
model.to(device)
|
||||
samples_z = sample_func(
|
||||
c,
|
||||
uc=uc,
|
||||
batch_size=1,
|
||||
shape=(T, C, H // F, W // F),
|
||||
)
|
||||
if args.image2video:
|
||||
samples_z = sample_func(
|
||||
c, uc=uc, batch_size=1, shape=(T, C, H, W), ofs=torch.tensor([2.0]).to("cuda")
|
||||
)
|
||||
else:
|
||||
samples_z = sample_func(
|
||||
c,
|
||||
uc=uc,
|
||||
batch_size=1,
|
||||
shape=(T, C, H // F, W // F),
|
||||
).to("cuda")
|
||||
|
||||
samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous()
|
||||
|
||||
# Unload the model from GPU to save GPU memory
|
||||
model.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
first_stage_model = model.first_stage_model
|
||||
first_stage_model = first_stage_model.to(device)
|
||||
|
||||
latent = 1.0 / model.scale_factor * samples_z
|
||||
|
||||
# Decode latent serial to save GPU memory
|
||||
recons = []
|
||||
loop_num = (T - 1) // 2
|
||||
for i in range(loop_num):
|
||||
if i == 0:
|
||||
start_frame, end_frame = 0, 3
|
||||
else:
|
||||
start_frame, end_frame = i * 2 + 1, i * 2 + 3
|
||||
if i == loop_num - 1:
|
||||
clear_fake_cp_cache = True
|
||||
else:
|
||||
clear_fake_cp_cache = False
|
||||
with torch.no_grad():
|
||||
recon = first_stage_model.decode(
|
||||
latent[:, :, start_frame:end_frame].contiguous(), clear_fake_cp_cache=clear_fake_cp_cache
|
||||
)
|
||||
|
||||
recons.append(recon)
|
||||
|
||||
recon = torch.cat(recons, dim=2).to(torch.float32)
|
||||
samples_x = recon.permute(0, 2, 1, 3, 4).contiguous()
|
||||
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu()
|
||||
|
||||
save_path = os.path.join(
|
||||
args.output_dir, str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120], str(index)
|
||||
)
|
||||
if mpu.get_model_parallel_rank() == 0:
|
||||
save_video_as_grid_and_mp4(samples, save_path, fps=args.sampling_fps)
|
||||
if args.only_save_latents:
|
||||
samples_z = 1.0 / model.scale_factor * samples_z
|
||||
save_path = os.path.join(
|
||||
args.output_dir, str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120], str(index)
|
||||
)
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
torch.save(samples_z, os.path.join(save_path, "latent.pt"))
|
||||
with open(os.path.join(save_path, "text.txt"), "w") as f:
|
||||
f.write(text)
|
||||
else:
|
||||
samples_x = model.decode_first_stage(samples_z).to(torch.float32)
|
||||
samples_x = samples_x.permute(0, 2, 1, 3, 4).contiguous()
|
||||
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu()
|
||||
save_path = os.path.join(
|
||||
args.output_dir, str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120], str(index)
|
||||
)
|
||||
if mpu.get_model_parallel_rank() == 0:
|
||||
save_video_as_grid_and_mp4(samples, save_path, fps=args.sampling_fps)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1,7 +1,8 @@
|
||||
"""
|
||||
Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py
|
||||
Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py
|
||||
"""
|
||||
|
||||
|
||||
from typing import Dict, Union
|
||||
|
||||
import torch
|
||||
@ -16,7 +17,6 @@ from ...modules.diffusionmodules.sampling_utils import (
|
||||
to_sigma,
|
||||
)
|
||||
from ...util import append_dims, default, instantiate_from_config
|
||||
from ...util import SeededNoise
|
||||
|
||||
from .guiders import DynamicCFG
|
||||
|
||||
@ -44,7 +44,9 @@ class BaseDiffusionSampler:
|
||||
self.device = device
|
||||
|
||||
def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
|
||||
sigmas = self.discretization(self.num_steps if num_steps is None else num_steps, device=self.device)
|
||||
sigmas = self.discretization(
|
||||
self.num_steps if num_steps is None else num_steps, device=self.device
|
||||
)
|
||||
uc = default(uc, cond)
|
||||
|
||||
x *= torch.sqrt(1.0 + sigmas[0] ** 2.0)
|
||||
@ -83,7 +85,9 @@ class SingleStepDiffusionSampler(BaseDiffusionSampler):
|
||||
|
||||
|
||||
class EDMSampler(SingleStepDiffusionSampler):
|
||||
def __init__(self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs):
|
||||
def __init__(
|
||||
self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.s_churn = s_churn
|
||||
@ -102,15 +106,21 @@ class EDMSampler(SingleStepDiffusionSampler):
|
||||
dt = append_dims(next_sigma - sigma_hat, x.ndim)
|
||||
|
||||
euler_step = self.euler_step(x, d, dt)
|
||||
x = self.possible_correction_step(euler_step, x, d, dt, next_sigma, denoiser, cond, uc)
|
||||
x = self.possible_correction_step(
|
||||
euler_step, x, d, dt, next_sigma, denoiser, cond, uc
|
||||
)
|
||||
return x
|
||||
|
||||
def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
|
||||
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
|
||||
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
|
||||
x, cond, uc, num_steps
|
||||
)
|
||||
|
||||
for i in self.get_sigma_gen(num_sigmas):
|
||||
gamma = (
|
||||
min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) if self.s_tmin <= sigmas[i] <= self.s_tmax else 0.0
|
||||
min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)
|
||||
if self.s_tmin <= sigmas[i] <= self.s_tmax
|
||||
else 0.0
|
||||
)
|
||||
x = self.sampler_step(
|
||||
s_in * sigmas[i],
|
||||
@ -126,23 +136,30 @@ class EDMSampler(SingleStepDiffusionSampler):
|
||||
|
||||
|
||||
class DDIMSampler(SingleStepDiffusionSampler):
|
||||
def __init__(self, s_noise=0.1, *args, **kwargs):
|
||||
def __init__(
|
||||
self, s_noise=0.1, *args, **kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.s_noise = s_noise
|
||||
|
||||
def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, s_noise=0.0):
|
||||
|
||||
denoised = self.denoise(x, denoiser, sigma, cond, uc)
|
||||
d = to_d(x, sigma, denoised)
|
||||
dt = append_dims(next_sigma * (1 - s_noise**2) ** 0.5 - sigma, x.ndim)
|
||||
dt = append_dims(next_sigma * (1 - s_noise**2)**0.5 - sigma, x.ndim)
|
||||
|
||||
euler_step = x + dt * d + s_noise * append_dims(next_sigma, x.ndim) * torch.randn_like(x)
|
||||
|
||||
x = self.possible_correction_step(euler_step, x, d, dt, next_sigma, denoiser, cond, uc)
|
||||
x = self.possible_correction_step(
|
||||
euler_step, x, d, dt, next_sigma, denoiser, cond, uc
|
||||
)
|
||||
return x
|
||||
|
||||
def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
|
||||
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
|
||||
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
|
||||
x, cond, uc, num_steps
|
||||
)
|
||||
|
||||
for i in self.get_sigma_gen(num_sigmas):
|
||||
x = self.sampler_step(
|
||||
@ -181,7 +198,9 @@ class AncestralSampler(SingleStepDiffusionSampler):
|
||||
return x
|
||||
|
||||
def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
|
||||
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
|
||||
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
|
||||
x, cond, uc, num_steps
|
||||
)
|
||||
|
||||
for i in self.get_sigma_gen(num_sigmas):
|
||||
x = self.sampler_step(
|
||||
@ -208,32 +227,43 @@ class LinearMultistepSampler(BaseDiffusionSampler):
|
||||
self.order = order
|
||||
|
||||
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):
|
||||
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
|
||||
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
|
||||
x, cond, uc, num_steps
|
||||
)
|
||||
|
||||
ds = []
|
||||
sigmas_cpu = sigmas.detach().cpu().numpy()
|
||||
for i in self.get_sigma_gen(num_sigmas):
|
||||
sigma = s_in * sigmas[i]
|
||||
denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs)
|
||||
denoised = denoiser(
|
||||
*self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs
|
||||
)
|
||||
denoised = self.guider(denoised, sigma)
|
||||
d = to_d(x, sigma, denoised)
|
||||
ds.append(d)
|
||||
if len(ds) > self.order:
|
||||
ds.pop(0)
|
||||
cur_order = min(i + 1, self.order)
|
||||
coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)]
|
||||
coeffs = [
|
||||
linear_multistep_coeff(cur_order, sigmas_cpu, i, j)
|
||||
for j in range(cur_order)
|
||||
]
|
||||
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class EulerEDMSampler(EDMSampler):
|
||||
def possible_correction_step(self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc):
|
||||
def possible_correction_step(
|
||||
self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc
|
||||
):
|
||||
return euler_step
|
||||
|
||||
|
||||
class HeunEDMSampler(EDMSampler):
|
||||
def possible_correction_step(self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc):
|
||||
def possible_correction_step(
|
||||
self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc
|
||||
):
|
||||
if torch.sum(next_sigma) < 1e-14:
|
||||
# Save a network evaluation if all noise levels are 0
|
||||
return euler_step
|
||||
@ -243,7 +273,9 @@ class HeunEDMSampler(EDMSampler):
|
||||
d_prime = (d + d_new) / 2.0
|
||||
|
||||
# apply correction if noise level is not 0
|
||||
x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step)
|
||||
x = torch.where(
|
||||
append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
@ -282,7 +314,9 @@ class DPMPP2SAncestralSampler(AncestralSampler):
|
||||
x = x_euler
|
||||
else:
|
||||
h, s, t, t_next = self.get_variables(sigma, sigma_down)
|
||||
mult = [append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next)]
|
||||
mult = [
|
||||
append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next)
|
||||
]
|
||||
|
||||
x2 = mult[0] * x - mult[1] * denoised
|
||||
denoised2 = self.denoise(x2, denoiser, to_sigma(s), cond, uc)
|
||||
@ -332,7 +366,10 @@ class DPMPP2MSampler(BaseDiffusionSampler):
|
||||
denoised = self.denoise(x, denoiser, sigma, cond, uc)
|
||||
|
||||
h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma)
|
||||
mult = [append_dims(mult, x.ndim) for mult in self.get_mult(h, r, t, t_next, previous_sigma)]
|
||||
mult = [
|
||||
append_dims(mult, x.ndim)
|
||||
for mult in self.get_mult(h, r, t, t_next, previous_sigma)
|
||||
]
|
||||
|
||||
x_standard = mult[0] * x - mult[1] * denoised
|
||||
if old_denoised is None or torch.sum(next_sigma) < 1e-14:
|
||||
@ -343,12 +380,16 @@ class DPMPP2MSampler(BaseDiffusionSampler):
|
||||
x_advanced = mult[0] * x - mult[1] * denoised_d
|
||||
|
||||
# apply correction if noise level is not 0 and not first step
|
||||
x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard)
|
||||
x = torch.where(
|
||||
append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard
|
||||
)
|
||||
|
||||
return x, denoised
|
||||
|
||||
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):
|
||||
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
|
||||
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
|
||||
x, cond, uc, num_steps
|
||||
)
|
||||
|
||||
old_denoised = None
|
||||
for i in self.get_sigma_gen(num_sigmas):
|
||||
@ -365,7 +406,6 @@ class DPMPP2MSampler(BaseDiffusionSampler):
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SDEDPMPP2MSampler(BaseDiffusionSampler):
|
||||
def get_variables(self, sigma, next_sigma, previous_sigma=None):
|
||||
t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)]
|
||||
@ -380,7 +420,7 @@ class SDEDPMPP2MSampler(BaseDiffusionSampler):
|
||||
|
||||
def get_mult(self, h, r, t, t_next, previous_sigma):
|
||||
mult1 = to_sigma(t_next) / to_sigma(t) * (-h).exp()
|
||||
mult2 = (-2 * h).expm1()
|
||||
mult2 = (-2*h).expm1()
|
||||
|
||||
if previous_sigma is not None:
|
||||
mult3 = 1 + 1 / (2 * r)
|
||||
@ -403,8 +443,11 @@ class SDEDPMPP2MSampler(BaseDiffusionSampler):
|
||||
denoised = self.denoise(x, denoiser, sigma, cond, uc)
|
||||
|
||||
h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma)
|
||||
mult = [append_dims(mult, x.ndim) for mult in self.get_mult(h, r, t, t_next, previous_sigma)]
|
||||
mult_noise = append_dims(next_sigma * (1 - (-2 * h).exp()) ** 0.5, x.ndim)
|
||||
mult = [
|
||||
append_dims(mult, x.ndim)
|
||||
for mult in self.get_mult(h, r, t, t_next, previous_sigma)
|
||||
]
|
||||
mult_noise = append_dims(next_sigma * (1 - (-2*h).exp())**0.5, x.ndim)
|
||||
|
||||
x_standard = mult[0] * x - mult[1] * denoised + mult_noise * torch.randn_like(x)
|
||||
if old_denoised is None or torch.sum(next_sigma) < 1e-14:
|
||||
@ -415,12 +458,16 @@ class SDEDPMPP2MSampler(BaseDiffusionSampler):
|
||||
x_advanced = mult[0] * x - mult[1] * denoised_d + mult_noise * torch.randn_like(x)
|
||||
|
||||
# apply correction if noise level is not 0 and not first step
|
||||
x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard)
|
||||
x = torch.where(
|
||||
append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard
|
||||
)
|
||||
|
||||
return x, denoised
|
||||
|
||||
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, **kwargs):
|
||||
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
|
||||
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
|
||||
x, cond, uc, num_steps
|
||||
)
|
||||
|
||||
old_denoised = None
|
||||
for i in self.get_sigma_gen(num_sigmas):
|
||||
@ -437,7 +484,6 @@ class SDEDPMPP2MSampler(BaseDiffusionSampler):
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SdeditEDMSampler(EulerEDMSampler):
|
||||
def __init__(self, edit_ratio=0.5, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@ -446,7 +492,9 @@ class SdeditEDMSampler(EulerEDMSampler):
|
||||
|
||||
def __call__(self, denoiser, image, randn, cond, uc=None, num_steps=None, edit_ratio=None):
|
||||
randn_unit = randn.clone()
|
||||
randn, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(randn, cond, uc, num_steps)
|
||||
randn, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
|
||||
randn, cond, uc, num_steps
|
||||
)
|
||||
|
||||
if num_steps is None:
|
||||
num_steps = self.num_steps
|
||||
@ -461,7 +509,9 @@ class SdeditEDMSampler(EulerEDMSampler):
|
||||
x = image + randn_unit * append_dims(s_in * sigmas[i], len(randn_unit.shape))
|
||||
|
||||
gamma = (
|
||||
min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) if self.s_tmin <= sigmas[i] <= self.s_tmax else 0.0
|
||||
min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)
|
||||
if self.s_tmin <= sigmas[i] <= self.s_tmax
|
||||
else 0.0
|
||||
)
|
||||
x = self.sampler_step(
|
||||
s_in * sigmas[i],
|
||||
@ -475,8 +525,8 @@ class SdeditEDMSampler(EulerEDMSampler):
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class VideoDDIMSampler(BaseDiffusionSampler):
|
||||
|
||||
def __init__(self, fixed_frames=0, sdedit=False, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.fixed_frames = fixed_frames
|
||||
@ -484,13 +534,10 @@ class VideoDDIMSampler(BaseDiffusionSampler):
|
||||
|
||||
def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
|
||||
alpha_cumprod_sqrt, timesteps = self.discretization(
|
||||
self.num_steps if num_steps is None else num_steps,
|
||||
device=self.device,
|
||||
return_idx=True,
|
||||
do_append_zero=False,
|
||||
self.num_steps if num_steps is None else num_steps, device=self.device, return_idx=True, do_append_zero=False
|
||||
)
|
||||
alpha_cumprod_sqrt = torch.cat([alpha_cumprod_sqrt, alpha_cumprod_sqrt.new_ones([1])])
|
||||
timesteps = torch.cat([torch.tensor(list(timesteps)).new_zeros([1]) - 1, torch.tensor(list(timesteps))])
|
||||
timesteps = torch.cat([torch.tensor(list(timesteps)).new_zeros([1])-1, torch.tensor(list(timesteps))])
|
||||
|
||||
uc = default(uc, cond)
|
||||
|
||||
@ -500,51 +547,36 @@ class VideoDDIMSampler(BaseDiffusionSampler):
|
||||
|
||||
return x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps
|
||||
|
||||
def denoise(self, x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep=None, idx=None, scale=None, scale_emb=None):
|
||||
def denoise(self, x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep=None, idx=None, scale=None, scale_emb=None, ofs=None):
|
||||
additional_model_inputs = {}
|
||||
|
||||
if ofs is not None:
|
||||
additional_model_inputs['ofs'] = ofs
|
||||
|
||||
if isinstance(scale, torch.Tensor) == False and scale == 1:
|
||||
additional_model_inputs["idx"] = x.new_ones([x.shape[0]]) * timestep
|
||||
additional_model_inputs['idx'] = x.new_ones([x.shape[0]]) * timestep
|
||||
if scale_emb is not None:
|
||||
additional_model_inputs["scale_emb"] = scale_emb
|
||||
additional_model_inputs['scale_emb'] = scale_emb
|
||||
denoised = denoiser(x, alpha_cumprod_sqrt, cond, **additional_model_inputs).to(torch.float32)
|
||||
else:
|
||||
additional_model_inputs["idx"] = torch.cat([x.new_ones([x.shape[0]]) * timestep] * 2)
|
||||
denoised = denoiser(
|
||||
*self.guider.prepare_inputs(x, alpha_cumprod_sqrt, cond, uc), **additional_model_inputs
|
||||
).to(torch.float32)
|
||||
additional_model_inputs['idx'] = torch.cat([x.new_ones([x.shape[0]]) * timestep] * 2)
|
||||
denoised = denoiser(*self.guider.prepare_inputs(x, alpha_cumprod_sqrt, cond, uc), **additional_model_inputs).to(torch.float32)
|
||||
if isinstance(self.guider, DynamicCFG):
|
||||
denoised = self.guider(
|
||||
denoised, (1 - alpha_cumprod_sqrt**2) ** 0.5, step_index=self.num_steps - timestep, scale=scale
|
||||
)
|
||||
denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt**2)**0.5, step_index=self.num_steps - timestep, scale=scale)
|
||||
else:
|
||||
denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt**2) ** 0.5, scale=scale)
|
||||
denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt**2)**0.5, scale=scale)
|
||||
return denoised
|
||||
|
||||
def sampler_step(
|
||||
self,
|
||||
alpha_cumprod_sqrt,
|
||||
next_alpha_cumprod_sqrt,
|
||||
denoiser,
|
||||
x,
|
||||
cond,
|
||||
uc=None,
|
||||
idx=None,
|
||||
timestep=None,
|
||||
scale=None,
|
||||
scale_emb=None,
|
||||
):
|
||||
denoised = self.denoise(
|
||||
x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx, scale=scale, scale_emb=scale_emb
|
||||
).to(torch.float32)
|
||||
def sampler_step(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, denoiser, x, cond, uc=None, idx=None, timestep=None, scale=None, scale_emb=None, ofs=None):
|
||||
denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx, scale=scale, scale_emb=scale_emb, ofs=ofs).to(torch.float32) # 1020
|
||||
|
||||
a_t = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5
|
||||
a_t = ((1-next_alpha_cumprod_sqrt**2)/(1-alpha_cumprod_sqrt**2))**0.5
|
||||
b_t = next_alpha_cumprod_sqrt - alpha_cumprod_sqrt * a_t
|
||||
|
||||
x = append_dims(a_t, x.ndim) * x + append_dims(b_t, x.ndim) * denoised
|
||||
return x
|
||||
|
||||
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None):
|
||||
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None, ofs=None): # 1020
|
||||
x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop(
|
||||
x, cond, uc, num_steps
|
||||
)
|
||||
@ -558,25 +590,83 @@ class VideoDDIMSampler(BaseDiffusionSampler):
|
||||
cond,
|
||||
uc,
|
||||
idx=self.num_steps - i,
|
||||
timestep=timesteps[-(i + 1)],
|
||||
timestep=timesteps[-(i+1)],
|
||||
scale=scale,
|
||||
scale_emb=scale_emb,
|
||||
ofs=ofs # 1020
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Image2VideoDDIMSampler(BaseDiffusionSampler):
|
||||
|
||||
def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
|
||||
alpha_cumprod_sqrt, timesteps = self.discretization(
|
||||
self.num_steps if num_steps is None else num_steps, device=self.device, return_idx=True
|
||||
)
|
||||
uc = default(uc, cond)
|
||||
|
||||
num_sigmas = len(alpha_cumprod_sqrt)
|
||||
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
return x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps
|
||||
|
||||
def denoise(self, x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep=None):
|
||||
additional_model_inputs = {}
|
||||
additional_model_inputs['idx'] = torch.cat([x.new_ones([x.shape[0]]) * timestep] * 2)
|
||||
denoised = denoiser(*self.guider.prepare_inputs(x, alpha_cumprod_sqrt, cond, uc), **additional_model_inputs).to(
|
||||
torch.float32)
|
||||
if isinstance(self.guider, DynamicCFG):
|
||||
denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt ** 2) ** 0.5, step_index=self.num_steps - timestep)
|
||||
else:
|
||||
denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt ** 2) ** 0.5)
|
||||
return denoised
|
||||
|
||||
def sampler_step(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, denoiser, x, cond, uc=None, idx=None,
|
||||
timestep=None):
|
||||
# 此处的sigma实际上是alpha_cumprod_sqrt
|
||||
denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep).to(torch.float32)
|
||||
if idx == 1:
|
||||
return denoised
|
||||
|
||||
a_t = ((1 - next_alpha_cumprod_sqrt ** 2) / (1 - alpha_cumprod_sqrt ** 2)) ** 0.5
|
||||
b_t = next_alpha_cumprod_sqrt - alpha_cumprod_sqrt * a_t
|
||||
|
||||
x = append_dims(a_t, x.ndim) * x + append_dims(b_t, x.ndim) * denoised
|
||||
return x
|
||||
|
||||
def __call__(self, image, denoiser, x, cond, uc=None, num_steps=None):
|
||||
x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop(
|
||||
x, cond, uc, num_steps
|
||||
)
|
||||
|
||||
for i in self.get_sigma_gen(num_sigmas):
|
||||
x = self.sampler_step(
|
||||
s_in * alpha_cumprod_sqrt[i],
|
||||
s_in * alpha_cumprod_sqrt[i + 1],
|
||||
denoiser,
|
||||
x,
|
||||
cond,
|
||||
uc,
|
||||
idx=self.num_steps - i,
|
||||
timestep=timesteps[-(i + 1)]
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
class VPSDEDPMPP2MSampler(VideoDDIMSampler):
|
||||
def get_variables(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt=None):
|
||||
alpha_cumprod = alpha_cumprod_sqrt**2
|
||||
lamb = ((alpha_cumprod / (1 - alpha_cumprod)) ** 0.5).log()
|
||||
next_alpha_cumprod = next_alpha_cumprod_sqrt**2
|
||||
lamb_next = ((next_alpha_cumprod / (1 - next_alpha_cumprod)) ** 0.5).log()
|
||||
alpha_cumprod = alpha_cumprod_sqrt ** 2
|
||||
lamb = ((alpha_cumprod / (1-alpha_cumprod))**0.5).log()
|
||||
next_alpha_cumprod = next_alpha_cumprod_sqrt ** 2
|
||||
lamb_next = ((next_alpha_cumprod / (1-next_alpha_cumprod))**0.5).log()
|
||||
h = lamb_next - lamb
|
||||
|
||||
if previous_alpha_cumprod_sqrt is not None:
|
||||
previous_alpha_cumprod = previous_alpha_cumprod_sqrt**2
|
||||
lamb_previous = ((previous_alpha_cumprod / (1 - previous_alpha_cumprod)) ** 0.5).log()
|
||||
previous_alpha_cumprod = previous_alpha_cumprod_sqrt ** 2
|
||||
lamb_previous = ((previous_alpha_cumprod / (1-previous_alpha_cumprod))**0.5).log()
|
||||
h_last = lamb - lamb_previous
|
||||
r = h_last / h
|
||||
return h, r, lamb, lamb_next
|
||||
@ -584,8 +674,8 @@ class VPSDEDPMPP2MSampler(VideoDDIMSampler):
|
||||
return h, None, lamb, lamb_next
|
||||
|
||||
def get_mult(self, h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt):
|
||||
mult1 = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5 * (-h).exp()
|
||||
mult2 = (-2 * h).expm1() * next_alpha_cumprod_sqrt
|
||||
mult1 = ((1-next_alpha_cumprod_sqrt**2) / (1-alpha_cumprod_sqrt**2))**0.5 * (-h).exp()
|
||||
mult2 = (-2*h).expm1() * next_alpha_cumprod_sqrt
|
||||
|
||||
if previous_alpha_cumprod_sqrt is not None:
|
||||
mult3 = 1 + 1 / (2 * r)
|
||||
@ -608,21 +698,18 @@ class VPSDEDPMPP2MSampler(VideoDDIMSampler):
|
||||
timestep=None,
|
||||
scale=None,
|
||||
scale_emb=None,
|
||||
ofs=None # 1020
|
||||
):
|
||||
denoised = self.denoise(
|
||||
x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx, scale=scale, scale_emb=scale_emb
|
||||
).to(torch.float32)
|
||||
denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx, scale=scale, scale_emb=scale_emb, ofs=ofs).to(torch.float32) # 1020
|
||||
if idx == 1:
|
||||
return denoised, denoised
|
||||
|
||||
h, r, lamb, lamb_next = self.get_variables(
|
||||
alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt
|
||||
)
|
||||
h, r, lamb, lamb_next = self.get_variables(alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt)
|
||||
mult = [
|
||||
append_dims(mult, x.ndim)
|
||||
for mult in self.get_mult(h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt)
|
||||
]
|
||||
mult_noise = append_dims((1 - next_alpha_cumprod_sqrt**2) ** 0.5 * (1 - (-2 * h).exp()) ** 0.5, x.ndim)
|
||||
mult_noise = append_dims((1-next_alpha_cumprod_sqrt**2)**0.5 * (1 - (-2*h).exp())**0.5, x.ndim)
|
||||
|
||||
x_standard = mult[0] * x - mult[1] * denoised + mult_noise * torch.randn_like(x)
|
||||
if old_denoised is None or torch.sum(next_alpha_cumprod_sqrt) < 1e-14:
|
||||
@ -636,24 +723,23 @@ class VPSDEDPMPP2MSampler(VideoDDIMSampler):
|
||||
|
||||
return x, denoised
|
||||
|
||||
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None):
|
||||
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None, ofs=None): # 1020
|
||||
x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop(
|
||||
x, cond, uc, num_steps
|
||||
)
|
||||
|
||||
if self.fixed_frames > 0:
|
||||
prefix_frames = x[:, : self.fixed_frames]
|
||||
prefix_frames = x[:, :self.fixed_frames]
|
||||
old_denoised = None
|
||||
for i in self.get_sigma_gen(num_sigmas):
|
||||
|
||||
if self.fixed_frames > 0:
|
||||
if self.sdedit:
|
||||
rd = torch.randn_like(prefix_frames)
|
||||
noised_prefix_frames = alpha_cumprod_sqrt[i] * prefix_frames + rd * append_dims(
|
||||
s_in * (1 - alpha_cumprod_sqrt[i] ** 2) ** 0.5, len(prefix_frames.shape)
|
||||
)
|
||||
x = torch.cat([noised_prefix_frames, x[:, self.fixed_frames :]], dim=1)
|
||||
noised_prefix_frames = alpha_cumprod_sqrt[i] * prefix_frames + rd * append_dims(s_in * (1 - alpha_cumprod_sqrt[i] ** 2)**0.5, len(prefix_frames.shape))
|
||||
x = torch.cat([noised_prefix_frames, x[:, self.fixed_frames:]], dim=1)
|
||||
else:
|
||||
x = torch.cat([prefix_frames, x[:, self.fixed_frames :]], dim=1)
|
||||
x = torch.cat([prefix_frames, x[:, self.fixed_frames:]], dim=1)
|
||||
x, old_denoised = self.sampler_step(
|
||||
old_denoised,
|
||||
None if i == 0 else s_in * alpha_cumprod_sqrt[i - 1],
|
||||
@ -664,28 +750,29 @@ class VPSDEDPMPP2MSampler(VideoDDIMSampler):
|
||||
cond,
|
||||
uc=uc,
|
||||
idx=self.num_steps - i,
|
||||
timestep=timesteps[-(i + 1)],
|
||||
timestep=timesteps[-(i+1)],
|
||||
scale=scale,
|
||||
scale_emb=scale_emb,
|
||||
ofs=ofs # 1020
|
||||
)
|
||||
|
||||
if self.fixed_frames > 0:
|
||||
x = torch.cat([prefix_frames, x[:, self.fixed_frames :]], dim=1)
|
||||
x = torch.cat([prefix_frames, x[:, self.fixed_frames:]], dim=1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class VPODEDPMPP2MSampler(VideoDDIMSampler):
|
||||
def get_variables(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt=None):
|
||||
alpha_cumprod = alpha_cumprod_sqrt**2
|
||||
lamb = ((alpha_cumprod / (1 - alpha_cumprod)) ** 0.5).log()
|
||||
next_alpha_cumprod = next_alpha_cumprod_sqrt**2
|
||||
lamb_next = ((next_alpha_cumprod / (1 - next_alpha_cumprod)) ** 0.5).log()
|
||||
alpha_cumprod = alpha_cumprod_sqrt ** 2
|
||||
lamb = ((alpha_cumprod / (1-alpha_cumprod))**0.5).log()
|
||||
next_alpha_cumprod = next_alpha_cumprod_sqrt ** 2
|
||||
lamb_next = ((next_alpha_cumprod / (1-next_alpha_cumprod))**0.5).log()
|
||||
h = lamb_next - lamb
|
||||
|
||||
if previous_alpha_cumprod_sqrt is not None:
|
||||
previous_alpha_cumprod = previous_alpha_cumprod_sqrt**2
|
||||
lamb_previous = ((previous_alpha_cumprod / (1 - previous_alpha_cumprod)) ** 0.5).log()
|
||||
previous_alpha_cumprod = previous_alpha_cumprod_sqrt ** 2
|
||||
lamb_previous = ((previous_alpha_cumprod / (1-previous_alpha_cumprod))**0.5).log()
|
||||
h_last = lamb - lamb_previous
|
||||
r = h_last / h
|
||||
return h, r, lamb, lamb_next
|
||||
@ -693,7 +780,7 @@ class VPODEDPMPP2MSampler(VideoDDIMSampler):
|
||||
return h, None, lamb, lamb_next
|
||||
|
||||
def get_mult(self, h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt):
|
||||
mult1 = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5
|
||||
mult1 = ((1-next_alpha_cumprod_sqrt**2) / (1-alpha_cumprod_sqrt**2))**0.5
|
||||
mult2 = (-h).expm1() * next_alpha_cumprod_sqrt
|
||||
|
||||
if previous_alpha_cumprod_sqrt is not None:
|
||||
@ -714,15 +801,13 @@ class VPODEDPMPP2MSampler(VideoDDIMSampler):
|
||||
cond,
|
||||
uc=None,
|
||||
idx=None,
|
||||
timestep=None,
|
||||
timestep=None
|
||||
):
|
||||
denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx).to(torch.float32)
|
||||
if idx == 1:
|
||||
return denoised, denoised
|
||||
|
||||
h, r, lamb, lamb_next = self.get_variables(
|
||||
alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt
|
||||
)
|
||||
h, r, lamb, lamb_next = self.get_variables(alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt)
|
||||
mult = [
|
||||
append_dims(mult, x.ndim)
|
||||
for mult in self.get_mult(h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt)
|
||||
@ -757,7 +842,39 @@ class VPODEDPMPP2MSampler(VideoDDIMSampler):
|
||||
cond,
|
||||
uc=uc,
|
||||
idx=self.num_steps - i,
|
||||
timestep=timesteps[-(i + 1)],
|
||||
timestep=timesteps[-(i+1)]
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
class VideoDDPMSampler(VideoDDIMSampler):
|
||||
def sampler_step(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, denoiser, x, cond, uc=None, idx=None):
|
||||
# 此处的sigma实际上是alpha_cumprod_sqrt
|
||||
denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, idx*1000//self.num_steps).to(torch.float32)
|
||||
if idx == 1:
|
||||
return denoised
|
||||
|
||||
alpha_sqrt = alpha_cumprod_sqrt / next_alpha_cumprod_sqrt
|
||||
x = append_dims(alpha_sqrt * (1-next_alpha_cumprod_sqrt**2) / (1-alpha_cumprod_sqrt**2), x.ndim) * x \
|
||||
+ append_dims(next_alpha_cumprod_sqrt * (1-alpha_sqrt**2) / (1-alpha_cumprod_sqrt**2), x.ndim) * denoised \
|
||||
+ append_dims(((1-next_alpha_cumprod_sqrt**2) * (1-alpha_sqrt**2) / (1-alpha_cumprod_sqrt**2))**0.5, x.ndim) * torch.randn_like(x)
|
||||
|
||||
return x
|
||||
|
||||
def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
|
||||
x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc = self.prepare_sampling_loop(
|
||||
x, cond, uc, num_steps
|
||||
)
|
||||
|
||||
for i in self.get_sigma_gen(num_sigmas):
|
||||
x = self.sampler_step(
|
||||
s_in * alpha_cumprod_sqrt[i],
|
||||
s_in * alpha_cumprod_sqrt[i + 1],
|
||||
denoiser,
|
||||
x,
|
||||
cond,
|
||||
uc,
|
||||
idx=self.num_steps - i
|
||||
)
|
||||
|
||||
return x
|
@ -17,23 +17,20 @@ class EDMSampling:
|
||||
|
||||
|
||||
class DiscreteSampling:
|
||||
def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True, uniform_sampling=False):
|
||||
def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True, uniform_sampling=False, group_num=0):
|
||||
self.num_idx = num_idx
|
||||
self.sigmas = instantiate_from_config(discretization_config)(num_idx, do_append_zero=do_append_zero, flip=flip)
|
||||
self.sigmas = instantiate_from_config(discretization_config)(
|
||||
num_idx, do_append_zero=do_append_zero, flip=flip
|
||||
)
|
||||
world_size = mpu.get_data_parallel_world_size()
|
||||
if world_size <= 8:
|
||||
uniform_sampling = False
|
||||
self.uniform_sampling = uniform_sampling
|
||||
self.group_num = group_num
|
||||
if self.uniform_sampling:
|
||||
i = 1
|
||||
while True:
|
||||
if world_size % i != 0 or num_idx % (world_size // i) != 0:
|
||||
i += 1
|
||||
else:
|
||||
self.group_num = world_size // i
|
||||
break
|
||||
|
||||
assert self.group_num > 0
|
||||
assert world_size % self.group_num == 0
|
||||
self.group_width = world_size // self.group_num # the number of rank in one group
|
||||
assert world_size % group_num == 0
|
||||
self.group_width = world_size // group_num # the number of rank in one group
|
||||
self.sigma_interval = self.num_idx // self.group_num
|
||||
|
||||
def idx_to_sigma(self, idx):
|
||||
@ -45,9 +42,7 @@ class DiscreteSampling:
|
||||
group_index = rank // self.group_width
|
||||
idx = default(
|
||||
rand,
|
||||
torch.randint(
|
||||
group_index * self.sigma_interval, (group_index + 1) * self.sigma_interval, (n_samples,)
|
||||
),
|
||||
torch.randint(group_index * self.sigma_interval, (group_index + 1) * self.sigma_interval, (n_samples,)),
|
||||
)
|
||||
else:
|
||||
idx = default(
|
||||
@ -59,7 +54,6 @@ class DiscreteSampling:
|
||||
else:
|
||||
return self.idx_to_sigma(idx)
|
||||
|
||||
|
||||
class PartialDiscreteSampling:
|
||||
def __init__(self, discretization_config, total_num_idx, partial_num_idx, do_append_zero=False, flip=True):
|
||||
self.total_num_idx = total_num_idx
|
||||
|
@ -592,8 +592,11 @@ class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine):
|
||||
unregularized: bool = False,
|
||||
input_cp: bool = False,
|
||||
output_cp: bool = False,
|
||||
use_cp: bool = True,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
||||
if self.cp_size > 0 and not input_cp:
|
||||
if self.cp_size <= 1:
|
||||
use_cp = False
|
||||
if self.cp_size > 0 and use_cp and not input_cp:
|
||||
if not is_context_parallel_initialized:
|
||||
initialize_context_parallel(self.cp_size)
|
||||
|
||||
@ -603,11 +606,11 @@ class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine):
|
||||
x = _conv_split(x, dim=2, kernel_size=1)
|
||||
|
||||
if return_reg_log:
|
||||
z, reg_log = super().encode(x, return_reg_log, unregularized)
|
||||
z, reg_log = super().encode(x, return_reg_log, unregularized, use_cp=use_cp)
|
||||
else:
|
||||
z = super().encode(x, return_reg_log, unregularized)
|
||||
z = super().encode(x, return_reg_log, unregularized, use_cp=use_cp)
|
||||
|
||||
if self.cp_size > 0 and not output_cp:
|
||||
if self.cp_size > 0 and use_cp and not output_cp:
|
||||
z = _conv_gather(z, dim=2, kernel_size=1)
|
||||
|
||||
if return_reg_log:
|
||||
@ -619,23 +622,24 @@ class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine):
|
||||
z: torch.Tensor,
|
||||
input_cp: bool = False,
|
||||
output_cp: bool = False,
|
||||
split_kernel_size: int = 1,
|
||||
use_cp: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
if self.cp_size > 0 and not input_cp:
|
||||
if self.cp_size <= 1:
|
||||
use_cp = False
|
||||
if self.cp_size > 0 and use_cp and not input_cp:
|
||||
if not is_context_parallel_initialized:
|
||||
initialize_context_parallel(self.cp_size)
|
||||
|
||||
global_src_rank = get_context_parallel_group_rank() * self.cp_size
|
||||
torch.distributed.broadcast(z, src=global_src_rank, group=get_context_parallel_group())
|
||||
|
||||
z = _conv_split(z, dim=2, kernel_size=split_kernel_size)
|
||||
z = _conv_split(z, dim=2, kernel_size=1)
|
||||
|
||||
x = super().decode(z, **kwargs)
|
||||
|
||||
if self.cp_size > 0 and not output_cp:
|
||||
x = _conv_gather(x, dim=2, kernel_size=split_kernel_size)
|
||||
x = super().decode(z, use_cp=use_cp, **kwargs)
|
||||
|
||||
if self.cp_size > 0 and use_cp and not output_cp:
|
||||
x = _conv_gather(x, dim=2, kernel_size=1)
|
||||
return x
|
||||
|
||||
def forward(
|
||||
|
@ -16,11 +16,7 @@ from sgm.util import (
|
||||
get_context_parallel_group_rank,
|
||||
)
|
||||
|
||||
# try:
|
||||
from vae_modules.utils import SafeConv3d as Conv3d
|
||||
# except:
|
||||
# # Degrade to normal Conv3d if SafeConv3d is not available
|
||||
# from torch.nn import Conv3d
|
||||
|
||||
|
||||
def cast_tuple(t, length=1):
|
||||
@ -81,8 +77,6 @@ def _split(input_, dim):
|
||||
|
||||
cp_rank = get_context_parallel_rank()
|
||||
|
||||
# print('in _split, cp_rank:', cp_rank, 'input_size:', input_.shape)
|
||||
|
||||
inpu_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous()
|
||||
input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous()
|
||||
dim_size = input_.size()[dim] // cp_world_size
|
||||
@ -94,8 +88,6 @@ def _split(input_, dim):
|
||||
output = torch.cat([inpu_first_frame_, output], dim=dim)
|
||||
output = output.contiguous()
|
||||
|
||||
# print('out _split, cp_rank:', cp_rank, 'output_size:', output.shape)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@ -382,19 +374,6 @@ class ContextParallelCausalConv3d(nn.Module):
|
||||
self.cache_padding = None
|
||||
|
||||
def forward(self, input_, clear_cache=True):
|
||||
# if input_.shape[2] == 1: # handle image
|
||||
# # first frame padding
|
||||
# input_parallel = torch.cat([input_] * self.time_kernel_size, dim=2)
|
||||
# else:
|
||||
# input_parallel = conv_pass_from_last_rank(input_, self.temporal_dim, self.time_kernel_size)
|
||||
|
||||
# padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
|
||||
# input_parallel = F.pad(input_parallel, padding_2d, mode = 'constant', value = 0)
|
||||
|
||||
# output_parallel = self.conv(input_parallel)
|
||||
# output = output_parallel
|
||||
# return output
|
||||
|
||||
input_parallel = fake_cp_pass_from_previous_rank(
|
||||
input_, self.temporal_dim, self.time_kernel_size, self.cache_padding
|
||||
)
|
||||
@ -441,7 +420,8 @@ class ContextParallelGroupNorm(torch.nn.GroupNorm):
|
||||
return output
|
||||
|
||||
|
||||
def Normalize(in_channels, gather=False, **kwargs): # same for 3D and 2D
|
||||
def Normalize(in_channels, gather=False, **kwargs):
|
||||
# same for 3D and 2D
|
||||
if gather:
|
||||
return ContextParallelGroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
else:
|
||||
@ -488,24 +468,34 @@ class SpatialNorm3D(nn.Module):
|
||||
kernel_size=1,
|
||||
)
|
||||
|
||||
def forward(self, f, zq, clear_fake_cp_cache=True):
|
||||
if f.shape[2] > 1 and f.shape[2] % 2 == 1:
|
||||
def forward(self, f, zq, clear_fake_cp_cache=True, fake_cp_rank0=True):
|
||||
if f.shape[2] > 1 and get_context_parallel_rank() == 0 and fake_cp_rank0:
|
||||
f_first, f_rest = f[:, :, :1], f[:, :, 1:]
|
||||
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
|
||||
zq_first, zq_rest = zq[:, :, :1], zq[:, :, 1:]
|
||||
zq_first = torch.nn.functional.interpolate(zq_first, size=f_first_size, mode="nearest")
|
||||
zq_rest = torch.nn.functional.interpolate(zq_rest, size=f_rest_size, mode="nearest")
|
||||
|
||||
zq_rest_splits = torch.split(zq_rest, 32, dim=1)
|
||||
interpolated_splits = [
|
||||
torch.nn.functional.interpolate(split, size=f_rest_size, mode="nearest") for split in zq_rest_splits
|
||||
]
|
||||
|
||||
zq_rest = torch.cat(interpolated_splits, dim=1)
|
||||
# zq_rest = torch.nn.functional.interpolate(zq_rest, size=f_rest_size, mode="nearest")
|
||||
zq = torch.cat([zq_first, zq_rest], dim=2)
|
||||
else:
|
||||
zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:], mode="nearest")
|
||||
f_size = f.shape[-3:]
|
||||
|
||||
zq_splits = torch.split(zq, 32, dim=1)
|
||||
interpolated_splits = [
|
||||
torch.nn.functional.interpolate(split, size=f_size, mode="nearest") for split in zq_splits
|
||||
]
|
||||
zq = torch.cat(interpolated_splits, dim=1)
|
||||
|
||||
if self.add_conv:
|
||||
zq = self.conv(zq, clear_cache=clear_fake_cp_cache)
|
||||
|
||||
# f = conv_gather_from_context_parallel_region(f, dim=2, kernel_size=1)
|
||||
norm_f = self.norm_layer(f)
|
||||
# norm_f = conv_scatter_to_context_parallel_region(norm_f, dim=2, kernel_size=1)
|
||||
|
||||
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
|
||||
return new_f
|
||||
|
||||
@ -541,23 +531,44 @@ class Upsample3D(nn.Module):
|
||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.compress_time = compress_time
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, fake_cp_rank0=True):
|
||||
if self.compress_time and x.shape[2] > 1:
|
||||
if x.shape[2] % 2 == 1:
|
||||
if get_context_parallel_rank() == 0 and fake_cp_rank0:
|
||||
# print(x.shape)
|
||||
# split first frame
|
||||
x_first, x_rest = x[:, :, 0], x[:, :, 1:]
|
||||
|
||||
x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0, mode="nearest")
|
||||
x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0, mode="nearest")
|
||||
|
||||
splits = torch.split(x_rest, 32, dim=1)
|
||||
interpolated_splits = [
|
||||
torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits
|
||||
]
|
||||
x_rest = torch.cat(interpolated_splits, dim=1)
|
||||
|
||||
# x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0, mode="nearest")
|
||||
x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2)
|
||||
else:
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
splits = torch.split(x, 32, dim=1)
|
||||
interpolated_splits = [
|
||||
torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits
|
||||
]
|
||||
x = torch.cat(interpolated_splits, dim=1)
|
||||
|
||||
# x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
|
||||
else:
|
||||
# only interpolate 2D
|
||||
t = x.shape[2]
|
||||
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
# x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
|
||||
splits = torch.split(x, 32, dim=1)
|
||||
interpolated_splits = [
|
||||
torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits
|
||||
]
|
||||
x = torch.cat(interpolated_splits, dim=1)
|
||||
|
||||
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
|
||||
|
||||
if self.with_conv:
|
||||
@ -579,21 +590,30 @@ class DownSample3D(nn.Module):
|
||||
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0)
|
||||
self.compress_time = compress_time
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, fake_cp_rank0=True):
|
||||
if self.compress_time and x.shape[2] > 1:
|
||||
h, w = x.shape[-2:]
|
||||
x = rearrange(x, "b c t h w -> (b h w) c t")
|
||||
|
||||
if x.shape[-1] % 2 == 1:
|
||||
if get_context_parallel_rank() == 0 and fake_cp_rank0:
|
||||
# split first frame
|
||||
x_first, x_rest = x[..., 0], x[..., 1:]
|
||||
|
||||
if x_rest.shape[-1] > 0:
|
||||
x_rest = torch.nn.functional.avg_pool1d(x_rest, kernel_size=2, stride=2)
|
||||
splits = torch.split(x_rest, 32, dim=1)
|
||||
interpolated_splits = [
|
||||
torch.nn.functional.avg_pool1d(split, kernel_size=2, stride=2) for split in splits
|
||||
]
|
||||
x_rest = torch.cat(interpolated_splits, dim=1)
|
||||
x = torch.cat([x_first[..., None], x_rest], dim=-1)
|
||||
x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w)
|
||||
else:
|
||||
x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2)
|
||||
# x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2)
|
||||
splits = torch.split(x, 32, dim=1)
|
||||
interpolated_splits = [
|
||||
torch.nn.functional.avg_pool1d(split, kernel_size=2, stride=2) for split in splits
|
||||
]
|
||||
x = torch.cat(interpolated_splits, dim=1)
|
||||
x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w)
|
||||
|
||||
if self.with_conv:
|
||||
@ -673,13 +693,13 @@ class ContextParallelResnetBlock3D(nn.Module):
|
||||
padding=0,
|
||||
)
|
||||
|
||||
def forward(self, x, temb, zq=None, clear_fake_cp_cache=True):
|
||||
def forward(self, x, temb, zq=None, clear_fake_cp_cache=True, fake_cp_rank0=True):
|
||||
h = x
|
||||
|
||||
# if isinstance(self.norm1, torch.nn.GroupNorm):
|
||||
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
|
||||
if zq is not None:
|
||||
h = self.norm1(h, zq, clear_fake_cp_cache=clear_fake_cp_cache)
|
||||
h = self.norm1(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0)
|
||||
else:
|
||||
h = self.norm1(h)
|
||||
# if isinstance(self.norm1, torch.nn.GroupNorm):
|
||||
@ -694,7 +714,7 @@ class ContextParallelResnetBlock3D(nn.Module):
|
||||
# if isinstance(self.norm2, torch.nn.GroupNorm):
|
||||
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
|
||||
if zq is not None:
|
||||
h = self.norm2(h, zq, clear_fake_cp_cache=clear_fake_cp_cache)
|
||||
h = self.norm2(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0)
|
||||
else:
|
||||
h = self.norm2(h)
|
||||
# if isinstance(self.norm2, torch.nn.GroupNorm):
|
||||
@ -807,23 +827,24 @@ class ContextParallelEncoder3D(nn.Module):
|
||||
kernel_size=3,
|
||||
)
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
def forward(self, x, clear_fake_cp_cache=True, fake_cp_rank0=True):
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# downsampling
|
||||
h = self.conv_in(x)
|
||||
h = self.conv_in(x, clear_cache=clear_fake_cp_cache)
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](h, temb)
|
||||
h = self.down[i_level].block[i_block](h, temb, clear_fake_cp_cache=clear_fake_cp_cache)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
print("Attention not implemented")
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
h = self.down[i_level].downsample(h)
|
||||
h = self.down[i_level].downsample(h, fake_cp_rank0=fake_cp_rank0)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.block_2(h, temb)
|
||||
h = self.mid.block_1(h, temb, clear_fake_cp_cache=clear_fake_cp_cache)
|
||||
h = self.mid.block_2(h, temb, clear_fake_cp_cache=clear_fake_cp_cache)
|
||||
|
||||
# end
|
||||
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
|
||||
@ -831,7 +852,7 @@ class ContextParallelEncoder3D(nn.Module):
|
||||
# h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
|
||||
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
h = self.conv_out(h, clear_cache=clear_fake_cp_cache)
|
||||
|
||||
return h
|
||||
|
||||
@ -934,6 +955,11 @@ class ContextParallelDecoder3D(nn.Module):
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
# # Symmetrical enc-dec
|
||||
if i_level <= self.temporal_compress_level:
|
||||
up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=True)
|
||||
else:
|
||||
up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False)
|
||||
if i_level < self.num_resolutions - self.temporal_compress_level:
|
||||
up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False)
|
||||
else:
|
||||
@ -948,7 +974,7 @@ class ContextParallelDecoder3D(nn.Module):
|
||||
kernel_size=3,
|
||||
)
|
||||
|
||||
def forward(self, z, clear_fake_cp_cache=True, **kwargs):
|
||||
def forward(self, z, clear_fake_cp_cache=True, fake_cp_rank0=True):
|
||||
self.last_z_shape = z.shape
|
||||
|
||||
# timestep embedding
|
||||
@ -961,23 +987,25 @@ class ContextParallelDecoder3D(nn.Module):
|
||||
h = self.conv_in(z, clear_cache=clear_fake_cp_cache)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache)
|
||||
h = self.mid.block_2(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache)
|
||||
h = self.mid.block_1(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0)
|
||||
h = self.mid.block_2(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache)
|
||||
h = self.up[i_level].block[i_block](
|
||||
h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0
|
||||
)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h, zq)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
h = self.up[i_level].upsample(h, fake_cp_rank0=fake_cp_rank0)
|
||||
|
||||
# end
|
||||
if self.give_pre_end:
|
||||
return h
|
||||
|
||||
h = self.norm_out(h, zq, clear_fake_cp_cache=clear_fake_cp_cache)
|
||||
h = self.norm_out(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h, clear_cache=clear_fake_cp_cache)
|
||||
|
||||
|
@ -1,22 +1,15 @@
|
||||
"""
|
||||
This script demonstrates how to convert and generate video from a text prompt
|
||||
using CogVideoX with 🤗Huggingface Diffusers Pipeline.
|
||||
This script requires the `diffusers>=0.30.2` library to be installed.
|
||||
|
||||
Functions:
|
||||
- reassign_query_key_value_inplace: Reassigns the query, key, and value weights in-place.
|
||||
- reassign_query_key_layernorm_inplace: Reassigns layer normalization for query and key in-place.
|
||||
- reassign_adaln_norm_inplace: Reassigns adaptive layer normalization in-place.
|
||||
- remove_keys_inplace: Removes specified keys from the state_dict in-place.
|
||||
- replace_up_keys_inplace: Replaces keys in the "up" block in-place.
|
||||
- get_state_dict: Extracts the state_dict from a saved checkpoint.
|
||||
- update_state_dict_inplace: Updates the state_dict with new key assignments in-place.
|
||||
- convert_transformer: Converts a transformer checkpoint to the CogVideoX format.
|
||||
- convert_vae: Converts a VAE checkpoint to the CogVideoX format.
|
||||
- get_args: Parses command-line arguments for the script.
|
||||
- generate_video: Generates a video from a text prompt using the CogVideoX pipeline.
|
||||
"""
|
||||
|
||||
The script demonstrates how to convert the weights of the CogVideoX model from SAT to Hugging Face format.
|
||||
This script supports the conversion of the following models:
|
||||
- CogVideoX-2B
|
||||
- CogVideoX-5B, CogVideoX-5B-I2V
|
||||
- CogVideoX1.1-5B, CogVideoX1.1-5B-I2V
|
||||
|
||||
Original Script:
|
||||
https://github.com/huggingface/diffusers/blob/main/scripts/convert_cogvideox_to_diffusers.py
|
||||
|
||||
"""
|
||||
import argparse
|
||||
from typing import Any, Dict
|
||||
|
||||
@ -153,12 +146,12 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key:
|
||||
|
||||
|
||||
def convert_transformer(
|
||||
ckpt_path: str,
|
||||
num_layers: int,
|
||||
num_attention_heads: int,
|
||||
use_rotary_positional_embeddings: bool,
|
||||
i2v: bool,
|
||||
dtype: torch.dtype,
|
||||
ckpt_path: str,
|
||||
num_layers: int,
|
||||
num_attention_heads: int,
|
||||
use_rotary_positional_embeddings: bool,
|
||||
i2v: bool,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
PREFIX_KEY = "model.diffusion_model."
|
||||
|
||||
@ -172,7 +165,7 @@ def convert_transformer(
|
||||
).to(dtype=dtype)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[len(PREFIX_KEY) :]
|
||||
new_key = key[len(PREFIX_KEY):]
|
||||
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_inplace(original_state_dict, key, new_key)
|
||||
@ -209,7 +202,8 @@ def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype):
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint")
|
||||
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
|
||||
)
|
||||
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint")
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
||||
parser.add_argument("--fp16", action="store_true", default=False, help="Whether to save the model weights in fp16")
|
||||
@ -259,9 +253,10 @@ if __name__ == "__main__":
|
||||
if args.vae_ckpt_path is not None:
|
||||
vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, dtype)
|
||||
|
||||
text_encoder_id = "google/t5-v1_1-xxl"
|
||||
text_encoder_id = "/share/official_pretrains/hf_home/t5-v1_1-xxl"
|
||||
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
|
||||
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
|
||||
|
||||
# Apparently, the conversion does not work anymore without this :shrug:
|
||||
for param in text_encoder.parameters():
|
||||
param.data = param.data.contiguous()
|
||||
@ -301,4 +296,7 @@ if __name__ == "__main__":
|
||||
# We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird
|
||||
# for users to specify variant when the default is not fp32 and they want to run with the correct default (which
|
||||
# is either fp16/bf16 here).
|
||||
pipe.save_pretrained(args.output_path, safe_serialization=True, push_to_hub=args.push_to_hub)
|
||||
|
||||
# This is necessary This is necessary for users with insufficient memory,
|
||||
# such as those using Colab and notebooks, as it can save some memory used for model loading.
|
||||
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=args.push_to_hub)
|
||||
|
Loading…
x
Reference in New Issue
Block a user