diff --git a/.github/ISSUE_TEMPLATE/PULL_REQUEST_TEMPLATE.md b/.github/ISSUE_TEMPLATE/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000..65d04d3 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,28 @@ +# Contribution Guide + +We welcome your contributions to this repository. To ensure elegant code style and better code quality, we have prepared the following contribution guidelines. + +## What We Accept + ++ This PR fixes a typo or improves the documentation (if this is the case, you may skip the other checks). ++ This PR fixes a specific issue — please reference the issue number in the PR description. Make sure your code strictly follows the coding standards below. ++ This PR introduces a new feature — please clearly explain the necessity and implementation of the feature. Make sure your code strictly follows the coding standards below. + +## Code Style Guide + +Good code style is an art. We have prepared a `pyproject.toml` and a `pre-commit` hook to enforce consistent code formatting across the project. You can clean up your code following the steps below: + +1. Install the required dependencies: +```shell + pip install ruff pre-commit +``` +2. Then, run the following command: +```shell + pre-commit run --all-files +``` +If your code complies with the standards, you should not see any errors. + +## Naming Conventions + +- Please use **English** for naming; do not use Pinyin or other languages. All comments should also be in English. +- Follow **PEP8** naming conventions strictly, and use underscores to separate words. Avoid meaningless names such as `a`, `b`, `c`. diff --git a/.github/ISSUE_TEMPLATE/bug_report.yaml b/.github/ISSUE_TEMPLATE/bug_report.yaml index 19271ef..4f1f0dc 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.yaml +++ b/.github/ISSUE_TEMPLATE/bug_report.yaml @@ -30,14 +30,14 @@ body: If you have code snippets, error messages, stack traces, please provide them here as well. Please format your code correctly using code tags. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting Do not use screenshots, as they are difficult to read and (more importantly) do not allow others to copy and paste your code. - + 请提供能重现您遇到的问题的代码示例,最好是最小复现单元。 如果您有代码片段、错误信息、堆栈跟踪,也请在此提供。 请使用代码标签正确格式化您的代码。请参见 https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting 请勿使用截图,因为截图难以阅读,而且(更重要的是)不允许他人复制粘贴您的代码。 placeholder: | Steps to reproduce the behavior/复现Bug的步骤: - + 1. 2. 3. @@ -48,4 +48,4 @@ body: required: true attributes: label: Expected behavior / 期待表现 - description: "A clear and concise description of what you would expect to happen. /简单描述您期望发生的事情。" \ No newline at end of file + description: "A clear and concise description of what you would expect to happen. /简单描述您期望发生的事情。" diff --git a/.github/ISSUE_TEMPLATE/feature-request.yaml b/.github/ISSUE_TEMPLATE/feature-request.yaml index 7e09bee..513e96c 100644 --- a/.github/ISSUE_TEMPLATE/feature-request.yaml +++ b/.github/ISSUE_TEMPLATE/feature-request.yaml @@ -29,6 +29,6 @@ body: attributes: label: Your contribution / 您的贡献 description: | - + Your PR link or any other link you can help with. - 您的PR链接或者其他您能提供帮助的链接。 \ No newline at end of file + 您的PR链接或者其他您能提供帮助的链接。 diff --git a/.github/PULL_REQUEST_TEMPLATE/pr_template.md b/.github/PULL_REQUEST_TEMPLATE/pr_template.md deleted file mode 100644 index 0c3140a..0000000 --- a/.github/PULL_REQUEST_TEMPLATE/pr_template.md +++ /dev/null @@ -1,34 +0,0 @@ -# Raise valuable PR / 提出有价值的PR - -## Caution / 注意事项: -Users should keep the following points in mind when submitting PRs: - -1. Ensure that your code meets the requirements in the [specification](../../resources/contribute.md). -2. the proposed PR should be relevant, if there are multiple ideas and optimizations, they should be assigned to different PRs. - -用户在提交PR时候应该注意以下几点: - -1. 确保您的代码符合 [规范](../../resources/contribute_zh.md) 中的要求。 -2. 提出的PR应该具有针对性,如果具有多个不同的想法和优化方案,应该分配到不同的PR中。 - -## 不应该提出的PR / PRs that should not be proposed - -If a developer proposes a PR about any of the following, it may be closed or Rejected. - -1. those that don't describe improvement options. -2. multiple issues of different types combined in one PR. -3. The proposed PR is highly duplicative of already existing PRs. - -如果开发者提出关于以下方面的PR,则可能会被直接关闭或拒绝通过。 - -1. 没有说明改进方案的。 -2. 多个不同类型的问题合并在一个PR中的。 -3. 提出的PR与已经存在的PR高度重复的。 - - -# 检查您的PR -- [ ] Have you read the Contributor Guidelines, Pull Request section? / 您是否阅读了贡献者指南、Pull Request 部分? -- [ ] Has this been discussed/approved via a Github issue or forum? If so, add a link. / 是否通过 Github 问题或论坛讨论/批准过?如果是,请添加链接。 -- [ ] Did you make sure you updated the documentation with your changes? Here are the Documentation Guidelines, and here are the Documentation Formatting Tips. /您是否确保根据您的更改更新了文档?这里是文档指南,这里是文档格式化技巧。 -- [ ] Did you write new required tests? / 您是否编写了新的必要测试? -- [ ] Are your PRs for only one issue / 您的PR是否仅针对一个问题 \ No newline at end of file diff --git a/.gitignore b/.gitignore index a667345..ad4bbeb 100644 --- a/.gitignore +++ b/.gitignore @@ -34,4 +34,4 @@ CogVideo-1.0 **/train_results **/train_res* -**/uv.lock \ No newline at end of file +**/uv.lock diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..fff4c90 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,19 @@ +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.4.5 + hooks: + - id: ruff + args: [--fix, --respect-gitignore, --config=pyproject.toml] + - id: ruff-format + args: [--config=pyproject.toml] + + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: check-toml + - id: check-case-conflict + - id: check-merge-conflict + - id: debug-statements diff --git a/MODEL_LICENSE b/MODEL_LICENSE index 3ca0c74..7f8da3b 100644 --- a/MODEL_LICENSE +++ b/MODEL_LICENSE @@ -68,4 +68,4 @@ Note that the license is subject to update to a more comprehensive version. For 本许可受中华人民共和国法律管辖并按其解释。 因本许可引起的或与本许可有关的任何争议应提交北京市海淀区人民法院。 -请注意,许可证可能会更新到更全面的版本。 有关许可和版权的任何问题,请通过 license@zhipuai.cn 与我们联系。 \ No newline at end of file +请注意,许可证可能会更新到更全面的版本。 有关许可和版权的任何问题,请通过 license@zhipuai.cn 与我们联系。 diff --git a/README.md b/README.md index d459b94..13c4e4e 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ Experience the CogVideoX-5B model online at paper and user guide

- 👋 Join our WeChat and Discord + 👋 Join our WeChat and Discord

📍 Visit QingYing and API Platform to experience larger-scale commercial video generation models. @@ -22,12 +22,12 @@ Experience the CogVideoX-5B model online at Position Encoding 3d_rope_pos_embed - 3d_sincos_pos_embed + 3d_sincos_pos_embed 3d_rope_pos_embed 3d_rope_pos_embed + learnable_pos_embed @@ -444,8 +444,6 @@ hands-on practice on text-to-video generation. *The original input is in Chinese } ``` -We welcome your contributions! You can click [here](resources/contribute.md) for more information. - ## Model-License The code in this repository is released under the [Apache 2.0 License](LICENSE). diff --git a/README_ja.md b/README_ja.md index 5659101..857ae6a 100644 --- a/README_ja.md +++ b/README_ja.md @@ -21,7 +21,7 @@

## 更新とニュース -- 🔥🔥 ```2025/03/16```: [CogKit](https://github.com/THUDM/CogKit) は **CogView4** および **CogVideoX** シリーズの微調整と推論のためのフレームワークです。このツールキットを活用することで、私たちのマルチモーダル生成モデルを最大限に活用できます。 +- 🔥🔥 ```2025/03/24```: [CogKit](https://github.com/THUDM/CogKit) は **CogView4** および **CogVideoX** シリーズの微調整と推論のためのフレームワークです。このツールキットを活用することで、私たちのマルチモーダル生成モデルを最大限に活用できます。 - **ニュース**: ```2025/02/28```: DDIM Inverse が `CogVideoX-5B` と `CogVideoX1.5-5B` でサポートされました。詳細は [こちら](inference/ddim_inversion.py) をご覧ください。 - **ニュース**: ```2025/01/08```: 私たちは`diffusers`バージョンのモデルをベースにした`Lora`微調整用のコードを更新しました。より少ないVRAM(ビデオメモリ)で動作します。詳細については[こちら](finetune/README_ja.md)をご覧ください。 - **ニュース**: ```2024/11/15```: `CogVideoX1.5` モデルのdiffusersバージョンをリリースしました。わずかなパラメータ調整で以前のコードをそのまま利用可能です。 @@ -243,7 +243,7 @@ CogVideoXは、[清影](https://chatglm.cn/video?fr=osm_cogvideox) と同源の 位置エンコーディング 3d_rope_pos_embed - 3d_sincos_pos_embed + 3d_sincos_pos_embed 3d_rope_pos_embed 3d_rope_pos_embed + learnable_pos_embed @@ -418,8 +418,6 @@ CogVideoのデモは [https://models.aminer.cn/cogvideo](https://models.aminer.c } ``` -あなたの貢献をお待ちしています!詳細は[こちら](resources/contribute_ja.md)をクリックしてください。 - ## ライセンス契約 このリポジトリのコードは [Apache 2.0 License](LICENSE) の下で公開されています。 diff --git a/README_zh.md b/README_zh.md index b982902..202aa05 100644 --- a/README_zh.md +++ b/README_zh.md @@ -14,7 +14,7 @@ 📚 查看 论文使用文档

- 👋 加入我们的 微信Discord + 👋 加入我们的 微信Discord

📍 前往 清影 API平台 体验更大规模的商业版视频生成模型。 @@ -22,11 +22,11 @@ ## 项目更新 -- 🔥🔥 **News**: ```2025/03/16```: 我们推出了 [CogKit](https://github.com/THUDM/CogKit) 工具,这是一个微调**CogView4**, **CogVideoX** 系列的微调和推理框架,一个工具包,玩转我们的多模态生成模型。 +- 🔥🔥 **News**: ```2025/03/24```: 我们推出了 [CogKit](https://github.com/THUDM/CogKit) 工具,这是一个微调**CogView4**, **CogVideoX** 系列的微调和推理框架,一个工具包,玩转我们的多模态生成模型。 - 🔥 **News**: ```2025/02/28```: DDIM Inverse 已经在`CogVideoX-5B` 和 `CogVideoX1.5 -5B` 支持,查看 [here](inference/ddim_inversion.py). - 🔥 **News**: ```2025/01/08```: 我们更新了基于`diffusers`版本模型的`Lora`微调代码,占用显存更低,详情请见[这里](finetune/README_zh.md)。 - 🔥 **News**: ```2024/11/15```: 我们发布 `CogVideoX1.5` 模型的diffusers版本,仅需调整部分参数仅可沿用之前的代码。 -- 🔥 **News**: ```2024/11/08```: 我们发布 `CogVideoX1.5` 模型。CogVideoX1.5 是 CogVideoX 开源模型的升级版本。 +- 🔥 **News**: ```2024/11/08```: 我们发布 `CogVideoX1.5` 模型。CogVideoX1.5 是 CogVideoX 开源模型的升级版本。 CogVideoX1.5-5B 系列模型支持 **10秒** 长度的视频和更高的分辨率,其中 `CogVideoX1.5-5B-I2V` 支持 **任意分辨率** 的视频生成,SAT代码已经更新。`diffusers`版本还在适配中。SAT版本代码前往 [这里](https://huggingface.co/THUDM/CogVideoX1.5-5B-SAT) 下载。 - 🔥**News**: ```2024/10/13```: 成本更低,单卡4090可微调 `CogVideoX-5B` 的微调框架[cogvideox-factory](https://github.com/a-r-r-o-w/cogvideox-factory)已经推出,多种分辨率微调,欢迎使用。 @@ -234,7 +234,7 @@ CogVideoX是 [清影](https://chatglm.cn/video?fr=osm_cogvideox) 同源的开源 位置编码 3d_rope_pos_embed - 3d_sincos_pos_embed + 3d_sincos_pos_embed 3d_rope_pos_embed 3d_rope_pos_embed + learnable_pos_embed @@ -398,8 +398,6 @@ CogVideo的demo网站在[https://models.aminer.cn/cogvideo](https://models.amine } ``` -我们欢迎您的贡献,您可以点击[这里](resources/contribute_zh.md)查看更多信息。 - ## 模型协议 本仓库代码使用 [Apache 2.0 协议](LICENSE) 发布。 diff --git a/finetune/README.md b/finetune/README.md index 736e6b9..bd86c34 100644 --- a/finetune/README.md +++ b/finetune/README.md @@ -143,4 +143,4 @@ When using SFT training, please note: + The original repository uses `lora_alpha` set to 1. We found that this value performed poorly in several runs, possibly due to differences in the model backend and training settings. Our recommendation is to set `lora_alpha` to be equal to the rank or `rank // 2`. -+ It's advised to use a rank of 64 or higher. \ No newline at end of file ++ It's advised to use a rank of 64 or higher. diff --git a/finetune/README_ja.md b/finetune/README_ja.md index 2866428..e4d4641 100644 --- a/finetune/README_ja.md +++ b/finetune/README_ja.md @@ -121,4 +121,4 @@ SFTトレーニングを使用する際に注意すべき点: + 25本以上の動画を使用することで、新しい概念やスタイルのトレーニングが最適です。 + `--id_token` で指定できる識別子トークンを使用すると、トレーニング効果がより良くなります。これはDreamboothトレーニングに似ていますが、このトークンを使用しない通常のファインチューニングでも問題なく動作します。 + 元のリポジトリでは `lora_alpha` が1に設定されていますが、この値は多くの実行で効果が悪かったため、モデルのバックエンドやトレーニング設定の違いが影響している可能性があります。私たちの推奨は、`lora_alpha` を rank と同じか、`rank // 2` に設定することです。 -+ rank は64以上に設定することをお勧めします。 \ No newline at end of file ++ rank は64以上に設定することをお勧めします。 diff --git a/finetune/accelerate_config.yaml b/finetune/accelerate_config.yaml index c46cdc8..7b7750c 100644 --- a/finetune/accelerate_config.yaml +++ b/finetune/accelerate_config.yaml @@ -18,4 +18,4 @@ same_network: true tpu_env: [] tpu_use_cluster: false tpu_use_sudo: false -use_cpu: false \ No newline at end of file +use_cpu: false diff --git a/finetune/configs/zero2.yaml b/finetune/configs/zero2.yaml index 96afa13..b056bd4 100644 --- a/finetune/configs/zero2.yaml +++ b/finetune/configs/zero2.yaml @@ -35,4 +35,4 @@ "gradient_clipping": "auto", "steps_per_print": 2000, "wall_clock_breakdown": false -} \ No newline at end of file +} diff --git a/finetune/configs/zero2_offload.yaml b/finetune/configs/zero2_offload.yaml index b542665..24fdcb4 100644 --- a/finetune/configs/zero2_offload.yaml +++ b/finetune/configs/zero2_offload.yaml @@ -39,4 +39,4 @@ "gradient_clipping": "auto", "steps_per_print": 2000, "wall_clock_breakdown": false -} \ No newline at end of file +} diff --git a/finetune/configs/zero3.yaml b/finetune/configs/zero3.yaml index 8f73fe8..69c5fd5 100644 --- a/finetune/configs/zero3.yaml +++ b/finetune/configs/zero3.yaml @@ -40,4 +40,4 @@ "gradient_clipping": "auto", "steps_per_print": 2000, "wall_clock_breakdown": false -} \ No newline at end of file +} diff --git a/finetune/configs/zero3_offload.yaml b/finetune/configs/zero3_offload.yaml index 9a2c502..58e6529 100644 --- a/finetune/configs/zero3_offload.yaml +++ b/finetune/configs/zero3_offload.yaml @@ -48,4 +48,4 @@ "gradient_clipping": "auto", "steps_per_print": 2000, "wall_clock_breakdown": false -} \ No newline at end of file +} diff --git a/finetune/datasets/bucket_sampler.py b/finetune/datasets/bucket_sampler.py index 8bc1dde..3e8c5c9 100644 --- a/finetune/datasets/bucket_sampler.py +++ b/finetune/datasets/bucket_sampler.py @@ -26,7 +26,11 @@ class BucketSampler(Sampler): """ def __init__( - self, data_source: Dataset, batch_size: int = 8, shuffle: bool = True, drop_last: bool = False + self, + data_source: Dataset, + batch_size: int = 8, + shuffle: bool = True, + drop_last: bool = False, ) -> None: self.data_source = data_source self.batch_size = batch_size @@ -48,7 +52,11 @@ class BucketSampler(Sampler): def __iter__(self): for index, data in enumerate(self.data_source): video_metadata = data["video_metadata"] - f, h, w = video_metadata["num_frames"], video_metadata["height"], video_metadata["width"] + f, h, w = ( + video_metadata["num_frames"], + video_metadata["height"], + video_metadata["width"], + ) self.buckets[(f, h, w)].append(data) if len(self.buckets[(f, h, w)]) == self.batch_size: diff --git a/finetune/datasets/i2v_dataset.py b/finetune/datasets/i2v_dataset.py index cad6331..f1f1d48 100644 --- a/finetune/datasets/i2v_dataset.py +++ b/finetune/datasets/i2v_dataset.py @@ -115,7 +115,9 @@ class BaseI2VDataset(Dataset): train_resolution_str = "x".join(str(x) for x in self.trainer.args.train_resolution) cache_dir = self.trainer.args.data_root / "cache" - video_latent_dir = cache_dir / "video_latent" / self.trainer.args.model_name / train_resolution_str + video_latent_dir = ( + cache_dir / "video_latent" / self.trainer.args.model_name / train_resolution_str + ) prompt_embeddings_dir = cache_dir / "prompt_embeddings" video_latent_dir.mkdir(parents=True, exist_ok=True) prompt_embeddings_dir.mkdir(parents=True, exist_ok=True) @@ -136,7 +138,9 @@ class BaseI2VDataset(Dataset): # [1, seq_len, hidden_size] -> [seq_len, hidden_size] prompt_embedding = prompt_embedding[0] save_file({"prompt_embedding": prompt_embedding}, prompt_embedding_path) - logger.info(f"Saved prompt embedding to {prompt_embedding_path}", main_process_only=False) + logger.info( + f"Saved prompt embedding to {prompt_embedding_path}", main_process_only=False + ) if encoded_video_path.exists(): encoded_video = load_file(encoded_video_path)["encoded_video"] @@ -177,7 +181,9 @@ class BaseI2VDataset(Dataset): }, } - def preprocess(self, video_path: Path | None, image_path: Path | None) -> Tuple[torch.Tensor, torch.Tensor]: + def preprocess( + self, video_path: Path | None, image_path: Path | None + ) -> Tuple[torch.Tensor, torch.Tensor]: """ Loads and preprocesses a video and an image. If either path is None, no preprocessing will be done for that input. @@ -249,13 +255,19 @@ class I2VDatasetWithResize(BaseI2VDataset): self.height = height self.width = width - self.__frame_transforms = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)]) + self.__frame_transforms = transforms.Compose( + [transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)] + ) self.__image_transforms = self.__frame_transforms @override - def preprocess(self, video_path: Path | None, image_path: Path | None) -> Tuple[torch.Tensor, torch.Tensor]: + def preprocess( + self, video_path: Path | None, image_path: Path | None + ) -> Tuple[torch.Tensor, torch.Tensor]: if video_path is not None: - video = preprocess_video_with_resize(video_path, self.max_num_frames, self.height, self.width) + video = preprocess_video_with_resize( + video_path, self.max_num_frames, self.height, self.width + ) else: video = None if image_path is not None: @@ -293,7 +305,9 @@ class I2VDatasetWithBuckets(BaseI2VDataset): ) for b in video_resolution_buckets ] - self.__frame_transforms = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)]) + self.__frame_transforms = transforms.Compose( + [transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)] + ) self.__image_transforms = self.__frame_transforms @override diff --git a/finetune/datasets/t2v_dataset.py b/finetune/datasets/t2v_dataset.py index d123ccf..9967a5c 100644 --- a/finetune/datasets/t2v_dataset.py +++ b/finetune/datasets/t2v_dataset.py @@ -11,7 +11,12 @@ from typing_extensions import override from finetune.constants import LOG_LEVEL, LOG_NAME -from .utils import load_prompts, load_videos, preprocess_video_with_buckets, preprocess_video_with_resize +from .utils import ( + load_prompts, + load_videos, + preprocess_video_with_buckets, + preprocess_video_with_resize, +) if TYPE_CHECKING: @@ -93,7 +98,9 @@ class BaseT2VDataset(Dataset): train_resolution_str = "x".join(str(x) for x in self.trainer.args.train_resolution) cache_dir = self.trainer.args.data_root / "cache" - video_latent_dir = cache_dir / "video_latent" / self.trainer.args.model_name / train_resolution_str + video_latent_dir = ( + cache_dir / "video_latent" / self.trainer.args.model_name / train_resolution_str + ) prompt_embeddings_dir = cache_dir / "prompt_embeddings" video_latent_dir.mkdir(parents=True, exist_ok=True) prompt_embeddings_dir.mkdir(parents=True, exist_ok=True) @@ -114,7 +121,9 @@ class BaseT2VDataset(Dataset): # [1, seq_len, hidden_size] -> [seq_len, hidden_size] prompt_embedding = prompt_embedding[0] save_file({"prompt_embedding": prompt_embedding}, prompt_embedding_path) - logger.info(f"Saved prompt embedding to {prompt_embedding_path}", main_process_only=False) + logger.info( + f"Saved prompt embedding to {prompt_embedding_path}", main_process_only=False + ) if encoded_video_path.exists(): # encoded_video = torch.load(encoded_video_path, weights_only=True) @@ -202,7 +211,9 @@ class T2VDatasetWithResize(BaseT2VDataset): self.height = height self.width = width - self.__frame_transform = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)]) + self.__frame_transform = transforms.Compose( + [transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)] + ) @override def preprocess(self, video_path: Path) -> torch.Tensor: @@ -240,7 +251,9 @@ class T2VDatasetWithBuckets(BaseT2VDataset): for b in video_resolution_buckets ] - self.__frame_transform = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)]) + self.__frame_transform = transforms.Compose( + [transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)] + ) @override def preprocess(self, video_path: Path) -> torch.Tensor: diff --git a/finetune/datasets/utils.py b/finetune/datasets/utils.py index 9f29d4a..e2667d6 100644 --- a/finetune/datasets/utils.py +++ b/finetune/datasets/utils.py @@ -24,12 +24,16 @@ def load_prompts(prompt_path: Path) -> List[str]: def load_videos(video_path: Path) -> List[Path]: with open(video_path, "r", encoding="utf-8") as file: - return [video_path.parent / line.strip() for line in file.readlines() if len(line.strip()) > 0] + return [ + video_path.parent / line.strip() for line in file.readlines() if len(line.strip()) > 0 + ] def load_images(image_path: Path) -> List[Path]: with open(image_path, "r", encoding="utf-8") as file: - return [image_path.parent / line.strip() for line in file.readlines() if len(line.strip()) > 0] + return [ + image_path.parent / line.strip() for line in file.readlines() if len(line.strip()) > 0 + ] def load_images_from_videos(videos_path: List[Path]) -> List[Path]: @@ -169,7 +173,9 @@ def preprocess_video_with_buckets( video_num_frames = len(video_reader) resolution_buckets = [bucket for bucket in resolution_buckets if bucket[0] <= video_num_frames] if len(resolution_buckets) == 0: - raise ValueError(f"video frame count in {video_path} is less than all frame buckets {resolution_buckets}") + raise ValueError( + f"video frame count in {video_path} is less than all frame buckets {resolution_buckets}" + ) nearest_frame_bucket = min( resolution_buckets, @@ -181,7 +187,9 @@ def preprocess_video_with_buckets( frames = frames[:nearest_frame_bucket].float() frames = frames.permute(0, 3, 1, 2).contiguous() - nearest_res = min(resolution_buckets, key=lambda x: abs(x[1] - frames.shape[2]) + abs(x[2] - frames.shape[3])) + nearest_res = min( + resolution_buckets, key=lambda x: abs(x[1] - frames.shape[2]) + abs(x[2] - frames.shape[3]) + ) nearest_res = (nearest_res[1], nearest_res[2]) frames = torch.stack([resize(f, nearest_res) for f in frames], dim=0) diff --git a/finetune/models/cogvideox_i2v/lora_trainer.py b/finetune/models/cogvideox_i2v/lora_trainer.py index a024189..793cf76 100644 --- a/finetune/models/cogvideox_i2v/lora_trainer.py +++ b/finetune/models/cogvideox_i2v/lora_trainer.py @@ -32,13 +32,19 @@ class CogVideoXI2VLoraTrainer(Trainer): components.tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder="tokenizer") - components.text_encoder = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder") + components.text_encoder = T5EncoderModel.from_pretrained( + model_path, subfolder="text_encoder" + ) - components.transformer = CogVideoXTransformer3DModel.from_pretrained(model_path, subfolder="transformer") + components.transformer = CogVideoXTransformer3DModel.from_pretrained( + model_path, subfolder="transformer" + ) components.vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae") - components.scheduler = CogVideoXDPMScheduler.from_pretrained(model_path, subfolder="scheduler") + components.scheduler = CogVideoXDPMScheduler.from_pretrained( + model_path, subfolder="scheduler" + ) return components @@ -73,7 +79,9 @@ class CogVideoXI2VLoraTrainer(Trainer): return_tensors="pt", ) prompt_token_ids = prompt_token_ids.input_ids - prompt_embedding = self.components.text_encoder(prompt_token_ids.to(self.accelerator.device))[0] + prompt_embedding = self.components.text_encoder( + prompt_token_ids.to(self.accelerator.device) + )[0] return prompt_embedding @override @@ -122,22 +130,34 @@ class CogVideoXI2VLoraTrainer(Trainer): # Add frame dimension to images [B,C,H,W] -> [B,C,F,H,W] images = images.unsqueeze(2) # Add noise to images - image_noise_sigma = torch.normal(mean=-3.0, std=0.5, size=(1,), device=self.accelerator.device) + image_noise_sigma = torch.normal( + mean=-3.0, std=0.5, size=(1,), device=self.accelerator.device + ) image_noise_sigma = torch.exp(image_noise_sigma).to(dtype=images.dtype) - noisy_images = images + torch.randn_like(images) * image_noise_sigma[:, None, None, None, None] - image_latent_dist = self.components.vae.encode(noisy_images.to(dtype=self.components.vae.dtype)).latent_dist + noisy_images = ( + images + torch.randn_like(images) * image_noise_sigma[:, None, None, None, None] + ) + image_latent_dist = self.components.vae.encode( + noisy_images.to(dtype=self.components.vae.dtype) + ).latent_dist image_latents = image_latent_dist.sample() * self.components.vae.config.scaling_factor # Sample a random timestep for each sample timesteps = torch.randint( - 0, self.components.scheduler.config.num_train_timesteps, (batch_size,), device=self.accelerator.device + 0, + self.components.scheduler.config.num_train_timesteps, + (batch_size,), + device=self.accelerator.device, ) timesteps = timesteps.long() # from [B, C, F, H, W] to [B, F, C, H, W] latent = latent.permute(0, 2, 1, 3, 4) image_latents = image_latents.permute(0, 2, 1, 3, 4) - assert (latent.shape[0], *latent.shape[2:]) == (image_latents.shape[0], *image_latents.shape[2:]) + assert (latent.shape[0], *latent.shape[2:]) == ( + image_latents.shape[0], + *image_latents.shape[2:], + ) # Padding image_latents to the same frame number as latent padding_shape = (latent.shape[0], latent.shape[1] - 1, *latent.shape[2:]) @@ -169,7 +189,9 @@ class CogVideoXI2VLoraTrainer(Trainer): # Predict noise, For CogVideoX1.5 Only. ofs_emb = ( - None if self.state.transformer_config.ofs_embed_dim is None else latent.new_full((1,), fill_value=2.0) + None + if self.state.transformer_config.ofs_embed_dim is None + else latent.new_full((1,), fill_value=2.0) ) predicted_noise = self.components.transformer( hidden_states=latent_img_noisy, @@ -181,7 +203,9 @@ class CogVideoXI2VLoraTrainer(Trainer): )[0] # Denoise - latent_pred = self.components.scheduler.get_velocity(predicted_noise, latent_noisy, timesteps) + latent_pred = self.components.scheduler.get_velocity( + predicted_noise, latent_noisy, timesteps + ) alphas_cumprod = self.components.scheduler.alphas_cumprod[timesteps] weights = 1 / (1 - alphas_cumprod) @@ -228,7 +252,9 @@ class CogVideoXI2VLoraTrainer(Trainer): if transformer_config.patch_size_t is None: base_num_frames = num_frames else: - base_num_frames = (num_frames + transformer_config.patch_size_t - 1) // transformer_config.patch_size_t + base_num_frames = ( + num_frames + transformer_config.patch_size_t - 1 + ) // transformer_config.patch_size_t freqs_cos, freqs_sin = get_3d_rotary_pos_embed( embed_dim=transformer_config.attention_head_dim, diff --git a/finetune/models/cogvideox_t2v/lora_trainer.py b/finetune/models/cogvideox_t2v/lora_trainer.py index 62582c8..5f0ec1c 100644 --- a/finetune/models/cogvideox_t2v/lora_trainer.py +++ b/finetune/models/cogvideox_t2v/lora_trainer.py @@ -31,13 +31,19 @@ class CogVideoXT2VLoraTrainer(Trainer): components.tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder="tokenizer") - components.text_encoder = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder") + components.text_encoder = T5EncoderModel.from_pretrained( + model_path, subfolder="text_encoder" + ) - components.transformer = CogVideoXTransformer3DModel.from_pretrained(model_path, subfolder="transformer") + components.transformer = CogVideoXTransformer3DModel.from_pretrained( + model_path, subfolder="transformer" + ) components.vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae") - components.scheduler = CogVideoXDPMScheduler.from_pretrained(model_path, subfolder="scheduler") + components.scheduler = CogVideoXDPMScheduler.from_pretrained( + model_path, subfolder="scheduler" + ) return components @@ -72,7 +78,9 @@ class CogVideoXT2VLoraTrainer(Trainer): return_tensors="pt", ) prompt_token_ids = prompt_token_ids.input_ids - prompt_embedding = self.components.text_encoder(prompt_token_ids.to(self.accelerator.device))[0] + prompt_embedding = self.components.text_encoder( + prompt_token_ids.to(self.accelerator.device) + )[0] return prompt_embedding @override @@ -115,7 +123,10 @@ class CogVideoXT2VLoraTrainer(Trainer): # Sample a random timestep for each sample timesteps = torch.randint( - 0, self.components.scheduler.config.num_train_timesteps, (batch_size,), device=self.accelerator.device + 0, + self.components.scheduler.config.num_train_timesteps, + (batch_size,), + device=self.accelerator.device, ) timesteps = timesteps.long() @@ -150,7 +161,9 @@ class CogVideoXT2VLoraTrainer(Trainer): )[0] # Denoise - latent_pred = self.components.scheduler.get_velocity(predicted_noise, latent_added_noise, timesteps) + latent_pred = self.components.scheduler.get_velocity( + predicted_noise, latent_added_noise, timesteps + ) alphas_cumprod = self.components.scheduler.alphas_cumprod[timesteps] weights = 1 / (1 - alphas_cumprod) @@ -196,7 +209,9 @@ class CogVideoXT2VLoraTrainer(Trainer): if transformer_config.patch_size_t is None: base_num_frames = num_frames else: - base_num_frames = (num_frames + transformer_config.patch_size_t - 1) // transformer_config.patch_size_t + base_num_frames = ( + num_frames + transformer_config.patch_size_t - 1 + ) // transformer_config.patch_size_t freqs_cos, freqs_sin = get_3d_rotary_pos_embed( embed_dim=transformer_config.attention_head_dim, crops_coords=None, diff --git a/finetune/models/utils.py b/finetune/models/utils.py index ef3ea5b..2028672 100644 --- a/finetune/models/utils.py +++ b/finetune/models/utils.py @@ -52,6 +52,8 @@ def get_model_cls(model_type: str, training_type: Literal["lora", "sft"]) -> Tra print(f"\nSupported training types for '{model_type}' are:") for supported_type in SUPPORTED_MODELS[model_type]: print(f" • {supported_type}") - raise ValueError(f"Training type '{training_type}' is not supported for model '{model_type}'") + raise ValueError( + f"Training type '{training_type}' is not supported for model '{model_type}'" + ) return SUPPORTED_MODELS[model_type][training_type] diff --git a/finetune/schemas/args.py b/finetune/schemas/args.py index a01ea18..bba7d01 100644 --- a/finetune/schemas/args.py +++ b/finetune/schemas/args.py @@ -115,14 +115,18 @@ class Args(BaseModel): def validate_validation_images(cls, v: str | None, info: ValidationInfo) -> str | None: values = info.data if values.get("do_validation") and values.get("model_type") == "i2v" and not v: - raise ValueError("validation_images must be specified when do_validation is True and model_type is i2v") + raise ValueError( + "validation_images must be specified when do_validation is True and model_type is i2v" + ) return v @field_validator("validation_videos") def validate_validation_videos(cls, v: str | None, info: ValidationInfo) -> str | None: values = info.data if values.get("do_validation") and values.get("model_type") == "v2v" and not v: - raise ValueError("validation_videos must be specified when do_validation is True and model_type is v2v") + raise ValueError( + "validation_videos must be specified when do_validation is True and model_type is v2v" + ) return v @field_validator("validation_steps") @@ -148,7 +152,9 @@ class Args(BaseModel): model_name = info.data.get("model_name", "") if model_name in ["cogvideox-5b-i2v", "cogvideox-5b-t2v"]: if (height, width) != (480, 720): - raise ValueError("For cogvideox-5b models, height must be 480 and width must be 720") + raise ValueError( + "For cogvideox-5b models, height must be 480 and width must be 720" + ) return v @@ -221,7 +227,9 @@ class Args(BaseModel): # LoRA parameters parser.add_argument("--rank", type=int, default=128) parser.add_argument("--lora_alpha", type=int, default=64) - parser.add_argument("--target_modules", type=str, nargs="+", default=["to_q", "to_k", "to_v", "to_out.0"]) + parser.add_argument( + "--target_modules", type=str, nargs="+", default=["to_q", "to_k", "to_v", "to_out.0"] + ) # Checkpointing parser.add_argument("--checkpointing_steps", type=int, default=200) diff --git a/finetune/scripts/extract_images.py b/finetune/scripts/extract_images.py index 42ce8e2..eca6c21 100644 --- a/finetune/scripts/extract_images.py +++ b/finetune/scripts/extract_images.py @@ -8,7 +8,10 @@ import cv2 def parse_args(): parser = argparse.ArgumentParser() parser.add_argument( - "--datadir", type=str, required=True, help="Root directory containing videos.txt and video subdirectory" + "--datadir", + type=str, + required=True, + help="Root directory containing videos.txt and video subdirectory", ) return parser.parse_args() diff --git a/finetune/trainer.py b/finetune/trainer.py index 46171e9..5746fee 100644 --- a/finetune/trainer.py +++ b/finetune/trainer.py @@ -88,7 +88,9 @@ class Trainer: def _init_distributed(self): logging_dir = Path(self.args.output_dir, "logs") - project_config = ProjectConfiguration(project_dir=self.args.output_dir, logging_dir=logging_dir) + project_config = ProjectConfiguration( + project_dir=self.args.output_dir, logging_dir=logging_dir + ) ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) init_process_group_kwargs = InitProcessGroupKwargs( backend="nccl", timeout=timedelta(seconds=self.args.nccl_timeout) @@ -183,7 +185,9 @@ class Trainer: # Prepare VAE and text encoder for encoding self.components.vae.requires_grad_(False) self.components.text_encoder.requires_grad_(False) - self.components.vae = self.components.vae.to(self.accelerator.device, dtype=self.state.weight_dtype) + self.components.vae = self.components.vae.to( + self.accelerator.device, dtype=self.state.weight_dtype + ) self.components.text_encoder = self.components.text_encoder.to( self.accelerator.device, dtype=self.state.weight_dtype ) @@ -263,7 +267,9 @@ class Trainer: # For LoRA, we only want to train the LoRA weights # For SFT, we want to train all the parameters - trainable_parameters = list(filter(lambda p: p.requires_grad, self.components.transformer.parameters())) + trainable_parameters = list( + filter(lambda p: p.requires_grad, self.components.transformer.parameters()) + ) transformer_parameters_with_lr = { "params": trainable_parameters, "lr": self.args.learning_rate, @@ -287,7 +293,9 @@ class Trainer: use_deepspeed=use_deepspeed_opt, ) - num_update_steps_per_epoch = math.ceil(len(self.data_loader) / self.args.gradient_accumulation_steps) + num_update_steps_per_epoch = math.ceil( + len(self.data_loader) / self.args.gradient_accumulation_steps + ) if self.args.train_steps is None: self.args.train_steps = self.args.train_epochs * num_update_steps_per_epoch self.state.overwrote_max_train_steps = True @@ -322,12 +330,16 @@ class Trainer: self.lr_scheduler = lr_scheduler def prepare_for_training(self) -> None: - self.components.transformer, self.optimizer, self.data_loader, self.lr_scheduler = self.accelerator.prepare( - self.components.transformer, self.optimizer, self.data_loader, self.lr_scheduler + self.components.transformer, self.optimizer, self.data_loader, self.lr_scheduler = ( + self.accelerator.prepare( + self.components.transformer, self.optimizer, self.data_loader, self.lr_scheduler + ) ) # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil(len(self.data_loader) / self.args.gradient_accumulation_steps) + num_update_steps_per_epoch = math.ceil( + len(self.data_loader) / self.args.gradient_accumulation_steps + ) if self.state.overwrote_max_train_steps: self.args.train_steps = self.args.train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs @@ -364,7 +376,9 @@ class Trainer: logger.info(f"Memory before training start: {json.dumps(memory_statistics, indent=4)}") self.state.total_batch_size_count = ( - self.args.batch_size * self.accelerator.num_processes * self.args.gradient_accumulation_steps + self.args.batch_size + * self.accelerator.num_processes + * self.args.gradient_accumulation_steps ) info = { "trainable parameters": self.state.num_trainable_parameters, @@ -454,7 +468,9 @@ class Trainer: progress_bar.set_postfix(logs) # Maybe run validation - should_run_validation = self.args.do_validation and global_step % self.args.validation_steps == 0 + should_run_validation = ( + self.args.do_validation and global_step % self.args.validation_steps == 0 + ) if should_run_validation: del loss free_memory() @@ -466,7 +482,9 @@ class Trainer: break memory_statistics = get_memory_statistics() - logger.info(f"Memory after epoch {epoch + 1}: {json.dumps(memory_statistics, indent=4)}") + logger.info( + f"Memory after epoch {epoch + 1}: {json.dumps(memory_statistics, indent=4)}" + ) accelerator.wait_for_everyone() self.__maybe_save_checkpoint(global_step, must_save=True) @@ -504,7 +522,9 @@ class Trainer: # Can't using model_cpu_offload in deepspeed, # so we need to move all components in pipe to device # pipe.to(self.accelerator.device, dtype=self.state.weight_dtype) - self.__move_components_to_device(dtype=self.state.weight_dtype, ignore_list=["transformer"]) + self.__move_components_to_device( + dtype=self.state.weight_dtype, ignore_list=["transformer"] + ) else: # if not using deepspeed, use model_cpu_offload to further reduce memory usage # Or use pipe.enable_sequential_cpu_offload() to further reduce memory usage @@ -528,7 +548,9 @@ class Trainer: video = self.state.validation_videos[i] if image is not None: - image = preprocess_image_with_resize(image, self.state.train_height, self.state.train_width) + image = preprocess_image_with_resize( + image, self.state.train_height, self.state.train_width + ) # Convert image tensor (C, H, W) to PIL images image = image.to(torch.uint8) image = image.permute(1, 2, 0).cpu().numpy() @@ -546,7 +568,9 @@ class Trainer: f"Validating sample {i + 1}/{num_validation_samples} on process {accelerator.process_index}. Prompt: {prompt}", main_process_only=False, ) - validation_artifacts = self.validation_step({"prompt": prompt, "image": image, "video": video}, pipe) + validation_artifacts = self.validation_step( + {"prompt": prompt, "image": image, "video": video}, pipe + ) if ( self.state.using_deepspeed @@ -565,7 +589,9 @@ class Trainer: "video": {"type": "video", "value": video}, } for i, (artifact_type, artifact_value) in enumerate(validation_artifacts): - artifacts.update({f"artifact_{i}": {"type": artifact_type, "value": artifact_value}}) + artifacts.update( + {f"artifact_{i}": {"type": artifact_type, "value": artifact_value}} + ) logger.debug( f"Validation artifacts on process {accelerator.process_index}: {list(artifacts.keys())}", main_process_only=False, @@ -600,8 +626,12 @@ class Trainer: tracker_key = "validation" for tracker in accelerator.trackers: if tracker.name == "wandb": - image_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Image)] - video_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Video)] + image_artifacts = [ + artifact for artifact in all_artifacts if isinstance(artifact, wandb.Image) + ] + video_artifacts = [ + artifact for artifact in all_artifacts if isinstance(artifact, wandb.Video) + ] tracker.log( { tracker_key: {"images": image_artifacts, "videos": video_artifacts}, @@ -618,7 +648,9 @@ class Trainer: pipe.remove_all_hooks() del pipe # Load models except those not needed for training - self.__move_components_to_device(dtype=self.state.weight_dtype, ignore_list=self.UNLOAD_LIST) + self.__move_components_to_device( + dtype=self.state.weight_dtype, ignore_list=self.UNLOAD_LIST + ) self.components.transformer.to(self.accelerator.device, dtype=self.state.weight_dtype) # Change trainable weights back to fp32 to keep with dtype after prepare the model @@ -687,7 +719,9 @@ class Trainer: for name, component in components.items(): if not isinstance(component, type) and hasattr(component, "to"): if name not in ignore_list: - setattr(self.components, name, component.to(self.accelerator.device, dtype=dtype)) + setattr( + self.components, name, component.to(self.accelerator.device, dtype=dtype) + ) def __move_components_to_cpu(self, unload_list: List[str] = []): unload_list = set(unload_list) @@ -732,11 +766,13 @@ class Trainer: ): transformer_ = unwrap_model(self.accelerator, model) else: - raise ValueError(f"Unexpected save model: {unwrap_model(self.accelerator, model).__class__}") + raise ValueError( + f"Unexpected save model: {unwrap_model(self.accelerator, model).__class__}" + ) else: - transformer_ = unwrap_model(self.accelerator, self.components.transformer).__class__.from_pretrained( - self.args.model_path, subfolder="transformer" - ) + transformer_ = unwrap_model( + self.accelerator, self.components.transformer + ).__class__.from_pretrained(self.args.model_path, subfolder="transformer") transformer_.add_adapter(transformer_lora_config) lora_state_dict = self.components.pipeline_cls.lora_state_dict(input_dir) @@ -745,7 +781,9 @@ class Trainer: for k, v in lora_state_dict.items() if k.startswith("transformer.") } - incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") + incompatible_keys = set_peft_model_state_dict( + transformer_, transformer_state_dict, adapter_name="default" + ) if incompatible_keys is not None: # check only for unexpected keys unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) @@ -759,7 +797,10 @@ class Trainer: self.accelerator.register_load_state_pre_hook(load_model_hook) def __maybe_save_checkpoint(self, global_step: int, must_save: bool = False): - if self.accelerator.distributed_type == DistributedType.DEEPSPEED or self.accelerator.is_main_process: + if ( + self.accelerator.distributed_type == DistributedType.DEEPSPEED + or self.accelerator.is_main_process + ): if must_save or global_step % self.args.checkpointing_steps == 0: # for training save_path = get_intermediate_ckpt_path( diff --git a/finetune/utils/checkpointing.py b/finetune/utils/checkpointing.py index 775038c..9c1ccb5 100644 --- a/finetune/utils/checkpointing.py +++ b/finetune/utils/checkpointing.py @@ -23,7 +23,9 @@ def get_latest_ckpt_path_to_resume_from( else: resume_from_checkpoint_path = Path(resume_from_checkpoint) if not resume_from_checkpoint_path.exists(): - logger.info(f"Checkpoint '{resume_from_checkpoint}' does not exist. Starting a new training run.") + logger.info( + f"Checkpoint '{resume_from_checkpoint}' does not exist. Starting a new training run." + ) initial_global_step = 0 global_step = 0 first_epoch = 0 diff --git a/finetune/utils/memory_utils.py b/finetune/utils/memory_utils.py index 0c88d70..f247e7d 100644 --- a/finetune/utils/memory_utils.py +++ b/finetune/utils/memory_utils.py @@ -55,7 +55,9 @@ def unload_model(model): model.to("cpu") -def make_contiguous(x: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: +def make_contiguous( + x: Union[torch.Tensor, Dict[str, torch.Tensor]], +) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: if isinstance(x, torch.Tensor): return x.contiguous() elif isinstance(x, dict): diff --git a/finetune/utils/optimizer_utils.py b/finetune/utils/optimizer_utils.py index d24aa3f..94aeafd 100644 --- a/finetune/utils/optimizer_utils.py +++ b/finetune/utils/optimizer_utils.py @@ -67,7 +67,9 @@ def get_optimizer( optimizer_name = "adamw" if (use_8bit or use_4bit) and optimizer_name not in ["adam", "adamw"]: - raise ValueError("`use_8bit` and `use_4bit` can only be used with the Adam and AdamW optimizers.") + raise ValueError( + "`use_8bit` and `use_4bit` can only be used with the Adam and AdamW optimizers." + ) if use_8bit: try: @@ -81,7 +83,9 @@ def get_optimizer( if use_torchao: from torchao.prototype.low_bit_optim import AdamW4bit, AdamW8bit - optimizer_class = AdamW8bit if use_8bit else AdamW4bit if use_4bit else torch.optim.AdamW + optimizer_class = ( + AdamW8bit if use_8bit else AdamW4bit if use_4bit else torch.optim.AdamW + ) else: optimizer_class = bnb.optim.AdamW8bit if use_8bit else torch.optim.AdamW @@ -109,7 +113,9 @@ def get_optimizer( try: import prodigyopt except ImportError: - raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") + raise ImportError( + "To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`" + ) optimizer_class = prodigyopt.Prodigy @@ -133,7 +139,9 @@ def get_optimizer( try: import came_pytorch except ImportError: - raise ImportError("To use CAME, please install the came-pytorch library: `pip install came-pytorch`") + raise ImportError( + "To use CAME, please install the came-pytorch library: `pip install came-pytorch`" + ) optimizer_class = came_pytorch.CAME @@ -151,7 +159,10 @@ def get_optimizer( init_kwargs.update({"fused": True}) optimizer = CPUOffloadOptimizer( - params_to_optimize, optimizer_class=optimizer_class, offload_gradients=offload_gradients, **init_kwargs + params_to_optimize, + optimizer_class=optimizer_class, + offload_gradients=offload_gradients, + **init_kwargs, ) else: optimizer = optimizer_class(params_to_optimize, **init_kwargs) diff --git a/inference/cli_demo.py b/inference/cli_demo.py index 37dfcfc..f262d33 100644 --- a/inference/cli_demo.py +++ b/inference/cli_demo.py @@ -99,7 +99,9 @@ def generate_video( desired_resolution = RESOLUTION_MAP[model_name] if width is None or height is None: height, width = desired_resolution - logging.info(f"\033[1mUsing default resolution {desired_resolution} for {model_name}\033[0m") + logging.info( + f"\033[1mUsing default resolution {desired_resolution} for {model_name}\033[0m" + ) elif (height, width) != desired_resolution: if generate_type == "i2v": # For i2v models, use user-defined width and height @@ -124,7 +126,9 @@ def generate_video( # If you're using with lora, add this code if lora_path: - pipe.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name="test_1") + pipe.load_lora_weights( + lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name="test_1" + ) pipe.fuse_lora(components=["transformer"], lora_scale=1 / lora_rank) # 2. Set Scheduler. @@ -133,7 +137,9 @@ def generate_video( # using `CogVideoXDPMScheduler` for CogVideoX-5B / CogVideoX-5B-I2V. # pipe.scheduler = CogVideoXDDIMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") - pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") + pipe.scheduler = CogVideoXDPMScheduler.from_config( + pipe.scheduler.config, timestep_spacing="trailing" + ) # 3. Enable CPU offload for the model. # turn off if you have multiple GPUs or enough GPU memory(such as H100) and it will cost less time in inference @@ -190,8 +196,12 @@ def generate_video( if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX") - parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated") + parser = argparse.ArgumentParser( + description="Generate a video from a text prompt using CogVideoX" + ) + parser.add_argument( + "--prompt", type=str, required=True, help="The description of the video to be generated" + ) parser.add_argument( "--image_or_video_path", type=str, @@ -199,20 +209,44 @@ if __name__ == "__main__": help="The path of the image to be used as the background of the video", ) parser.add_argument( - "--model_path", type=str, default="THUDM/CogVideoX1.5-5B", help="Path of the pre-trained model use" + "--model_path", + type=str, + default="THUDM/CogVideoX1.5-5B", + help="Path of the pre-trained model use", + ) + parser.add_argument( + "--lora_path", type=str, default=None, help="The path of the LoRA weights to be used" ) - parser.add_argument("--lora_path", type=str, default=None, help="The path of the LoRA weights to be used") parser.add_argument("--lora_rank", type=int, default=128, help="The rank of the LoRA weights") - parser.add_argument("--output_path", type=str, default="./output.mp4", help="The path save generated video") - parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance") + parser.add_argument( + "--output_path", type=str, default="./output.mp4", help="The path save generated video" + ) + parser.add_argument( + "--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance" + ) parser.add_argument("--num_inference_steps", type=int, default=50, help="Inference steps") - parser.add_argument("--num_frames", type=int, default=81, help="Number of steps for the inference process") + parser.add_argument( + "--num_frames", type=int, default=81, help="Number of steps for the inference process" + ) parser.add_argument("--width", type=int, default=None, help="The width of the generated video") - parser.add_argument("--height", type=int, default=None, help="The height of the generated video") - parser.add_argument("--fps", type=int, default=16, help="The frames per second for the generated video") - parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt") - parser.add_argument("--generate_type", type=str, default="t2v", help="The type of video generation") - parser.add_argument("--dtype", type=str, default="bfloat16", help="The data type for computation") + parser.add_argument( + "--height", type=int, default=None, help="The height of the generated video" + ) + parser.add_argument( + "--fps", type=int, default=16, help="The frames per second for the generated video" + ) + parser.add_argument( + "--num_videos_per_prompt", + type=int, + default=1, + help="Number of videos to generate per prompt", + ) + parser.add_argument( + "--generate_type", type=str, default="t2v", help="The type of video generation" + ) + parser.add_argument( + "--dtype", type=str, default="bfloat16", help="The data type for computation" + ) parser.add_argument("--seed", type=int, default=42, help="The seed for reproducibility") args = parser.parse_args() diff --git a/inference/cli_demo_quantization.py b/inference/cli_demo_quantization.py index 1ea3358..a91c04c 100644 --- a/inference/cli_demo_quantization.py +++ b/inference/cli_demo_quantization.py @@ -19,7 +19,12 @@ import argparse import os import torch import torch._dynamo -from diffusers import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel, CogVideoXPipeline, CogVideoXDPMScheduler +from diffusers import ( + AutoencoderKLCogVideoX, + CogVideoXTransformer3DModel, + CogVideoXPipeline, + CogVideoXDPMScheduler, +) from diffusers.utils import export_to_video from transformers import T5EncoderModel from torchao.quantization import quantize_, int8_weight_only @@ -68,9 +73,13 @@ def generate_video( - quantization_scheme (str): The quantization scheme to use ('int8', 'fp8'). - dtype (torch.dtype): The data type for computation (default is torch.bfloat16). """ - text_encoder = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder", torch_dtype=dtype) + text_encoder = T5EncoderModel.from_pretrained( + model_path, subfolder="text_encoder", torch_dtype=dtype + ) text_encoder = quantize_model(part=text_encoder, quantization_scheme=quantization_scheme) - transformer = CogVideoXTransformer3DModel.from_pretrained(model_path, subfolder="transformer", torch_dtype=dtype) + transformer = CogVideoXTransformer3DModel.from_pretrained( + model_path, subfolder="transformer", torch_dtype=dtype + ) transformer = quantize_model(part=transformer, quantization_scheme=quantization_scheme) vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae", torch_dtype=dtype) vae = quantize_model(part=vae, quantization_scheme=quantization_scheme) @@ -81,7 +90,9 @@ def generate_video( vae=vae, torch_dtype=dtype, ) - pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") + pipe.scheduler = CogVideoXDPMScheduler.from_config( + pipe.scheduler.config, timestep_spacing="trailing" + ) pipe.enable_model_cpu_offload() pipe.vae.enable_slicing() pipe.vae.enable_tiling() @@ -100,16 +111,34 @@ def generate_video( if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX") - parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated") - parser.add_argument("--model_path", type=str, default="THUDM/CogVideoX-5b", help="Path of the pre-trained model") - parser.add_argument("--output_path", type=str, default="./output.mp4", help="Path to save generated video") - parser.add_argument("--num_inference_steps", type=int, default=50, help="Inference steps") - parser.add_argument("--guidance_scale", type=float, default=6.0, help="Classifier-free guidance scale") - parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Videos to generate per prompt") - parser.add_argument("--dtype", type=str, default="bfloat16", help="Data type (e.g., 'float16', 'bfloat16')") + parser = argparse.ArgumentParser( + description="Generate a video from a text prompt using CogVideoX" + ) parser.add_argument( - "--quantization_scheme", type=str, default="fp8", choices=["int8", "fp8"], help="Quantization scheme" + "--prompt", type=str, required=True, help="The description of the video to be generated" + ) + parser.add_argument( + "--model_path", type=str, default="THUDM/CogVideoX-5b", help="Path of the pre-trained model" + ) + parser.add_argument( + "--output_path", type=str, default="./output.mp4", help="Path to save generated video" + ) + parser.add_argument("--num_inference_steps", type=int, default=50, help="Inference steps") + parser.add_argument( + "--guidance_scale", type=float, default=6.0, help="Classifier-free guidance scale" + ) + parser.add_argument( + "--num_videos_per_prompt", type=int, default=1, help="Videos to generate per prompt" + ) + parser.add_argument( + "--dtype", type=str, default="bfloat16", help="Data type (e.g., 'float16', 'bfloat16')" + ) + parser.add_argument( + "--quantization_scheme", + type=str, + default="fp8", + choices=["int8", "fp8"], + help="Quantization scheme", ) parser.add_argument("--num_frames", type=int, default=81, help="Number of frames in the video") parser.add_argument("--fps", type=int, default=16, help="Frames per second for output video") diff --git a/inference/cli_vae_demo.py b/inference/cli_vae_demo.py index 07e949e..508aedb 100644 --- a/inference/cli_vae_demo.py +++ b/inference/cli_vae_demo.py @@ -104,18 +104,34 @@ def save_video(tensor, output_path): if __name__ == "__main__": parser = argparse.ArgumentParser(description="CogVideoX encode/decode demo") - parser.add_argument("--model_path", type=str, required=True, help="The path to the CogVideoX model") + parser.add_argument( + "--model_path", type=str, required=True, help="The path to the CogVideoX model" + ) parser.add_argument("--video_path", type=str, help="The path to the video file (for encoding)") - parser.add_argument("--encoded_path", type=str, help="The path to the encoded tensor file (for decoding)") - parser.add_argument("--output_path", type=str, default=".", help="The path to save the output file") parser.add_argument( - "--mode", type=str, choices=["encode", "decode", "both"], required=True, help="Mode: encode, decode, or both" + "--encoded_path", type=str, help="The path to the encoded tensor file (for decoding)" ) parser.add_argument( - "--dtype", type=str, default="bfloat16", help="The data type for computation (e.g., 'float16' or 'bfloat16')" + "--output_path", type=str, default=".", help="The path to save the output file" ) parser.add_argument( - "--device", type=str, default="cuda", help="The device to use for computation (e.g., 'cuda' or 'cpu')" + "--mode", + type=str, + choices=["encode", "decode", "both"], + required=True, + help="Mode: encode, decode, or both", + ) + parser.add_argument( + "--dtype", + type=str, + default="bfloat16", + help="The data type for computation (e.g., 'float16' or 'bfloat16')", + ) + parser.add_argument( + "--device", + type=str, + default="cuda", + help="The device to use for computation (e.g., 'cuda' or 'cpu')", ) args = parser.parse_args() @@ -126,15 +142,21 @@ if __name__ == "__main__": assert args.video_path, "Video path must be provided for encoding." encoded_output = encode_video(args.model_path, args.video_path, dtype, device) torch.save(encoded_output, args.output_path + "/encoded.pt") - print(f"Finished encoding the video to a tensor, save it to a file at {encoded_output}/encoded.pt") + print( + f"Finished encoding the video to a tensor, save it to a file at {encoded_output}/encoded.pt" + ) elif args.mode == "decode": assert args.encoded_path, "Encoded tensor path must be provided for decoding." decoded_output = decode_video(args.model_path, args.encoded_path, dtype, device) save_video(decoded_output, args.output_path) - print(f"Finished decoding the video and saved it to a file at {args.output_path}/output.mp4") + print( + f"Finished decoding the video and saved it to a file at {args.output_path}/output.mp4" + ) elif args.mode == "both": assert args.video_path, "Video path must be provided for encoding." encoded_output = encode_video(args.model_path, args.video_path, dtype, device) torch.save(encoded_output, args.output_path + "/encoded.pt") - decoded_output = decode_video(args.model_path, args.output_path + "/encoded.pt", dtype, device) + decoded_output = decode_video( + args.model_path, args.output_path + "/encoded.pt", dtype, device + ) save_video(decoded_output, args.output_path) diff --git a/inference/convert_demo.py b/inference/convert_demo.py index b741581..2c423fc 100644 --- a/inference/convert_demo.py +++ b/inference/convert_demo.py @@ -35,9 +35,9 @@ Video descriptions must have the same num of words as examples below. Extra word """ sys_prompt_i2v = """ -**Objective**: **Give a highly descriptive video caption based on input image and user input. **. As an expert, delve deep into the image with a discerning eye, leveraging rich creativity, meticulous thought. When describing the details of an image, include appropriate dynamic information to ensure that the video caption contains reasonable actions and plots. If user input is not empty, then the caption should be expanded according to the user's input. +**Objective**: **Give a highly descriptive video caption based on input image and user input. **. As an expert, delve deep into the image with a discerning eye, leveraging rich creativity, meticulous thought. When describing the details of an image, include appropriate dynamic information to ensure that the video caption contains reasonable actions and plots. If user input is not empty, then the caption should be expanded according to the user's input. -**Note**: The input image is the first frame of the video, and the output video caption should describe the motion starting from the current image. User input is optional and can be empty. +**Note**: The input image is the first frame of the video, and the output video caption should describe the motion starting from the current image. User input is optional and can be empty. **Note**: Don't contain camera transitions!!! Don't contain screen switching!!! Don't contain perspective shifts !!! @@ -144,7 +144,9 @@ def convert_prompt(prompt: str, retry_times: int = 3, type: str = "t2v", image_p if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--prompt", type=str, required=True, help="Prompt to convert") - parser.add_argument("--retry_times", type=int, default=3, help="Number of times to retry the conversion") + parser.add_argument( + "--retry_times", type=int, default=3, help="Number of times to retry the conversion" + ) parser.add_argument("--type", type=str, default="t2v", help="Type of conversion (t2v or i2v)") parser.add_argument("--image_path", type=str, default=None, help="Path to the image file") args = parser.parse_args() diff --git a/inference/ddim_inversion.py b/inference/ddim_inversion.py index e932bf4..1ca682e 100644 --- a/inference/ddim_inversion.py +++ b/inference/ddim_inversion.py @@ -30,7 +30,10 @@ import torchvision.transforms as T from diffusers.models.attention_processor import Attention, CogVideoXAttnProcessor2_0 from diffusers.models.autoencoders import AutoencoderKLCogVideoX from diffusers.models.embeddings import apply_rotary_emb -from diffusers.models.transformers.cogvideox_transformer_3d import CogVideoXBlock, CogVideoXTransformer3DModel +from diffusers.models.transformers.cogvideox_transformer_3d import ( + CogVideoXBlock, + CogVideoXTransformer3DModel, +) from diffusers.pipelines.cogvideo.pipeline_cogvideox import CogVideoXPipeline, retrieve_timesteps from diffusers.schedulers import CogVideoXDDIMScheduler, DDIMInverseScheduler from diffusers.utils import export_to_video @@ -62,22 +65,48 @@ class DDIMInversionArguments(TypedDict): def get_args() -> DDIMInversionArguments: parser = argparse.ArgumentParser() - parser.add_argument("--model_path", type=str, required=True, help="Path of the pretrained model") - parser.add_argument("--prompt", type=str, required=True, help="Prompt for the direct sample procedure") - parser.add_argument("--video_path", type=str, required=True, help="Path of the video for inversion") - parser.add_argument("--output_path", type=str, default="output", help="Path of the output videos") - parser.add_argument("--guidance_scale", type=float, default=6.0, help="Classifier-free guidance scale") - parser.add_argument("--num_inference_steps", type=int, default=50, help="Number of inference steps") - parser.add_argument("--skip_frames_start", type=int, default=0, help="Number of skipped frames from the start") - parser.add_argument("--skip_frames_end", type=int, default=0, help="Number of skipped frames from the end") - parser.add_argument("--frame_sample_step", type=int, default=None, help="Temporal stride of the sampled frames") - parser.add_argument("--max_num_frames", type=int, default=81, help="Max number of sampled frames") + parser.add_argument( + "--model_path", type=str, required=True, help="Path of the pretrained model" + ) + parser.add_argument( + "--prompt", type=str, required=True, help="Prompt for the direct sample procedure" + ) + parser.add_argument( + "--video_path", type=str, required=True, help="Path of the video for inversion" + ) + parser.add_argument( + "--output_path", type=str, default="output", help="Path of the output videos" + ) + parser.add_argument( + "--guidance_scale", type=float, default=6.0, help="Classifier-free guidance scale" + ) + parser.add_argument( + "--num_inference_steps", type=int, default=50, help="Number of inference steps" + ) + parser.add_argument( + "--skip_frames_start", type=int, default=0, help="Number of skipped frames from the start" + ) + parser.add_argument( + "--skip_frames_end", type=int, default=0, help="Number of skipped frames from the end" + ) + parser.add_argument( + "--frame_sample_step", type=int, default=None, help="Temporal stride of the sampled frames" + ) + parser.add_argument( + "--max_num_frames", type=int, default=81, help="Max number of sampled frames" + ) parser.add_argument("--width", type=int, default=720, help="Resized width of the video frames") - parser.add_argument("--height", type=int, default=480, help="Resized height of the video frames") + parser.add_argument( + "--height", type=int, default=480, help="Resized height of the video frames" + ) parser.add_argument("--fps", type=int, default=8, help="Frame rate of the output videos") - parser.add_argument("--dtype", type=str, default="bf16", choices=["bf16", "fp16"], help="Dtype of the model") + parser.add_argument( + "--dtype", type=str, default="bf16", choices=["bf16", "fp16"], help="Dtype of the model" + ) parser.add_argument("--seed", type=int, default=42, help="Seed for the random number generator") - parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"], help="Device for inference") + parser.add_argument( + "--device", type=str, default="cuda", choices=["cuda", "cpu"], help="Device for inference" + ) args = parser.parse_args() args.dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float16 @@ -116,13 +145,20 @@ class CogVideoXAttnProcessor2_0ForDDIMInversion(CogVideoXAttnProcessor2_0): # Apply RoPE if needed if image_rotary_emb is not None: - query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) + query[:, :, text_seq_length:] = apply_rotary_emb( + query[:, :, text_seq_length:], image_rotary_emb + ) if not attn.is_cross_attention: if key.size(2) == query.size(2): # Attention for reference hidden states - key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) + key[:, :, text_seq_length:] = apply_rotary_emb( + key[:, :, text_seq_length:], image_rotary_emb + ) else: # RoPE should be applied to each group of image tokens - key[:, :, text_seq_length : text_seq_length + image_seq_length] = apply_rotary_emb( - key[:, :, text_seq_length : text_seq_length + image_seq_length], image_rotary_emb + key[:, :, text_seq_length : text_seq_length + image_seq_length] = ( + apply_rotary_emb( + key[:, :, text_seq_length : text_seq_length + image_seq_length], + image_rotary_emb, + ) ) key[:, :, text_seq_length * 2 + image_seq_length :] = apply_rotary_emb( key[:, :, text_seq_length * 2 + image_seq_length :], image_rotary_emb @@ -162,8 +198,12 @@ class CogVideoXAttnProcessor2_0ForDDIMInversion(CogVideoXAttnProcessor2_0): ) if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + attention_mask = attention_mask.view( + batch_size, attn.heads, -1, attention_mask.shape[-1] + ) query = attn.to_q(hidden_states) key = attn.to_k(hidden_states) @@ -260,14 +300,18 @@ def get_video_frames( return frames.permute(0, 3, 1, 2).contiguous() # [F, C, H, W] -def encode_video_frames(vae: AutoencoderKLCogVideoX, video_frames: torch.FloatTensor) -> torch.FloatTensor: +def encode_video_frames( + vae: AutoencoderKLCogVideoX, video_frames: torch.FloatTensor +) -> torch.FloatTensor: video_frames = video_frames.to(device=vae.device, dtype=vae.dtype) video_frames = video_frames.unsqueeze(0).permute(0, 2, 1, 3, 4) # [B, C, F, H, W] latent_dist = vae.encode(x=video_frames).latent_dist.sample().transpose(1, 2) return latent_dist * vae.config.scaling_factor -def export_latents_to_video(pipeline: CogVideoXPipeline, latents: torch.FloatTensor, video_path: str, fps: int): +def export_latents_to_video( + pipeline: CogVideoXPipeline, latents: torch.FloatTensor, video_path: str, fps: int +): video = pipeline.decode_latents(latents) frames = pipeline.video_processor.postprocess_video(video=video, output_type="pil") export_to_video(video_frames=frames[0], output_video_path=video_path, fps=fps) @@ -320,7 +364,9 @@ def sample( # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta) - if isinstance(scheduler, DDIMInverseScheduler): # Inverse scheduler does not accept extra kwargs + if isinstance( + scheduler, DDIMInverseScheduler + ): # Inverse scheduler does not accept extra kwargs extra_step_kwargs = {} # 7. Create rotary embeds if required @@ -344,7 +390,9 @@ def sample( if pipeline.interrupt: continue - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = ( + torch.cat([latents] * 2) if do_classifier_free_guidance else latents + ) if reference_latents is not None: reference = reference_latents[i] reference = torch.cat([reference] * 2) if do_classifier_free_guidance else reference @@ -371,18 +419,31 @@ def sample( # perform guidance if use_dynamic_cfg: pipeline._guidance_scale = 1 + guidance_scale * ( - (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 + ( + 1 + - math.cos( + math.pi + * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0 + ) + ) + / 2 ) if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + pipeline.guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = noise_pred_uncond + pipeline.guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) # compute the noisy sample x_t-1 -> x_t - latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + latents = scheduler.step( + noise_pred, t, latents, **extra_step_kwargs, return_dict=False + )[0] latents = latents.to(prompt_embeds.dtype) trajectory[i] = latents - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0): + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0 + ): progress_bar.update() # Offload all models @@ -410,7 +471,9 @@ def ddim_inversion( seed: int, device: torch.device, ): - pipeline: CogVideoXPipeline = CogVideoXPipeline.from_pretrained(model_path, torch_dtype=dtype).to(device=device) + pipeline: CogVideoXPipeline = CogVideoXPipeline.from_pretrained( + model_path, torch_dtype=dtype + ).to(device=device) if not pipeline.transformer.config.use_rotary_positional_embeddings: raise NotImplementedError("This script supports CogVideoX 5B model only.") video_frames = get_video_frames( diff --git a/inference/gradio_composite_demo/README.md b/inference/gradio_composite_demo/README.md index d19c3a5..50dfde7 100644 --- a/inference/gradio_composite_demo/README.md +++ b/inference/gradio_composite_demo/README.md @@ -35,7 +35,7 @@ Set the following environment variables in your system: ## Installation ```bash -pip install -r requirements.txt +pip install -r requirements.txt ``` ## Running the code @@ -43,5 +43,3 @@ pip install -r requirements.txt ```bash python app.py ``` - - diff --git a/inference/gradio_composite_demo/app.py b/inference/gradio_composite_demo/app.py index 6856cad..085371f 100644 --- a/inference/gradio_composite_demo/app.py +++ b/inference/gradio_composite_demo/app.py @@ -39,11 +39,15 @@ device = "cuda" if torch.cuda.is_available() else "cpu" MODEL = "THUDM/CogVideoX-5b" -hf_hub_download(repo_id="ai-forever/Real-ESRGAN", filename="RealESRGAN_x4.pth", local_dir="model_real_esran") +hf_hub_download( + repo_id="ai-forever/Real-ESRGAN", filename="RealESRGAN_x4.pth", local_dir="model_real_esran" +) snapshot_download(repo_id="AlexWortega/RIFE", local_dir="model_rife") pipe = CogVideoXPipeline.from_pretrained(MODEL, torch_dtype=torch.bfloat16).to(device) -pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") +pipe.scheduler = CogVideoXDPMScheduler.from_config( + pipe.scheduler.config, timestep_spacing="trailing" +) pipe_video = CogVideoXVideoToVideoPipeline.from_pretrained( MODEL, transformer=pipe.transformer, @@ -296,8 +300,16 @@ def delete_old_files(): threading.Thread(target=delete_old_files, daemon=True).start() -examples_videos = [["example_videos/horse.mp4"], ["example_videos/kitten.mp4"], ["example_videos/train_running.mp4"]] -examples_images = [["example_images/beach.png"], ["example_images/street.png"], ["example_images/camping.png"]] +examples_videos = [ + ["example_videos/horse.mp4"], + ["example_videos/kitten.mp4"], + ["example_videos/train_running.mp4"], +] +examples_images = [ + ["example_images/beach.png"], + ["example_images/street.png"], + ["example_images/camping.png"], +] with gr.Blocks() as demo: gr.Markdown(""" @@ -317,19 +329,31 @@ with gr.Blocks() as demo: ">

- ⚠️ This demo is for academic research and experimental use only. + ⚠️ This demo is for academic research and experimental use only.
""") with gr.Row(): with gr.Column(): - with gr.Accordion("I2V: Image Input (cannot be used simultaneously with video input)", open=False): + with gr.Accordion( + "I2V: Image Input (cannot be used simultaneously with video input)", open=False + ): image_input = gr.Image(label="Input Image (will be cropped to 720 * 480)") - examples_component_images = gr.Examples(examples_images, inputs=[image_input], cache_examples=False) - with gr.Accordion("V2V: Video Input (cannot be used simultaneously with image input)", open=False): - video_input = gr.Video(label="Input Video (will be cropped to 49 frames, 6 seconds at 8fps)") + examples_component_images = gr.Examples( + examples_images, inputs=[image_input], cache_examples=False + ) + with gr.Accordion( + "V2V: Video Input (cannot be used simultaneously with image input)", open=False + ): + video_input = gr.Video( + label="Input Video (will be cropped to 49 frames, 6 seconds at 8fps)" + ) strength = gr.Slider(0.1, 1.0, value=0.8, step=0.01, label="Strength") - examples_component_videos = gr.Examples(examples_videos, inputs=[video_input], cache_examples=False) - prompt = gr.Textbox(label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5) + examples_component_videos = gr.Examples( + examples_videos, inputs=[video_input], cache_examples=False + ) + prompt = gr.Textbox( + label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5 + ) with gr.Row(): gr.Markdown( @@ -340,11 +364,16 @@ with gr.Blocks() as demo: with gr.Column(): with gr.Row(): seed_param = gr.Number( - label="Inference Seed (Enter a positive number, -1 for random)", value=-1 + label="Inference Seed (Enter a positive number, -1 for random)", + value=-1, ) with gr.Row(): - enable_scale = gr.Checkbox(label="Super-Resolution (720 × 480 -> 2880 × 1920)", value=False) - enable_rife = gr.Checkbox(label="Frame Interpolation (8fps -> 16fps)", value=False) + enable_scale = gr.Checkbox( + label="Super-Resolution (720 × 480 -> 2880 × 1920)", value=False + ) + enable_rife = gr.Checkbox( + label="Frame Interpolation (8fps -> 16fps)", value=False + ) gr.Markdown( "✨In this demo, we use [RIFE](https://github.com/hzwer/ECCV2022-RIFE) for frame interpolation and [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) for upscaling(Super-Resolution).
    The entire process is based on open-source solutions." ) @@ -430,7 +459,7 @@ with gr.Blocks() as demo: seed_value, scale_status, rife_status, - progress=gr.Progress(track_tqdm=True) + progress=gr.Progress(track_tqdm=True), ): latents, seed = infer( prompt, @@ -457,7 +486,9 @@ with gr.Blocks() as demo: image_pil = VaeImageProcessor.numpy_to_pil(image_np) batch_video_frames.append(image_pil) - video_path = utils.save_video(batch_video_frames[0], fps=math.ceil((len(batch_video_frames[0]) - 1) / 6)) + video_path = utils.save_video( + batch_video_frames[0], fps=math.ceil((len(batch_video_frames[0]) - 1) / 6) + ) video_update = gr.update(visible=True, value=video_path) gif_path = convert_to_gif(video_path) gif_update = gr.update(visible=True, value=gif_path) diff --git a/inference/gradio_composite_demo/requirements.txt b/inference/gradio_composite_demo/requirements.txt index b843892..50ce31c 100644 --- a/inference/gradio_composite_demo/requirements.txt +++ b/inference/gradio_composite_demo/requirements.txt @@ -16,4 +16,4 @@ imageio>=2.34.2 imageio-ffmpeg>=0.5.1 openai>=1.45.0 moviepy>=2.0.0 -pillow==9.5.0 \ No newline at end of file +pillow==9.5.0 diff --git a/inference/gradio_composite_demo/rife/IFNet.py b/inference/gradio_composite_demo/rife/IFNet.py index 7b74fbf..c395783 100644 --- a/inference/gradio_composite_demo/rife/IFNet.py +++ b/inference/gradio_composite_demo/rife/IFNet.py @@ -3,7 +3,9 @@ from .refine import * def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): return nn.Sequential( - torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1), + torch.nn.ConvTranspose2d( + in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1 + ), nn.PReLU(out_planes), ) @@ -46,7 +48,11 @@ class IFBlock(nn.Module): if scale != 1: x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False) if flow != None: - flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False) * 1.0 / scale + flow = ( + F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False) + * 1.0 + / scale + ) x = torch.cat((x, flow), 1) x = self.conv0(x) x = self.convblock(x) + x @@ -102,7 +108,9 @@ class IFNet(nn.Module): warped_img0_teacher = warp(img0, flow_teacher[:, :2]) warped_img1_teacher = warp(img1, flow_teacher[:, 2:4]) mask_teacher = torch.sigmoid(mask + mask_d) - merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * (1 - mask_teacher) + merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * ( + 1 - mask_teacher + ) else: flow_teacher = None merged_teacher = None @@ -110,11 +118,16 @@ class IFNet(nn.Module): merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i]) if gt.shape[1] == 3: loss_mask = ( - ((merged[i] - gt).abs().mean(1, True) > (merged_teacher - gt).abs().mean(1, True) + 0.01) + ( + (merged[i] - gt).abs().mean(1, True) + > (merged_teacher - gt).abs().mean(1, True) + 0.01 + ) .float() .detach() ) - loss_distill += (((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask).mean() + loss_distill += ( + ((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask + ).mean() c0 = self.contextnet(img0, flow[:, :2]) c1 = self.contextnet(img1, flow[:, 2:4]) tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1) diff --git a/inference/gradio_composite_demo/rife/IFNet_2R.py b/inference/gradio_composite_demo/rife/IFNet_2R.py index 0317b86..8ec9eaa 100644 --- a/inference/gradio_composite_demo/rife/IFNet_2R.py +++ b/inference/gradio_composite_demo/rife/IFNet_2R.py @@ -3,7 +3,9 @@ from .refine_2R import * def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): return nn.Sequential( - torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1), + torch.nn.ConvTranspose2d( + in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1 + ), nn.PReLU(out_planes), ) @@ -46,7 +48,11 @@ class IFBlock(nn.Module): if scale != 1: x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False) if flow != None: - flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False) * 1.0 / scale + flow = ( + F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False) + * 1.0 + / scale + ) x = torch.cat((x, flow), 1) x = self.conv0(x) x = self.convblock(x) + x @@ -102,7 +108,9 @@ class IFNet(nn.Module): warped_img0_teacher = warp(img0, flow_teacher[:, :2]) warped_img1_teacher = warp(img1, flow_teacher[:, 2:4]) mask_teacher = torch.sigmoid(mask + mask_d) - merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * (1 - mask_teacher) + merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * ( + 1 - mask_teacher + ) else: flow_teacher = None merged_teacher = None @@ -110,11 +118,16 @@ class IFNet(nn.Module): merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i]) if gt.shape[1] == 3: loss_mask = ( - ((merged[i] - gt).abs().mean(1, True) > (merged_teacher - gt).abs().mean(1, True) + 0.01) + ( + (merged[i] - gt).abs().mean(1, True) + > (merged_teacher - gt).abs().mean(1, True) + 0.01 + ) .float() .detach() ) - loss_distill += (((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask).mean() + loss_distill += ( + ((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask + ).mean() c0 = self.contextnet(img0, flow[:, :2]) c1 = self.contextnet(img1, flow[:, 2:4]) tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1) diff --git a/inference/gradio_composite_demo/rife/IFNet_HDv3.py b/inference/gradio_composite_demo/rife/IFNet_HDv3.py index 57f8003..ad4a727 100644 --- a/inference/gradio_composite_demo/rife/IFNet_HDv3.py +++ b/inference/gradio_composite_demo/rife/IFNet_HDv3.py @@ -61,11 +61,19 @@ class IFBlock(nn.Module): def forward(self, x, flow, scale=1): x = F.interpolate( - x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False + x, + scale_factor=1.0 / scale, + mode="bilinear", + align_corners=False, + recompute_scale_factor=False, ) flow = ( F.interpolate( - flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False + flow, + scale_factor=1.0 / scale, + mode="bilinear", + align_corners=False, + recompute_scale_factor=False, ) * 1.0 / scale @@ -78,11 +86,21 @@ class IFBlock(nn.Module): flow = self.conv1(feat) mask = self.conv2(feat) flow = ( - F.interpolate(flow, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) + F.interpolate( + flow, + scale_factor=scale, + mode="bilinear", + align_corners=False, + recompute_scale_factor=False, + ) * scale ) mask = F.interpolate( - mask, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False + mask, + scale_factor=scale, + mode="bilinear", + align_corners=False, + recompute_scale_factor=False, ) return flow, mask @@ -112,7 +130,11 @@ class IFNet(nn.Module): loss_cons = 0 block = [self.block0, self.block1, self.block2] for i in range(3): - f0, m0 = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], mask), 1), flow, scale=scale_list[i]) + f0, m0 = block[i]( + torch.cat((warped_img0[:, :3], warped_img1[:, :3], mask), 1), + flow, + scale=scale_list[i], + ) f1, m1 = block[i]( torch.cat((warped_img1[:, :3], warped_img0[:, :3], -mask), 1), torch.cat((flow[:, 2:4], flow[:, :2]), 1), diff --git a/inference/gradio_composite_demo/rife/IFNet_m.py b/inference/gradio_composite_demo/rife/IFNet_m.py index b28acd3..8547ef8 100644 --- a/inference/gradio_composite_demo/rife/IFNet_m.py +++ b/inference/gradio_composite_demo/rife/IFNet_m.py @@ -3,7 +3,9 @@ from .refine import * def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): return nn.Sequential( - torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1), + torch.nn.ConvTranspose2d( + in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1 + ), nn.PReLU(out_planes), ) @@ -46,7 +48,11 @@ class IFBlock(nn.Module): if scale != 1: x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False) if flow != None: - flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False) * 1.0 / scale + flow = ( + F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False) + * 1.0 + / scale + ) x = torch.cat((x, flow), 1) x = self.conv0(x) x = self.convblock(x) + x @@ -83,7 +89,9 @@ class IFNet_m(nn.Module): for i in range(3): if flow != None: flow_d, mask_d = stu[i]( - torch.cat((img0, img1, timestep, warped_img0, warped_img1, mask), 1), flow, scale=scale[i] + torch.cat((img0, img1, timestep, warped_img0, warped_img1, mask), 1), + flow, + scale=scale[i], ) flow = flow + flow_d mask = mask + mask_d @@ -97,13 +105,17 @@ class IFNet_m(nn.Module): merged.append(merged_student) if gt.shape[1] == 3: flow_d, mask_d = self.block_tea( - torch.cat((img0, img1, timestep, warped_img0, warped_img1, mask, gt), 1), flow, scale=1 + torch.cat((img0, img1, timestep, warped_img0, warped_img1, mask, gt), 1), + flow, + scale=1, ) flow_teacher = flow + flow_d warped_img0_teacher = warp(img0, flow_teacher[:, :2]) warped_img1_teacher = warp(img1, flow_teacher[:, 2:4]) mask_teacher = torch.sigmoid(mask + mask_d) - merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * (1 - mask_teacher) + merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * ( + 1 - mask_teacher + ) else: flow_teacher = None merged_teacher = None @@ -111,11 +123,16 @@ class IFNet_m(nn.Module): merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i]) if gt.shape[1] == 3: loss_mask = ( - ((merged[i] - gt).abs().mean(1, True) > (merged_teacher - gt).abs().mean(1, True) + 0.01) + ( + (merged[i] - gt).abs().mean(1, True) + > (merged_teacher - gt).abs().mean(1, True) + 0.01 + ) .float() .detach() ) - loss_distill += (((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask).mean() + loss_distill += ( + ((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask + ).mean() if returnflow: return flow else: diff --git a/inference/gradio_composite_demo/rife/RIFE_HDv3.py b/inference/gradio_composite_demo/rife/RIFE_HDv3.py index 6123e31..182c78e 100644 --- a/inference/gradio_composite_demo/rife/RIFE_HDv3.py +++ b/inference/gradio_composite_demo/rife/RIFE_HDv3.py @@ -44,7 +44,9 @@ class Model: if torch.cuda.is_available(): self.flownet.load_state_dict(convert(torch.load("{}/flownet.pkl".format(path)))) else: - self.flownet.load_state_dict(convert(torch.load("{}/flownet.pkl".format(path), map_location="cpu"))) + self.flownet.load_state_dict( + convert(torch.load("{}/flownet.pkl".format(path), map_location="cpu")) + ) def save_model(self, path, rank=0): if rank == 0: diff --git a/inference/gradio_composite_demo/rife/laplacian.py b/inference/gradio_composite_demo/rife/laplacian.py index 6e72e51..dac3994 100644 --- a/inference/gradio_composite_demo/rife/laplacian.py +++ b/inference/gradio_composite_demo/rife/laplacian.py @@ -29,10 +29,14 @@ def downsample(x): def upsample(x): - cc = torch.cat([x, torch.zeros(x.shape[0], x.shape[1], x.shape[2], x.shape[3]).to(device)], dim=3) + cc = torch.cat( + [x, torch.zeros(x.shape[0], x.shape[1], x.shape[2], x.shape[3]).to(device)], dim=3 + ) cc = cc.view(x.shape[0], x.shape[1], x.shape[2] * 2, x.shape[3]) cc = cc.permute(0, 1, 3, 2) - cc = torch.cat([cc, torch.zeros(x.shape[0], x.shape[1], x.shape[3], x.shape[2] * 2).to(device)], dim=3) + cc = torch.cat( + [cc, torch.zeros(x.shape[0], x.shape[1], x.shape[3], x.shape[2] * 2).to(device)], dim=3 + ) cc = cc.view(x.shape[0], x.shape[1], x.shape[3] * 2, x.shape[2] * 2) x_up = cc.permute(0, 1, 3, 2) return conv_gauss(x_up, 4 * gauss_kernel(channels=x.shape[1])) @@ -64,6 +68,10 @@ class LapLoss(torch.nn.Module): self.gauss_kernel = gauss_kernel(channels=channels) def forward(self, input, target): - pyr_input = laplacian_pyramid(img=input, kernel=self.gauss_kernel, max_levels=self.max_levels) - pyr_target = laplacian_pyramid(img=target, kernel=self.gauss_kernel, max_levels=self.max_levels) + pyr_input = laplacian_pyramid( + img=input, kernel=self.gauss_kernel, max_levels=self.max_levels + ) + pyr_target = laplacian_pyramid( + img=target, kernel=self.gauss_kernel, max_levels=self.max_levels + ) return sum(torch.nn.functional.l1_loss(a, b) for a, b in zip(pyr_input, pyr_target)) diff --git a/inference/gradio_composite_demo/rife/pytorch_msssim/__init__.py b/inference/gradio_composite_demo/rife/pytorch_msssim/__init__.py index 3e2baaf..f1e4ec4 100644 --- a/inference/gradio_composite_demo/rife/pytorch_msssim/__init__.py +++ b/inference/gradio_composite_demo/rife/pytorch_msssim/__init__.py @@ -7,7 +7,9 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def gaussian(window_size, sigma): - gauss = torch.Tensor([exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2)) for x in range(window_size)]) + gauss = torch.Tensor( + [exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2)) for x in range(window_size)] + ) return gauss / gauss.sum() @@ -22,7 +24,9 @@ def create_window_3d(window_size, channel=1): _1D_window = gaussian(window_size, 1.5).unsqueeze(1) _2D_window = _1D_window.mm(_1D_window.t()) _3D_window = _2D_window.unsqueeze(2) @ (_1D_window.t()) - window = _3D_window.expand(1, channel, window_size, window_size, window_size).contiguous().to(device) + window = ( + _3D_window.expand(1, channel, window_size, window_size, window_size).contiguous().to(device) + ) return window @@ -50,16 +54,35 @@ def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, # mu1 = F.conv2d(img1, window, padding=padd, groups=channel) # mu2 = F.conv2d(img2, window, padding=padd, groups=channel) - mu1 = F.conv2d(F.pad(img1, (5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=channel) - mu2 = F.conv2d(F.pad(img2, (5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=channel) + mu1 = F.conv2d( + F.pad(img1, (5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=channel + ) + mu2 = F.conv2d( + F.pad(img2, (5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=channel + ) mu1_sq = mu1.pow(2) mu2_sq = mu2.pow(2) mu1_mu2 = mu1 * mu2 - sigma1_sq = F.conv2d(F.pad(img1 * img1, (5, 5, 5, 5), "replicate"), window, padding=padd, groups=channel) - mu1_sq - sigma2_sq = F.conv2d(F.pad(img2 * img2, (5, 5, 5, 5), "replicate"), window, padding=padd, groups=channel) - mu2_sq - sigma12 = F.conv2d(F.pad(img1 * img2, (5, 5, 5, 5), "replicate"), window, padding=padd, groups=channel) - mu1_mu2 + sigma1_sq = ( + F.conv2d( + F.pad(img1 * img1, (5, 5, 5, 5), "replicate"), window, padding=padd, groups=channel + ) + - mu1_sq + ) + sigma2_sq = ( + F.conv2d( + F.pad(img2 * img2, (5, 5, 5, 5), "replicate"), window, padding=padd, groups=channel + ) + - mu2_sq + ) + sigma12 = ( + F.conv2d( + F.pad(img1 * img2, (5, 5, 5, 5), "replicate"), window, padding=padd, groups=channel + ) + - mu1_mu2 + ) C1 = (0.01 * L) ** 2 C2 = (0.03 * L) ** 2 @@ -80,7 +103,9 @@ def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, return ret -def ssim_matlab(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None): +def ssim_matlab( + img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None +): # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh). if val_range is None: if torch.max(img1) > 128: @@ -106,16 +131,35 @@ def ssim_matlab(img1, img2, window_size=11, window=None, size_average=True, full img1 = img1.unsqueeze(1) img2 = img2.unsqueeze(1) - mu1 = F.conv3d(F.pad(img1, (5, 5, 5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=1) - mu2 = F.conv3d(F.pad(img2, (5, 5, 5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=1) + mu1 = F.conv3d( + F.pad(img1, (5, 5, 5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=1 + ) + mu2 = F.conv3d( + F.pad(img2, (5, 5, 5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=1 + ) mu1_sq = mu1.pow(2) mu2_sq = mu2.pow(2) mu1_mu2 = mu1 * mu2 - sigma1_sq = F.conv3d(F.pad(img1 * img1, (5, 5, 5, 5, 5, 5), "replicate"), window, padding=padd, groups=1) - mu1_sq - sigma2_sq = F.conv3d(F.pad(img2 * img2, (5, 5, 5, 5, 5, 5), "replicate"), window, padding=padd, groups=1) - mu2_sq - sigma12 = F.conv3d(F.pad(img1 * img2, (5, 5, 5, 5, 5, 5), "replicate"), window, padding=padd, groups=1) - mu1_mu2 + sigma1_sq = ( + F.conv3d( + F.pad(img1 * img1, (5, 5, 5, 5, 5, 5), "replicate"), window, padding=padd, groups=1 + ) + - mu1_sq + ) + sigma2_sq = ( + F.conv3d( + F.pad(img2 * img2, (5, 5, 5, 5, 5, 5), "replicate"), window, padding=padd, groups=1 + ) + - mu2_sq + ) + sigma12 = ( + F.conv3d( + F.pad(img1 * img2, (5, 5, 5, 5, 5, 5), "replicate"), window, padding=padd, groups=1 + ) + - mu1_mu2 + ) C1 = (0.01 * L) ** 2 C2 = (0.03 * L) ** 2 @@ -143,7 +187,14 @@ def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normal mssim = [] mcs = [] for _ in range(levels): - sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range) + sim, cs = ssim( + img1, + img2, + window_size=window_size, + size_average=size_average, + full=True, + val_range=val_range, + ) mssim.append(sim) mcs.append(cs) @@ -187,7 +238,9 @@ class SSIM(torch.nn.Module): self.window = window self.channel = channel - _ssim = ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average) + _ssim = ssim( + img1, img2, window=window, window_size=self.window_size, size_average=self.size_average + ) dssim = (1 - _ssim) / 2 return dssim diff --git a/inference/gradio_composite_demo/rife/refine.py b/inference/gradio_composite_demo/rife/refine.py index 2f9becb..e956318 100644 --- a/inference/gradio_composite_demo/rife/refine.py +++ b/inference/gradio_composite_demo/rife/refine.py @@ -24,7 +24,12 @@ def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): return nn.Sequential( torch.nn.ConvTranspose2d( - in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1, bias=True + in_channels=in_planes, + out_channels=out_planes, + kernel_size=4, + stride=2, + padding=1, + bias=True, ), nn.PReLU(out_planes), ) @@ -56,25 +61,49 @@ class Contextnet(nn.Module): def forward(self, x, flow): x = self.conv1(x) flow = ( - F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) + F.interpolate( + flow, + scale_factor=0.5, + mode="bilinear", + align_corners=False, + recompute_scale_factor=False, + ) * 0.5 ) f1 = warp(x, flow) x = self.conv2(x) flow = ( - F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) + F.interpolate( + flow, + scale_factor=0.5, + mode="bilinear", + align_corners=False, + recompute_scale_factor=False, + ) * 0.5 ) f2 = warp(x, flow) x = self.conv3(x) flow = ( - F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) + F.interpolate( + flow, + scale_factor=0.5, + mode="bilinear", + align_corners=False, + recompute_scale_factor=False, + ) * 0.5 ) f3 = warp(x, flow) x = self.conv4(x) flow = ( - F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) + F.interpolate( + flow, + scale_factor=0.5, + mode="bilinear", + align_corners=False, + recompute_scale_factor=False, + ) * 0.5 ) f4 = warp(x, flow) diff --git a/inference/gradio_composite_demo/rife/refine_2R.py b/inference/gradio_composite_demo/rife/refine_2R.py index c6cc2c0..6bc1515 100644 --- a/inference/gradio_composite_demo/rife/refine_2R.py +++ b/inference/gradio_composite_demo/rife/refine_2R.py @@ -24,7 +24,12 @@ def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): return nn.Sequential( torch.nn.ConvTranspose2d( - in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1, bias=True + in_channels=in_planes, + out_channels=out_planes, + kernel_size=4, + stride=2, + padding=1, + bias=True, ), nn.PReLU(out_planes), ) @@ -59,19 +64,37 @@ class Contextnet(nn.Module): f1 = warp(x, flow) x = self.conv2(x) flow = ( - F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) + F.interpolate( + flow, + scale_factor=0.5, + mode="bilinear", + align_corners=False, + recompute_scale_factor=False, + ) * 0.5 ) f2 = warp(x, flow) x = self.conv3(x) flow = ( - F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) + F.interpolate( + flow, + scale_factor=0.5, + mode="bilinear", + align_corners=False, + recompute_scale_factor=False, + ) * 0.5 ) f3 = warp(x, flow) x = self.conv4(x) flow = ( - F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) + F.interpolate( + flow, + scale_factor=0.5, + mode="bilinear", + align_corners=False, + recompute_scale_factor=False, + ) * 0.5 ) f4 = warp(x, flow) diff --git a/inference/gradio_composite_demo/rife_model.py b/inference/gradio_composite_demo/rife_model.py index e1783e3..c21a713 100644 --- a/inference/gradio_composite_demo/rife_model.py +++ b/inference/gradio_composite_demo/rife_model.py @@ -9,6 +9,7 @@ import logging import skvideo.io from rife.RIFE_HDv3 import Model from huggingface_hub import hf_hub_download, snapshot_download + logger = logging.getLogger(__name__) device = "cuda" if torch.cuda.is_available() else "cpu" @@ -19,7 +20,7 @@ def pad_image(img, scale): tmp = max(32, int(32 / scale)) ph = ((h - 1) // tmp + 1) * tmp pw = ((w - 1) // tmp + 1) * tmp - padding = (0, pw - w, 0, ph - h) + padding = (0, pw - w, 0, ph - h) return F.pad(img, padding), padding @@ -45,15 +46,15 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi for b in range(samples.shape[0]): frame = samples[b : b + 1] _, _, h, w = frame.shape - + I0 = samples[b : b + 1] I1 = samples[b + 1 : b + 2] if b + 2 < samples.shape[0] else samples[-1:] - + I0, padding = pad_image(I0, upscale_amount) I0 = I0.to(torch.float) I1, _ = pad_image(I1, upscale_amount) I1 = I1.to(torch.float) - + # [c, h, w] I0_small = F.interpolate(I0, (32, 32), mode="bilinear", align_corners=False) I1_small = F.interpolate(I1, (32, 32), mode="bilinear", align_corners=False) @@ -70,21 +71,20 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi # print(f'I0 shape:{I0.shape}') # print(f'I1 shape:{I1.shape}') I1 = make_inference(model, I0, I1, upscale_amount, 1) - + # print(f'I0 shape:{I0.shape}') - # print(f'I1[0] shape:{I1[0].shape}') + # print(f'I1[0] shape:{I1[0].shape}') I1 = I1[0] - - # print(f'I1[0] unpadded shape:{I1.shape}') + + # print(f'I1[0] unpadded shape:{I1.shape}') I1_small = F.interpolate(I1, (32, 32), mode="bilinear", align_corners=False) ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3]) - if padding[3] > 0 and padding[1] >0 : - - frame = I1[:, :, : -padding[3],:-padding[1]] + if padding[3] > 0 and padding[1] > 0: + frame = I1[:, :, : -padding[3], : -padding[1]] elif padding[3] > 0: - frame = I1[:, :, : -padding[3],:] - elif padding[1] >0: - frame = I1[:, :, :,:-padding[1]] + frame = I1[:, :, : -padding[3], :] + elif padding[1] > 0: + frame = I1[:, :, :, : -padding[1]] else: frame = I1 @@ -101,8 +101,7 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi frame = F.interpolate(frame, size=(h, w)) output.append(frame.to(output_device)) - for i, tmp_frame in enumerate(tmp_output): - + for i, tmp_frame in enumerate(tmp_output): # tmp_frame, _ = pad_image(tmp_frame, upscale_amount) tmp_frame = F.interpolate(tmp_frame, size=(h, w)) output.append(tmp_frame.to(output_device)) @@ -145,9 +144,7 @@ def rife_inference_with_path(model, video_path): frame_rgb = frame[..., ::-1] frame_rgb = frame_rgb.copy() tensor = torch.from_numpy(frame_rgb).float().to("cpu", non_blocking=True).float() / 255.0 - pt_frame_data.append( - tensor.permute(2, 0, 1) - ) # to [c, h, w,] + pt_frame_data.append(tensor.permute(2, 0, 1)) # to [c, h, w,] pt_frame = torch.from_numpy(np.stack(pt_frame_data)) pt_frame = pt_frame.to(device) @@ -170,7 +167,9 @@ def rife_inference_with_latents(model, latents): latent = latents[i] frames = ssim_interpolation_rife(model, latent) - pt_image = torch.stack([frames[i].squeeze(0) for i in range(len(frames))]) # (to [f, c, w, h]) + pt_image = torch.stack( + [frames[i].squeeze(0) for i in range(len(frames))] + ) # (to [f, c, w, h]) rife_results.append(pt_image) return torch.stack(rife_results) @@ -179,6 +178,6 @@ def rife_inference_with_latents(model, latents): # if __name__ == "__main__": # snapshot_download(repo_id="AlexWortega/RIFE", local_dir="model_rife") # model = load_rife_model("model_rife") - + # video_path = rife_inference_with_path(model, "/mnt/ceph/develop/jiawei/CogVideo/output/20241003_130720.mp4") -# print(video_path) \ No newline at end of file +# print(video_path) diff --git a/inference/gradio_composite_demo/utils.py b/inference/gradio_composite_demo/utils.py index 01c04d4..d39f227 100644 --- a/inference/gradio_composite_demo/utils.py +++ b/inference/gradio_composite_demo/utils.py @@ -22,7 +22,7 @@ def load_torch_file(ckpt, device=None, dtype=torch.float16): if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"): sd = safetensors.torch.load_file(ckpt, device=device.type) else: - if not "weights_only" in torch.load.__code__.co_varnames: + if "weights_only" not in torch.load.__code__.co_varnames: logger.warning( "Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely." ) @@ -74,27 +74,39 @@ def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap): @torch.inference_mode() def tiled_scale_multidim( - samples, function, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", pbar=None + samples, + function, + tile=(64, 64), + overlap=8, + upscale_amount=4, + out_channels=3, + output_device="cpu", + pbar=None, ): dims = len(tile) print(f"samples dtype:{samples.dtype}") output = torch.empty( - [samples.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), samples.shape[2:])), + [samples.shape[0], out_channels] + + list(map(lambda a: round(a * upscale_amount), samples.shape[2:])), device=output_device, ) for b in range(samples.shape[0]): s = samples[b : b + 1] out = torch.zeros( - [s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), + [s.shape[0], out_channels] + + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device, ) out_div = torch.zeros( - [s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), + [s.shape[0], out_channels] + + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device, ) - for it in itertools.product(*map(lambda a: range(0, a[0], a[1] - overlap), zip(s.shape[2:], tile))): + for it in itertools.product( + *map(lambda a: range(0, a[0], a[1] - overlap), zip(s.shape[2:], tile)) + ): s_in = s upscaled = [] @@ -142,7 +154,14 @@ def tiled_scale( pbar=None, ): return tiled_scale_multidim( - samples, function, (tile_y, tile_x), overlap, upscale_amount, out_channels, output_device, pbar + samples, + function, + (tile_y, tile_x), + overlap, + upscale_amount, + out_channels, + output_device, + pbar, ) @@ -186,7 +205,9 @@ def upscale(upscale_model, tensor: torch.Tensor, inf_device, output_device="cpu" return s -def upscale_batch_and_concatenate(upscale_model, latents, inf_device, output_device="cpu") -> torch.Tensor: +def upscale_batch_and_concatenate( + upscale_model, latents, inf_device, output_device="cpu" +) -> torch.Tensor: upscaled_latents = [] for i in range(latents.size(0)): latent = latents[i] @@ -207,7 +228,9 @@ class ProgressBar: def __init__(self, total, desc=None): self.total = total self.current = 0 - self.b_unit = tqdm.tqdm(total=total, desc="ProgressBar context index: 0" if desc is None else desc) + self.b_unit = tqdm.tqdm( + total=total, desc="ProgressBar context index: 0" if desc is None else desc + ) def update(self, value): if value > self.total: diff --git a/inference/gradio_web_demo.py b/inference/gradio_web_demo.py index 8e1e0e4..955c3b7 100644 --- a/inference/gradio_web_demo.py +++ b/inference/gradio_web_demo.py @@ -22,7 +22,9 @@ from datetime import datetime, timedelta from openai import OpenAI from moviepy import VideoFileClip -pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to("cuda") +pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to( + "cuda" +) pipe.vae.enable_slicing() pipe.vae.enable_tiling() @@ -95,7 +97,12 @@ def convert_prompt(prompt: str, retry_times: int = 3) -> str: return prompt -def infer(prompt: str, num_inference_steps: int, guidance_scale: float, progress=gr.Progress(track_tqdm=True)): +def infer( + prompt: str, + num_inference_steps: int, + guidance_scale: float, + progress=gr.Progress(track_tqdm=True), +): torch.cuda.empty_cache() video = pipe( prompt=prompt, @@ -151,7 +158,9 @@ with gr.Blocks() as demo: with gr.Row(): with gr.Column(): - prompt = gr.Textbox(label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5) + prompt = gr.Textbox( + label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5 + ) with gr.Row(): gr.Markdown( @@ -176,7 +185,13 @@ with gr.Blocks() as demo: download_video_button = gr.File(label="📥 Download Video", visible=False) download_gif_button = gr.File(label="📥 Download GIF", visible=False) - def generate(prompt, num_inference_steps, guidance_scale, model_choice, progress=gr.Progress(track_tqdm=True)): + def generate( + prompt, + num_inference_steps, + guidance_scale, + model_choice, + progress=gr.Progress(track_tqdm=True), + ): tensor = infer(prompt, num_inference_steps, guidance_scale, progress=progress) video_path = save_video(tensor) video_update = gr.update(visible=True, value=video_path) diff --git a/resources/WECHAT.md b/resources/WECHAT.md index 7f9620d..8c56211 100644 --- a/resources/WECHAT.md +++ b/resources/WECHAT.md @@ -4,4 +4,3 @@

扫码关注公众号,加入「 CogVideoX 交流群」

Scan the QR code to follow the official account and join the "CogVLM Discussion Group"

- diff --git a/resources/contribute.md b/resources/contribute.md deleted file mode 100644 index 0d3640f..0000000 --- a/resources/contribute.md +++ /dev/null @@ -1,49 +0,0 @@ -# Contribution Guide - -There may still be many incomplete aspects in this project. - -We look forward to your contributions to the repository in the following areas. If you complete the work mentioned above -and are willing to submit a PR and share it with the community, upon review, we -will acknowledge your contribution on the project homepage. - -## Model Algorithms - -- Support for model quantization inference (Int4 quantization project) -- Optimization of model fine-tuning data loading (replacing the existing decord tool) - -## Model Engineering - -- Model fine-tuning examples / Best prompt practices -- Inference adaptation on different devices (e.g., MLX framework) -- Any tools related to the model -- Any minimal fully open-source project using the CogVideoX open-source model - -## Code Standards - -Good code style is an art. We have prepared a `pyproject.toml` configuration file for the project to standardize code -style. You can organize the code according to the following specifications: - -1. Install the `ruff` tool - -```shell -pip install ruff -``` - -Then, run the `ruff` tool - -```shell -ruff check tools sat inference -``` - -Check the code style. If there are issues, you can automatically fix them using the `ruff format` command. - -```shell -ruff format tools sat inference -``` - -Once your code meets the standard, there should be no errors. - -## Naming Conventions -1. Please use English names, do not use Pinyin or other language names. All comments should be in English. -2. Please strictly follow the PEP8 specification and use underscores to separate words. Do not use names like a, b, c. - diff --git a/resources/contribute_ja.md b/resources/contribute_ja.md deleted file mode 100644 index 80ddc27..0000000 --- a/resources/contribute_ja.md +++ /dev/null @@ -1,47 +0,0 @@ -# コントリビューションガイド - -本プロジェクトにはまだ多くの未完成の部分があります。 - -以下の分野でリポジトリへの貢献をお待ちしています。上記の作業を完了し、PRを提出してコミュニティと共有する意志がある場合、レビュー後、プロジェクトのホームページで貢献を認識します。 - -## モデルアルゴリズム - -- モデル量子化推論のサポート (Int4量子化プロジェクト) -- モデルのファインチューニングデータロードの最適化(既存のdecordツールの置き換え) - -## モデルエンジニアリング - -- モデルのファインチューニング例 / 最適なプロンプトの実践 -- 異なるデバイスでの推論適応(例: MLXフレームワーク) -- モデルに関連するツール -- CogVideoXオープンソースモデルを使用した、完全にオープンソースの最小プロジェクト - -## コード標準 - -良いコードスタイルは一種の芸術です。本プロジェクトにはコードスタイルを標準化するための `pyproject.toml` -設定ファイルを用意しています。以下の仕様に従ってコードを整理してください。 - -1. `ruff` ツールをインストールする - -```shell -pip install ruff -``` - -次に、`ruff` ツールを実行します - -```shell -ruff check tools sat inference -``` - -コードスタイルを確認します。問題がある場合は、`ruff format` コマンドを使用して自動修正できます。 - -```shell -ruff format tools sat inference -``` - -コードが標準に準拠したら、エラーはなくなるはずです。 - -## 命名規則 - -1. 英語名を使用してください。ピンインや他の言語の名前を使用しないでください。すべてのコメントは英語で記載してください。 -2. PEP8仕様に厳密に従い、単語をアンダースコアで区切ってください。a、b、cのような名前は使用しないでください。 diff --git a/resources/contribute_zh.md b/resources/contribute_zh.md deleted file mode 100644 index 4b95254..0000000 --- a/resources/contribute_zh.md +++ /dev/null @@ -1,44 +0,0 @@ -# 贡献指南 - -本项目可能还存在很多不完善的内容。 我们期待您在以下方面与我们共建仓库, 如果您完成了上述工作并愿意PR和分享到社区,在通过审核后,我们将在项目首页感谢您的贡献。 - -## 模型算法 - -- 模型量化推理支持 (Int4量化工程) -- 模型微调数据载入优化支持(替换现有的decord工具) - -## 模型工程 - -- 模型微调示例 / 最佳提示词实践 -- 不同设备上的推理适配(MLX等框架) -- 任何模型周边工具 -- 任何使用CogVideoX开源模型制作的最小完整开源项目 - -## 代码规范 - -良好的代码风格是一种艺术,我们已经为项目准备好了`pyproject.toml`配置文件,用于规范代码风格。您可以按照以下规范梳理代码: - -1. 安装`ruff`工具 - -```shell -pip install ruff -``` - -接着,运行`ruff`工具 - -```shell -ruff check tools sat inference -``` - -检查代码风格,如果有问题,您可以通过`ruff format .`命令自动修复。 - -```shell -ruff format tools sat inference -``` - -如果您的代码符合规范,应该不会出现任何的错误。 - -## 命名规范 - -- 请使用英文命名,不要使用拼音或者其他语言命名。所有的注释均使用英文。 -- 请严格遵循 PEP8 规范,使用下划线分割单词。请勿使用 a,b,c 这样的命名。 \ No newline at end of file diff --git a/resources/galary_prompt.md b/resources/galary_prompt.md index a738bb2..b848211 100644 --- a/resources/galary_prompt.md +++ b/resources/galary_prompt.md @@ -20,7 +20,7 @@ Videos 1-8: ## CogVideoX-2B -Videos 1-4: +Videos 1-4: 1. A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting. diff --git a/sat/README.md b/sat/README.md index b07f52f..e8b90e8 100644 --- a/sat/README.md +++ b/sat/README.md @@ -64,7 +64,7 @@ Arrange the model files in the following structure: └── 3d-vae.pt ``` -Since model weight files are large, it’s recommended to use `git lfs`. +Since model weight files are large, it’s recommended to use `git lfs`. See [here](https://github.com/git-lfs/git-lfs?tab=readme-ov-file#installing) for `git lfs` installation. ``` @@ -339,7 +339,7 @@ Modify the files in `configs/sft.yaml` (full fine-tuning) as follows: eval_iters: 1 # Evaluation iterations eval_interval: 100 # Evaluation interval eval_batch_size: 1 # Evaluation batch size - save: ckpts # Model save path + save: ckpts # Model save path save_interval: 100 # Save interval log_interval: 20 # Log output interval train_data: [ "your train data path" ] @@ -411,7 +411,7 @@ run_cmd="$environs python sample_video.py --base configs/cogvideox_ self.max_num_frames: @@ -410,7 +418,11 @@ class SFTDataset(Dataset): indices = np.arange(start, end, max((end - start) // num_frames, 1)).astype(int) temp_frms = vr.get_batch(np.arange(start, end)) assert temp_frms is not None - tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms + tensor_frms = ( + torch.from_numpy(temp_frms) + if type(temp_frms) is not torch.Tensor + else temp_frms + ) tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())] else: @@ -423,11 +435,17 @@ class SFTDataset(Dataset): start = int(self.skip_frms_num) end = int(ori_vlen - self.skip_frms_num) - num_frames = nearest_smaller_4k_plus_1(end - start) # 3D VAE requires the number of frames to be 4k+1 + num_frames = nearest_smaller_4k_plus_1( + end - start + ) # 3D VAE requires the number of frames to be 4k+1 end = int(start + num_frames) temp_frms = vr.get_batch(np.arange(start, end)) assert temp_frms is not None - tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms + tensor_frms = ( + torch.from_numpy(temp_frms) + if type(temp_frms) is not torch.Tensor + else temp_frms + ) tensor_frms = pad_last_frame( tensor_frms, self.max_num_frames diff --git a/sat/diffusion_video.py b/sat/diffusion_video.py index b9c0552..9a09288 100644 --- a/sat/diffusion_video.py +++ b/sat/diffusion_video.py @@ -41,7 +41,9 @@ class SATVideoDiffusionEngine(nn.Module): latent_input = model_config.get("latent_input", False) disable_first_stage_autocast = model_config.get("disable_first_stage_autocast", False) no_cond_log = model_config.get("disable_first_stage_autocast", False) - not_trainable_prefixes = model_config.get("not_trainable_prefixes", ["first_stage_model", "conditioner"]) + not_trainable_prefixes = model_config.get( + "not_trainable_prefixes", ["first_stage_model", "conditioner"] + ) compile_model = model_config.get("compile_model", False) en_and_decode_n_samples_a_time = model_config.get("en_and_decode_n_samples_a_time", None) lr_scale = model_config.get("lr_scale", None) @@ -76,12 +78,18 @@ class SATVideoDiffusionEngine(nn.Module): ) self.denoiser = instantiate_from_config(denoiser_config) - self.sampler = instantiate_from_config(sampler_config) if sampler_config is not None else None - self.conditioner = instantiate_from_config(default(conditioner_config, UNCONDITIONAL_CONFIG)) + self.sampler = ( + instantiate_from_config(sampler_config) if sampler_config is not None else None + ) + self.conditioner = instantiate_from_config( + default(conditioner_config, UNCONDITIONAL_CONFIG) + ) self._init_first_stage(first_stage_config) - self.loss_fn = instantiate_from_config(loss_fn_config) if loss_fn_config is not None else None + self.loss_fn = ( + instantiate_from_config(loss_fn_config) if loss_fn_config is not None else None + ) self.latent_input = latent_input self.scale_factor = scale_factor @@ -151,8 +159,12 @@ class SATVideoDiffusionEngine(nn.Module): def shared_step(self, batch: Dict) -> Any: x = self.get_input(batch) if self.lr_scale is not None: - lr_x = F.interpolate(x, scale_factor=1 / self.lr_scale, mode="bilinear", align_corners=False) - lr_x = F.interpolate(lr_x, scale_factor=self.lr_scale, mode="bilinear", align_corners=False) + lr_x = F.interpolate( + x, scale_factor=1 / self.lr_scale, mode="bilinear", align_corners=False + ) + lr_x = F.interpolate( + lr_x, scale_factor=self.lr_scale, mode="bilinear", align_corners=False + ) lr_z = self.encode_first_stage(lr_x, batch) batch["lr_input"] = lr_z @@ -195,7 +207,11 @@ class SATVideoDiffusionEngine(nn.Module): recons = [] start_frame = 0 for i in range(fake_cp_size): - end_frame = start_frame + latent_time // fake_cp_size + (1 if i < latent_time % fake_cp_size else 0) + end_frame = ( + start_frame + + latent_time // fake_cp_size + + (1 if i < latent_time % fake_cp_size else 0) + ) use_cp = True if i == 0 else False clear_fake_cp_cache = True if i == fake_cp_size - 1 else False @@ -264,7 +280,9 @@ class SATVideoDiffusionEngine(nn.Module): self.model, input, sigma, c, concat_images=concat_images, **addtional_model_inputs ) - samples = self.sampler(denoiser, randn, cond, uc=uc, scale=scale, scale_emb=scale_emb, ofs=ofs) + samples = self.sampler( + denoiser, randn, cond, uc=uc, scale=scale, scale_emb=scale_emb, ofs=ofs + ) samples = samples.to(self.dtype) return samples @@ -278,7 +296,9 @@ class SATVideoDiffusionEngine(nn.Module): log = dict() for embedder in self.conditioner.embedders: - if ((self.log_keys is None) or (embedder.input_key in self.log_keys)) and not self.no_cond_log: + if ( + (self.log_keys is None) or (embedder.input_key in self.log_keys) + ) and not self.no_cond_log: x = batch[embedder.input_key][:n] if isinstance(x, torch.Tensor): if x.dim() == 1: @@ -354,7 +374,9 @@ class SATVideoDiffusionEngine(nn.Module): image = torch.concat([image, torch.zeros_like(z[:, 1:])], dim=1) c["concat"] = image uc["concat"] = image - samples = self.sample(c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs) # b t c h w + samples = self.sample( + c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs + ) # b t c h w samples = samples.permute(0, 2, 1, 3, 4).contiguous() if only_log_video_latents: latents = 1.0 / self.scale_factor * samples @@ -364,7 +386,9 @@ class SATVideoDiffusionEngine(nn.Module): samples = samples.permute(0, 2, 1, 3, 4).contiguous() log["samples"] = samples else: - samples = self.sample(c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs) # b t c h w + samples = self.sample( + c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs + ) # b t c h w samples = samples.permute(0, 2, 1, 3, 4).contiguous() if only_log_video_latents: latents = 1.0 / self.scale_factor * samples diff --git a/sat/dit_video_concat.py b/sat/dit_video_concat.py index 82c1d56..c90cf1e 100644 --- a/sat/dit_video_concat.py +++ b/sat/dit_video_concat.py @@ -94,7 +94,9 @@ def get_3d_sincos_pos_embed( # concate: [T, H, W] order pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :] - pos_embed_temporal = np.repeat(pos_embed_temporal, grid_height * grid_width, axis=1) # [T, H*W, D // 4] + pos_embed_temporal = np.repeat( + pos_embed_temporal, grid_height * grid_width, axis=1 + ) # [T, H*W, D // 4] pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :] pos_embed_spatial = np.repeat(pos_embed_spatial, t_size, axis=0) # [T, H*W, D // 4 * 3] @@ -160,7 +162,8 @@ class Basic2DPositionEmbeddingMixin(BaseMixin): self.width = width self.spatial_length = height * width self.pos_embedding = nn.Parameter( - torch.zeros(1, int(text_length + self.spatial_length), int(hidden_size)), requires_grad=False + torch.zeros(1, int(text_length + self.spatial_length), int(hidden_size)), + requires_grad=False, ) def position_embedding_forward(self, position_ids, **kwargs): @@ -169,7 +172,9 @@ class Basic2DPositionEmbeddingMixin(BaseMixin): def reinit(self, parent_model=None): del self.transformer.position_embeddings pos_embed = get_2d_sincos_pos_embed(self.pos_embedding.shape[-1], self.height, self.width) - self.pos_embedding.data[:, -self.spatial_length :].copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + self.pos_embedding.data[:, -self.spatial_length :].copy_( + torch.from_numpy(pos_embed).float().unsqueeze(0) + ) class Basic3DPositionEmbeddingMixin(BaseMixin): @@ -192,7 +197,8 @@ class Basic3DPositionEmbeddingMixin(BaseMixin): self.spatial_length = height * width self.num_patches = height * width * compressed_num_frames self.pos_embedding = nn.Parameter( - torch.zeros(1, int(text_length + self.num_patches), int(hidden_size)), requires_grad=False + torch.zeros(1, int(text_length + self.num_patches), int(hidden_size)), + requires_grad=False, ) self.height_interpolation = height_interpolation self.width_interpolation = width_interpolation @@ -285,7 +291,10 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin): freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2) freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2) - freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1) + freqs = broadcat( + (freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), + dim=-1, + ) freqs = freqs.contiguous() self.freqs_sin = freqs.sin().cuda() @@ -293,7 +302,9 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin): self.text_length = text_length if learnable_pos_embed: num_patches = height * width * compressed_num_frames + text_length - self.pos_embedding = nn.Parameter(torch.zeros(1, num_patches, int(hidden_size)), requires_grad=True) + self.pos_embedding = nn.Parameter( + torch.zeros(1, num_patches, int(hidden_size)), requires_grad=True + ) else: self.pos_embedding = None @@ -440,16 +451,26 @@ class FinalLayerMixin(BaseMixin): self.out_channels = out_channels self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=elementwise_affine, eps=1e-6) self.linear = nn.Linear(hidden_size, reduce(mul, patch_size) * out_channels, bias=True) - self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim, 2 * hidden_size, bias=True)) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(time_embed_dim, 2 * hidden_size, bias=True) + ) def final_forward(self, logits, **kwargs): - x, emb = logits[:, kwargs["text_length"] :, :], kwargs["emb"] # x:(b,(t n),d),只取了x中后面images的部分 + x, emb = ( + logits[:, kwargs["text_length"] :, :], + kwargs["emb"], + ) # x:(b,(t n),d),只取了x中后面images的部分 shift, scale = self.adaLN_modulation(emb).chunk(2, dim=1) x = modulate(self.norm_final(x), shift, scale) x = self.linear(x) return unpatchify( - x, c=self.out_channels, patch_size=self.patch_size, w=kwargs["rope_W"], h=kwargs["rope_H"], **kwargs + x, + c=self.out_channels, + patch_size=self.patch_size, + w=kwargs["rope_W"], + h=kwargs["rope_H"], + **kwargs, ) def reinit(self, parent_model=None): @@ -500,7 +521,10 @@ class AdaLNMixin(BaseMixin): self.compressed_num_frames = compressed_num_frames self.adaLN_modulations = nn.ModuleList( - [nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim, 12 * hidden_size)) for _ in range(num_layers)] + [ + nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim, 12 * hidden_size)) + for _ in range(num_layers) + ] ) self.qk_ln = qk_ln @@ -560,7 +584,9 @@ class AdaLNMixin(BaseMixin): img_attention_input = modulate(img_attention_input, shift_msa, scale_msa) text_attention_input = modulate(text_attention_input, text_shift_msa, text_scale_msa) - attention_input = torch.cat((text_attention_input, img_attention_input), dim=1) # (b,n_t+t*n_i,d) + attention_input = torch.cat( + (text_attention_input, img_attention_input), dim=1 + ) # (b,n_t+t*n_i,d) attention_output = layer.attention(attention_input, mask, **kwargs) text_attention_output = attention_output[:, :text_length] # (b,n,d) img_attention_output = attention_output[:, text_length:] # (b,(t n),d) @@ -584,9 +610,13 @@ class AdaLNMixin(BaseMixin): img_mlp_output = layer.fourth_layernorm(img_mlp_output) img_hidden_states = img_hidden_states + gate_mlp * img_mlp_output # vision (b,(t n),d) - text_hidden_states = text_hidden_states + text_gate_mlp * text_mlp_output # language (b,n,d) + text_hidden_states = ( + text_hidden_states + text_gate_mlp * text_mlp_output + ) # language (b,n,d) - hidden_states = torch.cat((text_hidden_states, img_hidden_states), dim=1) # (b,(n_t+t*n_i),d) + hidden_states = torch.cat( + (text_hidden_states, img_hidden_states), dim=1 + ) # (b,(n_t+t*n_i),d) return hidden_states def reinit(self, parent_model=None): @@ -694,7 +724,9 @@ class DiffusionTransformer(BaseModel): if use_RMSNorm: kwargs["layernorm"] = RMSNorm else: - kwargs["layernorm"] = partial(LayerNorm, elementwise_affine=elementwise_affine, eps=1e-6) + kwargs["layernorm"] = partial( + LayerNorm, elementwise_affine=elementwise_affine, eps=1e-6 + ) transformer_args.num_layers = num_layers transformer_args.hidden_size = hidden_size @@ -707,7 +739,9 @@ class DiffusionTransformer(BaseModel): if use_SwiGLU: self.add_mixin( - "swiglu", SwiGLUMixin(num_layers, hidden_size, self.inner_hidden_size, bias=False), reinit=True + "swiglu", + SwiGLUMixin(num_layers, hidden_size, self.inner_hidden_size, bias=False), + reinit=True, ) def _build_modules(self, module_configs): @@ -813,7 +847,9 @@ class DiffusionTransformer(BaseModel): ) if "lora_config" in module_configs: lora_config = module_configs["lora_config"] - self.add_mixin("lora", instantiate_from_config(lora_config, layer_num=self.num_layers), reinit=True) + self.add_mixin( + "lora", instantiate_from_config(lora_config, layer_num=self.num_layers), reinit=True + ) return def forward(self, x, timesteps=None, context=None, y=None, **kwargs): @@ -829,7 +865,9 @@ class DiffusionTransformer(BaseModel): assert (y is not None) == ( self.num_classes is not None ), "must specify y if and only if the model is class-conditional" - t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False, dtype=self.dtype) + t_emb = timestep_embedding( + timesteps, self.model_channels, repeat_only=False, dtype=self.dtype + ) emb = self.time_embed(t_emb) if self.num_classes is not None: @@ -838,7 +876,9 @@ class DiffusionTransformer(BaseModel): emb = emb + self.label_emb(y) if self.ofs_embed_dim is not None: - ofs_emb = timestep_embedding(kwargs["ofs"], self.ofs_embed_dim, repeat_only=False, dtype=self.dtype) + ofs_emb = timestep_embedding( + kwargs["ofs"], self.ofs_embed_dim, repeat_only=False, dtype=self.dtype + ) ofs_emb = self.ofs_embed(ofs_emb) emb = emb + ofs_emb @@ -852,6 +892,8 @@ class DiffusionTransformer(BaseModel): kwargs["rope_H"] = h // self.patch_size[1] kwargs["rope_W"] = w // self.patch_size[2] - kwargs["input_ids"] = kwargs["position_ids"] = kwargs["attention_mask"] = torch.ones((1, 1)).to(x.dtype) + kwargs["input_ids"] = kwargs["position_ids"] = kwargs["attention_mask"] = torch.ones( + (1, 1) + ).to(x.dtype) output = super().forward(**kwargs)[0] return output diff --git a/sat/finetune_multi_gpus.sh b/sat/finetune_multi_gpus.sh index a33f7cc..6a65c94 100644 --- a/sat/finetune_multi_gpus.sh +++ b/sat/finetune_multi_gpus.sh @@ -7,4 +7,4 @@ run_cmd="PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True torchrun --standalone echo ${run_cmd} eval ${run_cmd} -echo "DONE on `hostname`" \ No newline at end of file +echo "DONE on `hostname`" diff --git a/sat/finetune_single_gpu.sh b/sat/finetune_single_gpu.sh index 1359172..a8af46e 100644 --- a/sat/finetune_single_gpu.sh +++ b/sat/finetune_single_gpu.sh @@ -9,4 +9,4 @@ run_cmd="$environs python train_video.py --base configs/cogvideox_2b_lora.yaml c echo ${run_cmd} eval ${run_cmd} -echo "DONE on `hostname`" \ No newline at end of file +echo "DONE on `hostname`" diff --git a/sat/inference.sh b/sat/inference.sh index 9904433..490264d 100755 --- a/sat/inference.sh +++ b/sat/inference.sh @@ -9,4 +9,4 @@ run_cmd="$environs python sample_video.py --base configs/cogvideox1.5_5b.yaml co echo ${run_cmd} eval ${run_cmd} -echo "DONE on `hostname`" \ No newline at end of file +echo "DONE on `hostname`" diff --git a/sat/requirements.txt b/sat/requirements.txt index 3c1c501..7fb5da2 100644 --- a/sat/requirements.txt +++ b/sat/requirements.txt @@ -8,4 +8,4 @@ safetensors>=0.4.5 scipy>=1.14.1 decord>=0.6.0 wandb>=0.18.5 -deepspeed>=0.15.3 \ No newline at end of file +deepspeed>=0.15.3 diff --git a/sat/sample_video.py b/sat/sample_video.py index c34e6a7..03d3b68 100644 --- a/sat/sample_video.py +++ b/sat/sample_video.py @@ -19,6 +19,7 @@ from sat import mpu from diffusion_video import SATVideoDiffusionEngine from arguments import get_args + def read_from_cli(): cnt = 0 try: @@ -50,34 +51,50 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], T=None, device="cuda for key in keys: if key == "txt": - batch["txt"] = np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist() - batch_uc["txt"] = np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist() + batch["txt"] = ( + np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist() + ) + batch_uc["txt"] = ( + np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist() + ) elif key == "original_size_as_tuple": batch["original_size_as_tuple"] = ( - torch.tensor([value_dict["orig_height"], value_dict["orig_width"]]).to(device).repeat(*N, 1) + torch.tensor([value_dict["orig_height"], value_dict["orig_width"]]) + .to(device) + .repeat(*N, 1) ) elif key == "crop_coords_top_left": batch["crop_coords_top_left"] = ( - torch.tensor([value_dict["crop_coords_top"], value_dict["crop_coords_left"]]).to(device).repeat(*N, 1) + torch.tensor([value_dict["crop_coords_top"], value_dict["crop_coords_left"]]) + .to(device) + .repeat(*N, 1) ) elif key == "aesthetic_score": - batch["aesthetic_score"] = torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1) + batch["aesthetic_score"] = ( + torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1) + ) batch_uc["aesthetic_score"] = ( torch.tensor([value_dict["negative_aesthetic_score"]]).to(device).repeat(*N, 1) ) elif key == "target_size_as_tuple": batch["target_size_as_tuple"] = ( - torch.tensor([value_dict["target_height"], value_dict["target_width"]]).to(device).repeat(*N, 1) + torch.tensor([value_dict["target_height"], value_dict["target_width"]]) + .to(device) + .repeat(*N, 1) ) elif key == "fps": batch[key] = torch.tensor([value_dict["fps"]]).to(device).repeat(math.prod(N)) elif key == "fps_id": batch[key] = torch.tensor([value_dict["fps_id"]]).to(device).repeat(math.prod(N)) elif key == "motion_bucket_id": - batch[key] = torch.tensor([value_dict["motion_bucket_id"]]).to(device).repeat(math.prod(N)) + batch[key] = ( + torch.tensor([value_dict["motion_bucket_id"]]).to(device).repeat(math.prod(N)) + ) elif key == "pool_image": - batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=math.prod(N)).to(device, dtype=torch.half) + batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=math.prod(N)).to( + device, dtype=torch.half + ) elif key == "cond_aug": batch[key] = repeat( torch.tensor([value_dict["cond_aug"]]).to("cuda"), @@ -100,7 +117,9 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], T=None, device="cuda return batch, batch_uc -def save_video_as_grid_and_mp4(video_batch: torch.Tensor, save_path: str, fps: int = 5, args=None, key=None): +def save_video_as_grid_and_mp4( + video_batch: torch.Tensor, save_path: str, fps: int = 5, args=None, key=None +): os.makedirs(save_path, exist_ok=True) for i, vid in enumerate(video_batch): @@ -160,7 +179,9 @@ def sampling_main(args, model_cls): W = 96 H = int(nearest_multiple_of_16(img_H / img_W * W * 8)) // 8 chained_trainsforms = [] - chained_trainsforms.append(TT.Resize(size=[int(H * 8), int(W * 8)], interpolation=1)) + chained_trainsforms.append( + TT.Resize(size=[int(H * 8), int(W * 8)], interpolation=1) + ) chained_trainsforms.append(TT.ToTensor()) transform = TT.Compose(chained_trainsforms) image = transform(image).unsqueeze(0).to("cuda") @@ -170,7 +191,9 @@ def sampling_main(args, model_cls): image = image / model.scale_factor image = image.permute(0, 2, 1, 3, 4).contiguous() pad_shape = (image.shape[0], T - 1, C, H, W) - image = torch.concat([image, torch.zeros(pad_shape).to(image.device).to(image.dtype)], dim=1) + image = torch.concat( + [image, torch.zeros(pad_shape).to(image.device).to(image.dtype)], dim=1 + ) else: image_size = args.sampling_image_size H, W = image_size[0], image_size[1] @@ -181,12 +204,20 @@ def sampling_main(args, model_cls): mp_size = mpu.get_model_parallel_world_size() global_rank = torch.distributed.get_rank() // mp_size src = global_rank * mp_size - torch.distributed.broadcast_object_list(text_cast, src=src, group=mpu.get_model_parallel_group()) + torch.distributed.broadcast_object_list( + text_cast, src=src, group=mpu.get_model_parallel_group() + ) text = text_cast[0] - value_dict = {"prompt": text, "negative_prompt": "", "num_frames": torch.tensor(T).unsqueeze(0)} + value_dict = { + "prompt": text, + "negative_prompt": "", + "num_frames": torch.tensor(T).unsqueeze(0), + } batch, batch_uc = get_batch( - get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples + get_unique_embedder_keys_from_conditioner(model.conditioner), + value_dict, + num_samples, ) for key in batch: if isinstance(batch[key], torch.Tensor): @@ -212,7 +243,11 @@ def sampling_main(args, model_cls): for index in range(args.batch_size): if args.image2video: samples_z = sample_func( - c, uc=uc, batch_size=1, shape=(T, C, H, W), ofs=torch.tensor([2.0]).to("cuda") + c, + uc=uc, + batch_size=1, + shape=(T, C, H, W), + ofs=torch.tensor([2.0]).to("cuda"), ) else: samples_z = sample_func( @@ -226,7 +261,9 @@ def sampling_main(args, model_cls): if args.only_save_latents: samples_z = 1.0 / model.scale_factor * samples_z save_path = os.path.join( - args.output_dir, str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120], str(index) + args.output_dir, + str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120], + str(index), ) os.makedirs(save_path, exist_ok=True) torch.save(samples_z, os.path.join(save_path, "latent.pt")) @@ -237,7 +274,9 @@ def sampling_main(args, model_cls): samples_x = samples_x.permute(0, 2, 1, 3, 4).contiguous() samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu() save_path = os.path.join( - args.output_dir, str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120], str(index) + args.output_dir, + str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120], + str(index), ) if mpu.get_model_parallel_rank() == 0: save_video_as_grid_and_mp4(samples, save_path, fps=args.sampling_fps) diff --git a/sat/sgm/lr_scheduler.py b/sat/sgm/lr_scheduler.py index b45db69..25032ff 100644 --- a/sat/sgm/lr_scheduler.py +++ b/sat/sgm/lr_scheduler.py @@ -71,15 +71,24 @@ class LambdaWarmUpCosineScheduler2: n = n - self.cum_cycles[cycle] if self.verbosity_interval > 0: if n % self.verbosity_interval == 0: - print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}") + print( + f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}" + ) if n < self.lr_warm_up_steps[cycle]: - f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[ + cycle + ] * n + self.f_start[cycle] self.last_f = f return f else: - t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) + t = (n - self.lr_warm_up_steps[cycle]) / ( + self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle] + ) t = min(t, 1.0) - f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (1 + np.cos(t * np.pi)) + f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( + 1 + np.cos(t * np.pi) + ) self.last_f = f return f @@ -93,10 +102,15 @@ class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): n = n - self.cum_cycles[cycle] if self.verbosity_interval > 0: if n % self.verbosity_interval == 0: - print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}") + print( + f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}" + ) if n < self.lr_warm_up_steps[cycle]: - f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[ + cycle + ] * n + self.f_start[cycle] self.last_f = f return f else: diff --git a/sat/sgm/models/autoencoder.py b/sat/sgm/models/autoencoder.py index 0b21318..1cbcf8a 100644 --- a/sat/sgm/models/autoencoder.py +++ b/sat/sgm/models/autoencoder.py @@ -218,14 +218,20 @@ class AutoencodingEngine(AbstractAutoencoder): x = self.decoder(z, **kwargs) return x - def forward(self, x: torch.Tensor, **additional_decode_kwargs) -> Tuple[torch.Tensor, torch.Tensor, dict]: + def forward( + self, x: torch.Tensor, **additional_decode_kwargs + ) -> Tuple[torch.Tensor, torch.Tensor, dict]: z, reg_log = self.encode(x, return_reg_log=True) dec = self.decode(z, **additional_decode_kwargs) return z, dec, reg_log - def inner_training_step(self, batch: dict, batch_idx: int, optimizer_idx: int = 0) -> torch.Tensor: + def inner_training_step( + self, batch: dict, batch_idx: int, optimizer_idx: int = 0 + ) -> torch.Tensor: x = self.get_input(batch) - additional_decode_kwargs = {key: batch[key] for key in self.additional_decode_keys.intersection(batch)} + additional_decode_kwargs = { + key: batch[key] for key in self.additional_decode_keys.intersection(batch) + } z, xrec, regularization_log = self(x, **additional_decode_kwargs) if hasattr(self.loss, "forward_keys"): extra_info = { @@ -361,12 +367,16 @@ class AutoencodingEngine(AbstractAutoencoder): if self.trainable_ae_params is None: ae_params = self.get_autoencoder_params() else: - ae_params, num_ae_params = self.get_param_groups(self.trainable_ae_params, self.ae_optimizer_args) + ae_params, num_ae_params = self.get_param_groups( + self.trainable_ae_params, self.ae_optimizer_args + ) logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}") if self.trainable_disc_params is None: disc_params = self.get_discriminator_params() else: - disc_params, num_disc_params = self.get_param_groups(self.trainable_disc_params, self.disc_optimizer_args) + disc_params, num_disc_params = self.get_param_groups( + self.trainable_disc_params, self.disc_optimizer_args + ) logpy.info(f"Number of trainable discriminator parameters: {num_disc_params:,}") opt_ae = self.instantiate_optimizer_from_config( ae_params, @@ -375,17 +385,23 @@ class AutoencodingEngine(AbstractAutoencoder): ) opts = [opt_ae] if len(disc_params) > 0: - opt_disc = self.instantiate_optimizer_from_config(disc_params, self.learning_rate, self.optimizer_config) + opt_disc = self.instantiate_optimizer_from_config( + disc_params, self.learning_rate, self.optimizer_config + ) opts.append(opt_disc) return opts @torch.no_grad() - def log_images(self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs) -> dict: + def log_images( + self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs + ) -> dict: log = dict() additional_decode_kwargs = {} x = self.get_input(batch) - additional_decode_kwargs.update({key: batch[key] for key in self.additional_decode_keys.intersection(batch)}) + additional_decode_kwargs.update( + {key: batch[key] for key in self.additional_decode_keys.intersection(batch)} + ) _, xrec, _ = self(x, **additional_decode_kwargs) log["inputs"] = x @@ -404,7 +420,9 @@ class AutoencodingEngine(AbstractAutoencoder): diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x) diff_ema.clamp_(0, 1.0) log["diff_ema"] = 2.0 * diff_ema - 1.0 - log["diff_boost_ema"] = 2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1 + log["diff_boost_ema"] = ( + 2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1 + ) if additional_log_kwargs: additional_decode_kwargs.update(additional_log_kwargs) _, xrec_add, _ = self(x, **additional_decode_kwargs) @@ -446,7 +464,9 @@ class AutoencodingEngineLegacy(AutoencodingEngine): params = super().get_autoencoder_params() return params - def encode(self, x: torch.Tensor, return_reg_log: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: + def encode( + self, x: torch.Tensor, return_reg_log: bool = False + ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: if self.max_batch_size is None: z = self.encoder(x) z = self.quant_conv(z) @@ -513,7 +533,9 @@ class VideoAutoencodingEngine(AutoencodingEngine): if ckpt_path is not None: self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) - def log_videos(self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs) -> dict: + def log_videos( + self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs + ) -> dict: return self.log_images(batch, additional_log_kwargs, **kwargs) def get_input(self, batch: dict) -> torch.Tensor: @@ -524,7 +546,9 @@ class VideoAutoencodingEngine(AutoencodingEngine): batch = batch[self.input_key] global_src_rank = get_context_parallel_group_rank() * self.context_parallel_size - torch.distributed.broadcast(batch, src=global_src_rank, group=get_context_parallel_group()) + torch.distributed.broadcast( + batch, src=global_src_rank, group=get_context_parallel_group() + ) batch = _conv_split(batch, dim=2, kernel_size=1) return batch diff --git a/sat/sgm/modules/attention.py b/sat/sgm/modules/attention.py index bb24157..22ef06b 100644 --- a/sat/sgm/modules/attention.py +++ b/sat/sgm/modules/attention.py @@ -94,7 +94,11 @@ class FeedForward(nn.Module): super().__init__() inner_dim = int(dim * mult) dim_out = default(dim_out, dim) - project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim) + project_in = ( + nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) + if not glu + else GEGLU(dim, inner_dim) + ) self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) @@ -126,7 +130,9 @@ class LinearAttention(nn.Module): def forward(self, x): b, c, h, w = x.shape qkv = self.to_qkv(x) - q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3) + q, k, v = rearrange( + qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3 + ) k = k.softmax(dim=-1) context = torch.einsum("bhdn,bhen->bhde", k, v) out = torch.einsum("bhde,bhdn->bhen", context, q) @@ -143,7 +149,9 @@ class SpatialSelfAttention(nn.Module): self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) def forward(self, x): h_ = x @@ -244,7 +252,9 @@ class CrossAttention(nn.Module): # new with sdp_kernel(**BACKEND_MAP[self.backend]): # print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape) - out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) # scale is dim_head ** -0.5 per default + out = F.scaled_dot_product_attention( + q, k, v, attn_mask=mask + ) # scale is dim_head ** -0.5 per default del q, k, v out = rearrange(out, "b h n d -> b n (h d)", h=h) @@ -422,7 +432,9 @@ class BasicTransformerBlock(nn.Module): self.norm1(x), context=context if self.disable_self_attn else None, additional_tokens=additional_tokens, - n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self if not self.disable_self_attn else 0, + n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self + if not self.disable_self_attn + else 0, ) + x ) @@ -499,7 +511,9 @@ class SpatialTransformer(nn.Module): sdp_backend=None, ): super().__init__() - print(f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads") + print( + f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads" + ) from omegaconf import ListConfig if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)): @@ -542,7 +556,9 @@ class SpatialTransformer(nn.Module): ] ) if not use_linear: - self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)) + self.proj_out = zero_module( + nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + ) else: # self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) self.proj_out = zero_module(nn.Linear(inner_dim, in_channels)) diff --git a/sat/sgm/modules/autoencoding/losses/discriminator_loss.py b/sat/sgm/modules/autoencoding/losses/discriminator_loss.py index b5b144a..2b13d67 100644 --- a/sat/sgm/modules/autoencoding/losses/discriminator_loss.py +++ b/sat/sgm/modules/autoencoding/losses/discriminator_loss.py @@ -87,7 +87,9 @@ class GeneralLPIPSWithDiscriminator(nn.Module): yield from () @torch.no_grad() - def log_images(self, inputs: torch.Tensor, reconstructions: torch.Tensor) -> Dict[str, torch.Tensor]: + def log_images( + self, inputs: torch.Tensor, reconstructions: torch.Tensor + ) -> Dict[str, torch.Tensor]: # calc logits of real/fake logits_real = self.discriminator(inputs.contiguous().detach()) if len(logits_real.shape) < 4: @@ -209,7 +211,9 @@ class GeneralLPIPSWithDiscriminator(nn.Module): weights: Union[None, float, torch.Tensor] = None, ) -> Tuple[torch.Tensor, dict]: if self.scale_input_to_tgt_size: - inputs = torch.nn.functional.interpolate(inputs, reconstructions.shape[2:], mode="bicubic", antialias=True) + inputs = torch.nn.functional.interpolate( + inputs, reconstructions.shape[2:], mode="bicubic", antialias=True + ) if self.dims > 2: inputs, reconstructions = map( @@ -226,7 +230,9 @@ class GeneralLPIPSWithDiscriminator(nn.Module): input_frames = pick_video_frame(inputs, frame_indices) recon_frames = pick_video_frame(reconstructions, frame_indices) - p_loss = self.perceptual_loss(input_frames.contiguous(), recon_frames.contiguous()).mean() + p_loss = self.perceptual_loss( + input_frames.contiguous(), recon_frames.contiguous() + ).mean() rec_loss = rec_loss + self.perceptual_weight * p_loss nll_loss, weighted_nll_loss = self.get_nll_loss(rec_loss, weights) @@ -238,7 +244,9 @@ class GeneralLPIPSWithDiscriminator(nn.Module): logits_fake = self.discriminator(reconstructions.contiguous()) g_loss = -torch.mean(logits_fake) if self.training: - d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) + d_weight = self.calculate_adaptive_weight( + nll_loss, g_loss, last_layer=last_layer + ) else: d_weight = torch.tensor(1.0) else: diff --git a/sat/sgm/modules/autoencoding/losses/lpips.py b/sat/sgm/modules/autoencoding/losses/lpips.py index fed01d6..6b7656d 100644 --- a/sat/sgm/modules/autoencoding/losses/lpips.py +++ b/sat/sgm/modules/autoencoding/losses/lpips.py @@ -37,12 +37,18 @@ class LatentLPIPS(nn.Module): if self.perceptual_weight > 0.0: image_reconstructions = self.decoder.decode(latent_predictions) image_targets = self.decoder.decode(latent_inputs) - perceptual_loss = self.perceptual_loss(image_targets.contiguous(), image_reconstructions.contiguous()) - loss = self.latent_weight * loss.mean() + self.perceptual_weight * perceptual_loss.mean() + perceptual_loss = self.perceptual_loss( + image_targets.contiguous(), image_reconstructions.contiguous() + ) + loss = ( + self.latent_weight * loss.mean() + self.perceptual_weight * perceptual_loss.mean() + ) log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach() if self.perceptual_weight_on_inputs > 0.0: - image_reconstructions = default(image_reconstructions, self.decoder.decode(latent_predictions)) + image_reconstructions = default( + image_reconstructions, self.decoder.decode(latent_predictions) + ) if self.scale_input_to_tgt_size: image_inputs = torch.nn.functional.interpolate( image_inputs, @@ -58,7 +64,9 @@ class LatentLPIPS(nn.Module): antialias=True, ) - perceptual_loss2 = self.perceptual_loss(image_inputs.contiguous(), image_reconstructions.contiguous()) + perceptual_loss2 = self.perceptual_loss( + image_inputs.contiguous(), image_reconstructions.contiguous() + ) loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean() log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach() return loss, log diff --git a/sat/sgm/modules/autoencoding/losses/video_loss.py b/sat/sgm/modules/autoencoding/losses/video_loss.py index 01c4c60..0094727 100644 --- a/sat/sgm/modules/autoencoding/losses/video_loss.py +++ b/sat/sgm/modules/autoencoding/losses/video_loss.py @@ -45,7 +45,9 @@ def hinge_gen_loss(fake): @autocast(enabled=False) @beartype def grad_layer_wrt_loss(loss: Tensor, layer: nn.Parameter): - return torch_grad(outputs=loss, inputs=layer, grad_outputs=torch.ones_like(loss), retain_graph=True)[0].detach() + return torch_grad( + outputs=loss, inputs=layer, grad_outputs=torch.ones_like(loss), retain_graph=True + )[0].detach() def pick_video_frame(video, frame_indices): @@ -126,7 +128,8 @@ class DiscriminatorBlock(nn.Module): self.downsample = ( nn.Sequential( - Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2), nn.Conv2d(filters * 4, filters, 1) + Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2), + nn.Conv2d(filters * 4, filters, 1), ) if downsample else None @@ -185,11 +188,18 @@ class Discriminator(nn.Module): is_not_last = ind != (len(layer_dims_in_out) - 1) block = DiscriminatorBlock( - in_chan, out_chan, downsample=is_not_last, antialiased_downsample=antialiased_downsample + in_chan, + out_chan, + downsample=is_not_last, + antialiased_downsample=antialiased_downsample, ) attn_block = nn.Sequential( - Residual(LinearSpaceAttention(dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head)), + Residual( + LinearSpaceAttention( + dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head + ) + ), Residual(FeedForward(dim=out_chan, mult=ff_mult, images=True)), ) @@ -363,7 +373,9 @@ class Discriminator3D(nn.Module): ) attn_block = nn.Sequential( Residual( - LinearSpaceAttention(dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head) + LinearSpaceAttention( + dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head + ) ), Residual(FeedForward(dim=out_chan, mult=ff_mult, images=True)), ) @@ -458,7 +470,9 @@ class Discriminator3DWithfirstframe(nn.Module): ) attn_block = nn.Sequential( Residual( - LinearSpaceAttention(dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head) + LinearSpaceAttention( + dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head + ) ), Residual(FeedForward(dim=out_chan, mult=ff_mult, images=True)), ) @@ -581,11 +595,17 @@ class VideoAutoencoderLoss(nn.Module): input_frames = pick_video_frame(inputs, frame_indices) recon_frames = pick_video_frame(reconstructions, frame_indices) - perceptual_loss = self.perceptual_model(input_frames.contiguous(), recon_frames.contiguous()).mean() + perceptual_loss = self.perceptual_model( + input_frames.contiguous(), recon_frames.contiguous() + ).mean() else: perceptual_loss = self.zero - if global_step >= self.disc_start or not self.training or self.adversarial_loss_weight == 0: + if ( + global_step >= self.disc_start + or not self.training + or self.adversarial_loss_weight == 0 + ): gen_loss = self.zero adaptive_weight = 0 else: @@ -598,9 +618,13 @@ class VideoAutoencoderLoss(nn.Module): adaptive_weight = 1 if self.perceptual_weight > 0 and last_layer is not None: - norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_layer).norm(p=2) + norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss( + perceptual_loss, last_layer + ).norm(p=2) norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_layer).norm(p=2) - adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(min=1e-3) + adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp( + min=1e-3 + ) adaptive_weight.clamp_(max=1e3) if torch.isnan(adaptive_weight).any(): diff --git a/sat/sgm/modules/autoencoding/lpips/loss/.gitignore b/sat/sgm/modules/autoencoding/lpips/loss/.gitignore index a92958a..1396025 100644 --- a/sat/sgm/modules/autoencoding/lpips/loss/.gitignore +++ b/sat/sgm/modules/autoencoding/lpips/loss/.gitignore @@ -1 +1 @@ -vgg.pth \ No newline at end of file +vgg.pth diff --git a/sat/sgm/modules/autoencoding/lpips/loss/LICENSE b/sat/sgm/modules/autoencoding/lpips/loss/LICENSE index 924cfc8..842c363 100644 --- a/sat/sgm/modules/autoencoding/lpips/loss/LICENSE +++ b/sat/sgm/modules/autoencoding/lpips/loss/LICENSE @@ -20,4 +20,4 @@ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/sat/sgm/modules/autoencoding/lpips/loss/lpips.py b/sat/sgm/modules/autoencoding/lpips/loss/lpips.py index a0249cf..94c32c5 100644 --- a/sat/sgm/modules/autoencoding/lpips/loss/lpips.py +++ b/sat/sgm/modules/autoencoding/lpips/loss/lpips.py @@ -48,7 +48,9 @@ class LPIPS(nn.Module): feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 - res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] + res = [ + spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns)) + ] val = res[0] for l in range(1, len(self.chns)): val += res[l] @@ -118,7 +120,9 @@ class vgg16(torch.nn.Module): h_relu4_3 = h h = self.slice5(h) h_relu5_3 = h - vgg_outputs = namedtuple("VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]) + vgg_outputs = namedtuple( + "VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"] + ) out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) return out diff --git a/sat/sgm/modules/autoencoding/lpips/model/LICENSE b/sat/sgm/modules/autoencoding/lpips/model/LICENSE index 4b356e6..d75f0ee 100644 --- a/sat/sgm/modules/autoencoding/lpips/model/LICENSE +++ b/sat/sgm/modules/autoencoding/lpips/model/LICENSE @@ -55,4 +55,4 @@ Redistributions in binary form must reproduce the above copyright notice, this l Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/sat/sgm/modules/autoencoding/lpips/model/model.py b/sat/sgm/modules/autoencoding/lpips/model/model.py index ee13bab..5d767fc 100644 --- a/sat/sgm/modules/autoencoding/lpips/model/model.py +++ b/sat/sgm/modules/autoencoding/lpips/model/model.py @@ -35,7 +35,9 @@ class NLayerDiscriminator(nn.Module): norm_layer = nn.BatchNorm2d else: norm_layer = ActNorm - if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters + if ( + type(norm_layer) == functools.partial + ): # no need to use bias as BatchNorm2d has affine parameters use_bias = norm_layer.func != nn.BatchNorm2d else: use_bias = norm_layer != nn.BatchNorm2d diff --git a/sat/sgm/modules/autoencoding/lpips/vqperceptual.py b/sat/sgm/modules/autoencoding/lpips/vqperceptual.py index 1e4944b..6195f0a 100644 --- a/sat/sgm/modules/autoencoding/lpips/vqperceptual.py +++ b/sat/sgm/modules/autoencoding/lpips/vqperceptual.py @@ -11,6 +11,7 @@ def hinge_d_loss(logits_real, logits_fake): def vanilla_d_loss(logits_real, logits_fake): d_loss = 0.5 * ( - torch.mean(torch.nn.functional.softplus(-logits_real)) + torch.mean(torch.nn.functional.softplus(logits_fake)) + torch.mean(torch.nn.functional.softplus(-logits_real)) + + torch.mean(torch.nn.functional.softplus(logits_fake)) ) return d_loss diff --git a/sat/sgm/modules/autoencoding/magvit2_pytorch.py b/sat/sgm/modules/autoencoding/magvit2_pytorch.py index 5888952..f115c8a 100644 --- a/sat/sgm/modules/autoencoding/magvit2_pytorch.py +++ b/sat/sgm/modules/autoencoding/magvit2_pytorch.py @@ -147,7 +147,9 @@ def hinge_gen_loss(fake): @autocast(enabled=False) @beartype def grad_layer_wrt_loss(loss: Tensor, layer: nn.Parameter): - return torch_grad(outputs=loss, inputs=layer, grad_outputs=torch.ones_like(loss), retain_graph=True)[0].detach() + return torch_grad( + outputs=loss, inputs=layer, grad_outputs=torch.ones_like(loss), retain_graph=True + )[0].detach() # helper decorators @@ -223,7 +225,10 @@ class SqueezeExcite(Module): dim_hidden = max(dim_hidden_min, dim_out // 2) self.net = nn.Sequential( - nn.Conv2d(dim, dim_hidden, 1), nn.LeakyReLU(0.1), nn.Conv2d(dim_hidden, dim_out, 1), nn.Sigmoid() + nn.Conv2d(dim, dim_hidden, 1), + nn.LeakyReLU(0.1), + nn.Conv2d(dim_hidden, dim_out, 1), + nn.Sigmoid(), ) nn.init.zeros_(self.net[-2].weight) @@ -282,7 +287,10 @@ class RMSNorm(Module): self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 def forward(self, x): - return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias + return ( + F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + + self.bias + ) class AdaptiveRMSNorm(Module): @@ -353,7 +361,8 @@ class Attention(Module): self.norm = RMSNorm(dim) self.to_qkv = nn.Sequential( - nn.Linear(dim, dim_inner * 3, bias=False), Rearrange("b n (qkv h d) -> qkv b h n d", qkv=3, h=heads) + nn.Linear(dim, dim_inner * 3, bias=False), + Rearrange("b n (qkv h d) -> qkv b h n d", qkv=3, h=heads), ) assert num_memory_kv > 0 @@ -361,7 +370,9 @@ class Attention(Module): self.attend = Attend(causal=causal, dropout=dropout, flash=flash) - self.to_out = nn.Sequential(Rearrange("b h n d -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False)) + self.to_out = nn.Sequential( + Rearrange("b h n d -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False) + ) @beartype def forward(self, x, mask: Optional[Tensor] = None, cond: Optional[Tensor] = None): @@ -455,7 +466,9 @@ class FeedForward(Module): super().__init__() conv_klass = nn.Conv2d if images else nn.Conv3d - rmsnorm_klass = RMSNorm if not exists(dim_cond) else partial(AdaptiveRMSNorm, dim_cond=dim_cond) + rmsnorm_klass = ( + RMSNorm if not exists(dim_cond) else partial(AdaptiveRMSNorm, dim_cond=dim_cond) + ) maybe_adaptive_norm_klass = partial(rmsnorm_klass, channel_first=True, images=images) @@ -463,7 +476,9 @@ class FeedForward(Module): self.norm = maybe_adaptive_norm_klass(dim) - self.net = Sequential(conv_klass(dim, dim_inner * 2, 1), GEGLU(), conv_klass(dim_inner, dim, 1)) + self.net = Sequential( + conv_klass(dim, dim_inner * 2, 1), GEGLU(), conv_klass(dim_inner, dim, 1) + ) @beartype def forward(self, x: Tensor, *, cond: Optional[Tensor] = None): @@ -525,7 +540,8 @@ class DiscriminatorBlock(Module): self.downsample = ( nn.Sequential( - Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2), nn.Conv2d(filters * 4, filters, 1) + Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2), + nn.Conv2d(filters * 4, filters, 1), ) if downsample else None @@ -584,11 +600,18 @@ class Discriminator(Module): is_not_last = ind != (len(layer_dims_in_out) - 1) block = DiscriminatorBlock( - in_chan, out_chan, downsample=is_not_last, antialiased_downsample=antialiased_downsample + in_chan, + out_chan, + downsample=is_not_last, + antialiased_downsample=antialiased_downsample, ) attn_block = Sequential( - Residual(LinearSpaceAttention(dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head)), + Residual( + LinearSpaceAttention( + dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head + ) + ), Residual(FeedForward(dim=out_chan, mult=ff_mult, images=True)), ) @@ -628,7 +651,16 @@ class Discriminator(Module): class Conv3DMod(Module): @beartype def __init__( - self, dim, *, spatial_kernel, time_kernel, causal=True, dim_out=None, demod=True, eps=1e-8, pad_mode="zeros" + self, + dim, + *, + spatial_kernel, + time_kernel, + causal=True, + dim_out=None, + demod=True, + eps=1e-8, + pad_mode="zeros", ): super().__init__() dim_out = default(dim_out, dim) @@ -644,7 +676,9 @@ class Conv3DMod(Module): self.pad_mode = pad_mode self.padding = (*((spatial_kernel // 2,) * 4), *time_padding) - self.weights = nn.Parameter(torch.randn((dim_out, dim, time_kernel, spatial_kernel, spatial_kernel))) + self.weights = nn.Parameter( + torch.randn((dim_out, dim, time_kernel, spatial_kernel, spatial_kernel)) + ) self.demod = demod @@ -675,7 +709,11 @@ class Conv3DMod(Module): weights = weights * (cond + 1) if self.demod: - inv_norm = reduce(weights**2, "b o i k0 k1 k2 -> b o 1 1 1 1", "sum").clamp(min=self.eps).rsqrt() + inv_norm = ( + reduce(weights**2, "b o i k0 k1 k2 -> b o 1 1 1 1", "sum") + .clamp(min=self.eps) + .rsqrt() + ) weights = weights * inv_norm fmap = rearrange(fmap, "b c t h w -> 1 (b c) t h w") @@ -742,7 +780,9 @@ class SpatialUpsample2x(Module): dim_out = default(dim_out, dim) conv = nn.Conv2d(dim, dim_out * 4, 1) - self.net = nn.Sequential(conv, nn.SiLU(), Rearrange("b (c p1 p2) h w -> b c (h p1) (w p2)", p1=2, p2=2)) + self.net = nn.Sequential( + conv, nn.SiLU(), Rearrange("b (c p1 p2) h w -> b c (h p1) (w p2)", p1=2, p2=2) + ) self.init_conv_(conv) @@ -808,7 +848,12 @@ def SameConv2d(dim_in, dim_out, kernel_size): class CausalConv3d(Module): @beartype def __init__( - self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], pad_mode="constant", **kwargs + self, + chan_in, + chan_out, + kernel_size: Union[int, Tuple[int, int, int]], + pad_mode="constant", + **kwargs, ): super().__init__() kernel_size = cast_tuple(kernel_size, 3) @@ -830,7 +875,9 @@ class CausalConv3d(Module): stride = (stride, 1, 1) dilation = (dilation, 1, 1) - self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) + self.conv = nn.Conv3d( + chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs + ) def forward(self, x): pad_mode = self.pad_mode if self.time_pad < x.shape[2] else "constant" @@ -855,7 +902,13 @@ def ResidualUnit(dim, kernel_size: Union[int, Tuple[int, int, int]], pad_mode: s @beartype class ResidualUnitMod(Module): def __init__( - self, dim, kernel_size: Union[int, Tuple[int, int, int]], *, dim_cond, pad_mode: str = "constant", demod=True + self, + dim, + kernel_size: Union[int, Tuple[int, int, int]], + *, + dim_cond, + pad_mode: str = "constant", + demod=True, ): super().__init__() kernel_size = cast_tuple(kernel_size, 3) @@ -892,7 +945,15 @@ class ResidualUnitMod(Module): class CausalConvTranspose3d(Module): - def __init__(self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], *, time_stride, **kwargs): + def __init__( + self, + chan_in, + chan_out, + kernel_size: Union[int, Tuple[int, int, int]], + *, + time_stride, + **kwargs, + ): super().__init__() kernel_size = cast_tuple(kernel_size, 3) @@ -908,7 +969,9 @@ class CausalConvTranspose3d(Module): stride = (time_stride, 1, 1) padding = (0, height_pad, width_pad) - self.conv = nn.ConvTranspose3d(chan_in, chan_out, kernel_size, stride, padding=padding, **kwargs) + self.conv = nn.ConvTranspose3d( + chan_in, chan_out, kernel_size, stride, padding=padding, **kwargs + ) def forward(self, x): assert x.ndim == 5 @@ -936,7 +999,9 @@ LossBreakdown = namedtuple( ], ) -DiscrLossBreakdown = namedtuple("DiscrLossBreakdown", ["discr_loss", "multiscale_discr_losses", "gradient_penalty"]) +DiscrLossBreakdown = namedtuple( + "DiscrLossBreakdown", ["discr_loss", "multiscale_discr_losses", "gradient_penalty"] +) class VideoTokenizer(Module): @@ -1050,10 +1115,14 @@ class VideoTokenizer(Module): has_cond = True encoder_layer = ResidualUnitMod( - dim, residual_conv_kernel_size, dim_cond=int(dim_cond * dim_cond_expansion_factor) + dim, + residual_conv_kernel_size, + dim_cond=int(dim_cond * dim_cond_expansion_factor), ) decoder_layer = ResidualUnitMod( - dim, residual_conv_kernel_size, dim_cond=int(dim_cond * dim_cond_expansion_factor) + dim, + residual_conv_kernel_size, + dim_cond=int(dim_cond * dim_cond_expansion_factor), ) dim_out = dim @@ -1080,15 +1149,25 @@ class VideoTokenizer(Module): elif layer_type == "attend_space": attn_kwargs = dict( - dim=dim, dim_head=attn_dim_head, heads=attn_heads, dropout=attn_dropout, flash=flash_attn + dim=dim, + dim_head=attn_dim_head, + heads=attn_heads, + dropout=attn_dropout, + flash=flash_attn, ) - encoder_layer = Sequential(Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim))) + encoder_layer = Sequential( + Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim)) + ) - decoder_layer = Sequential(Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim))) + decoder_layer = Sequential( + Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim)) + ) elif layer_type == "linear_attend_space": - linear_attn_kwargs = dict(dim=dim, dim_head=linear_attn_dim_head, heads=linear_attn_heads) + linear_attn_kwargs = dict( + dim=dim, dim_head=linear_attn_dim_head, heads=linear_attn_heads + ) encoder_layer = Sequential( Residual(LinearSpaceAttention(**linear_attn_kwargs)), Residual(FeedForward(dim)) @@ -1136,9 +1215,13 @@ class VideoTokenizer(Module): flash=flash_attn, ) - encoder_layer = Sequential(Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim))) + encoder_layer = Sequential( + Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim)) + ) - decoder_layer = Sequential(Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim))) + decoder_layer = Sequential( + Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim)) + ) elif layer_type == "cond_linear_attend_space": has_cond = True @@ -1153,11 +1236,13 @@ class VideoTokenizer(Module): ) encoder_layer = Sequential( - Residual(LinearSpaceAttention(**attn_kwargs)), Residual(FeedForward(dim, dim_cond=dim_cond)) + Residual(LinearSpaceAttention(**attn_kwargs)), + Residual(FeedForward(dim, dim_cond=dim_cond)), ) decoder_layer = Sequential( - Residual(LinearSpaceAttention(**attn_kwargs)), Residual(FeedForward(dim, dim_cond=dim_cond)) + Residual(LinearSpaceAttention(**attn_kwargs)), + Residual(FeedForward(dim, dim_cond=dim_cond)), ) elif layer_type == "cond_attend_time": @@ -1283,7 +1368,9 @@ class VideoTokenizer(Module): # discriminator - discr_kwargs = default(discr_kwargs, dict(dim=dim, image_size=image_size, channels=channels, max_dim=512)) + discr_kwargs = default( + discr_kwargs, dict(dim=dim, image_size=image_size, channels=channels, max_dim=512) + ) self.discr = Discriminator(**discr_kwargs) @@ -1380,8 +1467,16 @@ class VideoTokenizer(Module): self.load_state_dict(state_dict, strict=strict) @beartype - def encode(self, video: Tensor, quantize=False, cond: Optional[Tensor] = None, video_contains_first_frame=True): - encode_first_frame_separately = self.separate_first_frame_encoding and video_contains_first_frame + def encode( + self, + video: Tensor, + quantize=False, + cond: Optional[Tensor] = None, + video_contains_first_frame=True, + ): + encode_first_frame_separately = ( + self.separate_first_frame_encoding and video_contains_first_frame + ) # whether to pad video or not @@ -1389,12 +1484,16 @@ class VideoTokenizer(Module): video_len = video.shape[2] video = pad_at_dim(video, (self.time_padding, 0), value=0.0, dim=2) - video_packed_shape = [torch.Size([self.time_padding]), torch.Size([]), torch.Size([video_len - 1])] + video_packed_shape = [ + torch.Size([self.time_padding]), + torch.Size([]), + torch.Size([video_len - 1]), + ] # conditioning, if needed - assert (not self.has_cond) or exists( - cond + assert ( + (not self.has_cond) or exists(cond) ), "`cond` must be passed into tokenizer forward method since conditionable layers were specified" if exists(cond): @@ -1431,7 +1530,9 @@ class VideoTokenizer(Module): return maybe_quantize(video) @beartype - def decode_from_code_indices(self, codes: Tensor, cond: Optional[Tensor] = None, video_contains_first_frame=True): + def decode_from_code_indices( + self, codes: Tensor, cond: Optional[Tensor] = None, video_contains_first_frame=True + ): assert codes.dtype in (torch.long, torch.int32) if codes.ndim == 2: @@ -1444,18 +1545,24 @@ class VideoTokenizer(Module): quantized = self.quantizers.indices_to_codes(codes) - return self.decode(quantized, cond=cond, video_contains_first_frame=video_contains_first_frame) + return self.decode( + quantized, cond=cond, video_contains_first_frame=video_contains_first_frame + ) @beartype - def decode(self, quantized: Tensor, cond: Optional[Tensor] = None, video_contains_first_frame=True): - decode_first_frame_separately = self.separate_first_frame_encoding and video_contains_first_frame + def decode( + self, quantized: Tensor, cond: Optional[Tensor] = None, video_contains_first_frame=True + ): + decode_first_frame_separately = ( + self.separate_first_frame_encoding and video_contains_first_frame + ) batch = quantized.shape[0] # conditioning, if needed - assert (not self.has_cond) or exists( - cond + assert ( + (not self.has_cond) or exists(cond) ), "`cond` must be passed into tokenizer forward method since conditionable layers were specified" if exists(cond): @@ -1558,14 +1665,18 @@ class VideoTokenizer(Module): aux_losses = self.zero quantizer_loss_breakdown = None else: - (quantized, codes, aux_losses), quantizer_loss_breakdown = self.quantizers(x, return_loss_breakdown=True) + (quantized, codes, aux_losses), quantizer_loss_breakdown = self.quantizers( + x, return_loss_breakdown=True + ) if return_codes and not return_recon: return codes # decoder - recon_video = self.decode(quantized, cond=cond, video_contains_first_frame=video_contains_first_frame) + recon_video = self.decode( + quantized, cond=cond, video_contains_first_frame=video_contains_first_frame + ) if return_codes: return codes, recon_video @@ -1613,7 +1724,9 @@ class VideoTokenizer(Module): multiscale_real_logits = discr(video) multiscale_fake_logits = discr(recon_video.detach()) - multiscale_discr_loss = hinge_discr_loss(multiscale_fake_logits, multiscale_real_logits) + multiscale_discr_loss = hinge_discr_loss( + multiscale_fake_logits, multiscale_real_logits + ) multiscale_discr_losses.append(multiscale_discr_loss) else: @@ -1634,7 +1747,9 @@ class VideoTokenizer(Module): + sum(multiscale_discr_losses) * self.multiscale_adversarial_loss_weight ) - discr_loss_breakdown = DiscrLossBreakdown(discr_loss, multiscale_discr_losses, gradient_penalty_loss) + discr_loss_breakdown = DiscrLossBreakdown( + discr_loss, multiscale_discr_losses, gradient_penalty_loss + ) return total_loss, discr_loss_breakdown @@ -1669,7 +1784,9 @@ class VideoTokenizer(Module): norm_grad_wrt_perceptual_loss = None if self.training and self.use_vgg and (self.has_gan or self.has_multiscale_discrs): - norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p=2) + norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss( + perceptual_loss, last_dec_layer + ).norm(p=2) # per-frame image discriminator @@ -1686,7 +1803,9 @@ class VideoTokenizer(Module): if exists(norm_grad_wrt_perceptual_loss): norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p=2) - adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(min=1e-3) + adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp( + min=1e-3 + ) adaptive_weight.clamp_(max=1e3) if torch.isnan(adaptive_weight).any(): @@ -1713,8 +1832,12 @@ class VideoTokenizer(Module): multiscale_adaptive_weight = 1.0 if exists(norm_grad_wrt_perceptual_loss): - norm_grad_wrt_gen_loss = grad_layer_wrt_loss(multiscale_gen_loss, last_dec_layer).norm(p=2) - multiscale_adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(min=1e-5) + norm_grad_wrt_gen_loss = grad_layer_wrt_loss( + multiscale_gen_loss, last_dec_layer + ).norm(p=2) + multiscale_adaptive_weight = ( + norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(min=1e-5) + ) multiscale_adaptive_weight.clamp_(max=1e3) multiscale_gen_adaptive_weights.append(multiscale_adaptive_weight) @@ -1730,10 +1853,13 @@ class VideoTokenizer(Module): if self.has_multiscale_discrs: weighted_multiscale_gen_losses = sum( - loss * weight for loss, weight in zip(multiscale_gen_losses, multiscale_gen_adaptive_weights) + loss * weight + for loss, weight in zip(multiscale_gen_losses, multiscale_gen_adaptive_weights) ) - total_loss = total_loss + weighted_multiscale_gen_losses * multiscale_adversarial_loss_weight + total_loss = ( + total_loss + weighted_multiscale_gen_losses * multiscale_adversarial_loss_weight + ) # loss breakdown diff --git a/sat/sgm/modules/autoencoding/regularizers/base.py b/sat/sgm/modules/autoencoding/regularizers/base.py index 7f455be..fa28c2c 100644 --- a/sat/sgm/modules/autoencoding/regularizers/base.py +++ b/sat/sgm/modules/autoencoding/regularizers/base.py @@ -26,7 +26,9 @@ class IdentityRegularizer(AbstractRegularizer): yield from () -def measure_perplexity(predicted_indices: torch.Tensor, num_centroids: int) -> Tuple[torch.Tensor, torch.Tensor]: +def measure_perplexity( + predicted_indices: torch.Tensor, num_centroids: int +) -> Tuple[torch.Tensor, torch.Tensor]: # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally encodings = F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids) diff --git a/sat/sgm/modules/autoencoding/regularizers/finite_scalar_quantization.py b/sat/sgm/modules/autoencoding/regularizers/finite_scalar_quantization.py index 5a20dd6..f0a9898 100644 --- a/sat/sgm/modules/autoencoding/regularizers/finite_scalar_quantization.py +++ b/sat/sgm/modules/autoencoding/regularizers/finite_scalar_quantization.py @@ -79,13 +79,19 @@ class FSQ(Module): self.dim = default(dim, len(_levels) * num_codebooks) has_projections = self.dim != effective_codebook_dim - self.project_in = nn.Linear(self.dim, effective_codebook_dim) if has_projections else nn.Identity() - self.project_out = nn.Linear(effective_codebook_dim, self.dim) if has_projections else nn.Identity() + self.project_in = ( + nn.Linear(self.dim, effective_codebook_dim) if has_projections else nn.Identity() + ) + self.project_out = ( + nn.Linear(effective_codebook_dim, self.dim) if has_projections else nn.Identity() + ) self.has_projections = has_projections self.codebook_size = self._levels.prod().item() - implicit_codebook = self.indices_to_codes(torch.arange(self.codebook_size), project_out=False) + implicit_codebook = self.indices_to_codes( + torch.arange(self.codebook_size), project_out=False + ) self.register_buffer("implicit_codebook", implicit_codebook, persistent=False) def bound(self, z: Tensor, eps: float = 1e-3) -> Tensor: @@ -153,7 +159,9 @@ class FSQ(Module): z = rearrange(z, "b d ... -> b ... d") z, ps = pack_one(z, "b * d") - assert z.shape[-1] == self.dim, f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}" + assert ( + z.shape[-1] == self.dim + ), f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}" z = self.project_in(z) diff --git a/sat/sgm/modules/autoencoding/regularizers/lookup_free_quantization.py b/sat/sgm/modules/autoencoding/regularizers/lookup_free_quantization.py index beca888..f70dee8 100644 --- a/sat/sgm/modules/autoencoding/regularizers/lookup_free_quantization.py +++ b/sat/sgm/modules/autoencoding/regularizers/lookup_free_quantization.py @@ -78,7 +78,9 @@ class LFQ(Module): # some assert validations - assert exists(dim) or exists(codebook_size), "either dim or codebook_size must be specified for LFQ" + assert exists(dim) or exists( + codebook_size + ), "either dim or codebook_size must be specified for LFQ" assert ( not exists(codebook_size) or log2(codebook_size).is_integer() ), f"your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})" @@ -195,7 +197,9 @@ class LFQ(Module): x = rearrange(x, "b d ... -> b ... d") x, ps = pack_one(x, "b * d") - assert x.shape[-1] == self.dim, f"expected dimension of {self.dim} but received {x.shape[-1]}" + assert ( + x.shape[-1] == self.dim + ), f"expected dimension of {self.dim} but received {x.shape[-1]}" x = self.project_in(x) @@ -299,7 +303,9 @@ class LFQ(Module): # complete aux loss - aux_loss = entropy_aux_loss * self.entropy_loss_weight + commit_loss * self.commitment_loss_weight + aux_loss = ( + entropy_aux_loss * self.entropy_loss_weight + commit_loss * self.commitment_loss_weight + ) ret = Return(x, indices, aux_loss) diff --git a/sat/sgm/modules/autoencoding/regularizers/quantize.py b/sat/sgm/modules/autoencoding/regularizers/quantize.py index 583f488..11615c0 100644 --- a/sat/sgm/modules/autoencoding/regularizers/quantize.py +++ b/sat/sgm/modules/autoencoding/regularizers/quantize.py @@ -33,7 +33,9 @@ class AbstractQuantizer(AbstractRegularizer): new = match.argmax(-1) unknown = match.sum(2) < 1 if self.unknown_index == "random": - new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) + new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to( + device=new.device + ) else: new[unknown] = self.unknown_index return new.reshape(ishape) @@ -50,7 +52,9 @@ class AbstractQuantizer(AbstractRegularizer): return back.reshape(ishape) @abstractmethod - def get_codebook_entry(self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None) -> torch.Tensor: + def get_codebook_entry( + self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None + ) -> torch.Tensor: raise NotImplementedError() def get_trainable_parameters(self) -> Iterator[torch.nn.Parameter]: @@ -239,7 +243,8 @@ class VectorQuantizer(AbstractQuantizer): d = ( torch.sum(z_flattened**2, dim=1, keepdim=True) + torch.sum(self.embedding.weight**2, dim=1) - - 2 * torch.einsum("bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n")) + - 2 + * torch.einsum("bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n")) ) min_encoding_indices = torch.argmin(d, dim=1) @@ -267,15 +272,21 @@ class VectorQuantizer(AbstractQuantizer): if self.sane_index_shape: if do_reshape: - min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3]) + min_encoding_indices = min_encoding_indices.reshape( + z_q.shape[0], z_q.shape[2], z_q.shape[3] + ) else: - min_encoding_indices = rearrange(min_encoding_indices, "(b s) 1 -> b s", b=z_q.shape[0]) + min_encoding_indices = rearrange( + min_encoding_indices, "(b s) 1 -> b s", b=z_q.shape[0] + ) loss_dict["min_encoding_indices"] = min_encoding_indices return z_q, loss_dict - def get_codebook_entry(self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None) -> torch.Tensor: + def get_codebook_entry( + self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None + ) -> torch.Tensor: # shape specifying (batch, height, width, channel) if self.remap is not None: assert shape is not None, "Need to give shape for remap" @@ -448,6 +459,8 @@ class VectorQuantizerWithInputProjection(VectorQuantizer): elif len(in_shape) == 5: z_q = rearrange(z_q, "b (t h w) c -> b c t h w ", w=in_shape[-1], h=in_shape[-2]) else: - raise NotImplementedError(f"rearranging not available for {len(in_shape)}-dimensional input.") + raise NotImplementedError( + f"rearranging not available for {len(in_shape)}-dimensional input." + ) return z_q, loss_dict diff --git a/sat/sgm/modules/autoencoding/temporal_ae.py b/sat/sgm/modules/autoencoding/temporal_ae.py index a45ef9d..3050345 100644 --- a/sat/sgm/modules/autoencoding/temporal_ae.py +++ b/sat/sgm/modules/autoencoding/temporal_ae.py @@ -248,7 +248,9 @@ def make_time_attn( "vanilla", "vanilla-xformers", ], f"attn_type {attn_type} not supported for spatio-temporal attention" - print(f"making spatial and temporal attention of type '{attn_type}' with {in_channels} in_channels") + print( + f"making spatial and temporal attention of type '{attn_type}' with {in_channels} in_channels" + ) if not XFORMERS_IS_AVAILABLE and attn_type == "vanilla-xformers": print( f"Attention mode '{attn_type}' is not available. Falling back to vanilla attention. " diff --git a/sat/sgm/modules/autoencoding/vqvae/movq_dec_3d.py b/sat/sgm/modules/autoencoding/vqvae/movq_dec_3d.py index 0f9a469..847555e 100644 --- a/sat/sgm/modules/autoencoding/vqvae/movq_dec_3d.py +++ b/sat/sgm/modules/autoencoding/vqvae/movq_dec_3d.py @@ -125,9 +125,13 @@ class ResnetBlock3D(nn.Module): self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, pad_mode=pad_mode) if self.in_channels != self.out_channels: if self.use_conv_shortcut: - self.conv_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode) + self.conv_shortcut = CausalConv3d( + in_channels, out_channels, kernel_size=3, pad_mode=pad_mode + ) else: - self.nin_shortcut = torch.nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + self.nin_shortcut = torch.nn.Conv3d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) def forward(self, x, temb, zq): h = x @@ -161,7 +165,9 @@ class AttnBlock2D(nn.Module): self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) def forward(self, x, zq): h_ = x @@ -380,7 +386,11 @@ class NewDecoder3D(nn.Module): block_in = ch * ch_mult[self.num_resolutions - 1] curr_res = resolution // 2 ** (self.num_resolutions - 1) self.z_shape = (1, z_channels, curr_res, curr_res) - print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) + print( + "Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape) + ) + ) # z to block_in # self.conv_in = torch.nn.Conv3d(z_channels, diff --git a/sat/sgm/modules/autoencoding/vqvae/movq_dec_3d_dev.py b/sat/sgm/modules/autoencoding/vqvae/movq_dec_3d_dev.py index 1b9a663..f4bb268 100644 --- a/sat/sgm/modules/autoencoding/vqvae/movq_dec_3d_dev.py +++ b/sat/sgm/modules/autoencoding/vqvae/movq_dec_3d_dev.py @@ -148,9 +148,13 @@ class ResnetBlock3D(nn.Module): # kernel_size=3, # stride=1, # padding=1) - self.conv_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode) + self.conv_shortcut = CausalConv3d( + in_channels, out_channels, kernel_size=3, pad_mode=pad_mode + ) else: - self.nin_shortcut = torch.nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + self.nin_shortcut = torch.nn.Conv3d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) # self.nin_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=1, pad_mode=pad_mode) def forward(self, x, temb, zq): @@ -185,7 +189,9 @@ class AttnBlock2D(nn.Module): self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) def forward(self, x, zq): h_ = x @@ -261,7 +267,11 @@ class MOVQDecoder3D(nn.Module): block_in = ch * ch_mult[self.num_resolutions - 1] curr_res = resolution // 2 ** (self.num_resolutions - 1) self.z_shape = (1, z_channels, curr_res, curr_res) - print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) + print( + "Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape) + ) + ) # z to block_in # self.conv_in = torch.nn.Conv3d(z_channels, @@ -420,7 +430,11 @@ class NewDecoder3D(nn.Module): block_in = ch * ch_mult[self.num_resolutions - 1] curr_res = resolution // 2 ** (self.num_resolutions - 1) self.z_shape = (1, z_channels, curr_res, curr_res) - print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) + print( + "Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape) + ) + ) # z to block_in # self.conv_in = torch.nn.Conv3d(z_channels, diff --git a/sat/sgm/modules/autoencoding/vqvae/movq_enc_3d.py b/sat/sgm/modules/autoencoding/vqvae/movq_enc_3d.py index 99b358d..d69b369 100644 --- a/sat/sgm/modules/autoencoding/vqvae/movq_enc_3d.py +++ b/sat/sgm/modules/autoencoding/vqvae/movq_enc_3d.py @@ -51,7 +51,12 @@ def nonlinearity(x): class CausalConv3d(nn.Module): @beartype def __init__( - self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], pad_mode="constant", **kwargs + self, + chan_in, + chan_out, + kernel_size: Union[int, Tuple[int, int, int]], + pad_mode="constant", + **kwargs, ): super().__init__() kernel_size = cast_tuple(kernel_size, 3) @@ -75,11 +80,20 @@ class CausalConv3d(nn.Module): stride = (stride, 1, 1) dilation = (dilation, 1, 1) - self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) + self.conv = nn.Conv3d( + chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs + ) def forward(self, x): if self.pad_mode == "constant": - causal_padding_3d = (self.time_pad, 0, self.width_pad, self.width_pad, self.height_pad, self.height_pad) + causal_padding_3d = ( + self.time_pad, + 0, + self.width_pad, + self.width_pad, + self.height_pad, + self.height_pad, + ) x = F.pad(x, causal_padding_3d, mode="constant", value=0) elif self.pad_mode == "first": pad_x = torch.cat([x[:, :, :1]] * self.time_pad, dim=2) @@ -91,7 +105,9 @@ class CausalConv3d(nn.Module): reflect_x = x[:, :, 1 : self.time_pad + 1, :, :].flip(dims=[2]) if reflect_x.shape[2] < self.time_pad: reflect_x = torch.cat( - [torch.zeros_like(x[:, :, :1, :, :])] * (self.time_pad - reflect_x.shape[2]) + [reflect_x], dim=2 + [torch.zeros_like(x[:, :, :1, :, :])] * (self.time_pad - reflect_x.shape[2]) + + [reflect_x], + dim=2, ) x = torch.cat([reflect_x, x], dim=2) causal_padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) @@ -110,7 +126,9 @@ class Upsample3D(nn.Module): super().__init__() self.with_conv = with_conv if self.with_conv: - self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) self.compress_time = compress_time def forward(self, x): @@ -149,7 +167,9 @@ class DownSample3D(nn.Module): out_channels = in_channels if self.with_conv: # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0) + self.conv = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=2, padding=0 + ) self.compress_time = compress_time def forward(self, x): @@ -182,7 +202,14 @@ class DownSample3D(nn.Module): class ResnetBlock3D(nn.Module): def __init__( - self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512, pad_mode="constant" + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout, + temb_channels=512, + pad_mode="constant", ): super().__init__() self.in_channels = in_channels @@ -214,9 +241,13 @@ class ResnetBlock3D(nn.Module): # kernel_size=3, # stride=1, # padding=1) - self.conv_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode) + self.conv_shortcut = CausalConv3d( + in_channels, out_channels, kernel_size=3, pad_mode=pad_mode + ) else: - self.nin_shortcut = torch.nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + self.nin_shortcut = torch.nn.Conv3d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) # self.nin_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=1, pad_mode=pad_mode) def forward(self, x, temb): @@ -251,7 +282,9 @@ class AttnBlock2D(nn.Module): self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) def forward(self, x): h_ = x @@ -365,12 +398,20 @@ class Encoder3D(nn.Module): # middle self.mid = nn.Module() self.mid.block_1 = ResnetBlock3D( - in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, pad_mode=pad_mode + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + pad_mode=pad_mode, ) # remove attention block # self.mid.attn_1 = AttnBlock2D(block_in) self.mid.block_2 = ResnetBlock3D( - in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, pad_mode=pad_mode + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + pad_mode=pad_mode, ) # end diff --git a/sat/sgm/modules/autoencoding/vqvae/movq_modules.py b/sat/sgm/modules/autoencoding/vqvae/movq_modules.py index 2773b0f..323a04e 100644 --- a/sat/sgm/modules/autoencoding/vqvae/movq_modules.py +++ b/sat/sgm/modules/autoencoding/vqvae/movq_modules.py @@ -80,7 +80,9 @@ class Upsample(nn.Module): super().__init__() self.with_conv = with_conv if self.with_conv: - self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) def forward(self, x): x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") @@ -95,7 +97,9 @@ class Downsample(nn.Module): self.with_conv = with_conv if self.with_conv: # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0 + ) def forward(self, x): if self.with_conv: @@ -134,9 +138,13 @@ class ResnetBlock(nn.Module): self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) if self.in_channels != self.out_channels: if self.use_conv_shortcut: - self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.conv_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) else: - self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + self.nin_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) def forward(self, x, temb, zq): h = x @@ -170,7 +178,9 @@ class AttnBlock(nn.Module): self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) def forward(self, x, zq): h_ = x @@ -232,7 +242,11 @@ class MOVQDecoder(nn.Module): block_in = ch * ch_mult[self.num_resolutions - 1] curr_res = resolution // 2 ** (self.num_resolutions - 1) self.z_shape = (1, z_channels, curr_res, curr_res) - print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) + print( + "Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape) + ) + ) # z to block_in self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) diff --git a/sat/sgm/modules/autoencoding/vqvae/quantize.py b/sat/sgm/modules/autoencoding/vqvae/quantize.py index 54ea128..3f6b98c 100644 --- a/sat/sgm/modules/autoencoding/vqvae/quantize.py +++ b/sat/sgm/modules/autoencoding/vqvae/quantize.py @@ -15,7 +15,16 @@ class VectorQuantizer2(nn.Module): # NOTE: due to a bug the beta term was applied to the wrong term. for # backwards compatibility we use the buggy version by default, but you can # specify legacy=False to fix it. - def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True): + def __init__( + self, + n_e, + e_dim, + beta, + remap=None, + unknown_index="random", + sane_index_shape=False, + legacy=True, + ): super().__init__() self.n_e = n_e self.e_dim = e_dim @@ -51,7 +60,9 @@ class VectorQuantizer2(nn.Module): new = match.argmax(-1) unknown = match.sum(2) < 1 if self.unknown_index == "random": - new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) + new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to( + device=new.device + ) else: new[unknown] = self.unknown_index return new.reshape(ishape) @@ -78,7 +89,8 @@ class VectorQuantizer2(nn.Module): d = ( torch.sum(z_flattened**2, dim=1, keepdim=True) + torch.sum(self.embedding.weight**2, dim=1) - - 2 * torch.einsum("bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n")) + - 2 + * torch.einsum("bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n")) ) min_encoding_indices = torch.argmin(d, dim=1) @@ -88,9 +100,13 @@ class VectorQuantizer2(nn.Module): # compute loss for embedding if not self.legacy: - loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2) + loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean( + (z_q - z.detach()) ** 2 + ) else: - loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2) + loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean( + (z_q - z.detach()) ** 2 + ) # preserve gradients z_q = z + (z_q - z).detach() @@ -104,7 +120,9 @@ class VectorQuantizer2(nn.Module): min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten if self.sane_index_shape: - min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3]) + min_encoding_indices = min_encoding_indices.reshape( + z_q.shape[0], z_q.shape[2], z_q.shape[3] + ) return z_q, loss, (perplexity, min_encodings, min_encoding_indices) @@ -184,7 +202,9 @@ class GumbelQuantize(nn.Module): new = match.argmax(-1) unknown = match.sum(2) < 1 if self.unknown_index == "random": - new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) + new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to( + device=new.device + ) else: new[unknown] = self.unknown_index return new.reshape(ishape) diff --git a/sat/sgm/modules/autoencoding/vqvae/vqvae_blocks.py b/sat/sgm/modules/autoencoding/vqvae/vqvae_blocks.py index e42154f..3ac2c25 100644 --- a/sat/sgm/modules/autoencoding/vqvae/vqvae_blocks.py +++ b/sat/sgm/modules/autoencoding/vqvae/vqvae_blocks.py @@ -40,7 +40,9 @@ class Upsample(nn.Module): super().__init__() self.with_conv = with_conv if self.with_conv: - self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) def forward(self, x): x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") @@ -55,7 +57,9 @@ class Downsample(nn.Module): self.with_conv = with_conv if self.with_conv: # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0 + ) def forward(self, x): if self.with_conv: @@ -68,7 +72,9 @@ class Downsample(nn.Module): class ResnetBlock(nn.Module): - def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512): + def __init__( + self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512 + ): super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels @@ -84,9 +90,13 @@ class ResnetBlock(nn.Module): self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) if self.in_channels != self.out_channels: if self.use_conv_shortcut: - self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.conv_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) else: - self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + self.nin_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) def forward(self, x, temb): h = x @@ -120,7 +130,9 @@ class AttnBlock(nn.Module): self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) def forward(self, x): h_ = x @@ -194,7 +206,10 @@ class Encoder(nn.Module): for i_block in range(self.num_res_blocks): block.append( ResnetBlock( - in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, ) ) block_in = block_out @@ -326,7 +341,11 @@ class Decoder(nn.Module): block_in = ch * ch_mult[self.num_resolutions - 1] curr_res = resolution // 2 ** (self.num_resolutions - 1) self.z_shape = (1, z_channels, curr_res, curr_res) - print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) + print( + "Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape) + ) + ) # z to block_in self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) @@ -350,7 +369,10 @@ class Decoder(nn.Module): for i_block in range(self.num_res_blocks + 1): block.append( ResnetBlock( - in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, ) ) block_in = block_out diff --git a/sat/sgm/modules/cp_enc_dec.py b/sat/sgm/modules/cp_enc_dec.py index 931e657..b010834 100644 --- a/sat/sgm/modules/cp_enc_dec.py +++ b/sat/sgm/modules/cp_enc_dec.py @@ -136,9 +136,9 @@ def _conv_split(input_, dim, kernel_size): if cp_rank == 0: output = input_.transpose(dim, 0)[: dim_size + kernel_size].transpose(dim, 0) else: - output = input_.transpose(dim, 0)[cp_rank * dim_size + 1 : (cp_rank + 1) * dim_size + kernel_size].transpose( - dim, 0 - ) + output = input_.transpose(dim, 0)[ + cp_rank * dim_size + 1 : (cp_rank + 1) * dim_size + kernel_size + ].transpose(dim, 0) output = output.contiguous() # print('out _conv_split, cp_rank:', cp_rank, 'input_size:', output.shape) diff --git a/sat/sgm/modules/diffusionmodules/denoiser.py b/sat/sgm/modules/diffusionmodules/denoiser.py index 3cc01e3..2dc56cd 100644 --- a/sat/sgm/modules/diffusionmodules/denoiser.py +++ b/sat/sgm/modules/diffusionmodules/denoiser.py @@ -35,7 +35,9 @@ class Denoiser(nn.Module): sigma = append_dims(sigma, input.ndim) c_skip, c_out, c_in, c_noise = self.scaling(sigma, **additional_model_inputs) c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape)) - return network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out + input * c_skip + return ( + network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out + input * c_skip + ) class DiscreteDenoiser(Denoiser): @@ -50,7 +52,9 @@ class DiscreteDenoiser(Denoiser): flip=True, ): super().__init__(weighting_config, scaling_config) - sigmas = instantiate_from_config(discretization_config)(num_idx, do_append_zero=do_append_zero, flip=flip) + sigmas = instantiate_from_config(discretization_config)( + num_idx, do_append_zero=do_append_zero, flip=flip + ) self.sigmas = sigmas # self.register_buffer("sigmas", sigmas) self.quantize_c_noise = quantize_c_noise diff --git a/sat/sgm/modules/diffusionmodules/denoiser_scaling.py b/sat/sgm/modules/diffusionmodules/denoiser_scaling.py index 2cb9643..05362a0 100644 --- a/sat/sgm/modules/diffusionmodules/denoiser_scaling.py +++ b/sat/sgm/modules/diffusionmodules/denoiser_scaling.py @@ -6,7 +6,9 @@ import torch class DenoiserScaling(ABC): @abstractmethod - def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + def __call__( + self, sigma: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: pass @@ -14,7 +16,9 @@ class EDMScaling: def __init__(self, sigma_data: float = 0.5): self.sigma_data = sigma_data - def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + def __call__( + self, sigma: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5 c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5 @@ -23,7 +27,9 @@ class EDMScaling: class EpsScaling: - def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + def __call__( + self, sigma: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: c_skip = torch.ones_like(sigma, device=sigma.device) c_out = -sigma c_in = 1 / (sigma**2 + 1.0) ** 0.5 @@ -32,7 +38,9 @@ class EpsScaling: class VScaling: - def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + def __call__( + self, sigma: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: c_skip = 1.0 / (sigma**2 + 1.0) c_out = -sigma / (sigma**2 + 1.0) ** 0.5 c_in = 1.0 / (sigma**2 + 1.0) ** 0.5 @@ -41,7 +49,9 @@ class VScaling: class VScalingWithEDMcNoise(DenoiserScaling): - def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + def __call__( + self, sigma: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: c_skip = 1.0 / (sigma**2 + 1.0) c_out = -sigma / (sigma**2 + 1.0) ** 0.5 c_in = 1.0 / (sigma**2 + 1.0) ** 0.5 diff --git a/sat/sgm/modules/diffusionmodules/discretizer.py b/sat/sgm/modules/diffusionmodules/discretizer.py index a86b7d8..466238f 100644 --- a/sat/sgm/modules/diffusionmodules/discretizer.py +++ b/sat/sgm/modules/diffusionmodules/discretizer.py @@ -52,7 +52,9 @@ class LegacyDDPMDiscretization(Discretization): ): super().__init__() self.num_timesteps = num_timesteps - betas = make_beta_schedule("linear", num_timesteps, linear_start=linear_start, linear_end=linear_end) + betas = make_beta_schedule( + "linear", num_timesteps, linear_start=linear_start, linear_end=linear_end + ) alphas = 1.0 - betas self.alphas_cumprod = np.cumprod(alphas, axis=0) self.to_torch = partial(torch.tensor, dtype=torch.float32) @@ -85,14 +87,18 @@ class ZeroSNRDDPMDiscretization(Discretization): if keep_start and not post_shift: linear_start = linear_start / (shift_scale + (1 - shift_scale) * linear_start) self.num_timesteps = num_timesteps - betas = make_beta_schedule("linear", num_timesteps, linear_start=linear_start, linear_end=linear_end) + betas = make_beta_schedule( + "linear", num_timesteps, linear_start=linear_start, linear_end=linear_end + ) alphas = 1.0 - betas self.alphas_cumprod = np.cumprod(alphas, axis=0) self.to_torch = partial(torch.tensor, dtype=torch.float32) # SNR shift if not post_shift: - self.alphas_cumprod = self.alphas_cumprod / (shift_scale + (1 - shift_scale) * self.alphas_cumprod) + self.alphas_cumprod = self.alphas_cumprod / ( + shift_scale + (1 - shift_scale) * self.alphas_cumprod + ) self.post_shift = post_shift self.shift_scale = shift_scale @@ -113,11 +119,14 @@ class ZeroSNRDDPMDiscretization(Discretization): alphas_cumprod_sqrt_T = alphas_cumprod_sqrt[-1].clone() alphas_cumprod_sqrt -= alphas_cumprod_sqrt_T - alphas_cumprod_sqrt *= alphas_cumprod_sqrt_0 / (alphas_cumprod_sqrt_0 - alphas_cumprod_sqrt_T) + alphas_cumprod_sqrt *= alphas_cumprod_sqrt_0 / ( + alphas_cumprod_sqrt_0 - alphas_cumprod_sqrt_T + ) if self.post_shift: alphas_cumprod_sqrt = ( - alphas_cumprod_sqrt**2 / (self.shift_scale + (1 - self.shift_scale) * alphas_cumprod_sqrt**2) + alphas_cumprod_sqrt**2 + / (self.shift_scale + (1 - self.shift_scale) * alphas_cumprod_sqrt**2) ) ** 0.5 if return_idx: diff --git a/sat/sgm/modules/diffusionmodules/guiders.py b/sat/sgm/modules/diffusionmodules/guiders.py index 7ce657c..4401bd7 100644 --- a/sat/sgm/modules/diffusionmodules/guiders.py +++ b/sat/sgm/modules/diffusionmodules/guiders.py @@ -15,7 +15,9 @@ class Guider(ABC): def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor: pass - def prepare_inputs(self, x: torch.Tensor, s: float, c: Dict, uc: Dict) -> Tuple[torch.Tensor, float, Dict]: + def prepare_inputs( + self, x: torch.Tensor, s: float, c: Dict, uc: Dict + ) -> Tuple[torch.Tensor, float, Dict]: pass @@ -57,7 +59,8 @@ class DynamicCFG(VanillaCFG): def __init__(self, scale, exp, num_steps, dyn_thresh_config=None): super().__init__(scale, dyn_thresh_config) scale_schedule = ( - lambda scale, sigma, step_index: 1 + scale * (1 - math.cos(math.pi * (step_index / num_steps) ** exp)) / 2 + lambda scale, sigma, step_index: 1 + + scale * (1 - math.cos(math.pi * (step_index / num_steps) ** exp)) / 2 ) self.scale_schedule = partial(scale_schedule, scale) self.dyn_thresh = instantiate_from_config( diff --git a/sat/sgm/modules/diffusionmodules/lora.py b/sat/sgm/modules/diffusionmodules/lora.py index 7ccd72a..cf73b75 100644 --- a/sat/sgm/modules/diffusionmodules/lora.py +++ b/sat/sgm/modules/diffusionmodules/lora.py @@ -20,7 +20,9 @@ from torch import nn class LoRALinearLayer(nn.Module): - def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None): + def __init__( + self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None + ): super().__init__() self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype) @@ -50,11 +52,20 @@ class LoRALinearLayer(nn.Module): class LoRAConv2dLayer(nn.Module): def __init__( - self, in_features, out_features, rank=4, kernel_size=(1, 1), stride=(1, 1), padding=0, network_alpha=None + self, + in_features, + out_features, + rank=4, + kernel_size=(1, 1), + stride=(1, 1), + padding=0, + network_alpha=None, ): super().__init__() - self.down = nn.Conv2d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) + self.down = nn.Conv2d( + in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False + ) # according to the official kohya_ss trainer kernel_size are always fixed for the up layer # # see: https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L129 self.up = nn.Conv2d(rank, out_features, kernel_size=(1, 1), stride=(1, 1), bias=False) @@ -85,7 +96,9 @@ class LoRACompatibleConv(nn.Conv2d): A convolutional layer that can be used with LoRA. """ - def __init__(self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, scale: float = 1.0, **kwargs): + def __init__( + self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, scale: float = 1.0, **kwargs + ): super().__init__(*args, **kwargs) self.lora_layer = lora_layer self.scale = scale @@ -144,7 +157,13 @@ class LoRACompatibleConv(nn.Conv2d): # make sure to the functional Conv2D function as otherwise torch.compile's graph will break # see: https://github.com/huggingface/diffusers/pull/4315 return F.conv2d( - hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups + hidden_states, + self.weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, ) else: return super().forward(hidden_states) + (scale * self.lora_layer(hidden_states)) @@ -155,7 +174,9 @@ class LoRACompatibleLinear(nn.Linear): A Linear layer that can be used with LoRA. """ - def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, scale: float = 1.0, **kwargs): + def __init__( + self, *args, lora_layer: Optional[LoRALinearLayer] = None, scale: float = 1.0, **kwargs + ): super().__init__(*args, **kwargs) self.lora_layer = lora_layer self.scale = scale @@ -197,7 +218,9 @@ class LoRACompatibleLinear(nn.Linear): w_up = self.w_up.to(device=device).float() w_down = self.w_down.to(device).float() - unfused_weight = fused_weight.float() - (self._lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0]) + unfused_weight = fused_weight.float() - ( + self._lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0] + ) self.weight.data = unfused_weight.to(device=device, dtype=dtype) self.w_up = None @@ -252,7 +275,9 @@ def _find_modules_v2( # Get the targets we should replace all linears under if ancestor_class is not None: - ancestors = (module for module in model.modules() if module.__class__.__name__ in ancestor_class) + ancestors = ( + module for module in model.modules() if module.__class__.__name__ in ancestor_class + ) else: # this, incase you want to naively iterate over all modules. ancestors = [module for module in model.modules()] @@ -274,7 +299,9 @@ def _find_modules_v2( if flag: continue # Skip this linear if it's a child of a LoraInjectedLinear - if exclude_children_of and any([isinstance(parent, _class) for _class in exclude_children_of]): + if exclude_children_of and any( + [isinstance(parent, _class) for _class in exclude_children_of] + ): continue # Otherwise, yield it yield parent, name, module diff --git a/sat/sgm/modules/diffusionmodules/loss.py b/sat/sgm/modules/diffusionmodules/loss.py index 66916c1..589e441 100644 --- a/sat/sgm/modules/diffusionmodules/loss.py +++ b/sat/sgm/modules/diffusionmodules/loss.py @@ -38,13 +38,17 @@ class StandardDiffusionLoss(nn.Module): def __call__(self, network, denoiser, conditioner, input, batch): cond = conditioner(batch) - additional_model_inputs = {key: batch[key] for key in self.batch2model_keys.intersection(batch)} + additional_model_inputs = { + key: batch[key] for key in self.batch2model_keys.intersection(batch) + } sigmas = self.sigma_sampler(input.shape[0]).to(input.device) noise = torch.randn_like(input) if self.offset_noise_level > 0.0: noise = ( - noise + append_dims(torch.randn(input.shape[0]).to(input.device), input.ndim) * self.offset_noise_level + noise + + append_dims(torch.randn(input.shape[0]).to(input.device), input.ndim) + * self.offset_noise_level ) noise = noise.to(input.dtype) noised_input = input.float() + noise * append_dims(sigmas, input.ndim) @@ -63,7 +67,9 @@ class StandardDiffusionLoss(nn.Module): class VideoDiffusionLoss(StandardDiffusionLoss): - def __init__(self, block_scale=None, block_size=None, min_snr_value=None, fixed_frames=0, **kwargs): + def __init__( + self, block_scale=None, block_size=None, min_snr_value=None, fixed_frames=0, **kwargs + ): self.fixed_frames = fixed_frames self.block_scale = block_scale self.block_size = block_size @@ -72,7 +78,9 @@ class VideoDiffusionLoss(StandardDiffusionLoss): def __call__(self, network, denoiser, conditioner, input, batch): cond = conditioner(batch) - additional_model_inputs = {key: batch[key] for key in self.batch2model_keys.intersection(batch)} + additional_model_inputs = { + key: batch[key] for key in self.batch2model_keys.intersection(batch) + } alphas_cumprod_sqrt, idx = self.sigma_sampler(input.shape[0], return_idx=True) alphas_cumprod_sqrt = alphas_cumprod_sqrt.to(input.device) @@ -86,24 +94,30 @@ class VideoDiffusionLoss(StandardDiffusionLoss): src = global_rank * mp_size torch.distributed.broadcast(idx, src=src, group=mpu.get_model_parallel_group()) torch.distributed.broadcast(noise, src=src, group=mpu.get_model_parallel_group()) - torch.distributed.broadcast(alphas_cumprod_sqrt, src=src, group=mpu.get_model_parallel_group()) + torch.distributed.broadcast( + alphas_cumprod_sqrt, src=src, group=mpu.get_model_parallel_group() + ) additional_model_inputs["idx"] = idx if self.offset_noise_level > 0.0: noise = ( - noise + append_dims(torch.randn(input.shape[0]).to(input.device), input.ndim) * self.offset_noise_level + noise + + append_dims(torch.randn(input.shape[0]).to(input.device), input.ndim) + * self.offset_noise_level ) - noised_input = input.float() * append_dims(alphas_cumprod_sqrt, input.ndim) + noise * append_dims( - (1 - alphas_cumprod_sqrt**2) ** 0.5, input.ndim - ) + noised_input = input.float() * append_dims( + alphas_cumprod_sqrt, input.ndim + ) + noise * append_dims((1 - alphas_cumprod_sqrt**2) ** 0.5, input.ndim) if "concat_images" in batch.keys(): cond["concat"] = batch["concat_images"] # [2, 13, 16, 60, 90],[2] dict_keys(['crossattn', 'concat']) dict_keys(['idx']) - model_output = denoiser(network, noised_input, alphas_cumprod_sqrt, cond, **additional_model_inputs) + model_output = denoiser( + network, noised_input, alphas_cumprod_sqrt, cond, **additional_model_inputs + ) w = append_dims(1 / (1 - alphas_cumprod_sqrt**2), input.ndim) # v-pred if self.min_snr_value is not None: diff --git a/sat/sgm/modules/diffusionmodules/model.py b/sat/sgm/modules/diffusionmodules/model.py index 466f01a..573833b 100644 --- a/sat/sgm/modules/diffusionmodules/model.py +++ b/sat/sgm/modules/diffusionmodules/model.py @@ -47,7 +47,9 @@ def nonlinearity(x): def Normalize(in_channels, num_groups=32): - return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + return torch.nn.GroupNorm( + num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True + ) class Upsample(nn.Module): @@ -55,7 +57,9 @@ class Upsample(nn.Module): super().__init__() self.with_conv = with_conv if self.with_conv: - self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) def forward(self, x): x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") @@ -70,7 +74,9 @@ class Downsample(nn.Module): self.with_conv = with_conv if self.with_conv: # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0 + ) def forward(self, x): if self.with_conv: @@ -107,9 +113,13 @@ class ResnetBlock(nn.Module): self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) if self.in_channels != self.out_channels: if self.use_conv_shortcut: - self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.conv_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) else: - self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + self.nin_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) def forward(self, x, temb): h = x @@ -150,7 +160,9 @@ class AttnBlock(nn.Module): self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) def attention(self, h_: torch.Tensor) -> torch.Tensor: h_ = self.norm(h_) @@ -160,7 +172,9 @@ class AttnBlock(nn.Module): b, c, h, w = q.shape q, k, v = map(lambda x: rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v)) - h_ = torch.nn.functional.scaled_dot_product_attention(q, k, v) # scale is dim ** -0.5 per default + h_ = torch.nn.functional.scaled_dot_product_attention( + q, k, v + ) # scale is dim ** -0.5 per default # compute attention return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) @@ -188,7 +202,9 @@ class MemoryEfficientAttnBlock(nn.Module): self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) self.attention_op: Optional[Any] = None def attention(self, h_: torch.Tensor) -> torch.Tensor: @@ -211,7 +227,12 @@ class MemoryEfficientAttnBlock(nn.Module): ) out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) - out = out.unsqueeze(0).reshape(B, 1, out.shape[1], C).permute(0, 2, 1, 3).reshape(B, out.shape[1], C) + out = ( + out.unsqueeze(0) + .reshape(B, 1, out.shape[1], C) + .permute(0, 2, 1, 3) + .reshape(B, out.shape[1], C) + ) return rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C) def forward(self, x, **kwargs): @@ -581,7 +602,11 @@ class Decoder(nn.Module): block_in = ch * ch_mult[self.num_resolutions - 1] curr_res = resolution // 2 ** (self.num_resolutions - 1) self.z_shape = (1, z_channels, curr_res, curr_res) - print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) + print( + "Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape) + ) + ) make_attn_cls = self._make_attn() make_resblock_cls = self._make_resblock() diff --git a/sat/sgm/modules/diffusionmodules/openaimodel.py b/sat/sgm/modules/diffusionmodules/openaimodel.py index 3f0b83c..9cde5e2 100644 --- a/sat/sgm/modules/diffusionmodules/openaimodel.py +++ b/sat/sgm/modules/diffusionmodules/openaimodel.py @@ -47,7 +47,9 @@ class AttentionPool2d(nn.Module): output_dim: int = None, ): super().__init__() - self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5) + self.positional_embedding = nn.Parameter( + th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5 + ) self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) self.num_heads = embed_dim // num_heads_channels @@ -303,7 +305,9 @@ class ResBlock(TimestepBlock): if self.out_channels == channels: self.skip_connection = nn.Identity() elif use_conv: - self.skip_connection = conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding) + self.skip_connection = conv_nd( + dims, channels, self.out_channels, kernel_size, padding=padding + ) else: self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) @@ -437,7 +441,9 @@ class QKVAttentionLegacy(nn.Module): ch = width // (3 * self.n_heads) q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) scale = 1 / math.sqrt(math.sqrt(ch)) - weight = th.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards + weight = th.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) a = th.einsum("bts,bcs->bct", weight, v) return a.reshape(bs, -1, length) @@ -574,9 +580,7 @@ class UNetModel(nn.Module): ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..." if context_dim is not None: - assert ( - use_spatial_transformer - ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..." + assert use_spatial_transformer, "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..." if type(context_dim) == ListConfig: context_dim = list(context_dim) @@ -640,7 +644,9 @@ class UNetModel(nn.Module): self.num_heads_upsample = num_heads_upsample self.predict_codebook_ids = n_embed is not None - assert use_fairscale_checkpoint != use_checkpoint or not (use_checkpoint or use_fairscale_checkpoint) + assert use_fairscale_checkpoint != use_checkpoint or not ( + use_checkpoint or use_fairscale_checkpoint + ) self.use_fairscale_checkpoint = False checkpoint_wrapper_fn = ( @@ -942,7 +948,9 @@ class UNetModel(nn.Module): print(f"loading lora from {ckpt_path}") sd = th.load(ckpt_path)["module"] sd = { - key[len("model.diffusion_model") :]: sd[key] for key in sd if key.startswith("model.diffusion_model") + key[len("model.diffusion_model") :]: sd[key] + for key in sd + if key.startswith("model.diffusion_model") } self.load_state_dict(sd, strict=False) @@ -978,7 +986,9 @@ class UNetModel(nn.Module): self.num_classes is not None ), "must specify y if and only if the model is class-conditional" hs = [] - t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False, dtype=self.dtype) + t_emb = timestep_embedding( + timesteps, self.model_channels, repeat_only=False, dtype=self.dtype + ) emb = self.time_embed(t_emb) if self.num_classes is not None: diff --git a/sat/sgm/modules/diffusionmodules/sampling.py b/sat/sgm/modules/diffusionmodules/sampling.py index 6efd154..edae1d6 100644 --- a/sat/sgm/modules/diffusionmodules/sampling.py +++ b/sat/sgm/modules/diffusionmodules/sampling.py @@ -1,8 +1,7 @@ """ - Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py +Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py """ - from typing import Dict, Union import torch @@ -85,9 +84,7 @@ class SingleStepDiffusionSampler(BaseDiffusionSampler): class EDMSampler(SingleStepDiffusionSampler): - def __init__( - self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs - ): + def __init__(self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs): super().__init__(*args, **kwargs) self.s_churn = s_churn @@ -106,15 +103,11 @@ class EDMSampler(SingleStepDiffusionSampler): dt = append_dims(next_sigma - sigma_hat, x.ndim) euler_step = self.euler_step(x, d, dt) - x = self.possible_correction_step( - euler_step, x, d, dt, next_sigma, denoiser, cond, uc - ) + x = self.possible_correction_step(euler_step, x, d, dt, next_sigma, denoiser, cond, uc) return x def __call__(self, denoiser, x, cond, uc=None, num_steps=None): - x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( - x, cond, uc, num_steps - ) + x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) for i in self.get_sigma_gen(num_sigmas): gamma = ( @@ -136,30 +129,23 @@ class EDMSampler(SingleStepDiffusionSampler): class DDIMSampler(SingleStepDiffusionSampler): - def __init__( - self, s_noise=0.1, *args, **kwargs - ): + def __init__(self, s_noise=0.1, *args, **kwargs): super().__init__(*args, **kwargs) self.s_noise = s_noise def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, s_noise=0.0): - denoised = self.denoise(x, denoiser, sigma, cond, uc) d = to_d(x, sigma, denoised) - dt = append_dims(next_sigma * (1 - s_noise**2)**0.5 - sigma, x.ndim) + dt = append_dims(next_sigma * (1 - s_noise**2) ** 0.5 - sigma, x.ndim) euler_step = x + dt * d + s_noise * append_dims(next_sigma, x.ndim) * torch.randn_like(x) - x = self.possible_correction_step( - euler_step, x, d, dt, next_sigma, denoiser, cond, uc - ) + x = self.possible_correction_step(euler_step, x, d, dt, next_sigma, denoiser, cond, uc) return x def __call__(self, denoiser, x, cond, uc=None, num_steps=None): - x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( - x, cond, uc, num_steps - ) + x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) for i in self.get_sigma_gen(num_sigmas): x = self.sampler_step( @@ -198,9 +184,7 @@ class AncestralSampler(SingleStepDiffusionSampler): return x def __call__(self, denoiser, x, cond, uc=None, num_steps=None): - x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( - x, cond, uc, num_steps - ) + x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) for i in self.get_sigma_gen(num_sigmas): x = self.sampler_step( @@ -227,43 +211,32 @@ class LinearMultistepSampler(BaseDiffusionSampler): self.order = order def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs): - x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( - x, cond, uc, num_steps - ) + x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) ds = [] sigmas_cpu = sigmas.detach().cpu().numpy() for i in self.get_sigma_gen(num_sigmas): sigma = s_in * sigmas[i] - denoised = denoiser( - *self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs - ) + denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs) denoised = self.guider(denoised, sigma) d = to_d(x, sigma, denoised) ds.append(d) if len(ds) > self.order: ds.pop(0) cur_order = min(i + 1, self.order) - coeffs = [ - linear_multistep_coeff(cur_order, sigmas_cpu, i, j) - for j in range(cur_order) - ] + coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)] x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds))) return x class EulerEDMSampler(EDMSampler): - def possible_correction_step( - self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc - ): + def possible_correction_step(self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc): return euler_step class HeunEDMSampler(EDMSampler): - def possible_correction_step( - self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc - ): + def possible_correction_step(self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc): if torch.sum(next_sigma) < 1e-14: # Save a network evaluation if all noise levels are 0 return euler_step @@ -273,9 +246,7 @@ class HeunEDMSampler(EDMSampler): d_prime = (d + d_new) / 2.0 # apply correction if noise level is not 0 - x = torch.where( - append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step - ) + x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step) return x @@ -314,9 +285,7 @@ class DPMPP2SAncestralSampler(AncestralSampler): x = x_euler else: h, s, t, t_next = self.get_variables(sigma, sigma_down) - mult = [ - append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next) - ] + mult = [append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next)] x2 = mult[0] * x - mult[1] * denoised denoised2 = self.denoise(x2, denoiser, to_sigma(s), cond, uc) @@ -367,8 +336,7 @@ class DPMPP2MSampler(BaseDiffusionSampler): h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma) mult = [ - append_dims(mult, x.ndim) - for mult in self.get_mult(h, r, t, t_next, previous_sigma) + append_dims(mult, x.ndim) for mult in self.get_mult(h, r, t, t_next, previous_sigma) ] x_standard = mult[0] * x - mult[1] * denoised @@ -380,16 +348,12 @@ class DPMPP2MSampler(BaseDiffusionSampler): x_advanced = mult[0] * x - mult[1] * denoised_d # apply correction if noise level is not 0 and not first step - x = torch.where( - append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard - ) + x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard) return x, denoised def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs): - x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( - x, cond, uc, num_steps - ) + x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) old_denoised = None for i in self.get_sigma_gen(num_sigmas): @@ -406,6 +370,7 @@ class DPMPP2MSampler(BaseDiffusionSampler): return x + class SDEDPMPP2MSampler(BaseDiffusionSampler): def get_variables(self, sigma, next_sigma, previous_sigma=None): t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)] @@ -420,7 +385,7 @@ class SDEDPMPP2MSampler(BaseDiffusionSampler): def get_mult(self, h, r, t, t_next, previous_sigma): mult1 = to_sigma(t_next) / to_sigma(t) * (-h).exp() - mult2 = (-2*h).expm1() + mult2 = (-2 * h).expm1() if previous_sigma is not None: mult3 = 1 + 1 / (2 * r) @@ -444,10 +409,9 @@ class SDEDPMPP2MSampler(BaseDiffusionSampler): h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma) mult = [ - append_dims(mult, x.ndim) - for mult in self.get_mult(h, r, t, t_next, previous_sigma) + append_dims(mult, x.ndim) for mult in self.get_mult(h, r, t, t_next, previous_sigma) ] - mult_noise = append_dims(next_sigma * (1 - (-2*h).exp())**0.5, x.ndim) + mult_noise = append_dims(next_sigma * (1 - (-2 * h).exp()) ** 0.5, x.ndim) x_standard = mult[0] * x - mult[1] * denoised + mult_noise * torch.randn_like(x) if old_denoised is None or torch.sum(next_sigma) < 1e-14: @@ -458,16 +422,12 @@ class SDEDPMPP2MSampler(BaseDiffusionSampler): x_advanced = mult[0] * x - mult[1] * denoised_d + mult_noise * torch.randn_like(x) # apply correction if noise level is not 0 and not first step - x = torch.where( - append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard - ) + x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard) return x, denoised def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, **kwargs): - x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( - x, cond, uc, num_steps - ) + x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) old_denoised = None for i in self.get_sigma_gen(num_sigmas): @@ -484,6 +444,7 @@ class SDEDPMPP2MSampler(BaseDiffusionSampler): return x + class SdeditEDMSampler(EulerEDMSampler): def __init__(self, edit_ratio=0.5, *args, **kwargs): super().__init__(*args, **kwargs) @@ -525,8 +486,8 @@ class SdeditEDMSampler(EulerEDMSampler): return x -class VideoDDIMSampler(BaseDiffusionSampler): +class VideoDDIMSampler(BaseDiffusionSampler): def __init__(self, fixed_frames=0, sdedit=False, **kwargs): super().__init__(**kwargs) self.fixed_frames = fixed_frames @@ -534,10 +495,15 @@ class VideoDDIMSampler(BaseDiffusionSampler): def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None): alpha_cumprod_sqrt, timesteps = self.discretization( - self.num_steps if num_steps is None else num_steps, device=self.device, return_idx=True, do_append_zero=False + self.num_steps if num_steps is None else num_steps, + device=self.device, + return_idx=True, + do_append_zero=False, ) alpha_cumprod_sqrt = torch.cat([alpha_cumprod_sqrt, alpha_cumprod_sqrt.new_ones([1])]) - timesteps = torch.cat([torch.tensor(list(timesteps)).new_zeros([1])-1, torch.tensor(list(timesteps))]) + timesteps = torch.cat( + [torch.tensor(list(timesteps)).new_zeros([1]) - 1, torch.tensor(list(timesteps))] + ) uc = default(uc, cond) @@ -547,7 +513,19 @@ class VideoDDIMSampler(BaseDiffusionSampler): return x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps - def denoise(self, x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep=None, idx=None, scale=None, scale_emb=None, ofs=None): + def denoise( + self, + x, + denoiser, + alpha_cumprod_sqrt, + cond, + uc, + timestep=None, + idx=None, + scale=None, + scale_emb=None, + ofs=None, + ): additional_model_inputs = {} if ofs is not None: @@ -557,26 +535,62 @@ class VideoDDIMSampler(BaseDiffusionSampler): additional_model_inputs['idx'] = x.new_ones([x.shape[0]]) * timestep if scale_emb is not None: additional_model_inputs['scale_emb'] = scale_emb - denoised = denoiser(x, alpha_cumprod_sqrt, cond, **additional_model_inputs).to(torch.float32) + denoised = denoiser(x, alpha_cumprod_sqrt, cond, **additional_model_inputs).to( + torch.float32 + ) else: additional_model_inputs['idx'] = torch.cat([x.new_ones([x.shape[0]]) * timestep] * 2) - denoised = denoiser(*self.guider.prepare_inputs(x, alpha_cumprod_sqrt, cond, uc), **additional_model_inputs).to(torch.float32) + denoised = denoiser( + *self.guider.prepare_inputs(x, alpha_cumprod_sqrt, cond, uc), + **additional_model_inputs, + ).to(torch.float32) if isinstance(self.guider, DynamicCFG): - denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt**2)**0.5, step_index=self.num_steps - timestep, scale=scale) + denoised = self.guider( + denoised, + (1 - alpha_cumprod_sqrt**2) ** 0.5, + step_index=self.num_steps - timestep, + scale=scale, + ) else: - denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt**2)**0.5, scale=scale) + denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt**2) ** 0.5, scale=scale) return denoised - def sampler_step(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, denoiser, x, cond, uc=None, idx=None, timestep=None, scale=None, scale_emb=None, ofs=None): - denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx, scale=scale, scale_emb=scale_emb, ofs=ofs).to(torch.float32) # 1020 + def sampler_step( + self, + alpha_cumprod_sqrt, + next_alpha_cumprod_sqrt, + denoiser, + x, + cond, + uc=None, + idx=None, + timestep=None, + scale=None, + scale_emb=None, + ofs=None, + ): + denoised = self.denoise( + x, + denoiser, + alpha_cumprod_sqrt, + cond, + uc, + timestep, + idx, + scale=scale, + scale_emb=scale_emb, + ofs=ofs, + ).to(torch.float32) # 1020 - a_t = ((1-next_alpha_cumprod_sqrt**2)/(1-alpha_cumprod_sqrt**2))**0.5 + a_t = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5 b_t = next_alpha_cumprod_sqrt - alpha_cumprod_sqrt * a_t x = append_dims(a_t, x.ndim) * x + append_dims(b_t, x.ndim) * denoised return x - def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None, ofs=None): # 1020 + def __call__( + self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None, ofs=None + ): # 1020 x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop( x, cond, uc, num_steps ) @@ -590,17 +604,16 @@ class VideoDDIMSampler(BaseDiffusionSampler): cond, uc, idx=self.num_steps - i, - timestep=timesteps[-(i+1)], + timestep=timesteps[-(i + 1)], scale=scale, scale_emb=scale_emb, - ofs=ofs # 1020 + ofs=ofs, # 1020 ) return x class Image2VideoDDIMSampler(BaseDiffusionSampler): - def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None): alpha_cumprod_sqrt, timesteps = self.discretization( self.num_steps if num_steps is None else num_steps, device=self.device, return_idx=True @@ -616,22 +629,36 @@ class Image2VideoDDIMSampler(BaseDiffusionSampler): def denoise(self, x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep=None): additional_model_inputs = {} additional_model_inputs['idx'] = torch.cat([x.new_ones([x.shape[0]]) * timestep] * 2) - denoised = denoiser(*self.guider.prepare_inputs(x, alpha_cumprod_sqrt, cond, uc), **additional_model_inputs).to( - torch.float32) + denoised = denoiser( + *self.guider.prepare_inputs(x, alpha_cumprod_sqrt, cond, uc), **additional_model_inputs + ).to(torch.float32) if isinstance(self.guider, DynamicCFG): - denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt ** 2) ** 0.5, step_index=self.num_steps - timestep) + denoised = self.guider( + denoised, (1 - alpha_cumprod_sqrt**2) ** 0.5, step_index=self.num_steps - timestep + ) else: - denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt ** 2) ** 0.5) + denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt**2) ** 0.5) return denoised - def sampler_step(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, denoiser, x, cond, uc=None, idx=None, - timestep=None): + def sampler_step( + self, + alpha_cumprod_sqrt, + next_alpha_cumprod_sqrt, + denoiser, + x, + cond, + uc=None, + idx=None, + timestep=None, + ): # 此处的sigma实际上是alpha_cumprod_sqrt - denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep).to(torch.float32) + denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep).to( + torch.float32 + ) if idx == 1: return denoised - a_t = ((1 - next_alpha_cumprod_sqrt ** 2) / (1 - alpha_cumprod_sqrt ** 2)) ** 0.5 + a_t = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5 b_t = next_alpha_cumprod_sqrt - alpha_cumprod_sqrt * a_t x = append_dims(a_t, x.ndim) * x + append_dims(b_t, x.ndim) * denoised @@ -651,31 +678,36 @@ class Image2VideoDDIMSampler(BaseDiffusionSampler): cond, uc, idx=self.num_steps - i, - timestep=timesteps[-(i + 1)] + timestep=timesteps[-(i + 1)], ) return x + class VPSDEDPMPP2MSampler(VideoDDIMSampler): - def get_variables(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt=None): - alpha_cumprod = alpha_cumprod_sqrt ** 2 - lamb = ((alpha_cumprod / (1-alpha_cumprod))**0.5).log() - next_alpha_cumprod = next_alpha_cumprod_sqrt ** 2 - lamb_next = ((next_alpha_cumprod / (1-next_alpha_cumprod))**0.5).log() + def get_variables( + self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt=None + ): + alpha_cumprod = alpha_cumprod_sqrt**2 + lamb = ((alpha_cumprod / (1 - alpha_cumprod)) ** 0.5).log() + next_alpha_cumprod = next_alpha_cumprod_sqrt**2 + lamb_next = ((next_alpha_cumprod / (1 - next_alpha_cumprod)) ** 0.5).log() h = lamb_next - lamb if previous_alpha_cumprod_sqrt is not None: - previous_alpha_cumprod = previous_alpha_cumprod_sqrt ** 2 - lamb_previous = ((previous_alpha_cumprod / (1-previous_alpha_cumprod))**0.5).log() + previous_alpha_cumprod = previous_alpha_cumprod_sqrt**2 + lamb_previous = ((previous_alpha_cumprod / (1 - previous_alpha_cumprod)) ** 0.5).log() h_last = lamb - lamb_previous r = h_last / h return h, r, lamb, lamb_next else: return h, None, lamb, lamb_next - def get_mult(self, h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt): - mult1 = ((1-next_alpha_cumprod_sqrt**2) / (1-alpha_cumprod_sqrt**2))**0.5 * (-h).exp() - mult2 = (-2*h).expm1() * next_alpha_cumprod_sqrt + def get_mult( + self, h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt + ): + mult1 = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5 * (-h).exp() + mult2 = (-2 * h).expm1() * next_alpha_cumprod_sqrt if previous_alpha_cumprod_sqrt is not None: mult3 = 1 + 1 / (2 * r) @@ -698,18 +730,35 @@ class VPSDEDPMPP2MSampler(VideoDDIMSampler): timestep=None, scale=None, scale_emb=None, - ofs=None # 1020 + ofs=None, # 1020 ): - denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx, scale=scale, scale_emb=scale_emb, ofs=ofs).to(torch.float32) # 1020 + denoised = self.denoise( + x, + denoiser, + alpha_cumprod_sqrt, + cond, + uc, + timestep, + idx, + scale=scale, + scale_emb=scale_emb, + ofs=ofs, + ).to(torch.float32) # 1020 if idx == 1: return denoised, denoised - h, r, lamb, lamb_next = self.get_variables(alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt) + h, r, lamb, lamb_next = self.get_variables( + alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt + ) mult = [ append_dims(mult, x.ndim) - for mult in self.get_mult(h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt) + for mult in self.get_mult( + h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt + ) ] - mult_noise = append_dims((1-next_alpha_cumprod_sqrt**2)**0.5 * (1 - (-2*h).exp())**0.5, x.ndim) + mult_noise = append_dims( + (1 - next_alpha_cumprod_sqrt**2) ** 0.5 * (1 - (-2 * h).exp()) ** 0.5, x.ndim + ) x_standard = mult[0] * x - mult[1] * denoised + mult_noise * torch.randn_like(x) if old_denoised is None or torch.sum(next_alpha_cumprod_sqrt) < 1e-14: @@ -723,23 +772,26 @@ class VPSDEDPMPP2MSampler(VideoDDIMSampler): return x, denoised - def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None, ofs=None): # 1020 + def __call__( + self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None, ofs=None + ): # 1020 x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop( x, cond, uc, num_steps ) if self.fixed_frames > 0: - prefix_frames = x[:, :self.fixed_frames] + prefix_frames = x[:, : self.fixed_frames] old_denoised = None for i in self.get_sigma_gen(num_sigmas): - if self.fixed_frames > 0: if self.sdedit: rd = torch.randn_like(prefix_frames) - noised_prefix_frames = alpha_cumprod_sqrt[i] * prefix_frames + rd * append_dims(s_in * (1 - alpha_cumprod_sqrt[i] ** 2)**0.5, len(prefix_frames.shape)) - x = torch.cat([noised_prefix_frames, x[:, self.fixed_frames:]], dim=1) + noised_prefix_frames = alpha_cumprod_sqrt[i] * prefix_frames + rd * append_dims( + s_in * (1 - alpha_cumprod_sqrt[i] ** 2) ** 0.5, len(prefix_frames.shape) + ) + x = torch.cat([noised_prefix_frames, x[:, self.fixed_frames :]], dim=1) else: - x = torch.cat([prefix_frames, x[:, self.fixed_frames:]], dim=1) + x = torch.cat([prefix_frames, x[:, self.fixed_frames :]], dim=1) x, old_denoised = self.sampler_step( old_denoised, None if i == 0 else s_in * alpha_cumprod_sqrt[i - 1], @@ -750,37 +802,41 @@ class VPSDEDPMPP2MSampler(VideoDDIMSampler): cond, uc=uc, idx=self.num_steps - i, - timestep=timesteps[-(i+1)], + timestep=timesteps[-(i + 1)], scale=scale, scale_emb=scale_emb, - ofs=ofs # 1020 + ofs=ofs, # 1020 ) if self.fixed_frames > 0: - x = torch.cat([prefix_frames, x[:, self.fixed_frames:]], dim=1) + x = torch.cat([prefix_frames, x[:, self.fixed_frames :]], dim=1) return x class VPODEDPMPP2MSampler(VideoDDIMSampler): - def get_variables(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt=None): - alpha_cumprod = alpha_cumprod_sqrt ** 2 - lamb = ((alpha_cumprod / (1-alpha_cumprod))**0.5).log() - next_alpha_cumprod = next_alpha_cumprod_sqrt ** 2 - lamb_next = ((next_alpha_cumprod / (1-next_alpha_cumprod))**0.5).log() + def get_variables( + self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt=None + ): + alpha_cumprod = alpha_cumprod_sqrt**2 + lamb = ((alpha_cumprod / (1 - alpha_cumprod)) ** 0.5).log() + next_alpha_cumprod = next_alpha_cumprod_sqrt**2 + lamb_next = ((next_alpha_cumprod / (1 - next_alpha_cumprod)) ** 0.5).log() h = lamb_next - lamb if previous_alpha_cumprod_sqrt is not None: - previous_alpha_cumprod = previous_alpha_cumprod_sqrt ** 2 - lamb_previous = ((previous_alpha_cumprod / (1-previous_alpha_cumprod))**0.5).log() + previous_alpha_cumprod = previous_alpha_cumprod_sqrt**2 + lamb_previous = ((previous_alpha_cumprod / (1 - previous_alpha_cumprod)) ** 0.5).log() h_last = lamb - lamb_previous r = h_last / h return h, r, lamb, lamb_next else: return h, None, lamb, lamb_next - def get_mult(self, h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt): - mult1 = ((1-next_alpha_cumprod_sqrt**2) / (1-alpha_cumprod_sqrt**2))**0.5 + def get_mult( + self, h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt + ): + mult1 = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5 mult2 = (-h).expm1() * next_alpha_cumprod_sqrt if previous_alpha_cumprod_sqrt is not None: @@ -801,16 +857,22 @@ class VPODEDPMPP2MSampler(VideoDDIMSampler): cond, uc=None, idx=None, - timestep=None + timestep=None, ): - denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx).to(torch.float32) + denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx).to( + torch.float32 + ) if idx == 1: return denoised, denoised - h, r, lamb, lamb_next = self.get_variables(alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt) + h, r, lamb, lamb_next = self.get_variables( + alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt + ) mult = [ append_dims(mult, x.ndim) - for mult in self.get_mult(h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt) + for mult in self.get_mult( + h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt + ) ] x_standard = mult[0] * x - mult[1] * denoised @@ -842,22 +904,44 @@ class VPODEDPMPP2MSampler(VideoDDIMSampler): cond, uc=uc, idx=self.num_steps - i, - timestep=timesteps[-(i+1)] + timestep=timesteps[-(i + 1)], ) return x + class VideoDDPMSampler(VideoDDIMSampler): - def sampler_step(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, denoiser, x, cond, uc=None, idx=None): + def sampler_step( + self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, denoiser, x, cond, uc=None, idx=None + ): # 此处的sigma实际上是alpha_cumprod_sqrt - denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, idx*1000//self.num_steps).to(torch.float32) + denoised = self.denoise( + x, denoiser, alpha_cumprod_sqrt, cond, uc, idx * 1000 // self.num_steps + ).to(torch.float32) if idx == 1: return denoised alpha_sqrt = alpha_cumprod_sqrt / next_alpha_cumprod_sqrt - x = append_dims(alpha_sqrt * (1-next_alpha_cumprod_sqrt**2) / (1-alpha_cumprod_sqrt**2), x.ndim) * x \ - + append_dims(next_alpha_cumprod_sqrt * (1-alpha_sqrt**2) / (1-alpha_cumprod_sqrt**2), x.ndim) * denoised \ - + append_dims(((1-next_alpha_cumprod_sqrt**2) * (1-alpha_sqrt**2) / (1-alpha_cumprod_sqrt**2))**0.5, x.ndim) * torch.randn_like(x) + x = ( + append_dims( + alpha_sqrt * (1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2), x.ndim + ) + * x + + append_dims( + next_alpha_cumprod_sqrt * (1 - alpha_sqrt**2) / (1 - alpha_cumprod_sqrt**2), x.ndim + ) + * denoised + + append_dims( + ( + (1 - next_alpha_cumprod_sqrt**2) + * (1 - alpha_sqrt**2) + / (1 - alpha_cumprod_sqrt**2) + ) + ** 0.5, + x.ndim, + ) + * torch.randn_like(x) + ) return x @@ -874,7 +958,7 @@ class VideoDDPMSampler(VideoDDIMSampler): x, cond, uc, - idx=self.num_steps - i + idx=self.num_steps - i, ) - return x \ No newline at end of file + return x diff --git a/sat/sgm/modules/diffusionmodules/sigma_sampling.py b/sat/sgm/modules/diffusionmodules/sigma_sampling.py index 8bb623e..5382f27 100644 --- a/sat/sgm/modules/diffusionmodules/sigma_sampling.py +++ b/sat/sgm/modules/diffusionmodules/sigma_sampling.py @@ -17,7 +17,15 @@ class EDMSampling: class DiscreteSampling: - def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True, uniform_sampling=False, group_num=0): + def __init__( + self, + discretization_config, + num_idx, + do_append_zero=False, + flip=True, + uniform_sampling=False, + group_num=0, + ): self.num_idx = num_idx self.sigmas = instantiate_from_config(discretization_config)( num_idx, do_append_zero=do_append_zero, flip=flip @@ -30,7 +38,7 @@ class DiscreteSampling: if self.uniform_sampling: assert self.group_num > 0 assert world_size % group_num == 0 - self.group_width = world_size // group_num # the number of rank in one group + self.group_width = world_size // group_num # the number of rank in one group self.sigma_interval = self.num_idx // self.group_num def idx_to_sigma(self, idx): @@ -42,7 +50,11 @@ class DiscreteSampling: group_index = rank // self.group_width idx = default( rand, - torch.randint(group_index * self.sigma_interval, (group_index + 1) * self.sigma_interval, (n_samples,)), + torch.randint( + group_index * self.sigma_interval, + (group_index + 1) * self.sigma_interval, + (n_samples,), + ), ) else: idx = default( @@ -54,8 +66,11 @@ class DiscreteSampling: else: return self.idx_to_sigma(idx) + class PartialDiscreteSampling: - def __init__(self, discretization_config, total_num_idx, partial_num_idx, do_append_zero=False, flip=True): + def __init__( + self, discretization_config, total_num_idx, partial_num_idx, do_append_zero=False, flip=True + ): self.total_num_idx = total_num_idx self.partial_num_idx = partial_num_idx self.sigmas = instantiate_from_config(discretization_config)( diff --git a/sat/sgm/modules/diffusionmodules/util.py b/sat/sgm/modules/diffusionmodules/util.py index abf72a7..64d2d46 100644 --- a/sat/sgm/modules/diffusionmodules/util.py +++ b/sat/sgm/modules/diffusionmodules/util.py @@ -24,7 +24,9 @@ def make_beta_schedule( linear_end=2e-2, ): if schedule == "linear": - betas = torch.linspace(linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64) ** 2 + betas = ( + torch.linspace(linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64) ** 2 + ) return betas.numpy() @@ -50,7 +52,9 @@ def mixed_checkpoint(func, inputs: dict, params, flag): tensor_keys = [key for key in inputs if isinstance(inputs[key], torch.Tensor)] tensor_inputs = [inputs[key] for key in inputs if isinstance(inputs[key], torch.Tensor)] non_tensor_keys = [key for key in inputs if not isinstance(inputs[key], torch.Tensor)] - non_tensor_inputs = [inputs[key] for key in inputs if not isinstance(inputs[key], torch.Tensor)] + non_tensor_inputs = [ + inputs[key] for key in inputs if not isinstance(inputs[key], torch.Tensor) + ] args = tuple(tensor_inputs) + tuple(non_tensor_inputs) + tuple(params) return MixedCheckpointFunction.apply( func, @@ -84,9 +88,14 @@ class MixedCheckpointFunction(torch.autograd.Function): } assert len(tensor_keys) == length_tensors and len(non_tensor_keys) == length_non_tensors - ctx.input_tensors = {key: val for (key, val) in zip(tensor_keys, list(args[: ctx.end_tensors]))} + ctx.input_tensors = { + key: val for (key, val) in zip(tensor_keys, list(args[: ctx.end_tensors])) + } ctx.input_non_tensors = { - key: val for (key, val) in zip(non_tensor_keys, list(args[ctx.end_tensors : ctx.end_non_tensors])) + key: val + for (key, val) in zip( + non_tensor_keys, list(args[ctx.end_tensors : ctx.end_non_tensors]) + ) } ctx.run_function = run_function ctx.input_params = list(args[ctx.end_non_tensors :]) @@ -98,13 +107,18 @@ class MixedCheckpointFunction(torch.autograd.Function): @staticmethod def backward(ctx, *output_grads): # additional_args = {key: ctx.input_tensors[key] for key in ctx.input_tensors if not isinstance(ctx.input_tensors[key],torch.Tensor)} - ctx.input_tensors = {key: ctx.input_tensors[key].detach().requires_grad_(True) for key in ctx.input_tensors} + ctx.input_tensors = { + key: ctx.input_tensors[key].detach().requires_grad_(True) for key in ctx.input_tensors + } with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): # Fixes a bug where the first op in run_function modifies the # Tensor storage in place, which is not allowed for detach()'d # Tensors. - shallow_copies = {key: ctx.input_tensors[key].view_as(ctx.input_tensors[key]) for key in ctx.input_tensors} + shallow_copies = { + key: ctx.input_tensors[key].view_as(ctx.input_tensors[key]) + for key in ctx.input_tensors + } # shallow_copies.update(additional_args) output_tensors = ctx.run_function(**shallow_copies, **ctx.input_non_tensors) input_grads = torch.autograd.grad( @@ -188,9 +202,9 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False, dtyp """ if not repeat_only: half = dim // 2 - freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( - device=timesteps.device - ) + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) args = timesteps[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: diff --git a/sat/sgm/modules/diffusionmodules/wrappers.py b/sat/sgm/modules/diffusionmodules/wrappers.py index d0b78ff..9f646da 100644 --- a/sat/sgm/modules/diffusionmodules/wrappers.py +++ b/sat/sgm/modules/diffusionmodules/wrappers.py @@ -6,7 +6,9 @@ OPENAIUNETWRAPPER = "sgm.modules.diffusionmodules.wrappers.OpenAIWrapper" class IdentityWrapper(nn.Module): - def __init__(self, diffusion_model, compile_model: bool = False, dtype: torch.dtype = torch.float32): + def __init__( + self, diffusion_model, compile_model: bool = False, dtype: torch.dtype = torch.float32 + ): super().__init__() compile = ( torch.compile diff --git a/sat/sgm/modules/distributions/distributions.py b/sat/sgm/modules/distributions/distributions.py index 0338a86..5373436 100644 --- a/sat/sgm/modules/distributions/distributions.py +++ b/sat/sgm/modules/distributions/distributions.py @@ -87,8 +87,14 @@ def normal_kl(mean1, logvar1, mean2, logvar2): # Force variances to be Tensors. Broadcasting helps convert scalars to # Tensors, but it does not work for torch.exp(). - logvar1, logvar2 = [x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2)] + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2) + ] return 0.5 * ( - -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) ) diff --git a/sat/sgm/modules/ema.py b/sat/sgm/modules/ema.py index 9f1f760..bf33b51 100644 --- a/sat/sgm/modules/ema.py +++ b/sat/sgm/modules/ema.py @@ -12,7 +12,9 @@ class LitEma(nn.Module): self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) self.register_buffer( "num_updates", - torch.tensor(0, dtype=torch.int) if use_num_upates else torch.tensor(-1, dtype=torch.int), + torch.tensor(0, dtype=torch.int) + if use_num_upates + else torch.tensor(-1, dtype=torch.int), ) for name, p in model.named_parameters(): @@ -45,9 +47,11 @@ class LitEma(nn.Module): if m_param[key].requires_grad: sname = self.m_name2s_name[key] shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) - shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) + shadow_params[sname].sub_( + one_minus_decay * (shadow_params[sname] - m_param[key]) + ) else: - assert not key in self.m_name2s_name + assert key not in self.m_name2s_name def copy_to(self, model): m_param = dict(model.named_parameters()) @@ -56,7 +60,7 @@ class LitEma(nn.Module): if m_param[key].requires_grad: m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) else: - assert not key in self.m_name2s_name + assert key not in self.m_name2s_name def store(self, parameters): """ diff --git a/sat/sgm/modules/encoders/modules.py b/sat/sgm/modules/encoders/modules.py index e8a16fc..c2ffdfd 100644 --- a/sat/sgm/modules/encoders/modules.py +++ b/sat/sgm/modules/encoders/modules.py @@ -99,7 +99,9 @@ class GeneralConditioner(nn.Module): elif "input_keys" in embconfig: embedder.input_keys = embconfig["input_keys"] else: - raise KeyError(f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}") + raise KeyError( + f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}" + ) embedder.legacy_ucg_val = embconfig.get("legacy_ucg_value", None) if embedder.legacy_ucg_val is not None: @@ -160,7 +162,10 @@ class GeneralConditioner(nn.Module): if cond_or_not is None: emb = ( expand_dims_like( - torch.bernoulli((1.0 - embedder.ucg_rate) * torch.ones(emb.shape[0], device=emb.device)), + torch.bernoulli( + (1.0 - embedder.ucg_rate) + * torch.ones(emb.shape[0], device=emb.device) + ), emb, ) * emb diff --git a/sat/sgm/modules/video_attention.py b/sat/sgm/modules/video_attention.py index 9f968d7..364a079 100644 --- a/sat/sgm/modules/video_attention.py +++ b/sat/sgm/modules/video_attention.py @@ -96,7 +96,9 @@ class VideoTransformerBlock(nn.Module): if self.checkpoint: print(f"{self.__class__.__name__} is using checkpointing") - def forward(self, x: torch.Tensor, context: torch.Tensor = None, timesteps: int = None) -> torch.Tensor: + def forward( + self, x: torch.Tensor, context: torch.Tensor = None, timesteps: int = None + ) -> torch.Tensor: if self.checkpoint: return checkpoint(self._forward, x, context, timesteps) else: @@ -239,7 +241,9 @@ class SpatialVideoTransformer(SpatialTransformer): spatial_context = context if self.use_spatial_context: - assert context.ndim == 3, f"n dims of spatial context should be 3 but are {context.ndim}" + assert ( + context.ndim == 3 + ), f"n dims of spatial context should be 3 but are {context.ndim}" time_context = context time_context_first_timestep = time_context[::timesteps] diff --git a/sat/sgm/util.py b/sat/sgm/util.py index b93a049..ea95cb0 100644 --- a/sat/sgm/util.py +++ b/sat/sgm/util.py @@ -86,7 +86,9 @@ class SafeConv3d(torch.nn.Conv3d): input_chunks = torch.chunk(input, part_num, dim=2) # NCTHW if kernel_size > 1: input_chunks = [input_chunks[0]] + [ - torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2) + torch.cat( + (input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2 + ) for i in range(1, len(input_chunks)) ] @@ -252,7 +254,7 @@ def count_params(model, verbose=False): def instantiate_from_config(config, **extra_kwargs): - if not "target" in config: + if "target" not in config: if config == "__is_first_stage__": return None elif config == "__is_unconditional__": diff --git a/sat/sgm/webds.py b/sat/sgm/webds.py index b99f9f3..b139633 100644 --- a/sat/sgm/webds.py +++ b/sat/sgm/webds.py @@ -93,7 +93,12 @@ class SimpleDistributedWebDataset(DataPipeline): def tar_file_iterator_with_meta( - fileobj, meta_names, skip_meta=r"__[^/]*__($|/)", suffix=None, handler=reraise_exception, meta_stream=None + fileobj, + meta_names, + skip_meta=r"__[^/]*__($|/)", + suffix=None, + handler=reraise_exception, + meta_stream=None, ): """Iterate over tar file, yielding filename, content pairs for the given tar stream. @@ -122,10 +127,13 @@ def tar_file_iterator_with_meta( except Exception as exn: from sat.helpers import print_rank0 - print_rank0(f"Error in loading jsonl {meta_file_name}, lineno {lineno}: {line}", level="DEBUG") + print_rank0( + f"Error in loading jsonl {meta_file_name}, lineno {lineno}: {line}", + level="DEBUG", + ) continue for item in meta_list: - if not item["key"] in meta_data: + if item["key"] not in meta_data: meta_data[item["key"]] = {} for meta_name in meta_names: if meta_name in item: @@ -186,7 +194,9 @@ def tar_file_expander_with_meta(data, meta_names, handler=reraise_exception): try: assert isinstance(source, dict) assert "stream" in source - for sample in tar_file_iterator_with_meta(source["stream"], meta_names, meta_stream=source["meta_stream"]): + for sample in tar_file_iterator_with_meta( + source["stream"], meta_names, meta_stream=source["meta_stream"] + ): assert isinstance(sample, dict) and "data" in sample and "fname" in sample sample["__url__"] = url yield sample @@ -250,7 +260,15 @@ class MetaDistributedWebDataset(DataPipeline): """ def __init__( - self, path, process_fn, seed, *, meta_names=[], nshards=sys.maxsize, shuffle_buffer=1000, include_dirs=None + self, + path, + process_fn, + seed, + *, + meta_names=[], + nshards=sys.maxsize, + shuffle_buffer=1000, + include_dirs=None, ): # os.environ['WDS_SHOW_SEED'] = '1' import torch @@ -361,7 +379,10 @@ def gopen_boto3(url, mode="rb", bufsize=8192 * 2): if mode[0] == "r": s3_client = boto3.client( - "s3", endpoint_url=endpoint_url, aws_access_key_id=access_key, aws_secret_access_key=secret_key + "s3", + endpoint_url=endpoint_url, + aws_access_key_id=access_key, + aws_secret_access_key=secret_key, ) bucket, key = url.split("/", 1) diff --git a/sat/train_video.py b/sat/train_video.py index b63bca8..58e9294 100644 --- a/sat/train_video.py +++ b/sat/train_video.py @@ -37,7 +37,9 @@ def save_texts(texts, save_dir, iterations): f.write(text + "\n") -def save_video_as_grid_and_mp4(video_batch: torch.Tensor, save_path: str, T: int, fps: int = 5, args=None, key=None): +def save_video_as_grid_and_mp4( + video_batch: torch.Tensor, save_path: str, T: int, fps: int = 5, args=None, key=None +): os.makedirs(save_path, exist_ok=True) for i, vid in enumerate(video_batch): @@ -52,7 +54,8 @@ def save_video_as_grid_and_mp4(video_batch: torch.Tensor, save_path: str, T: int writer.append_data(frame) if args is not None and args.wandb: wandb.log( - {key + f"_video_{i}": wandb.Video(now_save_path, fps=fps, format="mp4")}, step=args.iteration + 1 + {key + f"_video_{i}": wandb.Video(now_save_path, fps=fps, format="mp4")}, + step=args.iteration + 1, ) @@ -138,7 +141,9 @@ def broad_cast_batch(batch): return batch -def forward_step_eval(data_iterator, model, args, timers, only_log_video_latents=False, data_class=None): +def forward_step_eval( + data_iterator, model, args, timers, only_log_video_latents=False, data_class=None +): if mpu.get_model_parallel_rank() == 0: timers("data loader").start() batch_video = next(data_iterator) @@ -209,7 +214,9 @@ if __name__ == "__main__": args = argparse.Namespace(**vars(args), **vars(known)) data_class = get_obj_from_str(args.data_config["target"]) - create_dataset_function = partial(data_class.create_dataset_function, **args.data_config["params"]) + create_dataset_function = partial( + data_class.create_dataset_function, **args.data_config["params"] + ) import yaml @@ -225,7 +232,9 @@ if __name__ == "__main__": model_cls=SATVideoDiffusionEngine, forward_step_function=partial(forward_step, data_class=data_class), forward_step_eval=partial( - forward_step_eval, data_class=data_class, only_log_video_latents=args.only_log_video_latents + forward_step_eval, + data_class=data_class, + only_log_video_latents=args.only_log_video_latents, ), create_dataset_function=create_dataset_function, ) diff --git a/sat/vae_modules/attention.py b/sat/vae_modules/attention.py index 041df77..365aebf 100644 --- a/sat/vae_modules/attention.py +++ b/sat/vae_modules/attention.py @@ -94,7 +94,11 @@ class FeedForward(nn.Module): super().__init__() inner_dim = int(dim * mult) dim_out = default(dim_out, dim) - project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim) + project_in = ( + nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) + if not glu + else GEGLU(dim, inner_dim) + ) self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) @@ -126,7 +130,9 @@ class LinearAttention(nn.Module): def forward(self, x): b, c, h, w = x.shape qkv = self.to_qkv(x) - q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3) + q, k, v = rearrange( + qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3 + ) k = k.softmax(dim=-1) context = torch.einsum("bhdn,bhen->bhde", k, v) out = torch.einsum("bhde,bhdn->bhen", context, q) @@ -143,7 +149,9 @@ class SpatialSelfAttention(nn.Module): self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) def forward(self, x): h_ = x @@ -244,7 +252,9 @@ class CrossAttention(nn.Module): # new with sdp_kernel(**BACKEND_MAP[self.backend]): # print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape) - out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) # scale is dim_head ** -0.5 per default + out = F.scaled_dot_product_attention( + q, k, v, attn_mask=mask + ) # scale is dim_head ** -0.5 per default del q, k, v out = rearrange(out, "b h n d -> b n (h d)", h=h) @@ -422,7 +432,9 @@ class BasicTransformerBlock(nn.Module): self.norm1(x), context=context if self.disable_self_attn else None, additional_tokens=additional_tokens, - n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self if not self.disable_self_attn else 0, + n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self + if not self.disable_self_attn + else 0, ) + x ) @@ -499,7 +511,9 @@ class SpatialTransformer(nn.Module): sdp_backend=None, ): super().__init__() - print(f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads") + print( + f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads" + ) from omegaconf import ListConfig if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)): @@ -542,7 +556,9 @@ class SpatialTransformer(nn.Module): ] ) if not use_linear: - self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)) + self.proj_out = zero_module( + nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + ) else: # self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) self.proj_out = zero_module(nn.Linear(inner_dim, in_channels)) diff --git a/sat/vae_modules/autoencoder.py b/sat/vae_modules/autoencoder.py index 0adad85..7c129a0 100644 --- a/sat/vae_modules/autoencoder.py +++ b/sat/vae_modules/autoencoder.py @@ -97,9 +97,7 @@ class AbstractAutoencoder(pl.LightningModule): def instantiate_optimizer_from_config(self, params, lr, cfg): logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config") - return get_obj_from_str(cfg["target"])( - params, lr=lr, **cfg.get("params", dict()) - ) + return get_obj_from_str(cfg["target"])(params, lr=lr, **cfg.get("params", dict())) def configure_optimizers(self) -> Any: raise NotImplementedError() @@ -196,11 +194,11 @@ class AutoencodingEngine(AbstractAutoencoder): return self.decoder.get_last_layer() def encode( - self, - x: torch.Tensor, - return_reg_log: bool = False, - unregularized: bool = False, - **kwargs, + self, + x: torch.Tensor, + return_reg_log: bool = False, + unregularized: bool = False, + **kwargs, ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: z = self.encoder(x, **kwargs) if unregularized: @@ -214,14 +212,20 @@ class AutoencodingEngine(AbstractAutoencoder): x = self.decoder(z, **kwargs) return x - def forward(self, x: torch.Tensor, **additional_decode_kwargs) -> Tuple[torch.Tensor, torch.Tensor, dict]: + def forward( + self, x: torch.Tensor, **additional_decode_kwargs + ) -> Tuple[torch.Tensor, torch.Tensor, dict]: z, reg_log = self.encode(x, return_reg_log=True) dec = self.decode(z, **additional_decode_kwargs) return z, dec, reg_log - def inner_training_step(self, batch: dict, batch_idx: int, optimizer_idx: int = 0) -> torch.Tensor: + def inner_training_step( + self, batch: dict, batch_idx: int, optimizer_idx: int = 0 + ) -> torch.Tensor: x = self.get_input(batch) - additional_decode_kwargs = {key: batch[key] for key in self.additional_decode_keys.intersection(batch)} + additional_decode_kwargs = { + key: batch[key] for key in self.additional_decode_keys.intersection(batch) + } z, xrec, regularization_log = self(x, **additional_decode_kwargs) if hasattr(self.loss, "forward_keys"): extra_info = { @@ -357,12 +361,16 @@ class AutoencodingEngine(AbstractAutoencoder): if self.trainable_ae_params is None: ae_params = self.get_autoencoder_params() else: - ae_params, num_ae_params = self.get_param_groups(self.trainable_ae_params, self.ae_optimizer_args) + ae_params, num_ae_params = self.get_param_groups( + self.trainable_ae_params, self.ae_optimizer_args + ) logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}") if self.trainable_disc_params is None: disc_params = self.get_discriminator_params() else: - disc_params, num_disc_params = self.get_param_groups(self.trainable_disc_params, self.disc_optimizer_args) + disc_params, num_disc_params = self.get_param_groups( + self.trainable_disc_params, self.disc_optimizer_args + ) logpy.info(f"Number of trainable discriminator parameters: {num_disc_params:,}") opt_ae = self.instantiate_optimizer_from_config( ae_params, @@ -371,17 +379,23 @@ class AutoencodingEngine(AbstractAutoencoder): ) opts = [opt_ae] if len(disc_params) > 0: - opt_disc = self.instantiate_optimizer_from_config(disc_params, self.learning_rate, self.optimizer_config) + opt_disc = self.instantiate_optimizer_from_config( + disc_params, self.learning_rate, self.optimizer_config + ) opts.append(opt_disc) return opts @torch.no_grad() - def log_images(self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs) -> dict: + def log_images( + self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs + ) -> dict: log = dict() additional_decode_kwargs = {} x = self.get_input(batch) - additional_decode_kwargs.update({key: batch[key] for key in self.additional_decode_keys.intersection(batch)}) + additional_decode_kwargs.update( + {key: batch[key] for key in self.additional_decode_keys.intersection(batch)} + ) _, xrec, _ = self(x, **additional_decode_kwargs) log["inputs"] = x @@ -400,7 +414,9 @@ class AutoencodingEngine(AbstractAutoencoder): diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x) diff_ema.clamp_(0, 1.0) log["diff_ema"] = 2.0 * diff_ema - 1.0 - log["diff_boost_ema"] = 2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1 + log["diff_boost_ema"] = ( + 2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1 + ) if additional_log_kwargs: additional_decode_kwargs.update(additional_log_kwargs) _, xrec_add, _ = self(x, **additional_decode_kwargs) @@ -442,7 +458,9 @@ class AutoencodingEngineLegacy(AutoencodingEngine): params = super().get_autoencoder_params() return params - def encode(self, x: torch.Tensor, return_reg_log: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: + def encode( + self, x: torch.Tensor, return_reg_log: bool = False + ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: if self.max_batch_size is None: z = self.encoder(x) z = self.quant_conv(z) @@ -485,7 +503,9 @@ class AutoencoderKL(AutoencodingEngineLegacy): if "lossconfig" in kwargs: kwargs["loss_config"] = kwargs.pop("lossconfig") super().__init__( - regularizer_config={"target": ("sgm.modules.autoencoding.regularizers" ".DiagonalGaussianRegularizer")}, + regularizer_config={ + "target": ("sgm.modules.autoencoding.regularizers" ".DiagonalGaussianRegularizer") + }, **kwargs, ) @@ -519,7 +539,9 @@ class VideoAutoencodingEngine(AutoencodingEngine): if ckpt_path is not None: self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) - def log_videos(self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs) -> dict: + def log_videos( + self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs + ) -> dict: return self.log_images(batch, additional_log_kwargs, **kwargs) def get_input(self, batch: dict) -> torch.Tensor: @@ -530,7 +552,9 @@ class VideoAutoencodingEngine(AutoencodingEngine): batch = batch[self.input_key] global_src_rank = get_context_parallel_group_rank() * self.context_parallel_size - torch.distributed.broadcast(batch, src=global_src_rank, group=get_context_parallel_group()) + torch.distributed.broadcast( + batch, src=global_src_rank, group=get_context_parallel_group() + ) batch = _conv_split(batch, dim=2, kernel_size=1) return batch diff --git a/sat/vae_modules/cp_enc_dec.py b/sat/vae_modules/cp_enc_dec.py index 4d77324..ee7efed 100644 --- a/sat/vae_modules/cp_enc_dec.py +++ b/sat/vae_modules/cp_enc_dec.py @@ -201,7 +201,9 @@ def _pass_from_previous_rank(input_, dim, kernel_size): recv_rank += cp_world_size if cp_rank < cp_world_size - 1: - req_send = torch.distributed.isend(input_[-kernel_size + 1 :].contiguous(), send_rank, group=group) + req_send = torch.distributed.isend( + input_[-kernel_size + 1 :].contiguous(), send_rank, group=group + ) if cp_rank > 0: recv_buffer = torch.empty_like(input_[-kernel_size + 1 :]).contiguous() req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group) @@ -246,11 +248,12 @@ def _fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding=Non recv_buffer = torch.empty_like(input_[-kernel_size + 1 :]).contiguous() if cp_rank < cp_world_size - 1: - req_send = torch.distributed.isend(input_[-kernel_size + 1 :].contiguous(), send_rank, group=group) + req_send = torch.distributed.isend( + input_[-kernel_size + 1 :].contiguous(), send_rank, group=group + ) if cp_rank > 0: req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group) - if cp_rank == 0: if cache_padding is not None: input_ = torch.cat([cache_padding.transpose(0, dim).to(input_.device), input_], dim=0) @@ -334,7 +337,9 @@ def fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding): class ContextParallelCausalConv3d(nn.Module): - def __init__(self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], stride=1, **kwargs): + def __init__( + self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], stride=1, **kwargs + ): super().__init__() kernel_size = cast_tuple(kernel_size, 3) @@ -354,7 +359,9 @@ class ContextParallelCausalConv3d(nn.Module): stride = (stride, stride, stride) dilation = (1, 1, 1) - self.conv = Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) + self.conv = Conv3d( + chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs + ) self.cache_padding = None def forward(self, input_, clear_cache=True): @@ -369,7 +376,11 @@ class ContextParallelCausalConv3d(nn.Module): global_rank = torch.distributed.get_rank() if cp_world_size == 1: self.cache_padding = ( - input_parallel[:, :, -self.time_kernel_size + 1 :].contiguous().detach().clone().cpu() + input_parallel[:, :, -self.time_kernel_size + 1 :] + .contiguous() + .detach() + .clone() + .cpu() ) else: if cp_rank == cp_world_size - 1: @@ -379,9 +390,13 @@ class ContextParallelCausalConv3d(nn.Module): group=get_context_parallel_group(), ) if cp_rank == 0: - recv_buffer = torch.empty_like(input_parallel[:, :, -self.time_kernel_size + 1 :]).contiguous() + recv_buffer = torch.empty_like( + input_parallel[:, :, -self.time_kernel_size + 1 :] + ).contiguous() torch.distributed.recv( - recv_buffer, global_rank - 1 + cp_world_size, group=get_context_parallel_group() + recv_buffer, + global_rank - 1 + cp_world_size, + group=get_context_parallel_group(), ) self.cache_padding = recv_buffer.contiguous().detach().clone().cpu() @@ -406,7 +421,9 @@ class ContextParallelGroupNorm(torch.nn.GroupNorm): def Normalize(in_channels, gather=False, **kwargs): if gather: - return ContextParallelGroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + return ContextParallelGroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True + ) else: return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) @@ -460,7 +477,8 @@ class SpatialNorm3D(nn.Module): zq_rest_splits = torch.split(zq_rest, 32, dim=1) interpolated_splits = [ - torch.nn.functional.interpolate(split, size=f_rest_size, mode="nearest") for split in zq_rest_splits + torch.nn.functional.interpolate(split, size=f_rest_size, mode="nearest") + for split in zq_rest_splits ] zq_rest = torch.cat(interpolated_splits, dim=1) @@ -471,7 +489,8 @@ class SpatialNorm3D(nn.Module): zq_splits = torch.split(zq, 32, dim=1) interpolated_splits = [ - torch.nn.functional.interpolate(split, size=f_size, mode="nearest") for split in zq_splits + torch.nn.functional.interpolate(split, size=f_size, mode="nearest") + for split in zq_splits ] zq = torch.cat(interpolated_splits, dim=1) @@ -511,7 +530,9 @@ class Upsample3D(nn.Module): super().__init__() self.with_conv = with_conv if self.with_conv: - self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) self.compress_time = compress_time def forward(self, x, fake_cp=True): @@ -523,14 +544,16 @@ class Upsample3D(nn.Module): splits = torch.split(x_rest, 32, dim=1) interpolated_splits = [ - torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits + torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") + for split in splits ] x_rest = torch.cat(interpolated_splits, dim=1) x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2) else: splits = torch.split(x, 32, dim=1) interpolated_splits = [ - torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits + torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") + for split in splits ] x = torch.cat(interpolated_splits, dim=1) @@ -541,7 +564,8 @@ class Upsample3D(nn.Module): splits = torch.split(x, 32, dim=1) interpolated_splits = [ - torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits + torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") + for split in splits ] x = torch.cat(interpolated_splits, dim=1) @@ -563,7 +587,9 @@ class DownSample3D(nn.Module): out_channels = in_channels if self.with_conv: # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0) + self.conv = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=2, padding=0 + ) self.compress_time = compress_time def forward(self, x, fake_cp=True): @@ -578,7 +604,8 @@ class DownSample3D(nn.Module): if x_rest.shape[-1] > 0: splits = torch.split(x_rest, 32, dim=1) interpolated_splits = [ - torch.nn.functional.avg_pool1d(split, kernel_size=2, stride=2) for split in splits + torch.nn.functional.avg_pool1d(split, kernel_size=2, stride=2) + for split in splits ] x_rest = torch.cat(interpolated_splits, dim=1) x = torch.cat([x_first[..., None], x_rest], dim=-1) @@ -587,7 +614,8 @@ class DownSample3D(nn.Module): # x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2) splits = torch.split(x, 32, dim=1) interpolated_splits = [ - torch.nn.functional.avg_pool1d(split, kernel_size=2, stride=2) for split in splits + torch.nn.functional.avg_pool1d(split, kernel_size=2, stride=2) + for split in splits ] x = torch.cat(interpolated_splits, dim=1) x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w) @@ -923,9 +951,13 @@ class ContextParallelDecoder3D(nn.Module): up.attn = attn if i_level != 0: if i_level < self.num_resolutions - self.temporal_compress_level: - up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False) + up.upsample = Upsample3D( + block_in, with_conv=resamp_with_conv, compress_time=False + ) else: - up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=True) + up.upsample = Upsample3D( + block_in, with_conv=resamp_with_conv, compress_time=True + ) self.up.insert(0, up) self.norm_out = Normalize3D(block_in, zq_ch, add_conv=add_conv, gather=gather_norm) diff --git a/sat/vae_modules/ema.py b/sat/vae_modules/ema.py index 9f1f760..bf33b51 100644 --- a/sat/vae_modules/ema.py +++ b/sat/vae_modules/ema.py @@ -12,7 +12,9 @@ class LitEma(nn.Module): self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) self.register_buffer( "num_updates", - torch.tensor(0, dtype=torch.int) if use_num_upates else torch.tensor(-1, dtype=torch.int), + torch.tensor(0, dtype=torch.int) + if use_num_upates + else torch.tensor(-1, dtype=torch.int), ) for name, p in model.named_parameters(): @@ -45,9 +47,11 @@ class LitEma(nn.Module): if m_param[key].requires_grad: sname = self.m_name2s_name[key] shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) - shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) + shadow_params[sname].sub_( + one_minus_decay * (shadow_params[sname] - m_param[key]) + ) else: - assert not key in self.m_name2s_name + assert key not in self.m_name2s_name def copy_to(self, model): m_param = dict(model.named_parameters()) @@ -56,7 +60,7 @@ class LitEma(nn.Module): if m_param[key].requires_grad: m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) else: - assert not key in self.m_name2s_name + assert key not in self.m_name2s_name def store(self, parameters): """ diff --git a/sat/vae_modules/regularizers.py b/sat/vae_modules/regularizers.py index 205bd4a..eed2e72 100644 --- a/sat/vae_modules/regularizers.py +++ b/sat/vae_modules/regularizers.py @@ -77,7 +77,9 @@ class IdentityRegularizer(AbstractRegularizer): yield from () -def measure_perplexity(predicted_indices: torch.Tensor, num_centroids: int) -> Tuple[torch.Tensor, torch.Tensor]: +def measure_perplexity( + predicted_indices: torch.Tensor, num_centroids: int +) -> Tuple[torch.Tensor, torch.Tensor]: # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally encodings = F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids) diff --git a/sat/vae_modules/utils.py b/sat/vae_modules/utils.py index 8c8dba6..5ec1eec 100644 --- a/sat/vae_modules/utils.py +++ b/sat/vae_modules/utils.py @@ -78,7 +78,9 @@ class SafeConv3d(torch.nn.Conv3d): input_chunks = torch.chunk(input, part_num, dim=2) # NCTHW if kernel_size > 1: input_chunks = [input_chunks[0]] + [ - torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2) + torch.cat( + (input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2 + ) for i in range(1, len(input_chunks)) ] @@ -244,7 +246,7 @@ def count_params(model, verbose=False): def instantiate_from_config(config): - if not "target" in config: + if "target" not in config: if config == "__is_first_stage__": return None elif config == "__is_unconditional__": diff --git a/tools/caption/README.md b/tools/caption/README.md index e81691a..2cb8575 100644 --- a/tools/caption/README.md +++ b/tools/caption/README.md @@ -11,7 +11,7 @@ data into textual descriptions to provide the essential training data for text-t ## Video Caption via CogVLM2-Caption -🤗 [Hugging Face](https://huggingface.co/THUDM/cogvlm2-llama3-caption) | 🤖 [ModelScope](https://modelscope.cn/models/ZhipuAI/cogvlm2-llama3-caption/) +🤗 [Hugging Face](https://huggingface.co/THUDM/cogvlm2-llama3-caption) | 🤖 [ModelScope](https://modelscope.cn/models/ZhipuAI/cogvlm2-llama3-caption/) CogVLM2-Caption is a video captioning model used to generate training data for the CogVideoX model. @@ -64,4 +64,4 @@ CogVLM2-Video: journal={arXiv preprint arXiv:2408.16500}, year={2024} } -``` \ No newline at end of file +``` diff --git a/tools/caption/README_ja.md b/tools/caption/README_ja.md index 25c6cce..e328a9a 100644 --- a/tools/caption/README_ja.md +++ b/tools/caption/README_ja.md @@ -8,7 +8,7 @@ がオープンソース化されました。ぜひダウンロードしてご利用ください。 ## CogVLM2-Captionによるビデオキャプション -🤗 [Hugging Face](https://huggingface.co/THUDM/cogvlm2-llama3-caption) | 🤖 [ModelScope](https://modelscope.cn/models/ZhipuAI/cogvlm2-llama3-caption/) +🤗 [Hugging Face](https://huggingface.co/THUDM/cogvlm2-llama3-caption) | 🤖 [ModelScope](https://modelscope.cn/models/ZhipuAI/cogvlm2-llama3-caption/) CogVLM2-Captionは、CogVideoXモデルのトレーニングデータを生成するために使用されるビデオキャプションモデルです。 diff --git a/tools/caption/README_zh.md b/tools/caption/README_zh.md index f6da7a6..8cc79e1 100644 --- a/tools/caption/README_zh.md +++ b/tools/caption/README_zh.md @@ -9,7 +9,7 @@ ## 通过 CogVLM2-Caption 模型生成视频Caption -🤗 [Hugging Face](https://huggingface.co/THUDM/cogvlm2-llama3-caption) | 🤖 [ModelScope](https://modelscope.cn/models/ZhipuAI/cogvlm2-llama3-caption/) +🤗 [Hugging Face](https://huggingface.co/THUDM/cogvlm2-llama3-caption) | 🤖 [ModelScope](https://modelscope.cn/models/ZhipuAI/cogvlm2-llama3-caption/) CogVLM2-Caption是用于生成CogVideoX模型训练数据的视频caption模型。 @@ -64,4 +64,4 @@ CogVLM2-Video: journal={arXiv preprint arXiv:2408.16500}, year={2024} } -``` \ No newline at end of file +``` diff --git a/tools/caption/requirements.txt b/tools/caption/requirements.txt index ce2e17e..1c909e4 100644 --- a/tools/caption/requirements.txt +++ b/tools/caption/requirements.txt @@ -20,4 +20,4 @@ flask gunicorn gevent requests -gradio \ No newline at end of file +gradio diff --git a/tools/caption/video_caption.py b/tools/caption/video_caption.py index 1110fca..7b2bd71 100644 --- a/tools/caption/video_caption.py +++ b/tools/caption/video_caption.py @@ -9,11 +9,16 @@ from transformers import AutoModelForCausalLM, AutoTokenizer MODEL_PATH = "THUDM/cogvlm2-llama3-caption" DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' -TORCH_TYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[ - 0] >= 8 else torch.float16 +TORCH_TYPE = ( + torch.bfloat16 + if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 + else torch.float16 +) parser = argparse.ArgumentParser(description="CogVLM2-Video CLI Demo") -parser.add_argument('--quant', type=int, choices=[4, 8], help='Enable 4-bit or 8-bit precision loading', default=0) +parser.add_argument( + '--quant', type=int, choices=[4, 8], help='Enable 4-bit or 8-bit precision loading', default=0 +) args = parser.parse_args([]) @@ -29,8 +34,11 @@ def load_video(video_data, strategy='chat'): clip_end_sec = 60 clip_start_sec = 0 start_frame = int(clip_start_sec * decord_vr.get_avg_fps()) - end_frame = min(total_frames, - int(clip_end_sec * decord_vr.get_avg_fps())) if clip_end_sec is not None else total_frames + end_frame = ( + min(total_frames, int(clip_end_sec * decord_vr.get_avg_fps())) + if clip_end_sec is not None + else total_frames + ) frame_id_list = np.linspace(start_frame, end_frame - 1, num_frames, dtype=int) elif strategy == 'chat': timestamps = decord_vr.get_frame_timestamp(np.arange(total_frames)) @@ -54,11 +62,11 @@ tokenizer = AutoTokenizer.from_pretrained( trust_remote_code=True, ) -model = AutoModelForCausalLM.from_pretrained( - MODEL_PATH, - torch_dtype=TORCH_TYPE, - trust_remote_code=True -).eval().to(DEVICE) +model = ( + AutoModelForCausalLM.from_pretrained(MODEL_PATH, torch_dtype=TORCH_TYPE, trust_remote_code=True) + .eval() + .to(DEVICE) +) def predict(prompt, video_data, temperature): @@ -69,11 +77,7 @@ def predict(prompt, video_data, temperature): history = [] query = prompt inputs = model.build_conversation_input_ids( - tokenizer=tokenizer, - query=query, - images=[video], - history=history, - template_version=strategy + tokenizer=tokenizer, query=query, images=[video], history=history, template_version=strategy ) inputs = { 'input_ids': inputs['input_ids'].unsqueeze(0).to('cuda'), @@ -91,7 +95,7 @@ def predict(prompt, video_data, temperature): } with torch.no_grad(): outputs = model.generate(**inputs, **gen_kwargs) - outputs = outputs[:, inputs['input_ids'].shape[1]:] + outputs = outputs[:, inputs['input_ids'].shape[1] :] response = tokenizer.decode(outputs[0], skip_special_tokens=True) return response diff --git a/tools/convert_weight_deepspeed2hf.py b/tools/convert_weight_deepspeed2hf.py index 3c5ed88..7d228e1 100644 --- a/tools/convert_weight_deepspeed2hf.py +++ b/tools/convert_weight_deepspeed2hf.py @@ -31,9 +31,18 @@ from dataclasses import dataclass # while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with # DeepSpeed data structures it has to be available in the current python environment. from deepspeed.utils import logger -from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS, - FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES, - FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS) +from deepspeed.checkpoint.constants import ( + DS_VERSION, + OPTIMIZER_STATE_DICT, + SINGLE_PARTITION_OF_FP32_GROUPS, + FP32_FLAT_GROUPS, + ZERO_STAGE, + PARTITION_COUNT, + PARAM_SHAPES, + BUFFER_NAMES, + FROZEN_PARAM_SHAPES, + FROZEN_PARAM_FRAGMENTS, +) @dataclass @@ -134,12 +143,14 @@ def parse_model_states(files): frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None) - z_model_state = zero_model_state(buffers=buffers, - param_shapes=param_shapes, - shared_params=shared_params, - ds_version=ds_version, - frozen_param_shapes=frozen_param_shapes, - frozen_param_fragments=frozen_param_fragments) + z_model_state = zero_model_state( + buffers=buffers, + param_shapes=param_shapes, + shared_params=shared_params, + ds_version=ds_version, + frozen_param_shapes=frozen_param_shapes, + frozen_param_fragments=frozen_param_fragments, + ) zero_model_states.append(z_model_state) return zero_model_states @@ -155,7 +166,7 @@ def parse_optim_states(files, ds_checkpoint_dir): state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None) state_dicts.append(state_dict) - if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]: + if ZERO_STAGE not in state_dicts[0][OPTIMIZER_STATE_DICT]: raise ValueError(f"{files[0]} is not a zero checkpoint") zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE] world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT] @@ -181,7 +192,9 @@ def parse_optim_states(files, ds_checkpoint_dir): else: raise ValueError(f"unknown zero stage {zero_stage}") - fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))] + fp32_flat_groups = [ + state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts)) + ] return zero_stage, world_size, fp32_flat_groups @@ -205,15 +218,20 @@ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_ print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}') if zero_stage <= 2: - return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, - exclude_frozen_parameters) + return _get_fp32_state_dict_from_zero2_checkpoint( + world_size, fp32_flat_groups, zero_model_states, exclude_frozen_parameters + ) elif zero_stage == 3: - return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, - exclude_frozen_parameters) + return _get_fp32_state_dict_from_zero3_checkpoint( + world_size, fp32_flat_groups, zero_model_states, exclude_frozen_parameters + ) def _zero2_merge_frozen_params(state_dict, zero_model_states): - if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0: + if ( + zero_model_states[0].frozen_param_shapes is None + or len(zero_model_states[0].frozen_param_shapes) == 0 + ): return frozen_param_shapes = zero_model_states[0].frozen_param_shapes @@ -269,11 +287,17 @@ def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero full_single_fp32_vector = torch.cat(merged_partitions, 0) merged_single_partition_of_fp32_groups.append(full_single_fp32_vector) avail_numel = sum( - [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups]) + [ + full_single_fp32_vector.numel() + for full_single_fp32_vector in merged_single_partition_of_fp32_groups + ] + ) if debug: wanted_params = sum([len(shapes) for shapes in param_shapes]) - wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes]) + wanted_numel = sum( + [sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes] + ) # not asserting if there is a mismatch due to possible padding print(f"Have {avail_numel} numels to process.") print(f"Need {wanted_numel} numels in {wanted_params} params.") @@ -283,18 +307,23 @@ def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero # out-of-core computing solution total_numel = 0 total_params = 0 - for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups): + for shapes, full_single_fp32_vector in zip( + param_shapes, merged_single_partition_of_fp32_groups + ): offset = 0 avail_numel = full_single_fp32_vector.numel() for name, shape in shapes.items(): - - unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape) + unpartitioned_numel = ( + shape.numel() if _has_callable(shape, 'numel') else math.prod(shape) + ) total_numel += unpartitioned_numel total_params += 1 if debug: print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ") - state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape) + state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view( + shape + ) offset += unpartitioned_numel # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and @@ -322,8 +351,9 @@ def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements") -def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, - exclude_frozen_parameters): +def _get_fp32_state_dict_from_zero2_checkpoint( + world_size, fp32_flat_groups, zero_model_states, exclude_frozen_parameters +): state_dict = OrderedDict() # buffers @@ -353,7 +383,10 @@ def zero3_partitioned_param_info(unpartitioned_numel, world_size): def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states): - if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0: + if ( + zero_model_states[0].frozen_param_shapes is None + or len(zero_model_states[0].frozen_param_shapes) == 0 + ): return if debug: @@ -364,7 +397,10 @@ def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states): frozen_param_shapes = zero_model_states[0].frozen_param_shapes wanted_params = len(frozen_param_shapes) wanted_numel = sum(s.numel() for s in frozen_param_shapes.values()) - avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size + avail_numel = ( + sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) + * world_size + ) print(f'Frozen params: Have {avail_numel} numels to process.') print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params') @@ -375,10 +411,14 @@ def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states): unpartitioned_numel = shape.numel() total_numel += unpartitioned_numel - param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states) + param_frags = tuple( + model_state.frozen_param_fragments[name] for model_state in zero_model_states + ) state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape) - partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info( + unpartitioned_numel, world_size + ) if debug: print( @@ -416,21 +456,32 @@ class GatheredTensor: start_group_id = None end_group_id = None for group_id in range(len(self.flat_groups_offset)): - if self.flat_groups_offset[group_id] <= self.offset < self.flat_groups_offset[group_id + 1]: + if ( + self.flat_groups_offset[group_id] + <= self.offset + < self.flat_groups_offset[group_id + 1] + ): start_group_id = group_id - if self.flat_groups_offset[group_id] < end_idx <= self.flat_groups_offset[group_id + 1]: + if ( + self.flat_groups_offset[group_id] + < end_idx + <= self.flat_groups_offset[group_id + 1] + ): end_group_id = group_id break # collect weights from related group/groups for group_id in range(start_group_id, end_group_id + 1): flat_tensor = flat_groups_at_rank_i[group_id] start_offset = self.offset - self.flat_groups_offset[group_id] - end_offset = min(end_idx, self.flat_groups_offset[group_id + 1]) - self.flat_groups_offset[group_id] + end_offset = ( + min(end_idx, self.flat_groups_offset[group_id + 1]) + - self.flat_groups_offset[group_id] + ) pad_flat_param_chunks.append(flat_tensor[start_offset:end_offset]) # collect weights from all ranks pad_flat_param = torch.cat(pad_flat_param_chunks, dim=0) - param = pad_flat_param[:self.shape.numel()].view(self.shape).contiguous() + param = pad_flat_param[: self.shape.numel()].view(self.shape).contiguous() return param @@ -461,12 +512,16 @@ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero offset = 0 total_numel = 0 total_params = 0 - flat_groups_offset = [0] + list(np.cumsum([flat_tensor.numel() for flat_tensor in fp32_flat_groups[0]])) + flat_groups_offset = [0] + list( + np.cumsum([flat_tensor.numel() for flat_tensor in fp32_flat_groups[0]]) + ) for name, shape in tqdm(param_shapes.items(), desc='Gathering sharded weights'): unpartitioned_numel = shape.numel() total_numel += unpartitioned_numel total_params += 1 - partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info( + unpartitioned_numel, world_size + ) if debug: print( @@ -474,7 +529,9 @@ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero ) # memory efficient tensor - tensor = GatheredTensor(fp32_flat_groups, flat_groups_offset, offset, partitioned_numel, shape) + tensor = GatheredTensor( + fp32_flat_groups, flat_groups_offset, offset, partitioned_numel, shape + ) state_dict[name] = tensor offset += partitioned_numel @@ -484,11 +541,14 @@ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero if offset != avail_numel: raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong") - print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements") + print( + f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements" + ) -def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, - exclude_frozen_parameters): +def _get_fp32_state_dict_from_zero3_checkpoint( + world_size, fp32_flat_groups, zero_model_states, exclude_frozen_parameters +): state_dict = OrderedDict() # buffers @@ -530,10 +590,9 @@ def to_torch_tensor(state_dict, return_empty_tensor=False): return torch_state_dict -def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, - tag=None, - exclude_frozen_parameters=False, - lazy_mode=False): +def get_fp32_state_dict_from_zero_checkpoint( + checkpoint_dir, tag=None, exclude_frozen_parameters=False, lazy_mode=False +): """ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example @@ -588,19 +647,23 @@ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, if not os.path.isdir(ds_checkpoint_dir): raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist") - state_dict = _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters) + state_dict = _get_fp32_state_dict_from_zero_checkpoint( + ds_checkpoint_dir, exclude_frozen_parameters + ) if lazy_mode: return state_dict else: return to_torch_tensor(state_dict) -def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, - output_dir, - max_shard_size="5GB", - safe_serialization=False, - tag=None, - exclude_frozen_parameters=False): +def convert_zero_checkpoint_to_fp32_state_dict( + checkpoint_dir, + output_dir, + max_shard_size="5GB", + safe_serialization=False, + tag=None, + exclude_frozen_parameters=False, +): """ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed. @@ -629,25 +692,28 @@ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, raise # Convert zero checkpoint to state_dict - state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, - tag, - exclude_frozen_parameters, - lazy_mode=True) + state_dict = get_fp32_state_dict_from_zero_checkpoint( + checkpoint_dir, tag, exclude_frozen_parameters, lazy_mode=True + ) # Shard the model if it is too big. weights_name = "model.safetensors" if safe_serialization else "pytorch_model.bin" if max_shard_size is not None: - filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors") + filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace( + ".safetensors", "{suffix}.safetensors" + ) # an memory-efficient approach for sharding empty_state_dict = to_torch_tensor(state_dict, return_empty_tensor=True) - state_dict_split = split_torch_state_dict_into_shards(empty_state_dict, - filename_pattern=filename_pattern, - max_shard_size=max_shard_size) + state_dict_split = split_torch_state_dict_into_shards( + empty_state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size + ) else: from collections import namedtuple + StateDictSplit = namedtuple("StateDictSplit", ["is_sharded", "filename_to_tensors"]) - state_dict_split = StateDictSplit(is_sharded=False, - filename_to_tensors={weights_name: list(state_dict.keys())}) + state_dict_split = StateDictSplit( + is_sharded=False, filename_to_tensors={weights_name: list(state_dict.keys())} + ) # Save the model by shard os.makedirs(output_dir, exist_ok=True) @@ -673,7 +739,9 @@ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, "metadata": state_dict_split.metadata, "weight_map": state_dict_split.tensor_to_filename, } - save_index_file = "model.safetensors.index.json" if safe_serialization else "pytorch_model.bin.index.json" + save_index_file = ( + "model.safetensors.index.json" if safe_serialization else "pytorch_model.bin.index.json" + ) save_index_file = os.path.join(output_dir, save_index_file) with open(save_index_file, "w", encoding="utf-8") as f: content = json.dumps(index, indent=2, sort_keys=True) + "\n" @@ -719,12 +787,14 @@ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): return model -def convert_zero_checkpoint_to_bf16_state_dict(checkpoint_dir, - output_dir, - max_shard_size="5GB", - safe_serialization=True, - tag=None, - exclude_frozen_parameters=False): +def convert_zero_checkpoint_to_bf16_state_dict( + checkpoint_dir, + output_dir, + max_shard_size="5GB", + safe_serialization=True, + tag=None, + exclude_frozen_parameters=False, +): """ 将 ZeRO 2 或 ZeRO 3 格式的 DeepSpeed 检查点转换为 BF16,并输出到指定目录下,命名规则为: - 如果只有一个分片: @@ -748,10 +818,7 @@ def convert_zero_checkpoint_to_bf16_state_dict(checkpoint_dir, raise ImportError("You need `pip install huggingface_hub` to use the sharding feature.") state_dict = get_fp32_state_dict_from_zero_checkpoint( - checkpoint_dir, - tag=tag, - exclude_frozen_parameters=exclude_frozen_parameters, - lazy_mode=True + checkpoint_dir, tag=tag, exclude_frozen_parameters=exclude_frozen_parameters, lazy_mode=True ) state_dict = to_torch_tensor(state_dict, return_empty_tensor=False) @@ -766,9 +833,7 @@ def convert_zero_checkpoint_to_bf16_state_dict(checkpoint_dir, empty_state_dict = to_torch_tensor(state_dict, return_empty_tensor=True) state_dict_split = split_torch_state_dict_into_shards( - empty_state_dict, - filename_pattern=filename_pattern, - max_shard_size=max_shard_size + empty_state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size ) os.makedirs(output_dir, exist_ok=True) @@ -789,7 +854,6 @@ def convert_zero_checkpoint_to_bf16_state_dict(checkpoint_dir, del shard_state_dict gc.collect() - if state_dict_split.is_sharded: index = { "metadata": state_dict_split.metadata, @@ -801,21 +865,29 @@ def convert_zero_checkpoint_to_bf16_state_dict(checkpoint_dir, else: only_filename = list(state_dict_split.filename_to_tensors.keys())[0] old_path = os.path.join(output_dir, only_filename) - new_path = os.path.join(output_dir, "diffusion_pytorch_model.safetensors" if safe_serialization - else "diffusion_pytorch_model.bin") + new_path = os.path.join( + output_dir, + "diffusion_pytorch_model.safetensors" + if safe_serialization + else "diffusion_pytorch_model.bin", + ) if old_path != new_path: os.rename(old_path, new_path) if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("checkpoint_dir", - type=str, - help="path to the desired checkpoint folder, e.g., path/checkpoint-12") - parser.add_argument("output_dir", - type=str, - help="directory to the pytorch fp32 state_dict output files" - "(e.g. path/checkpoint-12-output/)") + parser.add_argument( + "checkpoint_dir", + type=str, + help="path to the desired checkpoint folder, e.g., path/checkpoint-12", + ) + parser.add_argument( + "output_dir", + type=str, + help="directory to the pytorch fp32 state_dict output files" + "(e.g. path/checkpoint-12-output/)", + ) parser.add_argument( "--max_shard_size", type=str, @@ -823,26 +895,34 @@ if __name__ == "__main__": help="The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size" "lower than this size. If expressed as a string, needs to be digits followed by a unit (like `5MB`" "We default it to 5GB in order for models to be able to run easily on free-tier google colab instances" - "without CPU OOM issues.") + "without CPU OOM issues.", + ) parser.add_argument( "--safe_serialization", default=False, action='store_true', - help="Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).") - parser.add_argument("-t", - "--tag", - type=str, - default=None, - help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1") - parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters") + help="Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).", + ) + parser.add_argument( + "-t", + "--tag", + type=str, + default=None, + help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1", + ) + parser.add_argument( + "--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters" + ) parser.add_argument("-d", "--debug", action='store_true', help="enable debug") args = parser.parse_args() debug = args.debug - convert_zero_checkpoint_to_bf16_state_dict(args.checkpoint_dir, - args.output_dir, - max_shard_size=args.max_shard_size, - safe_serialization=args.safe_serialization, - tag=args.tag, - exclude_frozen_parameters=args.exclude_frozen_parameters) + convert_zero_checkpoint_to_bf16_state_dict( + args.checkpoint_dir, + args.output_dir, + max_shard_size=args.max_shard_size, + safe_serialization=args.safe_serialization, + tag=args.tag, + exclude_frozen_parameters=args.exclude_frozen_parameters, + ) diff --git a/tools/convert_weight_sat2hf.py b/tools/convert_weight_sat2hf.py index b70af1a..35e7fa7 100644 --- a/tools/convert_weight_sat2hf.py +++ b/tools/convert_weight_sat2hf.py @@ -10,6 +10,7 @@ Original Script: https://github.com/huggingface/diffusers/blob/main/scripts/convert_cogvideox_to_diffusers.py """ + import argparse from typing import Any, Dict @@ -143,7 +144,9 @@ def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]: return state_dict -def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]: +def update_state_dict_inplace( + state_dict: Dict[str, Any], old_key: str, new_key: str +) -> Dict[str, Any]: state_dict[new_key] = state_dict.pop(old_key) @@ -164,8 +167,11 @@ def convert_transformer( num_layers=num_layers, num_attention_heads=num_attention_heads, use_rotary_positional_embeddings=use_rotary_positional_embeddings, - ofs_embed_dim=512 if (i2v and init_kwargs["patch_size_t"] is not None) else None, # CogVideoX1.5-5B-I2V - use_learned_positional_embeddings=i2v and init_kwargs["patch_size_t"] is None, # CogVideoX-5B-I2V + ofs_embed_dim=512 + if (i2v and init_kwargs["patch_size_t"] is not None) + else None, # CogVideoX1.5-5B-I2V + use_learned_positional_embeddings=i2v + and init_kwargs["patch_size_t"] is None, # CogVideoX-5B-I2V **init_kwargs, ).to(dtype=dtype) @@ -240,17 +246,40 @@ def get_transformer_init_kwargs(version: str): def get_args(): parser = argparse.ArgumentParser() parser.add_argument( - "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint" - ) - parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint") - parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") - parser.add_argument("--fp16", action="store_true", default=False, help="Whether to save the model weights in fp16") - parser.add_argument("--bf16", action="store_true", default=False, help="Whether to save the model weights in bf16") - parser.add_argument( - "--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving" + "--transformer_ckpt_path", + type=str, + default=None, + help="Path to original transformer checkpoint", ) parser.add_argument( - "--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory" + "--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint" + ) + parser.add_argument( + "--output_path", type=str, required=True, help="Path where converted model should be saved" + ) + parser.add_argument( + "--fp16", + action="store_true", + default=False, + help="Whether to save the model weights in fp16", + ) + parser.add_argument( + "--bf16", + action="store_true", + default=False, + help="Whether to save the model weights in bf16", + ) + parser.add_argument( + "--push_to_hub", + action="store_true", + default=False, + help="Whether to push to HF Hub after saving", + ) + parser.add_argument( + "--text_encoder_cache_dir", + type=str, + default=None, + help="Path to text encoder cache directory", ) parser.add_argument( "--typecast_text_encoder", @@ -261,15 +290,24 @@ def get_args(): # For CogVideoX-2B, num_layers is 30. For 5B, it is 42 parser.add_argument("--num_layers", type=int, default=30, help="Number of transformer blocks") # For CogVideoX-2B, num_attention_heads is 30. For 5B, it is 48 - parser.add_argument("--num_attention_heads", type=int, default=30, help="Number of attention heads") + parser.add_argument( + "--num_attention_heads", type=int, default=30, help="Number of attention heads" + ) # For CogVideoX-2B, use_rotary_positional_embeddings is False. For 5B, it is True parser.add_argument( - "--use_rotary_positional_embeddings", action="store_true", default=False, help="Whether to use RoPE or not" + "--use_rotary_positional_embeddings", + action="store_true", + default=False, + help="Whether to use RoPE or not", ) # For CogVideoX-2B, scaling_factor is 1.15258426. For 5B, it is 0.7 - parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE") + parser.add_argument( + "--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE" + ) # For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0 - parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE") + parser.add_argument( + "--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE" + ) parser.add_argument( "--i2v", action="store_true", @@ -313,7 +351,9 @@ if __name__ == "__main__": text_encoder_id = "google/t5-v1_1-xxl" tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH) - text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir) + text_encoder = T5EncoderModel.from_pretrained( + text_encoder_id, cache_dir=args.text_encoder_cache_dir + ) if args.typecast_text_encoder: text_encoder = text_encoder.to(dtype=dtype) @@ -355,4 +395,9 @@ if __name__ == "__main__": # This is necessary This is necessary for users with insufficient memory, # such as those using Colab and notebooks, as it can save some memory used for model loading. - pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=args.push_to_hub) + pipe.save_pretrained( + args.output_path, + safe_serialization=True, + max_shard_size="5GB", + push_to_hub=args.push_to_hub, + ) diff --git a/tools/export_sat_lora_weight.py b/tools/export_sat_lora_weight.py index 9340d8f..4f4be2c 100644 --- a/tools/export_sat_lora_weight.py +++ b/tools/export_sat_lora_weight.py @@ -1,6 +1,6 @@ from typing import Any, Dict -import torch -import argparse +import torch +import argparse from diffusers.loaders.lora_base import LoraBaseMixin from diffusers.models.modeling_utils import load_state_dict @@ -15,8 +15,8 @@ def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]: state_dict = state_dict["state_dict"] return state_dict -LORA_KEYS_RENAME = { +LORA_KEYS_RENAME = { 'attention.query_key_value.matrix_A.0': 'attn1.to_q.lora_A.weight', 'attention.query_key_value.matrix_A.1': 'attn1.to_k.lora_A.weight', 'attention.query_key_value.matrix_A.2': 'attn1.to_v.lora_A.weight', @@ -24,22 +24,18 @@ LORA_KEYS_RENAME = { 'attention.query_key_value.matrix_B.1': 'attn1.to_k.lora_B.weight', 'attention.query_key_value.matrix_B.2': 'attn1.to_v.lora_B.weight', 'attention.dense.matrix_A.0': 'attn1.to_out.0.lora_A.weight', - 'attention.dense.matrix_B.0': 'attn1.to_out.0.lora_B.weight' + 'attention.dense.matrix_B.0': 'attn1.to_out.0.lora_B.weight', } - PREFIX_KEY = "model.diffusion_model." SAT_UNIT_KEY = "layers" LORA_PREFIX_KEY = "transformer_blocks" - -def export_lora_weight(ckpt_path,lora_save_directory): - +def export_lora_weight(ckpt_path, lora_save_directory): merge_original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True)) - lora_state_dict = {} for key in list(merge_original_state_dict.keys()): new_key = key[len(PREFIX_KEY) :] @@ -50,9 +46,7 @@ def export_lora_weight(ckpt_path,lora_save_directory): lora_state_dict[new_key] = merge_original_state_dict[key] - - - # final length should be 240 + # final length should be 240 if len(lora_state_dict) != 240: raise ValueError("lora_state_dict length is not 240") @@ -64,7 +58,7 @@ def export_lora_weight(ckpt_path,lora_save_directory): is_main_process=True, weight_name=None, save_function=None, - safe_serialization=True + safe_serialization=True, ) @@ -73,7 +67,12 @@ def get_args(): parser.add_argument( "--sat_pt_path", type=str, required=True, help="Path to original sat transformer checkpoint" ) - parser.add_argument("--lora_save_directory", type=str, required=True, help="Path where converted lora should be saved") + parser.add_argument( + "--lora_save_directory", + type=str, + required=True, + help="Path where converted lora should be saved", + ) return parser.parse_args() diff --git a/tools/llm_flux_cogvideox/generate.sh b/tools/llm_flux_cogvideox/generate.sh index c455273..ade1cbe 100644 --- a/tools/llm_flux_cogvideox/generate.sh +++ b/tools/llm_flux_cogvideox/generate.sh @@ -32,4 +32,4 @@ do --use_dynamic_cfg \ --output_dir ${OUTPUT_DIR_PREFIX}${GPU} \ > ${LOG_DIR_PREFIX}${GPU}.log 2>&1 & -done \ No newline at end of file +done diff --git a/tools/llm_flux_cogvideox/gradio_page.py b/tools/llm_flux_cogvideox/gradio_page.py index 588c469..27fc1a0 100644 --- a/tools/llm_flux_cogvideox/gradio_page.py +++ b/tools/llm_flux_cogvideox/gradio_page.py @@ -35,20 +35,16 @@ caption_generator = transformers.pipeline( "torch_dtype": torch.bfloat16, }, trust_remote_code=True, - tokenizer=tokenizer + tokenizer=tokenizer, ) image_generator = DiffusionPipeline.from_pretrained( - image_generator_model_id, - torch_dtype=torch.bfloat16, - device_map="balanced" + image_generator_model_id, torch_dtype=torch.bfloat16, device_map="balanced" ) # image_generator.to("cuda") video_generator = CogVideoXImageToVideoPipeline.from_pretrained( - video_generator_model_id, - torch_dtype=torch.bfloat16, - device_map="balanced" + video_generator_model_id, torch_dtype=torch.bfloat16, device_map="balanced" ) video_generator.vae.enable_slicing() @@ -87,11 +83,7 @@ def generate_caption(prompt): {"role": "user", "content": prompt + "\n" + user_prompt}, ] - response = caption_generator( - messages, - max_new_tokens=226, - return_full_text=False - ) + response = caption_generator(messages, max_new_tokens=226, return_full_text=False) caption = response[0]["generated_text"] if caption.startswith("\"") and caption.endswith("\""): caption = caption[1:-1] @@ -109,11 +101,7 @@ def generate_image(caption, progress=gr.Progress(track_tqdm=True)): return image, image # One for output One for State -def generate_video( - caption, - image, - progress=gr.Progress(track_tqdm=True) -): +def generate_video(caption, image, progress=gr.Progress(track_tqdm=True)): generator = torch.Generator().manual_seed(seed) video_frames = video_generator( image=image, @@ -181,14 +169,19 @@ with gr.Blocks() as demo: image_output = gr.Image(label="Generated Image") state_image = gr.State() generate_caption_button.click(fn=generate_caption, inputs=prompt, outputs=caption) - generate_image_button.click(fn=generate_image, inputs=caption, outputs=[image_output, state_image]) + generate_image_button.click( + fn=generate_image, inputs=caption, outputs=[image_output, state_image] + ) with gr.Column(): video_output = gr.Video(label="Generated Video", width=720, height=480) download_video_button = gr.File(label="📥 Download Video", visible=False) download_gif_button = gr.File(label="📥 Download GIF", visible=False) generate_video_button = gr.Button("Generate Video from Image") - generate_video_button.click(fn=generate_video, inputs=[caption, state_image], - outputs=[video_output, download_gif_button]) + generate_video_button.click( + fn=generate_video, + inputs=[caption, state_image], + outputs=[video_output, download_gif_button], + ) if __name__ == "__main__": demo.launch() diff --git a/tools/llm_flux_cogvideox/llm_flux_cogvideox.py b/tools/llm_flux_cogvideox/llm_flux_cogvideox.py index 8e97888..853f4fc 100644 --- a/tools/llm_flux_cogvideox/llm_flux_cogvideox.py +++ b/tools/llm_flux_cogvideox/llm_flux_cogvideox.py @@ -54,7 +54,7 @@ You responses should just be the video generation prompt. Here are examples: """.strip() USER_PROMPT = """ -Could you generate a prompt for a video generation model? +Could you generate a prompt for a video generation model? Please limit the prompt to [{0}] words. """.strip() @@ -65,7 +65,7 @@ def get_args(): "--num_videos", type=int, default=5, - help="Number of unique videos you would like to generate." + help="Number of unique videos you would like to generate.", ) parser.add_argument( "--model_path", @@ -83,31 +83,28 @@ def get_args(): "--caption_generator_cache_dir", type=str, default=None, - help="Cache directory for caption generation model." + help="Cache directory for caption generation model.", ) parser.add_argument( "--image_generator_model_id", type=str, default="black-forest-labs/FLUX.1-dev", - help="Image generation model." + help="Image generation model.", ) parser.add_argument( "--image_generator_cache_dir", type=str, default=None, - help="Cache directory for image generation model." + help="Cache directory for image generation model.", ) parser.add_argument( "--image_generator_num_inference_steps", type=int, default=50, - help="Caption generation model." + help="Caption generation model.", ) parser.add_argument( - "--guidance_scale", - type=float, - default=7, - help="Guidance scale to be use for generation." + "--guidance_scale", type=float, default=7, help="Guidance scale to be use for generation." ) parser.add_argument( "--use_dynamic_cfg", @@ -123,19 +120,14 @@ def get_args(): parser.add_argument( "--compile", action="store_true", - help="Whether or not to compile the transformer of image and video generators." + help="Whether or not to compile the transformer of image and video generators.", ) parser.add_argument( "--enable_vae_tiling", action="store_true", - help="Whether or not to use VAE tiling when encoding/decoding." - ) - parser.add_argument( - "--seed", - type=int, - default=42, - help="Seed for reproducibility." + help="Whether or not to use VAE tiling when encoding/decoding.", ) + parser.add_argument("--seed", type=int, default=42, help="Seed for reproducibility.") return parser.parse_args() @@ -157,7 +149,9 @@ def main(args: Dict[str, Any]) -> None: torch.cuda.manual_seed_all(args.seed) reset_memory() - tokenizer = AutoTokenizer.from_pretrained(args.caption_generator_model_id, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained( + args.caption_generator_model_id, trust_remote_code=True + ) caption_generator = transformers.pipeline( "text-generation", model=args.caption_generator_model_id, @@ -168,7 +162,7 @@ def main(args: Dict[str, Any]) -> None: "torch_dtype": torch.bfloat16, }, trust_remote_code=True, - tokenizer=tokenizer + tokenizer=tokenizer, ) captions = [] @@ -197,12 +191,14 @@ def main(args: Dict[str, Any]) -> None: image_generator = DiffusionPipeline.from_pretrained( args.image_generator_model_id, cache_dir=args.image_generator_cache_dir, - torch_dtype=torch.bfloat16 + torch_dtype=torch.bfloat16, ) image_generator.to("cuda") if args.compile: - image_generator.transformer = torch.compile(image_generator.transformer, mode="max-autotune", fullgraph=True) + image_generator.transformer = torch.compile( + image_generator.transformer, mode="max-autotune", fullgraph=True + ) if args.enable_vae_tiling: image_generator.vae.enable_tiling() @@ -216,7 +212,9 @@ def main(args: Dict[str, Any]) -> None: num_inference_steps=args.image_generator_num_inference_steps, guidance_scale=3.5, ).images[0] - filename = caption[:25].replace(".", "_").replace("'", "_").replace('"', "_").replace(",", "_") + filename = ( + caption[:25].replace(".", "_").replace("'", "_").replace('"', "_").replace(",", "_") + ) image.save(output_dir / f"{index}_{filename}.png") images.append(image) @@ -224,13 +222,16 @@ def main(args: Dict[str, Any]) -> None: reset_memory() video_generator = CogVideoXImageToVideoPipeline.from_pretrained( - args.model_path, torch_dtype=torch.bfloat16).to("cuda") + args.model_path, torch_dtype=torch.bfloat16 + ).to("cuda") video_generator.scheduler = CogVideoXDPMScheduler.from_config( - video_generator.scheduler.config, - timestep_spacing="trailing") + video_generator.scheduler.config, timestep_spacing="trailing" + ) if args.compile: - video_generator.transformer = torch.compile(video_generator.transformer, mode="max-autotune", fullgraph=True) + video_generator.transformer = torch.compile( + video_generator.transformer, mode="max-autotune", fullgraph=True + ) if args.enable_vae_tiling: video_generator.vae.enable_tiling() @@ -248,7 +249,9 @@ def main(args: Dict[str, Any]) -> None: use_dynamic_cfg=args.use_dynamic_cfg, generator=generator, ).frames[0] - filename = caption[:25].replace(".", "_").replace("'", "_").replace('"', "_").replace(",", "_") + filename = ( + caption[:25].replace(".", "_").replace("'", "_").replace('"', "_").replace(",", "_") + ) export_to_video(video, output_dir / f"{index}_{filename}.mp4", fps=8) diff --git a/tools/load_cogvideox_lora.py b/tools/load_cogvideox_lora.py index 1b12975..2adf30e 100644 --- a/tools/load_cogvideox_lora.py +++ b/tools/load_cogvideox_lora.py @@ -14,8 +14,8 @@ # limitations under the License. -import math -import random +import math +import random import time from diffusers.utils import export_to_video from diffusers.image_processor import VaeImageProcessor @@ -49,8 +49,8 @@ def get_args(): "--lora_r", type=int, default=128, - help="""LoRA weights have a rank parameter, with the default for 2B trans set at 128 and 5B trans set at 256. - This part is used to calculate the value for lora_scale, which is by default divided by the alpha value, + help="""LoRA weights have a rank parameter, with the default for 2B trans set at 128 and 5B trans set at 256. + This part is used to calculate the value for lora_scale, which is by default divided by the alpha value, used for stable learning and to prevent underflow. In the SAT training framework, alpha is set to 1 by default. The higher the rank, the better the expressive capability, but it requires more memory and training time. Increasing this number blindly isn't always better. @@ -61,8 +61,8 @@ def get_args(): "--lora_alpha", type=int, default=1, - help="""LoRA weights have a rank parameter, with the default for 2B trans set at 128 and 5B trans set at 256. - This part is used to calculate the value for lora_scale, which is by default divided by the alpha value, + help="""LoRA weights have a rank parameter, with the default for 2B trans set at 128 and 5B trans set at 256. + This part is used to calculate the value for lora_scale, which is by default divided by the alpha value, used for stable learning and to prevent underflow. In the SAT training framework, alpha is set to 1 by default. The higher the rank, the better the expressive capability, but it requires more memory and training time. Increasing this number blindly isn't always better. @@ -85,17 +85,24 @@ def get_args(): if __name__ == "__main__": args = get_args() - pipe = CogVideoXPipeline.from_pretrained(args.pretrained_model_name_or_path, torch_dtype=torch.bfloat16).to(device) - pipe.load_lora_weights(args.lora_weights_path, weight_name="pytorch_lora_weights.safetensors", adapter_name="cogvideox-lora") + pipe = CogVideoXPipeline.from_pretrained( + args.pretrained_model_name_or_path, torch_dtype=torch.bfloat16 + ).to(device) + pipe.load_lora_weights( + args.lora_weights_path, + weight_name="pytorch_lora_weights.safetensors", + adapter_name="cogvideox-lora", + ) # pipe.fuse_lora(lora_scale=args.lora_alpha/args.lora_r, ['transformer']) - lora_scaling=args.lora_alpha/args.lora_r + lora_scaling = args.lora_alpha / args.lora_r pipe.set_adapters(["cogvideox-lora"], [lora_scaling]) - - pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") + pipe.scheduler = CogVideoXDPMScheduler.from_config( + pipe.scheduler.config, timestep_spacing="trailing" + ) os.makedirs(args.output_dir, exist_ok=True) - + latents = pipe( prompt=args.prompt, num_videos_per_prompt=1, @@ -120,6 +127,6 @@ if __name__ == "__main__": video_path = f"{args.output_dir}/{timestamp}.mp4" os.makedirs(os.path.dirname(video_path), exist_ok=True) tensor = batch_video_frames[0] - fps=math.ceil((len(batch_video_frames[0]) - 1) / 6) + fps = math.ceil((len(batch_video_frames[0]) - 1) / 6) - export_to_video(tensor, video_path, fps=fps) \ No newline at end of file + export_to_video(tensor, video_path, fps=fps) diff --git a/tools/parallel_inference/run.sh b/tools/parallel_inference/run.sh index 7f9d5a8..ee421a4 100644 --- a/tools/parallel_inference/run.sh +++ b/tools/parallel_inference/run.sh @@ -3,8 +3,8 @@ set -x export PYTHONPATH=$PWD:$PYTHONPATH # Select the model type -# The model is downloaded to a specified location on disk, -# or you can simply use the model's ID on Hugging Face, +# The model is downloaded to a specified location on disk, +# or you can simply use the model's ID on Hugging Face, # which will then be downloaded to the default cache path on Hugging Face. export MODEL_TYPE="CogVideoX" diff --git a/tools/replicate/predict_i2v.py b/tools/replicate/predict_i2v.py index 5e45796..69c294a 100644 --- a/tools/replicate/predict_i2v.py +++ b/tools/replicate/predict_i2v.py @@ -11,9 +11,7 @@ from cog import BasePredictor, Input, Path MODEL_CACHE = "model_cache_i2v" -MODEL_URL = ( - f"https://weights.replicate.delivery/default/THUDM/CogVideo/{MODEL_CACHE}.tar" -) +MODEL_URL = f"https://weights.replicate.delivery/default/THUDM/CogVideo/{MODEL_CACHE}.tar" os.environ["HF_DATASETS_OFFLINE"] = "1" os.environ["TRANSFORMERS_OFFLINE"] = "1" os.environ["HF_HOME"] = MODEL_CACHE @@ -48,9 +46,7 @@ class Predictor(BasePredictor): def predict( self, - prompt: str = Input( - description="Input prompt", default="Starry sky slowly rotating." - ), + prompt: str = Input(description="Input prompt", default="Starry sky slowly rotating."), image: Path = Input(description="Input image"), num_inference_steps: int = Input( description="Number of denoising steps", ge=1, le=500, default=50 @@ -58,9 +54,7 @@ class Predictor(BasePredictor): guidance_scale: float = Input( description="Scale for classifier-free guidance", ge=1, le=20, default=6 ), - num_frames: int = Input( - description="Number of frames for the output video", default=49 - ), + num_frames: int = Input(description="Number of frames for the output video", default=49), seed: int = Input( description="Random seed. Leave blank to randomize the seed", default=None ), diff --git a/tools/replicate/predict_t2v.py b/tools/replicate/predict_t2v.py index cadeee2..51f1106 100644 --- a/tools/replicate/predict_t2v.py +++ b/tools/replicate/predict_t2v.py @@ -11,9 +11,7 @@ from cog import BasePredictor, Input, Path MODEL_CACHE = "model_cache" -MODEL_URL = ( - f"https://weights.replicate.delivery/default/THUDM/CogVideo/{MODEL_CACHE}.tar" -) +MODEL_URL = f"https://weights.replicate.delivery/default/THUDM/CogVideo/{MODEL_CACHE}.tar" os.environ["HF_DATASETS_OFFLINE"] = "1" os.environ["TRANSFORMERS_OFFLINE"] = "1" os.environ["HF_HOME"] = MODEL_CACHE @@ -59,9 +57,7 @@ class Predictor(BasePredictor): guidance_scale: float = Input( description="Scale for classifier-free guidance", ge=1, le=20, default=6 ), - num_frames: int = Input( - description="Number of frames for the output video", default=49 - ), + num_frames: int = Input(description="Number of frames for the output video", default=49), seed: int = Input( description="Random seed. Leave blank to randomize the seed", default=None ), diff --git a/tools/venhancer/README.md b/tools/venhancer/README.md index cc6f45c..05e037d 100644 --- a/tools/venhancer/README.md +++ b/tools/venhancer/README.md @@ -95,4 +95,4 @@ Typical runtime logs are as follows: ``` Running on a single A100 GPU, enhancing each 6-second CogVideoX generated video with default settings will consume 60GB -of VRAM and take 40-50 minutes. \ No newline at end of file +of VRAM and take 40-50 minutes. diff --git a/tools/venhancer/README_zh.md b/tools/venhancer/README_zh.md index a481cd1..0deabcc 100644 --- a/tools/venhancer/README_zh.md +++ b/tools/venhancer/README_zh.md @@ -38,7 +38,7 @@ python enhance_a_video.py \ --solver_mode 'fast' --steps 15 \ --input_path inputs/000000.mp4 \ --prompt 'Wide-angle aerial shot at dawn,soft morning light casting long shadows,an elderly man walking his dog through a quiet,foggy park,trees and benches in the background,peaceful and serene atmosphere' \ ---save_dir 'results/' +--save_dir 'results/' ``` 其中: @@ -98,4 +98,4 @@ python enhance_a_video.py \ ``` -使用A100单卡运行,对于每个CogVideoX生产的6秒视频,按照默认配置,会消耗60G显存,并用时40-50分钟。 \ No newline at end of file +使用A100单卡运行,对于每个CogVideoX生产的6秒视频,按照默认配置,会消耗60G显存,并用时40-50分钟。