mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-06 03:57:56 +08:00
规范化
This commit is contained in:
parent
487a815219
commit
3430df1a36
51
.github/ISSUE_TEMPLATE/bug_report.yaml
vendored
Normal file
51
.github/ISSUE_TEMPLATE/bug_report.yaml
vendored
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
name: "\U0001F41B Bug Report"
|
||||||
|
description: Submit a bug report to help us improve CogVideoX / 提交一个 Bug 问题报告来帮助我们改进 CogVideoX 开源模型
|
||||||
|
body:
|
||||||
|
- type: textarea
|
||||||
|
id: system-info
|
||||||
|
attributes:
|
||||||
|
label: System Info / 系統信息
|
||||||
|
description: Your operating environment / 您的运行环境信息
|
||||||
|
placeholder: Includes Cuda version, Diffusers version, Python version, operating system, hardware information (if you suspect a hardware problem)... / 包括Cuda版本,Diffusers,Python版本,操作系统,硬件信息(如果您怀疑是硬件方面的问题)...
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: checkboxes
|
||||||
|
id: information-scripts-examples
|
||||||
|
attributes:
|
||||||
|
label: Information / 问题信息
|
||||||
|
description: 'The problem arises when using: / 问题出现在'
|
||||||
|
options:
|
||||||
|
- label: "The official example scripts / 官方的示例脚本"
|
||||||
|
- label: "My own modified scripts / 我自己修改的脚本和任务"
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: reproduction
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
attributes:
|
||||||
|
label: Reproduction / 复现过程
|
||||||
|
description: |
|
||||||
|
Please provide a code example that reproduces the problem you encountered, preferably with a minimal reproduction unit.
|
||||||
|
If you have code snippets, error messages, stack traces, please provide them here as well.
|
||||||
|
Please format your code correctly using code tags. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
|
||||||
|
Do not use screenshots, as they are difficult to read and (more importantly) do not allow others to copy and paste your code.
|
||||||
|
|
||||||
|
请提供能重现您遇到的问题的代码示例,最好是最小复现单元。
|
||||||
|
如果您有代码片段、错误信息、堆栈跟踪,也请在此提供。
|
||||||
|
请使用代码标签正确格式化您的代码。请参见 https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
|
||||||
|
请勿使用截图,因为截图难以阅读,而且(更重要的是)不允许他人复制粘贴您的代码。
|
||||||
|
placeholder: |
|
||||||
|
Steps to reproduce the behavior/复现Bug的步骤:
|
||||||
|
|
||||||
|
1.
|
||||||
|
2.
|
||||||
|
3.
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: expected-behavior
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
attributes:
|
||||||
|
label: Expected behavior / 期待表现
|
||||||
|
description: "A clear and concise description of what you would expect to happen. /简单描述您期望发生的事情。"
|
34
.github/ISSUE_TEMPLATE/feature-request.yaml
vendored
Normal file
34
.github/ISSUE_TEMPLATE/feature-request.yaml
vendored
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
name: "\U0001F680 Feature request"
|
||||||
|
description: Submit a request for a new CogVideoX feature / 提交一个新的 CogVideoX开源模型的功能建议
|
||||||
|
labels: [ "feature" ]
|
||||||
|
body:
|
||||||
|
- type: textarea
|
||||||
|
id: feature-request
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
attributes:
|
||||||
|
label: Feature request / 功能建议
|
||||||
|
description: |
|
||||||
|
A brief description of the functional proposal. Links to corresponding papers and code are desirable.
|
||||||
|
对功能建议的简述。最好提供对应的论文和代码链接。
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: motivation
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
attributes:
|
||||||
|
label: Motivation / 动机
|
||||||
|
description: |
|
||||||
|
Your motivation for making the suggestion. If that motivation is related to another GitHub issue, link to it here.
|
||||||
|
您提出建议的动机。如果该动机与另一个 GitHub 问题有关,请在此处提供对应的链接。
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: contribution
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
attributes:
|
||||||
|
label: Your contribution / 您的贡献
|
||||||
|
description: |
|
||||||
|
|
||||||
|
Your PR link or any other link you can help with.
|
||||||
|
您的PR链接或者其他您能提供帮助的链接。
|
34
.github/PULL_REQUEST_TEMPLATE/pr_template.md
vendored
Normal file
34
.github/PULL_REQUEST_TEMPLATE/pr_template.md
vendored
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
# Raise valuable PR / 提出有价值的PR
|
||||||
|
|
||||||
|
## Caution / 注意事项:
|
||||||
|
Users should keep the following points in mind when submitting PRs:
|
||||||
|
|
||||||
|
1. Ensure that your code meets the requirements in the [specification](../../resources/contribute.md).
|
||||||
|
2. the proposed PR should be relevant, if there are multiple ideas and optimizations, they should be assigned to different PRs.
|
||||||
|
|
||||||
|
用户在提交PR时候应该注意以下几点:
|
||||||
|
|
||||||
|
1. 确保您的代码符合 [规范](../../resources/contribute_zh.md) 中的要求。
|
||||||
|
2. 提出的PR应该具有针对性,如果具有多个不同的想法和优化方案,应该分配到不同的PR中。
|
||||||
|
|
||||||
|
## 不应该提出的PR / PRs that should not be proposed
|
||||||
|
|
||||||
|
If a developer proposes a PR about any of the following, it may be closed or Rejected.
|
||||||
|
|
||||||
|
1. those that don't describe improvement options.
|
||||||
|
2. multiple issues of different types combined in one PR.
|
||||||
|
3. The proposed PR is highly duplicative of already existing PRs.
|
||||||
|
|
||||||
|
如果开发者提出关于以下方面的PR,则可能会被直接关闭或拒绝通过。
|
||||||
|
|
||||||
|
1. 没有说明改进方案的。
|
||||||
|
2. 多个不同类型的问题合并在一个PR中的。
|
||||||
|
3. 提出的PR与已经存在的PR高度重复的。
|
||||||
|
|
||||||
|
|
||||||
|
# 检查您的PR
|
||||||
|
- [ ] Have you read the Contributor Guidelines, Pull Request section? / 您是否阅读了贡献者指南、Pull Request 部分?
|
||||||
|
- [ ] Has this been discussed/approved via a Github issue or forum? If so, add a link. / 是否通过 Github 问题或论坛讨论/批准过?如果是,请添加链接。
|
||||||
|
- [ ] Did you make sure you updated the documentation with your changes? Here are the Documentation Guidelines, and here are the Documentation Formatting Tips. /您是否确保根据您的更改更新了文档?这里是文档指南,这里是文档格式化技巧。
|
||||||
|
- [ ] Did you write new required tests? / 您是否编写了新的必要测试?
|
||||||
|
- [ ] Are your PRs for only one issue / 您的PR是否仅针对一个问题
|
@ -425,7 +425,7 @@ class SFTDataset(Dataset):
|
|||||||
self.videos_list.append(tensor_frms)
|
self.videos_list.append(tensor_frms)
|
||||||
|
|
||||||
# caption
|
# caption
|
||||||
caption_path = os.path.join(root, filename.replace('videos', 'labels').replace('.mp4', '.txt'))
|
caption_path = os.path.join(root, filename.replace("videos", "labels").replace(".mp4", ".txt"))
|
||||||
if os.path.exists(caption_path):
|
if os.path.exists(caption_path):
|
||||||
caption = open(caption_path, "r").read().splitlines()[0]
|
caption = open(caption_path, "r").read().splitlines()[0]
|
||||||
else:
|
else:
|
||||||
|
@ -180,7 +180,7 @@ def sampling_main(args, model_cls):
|
|||||||
samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous()
|
samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous()
|
||||||
|
|
||||||
# Unload the model from GPU to save GPU memory
|
# Unload the model from GPU to save GPU memory
|
||||||
model.to('cpu')
|
model.to("cpu")
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
first_stage_model = model.first_stage_model
|
first_stage_model = model.first_stage_model
|
||||||
first_stage_model = first_stage_model.to(device)
|
first_stage_model = first_stage_model.to(device)
|
||||||
@ -189,7 +189,7 @@ def sampling_main(args, model_cls):
|
|||||||
|
|
||||||
# Decode latent serial to save GPU memory
|
# Decode latent serial to save GPU memory
|
||||||
recons = []
|
recons = []
|
||||||
loop_num = (T-1)//2
|
loop_num = (T - 1) // 2
|
||||||
for i in range(loop_num):
|
for i in range(loop_num):
|
||||||
if i == 0:
|
if i == 0:
|
||||||
start_frame, end_frame = 0, 3
|
start_frame, end_frame = 0, 3
|
||||||
@ -200,7 +200,9 @@ def sampling_main(args, model_cls):
|
|||||||
else:
|
else:
|
||||||
clear_fake_cp_cache = False
|
clear_fake_cp_cache = False
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
recon = first_stage_model.decode(latent[:, :, start_frame:end_frame].contiguous(), clear_fake_cp_cache=clear_fake_cp_cache)
|
recon = first_stage_model.decode(
|
||||||
|
latent[:, :, start_frame:end_frame].contiguous(), clear_fake_cp_cache=clear_fake_cp_cache
|
||||||
|
)
|
||||||
|
|
||||||
recons.append(recon)
|
recons.append(recon)
|
||||||
|
|
||||||
@ -208,12 +210,13 @@ def sampling_main(args, model_cls):
|
|||||||
samples_x = recon.permute(0, 2, 1, 3, 4).contiguous()
|
samples_x = recon.permute(0, 2, 1, 3, 4).contiguous()
|
||||||
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu()
|
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu()
|
||||||
|
|
||||||
save_path = os.path.join(args.output_dir, str(cnt) + '_' + text.replace(' ', '_').replace('/', '')[:120], str(index))
|
save_path = os.path.join(
|
||||||
|
args.output_dir, str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120], str(index)
|
||||||
|
)
|
||||||
if mpu.get_model_parallel_rank() == 0:
|
if mpu.get_model_parallel_rank() == 0:
|
||||||
save_video_as_grid_and_mp4(samples, save_path, fps=args.sampling_fps)
|
save_video_as_grid_and_mp4(samples, save_path, fps=args.sampling_fps)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
if "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ:
|
if "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ:
|
||||||
os.environ["LOCAL_RANK"] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]
|
os.environ["LOCAL_RANK"] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]
|
||||||
|
@ -52,6 +52,7 @@ except:
|
|||||||
|
|
||||||
from modules.utils import checkpoint
|
from modules.utils import checkpoint
|
||||||
|
|
||||||
|
|
||||||
def exists(val):
|
def exists(val):
|
||||||
return val is not None
|
return val is not None
|
||||||
|
|
||||||
@ -93,15 +94,9 @@ class FeedForward(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
inner_dim = int(dim * mult)
|
inner_dim = int(dim * mult)
|
||||||
dim_out = default(dim_out, dim)
|
dim_out = default(dim_out, dim)
|
||||||
project_in = (
|
project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
|
||||||
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
|
|
||||||
if not glu
|
|
||||||
else GEGLU(dim, inner_dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.net = nn.Sequential(
|
self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
|
||||||
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.net(x)
|
return self.net(x)
|
||||||
@ -117,9 +112,7 @@ def zero_module(module):
|
|||||||
|
|
||||||
|
|
||||||
def Normalize(in_channels):
|
def Normalize(in_channels):
|
||||||
return torch.nn.GroupNorm(
|
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||||
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LinearAttention(nn.Module):
|
class LinearAttention(nn.Module):
|
||||||
@ -133,15 +126,11 @@ class LinearAttention(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
b, c, h, w = x.shape
|
b, c, h, w = x.shape
|
||||||
qkv = self.to_qkv(x)
|
qkv = self.to_qkv(x)
|
||||||
q, k, v = rearrange(
|
q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
|
||||||
qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
|
|
||||||
)
|
|
||||||
k = k.softmax(dim=-1)
|
k = k.softmax(dim=-1)
|
||||||
context = torch.einsum("bhdn,bhen->bhde", k, v)
|
context = torch.einsum("bhdn,bhen->bhde", k, v)
|
||||||
out = torch.einsum("bhde,bhdn->bhen", context, q)
|
out = torch.einsum("bhde,bhdn->bhen", context, q)
|
||||||
out = rearrange(
|
out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
|
||||||
out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
|
|
||||||
)
|
|
||||||
return self.to_out(out)
|
return self.to_out(out)
|
||||||
|
|
||||||
|
|
||||||
@ -151,18 +140,10 @@ class SpatialSelfAttention(nn.Module):
|
|||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
|
|
||||||
self.norm = Normalize(in_channels)
|
self.norm = Normalize(in_channels)
|
||||||
self.q = torch.nn.Conv2d(
|
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
)
|
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
self.k = torch.nn.Conv2d(
|
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
|
||||||
)
|
|
||||||
self.v = torch.nn.Conv2d(
|
|
||||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
|
||||||
)
|
|
||||||
self.proj_out = torch.nn.Conv2d(
|
|
||||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
h_ = x
|
h_ = x
|
||||||
@ -211,9 +192,7 @@ class CrossAttention(nn.Module):
|
|||||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
||||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||||
|
|
||||||
self.to_out = nn.Sequential(
|
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
||||||
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
|
|
||||||
)
|
|
||||||
self.backend = backend
|
self.backend = backend
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -241,12 +220,8 @@ class CrossAttention(nn.Module):
|
|||||||
# reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
|
# reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
|
||||||
assert x.shape[0] % n_times_crossframe_attn_in_self == 0
|
assert x.shape[0] % n_times_crossframe_attn_in_self == 0
|
||||||
n_cp = x.shape[0] // n_times_crossframe_attn_in_self
|
n_cp = x.shape[0] // n_times_crossframe_attn_in_self
|
||||||
k = repeat(
|
k = repeat(k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp)
|
||||||
k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
|
v = repeat(v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp)
|
||||||
)
|
|
||||||
v = repeat(
|
|
||||||
v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
|
|
||||||
)
|
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
|
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
|
||||||
|
|
||||||
@ -269,9 +244,7 @@ class CrossAttention(nn.Module):
|
|||||||
## new
|
## new
|
||||||
with sdp_kernel(**BACKEND_MAP[self.backend]):
|
with sdp_kernel(**BACKEND_MAP[self.backend]):
|
||||||
# print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
|
# print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
|
||||||
out = F.scaled_dot_product_attention(
|
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) # scale is dim_head ** -0.5 per default
|
||||||
q, k, v, attn_mask=mask
|
|
||||||
) # scale is dim_head ** -0.5 per default
|
|
||||||
|
|
||||||
del q, k, v
|
del q, k, v
|
||||||
out = rearrange(out, "b h n d -> b n (h d)", h=h)
|
out = rearrange(out, "b h n d -> b n (h d)", h=h)
|
||||||
@ -284,9 +257,7 @@ class CrossAttention(nn.Module):
|
|||||||
|
|
||||||
class MemoryEfficientCrossAttention(nn.Module):
|
class MemoryEfficientCrossAttention(nn.Module):
|
||||||
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
||||||
def __init__(
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs):
|
||||||
self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs
|
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
print(
|
print(
|
||||||
f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
|
f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
|
||||||
@ -302,9 +273,7 @@ class MemoryEfficientCrossAttention(nn.Module):
|
|||||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
||||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||||
|
|
||||||
self.to_out = nn.Sequential(
|
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
||||||
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
|
|
||||||
)
|
|
||||||
self.attention_op: Optional[Any] = None
|
self.attention_op: Optional[Any] = None
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -351,9 +320,7 @@ class MemoryEfficientCrossAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# actually compute the attention, what we cannot get enough of
|
# actually compute the attention, what we cannot get enough of
|
||||||
out = xformers.ops.memory_efficient_attention(
|
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
|
||||||
q, k, v, attn_bias=None, op=self.attention_op
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: Use this directly in the attention operation, as a bias
|
# TODO: Use this directly in the attention operation, as a bias
|
||||||
if exists(mask):
|
if exists(mask):
|
||||||
@ -398,13 +365,9 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
)
|
)
|
||||||
attn_mode = "softmax"
|
attn_mode = "softmax"
|
||||||
elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
|
elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
|
||||||
print(
|
print("We do not support vanilla attention anymore, as it is too expensive. Sorry.")
|
||||||
"We do not support vanilla attention anymore, as it is too expensive. Sorry."
|
|
||||||
)
|
|
||||||
if not XFORMERS_IS_AVAILABLE:
|
if not XFORMERS_IS_AVAILABLE:
|
||||||
assert (
|
assert False, "Please install xformers via e.g. 'pip install xformers==0.0.16'"
|
||||||
False
|
|
||||||
), "Please install xformers via e.g. 'pip install xformers==0.0.16'"
|
|
||||||
else:
|
else:
|
||||||
print("Falling back to xformers efficient attention.")
|
print("Falling back to xformers efficient attention.")
|
||||||
attn_mode = "softmax-xformers"
|
attn_mode = "softmax-xformers"
|
||||||
@ -438,9 +401,7 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
if self.checkpoint:
|
if self.checkpoint:
|
||||||
print(f"{self.__class__.__name__} is using checkpointing")
|
print(f"{self.__class__.__name__} is using checkpointing")
|
||||||
|
|
||||||
def forward(
|
def forward(self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
|
||||||
self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
|
|
||||||
):
|
|
||||||
kwargs = {"x": x}
|
kwargs = {"x": x}
|
||||||
|
|
||||||
if context is not None:
|
if context is not None:
|
||||||
@ -450,35 +411,22 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
kwargs.update({"additional_tokens": additional_tokens})
|
kwargs.update({"additional_tokens": additional_tokens})
|
||||||
|
|
||||||
if n_times_crossframe_attn_in_self:
|
if n_times_crossframe_attn_in_self:
|
||||||
kwargs.update(
|
kwargs.update({"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self})
|
||||||
{"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self}
|
|
||||||
)
|
|
||||||
|
|
||||||
# return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint)
|
# return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint)
|
||||||
return checkpoint(
|
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
|
||||||
self._forward, (x, context), self.parameters(), self.checkpoint
|
|
||||||
)
|
|
||||||
|
|
||||||
def _forward(
|
def _forward(self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
|
||||||
self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
|
|
||||||
):
|
|
||||||
x = (
|
x = (
|
||||||
self.attn1(
|
self.attn1(
|
||||||
self.norm1(x),
|
self.norm1(x),
|
||||||
context=context if self.disable_self_attn else None,
|
context=context if self.disable_self_attn else None,
|
||||||
additional_tokens=additional_tokens,
|
additional_tokens=additional_tokens,
|
||||||
n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self
|
n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self if not self.disable_self_attn else 0,
|
||||||
if not self.disable_self_attn
|
|
||||||
else 0,
|
|
||||||
)
|
|
||||||
+ x
|
|
||||||
)
|
|
||||||
x = (
|
|
||||||
self.attn2(
|
|
||||||
self.norm2(x), context=context, additional_tokens=additional_tokens
|
|
||||||
)
|
)
|
||||||
+ x
|
+ x
|
||||||
)
|
)
|
||||||
|
x = self.attn2(self.norm2(x), context=context, additional_tokens=additional_tokens) + x
|
||||||
x = self.ff(self.norm3(x)) + x
|
x = self.ff(self.norm3(x)) + x
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -486,7 +434,7 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
class BasicTransformerSingleLayerBlock(nn.Module):
|
class BasicTransformerSingleLayerBlock(nn.Module):
|
||||||
ATTENTION_MODES = {
|
ATTENTION_MODES = {
|
||||||
"softmax": CrossAttention, # vanilla attention
|
"softmax": CrossAttention, # vanilla attention
|
||||||
"softmax-xformers": MemoryEfficientCrossAttention # on the A100s not quite as fast as the above version
|
"softmax-xformers": MemoryEfficientCrossAttention, # on the A100s not quite as fast as the above version
|
||||||
# (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128])
|
# (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128])
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -517,9 +465,7 @@ class BasicTransformerSingleLayerBlock(nn.Module):
|
|||||||
self.checkpoint = checkpoint
|
self.checkpoint = checkpoint
|
||||||
|
|
||||||
def forward(self, x, context=None):
|
def forward(self, x, context=None):
|
||||||
return checkpoint(
|
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
|
||||||
self._forward, (x, context), self.parameters(), self.checkpoint
|
|
||||||
)
|
|
||||||
|
|
||||||
def _forward(self, x, context=None):
|
def _forward(self, x, context=None):
|
||||||
x = self.attn1(self.norm1(x), context=context) + x
|
x = self.attn1(self.norm1(x), context=context) + x
|
||||||
@ -553,9 +499,7 @@ class SpatialTransformer(nn.Module):
|
|||||||
sdp_backend=None,
|
sdp_backend=None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
print(
|
print(f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads")
|
||||||
f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads"
|
|
||||||
)
|
|
||||||
from omegaconf import ListConfig
|
from omegaconf import ListConfig
|
||||||
|
|
||||||
if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)):
|
if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)):
|
||||||
@ -577,9 +521,7 @@ class SpatialTransformer(nn.Module):
|
|||||||
inner_dim = n_heads * d_head
|
inner_dim = n_heads * d_head
|
||||||
self.norm = Normalize(in_channels)
|
self.norm = Normalize(in_channels)
|
||||||
if not use_linear:
|
if not use_linear:
|
||||||
self.proj_in = nn.Conv2d(
|
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
||||||
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
self.proj_in = nn.Linear(in_channels, inner_dim)
|
self.proj_in = nn.Linear(in_channels, inner_dim)
|
||||||
|
|
||||||
@ -600,9 +542,7 @@ class SpatialTransformer(nn.Module):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
if not use_linear:
|
if not use_linear:
|
||||||
self.proj_out = zero_module(
|
self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
|
||||||
nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
|
# self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
|
||||||
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
|
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
|
||||||
|
@ -73,7 +73,7 @@ class AbstractAutoencoder(pl.LightningModule):
|
|||||||
self.init_from_ckpt(ckpt)
|
self.init_from_ckpt(ckpt)
|
||||||
|
|
||||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||||
sd = torch.load(path, map_location="cpu")['state_dict']
|
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||||||
keys = list(sd.keys())
|
keys = list(sd.keys())
|
||||||
for k in keys:
|
for k in keys:
|
||||||
for ik in ignore_keys:
|
for ik in ignore_keys:
|
||||||
@ -119,9 +119,7 @@ class AbstractAutoencoder(pl.LightningModule):
|
|||||||
|
|
||||||
def instantiate_optimizer_from_config(self, params, lr, cfg):
|
def instantiate_optimizer_from_config(self, params, lr, cfg):
|
||||||
logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
|
logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
|
||||||
return get_obj_from_str(cfg["target"])(
|
return get_obj_from_str(cfg["target"])(params, lr=lr, **cfg.get("params", dict()))
|
||||||
params, lr=lr, **cfg.get("params", dict())
|
|
||||||
)
|
|
||||||
|
|
||||||
def configure_optimizers(self) -> Any:
|
def configure_optimizers(self) -> Any:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
@ -160,12 +158,8 @@ class AutoencodingEngine(AbstractAutoencoder):
|
|||||||
self.encoder = instantiate_from_config(encoder_config)
|
self.encoder = instantiate_from_config(encoder_config)
|
||||||
self.decoder = instantiate_from_config(decoder_config)
|
self.decoder = instantiate_from_config(decoder_config)
|
||||||
self.loss = instantiate_from_config(loss_config)
|
self.loss = instantiate_from_config(loss_config)
|
||||||
self.regularization = instantiate_from_config(
|
self.regularization = instantiate_from_config(regularizer_config)
|
||||||
regularizer_config
|
self.optimizer_config = default(optimizer_config, {"target": "torch.optim.Adam"})
|
||||||
)
|
|
||||||
self.optimizer_config = default(
|
|
||||||
optimizer_config, {"target": "torch.optim.Adam"}
|
|
||||||
)
|
|
||||||
self.diff_boost_factor = diff_boost_factor
|
self.diff_boost_factor = diff_boost_factor
|
||||||
self.disc_start_iter = disc_start_iter
|
self.disc_start_iter = disc_start_iter
|
||||||
self.lr_g_factor = lr_g_factor
|
self.lr_g_factor = lr_g_factor
|
||||||
@ -239,20 +233,14 @@ class AutoencodingEngine(AbstractAutoencoder):
|
|||||||
x = self.decoder(z, **kwargs)
|
x = self.decoder(z, **kwargs)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(
|
def forward(self, x: torch.Tensor, **additional_decode_kwargs) -> Tuple[torch.Tensor, torch.Tensor, dict]:
|
||||||
self, x: torch.Tensor, **additional_decode_kwargs
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
|
|
||||||
z, reg_log = self.encode(x, return_reg_log=True)
|
z, reg_log = self.encode(x, return_reg_log=True)
|
||||||
dec = self.decode(z, **additional_decode_kwargs)
|
dec = self.decode(z, **additional_decode_kwargs)
|
||||||
return z, dec, reg_log
|
return z, dec, reg_log
|
||||||
|
|
||||||
def inner_training_step(
|
def inner_training_step(self, batch: dict, batch_idx: int, optimizer_idx: int = 0) -> torch.Tensor:
|
||||||
self, batch: dict, batch_idx: int, optimizer_idx: int = 0
|
|
||||||
) -> torch.Tensor:
|
|
||||||
x = self.get_input(batch)
|
x = self.get_input(batch)
|
||||||
additional_decode_kwargs = {
|
additional_decode_kwargs = {key: batch[key] for key in self.additional_decode_keys.intersection(batch)}
|
||||||
key: batch[key] for key in self.additional_decode_keys.intersection(batch)
|
|
||||||
}
|
|
||||||
z, xrec, regularization_log = self(x, **additional_decode_kwargs)
|
z, xrec, regularization_log = self(x, **additional_decode_kwargs)
|
||||||
if hasattr(self.loss, "forward_keys"):
|
if hasattr(self.loss, "forward_keys"):
|
||||||
extra_info = {
|
extra_info = {
|
||||||
@ -299,9 +287,7 @@ class AutoencodingEngine(AbstractAutoencoder):
|
|||||||
# discriminator
|
# discriminator
|
||||||
discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
|
discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
|
||||||
# -> discriminator always needs to return a tuple
|
# -> discriminator always needs to return a tuple
|
||||||
self.log_dict(
|
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
||||||
log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True
|
|
||||||
)
|
|
||||||
return discloss
|
return discloss
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Unknown optimizer {optimizer_idx}")
|
raise NotImplementedError(f"Unknown optimizer {optimizer_idx}")
|
||||||
@ -317,9 +303,7 @@ class AutoencodingEngine(AbstractAutoencoder):
|
|||||||
opt = opts[optimizer_idx]
|
opt = opts[optimizer_idx]
|
||||||
opt.zero_grad()
|
opt.zero_grad()
|
||||||
with opt.toggle_model():
|
with opt.toggle_model():
|
||||||
loss = self.inner_training_step(
|
loss = self.inner_training_step(batch, batch_idx, optimizer_idx=optimizer_idx)
|
||||||
batch, batch_idx, optimizer_idx=optimizer_idx
|
|
||||||
)
|
|
||||||
self.manual_backward(loss)
|
self.manual_backward(loss)
|
||||||
opt.step()
|
opt.step()
|
||||||
|
|
||||||
@ -392,19 +376,13 @@ class AutoencodingEngine(AbstractAutoencoder):
|
|||||||
if self.trainable_ae_params is None:
|
if self.trainable_ae_params is None:
|
||||||
ae_params = self.get_autoencoder_params()
|
ae_params = self.get_autoencoder_params()
|
||||||
else:
|
else:
|
||||||
ae_params, num_ae_params = self.get_param_groups(
|
ae_params, num_ae_params = self.get_param_groups(self.trainable_ae_params, self.ae_optimizer_args)
|
||||||
self.trainable_ae_params, self.ae_optimizer_args
|
|
||||||
)
|
|
||||||
logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}")
|
logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}")
|
||||||
if self.trainable_disc_params is None:
|
if self.trainable_disc_params is None:
|
||||||
disc_params = self.get_discriminator_params()
|
disc_params = self.get_discriminator_params()
|
||||||
else:
|
else:
|
||||||
disc_params, num_disc_params = self.get_param_groups(
|
disc_params, num_disc_params = self.get_param_groups(self.trainable_disc_params, self.disc_optimizer_args)
|
||||||
self.trainable_disc_params, self.disc_optimizer_args
|
logpy.info(f"Number of trainable discriminator parameters: {num_disc_params:,}")
|
||||||
)
|
|
||||||
logpy.info(
|
|
||||||
f"Number of trainable discriminator parameters: {num_disc_params:,}"
|
|
||||||
)
|
|
||||||
opt_ae = self.instantiate_optimizer_from_config(
|
opt_ae = self.instantiate_optimizer_from_config(
|
||||||
ae_params,
|
ae_params,
|
||||||
default(self.lr_g_factor, 1.0) * self.learning_rate,
|
default(self.lr_g_factor, 1.0) * self.learning_rate,
|
||||||
@ -412,23 +390,17 @@ class AutoencodingEngine(AbstractAutoencoder):
|
|||||||
)
|
)
|
||||||
opts = [opt_ae]
|
opts = [opt_ae]
|
||||||
if len(disc_params) > 0:
|
if len(disc_params) > 0:
|
||||||
opt_disc = self.instantiate_optimizer_from_config(
|
opt_disc = self.instantiate_optimizer_from_config(disc_params, self.learning_rate, self.optimizer_config)
|
||||||
disc_params, self.learning_rate, self.optimizer_config
|
|
||||||
)
|
|
||||||
opts.append(opt_disc)
|
opts.append(opt_disc)
|
||||||
|
|
||||||
return opts
|
return opts
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def log_images(
|
def log_images(self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs) -> dict:
|
||||||
self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs
|
|
||||||
) -> dict:
|
|
||||||
log = dict()
|
log = dict()
|
||||||
additional_decode_kwargs = {}
|
additional_decode_kwargs = {}
|
||||||
x = self.get_input(batch)
|
x = self.get_input(batch)
|
||||||
additional_decode_kwargs.update(
|
additional_decode_kwargs.update({key: batch[key] for key in self.additional_decode_keys.intersection(batch)})
|
||||||
{key: batch[key] for key in self.additional_decode_keys.intersection(batch)}
|
|
||||||
)
|
|
||||||
|
|
||||||
_, xrec, _ = self(x, **additional_decode_kwargs)
|
_, xrec, _ = self(x, **additional_decode_kwargs)
|
||||||
log["inputs"] = x
|
log["inputs"] = x
|
||||||
@ -438,9 +410,7 @@ class AutoencodingEngine(AbstractAutoencoder):
|
|||||||
log["diff"] = 2.0 * diff - 1.0
|
log["diff"] = 2.0 * diff - 1.0
|
||||||
# diff_boost shows location of small errors, by boosting their
|
# diff_boost shows location of small errors, by boosting their
|
||||||
# brightness.
|
# brightness.
|
||||||
log["diff_boost"] = (
|
log["diff_boost"] = 2.0 * torch.clamp(self.diff_boost_factor * diff, 0.0, 1.0) - 1
|
||||||
2.0 * torch.clamp(self.diff_boost_factor * diff, 0.0, 1.0) - 1
|
|
||||||
)
|
|
||||||
if hasattr(self.loss, "log_images"):
|
if hasattr(self.loss, "log_images"):
|
||||||
log.update(self.loss.log_images(x, xrec))
|
log.update(self.loss.log_images(x, xrec))
|
||||||
with self.ema_scope():
|
with self.ema_scope():
|
||||||
@ -449,9 +419,7 @@ class AutoencodingEngine(AbstractAutoencoder):
|
|||||||
diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x)
|
diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x)
|
||||||
diff_ema.clamp_(0, 1.0)
|
diff_ema.clamp_(0, 1.0)
|
||||||
log["diff_ema"] = 2.0 * diff_ema - 1.0
|
log["diff_ema"] = 2.0 * diff_ema - 1.0
|
||||||
log["diff_boost_ema"] = (
|
log["diff_boost_ema"] = 2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1
|
||||||
2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1
|
|
||||||
)
|
|
||||||
if additional_log_kwargs:
|
if additional_log_kwargs:
|
||||||
additional_decode_kwargs.update(additional_log_kwargs)
|
additional_decode_kwargs.update(additional_log_kwargs)
|
||||||
_, xrec_add, _ = self(x, **additional_decode_kwargs)
|
_, xrec_add, _ = self(x, **additional_decode_kwargs)
|
||||||
@ -493,9 +461,7 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
|
|||||||
params = super().get_autoencoder_params()
|
params = super().get_autoencoder_params()
|
||||||
return params
|
return params
|
||||||
|
|
||||||
def encode(
|
def encode(self, x: torch.Tensor, return_reg_log: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
||||||
self, x: torch.Tensor, return_reg_log: bool = False
|
|
||||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
|
||||||
if self.max_batch_size is None:
|
if self.max_batch_size is None:
|
||||||
z = self.encoder(x)
|
z = self.encoder(x)
|
||||||
z = self.quant_conv(z)
|
z = self.quant_conv(z)
|
||||||
@ -538,12 +504,7 @@ class AutoencoderKL(AutoencodingEngineLegacy):
|
|||||||
if "lossconfig" in kwargs:
|
if "lossconfig" in kwargs:
|
||||||
kwargs["loss_config"] = kwargs.pop("lossconfig")
|
kwargs["loss_config"] = kwargs.pop("lossconfig")
|
||||||
super().__init__(
|
super().__init__(
|
||||||
regularizer_config={
|
regularizer_config={"target": ("sgm.modules.autoencoding.regularizers" ".DiagonalGaussianRegularizer")},
|
||||||
"target": (
|
|
||||||
"sgm.modules.autoencoding.regularizers"
|
|
||||||
".DiagonalGaussianRegularizer"
|
|
||||||
)
|
|
||||||
},
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -567,7 +528,7 @@ class VideoAutoencodingEngine(AutoencodingEngine):
|
|||||||
self,
|
self,
|
||||||
ckpt_path: Union[None, str] = None,
|
ckpt_path: Union[None, str] = None,
|
||||||
ignore_keys: Union[Tuple, list] = (),
|
ignore_keys: Union[Tuple, list] = (),
|
||||||
image_video_weights=[1,1],
|
image_video_weights=[1, 1],
|
||||||
only_train_decoder=False,
|
only_train_decoder=False,
|
||||||
context_parallel_size=0,
|
context_parallel_size=0,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -577,9 +538,7 @@ class VideoAutoencodingEngine(AutoencodingEngine):
|
|||||||
if ckpt_path is not None:
|
if ckpt_path is not None:
|
||||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||||
|
|
||||||
def log_videos(
|
def log_videos(self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs) -> dict:
|
||||||
self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs
|
|
||||||
) -> dict:
|
|
||||||
return self.log_images(batch, additional_log_kwargs, **kwargs)
|
return self.log_images(batch, additional_log_kwargs, **kwargs)
|
||||||
|
|
||||||
def get_input(self, batch: dict) -> torch.Tensor:
|
def get_input(self, batch: dict) -> torch.Tensor:
|
||||||
@ -603,7 +562,7 @@ class VideoAutoencodingEngine(AutoencodingEngine):
|
|||||||
self.init_from_ckpt(ckpt)
|
self.init_from_ckpt(ckpt)
|
||||||
|
|
||||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||||
sd = torch.load(path, map_location="cpu")['state_dict']
|
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||||||
keys = list(sd.keys())
|
keys = list(sd.keys())
|
||||||
for k in keys:
|
for k in keys:
|
||||||
for ik in ignore_keys:
|
for ik in ignore_keys:
|
||||||
@ -615,6 +574,7 @@ class VideoAutoencodingEngine(AutoencodingEngine):
|
|||||||
print("Unexpected keys: ", unexpected_keys)
|
print("Unexpected keys: ", unexpected_keys)
|
||||||
print(f"Restored from {path}")
|
print(f"Restored from {path}")
|
||||||
|
|
||||||
|
|
||||||
class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine):
|
class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -633,7 +593,6 @@ class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine):
|
|||||||
input_cp: bool = False,
|
input_cp: bool = False,
|
||||||
output_cp: bool = False,
|
output_cp: bool = False,
|
||||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
||||||
|
|
||||||
if self.cp_size > 0 and not input_cp:
|
if self.cp_size > 0 and not input_cp:
|
||||||
if not is_context_parallel_initialized:
|
if not is_context_parallel_initialized:
|
||||||
initialize_context_parallel(self.cp_size)
|
initialize_context_parallel(self.cp_size)
|
||||||
@ -685,10 +644,8 @@ class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine):
|
|||||||
input_cp: bool = False,
|
input_cp: bool = False,
|
||||||
latent_cp: bool = False,
|
latent_cp: bool = False,
|
||||||
output_cp: bool = False,
|
output_cp: bool = False,
|
||||||
**additional_decode_kwargs
|
**additional_decode_kwargs,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
|
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
|
||||||
|
|
||||||
z, reg_log = self.encode(x, return_reg_log=True, input_cp=input_cp, output_cp=latent_cp)
|
z, reg_log = self.encode(x, return_reg_log=True, input_cp=input_cp, output_cp=latent_cp)
|
||||||
dec = self.decode(z, input_cp=latent_cp, output_cp=output_cp, **additional_decode_kwargs)
|
dec = self.decode(z, input_cp=latent_cp, output_cp=output_cp, **additional_decode_kwargs)
|
||||||
return z, dec, reg_log
|
return z, dec, reg_log
|
||||||
|
|
@ -9,7 +9,12 @@ from beartype import beartype
|
|||||||
from beartype.typing import Union, Tuple, Optional, List
|
from beartype.typing import Union, Tuple, Optional, List
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
from sgm.util import get_context_parallel_group, get_context_parallel_rank, get_context_parallel_world_size, get_context_parallel_group_rank
|
from sgm.util import (
|
||||||
|
get_context_parallel_group,
|
||||||
|
get_context_parallel_rank,
|
||||||
|
get_context_parallel_world_size,
|
||||||
|
get_context_parallel_group_rank,
|
||||||
|
)
|
||||||
|
|
||||||
# try:
|
# try:
|
||||||
from vae_modules.utils import SafeConv3d as Conv3d
|
from vae_modules.utils import SafeConv3d as Conv3d
|
||||||
@ -17,12 +22,15 @@ from vae_modules.utils import SafeConv3d as Conv3d
|
|||||||
# # Degrade to normal Conv3d if SafeConv3d is not available
|
# # Degrade to normal Conv3d if SafeConv3d is not available
|
||||||
# from torch.nn import Conv3d
|
# from torch.nn import Conv3d
|
||||||
|
|
||||||
def cast_tuple(t, length = 1):
|
|
||||||
|
def cast_tuple(t, length=1):
|
||||||
return t if isinstance(t, tuple) else ((t,) * length)
|
return t if isinstance(t, tuple) else ((t,) * length)
|
||||||
|
|
||||||
|
|
||||||
def divisible_by(num, den):
|
def divisible_by(num, den):
|
||||||
return (num % den) == 0
|
return (num % den) == 0
|
||||||
|
|
||||||
|
|
||||||
def is_odd(n):
|
def is_odd(n):
|
||||||
return not divisible_by(n, 2)
|
return not divisible_by(n, 2)
|
||||||
|
|
||||||
@ -30,6 +38,7 @@ def is_odd(n):
|
|||||||
def exists(v):
|
def exists(v):
|
||||||
return v is not None
|
return v is not None
|
||||||
|
|
||||||
|
|
||||||
def pair(t):
|
def pair(t):
|
||||||
return t if isinstance(t, tuple) else (t, t)
|
return t if isinstance(t, tuple) else (t, t)
|
||||||
|
|
||||||
@ -51,15 +60,16 @@ def get_timestep_embedding(timesteps, embedding_dim):
|
|||||||
emb = timesteps.float()[:, None] * emb[None, :]
|
emb = timesteps.float()[:, None] * emb[None, :]
|
||||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||||
if embedding_dim % 2 == 1: # zero pad
|
if embedding_dim % 2 == 1: # zero pad
|
||||||
emb = torch.nn.functional.pad(emb, (0,1,0,0))
|
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
||||||
return emb
|
return emb
|
||||||
|
|
||||||
|
|
||||||
def nonlinearity(x):
|
def nonlinearity(x):
|
||||||
# swish
|
# swish
|
||||||
return x*torch.sigmoid(x)
|
return x * torch.sigmoid(x)
|
||||||
|
|
||||||
def leaky_relu(p = 0.1):
|
|
||||||
|
def leaky_relu(p=0.1):
|
||||||
return nn.LeakyReLU(p)
|
return nn.LeakyReLU(p)
|
||||||
|
|
||||||
|
|
||||||
@ -88,6 +98,7 @@ def _split(input_, dim):
|
|||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def _gather(input_, dim):
|
def _gather(input_, dim):
|
||||||
cp_world_size = get_context_parallel_world_size()
|
cp_world_size = get_context_parallel_world_size()
|
||||||
|
|
||||||
@ -104,7 +115,9 @@ def _gather(input_, dim):
|
|||||||
if cp_rank == 0:
|
if cp_rank == 0:
|
||||||
input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous()
|
input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous()
|
||||||
|
|
||||||
tensor_list = [torch.empty_like(torch.cat([input_first_frame_, input_], dim=dim))] + [torch.empty_like(input_) for _ in range(cp_world_size - 1)]
|
tensor_list = [torch.empty_like(torch.cat([input_first_frame_, input_], dim=dim))] + [
|
||||||
|
torch.empty_like(input_) for _ in range(cp_world_size - 1)
|
||||||
|
]
|
||||||
|
|
||||||
if cp_rank == 0:
|
if cp_rank == 0:
|
||||||
input_ = torch.cat([input_first_frame_, input_], dim=dim)
|
input_ = torch.cat([input_first_frame_, input_], dim=dim)
|
||||||
@ -118,6 +131,7 @@ def _gather(input_, dim):
|
|||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def _conv_split(input_, dim, kernel_size):
|
def _conv_split(input_, dim, kernel_size):
|
||||||
cp_world_size = get_context_parallel_world_size()
|
cp_world_size = get_context_parallel_world_size()
|
||||||
|
|
||||||
@ -132,16 +146,19 @@ def _conv_split(input_, dim, kernel_size):
|
|||||||
dim_size = (input_.size()[dim] - kernel_size) // cp_world_size
|
dim_size = (input_.size()[dim] - kernel_size) // cp_world_size
|
||||||
|
|
||||||
if cp_rank == 0:
|
if cp_rank == 0:
|
||||||
output = input_.transpose(dim, 0)[:dim_size + kernel_size].transpose(dim, 0)
|
output = input_.transpose(dim, 0)[: dim_size + kernel_size].transpose(dim, 0)
|
||||||
else:
|
else:
|
||||||
# output = input_.transpose(dim, 0)[cp_rank * dim_size + 1:(cp_rank + 1) * dim_size + kernel_size].transpose(dim, 0)
|
# output = input_.transpose(dim, 0)[cp_rank * dim_size + 1:(cp_rank + 1) * dim_size + kernel_size].transpose(dim, 0)
|
||||||
output = input_.transpose(dim, 0)[cp_rank * dim_size + kernel_size:(cp_rank + 1) * dim_size + kernel_size].transpose(dim, 0)
|
output = input_.transpose(dim, 0)[
|
||||||
|
cp_rank * dim_size + kernel_size : (cp_rank + 1) * dim_size + kernel_size
|
||||||
|
].transpose(dim, 0)
|
||||||
output = output.contiguous()
|
output = output.contiguous()
|
||||||
|
|
||||||
# print('out _conv_split, cp_rank:', cp_rank, 'input_size:', output.shape)
|
# print('out _conv_split, cp_rank:', cp_rank, 'input_size:', output.shape)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def _conv_gather(input_, dim, kernel_size):
|
def _conv_gather(input_, dim, kernel_size):
|
||||||
cp_world_size = get_context_parallel_world_size()
|
cp_world_size = get_context_parallel_world_size()
|
||||||
|
|
||||||
@ -158,9 +175,11 @@ def _conv_gather(input_, dim, kernel_size):
|
|||||||
if cp_rank == 0:
|
if cp_rank == 0:
|
||||||
input_ = input_.transpose(0, dim)[kernel_size:].transpose(0, dim).contiguous()
|
input_ = input_.transpose(0, dim)[kernel_size:].transpose(0, dim).contiguous()
|
||||||
else:
|
else:
|
||||||
input_ = input_.transpose(0, dim)[max(kernel_size - 1, 0):].transpose(0, dim).contiguous()
|
input_ = input_.transpose(0, dim)[max(kernel_size - 1, 0) :].transpose(0, dim).contiguous()
|
||||||
|
|
||||||
tensor_list = [torch.empty_like(torch.cat([input_first_kernel_, input_], dim=dim))] + [torch.empty_like(input_) for _ in range(cp_world_size - 1)]
|
tensor_list = [torch.empty_like(torch.cat([input_first_kernel_, input_], dim=dim))] + [
|
||||||
|
torch.empty_like(input_) for _ in range(cp_world_size - 1)
|
||||||
|
]
|
||||||
if cp_rank == 0:
|
if cp_rank == 0:
|
||||||
input_ = torch.cat([input_first_kernel_, input_], dim=dim)
|
input_ = torch.cat([input_first_kernel_, input_], dim=dim)
|
||||||
|
|
||||||
@ -174,8 +193,8 @@ def _conv_gather(input_, dim, kernel_size):
|
|||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def _pass_from_previous_rank(input_, dim, kernel_size):
|
|
||||||
|
|
||||||
|
def _pass_from_previous_rank(input_, dim, kernel_size):
|
||||||
# Bypass the function if kernel size is 1
|
# Bypass the function if kernel size is 1
|
||||||
if kernel_size == 1:
|
if kernel_size == 1:
|
||||||
return input_
|
return input_
|
||||||
@ -201,9 +220,9 @@ def _pass_from_previous_rank(input_, dim, kernel_size):
|
|||||||
recv_rank += cp_world_size
|
recv_rank += cp_world_size
|
||||||
|
|
||||||
if cp_rank < cp_world_size - 1:
|
if cp_rank < cp_world_size - 1:
|
||||||
req_send = torch.distributed.isend(input_[-kernel_size + 1:].contiguous(), send_rank, group=group)
|
req_send = torch.distributed.isend(input_[-kernel_size + 1 :].contiguous(), send_rank, group=group)
|
||||||
if cp_rank > 0:
|
if cp_rank > 0:
|
||||||
recv_buffer = torch.empty_like(input_[-kernel_size + 1:]).contiguous()
|
recv_buffer = torch.empty_like(input_[-kernel_size + 1 :]).contiguous()
|
||||||
req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group)
|
req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group)
|
||||||
|
|
||||||
if cp_rank == 0:
|
if cp_rank == 0:
|
||||||
@ -218,8 +237,8 @@ def _pass_from_previous_rank(input_, dim, kernel_size):
|
|||||||
|
|
||||||
return input_
|
return input_
|
||||||
|
|
||||||
def _fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding=None):
|
|
||||||
|
|
||||||
|
def _fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding=None):
|
||||||
# Bypass the function if kernel size is 1
|
# Bypass the function if kernel size is 1
|
||||||
if kernel_size == 1:
|
if kernel_size == 1:
|
||||||
return input_
|
return input_
|
||||||
@ -248,9 +267,9 @@ def _fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding=Non
|
|||||||
# recv_buffer = torch.empty_like(input_[-kernel_size + 1:]).contiguous()
|
# recv_buffer = torch.empty_like(input_[-kernel_size + 1:]).contiguous()
|
||||||
# req_recv = torch.distributed.recv(recv_buffer, recv_rank, group=group)
|
# req_recv = torch.distributed.recv(recv_buffer, recv_rank, group=group)
|
||||||
# req_recv.wait()
|
# req_recv.wait()
|
||||||
recv_buffer = torch.empty_like(input_[-kernel_size + 1:]).contiguous()
|
recv_buffer = torch.empty_like(input_[-kernel_size + 1 :]).contiguous()
|
||||||
if cp_rank < cp_world_size - 1:
|
if cp_rank < cp_world_size - 1:
|
||||||
req_send = torch.distributed.isend(input_[-kernel_size + 1:].contiguous(), send_rank, group=group)
|
req_send = torch.distributed.isend(input_[-kernel_size + 1 :].contiguous(), send_rank, group=group)
|
||||||
if cp_rank > 0:
|
if cp_rank > 0:
|
||||||
req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group)
|
req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group)
|
||||||
# req_send = torch.distributed.isend(input_[-kernel_size + 1:].contiguous(), send_rank, group=group)
|
# req_send = torch.distributed.isend(input_[-kernel_size + 1:].contiguous(), send_rank, group=group)
|
||||||
@ -270,12 +289,11 @@ def _fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding=Non
|
|||||||
|
|
||||||
|
|
||||||
def _drop_from_previous_rank(input_, dim, kernel_size):
|
def _drop_from_previous_rank(input_, dim, kernel_size):
|
||||||
input_ = input_.transpose(0, dim)[kernel_size - 1:].transpose(0, dim)
|
input_ = input_.transpose(0, dim)[kernel_size - 1 :].transpose(0, dim)
|
||||||
return input_
|
return input_
|
||||||
|
|
||||||
|
|
||||||
class _ConvolutionScatterToContextParallelRegion(torch.autograd.Function):
|
class _ConvolutionScatterToContextParallelRegion(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, input_, dim, kernel_size):
|
def forward(ctx, input_, dim, kernel_size):
|
||||||
ctx.dim = dim
|
ctx.dim = dim
|
||||||
@ -286,8 +304,8 @@ class _ConvolutionScatterToContextParallelRegion(torch.autograd.Function):
|
|||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
return _conv_gather(grad_output, ctx.dim, ctx.kernel_size), None, None
|
return _conv_gather(grad_output, ctx.dim, ctx.kernel_size), None, None
|
||||||
|
|
||||||
class _ConvolutionGatherFromContextParallelRegion(torch.autograd.Function):
|
|
||||||
|
|
||||||
|
class _ConvolutionGatherFromContextParallelRegion(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, input_, dim, kernel_size):
|
def forward(ctx, input_, dim, kernel_size):
|
||||||
ctx.dim = dim
|
ctx.dim = dim
|
||||||
@ -298,8 +316,8 @@ class _ConvolutionGatherFromContextParallelRegion(torch.autograd.Function):
|
|||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
return _conv_split(grad_output, ctx.dim, ctx.kernel_size), None, None
|
return _conv_split(grad_output, ctx.dim, ctx.kernel_size), None, None
|
||||||
|
|
||||||
class _ConvolutionPassFromPreviousRank(torch.autograd.Function):
|
|
||||||
|
|
||||||
|
class _ConvolutionPassFromPreviousRank(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, input_, dim, kernel_size):
|
def forward(ctx, input_, dim, kernel_size):
|
||||||
ctx.dim = dim
|
ctx.dim = dim
|
||||||
@ -310,8 +328,8 @@ class _ConvolutionPassFromPreviousRank(torch.autograd.Function):
|
|||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
return _drop_from_previous_rank(grad_output, ctx.dim, ctx.kernel_size), None, None
|
return _drop_from_previous_rank(grad_output, ctx.dim, ctx.kernel_size), None, None
|
||||||
|
|
||||||
class _FakeCPConvolutionPassFromPreviousRank(torch.autograd.Function):
|
|
||||||
|
|
||||||
|
class _FakeCPConvolutionPassFromPreviousRank(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, input_, dim, kernel_size, cache_padding):
|
def forward(ctx, input_, dim, kernel_size, cache_padding):
|
||||||
ctx.dim = dim
|
ctx.dim = dim
|
||||||
@ -326,25 +344,21 @@ class _FakeCPConvolutionPassFromPreviousRank(torch.autograd.Function):
|
|||||||
def conv_scatter_to_context_parallel_region(input_, dim, kernel_size):
|
def conv_scatter_to_context_parallel_region(input_, dim, kernel_size):
|
||||||
return _ConvolutionScatterToContextParallelRegion.apply(input_, dim, kernel_size)
|
return _ConvolutionScatterToContextParallelRegion.apply(input_, dim, kernel_size)
|
||||||
|
|
||||||
|
|
||||||
def conv_gather_from_context_parallel_region(input_, dim, kernel_size):
|
def conv_gather_from_context_parallel_region(input_, dim, kernel_size):
|
||||||
return _ConvolutionGatherFromContextParallelRegion.apply(input_, dim, kernel_size)
|
return _ConvolutionGatherFromContextParallelRegion.apply(input_, dim, kernel_size)
|
||||||
|
|
||||||
|
|
||||||
def conv_pass_from_last_rank(input_, dim, kernel_size):
|
def conv_pass_from_last_rank(input_, dim, kernel_size):
|
||||||
return _ConvolutionPassFromPreviousRank.apply(input_, dim, kernel_size)
|
return _ConvolutionPassFromPreviousRank.apply(input_, dim, kernel_size)
|
||||||
|
|
||||||
|
|
||||||
def fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding):
|
def fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding):
|
||||||
return _FakeCPConvolutionPassFromPreviousRank.apply(input_, dim, kernel_size, cache_padding)
|
return _FakeCPConvolutionPassFromPreviousRank.apply(input_, dim, kernel_size, cache_padding)
|
||||||
|
|
||||||
|
|
||||||
class ContextParallelCausalConv3d(nn.Module):
|
class ContextParallelCausalConv3d(nn.Module):
|
||||||
def __init__(
|
def __init__(self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], stride=1, **kwargs):
|
||||||
self,
|
|
||||||
chan_in,
|
|
||||||
chan_out,
|
|
||||||
kernel_size: Union[int, Tuple[int, int, int]],
|
|
||||||
stride = 1,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
kernel_size = cast_tuple(kernel_size, 3)
|
kernel_size = cast_tuple(kernel_size, 3)
|
||||||
|
|
||||||
@ -364,7 +378,7 @@ class ContextParallelCausalConv3d(nn.Module):
|
|||||||
|
|
||||||
stride = (stride, stride, stride)
|
stride = (stride, stride, stride)
|
||||||
dilation = (1, 1, 1)
|
dilation = (1, 1, 1)
|
||||||
self.conv = Conv3d(chan_in, chan_out, kernel_size, stride = stride, dilation = dilation, **kwargs)
|
self.conv = Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
||||||
self.cache_padding = None
|
self.cache_padding = None
|
||||||
|
|
||||||
def forward(self, input_, clear_cache=True):
|
def forward(self, input_, clear_cache=True):
|
||||||
@ -381,7 +395,9 @@ class ContextParallelCausalConv3d(nn.Module):
|
|||||||
# output = output_parallel
|
# output = output_parallel
|
||||||
# return output
|
# return output
|
||||||
|
|
||||||
input_parallel = fake_cp_pass_from_previous_rank(input_, self.temporal_dim, self.time_kernel_size, self.cache_padding)
|
input_parallel = fake_cp_pass_from_previous_rank(
|
||||||
|
input_, self.temporal_dim, self.time_kernel_size, self.cache_padding
|
||||||
|
)
|
||||||
|
|
||||||
del self.cache_padding
|
del self.cache_padding
|
||||||
self.cache_padding = None
|
self.cache_padding = None
|
||||||
@ -389,17 +405,25 @@ class ContextParallelCausalConv3d(nn.Module):
|
|||||||
cp_rank, cp_world_size = get_context_parallel_rank(), get_context_parallel_world_size()
|
cp_rank, cp_world_size = get_context_parallel_rank(), get_context_parallel_world_size()
|
||||||
global_rank = torch.distributed.get_rank()
|
global_rank = torch.distributed.get_rank()
|
||||||
if cp_world_size == 1:
|
if cp_world_size == 1:
|
||||||
self.cache_padding = input_parallel[:, :, -self.time_kernel_size + 1:].contiguous().detach().clone().cpu()
|
self.cache_padding = (
|
||||||
|
input_parallel[:, :, -self.time_kernel_size + 1 :].contiguous().detach().clone().cpu()
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if cp_rank == cp_world_size - 1:
|
if cp_rank == cp_world_size - 1:
|
||||||
torch.distributed.isend(input_parallel[:, :, -self.time_kernel_size + 1:].contiguous(), global_rank + 1 - cp_world_size, group=get_context_parallel_group())
|
torch.distributed.isend(
|
||||||
|
input_parallel[:, :, -self.time_kernel_size + 1 :].contiguous(),
|
||||||
|
global_rank + 1 - cp_world_size,
|
||||||
|
group=get_context_parallel_group(),
|
||||||
|
)
|
||||||
if cp_rank == 0:
|
if cp_rank == 0:
|
||||||
recv_buffer = torch.empty_like(input_parallel[:, :, -self.time_kernel_size + 1:]).contiguous()
|
recv_buffer = torch.empty_like(input_parallel[:, :, -self.time_kernel_size + 1 :]).contiguous()
|
||||||
torch.distributed.recv(recv_buffer, global_rank - 1 + cp_world_size, group=get_context_parallel_group())
|
torch.distributed.recv(
|
||||||
|
recv_buffer, global_rank - 1 + cp_world_size, group=get_context_parallel_group()
|
||||||
|
)
|
||||||
self.cache_padding = recv_buffer.contiguous().detach().clone().cpu()
|
self.cache_padding = recv_buffer.contiguous().detach().clone().cpu()
|
||||||
|
|
||||||
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
|
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
|
||||||
input_parallel = F.pad(input_parallel, padding_2d, mode = 'constant', value = 0)
|
input_parallel = F.pad(input_parallel, padding_2d, mode="constant", value=0)
|
||||||
|
|
||||||
output_parallel = self.conv(input_parallel)
|
output_parallel = self.conv(input_parallel)
|
||||||
output = output_parallel
|
output = output_parallel
|
||||||
@ -416,16 +440,14 @@ class ContextParallelGroupNorm(torch.nn.GroupNorm):
|
|||||||
output = conv_scatter_to_context_parallel_region(output, dim=2, kernel_size=1)
|
output = conv_scatter_to_context_parallel_region(output, dim=2, kernel_size=1)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def Normalize(
|
|
||||||
in_channels,
|
def Normalize(in_channels, gather=False, **kwargs): # same for 3D and 2D
|
||||||
gather=False,
|
|
||||||
**kwargs
|
|
||||||
): # same for 3D and 2D
|
|
||||||
if gather:
|
if gather:
|
||||||
return ContextParallelGroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
return ContextParallelGroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||||
else:
|
else:
|
||||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||||
|
|
||||||
|
|
||||||
class SpatialNorm3D(nn.Module):
|
class SpatialNorm3D(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -433,9 +455,9 @@ class SpatialNorm3D(nn.Module):
|
|||||||
zq_channels,
|
zq_channels,
|
||||||
freeze_norm_layer=False,
|
freeze_norm_layer=False,
|
||||||
add_conv=False,
|
add_conv=False,
|
||||||
pad_mode='constant',
|
pad_mode="constant",
|
||||||
gather=False,
|
gather=False,
|
||||||
**norm_layer_params
|
**norm_layer_params,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if gather:
|
if gather:
|
||||||
@ -487,6 +509,7 @@ class SpatialNorm3D(nn.Module):
|
|||||||
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
|
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
|
||||||
return new_f
|
return new_f
|
||||||
|
|
||||||
|
|
||||||
def Normalize3D(
|
def Normalize3D(
|
||||||
in_channels,
|
in_channels,
|
||||||
zq_ch,
|
zq_ch,
|
||||||
@ -515,11 +538,7 @@ class Upsample3D(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.with_conv = with_conv
|
self.with_conv = with_conv
|
||||||
if self.with_conv:
|
if self.with_conv:
|
||||||
self.conv = torch.nn.Conv2d(in_channels,
|
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||||
in_channels,
|
|
||||||
kernel_size=3,
|
|
||||||
stride=1,
|
|
||||||
padding=1)
|
|
||||||
self.compress_time = compress_time
|
self.compress_time = compress_time
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@ -537,42 +556,33 @@ class Upsample3D(nn.Module):
|
|||||||
else:
|
else:
|
||||||
# only interpolate 2D
|
# only interpolate 2D
|
||||||
t = x.shape[2]
|
t = x.shape[2]
|
||||||
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||||
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
|
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
|
||||||
|
|
||||||
if self.with_conv:
|
if self.with_conv:
|
||||||
t = x.shape[2]
|
t = x.shape[2]
|
||||||
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||||
x = self.conv(x)
|
x = self.conv(x)
|
||||||
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
|
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class DownSample3D(nn.Module):
|
class DownSample3D(nn.Module):
|
||||||
def __init__(
|
def __init__(self, in_channels, with_conv, compress_time=False, out_channels=None):
|
||||||
self,
|
|
||||||
in_channels,
|
|
||||||
with_conv,
|
|
||||||
compress_time=False,
|
|
||||||
out_channels=None
|
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.with_conv = with_conv
|
self.with_conv = with_conv
|
||||||
if out_channels is None:
|
if out_channels is None:
|
||||||
out_channels = in_channels
|
out_channels = in_channels
|
||||||
if self.with_conv:
|
if self.with_conv:
|
||||||
# no asymmetric padding in torch conv, must do it ourselves
|
# no asymmetric padding in torch conv, must do it ourselves
|
||||||
self.conv = torch.nn.Conv2d(in_channels,
|
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0)
|
||||||
out_channels,
|
|
||||||
kernel_size=3,
|
|
||||||
stride=2,
|
|
||||||
padding=0)
|
|
||||||
self.compress_time = compress_time
|
self.compress_time = compress_time
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if self.compress_time and x.shape[2] > 1:
|
if self.compress_time and x.shape[2] > 1:
|
||||||
h, w = x.shape[-2:]
|
h, w = x.shape[-2:]
|
||||||
x = rearrange(x, 'b c t h w -> (b h w) c t')
|
x = rearrange(x, "b c t h w -> (b h w) c t")
|
||||||
|
|
||||||
if x.shape[-1] % 2 == 1:
|
if x.shape[-1] % 2 == 1:
|
||||||
# split first frame
|
# split first frame
|
||||||
@ -581,29 +591,40 @@ class DownSample3D(nn.Module):
|
|||||||
if x_rest.shape[-1] > 0:
|
if x_rest.shape[-1] > 0:
|
||||||
x_rest = torch.nn.functional.avg_pool1d(x_rest, kernel_size=2, stride=2)
|
x_rest = torch.nn.functional.avg_pool1d(x_rest, kernel_size=2, stride=2)
|
||||||
x = torch.cat([x_first[..., None], x_rest], dim=-1)
|
x = torch.cat([x_first[..., None], x_rest], dim=-1)
|
||||||
x = rearrange(x, '(b h w) c t -> b c t h w', h=h, w=w)
|
x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w)
|
||||||
else:
|
else:
|
||||||
x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2)
|
x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2)
|
||||||
x = rearrange(x, '(b h w) c t -> b c t h w', h=h, w=w)
|
x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w)
|
||||||
|
|
||||||
if self.with_conv:
|
if self.with_conv:
|
||||||
pad = (0,1,0,1)
|
pad = (0, 1, 0, 1)
|
||||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
||||||
t = x.shape[2]
|
t = x.shape[2]
|
||||||
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||||
x = self.conv(x)
|
x = self.conv(x)
|
||||||
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
|
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
|
||||||
else:
|
else:
|
||||||
t = x.shape[2]
|
t = x.shape[2]
|
||||||
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||||
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
||||||
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
|
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class ContextParallelResnetBlock3D(nn.Module):
|
class ContextParallelResnetBlock3D(nn.Module):
|
||||||
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
def __init__(
|
||||||
dropout, temb_channels=512, zq_ch=None, add_conv=False, gather_norm=False, normalization=Normalize):
|
self,
|
||||||
|
*,
|
||||||
|
in_channels,
|
||||||
|
out_channels=None,
|
||||||
|
conv_shortcut=False,
|
||||||
|
dropout,
|
||||||
|
temb_channels=512,
|
||||||
|
zq_ch=None,
|
||||||
|
add_conv=False,
|
||||||
|
gather_norm=False,
|
||||||
|
normalization=Normalize,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
out_channels = in_channels if out_channels is None else out_channels
|
out_channels = in_channels if out_channels is None else out_channels
|
||||||
@ -623,8 +644,7 @@ class ContextParallelResnetBlock3D(nn.Module):
|
|||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
)
|
)
|
||||||
if temb_channels > 0:
|
if temb_channels > 0:
|
||||||
self.temb_proj = torch.nn.Linear(temb_channels,
|
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
||||||
out_channels)
|
|
||||||
self.norm2 = normalization(
|
self.norm2 = normalization(
|
||||||
out_channels,
|
out_channels,
|
||||||
zq_ch=zq_ch,
|
zq_ch=zq_ch,
|
||||||
@ -669,7 +689,7 @@ class ContextParallelResnetBlock3D(nn.Module):
|
|||||||
h = self.conv1(h, clear_cache=clear_fake_cp_cache)
|
h = self.conv1(h, clear_cache=clear_fake_cp_cache)
|
||||||
|
|
||||||
if temb is not None:
|
if temb is not None:
|
||||||
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None,None]
|
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None]
|
||||||
|
|
||||||
# if isinstance(self.norm2, torch.nn.GroupNorm):
|
# if isinstance(self.norm2, torch.nn.GroupNorm):
|
||||||
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
|
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
|
||||||
@ -690,13 +710,29 @@ class ContextParallelResnetBlock3D(nn.Module):
|
|||||||
else:
|
else:
|
||||||
x = self.nin_shortcut(x)
|
x = self.nin_shortcut(x)
|
||||||
|
|
||||||
return x+h
|
return x + h
|
||||||
|
|
||||||
|
|
||||||
class ContextParallelEncoder3D(nn.Module):
|
class ContextParallelEncoder3D(nn.Module):
|
||||||
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
def __init__(
|
||||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
self,
|
||||||
resolution, z_channels, double_z=True, pad_mode='first', temporal_compress_times=4, gather_norm=False, **ignore_kwargs):
|
*,
|
||||||
|
ch,
|
||||||
|
out_ch,
|
||||||
|
ch_mult=(1, 2, 4, 8),
|
||||||
|
num_res_blocks,
|
||||||
|
attn_resolutions,
|
||||||
|
dropout=0.0,
|
||||||
|
resamp_with_conv=True,
|
||||||
|
in_channels,
|
||||||
|
resolution,
|
||||||
|
z_channels,
|
||||||
|
double_z=True,
|
||||||
|
pad_mode="first",
|
||||||
|
temporal_compress_times=4,
|
||||||
|
gather_norm=False,
|
||||||
|
**ignore_kwargs,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.ch = ch
|
self.ch = ch
|
||||||
self.temb_ch = 0
|
self.temb_ch = 0
|
||||||
@ -715,13 +751,13 @@ class ContextParallelEncoder3D(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
curr_res = resolution
|
curr_res = resolution
|
||||||
in_ch_mult = (1,)+tuple(ch_mult)
|
in_ch_mult = (1,) + tuple(ch_mult)
|
||||||
self.down = nn.ModuleList()
|
self.down = nn.ModuleList()
|
||||||
for i_level in range(self.num_resolutions):
|
for i_level in range(self.num_resolutions):
|
||||||
block = nn.ModuleList()
|
block = nn.ModuleList()
|
||||||
attn = nn.ModuleList()
|
attn = nn.ModuleList()
|
||||||
block_in = ch*in_ch_mult[i_level]
|
block_in = ch * in_ch_mult[i_level]
|
||||||
block_out = ch*ch_mult[i_level]
|
block_out = ch * ch_mult[i_level]
|
||||||
for i_block in range(self.num_res_blocks):
|
for i_block in range(self.num_res_blocks):
|
||||||
block.append(
|
block.append(
|
||||||
ContextParallelResnetBlock3D(
|
ContextParallelResnetBlock3D(
|
||||||
@ -736,7 +772,7 @@ class ContextParallelEncoder3D(nn.Module):
|
|||||||
down = nn.Module()
|
down = nn.Module()
|
||||||
down.block = block
|
down.block = block
|
||||||
down.attn = attn
|
down.attn = attn
|
||||||
if i_level != self.num_resolutions-1:
|
if i_level != self.num_resolutions - 1:
|
||||||
if i_level < self.temporal_compress_level:
|
if i_level < self.temporal_compress_level:
|
||||||
down.downsample = DownSample3D(block_in, resamp_with_conv, compress_time=True)
|
down.downsample = DownSample3D(block_in, resamp_with_conv, compress_time=True)
|
||||||
else:
|
else:
|
||||||
@ -767,12 +803,11 @@ class ContextParallelEncoder3D(nn.Module):
|
|||||||
|
|
||||||
self.conv_out = ContextParallelCausalConv3d(
|
self.conv_out = ContextParallelCausalConv3d(
|
||||||
chan_in=block_in,
|
chan_in=block_in,
|
||||||
chan_out=2*z_channels if double_z else z_channels,
|
chan_out=2 * z_channels if double_z else z_channels,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|
||||||
# timestep embedding
|
# timestep embedding
|
||||||
temb = None
|
temb = None
|
||||||
|
|
||||||
@ -783,7 +818,7 @@ class ContextParallelEncoder3D(nn.Module):
|
|||||||
h = self.down[i_level].block[i_block](h, temb)
|
h = self.down[i_level].block[i_block](h, temb)
|
||||||
if len(self.down[i_level].attn) > 0:
|
if len(self.down[i_level].attn) > 0:
|
||||||
h = self.down[i_level].attn[i_block](h)
|
h = self.down[i_level].attn[i_block](h)
|
||||||
if i_level != self.num_resolutions-1:
|
if i_level != self.num_resolutions - 1:
|
||||||
h = self.down[i_level].downsample(h)
|
h = self.down[i_level].downsample(h)
|
||||||
|
|
||||||
# middle
|
# middle
|
||||||
@ -802,8 +837,27 @@ class ContextParallelEncoder3D(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class ContextParallelDecoder3D(nn.Module):
|
class ContextParallelDecoder3D(nn.Module):
|
||||||
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
def __init__(
|
||||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, resolution, z_channels, give_pre_end=False, zq_ch=None, add_conv=False, pad_mode='first', temporal_compress_times=4, gather_norm=False, **ignorekwargs):
|
self,
|
||||||
|
*,
|
||||||
|
ch,
|
||||||
|
out_ch,
|
||||||
|
ch_mult=(1, 2, 4, 8),
|
||||||
|
num_res_blocks,
|
||||||
|
attn_resolutions,
|
||||||
|
dropout=0.0,
|
||||||
|
resamp_with_conv=True,
|
||||||
|
in_channels,
|
||||||
|
resolution,
|
||||||
|
z_channels,
|
||||||
|
give_pre_end=False,
|
||||||
|
zq_ch=None,
|
||||||
|
add_conv=False,
|
||||||
|
pad_mode="first",
|
||||||
|
temporal_compress_times=4,
|
||||||
|
gather_norm=False,
|
||||||
|
**ignorekwargs,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.ch = ch
|
self.ch = ch
|
||||||
self.temb_ch = 0
|
self.temb_ch = 0
|
||||||
@ -820,12 +874,11 @@ class ContextParallelDecoder3D(nn.Module):
|
|||||||
zq_ch = z_channels
|
zq_ch = z_channels
|
||||||
|
|
||||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||||
in_ch_mult = (1,)+tuple(ch_mult)
|
in_ch_mult = (1,) + tuple(ch_mult)
|
||||||
block_in = ch*ch_mult[self.num_resolutions-1]
|
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||||
curr_res = resolution // 2**(self.num_resolutions-1)
|
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||||
self.z_shape = (1,z_channels,curr_res,curr_res)
|
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||||
print("Working with z of shape {} = {} dimensions.".format(
|
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
|
||||||
self.z_shape, np.prod(self.z_shape)))
|
|
||||||
|
|
||||||
self.conv_in = ContextParallelCausalConv3d(
|
self.conv_in = ContextParallelCausalConv3d(
|
||||||
chan_in=z_channels,
|
chan_in=z_channels,
|
||||||
@ -862,8 +915,8 @@ class ContextParallelDecoder3D(nn.Module):
|
|||||||
for i_level in reversed(range(self.num_resolutions)):
|
for i_level in reversed(range(self.num_resolutions)):
|
||||||
block = nn.ModuleList()
|
block = nn.ModuleList()
|
||||||
attn = nn.ModuleList()
|
attn = nn.ModuleList()
|
||||||
block_out = ch*ch_mult[i_level]
|
block_out = ch * ch_mult[i_level]
|
||||||
for i_block in range(self.num_res_blocks+1):
|
for i_block in range(self.num_res_blocks + 1):
|
||||||
block.append(
|
block.append(
|
||||||
ContextParallelResnetBlock3D(
|
ContextParallelResnetBlock3D(
|
||||||
in_channels=block_in,
|
in_channels=block_in,
|
||||||
@ -896,7 +949,6 @@ class ContextParallelDecoder3D(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, z, clear_fake_cp_cache=True):
|
def forward(self, z, clear_fake_cp_cache=True):
|
||||||
|
|
||||||
self.last_z_shape = z.shape
|
self.last_z_shape = z.shape
|
||||||
|
|
||||||
# timestep embedding
|
# timestep embedding
|
||||||
@ -914,7 +966,7 @@ class ContextParallelDecoder3D(nn.Module):
|
|||||||
|
|
||||||
# upsampling
|
# upsampling
|
||||||
for i_level in reversed(range(self.num_resolutions)):
|
for i_level in reversed(range(self.num_resolutions)):
|
||||||
for i_block in range(self.num_res_blocks+1):
|
for i_block in range(self.num_res_blocks + 1):
|
||||||
h = self.up[i_level].block[i_block](h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache)
|
h = self.up[i_level].block[i_block](h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache)
|
||||||
if len(self.up[i_level].attn) > 0:
|
if len(self.up[i_level].attn) > 0:
|
||||||
h = self.up[i_level].attn[i_block](h, zq)
|
h = self.up[i_level].attn[i_block](h, zq)
|
||||||
|
@ -12,9 +12,7 @@ class LitEma(nn.Module):
|
|||||||
self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
|
self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"num_updates",
|
"num_updates",
|
||||||
torch.tensor(0, dtype=torch.int)
|
torch.tensor(0, dtype=torch.int) if use_num_upates else torch.tensor(-1, dtype=torch.int),
|
||||||
if use_num_upates
|
|
||||||
else torch.tensor(-1, dtype=torch.int),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for name, p in model.named_parameters():
|
for name, p in model.named_parameters():
|
||||||
@ -47,9 +45,7 @@ class LitEma(nn.Module):
|
|||||||
if m_param[key].requires_grad:
|
if m_param[key].requires_grad:
|
||||||
sname = self.m_name2s_name[key]
|
sname = self.m_name2s_name[key]
|
||||||
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
|
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
|
||||||
shadow_params[sname].sub_(
|
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
|
||||||
one_minus_decay * (shadow_params[sname] - m_param[key])
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
assert not key in self.m_name2s_name
|
assert not key in self.m_name2s_name
|
||||||
|
|
||||||
|
@ -6,6 +6,7 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
class DiagonalGaussianDistribution(object):
|
class DiagonalGaussianDistribution(object):
|
||||||
def __init__(self, parameters, deterministic=False):
|
def __init__(self, parameters, deterministic=False):
|
||||||
self.parameters = parameters
|
self.parameters = parameters
|
||||||
@ -15,9 +16,7 @@ class DiagonalGaussianDistribution(object):
|
|||||||
self.std = torch.exp(0.5 * self.logvar)
|
self.std = torch.exp(0.5 * self.logvar)
|
||||||
self.var = torch.exp(self.logvar)
|
self.var = torch.exp(self.logvar)
|
||||||
if self.deterministic:
|
if self.deterministic:
|
||||||
self.var = self.std = torch.zeros_like(self.mean).to(
|
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
|
||||||
device=self.parameters.device
|
|
||||||
)
|
|
||||||
|
|
||||||
def sample(self):
|
def sample(self):
|
||||||
# x = self.mean + self.std * torch.randn(self.mean.shape).to(
|
# x = self.mean + self.std * torch.randn(self.mean.shape).to(
|
||||||
@ -57,6 +56,7 @@ class DiagonalGaussianDistribution(object):
|
|||||||
def mode(self):
|
def mode(self):
|
||||||
return self.mean
|
return self.mean
|
||||||
|
|
||||||
|
|
||||||
class AbstractRegularizer(nn.Module):
|
class AbstractRegularizer(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -77,14 +77,10 @@ class IdentityRegularizer(AbstractRegularizer):
|
|||||||
yield from ()
|
yield from ()
|
||||||
|
|
||||||
|
|
||||||
def measure_perplexity(
|
def measure_perplexity(predicted_indices: torch.Tensor, num_centroids: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
predicted_indices: torch.Tensor, num_centroids: int
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
|
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
|
||||||
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
|
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
|
||||||
encodings = (
|
encodings = F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids)
|
||||||
F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids)
|
|
||||||
)
|
|
||||||
avg_probs = encodings.mean(0)
|
avg_probs = encodings.mean(0)
|
||||||
perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
|
perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
|
||||||
cluster_use = torch.sum(avg_probs > 0)
|
cluster_use = torch.sum(avg_probs > 0)
|
||||||
|
@ -14,17 +14,19 @@ import torch.distributed
|
|||||||
_CONTEXT_PARALLEL_GROUP = None
|
_CONTEXT_PARALLEL_GROUP = None
|
||||||
_CONTEXT_PARALLEL_SIZE = None
|
_CONTEXT_PARALLEL_SIZE = None
|
||||||
|
|
||||||
|
|
||||||
def is_context_parallel_initialized():
|
def is_context_parallel_initialized():
|
||||||
if _CONTEXT_PARALLEL_GROUP is None:
|
if _CONTEXT_PARALLEL_GROUP is None:
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def initialize_context_parallel(context_parallel_size):
|
def initialize_context_parallel(context_parallel_size):
|
||||||
global _CONTEXT_PARALLEL_GROUP
|
global _CONTEXT_PARALLEL_GROUP
|
||||||
global _CONTEXT_PARALLEL_SIZE
|
global _CONTEXT_PARALLEL_SIZE
|
||||||
|
|
||||||
assert _CONTEXT_PARALLEL_GROUP is None, 'context parallel group is already initialized'
|
assert _CONTEXT_PARALLEL_GROUP is None, "context parallel group is already initialized"
|
||||||
_CONTEXT_PARALLEL_SIZE = context_parallel_size
|
_CONTEXT_PARALLEL_SIZE = context_parallel_size
|
||||||
|
|
||||||
rank = torch.distributed.get_rank()
|
rank = torch.distributed.get_rank()
|
||||||
@ -37,40 +39,48 @@ def initialize_context_parallel(context_parallel_size):
|
|||||||
_CONTEXT_PARALLEL_GROUP = group
|
_CONTEXT_PARALLEL_GROUP = group
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
def get_context_parallel_group():
|
def get_context_parallel_group():
|
||||||
assert _CONTEXT_PARALLEL_GROUP is not None, 'context parallel group is not initialized'
|
assert _CONTEXT_PARALLEL_GROUP is not None, "context parallel group is not initialized"
|
||||||
|
|
||||||
return _CONTEXT_PARALLEL_GROUP
|
return _CONTEXT_PARALLEL_GROUP
|
||||||
|
|
||||||
|
|
||||||
def get_context_parallel_world_size():
|
def get_context_parallel_world_size():
|
||||||
assert _CONTEXT_PARALLEL_SIZE is not None, 'context parallel size is not initialized'
|
assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized"
|
||||||
|
|
||||||
return _CONTEXT_PARALLEL_SIZE
|
return _CONTEXT_PARALLEL_SIZE
|
||||||
|
|
||||||
|
|
||||||
def get_context_parallel_rank():
|
def get_context_parallel_rank():
|
||||||
assert _CONTEXT_PARALLEL_SIZE is not None, 'context parallel size is not initialized'
|
assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized"
|
||||||
|
|
||||||
rank = torch.distributed.get_rank()
|
rank = torch.distributed.get_rank()
|
||||||
cp_rank = rank % _CONTEXT_PARALLEL_SIZE
|
cp_rank = rank % _CONTEXT_PARALLEL_SIZE
|
||||||
return cp_rank
|
return cp_rank
|
||||||
|
|
||||||
|
|
||||||
def get_context_parallel_group_rank():
|
def get_context_parallel_group_rank():
|
||||||
assert _CONTEXT_PARALLEL_SIZE is not None, 'context parallel size is not initialized'
|
assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized"
|
||||||
|
|
||||||
rank = torch.distributed.get_rank()
|
rank = torch.distributed.get_rank()
|
||||||
cp_group_rank = rank // _CONTEXT_PARALLEL_SIZE
|
cp_group_rank = rank // _CONTEXT_PARALLEL_SIZE
|
||||||
|
|
||||||
return cp_group_rank
|
return cp_group_rank
|
||||||
|
|
||||||
|
|
||||||
class SafeConv3d(torch.nn.Conv3d):
|
class SafeConv3d(torch.nn.Conv3d):
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3
|
memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3
|
||||||
if memory_count > 2:
|
if memory_count > 2:
|
||||||
kernel_size = self.kernel_size[0]
|
kernel_size = self.kernel_size[0]
|
||||||
part_num = int(memory_count / 2) + 1
|
part_num = int(memory_count / 2) + 1
|
||||||
input_chunks = torch.chunk(input, part_num, dim=2) # NCTHW
|
input_chunks = torch.chunk(input, part_num, dim=2) # NCTHW
|
||||||
if kernel_size > 1:
|
if kernel_size > 1:
|
||||||
input_chunks = [input_chunks[0]] + [torch.cat((input_chunks[i-1][:, :, -kernel_size+1:], input_chunks[i]), dim=2) for i in range(1, len(input_chunks))]
|
input_chunks = [input_chunks[0]] + [
|
||||||
|
torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2)
|
||||||
|
for i in range(1, len(input_chunks))
|
||||||
|
]
|
||||||
|
|
||||||
output_chunks = []
|
output_chunks = []
|
||||||
for input_chunk in input_chunks:
|
for input_chunk in input_chunks:
|
||||||
@ -149,9 +159,7 @@ def log_txt_as_img(wh, xc, size=10):
|
|||||||
text_seq = xc[bi][0]
|
text_seq = xc[bi][0]
|
||||||
else:
|
else:
|
||||||
text_seq = xc[bi]
|
text_seq = xc[bi]
|
||||||
lines = "\n".join(
|
lines = "\n".join(text_seq[start : start + nc] for start in range(0, len(text_seq), nc))
|
||||||
text_seq[start : start + nc] for start in range(0, len(text_seq), nc)
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
draw.text((0, 0), lines, fill="black", font=font)
|
draw.text((0, 0), lines, fill="black", font=font)
|
||||||
@ -263,9 +271,7 @@ def append_dims(x, target_dims):
|
|||||||
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
||||||
dims_to_append = target_dims - x.ndim
|
dims_to_append = target_dims - x.ndim
|
||||||
if dims_to_append < 0:
|
if dims_to_append < 0:
|
||||||
raise ValueError(
|
raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
|
||||||
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
|
|
||||||
)
|
|
||||||
return x[(...,) + (None,) * dims_to_append]
|
return x[(...,) + (None,) * dims_to_append]
|
||||||
|
|
||||||
|
|
||||||
@ -361,6 +367,7 @@ def checkpoint(func, inputs, params, flag):
|
|||||||
else:
|
else:
|
||||||
return func(*inputs)
|
return func(*inputs)
|
||||||
|
|
||||||
|
|
||||||
class CheckpointFunction(torch.autograd.Function):
|
class CheckpointFunction(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, run_function, length, *args):
|
def forward(ctx, run_function, length, *args):
|
||||||
@ -395,4 +402,3 @@ class CheckpointFunction(torch.autograd.Function):
|
|||||||
del ctx.input_params
|
del ctx.input_params
|
||||||
del output_tensors
|
del output_tensors
|
||||||
return (None, None) + input_grads
|
return (None, None) + input_grads
|
||||||
|
|
Loading…
x
Reference in New Issue
Block a user