规范化

This commit is contained in:
zR 2024-08-06 18:50:03 +08:00
parent 487a815219
commit 3430df1a36
11 changed files with 436 additions and 367 deletions

51
.github/ISSUE_TEMPLATE/bug_report.yaml vendored Normal file
View File

@ -0,0 +1,51 @@
name: "\U0001F41B Bug Report"
description: Submit a bug report to help us improve CogVideoX / 提交一个 Bug 问题报告来帮助我们改进 CogVideoX 开源模型
body:
- type: textarea
id: system-info
attributes:
label: System Info / 系統信息
description: Your operating environment / 您的运行环境信息
placeholder: Includes Cuda version, Diffusers version, Python version, operating system, hardware information (if you suspect a hardware problem)... / 包括Cuda版本DiffusersPython版本操作系统硬件信息(如果您怀疑是硬件方面的问题)...
validations:
required: true
- type: checkboxes
id: information-scripts-examples
attributes:
label: Information / 问题信息
description: 'The problem arises when using: / 问题出现在'
options:
- label: "The official example scripts / 官方的示例脚本"
- label: "My own modified scripts / 我自己修改的脚本和任务"
- type: textarea
id: reproduction
validations:
required: true
attributes:
label: Reproduction / 复现过程
description: |
Please provide a code example that reproduces the problem you encountered, preferably with a minimal reproduction unit.
If you have code snippets, error messages, stack traces, please provide them here as well.
Please format your code correctly using code tags. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
Do not use screenshots, as they are difficult to read and (more importantly) do not allow others to copy and paste your code.
请提供能重现您遇到的问题的代码示例,最好是最小复现单元。
如果您有代码片段、错误信息、堆栈跟踪,也请在此提供。
请使用代码标签正确格式化您的代码。请参见 https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
请勿使用截图,因为截图难以阅读,而且(更重要的是)不允许他人复制粘贴您的代码。
placeholder: |
Steps to reproduce the behavior/复现Bug的步骤:
1.
2.
3.
- type: textarea
id: expected-behavior
validations:
required: true
attributes:
label: Expected behavior / 期待表现
description: "A clear and concise description of what you would expect to happen. /简单描述您期望发生的事情。"

View File

@ -0,0 +1,34 @@
name: "\U0001F680 Feature request"
description: Submit a request for a new CogVideoX feature / 提交一个新的 CogVideoX开源模型的功能建议
labels: [ "feature" ]
body:
- type: textarea
id: feature-request
validations:
required: true
attributes:
label: Feature request / 功能建议
description: |
A brief description of the functional proposal. Links to corresponding papers and code are desirable.
对功能建议的简述。最好提供对应的论文和代码链接。
- type: textarea
id: motivation
validations:
required: true
attributes:
label: Motivation / 动机
description: |
Your motivation for making the suggestion. If that motivation is related to another GitHub issue, link to it here.
您提出建议的动机。如果该动机与另一个 GitHub 问题有关,请在此处提供对应的链接。
- type: textarea
id: contribution
validations:
required: true
attributes:
label: Your contribution / 您的贡献
description: |
Your PR link or any other link you can help with.
您的PR链接或者其他您能提供帮助的链接。

View File

@ -0,0 +1,34 @@
# Raise valuable PR / 提出有价值的PR
## Caution / 注意事项:
Users should keep the following points in mind when submitting PRs:
1. Ensure that your code meets the requirements in the [specification](../../resources/contribute.md).
2. the proposed PR should be relevant, if there are multiple ideas and optimizations, they should be assigned to different PRs.
用户在提交PR时候应该注意以下几点:
1. 确保您的代码符合 [规范](../../resources/contribute_zh.md) 中的要求。
2. 提出的PR应该具有针对性如果具有多个不同的想法和优化方案应该分配到不同的PR中。
## 不应该提出的PR / PRs that should not be proposed
If a developer proposes a PR about any of the following, it may be closed or Rejected.
1. those that don't describe improvement options.
2. multiple issues of different types combined in one PR.
3. The proposed PR is highly duplicative of already existing PRs.
如果开发者提出关于以下方面的PR则可能会被直接关闭或拒绝通过。
1. 没有说明改进方案的。
2. 多个不同类型的问题合并在一个PR中的。
3. 提出的PR与已经存在的PR高度重复的。
# 检查您的PR
- [ ] Have you read the Contributor Guidelines, Pull Request section? / 您是否阅读了贡献者指南、Pull Request 部分?
- [ ] Has this been discussed/approved via a Github issue or forum? If so, add a link. / 是否通过 Github 问题或论坛讨论/批准过?如果是,请添加链接。
- [ ] Did you make sure you updated the documentation with your changes? Here are the Documentation Guidelines, and here are the Documentation Formatting Tips. /您是否确保根据您的更改更新了文档?这里是文档指南,这里是文档格式化技巧。
- [ ] Did you write new required tests? / 您是否编写了新的必要测试?
- [ ] Are your PRs for only one issue / 您的PR是否仅针对一个问题

View File

@ -425,7 +425,7 @@ class SFTDataset(Dataset):
self.videos_list.append(tensor_frms)
# caption
caption_path = os.path.join(root, filename.replace('videos', 'labels').replace('.mp4', '.txt'))
caption_path = os.path.join(root, filename.replace("videos", "labels").replace(".mp4", ".txt"))
if os.path.exists(caption_path):
caption = open(caption_path, "r").read().splitlines()[0]
else:

View File

@ -178,18 +178,18 @@ def sampling_main(args, model_cls):
shape=(T, C, H // F, W // F),
)
samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous()
# Unload the model from GPU to save GPU memory
model.to('cpu')
model.to("cpu")
torch.cuda.empty_cache()
first_stage_model = model.first_stage_model
first_stage_model = first_stage_model.to(device)
latent = 1.0 / model.scale_factor * samples_z
# Decode latent serial to save GPU memory
recons = []
loop_num = (T-1)//2
loop_num = (T - 1) // 2
for i in range(loop_num):
if i == 0:
start_frame, end_frame = 0, 3
@ -200,7 +200,9 @@ def sampling_main(args, model_cls):
else:
clear_fake_cp_cache = False
with torch.no_grad():
recon = first_stage_model.decode(latent[:, :, start_frame:end_frame].contiguous(), clear_fake_cp_cache=clear_fake_cp_cache)
recon = first_stage_model.decode(
latent[:, :, start_frame:end_frame].contiguous(), clear_fake_cp_cache=clear_fake_cp_cache
)
recons.append(recon)
@ -208,12 +210,13 @@ def sampling_main(args, model_cls):
samples_x = recon.permute(0, 2, 1, 3, 4).contiguous()
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu()
save_path = os.path.join(args.output_dir, str(cnt) + '_' + text.replace(' ', '_').replace('/', '')[:120], str(index))
save_path = os.path.join(
args.output_dir, str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120], str(index)
)
if mpu.get_model_parallel_rank() == 0:
save_video_as_grid_and_mp4(samples, save_path, fps=args.sampling_fps)
if __name__ == "__main__":
if "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ:
os.environ["LOCAL_RANK"] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]

View File

@ -52,6 +52,7 @@ except:
from modules.utils import checkpoint
def exists(val):
return val is not None
@ -93,15 +94,9 @@ class FeedForward(nn.Module):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = (
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
if not glu
else GEGLU(dim, inner_dim)
)
project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
)
self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
def forward(self, x):
return self.net(x)
@ -117,9 +112,7 @@ def zero_module(module):
def Normalize(in_channels):
return torch.nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
)
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
class LinearAttention(nn.Module):
@ -133,15 +126,11 @@ class LinearAttention(nn.Module):
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(
qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
)
q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
k = k.softmax(dim=-1)
context = torch.einsum("bhdn,bhen->bhde", k, v)
out = torch.einsum("bhde,bhdn->bhen", context, q)
out = rearrange(
out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
)
out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
return self.to_out(out)
@ -151,18 +140,10 @@ class SpatialSelfAttention(nn.Module):
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.k = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.v = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x):
h_ = x
@ -211,9 +192,7 @@ class CrossAttention(nn.Module):
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
)
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
self.backend = backend
def forward(
@ -241,12 +220,8 @@ class CrossAttention(nn.Module):
# reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
assert x.shape[0] % n_times_crossframe_attn_in_self == 0
n_cp = x.shape[0] // n_times_crossframe_attn_in_self
k = repeat(
k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
)
v = repeat(
v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
)
k = repeat(k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp)
v = repeat(v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
@ -269,9 +244,7 @@ class CrossAttention(nn.Module):
## new
with sdp_kernel(**BACKEND_MAP[self.backend]):
# print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
out = F.scaled_dot_product_attention(
q, k, v, attn_mask=mask
) # scale is dim_head ** -0.5 per default
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) # scale is dim_head ** -0.5 per default
del q, k, v
out = rearrange(out, "b h n d -> b n (h d)", h=h)
@ -284,9 +257,7 @@ class CrossAttention(nn.Module):
class MemoryEfficientCrossAttention(nn.Module):
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
def __init__(
self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs
):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs):
super().__init__()
print(
f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
@ -302,9 +273,7 @@ class MemoryEfficientCrossAttention(nn.Module):
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
)
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
self.attention_op: Optional[Any] = None
def forward(
@ -351,9 +320,7 @@ class MemoryEfficientCrossAttention(nn.Module):
)
# actually compute the attention, what we cannot get enough of
out = xformers.ops.memory_efficient_attention(
q, k, v, attn_bias=None, op=self.attention_op
)
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
# TODO: Use this directly in the attention operation, as a bias
if exists(mask):
@ -398,13 +365,9 @@ class BasicTransformerBlock(nn.Module):
)
attn_mode = "softmax"
elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
print(
"We do not support vanilla attention anymore, as it is too expensive. Sorry."
)
print("We do not support vanilla attention anymore, as it is too expensive. Sorry.")
if not XFORMERS_IS_AVAILABLE:
assert (
False
), "Please install xformers via e.g. 'pip install xformers==0.0.16'"
assert False, "Please install xformers via e.g. 'pip install xformers==0.0.16'"
else:
print("Falling back to xformers efficient attention.")
attn_mode = "softmax-xformers"
@ -438,9 +401,7 @@ class BasicTransformerBlock(nn.Module):
if self.checkpoint:
print(f"{self.__class__.__name__} is using checkpointing")
def forward(
self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
):
def forward(self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
kwargs = {"x": x}
if context is not None:
@ -450,35 +411,22 @@ class BasicTransformerBlock(nn.Module):
kwargs.update({"additional_tokens": additional_tokens})
if n_times_crossframe_attn_in_self:
kwargs.update(
{"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self}
)
kwargs.update({"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self})
# return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint)
return checkpoint(
self._forward, (x, context), self.parameters(), self.checkpoint
)
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
def _forward(
self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
):
def _forward(self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
x = (
self.attn1(
self.norm1(x),
context=context if self.disable_self_attn else None,
additional_tokens=additional_tokens,
n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self
if not self.disable_self_attn
else 0,
)
+ x
)
x = (
self.attn2(
self.norm2(x), context=context, additional_tokens=additional_tokens
n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self if not self.disable_self_attn else 0,
)
+ x
)
x = self.attn2(self.norm2(x), context=context, additional_tokens=additional_tokens) + x
x = self.ff(self.norm3(x)) + x
return x
@ -486,7 +434,7 @@ class BasicTransformerBlock(nn.Module):
class BasicTransformerSingleLayerBlock(nn.Module):
ATTENTION_MODES = {
"softmax": CrossAttention, # vanilla attention
"softmax-xformers": MemoryEfficientCrossAttention # on the A100s not quite as fast as the above version
"softmax-xformers": MemoryEfficientCrossAttention, # on the A100s not quite as fast as the above version
# (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128])
}
@ -517,9 +465,7 @@ class BasicTransformerSingleLayerBlock(nn.Module):
self.checkpoint = checkpoint
def forward(self, x, context=None):
return checkpoint(
self._forward, (x, context), self.parameters(), self.checkpoint
)
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
def _forward(self, x, context=None):
x = self.attn1(self.norm1(x), context=context) + x
@ -553,9 +499,7 @@ class SpatialTransformer(nn.Module):
sdp_backend=None,
):
super().__init__()
print(
f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads"
)
print(f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads")
from omegaconf import ListConfig
if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)):
@ -577,9 +521,7 @@ class SpatialTransformer(nn.Module):
inner_dim = n_heads * d_head
self.norm = Normalize(in_channels)
if not use_linear:
self.proj_in = nn.Conv2d(
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
)
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
else:
self.proj_in = nn.Linear(in_channels, inner_dim)
@ -600,9 +542,7 @@ class SpatialTransformer(nn.Module):
]
)
if not use_linear:
self.proj_out = zero_module(
nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
)
self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
else:
# self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))

View File

@ -16,7 +16,7 @@ from packaging import version
from vae_modules.ema import LitEma
from sgm.util import (
instantiate_from_config,
instantiate_from_config,
get_obj_from_str,
default,
is_context_parallel_initialized,
@ -48,14 +48,14 @@ class AbstractAutoencoder(pl.LightningModule):
self.use_ema = ema_decay is not None
if monitor is not None:
self.monitor = monitor
if self.use_ema:
self.model_ema = LitEma(self, decay=ema_decay)
logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if version.parse(torch.__version__) >= version.parse("2.0.0"):
self.automatic_optimization = False
# def apply_ckpt(self, ckpt: Union[None, str, dict]):
# if ckpt is None:
# return
@ -66,14 +66,14 @@ class AbstractAutoencoder(pl.LightningModule):
# }
# engine = instantiate_from_config(ckpt)
# engine(self)
def apply_ckpt(self, ckpt: Union[None, str, dict]):
if ckpt is None:
return
self.init_from_ckpt(ckpt)
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")['state_dict']
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
@ -119,9 +119,7 @@ class AbstractAutoencoder(pl.LightningModule):
def instantiate_optimizer_from_config(self, params, lr, cfg):
logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
return get_obj_from_str(cfg["target"])(
params, lr=lr, **cfg.get("params", dict())
)
return get_obj_from_str(cfg["target"])(params, lr=lr, **cfg.get("params", dict()))
def configure_optimizers(self) -> Any:
raise NotImplementedError()
@ -160,12 +158,8 @@ class AutoencodingEngine(AbstractAutoencoder):
self.encoder = instantiate_from_config(encoder_config)
self.decoder = instantiate_from_config(decoder_config)
self.loss = instantiate_from_config(loss_config)
self.regularization = instantiate_from_config(
regularizer_config
)
self.optimizer_config = default(
optimizer_config, {"target": "torch.optim.Adam"}
)
self.regularization = instantiate_from_config(regularizer_config)
self.optimizer_config = default(optimizer_config, {"target": "torch.optim.Adam"})
self.diff_boost_factor = diff_boost_factor
self.disc_start_iter = disc_start_iter
self.lr_g_factor = lr_g_factor
@ -178,7 +172,7 @@ class AutoencodingEngine(AbstractAutoencoder):
assert len(self.ae_optimizer_args) == len(self.trainable_ae_params)
else:
self.ae_optimizer_args = [{}] # makes type consitent
self.trainable_disc_params = trainable_disc_params
if self.trainable_disc_params is not None:
self.disc_optimizer_args = default(
@ -210,7 +204,7 @@ class AutoencodingEngine(AbstractAutoencoder):
params = params + list(self.encoder.parameters())
params = params + list(self.decoder.parameters())
return params
def get_discriminator_params(self) -> list:
if hasattr(self.loss, "get_trainable_parameters"):
params = list(self.loss.get_trainable_parameters()) # e.g., discriminator
@ -234,25 +228,19 @@ class AutoencodingEngine(AbstractAutoencoder):
if return_reg_log:
return z, reg_log
return z
def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
x = self.decoder(z, **kwargs)
return x
def forward(
self, x: torch.Tensor, **additional_decode_kwargs
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
def forward(self, x: torch.Tensor, **additional_decode_kwargs) -> Tuple[torch.Tensor, torch.Tensor, dict]:
z, reg_log = self.encode(x, return_reg_log=True)
dec = self.decode(z, **additional_decode_kwargs)
return z, dec, reg_log
def inner_training_step(
self, batch: dict, batch_idx: int, optimizer_idx: int = 0
) -> torch.Tensor:
def inner_training_step(self, batch: dict, batch_idx: int, optimizer_idx: int = 0) -> torch.Tensor:
x = self.get_input(batch)
additional_decode_kwargs = {
key: batch[key] for key in self.additional_decode_keys.intersection(batch)
}
additional_decode_kwargs = {key: batch[key] for key in self.additional_decode_keys.intersection(batch)}
z, xrec, regularization_log = self(x, **additional_decode_kwargs)
if hasattr(self.loss, "forward_keys"):
extra_info = {
@ -267,7 +255,7 @@ class AutoencodingEngine(AbstractAutoencoder):
extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
else:
extra_info = dict()
if optimizer_idx == 0:
# autoencode
out_loss = self.loss(x, xrec, **extra_info)
@ -299,13 +287,11 @@ class AutoencodingEngine(AbstractAutoencoder):
# discriminator
discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
# -> discriminator always needs to return a tuple
self.log_dict(
log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True
)
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
return discloss
else:
raise NotImplementedError(f"Unknown optimizer {optimizer_idx}")
def training_step(self, batch: dict, batch_idx: int):
opts = self.optimizers()
if not isinstance(opts, list):
@ -317,9 +303,7 @@ class AutoencodingEngine(AbstractAutoencoder):
opt = opts[optimizer_idx]
opt.zero_grad()
with opt.toggle_model():
loss = self.inner_training_step(
batch, batch_idx, optimizer_idx=optimizer_idx
)
loss = self.inner_training_step(batch, batch_idx, optimizer_idx=optimizer_idx)
self.manual_backward(loss)
opt.step()
@ -329,7 +313,7 @@ class AutoencodingEngine(AbstractAutoencoder):
log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
log_dict.update(log_dict_ema)
return log_dict
def _validation_step(self, batch: dict, batch_idx: int, postfix: str = "") -> Dict:
x = self.get_input(batch)
@ -387,24 +371,18 @@ class AutoencodingEngine(AbstractAutoencoder):
params.extend(pattern_params)
groups.append({"params": params, **args})
return groups, num_params
def configure_optimizers(self) -> List[torch.optim.Optimizer]:
if self.trainable_ae_params is None:
ae_params = self.get_autoencoder_params()
else:
ae_params, num_ae_params = self.get_param_groups(
self.trainable_ae_params, self.ae_optimizer_args
)
ae_params, num_ae_params = self.get_param_groups(self.trainable_ae_params, self.ae_optimizer_args)
logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}")
if self.trainable_disc_params is None:
disc_params = self.get_discriminator_params()
else:
disc_params, num_disc_params = self.get_param_groups(
self.trainable_disc_params, self.disc_optimizer_args
)
logpy.info(
f"Number of trainable discriminator parameters: {num_disc_params:,}"
)
disc_params, num_disc_params = self.get_param_groups(self.trainable_disc_params, self.disc_optimizer_args)
logpy.info(f"Number of trainable discriminator parameters: {num_disc_params:,}")
opt_ae = self.instantiate_optimizer_from_config(
ae_params,
default(self.lr_g_factor, 1.0) * self.learning_rate,
@ -412,23 +390,17 @@ class AutoencodingEngine(AbstractAutoencoder):
)
opts = [opt_ae]
if len(disc_params) > 0:
opt_disc = self.instantiate_optimizer_from_config(
disc_params, self.learning_rate, self.optimizer_config
)
opt_disc = self.instantiate_optimizer_from_config(disc_params, self.learning_rate, self.optimizer_config)
opts.append(opt_disc)
return opts
@torch.no_grad()
def log_images(
self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs
) -> dict:
def log_images(self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs) -> dict:
log = dict()
additional_decode_kwargs = {}
x = self.get_input(batch)
additional_decode_kwargs.update(
{key: batch[key] for key in self.additional_decode_keys.intersection(batch)}
)
additional_decode_kwargs.update({key: batch[key] for key in self.additional_decode_keys.intersection(batch)})
_, xrec, _ = self(x, **additional_decode_kwargs)
log["inputs"] = x
@ -438,9 +410,7 @@ class AutoencodingEngine(AbstractAutoencoder):
log["diff"] = 2.0 * diff - 1.0
# diff_boost shows location of small errors, by boosting their
# brightness.
log["diff_boost"] = (
2.0 * torch.clamp(self.diff_boost_factor * diff, 0.0, 1.0) - 1
)
log["diff_boost"] = 2.0 * torch.clamp(self.diff_boost_factor * diff, 0.0, 1.0) - 1
if hasattr(self.loss, "log_images"):
log.update(self.loss.log_images(x, xrec))
with self.ema_scope():
@ -449,9 +419,7 @@ class AutoencodingEngine(AbstractAutoencoder):
diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x)
diff_ema.clamp_(0, 1.0)
log["diff_ema"] = 2.0 * diff_ema - 1.0
log["diff_boost_ema"] = (
2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1
)
log["diff_boost_ema"] = 2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1
if additional_log_kwargs:
additional_decode_kwargs.update(additional_log_kwargs)
_, xrec_add, _ = self(x, **additional_decode_kwargs)
@ -493,9 +461,7 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
params = super().get_autoencoder_params()
return params
def encode(
self, x: torch.Tensor, return_reg_log: bool = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
def encode(self, x: torch.Tensor, return_reg_log: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
if self.max_batch_size is None:
z = self.encoder(x)
z = self.quant_conv(z)
@ -538,16 +504,11 @@ class AutoencoderKL(AutoencodingEngineLegacy):
if "lossconfig" in kwargs:
kwargs["loss_config"] = kwargs.pop("lossconfig")
super().__init__(
regularizer_config={
"target": (
"sgm.modules.autoencoding.regularizers"
".DiagonalGaussianRegularizer"
)
},
regularizer_config={"target": ("sgm.modules.autoencoding.regularizers" ".DiagonalGaussianRegularizer")},
**kwargs,
)
class IdentityFirstStage(AbstractAutoencoder):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@ -560,33 +521,31 @@ class IdentityFirstStage(AbstractAutoencoder):
def decode(self, x: Any, *args, **kwargs) -> Any:
return x
class VideoAutoencodingEngine(AutoencodingEngine):
def __init__(
self,
ckpt_path: Union[None, str] = None,
ignore_keys: Union[Tuple, list] = (),
image_video_weights=[1,1],
image_video_weights=[1, 1],
only_train_decoder=False,
context_parallel_size=0,
**kwargs,
**kwargs,
):
super().__init__(**kwargs)
self.context_parallel_size = context_parallel_size
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def log_videos(
self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs
) -> dict:
def log_videos(self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs) -> dict:
return self.log_images(batch, additional_log_kwargs, **kwargs)
def get_input(self, batch: dict) -> torch.Tensor:
if self.context_parallel_size > 0:
if not is_context_parallel_initialized():
initialize_context_parallel(self.context_parallel_size)
batch = batch[self.input_key]
global_src_rank = get_context_parallel_group_rank() * self.context_parallel_size
@ -594,16 +553,16 @@ class VideoAutoencodingEngine(AutoencodingEngine):
batch = _conv_split(batch, dim=2, kernel_size=1)
return batch
return batch[self.input_key]
def apply_ckpt(self, ckpt: Union[None, str, dict]):
if ckpt is None:
return
self.init_from_ckpt(ckpt)
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")['state_dict']
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
@ -615,6 +574,7 @@ class VideoAutoencodingEngine(AutoencodingEngine):
print("Unexpected keys: ", unexpected_keys)
print(f"Restored from {path}")
class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine):
def __init__(
self,
@ -633,16 +593,15 @@ class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine):
input_cp: bool = False,
output_cp: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
if self.cp_size > 0 and not input_cp:
if not is_context_parallel_initialized:
initialize_context_parallel(self.cp_size)
global_src_rank = get_context_parallel_group_rank() * self.cp_size
torch.distributed.broadcast(x, src=global_src_rank, group=get_context_parallel_group())
x = _conv_split(x, dim=2, kernel_size=1)
if return_reg_log:
z, reg_log = super().encode(x, return_reg_log, unregularized)
else:
@ -650,7 +609,7 @@ class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine):
if self.cp_size > 0 and not output_cp:
z = _conv_gather(z, dim=2, kernel_size=1)
if return_reg_log:
return z, reg_log
return z
@ -671,7 +630,7 @@ class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine):
torch.distributed.broadcast(z, src=global_src_rank, group=get_context_parallel_group())
z = _conv_split(z, dim=2, kernel_size=split_kernel_size)
x = super().decode(z, **kwargs)
if self.cp_size > 0 and not output_cp:
@ -680,15 +639,13 @@ class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine):
return x
def forward(
self,
x: torch.Tensor,
self,
x: torch.Tensor,
input_cp: bool = False,
latent_cp: bool = False,
output_cp: bool = False,
**additional_decode_kwargs
**additional_decode_kwargs,
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
z, reg_log = self.encode(x, return_reg_log=True, input_cp=input_cp, output_cp=latent_cp)
dec = self.decode(z, input_cp=latent_cp, output_cp=output_cp, **additional_decode_kwargs)
return z, dec, reg_log

View File

@ -9,7 +9,12 @@ from beartype import beartype
from beartype.typing import Union, Tuple, Optional, List
from einops import rearrange
from sgm.util import get_context_parallel_group, get_context_parallel_rank, get_context_parallel_world_size, get_context_parallel_group_rank
from sgm.util import (
get_context_parallel_group,
get_context_parallel_rank,
get_context_parallel_world_size,
get_context_parallel_group_rank,
)
# try:
from vae_modules.utils import SafeConv3d as Conv3d
@ -17,12 +22,15 @@ from vae_modules.utils import SafeConv3d as Conv3d
# # Degrade to normal Conv3d if SafeConv3d is not available
# from torch.nn import Conv3d
def cast_tuple(t, length = 1):
def cast_tuple(t, length=1):
return t if isinstance(t, tuple) else ((t,) * length)
def divisible_by(num, den):
return (num % den) == 0
def is_odd(n):
return not divisible_by(n, 2)
@ -30,6 +38,7 @@ def is_odd(n):
def exists(v):
return v is not None
def pair(t):
return t if isinstance(t, tuple) else (t, t)
@ -51,15 +60,16 @@ def get_timestep_embedding(timesteps, embedding_dim):
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0,1,0,0))
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
def nonlinearity(x):
# swish
return x*torch.sigmoid(x)
return x * torch.sigmoid(x)
def leaky_relu(p = 0.1):
def leaky_relu(p=0.1):
return nn.LeakyReLU(p)
@ -68,7 +78,7 @@ def _split(input_, dim):
if cp_world_size == 1:
return input_
cp_rank = get_context_parallel_rank()
# print('in _split, cp_rank:', cp_rank, 'input_size:', input_.shape)
@ -88,27 +98,30 @@ def _split(input_, dim):
return output
def _gather(input_, dim):
cp_world_size = get_context_parallel_world_size()
# Bypass the function if context parallel is 1
if cp_world_size == 1:
return input_
group = get_context_parallel_group()
cp_rank = get_context_parallel_rank()
# print('in _gather, cp_rank:', cp_rank, 'input_size:', input_.shape)
input_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous()
if cp_rank == 0:
input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous()
tensor_list = [torch.empty_like(torch.cat([input_first_frame_, input_], dim=dim))] + [torch.empty_like(input_) for _ in range(cp_world_size - 1)]
tensor_list = [torch.empty_like(torch.cat([input_first_frame_, input_], dim=dim))] + [
torch.empty_like(input_) for _ in range(cp_world_size - 1)
]
if cp_rank == 0:
input_ = torch.cat([input_first_frame_, input_], dim=dim)
tensor_list[cp_rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=group)
@ -118,6 +131,7 @@ def _gather(input_, dim):
return output
def _conv_split(input_, dim, kernel_size):
cp_world_size = get_context_parallel_world_size()
@ -132,23 +146,26 @@ def _conv_split(input_, dim, kernel_size):
dim_size = (input_.size()[dim] - kernel_size) // cp_world_size
if cp_rank == 0:
output = input_.transpose(dim, 0)[:dim_size + kernel_size].transpose(dim, 0)
output = input_.transpose(dim, 0)[: dim_size + kernel_size].transpose(dim, 0)
else:
# output = input_.transpose(dim, 0)[cp_rank * dim_size + 1:(cp_rank + 1) * dim_size + kernel_size].transpose(dim, 0)
output = input_.transpose(dim, 0)[cp_rank * dim_size + kernel_size:(cp_rank + 1) * dim_size + kernel_size].transpose(dim, 0)
output = input_.transpose(dim, 0)[
cp_rank * dim_size + kernel_size : (cp_rank + 1) * dim_size + kernel_size
].transpose(dim, 0)
output = output.contiguous()
# print('out _conv_split, cp_rank:', cp_rank, 'input_size:', output.shape)
return output
def _conv_gather(input_, dim, kernel_size):
cp_world_size = get_context_parallel_world_size()
# Bypass the function if context parallel is 1
if cp_world_size == 1:
return input_
group = get_context_parallel_group()
cp_rank = get_context_parallel_rank()
@ -158,9 +175,11 @@ def _conv_gather(input_, dim, kernel_size):
if cp_rank == 0:
input_ = input_.transpose(0, dim)[kernel_size:].transpose(0, dim).contiguous()
else:
input_ = input_.transpose(0, dim)[max(kernel_size - 1, 0):].transpose(0, dim).contiguous()
input_ = input_.transpose(0, dim)[max(kernel_size - 1, 0) :].transpose(0, dim).contiguous()
tensor_list = [torch.empty_like(torch.cat([input_first_kernel_, input_], dim=dim))] + [torch.empty_like(input_) for _ in range(cp_world_size - 1)]
tensor_list = [torch.empty_like(torch.cat([input_first_kernel_, input_], dim=dim))] + [
torch.empty_like(input_) for _ in range(cp_world_size - 1)
]
if cp_rank == 0:
input_ = torch.cat([input_first_kernel_, input_], dim=dim)
@ -174,12 +193,12 @@ def _conv_gather(input_, dim, kernel_size):
return output
def _pass_from_previous_rank(input_, dim, kernel_size):
def _pass_from_previous_rank(input_, dim, kernel_size):
# Bypass the function if kernel size is 1
if kernel_size == 1:
return input_
group = get_context_parallel_group()
cp_rank = get_context_parallel_rank()
cp_group_rank = get_context_parallel_group_rank()
@ -191,7 +210,7 @@ def _pass_from_previous_rank(input_, dim, kernel_size):
global_world_size = torch.distributed.get_world_size()
input_ = input_.transpose(0, dim)
# pass from last rank
send_rank = global_rank + 1
recv_rank = global_rank - 1
@ -201,11 +220,11 @@ def _pass_from_previous_rank(input_, dim, kernel_size):
recv_rank += cp_world_size
if cp_rank < cp_world_size - 1:
req_send = torch.distributed.isend(input_[-kernel_size + 1:].contiguous(), send_rank, group=group)
req_send = torch.distributed.isend(input_[-kernel_size + 1 :].contiguous(), send_rank, group=group)
if cp_rank > 0:
recv_buffer = torch.empty_like(input_[-kernel_size + 1:]).contiguous()
recv_buffer = torch.empty_like(input_[-kernel_size + 1 :]).contiguous()
req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group)
if cp_rank == 0:
input_ = torch.cat([input_[:1]] * (kernel_size - 1) + [input_], dim=0)
else:
@ -218,8 +237,8 @@ def _pass_from_previous_rank(input_, dim, kernel_size):
return input_
def _fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding=None):
def _fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding=None):
# Bypass the function if kernel size is 1
if kernel_size == 1:
return input_
@ -248,9 +267,9 @@ def _fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding=Non
# recv_buffer = torch.empty_like(input_[-kernel_size + 1:]).contiguous()
# req_recv = torch.distributed.recv(recv_buffer, recv_rank, group=group)
# req_recv.wait()
recv_buffer = torch.empty_like(input_[-kernel_size + 1:]).contiguous()
recv_buffer = torch.empty_like(input_[-kernel_size + 1 :]).contiguous()
if cp_rank < cp_world_size - 1:
req_send = torch.distributed.isend(input_[-kernel_size + 1:].contiguous(), send_rank, group=group)
req_send = torch.distributed.isend(input_[-kernel_size + 1 :].contiguous(), send_rank, group=group)
if cp_rank > 0:
req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group)
# req_send = torch.distributed.isend(input_[-kernel_size + 1:].contiguous(), send_rank, group=group)
@ -264,18 +283,17 @@ def _fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding=Non
else:
req_recv.wait()
input_ = torch.cat([recv_buffer, input_], dim=0)
input_ = input_.transpose(0, dim).contiguous()
return input_
def _drop_from_previous_rank(input_, dim, kernel_size):
input_ = input_.transpose(0, dim)[kernel_size - 1:].transpose(0, dim)
input_ = input_.transpose(0, dim)[kernel_size - 1 :].transpose(0, dim)
return input_
class _ConvolutionScatterToContextParallelRegion(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, dim, kernel_size):
ctx.dim = dim
@ -286,8 +304,8 @@ class _ConvolutionScatterToContextParallelRegion(torch.autograd.Function):
def backward(ctx, grad_output):
return _conv_gather(grad_output, ctx.dim, ctx.kernel_size), None, None
class _ConvolutionGatherFromContextParallelRegion(torch.autograd.Function):
class _ConvolutionGatherFromContextParallelRegion(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, dim, kernel_size):
ctx.dim = dim
@ -298,26 +316,26 @@ class _ConvolutionGatherFromContextParallelRegion(torch.autograd.Function):
def backward(ctx, grad_output):
return _conv_split(grad_output, ctx.dim, ctx.kernel_size), None, None
class _ConvolutionPassFromPreviousRank(torch.autograd.Function):
class _ConvolutionPassFromPreviousRank(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, dim, kernel_size):
ctx.dim = dim
ctx.kernel_size = kernel_size
return _pass_from_previous_rank(input_, dim, kernel_size)
@staticmethod
def backward(ctx, grad_output):
return _drop_from_previous_rank(grad_output, ctx.dim, ctx.kernel_size), None, None
class _FakeCPConvolutionPassFromPreviousRank(torch.autograd.Function):
class _FakeCPConvolutionPassFromPreviousRank(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, dim, kernel_size, cache_padding):
ctx.dim = dim
ctx.kernel_size = kernel_size
return _fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding)
@staticmethod
def backward(ctx, grad_output):
return _drop_from_previous_rank(grad_output, ctx.dim, ctx.kernel_size), None, None, None
@ -326,25 +344,21 @@ class _FakeCPConvolutionPassFromPreviousRank(torch.autograd.Function):
def conv_scatter_to_context_parallel_region(input_, dim, kernel_size):
return _ConvolutionScatterToContextParallelRegion.apply(input_, dim, kernel_size)
def conv_gather_from_context_parallel_region(input_, dim, kernel_size):
return _ConvolutionGatherFromContextParallelRegion.apply(input_, dim, kernel_size)
def conv_pass_from_last_rank(input_, dim, kernel_size):
return _ConvolutionPassFromPreviousRank.apply(input_, dim, kernel_size)
def fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding):
return _FakeCPConvolutionPassFromPreviousRank.apply(input_, dim, kernel_size, cache_padding)
class ContextParallelCausalConv3d(nn.Module):
def __init__(
self,
chan_in,
chan_out,
kernel_size: Union[int, Tuple[int, int, int]],
stride = 1,
**kwargs
):
def __init__(self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], stride=1, **kwargs):
super().__init__()
kernel_size = cast_tuple(kernel_size, 3)
@ -364,7 +378,7 @@ class ContextParallelCausalConv3d(nn.Module):
stride = (stride, stride, stride)
dilation = (1, 1, 1)
self.conv = Conv3d(chan_in, chan_out, kernel_size, stride = stride, dilation = dilation, **kwargs)
self.conv = Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
self.cache_padding = None
def forward(self, input_, clear_cache=True):
@ -381,30 +395,40 @@ class ContextParallelCausalConv3d(nn.Module):
# output = output_parallel
# return output
input_parallel = fake_cp_pass_from_previous_rank(input_, self.temporal_dim, self.time_kernel_size, self.cache_padding)
input_parallel = fake_cp_pass_from_previous_rank(
input_, self.temporal_dim, self.time_kernel_size, self.cache_padding
)
del self.cache_padding
self.cache_padding = None
if not clear_cache:
cp_rank, cp_world_size = get_context_parallel_rank(), get_context_parallel_world_size()
global_rank = torch.distributed.get_rank()
if cp_world_size == 1:
self.cache_padding = input_parallel[:, :, -self.time_kernel_size + 1:].contiguous().detach().clone().cpu()
self.cache_padding = (
input_parallel[:, :, -self.time_kernel_size + 1 :].contiguous().detach().clone().cpu()
)
else:
if cp_rank == cp_world_size - 1:
torch.distributed.isend(input_parallel[:, :, -self.time_kernel_size + 1:].contiguous(), global_rank + 1 - cp_world_size, group=get_context_parallel_group())
torch.distributed.isend(
input_parallel[:, :, -self.time_kernel_size + 1 :].contiguous(),
global_rank + 1 - cp_world_size,
group=get_context_parallel_group(),
)
if cp_rank == 0:
recv_buffer = torch.empty_like(input_parallel[:, :, -self.time_kernel_size + 1:]).contiguous()
torch.distributed.recv(recv_buffer, global_rank - 1 + cp_world_size, group=get_context_parallel_group())
recv_buffer = torch.empty_like(input_parallel[:, :, -self.time_kernel_size + 1 :]).contiguous()
torch.distributed.recv(
recv_buffer, global_rank - 1 + cp_world_size, group=get_context_parallel_group()
)
self.cache_padding = recv_buffer.contiguous().detach().clone().cpu()
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
input_parallel = F.pad(input_parallel, padding_2d, mode = 'constant', value = 0)
input_parallel = F.pad(input_parallel, padding_2d, mode="constant", value=0)
output_parallel = self.conv(input_parallel)
output = output_parallel
return output
class ContextParallelGroupNorm(torch.nn.GroupNorm):
def forward(self, input_):
@ -416,26 +440,24 @@ class ContextParallelGroupNorm(torch.nn.GroupNorm):
output = conv_scatter_to_context_parallel_region(output, dim=2, kernel_size=1)
return output
def Normalize(
in_channels,
gather=False,
**kwargs
): # same for 3D and 2D
def Normalize(in_channels, gather=False, **kwargs): # same for 3D and 2D
if gather:
return ContextParallelGroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
else:
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
class SpatialNorm3D(nn.Module):
def __init__(
self,
f_channels,
zq_channels,
freeze_norm_layer=False,
add_conv=False,
pad_mode='constant',
self,
f_channels,
zq_channels,
freeze_norm_layer=False,
add_conv=False,
pad_mode="constant",
gather=False,
**norm_layer_params
**norm_layer_params,
):
super().__init__()
if gather:
@ -446,7 +468,7 @@ class SpatialNorm3D(nn.Module):
if freeze_norm_layer:
for p in self.norm_layer.parameters:
p.requires_grad = False
self.add_conv = add_conv
if add_conv:
self.conv = ContextParallelCausalConv3d(
@ -454,7 +476,7 @@ class SpatialNorm3D(nn.Module):
chan_out=zq_channels,
kernel_size=3,
)
self.conv_y = ContextParallelCausalConv3d(
chan_in=zq_channels,
chan_out=f_channels,
@ -476,10 +498,10 @@ class SpatialNorm3D(nn.Module):
zq = torch.cat([zq_first, zq_rest], dim=2)
else:
zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:], mode="nearest")
if self.add_conv:
zq = self.conv(zq, clear_cache=clear_fake_cp_cache)
# f = conv_gather_from_context_parallel_region(f, dim=2, kernel_size=1)
norm_f = self.norm_layer(f)
# norm_f = conv_scatter_to_context_parallel_region(norm_f, dim=2, kernel_size=1)
@ -487,20 +509,21 @@ class SpatialNorm3D(nn.Module):
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
return new_f
def Normalize3D(
in_channels,
zq_ch,
in_channels,
zq_ch,
add_conv,
gather=False,
):
return SpatialNorm3D(
in_channels,
zq_ch,
in_channels,
zq_ch,
gather=gather,
freeze_norm_layer=False,
add_conv=add_conv,
num_groups=32,
eps=1e-6,
freeze_norm_layer=False,
add_conv=add_conv,
num_groups=32,
eps=1e-6,
affine=True,
)
@ -515,13 +538,9 @@ class Upsample3D(nn.Module):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=3,
stride=1,
padding=1)
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
self.compress_time = compress_time
def forward(self, x):
if self.compress_time and x.shape[2] > 1:
if x.shape[2] % 2 == 1:
@ -533,46 +552,37 @@ class Upsample3D(nn.Module):
x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2)
else:
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
else:
# only interpolate 2D
# only interpolate 2D
t = x.shape[2]
x = rearrange(x, 'b c t h w -> (b t) c h w')
x = rearrange(x, "b c t h w -> (b t) c h w")
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
if self.with_conv:
t = x.shape[2]
x = rearrange(x, 'b c t h w -> (b t) c h w')
x = rearrange(x, "b c t h w -> (b t) c h w")
x = self.conv(x)
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
return x
class DownSample3D(nn.Module):
def __init__(
self,
in_channels,
with_conv,
compress_time=False,
out_channels=None
):
def __init__(self, in_channels, with_conv, compress_time=False, out_channels=None):
super().__init__()
self.with_conv = with_conv
if out_channels is None:
out_channels = in_channels
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=2,
padding=0)
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0)
self.compress_time = compress_time
def forward(self, x):
if self.compress_time and x.shape[2] > 1:
h, w = x.shape[-2:]
x = rearrange(x, 'b c t h w -> (b h w) c t')
x = rearrange(x, "b c t h w -> (b h w) c t")
if x.shape[-1] % 2 == 1:
# split first frame
@ -581,29 +591,40 @@ class DownSample3D(nn.Module):
if x_rest.shape[-1] > 0:
x_rest = torch.nn.functional.avg_pool1d(x_rest, kernel_size=2, stride=2)
x = torch.cat([x_first[..., None], x_rest], dim=-1)
x = rearrange(x, '(b h w) c t -> b c t h w', h=h, w=w)
x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w)
else:
x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2)
x = rearrange(x, '(b h w) c t -> b c t h w', h=h, w=w)
x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w)
if self.with_conv:
pad = (0,1,0,1)
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
t = x.shape[2]
x = rearrange(x, 'b c t h w -> (b t) c h w')
x = rearrange(x, "b c t h w -> (b t) c h w")
x = self.conv(x)
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
else:
t = x.shape[2]
x = rearrange(x, 'b c t h w -> (b t) c h w')
x = rearrange(x, "b c t h w -> (b t) c h w")
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
return x
class ContextParallelResnetBlock3D(nn.Module):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
dropout, temb_channels=512, zq_ch=None, add_conv=False, gather_norm=False, normalization=Normalize):
def __init__(
self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout,
temb_channels=512,
zq_ch=None,
add_conv=False,
gather_norm=False,
normalization=Normalize,
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
@ -623,8 +644,7 @@ class ContextParallelResnetBlock3D(nn.Module):
kernel_size=3,
)
if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels,
out_channels)
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = normalization(
out_channels,
zq_ch=zq_ch,
@ -652,7 +672,7 @@ class ContextParallelResnetBlock3D(nn.Module):
stride=1,
padding=0,
)
def forward(self, x, temb, zq=None, clear_fake_cp_cache=True):
h = x
@ -669,7 +689,7 @@ class ContextParallelResnetBlock3D(nn.Module):
h = self.conv1(h, clear_cache=clear_fake_cp_cache)
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None,None]
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None]
# if isinstance(self.norm2, torch.nn.GroupNorm):
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
@ -690,13 +710,29 @@ class ContextParallelResnetBlock3D(nn.Module):
else:
x = self.nin_shortcut(x)
return x+h
return x + h
class ContextParallelEncoder3D(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, z_channels, double_z=True, pad_mode='first', temporal_compress_times=4, gather_norm=False, **ignore_kwargs):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
double_z=True,
pad_mode="first",
temporal_compress_times=4,
gather_norm=False,
**ignore_kwargs,
):
super().__init__()
self.ch = ch
self.temb_ch = 0
@ -715,13 +751,13 @@ class ContextParallelEncoder3D(nn.Module):
)
curr_res = resolution
in_ch_mult = (1,)+tuple(ch_mult)
in_ch_mult = (1,) + tuple(ch_mult)
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch*in_ch_mult[i_level]
block_out = ch*ch_mult[i_level]
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(
ContextParallelResnetBlock3D(
@ -736,14 +772,14 @@ class ContextParallelEncoder3D(nn.Module):
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions-1:
if i_level != self.num_resolutions - 1:
if i_level < self.temporal_compress_level:
down.downsample = DownSample3D(block_in, resamp_with_conv, compress_time=True)
else:
down.downsample = DownSample3D(block_in, resamp_with_conv, compress_time=False)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ContextParallelResnetBlock3D(
@ -767,12 +803,11 @@ class ContextParallelEncoder3D(nn.Module):
self.conv_out = ContextParallelCausalConv3d(
chan_in=block_in,
chan_out=2*z_channels if double_z else z_channels,
chan_out=2 * z_channels if double_z else z_channels,
kernel_size=3,
)
def forward(self, x):
# timestep embedding
temb = None
@ -783,9 +818,9 @@ class ContextParallelEncoder3D(nn.Module):
h = self.down[i_level].block[i_block](h, temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
if i_level != self.num_resolutions-1:
if i_level != self.num_resolutions - 1:
h = self.down[i_level].downsample(h)
# middle
h = self.mid.block_1(h, temb)
h = self.mid.block_2(h, temb)
@ -794,7 +829,7 @@ class ContextParallelEncoder3D(nn.Module):
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
h = self.norm_out(h)
# h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
h = nonlinearity(h)
h = self.conv_out(h)
@ -802,8 +837,27 @@ class ContextParallelEncoder3D(nn.Module):
class ContextParallelDecoder3D(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, resolution, z_channels, give_pre_end=False, zq_ch=None, add_conv=False, pad_mode='first', temporal_compress_times=4, gather_norm=False, **ignorekwargs):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
give_pre_end=False,
zq_ch=None,
add_conv=False,
pad_mode="first",
temporal_compress_times=4,
gather_norm=False,
**ignorekwargs,
):
super().__init__()
self.ch = ch
self.temb_ch = 0
@ -820,12 +874,11 @@ class ContextParallelDecoder3D(nn.Module):
zq_ch = z_channels
# compute in_ch_mult, block_in and curr_res at lowest res
in_ch_mult = (1,)+tuple(ch_mult)
block_in = ch*ch_mult[self.num_resolutions-1]
curr_res = resolution // 2**(self.num_resolutions-1)
self.z_shape = (1,z_channels,curr_res,curr_res)
print("Working with z of shape {} = {} dimensions.".format(
self.z_shape, np.prod(self.z_shape)))
in_ch_mult = (1,) + tuple(ch_mult)
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
self.conv_in = ContextParallelCausalConv3d(
chan_in=z_channels,
@ -862,8 +915,8 @@ class ContextParallelDecoder3D(nn.Module):
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch*ch_mult[i_level]
for i_block in range(self.num_res_blocks+1):
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
ContextParallelResnetBlock3D(
in_channels=block_in,
@ -886,7 +939,7 @@ class ContextParallelDecoder3D(nn.Module):
else:
up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=True)
self.up.insert(0, up)
self.norm_out = Normalize3D(block_in, zq_ch, add_conv=add_conv, gather=gather_norm)
self.conv_out = ContextParallelCausalConv3d(
@ -896,7 +949,6 @@ class ContextParallelDecoder3D(nn.Module):
)
def forward(self, z, clear_fake_cp_cache=True):
self.last_z_shape = z.shape
# timestep embedding
@ -914,7 +966,7 @@ class ContextParallelDecoder3D(nn.Module):
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks+1):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h, zq)

View File

@ -12,9 +12,7 @@ class LitEma(nn.Module):
self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
self.register_buffer(
"num_updates",
torch.tensor(0, dtype=torch.int)
if use_num_upates
else torch.tensor(-1, dtype=torch.int),
torch.tensor(0, dtype=torch.int) if use_num_upates else torch.tensor(-1, dtype=torch.int),
)
for name, p in model.named_parameters():
@ -47,9 +45,7 @@ class LitEma(nn.Module):
if m_param[key].requires_grad:
sname = self.m_name2s_name[key]
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
shadow_params[sname].sub_(
one_minus_decay * (shadow_params[sname] - m_param[key])
)
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
else:
assert not key in self.m_name2s_name

View File

@ -6,6 +6,7 @@ import torch
import torch.nn.functional as F
from torch import nn
class DiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False):
self.parameters = parameters
@ -15,9 +16,7 @@ class DiagonalGaussianDistribution(object):
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean).to(
device=self.parameters.device
)
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
def sample(self):
# x = self.mean + self.std * torch.randn(self.mean.shape).to(
@ -57,6 +56,7 @@ class DiagonalGaussianDistribution(object):
def mode(self):
return self.mean
class AbstractRegularizer(nn.Module):
def __init__(self):
super().__init__()
@ -77,14 +77,10 @@ class IdentityRegularizer(AbstractRegularizer):
yield from ()
def measure_perplexity(
predicted_indices: torch.Tensor, num_centroids: int
) -> Tuple[torch.Tensor, torch.Tensor]:
def measure_perplexity(predicted_indices: torch.Tensor, num_centroids: int) -> Tuple[torch.Tensor, torch.Tensor]:
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
encodings = (
F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids)
)
encodings = F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids)
avg_probs = encodings.mean(0)
perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
cluster_use = torch.sum(avg_probs > 0)

View File

@ -14,17 +14,19 @@ import torch.distributed
_CONTEXT_PARALLEL_GROUP = None
_CONTEXT_PARALLEL_SIZE = None
def is_context_parallel_initialized():
if _CONTEXT_PARALLEL_GROUP is None:
return False
else:
return True
def initialize_context_parallel(context_parallel_size):
global _CONTEXT_PARALLEL_GROUP
global _CONTEXT_PARALLEL_SIZE
assert _CONTEXT_PARALLEL_GROUP is None, 'context parallel group is already initialized'
assert _CONTEXT_PARALLEL_GROUP is None, "context parallel group is already initialized"
_CONTEXT_PARALLEL_SIZE = context_parallel_size
rank = torch.distributed.get_rank()
@ -36,41 +38,49 @@ def initialize_context_parallel(context_parallel_size):
if rank in ranks:
_CONTEXT_PARALLEL_GROUP = group
break
def get_context_parallel_group():
assert _CONTEXT_PARALLEL_GROUP is not None, 'context parallel group is not initialized'
assert _CONTEXT_PARALLEL_GROUP is not None, "context parallel group is not initialized"
return _CONTEXT_PARALLEL_GROUP
def get_context_parallel_world_size():
assert _CONTEXT_PARALLEL_SIZE is not None, 'context parallel size is not initialized'
assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized"
return _CONTEXT_PARALLEL_SIZE
def get_context_parallel_rank():
assert _CONTEXT_PARALLEL_SIZE is not None, 'context parallel size is not initialized'
assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized"
rank = torch.distributed.get_rank()
cp_rank = rank % _CONTEXT_PARALLEL_SIZE
return cp_rank
def get_context_parallel_group_rank():
assert _CONTEXT_PARALLEL_SIZE is not None, 'context parallel size is not initialized'
assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized"
rank = torch.distributed.get_rank()
cp_group_rank = rank // _CONTEXT_PARALLEL_SIZE
return cp_group_rank
class SafeConv3d(torch.nn.Conv3d):
def forward(self, input):
memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3
if memory_count > 2:
kernel_size = self.kernel_size[0]
part_num = int(memory_count / 2) + 1
input_chunks = torch.chunk(input, part_num, dim=2) # NCTHW
input_chunks = torch.chunk(input, part_num, dim=2) # NCTHW
if kernel_size > 1:
input_chunks = [input_chunks[0]] + [torch.cat((input_chunks[i-1][:, :, -kernel_size+1:], input_chunks[i]), dim=2) for i in range(1, len(input_chunks))]
input_chunks = [input_chunks[0]] + [
torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2)
for i in range(1, len(input_chunks))
]
output_chunks = []
for input_chunk in input_chunks:
@ -149,9 +159,7 @@ def log_txt_as_img(wh, xc, size=10):
text_seq = xc[bi][0]
else:
text_seq = xc[bi]
lines = "\n".join(
text_seq[start : start + nc] for start in range(0, len(text_seq), nc)
)
lines = "\n".join(text_seq[start : start + nc] for start in range(0, len(text_seq), nc))
try:
draw.text((0, 0), lines, fill="black", font=font)
@ -263,9 +271,7 @@ def append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
)
raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
return x[(...,) + (None,) * dims_to_append]
@ -360,7 +366,8 @@ def checkpoint(func, inputs, params, flag):
return CheckpointFunction.apply(func, len(inputs), *args)
else:
return func(*inputs)
class CheckpointFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, run_function, length, *args):
@ -395,4 +402,3 @@ class CheckpointFunction(torch.autograd.Function):
del ctx.input_params
del output_tensors
return (None, None) + input_grads