This commit is contained in:
Yuxuan Zhang 2025-03-22 15:14:06 +08:00
parent b9b0539dbe
commit 39c6562dc8
144 changed files with 2619 additions and 1217 deletions

View File

@ -0,0 +1,28 @@
# Contribution Guide
We welcome your contributions to this repository. To ensure elegant code style and better code quality, we have prepared the following contribution guidelines.
## What We Accept
+ This PR fixes a typo or improves the documentation (if this is the case, you may skip the other checks).
+ This PR fixes a specific issue — please reference the issue number in the PR description. Make sure your code strictly follows the coding standards below.
+ This PR introduces a new feature — please clearly explain the necessity and implementation of the feature. Make sure your code strictly follows the coding standards below.
## Code Style Guide
Good code style is an art. We have prepared a `pyproject.toml` and a `pre-commit` hook to enforce consistent code formatting across the project. You can clean up your code following the steps below:
1. Install the required dependencies:
```shell
pip install ruff pre-commit
```
2. Then, run the following command:
```shell
pre-commit run --all-files
```
If your code complies with the standards, you should not see any errors.
## Naming Conventions
- Please use **English** for naming; do not use Pinyin or other languages. All comments should also be in English.
- Follow **PEP8** naming conventions strictly, and use underscores to separate words. Avoid meaningless names such as `a`, `b`, `c`.

View File

@ -1,34 +0,0 @@
# Raise valuable PR / 提出有价值的PR
## Caution / 注意事项:
Users should keep the following points in mind when submitting PRs:
1. Ensure that your code meets the requirements in the [specification](../../resources/contribute.md).
2. the proposed PR should be relevant, if there are multiple ideas and optimizations, they should be assigned to different PRs.
用户在提交PR时候应该注意以下几点:
1. 确保您的代码符合 [规范](../../resources/contribute_zh.md) 中的要求。
2. 提出的PR应该具有针对性如果具有多个不同的想法和优化方案应该分配到不同的PR中。
## 不应该提出的PR / PRs that should not be proposed
If a developer proposes a PR about any of the following, it may be closed or Rejected.
1. those that don't describe improvement options.
2. multiple issues of different types combined in one PR.
3. The proposed PR is highly duplicative of already existing PRs.
如果开发者提出关于以下方面的PR则可能会被直接关闭或拒绝通过。
1. 没有说明改进方案的。
2. 多个不同类型的问题合并在一个PR中的。
3. 提出的PR与已经存在的PR高度重复的。
# 检查您的PR
- [ ] Have you read the Contributor Guidelines, Pull Request section? / 您是否阅读了贡献者指南、Pull Request 部分?
- [ ] Has this been discussed/approved via a Github issue or forum? If so, add a link. / 是否通过 Github 问题或论坛讨论/批准过?如果是,请添加链接。
- [ ] Did you make sure you updated the documentation with your changes? Here are the Documentation Guidelines, and here are the Documentation Formatting Tips. /您是否确保根据您的更改更新了文档?这里是文档指南,这里是文档格式化技巧。
- [ ] Did you write new required tests? / 您是否编写了新的必要测试?
- [ ] Are your PRs for only one issue / 您的PR是否仅针对一个问题

19
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,19 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.5
hooks:
- id: ruff
args: [--fix, --respect-gitignore, --config=pyproject.toml]
- id: ruff-format
args: [--config=pyproject.toml]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-toml
- id: check-case-conflict
- id: check-merge-conflict
- id: debug-statements

View File

@ -22,7 +22,7 @@ Experience the CogVideoX-5B model online at <a href="https://huggingface.co/spac
## Project Updates ## Project Updates
- 🔥🔥 **News**: ```2025/03/16```: We have launched [CogKit](https://github.com/THUDM/CogKit), a fine-tuning and inference framework for the **CogView4** and **CogVideoX** series. This toolkit allows you to fully explore and utilize our multimodal generation models. - 🔥🔥 **News**: ```2025/03/24```: We have launched [CogKit](https://github.com/THUDM/CogKit), a fine-tuning and inference framework for the **CogView4** and **CogVideoX** series. This toolkit allows you to fully explore and utilize our multimodal generation models.
- 🔥 **News**: ```2025/02/28```: DDIM Inverse is now supported in `CogVideoX-5B` and `CogVideoX1.5-5B`. Check [here](inference/ddim_inversion.py). - 🔥 **News**: ```2025/02/28```: DDIM Inverse is now supported in `CogVideoX-5B` and `CogVideoX1.5-5B`. Check [here](inference/ddim_inversion.py).
- 🔥 **News**: ```2025/01/08```: We have updated the code for `Lora` fine-tuning based on the `diffusers` version model, which uses less GPU memory. For more details, please see [here](finetune/README.md). - 🔥 **News**: ```2025/01/08```: We have updated the code for `Lora` fine-tuning based on the `diffusers` version model, which uses less GPU memory. For more details, please see [here](finetune/README.md).
- 🔥 **News**: ```2024/11/15```: We released the `CogVideoX1.5` model in the diffusers version. Only minor parameter adjustments are needed to continue using previous code. - 🔥 **News**: ```2024/11/15```: We released the `CogVideoX1.5` model in the diffusers version. Only minor parameter adjustments are needed to continue using previous code.
@ -444,8 +444,6 @@ hands-on practice on text-to-video generation. *The original input is in Chinese
} }
``` ```
We welcome your contributions! You can click [here](resources/contribute.md) for more information.
## Model-License ## Model-License
The code in this repository is released under the [Apache 2.0 License](LICENSE). The code in this repository is released under the [Apache 2.0 License](LICENSE).

View File

@ -21,7 +21,7 @@
</p> </p>
## 更新とニュース ## 更新とニュース
- 🔥🔥 ```2025/03/16```: [CogKit](https://github.com/THUDM/CogKit) は **CogView4** および **CogVideoX** シリーズの微調整と推論のためのフレームワークです。このツールキットを活用することで、私たちのマルチモーダル生成モデルを最大限に活用できます。 - 🔥🔥 ```2025/03/24```: [CogKit](https://github.com/THUDM/CogKit) は **CogView4** および **CogVideoX** シリーズの微調整と推論のためのフレームワークです。このツールキットを活用することで、私たちのマルチモーダル生成モデルを最大限に活用できます。
- **ニュース**: ```2025/02/28```: DDIM Inverse が `CogVideoX-5B` と `CogVideoX1.5-5B` でサポートされました。詳細は [こちら](inference/ddim_inversion.py) をご覧ください。 - **ニュース**: ```2025/02/28```: DDIM Inverse が `CogVideoX-5B` と `CogVideoX1.5-5B` でサポートされました。詳細は [こちら](inference/ddim_inversion.py) をご覧ください。
- **ニュース**: ```2025/01/08```: 私たちは`diffusers`バージョンのモデルをベースにした`Lora`微調整用のコードを更新しました。より少ないVRAMビデオメモリで動作します。詳細については[こちら](finetune/README_ja.md)をご覧ください。 - **ニュース**: ```2025/01/08```: 私たちは`diffusers`バージョンのモデルをベースにした`Lora`微調整用のコードを更新しました。より少ないVRAMビデオメモリで動作します。詳細については[こちら](finetune/README_ja.md)をご覧ください。
- **ニュース**: ```2024/11/15```: `CogVideoX1.5` モデルのdiffusersバージョンをリリースしました。わずかなパラメータ調整で以前のコードをそのまま利用可能です。 - **ニュース**: ```2024/11/15```: `CogVideoX1.5` モデルのdiffusersバージョンをリリースしました。わずかなパラメータ調整で以前のコードをそのまま利用可能です。
@ -418,8 +418,6 @@ CogVideoのデモは [https://models.aminer.cn/cogvideo](https://models.aminer.c
} }
``` ```
あなたの貢献をお待ちしています!詳細は[こちら](resources/contribute_ja.md)をクリックしてください。
## ライセンス契約 ## ライセンス契約
このリポジトリのコードは [Apache 2.0 License](LICENSE) の下で公開されています。 このリポジトリのコードは [Apache 2.0 License](LICENSE) の下で公開されています。

View File

@ -22,7 +22,7 @@
## 项目更新 ## 项目更新
- 🔥🔥 **News**: ```2025/03/16```: 我们推出了 [CogKit](https://github.com/THUDM/CogKit) 工具,这是一个微调**CogView4**, **CogVideoX** 系列的微调和推理框架,一个工具包,玩转我们的多模态生成模型。 - 🔥🔥 **News**: ```2025/03/24```: 我们推出了 [CogKit](https://github.com/THUDM/CogKit) 工具,这是一个微调**CogView4**, **CogVideoX** 系列的微调和推理框架,一个工具包,玩转我们的多模态生成模型。
- 🔥 **News**: ```2025/02/28```: DDIM Inverse 已经在`CogVideoX-5B` 和 `CogVideoX1.5 -5B` 支持,查看 [here](inference/ddim_inversion.py). - 🔥 **News**: ```2025/02/28```: DDIM Inverse 已经在`CogVideoX-5B` 和 `CogVideoX1.5 -5B` 支持,查看 [here](inference/ddim_inversion.py).
- 🔥 **News**: ```2025/01/08```: 我们更新了基于`diffusers`版本模型的`Lora`微调代码,占用显存更低,详情请见[这里](finetune/README_zh.md)。 - 🔥 **News**: ```2025/01/08```: 我们更新了基于`diffusers`版本模型的`Lora`微调代码,占用显存更低,详情请见[这里](finetune/README_zh.md)。
- 🔥 **News**: ```2024/11/15```: 我们发布 `CogVideoX1.5` 模型的diffusers版本仅需调整部分参数仅可沿用之前的代码。 - 🔥 **News**: ```2024/11/15```: 我们发布 `CogVideoX1.5` 模型的diffusers版本仅需调整部分参数仅可沿用之前的代码。
@ -398,8 +398,6 @@ CogVideo的demo网站在[https://models.aminer.cn/cogvideo](https://models.amine
} }
``` ```
我们欢迎您的贡献,您可以点击[这里](resources/contribute_zh.md)查看更多信息。
## 模型协议 ## 模型协议
本仓库代码使用 [Apache 2.0 协议](LICENSE) 发布。 本仓库代码使用 [Apache 2.0 协议](LICENSE) 发布。

View File

@ -26,7 +26,11 @@ class BucketSampler(Sampler):
""" """
def __init__( def __init__(
self, data_source: Dataset, batch_size: int = 8, shuffle: bool = True, drop_last: bool = False self,
data_source: Dataset,
batch_size: int = 8,
shuffle: bool = True,
drop_last: bool = False,
) -> None: ) -> None:
self.data_source = data_source self.data_source = data_source
self.batch_size = batch_size self.batch_size = batch_size
@ -48,7 +52,11 @@ class BucketSampler(Sampler):
def __iter__(self): def __iter__(self):
for index, data in enumerate(self.data_source): for index, data in enumerate(self.data_source):
video_metadata = data["video_metadata"] video_metadata = data["video_metadata"]
f, h, w = video_metadata["num_frames"], video_metadata["height"], video_metadata["width"] f, h, w = (
video_metadata["num_frames"],
video_metadata["height"],
video_metadata["width"],
)
self.buckets[(f, h, w)].append(data) self.buckets[(f, h, w)].append(data)
if len(self.buckets[(f, h, w)]) == self.batch_size: if len(self.buckets[(f, h, w)]) == self.batch_size:

View File

@ -115,7 +115,9 @@ class BaseI2VDataset(Dataset):
train_resolution_str = "x".join(str(x) for x in self.trainer.args.train_resolution) train_resolution_str = "x".join(str(x) for x in self.trainer.args.train_resolution)
cache_dir = self.trainer.args.data_root / "cache" cache_dir = self.trainer.args.data_root / "cache"
video_latent_dir = cache_dir / "video_latent" / self.trainer.args.model_name / train_resolution_str video_latent_dir = (
cache_dir / "video_latent" / self.trainer.args.model_name / train_resolution_str
)
prompt_embeddings_dir = cache_dir / "prompt_embeddings" prompt_embeddings_dir = cache_dir / "prompt_embeddings"
video_latent_dir.mkdir(parents=True, exist_ok=True) video_latent_dir.mkdir(parents=True, exist_ok=True)
prompt_embeddings_dir.mkdir(parents=True, exist_ok=True) prompt_embeddings_dir.mkdir(parents=True, exist_ok=True)
@ -136,7 +138,9 @@ class BaseI2VDataset(Dataset):
# [1, seq_len, hidden_size] -> [seq_len, hidden_size] # [1, seq_len, hidden_size] -> [seq_len, hidden_size]
prompt_embedding = prompt_embedding[0] prompt_embedding = prompt_embedding[0]
save_file({"prompt_embedding": prompt_embedding}, prompt_embedding_path) save_file({"prompt_embedding": prompt_embedding}, prompt_embedding_path)
logger.info(f"Saved prompt embedding to {prompt_embedding_path}", main_process_only=False) logger.info(
f"Saved prompt embedding to {prompt_embedding_path}", main_process_only=False
)
if encoded_video_path.exists(): if encoded_video_path.exists():
encoded_video = load_file(encoded_video_path)["encoded_video"] encoded_video = load_file(encoded_video_path)["encoded_video"]
@ -177,7 +181,9 @@ class BaseI2VDataset(Dataset):
}, },
} }
def preprocess(self, video_path: Path | None, image_path: Path | None) -> Tuple[torch.Tensor, torch.Tensor]: def preprocess(
self, video_path: Path | None, image_path: Path | None
) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Loads and preprocesses a video and an image. Loads and preprocesses a video and an image.
If either path is None, no preprocessing will be done for that input. If either path is None, no preprocessing will be done for that input.
@ -249,13 +255,19 @@ class I2VDatasetWithResize(BaseI2VDataset):
self.height = height self.height = height
self.width = width self.width = width
self.__frame_transforms = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)]) self.__frame_transforms = transforms.Compose(
[transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)]
)
self.__image_transforms = self.__frame_transforms self.__image_transforms = self.__frame_transforms
@override @override
def preprocess(self, video_path: Path | None, image_path: Path | None) -> Tuple[torch.Tensor, torch.Tensor]: def preprocess(
self, video_path: Path | None, image_path: Path | None
) -> Tuple[torch.Tensor, torch.Tensor]:
if video_path is not None: if video_path is not None:
video = preprocess_video_with_resize(video_path, self.max_num_frames, self.height, self.width) video = preprocess_video_with_resize(
video_path, self.max_num_frames, self.height, self.width
)
else: else:
video = None video = None
if image_path is not None: if image_path is not None:
@ -293,7 +305,9 @@ class I2VDatasetWithBuckets(BaseI2VDataset):
) )
for b in video_resolution_buckets for b in video_resolution_buckets
] ]
self.__frame_transforms = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)]) self.__frame_transforms = transforms.Compose(
[transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)]
)
self.__image_transforms = self.__frame_transforms self.__image_transforms = self.__frame_transforms
@override @override

View File

@ -11,7 +11,12 @@ from typing_extensions import override
from finetune.constants import LOG_LEVEL, LOG_NAME from finetune.constants import LOG_LEVEL, LOG_NAME
from .utils import load_prompts, load_videos, preprocess_video_with_buckets, preprocess_video_with_resize from .utils import (
load_prompts,
load_videos,
preprocess_video_with_buckets,
preprocess_video_with_resize,
)
if TYPE_CHECKING: if TYPE_CHECKING:
@ -93,7 +98,9 @@ class BaseT2VDataset(Dataset):
train_resolution_str = "x".join(str(x) for x in self.trainer.args.train_resolution) train_resolution_str = "x".join(str(x) for x in self.trainer.args.train_resolution)
cache_dir = self.trainer.args.data_root / "cache" cache_dir = self.trainer.args.data_root / "cache"
video_latent_dir = cache_dir / "video_latent" / self.trainer.args.model_name / train_resolution_str video_latent_dir = (
cache_dir / "video_latent" / self.trainer.args.model_name / train_resolution_str
)
prompt_embeddings_dir = cache_dir / "prompt_embeddings" prompt_embeddings_dir = cache_dir / "prompt_embeddings"
video_latent_dir.mkdir(parents=True, exist_ok=True) video_latent_dir.mkdir(parents=True, exist_ok=True)
prompt_embeddings_dir.mkdir(parents=True, exist_ok=True) prompt_embeddings_dir.mkdir(parents=True, exist_ok=True)
@ -114,7 +121,9 @@ class BaseT2VDataset(Dataset):
# [1, seq_len, hidden_size] -> [seq_len, hidden_size] # [1, seq_len, hidden_size] -> [seq_len, hidden_size]
prompt_embedding = prompt_embedding[0] prompt_embedding = prompt_embedding[0]
save_file({"prompt_embedding": prompt_embedding}, prompt_embedding_path) save_file({"prompt_embedding": prompt_embedding}, prompt_embedding_path)
logger.info(f"Saved prompt embedding to {prompt_embedding_path}", main_process_only=False) logger.info(
f"Saved prompt embedding to {prompt_embedding_path}", main_process_only=False
)
if encoded_video_path.exists(): if encoded_video_path.exists():
# encoded_video = torch.load(encoded_video_path, weights_only=True) # encoded_video = torch.load(encoded_video_path, weights_only=True)
@ -202,7 +211,9 @@ class T2VDatasetWithResize(BaseT2VDataset):
self.height = height self.height = height
self.width = width self.width = width
self.__frame_transform = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)]) self.__frame_transform = transforms.Compose(
[transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)]
)
@override @override
def preprocess(self, video_path: Path) -> torch.Tensor: def preprocess(self, video_path: Path) -> torch.Tensor:
@ -240,7 +251,9 @@ class T2VDatasetWithBuckets(BaseT2VDataset):
for b in video_resolution_buckets for b in video_resolution_buckets
] ]
self.__frame_transform = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)]) self.__frame_transform = transforms.Compose(
[transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)]
)
@override @override
def preprocess(self, video_path: Path) -> torch.Tensor: def preprocess(self, video_path: Path) -> torch.Tensor:

View File

@ -24,12 +24,16 @@ def load_prompts(prompt_path: Path) -> List[str]:
def load_videos(video_path: Path) -> List[Path]: def load_videos(video_path: Path) -> List[Path]:
with open(video_path, "r", encoding="utf-8") as file: with open(video_path, "r", encoding="utf-8") as file:
return [video_path.parent / line.strip() for line in file.readlines() if len(line.strip()) > 0] return [
video_path.parent / line.strip() for line in file.readlines() if len(line.strip()) > 0
]
def load_images(image_path: Path) -> List[Path]: def load_images(image_path: Path) -> List[Path]:
with open(image_path, "r", encoding="utf-8") as file: with open(image_path, "r", encoding="utf-8") as file:
return [image_path.parent / line.strip() for line in file.readlines() if len(line.strip()) > 0] return [
image_path.parent / line.strip() for line in file.readlines() if len(line.strip()) > 0
]
def load_images_from_videos(videos_path: List[Path]) -> List[Path]: def load_images_from_videos(videos_path: List[Path]) -> List[Path]:
@ -169,7 +173,9 @@ def preprocess_video_with_buckets(
video_num_frames = len(video_reader) video_num_frames = len(video_reader)
resolution_buckets = [bucket for bucket in resolution_buckets if bucket[0] <= video_num_frames] resolution_buckets = [bucket for bucket in resolution_buckets if bucket[0] <= video_num_frames]
if len(resolution_buckets) == 0: if len(resolution_buckets) == 0:
raise ValueError(f"video frame count in {video_path} is less than all frame buckets {resolution_buckets}") raise ValueError(
f"video frame count in {video_path} is less than all frame buckets {resolution_buckets}"
)
nearest_frame_bucket = min( nearest_frame_bucket = min(
resolution_buckets, resolution_buckets,
@ -181,7 +187,9 @@ def preprocess_video_with_buckets(
frames = frames[:nearest_frame_bucket].float() frames = frames[:nearest_frame_bucket].float()
frames = frames.permute(0, 3, 1, 2).contiguous() frames = frames.permute(0, 3, 1, 2).contiguous()
nearest_res = min(resolution_buckets, key=lambda x: abs(x[1] - frames.shape[2]) + abs(x[2] - frames.shape[3])) nearest_res = min(
resolution_buckets, key=lambda x: abs(x[1] - frames.shape[2]) + abs(x[2] - frames.shape[3])
)
nearest_res = (nearest_res[1], nearest_res[2]) nearest_res = (nearest_res[1], nearest_res[2])
frames = torch.stack([resize(f, nearest_res) for f in frames], dim=0) frames = torch.stack([resize(f, nearest_res) for f in frames], dim=0)

View File

@ -32,13 +32,19 @@ class CogVideoXI2VLoraTrainer(Trainer):
components.tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder="tokenizer") components.tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder="tokenizer")
components.text_encoder = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder") components.text_encoder = T5EncoderModel.from_pretrained(
model_path, subfolder="text_encoder"
)
components.transformer = CogVideoXTransformer3DModel.from_pretrained(model_path, subfolder="transformer") components.transformer = CogVideoXTransformer3DModel.from_pretrained(
model_path, subfolder="transformer"
)
components.vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae") components.vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae")
components.scheduler = CogVideoXDPMScheduler.from_pretrained(model_path, subfolder="scheduler") components.scheduler = CogVideoXDPMScheduler.from_pretrained(
model_path, subfolder="scheduler"
)
return components return components
@ -73,7 +79,9 @@ class CogVideoXI2VLoraTrainer(Trainer):
return_tensors="pt", return_tensors="pt",
) )
prompt_token_ids = prompt_token_ids.input_ids prompt_token_ids = prompt_token_ids.input_ids
prompt_embedding = self.components.text_encoder(prompt_token_ids.to(self.accelerator.device))[0] prompt_embedding = self.components.text_encoder(
prompt_token_ids.to(self.accelerator.device)
)[0]
return prompt_embedding return prompt_embedding
@override @override
@ -122,22 +130,34 @@ class CogVideoXI2VLoraTrainer(Trainer):
# Add frame dimension to images [B,C,H,W] -> [B,C,F,H,W] # Add frame dimension to images [B,C,H,W] -> [B,C,F,H,W]
images = images.unsqueeze(2) images = images.unsqueeze(2)
# Add noise to images # Add noise to images
image_noise_sigma = torch.normal(mean=-3.0, std=0.5, size=(1,), device=self.accelerator.device) image_noise_sigma = torch.normal(
mean=-3.0, std=0.5, size=(1,), device=self.accelerator.device
)
image_noise_sigma = torch.exp(image_noise_sigma).to(dtype=images.dtype) image_noise_sigma = torch.exp(image_noise_sigma).to(dtype=images.dtype)
noisy_images = images + torch.randn_like(images) * image_noise_sigma[:, None, None, None, None] noisy_images = (
image_latent_dist = self.components.vae.encode(noisy_images.to(dtype=self.components.vae.dtype)).latent_dist images + torch.randn_like(images) * image_noise_sigma[:, None, None, None, None]
)
image_latent_dist = self.components.vae.encode(
noisy_images.to(dtype=self.components.vae.dtype)
).latent_dist
image_latents = image_latent_dist.sample() * self.components.vae.config.scaling_factor image_latents = image_latent_dist.sample() * self.components.vae.config.scaling_factor
# Sample a random timestep for each sample # Sample a random timestep for each sample
timesteps = torch.randint( timesteps = torch.randint(
0, self.components.scheduler.config.num_train_timesteps, (batch_size,), device=self.accelerator.device 0,
self.components.scheduler.config.num_train_timesteps,
(batch_size,),
device=self.accelerator.device,
) )
timesteps = timesteps.long() timesteps = timesteps.long()
# from [B, C, F, H, W] to [B, F, C, H, W] # from [B, C, F, H, W] to [B, F, C, H, W]
latent = latent.permute(0, 2, 1, 3, 4) latent = latent.permute(0, 2, 1, 3, 4)
image_latents = image_latents.permute(0, 2, 1, 3, 4) image_latents = image_latents.permute(0, 2, 1, 3, 4)
assert (latent.shape[0], *latent.shape[2:]) == (image_latents.shape[0], *image_latents.shape[2:]) assert (latent.shape[0], *latent.shape[2:]) == (
image_latents.shape[0],
*image_latents.shape[2:],
)
# Padding image_latents to the same frame number as latent # Padding image_latents to the same frame number as latent
padding_shape = (latent.shape[0], latent.shape[1] - 1, *latent.shape[2:]) padding_shape = (latent.shape[0], latent.shape[1] - 1, *latent.shape[2:])
@ -169,7 +189,9 @@ class CogVideoXI2VLoraTrainer(Trainer):
# Predict noise, For CogVideoX1.5 Only. # Predict noise, For CogVideoX1.5 Only.
ofs_emb = ( ofs_emb = (
None if self.state.transformer_config.ofs_embed_dim is None else latent.new_full((1,), fill_value=2.0) None
if self.state.transformer_config.ofs_embed_dim is None
else latent.new_full((1,), fill_value=2.0)
) )
predicted_noise = self.components.transformer( predicted_noise = self.components.transformer(
hidden_states=latent_img_noisy, hidden_states=latent_img_noisy,
@ -181,7 +203,9 @@ class CogVideoXI2VLoraTrainer(Trainer):
)[0] )[0]
# Denoise # Denoise
latent_pred = self.components.scheduler.get_velocity(predicted_noise, latent_noisy, timesteps) latent_pred = self.components.scheduler.get_velocity(
predicted_noise, latent_noisy, timesteps
)
alphas_cumprod = self.components.scheduler.alphas_cumprod[timesteps] alphas_cumprod = self.components.scheduler.alphas_cumprod[timesteps]
weights = 1 / (1 - alphas_cumprod) weights = 1 / (1 - alphas_cumprod)
@ -228,7 +252,9 @@ class CogVideoXI2VLoraTrainer(Trainer):
if transformer_config.patch_size_t is None: if transformer_config.patch_size_t is None:
base_num_frames = num_frames base_num_frames = num_frames
else: else:
base_num_frames = (num_frames + transformer_config.patch_size_t - 1) // transformer_config.patch_size_t base_num_frames = (
num_frames + transformer_config.patch_size_t - 1
) // transformer_config.patch_size_t
freqs_cos, freqs_sin = get_3d_rotary_pos_embed( freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=transformer_config.attention_head_dim, embed_dim=transformer_config.attention_head_dim,

View File

@ -31,13 +31,19 @@ class CogVideoXT2VLoraTrainer(Trainer):
components.tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder="tokenizer") components.tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder="tokenizer")
components.text_encoder = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder") components.text_encoder = T5EncoderModel.from_pretrained(
model_path, subfolder="text_encoder"
)
components.transformer = CogVideoXTransformer3DModel.from_pretrained(model_path, subfolder="transformer") components.transformer = CogVideoXTransformer3DModel.from_pretrained(
model_path, subfolder="transformer"
)
components.vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae") components.vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae")
components.scheduler = CogVideoXDPMScheduler.from_pretrained(model_path, subfolder="scheduler") components.scheduler = CogVideoXDPMScheduler.from_pretrained(
model_path, subfolder="scheduler"
)
return components return components
@ -72,7 +78,9 @@ class CogVideoXT2VLoraTrainer(Trainer):
return_tensors="pt", return_tensors="pt",
) )
prompt_token_ids = prompt_token_ids.input_ids prompt_token_ids = prompt_token_ids.input_ids
prompt_embedding = self.components.text_encoder(prompt_token_ids.to(self.accelerator.device))[0] prompt_embedding = self.components.text_encoder(
prompt_token_ids.to(self.accelerator.device)
)[0]
return prompt_embedding return prompt_embedding
@override @override
@ -115,7 +123,10 @@ class CogVideoXT2VLoraTrainer(Trainer):
# Sample a random timestep for each sample # Sample a random timestep for each sample
timesteps = torch.randint( timesteps = torch.randint(
0, self.components.scheduler.config.num_train_timesteps, (batch_size,), device=self.accelerator.device 0,
self.components.scheduler.config.num_train_timesteps,
(batch_size,),
device=self.accelerator.device,
) )
timesteps = timesteps.long() timesteps = timesteps.long()
@ -150,7 +161,9 @@ class CogVideoXT2VLoraTrainer(Trainer):
)[0] )[0]
# Denoise # Denoise
latent_pred = self.components.scheduler.get_velocity(predicted_noise, latent_added_noise, timesteps) latent_pred = self.components.scheduler.get_velocity(
predicted_noise, latent_added_noise, timesteps
)
alphas_cumprod = self.components.scheduler.alphas_cumprod[timesteps] alphas_cumprod = self.components.scheduler.alphas_cumprod[timesteps]
weights = 1 / (1 - alphas_cumprod) weights = 1 / (1 - alphas_cumprod)
@ -196,7 +209,9 @@ class CogVideoXT2VLoraTrainer(Trainer):
if transformer_config.patch_size_t is None: if transformer_config.patch_size_t is None:
base_num_frames = num_frames base_num_frames = num_frames
else: else:
base_num_frames = (num_frames + transformer_config.patch_size_t - 1) // transformer_config.patch_size_t base_num_frames = (
num_frames + transformer_config.patch_size_t - 1
) // transformer_config.patch_size_t
freqs_cos, freqs_sin = get_3d_rotary_pos_embed( freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=transformer_config.attention_head_dim, embed_dim=transformer_config.attention_head_dim,
crops_coords=None, crops_coords=None,

View File

@ -52,6 +52,8 @@ def get_model_cls(model_type: str, training_type: Literal["lora", "sft"]) -> Tra
print(f"\nSupported training types for '{model_type}' are:") print(f"\nSupported training types for '{model_type}' are:")
for supported_type in SUPPORTED_MODELS[model_type]: for supported_type in SUPPORTED_MODELS[model_type]:
print(f"{supported_type}") print(f"{supported_type}")
raise ValueError(f"Training type '{training_type}' is not supported for model '{model_type}'") raise ValueError(
f"Training type '{training_type}' is not supported for model '{model_type}'"
)
return SUPPORTED_MODELS[model_type][training_type] return SUPPORTED_MODELS[model_type][training_type]

View File

@ -115,14 +115,18 @@ class Args(BaseModel):
def validate_validation_images(cls, v: str | None, info: ValidationInfo) -> str | None: def validate_validation_images(cls, v: str | None, info: ValidationInfo) -> str | None:
values = info.data values = info.data
if values.get("do_validation") and values.get("model_type") == "i2v" and not v: if values.get("do_validation") and values.get("model_type") == "i2v" and not v:
raise ValueError("validation_images must be specified when do_validation is True and model_type is i2v") raise ValueError(
"validation_images must be specified when do_validation is True and model_type is i2v"
)
return v return v
@field_validator("validation_videos") @field_validator("validation_videos")
def validate_validation_videos(cls, v: str | None, info: ValidationInfo) -> str | None: def validate_validation_videos(cls, v: str | None, info: ValidationInfo) -> str | None:
values = info.data values = info.data
if values.get("do_validation") and values.get("model_type") == "v2v" and not v: if values.get("do_validation") and values.get("model_type") == "v2v" and not v:
raise ValueError("validation_videos must be specified when do_validation is True and model_type is v2v") raise ValueError(
"validation_videos must be specified when do_validation is True and model_type is v2v"
)
return v return v
@field_validator("validation_steps") @field_validator("validation_steps")
@ -148,7 +152,9 @@ class Args(BaseModel):
model_name = info.data.get("model_name", "") model_name = info.data.get("model_name", "")
if model_name in ["cogvideox-5b-i2v", "cogvideox-5b-t2v"]: if model_name in ["cogvideox-5b-i2v", "cogvideox-5b-t2v"]:
if (height, width) != (480, 720): if (height, width) != (480, 720):
raise ValueError("For cogvideox-5b models, height must be 480 and width must be 720") raise ValueError(
"For cogvideox-5b models, height must be 480 and width must be 720"
)
return v return v
@ -221,7 +227,9 @@ class Args(BaseModel):
# LoRA parameters # LoRA parameters
parser.add_argument("--rank", type=int, default=128) parser.add_argument("--rank", type=int, default=128)
parser.add_argument("--lora_alpha", type=int, default=64) parser.add_argument("--lora_alpha", type=int, default=64)
parser.add_argument("--target_modules", type=str, nargs="+", default=["to_q", "to_k", "to_v", "to_out.0"]) parser.add_argument(
"--target_modules", type=str, nargs="+", default=["to_q", "to_k", "to_v", "to_out.0"]
)
# Checkpointing # Checkpointing
parser.add_argument("--checkpointing_steps", type=int, default=200) parser.add_argument("--checkpointing_steps", type=int, default=200)

View File

@ -8,7 +8,10 @@ import cv2
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"--datadir", type=str, required=True, help="Root directory containing videos.txt and video subdirectory" "--datadir",
type=str,
required=True,
help="Root directory containing videos.txt and video subdirectory",
) )
return parser.parse_args() return parser.parse_args()

View File

@ -88,7 +88,9 @@ class Trainer:
def _init_distributed(self): def _init_distributed(self):
logging_dir = Path(self.args.output_dir, "logs") logging_dir = Path(self.args.output_dir, "logs")
project_config = ProjectConfiguration(project_dir=self.args.output_dir, logging_dir=logging_dir) project_config = ProjectConfiguration(
project_dir=self.args.output_dir, logging_dir=logging_dir
)
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
init_process_group_kwargs = InitProcessGroupKwargs( init_process_group_kwargs = InitProcessGroupKwargs(
backend="nccl", timeout=timedelta(seconds=self.args.nccl_timeout) backend="nccl", timeout=timedelta(seconds=self.args.nccl_timeout)
@ -183,7 +185,9 @@ class Trainer:
# Prepare VAE and text encoder for encoding # Prepare VAE and text encoder for encoding
self.components.vae.requires_grad_(False) self.components.vae.requires_grad_(False)
self.components.text_encoder.requires_grad_(False) self.components.text_encoder.requires_grad_(False)
self.components.vae = self.components.vae.to(self.accelerator.device, dtype=self.state.weight_dtype) self.components.vae = self.components.vae.to(
self.accelerator.device, dtype=self.state.weight_dtype
)
self.components.text_encoder = self.components.text_encoder.to( self.components.text_encoder = self.components.text_encoder.to(
self.accelerator.device, dtype=self.state.weight_dtype self.accelerator.device, dtype=self.state.weight_dtype
) )
@ -263,7 +267,9 @@ class Trainer:
# For LoRA, we only want to train the LoRA weights # For LoRA, we only want to train the LoRA weights
# For SFT, we want to train all the parameters # For SFT, we want to train all the parameters
trainable_parameters = list(filter(lambda p: p.requires_grad, self.components.transformer.parameters())) trainable_parameters = list(
filter(lambda p: p.requires_grad, self.components.transformer.parameters())
)
transformer_parameters_with_lr = { transformer_parameters_with_lr = {
"params": trainable_parameters, "params": trainable_parameters,
"lr": self.args.learning_rate, "lr": self.args.learning_rate,
@ -287,7 +293,9 @@ class Trainer:
use_deepspeed=use_deepspeed_opt, use_deepspeed=use_deepspeed_opt,
) )
num_update_steps_per_epoch = math.ceil(len(self.data_loader) / self.args.gradient_accumulation_steps) num_update_steps_per_epoch = math.ceil(
len(self.data_loader) / self.args.gradient_accumulation_steps
)
if self.args.train_steps is None: if self.args.train_steps is None:
self.args.train_steps = self.args.train_epochs * num_update_steps_per_epoch self.args.train_steps = self.args.train_epochs * num_update_steps_per_epoch
self.state.overwrote_max_train_steps = True self.state.overwrote_max_train_steps = True
@ -322,12 +330,16 @@ class Trainer:
self.lr_scheduler = lr_scheduler self.lr_scheduler = lr_scheduler
def prepare_for_training(self) -> None: def prepare_for_training(self) -> None:
self.components.transformer, self.optimizer, self.data_loader, self.lr_scheduler = self.accelerator.prepare( self.components.transformer, self.optimizer, self.data_loader, self.lr_scheduler = (
self.accelerator.prepare(
self.components.transformer, self.optimizer, self.data_loader, self.lr_scheduler self.components.transformer, self.optimizer, self.data_loader, self.lr_scheduler
) )
)
# We need to recalculate our total training steps as the size of the training dataloader may have changed. # We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(self.data_loader) / self.args.gradient_accumulation_steps) num_update_steps_per_epoch = math.ceil(
len(self.data_loader) / self.args.gradient_accumulation_steps
)
if self.state.overwrote_max_train_steps: if self.state.overwrote_max_train_steps:
self.args.train_steps = self.args.train_epochs * num_update_steps_per_epoch self.args.train_steps = self.args.train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs # Afterwards we recalculate our number of training epochs
@ -364,7 +376,9 @@ class Trainer:
logger.info(f"Memory before training start: {json.dumps(memory_statistics, indent=4)}") logger.info(f"Memory before training start: {json.dumps(memory_statistics, indent=4)}")
self.state.total_batch_size_count = ( self.state.total_batch_size_count = (
self.args.batch_size * self.accelerator.num_processes * self.args.gradient_accumulation_steps self.args.batch_size
* self.accelerator.num_processes
* self.args.gradient_accumulation_steps
) )
info = { info = {
"trainable parameters": self.state.num_trainable_parameters, "trainable parameters": self.state.num_trainable_parameters,
@ -454,7 +468,9 @@ class Trainer:
progress_bar.set_postfix(logs) progress_bar.set_postfix(logs)
# Maybe run validation # Maybe run validation
should_run_validation = self.args.do_validation and global_step % self.args.validation_steps == 0 should_run_validation = (
self.args.do_validation and global_step % self.args.validation_steps == 0
)
if should_run_validation: if should_run_validation:
del loss del loss
free_memory() free_memory()
@ -466,7 +482,9 @@ class Trainer:
break break
memory_statistics = get_memory_statistics() memory_statistics = get_memory_statistics()
logger.info(f"Memory after epoch {epoch + 1}: {json.dumps(memory_statistics, indent=4)}") logger.info(
f"Memory after epoch {epoch + 1}: {json.dumps(memory_statistics, indent=4)}"
)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
self.__maybe_save_checkpoint(global_step, must_save=True) self.__maybe_save_checkpoint(global_step, must_save=True)
@ -504,7 +522,9 @@ class Trainer:
# Can't using model_cpu_offload in deepspeed, # Can't using model_cpu_offload in deepspeed,
# so we need to move all components in pipe to device # so we need to move all components in pipe to device
# pipe.to(self.accelerator.device, dtype=self.state.weight_dtype) # pipe.to(self.accelerator.device, dtype=self.state.weight_dtype)
self.__move_components_to_device(dtype=self.state.weight_dtype, ignore_list=["transformer"]) self.__move_components_to_device(
dtype=self.state.weight_dtype, ignore_list=["transformer"]
)
else: else:
# if not using deepspeed, use model_cpu_offload to further reduce memory usage # if not using deepspeed, use model_cpu_offload to further reduce memory usage
# Or use pipe.enable_sequential_cpu_offload() to further reduce memory usage # Or use pipe.enable_sequential_cpu_offload() to further reduce memory usage
@ -528,7 +548,9 @@ class Trainer:
video = self.state.validation_videos[i] video = self.state.validation_videos[i]
if image is not None: if image is not None:
image = preprocess_image_with_resize(image, self.state.train_height, self.state.train_width) image = preprocess_image_with_resize(
image, self.state.train_height, self.state.train_width
)
# Convert image tensor (C, H, W) to PIL images # Convert image tensor (C, H, W) to PIL images
image = image.to(torch.uint8) image = image.to(torch.uint8)
image = image.permute(1, 2, 0).cpu().numpy() image = image.permute(1, 2, 0).cpu().numpy()
@ -546,7 +568,9 @@ class Trainer:
f"Validating sample {i + 1}/{num_validation_samples} on process {accelerator.process_index}. Prompt: {prompt}", f"Validating sample {i + 1}/{num_validation_samples} on process {accelerator.process_index}. Prompt: {prompt}",
main_process_only=False, main_process_only=False,
) )
validation_artifacts = self.validation_step({"prompt": prompt, "image": image, "video": video}, pipe) validation_artifacts = self.validation_step(
{"prompt": prompt, "image": image, "video": video}, pipe
)
if ( if (
self.state.using_deepspeed self.state.using_deepspeed
@ -565,7 +589,9 @@ class Trainer:
"video": {"type": "video", "value": video}, "video": {"type": "video", "value": video},
} }
for i, (artifact_type, artifact_value) in enumerate(validation_artifacts): for i, (artifact_type, artifact_value) in enumerate(validation_artifacts):
artifacts.update({f"artifact_{i}": {"type": artifact_type, "value": artifact_value}}) artifacts.update(
{f"artifact_{i}": {"type": artifact_type, "value": artifact_value}}
)
logger.debug( logger.debug(
f"Validation artifacts on process {accelerator.process_index}: {list(artifacts.keys())}", f"Validation artifacts on process {accelerator.process_index}: {list(artifacts.keys())}",
main_process_only=False, main_process_only=False,
@ -600,8 +626,12 @@ class Trainer:
tracker_key = "validation" tracker_key = "validation"
for tracker in accelerator.trackers: for tracker in accelerator.trackers:
if tracker.name == "wandb": if tracker.name == "wandb":
image_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Image)] image_artifacts = [
video_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Video)] artifact for artifact in all_artifacts if isinstance(artifact, wandb.Image)
]
video_artifacts = [
artifact for artifact in all_artifacts if isinstance(artifact, wandb.Video)
]
tracker.log( tracker.log(
{ {
tracker_key: {"images": image_artifacts, "videos": video_artifacts}, tracker_key: {"images": image_artifacts, "videos": video_artifacts},
@ -618,7 +648,9 @@ class Trainer:
pipe.remove_all_hooks() pipe.remove_all_hooks()
del pipe del pipe
# Load models except those not needed for training # Load models except those not needed for training
self.__move_components_to_device(dtype=self.state.weight_dtype, ignore_list=self.UNLOAD_LIST) self.__move_components_to_device(
dtype=self.state.weight_dtype, ignore_list=self.UNLOAD_LIST
)
self.components.transformer.to(self.accelerator.device, dtype=self.state.weight_dtype) self.components.transformer.to(self.accelerator.device, dtype=self.state.weight_dtype)
# Change trainable weights back to fp32 to keep with dtype after prepare the model # Change trainable weights back to fp32 to keep with dtype after prepare the model
@ -687,7 +719,9 @@ class Trainer:
for name, component in components.items(): for name, component in components.items():
if not isinstance(component, type) and hasattr(component, "to"): if not isinstance(component, type) and hasattr(component, "to"):
if name not in ignore_list: if name not in ignore_list:
setattr(self.components, name, component.to(self.accelerator.device, dtype=dtype)) setattr(
self.components, name, component.to(self.accelerator.device, dtype=dtype)
)
def __move_components_to_cpu(self, unload_list: List[str] = []): def __move_components_to_cpu(self, unload_list: List[str] = []):
unload_list = set(unload_list) unload_list = set(unload_list)
@ -732,11 +766,13 @@ class Trainer:
): ):
transformer_ = unwrap_model(self.accelerator, model) transformer_ = unwrap_model(self.accelerator, model)
else: else:
raise ValueError(f"Unexpected save model: {unwrap_model(self.accelerator, model).__class__}") raise ValueError(
else: f"Unexpected save model: {unwrap_model(self.accelerator, model).__class__}"
transformer_ = unwrap_model(self.accelerator, self.components.transformer).__class__.from_pretrained(
self.args.model_path, subfolder="transformer"
) )
else:
transformer_ = unwrap_model(
self.accelerator, self.components.transformer
).__class__.from_pretrained(self.args.model_path, subfolder="transformer")
transformer_.add_adapter(transformer_lora_config) transformer_.add_adapter(transformer_lora_config)
lora_state_dict = self.components.pipeline_cls.lora_state_dict(input_dir) lora_state_dict = self.components.pipeline_cls.lora_state_dict(input_dir)
@ -745,7 +781,9 @@ class Trainer:
for k, v in lora_state_dict.items() for k, v in lora_state_dict.items()
if k.startswith("transformer.") if k.startswith("transformer.")
} }
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") incompatible_keys = set_peft_model_state_dict(
transformer_, transformer_state_dict, adapter_name="default"
)
if incompatible_keys is not None: if incompatible_keys is not None:
# check only for unexpected keys # check only for unexpected keys
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
@ -759,7 +797,10 @@ class Trainer:
self.accelerator.register_load_state_pre_hook(load_model_hook) self.accelerator.register_load_state_pre_hook(load_model_hook)
def __maybe_save_checkpoint(self, global_step: int, must_save: bool = False): def __maybe_save_checkpoint(self, global_step: int, must_save: bool = False):
if self.accelerator.distributed_type == DistributedType.DEEPSPEED or self.accelerator.is_main_process: if (
self.accelerator.distributed_type == DistributedType.DEEPSPEED
or self.accelerator.is_main_process
):
if must_save or global_step % self.args.checkpointing_steps == 0: if must_save or global_step % self.args.checkpointing_steps == 0:
# for training # for training
save_path = get_intermediate_ckpt_path( save_path = get_intermediate_ckpt_path(

View File

@ -23,7 +23,9 @@ def get_latest_ckpt_path_to_resume_from(
else: else:
resume_from_checkpoint_path = Path(resume_from_checkpoint) resume_from_checkpoint_path = Path(resume_from_checkpoint)
if not resume_from_checkpoint_path.exists(): if not resume_from_checkpoint_path.exists():
logger.info(f"Checkpoint '{resume_from_checkpoint}' does not exist. Starting a new training run.") logger.info(
f"Checkpoint '{resume_from_checkpoint}' does not exist. Starting a new training run."
)
initial_global_step = 0 initial_global_step = 0
global_step = 0 global_step = 0
first_epoch = 0 first_epoch = 0

View File

@ -55,7 +55,9 @@ def unload_model(model):
model.to("cpu") model.to("cpu")
def make_contiguous(x: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: def make_contiguous(
x: Union[torch.Tensor, Dict[str, torch.Tensor]],
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
if isinstance(x, torch.Tensor): if isinstance(x, torch.Tensor):
return x.contiguous() return x.contiguous()
elif isinstance(x, dict): elif isinstance(x, dict):

View File

@ -67,7 +67,9 @@ def get_optimizer(
optimizer_name = "adamw" optimizer_name = "adamw"
if (use_8bit or use_4bit) and optimizer_name not in ["adam", "adamw"]: if (use_8bit or use_4bit) and optimizer_name not in ["adam", "adamw"]:
raise ValueError("`use_8bit` and `use_4bit` can only be used with the Adam and AdamW optimizers.") raise ValueError(
"`use_8bit` and `use_4bit` can only be used with the Adam and AdamW optimizers."
)
if use_8bit: if use_8bit:
try: try:
@ -81,7 +83,9 @@ def get_optimizer(
if use_torchao: if use_torchao:
from torchao.prototype.low_bit_optim import AdamW4bit, AdamW8bit from torchao.prototype.low_bit_optim import AdamW4bit, AdamW8bit
optimizer_class = AdamW8bit if use_8bit else AdamW4bit if use_4bit else torch.optim.AdamW optimizer_class = (
AdamW8bit if use_8bit else AdamW4bit if use_4bit else torch.optim.AdamW
)
else: else:
optimizer_class = bnb.optim.AdamW8bit if use_8bit else torch.optim.AdamW optimizer_class = bnb.optim.AdamW8bit if use_8bit else torch.optim.AdamW
@ -109,7 +113,9 @@ def get_optimizer(
try: try:
import prodigyopt import prodigyopt
except ImportError: except ImportError:
raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") raise ImportError(
"To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`"
)
optimizer_class = prodigyopt.Prodigy optimizer_class = prodigyopt.Prodigy
@ -133,7 +139,9 @@ def get_optimizer(
try: try:
import came_pytorch import came_pytorch
except ImportError: except ImportError:
raise ImportError("To use CAME, please install the came-pytorch library: `pip install came-pytorch`") raise ImportError(
"To use CAME, please install the came-pytorch library: `pip install came-pytorch`"
)
optimizer_class = came_pytorch.CAME optimizer_class = came_pytorch.CAME
@ -151,7 +159,10 @@ def get_optimizer(
init_kwargs.update({"fused": True}) init_kwargs.update({"fused": True})
optimizer = CPUOffloadOptimizer( optimizer = CPUOffloadOptimizer(
params_to_optimize, optimizer_class=optimizer_class, offload_gradients=offload_gradients, **init_kwargs params_to_optimize,
optimizer_class=optimizer_class,
offload_gradients=offload_gradients,
**init_kwargs,
) )
else: else:
optimizer = optimizer_class(params_to_optimize, **init_kwargs) optimizer = optimizer_class(params_to_optimize, **init_kwargs)

View File

@ -99,7 +99,9 @@ def generate_video(
desired_resolution = RESOLUTION_MAP[model_name] desired_resolution = RESOLUTION_MAP[model_name]
if width is None or height is None: if width is None or height is None:
height, width = desired_resolution height, width = desired_resolution
logging.info(f"\033[1mUsing default resolution {desired_resolution} for {model_name}\033[0m") logging.info(
f"\033[1mUsing default resolution {desired_resolution} for {model_name}\033[0m"
)
elif (height, width) != desired_resolution: elif (height, width) != desired_resolution:
if generate_type == "i2v": if generate_type == "i2v":
# For i2v models, use user-defined width and height # For i2v models, use user-defined width and height
@ -124,7 +126,9 @@ def generate_video(
# If you're using with lora, add this code # If you're using with lora, add this code
if lora_path: if lora_path:
pipe.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name="test_1") pipe.load_lora_weights(
lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name="test_1"
)
pipe.fuse_lora(components=["transformer"], lora_scale=1 / lora_rank) pipe.fuse_lora(components=["transformer"], lora_scale=1 / lora_rank)
# 2. Set Scheduler. # 2. Set Scheduler.
@ -133,7 +137,9 @@ def generate_video(
# using `CogVideoXDPMScheduler` for CogVideoX-5B / CogVideoX-5B-I2V. # using `CogVideoXDPMScheduler` for CogVideoX-5B / CogVideoX-5B-I2V.
# pipe.scheduler = CogVideoXDDIMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") # pipe.scheduler = CogVideoXDDIMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") pipe.scheduler = CogVideoXDPMScheduler.from_config(
pipe.scheduler.config, timestep_spacing="trailing"
)
# 3. Enable CPU offload for the model. # 3. Enable CPU offload for the model.
# turn off if you have multiple GPUs or enough GPU memory(such as H100) and it will cost less time in inference # turn off if you have multiple GPUs or enough GPU memory(such as H100) and it will cost less time in inference
@ -190,8 +196,12 @@ def generate_video(
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX") parser = argparse.ArgumentParser(
parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated") description="Generate a video from a text prompt using CogVideoX"
)
parser.add_argument(
"--prompt", type=str, required=True, help="The description of the video to be generated"
)
parser.add_argument( parser.add_argument(
"--image_or_video_path", "--image_or_video_path",
type=str, type=str,
@ -199,20 +209,44 @@ if __name__ == "__main__":
help="The path of the image to be used as the background of the video", help="The path of the image to be used as the background of the video",
) )
parser.add_argument( parser.add_argument(
"--model_path", type=str, default="THUDM/CogVideoX1.5-5B", help="Path of the pre-trained model use" "--model_path",
type=str,
default="THUDM/CogVideoX1.5-5B",
help="Path of the pre-trained model use",
)
parser.add_argument(
"--lora_path", type=str, default=None, help="The path of the LoRA weights to be used"
) )
parser.add_argument("--lora_path", type=str, default=None, help="The path of the LoRA weights to be used")
parser.add_argument("--lora_rank", type=int, default=128, help="The rank of the LoRA weights") parser.add_argument("--lora_rank", type=int, default=128, help="The rank of the LoRA weights")
parser.add_argument("--output_path", type=str, default="./output.mp4", help="The path save generated video") parser.add_argument(
parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance") "--output_path", type=str, default="./output.mp4", help="The path save generated video"
)
parser.add_argument(
"--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance"
)
parser.add_argument("--num_inference_steps", type=int, default=50, help="Inference steps") parser.add_argument("--num_inference_steps", type=int, default=50, help="Inference steps")
parser.add_argument("--num_frames", type=int, default=81, help="Number of steps for the inference process") parser.add_argument(
"--num_frames", type=int, default=81, help="Number of steps for the inference process"
)
parser.add_argument("--width", type=int, default=None, help="The width of the generated video") parser.add_argument("--width", type=int, default=None, help="The width of the generated video")
parser.add_argument("--height", type=int, default=None, help="The height of the generated video") parser.add_argument(
parser.add_argument("--fps", type=int, default=16, help="The frames per second for the generated video") "--height", type=int, default=None, help="The height of the generated video"
parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt") )
parser.add_argument("--generate_type", type=str, default="t2v", help="The type of video generation") parser.add_argument(
parser.add_argument("--dtype", type=str, default="bfloat16", help="The data type for computation") "--fps", type=int, default=16, help="The frames per second for the generated video"
)
parser.add_argument(
"--num_videos_per_prompt",
type=int,
default=1,
help="Number of videos to generate per prompt",
)
parser.add_argument(
"--generate_type", type=str, default="t2v", help="The type of video generation"
)
parser.add_argument(
"--dtype", type=str, default="bfloat16", help="The data type for computation"
)
parser.add_argument("--seed", type=int, default=42, help="The seed for reproducibility") parser.add_argument("--seed", type=int, default=42, help="The seed for reproducibility")
args = parser.parse_args() args = parser.parse_args()

View File

@ -19,7 +19,12 @@ import argparse
import os import os
import torch import torch
import torch._dynamo import torch._dynamo
from diffusers import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel, CogVideoXPipeline, CogVideoXDPMScheduler from diffusers import (
AutoencoderKLCogVideoX,
CogVideoXTransformer3DModel,
CogVideoXPipeline,
CogVideoXDPMScheduler,
)
from diffusers.utils import export_to_video from diffusers.utils import export_to_video
from transformers import T5EncoderModel from transformers import T5EncoderModel
from torchao.quantization import quantize_, int8_weight_only from torchao.quantization import quantize_, int8_weight_only
@ -68,9 +73,13 @@ def generate_video(
- quantization_scheme (str): The quantization scheme to use ('int8', 'fp8'). - quantization_scheme (str): The quantization scheme to use ('int8', 'fp8').
- dtype (torch.dtype): The data type for computation (default is torch.bfloat16). - dtype (torch.dtype): The data type for computation (default is torch.bfloat16).
""" """
text_encoder = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder", torch_dtype=dtype) text_encoder = T5EncoderModel.from_pretrained(
model_path, subfolder="text_encoder", torch_dtype=dtype
)
text_encoder = quantize_model(part=text_encoder, quantization_scheme=quantization_scheme) text_encoder = quantize_model(part=text_encoder, quantization_scheme=quantization_scheme)
transformer = CogVideoXTransformer3DModel.from_pretrained(model_path, subfolder="transformer", torch_dtype=dtype) transformer = CogVideoXTransformer3DModel.from_pretrained(
model_path, subfolder="transformer", torch_dtype=dtype
)
transformer = quantize_model(part=transformer, quantization_scheme=quantization_scheme) transformer = quantize_model(part=transformer, quantization_scheme=quantization_scheme)
vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae", torch_dtype=dtype) vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae", torch_dtype=dtype)
vae = quantize_model(part=vae, quantization_scheme=quantization_scheme) vae = quantize_model(part=vae, quantization_scheme=quantization_scheme)
@ -81,7 +90,9 @@ def generate_video(
vae=vae, vae=vae,
torch_dtype=dtype, torch_dtype=dtype,
) )
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") pipe.scheduler = CogVideoXDPMScheduler.from_config(
pipe.scheduler.config, timestep_spacing="trailing"
)
pipe.enable_model_cpu_offload() pipe.enable_model_cpu_offload()
pipe.vae.enable_slicing() pipe.vae.enable_slicing()
pipe.vae.enable_tiling() pipe.vae.enable_tiling()
@ -100,16 +111,34 @@ def generate_video(
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX") parser = argparse.ArgumentParser(
parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated") description="Generate a video from a text prompt using CogVideoX"
parser.add_argument("--model_path", type=str, default="THUDM/CogVideoX-5b", help="Path of the pre-trained model") )
parser.add_argument("--output_path", type=str, default="./output.mp4", help="Path to save generated video")
parser.add_argument("--num_inference_steps", type=int, default=50, help="Inference steps")
parser.add_argument("--guidance_scale", type=float, default=6.0, help="Classifier-free guidance scale")
parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Videos to generate per prompt")
parser.add_argument("--dtype", type=str, default="bfloat16", help="Data type (e.g., 'float16', 'bfloat16')")
parser.add_argument( parser.add_argument(
"--quantization_scheme", type=str, default="fp8", choices=["int8", "fp8"], help="Quantization scheme" "--prompt", type=str, required=True, help="The description of the video to be generated"
)
parser.add_argument(
"--model_path", type=str, default="THUDM/CogVideoX-5b", help="Path of the pre-trained model"
)
parser.add_argument(
"--output_path", type=str, default="./output.mp4", help="Path to save generated video"
)
parser.add_argument("--num_inference_steps", type=int, default=50, help="Inference steps")
parser.add_argument(
"--guidance_scale", type=float, default=6.0, help="Classifier-free guidance scale"
)
parser.add_argument(
"--num_videos_per_prompt", type=int, default=1, help="Videos to generate per prompt"
)
parser.add_argument(
"--dtype", type=str, default="bfloat16", help="Data type (e.g., 'float16', 'bfloat16')"
)
parser.add_argument(
"--quantization_scheme",
type=str,
default="fp8",
choices=["int8", "fp8"],
help="Quantization scheme",
) )
parser.add_argument("--num_frames", type=int, default=81, help="Number of frames in the video") parser.add_argument("--num_frames", type=int, default=81, help="Number of frames in the video")
parser.add_argument("--fps", type=int, default=16, help="Frames per second for output video") parser.add_argument("--fps", type=int, default=16, help="Frames per second for output video")

View File

@ -104,18 +104,34 @@ def save_video(tensor, output_path):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="CogVideoX encode/decode demo") parser = argparse.ArgumentParser(description="CogVideoX encode/decode demo")
parser.add_argument("--model_path", type=str, required=True, help="The path to the CogVideoX model") parser.add_argument(
"--model_path", type=str, required=True, help="The path to the CogVideoX model"
)
parser.add_argument("--video_path", type=str, help="The path to the video file (for encoding)") parser.add_argument("--video_path", type=str, help="The path to the video file (for encoding)")
parser.add_argument("--encoded_path", type=str, help="The path to the encoded tensor file (for decoding)")
parser.add_argument("--output_path", type=str, default=".", help="The path to save the output file")
parser.add_argument( parser.add_argument(
"--mode", type=str, choices=["encode", "decode", "both"], required=True, help="Mode: encode, decode, or both" "--encoded_path", type=str, help="The path to the encoded tensor file (for decoding)"
) )
parser.add_argument( parser.add_argument(
"--dtype", type=str, default="bfloat16", help="The data type for computation (e.g., 'float16' or 'bfloat16')" "--output_path", type=str, default=".", help="The path to save the output file"
) )
parser.add_argument( parser.add_argument(
"--device", type=str, default="cuda", help="The device to use for computation (e.g., 'cuda' or 'cpu')" "--mode",
type=str,
choices=["encode", "decode", "both"],
required=True,
help="Mode: encode, decode, or both",
)
parser.add_argument(
"--dtype",
type=str,
default="bfloat16",
help="The data type for computation (e.g., 'float16' or 'bfloat16')",
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="The device to use for computation (e.g., 'cuda' or 'cpu')",
) )
args = parser.parse_args() args = parser.parse_args()
@ -126,15 +142,21 @@ if __name__ == "__main__":
assert args.video_path, "Video path must be provided for encoding." assert args.video_path, "Video path must be provided for encoding."
encoded_output = encode_video(args.model_path, args.video_path, dtype, device) encoded_output = encode_video(args.model_path, args.video_path, dtype, device)
torch.save(encoded_output, args.output_path + "/encoded.pt") torch.save(encoded_output, args.output_path + "/encoded.pt")
print(f"Finished encoding the video to a tensor, save it to a file at {encoded_output}/encoded.pt") print(
f"Finished encoding the video to a tensor, save it to a file at {encoded_output}/encoded.pt"
)
elif args.mode == "decode": elif args.mode == "decode":
assert args.encoded_path, "Encoded tensor path must be provided for decoding." assert args.encoded_path, "Encoded tensor path must be provided for decoding."
decoded_output = decode_video(args.model_path, args.encoded_path, dtype, device) decoded_output = decode_video(args.model_path, args.encoded_path, dtype, device)
save_video(decoded_output, args.output_path) save_video(decoded_output, args.output_path)
print(f"Finished decoding the video and saved it to a file at {args.output_path}/output.mp4") print(
f"Finished decoding the video and saved it to a file at {args.output_path}/output.mp4"
)
elif args.mode == "both": elif args.mode == "both":
assert args.video_path, "Video path must be provided for encoding." assert args.video_path, "Video path must be provided for encoding."
encoded_output = encode_video(args.model_path, args.video_path, dtype, device) encoded_output = encode_video(args.model_path, args.video_path, dtype, device)
torch.save(encoded_output, args.output_path + "/encoded.pt") torch.save(encoded_output, args.output_path + "/encoded.pt")
decoded_output = decode_video(args.model_path, args.output_path + "/encoded.pt", dtype, device) decoded_output = decode_video(
args.model_path, args.output_path + "/encoded.pt", dtype, device
)
save_video(decoded_output, args.output_path) save_video(decoded_output, args.output_path)

View File

@ -144,7 +144,9 @@ def convert_prompt(prompt: str, retry_times: int = 3, type: str = "t2v", image_p
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--prompt", type=str, required=True, help="Prompt to convert") parser.add_argument("--prompt", type=str, required=True, help="Prompt to convert")
parser.add_argument("--retry_times", type=int, default=3, help="Number of times to retry the conversion") parser.add_argument(
"--retry_times", type=int, default=3, help="Number of times to retry the conversion"
)
parser.add_argument("--type", type=str, default="t2v", help="Type of conversion (t2v or i2v)") parser.add_argument("--type", type=str, default="t2v", help="Type of conversion (t2v or i2v)")
parser.add_argument("--image_path", type=str, default=None, help="Path to the image file") parser.add_argument("--image_path", type=str, default=None, help="Path to the image file")
args = parser.parse_args() args = parser.parse_args()

View File

@ -30,7 +30,10 @@ import torchvision.transforms as T
from diffusers.models.attention_processor import Attention, CogVideoXAttnProcessor2_0 from diffusers.models.attention_processor import Attention, CogVideoXAttnProcessor2_0
from diffusers.models.autoencoders import AutoencoderKLCogVideoX from diffusers.models.autoencoders import AutoencoderKLCogVideoX
from diffusers.models.embeddings import apply_rotary_emb from diffusers.models.embeddings import apply_rotary_emb
from diffusers.models.transformers.cogvideox_transformer_3d import CogVideoXBlock, CogVideoXTransformer3DModel from diffusers.models.transformers.cogvideox_transformer_3d import (
CogVideoXBlock,
CogVideoXTransformer3DModel,
)
from diffusers.pipelines.cogvideo.pipeline_cogvideox import CogVideoXPipeline, retrieve_timesteps from diffusers.pipelines.cogvideo.pipeline_cogvideox import CogVideoXPipeline, retrieve_timesteps
from diffusers.schedulers import CogVideoXDDIMScheduler, DDIMInverseScheduler from diffusers.schedulers import CogVideoXDDIMScheduler, DDIMInverseScheduler
from diffusers.utils import export_to_video from diffusers.utils import export_to_video
@ -62,22 +65,48 @@ class DDIMInversionArguments(TypedDict):
def get_args() -> DDIMInversionArguments: def get_args() -> DDIMInversionArguments:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, required=True, help="Path of the pretrained model") parser.add_argument(
parser.add_argument("--prompt", type=str, required=True, help="Prompt for the direct sample procedure") "--model_path", type=str, required=True, help="Path of the pretrained model"
parser.add_argument("--video_path", type=str, required=True, help="Path of the video for inversion") )
parser.add_argument("--output_path", type=str, default="output", help="Path of the output videos") parser.add_argument(
parser.add_argument("--guidance_scale", type=float, default=6.0, help="Classifier-free guidance scale") "--prompt", type=str, required=True, help="Prompt for the direct sample procedure"
parser.add_argument("--num_inference_steps", type=int, default=50, help="Number of inference steps") )
parser.add_argument("--skip_frames_start", type=int, default=0, help="Number of skipped frames from the start") parser.add_argument(
parser.add_argument("--skip_frames_end", type=int, default=0, help="Number of skipped frames from the end") "--video_path", type=str, required=True, help="Path of the video for inversion"
parser.add_argument("--frame_sample_step", type=int, default=None, help="Temporal stride of the sampled frames") )
parser.add_argument("--max_num_frames", type=int, default=81, help="Max number of sampled frames") parser.add_argument(
"--output_path", type=str, default="output", help="Path of the output videos"
)
parser.add_argument(
"--guidance_scale", type=float, default=6.0, help="Classifier-free guidance scale"
)
parser.add_argument(
"--num_inference_steps", type=int, default=50, help="Number of inference steps"
)
parser.add_argument(
"--skip_frames_start", type=int, default=0, help="Number of skipped frames from the start"
)
parser.add_argument(
"--skip_frames_end", type=int, default=0, help="Number of skipped frames from the end"
)
parser.add_argument(
"--frame_sample_step", type=int, default=None, help="Temporal stride of the sampled frames"
)
parser.add_argument(
"--max_num_frames", type=int, default=81, help="Max number of sampled frames"
)
parser.add_argument("--width", type=int, default=720, help="Resized width of the video frames") parser.add_argument("--width", type=int, default=720, help="Resized width of the video frames")
parser.add_argument("--height", type=int, default=480, help="Resized height of the video frames") parser.add_argument(
"--height", type=int, default=480, help="Resized height of the video frames"
)
parser.add_argument("--fps", type=int, default=8, help="Frame rate of the output videos") parser.add_argument("--fps", type=int, default=8, help="Frame rate of the output videos")
parser.add_argument("--dtype", type=str, default="bf16", choices=["bf16", "fp16"], help="Dtype of the model") parser.add_argument(
"--dtype", type=str, default="bf16", choices=["bf16", "fp16"], help="Dtype of the model"
)
parser.add_argument("--seed", type=int, default=42, help="Seed for the random number generator") parser.add_argument("--seed", type=int, default=42, help="Seed for the random number generator")
parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"], help="Device for inference") parser.add_argument(
"--device", type=str, default="cuda", choices=["cuda", "cpu"], help="Device for inference"
)
args = parser.parse_args() args = parser.parse_args()
args.dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float16 args.dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float16
@ -116,13 +145,20 @@ class CogVideoXAttnProcessor2_0ForDDIMInversion(CogVideoXAttnProcessor2_0):
# Apply RoPE if needed # Apply RoPE if needed
if image_rotary_emb is not None: if image_rotary_emb is not None:
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) query[:, :, text_seq_length:] = apply_rotary_emb(
query[:, :, text_seq_length:], image_rotary_emb
)
if not attn.is_cross_attention: if not attn.is_cross_attention:
if key.size(2) == query.size(2): # Attention for reference hidden states if key.size(2) == query.size(2): # Attention for reference hidden states
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) key[:, :, text_seq_length:] = apply_rotary_emb(
key[:, :, text_seq_length:], image_rotary_emb
)
else: # RoPE should be applied to each group of image tokens else: # RoPE should be applied to each group of image tokens
key[:, :, text_seq_length : text_seq_length + image_seq_length] = apply_rotary_emb( key[:, :, text_seq_length : text_seq_length + image_seq_length] = (
key[:, :, text_seq_length : text_seq_length + image_seq_length], image_rotary_emb apply_rotary_emb(
key[:, :, text_seq_length : text_seq_length + image_seq_length],
image_rotary_emb,
)
) )
key[:, :, text_seq_length * 2 + image_seq_length :] = apply_rotary_emb( key[:, :, text_seq_length * 2 + image_seq_length :] = apply_rotary_emb(
key[:, :, text_seq_length * 2 + image_seq_length :], image_rotary_emb key[:, :, text_seq_length * 2 + image_seq_length :], image_rotary_emb
@ -162,8 +198,12 @@ class CogVideoXAttnProcessor2_0ForDDIMInversion(CogVideoXAttnProcessor2_0):
) )
if attention_mask is not None: if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) attention_mask = attn.prepare_attention_mask(
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) attention_mask, sequence_length, batch_size
)
attention_mask = attention_mask.view(
batch_size, attn.heads, -1, attention_mask.shape[-1]
)
query = attn.to_q(hidden_states) query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states) key = attn.to_k(hidden_states)
@ -260,14 +300,18 @@ def get_video_frames(
return frames.permute(0, 3, 1, 2).contiguous() # [F, C, H, W] return frames.permute(0, 3, 1, 2).contiguous() # [F, C, H, W]
def encode_video_frames(vae: AutoencoderKLCogVideoX, video_frames: torch.FloatTensor) -> torch.FloatTensor: def encode_video_frames(
vae: AutoencoderKLCogVideoX, video_frames: torch.FloatTensor
) -> torch.FloatTensor:
video_frames = video_frames.to(device=vae.device, dtype=vae.dtype) video_frames = video_frames.to(device=vae.device, dtype=vae.dtype)
video_frames = video_frames.unsqueeze(0).permute(0, 2, 1, 3, 4) # [B, C, F, H, W] video_frames = video_frames.unsqueeze(0).permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
latent_dist = vae.encode(x=video_frames).latent_dist.sample().transpose(1, 2) latent_dist = vae.encode(x=video_frames).latent_dist.sample().transpose(1, 2)
return latent_dist * vae.config.scaling_factor return latent_dist * vae.config.scaling_factor
def export_latents_to_video(pipeline: CogVideoXPipeline, latents: torch.FloatTensor, video_path: str, fps: int): def export_latents_to_video(
pipeline: CogVideoXPipeline, latents: torch.FloatTensor, video_path: str, fps: int
):
video = pipeline.decode_latents(latents) video = pipeline.decode_latents(latents)
frames = pipeline.video_processor.postprocess_video(video=video, output_type="pil") frames = pipeline.video_processor.postprocess_video(video=video, output_type="pil")
export_to_video(video_frames=frames[0], output_video_path=video_path, fps=fps) export_to_video(video_frames=frames[0], output_video_path=video_path, fps=fps)
@ -320,7 +364,9 @@ def sample(
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta) extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta)
if isinstance(scheduler, DDIMInverseScheduler): # Inverse scheduler does not accept extra kwargs if isinstance(
scheduler, DDIMInverseScheduler
): # Inverse scheduler does not accept extra kwargs
extra_step_kwargs = {} extra_step_kwargs = {}
# 7. Create rotary embeds if required # 7. Create rotary embeds if required
@ -344,7 +390,9 @@ def sample(
if pipeline.interrupt: if pipeline.interrupt:
continue continue
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = (
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
)
if reference_latents is not None: if reference_latents is not None:
reference = reference_latents[i] reference = reference_latents[i]
reference = torch.cat([reference] * 2) if do_classifier_free_guidance else reference reference = torch.cat([reference] * 2) if do_classifier_free_guidance else reference
@ -371,18 +419,31 @@ def sample(
# perform guidance # perform guidance
if use_dynamic_cfg: if use_dynamic_cfg:
pipeline._guidance_scale = 1 + guidance_scale * ( pipeline._guidance_scale = 1 + guidance_scale * (
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 (
1
- math.cos(
math.pi
* ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0
)
)
/ 2
) )
if do_classifier_free_guidance: if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + pipeline.guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + pipeline.guidance_scale * (
noise_pred_text - noise_pred_uncond
)
# compute the noisy sample x_t-1 -> x_t # compute the noisy sample x_t-1 -> x_t
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] latents = scheduler.step(
noise_pred, t, latents, **extra_step_kwargs, return_dict=False
)[0]
latents = latents.to(prompt_embeds.dtype) latents = latents.to(prompt_embeds.dtype)
trajectory[i] = latents trajectory[i] = latents
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0): if i == len(timesteps) - 1 or (
(i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0
):
progress_bar.update() progress_bar.update()
# Offload all models # Offload all models
@ -410,7 +471,9 @@ def ddim_inversion(
seed: int, seed: int,
device: torch.device, device: torch.device,
): ):
pipeline: CogVideoXPipeline = CogVideoXPipeline.from_pretrained(model_path, torch_dtype=dtype).to(device=device) pipeline: CogVideoXPipeline = CogVideoXPipeline.from_pretrained(
model_path, torch_dtype=dtype
).to(device=device)
if not pipeline.transformer.config.use_rotary_positional_embeddings: if not pipeline.transformer.config.use_rotary_positional_embeddings:
raise NotImplementedError("This script supports CogVideoX 5B model only.") raise NotImplementedError("This script supports CogVideoX 5B model only.")
video_frames = get_video_frames( video_frames = get_video_frames(

View File

@ -43,5 +43,3 @@ pip install -r requirements.txt
```bash ```bash
python app.py python app.py
``` ```

View File

@ -39,11 +39,15 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
MODEL = "THUDM/CogVideoX-5b" MODEL = "THUDM/CogVideoX-5b"
hf_hub_download(repo_id="ai-forever/Real-ESRGAN", filename="RealESRGAN_x4.pth", local_dir="model_real_esran") hf_hub_download(
repo_id="ai-forever/Real-ESRGAN", filename="RealESRGAN_x4.pth", local_dir="model_real_esran"
)
snapshot_download(repo_id="AlexWortega/RIFE", local_dir="model_rife") snapshot_download(repo_id="AlexWortega/RIFE", local_dir="model_rife")
pipe = CogVideoXPipeline.from_pretrained(MODEL, torch_dtype=torch.bfloat16).to(device) pipe = CogVideoXPipeline.from_pretrained(MODEL, torch_dtype=torch.bfloat16).to(device)
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") pipe.scheduler = CogVideoXDPMScheduler.from_config(
pipe.scheduler.config, timestep_spacing="trailing"
)
pipe_video = CogVideoXVideoToVideoPipeline.from_pretrained( pipe_video = CogVideoXVideoToVideoPipeline.from_pretrained(
MODEL, MODEL,
transformer=pipe.transformer, transformer=pipe.transformer,
@ -296,8 +300,16 @@ def delete_old_files():
threading.Thread(target=delete_old_files, daemon=True).start() threading.Thread(target=delete_old_files, daemon=True).start()
examples_videos = [["example_videos/horse.mp4"], ["example_videos/kitten.mp4"], ["example_videos/train_running.mp4"]] examples_videos = [
examples_images = [["example_images/beach.png"], ["example_images/street.png"], ["example_images/camping.png"]] ["example_videos/horse.mp4"],
["example_videos/kitten.mp4"],
["example_videos/train_running.mp4"],
]
examples_images = [
["example_images/beach.png"],
["example_images/street.png"],
["example_images/camping.png"],
]
with gr.Blocks() as demo: with gr.Blocks() as demo:
gr.Markdown(""" gr.Markdown("""
@ -322,14 +334,26 @@ with gr.Blocks() as demo:
""") """)
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
with gr.Accordion("I2V: Image Input (cannot be used simultaneously with video input)", open=False): with gr.Accordion(
"I2V: Image Input (cannot be used simultaneously with video input)", open=False
):
image_input = gr.Image(label="Input Image (will be cropped to 720 * 480)") image_input = gr.Image(label="Input Image (will be cropped to 720 * 480)")
examples_component_images = gr.Examples(examples_images, inputs=[image_input], cache_examples=False) examples_component_images = gr.Examples(
with gr.Accordion("V2V: Video Input (cannot be used simultaneously with image input)", open=False): examples_images, inputs=[image_input], cache_examples=False
video_input = gr.Video(label="Input Video (will be cropped to 49 frames, 6 seconds at 8fps)") )
with gr.Accordion(
"V2V: Video Input (cannot be used simultaneously with image input)", open=False
):
video_input = gr.Video(
label="Input Video (will be cropped to 49 frames, 6 seconds at 8fps)"
)
strength = gr.Slider(0.1, 1.0, value=0.8, step=0.01, label="Strength") strength = gr.Slider(0.1, 1.0, value=0.8, step=0.01, label="Strength")
examples_component_videos = gr.Examples(examples_videos, inputs=[video_input], cache_examples=False) examples_component_videos = gr.Examples(
prompt = gr.Textbox(label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5) examples_videos, inputs=[video_input], cache_examples=False
)
prompt = gr.Textbox(
label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5
)
with gr.Row(): with gr.Row():
gr.Markdown( gr.Markdown(
@ -340,11 +364,16 @@ with gr.Blocks() as demo:
with gr.Column(): with gr.Column():
with gr.Row(): with gr.Row():
seed_param = gr.Number( seed_param = gr.Number(
label="Inference Seed (Enter a positive number, -1 for random)", value=-1 label="Inference Seed (Enter a positive number, -1 for random)",
value=-1,
) )
with gr.Row(): with gr.Row():
enable_scale = gr.Checkbox(label="Super-Resolution (720 × 480 -> 2880 × 1920)", value=False) enable_scale = gr.Checkbox(
enable_rife = gr.Checkbox(label="Frame Interpolation (8fps -> 16fps)", value=False) label="Super-Resolution (720 × 480 -> 2880 × 1920)", value=False
)
enable_rife = gr.Checkbox(
label="Frame Interpolation (8fps -> 16fps)", value=False
)
gr.Markdown( gr.Markdown(
"✨In this demo, we use [RIFE](https://github.com/hzwer/ECCV2022-RIFE) for frame interpolation and [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) for upscaling(Super-Resolution).<br>&nbsp;&nbsp;&nbsp;&nbsp;The entire process is based on open-source solutions." "✨In this demo, we use [RIFE](https://github.com/hzwer/ECCV2022-RIFE) for frame interpolation and [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) for upscaling(Super-Resolution).<br>&nbsp;&nbsp;&nbsp;&nbsp;The entire process is based on open-source solutions."
) )
@ -430,7 +459,7 @@ with gr.Blocks() as demo:
seed_value, seed_value,
scale_status, scale_status,
rife_status, rife_status,
progress=gr.Progress(track_tqdm=True) progress=gr.Progress(track_tqdm=True),
): ):
latents, seed = infer( latents, seed = infer(
prompt, prompt,
@ -457,7 +486,9 @@ with gr.Blocks() as demo:
image_pil = VaeImageProcessor.numpy_to_pil(image_np) image_pil = VaeImageProcessor.numpy_to_pil(image_np)
batch_video_frames.append(image_pil) batch_video_frames.append(image_pil)
video_path = utils.save_video(batch_video_frames[0], fps=math.ceil((len(batch_video_frames[0]) - 1) / 6)) video_path = utils.save_video(
batch_video_frames[0], fps=math.ceil((len(batch_video_frames[0]) - 1) / 6)
)
video_update = gr.update(visible=True, value=video_path) video_update = gr.update(visible=True, value=video_path)
gif_path = convert_to_gif(video_path) gif_path = convert_to_gif(video_path)
gif_update = gr.update(visible=True, value=gif_path) gif_update = gr.update(visible=True, value=gif_path)

View File

@ -3,7 +3,9 @@ from .refine import *
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
return nn.Sequential( return nn.Sequential(
torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1), torch.nn.ConvTranspose2d(
in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1
),
nn.PReLU(out_planes), nn.PReLU(out_planes),
) )
@ -46,7 +48,11 @@ class IFBlock(nn.Module):
if scale != 1: if scale != 1:
x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False) x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False)
if flow != None: if flow != None:
flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False) * 1.0 / scale flow = (
F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False)
* 1.0
/ scale
)
x = torch.cat((x, flow), 1) x = torch.cat((x, flow), 1)
x = self.conv0(x) x = self.conv0(x)
x = self.convblock(x) + x x = self.convblock(x) + x
@ -102,7 +108,9 @@ class IFNet(nn.Module):
warped_img0_teacher = warp(img0, flow_teacher[:, :2]) warped_img0_teacher = warp(img0, flow_teacher[:, :2])
warped_img1_teacher = warp(img1, flow_teacher[:, 2:4]) warped_img1_teacher = warp(img1, flow_teacher[:, 2:4])
mask_teacher = torch.sigmoid(mask + mask_d) mask_teacher = torch.sigmoid(mask + mask_d)
merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * (1 - mask_teacher) merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * (
1 - mask_teacher
)
else: else:
flow_teacher = None flow_teacher = None
merged_teacher = None merged_teacher = None
@ -110,11 +118,16 @@ class IFNet(nn.Module):
merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i]) merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
if gt.shape[1] == 3: if gt.shape[1] == 3:
loss_mask = ( loss_mask = (
((merged[i] - gt).abs().mean(1, True) > (merged_teacher - gt).abs().mean(1, True) + 0.01) (
(merged[i] - gt).abs().mean(1, True)
> (merged_teacher - gt).abs().mean(1, True) + 0.01
)
.float() .float()
.detach() .detach()
) )
loss_distill += (((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask).mean() loss_distill += (
((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask
).mean()
c0 = self.contextnet(img0, flow[:, :2]) c0 = self.contextnet(img0, flow[:, :2])
c1 = self.contextnet(img1, flow[:, 2:4]) c1 = self.contextnet(img1, flow[:, 2:4])
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1) tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)

View File

@ -3,7 +3,9 @@ from .refine_2R import *
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
return nn.Sequential( return nn.Sequential(
torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1), torch.nn.ConvTranspose2d(
in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1
),
nn.PReLU(out_planes), nn.PReLU(out_planes),
) )
@ -46,7 +48,11 @@ class IFBlock(nn.Module):
if scale != 1: if scale != 1:
x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False) x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False)
if flow != None: if flow != None:
flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False) * 1.0 / scale flow = (
F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False)
* 1.0
/ scale
)
x = torch.cat((x, flow), 1) x = torch.cat((x, flow), 1)
x = self.conv0(x) x = self.conv0(x)
x = self.convblock(x) + x x = self.convblock(x) + x
@ -102,7 +108,9 @@ class IFNet(nn.Module):
warped_img0_teacher = warp(img0, flow_teacher[:, :2]) warped_img0_teacher = warp(img0, flow_teacher[:, :2])
warped_img1_teacher = warp(img1, flow_teacher[:, 2:4]) warped_img1_teacher = warp(img1, flow_teacher[:, 2:4])
mask_teacher = torch.sigmoid(mask + mask_d) mask_teacher = torch.sigmoid(mask + mask_d)
merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * (1 - mask_teacher) merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * (
1 - mask_teacher
)
else: else:
flow_teacher = None flow_teacher = None
merged_teacher = None merged_teacher = None
@ -110,11 +118,16 @@ class IFNet(nn.Module):
merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i]) merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
if gt.shape[1] == 3: if gt.shape[1] == 3:
loss_mask = ( loss_mask = (
((merged[i] - gt).abs().mean(1, True) > (merged_teacher - gt).abs().mean(1, True) + 0.01) (
(merged[i] - gt).abs().mean(1, True)
> (merged_teacher - gt).abs().mean(1, True) + 0.01
)
.float() .float()
.detach() .detach()
) )
loss_distill += (((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask).mean() loss_distill += (
((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask
).mean()
c0 = self.contextnet(img0, flow[:, :2]) c0 = self.contextnet(img0, flow[:, :2])
c1 = self.contextnet(img1, flow[:, 2:4]) c1 = self.contextnet(img1, flow[:, 2:4])
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1) tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)

View File

@ -61,11 +61,19 @@ class IFBlock(nn.Module):
def forward(self, x, flow, scale=1): def forward(self, x, flow, scale=1):
x = F.interpolate( x = F.interpolate(
x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False x,
scale_factor=1.0 / scale,
mode="bilinear",
align_corners=False,
recompute_scale_factor=False,
) )
flow = ( flow = (
F.interpolate( F.interpolate(
flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False flow,
scale_factor=1.0 / scale,
mode="bilinear",
align_corners=False,
recompute_scale_factor=False,
) )
* 1.0 * 1.0
/ scale / scale
@ -78,11 +86,21 @@ class IFBlock(nn.Module):
flow = self.conv1(feat) flow = self.conv1(feat)
mask = self.conv2(feat) mask = self.conv2(feat)
flow = ( flow = (
F.interpolate(flow, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) F.interpolate(
flow,
scale_factor=scale,
mode="bilinear",
align_corners=False,
recompute_scale_factor=False,
)
* scale * scale
) )
mask = F.interpolate( mask = F.interpolate(
mask, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False mask,
scale_factor=scale,
mode="bilinear",
align_corners=False,
recompute_scale_factor=False,
) )
return flow, mask return flow, mask
@ -112,7 +130,11 @@ class IFNet(nn.Module):
loss_cons = 0 loss_cons = 0
block = [self.block0, self.block1, self.block2] block = [self.block0, self.block1, self.block2]
for i in range(3): for i in range(3):
f0, m0 = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], mask), 1), flow, scale=scale_list[i]) f0, m0 = block[i](
torch.cat((warped_img0[:, :3], warped_img1[:, :3], mask), 1),
flow,
scale=scale_list[i],
)
f1, m1 = block[i]( f1, m1 = block[i](
torch.cat((warped_img1[:, :3], warped_img0[:, :3], -mask), 1), torch.cat((warped_img1[:, :3], warped_img0[:, :3], -mask), 1),
torch.cat((flow[:, 2:4], flow[:, :2]), 1), torch.cat((flow[:, 2:4], flow[:, :2]), 1),

View File

@ -3,7 +3,9 @@ from .refine import *
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
return nn.Sequential( return nn.Sequential(
torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1), torch.nn.ConvTranspose2d(
in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1
),
nn.PReLU(out_planes), nn.PReLU(out_planes),
) )
@ -46,7 +48,11 @@ class IFBlock(nn.Module):
if scale != 1: if scale != 1:
x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False) x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False)
if flow != None: if flow != None:
flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False) * 1.0 / scale flow = (
F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False)
* 1.0
/ scale
)
x = torch.cat((x, flow), 1) x = torch.cat((x, flow), 1)
x = self.conv0(x) x = self.conv0(x)
x = self.convblock(x) + x x = self.convblock(x) + x
@ -83,7 +89,9 @@ class IFNet_m(nn.Module):
for i in range(3): for i in range(3):
if flow != None: if flow != None:
flow_d, mask_d = stu[i]( flow_d, mask_d = stu[i](
torch.cat((img0, img1, timestep, warped_img0, warped_img1, mask), 1), flow, scale=scale[i] torch.cat((img0, img1, timestep, warped_img0, warped_img1, mask), 1),
flow,
scale=scale[i],
) )
flow = flow + flow_d flow = flow + flow_d
mask = mask + mask_d mask = mask + mask_d
@ -97,13 +105,17 @@ class IFNet_m(nn.Module):
merged.append(merged_student) merged.append(merged_student)
if gt.shape[1] == 3: if gt.shape[1] == 3:
flow_d, mask_d = self.block_tea( flow_d, mask_d = self.block_tea(
torch.cat((img0, img1, timestep, warped_img0, warped_img1, mask, gt), 1), flow, scale=1 torch.cat((img0, img1, timestep, warped_img0, warped_img1, mask, gt), 1),
flow,
scale=1,
) )
flow_teacher = flow + flow_d flow_teacher = flow + flow_d
warped_img0_teacher = warp(img0, flow_teacher[:, :2]) warped_img0_teacher = warp(img0, flow_teacher[:, :2])
warped_img1_teacher = warp(img1, flow_teacher[:, 2:4]) warped_img1_teacher = warp(img1, flow_teacher[:, 2:4])
mask_teacher = torch.sigmoid(mask + mask_d) mask_teacher = torch.sigmoid(mask + mask_d)
merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * (1 - mask_teacher) merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * (
1 - mask_teacher
)
else: else:
flow_teacher = None flow_teacher = None
merged_teacher = None merged_teacher = None
@ -111,11 +123,16 @@ class IFNet_m(nn.Module):
merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i]) merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
if gt.shape[1] == 3: if gt.shape[1] == 3:
loss_mask = ( loss_mask = (
((merged[i] - gt).abs().mean(1, True) > (merged_teacher - gt).abs().mean(1, True) + 0.01) (
(merged[i] - gt).abs().mean(1, True)
> (merged_teacher - gt).abs().mean(1, True) + 0.01
)
.float() .float()
.detach() .detach()
) )
loss_distill += (((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask).mean() loss_distill += (
((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask
).mean()
if returnflow: if returnflow:
return flow return flow
else: else:

View File

@ -44,7 +44,9 @@ class Model:
if torch.cuda.is_available(): if torch.cuda.is_available():
self.flownet.load_state_dict(convert(torch.load("{}/flownet.pkl".format(path)))) self.flownet.load_state_dict(convert(torch.load("{}/flownet.pkl".format(path))))
else: else:
self.flownet.load_state_dict(convert(torch.load("{}/flownet.pkl".format(path), map_location="cpu"))) self.flownet.load_state_dict(
convert(torch.load("{}/flownet.pkl".format(path), map_location="cpu"))
)
def save_model(self, path, rank=0): def save_model(self, path, rank=0):
if rank == 0: if rank == 0:

View File

@ -29,10 +29,14 @@ def downsample(x):
def upsample(x): def upsample(x):
cc = torch.cat([x, torch.zeros(x.shape[0], x.shape[1], x.shape[2], x.shape[3]).to(device)], dim=3) cc = torch.cat(
[x, torch.zeros(x.shape[0], x.shape[1], x.shape[2], x.shape[3]).to(device)], dim=3
)
cc = cc.view(x.shape[0], x.shape[1], x.shape[2] * 2, x.shape[3]) cc = cc.view(x.shape[0], x.shape[1], x.shape[2] * 2, x.shape[3])
cc = cc.permute(0, 1, 3, 2) cc = cc.permute(0, 1, 3, 2)
cc = torch.cat([cc, torch.zeros(x.shape[0], x.shape[1], x.shape[3], x.shape[2] * 2).to(device)], dim=3) cc = torch.cat(
[cc, torch.zeros(x.shape[0], x.shape[1], x.shape[3], x.shape[2] * 2).to(device)], dim=3
)
cc = cc.view(x.shape[0], x.shape[1], x.shape[3] * 2, x.shape[2] * 2) cc = cc.view(x.shape[0], x.shape[1], x.shape[3] * 2, x.shape[2] * 2)
x_up = cc.permute(0, 1, 3, 2) x_up = cc.permute(0, 1, 3, 2)
return conv_gauss(x_up, 4 * gauss_kernel(channels=x.shape[1])) return conv_gauss(x_up, 4 * gauss_kernel(channels=x.shape[1]))
@ -64,6 +68,10 @@ class LapLoss(torch.nn.Module):
self.gauss_kernel = gauss_kernel(channels=channels) self.gauss_kernel = gauss_kernel(channels=channels)
def forward(self, input, target): def forward(self, input, target):
pyr_input = laplacian_pyramid(img=input, kernel=self.gauss_kernel, max_levels=self.max_levels) pyr_input = laplacian_pyramid(
pyr_target = laplacian_pyramid(img=target, kernel=self.gauss_kernel, max_levels=self.max_levels) img=input, kernel=self.gauss_kernel, max_levels=self.max_levels
)
pyr_target = laplacian_pyramid(
img=target, kernel=self.gauss_kernel, max_levels=self.max_levels
)
return sum(torch.nn.functional.l1_loss(a, b) for a, b in zip(pyr_input, pyr_target)) return sum(torch.nn.functional.l1_loss(a, b) for a, b in zip(pyr_input, pyr_target))

View File

@ -7,7 +7,9 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def gaussian(window_size, sigma): def gaussian(window_size, sigma):
gauss = torch.Tensor([exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2)) for x in range(window_size)]) gauss = torch.Tensor(
[exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2)) for x in range(window_size)]
)
return gauss / gauss.sum() return gauss / gauss.sum()
@ -22,7 +24,9 @@ def create_window_3d(window_size, channel=1):
_1D_window = gaussian(window_size, 1.5).unsqueeze(1) _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()) _2D_window = _1D_window.mm(_1D_window.t())
_3D_window = _2D_window.unsqueeze(2) @ (_1D_window.t()) _3D_window = _2D_window.unsqueeze(2) @ (_1D_window.t())
window = _3D_window.expand(1, channel, window_size, window_size, window_size).contiguous().to(device) window = (
_3D_window.expand(1, channel, window_size, window_size, window_size).contiguous().to(device)
)
return window return window
@ -50,16 +54,35 @@ def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False,
# mu1 = F.conv2d(img1, window, padding=padd, groups=channel) # mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
# mu2 = F.conv2d(img2, window, padding=padd, groups=channel) # mu2 = F.conv2d(img2, window, padding=padd, groups=channel)
mu1 = F.conv2d(F.pad(img1, (5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=channel) mu1 = F.conv2d(
mu2 = F.conv2d(F.pad(img2, (5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=channel) F.pad(img1, (5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=channel
)
mu2 = F.conv2d(
F.pad(img2, (5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=channel
)
mu1_sq = mu1.pow(2) mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2) mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2 mu1_mu2 = mu1 * mu2
sigma1_sq = F.conv2d(F.pad(img1 * img1, (5, 5, 5, 5), "replicate"), window, padding=padd, groups=channel) - mu1_sq sigma1_sq = (
sigma2_sq = F.conv2d(F.pad(img2 * img2, (5, 5, 5, 5), "replicate"), window, padding=padd, groups=channel) - mu2_sq F.conv2d(
sigma12 = F.conv2d(F.pad(img1 * img2, (5, 5, 5, 5), "replicate"), window, padding=padd, groups=channel) - mu1_mu2 F.pad(img1 * img1, (5, 5, 5, 5), "replicate"), window, padding=padd, groups=channel
)
- mu1_sq
)
sigma2_sq = (
F.conv2d(
F.pad(img2 * img2, (5, 5, 5, 5), "replicate"), window, padding=padd, groups=channel
)
- mu2_sq
)
sigma12 = (
F.conv2d(
F.pad(img1 * img2, (5, 5, 5, 5), "replicate"), window, padding=padd, groups=channel
)
- mu1_mu2
)
C1 = (0.01 * L) ** 2 C1 = (0.01 * L) ** 2
C2 = (0.03 * L) ** 2 C2 = (0.03 * L) ** 2
@ -80,7 +103,9 @@ def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False,
return ret return ret
def ssim_matlab(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None): def ssim_matlab(
img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None
):
# Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh). # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
if val_range is None: if val_range is None:
if torch.max(img1) > 128: if torch.max(img1) > 128:
@ -106,16 +131,35 @@ def ssim_matlab(img1, img2, window_size=11, window=None, size_average=True, full
img1 = img1.unsqueeze(1) img1 = img1.unsqueeze(1)
img2 = img2.unsqueeze(1) img2 = img2.unsqueeze(1)
mu1 = F.conv3d(F.pad(img1, (5, 5, 5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=1) mu1 = F.conv3d(
mu2 = F.conv3d(F.pad(img2, (5, 5, 5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=1) F.pad(img1, (5, 5, 5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=1
)
mu2 = F.conv3d(
F.pad(img2, (5, 5, 5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=1
)
mu1_sq = mu1.pow(2) mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2) mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2 mu1_mu2 = mu1 * mu2
sigma1_sq = F.conv3d(F.pad(img1 * img1, (5, 5, 5, 5, 5, 5), "replicate"), window, padding=padd, groups=1) - mu1_sq sigma1_sq = (
sigma2_sq = F.conv3d(F.pad(img2 * img2, (5, 5, 5, 5, 5, 5), "replicate"), window, padding=padd, groups=1) - mu2_sq F.conv3d(
sigma12 = F.conv3d(F.pad(img1 * img2, (5, 5, 5, 5, 5, 5), "replicate"), window, padding=padd, groups=1) - mu1_mu2 F.pad(img1 * img1, (5, 5, 5, 5, 5, 5), "replicate"), window, padding=padd, groups=1
)
- mu1_sq
)
sigma2_sq = (
F.conv3d(
F.pad(img2 * img2, (5, 5, 5, 5, 5, 5), "replicate"), window, padding=padd, groups=1
)
- mu2_sq
)
sigma12 = (
F.conv3d(
F.pad(img1 * img2, (5, 5, 5, 5, 5, 5), "replicate"), window, padding=padd, groups=1
)
- mu1_mu2
)
C1 = (0.01 * L) ** 2 C1 = (0.01 * L) ** 2
C2 = (0.03 * L) ** 2 C2 = (0.03 * L) ** 2
@ -143,7 +187,14 @@ def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normal
mssim = [] mssim = []
mcs = [] mcs = []
for _ in range(levels): for _ in range(levels):
sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range) sim, cs = ssim(
img1,
img2,
window_size=window_size,
size_average=size_average,
full=True,
val_range=val_range,
)
mssim.append(sim) mssim.append(sim)
mcs.append(cs) mcs.append(cs)
@ -187,7 +238,9 @@ class SSIM(torch.nn.Module):
self.window = window self.window = window
self.channel = channel self.channel = channel
_ssim = ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average) _ssim = ssim(
img1, img2, window=window, window_size=self.window_size, size_average=self.size_average
)
dssim = (1 - _ssim) / 2 dssim = (1 - _ssim) / 2
return dssim return dssim

View File

@ -24,7 +24,12 @@ def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
return nn.Sequential( return nn.Sequential(
torch.nn.ConvTranspose2d( torch.nn.ConvTranspose2d(
in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1, bias=True in_channels=in_planes,
out_channels=out_planes,
kernel_size=4,
stride=2,
padding=1,
bias=True,
), ),
nn.PReLU(out_planes), nn.PReLU(out_planes),
) )
@ -56,25 +61,49 @@ class Contextnet(nn.Module):
def forward(self, x, flow): def forward(self, x, flow):
x = self.conv1(x) x = self.conv1(x)
flow = ( flow = (
F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) F.interpolate(
flow,
scale_factor=0.5,
mode="bilinear",
align_corners=False,
recompute_scale_factor=False,
)
* 0.5 * 0.5
) )
f1 = warp(x, flow) f1 = warp(x, flow)
x = self.conv2(x) x = self.conv2(x)
flow = ( flow = (
F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) F.interpolate(
flow,
scale_factor=0.5,
mode="bilinear",
align_corners=False,
recompute_scale_factor=False,
)
* 0.5 * 0.5
) )
f2 = warp(x, flow) f2 = warp(x, flow)
x = self.conv3(x) x = self.conv3(x)
flow = ( flow = (
F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) F.interpolate(
flow,
scale_factor=0.5,
mode="bilinear",
align_corners=False,
recompute_scale_factor=False,
)
* 0.5 * 0.5
) )
f3 = warp(x, flow) f3 = warp(x, flow)
x = self.conv4(x) x = self.conv4(x)
flow = ( flow = (
F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) F.interpolate(
flow,
scale_factor=0.5,
mode="bilinear",
align_corners=False,
recompute_scale_factor=False,
)
* 0.5 * 0.5
) )
f4 = warp(x, flow) f4 = warp(x, flow)

View File

@ -24,7 +24,12 @@ def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
return nn.Sequential( return nn.Sequential(
torch.nn.ConvTranspose2d( torch.nn.ConvTranspose2d(
in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1, bias=True in_channels=in_planes,
out_channels=out_planes,
kernel_size=4,
stride=2,
padding=1,
bias=True,
), ),
nn.PReLU(out_planes), nn.PReLU(out_planes),
) )
@ -59,19 +64,37 @@ class Contextnet(nn.Module):
f1 = warp(x, flow) f1 = warp(x, flow)
x = self.conv2(x) x = self.conv2(x)
flow = ( flow = (
F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) F.interpolate(
flow,
scale_factor=0.5,
mode="bilinear",
align_corners=False,
recompute_scale_factor=False,
)
* 0.5 * 0.5
) )
f2 = warp(x, flow) f2 = warp(x, flow)
x = self.conv3(x) x = self.conv3(x)
flow = ( flow = (
F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) F.interpolate(
flow,
scale_factor=0.5,
mode="bilinear",
align_corners=False,
recompute_scale_factor=False,
)
* 0.5 * 0.5
) )
f3 = warp(x, flow) f3 = warp(x, flow)
x = self.conv4(x) x = self.conv4(x)
flow = ( flow = (
F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) F.interpolate(
flow,
scale_factor=0.5,
mode="bilinear",
align_corners=False,
recompute_scale_factor=False,
)
* 0.5 * 0.5
) )
f4 = warp(x, flow) f4 = warp(x, flow)

View File

@ -9,6 +9,7 @@ import logging
import skvideo.io import skvideo.io
from rife.RIFE_HDv3 import Model from rife.RIFE_HDv3 import Model
from huggingface_hub import hf_hub_download, snapshot_download from huggingface_hub import hf_hub_download, snapshot_download
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
@ -78,13 +79,12 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi
# print(f'I1[0] unpadded shape:{I1.shape}') # print(f'I1[0] unpadded shape:{I1.shape}')
I1_small = F.interpolate(I1, (32, 32), mode="bilinear", align_corners=False) I1_small = F.interpolate(I1, (32, 32), mode="bilinear", align_corners=False)
ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3]) ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
if padding[3] > 0 and padding[1] >0 : if padding[3] > 0 and padding[1] > 0:
frame = I1[:, :, : -padding[3], : -padding[1]]
frame = I1[:, :, : -padding[3],:-padding[1]]
elif padding[3] > 0: elif padding[3] > 0:
frame = I1[:, :, : -padding[3],:] frame = I1[:, :, : -padding[3], :]
elif padding[1] >0: elif padding[1] > 0:
frame = I1[:, :, :,:-padding[1]] frame = I1[:, :, :, : -padding[1]]
else: else:
frame = I1 frame = I1
@ -102,7 +102,6 @@ def ssim_interpolation_rife(model, samples, exp=1, upscale_amount=1, output_devi
frame = F.interpolate(frame, size=(h, w)) frame = F.interpolate(frame, size=(h, w))
output.append(frame.to(output_device)) output.append(frame.to(output_device))
for i, tmp_frame in enumerate(tmp_output): for i, tmp_frame in enumerate(tmp_output):
# tmp_frame, _ = pad_image(tmp_frame, upscale_amount) # tmp_frame, _ = pad_image(tmp_frame, upscale_amount)
tmp_frame = F.interpolate(tmp_frame, size=(h, w)) tmp_frame = F.interpolate(tmp_frame, size=(h, w))
output.append(tmp_frame.to(output_device)) output.append(tmp_frame.to(output_device))
@ -145,9 +144,7 @@ def rife_inference_with_path(model, video_path):
frame_rgb = frame[..., ::-1] frame_rgb = frame[..., ::-1]
frame_rgb = frame_rgb.copy() frame_rgb = frame_rgb.copy()
tensor = torch.from_numpy(frame_rgb).float().to("cpu", non_blocking=True).float() / 255.0 tensor = torch.from_numpy(frame_rgb).float().to("cpu", non_blocking=True).float() / 255.0
pt_frame_data.append( pt_frame_data.append(tensor.permute(2, 0, 1)) # to [c, h, w,]
tensor.permute(2, 0, 1)
) # to [c, h, w,]
pt_frame = torch.from_numpy(np.stack(pt_frame_data)) pt_frame = torch.from_numpy(np.stack(pt_frame_data))
pt_frame = pt_frame.to(device) pt_frame = pt_frame.to(device)
@ -170,7 +167,9 @@ def rife_inference_with_latents(model, latents):
latent = latents[i] latent = latents[i]
frames = ssim_interpolation_rife(model, latent) frames = ssim_interpolation_rife(model, latent)
pt_image = torch.stack([frames[i].squeeze(0) for i in range(len(frames))]) # (to [f, c, w, h]) pt_image = torch.stack(
[frames[i].squeeze(0) for i in range(len(frames))]
) # (to [f, c, w, h])
rife_results.append(pt_image) rife_results.append(pt_image)
return torch.stack(rife_results) return torch.stack(rife_results)

View File

@ -22,7 +22,7 @@ def load_torch_file(ckpt, device=None, dtype=torch.float16):
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"): if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
sd = safetensors.torch.load_file(ckpt, device=device.type) sd = safetensors.torch.load_file(ckpt, device=device.type)
else: else:
if not "weights_only" in torch.load.__code__.co_varnames: if "weights_only" not in torch.load.__code__.co_varnames:
logger.warning( logger.warning(
"Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely." "Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely."
) )
@ -74,27 +74,39 @@ def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
@torch.inference_mode() @torch.inference_mode()
def tiled_scale_multidim( def tiled_scale_multidim(
samples, function, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", pbar=None samples,
function,
tile=(64, 64),
overlap=8,
upscale_amount=4,
out_channels=3,
output_device="cpu",
pbar=None,
): ):
dims = len(tile) dims = len(tile)
print(f"samples dtype:{samples.dtype}") print(f"samples dtype:{samples.dtype}")
output = torch.empty( output = torch.empty(
[samples.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), samples.shape[2:])), [samples.shape[0], out_channels]
+ list(map(lambda a: round(a * upscale_amount), samples.shape[2:])),
device=output_device, device=output_device,
) )
for b in range(samples.shape[0]): for b in range(samples.shape[0]):
s = samples[b : b + 1] s = samples[b : b + 1]
out = torch.zeros( out = torch.zeros(
[s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), [s.shape[0], out_channels]
+ list(map(lambda a: round(a * upscale_amount), s.shape[2:])),
device=output_device, device=output_device,
) )
out_div = torch.zeros( out_div = torch.zeros(
[s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), [s.shape[0], out_channels]
+ list(map(lambda a: round(a * upscale_amount), s.shape[2:])),
device=output_device, device=output_device,
) )
for it in itertools.product(*map(lambda a: range(0, a[0], a[1] - overlap), zip(s.shape[2:], tile))): for it in itertools.product(
*map(lambda a: range(0, a[0], a[1] - overlap), zip(s.shape[2:], tile))
):
s_in = s s_in = s
upscaled = [] upscaled = []
@ -142,7 +154,14 @@ def tiled_scale(
pbar=None, pbar=None,
): ):
return tiled_scale_multidim( return tiled_scale_multidim(
samples, function, (tile_y, tile_x), overlap, upscale_amount, out_channels, output_device, pbar samples,
function,
(tile_y, tile_x),
overlap,
upscale_amount,
out_channels,
output_device,
pbar,
) )
@ -186,7 +205,9 @@ def upscale(upscale_model, tensor: torch.Tensor, inf_device, output_device="cpu"
return s return s
def upscale_batch_and_concatenate(upscale_model, latents, inf_device, output_device="cpu") -> torch.Tensor: def upscale_batch_and_concatenate(
upscale_model, latents, inf_device, output_device="cpu"
) -> torch.Tensor:
upscaled_latents = [] upscaled_latents = []
for i in range(latents.size(0)): for i in range(latents.size(0)):
latent = latents[i] latent = latents[i]
@ -207,7 +228,9 @@ class ProgressBar:
def __init__(self, total, desc=None): def __init__(self, total, desc=None):
self.total = total self.total = total
self.current = 0 self.current = 0
self.b_unit = tqdm.tqdm(total=total, desc="ProgressBar context index: 0" if desc is None else desc) self.b_unit = tqdm.tqdm(
total=total, desc="ProgressBar context index: 0" if desc is None else desc
)
def update(self, value): def update(self, value):
if value > self.total: if value > self.total:

View File

@ -22,7 +22,9 @@ from datetime import datetime, timedelta
from openai import OpenAI from openai import OpenAI
from moviepy import VideoFileClip from moviepy import VideoFileClip
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to("cuda") pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to(
"cuda"
)
pipe.vae.enable_slicing() pipe.vae.enable_slicing()
pipe.vae.enable_tiling() pipe.vae.enable_tiling()
@ -95,7 +97,12 @@ def convert_prompt(prompt: str, retry_times: int = 3) -> str:
return prompt return prompt
def infer(prompt: str, num_inference_steps: int, guidance_scale: float, progress=gr.Progress(track_tqdm=True)): def infer(
prompt: str,
num_inference_steps: int,
guidance_scale: float,
progress=gr.Progress(track_tqdm=True),
):
torch.cuda.empty_cache() torch.cuda.empty_cache()
video = pipe( video = pipe(
prompt=prompt, prompt=prompt,
@ -151,7 +158,9 @@ with gr.Blocks() as demo:
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
prompt = gr.Textbox(label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5) prompt = gr.Textbox(
label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5
)
with gr.Row(): with gr.Row():
gr.Markdown( gr.Markdown(
@ -176,7 +185,13 @@ with gr.Blocks() as demo:
download_video_button = gr.File(label="📥 Download Video", visible=False) download_video_button = gr.File(label="📥 Download Video", visible=False)
download_gif_button = gr.File(label="📥 Download GIF", visible=False) download_gif_button = gr.File(label="📥 Download GIF", visible=False)
def generate(prompt, num_inference_steps, guidance_scale, model_choice, progress=gr.Progress(track_tqdm=True)): def generate(
prompt,
num_inference_steps,
guidance_scale,
model_choice,
progress=gr.Progress(track_tqdm=True),
):
tensor = infer(prompt, num_inference_steps, guidance_scale, progress=progress) tensor = infer(prompt, num_inference_steps, guidance_scale, progress=progress)
video_path = save_video(tensor) video_path = save_video(tensor)
video_update = gr.update(visible=True, value=video_path) video_update = gr.update(visible=True, value=video_path)

View File

@ -4,4 +4,3 @@
<p> 扫码关注公众号,加入「 CogVideoX 交流群」 </p> <p> 扫码关注公众号,加入「 CogVideoX 交流群」 </p>
<p> Scan the QR code to follow the official account and join the "CogVLM Discussion Group" </p> <p> Scan the QR code to follow the official account and join the "CogVLM Discussion Group" </p>
</div> </div>

View File

@ -1,49 +0,0 @@
# Contribution Guide
There may still be many incomplete aspects in this project.
We look forward to your contributions to the repository in the following areas. If you complete the work mentioned above
and are willing to submit a PR and share it with the community, upon review, we
will acknowledge your contribution on the project homepage.
## Model Algorithms
- Support for model quantization inference (Int4 quantization project)
- Optimization of model fine-tuning data loading (replacing the existing decord tool)
## Model Engineering
- Model fine-tuning examples / Best prompt practices
- Inference adaptation on different devices (e.g., MLX framework)
- Any tools related to the model
- Any minimal fully open-source project using the CogVideoX open-source model
## Code Standards
Good code style is an art. We have prepared a `pyproject.toml` configuration file for the project to standardize code
style. You can organize the code according to the following specifications:
1. Install the `ruff` tool
```shell
pip install ruff
```
Then, run the `ruff` tool
```shell
ruff check tools sat inference
```
Check the code style. If there are issues, you can automatically fix them using the `ruff format` command.
```shell
ruff format tools sat inference
```
Once your code meets the standard, there should be no errors.
## Naming Conventions
1. Please use English names, do not use Pinyin or other language names. All comments should be in English.
2. Please strictly follow the PEP8 specification and use underscores to separate words. Do not use names like a, b, c.

View File

@ -1,47 +0,0 @@
# コントリビューションガイド
本プロジェクトにはまだ多くの未完成の部分があります。
以下の分野でリポジトリへの貢献をお待ちしています。上記の作業を完了し、PRを提出してコミュニティと共有する意志がある場合、レビュー後、プロジェクトのホームページで貢献を認識します。
## モデルアルゴリズム
- モデル量子化推論のサポート (Int4量子化プロジェクト)
- モデルのファインチューニングデータロードの最適化既存のdecordツールの置き換え
## モデルエンジニアリング
- モデルのファインチューニング例 / 最適なプロンプトの実践
- 異なるデバイスでの推論適応(例: MLXフレームワーク
- モデルに関連するツール
- CogVideoXオープンソースモデルを使用した、完全にオープンソースの最小プロジェクト
## コード標準
良いコードスタイルは一種の芸術です。本プロジェクトにはコードスタイルを標準化するための `pyproject.toml`
設定ファイルを用意しています。以下の仕様に従ってコードを整理してください。
1. `ruff` ツールをインストールする
```shell
pip install ruff
```
次に、`ruff` ツールを実行します
```shell
ruff check tools sat inference
```
コードスタイルを確認します。問題がある場合は、`ruff format` コマンドを使用して自動修正できます。
```shell
ruff format tools sat inference
```
コードが標準に準拠したら、エラーはなくなるはずです。
## 命名規則
1. 英語名を使用してください。ピンインや他の言語の名前を使用しないでください。すべてのコメントは英語で記載してください。
2. PEP8仕様に厳密に従い、単語をアンダースコアで区切ってください。a、b、cのような名前は使用しないでください。

View File

@ -1,44 +0,0 @@
# 贡献指南
本项目可能还存在很多不完善的内容。 我们期待您在以下方面与我们共建仓库, 如果您完成了上述工作并愿意PR和分享到社区在通过审核后我们将在项目首页感谢您的贡献。
## 模型算法
- 模型量化推理支持 (Int4量化工程)
- 模型微调数据载入优化支持(替换现有的decord工具)
## 模型工程
- 模型微调示例 / 最佳提示词实践
- 不同设备上的推理适配(MLX等框架)
- 任何模型周边工具
- 任何使用CogVideoX开源模型制作的最小完整开源项目
## 代码规范
良好的代码风格是一种艺术,我们已经为项目准备好了`pyproject.toml`配置文件,用于规范代码风格。您可以按照以下规范梳理代码:
1. 安装`ruff`工具
```shell
pip install ruff
```
接着,运行`ruff`工具
```shell
ruff check tools sat inference
```
检查代码风格,如果有问题,您可以通过`ruff format .`命令自动修复。
```shell
ruff format tools sat inference
```
如果您的代码符合规范,应该不会出现任何的错误。
## 命名规范
- 请使用英文命名,不要使用拼音或者其他语言命名。所有的注释均使用英文。
- 请严格遵循 PEP8 规范,使用下划线分割单词。请勿使用 a,b,c 这样的命名。

View File

@ -18,7 +18,10 @@ def add_model_config_args(parser):
group = parser.add_argument_group("model", "model configuration") group = parser.add_argument_group("model", "model configuration")
group.add_argument("--base", type=str, nargs="*", help="config for input and saving") group.add_argument("--base", type=str, nargs="*", help="config for input and saving")
group.add_argument( group.add_argument(
"--model-parallel-size", type=int, default=1, help="size of the model parallel. only use if you are an expert." "--model-parallel-size",
type=int,
default=1,
help="size of the model parallel. only use if you are an expert.",
) )
group.add_argument("--force-pretrain", action="store_true") group.add_argument("--force-pretrain", action="store_true")
group.add_argument("--device", type=int, default=-1) group.add_argument("--device", type=int, default=-1)
@ -74,10 +77,15 @@ def get_args(args_list=None, parser=None):
if not args.train_data: if not args.train_data:
print_rank0("No training data specified", level="WARNING") print_rank0("No training data specified", level="WARNING")
assert (args.train_iters is None) or (args.epochs is None), "only one of train_iters and epochs should be set." assert (args.train_iters is None) or (
args.epochs is None
), "only one of train_iters and epochs should be set."
if args.train_iters is None and args.epochs is None: if args.train_iters is None and args.epochs is None:
args.train_iters = 10000 # default 10k iters args.train_iters = 10000 # default 10k iters
print_rank0("No train_iters (recommended) or epochs specified, use default 10k iters.", level="WARNING") print_rank0(
"No train_iters (recommended) or epochs specified, use default 10k iters.",
level="WARNING",
)
args.cuda = torch.cuda.is_available() args.cuda = torch.cuda.is_available()
@ -213,7 +221,10 @@ def initialize_distributed(args):
args.master_port = os.getenv("MASTER_PORT", default_master_port) args.master_port = os.getenv("MASTER_PORT", default_master_port)
init_method += args.master_ip + ":" + args.master_port init_method += args.master_ip + ":" + args.master_port
torch.distributed.init_process_group( torch.distributed.init_process_group(
backend=args.distributed_backend, world_size=args.world_size, rank=args.rank, init_method=init_method backend=args.distributed_backend,
world_size=args.world_size,
rank=args.rank,
init_method=init_method,
) )
# Set the model-parallel / data-parallel communicators. # Set the model-parallel / data-parallel communicators.
@ -232,7 +243,10 @@ def initialize_distributed(args):
import deepspeed import deepspeed
deepspeed.init_distributed( deepspeed.init_distributed(
dist_backend=args.distributed_backend, world_size=args.world_size, rank=args.rank, init_method=init_method dist_backend=args.distributed_backend,
world_size=args.world_size,
rank=args.rank,
init_method=init_method,
) )
# # It seems that it has no negative influence to configure it even without using checkpointing. # # It seems that it has no negative influence to configure it even without using checkpointing.
# deepspeed.checkpointing.configure(mpu, deepspeed_config=args.deepspeed_config, num_checkpoints=args.num_layers) # deepspeed.checkpointing.configure(mpu, deepspeed_config=args.deepspeed_config, num_checkpoints=args.num_layers)
@ -262,7 +276,9 @@ def process_config_to_args(args):
args_config = config.pop("args", OmegaConf.create()) args_config = config.pop("args", OmegaConf.create())
for key in args_config: for key in args_config:
if isinstance(args_config[key], omegaconf.DictConfig) or isinstance(args_config[key], omegaconf.ListConfig): if isinstance(args_config[key], omegaconf.DictConfig) or isinstance(
args_config[key], omegaconf.ListConfig
):
arg = OmegaConf.to_object(args_config[key]) arg = OmegaConf.to_object(args_config[key])
else: else:
arg = args_config[key] arg = args_config[key]

View File

@ -56,7 +56,9 @@ def read_video(
end_pts = float("inf") end_pts = float("inf")
if end_pts < start_pts: if end_pts < start_pts:
raise ValueError(f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}") raise ValueError(
f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}"
)
info = {} info = {}
audio_frames = [] audio_frames = []
@ -342,7 +344,11 @@ class VideoDataset(MetaDistributedWebDataset):
super().__init__( super().__init__(
path, path,
partial( partial(
process_fn_video, num_frames=num_frames, image_size=image_size, fps=fps, skip_frms_num=skip_frms_num process_fn_video,
num_frames=num_frames,
image_size=image_size,
fps=fps,
skip_frms_num=skip_frms_num,
), ),
seed, seed,
meta_names=meta_names, meta_names=meta_names,
@ -400,7 +406,9 @@ class SFTDataset(Dataset):
indices = np.arange(start, end, (end - start) // num_frames).astype(int) indices = np.arange(start, end, (end - start) // num_frames).astype(int)
temp_frms = vr.get_batch(np.arange(start, end_safty)) temp_frms = vr.get_batch(np.arange(start, end_safty))
assert temp_frms is not None assert temp_frms is not None
tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms tensor_frms = (
torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
)
tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())] tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())]
else: else:
if ori_vlen > self.max_num_frames: if ori_vlen > self.max_num_frames:
@ -410,7 +418,11 @@ class SFTDataset(Dataset):
indices = np.arange(start, end, max((end - start) // num_frames, 1)).astype(int) indices = np.arange(start, end, max((end - start) // num_frames, 1)).astype(int)
temp_frms = vr.get_batch(np.arange(start, end)) temp_frms = vr.get_batch(np.arange(start, end))
assert temp_frms is not None assert temp_frms is not None
tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms tensor_frms = (
torch.from_numpy(temp_frms)
if type(temp_frms) is not torch.Tensor
else temp_frms
)
tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())] tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())]
else: else:
@ -423,11 +435,17 @@ class SFTDataset(Dataset):
start = int(self.skip_frms_num) start = int(self.skip_frms_num)
end = int(ori_vlen - self.skip_frms_num) end = int(ori_vlen - self.skip_frms_num)
num_frames = nearest_smaller_4k_plus_1(end - start) # 3D VAE requires the number of frames to be 4k+1 num_frames = nearest_smaller_4k_plus_1(
end - start
) # 3D VAE requires the number of frames to be 4k+1
end = int(start + num_frames) end = int(start + num_frames)
temp_frms = vr.get_batch(np.arange(start, end)) temp_frms = vr.get_batch(np.arange(start, end))
assert temp_frms is not None assert temp_frms is not None
tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms tensor_frms = (
torch.from_numpy(temp_frms)
if type(temp_frms) is not torch.Tensor
else temp_frms
)
tensor_frms = pad_last_frame( tensor_frms = pad_last_frame(
tensor_frms, self.max_num_frames tensor_frms, self.max_num_frames

View File

@ -41,7 +41,9 @@ class SATVideoDiffusionEngine(nn.Module):
latent_input = model_config.get("latent_input", False) latent_input = model_config.get("latent_input", False)
disable_first_stage_autocast = model_config.get("disable_first_stage_autocast", False) disable_first_stage_autocast = model_config.get("disable_first_stage_autocast", False)
no_cond_log = model_config.get("disable_first_stage_autocast", False) no_cond_log = model_config.get("disable_first_stage_autocast", False)
not_trainable_prefixes = model_config.get("not_trainable_prefixes", ["first_stage_model", "conditioner"]) not_trainable_prefixes = model_config.get(
"not_trainable_prefixes", ["first_stage_model", "conditioner"]
)
compile_model = model_config.get("compile_model", False) compile_model = model_config.get("compile_model", False)
en_and_decode_n_samples_a_time = model_config.get("en_and_decode_n_samples_a_time", None) en_and_decode_n_samples_a_time = model_config.get("en_and_decode_n_samples_a_time", None)
lr_scale = model_config.get("lr_scale", None) lr_scale = model_config.get("lr_scale", None)
@ -76,12 +78,18 @@ class SATVideoDiffusionEngine(nn.Module):
) )
self.denoiser = instantiate_from_config(denoiser_config) self.denoiser = instantiate_from_config(denoiser_config)
self.sampler = instantiate_from_config(sampler_config) if sampler_config is not None else None self.sampler = (
self.conditioner = instantiate_from_config(default(conditioner_config, UNCONDITIONAL_CONFIG)) instantiate_from_config(sampler_config) if sampler_config is not None else None
)
self.conditioner = instantiate_from_config(
default(conditioner_config, UNCONDITIONAL_CONFIG)
)
self._init_first_stage(first_stage_config) self._init_first_stage(first_stage_config)
self.loss_fn = instantiate_from_config(loss_fn_config) if loss_fn_config is not None else None self.loss_fn = (
instantiate_from_config(loss_fn_config) if loss_fn_config is not None else None
)
self.latent_input = latent_input self.latent_input = latent_input
self.scale_factor = scale_factor self.scale_factor = scale_factor
@ -151,8 +159,12 @@ class SATVideoDiffusionEngine(nn.Module):
def shared_step(self, batch: Dict) -> Any: def shared_step(self, batch: Dict) -> Any:
x = self.get_input(batch) x = self.get_input(batch)
if self.lr_scale is not None: if self.lr_scale is not None:
lr_x = F.interpolate(x, scale_factor=1 / self.lr_scale, mode="bilinear", align_corners=False) lr_x = F.interpolate(
lr_x = F.interpolate(lr_x, scale_factor=self.lr_scale, mode="bilinear", align_corners=False) x, scale_factor=1 / self.lr_scale, mode="bilinear", align_corners=False
)
lr_x = F.interpolate(
lr_x, scale_factor=self.lr_scale, mode="bilinear", align_corners=False
)
lr_z = self.encode_first_stage(lr_x, batch) lr_z = self.encode_first_stage(lr_x, batch)
batch["lr_input"] = lr_z batch["lr_input"] = lr_z
@ -195,7 +207,11 @@ class SATVideoDiffusionEngine(nn.Module):
recons = [] recons = []
start_frame = 0 start_frame = 0
for i in range(fake_cp_size): for i in range(fake_cp_size):
end_frame = start_frame + latent_time // fake_cp_size + (1 if i < latent_time % fake_cp_size else 0) end_frame = (
start_frame
+ latent_time // fake_cp_size
+ (1 if i < latent_time % fake_cp_size else 0)
)
use_cp = True if i == 0 else False use_cp = True if i == 0 else False
clear_fake_cp_cache = True if i == fake_cp_size - 1 else False clear_fake_cp_cache = True if i == fake_cp_size - 1 else False
@ -264,7 +280,9 @@ class SATVideoDiffusionEngine(nn.Module):
self.model, input, sigma, c, concat_images=concat_images, **addtional_model_inputs self.model, input, sigma, c, concat_images=concat_images, **addtional_model_inputs
) )
samples = self.sampler(denoiser, randn, cond, uc=uc, scale=scale, scale_emb=scale_emb, ofs=ofs) samples = self.sampler(
denoiser, randn, cond, uc=uc, scale=scale, scale_emb=scale_emb, ofs=ofs
)
samples = samples.to(self.dtype) samples = samples.to(self.dtype)
return samples return samples
@ -278,7 +296,9 @@ class SATVideoDiffusionEngine(nn.Module):
log = dict() log = dict()
for embedder in self.conditioner.embedders: for embedder in self.conditioner.embedders:
if ((self.log_keys is None) or (embedder.input_key in self.log_keys)) and not self.no_cond_log: if (
(self.log_keys is None) or (embedder.input_key in self.log_keys)
) and not self.no_cond_log:
x = batch[embedder.input_key][:n] x = batch[embedder.input_key][:n]
if isinstance(x, torch.Tensor): if isinstance(x, torch.Tensor):
if x.dim() == 1: if x.dim() == 1:
@ -354,7 +374,9 @@ class SATVideoDiffusionEngine(nn.Module):
image = torch.concat([image, torch.zeros_like(z[:, 1:])], dim=1) image = torch.concat([image, torch.zeros_like(z[:, 1:])], dim=1)
c["concat"] = image c["concat"] = image
uc["concat"] = image uc["concat"] = image
samples = self.sample(c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs) # b t c h w samples = self.sample(
c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
) # b t c h w
samples = samples.permute(0, 2, 1, 3, 4).contiguous() samples = samples.permute(0, 2, 1, 3, 4).contiguous()
if only_log_video_latents: if only_log_video_latents:
latents = 1.0 / self.scale_factor * samples latents = 1.0 / self.scale_factor * samples
@ -364,7 +386,9 @@ class SATVideoDiffusionEngine(nn.Module):
samples = samples.permute(0, 2, 1, 3, 4).contiguous() samples = samples.permute(0, 2, 1, 3, 4).contiguous()
log["samples"] = samples log["samples"] = samples
else: else:
samples = self.sample(c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs) # b t c h w samples = self.sample(
c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
) # b t c h w
samples = samples.permute(0, 2, 1, 3, 4).contiguous() samples = samples.permute(0, 2, 1, 3, 4).contiguous()
if only_log_video_latents: if only_log_video_latents:
latents = 1.0 / self.scale_factor * samples latents = 1.0 / self.scale_factor * samples

View File

@ -94,7 +94,9 @@ def get_3d_sincos_pos_embed(
# concate: [T, H, W] order # concate: [T, H, W] order
pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :] pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :]
pos_embed_temporal = np.repeat(pos_embed_temporal, grid_height * grid_width, axis=1) # [T, H*W, D // 4] pos_embed_temporal = np.repeat(
pos_embed_temporal, grid_height * grid_width, axis=1
) # [T, H*W, D // 4]
pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :] pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :]
pos_embed_spatial = np.repeat(pos_embed_spatial, t_size, axis=0) # [T, H*W, D // 4 * 3] pos_embed_spatial = np.repeat(pos_embed_spatial, t_size, axis=0) # [T, H*W, D // 4 * 3]
@ -160,7 +162,8 @@ class Basic2DPositionEmbeddingMixin(BaseMixin):
self.width = width self.width = width
self.spatial_length = height * width self.spatial_length = height * width
self.pos_embedding = nn.Parameter( self.pos_embedding = nn.Parameter(
torch.zeros(1, int(text_length + self.spatial_length), int(hidden_size)), requires_grad=False torch.zeros(1, int(text_length + self.spatial_length), int(hidden_size)),
requires_grad=False,
) )
def position_embedding_forward(self, position_ids, **kwargs): def position_embedding_forward(self, position_ids, **kwargs):
@ -169,7 +172,9 @@ class Basic2DPositionEmbeddingMixin(BaseMixin):
def reinit(self, parent_model=None): def reinit(self, parent_model=None):
del self.transformer.position_embeddings del self.transformer.position_embeddings
pos_embed = get_2d_sincos_pos_embed(self.pos_embedding.shape[-1], self.height, self.width) pos_embed = get_2d_sincos_pos_embed(self.pos_embedding.shape[-1], self.height, self.width)
self.pos_embedding.data[:, -self.spatial_length :].copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) self.pos_embedding.data[:, -self.spatial_length :].copy_(
torch.from_numpy(pos_embed).float().unsqueeze(0)
)
class Basic3DPositionEmbeddingMixin(BaseMixin): class Basic3DPositionEmbeddingMixin(BaseMixin):
@ -192,7 +197,8 @@ class Basic3DPositionEmbeddingMixin(BaseMixin):
self.spatial_length = height * width self.spatial_length = height * width
self.num_patches = height * width * compressed_num_frames self.num_patches = height * width * compressed_num_frames
self.pos_embedding = nn.Parameter( self.pos_embedding = nn.Parameter(
torch.zeros(1, int(text_length + self.num_patches), int(hidden_size)), requires_grad=False torch.zeros(1, int(text_length + self.num_patches), int(hidden_size)),
requires_grad=False,
) )
self.height_interpolation = height_interpolation self.height_interpolation = height_interpolation
self.width_interpolation = width_interpolation self.width_interpolation = width_interpolation
@ -285,7 +291,10 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin):
freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2) freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2)
freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2) freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1) freqs = broadcat(
(freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]),
dim=-1,
)
freqs = freqs.contiguous() freqs = freqs.contiguous()
self.freqs_sin = freqs.sin().cuda() self.freqs_sin = freqs.sin().cuda()
@ -293,7 +302,9 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin):
self.text_length = text_length self.text_length = text_length
if learnable_pos_embed: if learnable_pos_embed:
num_patches = height * width * compressed_num_frames + text_length num_patches = height * width * compressed_num_frames + text_length
self.pos_embedding = nn.Parameter(torch.zeros(1, num_patches, int(hidden_size)), requires_grad=True) self.pos_embedding = nn.Parameter(
torch.zeros(1, num_patches, int(hidden_size)), requires_grad=True
)
else: else:
self.pos_embedding = None self.pos_embedding = None
@ -440,16 +451,26 @@ class FinalLayerMixin(BaseMixin):
self.out_channels = out_channels self.out_channels = out_channels
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=elementwise_affine, eps=1e-6) self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=elementwise_affine, eps=1e-6)
self.linear = nn.Linear(hidden_size, reduce(mul, patch_size) * out_channels, bias=True) self.linear = nn.Linear(hidden_size, reduce(mul, patch_size) * out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim, 2 * hidden_size, bias=True)) self.adaLN_modulation = nn.Sequential(
nn.SiLU(), nn.Linear(time_embed_dim, 2 * hidden_size, bias=True)
)
def final_forward(self, logits, **kwargs): def final_forward(self, logits, **kwargs):
x, emb = logits[:, kwargs["text_length"] :, :], kwargs["emb"] # x:(b,(t n),d),只取了x中后面images的部分 x, emb = (
logits[:, kwargs["text_length"] :, :],
kwargs["emb"],
) # x:(b,(t n),d),只取了x中后面images的部分
shift, scale = self.adaLN_modulation(emb).chunk(2, dim=1) shift, scale = self.adaLN_modulation(emb).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale) x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x) x = self.linear(x)
return unpatchify( return unpatchify(
x, c=self.out_channels, patch_size=self.patch_size, w=kwargs["rope_W"], h=kwargs["rope_H"], **kwargs x,
c=self.out_channels,
patch_size=self.patch_size,
w=kwargs["rope_W"],
h=kwargs["rope_H"],
**kwargs,
) )
def reinit(self, parent_model=None): def reinit(self, parent_model=None):
@ -500,7 +521,10 @@ class AdaLNMixin(BaseMixin):
self.compressed_num_frames = compressed_num_frames self.compressed_num_frames = compressed_num_frames
self.adaLN_modulations = nn.ModuleList( self.adaLN_modulations = nn.ModuleList(
[nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim, 12 * hidden_size)) for _ in range(num_layers)] [
nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim, 12 * hidden_size))
for _ in range(num_layers)
]
) )
self.qk_ln = qk_ln self.qk_ln = qk_ln
@ -560,7 +584,9 @@ class AdaLNMixin(BaseMixin):
img_attention_input = modulate(img_attention_input, shift_msa, scale_msa) img_attention_input = modulate(img_attention_input, shift_msa, scale_msa)
text_attention_input = modulate(text_attention_input, text_shift_msa, text_scale_msa) text_attention_input = modulate(text_attention_input, text_shift_msa, text_scale_msa)
attention_input = torch.cat((text_attention_input, img_attention_input), dim=1) # (b,n_t+t*n_i,d) attention_input = torch.cat(
(text_attention_input, img_attention_input), dim=1
) # (b,n_t+t*n_i,d)
attention_output = layer.attention(attention_input, mask, **kwargs) attention_output = layer.attention(attention_input, mask, **kwargs)
text_attention_output = attention_output[:, :text_length] # (b,n,d) text_attention_output = attention_output[:, :text_length] # (b,n,d)
img_attention_output = attention_output[:, text_length:] # (b,(t n),d) img_attention_output = attention_output[:, text_length:] # (b,(t n),d)
@ -584,9 +610,13 @@ class AdaLNMixin(BaseMixin):
img_mlp_output = layer.fourth_layernorm(img_mlp_output) img_mlp_output = layer.fourth_layernorm(img_mlp_output)
img_hidden_states = img_hidden_states + gate_mlp * img_mlp_output # vision (b,(t n),d) img_hidden_states = img_hidden_states + gate_mlp * img_mlp_output # vision (b,(t n),d)
text_hidden_states = text_hidden_states + text_gate_mlp * text_mlp_output # language (b,n,d) text_hidden_states = (
text_hidden_states + text_gate_mlp * text_mlp_output
) # language (b,n,d)
hidden_states = torch.cat((text_hidden_states, img_hidden_states), dim=1) # (b,(n_t+t*n_i),d) hidden_states = torch.cat(
(text_hidden_states, img_hidden_states), dim=1
) # (b,(n_t+t*n_i),d)
return hidden_states return hidden_states
def reinit(self, parent_model=None): def reinit(self, parent_model=None):
@ -694,7 +724,9 @@ class DiffusionTransformer(BaseModel):
if use_RMSNorm: if use_RMSNorm:
kwargs["layernorm"] = RMSNorm kwargs["layernorm"] = RMSNorm
else: else:
kwargs["layernorm"] = partial(LayerNorm, elementwise_affine=elementwise_affine, eps=1e-6) kwargs["layernorm"] = partial(
LayerNorm, elementwise_affine=elementwise_affine, eps=1e-6
)
transformer_args.num_layers = num_layers transformer_args.num_layers = num_layers
transformer_args.hidden_size = hidden_size transformer_args.hidden_size = hidden_size
@ -707,7 +739,9 @@ class DiffusionTransformer(BaseModel):
if use_SwiGLU: if use_SwiGLU:
self.add_mixin( self.add_mixin(
"swiglu", SwiGLUMixin(num_layers, hidden_size, self.inner_hidden_size, bias=False), reinit=True "swiglu",
SwiGLUMixin(num_layers, hidden_size, self.inner_hidden_size, bias=False),
reinit=True,
) )
def _build_modules(self, module_configs): def _build_modules(self, module_configs):
@ -813,7 +847,9 @@ class DiffusionTransformer(BaseModel):
) )
if "lora_config" in module_configs: if "lora_config" in module_configs:
lora_config = module_configs["lora_config"] lora_config = module_configs["lora_config"]
self.add_mixin("lora", instantiate_from_config(lora_config, layer_num=self.num_layers), reinit=True) self.add_mixin(
"lora", instantiate_from_config(lora_config, layer_num=self.num_layers), reinit=True
)
return return
def forward(self, x, timesteps=None, context=None, y=None, **kwargs): def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
@ -829,7 +865,9 @@ class DiffusionTransformer(BaseModel):
assert (y is not None) == ( assert (y is not None) == (
self.num_classes is not None self.num_classes is not None
), "must specify y if and only if the model is class-conditional" ), "must specify y if and only if the model is class-conditional"
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False, dtype=self.dtype) t_emb = timestep_embedding(
timesteps, self.model_channels, repeat_only=False, dtype=self.dtype
)
emb = self.time_embed(t_emb) emb = self.time_embed(t_emb)
if self.num_classes is not None: if self.num_classes is not None:
@ -838,7 +876,9 @@ class DiffusionTransformer(BaseModel):
emb = emb + self.label_emb(y) emb = emb + self.label_emb(y)
if self.ofs_embed_dim is not None: if self.ofs_embed_dim is not None:
ofs_emb = timestep_embedding(kwargs["ofs"], self.ofs_embed_dim, repeat_only=False, dtype=self.dtype) ofs_emb = timestep_embedding(
kwargs["ofs"], self.ofs_embed_dim, repeat_only=False, dtype=self.dtype
)
ofs_emb = self.ofs_embed(ofs_emb) ofs_emb = self.ofs_embed(ofs_emb)
emb = emb + ofs_emb emb = emb + ofs_emb
@ -852,6 +892,8 @@ class DiffusionTransformer(BaseModel):
kwargs["rope_H"] = h // self.patch_size[1] kwargs["rope_H"] = h // self.patch_size[1]
kwargs["rope_W"] = w // self.patch_size[2] kwargs["rope_W"] = w // self.patch_size[2]
kwargs["input_ids"] = kwargs["position_ids"] = kwargs["attention_mask"] = torch.ones((1, 1)).to(x.dtype) kwargs["input_ids"] = kwargs["position_ids"] = kwargs["attention_mask"] = torch.ones(
(1, 1)
).to(x.dtype)
output = super().forward(**kwargs)[0] output = super().forward(**kwargs)[0]
return output return output

View File

@ -19,6 +19,7 @@ from sat import mpu
from diffusion_video import SATVideoDiffusionEngine from diffusion_video import SATVideoDiffusionEngine
from arguments import get_args from arguments import get_args
def read_from_cli(): def read_from_cli():
cnt = 0 cnt = 0
try: try:
@ -50,34 +51,50 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], T=None, device="cuda
for key in keys: for key in keys:
if key == "txt": if key == "txt":
batch["txt"] = np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist() batch["txt"] = (
batch_uc["txt"] = np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist() np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist()
)
batch_uc["txt"] = (
np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist()
)
elif key == "original_size_as_tuple": elif key == "original_size_as_tuple":
batch["original_size_as_tuple"] = ( batch["original_size_as_tuple"] = (
torch.tensor([value_dict["orig_height"], value_dict["orig_width"]]).to(device).repeat(*N, 1) torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
.to(device)
.repeat(*N, 1)
) )
elif key == "crop_coords_top_left": elif key == "crop_coords_top_left":
batch["crop_coords_top_left"] = ( batch["crop_coords_top_left"] = (
torch.tensor([value_dict["crop_coords_top"], value_dict["crop_coords_left"]]).to(device).repeat(*N, 1) torch.tensor([value_dict["crop_coords_top"], value_dict["crop_coords_left"]])
.to(device)
.repeat(*N, 1)
) )
elif key == "aesthetic_score": elif key == "aesthetic_score":
batch["aesthetic_score"] = torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1) batch["aesthetic_score"] = (
torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
)
batch_uc["aesthetic_score"] = ( batch_uc["aesthetic_score"] = (
torch.tensor([value_dict["negative_aesthetic_score"]]).to(device).repeat(*N, 1) torch.tensor([value_dict["negative_aesthetic_score"]]).to(device).repeat(*N, 1)
) )
elif key == "target_size_as_tuple": elif key == "target_size_as_tuple":
batch["target_size_as_tuple"] = ( batch["target_size_as_tuple"] = (
torch.tensor([value_dict["target_height"], value_dict["target_width"]]).to(device).repeat(*N, 1) torch.tensor([value_dict["target_height"], value_dict["target_width"]])
.to(device)
.repeat(*N, 1)
) )
elif key == "fps": elif key == "fps":
batch[key] = torch.tensor([value_dict["fps"]]).to(device).repeat(math.prod(N)) batch[key] = torch.tensor([value_dict["fps"]]).to(device).repeat(math.prod(N))
elif key == "fps_id": elif key == "fps_id":
batch[key] = torch.tensor([value_dict["fps_id"]]).to(device).repeat(math.prod(N)) batch[key] = torch.tensor([value_dict["fps_id"]]).to(device).repeat(math.prod(N))
elif key == "motion_bucket_id": elif key == "motion_bucket_id":
batch[key] = torch.tensor([value_dict["motion_bucket_id"]]).to(device).repeat(math.prod(N)) batch[key] = (
torch.tensor([value_dict["motion_bucket_id"]]).to(device).repeat(math.prod(N))
)
elif key == "pool_image": elif key == "pool_image":
batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=math.prod(N)).to(device, dtype=torch.half) batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=math.prod(N)).to(
device, dtype=torch.half
)
elif key == "cond_aug": elif key == "cond_aug":
batch[key] = repeat( batch[key] = repeat(
torch.tensor([value_dict["cond_aug"]]).to("cuda"), torch.tensor([value_dict["cond_aug"]]).to("cuda"),
@ -100,7 +117,9 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], T=None, device="cuda
return batch, batch_uc return batch, batch_uc
def save_video_as_grid_and_mp4(video_batch: torch.Tensor, save_path: str, fps: int = 5, args=None, key=None): def save_video_as_grid_and_mp4(
video_batch: torch.Tensor, save_path: str, fps: int = 5, args=None, key=None
):
os.makedirs(save_path, exist_ok=True) os.makedirs(save_path, exist_ok=True)
for i, vid in enumerate(video_batch): for i, vid in enumerate(video_batch):
@ -160,7 +179,9 @@ def sampling_main(args, model_cls):
W = 96 W = 96
H = int(nearest_multiple_of_16(img_H / img_W * W * 8)) // 8 H = int(nearest_multiple_of_16(img_H / img_W * W * 8)) // 8
chained_trainsforms = [] chained_trainsforms = []
chained_trainsforms.append(TT.Resize(size=[int(H * 8), int(W * 8)], interpolation=1)) chained_trainsforms.append(
TT.Resize(size=[int(H * 8), int(W * 8)], interpolation=1)
)
chained_trainsforms.append(TT.ToTensor()) chained_trainsforms.append(TT.ToTensor())
transform = TT.Compose(chained_trainsforms) transform = TT.Compose(chained_trainsforms)
image = transform(image).unsqueeze(0).to("cuda") image = transform(image).unsqueeze(0).to("cuda")
@ -170,7 +191,9 @@ def sampling_main(args, model_cls):
image = image / model.scale_factor image = image / model.scale_factor
image = image.permute(0, 2, 1, 3, 4).contiguous() image = image.permute(0, 2, 1, 3, 4).contiguous()
pad_shape = (image.shape[0], T - 1, C, H, W) pad_shape = (image.shape[0], T - 1, C, H, W)
image = torch.concat([image, torch.zeros(pad_shape).to(image.device).to(image.dtype)], dim=1) image = torch.concat(
[image, torch.zeros(pad_shape).to(image.device).to(image.dtype)], dim=1
)
else: else:
image_size = args.sampling_image_size image_size = args.sampling_image_size
H, W = image_size[0], image_size[1] H, W = image_size[0], image_size[1]
@ -181,12 +204,20 @@ def sampling_main(args, model_cls):
mp_size = mpu.get_model_parallel_world_size() mp_size = mpu.get_model_parallel_world_size()
global_rank = torch.distributed.get_rank() // mp_size global_rank = torch.distributed.get_rank() // mp_size
src = global_rank * mp_size src = global_rank * mp_size
torch.distributed.broadcast_object_list(text_cast, src=src, group=mpu.get_model_parallel_group()) torch.distributed.broadcast_object_list(
text_cast, src=src, group=mpu.get_model_parallel_group()
)
text = text_cast[0] text = text_cast[0]
value_dict = {"prompt": text, "negative_prompt": "", "num_frames": torch.tensor(T).unsqueeze(0)} value_dict = {
"prompt": text,
"negative_prompt": "",
"num_frames": torch.tensor(T).unsqueeze(0),
}
batch, batch_uc = get_batch( batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict,
num_samples,
) )
for key in batch: for key in batch:
if isinstance(batch[key], torch.Tensor): if isinstance(batch[key], torch.Tensor):
@ -212,7 +243,11 @@ def sampling_main(args, model_cls):
for index in range(args.batch_size): for index in range(args.batch_size):
if args.image2video: if args.image2video:
samples_z = sample_func( samples_z = sample_func(
c, uc=uc, batch_size=1, shape=(T, C, H, W), ofs=torch.tensor([2.0]).to("cuda") c,
uc=uc,
batch_size=1,
shape=(T, C, H, W),
ofs=torch.tensor([2.0]).to("cuda"),
) )
else: else:
samples_z = sample_func( samples_z = sample_func(
@ -226,7 +261,9 @@ def sampling_main(args, model_cls):
if args.only_save_latents: if args.only_save_latents:
samples_z = 1.0 / model.scale_factor * samples_z samples_z = 1.0 / model.scale_factor * samples_z
save_path = os.path.join( save_path = os.path.join(
args.output_dir, str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120], str(index) args.output_dir,
str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120],
str(index),
) )
os.makedirs(save_path, exist_ok=True) os.makedirs(save_path, exist_ok=True)
torch.save(samples_z, os.path.join(save_path, "latent.pt")) torch.save(samples_z, os.path.join(save_path, "latent.pt"))
@ -237,7 +274,9 @@ def sampling_main(args, model_cls):
samples_x = samples_x.permute(0, 2, 1, 3, 4).contiguous() samples_x = samples_x.permute(0, 2, 1, 3, 4).contiguous()
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu() samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu()
save_path = os.path.join( save_path = os.path.join(
args.output_dir, str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120], str(index) args.output_dir,
str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120],
str(index),
) )
if mpu.get_model_parallel_rank() == 0: if mpu.get_model_parallel_rank() == 0:
save_video_as_grid_and_mp4(samples, save_path, fps=args.sampling_fps) save_video_as_grid_and_mp4(samples, save_path, fps=args.sampling_fps)

View File

@ -71,15 +71,24 @@ class LambdaWarmUpCosineScheduler2:
n = n - self.cum_cycles[cycle] n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0: if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0: if n % self.verbosity_interval == 0:
print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}") print(
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current cycle {cycle}"
)
if n < self.lr_warm_up_steps[cycle]: if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
cycle
] * n + self.f_start[cycle]
self.last_f = f self.last_f = f
return f return f
else: else:
t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) t = (n - self.lr_warm_up_steps[cycle]) / (
self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]
)
t = min(t, 1.0) t = min(t, 1.0)
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (1 + np.cos(t * np.pi)) f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
1 + np.cos(t * np.pi)
)
self.last_f = f self.last_f = f
return f return f
@ -93,10 +102,15 @@ class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
n = n - self.cum_cycles[cycle] n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0: if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0: if n % self.verbosity_interval == 0:
print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}") print(
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current cycle {cycle}"
)
if n < self.lr_warm_up_steps[cycle]: if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
cycle
] * n + self.f_start[cycle]
self.last_f = f self.last_f = f
return f return f
else: else:

View File

@ -218,14 +218,20 @@ class AutoencodingEngine(AbstractAutoencoder):
x = self.decoder(z, **kwargs) x = self.decoder(z, **kwargs)
return x return x
def forward(self, x: torch.Tensor, **additional_decode_kwargs) -> Tuple[torch.Tensor, torch.Tensor, dict]: def forward(
self, x: torch.Tensor, **additional_decode_kwargs
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
z, reg_log = self.encode(x, return_reg_log=True) z, reg_log = self.encode(x, return_reg_log=True)
dec = self.decode(z, **additional_decode_kwargs) dec = self.decode(z, **additional_decode_kwargs)
return z, dec, reg_log return z, dec, reg_log
def inner_training_step(self, batch: dict, batch_idx: int, optimizer_idx: int = 0) -> torch.Tensor: def inner_training_step(
self, batch: dict, batch_idx: int, optimizer_idx: int = 0
) -> torch.Tensor:
x = self.get_input(batch) x = self.get_input(batch)
additional_decode_kwargs = {key: batch[key] for key in self.additional_decode_keys.intersection(batch)} additional_decode_kwargs = {
key: batch[key] for key in self.additional_decode_keys.intersection(batch)
}
z, xrec, regularization_log = self(x, **additional_decode_kwargs) z, xrec, regularization_log = self(x, **additional_decode_kwargs)
if hasattr(self.loss, "forward_keys"): if hasattr(self.loss, "forward_keys"):
extra_info = { extra_info = {
@ -361,12 +367,16 @@ class AutoencodingEngine(AbstractAutoencoder):
if self.trainable_ae_params is None: if self.trainable_ae_params is None:
ae_params = self.get_autoencoder_params() ae_params = self.get_autoencoder_params()
else: else:
ae_params, num_ae_params = self.get_param_groups(self.trainable_ae_params, self.ae_optimizer_args) ae_params, num_ae_params = self.get_param_groups(
self.trainable_ae_params, self.ae_optimizer_args
)
logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}") logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}")
if self.trainable_disc_params is None: if self.trainable_disc_params is None:
disc_params = self.get_discriminator_params() disc_params = self.get_discriminator_params()
else: else:
disc_params, num_disc_params = self.get_param_groups(self.trainable_disc_params, self.disc_optimizer_args) disc_params, num_disc_params = self.get_param_groups(
self.trainable_disc_params, self.disc_optimizer_args
)
logpy.info(f"Number of trainable discriminator parameters: {num_disc_params:,}") logpy.info(f"Number of trainable discriminator parameters: {num_disc_params:,}")
opt_ae = self.instantiate_optimizer_from_config( opt_ae = self.instantiate_optimizer_from_config(
ae_params, ae_params,
@ -375,17 +385,23 @@ class AutoencodingEngine(AbstractAutoencoder):
) )
opts = [opt_ae] opts = [opt_ae]
if len(disc_params) > 0: if len(disc_params) > 0:
opt_disc = self.instantiate_optimizer_from_config(disc_params, self.learning_rate, self.optimizer_config) opt_disc = self.instantiate_optimizer_from_config(
disc_params, self.learning_rate, self.optimizer_config
)
opts.append(opt_disc) opts.append(opt_disc)
return opts return opts
@torch.no_grad() @torch.no_grad()
def log_images(self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs) -> dict: def log_images(
self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs
) -> dict:
log = dict() log = dict()
additional_decode_kwargs = {} additional_decode_kwargs = {}
x = self.get_input(batch) x = self.get_input(batch)
additional_decode_kwargs.update({key: batch[key] for key in self.additional_decode_keys.intersection(batch)}) additional_decode_kwargs.update(
{key: batch[key] for key in self.additional_decode_keys.intersection(batch)}
)
_, xrec, _ = self(x, **additional_decode_kwargs) _, xrec, _ = self(x, **additional_decode_kwargs)
log["inputs"] = x log["inputs"] = x
@ -404,7 +420,9 @@ class AutoencodingEngine(AbstractAutoencoder):
diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x) diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x)
diff_ema.clamp_(0, 1.0) diff_ema.clamp_(0, 1.0)
log["diff_ema"] = 2.0 * diff_ema - 1.0 log["diff_ema"] = 2.0 * diff_ema - 1.0
log["diff_boost_ema"] = 2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1 log["diff_boost_ema"] = (
2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1
)
if additional_log_kwargs: if additional_log_kwargs:
additional_decode_kwargs.update(additional_log_kwargs) additional_decode_kwargs.update(additional_log_kwargs)
_, xrec_add, _ = self(x, **additional_decode_kwargs) _, xrec_add, _ = self(x, **additional_decode_kwargs)
@ -446,7 +464,9 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
params = super().get_autoencoder_params() params = super().get_autoencoder_params()
return params return params
def encode(self, x: torch.Tensor, return_reg_log: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: def encode(
self, x: torch.Tensor, return_reg_log: bool = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
if self.max_batch_size is None: if self.max_batch_size is None:
z = self.encoder(x) z = self.encoder(x)
z = self.quant_conv(z) z = self.quant_conv(z)
@ -513,7 +533,9 @@ class VideoAutoencodingEngine(AutoencodingEngine):
if ckpt_path is not None: if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def log_videos(self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs) -> dict: def log_videos(
self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs
) -> dict:
return self.log_images(batch, additional_log_kwargs, **kwargs) return self.log_images(batch, additional_log_kwargs, **kwargs)
def get_input(self, batch: dict) -> torch.Tensor: def get_input(self, batch: dict) -> torch.Tensor:
@ -524,7 +546,9 @@ class VideoAutoencodingEngine(AutoencodingEngine):
batch = batch[self.input_key] batch = batch[self.input_key]
global_src_rank = get_context_parallel_group_rank() * self.context_parallel_size global_src_rank = get_context_parallel_group_rank() * self.context_parallel_size
torch.distributed.broadcast(batch, src=global_src_rank, group=get_context_parallel_group()) torch.distributed.broadcast(
batch, src=global_src_rank, group=get_context_parallel_group()
)
batch = _conv_split(batch, dim=2, kernel_size=1) batch = _conv_split(batch, dim=2, kernel_size=1)
return batch return batch

View File

@ -94,7 +94,11 @@ class FeedForward(nn.Module):
super().__init__() super().__init__()
inner_dim = int(dim * mult) inner_dim = int(dim * mult)
dim_out = default(dim_out, dim) dim_out = default(dim_out, dim)
project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim) project_in = (
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
if not glu
else GEGLU(dim, inner_dim)
)
self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
@ -126,7 +130,9 @@ class LinearAttention(nn.Module):
def forward(self, x): def forward(self, x):
b, c, h, w = x.shape b, c, h, w = x.shape
qkv = self.to_qkv(x) qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3) q, k, v = rearrange(
qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
)
k = k.softmax(dim=-1) k = k.softmax(dim=-1)
context = torch.einsum("bhdn,bhen->bhde", k, v) context = torch.einsum("bhdn,bhen->bhde", k, v)
out = torch.einsum("bhde,bhdn->bhen", context, q) out = torch.einsum("bhde,bhdn->bhen", context, q)
@ -143,7 +149,9 @@ class SpatialSelfAttention(nn.Module):
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x): def forward(self, x):
h_ = x h_ = x
@ -244,7 +252,9 @@ class CrossAttention(nn.Module):
# new # new
with sdp_kernel(**BACKEND_MAP[self.backend]): with sdp_kernel(**BACKEND_MAP[self.backend]):
# print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape) # print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) # scale is dim_head ** -0.5 per default out = F.scaled_dot_product_attention(
q, k, v, attn_mask=mask
) # scale is dim_head ** -0.5 per default
del q, k, v del q, k, v
out = rearrange(out, "b h n d -> b n (h d)", h=h) out = rearrange(out, "b h n d -> b n (h d)", h=h)
@ -422,7 +432,9 @@ class BasicTransformerBlock(nn.Module):
self.norm1(x), self.norm1(x),
context=context if self.disable_self_attn else None, context=context if self.disable_self_attn else None,
additional_tokens=additional_tokens, additional_tokens=additional_tokens,
n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self if not self.disable_self_attn else 0, n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self
if not self.disable_self_attn
else 0,
) )
+ x + x
) )
@ -499,7 +511,9 @@ class SpatialTransformer(nn.Module):
sdp_backend=None, sdp_backend=None,
): ):
super().__init__() super().__init__()
print(f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads") print(
f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads"
)
from omegaconf import ListConfig from omegaconf import ListConfig
if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)): if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)):
@ -542,7 +556,9 @@ class SpatialTransformer(nn.Module):
] ]
) )
if not use_linear: if not use_linear:
self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)) self.proj_out = zero_module(
nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
)
else: else:
# self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) # self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels)) self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))

View File

@ -87,7 +87,9 @@ class GeneralLPIPSWithDiscriminator(nn.Module):
yield from () yield from ()
@torch.no_grad() @torch.no_grad()
def log_images(self, inputs: torch.Tensor, reconstructions: torch.Tensor) -> Dict[str, torch.Tensor]: def log_images(
self, inputs: torch.Tensor, reconstructions: torch.Tensor
) -> Dict[str, torch.Tensor]:
# calc logits of real/fake # calc logits of real/fake
logits_real = self.discriminator(inputs.contiguous().detach()) logits_real = self.discriminator(inputs.contiguous().detach())
if len(logits_real.shape) < 4: if len(logits_real.shape) < 4:
@ -209,7 +211,9 @@ class GeneralLPIPSWithDiscriminator(nn.Module):
weights: Union[None, float, torch.Tensor] = None, weights: Union[None, float, torch.Tensor] = None,
) -> Tuple[torch.Tensor, dict]: ) -> Tuple[torch.Tensor, dict]:
if self.scale_input_to_tgt_size: if self.scale_input_to_tgt_size:
inputs = torch.nn.functional.interpolate(inputs, reconstructions.shape[2:], mode="bicubic", antialias=True) inputs = torch.nn.functional.interpolate(
inputs, reconstructions.shape[2:], mode="bicubic", antialias=True
)
if self.dims > 2: if self.dims > 2:
inputs, reconstructions = map( inputs, reconstructions = map(
@ -226,7 +230,9 @@ class GeneralLPIPSWithDiscriminator(nn.Module):
input_frames = pick_video_frame(inputs, frame_indices) input_frames = pick_video_frame(inputs, frame_indices)
recon_frames = pick_video_frame(reconstructions, frame_indices) recon_frames = pick_video_frame(reconstructions, frame_indices)
p_loss = self.perceptual_loss(input_frames.contiguous(), recon_frames.contiguous()).mean() p_loss = self.perceptual_loss(
input_frames.contiguous(), recon_frames.contiguous()
).mean()
rec_loss = rec_loss + self.perceptual_weight * p_loss rec_loss = rec_loss + self.perceptual_weight * p_loss
nll_loss, weighted_nll_loss = self.get_nll_loss(rec_loss, weights) nll_loss, weighted_nll_loss = self.get_nll_loss(rec_loss, weights)
@ -238,7 +244,9 @@ class GeneralLPIPSWithDiscriminator(nn.Module):
logits_fake = self.discriminator(reconstructions.contiguous()) logits_fake = self.discriminator(reconstructions.contiguous())
g_loss = -torch.mean(logits_fake) g_loss = -torch.mean(logits_fake)
if self.training: if self.training:
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) d_weight = self.calculate_adaptive_weight(
nll_loss, g_loss, last_layer=last_layer
)
else: else:
d_weight = torch.tensor(1.0) d_weight = torch.tensor(1.0)
else: else:

View File

@ -37,12 +37,18 @@ class LatentLPIPS(nn.Module):
if self.perceptual_weight > 0.0: if self.perceptual_weight > 0.0:
image_reconstructions = self.decoder.decode(latent_predictions) image_reconstructions = self.decoder.decode(latent_predictions)
image_targets = self.decoder.decode(latent_inputs) image_targets = self.decoder.decode(latent_inputs)
perceptual_loss = self.perceptual_loss(image_targets.contiguous(), image_reconstructions.contiguous()) perceptual_loss = self.perceptual_loss(
loss = self.latent_weight * loss.mean() + self.perceptual_weight * perceptual_loss.mean() image_targets.contiguous(), image_reconstructions.contiguous()
)
loss = (
self.latent_weight * loss.mean() + self.perceptual_weight * perceptual_loss.mean()
)
log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach() log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach()
if self.perceptual_weight_on_inputs > 0.0: if self.perceptual_weight_on_inputs > 0.0:
image_reconstructions = default(image_reconstructions, self.decoder.decode(latent_predictions)) image_reconstructions = default(
image_reconstructions, self.decoder.decode(latent_predictions)
)
if self.scale_input_to_tgt_size: if self.scale_input_to_tgt_size:
image_inputs = torch.nn.functional.interpolate( image_inputs = torch.nn.functional.interpolate(
image_inputs, image_inputs,
@ -58,7 +64,9 @@ class LatentLPIPS(nn.Module):
antialias=True, antialias=True,
) )
perceptual_loss2 = self.perceptual_loss(image_inputs.contiguous(), image_reconstructions.contiguous()) perceptual_loss2 = self.perceptual_loss(
image_inputs.contiguous(), image_reconstructions.contiguous()
)
loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean() loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean()
log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach() log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach()
return loss, log return loss, log

View File

@ -45,7 +45,9 @@ def hinge_gen_loss(fake):
@autocast(enabled=False) @autocast(enabled=False)
@beartype @beartype
def grad_layer_wrt_loss(loss: Tensor, layer: nn.Parameter): def grad_layer_wrt_loss(loss: Tensor, layer: nn.Parameter):
return torch_grad(outputs=loss, inputs=layer, grad_outputs=torch.ones_like(loss), retain_graph=True)[0].detach() return torch_grad(
outputs=loss, inputs=layer, grad_outputs=torch.ones_like(loss), retain_graph=True
)[0].detach()
def pick_video_frame(video, frame_indices): def pick_video_frame(video, frame_indices):
@ -126,7 +128,8 @@ class DiscriminatorBlock(nn.Module):
self.downsample = ( self.downsample = (
nn.Sequential( nn.Sequential(
Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2), nn.Conv2d(filters * 4, filters, 1) Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2),
nn.Conv2d(filters * 4, filters, 1),
) )
if downsample if downsample
else None else None
@ -185,11 +188,18 @@ class Discriminator(nn.Module):
is_not_last = ind != (len(layer_dims_in_out) - 1) is_not_last = ind != (len(layer_dims_in_out) - 1)
block = DiscriminatorBlock( block = DiscriminatorBlock(
in_chan, out_chan, downsample=is_not_last, antialiased_downsample=antialiased_downsample in_chan,
out_chan,
downsample=is_not_last,
antialiased_downsample=antialiased_downsample,
) )
attn_block = nn.Sequential( attn_block = nn.Sequential(
Residual(LinearSpaceAttention(dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head)), Residual(
LinearSpaceAttention(
dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head
)
),
Residual(FeedForward(dim=out_chan, mult=ff_mult, images=True)), Residual(FeedForward(dim=out_chan, mult=ff_mult, images=True)),
) )
@ -363,7 +373,9 @@ class Discriminator3D(nn.Module):
) )
attn_block = nn.Sequential( attn_block = nn.Sequential(
Residual( Residual(
LinearSpaceAttention(dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head) LinearSpaceAttention(
dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head
)
), ),
Residual(FeedForward(dim=out_chan, mult=ff_mult, images=True)), Residual(FeedForward(dim=out_chan, mult=ff_mult, images=True)),
) )
@ -458,7 +470,9 @@ class Discriminator3DWithfirstframe(nn.Module):
) )
attn_block = nn.Sequential( attn_block = nn.Sequential(
Residual( Residual(
LinearSpaceAttention(dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head) LinearSpaceAttention(
dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head
)
), ),
Residual(FeedForward(dim=out_chan, mult=ff_mult, images=True)), Residual(FeedForward(dim=out_chan, mult=ff_mult, images=True)),
) )
@ -581,11 +595,17 @@ class VideoAutoencoderLoss(nn.Module):
input_frames = pick_video_frame(inputs, frame_indices) input_frames = pick_video_frame(inputs, frame_indices)
recon_frames = pick_video_frame(reconstructions, frame_indices) recon_frames = pick_video_frame(reconstructions, frame_indices)
perceptual_loss = self.perceptual_model(input_frames.contiguous(), recon_frames.contiguous()).mean() perceptual_loss = self.perceptual_model(
input_frames.contiguous(), recon_frames.contiguous()
).mean()
else: else:
perceptual_loss = self.zero perceptual_loss = self.zero
if global_step >= self.disc_start or not self.training or self.adversarial_loss_weight == 0: if (
global_step >= self.disc_start
or not self.training
or self.adversarial_loss_weight == 0
):
gen_loss = self.zero gen_loss = self.zero
adaptive_weight = 0 adaptive_weight = 0
else: else:
@ -598,9 +618,13 @@ class VideoAutoencoderLoss(nn.Module):
adaptive_weight = 1 adaptive_weight = 1
if self.perceptual_weight > 0 and last_layer is not None: if self.perceptual_weight > 0 and last_layer is not None:
norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_layer).norm(p=2) norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(
perceptual_loss, last_layer
).norm(p=2)
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_layer).norm(p=2) norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_layer).norm(p=2)
adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(min=1e-3) adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(
min=1e-3
)
adaptive_weight.clamp_(max=1e3) adaptive_weight.clamp_(max=1e3)
if torch.isnan(adaptive_weight).any(): if torch.isnan(adaptive_weight).any():

View File

@ -48,7 +48,9 @@ class LPIPS(nn.Module):
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] res = [
spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))
]
val = res[0] val = res[0]
for l in range(1, len(self.chns)): for l in range(1, len(self.chns)):
val += res[l] val += res[l]
@ -118,7 +120,9 @@ class vgg16(torch.nn.Module):
h_relu4_3 = h h_relu4_3 = h
h = self.slice5(h) h = self.slice5(h)
h_relu5_3 = h h_relu5_3 = h
vgg_outputs = namedtuple("VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]) vgg_outputs = namedtuple(
"VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]
)
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
return out return out

View File

@ -35,7 +35,9 @@ class NLayerDiscriminator(nn.Module):
norm_layer = nn.BatchNorm2d norm_layer = nn.BatchNorm2d
else: else:
norm_layer = ActNorm norm_layer = ActNorm
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters if (
type(norm_layer) == functools.partial
): # no need to use bias as BatchNorm2d has affine parameters
use_bias = norm_layer.func != nn.BatchNorm2d use_bias = norm_layer.func != nn.BatchNorm2d
else: else:
use_bias = norm_layer != nn.BatchNorm2d use_bias = norm_layer != nn.BatchNorm2d

View File

@ -11,6 +11,7 @@ def hinge_d_loss(logits_real, logits_fake):
def vanilla_d_loss(logits_real, logits_fake): def vanilla_d_loss(logits_real, logits_fake):
d_loss = 0.5 * ( d_loss = 0.5 * (
torch.mean(torch.nn.functional.softplus(-logits_real)) + torch.mean(torch.nn.functional.softplus(logits_fake)) torch.mean(torch.nn.functional.softplus(-logits_real))
+ torch.mean(torch.nn.functional.softplus(logits_fake))
) )
return d_loss return d_loss

View File

@ -147,7 +147,9 @@ def hinge_gen_loss(fake):
@autocast(enabled=False) @autocast(enabled=False)
@beartype @beartype
def grad_layer_wrt_loss(loss: Tensor, layer: nn.Parameter): def grad_layer_wrt_loss(loss: Tensor, layer: nn.Parameter):
return torch_grad(outputs=loss, inputs=layer, grad_outputs=torch.ones_like(loss), retain_graph=True)[0].detach() return torch_grad(
outputs=loss, inputs=layer, grad_outputs=torch.ones_like(loss), retain_graph=True
)[0].detach()
# helper decorators # helper decorators
@ -223,7 +225,10 @@ class SqueezeExcite(Module):
dim_hidden = max(dim_hidden_min, dim_out // 2) dim_hidden = max(dim_hidden_min, dim_out // 2)
self.net = nn.Sequential( self.net = nn.Sequential(
nn.Conv2d(dim, dim_hidden, 1), nn.LeakyReLU(0.1), nn.Conv2d(dim_hidden, dim_out, 1), nn.Sigmoid() nn.Conv2d(dim, dim_hidden, 1),
nn.LeakyReLU(0.1),
nn.Conv2d(dim_hidden, dim_out, 1),
nn.Sigmoid(),
) )
nn.init.zeros_(self.net[-2].weight) nn.init.zeros_(self.net[-2].weight)
@ -282,7 +287,10 @@ class RMSNorm(Module):
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
def forward(self, x): def forward(self, x):
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias return (
F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma
+ self.bias
)
class AdaptiveRMSNorm(Module): class AdaptiveRMSNorm(Module):
@ -353,7 +361,8 @@ class Attention(Module):
self.norm = RMSNorm(dim) self.norm = RMSNorm(dim)
self.to_qkv = nn.Sequential( self.to_qkv = nn.Sequential(
nn.Linear(dim, dim_inner * 3, bias=False), Rearrange("b n (qkv h d) -> qkv b h n d", qkv=3, h=heads) nn.Linear(dim, dim_inner * 3, bias=False),
Rearrange("b n (qkv h d) -> qkv b h n d", qkv=3, h=heads),
) )
assert num_memory_kv > 0 assert num_memory_kv > 0
@ -361,7 +370,9 @@ class Attention(Module):
self.attend = Attend(causal=causal, dropout=dropout, flash=flash) self.attend = Attend(causal=causal, dropout=dropout, flash=flash)
self.to_out = nn.Sequential(Rearrange("b h n d -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False)) self.to_out = nn.Sequential(
Rearrange("b h n d -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False)
)
@beartype @beartype
def forward(self, x, mask: Optional[Tensor] = None, cond: Optional[Tensor] = None): def forward(self, x, mask: Optional[Tensor] = None, cond: Optional[Tensor] = None):
@ -455,7 +466,9 @@ class FeedForward(Module):
super().__init__() super().__init__()
conv_klass = nn.Conv2d if images else nn.Conv3d conv_klass = nn.Conv2d if images else nn.Conv3d
rmsnorm_klass = RMSNorm if not exists(dim_cond) else partial(AdaptiveRMSNorm, dim_cond=dim_cond) rmsnorm_klass = (
RMSNorm if not exists(dim_cond) else partial(AdaptiveRMSNorm, dim_cond=dim_cond)
)
maybe_adaptive_norm_klass = partial(rmsnorm_klass, channel_first=True, images=images) maybe_adaptive_norm_klass = partial(rmsnorm_klass, channel_first=True, images=images)
@ -463,7 +476,9 @@ class FeedForward(Module):
self.norm = maybe_adaptive_norm_klass(dim) self.norm = maybe_adaptive_norm_klass(dim)
self.net = Sequential(conv_klass(dim, dim_inner * 2, 1), GEGLU(), conv_klass(dim_inner, dim, 1)) self.net = Sequential(
conv_klass(dim, dim_inner * 2, 1), GEGLU(), conv_klass(dim_inner, dim, 1)
)
@beartype @beartype
def forward(self, x: Tensor, *, cond: Optional[Tensor] = None): def forward(self, x: Tensor, *, cond: Optional[Tensor] = None):
@ -525,7 +540,8 @@ class DiscriminatorBlock(Module):
self.downsample = ( self.downsample = (
nn.Sequential( nn.Sequential(
Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2), nn.Conv2d(filters * 4, filters, 1) Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2),
nn.Conv2d(filters * 4, filters, 1),
) )
if downsample if downsample
else None else None
@ -584,11 +600,18 @@ class Discriminator(Module):
is_not_last = ind != (len(layer_dims_in_out) - 1) is_not_last = ind != (len(layer_dims_in_out) - 1)
block = DiscriminatorBlock( block = DiscriminatorBlock(
in_chan, out_chan, downsample=is_not_last, antialiased_downsample=antialiased_downsample in_chan,
out_chan,
downsample=is_not_last,
antialiased_downsample=antialiased_downsample,
) )
attn_block = Sequential( attn_block = Sequential(
Residual(LinearSpaceAttention(dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head)), Residual(
LinearSpaceAttention(
dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head
)
),
Residual(FeedForward(dim=out_chan, mult=ff_mult, images=True)), Residual(FeedForward(dim=out_chan, mult=ff_mult, images=True)),
) )
@ -628,7 +651,16 @@ class Discriminator(Module):
class Conv3DMod(Module): class Conv3DMod(Module):
@beartype @beartype
def __init__( def __init__(
self, dim, *, spatial_kernel, time_kernel, causal=True, dim_out=None, demod=True, eps=1e-8, pad_mode="zeros" self,
dim,
*,
spatial_kernel,
time_kernel,
causal=True,
dim_out=None,
demod=True,
eps=1e-8,
pad_mode="zeros",
): ):
super().__init__() super().__init__()
dim_out = default(dim_out, dim) dim_out = default(dim_out, dim)
@ -644,7 +676,9 @@ class Conv3DMod(Module):
self.pad_mode = pad_mode self.pad_mode = pad_mode
self.padding = (*((spatial_kernel // 2,) * 4), *time_padding) self.padding = (*((spatial_kernel // 2,) * 4), *time_padding)
self.weights = nn.Parameter(torch.randn((dim_out, dim, time_kernel, spatial_kernel, spatial_kernel))) self.weights = nn.Parameter(
torch.randn((dim_out, dim, time_kernel, spatial_kernel, spatial_kernel))
)
self.demod = demod self.demod = demod
@ -675,7 +709,11 @@ class Conv3DMod(Module):
weights = weights * (cond + 1) weights = weights * (cond + 1)
if self.demod: if self.demod:
inv_norm = reduce(weights**2, "b o i k0 k1 k2 -> b o 1 1 1 1", "sum").clamp(min=self.eps).rsqrt() inv_norm = (
reduce(weights**2, "b o i k0 k1 k2 -> b o 1 1 1 1", "sum")
.clamp(min=self.eps)
.rsqrt()
)
weights = weights * inv_norm weights = weights * inv_norm
fmap = rearrange(fmap, "b c t h w -> 1 (b c) t h w") fmap = rearrange(fmap, "b c t h w -> 1 (b c) t h w")
@ -742,7 +780,9 @@ class SpatialUpsample2x(Module):
dim_out = default(dim_out, dim) dim_out = default(dim_out, dim)
conv = nn.Conv2d(dim, dim_out * 4, 1) conv = nn.Conv2d(dim, dim_out * 4, 1)
self.net = nn.Sequential(conv, nn.SiLU(), Rearrange("b (c p1 p2) h w -> b c (h p1) (w p2)", p1=2, p2=2)) self.net = nn.Sequential(
conv, nn.SiLU(), Rearrange("b (c p1 p2) h w -> b c (h p1) (w p2)", p1=2, p2=2)
)
self.init_conv_(conv) self.init_conv_(conv)
@ -808,7 +848,12 @@ def SameConv2d(dim_in, dim_out, kernel_size):
class CausalConv3d(Module): class CausalConv3d(Module):
@beartype @beartype
def __init__( def __init__(
self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], pad_mode="constant", **kwargs self,
chan_in,
chan_out,
kernel_size: Union[int, Tuple[int, int, int]],
pad_mode="constant",
**kwargs,
): ):
super().__init__() super().__init__()
kernel_size = cast_tuple(kernel_size, 3) kernel_size = cast_tuple(kernel_size, 3)
@ -830,7 +875,9 @@ class CausalConv3d(Module):
stride = (stride, 1, 1) stride = (stride, 1, 1)
dilation = (dilation, 1, 1) dilation = (dilation, 1, 1)
self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) self.conv = nn.Conv3d(
chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs
)
def forward(self, x): def forward(self, x):
pad_mode = self.pad_mode if self.time_pad < x.shape[2] else "constant" pad_mode = self.pad_mode if self.time_pad < x.shape[2] else "constant"
@ -855,7 +902,13 @@ def ResidualUnit(dim, kernel_size: Union[int, Tuple[int, int, int]], pad_mode: s
@beartype @beartype
class ResidualUnitMod(Module): class ResidualUnitMod(Module):
def __init__( def __init__(
self, dim, kernel_size: Union[int, Tuple[int, int, int]], *, dim_cond, pad_mode: str = "constant", demod=True self,
dim,
kernel_size: Union[int, Tuple[int, int, int]],
*,
dim_cond,
pad_mode: str = "constant",
demod=True,
): ):
super().__init__() super().__init__()
kernel_size = cast_tuple(kernel_size, 3) kernel_size = cast_tuple(kernel_size, 3)
@ -892,7 +945,15 @@ class ResidualUnitMod(Module):
class CausalConvTranspose3d(Module): class CausalConvTranspose3d(Module):
def __init__(self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], *, time_stride, **kwargs): def __init__(
self,
chan_in,
chan_out,
kernel_size: Union[int, Tuple[int, int, int]],
*,
time_stride,
**kwargs,
):
super().__init__() super().__init__()
kernel_size = cast_tuple(kernel_size, 3) kernel_size = cast_tuple(kernel_size, 3)
@ -908,7 +969,9 @@ class CausalConvTranspose3d(Module):
stride = (time_stride, 1, 1) stride = (time_stride, 1, 1)
padding = (0, height_pad, width_pad) padding = (0, height_pad, width_pad)
self.conv = nn.ConvTranspose3d(chan_in, chan_out, kernel_size, stride, padding=padding, **kwargs) self.conv = nn.ConvTranspose3d(
chan_in, chan_out, kernel_size, stride, padding=padding, **kwargs
)
def forward(self, x): def forward(self, x):
assert x.ndim == 5 assert x.ndim == 5
@ -936,7 +999,9 @@ LossBreakdown = namedtuple(
], ],
) )
DiscrLossBreakdown = namedtuple("DiscrLossBreakdown", ["discr_loss", "multiscale_discr_losses", "gradient_penalty"]) DiscrLossBreakdown = namedtuple(
"DiscrLossBreakdown", ["discr_loss", "multiscale_discr_losses", "gradient_penalty"]
)
class VideoTokenizer(Module): class VideoTokenizer(Module):
@ -1050,10 +1115,14 @@ class VideoTokenizer(Module):
has_cond = True has_cond = True
encoder_layer = ResidualUnitMod( encoder_layer = ResidualUnitMod(
dim, residual_conv_kernel_size, dim_cond=int(dim_cond * dim_cond_expansion_factor) dim,
residual_conv_kernel_size,
dim_cond=int(dim_cond * dim_cond_expansion_factor),
) )
decoder_layer = ResidualUnitMod( decoder_layer = ResidualUnitMod(
dim, residual_conv_kernel_size, dim_cond=int(dim_cond * dim_cond_expansion_factor) dim,
residual_conv_kernel_size,
dim_cond=int(dim_cond * dim_cond_expansion_factor),
) )
dim_out = dim dim_out = dim
@ -1080,15 +1149,25 @@ class VideoTokenizer(Module):
elif layer_type == "attend_space": elif layer_type == "attend_space":
attn_kwargs = dict( attn_kwargs = dict(
dim=dim, dim_head=attn_dim_head, heads=attn_heads, dropout=attn_dropout, flash=flash_attn dim=dim,
dim_head=attn_dim_head,
heads=attn_heads,
dropout=attn_dropout,
flash=flash_attn,
) )
encoder_layer = Sequential(Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim))) encoder_layer = Sequential(
Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim))
)
decoder_layer = Sequential(Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim))) decoder_layer = Sequential(
Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim))
)
elif layer_type == "linear_attend_space": elif layer_type == "linear_attend_space":
linear_attn_kwargs = dict(dim=dim, dim_head=linear_attn_dim_head, heads=linear_attn_heads) linear_attn_kwargs = dict(
dim=dim, dim_head=linear_attn_dim_head, heads=linear_attn_heads
)
encoder_layer = Sequential( encoder_layer = Sequential(
Residual(LinearSpaceAttention(**linear_attn_kwargs)), Residual(FeedForward(dim)) Residual(LinearSpaceAttention(**linear_attn_kwargs)), Residual(FeedForward(dim))
@ -1136,9 +1215,13 @@ class VideoTokenizer(Module):
flash=flash_attn, flash=flash_attn,
) )
encoder_layer = Sequential(Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim))) encoder_layer = Sequential(
Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim))
)
decoder_layer = Sequential(Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim))) decoder_layer = Sequential(
Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim))
)
elif layer_type == "cond_linear_attend_space": elif layer_type == "cond_linear_attend_space":
has_cond = True has_cond = True
@ -1153,11 +1236,13 @@ class VideoTokenizer(Module):
) )
encoder_layer = Sequential( encoder_layer = Sequential(
Residual(LinearSpaceAttention(**attn_kwargs)), Residual(FeedForward(dim, dim_cond=dim_cond)) Residual(LinearSpaceAttention(**attn_kwargs)),
Residual(FeedForward(dim, dim_cond=dim_cond)),
) )
decoder_layer = Sequential( decoder_layer = Sequential(
Residual(LinearSpaceAttention(**attn_kwargs)), Residual(FeedForward(dim, dim_cond=dim_cond)) Residual(LinearSpaceAttention(**attn_kwargs)),
Residual(FeedForward(dim, dim_cond=dim_cond)),
) )
elif layer_type == "cond_attend_time": elif layer_type == "cond_attend_time":
@ -1283,7 +1368,9 @@ class VideoTokenizer(Module):
# discriminator # discriminator
discr_kwargs = default(discr_kwargs, dict(dim=dim, image_size=image_size, channels=channels, max_dim=512)) discr_kwargs = default(
discr_kwargs, dict(dim=dim, image_size=image_size, channels=channels, max_dim=512)
)
self.discr = Discriminator(**discr_kwargs) self.discr = Discriminator(**discr_kwargs)
@ -1380,8 +1467,16 @@ class VideoTokenizer(Module):
self.load_state_dict(state_dict, strict=strict) self.load_state_dict(state_dict, strict=strict)
@beartype @beartype
def encode(self, video: Tensor, quantize=False, cond: Optional[Tensor] = None, video_contains_first_frame=True): def encode(
encode_first_frame_separately = self.separate_first_frame_encoding and video_contains_first_frame self,
video: Tensor,
quantize=False,
cond: Optional[Tensor] = None,
video_contains_first_frame=True,
):
encode_first_frame_separately = (
self.separate_first_frame_encoding and video_contains_first_frame
)
# whether to pad video or not # whether to pad video or not
@ -1389,12 +1484,16 @@ class VideoTokenizer(Module):
video_len = video.shape[2] video_len = video.shape[2]
video = pad_at_dim(video, (self.time_padding, 0), value=0.0, dim=2) video = pad_at_dim(video, (self.time_padding, 0), value=0.0, dim=2)
video_packed_shape = [torch.Size([self.time_padding]), torch.Size([]), torch.Size([video_len - 1])] video_packed_shape = [
torch.Size([self.time_padding]),
torch.Size([]),
torch.Size([video_len - 1]),
]
# conditioning, if needed # conditioning, if needed
assert (not self.has_cond) or exists( assert (
cond (not self.has_cond) or exists(cond)
), "`cond` must be passed into tokenizer forward method since conditionable layers were specified" ), "`cond` must be passed into tokenizer forward method since conditionable layers were specified"
if exists(cond): if exists(cond):
@ -1431,7 +1530,9 @@ class VideoTokenizer(Module):
return maybe_quantize(video) return maybe_quantize(video)
@beartype @beartype
def decode_from_code_indices(self, codes: Tensor, cond: Optional[Tensor] = None, video_contains_first_frame=True): def decode_from_code_indices(
self, codes: Tensor, cond: Optional[Tensor] = None, video_contains_first_frame=True
):
assert codes.dtype in (torch.long, torch.int32) assert codes.dtype in (torch.long, torch.int32)
if codes.ndim == 2: if codes.ndim == 2:
@ -1444,18 +1545,24 @@ class VideoTokenizer(Module):
quantized = self.quantizers.indices_to_codes(codes) quantized = self.quantizers.indices_to_codes(codes)
return self.decode(quantized, cond=cond, video_contains_first_frame=video_contains_first_frame) return self.decode(
quantized, cond=cond, video_contains_first_frame=video_contains_first_frame
)
@beartype @beartype
def decode(self, quantized: Tensor, cond: Optional[Tensor] = None, video_contains_first_frame=True): def decode(
decode_first_frame_separately = self.separate_first_frame_encoding and video_contains_first_frame self, quantized: Tensor, cond: Optional[Tensor] = None, video_contains_first_frame=True
):
decode_first_frame_separately = (
self.separate_first_frame_encoding and video_contains_first_frame
)
batch = quantized.shape[0] batch = quantized.shape[0]
# conditioning, if needed # conditioning, if needed
assert (not self.has_cond) or exists( assert (
cond (not self.has_cond) or exists(cond)
), "`cond` must be passed into tokenizer forward method since conditionable layers were specified" ), "`cond` must be passed into tokenizer forward method since conditionable layers were specified"
if exists(cond): if exists(cond):
@ -1558,14 +1665,18 @@ class VideoTokenizer(Module):
aux_losses = self.zero aux_losses = self.zero
quantizer_loss_breakdown = None quantizer_loss_breakdown = None
else: else:
(quantized, codes, aux_losses), quantizer_loss_breakdown = self.quantizers(x, return_loss_breakdown=True) (quantized, codes, aux_losses), quantizer_loss_breakdown = self.quantizers(
x, return_loss_breakdown=True
)
if return_codes and not return_recon: if return_codes and not return_recon:
return codes return codes
# decoder # decoder
recon_video = self.decode(quantized, cond=cond, video_contains_first_frame=video_contains_first_frame) recon_video = self.decode(
quantized, cond=cond, video_contains_first_frame=video_contains_first_frame
)
if return_codes: if return_codes:
return codes, recon_video return codes, recon_video
@ -1613,7 +1724,9 @@ class VideoTokenizer(Module):
multiscale_real_logits = discr(video) multiscale_real_logits = discr(video)
multiscale_fake_logits = discr(recon_video.detach()) multiscale_fake_logits = discr(recon_video.detach())
multiscale_discr_loss = hinge_discr_loss(multiscale_fake_logits, multiscale_real_logits) multiscale_discr_loss = hinge_discr_loss(
multiscale_fake_logits, multiscale_real_logits
)
multiscale_discr_losses.append(multiscale_discr_loss) multiscale_discr_losses.append(multiscale_discr_loss)
else: else:
@ -1634,7 +1747,9 @@ class VideoTokenizer(Module):
+ sum(multiscale_discr_losses) * self.multiscale_adversarial_loss_weight + sum(multiscale_discr_losses) * self.multiscale_adversarial_loss_weight
) )
discr_loss_breakdown = DiscrLossBreakdown(discr_loss, multiscale_discr_losses, gradient_penalty_loss) discr_loss_breakdown = DiscrLossBreakdown(
discr_loss, multiscale_discr_losses, gradient_penalty_loss
)
return total_loss, discr_loss_breakdown return total_loss, discr_loss_breakdown
@ -1669,7 +1784,9 @@ class VideoTokenizer(Module):
norm_grad_wrt_perceptual_loss = None norm_grad_wrt_perceptual_loss = None
if self.training and self.use_vgg and (self.has_gan or self.has_multiscale_discrs): if self.training and self.use_vgg and (self.has_gan or self.has_multiscale_discrs):
norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p=2) norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(
perceptual_loss, last_dec_layer
).norm(p=2)
# per-frame image discriminator # per-frame image discriminator
@ -1686,7 +1803,9 @@ class VideoTokenizer(Module):
if exists(norm_grad_wrt_perceptual_loss): if exists(norm_grad_wrt_perceptual_loss):
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p=2) norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p=2)
adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(min=1e-3) adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(
min=1e-3
)
adaptive_weight.clamp_(max=1e3) adaptive_weight.clamp_(max=1e3)
if torch.isnan(adaptive_weight).any(): if torch.isnan(adaptive_weight).any():
@ -1713,8 +1832,12 @@ class VideoTokenizer(Module):
multiscale_adaptive_weight = 1.0 multiscale_adaptive_weight = 1.0
if exists(norm_grad_wrt_perceptual_loss): if exists(norm_grad_wrt_perceptual_loss):
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(multiscale_gen_loss, last_dec_layer).norm(p=2) norm_grad_wrt_gen_loss = grad_layer_wrt_loss(
multiscale_adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(min=1e-5) multiscale_gen_loss, last_dec_layer
).norm(p=2)
multiscale_adaptive_weight = (
norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(min=1e-5)
)
multiscale_adaptive_weight.clamp_(max=1e3) multiscale_adaptive_weight.clamp_(max=1e3)
multiscale_gen_adaptive_weights.append(multiscale_adaptive_weight) multiscale_gen_adaptive_weights.append(multiscale_adaptive_weight)
@ -1730,10 +1853,13 @@ class VideoTokenizer(Module):
if self.has_multiscale_discrs: if self.has_multiscale_discrs:
weighted_multiscale_gen_losses = sum( weighted_multiscale_gen_losses = sum(
loss * weight for loss, weight in zip(multiscale_gen_losses, multiscale_gen_adaptive_weights) loss * weight
for loss, weight in zip(multiscale_gen_losses, multiscale_gen_adaptive_weights)
) )
total_loss = total_loss + weighted_multiscale_gen_losses * multiscale_adversarial_loss_weight total_loss = (
total_loss + weighted_multiscale_gen_losses * multiscale_adversarial_loss_weight
)
# loss breakdown # loss breakdown

View File

@ -26,7 +26,9 @@ class IdentityRegularizer(AbstractRegularizer):
yield from () yield from ()
def measure_perplexity(predicted_indices: torch.Tensor, num_centroids: int) -> Tuple[torch.Tensor, torch.Tensor]: def measure_perplexity(
predicted_indices: torch.Tensor, num_centroids: int
) -> Tuple[torch.Tensor, torch.Tensor]:
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
encodings = F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids) encodings = F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids)

View File

@ -79,13 +79,19 @@ class FSQ(Module):
self.dim = default(dim, len(_levels) * num_codebooks) self.dim = default(dim, len(_levels) * num_codebooks)
has_projections = self.dim != effective_codebook_dim has_projections = self.dim != effective_codebook_dim
self.project_in = nn.Linear(self.dim, effective_codebook_dim) if has_projections else nn.Identity() self.project_in = (
self.project_out = nn.Linear(effective_codebook_dim, self.dim) if has_projections else nn.Identity() nn.Linear(self.dim, effective_codebook_dim) if has_projections else nn.Identity()
)
self.project_out = (
nn.Linear(effective_codebook_dim, self.dim) if has_projections else nn.Identity()
)
self.has_projections = has_projections self.has_projections = has_projections
self.codebook_size = self._levels.prod().item() self.codebook_size = self._levels.prod().item()
implicit_codebook = self.indices_to_codes(torch.arange(self.codebook_size), project_out=False) implicit_codebook = self.indices_to_codes(
torch.arange(self.codebook_size), project_out=False
)
self.register_buffer("implicit_codebook", implicit_codebook, persistent=False) self.register_buffer("implicit_codebook", implicit_codebook, persistent=False)
def bound(self, z: Tensor, eps: float = 1e-3) -> Tensor: def bound(self, z: Tensor, eps: float = 1e-3) -> Tensor:
@ -153,7 +159,9 @@ class FSQ(Module):
z = rearrange(z, "b d ... -> b ... d") z = rearrange(z, "b d ... -> b ... d")
z, ps = pack_one(z, "b * d") z, ps = pack_one(z, "b * d")
assert z.shape[-1] == self.dim, f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}" assert (
z.shape[-1] == self.dim
), f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}"
z = self.project_in(z) z = self.project_in(z)

View File

@ -78,7 +78,9 @@ class LFQ(Module):
# some assert validations # some assert validations
assert exists(dim) or exists(codebook_size), "either dim or codebook_size must be specified for LFQ" assert exists(dim) or exists(
codebook_size
), "either dim or codebook_size must be specified for LFQ"
assert ( assert (
not exists(codebook_size) or log2(codebook_size).is_integer() not exists(codebook_size) or log2(codebook_size).is_integer()
), f"your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})" ), f"your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})"
@ -195,7 +197,9 @@ class LFQ(Module):
x = rearrange(x, "b d ... -> b ... d") x = rearrange(x, "b d ... -> b ... d")
x, ps = pack_one(x, "b * d") x, ps = pack_one(x, "b * d")
assert x.shape[-1] == self.dim, f"expected dimension of {self.dim} but received {x.shape[-1]}" assert (
x.shape[-1] == self.dim
), f"expected dimension of {self.dim} but received {x.shape[-1]}"
x = self.project_in(x) x = self.project_in(x)
@ -299,7 +303,9 @@ class LFQ(Module):
# complete aux loss # complete aux loss
aux_loss = entropy_aux_loss * self.entropy_loss_weight + commit_loss * self.commitment_loss_weight aux_loss = (
entropy_aux_loss * self.entropy_loss_weight + commit_loss * self.commitment_loss_weight
)
ret = Return(x, indices, aux_loss) ret = Return(x, indices, aux_loss)

View File

@ -33,7 +33,9 @@ class AbstractQuantizer(AbstractRegularizer):
new = match.argmax(-1) new = match.argmax(-1)
unknown = match.sum(2) < 1 unknown = match.sum(2) < 1
if self.unknown_index == "random": if self.unknown_index == "random":
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(
device=new.device
)
else: else:
new[unknown] = self.unknown_index new[unknown] = self.unknown_index
return new.reshape(ishape) return new.reshape(ishape)
@ -50,7 +52,9 @@ class AbstractQuantizer(AbstractRegularizer):
return back.reshape(ishape) return back.reshape(ishape)
@abstractmethod @abstractmethod
def get_codebook_entry(self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None) -> torch.Tensor: def get_codebook_entry(
self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None
) -> torch.Tensor:
raise NotImplementedError() raise NotImplementedError()
def get_trainable_parameters(self) -> Iterator[torch.nn.Parameter]: def get_trainable_parameters(self) -> Iterator[torch.nn.Parameter]:
@ -239,7 +243,8 @@ class VectorQuantizer(AbstractQuantizer):
d = ( d = (
torch.sum(z_flattened**2, dim=1, keepdim=True) torch.sum(z_flattened**2, dim=1, keepdim=True)
+ torch.sum(self.embedding.weight**2, dim=1) + torch.sum(self.embedding.weight**2, dim=1)
- 2 * torch.einsum("bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n")) - 2
* torch.einsum("bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n"))
) )
min_encoding_indices = torch.argmin(d, dim=1) min_encoding_indices = torch.argmin(d, dim=1)
@ -267,15 +272,21 @@ class VectorQuantizer(AbstractQuantizer):
if self.sane_index_shape: if self.sane_index_shape:
if do_reshape: if do_reshape:
min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3]) min_encoding_indices = min_encoding_indices.reshape(
z_q.shape[0], z_q.shape[2], z_q.shape[3]
)
else: else:
min_encoding_indices = rearrange(min_encoding_indices, "(b s) 1 -> b s", b=z_q.shape[0]) min_encoding_indices = rearrange(
min_encoding_indices, "(b s) 1 -> b s", b=z_q.shape[0]
)
loss_dict["min_encoding_indices"] = min_encoding_indices loss_dict["min_encoding_indices"] = min_encoding_indices
return z_q, loss_dict return z_q, loss_dict
def get_codebook_entry(self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None) -> torch.Tensor: def get_codebook_entry(
self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None
) -> torch.Tensor:
# shape specifying (batch, height, width, channel) # shape specifying (batch, height, width, channel)
if self.remap is not None: if self.remap is not None:
assert shape is not None, "Need to give shape for remap" assert shape is not None, "Need to give shape for remap"
@ -448,6 +459,8 @@ class VectorQuantizerWithInputProjection(VectorQuantizer):
elif len(in_shape) == 5: elif len(in_shape) == 5:
z_q = rearrange(z_q, "b (t h w) c -> b c t h w ", w=in_shape[-1], h=in_shape[-2]) z_q = rearrange(z_q, "b (t h w) c -> b c t h w ", w=in_shape[-1], h=in_shape[-2])
else: else:
raise NotImplementedError(f"rearranging not available for {len(in_shape)}-dimensional input.") raise NotImplementedError(
f"rearranging not available for {len(in_shape)}-dimensional input."
)
return z_q, loss_dict return z_q, loss_dict

View File

@ -248,7 +248,9 @@ def make_time_attn(
"vanilla", "vanilla",
"vanilla-xformers", "vanilla-xformers",
], f"attn_type {attn_type} not supported for spatio-temporal attention" ], f"attn_type {attn_type} not supported for spatio-temporal attention"
print(f"making spatial and temporal attention of type '{attn_type}' with {in_channels} in_channels") print(
f"making spatial and temporal attention of type '{attn_type}' with {in_channels} in_channels"
)
if not XFORMERS_IS_AVAILABLE and attn_type == "vanilla-xformers": if not XFORMERS_IS_AVAILABLE and attn_type == "vanilla-xformers":
print( print(
f"Attention mode '{attn_type}' is not available. Falling back to vanilla attention. " f"Attention mode '{attn_type}' is not available. Falling back to vanilla attention. "

View File

@ -125,9 +125,13 @@ class ResnetBlock3D(nn.Module):
self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, pad_mode=pad_mode) self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
if self.in_channels != self.out_channels: if self.in_channels != self.out_channels:
if self.use_conv_shortcut: if self.use_conv_shortcut:
self.conv_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode) self.conv_shortcut = CausalConv3d(
in_channels, out_channels, kernel_size=3, pad_mode=pad_mode
)
else: else:
self.nin_shortcut = torch.nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) self.nin_shortcut = torch.nn.Conv3d(
in_channels, out_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x, temb, zq): def forward(self, x, temb, zq):
h = x h = x
@ -161,7 +165,9 @@ class AttnBlock2D(nn.Module):
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x, zq): def forward(self, x, zq):
h_ = x h_ = x
@ -380,7 +386,11 @@ class NewDecoder3D(nn.Module):
block_in = ch * ch_mult[self.num_resolutions - 1] block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1) curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res) self.z_shape = (1, z_channels, curr_res, curr_res)
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) print(
"Working with z of shape {} = {} dimensions.".format(
self.z_shape, np.prod(self.z_shape)
)
)
# z to block_in # z to block_in
# self.conv_in = torch.nn.Conv3d(z_channels, # self.conv_in = torch.nn.Conv3d(z_channels,

View File

@ -148,9 +148,13 @@ class ResnetBlock3D(nn.Module):
# kernel_size=3, # kernel_size=3,
# stride=1, # stride=1,
# padding=1) # padding=1)
self.conv_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode) self.conv_shortcut = CausalConv3d(
in_channels, out_channels, kernel_size=3, pad_mode=pad_mode
)
else: else:
self.nin_shortcut = torch.nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) self.nin_shortcut = torch.nn.Conv3d(
in_channels, out_channels, kernel_size=1, stride=1, padding=0
)
# self.nin_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=1, pad_mode=pad_mode) # self.nin_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=1, pad_mode=pad_mode)
def forward(self, x, temb, zq): def forward(self, x, temb, zq):
@ -185,7 +189,9 @@ class AttnBlock2D(nn.Module):
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x, zq): def forward(self, x, zq):
h_ = x h_ = x
@ -261,7 +267,11 @@ class MOVQDecoder3D(nn.Module):
block_in = ch * ch_mult[self.num_resolutions - 1] block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1) curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res) self.z_shape = (1, z_channels, curr_res, curr_res)
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) print(
"Working with z of shape {} = {} dimensions.".format(
self.z_shape, np.prod(self.z_shape)
)
)
# z to block_in # z to block_in
# self.conv_in = torch.nn.Conv3d(z_channels, # self.conv_in = torch.nn.Conv3d(z_channels,
@ -420,7 +430,11 @@ class NewDecoder3D(nn.Module):
block_in = ch * ch_mult[self.num_resolutions - 1] block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1) curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res) self.z_shape = (1, z_channels, curr_res, curr_res)
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) print(
"Working with z of shape {} = {} dimensions.".format(
self.z_shape, np.prod(self.z_shape)
)
)
# z to block_in # z to block_in
# self.conv_in = torch.nn.Conv3d(z_channels, # self.conv_in = torch.nn.Conv3d(z_channels,

View File

@ -51,7 +51,12 @@ def nonlinearity(x):
class CausalConv3d(nn.Module): class CausalConv3d(nn.Module):
@beartype @beartype
def __init__( def __init__(
self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], pad_mode="constant", **kwargs self,
chan_in,
chan_out,
kernel_size: Union[int, Tuple[int, int, int]],
pad_mode="constant",
**kwargs,
): ):
super().__init__() super().__init__()
kernel_size = cast_tuple(kernel_size, 3) kernel_size = cast_tuple(kernel_size, 3)
@ -75,11 +80,20 @@ class CausalConv3d(nn.Module):
stride = (stride, 1, 1) stride = (stride, 1, 1)
dilation = (dilation, 1, 1) dilation = (dilation, 1, 1)
self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) self.conv = nn.Conv3d(
chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs
)
def forward(self, x): def forward(self, x):
if self.pad_mode == "constant": if self.pad_mode == "constant":
causal_padding_3d = (self.time_pad, 0, self.width_pad, self.width_pad, self.height_pad, self.height_pad) causal_padding_3d = (
self.time_pad,
0,
self.width_pad,
self.width_pad,
self.height_pad,
self.height_pad,
)
x = F.pad(x, causal_padding_3d, mode="constant", value=0) x = F.pad(x, causal_padding_3d, mode="constant", value=0)
elif self.pad_mode == "first": elif self.pad_mode == "first":
pad_x = torch.cat([x[:, :, :1]] * self.time_pad, dim=2) pad_x = torch.cat([x[:, :, :1]] * self.time_pad, dim=2)
@ -91,7 +105,9 @@ class CausalConv3d(nn.Module):
reflect_x = x[:, :, 1 : self.time_pad + 1, :, :].flip(dims=[2]) reflect_x = x[:, :, 1 : self.time_pad + 1, :, :].flip(dims=[2])
if reflect_x.shape[2] < self.time_pad: if reflect_x.shape[2] < self.time_pad:
reflect_x = torch.cat( reflect_x = torch.cat(
[torch.zeros_like(x[:, :, :1, :, :])] * (self.time_pad - reflect_x.shape[2]) + [reflect_x], dim=2 [torch.zeros_like(x[:, :, :1, :, :])] * (self.time_pad - reflect_x.shape[2])
+ [reflect_x],
dim=2,
) )
x = torch.cat([reflect_x, x], dim=2) x = torch.cat([reflect_x, x], dim=2)
causal_padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) causal_padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
@ -110,7 +126,9 @@ class Upsample3D(nn.Module):
super().__init__() super().__init__()
self.with_conv = with_conv self.with_conv = with_conv
if self.with_conv: if self.with_conv:
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1
)
self.compress_time = compress_time self.compress_time = compress_time
def forward(self, x): def forward(self, x):
@ -149,7 +167,9 @@ class DownSample3D(nn.Module):
out_channels = in_channels out_channels = in_channels
if self.with_conv: if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves # no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0) self.conv = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=2, padding=0
)
self.compress_time = compress_time self.compress_time = compress_time
def forward(self, x): def forward(self, x):
@ -182,7 +202,14 @@ class DownSample3D(nn.Module):
class ResnetBlock3D(nn.Module): class ResnetBlock3D(nn.Module):
def __init__( def __init__(
self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512, pad_mode="constant" self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout,
temb_channels=512,
pad_mode="constant",
): ):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
@ -214,9 +241,13 @@ class ResnetBlock3D(nn.Module):
# kernel_size=3, # kernel_size=3,
# stride=1, # stride=1,
# padding=1) # padding=1)
self.conv_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode) self.conv_shortcut = CausalConv3d(
in_channels, out_channels, kernel_size=3, pad_mode=pad_mode
)
else: else:
self.nin_shortcut = torch.nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) self.nin_shortcut = torch.nn.Conv3d(
in_channels, out_channels, kernel_size=1, stride=1, padding=0
)
# self.nin_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=1, pad_mode=pad_mode) # self.nin_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=1, pad_mode=pad_mode)
def forward(self, x, temb): def forward(self, x, temb):
@ -251,7 +282,9 @@ class AttnBlock2D(nn.Module):
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x): def forward(self, x):
h_ = x h_ = x
@ -365,12 +398,20 @@ class Encoder3D(nn.Module):
# middle # middle
self.mid = nn.Module() self.mid = nn.Module()
self.mid.block_1 = ResnetBlock3D( self.mid.block_1 = ResnetBlock3D(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, pad_mode=pad_mode in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
pad_mode=pad_mode,
) )
# remove attention block # remove attention block
# self.mid.attn_1 = AttnBlock2D(block_in) # self.mid.attn_1 = AttnBlock2D(block_in)
self.mid.block_2 = ResnetBlock3D( self.mid.block_2 = ResnetBlock3D(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, pad_mode=pad_mode in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
pad_mode=pad_mode,
) )
# end # end

View File

@ -80,7 +80,9 @@ class Upsample(nn.Module):
super().__init__() super().__init__()
self.with_conv = with_conv self.with_conv = with_conv
if self.with_conv: if self.with_conv:
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1
)
def forward(self, x): def forward(self, x):
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
@ -95,7 +97,9 @@ class Downsample(nn.Module):
self.with_conv = with_conv self.with_conv = with_conv
if self.with_conv: if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves # no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=2, padding=0
)
def forward(self, x): def forward(self, x):
if self.with_conv: if self.with_conv:
@ -134,9 +138,13 @@ class ResnetBlock(nn.Module):
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels: if self.in_channels != self.out_channels:
if self.use_conv_shortcut: if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) self.conv_shortcut = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
else: else:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) self.nin_shortcut = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x, temb, zq): def forward(self, x, temb, zq):
h = x h = x
@ -170,7 +178,9 @@ class AttnBlock(nn.Module):
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x, zq): def forward(self, x, zq):
h_ = x h_ = x
@ -232,7 +242,11 @@ class MOVQDecoder(nn.Module):
block_in = ch * ch_mult[self.num_resolutions - 1] block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1) curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res) self.z_shape = (1, z_channels, curr_res, curr_res)
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) print(
"Working with z of shape {} = {} dimensions.".format(
self.z_shape, np.prod(self.z_shape)
)
)
# z to block_in # z to block_in
self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)

View File

@ -15,7 +15,16 @@ class VectorQuantizer2(nn.Module):
# NOTE: due to a bug the beta term was applied to the wrong term. for # NOTE: due to a bug the beta term was applied to the wrong term. for
# backwards compatibility we use the buggy version by default, but you can # backwards compatibility we use the buggy version by default, but you can
# specify legacy=False to fix it. # specify legacy=False to fix it.
def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True): def __init__(
self,
n_e,
e_dim,
beta,
remap=None,
unknown_index="random",
sane_index_shape=False,
legacy=True,
):
super().__init__() super().__init__()
self.n_e = n_e self.n_e = n_e
self.e_dim = e_dim self.e_dim = e_dim
@ -51,7 +60,9 @@ class VectorQuantizer2(nn.Module):
new = match.argmax(-1) new = match.argmax(-1)
unknown = match.sum(2) < 1 unknown = match.sum(2) < 1
if self.unknown_index == "random": if self.unknown_index == "random":
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(
device=new.device
)
else: else:
new[unknown] = self.unknown_index new[unknown] = self.unknown_index
return new.reshape(ishape) return new.reshape(ishape)
@ -78,7 +89,8 @@ class VectorQuantizer2(nn.Module):
d = ( d = (
torch.sum(z_flattened**2, dim=1, keepdim=True) torch.sum(z_flattened**2, dim=1, keepdim=True)
+ torch.sum(self.embedding.weight**2, dim=1) + torch.sum(self.embedding.weight**2, dim=1)
- 2 * torch.einsum("bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n")) - 2
* torch.einsum("bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n"))
) )
min_encoding_indices = torch.argmin(d, dim=1) min_encoding_indices = torch.argmin(d, dim=1)
@ -88,9 +100,13 @@ class VectorQuantizer2(nn.Module):
# compute loss for embedding # compute loss for embedding
if not self.legacy: if not self.legacy:
loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2) loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean(
(z_q - z.detach()) ** 2
)
else: else:
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2) loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean(
(z_q - z.detach()) ** 2
)
# preserve gradients # preserve gradients
z_q = z + (z_q - z).detach() z_q = z + (z_q - z).detach()
@ -104,7 +120,9 @@ class VectorQuantizer2(nn.Module):
min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
if self.sane_index_shape: if self.sane_index_shape:
min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3]) min_encoding_indices = min_encoding_indices.reshape(
z_q.shape[0], z_q.shape[2], z_q.shape[3]
)
return z_q, loss, (perplexity, min_encodings, min_encoding_indices) return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
@ -184,7 +202,9 @@ class GumbelQuantize(nn.Module):
new = match.argmax(-1) new = match.argmax(-1)
unknown = match.sum(2) < 1 unknown = match.sum(2) < 1
if self.unknown_index == "random": if self.unknown_index == "random":
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(
device=new.device
)
else: else:
new[unknown] = self.unknown_index new[unknown] = self.unknown_index
return new.reshape(ishape) return new.reshape(ishape)

View File

@ -40,7 +40,9 @@ class Upsample(nn.Module):
super().__init__() super().__init__()
self.with_conv = with_conv self.with_conv = with_conv
if self.with_conv: if self.with_conv:
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1
)
def forward(self, x): def forward(self, x):
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
@ -55,7 +57,9 @@ class Downsample(nn.Module):
self.with_conv = with_conv self.with_conv = with_conv
if self.with_conv: if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves # no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=2, padding=0
)
def forward(self, x): def forward(self, x):
if self.with_conv: if self.with_conv:
@ -68,7 +72,9 @@ class Downsample(nn.Module):
class ResnetBlock(nn.Module): class ResnetBlock(nn.Module):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512): def __init__(
self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512
):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels out_channels = in_channels if out_channels is None else out_channels
@ -84,9 +90,13 @@ class ResnetBlock(nn.Module):
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels: if self.in_channels != self.out_channels:
if self.use_conv_shortcut: if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) self.conv_shortcut = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
else: else:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) self.nin_shortcut = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x, temb): def forward(self, x, temb):
h = x h = x
@ -120,7 +130,9 @@ class AttnBlock(nn.Module):
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x): def forward(self, x):
h_ = x h_ = x
@ -194,7 +206,10 @@ class Encoder(nn.Module):
for i_block in range(self.num_res_blocks): for i_block in range(self.num_res_blocks):
block.append( block.append(
ResnetBlock( ResnetBlock(
in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
) )
) )
block_in = block_out block_in = block_out
@ -326,7 +341,11 @@ class Decoder(nn.Module):
block_in = ch * ch_mult[self.num_resolutions - 1] block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1) curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res) self.z_shape = (1, z_channels, curr_res, curr_res)
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) print(
"Working with z of shape {} = {} dimensions.".format(
self.z_shape, np.prod(self.z_shape)
)
)
# z to block_in # z to block_in
self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
@ -350,7 +369,10 @@ class Decoder(nn.Module):
for i_block in range(self.num_res_blocks + 1): for i_block in range(self.num_res_blocks + 1):
block.append( block.append(
ResnetBlock( ResnetBlock(
in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
) )
) )
block_in = block_out block_in = block_out

View File

@ -136,9 +136,9 @@ def _conv_split(input_, dim, kernel_size):
if cp_rank == 0: if cp_rank == 0:
output = input_.transpose(dim, 0)[: dim_size + kernel_size].transpose(dim, 0) output = input_.transpose(dim, 0)[: dim_size + kernel_size].transpose(dim, 0)
else: else:
output = input_.transpose(dim, 0)[cp_rank * dim_size + 1 : (cp_rank + 1) * dim_size + kernel_size].transpose( output = input_.transpose(dim, 0)[
dim, 0 cp_rank * dim_size + 1 : (cp_rank + 1) * dim_size + kernel_size
) ].transpose(dim, 0)
output = output.contiguous() output = output.contiguous()
# print('out _conv_split, cp_rank:', cp_rank, 'input_size:', output.shape) # print('out _conv_split, cp_rank:', cp_rank, 'input_size:', output.shape)

View File

@ -35,7 +35,9 @@ class Denoiser(nn.Module):
sigma = append_dims(sigma, input.ndim) sigma = append_dims(sigma, input.ndim)
c_skip, c_out, c_in, c_noise = self.scaling(sigma, **additional_model_inputs) c_skip, c_out, c_in, c_noise = self.scaling(sigma, **additional_model_inputs)
c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape)) c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape))
return network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out + input * c_skip return (
network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out + input * c_skip
)
class DiscreteDenoiser(Denoiser): class DiscreteDenoiser(Denoiser):
@ -50,7 +52,9 @@ class DiscreteDenoiser(Denoiser):
flip=True, flip=True,
): ):
super().__init__(weighting_config, scaling_config) super().__init__(weighting_config, scaling_config)
sigmas = instantiate_from_config(discretization_config)(num_idx, do_append_zero=do_append_zero, flip=flip) sigmas = instantiate_from_config(discretization_config)(
num_idx, do_append_zero=do_append_zero, flip=flip
)
self.sigmas = sigmas self.sigmas = sigmas
# self.register_buffer("sigmas", sigmas) # self.register_buffer("sigmas", sigmas)
self.quantize_c_noise = quantize_c_noise self.quantize_c_noise = quantize_c_noise

View File

@ -6,7 +6,9 @@ import torch
class DenoiserScaling(ABC): class DenoiserScaling(ABC):
@abstractmethod @abstractmethod
def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: def __call__(
self, sigma: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
pass pass
@ -14,7 +16,9 @@ class EDMScaling:
def __init__(self, sigma_data: float = 0.5): def __init__(self, sigma_data: float = 0.5):
self.sigma_data = sigma_data self.sigma_data = sigma_data
def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: def __call__(
self, sigma: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5 c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5 c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
@ -23,7 +27,9 @@ class EDMScaling:
class EpsScaling: class EpsScaling:
def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: def __call__(
self, sigma: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
c_skip = torch.ones_like(sigma, device=sigma.device) c_skip = torch.ones_like(sigma, device=sigma.device)
c_out = -sigma c_out = -sigma
c_in = 1 / (sigma**2 + 1.0) ** 0.5 c_in = 1 / (sigma**2 + 1.0) ** 0.5
@ -32,7 +38,9 @@ class EpsScaling:
class VScaling: class VScaling:
def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: def __call__(
self, sigma: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
c_skip = 1.0 / (sigma**2 + 1.0) c_skip = 1.0 / (sigma**2 + 1.0)
c_out = -sigma / (sigma**2 + 1.0) ** 0.5 c_out = -sigma / (sigma**2 + 1.0) ** 0.5
c_in = 1.0 / (sigma**2 + 1.0) ** 0.5 c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
@ -41,7 +49,9 @@ class VScaling:
class VScalingWithEDMcNoise(DenoiserScaling): class VScalingWithEDMcNoise(DenoiserScaling):
def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: def __call__(
self, sigma: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
c_skip = 1.0 / (sigma**2 + 1.0) c_skip = 1.0 / (sigma**2 + 1.0)
c_out = -sigma / (sigma**2 + 1.0) ** 0.5 c_out = -sigma / (sigma**2 + 1.0) ** 0.5
c_in = 1.0 / (sigma**2 + 1.0) ** 0.5 c_in = 1.0 / (sigma**2 + 1.0) ** 0.5

View File

@ -52,7 +52,9 @@ class LegacyDDPMDiscretization(Discretization):
): ):
super().__init__() super().__init__()
self.num_timesteps = num_timesteps self.num_timesteps = num_timesteps
betas = make_beta_schedule("linear", num_timesteps, linear_start=linear_start, linear_end=linear_end) betas = make_beta_schedule(
"linear", num_timesteps, linear_start=linear_start, linear_end=linear_end
)
alphas = 1.0 - betas alphas = 1.0 - betas
self.alphas_cumprod = np.cumprod(alphas, axis=0) self.alphas_cumprod = np.cumprod(alphas, axis=0)
self.to_torch = partial(torch.tensor, dtype=torch.float32) self.to_torch = partial(torch.tensor, dtype=torch.float32)
@ -85,14 +87,18 @@ class ZeroSNRDDPMDiscretization(Discretization):
if keep_start and not post_shift: if keep_start and not post_shift:
linear_start = linear_start / (shift_scale + (1 - shift_scale) * linear_start) linear_start = linear_start / (shift_scale + (1 - shift_scale) * linear_start)
self.num_timesteps = num_timesteps self.num_timesteps = num_timesteps
betas = make_beta_schedule("linear", num_timesteps, linear_start=linear_start, linear_end=linear_end) betas = make_beta_schedule(
"linear", num_timesteps, linear_start=linear_start, linear_end=linear_end
)
alphas = 1.0 - betas alphas = 1.0 - betas
self.alphas_cumprod = np.cumprod(alphas, axis=0) self.alphas_cumprod = np.cumprod(alphas, axis=0)
self.to_torch = partial(torch.tensor, dtype=torch.float32) self.to_torch = partial(torch.tensor, dtype=torch.float32)
# SNR shift # SNR shift
if not post_shift: if not post_shift:
self.alphas_cumprod = self.alphas_cumprod / (shift_scale + (1 - shift_scale) * self.alphas_cumprod) self.alphas_cumprod = self.alphas_cumprod / (
shift_scale + (1 - shift_scale) * self.alphas_cumprod
)
self.post_shift = post_shift self.post_shift = post_shift
self.shift_scale = shift_scale self.shift_scale = shift_scale
@ -113,11 +119,14 @@ class ZeroSNRDDPMDiscretization(Discretization):
alphas_cumprod_sqrt_T = alphas_cumprod_sqrt[-1].clone() alphas_cumprod_sqrt_T = alphas_cumprod_sqrt[-1].clone()
alphas_cumprod_sqrt -= alphas_cumprod_sqrt_T alphas_cumprod_sqrt -= alphas_cumprod_sqrt_T
alphas_cumprod_sqrt *= alphas_cumprod_sqrt_0 / (alphas_cumprod_sqrt_0 - alphas_cumprod_sqrt_T) alphas_cumprod_sqrt *= alphas_cumprod_sqrt_0 / (
alphas_cumprod_sqrt_0 - alphas_cumprod_sqrt_T
)
if self.post_shift: if self.post_shift:
alphas_cumprod_sqrt = ( alphas_cumprod_sqrt = (
alphas_cumprod_sqrt**2 / (self.shift_scale + (1 - self.shift_scale) * alphas_cumprod_sqrt**2) alphas_cumprod_sqrt**2
/ (self.shift_scale + (1 - self.shift_scale) * alphas_cumprod_sqrt**2)
) ** 0.5 ) ** 0.5
if return_idx: if return_idx:

View File

@ -15,7 +15,9 @@ class Guider(ABC):
def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor: def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor:
pass pass
def prepare_inputs(self, x: torch.Tensor, s: float, c: Dict, uc: Dict) -> Tuple[torch.Tensor, float, Dict]: def prepare_inputs(
self, x: torch.Tensor, s: float, c: Dict, uc: Dict
) -> Tuple[torch.Tensor, float, Dict]:
pass pass
@ -57,7 +59,8 @@ class DynamicCFG(VanillaCFG):
def __init__(self, scale, exp, num_steps, dyn_thresh_config=None): def __init__(self, scale, exp, num_steps, dyn_thresh_config=None):
super().__init__(scale, dyn_thresh_config) super().__init__(scale, dyn_thresh_config)
scale_schedule = ( scale_schedule = (
lambda scale, sigma, step_index: 1 + scale * (1 - math.cos(math.pi * (step_index / num_steps) ** exp)) / 2 lambda scale, sigma, step_index: 1
+ scale * (1 - math.cos(math.pi * (step_index / num_steps) ** exp)) / 2
) )
self.scale_schedule = partial(scale_schedule, scale) self.scale_schedule = partial(scale_schedule, scale)
self.dyn_thresh = instantiate_from_config( self.dyn_thresh = instantiate_from_config(

View File

@ -20,7 +20,9 @@ from torch import nn
class LoRALinearLayer(nn.Module): class LoRALinearLayer(nn.Module):
def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None): def __init__(
self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None
):
super().__init__() super().__init__()
self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype) self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
@ -50,11 +52,20 @@ class LoRALinearLayer(nn.Module):
class LoRAConv2dLayer(nn.Module): class LoRAConv2dLayer(nn.Module):
def __init__( def __init__(
self, in_features, out_features, rank=4, kernel_size=(1, 1), stride=(1, 1), padding=0, network_alpha=None self,
in_features,
out_features,
rank=4,
kernel_size=(1, 1),
stride=(1, 1),
padding=0,
network_alpha=None,
): ):
super().__init__() super().__init__()
self.down = nn.Conv2d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) self.down = nn.Conv2d(
in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False
)
# according to the official kohya_ss trainer kernel_size are always fixed for the up layer # according to the official kohya_ss trainer kernel_size are always fixed for the up layer
# # see: https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L129 # # see: https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L129
self.up = nn.Conv2d(rank, out_features, kernel_size=(1, 1), stride=(1, 1), bias=False) self.up = nn.Conv2d(rank, out_features, kernel_size=(1, 1), stride=(1, 1), bias=False)
@ -85,7 +96,9 @@ class LoRACompatibleConv(nn.Conv2d):
A convolutional layer that can be used with LoRA. A convolutional layer that can be used with LoRA.
""" """
def __init__(self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, scale: float = 1.0, **kwargs): def __init__(
self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, scale: float = 1.0, **kwargs
):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.lora_layer = lora_layer self.lora_layer = lora_layer
self.scale = scale self.scale = scale
@ -144,7 +157,13 @@ class LoRACompatibleConv(nn.Conv2d):
# make sure to the functional Conv2D function as otherwise torch.compile's graph will break # make sure to the functional Conv2D function as otherwise torch.compile's graph will break
# see: https://github.com/huggingface/diffusers/pull/4315 # see: https://github.com/huggingface/diffusers/pull/4315
return F.conv2d( return F.conv2d(
hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups hidden_states,
self.weight,
self.bias,
self.stride,
self.padding,
self.dilation,
self.groups,
) )
else: else:
return super().forward(hidden_states) + (scale * self.lora_layer(hidden_states)) return super().forward(hidden_states) + (scale * self.lora_layer(hidden_states))
@ -155,7 +174,9 @@ class LoRACompatibleLinear(nn.Linear):
A Linear layer that can be used with LoRA. A Linear layer that can be used with LoRA.
""" """
def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, scale: float = 1.0, **kwargs): def __init__(
self, *args, lora_layer: Optional[LoRALinearLayer] = None, scale: float = 1.0, **kwargs
):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.lora_layer = lora_layer self.lora_layer = lora_layer
self.scale = scale self.scale = scale
@ -197,7 +218,9 @@ class LoRACompatibleLinear(nn.Linear):
w_up = self.w_up.to(device=device).float() w_up = self.w_up.to(device=device).float()
w_down = self.w_down.to(device).float() w_down = self.w_down.to(device).float()
unfused_weight = fused_weight.float() - (self._lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0]) unfused_weight = fused_weight.float() - (
self._lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0]
)
self.weight.data = unfused_weight.to(device=device, dtype=dtype) self.weight.data = unfused_weight.to(device=device, dtype=dtype)
self.w_up = None self.w_up = None
@ -252,7 +275,9 @@ def _find_modules_v2(
# Get the targets we should replace all linears under # Get the targets we should replace all linears under
if ancestor_class is not None: if ancestor_class is not None:
ancestors = (module for module in model.modules() if module.__class__.__name__ in ancestor_class) ancestors = (
module for module in model.modules() if module.__class__.__name__ in ancestor_class
)
else: else:
# this, incase you want to naively iterate over all modules. # this, incase you want to naively iterate over all modules.
ancestors = [module for module in model.modules()] ancestors = [module for module in model.modules()]
@ -274,7 +299,9 @@ def _find_modules_v2(
if flag: if flag:
continue continue
# Skip this linear if it's a child of a LoraInjectedLinear # Skip this linear if it's a child of a LoraInjectedLinear
if exclude_children_of and any([isinstance(parent, _class) for _class in exclude_children_of]): if exclude_children_of and any(
[isinstance(parent, _class) for _class in exclude_children_of]
):
continue continue
# Otherwise, yield it # Otherwise, yield it
yield parent, name, module yield parent, name, module

View File

@ -38,13 +38,17 @@ class StandardDiffusionLoss(nn.Module):
def __call__(self, network, denoiser, conditioner, input, batch): def __call__(self, network, denoiser, conditioner, input, batch):
cond = conditioner(batch) cond = conditioner(batch)
additional_model_inputs = {key: batch[key] for key in self.batch2model_keys.intersection(batch)} additional_model_inputs = {
key: batch[key] for key in self.batch2model_keys.intersection(batch)
}
sigmas = self.sigma_sampler(input.shape[0]).to(input.device) sigmas = self.sigma_sampler(input.shape[0]).to(input.device)
noise = torch.randn_like(input) noise = torch.randn_like(input)
if self.offset_noise_level > 0.0: if self.offset_noise_level > 0.0:
noise = ( noise = (
noise + append_dims(torch.randn(input.shape[0]).to(input.device), input.ndim) * self.offset_noise_level noise
+ append_dims(torch.randn(input.shape[0]).to(input.device), input.ndim)
* self.offset_noise_level
) )
noise = noise.to(input.dtype) noise = noise.to(input.dtype)
noised_input = input.float() + noise * append_dims(sigmas, input.ndim) noised_input = input.float() + noise * append_dims(sigmas, input.ndim)
@ -63,7 +67,9 @@ class StandardDiffusionLoss(nn.Module):
class VideoDiffusionLoss(StandardDiffusionLoss): class VideoDiffusionLoss(StandardDiffusionLoss):
def __init__(self, block_scale=None, block_size=None, min_snr_value=None, fixed_frames=0, **kwargs): def __init__(
self, block_scale=None, block_size=None, min_snr_value=None, fixed_frames=0, **kwargs
):
self.fixed_frames = fixed_frames self.fixed_frames = fixed_frames
self.block_scale = block_scale self.block_scale = block_scale
self.block_size = block_size self.block_size = block_size
@ -72,7 +78,9 @@ class VideoDiffusionLoss(StandardDiffusionLoss):
def __call__(self, network, denoiser, conditioner, input, batch): def __call__(self, network, denoiser, conditioner, input, batch):
cond = conditioner(batch) cond = conditioner(batch)
additional_model_inputs = {key: batch[key] for key in self.batch2model_keys.intersection(batch)} additional_model_inputs = {
key: batch[key] for key in self.batch2model_keys.intersection(batch)
}
alphas_cumprod_sqrt, idx = self.sigma_sampler(input.shape[0], return_idx=True) alphas_cumprod_sqrt, idx = self.sigma_sampler(input.shape[0], return_idx=True)
alphas_cumprod_sqrt = alphas_cumprod_sqrt.to(input.device) alphas_cumprod_sqrt = alphas_cumprod_sqrt.to(input.device)
@ -86,24 +94,30 @@ class VideoDiffusionLoss(StandardDiffusionLoss):
src = global_rank * mp_size src = global_rank * mp_size
torch.distributed.broadcast(idx, src=src, group=mpu.get_model_parallel_group()) torch.distributed.broadcast(idx, src=src, group=mpu.get_model_parallel_group())
torch.distributed.broadcast(noise, src=src, group=mpu.get_model_parallel_group()) torch.distributed.broadcast(noise, src=src, group=mpu.get_model_parallel_group())
torch.distributed.broadcast(alphas_cumprod_sqrt, src=src, group=mpu.get_model_parallel_group()) torch.distributed.broadcast(
alphas_cumprod_sqrt, src=src, group=mpu.get_model_parallel_group()
)
additional_model_inputs["idx"] = idx additional_model_inputs["idx"] = idx
if self.offset_noise_level > 0.0: if self.offset_noise_level > 0.0:
noise = ( noise = (
noise + append_dims(torch.randn(input.shape[0]).to(input.device), input.ndim) * self.offset_noise_level noise
+ append_dims(torch.randn(input.shape[0]).to(input.device), input.ndim)
* self.offset_noise_level
) )
noised_input = input.float() * append_dims(alphas_cumprod_sqrt, input.ndim) + noise * append_dims( noised_input = input.float() * append_dims(
(1 - alphas_cumprod_sqrt**2) ** 0.5, input.ndim alphas_cumprod_sqrt, input.ndim
) ) + noise * append_dims((1 - alphas_cumprod_sqrt**2) ** 0.5, input.ndim)
if "concat_images" in batch.keys(): if "concat_images" in batch.keys():
cond["concat"] = batch["concat_images"] cond["concat"] = batch["concat_images"]
# [2, 13, 16, 60, 90],[2] dict_keys(['crossattn', 'concat']) dict_keys(['idx']) # [2, 13, 16, 60, 90],[2] dict_keys(['crossattn', 'concat']) dict_keys(['idx'])
model_output = denoiser(network, noised_input, alphas_cumprod_sqrt, cond, **additional_model_inputs) model_output = denoiser(
network, noised_input, alphas_cumprod_sqrt, cond, **additional_model_inputs
)
w = append_dims(1 / (1 - alphas_cumprod_sqrt**2), input.ndim) # v-pred w = append_dims(1 / (1 - alphas_cumprod_sqrt**2), input.ndim) # v-pred
if self.min_snr_value is not None: if self.min_snr_value is not None:

View File

@ -47,7 +47,9 @@ def nonlinearity(x):
def Normalize(in_channels, num_groups=32): def Normalize(in_channels, num_groups=32):
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) return torch.nn.GroupNorm(
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
)
class Upsample(nn.Module): class Upsample(nn.Module):
@ -55,7 +57,9 @@ class Upsample(nn.Module):
super().__init__() super().__init__()
self.with_conv = with_conv self.with_conv = with_conv
if self.with_conv: if self.with_conv:
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1
)
def forward(self, x): def forward(self, x):
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
@ -70,7 +74,9 @@ class Downsample(nn.Module):
self.with_conv = with_conv self.with_conv = with_conv
if self.with_conv: if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves # no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=2, padding=0
)
def forward(self, x): def forward(self, x):
if self.with_conv: if self.with_conv:
@ -107,9 +113,13 @@ class ResnetBlock(nn.Module):
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels: if self.in_channels != self.out_channels:
if self.use_conv_shortcut: if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) self.conv_shortcut = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
else: else:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) self.nin_shortcut = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x, temb): def forward(self, x, temb):
h = x h = x
@ -150,7 +160,9 @@ class AttnBlock(nn.Module):
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
def attention(self, h_: torch.Tensor) -> torch.Tensor: def attention(self, h_: torch.Tensor) -> torch.Tensor:
h_ = self.norm(h_) h_ = self.norm(h_)
@ -160,7 +172,9 @@ class AttnBlock(nn.Module):
b, c, h, w = q.shape b, c, h, w = q.shape
q, k, v = map(lambda x: rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v)) q, k, v = map(lambda x: rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v))
h_ = torch.nn.functional.scaled_dot_product_attention(q, k, v) # scale is dim ** -0.5 per default h_ = torch.nn.functional.scaled_dot_product_attention(
q, k, v
) # scale is dim ** -0.5 per default
# compute attention # compute attention
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
@ -188,7 +202,9 @@ class MemoryEfficientAttnBlock(nn.Module):
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.attention_op: Optional[Any] = None self.attention_op: Optional[Any] = None
def attention(self, h_: torch.Tensor) -> torch.Tensor: def attention(self, h_: torch.Tensor) -> torch.Tensor:
@ -211,7 +227,12 @@ class MemoryEfficientAttnBlock(nn.Module):
) )
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
out = out.unsqueeze(0).reshape(B, 1, out.shape[1], C).permute(0, 2, 1, 3).reshape(B, out.shape[1], C) out = (
out.unsqueeze(0)
.reshape(B, 1, out.shape[1], C)
.permute(0, 2, 1, 3)
.reshape(B, out.shape[1], C)
)
return rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C) return rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C)
def forward(self, x, **kwargs): def forward(self, x, **kwargs):
@ -581,7 +602,11 @@ class Decoder(nn.Module):
block_in = ch * ch_mult[self.num_resolutions - 1] block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1) curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res) self.z_shape = (1, z_channels, curr_res, curr_res)
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) print(
"Working with z of shape {} = {} dimensions.".format(
self.z_shape, np.prod(self.z_shape)
)
)
make_attn_cls = self._make_attn() make_attn_cls = self._make_attn()
make_resblock_cls = self._make_resblock() make_resblock_cls = self._make_resblock()

View File

@ -47,7 +47,9 @@ class AttentionPool2d(nn.Module):
output_dim: int = None, output_dim: int = None,
): ):
super().__init__() super().__init__()
self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5) self.positional_embedding = nn.Parameter(
th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5
)
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
self.num_heads = embed_dim // num_heads_channels self.num_heads = embed_dim // num_heads_channels
@ -303,7 +305,9 @@ class ResBlock(TimestepBlock):
if self.out_channels == channels: if self.out_channels == channels:
self.skip_connection = nn.Identity() self.skip_connection = nn.Identity()
elif use_conv: elif use_conv:
self.skip_connection = conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding) self.skip_connection = conv_nd(
dims, channels, self.out_channels, kernel_size, padding=padding
)
else: else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
@ -437,7 +441,9 @@ class QKVAttentionLegacy(nn.Module):
ch = width // (3 * self.n_heads) ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch)) scale = 1 / math.sqrt(math.sqrt(ch))
weight = th.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards weight = th.einsum(
"bct,bcs->bts", q * scale, k * scale
) # More stable with f16 than dividing afterwards
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
a = th.einsum("bts,bcs->bct", weight, v) a = th.einsum("bts,bcs->bct", weight, v)
return a.reshape(bs, -1, length) return a.reshape(bs, -1, length)
@ -574,9 +580,7 @@ class UNetModel(nn.Module):
), "Fool!! You forgot to include the dimension of your cross-attention conditioning..." ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
if context_dim is not None: if context_dim is not None:
assert ( assert use_spatial_transformer, "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
use_spatial_transformer
), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
if type(context_dim) == ListConfig: if type(context_dim) == ListConfig:
context_dim = list(context_dim) context_dim = list(context_dim)
@ -640,7 +644,9 @@ class UNetModel(nn.Module):
self.num_heads_upsample = num_heads_upsample self.num_heads_upsample = num_heads_upsample
self.predict_codebook_ids = n_embed is not None self.predict_codebook_ids = n_embed is not None
assert use_fairscale_checkpoint != use_checkpoint or not (use_checkpoint or use_fairscale_checkpoint) assert use_fairscale_checkpoint != use_checkpoint or not (
use_checkpoint or use_fairscale_checkpoint
)
self.use_fairscale_checkpoint = False self.use_fairscale_checkpoint = False
checkpoint_wrapper_fn = ( checkpoint_wrapper_fn = (
@ -942,7 +948,9 @@ class UNetModel(nn.Module):
print(f"loading lora from {ckpt_path}") print(f"loading lora from {ckpt_path}")
sd = th.load(ckpt_path)["module"] sd = th.load(ckpt_path)["module"]
sd = { sd = {
key[len("model.diffusion_model") :]: sd[key] for key in sd if key.startswith("model.diffusion_model") key[len("model.diffusion_model") :]: sd[key]
for key in sd
if key.startswith("model.diffusion_model")
} }
self.load_state_dict(sd, strict=False) self.load_state_dict(sd, strict=False)
@ -978,7 +986,9 @@ class UNetModel(nn.Module):
self.num_classes is not None self.num_classes is not None
), "must specify y if and only if the model is class-conditional" ), "must specify y if and only if the model is class-conditional"
hs = [] hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False, dtype=self.dtype) t_emb = timestep_embedding(
timesteps, self.model_channels, repeat_only=False, dtype=self.dtype
)
emb = self.time_embed(t_emb) emb = self.time_embed(t_emb)
if self.num_classes is not None: if self.num_classes is not None:

View File

@ -1,8 +1,7 @@
""" """
Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py
""" """
from typing import Dict, Union from typing import Dict, Union
import torch import torch
@ -85,9 +84,7 @@ class SingleStepDiffusionSampler(BaseDiffusionSampler):
class EDMSampler(SingleStepDiffusionSampler): class EDMSampler(SingleStepDiffusionSampler):
def __init__( def __init__(self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs):
self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs
):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.s_churn = s_churn self.s_churn = s_churn
@ -106,15 +103,11 @@ class EDMSampler(SingleStepDiffusionSampler):
dt = append_dims(next_sigma - sigma_hat, x.ndim) dt = append_dims(next_sigma - sigma_hat, x.ndim)
euler_step = self.euler_step(x, d, dt) euler_step = self.euler_step(x, d, dt)
x = self.possible_correction_step( x = self.possible_correction_step(euler_step, x, d, dt, next_sigma, denoiser, cond, uc)
euler_step, x, d, dt, next_sigma, denoiser, cond, uc
)
return x return x
def __call__(self, denoiser, x, cond, uc=None, num_steps=None): def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
x, cond, uc, num_steps
)
for i in self.get_sigma_gen(num_sigmas): for i in self.get_sigma_gen(num_sigmas):
gamma = ( gamma = (
@ -136,30 +129,23 @@ class EDMSampler(SingleStepDiffusionSampler):
class DDIMSampler(SingleStepDiffusionSampler): class DDIMSampler(SingleStepDiffusionSampler):
def __init__( def __init__(self, s_noise=0.1, *args, **kwargs):
self, s_noise=0.1, *args, **kwargs
):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.s_noise = s_noise self.s_noise = s_noise
def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, s_noise=0.0): def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, s_noise=0.0):
denoised = self.denoise(x, denoiser, sigma, cond, uc) denoised = self.denoise(x, denoiser, sigma, cond, uc)
d = to_d(x, sigma, denoised) d = to_d(x, sigma, denoised)
dt = append_dims(next_sigma * (1 - s_noise**2)**0.5 - sigma, x.ndim) dt = append_dims(next_sigma * (1 - s_noise**2) ** 0.5 - sigma, x.ndim)
euler_step = x + dt * d + s_noise * append_dims(next_sigma, x.ndim) * torch.randn_like(x) euler_step = x + dt * d + s_noise * append_dims(next_sigma, x.ndim) * torch.randn_like(x)
x = self.possible_correction_step( x = self.possible_correction_step(euler_step, x, d, dt, next_sigma, denoiser, cond, uc)
euler_step, x, d, dt, next_sigma, denoiser, cond, uc
)
return x return x
def __call__(self, denoiser, x, cond, uc=None, num_steps=None): def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
x, cond, uc, num_steps
)
for i in self.get_sigma_gen(num_sigmas): for i in self.get_sigma_gen(num_sigmas):
x = self.sampler_step( x = self.sampler_step(
@ -198,9 +184,7 @@ class AncestralSampler(SingleStepDiffusionSampler):
return x return x
def __call__(self, denoiser, x, cond, uc=None, num_steps=None): def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
x, cond, uc, num_steps
)
for i in self.get_sigma_gen(num_sigmas): for i in self.get_sigma_gen(num_sigmas):
x = self.sampler_step( x = self.sampler_step(
@ -227,43 +211,32 @@ class LinearMultistepSampler(BaseDiffusionSampler):
self.order = order self.order = order
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs): def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
x, cond, uc, num_steps
)
ds = [] ds = []
sigmas_cpu = sigmas.detach().cpu().numpy() sigmas_cpu = sigmas.detach().cpu().numpy()
for i in self.get_sigma_gen(num_sigmas): for i in self.get_sigma_gen(num_sigmas):
sigma = s_in * sigmas[i] sigma = s_in * sigmas[i]
denoised = denoiser( denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs)
*self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs
)
denoised = self.guider(denoised, sigma) denoised = self.guider(denoised, sigma)
d = to_d(x, sigma, denoised) d = to_d(x, sigma, denoised)
ds.append(d) ds.append(d)
if len(ds) > self.order: if len(ds) > self.order:
ds.pop(0) ds.pop(0)
cur_order = min(i + 1, self.order) cur_order = min(i + 1, self.order)
coeffs = [ coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)]
linear_multistep_coeff(cur_order, sigmas_cpu, i, j)
for j in range(cur_order)
]
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds))) x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
return x return x
class EulerEDMSampler(EDMSampler): class EulerEDMSampler(EDMSampler):
def possible_correction_step( def possible_correction_step(self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc):
self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc
):
return euler_step return euler_step
class HeunEDMSampler(EDMSampler): class HeunEDMSampler(EDMSampler):
def possible_correction_step( def possible_correction_step(self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc):
self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc
):
if torch.sum(next_sigma) < 1e-14: if torch.sum(next_sigma) < 1e-14:
# Save a network evaluation if all noise levels are 0 # Save a network evaluation if all noise levels are 0
return euler_step return euler_step
@ -273,9 +246,7 @@ class HeunEDMSampler(EDMSampler):
d_prime = (d + d_new) / 2.0 d_prime = (d + d_new) / 2.0
# apply correction if noise level is not 0 # apply correction if noise level is not 0
x = torch.where( x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step)
append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step
)
return x return x
@ -314,9 +285,7 @@ class DPMPP2SAncestralSampler(AncestralSampler):
x = x_euler x = x_euler
else: else:
h, s, t, t_next = self.get_variables(sigma, sigma_down) h, s, t, t_next = self.get_variables(sigma, sigma_down)
mult = [ mult = [append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next)]
append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next)
]
x2 = mult[0] * x - mult[1] * denoised x2 = mult[0] * x - mult[1] * denoised
denoised2 = self.denoise(x2, denoiser, to_sigma(s), cond, uc) denoised2 = self.denoise(x2, denoiser, to_sigma(s), cond, uc)
@ -367,8 +336,7 @@ class DPMPP2MSampler(BaseDiffusionSampler):
h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma) h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma)
mult = [ mult = [
append_dims(mult, x.ndim) append_dims(mult, x.ndim) for mult in self.get_mult(h, r, t, t_next, previous_sigma)
for mult in self.get_mult(h, r, t, t_next, previous_sigma)
] ]
x_standard = mult[0] * x - mult[1] * denoised x_standard = mult[0] * x - mult[1] * denoised
@ -380,16 +348,12 @@ class DPMPP2MSampler(BaseDiffusionSampler):
x_advanced = mult[0] * x - mult[1] * denoised_d x_advanced = mult[0] * x - mult[1] * denoised_d
# apply correction if noise level is not 0 and not first step # apply correction if noise level is not 0 and not first step
x = torch.where( x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard)
append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard
)
return x, denoised return x, denoised
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs): def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
x, cond, uc, num_steps
)
old_denoised = None old_denoised = None
for i in self.get_sigma_gen(num_sigmas): for i in self.get_sigma_gen(num_sigmas):
@ -406,6 +370,7 @@ class DPMPP2MSampler(BaseDiffusionSampler):
return x return x
class SDEDPMPP2MSampler(BaseDiffusionSampler): class SDEDPMPP2MSampler(BaseDiffusionSampler):
def get_variables(self, sigma, next_sigma, previous_sigma=None): def get_variables(self, sigma, next_sigma, previous_sigma=None):
t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)] t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)]
@ -420,7 +385,7 @@ class SDEDPMPP2MSampler(BaseDiffusionSampler):
def get_mult(self, h, r, t, t_next, previous_sigma): def get_mult(self, h, r, t, t_next, previous_sigma):
mult1 = to_sigma(t_next) / to_sigma(t) * (-h).exp() mult1 = to_sigma(t_next) / to_sigma(t) * (-h).exp()
mult2 = (-2*h).expm1() mult2 = (-2 * h).expm1()
if previous_sigma is not None: if previous_sigma is not None:
mult3 = 1 + 1 / (2 * r) mult3 = 1 + 1 / (2 * r)
@ -444,10 +409,9 @@ class SDEDPMPP2MSampler(BaseDiffusionSampler):
h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma) h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma)
mult = [ mult = [
append_dims(mult, x.ndim) append_dims(mult, x.ndim) for mult in self.get_mult(h, r, t, t_next, previous_sigma)
for mult in self.get_mult(h, r, t, t_next, previous_sigma)
] ]
mult_noise = append_dims(next_sigma * (1 - (-2*h).exp())**0.5, x.ndim) mult_noise = append_dims(next_sigma * (1 - (-2 * h).exp()) ** 0.5, x.ndim)
x_standard = mult[0] * x - mult[1] * denoised + mult_noise * torch.randn_like(x) x_standard = mult[0] * x - mult[1] * denoised + mult_noise * torch.randn_like(x)
if old_denoised is None or torch.sum(next_sigma) < 1e-14: if old_denoised is None or torch.sum(next_sigma) < 1e-14:
@ -458,16 +422,12 @@ class SDEDPMPP2MSampler(BaseDiffusionSampler):
x_advanced = mult[0] * x - mult[1] * denoised_d + mult_noise * torch.randn_like(x) x_advanced = mult[0] * x - mult[1] * denoised_d + mult_noise * torch.randn_like(x)
# apply correction if noise level is not 0 and not first step # apply correction if noise level is not 0 and not first step
x = torch.where( x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard)
append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard
)
return x, denoised return x, denoised
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, **kwargs): def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, **kwargs):
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
x, cond, uc, num_steps
)
old_denoised = None old_denoised = None
for i in self.get_sigma_gen(num_sigmas): for i in self.get_sigma_gen(num_sigmas):
@ -484,6 +444,7 @@ class SDEDPMPP2MSampler(BaseDiffusionSampler):
return x return x
class SdeditEDMSampler(EulerEDMSampler): class SdeditEDMSampler(EulerEDMSampler):
def __init__(self, edit_ratio=0.5, *args, **kwargs): def __init__(self, edit_ratio=0.5, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -525,8 +486,8 @@ class SdeditEDMSampler(EulerEDMSampler):
return x return x
class VideoDDIMSampler(BaseDiffusionSampler):
class VideoDDIMSampler(BaseDiffusionSampler):
def __init__(self, fixed_frames=0, sdedit=False, **kwargs): def __init__(self, fixed_frames=0, sdedit=False, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.fixed_frames = fixed_frames self.fixed_frames = fixed_frames
@ -534,10 +495,15 @@ class VideoDDIMSampler(BaseDiffusionSampler):
def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None): def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
alpha_cumprod_sqrt, timesteps = self.discretization( alpha_cumprod_sqrt, timesteps = self.discretization(
self.num_steps if num_steps is None else num_steps, device=self.device, return_idx=True, do_append_zero=False self.num_steps if num_steps is None else num_steps,
device=self.device,
return_idx=True,
do_append_zero=False,
) )
alpha_cumprod_sqrt = torch.cat([alpha_cumprod_sqrt, alpha_cumprod_sqrt.new_ones([1])]) alpha_cumprod_sqrt = torch.cat([alpha_cumprod_sqrt, alpha_cumprod_sqrt.new_ones([1])])
timesteps = torch.cat([torch.tensor(list(timesteps)).new_zeros([1])-1, torch.tensor(list(timesteps))]) timesteps = torch.cat(
[torch.tensor(list(timesteps)).new_zeros([1]) - 1, torch.tensor(list(timesteps))]
)
uc = default(uc, cond) uc = default(uc, cond)
@ -547,7 +513,19 @@ class VideoDDIMSampler(BaseDiffusionSampler):
return x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps return x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps
def denoise(self, x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep=None, idx=None, scale=None, scale_emb=None, ofs=None): def denoise(
self,
x,
denoiser,
alpha_cumprod_sqrt,
cond,
uc,
timestep=None,
idx=None,
scale=None,
scale_emb=None,
ofs=None,
):
additional_model_inputs = {} additional_model_inputs = {}
if ofs is not None: if ofs is not None:
@ -557,26 +535,62 @@ class VideoDDIMSampler(BaseDiffusionSampler):
additional_model_inputs['idx'] = x.new_ones([x.shape[0]]) * timestep additional_model_inputs['idx'] = x.new_ones([x.shape[0]]) * timestep
if scale_emb is not None: if scale_emb is not None:
additional_model_inputs['scale_emb'] = scale_emb additional_model_inputs['scale_emb'] = scale_emb
denoised = denoiser(x, alpha_cumprod_sqrt, cond, **additional_model_inputs).to(torch.float32) denoised = denoiser(x, alpha_cumprod_sqrt, cond, **additional_model_inputs).to(
torch.float32
)
else: else:
additional_model_inputs['idx'] = torch.cat([x.new_ones([x.shape[0]]) * timestep] * 2) additional_model_inputs['idx'] = torch.cat([x.new_ones([x.shape[0]]) * timestep] * 2)
denoised = denoiser(*self.guider.prepare_inputs(x, alpha_cumprod_sqrt, cond, uc), **additional_model_inputs).to(torch.float32) denoised = denoiser(
*self.guider.prepare_inputs(x, alpha_cumprod_sqrt, cond, uc),
**additional_model_inputs,
).to(torch.float32)
if isinstance(self.guider, DynamicCFG): if isinstance(self.guider, DynamicCFG):
denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt**2)**0.5, step_index=self.num_steps - timestep, scale=scale) denoised = self.guider(
denoised,
(1 - alpha_cumprod_sqrt**2) ** 0.5,
step_index=self.num_steps - timestep,
scale=scale,
)
else: else:
denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt**2)**0.5, scale=scale) denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt**2) ** 0.5, scale=scale)
return denoised return denoised
def sampler_step(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, denoiser, x, cond, uc=None, idx=None, timestep=None, scale=None, scale_emb=None, ofs=None): def sampler_step(
denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx, scale=scale, scale_emb=scale_emb, ofs=ofs).to(torch.float32) # 1020 self,
alpha_cumprod_sqrt,
next_alpha_cumprod_sqrt,
denoiser,
x,
cond,
uc=None,
idx=None,
timestep=None,
scale=None,
scale_emb=None,
ofs=None,
):
denoised = self.denoise(
x,
denoiser,
alpha_cumprod_sqrt,
cond,
uc,
timestep,
idx,
scale=scale,
scale_emb=scale_emb,
ofs=ofs,
).to(torch.float32) # 1020
a_t = ((1-next_alpha_cumprod_sqrt**2)/(1-alpha_cumprod_sqrt**2))**0.5 a_t = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5
b_t = next_alpha_cumprod_sqrt - alpha_cumprod_sqrt * a_t b_t = next_alpha_cumprod_sqrt - alpha_cumprod_sqrt * a_t
x = append_dims(a_t, x.ndim) * x + append_dims(b_t, x.ndim) * denoised x = append_dims(a_t, x.ndim) * x + append_dims(b_t, x.ndim) * denoised
return x return x
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None, ofs=None): # 1020 def __call__(
self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None, ofs=None
): # 1020
x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop( x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop(
x, cond, uc, num_steps x, cond, uc, num_steps
) )
@ -590,17 +604,16 @@ class VideoDDIMSampler(BaseDiffusionSampler):
cond, cond,
uc, uc,
idx=self.num_steps - i, idx=self.num_steps - i,
timestep=timesteps[-(i+1)], timestep=timesteps[-(i + 1)],
scale=scale, scale=scale,
scale_emb=scale_emb, scale_emb=scale_emb,
ofs=ofs # 1020 ofs=ofs, # 1020
) )
return x return x
class Image2VideoDDIMSampler(BaseDiffusionSampler): class Image2VideoDDIMSampler(BaseDiffusionSampler):
def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None): def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
alpha_cumprod_sqrt, timesteps = self.discretization( alpha_cumprod_sqrt, timesteps = self.discretization(
self.num_steps if num_steps is None else num_steps, device=self.device, return_idx=True self.num_steps if num_steps is None else num_steps, device=self.device, return_idx=True
@ -616,22 +629,36 @@ class Image2VideoDDIMSampler(BaseDiffusionSampler):
def denoise(self, x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep=None): def denoise(self, x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep=None):
additional_model_inputs = {} additional_model_inputs = {}
additional_model_inputs['idx'] = torch.cat([x.new_ones([x.shape[0]]) * timestep] * 2) additional_model_inputs['idx'] = torch.cat([x.new_ones([x.shape[0]]) * timestep] * 2)
denoised = denoiser(*self.guider.prepare_inputs(x, alpha_cumprod_sqrt, cond, uc), **additional_model_inputs).to( denoised = denoiser(
torch.float32) *self.guider.prepare_inputs(x, alpha_cumprod_sqrt, cond, uc), **additional_model_inputs
).to(torch.float32)
if isinstance(self.guider, DynamicCFG): if isinstance(self.guider, DynamicCFG):
denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt ** 2) ** 0.5, step_index=self.num_steps - timestep) denoised = self.guider(
denoised, (1 - alpha_cumprod_sqrt**2) ** 0.5, step_index=self.num_steps - timestep
)
else: else:
denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt ** 2) ** 0.5) denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt**2) ** 0.5)
return denoised return denoised
def sampler_step(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, denoiser, x, cond, uc=None, idx=None, def sampler_step(
timestep=None): self,
alpha_cumprod_sqrt,
next_alpha_cumprod_sqrt,
denoiser,
x,
cond,
uc=None,
idx=None,
timestep=None,
):
# 此处的sigma实际上是alpha_cumprod_sqrt # 此处的sigma实际上是alpha_cumprod_sqrt
denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep).to(torch.float32) denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep).to(
torch.float32
)
if idx == 1: if idx == 1:
return denoised return denoised
a_t = ((1 - next_alpha_cumprod_sqrt ** 2) / (1 - alpha_cumprod_sqrt ** 2)) ** 0.5 a_t = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5
b_t = next_alpha_cumprod_sqrt - alpha_cumprod_sqrt * a_t b_t = next_alpha_cumprod_sqrt - alpha_cumprod_sqrt * a_t
x = append_dims(a_t, x.ndim) * x + append_dims(b_t, x.ndim) * denoised x = append_dims(a_t, x.ndim) * x + append_dims(b_t, x.ndim) * denoised
@ -651,31 +678,36 @@ class Image2VideoDDIMSampler(BaseDiffusionSampler):
cond, cond,
uc, uc,
idx=self.num_steps - i, idx=self.num_steps - i,
timestep=timesteps[-(i + 1)] timestep=timesteps[-(i + 1)],
) )
return x return x
class VPSDEDPMPP2MSampler(VideoDDIMSampler): class VPSDEDPMPP2MSampler(VideoDDIMSampler):
def get_variables(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt=None): def get_variables(
alpha_cumprod = alpha_cumprod_sqrt ** 2 self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt=None
lamb = ((alpha_cumprod / (1-alpha_cumprod))**0.5).log() ):
next_alpha_cumprod = next_alpha_cumprod_sqrt ** 2 alpha_cumprod = alpha_cumprod_sqrt**2
lamb_next = ((next_alpha_cumprod / (1-next_alpha_cumprod))**0.5).log() lamb = ((alpha_cumprod / (1 - alpha_cumprod)) ** 0.5).log()
next_alpha_cumprod = next_alpha_cumprod_sqrt**2
lamb_next = ((next_alpha_cumprod / (1 - next_alpha_cumprod)) ** 0.5).log()
h = lamb_next - lamb h = lamb_next - lamb
if previous_alpha_cumprod_sqrt is not None: if previous_alpha_cumprod_sqrt is not None:
previous_alpha_cumprod = previous_alpha_cumprod_sqrt ** 2 previous_alpha_cumprod = previous_alpha_cumprod_sqrt**2
lamb_previous = ((previous_alpha_cumprod / (1-previous_alpha_cumprod))**0.5).log() lamb_previous = ((previous_alpha_cumprod / (1 - previous_alpha_cumprod)) ** 0.5).log()
h_last = lamb - lamb_previous h_last = lamb - lamb_previous
r = h_last / h r = h_last / h
return h, r, lamb, lamb_next return h, r, lamb, lamb_next
else: else:
return h, None, lamb, lamb_next return h, None, lamb, lamb_next
def get_mult(self, h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt): def get_mult(
mult1 = ((1-next_alpha_cumprod_sqrt**2) / (1-alpha_cumprod_sqrt**2))**0.5 * (-h).exp() self, h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt
mult2 = (-2*h).expm1() * next_alpha_cumprod_sqrt ):
mult1 = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5 * (-h).exp()
mult2 = (-2 * h).expm1() * next_alpha_cumprod_sqrt
if previous_alpha_cumprod_sqrt is not None: if previous_alpha_cumprod_sqrt is not None:
mult3 = 1 + 1 / (2 * r) mult3 = 1 + 1 / (2 * r)
@ -698,18 +730,35 @@ class VPSDEDPMPP2MSampler(VideoDDIMSampler):
timestep=None, timestep=None,
scale=None, scale=None,
scale_emb=None, scale_emb=None,
ofs=None # 1020 ofs=None, # 1020
): ):
denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx, scale=scale, scale_emb=scale_emb, ofs=ofs).to(torch.float32) # 1020 denoised = self.denoise(
x,
denoiser,
alpha_cumprod_sqrt,
cond,
uc,
timestep,
idx,
scale=scale,
scale_emb=scale_emb,
ofs=ofs,
).to(torch.float32) # 1020
if idx == 1: if idx == 1:
return denoised, denoised return denoised, denoised
h, r, lamb, lamb_next = self.get_variables(alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt) h, r, lamb, lamb_next = self.get_variables(
alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt
)
mult = [ mult = [
append_dims(mult, x.ndim) append_dims(mult, x.ndim)
for mult in self.get_mult(h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt) for mult in self.get_mult(
h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt
)
] ]
mult_noise = append_dims((1-next_alpha_cumprod_sqrt**2)**0.5 * (1 - (-2*h).exp())**0.5, x.ndim) mult_noise = append_dims(
(1 - next_alpha_cumprod_sqrt**2) ** 0.5 * (1 - (-2 * h).exp()) ** 0.5, x.ndim
)
x_standard = mult[0] * x - mult[1] * denoised + mult_noise * torch.randn_like(x) x_standard = mult[0] * x - mult[1] * denoised + mult_noise * torch.randn_like(x)
if old_denoised is None or torch.sum(next_alpha_cumprod_sqrt) < 1e-14: if old_denoised is None or torch.sum(next_alpha_cumprod_sqrt) < 1e-14:
@ -723,23 +772,26 @@ class VPSDEDPMPP2MSampler(VideoDDIMSampler):
return x, denoised return x, denoised
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None, ofs=None): # 1020 def __call__(
self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None, ofs=None
): # 1020
x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop( x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop(
x, cond, uc, num_steps x, cond, uc, num_steps
) )
if self.fixed_frames > 0: if self.fixed_frames > 0:
prefix_frames = x[:, :self.fixed_frames] prefix_frames = x[:, : self.fixed_frames]
old_denoised = None old_denoised = None
for i in self.get_sigma_gen(num_sigmas): for i in self.get_sigma_gen(num_sigmas):
if self.fixed_frames > 0: if self.fixed_frames > 0:
if self.sdedit: if self.sdedit:
rd = torch.randn_like(prefix_frames) rd = torch.randn_like(prefix_frames)
noised_prefix_frames = alpha_cumprod_sqrt[i] * prefix_frames + rd * append_dims(s_in * (1 - alpha_cumprod_sqrt[i] ** 2)**0.5, len(prefix_frames.shape)) noised_prefix_frames = alpha_cumprod_sqrt[i] * prefix_frames + rd * append_dims(
x = torch.cat([noised_prefix_frames, x[:, self.fixed_frames:]], dim=1) s_in * (1 - alpha_cumprod_sqrt[i] ** 2) ** 0.5, len(prefix_frames.shape)
)
x = torch.cat([noised_prefix_frames, x[:, self.fixed_frames :]], dim=1)
else: else:
x = torch.cat([prefix_frames, x[:, self.fixed_frames:]], dim=1) x = torch.cat([prefix_frames, x[:, self.fixed_frames :]], dim=1)
x, old_denoised = self.sampler_step( x, old_denoised = self.sampler_step(
old_denoised, old_denoised,
None if i == 0 else s_in * alpha_cumprod_sqrt[i - 1], None if i == 0 else s_in * alpha_cumprod_sqrt[i - 1],
@ -750,37 +802,41 @@ class VPSDEDPMPP2MSampler(VideoDDIMSampler):
cond, cond,
uc=uc, uc=uc,
idx=self.num_steps - i, idx=self.num_steps - i,
timestep=timesteps[-(i+1)], timestep=timesteps[-(i + 1)],
scale=scale, scale=scale,
scale_emb=scale_emb, scale_emb=scale_emb,
ofs=ofs # 1020 ofs=ofs, # 1020
) )
if self.fixed_frames > 0: if self.fixed_frames > 0:
x = torch.cat([prefix_frames, x[:, self.fixed_frames:]], dim=1) x = torch.cat([prefix_frames, x[:, self.fixed_frames :]], dim=1)
return x return x
class VPODEDPMPP2MSampler(VideoDDIMSampler): class VPODEDPMPP2MSampler(VideoDDIMSampler):
def get_variables(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt=None): def get_variables(
alpha_cumprod = alpha_cumprod_sqrt ** 2 self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt=None
lamb = ((alpha_cumprod / (1-alpha_cumprod))**0.5).log() ):
next_alpha_cumprod = next_alpha_cumprod_sqrt ** 2 alpha_cumprod = alpha_cumprod_sqrt**2
lamb_next = ((next_alpha_cumprod / (1-next_alpha_cumprod))**0.5).log() lamb = ((alpha_cumprod / (1 - alpha_cumprod)) ** 0.5).log()
next_alpha_cumprod = next_alpha_cumprod_sqrt**2
lamb_next = ((next_alpha_cumprod / (1 - next_alpha_cumprod)) ** 0.5).log()
h = lamb_next - lamb h = lamb_next - lamb
if previous_alpha_cumprod_sqrt is not None: if previous_alpha_cumprod_sqrt is not None:
previous_alpha_cumprod = previous_alpha_cumprod_sqrt ** 2 previous_alpha_cumprod = previous_alpha_cumprod_sqrt**2
lamb_previous = ((previous_alpha_cumprod / (1-previous_alpha_cumprod))**0.5).log() lamb_previous = ((previous_alpha_cumprod / (1 - previous_alpha_cumprod)) ** 0.5).log()
h_last = lamb - lamb_previous h_last = lamb - lamb_previous
r = h_last / h r = h_last / h
return h, r, lamb, lamb_next return h, r, lamb, lamb_next
else: else:
return h, None, lamb, lamb_next return h, None, lamb, lamb_next
def get_mult(self, h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt): def get_mult(
mult1 = ((1-next_alpha_cumprod_sqrt**2) / (1-alpha_cumprod_sqrt**2))**0.5 self, h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt
):
mult1 = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5
mult2 = (-h).expm1() * next_alpha_cumprod_sqrt mult2 = (-h).expm1() * next_alpha_cumprod_sqrt
if previous_alpha_cumprod_sqrt is not None: if previous_alpha_cumprod_sqrt is not None:
@ -801,16 +857,22 @@ class VPODEDPMPP2MSampler(VideoDDIMSampler):
cond, cond,
uc=None, uc=None,
idx=None, idx=None,
timestep=None timestep=None,
): ):
denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx).to(torch.float32) denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx).to(
torch.float32
)
if idx == 1: if idx == 1:
return denoised, denoised return denoised, denoised
h, r, lamb, lamb_next = self.get_variables(alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt) h, r, lamb, lamb_next = self.get_variables(
alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt
)
mult = [ mult = [
append_dims(mult, x.ndim) append_dims(mult, x.ndim)
for mult in self.get_mult(h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt) for mult in self.get_mult(
h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt
)
] ]
x_standard = mult[0] * x - mult[1] * denoised x_standard = mult[0] * x - mult[1] * denoised
@ -842,22 +904,44 @@ class VPODEDPMPP2MSampler(VideoDDIMSampler):
cond, cond,
uc=uc, uc=uc,
idx=self.num_steps - i, idx=self.num_steps - i,
timestep=timesteps[-(i+1)] timestep=timesteps[-(i + 1)],
) )
return x return x
class VideoDDPMSampler(VideoDDIMSampler): class VideoDDPMSampler(VideoDDIMSampler):
def sampler_step(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, denoiser, x, cond, uc=None, idx=None): def sampler_step(
self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, denoiser, x, cond, uc=None, idx=None
):
# 此处的sigma实际上是alpha_cumprod_sqrt # 此处的sigma实际上是alpha_cumprod_sqrt
denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, idx*1000//self.num_steps).to(torch.float32) denoised = self.denoise(
x, denoiser, alpha_cumprod_sqrt, cond, uc, idx * 1000 // self.num_steps
).to(torch.float32)
if idx == 1: if idx == 1:
return denoised return denoised
alpha_sqrt = alpha_cumprod_sqrt / next_alpha_cumprod_sqrt alpha_sqrt = alpha_cumprod_sqrt / next_alpha_cumprod_sqrt
x = append_dims(alpha_sqrt * (1-next_alpha_cumprod_sqrt**2) / (1-alpha_cumprod_sqrt**2), x.ndim) * x \ x = (
+ append_dims(next_alpha_cumprod_sqrt * (1-alpha_sqrt**2) / (1-alpha_cumprod_sqrt**2), x.ndim) * denoised \ append_dims(
+ append_dims(((1-next_alpha_cumprod_sqrt**2) * (1-alpha_sqrt**2) / (1-alpha_cumprod_sqrt**2))**0.5, x.ndim) * torch.randn_like(x) alpha_sqrt * (1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2), x.ndim
)
* x
+ append_dims(
next_alpha_cumprod_sqrt * (1 - alpha_sqrt**2) / (1 - alpha_cumprod_sqrt**2), x.ndim
)
* denoised
+ append_dims(
(
(1 - next_alpha_cumprod_sqrt**2)
* (1 - alpha_sqrt**2)
/ (1 - alpha_cumprod_sqrt**2)
)
** 0.5,
x.ndim,
)
* torch.randn_like(x)
)
return x return x
@ -874,7 +958,7 @@ class VideoDDPMSampler(VideoDDIMSampler):
x, x,
cond, cond,
uc, uc,
idx=self.num_steps - i idx=self.num_steps - i,
) )
return x return x

View File

@ -17,7 +17,15 @@ class EDMSampling:
class DiscreteSampling: class DiscreteSampling:
def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True, uniform_sampling=False, group_num=0): def __init__(
self,
discretization_config,
num_idx,
do_append_zero=False,
flip=True,
uniform_sampling=False,
group_num=0,
):
self.num_idx = num_idx self.num_idx = num_idx
self.sigmas = instantiate_from_config(discretization_config)( self.sigmas = instantiate_from_config(discretization_config)(
num_idx, do_append_zero=do_append_zero, flip=flip num_idx, do_append_zero=do_append_zero, flip=flip
@ -42,7 +50,11 @@ class DiscreteSampling:
group_index = rank // self.group_width group_index = rank // self.group_width
idx = default( idx = default(
rand, rand,
torch.randint(group_index * self.sigma_interval, (group_index + 1) * self.sigma_interval, (n_samples,)), torch.randint(
group_index * self.sigma_interval,
(group_index + 1) * self.sigma_interval,
(n_samples,),
),
) )
else: else:
idx = default( idx = default(
@ -54,8 +66,11 @@ class DiscreteSampling:
else: else:
return self.idx_to_sigma(idx) return self.idx_to_sigma(idx)
class PartialDiscreteSampling: class PartialDiscreteSampling:
def __init__(self, discretization_config, total_num_idx, partial_num_idx, do_append_zero=False, flip=True): def __init__(
self, discretization_config, total_num_idx, partial_num_idx, do_append_zero=False, flip=True
):
self.total_num_idx = total_num_idx self.total_num_idx = total_num_idx
self.partial_num_idx = partial_num_idx self.partial_num_idx = partial_num_idx
self.sigmas = instantiate_from_config(discretization_config)( self.sigmas = instantiate_from_config(discretization_config)(

View File

@ -24,7 +24,9 @@ def make_beta_schedule(
linear_end=2e-2, linear_end=2e-2,
): ):
if schedule == "linear": if schedule == "linear":
betas = torch.linspace(linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64) ** 2 betas = (
torch.linspace(linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64) ** 2
)
return betas.numpy() return betas.numpy()
@ -50,7 +52,9 @@ def mixed_checkpoint(func, inputs: dict, params, flag):
tensor_keys = [key for key in inputs if isinstance(inputs[key], torch.Tensor)] tensor_keys = [key for key in inputs if isinstance(inputs[key], torch.Tensor)]
tensor_inputs = [inputs[key] for key in inputs if isinstance(inputs[key], torch.Tensor)] tensor_inputs = [inputs[key] for key in inputs if isinstance(inputs[key], torch.Tensor)]
non_tensor_keys = [key for key in inputs if not isinstance(inputs[key], torch.Tensor)] non_tensor_keys = [key for key in inputs if not isinstance(inputs[key], torch.Tensor)]
non_tensor_inputs = [inputs[key] for key in inputs if not isinstance(inputs[key], torch.Tensor)] non_tensor_inputs = [
inputs[key] for key in inputs if not isinstance(inputs[key], torch.Tensor)
]
args = tuple(tensor_inputs) + tuple(non_tensor_inputs) + tuple(params) args = tuple(tensor_inputs) + tuple(non_tensor_inputs) + tuple(params)
return MixedCheckpointFunction.apply( return MixedCheckpointFunction.apply(
func, func,
@ -84,9 +88,14 @@ class MixedCheckpointFunction(torch.autograd.Function):
} }
assert len(tensor_keys) == length_tensors and len(non_tensor_keys) == length_non_tensors assert len(tensor_keys) == length_tensors and len(non_tensor_keys) == length_non_tensors
ctx.input_tensors = {key: val for (key, val) in zip(tensor_keys, list(args[: ctx.end_tensors]))} ctx.input_tensors = {
key: val for (key, val) in zip(tensor_keys, list(args[: ctx.end_tensors]))
}
ctx.input_non_tensors = { ctx.input_non_tensors = {
key: val for (key, val) in zip(non_tensor_keys, list(args[ctx.end_tensors : ctx.end_non_tensors])) key: val
for (key, val) in zip(
non_tensor_keys, list(args[ctx.end_tensors : ctx.end_non_tensors])
)
} }
ctx.run_function = run_function ctx.run_function = run_function
ctx.input_params = list(args[ctx.end_non_tensors :]) ctx.input_params = list(args[ctx.end_non_tensors :])
@ -98,13 +107,18 @@ class MixedCheckpointFunction(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, *output_grads): def backward(ctx, *output_grads):
# additional_args = {key: ctx.input_tensors[key] for key in ctx.input_tensors if not isinstance(ctx.input_tensors[key],torch.Tensor)} # additional_args = {key: ctx.input_tensors[key] for key in ctx.input_tensors if not isinstance(ctx.input_tensors[key],torch.Tensor)}
ctx.input_tensors = {key: ctx.input_tensors[key].detach().requires_grad_(True) for key in ctx.input_tensors} ctx.input_tensors = {
key: ctx.input_tensors[key].detach().requires_grad_(True) for key in ctx.input_tensors
}
with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
# Fixes a bug where the first op in run_function modifies the # Fixes a bug where the first op in run_function modifies the
# Tensor storage in place, which is not allowed for detach()'d # Tensor storage in place, which is not allowed for detach()'d
# Tensors. # Tensors.
shallow_copies = {key: ctx.input_tensors[key].view_as(ctx.input_tensors[key]) for key in ctx.input_tensors} shallow_copies = {
key: ctx.input_tensors[key].view_as(ctx.input_tensors[key])
for key in ctx.input_tensors
}
# shallow_copies.update(additional_args) # shallow_copies.update(additional_args)
output_tensors = ctx.run_function(**shallow_copies, **ctx.input_non_tensors) output_tensors = ctx.run_function(**shallow_copies, **ctx.input_non_tensors)
input_grads = torch.autograd.grad( input_grads = torch.autograd.grad(
@ -188,9 +202,9 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False, dtyp
""" """
if not repeat_only: if not repeat_only:
half = dim // 2 half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( freqs = torch.exp(
device=timesteps.device -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
) ).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None] args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2: if dim % 2:

View File

@ -6,7 +6,9 @@ OPENAIUNETWRAPPER = "sgm.modules.diffusionmodules.wrappers.OpenAIWrapper"
class IdentityWrapper(nn.Module): class IdentityWrapper(nn.Module):
def __init__(self, diffusion_model, compile_model: bool = False, dtype: torch.dtype = torch.float32): def __init__(
self, diffusion_model, compile_model: bool = False, dtype: torch.dtype = torch.float32
):
super().__init__() super().__init__()
compile = ( compile = (
torch.compile torch.compile

View File

@ -87,8 +87,14 @@ def normal_kl(mean1, logvar1, mean2, logvar2):
# Force variances to be Tensors. Broadcasting helps convert scalars to # Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for torch.exp(). # Tensors, but it does not work for torch.exp().
logvar1, logvar2 = [x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2)] logvar1, logvar2 = [
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2)
]
return 0.5 * ( return 0.5 * (
-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) -1.0
+ logvar2
- logvar1
+ torch.exp(logvar1 - logvar2)
+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
) )

View File

@ -12,7 +12,9 @@ class LitEma(nn.Module):
self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
self.register_buffer( self.register_buffer(
"num_updates", "num_updates",
torch.tensor(0, dtype=torch.int) if use_num_upates else torch.tensor(-1, dtype=torch.int), torch.tensor(0, dtype=torch.int)
if use_num_upates
else torch.tensor(-1, dtype=torch.int),
) )
for name, p in model.named_parameters(): for name, p in model.named_parameters():
@ -45,9 +47,11 @@ class LitEma(nn.Module):
if m_param[key].requires_grad: if m_param[key].requires_grad:
sname = self.m_name2s_name[key] sname = self.m_name2s_name[key]
shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) shadow_params[sname].sub_(
one_minus_decay * (shadow_params[sname] - m_param[key])
)
else: else:
assert not key in self.m_name2s_name assert key not in self.m_name2s_name
def copy_to(self, model): def copy_to(self, model):
m_param = dict(model.named_parameters()) m_param = dict(model.named_parameters())
@ -56,7 +60,7 @@ class LitEma(nn.Module):
if m_param[key].requires_grad: if m_param[key].requires_grad:
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
else: else:
assert not key in self.m_name2s_name assert key not in self.m_name2s_name
def store(self, parameters): def store(self, parameters):
""" """

View File

@ -99,7 +99,9 @@ class GeneralConditioner(nn.Module):
elif "input_keys" in embconfig: elif "input_keys" in embconfig:
embedder.input_keys = embconfig["input_keys"] embedder.input_keys = embconfig["input_keys"]
else: else:
raise KeyError(f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}") raise KeyError(
f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}"
)
embedder.legacy_ucg_val = embconfig.get("legacy_ucg_value", None) embedder.legacy_ucg_val = embconfig.get("legacy_ucg_value", None)
if embedder.legacy_ucg_val is not None: if embedder.legacy_ucg_val is not None:
@ -160,7 +162,10 @@ class GeneralConditioner(nn.Module):
if cond_or_not is None: if cond_or_not is None:
emb = ( emb = (
expand_dims_like( expand_dims_like(
torch.bernoulli((1.0 - embedder.ucg_rate) * torch.ones(emb.shape[0], device=emb.device)), torch.bernoulli(
(1.0 - embedder.ucg_rate)
* torch.ones(emb.shape[0], device=emb.device)
),
emb, emb,
) )
* emb * emb

View File

@ -96,7 +96,9 @@ class VideoTransformerBlock(nn.Module):
if self.checkpoint: if self.checkpoint:
print(f"{self.__class__.__name__} is using checkpointing") print(f"{self.__class__.__name__} is using checkpointing")
def forward(self, x: torch.Tensor, context: torch.Tensor = None, timesteps: int = None) -> torch.Tensor: def forward(
self, x: torch.Tensor, context: torch.Tensor = None, timesteps: int = None
) -> torch.Tensor:
if self.checkpoint: if self.checkpoint:
return checkpoint(self._forward, x, context, timesteps) return checkpoint(self._forward, x, context, timesteps)
else: else:
@ -239,7 +241,9 @@ class SpatialVideoTransformer(SpatialTransformer):
spatial_context = context spatial_context = context
if self.use_spatial_context: if self.use_spatial_context:
assert context.ndim == 3, f"n dims of spatial context should be 3 but are {context.ndim}" assert (
context.ndim == 3
), f"n dims of spatial context should be 3 but are {context.ndim}"
time_context = context time_context = context
time_context_first_timestep = time_context[::timesteps] time_context_first_timestep = time_context[::timesteps]

View File

@ -86,7 +86,9 @@ class SafeConv3d(torch.nn.Conv3d):
input_chunks = torch.chunk(input, part_num, dim=2) # NCTHW input_chunks = torch.chunk(input, part_num, dim=2) # NCTHW
if kernel_size > 1: if kernel_size > 1:
input_chunks = [input_chunks[0]] + [ input_chunks = [input_chunks[0]] + [
torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2) torch.cat(
(input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2
)
for i in range(1, len(input_chunks)) for i in range(1, len(input_chunks))
] ]
@ -252,7 +254,7 @@ def count_params(model, verbose=False):
def instantiate_from_config(config, **extra_kwargs): def instantiate_from_config(config, **extra_kwargs):
if not "target" in config: if "target" not in config:
if config == "__is_first_stage__": if config == "__is_first_stage__":
return None return None
elif config == "__is_unconditional__": elif config == "__is_unconditional__":

View File

@ -93,7 +93,12 @@ class SimpleDistributedWebDataset(DataPipeline):
def tar_file_iterator_with_meta( def tar_file_iterator_with_meta(
fileobj, meta_names, skip_meta=r"__[^/]*__($|/)", suffix=None, handler=reraise_exception, meta_stream=None fileobj,
meta_names,
skip_meta=r"__[^/]*__($|/)",
suffix=None,
handler=reraise_exception,
meta_stream=None,
): ):
"""Iterate over tar file, yielding filename, content pairs for the given tar stream. """Iterate over tar file, yielding filename, content pairs for the given tar stream.
@ -122,10 +127,13 @@ def tar_file_iterator_with_meta(
except Exception as exn: except Exception as exn:
from sat.helpers import print_rank0 from sat.helpers import print_rank0
print_rank0(f"Error in loading jsonl {meta_file_name}, lineno {lineno}: {line}", level="DEBUG") print_rank0(
f"Error in loading jsonl {meta_file_name}, lineno {lineno}: {line}",
level="DEBUG",
)
continue continue
for item in meta_list: for item in meta_list:
if not item["key"] in meta_data: if item["key"] not in meta_data:
meta_data[item["key"]] = {} meta_data[item["key"]] = {}
for meta_name in meta_names: for meta_name in meta_names:
if meta_name in item: if meta_name in item:
@ -186,7 +194,9 @@ def tar_file_expander_with_meta(data, meta_names, handler=reraise_exception):
try: try:
assert isinstance(source, dict) assert isinstance(source, dict)
assert "stream" in source assert "stream" in source
for sample in tar_file_iterator_with_meta(source["stream"], meta_names, meta_stream=source["meta_stream"]): for sample in tar_file_iterator_with_meta(
source["stream"], meta_names, meta_stream=source["meta_stream"]
):
assert isinstance(sample, dict) and "data" in sample and "fname" in sample assert isinstance(sample, dict) and "data" in sample and "fname" in sample
sample["__url__"] = url sample["__url__"] = url
yield sample yield sample
@ -250,7 +260,15 @@ class MetaDistributedWebDataset(DataPipeline):
""" """
def __init__( def __init__(
self, path, process_fn, seed, *, meta_names=[], nshards=sys.maxsize, shuffle_buffer=1000, include_dirs=None self,
path,
process_fn,
seed,
*,
meta_names=[],
nshards=sys.maxsize,
shuffle_buffer=1000,
include_dirs=None,
): ):
# os.environ['WDS_SHOW_SEED'] = '1' # os.environ['WDS_SHOW_SEED'] = '1'
import torch import torch
@ -361,7 +379,10 @@ def gopen_boto3(url, mode="rb", bufsize=8192 * 2):
if mode[0] == "r": if mode[0] == "r":
s3_client = boto3.client( s3_client = boto3.client(
"s3", endpoint_url=endpoint_url, aws_access_key_id=access_key, aws_secret_access_key=secret_key "s3",
endpoint_url=endpoint_url,
aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
) )
bucket, key = url.split("/", 1) bucket, key = url.split("/", 1)

View File

@ -37,7 +37,9 @@ def save_texts(texts, save_dir, iterations):
f.write(text + "\n") f.write(text + "\n")
def save_video_as_grid_and_mp4(video_batch: torch.Tensor, save_path: str, T: int, fps: int = 5, args=None, key=None): def save_video_as_grid_and_mp4(
video_batch: torch.Tensor, save_path: str, T: int, fps: int = 5, args=None, key=None
):
os.makedirs(save_path, exist_ok=True) os.makedirs(save_path, exist_ok=True)
for i, vid in enumerate(video_batch): for i, vid in enumerate(video_batch):
@ -52,7 +54,8 @@ def save_video_as_grid_and_mp4(video_batch: torch.Tensor, save_path: str, T: int
writer.append_data(frame) writer.append_data(frame)
if args is not None and args.wandb: if args is not None and args.wandb:
wandb.log( wandb.log(
{key + f"_video_{i}": wandb.Video(now_save_path, fps=fps, format="mp4")}, step=args.iteration + 1 {key + f"_video_{i}": wandb.Video(now_save_path, fps=fps, format="mp4")},
step=args.iteration + 1,
) )
@ -138,7 +141,9 @@ def broad_cast_batch(batch):
return batch return batch
def forward_step_eval(data_iterator, model, args, timers, only_log_video_latents=False, data_class=None): def forward_step_eval(
data_iterator, model, args, timers, only_log_video_latents=False, data_class=None
):
if mpu.get_model_parallel_rank() == 0: if mpu.get_model_parallel_rank() == 0:
timers("data loader").start() timers("data loader").start()
batch_video = next(data_iterator) batch_video = next(data_iterator)
@ -209,7 +214,9 @@ if __name__ == "__main__":
args = argparse.Namespace(**vars(args), **vars(known)) args = argparse.Namespace(**vars(args), **vars(known))
data_class = get_obj_from_str(args.data_config["target"]) data_class = get_obj_from_str(args.data_config["target"])
create_dataset_function = partial(data_class.create_dataset_function, **args.data_config["params"]) create_dataset_function = partial(
data_class.create_dataset_function, **args.data_config["params"]
)
import yaml import yaml
@ -225,7 +232,9 @@ if __name__ == "__main__":
model_cls=SATVideoDiffusionEngine, model_cls=SATVideoDiffusionEngine,
forward_step_function=partial(forward_step, data_class=data_class), forward_step_function=partial(forward_step, data_class=data_class),
forward_step_eval=partial( forward_step_eval=partial(
forward_step_eval, data_class=data_class, only_log_video_latents=args.only_log_video_latents forward_step_eval,
data_class=data_class,
only_log_video_latents=args.only_log_video_latents,
), ),
create_dataset_function=create_dataset_function, create_dataset_function=create_dataset_function,
) )

View File

@ -94,7 +94,11 @@ class FeedForward(nn.Module):
super().__init__() super().__init__()
inner_dim = int(dim * mult) inner_dim = int(dim * mult)
dim_out = default(dim_out, dim) dim_out = default(dim_out, dim)
project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim) project_in = (
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
if not glu
else GEGLU(dim, inner_dim)
)
self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
@ -126,7 +130,9 @@ class LinearAttention(nn.Module):
def forward(self, x): def forward(self, x):
b, c, h, w = x.shape b, c, h, w = x.shape
qkv = self.to_qkv(x) qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3) q, k, v = rearrange(
qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
)
k = k.softmax(dim=-1) k = k.softmax(dim=-1)
context = torch.einsum("bhdn,bhen->bhde", k, v) context = torch.einsum("bhdn,bhen->bhde", k, v)
out = torch.einsum("bhde,bhdn->bhen", context, q) out = torch.einsum("bhde,bhdn->bhen", context, q)
@ -143,7 +149,9 @@ class SpatialSelfAttention(nn.Module):
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x): def forward(self, x):
h_ = x h_ = x
@ -244,7 +252,9 @@ class CrossAttention(nn.Module):
# new # new
with sdp_kernel(**BACKEND_MAP[self.backend]): with sdp_kernel(**BACKEND_MAP[self.backend]):
# print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape) # print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) # scale is dim_head ** -0.5 per default out = F.scaled_dot_product_attention(
q, k, v, attn_mask=mask
) # scale is dim_head ** -0.5 per default
del q, k, v del q, k, v
out = rearrange(out, "b h n d -> b n (h d)", h=h) out = rearrange(out, "b h n d -> b n (h d)", h=h)
@ -422,7 +432,9 @@ class BasicTransformerBlock(nn.Module):
self.norm1(x), self.norm1(x),
context=context if self.disable_self_attn else None, context=context if self.disable_self_attn else None,
additional_tokens=additional_tokens, additional_tokens=additional_tokens,
n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self if not self.disable_self_attn else 0, n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self
if not self.disable_self_attn
else 0,
) )
+ x + x
) )
@ -499,7 +511,9 @@ class SpatialTransformer(nn.Module):
sdp_backend=None, sdp_backend=None,
): ):
super().__init__() super().__init__()
print(f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads") print(
f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads"
)
from omegaconf import ListConfig from omegaconf import ListConfig
if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)): if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)):
@ -542,7 +556,9 @@ class SpatialTransformer(nn.Module):
] ]
) )
if not use_linear: if not use_linear:
self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)) self.proj_out = zero_module(
nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
)
else: else:
# self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) # self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels)) self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))

View File

@ -97,9 +97,7 @@ class AbstractAutoencoder(pl.LightningModule):
def instantiate_optimizer_from_config(self, params, lr, cfg): def instantiate_optimizer_from_config(self, params, lr, cfg):
logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config") logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
return get_obj_from_str(cfg["target"])( return get_obj_from_str(cfg["target"])(params, lr=lr, **cfg.get("params", dict()))
params, lr=lr, **cfg.get("params", dict())
)
def configure_optimizers(self) -> Any: def configure_optimizers(self) -> Any:
raise NotImplementedError() raise NotImplementedError()
@ -214,14 +212,20 @@ class AutoencodingEngine(AbstractAutoencoder):
x = self.decoder(z, **kwargs) x = self.decoder(z, **kwargs)
return x return x
def forward(self, x: torch.Tensor, **additional_decode_kwargs) -> Tuple[torch.Tensor, torch.Tensor, dict]: def forward(
self, x: torch.Tensor, **additional_decode_kwargs
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
z, reg_log = self.encode(x, return_reg_log=True) z, reg_log = self.encode(x, return_reg_log=True)
dec = self.decode(z, **additional_decode_kwargs) dec = self.decode(z, **additional_decode_kwargs)
return z, dec, reg_log return z, dec, reg_log
def inner_training_step(self, batch: dict, batch_idx: int, optimizer_idx: int = 0) -> torch.Tensor: def inner_training_step(
self, batch: dict, batch_idx: int, optimizer_idx: int = 0
) -> torch.Tensor:
x = self.get_input(batch) x = self.get_input(batch)
additional_decode_kwargs = {key: batch[key] for key in self.additional_decode_keys.intersection(batch)} additional_decode_kwargs = {
key: batch[key] for key in self.additional_decode_keys.intersection(batch)
}
z, xrec, regularization_log = self(x, **additional_decode_kwargs) z, xrec, regularization_log = self(x, **additional_decode_kwargs)
if hasattr(self.loss, "forward_keys"): if hasattr(self.loss, "forward_keys"):
extra_info = { extra_info = {
@ -357,12 +361,16 @@ class AutoencodingEngine(AbstractAutoencoder):
if self.trainable_ae_params is None: if self.trainable_ae_params is None:
ae_params = self.get_autoencoder_params() ae_params = self.get_autoencoder_params()
else: else:
ae_params, num_ae_params = self.get_param_groups(self.trainable_ae_params, self.ae_optimizer_args) ae_params, num_ae_params = self.get_param_groups(
self.trainable_ae_params, self.ae_optimizer_args
)
logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}") logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}")
if self.trainable_disc_params is None: if self.trainable_disc_params is None:
disc_params = self.get_discriminator_params() disc_params = self.get_discriminator_params()
else: else:
disc_params, num_disc_params = self.get_param_groups(self.trainable_disc_params, self.disc_optimizer_args) disc_params, num_disc_params = self.get_param_groups(
self.trainable_disc_params, self.disc_optimizer_args
)
logpy.info(f"Number of trainable discriminator parameters: {num_disc_params:,}") logpy.info(f"Number of trainable discriminator parameters: {num_disc_params:,}")
opt_ae = self.instantiate_optimizer_from_config( opt_ae = self.instantiate_optimizer_from_config(
ae_params, ae_params,
@ -371,17 +379,23 @@ class AutoencodingEngine(AbstractAutoencoder):
) )
opts = [opt_ae] opts = [opt_ae]
if len(disc_params) > 0: if len(disc_params) > 0:
opt_disc = self.instantiate_optimizer_from_config(disc_params, self.learning_rate, self.optimizer_config) opt_disc = self.instantiate_optimizer_from_config(
disc_params, self.learning_rate, self.optimizer_config
)
opts.append(opt_disc) opts.append(opt_disc)
return opts return opts
@torch.no_grad() @torch.no_grad()
def log_images(self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs) -> dict: def log_images(
self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs
) -> dict:
log = dict() log = dict()
additional_decode_kwargs = {} additional_decode_kwargs = {}
x = self.get_input(batch) x = self.get_input(batch)
additional_decode_kwargs.update({key: batch[key] for key in self.additional_decode_keys.intersection(batch)}) additional_decode_kwargs.update(
{key: batch[key] for key in self.additional_decode_keys.intersection(batch)}
)
_, xrec, _ = self(x, **additional_decode_kwargs) _, xrec, _ = self(x, **additional_decode_kwargs)
log["inputs"] = x log["inputs"] = x
@ -400,7 +414,9 @@ class AutoencodingEngine(AbstractAutoencoder):
diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x) diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x)
diff_ema.clamp_(0, 1.0) diff_ema.clamp_(0, 1.0)
log["diff_ema"] = 2.0 * diff_ema - 1.0 log["diff_ema"] = 2.0 * diff_ema - 1.0
log["diff_boost_ema"] = 2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1 log["diff_boost_ema"] = (
2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1
)
if additional_log_kwargs: if additional_log_kwargs:
additional_decode_kwargs.update(additional_log_kwargs) additional_decode_kwargs.update(additional_log_kwargs)
_, xrec_add, _ = self(x, **additional_decode_kwargs) _, xrec_add, _ = self(x, **additional_decode_kwargs)
@ -442,7 +458,9 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
params = super().get_autoencoder_params() params = super().get_autoencoder_params()
return params return params
def encode(self, x: torch.Tensor, return_reg_log: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: def encode(
self, x: torch.Tensor, return_reg_log: bool = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
if self.max_batch_size is None: if self.max_batch_size is None:
z = self.encoder(x) z = self.encoder(x)
z = self.quant_conv(z) z = self.quant_conv(z)
@ -485,7 +503,9 @@ class AutoencoderKL(AutoencodingEngineLegacy):
if "lossconfig" in kwargs: if "lossconfig" in kwargs:
kwargs["loss_config"] = kwargs.pop("lossconfig") kwargs["loss_config"] = kwargs.pop("lossconfig")
super().__init__( super().__init__(
regularizer_config={"target": ("sgm.modules.autoencoding.regularizers" ".DiagonalGaussianRegularizer")}, regularizer_config={
"target": ("sgm.modules.autoencoding.regularizers" ".DiagonalGaussianRegularizer")
},
**kwargs, **kwargs,
) )
@ -519,7 +539,9 @@ class VideoAutoencodingEngine(AutoencodingEngine):
if ckpt_path is not None: if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def log_videos(self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs) -> dict: def log_videos(
self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs
) -> dict:
return self.log_images(batch, additional_log_kwargs, **kwargs) return self.log_images(batch, additional_log_kwargs, **kwargs)
def get_input(self, batch: dict) -> torch.Tensor: def get_input(self, batch: dict) -> torch.Tensor:
@ -530,7 +552,9 @@ class VideoAutoencodingEngine(AutoencodingEngine):
batch = batch[self.input_key] batch = batch[self.input_key]
global_src_rank = get_context_parallel_group_rank() * self.context_parallel_size global_src_rank = get_context_parallel_group_rank() * self.context_parallel_size
torch.distributed.broadcast(batch, src=global_src_rank, group=get_context_parallel_group()) torch.distributed.broadcast(
batch, src=global_src_rank, group=get_context_parallel_group()
)
batch = _conv_split(batch, dim=2, kernel_size=1) batch = _conv_split(batch, dim=2, kernel_size=1)
return batch return batch

View File

@ -201,7 +201,9 @@ def _pass_from_previous_rank(input_, dim, kernel_size):
recv_rank += cp_world_size recv_rank += cp_world_size
if cp_rank < cp_world_size - 1: if cp_rank < cp_world_size - 1:
req_send = torch.distributed.isend(input_[-kernel_size + 1 :].contiguous(), send_rank, group=group) req_send = torch.distributed.isend(
input_[-kernel_size + 1 :].contiguous(), send_rank, group=group
)
if cp_rank > 0: if cp_rank > 0:
recv_buffer = torch.empty_like(input_[-kernel_size + 1 :]).contiguous() recv_buffer = torch.empty_like(input_[-kernel_size + 1 :]).contiguous()
req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group) req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group)
@ -246,11 +248,12 @@ def _fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding=Non
recv_buffer = torch.empty_like(input_[-kernel_size + 1 :]).contiguous() recv_buffer = torch.empty_like(input_[-kernel_size + 1 :]).contiguous()
if cp_rank < cp_world_size - 1: if cp_rank < cp_world_size - 1:
req_send = torch.distributed.isend(input_[-kernel_size + 1 :].contiguous(), send_rank, group=group) req_send = torch.distributed.isend(
input_[-kernel_size + 1 :].contiguous(), send_rank, group=group
)
if cp_rank > 0: if cp_rank > 0:
req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group) req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group)
if cp_rank == 0: if cp_rank == 0:
if cache_padding is not None: if cache_padding is not None:
input_ = torch.cat([cache_padding.transpose(0, dim).to(input_.device), input_], dim=0) input_ = torch.cat([cache_padding.transpose(0, dim).to(input_.device), input_], dim=0)
@ -334,7 +337,9 @@ def fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding):
class ContextParallelCausalConv3d(nn.Module): class ContextParallelCausalConv3d(nn.Module):
def __init__(self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], stride=1, **kwargs): def __init__(
self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], stride=1, **kwargs
):
super().__init__() super().__init__()
kernel_size = cast_tuple(kernel_size, 3) kernel_size = cast_tuple(kernel_size, 3)
@ -354,7 +359,9 @@ class ContextParallelCausalConv3d(nn.Module):
stride = (stride, stride, stride) stride = (stride, stride, stride)
dilation = (1, 1, 1) dilation = (1, 1, 1)
self.conv = Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) self.conv = Conv3d(
chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs
)
self.cache_padding = None self.cache_padding = None
def forward(self, input_, clear_cache=True): def forward(self, input_, clear_cache=True):
@ -369,7 +376,11 @@ class ContextParallelCausalConv3d(nn.Module):
global_rank = torch.distributed.get_rank() global_rank = torch.distributed.get_rank()
if cp_world_size == 1: if cp_world_size == 1:
self.cache_padding = ( self.cache_padding = (
input_parallel[:, :, -self.time_kernel_size + 1 :].contiguous().detach().clone().cpu() input_parallel[:, :, -self.time_kernel_size + 1 :]
.contiguous()
.detach()
.clone()
.cpu()
) )
else: else:
if cp_rank == cp_world_size - 1: if cp_rank == cp_world_size - 1:
@ -379,9 +390,13 @@ class ContextParallelCausalConv3d(nn.Module):
group=get_context_parallel_group(), group=get_context_parallel_group(),
) )
if cp_rank == 0: if cp_rank == 0:
recv_buffer = torch.empty_like(input_parallel[:, :, -self.time_kernel_size + 1 :]).contiguous() recv_buffer = torch.empty_like(
input_parallel[:, :, -self.time_kernel_size + 1 :]
).contiguous()
torch.distributed.recv( torch.distributed.recv(
recv_buffer, global_rank - 1 + cp_world_size, group=get_context_parallel_group() recv_buffer,
global_rank - 1 + cp_world_size,
group=get_context_parallel_group(),
) )
self.cache_padding = recv_buffer.contiguous().detach().clone().cpu() self.cache_padding = recv_buffer.contiguous().detach().clone().cpu()
@ -406,7 +421,9 @@ class ContextParallelGroupNorm(torch.nn.GroupNorm):
def Normalize(in_channels, gather=False, **kwargs): def Normalize(in_channels, gather=False, **kwargs):
if gather: if gather:
return ContextParallelGroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) return ContextParallelGroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
)
else: else:
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
@ -460,7 +477,8 @@ class SpatialNorm3D(nn.Module):
zq_rest_splits = torch.split(zq_rest, 32, dim=1) zq_rest_splits = torch.split(zq_rest, 32, dim=1)
interpolated_splits = [ interpolated_splits = [
torch.nn.functional.interpolate(split, size=f_rest_size, mode="nearest") for split in zq_rest_splits torch.nn.functional.interpolate(split, size=f_rest_size, mode="nearest")
for split in zq_rest_splits
] ]
zq_rest = torch.cat(interpolated_splits, dim=1) zq_rest = torch.cat(interpolated_splits, dim=1)
@ -471,7 +489,8 @@ class SpatialNorm3D(nn.Module):
zq_splits = torch.split(zq, 32, dim=1) zq_splits = torch.split(zq, 32, dim=1)
interpolated_splits = [ interpolated_splits = [
torch.nn.functional.interpolate(split, size=f_size, mode="nearest") for split in zq_splits torch.nn.functional.interpolate(split, size=f_size, mode="nearest")
for split in zq_splits
] ]
zq = torch.cat(interpolated_splits, dim=1) zq = torch.cat(interpolated_splits, dim=1)
@ -511,7 +530,9 @@ class Upsample3D(nn.Module):
super().__init__() super().__init__()
self.with_conv = with_conv self.with_conv = with_conv
if self.with_conv: if self.with_conv:
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1
)
self.compress_time = compress_time self.compress_time = compress_time
def forward(self, x, fake_cp=True): def forward(self, x, fake_cp=True):
@ -523,14 +544,16 @@ class Upsample3D(nn.Module):
splits = torch.split(x_rest, 32, dim=1) splits = torch.split(x_rest, 32, dim=1)
interpolated_splits = [ interpolated_splits = [
torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest")
for split in splits
] ]
x_rest = torch.cat(interpolated_splits, dim=1) x_rest = torch.cat(interpolated_splits, dim=1)
x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2) x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2)
else: else:
splits = torch.split(x, 32, dim=1) splits = torch.split(x, 32, dim=1)
interpolated_splits = [ interpolated_splits = [
torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest")
for split in splits
] ]
x = torch.cat(interpolated_splits, dim=1) x = torch.cat(interpolated_splits, dim=1)
@ -541,7 +564,8 @@ class Upsample3D(nn.Module):
splits = torch.split(x, 32, dim=1) splits = torch.split(x, 32, dim=1)
interpolated_splits = [ interpolated_splits = [
torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest")
for split in splits
] ]
x = torch.cat(interpolated_splits, dim=1) x = torch.cat(interpolated_splits, dim=1)
@ -563,7 +587,9 @@ class DownSample3D(nn.Module):
out_channels = in_channels out_channels = in_channels
if self.with_conv: if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves # no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0) self.conv = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=2, padding=0
)
self.compress_time = compress_time self.compress_time = compress_time
def forward(self, x, fake_cp=True): def forward(self, x, fake_cp=True):
@ -578,7 +604,8 @@ class DownSample3D(nn.Module):
if x_rest.shape[-1] > 0: if x_rest.shape[-1] > 0:
splits = torch.split(x_rest, 32, dim=1) splits = torch.split(x_rest, 32, dim=1)
interpolated_splits = [ interpolated_splits = [
torch.nn.functional.avg_pool1d(split, kernel_size=2, stride=2) for split in splits torch.nn.functional.avg_pool1d(split, kernel_size=2, stride=2)
for split in splits
] ]
x_rest = torch.cat(interpolated_splits, dim=1) x_rest = torch.cat(interpolated_splits, dim=1)
x = torch.cat([x_first[..., None], x_rest], dim=-1) x = torch.cat([x_first[..., None], x_rest], dim=-1)
@ -587,7 +614,8 @@ class DownSample3D(nn.Module):
# x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2) # x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2)
splits = torch.split(x, 32, dim=1) splits = torch.split(x, 32, dim=1)
interpolated_splits = [ interpolated_splits = [
torch.nn.functional.avg_pool1d(split, kernel_size=2, stride=2) for split in splits torch.nn.functional.avg_pool1d(split, kernel_size=2, stride=2)
for split in splits
] ]
x = torch.cat(interpolated_splits, dim=1) x = torch.cat(interpolated_splits, dim=1)
x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w) x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w)
@ -923,9 +951,13 @@ class ContextParallelDecoder3D(nn.Module):
up.attn = attn up.attn = attn
if i_level != 0: if i_level != 0:
if i_level < self.num_resolutions - self.temporal_compress_level: if i_level < self.num_resolutions - self.temporal_compress_level:
up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False) up.upsample = Upsample3D(
block_in, with_conv=resamp_with_conv, compress_time=False
)
else: else:
up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=True) up.upsample = Upsample3D(
block_in, with_conv=resamp_with_conv, compress_time=True
)
self.up.insert(0, up) self.up.insert(0, up)
self.norm_out = Normalize3D(block_in, zq_ch, add_conv=add_conv, gather=gather_norm) self.norm_out = Normalize3D(block_in, zq_ch, add_conv=add_conv, gather=gather_norm)

View File

@ -12,7 +12,9 @@ class LitEma(nn.Module):
self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
self.register_buffer( self.register_buffer(
"num_updates", "num_updates",
torch.tensor(0, dtype=torch.int) if use_num_upates else torch.tensor(-1, dtype=torch.int), torch.tensor(0, dtype=torch.int)
if use_num_upates
else torch.tensor(-1, dtype=torch.int),
) )
for name, p in model.named_parameters(): for name, p in model.named_parameters():
@ -45,9 +47,11 @@ class LitEma(nn.Module):
if m_param[key].requires_grad: if m_param[key].requires_grad:
sname = self.m_name2s_name[key] sname = self.m_name2s_name[key]
shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) shadow_params[sname].sub_(
one_minus_decay * (shadow_params[sname] - m_param[key])
)
else: else:
assert not key in self.m_name2s_name assert key not in self.m_name2s_name
def copy_to(self, model): def copy_to(self, model):
m_param = dict(model.named_parameters()) m_param = dict(model.named_parameters())
@ -56,7 +60,7 @@ class LitEma(nn.Module):
if m_param[key].requires_grad: if m_param[key].requires_grad:
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
else: else:
assert not key in self.m_name2s_name assert key not in self.m_name2s_name
def store(self, parameters): def store(self, parameters):
""" """

View File

@ -77,7 +77,9 @@ class IdentityRegularizer(AbstractRegularizer):
yield from () yield from ()
def measure_perplexity(predicted_indices: torch.Tensor, num_centroids: int) -> Tuple[torch.Tensor, torch.Tensor]: def measure_perplexity(
predicted_indices: torch.Tensor, num_centroids: int
) -> Tuple[torch.Tensor, torch.Tensor]:
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
encodings = F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids) encodings = F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids)

View File

@ -78,7 +78,9 @@ class SafeConv3d(torch.nn.Conv3d):
input_chunks = torch.chunk(input, part_num, dim=2) # NCTHW input_chunks = torch.chunk(input, part_num, dim=2) # NCTHW
if kernel_size > 1: if kernel_size > 1:
input_chunks = [input_chunks[0]] + [ input_chunks = [input_chunks[0]] + [
torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2) torch.cat(
(input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2
)
for i in range(1, len(input_chunks)) for i in range(1, len(input_chunks))
] ]
@ -244,7 +246,7 @@ def count_params(model, verbose=False):
def instantiate_from_config(config): def instantiate_from_config(config):
if not "target" in config: if "target" not in config:
if config == "__is_first_stage__": if config == "__is_first_stage__":
return None return None
elif config == "__is_unconditional__": elif config == "__is_unconditional__":

View File

@ -9,11 +9,16 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
MODEL_PATH = "THUDM/cogvlm2-llama3-caption" MODEL_PATH = "THUDM/cogvlm2-llama3-caption"
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
TORCH_TYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[ TORCH_TYPE = (
0] >= 8 else torch.float16 torch.bfloat16
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8
else torch.float16
)
parser = argparse.ArgumentParser(description="CogVLM2-Video CLI Demo") parser = argparse.ArgumentParser(description="CogVLM2-Video CLI Demo")
parser.add_argument('--quant', type=int, choices=[4, 8], help='Enable 4-bit or 8-bit precision loading', default=0) parser.add_argument(
'--quant', type=int, choices=[4, 8], help='Enable 4-bit or 8-bit precision loading', default=0
)
args = parser.parse_args([]) args = parser.parse_args([])
@ -29,8 +34,11 @@ def load_video(video_data, strategy='chat'):
clip_end_sec = 60 clip_end_sec = 60
clip_start_sec = 0 clip_start_sec = 0
start_frame = int(clip_start_sec * decord_vr.get_avg_fps()) start_frame = int(clip_start_sec * decord_vr.get_avg_fps())
end_frame = min(total_frames, end_frame = (
int(clip_end_sec * decord_vr.get_avg_fps())) if clip_end_sec is not None else total_frames min(total_frames, int(clip_end_sec * decord_vr.get_avg_fps()))
if clip_end_sec is not None
else total_frames
)
frame_id_list = np.linspace(start_frame, end_frame - 1, num_frames, dtype=int) frame_id_list = np.linspace(start_frame, end_frame - 1, num_frames, dtype=int)
elif strategy == 'chat': elif strategy == 'chat':
timestamps = decord_vr.get_frame_timestamp(np.arange(total_frames)) timestamps = decord_vr.get_frame_timestamp(np.arange(total_frames))
@ -54,11 +62,11 @@ tokenizer = AutoTokenizer.from_pretrained(
trust_remote_code=True, trust_remote_code=True,
) )
model = AutoModelForCausalLM.from_pretrained( model = (
MODEL_PATH, AutoModelForCausalLM.from_pretrained(MODEL_PATH, torch_dtype=TORCH_TYPE, trust_remote_code=True)
torch_dtype=TORCH_TYPE, .eval()
trust_remote_code=True .to(DEVICE)
).eval().to(DEVICE) )
def predict(prompt, video_data, temperature): def predict(prompt, video_data, temperature):
@ -69,11 +77,7 @@ def predict(prompt, video_data, temperature):
history = [] history = []
query = prompt query = prompt
inputs = model.build_conversation_input_ids( inputs = model.build_conversation_input_ids(
tokenizer=tokenizer, tokenizer=tokenizer, query=query, images=[video], history=history, template_version=strategy
query=query,
images=[video],
history=history,
template_version=strategy
) )
inputs = { inputs = {
'input_ids': inputs['input_ids'].unsqueeze(0).to('cuda'), 'input_ids': inputs['input_ids'].unsqueeze(0).to('cuda'),
@ -91,7 +95,7 @@ def predict(prompt, video_data, temperature):
} }
with torch.no_grad(): with torch.no_grad():
outputs = model.generate(**inputs, **gen_kwargs) outputs = model.generate(**inputs, **gen_kwargs)
outputs = outputs[:, inputs['input_ids'].shape[1]:] outputs = outputs[:, inputs['input_ids'].shape[1] :]
response = tokenizer.decode(outputs[0], skip_special_tokens=True) response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response return response

View File

@ -31,9 +31,18 @@ from dataclasses import dataclass
# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with # while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
# DeepSpeed data structures it has to be available in the current python environment. # DeepSpeed data structures it has to be available in the current python environment.
from deepspeed.utils import logger from deepspeed.utils import logger
from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS, from deepspeed.checkpoint.constants import (
FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES, DS_VERSION,
FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS) OPTIMIZER_STATE_DICT,
SINGLE_PARTITION_OF_FP32_GROUPS,
FP32_FLAT_GROUPS,
ZERO_STAGE,
PARTITION_COUNT,
PARAM_SHAPES,
BUFFER_NAMES,
FROZEN_PARAM_SHAPES,
FROZEN_PARAM_FRAGMENTS,
)
@dataclass @dataclass
@ -134,12 +143,14 @@ def parse_model_states(files):
frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None) frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
z_model_state = zero_model_state(buffers=buffers, z_model_state = zero_model_state(
buffers=buffers,
param_shapes=param_shapes, param_shapes=param_shapes,
shared_params=shared_params, shared_params=shared_params,
ds_version=ds_version, ds_version=ds_version,
frozen_param_shapes=frozen_param_shapes, frozen_param_shapes=frozen_param_shapes,
frozen_param_fragments=frozen_param_fragments) frozen_param_fragments=frozen_param_fragments,
)
zero_model_states.append(z_model_state) zero_model_states.append(z_model_state)
return zero_model_states return zero_model_states
@ -155,7 +166,7 @@ def parse_optim_states(files, ds_checkpoint_dir):
state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None) state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
state_dicts.append(state_dict) state_dicts.append(state_dict)
if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]: if ZERO_STAGE not in state_dicts[0][OPTIMIZER_STATE_DICT]:
raise ValueError(f"{files[0]} is not a zero checkpoint") raise ValueError(f"{files[0]} is not a zero checkpoint")
zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE] zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT] world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
@ -181,7 +192,9 @@ def parse_optim_states(files, ds_checkpoint_dir):
else: else:
raise ValueError(f"unknown zero stage {zero_stage}") raise ValueError(f"unknown zero stage {zero_stage}")
fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))] fp32_flat_groups = [
state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))
]
return zero_stage, world_size, fp32_flat_groups return zero_stage, world_size, fp32_flat_groups
@ -205,15 +218,20 @@ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_
print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}') print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
if zero_stage <= 2: if zero_stage <= 2:
return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, return _get_fp32_state_dict_from_zero2_checkpoint(
exclude_frozen_parameters) world_size, fp32_flat_groups, zero_model_states, exclude_frozen_parameters
)
elif zero_stage == 3: elif zero_stage == 3:
return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, return _get_fp32_state_dict_from_zero3_checkpoint(
exclude_frozen_parameters) world_size, fp32_flat_groups, zero_model_states, exclude_frozen_parameters
)
def _zero2_merge_frozen_params(state_dict, zero_model_states): def _zero2_merge_frozen_params(state_dict, zero_model_states):
if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0: if (
zero_model_states[0].frozen_param_shapes is None
or len(zero_model_states[0].frozen_param_shapes) == 0
):
return return
frozen_param_shapes = zero_model_states[0].frozen_param_shapes frozen_param_shapes = zero_model_states[0].frozen_param_shapes
@ -269,11 +287,17 @@ def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero
full_single_fp32_vector = torch.cat(merged_partitions, 0) full_single_fp32_vector = torch.cat(merged_partitions, 0)
merged_single_partition_of_fp32_groups.append(full_single_fp32_vector) merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
avail_numel = sum( avail_numel = sum(
[full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups]) [
full_single_fp32_vector.numel()
for full_single_fp32_vector in merged_single_partition_of_fp32_groups
]
)
if debug: if debug:
wanted_params = sum([len(shapes) for shapes in param_shapes]) wanted_params = sum([len(shapes) for shapes in param_shapes])
wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes]) wanted_numel = sum(
[sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes]
)
# not asserting if there is a mismatch due to possible padding # not asserting if there is a mismatch due to possible padding
print(f"Have {avail_numel} numels to process.") print(f"Have {avail_numel} numels to process.")
print(f"Need {wanted_numel} numels in {wanted_params} params.") print(f"Need {wanted_numel} numels in {wanted_params} params.")
@ -283,18 +307,23 @@ def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero
# out-of-core computing solution # out-of-core computing solution
total_numel = 0 total_numel = 0
total_params = 0 total_params = 0
for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups): for shapes, full_single_fp32_vector in zip(
param_shapes, merged_single_partition_of_fp32_groups
):
offset = 0 offset = 0
avail_numel = full_single_fp32_vector.numel() avail_numel = full_single_fp32_vector.numel()
for name, shape in shapes.items(): for name, shape in shapes.items():
unpartitioned_numel = (
unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape) shape.numel() if _has_callable(shape, 'numel') else math.prod(shape)
)
total_numel += unpartitioned_numel total_numel += unpartitioned_numel
total_params += 1 total_params += 1
if debug: if debug:
print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ") print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape) state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(
shape
)
offset += unpartitioned_numel offset += unpartitioned_numel
# Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
@ -322,8 +351,9 @@ def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero
print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements") print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, def _get_fp32_state_dict_from_zero2_checkpoint(
exclude_frozen_parameters): world_size, fp32_flat_groups, zero_model_states, exclude_frozen_parameters
):
state_dict = OrderedDict() state_dict = OrderedDict()
# buffers # buffers
@ -353,7 +383,10 @@ def zero3_partitioned_param_info(unpartitioned_numel, world_size):
def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states): def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0: if (
zero_model_states[0].frozen_param_shapes is None
or len(zero_model_states[0].frozen_param_shapes) == 0
):
return return
if debug: if debug:
@ -364,7 +397,10 @@ def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
frozen_param_shapes = zero_model_states[0].frozen_param_shapes frozen_param_shapes = zero_model_states[0].frozen_param_shapes
wanted_params = len(frozen_param_shapes) wanted_params = len(frozen_param_shapes)
wanted_numel = sum(s.numel() for s in frozen_param_shapes.values()) wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size avail_numel = (
sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()])
* world_size
)
print(f'Frozen params: Have {avail_numel} numels to process.') print(f'Frozen params: Have {avail_numel} numels to process.')
print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params') print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
@ -375,10 +411,14 @@ def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
unpartitioned_numel = shape.numel() unpartitioned_numel = shape.numel()
total_numel += unpartitioned_numel total_numel += unpartitioned_numel
param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states) param_frags = tuple(
model_state.frozen_param_fragments[name] for model_state in zero_model_states
)
state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape) state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape)
partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(
unpartitioned_numel, world_size
)
if debug: if debug:
print( print(
@ -416,21 +456,32 @@ class GatheredTensor:
start_group_id = None start_group_id = None
end_group_id = None end_group_id = None
for group_id in range(len(self.flat_groups_offset)): for group_id in range(len(self.flat_groups_offset)):
if self.flat_groups_offset[group_id] <= self.offset < self.flat_groups_offset[group_id + 1]: if (
self.flat_groups_offset[group_id]
<= self.offset
< self.flat_groups_offset[group_id + 1]
):
start_group_id = group_id start_group_id = group_id
if self.flat_groups_offset[group_id] < end_idx <= self.flat_groups_offset[group_id + 1]: if (
self.flat_groups_offset[group_id]
< end_idx
<= self.flat_groups_offset[group_id + 1]
):
end_group_id = group_id end_group_id = group_id
break break
# collect weights from related group/groups # collect weights from related group/groups
for group_id in range(start_group_id, end_group_id + 1): for group_id in range(start_group_id, end_group_id + 1):
flat_tensor = flat_groups_at_rank_i[group_id] flat_tensor = flat_groups_at_rank_i[group_id]
start_offset = self.offset - self.flat_groups_offset[group_id] start_offset = self.offset - self.flat_groups_offset[group_id]
end_offset = min(end_idx, self.flat_groups_offset[group_id + 1]) - self.flat_groups_offset[group_id] end_offset = (
min(end_idx, self.flat_groups_offset[group_id + 1])
- self.flat_groups_offset[group_id]
)
pad_flat_param_chunks.append(flat_tensor[start_offset:end_offset]) pad_flat_param_chunks.append(flat_tensor[start_offset:end_offset])
# collect weights from all ranks # collect weights from all ranks
pad_flat_param = torch.cat(pad_flat_param_chunks, dim=0) pad_flat_param = torch.cat(pad_flat_param_chunks, dim=0)
param = pad_flat_param[:self.shape.numel()].view(self.shape).contiguous() param = pad_flat_param[: self.shape.numel()].view(self.shape).contiguous()
return param return param
@ -461,12 +512,16 @@ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero
offset = 0 offset = 0
total_numel = 0 total_numel = 0
total_params = 0 total_params = 0
flat_groups_offset = [0] + list(np.cumsum([flat_tensor.numel() for flat_tensor in fp32_flat_groups[0]])) flat_groups_offset = [0] + list(
np.cumsum([flat_tensor.numel() for flat_tensor in fp32_flat_groups[0]])
)
for name, shape in tqdm(param_shapes.items(), desc='Gathering sharded weights'): for name, shape in tqdm(param_shapes.items(), desc='Gathering sharded weights'):
unpartitioned_numel = shape.numel() unpartitioned_numel = shape.numel()
total_numel += unpartitioned_numel total_numel += unpartitioned_numel
total_params += 1 total_params += 1
partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(
unpartitioned_numel, world_size
)
if debug: if debug:
print( print(
@ -474,7 +529,9 @@ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero
) )
# memory efficient tensor # memory efficient tensor
tensor = GatheredTensor(fp32_flat_groups, flat_groups_offset, offset, partitioned_numel, shape) tensor = GatheredTensor(
fp32_flat_groups, flat_groups_offset, offset, partitioned_numel, shape
)
state_dict[name] = tensor state_dict[name] = tensor
offset += partitioned_numel offset += partitioned_numel
@ -484,11 +541,14 @@ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero
if offset != avail_numel: if offset != avail_numel:
raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong") raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements") print(
f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements"
)
def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, def _get_fp32_state_dict_from_zero3_checkpoint(
exclude_frozen_parameters): world_size, fp32_flat_groups, zero_model_states, exclude_frozen_parameters
):
state_dict = OrderedDict() state_dict = OrderedDict()
# buffers # buffers
@ -530,10 +590,9 @@ def to_torch_tensor(state_dict, return_empty_tensor=False):
return torch_state_dict return torch_state_dict
def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, def get_fp32_state_dict_from_zero_checkpoint(
tag=None, checkpoint_dir, tag=None, exclude_frozen_parameters=False, lazy_mode=False
exclude_frozen_parameters=False, ):
lazy_mode=False):
""" """
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
@ -588,19 +647,23 @@ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
if not os.path.isdir(ds_checkpoint_dir): if not os.path.isdir(ds_checkpoint_dir):
raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist") raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
state_dict = _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters) state_dict = _get_fp32_state_dict_from_zero_checkpoint(
ds_checkpoint_dir, exclude_frozen_parameters
)
if lazy_mode: if lazy_mode:
return state_dict return state_dict
else: else:
return to_torch_tensor(state_dict) return to_torch_tensor(state_dict)
def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, def convert_zero_checkpoint_to_fp32_state_dict(
checkpoint_dir,
output_dir, output_dir,
max_shard_size="5GB", max_shard_size="5GB",
safe_serialization=False, safe_serialization=False,
tag=None, tag=None,
exclude_frozen_parameters=False): exclude_frozen_parameters=False,
):
""" """
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed. loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
@ -629,25 +692,28 @@ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir,
raise raise
# Convert zero checkpoint to state_dict # Convert zero checkpoint to state_dict
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, state_dict = get_fp32_state_dict_from_zero_checkpoint(
tag, checkpoint_dir, tag, exclude_frozen_parameters, lazy_mode=True
exclude_frozen_parameters, )
lazy_mode=True)
# Shard the model if it is too big. # Shard the model if it is too big.
weights_name = "model.safetensors" if safe_serialization else "pytorch_model.bin" weights_name = "model.safetensors" if safe_serialization else "pytorch_model.bin"
if max_shard_size is not None: if max_shard_size is not None:
filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors") filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
".safetensors", "{suffix}.safetensors"
)
# an memory-efficient approach for sharding # an memory-efficient approach for sharding
empty_state_dict = to_torch_tensor(state_dict, return_empty_tensor=True) empty_state_dict = to_torch_tensor(state_dict, return_empty_tensor=True)
state_dict_split = split_torch_state_dict_into_shards(empty_state_dict, state_dict_split = split_torch_state_dict_into_shards(
filename_pattern=filename_pattern, empty_state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
max_shard_size=max_shard_size) )
else: else:
from collections import namedtuple from collections import namedtuple
StateDictSplit = namedtuple("StateDictSplit", ["is_sharded", "filename_to_tensors"]) StateDictSplit = namedtuple("StateDictSplit", ["is_sharded", "filename_to_tensors"])
state_dict_split = StateDictSplit(is_sharded=False, state_dict_split = StateDictSplit(
filename_to_tensors={weights_name: list(state_dict.keys())}) is_sharded=False, filename_to_tensors={weights_name: list(state_dict.keys())}
)
# Save the model by shard # Save the model by shard
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
@ -673,7 +739,9 @@ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir,
"metadata": state_dict_split.metadata, "metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename, "weight_map": state_dict_split.tensor_to_filename,
} }
save_index_file = "model.safetensors.index.json" if safe_serialization else "pytorch_model.bin.index.json" save_index_file = (
"model.safetensors.index.json" if safe_serialization else "pytorch_model.bin.index.json"
)
save_index_file = os.path.join(output_dir, save_index_file) save_index_file = os.path.join(output_dir, save_index_file)
with open(save_index_file, "w", encoding="utf-8") as f: with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n" content = json.dumps(index, indent=2, sort_keys=True) + "\n"
@ -719,12 +787,14 @@ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
return model return model
def convert_zero_checkpoint_to_bf16_state_dict(checkpoint_dir, def convert_zero_checkpoint_to_bf16_state_dict(
checkpoint_dir,
output_dir, output_dir,
max_shard_size="5GB", max_shard_size="5GB",
safe_serialization=True, safe_serialization=True,
tag=None, tag=None,
exclude_frozen_parameters=False): exclude_frozen_parameters=False,
):
""" """
ZeRO 2 ZeRO 3 格式的 DeepSpeed 检查点转换为 BF16并输出到指定目录下命名规则为: ZeRO 2 ZeRO 3 格式的 DeepSpeed 检查点转换为 BF16并输出到指定目录下命名规则为:
- 如果只有一个分片: - 如果只有一个分片:
@ -748,10 +818,7 @@ def convert_zero_checkpoint_to_bf16_state_dict(checkpoint_dir,
raise ImportError("You need `pip install huggingface_hub` to use the sharding feature.") raise ImportError("You need `pip install huggingface_hub` to use the sharding feature.")
state_dict = get_fp32_state_dict_from_zero_checkpoint( state_dict = get_fp32_state_dict_from_zero_checkpoint(
checkpoint_dir, checkpoint_dir, tag=tag, exclude_frozen_parameters=exclude_frozen_parameters, lazy_mode=True
tag=tag,
exclude_frozen_parameters=exclude_frozen_parameters,
lazy_mode=True
) )
state_dict = to_torch_tensor(state_dict, return_empty_tensor=False) state_dict = to_torch_tensor(state_dict, return_empty_tensor=False)
@ -766,9 +833,7 @@ def convert_zero_checkpoint_to_bf16_state_dict(checkpoint_dir,
empty_state_dict = to_torch_tensor(state_dict, return_empty_tensor=True) empty_state_dict = to_torch_tensor(state_dict, return_empty_tensor=True)
state_dict_split = split_torch_state_dict_into_shards( state_dict_split = split_torch_state_dict_into_shards(
empty_state_dict, empty_state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
filename_pattern=filename_pattern,
max_shard_size=max_shard_size
) )
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
@ -789,7 +854,6 @@ def convert_zero_checkpoint_to_bf16_state_dict(checkpoint_dir,
del shard_state_dict del shard_state_dict
gc.collect() gc.collect()
if state_dict_split.is_sharded: if state_dict_split.is_sharded:
index = { index = {
"metadata": state_dict_split.metadata, "metadata": state_dict_split.metadata,
@ -801,21 +865,29 @@ def convert_zero_checkpoint_to_bf16_state_dict(checkpoint_dir,
else: else:
only_filename = list(state_dict_split.filename_to_tensors.keys())[0] only_filename = list(state_dict_split.filename_to_tensors.keys())[0]
old_path = os.path.join(output_dir, only_filename) old_path = os.path.join(output_dir, only_filename)
new_path = os.path.join(output_dir, "diffusion_pytorch_model.safetensors" if safe_serialization new_path = os.path.join(
else "diffusion_pytorch_model.bin") output_dir,
"diffusion_pytorch_model.safetensors"
if safe_serialization
else "diffusion_pytorch_model.bin",
)
if old_path != new_path: if old_path != new_path:
os.rename(old_path, new_path) os.rename(old_path, new_path)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("checkpoint_dir", parser.add_argument(
"checkpoint_dir",
type=str, type=str,
help="path to the desired checkpoint folder, e.g., path/checkpoint-12") help="path to the desired checkpoint folder, e.g., path/checkpoint-12",
parser.add_argument("output_dir", )
parser.add_argument(
"output_dir",
type=str, type=str,
help="directory to the pytorch fp32 state_dict output files" help="directory to the pytorch fp32 state_dict output files"
"(e.g. path/checkpoint-12-output/)") "(e.g. path/checkpoint-12-output/)",
)
parser.add_argument( parser.add_argument(
"--max_shard_size", "--max_shard_size",
type=str, type=str,
@ -823,26 +895,34 @@ if __name__ == "__main__":
help="The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size" help="The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size"
"lower than this size. If expressed as a string, needs to be digits followed by a unit (like `5MB`" "lower than this size. If expressed as a string, needs to be digits followed by a unit (like `5MB`"
"We default it to 5GB in order for models to be able to run easily on free-tier google colab instances" "We default it to 5GB in order for models to be able to run easily on free-tier google colab instances"
"without CPU OOM issues.") "without CPU OOM issues.",
)
parser.add_argument( parser.add_argument(
"--safe_serialization", "--safe_serialization",
default=False, default=False,
action='store_true', action='store_true',
help="Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).") help="Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).",
parser.add_argument("-t", )
parser.add_argument(
"-t",
"--tag", "--tag",
type=str, type=str,
default=None, default=None,
help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1") help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1",
parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters") )
parser.add_argument(
"--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters"
)
parser.add_argument("-d", "--debug", action='store_true', help="enable debug") parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
args = parser.parse_args() args = parser.parse_args()
debug = args.debug debug = args.debug
convert_zero_checkpoint_to_bf16_state_dict(args.checkpoint_dir, convert_zero_checkpoint_to_bf16_state_dict(
args.checkpoint_dir,
args.output_dir, args.output_dir,
max_shard_size=args.max_shard_size, max_shard_size=args.max_shard_size,
safe_serialization=args.safe_serialization, safe_serialization=args.safe_serialization,
tag=args.tag, tag=args.tag,
exclude_frozen_parameters=args.exclude_frozen_parameters) exclude_frozen_parameters=args.exclude_frozen_parameters,
)

View File

@ -10,6 +10,7 @@ Original Script:
https://github.com/huggingface/diffusers/blob/main/scripts/convert_cogvideox_to_diffusers.py https://github.com/huggingface/diffusers/blob/main/scripts/convert_cogvideox_to_diffusers.py
""" """
import argparse import argparse
from typing import Any, Dict from typing import Any, Dict
@ -143,7 +144,9 @@ def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
return state_dict return state_dict
def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]: def update_state_dict_inplace(
state_dict: Dict[str, Any], old_key: str, new_key: str
) -> Dict[str, Any]:
state_dict[new_key] = state_dict.pop(old_key) state_dict[new_key] = state_dict.pop(old_key)
@ -164,8 +167,11 @@ def convert_transformer(
num_layers=num_layers, num_layers=num_layers,
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
use_rotary_positional_embeddings=use_rotary_positional_embeddings, use_rotary_positional_embeddings=use_rotary_positional_embeddings,
ofs_embed_dim=512 if (i2v and init_kwargs["patch_size_t"] is not None) else None, # CogVideoX1.5-5B-I2V ofs_embed_dim=512
use_learned_positional_embeddings=i2v and init_kwargs["patch_size_t"] is None, # CogVideoX-5B-I2V if (i2v and init_kwargs["patch_size_t"] is not None)
else None, # CogVideoX1.5-5B-I2V
use_learned_positional_embeddings=i2v
and init_kwargs["patch_size_t"] is None, # CogVideoX-5B-I2V
**init_kwargs, **init_kwargs,
).to(dtype=dtype) ).to(dtype=dtype)
@ -240,17 +246,40 @@ def get_transformer_init_kwargs(version: str):
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint" "--transformer_ckpt_path",
) type=str,
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint") default=None,
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") help="Path to original transformer checkpoint",
parser.add_argument("--fp16", action="store_true", default=False, help="Whether to save the model weights in fp16")
parser.add_argument("--bf16", action="store_true", default=False, help="Whether to save the model weights in bf16")
parser.add_argument(
"--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving"
) )
parser.add_argument( parser.add_argument(
"--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory" "--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint"
)
parser.add_argument(
"--output_path", type=str, required=True, help="Path where converted model should be saved"
)
parser.add_argument(
"--fp16",
action="store_true",
default=False,
help="Whether to save the model weights in fp16",
)
parser.add_argument(
"--bf16",
action="store_true",
default=False,
help="Whether to save the model weights in bf16",
)
parser.add_argument(
"--push_to_hub",
action="store_true",
default=False,
help="Whether to push to HF Hub after saving",
)
parser.add_argument(
"--text_encoder_cache_dir",
type=str,
default=None,
help="Path to text encoder cache directory",
) )
parser.add_argument( parser.add_argument(
"--typecast_text_encoder", "--typecast_text_encoder",
@ -261,15 +290,24 @@ def get_args():
# For CogVideoX-2B, num_layers is 30. For 5B, it is 42 # For CogVideoX-2B, num_layers is 30. For 5B, it is 42
parser.add_argument("--num_layers", type=int, default=30, help="Number of transformer blocks") parser.add_argument("--num_layers", type=int, default=30, help="Number of transformer blocks")
# For CogVideoX-2B, num_attention_heads is 30. For 5B, it is 48 # For CogVideoX-2B, num_attention_heads is 30. For 5B, it is 48
parser.add_argument("--num_attention_heads", type=int, default=30, help="Number of attention heads") parser.add_argument(
"--num_attention_heads", type=int, default=30, help="Number of attention heads"
)
# For CogVideoX-2B, use_rotary_positional_embeddings is False. For 5B, it is True # For CogVideoX-2B, use_rotary_positional_embeddings is False. For 5B, it is True
parser.add_argument( parser.add_argument(
"--use_rotary_positional_embeddings", action="store_true", default=False, help="Whether to use RoPE or not" "--use_rotary_positional_embeddings",
action="store_true",
default=False,
help="Whether to use RoPE or not",
) )
# For CogVideoX-2B, scaling_factor is 1.15258426. For 5B, it is 0.7 # For CogVideoX-2B, scaling_factor is 1.15258426. For 5B, it is 0.7
parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE") parser.add_argument(
"--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE"
)
# For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0 # For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0
parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE") parser.add_argument(
"--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE"
)
parser.add_argument( parser.add_argument(
"--i2v", "--i2v",
action="store_true", action="store_true",
@ -313,7 +351,9 @@ if __name__ == "__main__":
text_encoder_id = "google/t5-v1_1-xxl" text_encoder_id = "google/t5-v1_1-xxl"
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH) tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir) text_encoder = T5EncoderModel.from_pretrained(
text_encoder_id, cache_dir=args.text_encoder_cache_dir
)
if args.typecast_text_encoder: if args.typecast_text_encoder:
text_encoder = text_encoder.to(dtype=dtype) text_encoder = text_encoder.to(dtype=dtype)
@ -355,4 +395,9 @@ if __name__ == "__main__":
# This is necessary This is necessary for users with insufficient memory, # This is necessary This is necessary for users with insufficient memory,
# such as those using Colab and notebooks, as it can save some memory used for model loading. # such as those using Colab and notebooks, as it can save some memory used for model loading.
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=args.push_to_hub) pipe.save_pretrained(
args.output_path,
safe_serialization=True,
max_shard_size="5GB",
push_to_hub=args.push_to_hub,
)

View File

@ -15,8 +15,8 @@ def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
state_dict = state_dict["state_dict"] state_dict = state_dict["state_dict"]
return state_dict return state_dict
LORA_KEYS_RENAME = {
LORA_KEYS_RENAME = {
'attention.query_key_value.matrix_A.0': 'attn1.to_q.lora_A.weight', 'attention.query_key_value.matrix_A.0': 'attn1.to_q.lora_A.weight',
'attention.query_key_value.matrix_A.1': 'attn1.to_k.lora_A.weight', 'attention.query_key_value.matrix_A.1': 'attn1.to_k.lora_A.weight',
'attention.query_key_value.matrix_A.2': 'attn1.to_v.lora_A.weight', 'attention.query_key_value.matrix_A.2': 'attn1.to_v.lora_A.weight',
@ -24,22 +24,18 @@ LORA_KEYS_RENAME = {
'attention.query_key_value.matrix_B.1': 'attn1.to_k.lora_B.weight', 'attention.query_key_value.matrix_B.1': 'attn1.to_k.lora_B.weight',
'attention.query_key_value.matrix_B.2': 'attn1.to_v.lora_B.weight', 'attention.query_key_value.matrix_B.2': 'attn1.to_v.lora_B.weight',
'attention.dense.matrix_A.0': 'attn1.to_out.0.lora_A.weight', 'attention.dense.matrix_A.0': 'attn1.to_out.0.lora_A.weight',
'attention.dense.matrix_B.0': 'attn1.to_out.0.lora_B.weight' 'attention.dense.matrix_B.0': 'attn1.to_out.0.lora_B.weight',
} }
PREFIX_KEY = "model.diffusion_model." PREFIX_KEY = "model.diffusion_model."
SAT_UNIT_KEY = "layers" SAT_UNIT_KEY = "layers"
LORA_PREFIX_KEY = "transformer_blocks" LORA_PREFIX_KEY = "transformer_blocks"
def export_lora_weight(ckpt_path, lora_save_directory):
def export_lora_weight(ckpt_path,lora_save_directory):
merge_original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True)) merge_original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
lora_state_dict = {} lora_state_dict = {}
for key in list(merge_original_state_dict.keys()): for key in list(merge_original_state_dict.keys()):
new_key = key[len(PREFIX_KEY) :] new_key = key[len(PREFIX_KEY) :]
@ -50,8 +46,6 @@ def export_lora_weight(ckpt_path,lora_save_directory):
lora_state_dict[new_key] = merge_original_state_dict[key] lora_state_dict[new_key] = merge_original_state_dict[key]
# final length should be 240 # final length should be 240
if len(lora_state_dict) != 240: if len(lora_state_dict) != 240:
raise ValueError("lora_state_dict length is not 240") raise ValueError("lora_state_dict length is not 240")
@ -64,7 +58,7 @@ def export_lora_weight(ckpt_path,lora_save_directory):
is_main_process=True, is_main_process=True,
weight_name=None, weight_name=None,
save_function=None, save_function=None,
safe_serialization=True safe_serialization=True,
) )
@ -73,7 +67,12 @@ def get_args():
parser.add_argument( parser.add_argument(
"--sat_pt_path", type=str, required=True, help="Path to original sat transformer checkpoint" "--sat_pt_path", type=str, required=True, help="Path to original sat transformer checkpoint"
) )
parser.add_argument("--lora_save_directory", type=str, required=True, help="Path where converted lora should be saved") parser.add_argument(
"--lora_save_directory",
type=str,
required=True,
help="Path where converted lora should be saved",
)
return parser.parse_args() return parser.parse_args()

View File

@ -35,20 +35,16 @@ caption_generator = transformers.pipeline(
"torch_dtype": torch.bfloat16, "torch_dtype": torch.bfloat16,
}, },
trust_remote_code=True, trust_remote_code=True,
tokenizer=tokenizer tokenizer=tokenizer,
) )
image_generator = DiffusionPipeline.from_pretrained( image_generator = DiffusionPipeline.from_pretrained(
image_generator_model_id, image_generator_model_id, torch_dtype=torch.bfloat16, device_map="balanced"
torch_dtype=torch.bfloat16,
device_map="balanced"
) )
# image_generator.to("cuda") # image_generator.to("cuda")
video_generator = CogVideoXImageToVideoPipeline.from_pretrained( video_generator = CogVideoXImageToVideoPipeline.from_pretrained(
video_generator_model_id, video_generator_model_id, torch_dtype=torch.bfloat16, device_map="balanced"
torch_dtype=torch.bfloat16,
device_map="balanced"
) )
video_generator.vae.enable_slicing() video_generator.vae.enable_slicing()
@ -87,11 +83,7 @@ def generate_caption(prompt):
{"role": "user", "content": prompt + "\n" + user_prompt}, {"role": "user", "content": prompt + "\n" + user_prompt},
] ]
response = caption_generator( response = caption_generator(messages, max_new_tokens=226, return_full_text=False)
messages,
max_new_tokens=226,
return_full_text=False
)
caption = response[0]["generated_text"] caption = response[0]["generated_text"]
if caption.startswith("\"") and caption.endswith("\""): if caption.startswith("\"") and caption.endswith("\""):
caption = caption[1:-1] caption = caption[1:-1]
@ -109,11 +101,7 @@ def generate_image(caption, progress=gr.Progress(track_tqdm=True)):
return image, image # One for output One for State return image, image # One for output One for State
def generate_video( def generate_video(caption, image, progress=gr.Progress(track_tqdm=True)):
caption,
image,
progress=gr.Progress(track_tqdm=True)
):
generator = torch.Generator().manual_seed(seed) generator = torch.Generator().manual_seed(seed)
video_frames = video_generator( video_frames = video_generator(
image=image, image=image,
@ -181,14 +169,19 @@ with gr.Blocks() as demo:
image_output = gr.Image(label="Generated Image") image_output = gr.Image(label="Generated Image")
state_image = gr.State() state_image = gr.State()
generate_caption_button.click(fn=generate_caption, inputs=prompt, outputs=caption) generate_caption_button.click(fn=generate_caption, inputs=prompt, outputs=caption)
generate_image_button.click(fn=generate_image, inputs=caption, outputs=[image_output, state_image]) generate_image_button.click(
fn=generate_image, inputs=caption, outputs=[image_output, state_image]
)
with gr.Column(): with gr.Column():
video_output = gr.Video(label="Generated Video", width=720, height=480) video_output = gr.Video(label="Generated Video", width=720, height=480)
download_video_button = gr.File(label="📥 Download Video", visible=False) download_video_button = gr.File(label="📥 Download Video", visible=False)
download_gif_button = gr.File(label="📥 Download GIF", visible=False) download_gif_button = gr.File(label="📥 Download GIF", visible=False)
generate_video_button = gr.Button("Generate Video from Image") generate_video_button = gr.Button("Generate Video from Image")
generate_video_button.click(fn=generate_video, inputs=[caption, state_image], generate_video_button.click(
outputs=[video_output, download_gif_button]) fn=generate_video,
inputs=[caption, state_image],
outputs=[video_output, download_gif_button],
)
if __name__ == "__main__": if __name__ == "__main__":
demo.launch() demo.launch()

View File

@ -65,7 +65,7 @@ def get_args():
"--num_videos", "--num_videos",
type=int, type=int,
default=5, default=5,
help="Number of unique videos you would like to generate." help="Number of unique videos you would like to generate.",
) )
parser.add_argument( parser.add_argument(
"--model_path", "--model_path",
@ -83,31 +83,28 @@ def get_args():
"--caption_generator_cache_dir", "--caption_generator_cache_dir",
type=str, type=str,
default=None, default=None,
help="Cache directory for caption generation model." help="Cache directory for caption generation model.",
) )
parser.add_argument( parser.add_argument(
"--image_generator_model_id", "--image_generator_model_id",
type=str, type=str,
default="black-forest-labs/FLUX.1-dev", default="black-forest-labs/FLUX.1-dev",
help="Image generation model." help="Image generation model.",
) )
parser.add_argument( parser.add_argument(
"--image_generator_cache_dir", "--image_generator_cache_dir",
type=str, type=str,
default=None, default=None,
help="Cache directory for image generation model." help="Cache directory for image generation model.",
) )
parser.add_argument( parser.add_argument(
"--image_generator_num_inference_steps", "--image_generator_num_inference_steps",
type=int, type=int,
default=50, default=50,
help="Caption generation model." help="Caption generation model.",
) )
parser.add_argument( parser.add_argument(
"--guidance_scale", "--guidance_scale", type=float, default=7, help="Guidance scale to be use for generation."
type=float,
default=7,
help="Guidance scale to be use for generation."
) )
parser.add_argument( parser.add_argument(
"--use_dynamic_cfg", "--use_dynamic_cfg",
@ -123,19 +120,14 @@ def get_args():
parser.add_argument( parser.add_argument(
"--compile", "--compile",
action="store_true", action="store_true",
help="Whether or not to compile the transformer of image and video generators." help="Whether or not to compile the transformer of image and video generators.",
) )
parser.add_argument( parser.add_argument(
"--enable_vae_tiling", "--enable_vae_tiling",
action="store_true", action="store_true",
help="Whether or not to use VAE tiling when encoding/decoding." help="Whether or not to use VAE tiling when encoding/decoding.",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="Seed for reproducibility."
) )
parser.add_argument("--seed", type=int, default=42, help="Seed for reproducibility.")
return parser.parse_args() return parser.parse_args()
@ -157,7 +149,9 @@ def main(args: Dict[str, Any]) -> None:
torch.cuda.manual_seed_all(args.seed) torch.cuda.manual_seed_all(args.seed)
reset_memory() reset_memory()
tokenizer = AutoTokenizer.from_pretrained(args.caption_generator_model_id, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(
args.caption_generator_model_id, trust_remote_code=True
)
caption_generator = transformers.pipeline( caption_generator = transformers.pipeline(
"text-generation", "text-generation",
model=args.caption_generator_model_id, model=args.caption_generator_model_id,
@ -168,7 +162,7 @@ def main(args: Dict[str, Any]) -> None:
"torch_dtype": torch.bfloat16, "torch_dtype": torch.bfloat16,
}, },
trust_remote_code=True, trust_remote_code=True,
tokenizer=tokenizer tokenizer=tokenizer,
) )
captions = [] captions = []
@ -197,12 +191,14 @@ def main(args: Dict[str, Any]) -> None:
image_generator = DiffusionPipeline.from_pretrained( image_generator = DiffusionPipeline.from_pretrained(
args.image_generator_model_id, args.image_generator_model_id,
cache_dir=args.image_generator_cache_dir, cache_dir=args.image_generator_cache_dir,
torch_dtype=torch.bfloat16 torch_dtype=torch.bfloat16,
) )
image_generator.to("cuda") image_generator.to("cuda")
if args.compile: if args.compile:
image_generator.transformer = torch.compile(image_generator.transformer, mode="max-autotune", fullgraph=True) image_generator.transformer = torch.compile(
image_generator.transformer, mode="max-autotune", fullgraph=True
)
if args.enable_vae_tiling: if args.enable_vae_tiling:
image_generator.vae.enable_tiling() image_generator.vae.enable_tiling()
@ -216,7 +212,9 @@ def main(args: Dict[str, Any]) -> None:
num_inference_steps=args.image_generator_num_inference_steps, num_inference_steps=args.image_generator_num_inference_steps,
guidance_scale=3.5, guidance_scale=3.5,
).images[0] ).images[0]
filename = caption[:25].replace(".", "_").replace("'", "_").replace('"', "_").replace(",", "_") filename = (
caption[:25].replace(".", "_").replace("'", "_").replace('"', "_").replace(",", "_")
)
image.save(output_dir / f"{index}_{filename}.png") image.save(output_dir / f"{index}_{filename}.png")
images.append(image) images.append(image)
@ -224,13 +222,16 @@ def main(args: Dict[str, Any]) -> None:
reset_memory() reset_memory()
video_generator = CogVideoXImageToVideoPipeline.from_pretrained( video_generator = CogVideoXImageToVideoPipeline.from_pretrained(
args.model_path, torch_dtype=torch.bfloat16).to("cuda") args.model_path, torch_dtype=torch.bfloat16
).to("cuda")
video_generator.scheduler = CogVideoXDPMScheduler.from_config( video_generator.scheduler = CogVideoXDPMScheduler.from_config(
video_generator.scheduler.config, video_generator.scheduler.config, timestep_spacing="trailing"
timestep_spacing="trailing") )
if args.compile: if args.compile:
video_generator.transformer = torch.compile(video_generator.transformer, mode="max-autotune", fullgraph=True) video_generator.transformer = torch.compile(
video_generator.transformer, mode="max-autotune", fullgraph=True
)
if args.enable_vae_tiling: if args.enable_vae_tiling:
video_generator.vae.enable_tiling() video_generator.vae.enable_tiling()
@ -248,7 +249,9 @@ def main(args: Dict[str, Any]) -> None:
use_dynamic_cfg=args.use_dynamic_cfg, use_dynamic_cfg=args.use_dynamic_cfg,
generator=generator, generator=generator,
).frames[0] ).frames[0]
filename = caption[:25].replace(".", "_").replace("'", "_").replace('"', "_").replace(",", "_") filename = (
caption[:25].replace(".", "_").replace("'", "_").replace('"', "_").replace(",", "_")
)
export_to_video(video, output_dir / f"{index}_{filename}.mp4", fps=8) export_to_video(video, output_dir / f"{index}_{filename}.mp4", fps=8)

Some files were not shown because too many files have changed in this diff Show More