mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 19:41:59 +08:00
format
This commit is contained in:
parent
b9b0539dbe
commit
39c6562dc8
28
.github/ISSUE_TEMPLATE/PULL_REQUEST_TEMPLATE.md
vendored
Normal file
28
.github/ISSUE_TEMPLATE/PULL_REQUEST_TEMPLATE.md
vendored
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
# Contribution Guide
|
||||||
|
|
||||||
|
We welcome your contributions to this repository. To ensure elegant code style and better code quality, we have prepared the following contribution guidelines.
|
||||||
|
|
||||||
|
## What We Accept
|
||||||
|
|
||||||
|
+ This PR fixes a typo or improves the documentation (if this is the case, you may skip the other checks).
|
||||||
|
+ This PR fixes a specific issue — please reference the issue number in the PR description. Make sure your code strictly follows the coding standards below.
|
||||||
|
+ This PR introduces a new feature — please clearly explain the necessity and implementation of the feature. Make sure your code strictly follows the coding standards below.
|
||||||
|
|
||||||
|
## Code Style Guide
|
||||||
|
|
||||||
|
Good code style is an art. We have prepared a `pyproject.toml` and a `pre-commit` hook to enforce consistent code formatting across the project. You can clean up your code following the steps below:
|
||||||
|
|
||||||
|
1. Install the required dependencies:
|
||||||
|
```shell
|
||||||
|
pip install ruff pre-commit
|
||||||
|
```
|
||||||
|
2. Then, run the following command:
|
||||||
|
```shell
|
||||||
|
pre-commit run --all-files
|
||||||
|
```
|
||||||
|
If your code complies with the standards, you should not see any errors.
|
||||||
|
|
||||||
|
## Naming Conventions
|
||||||
|
|
||||||
|
- Please use **English** for naming; do not use Pinyin or other languages. All comments should also be in English.
|
||||||
|
- Follow **PEP8** naming conventions strictly, and use underscores to separate words. Avoid meaningless names such as `a`, `b`, `c`.
|
34
.github/PULL_REQUEST_TEMPLATE/pr_template.md
vendored
34
.github/PULL_REQUEST_TEMPLATE/pr_template.md
vendored
@ -1,34 +0,0 @@
|
|||||||
# Raise valuable PR / 提出有价值的PR
|
|
||||||
|
|
||||||
## Caution / 注意事项:
|
|
||||||
Users should keep the following points in mind when submitting PRs:
|
|
||||||
|
|
||||||
1. Ensure that your code meets the requirements in the [specification](../../resources/contribute.md).
|
|
||||||
2. the proposed PR should be relevant, if there are multiple ideas and optimizations, they should be assigned to different PRs.
|
|
||||||
|
|
||||||
用户在提交PR时候应该注意以下几点:
|
|
||||||
|
|
||||||
1. 确保您的代码符合 [规范](../../resources/contribute_zh.md) 中的要求。
|
|
||||||
2. 提出的PR应该具有针对性,如果具有多个不同的想法和优化方案,应该分配到不同的PR中。
|
|
||||||
|
|
||||||
## 不应该提出的PR / PRs that should not be proposed
|
|
||||||
|
|
||||||
If a developer proposes a PR about any of the following, it may be closed or Rejected.
|
|
||||||
|
|
||||||
1. those that don't describe improvement options.
|
|
||||||
2. multiple issues of different types combined in one PR.
|
|
||||||
3. The proposed PR is highly duplicative of already existing PRs.
|
|
||||||
|
|
||||||
如果开发者提出关于以下方面的PR,则可能会被直接关闭或拒绝通过。
|
|
||||||
|
|
||||||
1. 没有说明改进方案的。
|
|
||||||
2. 多个不同类型的问题合并在一个PR中的。
|
|
||||||
3. 提出的PR与已经存在的PR高度重复的。
|
|
||||||
|
|
||||||
|
|
||||||
# 检查您的PR
|
|
||||||
- [ ] Have you read the Contributor Guidelines, Pull Request section? / 您是否阅读了贡献者指南、Pull Request 部分?
|
|
||||||
- [ ] Has this been discussed/approved via a Github issue or forum? If so, add a link. / 是否通过 Github 问题或论坛讨论/批准过?如果是,请添加链接。
|
|
||||||
- [ ] Did you make sure you updated the documentation with your changes? Here are the Documentation Guidelines, and here are the Documentation Formatting Tips. /您是否确保根据您的更改更新了文档?这里是文档指南,这里是文档格式化技巧。
|
|
||||||
- [ ] Did you write new required tests? / 您是否编写了新的必要测试?
|
|
||||||
- [ ] Are your PRs for only one issue / 您的PR是否仅针对一个问题
|
|
19
.pre-commit-config.yaml
Normal file
19
.pre-commit-config.yaml
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
repos:
|
||||||
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
|
rev: v0.4.5
|
||||||
|
hooks:
|
||||||
|
- id: ruff
|
||||||
|
args: [--fix, --respect-gitignore, --config=pyproject.toml]
|
||||||
|
- id: ruff-format
|
||||||
|
args: [--config=pyproject.toml]
|
||||||
|
|
||||||
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
|
rev: v4.5.0
|
||||||
|
hooks:
|
||||||
|
- id: trailing-whitespace
|
||||||
|
- id: end-of-file-fixer
|
||||||
|
- id: check-yaml
|
||||||
|
- id: check-toml
|
||||||
|
- id: check-case-conflict
|
||||||
|
- id: check-merge-conflict
|
||||||
|
- id: debug-statements
|
@ -22,7 +22,7 @@ Experience the CogVideoX-5B model online at <a href="https://huggingface.co/spac
|
|||||||
|
|
||||||
## Project Updates
|
## Project Updates
|
||||||
|
|
||||||
- 🔥🔥 **News**: ```2025/03/16```: We have launched [CogKit](https://github.com/THUDM/CogKit), a fine-tuning and inference framework for the **CogView4** and **CogVideoX** series. This toolkit allows you to fully explore and utilize our multimodal generation models.
|
- 🔥🔥 **News**: ```2025/03/24```: We have launched [CogKit](https://github.com/THUDM/CogKit), a fine-tuning and inference framework for the **CogView4** and **CogVideoX** series. This toolkit allows you to fully explore and utilize our multimodal generation models.
|
||||||
- 🔥 **News**: ```2025/02/28```: DDIM Inverse is now supported in `CogVideoX-5B` and `CogVideoX1.5-5B`. Check [here](inference/ddim_inversion.py).
|
- 🔥 **News**: ```2025/02/28```: DDIM Inverse is now supported in `CogVideoX-5B` and `CogVideoX1.5-5B`. Check [here](inference/ddim_inversion.py).
|
||||||
- 🔥 **News**: ```2025/01/08```: We have updated the code for `Lora` fine-tuning based on the `diffusers` version model, which uses less GPU memory. For more details, please see [here](finetune/README.md).
|
- 🔥 **News**: ```2025/01/08```: We have updated the code for `Lora` fine-tuning based on the `diffusers` version model, which uses less GPU memory. For more details, please see [here](finetune/README.md).
|
||||||
- 🔥 **News**: ```2024/11/15```: We released the `CogVideoX1.5` model in the diffusers version. Only minor parameter adjustments are needed to continue using previous code.
|
- 🔥 **News**: ```2024/11/15```: We released the `CogVideoX1.5` model in the diffusers version. Only minor parameter adjustments are needed to continue using previous code.
|
||||||
@ -444,8 +444,6 @@ hands-on practice on text-to-video generation. *The original input is in Chinese
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
We welcome your contributions! You can click [here](resources/contribute.md) for more information.
|
|
||||||
|
|
||||||
## Model-License
|
## Model-License
|
||||||
|
|
||||||
The code in this repository is released under the [Apache 2.0 License](LICENSE).
|
The code in this repository is released under the [Apache 2.0 License](LICENSE).
|
||||||
|
@ -21,7 +21,7 @@
|
|||||||
</p>
|
</p>
|
||||||
|
|
||||||
## 更新とニュース
|
## 更新とニュース
|
||||||
- 🔥🔥 ```2025/03/16```: [CogKit](https://github.com/THUDM/CogKit) は **CogView4** および **CogVideoX** シリーズの微調整と推論のためのフレームワークです。このツールキットを活用することで、私たちのマルチモーダル生成モデルを最大限に活用できます。
|
- 🔥🔥 ```2025/03/24```: [CogKit](https://github.com/THUDM/CogKit) は **CogView4** および **CogVideoX** シリーズの微調整と推論のためのフレームワークです。このツールキットを活用することで、私たちのマルチモーダル生成モデルを最大限に活用できます。
|
||||||
- **ニュース**: ```2025/02/28```: DDIM Inverse が `CogVideoX-5B` と `CogVideoX1.5-5B` でサポートされました。詳細は [こちら](inference/ddim_inversion.py) をご覧ください。
|
- **ニュース**: ```2025/02/28```: DDIM Inverse が `CogVideoX-5B` と `CogVideoX1.5-5B` でサポートされました。詳細は [こちら](inference/ddim_inversion.py) をご覧ください。
|
||||||
- **ニュース**: ```2025/01/08```: 私たちは`diffusers`バージョンのモデルをベースにした`Lora`微調整用のコードを更新しました。より少ないVRAM(ビデオメモリ)で動作します。詳細については[こちら](finetune/README_ja.md)をご覧ください。
|
- **ニュース**: ```2025/01/08```: 私たちは`diffusers`バージョンのモデルをベースにした`Lora`微調整用のコードを更新しました。より少ないVRAM(ビデオメモリ)で動作します。詳細については[こちら](finetune/README_ja.md)をご覧ください。
|
||||||
- **ニュース**: ```2024/11/15```: `CogVideoX1.5` モデルのdiffusersバージョンをリリースしました。わずかなパラメータ調整で以前のコードをそのまま利用可能です。
|
- **ニュース**: ```2024/11/15```: `CogVideoX1.5` モデルのdiffusersバージョンをリリースしました。わずかなパラメータ調整で以前のコードをそのまま利用可能です。
|
||||||
@ -418,8 +418,6 @@ CogVideoのデモは [https://models.aminer.cn/cogvideo](https://models.aminer.c
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
あなたの貢献をお待ちしています!詳細は[こちら](resources/contribute_ja.md)をクリックしてください。
|
|
||||||
|
|
||||||
## ライセンス契約
|
## ライセンス契約
|
||||||
|
|
||||||
このリポジトリのコードは [Apache 2.0 License](LICENSE) の下で公開されています。
|
このリポジトリのコードは [Apache 2.0 License](LICENSE) の下で公開されています。
|
||||||
|
@ -22,7 +22,7 @@
|
|||||||
|
|
||||||
## 项目更新
|
## 项目更新
|
||||||
|
|
||||||
- 🔥🔥 **News**: ```2025/03/16```: 我们推出了 [CogKit](https://github.com/THUDM/CogKit) 工具,这是一个微调**CogView4**, **CogVideoX** 系列的微调和推理框架,一个工具包,玩转我们的多模态生成模型。
|
- 🔥🔥 **News**: ```2025/03/24```: 我们推出了 [CogKit](https://github.com/THUDM/CogKit) 工具,这是一个微调**CogView4**, **CogVideoX** 系列的微调和推理框架,一个工具包,玩转我们的多模态生成模型。
|
||||||
- 🔥 **News**: ```2025/02/28```: DDIM Inverse 已经在`CogVideoX-5B` 和 `CogVideoX1.5 -5B` 支持,查看 [here](inference/ddim_inversion.py).
|
- 🔥 **News**: ```2025/02/28```: DDIM Inverse 已经在`CogVideoX-5B` 和 `CogVideoX1.5 -5B` 支持,查看 [here](inference/ddim_inversion.py).
|
||||||
- 🔥 **News**: ```2025/01/08```: 我们更新了基于`diffusers`版本模型的`Lora`微调代码,占用显存更低,详情请见[这里](finetune/README_zh.md)。
|
- 🔥 **News**: ```2025/01/08```: 我们更新了基于`diffusers`版本模型的`Lora`微调代码,占用显存更低,详情请见[这里](finetune/README_zh.md)。
|
||||||
- 🔥 **News**: ```2024/11/15```: 我们发布 `CogVideoX1.5` 模型的diffusers版本,仅需调整部分参数仅可沿用之前的代码。
|
- 🔥 **News**: ```2024/11/15```: 我们发布 `CogVideoX1.5` 模型的diffusers版本,仅需调整部分参数仅可沿用之前的代码。
|
||||||
@ -398,8 +398,6 @@ CogVideo的demo网站在[https://models.aminer.cn/cogvideo](https://models.amine
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
我们欢迎您的贡献,您可以点击[这里](resources/contribute_zh.md)查看更多信息。
|
|
||||||
|
|
||||||
## 模型协议
|
## 模型协议
|
||||||
|
|
||||||
本仓库代码使用 [Apache 2.0 协议](LICENSE) 发布。
|
本仓库代码使用 [Apache 2.0 协议](LICENSE) 发布。
|
||||||
|
@ -26,7 +26,11 @@ class BucketSampler(Sampler):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, data_source: Dataset, batch_size: int = 8, shuffle: bool = True, drop_last: bool = False
|
self,
|
||||||
|
data_source: Dataset,
|
||||||
|
batch_size: int = 8,
|
||||||
|
shuffle: bool = True,
|
||||||
|
drop_last: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.data_source = data_source
|
self.data_source = data_source
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
@ -48,7 +52,11 @@ class BucketSampler(Sampler):
|
|||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
for index, data in enumerate(self.data_source):
|
for index, data in enumerate(self.data_source):
|
||||||
video_metadata = data["video_metadata"]
|
video_metadata = data["video_metadata"]
|
||||||
f, h, w = video_metadata["num_frames"], video_metadata["height"], video_metadata["width"]
|
f, h, w = (
|
||||||
|
video_metadata["num_frames"],
|
||||||
|
video_metadata["height"],
|
||||||
|
video_metadata["width"],
|
||||||
|
)
|
||||||
|
|
||||||
self.buckets[(f, h, w)].append(data)
|
self.buckets[(f, h, w)].append(data)
|
||||||
if len(self.buckets[(f, h, w)]) == self.batch_size:
|
if len(self.buckets[(f, h, w)]) == self.batch_size:
|
||||||
|
@ -115,7 +115,9 @@ class BaseI2VDataset(Dataset):
|
|||||||
train_resolution_str = "x".join(str(x) for x in self.trainer.args.train_resolution)
|
train_resolution_str = "x".join(str(x) for x in self.trainer.args.train_resolution)
|
||||||
|
|
||||||
cache_dir = self.trainer.args.data_root / "cache"
|
cache_dir = self.trainer.args.data_root / "cache"
|
||||||
video_latent_dir = cache_dir / "video_latent" / self.trainer.args.model_name / train_resolution_str
|
video_latent_dir = (
|
||||||
|
cache_dir / "video_latent" / self.trainer.args.model_name / train_resolution_str
|
||||||
|
)
|
||||||
prompt_embeddings_dir = cache_dir / "prompt_embeddings"
|
prompt_embeddings_dir = cache_dir / "prompt_embeddings"
|
||||||
video_latent_dir.mkdir(parents=True, exist_ok=True)
|
video_latent_dir.mkdir(parents=True, exist_ok=True)
|
||||||
prompt_embeddings_dir.mkdir(parents=True, exist_ok=True)
|
prompt_embeddings_dir.mkdir(parents=True, exist_ok=True)
|
||||||
@ -136,7 +138,9 @@ class BaseI2VDataset(Dataset):
|
|||||||
# [1, seq_len, hidden_size] -> [seq_len, hidden_size]
|
# [1, seq_len, hidden_size] -> [seq_len, hidden_size]
|
||||||
prompt_embedding = prompt_embedding[0]
|
prompt_embedding = prompt_embedding[0]
|
||||||
save_file({"prompt_embedding": prompt_embedding}, prompt_embedding_path)
|
save_file({"prompt_embedding": prompt_embedding}, prompt_embedding_path)
|
||||||
logger.info(f"Saved prompt embedding to {prompt_embedding_path}", main_process_only=False)
|
logger.info(
|
||||||
|
f"Saved prompt embedding to {prompt_embedding_path}", main_process_only=False
|
||||||
|
)
|
||||||
|
|
||||||
if encoded_video_path.exists():
|
if encoded_video_path.exists():
|
||||||
encoded_video = load_file(encoded_video_path)["encoded_video"]
|
encoded_video = load_file(encoded_video_path)["encoded_video"]
|
||||||
@ -177,7 +181,9 @@ class BaseI2VDataset(Dataset):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def preprocess(self, video_path: Path | None, image_path: Path | None) -> Tuple[torch.Tensor, torch.Tensor]:
|
def preprocess(
|
||||||
|
self, video_path: Path | None, image_path: Path | None
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Loads and preprocesses a video and an image.
|
Loads and preprocesses a video and an image.
|
||||||
If either path is None, no preprocessing will be done for that input.
|
If either path is None, no preprocessing will be done for that input.
|
||||||
@ -249,13 +255,19 @@ class I2VDatasetWithResize(BaseI2VDataset):
|
|||||||
self.height = height
|
self.height = height
|
||||||
self.width = width
|
self.width = width
|
||||||
|
|
||||||
self.__frame_transforms = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)])
|
self.__frame_transforms = transforms.Compose(
|
||||||
|
[transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)]
|
||||||
|
)
|
||||||
self.__image_transforms = self.__frame_transforms
|
self.__image_transforms = self.__frame_transforms
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def preprocess(self, video_path: Path | None, image_path: Path | None) -> Tuple[torch.Tensor, torch.Tensor]:
|
def preprocess(
|
||||||
|
self, video_path: Path | None, image_path: Path | None
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
if video_path is not None:
|
if video_path is not None:
|
||||||
video = preprocess_video_with_resize(video_path, self.max_num_frames, self.height, self.width)
|
video = preprocess_video_with_resize(
|
||||||
|
video_path, self.max_num_frames, self.height, self.width
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
video = None
|
video = None
|
||||||
if image_path is not None:
|
if image_path is not None:
|
||||||
@ -293,7 +305,9 @@ class I2VDatasetWithBuckets(BaseI2VDataset):
|
|||||||
)
|
)
|
||||||
for b in video_resolution_buckets
|
for b in video_resolution_buckets
|
||||||
]
|
]
|
||||||
self.__frame_transforms = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)])
|
self.__frame_transforms = transforms.Compose(
|
||||||
|
[transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)]
|
||||||
|
)
|
||||||
self.__image_transforms = self.__frame_transforms
|
self.__image_transforms = self.__frame_transforms
|
||||||
|
|
||||||
@override
|
@override
|
||||||
|
@ -11,7 +11,12 @@ from typing_extensions import override
|
|||||||
|
|
||||||
from finetune.constants import LOG_LEVEL, LOG_NAME
|
from finetune.constants import LOG_LEVEL, LOG_NAME
|
||||||
|
|
||||||
from .utils import load_prompts, load_videos, preprocess_video_with_buckets, preprocess_video_with_resize
|
from .utils import (
|
||||||
|
load_prompts,
|
||||||
|
load_videos,
|
||||||
|
preprocess_video_with_buckets,
|
||||||
|
preprocess_video_with_resize,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -93,7 +98,9 @@ class BaseT2VDataset(Dataset):
|
|||||||
train_resolution_str = "x".join(str(x) for x in self.trainer.args.train_resolution)
|
train_resolution_str = "x".join(str(x) for x in self.trainer.args.train_resolution)
|
||||||
|
|
||||||
cache_dir = self.trainer.args.data_root / "cache"
|
cache_dir = self.trainer.args.data_root / "cache"
|
||||||
video_latent_dir = cache_dir / "video_latent" / self.trainer.args.model_name / train_resolution_str
|
video_latent_dir = (
|
||||||
|
cache_dir / "video_latent" / self.trainer.args.model_name / train_resolution_str
|
||||||
|
)
|
||||||
prompt_embeddings_dir = cache_dir / "prompt_embeddings"
|
prompt_embeddings_dir = cache_dir / "prompt_embeddings"
|
||||||
video_latent_dir.mkdir(parents=True, exist_ok=True)
|
video_latent_dir.mkdir(parents=True, exist_ok=True)
|
||||||
prompt_embeddings_dir.mkdir(parents=True, exist_ok=True)
|
prompt_embeddings_dir.mkdir(parents=True, exist_ok=True)
|
||||||
@ -114,7 +121,9 @@ class BaseT2VDataset(Dataset):
|
|||||||
# [1, seq_len, hidden_size] -> [seq_len, hidden_size]
|
# [1, seq_len, hidden_size] -> [seq_len, hidden_size]
|
||||||
prompt_embedding = prompt_embedding[0]
|
prompt_embedding = prompt_embedding[0]
|
||||||
save_file({"prompt_embedding": prompt_embedding}, prompt_embedding_path)
|
save_file({"prompt_embedding": prompt_embedding}, prompt_embedding_path)
|
||||||
logger.info(f"Saved prompt embedding to {prompt_embedding_path}", main_process_only=False)
|
logger.info(
|
||||||
|
f"Saved prompt embedding to {prompt_embedding_path}", main_process_only=False
|
||||||
|
)
|
||||||
|
|
||||||
if encoded_video_path.exists():
|
if encoded_video_path.exists():
|
||||||
# encoded_video = torch.load(encoded_video_path, weights_only=True)
|
# encoded_video = torch.load(encoded_video_path, weights_only=True)
|
||||||
@ -202,7 +211,9 @@ class T2VDatasetWithResize(BaseT2VDataset):
|
|||||||
self.height = height
|
self.height = height
|
||||||
self.width = width
|
self.width = width
|
||||||
|
|
||||||
self.__frame_transform = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)])
|
self.__frame_transform = transforms.Compose(
|
||||||
|
[transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)]
|
||||||
|
)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def preprocess(self, video_path: Path) -> torch.Tensor:
|
def preprocess(self, video_path: Path) -> torch.Tensor:
|
||||||
@ -240,7 +251,9 @@ class T2VDatasetWithBuckets(BaseT2VDataset):
|
|||||||
for b in video_resolution_buckets
|
for b in video_resolution_buckets
|
||||||
]
|
]
|
||||||
|
|
||||||
self.__frame_transform = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)])
|
self.__frame_transform = transforms.Compose(
|
||||||
|
[transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)]
|
||||||
|
)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def preprocess(self, video_path: Path) -> torch.Tensor:
|
def preprocess(self, video_path: Path) -> torch.Tensor:
|
||||||
|
@ -24,12 +24,16 @@ def load_prompts(prompt_path: Path) -> List[str]:
|
|||||||
|
|
||||||
def load_videos(video_path: Path) -> List[Path]:
|
def load_videos(video_path: Path) -> List[Path]:
|
||||||
with open(video_path, "r", encoding="utf-8") as file:
|
with open(video_path, "r", encoding="utf-8") as file:
|
||||||
return [video_path.parent / line.strip() for line in file.readlines() if len(line.strip()) > 0]
|
return [
|
||||||
|
video_path.parent / line.strip() for line in file.readlines() if len(line.strip()) > 0
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def load_images(image_path: Path) -> List[Path]:
|
def load_images(image_path: Path) -> List[Path]:
|
||||||
with open(image_path, "r", encoding="utf-8") as file:
|
with open(image_path, "r", encoding="utf-8") as file:
|
||||||
return [image_path.parent / line.strip() for line in file.readlines() if len(line.strip()) > 0]
|
return [
|
||||||
|
image_path.parent / line.strip() for line in file.readlines() if len(line.strip()) > 0
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def load_images_from_videos(videos_path: List[Path]) -> List[Path]:
|
def load_images_from_videos(videos_path: List[Path]) -> List[Path]:
|
||||||
@ -169,7 +173,9 @@ def preprocess_video_with_buckets(
|
|||||||
video_num_frames = len(video_reader)
|
video_num_frames = len(video_reader)
|
||||||
resolution_buckets = [bucket for bucket in resolution_buckets if bucket[0] <= video_num_frames]
|
resolution_buckets = [bucket for bucket in resolution_buckets if bucket[0] <= video_num_frames]
|
||||||
if len(resolution_buckets) == 0:
|
if len(resolution_buckets) == 0:
|
||||||
raise ValueError(f"video frame count in {video_path} is less than all frame buckets {resolution_buckets}")
|
raise ValueError(
|
||||||
|
f"video frame count in {video_path} is less than all frame buckets {resolution_buckets}"
|
||||||
|
)
|
||||||
|
|
||||||
nearest_frame_bucket = min(
|
nearest_frame_bucket = min(
|
||||||
resolution_buckets,
|
resolution_buckets,
|
||||||
@ -181,7 +187,9 @@ def preprocess_video_with_buckets(
|
|||||||
frames = frames[:nearest_frame_bucket].float()
|
frames = frames[:nearest_frame_bucket].float()
|
||||||
frames = frames.permute(0, 3, 1, 2).contiguous()
|
frames = frames.permute(0, 3, 1, 2).contiguous()
|
||||||
|
|
||||||
nearest_res = min(resolution_buckets, key=lambda x: abs(x[1] - frames.shape[2]) + abs(x[2] - frames.shape[3]))
|
nearest_res = min(
|
||||||
|
resolution_buckets, key=lambda x: abs(x[1] - frames.shape[2]) + abs(x[2] - frames.shape[3])
|
||||||
|
)
|
||||||
nearest_res = (nearest_res[1], nearest_res[2])
|
nearest_res = (nearest_res[1], nearest_res[2])
|
||||||
frames = torch.stack([resize(f, nearest_res) for f in frames], dim=0)
|
frames = torch.stack([resize(f, nearest_res) for f in frames], dim=0)
|
||||||
|
|
||||||
|
@ -32,13 +32,19 @@ class CogVideoXI2VLoraTrainer(Trainer):
|
|||||||
|
|
||||||
components.tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder="tokenizer")
|
components.tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder="tokenizer")
|
||||||
|
|
||||||
components.text_encoder = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder")
|
components.text_encoder = T5EncoderModel.from_pretrained(
|
||||||
|
model_path, subfolder="text_encoder"
|
||||||
|
)
|
||||||
|
|
||||||
components.transformer = CogVideoXTransformer3DModel.from_pretrained(model_path, subfolder="transformer")
|
components.transformer = CogVideoXTransformer3DModel.from_pretrained(
|
||||||
|
model_path, subfolder="transformer"
|
||||||
|
)
|
||||||
|
|
||||||
components.vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae")
|
components.vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae")
|
||||||
|
|
||||||
components.scheduler = CogVideoXDPMScheduler.from_pretrained(model_path, subfolder="scheduler")
|
components.scheduler = CogVideoXDPMScheduler.from_pretrained(
|
||||||
|
model_path, subfolder="scheduler"
|
||||||
|
)
|
||||||
|
|
||||||
return components
|
return components
|
||||||
|
|
||||||
@ -73,7 +79,9 @@ class CogVideoXI2VLoraTrainer(Trainer):
|
|||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
)
|
)
|
||||||
prompt_token_ids = prompt_token_ids.input_ids
|
prompt_token_ids = prompt_token_ids.input_ids
|
||||||
prompt_embedding = self.components.text_encoder(prompt_token_ids.to(self.accelerator.device))[0]
|
prompt_embedding = self.components.text_encoder(
|
||||||
|
prompt_token_ids.to(self.accelerator.device)
|
||||||
|
)[0]
|
||||||
return prompt_embedding
|
return prompt_embedding
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@ -122,22 +130,34 @@ class CogVideoXI2VLoraTrainer(Trainer):
|
|||||||
# Add frame dimension to images [B,C,H,W] -> [B,C,F,H,W]
|
# Add frame dimension to images [B,C,H,W] -> [B,C,F,H,W]
|
||||||
images = images.unsqueeze(2)
|
images = images.unsqueeze(2)
|
||||||
# Add noise to images
|
# Add noise to images
|
||||||
image_noise_sigma = torch.normal(mean=-3.0, std=0.5, size=(1,), device=self.accelerator.device)
|
image_noise_sigma = torch.normal(
|
||||||
|
mean=-3.0, std=0.5, size=(1,), device=self.accelerator.device
|
||||||
|
)
|
||||||
image_noise_sigma = torch.exp(image_noise_sigma).to(dtype=images.dtype)
|
image_noise_sigma = torch.exp(image_noise_sigma).to(dtype=images.dtype)
|
||||||
noisy_images = images + torch.randn_like(images) * image_noise_sigma[:, None, None, None, None]
|
noisy_images = (
|
||||||
image_latent_dist = self.components.vae.encode(noisy_images.to(dtype=self.components.vae.dtype)).latent_dist
|
images + torch.randn_like(images) * image_noise_sigma[:, None, None, None, None]
|
||||||
|
)
|
||||||
|
image_latent_dist = self.components.vae.encode(
|
||||||
|
noisy_images.to(dtype=self.components.vae.dtype)
|
||||||
|
).latent_dist
|
||||||
image_latents = image_latent_dist.sample() * self.components.vae.config.scaling_factor
|
image_latents = image_latent_dist.sample() * self.components.vae.config.scaling_factor
|
||||||
|
|
||||||
# Sample a random timestep for each sample
|
# Sample a random timestep for each sample
|
||||||
timesteps = torch.randint(
|
timesteps = torch.randint(
|
||||||
0, self.components.scheduler.config.num_train_timesteps, (batch_size,), device=self.accelerator.device
|
0,
|
||||||
|
self.components.scheduler.config.num_train_timesteps,
|
||||||
|
(batch_size,),
|
||||||
|
device=self.accelerator.device,
|
||||||
)
|
)
|
||||||
timesteps = timesteps.long()
|
timesteps = timesteps.long()
|
||||||
|
|
||||||
# from [B, C, F, H, W] to [B, F, C, H, W]
|
# from [B, C, F, H, W] to [B, F, C, H, W]
|
||||||
latent = latent.permute(0, 2, 1, 3, 4)
|
latent = latent.permute(0, 2, 1, 3, 4)
|
||||||
image_latents = image_latents.permute(0, 2, 1, 3, 4)
|
image_latents = image_latents.permute(0, 2, 1, 3, 4)
|
||||||
assert (latent.shape[0], *latent.shape[2:]) == (image_latents.shape[0], *image_latents.shape[2:])
|
assert (latent.shape[0], *latent.shape[2:]) == (
|
||||||
|
image_latents.shape[0],
|
||||||
|
*image_latents.shape[2:],
|
||||||
|
)
|
||||||
|
|
||||||
# Padding image_latents to the same frame number as latent
|
# Padding image_latents to the same frame number as latent
|
||||||
padding_shape = (latent.shape[0], latent.shape[1] - 1, *latent.shape[2:])
|
padding_shape = (latent.shape[0], latent.shape[1] - 1, *latent.shape[2:])
|
||||||
@ -169,7 +189,9 @@ class CogVideoXI2VLoraTrainer(Trainer):
|
|||||||
|
|
||||||
# Predict noise, For CogVideoX1.5 Only.
|
# Predict noise, For CogVideoX1.5 Only.
|
||||||
ofs_emb = (
|
ofs_emb = (
|
||||||
None if self.state.transformer_config.ofs_embed_dim is None else latent.new_full((1,), fill_value=2.0)
|
None
|
||||||
|
if self.state.transformer_config.ofs_embed_dim is None
|
||||||
|
else latent.new_full((1,), fill_value=2.0)
|
||||||
)
|
)
|
||||||
predicted_noise = self.components.transformer(
|
predicted_noise = self.components.transformer(
|
||||||
hidden_states=latent_img_noisy,
|
hidden_states=latent_img_noisy,
|
||||||
@ -181,7 +203,9 @@ class CogVideoXI2VLoraTrainer(Trainer):
|
|||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
# Denoise
|
# Denoise
|
||||||
latent_pred = self.components.scheduler.get_velocity(predicted_noise, latent_noisy, timesteps)
|
latent_pred = self.components.scheduler.get_velocity(
|
||||||
|
predicted_noise, latent_noisy, timesteps
|
||||||
|
)
|
||||||
|
|
||||||
alphas_cumprod = self.components.scheduler.alphas_cumprod[timesteps]
|
alphas_cumprod = self.components.scheduler.alphas_cumprod[timesteps]
|
||||||
weights = 1 / (1 - alphas_cumprod)
|
weights = 1 / (1 - alphas_cumprod)
|
||||||
@ -228,7 +252,9 @@ class CogVideoXI2VLoraTrainer(Trainer):
|
|||||||
if transformer_config.patch_size_t is None:
|
if transformer_config.patch_size_t is None:
|
||||||
base_num_frames = num_frames
|
base_num_frames = num_frames
|
||||||
else:
|
else:
|
||||||
base_num_frames = (num_frames + transformer_config.patch_size_t - 1) // transformer_config.patch_size_t
|
base_num_frames = (
|
||||||
|
num_frames + transformer_config.patch_size_t - 1
|
||||||
|
) // transformer_config.patch_size_t
|
||||||
|
|
||||||
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
||||||
embed_dim=transformer_config.attention_head_dim,
|
embed_dim=transformer_config.attention_head_dim,
|
||||||
|
@ -31,13 +31,19 @@ class CogVideoXT2VLoraTrainer(Trainer):
|
|||||||
|
|
||||||
components.tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder="tokenizer")
|
components.tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder="tokenizer")
|
||||||
|
|
||||||
components.text_encoder = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder")
|
components.text_encoder = T5EncoderModel.from_pretrained(
|
||||||
|
model_path, subfolder="text_encoder"
|
||||||
|
)
|
||||||
|
|
||||||
components.transformer = CogVideoXTransformer3DModel.from_pretrained(model_path, subfolder="transformer")
|
components.transformer = CogVideoXTransformer3DModel.from_pretrained(
|
||||||
|
model_path, subfolder="transformer"
|
||||||
|
)
|
||||||
|
|
||||||
components.vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae")
|
components.vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae")
|
||||||
|
|
||||||
components.scheduler = CogVideoXDPMScheduler.from_pretrained(model_path, subfolder="scheduler")
|
components.scheduler = CogVideoXDPMScheduler.from_pretrained(
|
||||||
|
model_path, subfolder="scheduler"
|
||||||
|
)
|
||||||
|
|
||||||
return components
|
return components
|
||||||
|
|
||||||
@ -72,7 +78,9 @@ class CogVideoXT2VLoraTrainer(Trainer):
|
|||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
)
|
)
|
||||||
prompt_token_ids = prompt_token_ids.input_ids
|
prompt_token_ids = prompt_token_ids.input_ids
|
||||||
prompt_embedding = self.components.text_encoder(prompt_token_ids.to(self.accelerator.device))[0]
|
prompt_embedding = self.components.text_encoder(
|
||||||
|
prompt_token_ids.to(self.accelerator.device)
|
||||||
|
)[0]
|
||||||
return prompt_embedding
|
return prompt_embedding
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@ -115,7 +123,10 @@ class CogVideoXT2VLoraTrainer(Trainer):
|
|||||||
|
|
||||||
# Sample a random timestep for each sample
|
# Sample a random timestep for each sample
|
||||||
timesteps = torch.randint(
|
timesteps = torch.randint(
|
||||||
0, self.components.scheduler.config.num_train_timesteps, (batch_size,), device=self.accelerator.device
|
0,
|
||||||
|
self.components.scheduler.config.num_train_timesteps,
|
||||||
|
(batch_size,),
|
||||||
|
device=self.accelerator.device,
|
||||||
)
|
)
|
||||||
timesteps = timesteps.long()
|
timesteps = timesteps.long()
|
||||||
|
|
||||||
@ -150,7 +161,9 @@ class CogVideoXT2VLoraTrainer(Trainer):
|
|||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
# Denoise
|
# Denoise
|
||||||
latent_pred = self.components.scheduler.get_velocity(predicted_noise, latent_added_noise, timesteps)
|
latent_pred = self.components.scheduler.get_velocity(
|
||||||
|
predicted_noise, latent_added_noise, timesteps
|
||||||
|
)
|
||||||
|
|
||||||
alphas_cumprod = self.components.scheduler.alphas_cumprod[timesteps]
|
alphas_cumprod = self.components.scheduler.alphas_cumprod[timesteps]
|
||||||
weights = 1 / (1 - alphas_cumprod)
|
weights = 1 / (1 - alphas_cumprod)
|
||||||
@ -196,7 +209,9 @@ class CogVideoXT2VLoraTrainer(Trainer):
|
|||||||
if transformer_config.patch_size_t is None:
|
if transformer_config.patch_size_t is None:
|
||||||
base_num_frames = num_frames
|
base_num_frames = num_frames
|
||||||
else:
|
else:
|
||||||
base_num_frames = (num_frames + transformer_config.patch_size_t - 1) // transformer_config.patch_size_t
|
base_num_frames = (
|
||||||
|
num_frames + transformer_config.patch_size_t - 1
|
||||||
|
) // transformer_config.patch_size_t
|
||||||
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
||||||
embed_dim=transformer_config.attention_head_dim,
|
embed_dim=transformer_config.attention_head_dim,
|
||||||
crops_coords=None,
|
crops_coords=None,
|
||||||
|
@ -52,6 +52,8 @@ def get_model_cls(model_type: str, training_type: Literal["lora", "sft"]) -> Tra
|
|||||||
print(f"\nSupported training types for '{model_type}' are:")
|
print(f"\nSupported training types for '{model_type}' are:")
|
||||||
for supported_type in SUPPORTED_MODELS[model_type]:
|
for supported_type in SUPPORTED_MODELS[model_type]:
|
||||||
print(f" • {supported_type}")
|
print(f" • {supported_type}")
|
||||||
raise ValueError(f"Training type '{training_type}' is not supported for model '{model_type}'")
|
raise ValueError(
|
||||||
|
f"Training type '{training_type}' is not supported for model '{model_type}'"
|
||||||
|
)
|
||||||
|
|
||||||
return SUPPORTED_MODELS[model_type][training_type]
|
return SUPPORTED_MODELS[model_type][training_type]
|
||||||
|
@ -115,14 +115,18 @@ class Args(BaseModel):
|
|||||||
def validate_validation_images(cls, v: str | None, info: ValidationInfo) -> str | None:
|
def validate_validation_images(cls, v: str | None, info: ValidationInfo) -> str | None:
|
||||||
values = info.data
|
values = info.data
|
||||||
if values.get("do_validation") and values.get("model_type") == "i2v" and not v:
|
if values.get("do_validation") and values.get("model_type") == "i2v" and not v:
|
||||||
raise ValueError("validation_images must be specified when do_validation is True and model_type is i2v")
|
raise ValueError(
|
||||||
|
"validation_images must be specified when do_validation is True and model_type is i2v"
|
||||||
|
)
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@field_validator("validation_videos")
|
@field_validator("validation_videos")
|
||||||
def validate_validation_videos(cls, v: str | None, info: ValidationInfo) -> str | None:
|
def validate_validation_videos(cls, v: str | None, info: ValidationInfo) -> str | None:
|
||||||
values = info.data
|
values = info.data
|
||||||
if values.get("do_validation") and values.get("model_type") == "v2v" and not v:
|
if values.get("do_validation") and values.get("model_type") == "v2v" and not v:
|
||||||
raise ValueError("validation_videos must be specified when do_validation is True and model_type is v2v")
|
raise ValueError(
|
||||||
|
"validation_videos must be specified when do_validation is True and model_type is v2v"
|
||||||
|
)
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@field_validator("validation_steps")
|
@field_validator("validation_steps")
|
||||||
@ -148,7 +152,9 @@ class Args(BaseModel):
|
|||||||
model_name = info.data.get("model_name", "")
|
model_name = info.data.get("model_name", "")
|
||||||
if model_name in ["cogvideox-5b-i2v", "cogvideox-5b-t2v"]:
|
if model_name in ["cogvideox-5b-i2v", "cogvideox-5b-t2v"]:
|
||||||
if (height, width) != (480, 720):
|
if (height, width) != (480, 720):
|
||||||
raise ValueError("For cogvideox-5b models, height must be 480 and width must be 720")
|
raise ValueError(
|
||||||
|
"For cogvideox-5b models, height must be 480 and width must be 720"
|
||||||
|
)
|
||||||
|
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@ -221,7 +227,9 @@ class Args(BaseModel):
|
|||||||
# LoRA parameters
|
# LoRA parameters
|
||||||
parser.add_argument("--rank", type=int, default=128)
|
parser.add_argument("--rank", type=int, default=128)
|
||||||
parser.add_argument("--lora_alpha", type=int, default=64)
|
parser.add_argument("--lora_alpha", type=int, default=64)
|
||||||
parser.add_argument("--target_modules", type=str, nargs="+", default=["to_q", "to_k", "to_v", "to_out.0"])
|
parser.add_argument(
|
||||||
|
"--target_modules", type=str, nargs="+", default=["to_q", "to_k", "to_v", "to_out.0"]
|
||||||
|
)
|
||||||
|
|
||||||
# Checkpointing
|
# Checkpointing
|
||||||
parser.add_argument("--checkpointing_steps", type=int, default=200)
|
parser.add_argument("--checkpointing_steps", type=int, default=200)
|
||||||
|
@ -8,7 +8,10 @@ import cv2
|
|||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--datadir", type=str, required=True, help="Root directory containing videos.txt and video subdirectory"
|
"--datadir",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Root directory containing videos.txt and video subdirectory",
|
||||||
)
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
@ -88,7 +88,9 @@ class Trainer:
|
|||||||
|
|
||||||
def _init_distributed(self):
|
def _init_distributed(self):
|
||||||
logging_dir = Path(self.args.output_dir, "logs")
|
logging_dir = Path(self.args.output_dir, "logs")
|
||||||
project_config = ProjectConfiguration(project_dir=self.args.output_dir, logging_dir=logging_dir)
|
project_config = ProjectConfiguration(
|
||||||
|
project_dir=self.args.output_dir, logging_dir=logging_dir
|
||||||
|
)
|
||||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||||
init_process_group_kwargs = InitProcessGroupKwargs(
|
init_process_group_kwargs = InitProcessGroupKwargs(
|
||||||
backend="nccl", timeout=timedelta(seconds=self.args.nccl_timeout)
|
backend="nccl", timeout=timedelta(seconds=self.args.nccl_timeout)
|
||||||
@ -183,7 +185,9 @@ class Trainer:
|
|||||||
# Prepare VAE and text encoder for encoding
|
# Prepare VAE and text encoder for encoding
|
||||||
self.components.vae.requires_grad_(False)
|
self.components.vae.requires_grad_(False)
|
||||||
self.components.text_encoder.requires_grad_(False)
|
self.components.text_encoder.requires_grad_(False)
|
||||||
self.components.vae = self.components.vae.to(self.accelerator.device, dtype=self.state.weight_dtype)
|
self.components.vae = self.components.vae.to(
|
||||||
|
self.accelerator.device, dtype=self.state.weight_dtype
|
||||||
|
)
|
||||||
self.components.text_encoder = self.components.text_encoder.to(
|
self.components.text_encoder = self.components.text_encoder.to(
|
||||||
self.accelerator.device, dtype=self.state.weight_dtype
|
self.accelerator.device, dtype=self.state.weight_dtype
|
||||||
)
|
)
|
||||||
@ -263,7 +267,9 @@ class Trainer:
|
|||||||
|
|
||||||
# For LoRA, we only want to train the LoRA weights
|
# For LoRA, we only want to train the LoRA weights
|
||||||
# For SFT, we want to train all the parameters
|
# For SFT, we want to train all the parameters
|
||||||
trainable_parameters = list(filter(lambda p: p.requires_grad, self.components.transformer.parameters()))
|
trainable_parameters = list(
|
||||||
|
filter(lambda p: p.requires_grad, self.components.transformer.parameters())
|
||||||
|
)
|
||||||
transformer_parameters_with_lr = {
|
transformer_parameters_with_lr = {
|
||||||
"params": trainable_parameters,
|
"params": trainable_parameters,
|
||||||
"lr": self.args.learning_rate,
|
"lr": self.args.learning_rate,
|
||||||
@ -287,7 +293,9 @@ class Trainer:
|
|||||||
use_deepspeed=use_deepspeed_opt,
|
use_deepspeed=use_deepspeed_opt,
|
||||||
)
|
)
|
||||||
|
|
||||||
num_update_steps_per_epoch = math.ceil(len(self.data_loader) / self.args.gradient_accumulation_steps)
|
num_update_steps_per_epoch = math.ceil(
|
||||||
|
len(self.data_loader) / self.args.gradient_accumulation_steps
|
||||||
|
)
|
||||||
if self.args.train_steps is None:
|
if self.args.train_steps is None:
|
||||||
self.args.train_steps = self.args.train_epochs * num_update_steps_per_epoch
|
self.args.train_steps = self.args.train_epochs * num_update_steps_per_epoch
|
||||||
self.state.overwrote_max_train_steps = True
|
self.state.overwrote_max_train_steps = True
|
||||||
@ -322,12 +330,16 @@ class Trainer:
|
|||||||
self.lr_scheduler = lr_scheduler
|
self.lr_scheduler = lr_scheduler
|
||||||
|
|
||||||
def prepare_for_training(self) -> None:
|
def prepare_for_training(self) -> None:
|
||||||
self.components.transformer, self.optimizer, self.data_loader, self.lr_scheduler = self.accelerator.prepare(
|
self.components.transformer, self.optimizer, self.data_loader, self.lr_scheduler = (
|
||||||
|
self.accelerator.prepare(
|
||||||
self.components.transformer, self.optimizer, self.data_loader, self.lr_scheduler
|
self.components.transformer, self.optimizer, self.data_loader, self.lr_scheduler
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||||
num_update_steps_per_epoch = math.ceil(len(self.data_loader) / self.args.gradient_accumulation_steps)
|
num_update_steps_per_epoch = math.ceil(
|
||||||
|
len(self.data_loader) / self.args.gradient_accumulation_steps
|
||||||
|
)
|
||||||
if self.state.overwrote_max_train_steps:
|
if self.state.overwrote_max_train_steps:
|
||||||
self.args.train_steps = self.args.train_epochs * num_update_steps_per_epoch
|
self.args.train_steps = self.args.train_epochs * num_update_steps_per_epoch
|
||||||
# Afterwards we recalculate our number of training epochs
|
# Afterwards we recalculate our number of training epochs
|
||||||
@ -364,7 +376,9 @@ class Trainer:
|
|||||||
logger.info(f"Memory before training start: {json.dumps(memory_statistics, indent=4)}")
|
logger.info(f"Memory before training start: {json.dumps(memory_statistics, indent=4)}")
|
||||||
|
|
||||||
self.state.total_batch_size_count = (
|
self.state.total_batch_size_count = (
|
||||||
self.args.batch_size * self.accelerator.num_processes * self.args.gradient_accumulation_steps
|
self.args.batch_size
|
||||||
|
* self.accelerator.num_processes
|
||||||
|
* self.args.gradient_accumulation_steps
|
||||||
)
|
)
|
||||||
info = {
|
info = {
|
||||||
"trainable parameters": self.state.num_trainable_parameters,
|
"trainable parameters": self.state.num_trainable_parameters,
|
||||||
@ -454,7 +468,9 @@ class Trainer:
|
|||||||
progress_bar.set_postfix(logs)
|
progress_bar.set_postfix(logs)
|
||||||
|
|
||||||
# Maybe run validation
|
# Maybe run validation
|
||||||
should_run_validation = self.args.do_validation and global_step % self.args.validation_steps == 0
|
should_run_validation = (
|
||||||
|
self.args.do_validation and global_step % self.args.validation_steps == 0
|
||||||
|
)
|
||||||
if should_run_validation:
|
if should_run_validation:
|
||||||
del loss
|
del loss
|
||||||
free_memory()
|
free_memory()
|
||||||
@ -466,7 +482,9 @@ class Trainer:
|
|||||||
break
|
break
|
||||||
|
|
||||||
memory_statistics = get_memory_statistics()
|
memory_statistics = get_memory_statistics()
|
||||||
logger.info(f"Memory after epoch {epoch + 1}: {json.dumps(memory_statistics, indent=4)}")
|
logger.info(
|
||||||
|
f"Memory after epoch {epoch + 1}: {json.dumps(memory_statistics, indent=4)}"
|
||||||
|
)
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
self.__maybe_save_checkpoint(global_step, must_save=True)
|
self.__maybe_save_checkpoint(global_step, must_save=True)
|
||||||
@ -504,7 +522,9 @@ class Trainer:
|
|||||||
# Can't using model_cpu_offload in deepspeed,
|
# Can't using model_cpu_offload in deepspeed,
|
||||||
# so we need to move all components in pipe to device
|
# so we need to move all components in pipe to device
|
||||||
# pipe.to(self.accelerator.device, dtype=self.state.weight_dtype)
|
# pipe.to(self.accelerator.device, dtype=self.state.weight_dtype)
|
||||||
self.__move_components_to_device(dtype=self.state.weight_dtype, ignore_list=["transformer"])
|
self.__move_components_to_device(
|
||||||
|
dtype=self.state.weight_dtype, ignore_list=["transformer"]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# if not using deepspeed, use model_cpu_offload to further reduce memory usage
|
# if not using deepspeed, use model_cpu_offload to further reduce memory usage
|
||||||
# Or use pipe.enable_sequential_cpu_offload() to further reduce memory usage
|
# Or use pipe.enable_sequential_cpu_offload() to further reduce memory usage
|
||||||
@ -528,7 +548,9 @@ class Trainer:
|
|||||||
video = self.state.validation_videos[i]
|
video = self.state.validation_videos[i]
|
||||||
|
|
||||||
if image is not None:
|
if image is not None:
|
||||||
image = preprocess_image_with_resize(image, self.state.train_height, self.state.train_width)
|
image = preprocess_image_with_resize(
|
||||||
|
image, self.state.train_height, self.state.train_width
|
||||||
|
)
|
||||||
# Convert image tensor (C, H, W) to PIL images
|
# Convert image tensor (C, H, W) to PIL images
|
||||||
image = image.to(torch.uint8)
|
image = image.to(torch.uint8)
|
||||||
image = image.permute(1, 2, 0).cpu().numpy()
|
image = image.permute(1, 2, 0).cpu().numpy()
|
||||||
@ -546,7 +568,9 @@ class Trainer:
|
|||||||
f"Validating sample {i + 1}/{num_validation_samples} on process {accelerator.process_index}. Prompt: {prompt}",
|
f"Validating sample {i + 1}/{num_validation_samples} on process {accelerator.process_index}. Prompt: {prompt}",
|
||||||
main_process_only=False,
|
main_process_only=False,
|
||||||
)
|
)
|
||||||
validation_artifacts = self.validation_step({"prompt": prompt, "image": image, "video": video}, pipe)
|
validation_artifacts = self.validation_step(
|
||||||
|
{"prompt": prompt, "image": image, "video": video}, pipe
|
||||||
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.state.using_deepspeed
|
self.state.using_deepspeed
|
||||||
@ -565,7 +589,9 @@ class Trainer:
|
|||||||
"video": {"type": "video", "value": video},
|
"video": {"type": "video", "value": video},
|
||||||
}
|
}
|
||||||
for i, (artifact_type, artifact_value) in enumerate(validation_artifacts):
|
for i, (artifact_type, artifact_value) in enumerate(validation_artifacts):
|
||||||
artifacts.update({f"artifact_{i}": {"type": artifact_type, "value": artifact_value}})
|
artifacts.update(
|
||||||
|
{f"artifact_{i}": {"type": artifact_type, "value": artifact_value}}
|
||||||
|
)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Validation artifacts on process {accelerator.process_index}: {list(artifacts.keys())}",
|
f"Validation artifacts on process {accelerator.process_index}: {list(artifacts.keys())}",
|
||||||
main_process_only=False,
|
main_process_only=False,
|
||||||
@ -600,8 +626,12 @@ class Trainer:
|
|||||||
tracker_key = "validation"
|
tracker_key = "validation"
|
||||||
for tracker in accelerator.trackers:
|
for tracker in accelerator.trackers:
|
||||||
if tracker.name == "wandb":
|
if tracker.name == "wandb":
|
||||||
image_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Image)]
|
image_artifacts = [
|
||||||
video_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Video)]
|
artifact for artifact in all_artifacts if isinstance(artifact, wandb.Image)
|
||||||
|
]
|
||||||
|
video_artifacts = [
|
||||||
|
artifact for artifact in all_artifacts if isinstance(artifact, wandb.Video)
|
||||||
|
]
|
||||||
tracker.log(
|
tracker.log(
|
||||||
{
|
{
|
||||||
tracker_key: {"images": image_artifacts, "videos": video_artifacts},
|
tracker_key: {"images": image_artifacts, "videos": video_artifacts},
|
||||||
@ -618,7 +648,9 @@ class Trainer:
|
|||||||
pipe.remove_all_hooks()
|
pipe.remove_all_hooks()
|
||||||
del pipe
|
del pipe
|
||||||
# Load models except those not needed for training
|
# Load models except those not needed for training
|
||||||
self.__move_components_to_device(dtype=self.state.weight_dtype, ignore_list=self.UNLOAD_LIST)
|
self.__move_components_to_device(
|
||||||
|
dtype=self.state.weight_dtype, ignore_list=self.UNLOAD_LIST
|
||||||
|
)
|
||||||
self.components.transformer.to(self.accelerator.device, dtype=self.state.weight_dtype)
|
self.components.transformer.to(self.accelerator.device, dtype=self.state.weight_dtype)
|
||||||
|
|
||||||
# Change trainable weights back to fp32 to keep with dtype after prepare the model
|
# Change trainable weights back to fp32 to keep with dtype after prepare the model
|
||||||
@ -687,7 +719,9 @@ class Trainer:
|
|||||||
for name, component in components.items():
|
for name, component in components.items():
|
||||||
if not isinstance(component, type) and hasattr(component, "to"):
|
if not isinstance(component, type) and hasattr(component, "to"):
|
||||||
if name not in ignore_list:
|
if name not in ignore_list:
|
||||||
setattr(self.components, name, component.to(self.accelerator.device, dtype=dtype))
|
setattr(
|
||||||
|
self.components, name, component.to(self.accelerator.device, dtype=dtype)
|
||||||
|
)
|
||||||
|
|
||||||
def __move_components_to_cpu(self, unload_list: List[str] = []):
|
def __move_components_to_cpu(self, unload_list: List[str] = []):
|
||||||
unload_list = set(unload_list)
|
unload_list = set(unload_list)
|
||||||
@ -732,11 +766,13 @@ class Trainer:
|
|||||||
):
|
):
|
||||||
transformer_ = unwrap_model(self.accelerator, model)
|
transformer_ = unwrap_model(self.accelerator, model)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unexpected save model: {unwrap_model(self.accelerator, model).__class__}")
|
raise ValueError(
|
||||||
else:
|
f"Unexpected save model: {unwrap_model(self.accelerator, model).__class__}"
|
||||||
transformer_ = unwrap_model(self.accelerator, self.components.transformer).__class__.from_pretrained(
|
|
||||||
self.args.model_path, subfolder="transformer"
|
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
transformer_ = unwrap_model(
|
||||||
|
self.accelerator, self.components.transformer
|
||||||
|
).__class__.from_pretrained(self.args.model_path, subfolder="transformer")
|
||||||
transformer_.add_adapter(transformer_lora_config)
|
transformer_.add_adapter(transformer_lora_config)
|
||||||
|
|
||||||
lora_state_dict = self.components.pipeline_cls.lora_state_dict(input_dir)
|
lora_state_dict = self.components.pipeline_cls.lora_state_dict(input_dir)
|
||||||
@ -745,7 +781,9 @@ class Trainer:
|
|||||||
for k, v in lora_state_dict.items()
|
for k, v in lora_state_dict.items()
|
||||||
if k.startswith("transformer.")
|
if k.startswith("transformer.")
|
||||||
}
|
}
|
||||||
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
|
incompatible_keys = set_peft_model_state_dict(
|
||||||
|
transformer_, transformer_state_dict, adapter_name="default"
|
||||||
|
)
|
||||||
if incompatible_keys is not None:
|
if incompatible_keys is not None:
|
||||||
# check only for unexpected keys
|
# check only for unexpected keys
|
||||||
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
||||||
@ -759,7 +797,10 @@ class Trainer:
|
|||||||
self.accelerator.register_load_state_pre_hook(load_model_hook)
|
self.accelerator.register_load_state_pre_hook(load_model_hook)
|
||||||
|
|
||||||
def __maybe_save_checkpoint(self, global_step: int, must_save: bool = False):
|
def __maybe_save_checkpoint(self, global_step: int, must_save: bool = False):
|
||||||
if self.accelerator.distributed_type == DistributedType.DEEPSPEED or self.accelerator.is_main_process:
|
if (
|
||||||
|
self.accelerator.distributed_type == DistributedType.DEEPSPEED
|
||||||
|
or self.accelerator.is_main_process
|
||||||
|
):
|
||||||
if must_save or global_step % self.args.checkpointing_steps == 0:
|
if must_save or global_step % self.args.checkpointing_steps == 0:
|
||||||
# for training
|
# for training
|
||||||
save_path = get_intermediate_ckpt_path(
|
save_path = get_intermediate_ckpt_path(
|
||||||
|
@ -23,7 +23,9 @@ def get_latest_ckpt_path_to_resume_from(
|
|||||||
else:
|
else:
|
||||||
resume_from_checkpoint_path = Path(resume_from_checkpoint)
|
resume_from_checkpoint_path = Path(resume_from_checkpoint)
|
||||||
if not resume_from_checkpoint_path.exists():
|
if not resume_from_checkpoint_path.exists():
|
||||||
logger.info(f"Checkpoint '{resume_from_checkpoint}' does not exist. Starting a new training run.")
|
logger.info(
|
||||||
|
f"Checkpoint '{resume_from_checkpoint}' does not exist. Starting a new training run."
|
||||||
|
)
|
||||||
initial_global_step = 0
|
initial_global_step = 0
|
||||||
global_step = 0
|
global_step = 0
|
||||||
first_epoch = 0
|
first_epoch = 0
|
||||||
|
@ -55,7 +55,9 @@ def unload_model(model):
|
|||||||
model.to("cpu")
|
model.to("cpu")
|
||||||
|
|
||||||
|
|
||||||
def make_contiguous(x: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
def make_contiguous(
|
||||||
|
x: Union[torch.Tensor, Dict[str, torch.Tensor]],
|
||||||
|
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||||
if isinstance(x, torch.Tensor):
|
if isinstance(x, torch.Tensor):
|
||||||
return x.contiguous()
|
return x.contiguous()
|
||||||
elif isinstance(x, dict):
|
elif isinstance(x, dict):
|
||||||
|
@ -67,7 +67,9 @@ def get_optimizer(
|
|||||||
optimizer_name = "adamw"
|
optimizer_name = "adamw"
|
||||||
|
|
||||||
if (use_8bit or use_4bit) and optimizer_name not in ["adam", "adamw"]:
|
if (use_8bit or use_4bit) and optimizer_name not in ["adam", "adamw"]:
|
||||||
raise ValueError("`use_8bit` and `use_4bit` can only be used with the Adam and AdamW optimizers.")
|
raise ValueError(
|
||||||
|
"`use_8bit` and `use_4bit` can only be used with the Adam and AdamW optimizers."
|
||||||
|
)
|
||||||
|
|
||||||
if use_8bit:
|
if use_8bit:
|
||||||
try:
|
try:
|
||||||
@ -81,7 +83,9 @@ def get_optimizer(
|
|||||||
if use_torchao:
|
if use_torchao:
|
||||||
from torchao.prototype.low_bit_optim import AdamW4bit, AdamW8bit
|
from torchao.prototype.low_bit_optim import AdamW4bit, AdamW8bit
|
||||||
|
|
||||||
optimizer_class = AdamW8bit if use_8bit else AdamW4bit if use_4bit else torch.optim.AdamW
|
optimizer_class = (
|
||||||
|
AdamW8bit if use_8bit else AdamW4bit if use_4bit else torch.optim.AdamW
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
optimizer_class = bnb.optim.AdamW8bit if use_8bit else torch.optim.AdamW
|
optimizer_class = bnb.optim.AdamW8bit if use_8bit else torch.optim.AdamW
|
||||||
|
|
||||||
@ -109,7 +113,9 @@ def get_optimizer(
|
|||||||
try:
|
try:
|
||||||
import prodigyopt
|
import prodigyopt
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
|
raise ImportError(
|
||||||
|
"To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`"
|
||||||
|
)
|
||||||
|
|
||||||
optimizer_class = prodigyopt.Prodigy
|
optimizer_class = prodigyopt.Prodigy
|
||||||
|
|
||||||
@ -133,7 +139,9 @@ def get_optimizer(
|
|||||||
try:
|
try:
|
||||||
import came_pytorch
|
import came_pytorch
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("To use CAME, please install the came-pytorch library: `pip install came-pytorch`")
|
raise ImportError(
|
||||||
|
"To use CAME, please install the came-pytorch library: `pip install came-pytorch`"
|
||||||
|
)
|
||||||
|
|
||||||
optimizer_class = came_pytorch.CAME
|
optimizer_class = came_pytorch.CAME
|
||||||
|
|
||||||
@ -151,7 +159,10 @@ def get_optimizer(
|
|||||||
init_kwargs.update({"fused": True})
|
init_kwargs.update({"fused": True})
|
||||||
|
|
||||||
optimizer = CPUOffloadOptimizer(
|
optimizer = CPUOffloadOptimizer(
|
||||||
params_to_optimize, optimizer_class=optimizer_class, offload_gradients=offload_gradients, **init_kwargs
|
params_to_optimize,
|
||||||
|
optimizer_class=optimizer_class,
|
||||||
|
offload_gradients=offload_gradients,
|
||||||
|
**init_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
optimizer = optimizer_class(params_to_optimize, **init_kwargs)
|
optimizer = optimizer_class(params_to_optimize, **init_kwargs)
|
||||||
|
@ -99,7 +99,9 @@ def generate_video(
|
|||||||
desired_resolution = RESOLUTION_MAP[model_name]
|
desired_resolution = RESOLUTION_MAP[model_name]
|
||||||
if width is None or height is None:
|
if width is None or height is None:
|
||||||
height, width = desired_resolution
|
height, width = desired_resolution
|
||||||
logging.info(f"\033[1mUsing default resolution {desired_resolution} for {model_name}\033[0m")
|
logging.info(
|
||||||
|
f"\033[1mUsing default resolution {desired_resolution} for {model_name}\033[0m"
|
||||||
|
)
|
||||||
elif (height, width) != desired_resolution:
|
elif (height, width) != desired_resolution:
|
||||||
if generate_type == "i2v":
|
if generate_type == "i2v":
|
||||||
# For i2v models, use user-defined width and height
|
# For i2v models, use user-defined width and height
|
||||||
@ -124,7 +126,9 @@ def generate_video(
|
|||||||
|
|
||||||
# If you're using with lora, add this code
|
# If you're using with lora, add this code
|
||||||
if lora_path:
|
if lora_path:
|
||||||
pipe.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name="test_1")
|
pipe.load_lora_weights(
|
||||||
|
lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name="test_1"
|
||||||
|
)
|
||||||
pipe.fuse_lora(components=["transformer"], lora_scale=1 / lora_rank)
|
pipe.fuse_lora(components=["transformer"], lora_scale=1 / lora_rank)
|
||||||
|
|
||||||
# 2. Set Scheduler.
|
# 2. Set Scheduler.
|
||||||
@ -133,7 +137,9 @@ def generate_video(
|
|||||||
# using `CogVideoXDPMScheduler` for CogVideoX-5B / CogVideoX-5B-I2V.
|
# using `CogVideoXDPMScheduler` for CogVideoX-5B / CogVideoX-5B-I2V.
|
||||||
|
|
||||||
# pipe.scheduler = CogVideoXDDIMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
|
# pipe.scheduler = CogVideoXDDIMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
|
||||||
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
|
pipe.scheduler = CogVideoXDPMScheduler.from_config(
|
||||||
|
pipe.scheduler.config, timestep_spacing="trailing"
|
||||||
|
)
|
||||||
|
|
||||||
# 3. Enable CPU offload for the model.
|
# 3. Enable CPU offload for the model.
|
||||||
# turn off if you have multiple GPUs or enough GPU memory(such as H100) and it will cost less time in inference
|
# turn off if you have multiple GPUs or enough GPU memory(such as H100) and it will cost less time in inference
|
||||||
@ -190,8 +196,12 @@ def generate_video(
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX")
|
parser = argparse.ArgumentParser(
|
||||||
parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated")
|
description="Generate a video from a text prompt using CogVideoX"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--prompt", type=str, required=True, help="The description of the video to be generated"
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--image_or_video_path",
|
"--image_or_video_path",
|
||||||
type=str,
|
type=str,
|
||||||
@ -199,20 +209,44 @@ if __name__ == "__main__":
|
|||||||
help="The path of the image to be used as the background of the video",
|
help="The path of the image to be used as the background of the video",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model_path", type=str, default="THUDM/CogVideoX1.5-5B", help="Path of the pre-trained model use"
|
"--model_path",
|
||||||
|
type=str,
|
||||||
|
default="THUDM/CogVideoX1.5-5B",
|
||||||
|
help="Path of the pre-trained model use",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lora_path", type=str, default=None, help="The path of the LoRA weights to be used"
|
||||||
)
|
)
|
||||||
parser.add_argument("--lora_path", type=str, default=None, help="The path of the LoRA weights to be used")
|
|
||||||
parser.add_argument("--lora_rank", type=int, default=128, help="The rank of the LoRA weights")
|
parser.add_argument("--lora_rank", type=int, default=128, help="The rank of the LoRA weights")
|
||||||
parser.add_argument("--output_path", type=str, default="./output.mp4", help="The path save generated video")
|
parser.add_argument(
|
||||||
parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance")
|
"--output_path", type=str, default="./output.mp4", help="The path save generated video"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance"
|
||||||
|
)
|
||||||
parser.add_argument("--num_inference_steps", type=int, default=50, help="Inference steps")
|
parser.add_argument("--num_inference_steps", type=int, default=50, help="Inference steps")
|
||||||
parser.add_argument("--num_frames", type=int, default=81, help="Number of steps for the inference process")
|
parser.add_argument(
|
||||||
|
"--num_frames", type=int, default=81, help="Number of steps for the inference process"
|
||||||
|
)
|
||||||
parser.add_argument("--width", type=int, default=None, help="The width of the generated video")
|
parser.add_argument("--width", type=int, default=None, help="The width of the generated video")
|
||||||
parser.add_argument("--height", type=int, default=None, help="The height of the generated video")
|
parser.add_argument(
|
||||||
parser.add_argument("--fps", type=int, default=16, help="The frames per second for the generated video")
|
"--height", type=int, default=None, help="The height of the generated video"
|
||||||
parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt")
|
)
|
||||||
parser.add_argument("--generate_type", type=str, default="t2v", help="The type of video generation")
|
parser.add_argument(
|
||||||
parser.add_argument("--dtype", type=str, default="bfloat16", help="The data type for computation")
|
"--fps", type=int, default=16, help="The frames per second for the generated video"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_videos_per_prompt",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of videos to generate per prompt",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--generate_type", type=str, default="t2v", help="The type of video generation"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dtype", type=str, default="bfloat16", help="The data type for computation"
|
||||||
|
)
|
||||||
parser.add_argument("--seed", type=int, default=42, help="The seed for reproducibility")
|
parser.add_argument("--seed", type=int, default=42, help="The seed for reproducibility")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
@ -19,7 +19,12 @@ import argparse
|
|||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
import torch._dynamo
|
import torch._dynamo
|
||||||
from diffusers import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel, CogVideoXPipeline, CogVideoXDPMScheduler
|
from diffusers import (
|
||||||
|
AutoencoderKLCogVideoX,
|
||||||
|
CogVideoXTransformer3DModel,
|
||||||
|
CogVideoXPipeline,
|
||||||
|
CogVideoXDPMScheduler,
|
||||||
|
)
|
||||||
from diffusers.utils import export_to_video
|
from diffusers.utils import export_to_video
|
||||||
from transformers import T5EncoderModel
|
from transformers import T5EncoderModel
|
||||||
from torchao.quantization import quantize_, int8_weight_only
|
from torchao.quantization import quantize_, int8_weight_only
|
||||||
@ -68,9 +73,13 @@ def generate_video(
|
|||||||
- quantization_scheme (str): The quantization scheme to use ('int8', 'fp8').
|
- quantization_scheme (str): The quantization scheme to use ('int8', 'fp8').
|
||||||
- dtype (torch.dtype): The data type for computation (default is torch.bfloat16).
|
- dtype (torch.dtype): The data type for computation (default is torch.bfloat16).
|
||||||
"""
|
"""
|
||||||
text_encoder = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder", torch_dtype=dtype)
|
text_encoder = T5EncoderModel.from_pretrained(
|
||||||
|
model_path, subfolder="text_encoder", torch_dtype=dtype
|
||||||
|
)
|
||||||
text_encoder = quantize_model(part=text_encoder, quantization_scheme=quantization_scheme)
|
text_encoder = quantize_model(part=text_encoder, quantization_scheme=quantization_scheme)
|
||||||
transformer = CogVideoXTransformer3DModel.from_pretrained(model_path, subfolder="transformer", torch_dtype=dtype)
|
transformer = CogVideoXTransformer3DModel.from_pretrained(
|
||||||
|
model_path, subfolder="transformer", torch_dtype=dtype
|
||||||
|
)
|
||||||
transformer = quantize_model(part=transformer, quantization_scheme=quantization_scheme)
|
transformer = quantize_model(part=transformer, quantization_scheme=quantization_scheme)
|
||||||
vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae", torch_dtype=dtype)
|
vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae", torch_dtype=dtype)
|
||||||
vae = quantize_model(part=vae, quantization_scheme=quantization_scheme)
|
vae = quantize_model(part=vae, quantization_scheme=quantization_scheme)
|
||||||
@ -81,7 +90,9 @@ def generate_video(
|
|||||||
vae=vae,
|
vae=vae,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
)
|
)
|
||||||
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
|
pipe.scheduler = CogVideoXDPMScheduler.from_config(
|
||||||
|
pipe.scheduler.config, timestep_spacing="trailing"
|
||||||
|
)
|
||||||
pipe.enable_model_cpu_offload()
|
pipe.enable_model_cpu_offload()
|
||||||
pipe.vae.enable_slicing()
|
pipe.vae.enable_slicing()
|
||||||
pipe.vae.enable_tiling()
|
pipe.vae.enable_tiling()
|
||||||
@ -100,16 +111,34 @@ def generate_video(
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX")
|
parser = argparse.ArgumentParser(
|
||||||
parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated")
|
description="Generate a video from a text prompt using CogVideoX"
|
||||||
parser.add_argument("--model_path", type=str, default="THUDM/CogVideoX-5b", help="Path of the pre-trained model")
|
)
|
||||||
parser.add_argument("--output_path", type=str, default="./output.mp4", help="Path to save generated video")
|
|
||||||
parser.add_argument("--num_inference_steps", type=int, default=50, help="Inference steps")
|
|
||||||
parser.add_argument("--guidance_scale", type=float, default=6.0, help="Classifier-free guidance scale")
|
|
||||||
parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Videos to generate per prompt")
|
|
||||||
parser.add_argument("--dtype", type=str, default="bfloat16", help="Data type (e.g., 'float16', 'bfloat16')")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--quantization_scheme", type=str, default="fp8", choices=["int8", "fp8"], help="Quantization scheme"
|
"--prompt", type=str, required=True, help="The description of the video to be generated"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_path", type=str, default="THUDM/CogVideoX-5b", help="Path of the pre-trained model"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_path", type=str, default="./output.mp4", help="Path to save generated video"
|
||||||
|
)
|
||||||
|
parser.add_argument("--num_inference_steps", type=int, default=50, help="Inference steps")
|
||||||
|
parser.add_argument(
|
||||||
|
"--guidance_scale", type=float, default=6.0, help="Classifier-free guidance scale"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_videos_per_prompt", type=int, default=1, help="Videos to generate per prompt"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dtype", type=str, default="bfloat16", help="Data type (e.g., 'float16', 'bfloat16')"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--quantization_scheme",
|
||||||
|
type=str,
|
||||||
|
default="fp8",
|
||||||
|
choices=["int8", "fp8"],
|
||||||
|
help="Quantization scheme",
|
||||||
)
|
)
|
||||||
parser.add_argument("--num_frames", type=int, default=81, help="Number of frames in the video")
|
parser.add_argument("--num_frames", type=int, default=81, help="Number of frames in the video")
|
||||||
parser.add_argument("--fps", type=int, default=16, help="Frames per second for output video")
|
parser.add_argument("--fps", type=int, default=16, help="Frames per second for output video")
|
||||||
|
@ -104,18 +104,34 @@ def save_video(tensor, output_path):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="CogVideoX encode/decode demo")
|
parser = argparse.ArgumentParser(description="CogVideoX encode/decode demo")
|
||||||
parser.add_argument("--model_path", type=str, required=True, help="The path to the CogVideoX model")
|
parser.add_argument(
|
||||||
|
"--model_path", type=str, required=True, help="The path to the CogVideoX model"
|
||||||
|
)
|
||||||
parser.add_argument("--video_path", type=str, help="The path to the video file (for encoding)")
|
parser.add_argument("--video_path", type=str, help="The path to the video file (for encoding)")
|
||||||
parser.add_argument("--encoded_path", type=str, help="The path to the encoded tensor file (for decoding)")
|
|
||||||
parser.add_argument("--output_path", type=str, default=".", help="The path to save the output file")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--mode", type=str, choices=["encode", "decode", "both"], required=True, help="Mode: encode, decode, or both"
|
"--encoded_path", type=str, help="The path to the encoded tensor file (for decoding)"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--dtype", type=str, default="bfloat16", help="The data type for computation (e.g., 'float16' or 'bfloat16')"
|
"--output_path", type=str, default=".", help="The path to save the output file"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--device", type=str, default="cuda", help="The device to use for computation (e.g., 'cuda' or 'cpu')"
|
"--mode",
|
||||||
|
type=str,
|
||||||
|
choices=["encode", "decode", "both"],
|
||||||
|
required=True,
|
||||||
|
help="Mode: encode, decode, or both",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dtype",
|
||||||
|
type=str,
|
||||||
|
default="bfloat16",
|
||||||
|
help="The data type for computation (e.g., 'float16' or 'bfloat16')",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--device",
|
||||||
|
type=str,
|
||||||
|
default="cuda",
|
||||||
|
help="The device to use for computation (e.g., 'cuda' or 'cpu')",
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -126,15 +142,21 @@ if __name__ == "__main__":
|
|||||||
assert args.video_path, "Video path must be provided for encoding."
|
assert args.video_path, "Video path must be provided for encoding."
|
||||||
encoded_output = encode_video(args.model_path, args.video_path, dtype, device)
|
encoded_output = encode_video(args.model_path, args.video_path, dtype, device)
|
||||||
torch.save(encoded_output, args.output_path + "/encoded.pt")
|
torch.save(encoded_output, args.output_path + "/encoded.pt")
|
||||||
print(f"Finished encoding the video to a tensor, save it to a file at {encoded_output}/encoded.pt")
|
print(
|
||||||
|
f"Finished encoding the video to a tensor, save it to a file at {encoded_output}/encoded.pt"
|
||||||
|
)
|
||||||
elif args.mode == "decode":
|
elif args.mode == "decode":
|
||||||
assert args.encoded_path, "Encoded tensor path must be provided for decoding."
|
assert args.encoded_path, "Encoded tensor path must be provided for decoding."
|
||||||
decoded_output = decode_video(args.model_path, args.encoded_path, dtype, device)
|
decoded_output = decode_video(args.model_path, args.encoded_path, dtype, device)
|
||||||
save_video(decoded_output, args.output_path)
|
save_video(decoded_output, args.output_path)
|
||||||
print(f"Finished decoding the video and saved it to a file at {args.output_path}/output.mp4")
|
print(
|
||||||
|
f"Finished decoding the video and saved it to a file at {args.output_path}/output.mp4"
|
||||||
|
)
|
||||||
elif args.mode == "both":
|
elif args.mode == "both":
|
||||||
assert args.video_path, "Video path must be provided for encoding."
|
assert args.video_path, "Video path must be provided for encoding."
|
||||||
encoded_output = encode_video(args.model_path, args.video_path, dtype, device)
|
encoded_output = encode_video(args.model_path, args.video_path, dtype, device)
|
||||||
torch.save(encoded_output, args.output_path + "/encoded.pt")
|
torch.save(encoded_output, args.output_path + "/encoded.pt")
|
||||||
decoded_output = decode_video(args.model_path, args.output_path + "/encoded.pt", dtype, device)
|
decoded_output = decode_video(
|
||||||
|
args.model_path, args.output_path + "/encoded.pt", dtype, device
|
||||||
|
)
|
||||||
save_video(decoded_output, args.output_path)
|
save_video(decoded_output, args.output_path)
|
||||||
|
@ -144,7 +144,9 @@ def convert_prompt(prompt: str, retry_times: int = 3, type: str = "t2v", image_p
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--prompt", type=str, required=True, help="Prompt to convert")
|
parser.add_argument("--prompt", type=str, required=True, help="Prompt to convert")
|
||||||
parser.add_argument("--retry_times", type=int, default=3, help="Number of times to retry the conversion")
|
parser.add_argument(
|
||||||
|
"--retry_times", type=int, default=3, help="Number of times to retry the conversion"
|
||||||
|
)
|
||||||
parser.add_argument("--type", type=str, default="t2v", help="Type of conversion (t2v or i2v)")
|
parser.add_argument("--type", type=str, default="t2v", help="Type of conversion (t2v or i2v)")
|
||||||
parser.add_argument("--image_path", type=str, default=None, help="Path to the image file")
|
parser.add_argument("--image_path", type=str, default=None, help="Path to the image file")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
@ -30,7 +30,10 @@ import torchvision.transforms as T
|
|||||||
from diffusers.models.attention_processor import Attention, CogVideoXAttnProcessor2_0
|
from diffusers.models.attention_processor import Attention, CogVideoXAttnProcessor2_0
|
||||||
from diffusers.models.autoencoders import AutoencoderKLCogVideoX
|
from diffusers.models.autoencoders import AutoencoderKLCogVideoX
|
||||||
from diffusers.models.embeddings import apply_rotary_emb
|
from diffusers.models.embeddings import apply_rotary_emb
|
||||||
from diffusers.models.transformers.cogvideox_transformer_3d import CogVideoXBlock, CogVideoXTransformer3DModel
|
from diffusers.models.transformers.cogvideox_transformer_3d import (
|
||||||
|
CogVideoXBlock,
|
||||||
|
CogVideoXTransformer3DModel,
|
||||||
|
)
|
||||||
from diffusers.pipelines.cogvideo.pipeline_cogvideox import CogVideoXPipeline, retrieve_timesteps
|
from diffusers.pipelines.cogvideo.pipeline_cogvideox import CogVideoXPipeline, retrieve_timesteps
|
||||||
from diffusers.schedulers import CogVideoXDDIMScheduler, DDIMInverseScheduler
|
from diffusers.schedulers import CogVideoXDDIMScheduler, DDIMInverseScheduler
|
||||||
from diffusers.utils import export_to_video
|
from diffusers.utils import export_to_video
|
||||||
@ -62,22 +65,48 @@ class DDIMInversionArguments(TypedDict):
|
|||||||
def get_args() -> DDIMInversionArguments:
|
def get_args() -> DDIMInversionArguments:
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
parser.add_argument("--model_path", type=str, required=True, help="Path of the pretrained model")
|
parser.add_argument(
|
||||||
parser.add_argument("--prompt", type=str, required=True, help="Prompt for the direct sample procedure")
|
"--model_path", type=str, required=True, help="Path of the pretrained model"
|
||||||
parser.add_argument("--video_path", type=str, required=True, help="Path of the video for inversion")
|
)
|
||||||
parser.add_argument("--output_path", type=str, default="output", help="Path of the output videos")
|
parser.add_argument(
|
||||||
parser.add_argument("--guidance_scale", type=float, default=6.0, help="Classifier-free guidance scale")
|
"--prompt", type=str, required=True, help="Prompt for the direct sample procedure"
|
||||||
parser.add_argument("--num_inference_steps", type=int, default=50, help="Number of inference steps")
|
)
|
||||||
parser.add_argument("--skip_frames_start", type=int, default=0, help="Number of skipped frames from the start")
|
parser.add_argument(
|
||||||
parser.add_argument("--skip_frames_end", type=int, default=0, help="Number of skipped frames from the end")
|
"--video_path", type=str, required=True, help="Path of the video for inversion"
|
||||||
parser.add_argument("--frame_sample_step", type=int, default=None, help="Temporal stride of the sampled frames")
|
)
|
||||||
parser.add_argument("--max_num_frames", type=int, default=81, help="Max number of sampled frames")
|
parser.add_argument(
|
||||||
|
"--output_path", type=str, default="output", help="Path of the output videos"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--guidance_scale", type=float, default=6.0, help="Classifier-free guidance scale"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_inference_steps", type=int, default=50, help="Number of inference steps"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--skip_frames_start", type=int, default=0, help="Number of skipped frames from the start"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--skip_frames_end", type=int, default=0, help="Number of skipped frames from the end"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--frame_sample_step", type=int, default=None, help="Temporal stride of the sampled frames"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_num_frames", type=int, default=81, help="Max number of sampled frames"
|
||||||
|
)
|
||||||
parser.add_argument("--width", type=int, default=720, help="Resized width of the video frames")
|
parser.add_argument("--width", type=int, default=720, help="Resized width of the video frames")
|
||||||
parser.add_argument("--height", type=int, default=480, help="Resized height of the video frames")
|
parser.add_argument(
|
||||||
|
"--height", type=int, default=480, help="Resized height of the video frames"
|
||||||
|
)
|
||||||
parser.add_argument("--fps", type=int, default=8, help="Frame rate of the output videos")
|
parser.add_argument("--fps", type=int, default=8, help="Frame rate of the output videos")
|
||||||
parser.add_argument("--dtype", type=str, default="bf16", choices=["bf16", "fp16"], help="Dtype of the model")
|
parser.add_argument(
|
||||||
|
"--dtype", type=str, default="bf16", choices=["bf16", "fp16"], help="Dtype of the model"
|
||||||
|
)
|
||||||
parser.add_argument("--seed", type=int, default=42, help="Seed for the random number generator")
|
parser.add_argument("--seed", type=int, default=42, help="Seed for the random number generator")
|
||||||
parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"], help="Device for inference")
|
parser.add_argument(
|
||||||
|
"--device", type=str, default="cuda", choices=["cuda", "cpu"], help="Device for inference"
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float16
|
args.dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float16
|
||||||
@ -116,13 +145,20 @@ class CogVideoXAttnProcessor2_0ForDDIMInversion(CogVideoXAttnProcessor2_0):
|
|||||||
|
|
||||||
# Apply RoPE if needed
|
# Apply RoPE if needed
|
||||||
if image_rotary_emb is not None:
|
if image_rotary_emb is not None:
|
||||||
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
|
query[:, :, text_seq_length:] = apply_rotary_emb(
|
||||||
|
query[:, :, text_seq_length:], image_rotary_emb
|
||||||
|
)
|
||||||
if not attn.is_cross_attention:
|
if not attn.is_cross_attention:
|
||||||
if key.size(2) == query.size(2): # Attention for reference hidden states
|
if key.size(2) == query.size(2): # Attention for reference hidden states
|
||||||
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
key[:, :, text_seq_length:] = apply_rotary_emb(
|
||||||
|
key[:, :, text_seq_length:], image_rotary_emb
|
||||||
|
)
|
||||||
else: # RoPE should be applied to each group of image tokens
|
else: # RoPE should be applied to each group of image tokens
|
||||||
key[:, :, text_seq_length : text_seq_length + image_seq_length] = apply_rotary_emb(
|
key[:, :, text_seq_length : text_seq_length + image_seq_length] = (
|
||||||
key[:, :, text_seq_length : text_seq_length + image_seq_length], image_rotary_emb
|
apply_rotary_emb(
|
||||||
|
key[:, :, text_seq_length : text_seq_length + image_seq_length],
|
||||||
|
image_rotary_emb,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
key[:, :, text_seq_length * 2 + image_seq_length :] = apply_rotary_emb(
|
key[:, :, text_seq_length * 2 + image_seq_length :] = apply_rotary_emb(
|
||||||
key[:, :, text_seq_length * 2 + image_seq_length :], image_rotary_emb
|
key[:, :, text_seq_length * 2 + image_seq_length :], image_rotary_emb
|
||||||
@ -162,8 +198,12 @@ class CogVideoXAttnProcessor2_0ForDDIMInversion(CogVideoXAttnProcessor2_0):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
attention_mask = attn.prepare_attention_mask(
|
||||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
attention_mask, sequence_length, batch_size
|
||||||
|
)
|
||||||
|
attention_mask = attention_mask.view(
|
||||||
|
batch_size, attn.heads, -1, attention_mask.shape[-1]
|
||||||
|
)
|
||||||
|
|
||||||
query = attn.to_q(hidden_states)
|
query = attn.to_q(hidden_states)
|
||||||
key = attn.to_k(hidden_states)
|
key = attn.to_k(hidden_states)
|
||||||
@ -260,14 +300,18 @@ def get_video_frames(
|
|||||||
return frames.permute(0, 3, 1, 2).contiguous() # [F, C, H, W]
|
return frames.permute(0, 3, 1, 2).contiguous() # [F, C, H, W]
|
||||||
|
|
||||||
|
|
||||||
def encode_video_frames(vae: AutoencoderKLCogVideoX, video_frames: torch.FloatTensor) -> torch.FloatTensor:
|
def encode_video_frames(
|
||||||
|
vae: AutoencoderKLCogVideoX, video_frames: torch.FloatTensor
|
||||||
|
) -> torch.FloatTensor:
|
||||||
video_frames = video_frames.to(device=vae.device, dtype=vae.dtype)
|
video_frames = video_frames.to(device=vae.device, dtype=vae.dtype)
|
||||||
video_frames = video_frames.unsqueeze(0).permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
|
video_frames = video_frames.unsqueeze(0).permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
|
||||||
latent_dist = vae.encode(x=video_frames).latent_dist.sample().transpose(1, 2)
|
latent_dist = vae.encode(x=video_frames).latent_dist.sample().transpose(1, 2)
|
||||||
return latent_dist * vae.config.scaling_factor
|
return latent_dist * vae.config.scaling_factor
|
||||||
|
|
||||||
|
|
||||||
def export_latents_to_video(pipeline: CogVideoXPipeline, latents: torch.FloatTensor, video_path: str, fps: int):
|
def export_latents_to_video(
|
||||||
|
pipeline: CogVideoXPipeline, latents: torch.FloatTensor, video_path: str, fps: int
|
||||||
|
):
|
||||||
video = pipeline.decode_latents(latents)
|
video = pipeline.decode_latents(latents)
|
||||||
frames = pipeline.video_processor.postprocess_video(video=video, output_type="pil")
|
frames = pipeline.video_processor.postprocess_video(video=video, output_type="pil")
|
||||||
export_to_video(video_frames=frames[0], output_video_path=video_path, fps=fps)
|
export_to_video(video_frames=frames[0], output_video_path=video_path, fps=fps)
|
||||||
@ -320,7 +364,9 @@ def sample(
|
|||||||
|
|
||||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||||
extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta)
|
extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta)
|
||||||
if isinstance(scheduler, DDIMInverseScheduler): # Inverse scheduler does not accept extra kwargs
|
if isinstance(
|
||||||
|
scheduler, DDIMInverseScheduler
|
||||||
|
): # Inverse scheduler does not accept extra kwargs
|
||||||
extra_step_kwargs = {}
|
extra_step_kwargs = {}
|
||||||
|
|
||||||
# 7. Create rotary embeds if required
|
# 7. Create rotary embeds if required
|
||||||
@ -344,7 +390,9 @@ def sample(
|
|||||||
if pipeline.interrupt:
|
if pipeline.interrupt:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
latent_model_input = (
|
||||||
|
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||||
|
)
|
||||||
if reference_latents is not None:
|
if reference_latents is not None:
|
||||||
reference = reference_latents[i]
|
reference = reference_latents[i]
|
||||||
reference = torch.cat([reference] * 2) if do_classifier_free_guidance else reference
|
reference = torch.cat([reference] * 2) if do_classifier_free_guidance else reference
|
||||||
@ -371,18 +419,31 @@ def sample(
|
|||||||
# perform guidance
|
# perform guidance
|
||||||
if use_dynamic_cfg:
|
if use_dynamic_cfg:
|
||||||
pipeline._guidance_scale = 1 + guidance_scale * (
|
pipeline._guidance_scale = 1 + guidance_scale * (
|
||||||
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
|
(
|
||||||
|
1
|
||||||
|
- math.cos(
|
||||||
|
math.pi
|
||||||
|
* ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0
|
||||||
|
)
|
||||||
|
)
|
||||||
|
/ 2
|
||||||
)
|
)
|
||||||
if do_classifier_free_guidance:
|
if do_classifier_free_guidance:
|
||||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||||
noise_pred = noise_pred_uncond + pipeline.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
noise_pred = noise_pred_uncond + pipeline.guidance_scale * (
|
||||||
|
noise_pred_text - noise_pred_uncond
|
||||||
|
)
|
||||||
|
|
||||||
# compute the noisy sample x_t-1 -> x_t
|
# compute the noisy sample x_t-1 -> x_t
|
||||||
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
latents = scheduler.step(
|
||||||
|
noise_pred, t, latents, **extra_step_kwargs, return_dict=False
|
||||||
|
)[0]
|
||||||
latents = latents.to(prompt_embeds.dtype)
|
latents = latents.to(prompt_embeds.dtype)
|
||||||
trajectory[i] = latents
|
trajectory[i] = latents
|
||||||
|
|
||||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
if i == len(timesteps) - 1 or (
|
||||||
|
(i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0
|
||||||
|
):
|
||||||
progress_bar.update()
|
progress_bar.update()
|
||||||
|
|
||||||
# Offload all models
|
# Offload all models
|
||||||
@ -410,7 +471,9 @@ def ddim_inversion(
|
|||||||
seed: int,
|
seed: int,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
):
|
):
|
||||||
pipeline: CogVideoXPipeline = CogVideoXPipeline.from_pretrained(model_path, torch_dtype=dtype).to(device=device)
|
pipeline: CogVideoXPipeline = CogVideoXPipeline.from_pretrained(
|
||||||
|
model_path, torch_dtype=dtype
|
||||||
|
).to(device=device)
|
||||||
if not pipeline.transformer.config.use_rotary_positional_embeddings:
|
if not pipeline.transformer.config.use_rotary_positional_embeddings:
|
||||||
raise NotImplementedError("This script supports CogVideoX 5B model only.")
|
raise NotImplementedError("This script supports CogVideoX 5B model only.")
|
||||||
video_frames = get_video_frames(
|
video_frames = get_video_frames(
|
||||||
|
@ -43,5 +43,3 @@ pip install -r requirements.txt
|
|||||||
```bash
|
```bash
|
||||||
python app.py
|
python app.py
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
@ -39,11 +39,15 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
|
|||||||
|
|
||||||
MODEL = "THUDM/CogVideoX-5b"
|
MODEL = "THUDM/CogVideoX-5b"
|
||||||
|
|
||||||
hf_hub_download(repo_id="ai-forever/Real-ESRGAN", filename="RealESRGAN_x4.pth", local_dir="model_real_esran")
|
hf_hub_download(
|
||||||
|
repo_id="ai-forever/Real-ESRGAN", filename="RealESRGAN_x4.pth", local_dir="model_real_esran"
|
||||||
|
)
|
||||||
snapshot_download(repo_id="AlexWortega/RIFE", local_dir="model_rife")
|
snapshot_download(repo_id="AlexWortega/RIFE", local_dir="model_rife")
|
||||||
|
|
||||||
pipe = CogVideoXPipeline.from_pretrained(MODEL, torch_dtype=torch.bfloat16).to(device)
|
pipe = CogVideoXPipeline.from_pretrained(MODEL, torch_dtype=torch.bfloat16).to(device)
|
||||||
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
|
pipe.scheduler = CogVideoXDPMScheduler.from_config(
|
||||||
|
pipe.scheduler.config, timestep_spacing="trailing"
|
||||||
|
)
|
||||||
pipe_video = CogVideoXVideoToVideoPipeline.from_pretrained(
|
pipe_video = CogVideoXVideoToVideoPipeline.from_pretrained(
|
||||||
MODEL,
|
MODEL,
|
||||||
transformer=pipe.transformer,
|
transformer=pipe.transformer,
|
||||||
@ -296,8 +300,16 @@ def delete_old_files():
|
|||||||
|
|
||||||
|
|
||||||
threading.Thread(target=delete_old_files, daemon=True).start()
|
threading.Thread(target=delete_old_files, daemon=True).start()
|
||||||
examples_videos = [["example_videos/horse.mp4"], ["example_videos/kitten.mp4"], ["example_videos/train_running.mp4"]]
|
examples_videos = [
|
||||||
examples_images = [["example_images/beach.png"], ["example_images/street.png"], ["example_images/camping.png"]]
|
["example_videos/horse.mp4"],
|
||||||
|
["example_videos/kitten.mp4"],
|
||||||
|
["example_videos/train_running.mp4"],
|
||||||
|
]
|
||||||
|
examples_images = [
|
||||||
|
["example_images/beach.png"],
|
||||||
|
["example_images/street.png"],
|
||||||
|
["example_images/camping.png"],
|
||||||
|
]
|
||||||
|
|
||||||
with gr.Blocks() as demo:
|
with gr.Blocks() as demo:
|
||||||
gr.Markdown("""
|
gr.Markdown("""
|
||||||
@ -322,14 +334,26 @@ with gr.Blocks() as demo:
|
|||||||
""")
|
""")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
with gr.Accordion("I2V: Image Input (cannot be used simultaneously with video input)", open=False):
|
with gr.Accordion(
|
||||||
|
"I2V: Image Input (cannot be used simultaneously with video input)", open=False
|
||||||
|
):
|
||||||
image_input = gr.Image(label="Input Image (will be cropped to 720 * 480)")
|
image_input = gr.Image(label="Input Image (will be cropped to 720 * 480)")
|
||||||
examples_component_images = gr.Examples(examples_images, inputs=[image_input], cache_examples=False)
|
examples_component_images = gr.Examples(
|
||||||
with gr.Accordion("V2V: Video Input (cannot be used simultaneously with image input)", open=False):
|
examples_images, inputs=[image_input], cache_examples=False
|
||||||
video_input = gr.Video(label="Input Video (will be cropped to 49 frames, 6 seconds at 8fps)")
|
)
|
||||||
|
with gr.Accordion(
|
||||||
|
"V2V: Video Input (cannot be used simultaneously with image input)", open=False
|
||||||
|
):
|
||||||
|
video_input = gr.Video(
|
||||||
|
label="Input Video (will be cropped to 49 frames, 6 seconds at 8fps)"
|
||||||
|
)
|
||||||
strength = gr.Slider(0.1, 1.0, value=0.8, step=0.01, label="Strength")
|
strength = gr.Slider(0.1, 1.0, value=0.8, step=0.01, label="Strength")
|
||||||
examples_component_videos = gr.Examples(examples_videos, inputs=[video_input], cache_examples=False)
|
examples_component_videos = gr.Examples(
|
||||||
prompt = gr.Textbox(label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5)
|
examples_videos, inputs=[video_input], cache_examples=False
|
||||||
|
)
|
||||||
|
prompt = gr.Textbox(
|
||||||
|
label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5
|
||||||
|
)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
gr.Markdown(
|
gr.Markdown(
|
||||||
@ -340,11 +364,16 @@ with gr.Blocks() as demo:
|
|||||||
with gr.Column():
|
with gr.Column():
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
seed_param = gr.Number(
|
seed_param = gr.Number(
|
||||||
label="Inference Seed (Enter a positive number, -1 for random)", value=-1
|
label="Inference Seed (Enter a positive number, -1 for random)",
|
||||||
|
value=-1,
|
||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
enable_scale = gr.Checkbox(label="Super-Resolution (720 × 480 -> 2880 × 1920)", value=False)
|
enable_scale = gr.Checkbox(
|
||||||
enable_rife = gr.Checkbox(label="Frame Interpolation (8fps -> 16fps)", value=False)
|
label="Super-Resolution (720 × 480 -> 2880 × 1920)", value=False
|
||||||
|
)
|
||||||
|
enable_rife = gr.Checkbox(
|
||||||
|
label="Frame Interpolation (8fps -> 16fps)", value=False
|
||||||
|
)
|
||||||
gr.Markdown(
|
gr.Markdown(
|
||||||
"✨In this demo, we use [RIFE](https://github.com/hzwer/ECCV2022-RIFE) for frame interpolation and [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) for upscaling(Super-Resolution).<br> The entire process is based on open-source solutions."
|
"✨In this demo, we use [RIFE](https://github.com/hzwer/ECCV2022-RIFE) for frame interpolation and [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) for upscaling(Super-Resolution).<br> The entire process is based on open-source solutions."
|
||||||
)
|
)
|
||||||
@ -430,7 +459,7 @@ with gr.Blocks() as demo:
|
|||||||
seed_value,
|
seed_value,
|
||||||
scale_status,
|
scale_status,
|
||||||
rife_status,
|
rife_status,
|
||||||
progress=gr.Progress(track_tqdm=True)
|
progress=gr.Progress(track_tqdm=True),
|
||||||
):
|
):
|
||||||
latents, seed = infer(
|
latents, seed = infer(
|
||||||
prompt,
|
prompt,
|
||||||
@ -457,7 +486,9 @@ with gr.Blocks() as demo:
|
|||||||
image_pil = VaeImageProcessor.numpy_to_pil(image_np)
|
image_pil = VaeImageProcessor.numpy_to_pil(image_np)
|
||||||
batch_video_frames.append(image_pil)
|
batch_video_frames.append(image_pil)
|
||||||
|
|
||||||
video_path = utils.save_video(batch_video_frames[0], fps=math.ceil((len(batch_video_frames[0]) - 1) / 6))
|
video_path = utils.save_video(
|
||||||
|
batch_video_frames[0], fps=math.ceil((len(batch_video_frames[0]) - 1) / 6)
|
||||||
|
)
|
||||||
video_update = gr.update(visible=True, value=video_path)
|
video_update = gr.update(visible=True, value=video_path)
|
||||||
gif_path = convert_to_gif(video_path)
|
gif_path = convert_to_gif(video_path)
|
||||||
gif_update = gr.update(visible=True, value=gif_path)
|
gif_update = gr.update(visible=True, value=gif_path)
|
||||||
|
@ -3,7 +3,9 @@ from .refine import *
|
|||||||
|
|
||||||
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
|
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
|
||||||
return nn.Sequential(
|
return nn.Sequential(
|
||||||
torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1),
|
torch.nn.ConvTranspose2d(
|
||||||
|
in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1
|
||||||
|
),
|
||||||
nn.PReLU(out_planes),
|
nn.PReLU(out_planes),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -46,7 +48,11 @@ class IFBlock(nn.Module):
|
|||||||
if scale != 1:
|
if scale != 1:
|
||||||
x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False)
|
x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False)
|
||||||
if flow != None:
|
if flow != None:
|
||||||
flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False) * 1.0 / scale
|
flow = (
|
||||||
|
F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False)
|
||||||
|
* 1.0
|
||||||
|
/ scale
|
||||||
|
)
|
||||||
x = torch.cat((x, flow), 1)
|
x = torch.cat((x, flow), 1)
|
||||||
x = self.conv0(x)
|
x = self.conv0(x)
|
||||||
x = self.convblock(x) + x
|
x = self.convblock(x) + x
|
||||||
@ -102,7 +108,9 @@ class IFNet(nn.Module):
|
|||||||
warped_img0_teacher = warp(img0, flow_teacher[:, :2])
|
warped_img0_teacher = warp(img0, flow_teacher[:, :2])
|
||||||
warped_img1_teacher = warp(img1, flow_teacher[:, 2:4])
|
warped_img1_teacher = warp(img1, flow_teacher[:, 2:4])
|
||||||
mask_teacher = torch.sigmoid(mask + mask_d)
|
mask_teacher = torch.sigmoid(mask + mask_d)
|
||||||
merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * (1 - mask_teacher)
|
merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * (
|
||||||
|
1 - mask_teacher
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
flow_teacher = None
|
flow_teacher = None
|
||||||
merged_teacher = None
|
merged_teacher = None
|
||||||
@ -110,11 +118,16 @@ class IFNet(nn.Module):
|
|||||||
merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
|
merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
|
||||||
if gt.shape[1] == 3:
|
if gt.shape[1] == 3:
|
||||||
loss_mask = (
|
loss_mask = (
|
||||||
((merged[i] - gt).abs().mean(1, True) > (merged_teacher - gt).abs().mean(1, True) + 0.01)
|
(
|
||||||
|
(merged[i] - gt).abs().mean(1, True)
|
||||||
|
> (merged_teacher - gt).abs().mean(1, True) + 0.01
|
||||||
|
)
|
||||||
.float()
|
.float()
|
||||||
.detach()
|
.detach()
|
||||||
)
|
)
|
||||||
loss_distill += (((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask).mean()
|
loss_distill += (
|
||||||
|
((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask
|
||||||
|
).mean()
|
||||||
c0 = self.contextnet(img0, flow[:, :2])
|
c0 = self.contextnet(img0, flow[:, :2])
|
||||||
c1 = self.contextnet(img1, flow[:, 2:4])
|
c1 = self.contextnet(img1, flow[:, 2:4])
|
||||||
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
|
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
|
||||||
|
@ -3,7 +3,9 @@ from .refine_2R import *
|
|||||||
|
|
||||||
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
|
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
|
||||||
return nn.Sequential(
|
return nn.Sequential(
|
||||||
torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1),
|
torch.nn.ConvTranspose2d(
|
||||||
|
in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1
|
||||||
|
),
|
||||||
nn.PReLU(out_planes),
|
nn.PReLU(out_planes),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -46,7 +48,11 @@ class IFBlock(nn.Module):
|
|||||||
if scale != 1:
|
if scale != 1:
|
||||||
x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False)
|
x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False)
|
||||||
if flow != None:
|
if flow != None:
|
||||||
flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False) * 1.0 / scale
|
flow = (
|
||||||
|
F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False)
|
||||||
|
* 1.0
|
||||||
|
/ scale
|
||||||
|
)
|
||||||
x = torch.cat((x, flow), 1)
|
x = torch.cat((x, flow), 1)
|
||||||
x = self.conv0(x)
|
x = self.conv0(x)
|
||||||
x = self.convblock(x) + x
|
x = self.convblock(x) + x
|
||||||
@ -102,7 +108,9 @@ class IFNet(nn.Module):
|
|||||||
warped_img0_teacher = warp(img0, flow_teacher[:, :2])
|
warped_img0_teacher = warp(img0, flow_teacher[:, :2])
|
||||||
warped_img1_teacher = warp(img1, flow_teacher[:, 2:4])
|
warped_img1_teacher = warp(img1, flow_teacher[:, 2:4])
|
||||||
mask_teacher = torch.sigmoid(mask + mask_d)
|
mask_teacher = torch.sigmoid(mask + mask_d)
|
||||||
merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * (1 - mask_teacher)
|
merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * (
|
||||||
|
1 - mask_teacher
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
flow_teacher = None
|
flow_teacher = None
|
||||||
merged_teacher = None
|
merged_teacher = None
|
||||||
@ -110,11 +118,16 @@ class IFNet(nn.Module):
|
|||||||
merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
|
merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
|
||||||
if gt.shape[1] == 3:
|
if gt.shape[1] == 3:
|
||||||
loss_mask = (
|
loss_mask = (
|
||||||
((merged[i] - gt).abs().mean(1, True) > (merged_teacher - gt).abs().mean(1, True) + 0.01)
|
(
|
||||||
|
(merged[i] - gt).abs().mean(1, True)
|
||||||
|
> (merged_teacher - gt).abs().mean(1, True) + 0.01
|
||||||
|
)
|
||||||
.float()
|
.float()
|
||||||
.detach()
|
.detach()
|
||||||
)
|
)
|
||||||
loss_distill += (((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask).mean()
|
loss_distill += (
|
||||||
|
((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask
|
||||||
|
).mean()
|
||||||
c0 = self.contextnet(img0, flow[:, :2])
|
c0 = self.contextnet(img0, flow[:, :2])
|
||||||
c1 = self.contextnet(img1, flow[:, 2:4])
|
c1 = self.contextnet(img1, flow[:, 2:4])
|
||||||
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
|
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
|
||||||
|
@ -61,11 +61,19 @@ class IFBlock(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x, flow, scale=1):
|
def forward(self, x, flow, scale=1):
|
||||||
x = F.interpolate(
|
x = F.interpolate(
|
||||||
x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False
|
x,
|
||||||
|
scale_factor=1.0 / scale,
|
||||||
|
mode="bilinear",
|
||||||
|
align_corners=False,
|
||||||
|
recompute_scale_factor=False,
|
||||||
)
|
)
|
||||||
flow = (
|
flow = (
|
||||||
F.interpolate(
|
F.interpolate(
|
||||||
flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False
|
flow,
|
||||||
|
scale_factor=1.0 / scale,
|
||||||
|
mode="bilinear",
|
||||||
|
align_corners=False,
|
||||||
|
recompute_scale_factor=False,
|
||||||
)
|
)
|
||||||
* 1.0
|
* 1.0
|
||||||
/ scale
|
/ scale
|
||||||
@ -78,11 +86,21 @@ class IFBlock(nn.Module):
|
|||||||
flow = self.conv1(feat)
|
flow = self.conv1(feat)
|
||||||
mask = self.conv2(feat)
|
mask = self.conv2(feat)
|
||||||
flow = (
|
flow = (
|
||||||
F.interpolate(flow, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
|
F.interpolate(
|
||||||
|
flow,
|
||||||
|
scale_factor=scale,
|
||||||
|
mode="bilinear",
|
||||||
|
align_corners=False,
|
||||||
|
recompute_scale_factor=False,
|
||||||
|
)
|
||||||
* scale
|
* scale
|
||||||
)
|
)
|
||||||
mask = F.interpolate(
|
mask = F.interpolate(
|
||||||
mask, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False
|
mask,
|
||||||
|
scale_factor=scale,
|
||||||
|
mode="bilinear",
|
||||||
|
align_corners=False,
|
||||||
|
recompute_scale_factor=False,
|
||||||
)
|
)
|
||||||
return flow, mask
|
return flow, mask
|
||||||
|
|
||||||
@ -112,7 +130,11 @@ class IFNet(nn.Module):
|
|||||||
loss_cons = 0
|
loss_cons = 0
|
||||||
block = [self.block0, self.block1, self.block2]
|
block = [self.block0, self.block1, self.block2]
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
f0, m0 = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], mask), 1), flow, scale=scale_list[i])
|
f0, m0 = block[i](
|
||||||
|
torch.cat((warped_img0[:, :3], warped_img1[:, :3], mask), 1),
|
||||||
|
flow,
|
||||||
|
scale=scale_list[i],
|
||||||
|
)
|
||||||
f1, m1 = block[i](
|
f1, m1 = block[i](
|
||||||
torch.cat((warped_img1[:, :3], warped_img0[:, :3], -mask), 1),
|
torch.cat((warped_img1[:, :3], warped_img0[:, :3], -mask), 1),
|
||||||
torch.cat((flow[:, 2:4], flow[:, :2]), 1),
|
torch.cat((flow[:, 2:4], flow[:, :2]), 1),
|
||||||
|
@ -3,7 +3,9 @@ from .refine import *
|
|||||||
|
|
||||||
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
|
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
|
||||||
return nn.Sequential(
|
return nn.Sequential(
|
||||||
torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1),
|
torch.nn.ConvTranspose2d(
|
||||||
|
in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1
|
||||||
|
),
|
||||||
nn.PReLU(out_planes),
|
nn.PReLU(out_planes),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -46,7 +48,11 @@ class IFBlock(nn.Module):
|
|||||||
if scale != 1:
|
if scale != 1:
|
||||||
x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False)
|
x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False)
|
||||||
if flow != None:
|
if flow != None:
|
||||||
flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False) * 1.0 / scale
|
flow = (
|
||||||
|
F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False)
|
||||||
|
* 1.0
|
||||||
|
/ scale
|
||||||
|
)
|
||||||
x = torch.cat((x, flow), 1)
|
x = torch.cat((x, flow), 1)
|
||||||
x = self.conv0(x)
|
x = self.conv0(x)
|
||||||
x = self.convblock(x) + x
|
x = self.convblock(x) + x
|
||||||
@ -83,7 +89,9 @@ class IFNet_m(nn.Module):
|
|||||||
for i in range(3):
|
for i in range(3):
|
||||||
if flow != None:
|
if flow != None:
|
||||||
flow_d, mask_d = stu[i](
|
flow_d, mask_d = stu[i](
|
||||||
torch.cat((img0, img1, timestep, warped_img0, warped_img1, mask), 1), flow, scale=scale[i]
|
torch.cat((img0, img1, timestep, warped_img0, warped_img1, mask), 1),
|
||||||
|
flow,
|
||||||
|
scale=scale[i],
|
||||||
)
|
)
|
||||||
flow = flow + flow_d
|
flow = flow + flow_d
|
||||||
mask = mask + mask_d
|
mask = mask + mask_d
|
||||||
@ -97,13 +105,17 @@ class IFNet_m(nn.Module):
|
|||||||
merged.append(merged_student)
|
merged.append(merged_student)
|
||||||
if gt.shape[1] == 3:
|
if gt.shape[1] == 3:
|
||||||
flow_d, mask_d = self.block_tea(
|
flow_d, mask_d = self.block_tea(
|
||||||
torch.cat((img0, img1, timestep, warped_img0, warped_img1, mask, gt), 1), flow, scale=1
|
torch.cat((img0, img1, timestep, warped_img0, warped_img1, mask, gt), 1),
|
||||||
|
flow,
|
||||||
|
scale=1,
|
||||||
)
|
)
|
||||||
flow_teacher = flow + flow_d
|
flow_teacher = flow + flow_d
|
||||||
warped_img0_teacher = warp(img0, flow_teacher[:, :2])
|
warped_img0_teacher = warp(img0, flow_teacher[:, :2])
|
||||||
warped_img1_teacher = warp(img1, flow_teacher[:, 2:4])
|
warped_img1_teacher = warp(img1, flow_teacher[:, 2:4])
|
||||||
mask_teacher = torch.sigmoid(mask + mask_d)
|
mask_teacher = torch.sigmoid(mask + mask_d)
|
||||||
merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * (1 - mask_teacher)
|
merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * (
|
||||||
|
1 - mask_teacher
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
flow_teacher = None
|
flow_teacher = None
|
||||||
merged_teacher = None
|
merged_teacher = None
|
||||||
@ -111,11 +123,16 @@ class IFNet_m(nn.Module):
|
|||||||
merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
|
merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
|
||||||
if gt.shape[1] == 3:
|
if gt.shape[1] == 3:
|
||||||
loss_mask = (
|
loss_mask = (
|
||||||
((merged[i] - gt).abs().mean(1, True) > (merged_teacher - gt).abs().mean(1, True) + 0.01)
|
(
|
||||||
|
(merged[i] - gt).abs().mean(1, True)
|
||||||
|
> (merged_teacher - gt).abs().mean(1, True) + 0.01
|
||||||
|
)
|
||||||
.float()
|
.float()
|
||||||
.detach()
|
.detach()
|
||||||
)
|
)
|
||||||
loss_distill += (((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask).mean()
|
loss_distill += (
|
||||||
|
((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask
|
||||||
|
).mean()
|
||||||
if returnflow:
|
if returnflow:
|
||||||
return flow
|
return flow
|
||||||
else:
|
else:
|
||||||
|
@ -44,7 +44,9 @@ class Model:
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
self.flownet.load_state_dict(convert(torch.load("{}/flownet.pkl".format(path))))
|
self.flownet.load_state_dict(convert(torch.load("{}/flownet.pkl".format(path))))
|
||||||
else:
|
else:
|
||||||
self.flownet.load_state_dict(convert(torch.load("{}/flownet.pkl".format(path), map_location="cpu")))
|
self.flownet.load_state_dict(
|
||||||
|
convert(torch.load("{}/flownet.pkl".format(path), map_location="cpu"))
|
||||||
|
)
|
||||||
|
|
||||||
def save_model(self, path, rank=0):
|
def save_model(self, path, rank=0):
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
|
@ -29,10 +29,14 @@ def downsample(x):
|
|||||||
|
|
||||||
|
|
||||||
def upsample(x):
|
def upsample(x):
|
||||||
cc = torch.cat([x, torch.zeros(x.shape[0], x.shape[1], x.shape[2], x.shape[3]).to(device)], dim=3)
|
cc = torch.cat(
|
||||||
|
[x, torch.zeros(x.shape[0], x.shape[1], x.shape[2], x.shape[3]).to(device)], dim=3
|
||||||
|
)
|
||||||
cc = cc.view(x.shape[0], x.shape[1], x.shape[2] * 2, x.shape[3])
|
cc = cc.view(x.shape[0], x.shape[1], x.shape[2] * 2, x.shape[3])
|
||||||
cc = cc.permute(0, 1, 3, 2)
|
cc = cc.permute(0, 1, 3, 2)
|
||||||
cc = torch.cat([cc, torch.zeros(x.shape[0], x.shape[1], x.shape[3], x.shape[2] * 2).to(device)], dim=3)
|
cc = torch.cat(
|
||||||
|
[cc, torch.zeros(x.shape[0], x.shape[1], x.shape[3], x.shape[2] * 2).to(device)], dim=3
|
||||||
|
)
|
||||||
cc = cc.view(x.shape[0], x.shape[1], x.shape[3] * 2, x.shape[2] * 2)
|
cc = cc.view(x.shape[0], x.shape[1], x.shape[3] * 2, x.shape[2] * 2)
|
||||||
x_up = cc.permute(0, 1, 3, 2)
|
x_up = cc.permute(0, 1, 3, 2)
|
||||||
return conv_gauss(x_up, 4 * gauss_kernel(channels=x.shape[1]))
|
return conv_gauss(x_up, 4 * gauss_kernel(channels=x.shape[1]))
|
||||||
@ -64,6 +68,10 @@ class LapLoss(torch.nn.Module):
|
|||||||
self.gauss_kernel = gauss_kernel(channels=channels)
|
self.gauss_kernel = gauss_kernel(channels=channels)
|
||||||
|
|
||||||
def forward(self, input, target):
|
def forward(self, input, target):
|
||||||
pyr_input = laplacian_pyramid(img=input, kernel=self.gauss_kernel, max_levels=self.max_levels)
|
pyr_input = laplacian_pyramid(
|
||||||
pyr_target = laplacian_pyramid(img=target, kernel=self.gauss_kernel, max_levels=self.max_levels)
|
img=input, kernel=self.gauss_kernel, max_levels=self.max_levels
|
||||||
|
)
|
||||||
|
pyr_target = laplacian_pyramid(
|
||||||
|
img=target, kernel=self.gauss_kernel, max_levels=self.max_levels
|
||||||
|
)
|
||||||
return sum(torch.nn.functional.l1_loss(a, b) for a, b in zip(pyr_input, pyr_target))
|
return sum(torch.nn.functional.l1_loss(a, b) for a, b in zip(pyr_input, pyr_target))
|
||||||
|
@ -7,7 +7,9 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|||||||
|
|
||||||
|
|
||||||
def gaussian(window_size, sigma):
|
def gaussian(window_size, sigma):
|
||||||
gauss = torch.Tensor([exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2)) for x in range(window_size)])
|
gauss = torch.Tensor(
|
||||||
|
[exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2)) for x in range(window_size)]
|
||||||
|
)
|
||||||
return gauss / gauss.sum()
|
return gauss / gauss.sum()
|
||||||
|
|
||||||
|
|
||||||
@ -22,7 +24,9 @@ def create_window_3d(window_size, channel=1):
|
|||||||
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
||||||
_2D_window = _1D_window.mm(_1D_window.t())
|
_2D_window = _1D_window.mm(_1D_window.t())
|
||||||
_3D_window = _2D_window.unsqueeze(2) @ (_1D_window.t())
|
_3D_window = _2D_window.unsqueeze(2) @ (_1D_window.t())
|
||||||
window = _3D_window.expand(1, channel, window_size, window_size, window_size).contiguous().to(device)
|
window = (
|
||||||
|
_3D_window.expand(1, channel, window_size, window_size, window_size).contiguous().to(device)
|
||||||
|
)
|
||||||
return window
|
return window
|
||||||
|
|
||||||
|
|
||||||
@ -50,16 +54,35 @@ def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False,
|
|||||||
|
|
||||||
# mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
|
# mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
|
||||||
# mu2 = F.conv2d(img2, window, padding=padd, groups=channel)
|
# mu2 = F.conv2d(img2, window, padding=padd, groups=channel)
|
||||||
mu1 = F.conv2d(F.pad(img1, (5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=channel)
|
mu1 = F.conv2d(
|
||||||
mu2 = F.conv2d(F.pad(img2, (5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=channel)
|
F.pad(img1, (5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=channel
|
||||||
|
)
|
||||||
|
mu2 = F.conv2d(
|
||||||
|
F.pad(img2, (5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=channel
|
||||||
|
)
|
||||||
|
|
||||||
mu1_sq = mu1.pow(2)
|
mu1_sq = mu1.pow(2)
|
||||||
mu2_sq = mu2.pow(2)
|
mu2_sq = mu2.pow(2)
|
||||||
mu1_mu2 = mu1 * mu2
|
mu1_mu2 = mu1 * mu2
|
||||||
|
|
||||||
sigma1_sq = F.conv2d(F.pad(img1 * img1, (5, 5, 5, 5), "replicate"), window, padding=padd, groups=channel) - mu1_sq
|
sigma1_sq = (
|
||||||
sigma2_sq = F.conv2d(F.pad(img2 * img2, (5, 5, 5, 5), "replicate"), window, padding=padd, groups=channel) - mu2_sq
|
F.conv2d(
|
||||||
sigma12 = F.conv2d(F.pad(img1 * img2, (5, 5, 5, 5), "replicate"), window, padding=padd, groups=channel) - mu1_mu2
|
F.pad(img1 * img1, (5, 5, 5, 5), "replicate"), window, padding=padd, groups=channel
|
||||||
|
)
|
||||||
|
- mu1_sq
|
||||||
|
)
|
||||||
|
sigma2_sq = (
|
||||||
|
F.conv2d(
|
||||||
|
F.pad(img2 * img2, (5, 5, 5, 5), "replicate"), window, padding=padd, groups=channel
|
||||||
|
)
|
||||||
|
- mu2_sq
|
||||||
|
)
|
||||||
|
sigma12 = (
|
||||||
|
F.conv2d(
|
||||||
|
F.pad(img1 * img2, (5, 5, 5, 5), "replicate"), window, padding=padd, groups=channel
|
||||||
|
)
|
||||||
|
- mu1_mu2
|
||||||
|
)
|
||||||
|
|
||||||
C1 = (0.01 * L) ** 2
|
C1 = (0.01 * L) ** 2
|
||||||
C2 = (0.03 * L) ** 2
|
C2 = (0.03 * L) ** 2
|
||||||
@ -80,7 +103,9 @@ def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False,
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def ssim_matlab(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
|
def ssim_matlab(
|
||||||
|
img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None
|
||||||
|
):
|
||||||
# Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
|
# Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
|
||||||
if val_range is None:
|
if val_range is None:
|
||||||
if torch.max(img1) > 128:
|
if torch.max(img1) > 128:
|
||||||
@ -106,16 +131,35 @@ def ssim_matlab(img1, img2, window_size=11, window=None, size_average=True, full
|
|||||||
img1 = img1.unsqueeze(1)
|
img1 = img1.unsqueeze(1)
|
||||||
img2 = img2.unsqueeze(1)
|
img2 = img2.unsqueeze(1)
|
||||||
|
|
||||||
mu1 = F.conv3d(F.pad(img1, (5, 5, 5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=1)
|
mu1 = F.conv3d(
|
||||||
mu2 = F.conv3d(F.pad(img2, (5, 5, 5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=1)
|
F.pad(img1, (5, 5, 5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=1
|
||||||
|
)
|
||||||
|
mu2 = F.conv3d(
|
||||||
|
F.pad(img2, (5, 5, 5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=1
|
||||||
|
)
|
||||||
|
|
||||||
mu1_sq = mu1.pow(2)
|
mu1_sq = mu1.pow(2)
|
||||||
mu2_sq = mu2.pow(2)
|
mu2_sq = mu2.pow(2)
|
||||||
mu1_mu2 = mu1 * mu2
|
mu1_mu2 = mu1 * mu2
|
||||||
|
|
||||||
sigma1_sq = F.conv3d(F.pad(img1 * img1, (5, 5, 5, 5, 5, 5), "replicate"), window, padding=padd, groups=1) - mu1_sq
|
sigma1_sq = (
|
||||||
sigma2_sq = F.conv3d(F.pad(img2 * img2, (5, 5, 5, 5, 5, 5), "replicate"), window, padding=padd, groups=1) - mu2_sq
|
F.conv3d(
|
||||||
sigma12 = F.conv3d(F.pad(img1 * img2, (5, 5, 5, 5, 5, 5), "replicate"), window, padding=padd, groups=1) - mu1_mu2
|
F.pad(img1 * img1, (5, 5, 5, 5, 5, 5), "replicate"), window, padding=padd, groups=1
|
||||||
|
)
|
||||||
|
- mu1_sq
|
||||||
|
)
|
||||||
|
sigma2_sq = (
|
||||||
|
F.conv3d(
|
||||||
|
F.pad(img2 * img2, (5, 5, 5, 5, 5, 5), "replicate"), window, padding=padd, groups=1
|
||||||
|
)
|
||||||
|
- mu2_sq
|
||||||
|
)
|
||||||
|
sigma12 = (
|
||||||
|
F.conv3d(
|
||||||
|
F.pad(img1 * img2, (5, 5, 5, 5, 5, 5), "replicate"), window, padding=padd, groups=1
|
||||||
|
)
|
||||||
|
- mu1_mu2
|
||||||
|
)
|
||||||
|
|
||||||
C1 = (0.01 * L) ** 2
|
C1 = (0.01 * L) ** 2
|
||||||
C2 = (0.03 * L) ** 2
|
C2 = (0.03 * L) ** 2
|
||||||
@ -143,7 +187,14 @@ def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normal
|
|||||||
mssim = []
|
mssim = []
|
||||||
mcs = []
|
mcs = []
|
||||||
for _ in range(levels):
|
for _ in range(levels):
|
||||||
sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range)
|
sim, cs = ssim(
|
||||||
|
img1,
|
||||||
|
img2,
|
||||||
|
window_size=window_size,
|
||||||
|
size_average=size_average,
|
||||||
|
full=True,
|
||||||
|
val_range=val_range,
|
||||||
|
)
|
||||||
mssim.append(sim)
|
mssim.append(sim)
|
||||||
mcs.append(cs)
|
mcs.append(cs)
|
||||||
|
|
||||||
@ -187,7 +238,9 @@ class SSIM(torch.nn.Module):
|
|||||||
self.window = window
|
self.window = window
|
||||||
self.channel = channel
|
self.channel = channel
|
||||||
|
|
||||||
_ssim = ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)
|
_ssim = ssim(
|
||||||
|
img1, img2, window=window, window_size=self.window_size, size_average=self.size_average
|
||||||
|
)
|
||||||
dssim = (1 - _ssim) / 2
|
dssim = (1 - _ssim) / 2
|
||||||
return dssim
|
return dssim
|
||||||
|
|
||||||
|
@ -24,7 +24,12 @@ def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
|
|||||||
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
|
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
|
||||||
return nn.Sequential(
|
return nn.Sequential(
|
||||||
torch.nn.ConvTranspose2d(
|
torch.nn.ConvTranspose2d(
|
||||||
in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1, bias=True
|
in_channels=in_planes,
|
||||||
|
out_channels=out_planes,
|
||||||
|
kernel_size=4,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
bias=True,
|
||||||
),
|
),
|
||||||
nn.PReLU(out_planes),
|
nn.PReLU(out_planes),
|
||||||
)
|
)
|
||||||
@ -56,25 +61,49 @@ class Contextnet(nn.Module):
|
|||||||
def forward(self, x, flow):
|
def forward(self, x, flow):
|
||||||
x = self.conv1(x)
|
x = self.conv1(x)
|
||||||
flow = (
|
flow = (
|
||||||
F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False)
|
F.interpolate(
|
||||||
|
flow,
|
||||||
|
scale_factor=0.5,
|
||||||
|
mode="bilinear",
|
||||||
|
align_corners=False,
|
||||||
|
recompute_scale_factor=False,
|
||||||
|
)
|
||||||
* 0.5
|
* 0.5
|
||||||
)
|
)
|
||||||
f1 = warp(x, flow)
|
f1 = warp(x, flow)
|
||||||
x = self.conv2(x)
|
x = self.conv2(x)
|
||||||
flow = (
|
flow = (
|
||||||
F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False)
|
F.interpolate(
|
||||||
|
flow,
|
||||||
|
scale_factor=0.5,
|
||||||
|
mode="bilinear",
|
||||||
|
align_corners=False,
|
||||||
|
recompute_scale_factor=False,
|
||||||
|
)
|
||||||
* 0.5
|
* 0.5
|
||||||
)
|
)
|
||||||
f2 = warp(x, flow)
|
f2 = warp(x, flow)
|
||||||
x = self.conv3(x)
|
x = self.conv3(x)
|
||||||
flow = (
|
flow = (
|
||||||
F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False)
|
F.interpolate(
|
||||||
|
flow,
|
||||||
|
scale_factor=0.5,
|
||||||
|
mode="bilinear",
|
||||||
|
align_corners=False,
|
||||||
|
recompute_scale_factor=False,
|
||||||
|
)
|
||||||
* 0.5
|
* 0.5
|
||||||
)
|
)
|
||||||
f3 = warp(x, flow)
|
f3 = warp(x, flow)
|
||||||
x = self.conv4(x)
|
x = self.conv4(x)
|
||||||
flow = (
|
flow = (
|
||||||
F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False)
|
F.interpolate(
|
||||||
|
flow,
|
||||||
|
scale_factor=0.5,
|
||||||
|
mode="bilinear",
|
||||||
|
align_corners=False,
|
||||||
|
recompute_scale_factor=False,
|
||||||
|
)
|
||||||
* 0.5
|
* 0.5
|
||||||
)
|
)
|
||||||
f4 = warp(x, flow)
|
f4 = warp(x, flow)
|
||||||
|
@ -24,7 +24,12 @@ def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
|
|||||||
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
|
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
|
||||||
return nn.Sequential(
|
return nn.Sequential(
|
||||||
torch.nn.ConvTranspose2d(
|
torch.nn.ConvTranspose2d(
|
||||||
in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1, bias=True
|
in_channels=in_planes,
|
||||||
|
out_channels=out_planes,
|
||||||
|
kernel_size=4,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
bias=True,
|
||||||
),
|
),
|
||||||
nn.PReLU(out_planes),
|
nn.PReLU(out_planes),
|
||||||
)
|
)
|
||||||
@ -59,19 +64,37 @@ class Contextnet(nn.Module):
|
|||||||
f1 = warp(x, flow)
|
f1 = warp(x, flow)
|
||||||
x = self.conv2(x)
|
x = self.conv2(x)
|
||||||
flow = (
|
flow = (
|
||||||
F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False)
|
F.interpolate(
|
||||||
|
flow,
|
||||||
|
scale_factor=0.5,
|
||||||
|
mode="bilinear",
|
||||||
|
align_corners=False,
|
||||||
|
recompute_scale_factor=False,
|
||||||
|
)
|
||||||
* 0.5
|
* 0.5
|
||||||
)
|
)
|
||||||
f2 = warp(x, flow)
|
f2 = warp(x, flow)
|
||||||
x = self.conv3(x)
|
x = self.conv3(x)
|
||||||
flow = (
|
flow = (
|
||||||
F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False)
|
F.interpolate(
|
||||||
|
flow,
|
||||||
|
scale_factor=0.5,
|
||||||
|
mode="bilinear",
|
||||||
|
align_corners=False,
|
||||||
|
recompute_scale_factor=False,
|
||||||
|
)
|
||||||
* 0.5
|
* 0.5
|
||||||
)
|
)
|
||||||
f3 = warp(x, flow)
|
f3 = warp(x, flow)
|
||||||
x = self.conv4(x)
|
x = self.conv4(x)
|
||||||
flow = (
|
flow = (
|
||||||
F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False)
|
F.interpolate(
|
||||||
|
flow,
|
||||||
|
scale_factor=0.5,
|
||||||
|
mode="bilinear",
|
||||||
|
align_corners=False,
|
||||||
|
recompute_scale_factor=False,
|
||||||
|
)
|
||||||
* 0.5
|
* 0.5
|
||||||
)
|
)
|
||||||
f4 = warp(x, flow)
|
f4 = warp(x, flow)
|
||||||
|
@ -9,6 +9,7 @@ import logging
|
|||||||
import skvideo.io
|
import skvideo.io
|
||||||
from rife.RIFE_HDv3 import Model
|
from rife.RIFE_HDv3 import Model
|
||||||
from huggingface_hub import hf_hub_download, snapshot_download
|
from huggingface_hub import hf_hub_download, snapshot_download
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
@ -78,13 +79,12 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi
|
|||||||
# print(f'I1[0] unpadded shape:{I1.shape}')
|
# print(f'I1[0] unpadded shape:{I1.shape}')
|
||||||
I1_small = F.interpolate(I1, (32, 32), mode="bilinear", align_corners=False)
|
I1_small = F.interpolate(I1, (32, 32), mode="bilinear", align_corners=False)
|
||||||
ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
|
ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
|
||||||
if padding[3] > 0 and padding[1] >0 :
|
if padding[3] > 0 and padding[1] > 0:
|
||||||
|
frame = I1[:, :, : -padding[3], : -padding[1]]
|
||||||
frame = I1[:, :, : -padding[3],:-padding[1]]
|
|
||||||
elif padding[3] > 0:
|
elif padding[3] > 0:
|
||||||
frame = I1[:, :, : -padding[3],:]
|
frame = I1[:, :, : -padding[3], :]
|
||||||
elif padding[1] >0:
|
elif padding[1] > 0:
|
||||||
frame = I1[:, :, :,:-padding[1]]
|
frame = I1[:, :, :, : -padding[1]]
|
||||||
else:
|
else:
|
||||||
frame = I1
|
frame = I1
|
||||||
|
|
||||||
@ -102,7 +102,6 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi
|
|||||||
frame = F.interpolate(frame, size=(h, w))
|
frame = F.interpolate(frame, size=(h, w))
|
||||||
output.append(frame.to(output_device))
|
output.append(frame.to(output_device))
|
||||||
for i, tmp_frame in enumerate(tmp_output):
|
for i, tmp_frame in enumerate(tmp_output):
|
||||||
|
|
||||||
# tmp_frame, _ = pad_image(tmp_frame, upscale_amount)
|
# tmp_frame, _ = pad_image(tmp_frame, upscale_amount)
|
||||||
tmp_frame = F.interpolate(tmp_frame, size=(h, w))
|
tmp_frame = F.interpolate(tmp_frame, size=(h, w))
|
||||||
output.append(tmp_frame.to(output_device))
|
output.append(tmp_frame.to(output_device))
|
||||||
@ -145,9 +144,7 @@ def rife_inference_with_path(model, video_path):
|
|||||||
frame_rgb = frame[..., ::-1]
|
frame_rgb = frame[..., ::-1]
|
||||||
frame_rgb = frame_rgb.copy()
|
frame_rgb = frame_rgb.copy()
|
||||||
tensor = torch.from_numpy(frame_rgb).float().to("cpu", non_blocking=True).float() / 255.0
|
tensor = torch.from_numpy(frame_rgb).float().to("cpu", non_blocking=True).float() / 255.0
|
||||||
pt_frame_data.append(
|
pt_frame_data.append(tensor.permute(2, 0, 1)) # to [c, h, w,]
|
||||||
tensor.permute(2, 0, 1)
|
|
||||||
) # to [c, h, w,]
|
|
||||||
|
|
||||||
pt_frame = torch.from_numpy(np.stack(pt_frame_data))
|
pt_frame = torch.from_numpy(np.stack(pt_frame_data))
|
||||||
pt_frame = pt_frame.to(device)
|
pt_frame = pt_frame.to(device)
|
||||||
@ -170,7 +167,9 @@ def rife_inference_with_latents(model, latents):
|
|||||||
latent = latents[i]
|
latent = latents[i]
|
||||||
|
|
||||||
frames = ssim_interpolation_rife(model, latent)
|
frames = ssim_interpolation_rife(model, latent)
|
||||||
pt_image = torch.stack([frames[i].squeeze(0) for i in range(len(frames))]) # (to [f, c, w, h])
|
pt_image = torch.stack(
|
||||||
|
[frames[i].squeeze(0) for i in range(len(frames))]
|
||||||
|
) # (to [f, c, w, h])
|
||||||
rife_results.append(pt_image)
|
rife_results.append(pt_image)
|
||||||
|
|
||||||
return torch.stack(rife_results)
|
return torch.stack(rife_results)
|
||||||
|
@ -22,7 +22,7 @@ def load_torch_file(ckpt, device=None, dtype=torch.float16):
|
|||||||
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
|
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
|
||||||
sd = safetensors.torch.load_file(ckpt, device=device.type)
|
sd = safetensors.torch.load_file(ckpt, device=device.type)
|
||||||
else:
|
else:
|
||||||
if not "weights_only" in torch.load.__code__.co_varnames:
|
if "weights_only" not in torch.load.__code__.co_varnames:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely."
|
"Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely."
|
||||||
)
|
)
|
||||||
@ -74,27 +74,39 @@ def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
|
|||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def tiled_scale_multidim(
|
def tiled_scale_multidim(
|
||||||
samples, function, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", pbar=None
|
samples,
|
||||||
|
function,
|
||||||
|
tile=(64, 64),
|
||||||
|
overlap=8,
|
||||||
|
upscale_amount=4,
|
||||||
|
out_channels=3,
|
||||||
|
output_device="cpu",
|
||||||
|
pbar=None,
|
||||||
):
|
):
|
||||||
dims = len(tile)
|
dims = len(tile)
|
||||||
print(f"samples dtype:{samples.dtype}")
|
print(f"samples dtype:{samples.dtype}")
|
||||||
output = torch.empty(
|
output = torch.empty(
|
||||||
[samples.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), samples.shape[2:])),
|
[samples.shape[0], out_channels]
|
||||||
|
+ list(map(lambda a: round(a * upscale_amount), samples.shape[2:])),
|
||||||
device=output_device,
|
device=output_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
for b in range(samples.shape[0]):
|
for b in range(samples.shape[0]):
|
||||||
s = samples[b : b + 1]
|
s = samples[b : b + 1]
|
||||||
out = torch.zeros(
|
out = torch.zeros(
|
||||||
[s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])),
|
[s.shape[0], out_channels]
|
||||||
|
+ list(map(lambda a: round(a * upscale_amount), s.shape[2:])),
|
||||||
device=output_device,
|
device=output_device,
|
||||||
)
|
)
|
||||||
out_div = torch.zeros(
|
out_div = torch.zeros(
|
||||||
[s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])),
|
[s.shape[0], out_channels]
|
||||||
|
+ list(map(lambda a: round(a * upscale_amount), s.shape[2:])),
|
||||||
device=output_device,
|
device=output_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
for it in itertools.product(*map(lambda a: range(0, a[0], a[1] - overlap), zip(s.shape[2:], tile))):
|
for it in itertools.product(
|
||||||
|
*map(lambda a: range(0, a[0], a[1] - overlap), zip(s.shape[2:], tile))
|
||||||
|
):
|
||||||
s_in = s
|
s_in = s
|
||||||
upscaled = []
|
upscaled = []
|
||||||
|
|
||||||
@ -142,7 +154,14 @@ def tiled_scale(
|
|||||||
pbar=None,
|
pbar=None,
|
||||||
):
|
):
|
||||||
return tiled_scale_multidim(
|
return tiled_scale_multidim(
|
||||||
samples, function, (tile_y, tile_x), overlap, upscale_amount, out_channels, output_device, pbar
|
samples,
|
||||||
|
function,
|
||||||
|
(tile_y, tile_x),
|
||||||
|
overlap,
|
||||||
|
upscale_amount,
|
||||||
|
out_channels,
|
||||||
|
output_device,
|
||||||
|
pbar,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -186,7 +205,9 @@ def upscale(upscale_model, tensor: torch.Tensor, inf_device, output_device="cpu"
|
|||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
def upscale_batch_and_concatenate(upscale_model, latents, inf_device, output_device="cpu") -> torch.Tensor:
|
def upscale_batch_and_concatenate(
|
||||||
|
upscale_model, latents, inf_device, output_device="cpu"
|
||||||
|
) -> torch.Tensor:
|
||||||
upscaled_latents = []
|
upscaled_latents = []
|
||||||
for i in range(latents.size(0)):
|
for i in range(latents.size(0)):
|
||||||
latent = latents[i]
|
latent = latents[i]
|
||||||
@ -207,7 +228,9 @@ class ProgressBar:
|
|||||||
def __init__(self, total, desc=None):
|
def __init__(self, total, desc=None):
|
||||||
self.total = total
|
self.total = total
|
||||||
self.current = 0
|
self.current = 0
|
||||||
self.b_unit = tqdm.tqdm(total=total, desc="ProgressBar context index: 0" if desc is None else desc)
|
self.b_unit = tqdm.tqdm(
|
||||||
|
total=total, desc="ProgressBar context index: 0" if desc is None else desc
|
||||||
|
)
|
||||||
|
|
||||||
def update(self, value):
|
def update(self, value):
|
||||||
if value > self.total:
|
if value > self.total:
|
||||||
|
@ -22,7 +22,9 @@ from datetime import datetime, timedelta
|
|||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
from moviepy import VideoFileClip
|
from moviepy import VideoFileClip
|
||||||
|
|
||||||
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to("cuda")
|
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to(
|
||||||
|
"cuda"
|
||||||
|
)
|
||||||
|
|
||||||
pipe.vae.enable_slicing()
|
pipe.vae.enable_slicing()
|
||||||
pipe.vae.enable_tiling()
|
pipe.vae.enable_tiling()
|
||||||
@ -95,7 +97,12 @@ def convert_prompt(prompt: str, retry_times: int = 3) -> str:
|
|||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
def infer(prompt: str, num_inference_steps: int, guidance_scale: float, progress=gr.Progress(track_tqdm=True)):
|
def infer(
|
||||||
|
prompt: str,
|
||||||
|
num_inference_steps: int,
|
||||||
|
guidance_scale: float,
|
||||||
|
progress=gr.Progress(track_tqdm=True),
|
||||||
|
):
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
video = pipe(
|
video = pipe(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
@ -151,7 +158,9 @@ with gr.Blocks() as demo:
|
|||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
prompt = gr.Textbox(label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5)
|
prompt = gr.Textbox(
|
||||||
|
label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5
|
||||||
|
)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
gr.Markdown(
|
gr.Markdown(
|
||||||
@ -176,7 +185,13 @@ with gr.Blocks() as demo:
|
|||||||
download_video_button = gr.File(label="📥 Download Video", visible=False)
|
download_video_button = gr.File(label="📥 Download Video", visible=False)
|
||||||
download_gif_button = gr.File(label="📥 Download GIF", visible=False)
|
download_gif_button = gr.File(label="📥 Download GIF", visible=False)
|
||||||
|
|
||||||
def generate(prompt, num_inference_steps, guidance_scale, model_choice, progress=gr.Progress(track_tqdm=True)):
|
def generate(
|
||||||
|
prompt,
|
||||||
|
num_inference_steps,
|
||||||
|
guidance_scale,
|
||||||
|
model_choice,
|
||||||
|
progress=gr.Progress(track_tqdm=True),
|
||||||
|
):
|
||||||
tensor = infer(prompt, num_inference_steps, guidance_scale, progress=progress)
|
tensor = infer(prompt, num_inference_steps, guidance_scale, progress=progress)
|
||||||
video_path = save_video(tensor)
|
video_path = save_video(tensor)
|
||||||
video_update = gr.update(visible=True, value=video_path)
|
video_update = gr.update(visible=True, value=video_path)
|
||||||
|
@ -4,4 +4,3 @@
|
|||||||
<p> 扫码关注公众号,加入「 CogVideoX 交流群」 </p>
|
<p> 扫码关注公众号,加入「 CogVideoX 交流群」 </p>
|
||||||
<p> Scan the QR code to follow the official account and join the "CogVLM Discussion Group" </p>
|
<p> Scan the QR code to follow the official account and join the "CogVLM Discussion Group" </p>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
@ -1,49 +0,0 @@
|
|||||||
# Contribution Guide
|
|
||||||
|
|
||||||
There may still be many incomplete aspects in this project.
|
|
||||||
|
|
||||||
We look forward to your contributions to the repository in the following areas. If you complete the work mentioned above
|
|
||||||
and are willing to submit a PR and share it with the community, upon review, we
|
|
||||||
will acknowledge your contribution on the project homepage.
|
|
||||||
|
|
||||||
## Model Algorithms
|
|
||||||
|
|
||||||
- Support for model quantization inference (Int4 quantization project)
|
|
||||||
- Optimization of model fine-tuning data loading (replacing the existing decord tool)
|
|
||||||
|
|
||||||
## Model Engineering
|
|
||||||
|
|
||||||
- Model fine-tuning examples / Best prompt practices
|
|
||||||
- Inference adaptation on different devices (e.g., MLX framework)
|
|
||||||
- Any tools related to the model
|
|
||||||
- Any minimal fully open-source project using the CogVideoX open-source model
|
|
||||||
|
|
||||||
## Code Standards
|
|
||||||
|
|
||||||
Good code style is an art. We have prepared a `pyproject.toml` configuration file for the project to standardize code
|
|
||||||
style. You can organize the code according to the following specifications:
|
|
||||||
|
|
||||||
1. Install the `ruff` tool
|
|
||||||
|
|
||||||
```shell
|
|
||||||
pip install ruff
|
|
||||||
```
|
|
||||||
|
|
||||||
Then, run the `ruff` tool
|
|
||||||
|
|
||||||
```shell
|
|
||||||
ruff check tools sat inference
|
|
||||||
```
|
|
||||||
|
|
||||||
Check the code style. If there are issues, you can automatically fix them using the `ruff format` command.
|
|
||||||
|
|
||||||
```shell
|
|
||||||
ruff format tools sat inference
|
|
||||||
```
|
|
||||||
|
|
||||||
Once your code meets the standard, there should be no errors.
|
|
||||||
|
|
||||||
## Naming Conventions
|
|
||||||
1. Please use English names, do not use Pinyin or other language names. All comments should be in English.
|
|
||||||
2. Please strictly follow the PEP8 specification and use underscores to separate words. Do not use names like a, b, c.
|
|
||||||
|
|
@ -1,47 +0,0 @@
|
|||||||
# コントリビューションガイド
|
|
||||||
|
|
||||||
本プロジェクトにはまだ多くの未完成の部分があります。
|
|
||||||
|
|
||||||
以下の分野でリポジトリへの貢献をお待ちしています。上記の作業を完了し、PRを提出してコミュニティと共有する意志がある場合、レビュー後、プロジェクトのホームページで貢献を認識します。
|
|
||||||
|
|
||||||
## モデルアルゴリズム
|
|
||||||
|
|
||||||
- モデル量子化推論のサポート (Int4量子化プロジェクト)
|
|
||||||
- モデルのファインチューニングデータロードの最適化(既存のdecordツールの置き換え)
|
|
||||||
|
|
||||||
## モデルエンジニアリング
|
|
||||||
|
|
||||||
- モデルのファインチューニング例 / 最適なプロンプトの実践
|
|
||||||
- 異なるデバイスでの推論適応(例: MLXフレームワーク)
|
|
||||||
- モデルに関連するツール
|
|
||||||
- CogVideoXオープンソースモデルを使用した、完全にオープンソースの最小プロジェクト
|
|
||||||
|
|
||||||
## コード標準
|
|
||||||
|
|
||||||
良いコードスタイルは一種の芸術です。本プロジェクトにはコードスタイルを標準化するための `pyproject.toml`
|
|
||||||
設定ファイルを用意しています。以下の仕様に従ってコードを整理してください。
|
|
||||||
|
|
||||||
1. `ruff` ツールをインストールする
|
|
||||||
|
|
||||||
```shell
|
|
||||||
pip install ruff
|
|
||||||
```
|
|
||||||
|
|
||||||
次に、`ruff` ツールを実行します
|
|
||||||
|
|
||||||
```shell
|
|
||||||
ruff check tools sat inference
|
|
||||||
```
|
|
||||||
|
|
||||||
コードスタイルを確認します。問題がある場合は、`ruff format` コマンドを使用して自動修正できます。
|
|
||||||
|
|
||||||
```shell
|
|
||||||
ruff format tools sat inference
|
|
||||||
```
|
|
||||||
|
|
||||||
コードが標準に準拠したら、エラーはなくなるはずです。
|
|
||||||
|
|
||||||
## 命名規則
|
|
||||||
|
|
||||||
1. 英語名を使用してください。ピンインや他の言語の名前を使用しないでください。すべてのコメントは英語で記載してください。
|
|
||||||
2. PEP8仕様に厳密に従い、単語をアンダースコアで区切ってください。a、b、cのような名前は使用しないでください。
|
|
@ -1,44 +0,0 @@
|
|||||||
# 贡献指南
|
|
||||||
|
|
||||||
本项目可能还存在很多不完善的内容。 我们期待您在以下方面与我们共建仓库, 如果您完成了上述工作并愿意PR和分享到社区,在通过审核后,我们将在项目首页感谢您的贡献。
|
|
||||||
|
|
||||||
## 模型算法
|
|
||||||
|
|
||||||
- 模型量化推理支持 (Int4量化工程)
|
|
||||||
- 模型微调数据载入优化支持(替换现有的decord工具)
|
|
||||||
|
|
||||||
## 模型工程
|
|
||||||
|
|
||||||
- 模型微调示例 / 最佳提示词实践
|
|
||||||
- 不同设备上的推理适配(MLX等框架)
|
|
||||||
- 任何模型周边工具
|
|
||||||
- 任何使用CogVideoX开源模型制作的最小完整开源项目
|
|
||||||
|
|
||||||
## 代码规范
|
|
||||||
|
|
||||||
良好的代码风格是一种艺术,我们已经为项目准备好了`pyproject.toml`配置文件,用于规范代码风格。您可以按照以下规范梳理代码:
|
|
||||||
|
|
||||||
1. 安装`ruff`工具
|
|
||||||
|
|
||||||
```shell
|
|
||||||
pip install ruff
|
|
||||||
```
|
|
||||||
|
|
||||||
接着,运行`ruff`工具
|
|
||||||
|
|
||||||
```shell
|
|
||||||
ruff check tools sat inference
|
|
||||||
```
|
|
||||||
|
|
||||||
检查代码风格,如果有问题,您可以通过`ruff format .`命令自动修复。
|
|
||||||
|
|
||||||
```shell
|
|
||||||
ruff format tools sat inference
|
|
||||||
```
|
|
||||||
|
|
||||||
如果您的代码符合规范,应该不会出现任何的错误。
|
|
||||||
|
|
||||||
## 命名规范
|
|
||||||
|
|
||||||
- 请使用英文命名,不要使用拼音或者其他语言命名。所有的注释均使用英文。
|
|
||||||
- 请严格遵循 PEP8 规范,使用下划线分割单词。请勿使用 a,b,c 这样的命名。
|
|
@ -18,7 +18,10 @@ def add_model_config_args(parser):
|
|||||||
group = parser.add_argument_group("model", "model configuration")
|
group = parser.add_argument_group("model", "model configuration")
|
||||||
group.add_argument("--base", type=str, nargs="*", help="config for input and saving")
|
group.add_argument("--base", type=str, nargs="*", help="config for input and saving")
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--model-parallel-size", type=int, default=1, help="size of the model parallel. only use if you are an expert."
|
"--model-parallel-size",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="size of the model parallel. only use if you are an expert.",
|
||||||
)
|
)
|
||||||
group.add_argument("--force-pretrain", action="store_true")
|
group.add_argument("--force-pretrain", action="store_true")
|
||||||
group.add_argument("--device", type=int, default=-1)
|
group.add_argument("--device", type=int, default=-1)
|
||||||
@ -74,10 +77,15 @@ def get_args(args_list=None, parser=None):
|
|||||||
if not args.train_data:
|
if not args.train_data:
|
||||||
print_rank0("No training data specified", level="WARNING")
|
print_rank0("No training data specified", level="WARNING")
|
||||||
|
|
||||||
assert (args.train_iters is None) or (args.epochs is None), "only one of train_iters and epochs should be set."
|
assert (args.train_iters is None) or (
|
||||||
|
args.epochs is None
|
||||||
|
), "only one of train_iters and epochs should be set."
|
||||||
if args.train_iters is None and args.epochs is None:
|
if args.train_iters is None and args.epochs is None:
|
||||||
args.train_iters = 10000 # default 10k iters
|
args.train_iters = 10000 # default 10k iters
|
||||||
print_rank0("No train_iters (recommended) or epochs specified, use default 10k iters.", level="WARNING")
|
print_rank0(
|
||||||
|
"No train_iters (recommended) or epochs specified, use default 10k iters.",
|
||||||
|
level="WARNING",
|
||||||
|
)
|
||||||
|
|
||||||
args.cuda = torch.cuda.is_available()
|
args.cuda = torch.cuda.is_available()
|
||||||
|
|
||||||
@ -213,7 +221,10 @@ def initialize_distributed(args):
|
|||||||
args.master_port = os.getenv("MASTER_PORT", default_master_port)
|
args.master_port = os.getenv("MASTER_PORT", default_master_port)
|
||||||
init_method += args.master_ip + ":" + args.master_port
|
init_method += args.master_ip + ":" + args.master_port
|
||||||
torch.distributed.init_process_group(
|
torch.distributed.init_process_group(
|
||||||
backend=args.distributed_backend, world_size=args.world_size, rank=args.rank, init_method=init_method
|
backend=args.distributed_backend,
|
||||||
|
world_size=args.world_size,
|
||||||
|
rank=args.rank,
|
||||||
|
init_method=init_method,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set the model-parallel / data-parallel communicators.
|
# Set the model-parallel / data-parallel communicators.
|
||||||
@ -232,7 +243,10 @@ def initialize_distributed(args):
|
|||||||
import deepspeed
|
import deepspeed
|
||||||
|
|
||||||
deepspeed.init_distributed(
|
deepspeed.init_distributed(
|
||||||
dist_backend=args.distributed_backend, world_size=args.world_size, rank=args.rank, init_method=init_method
|
dist_backend=args.distributed_backend,
|
||||||
|
world_size=args.world_size,
|
||||||
|
rank=args.rank,
|
||||||
|
init_method=init_method,
|
||||||
)
|
)
|
||||||
# # It seems that it has no negative influence to configure it even without using checkpointing.
|
# # It seems that it has no negative influence to configure it even without using checkpointing.
|
||||||
# deepspeed.checkpointing.configure(mpu, deepspeed_config=args.deepspeed_config, num_checkpoints=args.num_layers)
|
# deepspeed.checkpointing.configure(mpu, deepspeed_config=args.deepspeed_config, num_checkpoints=args.num_layers)
|
||||||
@ -262,7 +276,9 @@ def process_config_to_args(args):
|
|||||||
|
|
||||||
args_config = config.pop("args", OmegaConf.create())
|
args_config = config.pop("args", OmegaConf.create())
|
||||||
for key in args_config:
|
for key in args_config:
|
||||||
if isinstance(args_config[key], omegaconf.DictConfig) or isinstance(args_config[key], omegaconf.ListConfig):
|
if isinstance(args_config[key], omegaconf.DictConfig) or isinstance(
|
||||||
|
args_config[key], omegaconf.ListConfig
|
||||||
|
):
|
||||||
arg = OmegaConf.to_object(args_config[key])
|
arg = OmegaConf.to_object(args_config[key])
|
||||||
else:
|
else:
|
||||||
arg = args_config[key]
|
arg = args_config[key]
|
||||||
|
@ -56,7 +56,9 @@ def read_video(
|
|||||||
end_pts = float("inf")
|
end_pts = float("inf")
|
||||||
|
|
||||||
if end_pts < start_pts:
|
if end_pts < start_pts:
|
||||||
raise ValueError(f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}")
|
raise ValueError(
|
||||||
|
f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}"
|
||||||
|
)
|
||||||
|
|
||||||
info = {}
|
info = {}
|
||||||
audio_frames = []
|
audio_frames = []
|
||||||
@ -342,7 +344,11 @@ class VideoDataset(MetaDistributedWebDataset):
|
|||||||
super().__init__(
|
super().__init__(
|
||||||
path,
|
path,
|
||||||
partial(
|
partial(
|
||||||
process_fn_video, num_frames=num_frames, image_size=image_size, fps=fps, skip_frms_num=skip_frms_num
|
process_fn_video,
|
||||||
|
num_frames=num_frames,
|
||||||
|
image_size=image_size,
|
||||||
|
fps=fps,
|
||||||
|
skip_frms_num=skip_frms_num,
|
||||||
),
|
),
|
||||||
seed,
|
seed,
|
||||||
meta_names=meta_names,
|
meta_names=meta_names,
|
||||||
@ -400,7 +406,9 @@ class SFTDataset(Dataset):
|
|||||||
indices = np.arange(start, end, (end - start) // num_frames).astype(int)
|
indices = np.arange(start, end, (end - start) // num_frames).astype(int)
|
||||||
temp_frms = vr.get_batch(np.arange(start, end_safty))
|
temp_frms = vr.get_batch(np.arange(start, end_safty))
|
||||||
assert temp_frms is not None
|
assert temp_frms is not None
|
||||||
tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
|
tensor_frms = (
|
||||||
|
torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
|
||||||
|
)
|
||||||
tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())]
|
tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())]
|
||||||
else:
|
else:
|
||||||
if ori_vlen > self.max_num_frames:
|
if ori_vlen > self.max_num_frames:
|
||||||
@ -410,7 +418,11 @@ class SFTDataset(Dataset):
|
|||||||
indices = np.arange(start, end, max((end - start) // num_frames, 1)).astype(int)
|
indices = np.arange(start, end, max((end - start) // num_frames, 1)).astype(int)
|
||||||
temp_frms = vr.get_batch(np.arange(start, end))
|
temp_frms = vr.get_batch(np.arange(start, end))
|
||||||
assert temp_frms is not None
|
assert temp_frms is not None
|
||||||
tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
|
tensor_frms = (
|
||||||
|
torch.from_numpy(temp_frms)
|
||||||
|
if type(temp_frms) is not torch.Tensor
|
||||||
|
else temp_frms
|
||||||
|
)
|
||||||
tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())]
|
tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())]
|
||||||
else:
|
else:
|
||||||
|
|
||||||
@ -423,11 +435,17 @@ class SFTDataset(Dataset):
|
|||||||
|
|
||||||
start = int(self.skip_frms_num)
|
start = int(self.skip_frms_num)
|
||||||
end = int(ori_vlen - self.skip_frms_num)
|
end = int(ori_vlen - self.skip_frms_num)
|
||||||
num_frames = nearest_smaller_4k_plus_1(end - start) # 3D VAE requires the number of frames to be 4k+1
|
num_frames = nearest_smaller_4k_plus_1(
|
||||||
|
end - start
|
||||||
|
) # 3D VAE requires the number of frames to be 4k+1
|
||||||
end = int(start + num_frames)
|
end = int(start + num_frames)
|
||||||
temp_frms = vr.get_batch(np.arange(start, end))
|
temp_frms = vr.get_batch(np.arange(start, end))
|
||||||
assert temp_frms is not None
|
assert temp_frms is not None
|
||||||
tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
|
tensor_frms = (
|
||||||
|
torch.from_numpy(temp_frms)
|
||||||
|
if type(temp_frms) is not torch.Tensor
|
||||||
|
else temp_frms
|
||||||
|
)
|
||||||
|
|
||||||
tensor_frms = pad_last_frame(
|
tensor_frms = pad_last_frame(
|
||||||
tensor_frms, self.max_num_frames
|
tensor_frms, self.max_num_frames
|
||||||
|
@ -41,7 +41,9 @@ class SATVideoDiffusionEngine(nn.Module):
|
|||||||
latent_input = model_config.get("latent_input", False)
|
latent_input = model_config.get("latent_input", False)
|
||||||
disable_first_stage_autocast = model_config.get("disable_first_stage_autocast", False)
|
disable_first_stage_autocast = model_config.get("disable_first_stage_autocast", False)
|
||||||
no_cond_log = model_config.get("disable_first_stage_autocast", False)
|
no_cond_log = model_config.get("disable_first_stage_autocast", False)
|
||||||
not_trainable_prefixes = model_config.get("not_trainable_prefixes", ["first_stage_model", "conditioner"])
|
not_trainable_prefixes = model_config.get(
|
||||||
|
"not_trainable_prefixes", ["first_stage_model", "conditioner"]
|
||||||
|
)
|
||||||
compile_model = model_config.get("compile_model", False)
|
compile_model = model_config.get("compile_model", False)
|
||||||
en_and_decode_n_samples_a_time = model_config.get("en_and_decode_n_samples_a_time", None)
|
en_and_decode_n_samples_a_time = model_config.get("en_and_decode_n_samples_a_time", None)
|
||||||
lr_scale = model_config.get("lr_scale", None)
|
lr_scale = model_config.get("lr_scale", None)
|
||||||
@ -76,12 +78,18 @@ class SATVideoDiffusionEngine(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.denoiser = instantiate_from_config(denoiser_config)
|
self.denoiser = instantiate_from_config(denoiser_config)
|
||||||
self.sampler = instantiate_from_config(sampler_config) if sampler_config is not None else None
|
self.sampler = (
|
||||||
self.conditioner = instantiate_from_config(default(conditioner_config, UNCONDITIONAL_CONFIG))
|
instantiate_from_config(sampler_config) if sampler_config is not None else None
|
||||||
|
)
|
||||||
|
self.conditioner = instantiate_from_config(
|
||||||
|
default(conditioner_config, UNCONDITIONAL_CONFIG)
|
||||||
|
)
|
||||||
|
|
||||||
self._init_first_stage(first_stage_config)
|
self._init_first_stage(first_stage_config)
|
||||||
|
|
||||||
self.loss_fn = instantiate_from_config(loss_fn_config) if loss_fn_config is not None else None
|
self.loss_fn = (
|
||||||
|
instantiate_from_config(loss_fn_config) if loss_fn_config is not None else None
|
||||||
|
)
|
||||||
|
|
||||||
self.latent_input = latent_input
|
self.latent_input = latent_input
|
||||||
self.scale_factor = scale_factor
|
self.scale_factor = scale_factor
|
||||||
@ -151,8 +159,12 @@ class SATVideoDiffusionEngine(nn.Module):
|
|||||||
def shared_step(self, batch: Dict) -> Any:
|
def shared_step(self, batch: Dict) -> Any:
|
||||||
x = self.get_input(batch)
|
x = self.get_input(batch)
|
||||||
if self.lr_scale is not None:
|
if self.lr_scale is not None:
|
||||||
lr_x = F.interpolate(x, scale_factor=1 / self.lr_scale, mode="bilinear", align_corners=False)
|
lr_x = F.interpolate(
|
||||||
lr_x = F.interpolate(lr_x, scale_factor=self.lr_scale, mode="bilinear", align_corners=False)
|
x, scale_factor=1 / self.lr_scale, mode="bilinear", align_corners=False
|
||||||
|
)
|
||||||
|
lr_x = F.interpolate(
|
||||||
|
lr_x, scale_factor=self.lr_scale, mode="bilinear", align_corners=False
|
||||||
|
)
|
||||||
lr_z = self.encode_first_stage(lr_x, batch)
|
lr_z = self.encode_first_stage(lr_x, batch)
|
||||||
batch["lr_input"] = lr_z
|
batch["lr_input"] = lr_z
|
||||||
|
|
||||||
@ -195,7 +207,11 @@ class SATVideoDiffusionEngine(nn.Module):
|
|||||||
recons = []
|
recons = []
|
||||||
start_frame = 0
|
start_frame = 0
|
||||||
for i in range(fake_cp_size):
|
for i in range(fake_cp_size):
|
||||||
end_frame = start_frame + latent_time // fake_cp_size + (1 if i < latent_time % fake_cp_size else 0)
|
end_frame = (
|
||||||
|
start_frame
|
||||||
|
+ latent_time // fake_cp_size
|
||||||
|
+ (1 if i < latent_time % fake_cp_size else 0)
|
||||||
|
)
|
||||||
|
|
||||||
use_cp = True if i == 0 else False
|
use_cp = True if i == 0 else False
|
||||||
clear_fake_cp_cache = True if i == fake_cp_size - 1 else False
|
clear_fake_cp_cache = True if i == fake_cp_size - 1 else False
|
||||||
@ -264,7 +280,9 @@ class SATVideoDiffusionEngine(nn.Module):
|
|||||||
self.model, input, sigma, c, concat_images=concat_images, **addtional_model_inputs
|
self.model, input, sigma, c, concat_images=concat_images, **addtional_model_inputs
|
||||||
)
|
)
|
||||||
|
|
||||||
samples = self.sampler(denoiser, randn, cond, uc=uc, scale=scale, scale_emb=scale_emb, ofs=ofs)
|
samples = self.sampler(
|
||||||
|
denoiser, randn, cond, uc=uc, scale=scale, scale_emb=scale_emb, ofs=ofs
|
||||||
|
)
|
||||||
samples = samples.to(self.dtype)
|
samples = samples.to(self.dtype)
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
@ -278,7 +296,9 @@ class SATVideoDiffusionEngine(nn.Module):
|
|||||||
log = dict()
|
log = dict()
|
||||||
|
|
||||||
for embedder in self.conditioner.embedders:
|
for embedder in self.conditioner.embedders:
|
||||||
if ((self.log_keys is None) or (embedder.input_key in self.log_keys)) and not self.no_cond_log:
|
if (
|
||||||
|
(self.log_keys is None) or (embedder.input_key in self.log_keys)
|
||||||
|
) and not self.no_cond_log:
|
||||||
x = batch[embedder.input_key][:n]
|
x = batch[embedder.input_key][:n]
|
||||||
if isinstance(x, torch.Tensor):
|
if isinstance(x, torch.Tensor):
|
||||||
if x.dim() == 1:
|
if x.dim() == 1:
|
||||||
@ -354,7 +374,9 @@ class SATVideoDiffusionEngine(nn.Module):
|
|||||||
image = torch.concat([image, torch.zeros_like(z[:, 1:])], dim=1)
|
image = torch.concat([image, torch.zeros_like(z[:, 1:])], dim=1)
|
||||||
c["concat"] = image
|
c["concat"] = image
|
||||||
uc["concat"] = image
|
uc["concat"] = image
|
||||||
samples = self.sample(c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs) # b t c h w
|
samples = self.sample(
|
||||||
|
c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
|
||||||
|
) # b t c h w
|
||||||
samples = samples.permute(0, 2, 1, 3, 4).contiguous()
|
samples = samples.permute(0, 2, 1, 3, 4).contiguous()
|
||||||
if only_log_video_latents:
|
if only_log_video_latents:
|
||||||
latents = 1.0 / self.scale_factor * samples
|
latents = 1.0 / self.scale_factor * samples
|
||||||
@ -364,7 +386,9 @@ class SATVideoDiffusionEngine(nn.Module):
|
|||||||
samples = samples.permute(0, 2, 1, 3, 4).contiguous()
|
samples = samples.permute(0, 2, 1, 3, 4).contiguous()
|
||||||
log["samples"] = samples
|
log["samples"] = samples
|
||||||
else:
|
else:
|
||||||
samples = self.sample(c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs) # b t c h w
|
samples = self.sample(
|
||||||
|
c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
|
||||||
|
) # b t c h w
|
||||||
samples = samples.permute(0, 2, 1, 3, 4).contiguous()
|
samples = samples.permute(0, 2, 1, 3, 4).contiguous()
|
||||||
if only_log_video_latents:
|
if only_log_video_latents:
|
||||||
latents = 1.0 / self.scale_factor * samples
|
latents = 1.0 / self.scale_factor * samples
|
||||||
|
@ -94,7 +94,9 @@ def get_3d_sincos_pos_embed(
|
|||||||
|
|
||||||
# concate: [T, H, W] order
|
# concate: [T, H, W] order
|
||||||
pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :]
|
pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :]
|
||||||
pos_embed_temporal = np.repeat(pos_embed_temporal, grid_height * grid_width, axis=1) # [T, H*W, D // 4]
|
pos_embed_temporal = np.repeat(
|
||||||
|
pos_embed_temporal, grid_height * grid_width, axis=1
|
||||||
|
) # [T, H*W, D // 4]
|
||||||
pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :]
|
pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :]
|
||||||
pos_embed_spatial = np.repeat(pos_embed_spatial, t_size, axis=0) # [T, H*W, D // 4 * 3]
|
pos_embed_spatial = np.repeat(pos_embed_spatial, t_size, axis=0) # [T, H*W, D // 4 * 3]
|
||||||
|
|
||||||
@ -160,7 +162,8 @@ class Basic2DPositionEmbeddingMixin(BaseMixin):
|
|||||||
self.width = width
|
self.width = width
|
||||||
self.spatial_length = height * width
|
self.spatial_length = height * width
|
||||||
self.pos_embedding = nn.Parameter(
|
self.pos_embedding = nn.Parameter(
|
||||||
torch.zeros(1, int(text_length + self.spatial_length), int(hidden_size)), requires_grad=False
|
torch.zeros(1, int(text_length + self.spatial_length), int(hidden_size)),
|
||||||
|
requires_grad=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
def position_embedding_forward(self, position_ids, **kwargs):
|
def position_embedding_forward(self, position_ids, **kwargs):
|
||||||
@ -169,7 +172,9 @@ class Basic2DPositionEmbeddingMixin(BaseMixin):
|
|||||||
def reinit(self, parent_model=None):
|
def reinit(self, parent_model=None):
|
||||||
del self.transformer.position_embeddings
|
del self.transformer.position_embeddings
|
||||||
pos_embed = get_2d_sincos_pos_embed(self.pos_embedding.shape[-1], self.height, self.width)
|
pos_embed = get_2d_sincos_pos_embed(self.pos_embedding.shape[-1], self.height, self.width)
|
||||||
self.pos_embedding.data[:, -self.spatial_length :].copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
self.pos_embedding.data[:, -self.spatial_length :].copy_(
|
||||||
|
torch.from_numpy(pos_embed).float().unsqueeze(0)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Basic3DPositionEmbeddingMixin(BaseMixin):
|
class Basic3DPositionEmbeddingMixin(BaseMixin):
|
||||||
@ -192,7 +197,8 @@ class Basic3DPositionEmbeddingMixin(BaseMixin):
|
|||||||
self.spatial_length = height * width
|
self.spatial_length = height * width
|
||||||
self.num_patches = height * width * compressed_num_frames
|
self.num_patches = height * width * compressed_num_frames
|
||||||
self.pos_embedding = nn.Parameter(
|
self.pos_embedding = nn.Parameter(
|
||||||
torch.zeros(1, int(text_length + self.num_patches), int(hidden_size)), requires_grad=False
|
torch.zeros(1, int(text_length + self.num_patches), int(hidden_size)),
|
||||||
|
requires_grad=False,
|
||||||
)
|
)
|
||||||
self.height_interpolation = height_interpolation
|
self.height_interpolation = height_interpolation
|
||||||
self.width_interpolation = width_interpolation
|
self.width_interpolation = width_interpolation
|
||||||
@ -285,7 +291,10 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin):
|
|||||||
freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2)
|
freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2)
|
||||||
freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
|
freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
|
||||||
|
|
||||||
freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
|
freqs = broadcat(
|
||||||
|
(freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]),
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
|
||||||
freqs = freqs.contiguous()
|
freqs = freqs.contiguous()
|
||||||
self.freqs_sin = freqs.sin().cuda()
|
self.freqs_sin = freqs.sin().cuda()
|
||||||
@ -293,7 +302,9 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin):
|
|||||||
self.text_length = text_length
|
self.text_length = text_length
|
||||||
if learnable_pos_embed:
|
if learnable_pos_embed:
|
||||||
num_patches = height * width * compressed_num_frames + text_length
|
num_patches = height * width * compressed_num_frames + text_length
|
||||||
self.pos_embedding = nn.Parameter(torch.zeros(1, num_patches, int(hidden_size)), requires_grad=True)
|
self.pos_embedding = nn.Parameter(
|
||||||
|
torch.zeros(1, num_patches, int(hidden_size)), requires_grad=True
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.pos_embedding = None
|
self.pos_embedding = None
|
||||||
|
|
||||||
@ -440,16 +451,26 @@ class FinalLayerMixin(BaseMixin):
|
|||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=elementwise_affine, eps=1e-6)
|
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=elementwise_affine, eps=1e-6)
|
||||||
self.linear = nn.Linear(hidden_size, reduce(mul, patch_size) * out_channels, bias=True)
|
self.linear = nn.Linear(hidden_size, reduce(mul, patch_size) * out_channels, bias=True)
|
||||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim, 2 * hidden_size, bias=True))
|
self.adaLN_modulation = nn.Sequential(
|
||||||
|
nn.SiLU(), nn.Linear(time_embed_dim, 2 * hidden_size, bias=True)
|
||||||
|
)
|
||||||
|
|
||||||
def final_forward(self, logits, **kwargs):
|
def final_forward(self, logits, **kwargs):
|
||||||
x, emb = logits[:, kwargs["text_length"] :, :], kwargs["emb"] # x:(b,(t n),d),只取了x中后面images的部分
|
x, emb = (
|
||||||
|
logits[:, kwargs["text_length"] :, :],
|
||||||
|
kwargs["emb"],
|
||||||
|
) # x:(b,(t n),d),只取了x中后面images的部分
|
||||||
shift, scale = self.adaLN_modulation(emb).chunk(2, dim=1)
|
shift, scale = self.adaLN_modulation(emb).chunk(2, dim=1)
|
||||||
x = modulate(self.norm_final(x), shift, scale)
|
x = modulate(self.norm_final(x), shift, scale)
|
||||||
x = self.linear(x)
|
x = self.linear(x)
|
||||||
|
|
||||||
return unpatchify(
|
return unpatchify(
|
||||||
x, c=self.out_channels, patch_size=self.patch_size, w=kwargs["rope_W"], h=kwargs["rope_H"], **kwargs
|
x,
|
||||||
|
c=self.out_channels,
|
||||||
|
patch_size=self.patch_size,
|
||||||
|
w=kwargs["rope_W"],
|
||||||
|
h=kwargs["rope_H"],
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def reinit(self, parent_model=None):
|
def reinit(self, parent_model=None):
|
||||||
@ -500,7 +521,10 @@ class AdaLNMixin(BaseMixin):
|
|||||||
self.compressed_num_frames = compressed_num_frames
|
self.compressed_num_frames = compressed_num_frames
|
||||||
|
|
||||||
self.adaLN_modulations = nn.ModuleList(
|
self.adaLN_modulations = nn.ModuleList(
|
||||||
[nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim, 12 * hidden_size)) for _ in range(num_layers)]
|
[
|
||||||
|
nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim, 12 * hidden_size))
|
||||||
|
for _ in range(num_layers)
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
self.qk_ln = qk_ln
|
self.qk_ln = qk_ln
|
||||||
@ -560,7 +584,9 @@ class AdaLNMixin(BaseMixin):
|
|||||||
img_attention_input = modulate(img_attention_input, shift_msa, scale_msa)
|
img_attention_input = modulate(img_attention_input, shift_msa, scale_msa)
|
||||||
text_attention_input = modulate(text_attention_input, text_shift_msa, text_scale_msa)
|
text_attention_input = modulate(text_attention_input, text_shift_msa, text_scale_msa)
|
||||||
|
|
||||||
attention_input = torch.cat((text_attention_input, img_attention_input), dim=1) # (b,n_t+t*n_i,d)
|
attention_input = torch.cat(
|
||||||
|
(text_attention_input, img_attention_input), dim=1
|
||||||
|
) # (b,n_t+t*n_i,d)
|
||||||
attention_output = layer.attention(attention_input, mask, **kwargs)
|
attention_output = layer.attention(attention_input, mask, **kwargs)
|
||||||
text_attention_output = attention_output[:, :text_length] # (b,n,d)
|
text_attention_output = attention_output[:, :text_length] # (b,n,d)
|
||||||
img_attention_output = attention_output[:, text_length:] # (b,(t n),d)
|
img_attention_output = attention_output[:, text_length:] # (b,(t n),d)
|
||||||
@ -584,9 +610,13 @@ class AdaLNMixin(BaseMixin):
|
|||||||
img_mlp_output = layer.fourth_layernorm(img_mlp_output)
|
img_mlp_output = layer.fourth_layernorm(img_mlp_output)
|
||||||
|
|
||||||
img_hidden_states = img_hidden_states + gate_mlp * img_mlp_output # vision (b,(t n),d)
|
img_hidden_states = img_hidden_states + gate_mlp * img_mlp_output # vision (b,(t n),d)
|
||||||
text_hidden_states = text_hidden_states + text_gate_mlp * text_mlp_output # language (b,n,d)
|
text_hidden_states = (
|
||||||
|
text_hidden_states + text_gate_mlp * text_mlp_output
|
||||||
|
) # language (b,n,d)
|
||||||
|
|
||||||
hidden_states = torch.cat((text_hidden_states, img_hidden_states), dim=1) # (b,(n_t+t*n_i),d)
|
hidden_states = torch.cat(
|
||||||
|
(text_hidden_states, img_hidden_states), dim=1
|
||||||
|
) # (b,(n_t+t*n_i),d)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def reinit(self, parent_model=None):
|
def reinit(self, parent_model=None):
|
||||||
@ -694,7 +724,9 @@ class DiffusionTransformer(BaseModel):
|
|||||||
if use_RMSNorm:
|
if use_RMSNorm:
|
||||||
kwargs["layernorm"] = RMSNorm
|
kwargs["layernorm"] = RMSNorm
|
||||||
else:
|
else:
|
||||||
kwargs["layernorm"] = partial(LayerNorm, elementwise_affine=elementwise_affine, eps=1e-6)
|
kwargs["layernorm"] = partial(
|
||||||
|
LayerNorm, elementwise_affine=elementwise_affine, eps=1e-6
|
||||||
|
)
|
||||||
|
|
||||||
transformer_args.num_layers = num_layers
|
transformer_args.num_layers = num_layers
|
||||||
transformer_args.hidden_size = hidden_size
|
transformer_args.hidden_size = hidden_size
|
||||||
@ -707,7 +739,9 @@ class DiffusionTransformer(BaseModel):
|
|||||||
|
|
||||||
if use_SwiGLU:
|
if use_SwiGLU:
|
||||||
self.add_mixin(
|
self.add_mixin(
|
||||||
"swiglu", SwiGLUMixin(num_layers, hidden_size, self.inner_hidden_size, bias=False), reinit=True
|
"swiglu",
|
||||||
|
SwiGLUMixin(num_layers, hidden_size, self.inner_hidden_size, bias=False),
|
||||||
|
reinit=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _build_modules(self, module_configs):
|
def _build_modules(self, module_configs):
|
||||||
@ -813,7 +847,9 @@ class DiffusionTransformer(BaseModel):
|
|||||||
)
|
)
|
||||||
if "lora_config" in module_configs:
|
if "lora_config" in module_configs:
|
||||||
lora_config = module_configs["lora_config"]
|
lora_config = module_configs["lora_config"]
|
||||||
self.add_mixin("lora", instantiate_from_config(lora_config, layer_num=self.num_layers), reinit=True)
|
self.add_mixin(
|
||||||
|
"lora", instantiate_from_config(lora_config, layer_num=self.num_layers), reinit=True
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
|
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
|
||||||
@ -829,7 +865,9 @@ class DiffusionTransformer(BaseModel):
|
|||||||
assert (y is not None) == (
|
assert (y is not None) == (
|
||||||
self.num_classes is not None
|
self.num_classes is not None
|
||||||
), "must specify y if and only if the model is class-conditional"
|
), "must specify y if and only if the model is class-conditional"
|
||||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False, dtype=self.dtype)
|
t_emb = timestep_embedding(
|
||||||
|
timesteps, self.model_channels, repeat_only=False, dtype=self.dtype
|
||||||
|
)
|
||||||
emb = self.time_embed(t_emb)
|
emb = self.time_embed(t_emb)
|
||||||
|
|
||||||
if self.num_classes is not None:
|
if self.num_classes is not None:
|
||||||
@ -838,7 +876,9 @@ class DiffusionTransformer(BaseModel):
|
|||||||
emb = emb + self.label_emb(y)
|
emb = emb + self.label_emb(y)
|
||||||
|
|
||||||
if self.ofs_embed_dim is not None:
|
if self.ofs_embed_dim is not None:
|
||||||
ofs_emb = timestep_embedding(kwargs["ofs"], self.ofs_embed_dim, repeat_only=False, dtype=self.dtype)
|
ofs_emb = timestep_embedding(
|
||||||
|
kwargs["ofs"], self.ofs_embed_dim, repeat_only=False, dtype=self.dtype
|
||||||
|
)
|
||||||
ofs_emb = self.ofs_embed(ofs_emb)
|
ofs_emb = self.ofs_embed(ofs_emb)
|
||||||
emb = emb + ofs_emb
|
emb = emb + ofs_emb
|
||||||
|
|
||||||
@ -852,6 +892,8 @@ class DiffusionTransformer(BaseModel):
|
|||||||
kwargs["rope_H"] = h // self.patch_size[1]
|
kwargs["rope_H"] = h // self.patch_size[1]
|
||||||
kwargs["rope_W"] = w // self.patch_size[2]
|
kwargs["rope_W"] = w // self.patch_size[2]
|
||||||
|
|
||||||
kwargs["input_ids"] = kwargs["position_ids"] = kwargs["attention_mask"] = torch.ones((1, 1)).to(x.dtype)
|
kwargs["input_ids"] = kwargs["position_ids"] = kwargs["attention_mask"] = torch.ones(
|
||||||
|
(1, 1)
|
||||||
|
).to(x.dtype)
|
||||||
output = super().forward(**kwargs)[0]
|
output = super().forward(**kwargs)[0]
|
||||||
return output
|
return output
|
||||||
|
@ -19,6 +19,7 @@ from sat import mpu
|
|||||||
from diffusion_video import SATVideoDiffusionEngine
|
from diffusion_video import SATVideoDiffusionEngine
|
||||||
from arguments import get_args
|
from arguments import get_args
|
||||||
|
|
||||||
|
|
||||||
def read_from_cli():
|
def read_from_cli():
|
||||||
cnt = 0
|
cnt = 0
|
||||||
try:
|
try:
|
||||||
@ -50,34 +51,50 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], T=None, device="cuda
|
|||||||
|
|
||||||
for key in keys:
|
for key in keys:
|
||||||
if key == "txt":
|
if key == "txt":
|
||||||
batch["txt"] = np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist()
|
batch["txt"] = (
|
||||||
batch_uc["txt"] = np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist()
|
np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist()
|
||||||
|
)
|
||||||
|
batch_uc["txt"] = (
|
||||||
|
np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist()
|
||||||
|
)
|
||||||
elif key == "original_size_as_tuple":
|
elif key == "original_size_as_tuple":
|
||||||
batch["original_size_as_tuple"] = (
|
batch["original_size_as_tuple"] = (
|
||||||
torch.tensor([value_dict["orig_height"], value_dict["orig_width"]]).to(device).repeat(*N, 1)
|
torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
|
||||||
|
.to(device)
|
||||||
|
.repeat(*N, 1)
|
||||||
)
|
)
|
||||||
elif key == "crop_coords_top_left":
|
elif key == "crop_coords_top_left":
|
||||||
batch["crop_coords_top_left"] = (
|
batch["crop_coords_top_left"] = (
|
||||||
torch.tensor([value_dict["crop_coords_top"], value_dict["crop_coords_left"]]).to(device).repeat(*N, 1)
|
torch.tensor([value_dict["crop_coords_top"], value_dict["crop_coords_left"]])
|
||||||
|
.to(device)
|
||||||
|
.repeat(*N, 1)
|
||||||
)
|
)
|
||||||
elif key == "aesthetic_score":
|
elif key == "aesthetic_score":
|
||||||
batch["aesthetic_score"] = torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
|
batch["aesthetic_score"] = (
|
||||||
|
torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
|
||||||
|
)
|
||||||
batch_uc["aesthetic_score"] = (
|
batch_uc["aesthetic_score"] = (
|
||||||
torch.tensor([value_dict["negative_aesthetic_score"]]).to(device).repeat(*N, 1)
|
torch.tensor([value_dict["negative_aesthetic_score"]]).to(device).repeat(*N, 1)
|
||||||
)
|
)
|
||||||
|
|
||||||
elif key == "target_size_as_tuple":
|
elif key == "target_size_as_tuple":
|
||||||
batch["target_size_as_tuple"] = (
|
batch["target_size_as_tuple"] = (
|
||||||
torch.tensor([value_dict["target_height"], value_dict["target_width"]]).to(device).repeat(*N, 1)
|
torch.tensor([value_dict["target_height"], value_dict["target_width"]])
|
||||||
|
.to(device)
|
||||||
|
.repeat(*N, 1)
|
||||||
)
|
)
|
||||||
elif key == "fps":
|
elif key == "fps":
|
||||||
batch[key] = torch.tensor([value_dict["fps"]]).to(device).repeat(math.prod(N))
|
batch[key] = torch.tensor([value_dict["fps"]]).to(device).repeat(math.prod(N))
|
||||||
elif key == "fps_id":
|
elif key == "fps_id":
|
||||||
batch[key] = torch.tensor([value_dict["fps_id"]]).to(device).repeat(math.prod(N))
|
batch[key] = torch.tensor([value_dict["fps_id"]]).to(device).repeat(math.prod(N))
|
||||||
elif key == "motion_bucket_id":
|
elif key == "motion_bucket_id":
|
||||||
batch[key] = torch.tensor([value_dict["motion_bucket_id"]]).to(device).repeat(math.prod(N))
|
batch[key] = (
|
||||||
|
torch.tensor([value_dict["motion_bucket_id"]]).to(device).repeat(math.prod(N))
|
||||||
|
)
|
||||||
elif key == "pool_image":
|
elif key == "pool_image":
|
||||||
batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=math.prod(N)).to(device, dtype=torch.half)
|
batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=math.prod(N)).to(
|
||||||
|
device, dtype=torch.half
|
||||||
|
)
|
||||||
elif key == "cond_aug":
|
elif key == "cond_aug":
|
||||||
batch[key] = repeat(
|
batch[key] = repeat(
|
||||||
torch.tensor([value_dict["cond_aug"]]).to("cuda"),
|
torch.tensor([value_dict["cond_aug"]]).to("cuda"),
|
||||||
@ -100,7 +117,9 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], T=None, device="cuda
|
|||||||
return batch, batch_uc
|
return batch, batch_uc
|
||||||
|
|
||||||
|
|
||||||
def save_video_as_grid_and_mp4(video_batch: torch.Tensor, save_path: str, fps: int = 5, args=None, key=None):
|
def save_video_as_grid_and_mp4(
|
||||||
|
video_batch: torch.Tensor, save_path: str, fps: int = 5, args=None, key=None
|
||||||
|
):
|
||||||
os.makedirs(save_path, exist_ok=True)
|
os.makedirs(save_path, exist_ok=True)
|
||||||
|
|
||||||
for i, vid in enumerate(video_batch):
|
for i, vid in enumerate(video_batch):
|
||||||
@ -160,7 +179,9 @@ def sampling_main(args, model_cls):
|
|||||||
W = 96
|
W = 96
|
||||||
H = int(nearest_multiple_of_16(img_H / img_W * W * 8)) // 8
|
H = int(nearest_multiple_of_16(img_H / img_W * W * 8)) // 8
|
||||||
chained_trainsforms = []
|
chained_trainsforms = []
|
||||||
chained_trainsforms.append(TT.Resize(size=[int(H * 8), int(W * 8)], interpolation=1))
|
chained_trainsforms.append(
|
||||||
|
TT.Resize(size=[int(H * 8), int(W * 8)], interpolation=1)
|
||||||
|
)
|
||||||
chained_trainsforms.append(TT.ToTensor())
|
chained_trainsforms.append(TT.ToTensor())
|
||||||
transform = TT.Compose(chained_trainsforms)
|
transform = TT.Compose(chained_trainsforms)
|
||||||
image = transform(image).unsqueeze(0).to("cuda")
|
image = transform(image).unsqueeze(0).to("cuda")
|
||||||
@ -170,7 +191,9 @@ def sampling_main(args, model_cls):
|
|||||||
image = image / model.scale_factor
|
image = image / model.scale_factor
|
||||||
image = image.permute(0, 2, 1, 3, 4).contiguous()
|
image = image.permute(0, 2, 1, 3, 4).contiguous()
|
||||||
pad_shape = (image.shape[0], T - 1, C, H, W)
|
pad_shape = (image.shape[0], T - 1, C, H, W)
|
||||||
image = torch.concat([image, torch.zeros(pad_shape).to(image.device).to(image.dtype)], dim=1)
|
image = torch.concat(
|
||||||
|
[image, torch.zeros(pad_shape).to(image.device).to(image.dtype)], dim=1
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
image_size = args.sampling_image_size
|
image_size = args.sampling_image_size
|
||||||
H, W = image_size[0], image_size[1]
|
H, W = image_size[0], image_size[1]
|
||||||
@ -181,12 +204,20 @@ def sampling_main(args, model_cls):
|
|||||||
mp_size = mpu.get_model_parallel_world_size()
|
mp_size = mpu.get_model_parallel_world_size()
|
||||||
global_rank = torch.distributed.get_rank() // mp_size
|
global_rank = torch.distributed.get_rank() // mp_size
|
||||||
src = global_rank * mp_size
|
src = global_rank * mp_size
|
||||||
torch.distributed.broadcast_object_list(text_cast, src=src, group=mpu.get_model_parallel_group())
|
torch.distributed.broadcast_object_list(
|
||||||
|
text_cast, src=src, group=mpu.get_model_parallel_group()
|
||||||
|
)
|
||||||
text = text_cast[0]
|
text = text_cast[0]
|
||||||
value_dict = {"prompt": text, "negative_prompt": "", "num_frames": torch.tensor(T).unsqueeze(0)}
|
value_dict = {
|
||||||
|
"prompt": text,
|
||||||
|
"negative_prompt": "",
|
||||||
|
"num_frames": torch.tensor(T).unsqueeze(0),
|
||||||
|
}
|
||||||
|
|
||||||
batch, batch_uc = get_batch(
|
batch, batch_uc = get_batch(
|
||||||
get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples
|
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
||||||
|
value_dict,
|
||||||
|
num_samples,
|
||||||
)
|
)
|
||||||
for key in batch:
|
for key in batch:
|
||||||
if isinstance(batch[key], torch.Tensor):
|
if isinstance(batch[key], torch.Tensor):
|
||||||
@ -212,7 +243,11 @@ def sampling_main(args, model_cls):
|
|||||||
for index in range(args.batch_size):
|
for index in range(args.batch_size):
|
||||||
if args.image2video:
|
if args.image2video:
|
||||||
samples_z = sample_func(
|
samples_z = sample_func(
|
||||||
c, uc=uc, batch_size=1, shape=(T, C, H, W), ofs=torch.tensor([2.0]).to("cuda")
|
c,
|
||||||
|
uc=uc,
|
||||||
|
batch_size=1,
|
||||||
|
shape=(T, C, H, W),
|
||||||
|
ofs=torch.tensor([2.0]).to("cuda"),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
samples_z = sample_func(
|
samples_z = sample_func(
|
||||||
@ -226,7 +261,9 @@ def sampling_main(args, model_cls):
|
|||||||
if args.only_save_latents:
|
if args.only_save_latents:
|
||||||
samples_z = 1.0 / model.scale_factor * samples_z
|
samples_z = 1.0 / model.scale_factor * samples_z
|
||||||
save_path = os.path.join(
|
save_path = os.path.join(
|
||||||
args.output_dir, str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120], str(index)
|
args.output_dir,
|
||||||
|
str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120],
|
||||||
|
str(index),
|
||||||
)
|
)
|
||||||
os.makedirs(save_path, exist_ok=True)
|
os.makedirs(save_path, exist_ok=True)
|
||||||
torch.save(samples_z, os.path.join(save_path, "latent.pt"))
|
torch.save(samples_z, os.path.join(save_path, "latent.pt"))
|
||||||
@ -237,7 +274,9 @@ def sampling_main(args, model_cls):
|
|||||||
samples_x = samples_x.permute(0, 2, 1, 3, 4).contiguous()
|
samples_x = samples_x.permute(0, 2, 1, 3, 4).contiguous()
|
||||||
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu()
|
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu()
|
||||||
save_path = os.path.join(
|
save_path = os.path.join(
|
||||||
args.output_dir, str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120], str(index)
|
args.output_dir,
|
||||||
|
str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120],
|
||||||
|
str(index),
|
||||||
)
|
)
|
||||||
if mpu.get_model_parallel_rank() == 0:
|
if mpu.get_model_parallel_rank() == 0:
|
||||||
save_video_as_grid_and_mp4(samples, save_path, fps=args.sampling_fps)
|
save_video_as_grid_and_mp4(samples, save_path, fps=args.sampling_fps)
|
||||||
|
@ -71,15 +71,24 @@ class LambdaWarmUpCosineScheduler2:
|
|||||||
n = n - self.cum_cycles[cycle]
|
n = n - self.cum_cycles[cycle]
|
||||||
if self.verbosity_interval > 0:
|
if self.verbosity_interval > 0:
|
||||||
if n % self.verbosity_interval == 0:
|
if n % self.verbosity_interval == 0:
|
||||||
print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}")
|
print(
|
||||||
|
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
||||||
|
f"current cycle {cycle}"
|
||||||
|
)
|
||||||
if n < self.lr_warm_up_steps[cycle]:
|
if n < self.lr_warm_up_steps[cycle]:
|
||||||
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
|
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
|
||||||
|
cycle
|
||||||
|
] * n + self.f_start[cycle]
|
||||||
self.last_f = f
|
self.last_f = f
|
||||||
return f
|
return f
|
||||||
else:
|
else:
|
||||||
t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
|
t = (n - self.lr_warm_up_steps[cycle]) / (
|
||||||
|
self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]
|
||||||
|
)
|
||||||
t = min(t, 1.0)
|
t = min(t, 1.0)
|
||||||
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (1 + np.cos(t * np.pi))
|
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
|
||||||
|
1 + np.cos(t * np.pi)
|
||||||
|
)
|
||||||
self.last_f = f
|
self.last_f = f
|
||||||
return f
|
return f
|
||||||
|
|
||||||
@ -93,10 +102,15 @@ class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
|
|||||||
n = n - self.cum_cycles[cycle]
|
n = n - self.cum_cycles[cycle]
|
||||||
if self.verbosity_interval > 0:
|
if self.verbosity_interval > 0:
|
||||||
if n % self.verbosity_interval == 0:
|
if n % self.verbosity_interval == 0:
|
||||||
print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}")
|
print(
|
||||||
|
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
||||||
|
f"current cycle {cycle}"
|
||||||
|
)
|
||||||
|
|
||||||
if n < self.lr_warm_up_steps[cycle]:
|
if n < self.lr_warm_up_steps[cycle]:
|
||||||
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
|
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
|
||||||
|
cycle
|
||||||
|
] * n + self.f_start[cycle]
|
||||||
self.last_f = f
|
self.last_f = f
|
||||||
return f
|
return f
|
||||||
else:
|
else:
|
||||||
|
@ -218,14 +218,20 @@ class AutoencodingEngine(AbstractAutoencoder):
|
|||||||
x = self.decoder(z, **kwargs)
|
x = self.decoder(z, **kwargs)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, **additional_decode_kwargs) -> Tuple[torch.Tensor, torch.Tensor, dict]:
|
def forward(
|
||||||
|
self, x: torch.Tensor, **additional_decode_kwargs
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
|
||||||
z, reg_log = self.encode(x, return_reg_log=True)
|
z, reg_log = self.encode(x, return_reg_log=True)
|
||||||
dec = self.decode(z, **additional_decode_kwargs)
|
dec = self.decode(z, **additional_decode_kwargs)
|
||||||
return z, dec, reg_log
|
return z, dec, reg_log
|
||||||
|
|
||||||
def inner_training_step(self, batch: dict, batch_idx: int, optimizer_idx: int = 0) -> torch.Tensor:
|
def inner_training_step(
|
||||||
|
self, batch: dict, batch_idx: int, optimizer_idx: int = 0
|
||||||
|
) -> torch.Tensor:
|
||||||
x = self.get_input(batch)
|
x = self.get_input(batch)
|
||||||
additional_decode_kwargs = {key: batch[key] for key in self.additional_decode_keys.intersection(batch)}
|
additional_decode_kwargs = {
|
||||||
|
key: batch[key] for key in self.additional_decode_keys.intersection(batch)
|
||||||
|
}
|
||||||
z, xrec, regularization_log = self(x, **additional_decode_kwargs)
|
z, xrec, regularization_log = self(x, **additional_decode_kwargs)
|
||||||
if hasattr(self.loss, "forward_keys"):
|
if hasattr(self.loss, "forward_keys"):
|
||||||
extra_info = {
|
extra_info = {
|
||||||
@ -361,12 +367,16 @@ class AutoencodingEngine(AbstractAutoencoder):
|
|||||||
if self.trainable_ae_params is None:
|
if self.trainable_ae_params is None:
|
||||||
ae_params = self.get_autoencoder_params()
|
ae_params = self.get_autoencoder_params()
|
||||||
else:
|
else:
|
||||||
ae_params, num_ae_params = self.get_param_groups(self.trainable_ae_params, self.ae_optimizer_args)
|
ae_params, num_ae_params = self.get_param_groups(
|
||||||
|
self.trainable_ae_params, self.ae_optimizer_args
|
||||||
|
)
|
||||||
logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}")
|
logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}")
|
||||||
if self.trainable_disc_params is None:
|
if self.trainable_disc_params is None:
|
||||||
disc_params = self.get_discriminator_params()
|
disc_params = self.get_discriminator_params()
|
||||||
else:
|
else:
|
||||||
disc_params, num_disc_params = self.get_param_groups(self.trainable_disc_params, self.disc_optimizer_args)
|
disc_params, num_disc_params = self.get_param_groups(
|
||||||
|
self.trainable_disc_params, self.disc_optimizer_args
|
||||||
|
)
|
||||||
logpy.info(f"Number of trainable discriminator parameters: {num_disc_params:,}")
|
logpy.info(f"Number of trainable discriminator parameters: {num_disc_params:,}")
|
||||||
opt_ae = self.instantiate_optimizer_from_config(
|
opt_ae = self.instantiate_optimizer_from_config(
|
||||||
ae_params,
|
ae_params,
|
||||||
@ -375,17 +385,23 @@ class AutoencodingEngine(AbstractAutoencoder):
|
|||||||
)
|
)
|
||||||
opts = [opt_ae]
|
opts = [opt_ae]
|
||||||
if len(disc_params) > 0:
|
if len(disc_params) > 0:
|
||||||
opt_disc = self.instantiate_optimizer_from_config(disc_params, self.learning_rate, self.optimizer_config)
|
opt_disc = self.instantiate_optimizer_from_config(
|
||||||
|
disc_params, self.learning_rate, self.optimizer_config
|
||||||
|
)
|
||||||
opts.append(opt_disc)
|
opts.append(opt_disc)
|
||||||
|
|
||||||
return opts
|
return opts
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def log_images(self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs) -> dict:
|
def log_images(
|
||||||
|
self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs
|
||||||
|
) -> dict:
|
||||||
log = dict()
|
log = dict()
|
||||||
additional_decode_kwargs = {}
|
additional_decode_kwargs = {}
|
||||||
x = self.get_input(batch)
|
x = self.get_input(batch)
|
||||||
additional_decode_kwargs.update({key: batch[key] for key in self.additional_decode_keys.intersection(batch)})
|
additional_decode_kwargs.update(
|
||||||
|
{key: batch[key] for key in self.additional_decode_keys.intersection(batch)}
|
||||||
|
)
|
||||||
|
|
||||||
_, xrec, _ = self(x, **additional_decode_kwargs)
|
_, xrec, _ = self(x, **additional_decode_kwargs)
|
||||||
log["inputs"] = x
|
log["inputs"] = x
|
||||||
@ -404,7 +420,9 @@ class AutoencodingEngine(AbstractAutoencoder):
|
|||||||
diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x)
|
diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x)
|
||||||
diff_ema.clamp_(0, 1.0)
|
diff_ema.clamp_(0, 1.0)
|
||||||
log["diff_ema"] = 2.0 * diff_ema - 1.0
|
log["diff_ema"] = 2.0 * diff_ema - 1.0
|
||||||
log["diff_boost_ema"] = 2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1
|
log["diff_boost_ema"] = (
|
||||||
|
2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1
|
||||||
|
)
|
||||||
if additional_log_kwargs:
|
if additional_log_kwargs:
|
||||||
additional_decode_kwargs.update(additional_log_kwargs)
|
additional_decode_kwargs.update(additional_log_kwargs)
|
||||||
_, xrec_add, _ = self(x, **additional_decode_kwargs)
|
_, xrec_add, _ = self(x, **additional_decode_kwargs)
|
||||||
@ -446,7 +464,9 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
|
|||||||
params = super().get_autoencoder_params()
|
params = super().get_autoencoder_params()
|
||||||
return params
|
return params
|
||||||
|
|
||||||
def encode(self, x: torch.Tensor, return_reg_log: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
def encode(
|
||||||
|
self, x: torch.Tensor, return_reg_log: bool = False
|
||||||
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
||||||
if self.max_batch_size is None:
|
if self.max_batch_size is None:
|
||||||
z = self.encoder(x)
|
z = self.encoder(x)
|
||||||
z = self.quant_conv(z)
|
z = self.quant_conv(z)
|
||||||
@ -513,7 +533,9 @@ class VideoAutoencodingEngine(AutoencodingEngine):
|
|||||||
if ckpt_path is not None:
|
if ckpt_path is not None:
|
||||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||||
|
|
||||||
def log_videos(self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs) -> dict:
|
def log_videos(
|
||||||
|
self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs
|
||||||
|
) -> dict:
|
||||||
return self.log_images(batch, additional_log_kwargs, **kwargs)
|
return self.log_images(batch, additional_log_kwargs, **kwargs)
|
||||||
|
|
||||||
def get_input(self, batch: dict) -> torch.Tensor:
|
def get_input(self, batch: dict) -> torch.Tensor:
|
||||||
@ -524,7 +546,9 @@ class VideoAutoencodingEngine(AutoencodingEngine):
|
|||||||
batch = batch[self.input_key]
|
batch = batch[self.input_key]
|
||||||
|
|
||||||
global_src_rank = get_context_parallel_group_rank() * self.context_parallel_size
|
global_src_rank = get_context_parallel_group_rank() * self.context_parallel_size
|
||||||
torch.distributed.broadcast(batch, src=global_src_rank, group=get_context_parallel_group())
|
torch.distributed.broadcast(
|
||||||
|
batch, src=global_src_rank, group=get_context_parallel_group()
|
||||||
|
)
|
||||||
|
|
||||||
batch = _conv_split(batch, dim=2, kernel_size=1)
|
batch = _conv_split(batch, dim=2, kernel_size=1)
|
||||||
return batch
|
return batch
|
||||||
|
@ -94,7 +94,11 @@ class FeedForward(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
inner_dim = int(dim * mult)
|
inner_dim = int(dim * mult)
|
||||||
dim_out = default(dim_out, dim)
|
dim_out = default(dim_out, dim)
|
||||||
project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
|
project_in = (
|
||||||
|
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
|
||||||
|
if not glu
|
||||||
|
else GEGLU(dim, inner_dim)
|
||||||
|
)
|
||||||
|
|
||||||
self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
|
self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
|
||||||
|
|
||||||
@ -126,7 +130,9 @@ class LinearAttention(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
b, c, h, w = x.shape
|
b, c, h, w = x.shape
|
||||||
qkv = self.to_qkv(x)
|
qkv = self.to_qkv(x)
|
||||||
q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
|
q, k, v = rearrange(
|
||||||
|
qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
|
||||||
|
)
|
||||||
k = k.softmax(dim=-1)
|
k = k.softmax(dim=-1)
|
||||||
context = torch.einsum("bhdn,bhen->bhde", k, v)
|
context = torch.einsum("bhdn,bhen->bhde", k, v)
|
||||||
out = torch.einsum("bhde,bhdn->bhen", context, q)
|
out = torch.einsum("bhde,bhdn->bhen", context, q)
|
||||||
@ -143,7 +149,9 @@ class SpatialSelfAttention(nn.Module):
|
|||||||
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
self.proj_out = torch.nn.Conv2d(
|
||||||
|
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
h_ = x
|
h_ = x
|
||||||
@ -244,7 +252,9 @@ class CrossAttention(nn.Module):
|
|||||||
# new
|
# new
|
||||||
with sdp_kernel(**BACKEND_MAP[self.backend]):
|
with sdp_kernel(**BACKEND_MAP[self.backend]):
|
||||||
# print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
|
# print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
|
||||||
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) # scale is dim_head ** -0.5 per default
|
out = F.scaled_dot_product_attention(
|
||||||
|
q, k, v, attn_mask=mask
|
||||||
|
) # scale is dim_head ** -0.5 per default
|
||||||
|
|
||||||
del q, k, v
|
del q, k, v
|
||||||
out = rearrange(out, "b h n d -> b n (h d)", h=h)
|
out = rearrange(out, "b h n d -> b n (h d)", h=h)
|
||||||
@ -422,7 +432,9 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
self.norm1(x),
|
self.norm1(x),
|
||||||
context=context if self.disable_self_attn else None,
|
context=context if self.disable_self_attn else None,
|
||||||
additional_tokens=additional_tokens,
|
additional_tokens=additional_tokens,
|
||||||
n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self if not self.disable_self_attn else 0,
|
n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self
|
||||||
|
if not self.disable_self_attn
|
||||||
|
else 0,
|
||||||
)
|
)
|
||||||
+ x
|
+ x
|
||||||
)
|
)
|
||||||
@ -499,7 +511,9 @@ class SpatialTransformer(nn.Module):
|
|||||||
sdp_backend=None,
|
sdp_backend=None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
print(f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads")
|
print(
|
||||||
|
f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads"
|
||||||
|
)
|
||||||
from omegaconf import ListConfig
|
from omegaconf import ListConfig
|
||||||
|
|
||||||
if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)):
|
if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)):
|
||||||
@ -542,7 +556,9 @@ class SpatialTransformer(nn.Module):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
if not use_linear:
|
if not use_linear:
|
||||||
self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
|
self.proj_out = zero_module(
|
||||||
|
nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
|
# self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
|
||||||
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
|
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
|
||||||
|
@ -87,7 +87,9 @@ class GeneralLPIPSWithDiscriminator(nn.Module):
|
|||||||
yield from ()
|
yield from ()
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def log_images(self, inputs: torch.Tensor, reconstructions: torch.Tensor) -> Dict[str, torch.Tensor]:
|
def log_images(
|
||||||
|
self, inputs: torch.Tensor, reconstructions: torch.Tensor
|
||||||
|
) -> Dict[str, torch.Tensor]:
|
||||||
# calc logits of real/fake
|
# calc logits of real/fake
|
||||||
logits_real = self.discriminator(inputs.contiguous().detach())
|
logits_real = self.discriminator(inputs.contiguous().detach())
|
||||||
if len(logits_real.shape) < 4:
|
if len(logits_real.shape) < 4:
|
||||||
@ -209,7 +211,9 @@ class GeneralLPIPSWithDiscriminator(nn.Module):
|
|||||||
weights: Union[None, float, torch.Tensor] = None,
|
weights: Union[None, float, torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, dict]:
|
) -> Tuple[torch.Tensor, dict]:
|
||||||
if self.scale_input_to_tgt_size:
|
if self.scale_input_to_tgt_size:
|
||||||
inputs = torch.nn.functional.interpolate(inputs, reconstructions.shape[2:], mode="bicubic", antialias=True)
|
inputs = torch.nn.functional.interpolate(
|
||||||
|
inputs, reconstructions.shape[2:], mode="bicubic", antialias=True
|
||||||
|
)
|
||||||
|
|
||||||
if self.dims > 2:
|
if self.dims > 2:
|
||||||
inputs, reconstructions = map(
|
inputs, reconstructions = map(
|
||||||
@ -226,7 +230,9 @@ class GeneralLPIPSWithDiscriminator(nn.Module):
|
|||||||
input_frames = pick_video_frame(inputs, frame_indices)
|
input_frames = pick_video_frame(inputs, frame_indices)
|
||||||
recon_frames = pick_video_frame(reconstructions, frame_indices)
|
recon_frames = pick_video_frame(reconstructions, frame_indices)
|
||||||
|
|
||||||
p_loss = self.perceptual_loss(input_frames.contiguous(), recon_frames.contiguous()).mean()
|
p_loss = self.perceptual_loss(
|
||||||
|
input_frames.contiguous(), recon_frames.contiguous()
|
||||||
|
).mean()
|
||||||
rec_loss = rec_loss + self.perceptual_weight * p_loss
|
rec_loss = rec_loss + self.perceptual_weight * p_loss
|
||||||
|
|
||||||
nll_loss, weighted_nll_loss = self.get_nll_loss(rec_loss, weights)
|
nll_loss, weighted_nll_loss = self.get_nll_loss(rec_loss, weights)
|
||||||
@ -238,7 +244,9 @@ class GeneralLPIPSWithDiscriminator(nn.Module):
|
|||||||
logits_fake = self.discriminator(reconstructions.contiguous())
|
logits_fake = self.discriminator(reconstructions.contiguous())
|
||||||
g_loss = -torch.mean(logits_fake)
|
g_loss = -torch.mean(logits_fake)
|
||||||
if self.training:
|
if self.training:
|
||||||
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
|
d_weight = self.calculate_adaptive_weight(
|
||||||
|
nll_loss, g_loss, last_layer=last_layer
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
d_weight = torch.tensor(1.0)
|
d_weight = torch.tensor(1.0)
|
||||||
else:
|
else:
|
||||||
|
@ -37,12 +37,18 @@ class LatentLPIPS(nn.Module):
|
|||||||
if self.perceptual_weight > 0.0:
|
if self.perceptual_weight > 0.0:
|
||||||
image_reconstructions = self.decoder.decode(latent_predictions)
|
image_reconstructions = self.decoder.decode(latent_predictions)
|
||||||
image_targets = self.decoder.decode(latent_inputs)
|
image_targets = self.decoder.decode(latent_inputs)
|
||||||
perceptual_loss = self.perceptual_loss(image_targets.contiguous(), image_reconstructions.contiguous())
|
perceptual_loss = self.perceptual_loss(
|
||||||
loss = self.latent_weight * loss.mean() + self.perceptual_weight * perceptual_loss.mean()
|
image_targets.contiguous(), image_reconstructions.contiguous()
|
||||||
|
)
|
||||||
|
loss = (
|
||||||
|
self.latent_weight * loss.mean() + self.perceptual_weight * perceptual_loss.mean()
|
||||||
|
)
|
||||||
log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach()
|
log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach()
|
||||||
|
|
||||||
if self.perceptual_weight_on_inputs > 0.0:
|
if self.perceptual_weight_on_inputs > 0.0:
|
||||||
image_reconstructions = default(image_reconstructions, self.decoder.decode(latent_predictions))
|
image_reconstructions = default(
|
||||||
|
image_reconstructions, self.decoder.decode(latent_predictions)
|
||||||
|
)
|
||||||
if self.scale_input_to_tgt_size:
|
if self.scale_input_to_tgt_size:
|
||||||
image_inputs = torch.nn.functional.interpolate(
|
image_inputs = torch.nn.functional.interpolate(
|
||||||
image_inputs,
|
image_inputs,
|
||||||
@ -58,7 +64,9 @@ class LatentLPIPS(nn.Module):
|
|||||||
antialias=True,
|
antialias=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
perceptual_loss2 = self.perceptual_loss(image_inputs.contiguous(), image_reconstructions.contiguous())
|
perceptual_loss2 = self.perceptual_loss(
|
||||||
|
image_inputs.contiguous(), image_reconstructions.contiguous()
|
||||||
|
)
|
||||||
loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean()
|
loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean()
|
||||||
log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach()
|
log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach()
|
||||||
return loss, log
|
return loss, log
|
||||||
|
@ -45,7 +45,9 @@ def hinge_gen_loss(fake):
|
|||||||
@autocast(enabled=False)
|
@autocast(enabled=False)
|
||||||
@beartype
|
@beartype
|
||||||
def grad_layer_wrt_loss(loss: Tensor, layer: nn.Parameter):
|
def grad_layer_wrt_loss(loss: Tensor, layer: nn.Parameter):
|
||||||
return torch_grad(outputs=loss, inputs=layer, grad_outputs=torch.ones_like(loss), retain_graph=True)[0].detach()
|
return torch_grad(
|
||||||
|
outputs=loss, inputs=layer, grad_outputs=torch.ones_like(loss), retain_graph=True
|
||||||
|
)[0].detach()
|
||||||
|
|
||||||
|
|
||||||
def pick_video_frame(video, frame_indices):
|
def pick_video_frame(video, frame_indices):
|
||||||
@ -126,7 +128,8 @@ class DiscriminatorBlock(nn.Module):
|
|||||||
|
|
||||||
self.downsample = (
|
self.downsample = (
|
||||||
nn.Sequential(
|
nn.Sequential(
|
||||||
Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2), nn.Conv2d(filters * 4, filters, 1)
|
Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2),
|
||||||
|
nn.Conv2d(filters * 4, filters, 1),
|
||||||
)
|
)
|
||||||
if downsample
|
if downsample
|
||||||
else None
|
else None
|
||||||
@ -185,11 +188,18 @@ class Discriminator(nn.Module):
|
|||||||
is_not_last = ind != (len(layer_dims_in_out) - 1)
|
is_not_last = ind != (len(layer_dims_in_out) - 1)
|
||||||
|
|
||||||
block = DiscriminatorBlock(
|
block = DiscriminatorBlock(
|
||||||
in_chan, out_chan, downsample=is_not_last, antialiased_downsample=antialiased_downsample
|
in_chan,
|
||||||
|
out_chan,
|
||||||
|
downsample=is_not_last,
|
||||||
|
antialiased_downsample=antialiased_downsample,
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_block = nn.Sequential(
|
attn_block = nn.Sequential(
|
||||||
Residual(LinearSpaceAttention(dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head)),
|
Residual(
|
||||||
|
LinearSpaceAttention(
|
||||||
|
dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head
|
||||||
|
)
|
||||||
|
),
|
||||||
Residual(FeedForward(dim=out_chan, mult=ff_mult, images=True)),
|
Residual(FeedForward(dim=out_chan, mult=ff_mult, images=True)),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -363,7 +373,9 @@ class Discriminator3D(nn.Module):
|
|||||||
)
|
)
|
||||||
attn_block = nn.Sequential(
|
attn_block = nn.Sequential(
|
||||||
Residual(
|
Residual(
|
||||||
LinearSpaceAttention(dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head)
|
LinearSpaceAttention(
|
||||||
|
dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head
|
||||||
|
)
|
||||||
),
|
),
|
||||||
Residual(FeedForward(dim=out_chan, mult=ff_mult, images=True)),
|
Residual(FeedForward(dim=out_chan, mult=ff_mult, images=True)),
|
||||||
)
|
)
|
||||||
@ -458,7 +470,9 @@ class Discriminator3DWithfirstframe(nn.Module):
|
|||||||
)
|
)
|
||||||
attn_block = nn.Sequential(
|
attn_block = nn.Sequential(
|
||||||
Residual(
|
Residual(
|
||||||
LinearSpaceAttention(dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head)
|
LinearSpaceAttention(
|
||||||
|
dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head
|
||||||
|
)
|
||||||
),
|
),
|
||||||
Residual(FeedForward(dim=out_chan, mult=ff_mult, images=True)),
|
Residual(FeedForward(dim=out_chan, mult=ff_mult, images=True)),
|
||||||
)
|
)
|
||||||
@ -581,11 +595,17 @@ class VideoAutoencoderLoss(nn.Module):
|
|||||||
input_frames = pick_video_frame(inputs, frame_indices)
|
input_frames = pick_video_frame(inputs, frame_indices)
|
||||||
recon_frames = pick_video_frame(reconstructions, frame_indices)
|
recon_frames = pick_video_frame(reconstructions, frame_indices)
|
||||||
|
|
||||||
perceptual_loss = self.perceptual_model(input_frames.contiguous(), recon_frames.contiguous()).mean()
|
perceptual_loss = self.perceptual_model(
|
||||||
|
input_frames.contiguous(), recon_frames.contiguous()
|
||||||
|
).mean()
|
||||||
else:
|
else:
|
||||||
perceptual_loss = self.zero
|
perceptual_loss = self.zero
|
||||||
|
|
||||||
if global_step >= self.disc_start or not self.training or self.adversarial_loss_weight == 0:
|
if (
|
||||||
|
global_step >= self.disc_start
|
||||||
|
or not self.training
|
||||||
|
or self.adversarial_loss_weight == 0
|
||||||
|
):
|
||||||
gen_loss = self.zero
|
gen_loss = self.zero
|
||||||
adaptive_weight = 0
|
adaptive_weight = 0
|
||||||
else:
|
else:
|
||||||
@ -598,9 +618,13 @@ class VideoAutoencoderLoss(nn.Module):
|
|||||||
|
|
||||||
adaptive_weight = 1
|
adaptive_weight = 1
|
||||||
if self.perceptual_weight > 0 and last_layer is not None:
|
if self.perceptual_weight > 0 and last_layer is not None:
|
||||||
norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_layer).norm(p=2)
|
norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(
|
||||||
|
perceptual_loss, last_layer
|
||||||
|
).norm(p=2)
|
||||||
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_layer).norm(p=2)
|
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_layer).norm(p=2)
|
||||||
adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(min=1e-3)
|
adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(
|
||||||
|
min=1e-3
|
||||||
|
)
|
||||||
adaptive_weight.clamp_(max=1e3)
|
adaptive_weight.clamp_(max=1e3)
|
||||||
|
|
||||||
if torch.isnan(adaptive_weight).any():
|
if torch.isnan(adaptive_weight).any():
|
||||||
|
@ -48,7 +48,9 @@ class LPIPS(nn.Module):
|
|||||||
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
|
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
|
||||||
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
|
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
|
||||||
|
|
||||||
res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
|
res = [
|
||||||
|
spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))
|
||||||
|
]
|
||||||
val = res[0]
|
val = res[0]
|
||||||
for l in range(1, len(self.chns)):
|
for l in range(1, len(self.chns)):
|
||||||
val += res[l]
|
val += res[l]
|
||||||
@ -118,7 +120,9 @@ class vgg16(torch.nn.Module):
|
|||||||
h_relu4_3 = h
|
h_relu4_3 = h
|
||||||
h = self.slice5(h)
|
h = self.slice5(h)
|
||||||
h_relu5_3 = h
|
h_relu5_3 = h
|
||||||
vgg_outputs = namedtuple("VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"])
|
vgg_outputs = namedtuple(
|
||||||
|
"VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]
|
||||||
|
)
|
||||||
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
|
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
@ -35,7 +35,9 @@ class NLayerDiscriminator(nn.Module):
|
|||||||
norm_layer = nn.BatchNorm2d
|
norm_layer = nn.BatchNorm2d
|
||||||
else:
|
else:
|
||||||
norm_layer = ActNorm
|
norm_layer = ActNorm
|
||||||
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
|
if (
|
||||||
|
type(norm_layer) == functools.partial
|
||||||
|
): # no need to use bias as BatchNorm2d has affine parameters
|
||||||
use_bias = norm_layer.func != nn.BatchNorm2d
|
use_bias = norm_layer.func != nn.BatchNorm2d
|
||||||
else:
|
else:
|
||||||
use_bias = norm_layer != nn.BatchNorm2d
|
use_bias = norm_layer != nn.BatchNorm2d
|
||||||
|
@ -11,6 +11,7 @@ def hinge_d_loss(logits_real, logits_fake):
|
|||||||
|
|
||||||
def vanilla_d_loss(logits_real, logits_fake):
|
def vanilla_d_loss(logits_real, logits_fake):
|
||||||
d_loss = 0.5 * (
|
d_loss = 0.5 * (
|
||||||
torch.mean(torch.nn.functional.softplus(-logits_real)) + torch.mean(torch.nn.functional.softplus(logits_fake))
|
torch.mean(torch.nn.functional.softplus(-logits_real))
|
||||||
|
+ torch.mean(torch.nn.functional.softplus(logits_fake))
|
||||||
)
|
)
|
||||||
return d_loss
|
return d_loss
|
||||||
|
@ -147,7 +147,9 @@ def hinge_gen_loss(fake):
|
|||||||
@autocast(enabled=False)
|
@autocast(enabled=False)
|
||||||
@beartype
|
@beartype
|
||||||
def grad_layer_wrt_loss(loss: Tensor, layer: nn.Parameter):
|
def grad_layer_wrt_loss(loss: Tensor, layer: nn.Parameter):
|
||||||
return torch_grad(outputs=loss, inputs=layer, grad_outputs=torch.ones_like(loss), retain_graph=True)[0].detach()
|
return torch_grad(
|
||||||
|
outputs=loss, inputs=layer, grad_outputs=torch.ones_like(loss), retain_graph=True
|
||||||
|
)[0].detach()
|
||||||
|
|
||||||
|
|
||||||
# helper decorators
|
# helper decorators
|
||||||
@ -223,7 +225,10 @@ class SqueezeExcite(Module):
|
|||||||
dim_hidden = max(dim_hidden_min, dim_out // 2)
|
dim_hidden = max(dim_hidden_min, dim_out // 2)
|
||||||
|
|
||||||
self.net = nn.Sequential(
|
self.net = nn.Sequential(
|
||||||
nn.Conv2d(dim, dim_hidden, 1), nn.LeakyReLU(0.1), nn.Conv2d(dim_hidden, dim_out, 1), nn.Sigmoid()
|
nn.Conv2d(dim, dim_hidden, 1),
|
||||||
|
nn.LeakyReLU(0.1),
|
||||||
|
nn.Conv2d(dim_hidden, dim_out, 1),
|
||||||
|
nn.Sigmoid(),
|
||||||
)
|
)
|
||||||
|
|
||||||
nn.init.zeros_(self.net[-2].weight)
|
nn.init.zeros_(self.net[-2].weight)
|
||||||
@ -282,7 +287,10 @@ class RMSNorm(Module):
|
|||||||
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
|
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
|
return (
|
||||||
|
F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma
|
||||||
|
+ self.bias
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AdaptiveRMSNorm(Module):
|
class AdaptiveRMSNorm(Module):
|
||||||
@ -353,7 +361,8 @@ class Attention(Module):
|
|||||||
self.norm = RMSNorm(dim)
|
self.norm = RMSNorm(dim)
|
||||||
|
|
||||||
self.to_qkv = nn.Sequential(
|
self.to_qkv = nn.Sequential(
|
||||||
nn.Linear(dim, dim_inner * 3, bias=False), Rearrange("b n (qkv h d) -> qkv b h n d", qkv=3, h=heads)
|
nn.Linear(dim, dim_inner * 3, bias=False),
|
||||||
|
Rearrange("b n (qkv h d) -> qkv b h n d", qkv=3, h=heads),
|
||||||
)
|
)
|
||||||
|
|
||||||
assert num_memory_kv > 0
|
assert num_memory_kv > 0
|
||||||
@ -361,7 +370,9 @@ class Attention(Module):
|
|||||||
|
|
||||||
self.attend = Attend(causal=causal, dropout=dropout, flash=flash)
|
self.attend = Attend(causal=causal, dropout=dropout, flash=flash)
|
||||||
|
|
||||||
self.to_out = nn.Sequential(Rearrange("b h n d -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False))
|
self.to_out = nn.Sequential(
|
||||||
|
Rearrange("b h n d -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False)
|
||||||
|
)
|
||||||
|
|
||||||
@beartype
|
@beartype
|
||||||
def forward(self, x, mask: Optional[Tensor] = None, cond: Optional[Tensor] = None):
|
def forward(self, x, mask: Optional[Tensor] = None, cond: Optional[Tensor] = None):
|
||||||
@ -455,7 +466,9 @@ class FeedForward(Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
conv_klass = nn.Conv2d if images else nn.Conv3d
|
conv_klass = nn.Conv2d if images else nn.Conv3d
|
||||||
|
|
||||||
rmsnorm_klass = RMSNorm if not exists(dim_cond) else partial(AdaptiveRMSNorm, dim_cond=dim_cond)
|
rmsnorm_klass = (
|
||||||
|
RMSNorm if not exists(dim_cond) else partial(AdaptiveRMSNorm, dim_cond=dim_cond)
|
||||||
|
)
|
||||||
|
|
||||||
maybe_adaptive_norm_klass = partial(rmsnorm_klass, channel_first=True, images=images)
|
maybe_adaptive_norm_klass = partial(rmsnorm_klass, channel_first=True, images=images)
|
||||||
|
|
||||||
@ -463,7 +476,9 @@ class FeedForward(Module):
|
|||||||
|
|
||||||
self.norm = maybe_adaptive_norm_klass(dim)
|
self.norm = maybe_adaptive_norm_klass(dim)
|
||||||
|
|
||||||
self.net = Sequential(conv_klass(dim, dim_inner * 2, 1), GEGLU(), conv_klass(dim_inner, dim, 1))
|
self.net = Sequential(
|
||||||
|
conv_klass(dim, dim_inner * 2, 1), GEGLU(), conv_klass(dim_inner, dim, 1)
|
||||||
|
)
|
||||||
|
|
||||||
@beartype
|
@beartype
|
||||||
def forward(self, x: Tensor, *, cond: Optional[Tensor] = None):
|
def forward(self, x: Tensor, *, cond: Optional[Tensor] = None):
|
||||||
@ -525,7 +540,8 @@ class DiscriminatorBlock(Module):
|
|||||||
|
|
||||||
self.downsample = (
|
self.downsample = (
|
||||||
nn.Sequential(
|
nn.Sequential(
|
||||||
Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2), nn.Conv2d(filters * 4, filters, 1)
|
Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2),
|
||||||
|
nn.Conv2d(filters * 4, filters, 1),
|
||||||
)
|
)
|
||||||
if downsample
|
if downsample
|
||||||
else None
|
else None
|
||||||
@ -584,11 +600,18 @@ class Discriminator(Module):
|
|||||||
is_not_last = ind != (len(layer_dims_in_out) - 1)
|
is_not_last = ind != (len(layer_dims_in_out) - 1)
|
||||||
|
|
||||||
block = DiscriminatorBlock(
|
block = DiscriminatorBlock(
|
||||||
in_chan, out_chan, downsample=is_not_last, antialiased_downsample=antialiased_downsample
|
in_chan,
|
||||||
|
out_chan,
|
||||||
|
downsample=is_not_last,
|
||||||
|
antialiased_downsample=antialiased_downsample,
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_block = Sequential(
|
attn_block = Sequential(
|
||||||
Residual(LinearSpaceAttention(dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head)),
|
Residual(
|
||||||
|
LinearSpaceAttention(
|
||||||
|
dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head
|
||||||
|
)
|
||||||
|
),
|
||||||
Residual(FeedForward(dim=out_chan, mult=ff_mult, images=True)),
|
Residual(FeedForward(dim=out_chan, mult=ff_mult, images=True)),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -628,7 +651,16 @@ class Discriminator(Module):
|
|||||||
class Conv3DMod(Module):
|
class Conv3DMod(Module):
|
||||||
@beartype
|
@beartype
|
||||||
def __init__(
|
def __init__(
|
||||||
self, dim, *, spatial_kernel, time_kernel, causal=True, dim_out=None, demod=True, eps=1e-8, pad_mode="zeros"
|
self,
|
||||||
|
dim,
|
||||||
|
*,
|
||||||
|
spatial_kernel,
|
||||||
|
time_kernel,
|
||||||
|
causal=True,
|
||||||
|
dim_out=None,
|
||||||
|
demod=True,
|
||||||
|
eps=1e-8,
|
||||||
|
pad_mode="zeros",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
dim_out = default(dim_out, dim)
|
dim_out = default(dim_out, dim)
|
||||||
@ -644,7 +676,9 @@ class Conv3DMod(Module):
|
|||||||
|
|
||||||
self.pad_mode = pad_mode
|
self.pad_mode = pad_mode
|
||||||
self.padding = (*((spatial_kernel // 2,) * 4), *time_padding)
|
self.padding = (*((spatial_kernel // 2,) * 4), *time_padding)
|
||||||
self.weights = nn.Parameter(torch.randn((dim_out, dim, time_kernel, spatial_kernel, spatial_kernel)))
|
self.weights = nn.Parameter(
|
||||||
|
torch.randn((dim_out, dim, time_kernel, spatial_kernel, spatial_kernel))
|
||||||
|
)
|
||||||
|
|
||||||
self.demod = demod
|
self.demod = demod
|
||||||
|
|
||||||
@ -675,7 +709,11 @@ class Conv3DMod(Module):
|
|||||||
weights = weights * (cond + 1)
|
weights = weights * (cond + 1)
|
||||||
|
|
||||||
if self.demod:
|
if self.demod:
|
||||||
inv_norm = reduce(weights**2, "b o i k0 k1 k2 -> b o 1 1 1 1", "sum").clamp(min=self.eps).rsqrt()
|
inv_norm = (
|
||||||
|
reduce(weights**2, "b o i k0 k1 k2 -> b o 1 1 1 1", "sum")
|
||||||
|
.clamp(min=self.eps)
|
||||||
|
.rsqrt()
|
||||||
|
)
|
||||||
weights = weights * inv_norm
|
weights = weights * inv_norm
|
||||||
|
|
||||||
fmap = rearrange(fmap, "b c t h w -> 1 (b c) t h w")
|
fmap = rearrange(fmap, "b c t h w -> 1 (b c) t h w")
|
||||||
@ -742,7 +780,9 @@ class SpatialUpsample2x(Module):
|
|||||||
dim_out = default(dim_out, dim)
|
dim_out = default(dim_out, dim)
|
||||||
conv = nn.Conv2d(dim, dim_out * 4, 1)
|
conv = nn.Conv2d(dim, dim_out * 4, 1)
|
||||||
|
|
||||||
self.net = nn.Sequential(conv, nn.SiLU(), Rearrange("b (c p1 p2) h w -> b c (h p1) (w p2)", p1=2, p2=2))
|
self.net = nn.Sequential(
|
||||||
|
conv, nn.SiLU(), Rearrange("b (c p1 p2) h w -> b c (h p1) (w p2)", p1=2, p2=2)
|
||||||
|
)
|
||||||
|
|
||||||
self.init_conv_(conv)
|
self.init_conv_(conv)
|
||||||
|
|
||||||
@ -808,7 +848,12 @@ def SameConv2d(dim_in, dim_out, kernel_size):
|
|||||||
class CausalConv3d(Module):
|
class CausalConv3d(Module):
|
||||||
@beartype
|
@beartype
|
||||||
def __init__(
|
def __init__(
|
||||||
self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], pad_mode="constant", **kwargs
|
self,
|
||||||
|
chan_in,
|
||||||
|
chan_out,
|
||||||
|
kernel_size: Union[int, Tuple[int, int, int]],
|
||||||
|
pad_mode="constant",
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
kernel_size = cast_tuple(kernel_size, 3)
|
kernel_size = cast_tuple(kernel_size, 3)
|
||||||
@ -830,7 +875,9 @@ class CausalConv3d(Module):
|
|||||||
|
|
||||||
stride = (stride, 1, 1)
|
stride = (stride, 1, 1)
|
||||||
dilation = (dilation, 1, 1)
|
dilation = (dilation, 1, 1)
|
||||||
self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
self.conv = nn.Conv3d(
|
||||||
|
chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
pad_mode = self.pad_mode if self.time_pad < x.shape[2] else "constant"
|
pad_mode = self.pad_mode if self.time_pad < x.shape[2] else "constant"
|
||||||
@ -855,7 +902,13 @@ def ResidualUnit(dim, kernel_size: Union[int, Tuple[int, int, int]], pad_mode: s
|
|||||||
@beartype
|
@beartype
|
||||||
class ResidualUnitMod(Module):
|
class ResidualUnitMod(Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, dim, kernel_size: Union[int, Tuple[int, int, int]], *, dim_cond, pad_mode: str = "constant", demod=True
|
self,
|
||||||
|
dim,
|
||||||
|
kernel_size: Union[int, Tuple[int, int, int]],
|
||||||
|
*,
|
||||||
|
dim_cond,
|
||||||
|
pad_mode: str = "constant",
|
||||||
|
demod=True,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
kernel_size = cast_tuple(kernel_size, 3)
|
kernel_size = cast_tuple(kernel_size, 3)
|
||||||
@ -892,7 +945,15 @@ class ResidualUnitMod(Module):
|
|||||||
|
|
||||||
|
|
||||||
class CausalConvTranspose3d(Module):
|
class CausalConvTranspose3d(Module):
|
||||||
def __init__(self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], *, time_stride, **kwargs):
|
def __init__(
|
||||||
|
self,
|
||||||
|
chan_in,
|
||||||
|
chan_out,
|
||||||
|
kernel_size: Union[int, Tuple[int, int, int]],
|
||||||
|
*,
|
||||||
|
time_stride,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
kernel_size = cast_tuple(kernel_size, 3)
|
kernel_size = cast_tuple(kernel_size, 3)
|
||||||
|
|
||||||
@ -908,7 +969,9 @@ class CausalConvTranspose3d(Module):
|
|||||||
stride = (time_stride, 1, 1)
|
stride = (time_stride, 1, 1)
|
||||||
padding = (0, height_pad, width_pad)
|
padding = (0, height_pad, width_pad)
|
||||||
|
|
||||||
self.conv = nn.ConvTranspose3d(chan_in, chan_out, kernel_size, stride, padding=padding, **kwargs)
|
self.conv = nn.ConvTranspose3d(
|
||||||
|
chan_in, chan_out, kernel_size, stride, padding=padding, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
assert x.ndim == 5
|
assert x.ndim == 5
|
||||||
@ -936,7 +999,9 @@ LossBreakdown = namedtuple(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
DiscrLossBreakdown = namedtuple("DiscrLossBreakdown", ["discr_loss", "multiscale_discr_losses", "gradient_penalty"])
|
DiscrLossBreakdown = namedtuple(
|
||||||
|
"DiscrLossBreakdown", ["discr_loss", "multiscale_discr_losses", "gradient_penalty"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class VideoTokenizer(Module):
|
class VideoTokenizer(Module):
|
||||||
@ -1050,10 +1115,14 @@ class VideoTokenizer(Module):
|
|||||||
has_cond = True
|
has_cond = True
|
||||||
|
|
||||||
encoder_layer = ResidualUnitMod(
|
encoder_layer = ResidualUnitMod(
|
||||||
dim, residual_conv_kernel_size, dim_cond=int(dim_cond * dim_cond_expansion_factor)
|
dim,
|
||||||
|
residual_conv_kernel_size,
|
||||||
|
dim_cond=int(dim_cond * dim_cond_expansion_factor),
|
||||||
)
|
)
|
||||||
decoder_layer = ResidualUnitMod(
|
decoder_layer = ResidualUnitMod(
|
||||||
dim, residual_conv_kernel_size, dim_cond=int(dim_cond * dim_cond_expansion_factor)
|
dim,
|
||||||
|
residual_conv_kernel_size,
|
||||||
|
dim_cond=int(dim_cond * dim_cond_expansion_factor),
|
||||||
)
|
)
|
||||||
dim_out = dim
|
dim_out = dim
|
||||||
|
|
||||||
@ -1080,15 +1149,25 @@ class VideoTokenizer(Module):
|
|||||||
|
|
||||||
elif layer_type == "attend_space":
|
elif layer_type == "attend_space":
|
||||||
attn_kwargs = dict(
|
attn_kwargs = dict(
|
||||||
dim=dim, dim_head=attn_dim_head, heads=attn_heads, dropout=attn_dropout, flash=flash_attn
|
dim=dim,
|
||||||
|
dim_head=attn_dim_head,
|
||||||
|
heads=attn_heads,
|
||||||
|
dropout=attn_dropout,
|
||||||
|
flash=flash_attn,
|
||||||
)
|
)
|
||||||
|
|
||||||
encoder_layer = Sequential(Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim)))
|
encoder_layer = Sequential(
|
||||||
|
Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim))
|
||||||
|
)
|
||||||
|
|
||||||
decoder_layer = Sequential(Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim)))
|
decoder_layer = Sequential(
|
||||||
|
Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim))
|
||||||
|
)
|
||||||
|
|
||||||
elif layer_type == "linear_attend_space":
|
elif layer_type == "linear_attend_space":
|
||||||
linear_attn_kwargs = dict(dim=dim, dim_head=linear_attn_dim_head, heads=linear_attn_heads)
|
linear_attn_kwargs = dict(
|
||||||
|
dim=dim, dim_head=linear_attn_dim_head, heads=linear_attn_heads
|
||||||
|
)
|
||||||
|
|
||||||
encoder_layer = Sequential(
|
encoder_layer = Sequential(
|
||||||
Residual(LinearSpaceAttention(**linear_attn_kwargs)), Residual(FeedForward(dim))
|
Residual(LinearSpaceAttention(**linear_attn_kwargs)), Residual(FeedForward(dim))
|
||||||
@ -1136,9 +1215,13 @@ class VideoTokenizer(Module):
|
|||||||
flash=flash_attn,
|
flash=flash_attn,
|
||||||
)
|
)
|
||||||
|
|
||||||
encoder_layer = Sequential(Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim)))
|
encoder_layer = Sequential(
|
||||||
|
Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim))
|
||||||
|
)
|
||||||
|
|
||||||
decoder_layer = Sequential(Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim)))
|
decoder_layer = Sequential(
|
||||||
|
Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim))
|
||||||
|
)
|
||||||
|
|
||||||
elif layer_type == "cond_linear_attend_space":
|
elif layer_type == "cond_linear_attend_space":
|
||||||
has_cond = True
|
has_cond = True
|
||||||
@ -1153,11 +1236,13 @@ class VideoTokenizer(Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
encoder_layer = Sequential(
|
encoder_layer = Sequential(
|
||||||
Residual(LinearSpaceAttention(**attn_kwargs)), Residual(FeedForward(dim, dim_cond=dim_cond))
|
Residual(LinearSpaceAttention(**attn_kwargs)),
|
||||||
|
Residual(FeedForward(dim, dim_cond=dim_cond)),
|
||||||
)
|
)
|
||||||
|
|
||||||
decoder_layer = Sequential(
|
decoder_layer = Sequential(
|
||||||
Residual(LinearSpaceAttention(**attn_kwargs)), Residual(FeedForward(dim, dim_cond=dim_cond))
|
Residual(LinearSpaceAttention(**attn_kwargs)),
|
||||||
|
Residual(FeedForward(dim, dim_cond=dim_cond)),
|
||||||
)
|
)
|
||||||
|
|
||||||
elif layer_type == "cond_attend_time":
|
elif layer_type == "cond_attend_time":
|
||||||
@ -1283,7 +1368,9 @@ class VideoTokenizer(Module):
|
|||||||
|
|
||||||
# discriminator
|
# discriminator
|
||||||
|
|
||||||
discr_kwargs = default(discr_kwargs, dict(dim=dim, image_size=image_size, channels=channels, max_dim=512))
|
discr_kwargs = default(
|
||||||
|
discr_kwargs, dict(dim=dim, image_size=image_size, channels=channels, max_dim=512)
|
||||||
|
)
|
||||||
|
|
||||||
self.discr = Discriminator(**discr_kwargs)
|
self.discr = Discriminator(**discr_kwargs)
|
||||||
|
|
||||||
@ -1380,8 +1467,16 @@ class VideoTokenizer(Module):
|
|||||||
self.load_state_dict(state_dict, strict=strict)
|
self.load_state_dict(state_dict, strict=strict)
|
||||||
|
|
||||||
@beartype
|
@beartype
|
||||||
def encode(self, video: Tensor, quantize=False, cond: Optional[Tensor] = None, video_contains_first_frame=True):
|
def encode(
|
||||||
encode_first_frame_separately = self.separate_first_frame_encoding and video_contains_first_frame
|
self,
|
||||||
|
video: Tensor,
|
||||||
|
quantize=False,
|
||||||
|
cond: Optional[Tensor] = None,
|
||||||
|
video_contains_first_frame=True,
|
||||||
|
):
|
||||||
|
encode_first_frame_separately = (
|
||||||
|
self.separate_first_frame_encoding and video_contains_first_frame
|
||||||
|
)
|
||||||
|
|
||||||
# whether to pad video or not
|
# whether to pad video or not
|
||||||
|
|
||||||
@ -1389,12 +1484,16 @@ class VideoTokenizer(Module):
|
|||||||
video_len = video.shape[2]
|
video_len = video.shape[2]
|
||||||
|
|
||||||
video = pad_at_dim(video, (self.time_padding, 0), value=0.0, dim=2)
|
video = pad_at_dim(video, (self.time_padding, 0), value=0.0, dim=2)
|
||||||
video_packed_shape = [torch.Size([self.time_padding]), torch.Size([]), torch.Size([video_len - 1])]
|
video_packed_shape = [
|
||||||
|
torch.Size([self.time_padding]),
|
||||||
|
torch.Size([]),
|
||||||
|
torch.Size([video_len - 1]),
|
||||||
|
]
|
||||||
|
|
||||||
# conditioning, if needed
|
# conditioning, if needed
|
||||||
|
|
||||||
assert (not self.has_cond) or exists(
|
assert (
|
||||||
cond
|
(not self.has_cond) or exists(cond)
|
||||||
), "`cond` must be passed into tokenizer forward method since conditionable layers were specified"
|
), "`cond` must be passed into tokenizer forward method since conditionable layers were specified"
|
||||||
|
|
||||||
if exists(cond):
|
if exists(cond):
|
||||||
@ -1431,7 +1530,9 @@ class VideoTokenizer(Module):
|
|||||||
return maybe_quantize(video)
|
return maybe_quantize(video)
|
||||||
|
|
||||||
@beartype
|
@beartype
|
||||||
def decode_from_code_indices(self, codes: Tensor, cond: Optional[Tensor] = None, video_contains_first_frame=True):
|
def decode_from_code_indices(
|
||||||
|
self, codes: Tensor, cond: Optional[Tensor] = None, video_contains_first_frame=True
|
||||||
|
):
|
||||||
assert codes.dtype in (torch.long, torch.int32)
|
assert codes.dtype in (torch.long, torch.int32)
|
||||||
|
|
||||||
if codes.ndim == 2:
|
if codes.ndim == 2:
|
||||||
@ -1444,18 +1545,24 @@ class VideoTokenizer(Module):
|
|||||||
|
|
||||||
quantized = self.quantizers.indices_to_codes(codes)
|
quantized = self.quantizers.indices_to_codes(codes)
|
||||||
|
|
||||||
return self.decode(quantized, cond=cond, video_contains_first_frame=video_contains_first_frame)
|
return self.decode(
|
||||||
|
quantized, cond=cond, video_contains_first_frame=video_contains_first_frame
|
||||||
|
)
|
||||||
|
|
||||||
@beartype
|
@beartype
|
||||||
def decode(self, quantized: Tensor, cond: Optional[Tensor] = None, video_contains_first_frame=True):
|
def decode(
|
||||||
decode_first_frame_separately = self.separate_first_frame_encoding and video_contains_first_frame
|
self, quantized: Tensor, cond: Optional[Tensor] = None, video_contains_first_frame=True
|
||||||
|
):
|
||||||
|
decode_first_frame_separately = (
|
||||||
|
self.separate_first_frame_encoding and video_contains_first_frame
|
||||||
|
)
|
||||||
|
|
||||||
batch = quantized.shape[0]
|
batch = quantized.shape[0]
|
||||||
|
|
||||||
# conditioning, if needed
|
# conditioning, if needed
|
||||||
|
|
||||||
assert (not self.has_cond) or exists(
|
assert (
|
||||||
cond
|
(not self.has_cond) or exists(cond)
|
||||||
), "`cond` must be passed into tokenizer forward method since conditionable layers were specified"
|
), "`cond` must be passed into tokenizer forward method since conditionable layers were specified"
|
||||||
|
|
||||||
if exists(cond):
|
if exists(cond):
|
||||||
@ -1558,14 +1665,18 @@ class VideoTokenizer(Module):
|
|||||||
aux_losses = self.zero
|
aux_losses = self.zero
|
||||||
quantizer_loss_breakdown = None
|
quantizer_loss_breakdown = None
|
||||||
else:
|
else:
|
||||||
(quantized, codes, aux_losses), quantizer_loss_breakdown = self.quantizers(x, return_loss_breakdown=True)
|
(quantized, codes, aux_losses), quantizer_loss_breakdown = self.quantizers(
|
||||||
|
x, return_loss_breakdown=True
|
||||||
|
)
|
||||||
|
|
||||||
if return_codes and not return_recon:
|
if return_codes and not return_recon:
|
||||||
return codes
|
return codes
|
||||||
|
|
||||||
# decoder
|
# decoder
|
||||||
|
|
||||||
recon_video = self.decode(quantized, cond=cond, video_contains_first_frame=video_contains_first_frame)
|
recon_video = self.decode(
|
||||||
|
quantized, cond=cond, video_contains_first_frame=video_contains_first_frame
|
||||||
|
)
|
||||||
|
|
||||||
if return_codes:
|
if return_codes:
|
||||||
return codes, recon_video
|
return codes, recon_video
|
||||||
@ -1613,7 +1724,9 @@ class VideoTokenizer(Module):
|
|||||||
multiscale_real_logits = discr(video)
|
multiscale_real_logits = discr(video)
|
||||||
multiscale_fake_logits = discr(recon_video.detach())
|
multiscale_fake_logits = discr(recon_video.detach())
|
||||||
|
|
||||||
multiscale_discr_loss = hinge_discr_loss(multiscale_fake_logits, multiscale_real_logits)
|
multiscale_discr_loss = hinge_discr_loss(
|
||||||
|
multiscale_fake_logits, multiscale_real_logits
|
||||||
|
)
|
||||||
|
|
||||||
multiscale_discr_losses.append(multiscale_discr_loss)
|
multiscale_discr_losses.append(multiscale_discr_loss)
|
||||||
else:
|
else:
|
||||||
@ -1634,7 +1747,9 @@ class VideoTokenizer(Module):
|
|||||||
+ sum(multiscale_discr_losses) * self.multiscale_adversarial_loss_weight
|
+ sum(multiscale_discr_losses) * self.multiscale_adversarial_loss_weight
|
||||||
)
|
)
|
||||||
|
|
||||||
discr_loss_breakdown = DiscrLossBreakdown(discr_loss, multiscale_discr_losses, gradient_penalty_loss)
|
discr_loss_breakdown = DiscrLossBreakdown(
|
||||||
|
discr_loss, multiscale_discr_losses, gradient_penalty_loss
|
||||||
|
)
|
||||||
|
|
||||||
return total_loss, discr_loss_breakdown
|
return total_loss, discr_loss_breakdown
|
||||||
|
|
||||||
@ -1669,7 +1784,9 @@ class VideoTokenizer(Module):
|
|||||||
norm_grad_wrt_perceptual_loss = None
|
norm_grad_wrt_perceptual_loss = None
|
||||||
|
|
||||||
if self.training and self.use_vgg and (self.has_gan or self.has_multiscale_discrs):
|
if self.training and self.use_vgg and (self.has_gan or self.has_multiscale_discrs):
|
||||||
norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p=2)
|
norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(
|
||||||
|
perceptual_loss, last_dec_layer
|
||||||
|
).norm(p=2)
|
||||||
|
|
||||||
# per-frame image discriminator
|
# per-frame image discriminator
|
||||||
|
|
||||||
@ -1686,7 +1803,9 @@ class VideoTokenizer(Module):
|
|||||||
|
|
||||||
if exists(norm_grad_wrt_perceptual_loss):
|
if exists(norm_grad_wrt_perceptual_loss):
|
||||||
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p=2)
|
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p=2)
|
||||||
adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(min=1e-3)
|
adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(
|
||||||
|
min=1e-3
|
||||||
|
)
|
||||||
adaptive_weight.clamp_(max=1e3)
|
adaptive_weight.clamp_(max=1e3)
|
||||||
|
|
||||||
if torch.isnan(adaptive_weight).any():
|
if torch.isnan(adaptive_weight).any():
|
||||||
@ -1713,8 +1832,12 @@ class VideoTokenizer(Module):
|
|||||||
multiscale_adaptive_weight = 1.0
|
multiscale_adaptive_weight = 1.0
|
||||||
|
|
||||||
if exists(norm_grad_wrt_perceptual_loss):
|
if exists(norm_grad_wrt_perceptual_loss):
|
||||||
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(multiscale_gen_loss, last_dec_layer).norm(p=2)
|
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(
|
||||||
multiscale_adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(min=1e-5)
|
multiscale_gen_loss, last_dec_layer
|
||||||
|
).norm(p=2)
|
||||||
|
multiscale_adaptive_weight = (
|
||||||
|
norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(min=1e-5)
|
||||||
|
)
|
||||||
multiscale_adaptive_weight.clamp_(max=1e3)
|
multiscale_adaptive_weight.clamp_(max=1e3)
|
||||||
|
|
||||||
multiscale_gen_adaptive_weights.append(multiscale_adaptive_weight)
|
multiscale_gen_adaptive_weights.append(multiscale_adaptive_weight)
|
||||||
@ -1730,10 +1853,13 @@ class VideoTokenizer(Module):
|
|||||||
|
|
||||||
if self.has_multiscale_discrs:
|
if self.has_multiscale_discrs:
|
||||||
weighted_multiscale_gen_losses = sum(
|
weighted_multiscale_gen_losses = sum(
|
||||||
loss * weight for loss, weight in zip(multiscale_gen_losses, multiscale_gen_adaptive_weights)
|
loss * weight
|
||||||
|
for loss, weight in zip(multiscale_gen_losses, multiscale_gen_adaptive_weights)
|
||||||
)
|
)
|
||||||
|
|
||||||
total_loss = total_loss + weighted_multiscale_gen_losses * multiscale_adversarial_loss_weight
|
total_loss = (
|
||||||
|
total_loss + weighted_multiscale_gen_losses * multiscale_adversarial_loss_weight
|
||||||
|
)
|
||||||
|
|
||||||
# loss breakdown
|
# loss breakdown
|
||||||
|
|
||||||
|
@ -26,7 +26,9 @@ class IdentityRegularizer(AbstractRegularizer):
|
|||||||
yield from ()
|
yield from ()
|
||||||
|
|
||||||
|
|
||||||
def measure_perplexity(predicted_indices: torch.Tensor, num_centroids: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
def measure_perplexity(
|
||||||
|
predicted_indices: torch.Tensor, num_centroids: int
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
|
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
|
||||||
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
|
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
|
||||||
encodings = F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids)
|
encodings = F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids)
|
||||||
|
@ -79,13 +79,19 @@ class FSQ(Module):
|
|||||||
self.dim = default(dim, len(_levels) * num_codebooks)
|
self.dim = default(dim, len(_levels) * num_codebooks)
|
||||||
|
|
||||||
has_projections = self.dim != effective_codebook_dim
|
has_projections = self.dim != effective_codebook_dim
|
||||||
self.project_in = nn.Linear(self.dim, effective_codebook_dim) if has_projections else nn.Identity()
|
self.project_in = (
|
||||||
self.project_out = nn.Linear(effective_codebook_dim, self.dim) if has_projections else nn.Identity()
|
nn.Linear(self.dim, effective_codebook_dim) if has_projections else nn.Identity()
|
||||||
|
)
|
||||||
|
self.project_out = (
|
||||||
|
nn.Linear(effective_codebook_dim, self.dim) if has_projections else nn.Identity()
|
||||||
|
)
|
||||||
self.has_projections = has_projections
|
self.has_projections = has_projections
|
||||||
|
|
||||||
self.codebook_size = self._levels.prod().item()
|
self.codebook_size = self._levels.prod().item()
|
||||||
|
|
||||||
implicit_codebook = self.indices_to_codes(torch.arange(self.codebook_size), project_out=False)
|
implicit_codebook = self.indices_to_codes(
|
||||||
|
torch.arange(self.codebook_size), project_out=False
|
||||||
|
)
|
||||||
self.register_buffer("implicit_codebook", implicit_codebook, persistent=False)
|
self.register_buffer("implicit_codebook", implicit_codebook, persistent=False)
|
||||||
|
|
||||||
def bound(self, z: Tensor, eps: float = 1e-3) -> Tensor:
|
def bound(self, z: Tensor, eps: float = 1e-3) -> Tensor:
|
||||||
@ -153,7 +159,9 @@ class FSQ(Module):
|
|||||||
z = rearrange(z, "b d ... -> b ... d")
|
z = rearrange(z, "b d ... -> b ... d")
|
||||||
z, ps = pack_one(z, "b * d")
|
z, ps = pack_one(z, "b * d")
|
||||||
|
|
||||||
assert z.shape[-1] == self.dim, f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}"
|
assert (
|
||||||
|
z.shape[-1] == self.dim
|
||||||
|
), f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}"
|
||||||
|
|
||||||
z = self.project_in(z)
|
z = self.project_in(z)
|
||||||
|
|
||||||
|
@ -78,7 +78,9 @@ class LFQ(Module):
|
|||||||
|
|
||||||
# some assert validations
|
# some assert validations
|
||||||
|
|
||||||
assert exists(dim) or exists(codebook_size), "either dim or codebook_size must be specified for LFQ"
|
assert exists(dim) or exists(
|
||||||
|
codebook_size
|
||||||
|
), "either dim or codebook_size must be specified for LFQ"
|
||||||
assert (
|
assert (
|
||||||
not exists(codebook_size) or log2(codebook_size).is_integer()
|
not exists(codebook_size) or log2(codebook_size).is_integer()
|
||||||
), f"your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})"
|
), f"your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})"
|
||||||
@ -195,7 +197,9 @@ class LFQ(Module):
|
|||||||
x = rearrange(x, "b d ... -> b ... d")
|
x = rearrange(x, "b d ... -> b ... d")
|
||||||
x, ps = pack_one(x, "b * d")
|
x, ps = pack_one(x, "b * d")
|
||||||
|
|
||||||
assert x.shape[-1] == self.dim, f"expected dimension of {self.dim} but received {x.shape[-1]}"
|
assert (
|
||||||
|
x.shape[-1] == self.dim
|
||||||
|
), f"expected dimension of {self.dim} but received {x.shape[-1]}"
|
||||||
|
|
||||||
x = self.project_in(x)
|
x = self.project_in(x)
|
||||||
|
|
||||||
@ -299,7 +303,9 @@ class LFQ(Module):
|
|||||||
|
|
||||||
# complete aux loss
|
# complete aux loss
|
||||||
|
|
||||||
aux_loss = entropy_aux_loss * self.entropy_loss_weight + commit_loss * self.commitment_loss_weight
|
aux_loss = (
|
||||||
|
entropy_aux_loss * self.entropy_loss_weight + commit_loss * self.commitment_loss_weight
|
||||||
|
)
|
||||||
|
|
||||||
ret = Return(x, indices, aux_loss)
|
ret = Return(x, indices, aux_loss)
|
||||||
|
|
||||||
|
@ -33,7 +33,9 @@ class AbstractQuantizer(AbstractRegularizer):
|
|||||||
new = match.argmax(-1)
|
new = match.argmax(-1)
|
||||||
unknown = match.sum(2) < 1
|
unknown = match.sum(2) < 1
|
||||||
if self.unknown_index == "random":
|
if self.unknown_index == "random":
|
||||||
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
|
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(
|
||||||
|
device=new.device
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
new[unknown] = self.unknown_index
|
new[unknown] = self.unknown_index
|
||||||
return new.reshape(ishape)
|
return new.reshape(ishape)
|
||||||
@ -50,7 +52,9 @@ class AbstractQuantizer(AbstractRegularizer):
|
|||||||
return back.reshape(ishape)
|
return back.reshape(ishape)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_codebook_entry(self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None) -> torch.Tensor:
|
def get_codebook_entry(
|
||||||
|
self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None
|
||||||
|
) -> torch.Tensor:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def get_trainable_parameters(self) -> Iterator[torch.nn.Parameter]:
|
def get_trainable_parameters(self) -> Iterator[torch.nn.Parameter]:
|
||||||
@ -239,7 +243,8 @@ class VectorQuantizer(AbstractQuantizer):
|
|||||||
d = (
|
d = (
|
||||||
torch.sum(z_flattened**2, dim=1, keepdim=True)
|
torch.sum(z_flattened**2, dim=1, keepdim=True)
|
||||||
+ torch.sum(self.embedding.weight**2, dim=1)
|
+ torch.sum(self.embedding.weight**2, dim=1)
|
||||||
- 2 * torch.einsum("bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n"))
|
- 2
|
||||||
|
* torch.einsum("bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n"))
|
||||||
)
|
)
|
||||||
|
|
||||||
min_encoding_indices = torch.argmin(d, dim=1)
|
min_encoding_indices = torch.argmin(d, dim=1)
|
||||||
@ -267,15 +272,21 @@ class VectorQuantizer(AbstractQuantizer):
|
|||||||
|
|
||||||
if self.sane_index_shape:
|
if self.sane_index_shape:
|
||||||
if do_reshape:
|
if do_reshape:
|
||||||
min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])
|
min_encoding_indices = min_encoding_indices.reshape(
|
||||||
|
z_q.shape[0], z_q.shape[2], z_q.shape[3]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
min_encoding_indices = rearrange(min_encoding_indices, "(b s) 1 -> b s", b=z_q.shape[0])
|
min_encoding_indices = rearrange(
|
||||||
|
min_encoding_indices, "(b s) 1 -> b s", b=z_q.shape[0]
|
||||||
|
)
|
||||||
|
|
||||||
loss_dict["min_encoding_indices"] = min_encoding_indices
|
loss_dict["min_encoding_indices"] = min_encoding_indices
|
||||||
|
|
||||||
return z_q, loss_dict
|
return z_q, loss_dict
|
||||||
|
|
||||||
def get_codebook_entry(self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None) -> torch.Tensor:
|
def get_codebook_entry(
|
||||||
|
self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None
|
||||||
|
) -> torch.Tensor:
|
||||||
# shape specifying (batch, height, width, channel)
|
# shape specifying (batch, height, width, channel)
|
||||||
if self.remap is not None:
|
if self.remap is not None:
|
||||||
assert shape is not None, "Need to give shape for remap"
|
assert shape is not None, "Need to give shape for remap"
|
||||||
@ -448,6 +459,8 @@ class VectorQuantizerWithInputProjection(VectorQuantizer):
|
|||||||
elif len(in_shape) == 5:
|
elif len(in_shape) == 5:
|
||||||
z_q = rearrange(z_q, "b (t h w) c -> b c t h w ", w=in_shape[-1], h=in_shape[-2])
|
z_q = rearrange(z_q, "b (t h w) c -> b c t h w ", w=in_shape[-1], h=in_shape[-2])
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"rearranging not available for {len(in_shape)}-dimensional input.")
|
raise NotImplementedError(
|
||||||
|
f"rearranging not available for {len(in_shape)}-dimensional input."
|
||||||
|
)
|
||||||
|
|
||||||
return z_q, loss_dict
|
return z_q, loss_dict
|
||||||
|
@ -248,7 +248,9 @@ def make_time_attn(
|
|||||||
"vanilla",
|
"vanilla",
|
||||||
"vanilla-xformers",
|
"vanilla-xformers",
|
||||||
], f"attn_type {attn_type} not supported for spatio-temporal attention"
|
], f"attn_type {attn_type} not supported for spatio-temporal attention"
|
||||||
print(f"making spatial and temporal attention of type '{attn_type}' with {in_channels} in_channels")
|
print(
|
||||||
|
f"making spatial and temporal attention of type '{attn_type}' with {in_channels} in_channels"
|
||||||
|
)
|
||||||
if not XFORMERS_IS_AVAILABLE and attn_type == "vanilla-xformers":
|
if not XFORMERS_IS_AVAILABLE and attn_type == "vanilla-xformers":
|
||||||
print(
|
print(
|
||||||
f"Attention mode '{attn_type}' is not available. Falling back to vanilla attention. "
|
f"Attention mode '{attn_type}' is not available. Falling back to vanilla attention. "
|
||||||
|
@ -125,9 +125,13 @@ class ResnetBlock3D(nn.Module):
|
|||||||
self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
|
self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
|
||||||
if self.in_channels != self.out_channels:
|
if self.in_channels != self.out_channels:
|
||||||
if self.use_conv_shortcut:
|
if self.use_conv_shortcut:
|
||||||
self.conv_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
|
self.conv_shortcut = CausalConv3d(
|
||||||
|
in_channels, out_channels, kernel_size=3, pad_mode=pad_mode
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.nin_shortcut = torch.nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
self.nin_shortcut = torch.nn.Conv3d(
|
||||||
|
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x, temb, zq):
|
def forward(self, x, temb, zq):
|
||||||
h = x
|
h = x
|
||||||
@ -161,7 +165,9 @@ class AttnBlock2D(nn.Module):
|
|||||||
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
self.proj_out = torch.nn.Conv2d(
|
||||||
|
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x, zq):
|
def forward(self, x, zq):
|
||||||
h_ = x
|
h_ = x
|
||||||
@ -380,7 +386,11 @@ class NewDecoder3D(nn.Module):
|
|||||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||||
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
|
print(
|
||||||
|
"Working with z of shape {} = {} dimensions.".format(
|
||||||
|
self.z_shape, np.prod(self.z_shape)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# z to block_in
|
# z to block_in
|
||||||
# self.conv_in = torch.nn.Conv3d(z_channels,
|
# self.conv_in = torch.nn.Conv3d(z_channels,
|
||||||
|
@ -148,9 +148,13 @@ class ResnetBlock3D(nn.Module):
|
|||||||
# kernel_size=3,
|
# kernel_size=3,
|
||||||
# stride=1,
|
# stride=1,
|
||||||
# padding=1)
|
# padding=1)
|
||||||
self.conv_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
|
self.conv_shortcut = CausalConv3d(
|
||||||
|
in_channels, out_channels, kernel_size=3, pad_mode=pad_mode
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.nin_shortcut = torch.nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
self.nin_shortcut = torch.nn.Conv3d(
|
||||||
|
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
||||||
|
)
|
||||||
# self.nin_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=1, pad_mode=pad_mode)
|
# self.nin_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=1, pad_mode=pad_mode)
|
||||||
|
|
||||||
def forward(self, x, temb, zq):
|
def forward(self, x, temb, zq):
|
||||||
@ -185,7 +189,9 @@ class AttnBlock2D(nn.Module):
|
|||||||
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
self.proj_out = torch.nn.Conv2d(
|
||||||
|
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x, zq):
|
def forward(self, x, zq):
|
||||||
h_ = x
|
h_ = x
|
||||||
@ -261,7 +267,11 @@ class MOVQDecoder3D(nn.Module):
|
|||||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||||
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
|
print(
|
||||||
|
"Working with z of shape {} = {} dimensions.".format(
|
||||||
|
self.z_shape, np.prod(self.z_shape)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# z to block_in
|
# z to block_in
|
||||||
# self.conv_in = torch.nn.Conv3d(z_channels,
|
# self.conv_in = torch.nn.Conv3d(z_channels,
|
||||||
@ -420,7 +430,11 @@ class NewDecoder3D(nn.Module):
|
|||||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||||
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
|
print(
|
||||||
|
"Working with z of shape {} = {} dimensions.".format(
|
||||||
|
self.z_shape, np.prod(self.z_shape)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# z to block_in
|
# z to block_in
|
||||||
# self.conv_in = torch.nn.Conv3d(z_channels,
|
# self.conv_in = torch.nn.Conv3d(z_channels,
|
||||||
|
@ -51,7 +51,12 @@ def nonlinearity(x):
|
|||||||
class CausalConv3d(nn.Module):
|
class CausalConv3d(nn.Module):
|
||||||
@beartype
|
@beartype
|
||||||
def __init__(
|
def __init__(
|
||||||
self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], pad_mode="constant", **kwargs
|
self,
|
||||||
|
chan_in,
|
||||||
|
chan_out,
|
||||||
|
kernel_size: Union[int, Tuple[int, int, int]],
|
||||||
|
pad_mode="constant",
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
kernel_size = cast_tuple(kernel_size, 3)
|
kernel_size = cast_tuple(kernel_size, 3)
|
||||||
@ -75,11 +80,20 @@ class CausalConv3d(nn.Module):
|
|||||||
|
|
||||||
stride = (stride, 1, 1)
|
stride = (stride, 1, 1)
|
||||||
dilation = (dilation, 1, 1)
|
dilation = (dilation, 1, 1)
|
||||||
self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
self.conv = nn.Conv3d(
|
||||||
|
chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if self.pad_mode == "constant":
|
if self.pad_mode == "constant":
|
||||||
causal_padding_3d = (self.time_pad, 0, self.width_pad, self.width_pad, self.height_pad, self.height_pad)
|
causal_padding_3d = (
|
||||||
|
self.time_pad,
|
||||||
|
0,
|
||||||
|
self.width_pad,
|
||||||
|
self.width_pad,
|
||||||
|
self.height_pad,
|
||||||
|
self.height_pad,
|
||||||
|
)
|
||||||
x = F.pad(x, causal_padding_3d, mode="constant", value=0)
|
x = F.pad(x, causal_padding_3d, mode="constant", value=0)
|
||||||
elif self.pad_mode == "first":
|
elif self.pad_mode == "first":
|
||||||
pad_x = torch.cat([x[:, :, :1]] * self.time_pad, dim=2)
|
pad_x = torch.cat([x[:, :, :1]] * self.time_pad, dim=2)
|
||||||
@ -91,7 +105,9 @@ class CausalConv3d(nn.Module):
|
|||||||
reflect_x = x[:, :, 1 : self.time_pad + 1, :, :].flip(dims=[2])
|
reflect_x = x[:, :, 1 : self.time_pad + 1, :, :].flip(dims=[2])
|
||||||
if reflect_x.shape[2] < self.time_pad:
|
if reflect_x.shape[2] < self.time_pad:
|
||||||
reflect_x = torch.cat(
|
reflect_x = torch.cat(
|
||||||
[torch.zeros_like(x[:, :, :1, :, :])] * (self.time_pad - reflect_x.shape[2]) + [reflect_x], dim=2
|
[torch.zeros_like(x[:, :, :1, :, :])] * (self.time_pad - reflect_x.shape[2])
|
||||||
|
+ [reflect_x],
|
||||||
|
dim=2,
|
||||||
)
|
)
|
||||||
x = torch.cat([reflect_x, x], dim=2)
|
x = torch.cat([reflect_x, x], dim=2)
|
||||||
causal_padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
|
causal_padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
|
||||||
@ -110,7 +126,9 @@ class Upsample3D(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.with_conv = with_conv
|
self.with_conv = with_conv
|
||||||
if self.with_conv:
|
if self.with_conv:
|
||||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
self.conv = torch.nn.Conv2d(
|
||||||
|
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
||||||
|
)
|
||||||
self.compress_time = compress_time
|
self.compress_time = compress_time
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@ -149,7 +167,9 @@ class DownSample3D(nn.Module):
|
|||||||
out_channels = in_channels
|
out_channels = in_channels
|
||||||
if self.with_conv:
|
if self.with_conv:
|
||||||
# no asymmetric padding in torch conv, must do it ourselves
|
# no asymmetric padding in torch conv, must do it ourselves
|
||||||
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0)
|
self.conv = torch.nn.Conv2d(
|
||||||
|
in_channels, out_channels, kernel_size=3, stride=2, padding=0
|
||||||
|
)
|
||||||
self.compress_time = compress_time
|
self.compress_time = compress_time
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@ -182,7 +202,14 @@ class DownSample3D(nn.Module):
|
|||||||
|
|
||||||
class ResnetBlock3D(nn.Module):
|
class ResnetBlock3D(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512, pad_mode="constant"
|
self,
|
||||||
|
*,
|
||||||
|
in_channels,
|
||||||
|
out_channels=None,
|
||||||
|
conv_shortcut=False,
|
||||||
|
dropout,
|
||||||
|
temb_channels=512,
|
||||||
|
pad_mode="constant",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
@ -214,9 +241,13 @@ class ResnetBlock3D(nn.Module):
|
|||||||
# kernel_size=3,
|
# kernel_size=3,
|
||||||
# stride=1,
|
# stride=1,
|
||||||
# padding=1)
|
# padding=1)
|
||||||
self.conv_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
|
self.conv_shortcut = CausalConv3d(
|
||||||
|
in_channels, out_channels, kernel_size=3, pad_mode=pad_mode
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.nin_shortcut = torch.nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
self.nin_shortcut = torch.nn.Conv3d(
|
||||||
|
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
||||||
|
)
|
||||||
# self.nin_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=1, pad_mode=pad_mode)
|
# self.nin_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=1, pad_mode=pad_mode)
|
||||||
|
|
||||||
def forward(self, x, temb):
|
def forward(self, x, temb):
|
||||||
@ -251,7 +282,9 @@ class AttnBlock2D(nn.Module):
|
|||||||
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
self.proj_out = torch.nn.Conv2d(
|
||||||
|
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
h_ = x
|
h_ = x
|
||||||
@ -365,12 +398,20 @@ class Encoder3D(nn.Module):
|
|||||||
# middle
|
# middle
|
||||||
self.mid = nn.Module()
|
self.mid = nn.Module()
|
||||||
self.mid.block_1 = ResnetBlock3D(
|
self.mid.block_1 = ResnetBlock3D(
|
||||||
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, pad_mode=pad_mode
|
in_channels=block_in,
|
||||||
|
out_channels=block_in,
|
||||||
|
temb_channels=self.temb_ch,
|
||||||
|
dropout=dropout,
|
||||||
|
pad_mode=pad_mode,
|
||||||
)
|
)
|
||||||
# remove attention block
|
# remove attention block
|
||||||
# self.mid.attn_1 = AttnBlock2D(block_in)
|
# self.mid.attn_1 = AttnBlock2D(block_in)
|
||||||
self.mid.block_2 = ResnetBlock3D(
|
self.mid.block_2 = ResnetBlock3D(
|
||||||
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, pad_mode=pad_mode
|
in_channels=block_in,
|
||||||
|
out_channels=block_in,
|
||||||
|
temb_channels=self.temb_ch,
|
||||||
|
dropout=dropout,
|
||||||
|
pad_mode=pad_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
# end
|
# end
|
||||||
|
@ -80,7 +80,9 @@ class Upsample(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.with_conv = with_conv
|
self.with_conv = with_conv
|
||||||
if self.with_conv:
|
if self.with_conv:
|
||||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
self.conv = torch.nn.Conv2d(
|
||||||
|
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||||
@ -95,7 +97,9 @@ class Downsample(nn.Module):
|
|||||||
self.with_conv = with_conv
|
self.with_conv = with_conv
|
||||||
if self.with_conv:
|
if self.with_conv:
|
||||||
# no asymmetric padding in torch conv, must do it ourselves
|
# no asymmetric padding in torch conv, must do it ourselves
|
||||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
self.conv = torch.nn.Conv2d(
|
||||||
|
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if self.with_conv:
|
if self.with_conv:
|
||||||
@ -134,9 +138,13 @@ class ResnetBlock(nn.Module):
|
|||||||
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||||
if self.in_channels != self.out_channels:
|
if self.in_channels != self.out_channels:
|
||||||
if self.use_conv_shortcut:
|
if self.use_conv_shortcut:
|
||||||
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
self.conv_shortcut = torch.nn.Conv2d(
|
||||||
|
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
self.nin_shortcut = torch.nn.Conv2d(
|
||||||
|
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x, temb, zq):
|
def forward(self, x, temb, zq):
|
||||||
h = x
|
h = x
|
||||||
@ -170,7 +178,9 @@ class AttnBlock(nn.Module):
|
|||||||
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
self.proj_out = torch.nn.Conv2d(
|
||||||
|
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x, zq):
|
def forward(self, x, zq):
|
||||||
h_ = x
|
h_ = x
|
||||||
@ -232,7 +242,11 @@ class MOVQDecoder(nn.Module):
|
|||||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||||
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
|
print(
|
||||||
|
"Working with z of shape {} = {} dimensions.".format(
|
||||||
|
self.z_shape, np.prod(self.z_shape)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# z to block_in
|
# z to block_in
|
||||||
self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
||||||
|
@ -15,7 +15,16 @@ class VectorQuantizer2(nn.Module):
|
|||||||
# NOTE: due to a bug the beta term was applied to the wrong term. for
|
# NOTE: due to a bug the beta term was applied to the wrong term. for
|
||||||
# backwards compatibility we use the buggy version by default, but you can
|
# backwards compatibility we use the buggy version by default, but you can
|
||||||
# specify legacy=False to fix it.
|
# specify legacy=False to fix it.
|
||||||
def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True):
|
def __init__(
|
||||||
|
self,
|
||||||
|
n_e,
|
||||||
|
e_dim,
|
||||||
|
beta,
|
||||||
|
remap=None,
|
||||||
|
unknown_index="random",
|
||||||
|
sane_index_shape=False,
|
||||||
|
legacy=True,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.n_e = n_e
|
self.n_e = n_e
|
||||||
self.e_dim = e_dim
|
self.e_dim = e_dim
|
||||||
@ -51,7 +60,9 @@ class VectorQuantizer2(nn.Module):
|
|||||||
new = match.argmax(-1)
|
new = match.argmax(-1)
|
||||||
unknown = match.sum(2) < 1
|
unknown = match.sum(2) < 1
|
||||||
if self.unknown_index == "random":
|
if self.unknown_index == "random":
|
||||||
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
|
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(
|
||||||
|
device=new.device
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
new[unknown] = self.unknown_index
|
new[unknown] = self.unknown_index
|
||||||
return new.reshape(ishape)
|
return new.reshape(ishape)
|
||||||
@ -78,7 +89,8 @@ class VectorQuantizer2(nn.Module):
|
|||||||
d = (
|
d = (
|
||||||
torch.sum(z_flattened**2, dim=1, keepdim=True)
|
torch.sum(z_flattened**2, dim=1, keepdim=True)
|
||||||
+ torch.sum(self.embedding.weight**2, dim=1)
|
+ torch.sum(self.embedding.weight**2, dim=1)
|
||||||
- 2 * torch.einsum("bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n"))
|
- 2
|
||||||
|
* torch.einsum("bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n"))
|
||||||
)
|
)
|
||||||
|
|
||||||
min_encoding_indices = torch.argmin(d, dim=1)
|
min_encoding_indices = torch.argmin(d, dim=1)
|
||||||
@ -88,9 +100,13 @@ class VectorQuantizer2(nn.Module):
|
|||||||
|
|
||||||
# compute loss for embedding
|
# compute loss for embedding
|
||||||
if not self.legacy:
|
if not self.legacy:
|
||||||
loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
|
loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean(
|
||||||
|
(z_q - z.detach()) ** 2
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
|
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean(
|
||||||
|
(z_q - z.detach()) ** 2
|
||||||
|
)
|
||||||
|
|
||||||
# preserve gradients
|
# preserve gradients
|
||||||
z_q = z + (z_q - z).detach()
|
z_q = z + (z_q - z).detach()
|
||||||
@ -104,7 +120,9 @@ class VectorQuantizer2(nn.Module):
|
|||||||
min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
|
min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
|
||||||
|
|
||||||
if self.sane_index_shape:
|
if self.sane_index_shape:
|
||||||
min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])
|
min_encoding_indices = min_encoding_indices.reshape(
|
||||||
|
z_q.shape[0], z_q.shape[2], z_q.shape[3]
|
||||||
|
)
|
||||||
|
|
||||||
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
|
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
|
||||||
|
|
||||||
@ -184,7 +202,9 @@ class GumbelQuantize(nn.Module):
|
|||||||
new = match.argmax(-1)
|
new = match.argmax(-1)
|
||||||
unknown = match.sum(2) < 1
|
unknown = match.sum(2) < 1
|
||||||
if self.unknown_index == "random":
|
if self.unknown_index == "random":
|
||||||
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
|
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(
|
||||||
|
device=new.device
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
new[unknown] = self.unknown_index
|
new[unknown] = self.unknown_index
|
||||||
return new.reshape(ishape)
|
return new.reshape(ishape)
|
||||||
|
@ -40,7 +40,9 @@ class Upsample(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.with_conv = with_conv
|
self.with_conv = with_conv
|
||||||
if self.with_conv:
|
if self.with_conv:
|
||||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
self.conv = torch.nn.Conv2d(
|
||||||
|
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||||
@ -55,7 +57,9 @@ class Downsample(nn.Module):
|
|||||||
self.with_conv = with_conv
|
self.with_conv = with_conv
|
||||||
if self.with_conv:
|
if self.with_conv:
|
||||||
# no asymmetric padding in torch conv, must do it ourselves
|
# no asymmetric padding in torch conv, must do it ourselves
|
||||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
self.conv = torch.nn.Conv2d(
|
||||||
|
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if self.with_conv:
|
if self.with_conv:
|
||||||
@ -68,7 +72,9 @@ class Downsample(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class ResnetBlock(nn.Module):
|
class ResnetBlock(nn.Module):
|
||||||
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
|
def __init__(
|
||||||
|
self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
out_channels = in_channels if out_channels is None else out_channels
|
out_channels = in_channels if out_channels is None else out_channels
|
||||||
@ -84,9 +90,13 @@ class ResnetBlock(nn.Module):
|
|||||||
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||||
if self.in_channels != self.out_channels:
|
if self.in_channels != self.out_channels:
|
||||||
if self.use_conv_shortcut:
|
if self.use_conv_shortcut:
|
||||||
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
self.conv_shortcut = torch.nn.Conv2d(
|
||||||
|
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
self.nin_shortcut = torch.nn.Conv2d(
|
||||||
|
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x, temb):
|
def forward(self, x, temb):
|
||||||
h = x
|
h = x
|
||||||
@ -120,7 +130,9 @@ class AttnBlock(nn.Module):
|
|||||||
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
self.proj_out = torch.nn.Conv2d(
|
||||||
|
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
h_ = x
|
h_ = x
|
||||||
@ -194,7 +206,10 @@ class Encoder(nn.Module):
|
|||||||
for i_block in range(self.num_res_blocks):
|
for i_block in range(self.num_res_blocks):
|
||||||
block.append(
|
block.append(
|
||||||
ResnetBlock(
|
ResnetBlock(
|
||||||
in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout
|
in_channels=block_in,
|
||||||
|
out_channels=block_out,
|
||||||
|
temb_channels=self.temb_ch,
|
||||||
|
dropout=dropout,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
block_in = block_out
|
block_in = block_out
|
||||||
@ -326,7 +341,11 @@ class Decoder(nn.Module):
|
|||||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||||
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
|
print(
|
||||||
|
"Working with z of shape {} = {} dimensions.".format(
|
||||||
|
self.z_shape, np.prod(self.z_shape)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# z to block_in
|
# z to block_in
|
||||||
self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
||||||
@ -350,7 +369,10 @@ class Decoder(nn.Module):
|
|||||||
for i_block in range(self.num_res_blocks + 1):
|
for i_block in range(self.num_res_blocks + 1):
|
||||||
block.append(
|
block.append(
|
||||||
ResnetBlock(
|
ResnetBlock(
|
||||||
in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout
|
in_channels=block_in,
|
||||||
|
out_channels=block_out,
|
||||||
|
temb_channels=self.temb_ch,
|
||||||
|
dropout=dropout,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
block_in = block_out
|
block_in = block_out
|
||||||
|
@ -136,9 +136,9 @@ def _conv_split(input_, dim, kernel_size):
|
|||||||
if cp_rank == 0:
|
if cp_rank == 0:
|
||||||
output = input_.transpose(dim, 0)[: dim_size + kernel_size].transpose(dim, 0)
|
output = input_.transpose(dim, 0)[: dim_size + kernel_size].transpose(dim, 0)
|
||||||
else:
|
else:
|
||||||
output = input_.transpose(dim, 0)[cp_rank * dim_size + 1 : (cp_rank + 1) * dim_size + kernel_size].transpose(
|
output = input_.transpose(dim, 0)[
|
||||||
dim, 0
|
cp_rank * dim_size + 1 : (cp_rank + 1) * dim_size + kernel_size
|
||||||
)
|
].transpose(dim, 0)
|
||||||
output = output.contiguous()
|
output = output.contiguous()
|
||||||
|
|
||||||
# print('out _conv_split, cp_rank:', cp_rank, 'input_size:', output.shape)
|
# print('out _conv_split, cp_rank:', cp_rank, 'input_size:', output.shape)
|
||||||
|
@ -35,7 +35,9 @@ class Denoiser(nn.Module):
|
|||||||
sigma = append_dims(sigma, input.ndim)
|
sigma = append_dims(sigma, input.ndim)
|
||||||
c_skip, c_out, c_in, c_noise = self.scaling(sigma, **additional_model_inputs)
|
c_skip, c_out, c_in, c_noise = self.scaling(sigma, **additional_model_inputs)
|
||||||
c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape))
|
c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape))
|
||||||
return network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out + input * c_skip
|
return (
|
||||||
|
network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out + input * c_skip
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DiscreteDenoiser(Denoiser):
|
class DiscreteDenoiser(Denoiser):
|
||||||
@ -50,7 +52,9 @@ class DiscreteDenoiser(Denoiser):
|
|||||||
flip=True,
|
flip=True,
|
||||||
):
|
):
|
||||||
super().__init__(weighting_config, scaling_config)
|
super().__init__(weighting_config, scaling_config)
|
||||||
sigmas = instantiate_from_config(discretization_config)(num_idx, do_append_zero=do_append_zero, flip=flip)
|
sigmas = instantiate_from_config(discretization_config)(
|
||||||
|
num_idx, do_append_zero=do_append_zero, flip=flip
|
||||||
|
)
|
||||||
self.sigmas = sigmas
|
self.sigmas = sigmas
|
||||||
# self.register_buffer("sigmas", sigmas)
|
# self.register_buffer("sigmas", sigmas)
|
||||||
self.quantize_c_noise = quantize_c_noise
|
self.quantize_c_noise = quantize_c_noise
|
||||||
|
@ -6,7 +6,9 @@ import torch
|
|||||||
|
|
||||||
class DenoiserScaling(ABC):
|
class DenoiserScaling(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
def __call__(
|
||||||
|
self, sigma: torch.Tensor
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ -14,7 +16,9 @@ class EDMScaling:
|
|||||||
def __init__(self, sigma_data: float = 0.5):
|
def __init__(self, sigma_data: float = 0.5):
|
||||||
self.sigma_data = sigma_data
|
self.sigma_data = sigma_data
|
||||||
|
|
||||||
def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
def __call__(
|
||||||
|
self, sigma: torch.Tensor
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
|
c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
|
||||||
c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
|
c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
|
||||||
c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
|
c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
|
||||||
@ -23,7 +27,9 @@ class EDMScaling:
|
|||||||
|
|
||||||
|
|
||||||
class EpsScaling:
|
class EpsScaling:
|
||||||
def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
def __call__(
|
||||||
|
self, sigma: torch.Tensor
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
c_skip = torch.ones_like(sigma, device=sigma.device)
|
c_skip = torch.ones_like(sigma, device=sigma.device)
|
||||||
c_out = -sigma
|
c_out = -sigma
|
||||||
c_in = 1 / (sigma**2 + 1.0) ** 0.5
|
c_in = 1 / (sigma**2 + 1.0) ** 0.5
|
||||||
@ -32,7 +38,9 @@ class EpsScaling:
|
|||||||
|
|
||||||
|
|
||||||
class VScaling:
|
class VScaling:
|
||||||
def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
def __call__(
|
||||||
|
self, sigma: torch.Tensor
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
c_skip = 1.0 / (sigma**2 + 1.0)
|
c_skip = 1.0 / (sigma**2 + 1.0)
|
||||||
c_out = -sigma / (sigma**2 + 1.0) ** 0.5
|
c_out = -sigma / (sigma**2 + 1.0) ** 0.5
|
||||||
c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
|
c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
|
||||||
@ -41,7 +49,9 @@ class VScaling:
|
|||||||
|
|
||||||
|
|
||||||
class VScalingWithEDMcNoise(DenoiserScaling):
|
class VScalingWithEDMcNoise(DenoiserScaling):
|
||||||
def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
def __call__(
|
||||||
|
self, sigma: torch.Tensor
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
c_skip = 1.0 / (sigma**2 + 1.0)
|
c_skip = 1.0 / (sigma**2 + 1.0)
|
||||||
c_out = -sigma / (sigma**2 + 1.0) ** 0.5
|
c_out = -sigma / (sigma**2 + 1.0) ** 0.5
|
||||||
c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
|
c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
|
||||||
|
@ -52,7 +52,9 @@ class LegacyDDPMDiscretization(Discretization):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_timesteps = num_timesteps
|
self.num_timesteps = num_timesteps
|
||||||
betas = make_beta_schedule("linear", num_timesteps, linear_start=linear_start, linear_end=linear_end)
|
betas = make_beta_schedule(
|
||||||
|
"linear", num_timesteps, linear_start=linear_start, linear_end=linear_end
|
||||||
|
)
|
||||||
alphas = 1.0 - betas
|
alphas = 1.0 - betas
|
||||||
self.alphas_cumprod = np.cumprod(alphas, axis=0)
|
self.alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||||
self.to_torch = partial(torch.tensor, dtype=torch.float32)
|
self.to_torch = partial(torch.tensor, dtype=torch.float32)
|
||||||
@ -85,14 +87,18 @@ class ZeroSNRDDPMDiscretization(Discretization):
|
|||||||
if keep_start and not post_shift:
|
if keep_start and not post_shift:
|
||||||
linear_start = linear_start / (shift_scale + (1 - shift_scale) * linear_start)
|
linear_start = linear_start / (shift_scale + (1 - shift_scale) * linear_start)
|
||||||
self.num_timesteps = num_timesteps
|
self.num_timesteps = num_timesteps
|
||||||
betas = make_beta_schedule("linear", num_timesteps, linear_start=linear_start, linear_end=linear_end)
|
betas = make_beta_schedule(
|
||||||
|
"linear", num_timesteps, linear_start=linear_start, linear_end=linear_end
|
||||||
|
)
|
||||||
alphas = 1.0 - betas
|
alphas = 1.0 - betas
|
||||||
self.alphas_cumprod = np.cumprod(alphas, axis=0)
|
self.alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||||
self.to_torch = partial(torch.tensor, dtype=torch.float32)
|
self.to_torch = partial(torch.tensor, dtype=torch.float32)
|
||||||
|
|
||||||
# SNR shift
|
# SNR shift
|
||||||
if not post_shift:
|
if not post_shift:
|
||||||
self.alphas_cumprod = self.alphas_cumprod / (shift_scale + (1 - shift_scale) * self.alphas_cumprod)
|
self.alphas_cumprod = self.alphas_cumprod / (
|
||||||
|
shift_scale + (1 - shift_scale) * self.alphas_cumprod
|
||||||
|
)
|
||||||
|
|
||||||
self.post_shift = post_shift
|
self.post_shift = post_shift
|
||||||
self.shift_scale = shift_scale
|
self.shift_scale = shift_scale
|
||||||
@ -113,11 +119,14 @@ class ZeroSNRDDPMDiscretization(Discretization):
|
|||||||
alphas_cumprod_sqrt_T = alphas_cumprod_sqrt[-1].clone()
|
alphas_cumprod_sqrt_T = alphas_cumprod_sqrt[-1].clone()
|
||||||
|
|
||||||
alphas_cumprod_sqrt -= alphas_cumprod_sqrt_T
|
alphas_cumprod_sqrt -= alphas_cumprod_sqrt_T
|
||||||
alphas_cumprod_sqrt *= alphas_cumprod_sqrt_0 / (alphas_cumprod_sqrt_0 - alphas_cumprod_sqrt_T)
|
alphas_cumprod_sqrt *= alphas_cumprod_sqrt_0 / (
|
||||||
|
alphas_cumprod_sqrt_0 - alphas_cumprod_sqrt_T
|
||||||
|
)
|
||||||
|
|
||||||
if self.post_shift:
|
if self.post_shift:
|
||||||
alphas_cumprod_sqrt = (
|
alphas_cumprod_sqrt = (
|
||||||
alphas_cumprod_sqrt**2 / (self.shift_scale + (1 - self.shift_scale) * alphas_cumprod_sqrt**2)
|
alphas_cumprod_sqrt**2
|
||||||
|
/ (self.shift_scale + (1 - self.shift_scale) * alphas_cumprod_sqrt**2)
|
||||||
) ** 0.5
|
) ** 0.5
|
||||||
|
|
||||||
if return_idx:
|
if return_idx:
|
||||||
|
@ -15,7 +15,9 @@ class Guider(ABC):
|
|||||||
def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor:
|
def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def prepare_inputs(self, x: torch.Tensor, s: float, c: Dict, uc: Dict) -> Tuple[torch.Tensor, float, Dict]:
|
def prepare_inputs(
|
||||||
|
self, x: torch.Tensor, s: float, c: Dict, uc: Dict
|
||||||
|
) -> Tuple[torch.Tensor, float, Dict]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ -57,7 +59,8 @@ class DynamicCFG(VanillaCFG):
|
|||||||
def __init__(self, scale, exp, num_steps, dyn_thresh_config=None):
|
def __init__(self, scale, exp, num_steps, dyn_thresh_config=None):
|
||||||
super().__init__(scale, dyn_thresh_config)
|
super().__init__(scale, dyn_thresh_config)
|
||||||
scale_schedule = (
|
scale_schedule = (
|
||||||
lambda scale, sigma, step_index: 1 + scale * (1 - math.cos(math.pi * (step_index / num_steps) ** exp)) / 2
|
lambda scale, sigma, step_index: 1
|
||||||
|
+ scale * (1 - math.cos(math.pi * (step_index / num_steps) ** exp)) / 2
|
||||||
)
|
)
|
||||||
self.scale_schedule = partial(scale_schedule, scale)
|
self.scale_schedule = partial(scale_schedule, scale)
|
||||||
self.dyn_thresh = instantiate_from_config(
|
self.dyn_thresh = instantiate_from_config(
|
||||||
|
@ -20,7 +20,9 @@ from torch import nn
|
|||||||
|
|
||||||
|
|
||||||
class LoRALinearLayer(nn.Module):
|
class LoRALinearLayer(nn.Module):
|
||||||
def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None):
|
def __init__(
|
||||||
|
self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
|
self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
|
||||||
@ -50,11 +52,20 @@ class LoRALinearLayer(nn.Module):
|
|||||||
|
|
||||||
class LoRAConv2dLayer(nn.Module):
|
class LoRAConv2dLayer(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, in_features, out_features, rank=4, kernel_size=(1, 1), stride=(1, 1), padding=0, network_alpha=None
|
self,
|
||||||
|
in_features,
|
||||||
|
out_features,
|
||||||
|
rank=4,
|
||||||
|
kernel_size=(1, 1),
|
||||||
|
stride=(1, 1),
|
||||||
|
padding=0,
|
||||||
|
network_alpha=None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.down = nn.Conv2d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
|
self.down = nn.Conv2d(
|
||||||
|
in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False
|
||||||
|
)
|
||||||
# according to the official kohya_ss trainer kernel_size are always fixed for the up layer
|
# according to the official kohya_ss trainer kernel_size are always fixed for the up layer
|
||||||
# # see: https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L129
|
# # see: https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L129
|
||||||
self.up = nn.Conv2d(rank, out_features, kernel_size=(1, 1), stride=(1, 1), bias=False)
|
self.up = nn.Conv2d(rank, out_features, kernel_size=(1, 1), stride=(1, 1), bias=False)
|
||||||
@ -85,7 +96,9 @@ class LoRACompatibleConv(nn.Conv2d):
|
|||||||
A convolutional layer that can be used with LoRA.
|
A convolutional layer that can be used with LoRA.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, scale: float = 1.0, **kwargs):
|
def __init__(
|
||||||
|
self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, scale: float = 1.0, **kwargs
|
||||||
|
):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.lora_layer = lora_layer
|
self.lora_layer = lora_layer
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
@ -144,7 +157,13 @@ class LoRACompatibleConv(nn.Conv2d):
|
|||||||
# make sure to the functional Conv2D function as otherwise torch.compile's graph will break
|
# make sure to the functional Conv2D function as otherwise torch.compile's graph will break
|
||||||
# see: https://github.com/huggingface/diffusers/pull/4315
|
# see: https://github.com/huggingface/diffusers/pull/4315
|
||||||
return F.conv2d(
|
return F.conv2d(
|
||||||
hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
|
hidden_states,
|
||||||
|
self.weight,
|
||||||
|
self.bias,
|
||||||
|
self.stride,
|
||||||
|
self.padding,
|
||||||
|
self.dilation,
|
||||||
|
self.groups,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return super().forward(hidden_states) + (scale * self.lora_layer(hidden_states))
|
return super().forward(hidden_states) + (scale * self.lora_layer(hidden_states))
|
||||||
@ -155,7 +174,9 @@ class LoRACompatibleLinear(nn.Linear):
|
|||||||
A Linear layer that can be used with LoRA.
|
A Linear layer that can be used with LoRA.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, scale: float = 1.0, **kwargs):
|
def __init__(
|
||||||
|
self, *args, lora_layer: Optional[LoRALinearLayer] = None, scale: float = 1.0, **kwargs
|
||||||
|
):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.lora_layer = lora_layer
|
self.lora_layer = lora_layer
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
@ -197,7 +218,9 @@ class LoRACompatibleLinear(nn.Linear):
|
|||||||
w_up = self.w_up.to(device=device).float()
|
w_up = self.w_up.to(device=device).float()
|
||||||
w_down = self.w_down.to(device).float()
|
w_down = self.w_down.to(device).float()
|
||||||
|
|
||||||
unfused_weight = fused_weight.float() - (self._lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
|
unfused_weight = fused_weight.float() - (
|
||||||
|
self._lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0]
|
||||||
|
)
|
||||||
self.weight.data = unfused_weight.to(device=device, dtype=dtype)
|
self.weight.data = unfused_weight.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
self.w_up = None
|
self.w_up = None
|
||||||
@ -252,7 +275,9 @@ def _find_modules_v2(
|
|||||||
|
|
||||||
# Get the targets we should replace all linears under
|
# Get the targets we should replace all linears under
|
||||||
if ancestor_class is not None:
|
if ancestor_class is not None:
|
||||||
ancestors = (module for module in model.modules() if module.__class__.__name__ in ancestor_class)
|
ancestors = (
|
||||||
|
module for module in model.modules() if module.__class__.__name__ in ancestor_class
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# this, incase you want to naively iterate over all modules.
|
# this, incase you want to naively iterate over all modules.
|
||||||
ancestors = [module for module in model.modules()]
|
ancestors = [module for module in model.modules()]
|
||||||
@ -274,7 +299,9 @@ def _find_modules_v2(
|
|||||||
if flag:
|
if flag:
|
||||||
continue
|
continue
|
||||||
# Skip this linear if it's a child of a LoraInjectedLinear
|
# Skip this linear if it's a child of a LoraInjectedLinear
|
||||||
if exclude_children_of and any([isinstance(parent, _class) for _class in exclude_children_of]):
|
if exclude_children_of and any(
|
||||||
|
[isinstance(parent, _class) for _class in exclude_children_of]
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
# Otherwise, yield it
|
# Otherwise, yield it
|
||||||
yield parent, name, module
|
yield parent, name, module
|
||||||
|
@ -38,13 +38,17 @@ class StandardDiffusionLoss(nn.Module):
|
|||||||
|
|
||||||
def __call__(self, network, denoiser, conditioner, input, batch):
|
def __call__(self, network, denoiser, conditioner, input, batch):
|
||||||
cond = conditioner(batch)
|
cond = conditioner(batch)
|
||||||
additional_model_inputs = {key: batch[key] for key in self.batch2model_keys.intersection(batch)}
|
additional_model_inputs = {
|
||||||
|
key: batch[key] for key in self.batch2model_keys.intersection(batch)
|
||||||
|
}
|
||||||
|
|
||||||
sigmas = self.sigma_sampler(input.shape[0]).to(input.device)
|
sigmas = self.sigma_sampler(input.shape[0]).to(input.device)
|
||||||
noise = torch.randn_like(input)
|
noise = torch.randn_like(input)
|
||||||
if self.offset_noise_level > 0.0:
|
if self.offset_noise_level > 0.0:
|
||||||
noise = (
|
noise = (
|
||||||
noise + append_dims(torch.randn(input.shape[0]).to(input.device), input.ndim) * self.offset_noise_level
|
noise
|
||||||
|
+ append_dims(torch.randn(input.shape[0]).to(input.device), input.ndim)
|
||||||
|
* self.offset_noise_level
|
||||||
)
|
)
|
||||||
noise = noise.to(input.dtype)
|
noise = noise.to(input.dtype)
|
||||||
noised_input = input.float() + noise * append_dims(sigmas, input.ndim)
|
noised_input = input.float() + noise * append_dims(sigmas, input.ndim)
|
||||||
@ -63,7 +67,9 @@ class StandardDiffusionLoss(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class VideoDiffusionLoss(StandardDiffusionLoss):
|
class VideoDiffusionLoss(StandardDiffusionLoss):
|
||||||
def __init__(self, block_scale=None, block_size=None, min_snr_value=None, fixed_frames=0, **kwargs):
|
def __init__(
|
||||||
|
self, block_scale=None, block_size=None, min_snr_value=None, fixed_frames=0, **kwargs
|
||||||
|
):
|
||||||
self.fixed_frames = fixed_frames
|
self.fixed_frames = fixed_frames
|
||||||
self.block_scale = block_scale
|
self.block_scale = block_scale
|
||||||
self.block_size = block_size
|
self.block_size = block_size
|
||||||
@ -72,7 +78,9 @@ class VideoDiffusionLoss(StandardDiffusionLoss):
|
|||||||
|
|
||||||
def __call__(self, network, denoiser, conditioner, input, batch):
|
def __call__(self, network, denoiser, conditioner, input, batch):
|
||||||
cond = conditioner(batch)
|
cond = conditioner(batch)
|
||||||
additional_model_inputs = {key: batch[key] for key in self.batch2model_keys.intersection(batch)}
|
additional_model_inputs = {
|
||||||
|
key: batch[key] for key in self.batch2model_keys.intersection(batch)
|
||||||
|
}
|
||||||
|
|
||||||
alphas_cumprod_sqrt, idx = self.sigma_sampler(input.shape[0], return_idx=True)
|
alphas_cumprod_sqrt, idx = self.sigma_sampler(input.shape[0], return_idx=True)
|
||||||
alphas_cumprod_sqrt = alphas_cumprod_sqrt.to(input.device)
|
alphas_cumprod_sqrt = alphas_cumprod_sqrt.to(input.device)
|
||||||
@ -86,24 +94,30 @@ class VideoDiffusionLoss(StandardDiffusionLoss):
|
|||||||
src = global_rank * mp_size
|
src = global_rank * mp_size
|
||||||
torch.distributed.broadcast(idx, src=src, group=mpu.get_model_parallel_group())
|
torch.distributed.broadcast(idx, src=src, group=mpu.get_model_parallel_group())
|
||||||
torch.distributed.broadcast(noise, src=src, group=mpu.get_model_parallel_group())
|
torch.distributed.broadcast(noise, src=src, group=mpu.get_model_parallel_group())
|
||||||
torch.distributed.broadcast(alphas_cumprod_sqrt, src=src, group=mpu.get_model_parallel_group())
|
torch.distributed.broadcast(
|
||||||
|
alphas_cumprod_sqrt, src=src, group=mpu.get_model_parallel_group()
|
||||||
|
)
|
||||||
|
|
||||||
additional_model_inputs["idx"] = idx
|
additional_model_inputs["idx"] = idx
|
||||||
|
|
||||||
if self.offset_noise_level > 0.0:
|
if self.offset_noise_level > 0.0:
|
||||||
noise = (
|
noise = (
|
||||||
noise + append_dims(torch.randn(input.shape[0]).to(input.device), input.ndim) * self.offset_noise_level
|
noise
|
||||||
|
+ append_dims(torch.randn(input.shape[0]).to(input.device), input.ndim)
|
||||||
|
* self.offset_noise_level
|
||||||
)
|
)
|
||||||
|
|
||||||
noised_input = input.float() * append_dims(alphas_cumprod_sqrt, input.ndim) + noise * append_dims(
|
noised_input = input.float() * append_dims(
|
||||||
(1 - alphas_cumprod_sqrt**2) ** 0.5, input.ndim
|
alphas_cumprod_sqrt, input.ndim
|
||||||
)
|
) + noise * append_dims((1 - alphas_cumprod_sqrt**2) ** 0.5, input.ndim)
|
||||||
|
|
||||||
if "concat_images" in batch.keys():
|
if "concat_images" in batch.keys():
|
||||||
cond["concat"] = batch["concat_images"]
|
cond["concat"] = batch["concat_images"]
|
||||||
|
|
||||||
# [2, 13, 16, 60, 90],[2] dict_keys(['crossattn', 'concat']) dict_keys(['idx'])
|
# [2, 13, 16, 60, 90],[2] dict_keys(['crossattn', 'concat']) dict_keys(['idx'])
|
||||||
model_output = denoiser(network, noised_input, alphas_cumprod_sqrt, cond, **additional_model_inputs)
|
model_output = denoiser(
|
||||||
|
network, noised_input, alphas_cumprod_sqrt, cond, **additional_model_inputs
|
||||||
|
)
|
||||||
w = append_dims(1 / (1 - alphas_cumprod_sqrt**2), input.ndim) # v-pred
|
w = append_dims(1 / (1 - alphas_cumprod_sqrt**2), input.ndim) # v-pred
|
||||||
|
|
||||||
if self.min_snr_value is not None:
|
if self.min_snr_value is not None:
|
||||||
|
@ -47,7 +47,9 @@ def nonlinearity(x):
|
|||||||
|
|
||||||
|
|
||||||
def Normalize(in_channels, num_groups=32):
|
def Normalize(in_channels, num_groups=32):
|
||||||
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
return torch.nn.GroupNorm(
|
||||||
|
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Upsample(nn.Module):
|
class Upsample(nn.Module):
|
||||||
@ -55,7 +57,9 @@ class Upsample(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.with_conv = with_conv
|
self.with_conv = with_conv
|
||||||
if self.with_conv:
|
if self.with_conv:
|
||||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
self.conv = torch.nn.Conv2d(
|
||||||
|
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||||
@ -70,7 +74,9 @@ class Downsample(nn.Module):
|
|||||||
self.with_conv = with_conv
|
self.with_conv = with_conv
|
||||||
if self.with_conv:
|
if self.with_conv:
|
||||||
# no asymmetric padding in torch conv, must do it ourselves
|
# no asymmetric padding in torch conv, must do it ourselves
|
||||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
self.conv = torch.nn.Conv2d(
|
||||||
|
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if self.with_conv:
|
if self.with_conv:
|
||||||
@ -107,9 +113,13 @@ class ResnetBlock(nn.Module):
|
|||||||
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||||
if self.in_channels != self.out_channels:
|
if self.in_channels != self.out_channels:
|
||||||
if self.use_conv_shortcut:
|
if self.use_conv_shortcut:
|
||||||
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
self.conv_shortcut = torch.nn.Conv2d(
|
||||||
|
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
self.nin_shortcut = torch.nn.Conv2d(
|
||||||
|
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x, temb):
|
def forward(self, x, temb):
|
||||||
h = x
|
h = x
|
||||||
@ -150,7 +160,9 @@ class AttnBlock(nn.Module):
|
|||||||
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
self.proj_out = torch.nn.Conv2d(
|
||||||
|
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||||
|
)
|
||||||
|
|
||||||
def attention(self, h_: torch.Tensor) -> torch.Tensor:
|
def attention(self, h_: torch.Tensor) -> torch.Tensor:
|
||||||
h_ = self.norm(h_)
|
h_ = self.norm(h_)
|
||||||
@ -160,7 +172,9 @@ class AttnBlock(nn.Module):
|
|||||||
|
|
||||||
b, c, h, w = q.shape
|
b, c, h, w = q.shape
|
||||||
q, k, v = map(lambda x: rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v))
|
q, k, v = map(lambda x: rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v))
|
||||||
h_ = torch.nn.functional.scaled_dot_product_attention(q, k, v) # scale is dim ** -0.5 per default
|
h_ = torch.nn.functional.scaled_dot_product_attention(
|
||||||
|
q, k, v
|
||||||
|
) # scale is dim ** -0.5 per default
|
||||||
# compute attention
|
# compute attention
|
||||||
|
|
||||||
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
||||||
@ -188,7 +202,9 @@ class MemoryEfficientAttnBlock(nn.Module):
|
|||||||
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
self.proj_out = torch.nn.Conv2d(
|
||||||
|
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||||
|
)
|
||||||
self.attention_op: Optional[Any] = None
|
self.attention_op: Optional[Any] = None
|
||||||
|
|
||||||
def attention(self, h_: torch.Tensor) -> torch.Tensor:
|
def attention(self, h_: torch.Tensor) -> torch.Tensor:
|
||||||
@ -211,7 +227,12 @@ class MemoryEfficientAttnBlock(nn.Module):
|
|||||||
)
|
)
|
||||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
|
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
|
||||||
|
|
||||||
out = out.unsqueeze(0).reshape(B, 1, out.shape[1], C).permute(0, 2, 1, 3).reshape(B, out.shape[1], C)
|
out = (
|
||||||
|
out.unsqueeze(0)
|
||||||
|
.reshape(B, 1, out.shape[1], C)
|
||||||
|
.permute(0, 2, 1, 3)
|
||||||
|
.reshape(B, out.shape[1], C)
|
||||||
|
)
|
||||||
return rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C)
|
return rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C)
|
||||||
|
|
||||||
def forward(self, x, **kwargs):
|
def forward(self, x, **kwargs):
|
||||||
@ -581,7 +602,11 @@ class Decoder(nn.Module):
|
|||||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||||
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
|
print(
|
||||||
|
"Working with z of shape {} = {} dimensions.".format(
|
||||||
|
self.z_shape, np.prod(self.z_shape)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
make_attn_cls = self._make_attn()
|
make_attn_cls = self._make_attn()
|
||||||
make_resblock_cls = self._make_resblock()
|
make_resblock_cls = self._make_resblock()
|
||||||
|
@ -47,7 +47,9 @@ class AttentionPool2d(nn.Module):
|
|||||||
output_dim: int = None,
|
output_dim: int = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5)
|
self.positional_embedding = nn.Parameter(
|
||||||
|
th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5
|
||||||
|
)
|
||||||
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
|
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
|
||||||
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
|
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
|
||||||
self.num_heads = embed_dim // num_heads_channels
|
self.num_heads = embed_dim // num_heads_channels
|
||||||
@ -303,7 +305,9 @@ class ResBlock(TimestepBlock):
|
|||||||
if self.out_channels == channels:
|
if self.out_channels == channels:
|
||||||
self.skip_connection = nn.Identity()
|
self.skip_connection = nn.Identity()
|
||||||
elif use_conv:
|
elif use_conv:
|
||||||
self.skip_connection = conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding)
|
self.skip_connection = conv_nd(
|
||||||
|
dims, channels, self.out_channels, kernel_size, padding=padding
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
||||||
|
|
||||||
@ -437,7 +441,9 @@ class QKVAttentionLegacy(nn.Module):
|
|||||||
ch = width // (3 * self.n_heads)
|
ch = width // (3 * self.n_heads)
|
||||||
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
||||||
scale = 1 / math.sqrt(math.sqrt(ch))
|
scale = 1 / math.sqrt(math.sqrt(ch))
|
||||||
weight = th.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
|
weight = th.einsum(
|
||||||
|
"bct,bcs->bts", q * scale, k * scale
|
||||||
|
) # More stable with f16 than dividing afterwards
|
||||||
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||||
a = th.einsum("bts,bcs->bct", weight, v)
|
a = th.einsum("bts,bcs->bct", weight, v)
|
||||||
return a.reshape(bs, -1, length)
|
return a.reshape(bs, -1, length)
|
||||||
@ -574,9 +580,7 @@ class UNetModel(nn.Module):
|
|||||||
), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
|
), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
|
||||||
|
|
||||||
if context_dim is not None:
|
if context_dim is not None:
|
||||||
assert (
|
assert use_spatial_transformer, "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
|
||||||
use_spatial_transformer
|
|
||||||
), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
|
|
||||||
if type(context_dim) == ListConfig:
|
if type(context_dim) == ListConfig:
|
||||||
context_dim = list(context_dim)
|
context_dim = list(context_dim)
|
||||||
|
|
||||||
@ -640,7 +644,9 @@ class UNetModel(nn.Module):
|
|||||||
self.num_heads_upsample = num_heads_upsample
|
self.num_heads_upsample = num_heads_upsample
|
||||||
self.predict_codebook_ids = n_embed is not None
|
self.predict_codebook_ids = n_embed is not None
|
||||||
|
|
||||||
assert use_fairscale_checkpoint != use_checkpoint or not (use_checkpoint or use_fairscale_checkpoint)
|
assert use_fairscale_checkpoint != use_checkpoint or not (
|
||||||
|
use_checkpoint or use_fairscale_checkpoint
|
||||||
|
)
|
||||||
|
|
||||||
self.use_fairscale_checkpoint = False
|
self.use_fairscale_checkpoint = False
|
||||||
checkpoint_wrapper_fn = (
|
checkpoint_wrapper_fn = (
|
||||||
@ -942,7 +948,9 @@ class UNetModel(nn.Module):
|
|||||||
print(f"loading lora from {ckpt_path}")
|
print(f"loading lora from {ckpt_path}")
|
||||||
sd = th.load(ckpt_path)["module"]
|
sd = th.load(ckpt_path)["module"]
|
||||||
sd = {
|
sd = {
|
||||||
key[len("model.diffusion_model") :]: sd[key] for key in sd if key.startswith("model.diffusion_model")
|
key[len("model.diffusion_model") :]: sd[key]
|
||||||
|
for key in sd
|
||||||
|
if key.startswith("model.diffusion_model")
|
||||||
}
|
}
|
||||||
self.load_state_dict(sd, strict=False)
|
self.load_state_dict(sd, strict=False)
|
||||||
|
|
||||||
@ -978,7 +986,9 @@ class UNetModel(nn.Module):
|
|||||||
self.num_classes is not None
|
self.num_classes is not None
|
||||||
), "must specify y if and only if the model is class-conditional"
|
), "must specify y if and only if the model is class-conditional"
|
||||||
hs = []
|
hs = []
|
||||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False, dtype=self.dtype)
|
t_emb = timestep_embedding(
|
||||||
|
timesteps, self.model_channels, repeat_only=False, dtype=self.dtype
|
||||||
|
)
|
||||||
emb = self.time_embed(t_emb)
|
emb = self.time_embed(t_emb)
|
||||||
|
|
||||||
if self.num_classes is not None:
|
if self.num_classes is not None:
|
||||||
|
@ -1,8 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py
|
Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
from typing import Dict, Union
|
from typing import Dict, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -85,9 +84,7 @@ class SingleStepDiffusionSampler(BaseDiffusionSampler):
|
|||||||
|
|
||||||
|
|
||||||
class EDMSampler(SingleStepDiffusionSampler):
|
class EDMSampler(SingleStepDiffusionSampler):
|
||||||
def __init__(
|
def __init__(self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs):
|
||||||
self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs
|
|
||||||
):
|
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
self.s_churn = s_churn
|
self.s_churn = s_churn
|
||||||
@ -106,15 +103,11 @@ class EDMSampler(SingleStepDiffusionSampler):
|
|||||||
dt = append_dims(next_sigma - sigma_hat, x.ndim)
|
dt = append_dims(next_sigma - sigma_hat, x.ndim)
|
||||||
|
|
||||||
euler_step = self.euler_step(x, d, dt)
|
euler_step = self.euler_step(x, d, dt)
|
||||||
x = self.possible_correction_step(
|
x = self.possible_correction_step(euler_step, x, d, dt, next_sigma, denoiser, cond, uc)
|
||||||
euler_step, x, d, dt, next_sigma, denoiser, cond, uc
|
|
||||||
)
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
|
def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
|
||||||
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
|
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
|
||||||
x, cond, uc, num_steps
|
|
||||||
)
|
|
||||||
|
|
||||||
for i in self.get_sigma_gen(num_sigmas):
|
for i in self.get_sigma_gen(num_sigmas):
|
||||||
gamma = (
|
gamma = (
|
||||||
@ -136,30 +129,23 @@ class EDMSampler(SingleStepDiffusionSampler):
|
|||||||
|
|
||||||
|
|
||||||
class DDIMSampler(SingleStepDiffusionSampler):
|
class DDIMSampler(SingleStepDiffusionSampler):
|
||||||
def __init__(
|
def __init__(self, s_noise=0.1, *args, **kwargs):
|
||||||
self, s_noise=0.1, *args, **kwargs
|
|
||||||
):
|
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
self.s_noise = s_noise
|
self.s_noise = s_noise
|
||||||
|
|
||||||
def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, s_noise=0.0):
|
def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, s_noise=0.0):
|
||||||
|
|
||||||
denoised = self.denoise(x, denoiser, sigma, cond, uc)
|
denoised = self.denoise(x, denoiser, sigma, cond, uc)
|
||||||
d = to_d(x, sigma, denoised)
|
d = to_d(x, sigma, denoised)
|
||||||
dt = append_dims(next_sigma * (1 - s_noise**2)**0.5 - sigma, x.ndim)
|
dt = append_dims(next_sigma * (1 - s_noise**2) ** 0.5 - sigma, x.ndim)
|
||||||
|
|
||||||
euler_step = x + dt * d + s_noise * append_dims(next_sigma, x.ndim) * torch.randn_like(x)
|
euler_step = x + dt * d + s_noise * append_dims(next_sigma, x.ndim) * torch.randn_like(x)
|
||||||
|
|
||||||
x = self.possible_correction_step(
|
x = self.possible_correction_step(euler_step, x, d, dt, next_sigma, denoiser, cond, uc)
|
||||||
euler_step, x, d, dt, next_sigma, denoiser, cond, uc
|
|
||||||
)
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
|
def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
|
||||||
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
|
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
|
||||||
x, cond, uc, num_steps
|
|
||||||
)
|
|
||||||
|
|
||||||
for i in self.get_sigma_gen(num_sigmas):
|
for i in self.get_sigma_gen(num_sigmas):
|
||||||
x = self.sampler_step(
|
x = self.sampler_step(
|
||||||
@ -198,9 +184,7 @@ class AncestralSampler(SingleStepDiffusionSampler):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
|
def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
|
||||||
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
|
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
|
||||||
x, cond, uc, num_steps
|
|
||||||
)
|
|
||||||
|
|
||||||
for i in self.get_sigma_gen(num_sigmas):
|
for i in self.get_sigma_gen(num_sigmas):
|
||||||
x = self.sampler_step(
|
x = self.sampler_step(
|
||||||
@ -227,43 +211,32 @@ class LinearMultistepSampler(BaseDiffusionSampler):
|
|||||||
self.order = order
|
self.order = order
|
||||||
|
|
||||||
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):
|
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):
|
||||||
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
|
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
|
||||||
x, cond, uc, num_steps
|
|
||||||
)
|
|
||||||
|
|
||||||
ds = []
|
ds = []
|
||||||
sigmas_cpu = sigmas.detach().cpu().numpy()
|
sigmas_cpu = sigmas.detach().cpu().numpy()
|
||||||
for i in self.get_sigma_gen(num_sigmas):
|
for i in self.get_sigma_gen(num_sigmas):
|
||||||
sigma = s_in * sigmas[i]
|
sigma = s_in * sigmas[i]
|
||||||
denoised = denoiser(
|
denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs)
|
||||||
*self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs
|
|
||||||
)
|
|
||||||
denoised = self.guider(denoised, sigma)
|
denoised = self.guider(denoised, sigma)
|
||||||
d = to_d(x, sigma, denoised)
|
d = to_d(x, sigma, denoised)
|
||||||
ds.append(d)
|
ds.append(d)
|
||||||
if len(ds) > self.order:
|
if len(ds) > self.order:
|
||||||
ds.pop(0)
|
ds.pop(0)
|
||||||
cur_order = min(i + 1, self.order)
|
cur_order = min(i + 1, self.order)
|
||||||
coeffs = [
|
coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)]
|
||||||
linear_multistep_coeff(cur_order, sigmas_cpu, i, j)
|
|
||||||
for j in range(cur_order)
|
|
||||||
]
|
|
||||||
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
|
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class EulerEDMSampler(EDMSampler):
|
class EulerEDMSampler(EDMSampler):
|
||||||
def possible_correction_step(
|
def possible_correction_step(self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc):
|
||||||
self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc
|
|
||||||
):
|
|
||||||
return euler_step
|
return euler_step
|
||||||
|
|
||||||
|
|
||||||
class HeunEDMSampler(EDMSampler):
|
class HeunEDMSampler(EDMSampler):
|
||||||
def possible_correction_step(
|
def possible_correction_step(self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc):
|
||||||
self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc
|
|
||||||
):
|
|
||||||
if torch.sum(next_sigma) < 1e-14:
|
if torch.sum(next_sigma) < 1e-14:
|
||||||
# Save a network evaluation if all noise levels are 0
|
# Save a network evaluation if all noise levels are 0
|
||||||
return euler_step
|
return euler_step
|
||||||
@ -273,9 +246,7 @@ class HeunEDMSampler(EDMSampler):
|
|||||||
d_prime = (d + d_new) / 2.0
|
d_prime = (d + d_new) / 2.0
|
||||||
|
|
||||||
# apply correction if noise level is not 0
|
# apply correction if noise level is not 0
|
||||||
x = torch.where(
|
x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step)
|
||||||
append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step
|
|
||||||
)
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -314,9 +285,7 @@ class DPMPP2SAncestralSampler(AncestralSampler):
|
|||||||
x = x_euler
|
x = x_euler
|
||||||
else:
|
else:
|
||||||
h, s, t, t_next = self.get_variables(sigma, sigma_down)
|
h, s, t, t_next = self.get_variables(sigma, sigma_down)
|
||||||
mult = [
|
mult = [append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next)]
|
||||||
append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next)
|
|
||||||
]
|
|
||||||
|
|
||||||
x2 = mult[0] * x - mult[1] * denoised
|
x2 = mult[0] * x - mult[1] * denoised
|
||||||
denoised2 = self.denoise(x2, denoiser, to_sigma(s), cond, uc)
|
denoised2 = self.denoise(x2, denoiser, to_sigma(s), cond, uc)
|
||||||
@ -367,8 +336,7 @@ class DPMPP2MSampler(BaseDiffusionSampler):
|
|||||||
|
|
||||||
h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma)
|
h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma)
|
||||||
mult = [
|
mult = [
|
||||||
append_dims(mult, x.ndim)
|
append_dims(mult, x.ndim) for mult in self.get_mult(h, r, t, t_next, previous_sigma)
|
||||||
for mult in self.get_mult(h, r, t, t_next, previous_sigma)
|
|
||||||
]
|
]
|
||||||
|
|
||||||
x_standard = mult[0] * x - mult[1] * denoised
|
x_standard = mult[0] * x - mult[1] * denoised
|
||||||
@ -380,16 +348,12 @@ class DPMPP2MSampler(BaseDiffusionSampler):
|
|||||||
x_advanced = mult[0] * x - mult[1] * denoised_d
|
x_advanced = mult[0] * x - mult[1] * denoised_d
|
||||||
|
|
||||||
# apply correction if noise level is not 0 and not first step
|
# apply correction if noise level is not 0 and not first step
|
||||||
x = torch.where(
|
x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard)
|
||||||
append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard
|
|
||||||
)
|
|
||||||
|
|
||||||
return x, denoised
|
return x, denoised
|
||||||
|
|
||||||
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):
|
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):
|
||||||
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
|
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
|
||||||
x, cond, uc, num_steps
|
|
||||||
)
|
|
||||||
|
|
||||||
old_denoised = None
|
old_denoised = None
|
||||||
for i in self.get_sigma_gen(num_sigmas):
|
for i in self.get_sigma_gen(num_sigmas):
|
||||||
@ -406,6 +370,7 @@ class DPMPP2MSampler(BaseDiffusionSampler):
|
|||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class SDEDPMPP2MSampler(BaseDiffusionSampler):
|
class SDEDPMPP2MSampler(BaseDiffusionSampler):
|
||||||
def get_variables(self, sigma, next_sigma, previous_sigma=None):
|
def get_variables(self, sigma, next_sigma, previous_sigma=None):
|
||||||
t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)]
|
t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)]
|
||||||
@ -420,7 +385,7 @@ class SDEDPMPP2MSampler(BaseDiffusionSampler):
|
|||||||
|
|
||||||
def get_mult(self, h, r, t, t_next, previous_sigma):
|
def get_mult(self, h, r, t, t_next, previous_sigma):
|
||||||
mult1 = to_sigma(t_next) / to_sigma(t) * (-h).exp()
|
mult1 = to_sigma(t_next) / to_sigma(t) * (-h).exp()
|
||||||
mult2 = (-2*h).expm1()
|
mult2 = (-2 * h).expm1()
|
||||||
|
|
||||||
if previous_sigma is not None:
|
if previous_sigma is not None:
|
||||||
mult3 = 1 + 1 / (2 * r)
|
mult3 = 1 + 1 / (2 * r)
|
||||||
@ -444,10 +409,9 @@ class SDEDPMPP2MSampler(BaseDiffusionSampler):
|
|||||||
|
|
||||||
h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma)
|
h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma)
|
||||||
mult = [
|
mult = [
|
||||||
append_dims(mult, x.ndim)
|
append_dims(mult, x.ndim) for mult in self.get_mult(h, r, t, t_next, previous_sigma)
|
||||||
for mult in self.get_mult(h, r, t, t_next, previous_sigma)
|
|
||||||
]
|
]
|
||||||
mult_noise = append_dims(next_sigma * (1 - (-2*h).exp())**0.5, x.ndim)
|
mult_noise = append_dims(next_sigma * (1 - (-2 * h).exp()) ** 0.5, x.ndim)
|
||||||
|
|
||||||
x_standard = mult[0] * x - mult[1] * denoised + mult_noise * torch.randn_like(x)
|
x_standard = mult[0] * x - mult[1] * denoised + mult_noise * torch.randn_like(x)
|
||||||
if old_denoised is None or torch.sum(next_sigma) < 1e-14:
|
if old_denoised is None or torch.sum(next_sigma) < 1e-14:
|
||||||
@ -458,16 +422,12 @@ class SDEDPMPP2MSampler(BaseDiffusionSampler):
|
|||||||
x_advanced = mult[0] * x - mult[1] * denoised_d + mult_noise * torch.randn_like(x)
|
x_advanced = mult[0] * x - mult[1] * denoised_d + mult_noise * torch.randn_like(x)
|
||||||
|
|
||||||
# apply correction if noise level is not 0 and not first step
|
# apply correction if noise level is not 0 and not first step
|
||||||
x = torch.where(
|
x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard)
|
||||||
append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard
|
|
||||||
)
|
|
||||||
|
|
||||||
return x, denoised
|
return x, denoised
|
||||||
|
|
||||||
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, **kwargs):
|
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, **kwargs):
|
||||||
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
|
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
|
||||||
x, cond, uc, num_steps
|
|
||||||
)
|
|
||||||
|
|
||||||
old_denoised = None
|
old_denoised = None
|
||||||
for i in self.get_sigma_gen(num_sigmas):
|
for i in self.get_sigma_gen(num_sigmas):
|
||||||
@ -484,6 +444,7 @@ class SDEDPMPP2MSampler(BaseDiffusionSampler):
|
|||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class SdeditEDMSampler(EulerEDMSampler):
|
class SdeditEDMSampler(EulerEDMSampler):
|
||||||
def __init__(self, edit_ratio=0.5, *args, **kwargs):
|
def __init__(self, edit_ratio=0.5, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
@ -525,8 +486,8 @@ class SdeditEDMSampler(EulerEDMSampler):
|
|||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
class VideoDDIMSampler(BaseDiffusionSampler):
|
|
||||||
|
|
||||||
|
class VideoDDIMSampler(BaseDiffusionSampler):
|
||||||
def __init__(self, fixed_frames=0, sdedit=False, **kwargs):
|
def __init__(self, fixed_frames=0, sdedit=False, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.fixed_frames = fixed_frames
|
self.fixed_frames = fixed_frames
|
||||||
@ -534,10 +495,15 @@ class VideoDDIMSampler(BaseDiffusionSampler):
|
|||||||
|
|
||||||
def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
|
def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
|
||||||
alpha_cumprod_sqrt, timesteps = self.discretization(
|
alpha_cumprod_sqrt, timesteps = self.discretization(
|
||||||
self.num_steps if num_steps is None else num_steps, device=self.device, return_idx=True, do_append_zero=False
|
self.num_steps if num_steps is None else num_steps,
|
||||||
|
device=self.device,
|
||||||
|
return_idx=True,
|
||||||
|
do_append_zero=False,
|
||||||
)
|
)
|
||||||
alpha_cumprod_sqrt = torch.cat([alpha_cumprod_sqrt, alpha_cumprod_sqrt.new_ones([1])])
|
alpha_cumprod_sqrt = torch.cat([alpha_cumprod_sqrt, alpha_cumprod_sqrt.new_ones([1])])
|
||||||
timesteps = torch.cat([torch.tensor(list(timesteps)).new_zeros([1])-1, torch.tensor(list(timesteps))])
|
timesteps = torch.cat(
|
||||||
|
[torch.tensor(list(timesteps)).new_zeros([1]) - 1, torch.tensor(list(timesteps))]
|
||||||
|
)
|
||||||
|
|
||||||
uc = default(uc, cond)
|
uc = default(uc, cond)
|
||||||
|
|
||||||
@ -547,7 +513,19 @@ class VideoDDIMSampler(BaseDiffusionSampler):
|
|||||||
|
|
||||||
return x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps
|
return x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps
|
||||||
|
|
||||||
def denoise(self, x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep=None, idx=None, scale=None, scale_emb=None, ofs=None):
|
def denoise(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
denoiser,
|
||||||
|
alpha_cumprod_sqrt,
|
||||||
|
cond,
|
||||||
|
uc,
|
||||||
|
timestep=None,
|
||||||
|
idx=None,
|
||||||
|
scale=None,
|
||||||
|
scale_emb=None,
|
||||||
|
ofs=None,
|
||||||
|
):
|
||||||
additional_model_inputs = {}
|
additional_model_inputs = {}
|
||||||
|
|
||||||
if ofs is not None:
|
if ofs is not None:
|
||||||
@ -557,26 +535,62 @@ class VideoDDIMSampler(BaseDiffusionSampler):
|
|||||||
additional_model_inputs['idx'] = x.new_ones([x.shape[0]]) * timestep
|
additional_model_inputs['idx'] = x.new_ones([x.shape[0]]) * timestep
|
||||||
if scale_emb is not None:
|
if scale_emb is not None:
|
||||||
additional_model_inputs['scale_emb'] = scale_emb
|
additional_model_inputs['scale_emb'] = scale_emb
|
||||||
denoised = denoiser(x, alpha_cumprod_sqrt, cond, **additional_model_inputs).to(torch.float32)
|
denoised = denoiser(x, alpha_cumprod_sqrt, cond, **additional_model_inputs).to(
|
||||||
|
torch.float32
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
additional_model_inputs['idx'] = torch.cat([x.new_ones([x.shape[0]]) * timestep] * 2)
|
additional_model_inputs['idx'] = torch.cat([x.new_ones([x.shape[0]]) * timestep] * 2)
|
||||||
denoised = denoiser(*self.guider.prepare_inputs(x, alpha_cumprod_sqrt, cond, uc), **additional_model_inputs).to(torch.float32)
|
denoised = denoiser(
|
||||||
|
*self.guider.prepare_inputs(x, alpha_cumprod_sqrt, cond, uc),
|
||||||
|
**additional_model_inputs,
|
||||||
|
).to(torch.float32)
|
||||||
if isinstance(self.guider, DynamicCFG):
|
if isinstance(self.guider, DynamicCFG):
|
||||||
denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt**2)**0.5, step_index=self.num_steps - timestep, scale=scale)
|
denoised = self.guider(
|
||||||
|
denoised,
|
||||||
|
(1 - alpha_cumprod_sqrt**2) ** 0.5,
|
||||||
|
step_index=self.num_steps - timestep,
|
||||||
|
scale=scale,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt**2)**0.5, scale=scale)
|
denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt**2) ** 0.5, scale=scale)
|
||||||
return denoised
|
return denoised
|
||||||
|
|
||||||
def sampler_step(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, denoiser, x, cond, uc=None, idx=None, timestep=None, scale=None, scale_emb=None, ofs=None):
|
def sampler_step(
|
||||||
denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx, scale=scale, scale_emb=scale_emb, ofs=ofs).to(torch.float32) # 1020
|
self,
|
||||||
|
alpha_cumprod_sqrt,
|
||||||
|
next_alpha_cumprod_sqrt,
|
||||||
|
denoiser,
|
||||||
|
x,
|
||||||
|
cond,
|
||||||
|
uc=None,
|
||||||
|
idx=None,
|
||||||
|
timestep=None,
|
||||||
|
scale=None,
|
||||||
|
scale_emb=None,
|
||||||
|
ofs=None,
|
||||||
|
):
|
||||||
|
denoised = self.denoise(
|
||||||
|
x,
|
||||||
|
denoiser,
|
||||||
|
alpha_cumprod_sqrt,
|
||||||
|
cond,
|
||||||
|
uc,
|
||||||
|
timestep,
|
||||||
|
idx,
|
||||||
|
scale=scale,
|
||||||
|
scale_emb=scale_emb,
|
||||||
|
ofs=ofs,
|
||||||
|
).to(torch.float32) # 1020
|
||||||
|
|
||||||
a_t = ((1-next_alpha_cumprod_sqrt**2)/(1-alpha_cumprod_sqrt**2))**0.5
|
a_t = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5
|
||||||
b_t = next_alpha_cumprod_sqrt - alpha_cumprod_sqrt * a_t
|
b_t = next_alpha_cumprod_sqrt - alpha_cumprod_sqrt * a_t
|
||||||
|
|
||||||
x = append_dims(a_t, x.ndim) * x + append_dims(b_t, x.ndim) * denoised
|
x = append_dims(a_t, x.ndim) * x + append_dims(b_t, x.ndim) * denoised
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None, ofs=None): # 1020
|
def __call__(
|
||||||
|
self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None, ofs=None
|
||||||
|
): # 1020
|
||||||
x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop(
|
x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop(
|
||||||
x, cond, uc, num_steps
|
x, cond, uc, num_steps
|
||||||
)
|
)
|
||||||
@ -590,17 +604,16 @@ class VideoDDIMSampler(BaseDiffusionSampler):
|
|||||||
cond,
|
cond,
|
||||||
uc,
|
uc,
|
||||||
idx=self.num_steps - i,
|
idx=self.num_steps - i,
|
||||||
timestep=timesteps[-(i+1)],
|
timestep=timesteps[-(i + 1)],
|
||||||
scale=scale,
|
scale=scale,
|
||||||
scale_emb=scale_emb,
|
scale_emb=scale_emb,
|
||||||
ofs=ofs # 1020
|
ofs=ofs, # 1020
|
||||||
)
|
)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class Image2VideoDDIMSampler(BaseDiffusionSampler):
|
class Image2VideoDDIMSampler(BaseDiffusionSampler):
|
||||||
|
|
||||||
def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
|
def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
|
||||||
alpha_cumprod_sqrt, timesteps = self.discretization(
|
alpha_cumprod_sqrt, timesteps = self.discretization(
|
||||||
self.num_steps if num_steps is None else num_steps, device=self.device, return_idx=True
|
self.num_steps if num_steps is None else num_steps, device=self.device, return_idx=True
|
||||||
@ -616,22 +629,36 @@ class Image2VideoDDIMSampler(BaseDiffusionSampler):
|
|||||||
def denoise(self, x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep=None):
|
def denoise(self, x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep=None):
|
||||||
additional_model_inputs = {}
|
additional_model_inputs = {}
|
||||||
additional_model_inputs['idx'] = torch.cat([x.new_ones([x.shape[0]]) * timestep] * 2)
|
additional_model_inputs['idx'] = torch.cat([x.new_ones([x.shape[0]]) * timestep] * 2)
|
||||||
denoised = denoiser(*self.guider.prepare_inputs(x, alpha_cumprod_sqrt, cond, uc), **additional_model_inputs).to(
|
denoised = denoiser(
|
||||||
torch.float32)
|
*self.guider.prepare_inputs(x, alpha_cumprod_sqrt, cond, uc), **additional_model_inputs
|
||||||
|
).to(torch.float32)
|
||||||
if isinstance(self.guider, DynamicCFG):
|
if isinstance(self.guider, DynamicCFG):
|
||||||
denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt ** 2) ** 0.5, step_index=self.num_steps - timestep)
|
denoised = self.guider(
|
||||||
|
denoised, (1 - alpha_cumprod_sqrt**2) ** 0.5, step_index=self.num_steps - timestep
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt ** 2) ** 0.5)
|
denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt**2) ** 0.5)
|
||||||
return denoised
|
return denoised
|
||||||
|
|
||||||
def sampler_step(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, denoiser, x, cond, uc=None, idx=None,
|
def sampler_step(
|
||||||
timestep=None):
|
self,
|
||||||
|
alpha_cumprod_sqrt,
|
||||||
|
next_alpha_cumprod_sqrt,
|
||||||
|
denoiser,
|
||||||
|
x,
|
||||||
|
cond,
|
||||||
|
uc=None,
|
||||||
|
idx=None,
|
||||||
|
timestep=None,
|
||||||
|
):
|
||||||
# 此处的sigma实际上是alpha_cumprod_sqrt
|
# 此处的sigma实际上是alpha_cumprod_sqrt
|
||||||
denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep).to(torch.float32)
|
denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep).to(
|
||||||
|
torch.float32
|
||||||
|
)
|
||||||
if idx == 1:
|
if idx == 1:
|
||||||
return denoised
|
return denoised
|
||||||
|
|
||||||
a_t = ((1 - next_alpha_cumprod_sqrt ** 2) / (1 - alpha_cumprod_sqrt ** 2)) ** 0.5
|
a_t = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5
|
||||||
b_t = next_alpha_cumprod_sqrt - alpha_cumprod_sqrt * a_t
|
b_t = next_alpha_cumprod_sqrt - alpha_cumprod_sqrt * a_t
|
||||||
|
|
||||||
x = append_dims(a_t, x.ndim) * x + append_dims(b_t, x.ndim) * denoised
|
x = append_dims(a_t, x.ndim) * x + append_dims(b_t, x.ndim) * denoised
|
||||||
@ -651,31 +678,36 @@ class Image2VideoDDIMSampler(BaseDiffusionSampler):
|
|||||||
cond,
|
cond,
|
||||||
uc,
|
uc,
|
||||||
idx=self.num_steps - i,
|
idx=self.num_steps - i,
|
||||||
timestep=timesteps[-(i + 1)]
|
timestep=timesteps[-(i + 1)],
|
||||||
)
|
)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class VPSDEDPMPP2MSampler(VideoDDIMSampler):
|
class VPSDEDPMPP2MSampler(VideoDDIMSampler):
|
||||||
def get_variables(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt=None):
|
def get_variables(
|
||||||
alpha_cumprod = alpha_cumprod_sqrt ** 2
|
self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt=None
|
||||||
lamb = ((alpha_cumprod / (1-alpha_cumprod))**0.5).log()
|
):
|
||||||
next_alpha_cumprod = next_alpha_cumprod_sqrt ** 2
|
alpha_cumprod = alpha_cumprod_sqrt**2
|
||||||
lamb_next = ((next_alpha_cumprod / (1-next_alpha_cumprod))**0.5).log()
|
lamb = ((alpha_cumprod / (1 - alpha_cumprod)) ** 0.5).log()
|
||||||
|
next_alpha_cumprod = next_alpha_cumprod_sqrt**2
|
||||||
|
lamb_next = ((next_alpha_cumprod / (1 - next_alpha_cumprod)) ** 0.5).log()
|
||||||
h = lamb_next - lamb
|
h = lamb_next - lamb
|
||||||
|
|
||||||
if previous_alpha_cumprod_sqrt is not None:
|
if previous_alpha_cumprod_sqrt is not None:
|
||||||
previous_alpha_cumprod = previous_alpha_cumprod_sqrt ** 2
|
previous_alpha_cumprod = previous_alpha_cumprod_sqrt**2
|
||||||
lamb_previous = ((previous_alpha_cumprod / (1-previous_alpha_cumprod))**0.5).log()
|
lamb_previous = ((previous_alpha_cumprod / (1 - previous_alpha_cumprod)) ** 0.5).log()
|
||||||
h_last = lamb - lamb_previous
|
h_last = lamb - lamb_previous
|
||||||
r = h_last / h
|
r = h_last / h
|
||||||
return h, r, lamb, lamb_next
|
return h, r, lamb, lamb_next
|
||||||
else:
|
else:
|
||||||
return h, None, lamb, lamb_next
|
return h, None, lamb, lamb_next
|
||||||
|
|
||||||
def get_mult(self, h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt):
|
def get_mult(
|
||||||
mult1 = ((1-next_alpha_cumprod_sqrt**2) / (1-alpha_cumprod_sqrt**2))**0.5 * (-h).exp()
|
self, h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt
|
||||||
mult2 = (-2*h).expm1() * next_alpha_cumprod_sqrt
|
):
|
||||||
|
mult1 = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5 * (-h).exp()
|
||||||
|
mult2 = (-2 * h).expm1() * next_alpha_cumprod_sqrt
|
||||||
|
|
||||||
if previous_alpha_cumprod_sqrt is not None:
|
if previous_alpha_cumprod_sqrt is not None:
|
||||||
mult3 = 1 + 1 / (2 * r)
|
mult3 = 1 + 1 / (2 * r)
|
||||||
@ -698,18 +730,35 @@ class VPSDEDPMPP2MSampler(VideoDDIMSampler):
|
|||||||
timestep=None,
|
timestep=None,
|
||||||
scale=None,
|
scale=None,
|
||||||
scale_emb=None,
|
scale_emb=None,
|
||||||
ofs=None # 1020
|
ofs=None, # 1020
|
||||||
):
|
):
|
||||||
denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx, scale=scale, scale_emb=scale_emb, ofs=ofs).to(torch.float32) # 1020
|
denoised = self.denoise(
|
||||||
|
x,
|
||||||
|
denoiser,
|
||||||
|
alpha_cumprod_sqrt,
|
||||||
|
cond,
|
||||||
|
uc,
|
||||||
|
timestep,
|
||||||
|
idx,
|
||||||
|
scale=scale,
|
||||||
|
scale_emb=scale_emb,
|
||||||
|
ofs=ofs,
|
||||||
|
).to(torch.float32) # 1020
|
||||||
if idx == 1:
|
if idx == 1:
|
||||||
return denoised, denoised
|
return denoised, denoised
|
||||||
|
|
||||||
h, r, lamb, lamb_next = self.get_variables(alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt)
|
h, r, lamb, lamb_next = self.get_variables(
|
||||||
|
alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt
|
||||||
|
)
|
||||||
mult = [
|
mult = [
|
||||||
append_dims(mult, x.ndim)
|
append_dims(mult, x.ndim)
|
||||||
for mult in self.get_mult(h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt)
|
for mult in self.get_mult(
|
||||||
|
h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt
|
||||||
|
)
|
||||||
]
|
]
|
||||||
mult_noise = append_dims((1-next_alpha_cumprod_sqrt**2)**0.5 * (1 - (-2*h).exp())**0.5, x.ndim)
|
mult_noise = append_dims(
|
||||||
|
(1 - next_alpha_cumprod_sqrt**2) ** 0.5 * (1 - (-2 * h).exp()) ** 0.5, x.ndim
|
||||||
|
)
|
||||||
|
|
||||||
x_standard = mult[0] * x - mult[1] * denoised + mult_noise * torch.randn_like(x)
|
x_standard = mult[0] * x - mult[1] * denoised + mult_noise * torch.randn_like(x)
|
||||||
if old_denoised is None or torch.sum(next_alpha_cumprod_sqrt) < 1e-14:
|
if old_denoised is None or torch.sum(next_alpha_cumprod_sqrt) < 1e-14:
|
||||||
@ -723,23 +772,26 @@ class VPSDEDPMPP2MSampler(VideoDDIMSampler):
|
|||||||
|
|
||||||
return x, denoised
|
return x, denoised
|
||||||
|
|
||||||
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None, ofs=None): # 1020
|
def __call__(
|
||||||
|
self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None, ofs=None
|
||||||
|
): # 1020
|
||||||
x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop(
|
x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop(
|
||||||
x, cond, uc, num_steps
|
x, cond, uc, num_steps
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.fixed_frames > 0:
|
if self.fixed_frames > 0:
|
||||||
prefix_frames = x[:, :self.fixed_frames]
|
prefix_frames = x[:, : self.fixed_frames]
|
||||||
old_denoised = None
|
old_denoised = None
|
||||||
for i in self.get_sigma_gen(num_sigmas):
|
for i in self.get_sigma_gen(num_sigmas):
|
||||||
|
|
||||||
if self.fixed_frames > 0:
|
if self.fixed_frames > 0:
|
||||||
if self.sdedit:
|
if self.sdedit:
|
||||||
rd = torch.randn_like(prefix_frames)
|
rd = torch.randn_like(prefix_frames)
|
||||||
noised_prefix_frames = alpha_cumprod_sqrt[i] * prefix_frames + rd * append_dims(s_in * (1 - alpha_cumprod_sqrt[i] ** 2)**0.5, len(prefix_frames.shape))
|
noised_prefix_frames = alpha_cumprod_sqrt[i] * prefix_frames + rd * append_dims(
|
||||||
x = torch.cat([noised_prefix_frames, x[:, self.fixed_frames:]], dim=1)
|
s_in * (1 - alpha_cumprod_sqrt[i] ** 2) ** 0.5, len(prefix_frames.shape)
|
||||||
|
)
|
||||||
|
x = torch.cat([noised_prefix_frames, x[:, self.fixed_frames :]], dim=1)
|
||||||
else:
|
else:
|
||||||
x = torch.cat([prefix_frames, x[:, self.fixed_frames:]], dim=1)
|
x = torch.cat([prefix_frames, x[:, self.fixed_frames :]], dim=1)
|
||||||
x, old_denoised = self.sampler_step(
|
x, old_denoised = self.sampler_step(
|
||||||
old_denoised,
|
old_denoised,
|
||||||
None if i == 0 else s_in * alpha_cumprod_sqrt[i - 1],
|
None if i == 0 else s_in * alpha_cumprod_sqrt[i - 1],
|
||||||
@ -750,37 +802,41 @@ class VPSDEDPMPP2MSampler(VideoDDIMSampler):
|
|||||||
cond,
|
cond,
|
||||||
uc=uc,
|
uc=uc,
|
||||||
idx=self.num_steps - i,
|
idx=self.num_steps - i,
|
||||||
timestep=timesteps[-(i+1)],
|
timestep=timesteps[-(i + 1)],
|
||||||
scale=scale,
|
scale=scale,
|
||||||
scale_emb=scale_emb,
|
scale_emb=scale_emb,
|
||||||
ofs=ofs # 1020
|
ofs=ofs, # 1020
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.fixed_frames > 0:
|
if self.fixed_frames > 0:
|
||||||
x = torch.cat([prefix_frames, x[:, self.fixed_frames:]], dim=1)
|
x = torch.cat([prefix_frames, x[:, self.fixed_frames :]], dim=1)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class VPODEDPMPP2MSampler(VideoDDIMSampler):
|
class VPODEDPMPP2MSampler(VideoDDIMSampler):
|
||||||
def get_variables(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt=None):
|
def get_variables(
|
||||||
alpha_cumprod = alpha_cumprod_sqrt ** 2
|
self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt=None
|
||||||
lamb = ((alpha_cumprod / (1-alpha_cumprod))**0.5).log()
|
):
|
||||||
next_alpha_cumprod = next_alpha_cumprod_sqrt ** 2
|
alpha_cumprod = alpha_cumprod_sqrt**2
|
||||||
lamb_next = ((next_alpha_cumprod / (1-next_alpha_cumprod))**0.5).log()
|
lamb = ((alpha_cumprod / (1 - alpha_cumprod)) ** 0.5).log()
|
||||||
|
next_alpha_cumprod = next_alpha_cumprod_sqrt**2
|
||||||
|
lamb_next = ((next_alpha_cumprod / (1 - next_alpha_cumprod)) ** 0.5).log()
|
||||||
h = lamb_next - lamb
|
h = lamb_next - lamb
|
||||||
|
|
||||||
if previous_alpha_cumprod_sqrt is not None:
|
if previous_alpha_cumprod_sqrt is not None:
|
||||||
previous_alpha_cumprod = previous_alpha_cumprod_sqrt ** 2
|
previous_alpha_cumprod = previous_alpha_cumprod_sqrt**2
|
||||||
lamb_previous = ((previous_alpha_cumprod / (1-previous_alpha_cumprod))**0.5).log()
|
lamb_previous = ((previous_alpha_cumprod / (1 - previous_alpha_cumprod)) ** 0.5).log()
|
||||||
h_last = lamb - lamb_previous
|
h_last = lamb - lamb_previous
|
||||||
r = h_last / h
|
r = h_last / h
|
||||||
return h, r, lamb, lamb_next
|
return h, r, lamb, lamb_next
|
||||||
else:
|
else:
|
||||||
return h, None, lamb, lamb_next
|
return h, None, lamb, lamb_next
|
||||||
|
|
||||||
def get_mult(self, h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt):
|
def get_mult(
|
||||||
mult1 = ((1-next_alpha_cumprod_sqrt**2) / (1-alpha_cumprod_sqrt**2))**0.5
|
self, h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt
|
||||||
|
):
|
||||||
|
mult1 = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5
|
||||||
mult2 = (-h).expm1() * next_alpha_cumprod_sqrt
|
mult2 = (-h).expm1() * next_alpha_cumprod_sqrt
|
||||||
|
|
||||||
if previous_alpha_cumprod_sqrt is not None:
|
if previous_alpha_cumprod_sqrt is not None:
|
||||||
@ -801,16 +857,22 @@ class VPODEDPMPP2MSampler(VideoDDIMSampler):
|
|||||||
cond,
|
cond,
|
||||||
uc=None,
|
uc=None,
|
||||||
idx=None,
|
idx=None,
|
||||||
timestep=None
|
timestep=None,
|
||||||
):
|
):
|
||||||
denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx).to(torch.float32)
|
denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx).to(
|
||||||
|
torch.float32
|
||||||
|
)
|
||||||
if idx == 1:
|
if idx == 1:
|
||||||
return denoised, denoised
|
return denoised, denoised
|
||||||
|
|
||||||
h, r, lamb, lamb_next = self.get_variables(alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt)
|
h, r, lamb, lamb_next = self.get_variables(
|
||||||
|
alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt
|
||||||
|
)
|
||||||
mult = [
|
mult = [
|
||||||
append_dims(mult, x.ndim)
|
append_dims(mult, x.ndim)
|
||||||
for mult in self.get_mult(h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt)
|
for mult in self.get_mult(
|
||||||
|
h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt
|
||||||
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
x_standard = mult[0] * x - mult[1] * denoised
|
x_standard = mult[0] * x - mult[1] * denoised
|
||||||
@ -842,22 +904,44 @@ class VPODEDPMPP2MSampler(VideoDDIMSampler):
|
|||||||
cond,
|
cond,
|
||||||
uc=uc,
|
uc=uc,
|
||||||
idx=self.num_steps - i,
|
idx=self.num_steps - i,
|
||||||
timestep=timesteps[-(i+1)]
|
timestep=timesteps[-(i + 1)],
|
||||||
)
|
)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class VideoDDPMSampler(VideoDDIMSampler):
|
class VideoDDPMSampler(VideoDDIMSampler):
|
||||||
def sampler_step(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, denoiser, x, cond, uc=None, idx=None):
|
def sampler_step(
|
||||||
|
self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, denoiser, x, cond, uc=None, idx=None
|
||||||
|
):
|
||||||
# 此处的sigma实际上是alpha_cumprod_sqrt
|
# 此处的sigma实际上是alpha_cumprod_sqrt
|
||||||
denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, idx*1000//self.num_steps).to(torch.float32)
|
denoised = self.denoise(
|
||||||
|
x, denoiser, alpha_cumprod_sqrt, cond, uc, idx * 1000 // self.num_steps
|
||||||
|
).to(torch.float32)
|
||||||
if idx == 1:
|
if idx == 1:
|
||||||
return denoised
|
return denoised
|
||||||
|
|
||||||
alpha_sqrt = alpha_cumprod_sqrt / next_alpha_cumprod_sqrt
|
alpha_sqrt = alpha_cumprod_sqrt / next_alpha_cumprod_sqrt
|
||||||
x = append_dims(alpha_sqrt * (1-next_alpha_cumprod_sqrt**2) / (1-alpha_cumprod_sqrt**2), x.ndim) * x \
|
x = (
|
||||||
+ append_dims(next_alpha_cumprod_sqrt * (1-alpha_sqrt**2) / (1-alpha_cumprod_sqrt**2), x.ndim) * denoised \
|
append_dims(
|
||||||
+ append_dims(((1-next_alpha_cumprod_sqrt**2) * (1-alpha_sqrt**2) / (1-alpha_cumprod_sqrt**2))**0.5, x.ndim) * torch.randn_like(x)
|
alpha_sqrt * (1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2), x.ndim
|
||||||
|
)
|
||||||
|
* x
|
||||||
|
+ append_dims(
|
||||||
|
next_alpha_cumprod_sqrt * (1 - alpha_sqrt**2) / (1 - alpha_cumprod_sqrt**2), x.ndim
|
||||||
|
)
|
||||||
|
* denoised
|
||||||
|
+ append_dims(
|
||||||
|
(
|
||||||
|
(1 - next_alpha_cumprod_sqrt**2)
|
||||||
|
* (1 - alpha_sqrt**2)
|
||||||
|
/ (1 - alpha_cumprod_sqrt**2)
|
||||||
|
)
|
||||||
|
** 0.5,
|
||||||
|
x.ndim,
|
||||||
|
)
|
||||||
|
* torch.randn_like(x)
|
||||||
|
)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -874,7 +958,7 @@ class VideoDDPMSampler(VideoDDIMSampler):
|
|||||||
x,
|
x,
|
||||||
cond,
|
cond,
|
||||||
uc,
|
uc,
|
||||||
idx=self.num_steps - i
|
idx=self.num_steps - i,
|
||||||
)
|
)
|
||||||
|
|
||||||
return x
|
return x
|
@ -17,7 +17,15 @@ class EDMSampling:
|
|||||||
|
|
||||||
|
|
||||||
class DiscreteSampling:
|
class DiscreteSampling:
|
||||||
def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True, uniform_sampling=False, group_num=0):
|
def __init__(
|
||||||
|
self,
|
||||||
|
discretization_config,
|
||||||
|
num_idx,
|
||||||
|
do_append_zero=False,
|
||||||
|
flip=True,
|
||||||
|
uniform_sampling=False,
|
||||||
|
group_num=0,
|
||||||
|
):
|
||||||
self.num_idx = num_idx
|
self.num_idx = num_idx
|
||||||
self.sigmas = instantiate_from_config(discretization_config)(
|
self.sigmas = instantiate_from_config(discretization_config)(
|
||||||
num_idx, do_append_zero=do_append_zero, flip=flip
|
num_idx, do_append_zero=do_append_zero, flip=flip
|
||||||
@ -42,7 +50,11 @@ class DiscreteSampling:
|
|||||||
group_index = rank // self.group_width
|
group_index = rank // self.group_width
|
||||||
idx = default(
|
idx = default(
|
||||||
rand,
|
rand,
|
||||||
torch.randint(group_index * self.sigma_interval, (group_index + 1) * self.sigma_interval, (n_samples,)),
|
torch.randint(
|
||||||
|
group_index * self.sigma_interval,
|
||||||
|
(group_index + 1) * self.sigma_interval,
|
||||||
|
(n_samples,),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
idx = default(
|
idx = default(
|
||||||
@ -54,8 +66,11 @@ class DiscreteSampling:
|
|||||||
else:
|
else:
|
||||||
return self.idx_to_sigma(idx)
|
return self.idx_to_sigma(idx)
|
||||||
|
|
||||||
|
|
||||||
class PartialDiscreteSampling:
|
class PartialDiscreteSampling:
|
||||||
def __init__(self, discretization_config, total_num_idx, partial_num_idx, do_append_zero=False, flip=True):
|
def __init__(
|
||||||
|
self, discretization_config, total_num_idx, partial_num_idx, do_append_zero=False, flip=True
|
||||||
|
):
|
||||||
self.total_num_idx = total_num_idx
|
self.total_num_idx = total_num_idx
|
||||||
self.partial_num_idx = partial_num_idx
|
self.partial_num_idx = partial_num_idx
|
||||||
self.sigmas = instantiate_from_config(discretization_config)(
|
self.sigmas = instantiate_from_config(discretization_config)(
|
||||||
|
@ -24,7 +24,9 @@ def make_beta_schedule(
|
|||||||
linear_end=2e-2,
|
linear_end=2e-2,
|
||||||
):
|
):
|
||||||
if schedule == "linear":
|
if schedule == "linear":
|
||||||
betas = torch.linspace(linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64) ** 2
|
betas = (
|
||||||
|
torch.linspace(linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64) ** 2
|
||||||
|
)
|
||||||
return betas.numpy()
|
return betas.numpy()
|
||||||
|
|
||||||
|
|
||||||
@ -50,7 +52,9 @@ def mixed_checkpoint(func, inputs: dict, params, flag):
|
|||||||
tensor_keys = [key for key in inputs if isinstance(inputs[key], torch.Tensor)]
|
tensor_keys = [key for key in inputs if isinstance(inputs[key], torch.Tensor)]
|
||||||
tensor_inputs = [inputs[key] for key in inputs if isinstance(inputs[key], torch.Tensor)]
|
tensor_inputs = [inputs[key] for key in inputs if isinstance(inputs[key], torch.Tensor)]
|
||||||
non_tensor_keys = [key for key in inputs if not isinstance(inputs[key], torch.Tensor)]
|
non_tensor_keys = [key for key in inputs if not isinstance(inputs[key], torch.Tensor)]
|
||||||
non_tensor_inputs = [inputs[key] for key in inputs if not isinstance(inputs[key], torch.Tensor)]
|
non_tensor_inputs = [
|
||||||
|
inputs[key] for key in inputs if not isinstance(inputs[key], torch.Tensor)
|
||||||
|
]
|
||||||
args = tuple(tensor_inputs) + tuple(non_tensor_inputs) + tuple(params)
|
args = tuple(tensor_inputs) + tuple(non_tensor_inputs) + tuple(params)
|
||||||
return MixedCheckpointFunction.apply(
|
return MixedCheckpointFunction.apply(
|
||||||
func,
|
func,
|
||||||
@ -84,9 +88,14 @@ class MixedCheckpointFunction(torch.autograd.Function):
|
|||||||
}
|
}
|
||||||
assert len(tensor_keys) == length_tensors and len(non_tensor_keys) == length_non_tensors
|
assert len(tensor_keys) == length_tensors and len(non_tensor_keys) == length_non_tensors
|
||||||
|
|
||||||
ctx.input_tensors = {key: val for (key, val) in zip(tensor_keys, list(args[: ctx.end_tensors]))}
|
ctx.input_tensors = {
|
||||||
|
key: val for (key, val) in zip(tensor_keys, list(args[: ctx.end_tensors]))
|
||||||
|
}
|
||||||
ctx.input_non_tensors = {
|
ctx.input_non_tensors = {
|
||||||
key: val for (key, val) in zip(non_tensor_keys, list(args[ctx.end_tensors : ctx.end_non_tensors]))
|
key: val
|
||||||
|
for (key, val) in zip(
|
||||||
|
non_tensor_keys, list(args[ctx.end_tensors : ctx.end_non_tensors])
|
||||||
|
)
|
||||||
}
|
}
|
||||||
ctx.run_function = run_function
|
ctx.run_function = run_function
|
||||||
ctx.input_params = list(args[ctx.end_non_tensors :])
|
ctx.input_params = list(args[ctx.end_non_tensors :])
|
||||||
@ -98,13 +107,18 @@ class MixedCheckpointFunction(torch.autograd.Function):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, *output_grads):
|
def backward(ctx, *output_grads):
|
||||||
# additional_args = {key: ctx.input_tensors[key] for key in ctx.input_tensors if not isinstance(ctx.input_tensors[key],torch.Tensor)}
|
# additional_args = {key: ctx.input_tensors[key] for key in ctx.input_tensors if not isinstance(ctx.input_tensors[key],torch.Tensor)}
|
||||||
ctx.input_tensors = {key: ctx.input_tensors[key].detach().requires_grad_(True) for key in ctx.input_tensors}
|
ctx.input_tensors = {
|
||||||
|
key: ctx.input_tensors[key].detach().requires_grad_(True) for key in ctx.input_tensors
|
||||||
|
}
|
||||||
|
|
||||||
with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
|
with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
|
||||||
# Fixes a bug where the first op in run_function modifies the
|
# Fixes a bug where the first op in run_function modifies the
|
||||||
# Tensor storage in place, which is not allowed for detach()'d
|
# Tensor storage in place, which is not allowed for detach()'d
|
||||||
# Tensors.
|
# Tensors.
|
||||||
shallow_copies = {key: ctx.input_tensors[key].view_as(ctx.input_tensors[key]) for key in ctx.input_tensors}
|
shallow_copies = {
|
||||||
|
key: ctx.input_tensors[key].view_as(ctx.input_tensors[key])
|
||||||
|
for key in ctx.input_tensors
|
||||||
|
}
|
||||||
# shallow_copies.update(additional_args)
|
# shallow_copies.update(additional_args)
|
||||||
output_tensors = ctx.run_function(**shallow_copies, **ctx.input_non_tensors)
|
output_tensors = ctx.run_function(**shallow_copies, **ctx.input_non_tensors)
|
||||||
input_grads = torch.autograd.grad(
|
input_grads = torch.autograd.grad(
|
||||||
@ -188,9 +202,9 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False, dtyp
|
|||||||
"""
|
"""
|
||||||
if not repeat_only:
|
if not repeat_only:
|
||||||
half = dim // 2
|
half = dim // 2
|
||||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
|
freqs = torch.exp(
|
||||||
device=timesteps.device
|
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
||||||
)
|
).to(device=timesteps.device)
|
||||||
args = timesteps[:, None].float() * freqs[None]
|
args = timesteps[:, None].float() * freqs[None]
|
||||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||||
if dim % 2:
|
if dim % 2:
|
||||||
|
@ -6,7 +6,9 @@ OPENAIUNETWRAPPER = "sgm.modules.diffusionmodules.wrappers.OpenAIWrapper"
|
|||||||
|
|
||||||
|
|
||||||
class IdentityWrapper(nn.Module):
|
class IdentityWrapper(nn.Module):
|
||||||
def __init__(self, diffusion_model, compile_model: bool = False, dtype: torch.dtype = torch.float32):
|
def __init__(
|
||||||
|
self, diffusion_model, compile_model: bool = False, dtype: torch.dtype = torch.float32
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
compile = (
|
compile = (
|
||||||
torch.compile
|
torch.compile
|
||||||
|
@ -87,8 +87,14 @@ def normal_kl(mean1, logvar1, mean2, logvar2):
|
|||||||
|
|
||||||
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
||||||
# Tensors, but it does not work for torch.exp().
|
# Tensors, but it does not work for torch.exp().
|
||||||
logvar1, logvar2 = [x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2)]
|
logvar1, logvar2 = [
|
||||||
|
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2)
|
||||||
|
]
|
||||||
|
|
||||||
return 0.5 * (
|
return 0.5 * (
|
||||||
-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
|
-1.0
|
||||||
|
+ logvar2
|
||||||
|
- logvar1
|
||||||
|
+ torch.exp(logvar1 - logvar2)
|
||||||
|
+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
|
||||||
)
|
)
|
||||||
|
@ -12,7 +12,9 @@ class LitEma(nn.Module):
|
|||||||
self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
|
self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"num_updates",
|
"num_updates",
|
||||||
torch.tensor(0, dtype=torch.int) if use_num_upates else torch.tensor(-1, dtype=torch.int),
|
torch.tensor(0, dtype=torch.int)
|
||||||
|
if use_num_upates
|
||||||
|
else torch.tensor(-1, dtype=torch.int),
|
||||||
)
|
)
|
||||||
|
|
||||||
for name, p in model.named_parameters():
|
for name, p in model.named_parameters():
|
||||||
@ -45,9 +47,11 @@ class LitEma(nn.Module):
|
|||||||
if m_param[key].requires_grad:
|
if m_param[key].requires_grad:
|
||||||
sname = self.m_name2s_name[key]
|
sname = self.m_name2s_name[key]
|
||||||
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
|
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
|
||||||
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
|
shadow_params[sname].sub_(
|
||||||
|
one_minus_decay * (shadow_params[sname] - m_param[key])
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
assert not key in self.m_name2s_name
|
assert key not in self.m_name2s_name
|
||||||
|
|
||||||
def copy_to(self, model):
|
def copy_to(self, model):
|
||||||
m_param = dict(model.named_parameters())
|
m_param = dict(model.named_parameters())
|
||||||
@ -56,7 +60,7 @@ class LitEma(nn.Module):
|
|||||||
if m_param[key].requires_grad:
|
if m_param[key].requires_grad:
|
||||||
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
|
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
|
||||||
else:
|
else:
|
||||||
assert not key in self.m_name2s_name
|
assert key not in self.m_name2s_name
|
||||||
|
|
||||||
def store(self, parameters):
|
def store(self, parameters):
|
||||||
"""
|
"""
|
||||||
|
@ -99,7 +99,9 @@ class GeneralConditioner(nn.Module):
|
|||||||
elif "input_keys" in embconfig:
|
elif "input_keys" in embconfig:
|
||||||
embedder.input_keys = embconfig["input_keys"]
|
embedder.input_keys = embconfig["input_keys"]
|
||||||
else:
|
else:
|
||||||
raise KeyError(f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}")
|
raise KeyError(
|
||||||
|
f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
embedder.legacy_ucg_val = embconfig.get("legacy_ucg_value", None)
|
embedder.legacy_ucg_val = embconfig.get("legacy_ucg_value", None)
|
||||||
if embedder.legacy_ucg_val is not None:
|
if embedder.legacy_ucg_val is not None:
|
||||||
@ -160,7 +162,10 @@ class GeneralConditioner(nn.Module):
|
|||||||
if cond_or_not is None:
|
if cond_or_not is None:
|
||||||
emb = (
|
emb = (
|
||||||
expand_dims_like(
|
expand_dims_like(
|
||||||
torch.bernoulli((1.0 - embedder.ucg_rate) * torch.ones(emb.shape[0], device=emb.device)),
|
torch.bernoulli(
|
||||||
|
(1.0 - embedder.ucg_rate)
|
||||||
|
* torch.ones(emb.shape[0], device=emb.device)
|
||||||
|
),
|
||||||
emb,
|
emb,
|
||||||
)
|
)
|
||||||
* emb
|
* emb
|
||||||
|
@ -96,7 +96,9 @@ class VideoTransformerBlock(nn.Module):
|
|||||||
if self.checkpoint:
|
if self.checkpoint:
|
||||||
print(f"{self.__class__.__name__} is using checkpointing")
|
print(f"{self.__class__.__name__} is using checkpointing")
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, context: torch.Tensor = None, timesteps: int = None) -> torch.Tensor:
|
def forward(
|
||||||
|
self, x: torch.Tensor, context: torch.Tensor = None, timesteps: int = None
|
||||||
|
) -> torch.Tensor:
|
||||||
if self.checkpoint:
|
if self.checkpoint:
|
||||||
return checkpoint(self._forward, x, context, timesteps)
|
return checkpoint(self._forward, x, context, timesteps)
|
||||||
else:
|
else:
|
||||||
@ -239,7 +241,9 @@ class SpatialVideoTransformer(SpatialTransformer):
|
|||||||
spatial_context = context
|
spatial_context = context
|
||||||
|
|
||||||
if self.use_spatial_context:
|
if self.use_spatial_context:
|
||||||
assert context.ndim == 3, f"n dims of spatial context should be 3 but are {context.ndim}"
|
assert (
|
||||||
|
context.ndim == 3
|
||||||
|
), f"n dims of spatial context should be 3 but are {context.ndim}"
|
||||||
|
|
||||||
time_context = context
|
time_context = context
|
||||||
time_context_first_timestep = time_context[::timesteps]
|
time_context_first_timestep = time_context[::timesteps]
|
||||||
|
@ -86,7 +86,9 @@ class SafeConv3d(torch.nn.Conv3d):
|
|||||||
input_chunks = torch.chunk(input, part_num, dim=2) # NCTHW
|
input_chunks = torch.chunk(input, part_num, dim=2) # NCTHW
|
||||||
if kernel_size > 1:
|
if kernel_size > 1:
|
||||||
input_chunks = [input_chunks[0]] + [
|
input_chunks = [input_chunks[0]] + [
|
||||||
torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2)
|
torch.cat(
|
||||||
|
(input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2
|
||||||
|
)
|
||||||
for i in range(1, len(input_chunks))
|
for i in range(1, len(input_chunks))
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -252,7 +254,7 @@ def count_params(model, verbose=False):
|
|||||||
|
|
||||||
|
|
||||||
def instantiate_from_config(config, **extra_kwargs):
|
def instantiate_from_config(config, **extra_kwargs):
|
||||||
if not "target" in config:
|
if "target" not in config:
|
||||||
if config == "__is_first_stage__":
|
if config == "__is_first_stage__":
|
||||||
return None
|
return None
|
||||||
elif config == "__is_unconditional__":
|
elif config == "__is_unconditional__":
|
||||||
|
@ -93,7 +93,12 @@ class SimpleDistributedWebDataset(DataPipeline):
|
|||||||
|
|
||||||
|
|
||||||
def tar_file_iterator_with_meta(
|
def tar_file_iterator_with_meta(
|
||||||
fileobj, meta_names, skip_meta=r"__[^/]*__($|/)", suffix=None, handler=reraise_exception, meta_stream=None
|
fileobj,
|
||||||
|
meta_names,
|
||||||
|
skip_meta=r"__[^/]*__($|/)",
|
||||||
|
suffix=None,
|
||||||
|
handler=reraise_exception,
|
||||||
|
meta_stream=None,
|
||||||
):
|
):
|
||||||
"""Iterate over tar file, yielding filename, content pairs for the given tar stream.
|
"""Iterate over tar file, yielding filename, content pairs for the given tar stream.
|
||||||
|
|
||||||
@ -122,10 +127,13 @@ def tar_file_iterator_with_meta(
|
|||||||
except Exception as exn:
|
except Exception as exn:
|
||||||
from sat.helpers import print_rank0
|
from sat.helpers import print_rank0
|
||||||
|
|
||||||
print_rank0(f"Error in loading jsonl {meta_file_name}, lineno {lineno}: {line}", level="DEBUG")
|
print_rank0(
|
||||||
|
f"Error in loading jsonl {meta_file_name}, lineno {lineno}: {line}",
|
||||||
|
level="DEBUG",
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
for item in meta_list:
|
for item in meta_list:
|
||||||
if not item["key"] in meta_data:
|
if item["key"] not in meta_data:
|
||||||
meta_data[item["key"]] = {}
|
meta_data[item["key"]] = {}
|
||||||
for meta_name in meta_names:
|
for meta_name in meta_names:
|
||||||
if meta_name in item:
|
if meta_name in item:
|
||||||
@ -186,7 +194,9 @@ def tar_file_expander_with_meta(data, meta_names, handler=reraise_exception):
|
|||||||
try:
|
try:
|
||||||
assert isinstance(source, dict)
|
assert isinstance(source, dict)
|
||||||
assert "stream" in source
|
assert "stream" in source
|
||||||
for sample in tar_file_iterator_with_meta(source["stream"], meta_names, meta_stream=source["meta_stream"]):
|
for sample in tar_file_iterator_with_meta(
|
||||||
|
source["stream"], meta_names, meta_stream=source["meta_stream"]
|
||||||
|
):
|
||||||
assert isinstance(sample, dict) and "data" in sample and "fname" in sample
|
assert isinstance(sample, dict) and "data" in sample and "fname" in sample
|
||||||
sample["__url__"] = url
|
sample["__url__"] = url
|
||||||
yield sample
|
yield sample
|
||||||
@ -250,7 +260,15 @@ class MetaDistributedWebDataset(DataPipeline):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, path, process_fn, seed, *, meta_names=[], nshards=sys.maxsize, shuffle_buffer=1000, include_dirs=None
|
self,
|
||||||
|
path,
|
||||||
|
process_fn,
|
||||||
|
seed,
|
||||||
|
*,
|
||||||
|
meta_names=[],
|
||||||
|
nshards=sys.maxsize,
|
||||||
|
shuffle_buffer=1000,
|
||||||
|
include_dirs=None,
|
||||||
):
|
):
|
||||||
# os.environ['WDS_SHOW_SEED'] = '1'
|
# os.environ['WDS_SHOW_SEED'] = '1'
|
||||||
import torch
|
import torch
|
||||||
@ -361,7 +379,10 @@ def gopen_boto3(url, mode="rb", bufsize=8192 * 2):
|
|||||||
|
|
||||||
if mode[0] == "r":
|
if mode[0] == "r":
|
||||||
s3_client = boto3.client(
|
s3_client = boto3.client(
|
||||||
"s3", endpoint_url=endpoint_url, aws_access_key_id=access_key, aws_secret_access_key=secret_key
|
"s3",
|
||||||
|
endpoint_url=endpoint_url,
|
||||||
|
aws_access_key_id=access_key,
|
||||||
|
aws_secret_access_key=secret_key,
|
||||||
)
|
)
|
||||||
bucket, key = url.split("/", 1)
|
bucket, key = url.split("/", 1)
|
||||||
|
|
||||||
|
@ -37,7 +37,9 @@ def save_texts(texts, save_dir, iterations):
|
|||||||
f.write(text + "\n")
|
f.write(text + "\n")
|
||||||
|
|
||||||
|
|
||||||
def save_video_as_grid_and_mp4(video_batch: torch.Tensor, save_path: str, T: int, fps: int = 5, args=None, key=None):
|
def save_video_as_grid_and_mp4(
|
||||||
|
video_batch: torch.Tensor, save_path: str, T: int, fps: int = 5, args=None, key=None
|
||||||
|
):
|
||||||
os.makedirs(save_path, exist_ok=True)
|
os.makedirs(save_path, exist_ok=True)
|
||||||
|
|
||||||
for i, vid in enumerate(video_batch):
|
for i, vid in enumerate(video_batch):
|
||||||
@ -52,7 +54,8 @@ def save_video_as_grid_and_mp4(video_batch: torch.Tensor, save_path: str, T: int
|
|||||||
writer.append_data(frame)
|
writer.append_data(frame)
|
||||||
if args is not None and args.wandb:
|
if args is not None and args.wandb:
|
||||||
wandb.log(
|
wandb.log(
|
||||||
{key + f"_video_{i}": wandb.Video(now_save_path, fps=fps, format="mp4")}, step=args.iteration + 1
|
{key + f"_video_{i}": wandb.Video(now_save_path, fps=fps, format="mp4")},
|
||||||
|
step=args.iteration + 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -138,7 +141,9 @@ def broad_cast_batch(batch):
|
|||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
|
||||||
def forward_step_eval(data_iterator, model, args, timers, only_log_video_latents=False, data_class=None):
|
def forward_step_eval(
|
||||||
|
data_iterator, model, args, timers, only_log_video_latents=False, data_class=None
|
||||||
|
):
|
||||||
if mpu.get_model_parallel_rank() == 0:
|
if mpu.get_model_parallel_rank() == 0:
|
||||||
timers("data loader").start()
|
timers("data loader").start()
|
||||||
batch_video = next(data_iterator)
|
batch_video = next(data_iterator)
|
||||||
@ -209,7 +214,9 @@ if __name__ == "__main__":
|
|||||||
args = argparse.Namespace(**vars(args), **vars(known))
|
args = argparse.Namespace(**vars(args), **vars(known))
|
||||||
|
|
||||||
data_class = get_obj_from_str(args.data_config["target"])
|
data_class = get_obj_from_str(args.data_config["target"])
|
||||||
create_dataset_function = partial(data_class.create_dataset_function, **args.data_config["params"])
|
create_dataset_function = partial(
|
||||||
|
data_class.create_dataset_function, **args.data_config["params"]
|
||||||
|
)
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
@ -225,7 +232,9 @@ if __name__ == "__main__":
|
|||||||
model_cls=SATVideoDiffusionEngine,
|
model_cls=SATVideoDiffusionEngine,
|
||||||
forward_step_function=partial(forward_step, data_class=data_class),
|
forward_step_function=partial(forward_step, data_class=data_class),
|
||||||
forward_step_eval=partial(
|
forward_step_eval=partial(
|
||||||
forward_step_eval, data_class=data_class, only_log_video_latents=args.only_log_video_latents
|
forward_step_eval,
|
||||||
|
data_class=data_class,
|
||||||
|
only_log_video_latents=args.only_log_video_latents,
|
||||||
),
|
),
|
||||||
create_dataset_function=create_dataset_function,
|
create_dataset_function=create_dataset_function,
|
||||||
)
|
)
|
||||||
|
@ -94,7 +94,11 @@ class FeedForward(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
inner_dim = int(dim * mult)
|
inner_dim = int(dim * mult)
|
||||||
dim_out = default(dim_out, dim)
|
dim_out = default(dim_out, dim)
|
||||||
project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
|
project_in = (
|
||||||
|
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
|
||||||
|
if not glu
|
||||||
|
else GEGLU(dim, inner_dim)
|
||||||
|
)
|
||||||
|
|
||||||
self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
|
self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
|
||||||
|
|
||||||
@ -126,7 +130,9 @@ class LinearAttention(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
b, c, h, w = x.shape
|
b, c, h, w = x.shape
|
||||||
qkv = self.to_qkv(x)
|
qkv = self.to_qkv(x)
|
||||||
q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
|
q, k, v = rearrange(
|
||||||
|
qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
|
||||||
|
)
|
||||||
k = k.softmax(dim=-1)
|
k = k.softmax(dim=-1)
|
||||||
context = torch.einsum("bhdn,bhen->bhde", k, v)
|
context = torch.einsum("bhdn,bhen->bhde", k, v)
|
||||||
out = torch.einsum("bhde,bhdn->bhen", context, q)
|
out = torch.einsum("bhde,bhdn->bhen", context, q)
|
||||||
@ -143,7 +149,9 @@ class SpatialSelfAttention(nn.Module):
|
|||||||
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
self.proj_out = torch.nn.Conv2d(
|
||||||
|
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
h_ = x
|
h_ = x
|
||||||
@ -244,7 +252,9 @@ class CrossAttention(nn.Module):
|
|||||||
# new
|
# new
|
||||||
with sdp_kernel(**BACKEND_MAP[self.backend]):
|
with sdp_kernel(**BACKEND_MAP[self.backend]):
|
||||||
# print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
|
# print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
|
||||||
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) # scale is dim_head ** -0.5 per default
|
out = F.scaled_dot_product_attention(
|
||||||
|
q, k, v, attn_mask=mask
|
||||||
|
) # scale is dim_head ** -0.5 per default
|
||||||
|
|
||||||
del q, k, v
|
del q, k, v
|
||||||
out = rearrange(out, "b h n d -> b n (h d)", h=h)
|
out = rearrange(out, "b h n d -> b n (h d)", h=h)
|
||||||
@ -422,7 +432,9 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
self.norm1(x),
|
self.norm1(x),
|
||||||
context=context if self.disable_self_attn else None,
|
context=context if self.disable_self_attn else None,
|
||||||
additional_tokens=additional_tokens,
|
additional_tokens=additional_tokens,
|
||||||
n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self if not self.disable_self_attn else 0,
|
n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self
|
||||||
|
if not self.disable_self_attn
|
||||||
|
else 0,
|
||||||
)
|
)
|
||||||
+ x
|
+ x
|
||||||
)
|
)
|
||||||
@ -499,7 +511,9 @@ class SpatialTransformer(nn.Module):
|
|||||||
sdp_backend=None,
|
sdp_backend=None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
print(f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads")
|
print(
|
||||||
|
f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads"
|
||||||
|
)
|
||||||
from omegaconf import ListConfig
|
from omegaconf import ListConfig
|
||||||
|
|
||||||
if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)):
|
if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)):
|
||||||
@ -542,7 +556,9 @@ class SpatialTransformer(nn.Module):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
if not use_linear:
|
if not use_linear:
|
||||||
self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
|
self.proj_out = zero_module(
|
||||||
|
nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
|
# self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
|
||||||
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
|
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
|
||||||
|
@ -97,9 +97,7 @@ class AbstractAutoencoder(pl.LightningModule):
|
|||||||
|
|
||||||
def instantiate_optimizer_from_config(self, params, lr, cfg):
|
def instantiate_optimizer_from_config(self, params, lr, cfg):
|
||||||
logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
|
logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
|
||||||
return get_obj_from_str(cfg["target"])(
|
return get_obj_from_str(cfg["target"])(params, lr=lr, **cfg.get("params", dict()))
|
||||||
params, lr=lr, **cfg.get("params", dict())
|
|
||||||
)
|
|
||||||
|
|
||||||
def configure_optimizers(self) -> Any:
|
def configure_optimizers(self) -> Any:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
@ -214,14 +212,20 @@ class AutoencodingEngine(AbstractAutoencoder):
|
|||||||
x = self.decoder(z, **kwargs)
|
x = self.decoder(z, **kwargs)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, **additional_decode_kwargs) -> Tuple[torch.Tensor, torch.Tensor, dict]:
|
def forward(
|
||||||
|
self, x: torch.Tensor, **additional_decode_kwargs
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
|
||||||
z, reg_log = self.encode(x, return_reg_log=True)
|
z, reg_log = self.encode(x, return_reg_log=True)
|
||||||
dec = self.decode(z, **additional_decode_kwargs)
|
dec = self.decode(z, **additional_decode_kwargs)
|
||||||
return z, dec, reg_log
|
return z, dec, reg_log
|
||||||
|
|
||||||
def inner_training_step(self, batch: dict, batch_idx: int, optimizer_idx: int = 0) -> torch.Tensor:
|
def inner_training_step(
|
||||||
|
self, batch: dict, batch_idx: int, optimizer_idx: int = 0
|
||||||
|
) -> torch.Tensor:
|
||||||
x = self.get_input(batch)
|
x = self.get_input(batch)
|
||||||
additional_decode_kwargs = {key: batch[key] for key in self.additional_decode_keys.intersection(batch)}
|
additional_decode_kwargs = {
|
||||||
|
key: batch[key] for key in self.additional_decode_keys.intersection(batch)
|
||||||
|
}
|
||||||
z, xrec, regularization_log = self(x, **additional_decode_kwargs)
|
z, xrec, regularization_log = self(x, **additional_decode_kwargs)
|
||||||
if hasattr(self.loss, "forward_keys"):
|
if hasattr(self.loss, "forward_keys"):
|
||||||
extra_info = {
|
extra_info = {
|
||||||
@ -357,12 +361,16 @@ class AutoencodingEngine(AbstractAutoencoder):
|
|||||||
if self.trainable_ae_params is None:
|
if self.trainable_ae_params is None:
|
||||||
ae_params = self.get_autoencoder_params()
|
ae_params = self.get_autoencoder_params()
|
||||||
else:
|
else:
|
||||||
ae_params, num_ae_params = self.get_param_groups(self.trainable_ae_params, self.ae_optimizer_args)
|
ae_params, num_ae_params = self.get_param_groups(
|
||||||
|
self.trainable_ae_params, self.ae_optimizer_args
|
||||||
|
)
|
||||||
logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}")
|
logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}")
|
||||||
if self.trainable_disc_params is None:
|
if self.trainable_disc_params is None:
|
||||||
disc_params = self.get_discriminator_params()
|
disc_params = self.get_discriminator_params()
|
||||||
else:
|
else:
|
||||||
disc_params, num_disc_params = self.get_param_groups(self.trainable_disc_params, self.disc_optimizer_args)
|
disc_params, num_disc_params = self.get_param_groups(
|
||||||
|
self.trainable_disc_params, self.disc_optimizer_args
|
||||||
|
)
|
||||||
logpy.info(f"Number of trainable discriminator parameters: {num_disc_params:,}")
|
logpy.info(f"Number of trainable discriminator parameters: {num_disc_params:,}")
|
||||||
opt_ae = self.instantiate_optimizer_from_config(
|
opt_ae = self.instantiate_optimizer_from_config(
|
||||||
ae_params,
|
ae_params,
|
||||||
@ -371,17 +379,23 @@ class AutoencodingEngine(AbstractAutoencoder):
|
|||||||
)
|
)
|
||||||
opts = [opt_ae]
|
opts = [opt_ae]
|
||||||
if len(disc_params) > 0:
|
if len(disc_params) > 0:
|
||||||
opt_disc = self.instantiate_optimizer_from_config(disc_params, self.learning_rate, self.optimizer_config)
|
opt_disc = self.instantiate_optimizer_from_config(
|
||||||
|
disc_params, self.learning_rate, self.optimizer_config
|
||||||
|
)
|
||||||
opts.append(opt_disc)
|
opts.append(opt_disc)
|
||||||
|
|
||||||
return opts
|
return opts
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def log_images(self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs) -> dict:
|
def log_images(
|
||||||
|
self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs
|
||||||
|
) -> dict:
|
||||||
log = dict()
|
log = dict()
|
||||||
additional_decode_kwargs = {}
|
additional_decode_kwargs = {}
|
||||||
x = self.get_input(batch)
|
x = self.get_input(batch)
|
||||||
additional_decode_kwargs.update({key: batch[key] for key in self.additional_decode_keys.intersection(batch)})
|
additional_decode_kwargs.update(
|
||||||
|
{key: batch[key] for key in self.additional_decode_keys.intersection(batch)}
|
||||||
|
)
|
||||||
|
|
||||||
_, xrec, _ = self(x, **additional_decode_kwargs)
|
_, xrec, _ = self(x, **additional_decode_kwargs)
|
||||||
log["inputs"] = x
|
log["inputs"] = x
|
||||||
@ -400,7 +414,9 @@ class AutoencodingEngine(AbstractAutoencoder):
|
|||||||
diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x)
|
diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x)
|
||||||
diff_ema.clamp_(0, 1.0)
|
diff_ema.clamp_(0, 1.0)
|
||||||
log["diff_ema"] = 2.0 * diff_ema - 1.0
|
log["diff_ema"] = 2.0 * diff_ema - 1.0
|
||||||
log["diff_boost_ema"] = 2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1
|
log["diff_boost_ema"] = (
|
||||||
|
2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1
|
||||||
|
)
|
||||||
if additional_log_kwargs:
|
if additional_log_kwargs:
|
||||||
additional_decode_kwargs.update(additional_log_kwargs)
|
additional_decode_kwargs.update(additional_log_kwargs)
|
||||||
_, xrec_add, _ = self(x, **additional_decode_kwargs)
|
_, xrec_add, _ = self(x, **additional_decode_kwargs)
|
||||||
@ -442,7 +458,9 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
|
|||||||
params = super().get_autoencoder_params()
|
params = super().get_autoencoder_params()
|
||||||
return params
|
return params
|
||||||
|
|
||||||
def encode(self, x: torch.Tensor, return_reg_log: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
def encode(
|
||||||
|
self, x: torch.Tensor, return_reg_log: bool = False
|
||||||
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
||||||
if self.max_batch_size is None:
|
if self.max_batch_size is None:
|
||||||
z = self.encoder(x)
|
z = self.encoder(x)
|
||||||
z = self.quant_conv(z)
|
z = self.quant_conv(z)
|
||||||
@ -485,7 +503,9 @@ class AutoencoderKL(AutoencodingEngineLegacy):
|
|||||||
if "lossconfig" in kwargs:
|
if "lossconfig" in kwargs:
|
||||||
kwargs["loss_config"] = kwargs.pop("lossconfig")
|
kwargs["loss_config"] = kwargs.pop("lossconfig")
|
||||||
super().__init__(
|
super().__init__(
|
||||||
regularizer_config={"target": ("sgm.modules.autoencoding.regularizers" ".DiagonalGaussianRegularizer")},
|
regularizer_config={
|
||||||
|
"target": ("sgm.modules.autoencoding.regularizers" ".DiagonalGaussianRegularizer")
|
||||||
|
},
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -519,7 +539,9 @@ class VideoAutoencodingEngine(AutoencodingEngine):
|
|||||||
if ckpt_path is not None:
|
if ckpt_path is not None:
|
||||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||||
|
|
||||||
def log_videos(self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs) -> dict:
|
def log_videos(
|
||||||
|
self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs
|
||||||
|
) -> dict:
|
||||||
return self.log_images(batch, additional_log_kwargs, **kwargs)
|
return self.log_images(batch, additional_log_kwargs, **kwargs)
|
||||||
|
|
||||||
def get_input(self, batch: dict) -> torch.Tensor:
|
def get_input(self, batch: dict) -> torch.Tensor:
|
||||||
@ -530,7 +552,9 @@ class VideoAutoencodingEngine(AutoencodingEngine):
|
|||||||
batch = batch[self.input_key]
|
batch = batch[self.input_key]
|
||||||
|
|
||||||
global_src_rank = get_context_parallel_group_rank() * self.context_parallel_size
|
global_src_rank = get_context_parallel_group_rank() * self.context_parallel_size
|
||||||
torch.distributed.broadcast(batch, src=global_src_rank, group=get_context_parallel_group())
|
torch.distributed.broadcast(
|
||||||
|
batch, src=global_src_rank, group=get_context_parallel_group()
|
||||||
|
)
|
||||||
|
|
||||||
batch = _conv_split(batch, dim=2, kernel_size=1)
|
batch = _conv_split(batch, dim=2, kernel_size=1)
|
||||||
return batch
|
return batch
|
||||||
|
@ -201,7 +201,9 @@ def _pass_from_previous_rank(input_, dim, kernel_size):
|
|||||||
recv_rank += cp_world_size
|
recv_rank += cp_world_size
|
||||||
|
|
||||||
if cp_rank < cp_world_size - 1:
|
if cp_rank < cp_world_size - 1:
|
||||||
req_send = torch.distributed.isend(input_[-kernel_size + 1 :].contiguous(), send_rank, group=group)
|
req_send = torch.distributed.isend(
|
||||||
|
input_[-kernel_size + 1 :].contiguous(), send_rank, group=group
|
||||||
|
)
|
||||||
if cp_rank > 0:
|
if cp_rank > 0:
|
||||||
recv_buffer = torch.empty_like(input_[-kernel_size + 1 :]).contiguous()
|
recv_buffer = torch.empty_like(input_[-kernel_size + 1 :]).contiguous()
|
||||||
req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group)
|
req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group)
|
||||||
@ -246,11 +248,12 @@ def _fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding=Non
|
|||||||
|
|
||||||
recv_buffer = torch.empty_like(input_[-kernel_size + 1 :]).contiguous()
|
recv_buffer = torch.empty_like(input_[-kernel_size + 1 :]).contiguous()
|
||||||
if cp_rank < cp_world_size - 1:
|
if cp_rank < cp_world_size - 1:
|
||||||
req_send = torch.distributed.isend(input_[-kernel_size + 1 :].contiguous(), send_rank, group=group)
|
req_send = torch.distributed.isend(
|
||||||
|
input_[-kernel_size + 1 :].contiguous(), send_rank, group=group
|
||||||
|
)
|
||||||
if cp_rank > 0:
|
if cp_rank > 0:
|
||||||
req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group)
|
req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group)
|
||||||
|
|
||||||
|
|
||||||
if cp_rank == 0:
|
if cp_rank == 0:
|
||||||
if cache_padding is not None:
|
if cache_padding is not None:
|
||||||
input_ = torch.cat([cache_padding.transpose(0, dim).to(input_.device), input_], dim=0)
|
input_ = torch.cat([cache_padding.transpose(0, dim).to(input_.device), input_], dim=0)
|
||||||
@ -334,7 +337,9 @@ def fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding):
|
|||||||
|
|
||||||
|
|
||||||
class ContextParallelCausalConv3d(nn.Module):
|
class ContextParallelCausalConv3d(nn.Module):
|
||||||
def __init__(self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], stride=1, **kwargs):
|
def __init__(
|
||||||
|
self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], stride=1, **kwargs
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
kernel_size = cast_tuple(kernel_size, 3)
|
kernel_size = cast_tuple(kernel_size, 3)
|
||||||
|
|
||||||
@ -354,7 +359,9 @@ class ContextParallelCausalConv3d(nn.Module):
|
|||||||
|
|
||||||
stride = (stride, stride, stride)
|
stride = (stride, stride, stride)
|
||||||
dilation = (1, 1, 1)
|
dilation = (1, 1, 1)
|
||||||
self.conv = Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
self.conv = Conv3d(
|
||||||
|
chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs
|
||||||
|
)
|
||||||
self.cache_padding = None
|
self.cache_padding = None
|
||||||
|
|
||||||
def forward(self, input_, clear_cache=True):
|
def forward(self, input_, clear_cache=True):
|
||||||
@ -369,7 +376,11 @@ class ContextParallelCausalConv3d(nn.Module):
|
|||||||
global_rank = torch.distributed.get_rank()
|
global_rank = torch.distributed.get_rank()
|
||||||
if cp_world_size == 1:
|
if cp_world_size == 1:
|
||||||
self.cache_padding = (
|
self.cache_padding = (
|
||||||
input_parallel[:, :, -self.time_kernel_size + 1 :].contiguous().detach().clone().cpu()
|
input_parallel[:, :, -self.time_kernel_size + 1 :]
|
||||||
|
.contiguous()
|
||||||
|
.detach()
|
||||||
|
.clone()
|
||||||
|
.cpu()
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if cp_rank == cp_world_size - 1:
|
if cp_rank == cp_world_size - 1:
|
||||||
@ -379,9 +390,13 @@ class ContextParallelCausalConv3d(nn.Module):
|
|||||||
group=get_context_parallel_group(),
|
group=get_context_parallel_group(),
|
||||||
)
|
)
|
||||||
if cp_rank == 0:
|
if cp_rank == 0:
|
||||||
recv_buffer = torch.empty_like(input_parallel[:, :, -self.time_kernel_size + 1 :]).contiguous()
|
recv_buffer = torch.empty_like(
|
||||||
|
input_parallel[:, :, -self.time_kernel_size + 1 :]
|
||||||
|
).contiguous()
|
||||||
torch.distributed.recv(
|
torch.distributed.recv(
|
||||||
recv_buffer, global_rank - 1 + cp_world_size, group=get_context_parallel_group()
|
recv_buffer,
|
||||||
|
global_rank - 1 + cp_world_size,
|
||||||
|
group=get_context_parallel_group(),
|
||||||
)
|
)
|
||||||
self.cache_padding = recv_buffer.contiguous().detach().clone().cpu()
|
self.cache_padding = recv_buffer.contiguous().detach().clone().cpu()
|
||||||
|
|
||||||
@ -406,7 +421,9 @@ class ContextParallelGroupNorm(torch.nn.GroupNorm):
|
|||||||
|
|
||||||
def Normalize(in_channels, gather=False, **kwargs):
|
def Normalize(in_channels, gather=False, **kwargs):
|
||||||
if gather:
|
if gather:
|
||||||
return ContextParallelGroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
return ContextParallelGroupNorm(
|
||||||
|
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||||
|
|
||||||
@ -460,7 +477,8 @@ class SpatialNorm3D(nn.Module):
|
|||||||
|
|
||||||
zq_rest_splits = torch.split(zq_rest, 32, dim=1)
|
zq_rest_splits = torch.split(zq_rest, 32, dim=1)
|
||||||
interpolated_splits = [
|
interpolated_splits = [
|
||||||
torch.nn.functional.interpolate(split, size=f_rest_size, mode="nearest") for split in zq_rest_splits
|
torch.nn.functional.interpolate(split, size=f_rest_size, mode="nearest")
|
||||||
|
for split in zq_rest_splits
|
||||||
]
|
]
|
||||||
|
|
||||||
zq_rest = torch.cat(interpolated_splits, dim=1)
|
zq_rest = torch.cat(interpolated_splits, dim=1)
|
||||||
@ -471,7 +489,8 @@ class SpatialNorm3D(nn.Module):
|
|||||||
|
|
||||||
zq_splits = torch.split(zq, 32, dim=1)
|
zq_splits = torch.split(zq, 32, dim=1)
|
||||||
interpolated_splits = [
|
interpolated_splits = [
|
||||||
torch.nn.functional.interpolate(split, size=f_size, mode="nearest") for split in zq_splits
|
torch.nn.functional.interpolate(split, size=f_size, mode="nearest")
|
||||||
|
for split in zq_splits
|
||||||
]
|
]
|
||||||
zq = torch.cat(interpolated_splits, dim=1)
|
zq = torch.cat(interpolated_splits, dim=1)
|
||||||
|
|
||||||
@ -511,7 +530,9 @@ class Upsample3D(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.with_conv = with_conv
|
self.with_conv = with_conv
|
||||||
if self.with_conv:
|
if self.with_conv:
|
||||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
self.conv = torch.nn.Conv2d(
|
||||||
|
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
||||||
|
)
|
||||||
self.compress_time = compress_time
|
self.compress_time = compress_time
|
||||||
|
|
||||||
def forward(self, x, fake_cp=True):
|
def forward(self, x, fake_cp=True):
|
||||||
@ -523,14 +544,16 @@ class Upsample3D(nn.Module):
|
|||||||
|
|
||||||
splits = torch.split(x_rest, 32, dim=1)
|
splits = torch.split(x_rest, 32, dim=1)
|
||||||
interpolated_splits = [
|
interpolated_splits = [
|
||||||
torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits
|
torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest")
|
||||||
|
for split in splits
|
||||||
]
|
]
|
||||||
x_rest = torch.cat(interpolated_splits, dim=1)
|
x_rest = torch.cat(interpolated_splits, dim=1)
|
||||||
x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2)
|
x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2)
|
||||||
else:
|
else:
|
||||||
splits = torch.split(x, 32, dim=1)
|
splits = torch.split(x, 32, dim=1)
|
||||||
interpolated_splits = [
|
interpolated_splits = [
|
||||||
torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits
|
torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest")
|
||||||
|
for split in splits
|
||||||
]
|
]
|
||||||
x = torch.cat(interpolated_splits, dim=1)
|
x = torch.cat(interpolated_splits, dim=1)
|
||||||
|
|
||||||
@ -541,7 +564,8 @@ class Upsample3D(nn.Module):
|
|||||||
|
|
||||||
splits = torch.split(x, 32, dim=1)
|
splits = torch.split(x, 32, dim=1)
|
||||||
interpolated_splits = [
|
interpolated_splits = [
|
||||||
torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits
|
torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest")
|
||||||
|
for split in splits
|
||||||
]
|
]
|
||||||
x = torch.cat(interpolated_splits, dim=1)
|
x = torch.cat(interpolated_splits, dim=1)
|
||||||
|
|
||||||
@ -563,7 +587,9 @@ class DownSample3D(nn.Module):
|
|||||||
out_channels = in_channels
|
out_channels = in_channels
|
||||||
if self.with_conv:
|
if self.with_conv:
|
||||||
# no asymmetric padding in torch conv, must do it ourselves
|
# no asymmetric padding in torch conv, must do it ourselves
|
||||||
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0)
|
self.conv = torch.nn.Conv2d(
|
||||||
|
in_channels, out_channels, kernel_size=3, stride=2, padding=0
|
||||||
|
)
|
||||||
self.compress_time = compress_time
|
self.compress_time = compress_time
|
||||||
|
|
||||||
def forward(self, x, fake_cp=True):
|
def forward(self, x, fake_cp=True):
|
||||||
@ -578,7 +604,8 @@ class DownSample3D(nn.Module):
|
|||||||
if x_rest.shape[-1] > 0:
|
if x_rest.shape[-1] > 0:
|
||||||
splits = torch.split(x_rest, 32, dim=1)
|
splits = torch.split(x_rest, 32, dim=1)
|
||||||
interpolated_splits = [
|
interpolated_splits = [
|
||||||
torch.nn.functional.avg_pool1d(split, kernel_size=2, stride=2) for split in splits
|
torch.nn.functional.avg_pool1d(split, kernel_size=2, stride=2)
|
||||||
|
for split in splits
|
||||||
]
|
]
|
||||||
x_rest = torch.cat(interpolated_splits, dim=1)
|
x_rest = torch.cat(interpolated_splits, dim=1)
|
||||||
x = torch.cat([x_first[..., None], x_rest], dim=-1)
|
x = torch.cat([x_first[..., None], x_rest], dim=-1)
|
||||||
@ -587,7 +614,8 @@ class DownSample3D(nn.Module):
|
|||||||
# x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2)
|
# x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2)
|
||||||
splits = torch.split(x, 32, dim=1)
|
splits = torch.split(x, 32, dim=1)
|
||||||
interpolated_splits = [
|
interpolated_splits = [
|
||||||
torch.nn.functional.avg_pool1d(split, kernel_size=2, stride=2) for split in splits
|
torch.nn.functional.avg_pool1d(split, kernel_size=2, stride=2)
|
||||||
|
for split in splits
|
||||||
]
|
]
|
||||||
x = torch.cat(interpolated_splits, dim=1)
|
x = torch.cat(interpolated_splits, dim=1)
|
||||||
x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w)
|
x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w)
|
||||||
@ -923,9 +951,13 @@ class ContextParallelDecoder3D(nn.Module):
|
|||||||
up.attn = attn
|
up.attn = attn
|
||||||
if i_level != 0:
|
if i_level != 0:
|
||||||
if i_level < self.num_resolutions - self.temporal_compress_level:
|
if i_level < self.num_resolutions - self.temporal_compress_level:
|
||||||
up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False)
|
up.upsample = Upsample3D(
|
||||||
|
block_in, with_conv=resamp_with_conv, compress_time=False
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=True)
|
up.upsample = Upsample3D(
|
||||||
|
block_in, with_conv=resamp_with_conv, compress_time=True
|
||||||
|
)
|
||||||
self.up.insert(0, up)
|
self.up.insert(0, up)
|
||||||
|
|
||||||
self.norm_out = Normalize3D(block_in, zq_ch, add_conv=add_conv, gather=gather_norm)
|
self.norm_out = Normalize3D(block_in, zq_ch, add_conv=add_conv, gather=gather_norm)
|
||||||
|
@ -12,7 +12,9 @@ class LitEma(nn.Module):
|
|||||||
self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
|
self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"num_updates",
|
"num_updates",
|
||||||
torch.tensor(0, dtype=torch.int) if use_num_upates else torch.tensor(-1, dtype=torch.int),
|
torch.tensor(0, dtype=torch.int)
|
||||||
|
if use_num_upates
|
||||||
|
else torch.tensor(-1, dtype=torch.int),
|
||||||
)
|
)
|
||||||
|
|
||||||
for name, p in model.named_parameters():
|
for name, p in model.named_parameters():
|
||||||
@ -45,9 +47,11 @@ class LitEma(nn.Module):
|
|||||||
if m_param[key].requires_grad:
|
if m_param[key].requires_grad:
|
||||||
sname = self.m_name2s_name[key]
|
sname = self.m_name2s_name[key]
|
||||||
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
|
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
|
||||||
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
|
shadow_params[sname].sub_(
|
||||||
|
one_minus_decay * (shadow_params[sname] - m_param[key])
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
assert not key in self.m_name2s_name
|
assert key not in self.m_name2s_name
|
||||||
|
|
||||||
def copy_to(self, model):
|
def copy_to(self, model):
|
||||||
m_param = dict(model.named_parameters())
|
m_param = dict(model.named_parameters())
|
||||||
@ -56,7 +60,7 @@ class LitEma(nn.Module):
|
|||||||
if m_param[key].requires_grad:
|
if m_param[key].requires_grad:
|
||||||
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
|
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
|
||||||
else:
|
else:
|
||||||
assert not key in self.m_name2s_name
|
assert key not in self.m_name2s_name
|
||||||
|
|
||||||
def store(self, parameters):
|
def store(self, parameters):
|
||||||
"""
|
"""
|
||||||
|
@ -77,7 +77,9 @@ class IdentityRegularizer(AbstractRegularizer):
|
|||||||
yield from ()
|
yield from ()
|
||||||
|
|
||||||
|
|
||||||
def measure_perplexity(predicted_indices: torch.Tensor, num_centroids: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
def measure_perplexity(
|
||||||
|
predicted_indices: torch.Tensor, num_centroids: int
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
|
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
|
||||||
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
|
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
|
||||||
encodings = F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids)
|
encodings = F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids)
|
||||||
|
@ -78,7 +78,9 @@ class SafeConv3d(torch.nn.Conv3d):
|
|||||||
input_chunks = torch.chunk(input, part_num, dim=2) # NCTHW
|
input_chunks = torch.chunk(input, part_num, dim=2) # NCTHW
|
||||||
if kernel_size > 1:
|
if kernel_size > 1:
|
||||||
input_chunks = [input_chunks[0]] + [
|
input_chunks = [input_chunks[0]] + [
|
||||||
torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2)
|
torch.cat(
|
||||||
|
(input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2
|
||||||
|
)
|
||||||
for i in range(1, len(input_chunks))
|
for i in range(1, len(input_chunks))
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -244,7 +246,7 @@ def count_params(model, verbose=False):
|
|||||||
|
|
||||||
|
|
||||||
def instantiate_from_config(config):
|
def instantiate_from_config(config):
|
||||||
if not "target" in config:
|
if "target" not in config:
|
||||||
if config == "__is_first_stage__":
|
if config == "__is_first_stage__":
|
||||||
return None
|
return None
|
||||||
elif config == "__is_unconditional__":
|
elif config == "__is_unconditional__":
|
||||||
|
@ -9,11 +9,16 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
|||||||
MODEL_PATH = "THUDM/cogvlm2-llama3-caption"
|
MODEL_PATH = "THUDM/cogvlm2-llama3-caption"
|
||||||
|
|
||||||
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
TORCH_TYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[
|
TORCH_TYPE = (
|
||||||
0] >= 8 else torch.float16
|
torch.bfloat16
|
||||||
|
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8
|
||||||
|
else torch.float16
|
||||||
|
)
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="CogVLM2-Video CLI Demo")
|
parser = argparse.ArgumentParser(description="CogVLM2-Video CLI Demo")
|
||||||
parser.add_argument('--quant', type=int, choices=[4, 8], help='Enable 4-bit or 8-bit precision loading', default=0)
|
parser.add_argument(
|
||||||
|
'--quant', type=int, choices=[4, 8], help='Enable 4-bit or 8-bit precision loading', default=0
|
||||||
|
)
|
||||||
args = parser.parse_args([])
|
args = parser.parse_args([])
|
||||||
|
|
||||||
|
|
||||||
@ -29,8 +34,11 @@ def load_video(video_data, strategy='chat'):
|
|||||||
clip_end_sec = 60
|
clip_end_sec = 60
|
||||||
clip_start_sec = 0
|
clip_start_sec = 0
|
||||||
start_frame = int(clip_start_sec * decord_vr.get_avg_fps())
|
start_frame = int(clip_start_sec * decord_vr.get_avg_fps())
|
||||||
end_frame = min(total_frames,
|
end_frame = (
|
||||||
int(clip_end_sec * decord_vr.get_avg_fps())) if clip_end_sec is not None else total_frames
|
min(total_frames, int(clip_end_sec * decord_vr.get_avg_fps()))
|
||||||
|
if clip_end_sec is not None
|
||||||
|
else total_frames
|
||||||
|
)
|
||||||
frame_id_list = np.linspace(start_frame, end_frame - 1, num_frames, dtype=int)
|
frame_id_list = np.linspace(start_frame, end_frame - 1, num_frames, dtype=int)
|
||||||
elif strategy == 'chat':
|
elif strategy == 'chat':
|
||||||
timestamps = decord_vr.get_frame_timestamp(np.arange(total_frames))
|
timestamps = decord_vr.get_frame_timestamp(np.arange(total_frames))
|
||||||
@ -54,11 +62,11 @@ tokenizer = AutoTokenizer.from_pretrained(
|
|||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = (
|
||||||
MODEL_PATH,
|
AutoModelForCausalLM.from_pretrained(MODEL_PATH, torch_dtype=TORCH_TYPE, trust_remote_code=True)
|
||||||
torch_dtype=TORCH_TYPE,
|
.eval()
|
||||||
trust_remote_code=True
|
.to(DEVICE)
|
||||||
).eval().to(DEVICE)
|
)
|
||||||
|
|
||||||
|
|
||||||
def predict(prompt, video_data, temperature):
|
def predict(prompt, video_data, temperature):
|
||||||
@ -69,11 +77,7 @@ def predict(prompt, video_data, temperature):
|
|||||||
history = []
|
history = []
|
||||||
query = prompt
|
query = prompt
|
||||||
inputs = model.build_conversation_input_ids(
|
inputs = model.build_conversation_input_ids(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer, query=query, images=[video], history=history, template_version=strategy
|
||||||
query=query,
|
|
||||||
images=[video],
|
|
||||||
history=history,
|
|
||||||
template_version=strategy
|
|
||||||
)
|
)
|
||||||
inputs = {
|
inputs = {
|
||||||
'input_ids': inputs['input_ids'].unsqueeze(0).to('cuda'),
|
'input_ids': inputs['input_ids'].unsqueeze(0).to('cuda'),
|
||||||
@ -91,7 +95,7 @@ def predict(prompt, video_data, temperature):
|
|||||||
}
|
}
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = model.generate(**inputs, **gen_kwargs)
|
outputs = model.generate(**inputs, **gen_kwargs)
|
||||||
outputs = outputs[:, inputs['input_ids'].shape[1]:]
|
outputs = outputs[:, inputs['input_ids'].shape[1] :]
|
||||||
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
@ -31,9 +31,18 @@ from dataclasses import dataclass
|
|||||||
# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
|
# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
|
||||||
# DeepSpeed data structures it has to be available in the current python environment.
|
# DeepSpeed data structures it has to be available in the current python environment.
|
||||||
from deepspeed.utils import logger
|
from deepspeed.utils import logger
|
||||||
from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS,
|
from deepspeed.checkpoint.constants import (
|
||||||
FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES,
|
DS_VERSION,
|
||||||
FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS)
|
OPTIMIZER_STATE_DICT,
|
||||||
|
SINGLE_PARTITION_OF_FP32_GROUPS,
|
||||||
|
FP32_FLAT_GROUPS,
|
||||||
|
ZERO_STAGE,
|
||||||
|
PARTITION_COUNT,
|
||||||
|
PARAM_SHAPES,
|
||||||
|
BUFFER_NAMES,
|
||||||
|
FROZEN_PARAM_SHAPES,
|
||||||
|
FROZEN_PARAM_FRAGMENTS,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -134,12 +143,14 @@ def parse_model_states(files):
|
|||||||
|
|
||||||
frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
|
frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
|
||||||
|
|
||||||
z_model_state = zero_model_state(buffers=buffers,
|
z_model_state = zero_model_state(
|
||||||
|
buffers=buffers,
|
||||||
param_shapes=param_shapes,
|
param_shapes=param_shapes,
|
||||||
shared_params=shared_params,
|
shared_params=shared_params,
|
||||||
ds_version=ds_version,
|
ds_version=ds_version,
|
||||||
frozen_param_shapes=frozen_param_shapes,
|
frozen_param_shapes=frozen_param_shapes,
|
||||||
frozen_param_fragments=frozen_param_fragments)
|
frozen_param_fragments=frozen_param_fragments,
|
||||||
|
)
|
||||||
zero_model_states.append(z_model_state)
|
zero_model_states.append(z_model_state)
|
||||||
|
|
||||||
return zero_model_states
|
return zero_model_states
|
||||||
@ -155,7 +166,7 @@ def parse_optim_states(files, ds_checkpoint_dir):
|
|||||||
state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
|
state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
|
||||||
state_dicts.append(state_dict)
|
state_dicts.append(state_dict)
|
||||||
|
|
||||||
if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]:
|
if ZERO_STAGE not in state_dicts[0][OPTIMIZER_STATE_DICT]:
|
||||||
raise ValueError(f"{files[0]} is not a zero checkpoint")
|
raise ValueError(f"{files[0]} is not a zero checkpoint")
|
||||||
zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
|
zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
|
||||||
world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
|
world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
|
||||||
@ -181,7 +192,9 @@ def parse_optim_states(files, ds_checkpoint_dir):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"unknown zero stage {zero_stage}")
|
raise ValueError(f"unknown zero stage {zero_stage}")
|
||||||
|
|
||||||
fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
|
fp32_flat_groups = [
|
||||||
|
state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))
|
||||||
|
]
|
||||||
return zero_stage, world_size, fp32_flat_groups
|
return zero_stage, world_size, fp32_flat_groups
|
||||||
|
|
||||||
|
|
||||||
@ -205,15 +218,20 @@ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_
|
|||||||
print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
|
print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
|
||||||
|
|
||||||
if zero_stage <= 2:
|
if zero_stage <= 2:
|
||||||
return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
|
return _get_fp32_state_dict_from_zero2_checkpoint(
|
||||||
exclude_frozen_parameters)
|
world_size, fp32_flat_groups, zero_model_states, exclude_frozen_parameters
|
||||||
|
)
|
||||||
elif zero_stage == 3:
|
elif zero_stage == 3:
|
||||||
return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
|
return _get_fp32_state_dict_from_zero3_checkpoint(
|
||||||
exclude_frozen_parameters)
|
world_size, fp32_flat_groups, zero_model_states, exclude_frozen_parameters
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _zero2_merge_frozen_params(state_dict, zero_model_states):
|
def _zero2_merge_frozen_params(state_dict, zero_model_states):
|
||||||
if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
|
if (
|
||||||
|
zero_model_states[0].frozen_param_shapes is None
|
||||||
|
or len(zero_model_states[0].frozen_param_shapes) == 0
|
||||||
|
):
|
||||||
return
|
return
|
||||||
|
|
||||||
frozen_param_shapes = zero_model_states[0].frozen_param_shapes
|
frozen_param_shapes = zero_model_states[0].frozen_param_shapes
|
||||||
@ -269,11 +287,17 @@ def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero
|
|||||||
full_single_fp32_vector = torch.cat(merged_partitions, 0)
|
full_single_fp32_vector = torch.cat(merged_partitions, 0)
|
||||||
merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
|
merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
|
||||||
avail_numel = sum(
|
avail_numel = sum(
|
||||||
[full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups])
|
[
|
||||||
|
full_single_fp32_vector.numel()
|
||||||
|
for full_single_fp32_vector in merged_single_partition_of_fp32_groups
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
if debug:
|
if debug:
|
||||||
wanted_params = sum([len(shapes) for shapes in param_shapes])
|
wanted_params = sum([len(shapes) for shapes in param_shapes])
|
||||||
wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
|
wanted_numel = sum(
|
||||||
|
[sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes]
|
||||||
|
)
|
||||||
# not asserting if there is a mismatch due to possible padding
|
# not asserting if there is a mismatch due to possible padding
|
||||||
print(f"Have {avail_numel} numels to process.")
|
print(f"Have {avail_numel} numels to process.")
|
||||||
print(f"Need {wanted_numel} numels in {wanted_params} params.")
|
print(f"Need {wanted_numel} numels in {wanted_params} params.")
|
||||||
@ -283,18 +307,23 @@ def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero
|
|||||||
# out-of-core computing solution
|
# out-of-core computing solution
|
||||||
total_numel = 0
|
total_numel = 0
|
||||||
total_params = 0
|
total_params = 0
|
||||||
for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
|
for shapes, full_single_fp32_vector in zip(
|
||||||
|
param_shapes, merged_single_partition_of_fp32_groups
|
||||||
|
):
|
||||||
offset = 0
|
offset = 0
|
||||||
avail_numel = full_single_fp32_vector.numel()
|
avail_numel = full_single_fp32_vector.numel()
|
||||||
for name, shape in shapes.items():
|
for name, shape in shapes.items():
|
||||||
|
unpartitioned_numel = (
|
||||||
unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape)
|
shape.numel() if _has_callable(shape, 'numel') else math.prod(shape)
|
||||||
|
)
|
||||||
total_numel += unpartitioned_numel
|
total_numel += unpartitioned_numel
|
||||||
total_params += 1
|
total_params += 1
|
||||||
|
|
||||||
if debug:
|
if debug:
|
||||||
print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
|
print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
|
||||||
state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape)
|
state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(
|
||||||
|
shape
|
||||||
|
)
|
||||||
offset += unpartitioned_numel
|
offset += unpartitioned_numel
|
||||||
|
|
||||||
# Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
|
# Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
|
||||||
@ -322,8 +351,9 @@ def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero
|
|||||||
print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
|
print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
|
||||||
|
|
||||||
|
|
||||||
def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
|
def _get_fp32_state_dict_from_zero2_checkpoint(
|
||||||
exclude_frozen_parameters):
|
world_size, fp32_flat_groups, zero_model_states, exclude_frozen_parameters
|
||||||
|
):
|
||||||
state_dict = OrderedDict()
|
state_dict = OrderedDict()
|
||||||
|
|
||||||
# buffers
|
# buffers
|
||||||
@ -353,7 +383,10 @@ def zero3_partitioned_param_info(unpartitioned_numel, world_size):
|
|||||||
|
|
||||||
|
|
||||||
def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
|
def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
|
||||||
if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
|
if (
|
||||||
|
zero_model_states[0].frozen_param_shapes is None
|
||||||
|
or len(zero_model_states[0].frozen_param_shapes) == 0
|
||||||
|
):
|
||||||
return
|
return
|
||||||
|
|
||||||
if debug:
|
if debug:
|
||||||
@ -364,7 +397,10 @@ def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
|
|||||||
frozen_param_shapes = zero_model_states[0].frozen_param_shapes
|
frozen_param_shapes = zero_model_states[0].frozen_param_shapes
|
||||||
wanted_params = len(frozen_param_shapes)
|
wanted_params = len(frozen_param_shapes)
|
||||||
wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
|
wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
|
||||||
avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size
|
avail_numel = (
|
||||||
|
sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()])
|
||||||
|
* world_size
|
||||||
|
)
|
||||||
print(f'Frozen params: Have {avail_numel} numels to process.')
|
print(f'Frozen params: Have {avail_numel} numels to process.')
|
||||||
print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
|
print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
|
||||||
|
|
||||||
@ -375,10 +411,14 @@ def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
|
|||||||
unpartitioned_numel = shape.numel()
|
unpartitioned_numel = shape.numel()
|
||||||
total_numel += unpartitioned_numel
|
total_numel += unpartitioned_numel
|
||||||
|
|
||||||
param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states)
|
param_frags = tuple(
|
||||||
|
model_state.frozen_param_fragments[name] for model_state in zero_model_states
|
||||||
|
)
|
||||||
state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape)
|
state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape)
|
||||||
|
|
||||||
partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
|
partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(
|
||||||
|
unpartitioned_numel, world_size
|
||||||
|
)
|
||||||
|
|
||||||
if debug:
|
if debug:
|
||||||
print(
|
print(
|
||||||
@ -416,21 +456,32 @@ class GatheredTensor:
|
|||||||
start_group_id = None
|
start_group_id = None
|
||||||
end_group_id = None
|
end_group_id = None
|
||||||
for group_id in range(len(self.flat_groups_offset)):
|
for group_id in range(len(self.flat_groups_offset)):
|
||||||
if self.flat_groups_offset[group_id] <= self.offset < self.flat_groups_offset[group_id + 1]:
|
if (
|
||||||
|
self.flat_groups_offset[group_id]
|
||||||
|
<= self.offset
|
||||||
|
< self.flat_groups_offset[group_id + 1]
|
||||||
|
):
|
||||||
start_group_id = group_id
|
start_group_id = group_id
|
||||||
if self.flat_groups_offset[group_id] < end_idx <= self.flat_groups_offset[group_id + 1]:
|
if (
|
||||||
|
self.flat_groups_offset[group_id]
|
||||||
|
< end_idx
|
||||||
|
<= self.flat_groups_offset[group_id + 1]
|
||||||
|
):
|
||||||
end_group_id = group_id
|
end_group_id = group_id
|
||||||
break
|
break
|
||||||
# collect weights from related group/groups
|
# collect weights from related group/groups
|
||||||
for group_id in range(start_group_id, end_group_id + 1):
|
for group_id in range(start_group_id, end_group_id + 1):
|
||||||
flat_tensor = flat_groups_at_rank_i[group_id]
|
flat_tensor = flat_groups_at_rank_i[group_id]
|
||||||
start_offset = self.offset - self.flat_groups_offset[group_id]
|
start_offset = self.offset - self.flat_groups_offset[group_id]
|
||||||
end_offset = min(end_idx, self.flat_groups_offset[group_id + 1]) - self.flat_groups_offset[group_id]
|
end_offset = (
|
||||||
|
min(end_idx, self.flat_groups_offset[group_id + 1])
|
||||||
|
- self.flat_groups_offset[group_id]
|
||||||
|
)
|
||||||
pad_flat_param_chunks.append(flat_tensor[start_offset:end_offset])
|
pad_flat_param_chunks.append(flat_tensor[start_offset:end_offset])
|
||||||
|
|
||||||
# collect weights from all ranks
|
# collect weights from all ranks
|
||||||
pad_flat_param = torch.cat(pad_flat_param_chunks, dim=0)
|
pad_flat_param = torch.cat(pad_flat_param_chunks, dim=0)
|
||||||
param = pad_flat_param[:self.shape.numel()].view(self.shape).contiguous()
|
param = pad_flat_param[: self.shape.numel()].view(self.shape).contiguous()
|
||||||
return param
|
return param
|
||||||
|
|
||||||
|
|
||||||
@ -461,12 +512,16 @@ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero
|
|||||||
offset = 0
|
offset = 0
|
||||||
total_numel = 0
|
total_numel = 0
|
||||||
total_params = 0
|
total_params = 0
|
||||||
flat_groups_offset = [0] + list(np.cumsum([flat_tensor.numel() for flat_tensor in fp32_flat_groups[0]]))
|
flat_groups_offset = [0] + list(
|
||||||
|
np.cumsum([flat_tensor.numel() for flat_tensor in fp32_flat_groups[0]])
|
||||||
|
)
|
||||||
for name, shape in tqdm(param_shapes.items(), desc='Gathering sharded weights'):
|
for name, shape in tqdm(param_shapes.items(), desc='Gathering sharded weights'):
|
||||||
unpartitioned_numel = shape.numel()
|
unpartitioned_numel = shape.numel()
|
||||||
total_numel += unpartitioned_numel
|
total_numel += unpartitioned_numel
|
||||||
total_params += 1
|
total_params += 1
|
||||||
partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
|
partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(
|
||||||
|
unpartitioned_numel, world_size
|
||||||
|
)
|
||||||
|
|
||||||
if debug:
|
if debug:
|
||||||
print(
|
print(
|
||||||
@ -474,7 +529,9 @@ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero
|
|||||||
)
|
)
|
||||||
|
|
||||||
# memory efficient tensor
|
# memory efficient tensor
|
||||||
tensor = GatheredTensor(fp32_flat_groups, flat_groups_offset, offset, partitioned_numel, shape)
|
tensor = GatheredTensor(
|
||||||
|
fp32_flat_groups, flat_groups_offset, offset, partitioned_numel, shape
|
||||||
|
)
|
||||||
state_dict[name] = tensor
|
state_dict[name] = tensor
|
||||||
offset += partitioned_numel
|
offset += partitioned_numel
|
||||||
|
|
||||||
@ -484,11 +541,14 @@ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero
|
|||||||
if offset != avail_numel:
|
if offset != avail_numel:
|
||||||
raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
|
raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
|
||||||
|
|
||||||
print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")
|
print(
|
||||||
|
f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
|
def _get_fp32_state_dict_from_zero3_checkpoint(
|
||||||
exclude_frozen_parameters):
|
world_size, fp32_flat_groups, zero_model_states, exclude_frozen_parameters
|
||||||
|
):
|
||||||
state_dict = OrderedDict()
|
state_dict = OrderedDict()
|
||||||
|
|
||||||
# buffers
|
# buffers
|
||||||
@ -530,10 +590,9 @@ def to_torch_tensor(state_dict, return_empty_tensor=False):
|
|||||||
return torch_state_dict
|
return torch_state_dict
|
||||||
|
|
||||||
|
|
||||||
def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
|
def get_fp32_state_dict_from_zero_checkpoint(
|
||||||
tag=None,
|
checkpoint_dir, tag=None, exclude_frozen_parameters=False, lazy_mode=False
|
||||||
exclude_frozen_parameters=False,
|
):
|
||||||
lazy_mode=False):
|
|
||||||
"""
|
"""
|
||||||
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
|
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
|
||||||
``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
|
``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
|
||||||
@ -588,19 +647,23 @@ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
|
|||||||
if not os.path.isdir(ds_checkpoint_dir):
|
if not os.path.isdir(ds_checkpoint_dir):
|
||||||
raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
|
raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
|
||||||
|
|
||||||
state_dict = _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters)
|
state_dict = _get_fp32_state_dict_from_zero_checkpoint(
|
||||||
|
ds_checkpoint_dir, exclude_frozen_parameters
|
||||||
|
)
|
||||||
if lazy_mode:
|
if lazy_mode:
|
||||||
return state_dict
|
return state_dict
|
||||||
else:
|
else:
|
||||||
return to_torch_tensor(state_dict)
|
return to_torch_tensor(state_dict)
|
||||||
|
|
||||||
|
|
||||||
def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir,
|
def convert_zero_checkpoint_to_fp32_state_dict(
|
||||||
|
checkpoint_dir,
|
||||||
output_dir,
|
output_dir,
|
||||||
max_shard_size="5GB",
|
max_shard_size="5GB",
|
||||||
safe_serialization=False,
|
safe_serialization=False,
|
||||||
tag=None,
|
tag=None,
|
||||||
exclude_frozen_parameters=False):
|
exclude_frozen_parameters=False,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
|
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
|
||||||
loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
|
loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
|
||||||
@ -629,25 +692,28 @@ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir,
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
# Convert zero checkpoint to state_dict
|
# Convert zero checkpoint to state_dict
|
||||||
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
|
state_dict = get_fp32_state_dict_from_zero_checkpoint(
|
||||||
tag,
|
checkpoint_dir, tag, exclude_frozen_parameters, lazy_mode=True
|
||||||
exclude_frozen_parameters,
|
)
|
||||||
lazy_mode=True)
|
|
||||||
|
|
||||||
# Shard the model if it is too big.
|
# Shard the model if it is too big.
|
||||||
weights_name = "model.safetensors" if safe_serialization else "pytorch_model.bin"
|
weights_name = "model.safetensors" if safe_serialization else "pytorch_model.bin"
|
||||||
if max_shard_size is not None:
|
if max_shard_size is not None:
|
||||||
filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
|
filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
|
||||||
|
".safetensors", "{suffix}.safetensors"
|
||||||
|
)
|
||||||
# an memory-efficient approach for sharding
|
# an memory-efficient approach for sharding
|
||||||
empty_state_dict = to_torch_tensor(state_dict, return_empty_tensor=True)
|
empty_state_dict = to_torch_tensor(state_dict, return_empty_tensor=True)
|
||||||
state_dict_split = split_torch_state_dict_into_shards(empty_state_dict,
|
state_dict_split = split_torch_state_dict_into_shards(
|
||||||
filename_pattern=filename_pattern,
|
empty_state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
|
||||||
max_shard_size=max_shard_size)
|
)
|
||||||
else:
|
else:
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
||||||
StateDictSplit = namedtuple("StateDictSplit", ["is_sharded", "filename_to_tensors"])
|
StateDictSplit = namedtuple("StateDictSplit", ["is_sharded", "filename_to_tensors"])
|
||||||
state_dict_split = StateDictSplit(is_sharded=False,
|
state_dict_split = StateDictSplit(
|
||||||
filename_to_tensors={weights_name: list(state_dict.keys())})
|
is_sharded=False, filename_to_tensors={weights_name: list(state_dict.keys())}
|
||||||
|
)
|
||||||
|
|
||||||
# Save the model by shard
|
# Save the model by shard
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
@ -673,7 +739,9 @@ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir,
|
|||||||
"metadata": state_dict_split.metadata,
|
"metadata": state_dict_split.metadata,
|
||||||
"weight_map": state_dict_split.tensor_to_filename,
|
"weight_map": state_dict_split.tensor_to_filename,
|
||||||
}
|
}
|
||||||
save_index_file = "model.safetensors.index.json" if safe_serialization else "pytorch_model.bin.index.json"
|
save_index_file = (
|
||||||
|
"model.safetensors.index.json" if safe_serialization else "pytorch_model.bin.index.json"
|
||||||
|
)
|
||||||
save_index_file = os.path.join(output_dir, save_index_file)
|
save_index_file = os.path.join(output_dir, save_index_file)
|
||||||
with open(save_index_file, "w", encoding="utf-8") as f:
|
with open(save_index_file, "w", encoding="utf-8") as f:
|
||||||
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||||
@ -719,12 +787,14 @@ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def convert_zero_checkpoint_to_bf16_state_dict(checkpoint_dir,
|
def convert_zero_checkpoint_to_bf16_state_dict(
|
||||||
|
checkpoint_dir,
|
||||||
output_dir,
|
output_dir,
|
||||||
max_shard_size="5GB",
|
max_shard_size="5GB",
|
||||||
safe_serialization=True,
|
safe_serialization=True,
|
||||||
tag=None,
|
tag=None,
|
||||||
exclude_frozen_parameters=False):
|
exclude_frozen_parameters=False,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
将 ZeRO 2 或 ZeRO 3 格式的 DeepSpeed 检查点转换为 BF16,并输出到指定目录下,命名规则为:
|
将 ZeRO 2 或 ZeRO 3 格式的 DeepSpeed 检查点转换为 BF16,并输出到指定目录下,命名规则为:
|
||||||
- 如果只有一个分片:
|
- 如果只有一个分片:
|
||||||
@ -748,10 +818,7 @@ def convert_zero_checkpoint_to_bf16_state_dict(checkpoint_dir,
|
|||||||
raise ImportError("You need `pip install huggingface_hub` to use the sharding feature.")
|
raise ImportError("You need `pip install huggingface_hub` to use the sharding feature.")
|
||||||
|
|
||||||
state_dict = get_fp32_state_dict_from_zero_checkpoint(
|
state_dict = get_fp32_state_dict_from_zero_checkpoint(
|
||||||
checkpoint_dir,
|
checkpoint_dir, tag=tag, exclude_frozen_parameters=exclude_frozen_parameters, lazy_mode=True
|
||||||
tag=tag,
|
|
||||||
exclude_frozen_parameters=exclude_frozen_parameters,
|
|
||||||
lazy_mode=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
state_dict = to_torch_tensor(state_dict, return_empty_tensor=False)
|
state_dict = to_torch_tensor(state_dict, return_empty_tensor=False)
|
||||||
@ -766,9 +833,7 @@ def convert_zero_checkpoint_to_bf16_state_dict(checkpoint_dir,
|
|||||||
|
|
||||||
empty_state_dict = to_torch_tensor(state_dict, return_empty_tensor=True)
|
empty_state_dict = to_torch_tensor(state_dict, return_empty_tensor=True)
|
||||||
state_dict_split = split_torch_state_dict_into_shards(
|
state_dict_split = split_torch_state_dict_into_shards(
|
||||||
empty_state_dict,
|
empty_state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
|
||||||
filename_pattern=filename_pattern,
|
|
||||||
max_shard_size=max_shard_size
|
|
||||||
)
|
)
|
||||||
|
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
@ -789,7 +854,6 @@ def convert_zero_checkpoint_to_bf16_state_dict(checkpoint_dir,
|
|||||||
del shard_state_dict
|
del shard_state_dict
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
|
|
||||||
if state_dict_split.is_sharded:
|
if state_dict_split.is_sharded:
|
||||||
index = {
|
index = {
|
||||||
"metadata": state_dict_split.metadata,
|
"metadata": state_dict_split.metadata,
|
||||||
@ -801,21 +865,29 @@ def convert_zero_checkpoint_to_bf16_state_dict(checkpoint_dir,
|
|||||||
else:
|
else:
|
||||||
only_filename = list(state_dict_split.filename_to_tensors.keys())[0]
|
only_filename = list(state_dict_split.filename_to_tensors.keys())[0]
|
||||||
old_path = os.path.join(output_dir, only_filename)
|
old_path = os.path.join(output_dir, only_filename)
|
||||||
new_path = os.path.join(output_dir, "diffusion_pytorch_model.safetensors" if safe_serialization
|
new_path = os.path.join(
|
||||||
else "diffusion_pytorch_model.bin")
|
output_dir,
|
||||||
|
"diffusion_pytorch_model.safetensors"
|
||||||
|
if safe_serialization
|
||||||
|
else "diffusion_pytorch_model.bin",
|
||||||
|
)
|
||||||
if old_path != new_path:
|
if old_path != new_path:
|
||||||
os.rename(old_path, new_path)
|
os.rename(old_path, new_path)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("checkpoint_dir",
|
parser.add_argument(
|
||||||
|
"checkpoint_dir",
|
||||||
type=str,
|
type=str,
|
||||||
help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
|
help="path to the desired checkpoint folder, e.g., path/checkpoint-12",
|
||||||
parser.add_argument("output_dir",
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"output_dir",
|
||||||
type=str,
|
type=str,
|
||||||
help="directory to the pytorch fp32 state_dict output files"
|
help="directory to the pytorch fp32 state_dict output files"
|
||||||
"(e.g. path/checkpoint-12-output/)")
|
"(e.g. path/checkpoint-12-output/)",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max_shard_size",
|
"--max_shard_size",
|
||||||
type=str,
|
type=str,
|
||||||
@ -823,26 +895,34 @@ if __name__ == "__main__":
|
|||||||
help="The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size"
|
help="The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size"
|
||||||
"lower than this size. If expressed as a string, needs to be digits followed by a unit (like `5MB`"
|
"lower than this size. If expressed as a string, needs to be digits followed by a unit (like `5MB`"
|
||||||
"We default it to 5GB in order for models to be able to run easily on free-tier google colab instances"
|
"We default it to 5GB in order for models to be able to run easily on free-tier google colab instances"
|
||||||
"without CPU OOM issues.")
|
"without CPU OOM issues.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--safe_serialization",
|
"--safe_serialization",
|
||||||
default=False,
|
default=False,
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help="Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).")
|
help="Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).",
|
||||||
parser.add_argument("-t",
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-t",
|
||||||
"--tag",
|
"--tag",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1")
|
help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1",
|
||||||
parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters")
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters"
|
||||||
|
)
|
||||||
parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
|
parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
debug = args.debug
|
debug = args.debug
|
||||||
|
|
||||||
convert_zero_checkpoint_to_bf16_state_dict(args.checkpoint_dir,
|
convert_zero_checkpoint_to_bf16_state_dict(
|
||||||
|
args.checkpoint_dir,
|
||||||
args.output_dir,
|
args.output_dir,
|
||||||
max_shard_size=args.max_shard_size,
|
max_shard_size=args.max_shard_size,
|
||||||
safe_serialization=args.safe_serialization,
|
safe_serialization=args.safe_serialization,
|
||||||
tag=args.tag,
|
tag=args.tag,
|
||||||
exclude_frozen_parameters=args.exclude_frozen_parameters)
|
exclude_frozen_parameters=args.exclude_frozen_parameters,
|
||||||
|
)
|
||||||
|
@ -10,6 +10,7 @@ Original Script:
|
|||||||
https://github.com/huggingface/diffusers/blob/main/scripts/convert_cogvideox_to_diffusers.py
|
https://github.com/huggingface/diffusers/blob/main/scripts/convert_cogvideox_to_diffusers.py
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
@ -143,7 +144,9 @@ def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
|
def update_state_dict_inplace(
|
||||||
|
state_dict: Dict[str, Any], old_key: str, new_key: str
|
||||||
|
) -> Dict[str, Any]:
|
||||||
state_dict[new_key] = state_dict.pop(old_key)
|
state_dict[new_key] = state_dict.pop(old_key)
|
||||||
|
|
||||||
|
|
||||||
@ -164,8 +167,11 @@ def convert_transformer(
|
|||||||
num_layers=num_layers,
|
num_layers=num_layers,
|
||||||
num_attention_heads=num_attention_heads,
|
num_attention_heads=num_attention_heads,
|
||||||
use_rotary_positional_embeddings=use_rotary_positional_embeddings,
|
use_rotary_positional_embeddings=use_rotary_positional_embeddings,
|
||||||
ofs_embed_dim=512 if (i2v and init_kwargs["patch_size_t"] is not None) else None, # CogVideoX1.5-5B-I2V
|
ofs_embed_dim=512
|
||||||
use_learned_positional_embeddings=i2v and init_kwargs["patch_size_t"] is None, # CogVideoX-5B-I2V
|
if (i2v and init_kwargs["patch_size_t"] is not None)
|
||||||
|
else None, # CogVideoX1.5-5B-I2V
|
||||||
|
use_learned_positional_embeddings=i2v
|
||||||
|
and init_kwargs["patch_size_t"] is None, # CogVideoX-5B-I2V
|
||||||
**init_kwargs,
|
**init_kwargs,
|
||||||
).to(dtype=dtype)
|
).to(dtype=dtype)
|
||||||
|
|
||||||
@ -240,17 +246,40 @@ def get_transformer_init_kwargs(version: str):
|
|||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
|
"--transformer_ckpt_path",
|
||||||
)
|
type=str,
|
||||||
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint")
|
default=None,
|
||||||
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
help="Path to original transformer checkpoint",
|
||||||
parser.add_argument("--fp16", action="store_true", default=False, help="Whether to save the model weights in fp16")
|
|
||||||
parser.add_argument("--bf16", action="store_true", default=False, help="Whether to save the model weights in bf16")
|
|
||||||
parser.add_argument(
|
|
||||||
"--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving"
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory"
|
"--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_path", type=str, required=True, help="Path where converted model should be saved"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--fp16",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Whether to save the model weights in fp16",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--bf16",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Whether to save the model weights in bf16",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--push_to_hub",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Whether to push to HF Hub after saving",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--text_encoder_cache_dir",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to text encoder cache directory",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--typecast_text_encoder",
|
"--typecast_text_encoder",
|
||||||
@ -261,15 +290,24 @@ def get_args():
|
|||||||
# For CogVideoX-2B, num_layers is 30. For 5B, it is 42
|
# For CogVideoX-2B, num_layers is 30. For 5B, it is 42
|
||||||
parser.add_argument("--num_layers", type=int, default=30, help="Number of transformer blocks")
|
parser.add_argument("--num_layers", type=int, default=30, help="Number of transformer blocks")
|
||||||
# For CogVideoX-2B, num_attention_heads is 30. For 5B, it is 48
|
# For CogVideoX-2B, num_attention_heads is 30. For 5B, it is 48
|
||||||
parser.add_argument("--num_attention_heads", type=int, default=30, help="Number of attention heads")
|
parser.add_argument(
|
||||||
|
"--num_attention_heads", type=int, default=30, help="Number of attention heads"
|
||||||
|
)
|
||||||
# For CogVideoX-2B, use_rotary_positional_embeddings is False. For 5B, it is True
|
# For CogVideoX-2B, use_rotary_positional_embeddings is False. For 5B, it is True
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use_rotary_positional_embeddings", action="store_true", default=False, help="Whether to use RoPE or not"
|
"--use_rotary_positional_embeddings",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Whether to use RoPE or not",
|
||||||
)
|
)
|
||||||
# For CogVideoX-2B, scaling_factor is 1.15258426. For 5B, it is 0.7
|
# For CogVideoX-2B, scaling_factor is 1.15258426. For 5B, it is 0.7
|
||||||
parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE")
|
parser.add_argument(
|
||||||
|
"--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE"
|
||||||
|
)
|
||||||
# For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0
|
# For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0
|
||||||
parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE")
|
parser.add_argument(
|
||||||
|
"--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE"
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--i2v",
|
"--i2v",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
@ -313,7 +351,9 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
text_encoder_id = "google/t5-v1_1-xxl"
|
text_encoder_id = "google/t5-v1_1-xxl"
|
||||||
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
|
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
|
||||||
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
|
text_encoder = T5EncoderModel.from_pretrained(
|
||||||
|
text_encoder_id, cache_dir=args.text_encoder_cache_dir
|
||||||
|
)
|
||||||
|
|
||||||
if args.typecast_text_encoder:
|
if args.typecast_text_encoder:
|
||||||
text_encoder = text_encoder.to(dtype=dtype)
|
text_encoder = text_encoder.to(dtype=dtype)
|
||||||
@ -355,4 +395,9 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# This is necessary This is necessary for users with insufficient memory,
|
# This is necessary This is necessary for users with insufficient memory,
|
||||||
# such as those using Colab and notebooks, as it can save some memory used for model loading.
|
# such as those using Colab and notebooks, as it can save some memory used for model loading.
|
||||||
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=args.push_to_hub)
|
pipe.save_pretrained(
|
||||||
|
args.output_path,
|
||||||
|
safe_serialization=True,
|
||||||
|
max_shard_size="5GB",
|
||||||
|
push_to_hub=args.push_to_hub,
|
||||||
|
)
|
||||||
|
@ -15,8 +15,8 @@ def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
state_dict = state_dict["state_dict"]
|
state_dict = state_dict["state_dict"]
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
LORA_KEYS_RENAME = {
|
|
||||||
|
|
||||||
|
LORA_KEYS_RENAME = {
|
||||||
'attention.query_key_value.matrix_A.0': 'attn1.to_q.lora_A.weight',
|
'attention.query_key_value.matrix_A.0': 'attn1.to_q.lora_A.weight',
|
||||||
'attention.query_key_value.matrix_A.1': 'attn1.to_k.lora_A.weight',
|
'attention.query_key_value.matrix_A.1': 'attn1.to_k.lora_A.weight',
|
||||||
'attention.query_key_value.matrix_A.2': 'attn1.to_v.lora_A.weight',
|
'attention.query_key_value.matrix_A.2': 'attn1.to_v.lora_A.weight',
|
||||||
@ -24,22 +24,18 @@ LORA_KEYS_RENAME = {
|
|||||||
'attention.query_key_value.matrix_B.1': 'attn1.to_k.lora_B.weight',
|
'attention.query_key_value.matrix_B.1': 'attn1.to_k.lora_B.weight',
|
||||||
'attention.query_key_value.matrix_B.2': 'attn1.to_v.lora_B.weight',
|
'attention.query_key_value.matrix_B.2': 'attn1.to_v.lora_B.weight',
|
||||||
'attention.dense.matrix_A.0': 'attn1.to_out.0.lora_A.weight',
|
'attention.dense.matrix_A.0': 'attn1.to_out.0.lora_A.weight',
|
||||||
'attention.dense.matrix_B.0': 'attn1.to_out.0.lora_B.weight'
|
'attention.dense.matrix_B.0': 'attn1.to_out.0.lora_B.weight',
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
PREFIX_KEY = "model.diffusion_model."
|
PREFIX_KEY = "model.diffusion_model."
|
||||||
SAT_UNIT_KEY = "layers"
|
SAT_UNIT_KEY = "layers"
|
||||||
LORA_PREFIX_KEY = "transformer_blocks"
|
LORA_PREFIX_KEY = "transformer_blocks"
|
||||||
|
|
||||||
|
|
||||||
|
def export_lora_weight(ckpt_path, lora_save_directory):
|
||||||
def export_lora_weight(ckpt_path,lora_save_directory):
|
|
||||||
|
|
||||||
merge_original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
|
merge_original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
|
||||||
|
|
||||||
|
|
||||||
lora_state_dict = {}
|
lora_state_dict = {}
|
||||||
for key in list(merge_original_state_dict.keys()):
|
for key in list(merge_original_state_dict.keys()):
|
||||||
new_key = key[len(PREFIX_KEY) :]
|
new_key = key[len(PREFIX_KEY) :]
|
||||||
@ -50,8 +46,6 @@ def export_lora_weight(ckpt_path,lora_save_directory):
|
|||||||
|
|
||||||
lora_state_dict[new_key] = merge_original_state_dict[key]
|
lora_state_dict[new_key] = merge_original_state_dict[key]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# final length should be 240
|
# final length should be 240
|
||||||
if len(lora_state_dict) != 240:
|
if len(lora_state_dict) != 240:
|
||||||
raise ValueError("lora_state_dict length is not 240")
|
raise ValueError("lora_state_dict length is not 240")
|
||||||
@ -64,7 +58,7 @@ def export_lora_weight(ckpt_path,lora_save_directory):
|
|||||||
is_main_process=True,
|
is_main_process=True,
|
||||||
weight_name=None,
|
weight_name=None,
|
||||||
save_function=None,
|
save_function=None,
|
||||||
safe_serialization=True
|
safe_serialization=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -73,7 +67,12 @@ def get_args():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--sat_pt_path", type=str, required=True, help="Path to original sat transformer checkpoint"
|
"--sat_pt_path", type=str, required=True, help="Path to original sat transformer checkpoint"
|
||||||
)
|
)
|
||||||
parser.add_argument("--lora_save_directory", type=str, required=True, help="Path where converted lora should be saved")
|
parser.add_argument(
|
||||||
|
"--lora_save_directory",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path where converted lora should be saved",
|
||||||
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
@ -35,20 +35,16 @@ caption_generator = transformers.pipeline(
|
|||||||
"torch_dtype": torch.bfloat16,
|
"torch_dtype": torch.bfloat16,
|
||||||
},
|
},
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
tokenizer=tokenizer
|
tokenizer=tokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
image_generator = DiffusionPipeline.from_pretrained(
|
image_generator = DiffusionPipeline.from_pretrained(
|
||||||
image_generator_model_id,
|
image_generator_model_id, torch_dtype=torch.bfloat16, device_map="balanced"
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device_map="balanced"
|
|
||||||
)
|
)
|
||||||
# image_generator.to("cuda")
|
# image_generator.to("cuda")
|
||||||
|
|
||||||
video_generator = CogVideoXImageToVideoPipeline.from_pretrained(
|
video_generator = CogVideoXImageToVideoPipeline.from_pretrained(
|
||||||
video_generator_model_id,
|
video_generator_model_id, torch_dtype=torch.bfloat16, device_map="balanced"
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device_map="balanced"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
video_generator.vae.enable_slicing()
|
video_generator.vae.enable_slicing()
|
||||||
@ -87,11 +83,7 @@ def generate_caption(prompt):
|
|||||||
{"role": "user", "content": prompt + "\n" + user_prompt},
|
{"role": "user", "content": prompt + "\n" + user_prompt},
|
||||||
]
|
]
|
||||||
|
|
||||||
response = caption_generator(
|
response = caption_generator(messages, max_new_tokens=226, return_full_text=False)
|
||||||
messages,
|
|
||||||
max_new_tokens=226,
|
|
||||||
return_full_text=False
|
|
||||||
)
|
|
||||||
caption = response[0]["generated_text"]
|
caption = response[0]["generated_text"]
|
||||||
if caption.startswith("\"") and caption.endswith("\""):
|
if caption.startswith("\"") and caption.endswith("\""):
|
||||||
caption = caption[1:-1]
|
caption = caption[1:-1]
|
||||||
@ -109,11 +101,7 @@ def generate_image(caption, progress=gr.Progress(track_tqdm=True)):
|
|||||||
return image, image # One for output One for State
|
return image, image # One for output One for State
|
||||||
|
|
||||||
|
|
||||||
def generate_video(
|
def generate_video(caption, image, progress=gr.Progress(track_tqdm=True)):
|
||||||
caption,
|
|
||||||
image,
|
|
||||||
progress=gr.Progress(track_tqdm=True)
|
|
||||||
):
|
|
||||||
generator = torch.Generator().manual_seed(seed)
|
generator = torch.Generator().manual_seed(seed)
|
||||||
video_frames = video_generator(
|
video_frames = video_generator(
|
||||||
image=image,
|
image=image,
|
||||||
@ -181,14 +169,19 @@ with gr.Blocks() as demo:
|
|||||||
image_output = gr.Image(label="Generated Image")
|
image_output = gr.Image(label="Generated Image")
|
||||||
state_image = gr.State()
|
state_image = gr.State()
|
||||||
generate_caption_button.click(fn=generate_caption, inputs=prompt, outputs=caption)
|
generate_caption_button.click(fn=generate_caption, inputs=prompt, outputs=caption)
|
||||||
generate_image_button.click(fn=generate_image, inputs=caption, outputs=[image_output, state_image])
|
generate_image_button.click(
|
||||||
|
fn=generate_image, inputs=caption, outputs=[image_output, state_image]
|
||||||
|
)
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
video_output = gr.Video(label="Generated Video", width=720, height=480)
|
video_output = gr.Video(label="Generated Video", width=720, height=480)
|
||||||
download_video_button = gr.File(label="📥 Download Video", visible=False)
|
download_video_button = gr.File(label="📥 Download Video", visible=False)
|
||||||
download_gif_button = gr.File(label="📥 Download GIF", visible=False)
|
download_gif_button = gr.File(label="📥 Download GIF", visible=False)
|
||||||
generate_video_button = gr.Button("Generate Video from Image")
|
generate_video_button = gr.Button("Generate Video from Image")
|
||||||
generate_video_button.click(fn=generate_video, inputs=[caption, state_image],
|
generate_video_button.click(
|
||||||
outputs=[video_output, download_gif_button])
|
fn=generate_video,
|
||||||
|
inputs=[caption, state_image],
|
||||||
|
outputs=[video_output, download_gif_button],
|
||||||
|
)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
demo.launch()
|
demo.launch()
|
||||||
|
@ -65,7 +65,7 @@ def get_args():
|
|||||||
"--num_videos",
|
"--num_videos",
|
||||||
type=int,
|
type=int,
|
||||||
default=5,
|
default=5,
|
||||||
help="Number of unique videos you would like to generate."
|
help="Number of unique videos you would like to generate.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model_path",
|
"--model_path",
|
||||||
@ -83,31 +83,28 @@ def get_args():
|
|||||||
"--caption_generator_cache_dir",
|
"--caption_generator_cache_dir",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="Cache directory for caption generation model."
|
help="Cache directory for caption generation model.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--image_generator_model_id",
|
"--image_generator_model_id",
|
||||||
type=str,
|
type=str,
|
||||||
default="black-forest-labs/FLUX.1-dev",
|
default="black-forest-labs/FLUX.1-dev",
|
||||||
help="Image generation model."
|
help="Image generation model.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--image_generator_cache_dir",
|
"--image_generator_cache_dir",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="Cache directory for image generation model."
|
help="Cache directory for image generation model.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--image_generator_num_inference_steps",
|
"--image_generator_num_inference_steps",
|
||||||
type=int,
|
type=int,
|
||||||
default=50,
|
default=50,
|
||||||
help="Caption generation model."
|
help="Caption generation model.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--guidance_scale",
|
"--guidance_scale", type=float, default=7, help="Guidance scale to be use for generation."
|
||||||
type=float,
|
|
||||||
default=7,
|
|
||||||
help="Guidance scale to be use for generation."
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use_dynamic_cfg",
|
"--use_dynamic_cfg",
|
||||||
@ -123,19 +120,14 @@ def get_args():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--compile",
|
"--compile",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Whether or not to compile the transformer of image and video generators."
|
help="Whether or not to compile the transformer of image and video generators.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--enable_vae_tiling",
|
"--enable_vae_tiling",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Whether or not to use VAE tiling when encoding/decoding."
|
help="Whether or not to use VAE tiling when encoding/decoding.",
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--seed",
|
|
||||||
type=int,
|
|
||||||
default=42,
|
|
||||||
help="Seed for reproducibility."
|
|
||||||
)
|
)
|
||||||
|
parser.add_argument("--seed", type=int, default=42, help="Seed for reproducibility.")
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@ -157,7 +149,9 @@ def main(args: Dict[str, Any]) -> None:
|
|||||||
torch.cuda.manual_seed_all(args.seed)
|
torch.cuda.manual_seed_all(args.seed)
|
||||||
|
|
||||||
reset_memory()
|
reset_memory()
|
||||||
tokenizer = AutoTokenizer.from_pretrained(args.caption_generator_model_id, trust_remote_code=True)
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
args.caption_generator_model_id, trust_remote_code=True
|
||||||
|
)
|
||||||
caption_generator = transformers.pipeline(
|
caption_generator = transformers.pipeline(
|
||||||
"text-generation",
|
"text-generation",
|
||||||
model=args.caption_generator_model_id,
|
model=args.caption_generator_model_id,
|
||||||
@ -168,7 +162,7 @@ def main(args: Dict[str, Any]) -> None:
|
|||||||
"torch_dtype": torch.bfloat16,
|
"torch_dtype": torch.bfloat16,
|
||||||
},
|
},
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
tokenizer=tokenizer
|
tokenizer=tokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
captions = []
|
captions = []
|
||||||
@ -197,12 +191,14 @@ def main(args: Dict[str, Any]) -> None:
|
|||||||
image_generator = DiffusionPipeline.from_pretrained(
|
image_generator = DiffusionPipeline.from_pretrained(
|
||||||
args.image_generator_model_id,
|
args.image_generator_model_id,
|
||||||
cache_dir=args.image_generator_cache_dir,
|
cache_dir=args.image_generator_cache_dir,
|
||||||
torch_dtype=torch.bfloat16
|
torch_dtype=torch.bfloat16,
|
||||||
)
|
)
|
||||||
image_generator.to("cuda")
|
image_generator.to("cuda")
|
||||||
|
|
||||||
if args.compile:
|
if args.compile:
|
||||||
image_generator.transformer = torch.compile(image_generator.transformer, mode="max-autotune", fullgraph=True)
|
image_generator.transformer = torch.compile(
|
||||||
|
image_generator.transformer, mode="max-autotune", fullgraph=True
|
||||||
|
)
|
||||||
|
|
||||||
if args.enable_vae_tiling:
|
if args.enable_vae_tiling:
|
||||||
image_generator.vae.enable_tiling()
|
image_generator.vae.enable_tiling()
|
||||||
@ -216,7 +212,9 @@ def main(args: Dict[str, Any]) -> None:
|
|||||||
num_inference_steps=args.image_generator_num_inference_steps,
|
num_inference_steps=args.image_generator_num_inference_steps,
|
||||||
guidance_scale=3.5,
|
guidance_scale=3.5,
|
||||||
).images[0]
|
).images[0]
|
||||||
filename = caption[:25].replace(".", "_").replace("'", "_").replace('"', "_").replace(",", "_")
|
filename = (
|
||||||
|
caption[:25].replace(".", "_").replace("'", "_").replace('"', "_").replace(",", "_")
|
||||||
|
)
|
||||||
image.save(output_dir / f"{index}_{filename}.png")
|
image.save(output_dir / f"{index}_{filename}.png")
|
||||||
images.append(image)
|
images.append(image)
|
||||||
|
|
||||||
@ -224,13 +222,16 @@ def main(args: Dict[str, Any]) -> None:
|
|||||||
reset_memory()
|
reset_memory()
|
||||||
|
|
||||||
video_generator = CogVideoXImageToVideoPipeline.from_pretrained(
|
video_generator = CogVideoXImageToVideoPipeline.from_pretrained(
|
||||||
args.model_path, torch_dtype=torch.bfloat16).to("cuda")
|
args.model_path, torch_dtype=torch.bfloat16
|
||||||
|
).to("cuda")
|
||||||
video_generator.scheduler = CogVideoXDPMScheduler.from_config(
|
video_generator.scheduler = CogVideoXDPMScheduler.from_config(
|
||||||
video_generator.scheduler.config,
|
video_generator.scheduler.config, timestep_spacing="trailing"
|
||||||
timestep_spacing="trailing")
|
)
|
||||||
|
|
||||||
if args.compile:
|
if args.compile:
|
||||||
video_generator.transformer = torch.compile(video_generator.transformer, mode="max-autotune", fullgraph=True)
|
video_generator.transformer = torch.compile(
|
||||||
|
video_generator.transformer, mode="max-autotune", fullgraph=True
|
||||||
|
)
|
||||||
|
|
||||||
if args.enable_vae_tiling:
|
if args.enable_vae_tiling:
|
||||||
video_generator.vae.enable_tiling()
|
video_generator.vae.enable_tiling()
|
||||||
@ -248,7 +249,9 @@ def main(args: Dict[str, Any]) -> None:
|
|||||||
use_dynamic_cfg=args.use_dynamic_cfg,
|
use_dynamic_cfg=args.use_dynamic_cfg,
|
||||||
generator=generator,
|
generator=generator,
|
||||||
).frames[0]
|
).frames[0]
|
||||||
filename = caption[:25].replace(".", "_").replace("'", "_").replace('"', "_").replace(",", "_")
|
filename = (
|
||||||
|
caption[:25].replace(".", "_").replace("'", "_").replace('"', "_").replace(",", "_")
|
||||||
|
)
|
||||||
export_to_video(video, output_dir / f"{index}_{filename}.mp4", fps=8)
|
export_to_video(video, output_dir / f"{index}_{filename}.mp4", fps=8)
|
||||||
|
|
||||||
|
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user