mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 19:41:59 +08:00
Compare commits
173 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
5ab1e2449f | ||
|
a01ffd9aba | ||
|
9be282d461 | ||
|
c624cb0d91 | ||
|
39c6562dc8 | ||
|
b9b0539dbe | ||
|
129c375c85 | ||
|
536b705105 | ||
|
a691a6dd35 | ||
|
6454293a1d | ||
|
887a4c7365 | ||
|
a494fa50cd | ||
|
4fb6766d7c | ||
|
8d90381ba8 | ||
|
eb66c9c6dc | ||
|
2c33c0982b | ||
|
d6bb910697 | ||
|
84766d02e8 | ||
|
e0bf395458 | ||
|
e44c9f2c83 | ||
|
5be6c0512f | ||
|
4dac252c63 | ||
|
250a0bce45 | ||
|
58d66c8a08 | ||
|
dd76b2b9ea | ||
|
34c6ba22ab | ||
|
bbe909d7f7 | ||
|
ea994c75c2 | ||
|
aa12ed37f5 | ||
|
d9e2a415e8 | ||
|
0e26f54cbe | ||
|
c1ca70ba67 | ||
|
bf73742c05 | ||
|
bf9c351a10 | ||
|
0e78f20629 | ||
|
4615479b51 | ||
|
7993670957 | ||
|
4878edd0cf | ||
|
78275b0480 | ||
|
455b44a7b5 | ||
|
954ba28d3c | ||
|
4f1cc66815 | ||
|
1534bf33eb | ||
|
86a0226f80 | ||
|
70c899f444 | ||
|
b362663679 | ||
|
30ba1085ff | ||
|
3252614569 | ||
|
f66f1647e2 | ||
|
f5169385bd | ||
|
795dd144a4 | ||
|
fdb9820949 | ||
|
09a49d3546 | ||
|
cd861bbe1e | ||
|
35383e2db3 | ||
|
7dc8516bcb | ||
|
2f275e82b5 | ||
|
caa24bdc36 | ||
|
e213b6c083 | ||
|
70ca65300c | ||
|
f6d722cec7 | ||
|
07766001f6 | ||
|
8f1829f1cd | ||
|
045e1b308b | ||
|
249fadfb76 | ||
|
10de04fc08 | ||
|
0e21d41b12 | ||
|
392e37021a | ||
|
11935892ae | ||
|
ee1f666206 | ||
|
e084a4a270 | ||
|
96e511b413 | ||
|
36427274d6 | ||
|
1789f07256 | ||
|
1b886326b2 | ||
|
9157e0cbc8 | ||
|
49dc370de6 | ||
|
93b906b3fb | ||
|
7e1ac76847 | ||
|
66e4ba2592 | ||
|
de5bef6611 | ||
|
ffb6ee36b4 | ||
|
c817e7f062 | ||
|
e5b8f9a2ee | ||
|
f731c35f70 | ||
|
ce2c299c1f | ||
|
b080c6a010 | ||
|
a88c1ede69 | ||
|
362b7bf273 | ||
|
aa240dc675 | ||
|
cf2fff7e55 | ||
|
7fa1bb48be | ||
|
48ad178818 | ||
|
6ef15dd2a5 | ||
|
6e79472417 | ||
|
26b87cd4ff | ||
|
04a60e7435 | ||
|
a001842834 | ||
|
91d79fd9a4 | ||
|
45d40450a1 | ||
|
6eae5c201e | ||
|
2a6cca0656 | ||
|
fa4659fb2c | ||
|
6971364591 | ||
|
60f6a3d7ee | ||
|
a505f2e312 | ||
|
78f655a9a4 | ||
|
85e00a1082 | ||
|
918ebb5a54 | ||
|
e3f6def234 | ||
|
7b282246dd | ||
|
5cb9303286 | ||
|
ba85627577 | ||
|
2508c8353b | ||
|
48ac9c1066 | ||
|
21693ca770 | ||
|
a6e611e354 | ||
|
7935bd58a1 | ||
|
1811c50e73 | ||
|
92a589240f | ||
|
7add8f8437 | ||
|
cfaca91cde | ||
|
d3a7d2dc91 | ||
|
46098f446b | ||
|
5a03e6fa79 | ||
|
1605e95033 | ||
|
7b4c9db6d9 | ||
|
36f1333788 | ||
|
4d1b9fd166 | ||
|
3ff9d3049d | ||
|
496e220463 | ||
|
a46d762cd9 | ||
|
87ccd38cea | ||
|
5aa6d3a9ee | ||
|
a094b34425 | ||
|
0fe46df21f | ||
|
f1a2b48974 | ||
|
d82922cc79 | ||
|
2fb763d25f | ||
|
ac2f2c78f7 | ||
|
2fdc59c3ce | ||
|
17996f11f8 | ||
|
5e3e3aabe0 | ||
|
e7a35ea33b | ||
|
cd5ceca22b | ||
|
bb2cb130a0 | ||
|
2151a3bdfb | ||
|
68d93ce8fc | ||
|
155456befa | ||
|
2475902027 | ||
|
fb806eecce | ||
|
c8c7b62aa1 | ||
|
e2987ff565 | ||
|
a8205b575d | ||
|
e7bcecf947 | ||
|
d8ee013842 | ||
|
e43a7645fd | ||
|
8e07303dbf | ||
|
0c6dc7b5d5 | ||
|
74baf5ceef | ||
|
2360393b99 | ||
|
ddd3dcd7eb | ||
|
42417ad3bc | ||
|
075fad4dae | ||
|
abe334f196 | ||
|
f1f539a9da | ||
|
494296b063 | ||
|
d1e45fbb86 | ||
|
806a7f609f | ||
|
1e6d1bbb82 | ||
|
0ae12e3ea3 | ||
|
4a3035d64e | ||
|
3a9af5bdd9 |
28
.github/ISSUE_TEMPLATE/PULL_REQUEST_TEMPLATE.md
vendored
Normal file
28
.github/ISSUE_TEMPLATE/PULL_REQUEST_TEMPLATE.md
vendored
Normal file
@ -0,0 +1,28 @@
|
||||
# Contribution Guide
|
||||
|
||||
We welcome your contributions to this repository. To ensure elegant code style and better code quality, we have prepared the following contribution guidelines.
|
||||
|
||||
## What We Accept
|
||||
|
||||
+ This PR fixes a typo or improves the documentation (if this is the case, you may skip the other checks).
|
||||
+ This PR fixes a specific issue — please reference the issue number in the PR description. Make sure your code strictly follows the coding standards below.
|
||||
+ This PR introduces a new feature — please clearly explain the necessity and implementation of the feature. Make sure your code strictly follows the coding standards below.
|
||||
|
||||
## Code Style Guide
|
||||
|
||||
Good code style is an art. We have prepared a `pyproject.toml` and a `pre-commit` hook to enforce consistent code formatting across the project. You can clean up your code following the steps below:
|
||||
|
||||
1. Install the required dependencies:
|
||||
```shell
|
||||
pip install ruff pre-commit
|
||||
```
|
||||
2. Then, run the following command:
|
||||
```shell
|
||||
pre-commit run --all-files
|
||||
```
|
||||
If your code complies with the standards, you should not see any errors.
|
||||
|
||||
## Naming Conventions
|
||||
|
||||
- Please use **English** for naming; do not use Pinyin or other languages. All comments should also be in English.
|
||||
- Follow **PEP8** naming conventions strictly, and use underscores to separate words. Avoid meaningless names such as `a`, `b`, `c`.
|
34
.github/PULL_REQUEST_TEMPLATE/pr_template.md
vendored
34
.github/PULL_REQUEST_TEMPLATE/pr_template.md
vendored
@ -1,34 +0,0 @@
|
||||
# Raise valuable PR / 提出有价值的PR
|
||||
|
||||
## Caution / 注意事项:
|
||||
Users should keep the following points in mind when submitting PRs:
|
||||
|
||||
1. Ensure that your code meets the requirements in the [specification](../../resources/contribute.md).
|
||||
2. the proposed PR should be relevant, if there are multiple ideas and optimizations, they should be assigned to different PRs.
|
||||
|
||||
用户在提交PR时候应该注意以下几点:
|
||||
|
||||
1. 确保您的代码符合 [规范](../../resources/contribute_zh.md) 中的要求。
|
||||
2. 提出的PR应该具有针对性,如果具有多个不同的想法和优化方案,应该分配到不同的PR中。
|
||||
|
||||
## 不应该提出的PR / PRs that should not be proposed
|
||||
|
||||
If a developer proposes a PR about any of the following, it may be closed or Rejected.
|
||||
|
||||
1. those that don't describe improvement options.
|
||||
2. multiple issues of different types combined in one PR.
|
||||
3. The proposed PR is highly duplicative of already existing PRs.
|
||||
|
||||
如果开发者提出关于以下方面的PR,则可能会被直接关闭或拒绝通过。
|
||||
|
||||
1. 没有说明改进方案的。
|
||||
2. 多个不同类型的问题合并在一个PR中的。
|
||||
3. 提出的PR与已经存在的PR高度重复的。
|
||||
|
||||
|
||||
# 检查您的PR
|
||||
- [ ] Have you read the Contributor Guidelines, Pull Request section? / 您是否阅读了贡献者指南、Pull Request 部分?
|
||||
- [ ] Has this been discussed/approved via a Github issue or forum? If so, add a link. / 是否通过 Github 问题或论坛讨论/批准过?如果是,请添加链接。
|
||||
- [ ] Did you make sure you updated the documentation with your changes? Here are the Documentation Guidelines, and here are the Documentation Formatting Tips. /您是否确保根据您的更改更新了文档?这里是文档指南,这里是文档格式化技巧。
|
||||
- [ ] Did you write new required tests? / 您是否编写了新的必要测试?
|
||||
- [ ] Are your PRs for only one issue / 您的PR是否仅针对一个问题
|
27
.gitignore
vendored
27
.gitignore
vendored
@ -8,3 +8,30 @@ logs/
|
||||
.idea
|
||||
output*
|
||||
test*
|
||||
venv
|
||||
**/.swp
|
||||
**/*.log
|
||||
**/*.debug
|
||||
**/.vscode
|
||||
|
||||
**/*debug*
|
||||
**/.gitignore
|
||||
**/finetune/*-lora-*
|
||||
**/finetune/Disney-*
|
||||
**/wandb
|
||||
**/results
|
||||
**/*.mp4
|
||||
**/validation_set
|
||||
CogVideo-1.0
|
||||
|
||||
**/**foo**
|
||||
**/sat
|
||||
**/train_results
|
||||
**/train_res*
|
||||
|
||||
**/**foo**
|
||||
**/sat
|
||||
**/train_results
|
||||
**/train_res*
|
||||
|
||||
**/uv.lock
|
||||
|
19
.pre-commit-config.yaml
Normal file
19
.pre-commit-config.yaml
Normal file
@ -0,0 +1,19 @@
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.4.5
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix, --respect-gitignore, --config=pyproject.toml]
|
||||
- id: ruff-format
|
||||
args: [--config=pyproject.toml]
|
||||
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.5.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: end-of-file-fixer
|
||||
- id: check-yaml
|
||||
- id: check-toml
|
||||
- id: check-case-conflict
|
||||
- id: check-merge-conflict
|
||||
- id: debug-statements
|
136
README.md
136
README.md
@ -22,7 +22,14 @@ Experience the CogVideoX-5B model online at <a href="https://huggingface.co/spac
|
||||
|
||||
## Project Updates
|
||||
|
||||
- 🔥🔥 **News**: ```2024/10/13```: A more cost-effective fine-tuning framework for `CogVideoX-5B` that works with a single
|
||||
- 🔥🔥 **News**: ```2025/03/24```: We have launched [CogKit](https://github.com/THUDM/CogKit), a fine-tuning and inference framework for the **CogView4** and **CogVideoX** series. This toolkit allows you to fully explore and utilize our multimodal generation models.
|
||||
- 🔥 **News**: ```2025/02/28```: DDIM Inverse is now supported in `CogVideoX-5B` and `CogVideoX1.5-5B`. Check [here](inference/ddim_inversion.py).
|
||||
- 🔥 **News**: ```2025/01/08```: We have updated the code for `Lora` fine-tuning based on the `diffusers` version model, which uses less GPU memory. For more details, please see [here](finetune/README.md).
|
||||
- 🔥 **News**: ```2024/11/15```: We released the `CogVideoX1.5` model in the diffusers version. Only minor parameter adjustments are needed to continue using previous code.
|
||||
- 🔥 **News**: ```2024/11/08```: We have released the CogVideoX1.5 model. CogVideoX1.5 is an upgraded version of the open-source model CogVideoX.
|
||||
The CogVideoX1.5-5B series supports 10-second videos with higher resolution, and CogVideoX1.5-5B-I2V supports video generation at any resolution.
|
||||
The SAT code has already been updated, while the diffusers version is still under adaptation. Download the SAT version code [here](https://huggingface.co/THUDM/CogVideoX1.5-5B-SAT).
|
||||
- 🔥 **News**: ```2024/10/13```: A more cost-effective fine-tuning framework for `CogVideoX-5B` that works with a single
|
||||
4090 GPU, [cogvideox-factory](https://github.com/a-r-r-o-w/cogvideox-factory), has been released. It supports
|
||||
fine-tuning with multiple resolutions. Feel free to use it!
|
||||
- 🔥 **News**: ```2024/10/10```: We have updated our technical report. Please
|
||||
@ -40,11 +47,11 @@ Experience the CogVideoX-5B model online at <a href="https://huggingface.co/spac
|
||||
model [CogVLM2-Caption](https://huggingface.co/THUDM/cogvlm2-llama3-caption), used in the training process of
|
||||
CogVideoX to convert video data into text descriptions, has been open-sourced. Welcome to download and use it.
|
||||
- 🔥 ```2024/8/27```: We have open-sourced a larger model in the CogVideoX series, **CogVideoX-5B**. We have
|
||||
significantly optimized the model's inference performance, greatly lowering the inference threshold. You can run *
|
||||
*CogVideoX-2B** on older GPUs like `GTX 1080TI`, and **CogVideoX-5B** on desktop GPUs like `RTX 3060`. Please strictly
|
||||
significantly optimized the model's inference performance, greatly lowering the inference threshold.
|
||||
You can run **CogVideoX-2B** on older GPUs like `GTX 1080TI`, and **CogVideoX-5B** on desktop GPUs like `RTX 3060`. Please strictly
|
||||
follow the [requirements](requirements.txt) to update and install dependencies, and refer
|
||||
to [cli_demo](inference/cli_demo.py) for inference code. Additionally, the open-source license for the **CogVideoX-2B
|
||||
** model has been changed to the **Apache 2.0 License**.
|
||||
to [cli_demo](inference/cli_demo.py) for inference code. Additionally, the open-source license for
|
||||
the **CogVideoX-2B** model has been changed to the **Apache 2.0 License**.
|
||||
- 🔥 ```2024/8/6```: We have open-sourced **3D Causal VAE**, used for **CogVideoX-2B**, which can reconstruct videos with
|
||||
almost no loss.
|
||||
- 🔥 ```2024/8/6```: We have open-sourced the first model of the CogVideoX series video generation models, **CogVideoX-2B
|
||||
@ -57,19 +64,24 @@ Experience the CogVideoX-5B model online at <a href="https://huggingface.co/spac
|
||||
|
||||
Jump to a specific section:
|
||||
|
||||
- [Quick Start](#Quick-Start)
|
||||
- [SAT](#sat)
|
||||
- [Diffusers](#Diffusers)
|
||||
- [CogVideoX-2B Video Works](#cogvideox-2b-gallery)
|
||||
- [Introduction to the CogVideoX Model](#Model-Introduction)
|
||||
- [Full Project Structure](#project-structure)
|
||||
- [Inference](#inference)
|
||||
- [SAT](#sat)
|
||||
- [Tools](#tools)
|
||||
- [Introduction to CogVideo(ICLR'23) Model](#cogvideoiclr23)
|
||||
- [Citations](#Citation)
|
||||
- [Open Source Project Plan](#Open-Source-Project-Plan)
|
||||
- [Model License](#Model-License)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Prompt Optimization](#prompt-optimization)
|
||||
- [SAT](#sat)
|
||||
- [Diffusers](#diffusers)
|
||||
- [Gallery](#gallery)
|
||||
- [CogVideoX-5B](#cogvideox-5b)
|
||||
- [CogVideoX-2B](#cogvideox-2b)
|
||||
- [Model Introduction](#model-introduction)
|
||||
- [Friendly Links](#friendly-links)
|
||||
- [Project Structure](#project-structure)
|
||||
- [Quick Start with Colab](#quick-start-with-colab)
|
||||
- [Inference](#inference)
|
||||
- [finetune](#finetune)
|
||||
- [sat](#sat-1)
|
||||
- [Tools](#tools)
|
||||
- [CogVideo(ICLR'23)](#cogvideoiclr23)
|
||||
- [Citation](#citation)
|
||||
- [Model-License](#model-license)
|
||||
|
||||
## Quick Start
|
||||
|
||||
@ -169,82 +181,93 @@ models we currently offer, along with their foundational information.
|
||||
<table style="border-collapse: collapse; width: 100%;">
|
||||
<tr>
|
||||
<th style="text-align: center;">Model Name</th>
|
||||
<th style="text-align: center;">CogVideoX1.5-5B (Latest)</th>
|
||||
<th style="text-align: center;">CogVideoX1.5-5B-I2V (Latest)</th>
|
||||
<th style="text-align: center;">CogVideoX-2B</th>
|
||||
<th style="text-align: center;">CogVideoX-5B</th>
|
||||
<th style="text-align: center;">CogVideoX-5B-I2V</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">Model Description</td>
|
||||
<td style="text-align: center;">Entry-level model, balancing compatibility. Low cost for running and secondary development.</td>
|
||||
<td style="text-align: center;">Larger model with higher video generation quality and better visual effects.</td>
|
||||
<td style="text-align: center;">CogVideoX-5B image-to-video version.</td>
|
||||
<td style="text-align: center;">Release Date</td>
|
||||
<th style="text-align: center;">November 8, 2024</th>
|
||||
<th style="text-align: center;">November 8, 2024</th>
|
||||
<th style="text-align: center;">August 6, 2024</th>
|
||||
<th style="text-align: center;">August 27, 2024</th>
|
||||
<th style="text-align: center;">September 19, 2024</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">Video Resolution</td>
|
||||
<td colspan="1" style="text-align: center;">1360 * 768</td>
|
||||
<td colspan="1" style="text-align: center;"> Min(W, H) = 768 <br> 768 ≤ Max(W, H) ≤ 1360 <br> Max(W, H) % 16 = 0 </td>
|
||||
<td colspan="3" style="text-align: center;">720 * 480</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">Number of Frames</td>
|
||||
<td colspan="2" style="text-align: center;">Should be <b>16N + 1</b> where N <= 10 (default 81)</td>
|
||||
<td colspan="3" style="text-align: center;">Should be <b>8N + 1</b> where N <= 6 (default 49)</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">Inference Precision</td>
|
||||
<td style="text-align: center;"><b>FP16*(recommended)</b>, BF16, FP32, FP8*, INT8, not supported: INT4</td>
|
||||
<td colspan="2" style="text-align: center;"><b>BF16 (recommended)</b>, FP16, FP32, FP8*, INT8, not supported: INT4</td>
|
||||
<td colspan="2" style="text-align: center;"><b>BF16 (Recommended)</b>, FP16, FP32, FP8*, INT8, Not supported: INT4</td>
|
||||
<td style="text-align: center;"><b>FP16*(Recommended)</b>, BF16, FP32, FP8*, INT8, Not supported: INT4</td>
|
||||
<td colspan="2" style="text-align: center;"><b>BF16 (Recommended)</b>, FP16, FP32, FP8*, INT8, Not supported: INT4</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">Single GPU Memory Usage<br></td>
|
||||
<td style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> FP16: 18GB <br><b>diffusers FP16: from 4GB* </b><br><b>diffusers INT8 (torchao): from 3.6GB*</b></td>
|
||||
<td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 26GB <br><b>diffusers BF16: from 5GB* </b><br><b>diffusers INT8 (torchao): from 4.4GB*</b></td>
|
||||
<td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 76GB <br><b>diffusers BF16: from 10GB*</b><br><b>diffusers INT8(torchao): from 7GB*</b></td>
|
||||
<td style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> FP16: 18GB <br><b>diffusers FP16: 4GB minimum* </b><br><b>diffusers INT8 (torchao): 3.6GB minimum*</b></td>
|
||||
<td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 26GB <br><b>diffusers BF16 : 5GB minimum* </b><br><b>diffusers INT8 (torchao): 4.4GB minimum* </b></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">Multi-GPU Inference Memory Usage</td>
|
||||
<td style="text-align: center;">Multi-GPU Memory Usage</td>
|
||||
<td colspan="2" style="text-align: center;"><b>BF16: 24GB* using diffusers</b><br></td>
|
||||
<td style="text-align: center;"><b>FP16: 10GB* using diffusers</b><br></td>
|
||||
<td colspan="2" style="text-align: center;"><b>BF16: 15GB* using diffusers</b><br></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">Inference Speed<br>(Step = 50, FP/BF16)</td>
|
||||
<td colspan="2" style="text-align: center;">Single A100: ~1000 seconds (5-second video)<br>Single H100: ~550 seconds (5-second video)</td>
|
||||
<td style="text-align: center;">Single A100: ~90 seconds<br>Single H100: ~45 seconds</td>
|
||||
<td colspan="2" style="text-align: center;">Single A100: ~180 seconds<br>Single H100: ~90 seconds</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">Fine-tuning Precision</td>
|
||||
<td style="text-align: center;"><b>FP16</b></td>
|
||||
<td colspan="2" style="text-align: center;"><b>BF16</b></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">Fine-tuning Memory Usage</td>
|
||||
<td style="text-align: center;">47 GB (bs=1, LORA)<br> 61 GB (bs=2, LORA)<br> 62GB (bs=1, SFT)</td>
|
||||
<td style="text-align: center;">63 GB (bs=1, LORA)<br> 80 GB (bs=2, LORA)<br> 75GB (bs=1, SFT)<br></td>
|
||||
<td style="text-align: center;">78 GB (bs=1, LORA)<br> 75GB (bs=1, SFT, 16GPU)<br></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">Prompt Language</td>
|
||||
<td colspan="3" style="text-align: center;">English*</td>
|
||||
<td colspan="5" style="text-align: center;">English*</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">Maximum Prompt Length</td>
|
||||
<td style="text-align: center;">Prompt Token Limit</td>
|
||||
<td colspan="2" style="text-align: center;">224 Tokens</td>
|
||||
<td colspan="3" style="text-align: center;">226 Tokens</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">Video Length</td>
|
||||
<td colspan="3" style="text-align: center;">6 Seconds</td>
|
||||
<td colspan="2" style="text-align: center;">5 seconds or 10 seconds</td>
|
||||
<td colspan="3" style="text-align: center;">6 seconds</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">Frame Rate</td>
|
||||
<td colspan="3" style="text-align: center;">8 Frames / Second</td>
|
||||
<td colspan="2" style="text-align: center;">16 frames / second </td>
|
||||
<td colspan="3" style="text-align: center;">8 frames / second </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">Video Resolution</td>
|
||||
<td colspan="3" style="text-align: center;">720 x 480, no support for other resolutions (including fine-tuning)</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">Position Encoding</td>
|
||||
<td colspan="2" style="text-align: center;">3d_rope_pos_embed</td>
|
||||
<td style="text-align: center;">3d_sincos_pos_embed</td>
|
||||
<td style="text-align: center;">3d_sincos_pos_embed</td>
|
||||
<td style="text-align: center;">3d_rope_pos_embed</td>
|
||||
<td style="text-align: center;">3d_rope_pos_embed + learnable_pos_embed</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">Download Link (Diffusers)</td>
|
||||
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX1.5-5B">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX1.5-5B">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX1.5-5B">🟣 WiseModel</a></td>
|
||||
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX1.5-5B-I2V">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX1.5-5B-I2V">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX1.5-5B-I2V">🟣 WiseModel</a></td>
|
||||
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-2b">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-2b">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-2b">🟣 WiseModel</a></td>
|
||||
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-5b">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-5b">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-5b">🟣 WiseModel</a></td>
|
||||
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-5b-I2V">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-5b-I2V">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-5b-I2V">🟣 WiseModel</a></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">Download Link (SAT)</td>
|
||||
<td colspan="3" style="text-align: center;"><a href="./sat/README.md">SAT</a></td>
|
||||
<td colspan="2" style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX1.5-5b-SAT">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX1.5-5b-SAT">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX1.5-5b-SAT">🟣 WiseModel</a></td>
|
||||
<td colspan="3" style="text-align: center;"><a href="./sat/README_zh.md">SAT</a></td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
@ -271,21 +294,20 @@ pipe.vae.enable_tiling()
|
||||
used to quantize the text encoder, transformer, and VAE modules to reduce the memory requirements of CogVideoX. This
|
||||
allows the model to run on free T4 Colabs or GPUs with smaller memory! Also, note that TorchAO quantization is fully
|
||||
compatible with `torch.compile`, which can significantly improve inference speed. FP8 precision must be used on
|
||||
devices with NVIDIA H100 and above, requiring source installation of `torch`, `torchao`, `diffusers`, and `accelerate`
|
||||
Python packages. CUDA 12.4 is recommended.
|
||||
devices with NVIDIA H100 and above, requiring source installation of `torch`, `torchao` Python packages. CUDA 12.4 is recommended.
|
||||
+ The inference speed tests also used the above memory optimization scheme. Without memory optimization, inference speed
|
||||
increases by about 10%. Only the `diffusers` version of the model supports quantization.
|
||||
+ The model only supports English input; other languages can be translated into English for use via large model
|
||||
refinement.
|
||||
+ The memory usage of model fine-tuning is tested in an `8 * H100` environment, and the program automatically
|
||||
uses `Zero 2` optimization. If a specific number of GPUs is marked in the table, that number or more GPUs must be used
|
||||
for fine-tuning.
|
||||
|
||||
|
||||
## Friendly Links
|
||||
|
||||
We highly welcome contributions from the community and actively contribute to the open-source community. The following
|
||||
works have already been adapted for CogVideoX, and we invite everyone to use them:
|
||||
|
||||
+ [RIFLEx-CogVideoX](https://github.com/thu-ml/RIFLEx):
|
||||
RIFLEx extends the video with just one line of code: `freq[k-1]=(2np.pi)/(Ls)`. The framework not only supports training-free inference, but also offers models fine-tuned based on CogVideoX. By fine-tuning the model for just 1,000 steps on original-length videos, RIFLEx significantly enhances its length extrapolation capability.
|
||||
+ [CogVideoX-Fun](https://github.com/aigc-apps/CogVideoX-Fun): CogVideoX-Fun is a modified pipeline based on the
|
||||
CogVideoX architecture, supporting flexible resolutions and multiple launch methods.
|
||||
+ [CogStudio](https://github.com/pinokiofactory/cogstudio): A separate repository for CogVideo's Gradio Web UI, which
|
||||
@ -312,6 +334,10 @@ works have already been adapted for CogVideoX, and we invite everyone to use the
|
||||
+ [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio): DiffSynth Studio is a diffusion engine. It has
|
||||
restructured the architecture, including text encoders, UNet, VAE, etc., enhancing computational performance while
|
||||
maintaining compatibility with open-source community models. The framework has been adapted for CogVideoX.
|
||||
+ [CogVideoX-Controlnet](https://github.com/TheDenk/cogvideox-controlnet): A simple ControlNet module code that includes the CogVideoX model.
|
||||
+ [VideoTuna](https://github.com/VideoVerses/VideoTuna): VideoTuna is the first repo that integrates multiple AI video generation models for text-to-video, image-to-video, text-to-image generation.
|
||||
+ [ConsisID](https://github.com/PKU-YuanGroup/ConsisID): An identity-preserving text-to-video generation model, bases on CogVideoX-5B, which keep the face consistent in the generated video by frequency decomposition.
|
||||
+ [A Step by Step Tutorial](https://www.youtube.com/watch?v=5UCkMzP2VLE&ab_channel=SECourses): A step-by-step guide on installing and optimizing the CogVideoX1.5-5B-I2V model in Windows and cloud environments. Special thanks to the [FurkanGozukara](https://github.com/FurkanGozukara) for his effort and support!
|
||||
|
||||
## Project Structure
|
||||
|
||||
@ -420,9 +446,7 @@ hands-on practice on text-to-video generation. *The original input is in Chinese
|
||||
}
|
||||
```
|
||||
|
||||
We welcome your contributions! You can click [here](resources/contribute.md) for more information.
|
||||
|
||||
## License Agreement
|
||||
## Model-License
|
||||
|
||||
The code in this repository is released under the [Apache 2.0 License](LICENSE).
|
||||
|
||||
|
138
README_ja.md
138
README_ja.md
@ -1,6 +1,6 @@
|
||||
# CogVideo & CogVideoX
|
||||
|
||||
[Read this in English](./README_zh.md)
|
||||
[Read this in English](./README.md)
|
||||
|
||||
[中文阅读](./README_zh.md)
|
||||
|
||||
@ -21,10 +21,18 @@
|
||||
</p>
|
||||
|
||||
## 更新とニュース
|
||||
|
||||
- 🔥🔥 **ニュース**: ```2024/10/13```: コスト削減のため、単一の4090 GPUで`CogVideoX-5B`
|
||||
- 🔥🔥 ```2025/03/24```: [CogKit](https://github.com/THUDM/CogKit) は **CogView4** および **CogVideoX** シリーズの微調整と推論のためのフレームワークです。このツールキットを活用することで、私たちのマルチモーダル生成モデルを最大限に活用できます。
|
||||
- **ニュース**: ```2025/02/28```: DDIM Inverse が `CogVideoX-5B` と `CogVideoX1.5-5B` でサポートされました。詳細は [こちら](inference/ddim_inversion.py) をご覧ください。
|
||||
- **ニュース**: ```2025/01/08```: 私たちは`diffusers`バージョンのモデルをベースにした`Lora`微調整用のコードを更新しました。より少ないVRAM(ビデオメモリ)で動作します。詳細については[こちら](finetune/README_ja.md)をご覧ください。
|
||||
- **ニュース**: ```2024/11/15```: `CogVideoX1.5` モデルのdiffusersバージョンをリリースしました。わずかなパラメータ調整で以前のコードをそのまま利用可能です。
|
||||
- **ニュース**: ```2024/11/08```: `CogVideoX1.5` モデルをリリースしました。CogVideoX1.5 は CogVideoX オープンソースモデルのアップグレードバージョンです。
|
||||
CogVideoX1.5-5B シリーズモデルは、10秒 長の動画とより高い解像度をサポートしており、`CogVideoX1.5-5B-I2V` は任意の解像度での動画生成に対応しています。
|
||||
SAT コードはすでに更新されており、`diffusers` バージョンは現在適応中です。
|
||||
SAT バージョンのコードは [こちら](https://huggingface.co/THUDM/CogVideoX1.5-5B-SAT) からダウンロードできます。
|
||||
- 🔥 **ニュース**: ```2024/10/13```: コスト削減のため、単一の4090 GPUで`CogVideoX-5B`
|
||||
を微調整できるフレームワーク [cogvideox-factory](https://github.com/a-r-r-o-w/cogvideox-factory)
|
||||
がリリースされました。複数の解像度での微調整に対応しています。ぜひご利用ください!- 🔥**ニュース**: ```2024/10/10```:
|
||||
がリリースされました。複数の解像度での微調整に対応しています。ぜひご利用ください!
|
||||
- 🔥**ニュース**: ```2024/10/10```:
|
||||
技術報告書を更新し、より詳細なトレーニング情報とデモを追加しました。
|
||||
- 🔥 **ニュース**: ```2024/10/10```: 技術報告書を更新しました。[こちら](https://arxiv.org/pdf/2408.06072)
|
||||
をクリックしてご覧ください。さらにトレーニングの詳細とデモを追加しました。デモを見るには[こちら](https://yzy-thu.github.io/CogVideoX-demo/)
|
||||
@ -34,7 +42,7 @@
|
||||
- 🔥**ニュース**: ```2024/9/19```: CogVideoXシリーズの画像生成ビデオモデル **CogVideoX-5B-I2V**
|
||||
をオープンソース化しました。このモデルは、画像を背景入力として使用し、プロンプトワードと組み合わせてビデオを生成することができ、より高い制御性を提供します。これにより、CogVideoXシリーズのモデルは、テキストからビデオ生成、ビデオの継続、画像からビデオ生成の3つのタスクをサポートするようになりました。オンラインでの[体験](https://huggingface.co/spaces/THUDM/CogVideoX-5B-Space)
|
||||
をお楽しみください。
|
||||
- 🔥🔥 **ニュース**: ```2024/9/19```:
|
||||
- 🔥 **ニュース**: ```2024/9/19```:
|
||||
CogVideoXのトレーニングプロセスでビデオデータをテキスト記述に変換するために使用されるキャプションモデル [CogVLM2-Caption](https://huggingface.co/THUDM/cogvlm2-llama3-caption)
|
||||
をオープンソース化しました。ダウンロードしてご利用ください。
|
||||
- 🔥 ```2024/8/27```: CogVideoXシリーズのより大きなモデル **CogVideoX-5B**
|
||||
@ -56,18 +64,23 @@
|
||||
特定のセクションにジャンプ:
|
||||
|
||||
- [クイックスタート](#クイックスタート)
|
||||
- [SAT](#sat)
|
||||
- [Diffusers](#Diffusers)
|
||||
- [CogVideoX-2B ギャラリー](#CogVideoX-2B-ギャラリー)
|
||||
- [プロンプトの最適化](#プロンプトの最適化)
|
||||
- [SAT](#sat)
|
||||
- [Diffusers](#diffusers)
|
||||
- [Gallery](#gallery)
|
||||
- [CogVideoX-5B](#cogvideox-5b)
|
||||
- [CogVideoX-2B](#cogvideox-2b)
|
||||
- [モデル紹介](#モデル紹介)
|
||||
- [友好的リンク](#友好的リンク)
|
||||
- [プロジェクト構造](#プロジェクト構造)
|
||||
- [推論](#推論)
|
||||
- [sat](#sat)
|
||||
- [ツール](#ツール)
|
||||
- [プロジェクト計画](#プロジェクト計画)
|
||||
- [モデルライセンス](#モデルライセンス)
|
||||
- [CogVideo(ICLR'23)モデル紹介](#CogVideoICLR23)
|
||||
- [Colabでのクイックスタート](#colabでのクイックスタート)
|
||||
- [Inference](#inference)
|
||||
- [finetune](#finetune)
|
||||
- [sat](#sat-1)
|
||||
- [ツール](#ツール)
|
||||
- [CogVideo(ICLR'23)](#cogvideoiclr23)
|
||||
- [引用](#引用)
|
||||
- [ライセンス契約](#ライセンス契約)
|
||||
|
||||
## クイックスタート
|
||||
|
||||
@ -156,79 +169,96 @@ pip install -r requirements.txt
|
||||
CogVideoXは、[清影](https://chatglm.cn/video?fr=osm_cogvideox) と同源のオープンソース版ビデオ生成モデルです。
|
||||
以下の表に、提供しているビデオ生成モデルの基本情報を示します:
|
||||
|
||||
<table style="border-collapse: collapse; width: 100%;">
|
||||
<table style="border-collapse: collapse; width: 100%;">
|
||||
<tr>
|
||||
<th style="text-align: center;">モデル名</th>
|
||||
<th style="text-align: center;">CogVideoX1.5-5B (最新)</th>
|
||||
<th style="text-align: center;">CogVideoX1.5-5B-I2V (最新)</th>
|
||||
<th style="text-align: center;">CogVideoX-2B</th>
|
||||
<th style="text-align: center;">CogVideoX-5B</th>
|
||||
<th style="text-align: center;">CogVideoX-5B-I2V </th>
|
||||
<th style="text-align: center;">CogVideoX-5B-I2V</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">公開日</td>
|
||||
<th style="text-align: center;">2024年11月8日</th>
|
||||
<th style="text-align: center;">2024年11月8日</th>
|
||||
<th style="text-align: center;">2024年8月6日</th>
|
||||
<th style="text-align: center;">2024年8月27日</th>
|
||||
<th style="text-align: center;">2024年9月19日</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">ビデオ解像度</td>
|
||||
<td colspan="1" style="text-align: center;">1360 * 768</td>
|
||||
<td colspan="1" style="text-align: center;"> Min(W, H) = 768 <br> 768 ≤ Max(W, H) ≤ 1360 <br> Max(W, H) % 16 = 0 </td>
|
||||
<td colspan="3" style="text-align: center;">720 * 480</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">フレーム数</td>
|
||||
<td colspan="2" style="text-align: center;"><b>16N + 1</b> (N <= 10) である必要があります (デフォルト 81)</td>
|
||||
<td colspan="3" style="text-align: center;"><b>8N + 1</b> (N <= 6) である必要があります (デフォルト 49)</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">推論精度</td>
|
||||
<td style="text-align: center;"><b>FP16*(推奨)</b>, BF16, FP32, FP8*, INT8, INT4は非対応</td>
|
||||
<td colspan="2" style="text-align: center;"><b>BF16(推奨)</b>, FP16, FP32, FP8*, INT8, INT4は非対応</td>
|
||||
<td colspan="2" style="text-align: center;"><b>BF16(推奨)</b>, FP16, FP32,FP8*,INT8,INT4非対応</td>
|
||||
<td style="text-align: center;"><b>FP16*(推奨)</b>, BF16, FP32,FP8*,INT8,INT4非対応</td>
|
||||
<td colspan="2" style="text-align: center;"><b>BF16(推奨)</b>, FP16, FP32,FP8*,INT8,INT4非対応</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">単一GPUのメモリ消費<br></td>
|
||||
<td style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> FP16: 18GB <br><b>diffusers FP16: 4GBから* </b><br><b>diffusers INT8(torchao): 3.6GBから*</b></td>
|
||||
<td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 26GB <br><b>diffusers BF16 : 5GBから* </b><br><b>diffusers INT8(torchao): 4.4GBから* </b></td>
|
||||
<td style="text-align: center;">単一GPUメモリ消費量<br></td>
|
||||
<td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 76GB <br><b>diffusers BF16:10GBから*</b><br><b>diffusers INT8(torchao):7GBから*</b></td>
|
||||
<td style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> FP16: 18GB <br><b>diffusers FP16: 4GB以上* </b><br><b>diffusers INT8(torchao): 3.6GB以上*</b></td>
|
||||
<td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 26GB <br><b>diffusers BF16 : 5GB以上* </b><br><b>diffusers INT8(torchao): 4.4GB以上* </b></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">マルチGPUのメモリ消費</td>
|
||||
<td style="text-align: center;"><b>FP16: 10GB* using diffusers</b><br></td>
|
||||
<td colspan="2" style="text-align: center;"><b>BF16: 15GB* using diffusers</b><br></td>
|
||||
<td style="text-align: center;">複数GPU推論メモリ消費量</td>
|
||||
<td colspan="2" style="text-align: center;"><b>BF16: 24GB* using diffusers</b><br></td>
|
||||
<td style="text-align: center;"><b>FP16: 10GB* diffusers使用</b><br></td>
|
||||
<td colspan="2" style="text-align: center;"><b>BF16: 15GB* diffusers使用</b><br></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">推論速度<br>(ステップ = 50, FP/BF16)</td>
|
||||
<td style="text-align: center;">単一A100: 約90秒<br>単一H100: 約45秒</td>
|
||||
<td colspan="2" style="text-align: center;">単一A100: 約180秒<br>単一H100: 約90秒</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">ファインチューニング精度</td>
|
||||
<td style="text-align: center;"><b>FP16</b></td>
|
||||
<td colspan="2" style="text-align: center;"><b>BF16</b></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">ファインチューニング時のメモリ消費</td>
|
||||
<td style="text-align: center;">47 GB (bs=1, LORA)<br> 61 GB (bs=2, LORA)<br> 62GB (bs=1, SFT)</td>
|
||||
<td style="text-align: center;">63 GB (bs=1, LORA)<br> 80 GB (bs=2, LORA)<br> 75GB (bs=1, SFT)<br></td>
|
||||
<td style="text-align: center;">78 GB (bs=1, LORA)<br> 75GB (bs=1, SFT, 16GPU)<br></td>
|
||||
<td style="text-align: center;">推論速度<br>(Step = 50, FP/BF16)</td>
|
||||
<td colspan="2" style="text-align: center;">シングルA100: ~1000秒(5秒ビデオ)<br>シングルH100: ~550秒(5秒ビデオ)</td>
|
||||
<td style="text-align: center;">シングルA100: ~90秒<br>シングルH100: ~45秒</td>
|
||||
<td colspan="2" style="text-align: center;">シングルA100: ~180秒<br>シングルH100: ~90秒</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">プロンプト言語</td>
|
||||
<td colspan="3" style="text-align: center;">英語*</td>
|
||||
<td colspan="5" style="text-align: center;">英語*</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">プロンプトの最大トークン数</td>
|
||||
<td style="text-align: center;">プロンプト長さの上限</td>
|
||||
<td colspan="2" style="text-align: center;">224トークン</td>
|
||||
<td colspan="3" style="text-align: center;">226トークン</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">ビデオの長さ</td>
|
||||
<td style="text-align: center;">ビデオ長さ</td>
|
||||
<td colspan="2" style="text-align: center;">5秒または10秒</td>
|
||||
<td colspan="3" style="text-align: center;">6秒</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">フレームレート</td>
|
||||
<td colspan="2" style="text-align: center;">16フレーム/秒</td>
|
||||
<td colspan="3" style="text-align: center;">8フレーム/秒</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">ビデオ解像度</td>
|
||||
<td colspan="3" style="text-align: center;">720 * 480、他の解像度は非対応(ファインチューニング含む)</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">位置エンコーディング</td>
|
||||
<td colspan="2" style="text-align: center;">3d_rope_pos_embed</td>
|
||||
<td style="text-align: center;">3d_sincos_pos_embed</td>
|
||||
<td style="text-align: center;">3d_sincos_pos_embed</td>
|
||||
<td style="text-align: center;">3d_rope_pos_embed</td>
|
||||
<td style="text-align: center;">3d_rope_pos_embed + learnable_pos_embed</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">ダウンロードリンク (Diffusers)</td>
|
||||
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX1.5-5B">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX1.5-5B">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX1.5-5B">🟣 WiseModel</a></td>
|
||||
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX1.5-5B-I2V">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX1.5-5B-I2V">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX1.5-5B-I2V">🟣 WiseModel</a></td>
|
||||
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-2b">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-2b">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-2b">🟣 WiseModel</a></td>
|
||||
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-5b">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-5b">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-5b">🟣 WiseModel</a></td>
|
||||
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-5b-I2V">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-5b-I2V">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-5b-I2V">🟣 WiseModel</a></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">ダウンロードリンク (SAT)</td>
|
||||
<td colspan="3" style="text-align: center;"><a href="./sat/README_ja.md">SAT</a></td>
|
||||
<td colspan="2" style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX1.5-5b-SAT">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX1.5-5b-SAT">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX1.5-5b-SAT">🟣 WiseModel</a></td>
|
||||
<td colspan="3" style="text-align: center;"><a href="./sat/README_zh.md">SAT</a></td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
@ -253,18 +283,18 @@ pipe.vae.enable_tiling()
|
||||
は、CogVideoXのメモリ要件を削減するためにテキストエンコーダ、トランスフォーマ、およびVAEモジュールを量子化するために使用できます。これにより、無料のT4
|
||||
Colabやより少ないメモリのGPUでモデルを実行することが可能になります。同様に重要なのは、TorchAOの量子化は`torch.compile`
|
||||
と完全に互換性があり、推論速度を大幅に向上させることができる点です。`NVIDIA H100`およびそれ以上のデバイスでは`FP8`
|
||||
精度を使用する必要があります。これには、`torch`、`torchao`、`diffusers`、`accelerate`
|
||||
Pythonパッケージのソースコードからのインストールが必要です。`CUDA 12.4`の使用をお勧めします。
|
||||
精度を使用する必要があります。これには、`torch`、`torchao` Pythonパッケージのソースコードからのインストールが必要です。`CUDA 12.4`の使用をお勧めします。
|
||||
+ 推論速度テストも同様に、上記のメモリ最適化方法を使用しています。メモリ最適化を使用しない場合、推論速度は約10%向上します。
|
||||
`diffusers`バージョンのモデルのみが量子化をサポートしています。
|
||||
+ モデルは英語入力のみをサポートしており、他の言語は大規模モデルの改善を通じて英語に翻訳できます。
|
||||
+ モデルのファインチューニングに使用されるメモリは`8 * H100`環境でテストされています。プログラムは自動的に`Zero 2`
|
||||
最適化を使用しています。表に具体的なGPU数が記載されている場合、ファインチューニングにはその数以上のGPUが必要です。
|
||||
|
||||
|
||||
## 友好的リンク
|
||||
|
||||
コミュニティからの貢献を大歓迎し、私たちもオープンソースコミュニティに積極的に貢献しています。以下の作品はすでにCogVideoXに対応しており、ぜひご利用ください:
|
||||
|
||||
+ [RIFLEx-CogVideoX](https://github.com/thu-ml/RIFLEx):
|
||||
RIFLExは動画の長さを外挿する手法で、たった1行のコードで動画の長さを元の2倍に延長できます。RIFLExはトレーニング不要の推論をサポートするだけでなく、CogVideoXをベースにファインチューニングしたモデルも提供しています。元の長さの動画でわずか1000ステップのファインチューニングを行うだけで、長さ外挿能力を大幅に向上させることができます。
|
||||
+ [CogVideoX-Fun](https://github.com/aigc-apps/CogVideoX-Fun):
|
||||
CogVideoX-Funは、CogVideoXアーキテクチャを基にした改良パイプラインで、自由な解像度と複数の起動方法をサポートしています。
|
||||
+ [CogStudio](https://github.com/pinokiofactory/cogstudio): CogVideo の Gradio Web UI の別のリポジトリ。より高機能な Web
|
||||
@ -284,6 +314,10 @@ pipe.vae.enable_tiling()
|
||||
キーフレーム補間生成において、より大きな柔軟性を提供することを目的とした、CogVideoX構造を基にした修正版のパイプライン。
|
||||
+ [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio): DiffSynth
|
||||
Studioは、拡散エンジンです。テキストエンコーダー、UNet、VAEなどを含むアーキテクチャを再構築し、オープンソースコミュニティモデルとの互換性を維持しつつ、計算性能を向上させました。このフレームワークはCogVideoXに適応しています。
|
||||
+ [CogVideoX-Controlnet](https://github.com/TheDenk/cogvideox-controlnet): CogVideoXモデルを含むシンプルなControlNetモジュールのコード。
|
||||
+ [VideoTuna](https://github.com/VideoVerses/VideoTuna): VideoTuna は、テキストからビデオ、画像からビデオ、テキストから画像生成のための複数のAIビデオ生成モデルを統合した最初のリポジトリです。
|
||||
+ [ConsisID](https://github.com/PKU-YuanGroup/ConsisID): 一貫性のある顔を保持するために、周波数分解を使用するCogVideoX-5Bに基づいたアイデンティティ保持型テキストから動画生成モデル。
|
||||
+ [ステップバイステップチュートリアル](https://www.youtube.com/watch?v=5UCkMzP2VLE&ab_channel=SECourses): WindowsおよびクラウドでのCogVideoX1.5-5B-I2Vモデルのインストールと最適化に関するステップバイステップガイド。[FurkanGozukara](https://github.com/FurkanGozukara)氏の尽力とサポートに感謝いたします!
|
||||
|
||||
## プロジェクト構造
|
||||
|
||||
@ -386,8 +420,6 @@ CogVideoのデモは [https://models.aminer.cn/cogvideo](https://models.aminer.c
|
||||
}
|
||||
```
|
||||
|
||||
あなたの貢献をお待ちしています!詳細は[こちら](resources/contribute_ja.md)をクリックしてください。
|
||||
|
||||
## ライセンス契約
|
||||
|
||||
このリポジトリのコードは [Apache 2.0 License](LICENSE) の下で公開されています。
|
||||
|
103
README_zh.md
103
README_zh.md
@ -1,10 +1,9 @@
|
||||
# CogVideo & CogVideoX
|
||||
|
||||
[Read this in English](./README_zh.md)
|
||||
[Read this in English](./README.md)
|
||||
|
||||
[日本語で読む](./README_ja.md)
|
||||
|
||||
|
||||
<div align="center">
|
||||
<img src=resources/logo.svg width="50%"/>
|
||||
</div>
|
||||
@ -23,7 +22,13 @@
|
||||
|
||||
## 项目更新
|
||||
|
||||
- 🔥🔥 **News**: ```2024/10/13```: 成本更低,单卡4090可微调`CogVideoX-5B`
|
||||
- 🔥🔥 **News**: ```2025/03/24```: 我们推出了 [CogKit](https://github.com/THUDM/CogKit) 工具,这是一个微调**CogView4**, **CogVideoX** 系列的微调和推理框架,一个工具包,玩转我们的多模态生成模型。
|
||||
- 🔥 **News**: ```2025/02/28```: DDIM Inverse 已经在`CogVideoX-5B` 和 `CogVideoX1.5 -5B` 支持,查看 [here](inference/ddim_inversion.py).
|
||||
- 🔥 **News**: ```2025/01/08```: 我们更新了基于`diffusers`版本模型的`Lora`微调代码,占用显存更低,详情请见[这里](finetune/README_zh.md)。
|
||||
- 🔥 **News**: ```2024/11/15```: 我们发布 `CogVideoX1.5` 模型的diffusers版本,仅需调整部分参数仅可沿用之前的代码。
|
||||
- 🔥 **News**: ```2024/11/08```: 我们发布 `CogVideoX1.5` 模型。CogVideoX1.5 是 CogVideoX 开源模型的升级版本。
|
||||
CogVideoX1.5-5B 系列模型支持 **10秒** 长度的视频和更高的分辨率,其中 `CogVideoX1.5-5B-I2V` 支持 **任意分辨率** 的视频生成,SAT代码已经更新。`diffusers`版本还在适配中。SAT版本代码前往 [这里](https://huggingface.co/THUDM/CogVideoX1.5-5B-SAT) 下载。
|
||||
- 🔥**News**: ```2024/10/13```: 成本更低,单卡4090可微调 `CogVideoX-5B`
|
||||
的微调框架[cogvideox-factory](https://github.com/a-r-r-o-w/cogvideox-factory)已经推出,多种分辨率微调,欢迎使用。
|
||||
- 🔥 **News**: ```2024/10/10```: 我们更新了我们的技术报告,请点击 [这里](https://arxiv.org/pdf/2408.06072)
|
||||
查看,附上了更多的训练细节和demo,关于demo,点击[这里](https://yzy-thu.github.io/CogVideoX-demo/) 查看。
|
||||
@ -38,8 +43,7 @@
|
||||
- 🔥 ```2024/8/27```: 我们开源 CogVideoX 系列更大的模型 **CogVideoX-5B**
|
||||
。我们大幅度优化了模型的推理性能,推理门槛大幅降低,您可以在 `GTX 1080TI` 等早期显卡运行 **CogVideoX-2B**,在 `RTX 3060`
|
||||
等桌面端甜品卡运行 **CogVideoX-5B** 模型。 请严格按照[要求](requirements.txt)
|
||||
更新安装依赖,推理代码请查看 [cli_demo](inference/cli_demo.py)。同时,**CogVideoX-2B** 模型开源协议已经修改为**Apache 2.0
|
||||
协议**。
|
||||
更新安装依赖,推理代码请查看 [cli_demo](inference/cli_demo.py)。同时,**CogVideoX-2B** 模型开源协议已经修改为**Apache 2.0 协议**。
|
||||
- 🔥 ```2024/8/6```: 我们开源 **3D Causal VAE**,用于 **CogVideoX-2B**,可以几乎无损地重构视频。
|
||||
- 🔥 ```2024/8/6```: 我们开源 CogVideoX 系列视频生成模型的第一个模型, **CogVideoX-2B**。
|
||||
- 🌱 **Source**: ```2022/5/19```: 我们开源了 CogVideo 视频生成模型(现在你可以在 `CogVideo` 分支中看到),这是首个开源的基于
|
||||
@ -50,18 +54,23 @@
|
||||
跳转到指定部分:
|
||||
|
||||
- [快速开始](#快速开始)
|
||||
- [SAT](#sat)
|
||||
- [Diffusers](#Diffusers)
|
||||
- [CogVideoX-2B 视频作品](#cogvideox-2b-视频作品)
|
||||
- [CogVideoX模型介绍](#模型介绍)
|
||||
- [提示词优化](#提示词优化)
|
||||
- [SAT](#sat)
|
||||
- [Diffusers](#diffusers)
|
||||
- [视频作品](#视频作品)
|
||||
- [CogVideoX-5B](#cogvideox-5b)
|
||||
- [CogVideoX-2B](#cogvideox-2b)
|
||||
- [模型介绍](#模型介绍)
|
||||
- [友情链接](#友情链接)
|
||||
- [完整项目代码结构](#完整项目代码结构)
|
||||
- [Inference](#inference)
|
||||
- [SAT](#sat)
|
||||
- [Tools](#tools)
|
||||
- [开源项目规划](#开源项目规划)
|
||||
- [模型协议](#模型协议)
|
||||
- [CogVideo(ICLR'23)模型介绍](#cogvideoiclr23)
|
||||
- [Colab 快速使用](#colab-快速使用)
|
||||
- [inference](#inference)
|
||||
- [finetune](#finetune)
|
||||
- [sat](#sat-1)
|
||||
- [tools](#tools)
|
||||
- [CogVideo(ICLR'23)](#cogvideoiclr23)
|
||||
- [引用](#引用)
|
||||
- [模型协议](#模型协议)
|
||||
|
||||
## 快速开始
|
||||
|
||||
@ -154,75 +163,92 @@ CogVideoX是 [清影](https://chatglm.cn/video?fr=osm_cogvideox) 同源的开源
|
||||
<table style="border-collapse: collapse; width: 100%;">
|
||||
<tr>
|
||||
<th style="text-align: center;">模型名</th>
|
||||
<th style="text-align: center;">CogVideoX1.5-5B (最新)</th>
|
||||
<th style="text-align: center;">CogVideoX1.5-5B-I2V (最新)</th>
|
||||
<th style="text-align: center;">CogVideoX-2B</th>
|
||||
<th style="text-align: center;">CogVideoX-5B</th>
|
||||
<th style="text-align: center;">CogVideoX-5B-I2V </th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">发布时间</td>
|
||||
<th style="text-align: center;">2024年11月8日</th>
|
||||
<th style="text-align: center;">2024年11月8日</th>
|
||||
<th style="text-align: center;">2024年8月6日</th>
|
||||
<th style="text-align: center;">2024年8月27日</th>
|
||||
<th style="text-align: center;">2024年9月19日</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">视频分辨率</td>
|
||||
<td colspan="1" style="text-align: center;">1360 * 768</td>
|
||||
<td colspan="1" style="text-align: center;"> Min(W, H) = 768 <br> 768 ≤ Max(W, H) ≤ 1360 <br> Max(W, H) % 16 = 0 </td>
|
||||
<td colspan="3" style="text-align: center;">720 * 480</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">帧数</td>
|
||||
<td colspan="2" style="text-align: center;">必须为 <b>16N + 1</b> 其中 N <= 10 (默认 81)</td>
|
||||
<td colspan="3" style="text-align: center;">必须为 <b>8N + 1</b> 其中 N <= 6 (默认 49)</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">推理精度</td>
|
||||
<td colspan="2" style="text-align: center;"><b>BF16(推荐)</b>, FP16, FP32,FP8*,INT8,不支持INT4</td>
|
||||
<td style="text-align: center;"><b>FP16*(推荐)</b>, BF16, FP32,FP8*,INT8,不支持INT4</td>
|
||||
<td colspan="2" style="text-align: center;"><b>BF16(推荐)</b>, FP16, FP32,FP8*,INT8,不支持INT4</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">单GPU显存消耗<br></td>
|
||||
<td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 76GB <br><b>diffusers BF16 : 10GB起* </b><br><b>diffusers INT8(torchao): 7G起* </b></td>
|
||||
<td style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> FP16: 18GB <br><b>diffusers FP16: 4GB起* </b><br><b>diffusers INT8(torchao): 3.6G起*</b></td>
|
||||
<td colspan="2" style="text-align: center;"><a href="https://github.com/THUDM/SwissArmyTransformer">SAT</a> BF16: 26GB <br><b>diffusers BF16 : 5GB起* </b><br><b>diffusers INT8(torchao): 4.4G起* </b></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">多GPU推理显存消耗</td>
|
||||
<td colspan="2" style="text-align: center;"><b>BF16: 24GB* using diffusers</b><br></td>
|
||||
<td style="text-align: center;"><b>FP16: 10GB* using diffusers</b><br></td>
|
||||
<td colspan="2" style="text-align: center;"><b>BF16: 15GB* using diffusers</b><br></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">推理速度<br>(Step = 50, FP/BF16)</td>
|
||||
<td colspan="2" style="text-align: center;">单卡A100: ~1000秒(5秒视频)<br>单卡H100: ~550秒(5秒视频)</td>
|
||||
<td style="text-align: center;">单卡A100: ~90秒<br>单卡H100: ~45秒</td>
|
||||
<td colspan="2" style="text-align: center;">单卡A100: ~180秒<br>单卡H100: ~90秒</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">微调精度</td>
|
||||
<td style="text-align: center;"><b>FP16</b></td>
|
||||
<td colspan="2" style="text-align: center;"><b>BF16</b></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">微调显存消耗</td>
|
||||
<td style="text-align: center;">47 GB (bs=1, LORA)<br> 61 GB (bs=2, LORA)<br> 62GB (bs=1, SFT)</td>
|
||||
<td style="text-align: center;">63 GB (bs=1, LORA)<br> 80 GB (bs=2, LORA)<br> 75GB (bs=1, SFT)<br></td>
|
||||
<td style="text-align: center;">78 GB (bs=1, LORA)<br> 75GB (bs=1, SFT, 16GPU)<br></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">提示词语言</td>
|
||||
<td colspan="3" style="text-align: center;">English*</td>
|
||||
<td colspan="5" style="text-align: center;">English*</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">提示词长度上限</td>
|
||||
<td colspan="2" style="text-align: center;">224 Tokens</td>
|
||||
<td colspan="3" style="text-align: center;">226 Tokens</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">视频长度</td>
|
||||
<td colspan="2" style="text-align: center;">5 秒 或 10 秒</td>
|
||||
<td colspan="3" style="text-align: center;">6 秒</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">帧率</td>
|
||||
<td colspan="2" style="text-align: center;">16 帧 / 秒 </td>
|
||||
<td colspan="3" style="text-align: center;">8 帧 / 秒 </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">视频分辨率</td>
|
||||
<td colspan="3" style="text-align: center;">720 * 480,不支持其他分辨率(含微调)</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">位置编码</td>
|
||||
<td colspan="2" style="text-align: center;">3d_rope_pos_embed</td>
|
||||
<td style="text-align: center;">3d_sincos_pos_embed</td>
|
||||
<td style="text-align: center;">3d_sincos_pos_embed</td>
|
||||
<td style="text-align: center;">3d_rope_pos_embed</td>
|
||||
<td style="text-align: center;">3d_rope_pos_embed + learnable_pos_embed</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">下载链接 (Diffusers)</td>
|
||||
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX1.5-5B">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX1.5-5B">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX1.5-5B">🟣 WiseModel</a></td>
|
||||
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX1.5-5B-I2V">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX1.5-5B-I2V">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX1.5-5B-I2V">🟣 WiseModel</a></td>
|
||||
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-2b">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-2b">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-2b">🟣 WiseModel</a></td>
|
||||
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-5b">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-5b">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-5b">🟣 WiseModel</a></td>
|
||||
<td style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX-5b-I2V">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX-5b-I2V">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX-5b-I2V">🟣 WiseModel</a></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align: center;">下载链接 (SAT)</td>
|
||||
<td colspan="2" style="text-align: center;"><a href="https://huggingface.co/THUDM/CogVideoX1.5-5b-SAT">🤗 HuggingFace</a><br><a href="https://modelscope.cn/models/ZhipuAI/CogVideoX1.5-5b-SAT">🤖 ModelScope</a><br><a href="https://wisemodel.cn/models/ZhipuAI/CogVideoX1.5-5b-SAT">🟣 WiseModel</a></td>
|
||||
<td colspan="3" style="text-align: center;"><a href="./sat/README_zh.md">SAT</a></td>
|
||||
</tr>
|
||||
</table>
|
||||
@ -245,16 +271,16 @@ pipe.vae.enable_tiling()
|
||||
+ [PytorchAO](https://github.com/pytorch/ao) 和 [Optimum-quanto](https://github.com/huggingface/optimum-quanto/)
|
||||
可以用于量化文本编码器、Transformer 和 VAE 模块,以降低 CogVideoX 的内存需求。这使得在免费的 T4 Colab 或更小显存的 GPU
|
||||
上运行模型成为可能!同样值得注意的是,TorchAO 量化完全兼容 `torch.compile`,这可以显著提高推理速度。在 `NVIDIA H100`
|
||||
及以上设备上必须使用 `FP8` 精度,这需要源码安装 `torch`、`torchao`、`diffusers` 和 `accelerate` Python
|
||||
包。建议使用 `CUDA 12.4`。
|
||||
及以上设备上必须使用 `FP8` 精度,这需要源码安装 `torch`、`torchao` Python 包。建议使用 `CUDA 12.4`。
|
||||
+ 推理速度测试同样采用了上述显存优化方案,不采用显存优化的情况下,推理速度提升约10%。 只有`diffusers`版本模型支持量化。
|
||||
+ 模型仅支持英语输入,其他语言可以通过大模型润色时翻译为英语。
|
||||
+ 模型微调所占用的显存是在 `8 * H100` 环境下进行测试,程序已经自动使用`Zero 2` 优化。表格中若有标注具体GPU数量则必须使用大于等于该数量的GPU进行微调。
|
||||
|
||||
## 友情链接
|
||||
|
||||
我们非常欢迎来自社区的贡献,并积极的贡献开源社区。以下作品已经对CogVideoX进行了适配,欢迎大家使用:
|
||||
|
||||
+ [RIFLEx-CogVideoX](https://github.com/thu-ml/RIFLEx):
|
||||
RIFLEx 是一个视频长度外推的方法,只需一行代码即可将视频生成长度延伸为原先的二倍。RIFLEx 不仅支持 Training-free 的推理,也提供基于 CogVideoX 进行微调的模型,只需在原有长度视频上微调 1000 步即可大大提高长度外推能力。
|
||||
+ [CogVideoX-Fun](https://github.com/aigc-apps/CogVideoX-Fun):
|
||||
CogVideoX-Fun是一个基于CogVideoX结构修改后的的pipeline,支持自由的分辨率,多种启动方式。
|
||||
+ [CogStudio](https://github.com/pinokiofactory/cogstudio): CogVideo 的 Gradio Web UI单独实现仓库,支持更多功能的 Web UI。
|
||||
@ -269,6 +295,11 @@ pipe.vae.enable_tiling()
|
||||
+ [CogVideoX-Interpolation](https://github.com/feizc/CogvideX-Interpolation): 基于 CogVideoX 结构修改的管道,旨在为关键帧插值生成提供更大的灵活性。
|
||||
+ [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio): DiffSynth 工作室是一款扩散引擎。重构了架构,包括文本编码器、UNet、VAE
|
||||
等,在保持与开源社区模型兼容性的同时,提升了计算性能。该框架已经适配 CogVideoX。
|
||||
+ [CogVideoX-Controlnet](https://github.com/TheDenk/cogvideox-controlnet): 一个包含 CogvideoX 模型的简单 Controlnet 模块的代码。
|
||||
+ [VideoTuna](https://github.com/VideoVerses/VideoTuna):VideoTuna 是首个集成多种 AI 视频生成模型的仓库,支持文本转视频、图像转视频、文本转图像生成。
|
||||
+ [ConsisID](https://github.com/PKU-YuanGroup/ConsisID): 一种身份保持的文本到视频生成模型,基于 CogVideoX-5B,通过频率分解在生成的视频中保持面部一致性。
|
||||
+ [教程](https://www.youtube.com/watch?v=5UCkMzP2VLE&ab_channel=SECourses): 一个关于在Windows和云环境中安装和优化CogVideoX1.5-5B-I2V模型的分步指南。特别感谢[FurkanGozukara](https://github.com/FurkanGozukara)的努力和支持!
|
||||
|
||||
|
||||
## 完整项目代码结构
|
||||
|
||||
@ -369,8 +400,6 @@ CogVideo的demo网站在[https://models.aminer.cn/cogvideo](https://models.amine
|
||||
}
|
||||
```
|
||||
|
||||
我们欢迎您的贡献,您可以点击[这里](resources/contribute_zh.md)查看更多信息。
|
||||
|
||||
## 模型协议
|
||||
|
||||
本仓库代码使用 [Apache 2.0 协议](LICENSE) 发布。
|
||||
|
@ -1,126 +1,146 @@
|
||||
# CogVideoX diffusers Fine-tuning Guide
|
||||
# CogVideoX Diffusers Fine-tuning Guide
|
||||
|
||||
[中文阅读](./README_zh.md)
|
||||
|
||||
[日本語で読む](./README_ja.md)
|
||||
|
||||
This feature is not fully complete yet. If you want to check the fine-tuning for the SAT version, please
|
||||
see [here](../sat/README_zh.md). The dataset format is different from this version.
|
||||
If you're looking for the fine-tuning instructions for the SAT version, please check [here](../sat/README_zh.md). The
|
||||
dataset format for this version differs from the one used here.
|
||||
|
||||
## Hardware Requirements
|
||||
|
||||
+ CogVideoX-2B / 5B LoRA: 1 * A100 (5B need to use `--use_8bit_adam`)
|
||||
+ CogVideoX-2B SFT: 8 * A100 (Working)
|
||||
+ CogVideoX-5B-I2V is not supported yet.
|
||||
| Model | Training Type | Distribution Strategy | Mixed Precision | Training Resolution (FxHxW) | Hardware Requirements |
|
||||
|----------------------------|----------------|--------------------------------------|-----------------|-----------------------------|-------------------------|
|
||||
| cogvideox-t2v-2b | lora (rank128) | DDP | fp16 | 49x480x720 | 16GB VRAM (NVIDIA 4080) |
|
||||
| cogvideox-{t2v, i2v}-5b | lora (rank128) | DDP | bf16 | 49x480x720 | 24GB VRAM (NVIDIA 4090) |
|
||||
| cogvideox1.5-{t2v, i2v}-5b | lora (rank128) | DDP | bf16 | 81x768x1360 | 35GB VRAM (NVIDIA A100) |
|
||||
| cogvideox-t2v-2b | sft | DDP | fp16 | 49x480x720 | 36GB VRAM (NVIDIA A100) |
|
||||
| cogvideox-t2v-2b | sft | 1-GPU zero-2 + opt offload | fp16 | 49x480x720 | 17GB VRAM (NVIDIA 4090) |
|
||||
| cogvideox-t2v-2b | sft | 8-GPU zero-2 | fp16 | 49x480x720 | 17GB VRAM (NVIDIA 4090) |
|
||||
| cogvideox-t2v-2b | sft | 8-GPU zero-3 | fp16 | 49x480x720 | 19GB VRAM (NVIDIA 4090) |
|
||||
| cogvideox-t2v-2b | sft | 8-GPU zero-3 + opt and param offload | bf16 | 49x480x720 | 14GB VRAM (NVIDIA 4080) |
|
||||
| cogvideox-{t2v, i2v}-5b | sft | 1-GPU zero-2 + opt offload | bf16 | 49x480x720 | 42GB VRAM (NVIDIA A100) |
|
||||
| cogvideox-{t2v, i2v}-5b | sft | 8-GPU zero-2 | bf16 | 49x480x720 | 42GB VRAM (NVIDIA 4090) |
|
||||
| cogvideox-{t2v, i2v}-5b | sft | 8-GPU zero-3 | bf16 | 49x480x720 | 43GB VRAM (NVIDIA 4090) |
|
||||
| cogvideox-{t2v, i2v}-5b | sft | 8-GPU zero-3 + opt and param offload | bf16 | 49x480x720 | 28GB VRAM (NVIDIA 5090) |
|
||||
| cogvideox1.5-{t2v, i2v}-5b | sft | 1-GPU zero-2 + opt offload | bf16 | 81x768x1360 | 56GB VRAM (NVIDIA A100) |
|
||||
| cogvideox1.5-{t2v, i2v}-5b | sft | 8-GPU zero-2 | bf16 | 81x768x1360 | 55GB VRAM (NVIDIA A100) |
|
||||
| cogvideox1.5-{t2v, i2v}-5b | sft | 8-GPU zero-3 | bf16 | 81x768x1360 | 55GB VRAM (NVIDIA A100) |
|
||||
| cogvideox1.5-{t2v, i2v}-5b | sft | 8-GPU zero-3 + opt and param offload | bf16 | 81x768x1360 | 40GB VRAM (NVIDIA A100) |
|
||||
|
||||
## Install Dependencies
|
||||
|
||||
Since the related code has not been merged into the diffusers release, you need to base your fine-tuning on the
|
||||
diffusers branch. Please follow the steps below to install dependencies:
|
||||
Since the relevant code has not yet been merged into the official `diffusers` release, you need to fine-tune based on
|
||||
the diffusers branch. Follow the steps below to install the dependencies:
|
||||
|
||||
```shell
|
||||
git clone https://github.com/huggingface/diffusers.git
|
||||
cd diffusers # Now in Main branch
|
||||
cd diffusers # Now on the Main branch
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
## Prepare the Dataset
|
||||
|
||||
First, you need to prepare the dataset. The dataset format should be as follows, with `videos.txt` containing the list
|
||||
of videos in the `videos` directory:
|
||||
First, you need to prepare your dataset. Depending on your task type (T2V or I2V), the dataset format will vary
|
||||
slightly:
|
||||
|
||||
```
|
||||
.
|
||||
├── prompts.txt
|
||||
├── videos
|
||||
└── videos.txt
|
||||
├── videos.txt
|
||||
├── images # (Optional) For I2V, if not provided, first frame will be extracted from video as reference
|
||||
└── images.txt # (Optional) For I2V, if not provided, first frame will be extracted from video as reference
|
||||
```
|
||||
|
||||
You can download
|
||||
the [Disney Steamboat Willie](https://huggingface.co/datasets/Wild-Heart/Disney-VideoGeneration-Dataset) dataset from
|
||||
here.
|
||||
Where:
|
||||
|
||||
This video fine-tuning dataset is used as a test for fine-tuning.
|
||||
- `prompts.txt`: Contains the prompts
|
||||
- `videos/`: Contains the .mp4 video files
|
||||
- `videos.txt`: Contains the list of video files in the `videos/` directory
|
||||
- `images/`: (Optional) Contains the .png reference image files
|
||||
- `images.txt`: (Optional) Contains the list of reference image files
|
||||
|
||||
## Configuration Files and Execution
|
||||
You can download a sample dataset (
|
||||
T2V) [Disney Steamboat Willie](https://huggingface.co/datasets/Wild-Heart/Disney-VideoGeneration-Dataset).
|
||||
|
||||
The `accelerate` configuration files are as follows:
|
||||
If you need to use a validation dataset during training, make sure to provide a validation dataset with the same format
|
||||
as the training dataset.
|
||||
|
||||
+ `accelerate_config_machine_multi.yaml`: Suitable for multi-GPU use
|
||||
+ `accelerate_config_machine_single.yaml`: Suitable for single-GPU use
|
||||
## Running Scripts to Start Fine-tuning
|
||||
|
||||
The configuration for the `finetune` script is as follows:
|
||||
Before starting training, please note the following resolution requirements:
|
||||
|
||||
```
|
||||
accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gpu \ # Use accelerate to launch multi-GPU training with the config file accelerate_config_machine_single.yaml
|
||||
train_cogvideox_lora.py \ # Training script train_cogvideox_lora.py for LoRA fine-tuning on CogVideoX model
|
||||
--gradient_checkpointing \ # Enable gradient checkpointing to reduce memory usage
|
||||
--pretrained_model_name_or_path $MODEL_PATH \ # Path to the pretrained model, specified by $MODEL_PATH
|
||||
--cache_dir $CACHE_PATH \ # Cache directory for model files, specified by $CACHE_PATH
|
||||
--enable_tiling \ # Enable tiling technique to process videos in chunks, saving memory
|
||||
--enable_slicing \ # Enable slicing to further optimize memory by slicing inputs
|
||||
--instance_data_root $DATASET_PATH \ # Dataset path specified by $DATASET_PATH
|
||||
--caption_column prompts.txt \ # Specify the file prompts.txt for video descriptions used in training
|
||||
--video_column videos.txt \ # Specify the file videos.txt for video paths used in training
|
||||
--validation_prompt "" \ # Prompt used for generating validation videos during training
|
||||
--validation_prompt_separator ::: \ # Set ::: as the separator for validation prompts
|
||||
--num_validation_videos 1 \ # Generate 1 validation video per validation round
|
||||
--validation_epochs 100 \ # Perform validation every 100 training epochs
|
||||
--seed 42 \ # Set random seed to 42 for reproducibility
|
||||
--rank 128 \ # Set the rank for LoRA parameters to 128
|
||||
--lora_alpha 64 \ # Set the alpha parameter for LoRA to 64, adjusting LoRA learning rate
|
||||
--mixed_precision bf16 \ # Use bf16 mixed precision for training to save memory
|
||||
--output_dir $OUTPUT_PATH \ # Specify the output directory for the model, defined by $OUTPUT_PATH
|
||||
--height 480 \ # Set video height to 480 pixels
|
||||
--width 720 \ # Set video width to 720 pixels
|
||||
--fps 8 \ # Set video frame rate to 8 frames per second
|
||||
--max_num_frames 49 \ # Set the maximum number of frames per video to 49
|
||||
--skip_frames_start 0 \ # Skip 0 frames at the start of the video
|
||||
--skip_frames_end 0 \ # Skip 0 frames at the end of the video
|
||||
--train_batch_size 4 \ # Set training batch size to 4
|
||||
--num_train_epochs 30 \ # Total number of training epochs set to 30
|
||||
--checkpointing_steps 1000 \ # Save model checkpoint every 1000 steps
|
||||
--gradient_accumulation_steps 1 \ # Accumulate gradients for 1 step, updating after each batch
|
||||
--learning_rate 1e-3 \ # Set learning rate to 0.001
|
||||
--lr_scheduler cosine_with_restarts \ # Use cosine learning rate scheduler with restarts
|
||||
--lr_warmup_steps 200 \ # Warm up the learning rate for the first 200 steps
|
||||
--lr_num_cycles 1 \ # Set the number of learning rate cycles to 1
|
||||
--optimizer AdamW \ # Use the AdamW optimizer
|
||||
--adam_beta1 0.9 \ # Set Adam optimizer beta1 parameter to 0.9
|
||||
--adam_beta2 0.95 \ # Set Adam optimizer beta2 parameter to 0.95
|
||||
--max_grad_norm 1.0 \ # Set maximum gradient clipping value to 1.0
|
||||
--allow_tf32 \ # Enable TF32 to speed up training
|
||||
--report_to wandb # Use Weights and Biases (wandb) for logging and monitoring the training
|
||||
1. The number of frames must be a multiple of 8 **plus 1** (i.e., 8N+1), such as 49, 81 ...
|
||||
2. Recommended video resolutions for each model:
|
||||
- CogVideoX: 480x720 (height x width)
|
||||
- CogVideoX1.5: 768x1360 (height x width)
|
||||
3. For samples (videos or images) that don't match the training resolution, the code will directly resize them. This may
|
||||
cause aspect ratio distortion and affect training results. It's recommended to preprocess your samples (e.g., using
|
||||
crop + resize to maintain aspect ratio) before training.
|
||||
|
||||
> **Important Note**: To improve training efficiency, we automatically encode videos and cache the results on disk
|
||||
> before training. If you modify the data after training, please delete the latent directory under the video directory to
|
||||
> ensure the latest data is used.
|
||||
|
||||
### LoRA
|
||||
|
||||
```bash
|
||||
# Modify configuration parameters in train_ddp_t2v.sh
|
||||
# Main parameters to modify:
|
||||
# --output_dir: Output directory
|
||||
# --data_root: Dataset root directory
|
||||
# --caption_column: Path to prompt file
|
||||
# --image_column: Optional for I2V, path to reference image file list (remove this parameter to use the first frame of video as image condition)
|
||||
# --video_column: Path to video file list
|
||||
# --train_resolution: Training resolution (frames x height x width)
|
||||
# For other important parameters, please refer to the launch script
|
||||
|
||||
bash train_ddp_t2v.sh # Text-to-Video (T2V) fine-tuning
|
||||
bash train_ddp_i2v.sh # Image-to-Video (I2V) fine-tuning
|
||||
```
|
||||
|
||||
## Running the Script to Start Fine-tuning
|
||||
### SFT
|
||||
|
||||
Single Node (One GPU or Multi GPU) fine-tuning:
|
||||
We provide several zero configuration templates in the `configs/` directory. Please choose the appropriate training
|
||||
configuration based on your needs (configure the `deepspeed_config_file` option in `accelerate_config.yaml`).
|
||||
|
||||
```shell
|
||||
bash finetune_single_rank.sh
|
||||
```bash
|
||||
# Parameters to configure are the same as LoRA training
|
||||
|
||||
bash train_zero_t2v.sh # Text-to-Video (T2V) fine-tuning
|
||||
bash train_zero_i2v.sh # Image-to-Video (I2V) fine-tuning
|
||||
```
|
||||
|
||||
Multi-Node fine-tuning:
|
||||
In addition to setting the bash script parameters, you need to set the relevant training options in the zero
|
||||
configuration file and ensure the zero training configuration matches the parameters in the bash script, such as
|
||||
batch_size, gradient_accumulation_steps, mixed_precision. For details, please refer to
|
||||
the [DeepSpeed official documentation](https://www.deepspeed.ai/docs/config-json/)
|
||||
|
||||
```shell
|
||||
bash finetune_multi_rank.sh # Needs to be run on each node
|
||||
```
|
||||
When using SFT training, please note:
|
||||
|
||||
## Loading the Fine-tuned Model
|
||||
1. For SFT training, model offload is not used during validation, so the peak VRAM usage may exceed 24GB. For GPUs with
|
||||
less than 24GB VRAM, it's recommended to disable validation.
|
||||
|
||||
+ Please refer to [cli_demo.py](../inference/cli_demo.py) for how to load the fine-tuned model.
|
||||
2. Validation is slow when zero-3 is enabled, so it's recommended to disable validation when using zero-3.
|
||||
|
||||
## Load the Fine-tuned Model
|
||||
|
||||
+ Please refer to [cli_demo.py](../inference/cli_demo.py) for instructions on how to load the fine-tuned model.
|
||||
|
||||
+ For SFT trained models, please first use the `zero_to_fp32.py` script in the `checkpoint-*/` directory to merge the
|
||||
model weights
|
||||
|
||||
## Best Practices
|
||||
|
||||
+ Includes 70 training videos with a resolution of `200 x 480 x 720` (frames x height x width). By skipping frames in
|
||||
the data preprocessing, we created two smaller datasets with 49 and 16 frames to speed up experimentation, as the
|
||||
maximum frame limit recommended by the CogVideoX team is 49 frames. We split the 70 videos into three groups of 10,
|
||||
25, and 50 videos, with similar conceptual nature.
|
||||
+ Using 25 or more videos works best when training new concepts and styles.
|
||||
+ It works better to train using identifier tokens specified with `--id_token`. This is similar to Dreambooth training,
|
||||
but regular fine-tuning without such tokens also works.
|
||||
+ The original repository used `lora_alpha` set to 1. We found this value ineffective across multiple runs, likely due
|
||||
to differences in the backend and training setup. Our recommendation is to set `lora_alpha` equal to rank or rank //
|
||||
2.
|
||||
+ We recommend using a rank of 64 or higher.
|
||||
+ We included 70 training videos with a resolution of `200 x 480 x 720` (frames x height x width). Through frame
|
||||
skipping in the data preprocessing, we created two smaller datasets with 49 and 16 frames to speed up experiments. The
|
||||
maximum frame count recommended by the CogVideoX team is 49 frames. These 70 videos were divided into three groups:
|
||||
10, 25, and 50 videos, with similar conceptual nature.
|
||||
+ Videos with 25 or more frames work best for training new concepts and styles.
|
||||
+ It's recommended to use an identifier token, which can be specified using `--id_token`, for better training results.
|
||||
This is similar to Dreambooth training, though regular fine-tuning without using this token will still work.
|
||||
+ The original repository uses `lora_alpha` set to 1. We found that this value performed poorly in several runs,
|
||||
possibly due to differences in the model backend and training settings. Our recommendation is to set `lora_alpha` to
|
||||
be equal to the rank or `rank // 2`.
|
||||
+ It's advised to use a rank of 64 or higher.
|
||||
|
@ -1,116 +1,124 @@
|
||||
# CogVideoX diffusers 微調整方法
|
||||
|
||||
[Read this in English.](./README_zh)
|
||||
# CogVideoX Diffusers ファインチューニングガイド
|
||||
|
||||
[中文阅读](./README_zh.md)
|
||||
|
||||
[Read in English](./README.md)
|
||||
|
||||
この機能はまだ完全に完成していません。SATバージョンの微調整を確認したい場合は、[こちら](../sat/README_ja.md)を参照してください。本バージョンとは異なるデータセット形式を使用しています。
|
||||
SATバージョンのファインチューニング手順については、[こちら](../sat/README_zh.md)をご確認ください。このバージョンのデータセットフォーマットは、こちらのバージョンとは異なります。
|
||||
|
||||
## ハードウェア要件
|
||||
|
||||
+ CogVideoX-2B / 5B T2V LORA: 1 * A100 (5B need to use `--use_8bit_adam`)
|
||||
+ CogVideoX-2B SFT: 8 * A100 (動作確認済み)
|
||||
+ CogVideoX-5B-I2V まだサポートしていません
|
||||
| モデル | トレーニングタイプ | 分散戦略 | 混合トレーニング精度 | トレーニング解像度(フレーム数×高さ×幅) | ハードウェア要件 |
|
||||
|----------------------------|-------------------|----------------------------------|-----------------|-----------------------------|----------------------------|
|
||||
| cogvideox-t2v-2b | lora (rank128) | DDP | fp16 | 49x480x720 | 16GB VRAM (NVIDIA 4080) |
|
||||
| cogvideox-{t2v, i2v}-5b | lora (rank128) | DDP | bf16 | 49x480x720 | 24GB VRAM (NVIDIA 4090) |
|
||||
| cogvideox1.5-{t2v, i2v}-5b | lora (rank128) | DDP | bf16 | 81x768x1360 | 35GB VRAM (NVIDIA A100) |
|
||||
| cogvideox-t2v-2b | sft | DDP | fp16 | 49x480x720 | 36GB VRAM (NVIDIA A100) |
|
||||
| cogvideox-t2v-2b | sft | 1カード zero-2 + オプティマイゼーションオフロード | fp16 | 49x480x720 | 17GB VRAM (NVIDIA 4090) |
|
||||
| cogvideox-t2v-2b | sft | 8カード zero-2 | fp16 | 49x480x720 | 17GB VRAM (NVIDIA 4090) |
|
||||
| cogvideox-t2v-2b | sft | 8カード zero-3 | fp16 | 49x480x720 | 19GB VRAM (NVIDIA 4090) |
|
||||
| cogvideox-t2v-2b | sft | 8カード zero-3 + オプティマイゼーションとパラメータオフロード | bf16 | 49x480x720 | 14GB VRAM (NVIDIA 4080) |
|
||||
| cogvideox-{t2v, i2v}-5b | sft | 1カード zero-2 + オプティマイゼーションオフロード | bf16 | 49x480x720 | 42GB VRAM (NVIDIA A100) |
|
||||
| cogvideox-{t2v, i2v}-5b | sft | 8カード zero-2 | bf16 | 49x480x720 | 42GB VRAM (NVIDIA 4090) |
|
||||
| cogvideox-{t2v, i2v}-5b | sft | 8カード zero-3 | bf16 | 49x480x720 | 43GB VRAM (NVIDIA 4090) |
|
||||
| cogvideox-{t2v, i2v}-5b | sft | 8カード zero-3 + オプティマイゼーションとパラメータオフロード | bf16 | 49x480x720 | 28GB VRAM (NVIDIA 5090) |
|
||||
| cogvideox1.5-{t2v, i2v}-5b | sft | 1カード zero-2 + オプティマイゼーションオフロード | bf16 | 81x768x1360 | 56GB VRAM (NVIDIA A100) |
|
||||
| cogvideox1.5-{t2v, i2v}-5b | sft | 8カード zero-2 | bf16 | 81x768x1360 | 55GB VRAM (NVIDIA A100) |
|
||||
| cogvideox1.5-{t2v, i2v}-5b | sft | 8カード zero-3 | bf16 | 81x768x1360 | 55GB VRAM (NVIDIA A100) |
|
||||
| cogvideox1.5-{t2v, i2v}-5b | sft | 8カード zero-3 + オプティマイゼーションとパラメータオフロード | bf16 | 81x768x1360 | 40GB VRAM (NVIDIA A100) |
|
||||
|
||||
|
||||
|
||||
## 依存関係のインストール
|
||||
|
||||
関連コードはまだdiffusersのリリース版に統合されていないため、diffusersブランチを使用して微調整を行う必要があります。以下の手順に従って依存関係をインストールしてください:
|
||||
関連するコードがまだ `diffusers` の公式リリースに統合されていないため、`diffusers` ブランチを基にファインチューニングを行う必要があります。以下の手順に従って依存関係をインストールしてください:
|
||||
|
||||
```shell
|
||||
git clone https://github.com/huggingface/diffusers.git
|
||||
cd diffusers # Now in Main branch
|
||||
cd diffusers # 現在は Main ブランチ
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
## データセットの準備
|
||||
|
||||
まず、データセットを準備する必要があります。データセットの形式は以下のようになります。
|
||||
まず、データセットを準備する必要があります。タスクの種類(T2V または I2V)によって、データセットのフォーマットが少し異なります:
|
||||
|
||||
```
|
||||
.
|
||||
├── prompts.txt
|
||||
├── videos
|
||||
└── videos.txt
|
||||
├── videos.txt
|
||||
├── images # (オプション) I2Vの場合。提供されない場合、動画の最初のフレームが参照画像として使用されます
|
||||
└── images.txt # (オプション) I2Vの場合。提供されない場合、動画の最初のフレームが参照画像として使用されます
|
||||
```
|
||||
|
||||
[ディズニースチームボートウィリー](https://huggingface.co/datasets/Wild-Heart/Disney-VideoGeneration-Dataset)をここからダウンロードできます。
|
||||
各ファイルの役割は以下の通りです:
|
||||
- `prompts.txt`: プロンプトを格納
|
||||
- `videos/`: .mp4 動画ファイルを格納
|
||||
- `videos.txt`: `videos/` フォルダ内の動画ファイルリストを格納
|
||||
- `images/`: (オプション) .png 形式の参照画像ファイル
|
||||
- `images.txt`: (オプション) 参照画像ファイルリスト
|
||||
|
||||
ビデオ微調整データセットはテスト用として使用されます。
|
||||
トレーニング中に検証データセットを使用する場合は、トレーニングデータセットと同じフォーマットで検証データセットを提供する必要があります。
|
||||
|
||||
## 設定ファイルと実行
|
||||
## スクリプトを実行してファインチューニングを開始
|
||||
|
||||
`accelerate` 設定ファイルは以下の通りです:
|
||||
トレーニングを開始する前に、以下の解像度設定要件に注意してください:
|
||||
|
||||
+ accelerate_config_machine_multi.yaml 複数GPU向け
|
||||
+ accelerate_config_machine_single.yaml 単一GPU向け
|
||||
1. フレーム数は8の倍数 **+1** (つまり8N+1) でなければなりません。例:49, 81 ...
|
||||
2. ビデオ解像度はモデルのデフォルトサイズを使用することをお勧めします:
|
||||
- CogVideoX: 480x720 (高さ×幅)
|
||||
- CogVideoX1.5: 768x1360 (高さ×幅)
|
||||
3. トレーニング解像度に合わないサンプル(ビデオや画像)はコード内で自動的にリサイズされます。このため、サンプルのアスペクト比が変形し、トレーニング効果に影響を与える可能性があります。解像度に関しては、事前にサンプルを処理(例えば、アスペクト比を維持するためにクロップ+リサイズを使用)してからトレーニングを行うことをお勧めします。
|
||||
|
||||
`finetune` スクリプト設定ファイルの例:
|
||||
> **重要な注意**:トレーニング効率を高めるため、トレーニング前にビデオをエンコードし、その結果をディスクにキャッシュします。トレーニング後にデータを変更した場合は、`video`ディレクトリ内の`latent`ディレクトリを削除して、最新のデータを使用するようにしてください。
|
||||
|
||||
```
|
||||
accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gpu \ # accelerateを使用してmulti-GPUトレーニングを起動、設定ファイルはaccelerate_config_machine_single.yaml
|
||||
train_cogvideox_lora.py \ # LoRAの微調整用のトレーニングスクリプトtrain_cogvideox_lora.pyを実行
|
||||
--gradient_checkpointing \ # メモリ使用量を減らすためにgradient checkpointingを有効化
|
||||
--pretrained_model_name_or_path $MODEL_PATH \ # 事前学習済みモデルのパスを$MODEL_PATHで指定
|
||||
--cache_dir $CACHE_PATH \ # モデルファイルのキャッシュディレクトリを$CACHE_PATHで指定
|
||||
--enable_tiling \ # メモリ節約のためにタイル処理を有効化し、動画をチャンク分けして処理
|
||||
--enable_slicing \ # 入力をスライスしてさらにメモリ最適化
|
||||
--instance_data_root $DATASET_PATH \ # データセットのパスを$DATASET_PATHで指定
|
||||
--caption_column prompts.txt \ # トレーニングで使用する動画の説明ファイルをprompts.txtで指定
|
||||
--video_column videos.txt \ # トレーニングで使用する動画のパスファイルをvideos.txtで指定
|
||||
--validation_prompt "" \ # トレーニング中に検証用の動画を生成する際のプロンプト
|
||||
--validation_prompt_separator ::: \ # 検証プロンプトの区切り文字を:::に設定
|
||||
--num_validation_videos 1 \ # 各検証ラウンドで1本の動画を生成
|
||||
--validation_epochs 100 \ # 100エポックごとに検証を実施
|
||||
--seed 42 \ # 再現性を保証するためにランダムシードを42に設定
|
||||
--rank 128 \ # LoRAのパラメータのランクを128に設定
|
||||
--lora_alpha 64 \ # LoRAのalphaパラメータを64に設定し、LoRAの学習率を調整
|
||||
--mixed_precision bf16 \ # bf16混合精度でトレーニングし、メモリを節約
|
||||
--output_dir $OUTPUT_PATH \ # モデルの出力ディレクトリを$OUTPUT_PATHで指定
|
||||
--height 480 \ # 動画の高さを480ピクセルに設定
|
||||
--width 720 \ # 動画の幅を720ピクセルに設定
|
||||
--fps 8 \ # 動画のフレームレートを1秒あたり8フレームに設定
|
||||
--max_num_frames 49 \ # 各動画の最大フレーム数を49に設定
|
||||
--skip_frames_start 0 \ # 動画の最初のフレームを0スキップ
|
||||
--skip_frames_end 0 \ # 動画の最後のフレームを0スキップ
|
||||
--train_batch_size 4 \ # トレーニングのバッチサイズを4に設定
|
||||
--num_train_epochs 30 \ # 総トレーニングエポック数を30に設定
|
||||
--checkpointing_steps 1000 \ # 1000ステップごとにモデルのチェックポイントを保存
|
||||
--gradient_accumulation_steps 1 \ # 1ステップの勾配累積を行い、各バッチ後に更新
|
||||
--learning_rate 1e-3 \ # 学習率を0.001に設定
|
||||
--lr_scheduler cosine_with_restarts \ # リスタート付きのコサイン学習率スケジューラを使用
|
||||
--lr_warmup_steps 200 \ # トレーニングの最初の200ステップで学習率をウォームアップ
|
||||
--lr_num_cycles 1 \ # 学習率のサイクル数を1に設定
|
||||
--optimizer AdamW \ # AdamWオプティマイザーを使用
|
||||
--adam_beta1 0.9 \ # Adamオプティマイザーのbeta1パラメータを0.9に設定
|
||||
--adam_beta2 0.95 \ # Adamオプティマイザーのbeta2パラメータを0.95に設定
|
||||
--max_grad_norm 1.0 \ # 勾配クリッピングの最大値を1.0に設定
|
||||
--allow_tf32 \ # トレーニングを高速化するためにTF32を有効化
|
||||
--report_to wandb # Weights and Biasesを使用してトレーニングの記録とモニタリングを行う
|
||||
### LoRA
|
||||
|
||||
```bash
|
||||
# train_ddp_t2v.sh の設定パラメータを変更
|
||||
# 主に以下のパラメータを変更する必要があります:
|
||||
# --output_dir: 出力ディレクトリ
|
||||
# --data_root: データセットのルートディレクトリ
|
||||
# --caption_column: テキストプロンプトのファイルパス
|
||||
# --image_column: I2Vの場合、参照画像のファイルリストのパス(このパラメータを削除すると、デフォルトで動画の最初のフレームが画像条件として使用されます)
|
||||
# --video_column: 動画ファイルのリストのパス
|
||||
# --train_resolution: トレーニング解像度(フレーム数×高さ×幅)
|
||||
# その他の重要なパラメータについては、起動スクリプトを参照してください
|
||||
|
||||
bash train_ddp_t2v.sh # テキストから動画(T2V)微調整
|
||||
bash train_ddp_i2v.sh # 画像から動画(I2V)微調整
|
||||
```
|
||||
|
||||
## 微調整を開始
|
||||
### SFT
|
||||
|
||||
単一マシン (シングルGPU、マルチGPU) での微調整:
|
||||
`configs/`ディレクトリにはいくつかのZero構成テンプレートが提供されています。必要に応じて適切なトレーニング設定を選択してください(`accelerate_config.yaml`で`deepspeed_config_file`オプションを設定します)。
|
||||
|
||||
```shell
|
||||
bash finetune_single_rank.sh
|
||||
```bash
|
||||
# 設定するパラメータはLoRAトレーニングと同様です
|
||||
|
||||
bash train_zero_t2v.sh # テキストから動画(T2V)微調整
|
||||
bash train_zero_i2v.sh # 画像から動画(I2V)微調整
|
||||
```
|
||||
|
||||
複数マシン・マルチGPUでの微調整:
|
||||
Bashスクリプトの関連パラメータを設定するだけでなく、Zeroの設定ファイルでトレーニングオプションを設定し、Zeroのトレーニング設定がBashスクリプト内のパラメータと一致していることを確認する必要があります。例えば、`batch_size`、`gradient_accumulation_steps`、`mixed_precision`など、具体的な詳細は[DeepSpeed公式ドキュメント](https://www.deepspeed.ai/docs/config-json/)を参照してください。
|
||||
|
||||
```shell
|
||||
bash finetune_multi_rank.sh # 各ノードで実行する必要があります。
|
||||
```
|
||||
SFTトレーニングを使用する際に注意すべき点:
|
||||
|
||||
## 微調整済みモデルのロード
|
||||
1. SFTトレーニングでは、検証時にモデルオフロードは使用されません。そのため、24GB以下のGPUでは検証時にVRAMのピークが24GBを超える可能性があります。24GB以下のGPUでは、検証を無効にすることをお勧めします。
|
||||
|
||||
+ 微調整済みのモデルをロードする方法については、[cli_demo.py](../inference/cli_demo.py) を参照してください。
|
||||
2. Zero-3を有効にすると検証が遅くなるため、Zero-3では検証を無効にすることをお勧めします。
|
||||
|
||||
## ファインチューニングしたモデルの読み込み
|
||||
|
||||
+ ファインチューニングしたモデルを読み込む方法については、[cli_demo.py](../inference/cli_demo.py)を参照してください。
|
||||
|
||||
+ SFTトレーニングのモデルについては、まず`checkpoint-*`/ディレクトリ内の`zero_to_fp32.py`スクリプトを使用して、モデルの重みを統合してください。
|
||||
|
||||
## ベストプラクティス
|
||||
|
||||
+ 解像度が `200 x 480 x 720`(フレーム数 x 高さ x 幅)のトレーニングビデオが70本含まれています。データ前処理でフレームをスキップすることで、49フレームと16フレームの小さなデータセットを作成しました。これは実験を加速するためのもので、CogVideoXチームが推奨する最大フレーム数制限は49フレームです。
|
||||
+ 25本以上のビデオが新しい概念やスタイルのトレーニングに最適です。
|
||||
+ 現在、`--id_token` を指定して識別トークンを使用してトレーニングする方が効果的です。これはDreamboothトレーニングに似ていますが、通常の微調整でも機能します。
|
||||
+ 元のリポジトリでは `lora_alpha` を1に設定していましたが、複数の実行でこの値が効果的でないことがわかりました。モデルのバックエンドやトレーニング設定によるかもしれません。私たちの提案は、lora_alphaをrankと同じか、rank // 2に設定することです。
|
||||
+ Rank 64以上の設定を推奨します。
|
||||
+ 解像度が `200 x 480 x 720`(フレーム数 x 高さ x 幅)の70本のトレーニング動画を使用しました。データ前処理でフレームスキップを行い、49フレームおよび16フレームの2つの小さなデータセットを作成して実験速度を向上させました。CogVideoXチームの推奨最大フレーム数制限は49フレームです。これらの70本の動画は、10、25、50本の3つのグループに分け、概念的に類似した性質のものです。
|
||||
+ 25本以上の動画を使用することで、新しい概念やスタイルのトレーニングが最適です。
|
||||
+ `--id_token` で指定できる識別子トークンを使用すると、トレーニング効果がより良くなります。これはDreamboothトレーニングに似ていますが、このトークンを使用しない通常のファインチューニングでも問題なく動作します。
|
||||
+ 元のリポジトリでは `lora_alpha` が1に設定されていますが、この値は多くの実行で効果が悪かったため、モデルのバックエンドやトレーニング設定の違いが影響している可能性があります。私たちの推奨は、`lora_alpha` を rank と同じか、`rank // 2` に設定することです。
|
||||
+ rank は64以上に設定することをお勧めします。
|
||||
|
@ -1,16 +1,32 @@
|
||||
# CogVideoX diffusers 微调方案
|
||||
|
||||
[Read this in English](./README_zh.md)
|
||||
[Read this in English](./README.md)
|
||||
|
||||
[日本語で読む](./README_ja.md)
|
||||
|
||||
本功能尚未完全完善,如果您想查看SAT版本微调,请查看[这里](../sat/README_zh.md)。其数据集格式与本版本不同。
|
||||
如果您想查看SAT版本微调,请查看[这里](../sat/README_zh.md)。其数据集格式与本版本不同。
|
||||
|
||||
## 硬件要求
|
||||
|
||||
+ CogVideoX-2B / 5B T2V LORA: 1 * A100 (5B need to use `--use_8bit_adam`)
|
||||
+ CogVideoX-2B SFT: 8 * A100 (制作中)
|
||||
+ CogVideoX-5B-I2V 暂未支持
|
||||
| 模型 | 训练类型 | 分布式策略 | 混合训练精度 | 训练分辨率(帧数x高x宽) | 硬件要求 |
|
||||
|----------------------------|----------------|-----------------------------------|------------|----------------------|-----------------------|
|
||||
| cogvideox-t2v-2b | lora (rank128) | DDP | fp16 | 49x480x720 | 16G显存 (NVIDIA 4080) |
|
||||
| cogvideox-{t2v, i2v}-5b | lora (rank128) | DDP | bf16 | 49x480x720 | 24G显存 (NVIDIA 4090) |
|
||||
| cogvideox1.5-{t2v, i2v}-5b | lora (rank128) | DDP | bf16 | 81x768x1360 | 35G显存 (NVIDIA A100) |
|
||||
| cogvideox-t2v-2b | sft | DDP | fp16 | 49x480x720 | 36G显存 (NVIDIA A100) |
|
||||
| cogvideox-t2v-2b | sft | 1卡zero-2 + opt offload | fp16 | 49x480x720 | 17G显存 (NVIDIA 4090) |
|
||||
| cogvideox-t2v-2b | sft | 8卡zero-2 | fp16 | 49x480x720 | 17G显存 (NVIDIA 4090) |
|
||||
| cogvideox-t2v-2b | sft | 8卡zero-3 | fp16 | 49x480x720 | 19G显存 (NVIDIA 4090) |
|
||||
| cogvideox-t2v-2b | sft | 8卡zero-3 + opt and param offload | bf16 | 49x480x720 | 14G显存 (NVIDIA 4080) |
|
||||
| cogvideox-{t2v, i2v}-5b | sft | 1卡zero-2 + opt offload | bf16 | 49x480x720 | 42G显存 (NVIDIA A100) |
|
||||
| cogvideox-{t2v, i2v}-5b | sft | 8卡zero-2 | bf16 | 49x480x720 | 42G显存 (NVIDIA 4090) |
|
||||
| cogvideox-{t2v, i2v}-5b | sft | 8卡zero-3 | bf16 | 49x480x720 | 43G显存 (NVIDIA 4090) |
|
||||
| cogvideox-{t2v, i2v}-5b | sft | 8卡zero-3 + opt and param offload | bf16 | 49x480x720 | 28G显存 (NVIDIA 5090) |
|
||||
| cogvideox1.5-{t2v, i2v}-5b | sft | 1卡zero-2 + opt offload | bf16 | 81x768x1360 | 56G显存 (NVIDIA A100) |
|
||||
| cogvideox1.5-{t2v, i2v}-5b | sft | 8卡zero-2 | bf16 | 81x768x1360 | 55G显存 (NVIDIA A100) |
|
||||
| cogvideox1.5-{t2v, i2v}-5b | sft | 8卡zero-3 | bf16 | 81x768x1360 | 55G显存 (NVIDIA A100) |
|
||||
| cogvideox1.5-{t2v, i2v}-5b | sft | 8卡zero-3 + opt and param offload | bf16 | 81x768x1360 | 40G显存 (NVIDIA A100) |
|
||||
|
||||
|
||||
## 安装依赖
|
||||
|
||||
@ -24,89 +40,83 @@ pip install -e .
|
||||
|
||||
## 准备数据集
|
||||
|
||||
首先,你需要准备数据集,数据集格式如下,其中,videos.txt 存放 videos 中的视频。
|
||||
首先,你需要准备数据集。根据你的任务类型(T2V 或 I2V),数据集格式略有不同:
|
||||
|
||||
```
|
||||
.
|
||||
├── prompts.txt
|
||||
├── videos
|
||||
└── videos.txt
|
||||
├── videos.txt
|
||||
├── images # (可选) 对于I2V,若不提供,则从视频中提取第一帧作为参考图像
|
||||
└── images.txt # (可选) 对于I2V,若不提供,则从视频中提取第一帧作为参考图像
|
||||
```
|
||||
|
||||
你可以从这里下载 [迪士尼汽船威利号](https://huggingface.co/datasets/Wild-Heart/Disney-VideoGeneration-Dataset)
|
||||
其中:
|
||||
- `prompts.txt`: 存放提示词
|
||||
- `videos/`: 存放.mp4视频文件
|
||||
- `videos.txt`: 存放 videos 目录中的视频文件列表
|
||||
- `images/`: (可选) 存放.png参考图像文件
|
||||
- `images.txt`: (可选) 存放参考图像文件列表
|
||||
|
||||
视频微调数据集作为测试微调。
|
||||
你可以从这里下载示例数据集(T2V) [迪士尼汽船威利号](https://huggingface.co/datasets/Wild-Heart/Disney-VideoGeneration-Dataset)
|
||||
|
||||
## 配置文件和运行
|
||||
|
||||
`accelerate` 配置文件如下:
|
||||
|
||||
+ accelerate_config_machine_multi.yaml 适合多GPU使用
|
||||
+ accelerate_config_machine_single.yaml 适合单GPU使用
|
||||
|
||||
`finetune` 脚本配置文件如下:
|
||||
|
||||
```shell
|
||||
|
||||
accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gpu \ # 使用 accelerate 启动多GPU训练,配置文件为 accelerate_config_machine_single.yaml
|
||||
train_cogvideox_lora.py \ # 运行的训练脚本为 train_cogvideox_lora.py,用于在 CogVideoX 模型上进行 LoRA 微调
|
||||
--gradient_checkpointing \ # 启用梯度检查点功能,以减少显存使用
|
||||
--pretrained_model_name_or_path $MODEL_PATH \ # 预训练模型路径,通过 $MODEL_PATH 指定
|
||||
--cache_dir $CACHE_PATH \ # 模型缓存路径,由 $CACHE_PATH 指定
|
||||
--enable_tiling \ # 启用tiling技术,以分片处理视频,节省显存
|
||||
--enable_slicing \ # 启用slicing技术,将输入切片,以进一步优化内存
|
||||
--instance_data_root $DATASET_PATH \ # 数据集路径,由 $DATASET_PATH 指定
|
||||
--caption_column prompts.txt \ # 指定用于训练的视频描述文件,文件名为 prompts.txt
|
||||
--video_column videos.txt \ # 指定用于训练的视频路径文件,文件名为 videos.txt
|
||||
--validation_prompt "" \ # 验证集的提示语 (prompt),用于在训练期间生成验证视频
|
||||
--validation_prompt_separator ::: \ # 设置验证提示语的分隔符为 :::
|
||||
--num_validation_videos 1 \ # 每个验证回合生成 1 个视频
|
||||
--validation_epochs 100 \ # 每 100 个训练epoch进行一次验证
|
||||
--seed 42 \ # 设置随机种子为 42,以保证结果的可复现性
|
||||
--rank 128 \ # 设置 LoRA 参数的秩 (rank) 为 128
|
||||
--lora_alpha 64 \ # 设置 LoRA 的 alpha 参数为 64,用于调整LoRA的学习率
|
||||
--mixed_precision bf16 \ # 使用 bf16 混合精度进行训练,减少显存使用
|
||||
--output_dir $OUTPUT_PATH \ # 指定模型输出目录,由 $OUTPUT_PATH 定义
|
||||
--height 480 \ # 视频高度为 480 像素
|
||||
--width 720 \ # 视频宽度为 720 像素
|
||||
--fps 8 \ # 视频帧率设置为 8 帧每秒
|
||||
--max_num_frames 49 \ # 每个视频的最大帧数为 49 帧
|
||||
--skip_frames_start 0 \ # 跳过视频开头的帧数为 0
|
||||
--skip_frames_end 0 \ # 跳过视频结尾的帧数为 0
|
||||
--train_batch_size 4 \ # 训练时的 batch size 设置为 4
|
||||
--num_train_epochs 30 \ # 总训练epoch数为 30
|
||||
--checkpointing_steps 1000 \ # 每 1000 步保存一次模型检查点
|
||||
--gradient_accumulation_steps 1 \ # 梯度累计步数为 1,即每个 batch 后都会更新梯度
|
||||
--learning_rate 1e-3 \ # 学习率设置为 0.001
|
||||
--lr_scheduler cosine_with_restarts \ # 使用带重启的余弦学习率调度器
|
||||
--lr_warmup_steps 200 \ # 在训练的前 200 步进行学习率预热
|
||||
--lr_num_cycles 1 \ # 学习率周期设置为 1
|
||||
--optimizer AdamW \ # 使用 AdamW 优化器
|
||||
--adam_beta1 0.9 \ # 设置 Adam 优化器的 beta1 参数为 0.9
|
||||
--adam_beta2 0.95 \ # 设置 Adam 优化器的 beta2 参数为 0.95
|
||||
--max_grad_norm 1.0 \ # 最大梯度裁剪值设置为 1.0
|
||||
--allow_tf32 \ # 启用 TF32 以加速训练
|
||||
--report_to wandb # 使用 Weights and Biases 进行训练记录与监控
|
||||
```
|
||||
如果需要在训练过程中进行validation,则需要额外提供验证数据集,其中数据格式与训练集相同。
|
||||
|
||||
## 运行脚本,开始微调
|
||||
|
||||
单机(单卡,多卡)微调:
|
||||
在开始训练之前,请注意以下分辨率设置要求:
|
||||
|
||||
```shell
|
||||
bash finetune_single_rank.sh
|
||||
1. 帧数必须是8的倍数 **+1** (即8N+1), 例如49, 81 ...
|
||||
2. 视频分辨率建议使用模型的默认大小:
|
||||
- CogVideoX: 480x720 (高x宽)
|
||||
- CogVideoX1.5: 768x1360 (高x宽)
|
||||
3. 对于不满足训练分辨率的样本(视频或图片)在代码中会直接进行resize。这可能会导致样本的宽高比发生形变从而影响训练效果。建议用户提前对样本在分辨率上进行处理(例如使用crop + resize来维持宽高比)再进行训练。
|
||||
|
||||
> **重要提示**:为了提高训练效率,我们会在训练前自动对video进行encode并将结果缓存在磁盘。如果在训练后修改了数据,请删除video目录下的latent目录,以确保使用最新的数据。
|
||||
|
||||
### LoRA
|
||||
|
||||
```bash
|
||||
# 修改 train_ddp_t2v.sh 中的配置参数
|
||||
# 主要需要修改以下参数:
|
||||
# --output_dir: 输出目录
|
||||
# --data_root: 数据集根目录
|
||||
# --caption_column: 提示词文件路径
|
||||
# --image_column: I2V可选,参考图像文件列表路径 (移除这个参数将默认使用视频第一帧作为image condition)
|
||||
# --video_column: 视频文件列表路径
|
||||
# --train_resolution: 训练分辨率 (帧数x高x宽)
|
||||
# 其他重要参数请参考启动脚本
|
||||
|
||||
bash train_ddp_t2v.sh # 文本生成视频 (T2V) 微调
|
||||
bash train_ddp_i2v.sh # 图像生成视频 (I2V) 微调
|
||||
```
|
||||
|
||||
多机多卡微调:
|
||||
### SFT
|
||||
|
||||
```shell
|
||||
bash finetune_multi_rank.sh #需要在每个节点运行
|
||||
我们在`configs/`目录中提供了几个zero配置的模版,请根据你的需求选择合适的训练配置(在`accelerate_config.yaml`中配置`deepspeed_config_file`选项即可)。
|
||||
|
||||
```bash
|
||||
# 需要配置的参数与LoRA训练同理
|
||||
|
||||
bash train_zero_t2v.sh # 文本生成视频 (T2V) 微调
|
||||
bash train_zero_i2v.sh # 图像生成视频 (I2V) 微调
|
||||
```
|
||||
|
||||
除了设置bash脚本的相关参数,你还需要在zero的配置文件中设定相关的训练选项,并确保zero的训练配置与bash脚本中的参数一致,例如batch_size,gradient_accumulation_steps,mixed_precision,具体细节请参考[deepspeed官方文档](https://www.deepspeed.ai/docs/config-json/)
|
||||
|
||||
在使用sft训练时,有以下几点需要注意:
|
||||
|
||||
1. 对于sft训练,validation时不会使用model offload,因此显存峰值可能会超出24GB,所以对于24GB以下的显卡,建议关闭validation。
|
||||
|
||||
2. 开启zero-3时validation会比较慢,建议在zero-3下关闭validation。
|
||||
|
||||
|
||||
## 载入微调的模型
|
||||
|
||||
+ 请关注[cli_demo.py](../inference/cli_demo.py) 以了解如何加载微调的模型。
|
||||
|
||||
+ 对于sft训练的模型,请先使用`checkpoint-*/`目录下的`zero_to_fp32.py`脚本合并模型权重
|
||||
|
||||
## 最佳实践
|
||||
|
||||
+ 包含70个分辨率为 `200 x 480 x 720`(帧数 x 高 x
|
||||
@ -116,4 +126,3 @@ bash finetune_multi_rank.sh #需要在每个节点运行
|
||||
+ 原始仓库使用 `lora_alpha` 设置为 1。我们发现这个值在多次运行中效果不佳,可能是因为模型后端和训练设置的不同。我们的建议是将
|
||||
lora_alpha 设置为与 rank 相同或 rank // 2。
|
||||
+ 建议使用 rank 为 64 及以上的设置。
|
||||
|
||||
|
@ -1,21 +1,18 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
|
||||
gpu_ids: "0,1,2,3,4,5,6,7"
|
||||
num_processes: 8 # should be the same as the number of GPUs
|
||||
|
||||
debug: false
|
||||
deepspeed_config:
|
||||
gradient_accumulation_steps: 1
|
||||
gradient_clipping: 1.0
|
||||
offload_optimizer_device: none
|
||||
offload_param_device: none
|
||||
deepspeed_config_file: configs/zero2.yaml # e.g. configs/zero2.yaml, need use absolute path
|
||||
zero3_init_flag: false
|
||||
zero_stage: 2
|
||||
distributed_type: DEEPSPEED
|
||||
downcast_bf16: 'no'
|
||||
enable_cpu_affinity: false
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
dynamo_backend: 'no'
|
||||
mixed_precision: 'no'
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
@ -1,26 +0,0 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: true
|
||||
deepspeed_config:
|
||||
deepspeed_hostfile: hostfile.txt
|
||||
deepspeed_multinode_launcher: pdsh
|
||||
gradient_accumulation_steps: 1
|
||||
gradient_clipping: 1.0
|
||||
offload_optimizer_device: none
|
||||
offload_param_device: none
|
||||
zero3_init_flag: true
|
||||
zero_stage: 3
|
||||
distributed_type: DEEPSPEED
|
||||
downcast_bf16: 'yes'
|
||||
enable_cpu_affinity: true
|
||||
main_process_ip: 10.250.128.19
|
||||
main_process_port: 12355
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 4
|
||||
num_processes: 32
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
38
finetune/configs/zero2.yaml
Normal file
38
finetune/configs/zero2.yaml
Normal file
@ -0,0 +1,38 @@
|
||||
{
|
||||
"bf16": {
|
||||
"enabled": true
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": "auto",
|
||||
"weight_decay": "auto",
|
||||
"torch_adam": true,
|
||||
"adam_w_mode": true
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupDecayLR",
|
||||
"params": {
|
||||
"warmup_min_lr": "auto",
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto",
|
||||
"total_num_steps": "auto"
|
||||
}
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"allgather_partitions": true,
|
||||
"allgather_bucket_size": 2e8,
|
||||
"overlap_comm": true,
|
||||
"reduce_scatter": true,
|
||||
"reduce_bucket_size": 5e8,
|
||||
"contiguous_gradients": true
|
||||
},
|
||||
"gradient_accumulation_steps": 1,
|
||||
"train_micro_batch_size_per_gpu": 1,
|
||||
"train_batch_size": "auto",
|
||||
"gradient_clipping": "auto",
|
||||
"steps_per_print": 2000,
|
||||
"wall_clock_breakdown": false
|
||||
}
|
42
finetune/configs/zero2_offload.yaml
Normal file
42
finetune/configs/zero2_offload.yaml
Normal file
@ -0,0 +1,42 @@
|
||||
{
|
||||
"bf16": {
|
||||
"enabled": true
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": "auto",
|
||||
"weight_decay": "auto",
|
||||
"torch_adam": true,
|
||||
"adam_w_mode": true
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupDecayLR",
|
||||
"params": {
|
||||
"warmup_min_lr": "auto",
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto",
|
||||
"total_num_steps": "auto"
|
||||
}
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"allgather_partitions": true,
|
||||
"allgather_bucket_size": 2e8,
|
||||
"overlap_comm": true,
|
||||
"reduce_scatter": true,
|
||||
"reduce_bucket_size": 5e8,
|
||||
"contiguous_gradients": true,
|
||||
"offload_optimizer": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
}
|
||||
},
|
||||
"gradient_accumulation_steps": 1,
|
||||
"train_micro_batch_size_per_gpu": 1,
|
||||
"train_batch_size": "auto",
|
||||
"gradient_clipping": "auto",
|
||||
"steps_per_print": 2000,
|
||||
"wall_clock_breakdown": false
|
||||
}
|
43
finetune/configs/zero3.yaml
Normal file
43
finetune/configs/zero3.yaml
Normal file
@ -0,0 +1,43 @@
|
||||
{
|
||||
"bf16": {
|
||||
"enabled": true
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": "auto",
|
||||
"weight_decay": "auto",
|
||||
"torch_adam": true,
|
||||
"adam_w_mode": true
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupDecayLR",
|
||||
"params": {
|
||||
"warmup_min_lr": "auto",
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto",
|
||||
"total_num_steps": "auto"
|
||||
}
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
"overlap_comm": true,
|
||||
"contiguous_gradients": true,
|
||||
"reduce_bucket_size": 5e8,
|
||||
"stage3_prefetch_bucket_size": "auto",
|
||||
"stage3_param_persistence_threshold": "auto",
|
||||
"sub_group_size": 1e9,
|
||||
"stage3_max_live_parameters": 1e9,
|
||||
"stage3_max_reuse_distance": 1e9,
|
||||
"stage3_gather_16bit_weights_on_model_save": "auto",
|
||||
"stage3_prefetch_bucket_size": 5e8,
|
||||
"stage3_param_persistence_threshold": 1e5
|
||||
},
|
||||
"gradient_accumulation_steps": 1,
|
||||
"train_micro_batch_size_per_gpu": 1,
|
||||
"train_batch_size": "auto",
|
||||
"gradient_clipping": "auto",
|
||||
"steps_per_print": 2000,
|
||||
"wall_clock_breakdown": false
|
||||
}
|
51
finetune/configs/zero3_offload.yaml
Normal file
51
finetune/configs/zero3_offload.yaml
Normal file
@ -0,0 +1,51 @@
|
||||
{
|
||||
"bf16": {
|
||||
"enabled": true
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": "auto",
|
||||
"weight_decay": "auto",
|
||||
"torch_adam": true,
|
||||
"adam_w_mode": true
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupDecayLR",
|
||||
"params": {
|
||||
"warmup_min_lr": "auto",
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto",
|
||||
"total_num_steps": "auto"
|
||||
}
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
"offload_optimizer": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"offload_param": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"overlap_comm": true,
|
||||
"contiguous_gradients": true,
|
||||
"reduce_bucket_size": 5e8,
|
||||
"stage3_prefetch_bucket_size": "auto",
|
||||
"stage3_param_persistence_threshold": "auto",
|
||||
"sub_group_size": 1e9,
|
||||
"stage3_max_live_parameters": 1e9,
|
||||
"stage3_max_reuse_distance": 1e9,
|
||||
"stage3_gather_16bit_weights_on_model_save": "auto",
|
||||
"stage3_prefetch_bucket_size": 5e8,
|
||||
"stage3_param_persistence_threshold": 1e6
|
||||
},
|
||||
"gradient_accumulation_steps": 1,
|
||||
"train_micro_batch_size_per_gpu": 1,
|
||||
"train_batch_size": "auto",
|
||||
"gradient_clipping": "auto",
|
||||
"steps_per_print": 2000,
|
||||
"wall_clock_breakdown": false
|
||||
}
|
2
finetune/constants.py
Normal file
2
finetune/constants.py
Normal file
@ -0,0 +1,2 @@
|
||||
LOG_NAME = "trainer"
|
||||
LOG_LEVEL = "INFO"
|
12
finetune/datasets/__init__.py
Normal file
12
finetune/datasets/__init__.py
Normal file
@ -0,0 +1,12 @@
|
||||
from .bucket_sampler import BucketSampler
|
||||
from .i2v_dataset import I2VDatasetWithBuckets, I2VDatasetWithResize
|
||||
from .t2v_dataset import T2VDatasetWithBuckets, T2VDatasetWithResize
|
||||
|
||||
|
||||
__all__ = [
|
||||
"I2VDatasetWithResize",
|
||||
"I2VDatasetWithBuckets",
|
||||
"T2VDatasetWithResize",
|
||||
"T2VDatasetWithBuckets",
|
||||
"BucketSampler",
|
||||
]
|
79
finetune/datasets/bucket_sampler.py
Normal file
79
finetune/datasets/bucket_sampler.py
Normal file
@ -0,0 +1,79 @@
|
||||
import logging
|
||||
import random
|
||||
|
||||
from torch.utils.data import Dataset, Sampler
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BucketSampler(Sampler):
|
||||
r"""
|
||||
PyTorch Sampler that groups 3D data by height, width and frames.
|
||||
|
||||
Args:
|
||||
data_source (`VideoDataset`):
|
||||
A PyTorch dataset object that is an instance of `VideoDataset`.
|
||||
batch_size (`int`, defaults to `8`):
|
||||
The batch size to use for training.
|
||||
shuffle (`bool`, defaults to `True`):
|
||||
Whether or not to shuffle the data in each batch before dispatching to dataloader.
|
||||
drop_last (`bool`, defaults to `False`):
|
||||
Whether or not to drop incomplete buckets of data after completely iterating over all data
|
||||
in the dataset. If set to True, only batches that have `batch_size` number of entries will
|
||||
be yielded. If set to False, it is guaranteed that all data in the dataset will be processed
|
||||
and batches that do not have `batch_size` number of entries will also be yielded.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_source: Dataset,
|
||||
batch_size: int = 8,
|
||||
shuffle: bool = True,
|
||||
drop_last: bool = False,
|
||||
) -> None:
|
||||
self.data_source = data_source
|
||||
self.batch_size = batch_size
|
||||
self.shuffle = shuffle
|
||||
self.drop_last = drop_last
|
||||
|
||||
self.buckets = {resolution: [] for resolution in data_source.video_resolution_buckets}
|
||||
|
||||
self._raised_warning_for_drop_last = False
|
||||
|
||||
def __len__(self):
|
||||
if self.drop_last and not self._raised_warning_for_drop_last:
|
||||
self._raised_warning_for_drop_last = True
|
||||
logger.warning(
|
||||
"Calculating the length for bucket sampler is not possible when `drop_last` is set to True. This may cause problems when setting the number of epochs used for training."
|
||||
)
|
||||
return (len(self.data_source) + self.batch_size - 1) // self.batch_size
|
||||
|
||||
def __iter__(self):
|
||||
for index, data in enumerate(self.data_source):
|
||||
video_metadata = data["video_metadata"]
|
||||
f, h, w = (
|
||||
video_metadata["num_frames"],
|
||||
video_metadata["height"],
|
||||
video_metadata["width"],
|
||||
)
|
||||
|
||||
self.buckets[(f, h, w)].append(data)
|
||||
if len(self.buckets[(f, h, w)]) == self.batch_size:
|
||||
if self.shuffle:
|
||||
random.shuffle(self.buckets[(f, h, w)])
|
||||
yield self.buckets[(f, h, w)]
|
||||
del self.buckets[(f, h, w)]
|
||||
self.buckets[(f, h, w)] = []
|
||||
|
||||
if self.drop_last:
|
||||
return
|
||||
|
||||
for fhw, bucket in list(self.buckets.items()):
|
||||
if len(bucket) == 0:
|
||||
continue
|
||||
if self.shuffle:
|
||||
random.shuffle(bucket)
|
||||
yield bucket
|
||||
del self.buckets[fhw]
|
||||
self.buckets[fhw] = []
|
325
finetune/datasets/i2v_dataset.py
Normal file
325
finetune/datasets/i2v_dataset.py
Normal file
@ -0,0 +1,325 @@
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
from accelerate.logging import get_logger
|
||||
from safetensors.torch import load_file, save_file
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
from typing_extensions import override
|
||||
|
||||
from finetune.constants import LOG_LEVEL, LOG_NAME
|
||||
|
||||
from .utils import (
|
||||
load_images,
|
||||
load_images_from_videos,
|
||||
load_prompts,
|
||||
load_videos,
|
||||
preprocess_image_with_resize,
|
||||
preprocess_video_with_buckets,
|
||||
preprocess_video_with_resize,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from finetune.trainer import Trainer
|
||||
|
||||
# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
|
||||
# Very few bug reports but it happens. Look in decord Github issues for more relevant information.
|
||||
import decord # isort:skip
|
||||
|
||||
decord.bridge.set_bridge("torch")
|
||||
|
||||
logger = get_logger(LOG_NAME, LOG_LEVEL)
|
||||
|
||||
|
||||
class BaseI2VDataset(Dataset):
|
||||
"""
|
||||
Base dataset class for Image-to-Video (I2V) training.
|
||||
|
||||
This dataset loads prompts, videos and corresponding conditioning images for I2V training.
|
||||
|
||||
Args:
|
||||
data_root (str): Root directory containing the dataset files
|
||||
caption_column (str): Path to file containing text prompts/captions
|
||||
video_column (str): Path to file containing video paths
|
||||
image_column (str): Path to file containing image paths
|
||||
device (torch.device): Device to load the data on
|
||||
encode_video_fn (Callable[[torch.Tensor], torch.Tensor], optional): Function to encode videos
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_root: str,
|
||||
caption_column: str,
|
||||
video_column: str,
|
||||
image_column: str | None,
|
||||
device: torch.device,
|
||||
trainer: "Trainer" = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
data_root = Path(data_root)
|
||||
self.prompts = load_prompts(data_root / caption_column)
|
||||
self.videos = load_videos(data_root / video_column)
|
||||
if image_column is not None:
|
||||
self.images = load_images(data_root / image_column)
|
||||
else:
|
||||
self.images = load_images_from_videos(self.videos)
|
||||
self.trainer = trainer
|
||||
|
||||
self.device = device
|
||||
self.encode_video = trainer.encode_video
|
||||
self.encode_text = trainer.encode_text
|
||||
|
||||
# Check if number of prompts matches number of videos and images
|
||||
if not (len(self.videos) == len(self.prompts) == len(self.images)):
|
||||
raise ValueError(
|
||||
f"Expected length of prompts, videos and images to be the same but found {len(self.prompts)=}, {len(self.videos)=} and {len(self.images)=}. Please ensure that the number of caption prompts, videos and images match in your dataset."
|
||||
)
|
||||
|
||||
# Check if all video files exist
|
||||
if any(not path.is_file() for path in self.videos):
|
||||
raise ValueError(
|
||||
f"Some video files were not found. Please ensure that all video files exist in the dataset directory. Missing file: {next(path for path in self.videos if not path.is_file())}"
|
||||
)
|
||||
|
||||
# Check if all image files exist
|
||||
if any(not path.is_file() for path in self.images):
|
||||
raise ValueError(
|
||||
f"Some image files were not found. Please ensure that all image files exist in the dataset directory. Missing file: {next(path for path in self.images if not path.is_file())}"
|
||||
)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.videos)
|
||||
|
||||
def __getitem__(self, index: int) -> Dict[str, Any]:
|
||||
if isinstance(index, list):
|
||||
# Here, index is actually a list of data objects that we need to return.
|
||||
# The BucketSampler should ideally return indices. But, in the sampler, we'd like
|
||||
# to have information about num_frames, height and width. Since this is not stored
|
||||
# as metadata, we need to read the video to get this information. You could read this
|
||||
# information without loading the full video in memory, but we do it anyway. In order
|
||||
# to not load the video twice (once to get the metadata, and once to return the loaded video
|
||||
# based on sampled indices), we cache it in the BucketSampler. When the sampler is
|
||||
# to yield, we yield the cache data instead of indices. So, this special check ensures
|
||||
# that data is not loaded a second time. PRs are welcome for improvements.
|
||||
return index
|
||||
|
||||
prompt = self.prompts[index]
|
||||
video = self.videos[index]
|
||||
image = self.images[index]
|
||||
train_resolution_str = "x".join(str(x) for x in self.trainer.args.train_resolution)
|
||||
|
||||
cache_dir = self.trainer.args.data_root / "cache"
|
||||
video_latent_dir = (
|
||||
cache_dir / "video_latent" / self.trainer.args.model_name / train_resolution_str
|
||||
)
|
||||
prompt_embeddings_dir = cache_dir / "prompt_embeddings"
|
||||
video_latent_dir.mkdir(parents=True, exist_ok=True)
|
||||
prompt_embeddings_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
prompt_hash = str(hashlib.sha256(prompt.encode()).hexdigest())
|
||||
prompt_embedding_path = prompt_embeddings_dir / (prompt_hash + ".safetensors")
|
||||
encoded_video_path = video_latent_dir / (video.stem + ".safetensors")
|
||||
|
||||
if prompt_embedding_path.exists():
|
||||
prompt_embedding = load_file(prompt_embedding_path)["prompt_embedding"]
|
||||
logger.debug(
|
||||
f"process {self.trainer.accelerator.process_index}: Loaded prompt embedding from {prompt_embedding_path}",
|
||||
main_process_only=False,
|
||||
)
|
||||
else:
|
||||
prompt_embedding = self.encode_text(prompt)
|
||||
prompt_embedding = prompt_embedding.to("cpu")
|
||||
# [1, seq_len, hidden_size] -> [seq_len, hidden_size]
|
||||
prompt_embedding = prompt_embedding[0]
|
||||
save_file({"prompt_embedding": prompt_embedding}, prompt_embedding_path)
|
||||
logger.info(
|
||||
f"Saved prompt embedding to {prompt_embedding_path}", main_process_only=False
|
||||
)
|
||||
|
||||
if encoded_video_path.exists():
|
||||
encoded_video = load_file(encoded_video_path)["encoded_video"]
|
||||
logger.debug(f"Loaded encoded video from {encoded_video_path}", main_process_only=False)
|
||||
# shape of image: [C, H, W]
|
||||
_, image = self.preprocess(None, self.images[index])
|
||||
image = self.image_transform(image)
|
||||
else:
|
||||
frames, image = self.preprocess(video, image)
|
||||
frames = frames.to(self.device)
|
||||
image = image.to(self.device)
|
||||
image = self.image_transform(image)
|
||||
# Current shape of frames: [F, C, H, W]
|
||||
frames = self.video_transform(frames)
|
||||
|
||||
# Convert to [B, C, F, H, W]
|
||||
frames = frames.unsqueeze(0)
|
||||
frames = frames.permute(0, 2, 1, 3, 4).contiguous()
|
||||
encoded_video = self.encode_video(frames)
|
||||
|
||||
# [1, C, F, H, W] -> [C, F, H, W]
|
||||
encoded_video = encoded_video[0]
|
||||
encoded_video = encoded_video.to("cpu")
|
||||
image = image.to("cpu")
|
||||
save_file({"encoded_video": encoded_video}, encoded_video_path)
|
||||
logger.info(f"Saved encoded video to {encoded_video_path}", main_process_only=False)
|
||||
|
||||
# shape of encoded_video: [C, F, H, W]
|
||||
# shape of image: [C, H, W]
|
||||
return {
|
||||
"image": image,
|
||||
"prompt_embedding": prompt_embedding,
|
||||
"encoded_video": encoded_video,
|
||||
"video_metadata": {
|
||||
"num_frames": encoded_video.shape[1],
|
||||
"height": encoded_video.shape[2],
|
||||
"width": encoded_video.shape[3],
|
||||
},
|
||||
}
|
||||
|
||||
def preprocess(
|
||||
self, video_path: Path | None, image_path: Path | None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Loads and preprocesses a video and an image.
|
||||
If either path is None, no preprocessing will be done for that input.
|
||||
|
||||
Args:
|
||||
video_path: Path to the video file to load
|
||||
image_path: Path to the image file to load
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- video(torch.Tensor) of shape [F, C, H, W] where F is number of frames,
|
||||
C is number of channels, H is height and W is width
|
||||
- image(torch.Tensor) of shape [C, H, W]
|
||||
"""
|
||||
raise NotImplementedError("Subclass must implement this method")
|
||||
|
||||
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Applies transformations to a video.
|
||||
|
||||
Args:
|
||||
frames (torch.Tensor): A 4D tensor representing a video
|
||||
with shape [F, C, H, W] where:
|
||||
- F is number of frames
|
||||
- C is number of channels (3 for RGB)
|
||||
- H is height
|
||||
- W is width
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The transformed video tensor
|
||||
"""
|
||||
raise NotImplementedError("Subclass must implement this method")
|
||||
|
||||
def image_transform(self, image: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Applies transformations to an image.
|
||||
|
||||
Args:
|
||||
image (torch.Tensor): A 3D tensor representing an image
|
||||
with shape [C, H, W] where:
|
||||
- C is number of channels (3 for RGB)
|
||||
- H is height
|
||||
- W is width
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The transformed image tensor
|
||||
"""
|
||||
raise NotImplementedError("Subclass must implement this method")
|
||||
|
||||
|
||||
class I2VDatasetWithResize(BaseI2VDataset):
|
||||
"""
|
||||
A dataset class for image-to-video generation that resizes inputs to fixed dimensions.
|
||||
|
||||
This class preprocesses videos and images by resizing them to specified dimensions:
|
||||
- Videos are resized to max_num_frames x height x width
|
||||
- Images are resized to height x width
|
||||
|
||||
Args:
|
||||
max_num_frames (int): Maximum number of frames to extract from videos
|
||||
height (int): Target height for resizing videos and images
|
||||
width (int): Target width for resizing videos and images
|
||||
"""
|
||||
|
||||
def __init__(self, max_num_frames: int, height: int, width: int, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.max_num_frames = max_num_frames
|
||||
self.height = height
|
||||
self.width = width
|
||||
|
||||
self.__frame_transforms = transforms.Compose(
|
||||
[transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)]
|
||||
)
|
||||
self.__image_transforms = self.__frame_transforms
|
||||
|
||||
@override
|
||||
def preprocess(
|
||||
self, video_path: Path | None, image_path: Path | None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if video_path is not None:
|
||||
video = preprocess_video_with_resize(
|
||||
video_path, self.max_num_frames, self.height, self.width
|
||||
)
|
||||
else:
|
||||
video = None
|
||||
if image_path is not None:
|
||||
image = preprocess_image_with_resize(image_path, self.height, self.width)
|
||||
else:
|
||||
image = None
|
||||
return video, image
|
||||
|
||||
@override
|
||||
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
|
||||
return torch.stack([self.__frame_transforms(f) for f in frames], dim=0)
|
||||
|
||||
@override
|
||||
def image_transform(self, image: torch.Tensor) -> torch.Tensor:
|
||||
return self.__image_transforms(image)
|
||||
|
||||
|
||||
class I2VDatasetWithBuckets(BaseI2VDataset):
|
||||
def __init__(
|
||||
self,
|
||||
video_resolution_buckets: List[Tuple[int, int, int]],
|
||||
vae_temporal_compression_ratio: int,
|
||||
vae_height_compression_ratio: int,
|
||||
vae_width_compression_ratio: int,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.video_resolution_buckets = [
|
||||
(
|
||||
int(b[0] / vae_temporal_compression_ratio),
|
||||
int(b[1] / vae_height_compression_ratio),
|
||||
int(b[2] / vae_width_compression_ratio),
|
||||
)
|
||||
for b in video_resolution_buckets
|
||||
]
|
||||
self.__frame_transforms = transforms.Compose(
|
||||
[transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)]
|
||||
)
|
||||
self.__image_transforms = self.__frame_transforms
|
||||
|
||||
@override
|
||||
def preprocess(self, video_path: Path, image_path: Path) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
video = preprocess_video_with_buckets(video_path, self.video_resolution_buckets)
|
||||
image = preprocess_image_with_resize(image_path, video.shape[2], video.shape[3])
|
||||
return video, image
|
||||
|
||||
@override
|
||||
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
|
||||
return torch.stack([self.__frame_transforms(f) for f in frames], dim=0)
|
||||
|
||||
@override
|
||||
def image_transform(self, image: torch.Tensor) -> torch.Tensor:
|
||||
return self.__image_transforms(image)
|
264
finetune/datasets/t2v_dataset.py
Normal file
264
finetune/datasets/t2v_dataset.py
Normal file
@ -0,0 +1,264 @@
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
from accelerate.logging import get_logger
|
||||
from safetensors.torch import load_file, save_file
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
from typing_extensions import override
|
||||
|
||||
from finetune.constants import LOG_LEVEL, LOG_NAME
|
||||
|
||||
from .utils import (
|
||||
load_prompts,
|
||||
load_videos,
|
||||
preprocess_video_with_buckets,
|
||||
preprocess_video_with_resize,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from finetune.trainer import Trainer
|
||||
|
||||
# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
|
||||
# Very few bug reports but it happens. Look in decord Github issues for more relevant information.
|
||||
import decord # isort:skip
|
||||
|
||||
decord.bridge.set_bridge("torch")
|
||||
|
||||
logger = get_logger(LOG_NAME, LOG_LEVEL)
|
||||
|
||||
|
||||
class BaseT2VDataset(Dataset):
|
||||
"""
|
||||
Base dataset class for Text-to-Video (T2V) training.
|
||||
|
||||
This dataset loads prompts and videos for T2V training.
|
||||
|
||||
Args:
|
||||
data_root (str): Root directory containing the dataset files
|
||||
caption_column (str): Path to file containing text prompts/captions
|
||||
video_column (str): Path to file containing video paths
|
||||
device (torch.device): Device to load the data on
|
||||
encode_video_fn (Callable[[torch.Tensor], torch.Tensor], optional): Function to encode videos
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_root: str,
|
||||
caption_column: str,
|
||||
video_column: str,
|
||||
device: torch.device = None,
|
||||
trainer: "Trainer" = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
data_root = Path(data_root)
|
||||
self.prompts = load_prompts(data_root / caption_column)
|
||||
self.videos = load_videos(data_root / video_column)
|
||||
self.device = device
|
||||
self.encode_video = trainer.encode_video
|
||||
self.encode_text = trainer.encode_text
|
||||
self.trainer = trainer
|
||||
|
||||
# Check if all video files exist
|
||||
if any(not path.is_file() for path in self.videos):
|
||||
raise ValueError(
|
||||
f"Some video files were not found. Please ensure that all video files exist in the dataset directory. Missing file: {next(path for path in self.videos if not path.is_file())}"
|
||||
)
|
||||
|
||||
# Check if number of prompts matches number of videos
|
||||
if len(self.videos) != len(self.prompts):
|
||||
raise ValueError(
|
||||
f"Expected length of prompts and videos to be the same but found {len(self.prompts)=} and {len(self.videos)=}. Please ensure that the number of caption prompts and videos match in your dataset."
|
||||
)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.videos)
|
||||
|
||||
def __getitem__(self, index: int) -> Dict[str, Any]:
|
||||
if isinstance(index, list):
|
||||
# Here, index is actually a list of data objects that we need to return.
|
||||
# The BucketSampler should ideally return indices. But, in the sampler, we'd like
|
||||
# to have information about num_frames, height and width. Since this is not stored
|
||||
# as metadata, we need to read the video to get this information. You could read this
|
||||
# information without loading the full video in memory, but we do it anyway. In order
|
||||
# to not load the video twice (once to get the metadata, and once to return the loaded video
|
||||
# based on sampled indices), we cache it in the BucketSampler. When the sampler is
|
||||
# to yield, we yield the cache data instead of indices. So, this special check ensures
|
||||
# that data is not loaded a second time. PRs are welcome for improvements.
|
||||
return index
|
||||
|
||||
prompt = self.prompts[index]
|
||||
video = self.videos[index]
|
||||
train_resolution_str = "x".join(str(x) for x in self.trainer.args.train_resolution)
|
||||
|
||||
cache_dir = self.trainer.args.data_root / "cache"
|
||||
video_latent_dir = (
|
||||
cache_dir / "video_latent" / self.trainer.args.model_name / train_resolution_str
|
||||
)
|
||||
prompt_embeddings_dir = cache_dir / "prompt_embeddings"
|
||||
video_latent_dir.mkdir(parents=True, exist_ok=True)
|
||||
prompt_embeddings_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
prompt_hash = str(hashlib.sha256(prompt.encode()).hexdigest())
|
||||
prompt_embedding_path = prompt_embeddings_dir / (prompt_hash + ".safetensors")
|
||||
encoded_video_path = video_latent_dir / (video.stem + ".safetensors")
|
||||
|
||||
if prompt_embedding_path.exists():
|
||||
prompt_embedding = load_file(prompt_embedding_path)["prompt_embedding"]
|
||||
logger.debug(
|
||||
f"process {self.trainer.accelerator.process_index}: Loaded prompt embedding from {prompt_embedding_path}",
|
||||
main_process_only=False,
|
||||
)
|
||||
else:
|
||||
prompt_embedding = self.encode_text(prompt)
|
||||
prompt_embedding = prompt_embedding.to("cpu")
|
||||
# [1, seq_len, hidden_size] -> [seq_len, hidden_size]
|
||||
prompt_embedding = prompt_embedding[0]
|
||||
save_file({"prompt_embedding": prompt_embedding}, prompt_embedding_path)
|
||||
logger.info(
|
||||
f"Saved prompt embedding to {prompt_embedding_path}", main_process_only=False
|
||||
)
|
||||
|
||||
if encoded_video_path.exists():
|
||||
# encoded_video = torch.load(encoded_video_path, weights_only=True)
|
||||
encoded_video = load_file(encoded_video_path)["encoded_video"]
|
||||
logger.debug(f"Loaded encoded video from {encoded_video_path}", main_process_only=False)
|
||||
# shape of image: [C, H, W]
|
||||
else:
|
||||
frames = self.preprocess(video)
|
||||
frames = frames.to(self.device)
|
||||
# Current shape of frames: [F, C, H, W]
|
||||
frames = self.video_transform(frames)
|
||||
# Convert to [B, C, F, H, W]
|
||||
frames = frames.unsqueeze(0)
|
||||
frames = frames.permute(0, 2, 1, 3, 4).contiguous()
|
||||
encoded_video = self.encode_video(frames)
|
||||
|
||||
# [1, C, F, H, W] -> [C, F, H, W]
|
||||
encoded_video = encoded_video[0]
|
||||
encoded_video = encoded_video.to("cpu")
|
||||
save_file({"encoded_video": encoded_video}, encoded_video_path)
|
||||
logger.info(f"Saved encoded video to {encoded_video_path}", main_process_only=False)
|
||||
|
||||
# shape of encoded_video: [C, F, H, W]
|
||||
return {
|
||||
"prompt_embedding": prompt_embedding,
|
||||
"encoded_video": encoded_video,
|
||||
"video_metadata": {
|
||||
"num_frames": encoded_video.shape[1],
|
||||
"height": encoded_video.shape[2],
|
||||
"width": encoded_video.shape[3],
|
||||
},
|
||||
}
|
||||
|
||||
def preprocess(self, video_path: Path) -> torch.Tensor:
|
||||
"""
|
||||
Loads and preprocesses a video.
|
||||
|
||||
Args:
|
||||
video_path: Path to the video file to load.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Video tensor of shape [F, C, H, W] where:
|
||||
- F is number of frames
|
||||
- C is number of channels (3 for RGB)
|
||||
- H is height
|
||||
- W is width
|
||||
"""
|
||||
raise NotImplementedError("Subclass must implement this method")
|
||||
|
||||
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Applies transformations to a video.
|
||||
|
||||
Args:
|
||||
frames (torch.Tensor): A 4D tensor representing a video
|
||||
with shape [F, C, H, W] where:
|
||||
- F is number of frames
|
||||
- C is number of channels (3 for RGB)
|
||||
- H is height
|
||||
- W is width
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The transformed video tensor with the same shape as the input
|
||||
"""
|
||||
raise NotImplementedError("Subclass must implement this method")
|
||||
|
||||
|
||||
class T2VDatasetWithResize(BaseT2VDataset):
|
||||
"""
|
||||
A dataset class for text-to-video generation that resizes inputs to fixed dimensions.
|
||||
|
||||
This class preprocesses videos by resizing them to specified dimensions:
|
||||
- Videos are resized to max_num_frames x height x width
|
||||
|
||||
Args:
|
||||
max_num_frames (int): Maximum number of frames to extract from videos
|
||||
height (int): Target height for resizing videos
|
||||
width (int): Target width for resizing videos
|
||||
"""
|
||||
|
||||
def __init__(self, max_num_frames: int, height: int, width: int, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.max_num_frames = max_num_frames
|
||||
self.height = height
|
||||
self.width = width
|
||||
|
||||
self.__frame_transform = transforms.Compose(
|
||||
[transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)]
|
||||
)
|
||||
|
||||
@override
|
||||
def preprocess(self, video_path: Path) -> torch.Tensor:
|
||||
return preprocess_video_with_resize(
|
||||
video_path,
|
||||
self.max_num_frames,
|
||||
self.height,
|
||||
self.width,
|
||||
)
|
||||
|
||||
@override
|
||||
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
|
||||
return torch.stack([self.__frame_transform(f) for f in frames], dim=0)
|
||||
|
||||
|
||||
class T2VDatasetWithBuckets(BaseT2VDataset):
|
||||
def __init__(
|
||||
self,
|
||||
video_resolution_buckets: List[Tuple[int, int, int]],
|
||||
vae_temporal_compression_ratio: int,
|
||||
vae_height_compression_ratio: int,
|
||||
vae_width_compression_ratio: int,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
""" """
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.video_resolution_buckets = [
|
||||
(
|
||||
int(b[0] / vae_temporal_compression_ratio),
|
||||
int(b[1] / vae_height_compression_ratio),
|
||||
int(b[2] / vae_width_compression_ratio),
|
||||
)
|
||||
for b in video_resolution_buckets
|
||||
]
|
||||
|
||||
self.__frame_transform = transforms.Compose(
|
||||
[transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)]
|
||||
)
|
||||
|
||||
@override
|
||||
def preprocess(self, video_path: Path) -> torch.Tensor:
|
||||
return preprocess_video_with_buckets(video_path, self.video_resolution_buckets)
|
||||
|
||||
@override
|
||||
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
|
||||
return torch.stack([self.__frame_transform(f) for f in frames], dim=0)
|
196
finetune/datasets/utils.py
Normal file
196
finetune/datasets/utils.py
Normal file
@ -0,0 +1,196 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
from torchvision.transforms.functional import resize
|
||||
|
||||
|
||||
# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
|
||||
# Very few bug reports but it happens. Look in decord Github issues for more relevant information.
|
||||
import decord # isort:skip
|
||||
|
||||
decord.bridge.set_bridge("torch")
|
||||
|
||||
|
||||
########## loaders ##########
|
||||
|
||||
|
||||
def load_prompts(prompt_path: Path) -> List[str]:
|
||||
with open(prompt_path, "r", encoding="utf-8") as file:
|
||||
return [line.strip() for line in file.readlines() if len(line.strip()) > 0]
|
||||
|
||||
|
||||
def load_videos(video_path: Path) -> List[Path]:
|
||||
with open(video_path, "r", encoding="utf-8") as file:
|
||||
return [
|
||||
video_path.parent / line.strip() for line in file.readlines() if len(line.strip()) > 0
|
||||
]
|
||||
|
||||
|
||||
def load_images(image_path: Path) -> List[Path]:
|
||||
with open(image_path, "r", encoding="utf-8") as file:
|
||||
return [
|
||||
image_path.parent / line.strip() for line in file.readlines() if len(line.strip()) > 0
|
||||
]
|
||||
|
||||
|
||||
def load_images_from_videos(videos_path: List[Path]) -> List[Path]:
|
||||
first_frames_dir = videos_path[0].parent.parent / "first_frames"
|
||||
first_frames_dir.mkdir(exist_ok=True)
|
||||
|
||||
first_frame_paths = []
|
||||
for video_path in videos_path:
|
||||
frame_path = first_frames_dir / f"{video_path.stem}.png"
|
||||
if frame_path.exists():
|
||||
first_frame_paths.append(frame_path)
|
||||
continue
|
||||
|
||||
# Open video
|
||||
cap = cv2.VideoCapture(str(video_path))
|
||||
|
||||
# Read first frame
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
raise RuntimeError(f"Failed to read video: {video_path}")
|
||||
|
||||
# Save frame as PNG with same name as video
|
||||
cv2.imwrite(str(frame_path), frame)
|
||||
logging.info(f"Saved first frame to {frame_path}")
|
||||
|
||||
# Release video capture
|
||||
cap.release()
|
||||
|
||||
first_frame_paths.append(frame_path)
|
||||
|
||||
return first_frame_paths
|
||||
|
||||
|
||||
########## preprocessors ##########
|
||||
|
||||
|
||||
def preprocess_image_with_resize(
|
||||
image_path: Path | str,
|
||||
height: int,
|
||||
width: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Loads and resizes a single image.
|
||||
|
||||
Args:
|
||||
image_path: Path to the image file.
|
||||
height: Target height for resizing.
|
||||
width: Target width for resizing.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Image tensor with shape [C, H, W] where:
|
||||
C = number of channels (3 for RGB)
|
||||
H = height
|
||||
W = width
|
||||
"""
|
||||
if isinstance(image_path, str):
|
||||
image_path = Path(image_path)
|
||||
image = cv2.imread(image_path.as_posix())
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
image = cv2.resize(image, (width, height))
|
||||
image = torch.from_numpy(image).float()
|
||||
image = image.permute(2, 0, 1).contiguous()
|
||||
return image
|
||||
|
||||
|
||||
def preprocess_video_with_resize(
|
||||
video_path: Path | str,
|
||||
max_num_frames: int,
|
||||
height: int,
|
||||
width: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Loads and resizes a single video.
|
||||
|
||||
The function processes the video through these steps:
|
||||
1. If video frame count > max_num_frames, downsample frames evenly
|
||||
2. If video dimensions don't match (height, width), resize frames
|
||||
|
||||
Args:
|
||||
video_path: Path to the video file.
|
||||
max_num_frames: Maximum number of frames to keep.
|
||||
height: Target height for resizing.
|
||||
width: Target width for resizing.
|
||||
|
||||
Returns:
|
||||
A torch.Tensor with shape [F, C, H, W] where:
|
||||
F = number of frames
|
||||
C = number of channels (3 for RGB)
|
||||
H = height
|
||||
W = width
|
||||
"""
|
||||
if isinstance(video_path, str):
|
||||
video_path = Path(video_path)
|
||||
video_reader = decord.VideoReader(uri=video_path.as_posix(), width=width, height=height)
|
||||
video_num_frames = len(video_reader)
|
||||
if video_num_frames < max_num_frames:
|
||||
# Get all frames first
|
||||
frames = video_reader.get_batch(list(range(video_num_frames)))
|
||||
# Repeat the last frame until we reach max_num_frames
|
||||
last_frame = frames[-1:]
|
||||
num_repeats = max_num_frames - video_num_frames
|
||||
repeated_frames = last_frame.repeat(num_repeats, 1, 1, 1)
|
||||
frames = torch.cat([frames, repeated_frames], dim=0)
|
||||
return frames.float().permute(0, 3, 1, 2).contiguous()
|
||||
else:
|
||||
indices = list(range(0, video_num_frames, video_num_frames // max_num_frames))
|
||||
frames = video_reader.get_batch(indices)
|
||||
frames = frames[:max_num_frames].float()
|
||||
frames = frames.permute(0, 3, 1, 2).contiguous()
|
||||
return frames
|
||||
|
||||
|
||||
def preprocess_video_with_buckets(
|
||||
video_path: Path,
|
||||
resolution_buckets: List[Tuple[int, int, int]],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
video_path: Path to the video file.
|
||||
resolution_buckets: List of tuples (num_frames, height, width) representing
|
||||
available resolution buckets.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Video tensor with shape [F, C, H, W] where:
|
||||
F = number of frames
|
||||
C = number of channels (3 for RGB)
|
||||
H = height
|
||||
W = width
|
||||
|
||||
The function processes the video through these steps:
|
||||
1. Finds nearest frame bucket <= video frame count
|
||||
2. Downsamples frames evenly to match bucket size
|
||||
3. Finds nearest resolution bucket based on dimensions
|
||||
4. Resizes frames to match bucket resolution
|
||||
"""
|
||||
video_reader = decord.VideoReader(uri=video_path.as_posix())
|
||||
video_num_frames = len(video_reader)
|
||||
resolution_buckets = [bucket for bucket in resolution_buckets if bucket[0] <= video_num_frames]
|
||||
if len(resolution_buckets) == 0:
|
||||
raise ValueError(
|
||||
f"video frame count in {video_path} is less than all frame buckets {resolution_buckets}"
|
||||
)
|
||||
|
||||
nearest_frame_bucket = min(
|
||||
resolution_buckets,
|
||||
key=lambda bucket: video_num_frames - bucket[0],
|
||||
default=1,
|
||||
)[0]
|
||||
frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket))
|
||||
frames = video_reader.get_batch(frame_indices)
|
||||
frames = frames[:nearest_frame_bucket].float()
|
||||
frames = frames.permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
nearest_res = min(
|
||||
resolution_buckets, key=lambda x: abs(x[1] - frames.shape[2]) + abs(x[2] - frames.shape[3])
|
||||
)
|
||||
nearest_res = (nearest_res[1], nearest_res[2])
|
||||
frames = torch.stack([resize(f, nearest_res) for f in frames], dim=0)
|
||||
|
||||
return frames
|
@ -1,52 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
export MODEL_PATH="THUDM/CogVideoX-2b"
|
||||
export CACHE_PATH="~/.cache"
|
||||
export DATASET_PATH="Disney-VideoGeneration-Dataset"
|
||||
export OUTPUT_PATH="cogvideox-lora-multi-node"
|
||||
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
||||
export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES
|
||||
|
||||
# max batch-size is 2.
|
||||
accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gpu --machine_rank 0 \
|
||||
train_cogvideox_lora.py \
|
||||
--gradient_checkpointing \
|
||||
--pretrained_model_name_or_path $MODEL_PATH \
|
||||
--cache_dir $CACHE_PATH \
|
||||
--enable_tiling \
|
||||
--enable_slicing \
|
||||
--instance_data_root $DATASET_PATH \
|
||||
--caption_column prompts.txt \
|
||||
--video_column videos.txt \
|
||||
--validation_prompt "DISNEY A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions:::A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance" \
|
||||
--validation_prompt_separator ::: \
|
||||
--num_validation_videos 1 \
|
||||
--validation_epochs 100 \
|
||||
--seed 42 \
|
||||
--rank 128 \
|
||||
--lora_alpha 64 \
|
||||
--mixed_precision bf16 \
|
||||
--output_dir $OUTPUT_PATH \
|
||||
--height 480 \
|
||||
--width 720 \
|
||||
--fps 8 \
|
||||
--max_num_frames 49 \
|
||||
--skip_frames_start 0 \
|
||||
--skip_frames_end 0 \
|
||||
--train_batch_size 1 \
|
||||
--num_train_epochs 30 \
|
||||
--checkpointing_steps 1000 \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--learning_rate 1e-3 \
|
||||
--lr_scheduler cosine_with_restarts \
|
||||
--lr_warmup_steps 200 \
|
||||
--lr_num_cycles 1 \
|
||||
--enable_slicing \
|
||||
--enable_tiling \
|
||||
--gradient_checkpointing \
|
||||
--optimizer AdamW \
|
||||
--adam_beta1 0.9 \
|
||||
--adam_beta2 0.95 \
|
||||
--max_grad_norm 1.0 \
|
||||
--allow_tf32 \
|
||||
--report_to wandb
|
@ -1,52 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
export MODEL_PATH="THUDM/CogVideoX-2b"
|
||||
export CACHE_PATH="~/.cache"
|
||||
export DATASET_PATH="Disney-VideoGeneration-Dataset"
|
||||
export OUTPUT_PATH="cogvideox-lora-single-node"
|
||||
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
||||
export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES
|
||||
|
||||
# if you are not using wth 8 gus, change `accelerate_config_machine_single.yaml` num_processes as your gpu number
|
||||
accelerate launch --config_file accelerate_config_machine_single.yaml --multi_gpu \
|
||||
train_cogvideox_lora.py \
|
||||
--gradient_checkpointing \
|
||||
--pretrained_model_name_or_path $MODEL_PATH \
|
||||
--cache_dir $CACHE_PATH \
|
||||
--enable_tiling \
|
||||
--enable_slicing \
|
||||
--instance_data_root $DATASET_PATH \
|
||||
--caption_column prompts.txt \
|
||||
--video_column videos.txt \
|
||||
--validation_prompt "DISNEY A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions:::A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance" \
|
||||
--validation_prompt_separator ::: \
|
||||
--num_validation_videos 1 \
|
||||
--validation_epochs 100 \
|
||||
--seed 42 \
|
||||
--rank 128 \
|
||||
--lora_alpha 64 \
|
||||
--mixed_precision bf16 \
|
||||
--output_dir $OUTPUT_PATH \
|
||||
--height 480 \
|
||||
--width 720 \
|
||||
--fps 8 \
|
||||
--max_num_frames 49 \
|
||||
--skip_frames_start 0 \
|
||||
--skip_frames_end 0 \
|
||||
--train_batch_size 1 \
|
||||
--num_train_epochs 30 \
|
||||
--checkpointing_steps 1000 \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--learning_rate 1e-3 \
|
||||
--lr_scheduler cosine_with_restarts \
|
||||
--lr_warmup_steps 200 \
|
||||
--lr_num_cycles 1 \
|
||||
--enable_slicing \
|
||||
--enable_tiling \
|
||||
--gradient_checkpointing \
|
||||
--optimizer AdamW \
|
||||
--adam_beta1 0.9 \
|
||||
--adam_beta2 0.95 \
|
||||
--max_grad_norm 1.0 \
|
||||
--allow_tf32 \
|
||||
--report_to wandb
|
@ -1,2 +0,0 @@
|
||||
node1 slots=8
|
||||
node2 slots=8
|
12
finetune/models/__init__.py
Normal file
12
finetune/models/__init__.py
Normal file
@ -0,0 +1,12 @@
|
||||
import importlib
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
package_dir = Path(__file__).parent
|
||||
|
||||
for subdir in package_dir.iterdir():
|
||||
if subdir.is_dir() and not subdir.name.startswith("_"):
|
||||
for module_path in subdir.glob("*.py"):
|
||||
module_name = module_path.stem
|
||||
full_module_name = f".{subdir.name}.{module_name}"
|
||||
importlib.import_module(full_module_name, package=__name__)
|
9
finetune/models/cogvideox1_5_i2v/lora_trainer.py
Normal file
9
finetune/models/cogvideox1_5_i2v/lora_trainer.py
Normal file
@ -0,0 +1,9 @@
|
||||
from ..cogvideox_i2v.lora_trainer import CogVideoXI2VLoraTrainer
|
||||
from ..utils import register
|
||||
|
||||
|
||||
class CogVideoX1_5I2VLoraTrainer(CogVideoXI2VLoraTrainer):
|
||||
pass
|
||||
|
||||
|
||||
register("cogvideox1.5-i2v", "lora", CogVideoX1_5I2VLoraTrainer)
|
9
finetune/models/cogvideox1_5_i2v/sft_trainer.py
Normal file
9
finetune/models/cogvideox1_5_i2v/sft_trainer.py
Normal file
@ -0,0 +1,9 @@
|
||||
from ..cogvideox_i2v.sft_trainer import CogVideoXI2VSftTrainer
|
||||
from ..utils import register
|
||||
|
||||
|
||||
class CogVideoX1_5I2VSftTrainer(CogVideoXI2VSftTrainer):
|
||||
pass
|
||||
|
||||
|
||||
register("cogvideox1.5-i2v", "sft", CogVideoX1_5I2VSftTrainer)
|
9
finetune/models/cogvideox1_5_t2v/lora_trainer.py
Normal file
9
finetune/models/cogvideox1_5_t2v/lora_trainer.py
Normal file
@ -0,0 +1,9 @@
|
||||
from ..cogvideox_t2v.lora_trainer import CogVideoXT2VLoraTrainer
|
||||
from ..utils import register
|
||||
|
||||
|
||||
class CogVideoX1_5T2VLoraTrainer(CogVideoXT2VLoraTrainer):
|
||||
pass
|
||||
|
||||
|
||||
register("cogvideox1.5-t2v", "lora", CogVideoX1_5T2VLoraTrainer)
|
9
finetune/models/cogvideox1_5_t2v/sft_trainer.py
Normal file
9
finetune/models/cogvideox1_5_t2v/sft_trainer.py
Normal file
@ -0,0 +1,9 @@
|
||||
from ..cogvideox_t2v.sft_trainer import CogVideoXT2VSftTrainer
|
||||
from ..utils import register
|
||||
|
||||
|
||||
class CogVideoX1_5T2VSftTrainer(CogVideoXT2VSftTrainer):
|
||||
pass
|
||||
|
||||
|
||||
register("cogvideox1.5-t2v", "sft", CogVideoX1_5T2VSftTrainer)
|
272
finetune/models/cogvideox_i2v/lora_trainer.py
Normal file
272
finetune/models/cogvideox_i2v/lora_trainer.py
Normal file
@ -0,0 +1,272 @@
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
from diffusers import (
|
||||
AutoencoderKLCogVideoX,
|
||||
CogVideoXDPMScheduler,
|
||||
CogVideoXImageToVideoPipeline,
|
||||
CogVideoXTransformer3DModel,
|
||||
)
|
||||
from diffusers.models.embeddings import get_3d_rotary_pos_embed
|
||||
from PIL import Image
|
||||
from numpy import dtype
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
from typing_extensions import override
|
||||
|
||||
from finetune.schemas import Components
|
||||
from finetune.trainer import Trainer
|
||||
from finetune.utils import unwrap_model
|
||||
|
||||
from ..utils import register
|
||||
|
||||
|
||||
class CogVideoXI2VLoraTrainer(Trainer):
|
||||
UNLOAD_LIST = ["text_encoder"]
|
||||
|
||||
@override
|
||||
def load_components(self) -> Dict[str, Any]:
|
||||
components = Components()
|
||||
model_path = str(self.args.model_path)
|
||||
|
||||
components.pipeline_cls = CogVideoXImageToVideoPipeline
|
||||
|
||||
components.tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder="tokenizer")
|
||||
|
||||
components.text_encoder = T5EncoderModel.from_pretrained(
|
||||
model_path, subfolder="text_encoder"
|
||||
)
|
||||
|
||||
components.transformer = CogVideoXTransformer3DModel.from_pretrained(
|
||||
model_path, subfolder="transformer"
|
||||
)
|
||||
|
||||
components.vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae")
|
||||
|
||||
components.scheduler = CogVideoXDPMScheduler.from_pretrained(
|
||||
model_path, subfolder="scheduler"
|
||||
)
|
||||
|
||||
return components
|
||||
|
||||
@override
|
||||
def initialize_pipeline(self) -> CogVideoXImageToVideoPipeline:
|
||||
pipe = CogVideoXImageToVideoPipeline(
|
||||
tokenizer=self.components.tokenizer,
|
||||
text_encoder=self.components.text_encoder,
|
||||
vae=self.components.vae,
|
||||
transformer=unwrap_model(self.accelerator, self.components.transformer),
|
||||
scheduler=self.components.scheduler,
|
||||
)
|
||||
return pipe
|
||||
|
||||
@override
|
||||
def encode_video(self, video: torch.Tensor) -> torch.Tensor:
|
||||
# shape of input video: [B, C, F, H, W]
|
||||
vae = self.components.vae
|
||||
video = video.to(vae.device, dtype=vae.dtype)
|
||||
latent_dist = vae.encode(video).latent_dist
|
||||
latent = latent_dist.sample() * vae.config.scaling_factor
|
||||
return latent
|
||||
|
||||
@override
|
||||
def encode_text(self, prompt: str) -> torch.Tensor:
|
||||
prompt_token_ids = self.components.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.state.transformer_config.max_text_seq_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
prompt_token_ids = prompt_token_ids.input_ids
|
||||
prompt_embedding = self.components.text_encoder(
|
||||
prompt_token_ids.to(self.accelerator.device)
|
||||
)[0]
|
||||
return prompt_embedding
|
||||
|
||||
@override
|
||||
def collate_fn(self, samples: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
ret = {"encoded_videos": [], "prompt_embedding": [], "images": []}
|
||||
|
||||
for sample in samples:
|
||||
encoded_video = sample["encoded_video"]
|
||||
prompt_embedding = sample["prompt_embedding"]
|
||||
image = sample["image"]
|
||||
|
||||
ret["encoded_videos"].append(encoded_video)
|
||||
ret["prompt_embedding"].append(prompt_embedding)
|
||||
ret["images"].append(image)
|
||||
|
||||
ret["encoded_videos"] = torch.stack(ret["encoded_videos"])
|
||||
ret["prompt_embedding"] = torch.stack(ret["prompt_embedding"])
|
||||
ret["images"] = torch.stack(ret["images"])
|
||||
|
||||
return ret
|
||||
|
||||
@override
|
||||
def compute_loss(self, batch) -> torch.Tensor:
|
||||
prompt_embedding = batch["prompt_embedding"]
|
||||
latent = batch["encoded_videos"]
|
||||
images = batch["images"]
|
||||
|
||||
# Shape of prompt_embedding: [B, seq_len, hidden_size]
|
||||
# Shape of latent: [B, C, F, H, W]
|
||||
# Shape of images: [B, C, H, W]
|
||||
|
||||
patch_size_t = self.state.transformer_config.patch_size_t
|
||||
if patch_size_t is not None:
|
||||
ncopy = latent.shape[2] % patch_size_t
|
||||
# Copy the first frame ncopy times to match patch_size_t
|
||||
first_frame = latent[:, :, :1, :, :] # Get first frame [B, C, 1, H, W]
|
||||
latent = torch.cat([first_frame.repeat(1, 1, ncopy, 1, 1), latent], dim=2)
|
||||
assert latent.shape[2] % patch_size_t == 0
|
||||
|
||||
batch_size, num_channels, num_frames, height, width = latent.shape
|
||||
|
||||
# Get prompt embeddings
|
||||
_, seq_len, _ = prompt_embedding.shape
|
||||
prompt_embedding = prompt_embedding.view(batch_size, seq_len, -1).to(dtype=latent.dtype)
|
||||
|
||||
# Add frame dimension to images [B,C,H,W] -> [B,C,F,H,W]
|
||||
images = images.unsqueeze(2)
|
||||
# Add noise to images
|
||||
image_noise_sigma = torch.normal(
|
||||
mean=-3.0, std=0.5, size=(1,), device=self.accelerator.device
|
||||
)
|
||||
image_noise_sigma = torch.exp(image_noise_sigma).to(dtype=images.dtype)
|
||||
noisy_images = (
|
||||
images + torch.randn_like(images) * image_noise_sigma[:, None, None, None, None]
|
||||
)
|
||||
image_latent_dist = self.components.vae.encode(
|
||||
noisy_images.to(dtype=self.components.vae.dtype)
|
||||
).latent_dist
|
||||
image_latents = image_latent_dist.sample() * self.components.vae.config.scaling_factor
|
||||
|
||||
# Sample a random timestep for each sample
|
||||
timesteps = torch.randint(
|
||||
0,
|
||||
self.components.scheduler.config.num_train_timesteps,
|
||||
(batch_size,),
|
||||
device=self.accelerator.device,
|
||||
)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
# from [B, C, F, H, W] to [B, F, C, H, W]
|
||||
latent = latent.permute(0, 2, 1, 3, 4)
|
||||
image_latents = image_latents.permute(0, 2, 1, 3, 4)
|
||||
assert (latent.shape[0], *latent.shape[2:]) == (
|
||||
image_latents.shape[0],
|
||||
*image_latents.shape[2:],
|
||||
)
|
||||
|
||||
# Padding image_latents to the same frame number as latent
|
||||
padding_shape = (latent.shape[0], latent.shape[1] - 1, *latent.shape[2:])
|
||||
latent_padding = image_latents.new_zeros(padding_shape)
|
||||
image_latents = torch.cat([image_latents, latent_padding], dim=1)
|
||||
|
||||
# Add noise to latent
|
||||
noise = torch.randn_like(latent)
|
||||
latent_noisy = self.components.scheduler.add_noise(latent, noise, timesteps)
|
||||
|
||||
# Concatenate latent and image_latents in the channel dimension
|
||||
latent_img_noisy = torch.cat([latent_noisy, image_latents], dim=2)
|
||||
|
||||
# Prepare rotary embeds
|
||||
vae_scale_factor_spatial = 2 ** (len(self.components.vae.config.block_out_channels) - 1)
|
||||
transformer_config = self.state.transformer_config
|
||||
rotary_emb = (
|
||||
self.prepare_rotary_positional_embeddings(
|
||||
height=height * vae_scale_factor_spatial,
|
||||
width=width * vae_scale_factor_spatial,
|
||||
num_frames=num_frames,
|
||||
transformer_config=transformer_config,
|
||||
vae_scale_factor_spatial=vae_scale_factor_spatial,
|
||||
device=self.accelerator.device,
|
||||
)
|
||||
if transformer_config.use_rotary_positional_embeddings
|
||||
else None
|
||||
)
|
||||
|
||||
# Predict noise, For CogVideoX1.5 Only.
|
||||
ofs_emb = (
|
||||
None
|
||||
if self.state.transformer_config.ofs_embed_dim is None
|
||||
else latent.new_full((1,), fill_value=2.0)
|
||||
)
|
||||
predicted_noise = self.components.transformer(
|
||||
hidden_states=latent_img_noisy,
|
||||
encoder_hidden_states=prompt_embedding,
|
||||
timestep=timesteps,
|
||||
ofs=ofs_emb,
|
||||
image_rotary_emb=rotary_emb,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# Denoise
|
||||
latent_pred = self.components.scheduler.get_velocity(
|
||||
predicted_noise, latent_noisy, timesteps
|
||||
)
|
||||
|
||||
alphas_cumprod = self.components.scheduler.alphas_cumprod[timesteps]
|
||||
weights = 1 / (1 - alphas_cumprod)
|
||||
while len(weights.shape) < len(latent_pred.shape):
|
||||
weights = weights.unsqueeze(-1)
|
||||
|
||||
loss = torch.mean((weights * (latent_pred - latent) ** 2).reshape(batch_size, -1), dim=1)
|
||||
loss = loss.mean()
|
||||
|
||||
return loss
|
||||
|
||||
@override
|
||||
def validation_step(
|
||||
self, eval_data: Dict[str, Any], pipe: CogVideoXImageToVideoPipeline
|
||||
) -> List[Tuple[str, Image.Image | List[Image.Image]]]:
|
||||
"""
|
||||
Return the data that needs to be saved. For videos, the data format is List[PIL],
|
||||
and for images, the data format is PIL
|
||||
"""
|
||||
prompt, image, video = eval_data["prompt"], eval_data["image"], eval_data["video"]
|
||||
|
||||
video_generate = pipe(
|
||||
num_frames=self.state.train_frames,
|
||||
height=self.state.train_height,
|
||||
width=self.state.train_width,
|
||||
prompt=prompt,
|
||||
image=image,
|
||||
generator=self.state.generator,
|
||||
).frames[0]
|
||||
return [("video", video_generate)]
|
||||
|
||||
def prepare_rotary_positional_embeddings(
|
||||
self,
|
||||
height: int,
|
||||
width: int,
|
||||
num_frames: int,
|
||||
transformer_config: Dict,
|
||||
vae_scale_factor_spatial: int,
|
||||
device: torch.device,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
grid_height = height // (vae_scale_factor_spatial * transformer_config.patch_size)
|
||||
grid_width = width // (vae_scale_factor_spatial * transformer_config.patch_size)
|
||||
|
||||
if transformer_config.patch_size_t is None:
|
||||
base_num_frames = num_frames
|
||||
else:
|
||||
base_num_frames = (
|
||||
num_frames + transformer_config.patch_size_t - 1
|
||||
) // transformer_config.patch_size_t
|
||||
|
||||
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
||||
embed_dim=transformer_config.attention_head_dim,
|
||||
crops_coords=None,
|
||||
grid_size=(grid_height, grid_width),
|
||||
temporal_size=base_num_frames,
|
||||
grid_type="slice",
|
||||
max_size=(grid_height, grid_width),
|
||||
device=device,
|
||||
)
|
||||
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
|
||||
register("cogvideox-i2v", "lora", CogVideoXI2VLoraTrainer)
|
9
finetune/models/cogvideox_i2v/sft_trainer.py
Normal file
9
finetune/models/cogvideox_i2v/sft_trainer.py
Normal file
@ -0,0 +1,9 @@
|
||||
from ..cogvideox_i2v.lora_trainer import CogVideoXI2VLoraTrainer
|
||||
from ..utils import register
|
||||
|
||||
|
||||
class CogVideoXI2VSftTrainer(CogVideoXI2VLoraTrainer):
|
||||
pass
|
||||
|
||||
|
||||
register("cogvideox-i2v", "sft", CogVideoXI2VSftTrainer)
|
228
finetune/models/cogvideox_t2v/lora_trainer.py
Normal file
228
finetune/models/cogvideox_t2v/lora_trainer.py
Normal file
@ -0,0 +1,228 @@
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
from diffusers import (
|
||||
AutoencoderKLCogVideoX,
|
||||
CogVideoXDPMScheduler,
|
||||
CogVideoXPipeline,
|
||||
CogVideoXTransformer3DModel,
|
||||
)
|
||||
from diffusers.models.embeddings import get_3d_rotary_pos_embed
|
||||
from PIL import Image
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
from typing_extensions import override
|
||||
|
||||
from finetune.schemas import Components
|
||||
from finetune.trainer import Trainer
|
||||
from finetune.utils import unwrap_model
|
||||
|
||||
from ..utils import register
|
||||
|
||||
|
||||
class CogVideoXT2VLoraTrainer(Trainer):
|
||||
UNLOAD_LIST = ["text_encoder", "vae"]
|
||||
|
||||
@override
|
||||
def load_components(self) -> Components:
|
||||
components = Components()
|
||||
model_path = str(self.args.model_path)
|
||||
|
||||
components.pipeline_cls = CogVideoXPipeline
|
||||
|
||||
components.tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder="tokenizer")
|
||||
|
||||
components.text_encoder = T5EncoderModel.from_pretrained(
|
||||
model_path, subfolder="text_encoder"
|
||||
)
|
||||
|
||||
components.transformer = CogVideoXTransformer3DModel.from_pretrained(
|
||||
model_path, subfolder="transformer"
|
||||
)
|
||||
|
||||
components.vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae")
|
||||
|
||||
components.scheduler = CogVideoXDPMScheduler.from_pretrained(
|
||||
model_path, subfolder="scheduler"
|
||||
)
|
||||
|
||||
return components
|
||||
|
||||
@override
|
||||
def initialize_pipeline(self) -> CogVideoXPipeline:
|
||||
pipe = CogVideoXPipeline(
|
||||
tokenizer=self.components.tokenizer,
|
||||
text_encoder=self.components.text_encoder,
|
||||
vae=self.components.vae,
|
||||
transformer=unwrap_model(self.accelerator, self.components.transformer),
|
||||
scheduler=self.components.scheduler,
|
||||
)
|
||||
return pipe
|
||||
|
||||
@override
|
||||
def encode_video(self, video: torch.Tensor) -> torch.Tensor:
|
||||
# shape of input video: [B, C, F, H, W]
|
||||
vae = self.components.vae
|
||||
video = video.to(vae.device, dtype=vae.dtype)
|
||||
latent_dist = vae.encode(video).latent_dist
|
||||
latent = latent_dist.sample() * vae.config.scaling_factor
|
||||
return latent
|
||||
|
||||
@override
|
||||
def encode_text(self, prompt: str) -> torch.Tensor:
|
||||
prompt_token_ids = self.components.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.state.transformer_config.max_text_seq_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
prompt_token_ids = prompt_token_ids.input_ids
|
||||
prompt_embedding = self.components.text_encoder(
|
||||
prompt_token_ids.to(self.accelerator.device)
|
||||
)[0]
|
||||
return prompt_embedding
|
||||
|
||||
@override
|
||||
def collate_fn(self, samples: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
ret = {"encoded_videos": [], "prompt_embedding": []}
|
||||
|
||||
for sample in samples:
|
||||
encoded_video = sample["encoded_video"]
|
||||
prompt_embedding = sample["prompt_embedding"]
|
||||
|
||||
ret["encoded_videos"].append(encoded_video)
|
||||
ret["prompt_embedding"].append(prompt_embedding)
|
||||
|
||||
ret["encoded_videos"] = torch.stack(ret["encoded_videos"])
|
||||
ret["prompt_embedding"] = torch.stack(ret["prompt_embedding"])
|
||||
|
||||
return ret
|
||||
|
||||
@override
|
||||
def compute_loss(self, batch) -> torch.Tensor:
|
||||
prompt_embedding = batch["prompt_embedding"]
|
||||
latent = batch["encoded_videos"]
|
||||
|
||||
# Shape of prompt_embedding: [B, seq_len, hidden_size]
|
||||
# Shape of latent: [B, C, F, H, W]
|
||||
|
||||
patch_size_t = self.state.transformer_config.patch_size_t
|
||||
if patch_size_t is not None:
|
||||
ncopy = latent.shape[2] % patch_size_t
|
||||
# Copy the first frame ncopy times to match patch_size_t
|
||||
first_frame = latent[:, :, :1, :, :] # Get first frame [B, C, 1, H, W]
|
||||
latent = torch.cat([first_frame.repeat(1, 1, ncopy, 1, 1), latent], dim=2)
|
||||
assert latent.shape[2] % patch_size_t == 0
|
||||
|
||||
batch_size, num_channels, num_frames, height, width = latent.shape
|
||||
|
||||
# Get prompt embeddings
|
||||
_, seq_len, _ = prompt_embedding.shape
|
||||
prompt_embedding = prompt_embedding.view(batch_size, seq_len, -1).to(dtype=latent.dtype)
|
||||
|
||||
# Sample a random timestep for each sample
|
||||
timesteps = torch.randint(
|
||||
0,
|
||||
self.components.scheduler.config.num_train_timesteps,
|
||||
(batch_size,),
|
||||
device=self.accelerator.device,
|
||||
)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
# Add noise to latent
|
||||
latent = latent.permute(0, 2, 1, 3, 4) # from [B, C, F, H, W] to [B, F, C, H, W]
|
||||
noise = torch.randn_like(latent)
|
||||
latent_added_noise = self.components.scheduler.add_noise(latent, noise, timesteps)
|
||||
|
||||
# Prepare rotary embeds
|
||||
vae_scale_factor_spatial = 2 ** (len(self.components.vae.config.block_out_channels) - 1)
|
||||
transformer_config = self.state.transformer_config
|
||||
rotary_emb = (
|
||||
self.prepare_rotary_positional_embeddings(
|
||||
height=height * vae_scale_factor_spatial,
|
||||
width=width * vae_scale_factor_spatial,
|
||||
num_frames=num_frames,
|
||||
transformer_config=transformer_config,
|
||||
vae_scale_factor_spatial=vae_scale_factor_spatial,
|
||||
device=self.accelerator.device,
|
||||
)
|
||||
if transformer_config.use_rotary_positional_embeddings
|
||||
else None
|
||||
)
|
||||
|
||||
# Predict noise
|
||||
predicted_noise = self.components.transformer(
|
||||
hidden_states=latent_added_noise,
|
||||
encoder_hidden_states=prompt_embedding,
|
||||
timestep=timesteps,
|
||||
image_rotary_emb=rotary_emb,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# Denoise
|
||||
latent_pred = self.components.scheduler.get_velocity(
|
||||
predicted_noise, latent_added_noise, timesteps
|
||||
)
|
||||
|
||||
alphas_cumprod = self.components.scheduler.alphas_cumprod[timesteps]
|
||||
weights = 1 / (1 - alphas_cumprod)
|
||||
while len(weights.shape) < len(latent_pred.shape):
|
||||
weights = weights.unsqueeze(-1)
|
||||
|
||||
loss = torch.mean((weights * (latent_pred - latent) ** 2).reshape(batch_size, -1), dim=1)
|
||||
loss = loss.mean()
|
||||
|
||||
return loss
|
||||
|
||||
@override
|
||||
def validation_step(
|
||||
self, eval_data: Dict[str, Any], pipe: CogVideoXPipeline
|
||||
) -> List[Tuple[str, Image.Image | List[Image.Image]]]:
|
||||
"""
|
||||
Return the data that needs to be saved. For videos, the data format is List[PIL],
|
||||
and for images, the data format is PIL
|
||||
"""
|
||||
prompt, image, video = eval_data["prompt"], eval_data["image"], eval_data["video"]
|
||||
|
||||
video_generate = pipe(
|
||||
num_frames=self.state.train_frames,
|
||||
height=self.state.train_height,
|
||||
width=self.state.train_width,
|
||||
prompt=prompt,
|
||||
generator=self.state.generator,
|
||||
).frames[0]
|
||||
return [("video", video_generate)]
|
||||
|
||||
def prepare_rotary_positional_embeddings(
|
||||
self,
|
||||
height: int,
|
||||
width: int,
|
||||
num_frames: int,
|
||||
transformer_config: Dict,
|
||||
vae_scale_factor_spatial: int,
|
||||
device: torch.device,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
grid_height = height // (vae_scale_factor_spatial * transformer_config.patch_size)
|
||||
grid_width = width // (vae_scale_factor_spatial * transformer_config.patch_size)
|
||||
|
||||
if transformer_config.patch_size_t is None:
|
||||
base_num_frames = num_frames
|
||||
else:
|
||||
base_num_frames = (
|
||||
num_frames + transformer_config.patch_size_t - 1
|
||||
) // transformer_config.patch_size_t
|
||||
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
||||
embed_dim=transformer_config.attention_head_dim,
|
||||
crops_coords=None,
|
||||
grid_size=(grid_height, grid_width),
|
||||
temporal_size=base_num_frames,
|
||||
grid_type="slice",
|
||||
max_size=(grid_height, grid_width),
|
||||
device=device,
|
||||
)
|
||||
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
|
||||
register("cogvideox-t2v", "lora", CogVideoXT2VLoraTrainer)
|
9
finetune/models/cogvideox_t2v/sft_trainer.py
Normal file
9
finetune/models/cogvideox_t2v/sft_trainer.py
Normal file
@ -0,0 +1,9 @@
|
||||
from ..cogvideox_t2v.lora_trainer import CogVideoXT2VLoraTrainer
|
||||
from ..utils import register
|
||||
|
||||
|
||||
class CogVideoXT2VSftTrainer(CogVideoXT2VLoraTrainer):
|
||||
pass
|
||||
|
||||
|
||||
register("cogvideox-t2v", "sft", CogVideoXT2VSftTrainer)
|
59
finetune/models/utils.py
Normal file
59
finetune/models/utils.py
Normal file
@ -0,0 +1,59 @@
|
||||
from typing import Dict, Literal
|
||||
|
||||
from finetune.trainer import Trainer
|
||||
|
||||
|
||||
SUPPORTED_MODELS: Dict[str, Dict[str, Trainer]] = {}
|
||||
|
||||
|
||||
def register(model_name: str, training_type: Literal["lora", "sft"], trainer_cls: Trainer):
|
||||
"""Register a model and its associated functions for a specific training type.
|
||||
|
||||
Args:
|
||||
model_name (str): Name of the model to register (e.g. "cogvideox-5b")
|
||||
training_type (Literal["lora", "sft"]): Type of training - either "lora" or "sft"
|
||||
trainer_cls (Trainer): Trainer class to register.
|
||||
"""
|
||||
|
||||
# Check if model_name and training_type exists in SUPPORTED_MODELS
|
||||
if model_name not in SUPPORTED_MODELS:
|
||||
SUPPORTED_MODELS[model_name] = {}
|
||||
else:
|
||||
if training_type in SUPPORTED_MODELS[model_name]:
|
||||
raise ValueError(f"Training type {training_type} already exists for model {model_name}")
|
||||
|
||||
SUPPORTED_MODELS[model_name][training_type] = trainer_cls
|
||||
|
||||
|
||||
def show_supported_models():
|
||||
"""Print all currently supported models and their training types."""
|
||||
|
||||
print("\nSupported Models:")
|
||||
print("================")
|
||||
|
||||
for model_name, training_types in SUPPORTED_MODELS.items():
|
||||
print(f"\n{model_name}")
|
||||
print("-" * len(model_name))
|
||||
for training_type in training_types:
|
||||
print(f" • {training_type}")
|
||||
|
||||
|
||||
def get_model_cls(model_type: str, training_type: Literal["lora", "sft"]) -> Trainer:
|
||||
"""Get the trainer class for a specific model and training type."""
|
||||
if model_type not in SUPPORTED_MODELS:
|
||||
print(f"\nModel '{model_type}' is not supported.")
|
||||
print("\nSupported models are:")
|
||||
for supported_model in SUPPORTED_MODELS:
|
||||
print(f" • {supported_model}")
|
||||
raise ValueError(f"Model '{model_type}' is not supported")
|
||||
|
||||
if training_type not in SUPPORTED_MODELS[model_type]:
|
||||
print(f"\nTraining type '{training_type}' is not supported for model '{model_type}'.")
|
||||
print(f"\nSupported training types for '{model_type}' are:")
|
||||
for supported_type in SUPPORTED_MODELS[model_type]:
|
||||
print(f" • {supported_type}")
|
||||
raise ValueError(
|
||||
f"Training type '{training_type}' is not supported for model '{model_type}'"
|
||||
)
|
||||
|
||||
return SUPPORTED_MODELS[model_type][training_type]
|
6
finetune/schemas/__init__.py
Normal file
6
finetune/schemas/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
from .args import Args
|
||||
from .components import Components
|
||||
from .state import State
|
||||
|
||||
|
||||
__all__ = ["Args", "State", "Components"]
|
254
finetune/schemas/args.py
Normal file
254
finetune/schemas/args.py
Normal file
@ -0,0 +1,254 @@
|
||||
import argparse
|
||||
import datetime
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Literal, Tuple
|
||||
|
||||
from pydantic import BaseModel, ValidationInfo, field_validator
|
||||
|
||||
|
||||
class Args(BaseModel):
|
||||
########## Model ##########
|
||||
model_path: Path
|
||||
model_name: str
|
||||
model_type: Literal["i2v", "t2v"]
|
||||
training_type: Literal["lora", "sft"] = "lora"
|
||||
|
||||
########## Output ##########
|
||||
output_dir: Path = Path("train_results/{:%Y-%m-%d-%H-%M-%S}".format(datetime.datetime.now()))
|
||||
report_to: Literal["tensorboard", "wandb", "all"] | None = None
|
||||
tracker_name: str = "finetrainer-cogvideo"
|
||||
|
||||
########## Data ###########
|
||||
data_root: Path
|
||||
caption_column: Path
|
||||
image_column: Path | None = None
|
||||
video_column: Path
|
||||
|
||||
########## Training #########
|
||||
resume_from_checkpoint: Path | None = None
|
||||
|
||||
seed: int | None = None
|
||||
train_epochs: int
|
||||
train_steps: int | None = None
|
||||
checkpointing_steps: int = 200
|
||||
checkpointing_limit: int = 10
|
||||
|
||||
batch_size: int
|
||||
gradient_accumulation_steps: int = 1
|
||||
|
||||
train_resolution: Tuple[int, int, int] # shape: (frames, height, width)
|
||||
|
||||
#### deprecated args: video_resolution_buckets
|
||||
# if use bucket for training, should not be None
|
||||
# Note1: At least one frame rate in the bucket must be less than or equal to the frame rate of any video in the dataset
|
||||
# Note2: For cogvideox, cogvideox1.5
|
||||
# The frame rate set in the bucket must be an integer multiple of 8 (spatial_compression_rate[4] * path_t[2] = 8)
|
||||
# The height and width set in the bucket must be an integer multiple of 8 (temporal_compression_rate[8])
|
||||
# video_resolution_buckets: List[Tuple[int, int, int]] | None = None
|
||||
|
||||
mixed_precision: Literal["no", "fp16", "bf16"]
|
||||
|
||||
learning_rate: float = 2e-5
|
||||
optimizer: str = "adamw"
|
||||
beta1: float = 0.9
|
||||
beta2: float = 0.95
|
||||
beta3: float = 0.98
|
||||
epsilon: float = 1e-8
|
||||
weight_decay: float = 1e-4
|
||||
max_grad_norm: float = 1.0
|
||||
|
||||
lr_scheduler: str = "constant_with_warmup"
|
||||
lr_warmup_steps: int = 100
|
||||
lr_num_cycles: int = 1
|
||||
lr_power: float = 1.0
|
||||
|
||||
num_workers: int = 8
|
||||
pin_memory: bool = True
|
||||
|
||||
gradient_checkpointing: bool = True
|
||||
enable_slicing: bool = True
|
||||
enable_tiling: bool = True
|
||||
nccl_timeout: int = 1800
|
||||
|
||||
########## Lora ##########
|
||||
rank: int = 128
|
||||
lora_alpha: int = 64
|
||||
target_modules: List[str] = ["to_q", "to_k", "to_v", "to_out.0"]
|
||||
|
||||
########## Validation ##########
|
||||
do_validation: bool = False
|
||||
validation_steps: int | None # if set, should be a multiple of checkpointing_steps
|
||||
validation_dir: Path | None # if set do_validation, should not be None
|
||||
validation_prompts: str | None # if set do_validation, should not be None
|
||||
validation_images: str | None # if set do_validation and model_type == i2v, should not be None
|
||||
validation_videos: str | None # if set do_validation and model_type == v2v, should not be None
|
||||
gen_fps: int = 15
|
||||
|
||||
#### deprecated args: gen_video_resolution
|
||||
# 1. If set do_validation, should not be None
|
||||
# 2. Suggest selecting the bucket from `video_resolution_buckets` that is closest to the resolution you have chosen for fine-tuning
|
||||
# or the resolution recommended by the model
|
||||
# 3. Note: For cogvideox, cogvideox1.5
|
||||
# The frame rate set in the bucket must be an integer multiple of 8 (spatial_compression_rate[4] * path_t[2] = 8)
|
||||
# The height and width set in the bucket must be an integer multiple of 8 (temporal_compression_rate[8])
|
||||
# gen_video_resolution: Tuple[int, int, int] | None # shape: (frames, height, width)
|
||||
|
||||
@field_validator("image_column")
|
||||
def validate_image_column(cls, v: str | None, info: ValidationInfo) -> str | None:
|
||||
values = info.data
|
||||
if values.get("model_type") == "i2v" and not v:
|
||||
logging.warning(
|
||||
"No `image_column` specified for i2v model. Will automatically extract first frames from videos as conditioning images."
|
||||
)
|
||||
return v
|
||||
|
||||
@field_validator("validation_dir", "validation_prompts")
|
||||
def validate_validation_required_fields(cls, v: Any, info: ValidationInfo) -> Any:
|
||||
values = info.data
|
||||
if values.get("do_validation") and not v:
|
||||
field_name = info.field_name
|
||||
raise ValueError(f"{field_name} must be specified when do_validation is True")
|
||||
return v
|
||||
|
||||
@field_validator("validation_images")
|
||||
def validate_validation_images(cls, v: str | None, info: ValidationInfo) -> str | None:
|
||||
values = info.data
|
||||
if values.get("do_validation") and values.get("model_type") == "i2v" and not v:
|
||||
raise ValueError(
|
||||
"validation_images must be specified when do_validation is True and model_type is i2v"
|
||||
)
|
||||
return v
|
||||
|
||||
@field_validator("validation_videos")
|
||||
def validate_validation_videos(cls, v: str | None, info: ValidationInfo) -> str | None:
|
||||
values = info.data
|
||||
if values.get("do_validation") and values.get("model_type") == "v2v" and not v:
|
||||
raise ValueError(
|
||||
"validation_videos must be specified when do_validation is True and model_type is v2v"
|
||||
)
|
||||
return v
|
||||
|
||||
@field_validator("validation_steps")
|
||||
def validate_validation_steps(cls, v: int | None, info: ValidationInfo) -> int | None:
|
||||
values = info.data
|
||||
if values.get("do_validation"):
|
||||
if v is None:
|
||||
raise ValueError("validation_steps must be specified when do_validation is True")
|
||||
if values.get("checkpointing_steps") and v % values["checkpointing_steps"] != 0:
|
||||
raise ValueError("validation_steps must be a multiple of checkpointing_steps")
|
||||
return v
|
||||
|
||||
@field_validator("train_resolution")
|
||||
def validate_train_resolution(cls, v: Tuple[int, int, int], info: ValidationInfo) -> str:
|
||||
try:
|
||||
frames, height, width = v
|
||||
|
||||
# Check if (frames - 1) is multiple of 8
|
||||
if (frames - 1) % 8 != 0:
|
||||
raise ValueError("Number of frames - 1 must be a multiple of 8")
|
||||
|
||||
# Check resolution for cogvideox-5b models
|
||||
model_name = info.data.get("model_name", "")
|
||||
if model_name in ["cogvideox-5b-i2v", "cogvideox-5b-t2v"]:
|
||||
if (height, width) != (480, 720):
|
||||
raise ValueError(
|
||||
"For cogvideox-5b models, height must be 480 and width must be 720"
|
||||
)
|
||||
|
||||
return v
|
||||
|
||||
except ValueError as e:
|
||||
if (
|
||||
str(e) == "not enough values to unpack (expected 3, got 0)"
|
||||
or str(e) == "invalid literal for int() with base 10"
|
||||
):
|
||||
raise ValueError("train_resolution must be in format 'frames x height x width'")
|
||||
raise e
|
||||
|
||||
@field_validator("mixed_precision")
|
||||
def validate_mixed_precision(cls, v: str, info: ValidationInfo) -> str:
|
||||
if v == "fp16" and "cogvideox-2b" not in str(info.data.get("model_path", "")).lower():
|
||||
logging.warning(
|
||||
"All CogVideoX models except cogvideox-2b were trained with bfloat16. "
|
||||
"Using fp16 precision may lead to training instability."
|
||||
)
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def parse_args(cls):
|
||||
"""Parse command line arguments and return Args instance"""
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required arguments
|
||||
parser.add_argument("--model_path", type=str, required=True)
|
||||
parser.add_argument("--model_name", type=str, required=True)
|
||||
parser.add_argument("--model_type", type=str, required=True)
|
||||
parser.add_argument("--training_type", type=str, required=True)
|
||||
parser.add_argument("--output_dir", type=str, required=True)
|
||||
parser.add_argument("--data_root", type=str, required=True)
|
||||
parser.add_argument("--caption_column", type=str, required=True)
|
||||
parser.add_argument("--video_column", type=str, required=True)
|
||||
parser.add_argument("--train_resolution", type=str, required=True)
|
||||
parser.add_argument("--report_to", type=str, required=True)
|
||||
|
||||
# Training hyperparameters
|
||||
parser.add_argument("--seed", type=int, default=42)
|
||||
parser.add_argument("--train_epochs", type=int, default=10)
|
||||
parser.add_argument("--train_steps", type=int, default=None)
|
||||
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
|
||||
parser.add_argument("--batch_size", type=int, default=1)
|
||||
parser.add_argument("--learning_rate", type=float, default=2e-5)
|
||||
parser.add_argument("--optimizer", type=str, default="adamw")
|
||||
parser.add_argument("--beta1", type=float, default=0.9)
|
||||
parser.add_argument("--beta2", type=float, default=0.95)
|
||||
parser.add_argument("--beta3", type=float, default=0.98)
|
||||
parser.add_argument("--epsilon", type=float, default=1e-8)
|
||||
parser.add_argument("--weight_decay", type=float, default=1e-4)
|
||||
parser.add_argument("--max_grad_norm", type=float, default=1.0)
|
||||
|
||||
# Learning rate scheduler
|
||||
parser.add_argument("--lr_scheduler", type=str, default="constant_with_warmup")
|
||||
parser.add_argument("--lr_warmup_steps", type=int, default=100)
|
||||
parser.add_argument("--lr_num_cycles", type=int, default=1)
|
||||
parser.add_argument("--lr_power", type=float, default=1.0)
|
||||
|
||||
# Data loading
|
||||
parser.add_argument("--num_workers", type=int, default=8)
|
||||
parser.add_argument("--pin_memory", type=bool, default=True)
|
||||
parser.add_argument("--image_column", type=str, default=None)
|
||||
|
||||
# Model configuration
|
||||
parser.add_argument("--mixed_precision", type=str, default="no")
|
||||
parser.add_argument("--gradient_checkpointing", type=bool, default=True)
|
||||
parser.add_argument("--enable_slicing", type=bool, default=True)
|
||||
parser.add_argument("--enable_tiling", type=bool, default=True)
|
||||
parser.add_argument("--nccl_timeout", type=int, default=1800)
|
||||
|
||||
# LoRA parameters
|
||||
parser.add_argument("--rank", type=int, default=128)
|
||||
parser.add_argument("--lora_alpha", type=int, default=64)
|
||||
parser.add_argument(
|
||||
"--target_modules", type=str, nargs="+", default=["to_q", "to_k", "to_v", "to_out.0"]
|
||||
)
|
||||
|
||||
# Checkpointing
|
||||
parser.add_argument("--checkpointing_steps", type=int, default=200)
|
||||
parser.add_argument("--checkpointing_limit", type=int, default=10)
|
||||
parser.add_argument("--resume_from_checkpoint", type=str, default=None)
|
||||
|
||||
# Validation
|
||||
parser.add_argument("--do_validation", type=lambda x: x.lower() == 'true', default=False)
|
||||
parser.add_argument("--validation_steps", type=int, default=None)
|
||||
parser.add_argument("--validation_dir", type=str, default=None)
|
||||
parser.add_argument("--validation_prompts", type=str, default=None)
|
||||
parser.add_argument("--validation_images", type=str, default=None)
|
||||
parser.add_argument("--validation_videos", type=str, default=None)
|
||||
parser.add_argument("--gen_fps", type=int, default=15)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Convert video_resolution_buckets string to list of tuples
|
||||
frames, height, width = args.train_resolution.split("x")
|
||||
args.train_resolution = (int(frames), int(height), int(width))
|
||||
|
||||
return cls(**vars(args))
|
28
finetune/schemas/components.py
Normal file
28
finetune/schemas/components.py
Normal file
@ -0,0 +1,28 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Components(BaseModel):
|
||||
# pipeline cls
|
||||
pipeline_cls: Any = None
|
||||
|
||||
# Tokenizers
|
||||
tokenizer: Any = None
|
||||
tokenizer_2: Any = None
|
||||
tokenizer_3: Any = None
|
||||
|
||||
# Text encoders
|
||||
text_encoder: Any = None
|
||||
text_encoder_2: Any = None
|
||||
text_encoder_3: Any = None
|
||||
|
||||
# Autoencoder
|
||||
vae: Any = None
|
||||
|
||||
# Denoiser
|
||||
transformer: Any = None
|
||||
unet: Any = None
|
||||
|
||||
# Scheduler
|
||||
scheduler: Any = None
|
29
finetune/schemas/state.py
Normal file
29
finetune/schemas/state.py
Normal file
@ -0,0 +1,29 @@
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class State(BaseModel):
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
train_frames: int
|
||||
train_height: int
|
||||
train_width: int
|
||||
|
||||
transformer_config: Dict[str, Any] = None
|
||||
|
||||
weight_dtype: torch.dtype = torch.float32 # dtype for mixed precision training
|
||||
num_trainable_parameters: int = 0
|
||||
overwrote_max_train_steps: bool = False
|
||||
num_update_steps_per_epoch: int = 0
|
||||
total_batch_size_count: int = 0
|
||||
|
||||
generator: torch.Generator | None = None
|
||||
|
||||
validation_prompts: List[str] = []
|
||||
validation_images: List[Path | None] = []
|
||||
validation_videos: List[Path | None] = []
|
||||
|
||||
using_deepspeed: bool = False
|
60
finetune/scripts/extract_images.py
Normal file
60
finetune/scripts/extract_images.py
Normal file
@ -0,0 +1,60 @@
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--datadir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Root directory containing videos.txt and video subdirectory",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
args = parse_args()
|
||||
|
||||
# Create data/images directory if it doesn't exist
|
||||
data_dir = Path(args.datadir)
|
||||
image_dir = data_dir / "images"
|
||||
image_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Read videos.txt
|
||||
videos_file = data_dir / "videos.txt"
|
||||
with open(videos_file, "r") as f:
|
||||
video_paths = [line.strip() for line in f.readlines() if line.strip()]
|
||||
|
||||
# Process each video file and collect image paths
|
||||
image_paths = []
|
||||
for video_rel_path in video_paths:
|
||||
video_path = data_dir / video_rel_path
|
||||
|
||||
# Open video
|
||||
cap = cv2.VideoCapture(str(video_path))
|
||||
|
||||
# Read first frame
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
print(f"Failed to read video: {video_path}")
|
||||
continue
|
||||
|
||||
# Save frame as PNG with same name as video
|
||||
image_name = f"images/{video_path.stem}.png"
|
||||
image_path = data_dir / image_name
|
||||
cv2.imwrite(str(image_path), frame)
|
||||
|
||||
# Release video capture
|
||||
cap.release()
|
||||
|
||||
print(f"Extracted first frame from {video_path} to {image_path}")
|
||||
image_paths.append(image_name)
|
||||
|
||||
# Write images.txt
|
||||
images_file = data_dir / "images.txt"
|
||||
with open(images_file, "w") as f:
|
||||
for path in image_paths:
|
||||
f.write(f"{path}\n")
|
19
finetune/train.py
Normal file
19
finetune/train.py
Normal file
@ -0,0 +1,19 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
sys.path.append(str(Path(__file__).parent.parent))
|
||||
|
||||
from finetune.models.utils import get_model_cls
|
||||
from finetune.schemas import Args
|
||||
|
||||
|
||||
def main():
|
||||
args = Args.parse_args()
|
||||
trainer_cls = get_model_cls(args.model_name, args.training_type)
|
||||
trainer = trainer_cls(args)
|
||||
trainer.fit()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
70
finetune/train_ddp_i2v.sh
Normal file
70
finetune/train_ddp_i2v.sh
Normal file
@ -0,0 +1,70 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# Prevent tokenizer parallelism issues
|
||||
export TOKENIZERS_PARALLELISM=false
|
||||
|
||||
# Model Configuration
|
||||
MODEL_ARGS=(
|
||||
--model_path "THUDM/CogVideoX1.5-5B-I2V"
|
||||
--model_name "cogvideox1.5-i2v" # ["cogvideox-i2v"]
|
||||
--model_type "i2v"
|
||||
--training_type "lora"
|
||||
)
|
||||
|
||||
# Output Configuration
|
||||
OUTPUT_ARGS=(
|
||||
--output_dir "/path/to/your/output_dir"
|
||||
--report_to "tensorboard"
|
||||
)
|
||||
|
||||
# Data Configuration
|
||||
DATA_ARGS=(
|
||||
--data_root "/absolute/path/to/your/data_root"
|
||||
--caption_column "prompt.txt"
|
||||
--video_column "videos.txt"
|
||||
# --image_column "images.txt" # comment this line will use first frame of video as image conditioning
|
||||
--train_resolution "81x768x1360" # (frames x height x width), frames should be 8N+1
|
||||
)
|
||||
|
||||
# Training Configuration
|
||||
TRAIN_ARGS=(
|
||||
--train_epochs 10 # number of training epochs
|
||||
--seed 42 # random seed
|
||||
--batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--mixed_precision "bf16" # ["no", "fp16"] # Only CogVideoX-2B supports fp16 training
|
||||
)
|
||||
|
||||
# System Configuration
|
||||
SYSTEM_ARGS=(
|
||||
--num_workers 8
|
||||
--pin_memory True
|
||||
--nccl_timeout 1800
|
||||
)
|
||||
|
||||
# Checkpointing Configuration
|
||||
CHECKPOINT_ARGS=(
|
||||
--checkpointing_steps 10 # save checkpoint every x steps
|
||||
--checkpointing_limit 2 # maximum number of checkpoints to keep, after which the oldest one is deleted
|
||||
--resume_from_checkpoint "/absolute/path/to/checkpoint_dir" # if you want to resume from a checkpoint, otherwise, comment this line
|
||||
)
|
||||
|
||||
# Validation Configuration
|
||||
VALIDATION_ARGS=(
|
||||
--do_validation false # ["true", "false"]
|
||||
--validation_dir "/absolute/path/to/your/validation_set"
|
||||
--validation_steps 20 # should be multiple of checkpointing_steps
|
||||
--validation_prompts "prompts.txt"
|
||||
--validation_images "images.txt"
|
||||
--gen_fps 16
|
||||
)
|
||||
|
||||
# Combine all arguments and launch training
|
||||
accelerate launch train.py \
|
||||
"${MODEL_ARGS[@]}" \
|
||||
"${OUTPUT_ARGS[@]}" \
|
||||
"${DATA_ARGS[@]}" \
|
||||
"${TRAIN_ARGS[@]}" \
|
||||
"${SYSTEM_ARGS[@]}" \
|
||||
"${CHECKPOINT_ARGS[@]}" \
|
||||
"${VALIDATION_ARGS[@]}"
|
68
finetune/train_ddp_t2v.sh
Normal file
68
finetune/train_ddp_t2v.sh
Normal file
@ -0,0 +1,68 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# Prevent tokenizer parallelism issues
|
||||
export TOKENIZERS_PARALLELISM=false
|
||||
|
||||
# Model Configuration
|
||||
MODEL_ARGS=(
|
||||
--model_path "THUDM/CogVideoX1.5-5B"
|
||||
--model_name "cogvideox1.5-t2v" # ["cogvideox-t2v"]
|
||||
--model_type "t2v"
|
||||
--training_type "lora"
|
||||
)
|
||||
|
||||
# Output Configuration
|
||||
OUTPUT_ARGS=(
|
||||
--output_dir "/absolute/path/to/your/output_dir"
|
||||
--report_to "tensorboard"
|
||||
)
|
||||
|
||||
# Data Configuration
|
||||
DATA_ARGS=(
|
||||
--data_root "/absolute/path/to/your/data_root"
|
||||
--caption_column "prompt.txt"
|
||||
--video_column "videos.txt"
|
||||
--train_resolution "81x768x1360" # (frames x height x width), frames should be 8N+1
|
||||
)
|
||||
|
||||
# Training Configuration
|
||||
TRAIN_ARGS=(
|
||||
--train_epochs 10 # number of training epochs
|
||||
--seed 42 # random seed
|
||||
--batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--mixed_precision "bf16" # ["no", "fp16"] # Only CogVideoX-2B supports fp16 training
|
||||
)
|
||||
|
||||
# System Configuration
|
||||
SYSTEM_ARGS=(
|
||||
--num_workers 8
|
||||
--pin_memory True
|
||||
--nccl_timeout 1800
|
||||
)
|
||||
|
||||
# Checkpointing Configuration
|
||||
CHECKPOINT_ARGS=(
|
||||
--checkpointing_steps 10 # save checkpoint every x steps
|
||||
--checkpointing_limit 2 # maximum number of checkpoints to keep, after which the oldest one is deleted
|
||||
--resume_from_checkpoint "/absolute/path/to/checkpoint_dir" # if you want to resume from a checkpoint, otherwise, comment this line
|
||||
)
|
||||
|
||||
# Validation Configuration
|
||||
VALIDATION_ARGS=(
|
||||
--do_validation false # ["true", "false"]
|
||||
--validation_dir "/absolute/path/to/your/validation_set"
|
||||
--validation_steps 20 # should be multiple of checkpointing_steps
|
||||
--validation_prompts "prompts.txt"
|
||||
--gen_fps 16
|
||||
)
|
||||
|
||||
# Combine all arguments and launch training
|
||||
accelerate launch train.py \
|
||||
"${MODEL_ARGS[@]}" \
|
||||
"${OUTPUT_ARGS[@]}" \
|
||||
"${DATA_ARGS[@]}" \
|
||||
"${TRAIN_ARGS[@]}" \
|
||||
"${SYSTEM_ARGS[@]}" \
|
||||
"${CHECKPOINT_ARGS[@]}" \
|
||||
"${VALIDATION_ARGS[@]}"
|
73
finetune/train_zero_i2v.sh
Normal file
73
finetune/train_zero_i2v.sh
Normal file
@ -0,0 +1,73 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# Prevent tokenizer parallelism issues
|
||||
export TOKENIZERS_PARALLELISM=false
|
||||
|
||||
# Model Configuration
|
||||
MODEL_ARGS=(
|
||||
--model_path "THUDM/CogVideoX1.5-5B-I2V"
|
||||
--model_name "cogvideox1.5-i2v" # ["cogvideox-i2v"]
|
||||
--model_type "i2v"
|
||||
--training_type "sft"
|
||||
)
|
||||
|
||||
# Output Configuration
|
||||
OUTPUT_ARGS=(
|
||||
--output_dir "/absolute/path/to/your/output_dir"
|
||||
--report_to "tensorboard"
|
||||
)
|
||||
|
||||
# Data Configuration
|
||||
DATA_ARGS=(
|
||||
--data_root "/absolute/path/to/your/data_root"
|
||||
--caption_column "prompt.txt"
|
||||
--video_column "videos.txt"
|
||||
# --image_column "images.txt" # comment this line will use first frame of video as image conditioning
|
||||
--train_resolution "81x768x1360" # (frames x height x width), frames should be 8N+1 and height, width should be multiples of 16
|
||||
)
|
||||
|
||||
# Training Configuration
|
||||
TRAIN_ARGS=(
|
||||
--train_epochs 10 # number of training epochs
|
||||
--seed 42 # random seed
|
||||
|
||||
######### Please keep consistent with deepspeed config file ##########
|
||||
--batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--mixed_precision "bf16" # ["no", "fp16"] Only CogVideoX-2B supports fp16 training
|
||||
########################################################################
|
||||
)
|
||||
|
||||
# System Configuration
|
||||
SYSTEM_ARGS=(
|
||||
--num_workers 8
|
||||
--pin_memory True
|
||||
--nccl_timeout 1800
|
||||
)
|
||||
|
||||
# Checkpointing Configuration
|
||||
CHECKPOINT_ARGS=(
|
||||
--checkpointing_steps 10 # save checkpoint every x steps
|
||||
--checkpointing_limit 2 # maximum number of checkpoints to keep, after which the oldest one is deleted
|
||||
# --resume_from_checkpoint "/absolute/path/to/checkpoint_dir" # if you want to resume from a checkpoint, otherwise, comment this line
|
||||
)
|
||||
|
||||
# Validation Configuration
|
||||
VALIDATION_ARGS=(
|
||||
--do_validation false # ["true", "false"]
|
||||
--validation_dir "/absolute/path/to/validation_set"
|
||||
--validation_steps 20 # should be multiple of checkpointing_steps
|
||||
--validation_prompts "prompts.txt"
|
||||
--validation_images "images.txt"
|
||||
--gen_fps 16
|
||||
)
|
||||
|
||||
# Combine all arguments and launch training
|
||||
accelerate launch --config_file accelerate_config.yaml train.py \
|
||||
"${MODEL_ARGS[@]}" \
|
||||
"${OUTPUT_ARGS[@]}" \
|
||||
"${DATA_ARGS[@]}" \
|
||||
"${TRAIN_ARGS[@]}" \
|
||||
"${SYSTEM_ARGS[@]}" \
|
||||
"${CHECKPOINT_ARGS[@]}" \
|
||||
"${VALIDATION_ARGS[@]}"
|
71
finetune/train_zero_t2v.sh
Normal file
71
finetune/train_zero_t2v.sh
Normal file
@ -0,0 +1,71 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# Prevent tokenizer parallelism issues
|
||||
export TOKENIZERS_PARALLELISM=false
|
||||
|
||||
# Model Configuration
|
||||
MODEL_ARGS=(
|
||||
--model_path "THUDM/CogVideoX1.5-5B"
|
||||
--model_name "cogvideox1.5-t2v" # ["cogvideox-t2v"]
|
||||
--model_type "t2v"
|
||||
--training_type "sft"
|
||||
)
|
||||
|
||||
# Output Configuration
|
||||
OUTPUT_ARGS=(
|
||||
--output_dir "/absolute/path/to/your/output_dir"
|
||||
--report_to "tensorboard"
|
||||
)
|
||||
|
||||
# Data Configuration
|
||||
DATA_ARGS=(
|
||||
--data_root "/absolute/path/to/your/data_root"
|
||||
--caption_column "prompt.txt"
|
||||
--video_column "videos.txt"
|
||||
--train_resolution "81x768x1360" # (frames x height x width), frames should be 8N+1 and height, width should be multiples of 16
|
||||
)
|
||||
|
||||
# Training Configuration
|
||||
TRAIN_ARGS=(
|
||||
--train_epochs 10 # number of training epochs
|
||||
--seed 42 # random seed
|
||||
|
||||
######### Please keep consistent with deepspeed config file ##########
|
||||
--batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--mixed_precision "bf16" # ["no", "fp16"] Only CogVideoX-2B supports fp16 training
|
||||
########################################################################
|
||||
)
|
||||
|
||||
# System Configuration
|
||||
SYSTEM_ARGS=(
|
||||
--num_workers 8
|
||||
--pin_memory True
|
||||
--nccl_timeout 1800
|
||||
)
|
||||
|
||||
# Checkpointing Configuration
|
||||
CHECKPOINT_ARGS=(
|
||||
--checkpointing_steps 10 # save checkpoint every x steps
|
||||
--checkpointing_limit 2 # maximum number of checkpoints to keep, after which the oldest one is deleted
|
||||
# --resume_from_checkpoint "/absolute/path/to/checkpoint_dir" # if you want to resume from a checkpoint, otherwise, comment this line
|
||||
)
|
||||
|
||||
# Validation Configuration
|
||||
VALIDATION_ARGS=(
|
||||
--do_validation false # ["true", "false"]
|
||||
--validation_dir "/absolute/path/to/validation_set"
|
||||
--validation_steps 20 # should be multiple of checkpointing_steps
|
||||
--validation_prompts "prompts.txt"
|
||||
--gen_fps 16
|
||||
)
|
||||
|
||||
# Combine all arguments and launch training
|
||||
accelerate launch --config_file accelerate_config.yaml train.py \
|
||||
"${MODEL_ARGS[@]}" \
|
||||
"${OUTPUT_ARGS[@]}" \
|
||||
"${DATA_ARGS[@]}" \
|
||||
"${TRAIN_ARGS[@]}" \
|
||||
"${SYSTEM_ARGS[@]}" \
|
||||
"${CHECKPOINT_ARGS[@]}" \
|
||||
"${VALIDATION_ARGS[@]}"
|
811
finetune/trainer.py
Normal file
811
finetune/trainer.py
Normal file
@ -0,0 +1,811 @@
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import diffusers
|
||||
import torch
|
||||
import transformers
|
||||
import wandb
|
||||
from accelerate.accelerator import Accelerator, DistributedType
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import (
|
||||
DistributedDataParallelKwargs,
|
||||
InitProcessGroupKwargs,
|
||||
ProjectConfiguration,
|
||||
gather_object,
|
||||
set_seed,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.pipelines import DiffusionPipeline
|
||||
from diffusers.utils.export_utils import export_to_video
|
||||
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
|
||||
from PIL import Image
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
from finetune.constants import LOG_LEVEL, LOG_NAME
|
||||
from finetune.datasets import I2VDatasetWithResize, T2VDatasetWithResize
|
||||
from finetune.datasets.utils import (
|
||||
load_images,
|
||||
load_prompts,
|
||||
load_videos,
|
||||
preprocess_image_with_resize,
|
||||
preprocess_video_with_resize,
|
||||
)
|
||||
from finetune.schemas import Args, Components, State
|
||||
from finetune.utils import (
|
||||
cast_training_params,
|
||||
free_memory,
|
||||
get_intermediate_ckpt_path,
|
||||
get_latest_ckpt_path_to_resume_from,
|
||||
get_memory_statistics,
|
||||
get_optimizer,
|
||||
string_to_filename,
|
||||
unload_model,
|
||||
unwrap_model,
|
||||
)
|
||||
|
||||
|
||||
logger = get_logger(LOG_NAME, LOG_LEVEL)
|
||||
|
||||
_DTYPE_MAP = {
|
||||
"fp32": torch.float32,
|
||||
"fp16": torch.float16, # FP16 is Only Support for CogVideoX-2B
|
||||
"bf16": torch.bfloat16,
|
||||
}
|
||||
|
||||
|
||||
class Trainer:
|
||||
# If set, should be a list of components to unload (refer to `Components``)
|
||||
UNLOAD_LIST: List[str] = None
|
||||
|
||||
def __init__(self, args: Args) -> None:
|
||||
self.args = args
|
||||
self.state = State(
|
||||
weight_dtype=self.__get_training_dtype(),
|
||||
train_frames=self.args.train_resolution[0],
|
||||
train_height=self.args.train_resolution[1],
|
||||
train_width=self.args.train_resolution[2],
|
||||
)
|
||||
|
||||
self.components: Components = self.load_components()
|
||||
self.accelerator: Accelerator = None
|
||||
self.dataset: Dataset = None
|
||||
self.data_loader: DataLoader = None
|
||||
|
||||
self.optimizer = None
|
||||
self.lr_scheduler = None
|
||||
|
||||
self._init_distributed()
|
||||
self._init_logging()
|
||||
self._init_directories()
|
||||
|
||||
self.state.using_deepspeed = self.accelerator.state.deepspeed_plugin is not None
|
||||
|
||||
def _init_distributed(self):
|
||||
logging_dir = Path(self.args.output_dir, "logs")
|
||||
project_config = ProjectConfiguration(
|
||||
project_dir=self.args.output_dir, logging_dir=logging_dir
|
||||
)
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||
init_process_group_kwargs = InitProcessGroupKwargs(
|
||||
backend="nccl", timeout=timedelta(seconds=self.args.nccl_timeout)
|
||||
)
|
||||
mixed_precision = "no" if torch.backends.mps.is_available() else self.args.mixed_precision
|
||||
report_to = None if self.args.report_to.lower() == "none" else self.args.report_to
|
||||
|
||||
accelerator = Accelerator(
|
||||
project_config=project_config,
|
||||
gradient_accumulation_steps=self.args.gradient_accumulation_steps,
|
||||
mixed_precision=mixed_precision,
|
||||
log_with=report_to,
|
||||
kwargs_handlers=[ddp_kwargs, init_process_group_kwargs],
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
self.accelerator = accelerator
|
||||
|
||||
if self.args.seed is not None:
|
||||
set_seed(self.args.seed)
|
||||
|
||||
def _init_logging(self) -> None:
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=LOG_LEVEL,
|
||||
)
|
||||
if self.accelerator.is_local_main_process:
|
||||
transformers.utils.logging.set_verbosity_warning()
|
||||
diffusers.utils.logging.set_verbosity_info()
|
||||
else:
|
||||
transformers.utils.logging.set_verbosity_error()
|
||||
diffusers.utils.logging.set_verbosity_error()
|
||||
|
||||
logger.info("Initialized Trainer")
|
||||
logger.info(f"Accelerator state: \n{self.accelerator.state}", main_process_only=False)
|
||||
|
||||
def _init_directories(self) -> None:
|
||||
if self.accelerator.is_main_process:
|
||||
self.args.output_dir = Path(self.args.output_dir)
|
||||
self.args.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def check_setting(self) -> None:
|
||||
# Check for unload_list
|
||||
if self.UNLOAD_LIST is None:
|
||||
logger.warning(
|
||||
"\033[91mNo unload_list specified for this Trainer. All components will be loaded to GPU during training.\033[0m"
|
||||
)
|
||||
else:
|
||||
for name in self.UNLOAD_LIST:
|
||||
if name not in self.components.model_fields:
|
||||
raise ValueError(f"Invalid component name in unload_list: {name}")
|
||||
|
||||
def prepare_models(self) -> None:
|
||||
logger.info("Initializing models")
|
||||
|
||||
if self.components.vae is not None:
|
||||
if self.args.enable_slicing:
|
||||
self.components.vae.enable_slicing()
|
||||
if self.args.enable_tiling:
|
||||
self.components.vae.enable_tiling()
|
||||
|
||||
self.state.transformer_config = self.components.transformer.config
|
||||
|
||||
def prepare_dataset(self) -> None:
|
||||
logger.info("Initializing dataset and dataloader")
|
||||
|
||||
if self.args.model_type == "i2v":
|
||||
self.dataset = I2VDatasetWithResize(
|
||||
**(self.args.model_dump()),
|
||||
device=self.accelerator.device,
|
||||
max_num_frames=self.state.train_frames,
|
||||
height=self.state.train_height,
|
||||
width=self.state.train_width,
|
||||
trainer=self,
|
||||
)
|
||||
elif self.args.model_type == "t2v":
|
||||
self.dataset = T2VDatasetWithResize(
|
||||
**(self.args.model_dump()),
|
||||
device=self.accelerator.device,
|
||||
max_num_frames=self.state.train_frames,
|
||||
height=self.state.train_height,
|
||||
width=self.state.train_width,
|
||||
trainer=self,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid model type: {self.args.model_type}")
|
||||
|
||||
# Prepare VAE and text encoder for encoding
|
||||
self.components.vae.requires_grad_(False)
|
||||
self.components.text_encoder.requires_grad_(False)
|
||||
self.components.vae = self.components.vae.to(
|
||||
self.accelerator.device, dtype=self.state.weight_dtype
|
||||
)
|
||||
self.components.text_encoder = self.components.text_encoder.to(
|
||||
self.accelerator.device, dtype=self.state.weight_dtype
|
||||
)
|
||||
|
||||
# Precompute latent for video and prompt embedding
|
||||
logger.info("Precomputing latent for video and prompt embedding ...")
|
||||
tmp_data_loader = torch.utils.data.DataLoader(
|
||||
self.dataset,
|
||||
collate_fn=self.collate_fn,
|
||||
batch_size=1,
|
||||
num_workers=0,
|
||||
pin_memory=self.args.pin_memory,
|
||||
)
|
||||
tmp_data_loader = self.accelerator.prepare_data_loader(tmp_data_loader)
|
||||
for _ in tmp_data_loader:
|
||||
...
|
||||
self.accelerator.wait_for_everyone()
|
||||
logger.info("Precomputing latent for video and prompt embedding ... Done")
|
||||
|
||||
unload_model(self.components.vae)
|
||||
unload_model(self.components.text_encoder)
|
||||
free_memory()
|
||||
|
||||
self.data_loader = torch.utils.data.DataLoader(
|
||||
self.dataset,
|
||||
collate_fn=self.collate_fn,
|
||||
batch_size=self.args.batch_size,
|
||||
num_workers=self.args.num_workers,
|
||||
pin_memory=self.args.pin_memory,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
def prepare_trainable_parameters(self):
|
||||
logger.info("Initializing trainable parameters")
|
||||
|
||||
# For mixed precision training we cast all non-trainable weights to half-precision
|
||||
# as these weights are only used for inference, keeping weights in full precision is not required.
|
||||
weight_dtype = self.state.weight_dtype
|
||||
|
||||
if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
|
||||
# due to pytorch#99272, MPS does not yet support bfloat16.
|
||||
raise ValueError(
|
||||
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
|
||||
)
|
||||
|
||||
# For LoRA, we freeze all the parameters
|
||||
# For SFT, we train all the parameters in transformer model
|
||||
for attr_name, component in vars(self.components).items():
|
||||
if hasattr(component, "requires_grad_"):
|
||||
if self.args.training_type == "sft" and attr_name == "transformer":
|
||||
component.requires_grad_(True)
|
||||
else:
|
||||
component.requires_grad_(False)
|
||||
|
||||
if self.args.training_type == "lora":
|
||||
transformer_lora_config = LoraConfig(
|
||||
r=self.args.rank,
|
||||
lora_alpha=self.args.lora_alpha,
|
||||
init_lora_weights=True,
|
||||
target_modules=self.args.target_modules,
|
||||
)
|
||||
self.components.transformer.add_adapter(transformer_lora_config)
|
||||
self.__prepare_saving_loading_hooks(transformer_lora_config)
|
||||
|
||||
# Load components needed for training to GPU (except transformer), and cast them to the specified data type
|
||||
ignore_list = ["transformer"] + self.UNLOAD_LIST
|
||||
self.__move_components_to_device(dtype=weight_dtype, ignore_list=ignore_list)
|
||||
|
||||
if self.args.gradient_checkpointing:
|
||||
self.components.transformer.enable_gradient_checkpointing()
|
||||
|
||||
def prepare_optimizer(self) -> None:
|
||||
logger.info("Initializing optimizer and lr scheduler")
|
||||
|
||||
# Make sure the trainable params are in float32
|
||||
cast_training_params([self.components.transformer], dtype=torch.float32)
|
||||
|
||||
# For LoRA, we only want to train the LoRA weights
|
||||
# For SFT, we want to train all the parameters
|
||||
trainable_parameters = list(
|
||||
filter(lambda p: p.requires_grad, self.components.transformer.parameters())
|
||||
)
|
||||
transformer_parameters_with_lr = {
|
||||
"params": trainable_parameters,
|
||||
"lr": self.args.learning_rate,
|
||||
}
|
||||
params_to_optimize = [transformer_parameters_with_lr]
|
||||
self.state.num_trainable_parameters = sum(p.numel() for p in trainable_parameters)
|
||||
|
||||
use_deepspeed_opt = (
|
||||
self.accelerator.state.deepspeed_plugin is not None
|
||||
and "optimizer" in self.accelerator.state.deepspeed_plugin.deepspeed_config
|
||||
)
|
||||
optimizer = get_optimizer(
|
||||
params_to_optimize=params_to_optimize,
|
||||
optimizer_name=self.args.optimizer,
|
||||
learning_rate=self.args.learning_rate,
|
||||
beta1=self.args.beta1,
|
||||
beta2=self.args.beta2,
|
||||
beta3=self.args.beta3,
|
||||
epsilon=self.args.epsilon,
|
||||
weight_decay=self.args.weight_decay,
|
||||
use_deepspeed=use_deepspeed_opt,
|
||||
)
|
||||
|
||||
num_update_steps_per_epoch = math.ceil(
|
||||
len(self.data_loader) / self.args.gradient_accumulation_steps
|
||||
)
|
||||
if self.args.train_steps is None:
|
||||
self.args.train_steps = self.args.train_epochs * num_update_steps_per_epoch
|
||||
self.state.overwrote_max_train_steps = True
|
||||
|
||||
use_deepspeed_lr_scheduler = (
|
||||
self.accelerator.state.deepspeed_plugin is not None
|
||||
and "scheduler" in self.accelerator.state.deepspeed_plugin.deepspeed_config
|
||||
)
|
||||
total_training_steps = self.args.train_steps * self.accelerator.num_processes
|
||||
num_warmup_steps = self.args.lr_warmup_steps * self.accelerator.num_processes
|
||||
|
||||
if use_deepspeed_lr_scheduler:
|
||||
from accelerate.utils import DummyScheduler
|
||||
|
||||
lr_scheduler = DummyScheduler(
|
||||
name=self.args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
total_num_steps=total_training_steps,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
)
|
||||
else:
|
||||
lr_scheduler = get_scheduler(
|
||||
name=self.args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
num_training_steps=total_training_steps,
|
||||
num_cycles=self.args.lr_num_cycles,
|
||||
power=self.args.lr_power,
|
||||
)
|
||||
|
||||
self.optimizer = optimizer
|
||||
self.lr_scheduler = lr_scheduler
|
||||
|
||||
def prepare_for_training(self) -> None:
|
||||
self.components.transformer, self.optimizer, self.data_loader, self.lr_scheduler = (
|
||||
self.accelerator.prepare(
|
||||
self.components.transformer, self.optimizer, self.data_loader, self.lr_scheduler
|
||||
)
|
||||
)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(
|
||||
len(self.data_loader) / self.args.gradient_accumulation_steps
|
||||
)
|
||||
if self.state.overwrote_max_train_steps:
|
||||
self.args.train_steps = self.args.train_epochs * num_update_steps_per_epoch
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
self.args.train_epochs = math.ceil(self.args.train_steps / num_update_steps_per_epoch)
|
||||
self.state.num_update_steps_per_epoch = num_update_steps_per_epoch
|
||||
|
||||
def prepare_for_validation(self):
|
||||
validation_prompts = load_prompts(self.args.validation_dir / self.args.validation_prompts)
|
||||
|
||||
if self.args.validation_images is not None:
|
||||
validation_images = load_images(self.args.validation_dir / self.args.validation_images)
|
||||
else:
|
||||
validation_images = [None] * len(validation_prompts)
|
||||
|
||||
if self.args.validation_videos is not None:
|
||||
validation_videos = load_videos(self.args.validation_dir / self.args.validation_videos)
|
||||
else:
|
||||
validation_videos = [None] * len(validation_prompts)
|
||||
|
||||
self.state.validation_prompts = validation_prompts
|
||||
self.state.validation_images = validation_images
|
||||
self.state.validation_videos = validation_videos
|
||||
|
||||
def prepare_trackers(self) -> None:
|
||||
logger.info("Initializing trackers")
|
||||
|
||||
tracker_name = self.args.tracker_name or "finetrainers-experiment"
|
||||
self.accelerator.init_trackers(tracker_name, config=self.args.model_dump())
|
||||
|
||||
def train(self) -> None:
|
||||
logger.info("Starting training")
|
||||
|
||||
memory_statistics = get_memory_statistics()
|
||||
logger.info(f"Memory before training start: {json.dumps(memory_statistics, indent=4)}")
|
||||
|
||||
self.state.total_batch_size_count = (
|
||||
self.args.batch_size
|
||||
* self.accelerator.num_processes
|
||||
* self.args.gradient_accumulation_steps
|
||||
)
|
||||
info = {
|
||||
"trainable parameters": self.state.num_trainable_parameters,
|
||||
"total samples": len(self.dataset),
|
||||
"train epochs": self.args.train_epochs,
|
||||
"train steps": self.args.train_steps,
|
||||
"batches per device": self.args.batch_size,
|
||||
"total batches observed per epoch": len(self.data_loader),
|
||||
"train batch size total count": self.state.total_batch_size_count,
|
||||
"gradient accumulation steps": self.args.gradient_accumulation_steps,
|
||||
}
|
||||
logger.info(f"Training configuration: {json.dumps(info, indent=4)}")
|
||||
|
||||
global_step = 0
|
||||
first_epoch = 0
|
||||
initial_global_step = 0
|
||||
|
||||
# Potentially load in the weights and states from a previous save
|
||||
(
|
||||
resume_from_checkpoint_path,
|
||||
initial_global_step,
|
||||
global_step,
|
||||
first_epoch,
|
||||
) = get_latest_ckpt_path_to_resume_from(
|
||||
resume_from_checkpoint=self.args.resume_from_checkpoint,
|
||||
num_update_steps_per_epoch=self.state.num_update_steps_per_epoch,
|
||||
)
|
||||
if resume_from_checkpoint_path is not None:
|
||||
self.accelerator.load_state(resume_from_checkpoint_path)
|
||||
|
||||
progress_bar = tqdm(
|
||||
range(0, self.args.train_steps),
|
||||
initial=initial_global_step,
|
||||
desc="Training steps",
|
||||
disable=not self.accelerator.is_local_main_process,
|
||||
)
|
||||
|
||||
accelerator = self.accelerator
|
||||
generator = torch.Generator(device=accelerator.device)
|
||||
if self.args.seed is not None:
|
||||
generator = generator.manual_seed(self.args.seed)
|
||||
self.state.generator = generator
|
||||
|
||||
free_memory()
|
||||
for epoch in range(first_epoch, self.args.train_epochs):
|
||||
logger.debug(f"Starting epoch ({epoch + 1}/{self.args.train_epochs})")
|
||||
|
||||
self.components.transformer.train()
|
||||
models_to_accumulate = [self.components.transformer]
|
||||
|
||||
for step, batch in enumerate(self.data_loader):
|
||||
logger.debug(f"Starting step {step + 1}")
|
||||
logs = {}
|
||||
|
||||
with accelerator.accumulate(models_to_accumulate):
|
||||
# These weighting schemes use a uniform timestep sampling and instead post-weight the loss
|
||||
loss = self.compute_loss(batch)
|
||||
accelerator.backward(loss)
|
||||
|
||||
if accelerator.sync_gradients:
|
||||
if accelerator.distributed_type == DistributedType.DEEPSPEED:
|
||||
grad_norm = self.components.transformer.get_global_grad_norm()
|
||||
# In some cases the grad norm may not return a float
|
||||
if torch.is_tensor(grad_norm):
|
||||
grad_norm = grad_norm.item()
|
||||
else:
|
||||
grad_norm = accelerator.clip_grad_norm_(
|
||||
self.components.transformer.parameters(), self.args.max_grad_norm
|
||||
)
|
||||
if torch.is_tensor(grad_norm):
|
||||
grad_norm = grad_norm.item()
|
||||
|
||||
logs["grad_norm"] = grad_norm
|
||||
|
||||
self.optimizer.step()
|
||||
self.lr_scheduler.step()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
self.__maybe_save_checkpoint(global_step)
|
||||
|
||||
logs["loss"] = loss.detach().item()
|
||||
logs["lr"] = self.lr_scheduler.get_last_lr()[0]
|
||||
progress_bar.set_postfix(logs)
|
||||
|
||||
# Maybe run validation
|
||||
should_run_validation = (
|
||||
self.args.do_validation and global_step % self.args.validation_steps == 0
|
||||
)
|
||||
if should_run_validation:
|
||||
del loss
|
||||
free_memory()
|
||||
self.validate(global_step)
|
||||
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if global_step >= self.args.train_steps:
|
||||
break
|
||||
|
||||
memory_statistics = get_memory_statistics()
|
||||
logger.info(
|
||||
f"Memory after epoch {epoch + 1}: {json.dumps(memory_statistics, indent=4)}"
|
||||
)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
self.__maybe_save_checkpoint(global_step, must_save=True)
|
||||
if self.args.do_validation:
|
||||
free_memory()
|
||||
self.validate(global_step)
|
||||
|
||||
del self.components
|
||||
free_memory()
|
||||
memory_statistics = get_memory_statistics()
|
||||
logger.info(f"Memory after training end: {json.dumps(memory_statistics, indent=4)}")
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
def validate(self, step: int) -> None:
|
||||
logger.info("Starting validation")
|
||||
|
||||
accelerator = self.accelerator
|
||||
num_validation_samples = len(self.state.validation_prompts)
|
||||
|
||||
if num_validation_samples == 0:
|
||||
logger.warning("No validation samples found. Skipping validation.")
|
||||
return
|
||||
|
||||
self.components.transformer.eval()
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
memory_statistics = get_memory_statistics()
|
||||
logger.info(f"Memory before validation start: {json.dumps(memory_statistics, indent=4)}")
|
||||
|
||||
##### Initialize pipeline #####
|
||||
pipe = self.initialize_pipeline()
|
||||
|
||||
if self.state.using_deepspeed:
|
||||
# Can't using model_cpu_offload in deepspeed,
|
||||
# so we need to move all components in pipe to device
|
||||
# pipe.to(self.accelerator.device, dtype=self.state.weight_dtype)
|
||||
self.__move_components_to_device(
|
||||
dtype=self.state.weight_dtype, ignore_list=["transformer"]
|
||||
)
|
||||
else:
|
||||
# if not using deepspeed, use model_cpu_offload to further reduce memory usage
|
||||
# Or use pipe.enable_sequential_cpu_offload() to further reduce memory usage
|
||||
pipe.enable_model_cpu_offload(device=self.accelerator.device)
|
||||
|
||||
# Convert all model weights to training dtype
|
||||
# Note, this will change LoRA weights in self.components.transformer to training dtype, rather than keep them in fp32
|
||||
pipe = pipe.to(dtype=self.state.weight_dtype)
|
||||
|
||||
#################################
|
||||
|
||||
all_processes_artifacts = []
|
||||
for i in range(num_validation_samples):
|
||||
if self.state.using_deepspeed and self.accelerator.deepspeed_plugin.zero_stage != 3:
|
||||
# Skip current validation on all processes but one
|
||||
if i % accelerator.num_processes != accelerator.process_index:
|
||||
continue
|
||||
|
||||
prompt = self.state.validation_prompts[i]
|
||||
image = self.state.validation_images[i]
|
||||
video = self.state.validation_videos[i]
|
||||
|
||||
if image is not None:
|
||||
image = preprocess_image_with_resize(
|
||||
image, self.state.train_height, self.state.train_width
|
||||
)
|
||||
# Convert image tensor (C, H, W) to PIL images
|
||||
image = image.to(torch.uint8)
|
||||
image = image.permute(1, 2, 0).cpu().numpy()
|
||||
image = Image.fromarray(image)
|
||||
|
||||
if video is not None:
|
||||
video = preprocess_video_with_resize(
|
||||
video, self.state.train_frames, self.state.train_height, self.state.train_width
|
||||
)
|
||||
# Convert video tensor (F, C, H, W) to list of PIL images
|
||||
video = video.round().clamp(0, 255).to(torch.uint8)
|
||||
video = [Image.fromarray(frame.permute(1, 2, 0).cpu().numpy()) for frame in video]
|
||||
|
||||
logger.debug(
|
||||
f"Validating sample {i + 1}/{num_validation_samples} on process {accelerator.process_index}. Prompt: {prompt}",
|
||||
main_process_only=False,
|
||||
)
|
||||
validation_artifacts = self.validation_step(
|
||||
{"prompt": prompt, "image": image, "video": video}, pipe
|
||||
)
|
||||
|
||||
if (
|
||||
self.state.using_deepspeed
|
||||
and self.accelerator.deepspeed_plugin.zero_stage == 3
|
||||
and not accelerator.is_main_process
|
||||
):
|
||||
continue
|
||||
|
||||
prompt_filename = string_to_filename(prompt)[:25]
|
||||
# Calculate hash of reversed prompt as a unique identifier
|
||||
reversed_prompt = prompt[::-1]
|
||||
hash_suffix = hashlib.md5(reversed_prompt.encode()).hexdigest()[:5]
|
||||
|
||||
artifacts = {
|
||||
"image": {"type": "image", "value": image},
|
||||
"video": {"type": "video", "value": video},
|
||||
}
|
||||
for i, (artifact_type, artifact_value) in enumerate(validation_artifacts):
|
||||
artifacts.update(
|
||||
{f"artifact_{i}": {"type": artifact_type, "value": artifact_value}}
|
||||
)
|
||||
logger.debug(
|
||||
f"Validation artifacts on process {accelerator.process_index}: {list(artifacts.keys())}",
|
||||
main_process_only=False,
|
||||
)
|
||||
|
||||
for key, value in list(artifacts.items()):
|
||||
artifact_type = value["type"]
|
||||
artifact_value = value["value"]
|
||||
if artifact_type not in ["image", "video"] or artifact_value is None:
|
||||
continue
|
||||
|
||||
extension = "png" if artifact_type == "image" else "mp4"
|
||||
filename = f"validation-{step}-{accelerator.process_index}-{prompt_filename}-{hash_suffix}.{extension}"
|
||||
validation_path = self.args.output_dir / "validation_res"
|
||||
validation_path.mkdir(parents=True, exist_ok=True)
|
||||
filename = str(validation_path / filename)
|
||||
|
||||
if artifact_type == "image":
|
||||
logger.debug(f"Saving image to {filename}")
|
||||
artifact_value.save(filename)
|
||||
artifact_value = wandb.Image(filename)
|
||||
elif artifact_type == "video":
|
||||
logger.debug(f"Saving video to {filename}")
|
||||
export_to_video(artifact_value, filename, fps=self.args.gen_fps)
|
||||
artifact_value = wandb.Video(filename, caption=prompt)
|
||||
|
||||
all_processes_artifacts.append(artifact_value)
|
||||
|
||||
all_artifacts = gather_object(all_processes_artifacts)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
tracker_key = "validation"
|
||||
for tracker in accelerator.trackers:
|
||||
if tracker.name == "wandb":
|
||||
image_artifacts = [
|
||||
artifact for artifact in all_artifacts if isinstance(artifact, wandb.Image)
|
||||
]
|
||||
video_artifacts = [
|
||||
artifact for artifact in all_artifacts if isinstance(artifact, wandb.Video)
|
||||
]
|
||||
tracker.log(
|
||||
{
|
||||
tracker_key: {"images": image_artifacts, "videos": video_artifacts},
|
||||
},
|
||||
step=step,
|
||||
)
|
||||
|
||||
########## Clean up ##########
|
||||
if self.state.using_deepspeed:
|
||||
del pipe
|
||||
# Unload models except those needed for training
|
||||
self.__move_components_to_cpu(unload_list=self.UNLOAD_LIST)
|
||||
else:
|
||||
pipe.remove_all_hooks()
|
||||
del pipe
|
||||
# Load models except those not needed for training
|
||||
self.__move_components_to_device(
|
||||
dtype=self.state.weight_dtype, ignore_list=self.UNLOAD_LIST
|
||||
)
|
||||
self.components.transformer.to(self.accelerator.device, dtype=self.state.weight_dtype)
|
||||
|
||||
# Change trainable weights back to fp32 to keep with dtype after prepare the model
|
||||
cast_training_params([self.components.transformer], dtype=torch.float32)
|
||||
|
||||
free_memory()
|
||||
accelerator.wait_for_everyone()
|
||||
################################
|
||||
|
||||
memory_statistics = get_memory_statistics()
|
||||
logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}")
|
||||
torch.cuda.reset_peak_memory_stats(accelerator.device)
|
||||
|
||||
torch.set_grad_enabled(True)
|
||||
self.components.transformer.train()
|
||||
|
||||
def fit(self):
|
||||
self.check_setting()
|
||||
self.prepare_models()
|
||||
self.prepare_dataset()
|
||||
self.prepare_trainable_parameters()
|
||||
self.prepare_optimizer()
|
||||
self.prepare_for_training()
|
||||
if self.args.do_validation:
|
||||
self.prepare_for_validation()
|
||||
self.prepare_trackers()
|
||||
self.train()
|
||||
|
||||
def collate_fn(self, examples: List[Dict[str, Any]]):
|
||||
raise NotImplementedError
|
||||
|
||||
def load_components(self) -> Components:
|
||||
raise NotImplementedError
|
||||
|
||||
def initialize_pipeline(self) -> DiffusionPipeline:
|
||||
raise NotImplementedError
|
||||
|
||||
def encode_video(self, video: torch.Tensor) -> torch.Tensor:
|
||||
# shape of input video: [B, C, F, H, W], where B = 1
|
||||
# shape of output video: [B, C', F', H', W'], where B = 1
|
||||
raise NotImplementedError
|
||||
|
||||
def encode_text(self, text: str) -> torch.Tensor:
|
||||
# shape of output text: [batch size, sequence length, embedding dimension]
|
||||
raise NotImplementedError
|
||||
|
||||
def compute_loss(self, batch) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def validation_step(self) -> List[Tuple[str, Image.Image | List[Image.Image]]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def __get_training_dtype(self) -> torch.dtype:
|
||||
if self.args.mixed_precision == "no":
|
||||
return _DTYPE_MAP["fp32"]
|
||||
elif self.args.mixed_precision == "fp16":
|
||||
return _DTYPE_MAP["fp16"]
|
||||
elif self.args.mixed_precision == "bf16":
|
||||
return _DTYPE_MAP["bf16"]
|
||||
else:
|
||||
raise ValueError(f"Invalid mixed precision: {self.args.mixed_precision}")
|
||||
|
||||
def __move_components_to_device(self, dtype, ignore_list: List[str] = []):
|
||||
ignore_list = set(ignore_list)
|
||||
components = self.components.model_dump()
|
||||
for name, component in components.items():
|
||||
if not isinstance(component, type) and hasattr(component, "to"):
|
||||
if name not in ignore_list:
|
||||
setattr(
|
||||
self.components, name, component.to(self.accelerator.device, dtype=dtype)
|
||||
)
|
||||
|
||||
def __move_components_to_cpu(self, unload_list: List[str] = []):
|
||||
unload_list = set(unload_list)
|
||||
components = self.components.model_dump()
|
||||
for name, component in components.items():
|
||||
if not isinstance(component, type) and hasattr(component, "to"):
|
||||
if name in unload_list:
|
||||
setattr(self.components, name, component.to("cpu"))
|
||||
|
||||
def __prepare_saving_loading_hooks(self, transformer_lora_config):
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if self.accelerator.is_main_process:
|
||||
transformer_lora_layers_to_save = None
|
||||
|
||||
for model in models:
|
||||
if isinstance(
|
||||
unwrap_model(self.accelerator, model),
|
||||
type(unwrap_model(self.accelerator, self.components.transformer)),
|
||||
):
|
||||
model = unwrap_model(self.accelerator, model)
|
||||
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
else:
|
||||
raise ValueError(f"Unexpected save model: {model.__class__}")
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
if weights:
|
||||
weights.pop()
|
||||
|
||||
self.components.pipeline_cls.save_lora_weights(
|
||||
output_dir,
|
||||
transformer_lora_layers=transformer_lora_layers_to_save,
|
||||
)
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
if not self.accelerator.distributed_type == DistributedType.DEEPSPEED:
|
||||
while len(models) > 0:
|
||||
model = models.pop()
|
||||
if isinstance(
|
||||
unwrap_model(self.accelerator, model),
|
||||
type(unwrap_model(self.accelerator, self.components.transformer)),
|
||||
):
|
||||
transformer_ = unwrap_model(self.accelerator, model)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected save model: {unwrap_model(self.accelerator, model).__class__}"
|
||||
)
|
||||
else:
|
||||
transformer_ = unwrap_model(
|
||||
self.accelerator, self.components.transformer
|
||||
).__class__.from_pretrained(self.args.model_path, subfolder="transformer")
|
||||
transformer_.add_adapter(transformer_lora_config)
|
||||
|
||||
lora_state_dict = self.components.pipeline_cls.lora_state_dict(input_dir)
|
||||
transformer_state_dict = {
|
||||
f'{k.replace("transformer.", "")}': v
|
||||
for k, v in lora_state_dict.items()
|
||||
if k.startswith("transformer.")
|
||||
}
|
||||
incompatible_keys = set_peft_model_state_dict(
|
||||
transformer_, transformer_state_dict, adapter_name="default"
|
||||
)
|
||||
if incompatible_keys is not None:
|
||||
# check only for unexpected keys
|
||||
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
||||
if unexpected_keys:
|
||||
logger.warning(
|
||||
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
|
||||
f" {unexpected_keys}. "
|
||||
)
|
||||
|
||||
self.accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
self.accelerator.register_load_state_pre_hook(load_model_hook)
|
||||
|
||||
def __maybe_save_checkpoint(self, global_step: int, must_save: bool = False):
|
||||
if (
|
||||
self.accelerator.distributed_type == DistributedType.DEEPSPEED
|
||||
or self.accelerator.is_main_process
|
||||
):
|
||||
if must_save or global_step % self.args.checkpointing_steps == 0:
|
||||
# for training
|
||||
save_path = get_intermediate_ckpt_path(
|
||||
checkpointing_limit=self.args.checkpointing_limit,
|
||||
step=global_step,
|
||||
output_dir=self.args.output_dir,
|
||||
)
|
||||
self.accelerator.save_state(save_path, safe_serialization=True)
|
5
finetune/utils/__init__.py
Normal file
5
finetune/utils/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
from .checkpointing import *
|
||||
from .file_utils import *
|
||||
from .memory_utils import *
|
||||
from .optimizer_utils import *
|
||||
from .torch_utils import *
|
57
finetune/utils/checkpointing.py
Normal file
57
finetune/utils/checkpointing.py
Normal file
@ -0,0 +1,57 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
|
||||
from accelerate.logging import get_logger
|
||||
|
||||
from finetune.constants import LOG_LEVEL, LOG_NAME
|
||||
|
||||
from ..utils.file_utils import delete_files, find_files
|
||||
|
||||
|
||||
logger = get_logger(LOG_NAME, LOG_LEVEL)
|
||||
|
||||
|
||||
def get_latest_ckpt_path_to_resume_from(
|
||||
resume_from_checkpoint: str | None, num_update_steps_per_epoch: int
|
||||
) -> Tuple[str | None, int, int, int]:
|
||||
if resume_from_checkpoint is None:
|
||||
initial_global_step = 0
|
||||
global_step = 0
|
||||
first_epoch = 0
|
||||
resume_from_checkpoint_path = None
|
||||
else:
|
||||
resume_from_checkpoint_path = Path(resume_from_checkpoint)
|
||||
if not resume_from_checkpoint_path.exists():
|
||||
logger.info(
|
||||
f"Checkpoint '{resume_from_checkpoint}' does not exist. Starting a new training run."
|
||||
)
|
||||
initial_global_step = 0
|
||||
global_step = 0
|
||||
first_epoch = 0
|
||||
resume_from_checkpoint_path = None
|
||||
else:
|
||||
logger.info(f"Resuming from checkpoint {resume_from_checkpoint}")
|
||||
global_step = int(resume_from_checkpoint_path.name.split("-")[1])
|
||||
|
||||
initial_global_step = global_step
|
||||
first_epoch = global_step // num_update_steps_per_epoch
|
||||
|
||||
return resume_from_checkpoint_path, initial_global_step, global_step, first_epoch
|
||||
|
||||
|
||||
def get_intermediate_ckpt_path(checkpointing_limit: int, step: int, output_dir: str) -> str:
|
||||
# before saving state, check if this save would set us over the `checkpointing_limit`
|
||||
if checkpointing_limit is not None:
|
||||
checkpoints = find_files(output_dir, prefix="checkpoint")
|
||||
|
||||
# before we save the new checkpoint, we need to have at_most `checkpoints_total_limit - 1` checkpoints
|
||||
if len(checkpoints) >= checkpointing_limit:
|
||||
num_to_remove = len(checkpoints) - checkpointing_limit + 1
|
||||
checkpoints_to_remove = checkpoints[0:num_to_remove]
|
||||
delete_files(checkpoints_to_remove)
|
||||
|
||||
logger.info(f"Checkpointing at step {step}")
|
||||
save_path = os.path.join(output_dir, f"checkpoint-{step}")
|
||||
logger.info(f"Saving state to {save_path}")
|
||||
return save_path
|
48
finetune/utils/file_utils.py
Normal file
48
finetune/utils/file_utils.py
Normal file
@ -0,0 +1,48 @@
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from accelerate.logging import get_logger
|
||||
|
||||
from finetune.constants import LOG_LEVEL, LOG_NAME
|
||||
|
||||
|
||||
logger = get_logger(LOG_NAME, LOG_LEVEL)
|
||||
|
||||
|
||||
def find_files(dir: Union[str, Path], prefix: str = "checkpoint") -> List[str]:
|
||||
if not isinstance(dir, Path):
|
||||
dir = Path(dir)
|
||||
if not dir.exists():
|
||||
return []
|
||||
checkpoints = os.listdir(dir.as_posix())
|
||||
checkpoints = [c for c in checkpoints if c.startswith(prefix)]
|
||||
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
|
||||
checkpoints = [dir / c for c in checkpoints]
|
||||
return checkpoints
|
||||
|
||||
|
||||
def delete_files(dirs: Union[str, List[str], Path, List[Path]]) -> None:
|
||||
if not isinstance(dirs, list):
|
||||
dirs = [dirs]
|
||||
dirs = [Path(d) if isinstance(d, str) else d for d in dirs]
|
||||
logger.info(f"Deleting files: {dirs}")
|
||||
for dir in dirs:
|
||||
if not dir.exists():
|
||||
continue
|
||||
shutil.rmtree(dir, ignore_errors=True)
|
||||
|
||||
|
||||
def string_to_filename(s: str) -> str:
|
||||
return (
|
||||
s.replace(" ", "-")
|
||||
.replace("/", "-")
|
||||
.replace(":", "-")
|
||||
.replace(".", "-")
|
||||
.replace(",", "-")
|
||||
.replace(";", "-")
|
||||
.replace("!", "-")
|
||||
.replace("?", "-")
|
||||
)
|
66
finetune/utils/memory_utils.py
Normal file
66
finetune/utils/memory_utils.py
Normal file
@ -0,0 +1,66 @@
|
||||
import gc
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import torch
|
||||
from accelerate.logging import get_logger
|
||||
|
||||
from finetune.constants import LOG_LEVEL, LOG_NAME
|
||||
|
||||
|
||||
logger = get_logger(LOG_NAME, LOG_LEVEL)
|
||||
|
||||
|
||||
def get_memory_statistics(precision: int = 3) -> Dict[str, Any]:
|
||||
memory_allocated = None
|
||||
memory_reserved = None
|
||||
max_memory_allocated = None
|
||||
max_memory_reserved = None
|
||||
|
||||
if torch.cuda.is_available():
|
||||
device = torch.cuda.current_device()
|
||||
memory_allocated = torch.cuda.memory_allocated(device)
|
||||
memory_reserved = torch.cuda.memory_reserved(device)
|
||||
max_memory_allocated = torch.cuda.max_memory_allocated(device)
|
||||
max_memory_reserved = torch.cuda.max_memory_reserved(device)
|
||||
|
||||
elif torch.mps.is_available():
|
||||
memory_allocated = torch.mps.current_allocated_memory()
|
||||
|
||||
else:
|
||||
logger.warning("No CUDA, MPS, or ROCm device found. Memory statistics are not available.")
|
||||
|
||||
return {
|
||||
"memory_allocated": round(bytes_to_gigabytes(memory_allocated), ndigits=precision),
|
||||
"memory_reserved": round(bytes_to_gigabytes(memory_reserved), ndigits=precision),
|
||||
"max_memory_allocated": round(bytes_to_gigabytes(max_memory_allocated), ndigits=precision),
|
||||
"max_memory_reserved": round(bytes_to_gigabytes(max_memory_reserved), ndigits=precision),
|
||||
}
|
||||
|
||||
|
||||
def bytes_to_gigabytes(x: int) -> float:
|
||||
if x is not None:
|
||||
return x / 1024**3
|
||||
|
||||
|
||||
def free_memory() -> None:
|
||||
if torch.cuda.is_available():
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
# TODO(aryan): handle non-cuda devices
|
||||
|
||||
|
||||
def unload_model(model):
|
||||
model.to("cpu")
|
||||
|
||||
|
||||
def make_contiguous(
|
||||
x: Union[torch.Tensor, Dict[str, torch.Tensor]],
|
||||
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
if isinstance(x, torch.Tensor):
|
||||
return x.contiguous()
|
||||
elif isinstance(x, dict):
|
||||
return {k: make_contiguous(v) for k, v in x.items()}
|
||||
else:
|
||||
return x
|
191
finetune/utils/optimizer_utils.py
Normal file
191
finetune/utils/optimizer_utils.py
Normal file
@ -0,0 +1,191 @@
|
||||
import inspect
|
||||
|
||||
import torch
|
||||
from accelerate.logging import get_logger
|
||||
|
||||
from finetune.constants import LOG_LEVEL, LOG_NAME
|
||||
|
||||
|
||||
logger = get_logger(LOG_NAME, LOG_LEVEL)
|
||||
|
||||
|
||||
def get_optimizer(
|
||||
params_to_optimize,
|
||||
optimizer_name: str = "adam",
|
||||
learning_rate: float = 1e-3,
|
||||
beta1: float = 0.9,
|
||||
beta2: float = 0.95,
|
||||
beta3: float = 0.98,
|
||||
epsilon: float = 1e-8,
|
||||
weight_decay: float = 1e-4,
|
||||
prodigy_decouple: bool = False,
|
||||
prodigy_use_bias_correction: bool = False,
|
||||
prodigy_safeguard_warmup: bool = False,
|
||||
use_8bit: bool = False,
|
||||
use_4bit: bool = False,
|
||||
use_torchao: bool = False,
|
||||
use_deepspeed: bool = False,
|
||||
use_cpu_offload_optimizer: bool = False,
|
||||
offload_gradients: bool = False,
|
||||
) -> torch.optim.Optimizer:
|
||||
optimizer_name = optimizer_name.lower()
|
||||
|
||||
# Use DeepSpeed optimzer
|
||||
if use_deepspeed:
|
||||
from accelerate.utils import DummyOptim
|
||||
|
||||
return DummyOptim(
|
||||
params_to_optimize,
|
||||
lr=learning_rate,
|
||||
betas=(beta1, beta2),
|
||||
eps=epsilon,
|
||||
weight_decay=weight_decay,
|
||||
)
|
||||
|
||||
if use_8bit and use_4bit:
|
||||
raise ValueError("Cannot set both `use_8bit` and `use_4bit` to True.")
|
||||
|
||||
if (use_torchao and (use_8bit or use_4bit)) or use_cpu_offload_optimizer:
|
||||
try:
|
||||
import torchao
|
||||
|
||||
torchao.__version__
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"To use optimizers from torchao, please install the torchao library: `USE_CPP=0 pip install torchao`."
|
||||
)
|
||||
|
||||
if not use_torchao and use_4bit:
|
||||
raise ValueError("4-bit Optimizers are only supported with torchao.")
|
||||
|
||||
# Optimizer creation
|
||||
supported_optimizers = ["adam", "adamw", "prodigy", "came"]
|
||||
if optimizer_name not in supported_optimizers:
|
||||
logger.warning(
|
||||
f"Unsupported choice of optimizer: {optimizer_name}. Supported optimizers include {supported_optimizers}. Defaulting to `AdamW`."
|
||||
)
|
||||
optimizer_name = "adamw"
|
||||
|
||||
if (use_8bit or use_4bit) and optimizer_name not in ["adam", "adamw"]:
|
||||
raise ValueError(
|
||||
"`use_8bit` and `use_4bit` can only be used with the Adam and AdamW optimizers."
|
||||
)
|
||||
|
||||
if use_8bit:
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
|
||||
)
|
||||
|
||||
if optimizer_name == "adamw":
|
||||
if use_torchao:
|
||||
from torchao.prototype.low_bit_optim import AdamW4bit, AdamW8bit
|
||||
|
||||
optimizer_class = (
|
||||
AdamW8bit if use_8bit else AdamW4bit if use_4bit else torch.optim.AdamW
|
||||
)
|
||||
else:
|
||||
optimizer_class = bnb.optim.AdamW8bit if use_8bit else torch.optim.AdamW
|
||||
|
||||
init_kwargs = {
|
||||
"betas": (beta1, beta2),
|
||||
"eps": epsilon,
|
||||
"weight_decay": weight_decay,
|
||||
}
|
||||
|
||||
elif optimizer_name == "adam":
|
||||
if use_torchao:
|
||||
from torchao.prototype.low_bit_optim import Adam4bit, Adam8bit
|
||||
|
||||
optimizer_class = Adam8bit if use_8bit else Adam4bit if use_4bit else torch.optim.Adam
|
||||
else:
|
||||
optimizer_class = bnb.optim.Adam8bit if use_8bit else torch.optim.Adam
|
||||
|
||||
init_kwargs = {
|
||||
"betas": (beta1, beta2),
|
||||
"eps": epsilon,
|
||||
"weight_decay": weight_decay,
|
||||
}
|
||||
|
||||
elif optimizer_name == "prodigy":
|
||||
try:
|
||||
import prodigyopt
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`"
|
||||
)
|
||||
|
||||
optimizer_class = prodigyopt.Prodigy
|
||||
|
||||
if learning_rate <= 0.1:
|
||||
logger.warning(
|
||||
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
|
||||
)
|
||||
|
||||
init_kwargs = {
|
||||
"lr": learning_rate,
|
||||
"betas": (beta1, beta2),
|
||||
"beta3": beta3,
|
||||
"eps": epsilon,
|
||||
"weight_decay": weight_decay,
|
||||
"decouple": prodigy_decouple,
|
||||
"use_bias_correction": prodigy_use_bias_correction,
|
||||
"safeguard_warmup": prodigy_safeguard_warmup,
|
||||
}
|
||||
|
||||
elif optimizer_name == "came":
|
||||
try:
|
||||
import came_pytorch
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"To use CAME, please install the came-pytorch library: `pip install came-pytorch`"
|
||||
)
|
||||
|
||||
optimizer_class = came_pytorch.CAME
|
||||
|
||||
init_kwargs = {
|
||||
"lr": learning_rate,
|
||||
"eps": (1e-30, 1e-16),
|
||||
"betas": (beta1, beta2, beta3),
|
||||
"weight_decay": weight_decay,
|
||||
}
|
||||
|
||||
if use_cpu_offload_optimizer:
|
||||
from torchao.prototype.low_bit_optim import CPUOffloadOptimizer
|
||||
|
||||
if "fused" in inspect.signature(optimizer_class.__init__).parameters:
|
||||
init_kwargs.update({"fused": True})
|
||||
|
||||
optimizer = CPUOffloadOptimizer(
|
||||
params_to_optimize,
|
||||
optimizer_class=optimizer_class,
|
||||
offload_gradients=offload_gradients,
|
||||
**init_kwargs,
|
||||
)
|
||||
else:
|
||||
optimizer = optimizer_class(params_to_optimize, **init_kwargs)
|
||||
|
||||
return optimizer
|
||||
|
||||
|
||||
def gradient_norm(parameters):
|
||||
norm = 0
|
||||
for param in parameters:
|
||||
if param.grad is None:
|
||||
continue
|
||||
local_norm = param.grad.detach().data.norm(2)
|
||||
norm += local_norm.item() ** 2
|
||||
norm = norm**0.5
|
||||
return norm
|
||||
|
||||
|
||||
def max_gradient(parameters):
|
||||
max_grad_value = float("-inf")
|
||||
for param in parameters:
|
||||
if param.grad is None:
|
||||
continue
|
||||
local_max_grad = param.grad.detach().data.abs().max()
|
||||
max_grad_value = max(max_grad_value, local_max_grad.item())
|
||||
return max_grad_value
|
52
finetune/utils/torch_utils.py
Normal file
52
finetune/utils/torch_utils.py
Normal file
@ -0,0 +1,52 @@
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
def unwrap_model(accelerator: Accelerator, model):
|
||||
model = accelerator.unwrap_model(model)
|
||||
model = model._orig_mod if is_compiled_module(model) else model
|
||||
return model
|
||||
|
||||
|
||||
def align_device_and_dtype(
|
||||
x: Union[torch.Tensor, Dict[str, torch.Tensor]],
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
if isinstance(x, torch.Tensor):
|
||||
if device is not None:
|
||||
x = x.to(device)
|
||||
if dtype is not None:
|
||||
x = x.to(dtype)
|
||||
elif isinstance(x, dict):
|
||||
if device is not None:
|
||||
x = {k: align_device_and_dtype(v, device, dtype) for k, v in x.items()}
|
||||
if dtype is not None:
|
||||
x = {k: align_device_and_dtype(v, device, dtype) for k, v in x.items()}
|
||||
return x
|
||||
|
||||
|
||||
def expand_tensor_to_dims(tensor, ndim):
|
||||
while len(tensor.shape) < ndim:
|
||||
tensor = tensor.unsqueeze(-1)
|
||||
return tensor
|
||||
|
||||
|
||||
def cast_training_params(model: Union[torch.nn.Module, List[torch.nn.Module]], dtype=torch.float32):
|
||||
"""
|
||||
Casts the training parameters of the model to the specified data type.
|
||||
|
||||
Args:
|
||||
model: The PyTorch model whose parameters will be cast.
|
||||
dtype: The data type to which the model parameters will be cast.
|
||||
"""
|
||||
if not isinstance(model, list):
|
||||
model = [model]
|
||||
for m in model:
|
||||
for param in m.parameters():
|
||||
# only upcast trainable parameters into fp32
|
||||
if param.requires_grad:
|
||||
param.data = param.to(dtype)
|
@ -3,40 +3,60 @@ This script demonstrates how to generate a video using the CogVideoX model with
|
||||
The script supports different types of video generation, including text-to-video (t2v), image-to-video (i2v),
|
||||
and video-to-video (v2v), depending on the input data and different weight.
|
||||
|
||||
- text-to-video: THUDM/CogVideoX-5b or THUDM/CogVideoX-2b
|
||||
- video-to-video: THUDM/CogVideoX-5b or THUDM/CogVideoX-2b
|
||||
- image-to-video: THUDM/CogVideoX-5b-I2V
|
||||
- text-to-video: THUDM/CogVideoX-5b, THUDM/CogVideoX-2b or THUDM/CogVideoX1.5-5b
|
||||
- video-to-video: THUDM/CogVideoX-5b, THUDM/CogVideoX-2b or THUDM/CogVideoX1.5-5b
|
||||
- image-to-video: THUDM/CogVideoX-5b-I2V or THUDM/CogVideoX1.5-5b-I2V
|
||||
|
||||
Running the Script:
|
||||
To run the script, use the following command with appropriate arguments:
|
||||
|
||||
```bash
|
||||
$ python cli_demo.py --prompt "A girl riding a bike." --model_path THUDM/CogVideoX-5b --generate_type "t2v"
|
||||
$ python cli_demo.py --prompt "A girl riding a bike." --model_path THUDM/CogVideoX1.5-5b --generate_type "t2v"
|
||||
```
|
||||
|
||||
You can change `pipe.enable_sequential_cpu_offload()` to `pipe.enable_model_cpu_offload()` to speed up inference, but this will use more GPU memory
|
||||
|
||||
Additional options are available to specify the model path, guidance scale, number of inference steps, video generation type, and output paths.
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from typing import Literal
|
||||
import logging
|
||||
from typing import Literal, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import (
|
||||
CogVideoXPipeline,
|
||||
CogVideoXDDIMScheduler,
|
||||
CogVideoXDPMScheduler,
|
||||
CogVideoXImageToVideoPipeline,
|
||||
CogVideoXPipeline,
|
||||
CogVideoXVideoToVideoPipeline,
|
||||
)
|
||||
|
||||
from diffusers.utils import export_to_video, load_image, load_video
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
# Recommended resolution for each model (width, height)
|
||||
RESOLUTION_MAP = {
|
||||
# cogvideox1.5-*
|
||||
"cogvideox1.5-5b-i2v": (768, 1360),
|
||||
"cogvideox1.5-5b": (768, 1360),
|
||||
# cogvideox-*
|
||||
"cogvideox-5b-i2v": (480, 720),
|
||||
"cogvideox-5b": (480, 720),
|
||||
"cogvideox-2b": (480, 720),
|
||||
}
|
||||
|
||||
|
||||
def generate_video(
|
||||
prompt: str,
|
||||
model_path: str,
|
||||
lora_path: str = None,
|
||||
lora_rank: int = 128,
|
||||
num_frames: int = 81,
|
||||
width: Optional[int] = None,
|
||||
height: Optional[int] = None,
|
||||
output_path: str = "./output.mp4",
|
||||
image_or_video_path: str = "",
|
||||
num_inference_steps: int = 50,
|
||||
@ -45,6 +65,7 @@ def generate_video(
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
generate_type: str = Literal["t2v", "i2v", "v2v"], # i2v: image to video, v2v: video to video
|
||||
seed: int = 42,
|
||||
fps: int = 16,
|
||||
):
|
||||
"""
|
||||
Generates a video based on the given prompt and saves it to the specified path.
|
||||
@ -56,11 +77,15 @@ def generate_video(
|
||||
- lora_rank (int): The rank of the LoRA weights.
|
||||
- output_path (str): The path where the generated video will be saved.
|
||||
- num_inference_steps (int): Number of steps for the inference process. More steps can result in better quality.
|
||||
- num_frames (int): Number of frames to generate. CogVideoX1.0 generates 49 frames for 6 seconds at 8 fps, while CogVideoX1.5 produces either 81 or 161 frames, corresponding to 5 seconds or 10 seconds at 16 fps.
|
||||
- width (int): The width of the generated video, applicable only for CogVideoX1.5-5B-I2V
|
||||
- height (int): The height of the generated video, applicable only for CogVideoX1.5-5B-I2V
|
||||
- guidance_scale (float): The scale for classifier-free guidance. Higher values can lead to better alignment with the prompt.
|
||||
- num_videos_per_prompt (int): Number of videos to generate per prompt.
|
||||
- dtype (torch.dtype): The data type for computation (default is torch.bfloat16).
|
||||
- generate_type (str): The type of video generation (e.g., 't2v', 'i2v', 'v2v').·
|
||||
- seed (int): The seed for reproducibility.
|
||||
- fps (int): The frames per second for the generated video.
|
||||
"""
|
||||
|
||||
# 1. Load the pre-trained CogVideoX pipeline with the specified precision (bfloat16).
|
||||
@ -70,6 +95,26 @@ def generate_video(
|
||||
image = None
|
||||
video = None
|
||||
|
||||
model_name = model_path.split("/")[-1].lower()
|
||||
desired_resolution = RESOLUTION_MAP[model_name]
|
||||
if width is None or height is None:
|
||||
height, width = desired_resolution
|
||||
logging.info(
|
||||
f"\033[1mUsing default resolution {desired_resolution} for {model_name}\033[0m"
|
||||
)
|
||||
elif (height, width) != desired_resolution:
|
||||
if generate_type == "i2v":
|
||||
# For i2v models, use user-defined width and height
|
||||
logging.warning(
|
||||
f"\033[1;31mThe width({width}) and height({height}) are not recommended for {model_name}. The best resolution is {desired_resolution}.\033[0m"
|
||||
)
|
||||
else:
|
||||
# Otherwise, use the recommended width and height
|
||||
logging.warning(
|
||||
f"\033[1;31m{model_name} is not supported for custom resolution. Setting back to default resolution {desired_resolution}.\033[0m"
|
||||
)
|
||||
height, width = desired_resolution
|
||||
|
||||
if generate_type == "i2v":
|
||||
pipe = CogVideoXImageToVideoPipeline.from_pretrained(model_path, torch_dtype=dtype)
|
||||
image = load_image(image=image_or_video_path)
|
||||
@ -81,8 +126,10 @@ def generate_video(
|
||||
|
||||
# If you're using with lora, add this code
|
||||
if lora_path:
|
||||
pipe.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name="test_1")
|
||||
pipe.fuse_lora(lora_scale=1 / lora_rank)
|
||||
pipe.load_lora_weights(
|
||||
lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name="test_1"
|
||||
)
|
||||
pipe.fuse_lora(components=["transformer"], lora_scale=1.0)
|
||||
|
||||
# 2. Set Scheduler.
|
||||
# Can be changed to `CogVideoXDPMScheduler` or `CogVideoXDDIMScheduler`.
|
||||
@ -90,61 +137,71 @@ def generate_video(
|
||||
# using `CogVideoXDPMScheduler` for CogVideoX-5B / CogVideoX-5B-I2V.
|
||||
|
||||
# pipe.scheduler = CogVideoXDDIMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
|
||||
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
|
||||
pipe.scheduler = CogVideoXDPMScheduler.from_config(
|
||||
pipe.scheduler.config, timestep_spacing="trailing"
|
||||
)
|
||||
|
||||
# 3. Enable CPU offload for the model.
|
||||
# turn off if you have multiple GPUs or enough GPU memory(such as H100) and it will cost less time in inference
|
||||
# and enable to("cuda")
|
||||
|
||||
# pipe.to("cuda")
|
||||
|
||||
# pipe.enable_model_cpu_offload()
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
|
||||
pipe.vae.enable_slicing()
|
||||
pipe.vae.enable_tiling()
|
||||
|
||||
# 4. Generate the video frames based on the prompt.
|
||||
# `num_frames` is the Number of frames to generate.
|
||||
# This is the default value for 6 seconds video and 8 fps and will plus 1 frame for the first frame and 49 frames.
|
||||
if generate_type == "i2v":
|
||||
video_generate = pipe(
|
||||
height=height,
|
||||
width=width,
|
||||
prompt=prompt,
|
||||
image=image, # The path of the image to be used as the background of the video
|
||||
image=image,
|
||||
# The path of the image, the resolution of video will be the same as the image for CogVideoX1.5-5B-I2V, otherwise it will be 720 * 480
|
||||
num_videos_per_prompt=num_videos_per_prompt, # Number of videos to generate per prompt
|
||||
num_inference_steps=num_inference_steps, # Number of inference steps
|
||||
num_frames=49, # Number of frames to generate,changed to 49 for diffusers version `0.30.3` and after.
|
||||
use_dynamic_cfg=True, # This id used for DPM Sechduler, for DDIM scheduler, it should be False
|
||||
num_frames=num_frames, # Number of frames to generate
|
||||
use_dynamic_cfg=True, # This id used for DPM scheduler, for DDIM scheduler, it should be False
|
||||
guidance_scale=guidance_scale,
|
||||
generator=torch.Generator().manual_seed(seed), # Set the seed for reproducibility
|
||||
).frames[0]
|
||||
elif generate_type == "t2v":
|
||||
video_generate = pipe(
|
||||
height=height,
|
||||
width=width,
|
||||
prompt=prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
num_inference_steps=num_inference_steps,
|
||||
num_frames=49,
|
||||
num_frames=num_frames,
|
||||
use_dynamic_cfg=True,
|
||||
guidance_scale=guidance_scale,
|
||||
generator=torch.Generator().manual_seed(seed),
|
||||
).frames[0]
|
||||
else:
|
||||
video_generate = pipe(
|
||||
height=height,
|
||||
width=width,
|
||||
prompt=prompt,
|
||||
video=video, # The path of the video to be used as the background of the video
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
num_inference_steps=num_inference_steps,
|
||||
# num_frames=49,
|
||||
num_frames=num_frames,
|
||||
use_dynamic_cfg=True,
|
||||
guidance_scale=guidance_scale,
|
||||
generator=torch.Generator().manual_seed(seed), # Set the seed for reproducibility
|
||||
).frames[0]
|
||||
# 5. Export the generated frames to a video file. fps must be 8 for original video.
|
||||
export_to_video(video_generate, output_path, fps=8)
|
||||
export_to_video(video_generate, output_path, fps=fps)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX")
|
||||
parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated")
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate a video from a text prompt using CogVideoX"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt", type=str, required=True, help="The description of the video to be generated"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image_or_video_path",
|
||||
type=str,
|
||||
@ -152,23 +209,43 @@ if __name__ == "__main__":
|
||||
help="The path of the image to be used as the background of the video",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_path", type=str, default="THUDM/CogVideoX-5b", help="The path of the pre-trained model to be used"
|
||||
"--model_path",
|
||||
type=str,
|
||||
default="THUDM/CogVideoX1.5-5B",
|
||||
help="Path of the pre-trained model use",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_path", type=str, default=None, help="The path of the LoRA weights to be used"
|
||||
)
|
||||
parser.add_argument("--lora_path", type=str, default=None, help="The path of the LoRA weights to be used")
|
||||
parser.add_argument("--lora_rank", type=int, default=128, help="The rank of the LoRA weights")
|
||||
parser.add_argument(
|
||||
"--output_path", type=str, default="./output.mp4", help="The path where the generated video will be saved"
|
||||
)
|
||||
parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance")
|
||||
parser.add_argument(
|
||||
"--num_inference_steps", type=int, default=50, help="Number of steps for the inference process"
|
||||
)
|
||||
parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt")
|
||||
parser.add_argument(
|
||||
"--generate_type", type=str, default="t2v", help="The type of video generation (e.g., 't2v', 'i2v', 'v2v')"
|
||||
"--output_path", type=str, default="./output.mp4", help="The path save generated video"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype", type=str, default="bfloat16", help="The data type for computation (e.g., 'float16' or 'bfloat16')"
|
||||
"--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance"
|
||||
)
|
||||
parser.add_argument("--num_inference_steps", type=int, default=50, help="Inference steps")
|
||||
parser.add_argument(
|
||||
"--num_frames", type=int, default=81, help="Number of steps for the inference process"
|
||||
)
|
||||
parser.add_argument("--width", type=int, default=None, help="The width of the generated video")
|
||||
parser.add_argument(
|
||||
"--height", type=int, default=None, help="The height of the generated video"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fps", type=int, default=16, help="The frames per second for the generated video"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_videos_per_prompt",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of videos to generate per prompt",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--generate_type", type=str, default="t2v", help="The type of video generation"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype", type=str, default="bfloat16", help="The data type for computation"
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=42, help="The seed for reproducibility")
|
||||
|
||||
@ -180,6 +257,9 @@ if __name__ == "__main__":
|
||||
lora_path=args.lora_path,
|
||||
lora_rank=args.lora_rank,
|
||||
output_path=args.output_path,
|
||||
num_frames=args.num_frames,
|
||||
width=args.width,
|
||||
height=args.height,
|
||||
image_or_video_path=args.image_or_video_path,
|
||||
num_inference_steps=args.num_inference_steps,
|
||||
guidance_scale=args.guidance_scale,
|
||||
@ -187,4 +267,5 @@ if __name__ == "__main__":
|
||||
dtype=dtype,
|
||||
generate_type=args.generate_type,
|
||||
seed=args.seed,
|
||||
fps=args.fps,
|
||||
)
|
||||
|
@ -3,7 +3,7 @@ This script demonstrates how to generate a video from a text prompt using CogVid
|
||||
|
||||
Note:
|
||||
|
||||
Must install the `torchao`,`torch`,`diffusers`,`accelerate` library FROM SOURCE to use the quantization feature.
|
||||
Must install the `torchao`,`torch` library FROM SOURCE to use the quantization feature.
|
||||
Only NVIDIA GPUs like H100 or higher are supported om FP-8 quantization.
|
||||
|
||||
ALL quantization schemes must use with NVIDIA GPUs.
|
||||
@ -19,7 +19,12 @@ import argparse
|
||||
import os
|
||||
import torch
|
||||
import torch._dynamo
|
||||
from diffusers import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel, CogVideoXPipeline, CogVideoXDPMScheduler
|
||||
from diffusers import (
|
||||
AutoencoderKLCogVideoX,
|
||||
CogVideoXTransformer3DModel,
|
||||
CogVideoXPipeline,
|
||||
CogVideoXDPMScheduler,
|
||||
)
|
||||
from diffusers.utils import export_to_video
|
||||
from transformers import T5EncoderModel
|
||||
from torchao.quantization import quantize_, int8_weight_only
|
||||
@ -51,6 +56,9 @@ def generate_video(
|
||||
num_videos_per_prompt: int = 1,
|
||||
quantization_scheme: str = "fp8",
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
num_frames: int = 81,
|
||||
fps: int = 8,
|
||||
seed: int = 42,
|
||||
):
|
||||
"""
|
||||
Generates a video based on the given prompt and saves it to the specified path.
|
||||
@ -65,10 +73,13 @@ def generate_video(
|
||||
- quantization_scheme (str): The quantization scheme to use ('int8', 'fp8').
|
||||
- dtype (torch.dtype): The data type for computation (default is torch.bfloat16).
|
||||
"""
|
||||
|
||||
text_encoder = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder", torch_dtype=dtype)
|
||||
text_encoder = T5EncoderModel.from_pretrained(
|
||||
model_path, subfolder="text_encoder", torch_dtype=dtype
|
||||
)
|
||||
text_encoder = quantize_model(part=text_encoder, quantization_scheme=quantization_scheme)
|
||||
transformer = CogVideoXTransformer3DModel.from_pretrained(model_path, subfolder="transformer", torch_dtype=dtype)
|
||||
transformer = CogVideoXTransformer3DModel.from_pretrained(
|
||||
model_path, subfolder="transformer", torch_dtype=dtype
|
||||
)
|
||||
transformer = quantize_model(part=transformer, quantization_scheme=quantization_scheme)
|
||||
vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae", torch_dtype=dtype)
|
||||
vae = quantize_model(part=vae, quantization_scheme=quantization_scheme)
|
||||
@ -79,55 +90,59 @@ def generate_video(
|
||||
vae=vae,
|
||||
torch_dtype=dtype,
|
||||
)
|
||||
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
|
||||
|
||||
# Using with compile will run faster. First time infer will cost ~30min to compile.
|
||||
# pipe.transformer.to(memory_format=torch.channels_last)
|
||||
|
||||
# for FP8 should remove pipe.enable_model_cpu_offload()
|
||||
pipe.scheduler = CogVideoXDPMScheduler.from_config(
|
||||
pipe.scheduler.config, timestep_spacing="trailing"
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
# This is not for FP8 and INT8 and should remove this line
|
||||
# pipe.enable_sequential_cpu_offload()
|
||||
pipe.vae.enable_slicing()
|
||||
pipe.vae.enable_tiling()
|
||||
|
||||
video = pipe(
|
||||
prompt=prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
num_inference_steps=num_inference_steps,
|
||||
num_frames=49,
|
||||
num_frames=num_frames,
|
||||
use_dynamic_cfg=True,
|
||||
guidance_scale=guidance_scale,
|
||||
generator=torch.Generator(device="cuda").manual_seed(42),
|
||||
generator=torch.Generator(device="cuda").manual_seed(seed),
|
||||
).frames[0]
|
||||
|
||||
export_to_video(video, output_path, fps=8)
|
||||
export_to_video(video, output_path, fps=fps)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX")
|
||||
parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated")
|
||||
parser.add_argument(
|
||||
"--model_path", type=str, default="THUDM/CogVideoX-5b", help="The path of the pre-trained model to be used"
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate a video from a text prompt using CogVideoX"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_path", type=str, default="./output.mp4", help="The path where the generated video will be saved"
|
||||
"--prompt", type=str, required=True, help="The description of the video to be generated"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_inference_steps", type=int, default=50, help="Number of steps for the inference process"
|
||||
"--model_path", type=str, default="THUDM/CogVideoX-5b", help="Path of the pre-trained model"
|
||||
)
|
||||
parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance")
|
||||
parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt")
|
||||
parser.add_argument(
|
||||
"--dtype", type=str, default="bfloat16", help="The data type for computation (e.g., 'float16', 'bfloat16')"
|
||||
"--output_path", type=str, default="./output.mp4", help="Path to save generated video"
|
||||
)
|
||||
parser.add_argument("--num_inference_steps", type=int, default=50, help="Inference steps")
|
||||
parser.add_argument(
|
||||
"--guidance_scale", type=float, default=6.0, help="Classifier-free guidance scale"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_videos_per_prompt", type=int, default=1, help="Videos to generate per prompt"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype", type=str, default="bfloat16", help="Data type (e.g., 'float16', 'bfloat16')"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quantization_scheme",
|
||||
type=str,
|
||||
default="bf16",
|
||||
default="fp8",
|
||||
choices=["int8", "fp8"],
|
||||
help="The quantization scheme to use (int8, fp8)",
|
||||
help="Quantization scheme",
|
||||
)
|
||||
parser.add_argument("--num_frames", type=int, default=81, help="Number of frames in the video")
|
||||
parser.add_argument("--fps", type=int, default=16, help="Frames per second for output video")
|
||||
parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
|
||||
|
||||
args = parser.parse_args()
|
||||
dtype = torch.float16 if args.dtype == "float16" else torch.bfloat16
|
||||
@ -140,4 +155,7 @@ if __name__ == "__main__":
|
||||
num_videos_per_prompt=args.num_videos_per_prompt,
|
||||
quantization_scheme=args.quantization_scheme,
|
||||
dtype=dtype,
|
||||
num_frames=args.num_frames,
|
||||
fps=args.fps,
|
||||
seed=args.seed,
|
||||
)
|
||||
|
@ -104,18 +104,34 @@ def save_video(tensor, output_path):
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="CogVideoX encode/decode demo")
|
||||
parser.add_argument("--model_path", type=str, required=True, help="The path to the CogVideoX model")
|
||||
parser.add_argument(
|
||||
"--model_path", type=str, required=True, help="The path to the CogVideoX model"
|
||||
)
|
||||
parser.add_argument("--video_path", type=str, help="The path to the video file (for encoding)")
|
||||
parser.add_argument("--encoded_path", type=str, help="The path to the encoded tensor file (for decoding)")
|
||||
parser.add_argument("--output_path", type=str, default=".", help="The path to save the output file")
|
||||
parser.add_argument(
|
||||
"--mode", type=str, choices=["encode", "decode", "both"], required=True, help="Mode: encode, decode, or both"
|
||||
"--encoded_path", type=str, help="The path to the encoded tensor file (for decoding)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype", type=str, default="bfloat16", help="The data type for computation (e.g., 'float16' or 'bfloat16')"
|
||||
"--output_path", type=str, default=".", help="The path to save the output file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device", type=str, default="cuda", help="The device to use for computation (e.g., 'cuda' or 'cpu')"
|
||||
"--mode",
|
||||
type=str,
|
||||
choices=["encode", "decode", "both"],
|
||||
required=True,
|
||||
help="Mode: encode, decode, or both",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=str,
|
||||
default="bfloat16",
|
||||
help="The data type for computation (e.g., 'float16' or 'bfloat16')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda",
|
||||
help="The device to use for computation (e.g., 'cuda' or 'cpu')",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -126,15 +142,21 @@ if __name__ == "__main__":
|
||||
assert args.video_path, "Video path must be provided for encoding."
|
||||
encoded_output = encode_video(args.model_path, args.video_path, dtype, device)
|
||||
torch.save(encoded_output, args.output_path + "/encoded.pt")
|
||||
print(f"Finished encoding the video to a tensor, save it to a file at {encoded_output}/encoded.pt")
|
||||
print(
|
||||
f"Finished encoding the video to a tensor, save it to a file at {encoded_output}/encoded.pt"
|
||||
)
|
||||
elif args.mode == "decode":
|
||||
assert args.encoded_path, "Encoded tensor path must be provided for decoding."
|
||||
decoded_output = decode_video(args.model_path, args.encoded_path, dtype, device)
|
||||
save_video(decoded_output, args.output_path)
|
||||
print(f"Finished decoding the video and saved it to a file at {args.output_path}/output.mp4")
|
||||
print(
|
||||
f"Finished decoding the video and saved it to a file at {args.output_path}/output.mp4"
|
||||
)
|
||||
elif args.mode == "both":
|
||||
assert args.video_path, "Video path must be provided for encoding."
|
||||
encoded_output = encode_video(args.model_path, args.video_path, dtype, device)
|
||||
torch.save(encoded_output, args.output_path + "/encoded.pt")
|
||||
decoded_output = decode_video(args.model_path, args.output_path + "/encoded.pt", dtype, device)
|
||||
decoded_output = decode_video(
|
||||
args.model_path, args.output_path + "/encoded.pt", dtype, device
|
||||
)
|
||||
save_video(decoded_output, args.output_path)
|
||||
|
@ -144,7 +144,9 @@ def convert_prompt(prompt: str, retry_times: int = 3, type: str = "t2v", image_p
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--prompt", type=str, required=True, help="Prompt to convert")
|
||||
parser.add_argument("--retry_times", type=int, default=3, help="Number of times to retry the conversion")
|
||||
parser.add_argument(
|
||||
"--retry_times", type=int, default=3, help="Number of times to retry the conversion"
|
||||
)
|
||||
parser.add_argument("--type", type=str, default="t2v", help="Type of conversion (t2v or i2v)")
|
||||
parser.add_argument("--image_path", type=str, default=None, help="Path to the image file")
|
||||
args = parser.parse_args()
|
||||
|
519
inference/ddim_inversion.py
Normal file
519
inference/ddim_inversion.py
Normal file
@ -0,0 +1,519 @@
|
||||
"""
|
||||
This script performs DDIM inversion for video frames using a pre-trained model and generates
|
||||
a video reconstruction based on a provided prompt. It utilizes the CogVideoX pipeline to
|
||||
process video frames, apply the DDIM inverse scheduler, and produce an output video.
|
||||
|
||||
**Please notice that this script is based on the CogVideoX 5B model, and would not generate
|
||||
a good result for 2B variants.**
|
||||
|
||||
Usage:
|
||||
python ddim_inversion.py
|
||||
--model-path /path/to/model
|
||||
--prompt "a prompt"
|
||||
--video-path /path/to/video.mp4
|
||||
--output-path /path/to/output
|
||||
|
||||
For more details about the cli arguments, please run `python ddim_inversion.py --help`.
|
||||
|
||||
Author:
|
||||
LittleNyima <littlenyima[at]163[dot]com>
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union, cast
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchvision.transforms as T
|
||||
from diffusers.models.attention_processor import Attention, CogVideoXAttnProcessor2_0
|
||||
from diffusers.models.autoencoders import AutoencoderKLCogVideoX
|
||||
from diffusers.models.embeddings import apply_rotary_emb
|
||||
from diffusers.models.transformers.cogvideox_transformer_3d import (
|
||||
CogVideoXBlock,
|
||||
CogVideoXTransformer3DModel,
|
||||
)
|
||||
from diffusers.pipelines.cogvideo.pipeline_cogvideox import CogVideoXPipeline, retrieve_timesteps
|
||||
from diffusers.schedulers import CogVideoXDDIMScheduler, DDIMInverseScheduler
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error.
|
||||
# Very few bug reports but it happens. Look in decord Github issues for more relevant information.
|
||||
import decord # isort: skip
|
||||
|
||||
|
||||
class DDIMInversionArguments(TypedDict):
|
||||
model_path: str
|
||||
prompt: str
|
||||
video_path: str
|
||||
output_path: str
|
||||
guidance_scale: float
|
||||
num_inference_steps: int
|
||||
skip_frames_start: int
|
||||
skip_frames_end: int
|
||||
frame_sample_step: Optional[int]
|
||||
max_num_frames: int
|
||||
width: int
|
||||
height: int
|
||||
fps: int
|
||||
dtype: torch.dtype
|
||||
seed: int
|
||||
device: torch.device
|
||||
|
||||
|
||||
def get_args() -> DDIMInversionArguments:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--model_path", type=str, required=True, help="Path of the pretrained model"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt", type=str, required=True, help="Prompt for the direct sample procedure"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--video_path", type=str, required=True, help="Path of the video for inversion"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_path", type=str, default="output", help="Path of the output videos"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--guidance_scale", type=float, default=6.0, help="Classifier-free guidance scale"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_inference_steps", type=int, default=50, help="Number of inference steps"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_frames_start", type=int, default=0, help="Number of skipped frames from the start"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_frames_end", type=int, default=0, help="Number of skipped frames from the end"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--frame_sample_step", type=int, default=None, help="Temporal stride of the sampled frames"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_num_frames", type=int, default=81, help="Max number of sampled frames"
|
||||
)
|
||||
parser.add_argument("--width", type=int, default=720, help="Resized width of the video frames")
|
||||
parser.add_argument(
|
||||
"--height", type=int, default=480, help="Resized height of the video frames"
|
||||
)
|
||||
parser.add_argument("--fps", type=int, default=8, help="Frame rate of the output videos")
|
||||
parser.add_argument(
|
||||
"--dtype", type=str, default="bf16", choices=["bf16", "fp16"], help="Dtype of the model"
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=42, help="Seed for the random number generator")
|
||||
parser.add_argument(
|
||||
"--device", type=str, default="cuda", choices=["cuda", "cpu"], help="Device for inference"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
args.dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float16
|
||||
args.device = torch.device(args.device)
|
||||
|
||||
return DDIMInversionArguments(**vars(args))
|
||||
|
||||
|
||||
class CogVideoXAttnProcessor2_0ForDDIMInversion(CogVideoXAttnProcessor2_0):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def calculate_attention(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn: Attention,
|
||||
batch_size: int,
|
||||
image_seq_length: int,
|
||||
text_seq_length: int,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
image_rotary_emb: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# Apply RoPE if needed
|
||||
if image_rotary_emb is not None:
|
||||
query[:, :, text_seq_length:] = apply_rotary_emb(
|
||||
query[:, :, text_seq_length:], image_rotary_emb
|
||||
)
|
||||
if not attn.is_cross_attention:
|
||||
if key.size(2) == query.size(2): # Attention for reference hidden states
|
||||
key[:, :, text_seq_length:] = apply_rotary_emb(
|
||||
key[:, :, text_seq_length:], image_rotary_emb
|
||||
)
|
||||
else: # RoPE should be applied to each group of image tokens
|
||||
key[:, :, text_seq_length : text_seq_length + image_seq_length] = (
|
||||
apply_rotary_emb(
|
||||
key[:, :, text_seq_length : text_seq_length + image_seq_length],
|
||||
image_rotary_emb,
|
||||
)
|
||||
)
|
||||
key[:, :, text_seq_length * 2 + image_seq_length :] = apply_rotary_emb(
|
||||
key[:, :, text_seq_length * 2 + image_seq_length :], image_rotary_emb
|
||||
)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
encoder_hidden_states, hidden_states = hidden_states.split(
|
||||
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
||||
)
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
image_seq_length = hidden_states.size(1)
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(
|
||||
attention_mask, sequence_length, batch_size
|
||||
)
|
||||
attention_mask = attention_mask.view(
|
||||
batch_size, attn.heads, -1, attention_mask.shape[-1]
|
||||
)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
query, query_reference = query.chunk(2)
|
||||
key, key_reference = key.chunk(2)
|
||||
value, value_reference = value.chunk(2)
|
||||
batch_size = batch_size // 2
|
||||
|
||||
hidden_states, encoder_hidden_states = self.calculate_attention(
|
||||
query=query,
|
||||
key=torch.cat((key, key_reference), dim=1),
|
||||
value=torch.cat((value, value_reference), dim=1),
|
||||
attn=attn,
|
||||
batch_size=batch_size,
|
||||
image_seq_length=image_seq_length,
|
||||
text_seq_length=text_seq_length,
|
||||
attention_mask=attention_mask,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
hidden_states_reference, encoder_hidden_states_reference = self.calculate_attention(
|
||||
query=query_reference,
|
||||
key=key_reference,
|
||||
value=value_reference,
|
||||
attn=attn,
|
||||
batch_size=batch_size,
|
||||
image_seq_length=image_seq_length,
|
||||
text_seq_length=text_seq_length,
|
||||
attention_mask=attention_mask,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
return (
|
||||
torch.cat((hidden_states, hidden_states_reference)),
|
||||
torch.cat((encoder_hidden_states, encoder_hidden_states_reference)),
|
||||
)
|
||||
|
||||
|
||||
class OverrideAttnProcessors:
|
||||
def __init__(self, transformer: CogVideoXTransformer3DModel):
|
||||
self.transformer = transformer
|
||||
self.original_processors = {}
|
||||
|
||||
def __enter__(self):
|
||||
for block in self.transformer.transformer_blocks:
|
||||
block = cast(CogVideoXBlock, block)
|
||||
self.original_processors[id(block)] = block.attn1.get_processor()
|
||||
block.attn1.set_processor(CogVideoXAttnProcessor2_0ForDDIMInversion())
|
||||
|
||||
def __exit__(self, _0, _1, _2):
|
||||
for block in self.transformer.transformer_blocks:
|
||||
block = cast(CogVideoXBlock, block)
|
||||
block.attn1.set_processor(self.original_processors[id(block)])
|
||||
|
||||
|
||||
def get_video_frames(
|
||||
video_path: str,
|
||||
width: int,
|
||||
height: int,
|
||||
skip_frames_start: int,
|
||||
skip_frames_end: int,
|
||||
max_num_frames: int,
|
||||
frame_sample_step: Optional[int],
|
||||
) -> torch.FloatTensor:
|
||||
with decord.bridge.use_torch():
|
||||
video_reader = decord.VideoReader(uri=video_path, width=width, height=height)
|
||||
video_num_frames = len(video_reader)
|
||||
start_frame = min(skip_frames_start, video_num_frames)
|
||||
end_frame = max(0, video_num_frames - skip_frames_end)
|
||||
|
||||
if end_frame <= start_frame:
|
||||
indices = [start_frame]
|
||||
elif end_frame - start_frame <= max_num_frames:
|
||||
indices = list(range(start_frame, end_frame))
|
||||
else:
|
||||
step = frame_sample_step or (end_frame - start_frame) // max_num_frames
|
||||
indices = list(range(start_frame, end_frame, step))
|
||||
|
||||
frames = video_reader.get_batch(indices=indices)
|
||||
frames = frames[:max_num_frames].float() # ensure that we don't go over the limit
|
||||
|
||||
# Choose first (4k + 1) frames as this is how many is required by the VAE
|
||||
selected_num_frames = frames.size(0)
|
||||
remainder = (3 + selected_num_frames) % 4
|
||||
if remainder != 0:
|
||||
frames = frames[:-remainder]
|
||||
assert frames.size(0) % 4 == 1
|
||||
|
||||
# Normalize the frames
|
||||
transform = T.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)
|
||||
frames = torch.stack(tuple(map(transform, frames)), dim=0)
|
||||
|
||||
return frames.permute(0, 3, 1, 2).contiguous() # [F, C, H, W]
|
||||
|
||||
|
||||
def encode_video_frames(
|
||||
vae: AutoencoderKLCogVideoX, video_frames: torch.FloatTensor
|
||||
) -> torch.FloatTensor:
|
||||
video_frames = video_frames.to(device=vae.device, dtype=vae.dtype)
|
||||
video_frames = video_frames.unsqueeze(0).permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
|
||||
latent_dist = vae.encode(x=video_frames).latent_dist.sample().transpose(1, 2)
|
||||
return latent_dist * vae.config.scaling_factor
|
||||
|
||||
|
||||
def export_latents_to_video(
|
||||
pipeline: CogVideoXPipeline, latents: torch.FloatTensor, video_path: str, fps: int
|
||||
):
|
||||
video = pipeline.decode_latents(latents)
|
||||
frames = pipeline.video_processor.postprocess_video(video=video, output_type="pil")
|
||||
export_to_video(video_frames=frames[0], output_video_path=video_path, fps=fps)
|
||||
|
||||
|
||||
# Modified from CogVideoXPipeline.__call__
|
||||
def sample(
|
||||
pipeline: CogVideoXPipeline,
|
||||
latents: torch.FloatTensor,
|
||||
scheduler: Union[DDIMInverseScheduler, CogVideoXDDIMScheduler],
|
||||
prompt: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 6,
|
||||
use_dynamic_cfg: bool = False,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
reference_latents: torch.FloatTensor = None,
|
||||
) -> torch.FloatTensor:
|
||||
pipeline._guidance_scale = guidance_scale
|
||||
pipeline._attention_kwargs = attention_kwargs
|
||||
pipeline._interrupt = False
|
||||
|
||||
device = pipeline._execution_device
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds, negative_prompt_embeds = pipeline.encode_prompt(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
do_classifier_free_guidance,
|
||||
device=device,
|
||||
)
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
if reference_latents is not None:
|
||||
prompt_embeds = torch.cat([prompt_embeds] * 2, dim=0)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps, device)
|
||||
pipeline._num_timesteps = len(timesteps)
|
||||
|
||||
# 5. Prepare latents.
|
||||
latents = latents.to(device=device) * scheduler.init_noise_sigma
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta)
|
||||
if isinstance(
|
||||
scheduler, DDIMInverseScheduler
|
||||
): # Inverse scheduler does not accept extra kwargs
|
||||
extra_step_kwargs = {}
|
||||
|
||||
# 7. Create rotary embeds if required
|
||||
image_rotary_emb = (
|
||||
pipeline._prepare_rotary_positional_embeddings(
|
||||
height=latents.size(3) * pipeline.vae_scale_factor_spatial,
|
||||
width=latents.size(4) * pipeline.vae_scale_factor_spatial,
|
||||
num_frames=latents.size(1),
|
||||
device=device,
|
||||
)
|
||||
if pipeline.transformer.config.use_rotary_positional_embeddings
|
||||
else None
|
||||
)
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * scheduler.order, 0)
|
||||
|
||||
trajectory = torch.zeros_like(latents).unsqueeze(0).repeat(len(timesteps), 1, 1, 1, 1, 1)
|
||||
with pipeline.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if pipeline.interrupt:
|
||||
continue
|
||||
|
||||
latent_model_input = (
|
||||
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
)
|
||||
if reference_latents is not None:
|
||||
reference = reference_latents[i]
|
||||
reference = torch.cat([reference] * 2) if do_classifier_free_guidance else reference
|
||||
latent_model_input = torch.cat([latent_model_input, reference], dim=0)
|
||||
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
|
||||
# predict noise model_output
|
||||
noise_pred = pipeline.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_pred.float()
|
||||
|
||||
if reference_latents is not None: # Recover the original batch size
|
||||
noise_pred, _ = noise_pred.chunk(2)
|
||||
|
||||
# perform guidance
|
||||
if use_dynamic_cfg:
|
||||
pipeline._guidance_scale = 1 + guidance_scale * (
|
||||
(
|
||||
1
|
||||
- math.cos(
|
||||
math.pi
|
||||
* ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0
|
||||
)
|
||||
)
|
||||
/ 2
|
||||
)
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + pipeline.guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
|
||||
# compute the noisy sample x_t-1 -> x_t
|
||||
latents = scheduler.step(
|
||||
noise_pred, t, latents, **extra_step_kwargs, return_dict=False
|
||||
)[0]
|
||||
latents = latents.to(prompt_embeds.dtype)
|
||||
trajectory[i] = latents
|
||||
|
||||
if i == len(timesteps) - 1 or (
|
||||
(i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0
|
||||
):
|
||||
progress_bar.update()
|
||||
|
||||
# Offload all models
|
||||
pipeline.maybe_free_model_hooks()
|
||||
|
||||
return trajectory
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_inversion(
|
||||
model_path: str,
|
||||
prompt: str,
|
||||
video_path: str,
|
||||
output_path: str,
|
||||
guidance_scale: float,
|
||||
num_inference_steps: int,
|
||||
skip_frames_start: int,
|
||||
skip_frames_end: int,
|
||||
frame_sample_step: Optional[int],
|
||||
max_num_frames: int,
|
||||
width: int,
|
||||
height: int,
|
||||
fps: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: torch.device,
|
||||
):
|
||||
pipeline: CogVideoXPipeline = CogVideoXPipeline.from_pretrained(
|
||||
model_path, torch_dtype=dtype
|
||||
).to(device=device)
|
||||
if not pipeline.transformer.config.use_rotary_positional_embeddings:
|
||||
raise NotImplementedError("This script supports CogVideoX 5B model only.")
|
||||
video_frames = get_video_frames(
|
||||
video_path=video_path,
|
||||
width=width,
|
||||
height=height,
|
||||
skip_frames_start=skip_frames_start,
|
||||
skip_frames_end=skip_frames_end,
|
||||
max_num_frames=max_num_frames,
|
||||
frame_sample_step=frame_sample_step,
|
||||
).to(device=device)
|
||||
video_latents = encode_video_frames(vae=pipeline.vae, video_frames=video_frames)
|
||||
inverse_scheduler = DDIMInverseScheduler(**pipeline.scheduler.config)
|
||||
inverse_latents = sample(
|
||||
pipeline=pipeline,
|
||||
latents=video_latents,
|
||||
scheduler=inverse_scheduler,
|
||||
prompt="",
|
||||
num_inference_steps=num_inference_steps,
|
||||
guidance_scale=guidance_scale,
|
||||
generator=torch.Generator(device=device).manual_seed(seed),
|
||||
)
|
||||
with OverrideAttnProcessors(transformer=pipeline.transformer):
|
||||
recon_latents = sample(
|
||||
pipeline=pipeline,
|
||||
latents=torch.randn_like(video_latents),
|
||||
scheduler=pipeline.scheduler,
|
||||
prompt=prompt,
|
||||
num_inference_steps=num_inference_steps,
|
||||
guidance_scale=guidance_scale,
|
||||
generator=torch.Generator(device=device).manual_seed(seed),
|
||||
reference_latents=reversed(inverse_latents),
|
||||
)
|
||||
filename, _ = os.path.splitext(os.path.basename(video_path))
|
||||
inverse_video_path = os.path.join(output_path, f"{filename}_inversion.mp4")
|
||||
recon_video_path = os.path.join(output_path, f"{filename}_reconstruction.mp4")
|
||||
export_latents_to_video(pipeline, inverse_latents[-1], inverse_video_path, fps)
|
||||
export_latents_to_video(pipeline, recon_latents[-1], recon_video_path, fps)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
arguments = get_args()
|
||||
ddim_inversion(**arguments)
|
@ -41,7 +41,5 @@ pip install -r requirements.txt
|
||||
## Running the code
|
||||
|
||||
```bash
|
||||
python gradio_web_demo.py
|
||||
python app.py
|
||||
```
|
||||
|
||||
|
||||
|
@ -30,7 +30,7 @@ from datetime import datetime, timedelta
|
||||
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from openai import OpenAI
|
||||
import moviepy.editor as mp
|
||||
from moviepy import VideoFileClip
|
||||
import utils
|
||||
from rife_model import load_rife_model, rife_inference_with_latents
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
@ -39,11 +39,15 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
MODEL = "THUDM/CogVideoX-5b"
|
||||
|
||||
hf_hub_download(repo_id="ai-forever/Real-ESRGAN", filename="RealESRGAN_x4.pth", local_dir="model_real_esran")
|
||||
hf_hub_download(
|
||||
repo_id="ai-forever/Real-ESRGAN", filename="RealESRGAN_x4.pth", local_dir="model_real_esran"
|
||||
)
|
||||
snapshot_download(repo_id="AlexWortega/RIFE", local_dir="model_rife")
|
||||
|
||||
pipe = CogVideoXPipeline.from_pretrained(MODEL, torch_dtype=torch.bfloat16).to(device)
|
||||
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
|
||||
pipe.scheduler = CogVideoXDPMScheduler.from_config(
|
||||
pipe.scheduler.config, timestep_spacing="trailing"
|
||||
)
|
||||
pipe_video = CogVideoXVideoToVideoPipeline.from_pretrained(
|
||||
MODEL,
|
||||
transformer=pipe.transformer,
|
||||
@ -271,9 +275,9 @@ def infer(
|
||||
|
||||
|
||||
def convert_to_gif(video_path):
|
||||
clip = mp.VideoFileClip(video_path)
|
||||
clip = clip.set_fps(8)
|
||||
clip = clip.resize(height=240)
|
||||
clip = VideoFileClip(video_path)
|
||||
clip = clip.with_fps(8)
|
||||
clip = clip.resized(height=240)
|
||||
gif_path = video_path.replace(".mp4", ".gif")
|
||||
clip.write_gif(gif_path, fps=8)
|
||||
return gif_path
|
||||
@ -296,8 +300,16 @@ def delete_old_files():
|
||||
|
||||
|
||||
threading.Thread(target=delete_old_files, daemon=True).start()
|
||||
examples_videos = [["example_videos/horse.mp4"], ["example_videos/kitten.mp4"], ["example_videos/train_running.mp4"]]
|
||||
examples_images = [["example_images/beach.png"], ["example_images/street.png"], ["example_images/camping.png"]]
|
||||
examples_videos = [
|
||||
["example_videos/horse.mp4"],
|
||||
["example_videos/kitten.mp4"],
|
||||
["example_videos/train_running.mp4"],
|
||||
]
|
||||
examples_images = [
|
||||
["example_images/beach.png"],
|
||||
["example_images/street.png"],
|
||||
["example_images/camping.png"],
|
||||
]
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown("""
|
||||
@ -322,14 +334,26 @@ with gr.Blocks() as demo:
|
||||
""")
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
with gr.Accordion("I2V: Image Input (cannot be used simultaneously with video input)", open=False):
|
||||
with gr.Accordion(
|
||||
"I2V: Image Input (cannot be used simultaneously with video input)", open=False
|
||||
):
|
||||
image_input = gr.Image(label="Input Image (will be cropped to 720 * 480)")
|
||||
examples_component_images = gr.Examples(examples_images, inputs=[image_input], cache_examples=False)
|
||||
with gr.Accordion("V2V: Video Input (cannot be used simultaneously with image input)", open=False):
|
||||
video_input = gr.Video(label="Input Video (will be cropped to 49 frames, 6 seconds at 8fps)")
|
||||
examples_component_images = gr.Examples(
|
||||
examples_images, inputs=[image_input], cache_examples=False
|
||||
)
|
||||
with gr.Accordion(
|
||||
"V2V: Video Input (cannot be used simultaneously with image input)", open=False
|
||||
):
|
||||
video_input = gr.Video(
|
||||
label="Input Video (will be cropped to 49 frames, 6 seconds at 8fps)"
|
||||
)
|
||||
strength = gr.Slider(0.1, 1.0, value=0.8, step=0.01, label="Strength")
|
||||
examples_component_videos = gr.Examples(examples_videos, inputs=[video_input], cache_examples=False)
|
||||
prompt = gr.Textbox(label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5)
|
||||
examples_component_videos = gr.Examples(
|
||||
examples_videos, inputs=[video_input], cache_examples=False
|
||||
)
|
||||
prompt = gr.Textbox(
|
||||
label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
gr.Markdown(
|
||||
@ -340,11 +364,16 @@ with gr.Blocks() as demo:
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
seed_param = gr.Number(
|
||||
label="Inference Seed (Enter a positive number, -1 for random)", value=-1
|
||||
label="Inference Seed (Enter a positive number, -1 for random)",
|
||||
value=-1,
|
||||
)
|
||||
with gr.Row():
|
||||
enable_scale = gr.Checkbox(label="Super-Resolution (720 × 480 -> 2880 × 1920)", value=False)
|
||||
enable_rife = gr.Checkbox(label="Frame Interpolation (8fps -> 16fps)", value=False)
|
||||
enable_scale = gr.Checkbox(
|
||||
label="Super-Resolution (720 × 480 -> 2880 × 1920)", value=False
|
||||
)
|
||||
enable_rife = gr.Checkbox(
|
||||
label="Frame Interpolation (8fps -> 16fps)", value=False
|
||||
)
|
||||
gr.Markdown(
|
||||
"✨In this demo, we use [RIFE](https://github.com/hzwer/ECCV2022-RIFE) for frame interpolation and [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) for upscaling(Super-Resolution).<br> The entire process is based on open-source solutions."
|
||||
)
|
||||
@ -430,7 +459,7 @@ with gr.Blocks() as demo:
|
||||
seed_value,
|
||||
scale_status,
|
||||
rife_status,
|
||||
progress=gr.Progress(track_tqdm=True)
|
||||
progress=gr.Progress(track_tqdm=True),
|
||||
):
|
||||
latents, seed = infer(
|
||||
prompt,
|
||||
@ -457,7 +486,9 @@ with gr.Blocks() as demo:
|
||||
image_pil = VaeImageProcessor.numpy_to_pil(image_np)
|
||||
batch_video_frames.append(image_pil)
|
||||
|
||||
video_path = utils.save_video(batch_video_frames[0], fps=math.ceil((len(batch_video_frames[0]) - 1) / 6))
|
||||
video_path = utils.save_video(
|
||||
batch_video_frames[0], fps=math.ceil((len(batch_video_frames[0]) - 1) / 6)
|
||||
)
|
||||
video_update = gr.update(visible=True, value=video_path)
|
||||
gif_path = convert_to_gif(video_path)
|
||||
gif_update = gr.update(visible=True, value=gif_path)
|
||||
|
@ -15,5 +15,5 @@ gradio>=5.4.0
|
||||
imageio>=2.34.2
|
||||
imageio-ffmpeg>=0.5.1
|
||||
openai>=1.45.0
|
||||
moviepy>=1.0.3
|
||||
moviepy>=2.0.0
|
||||
pillow==9.5.0
|
@ -3,7 +3,9 @@ from .refine import *
|
||||
|
||||
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
|
||||
return nn.Sequential(
|
||||
torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1),
|
||||
torch.nn.ConvTranspose2d(
|
||||
in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1
|
||||
),
|
||||
nn.PReLU(out_planes),
|
||||
)
|
||||
|
||||
@ -46,7 +48,11 @@ class IFBlock(nn.Module):
|
||||
if scale != 1:
|
||||
x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False)
|
||||
if flow != None:
|
||||
flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False) * 1.0 / scale
|
||||
flow = (
|
||||
F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False)
|
||||
* 1.0
|
||||
/ scale
|
||||
)
|
||||
x = torch.cat((x, flow), 1)
|
||||
x = self.conv0(x)
|
||||
x = self.convblock(x) + x
|
||||
@ -102,7 +108,9 @@ class IFNet(nn.Module):
|
||||
warped_img0_teacher = warp(img0, flow_teacher[:, :2])
|
||||
warped_img1_teacher = warp(img1, flow_teacher[:, 2:4])
|
||||
mask_teacher = torch.sigmoid(mask + mask_d)
|
||||
merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * (1 - mask_teacher)
|
||||
merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * (
|
||||
1 - mask_teacher
|
||||
)
|
||||
else:
|
||||
flow_teacher = None
|
||||
merged_teacher = None
|
||||
@ -110,11 +118,16 @@ class IFNet(nn.Module):
|
||||
merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
|
||||
if gt.shape[1] == 3:
|
||||
loss_mask = (
|
||||
((merged[i] - gt).abs().mean(1, True) > (merged_teacher - gt).abs().mean(1, True) + 0.01)
|
||||
(
|
||||
(merged[i] - gt).abs().mean(1, True)
|
||||
> (merged_teacher - gt).abs().mean(1, True) + 0.01
|
||||
)
|
||||
.float()
|
||||
.detach()
|
||||
)
|
||||
loss_distill += (((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask).mean()
|
||||
loss_distill += (
|
||||
((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask
|
||||
).mean()
|
||||
c0 = self.contextnet(img0, flow[:, :2])
|
||||
c1 = self.contextnet(img1, flow[:, 2:4])
|
||||
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
|
||||
|
@ -3,7 +3,9 @@ from .refine_2R import *
|
||||
|
||||
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
|
||||
return nn.Sequential(
|
||||
torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1),
|
||||
torch.nn.ConvTranspose2d(
|
||||
in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1
|
||||
),
|
||||
nn.PReLU(out_planes),
|
||||
)
|
||||
|
||||
@ -46,7 +48,11 @@ class IFBlock(nn.Module):
|
||||
if scale != 1:
|
||||
x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False)
|
||||
if flow != None:
|
||||
flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False) * 1.0 / scale
|
||||
flow = (
|
||||
F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False)
|
||||
* 1.0
|
||||
/ scale
|
||||
)
|
||||
x = torch.cat((x, flow), 1)
|
||||
x = self.conv0(x)
|
||||
x = self.convblock(x) + x
|
||||
@ -102,7 +108,9 @@ class IFNet(nn.Module):
|
||||
warped_img0_teacher = warp(img0, flow_teacher[:, :2])
|
||||
warped_img1_teacher = warp(img1, flow_teacher[:, 2:4])
|
||||
mask_teacher = torch.sigmoid(mask + mask_d)
|
||||
merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * (1 - mask_teacher)
|
||||
merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * (
|
||||
1 - mask_teacher
|
||||
)
|
||||
else:
|
||||
flow_teacher = None
|
||||
merged_teacher = None
|
||||
@ -110,11 +118,16 @@ class IFNet(nn.Module):
|
||||
merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
|
||||
if gt.shape[1] == 3:
|
||||
loss_mask = (
|
||||
((merged[i] - gt).abs().mean(1, True) > (merged_teacher - gt).abs().mean(1, True) + 0.01)
|
||||
(
|
||||
(merged[i] - gt).abs().mean(1, True)
|
||||
> (merged_teacher - gt).abs().mean(1, True) + 0.01
|
||||
)
|
||||
.float()
|
||||
.detach()
|
||||
)
|
||||
loss_distill += (((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask).mean()
|
||||
loss_distill += (
|
||||
((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask
|
||||
).mean()
|
||||
c0 = self.contextnet(img0, flow[:, :2])
|
||||
c1 = self.contextnet(img1, flow[:, 2:4])
|
||||
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
|
||||
|
@ -61,11 +61,19 @@ class IFBlock(nn.Module):
|
||||
|
||||
def forward(self, x, flow, scale=1):
|
||||
x = F.interpolate(
|
||||
x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False
|
||||
x,
|
||||
scale_factor=1.0 / scale,
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
recompute_scale_factor=False,
|
||||
)
|
||||
flow = (
|
||||
F.interpolate(
|
||||
flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False
|
||||
flow,
|
||||
scale_factor=1.0 / scale,
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
recompute_scale_factor=False,
|
||||
)
|
||||
* 1.0
|
||||
/ scale
|
||||
@ -78,11 +86,21 @@ class IFBlock(nn.Module):
|
||||
flow = self.conv1(feat)
|
||||
mask = self.conv2(feat)
|
||||
flow = (
|
||||
F.interpolate(flow, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
|
||||
F.interpolate(
|
||||
flow,
|
||||
scale_factor=scale,
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
recompute_scale_factor=False,
|
||||
)
|
||||
* scale
|
||||
)
|
||||
mask = F.interpolate(
|
||||
mask, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False
|
||||
mask,
|
||||
scale_factor=scale,
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
recompute_scale_factor=False,
|
||||
)
|
||||
return flow, mask
|
||||
|
||||
@ -112,7 +130,11 @@ class IFNet(nn.Module):
|
||||
loss_cons = 0
|
||||
block = [self.block0, self.block1, self.block2]
|
||||
for i in range(3):
|
||||
f0, m0 = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], mask), 1), flow, scale=scale_list[i])
|
||||
f0, m0 = block[i](
|
||||
torch.cat((warped_img0[:, :3], warped_img1[:, :3], mask), 1),
|
||||
flow,
|
||||
scale=scale_list[i],
|
||||
)
|
||||
f1, m1 = block[i](
|
||||
torch.cat((warped_img1[:, :3], warped_img0[:, :3], -mask), 1),
|
||||
torch.cat((flow[:, 2:4], flow[:, :2]), 1),
|
||||
|
@ -3,7 +3,9 @@ from .refine import *
|
||||
|
||||
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
|
||||
return nn.Sequential(
|
||||
torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1),
|
||||
torch.nn.ConvTranspose2d(
|
||||
in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1
|
||||
),
|
||||
nn.PReLU(out_planes),
|
||||
)
|
||||
|
||||
@ -46,7 +48,11 @@ class IFBlock(nn.Module):
|
||||
if scale != 1:
|
||||
x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False)
|
||||
if flow != None:
|
||||
flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False) * 1.0 / scale
|
||||
flow = (
|
||||
F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False)
|
||||
* 1.0
|
||||
/ scale
|
||||
)
|
||||
x = torch.cat((x, flow), 1)
|
||||
x = self.conv0(x)
|
||||
x = self.convblock(x) + x
|
||||
@ -83,7 +89,9 @@ class IFNet_m(nn.Module):
|
||||
for i in range(3):
|
||||
if flow != None:
|
||||
flow_d, mask_d = stu[i](
|
||||
torch.cat((img0, img1, timestep, warped_img0, warped_img1, mask), 1), flow, scale=scale[i]
|
||||
torch.cat((img0, img1, timestep, warped_img0, warped_img1, mask), 1),
|
||||
flow,
|
||||
scale=scale[i],
|
||||
)
|
||||
flow = flow + flow_d
|
||||
mask = mask + mask_d
|
||||
@ -97,13 +105,17 @@ class IFNet_m(nn.Module):
|
||||
merged.append(merged_student)
|
||||
if gt.shape[1] == 3:
|
||||
flow_d, mask_d = self.block_tea(
|
||||
torch.cat((img0, img1, timestep, warped_img0, warped_img1, mask, gt), 1), flow, scale=1
|
||||
torch.cat((img0, img1, timestep, warped_img0, warped_img1, mask, gt), 1),
|
||||
flow,
|
||||
scale=1,
|
||||
)
|
||||
flow_teacher = flow + flow_d
|
||||
warped_img0_teacher = warp(img0, flow_teacher[:, :2])
|
||||
warped_img1_teacher = warp(img1, flow_teacher[:, 2:4])
|
||||
mask_teacher = torch.sigmoid(mask + mask_d)
|
||||
merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * (1 - mask_teacher)
|
||||
merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * (
|
||||
1 - mask_teacher
|
||||
)
|
||||
else:
|
||||
flow_teacher = None
|
||||
merged_teacher = None
|
||||
@ -111,11 +123,16 @@ class IFNet_m(nn.Module):
|
||||
merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
|
||||
if gt.shape[1] == 3:
|
||||
loss_mask = (
|
||||
((merged[i] - gt).abs().mean(1, True) > (merged_teacher - gt).abs().mean(1, True) + 0.01)
|
||||
(
|
||||
(merged[i] - gt).abs().mean(1, True)
|
||||
> (merged_teacher - gt).abs().mean(1, True) + 0.01
|
||||
)
|
||||
.float()
|
||||
.detach()
|
||||
)
|
||||
loss_distill += (((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask).mean()
|
||||
loss_distill += (
|
||||
((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask
|
||||
).mean()
|
||||
if returnflow:
|
||||
return flow
|
||||
else:
|
||||
|
@ -44,7 +44,9 @@ class Model:
|
||||
if torch.cuda.is_available():
|
||||
self.flownet.load_state_dict(convert(torch.load("{}/flownet.pkl".format(path))))
|
||||
else:
|
||||
self.flownet.load_state_dict(convert(torch.load("{}/flownet.pkl".format(path), map_location="cpu")))
|
||||
self.flownet.load_state_dict(
|
||||
convert(torch.load("{}/flownet.pkl".format(path), map_location="cpu"))
|
||||
)
|
||||
|
||||
def save_model(self, path, rank=0):
|
||||
if rank == 0:
|
||||
|
@ -29,10 +29,14 @@ def downsample(x):
|
||||
|
||||
|
||||
def upsample(x):
|
||||
cc = torch.cat([x, torch.zeros(x.shape[0], x.shape[1], x.shape[2], x.shape[3]).to(device)], dim=3)
|
||||
cc = torch.cat(
|
||||
[x, torch.zeros(x.shape[0], x.shape[1], x.shape[2], x.shape[3]).to(device)], dim=3
|
||||
)
|
||||
cc = cc.view(x.shape[0], x.shape[1], x.shape[2] * 2, x.shape[3])
|
||||
cc = cc.permute(0, 1, 3, 2)
|
||||
cc = torch.cat([cc, torch.zeros(x.shape[0], x.shape[1], x.shape[3], x.shape[2] * 2).to(device)], dim=3)
|
||||
cc = torch.cat(
|
||||
[cc, torch.zeros(x.shape[0], x.shape[1], x.shape[3], x.shape[2] * 2).to(device)], dim=3
|
||||
)
|
||||
cc = cc.view(x.shape[0], x.shape[1], x.shape[3] * 2, x.shape[2] * 2)
|
||||
x_up = cc.permute(0, 1, 3, 2)
|
||||
return conv_gauss(x_up, 4 * gauss_kernel(channels=x.shape[1]))
|
||||
@ -64,6 +68,10 @@ class LapLoss(torch.nn.Module):
|
||||
self.gauss_kernel = gauss_kernel(channels=channels)
|
||||
|
||||
def forward(self, input, target):
|
||||
pyr_input = laplacian_pyramid(img=input, kernel=self.gauss_kernel, max_levels=self.max_levels)
|
||||
pyr_target = laplacian_pyramid(img=target, kernel=self.gauss_kernel, max_levels=self.max_levels)
|
||||
pyr_input = laplacian_pyramid(
|
||||
img=input, kernel=self.gauss_kernel, max_levels=self.max_levels
|
||||
)
|
||||
pyr_target = laplacian_pyramid(
|
||||
img=target, kernel=self.gauss_kernel, max_levels=self.max_levels
|
||||
)
|
||||
return sum(torch.nn.functional.l1_loss(a, b) for a, b in zip(pyr_input, pyr_target))
|
||||
|
@ -7,7 +7,9 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
def gaussian(window_size, sigma):
|
||||
gauss = torch.Tensor([exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2)) for x in range(window_size)])
|
||||
gauss = torch.Tensor(
|
||||
[exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2)) for x in range(window_size)]
|
||||
)
|
||||
return gauss / gauss.sum()
|
||||
|
||||
|
||||
@ -22,7 +24,9 @@ def create_window_3d(window_size, channel=1):
|
||||
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
||||
_2D_window = _1D_window.mm(_1D_window.t())
|
||||
_3D_window = _2D_window.unsqueeze(2) @ (_1D_window.t())
|
||||
window = _3D_window.expand(1, channel, window_size, window_size, window_size).contiguous().to(device)
|
||||
window = (
|
||||
_3D_window.expand(1, channel, window_size, window_size, window_size).contiguous().to(device)
|
||||
)
|
||||
return window
|
||||
|
||||
|
||||
@ -50,16 +54,35 @@ def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False,
|
||||
|
||||
# mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
|
||||
# mu2 = F.conv2d(img2, window, padding=padd, groups=channel)
|
||||
mu1 = F.conv2d(F.pad(img1, (5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=channel)
|
||||
mu2 = F.conv2d(F.pad(img2, (5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=channel)
|
||||
mu1 = F.conv2d(
|
||||
F.pad(img1, (5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=channel
|
||||
)
|
||||
mu2 = F.conv2d(
|
||||
F.pad(img2, (5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=channel
|
||||
)
|
||||
|
||||
mu1_sq = mu1.pow(2)
|
||||
mu2_sq = mu2.pow(2)
|
||||
mu1_mu2 = mu1 * mu2
|
||||
|
||||
sigma1_sq = F.conv2d(F.pad(img1 * img1, (5, 5, 5, 5), "replicate"), window, padding=padd, groups=channel) - mu1_sq
|
||||
sigma2_sq = F.conv2d(F.pad(img2 * img2, (5, 5, 5, 5), "replicate"), window, padding=padd, groups=channel) - mu2_sq
|
||||
sigma12 = F.conv2d(F.pad(img1 * img2, (5, 5, 5, 5), "replicate"), window, padding=padd, groups=channel) - mu1_mu2
|
||||
sigma1_sq = (
|
||||
F.conv2d(
|
||||
F.pad(img1 * img1, (5, 5, 5, 5), "replicate"), window, padding=padd, groups=channel
|
||||
)
|
||||
- mu1_sq
|
||||
)
|
||||
sigma2_sq = (
|
||||
F.conv2d(
|
||||
F.pad(img2 * img2, (5, 5, 5, 5), "replicate"), window, padding=padd, groups=channel
|
||||
)
|
||||
- mu2_sq
|
||||
)
|
||||
sigma12 = (
|
||||
F.conv2d(
|
||||
F.pad(img1 * img2, (5, 5, 5, 5), "replicate"), window, padding=padd, groups=channel
|
||||
)
|
||||
- mu1_mu2
|
||||
)
|
||||
|
||||
C1 = (0.01 * L) ** 2
|
||||
C2 = (0.03 * L) ** 2
|
||||
@ -80,7 +103,9 @@ def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False,
|
||||
return ret
|
||||
|
||||
|
||||
def ssim_matlab(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
|
||||
def ssim_matlab(
|
||||
img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None
|
||||
):
|
||||
# Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
|
||||
if val_range is None:
|
||||
if torch.max(img1) > 128:
|
||||
@ -106,16 +131,35 @@ def ssim_matlab(img1, img2, window_size=11, window=None, size_average=True, full
|
||||
img1 = img1.unsqueeze(1)
|
||||
img2 = img2.unsqueeze(1)
|
||||
|
||||
mu1 = F.conv3d(F.pad(img1, (5, 5, 5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=1)
|
||||
mu2 = F.conv3d(F.pad(img2, (5, 5, 5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=1)
|
||||
mu1 = F.conv3d(
|
||||
F.pad(img1, (5, 5, 5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=1
|
||||
)
|
||||
mu2 = F.conv3d(
|
||||
F.pad(img2, (5, 5, 5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=1
|
||||
)
|
||||
|
||||
mu1_sq = mu1.pow(2)
|
||||
mu2_sq = mu2.pow(2)
|
||||
mu1_mu2 = mu1 * mu2
|
||||
|
||||
sigma1_sq = F.conv3d(F.pad(img1 * img1, (5, 5, 5, 5, 5, 5), "replicate"), window, padding=padd, groups=1) - mu1_sq
|
||||
sigma2_sq = F.conv3d(F.pad(img2 * img2, (5, 5, 5, 5, 5, 5), "replicate"), window, padding=padd, groups=1) - mu2_sq
|
||||
sigma12 = F.conv3d(F.pad(img1 * img2, (5, 5, 5, 5, 5, 5), "replicate"), window, padding=padd, groups=1) - mu1_mu2
|
||||
sigma1_sq = (
|
||||
F.conv3d(
|
||||
F.pad(img1 * img1, (5, 5, 5, 5, 5, 5), "replicate"), window, padding=padd, groups=1
|
||||
)
|
||||
- mu1_sq
|
||||
)
|
||||
sigma2_sq = (
|
||||
F.conv3d(
|
||||
F.pad(img2 * img2, (5, 5, 5, 5, 5, 5), "replicate"), window, padding=padd, groups=1
|
||||
)
|
||||
- mu2_sq
|
||||
)
|
||||
sigma12 = (
|
||||
F.conv3d(
|
||||
F.pad(img1 * img2, (5, 5, 5, 5, 5, 5), "replicate"), window, padding=padd, groups=1
|
||||
)
|
||||
- mu1_mu2
|
||||
)
|
||||
|
||||
C1 = (0.01 * L) ** 2
|
||||
C2 = (0.03 * L) ** 2
|
||||
@ -143,7 +187,14 @@ def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normal
|
||||
mssim = []
|
||||
mcs = []
|
||||
for _ in range(levels):
|
||||
sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range)
|
||||
sim, cs = ssim(
|
||||
img1,
|
||||
img2,
|
||||
window_size=window_size,
|
||||
size_average=size_average,
|
||||
full=True,
|
||||
val_range=val_range,
|
||||
)
|
||||
mssim.append(sim)
|
||||
mcs.append(cs)
|
||||
|
||||
@ -187,7 +238,9 @@ class SSIM(torch.nn.Module):
|
||||
self.window = window
|
||||
self.channel = channel
|
||||
|
||||
_ssim = ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)
|
||||
_ssim = ssim(
|
||||
img1, img2, window=window, window_size=self.window_size, size_average=self.size_average
|
||||
)
|
||||
dssim = (1 - _ssim) / 2
|
||||
return dssim
|
||||
|
||||
|
@ -24,7 +24,12 @@ def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
|
||||
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
|
||||
return nn.Sequential(
|
||||
torch.nn.ConvTranspose2d(
|
||||
in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1, bias=True
|
||||
in_channels=in_planes,
|
||||
out_channels=out_planes,
|
||||
kernel_size=4,
|
||||
stride=2,
|
||||
padding=1,
|
||||
bias=True,
|
||||
),
|
||||
nn.PReLU(out_planes),
|
||||
)
|
||||
@ -56,25 +61,49 @@ class Contextnet(nn.Module):
|
||||
def forward(self, x, flow):
|
||||
x = self.conv1(x)
|
||||
flow = (
|
||||
F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False)
|
||||
F.interpolate(
|
||||
flow,
|
||||
scale_factor=0.5,
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
recompute_scale_factor=False,
|
||||
)
|
||||
* 0.5
|
||||
)
|
||||
f1 = warp(x, flow)
|
||||
x = self.conv2(x)
|
||||
flow = (
|
||||
F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False)
|
||||
F.interpolate(
|
||||
flow,
|
||||
scale_factor=0.5,
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
recompute_scale_factor=False,
|
||||
)
|
||||
* 0.5
|
||||
)
|
||||
f2 = warp(x, flow)
|
||||
x = self.conv3(x)
|
||||
flow = (
|
||||
F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False)
|
||||
F.interpolate(
|
||||
flow,
|
||||
scale_factor=0.5,
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
recompute_scale_factor=False,
|
||||
)
|
||||
* 0.5
|
||||
)
|
||||
f3 = warp(x, flow)
|
||||
x = self.conv4(x)
|
||||
flow = (
|
||||
F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False)
|
||||
F.interpolate(
|
||||
flow,
|
||||
scale_factor=0.5,
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
recompute_scale_factor=False,
|
||||
)
|
||||
* 0.5
|
||||
)
|
||||
f4 = warp(x, flow)
|
||||
|
@ -24,7 +24,12 @@ def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
|
||||
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
|
||||
return nn.Sequential(
|
||||
torch.nn.ConvTranspose2d(
|
||||
in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1, bias=True
|
||||
in_channels=in_planes,
|
||||
out_channels=out_planes,
|
||||
kernel_size=4,
|
||||
stride=2,
|
||||
padding=1,
|
||||
bias=True,
|
||||
),
|
||||
nn.PReLU(out_planes),
|
||||
)
|
||||
@ -59,19 +64,37 @@ class Contextnet(nn.Module):
|
||||
f1 = warp(x, flow)
|
||||
x = self.conv2(x)
|
||||
flow = (
|
||||
F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False)
|
||||
F.interpolate(
|
||||
flow,
|
||||
scale_factor=0.5,
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
recompute_scale_factor=False,
|
||||
)
|
||||
* 0.5
|
||||
)
|
||||
f2 = warp(x, flow)
|
||||
x = self.conv3(x)
|
||||
flow = (
|
||||
F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False)
|
||||
F.interpolate(
|
||||
flow,
|
||||
scale_factor=0.5,
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
recompute_scale_factor=False,
|
||||
)
|
||||
* 0.5
|
||||
)
|
||||
f3 = warp(x, flow)
|
||||
x = self.conv4(x)
|
||||
flow = (
|
||||
F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False)
|
||||
F.interpolate(
|
||||
flow,
|
||||
scale_factor=0.5,
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
recompute_scale_factor=False,
|
||||
)
|
||||
* 0.5
|
||||
)
|
||||
f4 = warp(x, flow)
|
||||
|
@ -9,6 +9,7 @@ import logging
|
||||
import skvideo.io
|
||||
from rife.RIFE_HDv3 import Model
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
@ -19,7 +20,7 @@ def pad_image(img, scale):
|
||||
tmp = max(32, int(32 / scale))
|
||||
ph = ((h - 1) // tmp + 1) * tmp
|
||||
pw = ((w - 1) // tmp + 1) * tmp
|
||||
padding = (0, pw - w, 0, ph - h)
|
||||
padding = (0, pw - w, 0, ph - h)
|
||||
return F.pad(img, padding), padding
|
||||
|
||||
|
||||
@ -78,13 +79,12 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi
|
||||
# print(f'I1[0] unpadded shape:{I1.shape}')
|
||||
I1_small = F.interpolate(I1, (32, 32), mode="bilinear", align_corners=False)
|
||||
ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
|
||||
if padding[3] > 0 and padding[1] >0 :
|
||||
|
||||
frame = I1[:, :, : -padding[3],:-padding[1]]
|
||||
if padding[3] > 0 and padding[1] > 0:
|
||||
frame = I1[:, :, : -padding[3], : -padding[1]]
|
||||
elif padding[3] > 0:
|
||||
frame = I1[:, :, : -padding[3],:]
|
||||
elif padding[1] >0:
|
||||
frame = I1[:, :, :,:-padding[1]]
|
||||
frame = I1[:, :, : -padding[3], :]
|
||||
elif padding[1] > 0:
|
||||
frame = I1[:, :, :, : -padding[1]]
|
||||
else:
|
||||
frame = I1
|
||||
|
||||
@ -102,7 +102,6 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi
|
||||
frame = F.interpolate(frame, size=(h, w))
|
||||
output.append(frame.to(output_device))
|
||||
for i, tmp_frame in enumerate(tmp_output):
|
||||
|
||||
# tmp_frame, _ = pad_image(tmp_frame, upscale_amount)
|
||||
tmp_frame = F.interpolate(tmp_frame, size=(h, w))
|
||||
output.append(tmp_frame.to(output_device))
|
||||
@ -145,9 +144,7 @@ def rife_inference_with_path(model, video_path):
|
||||
frame_rgb = frame[..., ::-1]
|
||||
frame_rgb = frame_rgb.copy()
|
||||
tensor = torch.from_numpy(frame_rgb).float().to("cpu", non_blocking=True).float() / 255.0
|
||||
pt_frame_data.append(
|
||||
tensor.permute(2, 0, 1)
|
||||
) # to [c, h, w,]
|
||||
pt_frame_data.append(tensor.permute(2, 0, 1)) # to [c, h, w,]
|
||||
|
||||
pt_frame = torch.from_numpy(np.stack(pt_frame_data))
|
||||
pt_frame = pt_frame.to(device)
|
||||
@ -170,7 +167,9 @@ def rife_inference_with_latents(model, latents):
|
||||
latent = latents[i]
|
||||
|
||||
frames = ssim_interpolation_rife(model, latent)
|
||||
pt_image = torch.stack([frames[i].squeeze(0) for i in range(len(frames))]) # (to [f, c, w, h])
|
||||
pt_image = torch.stack(
|
||||
[frames[i].squeeze(0) for i in range(len(frames))]
|
||||
) # (to [f, c, w, h])
|
||||
rife_results.append(pt_image)
|
||||
|
||||
return torch.stack(rife_results)
|
||||
|
@ -22,7 +22,7 @@ def load_torch_file(ckpt, device=None, dtype=torch.float16):
|
||||
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
|
||||
sd = safetensors.torch.load_file(ckpt, device=device.type)
|
||||
else:
|
||||
if not "weights_only" in torch.load.__code__.co_varnames:
|
||||
if "weights_only" not in torch.load.__code__.co_varnames:
|
||||
logger.warning(
|
||||
"Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely."
|
||||
)
|
||||
@ -74,27 +74,39 @@ def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
|
||||
|
||||
@torch.inference_mode()
|
||||
def tiled_scale_multidim(
|
||||
samples, function, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", pbar=None
|
||||
samples,
|
||||
function,
|
||||
tile=(64, 64),
|
||||
overlap=8,
|
||||
upscale_amount=4,
|
||||
out_channels=3,
|
||||
output_device="cpu",
|
||||
pbar=None,
|
||||
):
|
||||
dims = len(tile)
|
||||
print(f"samples dtype:{samples.dtype}")
|
||||
output = torch.empty(
|
||||
[samples.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), samples.shape[2:])),
|
||||
[samples.shape[0], out_channels]
|
||||
+ list(map(lambda a: round(a * upscale_amount), samples.shape[2:])),
|
||||
device=output_device,
|
||||
)
|
||||
|
||||
for b in range(samples.shape[0]):
|
||||
s = samples[b : b + 1]
|
||||
out = torch.zeros(
|
||||
[s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])),
|
||||
[s.shape[0], out_channels]
|
||||
+ list(map(lambda a: round(a * upscale_amount), s.shape[2:])),
|
||||
device=output_device,
|
||||
)
|
||||
out_div = torch.zeros(
|
||||
[s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])),
|
||||
[s.shape[0], out_channels]
|
||||
+ list(map(lambda a: round(a * upscale_amount), s.shape[2:])),
|
||||
device=output_device,
|
||||
)
|
||||
|
||||
for it in itertools.product(*map(lambda a: range(0, a[0], a[1] - overlap), zip(s.shape[2:], tile))):
|
||||
for it in itertools.product(
|
||||
*map(lambda a: range(0, a[0], a[1] - overlap), zip(s.shape[2:], tile))
|
||||
):
|
||||
s_in = s
|
||||
upscaled = []
|
||||
|
||||
@ -142,7 +154,14 @@ def tiled_scale(
|
||||
pbar=None,
|
||||
):
|
||||
return tiled_scale_multidim(
|
||||
samples, function, (tile_y, tile_x), overlap, upscale_amount, out_channels, output_device, pbar
|
||||
samples,
|
||||
function,
|
||||
(tile_y, tile_x),
|
||||
overlap,
|
||||
upscale_amount,
|
||||
out_channels,
|
||||
output_device,
|
||||
pbar,
|
||||
)
|
||||
|
||||
|
||||
@ -186,7 +205,9 @@ def upscale(upscale_model, tensor: torch.Tensor, inf_device, output_device="cpu"
|
||||
return s
|
||||
|
||||
|
||||
def upscale_batch_and_concatenate(upscale_model, latents, inf_device, output_device="cpu") -> torch.Tensor:
|
||||
def upscale_batch_and_concatenate(
|
||||
upscale_model, latents, inf_device, output_device="cpu"
|
||||
) -> torch.Tensor:
|
||||
upscaled_latents = []
|
||||
for i in range(latents.size(0)):
|
||||
latent = latents[i]
|
||||
@ -207,7 +228,9 @@ class ProgressBar:
|
||||
def __init__(self, total, desc=None):
|
||||
self.total = total
|
||||
self.current = 0
|
||||
self.b_unit = tqdm.tqdm(total=total, desc="ProgressBar context index: 0" if desc is None else desc)
|
||||
self.b_unit = tqdm.tqdm(
|
||||
total=total, desc="ProgressBar context index: 0" if desc is None else desc
|
||||
)
|
||||
|
||||
def update(self, value):
|
||||
if value > self.total:
|
||||
|
@ -20,9 +20,11 @@ from diffusers import CogVideoXPipeline
|
||||
from diffusers.utils import export_to_video
|
||||
from datetime import datetime, timedelta
|
||||
from openai import OpenAI
|
||||
import moviepy.editor as mp
|
||||
from moviepy import VideoFileClip
|
||||
|
||||
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to("cuda")
|
||||
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to(
|
||||
"cuda"
|
||||
)
|
||||
|
||||
pipe.vae.enable_slicing()
|
||||
pipe.vae.enable_tiling()
|
||||
@ -95,7 +97,12 @@ def convert_prompt(prompt: str, retry_times: int = 3) -> str:
|
||||
return prompt
|
||||
|
||||
|
||||
def infer(prompt: str, num_inference_steps: int, guidance_scale: float, progress=gr.Progress(track_tqdm=True)):
|
||||
def infer(
|
||||
prompt: str,
|
||||
num_inference_steps: int,
|
||||
guidance_scale: float,
|
||||
progress=gr.Progress(track_tqdm=True),
|
||||
):
|
||||
torch.cuda.empty_cache()
|
||||
video = pipe(
|
||||
prompt=prompt,
|
||||
@ -117,9 +124,9 @@ def save_video(tensor):
|
||||
|
||||
|
||||
def convert_to_gif(video_path):
|
||||
clip = mp.VideoFileClip(video_path)
|
||||
clip = clip.set_fps(8)
|
||||
clip = clip.resize(height=240)
|
||||
clip = VideoFileClip(video_path)
|
||||
clip = clip.with_fps(8)
|
||||
clip = clip.resized(height=240)
|
||||
gif_path = video_path.replace(".mp4", ".gif")
|
||||
clip.write_gif(gif_path, fps=8)
|
||||
return gif_path
|
||||
@ -151,7 +158,9 @@ with gr.Blocks() as demo:
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
prompt = gr.Textbox(label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5)
|
||||
prompt = gr.Textbox(
|
||||
label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
gr.Markdown(
|
||||
@ -176,7 +185,13 @@ with gr.Blocks() as demo:
|
||||
download_video_button = gr.File(label="📥 Download Video", visible=False)
|
||||
download_gif_button = gr.File(label="📥 Download GIF", visible=False)
|
||||
|
||||
def generate(prompt, num_inference_steps, guidance_scale, model_choice, progress=gr.Progress(track_tqdm=True)):
|
||||
def generate(
|
||||
prompt,
|
||||
num_inference_steps,
|
||||
guidance_scale,
|
||||
model_choice,
|
||||
progress=gr.Progress(track_tqdm=True),
|
||||
):
|
||||
tensor = infer(prompt, num_inference_steps, guidance_scale, progress=progress)
|
||||
video_path = save_video(tensor)
|
||||
video_update = gr.update(visible=True, value=video_path)
|
||||
|
@ -1,27 +1,43 @@
|
||||
[project]
|
||||
name = "cogvideo"
|
||||
version = "0.1.0"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.8"
|
||||
dependencies = [
|
||||
"datasets>=2.14.4",
|
||||
"decord>=0.6.0",
|
||||
"deepspeed>=0.16.4",
|
||||
"diffusers>=0.32.2",
|
||||
"opencv-python>=4.11.0.86",
|
||||
"peft>=0.13.2",
|
||||
"pydantic>=2.10.6",
|
||||
"sentencepiece>=0.2.0",
|
||||
"torch>=2.5.1",
|
||||
"torchvision>=0.20.1",
|
||||
"transformers>=4.46.3",
|
||||
"wandb>=0.19.7",
|
||||
]
|
||||
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 119
|
||||
|
||||
[tool.ruff.lint]
|
||||
# Never enforce `E501` (line length violations).
|
||||
ignore = ["C901", "E501", "E741", "F402", "F823"]
|
||||
select = ["C", "E", "F", "I", "W"]
|
||||
|
||||
# Ignore import violations in all `__init__.py` files.
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"__init__.py" = ["E402", "F401", "F403", "F811"]
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
lines-after-imports = 2
|
||||
exclude = ['.git', '.mypy_cache', '.ruff_cache', '.venv', 'dist']
|
||||
target-version = 'py310'
|
||||
line-length = 100
|
||||
|
||||
[tool.ruff.format]
|
||||
# Like Black, use double quotes for strings.
|
||||
quote-style = "double"
|
||||
line-ending = 'lf'
|
||||
quote-style = 'preserve'
|
||||
|
||||
# Like Black, indent with spaces, rather than tabs.
|
||||
indent-style = "space"
|
||||
|
||||
# Like Black, respect magic trailing commas.
|
||||
skip-magic-trailing-comma = false
|
||||
|
||||
# Like Black, automatically detect the appropriate line ending.
|
||||
line-ending = "auto"
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "W"]
|
||||
ignore = [
|
||||
"E999",
|
||||
"EXE001",
|
||||
"UP009",
|
||||
"F401",
|
||||
"TID252",
|
||||
"F403",
|
||||
"F841",
|
||||
"E501",
|
||||
"W293",
|
||||
]
|
||||
|
@ -1,14 +1,15 @@
|
||||
diffusers>=0.31.0
|
||||
accelerate>=1.0.1
|
||||
transformers>=4.46.1
|
||||
diffusers>=0.32.1
|
||||
accelerate>=1.1.1
|
||||
transformers>=4.46.2
|
||||
numpy==1.26.0
|
||||
torch>=2.5.0
|
||||
torchvision>=0.20.0
|
||||
sentencepiece>=0.2.0
|
||||
SwissArmyTransformer>=0.4.12
|
||||
gradio>=5.4.0
|
||||
gradio>=5.5.0
|
||||
imageio>=2.35.1
|
||||
imageio-ffmpeg>=0.5.1
|
||||
openai>=1.53.0
|
||||
moviepy>=1.0.3
|
||||
openai>=1.54.0
|
||||
moviepy>=2.0.0
|
||||
scikit-video>=1.1.11
|
||||
pydantic>=2.10.3
|
||||
|
@ -4,4 +4,3 @@
|
||||
<p> 扫码关注公众号,加入「 CogVideoX 交流群」 </p>
|
||||
<p> Scan the QR code to follow the official account and join the "CogVLM Discussion Group" </p>
|
||||
</div>
|
||||
|
||||
|
@ -1,49 +0,0 @@
|
||||
# Contribution Guide
|
||||
|
||||
There may still be many incomplete aspects in this project.
|
||||
|
||||
We look forward to your contributions to the repository in the following areas. If you complete the work mentioned above
|
||||
and are willing to submit a PR and share it with the community, upon review, we
|
||||
will acknowledge your contribution on the project homepage.
|
||||
|
||||
## Model Algorithms
|
||||
|
||||
- Support for model quantization inference (Int4 quantization project)
|
||||
- Optimization of model fine-tuning data loading (replacing the existing decord tool)
|
||||
|
||||
## Model Engineering
|
||||
|
||||
- Model fine-tuning examples / Best prompt practices
|
||||
- Inference adaptation on different devices (e.g., MLX framework)
|
||||
- Any tools related to the model
|
||||
- Any minimal fully open-source project using the CogVideoX open-source model
|
||||
|
||||
## Code Standards
|
||||
|
||||
Good code style is an art. We have prepared a `pyproject.toml` configuration file for the project to standardize code
|
||||
style. You can organize the code according to the following specifications:
|
||||
|
||||
1. Install the `ruff` tool
|
||||
|
||||
```shell
|
||||
pip install ruff
|
||||
```
|
||||
|
||||
Then, run the `ruff` tool
|
||||
|
||||
```shell
|
||||
ruff check tools sat inference
|
||||
```
|
||||
|
||||
Check the code style. If there are issues, you can automatically fix them using the `ruff format` command.
|
||||
|
||||
```shell
|
||||
ruff format tools sat inference
|
||||
```
|
||||
|
||||
Once your code meets the standard, there should be no errors.
|
||||
|
||||
## Naming Conventions
|
||||
1. Please use English names, do not use Pinyin or other language names. All comments should be in English.
|
||||
2. Please strictly follow the PEP8 specification and use underscores to separate words. Do not use names like a, b, c.
|
||||
|
@ -1,47 +0,0 @@
|
||||
# コントリビューションガイド
|
||||
|
||||
本プロジェクトにはまだ多くの未完成の部分があります。
|
||||
|
||||
以下の分野でリポジトリへの貢献をお待ちしています。上記の作業を完了し、PRを提出してコミュニティと共有する意志がある場合、レビュー後、プロジェクトのホームページで貢献を認識します。
|
||||
|
||||
## モデルアルゴリズム
|
||||
|
||||
- モデル量子化推論のサポート (Int4量子化プロジェクト)
|
||||
- モデルのファインチューニングデータロードの最適化(既存のdecordツールの置き換え)
|
||||
|
||||
## モデルエンジニアリング
|
||||
|
||||
- モデルのファインチューニング例 / 最適なプロンプトの実践
|
||||
- 異なるデバイスでの推論適応(例: MLXフレームワーク)
|
||||
- モデルに関連するツール
|
||||
- CogVideoXオープンソースモデルを使用した、完全にオープンソースの最小プロジェクト
|
||||
|
||||
## コード標準
|
||||
|
||||
良いコードスタイルは一種の芸術です。本プロジェクトにはコードスタイルを標準化するための `pyproject.toml`
|
||||
設定ファイルを用意しています。以下の仕様に従ってコードを整理してください。
|
||||
|
||||
1. `ruff` ツールをインストールする
|
||||
|
||||
```shell
|
||||
pip install ruff
|
||||
```
|
||||
|
||||
次に、`ruff` ツールを実行します
|
||||
|
||||
```shell
|
||||
ruff check tools sat inference
|
||||
```
|
||||
|
||||
コードスタイルを確認します。問題がある場合は、`ruff format` コマンドを使用して自動修正できます。
|
||||
|
||||
```shell
|
||||
ruff format tools sat inference
|
||||
```
|
||||
|
||||
コードが標準に準拠したら、エラーはなくなるはずです。
|
||||
|
||||
## 命名規則
|
||||
|
||||
1. 英語名を使用してください。ピンインや他の言語の名前を使用しないでください。すべてのコメントは英語で記載してください。
|
||||
2. PEP8仕様に厳密に従い、単語をアンダースコアで区切ってください。a、b、cのような名前は使用しないでください。
|
@ -1,44 +0,0 @@
|
||||
# 贡献指南
|
||||
|
||||
本项目可能还存在很多不完善的内容。 我们期待您在以下方面与我们共建仓库, 如果您完成了上述工作并愿意PR和分享到社区,在通过审核后,我们将在项目首页感谢您的贡献。
|
||||
|
||||
## 模型算法
|
||||
|
||||
- 模型量化推理支持 (Int4量化工程)
|
||||
- 模型微调数据载入优化支持(替换现有的decord工具)
|
||||
|
||||
## 模型工程
|
||||
|
||||
- 模型微调示例 / 最佳提示词实践
|
||||
- 不同设备上的推理适配(MLX等框架)
|
||||
- 任何模型周边工具
|
||||
- 任何使用CogVideoX开源模型制作的最小完整开源项目
|
||||
|
||||
## 代码规范
|
||||
|
||||
良好的代码风格是一种艺术,我们已经为项目准备好了`pyproject.toml`配置文件,用于规范代码风格。您可以按照以下规范梳理代码:
|
||||
|
||||
1. 安装`ruff`工具
|
||||
|
||||
```shell
|
||||
pip install ruff
|
||||
```
|
||||
|
||||
接着,运行`ruff`工具
|
||||
|
||||
```shell
|
||||
ruff check tools sat inference
|
||||
```
|
||||
|
||||
检查代码风格,如果有问题,您可以通过`ruff format .`命令自动修复。
|
||||
|
||||
```shell
|
||||
ruff format tools sat inference
|
||||
```
|
||||
|
||||
如果您的代码符合规范,应该不会出现任何的错误。
|
||||
|
||||
## 命名规范
|
||||
|
||||
- 请使用英文命名,不要使用拼音或者其他语言命名。所有的注释均使用英文。
|
||||
- 请严格遵循 PEP8 规范,使用下划线分割单词。请勿使用 a,b,c 这样的命名。
|
Binary file not shown.
Before Width: | Height: | Size: 4.6 KiB After Width: | Height: | Size: 195 KiB |
209
sat/README.md
209
sat/README.md
@ -1,29 +1,42 @@
|
||||
# SAT CogVideoX-2B
|
||||
# SAT CogVideoX
|
||||
|
||||
[中文阅读](./README_zh.md)
|
||||
|
||||
[日本語で読む](./README_ja.md)
|
||||
|
||||
This folder contains the inference code using [SAT](https://github.com/THUDM/SwissArmyTransformer) weights and the
|
||||
This folder contains inference code using [SAT](https://github.com/THUDM/SwissArmyTransformer) weights, along with
|
||||
fine-tuning code for SAT weights.
|
||||
|
||||
This code is the framework used by the team to train the model. It has few comments and requires careful study.
|
||||
If you are interested in the `CogVideoX1.0` version of the model, please check the SAT
|
||||
folder [here](https://github.com/THUDM/CogVideo/releases/tag/v1.0). This branch only supports the `CogVideoX1.5` series
|
||||
models.
|
||||
|
||||
## Inference Model
|
||||
|
||||
### 1. Ensure that you have correctly installed the dependencies required by this folder.
|
||||
### 1. Make sure you have installed all dependencies in this folder
|
||||
|
||||
```shell
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### 2. Download the model weights
|
||||
### 2. Download the Model Weights
|
||||
|
||||
### 2. Download model weights
|
||||
First, download the model weights from the SAT mirror.
|
||||
|
||||
First, go to the SAT mirror to download the model weights. For the CogVideoX-2B model, please download as follows:
|
||||
#### CogVideoX1.5 Model
|
||||
|
||||
```shell
|
||||
```
|
||||
git lfs install
|
||||
git clone https://huggingface.co/THUDM/CogVideoX1.5-5B-SAT
|
||||
```
|
||||
|
||||
This command downloads three models: Transformers, VAE, and T5 Encoder.
|
||||
|
||||
#### CogVideoX Model
|
||||
|
||||
For the CogVideoX-2B model, download as follows:
|
||||
|
||||
```
|
||||
mkdir CogVideoX-2b-sat
|
||||
cd CogVideoX-2b-sat
|
||||
wget https://cloud.tsinghua.edu.cn/f/fdba7608a49c463ba754/?dl=1
|
||||
@ -34,13 +47,12 @@ mv 'index.html?dl=1' transformer.zip
|
||||
unzip transformer.zip
|
||||
```
|
||||
|
||||
For the CogVideoX-5B model, please download the `transformers` file as follows link:
|
||||
(VAE files are the same as 2B)
|
||||
Download the `transformers` file for the CogVideoX-5B model (the VAE file is the same as for 2B):
|
||||
|
||||
+ [CogVideoX-5B](https://cloud.tsinghua.edu.cn/d/fcef5b3904294a6885e5/?p=%2F&mode=list)
|
||||
+ [CogVideoX-5B-I2V](https://cloud.tsinghua.edu.cn/d/5cc62a2d6e7d45c0a2f6/?p=%2F1&mode=list)
|
||||
|
||||
Next, you need to format the model files as follows:
|
||||
Arrange the model files in the following structure:
|
||||
|
||||
```
|
||||
.
|
||||
@ -52,20 +64,24 @@ Next, you need to format the model files as follows:
|
||||
└── 3d-vae.pt
|
||||
```
|
||||
|
||||
Due to large size of model weight file, using `git lfs` is recommended. Installation of `git lfs` can be
|
||||
found [here](https://github.com/git-lfs/git-lfs?tab=readme-ov-file#installing)
|
||||
Since model weight files are large, it’s recommended to use `git lfs`.
|
||||
See [here](https://github.com/git-lfs/git-lfs?tab=readme-ov-file#installing) for `git lfs` installation.
|
||||
|
||||
Next, clone the T5 model, which is not used for training and fine-tuning, but must be used.
|
||||
> T5 model is available on [Modelscope](https://modelscope.cn/models/ZhipuAI/CogVideoX-2b) as well.
|
||||
```
|
||||
git lfs install
|
||||
```
|
||||
|
||||
```shell
|
||||
git clone https://huggingface.co/THUDM/CogVideoX-2b.git
|
||||
Next, clone the T5 model, which is used as an encoder and doesn’t require training or fine-tuning.
|
||||
> You may also use the model file location on [Modelscope](https://modelscope.cn/models/ZhipuAI/CogVideoX-2b).
|
||||
|
||||
```
|
||||
git clone https://huggingface.co/THUDM/CogVideoX-2b.git # Download model from Huggingface
|
||||
# git clone https://www.modelscope.cn/ZhipuAI/CogVideoX-2b.git # Download from Modelscope
|
||||
mkdir t5-v1_1-xxl
|
||||
mv CogVideoX-2b/text_encoder/* CogVideoX-2b/tokenizer/* t5-v1_1-xxl
|
||||
```
|
||||
|
||||
By following the above approach, you will obtain a safetensor format T5 file. Ensure that there are no errors when
|
||||
loading it into Deepspeed in Finetune.
|
||||
This will yield a safetensor format T5 file that can be loaded without error during Deepspeed fine-tuning.
|
||||
|
||||
```
|
||||
├── added_tokens.json
|
||||
@ -80,11 +96,11 @@ loading it into Deepspeed in Finetune.
|
||||
0 directories, 8 files
|
||||
```
|
||||
|
||||
### 3. Modify the file in `configs/cogvideox_2b.yaml`.
|
||||
### 3. Modify `configs/cogvideox_*.yaml` file.
|
||||
|
||||
```yaml
|
||||
model:
|
||||
scale_factor: 1.15258426
|
||||
scale_factor: 1.55258426
|
||||
disable_first_stage_autocast: true
|
||||
log_keys:
|
||||
- txt
|
||||
@ -160,14 +176,14 @@ model:
|
||||
ucg_rate: 0.1
|
||||
target: sgm.modules.encoders.modules.FrozenT5Embedder
|
||||
params:
|
||||
model_dir: "t5-v1_1-xxl" # Absolute path to the CogVideoX-2b/t5-v1_1-xxl weights folder
|
||||
model_dir: "t5-v1_1-xxl" # absolute path to CogVideoX-2b/t5-v1_1-xxl weight folder
|
||||
max_length: 226
|
||||
|
||||
first_stage_config:
|
||||
target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper
|
||||
params:
|
||||
cp_size: 1
|
||||
ckpt_path: "CogVideoX-2b-sat/vae/3d-vae.pt" # Absolute path to the CogVideoX-2b-sat/vae/3d-vae.pt folder
|
||||
ckpt_path: "CogVideoX-2b-sat/vae/3d-vae.pt" # absolute path to CogVideoX-2b-sat/vae/3d-vae.pt file
|
||||
ignore_keys: [ 'loss' ]
|
||||
|
||||
loss_config:
|
||||
@ -239,48 +255,47 @@ model:
|
||||
num_steps: 50
|
||||
```
|
||||
|
||||
### 4. Modify the file in `configs/inference.yaml`.
|
||||
### 4. Modify `configs/inference.yaml` file.
|
||||
|
||||
```yaml
|
||||
args:
|
||||
latent_channels: 16
|
||||
mode: inference
|
||||
load: "{absolute_path/to/your}/transformer" # Absolute path to the CogVideoX-2b-sat/transformer folder
|
||||
load: "{absolute_path/to/your}/transformer" # Absolute path to CogVideoX-2b-sat/transformer folder
|
||||
# load: "{your lora folder} such as zRzRzRzRzRzRzR/lora-disney-08-20-13-28" # This is for Full model without lora adapter
|
||||
|
||||
batch_size: 1
|
||||
input_type: txt # You can choose txt for pure text input, or change to cli for command line input
|
||||
input_file: configs/test.txt # Pure text file, which can be edited
|
||||
sampling_num_frames: 13 # Must be 13, 11 or 9
|
||||
input_type: txt # You can choose "txt" for plain text input or change to "cli" for command-line input
|
||||
input_file: configs/test.txt # Plain text file, can be edited
|
||||
sampling_num_frames: 13 # For CogVideoX1.5-5B it must be 42 or 22. For CogVideoX-5B / 2B, it must be 13, 11, or 9.
|
||||
sampling_fps: 8
|
||||
fp16: True # For CogVideoX-2B
|
||||
# bf16: True # For CogVideoX-5B
|
||||
# bf16: True # For CogVideoX-5B
|
||||
output_dir: outputs/
|
||||
force_inference: True
|
||||
```
|
||||
|
||||
+ Modify `configs/test.txt` if multiple prompts is required, in which each line makes a prompt.
|
||||
+ For better prompt formatting, refer to [convert_demo.py](../inference/convert_demo.py), for which you should set the
|
||||
OPENAI_API_KEY as your environmental variable.
|
||||
+ Modify `input_type` in `configs/inference.yaml` if use command line as prompt iuput.
|
||||
+ If using a text file to save multiple prompts, modify `configs/test.txt` as needed. One prompt per line. If you are
|
||||
unsure how to write prompts, use [this code](../inference/convert_demo.py) to call an LLM for refinement.
|
||||
+ To use command-line input, modify:
|
||||
|
||||
```yaml
|
||||
```
|
||||
input_type: cli
|
||||
```
|
||||
|
||||
This allows input from the command line as prompts.
|
||||
This allows you to enter prompts from the command line.
|
||||
|
||||
Change `output_dir` if you wish to modify the address of the output video
|
||||
To modify the output video location, change:
|
||||
|
||||
```yaml
|
||||
```
|
||||
output_dir: outputs/
|
||||
```
|
||||
|
||||
It is saved by default in the `.outputs/` folder.
|
||||
The default location is the `.outputs/` folder.
|
||||
|
||||
### 5. Run the inference code to perform inference.
|
||||
### 5. Run the Inference Code to Perform Inference
|
||||
|
||||
```shell
|
||||
```
|
||||
bash inference.sh
|
||||
```
|
||||
|
||||
@ -288,95 +303,95 @@ bash inference.sh
|
||||
|
||||
### Preparing the Dataset
|
||||
|
||||
The dataset format should be as follows:
|
||||
The dataset should be structured as follows:
|
||||
|
||||
```
|
||||
.
|
||||
├── labels
|
||||
│ ├── 1.txt
|
||||
│ ├── 2.txt
|
||||
│ ├── ...
|
||||
│ ├── 1.txt
|
||||
│ ├── 2.txt
|
||||
│ ├── ...
|
||||
└── videos
|
||||
├── 1.mp4
|
||||
├── 2.mp4
|
||||
├── ...
|
||||
```
|
||||
|
||||
Each text file shares the same name as its corresponding video, serving as the label for that video. Videos and labels
|
||||
should be matched one-to-one. Generally, a single video should not be associated with multiple labels.
|
||||
Each txt file should have the same name as the corresponding video file and contain the label for that video. The videos
|
||||
and labels should correspond one-to-one. Generally, avoid using one video with multiple labels.
|
||||
|
||||
For style fine-tuning, please prepare at least 50 videos and labels with similar styles to ensure proper fitting.
|
||||
For style fine-tuning, prepare at least 50 videos and labels with a similar style to facilitate fitting.
|
||||
|
||||
### Modifying Configuration Files
|
||||
### Modifying the Configuration File
|
||||
|
||||
We support two fine-tuning methods: `Lora` and full-parameter fine-tuning. Please note that both methods only fine-tune
|
||||
the `transformer` part and do not modify the `VAE` section. `T5` is used solely as an Encoder. Please modify
|
||||
the `configs/sft.yaml` (for full-parameter fine-tuning) file as follows:
|
||||
We support two fine-tuning methods: `Lora` and full-parameter fine-tuning. Note that both methods only fine-tune the
|
||||
`transformer` part. The `VAE` part is not modified, and `T5` is only used as an encoder.
|
||||
Modify the files in `configs/sft.yaml` (full fine-tuning) as follows:
|
||||
|
||||
```
|
||||
# checkpoint_activations: True ## Using gradient checkpointing (Both checkpoint_activations in the config file need to be set to True)
|
||||
```yaml
|
||||
# checkpoint_activations: True ## using gradient checkpointing (both `checkpoint_activations` in the config file need to be set to True)
|
||||
model_parallel_size: 1 # Model parallel size
|
||||
experiment_name: lora-disney # Experiment name (do not modify)
|
||||
mode: finetune # Mode (do not modify)
|
||||
load: "{your_CogVideoX-2b-sat_path}/transformer" ## Transformer model path
|
||||
no_load_rng: True # Whether to load random seed
|
||||
experiment_name: lora-disney # Experiment name (do not change)
|
||||
mode: finetune # Mode (do not change)
|
||||
load: "{your_CogVideoX-2b-sat_path}/transformer" ## Path to Transformer model
|
||||
no_load_rng: True # Whether to load random number seed
|
||||
train_iters: 1000 # Training iterations
|
||||
eval_iters: 1 # Evaluation iterations
|
||||
eval_interval: 100 # Evaluation interval
|
||||
eval_batch_size: 1 # Evaluation batch size
|
||||
save: ckpts # Model save path
|
||||
save_interval: 100 # Model save interval
|
||||
save_interval: 100 # Save interval
|
||||
log_interval: 20 # Log output interval
|
||||
train_data: [ "your train data path" ]
|
||||
valid_data: [ "your val data path" ] # Training and validation datasets can be the same
|
||||
split: 1,0,0 # Training, validation, and test set ratio
|
||||
num_workers: 8 # Number of worker threads for data loader
|
||||
force_train: True # Allow missing keys when loading checkpoint (T5 and VAE are loaded separately)
|
||||
only_log_video_latents: True # Avoid memory overhead caused by VAE decode
|
||||
valid_data: [ "your val data path" ] # Training and validation sets can be the same
|
||||
split: 1,0,0 # Proportion for training, validation, and test sets
|
||||
num_workers: 8 # Number of data loader workers
|
||||
force_train: True # Allow missing keys when loading checkpoint (T5 and VAE loaded separately)
|
||||
only_log_video_latents: True # Avoid memory usage from VAE decoding
|
||||
deepspeed:
|
||||
bf16:
|
||||
enabled: False # For CogVideoX-2B set to False and for CogVideoX-5B set to True
|
||||
enabled: False # For CogVideoX-2B Turn to False and For CogVideoX-5B Turn to True
|
||||
fp16:
|
||||
enabled: True # For CogVideoX-2B set to True and for CogVideoX-5B set to False
|
||||
enabled: True # For CogVideoX-2B Turn to True and For CogVideoX-5B Turn to False
|
||||
```
|
||||
|
||||
If you wish to use Lora fine-tuning, you also need to modify the `cogvideox_<model_parameters>_lora` file:
|
||||
``` To use Lora fine-tuning, you also need to modify `cogvideox_<model parameters>_lora` file:
|
||||
|
||||
Here, take `CogVideoX-2B` as a reference:
|
||||
Here's an example using `CogVideoX-2B`:
|
||||
|
||||
```
|
||||
model:
|
||||
scale_factor: 1.15258426
|
||||
scale_factor: 1.55258426
|
||||
disable_first_stage_autocast: true
|
||||
not_trainable_prefixes: [ 'all' ] ## Uncomment
|
||||
not_trainable_prefixes: [ 'all' ] ## Uncomment to unlock
|
||||
log_keys:
|
||||
- txt'
|
||||
- txt
|
||||
|
||||
lora_config: ## Uncomment
|
||||
lora_config: ## Uncomment to unlock
|
||||
target: sat.model.finetune.lora2.LoraMixin
|
||||
params:
|
||||
r: 256
|
||||
```
|
||||
|
||||
### Modifying Run Scripts
|
||||
### Modify the Run Script
|
||||
|
||||
Edit `finetune_single_gpu.sh` or `finetune_multi_gpus.sh` to select the configuration file. Below are two examples:
|
||||
Edit `finetune_single_gpu.sh` or `finetune_multi_gpus.sh` and select the config file. Below are two examples:
|
||||
|
||||
1. If you want to use the `CogVideoX-2B` model and the `Lora` method, you need to modify `finetune_single_gpu.sh`
|
||||
or `finetune_multi_gpus.sh`:
|
||||
1. If you want to use the `CogVideoX-2B` model with `Lora`, modify `finetune_single_gpu.sh` or `finetune_multi_gpus.sh`
|
||||
as follows:
|
||||
|
||||
```
|
||||
run_cmd="torchrun --standalone --nproc_per_node=8 train_video.py --base configs/cogvideox_2b_lora.yaml configs/sft.yaml --seed $RANDOM"
|
||||
```
|
||||
|
||||
2. If you want to use the `CogVideoX-2B` model and the `full-parameter fine-tuning` method, you need to
|
||||
modify `finetune_single_gpu.sh` or `finetune_multi_gpus.sh`:
|
||||
2. If you want to use the `CogVideoX-2B` model with full fine-tuning, modify `finetune_single_gpu.sh` or
|
||||
`finetune_multi_gpus.sh` as follows:
|
||||
|
||||
```
|
||||
run_cmd="torchrun --standalone --nproc_per_node=8 train_video.py --base configs/cogvideox_2b.yaml configs/sft.yaml --seed $RANDOM"
|
||||
```
|
||||
|
||||
### Fine-Tuning and Evaluation
|
||||
### Fine-tuning and Validation
|
||||
|
||||
Run the inference code to start fine-tuning.
|
||||
|
||||
@ -385,45 +400,45 @@ bash finetune_single_gpu.sh # Single GPU
|
||||
bash finetune_multi_gpus.sh # Multi GPUs
|
||||
```
|
||||
|
||||
### Using the Fine-Tuned Model
|
||||
### Using the Fine-tuned Model
|
||||
|
||||
The fine-tuned model cannot be merged; here is how to modify the inference configuration file `inference.sh`:
|
||||
The fine-tuned model cannot be merged. Here’s how to modify the inference configuration file `inference.sh`
|
||||
|
||||
```
|
||||
run_cmd="$environs python sample_video.py --base configs/cogvideox_<model_parameters>_lora.yaml configs/inference.yaml --seed 42"
|
||||
run_cmd="$environs python sample_video.py --base configs/cogvideox_<model parameters>_lora.yaml configs/inference.yaml --seed 42"
|
||||
```
|
||||
|
||||
Then, execute the code:
|
||||
Then, run the code:
|
||||
|
||||
```
|
||||
bash inference.sh
|
||||
```
|
||||
|
||||
### Converting to Huggingface Diffusers Supported Weights
|
||||
### Converting to Huggingface Diffusers-compatible Weights
|
||||
|
||||
The SAT weight format is different from Huggingface's weight format and needs to be converted. Please run:
|
||||
The SAT weight format is different from Huggingface’s format and requires conversion. Run
|
||||
|
||||
```shell
|
||||
```
|
||||
python ../tools/convert_weight_sat2hf.py
|
||||
```
|
||||
|
||||
### Exporting Huggingface Diffusers lora LoRA Weights from SAT Checkpoints
|
||||
### Exporting Lora Weights from SAT to Huggingface Diffusers
|
||||
|
||||
After completing the training using the above steps, we get a SAT checkpoint with LoRA weights. You can find the file
|
||||
at `{args.save}/1000/1000/mp_rank_00_model_states.pt`.
|
||||
Support is provided for exporting Lora weights from SAT to Huggingface Diffusers format.
|
||||
After training with the above steps, you’ll find the SAT model with Lora weights in
|
||||
{args.save}/1000/1000/mp_rank_00_model_states.pt
|
||||
|
||||
The script for exporting LoRA weights can be found in the CogVideoX repository at `tools/export_sat_lora_weight.py`.
|
||||
After exporting, you can use `load_cogvideox_lora.py` for inference.
|
||||
The export script `export_sat_lora_weight.py` is located in the CogVideoX repository under `tools/`. After exporting,
|
||||
use `load_cogvideox_lora.py` for inference.
|
||||
|
||||
Export command:
|
||||
|
||||
```bash
|
||||
python tools/export_sat_lora_weight.py --sat_pt_path {args.save}/{experiment_name}-09-09-21-10/1000/mp_rank_00_model_states.pt --lora_save_directory {args.save}/export_hf_lora_weights_1/
|
||||
```
|
||||
python tools/export_sat_lora_weight.py --sat_pt_path {args.save}/{experiment_name}-09-09-21-10/1000/mp_rank_00_model_states.pt --lora_save_directory {args.save}/export_hf_lora_weights_1/
|
||||
```
|
||||
|
||||
This training mainly modified the following model structures. The table below lists the corresponding structure mappings
|
||||
for converting to the HF (Hugging Face) format LoRA structure. As you can see, LoRA adds a low-rank weight to the
|
||||
model's attention structure.
|
||||
The following model structures were modified during training. Here is the mapping between SAT and HF Lora structures.
|
||||
Lora adds a low-rank weight to the attention structure of the model.
|
||||
|
||||
```
|
||||
'attention.query_key_value.matrix_A.0': 'attn1.to_q.lora_A.weight',
|
||||
@ -436,5 +451,5 @@ model's attention structure.
|
||||
'attention.dense.matrix_B.0': 'attn1.to_out.0.lora_B.weight'
|
||||
```
|
||||
|
||||
Using export_sat_lora_weight.py, you can convert the SAT checkpoint into the HF LoRA format.
|
||||
Using `export_sat_lora_weight.py` will convert these to the HF format Lora structure.
|
||||

|
||||
|
353
sat/README_ja.md
353
sat/README_ja.md
@ -1,27 +1,39 @@
|
||||
# SAT CogVideoX-2B
|
||||
# SAT CogVideoX
|
||||
|
||||
[Read this in English.](./README_zh)
|
||||
[Read this in English.](./README.md)
|
||||
|
||||
[中文阅读](./README_zh.md)
|
||||
|
||||
このフォルダには、[SAT](https://github.com/THUDM/SwissArmyTransformer) ウェイトを使用した推論コードと、SAT
|
||||
ウェイトのファインチューニングコードが含まれています。
|
||||
|
||||
このコードは、チームがモデルをトレーニングするために使用したフレームワークです。コメントが少なく、注意深く研究する必要があります。
|
||||
このフォルダには、[SAT](https://github.com/THUDM/SwissArmyTransformer)の重みを使用した推論コードと、SAT重みのファインチューニングコードが含まれています。
|
||||
`CogVideoX1.0`バージョンのモデルに関心がある場合は、[こちら](https://github.com/THUDM/CogVideo/releases/tag/v1.0)
|
||||
のSATフォルダを参照してください。このブランチは`CogVideoX1.5`シリーズのモデルのみをサポートしています。
|
||||
|
||||
## 推論モデル
|
||||
|
||||
### 1. このフォルダに必要な依存関係が正しくインストールされていることを確認してください。
|
||||
### 1. このフォルダ内の必要な依存関係がすべてインストールされていることを確認してください
|
||||
|
||||
```shell
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### 2. モデルウェイトをダウンロードします
|
||||
### 2. モデルの重みをダウンロード
|
||||
|
||||
まず、SAT ミラーに移動してモデルの重みをダウンロードします。 CogVideoX-2B モデルの場合は、次のようにダウンロードしてください。
|
||||
まず、SATミラーからモデルの重みをダウンロードしてください。
|
||||
|
||||
```shell
|
||||
#### CogVideoX1.5 モデル
|
||||
|
||||
```
|
||||
git lfs install
|
||||
git clone https://huggingface.co/THUDM/CogVideoX1.5-5B-SAT
|
||||
```
|
||||
|
||||
これにより、Transformers、VAE、T5 Encoderの3つのモデルがダウンロードされます。
|
||||
|
||||
#### CogVideoX モデル
|
||||
|
||||
CogVideoX-2B モデルについては、以下のようにダウンロードしてください:
|
||||
|
||||
```
|
||||
mkdir CogVideoX-2b-sat
|
||||
cd CogVideoX-2b-sat
|
||||
wget https://cloud.tsinghua.edu.cn/f/fdba7608a49c463ba754/?dl=1
|
||||
@ -32,12 +44,12 @@ mv 'index.html?dl=1' transformer.zip
|
||||
unzip transformer.zip
|
||||
```
|
||||
|
||||
CogVideoX-5B モデルの `transformers` ファイルを以下のリンクからダウンロードしてください (VAE ファイルは 2B と同じです):
|
||||
CogVideoX-5B モデルの `transformers` ファイルをダウンロードしてください(VAEファイルは2Bと同じです):
|
||||
|
||||
+ [CogVideoX-5B](https://cloud.tsinghua.edu.cn/d/fcef5b3904294a6885e5/?p=%2F&mode=list)
|
||||
+ [CogVideoX-5B-I2V](https://cloud.tsinghua.edu.cn/d/5cc62a2d6e7d45c0a2f6/?p=%2F1&mode=list)
|
||||
|
||||
次に、モデルファイルを以下の形式にフォーマットする必要があります:
|
||||
モデルファイルを以下のように配置してください:
|
||||
|
||||
```
|
||||
.
|
||||
@ -49,24 +61,24 @@ CogVideoX-5B モデルの `transformers` ファイルを以下のリンクから
|
||||
└── 3d-vae.pt
|
||||
```
|
||||
|
||||
モデルの重みファイルが大きいため、`git lfs`を使用することをお勧めいたします。`git lfs`
|
||||
のインストールについては、[こちら](https://github.com/git-lfs/git-lfs?tab=readme-ov-file#installing)をご参照ください。
|
||||
モデルの重みファイルが大きいため、`git lfs`の使用をお勧めします。
|
||||
`git lfs`のインストール方法は[こちら](https://github.com/git-lfs/git-lfs?tab=readme-ov-file#installing)を参照してください。
|
||||
|
||||
```shell
|
||||
```
|
||||
git lfs install
|
||||
```
|
||||
|
||||
次に、T5 モデルをクローンします。これはトレーニングやファインチューニングには使用されませんが、使用する必要があります。
|
||||
> モデルを複製する際には、[Modelscope](https://modelscope.cn/models/ZhipuAI/CogVideoX-2b)のモデルファイルの場所もご使用いただけます。
|
||||
次に、T5モデルをクローンします。このモデルはEncoderとしてのみ使用され、訓練やファインチューニングは必要ありません。
|
||||
> [Modelscope](https://modelscope.cn/models/ZhipuAI/CogVideoX-2b)上のモデルファイルも使用可能です。
|
||||
|
||||
```shell
|
||||
git clone https://huggingface.co/THUDM/CogVideoX-2b.git #ハギングフェイス(huggingface.org)からモデルをダウンロードいただきます
|
||||
# git clone https://www.modelscope.cn/ZhipuAI/CogVideoX-2b.git #Modelscopeからモデルをダウンロードいただきます
|
||||
```
|
||||
git clone https://huggingface.co/THUDM/CogVideoX-2b.git # Huggingfaceからモデルをダウンロード
|
||||
# git clone https://www.modelscope.cn/ZhipuAI/CogVideoX-2b.git # Modelscopeからダウンロード
|
||||
mkdir t5-v1_1-xxl
|
||||
mv CogVideoX-2b/text_encoder/* CogVideoX-2b/tokenizer/* t5-v1_1-xxl
|
||||
```
|
||||
|
||||
上記の方法に従うことで、safetensor 形式の T5 ファイルを取得できます。これにより、Deepspeed でのファインチューニング中にエラーが発生しないようにします。
|
||||
これにより、Deepspeedファインチューニング中にエラーなくロードできるsafetensor形式のT5ファイルが作成されます。
|
||||
|
||||
```
|
||||
├── added_tokens.json
|
||||
@ -81,11 +93,11 @@ mv CogVideoX-2b/text_encoder/* CogVideoX-2b/tokenizer/* t5-v1_1-xxl
|
||||
0 directories, 8 files
|
||||
```
|
||||
|
||||
### 3. `configs/cogvideox_2b.yaml` ファイルを変更します。
|
||||
### 3. `configs/cogvideox_*.yaml`ファイルを編集
|
||||
|
||||
```yaml
|
||||
model:
|
||||
scale_factor: 1.15258426
|
||||
scale_factor: 1.55258426
|
||||
disable_first_stage_autocast: true
|
||||
log_keys:
|
||||
- txt
|
||||
@ -123,7 +135,7 @@ model:
|
||||
num_attention_heads: 30
|
||||
|
||||
transformer_args:
|
||||
checkpoint_activations: True ## グラデーション チェックポイントを使用する
|
||||
checkpoint_activations: True ## using gradient checkpointing
|
||||
vocab_size: 1
|
||||
max_sequence_length: 64
|
||||
layernorm_order: pre
|
||||
@ -161,14 +173,14 @@ model:
|
||||
ucg_rate: 0.1
|
||||
target: sgm.modules.encoders.modules.FrozenT5Embedder
|
||||
params:
|
||||
model_dir: "t5-v1_1-xxl" # CogVideoX-2b/t5-v1_1-xxlフォルダの絶対パス
|
||||
model_dir: "t5-v1_1-xxl" # CogVideoX-2b/t5-v1_1-xxl 重みフォルダの絶対パス
|
||||
max_length: 226
|
||||
|
||||
first_stage_config:
|
||||
target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper
|
||||
params:
|
||||
cp_size: 1
|
||||
ckpt_path: "CogVideoX-2b-sat/vae/3d-vae.pt" # CogVideoX-2b-sat/vae/3d-vae.ptフォルダの絶対パス
|
||||
ckpt_path: "CogVideoX-2b-sat/vae/3d-vae.pt" # CogVideoX-2b-sat/vae/3d-vae.ptファイルの絶対パス
|
||||
ignore_keys: [ 'loss' ]
|
||||
|
||||
loss_config:
|
||||
@ -240,7 +252,7 @@ model:
|
||||
num_steps: 50
|
||||
```
|
||||
|
||||
### 4. `configs/inference.yaml` ファイルを変更します。
|
||||
### 4. `configs/inference.yaml`ファイルを編集
|
||||
|
||||
```yaml
|
||||
args:
|
||||
@ -250,38 +262,38 @@ args:
|
||||
# load: "{your lora folder} such as zRzRzRzRzRzRzR/lora-disney-08-20-13-28" # This is for Full model without lora adapter
|
||||
|
||||
batch_size: 1
|
||||
input_type: txt #TXTのテキストファイルを入力として選択されたり、CLIコマンドラインを入力として変更されたりいただけます
|
||||
input_file: configs/test.txt #テキストファイルのパスで、これに対して編集がさせていただけます
|
||||
sampling_num_frames: 13 # Must be 13, 11 or 9
|
||||
input_type: txt # "txt"でプレーンテキスト入力、"cli"でコマンドライン入力を選択可能
|
||||
input_file: configs/test.txt # プレーンテキストファイル、編集可能
|
||||
sampling_num_frames: 13 # CogVideoX1.5-5Bでは42または22、CogVideoX-5B / 2Bでは13, 11, または9
|
||||
sampling_fps: 8
|
||||
fp16: True # For CogVideoX-2B
|
||||
# bf16: True # For CogVideoX-5B
|
||||
fp16: True # CogVideoX-2B用
|
||||
# bf16: True # CogVideoX-5B用
|
||||
output_dir: outputs/
|
||||
force_inference: True
|
||||
```
|
||||
|
||||
+ 複数のプロンプトを保存するために txt を使用する場合は、`configs/test.txt`
|
||||
を参照して変更してください。1行に1つのプロンプトを記述します。プロンプトの書き方がわからない場合は、最初に [このコード](../inference/convert_demo.py)
|
||||
を使用して LLM によるリファインメントを呼び出すことができます。
|
||||
+ コマンドラインを入力として使用する場合は、次のように変更します。
|
||||
+ 複数のプロンプトを含むテキストファイルを使用する場合、`configs/test.txt`
|
||||
を適宜編集してください。1行につき1プロンプトです。プロンプトの書き方が分からない場合は、[こちらのコード](../inference/convert_demo.py)
|
||||
を使用してLLMで補正できます。
|
||||
+ コマンドライン入力を使用する場合、以下のように変更します:
|
||||
|
||||
```yaml
|
||||
```
|
||||
input_type: cli
|
||||
```
|
||||
|
||||
これにより、コマンドラインからプロンプトを入力できます。
|
||||
|
||||
出力ビデオのディレクトリを変更したい場合は、次のように変更できます:
|
||||
出力ビデオの保存場所を変更する場合は、以下を編集してください:
|
||||
|
||||
```yaml
|
||||
```
|
||||
output_dir: outputs/
|
||||
```
|
||||
|
||||
デフォルトでは `.outputs/` フォルダに保存されます。
|
||||
デフォルトでは`.outputs/`フォルダに保存されます。
|
||||
|
||||
### 5. 推論コードを実行して推論を開始します。
|
||||
### 5. 推論コードを実行して推論を開始
|
||||
|
||||
```shell
|
||||
```
|
||||
bash inference.sh
|
||||
```
|
||||
|
||||
@ -289,7 +301,7 @@ bash inference.sh
|
||||
|
||||
### データセットの準備
|
||||
|
||||
データセットの形式は次のようになります:
|
||||
データセットは以下の構造である必要があります:
|
||||
|
||||
```
|
||||
.
|
||||
@ -303,123 +315,224 @@ bash inference.sh
|
||||
├── ...
|
||||
```
|
||||
|
||||
各 txt ファイルは対応するビデオファイルと同じ名前であり、そのビデオのラベルを含んでいます。各ビデオはラベルと一対一で対応する必要があります。通常、1つのビデオに複数のラベルを持たせることはありません。
|
||||
各txtファイルは対応するビデオファイルと同じ名前で、ビデオのラベルを含んでいます。ビデオとラベルは一対一で対応させる必要があります。通常、1つのビデオに複数のラベルを使用することは避けてください。
|
||||
|
||||
スタイルファインチューニングの場合、少なくとも50本のスタイルが似たビデオとラベルを準備し、フィッティングを容易にします。
|
||||
スタイルのファインチューニングの場合、スタイルが似たビデオとラベルを少なくとも50本準備し、フィッティングを促進します。
|
||||
|
||||
### 設定ファイルの変更
|
||||
### 設定ファイルの編集
|
||||
|
||||
`Lora` とフルパラメータ微調整の2つの方法をサポートしています。両方の微調整方法は、`transformer` 部分のみを微調整し、`VAE`
|
||||
部分には変更を加えないことに注意してください。`T5` はエンコーダーとしてのみ使用されます。以下のように `configs/sft.yaml` (
|
||||
フルパラメータ微調整用) ファイルを変更してください。
|
||||
``` `Lora`と全パラメータのファインチューニングの2種類をサポートしています。どちらも`transformer`部分のみをファインチューニングし、`VAE`部分は変更されず、`T5`はエンコーダーとしてのみ使用されます。
|
||||
``` 以下のようにして`configs/sft.yaml`(全量ファインチューニング)ファイルを編集してください:
|
||||
|
||||
```
|
||||
# checkpoint_activations: True ## 勾配チェックポイントを使用する場合 (設定ファイル内の2つの checkpoint_activations を True に設定する必要があります)
|
||||
# checkpoint_activations: True ## using gradient checkpointing (configファイル内の2つの`checkpoint_activations`を両方Trueに設定)
|
||||
model_parallel_size: 1 # モデル並列サイズ
|
||||
experiment_name: lora-disney # 実験名 (変更しないでください)
|
||||
mode: finetune # モード (変更しないでください)
|
||||
load: "{your_CogVideoX-2b-sat_path}/transformer" ## Transformer モデルのパス
|
||||
no_load_rng: True # 乱数シードを読み込むかどうか
|
||||
experiment_name: lora-disney # 実験名(変更不要)
|
||||
mode: finetune # モード(変更不要)
|
||||
load: "{your_CogVideoX-2b-sat_path}/transformer" ## Transformerモデルのパス
|
||||
no_load_rng: True # 乱数シードをロードするかどうか
|
||||
train_iters: 1000 # トレーニングイテレーション数
|
||||
eval_iters: 1 # 評価イテレーション数
|
||||
eval_interval: 100 # 評価間隔
|
||||
eval_batch_size: 1 # 評価バッチサイズ
|
||||
eval_iters: 1 # 検証イテレーション数
|
||||
eval_interval: 100 # 検証間隔
|
||||
eval_batch_size: 1 # 検証バッチサイズ
|
||||
save: ckpts # モデル保存パス
|
||||
save_interval: 100 # モデル保存間隔
|
||||
save_interval: 100 # 保存間隔
|
||||
log_interval: 20 # ログ出力間隔
|
||||
train_data: [ "your train data path" ]
|
||||
valid_data: [ "your val data path" ] # トレーニングデータと評価データは同じでも構いません
|
||||
split: 1,0,0 # トレーニングセット、評価セット、テストセットの割合
|
||||
num_workers: 8 # データローダーのワーカースレッド数
|
||||
force_train: True # チェックポイントをロードするときに欠落したキーを許可 (T5 と VAE は別々にロードされます)
|
||||
only_log_video_latents: True # VAE のデコードによるメモリオーバーヘッドを回避
|
||||
valid_data: [ "your val data path" ] # トレーニングセットと検証セットは同じでも構いません
|
||||
split: 1,0,0 # トレーニングセット、検証セット、テストセットの割合
|
||||
num_workers: 8 # データローダーのワーカー数
|
||||
force_train: True # チェックポイントをロードする際に`missing keys`を許可(T5とVAEは別途ロード)
|
||||
only_log_video_latents: True # VAEのデコードによるメモリ使用量を抑える
|
||||
deepspeed:
|
||||
bf16:
|
||||
enabled: False # CogVideoX-2B の場合は False に設定し、CogVideoX-5B の場合は True に設定
|
||||
enabled: False # CogVideoX-2B 用は False、CogVideoX-5B 用は True に設定
|
||||
fp16:
|
||||
enabled: True # CogVideoX-2B の場合は True に設定し、CogVideoX-5B の場合は False に設定
|
||||
enabled: True # CogVideoX-2B 用は True、CogVideoX-5B 用は False に設定
|
||||
```
|
||||
|
||||
Lora 微調整を使用したい場合は、`cogvideox_<model_parameters>_lora` ファイルも変更する必要があります。
|
||||
```yaml
|
||||
args:
|
||||
latent_channels: 16
|
||||
mode: inference
|
||||
load: "{absolute_path/to/your}/transformer" # Absolute path to CogVideoX-2b-sat/transformer folder
|
||||
# load: "{your lora folder} such as zRzRzRzRzRzRzR/lora-disney-08-20-13-28" # This is for Full model without lora adapter
|
||||
|
||||
ここでは、`CogVideoX-2B` を参考にします。
|
||||
|
||||
```
|
||||
model:
|
||||
scale_factor: 1.15258426
|
||||
disable_first_stage_autocast: true
|
||||
not_trainable_prefixes: [ 'all' ] ## コメントを解除
|
||||
log_keys:
|
||||
- txt'
|
||||
|
||||
lora_config: ## コメントを解除
|
||||
target: sat.model.finetune.lora2.LoraMixin
|
||||
params:
|
||||
r: 256
|
||||
batch_size: 1
|
||||
input_type: txt # You can choose "txt" for plain text input or change to "cli" for command-line input
|
||||
input_file: configs/test.txt # Plain text file, can be edited
|
||||
sampling_num_frames: 13 # For CogVideoX1.5-5B it must be 42 or 22. For CogVideoX-5B / 2B, it must be 13, 11, or 9.
|
||||
sampling_fps: 8
|
||||
fp16: True # For CogVideoX-2B
|
||||
# bf16: True # For CogVideoX-5B
|
||||
output_dir: outputs/
|
||||
force_inference: True
|
||||
```
|
||||
|
||||
### 実行スクリプトの変更
|
||||
|
||||
設定ファイルを選択するために `finetune_single_gpu.sh` または `finetune_multi_gpus.sh` を編集します。以下に2つの例を示します。
|
||||
|
||||
1. `CogVideoX-2B` モデルを使用し、`Lora` 手法を利用する場合は、`finetune_single_gpu.sh` または `finetune_multi_gpus.sh`
|
||||
を変更する必要があります。
|
||||
+ If using a text file to save multiple prompts, modify `configs/test.txt` as needed. One prompt per line. If you are
|
||||
unsure how to write prompts, use [this code](../inference/convert_demo.py) to call an LLM for refinement.
|
||||
+ To use command-line input, modify:
|
||||
|
||||
```
|
||||
run_cmd="torchrun --standalone --nproc_per_node=8 train_video.py --base configs/cogvideox_2b_lora.yaml configs/sft.yaml --seed $RANDOM"
|
||||
input_type: cli
|
||||
```
|
||||
|
||||
2. `CogVideoX-2B` モデルを使用し、`フルパラメータ微調整` 手法を利用する場合は、`finetune_single_gpu.sh`
|
||||
または `finetune_multi_gpus.sh` を変更する必要があります。
|
||||
This allows you to enter prompts from the command line.
|
||||
|
||||
To modify the output video location, change:
|
||||
|
||||
```
|
||||
run_cmd="torchrun --standalone --nproc_per_node=8 train_video.py --base configs/cogvideox_2b.yaml configs/sft.yaml --seed $RANDOM"
|
||||
output_dir: outputs/
|
||||
```
|
||||
|
||||
### 微調整と評価
|
||||
The default location is the `.outputs/` folder.
|
||||
|
||||
推論コードを実行して微調整を開始します。
|
||||
|
||||
```
|
||||
bash finetune_single_gpu.sh # シングルGPU
|
||||
bash finetune_multi_gpus.sh # マルチGPU
|
||||
```
|
||||
|
||||
### 微調整後のモデルの使用
|
||||
|
||||
微調整されたモデルは統合できません。ここでは、推論設定ファイル `inference.sh` を変更する方法を示します。
|
||||
|
||||
```
|
||||
run_cmd="$environs python sample_video.py --base configs/cogvideox_<model_parameters>_lora.yaml configs/inference.yaml --seed 42"
|
||||
```
|
||||
|
||||
その後、次のコードを実行します。
|
||||
### 5. Run the Inference Code to Perform Inference
|
||||
|
||||
```
|
||||
bash inference.sh
|
||||
```
|
||||
|
||||
### Huggingface Diffusers サポートのウェイトに変換
|
||||
## Fine-tuning the Model
|
||||
|
||||
SAT ウェイト形式は Huggingface のウェイト形式と異なり、変換が必要です。次のコマンドを実行してください:
|
||||
### Preparing the Dataset
|
||||
|
||||
```shell
|
||||
The dataset should be structured as follows:
|
||||
|
||||
```
|
||||
.
|
||||
├── labels
|
||||
│ ├── 1.txt
|
||||
│ ├── 2.txt
|
||||
│ ├── ...
|
||||
└── videos
|
||||
├── 1.mp4
|
||||
├── 2.mp4
|
||||
├── ...
|
||||
```
|
||||
|
||||
Each txt file should have the same name as the corresponding video file and contain the label for that video. The videos
|
||||
and labels should correspond one-to-one. Generally, avoid using one video with multiple labels.
|
||||
|
||||
For style fine-tuning, prepare at least 50 videos and labels with a similar style to facilitate fitting.
|
||||
|
||||
### Modifying the Configuration File
|
||||
|
||||
We support two fine-tuning methods: `Lora` and full-parameter fine-tuning. Note that both methods only fine-tune the
|
||||
`transformer` part. The `VAE` part is not modified, and `T5` is only used as an encoder.
|
||||
Modify the files in `configs/sft.yaml` (full fine-tuning) as follows:
|
||||
|
||||
```yaml
|
||||
# checkpoint_activations: True ## using gradient checkpointing (both `checkpoint_activations` in the config file need to be set to True)
|
||||
model_parallel_size: 1 # Model parallel size
|
||||
experiment_name: lora-disney # Experiment name (do not change)
|
||||
mode: finetune # Mode (do not change)
|
||||
load: "{your_CogVideoX-2b-sat_path}/transformer" ## Path to Transformer model
|
||||
no_load_rng: True # Whether to load random number seed
|
||||
train_iters: 1000 # Training iterations
|
||||
eval_iters: 1 # Evaluation iterations
|
||||
eval_interval: 100 # Evaluation interval
|
||||
eval_batch_size: 1 # Evaluation batch size
|
||||
save: ckpts # Model save path
|
||||
save_interval: 100 # Save interval
|
||||
log_interval: 20 # Log output interval
|
||||
train_data: [ "your train data path" ]
|
||||
valid_data: [ "your val data path" ] # Training and validation sets can be the same
|
||||
split: 1,0,0 # Proportion for training, validation, and test sets
|
||||
num_workers: 8 # Number of data loader workers
|
||||
force_train: True # Allow missing keys when loading checkpoint (T5 and VAE loaded separately)
|
||||
only_log_video_latents: True # Avoid memory usage from VAE decoding
|
||||
deepspeed:
|
||||
bf16:
|
||||
enabled: False # For CogVideoX-2B Turn to False and For CogVideoX-5B Turn to True
|
||||
fp16:
|
||||
enabled: True # For CogVideoX-2B Turn to True and For CogVideoX-5B Turn to False
|
||||
```
|
||||
|
||||
``` To use Lora fine-tuning, you also need to modify `cogvideox_<model parameters>_lora` file:
|
||||
|
||||
Here's an example using `CogVideoX-2B`:
|
||||
|
||||
```yaml
|
||||
model:
|
||||
scale_factor: 1.55258426
|
||||
disable_first_stage_autocast: true
|
||||
not_trainable_prefixes: [ 'all' ] ## Uncomment to unlock
|
||||
log_keys:
|
||||
- txt
|
||||
|
||||
lora_config: ## Uncomment to unlock
|
||||
target: sat.model.finetune.lora2.LoraMixin
|
||||
params:
|
||||
r: 256
|
||||
```
|
||||
|
||||
### Modify the Run Script
|
||||
|
||||
Edit `finetune_single_gpu.sh` or `finetune_multi_gpus.sh` and select the config file. Below are two examples:
|
||||
|
||||
1. If you want to use the `CogVideoX-2B` model with `Lora`, modify `finetune_single_gpu.sh` or `finetune_multi_gpus.sh`
|
||||
as follows:
|
||||
|
||||
```
|
||||
run_cmd="torchrun --standalone --nproc_per_node=8 train_video.py --base configs/cogvideox_2b_lora.yaml configs/sft.yaml --seed $RANDOM"
|
||||
```
|
||||
|
||||
2. If you want to use the `CogVideoX-2B` model with full fine-tuning, modify `finetune_single_gpu.sh` or
|
||||
`finetune_multi_gpus.sh` as follows:
|
||||
|
||||
```
|
||||
run_cmd="torchrun --standalone --nproc_per_node=8 train_video.py --base configs/cogvideox_2b.yaml configs/sft.yaml --seed $RANDOM"
|
||||
```
|
||||
|
||||
### Fine-tuning and Validation
|
||||
|
||||
Run the inference code to start fine-tuning.
|
||||
|
||||
```
|
||||
bash finetune_single_gpu.sh # Single GPU
|
||||
bash finetune_multi_gpus.sh # Multi GPUs
|
||||
```
|
||||
|
||||
### Using the Fine-tuned Model
|
||||
|
||||
The fine-tuned model cannot be merged. Here’s how to modify the inference configuration file `inference.sh`
|
||||
|
||||
```
|
||||
run_cmd="$environs python sample_video.py --base configs/cogvideox_<model parameters>_lora.yaml configs/inference.yaml --seed 42"
|
||||
```
|
||||
|
||||
Then, run the code:
|
||||
|
||||
```
|
||||
bash inference.sh
|
||||
```
|
||||
|
||||
### Converting to Huggingface Diffusers-compatible Weights
|
||||
|
||||
The SAT weight format is different from Huggingface’s format and requires conversion. Run
|
||||
|
||||
```
|
||||
python ../tools/convert_weight_sat2hf.py
|
||||
```
|
||||
|
||||
### SATチェックポイントからHuggingface Diffusers lora LoRAウェイトをエクスポート
|
||||
### Exporting Lora Weights from SAT to Huggingface Diffusers
|
||||
|
||||
上記のステップを完了すると、LoRAウェイト付きのSATチェックポイントが得られます。ファイルは `{args.save}/1000/1000/mp_rank_00_model_states.pt` にあります。
|
||||
Support is provided for exporting Lora weights from SAT to Huggingface Diffusers format.
|
||||
After training with the above steps, you’ll find the SAT model with Lora weights in
|
||||
{args.save}/1000/1000/mp_rank_00_model_states.pt
|
||||
|
||||
LoRAウェイトをエクスポートするためのスクリプトは、CogVideoXリポジトリの `tools/export_sat_lora_weight.py` にあります。エクスポート後、`load_cogvideox_lora.py` を使用して推論を行うことができます。
|
||||
The export script `export_sat_lora_weight.py` is located in the CogVideoX repository under `tools/`. After exporting,
|
||||
use `load_cogvideox_lora.py` for inference.
|
||||
|
||||
エクスポートコマンド:
|
||||
Export command:
|
||||
|
||||
```bash
|
||||
python tools/export_sat_lora_weight.py --sat_pt_path {args.save}/{experiment_name}-09-09-21-10/1000/mp_rank_00_model_states.pt --lora_save_directory {args.save}/export_hf_lora_weights_1/
|
||||
```
|
||||
python tools/export_sat_lora_weight.py --sat_pt_path {args.save}/{experiment_name}-09-09-21-10/1000/mp_rank_00_model_states.pt --lora_save_directory {args.save}/export_hf_lora_weights_1/
|
||||
```
|
||||
|
||||
このトレーニングでは主に以下のモデル構造が変更されました。以下の表は、HF (Hugging Face) 形式のLoRA構造に変換する際の対応関係を示しています。ご覧の通り、LoRAはモデルの注意メカニズムに低ランクの重みを追加しています。
|
||||
The following model structures were modified during training. Here is the mapping between SAT and HF Lora structures.
|
||||
Lora adds a low-rank weight to the attention structure of the model.
|
||||
|
||||
```
|
||||
'attention.query_key_value.matrix_A.0': 'attn1.to_q.lora_A.weight',
|
||||
@ -432,7 +545,5 @@ python tools/export_sat_lora_weight.py --sat_pt_path {args.save}/{experiment_nam
|
||||
'attention.dense.matrix_B.0': 'attn1.to_out.0.lora_B.weight'
|
||||
```
|
||||
|
||||
export_sat_lora_weight.py を使用して、SATチェックポイントをHF LoRA形式に変換できます。
|
||||
|
||||
|
||||
Using `export_sat_lora_weight.py` will convert these to the HF format Lora structure.
|
||||

|
||||
|
@ -1,12 +1,11 @@
|
||||
# SAT CogVideoX-2B
|
||||
# SAT CogVideoX
|
||||
|
||||
[Read this in English.](./README_zh)
|
||||
[Read this in English.](./README.md)
|
||||
|
||||
[日本語で読む](./README_ja.md)
|
||||
|
||||
本文件夹包含了使用 [SAT](https://github.com/THUDM/SwissArmyTransformer) 权重的推理代码,以及 SAT 权重的微调代码。
|
||||
|
||||
该代码是团队训练模型时使用的框架。注释较少,需要认真研究。
|
||||
如果你关注 `CogVideoX1.0`版本的模型,请查看[这里](https://github.com/THUDM/CogVideo/releases/tag/v1.0)的SAT文件夹,该分支仅支持`CogVideoX1.5`系列模型。
|
||||
|
||||
## 推理模型
|
||||
|
||||
@ -20,6 +19,15 @@ pip install -r requirements.txt
|
||||
|
||||
首先,前往 SAT 镜像下载模型权重。
|
||||
|
||||
#### CogVideoX1.5 模型
|
||||
|
||||
```shell
|
||||
git lfs install
|
||||
git clone https://huggingface.co/THUDM/CogVideoX1.5-5B-SAT
|
||||
```
|
||||
此操作会下载 Transformers, VAE, T5 Encoder 这三个模型。
|
||||
|
||||
#### CogVideoX 模型
|
||||
对于 CogVideoX-2B 模型,请按照如下方式下载:
|
||||
|
||||
```shell
|
||||
@ -82,11 +90,11 @@ mv CogVideoX-2b/text_encoder/* CogVideoX-2b/tokenizer/* t5-v1_1-xxl
|
||||
0 directories, 8 files
|
||||
```
|
||||
|
||||
### 3. 修改`configs/cogvideox_2b.yaml`中的文件。
|
||||
### 3. 修改`configs/cogvideox_*.yaml`中的文件。
|
||||
|
||||
```yaml
|
||||
model:
|
||||
scale_factor: 1.15258426
|
||||
scale_factor: 1.55258426
|
||||
disable_first_stage_autocast: true
|
||||
log_keys:
|
||||
- txt
|
||||
@ -253,7 +261,7 @@ args:
|
||||
batch_size: 1
|
||||
input_type: txt #可以选择txt纯文字档作为输入,或者改成cli命令行作为输入
|
||||
input_file: configs/test.txt #纯文字档,可以对此做编辑
|
||||
sampling_num_frames: 13 # Must be 13, 11 or 9
|
||||
sampling_num_frames: 13 #CogVideoX1.5-5B 必须是 42 或 22。 CogVideoX-5B / 2B 必须是 13 11 或 9。
|
||||
sampling_fps: 8
|
||||
fp16: True # For CogVideoX-2B
|
||||
# bf16: True # For CogVideoX-5B
|
||||
@ -346,7 +354,7 @@ Encoder 使用。
|
||||
|
||||
```yaml
|
||||
model:
|
||||
scale_factor: 1.15258426
|
||||
scale_factor: 1.55258426
|
||||
disable_first_stage_autocast: true
|
||||
not_trainable_prefixes: [ 'all' ] ## 解除注释
|
||||
log_keys:
|
||||
|
@ -18,7 +18,10 @@ def add_model_config_args(parser):
|
||||
group = parser.add_argument_group("model", "model configuration")
|
||||
group.add_argument("--base", type=str, nargs="*", help="config for input and saving")
|
||||
group.add_argument(
|
||||
"--model-parallel-size", type=int, default=1, help="size of the model parallel. only use if you are an expert."
|
||||
"--model-parallel-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="size of the model parallel. only use if you are an expert.",
|
||||
)
|
||||
group.add_argument("--force-pretrain", action="store_true")
|
||||
group.add_argument("--device", type=int, default=-1)
|
||||
@ -36,6 +39,7 @@ def add_sampling_config_args(parser):
|
||||
group.add_argument("--input-dir", type=str, default=None)
|
||||
group.add_argument("--input-type", type=str, default="cli")
|
||||
group.add_argument("--input-file", type=str, default="input.txt")
|
||||
group.add_argument("--sampling-image-size", type=list, default=[768, 1360])
|
||||
group.add_argument("--final-size", type=int, default=2048)
|
||||
group.add_argument("--sdedit", action="store_true")
|
||||
group.add_argument("--grid-num-rows", type=int, default=1)
|
||||
@ -73,10 +77,15 @@ def get_args(args_list=None, parser=None):
|
||||
if not args.train_data:
|
||||
print_rank0("No training data specified", level="WARNING")
|
||||
|
||||
assert (args.train_iters is None) or (args.epochs is None), "only one of train_iters and epochs should be set."
|
||||
assert (args.train_iters is None) or (
|
||||
args.epochs is None
|
||||
), "only one of train_iters and epochs should be set."
|
||||
if args.train_iters is None and args.epochs is None:
|
||||
args.train_iters = 10000 # default 10k iters
|
||||
print_rank0("No train_iters (recommended) or epochs specified, use default 10k iters.", level="WARNING")
|
||||
print_rank0(
|
||||
"No train_iters (recommended) or epochs specified, use default 10k iters.",
|
||||
level="WARNING",
|
||||
)
|
||||
|
||||
args.cuda = torch.cuda.is_available()
|
||||
|
||||
@ -212,7 +221,10 @@ def initialize_distributed(args):
|
||||
args.master_port = os.getenv("MASTER_PORT", default_master_port)
|
||||
init_method += args.master_ip + ":" + args.master_port
|
||||
torch.distributed.init_process_group(
|
||||
backend=args.distributed_backend, world_size=args.world_size, rank=args.rank, init_method=init_method
|
||||
backend=args.distributed_backend,
|
||||
world_size=args.world_size,
|
||||
rank=args.rank,
|
||||
init_method=init_method,
|
||||
)
|
||||
|
||||
# Set the model-parallel / data-parallel communicators.
|
||||
@ -231,7 +243,10 @@ def initialize_distributed(args):
|
||||
import deepspeed
|
||||
|
||||
deepspeed.init_distributed(
|
||||
dist_backend=args.distributed_backend, world_size=args.world_size, rank=args.rank, init_method=init_method
|
||||
dist_backend=args.distributed_backend,
|
||||
world_size=args.world_size,
|
||||
rank=args.rank,
|
||||
init_method=init_method,
|
||||
)
|
||||
# # It seems that it has no negative influence to configure it even without using checkpointing.
|
||||
# deepspeed.checkpointing.configure(mpu, deepspeed_config=args.deepspeed_config, num_checkpoints=args.num_layers)
|
||||
@ -261,7 +276,9 @@ def process_config_to_args(args):
|
||||
|
||||
args_config = config.pop("args", OmegaConf.create())
|
||||
for key in args_config:
|
||||
if isinstance(args_config[key], omegaconf.DictConfig) or isinstance(args_config[key], omegaconf.ListConfig):
|
||||
if isinstance(args_config[key], omegaconf.DictConfig) or isinstance(
|
||||
args_config[key], omegaconf.ListConfig
|
||||
):
|
||||
arg = OmegaConf.to_object(args_config[key])
|
||||
else:
|
||||
arg = args_config[key]
|
||||
|
149
sat/configs/cogvideox1.5_5b.yaml
Normal file
149
sat/configs/cogvideox1.5_5b.yaml
Normal file
@ -0,0 +1,149 @@
|
||||
model:
|
||||
scale_factor: 0.7
|
||||
disable_first_stage_autocast: true
|
||||
latent_input: true
|
||||
log_keys:
|
||||
- txt
|
||||
|
||||
denoiser_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
||||
params:
|
||||
num_idx: 1000
|
||||
quantize_c_noise: False
|
||||
|
||||
weighting_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
|
||||
scaling_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_scaling.VideoScaling
|
||||
discretization_config:
|
||||
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
|
||||
|
||||
network_config:
|
||||
target: dit_video_concat.DiffusionTransformer
|
||||
params:
|
||||
time_embed_dim: 512
|
||||
elementwise_affine: True
|
||||
num_frames: 81 # for 5 seconds and 161 for 10 seconds
|
||||
time_compressed_rate: 4
|
||||
latent_width: 300
|
||||
latent_height: 300
|
||||
num_layers: 42
|
||||
patch_size: [2, 2, 2]
|
||||
in_channels: 16
|
||||
out_channels: 16
|
||||
hidden_size: 3072
|
||||
adm_in_channels: 256
|
||||
num_attention_heads: 48
|
||||
|
||||
transformer_args:
|
||||
checkpoint_activations: True
|
||||
vocab_size: 1
|
||||
max_sequence_length: 64
|
||||
layernorm_order: pre
|
||||
skip_init: false
|
||||
model_parallel_size: 1
|
||||
is_decoder: false
|
||||
|
||||
modules:
|
||||
pos_embed_config:
|
||||
target: dit_video_concat.Rotary3DPositionEmbeddingMixin
|
||||
params:
|
||||
hidden_size_head: 64
|
||||
text_length: 224
|
||||
|
||||
patch_embed_config:
|
||||
target: dit_video_concat.ImagePatchEmbeddingMixin
|
||||
params:
|
||||
text_hidden_size: 4096
|
||||
|
||||
adaln_layer_config:
|
||||
target: dit_video_concat.AdaLNMixin
|
||||
params:
|
||||
qk_ln: True
|
||||
|
||||
final_layer_config:
|
||||
target: dit_video_concat.FinalLayerMixin
|
||||
|
||||
conditioner_config:
|
||||
target: sgm.modules.GeneralConditioner
|
||||
params:
|
||||
emb_models:
|
||||
- is_trainable: false
|
||||
input_key: txt
|
||||
ucg_rate: 0.1
|
||||
target: sgm.modules.encoders.modules.FrozenT5Embedder
|
||||
params:
|
||||
model_dir: "google/t5-v1_1-xxl"
|
||||
max_length: 224
|
||||
|
||||
|
||||
first_stage_config:
|
||||
target : vae_modules.autoencoder.VideoAutoencoderInferenceWrapper
|
||||
params:
|
||||
cp_size: 1
|
||||
ckpt_path: "cogvideox-5b-sat/vae/3d-vae.pt"
|
||||
ignore_keys: ['loss']
|
||||
|
||||
loss_config:
|
||||
target: torch.nn.Identity
|
||||
|
||||
regularizer_config:
|
||||
target: vae_modules.regularizers.DiagonalGaussianRegularizer
|
||||
|
||||
encoder_config:
|
||||
target: vae_modules.cp_enc_dec.ContextParallelEncoder3D
|
||||
params:
|
||||
double_z: true
|
||||
z_channels: 16
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult: [1, 2, 2, 4]
|
||||
attn_resolutions: []
|
||||
num_res_blocks: 3
|
||||
dropout: 0.0
|
||||
gather_norm: True
|
||||
|
||||
decoder_config:
|
||||
target: vae_modules.cp_enc_dec.ContextParallelDecoder3D
|
||||
params:
|
||||
double_z: True
|
||||
z_channels: 16
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult: [1, 2, 2, 4]
|
||||
attn_resolutions: []
|
||||
num_res_blocks: 3
|
||||
dropout: 0.0
|
||||
gather_norm: True
|
||||
|
||||
loss_fn_config:
|
||||
target: sgm.modules.diffusionmodules.loss.VideoDiffusionLoss
|
||||
params:
|
||||
offset_noise_level: 0
|
||||
sigma_sampler_config:
|
||||
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
|
||||
params:
|
||||
uniform_sampling: True
|
||||
group_num: 40
|
||||
num_idx: 1000
|
||||
discretization_config:
|
||||
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
|
||||
|
||||
sampler_config:
|
||||
target: sgm.modules.diffusionmodules.sampling.VPSDEDPMPP2MSampler
|
||||
params:
|
||||
num_steps: 50
|
||||
verbose: True
|
||||
|
||||
discretization_config:
|
||||
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
|
||||
guider_config:
|
||||
target: sgm.modules.diffusionmodules.guiders.DynamicCFG
|
||||
params:
|
||||
scale: 6
|
||||
exp: 5
|
||||
num_steps: 50
|
159
sat/configs/cogvideox1.5_5b_i2v.yaml
Normal file
159
sat/configs/cogvideox1.5_5b_i2v.yaml
Normal file
@ -0,0 +1,159 @@
|
||||
model:
|
||||
scale_factor: 0.7
|
||||
disable_first_stage_autocast: true
|
||||
latent_input: false
|
||||
noised_image_input: true
|
||||
noised_image_all_concat: false
|
||||
noised_image_dropout: 0.05
|
||||
augmentation_dropout: 0.15
|
||||
log_keys:
|
||||
- txt
|
||||
|
||||
denoiser_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
||||
params:
|
||||
num_idx: 1000
|
||||
quantize_c_noise: False
|
||||
|
||||
weighting_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
|
||||
scaling_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_scaling.VideoScaling
|
||||
discretization_config:
|
||||
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
|
||||
|
||||
network_config:
|
||||
target: dit_video_concat.DiffusionTransformer
|
||||
params:
|
||||
ofs_embed_dim: 512
|
||||
time_embed_dim: 512
|
||||
elementwise_affine: True
|
||||
num_frames: 81 # for 5 seconds and 161 for 10 seconds
|
||||
time_compressed_rate: 4
|
||||
latent_width: 300
|
||||
latent_height: 300
|
||||
num_layers: 42
|
||||
patch_size: [2, 2, 2]
|
||||
in_channels: 32
|
||||
out_channels: 16
|
||||
hidden_size: 3072
|
||||
adm_in_channels: 256
|
||||
num_attention_heads: 48
|
||||
|
||||
transformer_args:
|
||||
checkpoint_activations: True
|
||||
vocab_size: 1
|
||||
max_sequence_length: 64
|
||||
layernorm_order: pre
|
||||
skip_init: false
|
||||
model_parallel_size: 1
|
||||
is_decoder: false
|
||||
|
||||
modules:
|
||||
pos_embed_config:
|
||||
target: dit_video_concat.Rotary3DPositionEmbeddingMixin
|
||||
params:
|
||||
hidden_size_head: 64
|
||||
text_length: 224
|
||||
|
||||
patch_embed_config:
|
||||
target: dit_video_concat.ImagePatchEmbeddingMixin
|
||||
params:
|
||||
text_hidden_size: 4096
|
||||
|
||||
|
||||
adaln_layer_config:
|
||||
target: dit_video_concat.AdaLNMixin
|
||||
params:
|
||||
qk_ln: True
|
||||
|
||||
final_layer_config:
|
||||
target: dit_video_concat.FinalLayerMixin
|
||||
|
||||
conditioner_config:
|
||||
target: sgm.modules.GeneralConditioner
|
||||
params:
|
||||
emb_models:
|
||||
|
||||
- is_trainable: false
|
||||
input_key: txt
|
||||
ucg_rate: 0.1
|
||||
target: sgm.modules.encoders.modules.FrozenT5Embedder
|
||||
params:
|
||||
model_dir: "google/t5-v1_1-xxl"
|
||||
max_length: 224
|
||||
|
||||
|
||||
first_stage_config:
|
||||
target : vae_modules.autoencoder.VideoAutoencoderInferenceWrapper
|
||||
params:
|
||||
cp_size: 1
|
||||
ckpt_path: "cogvideox-5b-i2v-sat/vae/3d-vae.pt"
|
||||
ignore_keys: ['loss']
|
||||
|
||||
loss_config:
|
||||
target: torch.nn.Identity
|
||||
|
||||
regularizer_config:
|
||||
target: vae_modules.regularizers.DiagonalGaussianRegularizer
|
||||
|
||||
encoder_config:
|
||||
target: vae_modules.cp_enc_dec.ContextParallelEncoder3D
|
||||
params:
|
||||
double_z: true
|
||||
z_channels: 16
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult: [1, 2, 2, 4]
|
||||
attn_resolutions: []
|
||||
num_res_blocks: 3
|
||||
dropout: 0.0
|
||||
gather_norm: True
|
||||
|
||||
decoder_config:
|
||||
target: vae_modules.cp_enc_dec.ContextParallelDecoder3D
|
||||
params:
|
||||
double_z: True
|
||||
z_channels: 16
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult: [1, 2, 2, 4]
|
||||
attn_resolutions: []
|
||||
num_res_blocks: 3
|
||||
dropout: 0.0
|
||||
gather_norm: True
|
||||
|
||||
loss_fn_config:
|
||||
target: sgm.modules.diffusionmodules.loss.VideoDiffusionLoss
|
||||
params:
|
||||
fixed_frames: 0
|
||||
offset_noise_level: 0.0
|
||||
sigma_sampler_config:
|
||||
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
|
||||
params:
|
||||
uniform_sampling: True
|
||||
group_num: 40
|
||||
num_idx: 1000
|
||||
discretization_config:
|
||||
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
|
||||
|
||||
sampler_config:
|
||||
target: sgm.modules.diffusionmodules.sampling.VPSDEDPMPP2MSampler
|
||||
params:
|
||||
fixed_frames: 0
|
||||
num_steps: 50
|
||||
verbose: True
|
||||
|
||||
discretization_config:
|
||||
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
|
||||
|
||||
guider_config:
|
||||
target: sgm.modules.diffusionmodules.guiders.DynamicCFG
|
||||
params:
|
||||
scale: 6
|
||||
exp: 5
|
||||
num_steps: 50
|
@ -1,16 +1,14 @@
|
||||
args:
|
||||
image2video: False # True for image2video, False for text2video
|
||||
# image2video: True # True for image2video, False for text2video
|
||||
latent_channels: 16
|
||||
mode: inference
|
||||
load: "{your CogVideoX SAT folder}/transformer" # This is for Full model without lora adapter
|
||||
# load: "{your lora folder} such as zRzRzRzRzRzRzR/lora-disney-08-20-13-28" # This is for Full model without lora adapter
|
||||
batch_size: 1
|
||||
input_type: txt
|
||||
input_file: configs/test.txt
|
||||
sampling_image_size: [480, 720]
|
||||
sampling_num_frames: 13 # Must be 13, 11 or 9
|
||||
sampling_fps: 8
|
||||
# fp16: True # For CogVideoX-2B
|
||||
bf16: True # For CogVideoX-5B and CoGVideoX-5B-I2V
|
||||
output_dir: outputs/
|
||||
sampling_image_size: [768, 1360] # remove this for I2V
|
||||
sampling_num_frames: 22 # 42 for 10 seconds and 22 for 5 seconds
|
||||
sampling_fps: 16
|
||||
bf16: True
|
||||
output_dir: outputs
|
||||
force_inference: True
|
@ -1,4 +1,4 @@
|
||||
In the haunting backdrop of a warIn the haunting backdrop of a war-torn city, where ruins and crumbled walls tell a story of devastation, a poignant close-up frames a young girl. Her face is smudged with ash, a silent testament to the chaos around her. Her eyes glistening with a mix of sorrow and resilience, capturing the raw emotion of a world that has lost its innocence to the ravages of conflict.
|
||||
The camera follows behind a white vintage SUV with a black roof rack as it speeds up a steep dirt road surrounded by pine trees on a steep mountain slope, dust kicks up from its tires, the sunlight shines on the SUV as it speeds along the dirt road, casting a warm glow over the scene. The dirt road curves gently into the distance, with no other cars or vehicles in sight. The trees on either side of the road are redwoods, with patches of greenery scattered throughout. The car is seen from the rear following the curve with ease, making it seem as if it is on a rugged drive through the rugged terrain. The dirt road itself is surrounded by steep hills and mountains, with a clear blue sky above with wispy clouds.
|
||||
A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting.
|
||||
A street artist, clad in a worn-out denim jacket and a colorful bandana, stands before a vast concrete wall in the heart, holding a can of spray paint, spray-painting a colorful bird on a mottled wall.-torn city, where ruins and crumbled walls tell a story of devastation, a poignant close-up frames a young girl. Her face is smudged with ash, a silent testament to the chaos around her. Her eyes glistening with a mix of sorrow and resilience, capturing the raw emotion of a world that has lost its innocence to the ravages of conflict.
|
||||
A street artist, clad in a worn-out denim jacket and a colorful banana, stands before a vast concrete wall in the heart, holding a can of spray paint, spray-painting a colorful bird on a mottled wall.-torn city, where ruins and crumbled walls tell a story of devastation, a poignant close-up frames a young girl. Her face is smudged with ash, a silent testament to the chaos around her. Her eyes glistening with a mix of sorrow and resilience, capturing the raw emotion of a world that has lost its innocence to the ravages of conflict.
|
||||
|
@ -56,7 +56,9 @@ def read_video(
|
||||
end_pts = float("inf")
|
||||
|
||||
if end_pts < start_pts:
|
||||
raise ValueError(f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}")
|
||||
raise ValueError(
|
||||
f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}"
|
||||
)
|
||||
|
||||
info = {}
|
||||
audio_frames = []
|
||||
@ -342,7 +344,11 @@ class VideoDataset(MetaDistributedWebDataset):
|
||||
super().__init__(
|
||||
path,
|
||||
partial(
|
||||
process_fn_video, num_frames=num_frames, image_size=image_size, fps=fps, skip_frms_num=skip_frms_num
|
||||
process_fn_video,
|
||||
num_frames=num_frames,
|
||||
image_size=image_size,
|
||||
fps=fps,
|
||||
skip_frms_num=skip_frms_num,
|
||||
),
|
||||
seed,
|
||||
meta_names=meta_names,
|
||||
@ -400,7 +406,9 @@ class SFTDataset(Dataset):
|
||||
indices = np.arange(start, end, (end - start) // num_frames).astype(int)
|
||||
temp_frms = vr.get_batch(np.arange(start, end_safty))
|
||||
assert temp_frms is not None
|
||||
tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
|
||||
tensor_frms = (
|
||||
torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
|
||||
)
|
||||
tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())]
|
||||
else:
|
||||
if ori_vlen > self.max_num_frames:
|
||||
@ -410,7 +418,11 @@ class SFTDataset(Dataset):
|
||||
indices = np.arange(start, end, max((end - start) // num_frames, 1)).astype(int)
|
||||
temp_frms = vr.get_batch(np.arange(start, end))
|
||||
assert temp_frms is not None
|
||||
tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
|
||||
tensor_frms = (
|
||||
torch.from_numpy(temp_frms)
|
||||
if type(temp_frms) is not torch.Tensor
|
||||
else temp_frms
|
||||
)
|
||||
tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())]
|
||||
else:
|
||||
|
||||
@ -423,11 +435,17 @@ class SFTDataset(Dataset):
|
||||
|
||||
start = int(self.skip_frms_num)
|
||||
end = int(ori_vlen - self.skip_frms_num)
|
||||
num_frames = nearest_smaller_4k_plus_1(end - start) # 3D VAE requires the number of frames to be 4k+1
|
||||
num_frames = nearest_smaller_4k_plus_1(
|
||||
end - start
|
||||
) # 3D VAE requires the number of frames to be 4k+1
|
||||
end = int(start + num_frames)
|
||||
temp_frms = vr.get_batch(np.arange(start, end))
|
||||
assert temp_frms is not None
|
||||
tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
|
||||
tensor_frms = (
|
||||
torch.from_numpy(temp_frms)
|
||||
if type(temp_frms) is not torch.Tensor
|
||||
else temp_frms
|
||||
)
|
||||
|
||||
tensor_frms = pad_last_frame(
|
||||
tensor_frms, self.max_num_frames
|
||||
|
@ -10,7 +10,6 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
from sgm.modules import UNCONDITIONAL_CONFIG
|
||||
from sgm.modules.autoencoding.temporal_ae import VideoDecoder
|
||||
from sgm.modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
|
||||
from sgm.util import (
|
||||
default,
|
||||
@ -42,7 +41,9 @@ class SATVideoDiffusionEngine(nn.Module):
|
||||
latent_input = model_config.get("latent_input", False)
|
||||
disable_first_stage_autocast = model_config.get("disable_first_stage_autocast", False)
|
||||
no_cond_log = model_config.get("disable_first_stage_autocast", False)
|
||||
not_trainable_prefixes = model_config.get("not_trainable_prefixes", ["first_stage_model", "conditioner"])
|
||||
not_trainable_prefixes = model_config.get(
|
||||
"not_trainable_prefixes", ["first_stage_model", "conditioner"]
|
||||
)
|
||||
compile_model = model_config.get("compile_model", False)
|
||||
en_and_decode_n_samples_a_time = model_config.get("en_and_decode_n_samples_a_time", None)
|
||||
lr_scale = model_config.get("lr_scale", None)
|
||||
@ -77,12 +78,18 @@ class SATVideoDiffusionEngine(nn.Module):
|
||||
)
|
||||
|
||||
self.denoiser = instantiate_from_config(denoiser_config)
|
||||
self.sampler = instantiate_from_config(sampler_config) if sampler_config is not None else None
|
||||
self.conditioner = instantiate_from_config(default(conditioner_config, UNCONDITIONAL_CONFIG))
|
||||
self.sampler = (
|
||||
instantiate_from_config(sampler_config) if sampler_config is not None else None
|
||||
)
|
||||
self.conditioner = instantiate_from_config(
|
||||
default(conditioner_config, UNCONDITIONAL_CONFIG)
|
||||
)
|
||||
|
||||
self._init_first_stage(first_stage_config)
|
||||
|
||||
self.loss_fn = instantiate_from_config(loss_fn_config) if loss_fn_config is not None else None
|
||||
self.loss_fn = (
|
||||
instantiate_from_config(loss_fn_config) if loss_fn_config is not None else None
|
||||
)
|
||||
|
||||
self.latent_input = latent_input
|
||||
self.scale_factor = scale_factor
|
||||
@ -90,27 +97,37 @@ class SATVideoDiffusionEngine(nn.Module):
|
||||
self.no_cond_log = no_cond_log
|
||||
self.device = args.device
|
||||
|
||||
# put lora add here
|
||||
def disable_untrainable_params(self):
|
||||
total_trainable = 0
|
||||
for n, p in self.named_parameters():
|
||||
if p.requires_grad == False:
|
||||
continue
|
||||
flag = False
|
||||
for prefix in self.not_trainable_prefixes:
|
||||
if n.startswith(prefix) or prefix == "all":
|
||||
flag = True
|
||||
break
|
||||
if self.lora_train:
|
||||
for n, p in self.named_parameters():
|
||||
if p.requires_grad == False:
|
||||
continue
|
||||
if 'lora_layer' not in n:
|
||||
p.lr_scale = 0
|
||||
else:
|
||||
total_trainable += p.numel()
|
||||
else:
|
||||
for n, p in self.named_parameters():
|
||||
if p.requires_grad == False:
|
||||
continue
|
||||
flag = False
|
||||
for prefix in self.not_trainable_prefixes:
|
||||
if n.startswith(prefix) or prefix == "all":
|
||||
flag = True
|
||||
break
|
||||
|
||||
lora_prefix = ["matrix_A", "matrix_B"]
|
||||
for prefix in lora_prefix:
|
||||
if prefix in n:
|
||||
flag = False
|
||||
break
|
||||
lora_prefix = ['matrix_A', 'matrix_B']
|
||||
for prefix in lora_prefix:
|
||||
if prefix in n:
|
||||
flag = False
|
||||
break
|
||||
|
||||
if flag:
|
||||
p.requires_grad_(False)
|
||||
else:
|
||||
total_trainable += p.numel()
|
||||
if flag:
|
||||
p.requires_grad_(False)
|
||||
else:
|
||||
total_trainable += p.numel()
|
||||
|
||||
print_rank0("***** Total trainable parameters: " + str(total_trainable) + " *****")
|
||||
|
||||
@ -142,8 +159,12 @@ class SATVideoDiffusionEngine(nn.Module):
|
||||
def shared_step(self, batch: Dict) -> Any:
|
||||
x = self.get_input(batch)
|
||||
if self.lr_scale is not None:
|
||||
lr_x = F.interpolate(x, scale_factor=1 / self.lr_scale, mode="bilinear", align_corners=False)
|
||||
lr_x = F.interpolate(lr_x, scale_factor=self.lr_scale, mode="bilinear", align_corners=False)
|
||||
lr_x = F.interpolate(
|
||||
x, scale_factor=1 / self.lr_scale, mode="bilinear", align_corners=False
|
||||
)
|
||||
lr_x = F.interpolate(
|
||||
lr_x, scale_factor=self.lr_scale, mode="bilinear", align_corners=False
|
||||
)
|
||||
lr_z = self.encode_first_stage(lr_x, batch)
|
||||
batch["lr_input"] = lr_z
|
||||
|
||||
@ -179,14 +200,31 @@ class SATVideoDiffusionEngine(nn.Module):
|
||||
n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])
|
||||
n_rounds = math.ceil(z.shape[0] / n_samples)
|
||||
all_out = []
|
||||
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
|
||||
for n in range(n_rounds):
|
||||
if isinstance(self.first_stage_model.decoder, VideoDecoder):
|
||||
kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])}
|
||||
else:
|
||||
kwargs = {}
|
||||
out = self.first_stage_model.decode(z[n * n_samples : (n + 1) * n_samples], **kwargs)
|
||||
all_out.append(out)
|
||||
for n in range(n_rounds):
|
||||
z_now = z[n * n_samples : (n + 1) * n_samples, :, 1:]
|
||||
latent_time = z_now.shape[2] # check the time latent
|
||||
fake_cp_size = min(10, latent_time // 2)
|
||||
recons = []
|
||||
start_frame = 0
|
||||
for i in range(fake_cp_size):
|
||||
end_frame = (
|
||||
start_frame
|
||||
+ latent_time // fake_cp_size
|
||||
+ (1 if i < latent_time % fake_cp_size else 0)
|
||||
)
|
||||
|
||||
use_cp = True if i == 0 else False
|
||||
clear_fake_cp_cache = True if i == fake_cp_size - 1 else False
|
||||
with torch.no_grad():
|
||||
recon = self.first_stage_model.decode(
|
||||
z_now[:, :, start_frame:end_frame].contiguous(),
|
||||
clear_fake_cp_cache=clear_fake_cp_cache,
|
||||
use_cp=use_cp,
|
||||
)
|
||||
recons.append(recon)
|
||||
start_frame = end_frame
|
||||
recons = torch.cat(recons, dim=2)
|
||||
all_out.append(recons)
|
||||
out = torch.cat(all_out, dim=0)
|
||||
return out
|
||||
|
||||
@ -218,6 +256,7 @@ class SATVideoDiffusionEngine(nn.Module):
|
||||
shape: Union[None, Tuple, List] = None,
|
||||
prefix=None,
|
||||
concat_images=None,
|
||||
ofs=None,
|
||||
**kwargs,
|
||||
):
|
||||
randn = torch.randn(batch_size, *shape).to(torch.float32).to(self.device)
|
||||
@ -241,7 +280,9 @@ class SATVideoDiffusionEngine(nn.Module):
|
||||
self.model, input, sigma, c, concat_images=concat_images, **addtional_model_inputs
|
||||
)
|
||||
|
||||
samples = self.sampler(denoiser, randn, cond, uc=uc, scale=scale, scale_emb=scale_emb)
|
||||
samples = self.sampler(
|
||||
denoiser, randn, cond, uc=uc, scale=scale, scale_emb=scale_emb, ofs=ofs
|
||||
)
|
||||
samples = samples.to(self.dtype)
|
||||
return samples
|
||||
|
||||
@ -255,7 +296,9 @@ class SATVideoDiffusionEngine(nn.Module):
|
||||
log = dict()
|
||||
|
||||
for embedder in self.conditioner.embedders:
|
||||
if ((self.log_keys is None) or (embedder.input_key in self.log_keys)) and not self.no_cond_log:
|
||||
if (
|
||||
(self.log_keys is None) or (embedder.input_key in self.log_keys)
|
||||
) and not self.no_cond_log:
|
||||
x = batch[embedder.input_key][:n]
|
||||
if isinstance(x, torch.Tensor):
|
||||
if x.dim() == 1:
|
||||
@ -331,7 +374,9 @@ class SATVideoDiffusionEngine(nn.Module):
|
||||
image = torch.concat([image, torch.zeros_like(z[:, 1:])], dim=1)
|
||||
c["concat"] = image
|
||||
uc["concat"] = image
|
||||
samples = self.sample(c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs) # b t c h w
|
||||
samples = self.sample(
|
||||
c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
|
||||
) # b t c h w
|
||||
samples = samples.permute(0, 2, 1, 3, 4).contiguous()
|
||||
if only_log_video_latents:
|
||||
latents = 1.0 / self.scale_factor * samples
|
||||
@ -341,7 +386,9 @@ class SATVideoDiffusionEngine(nn.Module):
|
||||
samples = samples.permute(0, 2, 1, 3, 4).contiguous()
|
||||
log["samples"] = samples
|
||||
else:
|
||||
samples = self.sample(c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs) # b t c h w
|
||||
samples = self.sample(
|
||||
c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
|
||||
) # b t c h w
|
||||
samples = samples.permute(0, 2, 1, 3, 4).contiguous()
|
||||
if only_log_video_latents:
|
||||
latents = 1.0 / self.scale_factor * samples
|
||||
|
@ -1,11 +1,12 @@
|
||||
from functools import partial
|
||||
from einops import rearrange, repeat
|
||||
from functools import reduce
|
||||
from operator import mul
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from sat.model.base_model import BaseModel, non_conflict
|
||||
from sat.model.mixins import BaseMixin
|
||||
from sat.transformer_defaults import HOOKS_DEFAULT, attention_fn_default
|
||||
@ -13,38 +14,35 @@ from sat.mpu.layers import ColumnParallelLinear
|
||||
from sgm.util import instantiate_from_config
|
||||
|
||||
from sgm.modules.diffusionmodules.openaimodel import Timestep
|
||||
from sgm.modules.diffusionmodules.util import (
|
||||
linear,
|
||||
timestep_embedding,
|
||||
)
|
||||
from sgm.modules.diffusionmodules.util import linear, timestep_embedding
|
||||
from sat.ops.layernorm import LayerNorm, RMSNorm
|
||||
|
||||
|
||||
class ImagePatchEmbeddingMixin(BaseMixin):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
hidden_size,
|
||||
patch_size,
|
||||
bias=True,
|
||||
text_hidden_size=None,
|
||||
):
|
||||
def __init__(self, in_channels, hidden_size, patch_size, text_hidden_size=None):
|
||||
super().__init__()
|
||||
self.proj = nn.Conv2d(in_channels, hidden_size, kernel_size=patch_size, stride=patch_size, bias=bias)
|
||||
self.patch_size = patch_size
|
||||
self.proj = nn.Linear(in_channels * reduce(mul, patch_size), hidden_size)
|
||||
if text_hidden_size is not None:
|
||||
self.text_proj = nn.Linear(text_hidden_size, hidden_size)
|
||||
else:
|
||||
self.text_proj = None
|
||||
|
||||
def word_embedding_forward(self, input_ids, **kwargs):
|
||||
# now is 3d patch
|
||||
images = kwargs["images"] # (b,t,c,h,w)
|
||||
B, T = images.shape[:2]
|
||||
emb = images.view(-1, *images.shape[2:])
|
||||
emb = self.proj(emb) # ((b t),d,h/2,w/2)
|
||||
emb = emb.view(B, T, *emb.shape[1:])
|
||||
emb = emb.flatten(3).transpose(2, 3) # (b,t,n,d)
|
||||
emb = rearrange(emb, "b t n d -> b (t n) d")
|
||||
emb = rearrange(images, "b t c h w -> b (t h w) c")
|
||||
# emb = rearrange(images, "b c t h w -> b (t h w) c")
|
||||
emb = rearrange(
|
||||
emb,
|
||||
"b (t o h p w q) c -> b (t h w) (c o p q)",
|
||||
t=kwargs["rope_T"],
|
||||
h=kwargs["rope_H"],
|
||||
w=kwargs["rope_W"],
|
||||
o=self.patch_size[0],
|
||||
p=self.patch_size[1],
|
||||
q=self.patch_size[2],
|
||||
)
|
||||
emb = self.proj(emb)
|
||||
|
||||
if self.text_proj is not None:
|
||||
text_emb = self.text_proj(kwargs["encoder_outputs"])
|
||||
@ -74,7 +72,8 @@ def get_3d_sincos_pos_embed(
|
||||
grid_size: int of the grid height and width
|
||||
t_size: int of the temporal size
|
||||
return:
|
||||
pos_embed: [t_size*grid_size*grid_size, embed_dim] or [1+t_size*grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
||||
pos_embed: [t_size*grid_size * grid_size, embed_dim] or [1+t_size*grid_size * grid_size, embed_dim]
|
||||
(w/ or w/o cls_token)
|
||||
"""
|
||||
assert embed_dim % 4 == 0
|
||||
embed_dim_spatial = embed_dim // 4 * 3
|
||||
@ -95,12 +94,13 @@ def get_3d_sincos_pos_embed(
|
||||
|
||||
# concate: [T, H, W] order
|
||||
pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :]
|
||||
pos_embed_temporal = np.repeat(pos_embed_temporal, grid_height * grid_width, axis=1) # [T, H*W, D // 4]
|
||||
pos_embed_temporal = np.repeat(
|
||||
pos_embed_temporal, grid_height * grid_width, axis=1
|
||||
) # [T, H*W, D // 4]
|
||||
pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :]
|
||||
pos_embed_spatial = np.repeat(pos_embed_spatial, t_size, axis=0) # [T, H*W, D // 4 * 3]
|
||||
|
||||
pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1)
|
||||
# pos_embed = pos_embed.reshape([-1, embed_dim]) # [T*H*W, D]
|
||||
|
||||
return pos_embed # [T, H*W, D]
|
||||
|
||||
@ -162,7 +162,8 @@ class Basic2DPositionEmbeddingMixin(BaseMixin):
|
||||
self.width = width
|
||||
self.spatial_length = height * width
|
||||
self.pos_embedding = nn.Parameter(
|
||||
torch.zeros(1, int(text_length + self.spatial_length), int(hidden_size)), requires_grad=False
|
||||
torch.zeros(1, int(text_length + self.spatial_length), int(hidden_size)),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
def position_embedding_forward(self, position_ids, **kwargs):
|
||||
@ -171,7 +172,9 @@ class Basic2DPositionEmbeddingMixin(BaseMixin):
|
||||
def reinit(self, parent_model=None):
|
||||
del self.transformer.position_embeddings
|
||||
pos_embed = get_2d_sincos_pos_embed(self.pos_embedding.shape[-1], self.height, self.width)
|
||||
self.pos_embedding.data[:, -self.spatial_length :].copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
||||
self.pos_embedding.data[:, -self.spatial_length :].copy_(
|
||||
torch.from_numpy(pos_embed).float().unsqueeze(0)
|
||||
)
|
||||
|
||||
|
||||
class Basic3DPositionEmbeddingMixin(BaseMixin):
|
||||
@ -194,7 +197,8 @@ class Basic3DPositionEmbeddingMixin(BaseMixin):
|
||||
self.spatial_length = height * width
|
||||
self.num_patches = height * width * compressed_num_frames
|
||||
self.pos_embedding = nn.Parameter(
|
||||
torch.zeros(1, int(text_length + self.num_patches), int(hidden_size)), requires_grad=False
|
||||
torch.zeros(1, int(text_length + self.num_patches), int(hidden_size)),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.height_interpolation = height_interpolation
|
||||
self.width_interpolation = width_interpolation
|
||||
@ -259,6 +263,9 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin):
|
||||
text_length,
|
||||
theta=10000,
|
||||
rot_v=False,
|
||||
height_interpolation=1.0,
|
||||
width_interpolation=1.0,
|
||||
time_interpolation=1.0,
|
||||
learnable_pos_embed=False,
|
||||
):
|
||||
super().__init__()
|
||||
@ -284,32 +291,38 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin):
|
||||
freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2)
|
||||
freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
|
||||
|
||||
freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
|
||||
freqs = rearrange(freqs, "t h w d -> (t h w) d")
|
||||
freqs = broadcat(
|
||||
(freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]),
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
freqs = freqs.contiguous()
|
||||
freqs_sin = freqs.sin()
|
||||
freqs_cos = freqs.cos()
|
||||
self.register_buffer("freqs_sin", freqs_sin)
|
||||
self.register_buffer("freqs_cos", freqs_cos)
|
||||
|
||||
self.freqs_sin = freqs.sin().cuda()
|
||||
self.freqs_cos = freqs.cos().cuda()
|
||||
self.text_length = text_length
|
||||
if learnable_pos_embed:
|
||||
num_patches = height * width * compressed_num_frames + text_length
|
||||
self.pos_embedding = nn.Parameter(torch.zeros(1, num_patches, int(hidden_size)), requires_grad=True)
|
||||
self.pos_embedding = nn.Parameter(
|
||||
torch.zeros(1, num_patches, int(hidden_size)), requires_grad=True
|
||||
)
|
||||
else:
|
||||
self.pos_embedding = None
|
||||
|
||||
def rotary(self, t, **kwargs):
|
||||
seq_len = t.shape[2]
|
||||
freqs_cos = self.freqs_cos[:seq_len].unsqueeze(0).unsqueeze(0)
|
||||
freqs_sin = self.freqs_sin[:seq_len].unsqueeze(0).unsqueeze(0)
|
||||
def reshape_freq(freqs):
|
||||
freqs = freqs[: kwargs["rope_T"], : kwargs["rope_H"], : kwargs["rope_W"]].contiguous()
|
||||
freqs = rearrange(freqs, "t h w d -> (t h w) d")
|
||||
freqs = freqs.unsqueeze(0).unsqueeze(0)
|
||||
return freqs
|
||||
|
||||
freqs_cos = reshape_freq(self.freqs_cos).to(t.dtype)
|
||||
freqs_sin = reshape_freq(self.freqs_sin).to(t.dtype)
|
||||
|
||||
return t * freqs_cos + rotate_half(t) * freqs_sin
|
||||
|
||||
def position_embedding_forward(self, position_ids, **kwargs):
|
||||
if self.pos_embedding is not None:
|
||||
return self.pos_embedding[:, :self.text_length + kwargs["seq_length"]]
|
||||
return self.pos_embedding[:, : self.text_length + kwargs["seq_length"]]
|
||||
else:
|
||||
return None
|
||||
|
||||
@ -326,10 +339,61 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin):
|
||||
):
|
||||
attention_fn_default = HOOKS_DEFAULT["attention_fn"]
|
||||
|
||||
query_layer[:, :, self.text_length :] = self.rotary(query_layer[:, :, self.text_length :])
|
||||
key_layer[:, :, self.text_length :] = self.rotary(key_layer[:, :, self.text_length :])
|
||||
query_layer = torch.cat(
|
||||
(
|
||||
query_layer[
|
||||
:,
|
||||
:,
|
||||
: kwargs["text_length"],
|
||||
],
|
||||
self.rotary(
|
||||
query_layer[
|
||||
:,
|
||||
:,
|
||||
kwargs["text_length"] :,
|
||||
],
|
||||
**kwargs,
|
||||
),
|
||||
),
|
||||
dim=2,
|
||||
)
|
||||
key_layer = torch.cat(
|
||||
(
|
||||
key_layer[
|
||||
:,
|
||||
:,
|
||||
: kwargs["text_length"],
|
||||
],
|
||||
self.rotary(
|
||||
key_layer[
|
||||
:,
|
||||
:,
|
||||
kwargs["text_length"] :,
|
||||
],
|
||||
**kwargs,
|
||||
),
|
||||
),
|
||||
dim=2,
|
||||
)
|
||||
if self.rot_v:
|
||||
value_layer[:, :, self.text_length :] = self.rotary(value_layer[:, :, self.text_length :])
|
||||
value_layer = torch.cat(
|
||||
(
|
||||
value_layer[
|
||||
:,
|
||||
:,
|
||||
: kwargs["text_length"],
|
||||
],
|
||||
self.rotary(
|
||||
value_layer[
|
||||
:,
|
||||
:,
|
||||
kwargs["text_length"] :,
|
||||
],
|
||||
**kwargs,
|
||||
),
|
||||
),
|
||||
dim=2,
|
||||
)
|
||||
|
||||
return attention_fn_default(
|
||||
query_layer,
|
||||
@ -347,21 +411,25 @@ def modulate(x, shift, scale):
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
|
||||
def unpatchify(x, c, p, w, h, rope_position_ids=None, **kwargs):
|
||||
def unpatchify(x, c, patch_size, w, h, **kwargs):
|
||||
"""
|
||||
x: (N, T/2 * S, patch_size**3 * C)
|
||||
imgs: (N, T, H, W, C)
|
||||
|
||||
patch_size 被拆解为三个不同的维度 (o, p, q),分别对应了深度(o)、高度(p)和宽度(q)。这使得 patch 大小在不同维度上可以不相等,增加了灵活性。
|
||||
"""
|
||||
if rope_position_ids is not None:
|
||||
assert NotImplementedError
|
||||
# do pix2struct unpatchify
|
||||
L = x.shape[1]
|
||||
x = x.reshape(shape=(x.shape[0], L, p, p, c))
|
||||
x = torch.einsum("nlpqc->ncplq", x)
|
||||
imgs = x.reshape(shape=(x.shape[0], c, p, L * p))
|
||||
else:
|
||||
b = x.shape[0]
|
||||
imgs = rearrange(x, "b (t h w) (c p q) -> b t c (h p) (w q)", b=b, h=h, w=w, c=c, p=p, q=p)
|
||||
|
||||
imgs = rearrange(
|
||||
x,
|
||||
"b (t h w) (c o p q) -> b (t o) c (h p) (w q)",
|
||||
c=c,
|
||||
o=patch_size[0],
|
||||
p=patch_size[1],
|
||||
q=patch_size[2],
|
||||
t=kwargs["rope_T"],
|
||||
h=kwargs["rope_H"],
|
||||
w=kwargs["rope_W"],
|
||||
)
|
||||
|
||||
return imgs
|
||||
|
||||
@ -382,15 +450,16 @@ class FinalLayerMixin(BaseMixin):
|
||||
self.patch_size = patch_size
|
||||
self.out_channels = out_channels
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=elementwise_affine, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim, 2 * hidden_size, bias=True))
|
||||
|
||||
self.spatial_length = latent_width * latent_height // patch_size**2
|
||||
self.latent_width = latent_width
|
||||
self.latent_height = latent_height
|
||||
self.linear = nn.Linear(hidden_size, reduce(mul, patch_size) * out_channels, bias=True)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(), nn.Linear(time_embed_dim, 2 * hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def final_forward(self, logits, **kwargs):
|
||||
x, emb = logits[:, kwargs["text_length"] :, :], kwargs["emb"] # x:(b,(t n),d)
|
||||
x, emb = (
|
||||
logits[:, kwargs["text_length"] :, :],
|
||||
kwargs["emb"],
|
||||
) # x:(b,(t n),d),只取了x中后面images的部分
|
||||
shift, scale = self.adaLN_modulation(emb).chunk(2, dim=1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
@ -398,10 +467,9 @@ class FinalLayerMixin(BaseMixin):
|
||||
return unpatchify(
|
||||
x,
|
||||
c=self.out_channels,
|
||||
p=self.patch_size,
|
||||
w=self.latent_width // self.patch_size,
|
||||
h=self.latent_height // self.patch_size,
|
||||
rope_position_ids=kwargs.get("rope_position_ids", None),
|
||||
patch_size=self.patch_size,
|
||||
w=kwargs["rope_W"],
|
||||
h=kwargs["rope_H"],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -440,8 +508,6 @@ class SwiGLUMixin(BaseMixin):
|
||||
class AdaLNMixin(BaseMixin):
|
||||
def __init__(
|
||||
self,
|
||||
width,
|
||||
height,
|
||||
hidden_size,
|
||||
num_layers,
|
||||
time_embed_dim,
|
||||
@ -452,12 +518,13 @@ class AdaLNMixin(BaseMixin):
|
||||
):
|
||||
super().__init__()
|
||||
self.num_layers = num_layers
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.compressed_num_frames = compressed_num_frames
|
||||
|
||||
self.adaLN_modulations = nn.ModuleList(
|
||||
[nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim, 12 * hidden_size)) for _ in range(num_layers)]
|
||||
[
|
||||
nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim, 12 * hidden_size))
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.qk_ln = qk_ln
|
||||
@ -517,7 +584,9 @@ class AdaLNMixin(BaseMixin):
|
||||
img_attention_input = modulate(img_attention_input, shift_msa, scale_msa)
|
||||
text_attention_input = modulate(text_attention_input, text_shift_msa, text_scale_msa)
|
||||
|
||||
attention_input = torch.cat((text_attention_input, img_attention_input), dim=1) # (b,n_t+t*n_i,d)
|
||||
attention_input = torch.cat(
|
||||
(text_attention_input, img_attention_input), dim=1
|
||||
) # (b,n_t+t*n_i,d)
|
||||
attention_output = layer.attention(attention_input, mask, **kwargs)
|
||||
text_attention_output = attention_output[:, :text_length] # (b,n,d)
|
||||
img_attention_output = attention_output[:, text_length:] # (b,(t n),d)
|
||||
@ -541,9 +610,13 @@ class AdaLNMixin(BaseMixin):
|
||||
img_mlp_output = layer.fourth_layernorm(img_mlp_output)
|
||||
|
||||
img_hidden_states = img_hidden_states + gate_mlp * img_mlp_output # vision (b,(t n),d)
|
||||
text_hidden_states = text_hidden_states + text_gate_mlp * text_mlp_output # language (b,n,d)
|
||||
text_hidden_states = (
|
||||
text_hidden_states + text_gate_mlp * text_mlp_output
|
||||
) # language (b,n,d)
|
||||
|
||||
hidden_states = torch.cat((text_hidden_states, img_hidden_states), dim=1) # (b,(n_t+t*n_i),d)
|
||||
hidden_states = torch.cat(
|
||||
(text_hidden_states, img_hidden_states), dim=1
|
||||
) # (b,(n_t+t*n_i),d)
|
||||
return hidden_states
|
||||
|
||||
def reinit(self, parent_model=None):
|
||||
@ -611,7 +684,7 @@ class DiffusionTransformer(BaseModel):
|
||||
time_interpolation=1.0,
|
||||
use_SwiGLU=False,
|
||||
use_RMSNorm=False,
|
||||
zero_init_y_embed=False,
|
||||
ofs_embed_dim=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.latent_width = latent_width
|
||||
@ -619,12 +692,13 @@ class DiffusionTransformer(BaseModel):
|
||||
self.patch_size = patch_size
|
||||
self.num_frames = num_frames
|
||||
self.time_compressed_rate = time_compressed_rate
|
||||
self.spatial_length = latent_width * latent_height // patch_size**2
|
||||
self.spatial_length = latent_width * latent_height // reduce(mul, patch_size[1:])
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.hidden_size = hidden_size
|
||||
self.model_channels = hidden_size
|
||||
self.time_embed_dim = time_embed_dim if time_embed_dim is not None else hidden_size
|
||||
self.ofs_embed_dim = ofs_embed_dim
|
||||
self.num_classes = num_classes
|
||||
self.adm_in_channels = adm_in_channels
|
||||
self.input_time = input_time
|
||||
@ -636,7 +710,6 @@ class DiffusionTransformer(BaseModel):
|
||||
self.width_interpolation = width_interpolation
|
||||
self.time_interpolation = time_interpolation
|
||||
self.inner_hidden_size = hidden_size * 4
|
||||
self.zero_init_y_embed = zero_init_y_embed
|
||||
try:
|
||||
self.dtype = str_to_dtype[kwargs.pop("dtype")]
|
||||
except:
|
||||
@ -651,7 +724,9 @@ class DiffusionTransformer(BaseModel):
|
||||
if use_RMSNorm:
|
||||
kwargs["layernorm"] = RMSNorm
|
||||
else:
|
||||
kwargs["layernorm"] = partial(LayerNorm, elementwise_affine=elementwise_affine, eps=1e-6)
|
||||
kwargs["layernorm"] = partial(
|
||||
LayerNorm, elementwise_affine=elementwise_affine, eps=1e-6
|
||||
)
|
||||
|
||||
transformer_args.num_layers = num_layers
|
||||
transformer_args.hidden_size = hidden_size
|
||||
@ -664,12 +739,13 @@ class DiffusionTransformer(BaseModel):
|
||||
|
||||
if use_SwiGLU:
|
||||
self.add_mixin(
|
||||
"swiglu", SwiGLUMixin(num_layers, hidden_size, self.inner_hidden_size, bias=False), reinit=True
|
||||
"swiglu",
|
||||
SwiGLUMixin(num_layers, hidden_size, self.inner_hidden_size, bias=False),
|
||||
reinit=True,
|
||||
)
|
||||
|
||||
def _build_modules(self, module_configs):
|
||||
model_channels = self.hidden_size
|
||||
# time_embed_dim = model_channels * 4
|
||||
time_embed_dim = self.time_embed_dim
|
||||
self.time_embed = nn.Sequential(
|
||||
linear(model_channels, time_embed_dim),
|
||||
@ -677,6 +753,13 @@ class DiffusionTransformer(BaseModel):
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
|
||||
if self.ofs_embed_dim is not None:
|
||||
self.ofs_embed = nn.Sequential(
|
||||
linear(self.ofs_embed_dim, self.ofs_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(self.ofs_embed_dim, self.ofs_embed_dim),
|
||||
)
|
||||
|
||||
if self.num_classes is not None:
|
||||
if isinstance(self.num_classes, int):
|
||||
self.label_emb = nn.Embedding(self.num_classes, time_embed_dim)
|
||||
@ -701,9 +784,6 @@ class DiffusionTransformer(BaseModel):
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
)
|
||||
if self.zero_init_y_embed:
|
||||
nn.init.constant_(self.label_emb[0][2].weight, 0)
|
||||
nn.init.constant_(self.label_emb[0][2].bias, 0)
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
@ -712,10 +792,13 @@ class DiffusionTransformer(BaseModel):
|
||||
"pos_embed",
|
||||
instantiate_from_config(
|
||||
pos_embed_config,
|
||||
height=self.latent_height // self.patch_size,
|
||||
width=self.latent_width // self.patch_size,
|
||||
height=self.latent_height // self.patch_size[1],
|
||||
width=self.latent_width // self.patch_size[2],
|
||||
compressed_num_frames=(self.num_frames - 1) // self.time_compressed_rate + 1,
|
||||
hidden_size=self.hidden_size,
|
||||
height_interpolation=self.height_interpolation,
|
||||
width_interpolation=self.width_interpolation,
|
||||
time_interpolation=self.time_interpolation,
|
||||
),
|
||||
reinit=True,
|
||||
)
|
||||
@ -737,8 +820,6 @@ class DiffusionTransformer(BaseModel):
|
||||
"adaln_layer",
|
||||
instantiate_from_config(
|
||||
adaln_layer_config,
|
||||
height=self.latent_height // self.patch_size,
|
||||
width=self.latent_width // self.patch_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_layers=self.num_layers,
|
||||
compressed_num_frames=(self.num_frames - 1) // self.time_compressed_rate + 1,
|
||||
@ -749,7 +830,6 @@ class DiffusionTransformer(BaseModel):
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
final_layer_config = module_configs["final_layer_config"]
|
||||
self.add_mixin(
|
||||
"final_layer",
|
||||
@ -765,44 +845,55 @@ class DiffusionTransformer(BaseModel):
|
||||
),
|
||||
reinit=True,
|
||||
)
|
||||
|
||||
if "lora_config" in module_configs:
|
||||
lora_config = module_configs["lora_config"]
|
||||
self.add_mixin("lora", instantiate_from_config(lora_config, layer_num=self.num_layers), reinit=True)
|
||||
|
||||
self.add_mixin(
|
||||
"lora", instantiate_from_config(lora_config, layer_num=self.num_layers), reinit=True
|
||||
)
|
||||
return
|
||||
|
||||
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
|
||||
b, t, d, h, w = x.shape
|
||||
if x.dtype != self.dtype:
|
||||
x = x.to(self.dtype)
|
||||
|
||||
# This is not use in inference
|
||||
if "concat_images" in kwargs and kwargs["concat_images"] is not None:
|
||||
if kwargs["concat_images"].shape[0] != x.shape[0]:
|
||||
concat_images = kwargs["concat_images"].repeat(2, 1, 1, 1, 1)
|
||||
else:
|
||||
concat_images = kwargs["concat_images"]
|
||||
x = torch.cat([x, concat_images], dim=2)
|
||||
|
||||
assert (y is not None) == (
|
||||
self.num_classes is not None
|
||||
), "must specify y if and only if the model is class-conditional"
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False, dtype=self.dtype)
|
||||
t_emb = timestep_embedding(
|
||||
timesteps, self.model_channels, repeat_only=False, dtype=self.dtype
|
||||
)
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
if self.num_classes is not None:
|
||||
# assert y.shape[0] == x.shape[0]
|
||||
assert x.shape[0] % y.shape[0] == 0
|
||||
y = y.repeat_interleave(x.shape[0] // y.shape[0], dim=0)
|
||||
emb = emb + self.label_emb(y)
|
||||
|
||||
kwargs["seq_length"] = t * h * w // (self.patch_size**2)
|
||||
if self.ofs_embed_dim is not None:
|
||||
ofs_emb = timestep_embedding(
|
||||
kwargs["ofs"], self.ofs_embed_dim, repeat_only=False, dtype=self.dtype
|
||||
)
|
||||
ofs_emb = self.ofs_embed(ofs_emb)
|
||||
emb = emb + ofs_emb
|
||||
|
||||
kwargs["seq_length"] = t * h * w // reduce(mul, self.patch_size)
|
||||
kwargs["images"] = x
|
||||
kwargs["emb"] = emb
|
||||
kwargs["encoder_outputs"] = context
|
||||
kwargs["text_length"] = context.shape[1]
|
||||
|
||||
kwargs["input_ids"] = kwargs["position_ids"] = kwargs["attention_mask"] = torch.ones((1, 1)).to(x.dtype)
|
||||
kwargs["rope_T"] = t // self.patch_size[0]
|
||||
kwargs["rope_H"] = h // self.patch_size[1]
|
||||
kwargs["rope_W"] = w // self.patch_size[2]
|
||||
|
||||
kwargs["input_ids"] = kwargs["position_ids"] = kwargs["attention_mask"] = torch.ones(
|
||||
(1, 1)
|
||||
).to(x.dtype)
|
||||
output = super().forward(**kwargs)[0]
|
||||
return output
|
||||
|
@ -4,7 +4,7 @@ echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
|
||||
environs="WORLD_SIZE=1 RANK=0 LOCAL_RANK=0 LOCAL_WORLD_SIZE=1"
|
||||
|
||||
run_cmd="$environs python sample_video.py --base configs/cogvideox_5b.yaml configs/inference.yaml --seed $RANDOM"
|
||||
run_cmd="$environs python sample_video.py --base configs/cogvideox1.5_5b.yaml configs/inference.yaml --seed $RANDOM"
|
||||
|
||||
echo ${run_cmd}
|
||||
eval ${run_cmd}
|
||||
|
@ -1,16 +1,11 @@
|
||||
SwissArmyTransformer==0.4.12
|
||||
omegaconf==2.3.0
|
||||
torch==2.4.0
|
||||
torchvision==0.19.0
|
||||
pytorch_lightning==2.3.3
|
||||
kornia==0.7.3
|
||||
beartype==0.18.5
|
||||
numpy==2.0.1
|
||||
fsspec==2024.5.0
|
||||
safetensors==0.4.3
|
||||
imageio-ffmpeg==0.5.1
|
||||
imageio==2.34.2
|
||||
scipy==1.14.0
|
||||
decord==0.6.0
|
||||
wandb==0.17.5
|
||||
deepspeed==0.14.4
|
||||
SwissArmyTransformer>=0.4.12
|
||||
omegaconf>=2.3.0
|
||||
pytorch_lightning>=2.4.0
|
||||
kornia>=0.7.3
|
||||
beartype>=0.19.0
|
||||
fsspec>=2024.2.0
|
||||
safetensors>=0.4.5
|
||||
scipy>=1.14.1
|
||||
decord>=0.6.0
|
||||
wandb>=0.18.5
|
||||
deepspeed>=0.15.3
|
||||
|
@ -4,23 +4,20 @@ import argparse
|
||||
from typing import List, Union
|
||||
from tqdm import tqdm
|
||||
from omegaconf import ListConfig
|
||||
from PIL import Image
|
||||
import imageio
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from einops import rearrange
|
||||
from einops import rearrange, repeat
|
||||
import torchvision.transforms as TT
|
||||
|
||||
|
||||
from sat.model.base_model import get_model
|
||||
from sat.training.model_io import load_checkpoint
|
||||
from sat import mpu
|
||||
|
||||
from diffusion_video import SATVideoDiffusionEngine
|
||||
from arguments import get_args
|
||||
from torchvision.transforms.functional import center_crop, resize
|
||||
from torchvision.transforms import InterpolationMode
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def read_from_cli():
|
||||
@ -54,8 +51,60 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], T=None, device="cuda
|
||||
|
||||
for key in keys:
|
||||
if key == "txt":
|
||||
batch["txt"] = np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist()
|
||||
batch_uc["txt"] = np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist()
|
||||
batch["txt"] = (
|
||||
np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist()
|
||||
)
|
||||
batch_uc["txt"] = (
|
||||
np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist()
|
||||
)
|
||||
elif key == "original_size_as_tuple":
|
||||
batch["original_size_as_tuple"] = (
|
||||
torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
|
||||
.to(device)
|
||||
.repeat(*N, 1)
|
||||
)
|
||||
elif key == "crop_coords_top_left":
|
||||
batch["crop_coords_top_left"] = (
|
||||
torch.tensor([value_dict["crop_coords_top"], value_dict["crop_coords_left"]])
|
||||
.to(device)
|
||||
.repeat(*N, 1)
|
||||
)
|
||||
elif key == "aesthetic_score":
|
||||
batch["aesthetic_score"] = (
|
||||
torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
|
||||
)
|
||||
batch_uc["aesthetic_score"] = (
|
||||
torch.tensor([value_dict["negative_aesthetic_score"]]).to(device).repeat(*N, 1)
|
||||
)
|
||||
|
||||
elif key == "target_size_as_tuple":
|
||||
batch["target_size_as_tuple"] = (
|
||||
torch.tensor([value_dict["target_height"], value_dict["target_width"]])
|
||||
.to(device)
|
||||
.repeat(*N, 1)
|
||||
)
|
||||
elif key == "fps":
|
||||
batch[key] = torch.tensor([value_dict["fps"]]).to(device).repeat(math.prod(N))
|
||||
elif key == "fps_id":
|
||||
batch[key] = torch.tensor([value_dict["fps_id"]]).to(device).repeat(math.prod(N))
|
||||
elif key == "motion_bucket_id":
|
||||
batch[key] = (
|
||||
torch.tensor([value_dict["motion_bucket_id"]]).to(device).repeat(math.prod(N))
|
||||
)
|
||||
elif key == "pool_image":
|
||||
batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=math.prod(N)).to(
|
||||
device, dtype=torch.half
|
||||
)
|
||||
elif key == "cond_aug":
|
||||
batch[key] = repeat(
|
||||
torch.tensor([value_dict["cond_aug"]]).to("cuda"),
|
||||
"1 -> b",
|
||||
b=math.prod(N),
|
||||
)
|
||||
elif key == "cond_frames":
|
||||
batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0])
|
||||
elif key == "cond_frames_without_noise":
|
||||
batch[key] = repeat(value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0])
|
||||
else:
|
||||
batch[key] = value_dict[key]
|
||||
|
||||
@ -68,7 +117,9 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], T=None, device="cuda
|
||||
return batch, batch_uc
|
||||
|
||||
|
||||
def save_video_as_grid_and_mp4(video_batch: torch.Tensor, save_path: str, fps: int = 5, args=None, key=None):
|
||||
def save_video_as_grid_and_mp4(
|
||||
video_batch: torch.Tensor, save_path: str, fps: int = 5, args=None, key=None
|
||||
):
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
|
||||
for i, vid in enumerate(video_batch):
|
||||
@ -83,37 +134,6 @@ def save_video_as_grid_and_mp4(video_batch: torch.Tensor, save_path: str, fps: i
|
||||
writer.append_data(frame)
|
||||
|
||||
|
||||
def resize_for_rectangle_crop(arr, image_size, reshape_mode="random"):
|
||||
if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
|
||||
arr = resize(
|
||||
arr,
|
||||
size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
|
||||
interpolation=InterpolationMode.BICUBIC,
|
||||
)
|
||||
else:
|
||||
arr = resize(
|
||||
arr,
|
||||
size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
|
||||
interpolation=InterpolationMode.BICUBIC,
|
||||
)
|
||||
|
||||
h, w = arr.shape[2], arr.shape[3]
|
||||
arr = arr.squeeze(0)
|
||||
|
||||
delta_h = h - image_size[0]
|
||||
delta_w = w - image_size[1]
|
||||
|
||||
if reshape_mode == "random" or reshape_mode == "none":
|
||||
top = np.random.randint(0, delta_h + 1)
|
||||
left = np.random.randint(0, delta_w + 1)
|
||||
elif reshape_mode == "center":
|
||||
top, left = delta_h // 2, delta_w // 2
|
||||
else:
|
||||
raise NotImplementedError
|
||||
arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
|
||||
return arr
|
||||
|
||||
|
||||
def sampling_main(args, model_cls):
|
||||
if isinstance(model_cls, type):
|
||||
model = get_model(args, model_cls)
|
||||
@ -127,40 +147,67 @@ def sampling_main(args, model_cls):
|
||||
data_iter = read_from_cli()
|
||||
elif args.input_type == "txt":
|
||||
rank, world_size = mpu.get_data_parallel_rank(), mpu.get_data_parallel_world_size()
|
||||
print("rank and world_size", rank, world_size)
|
||||
data_iter = read_from_file(args.input_file, rank=rank, world_size=world_size)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
image_size = [480, 720]
|
||||
|
||||
if args.image2video:
|
||||
chained_trainsforms = []
|
||||
chained_trainsforms.append(TT.ToTensor())
|
||||
transform = TT.Compose(chained_trainsforms)
|
||||
|
||||
sample_func = model.sample
|
||||
T, H, W, C, F = args.sampling_num_frames, image_size[0], image_size[1], args.latent_channels, 8
|
||||
num_samples = [1]
|
||||
force_uc_zero_embeddings = ["txt"]
|
||||
device = model.device
|
||||
T, C = args.sampling_num_frames, args.latent_channels
|
||||
with torch.no_grad():
|
||||
for text, cnt in tqdm(data_iter):
|
||||
if args.image2video:
|
||||
# use with input image shape
|
||||
text, image_path = text.split("@@")
|
||||
assert os.path.exists(image_path), image_path
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
(img_W, img_H) = image.size
|
||||
|
||||
def nearest_multiple_of_16(n):
|
||||
lower_multiple = (n // 16) * 16
|
||||
upper_multiple = (n // 16 + 1) * 16
|
||||
if abs(n - lower_multiple) < abs(n - upper_multiple):
|
||||
return lower_multiple
|
||||
else:
|
||||
return upper_multiple
|
||||
|
||||
if img_H < img_W:
|
||||
H = 96
|
||||
W = int(nearest_multiple_of_16(img_W / img_H * H * 8)) // 8
|
||||
else:
|
||||
W = 96
|
||||
H = int(nearest_multiple_of_16(img_H / img_W * W * 8)) // 8
|
||||
chained_trainsforms = []
|
||||
chained_trainsforms.append(
|
||||
TT.Resize(size=[int(H * 8), int(W * 8)], interpolation=1)
|
||||
)
|
||||
chained_trainsforms.append(TT.ToTensor())
|
||||
transform = TT.Compose(chained_trainsforms)
|
||||
image = transform(image).unsqueeze(0).to("cuda")
|
||||
image = resize_for_rectangle_crop(image, image_size, reshape_mode="center").unsqueeze(0)
|
||||
image = image * 2.0 - 1.0
|
||||
image = image.unsqueeze(2).to(torch.bfloat16)
|
||||
image = model.encode_first_stage(image, None)
|
||||
image = image / model.scale_factor
|
||||
image = image.permute(0, 2, 1, 3, 4).contiguous()
|
||||
pad_shape = (image.shape[0], T - 1, C, H // F, W // F)
|
||||
image = torch.concat([image, torch.zeros(pad_shape).to(image.device).to(image.dtype)], dim=1)
|
||||
pad_shape = (image.shape[0], T - 1, C, H, W)
|
||||
image = torch.concat(
|
||||
[image, torch.zeros(pad_shape).to(image.device).to(image.dtype)], dim=1
|
||||
)
|
||||
else:
|
||||
image_size = args.sampling_image_size
|
||||
H, W = image_size[0], image_size[1]
|
||||
F = 8 # 8x downsampled
|
||||
image = None
|
||||
|
||||
text_cast = [text]
|
||||
mp_size = mpu.get_model_parallel_world_size()
|
||||
global_rank = torch.distributed.get_rank() // mp_size
|
||||
src = global_rank * mp_size
|
||||
torch.distributed.broadcast_object_list(
|
||||
text_cast, src=src, group=mpu.get_model_parallel_group()
|
||||
)
|
||||
text = text_cast[0]
|
||||
value_dict = {
|
||||
"prompt": text,
|
||||
"negative_prompt": "",
|
||||
@ -168,7 +215,9 @@ def sampling_main(args, model_cls):
|
||||
}
|
||||
|
||||
batch, batch_uc = get_batch(
|
||||
get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples
|
||||
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
||||
value_dict,
|
||||
num_samples,
|
||||
)
|
||||
for key in batch:
|
||||
if isinstance(batch[key], torch.Tensor):
|
||||
@ -187,57 +236,50 @@ def sampling_main(args, model_cls):
|
||||
if not k == "crossattn":
|
||||
c[k], uc[k] = map(lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc))
|
||||
|
||||
if args.image2video and image is not None:
|
||||
if args.image2video:
|
||||
c["concat"] = image
|
||||
uc["concat"] = image
|
||||
|
||||
for index in range(args.batch_size):
|
||||
# reload model on GPU
|
||||
model.to(device)
|
||||
samples_z = sample_func(
|
||||
c,
|
||||
uc=uc,
|
||||
batch_size=1,
|
||||
shape=(T, C, H // F, W // F),
|
||||
)
|
||||
if args.image2video:
|
||||
samples_z = sample_func(
|
||||
c,
|
||||
uc=uc,
|
||||
batch_size=1,
|
||||
shape=(T, C, H, W),
|
||||
ofs=torch.tensor([2.0]).to("cuda"),
|
||||
)
|
||||
else:
|
||||
samples_z = sample_func(
|
||||
c,
|
||||
uc=uc,
|
||||
batch_size=1,
|
||||
shape=(T, C, H // F, W // F),
|
||||
).to("cuda")
|
||||
|
||||
samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous()
|
||||
|
||||
# Unload the model from GPU to save GPU memory
|
||||
model.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
first_stage_model = model.first_stage_model
|
||||
first_stage_model = first_stage_model.to(device)
|
||||
|
||||
latent = 1.0 / model.scale_factor * samples_z
|
||||
|
||||
# Decode latent serial to save GPU memory
|
||||
recons = []
|
||||
loop_num = (T - 1) // 2
|
||||
for i in range(loop_num):
|
||||
if i == 0:
|
||||
start_frame, end_frame = 0, 3
|
||||
else:
|
||||
start_frame, end_frame = i * 2 + 1, i * 2 + 3
|
||||
if i == loop_num - 1:
|
||||
clear_fake_cp_cache = True
|
||||
else:
|
||||
clear_fake_cp_cache = False
|
||||
with torch.no_grad():
|
||||
recon = first_stage_model.decode(
|
||||
latent[:, :, start_frame:end_frame].contiguous(), clear_fake_cp_cache=clear_fake_cp_cache
|
||||
)
|
||||
|
||||
recons.append(recon)
|
||||
|
||||
recon = torch.cat(recons, dim=2).to(torch.float32)
|
||||
samples_x = recon.permute(0, 2, 1, 3, 4).contiguous()
|
||||
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu()
|
||||
|
||||
save_path = os.path.join(
|
||||
args.output_dir, str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120], str(index)
|
||||
)
|
||||
if mpu.get_model_parallel_rank() == 0:
|
||||
save_video_as_grid_and_mp4(samples, save_path, fps=args.sampling_fps)
|
||||
if args.only_save_latents:
|
||||
samples_z = 1.0 / model.scale_factor * samples_z
|
||||
save_path = os.path.join(
|
||||
args.output_dir,
|
||||
str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120],
|
||||
str(index),
|
||||
)
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
torch.save(samples_z, os.path.join(save_path, "latent.pt"))
|
||||
with open(os.path.join(save_path, "text.txt"), "w") as f:
|
||||
f.write(text)
|
||||
else:
|
||||
samples_x = model.decode_first_stage(samples_z).to(torch.float32)
|
||||
samples_x = samples_x.permute(0, 2, 1, 3, 4).contiguous()
|
||||
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu()
|
||||
save_path = os.path.join(
|
||||
args.output_dir,
|
||||
str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120],
|
||||
str(index),
|
||||
)
|
||||
if mpu.get_model_parallel_rank() == 0:
|
||||
save_video_as_grid_and_mp4(samples, save_path, fps=args.sampling_fps)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -71,15 +71,24 @@ class LambdaWarmUpCosineScheduler2:
|
||||
n = n - self.cum_cycles[cycle]
|
||||
if self.verbosity_interval > 0:
|
||||
if n % self.verbosity_interval == 0:
|
||||
print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}")
|
||||
print(
|
||||
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
||||
f"current cycle {cycle}"
|
||||
)
|
||||
if n < self.lr_warm_up_steps[cycle]:
|
||||
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
|
||||
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
|
||||
cycle
|
||||
] * n + self.f_start[cycle]
|
||||
self.last_f = f
|
||||
return f
|
||||
else:
|
||||
t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
|
||||
t = (n - self.lr_warm_up_steps[cycle]) / (
|
||||
self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]
|
||||
)
|
||||
t = min(t, 1.0)
|
||||
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (1 + np.cos(t * np.pi))
|
||||
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
|
||||
1 + np.cos(t * np.pi)
|
||||
)
|
||||
self.last_f = f
|
||||
return f
|
||||
|
||||
@ -93,10 +102,15 @@ class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
|
||||
n = n - self.cum_cycles[cycle]
|
||||
if self.verbosity_interval > 0:
|
||||
if n % self.verbosity_interval == 0:
|
||||
print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}")
|
||||
print(
|
||||
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
||||
f"current cycle {cycle}"
|
||||
)
|
||||
|
||||
if n < self.lr_warm_up_steps[cycle]:
|
||||
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
|
||||
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
|
||||
cycle
|
||||
] * n + self.f_start[cycle]
|
||||
self.last_f = f
|
||||
return f
|
||||
else:
|
||||
|
@ -218,14 +218,20 @@ class AutoencodingEngine(AbstractAutoencoder):
|
||||
x = self.decoder(z, **kwargs)
|
||||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor, **additional_decode_kwargs) -> Tuple[torch.Tensor, torch.Tensor, dict]:
|
||||
def forward(
|
||||
self, x: torch.Tensor, **additional_decode_kwargs
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
|
||||
z, reg_log = self.encode(x, return_reg_log=True)
|
||||
dec = self.decode(z, **additional_decode_kwargs)
|
||||
return z, dec, reg_log
|
||||
|
||||
def inner_training_step(self, batch: dict, batch_idx: int, optimizer_idx: int = 0) -> torch.Tensor:
|
||||
def inner_training_step(
|
||||
self, batch: dict, batch_idx: int, optimizer_idx: int = 0
|
||||
) -> torch.Tensor:
|
||||
x = self.get_input(batch)
|
||||
additional_decode_kwargs = {key: batch[key] for key in self.additional_decode_keys.intersection(batch)}
|
||||
additional_decode_kwargs = {
|
||||
key: batch[key] for key in self.additional_decode_keys.intersection(batch)
|
||||
}
|
||||
z, xrec, regularization_log = self(x, **additional_decode_kwargs)
|
||||
if hasattr(self.loss, "forward_keys"):
|
||||
extra_info = {
|
||||
@ -361,12 +367,16 @@ class AutoencodingEngine(AbstractAutoencoder):
|
||||
if self.trainable_ae_params is None:
|
||||
ae_params = self.get_autoencoder_params()
|
||||
else:
|
||||
ae_params, num_ae_params = self.get_param_groups(self.trainable_ae_params, self.ae_optimizer_args)
|
||||
ae_params, num_ae_params = self.get_param_groups(
|
||||
self.trainable_ae_params, self.ae_optimizer_args
|
||||
)
|
||||
logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}")
|
||||
if self.trainable_disc_params is None:
|
||||
disc_params = self.get_discriminator_params()
|
||||
else:
|
||||
disc_params, num_disc_params = self.get_param_groups(self.trainable_disc_params, self.disc_optimizer_args)
|
||||
disc_params, num_disc_params = self.get_param_groups(
|
||||
self.trainable_disc_params, self.disc_optimizer_args
|
||||
)
|
||||
logpy.info(f"Number of trainable discriminator parameters: {num_disc_params:,}")
|
||||
opt_ae = self.instantiate_optimizer_from_config(
|
||||
ae_params,
|
||||
@ -375,17 +385,23 @@ class AutoencodingEngine(AbstractAutoencoder):
|
||||
)
|
||||
opts = [opt_ae]
|
||||
if len(disc_params) > 0:
|
||||
opt_disc = self.instantiate_optimizer_from_config(disc_params, self.learning_rate, self.optimizer_config)
|
||||
opt_disc = self.instantiate_optimizer_from_config(
|
||||
disc_params, self.learning_rate, self.optimizer_config
|
||||
)
|
||||
opts.append(opt_disc)
|
||||
|
||||
return opts
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs) -> dict:
|
||||
def log_images(
|
||||
self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs
|
||||
) -> dict:
|
||||
log = dict()
|
||||
additional_decode_kwargs = {}
|
||||
x = self.get_input(batch)
|
||||
additional_decode_kwargs.update({key: batch[key] for key in self.additional_decode_keys.intersection(batch)})
|
||||
additional_decode_kwargs.update(
|
||||
{key: batch[key] for key in self.additional_decode_keys.intersection(batch)}
|
||||
)
|
||||
|
||||
_, xrec, _ = self(x, **additional_decode_kwargs)
|
||||
log["inputs"] = x
|
||||
@ -404,7 +420,9 @@ class AutoencodingEngine(AbstractAutoencoder):
|
||||
diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x)
|
||||
diff_ema.clamp_(0, 1.0)
|
||||
log["diff_ema"] = 2.0 * diff_ema - 1.0
|
||||
log["diff_boost_ema"] = 2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1
|
||||
log["diff_boost_ema"] = (
|
||||
2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1
|
||||
)
|
||||
if additional_log_kwargs:
|
||||
additional_decode_kwargs.update(additional_log_kwargs)
|
||||
_, xrec_add, _ = self(x, **additional_decode_kwargs)
|
||||
@ -446,7 +464,9 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
|
||||
params = super().get_autoencoder_params()
|
||||
return params
|
||||
|
||||
def encode(self, x: torch.Tensor, return_reg_log: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_reg_log: bool = False
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
||||
if self.max_batch_size is None:
|
||||
z = self.encoder(x)
|
||||
z = self.quant_conv(z)
|
||||
@ -513,7 +533,9 @@ class VideoAutoencodingEngine(AutoencodingEngine):
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
|
||||
def log_videos(self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs) -> dict:
|
||||
def log_videos(
|
||||
self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs
|
||||
) -> dict:
|
||||
return self.log_images(batch, additional_log_kwargs, **kwargs)
|
||||
|
||||
def get_input(self, batch: dict) -> torch.Tensor:
|
||||
@ -524,7 +546,9 @@ class VideoAutoencodingEngine(AutoencodingEngine):
|
||||
batch = batch[self.input_key]
|
||||
|
||||
global_src_rank = get_context_parallel_group_rank() * self.context_parallel_size
|
||||
torch.distributed.broadcast(batch, src=global_src_rank, group=get_context_parallel_group())
|
||||
torch.distributed.broadcast(
|
||||
batch, src=global_src_rank, group=get_context_parallel_group()
|
||||
)
|
||||
|
||||
batch = _conv_split(batch, dim=2, kernel_size=1)
|
||||
return batch
|
||||
|
@ -94,7 +94,11 @@ class FeedForward(nn.Module):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = default(dim_out, dim)
|
||||
project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
|
||||
project_in = (
|
||||
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
|
||||
if not glu
|
||||
else GEGLU(dim, inner_dim)
|
||||
)
|
||||
|
||||
self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
|
||||
|
||||
@ -126,7 +130,9 @@ class LinearAttention(nn.Module):
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
qkv = self.to_qkv(x)
|
||||
q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
|
||||
q, k, v = rearrange(
|
||||
qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
|
||||
)
|
||||
k = k.softmax(dim=-1)
|
||||
context = torch.einsum("bhdn,bhen->bhde", k, v)
|
||||
out = torch.einsum("bhde,bhdn->bhen", context, q)
|
||||
@ -143,7 +149,9 @@ class SpatialSelfAttention(nn.Module):
|
||||
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
@ -244,7 +252,9 @@ class CrossAttention(nn.Module):
|
||||
# new
|
||||
with sdp_kernel(**BACKEND_MAP[self.backend]):
|
||||
# print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
|
||||
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) # scale is dim_head ** -0.5 per default
|
||||
out = F.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=mask
|
||||
) # scale is dim_head ** -0.5 per default
|
||||
|
||||
del q, k, v
|
||||
out = rearrange(out, "b h n d -> b n (h d)", h=h)
|
||||
@ -422,7 +432,9 @@ class BasicTransformerBlock(nn.Module):
|
||||
self.norm1(x),
|
||||
context=context if self.disable_self_attn else None,
|
||||
additional_tokens=additional_tokens,
|
||||
n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self if not self.disable_self_attn else 0,
|
||||
n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self
|
||||
if not self.disable_self_attn
|
||||
else 0,
|
||||
)
|
||||
+ x
|
||||
)
|
||||
@ -499,7 +511,9 @@ class SpatialTransformer(nn.Module):
|
||||
sdp_backend=None,
|
||||
):
|
||||
super().__init__()
|
||||
print(f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads")
|
||||
print(
|
||||
f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads"
|
||||
)
|
||||
from omegaconf import ListConfig
|
||||
|
||||
if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)):
|
||||
@ -542,7 +556,9 @@ class SpatialTransformer(nn.Module):
|
||||
]
|
||||
)
|
||||
if not use_linear:
|
||||
self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
|
||||
self.proj_out = zero_module(
|
||||
nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
)
|
||||
else:
|
||||
# self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
|
||||
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
|
||||
|
@ -87,7 +87,9 @@ class GeneralLPIPSWithDiscriminator(nn.Module):
|
||||
yield from ()
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(self, inputs: torch.Tensor, reconstructions: torch.Tensor) -> Dict[str, torch.Tensor]:
|
||||
def log_images(
|
||||
self, inputs: torch.Tensor, reconstructions: torch.Tensor
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
# calc logits of real/fake
|
||||
logits_real = self.discriminator(inputs.contiguous().detach())
|
||||
if len(logits_real.shape) < 4:
|
||||
@ -209,7 +211,9 @@ class GeneralLPIPSWithDiscriminator(nn.Module):
|
||||
weights: Union[None, float, torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, dict]:
|
||||
if self.scale_input_to_tgt_size:
|
||||
inputs = torch.nn.functional.interpolate(inputs, reconstructions.shape[2:], mode="bicubic", antialias=True)
|
||||
inputs = torch.nn.functional.interpolate(
|
||||
inputs, reconstructions.shape[2:], mode="bicubic", antialias=True
|
||||
)
|
||||
|
||||
if self.dims > 2:
|
||||
inputs, reconstructions = map(
|
||||
@ -226,7 +230,9 @@ class GeneralLPIPSWithDiscriminator(nn.Module):
|
||||
input_frames = pick_video_frame(inputs, frame_indices)
|
||||
recon_frames = pick_video_frame(reconstructions, frame_indices)
|
||||
|
||||
p_loss = self.perceptual_loss(input_frames.contiguous(), recon_frames.contiguous()).mean()
|
||||
p_loss = self.perceptual_loss(
|
||||
input_frames.contiguous(), recon_frames.contiguous()
|
||||
).mean()
|
||||
rec_loss = rec_loss + self.perceptual_weight * p_loss
|
||||
|
||||
nll_loss, weighted_nll_loss = self.get_nll_loss(rec_loss, weights)
|
||||
@ -238,7 +244,9 @@ class GeneralLPIPSWithDiscriminator(nn.Module):
|
||||
logits_fake = self.discriminator(reconstructions.contiguous())
|
||||
g_loss = -torch.mean(logits_fake)
|
||||
if self.training:
|
||||
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
|
||||
d_weight = self.calculate_adaptive_weight(
|
||||
nll_loss, g_loss, last_layer=last_layer
|
||||
)
|
||||
else:
|
||||
d_weight = torch.tensor(1.0)
|
||||
else:
|
||||
|
@ -37,12 +37,18 @@ class LatentLPIPS(nn.Module):
|
||||
if self.perceptual_weight > 0.0:
|
||||
image_reconstructions = self.decoder.decode(latent_predictions)
|
||||
image_targets = self.decoder.decode(latent_inputs)
|
||||
perceptual_loss = self.perceptual_loss(image_targets.contiguous(), image_reconstructions.contiguous())
|
||||
loss = self.latent_weight * loss.mean() + self.perceptual_weight * perceptual_loss.mean()
|
||||
perceptual_loss = self.perceptual_loss(
|
||||
image_targets.contiguous(), image_reconstructions.contiguous()
|
||||
)
|
||||
loss = (
|
||||
self.latent_weight * loss.mean() + self.perceptual_weight * perceptual_loss.mean()
|
||||
)
|
||||
log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach()
|
||||
|
||||
if self.perceptual_weight_on_inputs > 0.0:
|
||||
image_reconstructions = default(image_reconstructions, self.decoder.decode(latent_predictions))
|
||||
image_reconstructions = default(
|
||||
image_reconstructions, self.decoder.decode(latent_predictions)
|
||||
)
|
||||
if self.scale_input_to_tgt_size:
|
||||
image_inputs = torch.nn.functional.interpolate(
|
||||
image_inputs,
|
||||
@ -58,7 +64,9 @@ class LatentLPIPS(nn.Module):
|
||||
antialias=True,
|
||||
)
|
||||
|
||||
perceptual_loss2 = self.perceptual_loss(image_inputs.contiguous(), image_reconstructions.contiguous())
|
||||
perceptual_loss2 = self.perceptual_loss(
|
||||
image_inputs.contiguous(), image_reconstructions.contiguous()
|
||||
)
|
||||
loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean()
|
||||
log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach()
|
||||
return loss, log
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user