mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 03:04:56 +08:00
commit
c624cb0d91
28
.github/ISSUE_TEMPLATE/PULL_REQUEST_TEMPLATE.md
vendored
Normal file
28
.github/ISSUE_TEMPLATE/PULL_REQUEST_TEMPLATE.md
vendored
Normal file
@ -0,0 +1,28 @@
|
||||
# Contribution Guide
|
||||
|
||||
We welcome your contributions to this repository. To ensure elegant code style and better code quality, we have prepared the following contribution guidelines.
|
||||
|
||||
## What We Accept
|
||||
|
||||
+ This PR fixes a typo or improves the documentation (if this is the case, you may skip the other checks).
|
||||
+ This PR fixes a specific issue — please reference the issue number in the PR description. Make sure your code strictly follows the coding standards below.
|
||||
+ This PR introduces a new feature — please clearly explain the necessity and implementation of the feature. Make sure your code strictly follows the coding standards below.
|
||||
|
||||
## Code Style Guide
|
||||
|
||||
Good code style is an art. We have prepared a `pyproject.toml` and a `pre-commit` hook to enforce consistent code formatting across the project. You can clean up your code following the steps below:
|
||||
|
||||
1. Install the required dependencies:
|
||||
```shell
|
||||
pip install ruff pre-commit
|
||||
```
|
||||
2. Then, run the following command:
|
||||
```shell
|
||||
pre-commit run --all-files
|
||||
```
|
||||
If your code complies with the standards, you should not see any errors.
|
||||
|
||||
## Naming Conventions
|
||||
|
||||
- Please use **English** for naming; do not use Pinyin or other languages. All comments should also be in English.
|
||||
- Follow **PEP8** naming conventions strictly, and use underscores to separate words. Avoid meaningless names such as `a`, `b`, `c`.
|
34
.github/PULL_REQUEST_TEMPLATE/pr_template.md
vendored
34
.github/PULL_REQUEST_TEMPLATE/pr_template.md
vendored
@ -1,34 +0,0 @@
|
||||
# Raise valuable PR / 提出有价值的PR
|
||||
|
||||
## Caution / 注意事项:
|
||||
Users should keep the following points in mind when submitting PRs:
|
||||
|
||||
1. Ensure that your code meets the requirements in the [specification](../../resources/contribute.md).
|
||||
2. the proposed PR should be relevant, if there are multiple ideas and optimizations, they should be assigned to different PRs.
|
||||
|
||||
用户在提交PR时候应该注意以下几点:
|
||||
|
||||
1. 确保您的代码符合 [规范](../../resources/contribute_zh.md) 中的要求。
|
||||
2. 提出的PR应该具有针对性,如果具有多个不同的想法和优化方案,应该分配到不同的PR中。
|
||||
|
||||
## 不应该提出的PR / PRs that should not be proposed
|
||||
|
||||
If a developer proposes a PR about any of the following, it may be closed or Rejected.
|
||||
|
||||
1. those that don't describe improvement options.
|
||||
2. multiple issues of different types combined in one PR.
|
||||
3. The proposed PR is highly duplicative of already existing PRs.
|
||||
|
||||
如果开发者提出关于以下方面的PR,则可能会被直接关闭或拒绝通过。
|
||||
|
||||
1. 没有说明改进方案的。
|
||||
2. 多个不同类型的问题合并在一个PR中的。
|
||||
3. 提出的PR与已经存在的PR高度重复的。
|
||||
|
||||
|
||||
# 检查您的PR
|
||||
- [ ] Have you read the Contributor Guidelines, Pull Request section? / 您是否阅读了贡献者指南、Pull Request 部分?
|
||||
- [ ] Has this been discussed/approved via a Github issue or forum? If so, add a link. / 是否通过 Github 问题或论坛讨论/批准过?如果是,请添加链接。
|
||||
- [ ] Did you make sure you updated the documentation with your changes? Here are the Documentation Guidelines, and here are the Documentation Formatting Tips. /您是否确保根据您的更改更新了文档?这里是文档指南,这里是文档格式化技巧。
|
||||
- [ ] Did you write new required tests? / 您是否编写了新的必要测试?
|
||||
- [ ] Are your PRs for only one issue / 您的PR是否仅针对一个问题
|
19
.pre-commit-config.yaml
Normal file
19
.pre-commit-config.yaml
Normal file
@ -0,0 +1,19 @@
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.4.5
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix, --respect-gitignore, --config=pyproject.toml]
|
||||
- id: ruff-format
|
||||
args: [--config=pyproject.toml]
|
||||
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.5.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: end-of-file-fixer
|
||||
- id: check-yaml
|
||||
- id: check-toml
|
||||
- id: check-case-conflict
|
||||
- id: check-merge-conflict
|
||||
- id: debug-statements
|
@ -22,6 +22,7 @@ Experience the CogVideoX-5B model online at <a href="https://huggingface.co/spac
|
||||
|
||||
## Project Updates
|
||||
|
||||
- 🔥🔥 **News**: ```2025/03/24```: We have launched [CogKit](https://github.com/THUDM/CogKit), a fine-tuning and inference framework for the **CogView4** and **CogVideoX** series. This toolkit allows you to fully explore and utilize our multimodal generation models.
|
||||
- 🔥 **News**: ```2025/02/28```: DDIM Inverse is now supported in `CogVideoX-5B` and `CogVideoX1.5-5B`. Check [here](inference/ddim_inversion.py).
|
||||
- 🔥 **News**: ```2025/01/08```: We have updated the code for `Lora` fine-tuning based on the `diffusers` version model, which uses less GPU memory. For more details, please see [here](finetune/README.md).
|
||||
- 🔥 **News**: ```2024/11/15```: We released the `CogVideoX1.5` model in the diffusers version. Only minor parameter adjustments are needed to continue using previous code.
|
||||
@ -445,8 +446,6 @@ hands-on practice on text-to-video generation. *The original input is in Chinese
|
||||
}
|
||||
```
|
||||
|
||||
We welcome your contributions! You can click [here](resources/contribute.md) for more information.
|
||||
|
||||
## Model-License
|
||||
|
||||
The code in this repository is released under the [Apache 2.0 License](LICENSE).
|
||||
|
11
README_ja.md
11
README_ja.md
@ -21,10 +21,11 @@
|
||||
</p>
|
||||
|
||||
## 更新とニュース
|
||||
- 🔥 **ニュース**: ```2025/02/28```: DDIM Inverse が `CogVideoX-5B` と `CogVideoX1.5-5B` でサポートされました。詳細は [こちら](inference/ddim_inversion.py) をご覧ください。
|
||||
- 🔥 **ニュース**: ```2025/01/08```: 私たちは`diffusers`バージョンのモデルをベースにした`Lora`微調整用のコードを更新しました。より少ないVRAM(ビデオメモリ)で動作します。詳細については[こちら](finetune/README_ja.md)をご覧ください。
|
||||
- 🔥 **ニュース**: ```2024/11/15```: `CogVideoX1.5`モデルのdiffusersバージョンをリリースしました。わずかなパラメータ調整で以前のコードをそのまま利用可能です。
|
||||
- 🔥 **ニュース**: ```2024/11/08```: `CogVideoX1.5` モデルをリリースしました。CogVideoX1.5 は CogVideoX オープンソースモデルのアップグレードバージョンです。
|
||||
- 🔥🔥 ```2025/03/24```: [CogKit](https://github.com/THUDM/CogKit) は **CogView4** および **CogVideoX** シリーズの微調整と推論のためのフレームワークです。このツールキットを活用することで、私たちのマルチモーダル生成モデルを最大限に活用できます。
|
||||
- **ニュース**: ```2025/02/28```: DDIM Inverse が `CogVideoX-5B` と `CogVideoX1.5-5B` でサポートされました。詳細は [こちら](inference/ddim_inversion.py) をご覧ください。
|
||||
- **ニュース**: ```2025/01/08```: 私たちは`diffusers`バージョンのモデルをベースにした`Lora`微調整用のコードを更新しました。より少ないVRAM(ビデオメモリ)で動作します。詳細については[こちら](finetune/README_ja.md)をご覧ください。
|
||||
- **ニュース**: ```2024/11/15```: `CogVideoX1.5` モデルのdiffusersバージョンをリリースしました。わずかなパラメータ調整で以前のコードをそのまま利用可能です。
|
||||
- **ニュース**: ```2024/11/08```: `CogVideoX1.5` モデルをリリースしました。CogVideoX1.5 は CogVideoX オープンソースモデルのアップグレードバージョンです。
|
||||
CogVideoX1.5-5B シリーズモデルは、10秒 長の動画とより高い解像度をサポートしており、`CogVideoX1.5-5B-I2V` は任意の解像度での動画生成に対応しています。
|
||||
SAT コードはすでに更新されており、`diffusers` バージョンは現在適応中です。
|
||||
SAT バージョンのコードは [こちら](https://huggingface.co/THUDM/CogVideoX1.5-5B-SAT) からダウンロードできます。
|
||||
@ -419,8 +420,6 @@ CogVideoのデモは [https://models.aminer.cn/cogvideo](https://models.aminer.c
|
||||
}
|
||||
```
|
||||
|
||||
あなたの貢献をお待ちしています!詳細は[こちら](resources/contribute_ja.md)をクリックしてください。
|
||||
|
||||
## ライセンス契約
|
||||
|
||||
このリポジトリのコードは [Apache 2.0 License](LICENSE) の下で公開されています。
|
||||
|
@ -22,6 +22,7 @@
|
||||
|
||||
## 项目更新
|
||||
|
||||
- 🔥🔥 **News**: ```2025/03/24```: 我们推出了 [CogKit](https://github.com/THUDM/CogKit) 工具,这是一个微调**CogView4**, **CogVideoX** 系列的微调和推理框架,一个工具包,玩转我们的多模态生成模型。
|
||||
- 🔥 **News**: ```2025/02/28```: DDIM Inverse 已经在`CogVideoX-5B` 和 `CogVideoX1.5 -5B` 支持,查看 [here](inference/ddim_inversion.py).
|
||||
- 🔥 **News**: ```2025/01/08```: 我们更新了基于`diffusers`版本模型的`Lora`微调代码,占用显存更低,详情请见[这里](finetune/README_zh.md)。
|
||||
- 🔥 **News**: ```2024/11/15```: 我们发布 `CogVideoX1.5` 模型的diffusers版本,仅需调整部分参数仅可沿用之前的代码。
|
||||
@ -399,8 +400,6 @@ CogVideo的demo网站在[https://models.aminer.cn/cogvideo](https://models.amine
|
||||
}
|
||||
```
|
||||
|
||||
我们欢迎您的贡献,您可以点击[这里](resources/contribute_zh.md)查看更多信息。
|
||||
|
||||
## 模型协议
|
||||
|
||||
本仓库代码使用 [Apache 2.0 协议](LICENSE) 发布。
|
||||
|
@ -26,7 +26,11 @@ class BucketSampler(Sampler):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, data_source: Dataset, batch_size: int = 8, shuffle: bool = True, drop_last: bool = False
|
||||
self,
|
||||
data_source: Dataset,
|
||||
batch_size: int = 8,
|
||||
shuffle: bool = True,
|
||||
drop_last: bool = False,
|
||||
) -> None:
|
||||
self.data_source = data_source
|
||||
self.batch_size = batch_size
|
||||
@ -48,7 +52,11 @@ class BucketSampler(Sampler):
|
||||
def __iter__(self):
|
||||
for index, data in enumerate(self.data_source):
|
||||
video_metadata = data["video_metadata"]
|
||||
f, h, w = video_metadata["num_frames"], video_metadata["height"], video_metadata["width"]
|
||||
f, h, w = (
|
||||
video_metadata["num_frames"],
|
||||
video_metadata["height"],
|
||||
video_metadata["width"],
|
||||
)
|
||||
|
||||
self.buckets[(f, h, w)].append(data)
|
||||
if len(self.buckets[(f, h, w)]) == self.batch_size:
|
||||
|
@ -115,7 +115,9 @@ class BaseI2VDataset(Dataset):
|
||||
train_resolution_str = "x".join(str(x) for x in self.trainer.args.train_resolution)
|
||||
|
||||
cache_dir = self.trainer.args.data_root / "cache"
|
||||
video_latent_dir = cache_dir / "video_latent" / self.trainer.args.model_name / train_resolution_str
|
||||
video_latent_dir = (
|
||||
cache_dir / "video_latent" / self.trainer.args.model_name / train_resolution_str
|
||||
)
|
||||
prompt_embeddings_dir = cache_dir / "prompt_embeddings"
|
||||
video_latent_dir.mkdir(parents=True, exist_ok=True)
|
||||
prompt_embeddings_dir.mkdir(parents=True, exist_ok=True)
|
||||
@ -136,7 +138,9 @@ class BaseI2VDataset(Dataset):
|
||||
# [1, seq_len, hidden_size] -> [seq_len, hidden_size]
|
||||
prompt_embedding = prompt_embedding[0]
|
||||
save_file({"prompt_embedding": prompt_embedding}, prompt_embedding_path)
|
||||
logger.info(f"Saved prompt embedding to {prompt_embedding_path}", main_process_only=False)
|
||||
logger.info(
|
||||
f"Saved prompt embedding to {prompt_embedding_path}", main_process_only=False
|
||||
)
|
||||
|
||||
if encoded_video_path.exists():
|
||||
encoded_video = load_file(encoded_video_path)["encoded_video"]
|
||||
@ -177,7 +181,9 @@ class BaseI2VDataset(Dataset):
|
||||
},
|
||||
}
|
||||
|
||||
def preprocess(self, video_path: Path | None, image_path: Path | None) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def preprocess(
|
||||
self, video_path: Path | None, image_path: Path | None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Loads and preprocesses a video and an image.
|
||||
If either path is None, no preprocessing will be done for that input.
|
||||
@ -249,13 +255,19 @@ class I2VDatasetWithResize(BaseI2VDataset):
|
||||
self.height = height
|
||||
self.width = width
|
||||
|
||||
self.__frame_transforms = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)])
|
||||
self.__frame_transforms = transforms.Compose(
|
||||
[transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)]
|
||||
)
|
||||
self.__image_transforms = self.__frame_transforms
|
||||
|
||||
@override
|
||||
def preprocess(self, video_path: Path | None, image_path: Path | None) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def preprocess(
|
||||
self, video_path: Path | None, image_path: Path | None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if video_path is not None:
|
||||
video = preprocess_video_with_resize(video_path, self.max_num_frames, self.height, self.width)
|
||||
video = preprocess_video_with_resize(
|
||||
video_path, self.max_num_frames, self.height, self.width
|
||||
)
|
||||
else:
|
||||
video = None
|
||||
if image_path is not None:
|
||||
@ -293,7 +305,9 @@ class I2VDatasetWithBuckets(BaseI2VDataset):
|
||||
)
|
||||
for b in video_resolution_buckets
|
||||
]
|
||||
self.__frame_transforms = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)])
|
||||
self.__frame_transforms = transforms.Compose(
|
||||
[transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)]
|
||||
)
|
||||
self.__image_transforms = self.__frame_transforms
|
||||
|
||||
@override
|
||||
|
@ -11,7 +11,12 @@ from typing_extensions import override
|
||||
|
||||
from finetune.constants import LOG_LEVEL, LOG_NAME
|
||||
|
||||
from .utils import load_prompts, load_videos, preprocess_video_with_buckets, preprocess_video_with_resize
|
||||
from .utils import (
|
||||
load_prompts,
|
||||
load_videos,
|
||||
preprocess_video_with_buckets,
|
||||
preprocess_video_with_resize,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -93,7 +98,9 @@ class BaseT2VDataset(Dataset):
|
||||
train_resolution_str = "x".join(str(x) for x in self.trainer.args.train_resolution)
|
||||
|
||||
cache_dir = self.trainer.args.data_root / "cache"
|
||||
video_latent_dir = cache_dir / "video_latent" / self.trainer.args.model_name / train_resolution_str
|
||||
video_latent_dir = (
|
||||
cache_dir / "video_latent" / self.trainer.args.model_name / train_resolution_str
|
||||
)
|
||||
prompt_embeddings_dir = cache_dir / "prompt_embeddings"
|
||||
video_latent_dir.mkdir(parents=True, exist_ok=True)
|
||||
prompt_embeddings_dir.mkdir(parents=True, exist_ok=True)
|
||||
@ -114,7 +121,9 @@ class BaseT2VDataset(Dataset):
|
||||
# [1, seq_len, hidden_size] -> [seq_len, hidden_size]
|
||||
prompt_embedding = prompt_embedding[0]
|
||||
save_file({"prompt_embedding": prompt_embedding}, prompt_embedding_path)
|
||||
logger.info(f"Saved prompt embedding to {prompt_embedding_path}", main_process_only=False)
|
||||
logger.info(
|
||||
f"Saved prompt embedding to {prompt_embedding_path}", main_process_only=False
|
||||
)
|
||||
|
||||
if encoded_video_path.exists():
|
||||
# encoded_video = torch.load(encoded_video_path, weights_only=True)
|
||||
@ -202,7 +211,9 @@ class T2VDatasetWithResize(BaseT2VDataset):
|
||||
self.height = height
|
||||
self.width = width
|
||||
|
||||
self.__frame_transform = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)])
|
||||
self.__frame_transform = transforms.Compose(
|
||||
[transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)]
|
||||
)
|
||||
|
||||
@override
|
||||
def preprocess(self, video_path: Path) -> torch.Tensor:
|
||||
@ -240,7 +251,9 @@ class T2VDatasetWithBuckets(BaseT2VDataset):
|
||||
for b in video_resolution_buckets
|
||||
]
|
||||
|
||||
self.__frame_transform = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)])
|
||||
self.__frame_transform = transforms.Compose(
|
||||
[transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)]
|
||||
)
|
||||
|
||||
@override
|
||||
def preprocess(self, video_path: Path) -> torch.Tensor:
|
||||
|
@ -24,12 +24,16 @@ def load_prompts(prompt_path: Path) -> List[str]:
|
||||
|
||||
def load_videos(video_path: Path) -> List[Path]:
|
||||
with open(video_path, "r", encoding="utf-8") as file:
|
||||
return [video_path.parent / line.strip() for line in file.readlines() if len(line.strip()) > 0]
|
||||
return [
|
||||
video_path.parent / line.strip() for line in file.readlines() if len(line.strip()) > 0
|
||||
]
|
||||
|
||||
|
||||
def load_images(image_path: Path) -> List[Path]:
|
||||
with open(image_path, "r", encoding="utf-8") as file:
|
||||
return [image_path.parent / line.strip() for line in file.readlines() if len(line.strip()) > 0]
|
||||
return [
|
||||
image_path.parent / line.strip() for line in file.readlines() if len(line.strip()) > 0
|
||||
]
|
||||
|
||||
|
||||
def load_images_from_videos(videos_path: List[Path]) -> List[Path]:
|
||||
@ -169,7 +173,9 @@ def preprocess_video_with_buckets(
|
||||
video_num_frames = len(video_reader)
|
||||
resolution_buckets = [bucket for bucket in resolution_buckets if bucket[0] <= video_num_frames]
|
||||
if len(resolution_buckets) == 0:
|
||||
raise ValueError(f"video frame count in {video_path} is less than all frame buckets {resolution_buckets}")
|
||||
raise ValueError(
|
||||
f"video frame count in {video_path} is less than all frame buckets {resolution_buckets}"
|
||||
)
|
||||
|
||||
nearest_frame_bucket = min(
|
||||
resolution_buckets,
|
||||
@ -181,7 +187,9 @@ def preprocess_video_with_buckets(
|
||||
frames = frames[:nearest_frame_bucket].float()
|
||||
frames = frames.permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
nearest_res = min(resolution_buckets, key=lambda x: abs(x[1] - frames.shape[2]) + abs(x[2] - frames.shape[3]))
|
||||
nearest_res = min(
|
||||
resolution_buckets, key=lambda x: abs(x[1] - frames.shape[2]) + abs(x[2] - frames.shape[3])
|
||||
)
|
||||
nearest_res = (nearest_res[1], nearest_res[2])
|
||||
frames = torch.stack([resize(f, nearest_res) for f in frames], dim=0)
|
||||
|
||||
|
@ -32,13 +32,19 @@ class CogVideoXI2VLoraTrainer(Trainer):
|
||||
|
||||
components.tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder="tokenizer")
|
||||
|
||||
components.text_encoder = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder")
|
||||
components.text_encoder = T5EncoderModel.from_pretrained(
|
||||
model_path, subfolder="text_encoder"
|
||||
)
|
||||
|
||||
components.transformer = CogVideoXTransformer3DModel.from_pretrained(model_path, subfolder="transformer")
|
||||
components.transformer = CogVideoXTransformer3DModel.from_pretrained(
|
||||
model_path, subfolder="transformer"
|
||||
)
|
||||
|
||||
components.vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae")
|
||||
|
||||
components.scheduler = CogVideoXDPMScheduler.from_pretrained(model_path, subfolder="scheduler")
|
||||
components.scheduler = CogVideoXDPMScheduler.from_pretrained(
|
||||
model_path, subfolder="scheduler"
|
||||
)
|
||||
|
||||
return components
|
||||
|
||||
@ -73,7 +79,9 @@ class CogVideoXI2VLoraTrainer(Trainer):
|
||||
return_tensors="pt",
|
||||
)
|
||||
prompt_token_ids = prompt_token_ids.input_ids
|
||||
prompt_embedding = self.components.text_encoder(prompt_token_ids.to(self.accelerator.device))[0]
|
||||
prompt_embedding = self.components.text_encoder(
|
||||
prompt_token_ids.to(self.accelerator.device)
|
||||
)[0]
|
||||
return prompt_embedding
|
||||
|
||||
@override
|
||||
@ -122,22 +130,34 @@ class CogVideoXI2VLoraTrainer(Trainer):
|
||||
# Add frame dimension to images [B,C,H,W] -> [B,C,F,H,W]
|
||||
images = images.unsqueeze(2)
|
||||
# Add noise to images
|
||||
image_noise_sigma = torch.normal(mean=-3.0, std=0.5, size=(1,), device=self.accelerator.device)
|
||||
image_noise_sigma = torch.normal(
|
||||
mean=-3.0, std=0.5, size=(1,), device=self.accelerator.device
|
||||
)
|
||||
image_noise_sigma = torch.exp(image_noise_sigma).to(dtype=images.dtype)
|
||||
noisy_images = images + torch.randn_like(images) * image_noise_sigma[:, None, None, None, None]
|
||||
image_latent_dist = self.components.vae.encode(noisy_images.to(dtype=self.components.vae.dtype)).latent_dist
|
||||
noisy_images = (
|
||||
images + torch.randn_like(images) * image_noise_sigma[:, None, None, None, None]
|
||||
)
|
||||
image_latent_dist = self.components.vae.encode(
|
||||
noisy_images.to(dtype=self.components.vae.dtype)
|
||||
).latent_dist
|
||||
image_latents = image_latent_dist.sample() * self.components.vae.config.scaling_factor
|
||||
|
||||
# Sample a random timestep for each sample
|
||||
timesteps = torch.randint(
|
||||
0, self.components.scheduler.config.num_train_timesteps, (batch_size,), device=self.accelerator.device
|
||||
0,
|
||||
self.components.scheduler.config.num_train_timesteps,
|
||||
(batch_size,),
|
||||
device=self.accelerator.device,
|
||||
)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
# from [B, C, F, H, W] to [B, F, C, H, W]
|
||||
latent = latent.permute(0, 2, 1, 3, 4)
|
||||
image_latents = image_latents.permute(0, 2, 1, 3, 4)
|
||||
assert (latent.shape[0], *latent.shape[2:]) == (image_latents.shape[0], *image_latents.shape[2:])
|
||||
assert (latent.shape[0], *latent.shape[2:]) == (
|
||||
image_latents.shape[0],
|
||||
*image_latents.shape[2:],
|
||||
)
|
||||
|
||||
# Padding image_latents to the same frame number as latent
|
||||
padding_shape = (latent.shape[0], latent.shape[1] - 1, *latent.shape[2:])
|
||||
@ -169,7 +189,9 @@ class CogVideoXI2VLoraTrainer(Trainer):
|
||||
|
||||
# Predict noise, For CogVideoX1.5 Only.
|
||||
ofs_emb = (
|
||||
None if self.state.transformer_config.ofs_embed_dim is None else latent.new_full((1,), fill_value=2.0)
|
||||
None
|
||||
if self.state.transformer_config.ofs_embed_dim is None
|
||||
else latent.new_full((1,), fill_value=2.0)
|
||||
)
|
||||
predicted_noise = self.components.transformer(
|
||||
hidden_states=latent_img_noisy,
|
||||
@ -181,7 +203,9 @@ class CogVideoXI2VLoraTrainer(Trainer):
|
||||
)[0]
|
||||
|
||||
# Denoise
|
||||
latent_pred = self.components.scheduler.get_velocity(predicted_noise, latent_noisy, timesteps)
|
||||
latent_pred = self.components.scheduler.get_velocity(
|
||||
predicted_noise, latent_noisy, timesteps
|
||||
)
|
||||
|
||||
alphas_cumprod = self.components.scheduler.alphas_cumprod[timesteps]
|
||||
weights = 1 / (1 - alphas_cumprod)
|
||||
@ -228,7 +252,9 @@ class CogVideoXI2VLoraTrainer(Trainer):
|
||||
if transformer_config.patch_size_t is None:
|
||||
base_num_frames = num_frames
|
||||
else:
|
||||
base_num_frames = (num_frames + transformer_config.patch_size_t - 1) // transformer_config.patch_size_t
|
||||
base_num_frames = (
|
||||
num_frames + transformer_config.patch_size_t - 1
|
||||
) // transformer_config.patch_size_t
|
||||
|
||||
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
||||
embed_dim=transformer_config.attention_head_dim,
|
||||
|
@ -31,13 +31,19 @@ class CogVideoXT2VLoraTrainer(Trainer):
|
||||
|
||||
components.tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder="tokenizer")
|
||||
|
||||
components.text_encoder = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder")
|
||||
components.text_encoder = T5EncoderModel.from_pretrained(
|
||||
model_path, subfolder="text_encoder"
|
||||
)
|
||||
|
||||
components.transformer = CogVideoXTransformer3DModel.from_pretrained(model_path, subfolder="transformer")
|
||||
components.transformer = CogVideoXTransformer3DModel.from_pretrained(
|
||||
model_path, subfolder="transformer"
|
||||
)
|
||||
|
||||
components.vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae")
|
||||
|
||||
components.scheduler = CogVideoXDPMScheduler.from_pretrained(model_path, subfolder="scheduler")
|
||||
components.scheduler = CogVideoXDPMScheduler.from_pretrained(
|
||||
model_path, subfolder="scheduler"
|
||||
)
|
||||
|
||||
return components
|
||||
|
||||
@ -72,7 +78,9 @@ class CogVideoXT2VLoraTrainer(Trainer):
|
||||
return_tensors="pt",
|
||||
)
|
||||
prompt_token_ids = prompt_token_ids.input_ids
|
||||
prompt_embedding = self.components.text_encoder(prompt_token_ids.to(self.accelerator.device))[0]
|
||||
prompt_embedding = self.components.text_encoder(
|
||||
prompt_token_ids.to(self.accelerator.device)
|
||||
)[0]
|
||||
return prompt_embedding
|
||||
|
||||
@override
|
||||
@ -115,7 +123,10 @@ class CogVideoXT2VLoraTrainer(Trainer):
|
||||
|
||||
# Sample a random timestep for each sample
|
||||
timesteps = torch.randint(
|
||||
0, self.components.scheduler.config.num_train_timesteps, (batch_size,), device=self.accelerator.device
|
||||
0,
|
||||
self.components.scheduler.config.num_train_timesteps,
|
||||
(batch_size,),
|
||||
device=self.accelerator.device,
|
||||
)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
@ -150,7 +161,9 @@ class CogVideoXT2VLoraTrainer(Trainer):
|
||||
)[0]
|
||||
|
||||
# Denoise
|
||||
latent_pred = self.components.scheduler.get_velocity(predicted_noise, latent_added_noise, timesteps)
|
||||
latent_pred = self.components.scheduler.get_velocity(
|
||||
predicted_noise, latent_added_noise, timesteps
|
||||
)
|
||||
|
||||
alphas_cumprod = self.components.scheduler.alphas_cumprod[timesteps]
|
||||
weights = 1 / (1 - alphas_cumprod)
|
||||
@ -196,7 +209,9 @@ class CogVideoXT2VLoraTrainer(Trainer):
|
||||
if transformer_config.patch_size_t is None:
|
||||
base_num_frames = num_frames
|
||||
else:
|
||||
base_num_frames = (num_frames + transformer_config.patch_size_t - 1) // transformer_config.patch_size_t
|
||||
base_num_frames = (
|
||||
num_frames + transformer_config.patch_size_t - 1
|
||||
) // transformer_config.patch_size_t
|
||||
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
||||
embed_dim=transformer_config.attention_head_dim,
|
||||
crops_coords=None,
|
||||
|
@ -52,6 +52,8 @@ def get_model_cls(model_type: str, training_type: Literal["lora", "sft"]) -> Tra
|
||||
print(f"\nSupported training types for '{model_type}' are:")
|
||||
for supported_type in SUPPORTED_MODELS[model_type]:
|
||||
print(f" • {supported_type}")
|
||||
raise ValueError(f"Training type '{training_type}' is not supported for model '{model_type}'")
|
||||
raise ValueError(
|
||||
f"Training type '{training_type}' is not supported for model '{model_type}'"
|
||||
)
|
||||
|
||||
return SUPPORTED_MODELS[model_type][training_type]
|
||||
|
@ -115,14 +115,18 @@ class Args(BaseModel):
|
||||
def validate_validation_images(cls, v: str | None, info: ValidationInfo) -> str | None:
|
||||
values = info.data
|
||||
if values.get("do_validation") and values.get("model_type") == "i2v" and not v:
|
||||
raise ValueError("validation_images must be specified when do_validation is True and model_type is i2v")
|
||||
raise ValueError(
|
||||
"validation_images must be specified when do_validation is True and model_type is i2v"
|
||||
)
|
||||
return v
|
||||
|
||||
@field_validator("validation_videos")
|
||||
def validate_validation_videos(cls, v: str | None, info: ValidationInfo) -> str | None:
|
||||
values = info.data
|
||||
if values.get("do_validation") and values.get("model_type") == "v2v" and not v:
|
||||
raise ValueError("validation_videos must be specified when do_validation is True and model_type is v2v")
|
||||
raise ValueError(
|
||||
"validation_videos must be specified when do_validation is True and model_type is v2v"
|
||||
)
|
||||
return v
|
||||
|
||||
@field_validator("validation_steps")
|
||||
@ -148,7 +152,9 @@ class Args(BaseModel):
|
||||
model_name = info.data.get("model_name", "")
|
||||
if model_name in ["cogvideox-5b-i2v", "cogvideox-5b-t2v"]:
|
||||
if (height, width) != (480, 720):
|
||||
raise ValueError("For cogvideox-5b models, height must be 480 and width must be 720")
|
||||
raise ValueError(
|
||||
"For cogvideox-5b models, height must be 480 and width must be 720"
|
||||
)
|
||||
|
||||
return v
|
||||
|
||||
@ -221,7 +227,9 @@ class Args(BaseModel):
|
||||
# LoRA parameters
|
||||
parser.add_argument("--rank", type=int, default=128)
|
||||
parser.add_argument("--lora_alpha", type=int, default=64)
|
||||
parser.add_argument("--target_modules", type=str, nargs="+", default=["to_q", "to_k", "to_v", "to_out.0"])
|
||||
parser.add_argument(
|
||||
"--target_modules", type=str, nargs="+", default=["to_q", "to_k", "to_v", "to_out.0"]
|
||||
)
|
||||
|
||||
# Checkpointing
|
||||
parser.add_argument("--checkpointing_steps", type=int, default=200)
|
||||
|
@ -8,7 +8,10 @@ import cv2
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--datadir", type=str, required=True, help="Root directory containing videos.txt and video subdirectory"
|
||||
"--datadir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Root directory containing videos.txt and video subdirectory",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
@ -88,7 +88,9 @@ class Trainer:
|
||||
|
||||
def _init_distributed(self):
|
||||
logging_dir = Path(self.args.output_dir, "logs")
|
||||
project_config = ProjectConfiguration(project_dir=self.args.output_dir, logging_dir=logging_dir)
|
||||
project_config = ProjectConfiguration(
|
||||
project_dir=self.args.output_dir, logging_dir=logging_dir
|
||||
)
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||
init_process_group_kwargs = InitProcessGroupKwargs(
|
||||
backend="nccl", timeout=timedelta(seconds=self.args.nccl_timeout)
|
||||
@ -183,7 +185,9 @@ class Trainer:
|
||||
# Prepare VAE and text encoder for encoding
|
||||
self.components.vae.requires_grad_(False)
|
||||
self.components.text_encoder.requires_grad_(False)
|
||||
self.components.vae = self.components.vae.to(self.accelerator.device, dtype=self.state.weight_dtype)
|
||||
self.components.vae = self.components.vae.to(
|
||||
self.accelerator.device, dtype=self.state.weight_dtype
|
||||
)
|
||||
self.components.text_encoder = self.components.text_encoder.to(
|
||||
self.accelerator.device, dtype=self.state.weight_dtype
|
||||
)
|
||||
@ -263,7 +267,9 @@ class Trainer:
|
||||
|
||||
# For LoRA, we only want to train the LoRA weights
|
||||
# For SFT, we want to train all the parameters
|
||||
trainable_parameters = list(filter(lambda p: p.requires_grad, self.components.transformer.parameters()))
|
||||
trainable_parameters = list(
|
||||
filter(lambda p: p.requires_grad, self.components.transformer.parameters())
|
||||
)
|
||||
transformer_parameters_with_lr = {
|
||||
"params": trainable_parameters,
|
||||
"lr": self.args.learning_rate,
|
||||
@ -287,7 +293,9 @@ class Trainer:
|
||||
use_deepspeed=use_deepspeed_opt,
|
||||
)
|
||||
|
||||
num_update_steps_per_epoch = math.ceil(len(self.data_loader) / self.args.gradient_accumulation_steps)
|
||||
num_update_steps_per_epoch = math.ceil(
|
||||
len(self.data_loader) / self.args.gradient_accumulation_steps
|
||||
)
|
||||
if self.args.train_steps is None:
|
||||
self.args.train_steps = self.args.train_epochs * num_update_steps_per_epoch
|
||||
self.state.overwrote_max_train_steps = True
|
||||
@ -322,12 +330,16 @@ class Trainer:
|
||||
self.lr_scheduler = lr_scheduler
|
||||
|
||||
def prepare_for_training(self) -> None:
|
||||
self.components.transformer, self.optimizer, self.data_loader, self.lr_scheduler = self.accelerator.prepare(
|
||||
self.components.transformer, self.optimizer, self.data_loader, self.lr_scheduler
|
||||
self.components.transformer, self.optimizer, self.data_loader, self.lr_scheduler = (
|
||||
self.accelerator.prepare(
|
||||
self.components.transformer, self.optimizer, self.data_loader, self.lr_scheduler
|
||||
)
|
||||
)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(self.data_loader) / self.args.gradient_accumulation_steps)
|
||||
num_update_steps_per_epoch = math.ceil(
|
||||
len(self.data_loader) / self.args.gradient_accumulation_steps
|
||||
)
|
||||
if self.state.overwrote_max_train_steps:
|
||||
self.args.train_steps = self.args.train_epochs * num_update_steps_per_epoch
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
@ -364,7 +376,9 @@ class Trainer:
|
||||
logger.info(f"Memory before training start: {json.dumps(memory_statistics, indent=4)}")
|
||||
|
||||
self.state.total_batch_size_count = (
|
||||
self.args.batch_size * self.accelerator.num_processes * self.args.gradient_accumulation_steps
|
||||
self.args.batch_size
|
||||
* self.accelerator.num_processes
|
||||
* self.args.gradient_accumulation_steps
|
||||
)
|
||||
info = {
|
||||
"trainable parameters": self.state.num_trainable_parameters,
|
||||
@ -454,7 +468,9 @@ class Trainer:
|
||||
progress_bar.set_postfix(logs)
|
||||
|
||||
# Maybe run validation
|
||||
should_run_validation = self.args.do_validation and global_step % self.args.validation_steps == 0
|
||||
should_run_validation = (
|
||||
self.args.do_validation and global_step % self.args.validation_steps == 0
|
||||
)
|
||||
if should_run_validation:
|
||||
del loss
|
||||
free_memory()
|
||||
@ -466,7 +482,9 @@ class Trainer:
|
||||
break
|
||||
|
||||
memory_statistics = get_memory_statistics()
|
||||
logger.info(f"Memory after epoch {epoch + 1}: {json.dumps(memory_statistics, indent=4)}")
|
||||
logger.info(
|
||||
f"Memory after epoch {epoch + 1}: {json.dumps(memory_statistics, indent=4)}"
|
||||
)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
self.__maybe_save_checkpoint(global_step, must_save=True)
|
||||
@ -504,7 +522,9 @@ class Trainer:
|
||||
# Can't using model_cpu_offload in deepspeed,
|
||||
# so we need to move all components in pipe to device
|
||||
# pipe.to(self.accelerator.device, dtype=self.state.weight_dtype)
|
||||
self.__move_components_to_device(dtype=self.state.weight_dtype, ignore_list=["transformer"])
|
||||
self.__move_components_to_device(
|
||||
dtype=self.state.weight_dtype, ignore_list=["transformer"]
|
||||
)
|
||||
else:
|
||||
# if not using deepspeed, use model_cpu_offload to further reduce memory usage
|
||||
# Or use pipe.enable_sequential_cpu_offload() to further reduce memory usage
|
||||
@ -528,7 +548,9 @@ class Trainer:
|
||||
video = self.state.validation_videos[i]
|
||||
|
||||
if image is not None:
|
||||
image = preprocess_image_with_resize(image, self.state.train_height, self.state.train_width)
|
||||
image = preprocess_image_with_resize(
|
||||
image, self.state.train_height, self.state.train_width
|
||||
)
|
||||
# Convert image tensor (C, H, W) to PIL images
|
||||
image = image.to(torch.uint8)
|
||||
image = image.permute(1, 2, 0).cpu().numpy()
|
||||
@ -546,7 +568,9 @@ class Trainer:
|
||||
f"Validating sample {i + 1}/{num_validation_samples} on process {accelerator.process_index}. Prompt: {prompt}",
|
||||
main_process_only=False,
|
||||
)
|
||||
validation_artifacts = self.validation_step({"prompt": prompt, "image": image, "video": video}, pipe)
|
||||
validation_artifacts = self.validation_step(
|
||||
{"prompt": prompt, "image": image, "video": video}, pipe
|
||||
)
|
||||
|
||||
if (
|
||||
self.state.using_deepspeed
|
||||
@ -565,7 +589,9 @@ class Trainer:
|
||||
"video": {"type": "video", "value": video},
|
||||
}
|
||||
for i, (artifact_type, artifact_value) in enumerate(validation_artifacts):
|
||||
artifacts.update({f"artifact_{i}": {"type": artifact_type, "value": artifact_value}})
|
||||
artifacts.update(
|
||||
{f"artifact_{i}": {"type": artifact_type, "value": artifact_value}}
|
||||
)
|
||||
logger.debug(
|
||||
f"Validation artifacts on process {accelerator.process_index}: {list(artifacts.keys())}",
|
||||
main_process_only=False,
|
||||
@ -600,8 +626,12 @@ class Trainer:
|
||||
tracker_key = "validation"
|
||||
for tracker in accelerator.trackers:
|
||||
if tracker.name == "wandb":
|
||||
image_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Image)]
|
||||
video_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Video)]
|
||||
image_artifacts = [
|
||||
artifact for artifact in all_artifacts if isinstance(artifact, wandb.Image)
|
||||
]
|
||||
video_artifacts = [
|
||||
artifact for artifact in all_artifacts if isinstance(artifact, wandb.Video)
|
||||
]
|
||||
tracker.log(
|
||||
{
|
||||
tracker_key: {"images": image_artifacts, "videos": video_artifacts},
|
||||
@ -618,7 +648,9 @@ class Trainer:
|
||||
pipe.remove_all_hooks()
|
||||
del pipe
|
||||
# Load models except those not needed for training
|
||||
self.__move_components_to_device(dtype=self.state.weight_dtype, ignore_list=self.UNLOAD_LIST)
|
||||
self.__move_components_to_device(
|
||||
dtype=self.state.weight_dtype, ignore_list=self.UNLOAD_LIST
|
||||
)
|
||||
self.components.transformer.to(self.accelerator.device, dtype=self.state.weight_dtype)
|
||||
|
||||
# Change trainable weights back to fp32 to keep with dtype after prepare the model
|
||||
@ -687,7 +719,9 @@ class Trainer:
|
||||
for name, component in components.items():
|
||||
if not isinstance(component, type) and hasattr(component, "to"):
|
||||
if name not in ignore_list:
|
||||
setattr(self.components, name, component.to(self.accelerator.device, dtype=dtype))
|
||||
setattr(
|
||||
self.components, name, component.to(self.accelerator.device, dtype=dtype)
|
||||
)
|
||||
|
||||
def __move_components_to_cpu(self, unload_list: List[str] = []):
|
||||
unload_list = set(unload_list)
|
||||
@ -732,11 +766,13 @@ class Trainer:
|
||||
):
|
||||
transformer_ = unwrap_model(self.accelerator, model)
|
||||
else:
|
||||
raise ValueError(f"Unexpected save model: {unwrap_model(self.accelerator, model).__class__}")
|
||||
raise ValueError(
|
||||
f"Unexpected save model: {unwrap_model(self.accelerator, model).__class__}"
|
||||
)
|
||||
else:
|
||||
transformer_ = unwrap_model(self.accelerator, self.components.transformer).__class__.from_pretrained(
|
||||
self.args.model_path, subfolder="transformer"
|
||||
)
|
||||
transformer_ = unwrap_model(
|
||||
self.accelerator, self.components.transformer
|
||||
).__class__.from_pretrained(self.args.model_path, subfolder="transformer")
|
||||
transformer_.add_adapter(transformer_lora_config)
|
||||
|
||||
lora_state_dict = self.components.pipeline_cls.lora_state_dict(input_dir)
|
||||
@ -745,7 +781,9 @@ class Trainer:
|
||||
for k, v in lora_state_dict.items()
|
||||
if k.startswith("transformer.")
|
||||
}
|
||||
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
|
||||
incompatible_keys = set_peft_model_state_dict(
|
||||
transformer_, transformer_state_dict, adapter_name="default"
|
||||
)
|
||||
if incompatible_keys is not None:
|
||||
# check only for unexpected keys
|
||||
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
||||
@ -759,7 +797,10 @@ class Trainer:
|
||||
self.accelerator.register_load_state_pre_hook(load_model_hook)
|
||||
|
||||
def __maybe_save_checkpoint(self, global_step: int, must_save: bool = False):
|
||||
if self.accelerator.distributed_type == DistributedType.DEEPSPEED or self.accelerator.is_main_process:
|
||||
if (
|
||||
self.accelerator.distributed_type == DistributedType.DEEPSPEED
|
||||
or self.accelerator.is_main_process
|
||||
):
|
||||
if must_save or global_step % self.args.checkpointing_steps == 0:
|
||||
# for training
|
||||
save_path = get_intermediate_ckpt_path(
|
||||
|
@ -23,7 +23,9 @@ def get_latest_ckpt_path_to_resume_from(
|
||||
else:
|
||||
resume_from_checkpoint_path = Path(resume_from_checkpoint)
|
||||
if not resume_from_checkpoint_path.exists():
|
||||
logger.info(f"Checkpoint '{resume_from_checkpoint}' does not exist. Starting a new training run.")
|
||||
logger.info(
|
||||
f"Checkpoint '{resume_from_checkpoint}' does not exist. Starting a new training run."
|
||||
)
|
||||
initial_global_step = 0
|
||||
global_step = 0
|
||||
first_epoch = 0
|
||||
|
@ -55,7 +55,9 @@ def unload_model(model):
|
||||
model.to("cpu")
|
||||
|
||||
|
||||
def make_contiguous(x: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
def make_contiguous(
|
||||
x: Union[torch.Tensor, Dict[str, torch.Tensor]],
|
||||
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
if isinstance(x, torch.Tensor):
|
||||
return x.contiguous()
|
||||
elif isinstance(x, dict):
|
||||
|
@ -67,7 +67,9 @@ def get_optimizer(
|
||||
optimizer_name = "adamw"
|
||||
|
||||
if (use_8bit or use_4bit) and optimizer_name not in ["adam", "adamw"]:
|
||||
raise ValueError("`use_8bit` and `use_4bit` can only be used with the Adam and AdamW optimizers.")
|
||||
raise ValueError(
|
||||
"`use_8bit` and `use_4bit` can only be used with the Adam and AdamW optimizers."
|
||||
)
|
||||
|
||||
if use_8bit:
|
||||
try:
|
||||
@ -81,7 +83,9 @@ def get_optimizer(
|
||||
if use_torchao:
|
||||
from torchao.prototype.low_bit_optim import AdamW4bit, AdamW8bit
|
||||
|
||||
optimizer_class = AdamW8bit if use_8bit else AdamW4bit if use_4bit else torch.optim.AdamW
|
||||
optimizer_class = (
|
||||
AdamW8bit if use_8bit else AdamW4bit if use_4bit else torch.optim.AdamW
|
||||
)
|
||||
else:
|
||||
optimizer_class = bnb.optim.AdamW8bit if use_8bit else torch.optim.AdamW
|
||||
|
||||
@ -109,7 +113,9 @@ def get_optimizer(
|
||||
try:
|
||||
import prodigyopt
|
||||
except ImportError:
|
||||
raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
|
||||
raise ImportError(
|
||||
"To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`"
|
||||
)
|
||||
|
||||
optimizer_class = prodigyopt.Prodigy
|
||||
|
||||
@ -133,7 +139,9 @@ def get_optimizer(
|
||||
try:
|
||||
import came_pytorch
|
||||
except ImportError:
|
||||
raise ImportError("To use CAME, please install the came-pytorch library: `pip install came-pytorch`")
|
||||
raise ImportError(
|
||||
"To use CAME, please install the came-pytorch library: `pip install came-pytorch`"
|
||||
)
|
||||
|
||||
optimizer_class = came_pytorch.CAME
|
||||
|
||||
@ -151,7 +159,10 @@ def get_optimizer(
|
||||
init_kwargs.update({"fused": True})
|
||||
|
||||
optimizer = CPUOffloadOptimizer(
|
||||
params_to_optimize, optimizer_class=optimizer_class, offload_gradients=offload_gradients, **init_kwargs
|
||||
params_to_optimize,
|
||||
optimizer_class=optimizer_class,
|
||||
offload_gradients=offload_gradients,
|
||||
**init_kwargs,
|
||||
)
|
||||
else:
|
||||
optimizer = optimizer_class(params_to_optimize, **init_kwargs)
|
||||
|
@ -99,7 +99,9 @@ def generate_video(
|
||||
desired_resolution = RESOLUTION_MAP[model_name]
|
||||
if width is None or height is None:
|
||||
height, width = desired_resolution
|
||||
logging.info(f"\033[1mUsing default resolution {desired_resolution} for {model_name}\033[0m")
|
||||
logging.info(
|
||||
f"\033[1mUsing default resolution {desired_resolution} for {model_name}\033[0m"
|
||||
)
|
||||
elif (height, width) != desired_resolution:
|
||||
if generate_type == "i2v":
|
||||
# For i2v models, use user-defined width and height
|
||||
@ -124,7 +126,9 @@ def generate_video(
|
||||
|
||||
# If you're using with lora, add this code
|
||||
if lora_path:
|
||||
pipe.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name="test_1")
|
||||
pipe.load_lora_weights(
|
||||
lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name="test_1"
|
||||
)
|
||||
pipe.fuse_lora(components=["transformer"], lora_scale=1 / lora_rank)
|
||||
|
||||
# 2. Set Scheduler.
|
||||
@ -133,7 +137,9 @@ def generate_video(
|
||||
# using `CogVideoXDPMScheduler` for CogVideoX-5B / CogVideoX-5B-I2V.
|
||||
|
||||
# pipe.scheduler = CogVideoXDDIMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
|
||||
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
|
||||
pipe.scheduler = CogVideoXDPMScheduler.from_config(
|
||||
pipe.scheduler.config, timestep_spacing="trailing"
|
||||
)
|
||||
|
||||
# 3. Enable CPU offload for the model.
|
||||
# turn off if you have multiple GPUs or enough GPU memory(such as H100) and it will cost less time in inference
|
||||
@ -190,8 +196,12 @@ def generate_video(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX")
|
||||
parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated")
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate a video from a text prompt using CogVideoX"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt", type=str, required=True, help="The description of the video to be generated"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image_or_video_path",
|
||||
type=str,
|
||||
@ -199,20 +209,44 @@ if __name__ == "__main__":
|
||||
help="The path of the image to be used as the background of the video",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_path", type=str, default="THUDM/CogVideoX1.5-5B", help="Path of the pre-trained model use"
|
||||
"--model_path",
|
||||
type=str,
|
||||
default="THUDM/CogVideoX1.5-5B",
|
||||
help="Path of the pre-trained model use",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_path", type=str, default=None, help="The path of the LoRA weights to be used"
|
||||
)
|
||||
parser.add_argument("--lora_path", type=str, default=None, help="The path of the LoRA weights to be used")
|
||||
parser.add_argument("--lora_rank", type=int, default=128, help="The rank of the LoRA weights")
|
||||
parser.add_argument("--output_path", type=str, default="./output.mp4", help="The path save generated video")
|
||||
parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance")
|
||||
parser.add_argument(
|
||||
"--output_path", type=str, default="./output.mp4", help="The path save generated video"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance"
|
||||
)
|
||||
parser.add_argument("--num_inference_steps", type=int, default=50, help="Inference steps")
|
||||
parser.add_argument("--num_frames", type=int, default=81, help="Number of steps for the inference process")
|
||||
parser.add_argument(
|
||||
"--num_frames", type=int, default=81, help="Number of steps for the inference process"
|
||||
)
|
||||
parser.add_argument("--width", type=int, default=None, help="The width of the generated video")
|
||||
parser.add_argument("--height", type=int, default=None, help="The height of the generated video")
|
||||
parser.add_argument("--fps", type=int, default=16, help="The frames per second for the generated video")
|
||||
parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt")
|
||||
parser.add_argument("--generate_type", type=str, default="t2v", help="The type of video generation")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16", help="The data type for computation")
|
||||
parser.add_argument(
|
||||
"--height", type=int, default=None, help="The height of the generated video"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fps", type=int, default=16, help="The frames per second for the generated video"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_videos_per_prompt",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of videos to generate per prompt",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--generate_type", type=str, default="t2v", help="The type of video generation"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype", type=str, default="bfloat16", help="The data type for computation"
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=42, help="The seed for reproducibility")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
@ -19,7 +19,12 @@ import argparse
|
||||
import os
|
||||
import torch
|
||||
import torch._dynamo
|
||||
from diffusers import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel, CogVideoXPipeline, CogVideoXDPMScheduler
|
||||
from diffusers import (
|
||||
AutoencoderKLCogVideoX,
|
||||
CogVideoXTransformer3DModel,
|
||||
CogVideoXPipeline,
|
||||
CogVideoXDPMScheduler,
|
||||
)
|
||||
from diffusers.utils import export_to_video
|
||||
from transformers import T5EncoderModel
|
||||
from torchao.quantization import quantize_, int8_weight_only
|
||||
@ -68,9 +73,13 @@ def generate_video(
|
||||
- quantization_scheme (str): The quantization scheme to use ('int8', 'fp8').
|
||||
- dtype (torch.dtype): The data type for computation (default is torch.bfloat16).
|
||||
"""
|
||||
text_encoder = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder", torch_dtype=dtype)
|
||||
text_encoder = T5EncoderModel.from_pretrained(
|
||||
model_path, subfolder="text_encoder", torch_dtype=dtype
|
||||
)
|
||||
text_encoder = quantize_model(part=text_encoder, quantization_scheme=quantization_scheme)
|
||||
transformer = CogVideoXTransformer3DModel.from_pretrained(model_path, subfolder="transformer", torch_dtype=dtype)
|
||||
transformer = CogVideoXTransformer3DModel.from_pretrained(
|
||||
model_path, subfolder="transformer", torch_dtype=dtype
|
||||
)
|
||||
transformer = quantize_model(part=transformer, quantization_scheme=quantization_scheme)
|
||||
vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae", torch_dtype=dtype)
|
||||
vae = quantize_model(part=vae, quantization_scheme=quantization_scheme)
|
||||
@ -81,7 +90,9 @@ def generate_video(
|
||||
vae=vae,
|
||||
torch_dtype=dtype,
|
||||
)
|
||||
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
|
||||
pipe.scheduler = CogVideoXDPMScheduler.from_config(
|
||||
pipe.scheduler.config, timestep_spacing="trailing"
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.vae.enable_slicing()
|
||||
pipe.vae.enable_tiling()
|
||||
@ -100,16 +111,34 @@ def generate_video(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX")
|
||||
parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated")
|
||||
parser.add_argument("--model_path", type=str, default="THUDM/CogVideoX-5b", help="Path of the pre-trained model")
|
||||
parser.add_argument("--output_path", type=str, default="./output.mp4", help="Path to save generated video")
|
||||
parser.add_argument("--num_inference_steps", type=int, default=50, help="Inference steps")
|
||||
parser.add_argument("--guidance_scale", type=float, default=6.0, help="Classifier-free guidance scale")
|
||||
parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Videos to generate per prompt")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16", help="Data type (e.g., 'float16', 'bfloat16')")
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate a video from a text prompt using CogVideoX"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quantization_scheme", type=str, default="fp8", choices=["int8", "fp8"], help="Quantization scheme"
|
||||
"--prompt", type=str, required=True, help="The description of the video to be generated"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_path", type=str, default="THUDM/CogVideoX-5b", help="Path of the pre-trained model"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_path", type=str, default="./output.mp4", help="Path to save generated video"
|
||||
)
|
||||
parser.add_argument("--num_inference_steps", type=int, default=50, help="Inference steps")
|
||||
parser.add_argument(
|
||||
"--guidance_scale", type=float, default=6.0, help="Classifier-free guidance scale"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_videos_per_prompt", type=int, default=1, help="Videos to generate per prompt"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype", type=str, default="bfloat16", help="Data type (e.g., 'float16', 'bfloat16')"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quantization_scheme",
|
||||
type=str,
|
||||
default="fp8",
|
||||
choices=["int8", "fp8"],
|
||||
help="Quantization scheme",
|
||||
)
|
||||
parser.add_argument("--num_frames", type=int, default=81, help="Number of frames in the video")
|
||||
parser.add_argument("--fps", type=int, default=16, help="Frames per second for output video")
|
||||
|
@ -104,18 +104,34 @@ def save_video(tensor, output_path):
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="CogVideoX encode/decode demo")
|
||||
parser.add_argument("--model_path", type=str, required=True, help="The path to the CogVideoX model")
|
||||
parser.add_argument(
|
||||
"--model_path", type=str, required=True, help="The path to the CogVideoX model"
|
||||
)
|
||||
parser.add_argument("--video_path", type=str, help="The path to the video file (for encoding)")
|
||||
parser.add_argument("--encoded_path", type=str, help="The path to the encoded tensor file (for decoding)")
|
||||
parser.add_argument("--output_path", type=str, default=".", help="The path to save the output file")
|
||||
parser.add_argument(
|
||||
"--mode", type=str, choices=["encode", "decode", "both"], required=True, help="Mode: encode, decode, or both"
|
||||
"--encoded_path", type=str, help="The path to the encoded tensor file (for decoding)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype", type=str, default="bfloat16", help="The data type for computation (e.g., 'float16' or 'bfloat16')"
|
||||
"--output_path", type=str, default=".", help="The path to save the output file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device", type=str, default="cuda", help="The device to use for computation (e.g., 'cuda' or 'cpu')"
|
||||
"--mode",
|
||||
type=str,
|
||||
choices=["encode", "decode", "both"],
|
||||
required=True,
|
||||
help="Mode: encode, decode, or both",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=str,
|
||||
default="bfloat16",
|
||||
help="The data type for computation (e.g., 'float16' or 'bfloat16')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda",
|
||||
help="The device to use for computation (e.g., 'cuda' or 'cpu')",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -126,15 +142,21 @@ if __name__ == "__main__":
|
||||
assert args.video_path, "Video path must be provided for encoding."
|
||||
encoded_output = encode_video(args.model_path, args.video_path, dtype, device)
|
||||
torch.save(encoded_output, args.output_path + "/encoded.pt")
|
||||
print(f"Finished encoding the video to a tensor, save it to a file at {encoded_output}/encoded.pt")
|
||||
print(
|
||||
f"Finished encoding the video to a tensor, save it to a file at {encoded_output}/encoded.pt"
|
||||
)
|
||||
elif args.mode == "decode":
|
||||
assert args.encoded_path, "Encoded tensor path must be provided for decoding."
|
||||
decoded_output = decode_video(args.model_path, args.encoded_path, dtype, device)
|
||||
save_video(decoded_output, args.output_path)
|
||||
print(f"Finished decoding the video and saved it to a file at {args.output_path}/output.mp4")
|
||||
print(
|
||||
f"Finished decoding the video and saved it to a file at {args.output_path}/output.mp4"
|
||||
)
|
||||
elif args.mode == "both":
|
||||
assert args.video_path, "Video path must be provided for encoding."
|
||||
encoded_output = encode_video(args.model_path, args.video_path, dtype, device)
|
||||
torch.save(encoded_output, args.output_path + "/encoded.pt")
|
||||
decoded_output = decode_video(args.model_path, args.output_path + "/encoded.pt", dtype, device)
|
||||
decoded_output = decode_video(
|
||||
args.model_path, args.output_path + "/encoded.pt", dtype, device
|
||||
)
|
||||
save_video(decoded_output, args.output_path)
|
||||
|
@ -144,7 +144,9 @@ def convert_prompt(prompt: str, retry_times: int = 3, type: str = "t2v", image_p
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--prompt", type=str, required=True, help="Prompt to convert")
|
||||
parser.add_argument("--retry_times", type=int, default=3, help="Number of times to retry the conversion")
|
||||
parser.add_argument(
|
||||
"--retry_times", type=int, default=3, help="Number of times to retry the conversion"
|
||||
)
|
||||
parser.add_argument("--type", type=str, default="t2v", help="Type of conversion (t2v or i2v)")
|
||||
parser.add_argument("--image_path", type=str, default=None, help="Path to the image file")
|
||||
args = parser.parse_args()
|
||||
|
@ -30,7 +30,10 @@ import torchvision.transforms as T
|
||||
from diffusers.models.attention_processor import Attention, CogVideoXAttnProcessor2_0
|
||||
from diffusers.models.autoencoders import AutoencoderKLCogVideoX
|
||||
from diffusers.models.embeddings import apply_rotary_emb
|
||||
from diffusers.models.transformers.cogvideox_transformer_3d import CogVideoXBlock, CogVideoXTransformer3DModel
|
||||
from diffusers.models.transformers.cogvideox_transformer_3d import (
|
||||
CogVideoXBlock,
|
||||
CogVideoXTransformer3DModel,
|
||||
)
|
||||
from diffusers.pipelines.cogvideo.pipeline_cogvideox import CogVideoXPipeline, retrieve_timesteps
|
||||
from diffusers.schedulers import CogVideoXDDIMScheduler, DDIMInverseScheduler
|
||||
from diffusers.utils import export_to_video
|
||||
@ -62,22 +65,48 @@ class DDIMInversionArguments(TypedDict):
|
||||
def get_args() -> DDIMInversionArguments:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--model_path", type=str, required=True, help="Path of the pretrained model")
|
||||
parser.add_argument("--prompt", type=str, required=True, help="Prompt for the direct sample procedure")
|
||||
parser.add_argument("--video_path", type=str, required=True, help="Path of the video for inversion")
|
||||
parser.add_argument("--output_path", type=str, default="output", help="Path of the output videos")
|
||||
parser.add_argument("--guidance_scale", type=float, default=6.0, help="Classifier-free guidance scale")
|
||||
parser.add_argument("--num_inference_steps", type=int, default=50, help="Number of inference steps")
|
||||
parser.add_argument("--skip_frames_start", type=int, default=0, help="Number of skipped frames from the start")
|
||||
parser.add_argument("--skip_frames_end", type=int, default=0, help="Number of skipped frames from the end")
|
||||
parser.add_argument("--frame_sample_step", type=int, default=None, help="Temporal stride of the sampled frames")
|
||||
parser.add_argument("--max_num_frames", type=int, default=81, help="Max number of sampled frames")
|
||||
parser.add_argument(
|
||||
"--model_path", type=str, required=True, help="Path of the pretrained model"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt", type=str, required=True, help="Prompt for the direct sample procedure"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--video_path", type=str, required=True, help="Path of the video for inversion"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_path", type=str, default="output", help="Path of the output videos"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--guidance_scale", type=float, default=6.0, help="Classifier-free guidance scale"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_inference_steps", type=int, default=50, help="Number of inference steps"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_frames_start", type=int, default=0, help="Number of skipped frames from the start"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_frames_end", type=int, default=0, help="Number of skipped frames from the end"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--frame_sample_step", type=int, default=None, help="Temporal stride of the sampled frames"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_num_frames", type=int, default=81, help="Max number of sampled frames"
|
||||
)
|
||||
parser.add_argument("--width", type=int, default=720, help="Resized width of the video frames")
|
||||
parser.add_argument("--height", type=int, default=480, help="Resized height of the video frames")
|
||||
parser.add_argument(
|
||||
"--height", type=int, default=480, help="Resized height of the video frames"
|
||||
)
|
||||
parser.add_argument("--fps", type=int, default=8, help="Frame rate of the output videos")
|
||||
parser.add_argument("--dtype", type=str, default="bf16", choices=["bf16", "fp16"], help="Dtype of the model")
|
||||
parser.add_argument(
|
||||
"--dtype", type=str, default="bf16", choices=["bf16", "fp16"], help="Dtype of the model"
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=42, help="Seed for the random number generator")
|
||||
parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"], help="Device for inference")
|
||||
parser.add_argument(
|
||||
"--device", type=str, default="cuda", choices=["cuda", "cpu"], help="Device for inference"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
args.dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float16
|
||||
@ -116,13 +145,20 @@ class CogVideoXAttnProcessor2_0ForDDIMInversion(CogVideoXAttnProcessor2_0):
|
||||
|
||||
# Apply RoPE if needed
|
||||
if image_rotary_emb is not None:
|
||||
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
|
||||
query[:, :, text_seq_length:] = apply_rotary_emb(
|
||||
query[:, :, text_seq_length:], image_rotary_emb
|
||||
)
|
||||
if not attn.is_cross_attention:
|
||||
if key.size(2) == query.size(2): # Attention for reference hidden states
|
||||
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
||||
key[:, :, text_seq_length:] = apply_rotary_emb(
|
||||
key[:, :, text_seq_length:], image_rotary_emb
|
||||
)
|
||||
else: # RoPE should be applied to each group of image tokens
|
||||
key[:, :, text_seq_length : text_seq_length + image_seq_length] = apply_rotary_emb(
|
||||
key[:, :, text_seq_length : text_seq_length + image_seq_length], image_rotary_emb
|
||||
key[:, :, text_seq_length : text_seq_length + image_seq_length] = (
|
||||
apply_rotary_emb(
|
||||
key[:, :, text_seq_length : text_seq_length + image_seq_length],
|
||||
image_rotary_emb,
|
||||
)
|
||||
)
|
||||
key[:, :, text_seq_length * 2 + image_seq_length :] = apply_rotary_emb(
|
||||
key[:, :, text_seq_length * 2 + image_seq_length :], image_rotary_emb
|
||||
@ -162,8 +198,12 @@ class CogVideoXAttnProcessor2_0ForDDIMInversion(CogVideoXAttnProcessor2_0):
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
attention_mask = attn.prepare_attention_mask(
|
||||
attention_mask, sequence_length, batch_size
|
||||
)
|
||||
attention_mask = attention_mask.view(
|
||||
batch_size, attn.heads, -1, attention_mask.shape[-1]
|
||||
)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
@ -260,14 +300,18 @@ def get_video_frames(
|
||||
return frames.permute(0, 3, 1, 2).contiguous() # [F, C, H, W]
|
||||
|
||||
|
||||
def encode_video_frames(vae: AutoencoderKLCogVideoX, video_frames: torch.FloatTensor) -> torch.FloatTensor:
|
||||
def encode_video_frames(
|
||||
vae: AutoencoderKLCogVideoX, video_frames: torch.FloatTensor
|
||||
) -> torch.FloatTensor:
|
||||
video_frames = video_frames.to(device=vae.device, dtype=vae.dtype)
|
||||
video_frames = video_frames.unsqueeze(0).permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
|
||||
latent_dist = vae.encode(x=video_frames).latent_dist.sample().transpose(1, 2)
|
||||
return latent_dist * vae.config.scaling_factor
|
||||
|
||||
|
||||
def export_latents_to_video(pipeline: CogVideoXPipeline, latents: torch.FloatTensor, video_path: str, fps: int):
|
||||
def export_latents_to_video(
|
||||
pipeline: CogVideoXPipeline, latents: torch.FloatTensor, video_path: str, fps: int
|
||||
):
|
||||
video = pipeline.decode_latents(latents)
|
||||
frames = pipeline.video_processor.postprocess_video(video=video, output_type="pil")
|
||||
export_to_video(video_frames=frames[0], output_video_path=video_path, fps=fps)
|
||||
@ -320,7 +364,9 @@ def sample(
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta)
|
||||
if isinstance(scheduler, DDIMInverseScheduler): # Inverse scheduler does not accept extra kwargs
|
||||
if isinstance(
|
||||
scheduler, DDIMInverseScheduler
|
||||
): # Inverse scheduler does not accept extra kwargs
|
||||
extra_step_kwargs = {}
|
||||
|
||||
# 7. Create rotary embeds if required
|
||||
@ -344,7 +390,9 @@ def sample(
|
||||
if pipeline.interrupt:
|
||||
continue
|
||||
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = (
|
||||
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
)
|
||||
if reference_latents is not None:
|
||||
reference = reference_latents[i]
|
||||
reference = torch.cat([reference] * 2) if do_classifier_free_guidance else reference
|
||||
@ -371,18 +419,31 @@ def sample(
|
||||
# perform guidance
|
||||
if use_dynamic_cfg:
|
||||
pipeline._guidance_scale = 1 + guidance_scale * (
|
||||
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
|
||||
(
|
||||
1
|
||||
- math.cos(
|
||||
math.pi
|
||||
* ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0
|
||||
)
|
||||
)
|
||||
/ 2
|
||||
)
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + pipeline.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
noise_pred = noise_pred_uncond + pipeline.guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
|
||||
# compute the noisy sample x_t-1 -> x_t
|
||||
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
latents = scheduler.step(
|
||||
noise_pred, t, latents, **extra_step_kwargs, return_dict=False
|
||||
)[0]
|
||||
latents = latents.to(prompt_embeds.dtype)
|
||||
trajectory[i] = latents
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
||||
if i == len(timesteps) - 1 or (
|
||||
(i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0
|
||||
):
|
||||
progress_bar.update()
|
||||
|
||||
# Offload all models
|
||||
@ -410,7 +471,9 @@ def ddim_inversion(
|
||||
seed: int,
|
||||
device: torch.device,
|
||||
):
|
||||
pipeline: CogVideoXPipeline = CogVideoXPipeline.from_pretrained(model_path, torch_dtype=dtype).to(device=device)
|
||||
pipeline: CogVideoXPipeline = CogVideoXPipeline.from_pretrained(
|
||||
model_path, torch_dtype=dtype
|
||||
).to(device=device)
|
||||
if not pipeline.transformer.config.use_rotary_positional_embeddings:
|
||||
raise NotImplementedError("This script supports CogVideoX 5B model only.")
|
||||
video_frames = get_video_frames(
|
||||
|
@ -43,5 +43,3 @@ pip install -r requirements.txt
|
||||
```bash
|
||||
python app.py
|
||||
```
|
||||
|
||||
|
||||
|
@ -39,11 +39,15 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
MODEL = "THUDM/CogVideoX-5b"
|
||||
|
||||
hf_hub_download(repo_id="ai-forever/Real-ESRGAN", filename="RealESRGAN_x4.pth", local_dir="model_real_esran")
|
||||
hf_hub_download(
|
||||
repo_id="ai-forever/Real-ESRGAN", filename="RealESRGAN_x4.pth", local_dir="model_real_esran"
|
||||
)
|
||||
snapshot_download(repo_id="AlexWortega/RIFE", local_dir="model_rife")
|
||||
|
||||
pipe = CogVideoXPipeline.from_pretrained(MODEL, torch_dtype=torch.bfloat16).to(device)
|
||||
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
|
||||
pipe.scheduler = CogVideoXDPMScheduler.from_config(
|
||||
pipe.scheduler.config, timestep_spacing="trailing"
|
||||
)
|
||||
pipe_video = CogVideoXVideoToVideoPipeline.from_pretrained(
|
||||
MODEL,
|
||||
transformer=pipe.transformer,
|
||||
@ -296,8 +300,16 @@ def delete_old_files():
|
||||
|
||||
|
||||
threading.Thread(target=delete_old_files, daemon=True).start()
|
||||
examples_videos = [["example_videos/horse.mp4"], ["example_videos/kitten.mp4"], ["example_videos/train_running.mp4"]]
|
||||
examples_images = [["example_images/beach.png"], ["example_images/street.png"], ["example_images/camping.png"]]
|
||||
examples_videos = [
|
||||
["example_videos/horse.mp4"],
|
||||
["example_videos/kitten.mp4"],
|
||||
["example_videos/train_running.mp4"],
|
||||
]
|
||||
examples_images = [
|
||||
["example_images/beach.png"],
|
||||
["example_images/street.png"],
|
||||
["example_images/camping.png"],
|
||||
]
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown("""
|
||||
@ -322,14 +334,26 @@ with gr.Blocks() as demo:
|
||||
""")
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
with gr.Accordion("I2V: Image Input (cannot be used simultaneously with video input)", open=False):
|
||||
with gr.Accordion(
|
||||
"I2V: Image Input (cannot be used simultaneously with video input)", open=False
|
||||
):
|
||||
image_input = gr.Image(label="Input Image (will be cropped to 720 * 480)")
|
||||
examples_component_images = gr.Examples(examples_images, inputs=[image_input], cache_examples=False)
|
||||
with gr.Accordion("V2V: Video Input (cannot be used simultaneously with image input)", open=False):
|
||||
video_input = gr.Video(label="Input Video (will be cropped to 49 frames, 6 seconds at 8fps)")
|
||||
examples_component_images = gr.Examples(
|
||||
examples_images, inputs=[image_input], cache_examples=False
|
||||
)
|
||||
with gr.Accordion(
|
||||
"V2V: Video Input (cannot be used simultaneously with image input)", open=False
|
||||
):
|
||||
video_input = gr.Video(
|
||||
label="Input Video (will be cropped to 49 frames, 6 seconds at 8fps)"
|
||||
)
|
||||
strength = gr.Slider(0.1, 1.0, value=0.8, step=0.01, label="Strength")
|
||||
examples_component_videos = gr.Examples(examples_videos, inputs=[video_input], cache_examples=False)
|
||||
prompt = gr.Textbox(label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5)
|
||||
examples_component_videos = gr.Examples(
|
||||
examples_videos, inputs=[video_input], cache_examples=False
|
||||
)
|
||||
prompt = gr.Textbox(
|
||||
label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
gr.Markdown(
|
||||
@ -340,11 +364,16 @@ with gr.Blocks() as demo:
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
seed_param = gr.Number(
|
||||
label="Inference Seed (Enter a positive number, -1 for random)", value=-1
|
||||
label="Inference Seed (Enter a positive number, -1 for random)",
|
||||
value=-1,
|
||||
)
|
||||
with gr.Row():
|
||||
enable_scale = gr.Checkbox(label="Super-Resolution (720 × 480 -> 2880 × 1920)", value=False)
|
||||
enable_rife = gr.Checkbox(label="Frame Interpolation (8fps -> 16fps)", value=False)
|
||||
enable_scale = gr.Checkbox(
|
||||
label="Super-Resolution (720 × 480 -> 2880 × 1920)", value=False
|
||||
)
|
||||
enable_rife = gr.Checkbox(
|
||||
label="Frame Interpolation (8fps -> 16fps)", value=False
|
||||
)
|
||||
gr.Markdown(
|
||||
"✨In this demo, we use [RIFE](https://github.com/hzwer/ECCV2022-RIFE) for frame interpolation and [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) for upscaling(Super-Resolution).<br> The entire process is based on open-source solutions."
|
||||
)
|
||||
@ -430,7 +459,7 @@ with gr.Blocks() as demo:
|
||||
seed_value,
|
||||
scale_status,
|
||||
rife_status,
|
||||
progress=gr.Progress(track_tqdm=True)
|
||||
progress=gr.Progress(track_tqdm=True),
|
||||
):
|
||||
latents, seed = infer(
|
||||
prompt,
|
||||
@ -457,7 +486,9 @@ with gr.Blocks() as demo:
|
||||
image_pil = VaeImageProcessor.numpy_to_pil(image_np)
|
||||
batch_video_frames.append(image_pil)
|
||||
|
||||
video_path = utils.save_video(batch_video_frames[0], fps=math.ceil((len(batch_video_frames[0]) - 1) / 6))
|
||||
video_path = utils.save_video(
|
||||
batch_video_frames[0], fps=math.ceil((len(batch_video_frames[0]) - 1) / 6)
|
||||
)
|
||||
video_update = gr.update(visible=True, value=video_path)
|
||||
gif_path = convert_to_gif(video_path)
|
||||
gif_update = gr.update(visible=True, value=gif_path)
|
||||
|
@ -3,7 +3,9 @@ from .refine import *
|
||||
|
||||
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
|
||||
return nn.Sequential(
|
||||
torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1),
|
||||
torch.nn.ConvTranspose2d(
|
||||
in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1
|
||||
),
|
||||
nn.PReLU(out_planes),
|
||||
)
|
||||
|
||||
@ -46,7 +48,11 @@ class IFBlock(nn.Module):
|
||||
if scale != 1:
|
||||
x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False)
|
||||
if flow != None:
|
||||
flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False) * 1.0 / scale
|
||||
flow = (
|
||||
F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False)
|
||||
* 1.0
|
||||
/ scale
|
||||
)
|
||||
x = torch.cat((x, flow), 1)
|
||||
x = self.conv0(x)
|
||||
x = self.convblock(x) + x
|
||||
@ -102,7 +108,9 @@ class IFNet(nn.Module):
|
||||
warped_img0_teacher = warp(img0, flow_teacher[:, :2])
|
||||
warped_img1_teacher = warp(img1, flow_teacher[:, 2:4])
|
||||
mask_teacher = torch.sigmoid(mask + mask_d)
|
||||
merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * (1 - mask_teacher)
|
||||
merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * (
|
||||
1 - mask_teacher
|
||||
)
|
||||
else:
|
||||
flow_teacher = None
|
||||
merged_teacher = None
|
||||
@ -110,11 +118,16 @@ class IFNet(nn.Module):
|
||||
merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
|
||||
if gt.shape[1] == 3:
|
||||
loss_mask = (
|
||||
((merged[i] - gt).abs().mean(1, True) > (merged_teacher - gt).abs().mean(1, True) + 0.01)
|
||||
(
|
||||
(merged[i] - gt).abs().mean(1, True)
|
||||
> (merged_teacher - gt).abs().mean(1, True) + 0.01
|
||||
)
|
||||
.float()
|
||||
.detach()
|
||||
)
|
||||
loss_distill += (((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask).mean()
|
||||
loss_distill += (
|
||||
((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask
|
||||
).mean()
|
||||
c0 = self.contextnet(img0, flow[:, :2])
|
||||
c1 = self.contextnet(img1, flow[:, 2:4])
|
||||
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
|
||||
|
@ -3,7 +3,9 @@ from .refine_2R import *
|
||||
|
||||
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
|
||||
return nn.Sequential(
|
||||
torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1),
|
||||
torch.nn.ConvTranspose2d(
|
||||
in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1
|
||||
),
|
||||
nn.PReLU(out_planes),
|
||||
)
|
||||
|
||||
@ -46,7 +48,11 @@ class IFBlock(nn.Module):
|
||||
if scale != 1:
|
||||
x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False)
|
||||
if flow != None:
|
||||
flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False) * 1.0 / scale
|
||||
flow = (
|
||||
F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False)
|
||||
* 1.0
|
||||
/ scale
|
||||
)
|
||||
x = torch.cat((x, flow), 1)
|
||||
x = self.conv0(x)
|
||||
x = self.convblock(x) + x
|
||||
@ -102,7 +108,9 @@ class IFNet(nn.Module):
|
||||
warped_img0_teacher = warp(img0, flow_teacher[:, :2])
|
||||
warped_img1_teacher = warp(img1, flow_teacher[:, 2:4])
|
||||
mask_teacher = torch.sigmoid(mask + mask_d)
|
||||
merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * (1 - mask_teacher)
|
||||
merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * (
|
||||
1 - mask_teacher
|
||||
)
|
||||
else:
|
||||
flow_teacher = None
|
||||
merged_teacher = None
|
||||
@ -110,11 +118,16 @@ class IFNet(nn.Module):
|
||||
merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
|
||||
if gt.shape[1] == 3:
|
||||
loss_mask = (
|
||||
((merged[i] - gt).abs().mean(1, True) > (merged_teacher - gt).abs().mean(1, True) + 0.01)
|
||||
(
|
||||
(merged[i] - gt).abs().mean(1, True)
|
||||
> (merged_teacher - gt).abs().mean(1, True) + 0.01
|
||||
)
|
||||
.float()
|
||||
.detach()
|
||||
)
|
||||
loss_distill += (((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask).mean()
|
||||
loss_distill += (
|
||||
((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask
|
||||
).mean()
|
||||
c0 = self.contextnet(img0, flow[:, :2])
|
||||
c1 = self.contextnet(img1, flow[:, 2:4])
|
||||
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
|
||||
|
@ -61,11 +61,19 @@ class IFBlock(nn.Module):
|
||||
|
||||
def forward(self, x, flow, scale=1):
|
||||
x = F.interpolate(
|
||||
x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False
|
||||
x,
|
||||
scale_factor=1.0 / scale,
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
recompute_scale_factor=False,
|
||||
)
|
||||
flow = (
|
||||
F.interpolate(
|
||||
flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False
|
||||
flow,
|
||||
scale_factor=1.0 / scale,
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
recompute_scale_factor=False,
|
||||
)
|
||||
* 1.0
|
||||
/ scale
|
||||
@ -78,11 +86,21 @@ class IFBlock(nn.Module):
|
||||
flow = self.conv1(feat)
|
||||
mask = self.conv2(feat)
|
||||
flow = (
|
||||
F.interpolate(flow, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
|
||||
F.interpolate(
|
||||
flow,
|
||||
scale_factor=scale,
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
recompute_scale_factor=False,
|
||||
)
|
||||
* scale
|
||||
)
|
||||
mask = F.interpolate(
|
||||
mask, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False
|
||||
mask,
|
||||
scale_factor=scale,
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
recompute_scale_factor=False,
|
||||
)
|
||||
return flow, mask
|
||||
|
||||
@ -112,7 +130,11 @@ class IFNet(nn.Module):
|
||||
loss_cons = 0
|
||||
block = [self.block0, self.block1, self.block2]
|
||||
for i in range(3):
|
||||
f0, m0 = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], mask), 1), flow, scale=scale_list[i])
|
||||
f0, m0 = block[i](
|
||||
torch.cat((warped_img0[:, :3], warped_img1[:, :3], mask), 1),
|
||||
flow,
|
||||
scale=scale_list[i],
|
||||
)
|
||||
f1, m1 = block[i](
|
||||
torch.cat((warped_img1[:, :3], warped_img0[:, :3], -mask), 1),
|
||||
torch.cat((flow[:, 2:4], flow[:, :2]), 1),
|
||||
|
@ -3,7 +3,9 @@ from .refine import *
|
||||
|
||||
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
|
||||
return nn.Sequential(
|
||||
torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1),
|
||||
torch.nn.ConvTranspose2d(
|
||||
in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1
|
||||
),
|
||||
nn.PReLU(out_planes),
|
||||
)
|
||||
|
||||
@ -46,7 +48,11 @@ class IFBlock(nn.Module):
|
||||
if scale != 1:
|
||||
x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False)
|
||||
if flow != None:
|
||||
flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False) * 1.0 / scale
|
||||
flow = (
|
||||
F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False)
|
||||
* 1.0
|
||||
/ scale
|
||||
)
|
||||
x = torch.cat((x, flow), 1)
|
||||
x = self.conv0(x)
|
||||
x = self.convblock(x) + x
|
||||
@ -83,7 +89,9 @@ class IFNet_m(nn.Module):
|
||||
for i in range(3):
|
||||
if flow != None:
|
||||
flow_d, mask_d = stu[i](
|
||||
torch.cat((img0, img1, timestep, warped_img0, warped_img1, mask), 1), flow, scale=scale[i]
|
||||
torch.cat((img0, img1, timestep, warped_img0, warped_img1, mask), 1),
|
||||
flow,
|
||||
scale=scale[i],
|
||||
)
|
||||
flow = flow + flow_d
|
||||
mask = mask + mask_d
|
||||
@ -97,13 +105,17 @@ class IFNet_m(nn.Module):
|
||||
merged.append(merged_student)
|
||||
if gt.shape[1] == 3:
|
||||
flow_d, mask_d = self.block_tea(
|
||||
torch.cat((img0, img1, timestep, warped_img0, warped_img1, mask, gt), 1), flow, scale=1
|
||||
torch.cat((img0, img1, timestep, warped_img0, warped_img1, mask, gt), 1),
|
||||
flow,
|
||||
scale=1,
|
||||
)
|
||||
flow_teacher = flow + flow_d
|
||||
warped_img0_teacher = warp(img0, flow_teacher[:, :2])
|
||||
warped_img1_teacher = warp(img1, flow_teacher[:, 2:4])
|
||||
mask_teacher = torch.sigmoid(mask + mask_d)
|
||||
merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * (1 - mask_teacher)
|
||||
merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * (
|
||||
1 - mask_teacher
|
||||
)
|
||||
else:
|
||||
flow_teacher = None
|
||||
merged_teacher = None
|
||||
@ -111,11 +123,16 @@ class IFNet_m(nn.Module):
|
||||
merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
|
||||
if gt.shape[1] == 3:
|
||||
loss_mask = (
|
||||
((merged[i] - gt).abs().mean(1, True) > (merged_teacher - gt).abs().mean(1, True) + 0.01)
|
||||
(
|
||||
(merged[i] - gt).abs().mean(1, True)
|
||||
> (merged_teacher - gt).abs().mean(1, True) + 0.01
|
||||
)
|
||||
.float()
|
||||
.detach()
|
||||
)
|
||||
loss_distill += (((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask).mean()
|
||||
loss_distill += (
|
||||
((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask
|
||||
).mean()
|
||||
if returnflow:
|
||||
return flow
|
||||
else:
|
||||
|
@ -44,7 +44,9 @@ class Model:
|
||||
if torch.cuda.is_available():
|
||||
self.flownet.load_state_dict(convert(torch.load("{}/flownet.pkl".format(path))))
|
||||
else:
|
||||
self.flownet.load_state_dict(convert(torch.load("{}/flownet.pkl".format(path), map_location="cpu")))
|
||||
self.flownet.load_state_dict(
|
||||
convert(torch.load("{}/flownet.pkl".format(path), map_location="cpu"))
|
||||
)
|
||||
|
||||
def save_model(self, path, rank=0):
|
||||
if rank == 0:
|
||||
|
@ -29,10 +29,14 @@ def downsample(x):
|
||||
|
||||
|
||||
def upsample(x):
|
||||
cc = torch.cat([x, torch.zeros(x.shape[0], x.shape[1], x.shape[2], x.shape[3]).to(device)], dim=3)
|
||||
cc = torch.cat(
|
||||
[x, torch.zeros(x.shape[0], x.shape[1], x.shape[2], x.shape[3]).to(device)], dim=3
|
||||
)
|
||||
cc = cc.view(x.shape[0], x.shape[1], x.shape[2] * 2, x.shape[3])
|
||||
cc = cc.permute(0, 1, 3, 2)
|
||||
cc = torch.cat([cc, torch.zeros(x.shape[0], x.shape[1], x.shape[3], x.shape[2] * 2).to(device)], dim=3)
|
||||
cc = torch.cat(
|
||||
[cc, torch.zeros(x.shape[0], x.shape[1], x.shape[3], x.shape[2] * 2).to(device)], dim=3
|
||||
)
|
||||
cc = cc.view(x.shape[0], x.shape[1], x.shape[3] * 2, x.shape[2] * 2)
|
||||
x_up = cc.permute(0, 1, 3, 2)
|
||||
return conv_gauss(x_up, 4 * gauss_kernel(channels=x.shape[1]))
|
||||
@ -64,6 +68,10 @@ class LapLoss(torch.nn.Module):
|
||||
self.gauss_kernel = gauss_kernel(channels=channels)
|
||||
|
||||
def forward(self, input, target):
|
||||
pyr_input = laplacian_pyramid(img=input, kernel=self.gauss_kernel, max_levels=self.max_levels)
|
||||
pyr_target = laplacian_pyramid(img=target, kernel=self.gauss_kernel, max_levels=self.max_levels)
|
||||
pyr_input = laplacian_pyramid(
|
||||
img=input, kernel=self.gauss_kernel, max_levels=self.max_levels
|
||||
)
|
||||
pyr_target = laplacian_pyramid(
|
||||
img=target, kernel=self.gauss_kernel, max_levels=self.max_levels
|
||||
)
|
||||
return sum(torch.nn.functional.l1_loss(a, b) for a, b in zip(pyr_input, pyr_target))
|
||||
|
@ -7,7 +7,9 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
def gaussian(window_size, sigma):
|
||||
gauss = torch.Tensor([exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2)) for x in range(window_size)])
|
||||
gauss = torch.Tensor(
|
||||
[exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2)) for x in range(window_size)]
|
||||
)
|
||||
return gauss / gauss.sum()
|
||||
|
||||
|
||||
@ -22,7 +24,9 @@ def create_window_3d(window_size, channel=1):
|
||||
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
||||
_2D_window = _1D_window.mm(_1D_window.t())
|
||||
_3D_window = _2D_window.unsqueeze(2) @ (_1D_window.t())
|
||||
window = _3D_window.expand(1, channel, window_size, window_size, window_size).contiguous().to(device)
|
||||
window = (
|
||||
_3D_window.expand(1, channel, window_size, window_size, window_size).contiguous().to(device)
|
||||
)
|
||||
return window
|
||||
|
||||
|
||||
@ -50,16 +54,35 @@ def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False,
|
||||
|
||||
# mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
|
||||
# mu2 = F.conv2d(img2, window, padding=padd, groups=channel)
|
||||
mu1 = F.conv2d(F.pad(img1, (5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=channel)
|
||||
mu2 = F.conv2d(F.pad(img2, (5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=channel)
|
||||
mu1 = F.conv2d(
|
||||
F.pad(img1, (5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=channel
|
||||
)
|
||||
mu2 = F.conv2d(
|
||||
F.pad(img2, (5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=channel
|
||||
)
|
||||
|
||||
mu1_sq = mu1.pow(2)
|
||||
mu2_sq = mu2.pow(2)
|
||||
mu1_mu2 = mu1 * mu2
|
||||
|
||||
sigma1_sq = F.conv2d(F.pad(img1 * img1, (5, 5, 5, 5), "replicate"), window, padding=padd, groups=channel) - mu1_sq
|
||||
sigma2_sq = F.conv2d(F.pad(img2 * img2, (5, 5, 5, 5), "replicate"), window, padding=padd, groups=channel) - mu2_sq
|
||||
sigma12 = F.conv2d(F.pad(img1 * img2, (5, 5, 5, 5), "replicate"), window, padding=padd, groups=channel) - mu1_mu2
|
||||
sigma1_sq = (
|
||||
F.conv2d(
|
||||
F.pad(img1 * img1, (5, 5, 5, 5), "replicate"), window, padding=padd, groups=channel
|
||||
)
|
||||
- mu1_sq
|
||||
)
|
||||
sigma2_sq = (
|
||||
F.conv2d(
|
||||
F.pad(img2 * img2, (5, 5, 5, 5), "replicate"), window, padding=padd, groups=channel
|
||||
)
|
||||
- mu2_sq
|
||||
)
|
||||
sigma12 = (
|
||||
F.conv2d(
|
||||
F.pad(img1 * img2, (5, 5, 5, 5), "replicate"), window, padding=padd, groups=channel
|
||||
)
|
||||
- mu1_mu2
|
||||
)
|
||||
|
||||
C1 = (0.01 * L) ** 2
|
||||
C2 = (0.03 * L) ** 2
|
||||
@ -80,7 +103,9 @@ def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False,
|
||||
return ret
|
||||
|
||||
|
||||
def ssim_matlab(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
|
||||
def ssim_matlab(
|
||||
img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None
|
||||
):
|
||||
# Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
|
||||
if val_range is None:
|
||||
if torch.max(img1) > 128:
|
||||
@ -106,16 +131,35 @@ def ssim_matlab(img1, img2, window_size=11, window=None, size_average=True, full
|
||||
img1 = img1.unsqueeze(1)
|
||||
img2 = img2.unsqueeze(1)
|
||||
|
||||
mu1 = F.conv3d(F.pad(img1, (5, 5, 5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=1)
|
||||
mu2 = F.conv3d(F.pad(img2, (5, 5, 5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=1)
|
||||
mu1 = F.conv3d(
|
||||
F.pad(img1, (5, 5, 5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=1
|
||||
)
|
||||
mu2 = F.conv3d(
|
||||
F.pad(img2, (5, 5, 5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=1
|
||||
)
|
||||
|
||||
mu1_sq = mu1.pow(2)
|
||||
mu2_sq = mu2.pow(2)
|
||||
mu1_mu2 = mu1 * mu2
|
||||
|
||||
sigma1_sq = F.conv3d(F.pad(img1 * img1, (5, 5, 5, 5, 5, 5), "replicate"), window, padding=padd, groups=1) - mu1_sq
|
||||
sigma2_sq = F.conv3d(F.pad(img2 * img2, (5, 5, 5, 5, 5, 5), "replicate"), window, padding=padd, groups=1) - mu2_sq
|
||||
sigma12 = F.conv3d(F.pad(img1 * img2, (5, 5, 5, 5, 5, 5), "replicate"), window, padding=padd, groups=1) - mu1_mu2
|
||||
sigma1_sq = (
|
||||
F.conv3d(
|
||||
F.pad(img1 * img1, (5, 5, 5, 5, 5, 5), "replicate"), window, padding=padd, groups=1
|
||||
)
|
||||
- mu1_sq
|
||||
)
|
||||
sigma2_sq = (
|
||||
F.conv3d(
|
||||
F.pad(img2 * img2, (5, 5, 5, 5, 5, 5), "replicate"), window, padding=padd, groups=1
|
||||
)
|
||||
- mu2_sq
|
||||
)
|
||||
sigma12 = (
|
||||
F.conv3d(
|
||||
F.pad(img1 * img2, (5, 5, 5, 5, 5, 5), "replicate"), window, padding=padd, groups=1
|
||||
)
|
||||
- mu1_mu2
|
||||
)
|
||||
|
||||
C1 = (0.01 * L) ** 2
|
||||
C2 = (0.03 * L) ** 2
|
||||
@ -143,7 +187,14 @@ def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normal
|
||||
mssim = []
|
||||
mcs = []
|
||||
for _ in range(levels):
|
||||
sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range)
|
||||
sim, cs = ssim(
|
||||
img1,
|
||||
img2,
|
||||
window_size=window_size,
|
||||
size_average=size_average,
|
||||
full=True,
|
||||
val_range=val_range,
|
||||
)
|
||||
mssim.append(sim)
|
||||
mcs.append(cs)
|
||||
|
||||
@ -187,7 +238,9 @@ class SSIM(torch.nn.Module):
|
||||
self.window = window
|
||||
self.channel = channel
|
||||
|
||||
_ssim = ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)
|
||||
_ssim = ssim(
|
||||
img1, img2, window=window, window_size=self.window_size, size_average=self.size_average
|
||||
)
|
||||
dssim = (1 - _ssim) / 2
|
||||
return dssim
|
||||
|
||||
|
@ -24,7 +24,12 @@ def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
|
||||
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
|
||||
return nn.Sequential(
|
||||
torch.nn.ConvTranspose2d(
|
||||
in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1, bias=True
|
||||
in_channels=in_planes,
|
||||
out_channels=out_planes,
|
||||
kernel_size=4,
|
||||
stride=2,
|
||||
padding=1,
|
||||
bias=True,
|
||||
),
|
||||
nn.PReLU(out_planes),
|
||||
)
|
||||
@ -56,25 +61,49 @@ class Contextnet(nn.Module):
|
||||
def forward(self, x, flow):
|
||||
x = self.conv1(x)
|
||||
flow = (
|
||||
F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False)
|
||||
F.interpolate(
|
||||
flow,
|
||||
scale_factor=0.5,
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
recompute_scale_factor=False,
|
||||
)
|
||||
* 0.5
|
||||
)
|
||||
f1 = warp(x, flow)
|
||||
x = self.conv2(x)
|
||||
flow = (
|
||||
F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False)
|
||||
F.interpolate(
|
||||
flow,
|
||||
scale_factor=0.5,
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
recompute_scale_factor=False,
|
||||
)
|
||||
* 0.5
|
||||
)
|
||||
f2 = warp(x, flow)
|
||||
x = self.conv3(x)
|
||||
flow = (
|
||||
F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False)
|
||||
F.interpolate(
|
||||
flow,
|
||||
scale_factor=0.5,
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
recompute_scale_factor=False,
|
||||
)
|
||||
* 0.5
|
||||
)
|
||||
f3 = warp(x, flow)
|
||||
x = self.conv4(x)
|
||||
flow = (
|
||||
F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False)
|
||||
F.interpolate(
|
||||
flow,
|
||||
scale_factor=0.5,
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
recompute_scale_factor=False,
|
||||
)
|
||||
* 0.5
|
||||
)
|
||||
f4 = warp(x, flow)
|
||||
|
@ -24,7 +24,12 @@ def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
|
||||
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
|
||||
return nn.Sequential(
|
||||
torch.nn.ConvTranspose2d(
|
||||
in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1, bias=True
|
||||
in_channels=in_planes,
|
||||
out_channels=out_planes,
|
||||
kernel_size=4,
|
||||
stride=2,
|
||||
padding=1,
|
||||
bias=True,
|
||||
),
|
||||
nn.PReLU(out_planes),
|
||||
)
|
||||
@ -59,19 +64,37 @@ class Contextnet(nn.Module):
|
||||
f1 = warp(x, flow)
|
||||
x = self.conv2(x)
|
||||
flow = (
|
||||
F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False)
|
||||
F.interpolate(
|
||||
flow,
|
||||
scale_factor=0.5,
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
recompute_scale_factor=False,
|
||||
)
|
||||
* 0.5
|
||||
)
|
||||
f2 = warp(x, flow)
|
||||
x = self.conv3(x)
|
||||
flow = (
|
||||
F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False)
|
||||
F.interpolate(
|
||||
flow,
|
||||
scale_factor=0.5,
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
recompute_scale_factor=False,
|
||||
)
|
||||
* 0.5
|
||||
)
|
||||
f3 = warp(x, flow)
|
||||
x = self.conv4(x)
|
||||
flow = (
|
||||
F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False)
|
||||
F.interpolate(
|
||||
flow,
|
||||
scale_factor=0.5,
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
recompute_scale_factor=False,
|
||||
)
|
||||
* 0.5
|
||||
)
|
||||
f4 = warp(x, flow)
|
||||
|
@ -9,6 +9,7 @@ import logging
|
||||
import skvideo.io
|
||||
from rife.RIFE_HDv3 import Model
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
@ -19,7 +20,7 @@ def pad_image(img, scale):
|
||||
tmp = max(32, int(32 / scale))
|
||||
ph = ((h - 1) // tmp + 1) * tmp
|
||||
pw = ((w - 1) // tmp + 1) * tmp
|
||||
padding = (0, pw - w, 0, ph - h)
|
||||
padding = (0, pw - w, 0, ph - h)
|
||||
return F.pad(img, padding), padding
|
||||
|
||||
|
||||
@ -78,13 +79,12 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi
|
||||
# print(f'I1[0] unpadded shape:{I1.shape}')
|
||||
I1_small = F.interpolate(I1, (32, 32), mode="bilinear", align_corners=False)
|
||||
ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
|
||||
if padding[3] > 0 and padding[1] >0 :
|
||||
|
||||
frame = I1[:, :, : -padding[3],:-padding[1]]
|
||||
if padding[3] > 0 and padding[1] > 0:
|
||||
frame = I1[:, :, : -padding[3], : -padding[1]]
|
||||
elif padding[3] > 0:
|
||||
frame = I1[:, :, : -padding[3],:]
|
||||
elif padding[1] >0:
|
||||
frame = I1[:, :, :,:-padding[1]]
|
||||
frame = I1[:, :, : -padding[3], :]
|
||||
elif padding[1] > 0:
|
||||
frame = I1[:, :, :, : -padding[1]]
|
||||
else:
|
||||
frame = I1
|
||||
|
||||
@ -102,7 +102,6 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi
|
||||
frame = F.interpolate(frame, size=(h, w))
|
||||
output.append(frame.to(output_device))
|
||||
for i, tmp_frame in enumerate(tmp_output):
|
||||
|
||||
# tmp_frame, _ = pad_image(tmp_frame, upscale_amount)
|
||||
tmp_frame = F.interpolate(tmp_frame, size=(h, w))
|
||||
output.append(tmp_frame.to(output_device))
|
||||
@ -145,9 +144,7 @@ def rife_inference_with_path(model, video_path):
|
||||
frame_rgb = frame[..., ::-1]
|
||||
frame_rgb = frame_rgb.copy()
|
||||
tensor = torch.from_numpy(frame_rgb).float().to("cpu", non_blocking=True).float() / 255.0
|
||||
pt_frame_data.append(
|
||||
tensor.permute(2, 0, 1)
|
||||
) # to [c, h, w,]
|
||||
pt_frame_data.append(tensor.permute(2, 0, 1)) # to [c, h, w,]
|
||||
|
||||
pt_frame = torch.from_numpy(np.stack(pt_frame_data))
|
||||
pt_frame = pt_frame.to(device)
|
||||
@ -170,7 +167,9 @@ def rife_inference_with_latents(model, latents):
|
||||
latent = latents[i]
|
||||
|
||||
frames = ssim_interpolation_rife(model, latent)
|
||||
pt_image = torch.stack([frames[i].squeeze(0) for i in range(len(frames))]) # (to [f, c, w, h])
|
||||
pt_image = torch.stack(
|
||||
[frames[i].squeeze(0) for i in range(len(frames))]
|
||||
) # (to [f, c, w, h])
|
||||
rife_results.append(pt_image)
|
||||
|
||||
return torch.stack(rife_results)
|
||||
|
@ -22,7 +22,7 @@ def load_torch_file(ckpt, device=None, dtype=torch.float16):
|
||||
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
|
||||
sd = safetensors.torch.load_file(ckpt, device=device.type)
|
||||
else:
|
||||
if not "weights_only" in torch.load.__code__.co_varnames:
|
||||
if "weights_only" not in torch.load.__code__.co_varnames:
|
||||
logger.warning(
|
||||
"Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely."
|
||||
)
|
||||
@ -74,27 +74,39 @@ def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
|
||||
|
||||
@torch.inference_mode()
|
||||
def tiled_scale_multidim(
|
||||
samples, function, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", pbar=None
|
||||
samples,
|
||||
function,
|
||||
tile=(64, 64),
|
||||
overlap=8,
|
||||
upscale_amount=4,
|
||||
out_channels=3,
|
||||
output_device="cpu",
|
||||
pbar=None,
|
||||
):
|
||||
dims = len(tile)
|
||||
print(f"samples dtype:{samples.dtype}")
|
||||
output = torch.empty(
|
||||
[samples.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), samples.shape[2:])),
|
||||
[samples.shape[0], out_channels]
|
||||
+ list(map(lambda a: round(a * upscale_amount), samples.shape[2:])),
|
||||
device=output_device,
|
||||
)
|
||||
|
||||
for b in range(samples.shape[0]):
|
||||
s = samples[b : b + 1]
|
||||
out = torch.zeros(
|
||||
[s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])),
|
||||
[s.shape[0], out_channels]
|
||||
+ list(map(lambda a: round(a * upscale_amount), s.shape[2:])),
|
||||
device=output_device,
|
||||
)
|
||||
out_div = torch.zeros(
|
||||
[s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])),
|
||||
[s.shape[0], out_channels]
|
||||
+ list(map(lambda a: round(a * upscale_amount), s.shape[2:])),
|
||||
device=output_device,
|
||||
)
|
||||
|
||||
for it in itertools.product(*map(lambda a: range(0, a[0], a[1] - overlap), zip(s.shape[2:], tile))):
|
||||
for it in itertools.product(
|
||||
*map(lambda a: range(0, a[0], a[1] - overlap), zip(s.shape[2:], tile))
|
||||
):
|
||||
s_in = s
|
||||
upscaled = []
|
||||
|
||||
@ -142,7 +154,14 @@ def tiled_scale(
|
||||
pbar=None,
|
||||
):
|
||||
return tiled_scale_multidim(
|
||||
samples, function, (tile_y, tile_x), overlap, upscale_amount, out_channels, output_device, pbar
|
||||
samples,
|
||||
function,
|
||||
(tile_y, tile_x),
|
||||
overlap,
|
||||
upscale_amount,
|
||||
out_channels,
|
||||
output_device,
|
||||
pbar,
|
||||
)
|
||||
|
||||
|
||||
@ -186,7 +205,9 @@ def upscale(upscale_model, tensor: torch.Tensor, inf_device, output_device="cpu"
|
||||
return s
|
||||
|
||||
|
||||
def upscale_batch_and_concatenate(upscale_model, latents, inf_device, output_device="cpu") -> torch.Tensor:
|
||||
def upscale_batch_and_concatenate(
|
||||
upscale_model, latents, inf_device, output_device="cpu"
|
||||
) -> torch.Tensor:
|
||||
upscaled_latents = []
|
||||
for i in range(latents.size(0)):
|
||||
latent = latents[i]
|
||||
@ -207,7 +228,9 @@ class ProgressBar:
|
||||
def __init__(self, total, desc=None):
|
||||
self.total = total
|
||||
self.current = 0
|
||||
self.b_unit = tqdm.tqdm(total=total, desc="ProgressBar context index: 0" if desc is None else desc)
|
||||
self.b_unit = tqdm.tqdm(
|
||||
total=total, desc="ProgressBar context index: 0" if desc is None else desc
|
||||
)
|
||||
|
||||
def update(self, value):
|
||||
if value > self.total:
|
||||
|
@ -22,7 +22,9 @@ from datetime import datetime, timedelta
|
||||
from openai import OpenAI
|
||||
from moviepy import VideoFileClip
|
||||
|
||||
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to("cuda")
|
||||
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to(
|
||||
"cuda"
|
||||
)
|
||||
|
||||
pipe.vae.enable_slicing()
|
||||
pipe.vae.enable_tiling()
|
||||
@ -95,7 +97,12 @@ def convert_prompt(prompt: str, retry_times: int = 3) -> str:
|
||||
return prompt
|
||||
|
||||
|
||||
def infer(prompt: str, num_inference_steps: int, guidance_scale: float, progress=gr.Progress(track_tqdm=True)):
|
||||
def infer(
|
||||
prompt: str,
|
||||
num_inference_steps: int,
|
||||
guidance_scale: float,
|
||||
progress=gr.Progress(track_tqdm=True),
|
||||
):
|
||||
torch.cuda.empty_cache()
|
||||
video = pipe(
|
||||
prompt=prompt,
|
||||
@ -151,7 +158,9 @@ with gr.Blocks() as demo:
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
prompt = gr.Textbox(label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5)
|
||||
prompt = gr.Textbox(
|
||||
label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
gr.Markdown(
|
||||
@ -176,7 +185,13 @@ with gr.Blocks() as demo:
|
||||
download_video_button = gr.File(label="📥 Download Video", visible=False)
|
||||
download_gif_button = gr.File(label="📥 Download GIF", visible=False)
|
||||
|
||||
def generate(prompt, num_inference_steps, guidance_scale, model_choice, progress=gr.Progress(track_tqdm=True)):
|
||||
def generate(
|
||||
prompt,
|
||||
num_inference_steps,
|
||||
guidance_scale,
|
||||
model_choice,
|
||||
progress=gr.Progress(track_tqdm=True),
|
||||
):
|
||||
tensor = infer(prompt, num_inference_steps, guidance_scale, progress=progress)
|
||||
video_path = save_video(tensor)
|
||||
video_update = gr.update(visible=True, value=video_path)
|
||||
|
@ -4,4 +4,3 @@
|
||||
<p> 扫码关注公众号,加入「 CogVideoX 交流群」 </p>
|
||||
<p> Scan the QR code to follow the official account and join the "CogVLM Discussion Group" </p>
|
||||
</div>
|
||||
|
||||
|
@ -1,49 +0,0 @@
|
||||
# Contribution Guide
|
||||
|
||||
There may still be many incomplete aspects in this project.
|
||||
|
||||
We look forward to your contributions to the repository in the following areas. If you complete the work mentioned above
|
||||
and are willing to submit a PR and share it with the community, upon review, we
|
||||
will acknowledge your contribution on the project homepage.
|
||||
|
||||
## Model Algorithms
|
||||
|
||||
- Support for model quantization inference (Int4 quantization project)
|
||||
- Optimization of model fine-tuning data loading (replacing the existing decord tool)
|
||||
|
||||
## Model Engineering
|
||||
|
||||
- Model fine-tuning examples / Best prompt practices
|
||||
- Inference adaptation on different devices (e.g., MLX framework)
|
||||
- Any tools related to the model
|
||||
- Any minimal fully open-source project using the CogVideoX open-source model
|
||||
|
||||
## Code Standards
|
||||
|
||||
Good code style is an art. We have prepared a `pyproject.toml` configuration file for the project to standardize code
|
||||
style. You can organize the code according to the following specifications:
|
||||
|
||||
1. Install the `ruff` tool
|
||||
|
||||
```shell
|
||||
pip install ruff
|
||||
```
|
||||
|
||||
Then, run the `ruff` tool
|
||||
|
||||
```shell
|
||||
ruff check tools sat inference
|
||||
```
|
||||
|
||||
Check the code style. If there are issues, you can automatically fix them using the `ruff format` command.
|
||||
|
||||
```shell
|
||||
ruff format tools sat inference
|
||||
```
|
||||
|
||||
Once your code meets the standard, there should be no errors.
|
||||
|
||||
## Naming Conventions
|
||||
1. Please use English names, do not use Pinyin or other language names. All comments should be in English.
|
||||
2. Please strictly follow the PEP8 specification and use underscores to separate words. Do not use names like a, b, c.
|
||||
|
@ -1,47 +0,0 @@
|
||||
# コントリビューションガイド
|
||||
|
||||
本プロジェクトにはまだ多くの未完成の部分があります。
|
||||
|
||||
以下の分野でリポジトリへの貢献をお待ちしています。上記の作業を完了し、PRを提出してコミュニティと共有する意志がある場合、レビュー後、プロジェクトのホームページで貢献を認識します。
|
||||
|
||||
## モデルアルゴリズム
|
||||
|
||||
- モデル量子化推論のサポート (Int4量子化プロジェクト)
|
||||
- モデルのファインチューニングデータロードの最適化(既存のdecordツールの置き換え)
|
||||
|
||||
## モデルエンジニアリング
|
||||
|
||||
- モデルのファインチューニング例 / 最適なプロンプトの実践
|
||||
- 異なるデバイスでの推論適応(例: MLXフレームワーク)
|
||||
- モデルに関連するツール
|
||||
- CogVideoXオープンソースモデルを使用した、完全にオープンソースの最小プロジェクト
|
||||
|
||||
## コード標準
|
||||
|
||||
良いコードスタイルは一種の芸術です。本プロジェクトにはコードスタイルを標準化するための `pyproject.toml`
|
||||
設定ファイルを用意しています。以下の仕様に従ってコードを整理してください。
|
||||
|
||||
1. `ruff` ツールをインストールする
|
||||
|
||||
```shell
|
||||
pip install ruff
|
||||
```
|
||||
|
||||
次に、`ruff` ツールを実行します
|
||||
|
||||
```shell
|
||||
ruff check tools sat inference
|
||||
```
|
||||
|
||||
コードスタイルを確認します。問題がある場合は、`ruff format` コマンドを使用して自動修正できます。
|
||||
|
||||
```shell
|
||||
ruff format tools sat inference
|
||||
```
|
||||
|
||||
コードが標準に準拠したら、エラーはなくなるはずです。
|
||||
|
||||
## 命名規則
|
||||
|
||||
1. 英語名を使用してください。ピンインや他の言語の名前を使用しないでください。すべてのコメントは英語で記載してください。
|
||||
2. PEP8仕様に厳密に従い、単語をアンダースコアで区切ってください。a、b、cのような名前は使用しないでください。
|
@ -1,44 +0,0 @@
|
||||
# 贡献指南
|
||||
|
||||
本项目可能还存在很多不完善的内容。 我们期待您在以下方面与我们共建仓库, 如果您完成了上述工作并愿意PR和分享到社区,在通过审核后,我们将在项目首页感谢您的贡献。
|
||||
|
||||
## 模型算法
|
||||
|
||||
- 模型量化推理支持 (Int4量化工程)
|
||||
- 模型微调数据载入优化支持(替换现有的decord工具)
|
||||
|
||||
## 模型工程
|
||||
|
||||
- 模型微调示例 / 最佳提示词实践
|
||||
- 不同设备上的推理适配(MLX等框架)
|
||||
- 任何模型周边工具
|
||||
- 任何使用CogVideoX开源模型制作的最小完整开源项目
|
||||
|
||||
## 代码规范
|
||||
|
||||
良好的代码风格是一种艺术,我们已经为项目准备好了`pyproject.toml`配置文件,用于规范代码风格。您可以按照以下规范梳理代码:
|
||||
|
||||
1. 安装`ruff`工具
|
||||
|
||||
```shell
|
||||
pip install ruff
|
||||
```
|
||||
|
||||
接着,运行`ruff`工具
|
||||
|
||||
```shell
|
||||
ruff check tools sat inference
|
||||
```
|
||||
|
||||
检查代码风格,如果有问题,您可以通过`ruff format .`命令自动修复。
|
||||
|
||||
```shell
|
||||
ruff format tools sat inference
|
||||
```
|
||||
|
||||
如果您的代码符合规范,应该不会出现任何的错误。
|
||||
|
||||
## 命名规范
|
||||
|
||||
- 请使用英文命名,不要使用拼音或者其他语言命名。所有的注释均使用英文。
|
||||
- 请严格遵循 PEP8 规范,使用下划线分割单词。请勿使用 a,b,c 这样的命名。
|
@ -18,7 +18,10 @@ def add_model_config_args(parser):
|
||||
group = parser.add_argument_group("model", "model configuration")
|
||||
group.add_argument("--base", type=str, nargs="*", help="config for input and saving")
|
||||
group.add_argument(
|
||||
"--model-parallel-size", type=int, default=1, help="size of the model parallel. only use if you are an expert."
|
||||
"--model-parallel-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="size of the model parallel. only use if you are an expert.",
|
||||
)
|
||||
group.add_argument("--force-pretrain", action="store_true")
|
||||
group.add_argument("--device", type=int, default=-1)
|
||||
@ -74,10 +77,15 @@ def get_args(args_list=None, parser=None):
|
||||
if not args.train_data:
|
||||
print_rank0("No training data specified", level="WARNING")
|
||||
|
||||
assert (args.train_iters is None) or (args.epochs is None), "only one of train_iters and epochs should be set."
|
||||
assert (args.train_iters is None) or (
|
||||
args.epochs is None
|
||||
), "only one of train_iters and epochs should be set."
|
||||
if args.train_iters is None and args.epochs is None:
|
||||
args.train_iters = 10000 # default 10k iters
|
||||
print_rank0("No train_iters (recommended) or epochs specified, use default 10k iters.", level="WARNING")
|
||||
print_rank0(
|
||||
"No train_iters (recommended) or epochs specified, use default 10k iters.",
|
||||
level="WARNING",
|
||||
)
|
||||
|
||||
args.cuda = torch.cuda.is_available()
|
||||
|
||||
@ -213,7 +221,10 @@ def initialize_distributed(args):
|
||||
args.master_port = os.getenv("MASTER_PORT", default_master_port)
|
||||
init_method += args.master_ip + ":" + args.master_port
|
||||
torch.distributed.init_process_group(
|
||||
backend=args.distributed_backend, world_size=args.world_size, rank=args.rank, init_method=init_method
|
||||
backend=args.distributed_backend,
|
||||
world_size=args.world_size,
|
||||
rank=args.rank,
|
||||
init_method=init_method,
|
||||
)
|
||||
|
||||
# Set the model-parallel / data-parallel communicators.
|
||||
@ -232,7 +243,10 @@ def initialize_distributed(args):
|
||||
import deepspeed
|
||||
|
||||
deepspeed.init_distributed(
|
||||
dist_backend=args.distributed_backend, world_size=args.world_size, rank=args.rank, init_method=init_method
|
||||
dist_backend=args.distributed_backend,
|
||||
world_size=args.world_size,
|
||||
rank=args.rank,
|
||||
init_method=init_method,
|
||||
)
|
||||
# # It seems that it has no negative influence to configure it even without using checkpointing.
|
||||
# deepspeed.checkpointing.configure(mpu, deepspeed_config=args.deepspeed_config, num_checkpoints=args.num_layers)
|
||||
@ -262,7 +276,9 @@ def process_config_to_args(args):
|
||||
|
||||
args_config = config.pop("args", OmegaConf.create())
|
||||
for key in args_config:
|
||||
if isinstance(args_config[key], omegaconf.DictConfig) or isinstance(args_config[key], omegaconf.ListConfig):
|
||||
if isinstance(args_config[key], omegaconf.DictConfig) or isinstance(
|
||||
args_config[key], omegaconf.ListConfig
|
||||
):
|
||||
arg = OmegaConf.to_object(args_config[key])
|
||||
else:
|
||||
arg = args_config[key]
|
||||
|
@ -56,7 +56,9 @@ def read_video(
|
||||
end_pts = float("inf")
|
||||
|
||||
if end_pts < start_pts:
|
||||
raise ValueError(f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}")
|
||||
raise ValueError(
|
||||
f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}"
|
||||
)
|
||||
|
||||
info = {}
|
||||
audio_frames = []
|
||||
@ -342,7 +344,11 @@ class VideoDataset(MetaDistributedWebDataset):
|
||||
super().__init__(
|
||||
path,
|
||||
partial(
|
||||
process_fn_video, num_frames=num_frames, image_size=image_size, fps=fps, skip_frms_num=skip_frms_num
|
||||
process_fn_video,
|
||||
num_frames=num_frames,
|
||||
image_size=image_size,
|
||||
fps=fps,
|
||||
skip_frms_num=skip_frms_num,
|
||||
),
|
||||
seed,
|
||||
meta_names=meta_names,
|
||||
@ -400,7 +406,9 @@ class SFTDataset(Dataset):
|
||||
indices = np.arange(start, end, (end - start) // num_frames).astype(int)
|
||||
temp_frms = vr.get_batch(np.arange(start, end_safty))
|
||||
assert temp_frms is not None
|
||||
tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
|
||||
tensor_frms = (
|
||||
torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
|
||||
)
|
||||
tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())]
|
||||
else:
|
||||
if ori_vlen > self.max_num_frames:
|
||||
@ -410,7 +418,11 @@ class SFTDataset(Dataset):
|
||||
indices = np.arange(start, end, max((end - start) // num_frames, 1)).astype(int)
|
||||
temp_frms = vr.get_batch(np.arange(start, end))
|
||||
assert temp_frms is not None
|
||||
tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
|
||||
tensor_frms = (
|
||||
torch.from_numpy(temp_frms)
|
||||
if type(temp_frms) is not torch.Tensor
|
||||
else temp_frms
|
||||
)
|
||||
tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())]
|
||||
else:
|
||||
|
||||
@ -423,11 +435,17 @@ class SFTDataset(Dataset):
|
||||
|
||||
start = int(self.skip_frms_num)
|
||||
end = int(ori_vlen - self.skip_frms_num)
|
||||
num_frames = nearest_smaller_4k_plus_1(end - start) # 3D VAE requires the number of frames to be 4k+1
|
||||
num_frames = nearest_smaller_4k_plus_1(
|
||||
end - start
|
||||
) # 3D VAE requires the number of frames to be 4k+1
|
||||
end = int(start + num_frames)
|
||||
temp_frms = vr.get_batch(np.arange(start, end))
|
||||
assert temp_frms is not None
|
||||
tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
|
||||
tensor_frms = (
|
||||
torch.from_numpy(temp_frms)
|
||||
if type(temp_frms) is not torch.Tensor
|
||||
else temp_frms
|
||||
)
|
||||
|
||||
tensor_frms = pad_last_frame(
|
||||
tensor_frms, self.max_num_frames
|
||||
|
@ -41,7 +41,9 @@ class SATVideoDiffusionEngine(nn.Module):
|
||||
latent_input = model_config.get("latent_input", False)
|
||||
disable_first_stage_autocast = model_config.get("disable_first_stage_autocast", False)
|
||||
no_cond_log = model_config.get("disable_first_stage_autocast", False)
|
||||
not_trainable_prefixes = model_config.get("not_trainable_prefixes", ["first_stage_model", "conditioner"])
|
||||
not_trainable_prefixes = model_config.get(
|
||||
"not_trainable_prefixes", ["first_stage_model", "conditioner"]
|
||||
)
|
||||
compile_model = model_config.get("compile_model", False)
|
||||
en_and_decode_n_samples_a_time = model_config.get("en_and_decode_n_samples_a_time", None)
|
||||
lr_scale = model_config.get("lr_scale", None)
|
||||
@ -76,12 +78,18 @@ class SATVideoDiffusionEngine(nn.Module):
|
||||
)
|
||||
|
||||
self.denoiser = instantiate_from_config(denoiser_config)
|
||||
self.sampler = instantiate_from_config(sampler_config) if sampler_config is not None else None
|
||||
self.conditioner = instantiate_from_config(default(conditioner_config, UNCONDITIONAL_CONFIG))
|
||||
self.sampler = (
|
||||
instantiate_from_config(sampler_config) if sampler_config is not None else None
|
||||
)
|
||||
self.conditioner = instantiate_from_config(
|
||||
default(conditioner_config, UNCONDITIONAL_CONFIG)
|
||||
)
|
||||
|
||||
self._init_first_stage(first_stage_config)
|
||||
|
||||
self.loss_fn = instantiate_from_config(loss_fn_config) if loss_fn_config is not None else None
|
||||
self.loss_fn = (
|
||||
instantiate_from_config(loss_fn_config) if loss_fn_config is not None else None
|
||||
)
|
||||
|
||||
self.latent_input = latent_input
|
||||
self.scale_factor = scale_factor
|
||||
@ -151,8 +159,12 @@ class SATVideoDiffusionEngine(nn.Module):
|
||||
def shared_step(self, batch: Dict) -> Any:
|
||||
x = self.get_input(batch)
|
||||
if self.lr_scale is not None:
|
||||
lr_x = F.interpolate(x, scale_factor=1 / self.lr_scale, mode="bilinear", align_corners=False)
|
||||
lr_x = F.interpolate(lr_x, scale_factor=self.lr_scale, mode="bilinear", align_corners=False)
|
||||
lr_x = F.interpolate(
|
||||
x, scale_factor=1 / self.lr_scale, mode="bilinear", align_corners=False
|
||||
)
|
||||
lr_x = F.interpolate(
|
||||
lr_x, scale_factor=self.lr_scale, mode="bilinear", align_corners=False
|
||||
)
|
||||
lr_z = self.encode_first_stage(lr_x, batch)
|
||||
batch["lr_input"] = lr_z
|
||||
|
||||
@ -195,7 +207,11 @@ class SATVideoDiffusionEngine(nn.Module):
|
||||
recons = []
|
||||
start_frame = 0
|
||||
for i in range(fake_cp_size):
|
||||
end_frame = start_frame + latent_time // fake_cp_size + (1 if i < latent_time % fake_cp_size else 0)
|
||||
end_frame = (
|
||||
start_frame
|
||||
+ latent_time // fake_cp_size
|
||||
+ (1 if i < latent_time % fake_cp_size else 0)
|
||||
)
|
||||
|
||||
use_cp = True if i == 0 else False
|
||||
clear_fake_cp_cache = True if i == fake_cp_size - 1 else False
|
||||
@ -264,7 +280,9 @@ class SATVideoDiffusionEngine(nn.Module):
|
||||
self.model, input, sigma, c, concat_images=concat_images, **addtional_model_inputs
|
||||
)
|
||||
|
||||
samples = self.sampler(denoiser, randn, cond, uc=uc, scale=scale, scale_emb=scale_emb, ofs=ofs)
|
||||
samples = self.sampler(
|
||||
denoiser, randn, cond, uc=uc, scale=scale, scale_emb=scale_emb, ofs=ofs
|
||||
)
|
||||
samples = samples.to(self.dtype)
|
||||
return samples
|
||||
|
||||
@ -278,7 +296,9 @@ class SATVideoDiffusionEngine(nn.Module):
|
||||
log = dict()
|
||||
|
||||
for embedder in self.conditioner.embedders:
|
||||
if ((self.log_keys is None) or (embedder.input_key in self.log_keys)) and not self.no_cond_log:
|
||||
if (
|
||||
(self.log_keys is None) or (embedder.input_key in self.log_keys)
|
||||
) and not self.no_cond_log:
|
||||
x = batch[embedder.input_key][:n]
|
||||
if isinstance(x, torch.Tensor):
|
||||
if x.dim() == 1:
|
||||
@ -354,7 +374,9 @@ class SATVideoDiffusionEngine(nn.Module):
|
||||
image = torch.concat([image, torch.zeros_like(z[:, 1:])], dim=1)
|
||||
c["concat"] = image
|
||||
uc["concat"] = image
|
||||
samples = self.sample(c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs) # b t c h w
|
||||
samples = self.sample(
|
||||
c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
|
||||
) # b t c h w
|
||||
samples = samples.permute(0, 2, 1, 3, 4).contiguous()
|
||||
if only_log_video_latents:
|
||||
latents = 1.0 / self.scale_factor * samples
|
||||
@ -364,7 +386,9 @@ class SATVideoDiffusionEngine(nn.Module):
|
||||
samples = samples.permute(0, 2, 1, 3, 4).contiguous()
|
||||
log["samples"] = samples
|
||||
else:
|
||||
samples = self.sample(c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs) # b t c h w
|
||||
samples = self.sample(
|
||||
c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
|
||||
) # b t c h w
|
||||
samples = samples.permute(0, 2, 1, 3, 4).contiguous()
|
||||
if only_log_video_latents:
|
||||
latents = 1.0 / self.scale_factor * samples
|
||||
|
@ -94,7 +94,9 @@ def get_3d_sincos_pos_embed(
|
||||
|
||||
# concate: [T, H, W] order
|
||||
pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :]
|
||||
pos_embed_temporal = np.repeat(pos_embed_temporal, grid_height * grid_width, axis=1) # [T, H*W, D // 4]
|
||||
pos_embed_temporal = np.repeat(
|
||||
pos_embed_temporal, grid_height * grid_width, axis=1
|
||||
) # [T, H*W, D // 4]
|
||||
pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :]
|
||||
pos_embed_spatial = np.repeat(pos_embed_spatial, t_size, axis=0) # [T, H*W, D // 4 * 3]
|
||||
|
||||
@ -160,7 +162,8 @@ class Basic2DPositionEmbeddingMixin(BaseMixin):
|
||||
self.width = width
|
||||
self.spatial_length = height * width
|
||||
self.pos_embedding = nn.Parameter(
|
||||
torch.zeros(1, int(text_length + self.spatial_length), int(hidden_size)), requires_grad=False
|
||||
torch.zeros(1, int(text_length + self.spatial_length), int(hidden_size)),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
def position_embedding_forward(self, position_ids, **kwargs):
|
||||
@ -169,7 +172,9 @@ class Basic2DPositionEmbeddingMixin(BaseMixin):
|
||||
def reinit(self, parent_model=None):
|
||||
del self.transformer.position_embeddings
|
||||
pos_embed = get_2d_sincos_pos_embed(self.pos_embedding.shape[-1], self.height, self.width)
|
||||
self.pos_embedding.data[:, -self.spatial_length :].copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
||||
self.pos_embedding.data[:, -self.spatial_length :].copy_(
|
||||
torch.from_numpy(pos_embed).float().unsqueeze(0)
|
||||
)
|
||||
|
||||
|
||||
class Basic3DPositionEmbeddingMixin(BaseMixin):
|
||||
@ -192,7 +197,8 @@ class Basic3DPositionEmbeddingMixin(BaseMixin):
|
||||
self.spatial_length = height * width
|
||||
self.num_patches = height * width * compressed_num_frames
|
||||
self.pos_embedding = nn.Parameter(
|
||||
torch.zeros(1, int(text_length + self.num_patches), int(hidden_size)), requires_grad=False
|
||||
torch.zeros(1, int(text_length + self.num_patches), int(hidden_size)),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.height_interpolation = height_interpolation
|
||||
self.width_interpolation = width_interpolation
|
||||
@ -285,7 +291,10 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin):
|
||||
freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2)
|
||||
freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
|
||||
|
||||
freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
|
||||
freqs = broadcat(
|
||||
(freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]),
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
freqs = freqs.contiguous()
|
||||
self.freqs_sin = freqs.sin().cuda()
|
||||
@ -293,7 +302,9 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin):
|
||||
self.text_length = text_length
|
||||
if learnable_pos_embed:
|
||||
num_patches = height * width * compressed_num_frames + text_length
|
||||
self.pos_embedding = nn.Parameter(torch.zeros(1, num_patches, int(hidden_size)), requires_grad=True)
|
||||
self.pos_embedding = nn.Parameter(
|
||||
torch.zeros(1, num_patches, int(hidden_size)), requires_grad=True
|
||||
)
|
||||
else:
|
||||
self.pos_embedding = None
|
||||
|
||||
@ -440,16 +451,26 @@ class FinalLayerMixin(BaseMixin):
|
||||
self.out_channels = out_channels
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=elementwise_affine, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, reduce(mul, patch_size) * out_channels, bias=True)
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim, 2 * hidden_size, bias=True))
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(), nn.Linear(time_embed_dim, 2 * hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def final_forward(self, logits, **kwargs):
|
||||
x, emb = logits[:, kwargs["text_length"] :, :], kwargs["emb"] # x:(b,(t n),d),只取了x中后面images的部分
|
||||
x, emb = (
|
||||
logits[:, kwargs["text_length"] :, :],
|
||||
kwargs["emb"],
|
||||
) # x:(b,(t n),d),只取了x中后面images的部分
|
||||
shift, scale = self.adaLN_modulation(emb).chunk(2, dim=1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
|
||||
return unpatchify(
|
||||
x, c=self.out_channels, patch_size=self.patch_size, w=kwargs["rope_W"], h=kwargs["rope_H"], **kwargs
|
||||
x,
|
||||
c=self.out_channels,
|
||||
patch_size=self.patch_size,
|
||||
w=kwargs["rope_W"],
|
||||
h=kwargs["rope_H"],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def reinit(self, parent_model=None):
|
||||
@ -500,7 +521,10 @@ class AdaLNMixin(BaseMixin):
|
||||
self.compressed_num_frames = compressed_num_frames
|
||||
|
||||
self.adaLN_modulations = nn.ModuleList(
|
||||
[nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim, 12 * hidden_size)) for _ in range(num_layers)]
|
||||
[
|
||||
nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim, 12 * hidden_size))
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.qk_ln = qk_ln
|
||||
@ -560,7 +584,9 @@ class AdaLNMixin(BaseMixin):
|
||||
img_attention_input = modulate(img_attention_input, shift_msa, scale_msa)
|
||||
text_attention_input = modulate(text_attention_input, text_shift_msa, text_scale_msa)
|
||||
|
||||
attention_input = torch.cat((text_attention_input, img_attention_input), dim=1) # (b,n_t+t*n_i,d)
|
||||
attention_input = torch.cat(
|
||||
(text_attention_input, img_attention_input), dim=1
|
||||
) # (b,n_t+t*n_i,d)
|
||||
attention_output = layer.attention(attention_input, mask, **kwargs)
|
||||
text_attention_output = attention_output[:, :text_length] # (b,n,d)
|
||||
img_attention_output = attention_output[:, text_length:] # (b,(t n),d)
|
||||
@ -584,9 +610,13 @@ class AdaLNMixin(BaseMixin):
|
||||
img_mlp_output = layer.fourth_layernorm(img_mlp_output)
|
||||
|
||||
img_hidden_states = img_hidden_states + gate_mlp * img_mlp_output # vision (b,(t n),d)
|
||||
text_hidden_states = text_hidden_states + text_gate_mlp * text_mlp_output # language (b,n,d)
|
||||
text_hidden_states = (
|
||||
text_hidden_states + text_gate_mlp * text_mlp_output
|
||||
) # language (b,n,d)
|
||||
|
||||
hidden_states = torch.cat((text_hidden_states, img_hidden_states), dim=1) # (b,(n_t+t*n_i),d)
|
||||
hidden_states = torch.cat(
|
||||
(text_hidden_states, img_hidden_states), dim=1
|
||||
) # (b,(n_t+t*n_i),d)
|
||||
return hidden_states
|
||||
|
||||
def reinit(self, parent_model=None):
|
||||
@ -694,7 +724,9 @@ class DiffusionTransformer(BaseModel):
|
||||
if use_RMSNorm:
|
||||
kwargs["layernorm"] = RMSNorm
|
||||
else:
|
||||
kwargs["layernorm"] = partial(LayerNorm, elementwise_affine=elementwise_affine, eps=1e-6)
|
||||
kwargs["layernorm"] = partial(
|
||||
LayerNorm, elementwise_affine=elementwise_affine, eps=1e-6
|
||||
)
|
||||
|
||||
transformer_args.num_layers = num_layers
|
||||
transformer_args.hidden_size = hidden_size
|
||||
@ -707,7 +739,9 @@ class DiffusionTransformer(BaseModel):
|
||||
|
||||
if use_SwiGLU:
|
||||
self.add_mixin(
|
||||
"swiglu", SwiGLUMixin(num_layers, hidden_size, self.inner_hidden_size, bias=False), reinit=True
|
||||
"swiglu",
|
||||
SwiGLUMixin(num_layers, hidden_size, self.inner_hidden_size, bias=False),
|
||||
reinit=True,
|
||||
)
|
||||
|
||||
def _build_modules(self, module_configs):
|
||||
@ -813,7 +847,9 @@ class DiffusionTransformer(BaseModel):
|
||||
)
|
||||
if "lora_config" in module_configs:
|
||||
lora_config = module_configs["lora_config"]
|
||||
self.add_mixin("lora", instantiate_from_config(lora_config, layer_num=self.num_layers), reinit=True)
|
||||
self.add_mixin(
|
||||
"lora", instantiate_from_config(lora_config, layer_num=self.num_layers), reinit=True
|
||||
)
|
||||
return
|
||||
|
||||
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
|
||||
@ -829,7 +865,9 @@ class DiffusionTransformer(BaseModel):
|
||||
assert (y is not None) == (
|
||||
self.num_classes is not None
|
||||
), "must specify y if and only if the model is class-conditional"
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False, dtype=self.dtype)
|
||||
t_emb = timestep_embedding(
|
||||
timesteps, self.model_channels, repeat_only=False, dtype=self.dtype
|
||||
)
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
if self.num_classes is not None:
|
||||
@ -838,7 +876,9 @@ class DiffusionTransformer(BaseModel):
|
||||
emb = emb + self.label_emb(y)
|
||||
|
||||
if self.ofs_embed_dim is not None:
|
||||
ofs_emb = timestep_embedding(kwargs["ofs"], self.ofs_embed_dim, repeat_only=False, dtype=self.dtype)
|
||||
ofs_emb = timestep_embedding(
|
||||
kwargs["ofs"], self.ofs_embed_dim, repeat_only=False, dtype=self.dtype
|
||||
)
|
||||
ofs_emb = self.ofs_embed(ofs_emb)
|
||||
emb = emb + ofs_emb
|
||||
|
||||
@ -852,6 +892,8 @@ class DiffusionTransformer(BaseModel):
|
||||
kwargs["rope_H"] = h // self.patch_size[1]
|
||||
kwargs["rope_W"] = w // self.patch_size[2]
|
||||
|
||||
kwargs["input_ids"] = kwargs["position_ids"] = kwargs["attention_mask"] = torch.ones((1, 1)).to(x.dtype)
|
||||
kwargs["input_ids"] = kwargs["position_ids"] = kwargs["attention_mask"] = torch.ones(
|
||||
(1, 1)
|
||||
).to(x.dtype)
|
||||
output = super().forward(**kwargs)[0]
|
||||
return output
|
||||
|
@ -19,6 +19,7 @@ from sat import mpu
|
||||
from diffusion_video import SATVideoDiffusionEngine
|
||||
from arguments import get_args
|
||||
|
||||
|
||||
def read_from_cli():
|
||||
cnt = 0
|
||||
try:
|
||||
@ -50,34 +51,50 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], T=None, device="cuda
|
||||
|
||||
for key in keys:
|
||||
if key == "txt":
|
||||
batch["txt"] = np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist()
|
||||
batch_uc["txt"] = np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist()
|
||||
batch["txt"] = (
|
||||
np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist()
|
||||
)
|
||||
batch_uc["txt"] = (
|
||||
np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist()
|
||||
)
|
||||
elif key == "original_size_as_tuple":
|
||||
batch["original_size_as_tuple"] = (
|
||||
torch.tensor([value_dict["orig_height"], value_dict["orig_width"]]).to(device).repeat(*N, 1)
|
||||
torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
|
||||
.to(device)
|
||||
.repeat(*N, 1)
|
||||
)
|
||||
elif key == "crop_coords_top_left":
|
||||
batch["crop_coords_top_left"] = (
|
||||
torch.tensor([value_dict["crop_coords_top"], value_dict["crop_coords_left"]]).to(device).repeat(*N, 1)
|
||||
torch.tensor([value_dict["crop_coords_top"], value_dict["crop_coords_left"]])
|
||||
.to(device)
|
||||
.repeat(*N, 1)
|
||||
)
|
||||
elif key == "aesthetic_score":
|
||||
batch["aesthetic_score"] = torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
|
||||
batch["aesthetic_score"] = (
|
||||
torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
|
||||
)
|
||||
batch_uc["aesthetic_score"] = (
|
||||
torch.tensor([value_dict["negative_aesthetic_score"]]).to(device).repeat(*N, 1)
|
||||
)
|
||||
|
||||
elif key == "target_size_as_tuple":
|
||||
batch["target_size_as_tuple"] = (
|
||||
torch.tensor([value_dict["target_height"], value_dict["target_width"]]).to(device).repeat(*N, 1)
|
||||
torch.tensor([value_dict["target_height"], value_dict["target_width"]])
|
||||
.to(device)
|
||||
.repeat(*N, 1)
|
||||
)
|
||||
elif key == "fps":
|
||||
batch[key] = torch.tensor([value_dict["fps"]]).to(device).repeat(math.prod(N))
|
||||
elif key == "fps_id":
|
||||
batch[key] = torch.tensor([value_dict["fps_id"]]).to(device).repeat(math.prod(N))
|
||||
elif key == "motion_bucket_id":
|
||||
batch[key] = torch.tensor([value_dict["motion_bucket_id"]]).to(device).repeat(math.prod(N))
|
||||
batch[key] = (
|
||||
torch.tensor([value_dict["motion_bucket_id"]]).to(device).repeat(math.prod(N))
|
||||
)
|
||||
elif key == "pool_image":
|
||||
batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=math.prod(N)).to(device, dtype=torch.half)
|
||||
batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=math.prod(N)).to(
|
||||
device, dtype=torch.half
|
||||
)
|
||||
elif key == "cond_aug":
|
||||
batch[key] = repeat(
|
||||
torch.tensor([value_dict["cond_aug"]]).to("cuda"),
|
||||
@ -100,7 +117,9 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], T=None, device="cuda
|
||||
return batch, batch_uc
|
||||
|
||||
|
||||
def save_video_as_grid_and_mp4(video_batch: torch.Tensor, save_path: str, fps: int = 5, args=None, key=None):
|
||||
def save_video_as_grid_and_mp4(
|
||||
video_batch: torch.Tensor, save_path: str, fps: int = 5, args=None, key=None
|
||||
):
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
|
||||
for i, vid in enumerate(video_batch):
|
||||
@ -160,7 +179,9 @@ def sampling_main(args, model_cls):
|
||||
W = 96
|
||||
H = int(nearest_multiple_of_16(img_H / img_W * W * 8)) // 8
|
||||
chained_trainsforms = []
|
||||
chained_trainsforms.append(TT.Resize(size=[int(H * 8), int(W * 8)], interpolation=1))
|
||||
chained_trainsforms.append(
|
||||
TT.Resize(size=[int(H * 8), int(W * 8)], interpolation=1)
|
||||
)
|
||||
chained_trainsforms.append(TT.ToTensor())
|
||||
transform = TT.Compose(chained_trainsforms)
|
||||
image = transform(image).unsqueeze(0).to("cuda")
|
||||
@ -170,7 +191,9 @@ def sampling_main(args, model_cls):
|
||||
image = image / model.scale_factor
|
||||
image = image.permute(0, 2, 1, 3, 4).contiguous()
|
||||
pad_shape = (image.shape[0], T - 1, C, H, W)
|
||||
image = torch.concat([image, torch.zeros(pad_shape).to(image.device).to(image.dtype)], dim=1)
|
||||
image = torch.concat(
|
||||
[image, torch.zeros(pad_shape).to(image.device).to(image.dtype)], dim=1
|
||||
)
|
||||
else:
|
||||
image_size = args.sampling_image_size
|
||||
H, W = image_size[0], image_size[1]
|
||||
@ -181,12 +204,20 @@ def sampling_main(args, model_cls):
|
||||
mp_size = mpu.get_model_parallel_world_size()
|
||||
global_rank = torch.distributed.get_rank() // mp_size
|
||||
src = global_rank * mp_size
|
||||
torch.distributed.broadcast_object_list(text_cast, src=src, group=mpu.get_model_parallel_group())
|
||||
torch.distributed.broadcast_object_list(
|
||||
text_cast, src=src, group=mpu.get_model_parallel_group()
|
||||
)
|
||||
text = text_cast[0]
|
||||
value_dict = {"prompt": text, "negative_prompt": "", "num_frames": torch.tensor(T).unsqueeze(0)}
|
||||
value_dict = {
|
||||
"prompt": text,
|
||||
"negative_prompt": "",
|
||||
"num_frames": torch.tensor(T).unsqueeze(0),
|
||||
}
|
||||
|
||||
batch, batch_uc = get_batch(
|
||||
get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples
|
||||
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
||||
value_dict,
|
||||
num_samples,
|
||||
)
|
||||
for key in batch:
|
||||
if isinstance(batch[key], torch.Tensor):
|
||||
@ -212,7 +243,11 @@ def sampling_main(args, model_cls):
|
||||
for index in range(args.batch_size):
|
||||
if args.image2video:
|
||||
samples_z = sample_func(
|
||||
c, uc=uc, batch_size=1, shape=(T, C, H, W), ofs=torch.tensor([2.0]).to("cuda")
|
||||
c,
|
||||
uc=uc,
|
||||
batch_size=1,
|
||||
shape=(T, C, H, W),
|
||||
ofs=torch.tensor([2.0]).to("cuda"),
|
||||
)
|
||||
else:
|
||||
samples_z = sample_func(
|
||||
@ -226,7 +261,9 @@ def sampling_main(args, model_cls):
|
||||
if args.only_save_latents:
|
||||
samples_z = 1.0 / model.scale_factor * samples_z
|
||||
save_path = os.path.join(
|
||||
args.output_dir, str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120], str(index)
|
||||
args.output_dir,
|
||||
str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120],
|
||||
str(index),
|
||||
)
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
torch.save(samples_z, os.path.join(save_path, "latent.pt"))
|
||||
@ -237,7 +274,9 @@ def sampling_main(args, model_cls):
|
||||
samples_x = samples_x.permute(0, 2, 1, 3, 4).contiguous()
|
||||
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu()
|
||||
save_path = os.path.join(
|
||||
args.output_dir, str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120], str(index)
|
||||
args.output_dir,
|
||||
str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120],
|
||||
str(index),
|
||||
)
|
||||
if mpu.get_model_parallel_rank() == 0:
|
||||
save_video_as_grid_and_mp4(samples, save_path, fps=args.sampling_fps)
|
||||
|
@ -71,15 +71,24 @@ class LambdaWarmUpCosineScheduler2:
|
||||
n = n - self.cum_cycles[cycle]
|
||||
if self.verbosity_interval > 0:
|
||||
if n % self.verbosity_interval == 0:
|
||||
print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}")
|
||||
print(
|
||||
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
||||
f"current cycle {cycle}"
|
||||
)
|
||||
if n < self.lr_warm_up_steps[cycle]:
|
||||
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
|
||||
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
|
||||
cycle
|
||||
] * n + self.f_start[cycle]
|
||||
self.last_f = f
|
||||
return f
|
||||
else:
|
||||
t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
|
||||
t = (n - self.lr_warm_up_steps[cycle]) / (
|
||||
self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]
|
||||
)
|
||||
t = min(t, 1.0)
|
||||
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (1 + np.cos(t * np.pi))
|
||||
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
|
||||
1 + np.cos(t * np.pi)
|
||||
)
|
||||
self.last_f = f
|
||||
return f
|
||||
|
||||
@ -93,10 +102,15 @@ class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
|
||||
n = n - self.cum_cycles[cycle]
|
||||
if self.verbosity_interval > 0:
|
||||
if n % self.verbosity_interval == 0:
|
||||
print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}")
|
||||
print(
|
||||
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
||||
f"current cycle {cycle}"
|
||||
)
|
||||
|
||||
if n < self.lr_warm_up_steps[cycle]:
|
||||
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
|
||||
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
|
||||
cycle
|
||||
] * n + self.f_start[cycle]
|
||||
self.last_f = f
|
||||
return f
|
||||
else:
|
||||
|
@ -218,14 +218,20 @@ class AutoencodingEngine(AbstractAutoencoder):
|
||||
x = self.decoder(z, **kwargs)
|
||||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor, **additional_decode_kwargs) -> Tuple[torch.Tensor, torch.Tensor, dict]:
|
||||
def forward(
|
||||
self, x: torch.Tensor, **additional_decode_kwargs
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
|
||||
z, reg_log = self.encode(x, return_reg_log=True)
|
||||
dec = self.decode(z, **additional_decode_kwargs)
|
||||
return z, dec, reg_log
|
||||
|
||||
def inner_training_step(self, batch: dict, batch_idx: int, optimizer_idx: int = 0) -> torch.Tensor:
|
||||
def inner_training_step(
|
||||
self, batch: dict, batch_idx: int, optimizer_idx: int = 0
|
||||
) -> torch.Tensor:
|
||||
x = self.get_input(batch)
|
||||
additional_decode_kwargs = {key: batch[key] for key in self.additional_decode_keys.intersection(batch)}
|
||||
additional_decode_kwargs = {
|
||||
key: batch[key] for key in self.additional_decode_keys.intersection(batch)
|
||||
}
|
||||
z, xrec, regularization_log = self(x, **additional_decode_kwargs)
|
||||
if hasattr(self.loss, "forward_keys"):
|
||||
extra_info = {
|
||||
@ -361,12 +367,16 @@ class AutoencodingEngine(AbstractAutoencoder):
|
||||
if self.trainable_ae_params is None:
|
||||
ae_params = self.get_autoencoder_params()
|
||||
else:
|
||||
ae_params, num_ae_params = self.get_param_groups(self.trainable_ae_params, self.ae_optimizer_args)
|
||||
ae_params, num_ae_params = self.get_param_groups(
|
||||
self.trainable_ae_params, self.ae_optimizer_args
|
||||
)
|
||||
logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}")
|
||||
if self.trainable_disc_params is None:
|
||||
disc_params = self.get_discriminator_params()
|
||||
else:
|
||||
disc_params, num_disc_params = self.get_param_groups(self.trainable_disc_params, self.disc_optimizer_args)
|
||||
disc_params, num_disc_params = self.get_param_groups(
|
||||
self.trainable_disc_params, self.disc_optimizer_args
|
||||
)
|
||||
logpy.info(f"Number of trainable discriminator parameters: {num_disc_params:,}")
|
||||
opt_ae = self.instantiate_optimizer_from_config(
|
||||
ae_params,
|
||||
@ -375,17 +385,23 @@ class AutoencodingEngine(AbstractAutoencoder):
|
||||
)
|
||||
opts = [opt_ae]
|
||||
if len(disc_params) > 0:
|
||||
opt_disc = self.instantiate_optimizer_from_config(disc_params, self.learning_rate, self.optimizer_config)
|
||||
opt_disc = self.instantiate_optimizer_from_config(
|
||||
disc_params, self.learning_rate, self.optimizer_config
|
||||
)
|
||||
opts.append(opt_disc)
|
||||
|
||||
return opts
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs) -> dict:
|
||||
def log_images(
|
||||
self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs
|
||||
) -> dict:
|
||||
log = dict()
|
||||
additional_decode_kwargs = {}
|
||||
x = self.get_input(batch)
|
||||
additional_decode_kwargs.update({key: batch[key] for key in self.additional_decode_keys.intersection(batch)})
|
||||
additional_decode_kwargs.update(
|
||||
{key: batch[key] for key in self.additional_decode_keys.intersection(batch)}
|
||||
)
|
||||
|
||||
_, xrec, _ = self(x, **additional_decode_kwargs)
|
||||
log["inputs"] = x
|
||||
@ -404,7 +420,9 @@ class AutoencodingEngine(AbstractAutoencoder):
|
||||
diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x)
|
||||
diff_ema.clamp_(0, 1.0)
|
||||
log["diff_ema"] = 2.0 * diff_ema - 1.0
|
||||
log["diff_boost_ema"] = 2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1
|
||||
log["diff_boost_ema"] = (
|
||||
2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1
|
||||
)
|
||||
if additional_log_kwargs:
|
||||
additional_decode_kwargs.update(additional_log_kwargs)
|
||||
_, xrec_add, _ = self(x, **additional_decode_kwargs)
|
||||
@ -446,7 +464,9 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
|
||||
params = super().get_autoencoder_params()
|
||||
return params
|
||||
|
||||
def encode(self, x: torch.Tensor, return_reg_log: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_reg_log: bool = False
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
||||
if self.max_batch_size is None:
|
||||
z = self.encoder(x)
|
||||
z = self.quant_conv(z)
|
||||
@ -513,7 +533,9 @@ class VideoAutoencodingEngine(AutoencodingEngine):
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
|
||||
def log_videos(self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs) -> dict:
|
||||
def log_videos(
|
||||
self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs
|
||||
) -> dict:
|
||||
return self.log_images(batch, additional_log_kwargs, **kwargs)
|
||||
|
||||
def get_input(self, batch: dict) -> torch.Tensor:
|
||||
@ -524,7 +546,9 @@ class VideoAutoencodingEngine(AutoencodingEngine):
|
||||
batch = batch[self.input_key]
|
||||
|
||||
global_src_rank = get_context_parallel_group_rank() * self.context_parallel_size
|
||||
torch.distributed.broadcast(batch, src=global_src_rank, group=get_context_parallel_group())
|
||||
torch.distributed.broadcast(
|
||||
batch, src=global_src_rank, group=get_context_parallel_group()
|
||||
)
|
||||
|
||||
batch = _conv_split(batch, dim=2, kernel_size=1)
|
||||
return batch
|
||||
|
@ -94,7 +94,11 @@ class FeedForward(nn.Module):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = default(dim_out, dim)
|
||||
project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
|
||||
project_in = (
|
||||
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
|
||||
if not glu
|
||||
else GEGLU(dim, inner_dim)
|
||||
)
|
||||
|
||||
self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
|
||||
|
||||
@ -126,7 +130,9 @@ class LinearAttention(nn.Module):
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
qkv = self.to_qkv(x)
|
||||
q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
|
||||
q, k, v = rearrange(
|
||||
qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
|
||||
)
|
||||
k = k.softmax(dim=-1)
|
||||
context = torch.einsum("bhdn,bhen->bhde", k, v)
|
||||
out = torch.einsum("bhde,bhdn->bhen", context, q)
|
||||
@ -143,7 +149,9 @@ class SpatialSelfAttention(nn.Module):
|
||||
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
@ -244,7 +252,9 @@ class CrossAttention(nn.Module):
|
||||
# new
|
||||
with sdp_kernel(**BACKEND_MAP[self.backend]):
|
||||
# print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
|
||||
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) # scale is dim_head ** -0.5 per default
|
||||
out = F.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=mask
|
||||
) # scale is dim_head ** -0.5 per default
|
||||
|
||||
del q, k, v
|
||||
out = rearrange(out, "b h n d -> b n (h d)", h=h)
|
||||
@ -422,7 +432,9 @@ class BasicTransformerBlock(nn.Module):
|
||||
self.norm1(x),
|
||||
context=context if self.disable_self_attn else None,
|
||||
additional_tokens=additional_tokens,
|
||||
n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self if not self.disable_self_attn else 0,
|
||||
n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self
|
||||
if not self.disable_self_attn
|
||||
else 0,
|
||||
)
|
||||
+ x
|
||||
)
|
||||
@ -499,7 +511,9 @@ class SpatialTransformer(nn.Module):
|
||||
sdp_backend=None,
|
||||
):
|
||||
super().__init__()
|
||||
print(f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads")
|
||||
print(
|
||||
f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads"
|
||||
)
|
||||
from omegaconf import ListConfig
|
||||
|
||||
if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)):
|
||||
@ -542,7 +556,9 @@ class SpatialTransformer(nn.Module):
|
||||
]
|
||||
)
|
||||
if not use_linear:
|
||||
self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
|
||||
self.proj_out = zero_module(
|
||||
nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
)
|
||||
else:
|
||||
# self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
|
||||
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
|
||||
|
@ -87,7 +87,9 @@ class GeneralLPIPSWithDiscriminator(nn.Module):
|
||||
yield from ()
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(self, inputs: torch.Tensor, reconstructions: torch.Tensor) -> Dict[str, torch.Tensor]:
|
||||
def log_images(
|
||||
self, inputs: torch.Tensor, reconstructions: torch.Tensor
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
# calc logits of real/fake
|
||||
logits_real = self.discriminator(inputs.contiguous().detach())
|
||||
if len(logits_real.shape) < 4:
|
||||
@ -209,7 +211,9 @@ class GeneralLPIPSWithDiscriminator(nn.Module):
|
||||
weights: Union[None, float, torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, dict]:
|
||||
if self.scale_input_to_tgt_size:
|
||||
inputs = torch.nn.functional.interpolate(inputs, reconstructions.shape[2:], mode="bicubic", antialias=True)
|
||||
inputs = torch.nn.functional.interpolate(
|
||||
inputs, reconstructions.shape[2:], mode="bicubic", antialias=True
|
||||
)
|
||||
|
||||
if self.dims > 2:
|
||||
inputs, reconstructions = map(
|
||||
@ -226,7 +230,9 @@ class GeneralLPIPSWithDiscriminator(nn.Module):
|
||||
input_frames = pick_video_frame(inputs, frame_indices)
|
||||
recon_frames = pick_video_frame(reconstructions, frame_indices)
|
||||
|
||||
p_loss = self.perceptual_loss(input_frames.contiguous(), recon_frames.contiguous()).mean()
|
||||
p_loss = self.perceptual_loss(
|
||||
input_frames.contiguous(), recon_frames.contiguous()
|
||||
).mean()
|
||||
rec_loss = rec_loss + self.perceptual_weight * p_loss
|
||||
|
||||
nll_loss, weighted_nll_loss = self.get_nll_loss(rec_loss, weights)
|
||||
@ -238,7 +244,9 @@ class GeneralLPIPSWithDiscriminator(nn.Module):
|
||||
logits_fake = self.discriminator(reconstructions.contiguous())
|
||||
g_loss = -torch.mean(logits_fake)
|
||||
if self.training:
|
||||
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
|
||||
d_weight = self.calculate_adaptive_weight(
|
||||
nll_loss, g_loss, last_layer=last_layer
|
||||
)
|
||||
else:
|
||||
d_weight = torch.tensor(1.0)
|
||||
else:
|
||||
|
@ -37,12 +37,18 @@ class LatentLPIPS(nn.Module):
|
||||
if self.perceptual_weight > 0.0:
|
||||
image_reconstructions = self.decoder.decode(latent_predictions)
|
||||
image_targets = self.decoder.decode(latent_inputs)
|
||||
perceptual_loss = self.perceptual_loss(image_targets.contiguous(), image_reconstructions.contiguous())
|
||||
loss = self.latent_weight * loss.mean() + self.perceptual_weight * perceptual_loss.mean()
|
||||
perceptual_loss = self.perceptual_loss(
|
||||
image_targets.contiguous(), image_reconstructions.contiguous()
|
||||
)
|
||||
loss = (
|
||||
self.latent_weight * loss.mean() + self.perceptual_weight * perceptual_loss.mean()
|
||||
)
|
||||
log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach()
|
||||
|
||||
if self.perceptual_weight_on_inputs > 0.0:
|
||||
image_reconstructions = default(image_reconstructions, self.decoder.decode(latent_predictions))
|
||||
image_reconstructions = default(
|
||||
image_reconstructions, self.decoder.decode(latent_predictions)
|
||||
)
|
||||
if self.scale_input_to_tgt_size:
|
||||
image_inputs = torch.nn.functional.interpolate(
|
||||
image_inputs,
|
||||
@ -58,7 +64,9 @@ class LatentLPIPS(nn.Module):
|
||||
antialias=True,
|
||||
)
|
||||
|
||||
perceptual_loss2 = self.perceptual_loss(image_inputs.contiguous(), image_reconstructions.contiguous())
|
||||
perceptual_loss2 = self.perceptual_loss(
|
||||
image_inputs.contiguous(), image_reconstructions.contiguous()
|
||||
)
|
||||
loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean()
|
||||
log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach()
|
||||
return loss, log
|
||||
|
@ -45,7 +45,9 @@ def hinge_gen_loss(fake):
|
||||
@autocast(enabled=False)
|
||||
@beartype
|
||||
def grad_layer_wrt_loss(loss: Tensor, layer: nn.Parameter):
|
||||
return torch_grad(outputs=loss, inputs=layer, grad_outputs=torch.ones_like(loss), retain_graph=True)[0].detach()
|
||||
return torch_grad(
|
||||
outputs=loss, inputs=layer, grad_outputs=torch.ones_like(loss), retain_graph=True
|
||||
)[0].detach()
|
||||
|
||||
|
||||
def pick_video_frame(video, frame_indices):
|
||||
@ -126,7 +128,8 @@ class DiscriminatorBlock(nn.Module):
|
||||
|
||||
self.downsample = (
|
||||
nn.Sequential(
|
||||
Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2), nn.Conv2d(filters * 4, filters, 1)
|
||||
Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2),
|
||||
nn.Conv2d(filters * 4, filters, 1),
|
||||
)
|
||||
if downsample
|
||||
else None
|
||||
@ -185,11 +188,18 @@ class Discriminator(nn.Module):
|
||||
is_not_last = ind != (len(layer_dims_in_out) - 1)
|
||||
|
||||
block = DiscriminatorBlock(
|
||||
in_chan, out_chan, downsample=is_not_last, antialiased_downsample=antialiased_downsample
|
||||
in_chan,
|
||||
out_chan,
|
||||
downsample=is_not_last,
|
||||
antialiased_downsample=antialiased_downsample,
|
||||
)
|
||||
|
||||
attn_block = nn.Sequential(
|
||||
Residual(LinearSpaceAttention(dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head)),
|
||||
Residual(
|
||||
LinearSpaceAttention(
|
||||
dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head
|
||||
)
|
||||
),
|
||||
Residual(FeedForward(dim=out_chan, mult=ff_mult, images=True)),
|
||||
)
|
||||
|
||||
@ -363,7 +373,9 @@ class Discriminator3D(nn.Module):
|
||||
)
|
||||
attn_block = nn.Sequential(
|
||||
Residual(
|
||||
LinearSpaceAttention(dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head)
|
||||
LinearSpaceAttention(
|
||||
dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head
|
||||
)
|
||||
),
|
||||
Residual(FeedForward(dim=out_chan, mult=ff_mult, images=True)),
|
||||
)
|
||||
@ -458,7 +470,9 @@ class Discriminator3DWithfirstframe(nn.Module):
|
||||
)
|
||||
attn_block = nn.Sequential(
|
||||
Residual(
|
||||
LinearSpaceAttention(dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head)
|
||||
LinearSpaceAttention(
|
||||
dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head
|
||||
)
|
||||
),
|
||||
Residual(FeedForward(dim=out_chan, mult=ff_mult, images=True)),
|
||||
)
|
||||
@ -581,11 +595,17 @@ class VideoAutoencoderLoss(nn.Module):
|
||||
input_frames = pick_video_frame(inputs, frame_indices)
|
||||
recon_frames = pick_video_frame(reconstructions, frame_indices)
|
||||
|
||||
perceptual_loss = self.perceptual_model(input_frames.contiguous(), recon_frames.contiguous()).mean()
|
||||
perceptual_loss = self.perceptual_model(
|
||||
input_frames.contiguous(), recon_frames.contiguous()
|
||||
).mean()
|
||||
else:
|
||||
perceptual_loss = self.zero
|
||||
|
||||
if global_step >= self.disc_start or not self.training or self.adversarial_loss_weight == 0:
|
||||
if (
|
||||
global_step >= self.disc_start
|
||||
or not self.training
|
||||
or self.adversarial_loss_weight == 0
|
||||
):
|
||||
gen_loss = self.zero
|
||||
adaptive_weight = 0
|
||||
else:
|
||||
@ -598,9 +618,13 @@ class VideoAutoencoderLoss(nn.Module):
|
||||
|
||||
adaptive_weight = 1
|
||||
if self.perceptual_weight > 0 and last_layer is not None:
|
||||
norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_layer).norm(p=2)
|
||||
norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(
|
||||
perceptual_loss, last_layer
|
||||
).norm(p=2)
|
||||
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_layer).norm(p=2)
|
||||
adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(min=1e-3)
|
||||
adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(
|
||||
min=1e-3
|
||||
)
|
||||
adaptive_weight.clamp_(max=1e3)
|
||||
|
||||
if torch.isnan(adaptive_weight).any():
|
||||
|
@ -48,7 +48,9 @@ class LPIPS(nn.Module):
|
||||
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
|
||||
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
|
||||
|
||||
res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
|
||||
res = [
|
||||
spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))
|
||||
]
|
||||
val = res[0]
|
||||
for l in range(1, len(self.chns)):
|
||||
val += res[l]
|
||||
@ -118,7 +120,9 @@ class vgg16(torch.nn.Module):
|
||||
h_relu4_3 = h
|
||||
h = self.slice5(h)
|
||||
h_relu5_3 = h
|
||||
vgg_outputs = namedtuple("VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"])
|
||||
vgg_outputs = namedtuple(
|
||||
"VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]
|
||||
)
|
||||
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
|
||||
return out
|
||||
|
||||
|
@ -35,7 +35,9 @@ class NLayerDiscriminator(nn.Module):
|
||||
norm_layer = nn.BatchNorm2d
|
||||
else:
|
||||
norm_layer = ActNorm
|
||||
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
|
||||
if (
|
||||
type(norm_layer) == functools.partial
|
||||
): # no need to use bias as BatchNorm2d has affine parameters
|
||||
use_bias = norm_layer.func != nn.BatchNorm2d
|
||||
else:
|
||||
use_bias = norm_layer != nn.BatchNorm2d
|
||||
|
@ -11,6 +11,7 @@ def hinge_d_loss(logits_real, logits_fake):
|
||||
|
||||
def vanilla_d_loss(logits_real, logits_fake):
|
||||
d_loss = 0.5 * (
|
||||
torch.mean(torch.nn.functional.softplus(-logits_real)) + torch.mean(torch.nn.functional.softplus(logits_fake))
|
||||
torch.mean(torch.nn.functional.softplus(-logits_real))
|
||||
+ torch.mean(torch.nn.functional.softplus(logits_fake))
|
||||
)
|
||||
return d_loss
|
||||
|
@ -147,7 +147,9 @@ def hinge_gen_loss(fake):
|
||||
@autocast(enabled=False)
|
||||
@beartype
|
||||
def grad_layer_wrt_loss(loss: Tensor, layer: nn.Parameter):
|
||||
return torch_grad(outputs=loss, inputs=layer, grad_outputs=torch.ones_like(loss), retain_graph=True)[0].detach()
|
||||
return torch_grad(
|
||||
outputs=loss, inputs=layer, grad_outputs=torch.ones_like(loss), retain_graph=True
|
||||
)[0].detach()
|
||||
|
||||
|
||||
# helper decorators
|
||||
@ -223,7 +225,10 @@ class SqueezeExcite(Module):
|
||||
dim_hidden = max(dim_hidden_min, dim_out // 2)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(dim, dim_hidden, 1), nn.LeakyReLU(0.1), nn.Conv2d(dim_hidden, dim_out, 1), nn.Sigmoid()
|
||||
nn.Conv2d(dim, dim_hidden, 1),
|
||||
nn.LeakyReLU(0.1),
|
||||
nn.Conv2d(dim_hidden, dim_out, 1),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
nn.init.zeros_(self.net[-2].weight)
|
||||
@ -282,7 +287,10 @@ class RMSNorm(Module):
|
||||
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
|
||||
|
||||
def forward(self, x):
|
||||
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
|
||||
return (
|
||||
F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma
|
||||
+ self.bias
|
||||
)
|
||||
|
||||
|
||||
class AdaptiveRMSNorm(Module):
|
||||
@ -353,7 +361,8 @@ class Attention(Module):
|
||||
self.norm = RMSNorm(dim)
|
||||
|
||||
self.to_qkv = nn.Sequential(
|
||||
nn.Linear(dim, dim_inner * 3, bias=False), Rearrange("b n (qkv h d) -> qkv b h n d", qkv=3, h=heads)
|
||||
nn.Linear(dim, dim_inner * 3, bias=False),
|
||||
Rearrange("b n (qkv h d) -> qkv b h n d", qkv=3, h=heads),
|
||||
)
|
||||
|
||||
assert num_memory_kv > 0
|
||||
@ -361,7 +370,9 @@ class Attention(Module):
|
||||
|
||||
self.attend = Attend(causal=causal, dropout=dropout, flash=flash)
|
||||
|
||||
self.to_out = nn.Sequential(Rearrange("b h n d -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False))
|
||||
self.to_out = nn.Sequential(
|
||||
Rearrange("b h n d -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False)
|
||||
)
|
||||
|
||||
@beartype
|
||||
def forward(self, x, mask: Optional[Tensor] = None, cond: Optional[Tensor] = None):
|
||||
@ -455,7 +466,9 @@ class FeedForward(Module):
|
||||
super().__init__()
|
||||
conv_klass = nn.Conv2d if images else nn.Conv3d
|
||||
|
||||
rmsnorm_klass = RMSNorm if not exists(dim_cond) else partial(AdaptiveRMSNorm, dim_cond=dim_cond)
|
||||
rmsnorm_klass = (
|
||||
RMSNorm if not exists(dim_cond) else partial(AdaptiveRMSNorm, dim_cond=dim_cond)
|
||||
)
|
||||
|
||||
maybe_adaptive_norm_klass = partial(rmsnorm_klass, channel_first=True, images=images)
|
||||
|
||||
@ -463,7 +476,9 @@ class FeedForward(Module):
|
||||
|
||||
self.norm = maybe_adaptive_norm_klass(dim)
|
||||
|
||||
self.net = Sequential(conv_klass(dim, dim_inner * 2, 1), GEGLU(), conv_klass(dim_inner, dim, 1))
|
||||
self.net = Sequential(
|
||||
conv_klass(dim, dim_inner * 2, 1), GEGLU(), conv_klass(dim_inner, dim, 1)
|
||||
)
|
||||
|
||||
@beartype
|
||||
def forward(self, x: Tensor, *, cond: Optional[Tensor] = None):
|
||||
@ -525,7 +540,8 @@ class DiscriminatorBlock(Module):
|
||||
|
||||
self.downsample = (
|
||||
nn.Sequential(
|
||||
Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2), nn.Conv2d(filters * 4, filters, 1)
|
||||
Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2),
|
||||
nn.Conv2d(filters * 4, filters, 1),
|
||||
)
|
||||
if downsample
|
||||
else None
|
||||
@ -584,11 +600,18 @@ class Discriminator(Module):
|
||||
is_not_last = ind != (len(layer_dims_in_out) - 1)
|
||||
|
||||
block = DiscriminatorBlock(
|
||||
in_chan, out_chan, downsample=is_not_last, antialiased_downsample=antialiased_downsample
|
||||
in_chan,
|
||||
out_chan,
|
||||
downsample=is_not_last,
|
||||
antialiased_downsample=antialiased_downsample,
|
||||
)
|
||||
|
||||
attn_block = Sequential(
|
||||
Residual(LinearSpaceAttention(dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head)),
|
||||
Residual(
|
||||
LinearSpaceAttention(
|
||||
dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head
|
||||
)
|
||||
),
|
||||
Residual(FeedForward(dim=out_chan, mult=ff_mult, images=True)),
|
||||
)
|
||||
|
||||
@ -628,7 +651,16 @@ class Discriminator(Module):
|
||||
class Conv3DMod(Module):
|
||||
@beartype
|
||||
def __init__(
|
||||
self, dim, *, spatial_kernel, time_kernel, causal=True, dim_out=None, demod=True, eps=1e-8, pad_mode="zeros"
|
||||
self,
|
||||
dim,
|
||||
*,
|
||||
spatial_kernel,
|
||||
time_kernel,
|
||||
causal=True,
|
||||
dim_out=None,
|
||||
demod=True,
|
||||
eps=1e-8,
|
||||
pad_mode="zeros",
|
||||
):
|
||||
super().__init__()
|
||||
dim_out = default(dim_out, dim)
|
||||
@ -644,7 +676,9 @@ class Conv3DMod(Module):
|
||||
|
||||
self.pad_mode = pad_mode
|
||||
self.padding = (*((spatial_kernel // 2,) * 4), *time_padding)
|
||||
self.weights = nn.Parameter(torch.randn((dim_out, dim, time_kernel, spatial_kernel, spatial_kernel)))
|
||||
self.weights = nn.Parameter(
|
||||
torch.randn((dim_out, dim, time_kernel, spatial_kernel, spatial_kernel))
|
||||
)
|
||||
|
||||
self.demod = demod
|
||||
|
||||
@ -675,7 +709,11 @@ class Conv3DMod(Module):
|
||||
weights = weights * (cond + 1)
|
||||
|
||||
if self.demod:
|
||||
inv_norm = reduce(weights**2, "b o i k0 k1 k2 -> b o 1 1 1 1", "sum").clamp(min=self.eps).rsqrt()
|
||||
inv_norm = (
|
||||
reduce(weights**2, "b o i k0 k1 k2 -> b o 1 1 1 1", "sum")
|
||||
.clamp(min=self.eps)
|
||||
.rsqrt()
|
||||
)
|
||||
weights = weights * inv_norm
|
||||
|
||||
fmap = rearrange(fmap, "b c t h w -> 1 (b c) t h w")
|
||||
@ -742,7 +780,9 @@ class SpatialUpsample2x(Module):
|
||||
dim_out = default(dim_out, dim)
|
||||
conv = nn.Conv2d(dim, dim_out * 4, 1)
|
||||
|
||||
self.net = nn.Sequential(conv, nn.SiLU(), Rearrange("b (c p1 p2) h w -> b c (h p1) (w p2)", p1=2, p2=2))
|
||||
self.net = nn.Sequential(
|
||||
conv, nn.SiLU(), Rearrange("b (c p1 p2) h w -> b c (h p1) (w p2)", p1=2, p2=2)
|
||||
)
|
||||
|
||||
self.init_conv_(conv)
|
||||
|
||||
@ -808,7 +848,12 @@ def SameConv2d(dim_in, dim_out, kernel_size):
|
||||
class CausalConv3d(Module):
|
||||
@beartype
|
||||
def __init__(
|
||||
self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], pad_mode="constant", **kwargs
|
||||
self,
|
||||
chan_in,
|
||||
chan_out,
|
||||
kernel_size: Union[int, Tuple[int, int, int]],
|
||||
pad_mode="constant",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
kernel_size = cast_tuple(kernel_size, 3)
|
||||
@ -830,7 +875,9 @@ class CausalConv3d(Module):
|
||||
|
||||
stride = (stride, 1, 1)
|
||||
dilation = (dilation, 1, 1)
|
||||
self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
||||
self.conv = nn.Conv3d(
|
||||
chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
pad_mode = self.pad_mode if self.time_pad < x.shape[2] else "constant"
|
||||
@ -855,7 +902,13 @@ def ResidualUnit(dim, kernel_size: Union[int, Tuple[int, int, int]], pad_mode: s
|
||||
@beartype
|
||||
class ResidualUnitMod(Module):
|
||||
def __init__(
|
||||
self, dim, kernel_size: Union[int, Tuple[int, int, int]], *, dim_cond, pad_mode: str = "constant", demod=True
|
||||
self,
|
||||
dim,
|
||||
kernel_size: Union[int, Tuple[int, int, int]],
|
||||
*,
|
||||
dim_cond,
|
||||
pad_mode: str = "constant",
|
||||
demod=True,
|
||||
):
|
||||
super().__init__()
|
||||
kernel_size = cast_tuple(kernel_size, 3)
|
||||
@ -892,7 +945,15 @@ class ResidualUnitMod(Module):
|
||||
|
||||
|
||||
class CausalConvTranspose3d(Module):
|
||||
def __init__(self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], *, time_stride, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
chan_in,
|
||||
chan_out,
|
||||
kernel_size: Union[int, Tuple[int, int, int]],
|
||||
*,
|
||||
time_stride,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
kernel_size = cast_tuple(kernel_size, 3)
|
||||
|
||||
@ -908,7 +969,9 @@ class CausalConvTranspose3d(Module):
|
||||
stride = (time_stride, 1, 1)
|
||||
padding = (0, height_pad, width_pad)
|
||||
|
||||
self.conv = nn.ConvTranspose3d(chan_in, chan_out, kernel_size, stride, padding=padding, **kwargs)
|
||||
self.conv = nn.ConvTranspose3d(
|
||||
chan_in, chan_out, kernel_size, stride, padding=padding, **kwargs
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.ndim == 5
|
||||
@ -936,7 +999,9 @@ LossBreakdown = namedtuple(
|
||||
],
|
||||
)
|
||||
|
||||
DiscrLossBreakdown = namedtuple("DiscrLossBreakdown", ["discr_loss", "multiscale_discr_losses", "gradient_penalty"])
|
||||
DiscrLossBreakdown = namedtuple(
|
||||
"DiscrLossBreakdown", ["discr_loss", "multiscale_discr_losses", "gradient_penalty"]
|
||||
)
|
||||
|
||||
|
||||
class VideoTokenizer(Module):
|
||||
@ -1050,10 +1115,14 @@ class VideoTokenizer(Module):
|
||||
has_cond = True
|
||||
|
||||
encoder_layer = ResidualUnitMod(
|
||||
dim, residual_conv_kernel_size, dim_cond=int(dim_cond * dim_cond_expansion_factor)
|
||||
dim,
|
||||
residual_conv_kernel_size,
|
||||
dim_cond=int(dim_cond * dim_cond_expansion_factor),
|
||||
)
|
||||
decoder_layer = ResidualUnitMod(
|
||||
dim, residual_conv_kernel_size, dim_cond=int(dim_cond * dim_cond_expansion_factor)
|
||||
dim,
|
||||
residual_conv_kernel_size,
|
||||
dim_cond=int(dim_cond * dim_cond_expansion_factor),
|
||||
)
|
||||
dim_out = dim
|
||||
|
||||
@ -1080,15 +1149,25 @@ class VideoTokenizer(Module):
|
||||
|
||||
elif layer_type == "attend_space":
|
||||
attn_kwargs = dict(
|
||||
dim=dim, dim_head=attn_dim_head, heads=attn_heads, dropout=attn_dropout, flash=flash_attn
|
||||
dim=dim,
|
||||
dim_head=attn_dim_head,
|
||||
heads=attn_heads,
|
||||
dropout=attn_dropout,
|
||||
flash=flash_attn,
|
||||
)
|
||||
|
||||
encoder_layer = Sequential(Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim)))
|
||||
encoder_layer = Sequential(
|
||||
Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim))
|
||||
)
|
||||
|
||||
decoder_layer = Sequential(Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim)))
|
||||
decoder_layer = Sequential(
|
||||
Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim))
|
||||
)
|
||||
|
||||
elif layer_type == "linear_attend_space":
|
||||
linear_attn_kwargs = dict(dim=dim, dim_head=linear_attn_dim_head, heads=linear_attn_heads)
|
||||
linear_attn_kwargs = dict(
|
||||
dim=dim, dim_head=linear_attn_dim_head, heads=linear_attn_heads
|
||||
)
|
||||
|
||||
encoder_layer = Sequential(
|
||||
Residual(LinearSpaceAttention(**linear_attn_kwargs)), Residual(FeedForward(dim))
|
||||
@ -1136,9 +1215,13 @@ class VideoTokenizer(Module):
|
||||
flash=flash_attn,
|
||||
)
|
||||
|
||||
encoder_layer = Sequential(Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim)))
|
||||
encoder_layer = Sequential(
|
||||
Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim))
|
||||
)
|
||||
|
||||
decoder_layer = Sequential(Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim)))
|
||||
decoder_layer = Sequential(
|
||||
Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim))
|
||||
)
|
||||
|
||||
elif layer_type == "cond_linear_attend_space":
|
||||
has_cond = True
|
||||
@ -1153,11 +1236,13 @@ class VideoTokenizer(Module):
|
||||
)
|
||||
|
||||
encoder_layer = Sequential(
|
||||
Residual(LinearSpaceAttention(**attn_kwargs)), Residual(FeedForward(dim, dim_cond=dim_cond))
|
||||
Residual(LinearSpaceAttention(**attn_kwargs)),
|
||||
Residual(FeedForward(dim, dim_cond=dim_cond)),
|
||||
)
|
||||
|
||||
decoder_layer = Sequential(
|
||||
Residual(LinearSpaceAttention(**attn_kwargs)), Residual(FeedForward(dim, dim_cond=dim_cond))
|
||||
Residual(LinearSpaceAttention(**attn_kwargs)),
|
||||
Residual(FeedForward(dim, dim_cond=dim_cond)),
|
||||
)
|
||||
|
||||
elif layer_type == "cond_attend_time":
|
||||
@ -1283,7 +1368,9 @@ class VideoTokenizer(Module):
|
||||
|
||||
# discriminator
|
||||
|
||||
discr_kwargs = default(discr_kwargs, dict(dim=dim, image_size=image_size, channels=channels, max_dim=512))
|
||||
discr_kwargs = default(
|
||||
discr_kwargs, dict(dim=dim, image_size=image_size, channels=channels, max_dim=512)
|
||||
)
|
||||
|
||||
self.discr = Discriminator(**discr_kwargs)
|
||||
|
||||
@ -1380,8 +1467,16 @@ class VideoTokenizer(Module):
|
||||
self.load_state_dict(state_dict, strict=strict)
|
||||
|
||||
@beartype
|
||||
def encode(self, video: Tensor, quantize=False, cond: Optional[Tensor] = None, video_contains_first_frame=True):
|
||||
encode_first_frame_separately = self.separate_first_frame_encoding and video_contains_first_frame
|
||||
def encode(
|
||||
self,
|
||||
video: Tensor,
|
||||
quantize=False,
|
||||
cond: Optional[Tensor] = None,
|
||||
video_contains_first_frame=True,
|
||||
):
|
||||
encode_first_frame_separately = (
|
||||
self.separate_first_frame_encoding and video_contains_first_frame
|
||||
)
|
||||
|
||||
# whether to pad video or not
|
||||
|
||||
@ -1389,12 +1484,16 @@ class VideoTokenizer(Module):
|
||||
video_len = video.shape[2]
|
||||
|
||||
video = pad_at_dim(video, (self.time_padding, 0), value=0.0, dim=2)
|
||||
video_packed_shape = [torch.Size([self.time_padding]), torch.Size([]), torch.Size([video_len - 1])]
|
||||
video_packed_shape = [
|
||||
torch.Size([self.time_padding]),
|
||||
torch.Size([]),
|
||||
torch.Size([video_len - 1]),
|
||||
]
|
||||
|
||||
# conditioning, if needed
|
||||
|
||||
assert (not self.has_cond) or exists(
|
||||
cond
|
||||
assert (
|
||||
(not self.has_cond) or exists(cond)
|
||||
), "`cond` must be passed into tokenizer forward method since conditionable layers were specified"
|
||||
|
||||
if exists(cond):
|
||||
@ -1431,7 +1530,9 @@ class VideoTokenizer(Module):
|
||||
return maybe_quantize(video)
|
||||
|
||||
@beartype
|
||||
def decode_from_code_indices(self, codes: Tensor, cond: Optional[Tensor] = None, video_contains_first_frame=True):
|
||||
def decode_from_code_indices(
|
||||
self, codes: Tensor, cond: Optional[Tensor] = None, video_contains_first_frame=True
|
||||
):
|
||||
assert codes.dtype in (torch.long, torch.int32)
|
||||
|
||||
if codes.ndim == 2:
|
||||
@ -1444,18 +1545,24 @@ class VideoTokenizer(Module):
|
||||
|
||||
quantized = self.quantizers.indices_to_codes(codes)
|
||||
|
||||
return self.decode(quantized, cond=cond, video_contains_first_frame=video_contains_first_frame)
|
||||
return self.decode(
|
||||
quantized, cond=cond, video_contains_first_frame=video_contains_first_frame
|
||||
)
|
||||
|
||||
@beartype
|
||||
def decode(self, quantized: Tensor, cond: Optional[Tensor] = None, video_contains_first_frame=True):
|
||||
decode_first_frame_separately = self.separate_first_frame_encoding and video_contains_first_frame
|
||||
def decode(
|
||||
self, quantized: Tensor, cond: Optional[Tensor] = None, video_contains_first_frame=True
|
||||
):
|
||||
decode_first_frame_separately = (
|
||||
self.separate_first_frame_encoding and video_contains_first_frame
|
||||
)
|
||||
|
||||
batch = quantized.shape[0]
|
||||
|
||||
# conditioning, if needed
|
||||
|
||||
assert (not self.has_cond) or exists(
|
||||
cond
|
||||
assert (
|
||||
(not self.has_cond) or exists(cond)
|
||||
), "`cond` must be passed into tokenizer forward method since conditionable layers were specified"
|
||||
|
||||
if exists(cond):
|
||||
@ -1558,14 +1665,18 @@ class VideoTokenizer(Module):
|
||||
aux_losses = self.zero
|
||||
quantizer_loss_breakdown = None
|
||||
else:
|
||||
(quantized, codes, aux_losses), quantizer_loss_breakdown = self.quantizers(x, return_loss_breakdown=True)
|
||||
(quantized, codes, aux_losses), quantizer_loss_breakdown = self.quantizers(
|
||||
x, return_loss_breakdown=True
|
||||
)
|
||||
|
||||
if return_codes and not return_recon:
|
||||
return codes
|
||||
|
||||
# decoder
|
||||
|
||||
recon_video = self.decode(quantized, cond=cond, video_contains_first_frame=video_contains_first_frame)
|
||||
recon_video = self.decode(
|
||||
quantized, cond=cond, video_contains_first_frame=video_contains_first_frame
|
||||
)
|
||||
|
||||
if return_codes:
|
||||
return codes, recon_video
|
||||
@ -1613,7 +1724,9 @@ class VideoTokenizer(Module):
|
||||
multiscale_real_logits = discr(video)
|
||||
multiscale_fake_logits = discr(recon_video.detach())
|
||||
|
||||
multiscale_discr_loss = hinge_discr_loss(multiscale_fake_logits, multiscale_real_logits)
|
||||
multiscale_discr_loss = hinge_discr_loss(
|
||||
multiscale_fake_logits, multiscale_real_logits
|
||||
)
|
||||
|
||||
multiscale_discr_losses.append(multiscale_discr_loss)
|
||||
else:
|
||||
@ -1634,7 +1747,9 @@ class VideoTokenizer(Module):
|
||||
+ sum(multiscale_discr_losses) * self.multiscale_adversarial_loss_weight
|
||||
)
|
||||
|
||||
discr_loss_breakdown = DiscrLossBreakdown(discr_loss, multiscale_discr_losses, gradient_penalty_loss)
|
||||
discr_loss_breakdown = DiscrLossBreakdown(
|
||||
discr_loss, multiscale_discr_losses, gradient_penalty_loss
|
||||
)
|
||||
|
||||
return total_loss, discr_loss_breakdown
|
||||
|
||||
@ -1669,7 +1784,9 @@ class VideoTokenizer(Module):
|
||||
norm_grad_wrt_perceptual_loss = None
|
||||
|
||||
if self.training and self.use_vgg and (self.has_gan or self.has_multiscale_discrs):
|
||||
norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p=2)
|
||||
norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(
|
||||
perceptual_loss, last_dec_layer
|
||||
).norm(p=2)
|
||||
|
||||
# per-frame image discriminator
|
||||
|
||||
@ -1686,7 +1803,9 @@ class VideoTokenizer(Module):
|
||||
|
||||
if exists(norm_grad_wrt_perceptual_loss):
|
||||
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p=2)
|
||||
adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(min=1e-3)
|
||||
adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(
|
||||
min=1e-3
|
||||
)
|
||||
adaptive_weight.clamp_(max=1e3)
|
||||
|
||||
if torch.isnan(adaptive_weight).any():
|
||||
@ -1713,8 +1832,12 @@ class VideoTokenizer(Module):
|
||||
multiscale_adaptive_weight = 1.0
|
||||
|
||||
if exists(norm_grad_wrt_perceptual_loss):
|
||||
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(multiscale_gen_loss, last_dec_layer).norm(p=2)
|
||||
multiscale_adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(min=1e-5)
|
||||
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(
|
||||
multiscale_gen_loss, last_dec_layer
|
||||
).norm(p=2)
|
||||
multiscale_adaptive_weight = (
|
||||
norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(min=1e-5)
|
||||
)
|
||||
multiscale_adaptive_weight.clamp_(max=1e3)
|
||||
|
||||
multiscale_gen_adaptive_weights.append(multiscale_adaptive_weight)
|
||||
@ -1730,10 +1853,13 @@ class VideoTokenizer(Module):
|
||||
|
||||
if self.has_multiscale_discrs:
|
||||
weighted_multiscale_gen_losses = sum(
|
||||
loss * weight for loss, weight in zip(multiscale_gen_losses, multiscale_gen_adaptive_weights)
|
||||
loss * weight
|
||||
for loss, weight in zip(multiscale_gen_losses, multiscale_gen_adaptive_weights)
|
||||
)
|
||||
|
||||
total_loss = total_loss + weighted_multiscale_gen_losses * multiscale_adversarial_loss_weight
|
||||
total_loss = (
|
||||
total_loss + weighted_multiscale_gen_losses * multiscale_adversarial_loss_weight
|
||||
)
|
||||
|
||||
# loss breakdown
|
||||
|
||||
|
@ -26,7 +26,9 @@ class IdentityRegularizer(AbstractRegularizer):
|
||||
yield from ()
|
||||
|
||||
|
||||
def measure_perplexity(predicted_indices: torch.Tensor, num_centroids: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def measure_perplexity(
|
||||
predicted_indices: torch.Tensor, num_centroids: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
|
||||
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
|
||||
encodings = F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids)
|
||||
|
@ -79,13 +79,19 @@ class FSQ(Module):
|
||||
self.dim = default(dim, len(_levels) * num_codebooks)
|
||||
|
||||
has_projections = self.dim != effective_codebook_dim
|
||||
self.project_in = nn.Linear(self.dim, effective_codebook_dim) if has_projections else nn.Identity()
|
||||
self.project_out = nn.Linear(effective_codebook_dim, self.dim) if has_projections else nn.Identity()
|
||||
self.project_in = (
|
||||
nn.Linear(self.dim, effective_codebook_dim) if has_projections else nn.Identity()
|
||||
)
|
||||
self.project_out = (
|
||||
nn.Linear(effective_codebook_dim, self.dim) if has_projections else nn.Identity()
|
||||
)
|
||||
self.has_projections = has_projections
|
||||
|
||||
self.codebook_size = self._levels.prod().item()
|
||||
|
||||
implicit_codebook = self.indices_to_codes(torch.arange(self.codebook_size), project_out=False)
|
||||
implicit_codebook = self.indices_to_codes(
|
||||
torch.arange(self.codebook_size), project_out=False
|
||||
)
|
||||
self.register_buffer("implicit_codebook", implicit_codebook, persistent=False)
|
||||
|
||||
def bound(self, z: Tensor, eps: float = 1e-3) -> Tensor:
|
||||
@ -153,7 +159,9 @@ class FSQ(Module):
|
||||
z = rearrange(z, "b d ... -> b ... d")
|
||||
z, ps = pack_one(z, "b * d")
|
||||
|
||||
assert z.shape[-1] == self.dim, f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}"
|
||||
assert (
|
||||
z.shape[-1] == self.dim
|
||||
), f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}"
|
||||
|
||||
z = self.project_in(z)
|
||||
|
||||
|
@ -78,7 +78,9 @@ class LFQ(Module):
|
||||
|
||||
# some assert validations
|
||||
|
||||
assert exists(dim) or exists(codebook_size), "either dim or codebook_size must be specified for LFQ"
|
||||
assert exists(dim) or exists(
|
||||
codebook_size
|
||||
), "either dim or codebook_size must be specified for LFQ"
|
||||
assert (
|
||||
not exists(codebook_size) or log2(codebook_size).is_integer()
|
||||
), f"your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})"
|
||||
@ -195,7 +197,9 @@ class LFQ(Module):
|
||||
x = rearrange(x, "b d ... -> b ... d")
|
||||
x, ps = pack_one(x, "b * d")
|
||||
|
||||
assert x.shape[-1] == self.dim, f"expected dimension of {self.dim} but received {x.shape[-1]}"
|
||||
assert (
|
||||
x.shape[-1] == self.dim
|
||||
), f"expected dimension of {self.dim} but received {x.shape[-1]}"
|
||||
|
||||
x = self.project_in(x)
|
||||
|
||||
@ -299,7 +303,9 @@ class LFQ(Module):
|
||||
|
||||
# complete aux loss
|
||||
|
||||
aux_loss = entropy_aux_loss * self.entropy_loss_weight + commit_loss * self.commitment_loss_weight
|
||||
aux_loss = (
|
||||
entropy_aux_loss * self.entropy_loss_weight + commit_loss * self.commitment_loss_weight
|
||||
)
|
||||
|
||||
ret = Return(x, indices, aux_loss)
|
||||
|
||||
|
@ -33,7 +33,9 @@ class AbstractQuantizer(AbstractRegularizer):
|
||||
new = match.argmax(-1)
|
||||
unknown = match.sum(2) < 1
|
||||
if self.unknown_index == "random":
|
||||
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
|
||||
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(
|
||||
device=new.device
|
||||
)
|
||||
else:
|
||||
new[unknown] = self.unknown_index
|
||||
return new.reshape(ishape)
|
||||
@ -50,7 +52,9 @@ class AbstractQuantizer(AbstractRegularizer):
|
||||
return back.reshape(ishape)
|
||||
|
||||
@abstractmethod
|
||||
def get_codebook_entry(self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None) -> torch.Tensor:
|
||||
def get_codebook_entry(
|
||||
self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_trainable_parameters(self) -> Iterator[torch.nn.Parameter]:
|
||||
@ -239,7 +243,8 @@ class VectorQuantizer(AbstractQuantizer):
|
||||
d = (
|
||||
torch.sum(z_flattened**2, dim=1, keepdim=True)
|
||||
+ torch.sum(self.embedding.weight**2, dim=1)
|
||||
- 2 * torch.einsum("bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n"))
|
||||
- 2
|
||||
* torch.einsum("bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n"))
|
||||
)
|
||||
|
||||
min_encoding_indices = torch.argmin(d, dim=1)
|
||||
@ -267,15 +272,21 @@ class VectorQuantizer(AbstractQuantizer):
|
||||
|
||||
if self.sane_index_shape:
|
||||
if do_reshape:
|
||||
min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])
|
||||
min_encoding_indices = min_encoding_indices.reshape(
|
||||
z_q.shape[0], z_q.shape[2], z_q.shape[3]
|
||||
)
|
||||
else:
|
||||
min_encoding_indices = rearrange(min_encoding_indices, "(b s) 1 -> b s", b=z_q.shape[0])
|
||||
min_encoding_indices = rearrange(
|
||||
min_encoding_indices, "(b s) 1 -> b s", b=z_q.shape[0]
|
||||
)
|
||||
|
||||
loss_dict["min_encoding_indices"] = min_encoding_indices
|
||||
|
||||
return z_q, loss_dict
|
||||
|
||||
def get_codebook_entry(self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None) -> torch.Tensor:
|
||||
def get_codebook_entry(
|
||||
self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None
|
||||
) -> torch.Tensor:
|
||||
# shape specifying (batch, height, width, channel)
|
||||
if self.remap is not None:
|
||||
assert shape is not None, "Need to give shape for remap"
|
||||
@ -448,6 +459,8 @@ class VectorQuantizerWithInputProjection(VectorQuantizer):
|
||||
elif len(in_shape) == 5:
|
||||
z_q = rearrange(z_q, "b (t h w) c -> b c t h w ", w=in_shape[-1], h=in_shape[-2])
|
||||
else:
|
||||
raise NotImplementedError(f"rearranging not available for {len(in_shape)}-dimensional input.")
|
||||
raise NotImplementedError(
|
||||
f"rearranging not available for {len(in_shape)}-dimensional input."
|
||||
)
|
||||
|
||||
return z_q, loss_dict
|
||||
|
@ -248,7 +248,9 @@ def make_time_attn(
|
||||
"vanilla",
|
||||
"vanilla-xformers",
|
||||
], f"attn_type {attn_type} not supported for spatio-temporal attention"
|
||||
print(f"making spatial and temporal attention of type '{attn_type}' with {in_channels} in_channels")
|
||||
print(
|
||||
f"making spatial and temporal attention of type '{attn_type}' with {in_channels} in_channels"
|
||||
)
|
||||
if not XFORMERS_IS_AVAILABLE and attn_type == "vanilla-xformers":
|
||||
print(
|
||||
f"Attention mode '{attn_type}' is not available. Falling back to vanilla attention. "
|
||||
|
@ -125,9 +125,13 @@ class ResnetBlock3D(nn.Module):
|
||||
self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
|
||||
self.conv_shortcut = CausalConv3d(
|
||||
in_channels, out_channels, kernel_size=3, pad_mode=pad_mode
|
||||
)
|
||||
else:
|
||||
self.nin_shortcut = torch.nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.nin_shortcut = torch.nn.Conv3d(
|
||||
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
|
||||
def forward(self, x, temb, zq):
|
||||
h = x
|
||||
@ -161,7 +165,9 @@ class AttnBlock2D(nn.Module):
|
||||
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
|
||||
def forward(self, x, zq):
|
||||
h_ = x
|
||||
@ -380,7 +386,11 @@ class NewDecoder3D(nn.Module):
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
|
||||
print(
|
||||
"Working with z of shape {} = {} dimensions.".format(
|
||||
self.z_shape, np.prod(self.z_shape)
|
||||
)
|
||||
)
|
||||
|
||||
# z to block_in
|
||||
# self.conv_in = torch.nn.Conv3d(z_channels,
|
||||
|
@ -148,9 +148,13 @@ class ResnetBlock3D(nn.Module):
|
||||
# kernel_size=3,
|
||||
# stride=1,
|
||||
# padding=1)
|
||||
self.conv_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
|
||||
self.conv_shortcut = CausalConv3d(
|
||||
in_channels, out_channels, kernel_size=3, pad_mode=pad_mode
|
||||
)
|
||||
else:
|
||||
self.nin_shortcut = torch.nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.nin_shortcut = torch.nn.Conv3d(
|
||||
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
# self.nin_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=1, pad_mode=pad_mode)
|
||||
|
||||
def forward(self, x, temb, zq):
|
||||
@ -185,7 +189,9 @@ class AttnBlock2D(nn.Module):
|
||||
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
|
||||
def forward(self, x, zq):
|
||||
h_ = x
|
||||
@ -261,7 +267,11 @@ class MOVQDecoder3D(nn.Module):
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
|
||||
print(
|
||||
"Working with z of shape {} = {} dimensions.".format(
|
||||
self.z_shape, np.prod(self.z_shape)
|
||||
)
|
||||
)
|
||||
|
||||
# z to block_in
|
||||
# self.conv_in = torch.nn.Conv3d(z_channels,
|
||||
@ -420,7 +430,11 @@ class NewDecoder3D(nn.Module):
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
|
||||
print(
|
||||
"Working with z of shape {} = {} dimensions.".format(
|
||||
self.z_shape, np.prod(self.z_shape)
|
||||
)
|
||||
)
|
||||
|
||||
# z to block_in
|
||||
# self.conv_in = torch.nn.Conv3d(z_channels,
|
||||
|
@ -51,7 +51,12 @@ def nonlinearity(x):
|
||||
class CausalConv3d(nn.Module):
|
||||
@beartype
|
||||
def __init__(
|
||||
self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], pad_mode="constant", **kwargs
|
||||
self,
|
||||
chan_in,
|
||||
chan_out,
|
||||
kernel_size: Union[int, Tuple[int, int, int]],
|
||||
pad_mode="constant",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
kernel_size = cast_tuple(kernel_size, 3)
|
||||
@ -75,11 +80,20 @@ class CausalConv3d(nn.Module):
|
||||
|
||||
stride = (stride, 1, 1)
|
||||
dilation = (dilation, 1, 1)
|
||||
self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
||||
self.conv = nn.Conv3d(
|
||||
chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
if self.pad_mode == "constant":
|
||||
causal_padding_3d = (self.time_pad, 0, self.width_pad, self.width_pad, self.height_pad, self.height_pad)
|
||||
causal_padding_3d = (
|
||||
self.time_pad,
|
||||
0,
|
||||
self.width_pad,
|
||||
self.width_pad,
|
||||
self.height_pad,
|
||||
self.height_pad,
|
||||
)
|
||||
x = F.pad(x, causal_padding_3d, mode="constant", value=0)
|
||||
elif self.pad_mode == "first":
|
||||
pad_x = torch.cat([x[:, :, :1]] * self.time_pad, dim=2)
|
||||
@ -91,7 +105,9 @@ class CausalConv3d(nn.Module):
|
||||
reflect_x = x[:, :, 1 : self.time_pad + 1, :, :].flip(dims=[2])
|
||||
if reflect_x.shape[2] < self.time_pad:
|
||||
reflect_x = torch.cat(
|
||||
[torch.zeros_like(x[:, :, :1, :, :])] * (self.time_pad - reflect_x.shape[2]) + [reflect_x], dim=2
|
||||
[torch.zeros_like(x[:, :, :1, :, :])] * (self.time_pad - reflect_x.shape[2])
|
||||
+ [reflect_x],
|
||||
dim=2,
|
||||
)
|
||||
x = torch.cat([reflect_x, x], dim=2)
|
||||
causal_padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
|
||||
@ -110,7 +126,9 @@ class Upsample3D(nn.Module):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.conv = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
self.compress_time = compress_time
|
||||
|
||||
def forward(self, x):
|
||||
@ -149,7 +167,9 @@ class DownSample3D(nn.Module):
|
||||
out_channels = in_channels
|
||||
if self.with_conv:
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0)
|
||||
self.conv = torch.nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=2, padding=0
|
||||
)
|
||||
self.compress_time = compress_time
|
||||
|
||||
def forward(self, x):
|
||||
@ -182,7 +202,14 @@ class DownSample3D(nn.Module):
|
||||
|
||||
class ResnetBlock3D(nn.Module):
|
||||
def __init__(
|
||||
self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512, pad_mode="constant"
|
||||
self,
|
||||
*,
|
||||
in_channels,
|
||||
out_channels=None,
|
||||
conv_shortcut=False,
|
||||
dropout,
|
||||
temb_channels=512,
|
||||
pad_mode="constant",
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
@ -214,9 +241,13 @@ class ResnetBlock3D(nn.Module):
|
||||
# kernel_size=3,
|
||||
# stride=1,
|
||||
# padding=1)
|
||||
self.conv_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
|
||||
self.conv_shortcut = CausalConv3d(
|
||||
in_channels, out_channels, kernel_size=3, pad_mode=pad_mode
|
||||
)
|
||||
else:
|
||||
self.nin_shortcut = torch.nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.nin_shortcut = torch.nn.Conv3d(
|
||||
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
# self.nin_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=1, pad_mode=pad_mode)
|
||||
|
||||
def forward(self, x, temb):
|
||||
@ -251,7 +282,9 @@ class AttnBlock2D(nn.Module):
|
||||
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
@ -365,12 +398,20 @@ class Encoder3D(nn.Module):
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock3D(
|
||||
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, pad_mode=pad_mode
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
pad_mode=pad_mode,
|
||||
)
|
||||
# remove attention block
|
||||
# self.mid.attn_1 = AttnBlock2D(block_in)
|
||||
self.mid.block_2 = ResnetBlock3D(
|
||||
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, pad_mode=pad_mode
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
pad_mode=pad_mode,
|
||||
)
|
||||
|
||||
# end
|
||||
|
@ -80,7 +80,9 @@ class Upsample(nn.Module):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.conv = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
@ -95,7 +97,9 @@ class Downsample(nn.Module):
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||
self.conv = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
if self.with_conv:
|
||||
@ -134,9 +138,13 @@ class ResnetBlock(nn.Module):
|
||||
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.conv_shortcut = torch.nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
else:
|
||||
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.nin_shortcut = torch.nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
|
||||
def forward(self, x, temb, zq):
|
||||
h = x
|
||||
@ -170,7 +178,9 @@ class AttnBlock(nn.Module):
|
||||
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
|
||||
def forward(self, x, zq):
|
||||
h_ = x
|
||||
@ -232,7 +242,11 @@ class MOVQDecoder(nn.Module):
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
|
||||
print(
|
||||
"Working with z of shape {} = {} dimensions.".format(
|
||||
self.z_shape, np.prod(self.z_shape)
|
||||
)
|
||||
)
|
||||
|
||||
# z to block_in
|
||||
self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
||||
|
@ -15,7 +15,16 @@ class VectorQuantizer2(nn.Module):
|
||||
# NOTE: due to a bug the beta term was applied to the wrong term. for
|
||||
# backwards compatibility we use the buggy version by default, but you can
|
||||
# specify legacy=False to fix it.
|
||||
def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True):
|
||||
def __init__(
|
||||
self,
|
||||
n_e,
|
||||
e_dim,
|
||||
beta,
|
||||
remap=None,
|
||||
unknown_index="random",
|
||||
sane_index_shape=False,
|
||||
legacy=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.n_e = n_e
|
||||
self.e_dim = e_dim
|
||||
@ -51,7 +60,9 @@ class VectorQuantizer2(nn.Module):
|
||||
new = match.argmax(-1)
|
||||
unknown = match.sum(2) < 1
|
||||
if self.unknown_index == "random":
|
||||
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
|
||||
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(
|
||||
device=new.device
|
||||
)
|
||||
else:
|
||||
new[unknown] = self.unknown_index
|
||||
return new.reshape(ishape)
|
||||
@ -78,7 +89,8 @@ class VectorQuantizer2(nn.Module):
|
||||
d = (
|
||||
torch.sum(z_flattened**2, dim=1, keepdim=True)
|
||||
+ torch.sum(self.embedding.weight**2, dim=1)
|
||||
- 2 * torch.einsum("bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n"))
|
||||
- 2
|
||||
* torch.einsum("bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n"))
|
||||
)
|
||||
|
||||
min_encoding_indices = torch.argmin(d, dim=1)
|
||||
@ -88,9 +100,13 @@ class VectorQuantizer2(nn.Module):
|
||||
|
||||
# compute loss for embedding
|
||||
if not self.legacy:
|
||||
loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
|
||||
loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean(
|
||||
(z_q - z.detach()) ** 2
|
||||
)
|
||||
else:
|
||||
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
|
||||
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean(
|
||||
(z_q - z.detach()) ** 2
|
||||
)
|
||||
|
||||
# preserve gradients
|
||||
z_q = z + (z_q - z).detach()
|
||||
@ -104,7 +120,9 @@ class VectorQuantizer2(nn.Module):
|
||||
min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
|
||||
|
||||
if self.sane_index_shape:
|
||||
min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])
|
||||
min_encoding_indices = min_encoding_indices.reshape(
|
||||
z_q.shape[0], z_q.shape[2], z_q.shape[3]
|
||||
)
|
||||
|
||||
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
|
||||
|
||||
@ -184,7 +202,9 @@ class GumbelQuantize(nn.Module):
|
||||
new = match.argmax(-1)
|
||||
unknown = match.sum(2) < 1
|
||||
if self.unknown_index == "random":
|
||||
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
|
||||
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(
|
||||
device=new.device
|
||||
)
|
||||
else:
|
||||
new[unknown] = self.unknown_index
|
||||
return new.reshape(ishape)
|
||||
|
@ -40,7 +40,9 @@ class Upsample(nn.Module):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.conv = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
@ -55,7 +57,9 @@ class Downsample(nn.Module):
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||
self.conv = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
if self.with_conv:
|
||||
@ -68,7 +72,9 @@ class Downsample(nn.Module):
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
|
||||
def __init__(
|
||||
self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
@ -84,9 +90,13 @@ class ResnetBlock(nn.Module):
|
||||
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.conv_shortcut = torch.nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
else:
|
||||
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.nin_shortcut = torch.nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
|
||||
def forward(self, x, temb):
|
||||
h = x
|
||||
@ -120,7 +130,9 @@ class AttnBlock(nn.Module):
|
||||
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
@ -194,7 +206,10 @@ class Encoder(nn.Module):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
@ -326,7 +341,11 @@ class Decoder(nn.Module):
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
|
||||
print(
|
||||
"Working with z of shape {} = {} dimensions.".format(
|
||||
self.z_shape, np.prod(self.z_shape)
|
||||
)
|
||||
)
|
||||
|
||||
# z to block_in
|
||||
self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
||||
@ -350,7 +369,10 @@ class Decoder(nn.Module):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
|
@ -136,9 +136,9 @@ def _conv_split(input_, dim, kernel_size):
|
||||
if cp_rank == 0:
|
||||
output = input_.transpose(dim, 0)[: dim_size + kernel_size].transpose(dim, 0)
|
||||
else:
|
||||
output = input_.transpose(dim, 0)[cp_rank * dim_size + 1 : (cp_rank + 1) * dim_size + kernel_size].transpose(
|
||||
dim, 0
|
||||
)
|
||||
output = input_.transpose(dim, 0)[
|
||||
cp_rank * dim_size + 1 : (cp_rank + 1) * dim_size + kernel_size
|
||||
].transpose(dim, 0)
|
||||
output = output.contiguous()
|
||||
|
||||
# print('out _conv_split, cp_rank:', cp_rank, 'input_size:', output.shape)
|
||||
|
@ -35,7 +35,9 @@ class Denoiser(nn.Module):
|
||||
sigma = append_dims(sigma, input.ndim)
|
||||
c_skip, c_out, c_in, c_noise = self.scaling(sigma, **additional_model_inputs)
|
||||
c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape))
|
||||
return network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out + input * c_skip
|
||||
return (
|
||||
network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out + input * c_skip
|
||||
)
|
||||
|
||||
|
||||
class DiscreteDenoiser(Denoiser):
|
||||
@ -50,7 +52,9 @@ class DiscreteDenoiser(Denoiser):
|
||||
flip=True,
|
||||
):
|
||||
super().__init__(weighting_config, scaling_config)
|
||||
sigmas = instantiate_from_config(discretization_config)(num_idx, do_append_zero=do_append_zero, flip=flip)
|
||||
sigmas = instantiate_from_config(discretization_config)(
|
||||
num_idx, do_append_zero=do_append_zero, flip=flip
|
||||
)
|
||||
self.sigmas = sigmas
|
||||
# self.register_buffer("sigmas", sigmas)
|
||||
self.quantize_c_noise = quantize_c_noise
|
||||
|
@ -6,7 +6,9 @@ import torch
|
||||
|
||||
class DenoiserScaling(ABC):
|
||||
@abstractmethod
|
||||
def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
def __call__(
|
||||
self, sigma: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
pass
|
||||
|
||||
|
||||
@ -14,7 +16,9 @@ class EDMScaling:
|
||||
def __init__(self, sigma_data: float = 0.5):
|
||||
self.sigma_data = sigma_data
|
||||
|
||||
def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
def __call__(
|
||||
self, sigma: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
|
||||
c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
|
||||
c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
|
||||
@ -23,7 +27,9 @@ class EDMScaling:
|
||||
|
||||
|
||||
class EpsScaling:
|
||||
def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
def __call__(
|
||||
self, sigma: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
c_skip = torch.ones_like(sigma, device=sigma.device)
|
||||
c_out = -sigma
|
||||
c_in = 1 / (sigma**2 + 1.0) ** 0.5
|
||||
@ -32,7 +38,9 @@ class EpsScaling:
|
||||
|
||||
|
||||
class VScaling:
|
||||
def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
def __call__(
|
||||
self, sigma: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
c_skip = 1.0 / (sigma**2 + 1.0)
|
||||
c_out = -sigma / (sigma**2 + 1.0) ** 0.5
|
||||
c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
|
||||
@ -41,7 +49,9 @@ class VScaling:
|
||||
|
||||
|
||||
class VScalingWithEDMcNoise(DenoiserScaling):
|
||||
def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
def __call__(
|
||||
self, sigma: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
c_skip = 1.0 / (sigma**2 + 1.0)
|
||||
c_out = -sigma / (sigma**2 + 1.0) ** 0.5
|
||||
c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
|
||||
|
@ -52,7 +52,9 @@ class LegacyDDPMDiscretization(Discretization):
|
||||
):
|
||||
super().__init__()
|
||||
self.num_timesteps = num_timesteps
|
||||
betas = make_beta_schedule("linear", num_timesteps, linear_start=linear_start, linear_end=linear_end)
|
||||
betas = make_beta_schedule(
|
||||
"linear", num_timesteps, linear_start=linear_start, linear_end=linear_end
|
||||
)
|
||||
alphas = 1.0 - betas
|
||||
self.alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||
self.to_torch = partial(torch.tensor, dtype=torch.float32)
|
||||
@ -85,14 +87,18 @@ class ZeroSNRDDPMDiscretization(Discretization):
|
||||
if keep_start and not post_shift:
|
||||
linear_start = linear_start / (shift_scale + (1 - shift_scale) * linear_start)
|
||||
self.num_timesteps = num_timesteps
|
||||
betas = make_beta_schedule("linear", num_timesteps, linear_start=linear_start, linear_end=linear_end)
|
||||
betas = make_beta_schedule(
|
||||
"linear", num_timesteps, linear_start=linear_start, linear_end=linear_end
|
||||
)
|
||||
alphas = 1.0 - betas
|
||||
self.alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||
self.to_torch = partial(torch.tensor, dtype=torch.float32)
|
||||
|
||||
# SNR shift
|
||||
if not post_shift:
|
||||
self.alphas_cumprod = self.alphas_cumprod / (shift_scale + (1 - shift_scale) * self.alphas_cumprod)
|
||||
self.alphas_cumprod = self.alphas_cumprod / (
|
||||
shift_scale + (1 - shift_scale) * self.alphas_cumprod
|
||||
)
|
||||
|
||||
self.post_shift = post_shift
|
||||
self.shift_scale = shift_scale
|
||||
@ -113,11 +119,14 @@ class ZeroSNRDDPMDiscretization(Discretization):
|
||||
alphas_cumprod_sqrt_T = alphas_cumprod_sqrt[-1].clone()
|
||||
|
||||
alphas_cumprod_sqrt -= alphas_cumprod_sqrt_T
|
||||
alphas_cumprod_sqrt *= alphas_cumprod_sqrt_0 / (alphas_cumprod_sqrt_0 - alphas_cumprod_sqrt_T)
|
||||
alphas_cumprod_sqrt *= alphas_cumprod_sqrt_0 / (
|
||||
alphas_cumprod_sqrt_0 - alphas_cumprod_sqrt_T
|
||||
)
|
||||
|
||||
if self.post_shift:
|
||||
alphas_cumprod_sqrt = (
|
||||
alphas_cumprod_sqrt**2 / (self.shift_scale + (1 - self.shift_scale) * alphas_cumprod_sqrt**2)
|
||||
alphas_cumprod_sqrt**2
|
||||
/ (self.shift_scale + (1 - self.shift_scale) * alphas_cumprod_sqrt**2)
|
||||
) ** 0.5
|
||||
|
||||
if return_idx:
|
||||
|
@ -15,7 +15,9 @@ class Guider(ABC):
|
||||
def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor:
|
||||
pass
|
||||
|
||||
def prepare_inputs(self, x: torch.Tensor, s: float, c: Dict, uc: Dict) -> Tuple[torch.Tensor, float, Dict]:
|
||||
def prepare_inputs(
|
||||
self, x: torch.Tensor, s: float, c: Dict, uc: Dict
|
||||
) -> Tuple[torch.Tensor, float, Dict]:
|
||||
pass
|
||||
|
||||
|
||||
@ -57,7 +59,8 @@ class DynamicCFG(VanillaCFG):
|
||||
def __init__(self, scale, exp, num_steps, dyn_thresh_config=None):
|
||||
super().__init__(scale, dyn_thresh_config)
|
||||
scale_schedule = (
|
||||
lambda scale, sigma, step_index: 1 + scale * (1 - math.cos(math.pi * (step_index / num_steps) ** exp)) / 2
|
||||
lambda scale, sigma, step_index: 1
|
||||
+ scale * (1 - math.cos(math.pi * (step_index / num_steps) ** exp)) / 2
|
||||
)
|
||||
self.scale_schedule = partial(scale_schedule, scale)
|
||||
self.dyn_thresh = instantiate_from_config(
|
||||
|
@ -20,7 +20,9 @@ from torch import nn
|
||||
|
||||
|
||||
class LoRALinearLayer(nn.Module):
|
||||
def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None):
|
||||
def __init__(
|
||||
self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
|
||||
@ -50,11 +52,20 @@ class LoRALinearLayer(nn.Module):
|
||||
|
||||
class LoRAConv2dLayer(nn.Module):
|
||||
def __init__(
|
||||
self, in_features, out_features, rank=4, kernel_size=(1, 1), stride=(1, 1), padding=0, network_alpha=None
|
||||
self,
|
||||
in_features,
|
||||
out_features,
|
||||
rank=4,
|
||||
kernel_size=(1, 1),
|
||||
stride=(1, 1),
|
||||
padding=0,
|
||||
network_alpha=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.down = nn.Conv2d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
|
||||
self.down = nn.Conv2d(
|
||||
in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False
|
||||
)
|
||||
# according to the official kohya_ss trainer kernel_size are always fixed for the up layer
|
||||
# # see: https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L129
|
||||
self.up = nn.Conv2d(rank, out_features, kernel_size=(1, 1), stride=(1, 1), bias=False)
|
||||
@ -85,7 +96,9 @@ class LoRACompatibleConv(nn.Conv2d):
|
||||
A convolutional layer that can be used with LoRA.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, scale: float = 1.0, **kwargs):
|
||||
def __init__(
|
||||
self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, scale: float = 1.0, **kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lora_layer = lora_layer
|
||||
self.scale = scale
|
||||
@ -144,7 +157,13 @@ class LoRACompatibleConv(nn.Conv2d):
|
||||
# make sure to the functional Conv2D function as otherwise torch.compile's graph will break
|
||||
# see: https://github.com/huggingface/diffusers/pull/4315
|
||||
return F.conv2d(
|
||||
hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
|
||||
hidden_states,
|
||||
self.weight,
|
||||
self.bias,
|
||||
self.stride,
|
||||
self.padding,
|
||||
self.dilation,
|
||||
self.groups,
|
||||
)
|
||||
else:
|
||||
return super().forward(hidden_states) + (scale * self.lora_layer(hidden_states))
|
||||
@ -155,7 +174,9 @@ class LoRACompatibleLinear(nn.Linear):
|
||||
A Linear layer that can be used with LoRA.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, scale: float = 1.0, **kwargs):
|
||||
def __init__(
|
||||
self, *args, lora_layer: Optional[LoRALinearLayer] = None, scale: float = 1.0, **kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lora_layer = lora_layer
|
||||
self.scale = scale
|
||||
@ -197,7 +218,9 @@ class LoRACompatibleLinear(nn.Linear):
|
||||
w_up = self.w_up.to(device=device).float()
|
||||
w_down = self.w_down.to(device).float()
|
||||
|
||||
unfused_weight = fused_weight.float() - (self._lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
|
||||
unfused_weight = fused_weight.float() - (
|
||||
self._lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0]
|
||||
)
|
||||
self.weight.data = unfused_weight.to(device=device, dtype=dtype)
|
||||
|
||||
self.w_up = None
|
||||
@ -252,7 +275,9 @@ def _find_modules_v2(
|
||||
|
||||
# Get the targets we should replace all linears under
|
||||
if ancestor_class is not None:
|
||||
ancestors = (module for module in model.modules() if module.__class__.__name__ in ancestor_class)
|
||||
ancestors = (
|
||||
module for module in model.modules() if module.__class__.__name__ in ancestor_class
|
||||
)
|
||||
else:
|
||||
# this, incase you want to naively iterate over all modules.
|
||||
ancestors = [module for module in model.modules()]
|
||||
@ -274,7 +299,9 @@ def _find_modules_v2(
|
||||
if flag:
|
||||
continue
|
||||
# Skip this linear if it's a child of a LoraInjectedLinear
|
||||
if exclude_children_of and any([isinstance(parent, _class) for _class in exclude_children_of]):
|
||||
if exclude_children_of and any(
|
||||
[isinstance(parent, _class) for _class in exclude_children_of]
|
||||
):
|
||||
continue
|
||||
# Otherwise, yield it
|
||||
yield parent, name, module
|
||||
|
@ -38,13 +38,17 @@ class StandardDiffusionLoss(nn.Module):
|
||||
|
||||
def __call__(self, network, denoiser, conditioner, input, batch):
|
||||
cond = conditioner(batch)
|
||||
additional_model_inputs = {key: batch[key] for key in self.batch2model_keys.intersection(batch)}
|
||||
additional_model_inputs = {
|
||||
key: batch[key] for key in self.batch2model_keys.intersection(batch)
|
||||
}
|
||||
|
||||
sigmas = self.sigma_sampler(input.shape[0]).to(input.device)
|
||||
noise = torch.randn_like(input)
|
||||
if self.offset_noise_level > 0.0:
|
||||
noise = (
|
||||
noise + append_dims(torch.randn(input.shape[0]).to(input.device), input.ndim) * self.offset_noise_level
|
||||
noise
|
||||
+ append_dims(torch.randn(input.shape[0]).to(input.device), input.ndim)
|
||||
* self.offset_noise_level
|
||||
)
|
||||
noise = noise.to(input.dtype)
|
||||
noised_input = input.float() + noise * append_dims(sigmas, input.ndim)
|
||||
@ -63,7 +67,9 @@ class StandardDiffusionLoss(nn.Module):
|
||||
|
||||
|
||||
class VideoDiffusionLoss(StandardDiffusionLoss):
|
||||
def __init__(self, block_scale=None, block_size=None, min_snr_value=None, fixed_frames=0, **kwargs):
|
||||
def __init__(
|
||||
self, block_scale=None, block_size=None, min_snr_value=None, fixed_frames=0, **kwargs
|
||||
):
|
||||
self.fixed_frames = fixed_frames
|
||||
self.block_scale = block_scale
|
||||
self.block_size = block_size
|
||||
@ -72,7 +78,9 @@ class VideoDiffusionLoss(StandardDiffusionLoss):
|
||||
|
||||
def __call__(self, network, denoiser, conditioner, input, batch):
|
||||
cond = conditioner(batch)
|
||||
additional_model_inputs = {key: batch[key] for key in self.batch2model_keys.intersection(batch)}
|
||||
additional_model_inputs = {
|
||||
key: batch[key] for key in self.batch2model_keys.intersection(batch)
|
||||
}
|
||||
|
||||
alphas_cumprod_sqrt, idx = self.sigma_sampler(input.shape[0], return_idx=True)
|
||||
alphas_cumprod_sqrt = alphas_cumprod_sqrt.to(input.device)
|
||||
@ -86,24 +94,30 @@ class VideoDiffusionLoss(StandardDiffusionLoss):
|
||||
src = global_rank * mp_size
|
||||
torch.distributed.broadcast(idx, src=src, group=mpu.get_model_parallel_group())
|
||||
torch.distributed.broadcast(noise, src=src, group=mpu.get_model_parallel_group())
|
||||
torch.distributed.broadcast(alphas_cumprod_sqrt, src=src, group=mpu.get_model_parallel_group())
|
||||
torch.distributed.broadcast(
|
||||
alphas_cumprod_sqrt, src=src, group=mpu.get_model_parallel_group()
|
||||
)
|
||||
|
||||
additional_model_inputs["idx"] = idx
|
||||
|
||||
if self.offset_noise_level > 0.0:
|
||||
noise = (
|
||||
noise + append_dims(torch.randn(input.shape[0]).to(input.device), input.ndim) * self.offset_noise_level
|
||||
noise
|
||||
+ append_dims(torch.randn(input.shape[0]).to(input.device), input.ndim)
|
||||
* self.offset_noise_level
|
||||
)
|
||||
|
||||
noised_input = input.float() * append_dims(alphas_cumprod_sqrt, input.ndim) + noise * append_dims(
|
||||
(1 - alphas_cumprod_sqrt**2) ** 0.5, input.ndim
|
||||
)
|
||||
noised_input = input.float() * append_dims(
|
||||
alphas_cumprod_sqrt, input.ndim
|
||||
) + noise * append_dims((1 - alphas_cumprod_sqrt**2) ** 0.5, input.ndim)
|
||||
|
||||
if "concat_images" in batch.keys():
|
||||
cond["concat"] = batch["concat_images"]
|
||||
|
||||
# [2, 13, 16, 60, 90],[2] dict_keys(['crossattn', 'concat']) dict_keys(['idx'])
|
||||
model_output = denoiser(network, noised_input, alphas_cumprod_sqrt, cond, **additional_model_inputs)
|
||||
model_output = denoiser(
|
||||
network, noised_input, alphas_cumprod_sqrt, cond, **additional_model_inputs
|
||||
)
|
||||
w = append_dims(1 / (1 - alphas_cumprod_sqrt**2), input.ndim) # v-pred
|
||||
|
||||
if self.min_snr_value is not None:
|
||||
|
@ -47,7 +47,9 @@ def nonlinearity(x):
|
||||
|
||||
|
||||
def Normalize(in_channels, num_groups=32):
|
||||
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
return torch.nn.GroupNorm(
|
||||
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
|
||||
)
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
@ -55,7 +57,9 @@ class Upsample(nn.Module):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.conv = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
@ -70,7 +74,9 @@ class Downsample(nn.Module):
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||
self.conv = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
if self.with_conv:
|
||||
@ -107,9 +113,13 @@ class ResnetBlock(nn.Module):
|
||||
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.conv_shortcut = torch.nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
else:
|
||||
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.nin_shortcut = torch.nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
|
||||
def forward(self, x, temb):
|
||||
h = x
|
||||
@ -150,7 +160,9 @@ class AttnBlock(nn.Module):
|
||||
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
|
||||
def attention(self, h_: torch.Tensor) -> torch.Tensor:
|
||||
h_ = self.norm(h_)
|
||||
@ -160,7 +172,9 @@ class AttnBlock(nn.Module):
|
||||
|
||||
b, c, h, w = q.shape
|
||||
q, k, v = map(lambda x: rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v))
|
||||
h_ = torch.nn.functional.scaled_dot_product_attention(q, k, v) # scale is dim ** -0.5 per default
|
||||
h_ = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v
|
||||
) # scale is dim ** -0.5 per default
|
||||
# compute attention
|
||||
|
||||
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
||||
@ -188,7 +202,9 @@ class MemoryEfficientAttnBlock(nn.Module):
|
||||
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.attention_op: Optional[Any] = None
|
||||
|
||||
def attention(self, h_: torch.Tensor) -> torch.Tensor:
|
||||
@ -211,7 +227,12 @@ class MemoryEfficientAttnBlock(nn.Module):
|
||||
)
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
|
||||
|
||||
out = out.unsqueeze(0).reshape(B, 1, out.shape[1], C).permute(0, 2, 1, 3).reshape(B, out.shape[1], C)
|
||||
out = (
|
||||
out.unsqueeze(0)
|
||||
.reshape(B, 1, out.shape[1], C)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(B, out.shape[1], C)
|
||||
)
|
||||
return rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C)
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
@ -581,7 +602,11 @@ class Decoder(nn.Module):
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
|
||||
print(
|
||||
"Working with z of shape {} = {} dimensions.".format(
|
||||
self.z_shape, np.prod(self.z_shape)
|
||||
)
|
||||
)
|
||||
|
||||
make_attn_cls = self._make_attn()
|
||||
make_resblock_cls = self._make_resblock()
|
||||
|
@ -47,7 +47,9 @@ class AttentionPool2d(nn.Module):
|
||||
output_dim: int = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5)
|
||||
self.positional_embedding = nn.Parameter(
|
||||
th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5
|
||||
)
|
||||
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
|
||||
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
|
||||
self.num_heads = embed_dim // num_heads_channels
|
||||
@ -303,7 +305,9 @@ class ResBlock(TimestepBlock):
|
||||
if self.out_channels == channels:
|
||||
self.skip_connection = nn.Identity()
|
||||
elif use_conv:
|
||||
self.skip_connection = conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding)
|
||||
self.skip_connection = conv_nd(
|
||||
dims, channels, self.out_channels, kernel_size, padding=padding
|
||||
)
|
||||
else:
|
||||
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
||||
|
||||
@ -437,7 +441,9 @@ class QKVAttentionLegacy(nn.Module):
|
||||
ch = width // (3 * self.n_heads)
|
||||
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
||||
scale = 1 / math.sqrt(math.sqrt(ch))
|
||||
weight = th.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
|
||||
weight = th.einsum(
|
||||
"bct,bcs->bts", q * scale, k * scale
|
||||
) # More stable with f16 than dividing afterwards
|
||||
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
a = th.einsum("bts,bcs->bct", weight, v)
|
||||
return a.reshape(bs, -1, length)
|
||||
@ -574,9 +580,7 @@ class UNetModel(nn.Module):
|
||||
), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
|
||||
|
||||
if context_dim is not None:
|
||||
assert (
|
||||
use_spatial_transformer
|
||||
), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
|
||||
assert use_spatial_transformer, "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
|
||||
if type(context_dim) == ListConfig:
|
||||
context_dim = list(context_dim)
|
||||
|
||||
@ -640,7 +644,9 @@ class UNetModel(nn.Module):
|
||||
self.num_heads_upsample = num_heads_upsample
|
||||
self.predict_codebook_ids = n_embed is not None
|
||||
|
||||
assert use_fairscale_checkpoint != use_checkpoint or not (use_checkpoint or use_fairscale_checkpoint)
|
||||
assert use_fairscale_checkpoint != use_checkpoint or not (
|
||||
use_checkpoint or use_fairscale_checkpoint
|
||||
)
|
||||
|
||||
self.use_fairscale_checkpoint = False
|
||||
checkpoint_wrapper_fn = (
|
||||
@ -942,7 +948,9 @@ class UNetModel(nn.Module):
|
||||
print(f"loading lora from {ckpt_path}")
|
||||
sd = th.load(ckpt_path)["module"]
|
||||
sd = {
|
||||
key[len("model.diffusion_model") :]: sd[key] for key in sd if key.startswith("model.diffusion_model")
|
||||
key[len("model.diffusion_model") :]: sd[key]
|
||||
for key in sd
|
||||
if key.startswith("model.diffusion_model")
|
||||
}
|
||||
self.load_state_dict(sd, strict=False)
|
||||
|
||||
@ -978,7 +986,9 @@ class UNetModel(nn.Module):
|
||||
self.num_classes is not None
|
||||
), "must specify y if and only if the model is class-conditional"
|
||||
hs = []
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False, dtype=self.dtype)
|
||||
t_emb = timestep_embedding(
|
||||
timesteps, self.model_channels, repeat_only=False, dtype=self.dtype
|
||||
)
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
if self.num_classes is not None:
|
||||
|
@ -1,8 +1,7 @@
|
||||
"""
|
||||
Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py
|
||||
Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py
|
||||
"""
|
||||
|
||||
|
||||
from typing import Dict, Union
|
||||
|
||||
import torch
|
||||
@ -85,9 +84,7 @@ class SingleStepDiffusionSampler(BaseDiffusionSampler):
|
||||
|
||||
|
||||
class EDMSampler(SingleStepDiffusionSampler):
|
||||
def __init__(
|
||||
self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs
|
||||
):
|
||||
def __init__(self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.s_churn = s_churn
|
||||
@ -106,15 +103,11 @@ class EDMSampler(SingleStepDiffusionSampler):
|
||||
dt = append_dims(next_sigma - sigma_hat, x.ndim)
|
||||
|
||||
euler_step = self.euler_step(x, d, dt)
|
||||
x = self.possible_correction_step(
|
||||
euler_step, x, d, dt, next_sigma, denoiser, cond, uc
|
||||
)
|
||||
x = self.possible_correction_step(euler_step, x, d, dt, next_sigma, denoiser, cond, uc)
|
||||
return x
|
||||
|
||||
def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
|
||||
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
|
||||
x, cond, uc, num_steps
|
||||
)
|
||||
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
|
||||
|
||||
for i in self.get_sigma_gen(num_sigmas):
|
||||
gamma = (
|
||||
@ -136,30 +129,23 @@ class EDMSampler(SingleStepDiffusionSampler):
|
||||
|
||||
|
||||
class DDIMSampler(SingleStepDiffusionSampler):
|
||||
def __init__(
|
||||
self, s_noise=0.1, *args, **kwargs
|
||||
):
|
||||
def __init__(self, s_noise=0.1, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.s_noise = s_noise
|
||||
|
||||
def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, s_noise=0.0):
|
||||
|
||||
denoised = self.denoise(x, denoiser, sigma, cond, uc)
|
||||
d = to_d(x, sigma, denoised)
|
||||
dt = append_dims(next_sigma * (1 - s_noise**2)**0.5 - sigma, x.ndim)
|
||||
dt = append_dims(next_sigma * (1 - s_noise**2) ** 0.5 - sigma, x.ndim)
|
||||
|
||||
euler_step = x + dt * d + s_noise * append_dims(next_sigma, x.ndim) * torch.randn_like(x)
|
||||
|
||||
x = self.possible_correction_step(
|
||||
euler_step, x, d, dt, next_sigma, denoiser, cond, uc
|
||||
)
|
||||
x = self.possible_correction_step(euler_step, x, d, dt, next_sigma, denoiser, cond, uc)
|
||||
return x
|
||||
|
||||
def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
|
||||
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
|
||||
x, cond, uc, num_steps
|
||||
)
|
||||
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
|
||||
|
||||
for i in self.get_sigma_gen(num_sigmas):
|
||||
x = self.sampler_step(
|
||||
@ -198,9 +184,7 @@ class AncestralSampler(SingleStepDiffusionSampler):
|
||||
return x
|
||||
|
||||
def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
|
||||
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
|
||||
x, cond, uc, num_steps
|
||||
)
|
||||
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
|
||||
|
||||
for i in self.get_sigma_gen(num_sigmas):
|
||||
x = self.sampler_step(
|
||||
@ -227,43 +211,32 @@ class LinearMultistepSampler(BaseDiffusionSampler):
|
||||
self.order = order
|
||||
|
||||
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):
|
||||
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
|
||||
x, cond, uc, num_steps
|
||||
)
|
||||
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
|
||||
|
||||
ds = []
|
||||
sigmas_cpu = sigmas.detach().cpu().numpy()
|
||||
for i in self.get_sigma_gen(num_sigmas):
|
||||
sigma = s_in * sigmas[i]
|
||||
denoised = denoiser(
|
||||
*self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs
|
||||
)
|
||||
denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs)
|
||||
denoised = self.guider(denoised, sigma)
|
||||
d = to_d(x, sigma, denoised)
|
||||
ds.append(d)
|
||||
if len(ds) > self.order:
|
||||
ds.pop(0)
|
||||
cur_order = min(i + 1, self.order)
|
||||
coeffs = [
|
||||
linear_multistep_coeff(cur_order, sigmas_cpu, i, j)
|
||||
for j in range(cur_order)
|
||||
]
|
||||
coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)]
|
||||
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class EulerEDMSampler(EDMSampler):
|
||||
def possible_correction_step(
|
||||
self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc
|
||||
):
|
||||
def possible_correction_step(self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc):
|
||||
return euler_step
|
||||
|
||||
|
||||
class HeunEDMSampler(EDMSampler):
|
||||
def possible_correction_step(
|
||||
self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc
|
||||
):
|
||||
def possible_correction_step(self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc):
|
||||
if torch.sum(next_sigma) < 1e-14:
|
||||
# Save a network evaluation if all noise levels are 0
|
||||
return euler_step
|
||||
@ -273,9 +246,7 @@ class HeunEDMSampler(EDMSampler):
|
||||
d_prime = (d + d_new) / 2.0
|
||||
|
||||
# apply correction if noise level is not 0
|
||||
x = torch.where(
|
||||
append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step
|
||||
)
|
||||
x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step)
|
||||
return x
|
||||
|
||||
|
||||
@ -314,9 +285,7 @@ class DPMPP2SAncestralSampler(AncestralSampler):
|
||||
x = x_euler
|
||||
else:
|
||||
h, s, t, t_next = self.get_variables(sigma, sigma_down)
|
||||
mult = [
|
||||
append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next)
|
||||
]
|
||||
mult = [append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next)]
|
||||
|
||||
x2 = mult[0] * x - mult[1] * denoised
|
||||
denoised2 = self.denoise(x2, denoiser, to_sigma(s), cond, uc)
|
||||
@ -367,8 +336,7 @@ class DPMPP2MSampler(BaseDiffusionSampler):
|
||||
|
||||
h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma)
|
||||
mult = [
|
||||
append_dims(mult, x.ndim)
|
||||
for mult in self.get_mult(h, r, t, t_next, previous_sigma)
|
||||
append_dims(mult, x.ndim) for mult in self.get_mult(h, r, t, t_next, previous_sigma)
|
||||
]
|
||||
|
||||
x_standard = mult[0] * x - mult[1] * denoised
|
||||
@ -380,16 +348,12 @@ class DPMPP2MSampler(BaseDiffusionSampler):
|
||||
x_advanced = mult[0] * x - mult[1] * denoised_d
|
||||
|
||||
# apply correction if noise level is not 0 and not first step
|
||||
x = torch.where(
|
||||
append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard
|
||||
)
|
||||
x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard)
|
||||
|
||||
return x, denoised
|
||||
|
||||
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):
|
||||
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
|
||||
x, cond, uc, num_steps
|
||||
)
|
||||
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
|
||||
|
||||
old_denoised = None
|
||||
for i in self.get_sigma_gen(num_sigmas):
|
||||
@ -406,6 +370,7 @@ class DPMPP2MSampler(BaseDiffusionSampler):
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SDEDPMPP2MSampler(BaseDiffusionSampler):
|
||||
def get_variables(self, sigma, next_sigma, previous_sigma=None):
|
||||
t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)]
|
||||
@ -420,7 +385,7 @@ class SDEDPMPP2MSampler(BaseDiffusionSampler):
|
||||
|
||||
def get_mult(self, h, r, t, t_next, previous_sigma):
|
||||
mult1 = to_sigma(t_next) / to_sigma(t) * (-h).exp()
|
||||
mult2 = (-2*h).expm1()
|
||||
mult2 = (-2 * h).expm1()
|
||||
|
||||
if previous_sigma is not None:
|
||||
mult3 = 1 + 1 / (2 * r)
|
||||
@ -444,10 +409,9 @@ class SDEDPMPP2MSampler(BaseDiffusionSampler):
|
||||
|
||||
h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma)
|
||||
mult = [
|
||||
append_dims(mult, x.ndim)
|
||||
for mult in self.get_mult(h, r, t, t_next, previous_sigma)
|
||||
append_dims(mult, x.ndim) for mult in self.get_mult(h, r, t, t_next, previous_sigma)
|
||||
]
|
||||
mult_noise = append_dims(next_sigma * (1 - (-2*h).exp())**0.5, x.ndim)
|
||||
mult_noise = append_dims(next_sigma * (1 - (-2 * h).exp()) ** 0.5, x.ndim)
|
||||
|
||||
x_standard = mult[0] * x - mult[1] * denoised + mult_noise * torch.randn_like(x)
|
||||
if old_denoised is None or torch.sum(next_sigma) < 1e-14:
|
||||
@ -458,16 +422,12 @@ class SDEDPMPP2MSampler(BaseDiffusionSampler):
|
||||
x_advanced = mult[0] * x - mult[1] * denoised_d + mult_noise * torch.randn_like(x)
|
||||
|
||||
# apply correction if noise level is not 0 and not first step
|
||||
x = torch.where(
|
||||
append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard
|
||||
)
|
||||
x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard)
|
||||
|
||||
return x, denoised
|
||||
|
||||
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, **kwargs):
|
||||
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
|
||||
x, cond, uc, num_steps
|
||||
)
|
||||
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
|
||||
|
||||
old_denoised = None
|
||||
for i in self.get_sigma_gen(num_sigmas):
|
||||
@ -484,6 +444,7 @@ class SDEDPMPP2MSampler(BaseDiffusionSampler):
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SdeditEDMSampler(EulerEDMSampler):
|
||||
def __init__(self, edit_ratio=0.5, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@ -525,8 +486,8 @@ class SdeditEDMSampler(EulerEDMSampler):
|
||||
|
||||
return x
|
||||
|
||||
class VideoDDIMSampler(BaseDiffusionSampler):
|
||||
|
||||
class VideoDDIMSampler(BaseDiffusionSampler):
|
||||
def __init__(self, fixed_frames=0, sdedit=False, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.fixed_frames = fixed_frames
|
||||
@ -534,10 +495,15 @@ class VideoDDIMSampler(BaseDiffusionSampler):
|
||||
|
||||
def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
|
||||
alpha_cumprod_sqrt, timesteps = self.discretization(
|
||||
self.num_steps if num_steps is None else num_steps, device=self.device, return_idx=True, do_append_zero=False
|
||||
self.num_steps if num_steps is None else num_steps,
|
||||
device=self.device,
|
||||
return_idx=True,
|
||||
do_append_zero=False,
|
||||
)
|
||||
alpha_cumprod_sqrt = torch.cat([alpha_cumprod_sqrt, alpha_cumprod_sqrt.new_ones([1])])
|
||||
timesteps = torch.cat([torch.tensor(list(timesteps)).new_zeros([1])-1, torch.tensor(list(timesteps))])
|
||||
timesteps = torch.cat(
|
||||
[torch.tensor(list(timesteps)).new_zeros([1]) - 1, torch.tensor(list(timesteps))]
|
||||
)
|
||||
|
||||
uc = default(uc, cond)
|
||||
|
||||
@ -547,7 +513,19 @@ class VideoDDIMSampler(BaseDiffusionSampler):
|
||||
|
||||
return x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps
|
||||
|
||||
def denoise(self, x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep=None, idx=None, scale=None, scale_emb=None, ofs=None):
|
||||
def denoise(
|
||||
self,
|
||||
x,
|
||||
denoiser,
|
||||
alpha_cumprod_sqrt,
|
||||
cond,
|
||||
uc,
|
||||
timestep=None,
|
||||
idx=None,
|
||||
scale=None,
|
||||
scale_emb=None,
|
||||
ofs=None,
|
||||
):
|
||||
additional_model_inputs = {}
|
||||
|
||||
if ofs is not None:
|
||||
@ -557,26 +535,62 @@ class VideoDDIMSampler(BaseDiffusionSampler):
|
||||
additional_model_inputs['idx'] = x.new_ones([x.shape[0]]) * timestep
|
||||
if scale_emb is not None:
|
||||
additional_model_inputs['scale_emb'] = scale_emb
|
||||
denoised = denoiser(x, alpha_cumprod_sqrt, cond, **additional_model_inputs).to(torch.float32)
|
||||
denoised = denoiser(x, alpha_cumprod_sqrt, cond, **additional_model_inputs).to(
|
||||
torch.float32
|
||||
)
|
||||
else:
|
||||
additional_model_inputs['idx'] = torch.cat([x.new_ones([x.shape[0]]) * timestep] * 2)
|
||||
denoised = denoiser(*self.guider.prepare_inputs(x, alpha_cumprod_sqrt, cond, uc), **additional_model_inputs).to(torch.float32)
|
||||
denoised = denoiser(
|
||||
*self.guider.prepare_inputs(x, alpha_cumprod_sqrt, cond, uc),
|
||||
**additional_model_inputs,
|
||||
).to(torch.float32)
|
||||
if isinstance(self.guider, DynamicCFG):
|
||||
denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt**2)**0.5, step_index=self.num_steps - timestep, scale=scale)
|
||||
denoised = self.guider(
|
||||
denoised,
|
||||
(1 - alpha_cumprod_sqrt**2) ** 0.5,
|
||||
step_index=self.num_steps - timestep,
|
||||
scale=scale,
|
||||
)
|
||||
else:
|
||||
denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt**2)**0.5, scale=scale)
|
||||
denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt**2) ** 0.5, scale=scale)
|
||||
return denoised
|
||||
|
||||
def sampler_step(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, denoiser, x, cond, uc=None, idx=None, timestep=None, scale=None, scale_emb=None, ofs=None):
|
||||
denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx, scale=scale, scale_emb=scale_emb, ofs=ofs).to(torch.float32) # 1020
|
||||
def sampler_step(
|
||||
self,
|
||||
alpha_cumprod_sqrt,
|
||||
next_alpha_cumprod_sqrt,
|
||||
denoiser,
|
||||
x,
|
||||
cond,
|
||||
uc=None,
|
||||
idx=None,
|
||||
timestep=None,
|
||||
scale=None,
|
||||
scale_emb=None,
|
||||
ofs=None,
|
||||
):
|
||||
denoised = self.denoise(
|
||||
x,
|
||||
denoiser,
|
||||
alpha_cumprod_sqrt,
|
||||
cond,
|
||||
uc,
|
||||
timestep,
|
||||
idx,
|
||||
scale=scale,
|
||||
scale_emb=scale_emb,
|
||||
ofs=ofs,
|
||||
).to(torch.float32) # 1020
|
||||
|
||||
a_t = ((1-next_alpha_cumprod_sqrt**2)/(1-alpha_cumprod_sqrt**2))**0.5
|
||||
a_t = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5
|
||||
b_t = next_alpha_cumprod_sqrt - alpha_cumprod_sqrt * a_t
|
||||
|
||||
x = append_dims(a_t, x.ndim) * x + append_dims(b_t, x.ndim) * denoised
|
||||
return x
|
||||
|
||||
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None, ofs=None): # 1020
|
||||
def __call__(
|
||||
self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None, ofs=None
|
||||
): # 1020
|
||||
x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop(
|
||||
x, cond, uc, num_steps
|
||||
)
|
||||
@ -590,17 +604,16 @@ class VideoDDIMSampler(BaseDiffusionSampler):
|
||||
cond,
|
||||
uc,
|
||||
idx=self.num_steps - i,
|
||||
timestep=timesteps[-(i+1)],
|
||||
timestep=timesteps[-(i + 1)],
|
||||
scale=scale,
|
||||
scale_emb=scale_emb,
|
||||
ofs=ofs # 1020
|
||||
ofs=ofs, # 1020
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Image2VideoDDIMSampler(BaseDiffusionSampler):
|
||||
|
||||
def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
|
||||
alpha_cumprod_sqrt, timesteps = self.discretization(
|
||||
self.num_steps if num_steps is None else num_steps, device=self.device, return_idx=True
|
||||
@ -616,22 +629,36 @@ class Image2VideoDDIMSampler(BaseDiffusionSampler):
|
||||
def denoise(self, x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep=None):
|
||||
additional_model_inputs = {}
|
||||
additional_model_inputs['idx'] = torch.cat([x.new_ones([x.shape[0]]) * timestep] * 2)
|
||||
denoised = denoiser(*self.guider.prepare_inputs(x, alpha_cumprod_sqrt, cond, uc), **additional_model_inputs).to(
|
||||
torch.float32)
|
||||
denoised = denoiser(
|
||||
*self.guider.prepare_inputs(x, alpha_cumprod_sqrt, cond, uc), **additional_model_inputs
|
||||
).to(torch.float32)
|
||||
if isinstance(self.guider, DynamicCFG):
|
||||
denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt ** 2) ** 0.5, step_index=self.num_steps - timestep)
|
||||
denoised = self.guider(
|
||||
denoised, (1 - alpha_cumprod_sqrt**2) ** 0.5, step_index=self.num_steps - timestep
|
||||
)
|
||||
else:
|
||||
denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt ** 2) ** 0.5)
|
||||
denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt**2) ** 0.5)
|
||||
return denoised
|
||||
|
||||
def sampler_step(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, denoiser, x, cond, uc=None, idx=None,
|
||||
timestep=None):
|
||||
def sampler_step(
|
||||
self,
|
||||
alpha_cumprod_sqrt,
|
||||
next_alpha_cumprod_sqrt,
|
||||
denoiser,
|
||||
x,
|
||||
cond,
|
||||
uc=None,
|
||||
idx=None,
|
||||
timestep=None,
|
||||
):
|
||||
# 此处的sigma实际上是alpha_cumprod_sqrt
|
||||
denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep).to(torch.float32)
|
||||
denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep).to(
|
||||
torch.float32
|
||||
)
|
||||
if idx == 1:
|
||||
return denoised
|
||||
|
||||
a_t = ((1 - next_alpha_cumprod_sqrt ** 2) / (1 - alpha_cumprod_sqrt ** 2)) ** 0.5
|
||||
a_t = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5
|
||||
b_t = next_alpha_cumprod_sqrt - alpha_cumprod_sqrt * a_t
|
||||
|
||||
x = append_dims(a_t, x.ndim) * x + append_dims(b_t, x.ndim) * denoised
|
||||
@ -651,31 +678,36 @@ class Image2VideoDDIMSampler(BaseDiffusionSampler):
|
||||
cond,
|
||||
uc,
|
||||
idx=self.num_steps - i,
|
||||
timestep=timesteps[-(i + 1)]
|
||||
timestep=timesteps[-(i + 1)],
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class VPSDEDPMPP2MSampler(VideoDDIMSampler):
|
||||
def get_variables(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt=None):
|
||||
alpha_cumprod = alpha_cumprod_sqrt ** 2
|
||||
lamb = ((alpha_cumprod / (1-alpha_cumprod))**0.5).log()
|
||||
next_alpha_cumprod = next_alpha_cumprod_sqrt ** 2
|
||||
lamb_next = ((next_alpha_cumprod / (1-next_alpha_cumprod))**0.5).log()
|
||||
def get_variables(
|
||||
self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt=None
|
||||
):
|
||||
alpha_cumprod = alpha_cumprod_sqrt**2
|
||||
lamb = ((alpha_cumprod / (1 - alpha_cumprod)) ** 0.5).log()
|
||||
next_alpha_cumprod = next_alpha_cumprod_sqrt**2
|
||||
lamb_next = ((next_alpha_cumprod / (1 - next_alpha_cumprod)) ** 0.5).log()
|
||||
h = lamb_next - lamb
|
||||
|
||||
if previous_alpha_cumprod_sqrt is not None:
|
||||
previous_alpha_cumprod = previous_alpha_cumprod_sqrt ** 2
|
||||
lamb_previous = ((previous_alpha_cumprod / (1-previous_alpha_cumprod))**0.5).log()
|
||||
previous_alpha_cumprod = previous_alpha_cumprod_sqrt**2
|
||||
lamb_previous = ((previous_alpha_cumprod / (1 - previous_alpha_cumprod)) ** 0.5).log()
|
||||
h_last = lamb - lamb_previous
|
||||
r = h_last / h
|
||||
return h, r, lamb, lamb_next
|
||||
else:
|
||||
return h, None, lamb, lamb_next
|
||||
|
||||
def get_mult(self, h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt):
|
||||
mult1 = ((1-next_alpha_cumprod_sqrt**2) / (1-alpha_cumprod_sqrt**2))**0.5 * (-h).exp()
|
||||
mult2 = (-2*h).expm1() * next_alpha_cumprod_sqrt
|
||||
def get_mult(
|
||||
self, h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt
|
||||
):
|
||||
mult1 = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5 * (-h).exp()
|
||||
mult2 = (-2 * h).expm1() * next_alpha_cumprod_sqrt
|
||||
|
||||
if previous_alpha_cumprod_sqrt is not None:
|
||||
mult3 = 1 + 1 / (2 * r)
|
||||
@ -698,18 +730,35 @@ class VPSDEDPMPP2MSampler(VideoDDIMSampler):
|
||||
timestep=None,
|
||||
scale=None,
|
||||
scale_emb=None,
|
||||
ofs=None # 1020
|
||||
ofs=None, # 1020
|
||||
):
|
||||
denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx, scale=scale, scale_emb=scale_emb, ofs=ofs).to(torch.float32) # 1020
|
||||
denoised = self.denoise(
|
||||
x,
|
||||
denoiser,
|
||||
alpha_cumprod_sqrt,
|
||||
cond,
|
||||
uc,
|
||||
timestep,
|
||||
idx,
|
||||
scale=scale,
|
||||
scale_emb=scale_emb,
|
||||
ofs=ofs,
|
||||
).to(torch.float32) # 1020
|
||||
if idx == 1:
|
||||
return denoised, denoised
|
||||
|
||||
h, r, lamb, lamb_next = self.get_variables(alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt)
|
||||
h, r, lamb, lamb_next = self.get_variables(
|
||||
alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt
|
||||
)
|
||||
mult = [
|
||||
append_dims(mult, x.ndim)
|
||||
for mult in self.get_mult(h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt)
|
||||
for mult in self.get_mult(
|
||||
h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt
|
||||
)
|
||||
]
|
||||
mult_noise = append_dims((1-next_alpha_cumprod_sqrt**2)**0.5 * (1 - (-2*h).exp())**0.5, x.ndim)
|
||||
mult_noise = append_dims(
|
||||
(1 - next_alpha_cumprod_sqrt**2) ** 0.5 * (1 - (-2 * h).exp()) ** 0.5, x.ndim
|
||||
)
|
||||
|
||||
x_standard = mult[0] * x - mult[1] * denoised + mult_noise * torch.randn_like(x)
|
||||
if old_denoised is None or torch.sum(next_alpha_cumprod_sqrt) < 1e-14:
|
||||
@ -723,23 +772,26 @@ class VPSDEDPMPP2MSampler(VideoDDIMSampler):
|
||||
|
||||
return x, denoised
|
||||
|
||||
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None, ofs=None): # 1020
|
||||
def __call__(
|
||||
self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None, ofs=None
|
||||
): # 1020
|
||||
x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop(
|
||||
x, cond, uc, num_steps
|
||||
)
|
||||
|
||||
if self.fixed_frames > 0:
|
||||
prefix_frames = x[:, :self.fixed_frames]
|
||||
prefix_frames = x[:, : self.fixed_frames]
|
||||
old_denoised = None
|
||||
for i in self.get_sigma_gen(num_sigmas):
|
||||
|
||||
if self.fixed_frames > 0:
|
||||
if self.sdedit:
|
||||
rd = torch.randn_like(prefix_frames)
|
||||
noised_prefix_frames = alpha_cumprod_sqrt[i] * prefix_frames + rd * append_dims(s_in * (1 - alpha_cumprod_sqrt[i] ** 2)**0.5, len(prefix_frames.shape))
|
||||
x = torch.cat([noised_prefix_frames, x[:, self.fixed_frames:]], dim=1)
|
||||
noised_prefix_frames = alpha_cumprod_sqrt[i] * prefix_frames + rd * append_dims(
|
||||
s_in * (1 - alpha_cumprod_sqrt[i] ** 2) ** 0.5, len(prefix_frames.shape)
|
||||
)
|
||||
x = torch.cat([noised_prefix_frames, x[:, self.fixed_frames :]], dim=1)
|
||||
else:
|
||||
x = torch.cat([prefix_frames, x[:, self.fixed_frames:]], dim=1)
|
||||
x = torch.cat([prefix_frames, x[:, self.fixed_frames :]], dim=1)
|
||||
x, old_denoised = self.sampler_step(
|
||||
old_denoised,
|
||||
None if i == 0 else s_in * alpha_cumprod_sqrt[i - 1],
|
||||
@ -750,37 +802,41 @@ class VPSDEDPMPP2MSampler(VideoDDIMSampler):
|
||||
cond,
|
||||
uc=uc,
|
||||
idx=self.num_steps - i,
|
||||
timestep=timesteps[-(i+1)],
|
||||
timestep=timesteps[-(i + 1)],
|
||||
scale=scale,
|
||||
scale_emb=scale_emb,
|
||||
ofs=ofs # 1020
|
||||
ofs=ofs, # 1020
|
||||
)
|
||||
|
||||
if self.fixed_frames > 0:
|
||||
x = torch.cat([prefix_frames, x[:, self.fixed_frames:]], dim=1)
|
||||
x = torch.cat([prefix_frames, x[:, self.fixed_frames :]], dim=1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class VPODEDPMPP2MSampler(VideoDDIMSampler):
|
||||
def get_variables(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt=None):
|
||||
alpha_cumprod = alpha_cumprod_sqrt ** 2
|
||||
lamb = ((alpha_cumprod / (1-alpha_cumprod))**0.5).log()
|
||||
next_alpha_cumprod = next_alpha_cumprod_sqrt ** 2
|
||||
lamb_next = ((next_alpha_cumprod / (1-next_alpha_cumprod))**0.5).log()
|
||||
def get_variables(
|
||||
self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt=None
|
||||
):
|
||||
alpha_cumprod = alpha_cumprod_sqrt**2
|
||||
lamb = ((alpha_cumprod / (1 - alpha_cumprod)) ** 0.5).log()
|
||||
next_alpha_cumprod = next_alpha_cumprod_sqrt**2
|
||||
lamb_next = ((next_alpha_cumprod / (1 - next_alpha_cumprod)) ** 0.5).log()
|
||||
h = lamb_next - lamb
|
||||
|
||||
if previous_alpha_cumprod_sqrt is not None:
|
||||
previous_alpha_cumprod = previous_alpha_cumprod_sqrt ** 2
|
||||
lamb_previous = ((previous_alpha_cumprod / (1-previous_alpha_cumprod))**0.5).log()
|
||||
previous_alpha_cumprod = previous_alpha_cumprod_sqrt**2
|
||||
lamb_previous = ((previous_alpha_cumprod / (1 - previous_alpha_cumprod)) ** 0.5).log()
|
||||
h_last = lamb - lamb_previous
|
||||
r = h_last / h
|
||||
return h, r, lamb, lamb_next
|
||||
else:
|
||||
return h, None, lamb, lamb_next
|
||||
|
||||
def get_mult(self, h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt):
|
||||
mult1 = ((1-next_alpha_cumprod_sqrt**2) / (1-alpha_cumprod_sqrt**2))**0.5
|
||||
def get_mult(
|
||||
self, h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt
|
||||
):
|
||||
mult1 = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5
|
||||
mult2 = (-h).expm1() * next_alpha_cumprod_sqrt
|
||||
|
||||
if previous_alpha_cumprod_sqrt is not None:
|
||||
@ -801,16 +857,22 @@ class VPODEDPMPP2MSampler(VideoDDIMSampler):
|
||||
cond,
|
||||
uc=None,
|
||||
idx=None,
|
||||
timestep=None
|
||||
timestep=None,
|
||||
):
|
||||
denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx).to(torch.float32)
|
||||
denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx).to(
|
||||
torch.float32
|
||||
)
|
||||
if idx == 1:
|
||||
return denoised, denoised
|
||||
|
||||
h, r, lamb, lamb_next = self.get_variables(alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt)
|
||||
h, r, lamb, lamb_next = self.get_variables(
|
||||
alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt
|
||||
)
|
||||
mult = [
|
||||
append_dims(mult, x.ndim)
|
||||
for mult in self.get_mult(h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt)
|
||||
for mult in self.get_mult(
|
||||
h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt
|
||||
)
|
||||
]
|
||||
|
||||
x_standard = mult[0] * x - mult[1] * denoised
|
||||
@ -842,22 +904,44 @@ class VPODEDPMPP2MSampler(VideoDDIMSampler):
|
||||
cond,
|
||||
uc=uc,
|
||||
idx=self.num_steps - i,
|
||||
timestep=timesteps[-(i+1)]
|
||||
timestep=timesteps[-(i + 1)],
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class VideoDDPMSampler(VideoDDIMSampler):
|
||||
def sampler_step(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, denoiser, x, cond, uc=None, idx=None):
|
||||
def sampler_step(
|
||||
self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, denoiser, x, cond, uc=None, idx=None
|
||||
):
|
||||
# 此处的sigma实际上是alpha_cumprod_sqrt
|
||||
denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, idx*1000//self.num_steps).to(torch.float32)
|
||||
denoised = self.denoise(
|
||||
x, denoiser, alpha_cumprod_sqrt, cond, uc, idx * 1000 // self.num_steps
|
||||
).to(torch.float32)
|
||||
if idx == 1:
|
||||
return denoised
|
||||
|
||||
alpha_sqrt = alpha_cumprod_sqrt / next_alpha_cumprod_sqrt
|
||||
x = append_dims(alpha_sqrt * (1-next_alpha_cumprod_sqrt**2) / (1-alpha_cumprod_sqrt**2), x.ndim) * x \
|
||||
+ append_dims(next_alpha_cumprod_sqrt * (1-alpha_sqrt**2) / (1-alpha_cumprod_sqrt**2), x.ndim) * denoised \
|
||||
+ append_dims(((1-next_alpha_cumprod_sqrt**2) * (1-alpha_sqrt**2) / (1-alpha_cumprod_sqrt**2))**0.5, x.ndim) * torch.randn_like(x)
|
||||
x = (
|
||||
append_dims(
|
||||
alpha_sqrt * (1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2), x.ndim
|
||||
)
|
||||
* x
|
||||
+ append_dims(
|
||||
next_alpha_cumprod_sqrt * (1 - alpha_sqrt**2) / (1 - alpha_cumprod_sqrt**2), x.ndim
|
||||
)
|
||||
* denoised
|
||||
+ append_dims(
|
||||
(
|
||||
(1 - next_alpha_cumprod_sqrt**2)
|
||||
* (1 - alpha_sqrt**2)
|
||||
/ (1 - alpha_cumprod_sqrt**2)
|
||||
)
|
||||
** 0.5,
|
||||
x.ndim,
|
||||
)
|
||||
* torch.randn_like(x)
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
@ -874,7 +958,7 @@ class VideoDDPMSampler(VideoDDIMSampler):
|
||||
x,
|
||||
cond,
|
||||
uc,
|
||||
idx=self.num_steps - i
|
||||
idx=self.num_steps - i,
|
||||
)
|
||||
|
||||
return x
|
@ -17,7 +17,15 @@ class EDMSampling:
|
||||
|
||||
|
||||
class DiscreteSampling:
|
||||
def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True, uniform_sampling=False, group_num=0):
|
||||
def __init__(
|
||||
self,
|
||||
discretization_config,
|
||||
num_idx,
|
||||
do_append_zero=False,
|
||||
flip=True,
|
||||
uniform_sampling=False,
|
||||
group_num=0,
|
||||
):
|
||||
self.num_idx = num_idx
|
||||
self.sigmas = instantiate_from_config(discretization_config)(
|
||||
num_idx, do_append_zero=do_append_zero, flip=flip
|
||||
@ -30,7 +38,7 @@ class DiscreteSampling:
|
||||
if self.uniform_sampling:
|
||||
assert self.group_num > 0
|
||||
assert world_size % group_num == 0
|
||||
self.group_width = world_size // group_num # the number of rank in one group
|
||||
self.group_width = world_size // group_num # the number of rank in one group
|
||||
self.sigma_interval = self.num_idx // self.group_num
|
||||
|
||||
def idx_to_sigma(self, idx):
|
||||
@ -42,7 +50,11 @@ class DiscreteSampling:
|
||||
group_index = rank // self.group_width
|
||||
idx = default(
|
||||
rand,
|
||||
torch.randint(group_index * self.sigma_interval, (group_index + 1) * self.sigma_interval, (n_samples,)),
|
||||
torch.randint(
|
||||
group_index * self.sigma_interval,
|
||||
(group_index + 1) * self.sigma_interval,
|
||||
(n_samples,),
|
||||
),
|
||||
)
|
||||
else:
|
||||
idx = default(
|
||||
@ -54,8 +66,11 @@ class DiscreteSampling:
|
||||
else:
|
||||
return self.idx_to_sigma(idx)
|
||||
|
||||
|
||||
class PartialDiscreteSampling:
|
||||
def __init__(self, discretization_config, total_num_idx, partial_num_idx, do_append_zero=False, flip=True):
|
||||
def __init__(
|
||||
self, discretization_config, total_num_idx, partial_num_idx, do_append_zero=False, flip=True
|
||||
):
|
||||
self.total_num_idx = total_num_idx
|
||||
self.partial_num_idx = partial_num_idx
|
||||
self.sigmas = instantiate_from_config(discretization_config)(
|
||||
|
@ -24,7 +24,9 @@ def make_beta_schedule(
|
||||
linear_end=2e-2,
|
||||
):
|
||||
if schedule == "linear":
|
||||
betas = torch.linspace(linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64) ** 2
|
||||
betas = (
|
||||
torch.linspace(linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64) ** 2
|
||||
)
|
||||
return betas.numpy()
|
||||
|
||||
|
||||
@ -50,7 +52,9 @@ def mixed_checkpoint(func, inputs: dict, params, flag):
|
||||
tensor_keys = [key for key in inputs if isinstance(inputs[key], torch.Tensor)]
|
||||
tensor_inputs = [inputs[key] for key in inputs if isinstance(inputs[key], torch.Tensor)]
|
||||
non_tensor_keys = [key for key in inputs if not isinstance(inputs[key], torch.Tensor)]
|
||||
non_tensor_inputs = [inputs[key] for key in inputs if not isinstance(inputs[key], torch.Tensor)]
|
||||
non_tensor_inputs = [
|
||||
inputs[key] for key in inputs if not isinstance(inputs[key], torch.Tensor)
|
||||
]
|
||||
args = tuple(tensor_inputs) + tuple(non_tensor_inputs) + tuple(params)
|
||||
return MixedCheckpointFunction.apply(
|
||||
func,
|
||||
@ -84,9 +88,14 @@ class MixedCheckpointFunction(torch.autograd.Function):
|
||||
}
|
||||
assert len(tensor_keys) == length_tensors and len(non_tensor_keys) == length_non_tensors
|
||||
|
||||
ctx.input_tensors = {key: val for (key, val) in zip(tensor_keys, list(args[: ctx.end_tensors]))}
|
||||
ctx.input_tensors = {
|
||||
key: val for (key, val) in zip(tensor_keys, list(args[: ctx.end_tensors]))
|
||||
}
|
||||
ctx.input_non_tensors = {
|
||||
key: val for (key, val) in zip(non_tensor_keys, list(args[ctx.end_tensors : ctx.end_non_tensors]))
|
||||
key: val
|
||||
for (key, val) in zip(
|
||||
non_tensor_keys, list(args[ctx.end_tensors : ctx.end_non_tensors])
|
||||
)
|
||||
}
|
||||
ctx.run_function = run_function
|
||||
ctx.input_params = list(args[ctx.end_non_tensors :])
|
||||
@ -98,13 +107,18 @@ class MixedCheckpointFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def backward(ctx, *output_grads):
|
||||
# additional_args = {key: ctx.input_tensors[key] for key in ctx.input_tensors if not isinstance(ctx.input_tensors[key],torch.Tensor)}
|
||||
ctx.input_tensors = {key: ctx.input_tensors[key].detach().requires_grad_(True) for key in ctx.input_tensors}
|
||||
ctx.input_tensors = {
|
||||
key: ctx.input_tensors[key].detach().requires_grad_(True) for key in ctx.input_tensors
|
||||
}
|
||||
|
||||
with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
|
||||
# Fixes a bug where the first op in run_function modifies the
|
||||
# Tensor storage in place, which is not allowed for detach()'d
|
||||
# Tensors.
|
||||
shallow_copies = {key: ctx.input_tensors[key].view_as(ctx.input_tensors[key]) for key in ctx.input_tensors}
|
||||
shallow_copies = {
|
||||
key: ctx.input_tensors[key].view_as(ctx.input_tensors[key])
|
||||
for key in ctx.input_tensors
|
||||
}
|
||||
# shallow_copies.update(additional_args)
|
||||
output_tensors = ctx.run_function(**shallow_copies, **ctx.input_non_tensors)
|
||||
input_grads = torch.autograd.grad(
|
||||
@ -188,9 +202,9 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False, dtyp
|
||||
"""
|
||||
if not repeat_only:
|
||||
half = dim // 2
|
||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
|
||||
device=timesteps.device
|
||||
)
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
||||
).to(device=timesteps.device)
|
||||
args = timesteps[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
|
@ -6,7 +6,9 @@ OPENAIUNETWRAPPER = "sgm.modules.diffusionmodules.wrappers.OpenAIWrapper"
|
||||
|
||||
|
||||
class IdentityWrapper(nn.Module):
|
||||
def __init__(self, diffusion_model, compile_model: bool = False, dtype: torch.dtype = torch.float32):
|
||||
def __init__(
|
||||
self, diffusion_model, compile_model: bool = False, dtype: torch.dtype = torch.float32
|
||||
):
|
||||
super().__init__()
|
||||
compile = (
|
||||
torch.compile
|
||||
|
@ -87,8 +87,14 @@ def normal_kl(mean1, logvar1, mean2, logvar2):
|
||||
|
||||
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
||||
# Tensors, but it does not work for torch.exp().
|
||||
logvar1, logvar2 = [x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2)]
|
||||
logvar1, logvar2 = [
|
||||
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2)
|
||||
]
|
||||
|
||||
return 0.5 * (
|
||||
-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
|
||||
-1.0
|
||||
+ logvar2
|
||||
- logvar1
|
||||
+ torch.exp(logvar1 - logvar2)
|
||||
+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
|
||||
)
|
||||
|
@ -12,7 +12,9 @@ class LitEma(nn.Module):
|
||||
self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
|
||||
self.register_buffer(
|
||||
"num_updates",
|
||||
torch.tensor(0, dtype=torch.int) if use_num_upates else torch.tensor(-1, dtype=torch.int),
|
||||
torch.tensor(0, dtype=torch.int)
|
||||
if use_num_upates
|
||||
else torch.tensor(-1, dtype=torch.int),
|
||||
)
|
||||
|
||||
for name, p in model.named_parameters():
|
||||
@ -45,9 +47,11 @@ class LitEma(nn.Module):
|
||||
if m_param[key].requires_grad:
|
||||
sname = self.m_name2s_name[key]
|
||||
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
|
||||
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
|
||||
shadow_params[sname].sub_(
|
||||
one_minus_decay * (shadow_params[sname] - m_param[key])
|
||||
)
|
||||
else:
|
||||
assert not key in self.m_name2s_name
|
||||
assert key not in self.m_name2s_name
|
||||
|
||||
def copy_to(self, model):
|
||||
m_param = dict(model.named_parameters())
|
||||
@ -56,7 +60,7 @@ class LitEma(nn.Module):
|
||||
if m_param[key].requires_grad:
|
||||
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
|
||||
else:
|
||||
assert not key in self.m_name2s_name
|
||||
assert key not in self.m_name2s_name
|
||||
|
||||
def store(self, parameters):
|
||||
"""
|
||||
|
@ -99,7 +99,9 @@ class GeneralConditioner(nn.Module):
|
||||
elif "input_keys" in embconfig:
|
||||
embedder.input_keys = embconfig["input_keys"]
|
||||
else:
|
||||
raise KeyError(f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}")
|
||||
raise KeyError(
|
||||
f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}"
|
||||
)
|
||||
|
||||
embedder.legacy_ucg_val = embconfig.get("legacy_ucg_value", None)
|
||||
if embedder.legacy_ucg_val is not None:
|
||||
@ -160,7 +162,10 @@ class GeneralConditioner(nn.Module):
|
||||
if cond_or_not is None:
|
||||
emb = (
|
||||
expand_dims_like(
|
||||
torch.bernoulli((1.0 - embedder.ucg_rate) * torch.ones(emb.shape[0], device=emb.device)),
|
||||
torch.bernoulli(
|
||||
(1.0 - embedder.ucg_rate)
|
||||
* torch.ones(emb.shape[0], device=emb.device)
|
||||
),
|
||||
emb,
|
||||
)
|
||||
* emb
|
||||
|
@ -96,7 +96,9 @@ class VideoTransformerBlock(nn.Module):
|
||||
if self.checkpoint:
|
||||
print(f"{self.__class__.__name__} is using checkpointing")
|
||||
|
||||
def forward(self, x: torch.Tensor, context: torch.Tensor = None, timesteps: int = None) -> torch.Tensor:
|
||||
def forward(
|
||||
self, x: torch.Tensor, context: torch.Tensor = None, timesteps: int = None
|
||||
) -> torch.Tensor:
|
||||
if self.checkpoint:
|
||||
return checkpoint(self._forward, x, context, timesteps)
|
||||
else:
|
||||
@ -239,7 +241,9 @@ class SpatialVideoTransformer(SpatialTransformer):
|
||||
spatial_context = context
|
||||
|
||||
if self.use_spatial_context:
|
||||
assert context.ndim == 3, f"n dims of spatial context should be 3 but are {context.ndim}"
|
||||
assert (
|
||||
context.ndim == 3
|
||||
), f"n dims of spatial context should be 3 but are {context.ndim}"
|
||||
|
||||
time_context = context
|
||||
time_context_first_timestep = time_context[::timesteps]
|
||||
|
@ -86,7 +86,9 @@ class SafeConv3d(torch.nn.Conv3d):
|
||||
input_chunks = torch.chunk(input, part_num, dim=2) # NCTHW
|
||||
if kernel_size > 1:
|
||||
input_chunks = [input_chunks[0]] + [
|
||||
torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2)
|
||||
torch.cat(
|
||||
(input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2
|
||||
)
|
||||
for i in range(1, len(input_chunks))
|
||||
]
|
||||
|
||||
@ -252,7 +254,7 @@ def count_params(model, verbose=False):
|
||||
|
||||
|
||||
def instantiate_from_config(config, **extra_kwargs):
|
||||
if not "target" in config:
|
||||
if "target" not in config:
|
||||
if config == "__is_first_stage__":
|
||||
return None
|
||||
elif config == "__is_unconditional__":
|
||||
|
@ -93,7 +93,12 @@ class SimpleDistributedWebDataset(DataPipeline):
|
||||
|
||||
|
||||
def tar_file_iterator_with_meta(
|
||||
fileobj, meta_names, skip_meta=r"__[^/]*__($|/)", suffix=None, handler=reraise_exception, meta_stream=None
|
||||
fileobj,
|
||||
meta_names,
|
||||
skip_meta=r"__[^/]*__($|/)",
|
||||
suffix=None,
|
||||
handler=reraise_exception,
|
||||
meta_stream=None,
|
||||
):
|
||||
"""Iterate over tar file, yielding filename, content pairs for the given tar stream.
|
||||
|
||||
@ -122,10 +127,13 @@ def tar_file_iterator_with_meta(
|
||||
except Exception as exn:
|
||||
from sat.helpers import print_rank0
|
||||
|
||||
print_rank0(f"Error in loading jsonl {meta_file_name}, lineno {lineno}: {line}", level="DEBUG")
|
||||
print_rank0(
|
||||
f"Error in loading jsonl {meta_file_name}, lineno {lineno}: {line}",
|
||||
level="DEBUG",
|
||||
)
|
||||
continue
|
||||
for item in meta_list:
|
||||
if not item["key"] in meta_data:
|
||||
if item["key"] not in meta_data:
|
||||
meta_data[item["key"]] = {}
|
||||
for meta_name in meta_names:
|
||||
if meta_name in item:
|
||||
@ -186,7 +194,9 @@ def tar_file_expander_with_meta(data, meta_names, handler=reraise_exception):
|
||||
try:
|
||||
assert isinstance(source, dict)
|
||||
assert "stream" in source
|
||||
for sample in tar_file_iterator_with_meta(source["stream"], meta_names, meta_stream=source["meta_stream"]):
|
||||
for sample in tar_file_iterator_with_meta(
|
||||
source["stream"], meta_names, meta_stream=source["meta_stream"]
|
||||
):
|
||||
assert isinstance(sample, dict) and "data" in sample and "fname" in sample
|
||||
sample["__url__"] = url
|
||||
yield sample
|
||||
@ -250,7 +260,15 @@ class MetaDistributedWebDataset(DataPipeline):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, path, process_fn, seed, *, meta_names=[], nshards=sys.maxsize, shuffle_buffer=1000, include_dirs=None
|
||||
self,
|
||||
path,
|
||||
process_fn,
|
||||
seed,
|
||||
*,
|
||||
meta_names=[],
|
||||
nshards=sys.maxsize,
|
||||
shuffle_buffer=1000,
|
||||
include_dirs=None,
|
||||
):
|
||||
# os.environ['WDS_SHOW_SEED'] = '1'
|
||||
import torch
|
||||
@ -361,7 +379,10 @@ def gopen_boto3(url, mode="rb", bufsize=8192 * 2):
|
||||
|
||||
if mode[0] == "r":
|
||||
s3_client = boto3.client(
|
||||
"s3", endpoint_url=endpoint_url, aws_access_key_id=access_key, aws_secret_access_key=secret_key
|
||||
"s3",
|
||||
endpoint_url=endpoint_url,
|
||||
aws_access_key_id=access_key,
|
||||
aws_secret_access_key=secret_key,
|
||||
)
|
||||
bucket, key = url.split("/", 1)
|
||||
|
||||
|
@ -37,7 +37,9 @@ def save_texts(texts, save_dir, iterations):
|
||||
f.write(text + "\n")
|
||||
|
||||
|
||||
def save_video_as_grid_and_mp4(video_batch: torch.Tensor, save_path: str, T: int, fps: int = 5, args=None, key=None):
|
||||
def save_video_as_grid_and_mp4(
|
||||
video_batch: torch.Tensor, save_path: str, T: int, fps: int = 5, args=None, key=None
|
||||
):
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
|
||||
for i, vid in enumerate(video_batch):
|
||||
@ -52,7 +54,8 @@ def save_video_as_grid_and_mp4(video_batch: torch.Tensor, save_path: str, T: int
|
||||
writer.append_data(frame)
|
||||
if args is not None and args.wandb:
|
||||
wandb.log(
|
||||
{key + f"_video_{i}": wandb.Video(now_save_path, fps=fps, format="mp4")}, step=args.iteration + 1
|
||||
{key + f"_video_{i}": wandb.Video(now_save_path, fps=fps, format="mp4")},
|
||||
step=args.iteration + 1,
|
||||
)
|
||||
|
||||
|
||||
@ -138,7 +141,9 @@ def broad_cast_batch(batch):
|
||||
return batch
|
||||
|
||||
|
||||
def forward_step_eval(data_iterator, model, args, timers, only_log_video_latents=False, data_class=None):
|
||||
def forward_step_eval(
|
||||
data_iterator, model, args, timers, only_log_video_latents=False, data_class=None
|
||||
):
|
||||
if mpu.get_model_parallel_rank() == 0:
|
||||
timers("data loader").start()
|
||||
batch_video = next(data_iterator)
|
||||
@ -209,7 +214,9 @@ if __name__ == "__main__":
|
||||
args = argparse.Namespace(**vars(args), **vars(known))
|
||||
|
||||
data_class = get_obj_from_str(args.data_config["target"])
|
||||
create_dataset_function = partial(data_class.create_dataset_function, **args.data_config["params"])
|
||||
create_dataset_function = partial(
|
||||
data_class.create_dataset_function, **args.data_config["params"]
|
||||
)
|
||||
|
||||
import yaml
|
||||
|
||||
@ -225,7 +232,9 @@ if __name__ == "__main__":
|
||||
model_cls=SATVideoDiffusionEngine,
|
||||
forward_step_function=partial(forward_step, data_class=data_class),
|
||||
forward_step_eval=partial(
|
||||
forward_step_eval, data_class=data_class, only_log_video_latents=args.only_log_video_latents
|
||||
forward_step_eval,
|
||||
data_class=data_class,
|
||||
only_log_video_latents=args.only_log_video_latents,
|
||||
),
|
||||
create_dataset_function=create_dataset_function,
|
||||
)
|
||||
|
@ -94,7 +94,11 @@ class FeedForward(nn.Module):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = default(dim_out, dim)
|
||||
project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
|
||||
project_in = (
|
||||
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
|
||||
if not glu
|
||||
else GEGLU(dim, inner_dim)
|
||||
)
|
||||
|
||||
self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
|
||||
|
||||
@ -126,7 +130,9 @@ class LinearAttention(nn.Module):
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
qkv = self.to_qkv(x)
|
||||
q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
|
||||
q, k, v = rearrange(
|
||||
qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
|
||||
)
|
||||
k = k.softmax(dim=-1)
|
||||
context = torch.einsum("bhdn,bhen->bhde", k, v)
|
||||
out = torch.einsum("bhde,bhdn->bhen", context, q)
|
||||
@ -143,7 +149,9 @@ class SpatialSelfAttention(nn.Module):
|
||||
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
@ -244,7 +252,9 @@ class CrossAttention(nn.Module):
|
||||
# new
|
||||
with sdp_kernel(**BACKEND_MAP[self.backend]):
|
||||
# print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
|
||||
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) # scale is dim_head ** -0.5 per default
|
||||
out = F.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=mask
|
||||
) # scale is dim_head ** -0.5 per default
|
||||
|
||||
del q, k, v
|
||||
out = rearrange(out, "b h n d -> b n (h d)", h=h)
|
||||
@ -422,7 +432,9 @@ class BasicTransformerBlock(nn.Module):
|
||||
self.norm1(x),
|
||||
context=context if self.disable_self_attn else None,
|
||||
additional_tokens=additional_tokens,
|
||||
n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self if not self.disable_self_attn else 0,
|
||||
n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self
|
||||
if not self.disable_self_attn
|
||||
else 0,
|
||||
)
|
||||
+ x
|
||||
)
|
||||
@ -499,7 +511,9 @@ class SpatialTransformer(nn.Module):
|
||||
sdp_backend=None,
|
||||
):
|
||||
super().__init__()
|
||||
print(f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads")
|
||||
print(
|
||||
f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads"
|
||||
)
|
||||
from omegaconf import ListConfig
|
||||
|
||||
if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)):
|
||||
@ -542,7 +556,9 @@ class SpatialTransformer(nn.Module):
|
||||
]
|
||||
)
|
||||
if not use_linear:
|
||||
self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
|
||||
self.proj_out = zero_module(
|
||||
nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
)
|
||||
else:
|
||||
# self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
|
||||
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
|
||||
|
@ -97,9 +97,7 @@ class AbstractAutoencoder(pl.LightningModule):
|
||||
|
||||
def instantiate_optimizer_from_config(self, params, lr, cfg):
|
||||
logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
|
||||
return get_obj_from_str(cfg["target"])(
|
||||
params, lr=lr, **cfg.get("params", dict())
|
||||
)
|
||||
return get_obj_from_str(cfg["target"])(params, lr=lr, **cfg.get("params", dict()))
|
||||
|
||||
def configure_optimizers(self) -> Any:
|
||||
raise NotImplementedError()
|
||||
@ -196,11 +194,11 @@ class AutoencodingEngine(AbstractAutoencoder):
|
||||
return self.decoder.get_last_layer()
|
||||
|
||||
def encode(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
return_reg_log: bool = False,
|
||||
unregularized: bool = False,
|
||||
**kwargs,
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
return_reg_log: bool = False,
|
||||
unregularized: bool = False,
|
||||
**kwargs,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
||||
z = self.encoder(x, **kwargs)
|
||||
if unregularized:
|
||||
@ -214,14 +212,20 @@ class AutoencodingEngine(AbstractAutoencoder):
|
||||
x = self.decoder(z, **kwargs)
|
||||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor, **additional_decode_kwargs) -> Tuple[torch.Tensor, torch.Tensor, dict]:
|
||||
def forward(
|
||||
self, x: torch.Tensor, **additional_decode_kwargs
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
|
||||
z, reg_log = self.encode(x, return_reg_log=True)
|
||||
dec = self.decode(z, **additional_decode_kwargs)
|
||||
return z, dec, reg_log
|
||||
|
||||
def inner_training_step(self, batch: dict, batch_idx: int, optimizer_idx: int = 0) -> torch.Tensor:
|
||||
def inner_training_step(
|
||||
self, batch: dict, batch_idx: int, optimizer_idx: int = 0
|
||||
) -> torch.Tensor:
|
||||
x = self.get_input(batch)
|
||||
additional_decode_kwargs = {key: batch[key] for key in self.additional_decode_keys.intersection(batch)}
|
||||
additional_decode_kwargs = {
|
||||
key: batch[key] for key in self.additional_decode_keys.intersection(batch)
|
||||
}
|
||||
z, xrec, regularization_log = self(x, **additional_decode_kwargs)
|
||||
if hasattr(self.loss, "forward_keys"):
|
||||
extra_info = {
|
||||
@ -357,12 +361,16 @@ class AutoencodingEngine(AbstractAutoencoder):
|
||||
if self.trainable_ae_params is None:
|
||||
ae_params = self.get_autoencoder_params()
|
||||
else:
|
||||
ae_params, num_ae_params = self.get_param_groups(self.trainable_ae_params, self.ae_optimizer_args)
|
||||
ae_params, num_ae_params = self.get_param_groups(
|
||||
self.trainable_ae_params, self.ae_optimizer_args
|
||||
)
|
||||
logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}")
|
||||
if self.trainable_disc_params is None:
|
||||
disc_params = self.get_discriminator_params()
|
||||
else:
|
||||
disc_params, num_disc_params = self.get_param_groups(self.trainable_disc_params, self.disc_optimizer_args)
|
||||
disc_params, num_disc_params = self.get_param_groups(
|
||||
self.trainable_disc_params, self.disc_optimizer_args
|
||||
)
|
||||
logpy.info(f"Number of trainable discriminator parameters: {num_disc_params:,}")
|
||||
opt_ae = self.instantiate_optimizer_from_config(
|
||||
ae_params,
|
||||
@ -371,17 +379,23 @@ class AutoencodingEngine(AbstractAutoencoder):
|
||||
)
|
||||
opts = [opt_ae]
|
||||
if len(disc_params) > 0:
|
||||
opt_disc = self.instantiate_optimizer_from_config(disc_params, self.learning_rate, self.optimizer_config)
|
||||
opt_disc = self.instantiate_optimizer_from_config(
|
||||
disc_params, self.learning_rate, self.optimizer_config
|
||||
)
|
||||
opts.append(opt_disc)
|
||||
|
||||
return opts
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs) -> dict:
|
||||
def log_images(
|
||||
self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs
|
||||
) -> dict:
|
||||
log = dict()
|
||||
additional_decode_kwargs = {}
|
||||
x = self.get_input(batch)
|
||||
additional_decode_kwargs.update({key: batch[key] for key in self.additional_decode_keys.intersection(batch)})
|
||||
additional_decode_kwargs.update(
|
||||
{key: batch[key] for key in self.additional_decode_keys.intersection(batch)}
|
||||
)
|
||||
|
||||
_, xrec, _ = self(x, **additional_decode_kwargs)
|
||||
log["inputs"] = x
|
||||
@ -400,7 +414,9 @@ class AutoencodingEngine(AbstractAutoencoder):
|
||||
diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x)
|
||||
diff_ema.clamp_(0, 1.0)
|
||||
log["diff_ema"] = 2.0 * diff_ema - 1.0
|
||||
log["diff_boost_ema"] = 2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1
|
||||
log["diff_boost_ema"] = (
|
||||
2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1
|
||||
)
|
||||
if additional_log_kwargs:
|
||||
additional_decode_kwargs.update(additional_log_kwargs)
|
||||
_, xrec_add, _ = self(x, **additional_decode_kwargs)
|
||||
@ -442,7 +458,9 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
|
||||
params = super().get_autoencoder_params()
|
||||
return params
|
||||
|
||||
def encode(self, x: torch.Tensor, return_reg_log: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_reg_log: bool = False
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
||||
if self.max_batch_size is None:
|
||||
z = self.encoder(x)
|
||||
z = self.quant_conv(z)
|
||||
@ -485,7 +503,9 @@ class AutoencoderKL(AutoencodingEngineLegacy):
|
||||
if "lossconfig" in kwargs:
|
||||
kwargs["loss_config"] = kwargs.pop("lossconfig")
|
||||
super().__init__(
|
||||
regularizer_config={"target": ("sgm.modules.autoencoding.regularizers" ".DiagonalGaussianRegularizer")},
|
||||
regularizer_config={
|
||||
"target": ("sgm.modules.autoencoding.regularizers" ".DiagonalGaussianRegularizer")
|
||||
},
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -519,7 +539,9 @@ class VideoAutoencodingEngine(AutoencodingEngine):
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
|
||||
def log_videos(self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs) -> dict:
|
||||
def log_videos(
|
||||
self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs
|
||||
) -> dict:
|
||||
return self.log_images(batch, additional_log_kwargs, **kwargs)
|
||||
|
||||
def get_input(self, batch: dict) -> torch.Tensor:
|
||||
@ -530,7 +552,9 @@ class VideoAutoencodingEngine(AutoencodingEngine):
|
||||
batch = batch[self.input_key]
|
||||
|
||||
global_src_rank = get_context_parallel_group_rank() * self.context_parallel_size
|
||||
torch.distributed.broadcast(batch, src=global_src_rank, group=get_context_parallel_group())
|
||||
torch.distributed.broadcast(
|
||||
batch, src=global_src_rank, group=get_context_parallel_group()
|
||||
)
|
||||
|
||||
batch = _conv_split(batch, dim=2, kernel_size=1)
|
||||
return batch
|
||||
|
@ -201,7 +201,9 @@ def _pass_from_previous_rank(input_, dim, kernel_size):
|
||||
recv_rank += cp_world_size
|
||||
|
||||
if cp_rank < cp_world_size - 1:
|
||||
req_send = torch.distributed.isend(input_[-kernel_size + 1 :].contiguous(), send_rank, group=group)
|
||||
req_send = torch.distributed.isend(
|
||||
input_[-kernel_size + 1 :].contiguous(), send_rank, group=group
|
||||
)
|
||||
if cp_rank > 0:
|
||||
recv_buffer = torch.empty_like(input_[-kernel_size + 1 :]).contiguous()
|
||||
req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group)
|
||||
@ -246,11 +248,12 @@ def _fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding=Non
|
||||
|
||||
recv_buffer = torch.empty_like(input_[-kernel_size + 1 :]).contiguous()
|
||||
if cp_rank < cp_world_size - 1:
|
||||
req_send = torch.distributed.isend(input_[-kernel_size + 1 :].contiguous(), send_rank, group=group)
|
||||
req_send = torch.distributed.isend(
|
||||
input_[-kernel_size + 1 :].contiguous(), send_rank, group=group
|
||||
)
|
||||
if cp_rank > 0:
|
||||
req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group)
|
||||
|
||||
|
||||
if cp_rank == 0:
|
||||
if cache_padding is not None:
|
||||
input_ = torch.cat([cache_padding.transpose(0, dim).to(input_.device), input_], dim=0)
|
||||
@ -334,7 +337,9 @@ def fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding):
|
||||
|
||||
|
||||
class ContextParallelCausalConv3d(nn.Module):
|
||||
def __init__(self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], stride=1, **kwargs):
|
||||
def __init__(
|
||||
self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], stride=1, **kwargs
|
||||
):
|
||||
super().__init__()
|
||||
kernel_size = cast_tuple(kernel_size, 3)
|
||||
|
||||
@ -354,7 +359,9 @@ class ContextParallelCausalConv3d(nn.Module):
|
||||
|
||||
stride = (stride, stride, stride)
|
||||
dilation = (1, 1, 1)
|
||||
self.conv = Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
||||
self.conv = Conv3d(
|
||||
chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs
|
||||
)
|
||||
self.cache_padding = None
|
||||
|
||||
def forward(self, input_, clear_cache=True):
|
||||
@ -369,7 +376,11 @@ class ContextParallelCausalConv3d(nn.Module):
|
||||
global_rank = torch.distributed.get_rank()
|
||||
if cp_world_size == 1:
|
||||
self.cache_padding = (
|
||||
input_parallel[:, :, -self.time_kernel_size + 1 :].contiguous().detach().clone().cpu()
|
||||
input_parallel[:, :, -self.time_kernel_size + 1 :]
|
||||
.contiguous()
|
||||
.detach()
|
||||
.clone()
|
||||
.cpu()
|
||||
)
|
||||
else:
|
||||
if cp_rank == cp_world_size - 1:
|
||||
@ -379,9 +390,13 @@ class ContextParallelCausalConv3d(nn.Module):
|
||||
group=get_context_parallel_group(),
|
||||
)
|
||||
if cp_rank == 0:
|
||||
recv_buffer = torch.empty_like(input_parallel[:, :, -self.time_kernel_size + 1 :]).contiguous()
|
||||
recv_buffer = torch.empty_like(
|
||||
input_parallel[:, :, -self.time_kernel_size + 1 :]
|
||||
).contiguous()
|
||||
torch.distributed.recv(
|
||||
recv_buffer, global_rank - 1 + cp_world_size, group=get_context_parallel_group()
|
||||
recv_buffer,
|
||||
global_rank - 1 + cp_world_size,
|
||||
group=get_context_parallel_group(),
|
||||
)
|
||||
self.cache_padding = recv_buffer.contiguous().detach().clone().cpu()
|
||||
|
||||
@ -406,7 +421,9 @@ class ContextParallelGroupNorm(torch.nn.GroupNorm):
|
||||
|
||||
def Normalize(in_channels, gather=False, **kwargs):
|
||||
if gather:
|
||||
return ContextParallelGroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
return ContextParallelGroupNorm(
|
||||
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
||||
)
|
||||
else:
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
@ -460,7 +477,8 @@ class SpatialNorm3D(nn.Module):
|
||||
|
||||
zq_rest_splits = torch.split(zq_rest, 32, dim=1)
|
||||
interpolated_splits = [
|
||||
torch.nn.functional.interpolate(split, size=f_rest_size, mode="nearest") for split in zq_rest_splits
|
||||
torch.nn.functional.interpolate(split, size=f_rest_size, mode="nearest")
|
||||
for split in zq_rest_splits
|
||||
]
|
||||
|
||||
zq_rest = torch.cat(interpolated_splits, dim=1)
|
||||
@ -471,7 +489,8 @@ class SpatialNorm3D(nn.Module):
|
||||
|
||||
zq_splits = torch.split(zq, 32, dim=1)
|
||||
interpolated_splits = [
|
||||
torch.nn.functional.interpolate(split, size=f_size, mode="nearest") for split in zq_splits
|
||||
torch.nn.functional.interpolate(split, size=f_size, mode="nearest")
|
||||
for split in zq_splits
|
||||
]
|
||||
zq = torch.cat(interpolated_splits, dim=1)
|
||||
|
||||
@ -511,7 +530,9 @@ class Upsample3D(nn.Module):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.conv = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
self.compress_time = compress_time
|
||||
|
||||
def forward(self, x, fake_cp=True):
|
||||
@ -523,14 +544,16 @@ class Upsample3D(nn.Module):
|
||||
|
||||
splits = torch.split(x_rest, 32, dim=1)
|
||||
interpolated_splits = [
|
||||
torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits
|
||||
torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest")
|
||||
for split in splits
|
||||
]
|
||||
x_rest = torch.cat(interpolated_splits, dim=1)
|
||||
x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2)
|
||||
else:
|
||||
splits = torch.split(x, 32, dim=1)
|
||||
interpolated_splits = [
|
||||
torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits
|
||||
torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest")
|
||||
for split in splits
|
||||
]
|
||||
x = torch.cat(interpolated_splits, dim=1)
|
||||
|
||||
@ -541,7 +564,8 @@ class Upsample3D(nn.Module):
|
||||
|
||||
splits = torch.split(x, 32, dim=1)
|
||||
interpolated_splits = [
|
||||
torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits
|
||||
torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest")
|
||||
for split in splits
|
||||
]
|
||||
x = torch.cat(interpolated_splits, dim=1)
|
||||
|
||||
@ -563,7 +587,9 @@ class DownSample3D(nn.Module):
|
||||
out_channels = in_channels
|
||||
if self.with_conv:
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0)
|
||||
self.conv = torch.nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=2, padding=0
|
||||
)
|
||||
self.compress_time = compress_time
|
||||
|
||||
def forward(self, x, fake_cp=True):
|
||||
@ -578,7 +604,8 @@ class DownSample3D(nn.Module):
|
||||
if x_rest.shape[-1] > 0:
|
||||
splits = torch.split(x_rest, 32, dim=1)
|
||||
interpolated_splits = [
|
||||
torch.nn.functional.avg_pool1d(split, kernel_size=2, stride=2) for split in splits
|
||||
torch.nn.functional.avg_pool1d(split, kernel_size=2, stride=2)
|
||||
for split in splits
|
||||
]
|
||||
x_rest = torch.cat(interpolated_splits, dim=1)
|
||||
x = torch.cat([x_first[..., None], x_rest], dim=-1)
|
||||
@ -587,7 +614,8 @@ class DownSample3D(nn.Module):
|
||||
# x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2)
|
||||
splits = torch.split(x, 32, dim=1)
|
||||
interpolated_splits = [
|
||||
torch.nn.functional.avg_pool1d(split, kernel_size=2, stride=2) for split in splits
|
||||
torch.nn.functional.avg_pool1d(split, kernel_size=2, stride=2)
|
||||
for split in splits
|
||||
]
|
||||
x = torch.cat(interpolated_splits, dim=1)
|
||||
x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w)
|
||||
@ -923,9 +951,13 @@ class ContextParallelDecoder3D(nn.Module):
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
if i_level < self.num_resolutions - self.temporal_compress_level:
|
||||
up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False)
|
||||
up.upsample = Upsample3D(
|
||||
block_in, with_conv=resamp_with_conv, compress_time=False
|
||||
)
|
||||
else:
|
||||
up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=True)
|
||||
up.upsample = Upsample3D(
|
||||
block_in, with_conv=resamp_with_conv, compress_time=True
|
||||
)
|
||||
self.up.insert(0, up)
|
||||
|
||||
self.norm_out = Normalize3D(block_in, zq_ch, add_conv=add_conv, gather=gather_norm)
|
||||
|
@ -12,7 +12,9 @@ class LitEma(nn.Module):
|
||||
self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
|
||||
self.register_buffer(
|
||||
"num_updates",
|
||||
torch.tensor(0, dtype=torch.int) if use_num_upates else torch.tensor(-1, dtype=torch.int),
|
||||
torch.tensor(0, dtype=torch.int)
|
||||
if use_num_upates
|
||||
else torch.tensor(-1, dtype=torch.int),
|
||||
)
|
||||
|
||||
for name, p in model.named_parameters():
|
||||
@ -45,9 +47,11 @@ class LitEma(nn.Module):
|
||||
if m_param[key].requires_grad:
|
||||
sname = self.m_name2s_name[key]
|
||||
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
|
||||
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
|
||||
shadow_params[sname].sub_(
|
||||
one_minus_decay * (shadow_params[sname] - m_param[key])
|
||||
)
|
||||
else:
|
||||
assert not key in self.m_name2s_name
|
||||
assert key not in self.m_name2s_name
|
||||
|
||||
def copy_to(self, model):
|
||||
m_param = dict(model.named_parameters())
|
||||
@ -56,7 +60,7 @@ class LitEma(nn.Module):
|
||||
if m_param[key].requires_grad:
|
||||
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
|
||||
else:
|
||||
assert not key in self.m_name2s_name
|
||||
assert key not in self.m_name2s_name
|
||||
|
||||
def store(self, parameters):
|
||||
"""
|
||||
|
@ -77,7 +77,9 @@ class IdentityRegularizer(AbstractRegularizer):
|
||||
yield from ()
|
||||
|
||||
|
||||
def measure_perplexity(predicted_indices: torch.Tensor, num_centroids: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def measure_perplexity(
|
||||
predicted_indices: torch.Tensor, num_centroids: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
|
||||
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
|
||||
encodings = F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids)
|
||||
|
@ -78,7 +78,9 @@ class SafeConv3d(torch.nn.Conv3d):
|
||||
input_chunks = torch.chunk(input, part_num, dim=2) # NCTHW
|
||||
if kernel_size > 1:
|
||||
input_chunks = [input_chunks[0]] + [
|
||||
torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2)
|
||||
torch.cat(
|
||||
(input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2
|
||||
)
|
||||
for i in range(1, len(input_chunks))
|
||||
]
|
||||
|
||||
@ -244,7 +246,7 @@ def count_params(model, verbose=False):
|
||||
|
||||
|
||||
def instantiate_from_config(config):
|
||||
if not "target" in config:
|
||||
if "target" not in config:
|
||||
if config == "__is_first_stage__":
|
||||
return None
|
||||
elif config == "__is_unconditional__":
|
||||
|
@ -9,11 +9,16 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
MODEL_PATH = "THUDM/cogvlm2-llama3-caption"
|
||||
|
||||
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
TORCH_TYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[
|
||||
0] >= 8 else torch.float16
|
||||
TORCH_TYPE = (
|
||||
torch.bfloat16
|
||||
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8
|
||||
else torch.float16
|
||||
)
|
||||
|
||||
parser = argparse.ArgumentParser(description="CogVLM2-Video CLI Demo")
|
||||
parser.add_argument('--quant', type=int, choices=[4, 8], help='Enable 4-bit or 8-bit precision loading', default=0)
|
||||
parser.add_argument(
|
||||
'--quant', type=int, choices=[4, 8], help='Enable 4-bit or 8-bit precision loading', default=0
|
||||
)
|
||||
args = parser.parse_args([])
|
||||
|
||||
|
||||
@ -29,8 +34,11 @@ def load_video(video_data, strategy='chat'):
|
||||
clip_end_sec = 60
|
||||
clip_start_sec = 0
|
||||
start_frame = int(clip_start_sec * decord_vr.get_avg_fps())
|
||||
end_frame = min(total_frames,
|
||||
int(clip_end_sec * decord_vr.get_avg_fps())) if clip_end_sec is not None else total_frames
|
||||
end_frame = (
|
||||
min(total_frames, int(clip_end_sec * decord_vr.get_avg_fps()))
|
||||
if clip_end_sec is not None
|
||||
else total_frames
|
||||
)
|
||||
frame_id_list = np.linspace(start_frame, end_frame - 1, num_frames, dtype=int)
|
||||
elif strategy == 'chat':
|
||||
timestamps = decord_vr.get_frame_timestamp(np.arange(total_frames))
|
||||
@ -54,11 +62,11 @@ tokenizer = AutoTokenizer.from_pretrained(
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
MODEL_PATH,
|
||||
torch_dtype=TORCH_TYPE,
|
||||
trust_remote_code=True
|
||||
).eval().to(DEVICE)
|
||||
model = (
|
||||
AutoModelForCausalLM.from_pretrained(MODEL_PATH, torch_dtype=TORCH_TYPE, trust_remote_code=True)
|
||||
.eval()
|
||||
.to(DEVICE)
|
||||
)
|
||||
|
||||
|
||||
def predict(prompt, video_data, temperature):
|
||||
@ -69,11 +77,7 @@ def predict(prompt, video_data, temperature):
|
||||
history = []
|
||||
query = prompt
|
||||
inputs = model.build_conversation_input_ids(
|
||||
tokenizer=tokenizer,
|
||||
query=query,
|
||||
images=[video],
|
||||
history=history,
|
||||
template_version=strategy
|
||||
tokenizer=tokenizer, query=query, images=[video], history=history, template_version=strategy
|
||||
)
|
||||
inputs = {
|
||||
'input_ids': inputs['input_ids'].unsqueeze(0).to('cuda'),
|
||||
@ -91,7 +95,7 @@ def predict(prompt, video_data, temperature):
|
||||
}
|
||||
with torch.no_grad():
|
||||
outputs = model.generate(**inputs, **gen_kwargs)
|
||||
outputs = outputs[:, inputs['input_ids'].shape[1]:]
|
||||
outputs = outputs[:, inputs['input_ids'].shape[1] :]
|
||||
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
return response
|
||||
|
||||
|
@ -31,9 +31,18 @@ from dataclasses import dataclass
|
||||
# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
|
||||
# DeepSpeed data structures it has to be available in the current python environment.
|
||||
from deepspeed.utils import logger
|
||||
from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS,
|
||||
FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES,
|
||||
FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS)
|
||||
from deepspeed.checkpoint.constants import (
|
||||
DS_VERSION,
|
||||
OPTIMIZER_STATE_DICT,
|
||||
SINGLE_PARTITION_OF_FP32_GROUPS,
|
||||
FP32_FLAT_GROUPS,
|
||||
ZERO_STAGE,
|
||||
PARTITION_COUNT,
|
||||
PARAM_SHAPES,
|
||||
BUFFER_NAMES,
|
||||
FROZEN_PARAM_SHAPES,
|
||||
FROZEN_PARAM_FRAGMENTS,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -134,12 +143,14 @@ def parse_model_states(files):
|
||||
|
||||
frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
|
||||
|
||||
z_model_state = zero_model_state(buffers=buffers,
|
||||
param_shapes=param_shapes,
|
||||
shared_params=shared_params,
|
||||
ds_version=ds_version,
|
||||
frozen_param_shapes=frozen_param_shapes,
|
||||
frozen_param_fragments=frozen_param_fragments)
|
||||
z_model_state = zero_model_state(
|
||||
buffers=buffers,
|
||||
param_shapes=param_shapes,
|
||||
shared_params=shared_params,
|
||||
ds_version=ds_version,
|
||||
frozen_param_shapes=frozen_param_shapes,
|
||||
frozen_param_fragments=frozen_param_fragments,
|
||||
)
|
||||
zero_model_states.append(z_model_state)
|
||||
|
||||
return zero_model_states
|
||||
@ -155,7 +166,7 @@ def parse_optim_states(files, ds_checkpoint_dir):
|
||||
state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
|
||||
state_dicts.append(state_dict)
|
||||
|
||||
if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]:
|
||||
if ZERO_STAGE not in state_dicts[0][OPTIMIZER_STATE_DICT]:
|
||||
raise ValueError(f"{files[0]} is not a zero checkpoint")
|
||||
zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
|
||||
world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
|
||||
@ -181,7 +192,9 @@ def parse_optim_states(files, ds_checkpoint_dir):
|
||||
else:
|
||||
raise ValueError(f"unknown zero stage {zero_stage}")
|
||||
|
||||
fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
|
||||
fp32_flat_groups = [
|
||||
state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))
|
||||
]
|
||||
return zero_stage, world_size, fp32_flat_groups
|
||||
|
||||
|
||||
@ -205,15 +218,20 @@ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_
|
||||
print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
|
||||
|
||||
if zero_stage <= 2:
|
||||
return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
|
||||
exclude_frozen_parameters)
|
||||
return _get_fp32_state_dict_from_zero2_checkpoint(
|
||||
world_size, fp32_flat_groups, zero_model_states, exclude_frozen_parameters
|
||||
)
|
||||
elif zero_stage == 3:
|
||||
return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
|
||||
exclude_frozen_parameters)
|
||||
return _get_fp32_state_dict_from_zero3_checkpoint(
|
||||
world_size, fp32_flat_groups, zero_model_states, exclude_frozen_parameters
|
||||
)
|
||||
|
||||
|
||||
def _zero2_merge_frozen_params(state_dict, zero_model_states):
|
||||
if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
|
||||
if (
|
||||
zero_model_states[0].frozen_param_shapes is None
|
||||
or len(zero_model_states[0].frozen_param_shapes) == 0
|
||||
):
|
||||
return
|
||||
|
||||
frozen_param_shapes = zero_model_states[0].frozen_param_shapes
|
||||
@ -269,11 +287,17 @@ def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero
|
||||
full_single_fp32_vector = torch.cat(merged_partitions, 0)
|
||||
merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
|
||||
avail_numel = sum(
|
||||
[full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups])
|
||||
[
|
||||
full_single_fp32_vector.numel()
|
||||
for full_single_fp32_vector in merged_single_partition_of_fp32_groups
|
||||
]
|
||||
)
|
||||
|
||||
if debug:
|
||||
wanted_params = sum([len(shapes) for shapes in param_shapes])
|
||||
wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
|
||||
wanted_numel = sum(
|
||||
[sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes]
|
||||
)
|
||||
# not asserting if there is a mismatch due to possible padding
|
||||
print(f"Have {avail_numel} numels to process.")
|
||||
print(f"Need {wanted_numel} numels in {wanted_params} params.")
|
||||
@ -283,18 +307,23 @@ def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero
|
||||
# out-of-core computing solution
|
||||
total_numel = 0
|
||||
total_params = 0
|
||||
for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
|
||||
for shapes, full_single_fp32_vector in zip(
|
||||
param_shapes, merged_single_partition_of_fp32_groups
|
||||
):
|
||||
offset = 0
|
||||
avail_numel = full_single_fp32_vector.numel()
|
||||
for name, shape in shapes.items():
|
||||
|
||||
unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape)
|
||||
unpartitioned_numel = (
|
||||
shape.numel() if _has_callable(shape, 'numel') else math.prod(shape)
|
||||
)
|
||||
total_numel += unpartitioned_numel
|
||||
total_params += 1
|
||||
|
||||
if debug:
|
||||
print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
|
||||
state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape)
|
||||
state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(
|
||||
shape
|
||||
)
|
||||
offset += unpartitioned_numel
|
||||
|
||||
# Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
|
||||
@ -322,8 +351,9 @@ def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero
|
||||
print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
|
||||
|
||||
|
||||
def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
|
||||
exclude_frozen_parameters):
|
||||
def _get_fp32_state_dict_from_zero2_checkpoint(
|
||||
world_size, fp32_flat_groups, zero_model_states, exclude_frozen_parameters
|
||||
):
|
||||
state_dict = OrderedDict()
|
||||
|
||||
# buffers
|
||||
@ -353,7 +383,10 @@ def zero3_partitioned_param_info(unpartitioned_numel, world_size):
|
||||
|
||||
|
||||
def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
|
||||
if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
|
||||
if (
|
||||
zero_model_states[0].frozen_param_shapes is None
|
||||
or len(zero_model_states[0].frozen_param_shapes) == 0
|
||||
):
|
||||
return
|
||||
|
||||
if debug:
|
||||
@ -364,7 +397,10 @@ def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
|
||||
frozen_param_shapes = zero_model_states[0].frozen_param_shapes
|
||||
wanted_params = len(frozen_param_shapes)
|
||||
wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
|
||||
avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size
|
||||
avail_numel = (
|
||||
sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()])
|
||||
* world_size
|
||||
)
|
||||
print(f'Frozen params: Have {avail_numel} numels to process.')
|
||||
print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
|
||||
|
||||
@ -375,10 +411,14 @@ def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
|
||||
unpartitioned_numel = shape.numel()
|
||||
total_numel += unpartitioned_numel
|
||||
|
||||
param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states)
|
||||
param_frags = tuple(
|
||||
model_state.frozen_param_fragments[name] for model_state in zero_model_states
|
||||
)
|
||||
state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape)
|
||||
|
||||
partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
|
||||
partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(
|
||||
unpartitioned_numel, world_size
|
||||
)
|
||||
|
||||
if debug:
|
||||
print(
|
||||
@ -416,21 +456,32 @@ class GatheredTensor:
|
||||
start_group_id = None
|
||||
end_group_id = None
|
||||
for group_id in range(len(self.flat_groups_offset)):
|
||||
if self.flat_groups_offset[group_id] <= self.offset < self.flat_groups_offset[group_id + 1]:
|
||||
if (
|
||||
self.flat_groups_offset[group_id]
|
||||
<= self.offset
|
||||
< self.flat_groups_offset[group_id + 1]
|
||||
):
|
||||
start_group_id = group_id
|
||||
if self.flat_groups_offset[group_id] < end_idx <= self.flat_groups_offset[group_id + 1]:
|
||||
if (
|
||||
self.flat_groups_offset[group_id]
|
||||
< end_idx
|
||||
<= self.flat_groups_offset[group_id + 1]
|
||||
):
|
||||
end_group_id = group_id
|
||||
break
|
||||
# collect weights from related group/groups
|
||||
for group_id in range(start_group_id, end_group_id + 1):
|
||||
flat_tensor = flat_groups_at_rank_i[group_id]
|
||||
start_offset = self.offset - self.flat_groups_offset[group_id]
|
||||
end_offset = min(end_idx, self.flat_groups_offset[group_id + 1]) - self.flat_groups_offset[group_id]
|
||||
end_offset = (
|
||||
min(end_idx, self.flat_groups_offset[group_id + 1])
|
||||
- self.flat_groups_offset[group_id]
|
||||
)
|
||||
pad_flat_param_chunks.append(flat_tensor[start_offset:end_offset])
|
||||
|
||||
# collect weights from all ranks
|
||||
pad_flat_param = torch.cat(pad_flat_param_chunks, dim=0)
|
||||
param = pad_flat_param[:self.shape.numel()].view(self.shape).contiguous()
|
||||
param = pad_flat_param[: self.shape.numel()].view(self.shape).contiguous()
|
||||
return param
|
||||
|
||||
|
||||
@ -461,12 +512,16 @@ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero
|
||||
offset = 0
|
||||
total_numel = 0
|
||||
total_params = 0
|
||||
flat_groups_offset = [0] + list(np.cumsum([flat_tensor.numel() for flat_tensor in fp32_flat_groups[0]]))
|
||||
flat_groups_offset = [0] + list(
|
||||
np.cumsum([flat_tensor.numel() for flat_tensor in fp32_flat_groups[0]])
|
||||
)
|
||||
for name, shape in tqdm(param_shapes.items(), desc='Gathering sharded weights'):
|
||||
unpartitioned_numel = shape.numel()
|
||||
total_numel += unpartitioned_numel
|
||||
total_params += 1
|
||||
partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
|
||||
partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(
|
||||
unpartitioned_numel, world_size
|
||||
)
|
||||
|
||||
if debug:
|
||||
print(
|
||||
@ -474,7 +529,9 @@ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero
|
||||
)
|
||||
|
||||
# memory efficient tensor
|
||||
tensor = GatheredTensor(fp32_flat_groups, flat_groups_offset, offset, partitioned_numel, shape)
|
||||
tensor = GatheredTensor(
|
||||
fp32_flat_groups, flat_groups_offset, offset, partitioned_numel, shape
|
||||
)
|
||||
state_dict[name] = tensor
|
||||
offset += partitioned_numel
|
||||
|
||||
@ -484,11 +541,14 @@ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero
|
||||
if offset != avail_numel:
|
||||
raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
|
||||
|
||||
print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")
|
||||
print(
|
||||
f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements"
|
||||
)
|
||||
|
||||
|
||||
def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
|
||||
exclude_frozen_parameters):
|
||||
def _get_fp32_state_dict_from_zero3_checkpoint(
|
||||
world_size, fp32_flat_groups, zero_model_states, exclude_frozen_parameters
|
||||
):
|
||||
state_dict = OrderedDict()
|
||||
|
||||
# buffers
|
||||
@ -530,10 +590,9 @@ def to_torch_tensor(state_dict, return_empty_tensor=False):
|
||||
return torch_state_dict
|
||||
|
||||
|
||||
def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
|
||||
tag=None,
|
||||
exclude_frozen_parameters=False,
|
||||
lazy_mode=False):
|
||||
def get_fp32_state_dict_from_zero_checkpoint(
|
||||
checkpoint_dir, tag=None, exclude_frozen_parameters=False, lazy_mode=False
|
||||
):
|
||||
"""
|
||||
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
|
||||
``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
|
||||
@ -588,19 +647,23 @@ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
|
||||
if not os.path.isdir(ds_checkpoint_dir):
|
||||
raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
|
||||
|
||||
state_dict = _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters)
|
||||
state_dict = _get_fp32_state_dict_from_zero_checkpoint(
|
||||
ds_checkpoint_dir, exclude_frozen_parameters
|
||||
)
|
||||
if lazy_mode:
|
||||
return state_dict
|
||||
else:
|
||||
return to_torch_tensor(state_dict)
|
||||
|
||||
|
||||
def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir,
|
||||
output_dir,
|
||||
max_shard_size="5GB",
|
||||
safe_serialization=False,
|
||||
tag=None,
|
||||
exclude_frozen_parameters=False):
|
||||
def convert_zero_checkpoint_to_fp32_state_dict(
|
||||
checkpoint_dir,
|
||||
output_dir,
|
||||
max_shard_size="5GB",
|
||||
safe_serialization=False,
|
||||
tag=None,
|
||||
exclude_frozen_parameters=False,
|
||||
):
|
||||
"""
|
||||
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
|
||||
loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
|
||||
@ -629,25 +692,28 @@ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir,
|
||||
raise
|
||||
|
||||
# Convert zero checkpoint to state_dict
|
||||
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
|
||||
tag,
|
||||
exclude_frozen_parameters,
|
||||
lazy_mode=True)
|
||||
state_dict = get_fp32_state_dict_from_zero_checkpoint(
|
||||
checkpoint_dir, tag, exclude_frozen_parameters, lazy_mode=True
|
||||
)
|
||||
|
||||
# Shard the model if it is too big.
|
||||
weights_name = "model.safetensors" if safe_serialization else "pytorch_model.bin"
|
||||
if max_shard_size is not None:
|
||||
filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
|
||||
filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
|
||||
".safetensors", "{suffix}.safetensors"
|
||||
)
|
||||
# an memory-efficient approach for sharding
|
||||
empty_state_dict = to_torch_tensor(state_dict, return_empty_tensor=True)
|
||||
state_dict_split = split_torch_state_dict_into_shards(empty_state_dict,
|
||||
filename_pattern=filename_pattern,
|
||||
max_shard_size=max_shard_size)
|
||||
state_dict_split = split_torch_state_dict_into_shards(
|
||||
empty_state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
|
||||
)
|
||||
else:
|
||||
from collections import namedtuple
|
||||
|
||||
StateDictSplit = namedtuple("StateDictSplit", ["is_sharded", "filename_to_tensors"])
|
||||
state_dict_split = StateDictSplit(is_sharded=False,
|
||||
filename_to_tensors={weights_name: list(state_dict.keys())})
|
||||
state_dict_split = StateDictSplit(
|
||||
is_sharded=False, filename_to_tensors={weights_name: list(state_dict.keys())}
|
||||
)
|
||||
|
||||
# Save the model by shard
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
@ -673,7 +739,9 @@ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir,
|
||||
"metadata": state_dict_split.metadata,
|
||||
"weight_map": state_dict_split.tensor_to_filename,
|
||||
}
|
||||
save_index_file = "model.safetensors.index.json" if safe_serialization else "pytorch_model.bin.index.json"
|
||||
save_index_file = (
|
||||
"model.safetensors.index.json" if safe_serialization else "pytorch_model.bin.index.json"
|
||||
)
|
||||
save_index_file = os.path.join(output_dir, save_index_file)
|
||||
with open(save_index_file, "w", encoding="utf-8") as f:
|
||||
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||
@ -719,12 +787,14 @@ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
|
||||
return model
|
||||
|
||||
|
||||
def convert_zero_checkpoint_to_bf16_state_dict(checkpoint_dir,
|
||||
output_dir,
|
||||
max_shard_size="5GB",
|
||||
safe_serialization=True,
|
||||
tag=None,
|
||||
exclude_frozen_parameters=False):
|
||||
def convert_zero_checkpoint_to_bf16_state_dict(
|
||||
checkpoint_dir,
|
||||
output_dir,
|
||||
max_shard_size="5GB",
|
||||
safe_serialization=True,
|
||||
tag=None,
|
||||
exclude_frozen_parameters=False,
|
||||
):
|
||||
"""
|
||||
将 ZeRO 2 或 ZeRO 3 格式的 DeepSpeed 检查点转换为 BF16,并输出到指定目录下,命名规则为:
|
||||
- 如果只有一个分片:
|
||||
@ -748,10 +818,7 @@ def convert_zero_checkpoint_to_bf16_state_dict(checkpoint_dir,
|
||||
raise ImportError("You need `pip install huggingface_hub` to use the sharding feature.")
|
||||
|
||||
state_dict = get_fp32_state_dict_from_zero_checkpoint(
|
||||
checkpoint_dir,
|
||||
tag=tag,
|
||||
exclude_frozen_parameters=exclude_frozen_parameters,
|
||||
lazy_mode=True
|
||||
checkpoint_dir, tag=tag, exclude_frozen_parameters=exclude_frozen_parameters, lazy_mode=True
|
||||
)
|
||||
|
||||
state_dict = to_torch_tensor(state_dict, return_empty_tensor=False)
|
||||
@ -766,9 +833,7 @@ def convert_zero_checkpoint_to_bf16_state_dict(checkpoint_dir,
|
||||
|
||||
empty_state_dict = to_torch_tensor(state_dict, return_empty_tensor=True)
|
||||
state_dict_split = split_torch_state_dict_into_shards(
|
||||
empty_state_dict,
|
||||
filename_pattern=filename_pattern,
|
||||
max_shard_size=max_shard_size
|
||||
empty_state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
|
||||
)
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
@ -789,7 +854,6 @@ def convert_zero_checkpoint_to_bf16_state_dict(checkpoint_dir,
|
||||
del shard_state_dict
|
||||
gc.collect()
|
||||
|
||||
|
||||
if state_dict_split.is_sharded:
|
||||
index = {
|
||||
"metadata": state_dict_split.metadata,
|
||||
@ -801,21 +865,29 @@ def convert_zero_checkpoint_to_bf16_state_dict(checkpoint_dir,
|
||||
else:
|
||||
only_filename = list(state_dict_split.filename_to_tensors.keys())[0]
|
||||
old_path = os.path.join(output_dir, only_filename)
|
||||
new_path = os.path.join(output_dir, "diffusion_pytorch_model.safetensors" if safe_serialization
|
||||
else "diffusion_pytorch_model.bin")
|
||||
new_path = os.path.join(
|
||||
output_dir,
|
||||
"diffusion_pytorch_model.safetensors"
|
||||
if safe_serialization
|
||||
else "diffusion_pytorch_model.bin",
|
||||
)
|
||||
if old_path != new_path:
|
||||
os.rename(old_path, new_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("checkpoint_dir",
|
||||
type=str,
|
||||
help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
|
||||
parser.add_argument("output_dir",
|
||||
type=str,
|
||||
help="directory to the pytorch fp32 state_dict output files"
|
||||
"(e.g. path/checkpoint-12-output/)")
|
||||
parser.add_argument(
|
||||
"checkpoint_dir",
|
||||
type=str,
|
||||
help="path to the desired checkpoint folder, e.g., path/checkpoint-12",
|
||||
)
|
||||
parser.add_argument(
|
||||
"output_dir",
|
||||
type=str,
|
||||
help="directory to the pytorch fp32 state_dict output files"
|
||||
"(e.g. path/checkpoint-12-output/)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_shard_size",
|
||||
type=str,
|
||||
@ -823,26 +895,34 @@ if __name__ == "__main__":
|
||||
help="The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size"
|
||||
"lower than this size. If expressed as a string, needs to be digits followed by a unit (like `5MB`"
|
||||
"We default it to 5GB in order for models to be able to run easily on free-tier google colab instances"
|
||||
"without CPU OOM issues.")
|
||||
"without CPU OOM issues.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--safe_serialization",
|
||||
default=False,
|
||||
action='store_true',
|
||||
help="Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).")
|
||||
parser.add_argument("-t",
|
||||
"--tag",
|
||||
type=str,
|
||||
default=None,
|
||||
help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1")
|
||||
parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters")
|
||||
help="Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--tag",
|
||||
type=str,
|
||||
default=None,
|
||||
help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters"
|
||||
)
|
||||
parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
|
||||
args = parser.parse_args()
|
||||
|
||||
debug = args.debug
|
||||
|
||||
convert_zero_checkpoint_to_bf16_state_dict(args.checkpoint_dir,
|
||||
args.output_dir,
|
||||
max_shard_size=args.max_shard_size,
|
||||
safe_serialization=args.safe_serialization,
|
||||
tag=args.tag,
|
||||
exclude_frozen_parameters=args.exclude_frozen_parameters)
|
||||
convert_zero_checkpoint_to_bf16_state_dict(
|
||||
args.checkpoint_dir,
|
||||
args.output_dir,
|
||||
max_shard_size=args.max_shard_size,
|
||||
safe_serialization=args.safe_serialization,
|
||||
tag=args.tag,
|
||||
exclude_frozen_parameters=args.exclude_frozen_parameters,
|
||||
)
|
||||
|
@ -10,6 +10,7 @@ Original Script:
|
||||
https://github.com/huggingface/diffusers/blob/main/scripts/convert_cogvideox_to_diffusers.py
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from typing import Any, Dict
|
||||
|
||||
@ -143,7 +144,9 @@ def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return state_dict
|
||||
|
||||
|
||||
def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
|
||||
def update_state_dict_inplace(
|
||||
state_dict: Dict[str, Any], old_key: str, new_key: str
|
||||
) -> Dict[str, Any]:
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
|
||||
@ -164,8 +167,11 @@ def convert_transformer(
|
||||
num_layers=num_layers,
|
||||
num_attention_heads=num_attention_heads,
|
||||
use_rotary_positional_embeddings=use_rotary_positional_embeddings,
|
||||
ofs_embed_dim=512 if (i2v and init_kwargs["patch_size_t"] is not None) else None, # CogVideoX1.5-5B-I2V
|
||||
use_learned_positional_embeddings=i2v and init_kwargs["patch_size_t"] is None, # CogVideoX-5B-I2V
|
||||
ofs_embed_dim=512
|
||||
if (i2v and init_kwargs["patch_size_t"] is not None)
|
||||
else None, # CogVideoX1.5-5B-I2V
|
||||
use_learned_positional_embeddings=i2v
|
||||
and init_kwargs["patch_size_t"] is None, # CogVideoX-5B-I2V
|
||||
**init_kwargs,
|
||||
).to(dtype=dtype)
|
||||
|
||||
@ -240,17 +246,40 @@ def get_transformer_init_kwargs(version: str):
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
|
||||
)
|
||||
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint")
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
||||
parser.add_argument("--fp16", action="store_true", default=False, help="Whether to save the model weights in fp16")
|
||||
parser.add_argument("--bf16", action="store_true", default=False, help="Whether to save the model weights in bf16")
|
||||
parser.add_argument(
|
||||
"--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving"
|
||||
"--transformer_ckpt_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to original transformer checkpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory"
|
||||
"--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_path", type=str, required=True, help="Path where converted model should be saved"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fp16",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to save the model weights in fp16",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bf16",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to save the model weights in bf16",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push_to_hub",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to push to HF Hub after saving",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text_encoder_cache_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to text encoder cache directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--typecast_text_encoder",
|
||||
@ -261,15 +290,24 @@ def get_args():
|
||||
# For CogVideoX-2B, num_layers is 30. For 5B, it is 42
|
||||
parser.add_argument("--num_layers", type=int, default=30, help="Number of transformer blocks")
|
||||
# For CogVideoX-2B, num_attention_heads is 30. For 5B, it is 48
|
||||
parser.add_argument("--num_attention_heads", type=int, default=30, help="Number of attention heads")
|
||||
parser.add_argument(
|
||||
"--num_attention_heads", type=int, default=30, help="Number of attention heads"
|
||||
)
|
||||
# For CogVideoX-2B, use_rotary_positional_embeddings is False. For 5B, it is True
|
||||
parser.add_argument(
|
||||
"--use_rotary_positional_embeddings", action="store_true", default=False, help="Whether to use RoPE or not"
|
||||
"--use_rotary_positional_embeddings",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to use RoPE or not",
|
||||
)
|
||||
# For CogVideoX-2B, scaling_factor is 1.15258426. For 5B, it is 0.7
|
||||
parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE")
|
||||
parser.add_argument(
|
||||
"--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE"
|
||||
)
|
||||
# For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0
|
||||
parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE")
|
||||
parser.add_argument(
|
||||
"--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--i2v",
|
||||
action="store_true",
|
||||
@ -313,7 +351,9 @@ if __name__ == "__main__":
|
||||
|
||||
text_encoder_id = "google/t5-v1_1-xxl"
|
||||
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
|
||||
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
|
||||
text_encoder = T5EncoderModel.from_pretrained(
|
||||
text_encoder_id, cache_dir=args.text_encoder_cache_dir
|
||||
)
|
||||
|
||||
if args.typecast_text_encoder:
|
||||
text_encoder = text_encoder.to(dtype=dtype)
|
||||
@ -355,4 +395,9 @@ if __name__ == "__main__":
|
||||
|
||||
# This is necessary This is necessary for users with insufficient memory,
|
||||
# such as those using Colab and notebooks, as it can save some memory used for model loading.
|
||||
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=args.push_to_hub)
|
||||
pipe.save_pretrained(
|
||||
args.output_path,
|
||||
safe_serialization=True,
|
||||
max_shard_size="5GB",
|
||||
push_to_hub=args.push_to_hub,
|
||||
)
|
||||
|
@ -15,8 +15,8 @@ def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
state_dict = state_dict["state_dict"]
|
||||
return state_dict
|
||||
|
||||
LORA_KEYS_RENAME = {
|
||||
|
||||
LORA_KEYS_RENAME = {
|
||||
'attention.query_key_value.matrix_A.0': 'attn1.to_q.lora_A.weight',
|
||||
'attention.query_key_value.matrix_A.1': 'attn1.to_k.lora_A.weight',
|
||||
'attention.query_key_value.matrix_A.2': 'attn1.to_v.lora_A.weight',
|
||||
@ -24,22 +24,18 @@ LORA_KEYS_RENAME = {
|
||||
'attention.query_key_value.matrix_B.1': 'attn1.to_k.lora_B.weight',
|
||||
'attention.query_key_value.matrix_B.2': 'attn1.to_v.lora_B.weight',
|
||||
'attention.dense.matrix_A.0': 'attn1.to_out.0.lora_A.weight',
|
||||
'attention.dense.matrix_B.0': 'attn1.to_out.0.lora_B.weight'
|
||||
'attention.dense.matrix_B.0': 'attn1.to_out.0.lora_B.weight',
|
||||
}
|
||||
|
||||
|
||||
|
||||
PREFIX_KEY = "model.diffusion_model."
|
||||
SAT_UNIT_KEY = "layers"
|
||||
LORA_PREFIX_KEY = "transformer_blocks"
|
||||
|
||||
|
||||
|
||||
def export_lora_weight(ckpt_path,lora_save_directory):
|
||||
|
||||
def export_lora_weight(ckpt_path, lora_save_directory):
|
||||
merge_original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
|
||||
|
||||
|
||||
lora_state_dict = {}
|
||||
for key in list(merge_original_state_dict.keys()):
|
||||
new_key = key[len(PREFIX_KEY) :]
|
||||
@ -50,8 +46,6 @@ def export_lora_weight(ckpt_path,lora_save_directory):
|
||||
|
||||
lora_state_dict[new_key] = merge_original_state_dict[key]
|
||||
|
||||
|
||||
|
||||
# final length should be 240
|
||||
if len(lora_state_dict) != 240:
|
||||
raise ValueError("lora_state_dict length is not 240")
|
||||
@ -64,7 +58,7 @@ def export_lora_weight(ckpt_path,lora_save_directory):
|
||||
is_main_process=True,
|
||||
weight_name=None,
|
||||
save_function=None,
|
||||
safe_serialization=True
|
||||
safe_serialization=True,
|
||||
)
|
||||
|
||||
|
||||
@ -73,7 +67,12 @@ def get_args():
|
||||
parser.add_argument(
|
||||
"--sat_pt_path", type=str, required=True, help="Path to original sat transformer checkpoint"
|
||||
)
|
||||
parser.add_argument("--lora_save_directory", type=str, required=True, help="Path where converted lora should be saved")
|
||||
parser.add_argument(
|
||||
"--lora_save_directory",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path where converted lora should be saved",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
|
@ -35,20 +35,16 @@ caption_generator = transformers.pipeline(
|
||||
"torch_dtype": torch.bfloat16,
|
||||
},
|
||||
trust_remote_code=True,
|
||||
tokenizer=tokenizer
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
image_generator = DiffusionPipeline.from_pretrained(
|
||||
image_generator_model_id,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="balanced"
|
||||
image_generator_model_id, torch_dtype=torch.bfloat16, device_map="balanced"
|
||||
)
|
||||
# image_generator.to("cuda")
|
||||
|
||||
video_generator = CogVideoXImageToVideoPipeline.from_pretrained(
|
||||
video_generator_model_id,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="balanced"
|
||||
video_generator_model_id, torch_dtype=torch.bfloat16, device_map="balanced"
|
||||
)
|
||||
|
||||
video_generator.vae.enable_slicing()
|
||||
@ -87,11 +83,7 @@ def generate_caption(prompt):
|
||||
{"role": "user", "content": prompt + "\n" + user_prompt},
|
||||
]
|
||||
|
||||
response = caption_generator(
|
||||
messages,
|
||||
max_new_tokens=226,
|
||||
return_full_text=False
|
||||
)
|
||||
response = caption_generator(messages, max_new_tokens=226, return_full_text=False)
|
||||
caption = response[0]["generated_text"]
|
||||
if caption.startswith("\"") and caption.endswith("\""):
|
||||
caption = caption[1:-1]
|
||||
@ -109,11 +101,7 @@ def generate_image(caption, progress=gr.Progress(track_tqdm=True)):
|
||||
return image, image # One for output One for State
|
||||
|
||||
|
||||
def generate_video(
|
||||
caption,
|
||||
image,
|
||||
progress=gr.Progress(track_tqdm=True)
|
||||
):
|
||||
def generate_video(caption, image, progress=gr.Progress(track_tqdm=True)):
|
||||
generator = torch.Generator().manual_seed(seed)
|
||||
video_frames = video_generator(
|
||||
image=image,
|
||||
@ -181,14 +169,19 @@ with gr.Blocks() as demo:
|
||||
image_output = gr.Image(label="Generated Image")
|
||||
state_image = gr.State()
|
||||
generate_caption_button.click(fn=generate_caption, inputs=prompt, outputs=caption)
|
||||
generate_image_button.click(fn=generate_image, inputs=caption, outputs=[image_output, state_image])
|
||||
generate_image_button.click(
|
||||
fn=generate_image, inputs=caption, outputs=[image_output, state_image]
|
||||
)
|
||||
with gr.Column():
|
||||
video_output = gr.Video(label="Generated Video", width=720, height=480)
|
||||
download_video_button = gr.File(label="📥 Download Video", visible=False)
|
||||
download_gif_button = gr.File(label="📥 Download GIF", visible=False)
|
||||
generate_video_button = gr.Button("Generate Video from Image")
|
||||
generate_video_button.click(fn=generate_video, inputs=[caption, state_image],
|
||||
outputs=[video_output, download_gif_button])
|
||||
generate_video_button.click(
|
||||
fn=generate_video,
|
||||
inputs=[caption, state_image],
|
||||
outputs=[video_output, download_gif_button],
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
||||
|
@ -65,7 +65,7 @@ def get_args():
|
||||
"--num_videos",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Number of unique videos you would like to generate."
|
||||
help="Number of unique videos you would like to generate.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_path",
|
||||
@ -83,31 +83,28 @@ def get_args():
|
||||
"--caption_generator_cache_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Cache directory for caption generation model."
|
||||
help="Cache directory for caption generation model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image_generator_model_id",
|
||||
type=str,
|
||||
default="black-forest-labs/FLUX.1-dev",
|
||||
help="Image generation model."
|
||||
help="Image generation model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image_generator_cache_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Cache directory for image generation model."
|
||||
help="Cache directory for image generation model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image_generator_num_inference_steps",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Caption generation model."
|
||||
help="Caption generation model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--guidance_scale",
|
||||
type=float,
|
||||
default=7,
|
||||
help="Guidance scale to be use for generation."
|
||||
"--guidance_scale", type=float, default=7, help="Guidance scale to be use for generation."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_dynamic_cfg",
|
||||
@ -123,19 +120,14 @@ def get_args():
|
||||
parser.add_argument(
|
||||
"--compile",
|
||||
action="store_true",
|
||||
help="Whether or not to compile the transformer of image and video generators."
|
||||
help="Whether or not to compile the transformer of image and video generators.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable_vae_tiling",
|
||||
action="store_true",
|
||||
help="Whether or not to use VAE tiling when encoding/decoding."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=42,
|
||||
help="Seed for reproducibility."
|
||||
help="Whether or not to use VAE tiling when encoding/decoding.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=42, help="Seed for reproducibility.")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -157,7 +149,9 @@ def main(args: Dict[str, Any]) -> None:
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
reset_memory()
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.caption_generator_model_id, trust_remote_code=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.caption_generator_model_id, trust_remote_code=True
|
||||
)
|
||||
caption_generator = transformers.pipeline(
|
||||
"text-generation",
|
||||
model=args.caption_generator_model_id,
|
||||
@ -168,7 +162,7 @@ def main(args: Dict[str, Any]) -> None:
|
||||
"torch_dtype": torch.bfloat16,
|
||||
},
|
||||
trust_remote_code=True,
|
||||
tokenizer=tokenizer
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
captions = []
|
||||
@ -197,12 +191,14 @@ def main(args: Dict[str, Any]) -> None:
|
||||
image_generator = DiffusionPipeline.from_pretrained(
|
||||
args.image_generator_model_id,
|
||||
cache_dir=args.image_generator_cache_dir,
|
||||
torch_dtype=torch.bfloat16
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
image_generator.to("cuda")
|
||||
|
||||
if args.compile:
|
||||
image_generator.transformer = torch.compile(image_generator.transformer, mode="max-autotune", fullgraph=True)
|
||||
image_generator.transformer = torch.compile(
|
||||
image_generator.transformer, mode="max-autotune", fullgraph=True
|
||||
)
|
||||
|
||||
if args.enable_vae_tiling:
|
||||
image_generator.vae.enable_tiling()
|
||||
@ -216,7 +212,9 @@ def main(args: Dict[str, Any]) -> None:
|
||||
num_inference_steps=args.image_generator_num_inference_steps,
|
||||
guidance_scale=3.5,
|
||||
).images[0]
|
||||
filename = caption[:25].replace(".", "_").replace("'", "_").replace('"', "_").replace(",", "_")
|
||||
filename = (
|
||||
caption[:25].replace(".", "_").replace("'", "_").replace('"', "_").replace(",", "_")
|
||||
)
|
||||
image.save(output_dir / f"{index}_{filename}.png")
|
||||
images.append(image)
|
||||
|
||||
@ -224,13 +222,16 @@ def main(args: Dict[str, Any]) -> None:
|
||||
reset_memory()
|
||||
|
||||
video_generator = CogVideoXImageToVideoPipeline.from_pretrained(
|
||||
args.model_path, torch_dtype=torch.bfloat16).to("cuda")
|
||||
args.model_path, torch_dtype=torch.bfloat16
|
||||
).to("cuda")
|
||||
video_generator.scheduler = CogVideoXDPMScheduler.from_config(
|
||||
video_generator.scheduler.config,
|
||||
timestep_spacing="trailing")
|
||||
video_generator.scheduler.config, timestep_spacing="trailing"
|
||||
)
|
||||
|
||||
if args.compile:
|
||||
video_generator.transformer = torch.compile(video_generator.transformer, mode="max-autotune", fullgraph=True)
|
||||
video_generator.transformer = torch.compile(
|
||||
video_generator.transformer, mode="max-autotune", fullgraph=True
|
||||
)
|
||||
|
||||
if args.enable_vae_tiling:
|
||||
video_generator.vae.enable_tiling()
|
||||
@ -248,7 +249,9 @@ def main(args: Dict[str, Any]) -> None:
|
||||
use_dynamic_cfg=args.use_dynamic_cfg,
|
||||
generator=generator,
|
||||
).frames[0]
|
||||
filename = caption[:25].replace(".", "_").replace("'", "_").replace('"', "_").replace(",", "_")
|
||||
filename = (
|
||||
caption[:25].replace(".", "_").replace("'", "_").replace('"', "_").replace(",", "_")
|
||||
)
|
||||
export_to_video(video, output_dir / f"{index}_{filename}.mp4", fps=8)
|
||||
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user