This commit is contained in:
Yuxuan Zhang 2025-03-22 15:14:06 +08:00
parent b9b0539dbe
commit 39c6562dc8
144 changed files with 2619 additions and 1217 deletions

View File

@ -0,0 +1,28 @@
# Contribution Guide
We welcome your contributions to this repository. To ensure elegant code style and better code quality, we have prepared the following contribution guidelines.
## What We Accept
+ This PR fixes a typo or improves the documentation (if this is the case, you may skip the other checks).
+ This PR fixes a specific issue — please reference the issue number in the PR description. Make sure your code strictly follows the coding standards below.
+ This PR introduces a new feature — please clearly explain the necessity and implementation of the feature. Make sure your code strictly follows the coding standards below.
## Code Style Guide
Good code style is an art. We have prepared a `pyproject.toml` and a `pre-commit` hook to enforce consistent code formatting across the project. You can clean up your code following the steps below:
1. Install the required dependencies:
```shell
pip install ruff pre-commit
```
2. Then, run the following command:
```shell
pre-commit run --all-files
```
If your code complies with the standards, you should not see any errors.
## Naming Conventions
- Please use **English** for naming; do not use Pinyin or other languages. All comments should also be in English.
- Follow **PEP8** naming conventions strictly, and use underscores to separate words. Avoid meaningless names such as `a`, `b`, `c`.

View File

@ -30,14 +30,14 @@ body:
If you have code snippets, error messages, stack traces, please provide them here as well.
Please format your code correctly using code tags. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
Do not use screenshots, as they are difficult to read and (more importantly) do not allow others to copy and paste your code.
请提供能重现您遇到的问题的代码示例,最好是最小复现单元。
如果您有代码片段、错误信息、堆栈跟踪,也请在此提供。
请使用代码标签正确格式化您的代码。请参见 https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
请勿使用截图,因为截图难以阅读,而且(更重要的是)不允许他人复制粘贴您的代码。
placeholder: |
Steps to reproduce the behavior/复现Bug的步骤:
1.
2.
3.
@ -48,4 +48,4 @@ body:
required: true
attributes:
label: Expected behavior / 期待表现
description: "A clear and concise description of what you would expect to happen. /简单描述您期望发生的事情。"
description: "A clear and concise description of what you would expect to happen. /简单描述您期望发生的事情。"

View File

@ -29,6 +29,6 @@ body:
attributes:
label: Your contribution / 您的贡献
description: |
Your PR link or any other link you can help with.
您的PR链接或者其他您能提供帮助的链接。
您的PR链接或者其他您能提供帮助的链接。

View File

@ -1,34 +0,0 @@
# Raise valuable PR / 提出有价值的PR
## Caution / 注意事项:
Users should keep the following points in mind when submitting PRs:
1. Ensure that your code meets the requirements in the [specification](../../resources/contribute.md).
2. the proposed PR should be relevant, if there are multiple ideas and optimizations, they should be assigned to different PRs.
用户在提交PR时候应该注意以下几点:
1. 确保您的代码符合 [规范](../../resources/contribute_zh.md) 中的要求。
2. 提出的PR应该具有针对性如果具有多个不同的想法和优化方案应该分配到不同的PR中。
## 不应该提出的PR / PRs that should not be proposed
If a developer proposes a PR about any of the following, it may be closed or Rejected.
1. those that don't describe improvement options.
2. multiple issues of different types combined in one PR.
3. The proposed PR is highly duplicative of already existing PRs.
如果开发者提出关于以下方面的PR则可能会被直接关闭或拒绝通过。
1. 没有说明改进方案的。
2. 多个不同类型的问题合并在一个PR中的。
3. 提出的PR与已经存在的PR高度重复的。
# 检查您的PR
- [ ] Have you read the Contributor Guidelines, Pull Request section? / 您是否阅读了贡献者指南、Pull Request 部分?
- [ ] Has this been discussed/approved via a Github issue or forum? If so, add a link. / 是否通过 Github 问题或论坛讨论/批准过?如果是,请添加链接。
- [ ] Did you make sure you updated the documentation with your changes? Here are the Documentation Guidelines, and here are the Documentation Formatting Tips. /您是否确保根据您的更改更新了文档?这里是文档指南,这里是文档格式化技巧。
- [ ] Did you write new required tests? / 您是否编写了新的必要测试?
- [ ] Are your PRs for only one issue / 您的PR是否仅针对一个问题

2
.gitignore vendored
View File

@ -34,4 +34,4 @@ CogVideo-1.0
**/train_results
**/train_res*
**/uv.lock
**/uv.lock

19
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,19 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.5
hooks:
- id: ruff
args: [--fix, --respect-gitignore, --config=pyproject.toml]
- id: ruff-format
args: [--config=pyproject.toml]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-toml
- id: check-case-conflict
- id: check-merge-conflict
- id: debug-statements

View File

@ -68,4 +68,4 @@ Note that the license is subject to update to a more comprehensive version. For
本许可受中华人民共和国法律管辖并按其解释。 因本许可引起的或与本许可有关的任何争议应提交北京市海淀区人民法院。
请注意,许可证可能会更新到更全面的版本。 有关许可和版权的任何问题,请通过 license@zhipuai.cn 与我们联系。
请注意,许可证可能会更新到更全面的版本。 有关许可和版权的任何问题,请通过 license@zhipuai.cn 与我们联系。

View File

@ -14,7 +14,7 @@ Experience the CogVideoX-5B model online at <a href="https://huggingface.co/spac
📚 View the <a href="https://arxiv.org/abs/2408.06072" target="_blank">paper</a> and <a href="https://zhipu-ai.feishu.cn/wiki/DHCjw1TrJiTyeukfc9RceoSRnCh" target="_blank">user guide</a>
</p>
<p align="center">
👋 Join our <a href="resources/WECHAT.md" target="_blank">WeChat</a> and <a href="https://discord.gg/dCGfUsagrD" target="_blank">Discord</a>
👋 Join our <a href="resources/WECHAT.md" target="_blank">WeChat</a> and <a href="https://discord.gg/dCGfUsagrD" target="_blank">Discord</a>
</p>
<p align="center">
📍 Visit <a href="https://chatglm.cn/video?lang=en?fr=osm_cogvideo">QingYing</a> and <a href="https://open.bigmodel.cn/?utm_campaign=open&_channel_track_key=OWTVNma9">API Platform</a> to experience larger-scale commercial video generation models.
@ -22,12 +22,12 @@ Experience the CogVideoX-5B model online at <a href="https://huggingface.co/spac
## Project Updates
- 🔥🔥 **News**: ```2025/03/16```: We have launched [CogKit](https://github.com/THUDM/CogKit), a fine-tuning and inference framework for the **CogView4** and **CogVideoX** series. This toolkit allows you to fully explore and utilize our multimodal generation models.
- 🔥🔥 **News**: ```2025/03/24```: We have launched [CogKit](https://github.com/THUDM/CogKit), a fine-tuning and inference framework for the **CogView4** and **CogVideoX** series. This toolkit allows you to fully explore and utilize our multimodal generation models.
- 🔥 **News**: ```2025/02/28```: DDIM Inverse is now supported in `CogVideoX-5B` and `CogVideoX1.5-5B`. Check [here](inference/ddim_inversion.py).
- 🔥 **News**: ```2025/01/08```: We have updated the code for `Lora` fine-tuning based on the `diffusers` version model, which uses less GPU memory. For more details, please see [here](finetune/README.md).
- 🔥 **News**: ```2024/11/15```: We released the `CogVideoX1.5` model in the diffusers version. Only minor parameter adjustments are needed to continue using previous code.
- 🔥 **News**: ```2024/11/08```: We have released the CogVideoX1.5 model. CogVideoX1.5 is an upgraded version of the open-source model CogVideoX.
The CogVideoX1.5-5B series supports 10-second videos with higher resolution, and CogVideoX1.5-5B-I2V supports video generation at any resolution.
The CogVideoX1.5-5B series supports 10-second videos with higher resolution, and CogVideoX1.5-5B-I2V supports video generation at any resolution.
The SAT code has already been updated, while the diffusers version is still under adaptation. Download the SAT version code [here](https://huggingface.co/THUDM/CogVideoX1.5-5B-SAT).
- 🔥 **News**: ```2024/10/13```: A more cost-effective fine-tuning framework for `CogVideoX-5B` that works with a single
4090 GPU, [cogvideox-factory](https://github.com/a-r-r-o-w/cogvideox-factory), has been released. It supports
@ -47,10 +47,10 @@ The SAT code has already been updated, while the diffusers version is still unde
model [CogVLM2-Caption](https://huggingface.co/THUDM/cogvlm2-llama3-caption), used in the training process of
CogVideoX to convert video data into text descriptions, has been open-sourced. Welcome to download and use it.
- 🔥 ```2024/8/27```: We have open-sourced a larger model in the CogVideoX series, **CogVideoX-5B**. We have
significantly optimized the model's inference performance, greatly lowering the inference threshold.
significantly optimized the model's inference performance, greatly lowering the inference threshold.
You can run **CogVideoX-2B** on older GPUs like `GTX 1080TI`, and **CogVideoX-5B** on desktop GPUs like `RTX 3060`. Please strictly
follow the [requirements](requirements.txt) to update and install dependencies, and refer
to [cli_demo](inference/cli_demo.py) for inference code. Additionally, the open-source license for
to [cli_demo](inference/cli_demo.py) for inference code. Additionally, the open-source license for
the **CogVideoX-2B** model has been changed to the **Apache 2.0 License**.
- 🔥 ```2024/8/6```: We have open-sourced **3D Causal VAE**, used for **CogVideoX-2B**, which can reconstruct videos with
almost no loss.
@ -252,7 +252,7 @@ models we currently offer, along with their foundational information.
<tr>
<td style="text-align: center;">Position Encoding</td>
<td colspan="2" style="text-align: center;">3d_rope_pos_embed</td>
<td style="text-align: center;">3d_sincos_pos_embed</td>
<td style="text-align: center;">3d_sincos_pos_embed</td>
<td style="text-align: center;">3d_rope_pos_embed</td>
<td style="text-align: center;">3d_rope_pos_embed + learnable_pos_embed</td>
</tr>
@ -444,8 +444,6 @@ hands-on practice on text-to-video generation. *The original input is in Chinese
}
```
We welcome your contributions! You can click [here](resources/contribute.md) for more information.
## Model-License
The code in this repository is released under the [Apache 2.0 License](LICENSE).

View File

@ -21,7 +21,7 @@
</p>
## 更新とニュース
- 🔥🔥 ```2025/03/16```: [CogKit](https://github.com/THUDM/CogKit) は **CogView4** および **CogVideoX** シリーズの微調整と推論のためのフレームワークです。このツールキットを活用することで、私たちのマルチモーダル生成モデルを最大限に活用できます。
- 🔥🔥 ```2025/03/24```: [CogKit](https://github.com/THUDM/CogKit) は **CogView4** および **CogVideoX** シリーズの微調整と推論のためのフレームワークです。このツールキットを活用することで、私たちのマルチモーダル生成モデルを最大限に活用できます。
- **ニュース**: ```2025/02/28```: DDIM Inverse が `CogVideoX-5B` と `CogVideoX1.5-5B` でサポートされました。詳細は [こちら](inference/ddim_inversion.py) をご覧ください。
- **ニュース**: ```2025/01/08```: 私たちは`diffusers`バージョンのモデルをベースにした`Lora`微調整用のコードを更新しました。より少ないVRAMビデオメモリで動作します。詳細については[こちら](finetune/README_ja.md)をご覧ください。
- **ニュース**: ```2024/11/15```: `CogVideoX1.5` モデルのdiffusersバージョンをリリースしました。わずかなパラメータ調整で以前のコードをそのまま利用可能です。
@ -243,7 +243,7 @@ CogVideoXは、[清影](https://chatglm.cn/video?fr=osm_cogvideox) と同源の
<tr>
<td style="text-align: center;">位置エンコーディング</td>
<td colspan="2" style="text-align: center;">3d_rope_pos_embed</td>
<td style="text-align: center;">3d_sincos_pos_embed</td>
<td style="text-align: center;">3d_sincos_pos_embed</td>
<td style="text-align: center;">3d_rope_pos_embed</td>
<td style="text-align: center;">3d_rope_pos_embed + learnable_pos_embed</td>
</tr>
@ -418,8 +418,6 @@ CogVideoのデモは [https://models.aminer.cn/cogvideo](https://models.aminer.c
}
```
あなたの貢献をお待ちしています!詳細は[こちら](resources/contribute_ja.md)をクリックしてください。
## ライセンス契約
このリポジトリのコードは [Apache 2.0 License](LICENSE) の下で公開されています。

View File

@ -14,7 +14,7 @@
📚 查看 <a href="https://arxiv.org/abs/2408.06072" target="_blank">论文</a><a href="https://zhipu-ai.feishu.cn/wiki/DHCjw1TrJiTyeukfc9RceoSRnCh" target="_blank">使用文档</a>
</p>
<p align="center">
👋 加入我们的 <a href="resources/WECHAT.md" target="_blank">微信</a><a href="https://discord.gg/dCGfUsagrD" target="_blank">Discord</a>
👋 加入我们的 <a href="resources/WECHAT.md" target="_blank">微信</a><a href="https://discord.gg/dCGfUsagrD" target="_blank">Discord</a>
</p>
<p align="center">
📍 前往<a href="https://chatglm.cn/video?fr=osm_cogvideox"> 清影</a><a href="https://open.bigmodel.cn/?utm_campaign=open&_channel_track_key=OWTVNma9"> API平台</a> 体验更大规模的商业版视频生成模型。
@ -22,11 +22,11 @@
## 项目更新
- 🔥🔥 **News**: ```2025/03/16```: 我们推出了 [CogKit](https://github.com/THUDM/CogKit) 工具,这是一个微调**CogView4**, **CogVideoX** 系列的微调和推理框架,一个工具包,玩转我们的多模态生成模型。
- 🔥🔥 **News**: ```2025/03/24```: 我们推出了 [CogKit](https://github.com/THUDM/CogKit) 工具,这是一个微调**CogView4**, **CogVideoX** 系列的微调和推理框架,一个工具包,玩转我们的多模态生成模型。
- 🔥 **News**: ```2025/02/28```: DDIM Inverse 已经在`CogVideoX-5B` 和 `CogVideoX1.5 -5B` 支持,查看 [here](inference/ddim_inversion.py).
- 🔥 **News**: ```2025/01/08```: 我们更新了基于`diffusers`版本模型的`Lora`微调代码,占用显存更低,详情请见[这里](finetune/README_zh.md)。
- 🔥 **News**: ```2024/11/15```: 我们发布 `CogVideoX1.5` 模型的diffusers版本仅需调整部分参数仅可沿用之前的代码。
- 🔥 **News**: ```2024/11/08```: 我们发布 `CogVideoX1.5` 模型。CogVideoX1.5 是 CogVideoX 开源模型的升级版本。
- 🔥 **News**: ```2024/11/08```: 我们发布 `CogVideoX1.5` 模型。CogVideoX1.5 是 CogVideoX 开源模型的升级版本。
CogVideoX1.5-5B 系列模型支持 **10秒** 长度的视频和更高的分辨率,其中 `CogVideoX1.5-5B-I2V` 支持 **任意分辨率** 的视频生成SAT代码已经更新。`diffusers`版本还在适配中。SAT版本代码前往 [这里](https://huggingface.co/THUDM/CogVideoX1.5-5B-SAT) 下载。
- 🔥**News**: ```2024/10/13```: 成本更低单卡4090可微调 `CogVideoX-5B`
的微调框架[cogvideox-factory](https://github.com/a-r-r-o-w/cogvideox-factory)已经推出,多种分辨率微调,欢迎使用。
@ -234,7 +234,7 @@ CogVideoX是 [清影](https://chatglm.cn/video?fr=osm_cogvideox) 同源的开源
<tr>
<td style="text-align: center;">位置编码</td>
<td colspan="2" style="text-align: center;">3d_rope_pos_embed</td>
<td style="text-align: center;">3d_sincos_pos_embed</td>
<td style="text-align: center;">3d_sincos_pos_embed</td>
<td style="text-align: center;">3d_rope_pos_embed</td>
<td style="text-align: center;">3d_rope_pos_embed + learnable_pos_embed</td>
</tr>
@ -398,8 +398,6 @@ CogVideo的demo网站在[https://models.aminer.cn/cogvideo](https://models.amine
}
```
我们欢迎您的贡献,您可以点击[这里](resources/contribute_zh.md)查看更多信息。
## 模型协议
本仓库代码使用 [Apache 2.0 协议](LICENSE) 发布。

View File

@ -143,4 +143,4 @@ When using SFT training, please note:
+ The original repository uses `lora_alpha` set to 1. We found that this value performed poorly in several runs,
possibly due to differences in the model backend and training settings. Our recommendation is to set `lora_alpha` to
be equal to the rank or `rank // 2`.
+ It's advised to use a rank of 64 or higher.
+ It's advised to use a rank of 64 or higher.

View File

@ -121,4 +121,4 @@ SFTトレーニングを使用する際に注意すべき点
+ 25本以上の動画を使用することで、新しい概念やスタイルのトレーニングが最適です。
+ `--id_token` で指定できる識別子トークンを使用すると、トレーニング効果がより良くなります。これはDreamboothトレーニングに似ていますが、このトークンを使用しない通常のファインチューニングでも問題なく動作します。
+ 元のリポジトリでは `lora_alpha` が1に設定されていますが、この値は多くの実行で効果が悪かったため、モデルのバックエンドやトレーニング設定の違いが影響している可能性があります。私たちの推奨は、`lora_alpha` を rank と同じか、`rank // 2` に設定することです。
+ rank は64以上に設定することをお勧めします。
+ rank は64以上に設定することをお勧めします。

View File

@ -18,4 +18,4 @@ same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
use_cpu: false

View File

@ -35,4 +35,4 @@
"gradient_clipping": "auto",
"steps_per_print": 2000,
"wall_clock_breakdown": false
}
}

View File

@ -39,4 +39,4 @@
"gradient_clipping": "auto",
"steps_per_print": 2000,
"wall_clock_breakdown": false
}
}

View File

@ -40,4 +40,4 @@
"gradient_clipping": "auto",
"steps_per_print": 2000,
"wall_clock_breakdown": false
}
}

View File

@ -48,4 +48,4 @@
"gradient_clipping": "auto",
"steps_per_print": 2000,
"wall_clock_breakdown": false
}
}

View File

@ -26,7 +26,11 @@ class BucketSampler(Sampler):
"""
def __init__(
self, data_source: Dataset, batch_size: int = 8, shuffle: bool = True, drop_last: bool = False
self,
data_source: Dataset,
batch_size: int = 8,
shuffle: bool = True,
drop_last: bool = False,
) -> None:
self.data_source = data_source
self.batch_size = batch_size
@ -48,7 +52,11 @@ class BucketSampler(Sampler):
def __iter__(self):
for index, data in enumerate(self.data_source):
video_metadata = data["video_metadata"]
f, h, w = video_metadata["num_frames"], video_metadata["height"], video_metadata["width"]
f, h, w = (
video_metadata["num_frames"],
video_metadata["height"],
video_metadata["width"],
)
self.buckets[(f, h, w)].append(data)
if len(self.buckets[(f, h, w)]) == self.batch_size:

View File

@ -115,7 +115,9 @@ class BaseI2VDataset(Dataset):
train_resolution_str = "x".join(str(x) for x in self.trainer.args.train_resolution)
cache_dir = self.trainer.args.data_root / "cache"
video_latent_dir = cache_dir / "video_latent" / self.trainer.args.model_name / train_resolution_str
video_latent_dir = (
cache_dir / "video_latent" / self.trainer.args.model_name / train_resolution_str
)
prompt_embeddings_dir = cache_dir / "prompt_embeddings"
video_latent_dir.mkdir(parents=True, exist_ok=True)
prompt_embeddings_dir.mkdir(parents=True, exist_ok=True)
@ -136,7 +138,9 @@ class BaseI2VDataset(Dataset):
# [1, seq_len, hidden_size] -> [seq_len, hidden_size]
prompt_embedding = prompt_embedding[0]
save_file({"prompt_embedding": prompt_embedding}, prompt_embedding_path)
logger.info(f"Saved prompt embedding to {prompt_embedding_path}", main_process_only=False)
logger.info(
f"Saved prompt embedding to {prompt_embedding_path}", main_process_only=False
)
if encoded_video_path.exists():
encoded_video = load_file(encoded_video_path)["encoded_video"]
@ -177,7 +181,9 @@ class BaseI2VDataset(Dataset):
},
}
def preprocess(self, video_path: Path | None, image_path: Path | None) -> Tuple[torch.Tensor, torch.Tensor]:
def preprocess(
self, video_path: Path | None, image_path: Path | None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Loads and preprocesses a video and an image.
If either path is None, no preprocessing will be done for that input.
@ -249,13 +255,19 @@ class I2VDatasetWithResize(BaseI2VDataset):
self.height = height
self.width = width
self.__frame_transforms = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)])
self.__frame_transforms = transforms.Compose(
[transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)]
)
self.__image_transforms = self.__frame_transforms
@override
def preprocess(self, video_path: Path | None, image_path: Path | None) -> Tuple[torch.Tensor, torch.Tensor]:
def preprocess(
self, video_path: Path | None, image_path: Path | None
) -> Tuple[torch.Tensor, torch.Tensor]:
if video_path is not None:
video = preprocess_video_with_resize(video_path, self.max_num_frames, self.height, self.width)
video = preprocess_video_with_resize(
video_path, self.max_num_frames, self.height, self.width
)
else:
video = None
if image_path is not None:
@ -293,7 +305,9 @@ class I2VDatasetWithBuckets(BaseI2VDataset):
)
for b in video_resolution_buckets
]
self.__frame_transforms = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)])
self.__frame_transforms = transforms.Compose(
[transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)]
)
self.__image_transforms = self.__frame_transforms
@override

View File

@ -11,7 +11,12 @@ from typing_extensions import override
from finetune.constants import LOG_LEVEL, LOG_NAME
from .utils import load_prompts, load_videos, preprocess_video_with_buckets, preprocess_video_with_resize
from .utils import (
load_prompts,
load_videos,
preprocess_video_with_buckets,
preprocess_video_with_resize,
)
if TYPE_CHECKING:
@ -93,7 +98,9 @@ class BaseT2VDataset(Dataset):
train_resolution_str = "x".join(str(x) for x in self.trainer.args.train_resolution)
cache_dir = self.trainer.args.data_root / "cache"
video_latent_dir = cache_dir / "video_latent" / self.trainer.args.model_name / train_resolution_str
video_latent_dir = (
cache_dir / "video_latent" / self.trainer.args.model_name / train_resolution_str
)
prompt_embeddings_dir = cache_dir / "prompt_embeddings"
video_latent_dir.mkdir(parents=True, exist_ok=True)
prompt_embeddings_dir.mkdir(parents=True, exist_ok=True)
@ -114,7 +121,9 @@ class BaseT2VDataset(Dataset):
# [1, seq_len, hidden_size] -> [seq_len, hidden_size]
prompt_embedding = prompt_embedding[0]
save_file({"prompt_embedding": prompt_embedding}, prompt_embedding_path)
logger.info(f"Saved prompt embedding to {prompt_embedding_path}", main_process_only=False)
logger.info(
f"Saved prompt embedding to {prompt_embedding_path}", main_process_only=False
)
if encoded_video_path.exists():
# encoded_video = torch.load(encoded_video_path, weights_only=True)
@ -202,7 +211,9 @@ class T2VDatasetWithResize(BaseT2VDataset):
self.height = height
self.width = width
self.__frame_transform = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)])
self.__frame_transform = transforms.Compose(
[transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)]
)
@override
def preprocess(self, video_path: Path) -> torch.Tensor:
@ -240,7 +251,9 @@ class T2VDatasetWithBuckets(BaseT2VDataset):
for b in video_resolution_buckets
]
self.__frame_transform = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)])
self.__frame_transform = transforms.Compose(
[transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)]
)
@override
def preprocess(self, video_path: Path) -> torch.Tensor:

View File

@ -24,12 +24,16 @@ def load_prompts(prompt_path: Path) -> List[str]:
def load_videos(video_path: Path) -> List[Path]:
with open(video_path, "r", encoding="utf-8") as file:
return [video_path.parent / line.strip() for line in file.readlines() if len(line.strip()) > 0]
return [
video_path.parent / line.strip() for line in file.readlines() if len(line.strip()) > 0
]
def load_images(image_path: Path) -> List[Path]:
with open(image_path, "r", encoding="utf-8") as file:
return [image_path.parent / line.strip() for line in file.readlines() if len(line.strip()) > 0]
return [
image_path.parent / line.strip() for line in file.readlines() if len(line.strip()) > 0
]
def load_images_from_videos(videos_path: List[Path]) -> List[Path]:
@ -169,7 +173,9 @@ def preprocess_video_with_buckets(
video_num_frames = len(video_reader)
resolution_buckets = [bucket for bucket in resolution_buckets if bucket[0] <= video_num_frames]
if len(resolution_buckets) == 0:
raise ValueError(f"video frame count in {video_path} is less than all frame buckets {resolution_buckets}")
raise ValueError(
f"video frame count in {video_path} is less than all frame buckets {resolution_buckets}"
)
nearest_frame_bucket = min(
resolution_buckets,
@ -181,7 +187,9 @@ def preprocess_video_with_buckets(
frames = frames[:nearest_frame_bucket].float()
frames = frames.permute(0, 3, 1, 2).contiguous()
nearest_res = min(resolution_buckets, key=lambda x: abs(x[1] - frames.shape[2]) + abs(x[2] - frames.shape[3]))
nearest_res = min(
resolution_buckets, key=lambda x: abs(x[1] - frames.shape[2]) + abs(x[2] - frames.shape[3])
)
nearest_res = (nearest_res[1], nearest_res[2])
frames = torch.stack([resize(f, nearest_res) for f in frames], dim=0)

View File

@ -32,13 +32,19 @@ class CogVideoXI2VLoraTrainer(Trainer):
components.tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder="tokenizer")
components.text_encoder = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder")
components.text_encoder = T5EncoderModel.from_pretrained(
model_path, subfolder="text_encoder"
)
components.transformer = CogVideoXTransformer3DModel.from_pretrained(model_path, subfolder="transformer")
components.transformer = CogVideoXTransformer3DModel.from_pretrained(
model_path, subfolder="transformer"
)
components.vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae")
components.scheduler = CogVideoXDPMScheduler.from_pretrained(model_path, subfolder="scheduler")
components.scheduler = CogVideoXDPMScheduler.from_pretrained(
model_path, subfolder="scheduler"
)
return components
@ -73,7 +79,9 @@ class CogVideoXI2VLoraTrainer(Trainer):
return_tensors="pt",
)
prompt_token_ids = prompt_token_ids.input_ids
prompt_embedding = self.components.text_encoder(prompt_token_ids.to(self.accelerator.device))[0]
prompt_embedding = self.components.text_encoder(
prompt_token_ids.to(self.accelerator.device)
)[0]
return prompt_embedding
@override
@ -122,22 +130,34 @@ class CogVideoXI2VLoraTrainer(Trainer):
# Add frame dimension to images [B,C,H,W] -> [B,C,F,H,W]
images = images.unsqueeze(2)
# Add noise to images
image_noise_sigma = torch.normal(mean=-3.0, std=0.5, size=(1,), device=self.accelerator.device)
image_noise_sigma = torch.normal(
mean=-3.0, std=0.5, size=(1,), device=self.accelerator.device
)
image_noise_sigma = torch.exp(image_noise_sigma).to(dtype=images.dtype)
noisy_images = images + torch.randn_like(images) * image_noise_sigma[:, None, None, None, None]
image_latent_dist = self.components.vae.encode(noisy_images.to(dtype=self.components.vae.dtype)).latent_dist
noisy_images = (
images + torch.randn_like(images) * image_noise_sigma[:, None, None, None, None]
)
image_latent_dist = self.components.vae.encode(
noisy_images.to(dtype=self.components.vae.dtype)
).latent_dist
image_latents = image_latent_dist.sample() * self.components.vae.config.scaling_factor
# Sample a random timestep for each sample
timesteps = torch.randint(
0, self.components.scheduler.config.num_train_timesteps, (batch_size,), device=self.accelerator.device
0,
self.components.scheduler.config.num_train_timesteps,
(batch_size,),
device=self.accelerator.device,
)
timesteps = timesteps.long()
# from [B, C, F, H, W] to [B, F, C, H, W]
latent = latent.permute(0, 2, 1, 3, 4)
image_latents = image_latents.permute(0, 2, 1, 3, 4)
assert (latent.shape[0], *latent.shape[2:]) == (image_latents.shape[0], *image_latents.shape[2:])
assert (latent.shape[0], *latent.shape[2:]) == (
image_latents.shape[0],
*image_latents.shape[2:],
)
# Padding image_latents to the same frame number as latent
padding_shape = (latent.shape[0], latent.shape[1] - 1, *latent.shape[2:])
@ -169,7 +189,9 @@ class CogVideoXI2VLoraTrainer(Trainer):
# Predict noise, For CogVideoX1.5 Only.
ofs_emb = (
None if self.state.transformer_config.ofs_embed_dim is None else latent.new_full((1,), fill_value=2.0)
None
if self.state.transformer_config.ofs_embed_dim is None
else latent.new_full((1,), fill_value=2.0)
)
predicted_noise = self.components.transformer(
hidden_states=latent_img_noisy,
@ -181,7 +203,9 @@ class CogVideoXI2VLoraTrainer(Trainer):
)[0]
# Denoise
latent_pred = self.components.scheduler.get_velocity(predicted_noise, latent_noisy, timesteps)
latent_pred = self.components.scheduler.get_velocity(
predicted_noise, latent_noisy, timesteps
)
alphas_cumprod = self.components.scheduler.alphas_cumprod[timesteps]
weights = 1 / (1 - alphas_cumprod)
@ -228,7 +252,9 @@ class CogVideoXI2VLoraTrainer(Trainer):
if transformer_config.patch_size_t is None:
base_num_frames = num_frames
else:
base_num_frames = (num_frames + transformer_config.patch_size_t - 1) // transformer_config.patch_size_t
base_num_frames = (
num_frames + transformer_config.patch_size_t - 1
) // transformer_config.patch_size_t
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=transformer_config.attention_head_dim,

View File

@ -31,13 +31,19 @@ class CogVideoXT2VLoraTrainer(Trainer):
components.tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder="tokenizer")
components.text_encoder = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder")
components.text_encoder = T5EncoderModel.from_pretrained(
model_path, subfolder="text_encoder"
)
components.transformer = CogVideoXTransformer3DModel.from_pretrained(model_path, subfolder="transformer")
components.transformer = CogVideoXTransformer3DModel.from_pretrained(
model_path, subfolder="transformer"
)
components.vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae")
components.scheduler = CogVideoXDPMScheduler.from_pretrained(model_path, subfolder="scheduler")
components.scheduler = CogVideoXDPMScheduler.from_pretrained(
model_path, subfolder="scheduler"
)
return components
@ -72,7 +78,9 @@ class CogVideoXT2VLoraTrainer(Trainer):
return_tensors="pt",
)
prompt_token_ids = prompt_token_ids.input_ids
prompt_embedding = self.components.text_encoder(prompt_token_ids.to(self.accelerator.device))[0]
prompt_embedding = self.components.text_encoder(
prompt_token_ids.to(self.accelerator.device)
)[0]
return prompt_embedding
@override
@ -115,7 +123,10 @@ class CogVideoXT2VLoraTrainer(Trainer):
# Sample a random timestep for each sample
timesteps = torch.randint(
0, self.components.scheduler.config.num_train_timesteps, (batch_size,), device=self.accelerator.device
0,
self.components.scheduler.config.num_train_timesteps,
(batch_size,),
device=self.accelerator.device,
)
timesteps = timesteps.long()
@ -150,7 +161,9 @@ class CogVideoXT2VLoraTrainer(Trainer):
)[0]
# Denoise
latent_pred = self.components.scheduler.get_velocity(predicted_noise, latent_added_noise, timesteps)
latent_pred = self.components.scheduler.get_velocity(
predicted_noise, latent_added_noise, timesteps
)
alphas_cumprod = self.components.scheduler.alphas_cumprod[timesteps]
weights = 1 / (1 - alphas_cumprod)
@ -196,7 +209,9 @@ class CogVideoXT2VLoraTrainer(Trainer):
if transformer_config.patch_size_t is None:
base_num_frames = num_frames
else:
base_num_frames = (num_frames + transformer_config.patch_size_t - 1) // transformer_config.patch_size_t
base_num_frames = (
num_frames + transformer_config.patch_size_t - 1
) // transformer_config.patch_size_t
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=transformer_config.attention_head_dim,
crops_coords=None,

View File

@ -52,6 +52,8 @@ def get_model_cls(model_type: str, training_type: Literal["lora", "sft"]) -> Tra
print(f"\nSupported training types for '{model_type}' are:")
for supported_type in SUPPORTED_MODELS[model_type]:
print(f"{supported_type}")
raise ValueError(f"Training type '{training_type}' is not supported for model '{model_type}'")
raise ValueError(
f"Training type '{training_type}' is not supported for model '{model_type}'"
)
return SUPPORTED_MODELS[model_type][training_type]

View File

@ -115,14 +115,18 @@ class Args(BaseModel):
def validate_validation_images(cls, v: str | None, info: ValidationInfo) -> str | None:
values = info.data
if values.get("do_validation") and values.get("model_type") == "i2v" and not v:
raise ValueError("validation_images must be specified when do_validation is True and model_type is i2v")
raise ValueError(
"validation_images must be specified when do_validation is True and model_type is i2v"
)
return v
@field_validator("validation_videos")
def validate_validation_videos(cls, v: str | None, info: ValidationInfo) -> str | None:
values = info.data
if values.get("do_validation") and values.get("model_type") == "v2v" and not v:
raise ValueError("validation_videos must be specified when do_validation is True and model_type is v2v")
raise ValueError(
"validation_videos must be specified when do_validation is True and model_type is v2v"
)
return v
@field_validator("validation_steps")
@ -148,7 +152,9 @@ class Args(BaseModel):
model_name = info.data.get("model_name", "")
if model_name in ["cogvideox-5b-i2v", "cogvideox-5b-t2v"]:
if (height, width) != (480, 720):
raise ValueError("For cogvideox-5b models, height must be 480 and width must be 720")
raise ValueError(
"For cogvideox-5b models, height must be 480 and width must be 720"
)
return v
@ -221,7 +227,9 @@ class Args(BaseModel):
# LoRA parameters
parser.add_argument("--rank", type=int, default=128)
parser.add_argument("--lora_alpha", type=int, default=64)
parser.add_argument("--target_modules", type=str, nargs="+", default=["to_q", "to_k", "to_v", "to_out.0"])
parser.add_argument(
"--target_modules", type=str, nargs="+", default=["to_q", "to_k", "to_v", "to_out.0"]
)
# Checkpointing
parser.add_argument("--checkpointing_steps", type=int, default=200)

View File

@ -8,7 +8,10 @@ import cv2
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--datadir", type=str, required=True, help="Root directory containing videos.txt and video subdirectory"
"--datadir",
type=str,
required=True,
help="Root directory containing videos.txt and video subdirectory",
)
return parser.parse_args()

View File

@ -88,7 +88,9 @@ class Trainer:
def _init_distributed(self):
logging_dir = Path(self.args.output_dir, "logs")
project_config = ProjectConfiguration(project_dir=self.args.output_dir, logging_dir=logging_dir)
project_config = ProjectConfiguration(
project_dir=self.args.output_dir, logging_dir=logging_dir
)
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
init_process_group_kwargs = InitProcessGroupKwargs(
backend="nccl", timeout=timedelta(seconds=self.args.nccl_timeout)
@ -183,7 +185,9 @@ class Trainer:
# Prepare VAE and text encoder for encoding
self.components.vae.requires_grad_(False)
self.components.text_encoder.requires_grad_(False)
self.components.vae = self.components.vae.to(self.accelerator.device, dtype=self.state.weight_dtype)
self.components.vae = self.components.vae.to(
self.accelerator.device, dtype=self.state.weight_dtype
)
self.components.text_encoder = self.components.text_encoder.to(
self.accelerator.device, dtype=self.state.weight_dtype
)
@ -263,7 +267,9 @@ class Trainer:
# For LoRA, we only want to train the LoRA weights
# For SFT, we want to train all the parameters
trainable_parameters = list(filter(lambda p: p.requires_grad, self.components.transformer.parameters()))
trainable_parameters = list(
filter(lambda p: p.requires_grad, self.components.transformer.parameters())
)
transformer_parameters_with_lr = {
"params": trainable_parameters,
"lr": self.args.learning_rate,
@ -287,7 +293,9 @@ class Trainer:
use_deepspeed=use_deepspeed_opt,
)
num_update_steps_per_epoch = math.ceil(len(self.data_loader) / self.args.gradient_accumulation_steps)
num_update_steps_per_epoch = math.ceil(
len(self.data_loader) / self.args.gradient_accumulation_steps
)
if self.args.train_steps is None:
self.args.train_steps = self.args.train_epochs * num_update_steps_per_epoch
self.state.overwrote_max_train_steps = True
@ -322,12 +330,16 @@ class Trainer:
self.lr_scheduler = lr_scheduler
def prepare_for_training(self) -> None:
self.components.transformer, self.optimizer, self.data_loader, self.lr_scheduler = self.accelerator.prepare(
self.components.transformer, self.optimizer, self.data_loader, self.lr_scheduler
self.components.transformer, self.optimizer, self.data_loader, self.lr_scheduler = (
self.accelerator.prepare(
self.components.transformer, self.optimizer, self.data_loader, self.lr_scheduler
)
)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(self.data_loader) / self.args.gradient_accumulation_steps)
num_update_steps_per_epoch = math.ceil(
len(self.data_loader) / self.args.gradient_accumulation_steps
)
if self.state.overwrote_max_train_steps:
self.args.train_steps = self.args.train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
@ -364,7 +376,9 @@ class Trainer:
logger.info(f"Memory before training start: {json.dumps(memory_statistics, indent=4)}")
self.state.total_batch_size_count = (
self.args.batch_size * self.accelerator.num_processes * self.args.gradient_accumulation_steps
self.args.batch_size
* self.accelerator.num_processes
* self.args.gradient_accumulation_steps
)
info = {
"trainable parameters": self.state.num_trainable_parameters,
@ -454,7 +468,9 @@ class Trainer:
progress_bar.set_postfix(logs)
# Maybe run validation
should_run_validation = self.args.do_validation and global_step % self.args.validation_steps == 0
should_run_validation = (
self.args.do_validation and global_step % self.args.validation_steps == 0
)
if should_run_validation:
del loss
free_memory()
@ -466,7 +482,9 @@ class Trainer:
break
memory_statistics = get_memory_statistics()
logger.info(f"Memory after epoch {epoch + 1}: {json.dumps(memory_statistics, indent=4)}")
logger.info(
f"Memory after epoch {epoch + 1}: {json.dumps(memory_statistics, indent=4)}"
)
accelerator.wait_for_everyone()
self.__maybe_save_checkpoint(global_step, must_save=True)
@ -504,7 +522,9 @@ class Trainer:
# Can't using model_cpu_offload in deepspeed,
# so we need to move all components in pipe to device
# pipe.to(self.accelerator.device, dtype=self.state.weight_dtype)
self.__move_components_to_device(dtype=self.state.weight_dtype, ignore_list=["transformer"])
self.__move_components_to_device(
dtype=self.state.weight_dtype, ignore_list=["transformer"]
)
else:
# if not using deepspeed, use model_cpu_offload to further reduce memory usage
# Or use pipe.enable_sequential_cpu_offload() to further reduce memory usage
@ -528,7 +548,9 @@ class Trainer:
video = self.state.validation_videos[i]
if image is not None:
image = preprocess_image_with_resize(image, self.state.train_height, self.state.train_width)
image = preprocess_image_with_resize(
image, self.state.train_height, self.state.train_width
)
# Convert image tensor (C, H, W) to PIL images
image = image.to(torch.uint8)
image = image.permute(1, 2, 0).cpu().numpy()
@ -546,7 +568,9 @@ class Trainer:
f"Validating sample {i + 1}/{num_validation_samples} on process {accelerator.process_index}. Prompt: {prompt}",
main_process_only=False,
)
validation_artifacts = self.validation_step({"prompt": prompt, "image": image, "video": video}, pipe)
validation_artifacts = self.validation_step(
{"prompt": prompt, "image": image, "video": video}, pipe
)
if (
self.state.using_deepspeed
@ -565,7 +589,9 @@ class Trainer:
"video": {"type": "video", "value": video},
}
for i, (artifact_type, artifact_value) in enumerate(validation_artifacts):
artifacts.update({f"artifact_{i}": {"type": artifact_type, "value": artifact_value}})
artifacts.update(
{f"artifact_{i}": {"type": artifact_type, "value": artifact_value}}
)
logger.debug(
f"Validation artifacts on process {accelerator.process_index}: {list(artifacts.keys())}",
main_process_only=False,
@ -600,8 +626,12 @@ class Trainer:
tracker_key = "validation"
for tracker in accelerator.trackers:
if tracker.name == "wandb":
image_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Image)]
video_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Video)]
image_artifacts = [
artifact for artifact in all_artifacts if isinstance(artifact, wandb.Image)
]
video_artifacts = [
artifact for artifact in all_artifacts if isinstance(artifact, wandb.Video)
]
tracker.log(
{
tracker_key: {"images": image_artifacts, "videos": video_artifacts},
@ -618,7 +648,9 @@ class Trainer:
pipe.remove_all_hooks()
del pipe
# Load models except those not needed for training
self.__move_components_to_device(dtype=self.state.weight_dtype, ignore_list=self.UNLOAD_LIST)
self.__move_components_to_device(
dtype=self.state.weight_dtype, ignore_list=self.UNLOAD_LIST
)
self.components.transformer.to(self.accelerator.device, dtype=self.state.weight_dtype)
# Change trainable weights back to fp32 to keep with dtype after prepare the model
@ -687,7 +719,9 @@ class Trainer:
for name, component in components.items():
if not isinstance(component, type) and hasattr(component, "to"):
if name not in ignore_list:
setattr(self.components, name, component.to(self.accelerator.device, dtype=dtype))
setattr(
self.components, name, component.to(self.accelerator.device, dtype=dtype)
)
def __move_components_to_cpu(self, unload_list: List[str] = []):
unload_list = set(unload_list)
@ -732,11 +766,13 @@ class Trainer:
):
transformer_ = unwrap_model(self.accelerator, model)
else:
raise ValueError(f"Unexpected save model: {unwrap_model(self.accelerator, model).__class__}")
raise ValueError(
f"Unexpected save model: {unwrap_model(self.accelerator, model).__class__}"
)
else:
transformer_ = unwrap_model(self.accelerator, self.components.transformer).__class__.from_pretrained(
self.args.model_path, subfolder="transformer"
)
transformer_ = unwrap_model(
self.accelerator, self.components.transformer
).__class__.from_pretrained(self.args.model_path, subfolder="transformer")
transformer_.add_adapter(transformer_lora_config)
lora_state_dict = self.components.pipeline_cls.lora_state_dict(input_dir)
@ -745,7 +781,9 @@ class Trainer:
for k, v in lora_state_dict.items()
if k.startswith("transformer.")
}
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
incompatible_keys = set_peft_model_state_dict(
transformer_, transformer_state_dict, adapter_name="default"
)
if incompatible_keys is not None:
# check only for unexpected keys
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
@ -759,7 +797,10 @@ class Trainer:
self.accelerator.register_load_state_pre_hook(load_model_hook)
def __maybe_save_checkpoint(self, global_step: int, must_save: bool = False):
if self.accelerator.distributed_type == DistributedType.DEEPSPEED or self.accelerator.is_main_process:
if (
self.accelerator.distributed_type == DistributedType.DEEPSPEED
or self.accelerator.is_main_process
):
if must_save or global_step % self.args.checkpointing_steps == 0:
# for training
save_path = get_intermediate_ckpt_path(

View File

@ -23,7 +23,9 @@ def get_latest_ckpt_path_to_resume_from(
else:
resume_from_checkpoint_path = Path(resume_from_checkpoint)
if not resume_from_checkpoint_path.exists():
logger.info(f"Checkpoint '{resume_from_checkpoint}' does not exist. Starting a new training run.")
logger.info(
f"Checkpoint '{resume_from_checkpoint}' does not exist. Starting a new training run."
)
initial_global_step = 0
global_step = 0
first_epoch = 0

View File

@ -55,7 +55,9 @@ def unload_model(model):
model.to("cpu")
def make_contiguous(x: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
def make_contiguous(
x: Union[torch.Tensor, Dict[str, torch.Tensor]],
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
if isinstance(x, torch.Tensor):
return x.contiguous()
elif isinstance(x, dict):

View File

@ -67,7 +67,9 @@ def get_optimizer(
optimizer_name = "adamw"
if (use_8bit or use_4bit) and optimizer_name not in ["adam", "adamw"]:
raise ValueError("`use_8bit` and `use_4bit` can only be used with the Adam and AdamW optimizers.")
raise ValueError(
"`use_8bit` and `use_4bit` can only be used with the Adam and AdamW optimizers."
)
if use_8bit:
try:
@ -81,7 +83,9 @@ def get_optimizer(
if use_torchao:
from torchao.prototype.low_bit_optim import AdamW4bit, AdamW8bit
optimizer_class = AdamW8bit if use_8bit else AdamW4bit if use_4bit else torch.optim.AdamW
optimizer_class = (
AdamW8bit if use_8bit else AdamW4bit if use_4bit else torch.optim.AdamW
)
else:
optimizer_class = bnb.optim.AdamW8bit if use_8bit else torch.optim.AdamW
@ -109,7 +113,9 @@ def get_optimizer(
try:
import prodigyopt
except ImportError:
raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
raise ImportError(
"To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`"
)
optimizer_class = prodigyopt.Prodigy
@ -133,7 +139,9 @@ def get_optimizer(
try:
import came_pytorch
except ImportError:
raise ImportError("To use CAME, please install the came-pytorch library: `pip install came-pytorch`")
raise ImportError(
"To use CAME, please install the came-pytorch library: `pip install came-pytorch`"
)
optimizer_class = came_pytorch.CAME
@ -151,7 +159,10 @@ def get_optimizer(
init_kwargs.update({"fused": True})
optimizer = CPUOffloadOptimizer(
params_to_optimize, optimizer_class=optimizer_class, offload_gradients=offload_gradients, **init_kwargs
params_to_optimize,
optimizer_class=optimizer_class,
offload_gradients=offload_gradients,
**init_kwargs,
)
else:
optimizer = optimizer_class(params_to_optimize, **init_kwargs)

View File

@ -99,7 +99,9 @@ def generate_video(
desired_resolution = RESOLUTION_MAP[model_name]
if width is None or height is None:
height, width = desired_resolution
logging.info(f"\033[1mUsing default resolution {desired_resolution} for {model_name}\033[0m")
logging.info(
f"\033[1mUsing default resolution {desired_resolution} for {model_name}\033[0m"
)
elif (height, width) != desired_resolution:
if generate_type == "i2v":
# For i2v models, use user-defined width and height
@ -124,7 +126,9 @@ def generate_video(
# If you're using with lora, add this code
if lora_path:
pipe.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name="test_1")
pipe.load_lora_weights(
lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name="test_1"
)
pipe.fuse_lora(components=["transformer"], lora_scale=1 / lora_rank)
# 2. Set Scheduler.
@ -133,7 +137,9 @@ def generate_video(
# using `CogVideoXDPMScheduler` for CogVideoX-5B / CogVideoX-5B-I2V.
# pipe.scheduler = CogVideoXDDIMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
pipe.scheduler = CogVideoXDPMScheduler.from_config(
pipe.scheduler.config, timestep_spacing="trailing"
)
# 3. Enable CPU offload for the model.
# turn off if you have multiple GPUs or enough GPU memory(such as H100) and it will cost less time in inference
@ -190,8 +196,12 @@ def generate_video(
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX")
parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated")
parser = argparse.ArgumentParser(
description="Generate a video from a text prompt using CogVideoX"
)
parser.add_argument(
"--prompt", type=str, required=True, help="The description of the video to be generated"
)
parser.add_argument(
"--image_or_video_path",
type=str,
@ -199,20 +209,44 @@ if __name__ == "__main__":
help="The path of the image to be used as the background of the video",
)
parser.add_argument(
"--model_path", type=str, default="THUDM/CogVideoX1.5-5B", help="Path of the pre-trained model use"
"--model_path",
type=str,
default="THUDM/CogVideoX1.5-5B",
help="Path of the pre-trained model use",
)
parser.add_argument(
"--lora_path", type=str, default=None, help="The path of the LoRA weights to be used"
)
parser.add_argument("--lora_path", type=str, default=None, help="The path of the LoRA weights to be used")
parser.add_argument("--lora_rank", type=int, default=128, help="The rank of the LoRA weights")
parser.add_argument("--output_path", type=str, default="./output.mp4", help="The path save generated video")
parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance")
parser.add_argument(
"--output_path", type=str, default="./output.mp4", help="The path save generated video"
)
parser.add_argument(
"--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance"
)
parser.add_argument("--num_inference_steps", type=int, default=50, help="Inference steps")
parser.add_argument("--num_frames", type=int, default=81, help="Number of steps for the inference process")
parser.add_argument(
"--num_frames", type=int, default=81, help="Number of steps for the inference process"
)
parser.add_argument("--width", type=int, default=None, help="The width of the generated video")
parser.add_argument("--height", type=int, default=None, help="The height of the generated video")
parser.add_argument("--fps", type=int, default=16, help="The frames per second for the generated video")
parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt")
parser.add_argument("--generate_type", type=str, default="t2v", help="The type of video generation")
parser.add_argument("--dtype", type=str, default="bfloat16", help="The data type for computation")
parser.add_argument(
"--height", type=int, default=None, help="The height of the generated video"
)
parser.add_argument(
"--fps", type=int, default=16, help="The frames per second for the generated video"
)
parser.add_argument(
"--num_videos_per_prompt",
type=int,
default=1,
help="Number of videos to generate per prompt",
)
parser.add_argument(
"--generate_type", type=str, default="t2v", help="The type of video generation"
)
parser.add_argument(
"--dtype", type=str, default="bfloat16", help="The data type for computation"
)
parser.add_argument("--seed", type=int, default=42, help="The seed for reproducibility")
args = parser.parse_args()

View File

@ -19,7 +19,12 @@ import argparse
import os
import torch
import torch._dynamo
from diffusers import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel, CogVideoXPipeline, CogVideoXDPMScheduler
from diffusers import (
AutoencoderKLCogVideoX,
CogVideoXTransformer3DModel,
CogVideoXPipeline,
CogVideoXDPMScheduler,
)
from diffusers.utils import export_to_video
from transformers import T5EncoderModel
from torchao.quantization import quantize_, int8_weight_only
@ -68,9 +73,13 @@ def generate_video(
- quantization_scheme (str): The quantization scheme to use ('int8', 'fp8').
- dtype (torch.dtype): The data type for computation (default is torch.bfloat16).
"""
text_encoder = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder", torch_dtype=dtype)
text_encoder = T5EncoderModel.from_pretrained(
model_path, subfolder="text_encoder", torch_dtype=dtype
)
text_encoder = quantize_model(part=text_encoder, quantization_scheme=quantization_scheme)
transformer = CogVideoXTransformer3DModel.from_pretrained(model_path, subfolder="transformer", torch_dtype=dtype)
transformer = CogVideoXTransformer3DModel.from_pretrained(
model_path, subfolder="transformer", torch_dtype=dtype
)
transformer = quantize_model(part=transformer, quantization_scheme=quantization_scheme)
vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae", torch_dtype=dtype)
vae = quantize_model(part=vae, quantization_scheme=quantization_scheme)
@ -81,7 +90,9 @@ def generate_video(
vae=vae,
torch_dtype=dtype,
)
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
pipe.scheduler = CogVideoXDPMScheduler.from_config(
pipe.scheduler.config, timestep_spacing="trailing"
)
pipe.enable_model_cpu_offload()
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()
@ -100,16 +111,34 @@ def generate_video(
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX")
parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated")
parser.add_argument("--model_path", type=str, default="THUDM/CogVideoX-5b", help="Path of the pre-trained model")
parser.add_argument("--output_path", type=str, default="./output.mp4", help="Path to save generated video")
parser.add_argument("--num_inference_steps", type=int, default=50, help="Inference steps")
parser.add_argument("--guidance_scale", type=float, default=6.0, help="Classifier-free guidance scale")
parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Videos to generate per prompt")
parser.add_argument("--dtype", type=str, default="bfloat16", help="Data type (e.g., 'float16', 'bfloat16')")
parser = argparse.ArgumentParser(
description="Generate a video from a text prompt using CogVideoX"
)
parser.add_argument(
"--quantization_scheme", type=str, default="fp8", choices=["int8", "fp8"], help="Quantization scheme"
"--prompt", type=str, required=True, help="The description of the video to be generated"
)
parser.add_argument(
"--model_path", type=str, default="THUDM/CogVideoX-5b", help="Path of the pre-trained model"
)
parser.add_argument(
"--output_path", type=str, default="./output.mp4", help="Path to save generated video"
)
parser.add_argument("--num_inference_steps", type=int, default=50, help="Inference steps")
parser.add_argument(
"--guidance_scale", type=float, default=6.0, help="Classifier-free guidance scale"
)
parser.add_argument(
"--num_videos_per_prompt", type=int, default=1, help="Videos to generate per prompt"
)
parser.add_argument(
"--dtype", type=str, default="bfloat16", help="Data type (e.g., 'float16', 'bfloat16')"
)
parser.add_argument(
"--quantization_scheme",
type=str,
default="fp8",
choices=["int8", "fp8"],
help="Quantization scheme",
)
parser.add_argument("--num_frames", type=int, default=81, help="Number of frames in the video")
parser.add_argument("--fps", type=int, default=16, help="Frames per second for output video")

View File

@ -104,18 +104,34 @@ def save_video(tensor, output_path):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="CogVideoX encode/decode demo")
parser.add_argument("--model_path", type=str, required=True, help="The path to the CogVideoX model")
parser.add_argument(
"--model_path", type=str, required=True, help="The path to the CogVideoX model"
)
parser.add_argument("--video_path", type=str, help="The path to the video file (for encoding)")
parser.add_argument("--encoded_path", type=str, help="The path to the encoded tensor file (for decoding)")
parser.add_argument("--output_path", type=str, default=".", help="The path to save the output file")
parser.add_argument(
"--mode", type=str, choices=["encode", "decode", "both"], required=True, help="Mode: encode, decode, or both"
"--encoded_path", type=str, help="The path to the encoded tensor file (for decoding)"
)
parser.add_argument(
"--dtype", type=str, default="bfloat16", help="The data type for computation (e.g., 'float16' or 'bfloat16')"
"--output_path", type=str, default=".", help="The path to save the output file"
)
parser.add_argument(
"--device", type=str, default="cuda", help="The device to use for computation (e.g., 'cuda' or 'cpu')"
"--mode",
type=str,
choices=["encode", "decode", "both"],
required=True,
help="Mode: encode, decode, or both",
)
parser.add_argument(
"--dtype",
type=str,
default="bfloat16",
help="The data type for computation (e.g., 'float16' or 'bfloat16')",
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="The device to use for computation (e.g., 'cuda' or 'cpu')",
)
args = parser.parse_args()
@ -126,15 +142,21 @@ if __name__ == "__main__":
assert args.video_path, "Video path must be provided for encoding."
encoded_output = encode_video(args.model_path, args.video_path, dtype, device)
torch.save(encoded_output, args.output_path + "/encoded.pt")
print(f"Finished encoding the video to a tensor, save it to a file at {encoded_output}/encoded.pt")
print(
f"Finished encoding the video to a tensor, save it to a file at {encoded_output}/encoded.pt"
)
elif args.mode == "decode":
assert args.encoded_path, "Encoded tensor path must be provided for decoding."
decoded_output = decode_video(args.model_path, args.encoded_path, dtype, device)
save_video(decoded_output, args.output_path)
print(f"Finished decoding the video and saved it to a file at {args.output_path}/output.mp4")
print(
f"Finished decoding the video and saved it to a file at {args.output_path}/output.mp4"
)
elif args.mode == "both":
assert args.video_path, "Video path must be provided for encoding."
encoded_output = encode_video(args.model_path, args.video_path, dtype, device)
torch.save(encoded_output, args.output_path + "/encoded.pt")
decoded_output = decode_video(args.model_path, args.output_path + "/encoded.pt", dtype, device)
decoded_output = decode_video(
args.model_path, args.output_path + "/encoded.pt", dtype, device
)
save_video(decoded_output, args.output_path)

View File

@ -35,9 +35,9 @@ Video descriptions must have the same num of words as examples below. Extra word
"""
sys_prompt_i2v = """
**Objective**: **Give a highly descriptive video caption based on input image and user input. **. As an expert, delve deep into the image with a discerning eye, leveraging rich creativity, meticulous thought. When describing the details of an image, include appropriate dynamic information to ensure that the video caption contains reasonable actions and plots. If user input is not empty, then the caption should be expanded according to the user's input.
**Objective**: **Give a highly descriptive video caption based on input image and user input. **. As an expert, delve deep into the image with a discerning eye, leveraging rich creativity, meticulous thought. When describing the details of an image, include appropriate dynamic information to ensure that the video caption contains reasonable actions and plots. If user input is not empty, then the caption should be expanded according to the user's input.
**Note**: The input image is the first frame of the video, and the output video caption should describe the motion starting from the current image. User input is optional and can be empty.
**Note**: The input image is the first frame of the video, and the output video caption should describe the motion starting from the current image. User input is optional and can be empty.
**Note**: Don't contain camera transitions!!! Don't contain screen switching!!! Don't contain perspective shifts !!!
@ -144,7 +144,9 @@ def convert_prompt(prompt: str, retry_times: int = 3, type: str = "t2v", image_p
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--prompt", type=str, required=True, help="Prompt to convert")
parser.add_argument("--retry_times", type=int, default=3, help="Number of times to retry the conversion")
parser.add_argument(
"--retry_times", type=int, default=3, help="Number of times to retry the conversion"
)
parser.add_argument("--type", type=str, default="t2v", help="Type of conversion (t2v or i2v)")
parser.add_argument("--image_path", type=str, default=None, help="Path to the image file")
args = parser.parse_args()

View File

@ -30,7 +30,10 @@ import torchvision.transforms as T
from diffusers.models.attention_processor import Attention, CogVideoXAttnProcessor2_0
from diffusers.models.autoencoders import AutoencoderKLCogVideoX
from diffusers.models.embeddings import apply_rotary_emb
from diffusers.models.transformers.cogvideox_transformer_3d import CogVideoXBlock, CogVideoXTransformer3DModel
from diffusers.models.transformers.cogvideox_transformer_3d import (
CogVideoXBlock,
CogVideoXTransformer3DModel,
)
from diffusers.pipelines.cogvideo.pipeline_cogvideox import CogVideoXPipeline, retrieve_timesteps
from diffusers.schedulers import CogVideoXDDIMScheduler, DDIMInverseScheduler
from diffusers.utils import export_to_video
@ -62,22 +65,48 @@ class DDIMInversionArguments(TypedDict):
def get_args() -> DDIMInversionArguments:
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, required=True, help="Path of the pretrained model")
parser.add_argument("--prompt", type=str, required=True, help="Prompt for the direct sample procedure")
parser.add_argument("--video_path", type=str, required=True, help="Path of the video for inversion")
parser.add_argument("--output_path", type=str, default="output", help="Path of the output videos")
parser.add_argument("--guidance_scale", type=float, default=6.0, help="Classifier-free guidance scale")
parser.add_argument("--num_inference_steps", type=int, default=50, help="Number of inference steps")
parser.add_argument("--skip_frames_start", type=int, default=0, help="Number of skipped frames from the start")
parser.add_argument("--skip_frames_end", type=int, default=0, help="Number of skipped frames from the end")
parser.add_argument("--frame_sample_step", type=int, default=None, help="Temporal stride of the sampled frames")
parser.add_argument("--max_num_frames", type=int, default=81, help="Max number of sampled frames")
parser.add_argument(
"--model_path", type=str, required=True, help="Path of the pretrained model"
)
parser.add_argument(
"--prompt", type=str, required=True, help="Prompt for the direct sample procedure"
)
parser.add_argument(
"--video_path", type=str, required=True, help="Path of the video for inversion"
)
parser.add_argument(
"--output_path", type=str, default="output", help="Path of the output videos"
)
parser.add_argument(
"--guidance_scale", type=float, default=6.0, help="Classifier-free guidance scale"
)
parser.add_argument(
"--num_inference_steps", type=int, default=50, help="Number of inference steps"
)
parser.add_argument(
"--skip_frames_start", type=int, default=0, help="Number of skipped frames from the start"
)
parser.add_argument(
"--skip_frames_end", type=int, default=0, help="Number of skipped frames from the end"
)
parser.add_argument(
"--frame_sample_step", type=int, default=None, help="Temporal stride of the sampled frames"
)
parser.add_argument(
"--max_num_frames", type=int, default=81, help="Max number of sampled frames"
)
parser.add_argument("--width", type=int, default=720, help="Resized width of the video frames")
parser.add_argument("--height", type=int, default=480, help="Resized height of the video frames")
parser.add_argument(
"--height", type=int, default=480, help="Resized height of the video frames"
)
parser.add_argument("--fps", type=int, default=8, help="Frame rate of the output videos")
parser.add_argument("--dtype", type=str, default="bf16", choices=["bf16", "fp16"], help="Dtype of the model")
parser.add_argument(
"--dtype", type=str, default="bf16", choices=["bf16", "fp16"], help="Dtype of the model"
)
parser.add_argument("--seed", type=int, default=42, help="Seed for the random number generator")
parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"], help="Device for inference")
parser.add_argument(
"--device", type=str, default="cuda", choices=["cuda", "cpu"], help="Device for inference"
)
args = parser.parse_args()
args.dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float16
@ -116,13 +145,20 @@ class CogVideoXAttnProcessor2_0ForDDIMInversion(CogVideoXAttnProcessor2_0):
# Apply RoPE if needed
if image_rotary_emb is not None:
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
query[:, :, text_seq_length:] = apply_rotary_emb(
query[:, :, text_seq_length:], image_rotary_emb
)
if not attn.is_cross_attention:
if key.size(2) == query.size(2): # Attention for reference hidden states
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
key[:, :, text_seq_length:] = apply_rotary_emb(
key[:, :, text_seq_length:], image_rotary_emb
)
else: # RoPE should be applied to each group of image tokens
key[:, :, text_seq_length : text_seq_length + image_seq_length] = apply_rotary_emb(
key[:, :, text_seq_length : text_seq_length + image_seq_length], image_rotary_emb
key[:, :, text_seq_length : text_seq_length + image_seq_length] = (
apply_rotary_emb(
key[:, :, text_seq_length : text_seq_length + image_seq_length],
image_rotary_emb,
)
)
key[:, :, text_seq_length * 2 + image_seq_length :] = apply_rotary_emb(
key[:, :, text_seq_length * 2 + image_seq_length :], image_rotary_emb
@ -162,8 +198,12 @@ class CogVideoXAttnProcessor2_0ForDDIMInversion(CogVideoXAttnProcessor2_0):
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
attention_mask = attn.prepare_attention_mask(
attention_mask, sequence_length, batch_size
)
attention_mask = attention_mask.view(
batch_size, attn.heads, -1, attention_mask.shape[-1]
)
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
@ -260,14 +300,18 @@ def get_video_frames(
return frames.permute(0, 3, 1, 2).contiguous() # [F, C, H, W]
def encode_video_frames(vae: AutoencoderKLCogVideoX, video_frames: torch.FloatTensor) -> torch.FloatTensor:
def encode_video_frames(
vae: AutoencoderKLCogVideoX, video_frames: torch.FloatTensor
) -> torch.FloatTensor:
video_frames = video_frames.to(device=vae.device, dtype=vae.dtype)
video_frames = video_frames.unsqueeze(0).permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
latent_dist = vae.encode(x=video_frames).latent_dist.sample().transpose(1, 2)
return latent_dist * vae.config.scaling_factor
def export_latents_to_video(pipeline: CogVideoXPipeline, latents: torch.FloatTensor, video_path: str, fps: int):
def export_latents_to_video(
pipeline: CogVideoXPipeline, latents: torch.FloatTensor, video_path: str, fps: int
):
video = pipeline.decode_latents(latents)
frames = pipeline.video_processor.postprocess_video(video=video, output_type="pil")
export_to_video(video_frames=frames[0], output_video_path=video_path, fps=fps)
@ -320,7 +364,9 @@ def sample(
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta)
if isinstance(scheduler, DDIMInverseScheduler): # Inverse scheduler does not accept extra kwargs
if isinstance(
scheduler, DDIMInverseScheduler
): # Inverse scheduler does not accept extra kwargs
extra_step_kwargs = {}
# 7. Create rotary embeds if required
@ -344,7 +390,9 @@ def sample(
if pipeline.interrupt:
continue
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = (
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
)
if reference_latents is not None:
reference = reference_latents[i]
reference = torch.cat([reference] * 2) if do_classifier_free_guidance else reference
@ -371,18 +419,31 @@ def sample(
# perform guidance
if use_dynamic_cfg:
pipeline._guidance_scale = 1 + guidance_scale * (
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
(
1
- math.cos(
math.pi
* ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0
)
)
/ 2
)
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + pipeline.guidance_scale * (noise_pred_text - noise_pred_uncond)
noise_pred = noise_pred_uncond + pipeline.guidance_scale * (
noise_pred_text - noise_pred_uncond
)
# compute the noisy sample x_t-1 -> x_t
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
latents = scheduler.step(
noise_pred, t, latents, **extra_step_kwargs, return_dict=False
)[0]
latents = latents.to(prompt_embeds.dtype)
trajectory[i] = latents
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
if i == len(timesteps) - 1 or (
(i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0
):
progress_bar.update()
# Offload all models
@ -410,7 +471,9 @@ def ddim_inversion(
seed: int,
device: torch.device,
):
pipeline: CogVideoXPipeline = CogVideoXPipeline.from_pretrained(model_path, torch_dtype=dtype).to(device=device)
pipeline: CogVideoXPipeline = CogVideoXPipeline.from_pretrained(
model_path, torch_dtype=dtype
).to(device=device)
if not pipeline.transformer.config.use_rotary_positional_embeddings:
raise NotImplementedError("This script supports CogVideoX 5B model only.")
video_frames = get_video_frames(

View File

@ -35,7 +35,7 @@ Set the following environment variables in your system:
## Installation
```bash
pip install -r requirements.txt
pip install -r requirements.txt
```
## Running the code
@ -43,5 +43,3 @@ pip install -r requirements.txt
```bash
python app.py
```

View File

@ -39,11 +39,15 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
MODEL = "THUDM/CogVideoX-5b"
hf_hub_download(repo_id="ai-forever/Real-ESRGAN", filename="RealESRGAN_x4.pth", local_dir="model_real_esran")
hf_hub_download(
repo_id="ai-forever/Real-ESRGAN", filename="RealESRGAN_x4.pth", local_dir="model_real_esran"
)
snapshot_download(repo_id="AlexWortega/RIFE", local_dir="model_rife")
pipe = CogVideoXPipeline.from_pretrained(MODEL, torch_dtype=torch.bfloat16).to(device)
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
pipe.scheduler = CogVideoXDPMScheduler.from_config(
pipe.scheduler.config, timestep_spacing="trailing"
)
pipe_video = CogVideoXVideoToVideoPipeline.from_pretrained(
MODEL,
transformer=pipe.transformer,
@ -296,8 +300,16 @@ def delete_old_files():
threading.Thread(target=delete_old_files, daemon=True).start()
examples_videos = [["example_videos/horse.mp4"], ["example_videos/kitten.mp4"], ["example_videos/train_running.mp4"]]
examples_images = [["example_images/beach.png"], ["example_images/street.png"], ["example_images/camping.png"]]
examples_videos = [
["example_videos/horse.mp4"],
["example_videos/kitten.mp4"],
["example_videos/train_running.mp4"],
]
examples_images = [
["example_images/beach.png"],
["example_images/street.png"],
["example_images/camping.png"],
]
with gr.Blocks() as demo:
gr.Markdown("""
@ -317,19 +329,31 @@ with gr.Blocks() as demo:
"></a>
</div>
<div style="text-align: center; font-size: 15px; font-weight: bold; color: red; margin-bottom: 20px;">
This demo is for academic research and experimental use only.
This demo is for academic research and experimental use only.
</div>
""")
with gr.Row():
with gr.Column():
with gr.Accordion("I2V: Image Input (cannot be used simultaneously with video input)", open=False):
with gr.Accordion(
"I2V: Image Input (cannot be used simultaneously with video input)", open=False
):
image_input = gr.Image(label="Input Image (will be cropped to 720 * 480)")
examples_component_images = gr.Examples(examples_images, inputs=[image_input], cache_examples=False)
with gr.Accordion("V2V: Video Input (cannot be used simultaneously with image input)", open=False):
video_input = gr.Video(label="Input Video (will be cropped to 49 frames, 6 seconds at 8fps)")
examples_component_images = gr.Examples(
examples_images, inputs=[image_input], cache_examples=False
)
with gr.Accordion(
"V2V: Video Input (cannot be used simultaneously with image input)", open=False
):
video_input = gr.Video(
label="Input Video (will be cropped to 49 frames, 6 seconds at 8fps)"
)
strength = gr.Slider(0.1, 1.0, value=0.8, step=0.01, label="Strength")
examples_component_videos = gr.Examples(examples_videos, inputs=[video_input], cache_examples=False)
prompt = gr.Textbox(label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5)
examples_component_videos = gr.Examples(
examples_videos, inputs=[video_input], cache_examples=False
)
prompt = gr.Textbox(
label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5
)
with gr.Row():
gr.Markdown(
@ -340,11 +364,16 @@ with gr.Blocks() as demo:
with gr.Column():
with gr.Row():
seed_param = gr.Number(
label="Inference Seed (Enter a positive number, -1 for random)", value=-1
label="Inference Seed (Enter a positive number, -1 for random)",
value=-1,
)
with gr.Row():
enable_scale = gr.Checkbox(label="Super-Resolution (720 × 480 -> 2880 × 1920)", value=False)
enable_rife = gr.Checkbox(label="Frame Interpolation (8fps -> 16fps)", value=False)
enable_scale = gr.Checkbox(
label="Super-Resolution (720 × 480 -> 2880 × 1920)", value=False
)
enable_rife = gr.Checkbox(
label="Frame Interpolation (8fps -> 16fps)", value=False
)
gr.Markdown(
"✨In this demo, we use [RIFE](https://github.com/hzwer/ECCV2022-RIFE) for frame interpolation and [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) for upscaling(Super-Resolution).<br>&nbsp;&nbsp;&nbsp;&nbsp;The entire process is based on open-source solutions."
)
@ -430,7 +459,7 @@ with gr.Blocks() as demo:
seed_value,
scale_status,
rife_status,
progress=gr.Progress(track_tqdm=True)
progress=gr.Progress(track_tqdm=True),
):
latents, seed = infer(
prompt,
@ -457,7 +486,9 @@ with gr.Blocks() as demo:
image_pil = VaeImageProcessor.numpy_to_pil(image_np)
batch_video_frames.append(image_pil)
video_path = utils.save_video(batch_video_frames[0], fps=math.ceil((len(batch_video_frames[0]) - 1) / 6))
video_path = utils.save_video(
batch_video_frames[0], fps=math.ceil((len(batch_video_frames[0]) - 1) / 6)
)
video_update = gr.update(visible=True, value=video_path)
gif_path = convert_to_gif(video_path)
gif_update = gr.update(visible=True, value=gif_path)

View File

@ -16,4 +16,4 @@ imageio>=2.34.2
imageio-ffmpeg>=0.5.1
openai>=1.45.0
moviepy>=2.0.0
pillow==9.5.0
pillow==9.5.0

View File

@ -3,7 +3,9 @@ from .refine import *
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
return nn.Sequential(
torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1),
torch.nn.ConvTranspose2d(
in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1
),
nn.PReLU(out_planes),
)
@ -46,7 +48,11 @@ class IFBlock(nn.Module):
if scale != 1:
x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False)
if flow != None:
flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False) * 1.0 / scale
flow = (
F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False)
* 1.0
/ scale
)
x = torch.cat((x, flow), 1)
x = self.conv0(x)
x = self.convblock(x) + x
@ -102,7 +108,9 @@ class IFNet(nn.Module):
warped_img0_teacher = warp(img0, flow_teacher[:, :2])
warped_img1_teacher = warp(img1, flow_teacher[:, 2:4])
mask_teacher = torch.sigmoid(mask + mask_d)
merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * (1 - mask_teacher)
merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * (
1 - mask_teacher
)
else:
flow_teacher = None
merged_teacher = None
@ -110,11 +118,16 @@ class IFNet(nn.Module):
merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
if gt.shape[1] == 3:
loss_mask = (
((merged[i] - gt).abs().mean(1, True) > (merged_teacher - gt).abs().mean(1, True) + 0.01)
(
(merged[i] - gt).abs().mean(1, True)
> (merged_teacher - gt).abs().mean(1, True) + 0.01
)
.float()
.detach()
)
loss_distill += (((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask).mean()
loss_distill += (
((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask
).mean()
c0 = self.contextnet(img0, flow[:, :2])
c1 = self.contextnet(img1, flow[:, 2:4])
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)

View File

@ -3,7 +3,9 @@ from .refine_2R import *
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
return nn.Sequential(
torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1),
torch.nn.ConvTranspose2d(
in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1
),
nn.PReLU(out_planes),
)
@ -46,7 +48,11 @@ class IFBlock(nn.Module):
if scale != 1:
x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False)
if flow != None:
flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False) * 1.0 / scale
flow = (
F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False)
* 1.0
/ scale
)
x = torch.cat((x, flow), 1)
x = self.conv0(x)
x = self.convblock(x) + x
@ -102,7 +108,9 @@ class IFNet(nn.Module):
warped_img0_teacher = warp(img0, flow_teacher[:, :2])
warped_img1_teacher = warp(img1, flow_teacher[:, 2:4])
mask_teacher = torch.sigmoid(mask + mask_d)
merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * (1 - mask_teacher)
merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * (
1 - mask_teacher
)
else:
flow_teacher = None
merged_teacher = None
@ -110,11 +118,16 @@ class IFNet(nn.Module):
merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
if gt.shape[1] == 3:
loss_mask = (
((merged[i] - gt).abs().mean(1, True) > (merged_teacher - gt).abs().mean(1, True) + 0.01)
(
(merged[i] - gt).abs().mean(1, True)
> (merged_teacher - gt).abs().mean(1, True) + 0.01
)
.float()
.detach()
)
loss_distill += (((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask).mean()
loss_distill += (
((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask
).mean()
c0 = self.contextnet(img0, flow[:, :2])
c1 = self.contextnet(img1, flow[:, 2:4])
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)

View File

@ -61,11 +61,19 @@ class IFBlock(nn.Module):
def forward(self, x, flow, scale=1):
x = F.interpolate(
x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False
x,
scale_factor=1.0 / scale,
mode="bilinear",
align_corners=False,
recompute_scale_factor=False,
)
flow = (
F.interpolate(
flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False
flow,
scale_factor=1.0 / scale,
mode="bilinear",
align_corners=False,
recompute_scale_factor=False,
)
* 1.0
/ scale
@ -78,11 +86,21 @@ class IFBlock(nn.Module):
flow = self.conv1(feat)
mask = self.conv2(feat)
flow = (
F.interpolate(flow, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
F.interpolate(
flow,
scale_factor=scale,
mode="bilinear",
align_corners=False,
recompute_scale_factor=False,
)
* scale
)
mask = F.interpolate(
mask, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False
mask,
scale_factor=scale,
mode="bilinear",
align_corners=False,
recompute_scale_factor=False,
)
return flow, mask
@ -112,7 +130,11 @@ class IFNet(nn.Module):
loss_cons = 0
block = [self.block0, self.block1, self.block2]
for i in range(3):
f0, m0 = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], mask), 1), flow, scale=scale_list[i])
f0, m0 = block[i](
torch.cat((warped_img0[:, :3], warped_img1[:, :3], mask), 1),
flow,
scale=scale_list[i],
)
f1, m1 = block[i](
torch.cat((warped_img1[:, :3], warped_img0[:, :3], -mask), 1),
torch.cat((flow[:, 2:4], flow[:, :2]), 1),

View File

@ -3,7 +3,9 @@ from .refine import *
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
return nn.Sequential(
torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1),
torch.nn.ConvTranspose2d(
in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1
),
nn.PReLU(out_planes),
)
@ -46,7 +48,11 @@ class IFBlock(nn.Module):
if scale != 1:
x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False)
if flow != None:
flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False) * 1.0 / scale
flow = (
F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False)
* 1.0
/ scale
)
x = torch.cat((x, flow), 1)
x = self.conv0(x)
x = self.convblock(x) + x
@ -83,7 +89,9 @@ class IFNet_m(nn.Module):
for i in range(3):
if flow != None:
flow_d, mask_d = stu[i](
torch.cat((img0, img1, timestep, warped_img0, warped_img1, mask), 1), flow, scale=scale[i]
torch.cat((img0, img1, timestep, warped_img0, warped_img1, mask), 1),
flow,
scale=scale[i],
)
flow = flow + flow_d
mask = mask + mask_d
@ -97,13 +105,17 @@ class IFNet_m(nn.Module):
merged.append(merged_student)
if gt.shape[1] == 3:
flow_d, mask_d = self.block_tea(
torch.cat((img0, img1, timestep, warped_img0, warped_img1, mask, gt), 1), flow, scale=1
torch.cat((img0, img1, timestep, warped_img0, warped_img1, mask, gt), 1),
flow,
scale=1,
)
flow_teacher = flow + flow_d
warped_img0_teacher = warp(img0, flow_teacher[:, :2])
warped_img1_teacher = warp(img1, flow_teacher[:, 2:4])
mask_teacher = torch.sigmoid(mask + mask_d)
merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * (1 - mask_teacher)
merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * (
1 - mask_teacher
)
else:
flow_teacher = None
merged_teacher = None
@ -111,11 +123,16 @@ class IFNet_m(nn.Module):
merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
if gt.shape[1] == 3:
loss_mask = (
((merged[i] - gt).abs().mean(1, True) > (merged_teacher - gt).abs().mean(1, True) + 0.01)
(
(merged[i] - gt).abs().mean(1, True)
> (merged_teacher - gt).abs().mean(1, True) + 0.01
)
.float()
.detach()
)
loss_distill += (((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask).mean()
loss_distill += (
((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask
).mean()
if returnflow:
return flow
else:

View File

@ -44,7 +44,9 @@ class Model:
if torch.cuda.is_available():
self.flownet.load_state_dict(convert(torch.load("{}/flownet.pkl".format(path))))
else:
self.flownet.load_state_dict(convert(torch.load("{}/flownet.pkl".format(path), map_location="cpu")))
self.flownet.load_state_dict(
convert(torch.load("{}/flownet.pkl".format(path), map_location="cpu"))
)
def save_model(self, path, rank=0):
if rank == 0:

View File

@ -29,10 +29,14 @@ def downsample(x):
def upsample(x):
cc = torch.cat([x, torch.zeros(x.shape[0], x.shape[1], x.shape[2], x.shape[3]).to(device)], dim=3)
cc = torch.cat(
[x, torch.zeros(x.shape[0], x.shape[1], x.shape[2], x.shape[3]).to(device)], dim=3
)
cc = cc.view(x.shape[0], x.shape[1], x.shape[2] * 2, x.shape[3])
cc = cc.permute(0, 1, 3, 2)
cc = torch.cat([cc, torch.zeros(x.shape[0], x.shape[1], x.shape[3], x.shape[2] * 2).to(device)], dim=3)
cc = torch.cat(
[cc, torch.zeros(x.shape[0], x.shape[1], x.shape[3], x.shape[2] * 2).to(device)], dim=3
)
cc = cc.view(x.shape[0], x.shape[1], x.shape[3] * 2, x.shape[2] * 2)
x_up = cc.permute(0, 1, 3, 2)
return conv_gauss(x_up, 4 * gauss_kernel(channels=x.shape[1]))
@ -64,6 +68,10 @@ class LapLoss(torch.nn.Module):
self.gauss_kernel = gauss_kernel(channels=channels)
def forward(self, input, target):
pyr_input = laplacian_pyramid(img=input, kernel=self.gauss_kernel, max_levels=self.max_levels)
pyr_target = laplacian_pyramid(img=target, kernel=self.gauss_kernel, max_levels=self.max_levels)
pyr_input = laplacian_pyramid(
img=input, kernel=self.gauss_kernel, max_levels=self.max_levels
)
pyr_target = laplacian_pyramid(
img=target, kernel=self.gauss_kernel, max_levels=self.max_levels
)
return sum(torch.nn.functional.l1_loss(a, b) for a, b in zip(pyr_input, pyr_target))

View File

@ -7,7 +7,9 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def gaussian(window_size, sigma):
gauss = torch.Tensor([exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2)) for x in range(window_size)])
gauss = torch.Tensor(
[exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2)) for x in range(window_size)]
)
return gauss / gauss.sum()
@ -22,7 +24,9 @@ def create_window_3d(window_size, channel=1):
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t())
_3D_window = _2D_window.unsqueeze(2) @ (_1D_window.t())
window = _3D_window.expand(1, channel, window_size, window_size, window_size).contiguous().to(device)
window = (
_3D_window.expand(1, channel, window_size, window_size, window_size).contiguous().to(device)
)
return window
@ -50,16 +54,35 @@ def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False,
# mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
# mu2 = F.conv2d(img2, window, padding=padd, groups=channel)
mu1 = F.conv2d(F.pad(img1, (5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=channel)
mu2 = F.conv2d(F.pad(img2, (5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=channel)
mu1 = F.conv2d(
F.pad(img1, (5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=channel
)
mu2 = F.conv2d(
F.pad(img2, (5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=channel
)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = F.conv2d(F.pad(img1 * img1, (5, 5, 5, 5), "replicate"), window, padding=padd, groups=channel) - mu1_sq
sigma2_sq = F.conv2d(F.pad(img2 * img2, (5, 5, 5, 5), "replicate"), window, padding=padd, groups=channel) - mu2_sq
sigma12 = F.conv2d(F.pad(img1 * img2, (5, 5, 5, 5), "replicate"), window, padding=padd, groups=channel) - mu1_mu2
sigma1_sq = (
F.conv2d(
F.pad(img1 * img1, (5, 5, 5, 5), "replicate"), window, padding=padd, groups=channel
)
- mu1_sq
)
sigma2_sq = (
F.conv2d(
F.pad(img2 * img2, (5, 5, 5, 5), "replicate"), window, padding=padd, groups=channel
)
- mu2_sq
)
sigma12 = (
F.conv2d(
F.pad(img1 * img2, (5, 5, 5, 5), "replicate"), window, padding=padd, groups=channel
)
- mu1_mu2
)
C1 = (0.01 * L) ** 2
C2 = (0.03 * L) ** 2
@ -80,7 +103,9 @@ def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False,
return ret
def ssim_matlab(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
def ssim_matlab(
img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None
):
# Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
if val_range is None:
if torch.max(img1) > 128:
@ -106,16 +131,35 @@ def ssim_matlab(img1, img2, window_size=11, window=None, size_average=True, full
img1 = img1.unsqueeze(1)
img2 = img2.unsqueeze(1)
mu1 = F.conv3d(F.pad(img1, (5, 5, 5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=1)
mu2 = F.conv3d(F.pad(img2, (5, 5, 5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=1)
mu1 = F.conv3d(
F.pad(img1, (5, 5, 5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=1
)
mu2 = F.conv3d(
F.pad(img2, (5, 5, 5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=1
)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = F.conv3d(F.pad(img1 * img1, (5, 5, 5, 5, 5, 5), "replicate"), window, padding=padd, groups=1) - mu1_sq
sigma2_sq = F.conv3d(F.pad(img2 * img2, (5, 5, 5, 5, 5, 5), "replicate"), window, padding=padd, groups=1) - mu2_sq
sigma12 = F.conv3d(F.pad(img1 * img2, (5, 5, 5, 5, 5, 5), "replicate"), window, padding=padd, groups=1) - mu1_mu2
sigma1_sq = (
F.conv3d(
F.pad(img1 * img1, (5, 5, 5, 5, 5, 5), "replicate"), window, padding=padd, groups=1
)
- mu1_sq
)
sigma2_sq = (
F.conv3d(
F.pad(img2 * img2, (5, 5, 5, 5, 5, 5), "replicate"), window, padding=padd, groups=1
)
- mu2_sq
)
sigma12 = (
F.conv3d(
F.pad(img1 * img2, (5, 5, 5, 5, 5, 5), "replicate"), window, padding=padd, groups=1
)
- mu1_mu2
)
C1 = (0.01 * L) ** 2
C2 = (0.03 * L) ** 2
@ -143,7 +187,14 @@ def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normal
mssim = []
mcs = []
for _ in range(levels):
sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range)
sim, cs = ssim(
img1,
img2,
window_size=window_size,
size_average=size_average,
full=True,
val_range=val_range,
)
mssim.append(sim)
mcs.append(cs)
@ -187,7 +238,9 @@ class SSIM(torch.nn.Module):
self.window = window
self.channel = channel
_ssim = ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)
_ssim = ssim(
img1, img2, window=window, window_size=self.window_size, size_average=self.size_average
)
dssim = (1 - _ssim) / 2
return dssim

View File

@ -24,7 +24,12 @@ def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
return nn.Sequential(
torch.nn.ConvTranspose2d(
in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1, bias=True
in_channels=in_planes,
out_channels=out_planes,
kernel_size=4,
stride=2,
padding=1,
bias=True,
),
nn.PReLU(out_planes),
)
@ -56,25 +61,49 @@ class Contextnet(nn.Module):
def forward(self, x, flow):
x = self.conv1(x)
flow = (
F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False)
F.interpolate(
flow,
scale_factor=0.5,
mode="bilinear",
align_corners=False,
recompute_scale_factor=False,
)
* 0.5
)
f1 = warp(x, flow)
x = self.conv2(x)
flow = (
F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False)
F.interpolate(
flow,
scale_factor=0.5,
mode="bilinear",
align_corners=False,
recompute_scale_factor=False,
)
* 0.5
)
f2 = warp(x, flow)
x = self.conv3(x)
flow = (
F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False)
F.interpolate(
flow,
scale_factor=0.5,
mode="bilinear",
align_corners=False,
recompute_scale_factor=False,
)
* 0.5
)
f3 = warp(x, flow)
x = self.conv4(x)
flow = (
F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False)
F.interpolate(
flow,
scale_factor=0.5,
mode="bilinear",
align_corners=False,
recompute_scale_factor=False,
)
* 0.5
)
f4 = warp(x, flow)

View File

@ -24,7 +24,12 @@ def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
return nn.Sequential(
torch.nn.ConvTranspose2d(
in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1, bias=True
in_channels=in_planes,
out_channels=out_planes,
kernel_size=4,
stride=2,
padding=1,
bias=True,
),
nn.PReLU(out_planes),
)
@ -59,19 +64,37 @@ class Contextnet(nn.Module):
f1 = warp(x, flow)
x = self.conv2(x)
flow = (
F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False)
F.interpolate(
flow,
scale_factor=0.5,
mode="bilinear",
align_corners=False,
recompute_scale_factor=False,
)
* 0.5
)
f2 = warp(x, flow)
x = self.conv3(x)
flow = (
F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False)
F.interpolate(
flow,
scale_factor=0.5,
mode="bilinear",
align_corners=False,
recompute_scale_factor=False,
)
* 0.5
)
f3 = warp(x, flow)
x = self.conv4(x)
flow = (
F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False)
F.interpolate(
flow,
scale_factor=0.5,
mode="bilinear",
align_corners=False,
recompute_scale_factor=False,
)
* 0.5
)
f4 = warp(x, flow)

View File

@ -9,6 +9,7 @@ import logging
import skvideo.io
from rife.RIFE_HDv3 import Model
from huggingface_hub import hf_hub_download, snapshot_download
logger = logging.getLogger(__name__)
device = "cuda" if torch.cuda.is_available() else "cpu"
@ -19,7 +20,7 @@ def pad_image(img, scale):
tmp = max(32, int(32 / scale))
ph = ((h - 1) // tmp + 1) * tmp
pw = ((w - 1) // tmp + 1) * tmp
padding = (0, pw - w, 0, ph - h)
padding = (0, pw - w, 0, ph - h)
return F.pad(img, padding), padding
@ -45,15 +46,15 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi
for b in range(samples.shape[0]):
frame = samples[b : b + 1]
_, _, h, w = frame.shape
I0 = samples[b : b + 1]
I1 = samples[b + 1 : b + 2] if b + 2 < samples.shape[0] else samples[-1:]
I0, padding = pad_image(I0, upscale_amount)
I0 = I0.to(torch.float)
I1, _ = pad_image(I1, upscale_amount)
I1 = I1.to(torch.float)
# [c, h, w]
I0_small = F.interpolate(I0, (32, 32), mode="bilinear", align_corners=False)
I1_small = F.interpolate(I1, (32, 32), mode="bilinear", align_corners=False)
@ -70,21 +71,20 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi
# print(f'I0 shape:{I0.shape}')
# print(f'I1 shape:{I1.shape}')
I1 = make_inference(model, I0, I1, upscale_amount, 1)
# print(f'I0 shape:{I0.shape}')
# print(f'I1[0] shape:{I1[0].shape}')
# print(f'I1[0] shape:{I1[0].shape}')
I1 = I1[0]
# print(f'I1[0] unpadded shape:{I1.shape}')
# print(f'I1[0] unpadded shape:{I1.shape}')
I1_small = F.interpolate(I1, (32, 32), mode="bilinear", align_corners=False)
ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
if padding[3] > 0 and padding[1] >0 :
frame = I1[:, :, : -padding[3],:-padding[1]]
if padding[3] > 0 and padding[1] > 0:
frame = I1[:, :, : -padding[3], : -padding[1]]
elif padding[3] > 0:
frame = I1[:, :, : -padding[3],:]
elif padding[1] >0:
frame = I1[:, :, :,:-padding[1]]
frame = I1[:, :, : -padding[3], :]
elif padding[1] > 0:
frame = I1[:, :, :, : -padding[1]]
else:
frame = I1
@ -101,8 +101,7 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi
frame = F.interpolate(frame, size=(h, w))
output.append(frame.to(output_device))
for i, tmp_frame in enumerate(tmp_output):
for i, tmp_frame in enumerate(tmp_output):
# tmp_frame, _ = pad_image(tmp_frame, upscale_amount)
tmp_frame = F.interpolate(tmp_frame, size=(h, w))
output.append(tmp_frame.to(output_device))
@ -145,9 +144,7 @@ def rife_inference_with_path(model, video_path):
frame_rgb = frame[..., ::-1]
frame_rgb = frame_rgb.copy()
tensor = torch.from_numpy(frame_rgb).float().to("cpu", non_blocking=True).float() / 255.0
pt_frame_data.append(
tensor.permute(2, 0, 1)
) # to [c, h, w,]
pt_frame_data.append(tensor.permute(2, 0, 1)) # to [c, h, w,]
pt_frame = torch.from_numpy(np.stack(pt_frame_data))
pt_frame = pt_frame.to(device)
@ -170,7 +167,9 @@ def rife_inference_with_latents(model, latents):
latent = latents[i]
frames = ssim_interpolation_rife(model, latent)
pt_image = torch.stack([frames[i].squeeze(0) for i in range(len(frames))]) # (to [f, c, w, h])
pt_image = torch.stack(
[frames[i].squeeze(0) for i in range(len(frames))]
) # (to [f, c, w, h])
rife_results.append(pt_image)
return torch.stack(rife_results)
@ -179,6 +178,6 @@ def rife_inference_with_latents(model, latents):
# if __name__ == "__main__":
# snapshot_download(repo_id="AlexWortega/RIFE", local_dir="model_rife")
# model = load_rife_model("model_rife")
# video_path = rife_inference_with_path(model, "/mnt/ceph/develop/jiawei/CogVideo/output/20241003_130720.mp4")
# print(video_path)
# print(video_path)

View File

@ -22,7 +22,7 @@ def load_torch_file(ckpt, device=None, dtype=torch.float16):
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
sd = safetensors.torch.load_file(ckpt, device=device.type)
else:
if not "weights_only" in torch.load.__code__.co_varnames:
if "weights_only" not in torch.load.__code__.co_varnames:
logger.warning(
"Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely."
)
@ -74,27 +74,39 @@ def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
@torch.inference_mode()
def tiled_scale_multidim(
samples, function, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", pbar=None
samples,
function,
tile=(64, 64),
overlap=8,
upscale_amount=4,
out_channels=3,
output_device="cpu",
pbar=None,
):
dims = len(tile)
print(f"samples dtype:{samples.dtype}")
output = torch.empty(
[samples.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), samples.shape[2:])),
[samples.shape[0], out_channels]
+ list(map(lambda a: round(a * upscale_amount), samples.shape[2:])),
device=output_device,
)
for b in range(samples.shape[0]):
s = samples[b : b + 1]
out = torch.zeros(
[s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])),
[s.shape[0], out_channels]
+ list(map(lambda a: round(a * upscale_amount), s.shape[2:])),
device=output_device,
)
out_div = torch.zeros(
[s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])),
[s.shape[0], out_channels]
+ list(map(lambda a: round(a * upscale_amount), s.shape[2:])),
device=output_device,
)
for it in itertools.product(*map(lambda a: range(0, a[0], a[1] - overlap), zip(s.shape[2:], tile))):
for it in itertools.product(
*map(lambda a: range(0, a[0], a[1] - overlap), zip(s.shape[2:], tile))
):
s_in = s
upscaled = []
@ -142,7 +154,14 @@ def tiled_scale(
pbar=None,
):
return tiled_scale_multidim(
samples, function, (tile_y, tile_x), overlap, upscale_amount, out_channels, output_device, pbar
samples,
function,
(tile_y, tile_x),
overlap,
upscale_amount,
out_channels,
output_device,
pbar,
)
@ -186,7 +205,9 @@ def upscale(upscale_model, tensor: torch.Tensor, inf_device, output_device="cpu"
return s
def upscale_batch_and_concatenate(upscale_model, latents, inf_device, output_device="cpu") -> torch.Tensor:
def upscale_batch_and_concatenate(
upscale_model, latents, inf_device, output_device="cpu"
) -> torch.Tensor:
upscaled_latents = []
for i in range(latents.size(0)):
latent = latents[i]
@ -207,7 +228,9 @@ class ProgressBar:
def __init__(self, total, desc=None):
self.total = total
self.current = 0
self.b_unit = tqdm.tqdm(total=total, desc="ProgressBar context index: 0" if desc is None else desc)
self.b_unit = tqdm.tqdm(
total=total, desc="ProgressBar context index: 0" if desc is None else desc
)
def update(self, value):
if value > self.total:

View File

@ -22,7 +22,9 @@ from datetime import datetime, timedelta
from openai import OpenAI
from moviepy import VideoFileClip
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to("cuda")
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to(
"cuda"
)
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()
@ -95,7 +97,12 @@ def convert_prompt(prompt: str, retry_times: int = 3) -> str:
return prompt
def infer(prompt: str, num_inference_steps: int, guidance_scale: float, progress=gr.Progress(track_tqdm=True)):
def infer(
prompt: str,
num_inference_steps: int,
guidance_scale: float,
progress=gr.Progress(track_tqdm=True),
):
torch.cuda.empty_cache()
video = pipe(
prompt=prompt,
@ -151,7 +158,9 @@ with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5)
prompt = gr.Textbox(
label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5
)
with gr.Row():
gr.Markdown(
@ -176,7 +185,13 @@ with gr.Blocks() as demo:
download_video_button = gr.File(label="📥 Download Video", visible=False)
download_gif_button = gr.File(label="📥 Download GIF", visible=False)
def generate(prompt, num_inference_steps, guidance_scale, model_choice, progress=gr.Progress(track_tqdm=True)):
def generate(
prompt,
num_inference_steps,
guidance_scale,
model_choice,
progress=gr.Progress(track_tqdm=True),
):
tensor = infer(prompt, num_inference_steps, guidance_scale, progress=progress)
video_path = save_video(tensor)
video_update = gr.update(visible=True, value=video_path)

View File

@ -4,4 +4,3 @@
<p> 扫码关注公众号,加入「 CogVideoX 交流群」 </p>
<p> Scan the QR code to follow the official account and join the "CogVLM Discussion Group" </p>
</div>

View File

@ -1,49 +0,0 @@
# Contribution Guide
There may still be many incomplete aspects in this project.
We look forward to your contributions to the repository in the following areas. If you complete the work mentioned above
and are willing to submit a PR and share it with the community, upon review, we
will acknowledge your contribution on the project homepage.
## Model Algorithms
- Support for model quantization inference (Int4 quantization project)
- Optimization of model fine-tuning data loading (replacing the existing decord tool)
## Model Engineering
- Model fine-tuning examples / Best prompt practices
- Inference adaptation on different devices (e.g., MLX framework)
- Any tools related to the model
- Any minimal fully open-source project using the CogVideoX open-source model
## Code Standards
Good code style is an art. We have prepared a `pyproject.toml` configuration file for the project to standardize code
style. You can organize the code according to the following specifications:
1. Install the `ruff` tool
```shell
pip install ruff
```
Then, run the `ruff` tool
```shell
ruff check tools sat inference
```
Check the code style. If there are issues, you can automatically fix them using the `ruff format` command.
```shell
ruff format tools sat inference
```
Once your code meets the standard, there should be no errors.
## Naming Conventions
1. Please use English names, do not use Pinyin or other language names. All comments should be in English.
2. Please strictly follow the PEP8 specification and use underscores to separate words. Do not use names like a, b, c.

View File

@ -1,47 +0,0 @@
# コントリビューションガイド
本プロジェクトにはまだ多くの未完成の部分があります。
以下の分野でリポジトリへの貢献をお待ちしています。上記の作業を完了し、PRを提出してコミュニティと共有する意志がある場合、レビュー後、プロジェクトのホームページで貢献を認識します。
## モデルアルゴリズム
- モデル量子化推論のサポート (Int4量子化プロジェクト)
- モデルのファインチューニングデータロードの最適化既存のdecordツールの置き換え
## モデルエンジニアリング
- モデルのファインチューニング例 / 最適なプロンプトの実践
- 異なるデバイスでの推論適応(例: MLXフレームワーク
- モデルに関連するツール
- CogVideoXオープンソースモデルを使用した、完全にオープンソースの最小プロジェクト
## コード標準
良いコードスタイルは一種の芸術です。本プロジェクトにはコードスタイルを標準化するための `pyproject.toml`
設定ファイルを用意しています。以下の仕様に従ってコードを整理してください。
1. `ruff` ツールをインストールする
```shell
pip install ruff
```
次に、`ruff` ツールを実行します
```shell
ruff check tools sat inference
```
コードスタイルを確認します。問題がある場合は、`ruff format` コマンドを使用して自動修正できます。
```shell
ruff format tools sat inference
```
コードが標準に準拠したら、エラーはなくなるはずです。
## 命名規則
1. 英語名を使用してください。ピンインや他の言語の名前を使用しないでください。すべてのコメントは英語で記載してください。
2. PEP8仕様に厳密に従い、単語をアンダースコアで区切ってください。a、b、cのような名前は使用しないでください。

View File

@ -1,44 +0,0 @@
# 贡献指南
本项目可能还存在很多不完善的内容。 我们期待您在以下方面与我们共建仓库, 如果您完成了上述工作并愿意PR和分享到社区在通过审核后我们将在项目首页感谢您的贡献。
## 模型算法
- 模型量化推理支持 (Int4量化工程)
- 模型微调数据载入优化支持(替换现有的decord工具)
## 模型工程
- 模型微调示例 / 最佳提示词实践
- 不同设备上的推理适配(MLX等框架)
- 任何模型周边工具
- 任何使用CogVideoX开源模型制作的最小完整开源项目
## 代码规范
良好的代码风格是一种艺术,我们已经为项目准备好了`pyproject.toml`配置文件,用于规范代码风格。您可以按照以下规范梳理代码:
1. 安装`ruff`工具
```shell
pip install ruff
```
接着,运行`ruff`工具
```shell
ruff check tools sat inference
```
检查代码风格,如果有问题,您可以通过`ruff format .`命令自动修复。
```shell
ruff format tools sat inference
```
如果您的代码符合规范,应该不会出现任何的错误。
## 命名规范
- 请使用英文命名,不要使用拼音或者其他语言命名。所有的注释均使用英文。
- 请严格遵循 PEP8 规范,使用下划线分割单词。请勿使用 a,b,c 这样的命名。

View File

@ -20,7 +20,7 @@ Videos 1-8:
## CogVideoX-2B
Videos 1-4:
Videos 1-4:
1. A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting.

View File

@ -64,7 +64,7 @@ Arrange the model files in the following structure:
└── 3d-vae.pt
```
Since model weight files are large, its recommended to use `git lfs`.
Since model weight files are large, its recommended to use `git lfs`.
See [here](https://github.com/git-lfs/git-lfs?tab=readme-ov-file#installing) for `git lfs` installation.
```
@ -339,7 +339,7 @@ Modify the files in `configs/sft.yaml` (full fine-tuning) as follows:
eval_iters: 1 # Evaluation iterations
eval_interval: 100 # Evaluation interval
eval_batch_size: 1 # Evaluation batch size
save: ckpts # Model save path
save: ckpts # Model save path
save_interval: 100 # Save interval
log_interval: 20 # Log output interval
train_data: [ "your train data path" ]
@ -411,7 +411,7 @@ run_cmd="$environs python sample_video.py --base configs/cogvideox_<model parame
Then, run the code:
```
bash inference.sh
bash inference.sh
```
### Converting to Huggingface Diffusers-compatible Weights
@ -452,4 +452,4 @@ Lora adds a low-rank weight to the attention structure of the model.
```
Using `export_sat_lora_weight.py` will convert these to the HF format Lora structure.
![alt text](../resources/hf_lora_weights.png)
![alt text](../resources/hf_lora_weights.png)

View File

@ -335,7 +335,7 @@ bash inference.sh
eval_iters: 1 # 検証イテレーション数
eval_interval: 100 # 検証間隔
eval_batch_size: 1 # 検証バッチサイズ
save: ckpts # モデル保存パス
save: ckpts # モデル保存パス
save_interval: 100 # 保存間隔
log_interval: 20 # ログ出力間隔
train_data: [ "your train data path" ]
@ -433,7 +433,7 @@ Modify the files in `configs/sft.yaml` (full fine-tuning) as follows:
eval_iters: 1 # Evaluation iterations
eval_interval: 100 # Evaluation interval
eval_batch_size: 1 # Evaluation batch size
save: ckpts # Model save path
save: ckpts # Model save path
save_interval: 100 # Save interval
log_interval: 20 # Log output interval
train_data: [ "your train data path" ]
@ -505,7 +505,7 @@ run_cmd="$environs python sample_video.py --base configs/cogvideox_<model parame
Then, run the code:
```
bash inference.sh
bash inference.sh
```
### Converting to Huggingface Diffusers-compatible Weights
@ -546,4 +546,4 @@ Lora adds a low-rank weight to the attention structure of the model.
```
Using `export_sat_lora_weight.py` will convert these to the HF format Lora structure.
![alt text](../resources/hf_lora_weights.png)
![alt text](../resources/hf_lora_weights.png)

View File

@ -332,7 +332,7 @@ Encoder 使用。
eval_iters: 1 # 验证迭代次数
eval_interval: 100 # 验证间隔
eval_batch_size: 1 # 验证集 batch size
save: ckpts # 模型保存路径
save: ckpts # 模型保存路径
save_interval: 100 # 模型保存间隔
log_interval: 20 # 日志输出间隔
train_data: [ "your train data path" ]
@ -403,7 +403,7 @@ run_cmd="$environs python sample_video.py --base configs/cogvideox_<模型参数
然后,执行代码:
```
bash inference.sh
bash inference.sh
```
### 转换到 Huggingface Diffusers 库支持的权重

View File

@ -18,7 +18,10 @@ def add_model_config_args(parser):
group = parser.add_argument_group("model", "model configuration")
group.add_argument("--base", type=str, nargs="*", help="config for input and saving")
group.add_argument(
"--model-parallel-size", type=int, default=1, help="size of the model parallel. only use if you are an expert."
"--model-parallel-size",
type=int,
default=1,
help="size of the model parallel. only use if you are an expert.",
)
group.add_argument("--force-pretrain", action="store_true")
group.add_argument("--device", type=int, default=-1)
@ -74,10 +77,15 @@ def get_args(args_list=None, parser=None):
if not args.train_data:
print_rank0("No training data specified", level="WARNING")
assert (args.train_iters is None) or (args.epochs is None), "only one of train_iters and epochs should be set."
assert (args.train_iters is None) or (
args.epochs is None
), "only one of train_iters and epochs should be set."
if args.train_iters is None and args.epochs is None:
args.train_iters = 10000 # default 10k iters
print_rank0("No train_iters (recommended) or epochs specified, use default 10k iters.", level="WARNING")
print_rank0(
"No train_iters (recommended) or epochs specified, use default 10k iters.",
level="WARNING",
)
args.cuda = torch.cuda.is_available()
@ -213,7 +221,10 @@ def initialize_distributed(args):
args.master_port = os.getenv("MASTER_PORT", default_master_port)
init_method += args.master_ip + ":" + args.master_port
torch.distributed.init_process_group(
backend=args.distributed_backend, world_size=args.world_size, rank=args.rank, init_method=init_method
backend=args.distributed_backend,
world_size=args.world_size,
rank=args.rank,
init_method=init_method,
)
# Set the model-parallel / data-parallel communicators.
@ -232,7 +243,10 @@ def initialize_distributed(args):
import deepspeed
deepspeed.init_distributed(
dist_backend=args.distributed_backend, world_size=args.world_size, rank=args.rank, init_method=init_method
dist_backend=args.distributed_backend,
world_size=args.world_size,
rank=args.rank,
init_method=init_method,
)
# # It seems that it has no negative influence to configure it even without using checkpointing.
# deepspeed.checkpointing.configure(mpu, deepspeed_config=args.deepspeed_config, num_checkpoints=args.num_layers)
@ -262,7 +276,9 @@ def process_config_to_args(args):
args_config = config.pop("args", OmegaConf.create())
for key in args_config:
if isinstance(args_config[key], omegaconf.DictConfig) or isinstance(args_config[key], omegaconf.ListConfig):
if isinstance(args_config[key], omegaconf.DictConfig) or isinstance(
args_config[key], omegaconf.ListConfig
):
arg = OmegaConf.to_object(args_config[key])
else:
arg = args_config[key]

View File

@ -156,4 +156,4 @@ model:
params:
scale: 6
exp: 5
num_steps: 50
num_steps: 50

View File

@ -151,4 +151,4 @@ model:
params:
scale: 6
exp: 5
num_steps: 50
num_steps: 50

View File

@ -157,4 +157,4 @@ model:
params:
scale: 6
exp: 5
num_steps: 50
num_steps: 50

View File

@ -3,7 +3,7 @@ model:
disable_first_stage_autocast: true
log_keys:
- txt
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
params:
@ -150,4 +150,4 @@ model:
params:
scale: 6
exp: 5
num_steps: 50
num_steps: 50

View File

@ -156,4 +156,4 @@ model:
params:
scale: 6
exp: 5
num_steps: 50
num_steps: 50

View File

@ -162,4 +162,4 @@ model:
params:
scale: 6
exp: 5
num_steps: 50
num_steps: 50

View File

@ -156,4 +156,4 @@ model:
params:
scale: 6
exp: 5
num_steps: 50
num_steps: 50

View File

@ -11,4 +11,4 @@ args:
sampling_fps: 16
bf16: True
output_dir: outputs
force_inference: True
force_inference: True

View File

@ -62,4 +62,4 @@ deepspeed:
activation_checkpointing:
partition_activations: false
contiguous_memory_optimization: false
wall_clock_breakdown: false
wall_clock_breakdown: false

View File

@ -1,4 +1,4 @@
In the haunting backdrop of a warIn the haunting backdrop of a war-torn city, where ruins and crumbled walls tell a story of devastation, a poignant close-up frames a young girl. Her face is smudged with ash, a silent testament to the chaos around her. Her eyes glistening with a mix of sorrow and resilience, capturing the raw emotion of a world that has lost its innocence to the ravages of conflict.
The camera follows behind a white vintage SUV with a black roof rack as it speeds up a steep dirt road surrounded by pine trees on a steep mountain slope, dust kicks up from its tires, the sunlight shines on the SUV as it speeds along the dirt road, casting a warm glow over the scene. The dirt road curves gently into the distance, with no other cars or vehicles in sight. The trees on either side of the road are redwoods, with patches of greenery scattered throughout. The car is seen from the rear following the curve with ease, making it seem as if it is on a rugged drive through the rugged terrain. The dirt road itself is surrounded by steep hills and mountains, with a clear blue sky above with wispy clouds.
A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting.
A street artist, clad in a worn-out denim jacket and a colorful banana, stands before a vast concrete wall in the heart, holding a can of spray paint, spray-painting a colorful bird on a mottled wall.-torn city, where ruins and crumbled walls tell a story of devastation, a poignant close-up frames a young girl. Her face is smudged with ash, a silent testament to the chaos around her. Her eyes glistening with a mix of sorrow and resilience, capturing the raw emotion of a world that has lost its innocence to the ravages of conflict.
A street artist, clad in a worn-out denim jacket and a colorful banana, stands before a vast concrete wall in the heart, holding a can of spray paint, spray-painting a colorful bird on a mottled wall.-torn city, where ruins and crumbled walls tell a story of devastation, a poignant close-up frames a young girl. Her face is smudged with ash, a silent testament to the chaos around her. Her eyes glistening with a mix of sorrow and resilience, capturing the raw emotion of a world that has lost its innocence to the ravages of conflict.

View File

@ -56,7 +56,9 @@ def read_video(
end_pts = float("inf")
if end_pts < start_pts:
raise ValueError(f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}")
raise ValueError(
f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}"
)
info = {}
audio_frames = []
@ -342,7 +344,11 @@ class VideoDataset(MetaDistributedWebDataset):
super().__init__(
path,
partial(
process_fn_video, num_frames=num_frames, image_size=image_size, fps=fps, skip_frms_num=skip_frms_num
process_fn_video,
num_frames=num_frames,
image_size=image_size,
fps=fps,
skip_frms_num=skip_frms_num,
),
seed,
meta_names=meta_names,
@ -400,7 +406,9 @@ class SFTDataset(Dataset):
indices = np.arange(start, end, (end - start) // num_frames).astype(int)
temp_frms = vr.get_batch(np.arange(start, end_safty))
assert temp_frms is not None
tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
tensor_frms = (
torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
)
tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())]
else:
if ori_vlen > self.max_num_frames:
@ -410,7 +418,11 @@ class SFTDataset(Dataset):
indices = np.arange(start, end, max((end - start) // num_frames, 1)).astype(int)
temp_frms = vr.get_batch(np.arange(start, end))
assert temp_frms is not None
tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
tensor_frms = (
torch.from_numpy(temp_frms)
if type(temp_frms) is not torch.Tensor
else temp_frms
)
tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())]
else:
@ -423,11 +435,17 @@ class SFTDataset(Dataset):
start = int(self.skip_frms_num)
end = int(ori_vlen - self.skip_frms_num)
num_frames = nearest_smaller_4k_plus_1(end - start) # 3D VAE requires the number of frames to be 4k+1
num_frames = nearest_smaller_4k_plus_1(
end - start
) # 3D VAE requires the number of frames to be 4k+1
end = int(start + num_frames)
temp_frms = vr.get_batch(np.arange(start, end))
assert temp_frms is not None
tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
tensor_frms = (
torch.from_numpy(temp_frms)
if type(temp_frms) is not torch.Tensor
else temp_frms
)
tensor_frms = pad_last_frame(
tensor_frms, self.max_num_frames

View File

@ -41,7 +41,9 @@ class SATVideoDiffusionEngine(nn.Module):
latent_input = model_config.get("latent_input", False)
disable_first_stage_autocast = model_config.get("disable_first_stage_autocast", False)
no_cond_log = model_config.get("disable_first_stage_autocast", False)
not_trainable_prefixes = model_config.get("not_trainable_prefixes", ["first_stage_model", "conditioner"])
not_trainable_prefixes = model_config.get(
"not_trainable_prefixes", ["first_stage_model", "conditioner"]
)
compile_model = model_config.get("compile_model", False)
en_and_decode_n_samples_a_time = model_config.get("en_and_decode_n_samples_a_time", None)
lr_scale = model_config.get("lr_scale", None)
@ -76,12 +78,18 @@ class SATVideoDiffusionEngine(nn.Module):
)
self.denoiser = instantiate_from_config(denoiser_config)
self.sampler = instantiate_from_config(sampler_config) if sampler_config is not None else None
self.conditioner = instantiate_from_config(default(conditioner_config, UNCONDITIONAL_CONFIG))
self.sampler = (
instantiate_from_config(sampler_config) if sampler_config is not None else None
)
self.conditioner = instantiate_from_config(
default(conditioner_config, UNCONDITIONAL_CONFIG)
)
self._init_first_stage(first_stage_config)
self.loss_fn = instantiate_from_config(loss_fn_config) if loss_fn_config is not None else None
self.loss_fn = (
instantiate_from_config(loss_fn_config) if loss_fn_config is not None else None
)
self.latent_input = latent_input
self.scale_factor = scale_factor
@ -151,8 +159,12 @@ class SATVideoDiffusionEngine(nn.Module):
def shared_step(self, batch: Dict) -> Any:
x = self.get_input(batch)
if self.lr_scale is not None:
lr_x = F.interpolate(x, scale_factor=1 / self.lr_scale, mode="bilinear", align_corners=False)
lr_x = F.interpolate(lr_x, scale_factor=self.lr_scale, mode="bilinear", align_corners=False)
lr_x = F.interpolate(
x, scale_factor=1 / self.lr_scale, mode="bilinear", align_corners=False
)
lr_x = F.interpolate(
lr_x, scale_factor=self.lr_scale, mode="bilinear", align_corners=False
)
lr_z = self.encode_first_stage(lr_x, batch)
batch["lr_input"] = lr_z
@ -195,7 +207,11 @@ class SATVideoDiffusionEngine(nn.Module):
recons = []
start_frame = 0
for i in range(fake_cp_size):
end_frame = start_frame + latent_time // fake_cp_size + (1 if i < latent_time % fake_cp_size else 0)
end_frame = (
start_frame
+ latent_time // fake_cp_size
+ (1 if i < latent_time % fake_cp_size else 0)
)
use_cp = True if i == 0 else False
clear_fake_cp_cache = True if i == fake_cp_size - 1 else False
@ -264,7 +280,9 @@ class SATVideoDiffusionEngine(nn.Module):
self.model, input, sigma, c, concat_images=concat_images, **addtional_model_inputs
)
samples = self.sampler(denoiser, randn, cond, uc=uc, scale=scale, scale_emb=scale_emb, ofs=ofs)
samples = self.sampler(
denoiser, randn, cond, uc=uc, scale=scale, scale_emb=scale_emb, ofs=ofs
)
samples = samples.to(self.dtype)
return samples
@ -278,7 +296,9 @@ class SATVideoDiffusionEngine(nn.Module):
log = dict()
for embedder in self.conditioner.embedders:
if ((self.log_keys is None) or (embedder.input_key in self.log_keys)) and not self.no_cond_log:
if (
(self.log_keys is None) or (embedder.input_key in self.log_keys)
) and not self.no_cond_log:
x = batch[embedder.input_key][:n]
if isinstance(x, torch.Tensor):
if x.dim() == 1:
@ -354,7 +374,9 @@ class SATVideoDiffusionEngine(nn.Module):
image = torch.concat([image, torch.zeros_like(z[:, 1:])], dim=1)
c["concat"] = image
uc["concat"] = image
samples = self.sample(c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs) # b t c h w
samples = self.sample(
c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
) # b t c h w
samples = samples.permute(0, 2, 1, 3, 4).contiguous()
if only_log_video_latents:
latents = 1.0 / self.scale_factor * samples
@ -364,7 +386,9 @@ class SATVideoDiffusionEngine(nn.Module):
samples = samples.permute(0, 2, 1, 3, 4).contiguous()
log["samples"] = samples
else:
samples = self.sample(c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs) # b t c h w
samples = self.sample(
c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
) # b t c h w
samples = samples.permute(0, 2, 1, 3, 4).contiguous()
if only_log_video_latents:
latents = 1.0 / self.scale_factor * samples

View File

@ -94,7 +94,9 @@ def get_3d_sincos_pos_embed(
# concate: [T, H, W] order
pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :]
pos_embed_temporal = np.repeat(pos_embed_temporal, grid_height * grid_width, axis=1) # [T, H*W, D // 4]
pos_embed_temporal = np.repeat(
pos_embed_temporal, grid_height * grid_width, axis=1
) # [T, H*W, D // 4]
pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :]
pos_embed_spatial = np.repeat(pos_embed_spatial, t_size, axis=0) # [T, H*W, D // 4 * 3]
@ -160,7 +162,8 @@ class Basic2DPositionEmbeddingMixin(BaseMixin):
self.width = width
self.spatial_length = height * width
self.pos_embedding = nn.Parameter(
torch.zeros(1, int(text_length + self.spatial_length), int(hidden_size)), requires_grad=False
torch.zeros(1, int(text_length + self.spatial_length), int(hidden_size)),
requires_grad=False,
)
def position_embedding_forward(self, position_ids, **kwargs):
@ -169,7 +172,9 @@ class Basic2DPositionEmbeddingMixin(BaseMixin):
def reinit(self, parent_model=None):
del self.transformer.position_embeddings
pos_embed = get_2d_sincos_pos_embed(self.pos_embedding.shape[-1], self.height, self.width)
self.pos_embedding.data[:, -self.spatial_length :].copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
self.pos_embedding.data[:, -self.spatial_length :].copy_(
torch.from_numpy(pos_embed).float().unsqueeze(0)
)
class Basic3DPositionEmbeddingMixin(BaseMixin):
@ -192,7 +197,8 @@ class Basic3DPositionEmbeddingMixin(BaseMixin):
self.spatial_length = height * width
self.num_patches = height * width * compressed_num_frames
self.pos_embedding = nn.Parameter(
torch.zeros(1, int(text_length + self.num_patches), int(hidden_size)), requires_grad=False
torch.zeros(1, int(text_length + self.num_patches), int(hidden_size)),
requires_grad=False,
)
self.height_interpolation = height_interpolation
self.width_interpolation = width_interpolation
@ -285,7 +291,10 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin):
freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2)
freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
freqs = broadcat(
(freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]),
dim=-1,
)
freqs = freqs.contiguous()
self.freqs_sin = freqs.sin().cuda()
@ -293,7 +302,9 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin):
self.text_length = text_length
if learnable_pos_embed:
num_patches = height * width * compressed_num_frames + text_length
self.pos_embedding = nn.Parameter(torch.zeros(1, num_patches, int(hidden_size)), requires_grad=True)
self.pos_embedding = nn.Parameter(
torch.zeros(1, num_patches, int(hidden_size)), requires_grad=True
)
else:
self.pos_embedding = None
@ -440,16 +451,26 @@ class FinalLayerMixin(BaseMixin):
self.out_channels = out_channels
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=elementwise_affine, eps=1e-6)
self.linear = nn.Linear(hidden_size, reduce(mul, patch_size) * out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim, 2 * hidden_size, bias=True))
self.adaLN_modulation = nn.Sequential(
nn.SiLU(), nn.Linear(time_embed_dim, 2 * hidden_size, bias=True)
)
def final_forward(self, logits, **kwargs):
x, emb = logits[:, kwargs["text_length"] :, :], kwargs["emb"] # x:(b,(t n),d),只取了x中后面images的部分
x, emb = (
logits[:, kwargs["text_length"] :, :],
kwargs["emb"],
) # x:(b,(t n),d),只取了x中后面images的部分
shift, scale = self.adaLN_modulation(emb).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return unpatchify(
x, c=self.out_channels, patch_size=self.patch_size, w=kwargs["rope_W"], h=kwargs["rope_H"], **kwargs
x,
c=self.out_channels,
patch_size=self.patch_size,
w=kwargs["rope_W"],
h=kwargs["rope_H"],
**kwargs,
)
def reinit(self, parent_model=None):
@ -500,7 +521,10 @@ class AdaLNMixin(BaseMixin):
self.compressed_num_frames = compressed_num_frames
self.adaLN_modulations = nn.ModuleList(
[nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim, 12 * hidden_size)) for _ in range(num_layers)]
[
nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim, 12 * hidden_size))
for _ in range(num_layers)
]
)
self.qk_ln = qk_ln
@ -560,7 +584,9 @@ class AdaLNMixin(BaseMixin):
img_attention_input = modulate(img_attention_input, shift_msa, scale_msa)
text_attention_input = modulate(text_attention_input, text_shift_msa, text_scale_msa)
attention_input = torch.cat((text_attention_input, img_attention_input), dim=1) # (b,n_t+t*n_i,d)
attention_input = torch.cat(
(text_attention_input, img_attention_input), dim=1
) # (b,n_t+t*n_i,d)
attention_output = layer.attention(attention_input, mask, **kwargs)
text_attention_output = attention_output[:, :text_length] # (b,n,d)
img_attention_output = attention_output[:, text_length:] # (b,(t n),d)
@ -584,9 +610,13 @@ class AdaLNMixin(BaseMixin):
img_mlp_output = layer.fourth_layernorm(img_mlp_output)
img_hidden_states = img_hidden_states + gate_mlp * img_mlp_output # vision (b,(t n),d)
text_hidden_states = text_hidden_states + text_gate_mlp * text_mlp_output # language (b,n,d)
text_hidden_states = (
text_hidden_states + text_gate_mlp * text_mlp_output
) # language (b,n,d)
hidden_states = torch.cat((text_hidden_states, img_hidden_states), dim=1) # (b,(n_t+t*n_i),d)
hidden_states = torch.cat(
(text_hidden_states, img_hidden_states), dim=1
) # (b,(n_t+t*n_i),d)
return hidden_states
def reinit(self, parent_model=None):
@ -694,7 +724,9 @@ class DiffusionTransformer(BaseModel):
if use_RMSNorm:
kwargs["layernorm"] = RMSNorm
else:
kwargs["layernorm"] = partial(LayerNorm, elementwise_affine=elementwise_affine, eps=1e-6)
kwargs["layernorm"] = partial(
LayerNorm, elementwise_affine=elementwise_affine, eps=1e-6
)
transformer_args.num_layers = num_layers
transformer_args.hidden_size = hidden_size
@ -707,7 +739,9 @@ class DiffusionTransformer(BaseModel):
if use_SwiGLU:
self.add_mixin(
"swiglu", SwiGLUMixin(num_layers, hidden_size, self.inner_hidden_size, bias=False), reinit=True
"swiglu",
SwiGLUMixin(num_layers, hidden_size, self.inner_hidden_size, bias=False),
reinit=True,
)
def _build_modules(self, module_configs):
@ -813,7 +847,9 @@ class DiffusionTransformer(BaseModel):
)
if "lora_config" in module_configs:
lora_config = module_configs["lora_config"]
self.add_mixin("lora", instantiate_from_config(lora_config, layer_num=self.num_layers), reinit=True)
self.add_mixin(
"lora", instantiate_from_config(lora_config, layer_num=self.num_layers), reinit=True
)
return
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
@ -829,7 +865,9 @@ class DiffusionTransformer(BaseModel):
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False, dtype=self.dtype)
t_emb = timestep_embedding(
timesteps, self.model_channels, repeat_only=False, dtype=self.dtype
)
emb = self.time_embed(t_emb)
if self.num_classes is not None:
@ -838,7 +876,9 @@ class DiffusionTransformer(BaseModel):
emb = emb + self.label_emb(y)
if self.ofs_embed_dim is not None:
ofs_emb = timestep_embedding(kwargs["ofs"], self.ofs_embed_dim, repeat_only=False, dtype=self.dtype)
ofs_emb = timestep_embedding(
kwargs["ofs"], self.ofs_embed_dim, repeat_only=False, dtype=self.dtype
)
ofs_emb = self.ofs_embed(ofs_emb)
emb = emb + ofs_emb
@ -852,6 +892,8 @@ class DiffusionTransformer(BaseModel):
kwargs["rope_H"] = h // self.patch_size[1]
kwargs["rope_W"] = w // self.patch_size[2]
kwargs["input_ids"] = kwargs["position_ids"] = kwargs["attention_mask"] = torch.ones((1, 1)).to(x.dtype)
kwargs["input_ids"] = kwargs["position_ids"] = kwargs["attention_mask"] = torch.ones(
(1, 1)
).to(x.dtype)
output = super().forward(**kwargs)[0]
return output

View File

@ -7,4 +7,4 @@ run_cmd="PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True torchrun --standalone
echo ${run_cmd}
eval ${run_cmd}
echo "DONE on `hostname`"
echo "DONE on `hostname`"

View File

@ -9,4 +9,4 @@ run_cmd="$environs python train_video.py --base configs/cogvideox_2b_lora.yaml c
echo ${run_cmd}
eval ${run_cmd}
echo "DONE on `hostname`"
echo "DONE on `hostname`"

View File

@ -9,4 +9,4 @@ run_cmd="$environs python sample_video.py --base configs/cogvideox1.5_5b.yaml co
echo ${run_cmd}
eval ${run_cmd}
echo "DONE on `hostname`"
echo "DONE on `hostname`"

View File

@ -8,4 +8,4 @@ safetensors>=0.4.5
scipy>=1.14.1
decord>=0.6.0
wandb>=0.18.5
deepspeed>=0.15.3
deepspeed>=0.15.3

View File

@ -19,6 +19,7 @@ from sat import mpu
from diffusion_video import SATVideoDiffusionEngine
from arguments import get_args
def read_from_cli():
cnt = 0
try:
@ -50,34 +51,50 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], T=None, device="cuda
for key in keys:
if key == "txt":
batch["txt"] = np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist()
batch_uc["txt"] = np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist()
batch["txt"] = (
np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist()
)
batch_uc["txt"] = (
np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist()
)
elif key == "original_size_as_tuple":
batch["original_size_as_tuple"] = (
torch.tensor([value_dict["orig_height"], value_dict["orig_width"]]).to(device).repeat(*N, 1)
torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
.to(device)
.repeat(*N, 1)
)
elif key == "crop_coords_top_left":
batch["crop_coords_top_left"] = (
torch.tensor([value_dict["crop_coords_top"], value_dict["crop_coords_left"]]).to(device).repeat(*N, 1)
torch.tensor([value_dict["crop_coords_top"], value_dict["crop_coords_left"]])
.to(device)
.repeat(*N, 1)
)
elif key == "aesthetic_score":
batch["aesthetic_score"] = torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
batch["aesthetic_score"] = (
torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
)
batch_uc["aesthetic_score"] = (
torch.tensor([value_dict["negative_aesthetic_score"]]).to(device).repeat(*N, 1)
)
elif key == "target_size_as_tuple":
batch["target_size_as_tuple"] = (
torch.tensor([value_dict["target_height"], value_dict["target_width"]]).to(device).repeat(*N, 1)
torch.tensor([value_dict["target_height"], value_dict["target_width"]])
.to(device)
.repeat(*N, 1)
)
elif key == "fps":
batch[key] = torch.tensor([value_dict["fps"]]).to(device).repeat(math.prod(N))
elif key == "fps_id":
batch[key] = torch.tensor([value_dict["fps_id"]]).to(device).repeat(math.prod(N))
elif key == "motion_bucket_id":
batch[key] = torch.tensor([value_dict["motion_bucket_id"]]).to(device).repeat(math.prod(N))
batch[key] = (
torch.tensor([value_dict["motion_bucket_id"]]).to(device).repeat(math.prod(N))
)
elif key == "pool_image":
batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=math.prod(N)).to(device, dtype=torch.half)
batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=math.prod(N)).to(
device, dtype=torch.half
)
elif key == "cond_aug":
batch[key] = repeat(
torch.tensor([value_dict["cond_aug"]]).to("cuda"),
@ -100,7 +117,9 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], T=None, device="cuda
return batch, batch_uc
def save_video_as_grid_and_mp4(video_batch: torch.Tensor, save_path: str, fps: int = 5, args=None, key=None):
def save_video_as_grid_and_mp4(
video_batch: torch.Tensor, save_path: str, fps: int = 5, args=None, key=None
):
os.makedirs(save_path, exist_ok=True)
for i, vid in enumerate(video_batch):
@ -160,7 +179,9 @@ def sampling_main(args, model_cls):
W = 96
H = int(nearest_multiple_of_16(img_H / img_W * W * 8)) // 8
chained_trainsforms = []
chained_trainsforms.append(TT.Resize(size=[int(H * 8), int(W * 8)], interpolation=1))
chained_trainsforms.append(
TT.Resize(size=[int(H * 8), int(W * 8)], interpolation=1)
)
chained_trainsforms.append(TT.ToTensor())
transform = TT.Compose(chained_trainsforms)
image = transform(image).unsqueeze(0).to("cuda")
@ -170,7 +191,9 @@ def sampling_main(args, model_cls):
image = image / model.scale_factor
image = image.permute(0, 2, 1, 3, 4).contiguous()
pad_shape = (image.shape[0], T - 1, C, H, W)
image = torch.concat([image, torch.zeros(pad_shape).to(image.device).to(image.dtype)], dim=1)
image = torch.concat(
[image, torch.zeros(pad_shape).to(image.device).to(image.dtype)], dim=1
)
else:
image_size = args.sampling_image_size
H, W = image_size[0], image_size[1]
@ -181,12 +204,20 @@ def sampling_main(args, model_cls):
mp_size = mpu.get_model_parallel_world_size()
global_rank = torch.distributed.get_rank() // mp_size
src = global_rank * mp_size
torch.distributed.broadcast_object_list(text_cast, src=src, group=mpu.get_model_parallel_group())
torch.distributed.broadcast_object_list(
text_cast, src=src, group=mpu.get_model_parallel_group()
)
text = text_cast[0]
value_dict = {"prompt": text, "negative_prompt": "", "num_frames": torch.tensor(T).unsqueeze(0)}
value_dict = {
"prompt": text,
"negative_prompt": "",
"num_frames": torch.tensor(T).unsqueeze(0),
}
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples
get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict,
num_samples,
)
for key in batch:
if isinstance(batch[key], torch.Tensor):
@ -212,7 +243,11 @@ def sampling_main(args, model_cls):
for index in range(args.batch_size):
if args.image2video:
samples_z = sample_func(
c, uc=uc, batch_size=1, shape=(T, C, H, W), ofs=torch.tensor([2.0]).to("cuda")
c,
uc=uc,
batch_size=1,
shape=(T, C, H, W),
ofs=torch.tensor([2.0]).to("cuda"),
)
else:
samples_z = sample_func(
@ -226,7 +261,9 @@ def sampling_main(args, model_cls):
if args.only_save_latents:
samples_z = 1.0 / model.scale_factor * samples_z
save_path = os.path.join(
args.output_dir, str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120], str(index)
args.output_dir,
str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120],
str(index),
)
os.makedirs(save_path, exist_ok=True)
torch.save(samples_z, os.path.join(save_path, "latent.pt"))
@ -237,7 +274,9 @@ def sampling_main(args, model_cls):
samples_x = samples_x.permute(0, 2, 1, 3, 4).contiguous()
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu()
save_path = os.path.join(
args.output_dir, str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120], str(index)
args.output_dir,
str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120],
str(index),
)
if mpu.get_model_parallel_rank() == 0:
save_video_as_grid_and_mp4(samples, save_path, fps=args.sampling_fps)

View File

@ -71,15 +71,24 @@ class LambdaWarmUpCosineScheduler2:
n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0:
print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}")
print(
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current cycle {cycle}"
)
if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
cycle
] * n + self.f_start[cycle]
self.last_f = f
return f
else:
t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
t = (n - self.lr_warm_up_steps[cycle]) / (
self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]
)
t = min(t, 1.0)
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (1 + np.cos(t * np.pi))
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
1 + np.cos(t * np.pi)
)
self.last_f = f
return f
@ -93,10 +102,15 @@ class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0:
print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}")
print(
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current cycle {cycle}"
)
if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
cycle
] * n + self.f_start[cycle]
self.last_f = f
return f
else:

View File

@ -218,14 +218,20 @@ class AutoencodingEngine(AbstractAutoencoder):
x = self.decoder(z, **kwargs)
return x
def forward(self, x: torch.Tensor, **additional_decode_kwargs) -> Tuple[torch.Tensor, torch.Tensor, dict]:
def forward(
self, x: torch.Tensor, **additional_decode_kwargs
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
z, reg_log = self.encode(x, return_reg_log=True)
dec = self.decode(z, **additional_decode_kwargs)
return z, dec, reg_log
def inner_training_step(self, batch: dict, batch_idx: int, optimizer_idx: int = 0) -> torch.Tensor:
def inner_training_step(
self, batch: dict, batch_idx: int, optimizer_idx: int = 0
) -> torch.Tensor:
x = self.get_input(batch)
additional_decode_kwargs = {key: batch[key] for key in self.additional_decode_keys.intersection(batch)}
additional_decode_kwargs = {
key: batch[key] for key in self.additional_decode_keys.intersection(batch)
}
z, xrec, regularization_log = self(x, **additional_decode_kwargs)
if hasattr(self.loss, "forward_keys"):
extra_info = {
@ -361,12 +367,16 @@ class AutoencodingEngine(AbstractAutoencoder):
if self.trainable_ae_params is None:
ae_params = self.get_autoencoder_params()
else:
ae_params, num_ae_params = self.get_param_groups(self.trainable_ae_params, self.ae_optimizer_args)
ae_params, num_ae_params = self.get_param_groups(
self.trainable_ae_params, self.ae_optimizer_args
)
logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}")
if self.trainable_disc_params is None:
disc_params = self.get_discriminator_params()
else:
disc_params, num_disc_params = self.get_param_groups(self.trainable_disc_params, self.disc_optimizer_args)
disc_params, num_disc_params = self.get_param_groups(
self.trainable_disc_params, self.disc_optimizer_args
)
logpy.info(f"Number of trainable discriminator parameters: {num_disc_params:,}")
opt_ae = self.instantiate_optimizer_from_config(
ae_params,
@ -375,17 +385,23 @@ class AutoencodingEngine(AbstractAutoencoder):
)
opts = [opt_ae]
if len(disc_params) > 0:
opt_disc = self.instantiate_optimizer_from_config(disc_params, self.learning_rate, self.optimizer_config)
opt_disc = self.instantiate_optimizer_from_config(
disc_params, self.learning_rate, self.optimizer_config
)
opts.append(opt_disc)
return opts
@torch.no_grad()
def log_images(self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs) -> dict:
def log_images(
self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs
) -> dict:
log = dict()
additional_decode_kwargs = {}
x = self.get_input(batch)
additional_decode_kwargs.update({key: batch[key] for key in self.additional_decode_keys.intersection(batch)})
additional_decode_kwargs.update(
{key: batch[key] for key in self.additional_decode_keys.intersection(batch)}
)
_, xrec, _ = self(x, **additional_decode_kwargs)
log["inputs"] = x
@ -404,7 +420,9 @@ class AutoencodingEngine(AbstractAutoencoder):
diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x)
diff_ema.clamp_(0, 1.0)
log["diff_ema"] = 2.0 * diff_ema - 1.0
log["diff_boost_ema"] = 2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1
log["diff_boost_ema"] = (
2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1
)
if additional_log_kwargs:
additional_decode_kwargs.update(additional_log_kwargs)
_, xrec_add, _ = self(x, **additional_decode_kwargs)
@ -446,7 +464,9 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
params = super().get_autoencoder_params()
return params
def encode(self, x: torch.Tensor, return_reg_log: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
def encode(
self, x: torch.Tensor, return_reg_log: bool = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
if self.max_batch_size is None:
z = self.encoder(x)
z = self.quant_conv(z)
@ -513,7 +533,9 @@ class VideoAutoencodingEngine(AutoencodingEngine):
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def log_videos(self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs) -> dict:
def log_videos(
self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs
) -> dict:
return self.log_images(batch, additional_log_kwargs, **kwargs)
def get_input(self, batch: dict) -> torch.Tensor:
@ -524,7 +546,9 @@ class VideoAutoencodingEngine(AutoencodingEngine):
batch = batch[self.input_key]
global_src_rank = get_context_parallel_group_rank() * self.context_parallel_size
torch.distributed.broadcast(batch, src=global_src_rank, group=get_context_parallel_group())
torch.distributed.broadcast(
batch, src=global_src_rank, group=get_context_parallel_group()
)
batch = _conv_split(batch, dim=2, kernel_size=1)
return batch

View File

@ -94,7 +94,11 @@ class FeedForward(nn.Module):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
project_in = (
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
if not glu
else GEGLU(dim, inner_dim)
)
self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
@ -126,7 +130,9 @@ class LinearAttention(nn.Module):
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
q, k, v = rearrange(
qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
)
k = k.softmax(dim=-1)
context = torch.einsum("bhdn,bhen->bhde", k, v)
out = torch.einsum("bhde,bhdn->bhen", context, q)
@ -143,7 +149,9 @@ class SpatialSelfAttention(nn.Module):
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x):
h_ = x
@ -244,7 +252,9 @@ class CrossAttention(nn.Module):
# new
with sdp_kernel(**BACKEND_MAP[self.backend]):
# print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) # scale is dim_head ** -0.5 per default
out = F.scaled_dot_product_attention(
q, k, v, attn_mask=mask
) # scale is dim_head ** -0.5 per default
del q, k, v
out = rearrange(out, "b h n d -> b n (h d)", h=h)
@ -422,7 +432,9 @@ class BasicTransformerBlock(nn.Module):
self.norm1(x),
context=context if self.disable_self_attn else None,
additional_tokens=additional_tokens,
n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self if not self.disable_self_attn else 0,
n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self
if not self.disable_self_attn
else 0,
)
+ x
)
@ -499,7 +511,9 @@ class SpatialTransformer(nn.Module):
sdp_backend=None,
):
super().__init__()
print(f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads")
print(
f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads"
)
from omegaconf import ListConfig
if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)):
@ -542,7 +556,9 @@ class SpatialTransformer(nn.Module):
]
)
if not use_linear:
self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
self.proj_out = zero_module(
nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
)
else:
# self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))

View File

@ -87,7 +87,9 @@ class GeneralLPIPSWithDiscriminator(nn.Module):
yield from ()
@torch.no_grad()
def log_images(self, inputs: torch.Tensor, reconstructions: torch.Tensor) -> Dict[str, torch.Tensor]:
def log_images(
self, inputs: torch.Tensor, reconstructions: torch.Tensor
) -> Dict[str, torch.Tensor]:
# calc logits of real/fake
logits_real = self.discriminator(inputs.contiguous().detach())
if len(logits_real.shape) < 4:
@ -209,7 +211,9 @@ class GeneralLPIPSWithDiscriminator(nn.Module):
weights: Union[None, float, torch.Tensor] = None,
) -> Tuple[torch.Tensor, dict]:
if self.scale_input_to_tgt_size:
inputs = torch.nn.functional.interpolate(inputs, reconstructions.shape[2:], mode="bicubic", antialias=True)
inputs = torch.nn.functional.interpolate(
inputs, reconstructions.shape[2:], mode="bicubic", antialias=True
)
if self.dims > 2:
inputs, reconstructions = map(
@ -226,7 +230,9 @@ class GeneralLPIPSWithDiscriminator(nn.Module):
input_frames = pick_video_frame(inputs, frame_indices)
recon_frames = pick_video_frame(reconstructions, frame_indices)
p_loss = self.perceptual_loss(input_frames.contiguous(), recon_frames.contiguous()).mean()
p_loss = self.perceptual_loss(
input_frames.contiguous(), recon_frames.contiguous()
).mean()
rec_loss = rec_loss + self.perceptual_weight * p_loss
nll_loss, weighted_nll_loss = self.get_nll_loss(rec_loss, weights)
@ -238,7 +244,9 @@ class GeneralLPIPSWithDiscriminator(nn.Module):
logits_fake = self.discriminator(reconstructions.contiguous())
g_loss = -torch.mean(logits_fake)
if self.training:
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
d_weight = self.calculate_adaptive_weight(
nll_loss, g_loss, last_layer=last_layer
)
else:
d_weight = torch.tensor(1.0)
else:

View File

@ -37,12 +37,18 @@ class LatentLPIPS(nn.Module):
if self.perceptual_weight > 0.0:
image_reconstructions = self.decoder.decode(latent_predictions)
image_targets = self.decoder.decode(latent_inputs)
perceptual_loss = self.perceptual_loss(image_targets.contiguous(), image_reconstructions.contiguous())
loss = self.latent_weight * loss.mean() + self.perceptual_weight * perceptual_loss.mean()
perceptual_loss = self.perceptual_loss(
image_targets.contiguous(), image_reconstructions.contiguous()
)
loss = (
self.latent_weight * loss.mean() + self.perceptual_weight * perceptual_loss.mean()
)
log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach()
if self.perceptual_weight_on_inputs > 0.0:
image_reconstructions = default(image_reconstructions, self.decoder.decode(latent_predictions))
image_reconstructions = default(
image_reconstructions, self.decoder.decode(latent_predictions)
)
if self.scale_input_to_tgt_size:
image_inputs = torch.nn.functional.interpolate(
image_inputs,
@ -58,7 +64,9 @@ class LatentLPIPS(nn.Module):
antialias=True,
)
perceptual_loss2 = self.perceptual_loss(image_inputs.contiguous(), image_reconstructions.contiguous())
perceptual_loss2 = self.perceptual_loss(
image_inputs.contiguous(), image_reconstructions.contiguous()
)
loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean()
log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach()
return loss, log

View File

@ -45,7 +45,9 @@ def hinge_gen_loss(fake):
@autocast(enabled=False)
@beartype
def grad_layer_wrt_loss(loss: Tensor, layer: nn.Parameter):
return torch_grad(outputs=loss, inputs=layer, grad_outputs=torch.ones_like(loss), retain_graph=True)[0].detach()
return torch_grad(
outputs=loss, inputs=layer, grad_outputs=torch.ones_like(loss), retain_graph=True
)[0].detach()
def pick_video_frame(video, frame_indices):
@ -126,7 +128,8 @@ class DiscriminatorBlock(nn.Module):
self.downsample = (
nn.Sequential(
Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2), nn.Conv2d(filters * 4, filters, 1)
Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2),
nn.Conv2d(filters * 4, filters, 1),
)
if downsample
else None
@ -185,11 +188,18 @@ class Discriminator(nn.Module):
is_not_last = ind != (len(layer_dims_in_out) - 1)
block = DiscriminatorBlock(
in_chan, out_chan, downsample=is_not_last, antialiased_downsample=antialiased_downsample
in_chan,
out_chan,
downsample=is_not_last,
antialiased_downsample=antialiased_downsample,
)
attn_block = nn.Sequential(
Residual(LinearSpaceAttention(dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head)),
Residual(
LinearSpaceAttention(
dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head
)
),
Residual(FeedForward(dim=out_chan, mult=ff_mult, images=True)),
)
@ -363,7 +373,9 @@ class Discriminator3D(nn.Module):
)
attn_block = nn.Sequential(
Residual(
LinearSpaceAttention(dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head)
LinearSpaceAttention(
dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head
)
),
Residual(FeedForward(dim=out_chan, mult=ff_mult, images=True)),
)
@ -458,7 +470,9 @@ class Discriminator3DWithfirstframe(nn.Module):
)
attn_block = nn.Sequential(
Residual(
LinearSpaceAttention(dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head)
LinearSpaceAttention(
dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head
)
),
Residual(FeedForward(dim=out_chan, mult=ff_mult, images=True)),
)
@ -581,11 +595,17 @@ class VideoAutoencoderLoss(nn.Module):
input_frames = pick_video_frame(inputs, frame_indices)
recon_frames = pick_video_frame(reconstructions, frame_indices)
perceptual_loss = self.perceptual_model(input_frames.contiguous(), recon_frames.contiguous()).mean()
perceptual_loss = self.perceptual_model(
input_frames.contiguous(), recon_frames.contiguous()
).mean()
else:
perceptual_loss = self.zero
if global_step >= self.disc_start or not self.training or self.adversarial_loss_weight == 0:
if (
global_step >= self.disc_start
or not self.training
or self.adversarial_loss_weight == 0
):
gen_loss = self.zero
adaptive_weight = 0
else:
@ -598,9 +618,13 @@ class VideoAutoencoderLoss(nn.Module):
adaptive_weight = 1
if self.perceptual_weight > 0 and last_layer is not None:
norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_layer).norm(p=2)
norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(
perceptual_loss, last_layer
).norm(p=2)
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_layer).norm(p=2)
adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(min=1e-3)
adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(
min=1e-3
)
adaptive_weight.clamp_(max=1e3)
if torch.isnan(adaptive_weight).any():

View File

@ -1 +1 @@
vgg.pth
vgg.pth

View File

@ -20,4 +20,4 @@ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@ -48,7 +48,9 @@ class LPIPS(nn.Module):
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
res = [
spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))
]
val = res[0]
for l in range(1, len(self.chns)):
val += res[l]
@ -118,7 +120,9 @@ class vgg16(torch.nn.Module):
h_relu4_3 = h
h = self.slice5(h)
h_relu5_3 = h
vgg_outputs = namedtuple("VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"])
vgg_outputs = namedtuple(
"VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]
)
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
return out

View File

@ -55,4 +55,4 @@ Redistributions in binary form must reproduce the above copyright notice, this l
Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@ -35,7 +35,9 @@ class NLayerDiscriminator(nn.Module):
norm_layer = nn.BatchNorm2d
else:
norm_layer = ActNorm
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
if (
type(norm_layer) == functools.partial
): # no need to use bias as BatchNorm2d has affine parameters
use_bias = norm_layer.func != nn.BatchNorm2d
else:
use_bias = norm_layer != nn.BatchNorm2d

View File

@ -11,6 +11,7 @@ def hinge_d_loss(logits_real, logits_fake):
def vanilla_d_loss(logits_real, logits_fake):
d_loss = 0.5 * (
torch.mean(torch.nn.functional.softplus(-logits_real)) + torch.mean(torch.nn.functional.softplus(logits_fake))
torch.mean(torch.nn.functional.softplus(-logits_real))
+ torch.mean(torch.nn.functional.softplus(logits_fake))
)
return d_loss

View File

@ -147,7 +147,9 @@ def hinge_gen_loss(fake):
@autocast(enabled=False)
@beartype
def grad_layer_wrt_loss(loss: Tensor, layer: nn.Parameter):
return torch_grad(outputs=loss, inputs=layer, grad_outputs=torch.ones_like(loss), retain_graph=True)[0].detach()
return torch_grad(
outputs=loss, inputs=layer, grad_outputs=torch.ones_like(loss), retain_graph=True
)[0].detach()
# helper decorators
@ -223,7 +225,10 @@ class SqueezeExcite(Module):
dim_hidden = max(dim_hidden_min, dim_out // 2)
self.net = nn.Sequential(
nn.Conv2d(dim, dim_hidden, 1), nn.LeakyReLU(0.1), nn.Conv2d(dim_hidden, dim_out, 1), nn.Sigmoid()
nn.Conv2d(dim, dim_hidden, 1),
nn.LeakyReLU(0.1),
nn.Conv2d(dim_hidden, dim_out, 1),
nn.Sigmoid(),
)
nn.init.zeros_(self.net[-2].weight)
@ -282,7 +287,10 @@ class RMSNorm(Module):
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
def forward(self, x):
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
return (
F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma
+ self.bias
)
class AdaptiveRMSNorm(Module):
@ -353,7 +361,8 @@ class Attention(Module):
self.norm = RMSNorm(dim)
self.to_qkv = nn.Sequential(
nn.Linear(dim, dim_inner * 3, bias=False), Rearrange("b n (qkv h d) -> qkv b h n d", qkv=3, h=heads)
nn.Linear(dim, dim_inner * 3, bias=False),
Rearrange("b n (qkv h d) -> qkv b h n d", qkv=3, h=heads),
)
assert num_memory_kv > 0
@ -361,7 +370,9 @@ class Attention(Module):
self.attend = Attend(causal=causal, dropout=dropout, flash=flash)
self.to_out = nn.Sequential(Rearrange("b h n d -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False))
self.to_out = nn.Sequential(
Rearrange("b h n d -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False)
)
@beartype
def forward(self, x, mask: Optional[Tensor] = None, cond: Optional[Tensor] = None):
@ -455,7 +466,9 @@ class FeedForward(Module):
super().__init__()
conv_klass = nn.Conv2d if images else nn.Conv3d
rmsnorm_klass = RMSNorm if not exists(dim_cond) else partial(AdaptiveRMSNorm, dim_cond=dim_cond)
rmsnorm_klass = (
RMSNorm if not exists(dim_cond) else partial(AdaptiveRMSNorm, dim_cond=dim_cond)
)
maybe_adaptive_norm_klass = partial(rmsnorm_klass, channel_first=True, images=images)
@ -463,7 +476,9 @@ class FeedForward(Module):
self.norm = maybe_adaptive_norm_klass(dim)
self.net = Sequential(conv_klass(dim, dim_inner * 2, 1), GEGLU(), conv_klass(dim_inner, dim, 1))
self.net = Sequential(
conv_klass(dim, dim_inner * 2, 1), GEGLU(), conv_klass(dim_inner, dim, 1)
)
@beartype
def forward(self, x: Tensor, *, cond: Optional[Tensor] = None):
@ -525,7 +540,8 @@ class DiscriminatorBlock(Module):
self.downsample = (
nn.Sequential(
Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2), nn.Conv2d(filters * 4, filters, 1)
Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2),
nn.Conv2d(filters * 4, filters, 1),
)
if downsample
else None
@ -584,11 +600,18 @@ class Discriminator(Module):
is_not_last = ind != (len(layer_dims_in_out) - 1)
block = DiscriminatorBlock(
in_chan, out_chan, downsample=is_not_last, antialiased_downsample=antialiased_downsample
in_chan,
out_chan,
downsample=is_not_last,
antialiased_downsample=antialiased_downsample,
)
attn_block = Sequential(
Residual(LinearSpaceAttention(dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head)),
Residual(
LinearSpaceAttention(
dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head
)
),
Residual(FeedForward(dim=out_chan, mult=ff_mult, images=True)),
)
@ -628,7 +651,16 @@ class Discriminator(Module):
class Conv3DMod(Module):
@beartype
def __init__(
self, dim, *, spatial_kernel, time_kernel, causal=True, dim_out=None, demod=True, eps=1e-8, pad_mode="zeros"
self,
dim,
*,
spatial_kernel,
time_kernel,
causal=True,
dim_out=None,
demod=True,
eps=1e-8,
pad_mode="zeros",
):
super().__init__()
dim_out = default(dim_out, dim)
@ -644,7 +676,9 @@ class Conv3DMod(Module):
self.pad_mode = pad_mode
self.padding = (*((spatial_kernel // 2,) * 4), *time_padding)
self.weights = nn.Parameter(torch.randn((dim_out, dim, time_kernel, spatial_kernel, spatial_kernel)))
self.weights = nn.Parameter(
torch.randn((dim_out, dim, time_kernel, spatial_kernel, spatial_kernel))
)
self.demod = demod
@ -675,7 +709,11 @@ class Conv3DMod(Module):
weights = weights * (cond + 1)
if self.demod:
inv_norm = reduce(weights**2, "b o i k0 k1 k2 -> b o 1 1 1 1", "sum").clamp(min=self.eps).rsqrt()
inv_norm = (
reduce(weights**2, "b o i k0 k1 k2 -> b o 1 1 1 1", "sum")
.clamp(min=self.eps)
.rsqrt()
)
weights = weights * inv_norm
fmap = rearrange(fmap, "b c t h w -> 1 (b c) t h w")
@ -742,7 +780,9 @@ class SpatialUpsample2x(Module):
dim_out = default(dim_out, dim)
conv = nn.Conv2d(dim, dim_out * 4, 1)
self.net = nn.Sequential(conv, nn.SiLU(), Rearrange("b (c p1 p2) h w -> b c (h p1) (w p2)", p1=2, p2=2))
self.net = nn.Sequential(
conv, nn.SiLU(), Rearrange("b (c p1 p2) h w -> b c (h p1) (w p2)", p1=2, p2=2)
)
self.init_conv_(conv)
@ -808,7 +848,12 @@ def SameConv2d(dim_in, dim_out, kernel_size):
class CausalConv3d(Module):
@beartype
def __init__(
self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], pad_mode="constant", **kwargs
self,
chan_in,
chan_out,
kernel_size: Union[int, Tuple[int, int, int]],
pad_mode="constant",
**kwargs,
):
super().__init__()
kernel_size = cast_tuple(kernel_size, 3)
@ -830,7 +875,9 @@ class CausalConv3d(Module):
stride = (stride, 1, 1)
dilation = (dilation, 1, 1)
self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
self.conv = nn.Conv3d(
chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs
)
def forward(self, x):
pad_mode = self.pad_mode if self.time_pad < x.shape[2] else "constant"
@ -855,7 +902,13 @@ def ResidualUnit(dim, kernel_size: Union[int, Tuple[int, int, int]], pad_mode: s
@beartype
class ResidualUnitMod(Module):
def __init__(
self, dim, kernel_size: Union[int, Tuple[int, int, int]], *, dim_cond, pad_mode: str = "constant", demod=True
self,
dim,
kernel_size: Union[int, Tuple[int, int, int]],
*,
dim_cond,
pad_mode: str = "constant",
demod=True,
):
super().__init__()
kernel_size = cast_tuple(kernel_size, 3)
@ -892,7 +945,15 @@ class ResidualUnitMod(Module):
class CausalConvTranspose3d(Module):
def __init__(self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], *, time_stride, **kwargs):
def __init__(
self,
chan_in,
chan_out,
kernel_size: Union[int, Tuple[int, int, int]],
*,
time_stride,
**kwargs,
):
super().__init__()
kernel_size = cast_tuple(kernel_size, 3)
@ -908,7 +969,9 @@ class CausalConvTranspose3d(Module):
stride = (time_stride, 1, 1)
padding = (0, height_pad, width_pad)
self.conv = nn.ConvTranspose3d(chan_in, chan_out, kernel_size, stride, padding=padding, **kwargs)
self.conv = nn.ConvTranspose3d(
chan_in, chan_out, kernel_size, stride, padding=padding, **kwargs
)
def forward(self, x):
assert x.ndim == 5
@ -936,7 +999,9 @@ LossBreakdown = namedtuple(
],
)
DiscrLossBreakdown = namedtuple("DiscrLossBreakdown", ["discr_loss", "multiscale_discr_losses", "gradient_penalty"])
DiscrLossBreakdown = namedtuple(
"DiscrLossBreakdown", ["discr_loss", "multiscale_discr_losses", "gradient_penalty"]
)
class VideoTokenizer(Module):
@ -1050,10 +1115,14 @@ class VideoTokenizer(Module):
has_cond = True
encoder_layer = ResidualUnitMod(
dim, residual_conv_kernel_size, dim_cond=int(dim_cond * dim_cond_expansion_factor)
dim,
residual_conv_kernel_size,
dim_cond=int(dim_cond * dim_cond_expansion_factor),
)
decoder_layer = ResidualUnitMod(
dim, residual_conv_kernel_size, dim_cond=int(dim_cond * dim_cond_expansion_factor)
dim,
residual_conv_kernel_size,
dim_cond=int(dim_cond * dim_cond_expansion_factor),
)
dim_out = dim
@ -1080,15 +1149,25 @@ class VideoTokenizer(Module):
elif layer_type == "attend_space":
attn_kwargs = dict(
dim=dim, dim_head=attn_dim_head, heads=attn_heads, dropout=attn_dropout, flash=flash_attn
dim=dim,
dim_head=attn_dim_head,
heads=attn_heads,
dropout=attn_dropout,
flash=flash_attn,
)
encoder_layer = Sequential(Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim)))
encoder_layer = Sequential(
Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim))
)
decoder_layer = Sequential(Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim)))
decoder_layer = Sequential(
Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim))
)
elif layer_type == "linear_attend_space":
linear_attn_kwargs = dict(dim=dim, dim_head=linear_attn_dim_head, heads=linear_attn_heads)
linear_attn_kwargs = dict(
dim=dim, dim_head=linear_attn_dim_head, heads=linear_attn_heads
)
encoder_layer = Sequential(
Residual(LinearSpaceAttention(**linear_attn_kwargs)), Residual(FeedForward(dim))
@ -1136,9 +1215,13 @@ class VideoTokenizer(Module):
flash=flash_attn,
)
encoder_layer = Sequential(Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim)))
encoder_layer = Sequential(
Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim))
)
decoder_layer = Sequential(Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim)))
decoder_layer = Sequential(
Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim))
)
elif layer_type == "cond_linear_attend_space":
has_cond = True
@ -1153,11 +1236,13 @@ class VideoTokenizer(Module):
)
encoder_layer = Sequential(
Residual(LinearSpaceAttention(**attn_kwargs)), Residual(FeedForward(dim, dim_cond=dim_cond))
Residual(LinearSpaceAttention(**attn_kwargs)),
Residual(FeedForward(dim, dim_cond=dim_cond)),
)
decoder_layer = Sequential(
Residual(LinearSpaceAttention(**attn_kwargs)), Residual(FeedForward(dim, dim_cond=dim_cond))
Residual(LinearSpaceAttention(**attn_kwargs)),
Residual(FeedForward(dim, dim_cond=dim_cond)),
)
elif layer_type == "cond_attend_time":
@ -1283,7 +1368,9 @@ class VideoTokenizer(Module):
# discriminator
discr_kwargs = default(discr_kwargs, dict(dim=dim, image_size=image_size, channels=channels, max_dim=512))
discr_kwargs = default(
discr_kwargs, dict(dim=dim, image_size=image_size, channels=channels, max_dim=512)
)
self.discr = Discriminator(**discr_kwargs)
@ -1380,8 +1467,16 @@ class VideoTokenizer(Module):
self.load_state_dict(state_dict, strict=strict)
@beartype
def encode(self, video: Tensor, quantize=False, cond: Optional[Tensor] = None, video_contains_first_frame=True):
encode_first_frame_separately = self.separate_first_frame_encoding and video_contains_first_frame
def encode(
self,
video: Tensor,
quantize=False,
cond: Optional[Tensor] = None,
video_contains_first_frame=True,
):
encode_first_frame_separately = (
self.separate_first_frame_encoding and video_contains_first_frame
)
# whether to pad video or not
@ -1389,12 +1484,16 @@ class VideoTokenizer(Module):
video_len = video.shape[2]
video = pad_at_dim(video, (self.time_padding, 0), value=0.0, dim=2)
video_packed_shape = [torch.Size([self.time_padding]), torch.Size([]), torch.Size([video_len - 1])]
video_packed_shape = [
torch.Size([self.time_padding]),
torch.Size([]),
torch.Size([video_len - 1]),
]
# conditioning, if needed
assert (not self.has_cond) or exists(
cond
assert (
(not self.has_cond) or exists(cond)
), "`cond` must be passed into tokenizer forward method since conditionable layers were specified"
if exists(cond):
@ -1431,7 +1530,9 @@ class VideoTokenizer(Module):
return maybe_quantize(video)
@beartype
def decode_from_code_indices(self, codes: Tensor, cond: Optional[Tensor] = None, video_contains_first_frame=True):
def decode_from_code_indices(
self, codes: Tensor, cond: Optional[Tensor] = None, video_contains_first_frame=True
):
assert codes.dtype in (torch.long, torch.int32)
if codes.ndim == 2:
@ -1444,18 +1545,24 @@ class VideoTokenizer(Module):
quantized = self.quantizers.indices_to_codes(codes)
return self.decode(quantized, cond=cond, video_contains_first_frame=video_contains_first_frame)
return self.decode(
quantized, cond=cond, video_contains_first_frame=video_contains_first_frame
)
@beartype
def decode(self, quantized: Tensor, cond: Optional[Tensor] = None, video_contains_first_frame=True):
decode_first_frame_separately = self.separate_first_frame_encoding and video_contains_first_frame
def decode(
self, quantized: Tensor, cond: Optional[Tensor] = None, video_contains_first_frame=True
):
decode_first_frame_separately = (
self.separate_first_frame_encoding and video_contains_first_frame
)
batch = quantized.shape[0]
# conditioning, if needed
assert (not self.has_cond) or exists(
cond
assert (
(not self.has_cond) or exists(cond)
), "`cond` must be passed into tokenizer forward method since conditionable layers were specified"
if exists(cond):
@ -1558,14 +1665,18 @@ class VideoTokenizer(Module):
aux_losses = self.zero
quantizer_loss_breakdown = None
else:
(quantized, codes, aux_losses), quantizer_loss_breakdown = self.quantizers(x, return_loss_breakdown=True)
(quantized, codes, aux_losses), quantizer_loss_breakdown = self.quantizers(
x, return_loss_breakdown=True
)
if return_codes and not return_recon:
return codes
# decoder
recon_video = self.decode(quantized, cond=cond, video_contains_first_frame=video_contains_first_frame)
recon_video = self.decode(
quantized, cond=cond, video_contains_first_frame=video_contains_first_frame
)
if return_codes:
return codes, recon_video
@ -1613,7 +1724,9 @@ class VideoTokenizer(Module):
multiscale_real_logits = discr(video)
multiscale_fake_logits = discr(recon_video.detach())
multiscale_discr_loss = hinge_discr_loss(multiscale_fake_logits, multiscale_real_logits)
multiscale_discr_loss = hinge_discr_loss(
multiscale_fake_logits, multiscale_real_logits
)
multiscale_discr_losses.append(multiscale_discr_loss)
else:
@ -1634,7 +1747,9 @@ class VideoTokenizer(Module):
+ sum(multiscale_discr_losses) * self.multiscale_adversarial_loss_weight
)
discr_loss_breakdown = DiscrLossBreakdown(discr_loss, multiscale_discr_losses, gradient_penalty_loss)
discr_loss_breakdown = DiscrLossBreakdown(
discr_loss, multiscale_discr_losses, gradient_penalty_loss
)
return total_loss, discr_loss_breakdown
@ -1669,7 +1784,9 @@ class VideoTokenizer(Module):
norm_grad_wrt_perceptual_loss = None
if self.training and self.use_vgg and (self.has_gan or self.has_multiscale_discrs):
norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p=2)
norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(
perceptual_loss, last_dec_layer
).norm(p=2)
# per-frame image discriminator
@ -1686,7 +1803,9 @@ class VideoTokenizer(Module):
if exists(norm_grad_wrt_perceptual_loss):
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p=2)
adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(min=1e-3)
adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(
min=1e-3
)
adaptive_weight.clamp_(max=1e3)
if torch.isnan(adaptive_weight).any():
@ -1713,8 +1832,12 @@ class VideoTokenizer(Module):
multiscale_adaptive_weight = 1.0
if exists(norm_grad_wrt_perceptual_loss):
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(multiscale_gen_loss, last_dec_layer).norm(p=2)
multiscale_adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(min=1e-5)
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(
multiscale_gen_loss, last_dec_layer
).norm(p=2)
multiscale_adaptive_weight = (
norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(min=1e-5)
)
multiscale_adaptive_weight.clamp_(max=1e3)
multiscale_gen_adaptive_weights.append(multiscale_adaptive_weight)
@ -1730,10 +1853,13 @@ class VideoTokenizer(Module):
if self.has_multiscale_discrs:
weighted_multiscale_gen_losses = sum(
loss * weight for loss, weight in zip(multiscale_gen_losses, multiscale_gen_adaptive_weights)
loss * weight
for loss, weight in zip(multiscale_gen_losses, multiscale_gen_adaptive_weights)
)
total_loss = total_loss + weighted_multiscale_gen_losses * multiscale_adversarial_loss_weight
total_loss = (
total_loss + weighted_multiscale_gen_losses * multiscale_adversarial_loss_weight
)
# loss breakdown

View File

@ -26,7 +26,9 @@ class IdentityRegularizer(AbstractRegularizer):
yield from ()
def measure_perplexity(predicted_indices: torch.Tensor, num_centroids: int) -> Tuple[torch.Tensor, torch.Tensor]:
def measure_perplexity(
predicted_indices: torch.Tensor, num_centroids: int
) -> Tuple[torch.Tensor, torch.Tensor]:
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
encodings = F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids)

View File

@ -79,13 +79,19 @@ class FSQ(Module):
self.dim = default(dim, len(_levels) * num_codebooks)
has_projections = self.dim != effective_codebook_dim
self.project_in = nn.Linear(self.dim, effective_codebook_dim) if has_projections else nn.Identity()
self.project_out = nn.Linear(effective_codebook_dim, self.dim) if has_projections else nn.Identity()
self.project_in = (
nn.Linear(self.dim, effective_codebook_dim) if has_projections else nn.Identity()
)
self.project_out = (
nn.Linear(effective_codebook_dim, self.dim) if has_projections else nn.Identity()
)
self.has_projections = has_projections
self.codebook_size = self._levels.prod().item()
implicit_codebook = self.indices_to_codes(torch.arange(self.codebook_size), project_out=False)
implicit_codebook = self.indices_to_codes(
torch.arange(self.codebook_size), project_out=False
)
self.register_buffer("implicit_codebook", implicit_codebook, persistent=False)
def bound(self, z: Tensor, eps: float = 1e-3) -> Tensor:
@ -153,7 +159,9 @@ class FSQ(Module):
z = rearrange(z, "b d ... -> b ... d")
z, ps = pack_one(z, "b * d")
assert z.shape[-1] == self.dim, f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}"
assert (
z.shape[-1] == self.dim
), f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}"
z = self.project_in(z)

View File

@ -78,7 +78,9 @@ class LFQ(Module):
# some assert validations
assert exists(dim) or exists(codebook_size), "either dim or codebook_size must be specified for LFQ"
assert exists(dim) or exists(
codebook_size
), "either dim or codebook_size must be specified for LFQ"
assert (
not exists(codebook_size) or log2(codebook_size).is_integer()
), f"your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})"
@ -195,7 +197,9 @@ class LFQ(Module):
x = rearrange(x, "b d ... -> b ... d")
x, ps = pack_one(x, "b * d")
assert x.shape[-1] == self.dim, f"expected dimension of {self.dim} but received {x.shape[-1]}"
assert (
x.shape[-1] == self.dim
), f"expected dimension of {self.dim} but received {x.shape[-1]}"
x = self.project_in(x)
@ -299,7 +303,9 @@ class LFQ(Module):
# complete aux loss
aux_loss = entropy_aux_loss * self.entropy_loss_weight + commit_loss * self.commitment_loss_weight
aux_loss = (
entropy_aux_loss * self.entropy_loss_weight + commit_loss * self.commitment_loss_weight
)
ret = Return(x, indices, aux_loss)

View File

@ -33,7 +33,9 @@ class AbstractQuantizer(AbstractRegularizer):
new = match.argmax(-1)
unknown = match.sum(2) < 1
if self.unknown_index == "random":
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(
device=new.device
)
else:
new[unknown] = self.unknown_index
return new.reshape(ishape)
@ -50,7 +52,9 @@ class AbstractQuantizer(AbstractRegularizer):
return back.reshape(ishape)
@abstractmethod
def get_codebook_entry(self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None) -> torch.Tensor:
def get_codebook_entry(
self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None
) -> torch.Tensor:
raise NotImplementedError()
def get_trainable_parameters(self) -> Iterator[torch.nn.Parameter]:
@ -239,7 +243,8 @@ class VectorQuantizer(AbstractQuantizer):
d = (
torch.sum(z_flattened**2, dim=1, keepdim=True)
+ torch.sum(self.embedding.weight**2, dim=1)
- 2 * torch.einsum("bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n"))
- 2
* torch.einsum("bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n"))
)
min_encoding_indices = torch.argmin(d, dim=1)
@ -267,15 +272,21 @@ class VectorQuantizer(AbstractQuantizer):
if self.sane_index_shape:
if do_reshape:
min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])
min_encoding_indices = min_encoding_indices.reshape(
z_q.shape[0], z_q.shape[2], z_q.shape[3]
)
else:
min_encoding_indices = rearrange(min_encoding_indices, "(b s) 1 -> b s", b=z_q.shape[0])
min_encoding_indices = rearrange(
min_encoding_indices, "(b s) 1 -> b s", b=z_q.shape[0]
)
loss_dict["min_encoding_indices"] = min_encoding_indices
return z_q, loss_dict
def get_codebook_entry(self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None) -> torch.Tensor:
def get_codebook_entry(
self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None
) -> torch.Tensor:
# shape specifying (batch, height, width, channel)
if self.remap is not None:
assert shape is not None, "Need to give shape for remap"
@ -448,6 +459,8 @@ class VectorQuantizerWithInputProjection(VectorQuantizer):
elif len(in_shape) == 5:
z_q = rearrange(z_q, "b (t h w) c -> b c t h w ", w=in_shape[-1], h=in_shape[-2])
else:
raise NotImplementedError(f"rearranging not available for {len(in_shape)}-dimensional input.")
raise NotImplementedError(
f"rearranging not available for {len(in_shape)}-dimensional input."
)
return z_q, loss_dict

View File

@ -248,7 +248,9 @@ def make_time_attn(
"vanilla",
"vanilla-xformers",
], f"attn_type {attn_type} not supported for spatio-temporal attention"
print(f"making spatial and temporal attention of type '{attn_type}' with {in_channels} in_channels")
print(
f"making spatial and temporal attention of type '{attn_type}' with {in_channels} in_channels"
)
if not XFORMERS_IS_AVAILABLE and attn_type == "vanilla-xformers":
print(
f"Attention mode '{attn_type}' is not available. Falling back to vanilla attention. "

View File

@ -125,9 +125,13 @@ class ResnetBlock3D(nn.Module):
self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
self.conv_shortcut = CausalConv3d(
in_channels, out_channels, kernel_size=3, pad_mode=pad_mode
)
else:
self.nin_shortcut = torch.nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.nin_shortcut = torch.nn.Conv3d(
in_channels, out_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x, temb, zq):
h = x
@ -161,7 +165,9 @@ class AttnBlock2D(nn.Module):
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x, zq):
h_ = x
@ -380,7 +386,11 @@ class NewDecoder3D(nn.Module):
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
print(
"Working with z of shape {} = {} dimensions.".format(
self.z_shape, np.prod(self.z_shape)
)
)
# z to block_in
# self.conv_in = torch.nn.Conv3d(z_channels,

View File

@ -148,9 +148,13 @@ class ResnetBlock3D(nn.Module):
# kernel_size=3,
# stride=1,
# padding=1)
self.conv_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
self.conv_shortcut = CausalConv3d(
in_channels, out_channels, kernel_size=3, pad_mode=pad_mode
)
else:
self.nin_shortcut = torch.nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.nin_shortcut = torch.nn.Conv3d(
in_channels, out_channels, kernel_size=1, stride=1, padding=0
)
# self.nin_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=1, pad_mode=pad_mode)
def forward(self, x, temb, zq):
@ -185,7 +189,9 @@ class AttnBlock2D(nn.Module):
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x, zq):
h_ = x
@ -261,7 +267,11 @@ class MOVQDecoder3D(nn.Module):
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
print(
"Working with z of shape {} = {} dimensions.".format(
self.z_shape, np.prod(self.z_shape)
)
)
# z to block_in
# self.conv_in = torch.nn.Conv3d(z_channels,
@ -420,7 +430,11 @@ class NewDecoder3D(nn.Module):
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
print(
"Working with z of shape {} = {} dimensions.".format(
self.z_shape, np.prod(self.z_shape)
)
)
# z to block_in
# self.conv_in = torch.nn.Conv3d(z_channels,

View File

@ -51,7 +51,12 @@ def nonlinearity(x):
class CausalConv3d(nn.Module):
@beartype
def __init__(
self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], pad_mode="constant", **kwargs
self,
chan_in,
chan_out,
kernel_size: Union[int, Tuple[int, int, int]],
pad_mode="constant",
**kwargs,
):
super().__init__()
kernel_size = cast_tuple(kernel_size, 3)
@ -75,11 +80,20 @@ class CausalConv3d(nn.Module):
stride = (stride, 1, 1)
dilation = (dilation, 1, 1)
self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
self.conv = nn.Conv3d(
chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs
)
def forward(self, x):
if self.pad_mode == "constant":
causal_padding_3d = (self.time_pad, 0, self.width_pad, self.width_pad, self.height_pad, self.height_pad)
causal_padding_3d = (
self.time_pad,
0,
self.width_pad,
self.width_pad,
self.height_pad,
self.height_pad,
)
x = F.pad(x, causal_padding_3d, mode="constant", value=0)
elif self.pad_mode == "first":
pad_x = torch.cat([x[:, :, :1]] * self.time_pad, dim=2)
@ -91,7 +105,9 @@ class CausalConv3d(nn.Module):
reflect_x = x[:, :, 1 : self.time_pad + 1, :, :].flip(dims=[2])
if reflect_x.shape[2] < self.time_pad:
reflect_x = torch.cat(
[torch.zeros_like(x[:, :, :1, :, :])] * (self.time_pad - reflect_x.shape[2]) + [reflect_x], dim=2
[torch.zeros_like(x[:, :, :1, :, :])] * (self.time_pad - reflect_x.shape[2])
+ [reflect_x],
dim=2,
)
x = torch.cat([reflect_x, x], dim=2)
causal_padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
@ -110,7 +126,9 @@ class Upsample3D(nn.Module):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1
)
self.compress_time = compress_time
def forward(self, x):
@ -149,7 +167,9 @@ class DownSample3D(nn.Module):
out_channels = in_channels
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0)
self.conv = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=2, padding=0
)
self.compress_time = compress_time
def forward(self, x):
@ -182,7 +202,14 @@ class DownSample3D(nn.Module):
class ResnetBlock3D(nn.Module):
def __init__(
self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512, pad_mode="constant"
self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout,
temb_channels=512,
pad_mode="constant",
):
super().__init__()
self.in_channels = in_channels
@ -214,9 +241,13 @@ class ResnetBlock3D(nn.Module):
# kernel_size=3,
# stride=1,
# padding=1)
self.conv_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
self.conv_shortcut = CausalConv3d(
in_channels, out_channels, kernel_size=3, pad_mode=pad_mode
)
else:
self.nin_shortcut = torch.nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.nin_shortcut = torch.nn.Conv3d(
in_channels, out_channels, kernel_size=1, stride=1, padding=0
)
# self.nin_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=1, pad_mode=pad_mode)
def forward(self, x, temb):
@ -251,7 +282,9 @@ class AttnBlock2D(nn.Module):
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x):
h_ = x
@ -365,12 +398,20 @@ class Encoder3D(nn.Module):
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock3D(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, pad_mode=pad_mode
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
pad_mode=pad_mode,
)
# remove attention block
# self.mid.attn_1 = AttnBlock2D(block_in)
self.mid.block_2 = ResnetBlock3D(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, pad_mode=pad_mode
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
pad_mode=pad_mode,
)
# end

View File

@ -80,7 +80,9 @@ class Upsample(nn.Module):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1
)
def forward(self, x):
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
@ -95,7 +97,9 @@ class Downsample(nn.Module):
self.with_conv = with_conv
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=2, padding=0
)
def forward(self, x):
if self.with_conv:
@ -134,9 +138,13 @@ class ResnetBlock(nn.Module):
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.conv_shortcut = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
else:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.nin_shortcut = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x, temb, zq):
h = x
@ -170,7 +178,9 @@ class AttnBlock(nn.Module):
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x, zq):
h_ = x
@ -232,7 +242,11 @@ class MOVQDecoder(nn.Module):
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
print(
"Working with z of shape {} = {} dimensions.".format(
self.z_shape, np.prod(self.z_shape)
)
)
# z to block_in
self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)

View File

@ -15,7 +15,16 @@ class VectorQuantizer2(nn.Module):
# NOTE: due to a bug the beta term was applied to the wrong term. for
# backwards compatibility we use the buggy version by default, but you can
# specify legacy=False to fix it.
def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True):
def __init__(
self,
n_e,
e_dim,
beta,
remap=None,
unknown_index="random",
sane_index_shape=False,
legacy=True,
):
super().__init__()
self.n_e = n_e
self.e_dim = e_dim
@ -51,7 +60,9 @@ class VectorQuantizer2(nn.Module):
new = match.argmax(-1)
unknown = match.sum(2) < 1
if self.unknown_index == "random":
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(
device=new.device
)
else:
new[unknown] = self.unknown_index
return new.reshape(ishape)
@ -78,7 +89,8 @@ class VectorQuantizer2(nn.Module):
d = (
torch.sum(z_flattened**2, dim=1, keepdim=True)
+ torch.sum(self.embedding.weight**2, dim=1)
- 2 * torch.einsum("bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n"))
- 2
* torch.einsum("bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n"))
)
min_encoding_indices = torch.argmin(d, dim=1)
@ -88,9 +100,13 @@ class VectorQuantizer2(nn.Module):
# compute loss for embedding
if not self.legacy:
loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean(
(z_q - z.detach()) ** 2
)
else:
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean(
(z_q - z.detach()) ** 2
)
# preserve gradients
z_q = z + (z_q - z).detach()
@ -104,7 +120,9 @@ class VectorQuantizer2(nn.Module):
min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
if self.sane_index_shape:
min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])
min_encoding_indices = min_encoding_indices.reshape(
z_q.shape[0], z_q.shape[2], z_q.shape[3]
)
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
@ -184,7 +202,9 @@ class GumbelQuantize(nn.Module):
new = match.argmax(-1)
unknown = match.sum(2) < 1
if self.unknown_index == "random":
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(
device=new.device
)
else:
new[unknown] = self.unknown_index
return new.reshape(ishape)

Some files were not shown because too many files have changed in this diff Show More