Merge pull request #474 from THUDM/CogVideoX_dev

Fix #472 #473
This commit is contained in:
Yuxuan.Zhang 2024-11-09 00:18:01 +08:00 committed by GitHub
commit e2987ff565
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 128 additions and 191 deletions

View File

@ -171,49 +171,49 @@ models we currently offer, along with their foundational information.
<table style="border-collapse: collapse; width: 100%;"> <table style="border-collapse: collapse; width: 100%;">
<tr> <tr>
<th style="text-align: center;">Model Name</th> <th style="text-align: center;">Model Name</th>
<th style="text-align: center;">CogVideoX1.5-5B (Latest)</th>
<th style="text-align: center;">CogVideoX1.5-5B-I2V (Latest)</th>
<th style="text-align: center;">CogVideoX-2B</th> <th style="text-align: center;">CogVideoX-2B</th>
<th style="text-align: center;">CogVideoX-5B</th> <th style="text-align: center;">CogVideoX-5B</th>
<th style="text-align: center;">CogVideoX-5B-I2V</th> <th style="text-align: center;">CogVideoX-5B-I2V</th>
<th style="text-align: center;">CogVideoX1.5-5B</th>
<th style="text-align: center;">CogVideoX1.5-5B-I2V</th>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">Release Date</td> <td style="text-align: center;">Release Date</td>
<th style="text-align: center;">November 8, 2024</th>
<th style="text-align: center;">November 8, 2024</th>
<th style="text-align: center;">August 6, 2024</th> <th style="text-align: center;">August 6, 2024</th>
<th style="text-align: center;">August 27, 2024</th> <th style="text-align: center;">August 27, 2024</th>
<th style="text-align: center;">September 19, 2024</th> <th style="text-align: center;">September 19, 2024</th>
<th style="text-align: center;">November 8, 2024</th>
<th style="text-align: center;">November 8, 2024</th>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">Video Resolution</td> <td style="text-align: center;">Video Resolution</td>
<td colspan="3" style="text-align: center;">720 * 480</td>
<td colspan="1" style="text-align: center;">1360 * 768</td> <td colspan="1" style="text-align: center;">1360 * 768</td>
<td colspan="1" style="text-align: center;">256 <= W <=1360<br>256 <= H <=768<br> W,H % 16 == 0</td> <td colspan="1" style="text-align: center;">256 <= W <=1360<br> 256 <= H <=768<br> W,H % 16 == 0</td>
<td colspan="3" style="text-align: center;">720 * 480</td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">Inference Precision</td> <td style="text-align: center;">Inference Precision</td>
<td style="text-align: center;"><b>FP16*(recommended)</b>, BF16, FP32, FP8*, INT8, not supported: INT4</td>
<td colspan="2" style="text-align: center;"><b>BF16(recommended)</b>, FP16, FP32, FP8*, INT8, not supported: INT4</td>
<td colspan="2" style="text-align: center;"><b>BF16</b></td> <td colspan="2" style="text-align: center;"><b>BF16</b></td>
<td style="text-align: center;"><b>FP16*(Recommended)</b>, BF16, FP32, FP8*, INT8, Not supported: INT4</td>
<td colspan="2" style="text-align: center;"><b>BF16 (Recommended)</b>, FP16, FP32, FP8*, INT8, Not supported: INT4</td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">Single GPU Memory Usage</td> <td style="text-align: center;">Single GPU Memory Usage<br></td>
<td style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> FP16: 18GB<br><b>diffusers FP16: from 4GB*</b><br><b>diffusers INT8(torchao): from 3.6GB*</b></td> <td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 66GB <br></td>
<td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 26GB<br><b>diffusers BF16 : from 5GB*</b><br><b>diffusers INT8(torchao): from 4.4GB*</b></td> <td style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> FP16: 18GB <br><b>diffusers FP16: 4GB minimum* </b><br><b>diffusers INT8 (torchao): 3.6GB minimum*</b></td>
<td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 66GB<br></td> <td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 26GB <br><b>diffusers BF16 : 5GB minimum* </b><br><b>diffusers INT8 (torchao): 4.4GB minimum* </b></td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">Multi-GPU Memory Usage</td> <td style="text-align: center;">Multi-GPU Memory Usage</td>
<td colspan="2" style="text-align: center;"><b>Not Supported</b><br></td>
<td style="text-align: center;"><b>FP16: 10GB* using diffusers</b><br></td> <td style="text-align: center;"><b>FP16: 10GB* using diffusers</b><br></td>
<td colspan="2" style="text-align: center;"><b>BF16: 15GB* using diffusers</b><br></td> <td colspan="2" style="text-align: center;"><b>BF16: 15GB* using diffusers</b><br></td>
<td colspan="2" style="text-align: center;"><b>Not supported</b><br></td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">Inference Speed<br>(Step = 50, FP/BF16)</td> <td style="text-align: center;">Inference Speed<br>(Step = 50, FP/BF16)</td>
<td colspan="2" style="text-align: center;">Single A100: ~1000 seconds (5-second video)<br>Single H100: ~550 seconds (5-second video)</td>
<td style="text-align: center;">Single A100: ~90 seconds<br>Single H100: ~45 seconds</td> <td style="text-align: center;">Single A100: ~90 seconds<br>Single H100: ~45 seconds</td>
<td colspan="2" style="text-align: center;">Single A100: ~180 seconds<br>Single H100: ~90 seconds</td> <td colspan="2" style="text-align: center;">Single A100: ~180 seconds<br>Single H100: ~90 seconds</td>
<td colspan="2" style="text-align: center;">Single A100: ~1000 seconds (5-second video)<br>Single H100: ~550 seconds (5-second video)</td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">Prompt Language</td> <td style="text-align: center;">Prompt Language</td>
@ -221,38 +221,37 @@ models we currently offer, along with their foundational information.
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">Prompt Token Limit</td> <td style="text-align: center;">Prompt Token Limit</td>
<td colspan="3" style="text-align: center;">226 Tokens</td>
<td colspan="2" style="text-align: center;">224 Tokens</td> <td colspan="2" style="text-align: center;">224 Tokens</td>
<td colspan="3" style="text-align: center;">226 Tokens</td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">Video Length</td> <td style="text-align: center;">Video Length</td>
<td colspan="2" style="text-align: center;">5 seconds or 10 seconds</td>
<td colspan="3" style="text-align: center;">6 seconds</td> <td colspan="3" style="text-align: center;">6 seconds</td>
<td colspan="2" style="text-align: center;">5 or 10 seconds</td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">Frame Rate</td> <td style="text-align: center;">Frame Rate</td>
<td colspan="3" style="text-align: center;">8 frames / second</td> <td colspan="2" style="text-align: center;">16 frames / second </td>
<td colspan="2" style="text-align: center;">16 frames / second</td> <td colspan="3" style="text-align: center;">8 frames / second </td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">Positional Encoding</td> <td style="text-align: center;">Position Encoding</td>
<td style="text-align: center;">3d_sincos_pos_embed</td> <td colspan="2" style="text-align: center;">3d_rope_pos_embed</td>
<td style="text-align: center;">3d_sincos_pos_embed</td>
<td style="text-align: center;">3d_rope_pos_embed</td> <td style="text-align: center;">3d_rope_pos_embed</td>
<td style="text-align: center;">3d_rope_pos_embed + learnable_pos_embed</td> <td style="text-align: center;">3d_rope_pos_embed + learnable_pos_embed</td>
<td style="text-align: center;">3d_rope_pos_embed</td>
<td style="text-align: center;">3d_rope_pos_embed</td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">Download Link (Diffusers)</td> <td style="text-align: center;">Download Link (Diffusers)</td>
<td colspan="2" style="text-align: center;"> Coming Soon </td>
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-2b">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-2b">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-2b">🟣 WiseModel</a></td> <td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-2b">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-2b">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-2b">🟣 WiseModel</a></td>
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-5b">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-5b">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-5b">🟣 WiseModel</a></td> <td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-5b">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-5b">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-5b">🟣 WiseModel</a></td>
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-5b-I2V">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-5b-I2V">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-5b-I2V">🟣 WiseModel</a></td> <td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-5b-I2V">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-5b-I2V">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-5b-I2V">🟣 WiseModel</a></td>
<td colspan="2" style="text-align: center;"> Coming Soon </td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">Download Link (SAT)</td> <td style="text-align: center;">Download Link (SAT)</td>
<td colspan="3" style="text-align: center;"><a href="./sat/README_zh.md">SAT</a></td>
<td colspan="2" style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX1.5-5b-SAT">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX1.5-5b-SAT">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX1.5-5b-SAT">🟣 WiseModel</a></td> <td colspan="2" style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX1.5-5b-SAT">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX1.5-5b-SAT">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX1.5-5b-SAT">🟣 WiseModel</a></td>
<td colspan="3" style="text-align: center;"><a href="./sat/README_zh.md">SAT</a></td>
</tr> </tr>
</table> </table>

View File

@ -163,88 +163,87 @@ CogVideoXは、[清影](https://chatglm.cn/video?fr=osm_cogvideox) と同源の
<table style="border-collapse: collapse; width: 100%;"> <table style="border-collapse: collapse; width: 100%;">
<tr> <tr>
<th style="text-align: center;">モデル名</th> <th style="text-align: center;">モデル名</th>
<th style="text-align: center;">CogVideoX1.5-5B (最新)</th>
<th style="text-align: center;">CogVideoX1.5-5B-I2V (最新)</th>
<th style="text-align: center;">CogVideoX-2B</th> <th style="text-align: center;">CogVideoX-2B</th>
<th style="text-align: center;">CogVideoX-5B</th> <th style="text-align: center;">CogVideoX-5B</th>
<th style="text-align: center;">CogVideoX-5B-I2V</th> <th style="text-align: center;">CogVideoX-5B-I2V</th>
<th style="text-align: center;">CogVideoX1.5-5B</th>
<th style="text-align: center;">CogVideoX1.5-5B-I2V</th>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">リリース日</td> <td style="text-align: center;">公開日</td>
<th style="text-align: center;">2024年11月8日</th>
<th style="text-align: center;">2024年11月8日</th>
<th style="text-align: center;">2024年8月6日</th> <th style="text-align: center;">2024年8月6日</th>
<th style="text-align: center;">2024年8月27日</th> <th style="text-align: center;">2024年8月27日</th>
<th style="text-align: center;">2024年9月19日</th> <th style="text-align: center;">2024年9月19日</th>
<th style="text-align: center;">2024年11月8日</th>
<th style="text-align: center;">2024年11月8日</th>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">ビデオ解像度</td> <td style="text-align: center;">ビデオ解像度</td>
<td colspan="3" style="text-align: center;">720 * 480</td>
<td colspan="1" style="text-align: center;">1360 * 768</td> <td colspan="1" style="text-align: center;">1360 * 768</td>
<td colspan="1" style="text-align: center;">256 <= W <=1360<br>256 <= H <=768<br> W,H % 16 == 0</td> <td colspan="1" style="text-align: center;">256 <= W <=1360<br> 256 <= H <=768<br> W,H % 16 == 0</td>
<td colspan="3" style="text-align: center;">720 * 480</td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">推論精度</td> <td style="text-align: center;">推論精度</td>
<td style="text-align: center;"><b>FP16*(推奨)</b>, BF16, FP32, FP8*, INT8, INT4は非対応</td>
<td colspan="2" style="text-align: center;"><b>BF16(推奨)</b>, FP16, FP32, FP8*, INT8, INT4は非対応</td>
<td colspan="2" style="text-align: center;"><b>BF16</b></td> <td colspan="2" style="text-align: center;"><b>BF16</b></td>
<td style="text-align: center;"><b>FP16*(推奨)</b>, BF16, FP32FP8*INT8INT4非対応</td>
<td colspan="2" style="text-align: center;"><b>BF16(推奨)</b>, FP16, FP32FP8*INT8INT4非対応</td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">シングルGPUメモリ消費</td> <td style="text-align: center;">単一GPUメモリ消費量<br></td>
<td style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> FP16: 18GB<br><b>diffusers FP16: 4GBから*</b><br><b>diffusers INT8(torchao): 3.6GBから*</b></td> <td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 66GB <br></td>
<td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 26GB<br><b>diffusers BF16: 5GBから*</b><br><b>diffusers INT8(torchao): 4.4GBから*</b></td> <td style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> FP16: 18GB <br><b>diffusers FP16: 4GB以上* </b><br><b>diffusers INT8(torchao): 3.6GB以上*</b></td>
<td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 66GB<br></td> <td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 26GB <br><b>diffusers BF16 : 5GB以上* </b><br><b>diffusers INT8(torchao): 4.4GB以上* </b></td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">マルチGPUメモリ消費</td> <td style="text-align: center;">複数GPU推論メモリ消費量</td>
<td style="text-align: center;"><b>FP16: 10GB* using diffusers</b><br></td> <td colspan="2" style="text-align: center;"><b>非対応</b><br></td>
<td colspan="2" style="text-align: center;"><b>BF16: 15GB* using diffusers</b><br></td> <td style="text-align: center;"><b>FP16: 10GB* diffusers使用</b><br></td>
<td colspan="2" style="text-align: center;"><b>サポートなし</b><br></td> <td colspan="2" style="text-align: center;"><b>BF16: 15GB* diffusers使用</b><br></td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">推論速度<br>(ステップ数 = 50, FP/BF16)</td> <td style="text-align: center;">推論速度<br>(Step = 50, FP/BF16)</td>
<td style="text-align: center;">単一A100: 約90秒<br>単一H100: 約45秒</td> <td colspan="2" style="text-align: center;">シングルA100: ~1000秒(5秒ビデオ)<br>シングルH100: ~550秒(5秒ビデオ)</td>
<td colspan="2" style="text-align: center;">単一A100: 約180秒<br>単一H100: 約90</td> <td style="text-align: center;">シングルA100: ~90秒<br>シングルH100: ~45</td>
<td colspan="2" style="text-align: center;">単一A100: 約1000秒(5秒動画)<br>単一H100: 約550秒(5秒動画)</td> <td colspan="2" style="text-align: center;">シングルA100: ~180秒<br>シングルH100: ~90秒</td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">プロンプト言語</td> <td style="text-align: center;">プロンプト言語</td>
<td colspan="5" style="text-align: center;">英語*</td> <td colspan="5" style="text-align: center;">英語*</td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">プロンプトトークン制限</td> <td style="text-align: center;">プロンプト長さの上限</td>
<td colspan="3" style="text-align: center;">226トークン</td>
<td colspan="2" style="text-align: center;">224トークン</td> <td colspan="2" style="text-align: center;">224トークン</td>
<td colspan="3" style="text-align: center;">226トークン</td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">ビデオの長さ</td> <td style="text-align: center;">ビデオ長さ</td>
<td colspan="3" style="text-align: center;">6秒</td>
<td colspan="2" style="text-align: center;">5秒または10秒</td> <td colspan="2" style="text-align: center;">5秒または10秒</td>
<td colspan="3" style="text-align: center;">6秒</td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">フレームレート</td> <td style="text-align: center;">フレームレート</td>
<td colspan="3" style="text-align: center;">8 フレーム / </td> <td colspan="2" style="text-align: center;">16フレーム/</td>
<td colspan="2" style="text-align: center;">16 フレーム / </td> <td colspan="3" style="text-align: center;">8フレーム/</td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">位置エンコーディング</td> <td style="text-align: center;">位置エンコーディング</td>
<td style="text-align: center;">3d_sincos_pos_embed</td> <td colspan="2" style="text-align: center;">3d_rope_pos_embed</td>
<td style="text-align: center;">3d_sincos_pos_embed</td>
<td style="text-align: center;">3d_rope_pos_embed</td> <td style="text-align: center;">3d_rope_pos_embed</td>
<td style="text-align: center;">3d_rope_pos_embed + learnable_pos_embed</td> <td style="text-align: center;">3d_rope_pos_embed + learnable_pos_embed</td>
<td style="text-align: center;">3d_rope_pos_embed</td>
<td style="text-align: center;">3d_rope_pos_embed</td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">ダウンロードリンク (Diffusers)</td> <td style="text-align: center;">ダウンロードリンク (Diffusers)</td>
<td colspan="2" style="text-align: center;"> 近日公開 </td>
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-2b">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-2b">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-2b">🟣 WiseModel</a></td> <td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-2b">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-2b">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-2b">🟣 WiseModel</a></td>
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-5b">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-5b">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-5b">🟣 WiseModel</a></td> <td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-5b">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-5b">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-5b">🟣 WiseModel</a></td>
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-5b-I2V">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-5b-I2V">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-5b-I2V">🟣 WiseModel</a></td> <td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-5b-I2V">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-5b-I2V">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-5b-I2V">🟣 WiseModel</a></td>
<td colspan="2" style="text-align: center;">近日公開</td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">ダウンロードリンク (SAT)</td> <td style="text-align: center;">ダウンロードリンク (SAT)</td>
<td colspan="3" style="text-align: center;"><a href="./sat/README_zh.md">SAT</a></td>
<td colspan="2" style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX1.5-5b-SAT">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX1.5-5b-SAT">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX1.5-5b-SAT">🟣 WiseModel</a></td> <td colspan="2" style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX1.5-5b-SAT">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX1.5-5b-SAT">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX1.5-5b-SAT">🟣 WiseModel</a></td>
<td colspan="3" style="text-align: center;"><a href="./sat/README_zh.md">SAT</a></td>
</tr> </tr>
</table> </table>

View File

@ -154,49 +154,49 @@ CogVideoX是 [清影](https://chatglm.cn/video?fr=osm_cogvideox) 同源的开源
<table style="border-collapse: collapse; width: 100%;"> <table style="border-collapse: collapse; width: 100%;">
<tr> <tr>
<th style="text-align: center;">模型名</th> <th style="text-align: center;">模型名</th>
<th style="text-align: center;">CogVideoX1.5-5B (最新)</th>
<th style="text-align: center;">CogVideoX1.5-5B-I2V (最新)</th>
<th style="text-align: center;">CogVideoX-2B</th> <th style="text-align: center;">CogVideoX-2B</th>
<th style="text-align: center;">CogVideoX-5B</th> <th style="text-align: center;">CogVideoX-5B</th>
<th style="text-align: center;">CogVideoX-5B-I2V </th> <th style="text-align: center;">CogVideoX-5B-I2V </th>
<th style="text-align: center;">CogVideoX1.5-5B</th>
<th style="text-align: center;">CogVideoX1.5-5B-I2V</th>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">发布时间</td> <td style="text-align: center;">发布时间</td>
<th style="text-align: center;">2024年11月8日</th>
<th style="text-align: center;">2024年11月8日</th>
<th style="text-align: center;">2024年8月6日</th> <th style="text-align: center;">2024年8月6日</th>
<th style="text-align: center;">2024年8月27日</th> <th style="text-align: center;">2024年8月27日</th>
<th style="text-align: center;">2024年9月19日</th> <th style="text-align: center;">2024年9月19日</th>
<th style="text-align: center;">2024年11月8日</th>
<th style="text-align: center;">2024年11月8日</th>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">视频分辨率</td> <td style="text-align: center;">视频分辨率</td>
<td colspan="3" style="text-align: center;">720 * 480</td>
<td colspan="1" style="text-align: center;">1360 * 768</td> <td colspan="1" style="text-align: center;">1360 * 768</td>
<td colspan="1" style="text-align: center;">256 <= W <=1360<br> 256 <= H <=768<br> W,H % 16 == 0</td> <td colspan="1" style="text-align: center;">256 <= W <=1360<br> 256 <= H <=768<br> W,H % 16 == 0</td>
</tr> <td colspan="3" style="text-align: center;">720 * 480</td>
</tr>
<tr> <tr>
<td style="text-align: center;">推理精度</td> <td style="text-align: center;">推理精度</td>
<td colspan="2" style="text-align: center;"><b>BF16</b></td>
<td style="text-align: center;"><b>FP16*(推荐)</b>, BF16, FP32FP8*INT8不支持INT4</td> <td style="text-align: center;"><b>FP16*(推荐)</b>, BF16, FP32FP8*INT8不支持INT4</td>
<td colspan="2" style="text-align: center;"><b>BF16(推荐)</b>, FP16, FP32FP8*INT8不支持INT4</td> <td colspan="2" style="text-align: center;"><b>BF16(推荐)</b>, FP16, FP32FP8*INT8不支持INT4</td>
<td colspan="2" style="text-align: center;"><b>BF16</b></td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">单GPU显存消耗<br></td> <td style="text-align: center;">单GPU显存消耗<br></td>
<td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 66GB <br></td>
<td style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> FP16: 18GB <br><b>diffusers FP16: 4GB起* </b><br><b>diffusers INT8(torchao): 3.6G起*</b></td> <td style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> FP16: 18GB <br><b>diffusers FP16: 4GB起* </b><br><b>diffusers INT8(torchao): 3.6G起*</b></td>
<td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 26GB <br><b>diffusers BF16 : 5GB起* </b><br><b>diffusers INT8(torchao): 4.4G起* </b></td> <td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 26GB <br><b>diffusers BF16 : 5GB起* </b><br><b>diffusers INT8(torchao): 4.4G起* </b></td>
<td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 66GB <br></td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">多GPU推理显存消耗</td> <td style="text-align: center;">多GPU推理显存消耗</td>
<td colspan="2" style="text-align: center;"><b>不支持</b><br></td>
<td style="text-align: center;"><b>FP16: 10GB* using diffusers</b><br></td> <td style="text-align: center;"><b>FP16: 10GB* using diffusers</b><br></td>
<td colspan="2" style="text-align: center;"><b>BF16: 15GB* using diffusers</b><br></td> <td colspan="2" style="text-align: center;"><b>BF16: 15GB* using diffusers</b><br></td>
<td colspan="2" style="text-align: center;"><b>Not support</b><br></td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">推理速度<br>(Step = 50, FP/BF16)</td> <td style="text-align: center;">推理速度<br>(Step = 50, FP/BF16)</td>
<td colspan="2" style="text-align: center;">单卡A100: ~1000秒(5秒视频)<br>单卡H100: ~550秒(5秒视频)</td>
<td style="text-align: center;">单卡A100: ~90秒<br>单卡H100: ~45秒</td> <td style="text-align: center;">单卡A100: ~90秒<br>单卡H100: ~45秒</td>
<td colspan="2" style="text-align: center;">单卡A100: ~180秒<br>单卡H100: ~90秒</td> <td colspan="2" style="text-align: center;">单卡A100: ~180秒<br>单卡H100: ~90秒</td>
<td colspan="2" style="text-align: center;">单卡A100: ~1000秒(5秒视频)<br>单卡H100: ~550秒(5秒视频)</td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">提示词语言</td> <td style="text-align: center;">提示词语言</td>
@ -204,39 +204,37 @@ CogVideoX是 [清影](https://chatglm.cn/video?fr=osm_cogvideox) 同源的开源
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">提示词长度上限</td> <td style="text-align: center;">提示词长度上限</td>
<td colspan="3" style="text-align: center;">226 Tokens</td>
<td colspan="2" style="text-align: center;">224 Tokens</td> <td colspan="2" style="text-align: center;">224 Tokens</td>
<td colspan="3" style="text-align: center;">226 Tokens</td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">视频长度</td> <td style="text-align: center;">视频长度</td>
<td colspan="3" style="text-align: center;">6 秒</td>
<td colspan="2" style="text-align: center;">5 秒 或 10 秒</td> <td colspan="2" style="text-align: center;">5 秒 或 10 秒</td>
<td colspan="3" style="text-align: center;">6 秒</td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">帧率</td> <td style="text-align: center;">帧率</td>
<td colspan="3" style="text-align: center;">8 帧 / 秒 </td>
<td colspan="2" style="text-align: center;">16 帧 / 秒 </td> <td colspan="2" style="text-align: center;">16 帧 / 秒 </td>
<td colspan="3" style="text-align: center;">8 帧 / 秒 </td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">位置编码</td> <td style="text-align: center;">位置编码</td>
<td style="text-align: center;">3d_sincos_pos_embed</td> <td colspan="2" style="text-align: center;">3d_rope_pos_embed</td>
<td style="text-align: center;">3d_sincos_pos_embed</td>
<td style="text-align: center;">3d_rope_pos_embed</td> <td style="text-align: center;">3d_rope_pos_embed</td>
<td style="text-align: center;">3d_rope_pos_embed + learnable_pos_embed</td> <td style="text-align: center;">3d_rope_pos_embed + learnable_pos_embed</td>
<td style="text-align: center;">3d_rope_pos_embed</td>
<td style="text-align: center;">3d_rope_pos_embed</td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">下载链接 (Diffusers)</td> <td style="text-align: center;">下载链接 (Diffusers)</td>
<td colspan="2" style="text-align: center;"> 即将推出 </td>
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-2b">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-2b">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-2b">🟣 WiseModel</a></td> <td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-2b">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-2b">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-2b">🟣 WiseModel</a></td>
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-5b">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-5b">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-5b">🟣 WiseModel</a></td> <td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-5b">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-5b">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-5b">🟣 WiseModel</a></td>
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-5b-I2V">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-5b-I2V">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-5b-I2V">🟣 WiseModel</a></td> <td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-5b-I2V">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-5b-I2V">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-5b-I2V">🟣 WiseModel</a></td>
<td colspan="2" style="text-align: center;"> 即将推出 </td>
</tr> </tr>
<tr> <tr>
<td style="text-align: center;">下载链接 (SAT)</td> <td style="text-align: center;">下载链接 (SAT)</td>
<td colspan="3" style="text-align: center;"><a href="./sat/README_zh.md">SAT</a></td>
<td colspan="2" style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX1.5-5b-SAT">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX1.5-5b-SAT">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX1.5-5b-SAT">🟣 WiseModel</a></td> <td colspan="2" style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX1.5-5b-SAT">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX1.5-5b-SAT">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX1.5-5b-SAT">🟣 WiseModel</a></td>
<td colspan="3" style="text-align: center;"><a href="./sat/README_zh.md">SAT</a></td>
</tr> </tr>
</table> </table>

View File

@ -23,7 +23,7 @@ model:
params: params:
time_embed_dim: 512 time_embed_dim: 512
elementwise_affine: True elementwise_affine: True
num_frames: 81 num_frames: 81 # for 5 seconds and 161 for 10 seconds
time_compressed_rate: 4 time_compressed_rate: 4
latent_width: 300 latent_width: 300
latent_height: 300 latent_height: 300

View File

@ -25,11 +25,10 @@ model:
network_config: network_config:
target: dit_video_concat.DiffusionTransformer target: dit_video_concat.DiffusionTransformer
params: params:
# space_interpolation: 1.875
ofs_embed_dim: 512 ofs_embed_dim: 512
time_embed_dim: 512 time_embed_dim: 512
elementwise_affine: True elementwise_affine: True
num_frames: 81 num_frames: 81 # for 5 seconds and 161 for 10 seconds
time_compressed_rate: 4 time_compressed_rate: 4
latent_width: 300 latent_width: 300
latent_height: 300 latent_height: 300

View File

@ -1,16 +1,14 @@
args: args:
image2video: False # True for image2video, False for text2video # image2video: True # True for image2video, False for text2video
latent_channels: 16 latent_channels: 16
mode: inference mode: inference
load: "{your CogVideoX SAT folder}/transformer" # This is for Full model without lora adapter load: "{your CogVideoX SAT folder}/transformer" # This is for Full model without lora adapter
# load: "{your lora folder} such as zRzRzRzRzRzRzR/lora-disney-08-20-13-28" # This is for Full model without lora adapter
batch_size: 1 batch_size: 1
input_type: txt input_type: txt
input_file: configs/test.txt input_file: configs/test.txt
sampling_image_size: [480, 720] sampling_image_size: [768, 1360] # remove this for I2V
sampling_num_frames: 13 # Must be 13, 11 or 9 sampling_num_frames: 22 # 42 for 10 seconds and 22 for 5 seconds
sampling_fps: 8 sampling_fps: 16
# fp16: True # For CogVideoX-2B bf16: True
bf16: True # For CogVideoX-5B and CoGVideoX-5B-I2V output_dir: outputs
output_dir: outputs/
force_inference: True force_inference: True

View File

@ -192,13 +192,13 @@ class SATVideoDiffusionEngine(nn.Module):
for i in range(fake_cp_size): for i in range(fake_cp_size):
end_frame = start_frame + latent_time // fake_cp_size + (1 if i < latent_time % fake_cp_size else 0) end_frame = start_frame + latent_time // fake_cp_size + (1 if i < latent_time % fake_cp_size else 0)
fake_cp_rank0 = True if i == 0 else False use_cp = True if i == 0 else False
clear_fake_cp_cache = True if i == fake_cp_size - 1 else False clear_fake_cp_cache = True if i == fake_cp_size - 1 else False
with torch.no_grad(): with torch.no_grad():
recon = self.first_stage_model.decode( recon = self.first_stage_model.decode(
z_now[:, :, start_frame:end_frame].contiguous(), z_now[:, :, start_frame:end_frame].contiguous(),
clear_fake_cp_cache=clear_fake_cp_cache, clear_fake_cp_cache=clear_fake_cp_cache,
fake_cp_rank0=fake_cp_rank0, use_cp=use_cp,
) )
recons.append(recon) recons.append(recon)
start_frame = end_frame start_frame = end_frame

View File

@ -7,7 +7,6 @@ import numpy as np
import torch import torch
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from sat.model.base_model import BaseModel, non_conflict from sat.model.base_model import BaseModel, non_conflict
from sat.model.mixins import BaseMixin from sat.model.mixins import BaseMixin
from sat.transformer_defaults import HOOKS_DEFAULT, attention_fn_default from sat.transformer_defaults import HOOKS_DEFAULT, attention_fn_default

View File

@ -4,7 +4,7 @@ echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
environs="WORLD_SIZE=1 RANK=0 LOCAL_RANK=0 LOCAL_WORLD_SIZE=1" environs="WORLD_SIZE=1 RANK=0 LOCAL_RANK=0 LOCAL_WORLD_SIZE=1"
run_cmd="$environs python sample_video.py --base configs/cogvideox1.5_5b.yaml configs/test_inference.yaml --seed $RANDOM" run_cmd="$environs python sample_video.py --base configs/cogvideox1.5_5b.yaml configs/inference.yaml --seed $RANDOM"
echo ${run_cmd} echo ${run_cmd}
eval ${run_cmd} eval ${run_cmd}

View File

@ -1,17 +1,13 @@
import logging import logging
import math import math
import re import re
import random
from abc import abstractmethod from abc import abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import torch.distributed import torch.distributed
import torch.nn as nn
from einops import rearrange
from packaging import version from packaging import version
from vae_modules.ema import LitEma from vae_modules.ema import LitEma
@ -56,34 +52,16 @@ class AbstractAutoencoder(pl.LightningModule):
if version.parse(torch.__version__) >= version.parse("2.0.0"): if version.parse(torch.__version__) >= version.parse("2.0.0"):
self.automatic_optimization = False self.automatic_optimization = False
# def apply_ckpt(self, ckpt: Union[None, str, dict]):
# if ckpt is None:
# return
# if isinstance(ckpt, str):
# ckpt = {
# "target": "sgm.modules.checkpoint.CheckpointEngine",
# "params": {"ckpt_path": ckpt},
# }
# engine = instantiate_from_config(ckpt)
# engine(self)
def apply_ckpt(self, ckpt: Union[None, str, dict]): def apply_ckpt(self, ckpt: Union[None, str, dict]):
if ckpt is None: if ckpt is None:
return return
self.init_from_ckpt(ckpt) if isinstance(ckpt, str):
ckpt = {
def init_from_ckpt(self, path, ignore_keys=list()): "target": "sgm.modules.checkpoint.CheckpointEngine",
sd = torch.load(path, map_location="cpu")["state_dict"] "params": {"ckpt_path": ckpt},
keys = list(sd.keys()) }
for k in keys: engine = instantiate_from_config(ckpt)
for ik in ignore_keys: engine(self)
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
missing_keys, unexpected_keys = self.load_state_dict(sd, strict=False)
print("Missing keys: ", missing_keys)
print("Unexpected keys: ", unexpected_keys)
print(f"Restored from {path}")
@abstractmethod @abstractmethod
def get_input(self, batch) -> Any: def get_input(self, batch) -> Any:
@ -119,7 +97,9 @@ class AbstractAutoencoder(pl.LightningModule):
def instantiate_optimizer_from_config(self, params, lr, cfg): def instantiate_optimizer_from_config(self, params, lr, cfg):
logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config") logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
return get_obj_from_str(cfg["target"])(params, lr=lr, **cfg.get("params", dict())) return get_obj_from_str(cfg["target"])(
params, lr=lr, **cfg.get("params", dict())
)
def configure_optimizers(self) -> Any: def configure_optimizers(self) -> Any:
raise NotImplementedError() raise NotImplementedError()
@ -216,12 +196,13 @@ class AutoencodingEngine(AbstractAutoencoder):
return self.decoder.get_last_layer() return self.decoder.get_last_layer()
def encode( def encode(
self, self,
x: torch.Tensor, x: torch.Tensor,
return_reg_log: bool = False, return_reg_log: bool = False,
unregularized: bool = False, unregularized: bool = False,
**kwargs,
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
z = self.encoder(x) z = self.encoder(x, **kwargs)
if unregularized: if unregularized:
return z, dict() return z, dict()
z, reg_log = self.regularization(z) z, reg_log = self.regularization(z)

View File

@ -101,8 +101,6 @@ def _gather(input_, dim):
group = get_context_parallel_group() group = get_context_parallel_group()
cp_rank = get_context_parallel_rank() cp_rank = get_context_parallel_rank()
# print('in _gather, cp_rank:', cp_rank, 'input_size:', input_.shape)
input_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous() input_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous()
if cp_rank == 0: if cp_rank == 0:
input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous() input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous()
@ -127,12 +125,9 @@ def _gather(input_, dim):
def _conv_split(input_, dim, kernel_size): def _conv_split(input_, dim, kernel_size):
cp_world_size = get_context_parallel_world_size() cp_world_size = get_context_parallel_world_size()
# Bypass the function if context parallel is 1
if cp_world_size == 1: if cp_world_size == 1:
return input_ return input_
# print('in _conv_split, cp_rank:', cp_rank, 'input_size:', input_.shape)
cp_rank = get_context_parallel_rank() cp_rank = get_context_parallel_rank()
dim_size = (input_.size()[dim] - kernel_size) // cp_world_size dim_size = (input_.size()[dim] - kernel_size) // cp_world_size
@ -140,14 +135,11 @@ def _conv_split(input_, dim, kernel_size):
if cp_rank == 0: if cp_rank == 0:
output = input_.transpose(dim, 0)[: dim_size + kernel_size].transpose(dim, 0) output = input_.transpose(dim, 0)[: dim_size + kernel_size].transpose(dim, 0)
else: else:
# output = input_.transpose(dim, 0)[cp_rank * dim_size + 1:(cp_rank + 1) * dim_size + kernel_size].transpose(dim, 0)
output = input_.transpose(dim, 0)[ output = input_.transpose(dim, 0)[
cp_rank * dim_size + kernel_size : (cp_rank + 1) * dim_size + kernel_size cp_rank * dim_size + kernel_size : (cp_rank + 1) * dim_size + kernel_size
].transpose(dim, 0) ].transpose(dim, 0)
output = output.contiguous() output = output.contiguous()
# print('out _conv_split, cp_rank:', cp_rank, 'input_size:', output.shape)
return output return output
@ -160,9 +152,6 @@ def _conv_gather(input_, dim, kernel_size):
group = get_context_parallel_group() group = get_context_parallel_group()
cp_rank = get_context_parallel_rank() cp_rank = get_context_parallel_rank()
# print('in _conv_gather, cp_rank:', cp_rank, 'input_size:', input_.shape)
input_first_kernel_ = input_.transpose(0, dim)[:kernel_size].transpose(0, dim).contiguous() input_first_kernel_ = input_.transpose(0, dim)[:kernel_size].transpose(0, dim).contiguous()
if cp_rank == 0: if cp_rank == 0:
input_ = input_.transpose(0, dim)[kernel_size:].transpose(0, dim).contiguous() input_ = input_.transpose(0, dim)[kernel_size:].transpose(0, dim).contiguous()
@ -255,17 +244,12 @@ def _fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding=Non
if recv_rank % cp_world_size == cp_world_size - 1: if recv_rank % cp_world_size == cp_world_size - 1:
recv_rank += cp_world_size recv_rank += cp_world_size
# req_send = torch.distributed.isend(input_[-kernel_size + 1:].contiguous(), send_rank, group=group)
# recv_buffer = torch.empty_like(input_[-kernel_size + 1:]).contiguous()
# req_recv = torch.distributed.recv(recv_buffer, recv_rank, group=group)
# req_recv.wait()
recv_buffer = torch.empty_like(input_[-kernel_size + 1 :]).contiguous() recv_buffer = torch.empty_like(input_[-kernel_size + 1 :]).contiguous()
if cp_rank < cp_world_size - 1: if cp_rank < cp_world_size - 1:
req_send = torch.distributed.isend(input_[-kernel_size + 1 :].contiguous(), send_rank, group=group) req_send = torch.distributed.isend(input_[-kernel_size + 1 :].contiguous(), send_rank, group=group)
if cp_rank > 0: if cp_rank > 0:
req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group) req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group)
# req_send = torch.distributed.isend(input_[-kernel_size + 1:].contiguous(), send_rank, group=group)
# req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group)
if cp_rank == 0: if cp_rank == 0:
if cache_padding is not None: if cache_padding is not None:
@ -421,7 +405,6 @@ class ContextParallelGroupNorm(torch.nn.GroupNorm):
def Normalize(in_channels, gather=False, **kwargs): def Normalize(in_channels, gather=False, **kwargs):
# same for 3D and 2D
if gather: if gather:
return ContextParallelGroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) return ContextParallelGroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
else: else:
@ -468,8 +451,8 @@ class SpatialNorm3D(nn.Module):
kernel_size=1, kernel_size=1,
) )
def forward(self, f, zq, clear_fake_cp_cache=True, fake_cp_rank0=True): def forward(self, f, zq, clear_fake_cp_cache=True, fake_cp=True):
if f.shape[2] > 1 and get_context_parallel_rank() == 0 and fake_cp_rank0: if f.shape[2] > 1 and get_context_parallel_rank() == 0 and fake_cp:
f_first, f_rest = f[:, :, :1], f[:, :, 1:] f_first, f_rest = f[:, :, :1], f[:, :, 1:]
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:] f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
zq_first, zq_rest = zq[:, :, :1], zq[:, :, 1:] zq_first, zq_rest = zq[:, :, :1], zq[:, :, 1:]
@ -531,13 +514,11 @@ class Upsample3D(nn.Module):
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
self.compress_time = compress_time self.compress_time = compress_time
def forward(self, x, fake_cp_rank0=True): def forward(self, x, fake_cp=True):
if self.compress_time and x.shape[2] > 1: if self.compress_time and x.shape[2] > 1:
if get_context_parallel_rank() == 0 and fake_cp_rank0: if get_context_parallel_rank() == 0 and fake_cp:
# print(x.shape)
# split first frame # split first frame
x_first, x_rest = x[:, :, 0], x[:, :, 1:] x_first, x_rest = x[:, :, 0], x[:, :, 1:]
x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0, mode="nearest") x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0, mode="nearest")
splits = torch.split(x_rest, 32, dim=1) splits = torch.split(x_rest, 32, dim=1)
@ -545,8 +526,6 @@ class Upsample3D(nn.Module):
torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits
] ]
x_rest = torch.cat(interpolated_splits, dim=1) x_rest = torch.cat(interpolated_splits, dim=1)
# x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0, mode="nearest")
x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2) x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2)
else: else:
splits = torch.split(x, 32, dim=1) splits = torch.split(x, 32, dim=1)
@ -555,13 +534,10 @@ class Upsample3D(nn.Module):
] ]
x = torch.cat(interpolated_splits, dim=1) x = torch.cat(interpolated_splits, dim=1)
# x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
else: else:
# only interpolate 2D # only interpolate 2D
t = x.shape[2] t = x.shape[2]
x = rearrange(x, "b c t h w -> (b t) c h w") x = rearrange(x, "b c t h w -> (b t) c h w")
# x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
splits = torch.split(x, 32, dim=1) splits = torch.split(x, 32, dim=1)
interpolated_splits = [ interpolated_splits = [
@ -590,12 +566,12 @@ class DownSample3D(nn.Module):
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0) self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0)
self.compress_time = compress_time self.compress_time = compress_time
def forward(self, x, fake_cp_rank0=True): def forward(self, x, fake_cp=True):
if self.compress_time and x.shape[2] > 1: if self.compress_time and x.shape[2] > 1:
h, w = x.shape[-2:] h, w = x.shape[-2:]
x = rearrange(x, "b c t h w -> (b h w) c t") x = rearrange(x, "b c t h w -> (b h w) c t")
if get_context_parallel_rank() == 0 and fake_cp_rank0: if get_context_parallel_rank() == 0 and fake_cp:
# split first frame # split first frame
x_first, x_rest = x[..., 0], x[..., 1:] x_first, x_rest = x[..., 0], x[..., 1:]
@ -693,17 +669,13 @@ class ContextParallelResnetBlock3D(nn.Module):
padding=0, padding=0,
) )
def forward(self, x, temb, zq=None, clear_fake_cp_cache=True, fake_cp_rank0=True): def forward(self, x, temb, zq=None, clear_fake_cp_cache=True, fake_cp=True):
h = x h = x
# if isinstance(self.norm1, torch.nn.GroupNorm):
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
if zq is not None: if zq is not None:
h = self.norm1(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0) h = self.norm1(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=fake_cp)
else: else:
h = self.norm1(h) h = self.norm1(h)
# if isinstance(self.norm1, torch.nn.GroupNorm):
# h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
h = nonlinearity(h) h = nonlinearity(h)
h = self.conv1(h, clear_cache=clear_fake_cp_cache) h = self.conv1(h, clear_cache=clear_fake_cp_cache)
@ -711,14 +683,10 @@ class ContextParallelResnetBlock3D(nn.Module):
if temb is not None: if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None] h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None]
# if isinstance(self.norm2, torch.nn.GroupNorm):
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
if zq is not None: if zq is not None:
h = self.norm2(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0) h = self.norm2(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=fake_cp)
else: else:
h = self.norm2(h) h = self.norm2(h)
# if isinstance(self.norm2, torch.nn.GroupNorm):
# h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
h = nonlinearity(h) h = nonlinearity(h)
h = self.dropout(h) h = self.dropout(h)
@ -827,32 +795,33 @@ class ContextParallelEncoder3D(nn.Module):
kernel_size=3, kernel_size=3,
) )
def forward(self, x, clear_fake_cp_cache=True, fake_cp_rank0=True): def forward(self, x, use_cp=True):
global _USE_CP
_USE_CP = use_cp
# timestep embedding # timestep embedding
temb = None temb = None
# downsampling # downsampling
h = self.conv_in(x, clear_cache=clear_fake_cp_cache) hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions): for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks): for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](h, temb, clear_fake_cp_cache=clear_fake_cp_cache) h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0: if len(self.down[i_level].attn) > 0:
print("Attention not implemented")
h = self.down[i_level].attn[i_block](h) h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1: if i_level != self.num_resolutions - 1:
h = self.down[i_level].downsample(h, fake_cp_rank0=fake_cp_rank0) hs.append(self.down[i_level].downsample(hs[-1]))
# middle # middle
h = self.mid.block_1(h, temb, clear_fake_cp_cache=clear_fake_cp_cache) h = hs[-1]
h = self.mid.block_2(h, temb, clear_fake_cp_cache=clear_fake_cp_cache) h = self.mid.block_1(h, temb)
h = self.mid.block_2(h, temb)
# end # end
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
h = self.norm_out(h) h = self.norm_out(h)
# h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
h = nonlinearity(h) h = nonlinearity(h)
h = self.conv_out(h, clear_cache=clear_fake_cp_cache) h = self.conv_out(h)
return h return h
@ -895,11 +864,9 @@ class ContextParallelDecoder3D(nn.Module):
zq_ch = z_channels zq_ch = z_channels
# compute in_ch_mult, block_in and curr_res at lowest res # compute in_ch_mult, block_in and curr_res at lowest res
in_ch_mult = (1,) + tuple(ch_mult)
block_in = ch * ch_mult[self.num_resolutions - 1] block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1) curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res) self.z_shape = (1, z_channels, curr_res, curr_res)
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
self.conv_in = ContextParallelCausalConv3d( self.conv_in = ContextParallelCausalConv3d(
chan_in=z_channels, chan_in=z_channels,
@ -955,11 +922,6 @@ class ContextParallelDecoder3D(nn.Module):
up.block = block up.block = block
up.attn = attn up.attn = attn
if i_level != 0: if i_level != 0:
# # Symmetrical enc-dec
if i_level <= self.temporal_compress_level:
up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=True)
else:
up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False)
if i_level < self.num_resolutions - self.temporal_compress_level: if i_level < self.num_resolutions - self.temporal_compress_level:
up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False) up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False)
else: else:
@ -974,7 +936,9 @@ class ContextParallelDecoder3D(nn.Module):
kernel_size=3, kernel_size=3,
) )
def forward(self, z, clear_fake_cp_cache=True, fake_cp_rank0=True): def forward(self, z, clear_fake_cp_cache=True, use_cp=True):
global _USE_CP
_USE_CP = use_cp
self.last_z_shape = z.shape self.last_z_shape = z.shape
# timestep embedding # timestep embedding
@ -987,25 +951,25 @@ class ContextParallelDecoder3D(nn.Module):
h = self.conv_in(z, clear_cache=clear_fake_cp_cache) h = self.conv_in(z, clear_cache=clear_fake_cp_cache)
# middle # middle
h = self.mid.block_1(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0) h = self.mid.block_1(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=use_cp)
h = self.mid.block_2(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0) h = self.mid.block_2(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=use_cp)
# upsampling # upsampling
for i_level in reversed(range(self.num_resolutions)): for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1): for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block]( h = self.up[i_level].block[i_block](
h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0 h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=use_cp
) )
if len(self.up[i_level].attn) > 0: if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h, zq) h = self.up[i_level].attn[i_block](h, zq)
if i_level != 0: if i_level != 0:
h = self.up[i_level].upsample(h, fake_cp_rank0=fake_cp_rank0) h = self.up[i_level].upsample(h, fake_cp=use_cp)
# end # end
if self.give_pre_end: if self.give_pre_end:
return h return h
h = self.norm_out(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0) h = self.norm_out(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=use_cp)
h = nonlinearity(h) h = nonlinearity(h)
h = self.conv_out(h, clear_cache=clear_fake_cp_cache) h = self.conv_out(h, clear_cache=clear_fake_cp_cache)