mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 19:41:59 +08:00
commit
e2987ff565
47
README.md
47
README.md
@ -171,49 +171,49 @@ models we currently offer, along with their foundational information.
|
||||
<table style="border-collapse: collapse; width: 100%;">
|
||||
<tr>
|
||||
<th style="text-align: center;">Model Name</th>
|
||||
<th style="text-align: center;">CogVideoX1.5-5B (Latest)</th>
|
||||
<th style="text-align: center;">CogVideoX1.5-5B-I2V (Latest)</th>
|
||||
<th style="text-align: center;">CogVideoX-2B</th>
|
||||
<th style="text-align: center;">CogVideoX-5B</th>
|
||||
<th style="text-align: center;">CogVideoX-5B-I2V</th>
|
||||
<th style="text-align: center;">CogVideoX1.5-5B</th>
|
||||
<th style="text-align: center;">CogVideoX1.5-5B-I2V</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">Release Date</td>
|
||||
<th style="text-align: center;">November 8, 2024</th>
|
||||
<th style="text-align: center;">November 8, 2024</th>
|
||||
<th style="text-align: center;">August 6, 2024</th>
|
||||
<th style="text-align: center;">August 27, 2024</th>
|
||||
<th style="text-align: center;">September 19, 2024</th>
|
||||
<th style="text-align: center;">November 8, 2024</th>
|
||||
<th style="text-align: center;">November 8, 2024</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">Video Resolution</td>
|
||||
<td colspan="3" style="text-align: center;">720 * 480</td>
|
||||
<td colspan="1" style="text-align: center;">1360 * 768</td>
|
||||
<td colspan="1" style="text-align: center;">256 <= W <=1360<br>256 <= H <=768<br> W,H % 16 == 0</td>
|
||||
<td colspan="1" style="text-align: center;">256 <= W <=1360<br> 256 <= H <=768<br> W,H % 16 == 0</td>
|
||||
<td colspan="3" style="text-align: center;">720 * 480</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">Inference Precision</td>
|
||||
<td style="text-align: center;"><b>FP16*(recommended)</b>, BF16, FP32, FP8*, INT8, not supported: INT4</td>
|
||||
<td colspan="2" style="text-align: center;"><b>BF16(recommended)</b>, FP16, FP32, FP8*, INT8, not supported: INT4</td>
|
||||
<td colspan="2" style="text-align: center;"><b>BF16</b></td>
|
||||
<td style="text-align: center;"><b>FP16*(Recommended)</b>, BF16, FP32, FP8*, INT8, Not supported: INT4</td>
|
||||
<td colspan="2" style="text-align: center;"><b>BF16 (Recommended)</b>, FP16, FP32, FP8*, INT8, Not supported: INT4</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">Single GPU Memory Usage</td>
|
||||
<td style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> FP16: 18GB<br><b>diffusers FP16: from 4GB*</b><br><b>diffusers INT8(torchao): from 3.6GB*</b></td>
|
||||
<td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 26GB<br><b>diffusers BF16 : from 5GB*</b><br><b>diffusers INT8(torchao): from 4.4GB*</b></td>
|
||||
<td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 66GB<br></td>
|
||||
<td style="text-align: center;">Single GPU Memory Usage<br></td>
|
||||
<td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 66GB <br></td>
|
||||
<td style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> FP16: 18GB <br><b>diffusers FP16: 4GB minimum* </b><br><b>diffusers INT8 (torchao): 3.6GB minimum*</b></td>
|
||||
<td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 26GB <br><b>diffusers BF16 : 5GB minimum* </b><br><b>diffusers INT8 (torchao): 4.4GB minimum* </b></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">Multi-GPU Memory Usage</td>
|
||||
<td colspan="2" style="text-align: center;"><b>Not Supported</b><br></td>
|
||||
<td style="text-align: center;"><b>FP16: 10GB* using diffusers</b><br></td>
|
||||
<td colspan="2" style="text-align: center;"><b>BF16: 15GB* using diffusers</b><br></td>
|
||||
<td colspan="2" style="text-align: center;"><b>Not supported</b><br></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">Inference Speed<br>(Step = 50, FP/BF16)</td>
|
||||
<td colspan="2" style="text-align: center;">Single A100: ~1000 seconds (5-second video)<br>Single H100: ~550 seconds (5-second video)</td>
|
||||
<td style="text-align: center;">Single A100: ~90 seconds<br>Single H100: ~45 seconds</td>
|
||||
<td colspan="2" style="text-align: center;">Single A100: ~180 seconds<br>Single H100: ~90 seconds</td>
|
||||
<td colspan="2" style="text-align: center;">Single A100: ~1000 seconds (5-second video)<br>Single H100: ~550 seconds (5-second video)</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">Prompt Language</td>
|
||||
@ -221,38 +221,37 @@ models we currently offer, along with their foundational information.
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">Prompt Token Limit</td>
|
||||
<td colspan="3" style="text-align: center;">226 Tokens</td>
|
||||
<td colspan="2" style="text-align: center;">224 Tokens</td>
|
||||
<td colspan="3" style="text-align: center;">226 Tokens</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">Video Length</td>
|
||||
<td colspan="2" style="text-align: center;">5 seconds or 10 seconds</td>
|
||||
<td colspan="3" style="text-align: center;">6 seconds</td>
|
||||
<td colspan="2" style="text-align: center;">5 or 10 seconds</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">Frame Rate</td>
|
||||
<td colspan="3" style="text-align: center;">8 frames / second</td>
|
||||
<td colspan="2" style="text-align: center;">16 frames / second</td>
|
||||
<td colspan="2" style="text-align: center;">16 frames / second </td>
|
||||
<td colspan="3" style="text-align: center;">8 frames / second </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">Positional Encoding</td>
|
||||
<td style="text-align: center;">3d_sincos_pos_embed</td>
|
||||
<td style="text-align: center;">Position Encoding</td>
|
||||
<td colspan="2" style="text-align: center;">3d_rope_pos_embed</td>
|
||||
<td style="text-align: center;">3d_sincos_pos_embed</td>
|
||||
<td style="text-align: center;">3d_rope_pos_embed</td>
|
||||
<td style="text-align: center;">3d_rope_pos_embed + learnable_pos_embed</td>
|
||||
<td style="text-align: center;">3d_rope_pos_embed</td>
|
||||
<td style="text-align: center;">3d_rope_pos_embed</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">Download Link (Diffusers)</td>
|
||||
<td colspan="2" style="text-align: center;"> Coming Soon </td>
|
||||
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-2b">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-2b">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-2b">🟣 WiseModel</a></td>
|
||||
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-5b">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-5b">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-5b">🟣 WiseModel</a></td>
|
||||
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-5b-I2V">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-5b-I2V">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-5b-I2V">🟣 WiseModel</a></td>
|
||||
<td colspan="2" style="text-align: center;"> Coming Soon </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">Download Link (SAT)</td>
|
||||
<td colspan="3" style="text-align: center;"><a href="./sat/README_zh.md">SAT</a></td>
|
||||
<td colspan="2" style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX1.5-5b-SAT">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX1.5-5b-SAT">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX1.5-5b-SAT">🟣 WiseModel</a></td>
|
||||
<td colspan="3" style="text-align: center;"><a href="./sat/README_zh.md">SAT</a></td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
|
63
README_ja.md
63
README_ja.md
@ -163,88 +163,87 @@ CogVideoXは、[清影](https://chatglm.cn/video?fr=osm_cogvideox) と同源の
|
||||
<table style="border-collapse: collapse; width: 100%;">
|
||||
<tr>
|
||||
<th style="text-align: center;">モデル名</th>
|
||||
<th style="text-align: center;">CogVideoX1.5-5B (最新)</th>
|
||||
<th style="text-align: center;">CogVideoX1.5-5B-I2V (最新)</th>
|
||||
<th style="text-align: center;">CogVideoX-2B</th>
|
||||
<th style="text-align: center;">CogVideoX-5B</th>
|
||||
<th style="text-align: center;">CogVideoX-5B-I2V</th>
|
||||
<th style="text-align: center;">CogVideoX1.5-5B</th>
|
||||
<th style="text-align: center;">CogVideoX1.5-5B-I2V</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">リリース日</td>
|
||||
<td style="text-align: center;">公開日</td>
|
||||
<th style="text-align: center;">2024年11月8日</th>
|
||||
<th style="text-align: center;">2024年11月8日</th>
|
||||
<th style="text-align: center;">2024年8月6日</th>
|
||||
<th style="text-align: center;">2024年8月27日</th>
|
||||
<th style="text-align: center;">2024年9月19日</th>
|
||||
<th style="text-align: center;">2024年11月8日</th>
|
||||
<th style="text-align: center;">2024年11月8日</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">ビデオ解像度</td>
|
||||
<td colspan="3" style="text-align: center;">720 * 480</td>
|
||||
<td colspan="1" style="text-align: center;">1360 * 768</td>
|
||||
<td colspan="1" style="text-align: center;">256 <= W <=1360<br>256 <= H <=768<br> W,H % 16 == 0</td>
|
||||
<td colspan="1" style="text-align: center;">256 <= W <=1360<br> 256 <= H <=768<br> W,H % 16 == 0</td>
|
||||
<td colspan="3" style="text-align: center;">720 * 480</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">推論精度</td>
|
||||
<td style="text-align: center;"><b>FP16*(推奨)</b>, BF16, FP32, FP8*, INT8, INT4は非対応</td>
|
||||
<td colspan="2" style="text-align: center;"><b>BF16(推奨)</b>, FP16, FP32, FP8*, INT8, INT4は非対応</td>
|
||||
<td colspan="2" style="text-align: center;"><b>BF16</b></td>
|
||||
<td style="text-align: center;"><b>FP16*(推奨)</b>, BF16, FP32,FP8*,INT8,INT4非対応</td>
|
||||
<td colspan="2" style="text-align: center;"><b>BF16(推奨)</b>, FP16, FP32,FP8*,INT8,INT4非対応</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">シングルGPUメモリ消費</td>
|
||||
<td style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> FP16: 18GB<br><b>diffusers FP16: 4GBから*</b><br><b>diffusers INT8(torchao): 3.6GBから*</b></td>
|
||||
<td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 26GB<br><b>diffusers BF16: 5GBから*</b><br><b>diffusers INT8(torchao): 4.4GBから*</b></td>
|
||||
<td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 66GB<br></td>
|
||||
<td style="text-align: center;">単一GPUメモリ消費量<br></td>
|
||||
<td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 66GB <br></td>
|
||||
<td style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> FP16: 18GB <br><b>diffusers FP16: 4GB以上* </b><br><b>diffusers INT8(torchao): 3.6GB以上*</b></td>
|
||||
<td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 26GB <br><b>diffusers BF16 : 5GB以上* </b><br><b>diffusers INT8(torchao): 4.4GB以上* </b></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">マルチGPUメモリ消費</td>
|
||||
<td style="text-align: center;"><b>FP16: 10GB* using diffusers</b><br></td>
|
||||
<td colspan="2" style="text-align: center;"><b>BF16: 15GB* using diffusers</b><br></td>
|
||||
<td colspan="2" style="text-align: center;"><b>サポートなし</b><br></td>
|
||||
<td style="text-align: center;">複数GPU推論メモリ消費量</td>
|
||||
<td colspan="2" style="text-align: center;"><b>非対応</b><br></td>
|
||||
<td style="text-align: center;"><b>FP16: 10GB* diffusers使用</b><br></td>
|
||||
<td colspan="2" style="text-align: center;"><b>BF16: 15GB* diffusers使用</b><br></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">推論速度<br>(ステップ数 = 50, FP/BF16)</td>
|
||||
<td style="text-align: center;">単一A100: 約90秒<br>単一H100: 約45秒</td>
|
||||
<td colspan="2" style="text-align: center;">単一A100: 約180秒<br>単一H100: 約90秒</td>
|
||||
<td colspan="2" style="text-align: center;">単一A100: 約1000秒(5秒動画)<br>単一H100: 約550秒(5秒動画)</td>
|
||||
<td style="text-align: center;">推論速度<br>(Step = 50, FP/BF16)</td>
|
||||
<td colspan="2" style="text-align: center;">シングルA100: ~1000秒(5秒ビデオ)<br>シングルH100: ~550秒(5秒ビデオ)</td>
|
||||
<td style="text-align: center;">シングルA100: ~90秒<br>シングルH100: ~45秒</td>
|
||||
<td colspan="2" style="text-align: center;">シングルA100: ~180秒<br>シングルH100: ~90秒</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">プロンプト言語</td>
|
||||
<td colspan="5" style="text-align: center;">英語*</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">プロンプトトークン制限</td>
|
||||
<td colspan="3" style="text-align: center;">226トークン</td>
|
||||
<td style="text-align: center;">プロンプト長さの上限</td>
|
||||
<td colspan="2" style="text-align: center;">224トークン</td>
|
||||
<td colspan="3" style="text-align: center;">226トークン</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">ビデオの長さ</td>
|
||||
<td colspan="3" style="text-align: center;">6秒</td>
|
||||
<td style="text-align: center;">ビデオ長さ</td>
|
||||
<td colspan="2" style="text-align: center;">5秒または10秒</td>
|
||||
<td colspan="3" style="text-align: center;">6秒</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">フレームレート</td>
|
||||
<td colspan="3" style="text-align: center;">8 フレーム / 秒</td>
|
||||
<td colspan="2" style="text-align: center;">16 フレーム / 秒</td>
|
||||
<td colspan="2" style="text-align: center;">16フレーム/秒</td>
|
||||
<td colspan="3" style="text-align: center;">8フレーム/秒</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">位置エンコーディング</td>
|
||||
<td style="text-align: center;">3d_sincos_pos_embed</td>
|
||||
<td colspan="2" style="text-align: center;">3d_rope_pos_embed</td>
|
||||
<td style="text-align: center;">3d_sincos_pos_embed</td>
|
||||
<td style="text-align: center;">3d_rope_pos_embed</td>
|
||||
<td style="text-align: center;">3d_rope_pos_embed + learnable_pos_embed</td>
|
||||
<td style="text-align: center;">3d_rope_pos_embed</td>
|
||||
<td style="text-align: center;">3d_rope_pos_embed</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">ダウンロードリンク (Diffusers)</td>
|
||||
<td colspan="2" style="text-align: center;"> 近日公開 </td>
|
||||
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-2b">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-2b">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-2b">🟣 WiseModel</a></td>
|
||||
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-5b">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-5b">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-5b">🟣 WiseModel</a></td>
|
||||
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-5b-I2V">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-5b-I2V">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-5b-I2V">🟣 WiseModel</a></td>
|
||||
<td colspan="2" style="text-align: center;">近日公開</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">ダウンロードリンク (SAT)</td>
|
||||
<td colspan="3" style="text-align: center;"><a href="./sat/README_zh.md">SAT</a></td>
|
||||
<td colspan="2" style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX1.5-5b-SAT">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX1.5-5b-SAT">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX1.5-5b-SAT">🟣 WiseModel</a></td>
|
||||
<td colspan="3" style="text-align: center;"><a href="./sat/README_zh.md">SAT</a></td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
|
36
README_zh.md
36
README_zh.md
@ -154,49 +154,49 @@ CogVideoX是 [清影](https://chatglm.cn/video?fr=osm_cogvideox) 同源的开源
|
||||
<table style="border-collapse: collapse; width: 100%;">
|
||||
<tr>
|
||||
<th style="text-align: center;">模型名</th>
|
||||
<th style="text-align: center;">CogVideoX1.5-5B (最新)</th>
|
||||
<th style="text-align: center;">CogVideoX1.5-5B-I2V (最新)</th>
|
||||
<th style="text-align: center;">CogVideoX-2B</th>
|
||||
<th style="text-align: center;">CogVideoX-5B</th>
|
||||
<th style="text-align: center;">CogVideoX-5B-I2V </th>
|
||||
<th style="text-align: center;">CogVideoX1.5-5B</th>
|
||||
<th style="text-align: center;">CogVideoX1.5-5B-I2V</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">发布时间</td>
|
||||
<th style="text-align: center;">2024年11月8日</th>
|
||||
<th style="text-align: center;">2024年11月8日</th>
|
||||
<th style="text-align: center;">2024年8月6日</th>
|
||||
<th style="text-align: center;">2024年8月27日</th>
|
||||
<th style="text-align: center;">2024年9月19日</th>
|
||||
<th style="text-align: center;">2024年11月8日</th>
|
||||
<th style="text-align: center;">2024年11月8日</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">视频分辨率</td>
|
||||
<td colspan="3" style="text-align: center;">720 * 480</td>
|
||||
<td colspan="1" style="text-align: center;">1360 * 768</td>
|
||||
<td colspan="1" style="text-align: center;">256 <= W <=1360<br> 256 <= H <=768<br> W,H % 16 == 0</td>
|
||||
</tr>
|
||||
<td colspan="3" style="text-align: center;">720 * 480</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">推理精度</td>
|
||||
<td colspan="2" style="text-align: center;"><b>BF16</b></td>
|
||||
<td style="text-align: center;"><b>FP16*(推荐)</b>, BF16, FP32,FP8*,INT8,不支持INT4</td>
|
||||
<td colspan="2" style="text-align: center;"><b>BF16(推荐)</b>, FP16, FP32,FP8*,INT8,不支持INT4</td>
|
||||
<td colspan="2" style="text-align: center;"><b>BF16</b></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">单GPU显存消耗<br></td>
|
||||
<td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 66GB <br></td>
|
||||
<td style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> FP16: 18GB <br><b>diffusers FP16: 4GB起* </b><br><b>diffusers INT8(torchao): 3.6G起*</b></td>
|
||||
<td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 26GB <br><b>diffusers BF16 : 5GB起* </b><br><b>diffusers INT8(torchao): 4.4G起* </b></td>
|
||||
<td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 66GB <br></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">多GPU推理显存消耗</td>
|
||||
<td colspan="2" style="text-align: center;"><b>不支持</b><br></td>
|
||||
<td style="text-align: center;"><b>FP16: 10GB* using diffusers</b><br></td>
|
||||
<td colspan="2" style="text-align: center;"><b>BF16: 15GB* using diffusers</b><br></td>
|
||||
<td colspan="2" style="text-align: center;"><b>Not support</b><br></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">推理速度<br>(Step = 50, FP/BF16)</td>
|
||||
<td colspan="2" style="text-align: center;">单卡A100: ~1000秒(5秒视频)<br>单卡H100: ~550秒(5秒视频)</td>
|
||||
<td style="text-align: center;">单卡A100: ~90秒<br>单卡H100: ~45秒</td>
|
||||
<td colspan="2" style="text-align: center;">单卡A100: ~180秒<br>单卡H100: ~90秒</td>
|
||||
<td colspan="2" style="text-align: center;">单卡A100: ~1000秒(5秒视频)<br>单卡H100: ~550秒(5秒视频)</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">提示词语言</td>
|
||||
@ -204,39 +204,37 @@ CogVideoX是 [清影](https://chatglm.cn/video?fr=osm_cogvideox) 同源的开源
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">提示词长度上限</td>
|
||||
<td colspan="3" style="text-align: center;">226 Tokens</td>
|
||||
<td colspan="2" style="text-align: center;">224 Tokens</td>
|
||||
<td colspan="3" style="text-align: center;">226 Tokens</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">视频长度</td>
|
||||
<td colspan="3" style="text-align: center;">6 秒</td>
|
||||
<td colspan="2" style="text-align: center;">5 秒 或 10 秒</td>
|
||||
<td colspan="3" style="text-align: center;">6 秒</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">帧率</td>
|
||||
<td colspan="3" style="text-align: center;">8 帧 / 秒 </td>
|
||||
<td colspan="2" style="text-align: center;">16 帧 / 秒 </td>
|
||||
<td colspan="3" style="text-align: center;">8 帧 / 秒 </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">位置编码</td>
|
||||
<td style="text-align: center;">3d_sincos_pos_embed</td>
|
||||
<td colspan="2" style="text-align: center;">3d_rope_pos_embed</td>
|
||||
<td style="text-align: center;">3d_sincos_pos_embed</td>
|
||||
<td style="text-align: center;">3d_rope_pos_embed</td>
|
||||
<td style="text-align: center;">3d_rope_pos_embed + learnable_pos_embed</td>
|
||||
<td style="text-align: center;">3d_rope_pos_embed</td>
|
||||
<td style="text-align: center;">3d_rope_pos_embed</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">下载链接 (Diffusers)</td>
|
||||
<td colspan="2" style="text-align: center;"> 即将推出 </td>
|
||||
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-2b">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-2b">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-2b">🟣 WiseModel</a></td>
|
||||
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-5b">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-5b">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-5b">🟣 WiseModel</a></td>
|
||||
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-5b-I2V">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-5b-I2V">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-5b-I2V">🟣 WiseModel</a></td>
|
||||
<td colspan="2" style="text-align: center;"> 即将推出 </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">下载链接 (SAT)</td>
|
||||
<td colspan="3" style="text-align: center;"><a href="./sat/README_zh.md">SAT</a></td>
|
||||
<td colspan="2" style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX1.5-5b-SAT">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX1.5-5b-SAT">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX1.5-5b-SAT">🟣 WiseModel</a></td>
|
||||
|
||||
<td colspan="3" style="text-align: center;"><a href="./sat/README_zh.md">SAT</a></td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
|
@ -23,7 +23,7 @@ model:
|
||||
params:
|
||||
time_embed_dim: 512
|
||||
elementwise_affine: True
|
||||
num_frames: 81
|
||||
num_frames: 81 # for 5 seconds and 161 for 10 seconds
|
||||
time_compressed_rate: 4
|
||||
latent_width: 300
|
||||
latent_height: 300
|
||||
|
@ -25,11 +25,10 @@ model:
|
||||
network_config:
|
||||
target: dit_video_concat.DiffusionTransformer
|
||||
params:
|
||||
# space_interpolation: 1.875
|
||||
ofs_embed_dim: 512
|
||||
time_embed_dim: 512
|
||||
elementwise_affine: True
|
||||
num_frames: 81
|
||||
num_frames: 81 # for 5 seconds and 161 for 10 seconds
|
||||
time_compressed_rate: 4
|
||||
latent_width: 300
|
||||
latent_height: 300
|
||||
|
@ -1,16 +1,14 @@
|
||||
args:
|
||||
image2video: False # True for image2video, False for text2video
|
||||
# image2video: True # True for image2video, False for text2video
|
||||
latent_channels: 16
|
||||
mode: inference
|
||||
load: "{your CogVideoX SAT folder}/transformer" # This is for Full model without lora adapter
|
||||
# load: "{your lora folder} such as zRzRzRzRzRzRzR/lora-disney-08-20-13-28" # This is for Full model without lora adapter
|
||||
batch_size: 1
|
||||
input_type: txt
|
||||
input_file: configs/test.txt
|
||||
sampling_image_size: [480, 720]
|
||||
sampling_num_frames: 13 # Must be 13, 11 or 9
|
||||
sampling_fps: 8
|
||||
# fp16: True # For CogVideoX-2B
|
||||
bf16: True # For CogVideoX-5B and CoGVideoX-5B-I2V
|
||||
output_dir: outputs/
|
||||
sampling_image_size: [768, 1360] # remove this for I2V
|
||||
sampling_num_frames: 22 # 42 for 10 seconds and 22 for 5 seconds
|
||||
sampling_fps: 16
|
||||
bf16: True
|
||||
output_dir: outputs
|
||||
force_inference: True
|
@ -192,13 +192,13 @@ class SATVideoDiffusionEngine(nn.Module):
|
||||
for i in range(fake_cp_size):
|
||||
end_frame = start_frame + latent_time // fake_cp_size + (1 if i < latent_time % fake_cp_size else 0)
|
||||
|
||||
fake_cp_rank0 = True if i == 0 else False
|
||||
use_cp = True if i == 0 else False
|
||||
clear_fake_cp_cache = True if i == fake_cp_size - 1 else False
|
||||
with torch.no_grad():
|
||||
recon = self.first_stage_model.decode(
|
||||
z_now[:, :, start_frame:end_frame].contiguous(),
|
||||
clear_fake_cp_cache=clear_fake_cp_cache,
|
||||
fake_cp_rank0=fake_cp_rank0,
|
||||
use_cp=use_cp,
|
||||
)
|
||||
recons.append(recon)
|
||||
start_frame = end_frame
|
||||
|
@ -7,7 +7,6 @@ import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from sat.model.base_model import BaseModel, non_conflict
|
||||
from sat.model.mixins import BaseMixin
|
||||
from sat.transformer_defaults import HOOKS_DEFAULT, attention_fn_default
|
||||
|
@ -4,7 +4,7 @@ echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
|
||||
environs="WORLD_SIZE=1 RANK=0 LOCAL_RANK=0 LOCAL_WORLD_SIZE=1"
|
||||
|
||||
run_cmd="$environs python sample_video.py --base configs/cogvideox1.5_5b.yaml configs/test_inference.yaml --seed $RANDOM"
|
||||
run_cmd="$environs python sample_video.py --base configs/cogvideox1.5_5b.yaml configs/inference.yaml --seed $RANDOM"
|
||||
|
||||
echo ${run_cmd}
|
||||
eval ${run_cmd}
|
||||
|
@ -1,17 +1,13 @@
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
import random
|
||||
from abc import abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
from packaging import version
|
||||
|
||||
from vae_modules.ema import LitEma
|
||||
@ -56,34 +52,16 @@ class AbstractAutoencoder(pl.LightningModule):
|
||||
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
||||
self.automatic_optimization = False
|
||||
|
||||
# def apply_ckpt(self, ckpt: Union[None, str, dict]):
|
||||
# if ckpt is None:
|
||||
# return
|
||||
# if isinstance(ckpt, str):
|
||||
# ckpt = {
|
||||
# "target": "sgm.modules.checkpoint.CheckpointEngine",
|
||||
# "params": {"ckpt_path": ckpt},
|
||||
# }
|
||||
# engine = instantiate_from_config(ckpt)
|
||||
# engine(self)
|
||||
|
||||
def apply_ckpt(self, ckpt: Union[None, str, dict]):
|
||||
if ckpt is None:
|
||||
return
|
||||
self.init_from_ckpt(ckpt)
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if k.startswith(ik):
|
||||
print("Deleting key {} from state_dict.".format(k))
|
||||
del sd[k]
|
||||
missing_keys, unexpected_keys = self.load_state_dict(sd, strict=False)
|
||||
print("Missing keys: ", missing_keys)
|
||||
print("Unexpected keys: ", unexpected_keys)
|
||||
print(f"Restored from {path}")
|
||||
if isinstance(ckpt, str):
|
||||
ckpt = {
|
||||
"target": "sgm.modules.checkpoint.CheckpointEngine",
|
||||
"params": {"ckpt_path": ckpt},
|
||||
}
|
||||
engine = instantiate_from_config(ckpt)
|
||||
engine(self)
|
||||
|
||||
@abstractmethod
|
||||
def get_input(self, batch) -> Any:
|
||||
@ -119,7 +97,9 @@ class AbstractAutoencoder(pl.LightningModule):
|
||||
|
||||
def instantiate_optimizer_from_config(self, params, lr, cfg):
|
||||
logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
|
||||
return get_obj_from_str(cfg["target"])(params, lr=lr, **cfg.get("params", dict()))
|
||||
return get_obj_from_str(cfg["target"])(
|
||||
params, lr=lr, **cfg.get("params", dict())
|
||||
)
|
||||
|
||||
def configure_optimizers(self) -> Any:
|
||||
raise NotImplementedError()
|
||||
@ -216,12 +196,13 @@ class AutoencodingEngine(AbstractAutoencoder):
|
||||
return self.decoder.get_last_layer()
|
||||
|
||||
def encode(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
return_reg_log: bool = False,
|
||||
unregularized: bool = False,
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
return_reg_log: bool = False,
|
||||
unregularized: bool = False,
|
||||
**kwargs,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
||||
z = self.encoder(x)
|
||||
z = self.encoder(x, **kwargs)
|
||||
if unregularized:
|
||||
return z, dict()
|
||||
z, reg_log = self.regularization(z)
|
||||
|
@ -101,8 +101,6 @@ def _gather(input_, dim):
|
||||
group = get_context_parallel_group()
|
||||
cp_rank = get_context_parallel_rank()
|
||||
|
||||
# print('in _gather, cp_rank:', cp_rank, 'input_size:', input_.shape)
|
||||
|
||||
input_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous()
|
||||
if cp_rank == 0:
|
||||
input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous()
|
||||
@ -127,12 +125,9 @@ def _gather(input_, dim):
|
||||
def _conv_split(input_, dim, kernel_size):
|
||||
cp_world_size = get_context_parallel_world_size()
|
||||
|
||||
# Bypass the function if context parallel is 1
|
||||
if cp_world_size == 1:
|
||||
return input_
|
||||
|
||||
# print('in _conv_split, cp_rank:', cp_rank, 'input_size:', input_.shape)
|
||||
|
||||
cp_rank = get_context_parallel_rank()
|
||||
|
||||
dim_size = (input_.size()[dim] - kernel_size) // cp_world_size
|
||||
@ -140,14 +135,11 @@ def _conv_split(input_, dim, kernel_size):
|
||||
if cp_rank == 0:
|
||||
output = input_.transpose(dim, 0)[: dim_size + kernel_size].transpose(dim, 0)
|
||||
else:
|
||||
# output = input_.transpose(dim, 0)[cp_rank * dim_size + 1:(cp_rank + 1) * dim_size + kernel_size].transpose(dim, 0)
|
||||
output = input_.transpose(dim, 0)[
|
||||
cp_rank * dim_size + kernel_size : (cp_rank + 1) * dim_size + kernel_size
|
||||
].transpose(dim, 0)
|
||||
output = output.contiguous()
|
||||
|
||||
# print('out _conv_split, cp_rank:', cp_rank, 'input_size:', output.shape)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@ -160,9 +152,6 @@ def _conv_gather(input_, dim, kernel_size):
|
||||
|
||||
group = get_context_parallel_group()
|
||||
cp_rank = get_context_parallel_rank()
|
||||
|
||||
# print('in _conv_gather, cp_rank:', cp_rank, 'input_size:', input_.shape)
|
||||
|
||||
input_first_kernel_ = input_.transpose(0, dim)[:kernel_size].transpose(0, dim).contiguous()
|
||||
if cp_rank == 0:
|
||||
input_ = input_.transpose(0, dim)[kernel_size:].transpose(0, dim).contiguous()
|
||||
@ -255,17 +244,12 @@ def _fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding=Non
|
||||
if recv_rank % cp_world_size == cp_world_size - 1:
|
||||
recv_rank += cp_world_size
|
||||
|
||||
# req_send = torch.distributed.isend(input_[-kernel_size + 1:].contiguous(), send_rank, group=group)
|
||||
# recv_buffer = torch.empty_like(input_[-kernel_size + 1:]).contiguous()
|
||||
# req_recv = torch.distributed.recv(recv_buffer, recv_rank, group=group)
|
||||
# req_recv.wait()
|
||||
recv_buffer = torch.empty_like(input_[-kernel_size + 1 :]).contiguous()
|
||||
if cp_rank < cp_world_size - 1:
|
||||
req_send = torch.distributed.isend(input_[-kernel_size + 1 :].contiguous(), send_rank, group=group)
|
||||
if cp_rank > 0:
|
||||
req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group)
|
||||
# req_send = torch.distributed.isend(input_[-kernel_size + 1:].contiguous(), send_rank, group=group)
|
||||
# req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group)
|
||||
|
||||
|
||||
if cp_rank == 0:
|
||||
if cache_padding is not None:
|
||||
@ -421,7 +405,6 @@ class ContextParallelGroupNorm(torch.nn.GroupNorm):
|
||||
|
||||
|
||||
def Normalize(in_channels, gather=False, **kwargs):
|
||||
# same for 3D and 2D
|
||||
if gather:
|
||||
return ContextParallelGroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
else:
|
||||
@ -468,8 +451,8 @@ class SpatialNorm3D(nn.Module):
|
||||
kernel_size=1,
|
||||
)
|
||||
|
||||
def forward(self, f, zq, clear_fake_cp_cache=True, fake_cp_rank0=True):
|
||||
if f.shape[2] > 1 and get_context_parallel_rank() == 0 and fake_cp_rank0:
|
||||
def forward(self, f, zq, clear_fake_cp_cache=True, fake_cp=True):
|
||||
if f.shape[2] > 1 and get_context_parallel_rank() == 0 and fake_cp:
|
||||
f_first, f_rest = f[:, :, :1], f[:, :, 1:]
|
||||
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
|
||||
zq_first, zq_rest = zq[:, :, :1], zq[:, :, 1:]
|
||||
@ -531,13 +514,11 @@ class Upsample3D(nn.Module):
|
||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.compress_time = compress_time
|
||||
|
||||
def forward(self, x, fake_cp_rank0=True):
|
||||
def forward(self, x, fake_cp=True):
|
||||
if self.compress_time and x.shape[2] > 1:
|
||||
if get_context_parallel_rank() == 0 and fake_cp_rank0:
|
||||
# print(x.shape)
|
||||
if get_context_parallel_rank() == 0 and fake_cp:
|
||||
# split first frame
|
||||
x_first, x_rest = x[:, :, 0], x[:, :, 1:]
|
||||
|
||||
x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0, mode="nearest")
|
||||
|
||||
splits = torch.split(x_rest, 32, dim=1)
|
||||
@ -545,8 +526,6 @@ class Upsample3D(nn.Module):
|
||||
torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits
|
||||
]
|
||||
x_rest = torch.cat(interpolated_splits, dim=1)
|
||||
|
||||
# x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0, mode="nearest")
|
||||
x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2)
|
||||
else:
|
||||
splits = torch.split(x, 32, dim=1)
|
||||
@ -555,13 +534,10 @@ class Upsample3D(nn.Module):
|
||||
]
|
||||
x = torch.cat(interpolated_splits, dim=1)
|
||||
|
||||
# x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
|
||||
else:
|
||||
# only interpolate 2D
|
||||
t = x.shape[2]
|
||||
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||
# x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
|
||||
splits = torch.split(x, 32, dim=1)
|
||||
interpolated_splits = [
|
||||
@ -590,12 +566,12 @@ class DownSample3D(nn.Module):
|
||||
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0)
|
||||
self.compress_time = compress_time
|
||||
|
||||
def forward(self, x, fake_cp_rank0=True):
|
||||
def forward(self, x, fake_cp=True):
|
||||
if self.compress_time and x.shape[2] > 1:
|
||||
h, w = x.shape[-2:]
|
||||
x = rearrange(x, "b c t h w -> (b h w) c t")
|
||||
|
||||
if get_context_parallel_rank() == 0 and fake_cp_rank0:
|
||||
if get_context_parallel_rank() == 0 and fake_cp:
|
||||
# split first frame
|
||||
x_first, x_rest = x[..., 0], x[..., 1:]
|
||||
|
||||
@ -693,17 +669,13 @@ class ContextParallelResnetBlock3D(nn.Module):
|
||||
padding=0,
|
||||
)
|
||||
|
||||
def forward(self, x, temb, zq=None, clear_fake_cp_cache=True, fake_cp_rank0=True):
|
||||
def forward(self, x, temb, zq=None, clear_fake_cp_cache=True, fake_cp=True):
|
||||
h = x
|
||||
|
||||
# if isinstance(self.norm1, torch.nn.GroupNorm):
|
||||
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
|
||||
if zq is not None:
|
||||
h = self.norm1(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0)
|
||||
h = self.norm1(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=fake_cp)
|
||||
else:
|
||||
h = self.norm1(h)
|
||||
# if isinstance(self.norm1, torch.nn.GroupNorm):
|
||||
# h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
|
||||
|
||||
h = nonlinearity(h)
|
||||
h = self.conv1(h, clear_cache=clear_fake_cp_cache)
|
||||
@ -711,14 +683,10 @@ class ContextParallelResnetBlock3D(nn.Module):
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None]
|
||||
|
||||
# if isinstance(self.norm2, torch.nn.GroupNorm):
|
||||
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
|
||||
if zq is not None:
|
||||
h = self.norm2(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0)
|
||||
h = self.norm2(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=fake_cp)
|
||||
else:
|
||||
h = self.norm2(h)
|
||||
# if isinstance(self.norm2, torch.nn.GroupNorm):
|
||||
# h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
|
||||
|
||||
h = nonlinearity(h)
|
||||
h = self.dropout(h)
|
||||
@ -827,32 +795,33 @@ class ContextParallelEncoder3D(nn.Module):
|
||||
kernel_size=3,
|
||||
)
|
||||
|
||||
def forward(self, x, clear_fake_cp_cache=True, fake_cp_rank0=True):
|
||||
def forward(self, x, use_cp=True):
|
||||
global _USE_CP
|
||||
_USE_CP = use_cp
|
||||
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# downsampling
|
||||
h = self.conv_in(x, clear_cache=clear_fake_cp_cache)
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](h, temb, clear_fake_cp_cache=clear_fake_cp_cache)
|
||||
h = self.down[i_level].block[i_block](hs[-1], temb)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
print("Attention not implemented")
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
h = self.down[i_level].downsample(h, fake_cp_rank0=fake_cp_rank0)
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h, temb, clear_fake_cp_cache=clear_fake_cp_cache)
|
||||
h = self.mid.block_2(h, temb, clear_fake_cp_cache=clear_fake_cp_cache)
|
||||
h = hs[-1]
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# end
|
||||
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
|
||||
h = self.norm_out(h)
|
||||
# h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
|
||||
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h, clear_cache=clear_fake_cp_cache)
|
||||
h = self.conv_out(h)
|
||||
|
||||
return h
|
||||
|
||||
@ -895,11 +864,9 @@ class ContextParallelDecoder3D(nn.Module):
|
||||
zq_ch = z_channels
|
||||
|
||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
|
||||
|
||||
self.conv_in = ContextParallelCausalConv3d(
|
||||
chan_in=z_channels,
|
||||
@ -955,11 +922,6 @@ class ContextParallelDecoder3D(nn.Module):
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
# # Symmetrical enc-dec
|
||||
if i_level <= self.temporal_compress_level:
|
||||
up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=True)
|
||||
else:
|
||||
up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False)
|
||||
if i_level < self.num_resolutions - self.temporal_compress_level:
|
||||
up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False)
|
||||
else:
|
||||
@ -974,7 +936,9 @@ class ContextParallelDecoder3D(nn.Module):
|
||||
kernel_size=3,
|
||||
)
|
||||
|
||||
def forward(self, z, clear_fake_cp_cache=True, fake_cp_rank0=True):
|
||||
def forward(self, z, clear_fake_cp_cache=True, use_cp=True):
|
||||
global _USE_CP
|
||||
_USE_CP = use_cp
|
||||
self.last_z_shape = z.shape
|
||||
|
||||
# timestep embedding
|
||||
@ -987,25 +951,25 @@ class ContextParallelDecoder3D(nn.Module):
|
||||
h = self.conv_in(z, clear_cache=clear_fake_cp_cache)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0)
|
||||
h = self.mid.block_2(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0)
|
||||
h = self.mid.block_1(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=use_cp)
|
||||
h = self.mid.block_2(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=use_cp)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](
|
||||
h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0
|
||||
h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=use_cp
|
||||
)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h, zq)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h, fake_cp_rank0=fake_cp_rank0)
|
||||
h = self.up[i_level].upsample(h, fake_cp=use_cp)
|
||||
|
||||
# end
|
||||
if self.give_pre_end:
|
||||
return h
|
||||
|
||||
h = self.norm_out(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0)
|
||||
h = self.norm_out(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=use_cp)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h, clear_cache=clear_fake_cp_cache)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user