mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 11:28:37 +08:00
1012 lines
33 KiB
Python
1012 lines
33 KiB
Python
import math
|
|
import torch
|
|
import torch.distributed
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import numpy as np
|
|
|
|
from beartype import beartype
|
|
from beartype.typing import Union, Tuple, Optional, List
|
|
from einops import rearrange
|
|
|
|
from sgm.util import (
|
|
get_context_parallel_group,
|
|
get_context_parallel_rank,
|
|
get_context_parallel_world_size,
|
|
get_context_parallel_group_rank,
|
|
)
|
|
|
|
from vae_modules.utils import SafeConv3d as Conv3d
|
|
|
|
|
|
def cast_tuple(t, length=1):
|
|
return t if isinstance(t, tuple) else ((t,) * length)
|
|
|
|
|
|
def divisible_by(num, den):
|
|
return (num % den) == 0
|
|
|
|
|
|
def is_odd(n):
|
|
return not divisible_by(n, 2)
|
|
|
|
|
|
def exists(v):
|
|
return v is not None
|
|
|
|
|
|
def pair(t):
|
|
return t if isinstance(t, tuple) else (t, t)
|
|
|
|
|
|
def get_timestep_embedding(timesteps, embedding_dim):
|
|
"""
|
|
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
|
From Fairseq.
|
|
Build sinusoidal embeddings.
|
|
This matches the implementation in tensor2tensor, but differs slightly
|
|
from the description in Section 3.5 of "Attention Is All You Need".
|
|
"""
|
|
assert len(timesteps.shape) == 1
|
|
|
|
half_dim = embedding_dim // 2
|
|
emb = math.log(10000) / (half_dim - 1)
|
|
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
|
emb = emb.to(device=timesteps.device)
|
|
emb = timesteps.float()[:, None] * emb[None, :]
|
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
|
if embedding_dim % 2 == 1: # zero pad
|
|
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
|
return emb
|
|
|
|
|
|
def nonlinearity(x):
|
|
# swish
|
|
return x * torch.sigmoid(x)
|
|
|
|
|
|
def leaky_relu(p=0.1):
|
|
return nn.LeakyReLU(p)
|
|
|
|
|
|
def _split(input_, dim):
|
|
cp_world_size = get_context_parallel_world_size()
|
|
|
|
if cp_world_size == 1:
|
|
return input_
|
|
|
|
cp_rank = get_context_parallel_rank()
|
|
|
|
inpu_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous()
|
|
input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous()
|
|
dim_size = input_.size()[dim] // cp_world_size
|
|
|
|
input_list = torch.split(input_, dim_size, dim=dim)
|
|
output = input_list[cp_rank]
|
|
|
|
if cp_rank == 0:
|
|
output = torch.cat([inpu_first_frame_, output], dim=dim)
|
|
output = output.contiguous()
|
|
|
|
return output
|
|
|
|
|
|
def _gather(input_, dim):
|
|
cp_world_size = get_context_parallel_world_size()
|
|
|
|
# Bypass the function if context parallel is 1
|
|
if cp_world_size == 1:
|
|
return input_
|
|
|
|
group = get_context_parallel_group()
|
|
cp_rank = get_context_parallel_rank()
|
|
|
|
input_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous()
|
|
if cp_rank == 0:
|
|
input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous()
|
|
|
|
tensor_list = [torch.empty_like(torch.cat([input_first_frame_, input_], dim=dim))] + [
|
|
torch.empty_like(input_) for _ in range(cp_world_size - 1)
|
|
]
|
|
|
|
if cp_rank == 0:
|
|
input_ = torch.cat([input_first_frame_, input_], dim=dim)
|
|
|
|
tensor_list[cp_rank] = input_
|
|
torch.distributed.all_gather(tensor_list, input_, group=group)
|
|
|
|
output = torch.cat(tensor_list, dim=dim).contiguous()
|
|
|
|
# print('out _gather, cp_rank:', cp_rank, 'output_size:', output.shape)
|
|
|
|
return output
|
|
|
|
|
|
def _conv_split(input_, dim, kernel_size):
|
|
cp_world_size = get_context_parallel_world_size()
|
|
|
|
if cp_world_size == 1:
|
|
return input_
|
|
|
|
cp_rank = get_context_parallel_rank()
|
|
|
|
dim_size = (input_.size()[dim] - kernel_size) // cp_world_size
|
|
|
|
if cp_rank == 0:
|
|
output = input_.transpose(dim, 0)[: dim_size + kernel_size].transpose(dim, 0)
|
|
else:
|
|
output = input_.transpose(dim, 0)[
|
|
cp_rank * dim_size + kernel_size : (cp_rank + 1) * dim_size + kernel_size
|
|
].transpose(dim, 0)
|
|
output = output.contiguous()
|
|
|
|
return output
|
|
|
|
|
|
def _conv_gather(input_, dim, kernel_size):
|
|
cp_world_size = get_context_parallel_world_size()
|
|
|
|
# Bypass the function if context parallel is 1
|
|
if cp_world_size == 1:
|
|
return input_
|
|
|
|
group = get_context_parallel_group()
|
|
cp_rank = get_context_parallel_rank()
|
|
input_first_kernel_ = input_.transpose(0, dim)[:kernel_size].transpose(0, dim).contiguous()
|
|
if cp_rank == 0:
|
|
input_ = input_.transpose(0, dim)[kernel_size:].transpose(0, dim).contiguous()
|
|
else:
|
|
input_ = input_.transpose(0, dim)[max(kernel_size - 1, 0) :].transpose(0, dim).contiguous()
|
|
|
|
tensor_list = [torch.empty_like(torch.cat([input_first_kernel_, input_], dim=dim))] + [
|
|
torch.empty_like(input_) for _ in range(cp_world_size - 1)
|
|
]
|
|
if cp_rank == 0:
|
|
input_ = torch.cat([input_first_kernel_, input_], dim=dim)
|
|
|
|
tensor_list[cp_rank] = input_
|
|
torch.distributed.all_gather(tensor_list, input_, group=group)
|
|
|
|
# Note: torch.cat already creates a contiguous tensor.
|
|
output = torch.cat(tensor_list, dim=dim).contiguous()
|
|
|
|
# print('out _conv_gather, cp_rank:', cp_rank, 'input_size:', output.shape)
|
|
|
|
return output
|
|
|
|
|
|
def _pass_from_previous_rank(input_, dim, kernel_size):
|
|
# Bypass the function if kernel size is 1
|
|
if kernel_size == 1:
|
|
return input_
|
|
|
|
group = get_context_parallel_group()
|
|
cp_rank = get_context_parallel_rank()
|
|
cp_group_rank = get_context_parallel_group_rank()
|
|
cp_world_size = get_context_parallel_world_size()
|
|
|
|
# print('in _pass_from_previous_rank, cp_rank:', cp_rank, 'input_size:', input_.shape)
|
|
|
|
global_rank = torch.distributed.get_rank()
|
|
global_world_size = torch.distributed.get_world_size()
|
|
|
|
input_ = input_.transpose(0, dim)
|
|
|
|
# pass from last rank
|
|
send_rank = global_rank + 1
|
|
recv_rank = global_rank - 1
|
|
if send_rank % cp_world_size == 0:
|
|
send_rank -= cp_world_size
|
|
if recv_rank % cp_world_size == cp_world_size - 1:
|
|
recv_rank += cp_world_size
|
|
|
|
if cp_rank < cp_world_size - 1:
|
|
req_send = torch.distributed.isend(
|
|
input_[-kernel_size + 1 :].contiguous(), send_rank, group=group
|
|
)
|
|
if cp_rank > 0:
|
|
recv_buffer = torch.empty_like(input_[-kernel_size + 1 :]).contiguous()
|
|
req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group)
|
|
|
|
if cp_rank == 0:
|
|
input_ = torch.cat([input_[:1]] * (kernel_size - 1) + [input_], dim=0)
|
|
else:
|
|
req_recv.wait()
|
|
input_ = torch.cat([recv_buffer, input_], dim=0)
|
|
|
|
input_ = input_.transpose(0, dim).contiguous()
|
|
|
|
# print('out _pass_from_previous_rank, cp_rank:', cp_rank, 'input_size:', input_.shape)
|
|
|
|
return input_
|
|
|
|
|
|
def _fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding=None):
|
|
# Bypass the function if kernel size is 1
|
|
if kernel_size == 1:
|
|
return input_
|
|
|
|
group = get_context_parallel_group()
|
|
cp_rank = get_context_parallel_rank()
|
|
cp_group_rank = get_context_parallel_group_rank()
|
|
cp_world_size = get_context_parallel_world_size()
|
|
|
|
# print('in _pass_from_previous_rank, cp_rank:', cp_rank, 'input_size:', input_.shape)
|
|
|
|
global_rank = torch.distributed.get_rank()
|
|
global_world_size = torch.distributed.get_world_size()
|
|
|
|
input_ = input_.transpose(0, dim)
|
|
|
|
# pass from last rank
|
|
send_rank = global_rank + 1
|
|
recv_rank = global_rank - 1
|
|
if send_rank % cp_world_size == 0:
|
|
send_rank -= cp_world_size
|
|
if recv_rank % cp_world_size == cp_world_size - 1:
|
|
recv_rank += cp_world_size
|
|
|
|
recv_buffer = torch.empty_like(input_[-kernel_size + 1 :]).contiguous()
|
|
if cp_rank < cp_world_size - 1:
|
|
req_send = torch.distributed.isend(
|
|
input_[-kernel_size + 1 :].contiguous(), send_rank, group=group
|
|
)
|
|
if cp_rank > 0:
|
|
req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group)
|
|
|
|
if cp_rank == 0:
|
|
if cache_padding is not None:
|
|
input_ = torch.cat([cache_padding.transpose(0, dim).to(input_.device), input_], dim=0)
|
|
else:
|
|
input_ = torch.cat([input_[:1]] * (kernel_size - 1) + [input_], dim=0)
|
|
else:
|
|
req_recv.wait()
|
|
input_ = torch.cat([recv_buffer, input_], dim=0)
|
|
|
|
input_ = input_.transpose(0, dim).contiguous()
|
|
return input_
|
|
|
|
|
|
def _drop_from_previous_rank(input_, dim, kernel_size):
|
|
input_ = input_.transpose(0, dim)[kernel_size - 1 :].transpose(0, dim)
|
|
return input_
|
|
|
|
|
|
class _ConvolutionScatterToContextParallelRegion(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, input_, dim, kernel_size):
|
|
ctx.dim = dim
|
|
ctx.kernel_size = kernel_size
|
|
return _conv_split(input_, dim, kernel_size)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
return _conv_gather(grad_output, ctx.dim, ctx.kernel_size), None, None
|
|
|
|
|
|
class _ConvolutionGatherFromContextParallelRegion(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, input_, dim, kernel_size):
|
|
ctx.dim = dim
|
|
ctx.kernel_size = kernel_size
|
|
return _conv_gather(input_, dim, kernel_size)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
return _conv_split(grad_output, ctx.dim, ctx.kernel_size), None, None
|
|
|
|
|
|
class _ConvolutionPassFromPreviousRank(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, input_, dim, kernel_size):
|
|
ctx.dim = dim
|
|
ctx.kernel_size = kernel_size
|
|
return _pass_from_previous_rank(input_, dim, kernel_size)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
return _drop_from_previous_rank(grad_output, ctx.dim, ctx.kernel_size), None, None
|
|
|
|
|
|
class _FakeCPConvolutionPassFromPreviousRank(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, input_, dim, kernel_size, cache_padding):
|
|
ctx.dim = dim
|
|
ctx.kernel_size = kernel_size
|
|
return _fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
return _drop_from_previous_rank(grad_output, ctx.dim, ctx.kernel_size), None, None, None
|
|
|
|
|
|
def conv_scatter_to_context_parallel_region(input_, dim, kernel_size):
|
|
return _ConvolutionScatterToContextParallelRegion.apply(input_, dim, kernel_size)
|
|
|
|
|
|
def conv_gather_from_context_parallel_region(input_, dim, kernel_size):
|
|
return _ConvolutionGatherFromContextParallelRegion.apply(input_, dim, kernel_size)
|
|
|
|
|
|
def conv_pass_from_last_rank(input_, dim, kernel_size):
|
|
return _ConvolutionPassFromPreviousRank.apply(input_, dim, kernel_size)
|
|
|
|
|
|
def fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding):
|
|
return _FakeCPConvolutionPassFromPreviousRank.apply(input_, dim, kernel_size, cache_padding)
|
|
|
|
|
|
class ContextParallelCausalConv3d(nn.Module):
|
|
def __init__(
|
|
self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], stride=1, **kwargs
|
|
):
|
|
super().__init__()
|
|
kernel_size = cast_tuple(kernel_size, 3)
|
|
|
|
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
|
|
|
|
assert is_odd(height_kernel_size) and is_odd(width_kernel_size)
|
|
|
|
time_pad = time_kernel_size - 1
|
|
height_pad = height_kernel_size // 2
|
|
width_pad = width_kernel_size // 2
|
|
|
|
self.height_pad = height_pad
|
|
self.width_pad = width_pad
|
|
self.time_pad = time_pad
|
|
self.time_kernel_size = time_kernel_size
|
|
self.temporal_dim = 2
|
|
|
|
stride = (stride, stride, stride)
|
|
dilation = (1, 1, 1)
|
|
self.conv = Conv3d(
|
|
chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs
|
|
)
|
|
self.cache_padding = None
|
|
|
|
def forward(self, input_, clear_cache=True):
|
|
input_parallel = fake_cp_pass_from_previous_rank(
|
|
input_, self.temporal_dim, self.time_kernel_size, self.cache_padding
|
|
)
|
|
|
|
del self.cache_padding
|
|
self.cache_padding = None
|
|
if not clear_cache:
|
|
cp_rank, cp_world_size = get_context_parallel_rank(), get_context_parallel_world_size()
|
|
global_rank = torch.distributed.get_rank()
|
|
if cp_world_size == 1:
|
|
self.cache_padding = (
|
|
input_parallel[:, :, -self.time_kernel_size + 1 :]
|
|
.contiguous()
|
|
.detach()
|
|
.clone()
|
|
.cpu()
|
|
)
|
|
else:
|
|
if cp_rank == cp_world_size - 1:
|
|
torch.distributed.isend(
|
|
input_parallel[:, :, -self.time_kernel_size + 1 :].contiguous(),
|
|
global_rank + 1 - cp_world_size,
|
|
group=get_context_parallel_group(),
|
|
)
|
|
if cp_rank == 0:
|
|
recv_buffer = torch.empty_like(
|
|
input_parallel[:, :, -self.time_kernel_size + 1 :]
|
|
).contiguous()
|
|
torch.distributed.recv(
|
|
recv_buffer,
|
|
global_rank - 1 + cp_world_size,
|
|
group=get_context_parallel_group(),
|
|
)
|
|
self.cache_padding = recv_buffer.contiguous().detach().clone().cpu()
|
|
|
|
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
|
|
input_parallel = F.pad(input_parallel, padding_2d, mode="constant", value=0)
|
|
|
|
output_parallel = self.conv(input_parallel)
|
|
output = output_parallel
|
|
return output
|
|
|
|
|
|
class ContextParallelGroupNorm(torch.nn.GroupNorm):
|
|
def forward(self, input_):
|
|
gather_flag = input_.shape[2] > 1
|
|
if gather_flag:
|
|
input_ = conv_gather_from_context_parallel_region(input_, dim=2, kernel_size=1)
|
|
output = super().forward(input_)
|
|
if gather_flag:
|
|
output = conv_scatter_to_context_parallel_region(output, dim=2, kernel_size=1)
|
|
return output
|
|
|
|
|
|
def Normalize(in_channels, gather=False, **kwargs):
|
|
if gather:
|
|
return ContextParallelGroupNorm(
|
|
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
|
)
|
|
else:
|
|
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
|
|
|
|
|
class SpatialNorm3D(nn.Module):
|
|
def __init__(
|
|
self,
|
|
f_channels,
|
|
zq_channels,
|
|
freeze_norm_layer=False,
|
|
add_conv=False,
|
|
pad_mode="constant",
|
|
gather=False,
|
|
**norm_layer_params,
|
|
):
|
|
super().__init__()
|
|
if gather:
|
|
self.norm_layer = ContextParallelGroupNorm(num_channels=f_channels, **norm_layer_params)
|
|
else:
|
|
self.norm_layer = torch.nn.GroupNorm(num_channels=f_channels, **norm_layer_params)
|
|
# self.norm_layer = norm_layer(num_channels=f_channels, **norm_layer_params)
|
|
if freeze_norm_layer:
|
|
for p in self.norm_layer.parameters:
|
|
p.requires_grad = False
|
|
|
|
self.add_conv = add_conv
|
|
if add_conv:
|
|
self.conv = ContextParallelCausalConv3d(
|
|
chan_in=zq_channels,
|
|
chan_out=zq_channels,
|
|
kernel_size=3,
|
|
)
|
|
|
|
self.conv_y = ContextParallelCausalConv3d(
|
|
chan_in=zq_channels,
|
|
chan_out=f_channels,
|
|
kernel_size=1,
|
|
)
|
|
self.conv_b = ContextParallelCausalConv3d(
|
|
chan_in=zq_channels,
|
|
chan_out=f_channels,
|
|
kernel_size=1,
|
|
)
|
|
|
|
def forward(self, f, zq, clear_fake_cp_cache=True, fake_cp=True):
|
|
if f.shape[2] > 1 and get_context_parallel_rank() == 0 and fake_cp:
|
|
f_first, f_rest = f[:, :, :1], f[:, :, 1:]
|
|
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
|
|
zq_first, zq_rest = zq[:, :, :1], zq[:, :, 1:]
|
|
zq_first = torch.nn.functional.interpolate(zq_first, size=f_first_size, mode="nearest")
|
|
|
|
zq_rest_splits = torch.split(zq_rest, 32, dim=1)
|
|
interpolated_splits = [
|
|
torch.nn.functional.interpolate(split, size=f_rest_size, mode="nearest")
|
|
for split in zq_rest_splits
|
|
]
|
|
|
|
zq_rest = torch.cat(interpolated_splits, dim=1)
|
|
# zq_rest = torch.nn.functional.interpolate(zq_rest, size=f_rest_size, mode="nearest")
|
|
zq = torch.cat([zq_first, zq_rest], dim=2)
|
|
else:
|
|
f_size = f.shape[-3:]
|
|
|
|
zq_splits = torch.split(zq, 32, dim=1)
|
|
interpolated_splits = [
|
|
torch.nn.functional.interpolate(split, size=f_size, mode="nearest")
|
|
for split in zq_splits
|
|
]
|
|
zq = torch.cat(interpolated_splits, dim=1)
|
|
|
|
if self.add_conv:
|
|
zq = self.conv(zq, clear_cache=clear_fake_cp_cache)
|
|
|
|
norm_f = self.norm_layer(f)
|
|
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
|
|
return new_f
|
|
|
|
|
|
def Normalize3D(
|
|
in_channels,
|
|
zq_ch,
|
|
add_conv,
|
|
gather=False,
|
|
):
|
|
return SpatialNorm3D(
|
|
in_channels,
|
|
zq_ch,
|
|
gather=gather,
|
|
freeze_norm_layer=False,
|
|
add_conv=add_conv,
|
|
num_groups=32,
|
|
eps=1e-6,
|
|
affine=True,
|
|
)
|
|
|
|
|
|
class Upsample3D(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels,
|
|
with_conv,
|
|
compress_time=False,
|
|
):
|
|
super().__init__()
|
|
self.with_conv = with_conv
|
|
if self.with_conv:
|
|
self.conv = torch.nn.Conv2d(
|
|
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
|
)
|
|
self.compress_time = compress_time
|
|
|
|
def forward(self, x, fake_cp=True):
|
|
if self.compress_time and x.shape[2] > 1:
|
|
if get_context_parallel_rank() == 0 and fake_cp:
|
|
# split first frame
|
|
x_first, x_rest = x[:, :, 0], x[:, :, 1:]
|
|
x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0, mode="nearest")
|
|
|
|
splits = torch.split(x_rest, 32, dim=1)
|
|
interpolated_splits = [
|
|
torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest")
|
|
for split in splits
|
|
]
|
|
x_rest = torch.cat(interpolated_splits, dim=1)
|
|
x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2)
|
|
else:
|
|
splits = torch.split(x, 32, dim=1)
|
|
interpolated_splits = [
|
|
torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest")
|
|
for split in splits
|
|
]
|
|
x = torch.cat(interpolated_splits, dim=1)
|
|
|
|
else:
|
|
# only interpolate 2D
|
|
t = x.shape[2]
|
|
x = rearrange(x, "b c t h w -> (b t) c h w")
|
|
|
|
splits = torch.split(x, 32, dim=1)
|
|
interpolated_splits = [
|
|
torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest")
|
|
for split in splits
|
|
]
|
|
x = torch.cat(interpolated_splits, dim=1)
|
|
|
|
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
|
|
|
|
if self.with_conv:
|
|
t = x.shape[2]
|
|
x = rearrange(x, "b c t h w -> (b t) c h w")
|
|
x = self.conv(x)
|
|
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
|
|
return x
|
|
|
|
|
|
class DownSample3D(nn.Module):
|
|
def __init__(self, in_channels, with_conv, compress_time=False, out_channels=None):
|
|
super().__init__()
|
|
self.with_conv = with_conv
|
|
if out_channels is None:
|
|
out_channels = in_channels
|
|
if self.with_conv:
|
|
# no asymmetric padding in torch conv, must do it ourselves
|
|
self.conv = torch.nn.Conv2d(
|
|
in_channels, out_channels, kernel_size=3, stride=2, padding=0
|
|
)
|
|
self.compress_time = compress_time
|
|
|
|
def forward(self, x, fake_cp=True):
|
|
if self.compress_time and x.shape[2] > 1:
|
|
h, w = x.shape[-2:]
|
|
x = rearrange(x, "b c t h w -> (b h w) c t")
|
|
|
|
if get_context_parallel_rank() == 0 and fake_cp:
|
|
# split first frame
|
|
x_first, x_rest = x[..., 0], x[..., 1:]
|
|
|
|
if x_rest.shape[-1] > 0:
|
|
splits = torch.split(x_rest, 32, dim=1)
|
|
interpolated_splits = [
|
|
torch.nn.functional.avg_pool1d(split, kernel_size=2, stride=2)
|
|
for split in splits
|
|
]
|
|
x_rest = torch.cat(interpolated_splits, dim=1)
|
|
x = torch.cat([x_first[..., None], x_rest], dim=-1)
|
|
x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w)
|
|
else:
|
|
# x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2)
|
|
splits = torch.split(x, 32, dim=1)
|
|
interpolated_splits = [
|
|
torch.nn.functional.avg_pool1d(split, kernel_size=2, stride=2)
|
|
for split in splits
|
|
]
|
|
x = torch.cat(interpolated_splits, dim=1)
|
|
x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w)
|
|
|
|
if self.with_conv:
|
|
pad = (0, 1, 0, 1)
|
|
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
|
t = x.shape[2]
|
|
x = rearrange(x, "b c t h w -> (b t) c h w")
|
|
x = self.conv(x)
|
|
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
|
|
else:
|
|
t = x.shape[2]
|
|
x = rearrange(x, "b c t h w -> (b t) c h w")
|
|
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
|
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
|
|
return x
|
|
|
|
|
|
class ContextParallelResnetBlock3D(nn.Module):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
in_channels,
|
|
out_channels=None,
|
|
conv_shortcut=False,
|
|
dropout,
|
|
temb_channels=512,
|
|
zq_ch=None,
|
|
add_conv=False,
|
|
gather_norm=False,
|
|
normalization=Normalize,
|
|
):
|
|
super().__init__()
|
|
self.in_channels = in_channels
|
|
out_channels = in_channels if out_channels is None else out_channels
|
|
self.out_channels = out_channels
|
|
self.use_conv_shortcut = conv_shortcut
|
|
|
|
self.norm1 = normalization(
|
|
in_channels,
|
|
zq_ch=zq_ch,
|
|
add_conv=add_conv,
|
|
gather=gather_norm,
|
|
)
|
|
|
|
self.conv1 = ContextParallelCausalConv3d(
|
|
chan_in=in_channels,
|
|
chan_out=out_channels,
|
|
kernel_size=3,
|
|
)
|
|
if temb_channels > 0:
|
|
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
|
self.norm2 = normalization(
|
|
out_channels,
|
|
zq_ch=zq_ch,
|
|
add_conv=add_conv,
|
|
gather=gather_norm,
|
|
)
|
|
self.dropout = torch.nn.Dropout(dropout)
|
|
self.conv2 = ContextParallelCausalConv3d(
|
|
chan_in=out_channels,
|
|
chan_out=out_channels,
|
|
kernel_size=3,
|
|
)
|
|
if self.in_channels != self.out_channels:
|
|
if self.use_conv_shortcut:
|
|
self.conv_shortcut = ContextParallelCausalConv3d(
|
|
chan_in=in_channels,
|
|
chan_out=out_channels,
|
|
kernel_size=3,
|
|
)
|
|
else:
|
|
self.nin_shortcut = Conv3d(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0,
|
|
)
|
|
|
|
def forward(self, x, temb, zq=None, clear_fake_cp_cache=True, fake_cp=True):
|
|
h = x
|
|
|
|
if zq is not None:
|
|
h = self.norm1(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=fake_cp)
|
|
else:
|
|
h = self.norm1(h)
|
|
|
|
h = nonlinearity(h)
|
|
h = self.conv1(h, clear_cache=clear_fake_cp_cache)
|
|
|
|
if temb is not None:
|
|
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None]
|
|
|
|
if zq is not None:
|
|
h = self.norm2(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=fake_cp)
|
|
else:
|
|
h = self.norm2(h)
|
|
|
|
h = nonlinearity(h)
|
|
h = self.dropout(h)
|
|
h = self.conv2(h, clear_cache=clear_fake_cp_cache)
|
|
|
|
if self.in_channels != self.out_channels:
|
|
if self.use_conv_shortcut:
|
|
x = self.conv_shortcut(x, clear_cache=clear_fake_cp_cache)
|
|
else:
|
|
x = self.nin_shortcut(x)
|
|
|
|
return x + h
|
|
|
|
|
|
class ContextParallelEncoder3D(nn.Module):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
ch,
|
|
out_ch,
|
|
ch_mult=(1, 2, 4, 8),
|
|
num_res_blocks,
|
|
attn_resolutions,
|
|
dropout=0.0,
|
|
resamp_with_conv=True,
|
|
in_channels,
|
|
resolution,
|
|
z_channels,
|
|
double_z=True,
|
|
pad_mode="first",
|
|
temporal_compress_times=4,
|
|
gather_norm=False,
|
|
**ignore_kwargs,
|
|
):
|
|
super().__init__()
|
|
self.ch = ch
|
|
self.temb_ch = 0
|
|
self.num_resolutions = len(ch_mult)
|
|
self.num_res_blocks = num_res_blocks
|
|
self.resolution = resolution
|
|
self.in_channels = in_channels
|
|
|
|
# log2 of temporal_compress_times
|
|
self.temporal_compress_level = int(np.log2(temporal_compress_times))
|
|
|
|
self.conv_in = ContextParallelCausalConv3d(
|
|
chan_in=in_channels,
|
|
chan_out=self.ch,
|
|
kernel_size=3,
|
|
)
|
|
|
|
curr_res = resolution
|
|
in_ch_mult = (1,) + tuple(ch_mult)
|
|
self.down = nn.ModuleList()
|
|
for i_level in range(self.num_resolutions):
|
|
block = nn.ModuleList()
|
|
attn = nn.ModuleList()
|
|
block_in = ch * in_ch_mult[i_level]
|
|
block_out = ch * ch_mult[i_level]
|
|
for i_block in range(self.num_res_blocks):
|
|
block.append(
|
|
ContextParallelResnetBlock3D(
|
|
in_channels=block_in,
|
|
out_channels=block_out,
|
|
dropout=dropout,
|
|
temb_channels=self.temb_ch,
|
|
gather_norm=gather_norm,
|
|
)
|
|
)
|
|
block_in = block_out
|
|
down = nn.Module()
|
|
down.block = block
|
|
down.attn = attn
|
|
if i_level != self.num_resolutions - 1:
|
|
if i_level < self.temporal_compress_level:
|
|
down.downsample = DownSample3D(block_in, resamp_with_conv, compress_time=True)
|
|
else:
|
|
down.downsample = DownSample3D(block_in, resamp_with_conv, compress_time=False)
|
|
curr_res = curr_res // 2
|
|
self.down.append(down)
|
|
|
|
# middle
|
|
self.mid = nn.Module()
|
|
self.mid.block_1 = ContextParallelResnetBlock3D(
|
|
in_channels=block_in,
|
|
out_channels=block_in,
|
|
temb_channels=self.temb_ch,
|
|
dropout=dropout,
|
|
gather_norm=gather_norm,
|
|
)
|
|
|
|
self.mid.block_2 = ContextParallelResnetBlock3D(
|
|
in_channels=block_in,
|
|
out_channels=block_in,
|
|
temb_channels=self.temb_ch,
|
|
dropout=dropout,
|
|
gather_norm=gather_norm,
|
|
)
|
|
|
|
# end
|
|
self.norm_out = Normalize(block_in, gather=gather_norm)
|
|
|
|
self.conv_out = ContextParallelCausalConv3d(
|
|
chan_in=block_in,
|
|
chan_out=2 * z_channels if double_z else z_channels,
|
|
kernel_size=3,
|
|
)
|
|
|
|
def forward(self, x, use_cp=True):
|
|
global _USE_CP
|
|
_USE_CP = use_cp
|
|
|
|
# timestep embedding
|
|
temb = None
|
|
|
|
# downsampling
|
|
hs = [self.conv_in(x)]
|
|
for i_level in range(self.num_resolutions):
|
|
for i_block in range(self.num_res_blocks):
|
|
h = self.down[i_level].block[i_block](hs[-1], temb)
|
|
if len(self.down[i_level].attn) > 0:
|
|
h = self.down[i_level].attn[i_block](h)
|
|
hs.append(h)
|
|
if i_level != self.num_resolutions - 1:
|
|
hs.append(self.down[i_level].downsample(hs[-1]))
|
|
|
|
# middle
|
|
h = hs[-1]
|
|
h = self.mid.block_1(h, temb)
|
|
h = self.mid.block_2(h, temb)
|
|
|
|
# end
|
|
h = self.norm_out(h)
|
|
h = nonlinearity(h)
|
|
h = self.conv_out(h)
|
|
|
|
return h
|
|
|
|
|
|
class ContextParallelDecoder3D(nn.Module):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
ch,
|
|
out_ch,
|
|
ch_mult=(1, 2, 4, 8),
|
|
num_res_blocks,
|
|
attn_resolutions,
|
|
dropout=0.0,
|
|
resamp_with_conv=True,
|
|
in_channels,
|
|
resolution,
|
|
z_channels,
|
|
give_pre_end=False,
|
|
zq_ch=None,
|
|
add_conv=False,
|
|
pad_mode="first",
|
|
temporal_compress_times=4,
|
|
gather_norm=False,
|
|
**ignorekwargs,
|
|
):
|
|
super().__init__()
|
|
self.ch = ch
|
|
self.temb_ch = 0
|
|
self.num_resolutions = len(ch_mult)
|
|
self.num_res_blocks = num_res_blocks
|
|
self.resolution = resolution
|
|
self.in_channels = in_channels
|
|
self.give_pre_end = give_pre_end
|
|
|
|
# log2 of temporal_compress_times
|
|
self.temporal_compress_level = int(np.log2(temporal_compress_times))
|
|
|
|
if zq_ch is None:
|
|
zq_ch = z_channels
|
|
|
|
# compute in_ch_mult, block_in and curr_res at lowest res
|
|
block_in = ch * ch_mult[self.num_resolutions - 1]
|
|
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
|
self.z_shape = (1, z_channels, curr_res, curr_res)
|
|
|
|
self.conv_in = ContextParallelCausalConv3d(
|
|
chan_in=z_channels,
|
|
chan_out=block_in,
|
|
kernel_size=3,
|
|
)
|
|
|
|
# middle
|
|
self.mid = nn.Module()
|
|
self.mid.block_1 = ContextParallelResnetBlock3D(
|
|
in_channels=block_in,
|
|
out_channels=block_in,
|
|
temb_channels=self.temb_ch,
|
|
dropout=dropout,
|
|
zq_ch=zq_ch,
|
|
add_conv=add_conv,
|
|
normalization=Normalize3D,
|
|
gather_norm=gather_norm,
|
|
)
|
|
|
|
self.mid.block_2 = ContextParallelResnetBlock3D(
|
|
in_channels=block_in,
|
|
out_channels=block_in,
|
|
temb_channels=self.temb_ch,
|
|
dropout=dropout,
|
|
zq_ch=zq_ch,
|
|
add_conv=add_conv,
|
|
normalization=Normalize3D,
|
|
gather_norm=gather_norm,
|
|
)
|
|
|
|
# upsampling
|
|
self.up = nn.ModuleList()
|
|
for i_level in reversed(range(self.num_resolutions)):
|
|
block = nn.ModuleList()
|
|
attn = nn.ModuleList()
|
|
block_out = ch * ch_mult[i_level]
|
|
for i_block in range(self.num_res_blocks + 1):
|
|
block.append(
|
|
ContextParallelResnetBlock3D(
|
|
in_channels=block_in,
|
|
out_channels=block_out,
|
|
temb_channels=self.temb_ch,
|
|
dropout=dropout,
|
|
zq_ch=zq_ch,
|
|
add_conv=add_conv,
|
|
normalization=Normalize3D,
|
|
gather_norm=gather_norm,
|
|
)
|
|
)
|
|
block_in = block_out
|
|
up = nn.Module()
|
|
up.block = block
|
|
up.attn = attn
|
|
if i_level != 0:
|
|
if i_level < self.num_resolutions - self.temporal_compress_level:
|
|
up.upsample = Upsample3D(
|
|
block_in, with_conv=resamp_with_conv, compress_time=False
|
|
)
|
|
else:
|
|
up.upsample = Upsample3D(
|
|
block_in, with_conv=resamp_with_conv, compress_time=True
|
|
)
|
|
self.up.insert(0, up)
|
|
|
|
self.norm_out = Normalize3D(block_in, zq_ch, add_conv=add_conv, gather=gather_norm)
|
|
|
|
self.conv_out = ContextParallelCausalConv3d(
|
|
chan_in=block_in,
|
|
chan_out=out_ch,
|
|
kernel_size=3,
|
|
)
|
|
|
|
def forward(self, z, clear_fake_cp_cache=True, use_cp=True):
|
|
global _USE_CP
|
|
_USE_CP = use_cp
|
|
self.last_z_shape = z.shape
|
|
|
|
# timestep embedding
|
|
temb = None
|
|
|
|
t = z.shape[2]
|
|
# z to block_in
|
|
|
|
zq = z
|
|
h = self.conv_in(z, clear_cache=clear_fake_cp_cache)
|
|
|
|
# middle
|
|
h = self.mid.block_1(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=use_cp)
|
|
h = self.mid.block_2(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=use_cp)
|
|
|
|
# upsampling
|
|
for i_level in reversed(range(self.num_resolutions)):
|
|
for i_block in range(self.num_res_blocks + 1):
|
|
h = self.up[i_level].block[i_block](
|
|
h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=use_cp
|
|
)
|
|
if len(self.up[i_level].attn) > 0:
|
|
h = self.up[i_level].attn[i_block](h, zq)
|
|
if i_level != 0:
|
|
h = self.up[i_level].upsample(h, fake_cp=use_cp)
|
|
|
|
# end
|
|
if self.give_pre_end:
|
|
return h
|
|
|
|
h = self.norm_out(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=use_cp)
|
|
h = nonlinearity(h)
|
|
h = self.conv_out(h, clear_cache=clear_fake_cp_cache)
|
|
|
|
return h
|
|
|
|
def get_last_layer(self):
|
|
return self.conv_out.conv.weight
|