diff --git a/README.md b/README.md
index a27782c..4e83ec2 100644
--- a/README.md
+++ b/README.md
@@ -22,7 +22,10 @@ Experience the CogVideoX-5B model online at CogVideoX-2B
CogVideoX-5B |
CogVideoX-5B-I2V |
+ CogVideoX1.5-5B |
+ CogVideoX1.5-5B-I2V |
- Model Description |
- Entry-level model, balancing compatibility. Low cost for running and secondary development. |
- Larger model with higher video generation quality and better visual effects. |
- CogVideoX-5B image-to-video version. |
+ Release Date |
+ August 6, 2024 |
+ August 27, 2024 |
+ September 19, 2024 |
+ November 8, 2024 |
+ November 8, 2024 |
+
+
+ Video Resolution |
+ 720 * 480 |
+ 1360 * 768 |
+ 256 <= W <=1360 256 <= H <=768 W,H % 16 == 0 |
Inference Precision |
FP16*(recommended), BF16, FP32, FP8*, INT8, not supported: INT4 |
- BF16 (recommended), FP16, FP32, FP8*, INT8, not supported: INT4 |
+ BF16(recommended), FP16, FP32, FP8*, INT8, not supported: INT4 |
+ BF16 |
- Single GPU Memory Usage
|
- SAT FP16: 18GB diffusers FP16: from 4GB* diffusers INT8 (torchao): from 3.6GB* |
- SAT BF16: 26GB diffusers BF16: from 5GB* diffusers INT8 (torchao): from 4.4GB* |
+ Single GPU Memory Usage |
+ SAT FP16: 18GB diffusers FP16: from 4GB* diffusers INT8(torchao): from 3.6GB* |
+ SAT BF16: 26GB diffusers BF16 : from 5GB* diffusers INT8(torchao): from 4.4GB* |
+ SAT BF16: 66GB
|
- Multi-GPU Inference Memory Usage |
+ Multi-GPU Memory Usage |
FP16: 10GB* using diffusers
|
BF16: 15GB* using diffusers
|
+ Not supported
|
Inference Speed (Step = 50, FP/BF16) |
Single A100: ~90 seconds Single H100: ~45 seconds |
Single A100: ~180 seconds Single H100: ~90 seconds |
-
-
- Fine-tuning Precision |
- FP16 |
- BF16 |
-
-
- Fine-tuning Memory Usage |
- 47 GB (bs=1, LORA) 61 GB (bs=2, LORA) 62GB (bs=1, SFT) |
- 63 GB (bs=1, LORA) 80 GB (bs=2, LORA) 75GB (bs=1, SFT)
|
- 78 GB (bs=1, LORA) 75GB (bs=1, SFT, 16GPU)
|
+ Single A100: ~1000 seconds (5-second video) Single H100: ~550 seconds (5-second video) |
Prompt Language |
- English* |
+ English* |
- Maximum Prompt Length |
+ Prompt Token Limit |
226 Tokens |
+ 224 Tokens |
Video Length |
- 6 Seconds |
+ 6 seconds |
+ 5 or 10 seconds |
Frame Rate |
- 8 Frames / Second |
+ 8 frames / second |
+ 16 frames / second |
- Video Resolution |
- 720 x 480, no support for other resolutions (including fine-tuning) |
-
-
- Position Encoding |
+ Positional Encoding |
3d_sincos_pos_embed |
3d_sincos_pos_embed |
3d_rope_pos_embed + learnable_pos_embed |
+ 3d_sincos_pos_embed |
+ 3d_rope_pos_embed + learnable_pos_embed |
Download Link (Diffusers) |
ð€ HuggingFace ð€ ModelScope ð£ WiseModel |
ð€ HuggingFace ð€ ModelScope ð£ WiseModel |
ð€ HuggingFace ð€ ModelScope ð£ WiseModel |
+ Coming Soon |
Download Link (SAT) |
- SAT |
+ SAT |
+ ð€ HuggingFace ð€ ModelScope ð£ WiseModel |
@@ -422,7 +430,7 @@ hands-on practice on text-to-video generation. *The original input is in Chinese
We welcome your contributions! You can click [here](resources/contribute.md) for more information.
-## License Agreement
+## Model-License
The code in this repository is released under the [Apache 2.0 License](LICENSE).
diff --git a/README_ja.md b/README_ja.md
index 69b46b6..aa7ae37 100644
--- a/README_ja.md
+++ b/README_ja.md
@@ -1,6 +1,6 @@
# CogVideo & CogVideoX
-[Read this in English](./README_zh.md)
+[Read this in English](./README.md)
[äžæé
读](./README_zh.md)
@@ -22,9 +22,14 @@
## æŽæ°ãšãã¥ãŒã¹
-- ð¥ð¥ **ãã¥ãŒã¹**: ```2024/10/13```: ã³ã¹ãåæžã®ãããåäžã®4090 GPUã§`CogVideoX-5B`
+- ð¥ð¥ ãã¥ãŒã¹: ```2024/11/08```: `CogVideoX1.5` ã¢ãã«ããªãªãŒã¹ããŸãããCogVideoX1.5 㯠CogVideoX ãªãŒãã³ãœãŒã¹ã¢ãã«ã®ã¢ããã°ã¬ãŒãããŒãžã§ã³ã§ãã
+CogVideoX1.5-5B ã·ãªãŒãºã¢ãã«ã¯ã10ç§ é·ã®åç»ãšããé«ã解å床ããµããŒãããŠããã`CogVideoX1.5-5B-I2V` ã¯ä»»æã®è§£å床ã§ã®åç»çæã«å¯Ÿå¿ããŠããŸãã
+SAT ã³ãŒãã¯ãã§ã«æŽæ°ãããŠããã`diffusers` ããŒãžã§ã³ã¯çŸåšé©å¿äžã§ãã
+SAT ããŒãžã§ã³ã®ã³ãŒã㯠[ãã¡ã](https://huggingface.co/THUDM/CogVideoX1.5-5B-SAT) ããããŠã³ããŒãã§ããŸãã
+- ð¥ **ãã¥ãŒã¹**: ```2024/10/13```: ã³ã¹ãåæžã®ãããåäžã®4090 GPUã§`CogVideoX-5B`
ã埮調æŽã§ãããã¬ãŒã ã¯ãŒã¯ [cogvideox-factory](https://github.com/a-r-r-o-w/cogvideox-factory)
- ããªãªãŒã¹ãããŸãããè€æ°ã®è§£å床ã§ã®åŸ®èª¿æŽã«å¯Ÿå¿ããŠããŸãããã²ãå©çšãã ããïŒ- ð¥**ãã¥ãŒã¹**: ```2024/10/10```:
+ ããªãªãŒã¹ãããŸãããè€æ°ã®è§£å床ã§ã®åŸ®èª¿æŽã«å¯Ÿå¿ããŠããŸãããã²ãå©çšãã ããïŒ
+- ð¥**ãã¥ãŒã¹**: ```2024/10/10```:
æè¡å ±åæžãæŽæ°ãããã詳现ãªãã¬ãŒãã³ã°æ
å ±ãšãã¢ãè¿œå ããŸããã
- ð¥ **ãã¥ãŒã¹**: ```2024/10/10```: æè¡å ±åæžãæŽæ°ããŸããã[ãã¡ã](https://arxiv.org/pdf/2408.06072)
ãã¯ãªãã¯ããŠã芧ãã ãããããã«ãã¬ãŒãã³ã°ã®è©³çŽ°ãšãã¢ãè¿œå ããŸããããã¢ãèŠãã«ã¯[ãã¡ã](https://yzy-thu.github.io/CogVideoX-demo/)
@@ -34,7 +39,7 @@
- ð¥**ãã¥ãŒã¹**: ```2024/9/19```: CogVideoXã·ãªãŒãºã®ç»åçæãããªã¢ãã« **CogVideoX-5B-I2V**
ããªãŒãã³ãœãŒã¹åããŸããããã®ã¢ãã«ã¯ãç»åãèæ¯å
¥åãšããŠäœ¿çšããããã³ããã¯ãŒããšçµã¿åãããŠãããªãçæããããšãã§ããããé«ãå¶åŸ¡æ§ãæäŸããŸããããã«ãããCogVideoXã·ãªãŒãºã®ã¢ãã«ã¯ãããã¹ããããããªçæããããªã®ç¶ç¶ãç»åãããããªçæã®3ã€ã®ã¿ã¹ã¯ããµããŒãããããã«ãªããŸããããªã³ã©ã€ã³ã§ã®[äœéš](https://huggingface.co/spaces/THUDM/CogVideoX-5B-Space)
ãã楜ãã¿ãã ããã
-- ð¥ð¥ **ãã¥ãŒã¹**: ```2024/9/19```:
+- ð¥ **ãã¥ãŒã¹**: ```2024/9/19```:
CogVideoXã®ãã¬ãŒãã³ã°ããã»ã¹ã§ãããªããŒã¿ãããã¹ãèšè¿°ã«å€æããããã«äœ¿çšããããã£ãã·ã§ã³ã¢ãã« [CogVLM2-Caption](https://huggingface.co/THUDM/cogvlm2-llama3-caption)
ããªãŒãã³ãœãŒã¹åããŸãããããŠã³ããŒãããŠãå©çšãã ããã
- ð¥ ```2024/8/27```: CogVideoXã·ãªãŒãºã®ãã倧ããªã¢ãã« **CogVideoX-5B**
@@ -63,11 +68,10 @@
- [ãããžã§ã¯ãæ§é ](#ãããžã§ã¯ãæ§é )
- [æšè«](#æšè«)
- [sat](#sat)
- - [ããŒã«](#ããŒã«)
-- [ãããžã§ã¯ãèšç»](#ãããžã§ã¯ãèšç»)
-- [ã¢ãã«ã©ã€ã»ã³ã¹](#ã¢ãã«ã©ã€ã»ã³ã¹)
+ - [ããŒã«](#ããŒã«)=
- [CogVideo(ICLR'23)ã¢ãã«çŽ¹ä»](#CogVideoICLR23)
- [åŒçš](#åŒçš)
+- [ã©ã€ã»ã³ã¹å¥çŽ](#ã©ã€ã»ã³ã¹å¥çŽ)
## ã¯ã€ãã¯ã¹ã¿ãŒã
@@ -156,79 +160,91 @@ pip install -r requirements.txt
CogVideoXã¯ã[æž
圱](https://chatglm.cn/video?fr=osm_cogvideox) ãšåæºã®ãªãŒãã³ãœãŒã¹çãããªçæã¢ãã«ã§ãã
以äžã®è¡šã«ãæäŸããŠãããããªçæã¢ãã«ã®åºæ¬æ
å ±ã瀺ããŸã:
-
+
ã¢ãã«å |
CogVideoX-2B |
CogVideoX-5B |
- CogVideoX-5B-I2V |
+ CogVideoX-5B-I2V |
+ CogVideoX1.5-5B |
+ CogVideoX1.5-5B-I2V |
+
+
+ ãªãªãŒã¹æ¥ |
+ 2024幎8æ6æ¥ |
+ 2024幎8æ27æ¥ |
+ 2024幎9æ19æ¥ |
+ 2024幎11æ8æ¥ |
+ 2024幎11æ8æ¥ |
+
+
+ ãããªè§£å床 |
+ 720 * 480 |
+ 1360 * 768 |
+ 256 <= W <=1360 256 <= H <=768 W,H % 16 == 0 |
æšè«ç²ŸåºŠ |
FP16*(æšå¥š), BF16, FP32, FP8*, INT8, INT4ã¯éå¯Ÿå¿ |
BF16(æšå¥š), FP16, FP32, FP8*, INT8, INT4ã¯éå¯Ÿå¿ |
-
-
- åäžGPUã®ã¡ã¢ãªæ¶è²»
|
- SAT FP16: 18GB diffusers FP16: 4GBãã* diffusers INT8(torchao): 3.6GBãã* |
- SAT BF16: 26GB diffusers BF16 : 5GBãã* diffusers INT8(torchao): 4.4GBãã* |
-
-
- ãã«ãGPUã®ã¡ã¢ãªæ¶è²» |
- FP16: 10GB* using diffusers
|
- BF16: 15GB* using diffusers
|
-
-
- æšè«é床 (ã¹ããã = 50, FP/BF16) |
- åäžA100: çŽ90ç§ åäžH100: çŽ45ç§ |
- åäžA100: çŽ180ç§ åäžH100: çŽ90ç§ |
-
-
- ãã¡ã€ã³ãã¥ãŒãã³ã°ç²ŸåºŠ |
- FP16 |
BF16 |
- ãã¡ã€ã³ãã¥ãŒãã³ã°æã®ã¡ã¢ãªæ¶è²» |
- 47 GB (bs=1, LORA) 61 GB (bs=2, LORA) 62GB (bs=1, SFT) |
- 63 GB (bs=1, LORA) 80 GB (bs=2, LORA) 75GB (bs=1, SFT)
|
- 78 GB (bs=1, LORA) 75GB (bs=1, SFT, 16GPU)
|
+ ã·ã³ã°ã«GPUã¡ã¢ãªæ¶è²» |
+ SAT FP16: 18GB diffusers FP16: 4GBãã* diffusers INT8(torchao): 3.6GBãã* |
+ SAT BF16: 26GB diffusers BF16: 5GBãã* diffusers INT8(torchao): 4.4GBãã* |
+ SAT BF16: 66GB
|
+
+
+ ãã«ãGPUã¡ã¢ãªæ¶è²» |
+ FP16: 10GB* using diffusers
|
+ BF16: 15GB* using diffusers
|
+ ãµããŒããªã
|
+
+
+ æšè«é床 (ã¹ãããæ° = 50, FP/BF16) |
+ åäžA100: çŽ90ç§ åäžH100: çŽ45ç§ |
+ åäžA100: çŽ180ç§ åäžH100: çŽ90ç§ |
+ åäžA100: çŽ1000ç§(5ç§åç») åäžH100: çŽ550ç§(5ç§åç») |
ããã³ããèšèª |
- è±èª* |
+ è±èª* |
- ããã³ããã®æ倧ããŒã¯ã³æ° |
+ ããã³ããããŒã¯ã³å¶é |
226ããŒã¯ã³ |
+ 224ããŒã¯ã³ |
ãããªã®é·ã |
6ç§ |
+ 5ç§ãŸãã¯10ç§ |
ãã¬ãŒã ã¬ãŒã |
- 8ãã¬ãŒã /ç§ |
-
-
- ãããªè§£å床 |
- 720 * 480ãä»ã®è§£å床ã¯é察å¿(ãã¡ã€ã³ãã¥ãŒãã³ã°å«ã) |
+ 8 ãã¬ãŒã / ç§ |
+ 16 ãã¬ãŒã / ç§ |
äœçœ®ãšã³ã³ãŒãã£ã³ã° |
3d_sincos_pos_embed |
3d_sincos_pos_embed |
3d_rope_pos_embed + learnable_pos_embed |
+ 3d_sincos_pos_embed |
+ 3d_rope_pos_embed + learnable_pos_embed |
ããŠã³ããŒããªã³ã¯ (Diffusers) |
ð€ HuggingFace ð€ ModelScope ð£ WiseModel |
ð€ HuggingFace ð€ ModelScope ð£ WiseModel |
ð€ HuggingFace ð€ ModelScope ð£ WiseModel |
+ è¿æ¥å
Ž |
ããŠã³ããŒããªã³ã¯ (SAT) |
- SAT |
+ SAT |
+ ð€ HuggingFace ð€ ModelScope ð£ WiseModel |
diff --git a/README_zh.md b/README_zh.md
index 9f84f84..3574e7d 100644
--- a/README_zh.md
+++ b/README_zh.md
@@ -1,10 +1,9 @@
# CogVideo & CogVideoX
-[Read this in English](./README_zh.md)
+[Read this in English](./README.md)
[æ¥æ¬èªã§èªã](./README_ja.md)
-
@@ -23,7 +22,9 @@
## 项ç®æŽæ°
-- ð¥ð¥ **News**: ```2024/10/13```: ææ¬æŽäœïŒåå¡4090å¯åŸ®è°`CogVideoX-5B`
+- ð¥ð¥ **News**: ```2024/11/08```: æ们ååž `CogVideoX1.5` æš¡åãCogVideoX1.5 æ¯ CogVideoX åŒæºæš¡åçå级çæ¬ã
+CogVideoX1.5-5B ç³»åæš¡åæ¯æ **10ç§** é¿åºŠçè§é¢åæŽé«çå蟚çïŒå
¶äž `CogVideoX1.5-5B-I2V` æ¯æ **ä»»æå蟚ç** çè§é¢çæïŒSAT代ç å·²ç»æŽæ°ã`diffusers`çæ¬è¿åšéé
äžãSATçæ¬ä»£ç ååŸ [è¿é](https://huggingface.co/THUDM/CogVideoX1.5-5B-SAT) äžèœœã
+- ð¥**News**: ```2024/10/13```: ææ¬æŽäœïŒåå¡4090å¯åŸ®è° `CogVideoX-5B`
ç埮è°æ¡æ¶[cogvideox-factory](https://github.com/a-r-r-o-w/cogvideox-factory)å·²ç»æšåºïŒå€ç§å蟚ç埮è°ïŒæ¬¢è¿äœ¿çšã
- ð¥ **News**: ```2024/10/10```: æ们æŽæ°äºæ们çææ¯æ¥å,请ç¹å» [è¿é](https://arxiv.org/pdf/2408.06072)
æ¥çïŒéäžäºæŽå€çè®ç»ç»èådemoïŒå
³äºdemoïŒç¹å»[è¿é](https://yzy-thu.github.io/CogVideoX-demo/) æ¥çã
@@ -58,10 +59,9 @@
- [Inference](#inference)
- [SAT](#sat)
- [Tools](#tools)
-- [åŒæºé¡¹ç®è§å](#åŒæºé¡¹ç®è§å)
-- [æš¡ååè®®](#æš¡ååè®®)
- [CogVideo(ICLR'23)æš¡åä»ç»](#cogvideoiclr23)
- [åŒçš](#åŒçš)
+- [æš¡ååè®®](#æš¡ååè®®)
## å¿«éåŒå§
@@ -157,62 +157,72 @@ CogVideoXæ¯ [æž
圱](https://chatglm.cn/video?fr=osm_cogvideox) åæºçåŒæº
CogVideoX-2B |
CogVideoX-5B |
CogVideoX-5B-I2V |
+ CogVideoX1.5-5B |
+ CogVideoX1.5-5B-I2V |
+
+
+ ååžæ¶éŽ |
+ 2024幎8æ6æ¥ |
+ 2024幎8æ27æ¥ |
+ 2024幎9æ19æ¥ |
+ 2024幎11æ8æ¥ |
+ 2024幎11æ8æ¥ |
+
+
+ è§é¢å蟚ç |
+ 720 * 480 |
+ 1360 * 768 |
+ 256 <= W <=1360 256 <= H <=768 W,H % 16 == 0 |
æšç粟床 |
FP16*(æšè), BF16, FP32ïŒFP8*ïŒINT8ïŒäžæ¯æINT4 |
BF16(æšè), FP16, FP32ïŒFP8*ïŒINT8ïŒäžæ¯æINT4 |
+ BF16 |
åGPUæŸåæ¶è
|
SAT FP16: 18GB diffusers FP16: 4GBèµ·* diffusers INT8(torchao): 3.6Gèµ·* |
SAT BF16: 26GB diffusers BF16 : 5GBèµ·* diffusers INT8(torchao): 4.4Gèµ·* |
+ SAT BF16: 66GB
|
å€GPUæšçæŸåæ¶è |
FP16: 10GB* using diffusers
|
BF16: 15GB* using diffusers
|
+ Not support
|
æšçé床 (Step = 50, FP/BF16) |
åå¡A100: ~90ç§ åå¡H100: ~45ç§ |
åå¡A100: ~180ç§ åå¡H100: ~90ç§ |
-
-
- 埮è°ç²ŸåºŠ |
- FP16 |
- BF16 |
-
-
- 埮è°æŸåæ¶è |
- 47 GB (bs=1, LORA) 61 GB (bs=2, LORA) 62GB (bs=1, SFT) |
- 63 GB (bs=1, LORA) 80 GB (bs=2, LORA) 75GB (bs=1, SFT)
|
- 78 GB (bs=1, LORA) 75GB (bs=1, SFT, 16GPU)
|
+ åå¡A100: ~1000ç§(5ç§è§é¢) åå¡H100: ~550ç§(5ç§è§é¢) |
æ瀺è¯è¯èš |
- English* |
+ English* |
æ瀺è¯é¿åºŠäžé |
226 Tokens |
+ 224 Tokens |
è§é¢é¿åºŠ |
6 ç§ |
+ 5 ç§ æ 10 ç§ |
垧ç |
8 垧 / ç§ |
+ 16 垧 / ç§ |
- è§é¢å蟚ç |
- 720 * 480ïŒäžæ¯æå
¶ä»å蟚ç(å«åŸ®è°) |
-
-
äœçœ®çŒç |
3d_sincos_pos_embed |
- 3d_sincos_pos_embed |
+ 3d_sincos_pos_embed |
+ 3d_rope_pos_embed + learnable_pos_embed |
+ 3d_sincos_pos_embed |
3d_rope_pos_embed + learnable_pos_embed |
@@ -220,10 +230,13 @@ CogVideoXæ¯ [æž
圱](https://chatglm.cn/video?fr=osm_cogvideox) åæºçåŒæº
ð€ HuggingFace ð€ ModelScope ð£ WiseModel |
ð€ HuggingFace ð€ ModelScope ð£ WiseModel |
ð€ HuggingFace ð€ ModelScope ð£ WiseModel |
+ å³å°æšåº |
äžèœœéŸæ¥ (SAT) |
SAT |
+ ð€ HuggingFace ð€ ModelScope ð£ WiseModel |
+
diff --git a/sat/README.md b/sat/README.md
index 48c4552..c67e15c 100644
--- a/sat/README.md
+++ b/sat/README.md
@@ -1,29 +1,39 @@
-# SAT CogVideoX-2B
+# SAT CogVideoX
-[äžæé
读](./README_zh.md)
+[Read this in English.](./README_zh.md)
[æ¥æ¬èªã§èªã](./README_ja.md)
-This folder contains the inference code using [SAT](https://github.com/THUDM/SwissArmyTransformer) weights and the
-fine-tuning code for SAT weights.
+This folder contains inference code using [SAT](https://github.com/THUDM/SwissArmyTransformer) weights, along with fine-tuning code for SAT weights.
-This code is the framework used by the team to train the model. It has few comments and requires careful study.
+This code framework was used by our team during model training. There are few comments, so careful study is required.
## Inference Model
-### 1. Ensure that you have correctly installed the dependencies required by this folder.
+### 1. Make sure you have installed all dependencies in this folder
-```shell
+```
pip install -r requirements.txt
```
-### 2. Download the model weights
+### 2. Download the Model Weights
-### 2. Download model weights
+First, download the model weights from the SAT mirror.
-First, go to the SAT mirror to download the model weights. For the CogVideoX-2B model, please download as follows:
+#### CogVideoX1.5 Model
-```shell
+```
+git lfs install
+git clone https://huggingface.co/THUDM/CogVideoX1.5-5B-SAT
+```
+
+This command downloads three models: Transformers, VAE, and T5 Encoder.
+
+#### CogVideoX Model
+
+For the CogVideoX-2B model, download as follows:
+
+```
mkdir CogVideoX-2b-sat
cd CogVideoX-2b-sat
wget https://cloud.tsinghua.edu.cn/f/fdba7608a49c463ba754/?dl=1
@@ -34,13 +44,12 @@ mv 'index.html?dl=1' transformer.zip
unzip transformer.zip
```
-For the CogVideoX-5B model, please download the `transformers` file as follows link:
-(VAE files are the same as 2B)
+Download the `transformers` file for the CogVideoX-5B model (the VAE file is the same as for 2B):
+ [CogVideoX-5B](https://cloud.tsinghua.edu.cn/d/fcef5b3904294a6885e5/?p=%2F&mode=list)
+ [CogVideoX-5B-I2V](https://cloud.tsinghua.edu.cn/d/5cc62a2d6e7d45c0a2f6/?p=%2F1&mode=list)
-Next, you need to format the model files as follows:
+Arrange the model files in the following structure:
```
.
@@ -52,20 +61,24 @@ Next, you need to format the model files as follows:
âââ 3d-vae.pt
```
-Due to large size of model weight file, using `git lfs` is recommended. Installation of `git lfs` can be
-found [here](https://github.com/git-lfs/git-lfs?tab=readme-ov-file#installing)
+Since model weight files are large, itâs recommended to use `git lfs`.
+See [here](https://github.com/git-lfs/git-lfs?tab=readme-ov-file#installing) for `git lfs` installation.
-Next, clone the T5 model, which is not used for training and fine-tuning, but must be used.
-> T5 model is available on [Modelscope](https://modelscope.cn/models/ZhipuAI/CogVideoX-2b) as well.
+```
+git lfs install
+```
-```shell
-git clone https://huggingface.co/THUDM/CogVideoX-2b.git
+Next, clone the T5 model, which is used as an encoder and doesnât require training or fine-tuning.
+> You may also use the model file location on [Modelscope](https://modelscope.cn/models/ZhipuAI/CogVideoX-2b).
+
+```
+git clone https://huggingface.co/THUDM/CogVideoX-2b.git # Download model from Huggingface
+# git clone https://www.modelscope.cn/ZhipuAI/CogVideoX-2b.git # Download from Modelscope
mkdir t5-v1_1-xxl
mv CogVideoX-2b/text_encoder/* CogVideoX-2b/tokenizer/* t5-v1_1-xxl
```
-By following the above approach, you will obtain a safetensor format T5 file. Ensure that there are no errors when
-loading it into Deepspeed in Finetune.
+This will yield a safetensor format T5 file that can be loaded without error during Deepspeed fine-tuning.
```
âââ added_tokens.json
@@ -80,11 +93,11 @@ loading it into Deepspeed in Finetune.
0 directories, 8 files
```
-### 3. Modify the file in `configs/cogvideox_2b.yaml`.
+### 3. Modify `configs/cogvideox_*.yaml` file.
```yaml
model:
- scale_factor: 1.15258426
+ scale_factor: 1.55258426
disable_first_stage_autocast: true
log_keys:
- txt
@@ -160,14 +173,14 @@ model:
ucg_rate: 0.1
target: sgm.modules.encoders.modules.FrozenT5Embedder
params:
- model_dir: "t5-v1_1-xxl" # Absolute path to the CogVideoX-2b/t5-v1_1-xxl weights folder
+ model_dir: "t5-v1_1-xxl" # absolute path to CogVideoX-2b/t5-v1_1-xxl weight folder
max_length: 226
first_stage_config:
target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper
params:
cp_size: 1
- ckpt_path: "CogVideoX-2b-sat/vae/3d-vae.pt" # Absolute path to the CogVideoX-2b-sat/vae/3d-vae.pt folder
+ ckpt_path: "CogVideoX-2b-sat/vae/3d-vae.pt" # absolute path to CogVideoX-2b-sat/vae/3d-vae.pt file
ignore_keys: [ 'loss' ]
loss_config:
@@ -239,48 +252,46 @@ model:
num_steps: 50
```
-### 4. Modify the file in `configs/inference.yaml`.
+### 4. Modify `configs/inference.yaml` file.
```yaml
args:
latent_channels: 16
mode: inference
- load: "{absolute_path/to/your}/transformer" # Absolute path to the CogVideoX-2b-sat/transformer folder
+ load: "{absolute_path/to/your}/transformer" # Absolute path to CogVideoX-2b-sat/transformer folder
# load: "{your lora folder} such as zRzRzRzRzRzRzR/lora-disney-08-20-13-28" # This is for Full model without lora adapter
batch_size: 1
- input_type: txt # You can choose txt for pure text input, or change to cli for command line input
- input_file: configs/test.txt # Pure text file, which can be edited
- sampling_num_frames: 13 # Must be 13, 11 or 9
+ input_type: txt # You can choose "txt" for plain text input or change to "cli" for command-line input
+ input_file: configs/test.txt # Plain text file, can be edited
+ sampling_num_frames: 13 # For CogVideoX1.5-5B it must be 42 or 22. For CogVideoX-5B / 2B, it must be 13, 11, or 9.
sampling_fps: 8
fp16: True # For CogVideoX-2B
- # bf16: True # For CogVideoX-5B
+ # bf16: True # For CogVideoX-5B
output_dir: outputs/
force_inference: True
```
-+ Modify `configs/test.txt` if multiple prompts is required, in which each line makes a prompt.
-+ For better prompt formatting, refer to [convert_demo.py](../inference/convert_demo.py), for which you should set the
- OPENAI_API_KEY as your environmental variable.
-+ Modify `input_type` in `configs/inference.yaml` if use command line as prompt iuput.
++ If using a text file to save multiple prompts, modify `configs/test.txt` as needed. One prompt per line. If you are unsure how to write prompts, use [this code](../inference/convert_demo.py) to call an LLM for refinement.
++ To use command-line input, modify:
-```yaml
+```
input_type: cli
```
-This allows input from the command line as prompts.
+This allows you to enter prompts from the command line.
-Change `output_dir` if you wish to modify the address of the output video
+To modify the output video location, change:
-```yaml
+```
output_dir: outputs/
```
-It is saved by default in the `.outputs/` folder.
+The default location is the `.outputs/` folder.
-### 5. Run the inference code to perform inference.
+### 5. Run the Inference Code to Perform Inference
-```shell
+```
bash inference.sh
```
@@ -288,95 +299,91 @@ bash inference.sh
### Preparing the Dataset
-The dataset format should be as follows:
+The dataset should be structured as follows:
```
.
âââ labels
-â  âââ 1.txt
-â  âââ 2.txt
-â  âââ ...
+â âââ 1.txt
+â âââ 2.txt
+â âââ ...
âââ videos
âââ 1.mp4
âââ 2.mp4
âââ ...
```
-Each text file shares the same name as its corresponding video, serving as the label for that video. Videos and labels
-should be matched one-to-one. Generally, a single video should not be associated with multiple labels.
+Each txt file should have the same name as the corresponding video file and contain the label for that video. The videos and labels should correspond one-to-one. Generally, avoid using one video with multiple labels.
-For style fine-tuning, please prepare at least 50 videos and labels with similar styles to ensure proper fitting.
+For style fine-tuning, prepare at least 50 videos and labels with a similar style to facilitate fitting.
-### Modifying Configuration Files
+### Modifying the Configuration File
-We support two fine-tuning methods: `Lora` and full-parameter fine-tuning. Please note that both methods only fine-tune
-the `transformer` part and do not modify the `VAE` section. `T5` is used solely as an Encoder. Please modify
-the `configs/sft.yaml` (for full-parameter fine-tuning) file as follows:
+We support two fine-tuning methods: `Lora` and full-parameter fine-tuning. Note that both methods only fine-tune the `transformer` part. The `VAE` part is not modified, and `T5` is only used as an encoder.
+Modify the files in `configs/sft.yaml` (full fine-tuning) as follows:
-```
- # checkpoint_activations: True ## Using gradient checkpointing (Both checkpoint_activations in the config file need to be set to True)
+```yaml
+ # checkpoint_activations: True ## using gradient checkpointing (both `checkpoint_activations` in the config file need to be set to True)
model_parallel_size: 1 # Model parallel size
- experiment_name: lora-disney # Experiment name (do not modify)
- mode: finetune # Mode (do not modify)
- load: "{your_CogVideoX-2b-sat_path}/transformer" ## Transformer model path
- no_load_rng: True # Whether to load random seed
+ experiment_name: lora-disney # Experiment name (do not change)
+ mode: finetune # Mode (do not change)
+ load: "{your_CogVideoX-2b-sat_path}/transformer" ## Path to Transformer model
+ no_load_rng: True # Whether to load random number seed
train_iters: 1000 # Training iterations
eval_iters: 1 # Evaluation iterations
eval_interval: 100 # Evaluation interval
eval_batch_size: 1 # Evaluation batch size
- save: ckpts # Model save path
- save_interval: 100 # Model save interval
+ save: ckpts # Model save path
+ save_interval: 100 # Save interval
log_interval: 20 # Log output interval
train_data: [ "your train data path" ]
- valid_data: [ "your val data path" ] # Training and validation datasets can be the same
- split: 1,0,0 # Training, validation, and test set ratio
- num_workers: 8 # Number of worker threads for data loader
- force_train: True # Allow missing keys when loading checkpoint (T5 and VAE are loaded separately)
- only_log_video_latents: True # Avoid memory overhead caused by VAE decode
+ valid_data: [ "your val data path" ] # Training and validation sets can be the same
+ split: 1,0,0 # Proportion for training, validation, and test sets
+ num_workers: 8 # Number of data loader workers
+ force_train: True # Allow missing keys when loading checkpoint (T5 and VAE loaded separately)
+ only_log_video_latents: True # Avoid memory usage from VAE decoding
deepspeed:
bf16:
- enabled: False # For CogVideoX-2B set to False and for CogVideoX-5B set to True
+ enabled: False # For CogVideoX-2B Turn to False and For CogVideoX-5B Turn to True
fp16:
- enabled: True # For CogVideoX-2B set to True and for CogVideoX-5B set to False
+ enabled: True # For CogVideoX-2B Turn to True and For CogVideoX-5B Turn to False
```
-If you wish to use Lora fine-tuning, you also need to modify the `cogvideox__lora` file:
+``` To use Lora fine-tuning, you also need to modify `cogvideox__lora` file:
-Here, take `CogVideoX-2B` as a reference:
+Here's an example using `CogVideoX-2B`:
```
model:
- scale_factor: 1.15258426
+ scale_factor: 1.55258426
disable_first_stage_autocast: true
- not_trainable_prefixes: [ 'all' ] ## Uncomment
+ not_trainable_prefixes: [ 'all' ] ## Uncomment to unlock
log_keys:
- - txt'
+ - txt
- lora_config: ## Uncomment
+ lora_config: ## Uncomment to unlock
target: sat.model.finetune.lora2.LoraMixin
params:
r: 256
```
-### Modifying Run Scripts
+### Modify the Run Script
-Edit `finetune_single_gpu.sh` or `finetune_multi_gpus.sh` to select the configuration file. Below are two examples:
+Edit `finetune_single_gpu.sh` or `finetune_multi_gpus.sh` and select the config file. Below are two examples:
-1. If you want to use the `CogVideoX-2B` model and the `Lora` method, you need to modify `finetune_single_gpu.sh`
- or `finetune_multi_gpus.sh`:
+1. If you want to use the `CogVideoX-2B` model with `Lora`, modify `finetune_single_gpu.sh` or `finetune_multi_gpus.sh` as follows:
```
run_cmd="torchrun --standalone --nproc_per_node=8 train_video.py --base configs/cogvideox_2b_lora.yaml configs/sft.yaml --seed $RANDOM"
```
-2. If you want to use the `CogVideoX-2B` model and the `full-parameter fine-tuning` method, you need to
- modify `finetune_single_gpu.sh` or `finetune_multi_gpus.sh`:
+2. If you want to use the `CogVideoX-2B` model with full fine-tuning, modify `finetune_single_gpu.sh` or `finetune_multi_gpus.sh` as follows:
```
run_cmd="torchrun --standalone --nproc_per_node=8 train_video.py --base configs/cogvideox_2b.yaml configs/sft.yaml --seed $RANDOM"
```
-### Fine-Tuning and Evaluation
+### Fine-tuning and Validation
Run the inference code to start fine-tuning.
@@ -385,45 +392,42 @@ bash finetune_single_gpu.sh # Single GPU
bash finetune_multi_gpus.sh # Multi GPUs
```
-### Using the Fine-Tuned Model
+### Using the Fine-tuned Model
-The fine-tuned model cannot be merged; here is how to modify the inference configuration file `inference.sh`:
+The fine-tuned model cannot be merged. Hereâs how to modify the inference configuration file `inference.sh`
```
-run_cmd="$environs python sample_video.py --base configs/cogvideox__lora.yaml configs/inference.yaml --seed 42"
+run_cmd="$environs python sample_video.py --base configs/cogvideox__lora.yaml configs/inference.yaml --seed 42"
```
-Then, execute the code:
+Then, run the code:
```
bash inference.sh
```
-### Converting to Huggingface Diffusers Supported Weights
+### Converting to Huggingface Diffusers-compatible Weights
-The SAT weight format is different from Huggingface's weight format and needs to be converted. Please run:
+The SAT weight format is different from Huggingfaceâs format and requires conversion. Run
-```shell
+```
python ../tools/convert_weight_sat2hf.py
```
-### Exporting Huggingface Diffusers lora LoRA Weights from SAT Checkpoints
+### Exporting Lora Weights from SAT to Huggingface Diffusers
-After completing the training using the above steps, we get a SAT checkpoint with LoRA weights. You can find the file
-at `{args.save}/1000/1000/mp_rank_00_model_states.pt`.
+Support is provided for exporting Lora weights from SAT to Huggingface Diffusers format.
+ After training with the above steps, youâll find the SAT model with Lora weights in {args.save}/1000/1000/mp_rank_00_model_states.pt
-The script for exporting LoRA weights can be found in the CogVideoX repository at `tools/export_sat_lora_weight.py`.
-After exporting, you can use `load_cogvideox_lora.py` for inference.
+The export script `export_sat_lora_weight.py` is located in the CogVideoX repository under `tools/`. After exporting, use `load_cogvideox_lora.py` for inference.
Export command:
-```bash
-python tools/export_sat_lora_weight.py --sat_pt_path {args.save}/{experiment_name}-09-09-21-10/1000/mp_rank_00_model_states.pt --lora_save_directory {args.save}/export_hf_lora_weights_1/
+```
+python tools/export_sat_lora_weight.py --sat_pt_path {args.save}/{experiment_name}-09-09-21-10/1000/mp_rank_00_model_states.pt --lora_save_directory {args.save}/export_hf_lora_weights_1/
```
-This training mainly modified the following model structures. The table below lists the corresponding structure mappings
-for converting to the HF (Hugging Face) format LoRA structure. As you can see, LoRA adds a low-rank weight to the
-model's attention structure.
+The following model structures were modified during training. Here is the mapping between SAT and HF Lora structures. Lora adds a low-rank weight to the attention structure of the model.
```
'attention.query_key_value.matrix_A.0': 'attn1.to_q.lora_A.weight',
@@ -436,5 +440,5 @@ model's attention structure.
'attention.dense.matrix_B.0': 'attn1.to_out.0.lora_B.weight'
```
-Using export_sat_lora_weight.py, you can convert the SAT checkpoint into the HF LoRA format.
-
+Using `export_sat_lora_weight.py` will convert these to the HF format Lora structure.
+
\ No newline at end of file
diff --git a/sat/README_ja.md b/sat/README_ja.md
index ee1abcd..3685ba3 100644
--- a/sat/README_ja.md
+++ b/sat/README_ja.md
@@ -1,27 +1,37 @@
-# SAT CogVideoX-2B
+# SAT CogVideoX
-[Read this in English.](./README_zh)
+[Read this in English.](./README.md)
[äžæé
读](./README_zh.md)
-ãã®ãã©ã«ãã«ã¯ã[SAT](https://github.com/THUDM/SwissArmyTransformer) ãŠã§ã€ãã䜿çšããæšè«ã³ãŒããšãSAT
-ãŠã§ã€ãã®ãã¡ã€ã³ãã¥ãŒãã³ã°ã³ãŒããå«ãŸããŠããŸãã
-
-ãã®ã³ãŒãã¯ãããŒã ãã¢ãã«ããã¬ãŒãã³ã°ããããã«äœ¿çšãããã¬ãŒã ã¯ãŒã¯ã§ããã³ã¡ã³ããå°ãªãã泚ææ·±ãç 究ããå¿
èŠããããŸãã
+ãã®ãã©ã«ãã«ã¯ã[SAT](https://github.com/THUDM/SwissArmyTransformer)ã®éã¿ã䜿çšããæšè«ã³ãŒããšãSATéã¿ã®ãã¡ã€ã³ãã¥ãŒãã³ã°ã³ãŒããå«ãŸããŠããŸãã
+ãã®ã³ãŒãã¯ãããŒã ãã¢ãã«ãèšç·Žããéã«äœ¿çšãããã¬ãŒã ã¯ãŒã¯ã§ããã³ã¡ã³ããå°ãªãããã泚ææ·±ã確èªããå¿
èŠããããŸãã
## æšè«ã¢ãã«
-### 1. ãã®ãã©ã«ãã«å¿
èŠãªäŸåé¢ä¿ãæ£ããã€ã³ã¹ããŒã«ãããŠããããšã確èªããŠãã ããã
+### 1. ãã®ãã©ã«ãå
ã®å¿
èŠãªäŸåé¢ä¿ããã¹ãŠã€ã³ã¹ããŒã«ãããŠããããšã確èªããŠãã ãã
-```shell
+```
pip install -r requirements.txt
```
-### 2. ã¢ãã«ãŠã§ã€ããããŠã³ããŒãããŸã
+### 2. ã¢ãã«ã®éã¿ãããŠã³ããŒã
+ ãŸããSATãã©ãŒããã¢ãã«ã®éã¿ãããŠã³ããŒãããŠãã ããã
-ãŸããSAT ãã©ãŒã«ç§»åããŠã¢ãã«ã®éã¿ãããŠã³ããŒãããŸãã CogVideoX-2B ã¢ãã«ã®å Žåã¯ã次ã®ããã«ããŠã³ããŒãããŠãã ããã
+#### CogVideoX1.5 ã¢ãã«
-```shell
+```
+git lfs install
+git clone https://huggingface.co/THUDM/CogVideoX1.5-5B-SAT
+```
+
+ããã«ãããTransformersãVAEãT5 Encoderã®3ã€ã®ã¢ãã«ãããŠã³ããŒããããŸãã
+
+#### CogVideoX ã¢ãã«
+
+CogVideoX-2B ã¢ãã«ã«ã€ããŠã¯ã以äžã®ããã«ããŠã³ããŒãããŠãã ããïŒ
+
+```
mkdir CogVideoX-2b-sat
cd CogVideoX-2b-sat
wget https://cloud.tsinghua.edu.cn/f/fdba7608a49c463ba754/?dl=1
@@ -32,12 +42,12 @@ mv 'index.html?dl=1' transformer.zip
unzip transformer.zip
```
-CogVideoX-5B ã¢ãã«ã® `transformers` ãã¡ã€ã«ã以äžã®ãªã³ã¯ããããŠã³ããŒãããŠãã ãã ïŒVAE ãã¡ã€ã«ã¯ 2B ãšåãã§ãïŒïŒ
+CogVideoX-5B ã¢ãã«ã® `transformers` ãã¡ã€ã«ãããŠã³ããŒãããŠãã ããïŒVAEãã¡ã€ã«ã¯2Bãšåãã§ãïŒïŒ
+ [CogVideoX-5B](https://cloud.tsinghua.edu.cn/d/fcef5b3904294a6885e5/?p=%2F&mode=list)
+ [CogVideoX-5B-I2V](https://cloud.tsinghua.edu.cn/d/5cc62a2d6e7d45c0a2f6/?p=%2F1&mode=list)
-次ã«ãã¢ãã«ãã¡ã€ã«ã以äžã®åœ¢åŒã«ãã©ãŒãããããå¿
èŠããããŸãïŒ
+ã¢ãã«ãã¡ã€ã«ã以äžã®ããã«é
眮ããŠãã ããïŒ
```
.
@@ -49,24 +59,24 @@ CogVideoX-5B ã¢ãã«ã® `transformers` ãã¡ã€ã«ã以äžã®ãªã³ã¯ãã
âââ 3d-vae.pt
```
-ã¢ãã«ã®éã¿ãã¡ã€ã«ã倧ããããã`git lfs`ã䜿çšããããšããå§ãããããŸãã`git lfs`
-ã®ã€ã³ã¹ããŒã«ã«ã€ããŠã¯ã[ãã¡ã](https://github.com/git-lfs/git-lfs?tab=readme-ov-file#installing)ããåç
§ãã ããã
+ã¢ãã«ã®éã¿ãã¡ã€ã«ã倧ããããã`git lfs`ã®äœ¿çšããå§ãããŸãã
+`git lfs`ã®ã€ã³ã¹ããŒã«æ¹æ³ã¯[ãã¡ã](https://github.com/git-lfs/git-lfs?tab=readme-ov-file#installing)ãåç
§ããŠãã ããã
-```shell
+```
git lfs install
```
-次ã«ãT5 ã¢ãã«ãã¯ããŒã³ããŸããããã¯ãã¬ãŒãã³ã°ããã¡ã€ã³ãã¥ãŒãã³ã°ã«ã¯äœ¿çšãããŸãããã䜿çšããå¿
èŠããããŸãã
-> ã¢ãã«ãè€è£œããéã«ã¯ã[Modelscope](https://modelscope.cn/models/ZhipuAI/CogVideoX-2b)ã®ã¢ãã«ãã¡ã€ã«ã®å Žæãã䜿çšããã ããŸãã
+次ã«ãT5ã¢ãã«ãã¯ããŒã³ããŸãããã®ã¢ãã«ã¯EncoderãšããŠã®ã¿äœ¿çšãããèšç·Žããã¡ã€ã³ãã¥ãŒãã³ã°ã¯å¿
èŠãããŸããã
+> [Modelscope](https://modelscope.cn/models/ZhipuAI/CogVideoX-2b)äžã®ã¢ãã«ãã¡ã€ã«ã䜿çšå¯èœã§ãã
-```shell
-git clone https://huggingface.co/THUDM/CogVideoX-2b.git #ãã®ã³ã°ãã§ã€ã¹(huggingface.org)ããã¢ãã«ãããŠã³ããŒãããã ããŸã
-# git clone https://www.modelscope.cn/ZhipuAI/CogVideoX-2b.git #Modelscopeããã¢ãã«ãããŠã³ããŒãããã ããŸã
+```
+git clone https://huggingface.co/THUDM/CogVideoX-2b.git # Huggingfaceããã¢ãã«ãããŠã³ããŒã
+# git clone https://www.modelscope.cn/ZhipuAI/CogVideoX-2b.git # ModelscopeããããŠã³ããŒã
mkdir t5-v1_1-xxl
mv CogVideoX-2b/text_encoder/* CogVideoX-2b/tokenizer/* t5-v1_1-xxl
```
-äžèšã®æ¹æ³ã«åŸãããšã§ãsafetensor 圢åŒã® T5 ãã¡ã€ã«ãååŸã§ããŸããããã«ãããDeepspeed ã§ã®ãã¡ã€ã³ãã¥ãŒãã³ã°äžã«ãšã©ãŒãçºçããªãããã«ããŸãã
+ããã«ãããDeepspeedãã¡ã€ã³ãã¥ãŒãã³ã°äžã«ãšã©ãŒãªãããŒãã§ããsafetensor圢åŒã®T5ãã¡ã€ã«ãäœæãããŸãã
```
âââ added_tokens.json
@@ -81,11 +91,11 @@ mv CogVideoX-2b/text_encoder/* CogVideoX-2b/tokenizer/* t5-v1_1-xxl
0 directories, 8 files
```
-### 3. `configs/cogvideox_2b.yaml` ãã¡ã€ã«ãå€æŽããŸãã
+### 3. `configs/cogvideox_*.yaml`ãã¡ã€ã«ãç·šé
```yaml
model:
- scale_factor: 1.15258426
+ scale_factor: 1.55258426
disable_first_stage_autocast: true
log_keys:
- txt
@@ -123,7 +133,7 @@ model:
num_attention_heads: 30
transformer_args:
- checkpoint_activations: True ## ã°ã©ããŒã·ã§ã³ ãã§ãã¯ãã€ã³ãã䜿çšãã
+ checkpoint_activations: True ## using gradient checkpointing
vocab_size: 1
max_sequence_length: 64
layernorm_order: pre
@@ -161,14 +171,14 @@ model:
ucg_rate: 0.1
target: sgm.modules.encoders.modules.FrozenT5Embedder
params:
- model_dir: "t5-v1_1-xxl" # CogVideoX-2b/t5-v1_1-xxlãã©ã«ãã®çµ¶å¯Ÿãã¹
+ model_dir: "t5-v1_1-xxl" # CogVideoX-2b/t5-v1_1-xxl éã¿ãã©ã«ãã®çµ¶å¯Ÿãã¹
max_length: 226
first_stage_config:
target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper
params:
cp_size: 1
- ckpt_path: "CogVideoX-2b-sat/vae/3d-vae.pt" # CogVideoX-2b-sat/vae/3d-vae.ptãã©ã«ãã®çµ¶å¯Ÿãã¹
+ ckpt_path: "CogVideoX-2b-sat/vae/3d-vae.pt" # CogVideoX-2b-sat/vae/3d-vae.ptãã¡ã€ã«ã®çµ¶å¯Ÿãã¹
ignore_keys: [ 'loss' ]
loss_config:
@@ -240,7 +250,7 @@ model:
num_steps: 50
```
-### 4. `configs/inference.yaml` ãã¡ã€ã«ãå€æŽããŸãã
+### 4. `configs/inference.yaml`ãã¡ã€ã«ãç·šé
```yaml
args:
@@ -250,38 +260,36 @@ args:
# load: "{your lora folder} such as zRzRzRzRzRzRzR/lora-disney-08-20-13-28" # This is for Full model without lora adapter
batch_size: 1
- input_type: txt #TXTã®ããã¹ããã¡ã€ã«ãå
¥åãšããŠéžæãããããCLIã³ãã³ãã©ã€ã³ãå
¥åãšããŠå€æŽããããããã ããŸã
- input_file: configs/test.txt #ããã¹ããã¡ã€ã«ã®ãã¹ã§ãããã«å¯ŸããŠç·šéããããŠããã ããŸã
- sampling_num_frames: 13 # Must be 13, 11 or 9
+ input_type: txt # "txt"ã§ãã¬ãŒã³ããã¹ãå
¥åã"cli"ã§ã³ãã³ãã©ã€ã³å
¥åãéžæå¯èœ
+ input_file: configs/test.txt # ãã¬ãŒã³ããã¹ããã¡ã€ã«ãç·šéå¯èœ
+ sampling_num_frames: 13 # CogVideoX1.5-5Bã§ã¯42ãŸãã¯22ãCogVideoX-5B / 2Bã§ã¯13, 11, ãŸãã¯9
sampling_fps: 8
- fp16: True # For CogVideoX-2B
- # bf16: True # For CogVideoX-5B
+ fp16: True # CogVideoX-2Bçš
+ # bf16: True # CogVideoX-5Bçš
output_dir: outputs/
force_inference: True
```
-+ è€æ°ã®ããã³ãããä¿åããããã« txt ã䜿çšããå Žåã¯ã`configs/test.txt`
- ãåç
§ããŠå€æŽããŠãã ããã1è¡ã«1ã€ã®ããã³ãããèšè¿°ããŸããããã³ããã®æžãæ¹ãããããªãå Žåã¯ãæåã« [ãã®ã³ãŒã](../inference/convert_demo.py)
- ã䜿çšã㊠LLM ã«ãããªãã¡ã€ã³ã¡ã³ããåŒã³åºãããšãã§ããŸãã
-+ ã³ãã³ãã©ã€ã³ãå
¥åãšããŠäœ¿çšããå Žåã¯ã次ã®ããã«å€æŽããŸãã
++ è€æ°ã®ããã³ãããå«ãããã¹ããã¡ã€ã«ã䜿çšããå Žåã`configs/test.txt`ãé©å®ç·šéããŠãã ããã1è¡ã«ã€ã1ããã³ããã§ããããã³ããã®æžãæ¹ãåãããªãå Žåã¯ã[ãã¡ãã®ã³ãŒã](../inference/convert_demo.py)ã䜿çšããŠLLMã§è£æ£ã§ããŸãã
++ ã³ãã³ãã©ã€ã³å
¥åã䜿çšããå Žåã以äžã®ããã«å€æŽããŸãïŒ
-```yaml
+```
input_type: cli
```
ããã«ãããã³ãã³ãã©ã€ã³ããããã³ãããå
¥åã§ããŸãã
-åºåãããªã®ãã£ã¬ã¯ããªãå€æŽãããå Žåã¯ã次ã®ããã«å€æŽã§ããŸãïŒ
+åºåãããªã®ä¿åå Žæãå€æŽããå Žåã¯ã以äžãç·šéããŠãã ããïŒ
-```yaml
+```
output_dir: outputs/
```
-ããã©ã«ãã§ã¯ `.outputs/` ãã©ã«ãã«ä¿åãããŸãã
+ããã©ã«ãã§ã¯`.outputs/`ãã©ã«ãã«ä¿åãããŸãã
-### 5. æšè«ã³ãŒããå®è¡ããŠæšè«ãéå§ããŸãã
+### 5. æšè«ã³ãŒããå®è¡ããŠæšè«ãéå§
-```shell
+```
bash inference.sh
```
@@ -289,7 +297,7 @@ bash inference.sh
### ããŒã¿ã»ããã®æºå
-ããŒã¿ã»ããã®åœ¢åŒã¯æ¬¡ã®ããã«ãªããŸãïŒ
+ããŒã¿ã»ããã¯ä»¥äžã®æ§é ã§ããå¿
èŠããããŸãïŒ
```
.
@@ -303,123 +311,215 @@ bash inference.sh
âââ ...
```
-å txt ãã¡ã€ã«ã¯å¯Ÿå¿ãããããªãã¡ã€ã«ãšåãååã§ããããã®ãããªã®ã©ãã«ãå«ãã§ããŸããåãããªã¯ã©ãã«ãšäžå¯Ÿäžã§å¯Ÿå¿ããå¿
èŠããããŸããéåžžã1ã€ã®ãããªã«è€æ°ã®ã©ãã«ãæãããããšã¯ãããŸããã
+åtxtãã¡ã€ã«ã¯å¯Ÿå¿ãããããªãã¡ã€ã«ãšåãååã§ããããªã®ã©ãã«ãå«ãã§ããŸãããããªãšã©ãã«ã¯äžå¯Ÿäžã§å¯Ÿå¿ãããå¿
èŠããããŸããéåžžã1ã€ã®ãããªã«è€æ°ã®ã©ãã«ã䜿çšããããšã¯é¿ããŠãã ããã
-ã¹ã¿ã€ã«ãã¡ã€ã³ãã¥ãŒãã³ã°ã®å Žåãå°ãªããšã50æ¬ã®ã¹ã¿ã€ã«ã䌌ããããªãšã©ãã«ãæºåãããã£ããã£ã³ã°ã容æã«ããŸãã
+ã¹ã¿ã€ã«ã®ãã¡ã€ã³ãã¥ãŒãã³ã°ã®å Žåãã¹ã¿ã€ã«ã䌌ããããªãšã©ãã«ãå°ãªããšã50æ¬æºåãããã£ããã£ã³ã°ãä¿é²ããŸãã
-### èšå®ãã¡ã€ã«ã®å€æŽ
+### èšå®ãã¡ã€ã«ã®ç·šé
-`Lora` ãšãã«ãã©ã¡ãŒã¿åŸ®èª¿æŽã®2ã€ã®æ¹æ³ããµããŒãããŠããŸããäž¡æ¹ã®åŸ®èª¿æŽæ¹æ³ã¯ã`transformer` éšåã®ã¿ã埮調æŽãã`VAE`
-éšåã«ã¯å€æŽãå ããªãããšã«æ³šæããŠãã ããã`T5` ã¯ãšã³ã³ãŒããŒãšããŠã®ã¿äœ¿çšãããŸãã以äžã®ããã« `configs/sft.yaml` (
-ãã«ãã©ã¡ãŒã¿åŸ®èª¿æŽçš) ãã¡ã€ã«ãå€æŽããŠãã ããã
+``` `Lora`ãšå
šãã©ã¡ãŒã¿ã®ãã¡ã€ã³ãã¥ãŒãã³ã°ã®2çš®é¡ããµããŒãããŠããŸããã©ã¡ãã`transformer`éšåã®ã¿ããã¡ã€ã³ãã¥ãŒãã³ã°ãã`VAE`éšåã¯å€æŽãããã`T5`ã¯ãšã³ã³ãŒããŒãšããŠã®ã¿äœ¿çšãããŸãã
+``` 以äžã®ããã«ããŠ`configs/sft.yaml`ïŒå
šéãã¡ã€ã³ãã¥ãŒãã³ã°ïŒãã¡ã€ã«ãç·šéããŠãã ããïŒ
```
- # checkpoint_activations: True ## åŸé
ãã§ãã¯ãã€ã³ãã䜿çšããå Žå (èšå®ãã¡ã€ã«å
ã®2ã€ã® checkpoint_activations ã True ã«èšå®ããå¿
èŠããããŸã)
+ # checkpoint_activations: True ## using gradient checkpointing (configãã¡ã€ã«å
ã®2ã€ã®`checkpoint_activations`ãäž¡æ¹Trueã«èšå®)
model_parallel_size: 1 # ã¢ãã«äžŠåãµã€ãº
- experiment_name: lora-disney # å®éšå (å€æŽããªãã§ãã ãã)
- mode: finetune # ã¢ãŒã (å€æŽããªãã§ãã ãã)
- load: "{your_CogVideoX-2b-sat_path}/transformer" ## Transformer ã¢ãã«ã®ãã¹
- no_load_rng: True # ä¹±æ°ã·ãŒããèªã¿èŸŒããã©ãã
+ experiment_name: lora-disney # å®éšåïŒå€æŽäžèŠïŒ
+ mode: finetune # ã¢ãŒãïŒå€æŽäžèŠïŒ
+ load: "{your_CogVideoX-2b-sat_path}/transformer" ## Transformerã¢ãã«ã®ãã¹
+ no_load_rng: True # ä¹±æ°ã·ãŒããããŒããããã©ãã
train_iters: 1000 # ãã¬ãŒãã³ã°ã€ãã¬ãŒã·ã§ã³æ°
- eval_iters: 1 # è©äŸ¡ã€ãã¬ãŒã·ã§ã³æ°
- eval_interval: 100 # è©äŸ¡éé
- eval_batch_size: 1 # è©äŸ¡ããããµã€ãº
- save: ckpts # ã¢ãã«ä¿åãã¹
- save_interval: 100 # ã¢ãã«ä¿åéé
+ eval_iters: 1 # æ€èšŒã€ãã¬ãŒã·ã§ã³æ°
+ eval_interval: 100 # æ€èšŒéé
+ eval_batch_size: 1 # æ€èšŒããããµã€ãº
+ save: ckpts # ã¢ãã«ä¿åãã¹
+ save_interval: 100 # ä¿åéé
log_interval: 20 # ãã°åºåéé
train_data: [ "your train data path" ]
- valid_data: [ "your val data path" ] # ãã¬ãŒãã³ã°ããŒã¿ãšè©äŸ¡ããŒã¿ã¯åãã§ãæ§ããŸãã
- split: 1,0,0 # ãã¬ãŒãã³ã°ã»ãããè©äŸ¡ã»ããããã¹ãã»ããã®å²å
- num_workers: 8 # ããŒã¿ããŒããŒã®ã¯ãŒã«ãŒã¹ã¬ããæ°
- force_train: True # ãã§ãã¯ãã€ã³ããããŒããããšãã«æ¬ èœããããŒãèš±å¯ (T5 ãš VAE ã¯å¥ã
ã«ããŒããããŸã)
- only_log_video_latents: True # VAE ã®ãã³ãŒãã«ããã¡ã¢ãªãªãŒããŒããããåé¿
+ valid_data: [ "your val data path" ] # ãã¬ãŒãã³ã°ã»ãããšæ€èšŒã»ããã¯åãã§ãæ§ããŸãã
+ split: 1,0,0 # ãã¬ãŒãã³ã°ã»ãããæ€èšŒã»ããããã¹ãã»ããã®å²å
+ num_workers: 8 # ããŒã¿ããŒããŒã®ã¯ãŒã«ãŒæ°
+ force_train: True # ãã§ãã¯ãã€ã³ããããŒãããéã«`missing keys`ãèš±å¯ïŒT5ãšVAEã¯å¥éããŒãïŒ
+ only_log_video_latents: True # VAEã®ãã³ãŒãã«ããã¡ã¢ãªäœ¿çšéãæãã
deepspeed:
bf16:
- enabled: False # CogVideoX-2B ã®å Žå㯠False ã«èšå®ããCogVideoX-5B ã®å Žå㯠True ã«èšå®
+ enabled: False # CogVideoX-2B çšã¯ FalseãCogVideoX-5B çšã¯ True ã«èšå®
fp16:
- enabled: True # CogVideoX-2B ã®å Žå㯠True ã«èšå®ããCogVideoX-5B ã®å Žå㯠False ã«èšå®
+ enabled: True # CogVideoX-2B çšã¯ TrueãCogVideoX-5B çšã¯ False ã«èšå®
+```
+```yaml
+args:
+ latent_channels: 16
+ mode: inference
+ load: "{absolute_path/to/your}/transformer" # Absolute path to CogVideoX-2b-sat/transformer folder
+ # load: "{your lora folder} such as zRzRzRzRzRzRzR/lora-disney-08-20-13-28" # This is for Full model without lora adapter
+
+ batch_size: 1
+ input_type: txt # You can choose "txt" for plain text input or change to "cli" for command-line input
+ input_file: configs/test.txt # Plain text file, can be edited
+ sampling_num_frames: 13 # For CogVideoX1.5-5B it must be 42 or 22. For CogVideoX-5B / 2B, it must be 13, 11, or 9.
+ sampling_fps: 8
+ fp16: True # For CogVideoX-2B
+ # bf16: True # For CogVideoX-5B
+ output_dir: outputs/
+ force_inference: True
```
-Lora 埮調æŽã䜿çšãããå Žåã¯ã`cogvideox__lora` ãã¡ã€ã«ãå€æŽããå¿
èŠããããŸãã
-
-ããã§ã¯ã`CogVideoX-2B` ãåèã«ããŸãã
++ If using a text file to save multiple prompts, modify `configs/test.txt` as needed. One prompt per line. If you are unsure how to write prompts, use [this code](../inference/convert_demo.py) to call an LLM for refinement.
++ To use command-line input, modify:
```
+input_type: cli
+```
+
+This allows you to enter prompts from the command line.
+
+To modify the output video location, change:
+
+```
+output_dir: outputs/
+```
+
+The default location is the `.outputs/` folder.
+
+### 5. Run the Inference Code to Perform Inference
+
+```
+bash inference.sh
+```
+
+## Fine-tuning the Model
+
+### Preparing the Dataset
+
+The dataset should be structured as follows:
+
+```
+.
+âââ labels
+â âââ 1.txt
+â âââ 2.txt
+â âââ ...
+âââ videos
+ âââ 1.mp4
+ âââ 2.mp4
+ âââ ...
+```
+
+Each txt file should have the same name as the corresponding video file and contain the label for that video. The videos and labels should correspond one-to-one. Generally, avoid using one video with multiple labels.
+
+For style fine-tuning, prepare at least 50 videos and labels with a similar style to facilitate fitting.
+
+### Modifying the Configuration File
+
+We support two fine-tuning methods: `Lora` and full-parameter fine-tuning. Note that both methods only fine-tune the `transformer` part. The `VAE` part is not modified, and `T5` is only used as an encoder.
+Modify the files in `configs/sft.yaml` (full fine-tuning) as follows:
+
+```yaml
+ # checkpoint_activations: True ## using gradient checkpointing (both `checkpoint_activations` in the config file need to be set to True)
+ model_parallel_size: 1 # Model parallel size
+ experiment_name: lora-disney # Experiment name (do not change)
+ mode: finetune # Mode (do not change)
+ load: "{your_CogVideoX-2b-sat_path}/transformer" ## Path to Transformer model
+ no_load_rng: True # Whether to load random number seed
+ train_iters: 1000 # Training iterations
+ eval_iters: 1 # Evaluation iterations
+ eval_interval: 100 # Evaluation interval
+ eval_batch_size: 1 # Evaluation batch size
+ save: ckpts # Model save path
+ save_interval: 100 # Save interval
+ log_interval: 20 # Log output interval
+ train_data: [ "your train data path" ]
+ valid_data: [ "your val data path" ] # Training and validation sets can be the same
+ split: 1,0,0 # Proportion for training, validation, and test sets
+ num_workers: 8 # Number of data loader workers
+ force_train: True # Allow missing keys when loading checkpoint (T5 and VAE loaded separately)
+ only_log_video_latents: True # Avoid memory usage from VAE decoding
+ deepspeed:
+ bf16:
+ enabled: False # For CogVideoX-2B Turn to False and For CogVideoX-5B Turn to True
+ fp16:
+ enabled: True # For CogVideoX-2B Turn to True and For CogVideoX-5B Turn to False
+```
+
+``` To use Lora fine-tuning, you also need to modify `cogvideox__lora` file:
+
+Here's an example using `CogVideoX-2B`:
+
+```yaml
model:
- scale_factor: 1.15258426
+ scale_factor: 1.55258426
disable_first_stage_autocast: true
- not_trainable_prefixes: [ 'all' ] ## ã³ã¡ã³ãã解é€
+ not_trainable_prefixes: [ 'all' ] ## Uncomment to unlock
log_keys:
- - txt'
+ - txt
- lora_config: ## ã³ã¡ã³ãã解é€
+ lora_config: ## Uncomment to unlock
target: sat.model.finetune.lora2.LoraMixin
params:
r: 256
```
-### å®è¡ã¹ã¯ãªããã®å€æŽ
+### Modify the Run Script
-èšå®ãã¡ã€ã«ãéžæããããã« `finetune_single_gpu.sh` ãŸã㯠`finetune_multi_gpus.sh` ãç·šéããŸãã以äžã«2ã€ã®äŸã瀺ããŸãã
+Edit `finetune_single_gpu.sh` or `finetune_multi_gpus.sh` and select the config file. Below are two examples:
-1. `CogVideoX-2B` ã¢ãã«ã䜿çšãã`Lora` ææ³ãå©çšããå Žåã¯ã`finetune_single_gpu.sh` ãŸã㯠`finetune_multi_gpus.sh`
- ãå€æŽããå¿
èŠããããŸãã
+1. If you want to use the `CogVideoX-2B` model with `Lora`, modify `finetune_single_gpu.sh` or `finetune_multi_gpus.sh` as follows:
```
run_cmd="torchrun --standalone --nproc_per_node=8 train_video.py --base configs/cogvideox_2b_lora.yaml configs/sft.yaml --seed $RANDOM"
```
-2. `CogVideoX-2B` ã¢ãã«ã䜿çšãã`ãã«ãã©ã¡ãŒã¿åŸ®èª¿æŽ` ææ³ãå©çšããå Žåã¯ã`finetune_single_gpu.sh`
- ãŸã㯠`finetune_multi_gpus.sh` ãå€æŽããå¿
èŠããããŸãã
+2. If you want to use the `CogVideoX-2B` model with full fine-tuning, modify `finetune_single_gpu.sh` or `finetune_multi_gpus.sh` as follows:
```
run_cmd="torchrun --standalone --nproc_per_node=8 train_video.py --base configs/cogvideox_2b.yaml configs/sft.yaml --seed $RANDOM"
```
-### 埮調æŽãšè©äŸ¡
+### Fine-tuning and Validation
-æšè«ã³ãŒããå®è¡ããŠåŸ®èª¿æŽãéå§ããŸãã
+Run the inference code to start fine-tuning.
```
-bash finetune_single_gpu.sh # ã·ã³ã°ã«GPU
-bash finetune_multi_gpus.sh # ãã«ãGPU
+bash finetune_single_gpu.sh # Single GPU
+bash finetune_multi_gpus.sh # Multi GPUs
```
-### 埮調æŽåŸã®ã¢ãã«ã®äœ¿çš
+### Using the Fine-tuned Model
-埮調æŽãããã¢ãã«ã¯çµ±åã§ããŸãããããã§ã¯ãæšè«èšå®ãã¡ã€ã« `inference.sh` ãå€æŽããæ¹æ³ã瀺ããŸãã
+The fine-tuned model cannot be merged. Hereâs how to modify the inference configuration file `inference.sh`
```
-run_cmd="$environs python sample_video.py --base configs/cogvideox__lora.yaml configs/inference.yaml --seed 42"
+run_cmd="$environs python sample_video.py --base configs/cogvideox__lora.yaml configs/inference.yaml --seed 42"
```
-ãã®åŸã次ã®ã³ãŒããå®è¡ããŸãã
+Then, run the code:
```
bash inference.sh
```
-### Huggingface Diffusers ãµããŒãã®ãŠã§ã€ãã«å€æ
+### Converting to Huggingface Diffusers-compatible Weights
-SAT ãŠã§ã€ã圢åŒã¯ Huggingface ã®ãŠã§ã€ã圢åŒãšç°ãªããå€æãå¿
èŠã§ãã次ã®ã³ãã³ããå®è¡ããŠãã ããïŒ
+The SAT weight format is different from Huggingfaceâs format and requires conversion. Run
-```shell
+```
python ../tools/convert_weight_sat2hf.py
```
-### SATãã§ãã¯ãã€ã³ãããHuggingface Diffusers lora LoRAãŠã§ã€ãããšã¯ã¹ããŒã
+### Exporting Lora Weights from SAT to Huggingface Diffusers
-äžèšã®ã¹ããããå®äºãããšãLoRAãŠã§ã€ãä»ãã®SATãã§ãã¯ãã€ã³ããåŸãããŸãããã¡ã€ã«ã¯ `{args.save}/1000/1000/mp_rank_00_model_states.pt` ã«ãããŸãã
+Support is provided for exporting Lora weights from SAT to Huggingface Diffusers format.
+After training with the above steps, youâll find the SAT model with Lora weights in {args.save}/1000/1000/mp_rank_00_model_states.pt
-LoRAãŠã§ã€ãããšã¯ã¹ããŒãããããã®ã¹ã¯ãªããã¯ãCogVideoXãªããžããªã® `tools/export_sat_lora_weight.py` ã«ãããŸãããšã¯ã¹ããŒãåŸã`load_cogvideox_lora.py` ã䜿çšããŠæšè«ãè¡ãããšãã§ããŸãã
+The export script `export_sat_lora_weight.py` is located in the CogVideoX repository under `tools/`. After exporting, use `load_cogvideox_lora.py` for inference.
-ãšã¯ã¹ããŒãã³ãã³ã:
+Export command:
-```bash
-python tools/export_sat_lora_weight.py --sat_pt_path {args.save}/{experiment_name}-09-09-21-10/1000/mp_rank_00_model_states.pt --lora_save_directory {args.save}/export_hf_lora_weights_1/
+```
+python tools/export_sat_lora_weight.py --sat_pt_path {args.save}/{experiment_name}-09-09-21-10/1000/mp_rank_00_model_states.pt --lora_save_directory {args.save}/export_hf_lora_weights_1/
```
-ãã®ãã¬ãŒãã³ã°ã§ã¯äž»ã«ä»¥äžã®ã¢ãã«æ§é ãå€æŽãããŸããã以äžã®è¡šã¯ãHF (Hugging Face) 圢åŒã®LoRAæ§é ã«å€æããéã®å¯Ÿå¿é¢ä¿ã瀺ããŠããŸããã芧ã®éããLoRAã¯ã¢ãã«ã®æ³šæã¡ã«ããºã ã«äœã©ã³ã¯ã®éã¿ãè¿œå ããŠããŸãã
+The following model structures were modified during training. Here is the mapping between SAT and HF Lora structures. Lora adds a low-rank weight to the attention structure of the model.
```
'attention.query_key_value.matrix_A.0': 'attn1.to_q.lora_A.weight',
@@ -431,8 +531,6 @@ python tools/export_sat_lora_weight.py --sat_pt_path {args.save}/{experiment_nam
'attention.dense.matrix_A.0': 'attn1.to_out.0.lora_A.weight',
'attention.dense.matrix_B.0': 'attn1.to_out.0.lora_B.weight'
```
-
-export_sat_lora_weight.py ã䜿çšããŠãSATãã§ãã¯ãã€ã³ããHF LoRA圢åŒã«å€æã§ããŸãã
-
-
+Using `export_sat_lora_weight.py` will convert these to the HF format Lora structure.
+
\ No newline at end of file
diff --git a/sat/README_zh.md b/sat/README_zh.md
index c605da8..c25c6b7 100644
--- a/sat/README_zh.md
+++ b/sat/README_zh.md
@@ -1,6 +1,6 @@
-# SAT CogVideoX-2B
+# SAT CogVideoX
-[Read this in English.](./README_zh)
+[Read this in English.](./README.md)
[æ¥æ¬èªã§èªã](./README_ja.md)
@@ -20,6 +20,15 @@ pip install -r requirements.txt
éŠå
ïŒååŸ SAT éåäžèœœæš¡åæéã
+#### CogVideoX1.5 æš¡å
+
+```shell
+git lfs install
+git clone https://huggingface.co/THUDM/CogVideoX1.5-5B-SAT
+```
+æ€æäœäŒäžèœœ Transformers, VAE, T5 Encoder è¿äžäžªæš¡åã
+
+#### CogVideoX æš¡å
å¯¹äº CogVideoX-2B æš¡åïŒè¯·æç
§åŠäžæ¹åŒäžèœœ:
```shell
@@ -82,11 +91,11 @@ mv CogVideoX-2b/text_encoder/* CogVideoX-2b/tokenizer/* t5-v1_1-xxl
0 directories, 8 files
```
-### 3. ä¿®æ¹`configs/cogvideox_2b.yaml`äžçæ件ã
+### 3. ä¿®æ¹`configs/cogvideox_*.yaml`äžçæ件ã
```yaml
model:
- scale_factor: 1.15258426
+ scale_factor: 1.55258426
disable_first_stage_autocast: true
log_keys:
- txt
@@ -253,7 +262,7 @@ args:
batch_size: 1
input_type: txt #å¯ä»¥éæ©txt纯æåæ¡£äœäžºèŸå
¥ïŒæè
æ¹æcliåœä»€è¡äœäžºèŸå
¥
input_file: configs/test.txt #纯æåæ¡£ïŒå¯ä»¥å¯¹æ€åçŒèŸ
- sampling_num_frames: 13 # Must be 13, 11 or 9
+ sampling_num_frames: 13 #CogVideoX1.5-5B å¿
é¡»æ¯ 42 æ 22ã CogVideoX-5B / 2B å¿
é¡»æ¯ 13 11 æ 9ã
sampling_fps: 8
fp16: True # For CogVideoX-2B
# bf16: True # For CogVideoX-5B
@@ -346,7 +355,7 @@ Encoder 䜿çšã
```yaml
model:
- scale_factor: 1.15258426
+ scale_factor: 1.55258426
disable_first_stage_autocast: true
not_trainable_prefixes: [ 'all' ] ## 解é€æ³šé
log_keys:
diff --git a/sat/arguments.py b/sat/arguments.py
index 44767d3..9b0a1bb 100644
--- a/sat/arguments.py
+++ b/sat/arguments.py
@@ -36,6 +36,7 @@ def add_sampling_config_args(parser):
group.add_argument("--input-dir", type=str, default=None)
group.add_argument("--input-type", type=str, default="cli")
group.add_argument("--input-file", type=str, default="input.txt")
+ group.add_argument("--sampling-image-size", type=list, default=[768, 1360])
group.add_argument("--final-size", type=int, default=2048)
group.add_argument("--sdedit", action="store_true")
group.add_argument("--grid-num-rows", type=int, default=1)
diff --git a/sat/configs/cogvideox1.5_5b.yaml b/sat/configs/cogvideox1.5_5b.yaml
new file mode 100644
index 0000000..0000ec2
--- /dev/null
+++ b/sat/configs/cogvideox1.5_5b.yaml
@@ -0,0 +1,149 @@
+model:
+ scale_factor: 0.7
+ disable_first_stage_autocast: true
+ latent_input: true
+ log_keys:
+ - txt
+
+ denoiser_config:
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
+ params:
+ num_idx: 1000
+ quantize_c_noise: False
+
+ weighting_config:
+ target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
+ scaling_config:
+ target: sgm.modules.diffusionmodules.denoiser_scaling.VideoScaling
+ discretization_config:
+ target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
+
+ network_config:
+ target: dit_video_concat.DiffusionTransformer
+ params:
+ time_embed_dim: 512
+ elementwise_affine: True
+ num_frames: 81
+ time_compressed_rate: 4
+ latent_width: 300
+ latent_height: 300
+ num_layers: 42
+ patch_size: [2, 2, 2]
+ in_channels: 16
+ out_channels: 16
+ hidden_size: 3072
+ adm_in_channels: 256
+ num_attention_heads: 48
+
+ transformer_args:
+ checkpoint_activations: True
+ vocab_size: 1
+ max_sequence_length: 64
+ layernorm_order: pre
+ skip_init: false
+ model_parallel_size: 1
+ is_decoder: false
+
+ modules:
+ pos_embed_config:
+ target: dit_video_concat.Rotary3DPositionEmbeddingMixin
+ params:
+ hidden_size_head: 64
+ text_length: 224
+
+ patch_embed_config:
+ target: dit_video_concat.ImagePatchEmbeddingMixin
+ params:
+ text_hidden_size: 4096
+
+ adaln_layer_config:
+ target: dit_video_concat.AdaLNMixin
+ params:
+ qk_ln: True
+
+ final_layer_config:
+ target: dit_video_concat.FinalLayerMixin
+
+ conditioner_config:
+ target: sgm.modules.GeneralConditioner
+ params:
+ emb_models:
+ - is_trainable: false
+ input_key: txt
+ ucg_rate: 0.1
+ target: sgm.modules.encoders.modules.FrozenT5Embedder
+ params:
+ model_dir: "google/t5-v1_1-xxl"
+ max_length: 224
+
+
+ first_stage_config:
+ target : vae_modules.autoencoder.VideoAutoencoderInferenceWrapper
+ params:
+ cp_size: 1
+ ckpt_path: "cogvideox-5b-sat/vae/3d-vae.pt"
+ ignore_keys: ['loss']
+
+ loss_config:
+ target: torch.nn.Identity
+
+ regularizer_config:
+ target: vae_modules.regularizers.DiagonalGaussianRegularizer
+
+ encoder_config:
+ target: vae_modules.cp_enc_dec.ContextParallelEncoder3D
+ params:
+ double_z: true
+ z_channels: 16
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult: [1, 2, 2, 4]
+ attn_resolutions: []
+ num_res_blocks: 3
+ dropout: 0.0
+ gather_norm: True
+
+ decoder_config:
+ target: vae_modules.cp_enc_dec.ContextParallelDecoder3D
+ params:
+ double_z: True
+ z_channels: 16
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult: [1, 2, 2, 4]
+ attn_resolutions: []
+ num_res_blocks: 3
+ dropout: 0.0
+ gather_norm: True
+
+ loss_fn_config:
+ target: sgm.modules.diffusionmodules.loss.VideoDiffusionLoss
+ params:
+ offset_noise_level: 0
+ sigma_sampler_config:
+ target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
+ params:
+ uniform_sampling: True
+ group_num: 40
+ num_idx: 1000
+ discretization_config:
+ target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
+
+ sampler_config:
+ target: sgm.modules.diffusionmodules.sampling.VPSDEDPMPP2MSampler
+ params:
+ num_steps: 50
+ verbose: True
+
+ discretization_config:
+ target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
+ guider_config:
+ target: sgm.modules.diffusionmodules.guiders.DynamicCFG
+ params:
+ scale: 6
+ exp: 5
+ num_steps: 50
diff --git a/sat/configs/cogvideox1.5_5b_i2v.yaml b/sat/configs/cogvideox1.5_5b_i2v.yaml
new file mode 100644
index 0000000..c65f0b7
--- /dev/null
+++ b/sat/configs/cogvideox1.5_5b_i2v.yaml
@@ -0,0 +1,160 @@
+model:
+ scale_factor: 0.7
+ disable_first_stage_autocast: true
+ latent_input: false
+ noised_image_input: true
+ noised_image_all_concat: false
+ noised_image_dropout: 0.05
+ augmentation_dropout: 0.15
+ log_keys:
+ - txt
+
+ denoiser_config:
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
+ params:
+ num_idx: 1000
+ quantize_c_noise: False
+
+ weighting_config:
+ target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
+ scaling_config:
+ target: sgm.modules.diffusionmodules.denoiser_scaling.VideoScaling
+ discretization_config:
+ target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
+
+ network_config:
+ target: dit_video_concat.DiffusionTransformer
+ params:
+# space_interpolation: 1.875
+ ofs_embed_dim: 512
+ time_embed_dim: 512
+ elementwise_affine: True
+ num_frames: 81
+ time_compressed_rate: 4
+ latent_width: 300
+ latent_height: 300
+ num_layers: 42
+ patch_size: [2, 2, 2]
+ in_channels: 32
+ out_channels: 16
+ hidden_size: 3072
+ adm_in_channels: 256
+ num_attention_heads: 48
+
+ transformer_args:
+ checkpoint_activations: True
+ vocab_size: 1
+ max_sequence_length: 64
+ layernorm_order: pre
+ skip_init: false
+ model_parallel_size: 1
+ is_decoder: false
+
+ modules:
+ pos_embed_config:
+ target: dit_video_concat.Rotary3DPositionEmbeddingMixin
+ params:
+ hidden_size_head: 64
+ text_length: 224
+
+ patch_embed_config:
+ target: dit_video_concat.ImagePatchEmbeddingMixin
+ params:
+ text_hidden_size: 4096
+
+
+ adaln_layer_config:
+ target: dit_video_concat.AdaLNMixin
+ params:
+ qk_ln: True
+
+ final_layer_config:
+ target: dit_video_concat.FinalLayerMixin
+
+ conditioner_config:
+ target: sgm.modules.GeneralConditioner
+ params:
+ emb_models:
+
+ - is_trainable: false
+ input_key: txt
+ ucg_rate: 0.1
+ target: sgm.modules.encoders.modules.FrozenT5Embedder
+ params:
+ model_dir: "google/t5-v1_1-xxl"
+ max_length: 224
+
+
+ first_stage_config:
+ target : vae_modules.autoencoder.VideoAutoencoderInferenceWrapper
+ params:
+ cp_size: 1
+ ckpt_path: "cogvideox-5b-i2v-sat/vae/3d-vae.pt"
+ ignore_keys: ['loss']
+
+ loss_config:
+ target: torch.nn.Identity
+
+ regularizer_config:
+ target: vae_modules.regularizers.DiagonalGaussianRegularizer
+
+ encoder_config:
+ target: vae_modules.cp_enc_dec.ContextParallelEncoder3D
+ params:
+ double_z: true
+ z_channels: 16
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult: [1, 2, 2, 4]
+ attn_resolutions: []
+ num_res_blocks: 3
+ dropout: 0.0
+ gather_norm: True
+
+ decoder_config:
+ target: vae_modules.cp_enc_dec.ContextParallelDecoder3D
+ params:
+ double_z: True
+ z_channels: 16
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult: [1, 2, 2, 4]
+ attn_resolutions: []
+ num_res_blocks: 3
+ dropout: 0.0
+ gather_norm: True
+
+ loss_fn_config:
+ target: sgm.modules.diffusionmodules.loss.VideoDiffusionLoss
+ params:
+ fixed_frames: 0
+ offset_noise_level: 0.0
+ sigma_sampler_config:
+ target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
+ params:
+ uniform_sampling: True
+ group_num: 40
+ num_idx: 1000
+ discretization_config:
+ target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
+
+ sampler_config:
+ target: sgm.modules.diffusionmodules.sampling.VPSDEDPMPP2MSampler
+ params:
+ fixed_frames: 0
+ num_steps: 50
+ verbose: True
+
+ discretization_config:
+ target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
+
+ guider_config:
+ target: sgm.modules.diffusionmodules.guiders.DynamicCFG
+ params:
+ scale: 6
+ exp: 5
+ num_steps: 50
\ No newline at end of file
diff --git a/sat/diffusion_video.py b/sat/diffusion_video.py
index 963038b..10635b4 100644
--- a/sat/diffusion_video.py
+++ b/sat/diffusion_video.py
@@ -179,14 +179,31 @@ class SATVideoDiffusionEngine(nn.Module):
n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])
n_rounds = math.ceil(z.shape[0] / n_samples)
all_out = []
- with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
- for n in range(n_rounds):
- if isinstance(self.first_stage_model.decoder, VideoDecoder):
- kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])}
- else:
- kwargs = {}
- out = self.first_stage_model.decode(z[n * n_samples : (n + 1) * n_samples], **kwargs)
- all_out.append(out)
+ for n in range(n_rounds):
+ z_now = z[n * n_samples : (n + 1) * n_samples, :, 1:]
+ latent_time = z_now.shape[2] # check the time latent
+ temporal_compress_times = 4
+
+ fake_cp_size = min(10, latent_time // 2)
+ start_frame = 0
+
+ recons = []
+ start_frame = 0
+ for i in range(fake_cp_size):
+ end_frame = start_frame + latent_time // fake_cp_size + (1 if i < latent_time % fake_cp_size else 0)
+
+ fake_cp_rank0 = True if i == 0 else False
+ clear_fake_cp_cache = True if i == fake_cp_size - 1 else False
+ with torch.no_grad():
+ recon = self.first_stage_model.decode(
+ z_now[:, :, start_frame:end_frame].contiguous(),
+ clear_fake_cp_cache=clear_fake_cp_cache,
+ fake_cp_rank0=fake_cp_rank0,
+ )
+ recons.append(recon)
+ start_frame = end_frame
+ recons = torch.cat(recons, dim=2)
+ all_out.append(recons)
out = torch.cat(all_out, dim=0)
return out
@@ -218,6 +235,7 @@ class SATVideoDiffusionEngine(nn.Module):
shape: Union[None, Tuple, List] = None,
prefix=None,
concat_images=None,
+ ofs=None,
**kwargs,
):
randn = torch.randn(batch_size, *shape).to(torch.float32).to(self.device)
@@ -241,7 +259,7 @@ class SATVideoDiffusionEngine(nn.Module):
self.model, input, sigma, c, concat_images=concat_images, **addtional_model_inputs
)
- samples = self.sampler(denoiser, randn, cond, uc=uc, scale=scale, scale_emb=scale_emb)
+ samples = self.sampler(denoiser, randn, cond, uc=uc, scale=scale, scale_emb=scale_emb, ofs=ofs)
samples = samples.to(self.dtype)
return samples
diff --git a/sat/dit_video_concat.py b/sat/dit_video_concat.py
index 7692116..b55a3f1 100644
--- a/sat/dit_video_concat.py
+++ b/sat/dit_video_concat.py
@@ -1,5 +1,7 @@
from functools import partial
from einops import rearrange, repeat
+from functools import reduce
+from operator import mul
import numpy as np
import torch
@@ -13,38 +15,34 @@ from sat.mpu.layers import ColumnParallelLinear
from sgm.util import instantiate_from_config
from sgm.modules.diffusionmodules.openaimodel import Timestep
-from sgm.modules.diffusionmodules.util import (
- linear,
- timestep_embedding,
-)
+from sgm.modules.diffusionmodules.util import linear, timestep_embedding
from sat.ops.layernorm import LayerNorm, RMSNorm
class ImagePatchEmbeddingMixin(BaseMixin):
- def __init__(
- self,
- in_channels,
- hidden_size,
- patch_size,
- bias=True,
- text_hidden_size=None,
- ):
+ def __init__(self, in_channels, hidden_size, patch_size, text_hidden_size=None):
super().__init__()
- self.proj = nn.Conv2d(in_channels, hidden_size, kernel_size=patch_size, stride=patch_size, bias=bias)
+ self.patch_size = patch_size
+ self.proj = nn.Linear(in_channels * reduce(mul, patch_size), hidden_size)
if text_hidden_size is not None:
self.text_proj = nn.Linear(text_hidden_size, hidden_size)
else:
self.text_proj = None
def word_embedding_forward(self, input_ids, **kwargs):
- # now is 3d patch
images = kwargs["images"] # (b,t,c,h,w)
- B, T = images.shape[:2]
- emb = images.view(-1, *images.shape[2:])
- emb = self.proj(emb) # ((b t),d,h/2,w/2)
- emb = emb.view(B, T, *emb.shape[1:])
- emb = emb.flatten(3).transpose(2, 3) # (b,t,n,d)
- emb = rearrange(emb, "b t n d -> b (t n) d")
+ emb = rearrange(images, "b t c h w -> b (t h w) c")
+ emb = rearrange(
+ emb,
+ "b (t o h p w q) c -> b (t h w) (c o p q)",
+ t=kwargs["rope_T"],
+ h=kwargs["rope_H"],
+ w=kwargs["rope_W"],
+ o=self.patch_size[0],
+ p=self.patch_size[1],
+ q=self.patch_size[2],
+ )
+ emb = self.proj(emb)
if self.text_proj is not None:
text_emb = self.text_proj(kwargs["encoder_outputs"])
@@ -74,7 +72,8 @@ def get_3d_sincos_pos_embed(
grid_size: int of the grid height and width
t_size: int of the temporal size
return:
- pos_embed: [t_size*grid_size*grid_size, embed_dim] or [1+t_size*grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ pos_embed: [t_size*grid_size * grid_size, embed_dim] or [1+t_size*grid_size * grid_size, embed_dim]
+ (w/ or w/o cls_token)
"""
assert embed_dim % 4 == 0
embed_dim_spatial = embed_dim // 4 * 3
@@ -100,7 +99,6 @@ def get_3d_sincos_pos_embed(
pos_embed_spatial = np.repeat(pos_embed_spatial, t_size, axis=0) # [T, H*W, D // 4 * 3]
pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1)
- # pos_embed = pos_embed.reshape([-1, embed_dim]) # [T*H*W, D]
return pos_embed # [T, H*W, D]
@@ -259,6 +257,9 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin):
text_length,
theta=10000,
rot_v=False,
+ height_interpolation=1.0,
+ width_interpolation=1.0,
+ time_interpolation=1.0,
learnable_pos_embed=False,
):
super().__init__()
@@ -285,14 +286,10 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin):
freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
- freqs = rearrange(freqs, "t h w d -> (t h w) d")
freqs = freqs.contiguous()
- freqs_sin = freqs.sin()
- freqs_cos = freqs.cos()
- self.register_buffer("freqs_sin", freqs_sin)
- self.register_buffer("freqs_cos", freqs_cos)
-
+ self.freqs_sin = freqs.sin().cuda()
+ self.freqs_cos = freqs.cos().cuda()
self.text_length = text_length
if learnable_pos_embed:
num_patches = height * width * compressed_num_frames + text_length
@@ -301,15 +298,20 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin):
self.pos_embedding = None
def rotary(self, t, **kwargs):
- seq_len = t.shape[2]
- freqs_cos = self.freqs_cos[:seq_len].unsqueeze(0).unsqueeze(0)
- freqs_sin = self.freqs_sin[:seq_len].unsqueeze(0).unsqueeze(0)
+ def reshape_freq(freqs):
+ freqs = freqs[: kwargs["rope_T"], : kwargs["rope_H"], : kwargs["rope_W"]].contiguous()
+ freqs = rearrange(freqs, "t h w d -> (t h w) d")
+ freqs = freqs.unsqueeze(0).unsqueeze(0)
+ return freqs
+
+ freqs_cos = reshape_freq(self.freqs_cos).to(t.dtype)
+ freqs_sin = reshape_freq(self.freqs_sin).to(t.dtype)
return t * freqs_cos + rotate_half(t) * freqs_sin
def position_embedding_forward(self, position_ids, **kwargs):
if self.pos_embedding is not None:
- return self.pos_embedding[:, :self.text_length + kwargs["seq_length"]]
+ return self.pos_embedding[:, : self.text_length + kwargs["seq_length"]]
else:
return None
@@ -326,10 +328,61 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin):
):
attention_fn_default = HOOKS_DEFAULT["attention_fn"]
- query_layer[:, :, self.text_length :] = self.rotary(query_layer[:, :, self.text_length :])
- key_layer[:, :, self.text_length :] = self.rotary(key_layer[:, :, self.text_length :])
+ query_layer = torch.cat(
+ (
+ query_layer[
+ :,
+ :,
+ : kwargs["text_length"],
+ ],
+ self.rotary(
+ query_layer[
+ :,
+ :,
+ kwargs["text_length"] :,
+ ],
+ **kwargs,
+ ),
+ ),
+ dim=2,
+ )
+ key_layer = torch.cat(
+ (
+ key_layer[
+ :,
+ :,
+ : kwargs["text_length"],
+ ],
+ self.rotary(
+ key_layer[
+ :,
+ :,
+ kwargs["text_length"] :,
+ ],
+ **kwargs,
+ ),
+ ),
+ dim=2,
+ )
if self.rot_v:
- value_layer[:, :, self.text_length :] = self.rotary(value_layer[:, :, self.text_length :])
+ value_layer = torch.cat(
+ (
+ value_layer[
+ :,
+ :,
+ : kwargs["text_length"],
+ ],
+ self.rotary(
+ value_layer[
+ :,
+ :,
+ kwargs["text_length"] :,
+ ],
+ **kwargs,
+ ),
+ ),
+ dim=2,
+ )
return attention_fn_default(
query_layer,
@@ -347,21 +400,25 @@ def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
-def unpatchify(x, c, p, w, h, rope_position_ids=None, **kwargs):
+def unpatchify(x, c, patch_size, w, h, **kwargs):
"""
x: (N, T/2 * S, patch_size**3 * C)
imgs: (N, T, H, W, C)
+
+ patch_size 被æ解䞺äžäžªäžåç绎床 (o, p, q)ïŒåå«å¯¹åºäºæ·±åºŠïŒoïŒãé«åºŠïŒpïŒå宜床ïŒqïŒãè¿äœ¿åŸ patch 倧å°åšäžå绎床äžå¯ä»¥äžçžçïŒå¢å äºçµæŽ»æ§ã
"""
- if rope_position_ids is not None:
- assert NotImplementedError
- # do pix2struct unpatchify
- L = x.shape[1]
- x = x.reshape(shape=(x.shape[0], L, p, p, c))
- x = torch.einsum("nlpqc->ncplq", x)
- imgs = x.reshape(shape=(x.shape[0], c, p, L * p))
- else:
- b = x.shape[0]
- imgs = rearrange(x, "b (t h w) (c p q) -> b t c (h p) (w q)", b=b, h=h, w=w, c=c, p=p, q=p)
+
+ imgs = rearrange(
+ x,
+ "b (t h w) (c o p q) -> b (t o) c (h p) (w q)",
+ c=c,
+ o=patch_size[0],
+ p=patch_size[1],
+ q=patch_size[2],
+ t=kwargs["rope_T"],
+ h=kwargs["rope_H"],
+ w=kwargs["rope_W"],
+ )
return imgs
@@ -382,27 +439,17 @@ class FinalLayerMixin(BaseMixin):
self.patch_size = patch_size
self.out_channels = out_channels
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=elementwise_affine, eps=1e-6)
- self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
+ self.linear = nn.Linear(hidden_size, reduce(mul, patch_size) * out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim, 2 * hidden_size, bias=True))
- self.spatial_length = latent_width * latent_height // patch_size**2
- self.latent_width = latent_width
- self.latent_height = latent_height
-
def final_forward(self, logits, **kwargs):
- x, emb = logits[:, kwargs["text_length"] :, :], kwargs["emb"] # x:(b,(t n),d)
+ x, emb = logits[:, kwargs["text_length"] :, :], kwargs["emb"] # x:(b,(t n),d),åªåäºxäžåé¢imagesçéšå
shift, scale = self.adaLN_modulation(emb).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return unpatchify(
- x,
- c=self.out_channels,
- p=self.patch_size,
- w=self.latent_width // self.patch_size,
- h=self.latent_height // self.patch_size,
- rope_position_ids=kwargs.get("rope_position_ids", None),
- **kwargs,
+ x, c=self.out_channels, patch_size=self.patch_size, w=kwargs["rope_W"], h=kwargs["rope_H"], **kwargs
)
def reinit(self, parent_model=None):
@@ -440,8 +487,6 @@ class SwiGLUMixin(BaseMixin):
class AdaLNMixin(BaseMixin):
def __init__(
self,
- width,
- height,
hidden_size,
num_layers,
time_embed_dim,
@@ -452,8 +497,6 @@ class AdaLNMixin(BaseMixin):
):
super().__init__()
self.num_layers = num_layers
- self.width = width
- self.height = height
self.compressed_num_frames = compressed_num_frames
self.adaLN_modulations = nn.ModuleList(
@@ -611,7 +654,7 @@ class DiffusionTransformer(BaseModel):
time_interpolation=1.0,
use_SwiGLU=False,
use_RMSNorm=False,
- zero_init_y_embed=False,
+ ofs_embed_dim=None,
**kwargs,
):
self.latent_width = latent_width
@@ -619,12 +662,13 @@ class DiffusionTransformer(BaseModel):
self.patch_size = patch_size
self.num_frames = num_frames
self.time_compressed_rate = time_compressed_rate
- self.spatial_length = latent_width * latent_height // patch_size**2
+ self.spatial_length = latent_width * latent_height // reduce(mul, patch_size[1:])
self.in_channels = in_channels
self.out_channels = out_channels
self.hidden_size = hidden_size
self.model_channels = hidden_size
self.time_embed_dim = time_embed_dim if time_embed_dim is not None else hidden_size
+ self.ofs_embed_dim = ofs_embed_dim
self.num_classes = num_classes
self.adm_in_channels = adm_in_channels
self.input_time = input_time
@@ -636,7 +680,6 @@ class DiffusionTransformer(BaseModel):
self.width_interpolation = width_interpolation
self.time_interpolation = time_interpolation
self.inner_hidden_size = hidden_size * 4
- self.zero_init_y_embed = zero_init_y_embed
try:
self.dtype = str_to_dtype[kwargs.pop("dtype")]
except:
@@ -669,7 +712,6 @@ class DiffusionTransformer(BaseModel):
def _build_modules(self, module_configs):
model_channels = self.hidden_size
- # time_embed_dim = model_channels * 4
time_embed_dim = self.time_embed_dim
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
@@ -677,6 +719,13 @@ class DiffusionTransformer(BaseModel):
linear(time_embed_dim, time_embed_dim),
)
+ if self.ofs_embed_dim is not None:
+ self.ofs_embed = nn.Sequential(
+ linear(self.ofs_embed_dim, self.ofs_embed_dim),
+ nn.SiLU(),
+ linear(self.ofs_embed_dim, self.ofs_embed_dim),
+ )
+
if self.num_classes is not None:
if isinstance(self.num_classes, int):
self.label_emb = nn.Embedding(self.num_classes, time_embed_dim)
@@ -701,9 +750,6 @@ class DiffusionTransformer(BaseModel):
linear(time_embed_dim, time_embed_dim),
)
)
- if self.zero_init_y_embed:
- nn.init.constant_(self.label_emb[0][2].weight, 0)
- nn.init.constant_(self.label_emb[0][2].bias, 0)
else:
raise ValueError()
@@ -712,10 +758,13 @@ class DiffusionTransformer(BaseModel):
"pos_embed",
instantiate_from_config(
pos_embed_config,
- height=self.latent_height // self.patch_size,
- width=self.latent_width // self.patch_size,
+ height=self.latent_height // self.patch_size[1],
+ width=self.latent_width // self.patch_size[2],
compressed_num_frames=(self.num_frames - 1) // self.time_compressed_rate + 1,
hidden_size=self.hidden_size,
+ height_interpolation=self.height_interpolation,
+ width_interpolation=self.width_interpolation,
+ time_interpolation=self.time_interpolation,
),
reinit=True,
)
@@ -737,8 +786,6 @@ class DiffusionTransformer(BaseModel):
"adaln_layer",
instantiate_from_config(
adaln_layer_config,
- height=self.latent_height // self.patch_size,
- width=self.latent_width // self.patch_size,
hidden_size=self.hidden_size,
num_layers=self.num_layers,
compressed_num_frames=(self.num_frames - 1) // self.time_compressed_rate + 1,
@@ -749,7 +796,6 @@ class DiffusionTransformer(BaseModel):
)
else:
raise NotImplementedError
-
final_layer_config = module_configs["final_layer_config"]
self.add_mixin(
"final_layer",
@@ -766,25 +812,18 @@ class DiffusionTransformer(BaseModel):
reinit=True,
)
- if "lora_config" in module_configs:
- lora_config = module_configs["lora_config"]
- self.add_mixin("lora", instantiate_from_config(lora_config, layer_num=self.num_layers), reinit=True)
-
return
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
b, t, d, h, w = x.shape
if x.dtype != self.dtype:
x = x.to(self.dtype)
-
- # This is not use in inference
if "concat_images" in kwargs and kwargs["concat_images"] is not None:
if kwargs["concat_images"].shape[0] != x.shape[0]:
concat_images = kwargs["concat_images"].repeat(2, 1, 1, 1, 1)
else:
concat_images = kwargs["concat_images"]
x = torch.cat([x, concat_images], dim=2)
-
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
@@ -792,17 +831,25 @@ class DiffusionTransformer(BaseModel):
emb = self.time_embed(t_emb)
if self.num_classes is not None:
- # assert y.shape[0] == x.shape[0]
assert x.shape[0] % y.shape[0] == 0
y = y.repeat_interleave(x.shape[0] // y.shape[0], dim=0)
emb = emb + self.label_emb(y)
- kwargs["seq_length"] = t * h * w // (self.patch_size**2)
+ if self.ofs_embed_dim is not None:
+ ofs_emb = timestep_embedding(kwargs["ofs"], self.ofs_embed_dim, repeat_only=False, dtype=self.dtype)
+ ofs_emb = self.ofs_embed(ofs_emb)
+ emb = emb + ofs_emb
+
+ kwargs["seq_length"] = t * h * w // reduce(mul, self.patch_size)
kwargs["images"] = x
kwargs["emb"] = emb
kwargs["encoder_outputs"] = context
kwargs["text_length"] = context.shape[1]
+ kwargs["rope_T"] = t // self.patch_size[0]
+ kwargs["rope_H"] = h // self.patch_size[1]
+ kwargs["rope_W"] = w // self.patch_size[2]
+
kwargs["input_ids"] = kwargs["position_ids"] = kwargs["attention_mask"] = torch.ones((1, 1)).to(x.dtype)
output = super().forward(**kwargs)[0]
return output
diff --git a/sat/inference.sh b/sat/inference.sh
index c798fa5..a22ef87 100755
--- a/sat/inference.sh
+++ b/sat/inference.sh
@@ -4,7 +4,7 @@ echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
environs="WORLD_SIZE=1 RANK=0 LOCAL_RANK=0 LOCAL_WORLD_SIZE=1"
-run_cmd="$environs python sample_video.py --base configs/cogvideox_5b.yaml configs/inference.yaml --seed $RANDOM"
+run_cmd="$environs python sample_video.py --base configs/cogvideox1.5_5b.yaml configs/test_inference.yaml --seed $RANDOM"
echo ${run_cmd}
eval ${run_cmd}
diff --git a/sat/requirements.txt b/sat/requirements.txt
index 75b4649..3c1c501 100644
--- a/sat/requirements.txt
+++ b/sat/requirements.txt
@@ -1,16 +1,11 @@
-SwissArmyTransformer==0.4.12
-omegaconf==2.3.0
-torch==2.4.0
-torchvision==0.19.0
-pytorch_lightning==2.3.3
-kornia==0.7.3
-beartype==0.18.5
-numpy==2.0.1
-fsspec==2024.5.0
-safetensors==0.4.3
-imageio-ffmpeg==0.5.1
-imageio==2.34.2
-scipy==1.14.0
-decord==0.6.0
-wandb==0.17.5
-deepspeed==0.14.4
\ No newline at end of file
+SwissArmyTransformer>=0.4.12
+omegaconf>=2.3.0
+pytorch_lightning>=2.4.0
+kornia>=0.7.3
+beartype>=0.19.0
+fsspec>=2024.2.0
+safetensors>=0.4.5
+scipy>=1.14.1
+decord>=0.6.0
+wandb>=0.18.5
+deepspeed>=0.15.3
\ No newline at end of file
diff --git a/sat/sample_video.py b/sat/sample_video.py
index 49cfcac..c34e6a7 100644
--- a/sat/sample_video.py
+++ b/sat/sample_video.py
@@ -4,24 +4,20 @@ import argparse
from typing import List, Union
from tqdm import tqdm
from omegaconf import ListConfig
+from PIL import Image
import imageio
import torch
import numpy as np
-from einops import rearrange
+from einops import rearrange, repeat
import torchvision.transforms as TT
-
from sat.model.base_model import get_model
from sat.training.model_io import load_checkpoint
from sat import mpu
from diffusion_video import SATVideoDiffusionEngine
from arguments import get_args
-from torchvision.transforms.functional import center_crop, resize
-from torchvision.transforms import InterpolationMode
-from PIL import Image
-
def read_from_cli():
cnt = 0
@@ -56,6 +52,42 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], T=None, device="cuda
if key == "txt":
batch["txt"] = np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist()
batch_uc["txt"] = np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist()
+ elif key == "original_size_as_tuple":
+ batch["original_size_as_tuple"] = (
+ torch.tensor([value_dict["orig_height"], value_dict["orig_width"]]).to(device).repeat(*N, 1)
+ )
+ elif key == "crop_coords_top_left":
+ batch["crop_coords_top_left"] = (
+ torch.tensor([value_dict["crop_coords_top"], value_dict["crop_coords_left"]]).to(device).repeat(*N, 1)
+ )
+ elif key == "aesthetic_score":
+ batch["aesthetic_score"] = torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
+ batch_uc["aesthetic_score"] = (
+ torch.tensor([value_dict["negative_aesthetic_score"]]).to(device).repeat(*N, 1)
+ )
+
+ elif key == "target_size_as_tuple":
+ batch["target_size_as_tuple"] = (
+ torch.tensor([value_dict["target_height"], value_dict["target_width"]]).to(device).repeat(*N, 1)
+ )
+ elif key == "fps":
+ batch[key] = torch.tensor([value_dict["fps"]]).to(device).repeat(math.prod(N))
+ elif key == "fps_id":
+ batch[key] = torch.tensor([value_dict["fps_id"]]).to(device).repeat(math.prod(N))
+ elif key == "motion_bucket_id":
+ batch[key] = torch.tensor([value_dict["motion_bucket_id"]]).to(device).repeat(math.prod(N))
+ elif key == "pool_image":
+ batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=math.prod(N)).to(device, dtype=torch.half)
+ elif key == "cond_aug":
+ batch[key] = repeat(
+ torch.tensor([value_dict["cond_aug"]]).to("cuda"),
+ "1 -> b",
+ b=math.prod(N),
+ )
+ elif key == "cond_frames":
+ batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0])
+ elif key == "cond_frames_without_noise":
+ batch[key] = repeat(value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0])
else:
batch[key] = value_dict[key]
@@ -83,37 +115,6 @@ def save_video_as_grid_and_mp4(video_batch: torch.Tensor, save_path: str, fps: i
writer.append_data(frame)
-def resize_for_rectangle_crop(arr, image_size, reshape_mode="random"):
- if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
- arr = resize(
- arr,
- size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
- interpolation=InterpolationMode.BICUBIC,
- )
- else:
- arr = resize(
- arr,
- size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
- interpolation=InterpolationMode.BICUBIC,
- )
-
- h, w = arr.shape[2], arr.shape[3]
- arr = arr.squeeze(0)
-
- delta_h = h - image_size[0]
- delta_w = w - image_size[1]
-
- if reshape_mode == "random" or reshape_mode == "none":
- top = np.random.randint(0, delta_h + 1)
- left = np.random.randint(0, delta_w + 1)
- elif reshape_mode == "center":
- top, left = delta_h // 2, delta_w // 2
- else:
- raise NotImplementedError
- arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
- return arr
-
-
def sampling_main(args, model_cls):
if isinstance(model_cls, type):
model = get_model(args, model_cls)
@@ -127,45 +128,62 @@ def sampling_main(args, model_cls):
data_iter = read_from_cli()
elif args.input_type == "txt":
rank, world_size = mpu.get_data_parallel_rank(), mpu.get_data_parallel_world_size()
- print("rank and world_size", rank, world_size)
data_iter = read_from_file(args.input_file, rank=rank, world_size=world_size)
else:
raise NotImplementedError
- image_size = [480, 720]
-
- if args.image2video:
- chained_trainsforms = []
- chained_trainsforms.append(TT.ToTensor())
- transform = TT.Compose(chained_trainsforms)
-
sample_func = model.sample
- T, H, W, C, F = args.sampling_num_frames, image_size[0], image_size[1], args.latent_channels, 8
num_samples = [1]
force_uc_zero_embeddings = ["txt"]
- device = model.device
+ T, C = args.sampling_num_frames, args.latent_channels
with torch.no_grad():
for text, cnt in tqdm(data_iter):
if args.image2video:
+ # use with input image shape
text, image_path = text.split("@@")
assert os.path.exists(image_path), image_path
image = Image.open(image_path).convert("RGB")
+ (img_W, img_H) = image.size
+
+ def nearest_multiple_of_16(n):
+ lower_multiple = (n // 16) * 16
+ upper_multiple = (n // 16 + 1) * 16
+ if abs(n - lower_multiple) < abs(n - upper_multiple):
+ return lower_multiple
+ else:
+ return upper_multiple
+
+ if img_H < img_W:
+ H = 96
+ W = int(nearest_multiple_of_16(img_W / img_H * H * 8)) // 8
+ else:
+ W = 96
+ H = int(nearest_multiple_of_16(img_H / img_W * W * 8)) // 8
+ chained_trainsforms = []
+ chained_trainsforms.append(TT.Resize(size=[int(H * 8), int(W * 8)], interpolation=1))
+ chained_trainsforms.append(TT.ToTensor())
+ transform = TT.Compose(chained_trainsforms)
image = transform(image).unsqueeze(0).to("cuda")
- image = resize_for_rectangle_crop(image, image_size, reshape_mode="center").unsqueeze(0)
image = image * 2.0 - 1.0
image = image.unsqueeze(2).to(torch.bfloat16)
image = model.encode_first_stage(image, None)
+ image = image / model.scale_factor
image = image.permute(0, 2, 1, 3, 4).contiguous()
- pad_shape = (image.shape[0], T - 1, C, H // F, W // F)
+ pad_shape = (image.shape[0], T - 1, C, H, W)
image = torch.concat([image, torch.zeros(pad_shape).to(image.device).to(image.dtype)], dim=1)
else:
+ image_size = args.sampling_image_size
+ H, W = image_size[0], image_size[1]
+ F = 8 # 8x downsampled
image = None
- value_dict = {
- "prompt": text,
- "negative_prompt": "",
- "num_frames": torch.tensor(T).unsqueeze(0),
- }
+ text_cast = [text]
+ mp_size = mpu.get_model_parallel_world_size()
+ global_rank = torch.distributed.get_rank() // mp_size
+ src = global_rank * mp_size
+ torch.distributed.broadcast_object_list(text_cast, src=src, group=mpu.get_model_parallel_group())
+ text = text_cast[0]
+ value_dict = {"prompt": text, "negative_prompt": "", "num_frames": torch.tensor(T).unsqueeze(0)}
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples
@@ -187,57 +205,42 @@ def sampling_main(args, model_cls):
if not k == "crossattn":
c[k], uc[k] = map(lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc))
- if args.image2video and image is not None:
+ if args.image2video:
c["concat"] = image
uc["concat"] = image
for index in range(args.batch_size):
- # reload model on GPU
- model.to(device)
- samples_z = sample_func(
- c,
- uc=uc,
- batch_size=1,
- shape=(T, C, H // F, W // F),
- )
+ if args.image2video:
+ samples_z = sample_func(
+ c, uc=uc, batch_size=1, shape=(T, C, H, W), ofs=torch.tensor([2.0]).to("cuda")
+ )
+ else:
+ samples_z = sample_func(
+ c,
+ uc=uc,
+ batch_size=1,
+ shape=(T, C, H // F, W // F),
+ ).to("cuda")
+
samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous()
-
- # Unload the model from GPU to save GPU memory
- model.to("cpu")
- torch.cuda.empty_cache()
- first_stage_model = model.first_stage_model
- first_stage_model = first_stage_model.to(device)
-
- latent = 1.0 / model.scale_factor * samples_z
-
- # Decode latent serial to save GPU memory
- recons = []
- loop_num = (T - 1) // 2
- for i in range(loop_num):
- if i == 0:
- start_frame, end_frame = 0, 3
- else:
- start_frame, end_frame = i * 2 + 1, i * 2 + 3
- if i == loop_num - 1:
- clear_fake_cp_cache = True
- else:
- clear_fake_cp_cache = False
- with torch.no_grad():
- recon = first_stage_model.decode(
- latent[:, :, start_frame:end_frame].contiguous(), clear_fake_cp_cache=clear_fake_cp_cache
- )
-
- recons.append(recon)
-
- recon = torch.cat(recons, dim=2).to(torch.float32)
- samples_x = recon.permute(0, 2, 1, 3, 4).contiguous()
- samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu()
-
- save_path = os.path.join(
- args.output_dir, str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120], str(index)
- )
- if mpu.get_model_parallel_rank() == 0:
- save_video_as_grid_and_mp4(samples, save_path, fps=args.sampling_fps)
+ if args.only_save_latents:
+ samples_z = 1.0 / model.scale_factor * samples_z
+ save_path = os.path.join(
+ args.output_dir, str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120], str(index)
+ )
+ os.makedirs(save_path, exist_ok=True)
+ torch.save(samples_z, os.path.join(save_path, "latent.pt"))
+ with open(os.path.join(save_path, "text.txt"), "w") as f:
+ f.write(text)
+ else:
+ samples_x = model.decode_first_stage(samples_z).to(torch.float32)
+ samples_x = samples_x.permute(0, 2, 1, 3, 4).contiguous()
+ samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu()
+ save_path = os.path.join(
+ args.output_dir, str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120], str(index)
+ )
+ if mpu.get_model_parallel_rank() == 0:
+ save_video_as_grid_and_mp4(samples, save_path, fps=args.sampling_fps)
if __name__ == "__main__":
diff --git a/sat/sgm/modules/diffusionmodules/sampling.py b/sat/sgm/modules/diffusionmodules/sampling.py
index f0f1830..6efd154 100644
--- a/sat/sgm/modules/diffusionmodules/sampling.py
+++ b/sat/sgm/modules/diffusionmodules/sampling.py
@@ -1,7 +1,8 @@
"""
-Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py
+ Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py
"""
+
from typing import Dict, Union
import torch
@@ -16,7 +17,6 @@ from ...modules.diffusionmodules.sampling_utils import (
to_sigma,
)
from ...util import append_dims, default, instantiate_from_config
-from ...util import SeededNoise
from .guiders import DynamicCFG
@@ -44,7 +44,9 @@ class BaseDiffusionSampler:
self.device = device
def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
- sigmas = self.discretization(self.num_steps if num_steps is None else num_steps, device=self.device)
+ sigmas = self.discretization(
+ self.num_steps if num_steps is None else num_steps, device=self.device
+ )
uc = default(uc, cond)
x *= torch.sqrt(1.0 + sigmas[0] ** 2.0)
@@ -83,7 +85,9 @@ class SingleStepDiffusionSampler(BaseDiffusionSampler):
class EDMSampler(SingleStepDiffusionSampler):
- def __init__(self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs):
+ def __init__(
+ self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs
+ ):
super().__init__(*args, **kwargs)
self.s_churn = s_churn
@@ -102,15 +106,21 @@ class EDMSampler(SingleStepDiffusionSampler):
dt = append_dims(next_sigma - sigma_hat, x.ndim)
euler_step = self.euler_step(x, d, dt)
- x = self.possible_correction_step(euler_step, x, d, dt, next_sigma, denoiser, cond, uc)
+ x = self.possible_correction_step(
+ euler_step, x, d, dt, next_sigma, denoiser, cond, uc
+ )
return x
def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
- x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
+ x, cond, uc, num_steps
+ )
for i in self.get_sigma_gen(num_sigmas):
gamma = (
- min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) if self.s_tmin <= sigmas[i] <= self.s_tmax else 0.0
+ min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)
+ if self.s_tmin <= sigmas[i] <= self.s_tmax
+ else 0.0
)
x = self.sampler_step(
s_in * sigmas[i],
@@ -126,23 +136,30 @@ class EDMSampler(SingleStepDiffusionSampler):
class DDIMSampler(SingleStepDiffusionSampler):
- def __init__(self, s_noise=0.1, *args, **kwargs):
+ def __init__(
+ self, s_noise=0.1, *args, **kwargs
+ ):
super().__init__(*args, **kwargs)
self.s_noise = s_noise
def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, s_noise=0.0):
+
denoised = self.denoise(x, denoiser, sigma, cond, uc)
d = to_d(x, sigma, denoised)
- dt = append_dims(next_sigma * (1 - s_noise**2) ** 0.5 - sigma, x.ndim)
+ dt = append_dims(next_sigma * (1 - s_noise**2)**0.5 - sigma, x.ndim)
euler_step = x + dt * d + s_noise * append_dims(next_sigma, x.ndim) * torch.randn_like(x)
- x = self.possible_correction_step(euler_step, x, d, dt, next_sigma, denoiser, cond, uc)
+ x = self.possible_correction_step(
+ euler_step, x, d, dt, next_sigma, denoiser, cond, uc
+ )
return x
def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
- x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
+ x, cond, uc, num_steps
+ )
for i in self.get_sigma_gen(num_sigmas):
x = self.sampler_step(
@@ -181,7 +198,9 @@ class AncestralSampler(SingleStepDiffusionSampler):
return x
def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
- x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
+ x, cond, uc, num_steps
+ )
for i in self.get_sigma_gen(num_sigmas):
x = self.sampler_step(
@@ -208,32 +227,43 @@ class LinearMultistepSampler(BaseDiffusionSampler):
self.order = order
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):
- x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
+ x, cond, uc, num_steps
+ )
ds = []
sigmas_cpu = sigmas.detach().cpu().numpy()
for i in self.get_sigma_gen(num_sigmas):
sigma = s_in * sigmas[i]
- denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs)
+ denoised = denoiser(
+ *self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs
+ )
denoised = self.guider(denoised, sigma)
d = to_d(x, sigma, denoised)
ds.append(d)
if len(ds) > self.order:
ds.pop(0)
cur_order = min(i + 1, self.order)
- coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)]
+ coeffs = [
+ linear_multistep_coeff(cur_order, sigmas_cpu, i, j)
+ for j in range(cur_order)
+ ]
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
return x
class EulerEDMSampler(EDMSampler):
- def possible_correction_step(self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc):
+ def possible_correction_step(
+ self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc
+ ):
return euler_step
class HeunEDMSampler(EDMSampler):
- def possible_correction_step(self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc):
+ def possible_correction_step(
+ self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc
+ ):
if torch.sum(next_sigma) < 1e-14:
# Save a network evaluation if all noise levels are 0
return euler_step
@@ -243,7 +273,9 @@ class HeunEDMSampler(EDMSampler):
d_prime = (d + d_new) / 2.0
# apply correction if noise level is not 0
- x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step)
+ x = torch.where(
+ append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step
+ )
return x
@@ -282,7 +314,9 @@ class DPMPP2SAncestralSampler(AncestralSampler):
x = x_euler
else:
h, s, t, t_next = self.get_variables(sigma, sigma_down)
- mult = [append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next)]
+ mult = [
+ append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next)
+ ]
x2 = mult[0] * x - mult[1] * denoised
denoised2 = self.denoise(x2, denoiser, to_sigma(s), cond, uc)
@@ -332,7 +366,10 @@ class DPMPP2MSampler(BaseDiffusionSampler):
denoised = self.denoise(x, denoiser, sigma, cond, uc)
h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma)
- mult = [append_dims(mult, x.ndim) for mult in self.get_mult(h, r, t, t_next, previous_sigma)]
+ mult = [
+ append_dims(mult, x.ndim)
+ for mult in self.get_mult(h, r, t, t_next, previous_sigma)
+ ]
x_standard = mult[0] * x - mult[1] * denoised
if old_denoised is None or torch.sum(next_sigma) < 1e-14:
@@ -343,12 +380,16 @@ class DPMPP2MSampler(BaseDiffusionSampler):
x_advanced = mult[0] * x - mult[1] * denoised_d
# apply correction if noise level is not 0 and not first step
- x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard)
+ x = torch.where(
+ append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard
+ )
return x, denoised
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):
- x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
+ x, cond, uc, num_steps
+ )
old_denoised = None
for i in self.get_sigma_gen(num_sigmas):
@@ -365,7 +406,6 @@ class DPMPP2MSampler(BaseDiffusionSampler):
return x
-
class SDEDPMPP2MSampler(BaseDiffusionSampler):
def get_variables(self, sigma, next_sigma, previous_sigma=None):
t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)]
@@ -380,7 +420,7 @@ class SDEDPMPP2MSampler(BaseDiffusionSampler):
def get_mult(self, h, r, t, t_next, previous_sigma):
mult1 = to_sigma(t_next) / to_sigma(t) * (-h).exp()
- mult2 = (-2 * h).expm1()
+ mult2 = (-2*h).expm1()
if previous_sigma is not None:
mult3 = 1 + 1 / (2 * r)
@@ -403,8 +443,11 @@ class SDEDPMPP2MSampler(BaseDiffusionSampler):
denoised = self.denoise(x, denoiser, sigma, cond, uc)
h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma)
- mult = [append_dims(mult, x.ndim) for mult in self.get_mult(h, r, t, t_next, previous_sigma)]
- mult_noise = append_dims(next_sigma * (1 - (-2 * h).exp()) ** 0.5, x.ndim)
+ mult = [
+ append_dims(mult, x.ndim)
+ for mult in self.get_mult(h, r, t, t_next, previous_sigma)
+ ]
+ mult_noise = append_dims(next_sigma * (1 - (-2*h).exp())**0.5, x.ndim)
x_standard = mult[0] * x - mult[1] * denoised + mult_noise * torch.randn_like(x)
if old_denoised is None or torch.sum(next_sigma) < 1e-14:
@@ -415,12 +458,16 @@ class SDEDPMPP2MSampler(BaseDiffusionSampler):
x_advanced = mult[0] * x - mult[1] * denoised_d + mult_noise * torch.randn_like(x)
# apply correction if noise level is not 0 and not first step
- x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard)
+ x = torch.where(
+ append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard
+ )
return x, denoised
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, **kwargs):
- x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
+ x, cond, uc, num_steps
+ )
old_denoised = None
for i in self.get_sigma_gen(num_sigmas):
@@ -437,7 +484,6 @@ class SDEDPMPP2MSampler(BaseDiffusionSampler):
return x
-
class SdeditEDMSampler(EulerEDMSampler):
def __init__(self, edit_ratio=0.5, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -446,7 +492,9 @@ class SdeditEDMSampler(EulerEDMSampler):
def __call__(self, denoiser, image, randn, cond, uc=None, num_steps=None, edit_ratio=None):
randn_unit = randn.clone()
- randn, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(randn, cond, uc, num_steps)
+ randn, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
+ randn, cond, uc, num_steps
+ )
if num_steps is None:
num_steps = self.num_steps
@@ -461,7 +509,9 @@ class SdeditEDMSampler(EulerEDMSampler):
x = image + randn_unit * append_dims(s_in * sigmas[i], len(randn_unit.shape))
gamma = (
- min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) if self.s_tmin <= sigmas[i] <= self.s_tmax else 0.0
+ min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)
+ if self.s_tmin <= sigmas[i] <= self.s_tmax
+ else 0.0
)
x = self.sampler_step(
s_in * sigmas[i],
@@ -475,8 +525,8 @@ class SdeditEDMSampler(EulerEDMSampler):
return x
-
class VideoDDIMSampler(BaseDiffusionSampler):
+
def __init__(self, fixed_frames=0, sdedit=False, **kwargs):
super().__init__(**kwargs)
self.fixed_frames = fixed_frames
@@ -484,13 +534,10 @@ class VideoDDIMSampler(BaseDiffusionSampler):
def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
alpha_cumprod_sqrt, timesteps = self.discretization(
- self.num_steps if num_steps is None else num_steps,
- device=self.device,
- return_idx=True,
- do_append_zero=False,
+ self.num_steps if num_steps is None else num_steps, device=self.device, return_idx=True, do_append_zero=False
)
alpha_cumprod_sqrt = torch.cat([alpha_cumprod_sqrt, alpha_cumprod_sqrt.new_ones([1])])
- timesteps = torch.cat([torch.tensor(list(timesteps)).new_zeros([1]) - 1, torch.tensor(list(timesteps))])
+ timesteps = torch.cat([torch.tensor(list(timesteps)).new_zeros([1])-1, torch.tensor(list(timesteps))])
uc = default(uc, cond)
@@ -500,51 +547,36 @@ class VideoDDIMSampler(BaseDiffusionSampler):
return x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps
- def denoise(self, x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep=None, idx=None, scale=None, scale_emb=None):
+ def denoise(self, x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep=None, idx=None, scale=None, scale_emb=None, ofs=None):
additional_model_inputs = {}
+ if ofs is not None:
+ additional_model_inputs['ofs'] = ofs
+
if isinstance(scale, torch.Tensor) == False and scale == 1:
- additional_model_inputs["idx"] = x.new_ones([x.shape[0]]) * timestep
+ additional_model_inputs['idx'] = x.new_ones([x.shape[0]]) * timestep
if scale_emb is not None:
- additional_model_inputs["scale_emb"] = scale_emb
+ additional_model_inputs['scale_emb'] = scale_emb
denoised = denoiser(x, alpha_cumprod_sqrt, cond, **additional_model_inputs).to(torch.float32)
else:
- additional_model_inputs["idx"] = torch.cat([x.new_ones([x.shape[0]]) * timestep] * 2)
- denoised = denoiser(
- *self.guider.prepare_inputs(x, alpha_cumprod_sqrt, cond, uc), **additional_model_inputs
- ).to(torch.float32)
+ additional_model_inputs['idx'] = torch.cat([x.new_ones([x.shape[0]]) * timestep] * 2)
+ denoised = denoiser(*self.guider.prepare_inputs(x, alpha_cumprod_sqrt, cond, uc), **additional_model_inputs).to(torch.float32)
if isinstance(self.guider, DynamicCFG):
- denoised = self.guider(
- denoised, (1 - alpha_cumprod_sqrt**2) ** 0.5, step_index=self.num_steps - timestep, scale=scale
- )
+ denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt**2)**0.5, step_index=self.num_steps - timestep, scale=scale)
else:
- denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt**2) ** 0.5, scale=scale)
+ denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt**2)**0.5, scale=scale)
return denoised
- def sampler_step(
- self,
- alpha_cumprod_sqrt,
- next_alpha_cumprod_sqrt,
- denoiser,
- x,
- cond,
- uc=None,
- idx=None,
- timestep=None,
- scale=None,
- scale_emb=None,
- ):
- denoised = self.denoise(
- x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx, scale=scale, scale_emb=scale_emb
- ).to(torch.float32)
+ def sampler_step(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, denoiser, x, cond, uc=None, idx=None, timestep=None, scale=None, scale_emb=None, ofs=None):
+ denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx, scale=scale, scale_emb=scale_emb, ofs=ofs).to(torch.float32) # 1020
- a_t = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5
+ a_t = ((1-next_alpha_cumprod_sqrt**2)/(1-alpha_cumprod_sqrt**2))**0.5
b_t = next_alpha_cumprod_sqrt - alpha_cumprod_sqrt * a_t
x = append_dims(a_t, x.ndim) * x + append_dims(b_t, x.ndim) * denoised
return x
- def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None):
+ def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None, ofs=None): # 1020
x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop(
x, cond, uc, num_steps
)
@@ -558,25 +590,83 @@ class VideoDDIMSampler(BaseDiffusionSampler):
cond,
uc,
idx=self.num_steps - i,
- timestep=timesteps[-(i + 1)],
+ timestep=timesteps[-(i+1)],
scale=scale,
scale_emb=scale_emb,
+ ofs=ofs # 1020
)
return x
+class Image2VideoDDIMSampler(BaseDiffusionSampler):
+
+ def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
+ alpha_cumprod_sqrt, timesteps = self.discretization(
+ self.num_steps if num_steps is None else num_steps, device=self.device, return_idx=True
+ )
+ uc = default(uc, cond)
+
+ num_sigmas = len(alpha_cumprod_sqrt)
+
+ s_in = x.new_ones([x.shape[0]])
+
+ return x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps
+
+ def denoise(self, x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep=None):
+ additional_model_inputs = {}
+ additional_model_inputs['idx'] = torch.cat([x.new_ones([x.shape[0]]) * timestep] * 2)
+ denoised = denoiser(*self.guider.prepare_inputs(x, alpha_cumprod_sqrt, cond, uc), **additional_model_inputs).to(
+ torch.float32)
+ if isinstance(self.guider, DynamicCFG):
+ denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt ** 2) ** 0.5, step_index=self.num_steps - timestep)
+ else:
+ denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt ** 2) ** 0.5)
+ return denoised
+
+ def sampler_step(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, denoiser, x, cond, uc=None, idx=None,
+ timestep=None):
+ # æ€å€çsigmaå®é
äžæ¯alpha_cumprod_sqrt
+ denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep).to(torch.float32)
+ if idx == 1:
+ return denoised
+
+ a_t = ((1 - next_alpha_cumprod_sqrt ** 2) / (1 - alpha_cumprod_sqrt ** 2)) ** 0.5
+ b_t = next_alpha_cumprod_sqrt - alpha_cumprod_sqrt * a_t
+
+ x = append_dims(a_t, x.ndim) * x + append_dims(b_t, x.ndim) * denoised
+ return x
+
+ def __call__(self, image, denoiser, x, cond, uc=None, num_steps=None):
+ x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop(
+ x, cond, uc, num_steps
+ )
+
+ for i in self.get_sigma_gen(num_sigmas):
+ x = self.sampler_step(
+ s_in * alpha_cumprod_sqrt[i],
+ s_in * alpha_cumprod_sqrt[i + 1],
+ denoiser,
+ x,
+ cond,
+ uc,
+ idx=self.num_steps - i,
+ timestep=timesteps[-(i + 1)]
+ )
+
+ return x
+
class VPSDEDPMPP2MSampler(VideoDDIMSampler):
def get_variables(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt=None):
- alpha_cumprod = alpha_cumprod_sqrt**2
- lamb = ((alpha_cumprod / (1 - alpha_cumprod)) ** 0.5).log()
- next_alpha_cumprod = next_alpha_cumprod_sqrt**2
- lamb_next = ((next_alpha_cumprod / (1 - next_alpha_cumprod)) ** 0.5).log()
+ alpha_cumprod = alpha_cumprod_sqrt ** 2
+ lamb = ((alpha_cumprod / (1-alpha_cumprod))**0.5).log()
+ next_alpha_cumprod = next_alpha_cumprod_sqrt ** 2
+ lamb_next = ((next_alpha_cumprod / (1-next_alpha_cumprod))**0.5).log()
h = lamb_next - lamb
if previous_alpha_cumprod_sqrt is not None:
- previous_alpha_cumprod = previous_alpha_cumprod_sqrt**2
- lamb_previous = ((previous_alpha_cumprod / (1 - previous_alpha_cumprod)) ** 0.5).log()
+ previous_alpha_cumprod = previous_alpha_cumprod_sqrt ** 2
+ lamb_previous = ((previous_alpha_cumprod / (1-previous_alpha_cumprod))**0.5).log()
h_last = lamb - lamb_previous
r = h_last / h
return h, r, lamb, lamb_next
@@ -584,8 +674,8 @@ class VPSDEDPMPP2MSampler(VideoDDIMSampler):
return h, None, lamb, lamb_next
def get_mult(self, h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt):
- mult1 = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5 * (-h).exp()
- mult2 = (-2 * h).expm1() * next_alpha_cumprod_sqrt
+ mult1 = ((1-next_alpha_cumprod_sqrt**2) / (1-alpha_cumprod_sqrt**2))**0.5 * (-h).exp()
+ mult2 = (-2*h).expm1() * next_alpha_cumprod_sqrt
if previous_alpha_cumprod_sqrt is not None:
mult3 = 1 + 1 / (2 * r)
@@ -608,21 +698,18 @@ class VPSDEDPMPP2MSampler(VideoDDIMSampler):
timestep=None,
scale=None,
scale_emb=None,
+ ofs=None # 1020
):
- denoised = self.denoise(
- x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx, scale=scale, scale_emb=scale_emb
- ).to(torch.float32)
+ denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx, scale=scale, scale_emb=scale_emb, ofs=ofs).to(torch.float32) # 1020
if idx == 1:
return denoised, denoised
- h, r, lamb, lamb_next = self.get_variables(
- alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt
- )
+ h, r, lamb, lamb_next = self.get_variables(alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt)
mult = [
append_dims(mult, x.ndim)
for mult in self.get_mult(h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt)
]
- mult_noise = append_dims((1 - next_alpha_cumprod_sqrt**2) ** 0.5 * (1 - (-2 * h).exp()) ** 0.5, x.ndim)
+ mult_noise = append_dims((1-next_alpha_cumprod_sqrt**2)**0.5 * (1 - (-2*h).exp())**0.5, x.ndim)
x_standard = mult[0] * x - mult[1] * denoised + mult_noise * torch.randn_like(x)
if old_denoised is None or torch.sum(next_alpha_cumprod_sqrt) < 1e-14:
@@ -636,24 +723,23 @@ class VPSDEDPMPP2MSampler(VideoDDIMSampler):
return x, denoised
- def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None):
+ def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None, ofs=None): # 1020
x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop(
x, cond, uc, num_steps
)
if self.fixed_frames > 0:
- prefix_frames = x[:, : self.fixed_frames]
+ prefix_frames = x[:, :self.fixed_frames]
old_denoised = None
for i in self.get_sigma_gen(num_sigmas):
+
if self.fixed_frames > 0:
if self.sdedit:
rd = torch.randn_like(prefix_frames)
- noised_prefix_frames = alpha_cumprod_sqrt[i] * prefix_frames + rd * append_dims(
- s_in * (1 - alpha_cumprod_sqrt[i] ** 2) ** 0.5, len(prefix_frames.shape)
- )
- x = torch.cat([noised_prefix_frames, x[:, self.fixed_frames :]], dim=1)
+ noised_prefix_frames = alpha_cumprod_sqrt[i] * prefix_frames + rd * append_dims(s_in * (1 - alpha_cumprod_sqrt[i] ** 2)**0.5, len(prefix_frames.shape))
+ x = torch.cat([noised_prefix_frames, x[:, self.fixed_frames:]], dim=1)
else:
- x = torch.cat([prefix_frames, x[:, self.fixed_frames :]], dim=1)
+ x = torch.cat([prefix_frames, x[:, self.fixed_frames:]], dim=1)
x, old_denoised = self.sampler_step(
old_denoised,
None if i == 0 else s_in * alpha_cumprod_sqrt[i - 1],
@@ -664,28 +750,29 @@ class VPSDEDPMPP2MSampler(VideoDDIMSampler):
cond,
uc=uc,
idx=self.num_steps - i,
- timestep=timesteps[-(i + 1)],
+ timestep=timesteps[-(i+1)],
scale=scale,
scale_emb=scale_emb,
+ ofs=ofs # 1020
)
if self.fixed_frames > 0:
- x = torch.cat([prefix_frames, x[:, self.fixed_frames :]], dim=1)
+ x = torch.cat([prefix_frames, x[:, self.fixed_frames:]], dim=1)
return x
class VPODEDPMPP2MSampler(VideoDDIMSampler):
def get_variables(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt=None):
- alpha_cumprod = alpha_cumprod_sqrt**2
- lamb = ((alpha_cumprod / (1 - alpha_cumprod)) ** 0.5).log()
- next_alpha_cumprod = next_alpha_cumprod_sqrt**2
- lamb_next = ((next_alpha_cumprod / (1 - next_alpha_cumprod)) ** 0.5).log()
+ alpha_cumprod = alpha_cumprod_sqrt ** 2
+ lamb = ((alpha_cumprod / (1-alpha_cumprod))**0.5).log()
+ next_alpha_cumprod = next_alpha_cumprod_sqrt ** 2
+ lamb_next = ((next_alpha_cumprod / (1-next_alpha_cumprod))**0.5).log()
h = lamb_next - lamb
if previous_alpha_cumprod_sqrt is not None:
- previous_alpha_cumprod = previous_alpha_cumprod_sqrt**2
- lamb_previous = ((previous_alpha_cumprod / (1 - previous_alpha_cumprod)) ** 0.5).log()
+ previous_alpha_cumprod = previous_alpha_cumprod_sqrt ** 2
+ lamb_previous = ((previous_alpha_cumprod / (1-previous_alpha_cumprod))**0.5).log()
h_last = lamb - lamb_previous
r = h_last / h
return h, r, lamb, lamb_next
@@ -693,7 +780,7 @@ class VPODEDPMPP2MSampler(VideoDDIMSampler):
return h, None, lamb, lamb_next
def get_mult(self, h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt):
- mult1 = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5
+ mult1 = ((1-next_alpha_cumprod_sqrt**2) / (1-alpha_cumprod_sqrt**2))**0.5
mult2 = (-h).expm1() * next_alpha_cumprod_sqrt
if previous_alpha_cumprod_sqrt is not None:
@@ -714,15 +801,13 @@ class VPODEDPMPP2MSampler(VideoDDIMSampler):
cond,
uc=None,
idx=None,
- timestep=None,
+ timestep=None
):
denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx).to(torch.float32)
if idx == 1:
return denoised, denoised
- h, r, lamb, lamb_next = self.get_variables(
- alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt
- )
+ h, r, lamb, lamb_next = self.get_variables(alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt)
mult = [
append_dims(mult, x.ndim)
for mult in self.get_mult(h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt)
@@ -757,7 +842,39 @@ class VPODEDPMPP2MSampler(VideoDDIMSampler):
cond,
uc=uc,
idx=self.num_steps - i,
- timestep=timesteps[-(i + 1)],
+ timestep=timesteps[-(i+1)]
)
return x
+
+class VideoDDPMSampler(VideoDDIMSampler):
+ def sampler_step(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, denoiser, x, cond, uc=None, idx=None):
+ # æ€å€çsigmaå®é
äžæ¯alpha_cumprod_sqrt
+ denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, idx*1000//self.num_steps).to(torch.float32)
+ if idx == 1:
+ return denoised
+
+ alpha_sqrt = alpha_cumprod_sqrt / next_alpha_cumprod_sqrt
+ x = append_dims(alpha_sqrt * (1-next_alpha_cumprod_sqrt**2) / (1-alpha_cumprod_sqrt**2), x.ndim) * x \
+ + append_dims(next_alpha_cumprod_sqrt * (1-alpha_sqrt**2) / (1-alpha_cumprod_sqrt**2), x.ndim) * denoised \
+ + append_dims(((1-next_alpha_cumprod_sqrt**2) * (1-alpha_sqrt**2) / (1-alpha_cumprod_sqrt**2))**0.5, x.ndim) * torch.randn_like(x)
+
+ return x
+
+ def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
+ x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc = self.prepare_sampling_loop(
+ x, cond, uc, num_steps
+ )
+
+ for i in self.get_sigma_gen(num_sigmas):
+ x = self.sampler_step(
+ s_in * alpha_cumprod_sqrt[i],
+ s_in * alpha_cumprod_sqrt[i + 1],
+ denoiser,
+ x,
+ cond,
+ uc,
+ idx=self.num_steps - i
+ )
+
+ return x
\ No newline at end of file
diff --git a/sat/sgm/modules/diffusionmodules/sigma_sampling.py b/sat/sgm/modules/diffusionmodules/sigma_sampling.py
index 770de42..8bb623e 100644
--- a/sat/sgm/modules/diffusionmodules/sigma_sampling.py
+++ b/sat/sgm/modules/diffusionmodules/sigma_sampling.py
@@ -17,23 +17,20 @@ class EDMSampling:
class DiscreteSampling:
- def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True, uniform_sampling=False):
+ def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True, uniform_sampling=False, group_num=0):
self.num_idx = num_idx
- self.sigmas = instantiate_from_config(discretization_config)(num_idx, do_append_zero=do_append_zero, flip=flip)
+ self.sigmas = instantiate_from_config(discretization_config)(
+ num_idx, do_append_zero=do_append_zero, flip=flip
+ )
world_size = mpu.get_data_parallel_world_size()
+ if world_size <= 8:
+ uniform_sampling = False
self.uniform_sampling = uniform_sampling
+ self.group_num = group_num
if self.uniform_sampling:
- i = 1
- while True:
- if world_size % i != 0 or num_idx % (world_size // i) != 0:
- i += 1
- else:
- self.group_num = world_size // i
- break
-
assert self.group_num > 0
- assert world_size % self.group_num == 0
- self.group_width = world_size // self.group_num # the number of rank in one group
+ assert world_size % group_num == 0
+ self.group_width = world_size // group_num # the number of rank in one group
self.sigma_interval = self.num_idx // self.group_num
def idx_to_sigma(self, idx):
@@ -45,9 +42,7 @@ class DiscreteSampling:
group_index = rank // self.group_width
idx = default(
rand,
- torch.randint(
- group_index * self.sigma_interval, (group_index + 1) * self.sigma_interval, (n_samples,)
- ),
+ torch.randint(group_index * self.sigma_interval, (group_index + 1) * self.sigma_interval, (n_samples,)),
)
else:
idx = default(
@@ -59,7 +54,6 @@ class DiscreteSampling:
else:
return self.idx_to_sigma(idx)
-
class PartialDiscreteSampling:
def __init__(self, discretization_config, total_num_idx, partial_num_idx, do_append_zero=False, flip=True):
self.total_num_idx = total_num_idx
diff --git a/sat/vae_modules/autoencoder.py b/sat/vae_modules/autoencoder.py
index 7c0cc80..9642fb4 100644
--- a/sat/vae_modules/autoencoder.py
+++ b/sat/vae_modules/autoencoder.py
@@ -592,8 +592,11 @@ class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine):
unregularized: bool = False,
input_cp: bool = False,
output_cp: bool = False,
+ use_cp: bool = True,
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
- if self.cp_size > 0 and not input_cp:
+ if self.cp_size <= 1:
+ use_cp = False
+ if self.cp_size > 0 and use_cp and not input_cp:
if not is_context_parallel_initialized:
initialize_context_parallel(self.cp_size)
@@ -603,11 +606,11 @@ class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine):
x = _conv_split(x, dim=2, kernel_size=1)
if return_reg_log:
- z, reg_log = super().encode(x, return_reg_log, unregularized)
+ z, reg_log = super().encode(x, return_reg_log, unregularized, use_cp=use_cp)
else:
- z = super().encode(x, return_reg_log, unregularized)
+ z = super().encode(x, return_reg_log, unregularized, use_cp=use_cp)
- if self.cp_size > 0 and not output_cp:
+ if self.cp_size > 0 and use_cp and not output_cp:
z = _conv_gather(z, dim=2, kernel_size=1)
if return_reg_log:
@@ -619,23 +622,24 @@ class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine):
z: torch.Tensor,
input_cp: bool = False,
output_cp: bool = False,
- split_kernel_size: int = 1,
+ use_cp: bool = True,
**kwargs,
):
- if self.cp_size > 0 and not input_cp:
+ if self.cp_size <= 1:
+ use_cp = False
+ if self.cp_size > 0 and use_cp and not input_cp:
if not is_context_parallel_initialized:
initialize_context_parallel(self.cp_size)
global_src_rank = get_context_parallel_group_rank() * self.cp_size
torch.distributed.broadcast(z, src=global_src_rank, group=get_context_parallel_group())
- z = _conv_split(z, dim=2, kernel_size=split_kernel_size)
+ z = _conv_split(z, dim=2, kernel_size=1)
- x = super().decode(z, **kwargs)
-
- if self.cp_size > 0 and not output_cp:
- x = _conv_gather(x, dim=2, kernel_size=split_kernel_size)
+ x = super().decode(z, use_cp=use_cp, **kwargs)
+ if self.cp_size > 0 and use_cp and not output_cp:
+ x = _conv_gather(x, dim=2, kernel_size=1)
return x
def forward(
diff --git a/sat/vae_modules/cp_enc_dec.py b/sat/vae_modules/cp_enc_dec.py
index d50720d..1d9c34f 100644
--- a/sat/vae_modules/cp_enc_dec.py
+++ b/sat/vae_modules/cp_enc_dec.py
@@ -16,11 +16,7 @@ from sgm.util import (
get_context_parallel_group_rank,
)
-# try:
from vae_modules.utils import SafeConv3d as Conv3d
-# except:
-# # Degrade to normal Conv3d if SafeConv3d is not available
-# from torch.nn import Conv3d
def cast_tuple(t, length=1):
@@ -81,8 +77,6 @@ def _split(input_, dim):
cp_rank = get_context_parallel_rank()
- # print('in _split, cp_rank:', cp_rank, 'input_size:', input_.shape)
-
inpu_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous()
input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous()
dim_size = input_.size()[dim] // cp_world_size
@@ -94,8 +88,6 @@ def _split(input_, dim):
output = torch.cat([inpu_first_frame_, output], dim=dim)
output = output.contiguous()
- # print('out _split, cp_rank:', cp_rank, 'output_size:', output.shape)
-
return output
@@ -382,19 +374,6 @@ class ContextParallelCausalConv3d(nn.Module):
self.cache_padding = None
def forward(self, input_, clear_cache=True):
- # if input_.shape[2] == 1: # handle image
- # # first frame padding
- # input_parallel = torch.cat([input_] * self.time_kernel_size, dim=2)
- # else:
- # input_parallel = conv_pass_from_last_rank(input_, self.temporal_dim, self.time_kernel_size)
-
- # padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
- # input_parallel = F.pad(input_parallel, padding_2d, mode = 'constant', value = 0)
-
- # output_parallel = self.conv(input_parallel)
- # output = output_parallel
- # return output
-
input_parallel = fake_cp_pass_from_previous_rank(
input_, self.temporal_dim, self.time_kernel_size, self.cache_padding
)
@@ -441,7 +420,8 @@ class ContextParallelGroupNorm(torch.nn.GroupNorm):
return output
-def Normalize(in_channels, gather=False, **kwargs): # same for 3D and 2D
+def Normalize(in_channels, gather=False, **kwargs):
+ # same for 3D and 2D
if gather:
return ContextParallelGroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
else:
@@ -488,24 +468,34 @@ class SpatialNorm3D(nn.Module):
kernel_size=1,
)
- def forward(self, f, zq, clear_fake_cp_cache=True):
- if f.shape[2] > 1 and f.shape[2] % 2 == 1:
+ def forward(self, f, zq, clear_fake_cp_cache=True, fake_cp_rank0=True):
+ if f.shape[2] > 1 and get_context_parallel_rank() == 0 and fake_cp_rank0:
f_first, f_rest = f[:, :, :1], f[:, :, 1:]
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
zq_first, zq_rest = zq[:, :, :1], zq[:, :, 1:]
zq_first = torch.nn.functional.interpolate(zq_first, size=f_first_size, mode="nearest")
- zq_rest = torch.nn.functional.interpolate(zq_rest, size=f_rest_size, mode="nearest")
+
+ zq_rest_splits = torch.split(zq_rest, 32, dim=1)
+ interpolated_splits = [
+ torch.nn.functional.interpolate(split, size=f_rest_size, mode="nearest") for split in zq_rest_splits
+ ]
+
+ zq_rest = torch.cat(interpolated_splits, dim=1)
+ # zq_rest = torch.nn.functional.interpolate(zq_rest, size=f_rest_size, mode="nearest")
zq = torch.cat([zq_first, zq_rest], dim=2)
else:
- zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:], mode="nearest")
+ f_size = f.shape[-3:]
+
+ zq_splits = torch.split(zq, 32, dim=1)
+ interpolated_splits = [
+ torch.nn.functional.interpolate(split, size=f_size, mode="nearest") for split in zq_splits
+ ]
+ zq = torch.cat(interpolated_splits, dim=1)
if self.add_conv:
zq = self.conv(zq, clear_cache=clear_fake_cp_cache)
- # f = conv_gather_from_context_parallel_region(f, dim=2, kernel_size=1)
norm_f = self.norm_layer(f)
- # norm_f = conv_scatter_to_context_parallel_region(norm_f, dim=2, kernel_size=1)
-
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
return new_f
@@ -541,23 +531,44 @@ class Upsample3D(nn.Module):
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
self.compress_time = compress_time
- def forward(self, x):
+ def forward(self, x, fake_cp_rank0=True):
if self.compress_time and x.shape[2] > 1:
- if x.shape[2] % 2 == 1:
+ if get_context_parallel_rank() == 0 and fake_cp_rank0:
+ # print(x.shape)
# split first frame
x_first, x_rest = x[:, :, 0], x[:, :, 1:]
x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0, mode="nearest")
- x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0, mode="nearest")
+
+ splits = torch.split(x_rest, 32, dim=1)
+ interpolated_splits = [
+ torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits
+ ]
+ x_rest = torch.cat(interpolated_splits, dim=1)
+
+ # x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0, mode="nearest")
x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2)
else:
- x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+ splits = torch.split(x, 32, dim=1)
+ interpolated_splits = [
+ torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits
+ ]
+ x = torch.cat(interpolated_splits, dim=1)
+
+ # x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
else:
# only interpolate 2D
t = x.shape[2]
x = rearrange(x, "b c t h w -> (b t) c h w")
- x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+ # x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+
+ splits = torch.split(x, 32, dim=1)
+ interpolated_splits = [
+ torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits
+ ]
+ x = torch.cat(interpolated_splits, dim=1)
+
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
if self.with_conv:
@@ -579,21 +590,30 @@ class DownSample3D(nn.Module):
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0)
self.compress_time = compress_time
- def forward(self, x):
+ def forward(self, x, fake_cp_rank0=True):
if self.compress_time and x.shape[2] > 1:
h, w = x.shape[-2:]
x = rearrange(x, "b c t h w -> (b h w) c t")
- if x.shape[-1] % 2 == 1:
+ if get_context_parallel_rank() == 0 and fake_cp_rank0:
# split first frame
x_first, x_rest = x[..., 0], x[..., 1:]
if x_rest.shape[-1] > 0:
- x_rest = torch.nn.functional.avg_pool1d(x_rest, kernel_size=2, stride=2)
+ splits = torch.split(x_rest, 32, dim=1)
+ interpolated_splits = [
+ torch.nn.functional.avg_pool1d(split, kernel_size=2, stride=2) for split in splits
+ ]
+ x_rest = torch.cat(interpolated_splits, dim=1)
x = torch.cat([x_first[..., None], x_rest], dim=-1)
x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w)
else:
- x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2)
+ # x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2)
+ splits = torch.split(x, 32, dim=1)
+ interpolated_splits = [
+ torch.nn.functional.avg_pool1d(split, kernel_size=2, stride=2) for split in splits
+ ]
+ x = torch.cat(interpolated_splits, dim=1)
x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w)
if self.with_conv:
@@ -673,13 +693,13 @@ class ContextParallelResnetBlock3D(nn.Module):
padding=0,
)
- def forward(self, x, temb, zq=None, clear_fake_cp_cache=True):
+ def forward(self, x, temb, zq=None, clear_fake_cp_cache=True, fake_cp_rank0=True):
h = x
# if isinstance(self.norm1, torch.nn.GroupNorm):
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
if zq is not None:
- h = self.norm1(h, zq, clear_fake_cp_cache=clear_fake_cp_cache)
+ h = self.norm1(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0)
else:
h = self.norm1(h)
# if isinstance(self.norm1, torch.nn.GroupNorm):
@@ -694,7 +714,7 @@ class ContextParallelResnetBlock3D(nn.Module):
# if isinstance(self.norm2, torch.nn.GroupNorm):
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
if zq is not None:
- h = self.norm2(h, zq, clear_fake_cp_cache=clear_fake_cp_cache)
+ h = self.norm2(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0)
else:
h = self.norm2(h)
# if isinstance(self.norm2, torch.nn.GroupNorm):
@@ -807,23 +827,24 @@ class ContextParallelEncoder3D(nn.Module):
kernel_size=3,
)
- def forward(self, x, **kwargs):
+ def forward(self, x, clear_fake_cp_cache=True, fake_cp_rank0=True):
# timestep embedding
temb = None
# downsampling
- h = self.conv_in(x)
+ h = self.conv_in(x, clear_cache=clear_fake_cp_cache)
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
- h = self.down[i_level].block[i_block](h, temb)
+ h = self.down[i_level].block[i_block](h, temb, clear_fake_cp_cache=clear_fake_cp_cache)
if len(self.down[i_level].attn) > 0:
+ print("Attention not implemented")
h = self.down[i_level].attn[i_block](h)
if i_level != self.num_resolutions - 1:
- h = self.down[i_level].downsample(h)
+ h = self.down[i_level].downsample(h, fake_cp_rank0=fake_cp_rank0)
# middle
- h = self.mid.block_1(h, temb)
- h = self.mid.block_2(h, temb)
+ h = self.mid.block_1(h, temb, clear_fake_cp_cache=clear_fake_cp_cache)
+ h = self.mid.block_2(h, temb, clear_fake_cp_cache=clear_fake_cp_cache)
# end
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
@@ -831,7 +852,7 @@ class ContextParallelEncoder3D(nn.Module):
# h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
h = nonlinearity(h)
- h = self.conv_out(h)
+ h = self.conv_out(h, clear_cache=clear_fake_cp_cache)
return h
@@ -934,6 +955,11 @@ class ContextParallelDecoder3D(nn.Module):
up.block = block
up.attn = attn
if i_level != 0:
+ # # Symmetrical enc-dec
+ if i_level <= self.temporal_compress_level:
+ up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=True)
+ else:
+ up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False)
if i_level < self.num_resolutions - self.temporal_compress_level:
up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False)
else:
@@ -948,7 +974,7 @@ class ContextParallelDecoder3D(nn.Module):
kernel_size=3,
)
- def forward(self, z, clear_fake_cp_cache=True, **kwargs):
+ def forward(self, z, clear_fake_cp_cache=True, fake_cp_rank0=True):
self.last_z_shape = z.shape
# timestep embedding
@@ -961,23 +987,25 @@ class ContextParallelDecoder3D(nn.Module):
h = self.conv_in(z, clear_cache=clear_fake_cp_cache)
# middle
- h = self.mid.block_1(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache)
- h = self.mid.block_2(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache)
+ h = self.mid.block_1(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0)
+ h = self.mid.block_2(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
- h = self.up[i_level].block[i_block](h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache)
+ h = self.up[i_level].block[i_block](
+ h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0
+ )
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h, zq)
if i_level != 0:
- h = self.up[i_level].upsample(h)
+ h = self.up[i_level].upsample(h, fake_cp_rank0=fake_cp_rank0)
# end
if self.give_pre_end:
return h
- h = self.norm_out(h, zq, clear_fake_cp_cache=clear_fake_cp_cache)
+ h = self.norm_out(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0)
h = nonlinearity(h)
h = self.conv_out(h, clear_cache=clear_fake_cp_cache)
diff --git a/tools/convert_weight_sat2hf.py b/tools/convert_weight_sat2hf.py
index 183be62..f325018 100644
--- a/tools/convert_weight_sat2hf.py
+++ b/tools/convert_weight_sat2hf.py
@@ -1,22 +1,15 @@
"""
-This script demonstrates how to convert and generate video from a text prompt
-using CogVideoX with ð€Huggingface Diffusers Pipeline.
-This script requires the `diffusers>=0.30.2` library to be installed.
-
-Functions:
- - reassign_query_key_value_inplace: Reassigns the query, key, and value weights in-place.
- - reassign_query_key_layernorm_inplace: Reassigns layer normalization for query and key in-place.
- - reassign_adaln_norm_inplace: Reassigns adaptive layer normalization in-place.
- - remove_keys_inplace: Removes specified keys from the state_dict in-place.
- - replace_up_keys_inplace: Replaces keys in the "up" block in-place.
- - get_state_dict: Extracts the state_dict from a saved checkpoint.
- - update_state_dict_inplace: Updates the state_dict with new key assignments in-place.
- - convert_transformer: Converts a transformer checkpoint to the CogVideoX format.
- - convert_vae: Converts a VAE checkpoint to the CogVideoX format.
- - get_args: Parses command-line arguments for the script.
- - generate_video: Generates a video from a text prompt using the CogVideoX pipeline.
-"""
+The script demonstrates how to convert the weights of the CogVideoX model from SAT to Hugging Face format.
+This script supports the conversion of the following models:
+- CogVideoX-2B
+- CogVideoX-5B, CogVideoX-5B-I2V
+- CogVideoX1.1-5B, CogVideoX1.1-5B-I2V
+
+Original Script:
+https://github.com/huggingface/diffusers/blob/main/scripts/convert_cogvideox_to_diffusers.py
+
+"""
import argparse
from typing import Any, Dict
@@ -153,12 +146,12 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key:
def convert_transformer(
- ckpt_path: str,
- num_layers: int,
- num_attention_heads: int,
- use_rotary_positional_embeddings: bool,
- i2v: bool,
- dtype: torch.dtype,
+ ckpt_path: str,
+ num_layers: int,
+ num_attention_heads: int,
+ use_rotary_positional_embeddings: bool,
+ i2v: bool,
+ dtype: torch.dtype,
):
PREFIX_KEY = "model.diffusion_model."
@@ -172,7 +165,7 @@ def convert_transformer(
).to(dtype=dtype)
for key in list(original_state_dict.keys()):
- new_key = key[len(PREFIX_KEY) :]
+ new_key = key[len(PREFIX_KEY):]
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)
@@ -209,7 +202,8 @@ def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype):
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
- "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint")
+ "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
+ )
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint")
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
parser.add_argument("--fp16", action="store_true", default=False, help="Whether to save the model weights in fp16")
@@ -259,9 +253,10 @@ if __name__ == "__main__":
if args.vae_ckpt_path is not None:
vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, dtype)
- text_encoder_id = "google/t5-v1_1-xxl"
+ text_encoder_id = "/share/official_pretrains/hf_home/t5-v1_1-xxl"
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
+
# Apparently, the conversion does not work anymore without this :shrug:
for param in text_encoder.parameters():
param.data = param.data.contiguous()
@@ -301,4 +296,7 @@ if __name__ == "__main__":
# We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird
# for users to specify variant when the default is not fp32 and they want to run with the correct default (which
# is either fp16/bf16 here).
- pipe.save_pretrained(args.output_path, safe_serialization=True, push_to_hub=args.push_to_hub)
+
+ # This is necessary This is necessary for users with insufficient memory,
+ # such as those using Colab and notebooks, as it can save some memory used for model loading.
+ pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=args.push_to_hub)