mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 19:41:59 +08:00
规范化
This commit is contained in:
parent
487a815219
commit
3430df1a36
51
.github/ISSUE_TEMPLATE/bug_report.yaml
vendored
Normal file
51
.github/ISSUE_TEMPLATE/bug_report.yaml
vendored
Normal file
@ -0,0 +1,51 @@
|
||||
name: "\U0001F41B Bug Report"
|
||||
description: Submit a bug report to help us improve CogVideoX / 提交一个 Bug 问题报告来帮助我们改进 CogVideoX 开源模型
|
||||
body:
|
||||
- type: textarea
|
||||
id: system-info
|
||||
attributes:
|
||||
label: System Info / 系統信息
|
||||
description: Your operating environment / 您的运行环境信息
|
||||
placeholder: Includes Cuda version, Diffusers version, Python version, operating system, hardware information (if you suspect a hardware problem)... / 包括Cuda版本,Diffusers,Python版本,操作系统,硬件信息(如果您怀疑是硬件方面的问题)...
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: checkboxes
|
||||
id: information-scripts-examples
|
||||
attributes:
|
||||
label: Information / 问题信息
|
||||
description: 'The problem arises when using: / 问题出现在'
|
||||
options:
|
||||
- label: "The official example scripts / 官方的示例脚本"
|
||||
- label: "My own modified scripts / 我自己修改的脚本和任务"
|
||||
|
||||
- type: textarea
|
||||
id: reproduction
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: Reproduction / 复现过程
|
||||
description: |
|
||||
Please provide a code example that reproduces the problem you encountered, preferably with a minimal reproduction unit.
|
||||
If you have code snippets, error messages, stack traces, please provide them here as well.
|
||||
Please format your code correctly using code tags. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
|
||||
Do not use screenshots, as they are difficult to read and (more importantly) do not allow others to copy and paste your code.
|
||||
|
||||
请提供能重现您遇到的问题的代码示例,最好是最小复现单元。
|
||||
如果您有代码片段、错误信息、堆栈跟踪,也请在此提供。
|
||||
请使用代码标签正确格式化您的代码。请参见 https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
|
||||
请勿使用截图,因为截图难以阅读,而且(更重要的是)不允许他人复制粘贴您的代码。
|
||||
placeholder: |
|
||||
Steps to reproduce the behavior/复现Bug的步骤:
|
||||
|
||||
1.
|
||||
2.
|
||||
3.
|
||||
|
||||
- type: textarea
|
||||
id: expected-behavior
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: Expected behavior / 期待表现
|
||||
description: "A clear and concise description of what you would expect to happen. /简单描述您期望发生的事情。"
|
34
.github/ISSUE_TEMPLATE/feature-request.yaml
vendored
Normal file
34
.github/ISSUE_TEMPLATE/feature-request.yaml
vendored
Normal file
@ -0,0 +1,34 @@
|
||||
name: "\U0001F680 Feature request"
|
||||
description: Submit a request for a new CogVideoX feature / 提交一个新的 CogVideoX开源模型的功能建议
|
||||
labels: [ "feature" ]
|
||||
body:
|
||||
- type: textarea
|
||||
id: feature-request
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: Feature request / 功能建议
|
||||
description: |
|
||||
A brief description of the functional proposal. Links to corresponding papers and code are desirable.
|
||||
对功能建议的简述。最好提供对应的论文和代码链接。
|
||||
|
||||
- type: textarea
|
||||
id: motivation
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: Motivation / 动机
|
||||
description: |
|
||||
Your motivation for making the suggestion. If that motivation is related to another GitHub issue, link to it here.
|
||||
您提出建议的动机。如果该动机与另一个 GitHub 问题有关,请在此处提供对应的链接。
|
||||
|
||||
- type: textarea
|
||||
id: contribution
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: Your contribution / 您的贡献
|
||||
description: |
|
||||
|
||||
Your PR link or any other link you can help with.
|
||||
您的PR链接或者其他您能提供帮助的链接。
|
34
.github/PULL_REQUEST_TEMPLATE/pr_template.md
vendored
Normal file
34
.github/PULL_REQUEST_TEMPLATE/pr_template.md
vendored
Normal file
@ -0,0 +1,34 @@
|
||||
# Raise valuable PR / 提出有价值的PR
|
||||
|
||||
## Caution / 注意事项:
|
||||
Users should keep the following points in mind when submitting PRs:
|
||||
|
||||
1. Ensure that your code meets the requirements in the [specification](../../resources/contribute.md).
|
||||
2. the proposed PR should be relevant, if there are multiple ideas and optimizations, they should be assigned to different PRs.
|
||||
|
||||
用户在提交PR时候应该注意以下几点:
|
||||
|
||||
1. 确保您的代码符合 [规范](../../resources/contribute_zh.md) 中的要求。
|
||||
2. 提出的PR应该具有针对性,如果具有多个不同的想法和优化方案,应该分配到不同的PR中。
|
||||
|
||||
## 不应该提出的PR / PRs that should not be proposed
|
||||
|
||||
If a developer proposes a PR about any of the following, it may be closed or Rejected.
|
||||
|
||||
1. those that don't describe improvement options.
|
||||
2. multiple issues of different types combined in one PR.
|
||||
3. The proposed PR is highly duplicative of already existing PRs.
|
||||
|
||||
如果开发者提出关于以下方面的PR,则可能会被直接关闭或拒绝通过。
|
||||
|
||||
1. 没有说明改进方案的。
|
||||
2. 多个不同类型的问题合并在一个PR中的。
|
||||
3. 提出的PR与已经存在的PR高度重复的。
|
||||
|
||||
|
||||
# 检查您的PR
|
||||
- [ ] Have you read the Contributor Guidelines, Pull Request section? / 您是否阅读了贡献者指南、Pull Request 部分?
|
||||
- [ ] Has this been discussed/approved via a Github issue or forum? If so, add a link. / 是否通过 Github 问题或论坛讨论/批准过?如果是,请添加链接。
|
||||
- [ ] Did you make sure you updated the documentation with your changes? Here are the Documentation Guidelines, and here are the Documentation Formatting Tips. /您是否确保根据您的更改更新了文档?这里是文档指南,这里是文档格式化技巧。
|
||||
- [ ] Did you write new required tests? / 您是否编写了新的必要测试?
|
||||
- [ ] Are your PRs for only one issue / 您的PR是否仅针对一个问题
|
@ -425,7 +425,7 @@ class SFTDataset(Dataset):
|
||||
self.videos_list.append(tensor_frms)
|
||||
|
||||
# caption
|
||||
caption_path = os.path.join(root, filename.replace('videos', 'labels').replace('.mp4', '.txt'))
|
||||
caption_path = os.path.join(root, filename.replace("videos", "labels").replace(".mp4", ".txt"))
|
||||
if os.path.exists(caption_path):
|
||||
caption = open(caption_path, "r").read().splitlines()[0]
|
||||
else:
|
||||
|
@ -180,7 +180,7 @@ def sampling_main(args, model_cls):
|
||||
samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous()
|
||||
|
||||
# Unload the model from GPU to save GPU memory
|
||||
model.to('cpu')
|
||||
model.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
first_stage_model = model.first_stage_model
|
||||
first_stage_model = first_stage_model.to(device)
|
||||
@ -200,7 +200,9 @@ def sampling_main(args, model_cls):
|
||||
else:
|
||||
clear_fake_cp_cache = False
|
||||
with torch.no_grad():
|
||||
recon = first_stage_model.decode(latent[:, :, start_frame:end_frame].contiguous(), clear_fake_cp_cache=clear_fake_cp_cache)
|
||||
recon = first_stage_model.decode(
|
||||
latent[:, :, start_frame:end_frame].contiguous(), clear_fake_cp_cache=clear_fake_cp_cache
|
||||
)
|
||||
|
||||
recons.append(recon)
|
||||
|
||||
@ -208,12 +210,13 @@ def sampling_main(args, model_cls):
|
||||
samples_x = recon.permute(0, 2, 1, 3, 4).contiguous()
|
||||
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu()
|
||||
|
||||
save_path = os.path.join(args.output_dir, str(cnt) + '_' + text.replace(' ', '_').replace('/', '')[:120], str(index))
|
||||
save_path = os.path.join(
|
||||
args.output_dir, str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120], str(index)
|
||||
)
|
||||
if mpu.get_model_parallel_rank() == 0:
|
||||
save_video_as_grid_and_mp4(samples, save_path, fps=args.sampling_fps)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ:
|
||||
os.environ["LOCAL_RANK"] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]
|
||||
|
@ -52,6 +52,7 @@ except:
|
||||
|
||||
from modules.utils import checkpoint
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
@ -93,15 +94,9 @@ class FeedForward(nn.Module):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = default(dim_out, dim)
|
||||
project_in = (
|
||||
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
|
||||
if not glu
|
||||
else GEGLU(dim, inner_dim)
|
||||
)
|
||||
project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
|
||||
)
|
||||
self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
@ -117,9 +112,7 @@ def zero_module(module):
|
||||
|
||||
|
||||
def Normalize(in_channels):
|
||||
return torch.nn.GroupNorm(
|
||||
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
||||
)
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
class LinearAttention(nn.Module):
|
||||
@ -133,15 +126,11 @@ class LinearAttention(nn.Module):
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
qkv = self.to_qkv(x)
|
||||
q, k, v = rearrange(
|
||||
qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
|
||||
)
|
||||
q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
|
||||
k = k.softmax(dim=-1)
|
||||
context = torch.einsum("bhdn,bhen->bhde", k, v)
|
||||
out = torch.einsum("bhde,bhdn->bhen", context, q)
|
||||
out = rearrange(
|
||||
out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
|
||||
)
|
||||
out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
@ -151,18 +140,10 @@ class SpatialSelfAttention(nn.Module):
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.k = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.v = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.proj_out = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
@ -211,9 +192,7 @@ class CrossAttention(nn.Module):
|
||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
|
||||
)
|
||||
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
||||
self.backend = backend
|
||||
|
||||
def forward(
|
||||
@ -241,12 +220,8 @@ class CrossAttention(nn.Module):
|
||||
# reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
|
||||
assert x.shape[0] % n_times_crossframe_attn_in_self == 0
|
||||
n_cp = x.shape[0] // n_times_crossframe_attn_in_self
|
||||
k = repeat(
|
||||
k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
|
||||
)
|
||||
v = repeat(
|
||||
v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
|
||||
)
|
||||
k = repeat(k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp)
|
||||
v = repeat(v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
|
||||
|
||||
@ -269,9 +244,7 @@ class CrossAttention(nn.Module):
|
||||
## new
|
||||
with sdp_kernel(**BACKEND_MAP[self.backend]):
|
||||
# print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
|
||||
out = F.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=mask
|
||||
) # scale is dim_head ** -0.5 per default
|
||||
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) # scale is dim_head ** -0.5 per default
|
||||
|
||||
del q, k, v
|
||||
out = rearrange(out, "b h n d -> b n (h d)", h=h)
|
||||
@ -284,9 +257,7 @@ class CrossAttention(nn.Module):
|
||||
|
||||
class MemoryEfficientCrossAttention(nn.Module):
|
||||
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
||||
def __init__(
|
||||
self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs
|
||||
):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs):
|
||||
super().__init__()
|
||||
print(
|
||||
f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
|
||||
@ -302,9 +273,7 @@ class MemoryEfficientCrossAttention(nn.Module):
|
||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
|
||||
)
|
||||
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
||||
self.attention_op: Optional[Any] = None
|
||||
|
||||
def forward(
|
||||
@ -351,9 +320,7 @@ class MemoryEfficientCrossAttention(nn.Module):
|
||||
)
|
||||
|
||||
# actually compute the attention, what we cannot get enough of
|
||||
out = xformers.ops.memory_efficient_attention(
|
||||
q, k, v, attn_bias=None, op=self.attention_op
|
||||
)
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
|
||||
|
||||
# TODO: Use this directly in the attention operation, as a bias
|
||||
if exists(mask):
|
||||
@ -398,13 +365,9 @@ class BasicTransformerBlock(nn.Module):
|
||||
)
|
||||
attn_mode = "softmax"
|
||||
elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
|
||||
print(
|
||||
"We do not support vanilla attention anymore, as it is too expensive. Sorry."
|
||||
)
|
||||
print("We do not support vanilla attention anymore, as it is too expensive. Sorry.")
|
||||
if not XFORMERS_IS_AVAILABLE:
|
||||
assert (
|
||||
False
|
||||
), "Please install xformers via e.g. 'pip install xformers==0.0.16'"
|
||||
assert False, "Please install xformers via e.g. 'pip install xformers==0.0.16'"
|
||||
else:
|
||||
print("Falling back to xformers efficient attention.")
|
||||
attn_mode = "softmax-xformers"
|
||||
@ -438,9 +401,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
if self.checkpoint:
|
||||
print(f"{self.__class__.__name__} is using checkpointing")
|
||||
|
||||
def forward(
|
||||
self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
|
||||
):
|
||||
def forward(self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
|
||||
kwargs = {"x": x}
|
||||
|
||||
if context is not None:
|
||||
@ -450,35 +411,22 @@ class BasicTransformerBlock(nn.Module):
|
||||
kwargs.update({"additional_tokens": additional_tokens})
|
||||
|
||||
if n_times_crossframe_attn_in_self:
|
||||
kwargs.update(
|
||||
{"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self}
|
||||
)
|
||||
kwargs.update({"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self})
|
||||
|
||||
# return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint)
|
||||
return checkpoint(
|
||||
self._forward, (x, context), self.parameters(), self.checkpoint
|
||||
)
|
||||
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
|
||||
|
||||
def _forward(
|
||||
self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
|
||||
):
|
||||
def _forward(self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
|
||||
x = (
|
||||
self.attn1(
|
||||
self.norm1(x),
|
||||
context=context if self.disable_self_attn else None,
|
||||
additional_tokens=additional_tokens,
|
||||
n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self
|
||||
if not self.disable_self_attn
|
||||
else 0,
|
||||
)
|
||||
+ x
|
||||
)
|
||||
x = (
|
||||
self.attn2(
|
||||
self.norm2(x), context=context, additional_tokens=additional_tokens
|
||||
n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self if not self.disable_self_attn else 0,
|
||||
)
|
||||
+ x
|
||||
)
|
||||
x = self.attn2(self.norm2(x), context=context, additional_tokens=additional_tokens) + x
|
||||
x = self.ff(self.norm3(x)) + x
|
||||
return x
|
||||
|
||||
@ -486,7 +434,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
class BasicTransformerSingleLayerBlock(nn.Module):
|
||||
ATTENTION_MODES = {
|
||||
"softmax": CrossAttention, # vanilla attention
|
||||
"softmax-xformers": MemoryEfficientCrossAttention # on the A100s not quite as fast as the above version
|
||||
"softmax-xformers": MemoryEfficientCrossAttention, # on the A100s not quite as fast as the above version
|
||||
# (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128])
|
||||
}
|
||||
|
||||
@ -517,9 +465,7 @@ class BasicTransformerSingleLayerBlock(nn.Module):
|
||||
self.checkpoint = checkpoint
|
||||
|
||||
def forward(self, x, context=None):
|
||||
return checkpoint(
|
||||
self._forward, (x, context), self.parameters(), self.checkpoint
|
||||
)
|
||||
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
|
||||
|
||||
def _forward(self, x, context=None):
|
||||
x = self.attn1(self.norm1(x), context=context) + x
|
||||
@ -553,9 +499,7 @@ class SpatialTransformer(nn.Module):
|
||||
sdp_backend=None,
|
||||
):
|
||||
super().__init__()
|
||||
print(
|
||||
f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads"
|
||||
)
|
||||
print(f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads")
|
||||
from omegaconf import ListConfig
|
||||
|
||||
if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)):
|
||||
@ -577,9 +521,7 @@ class SpatialTransformer(nn.Module):
|
||||
inner_dim = n_heads * d_head
|
||||
self.norm = Normalize(in_channels)
|
||||
if not use_linear:
|
||||
self.proj_in = nn.Conv2d(
|
||||
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
||||
else:
|
||||
self.proj_in = nn.Linear(in_channels, inner_dim)
|
||||
|
||||
@ -600,9 +542,7 @@ class SpatialTransformer(nn.Module):
|
||||
]
|
||||
)
|
||||
if not use_linear:
|
||||
self.proj_out = zero_module(
|
||||
nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
)
|
||||
self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
|
||||
else:
|
||||
# self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
|
||||
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
|
||||
|
@ -73,7 +73,7 @@ class AbstractAutoencoder(pl.LightningModule):
|
||||
self.init_from_ckpt(ckpt)
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||
sd = torch.load(path, map_location="cpu")['state_dict']
|
||||
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
@ -119,9 +119,7 @@ class AbstractAutoencoder(pl.LightningModule):
|
||||
|
||||
def instantiate_optimizer_from_config(self, params, lr, cfg):
|
||||
logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
|
||||
return get_obj_from_str(cfg["target"])(
|
||||
params, lr=lr, **cfg.get("params", dict())
|
||||
)
|
||||
return get_obj_from_str(cfg["target"])(params, lr=lr, **cfg.get("params", dict()))
|
||||
|
||||
def configure_optimizers(self) -> Any:
|
||||
raise NotImplementedError()
|
||||
@ -160,12 +158,8 @@ class AutoencodingEngine(AbstractAutoencoder):
|
||||
self.encoder = instantiate_from_config(encoder_config)
|
||||
self.decoder = instantiate_from_config(decoder_config)
|
||||
self.loss = instantiate_from_config(loss_config)
|
||||
self.regularization = instantiate_from_config(
|
||||
regularizer_config
|
||||
)
|
||||
self.optimizer_config = default(
|
||||
optimizer_config, {"target": "torch.optim.Adam"}
|
||||
)
|
||||
self.regularization = instantiate_from_config(regularizer_config)
|
||||
self.optimizer_config = default(optimizer_config, {"target": "torch.optim.Adam"})
|
||||
self.diff_boost_factor = diff_boost_factor
|
||||
self.disc_start_iter = disc_start_iter
|
||||
self.lr_g_factor = lr_g_factor
|
||||
@ -239,20 +233,14 @@ class AutoencodingEngine(AbstractAutoencoder):
|
||||
x = self.decoder(z, **kwargs)
|
||||
return x
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, **additional_decode_kwargs
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
|
||||
def forward(self, x: torch.Tensor, **additional_decode_kwargs) -> Tuple[torch.Tensor, torch.Tensor, dict]:
|
||||
z, reg_log = self.encode(x, return_reg_log=True)
|
||||
dec = self.decode(z, **additional_decode_kwargs)
|
||||
return z, dec, reg_log
|
||||
|
||||
def inner_training_step(
|
||||
self, batch: dict, batch_idx: int, optimizer_idx: int = 0
|
||||
) -> torch.Tensor:
|
||||
def inner_training_step(self, batch: dict, batch_idx: int, optimizer_idx: int = 0) -> torch.Tensor:
|
||||
x = self.get_input(batch)
|
||||
additional_decode_kwargs = {
|
||||
key: batch[key] for key in self.additional_decode_keys.intersection(batch)
|
||||
}
|
||||
additional_decode_kwargs = {key: batch[key] for key in self.additional_decode_keys.intersection(batch)}
|
||||
z, xrec, regularization_log = self(x, **additional_decode_kwargs)
|
||||
if hasattr(self.loss, "forward_keys"):
|
||||
extra_info = {
|
||||
@ -299,9 +287,7 @@ class AutoencodingEngine(AbstractAutoencoder):
|
||||
# discriminator
|
||||
discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
|
||||
# -> discriminator always needs to return a tuple
|
||||
self.log_dict(
|
||||
log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True
|
||||
)
|
||||
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
||||
return discloss
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown optimizer {optimizer_idx}")
|
||||
@ -317,9 +303,7 @@ class AutoencodingEngine(AbstractAutoencoder):
|
||||
opt = opts[optimizer_idx]
|
||||
opt.zero_grad()
|
||||
with opt.toggle_model():
|
||||
loss = self.inner_training_step(
|
||||
batch, batch_idx, optimizer_idx=optimizer_idx
|
||||
)
|
||||
loss = self.inner_training_step(batch, batch_idx, optimizer_idx=optimizer_idx)
|
||||
self.manual_backward(loss)
|
||||
opt.step()
|
||||
|
||||
@ -392,19 +376,13 @@ class AutoencodingEngine(AbstractAutoencoder):
|
||||
if self.trainable_ae_params is None:
|
||||
ae_params = self.get_autoencoder_params()
|
||||
else:
|
||||
ae_params, num_ae_params = self.get_param_groups(
|
||||
self.trainable_ae_params, self.ae_optimizer_args
|
||||
)
|
||||
ae_params, num_ae_params = self.get_param_groups(self.trainable_ae_params, self.ae_optimizer_args)
|
||||
logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}")
|
||||
if self.trainable_disc_params is None:
|
||||
disc_params = self.get_discriminator_params()
|
||||
else:
|
||||
disc_params, num_disc_params = self.get_param_groups(
|
||||
self.trainable_disc_params, self.disc_optimizer_args
|
||||
)
|
||||
logpy.info(
|
||||
f"Number of trainable discriminator parameters: {num_disc_params:,}"
|
||||
)
|
||||
disc_params, num_disc_params = self.get_param_groups(self.trainable_disc_params, self.disc_optimizer_args)
|
||||
logpy.info(f"Number of trainable discriminator parameters: {num_disc_params:,}")
|
||||
opt_ae = self.instantiate_optimizer_from_config(
|
||||
ae_params,
|
||||
default(self.lr_g_factor, 1.0) * self.learning_rate,
|
||||
@ -412,23 +390,17 @@ class AutoencodingEngine(AbstractAutoencoder):
|
||||
)
|
||||
opts = [opt_ae]
|
||||
if len(disc_params) > 0:
|
||||
opt_disc = self.instantiate_optimizer_from_config(
|
||||
disc_params, self.learning_rate, self.optimizer_config
|
||||
)
|
||||
opt_disc = self.instantiate_optimizer_from_config(disc_params, self.learning_rate, self.optimizer_config)
|
||||
opts.append(opt_disc)
|
||||
|
||||
return opts
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(
|
||||
self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs
|
||||
) -> dict:
|
||||
def log_images(self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs) -> dict:
|
||||
log = dict()
|
||||
additional_decode_kwargs = {}
|
||||
x = self.get_input(batch)
|
||||
additional_decode_kwargs.update(
|
||||
{key: batch[key] for key in self.additional_decode_keys.intersection(batch)}
|
||||
)
|
||||
additional_decode_kwargs.update({key: batch[key] for key in self.additional_decode_keys.intersection(batch)})
|
||||
|
||||
_, xrec, _ = self(x, **additional_decode_kwargs)
|
||||
log["inputs"] = x
|
||||
@ -438,9 +410,7 @@ class AutoencodingEngine(AbstractAutoencoder):
|
||||
log["diff"] = 2.0 * diff - 1.0
|
||||
# diff_boost shows location of small errors, by boosting their
|
||||
# brightness.
|
||||
log["diff_boost"] = (
|
||||
2.0 * torch.clamp(self.diff_boost_factor * diff, 0.0, 1.0) - 1
|
||||
)
|
||||
log["diff_boost"] = 2.0 * torch.clamp(self.diff_boost_factor * diff, 0.0, 1.0) - 1
|
||||
if hasattr(self.loss, "log_images"):
|
||||
log.update(self.loss.log_images(x, xrec))
|
||||
with self.ema_scope():
|
||||
@ -449,9 +419,7 @@ class AutoencodingEngine(AbstractAutoencoder):
|
||||
diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x)
|
||||
diff_ema.clamp_(0, 1.0)
|
||||
log["diff_ema"] = 2.0 * diff_ema - 1.0
|
||||
log["diff_boost_ema"] = (
|
||||
2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1
|
||||
)
|
||||
log["diff_boost_ema"] = 2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1
|
||||
if additional_log_kwargs:
|
||||
additional_decode_kwargs.update(additional_log_kwargs)
|
||||
_, xrec_add, _ = self(x, **additional_decode_kwargs)
|
||||
@ -493,9 +461,7 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
|
||||
params = super().get_autoencoder_params()
|
||||
return params
|
||||
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_reg_log: bool = False
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
||||
def encode(self, x: torch.Tensor, return_reg_log: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
||||
if self.max_batch_size is None:
|
||||
z = self.encoder(x)
|
||||
z = self.quant_conv(z)
|
||||
@ -538,12 +504,7 @@ class AutoencoderKL(AutoencodingEngineLegacy):
|
||||
if "lossconfig" in kwargs:
|
||||
kwargs["loss_config"] = kwargs.pop("lossconfig")
|
||||
super().__init__(
|
||||
regularizer_config={
|
||||
"target": (
|
||||
"sgm.modules.autoencoding.regularizers"
|
||||
".DiagonalGaussianRegularizer"
|
||||
)
|
||||
},
|
||||
regularizer_config={"target": ("sgm.modules.autoencoding.regularizers" ".DiagonalGaussianRegularizer")},
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -577,9 +538,7 @@ class VideoAutoencodingEngine(AutoencodingEngine):
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
|
||||
def log_videos(
|
||||
self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs
|
||||
) -> dict:
|
||||
def log_videos(self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs) -> dict:
|
||||
return self.log_images(batch, additional_log_kwargs, **kwargs)
|
||||
|
||||
def get_input(self, batch: dict) -> torch.Tensor:
|
||||
@ -603,7 +562,7 @@ class VideoAutoencodingEngine(AutoencodingEngine):
|
||||
self.init_from_ckpt(ckpt)
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||
sd = torch.load(path, map_location="cpu")['state_dict']
|
||||
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
@ -615,6 +574,7 @@ class VideoAutoencodingEngine(AutoencodingEngine):
|
||||
print("Unexpected keys: ", unexpected_keys)
|
||||
print(f"Restored from {path}")
|
||||
|
||||
|
||||
class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine):
|
||||
def __init__(
|
||||
self,
|
||||
@ -633,7 +593,6 @@ class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine):
|
||||
input_cp: bool = False,
|
||||
output_cp: bool = False,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
||||
|
||||
if self.cp_size > 0 and not input_cp:
|
||||
if not is_context_parallel_initialized:
|
||||
initialize_context_parallel(self.cp_size)
|
||||
@ -685,10 +644,8 @@ class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine):
|
||||
input_cp: bool = False,
|
||||
latent_cp: bool = False,
|
||||
output_cp: bool = False,
|
||||
**additional_decode_kwargs
|
||||
**additional_decode_kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
|
||||
|
||||
z, reg_log = self.encode(x, return_reg_log=True, input_cp=input_cp, output_cp=latent_cp)
|
||||
dec = self.decode(z, input_cp=latent_cp, output_cp=output_cp, **additional_decode_kwargs)
|
||||
return z, dec, reg_log
|
||||
|
@ -9,7 +9,12 @@ from beartype import beartype
|
||||
from beartype.typing import Union, Tuple, Optional, List
|
||||
from einops import rearrange
|
||||
|
||||
from sgm.util import get_context_parallel_group, get_context_parallel_rank, get_context_parallel_world_size, get_context_parallel_group_rank
|
||||
from sgm.util import (
|
||||
get_context_parallel_group,
|
||||
get_context_parallel_rank,
|
||||
get_context_parallel_world_size,
|
||||
get_context_parallel_group_rank,
|
||||
)
|
||||
|
||||
# try:
|
||||
from vae_modules.utils import SafeConv3d as Conv3d
|
||||
@ -17,12 +22,15 @@ from vae_modules.utils import SafeConv3d as Conv3d
|
||||
# # Degrade to normal Conv3d if SafeConv3d is not available
|
||||
# from torch.nn import Conv3d
|
||||
|
||||
|
||||
def cast_tuple(t, length=1):
|
||||
return t if isinstance(t, tuple) else ((t,) * length)
|
||||
|
||||
|
||||
def divisible_by(num, den):
|
||||
return (num % den) == 0
|
||||
|
||||
|
||||
def is_odd(n):
|
||||
return not divisible_by(n, 2)
|
||||
|
||||
@ -30,6 +38,7 @@ def is_odd(n):
|
||||
def exists(v):
|
||||
return v is not None
|
||||
|
||||
|
||||
def pair(t):
|
||||
return t if isinstance(t, tuple) else (t, t)
|
||||
|
||||
@ -59,6 +68,7 @@ def nonlinearity(x):
|
||||
# swish
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
def leaky_relu(p=0.1):
|
||||
return nn.LeakyReLU(p)
|
||||
|
||||
@ -88,6 +98,7 @@ def _split(input_, dim):
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _gather(input_, dim):
|
||||
cp_world_size = get_context_parallel_world_size()
|
||||
|
||||
@ -104,7 +115,9 @@ def _gather(input_, dim):
|
||||
if cp_rank == 0:
|
||||
input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous()
|
||||
|
||||
tensor_list = [torch.empty_like(torch.cat([input_first_frame_, input_], dim=dim))] + [torch.empty_like(input_) for _ in range(cp_world_size - 1)]
|
||||
tensor_list = [torch.empty_like(torch.cat([input_first_frame_, input_], dim=dim))] + [
|
||||
torch.empty_like(input_) for _ in range(cp_world_size - 1)
|
||||
]
|
||||
|
||||
if cp_rank == 0:
|
||||
input_ = torch.cat([input_first_frame_, input_], dim=dim)
|
||||
@ -118,6 +131,7 @@ def _gather(input_, dim):
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _conv_split(input_, dim, kernel_size):
|
||||
cp_world_size = get_context_parallel_world_size()
|
||||
|
||||
@ -135,13 +149,16 @@ def _conv_split(input_, dim, kernel_size):
|
||||
output = input_.transpose(dim, 0)[: dim_size + kernel_size].transpose(dim, 0)
|
||||
else:
|
||||
# output = input_.transpose(dim, 0)[cp_rank * dim_size + 1:(cp_rank + 1) * dim_size + kernel_size].transpose(dim, 0)
|
||||
output = input_.transpose(dim, 0)[cp_rank * dim_size + kernel_size:(cp_rank + 1) * dim_size + kernel_size].transpose(dim, 0)
|
||||
output = input_.transpose(dim, 0)[
|
||||
cp_rank * dim_size + kernel_size : (cp_rank + 1) * dim_size + kernel_size
|
||||
].transpose(dim, 0)
|
||||
output = output.contiguous()
|
||||
|
||||
# print('out _conv_split, cp_rank:', cp_rank, 'input_size:', output.shape)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _conv_gather(input_, dim, kernel_size):
|
||||
cp_world_size = get_context_parallel_world_size()
|
||||
|
||||
@ -160,7 +177,9 @@ def _conv_gather(input_, dim, kernel_size):
|
||||
else:
|
||||
input_ = input_.transpose(0, dim)[max(kernel_size - 1, 0) :].transpose(0, dim).contiguous()
|
||||
|
||||
tensor_list = [torch.empty_like(torch.cat([input_first_kernel_, input_], dim=dim))] + [torch.empty_like(input_) for _ in range(cp_world_size - 1)]
|
||||
tensor_list = [torch.empty_like(torch.cat([input_first_kernel_, input_], dim=dim))] + [
|
||||
torch.empty_like(input_) for _ in range(cp_world_size - 1)
|
||||
]
|
||||
if cp_rank == 0:
|
||||
input_ = torch.cat([input_first_kernel_, input_], dim=dim)
|
||||
|
||||
@ -174,8 +193,8 @@ def _conv_gather(input_, dim, kernel_size):
|
||||
|
||||
return output
|
||||
|
||||
def _pass_from_previous_rank(input_, dim, kernel_size):
|
||||
|
||||
def _pass_from_previous_rank(input_, dim, kernel_size):
|
||||
# Bypass the function if kernel size is 1
|
||||
if kernel_size == 1:
|
||||
return input_
|
||||
@ -218,8 +237,8 @@ def _pass_from_previous_rank(input_, dim, kernel_size):
|
||||
|
||||
return input_
|
||||
|
||||
def _fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding=None):
|
||||
|
||||
def _fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding=None):
|
||||
# Bypass the function if kernel size is 1
|
||||
if kernel_size == 1:
|
||||
return input_
|
||||
@ -275,7 +294,6 @@ def _drop_from_previous_rank(input_, dim, kernel_size):
|
||||
|
||||
|
||||
class _ConvolutionScatterToContextParallelRegion(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, dim, kernel_size):
|
||||
ctx.dim = dim
|
||||
@ -286,8 +304,8 @@ class _ConvolutionScatterToContextParallelRegion(torch.autograd.Function):
|
||||
def backward(ctx, grad_output):
|
||||
return _conv_gather(grad_output, ctx.dim, ctx.kernel_size), None, None
|
||||
|
||||
class _ConvolutionGatherFromContextParallelRegion(torch.autograd.Function):
|
||||
|
||||
class _ConvolutionGatherFromContextParallelRegion(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input_, dim, kernel_size):
|
||||
ctx.dim = dim
|
||||
@ -298,8 +316,8 @@ class _ConvolutionGatherFromContextParallelRegion(torch.autograd.Function):
|
||||
def backward(ctx, grad_output):
|
||||
return _conv_split(grad_output, ctx.dim, ctx.kernel_size), None, None
|
||||
|
||||
class _ConvolutionPassFromPreviousRank(torch.autograd.Function):
|
||||
|
||||
class _ConvolutionPassFromPreviousRank(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input_, dim, kernel_size):
|
||||
ctx.dim = dim
|
||||
@ -310,8 +328,8 @@ class _ConvolutionPassFromPreviousRank(torch.autograd.Function):
|
||||
def backward(ctx, grad_output):
|
||||
return _drop_from_previous_rank(grad_output, ctx.dim, ctx.kernel_size), None, None
|
||||
|
||||
class _FakeCPConvolutionPassFromPreviousRank(torch.autograd.Function):
|
||||
|
||||
class _FakeCPConvolutionPassFromPreviousRank(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input_, dim, kernel_size, cache_padding):
|
||||
ctx.dim = dim
|
||||
@ -326,25 +344,21 @@ class _FakeCPConvolutionPassFromPreviousRank(torch.autograd.Function):
|
||||
def conv_scatter_to_context_parallel_region(input_, dim, kernel_size):
|
||||
return _ConvolutionScatterToContextParallelRegion.apply(input_, dim, kernel_size)
|
||||
|
||||
|
||||
def conv_gather_from_context_parallel_region(input_, dim, kernel_size):
|
||||
return _ConvolutionGatherFromContextParallelRegion.apply(input_, dim, kernel_size)
|
||||
|
||||
|
||||
def conv_pass_from_last_rank(input_, dim, kernel_size):
|
||||
return _ConvolutionPassFromPreviousRank.apply(input_, dim, kernel_size)
|
||||
|
||||
|
||||
def fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding):
|
||||
return _FakeCPConvolutionPassFromPreviousRank.apply(input_, dim, kernel_size, cache_padding)
|
||||
|
||||
|
||||
class ContextParallelCausalConv3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
chan_in,
|
||||
chan_out,
|
||||
kernel_size: Union[int, Tuple[int, int, int]],
|
||||
stride = 1,
|
||||
**kwargs
|
||||
):
|
||||
def __init__(self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], stride=1, **kwargs):
|
||||
super().__init__()
|
||||
kernel_size = cast_tuple(kernel_size, 3)
|
||||
|
||||
@ -381,7 +395,9 @@ class ContextParallelCausalConv3d(nn.Module):
|
||||
# output = output_parallel
|
||||
# return output
|
||||
|
||||
input_parallel = fake_cp_pass_from_previous_rank(input_, self.temporal_dim, self.time_kernel_size, self.cache_padding)
|
||||
input_parallel = fake_cp_pass_from_previous_rank(
|
||||
input_, self.temporal_dim, self.time_kernel_size, self.cache_padding
|
||||
)
|
||||
|
||||
del self.cache_padding
|
||||
self.cache_padding = None
|
||||
@ -389,17 +405,25 @@ class ContextParallelCausalConv3d(nn.Module):
|
||||
cp_rank, cp_world_size = get_context_parallel_rank(), get_context_parallel_world_size()
|
||||
global_rank = torch.distributed.get_rank()
|
||||
if cp_world_size == 1:
|
||||
self.cache_padding = input_parallel[:, :, -self.time_kernel_size + 1:].contiguous().detach().clone().cpu()
|
||||
self.cache_padding = (
|
||||
input_parallel[:, :, -self.time_kernel_size + 1 :].contiguous().detach().clone().cpu()
|
||||
)
|
||||
else:
|
||||
if cp_rank == cp_world_size - 1:
|
||||
torch.distributed.isend(input_parallel[:, :, -self.time_kernel_size + 1:].contiguous(), global_rank + 1 - cp_world_size, group=get_context_parallel_group())
|
||||
torch.distributed.isend(
|
||||
input_parallel[:, :, -self.time_kernel_size + 1 :].contiguous(),
|
||||
global_rank + 1 - cp_world_size,
|
||||
group=get_context_parallel_group(),
|
||||
)
|
||||
if cp_rank == 0:
|
||||
recv_buffer = torch.empty_like(input_parallel[:, :, -self.time_kernel_size + 1 :]).contiguous()
|
||||
torch.distributed.recv(recv_buffer, global_rank - 1 + cp_world_size, group=get_context_parallel_group())
|
||||
torch.distributed.recv(
|
||||
recv_buffer, global_rank - 1 + cp_world_size, group=get_context_parallel_group()
|
||||
)
|
||||
self.cache_padding = recv_buffer.contiguous().detach().clone().cpu()
|
||||
|
||||
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
|
||||
input_parallel = F.pad(input_parallel, padding_2d, mode = 'constant', value = 0)
|
||||
input_parallel = F.pad(input_parallel, padding_2d, mode="constant", value=0)
|
||||
|
||||
output_parallel = self.conv(input_parallel)
|
||||
output = output_parallel
|
||||
@ -416,16 +440,14 @@ class ContextParallelGroupNorm(torch.nn.GroupNorm):
|
||||
output = conv_scatter_to_context_parallel_region(output, dim=2, kernel_size=1)
|
||||
return output
|
||||
|
||||
def Normalize(
|
||||
in_channels,
|
||||
gather=False,
|
||||
**kwargs
|
||||
): # same for 3D and 2D
|
||||
|
||||
def Normalize(in_channels, gather=False, **kwargs): # same for 3D and 2D
|
||||
if gather:
|
||||
return ContextParallelGroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
else:
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
class SpatialNorm3D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -433,9 +455,9 @@ class SpatialNorm3D(nn.Module):
|
||||
zq_channels,
|
||||
freeze_norm_layer=False,
|
||||
add_conv=False,
|
||||
pad_mode='constant',
|
||||
pad_mode="constant",
|
||||
gather=False,
|
||||
**norm_layer_params
|
||||
**norm_layer_params,
|
||||
):
|
||||
super().__init__()
|
||||
if gather:
|
||||
@ -487,6 +509,7 @@ class SpatialNorm3D(nn.Module):
|
||||
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
|
||||
return new_f
|
||||
|
||||
|
||||
def Normalize3D(
|
||||
in_channels,
|
||||
zq_ch,
|
||||
@ -515,11 +538,7 @@ class Upsample3D(nn.Module):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
self.conv = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.compress_time = compress_time
|
||||
|
||||
def forward(self, x):
|
||||
@ -537,42 +556,33 @@ class Upsample3D(nn.Module):
|
||||
else:
|
||||
# only interpolate 2D
|
||||
t = x.shape[2]
|
||||
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
||||
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
|
||||
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
|
||||
|
||||
if self.with_conv:
|
||||
t = x.shape[2]
|
||||
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
||||
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||
x = self.conv(x)
|
||||
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
|
||||
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
|
||||
return x
|
||||
|
||||
|
||||
class DownSample3D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
with_conv,
|
||||
compress_time=False,
|
||||
out_channels=None
|
||||
):
|
||||
def __init__(self, in_channels, with_conv, compress_time=False, out_channels=None):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if out_channels is None:
|
||||
out_channels = in_channels
|
||||
if self.with_conv:
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = torch.nn.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=0)
|
||||
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0)
|
||||
self.compress_time = compress_time
|
||||
|
||||
def forward(self, x):
|
||||
if self.compress_time and x.shape[2] > 1:
|
||||
h, w = x.shape[-2:]
|
||||
x = rearrange(x, 'b c t h w -> (b h w) c t')
|
||||
x = rearrange(x, "b c t h w -> (b h w) c t")
|
||||
|
||||
if x.shape[-1] % 2 == 1:
|
||||
# split first frame
|
||||
@ -581,29 +591,40 @@ class DownSample3D(nn.Module):
|
||||
if x_rest.shape[-1] > 0:
|
||||
x_rest = torch.nn.functional.avg_pool1d(x_rest, kernel_size=2, stride=2)
|
||||
x = torch.cat([x_first[..., None], x_rest], dim=-1)
|
||||
x = rearrange(x, '(b h w) c t -> b c t h w', h=h, w=w)
|
||||
x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w)
|
||||
else:
|
||||
x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2)
|
||||
x = rearrange(x, '(b h w) c t -> b c t h w', h=h, w=w)
|
||||
x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w)
|
||||
|
||||
if self.with_conv:
|
||||
pad = (0, 1, 0, 1)
|
||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
||||
t = x.shape[2]
|
||||
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
||||
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||
x = self.conv(x)
|
||||
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
|
||||
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
|
||||
else:
|
||||
t = x.shape[2]
|
||||
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
||||
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
||||
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
|
||||
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
|
||||
return x
|
||||
|
||||
|
||||
class ContextParallelResnetBlock3D(nn.Module):
|
||||
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
||||
dropout, temb_channels=512, zq_ch=None, add_conv=False, gather_norm=False, normalization=Normalize):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
in_channels,
|
||||
out_channels=None,
|
||||
conv_shortcut=False,
|
||||
dropout,
|
||||
temb_channels=512,
|
||||
zq_ch=None,
|
||||
add_conv=False,
|
||||
gather_norm=False,
|
||||
normalization=Normalize,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
@ -623,8 +644,7 @@ class ContextParallelResnetBlock3D(nn.Module):
|
||||
kernel_size=3,
|
||||
)
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = torch.nn.Linear(temb_channels,
|
||||
out_channels)
|
||||
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
||||
self.norm2 = normalization(
|
||||
out_channels,
|
||||
zq_ch=zq_ch,
|
||||
@ -694,9 +714,25 @@ class ContextParallelResnetBlock3D(nn.Module):
|
||||
|
||||
|
||||
class ContextParallelEncoder3D(nn.Module):
|
||||
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
||||
resolution, z_channels, double_z=True, pad_mode='first', temporal_compress_times=4, gather_norm=False, **ignore_kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
double_z=True,
|
||||
pad_mode="first",
|
||||
temporal_compress_times=4,
|
||||
gather_norm=False,
|
||||
**ignore_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
@ -772,7 +808,6 @@ class ContextParallelEncoder3D(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
@ -802,8 +837,27 @@ class ContextParallelEncoder3D(nn.Module):
|
||||
|
||||
|
||||
class ContextParallelDecoder3D(nn.Module):
|
||||
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, resolution, z_channels, give_pre_end=False, zq_ch=None, add_conv=False, pad_mode='first', temporal_compress_times=4, gather_norm=False, **ignorekwargs):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
give_pre_end=False,
|
||||
zq_ch=None,
|
||||
add_conv=False,
|
||||
pad_mode="first",
|
||||
temporal_compress_times=4,
|
||||
gather_norm=False,
|
||||
**ignorekwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
@ -824,8 +878,7 @@ class ContextParallelDecoder3D(nn.Module):
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
print("Working with z of shape {} = {} dimensions.".format(
|
||||
self.z_shape, np.prod(self.z_shape)))
|
||||
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
|
||||
|
||||
self.conv_in = ContextParallelCausalConv3d(
|
||||
chan_in=z_channels,
|
||||
@ -896,7 +949,6 @@ class ContextParallelDecoder3D(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, z, clear_fake_cp_cache=True):
|
||||
|
||||
self.last_z_shape = z.shape
|
||||
|
||||
# timestep embedding
|
||||
|
@ -12,9 +12,7 @@ class LitEma(nn.Module):
|
||||
self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
|
||||
self.register_buffer(
|
||||
"num_updates",
|
||||
torch.tensor(0, dtype=torch.int)
|
||||
if use_num_upates
|
||||
else torch.tensor(-1, dtype=torch.int),
|
||||
torch.tensor(0, dtype=torch.int) if use_num_upates else torch.tensor(-1, dtype=torch.int),
|
||||
)
|
||||
|
||||
for name, p in model.named_parameters():
|
||||
@ -47,9 +45,7 @@ class LitEma(nn.Module):
|
||||
if m_param[key].requires_grad:
|
||||
sname = self.m_name2s_name[key]
|
||||
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
|
||||
shadow_params[sname].sub_(
|
||||
one_minus_decay * (shadow_params[sname] - m_param[key])
|
||||
)
|
||||
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
|
||||
else:
|
||||
assert not key in self.m_name2s_name
|
||||
|
||||
|
@ -6,6 +6,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
|
||||
class DiagonalGaussianDistribution(object):
|
||||
def __init__(self, parameters, deterministic=False):
|
||||
self.parameters = parameters
|
||||
@ -15,9 +16,7 @@ class DiagonalGaussianDistribution(object):
|
||||
self.std = torch.exp(0.5 * self.logvar)
|
||||
self.var = torch.exp(self.logvar)
|
||||
if self.deterministic:
|
||||
self.var = self.std = torch.zeros_like(self.mean).to(
|
||||
device=self.parameters.device
|
||||
)
|
||||
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
|
||||
|
||||
def sample(self):
|
||||
# x = self.mean + self.std * torch.randn(self.mean.shape).to(
|
||||
@ -57,6 +56,7 @@ class DiagonalGaussianDistribution(object):
|
||||
def mode(self):
|
||||
return self.mean
|
||||
|
||||
|
||||
class AbstractRegularizer(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -77,14 +77,10 @@ class IdentityRegularizer(AbstractRegularizer):
|
||||
yield from ()
|
||||
|
||||
|
||||
def measure_perplexity(
|
||||
predicted_indices: torch.Tensor, num_centroids: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def measure_perplexity(predicted_indices: torch.Tensor, num_centroids: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
|
||||
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
|
||||
encodings = (
|
||||
F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids)
|
||||
)
|
||||
encodings = F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids)
|
||||
avg_probs = encodings.mean(0)
|
||||
perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
|
||||
cluster_use = torch.sum(avg_probs > 0)
|
||||
|
@ -14,17 +14,19 @@ import torch.distributed
|
||||
_CONTEXT_PARALLEL_GROUP = None
|
||||
_CONTEXT_PARALLEL_SIZE = None
|
||||
|
||||
|
||||
def is_context_parallel_initialized():
|
||||
if _CONTEXT_PARALLEL_GROUP is None:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def initialize_context_parallel(context_parallel_size):
|
||||
global _CONTEXT_PARALLEL_GROUP
|
||||
global _CONTEXT_PARALLEL_SIZE
|
||||
|
||||
assert _CONTEXT_PARALLEL_GROUP is None, 'context parallel group is already initialized'
|
||||
assert _CONTEXT_PARALLEL_GROUP is None, "context parallel group is already initialized"
|
||||
_CONTEXT_PARALLEL_SIZE = context_parallel_size
|
||||
|
||||
rank = torch.distributed.get_rank()
|
||||
@ -37,31 +39,36 @@ def initialize_context_parallel(context_parallel_size):
|
||||
_CONTEXT_PARALLEL_GROUP = group
|
||||
break
|
||||
|
||||
|
||||
def get_context_parallel_group():
|
||||
assert _CONTEXT_PARALLEL_GROUP is not None, 'context parallel group is not initialized'
|
||||
assert _CONTEXT_PARALLEL_GROUP is not None, "context parallel group is not initialized"
|
||||
|
||||
return _CONTEXT_PARALLEL_GROUP
|
||||
|
||||
|
||||
def get_context_parallel_world_size():
|
||||
assert _CONTEXT_PARALLEL_SIZE is not None, 'context parallel size is not initialized'
|
||||
assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized"
|
||||
|
||||
return _CONTEXT_PARALLEL_SIZE
|
||||
|
||||
|
||||
def get_context_parallel_rank():
|
||||
assert _CONTEXT_PARALLEL_SIZE is not None, 'context parallel size is not initialized'
|
||||
assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized"
|
||||
|
||||
rank = torch.distributed.get_rank()
|
||||
cp_rank = rank % _CONTEXT_PARALLEL_SIZE
|
||||
return cp_rank
|
||||
|
||||
|
||||
def get_context_parallel_group_rank():
|
||||
assert _CONTEXT_PARALLEL_SIZE is not None, 'context parallel size is not initialized'
|
||||
assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized"
|
||||
|
||||
rank = torch.distributed.get_rank()
|
||||
cp_group_rank = rank // _CONTEXT_PARALLEL_SIZE
|
||||
|
||||
return cp_group_rank
|
||||
|
||||
|
||||
class SafeConv3d(torch.nn.Conv3d):
|
||||
def forward(self, input):
|
||||
memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3
|
||||
@ -70,7 +77,10 @@ class SafeConv3d(torch.nn.Conv3d):
|
||||
part_num = int(memory_count / 2) + 1
|
||||
input_chunks = torch.chunk(input, part_num, dim=2) # NCTHW
|
||||
if kernel_size > 1:
|
||||
input_chunks = [input_chunks[0]] + [torch.cat((input_chunks[i-1][:, :, -kernel_size+1:], input_chunks[i]), dim=2) for i in range(1, len(input_chunks))]
|
||||
input_chunks = [input_chunks[0]] + [
|
||||
torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2)
|
||||
for i in range(1, len(input_chunks))
|
||||
]
|
||||
|
||||
output_chunks = []
|
||||
for input_chunk in input_chunks:
|
||||
@ -149,9 +159,7 @@ def log_txt_as_img(wh, xc, size=10):
|
||||
text_seq = xc[bi][0]
|
||||
else:
|
||||
text_seq = xc[bi]
|
||||
lines = "\n".join(
|
||||
text_seq[start : start + nc] for start in range(0, len(text_seq), nc)
|
||||
)
|
||||
lines = "\n".join(text_seq[start : start + nc] for start in range(0, len(text_seq), nc))
|
||||
|
||||
try:
|
||||
draw.text((0, 0), lines, fill="black", font=font)
|
||||
@ -263,9 +271,7 @@ def append_dims(x, target_dims):
|
||||
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
||||
dims_to_append = target_dims - x.ndim
|
||||
if dims_to_append < 0:
|
||||
raise ValueError(
|
||||
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
|
||||
)
|
||||
raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
|
||||
return x[(...,) + (None,) * dims_to_append]
|
||||
|
||||
|
||||
@ -361,6 +367,7 @@ def checkpoint(func, inputs, params, flag):
|
||||
else:
|
||||
return func(*inputs)
|
||||
|
||||
|
||||
class CheckpointFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, run_function, length, *args):
|
||||
@ -395,4 +402,3 @@ class CheckpointFunction(torch.autograd.Function):
|
||||
del ctx.input_params
|
||||
del output_tensors
|
||||
return (None, None) + input_grads
|
||||
|
Loading…
x
Reference in New Issue
Block a user