mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 03:04:56 +08:00
1889 lines
55 KiB
Python
1889 lines
55 KiB
Python
import copy
|
|
from pathlib import Path
|
|
from math import log2, ceil, sqrt
|
|
from functools import wraps, partial
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch.cuda.amp import autocast
|
|
from torch import nn, einsum, Tensor
|
|
from torch.nn import Module, ModuleList
|
|
from torch.autograd import grad as torch_grad
|
|
|
|
import torchvision
|
|
from torchvision.models import VGG16_Weights
|
|
|
|
from collections import namedtuple
|
|
|
|
# from vector_quantize_pytorch import LFQ, FSQ
|
|
from .regularizers.finite_scalar_quantization import FSQ
|
|
from .regularizers.lookup_free_quantization import LFQ
|
|
|
|
from einops import rearrange, repeat, reduce, pack, unpack
|
|
from einops.layers.torch import Rearrange
|
|
|
|
from beartype import beartype
|
|
from beartype.typing import Union, Tuple, Optional, List
|
|
|
|
from magvit2_pytorch.attend import Attend
|
|
from magvit2_pytorch.version import __version__
|
|
|
|
from gateloop_transformer import SimpleGateLoopLayer
|
|
|
|
from taylor_series_linear_attention import TaylorSeriesLinearAttn
|
|
|
|
from kornia.filters import filter3d
|
|
|
|
import pickle
|
|
|
|
# helper
|
|
|
|
|
|
def exists(v):
|
|
return v is not None
|
|
|
|
|
|
def default(v, d):
|
|
return v if exists(v) else d
|
|
|
|
|
|
def safe_get_index(it, ind, default=None):
|
|
if ind < len(it):
|
|
return it[ind]
|
|
return default
|
|
|
|
|
|
def pair(t):
|
|
return t if isinstance(t, tuple) else (t, t)
|
|
|
|
|
|
def identity(t, *args, **kwargs):
|
|
return t
|
|
|
|
|
|
def divisible_by(num, den):
|
|
return (num % den) == 0
|
|
|
|
|
|
def pack_one(t, pattern):
|
|
return pack([t], pattern)
|
|
|
|
|
|
def unpack_one(t, ps, pattern):
|
|
return unpack(t, ps, pattern)[0]
|
|
|
|
|
|
def append_dims(t, ndims: int):
|
|
return t.reshape(*t.shape, *((1,) * ndims))
|
|
|
|
|
|
def is_odd(n):
|
|
return not divisible_by(n, 2)
|
|
|
|
|
|
def maybe_del_attr_(o, attr):
|
|
if hasattr(o, attr):
|
|
delattr(o, attr)
|
|
|
|
|
|
def cast_tuple(t, length=1):
|
|
return t if isinstance(t, tuple) else ((t,) * length)
|
|
|
|
|
|
# tensor helpers
|
|
|
|
|
|
def l2norm(t):
|
|
return F.normalize(t, dim=-1, p=2)
|
|
|
|
|
|
def pad_at_dim(t, pad, dim=-1, value=0.0):
|
|
dims_from_right = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
|
|
zeros = (0, 0) * dims_from_right
|
|
return F.pad(t, (*zeros, *pad), value=value)
|
|
|
|
|
|
def pick_video_frame(video, frame_indices):
|
|
batch, device = video.shape[0], video.device
|
|
video = rearrange(video, "b c f ... -> b f c ...")
|
|
batch_indices = torch.arange(batch, device=device)
|
|
batch_indices = rearrange(batch_indices, "b -> b 1")
|
|
images = video[batch_indices, frame_indices]
|
|
images = rearrange(images, "b 1 c ... -> b c ...")
|
|
return images
|
|
|
|
|
|
# gan related
|
|
|
|
|
|
def gradient_penalty(images, output):
|
|
batch_size = images.shape[0]
|
|
|
|
gradients = torch_grad(
|
|
outputs=output,
|
|
inputs=images,
|
|
grad_outputs=torch.ones(output.size(), device=images.device),
|
|
create_graph=True,
|
|
retain_graph=True,
|
|
only_inputs=True,
|
|
)[0]
|
|
|
|
gradients = rearrange(gradients, "b ... -> b (...)")
|
|
return ((gradients.norm(2, dim=1) - 1) ** 2).mean()
|
|
|
|
|
|
def leaky_relu(p=0.1):
|
|
return nn.LeakyReLU(p)
|
|
|
|
|
|
def hinge_discr_loss(fake, real):
|
|
return (F.relu(1 + fake) + F.relu(1 - real)).mean()
|
|
|
|
|
|
def hinge_gen_loss(fake):
|
|
return -fake.mean()
|
|
|
|
|
|
@autocast(enabled=False)
|
|
@beartype
|
|
def grad_layer_wrt_loss(loss: Tensor, layer: nn.Parameter):
|
|
return torch_grad(
|
|
outputs=loss, inputs=layer, grad_outputs=torch.ones_like(loss), retain_graph=True
|
|
)[0].detach()
|
|
|
|
|
|
# helper decorators
|
|
|
|
|
|
def remove_vgg(fn):
|
|
@wraps(fn)
|
|
def inner(self, *args, **kwargs):
|
|
has_vgg = hasattr(self, "vgg")
|
|
if has_vgg:
|
|
vgg = self.vgg
|
|
delattr(self, "vgg")
|
|
|
|
out = fn(self, *args, **kwargs)
|
|
|
|
if has_vgg:
|
|
self.vgg = vgg
|
|
|
|
return out
|
|
|
|
return inner
|
|
|
|
|
|
# helper classes
|
|
|
|
|
|
def Sequential(*modules):
|
|
modules = [*filter(exists, modules)]
|
|
|
|
if len(modules) == 0:
|
|
return nn.Identity()
|
|
|
|
return nn.Sequential(*modules)
|
|
|
|
|
|
class Residual(Module):
|
|
@beartype
|
|
def __init__(self, fn: Module):
|
|
super().__init__()
|
|
self.fn = fn
|
|
|
|
def forward(self, x, **kwargs):
|
|
return self.fn(x, **kwargs) + x
|
|
|
|
|
|
# for a bunch of tensor operations to change tensor to (batch, time, feature dimension) and back
|
|
|
|
|
|
class ToTimeSequence(Module):
|
|
@beartype
|
|
def __init__(self, fn: Module):
|
|
super().__init__()
|
|
self.fn = fn
|
|
|
|
def forward(self, x, **kwargs):
|
|
x = rearrange(x, "b c f ... -> b ... f c")
|
|
x, ps = pack_one(x, "* n c")
|
|
|
|
o = self.fn(x, **kwargs)
|
|
|
|
o = unpack_one(o, ps, "* n c")
|
|
return rearrange(o, "b ... f c -> b c f ...")
|
|
|
|
|
|
class SqueezeExcite(Module):
|
|
# global context network - attention-esque squeeze-excite variant (https://arxiv.org/abs/2012.13375)
|
|
|
|
def __init__(self, dim, *, dim_out=None, dim_hidden_min=16, init_bias=-10):
|
|
super().__init__()
|
|
dim_out = default(dim_out, dim)
|
|
|
|
self.to_k = nn.Conv2d(dim, 1, 1)
|
|
dim_hidden = max(dim_hidden_min, dim_out // 2)
|
|
|
|
self.net = nn.Sequential(
|
|
nn.Conv2d(dim, dim_hidden, 1),
|
|
nn.LeakyReLU(0.1),
|
|
nn.Conv2d(dim_hidden, dim_out, 1),
|
|
nn.Sigmoid(),
|
|
)
|
|
|
|
nn.init.zeros_(self.net[-2].weight)
|
|
nn.init.constant_(self.net[-2].bias, init_bias)
|
|
|
|
def forward(self, x):
|
|
orig_input, batch = x, x.shape[0]
|
|
is_video = x.ndim == 5
|
|
|
|
if is_video:
|
|
x = rearrange(x, "b c f h w -> (b f) c h w")
|
|
|
|
context = self.to_k(x)
|
|
|
|
context = rearrange(context, "b c h w -> b c (h w)").softmax(dim=-1)
|
|
spatial_flattened_input = rearrange(x, "b c h w -> b c (h w)")
|
|
|
|
out = einsum("b i n, b c n -> b c i", context, spatial_flattened_input)
|
|
out = rearrange(out, "... -> ... 1")
|
|
gates = self.net(out)
|
|
|
|
if is_video:
|
|
gates = rearrange(gates, "(b f) c h w -> b c f h w", b=batch)
|
|
|
|
return gates * orig_input
|
|
|
|
|
|
# token shifting
|
|
|
|
|
|
class TokenShift(Module):
|
|
@beartype
|
|
def __init__(self, fn: Module):
|
|
super().__init__()
|
|
self.fn = fn
|
|
|
|
def forward(self, x, **kwargs):
|
|
x, x_shift = x.chunk(2, dim=1)
|
|
x_shift = pad_at_dim(x_shift, (1, -1), dim=2) # shift time dimension
|
|
x = torch.cat((x, x_shift), dim=1)
|
|
return self.fn(x, **kwargs)
|
|
|
|
|
|
# rmsnorm
|
|
|
|
|
|
class RMSNorm(Module):
|
|
def __init__(self, dim, channel_first=False, images=False, bias=False):
|
|
super().__init__()
|
|
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
|
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
|
|
|
self.channel_first = channel_first
|
|
self.scale = dim**0.5
|
|
self.gamma = nn.Parameter(torch.ones(shape))
|
|
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
|
|
|
|
def forward(self, x):
|
|
return (
|
|
F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma
|
|
+ self.bias
|
|
)
|
|
|
|
|
|
class AdaptiveRMSNorm(Module):
|
|
def __init__(self, dim, *, dim_cond, channel_first=False, images=False, bias=False):
|
|
super().__init__()
|
|
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
|
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
|
|
|
self.dim_cond = dim_cond
|
|
self.channel_first = channel_first
|
|
self.scale = dim**0.5
|
|
|
|
self.to_gamma = nn.Linear(dim_cond, dim)
|
|
self.to_bias = nn.Linear(dim_cond, dim) if bias else None
|
|
|
|
nn.init.zeros_(self.to_gamma.weight)
|
|
nn.init.ones_(self.to_gamma.bias)
|
|
|
|
if bias:
|
|
nn.init.zeros_(self.to_bias.weight)
|
|
nn.init.zeros_(self.to_bias.bias)
|
|
|
|
@beartype
|
|
def forward(self, x: Tensor, *, cond: Tensor):
|
|
batch = x.shape[0]
|
|
assert cond.shape == (batch, self.dim_cond)
|
|
|
|
gamma = self.to_gamma(cond)
|
|
|
|
bias = 0.0
|
|
if exists(self.to_bias):
|
|
bias = self.to_bias(cond)
|
|
|
|
if self.channel_first:
|
|
gamma = append_dims(gamma, x.ndim - 2)
|
|
|
|
if exists(self.to_bias):
|
|
bias = append_dims(bias, x.ndim - 2)
|
|
|
|
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * gamma + bias
|
|
|
|
|
|
# attention
|
|
|
|
|
|
class Attention(Module):
|
|
@beartype
|
|
def __init__(
|
|
self,
|
|
*,
|
|
dim,
|
|
dim_cond: Optional[int] = None,
|
|
causal=False,
|
|
dim_head=32,
|
|
heads=8,
|
|
flash=False,
|
|
dropout=0.0,
|
|
num_memory_kv=4,
|
|
):
|
|
super().__init__()
|
|
dim_inner = dim_head * heads
|
|
|
|
self.need_cond = exists(dim_cond)
|
|
|
|
if self.need_cond:
|
|
self.norm = AdaptiveRMSNorm(dim, dim_cond=dim_cond)
|
|
else:
|
|
self.norm = RMSNorm(dim)
|
|
|
|
self.to_qkv = nn.Sequential(
|
|
nn.Linear(dim, dim_inner * 3, bias=False),
|
|
Rearrange("b n (qkv h d) -> qkv b h n d", qkv=3, h=heads),
|
|
)
|
|
|
|
assert num_memory_kv > 0
|
|
self.mem_kv = nn.Parameter(torch.randn(2, heads, num_memory_kv, dim_head))
|
|
|
|
self.attend = Attend(causal=causal, dropout=dropout, flash=flash)
|
|
|
|
self.to_out = nn.Sequential(
|
|
Rearrange("b h n d -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False)
|
|
)
|
|
|
|
@beartype
|
|
def forward(self, x, mask: Optional[Tensor] = None, cond: Optional[Tensor] = None):
|
|
maybe_cond_kwargs = dict(cond=cond) if self.need_cond else dict()
|
|
|
|
x = self.norm(x, **maybe_cond_kwargs)
|
|
|
|
q, k, v = self.to_qkv(x)
|
|
|
|
mk, mv = map(lambda t: repeat(t, "h n d -> b h n d", b=q.shape[0]), self.mem_kv)
|
|
k = torch.cat((mk, k), dim=-2)
|
|
v = torch.cat((mv, v), dim=-2)
|
|
|
|
out = self.attend(q, k, v, mask=mask)
|
|
return self.to_out(out)
|
|
|
|
|
|
class LinearAttention(Module):
|
|
"""
|
|
using the specific linear attention proposed in https://arxiv.org/abs/2106.09681
|
|
"""
|
|
|
|
@beartype
|
|
def __init__(self, *, dim, dim_cond: Optional[int] = None, dim_head=8, heads=8, dropout=0.0):
|
|
super().__init__()
|
|
dim_inner = dim_head * heads
|
|
|
|
self.need_cond = exists(dim_cond)
|
|
|
|
if self.need_cond:
|
|
self.norm = AdaptiveRMSNorm(dim, dim_cond=dim_cond)
|
|
else:
|
|
self.norm = RMSNorm(dim)
|
|
|
|
self.attn = TaylorSeriesLinearAttn(dim=dim, dim_head=dim_head, heads=heads)
|
|
|
|
def forward(self, x, cond: Optional[Tensor] = None):
|
|
maybe_cond_kwargs = dict(cond=cond) if self.need_cond else dict()
|
|
|
|
x = self.norm(x, **maybe_cond_kwargs)
|
|
|
|
return self.attn(x)
|
|
|
|
|
|
class LinearSpaceAttention(LinearAttention):
|
|
def forward(self, x, *args, **kwargs):
|
|
x = rearrange(x, "b c ... h w -> b ... h w c")
|
|
x, batch_ps = pack_one(x, "* h w c")
|
|
x, seq_ps = pack_one(x, "b * c")
|
|
|
|
x = super().forward(x, *args, **kwargs)
|
|
|
|
x = unpack_one(x, seq_ps, "b * c")
|
|
x = unpack_one(x, batch_ps, "* h w c")
|
|
return rearrange(x, "b ... h w c -> b c ... h w")
|
|
|
|
|
|
class SpaceAttention(Attention):
|
|
def forward(self, x, *args, **kwargs):
|
|
x = rearrange(x, "b c t h w -> b t h w c")
|
|
x, batch_ps = pack_one(x, "* h w c")
|
|
x, seq_ps = pack_one(x, "b * c")
|
|
|
|
x = super().forward(x, *args, **kwargs)
|
|
|
|
x = unpack_one(x, seq_ps, "b * c")
|
|
x = unpack_one(x, batch_ps, "* h w c")
|
|
return rearrange(x, "b t h w c -> b c t h w")
|
|
|
|
|
|
class TimeAttention(Attention):
|
|
def forward(self, x, *args, **kwargs):
|
|
x = rearrange(x, "b c t h w -> b h w t c")
|
|
x, batch_ps = pack_one(x, "* t c")
|
|
|
|
x = super().forward(x, *args, **kwargs)
|
|
|
|
x = unpack_one(x, batch_ps, "* t c")
|
|
return rearrange(x, "b h w t c -> b c t h w")
|
|
|
|
|
|
class GEGLU(Module):
|
|
def forward(self, x):
|
|
x, gate = x.chunk(2, dim=1)
|
|
return F.gelu(gate) * x
|
|
|
|
|
|
class FeedForward(Module):
|
|
@beartype
|
|
def __init__(self, dim, *, dim_cond: Optional[int] = None, mult=4, images=False):
|
|
super().__init__()
|
|
conv_klass = nn.Conv2d if images else nn.Conv3d
|
|
|
|
rmsnorm_klass = (
|
|
RMSNorm if not exists(dim_cond) else partial(AdaptiveRMSNorm, dim_cond=dim_cond)
|
|
)
|
|
|
|
maybe_adaptive_norm_klass = partial(rmsnorm_klass, channel_first=True, images=images)
|
|
|
|
dim_inner = int(dim * mult * 2 / 3)
|
|
|
|
self.norm = maybe_adaptive_norm_klass(dim)
|
|
|
|
self.net = Sequential(
|
|
conv_klass(dim, dim_inner * 2, 1), GEGLU(), conv_klass(dim_inner, dim, 1)
|
|
)
|
|
|
|
@beartype
|
|
def forward(self, x: Tensor, *, cond: Optional[Tensor] = None):
|
|
maybe_cond_kwargs = dict(cond=cond) if exists(cond) else dict()
|
|
|
|
x = self.norm(x, **maybe_cond_kwargs)
|
|
return self.net(x)
|
|
|
|
|
|
# discriminator with anti-aliased downsampling (blurpool Zhang et al.)
|
|
|
|
|
|
class Blur(Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
f = torch.Tensor([1, 2, 1])
|
|
self.register_buffer("f", f)
|
|
|
|
def forward(self, x, space_only=False, time_only=False):
|
|
assert not (space_only and time_only)
|
|
|
|
f = self.f
|
|
|
|
if space_only:
|
|
f = einsum("i, j -> i j", f, f)
|
|
f = rearrange(f, "... -> 1 1 ...")
|
|
elif time_only:
|
|
f = rearrange(f, "f -> 1 f 1 1")
|
|
else:
|
|
f = einsum("i, j, k -> i j k", f, f, f)
|
|
f = rearrange(f, "... -> 1 ...")
|
|
|
|
is_images = x.ndim == 4
|
|
|
|
if is_images:
|
|
x = rearrange(x, "b c h w -> b c 1 h w")
|
|
|
|
out = filter3d(x, f, normalized=True)
|
|
|
|
if is_images:
|
|
out = rearrange(out, "b c 1 h w -> b c h w")
|
|
|
|
return out
|
|
|
|
|
|
class DiscriminatorBlock(Module):
|
|
def __init__(self, input_channels, filters, downsample=True, antialiased_downsample=True):
|
|
super().__init__()
|
|
self.conv_res = nn.Conv2d(input_channels, filters, 1, stride=(2 if downsample else 1))
|
|
|
|
self.net = nn.Sequential(
|
|
nn.Conv2d(input_channels, filters, 3, padding=1),
|
|
leaky_relu(),
|
|
nn.Conv2d(filters, filters, 3, padding=1),
|
|
leaky_relu(),
|
|
)
|
|
|
|
self.maybe_blur = Blur() if antialiased_downsample else None
|
|
|
|
self.downsample = (
|
|
nn.Sequential(
|
|
Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2),
|
|
nn.Conv2d(filters * 4, filters, 1),
|
|
)
|
|
if downsample
|
|
else None
|
|
)
|
|
|
|
def forward(self, x):
|
|
res = self.conv_res(x)
|
|
|
|
x = self.net(x)
|
|
|
|
if exists(self.downsample):
|
|
if exists(self.maybe_blur):
|
|
x = self.maybe_blur(x, space_only=True)
|
|
|
|
x = self.downsample(x)
|
|
|
|
x = (x + res) * (2**-0.5)
|
|
return x
|
|
|
|
|
|
class Discriminator(Module):
|
|
@beartype
|
|
def __init__(
|
|
self,
|
|
*,
|
|
dim,
|
|
image_size,
|
|
channels=3,
|
|
max_dim=512,
|
|
attn_heads=8,
|
|
attn_dim_head=32,
|
|
linear_attn_dim_head=8,
|
|
linear_attn_heads=16,
|
|
ff_mult=4,
|
|
antialiased_downsample=False,
|
|
):
|
|
super().__init__()
|
|
image_size = pair(image_size)
|
|
min_image_resolution = min(image_size)
|
|
|
|
num_layers = int(log2(min_image_resolution) - 2)
|
|
|
|
blocks = []
|
|
|
|
layer_dims = [channels] + [(dim * 4) * (2**i) for i in range(num_layers + 1)]
|
|
layer_dims = [min(layer_dim, max_dim) for layer_dim in layer_dims]
|
|
layer_dims_in_out = tuple(zip(layer_dims[:-1], layer_dims[1:]))
|
|
|
|
blocks = []
|
|
attn_blocks = []
|
|
|
|
image_resolution = min_image_resolution
|
|
|
|
for ind, (in_chan, out_chan) in enumerate(layer_dims_in_out):
|
|
num_layer = ind + 1
|
|
is_not_last = ind != (len(layer_dims_in_out) - 1)
|
|
|
|
block = DiscriminatorBlock(
|
|
in_chan,
|
|
out_chan,
|
|
downsample=is_not_last,
|
|
antialiased_downsample=antialiased_downsample,
|
|
)
|
|
|
|
attn_block = Sequential(
|
|
Residual(
|
|
LinearSpaceAttention(
|
|
dim=out_chan, heads=linear_attn_heads, dim_head=linear_attn_dim_head
|
|
)
|
|
),
|
|
Residual(FeedForward(dim=out_chan, mult=ff_mult, images=True)),
|
|
)
|
|
|
|
blocks.append(ModuleList([block, attn_block]))
|
|
|
|
image_resolution //= 2
|
|
|
|
self.blocks = ModuleList(blocks)
|
|
|
|
dim_last = layer_dims[-1]
|
|
|
|
downsample_factor = 2**num_layers
|
|
last_fmap_size = tuple(map(lambda n: n // downsample_factor, image_size))
|
|
|
|
latent_dim = last_fmap_size[0] * last_fmap_size[1] * dim_last
|
|
|
|
self.to_logits = Sequential(
|
|
nn.Conv2d(dim_last, dim_last, 3, padding=1),
|
|
leaky_relu(),
|
|
Rearrange("b ... -> b (...)"),
|
|
nn.Linear(latent_dim, 1),
|
|
Rearrange("b 1 -> b"),
|
|
)
|
|
|
|
def forward(self, x):
|
|
for block, attn_block in self.blocks:
|
|
x = block(x)
|
|
x = attn_block(x)
|
|
|
|
return self.to_logits(x)
|
|
|
|
|
|
# modulatable conv from Karras et al. Stylegan2
|
|
# for conditioning on latents
|
|
|
|
|
|
class Conv3DMod(Module):
|
|
@beartype
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
*,
|
|
spatial_kernel,
|
|
time_kernel,
|
|
causal=True,
|
|
dim_out=None,
|
|
demod=True,
|
|
eps=1e-8,
|
|
pad_mode="zeros",
|
|
):
|
|
super().__init__()
|
|
dim_out = default(dim_out, dim)
|
|
|
|
self.eps = eps
|
|
|
|
assert is_odd(spatial_kernel) and is_odd(time_kernel)
|
|
|
|
self.spatial_kernel = spatial_kernel
|
|
self.time_kernel = time_kernel
|
|
|
|
time_padding = (time_kernel - 1, 0) if causal else ((time_kernel // 2,) * 2)
|
|
|
|
self.pad_mode = pad_mode
|
|
self.padding = (*((spatial_kernel // 2,) * 4), *time_padding)
|
|
self.weights = nn.Parameter(
|
|
torch.randn((dim_out, dim, time_kernel, spatial_kernel, spatial_kernel))
|
|
)
|
|
|
|
self.demod = demod
|
|
|
|
nn.init.kaiming_normal_(self.weights, a=0, mode="fan_in", nonlinearity="selu")
|
|
|
|
@beartype
|
|
def forward(self, fmap, cond: Tensor):
|
|
"""
|
|
notation
|
|
|
|
b - batch
|
|
n - convs
|
|
o - output
|
|
i - input
|
|
k - kernel
|
|
"""
|
|
|
|
b = fmap.shape[0]
|
|
|
|
# prepare weights for modulation
|
|
|
|
weights = self.weights
|
|
|
|
# do the modulation, demodulation, as done in stylegan2
|
|
|
|
cond = rearrange(cond, "b i -> b 1 i 1 1 1")
|
|
|
|
weights = weights * (cond + 1)
|
|
|
|
if self.demod:
|
|
inv_norm = (
|
|
reduce(weights**2, "b o i k0 k1 k2 -> b o 1 1 1 1", "sum")
|
|
.clamp(min=self.eps)
|
|
.rsqrt()
|
|
)
|
|
weights = weights * inv_norm
|
|
|
|
fmap = rearrange(fmap, "b c t h w -> 1 (b c) t h w")
|
|
|
|
weights = rearrange(weights, "b o ... -> (b o) ...")
|
|
|
|
fmap = F.pad(fmap, self.padding, mode=self.pad_mode)
|
|
fmap = F.conv3d(fmap, weights, groups=b)
|
|
|
|
return rearrange(fmap, "1 (b o) ... -> b o ...", b=b)
|
|
|
|
|
|
# strided conv downsamples
|
|
|
|
|
|
class SpatialDownsample2x(Module):
|
|
def __init__(self, dim, dim_out=None, kernel_size=3, antialias=False):
|
|
super().__init__()
|
|
dim_out = default(dim_out, dim)
|
|
self.maybe_blur = Blur() if antialias else identity
|
|
self.conv = nn.Conv2d(dim, dim_out, kernel_size, stride=2, padding=kernel_size // 2)
|
|
|
|
def forward(self, x):
|
|
x = self.maybe_blur(x, space_only=True)
|
|
|
|
x = rearrange(x, "b c t h w -> b t c h w")
|
|
x, ps = pack_one(x, "* c h w")
|
|
|
|
out = self.conv(x)
|
|
|
|
out = unpack_one(out, ps, "* c h w")
|
|
out = rearrange(out, "b t c h w -> b c t h w")
|
|
return out
|
|
|
|
|
|
class TimeDownsample2x(Module):
|
|
def __init__(self, dim, dim_out=None, kernel_size=3, antialias=False):
|
|
super().__init__()
|
|
dim_out = default(dim_out, dim)
|
|
self.maybe_blur = Blur() if antialias else identity
|
|
self.time_causal_padding = (kernel_size - 1, 0)
|
|
self.conv = nn.Conv1d(dim, dim_out, kernel_size, stride=2)
|
|
|
|
def forward(self, x):
|
|
x = self.maybe_blur(x, time_only=True)
|
|
|
|
x = rearrange(x, "b c t h w -> b h w c t")
|
|
x, ps = pack_one(x, "* c t")
|
|
|
|
x = F.pad(x, self.time_causal_padding)
|
|
out = self.conv(x)
|
|
|
|
out = unpack_one(out, ps, "* c t")
|
|
out = rearrange(out, "b h w c t -> b c t h w")
|
|
return out
|
|
|
|
|
|
# depth to space upsamples
|
|
|
|
|
|
class SpatialUpsample2x(Module):
|
|
def __init__(self, dim, dim_out=None):
|
|
super().__init__()
|
|
dim_out = default(dim_out, dim)
|
|
conv = nn.Conv2d(dim, dim_out * 4, 1)
|
|
|
|
self.net = nn.Sequential(
|
|
conv, nn.SiLU(), Rearrange("b (c p1 p2) h w -> b c (h p1) (w p2)", p1=2, p2=2)
|
|
)
|
|
|
|
self.init_conv_(conv)
|
|
|
|
def init_conv_(self, conv):
|
|
o, i, h, w = conv.weight.shape
|
|
conv_weight = torch.empty(o // 4, i, h, w)
|
|
nn.init.kaiming_uniform_(conv_weight)
|
|
conv_weight = repeat(conv_weight, "o ... -> (o 4) ...")
|
|
|
|
conv.weight.data.copy_(conv_weight)
|
|
nn.init.zeros_(conv.bias.data)
|
|
|
|
def forward(self, x):
|
|
x = rearrange(x, "b c t h w -> b t c h w")
|
|
x, ps = pack_one(x, "* c h w")
|
|
|
|
out = self.net(x)
|
|
|
|
out = unpack_one(out, ps, "* c h w")
|
|
out = rearrange(out, "b t c h w -> b c t h w")
|
|
return out
|
|
|
|
|
|
class TimeUpsample2x(Module):
|
|
def __init__(self, dim, dim_out=None):
|
|
super().__init__()
|
|
dim_out = default(dim_out, dim)
|
|
conv = nn.Conv1d(dim, dim_out * 2, 1)
|
|
|
|
self.net = nn.Sequential(conv, nn.SiLU(), Rearrange("b (c p) t -> b c (t p)", p=2))
|
|
|
|
self.init_conv_(conv)
|
|
|
|
def init_conv_(self, conv):
|
|
o, i, t = conv.weight.shape
|
|
conv_weight = torch.empty(o // 2, i, t)
|
|
nn.init.kaiming_uniform_(conv_weight)
|
|
conv_weight = repeat(conv_weight, "o ... -> (o 2) ...")
|
|
|
|
conv.weight.data.copy_(conv_weight)
|
|
nn.init.zeros_(conv.bias.data)
|
|
|
|
def forward(self, x):
|
|
x = rearrange(x, "b c t h w -> b h w c t")
|
|
x, ps = pack_one(x, "* c t")
|
|
|
|
out = self.net(x)
|
|
|
|
out = unpack_one(out, ps, "* c t")
|
|
out = rearrange(out, "b h w c t -> b c t h w")
|
|
return out
|
|
|
|
|
|
# autoencoder - only best variant here offered, with causal conv 3d
|
|
|
|
|
|
def SameConv2d(dim_in, dim_out, kernel_size):
|
|
kernel_size = cast_tuple(kernel_size, 2)
|
|
padding = [k // 2 for k in kernel_size]
|
|
return nn.Conv2d(dim_in, dim_out, kernel_size=kernel_size, padding=padding)
|
|
|
|
|
|
class CausalConv3d(Module):
|
|
@beartype
|
|
def __init__(
|
|
self,
|
|
chan_in,
|
|
chan_out,
|
|
kernel_size: Union[int, Tuple[int, int, int]],
|
|
pad_mode="constant",
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
kernel_size = cast_tuple(kernel_size, 3)
|
|
|
|
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
|
|
|
|
assert is_odd(height_kernel_size) and is_odd(width_kernel_size)
|
|
|
|
dilation = kwargs.pop("dilation", 1)
|
|
stride = kwargs.pop("stride", 1)
|
|
|
|
self.pad_mode = pad_mode
|
|
time_pad = dilation * (time_kernel_size - 1) + (1 - stride)
|
|
height_pad = height_kernel_size // 2
|
|
width_pad = width_kernel_size // 2
|
|
|
|
self.time_pad = time_pad
|
|
self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
|
|
|
|
stride = (stride, 1, 1)
|
|
dilation = (dilation, 1, 1)
|
|
self.conv = nn.Conv3d(
|
|
chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs
|
|
)
|
|
|
|
def forward(self, x):
|
|
pad_mode = self.pad_mode if self.time_pad < x.shape[2] else "constant"
|
|
|
|
x = F.pad(x, self.time_causal_padding, mode=pad_mode)
|
|
return self.conv(x)
|
|
|
|
|
|
@beartype
|
|
def ResidualUnit(dim, kernel_size: Union[int, Tuple[int, int, int]], pad_mode: str = "constant"):
|
|
net = Sequential(
|
|
CausalConv3d(dim, dim, kernel_size, pad_mode=pad_mode),
|
|
nn.ELU(),
|
|
nn.Conv3d(dim, dim, 1),
|
|
nn.ELU(),
|
|
SqueezeExcite(dim),
|
|
)
|
|
|
|
return Residual(net)
|
|
|
|
|
|
@beartype
|
|
class ResidualUnitMod(Module):
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
kernel_size: Union[int, Tuple[int, int, int]],
|
|
*,
|
|
dim_cond,
|
|
pad_mode: str = "constant",
|
|
demod=True,
|
|
):
|
|
super().__init__()
|
|
kernel_size = cast_tuple(kernel_size, 3)
|
|
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
|
|
assert height_kernel_size == width_kernel_size
|
|
|
|
self.to_cond = nn.Linear(dim_cond, dim)
|
|
|
|
self.conv = Conv3DMod(
|
|
dim=dim,
|
|
spatial_kernel=height_kernel_size,
|
|
time_kernel=time_kernel_size,
|
|
causal=True,
|
|
demod=demod,
|
|
pad_mode=pad_mode,
|
|
)
|
|
|
|
self.conv_out = nn.Conv3d(dim, dim, 1)
|
|
|
|
@beartype
|
|
def forward(
|
|
self,
|
|
x,
|
|
cond: Tensor,
|
|
):
|
|
res = x
|
|
cond = self.to_cond(cond)
|
|
|
|
x = self.conv(x, cond=cond)
|
|
x = F.elu(x)
|
|
x = self.conv_out(x)
|
|
x = F.elu(x)
|
|
return x + res
|
|
|
|
|
|
class CausalConvTranspose3d(Module):
|
|
def __init__(
|
|
self,
|
|
chan_in,
|
|
chan_out,
|
|
kernel_size: Union[int, Tuple[int, int, int]],
|
|
*,
|
|
time_stride,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
kernel_size = cast_tuple(kernel_size, 3)
|
|
|
|
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
|
|
|
|
assert is_odd(height_kernel_size) and is_odd(width_kernel_size)
|
|
|
|
self.upsample_factor = time_stride
|
|
|
|
height_pad = height_kernel_size // 2
|
|
width_pad = width_kernel_size // 2
|
|
|
|
stride = (time_stride, 1, 1)
|
|
padding = (0, height_pad, width_pad)
|
|
|
|
self.conv = nn.ConvTranspose3d(
|
|
chan_in, chan_out, kernel_size, stride, padding=padding, **kwargs
|
|
)
|
|
|
|
def forward(self, x):
|
|
assert x.ndim == 5
|
|
t = x.shape[2]
|
|
|
|
out = self.conv(x)
|
|
|
|
out = out[..., : (t * self.upsample_factor), :, :]
|
|
return out
|
|
|
|
|
|
# video tokenizer class
|
|
|
|
LossBreakdown = namedtuple(
|
|
"LossBreakdown",
|
|
[
|
|
"recon_loss",
|
|
"lfq_aux_loss",
|
|
"quantizer_loss_breakdown",
|
|
"perceptual_loss",
|
|
"adversarial_gen_loss",
|
|
"adaptive_adversarial_weight",
|
|
"multiscale_gen_losses",
|
|
"multiscale_gen_adaptive_weights",
|
|
],
|
|
)
|
|
|
|
DiscrLossBreakdown = namedtuple(
|
|
"DiscrLossBreakdown", ["discr_loss", "multiscale_discr_losses", "gradient_penalty"]
|
|
)
|
|
|
|
|
|
class VideoTokenizer(Module):
|
|
@beartype
|
|
def __init__(
|
|
self,
|
|
*,
|
|
image_size,
|
|
layers: Tuple[Union[str, Tuple[str, int]], ...] = ("residual", "residual", "residual"),
|
|
residual_conv_kernel_size=3,
|
|
num_codebooks=1,
|
|
codebook_size: Optional[int] = None,
|
|
channels=3,
|
|
init_dim=64,
|
|
max_dim=float("inf"),
|
|
dim_cond=None,
|
|
dim_cond_expansion_factor=4.0,
|
|
input_conv_kernel_size: Tuple[int, int, int] = (7, 7, 7),
|
|
output_conv_kernel_size: Tuple[int, int, int] = (3, 3, 3),
|
|
pad_mode: str = "constant",
|
|
lfq_entropy_loss_weight=0.1,
|
|
lfq_commitment_loss_weight=1.0,
|
|
lfq_diversity_gamma=2.5,
|
|
quantizer_aux_loss_weight=1.0,
|
|
lfq_activation=nn.Identity(),
|
|
use_fsq=False,
|
|
fsq_levels: Optional[List[int]] = None,
|
|
attn_dim_head=32,
|
|
attn_heads=8,
|
|
attn_dropout=0.0,
|
|
linear_attn_dim_head=8,
|
|
linear_attn_heads=16,
|
|
vgg: Optional[Module] = None,
|
|
vgg_weights: VGG16_Weights = VGG16_Weights.DEFAULT,
|
|
perceptual_loss_weight=1e-1,
|
|
discr_kwargs: Optional[dict] = None,
|
|
multiscale_discrs: Tuple[Module, ...] = tuple(),
|
|
use_gan=True,
|
|
adversarial_loss_weight=1.0,
|
|
grad_penalty_loss_weight=10.0,
|
|
multiscale_adversarial_loss_weight=1.0,
|
|
flash_attn=True,
|
|
separate_first_frame_encoding=False,
|
|
):
|
|
super().__init__()
|
|
|
|
# for autosaving the config
|
|
|
|
_locals = locals()
|
|
_locals.pop("self", None)
|
|
_locals.pop("__class__", None)
|
|
self._configs = pickle.dumps(_locals)
|
|
|
|
# image size
|
|
|
|
self.channels = channels
|
|
self.image_size = image_size
|
|
|
|
# initial encoder
|
|
|
|
self.conv_in = CausalConv3d(channels, init_dim, input_conv_kernel_size, pad_mode=pad_mode)
|
|
|
|
# whether to encode the first frame separately or not
|
|
|
|
self.conv_in_first_frame = nn.Identity()
|
|
self.conv_out_first_frame = nn.Identity()
|
|
|
|
if separate_first_frame_encoding:
|
|
self.conv_in_first_frame = SameConv2d(channels, init_dim, input_conv_kernel_size[-2:])
|
|
self.conv_out_first_frame = SameConv2d(init_dim, channels, output_conv_kernel_size[-2:])
|
|
|
|
self.separate_first_frame_encoding = separate_first_frame_encoding
|
|
|
|
# encoder and decoder layers
|
|
|
|
self.encoder_layers = ModuleList([])
|
|
self.decoder_layers = ModuleList([])
|
|
|
|
self.conv_out = CausalConv3d(init_dim, channels, output_conv_kernel_size, pad_mode=pad_mode)
|
|
|
|
dim = init_dim
|
|
dim_out = dim
|
|
|
|
layer_fmap_size = image_size
|
|
time_downsample_factor = 1
|
|
has_cond_across_layers = []
|
|
|
|
for layer_def in layers:
|
|
layer_type, *layer_params = cast_tuple(layer_def)
|
|
|
|
has_cond = False
|
|
|
|
if layer_type == "residual":
|
|
encoder_layer = ResidualUnit(dim, residual_conv_kernel_size)
|
|
decoder_layer = ResidualUnit(dim, residual_conv_kernel_size)
|
|
|
|
elif layer_type == "consecutive_residual":
|
|
(num_consecutive,) = layer_params
|
|
encoder_layer = Sequential(
|
|
*[ResidualUnit(dim, residual_conv_kernel_size) for _ in range(num_consecutive)]
|
|
)
|
|
decoder_layer = Sequential(
|
|
*[ResidualUnit(dim, residual_conv_kernel_size) for _ in range(num_consecutive)]
|
|
)
|
|
|
|
elif layer_type == "cond_residual":
|
|
assert exists(
|
|
dim_cond
|
|
), "dim_cond must be passed into VideoTokenizer, if tokenizer is to be conditioned"
|
|
|
|
has_cond = True
|
|
|
|
encoder_layer = ResidualUnitMod(
|
|
dim,
|
|
residual_conv_kernel_size,
|
|
dim_cond=int(dim_cond * dim_cond_expansion_factor),
|
|
)
|
|
decoder_layer = ResidualUnitMod(
|
|
dim,
|
|
residual_conv_kernel_size,
|
|
dim_cond=int(dim_cond * dim_cond_expansion_factor),
|
|
)
|
|
dim_out = dim
|
|
|
|
elif layer_type == "compress_space":
|
|
dim_out = safe_get_index(layer_params, 0)
|
|
dim_out = default(dim_out, dim * 2)
|
|
dim_out = min(dim_out, max_dim)
|
|
|
|
encoder_layer = SpatialDownsample2x(dim, dim_out)
|
|
decoder_layer = SpatialUpsample2x(dim_out, dim)
|
|
|
|
assert layer_fmap_size > 1
|
|
layer_fmap_size //= 2
|
|
|
|
elif layer_type == "compress_time":
|
|
dim_out = safe_get_index(layer_params, 0)
|
|
dim_out = default(dim_out, dim * 2)
|
|
dim_out = min(dim_out, max_dim)
|
|
|
|
encoder_layer = TimeDownsample2x(dim, dim_out)
|
|
decoder_layer = TimeUpsample2x(dim_out, dim)
|
|
|
|
time_downsample_factor *= 2
|
|
|
|
elif layer_type == "attend_space":
|
|
attn_kwargs = dict(
|
|
dim=dim,
|
|
dim_head=attn_dim_head,
|
|
heads=attn_heads,
|
|
dropout=attn_dropout,
|
|
flash=flash_attn,
|
|
)
|
|
|
|
encoder_layer = Sequential(
|
|
Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim))
|
|
)
|
|
|
|
decoder_layer = Sequential(
|
|
Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim))
|
|
)
|
|
|
|
elif layer_type == "linear_attend_space":
|
|
linear_attn_kwargs = dict(
|
|
dim=dim, dim_head=linear_attn_dim_head, heads=linear_attn_heads
|
|
)
|
|
|
|
encoder_layer = Sequential(
|
|
Residual(LinearSpaceAttention(**linear_attn_kwargs)), Residual(FeedForward(dim))
|
|
)
|
|
|
|
decoder_layer = Sequential(
|
|
Residual(LinearSpaceAttention(**linear_attn_kwargs)), Residual(FeedForward(dim))
|
|
)
|
|
|
|
elif layer_type == "gateloop_time":
|
|
gateloop_kwargs = dict(use_heinsen=False)
|
|
|
|
encoder_layer = ToTimeSequence(Residual(SimpleGateLoopLayer(dim=dim)))
|
|
decoder_layer = ToTimeSequence(Residual(SimpleGateLoopLayer(dim=dim)))
|
|
|
|
elif layer_type == "attend_time":
|
|
attn_kwargs = dict(
|
|
dim=dim,
|
|
dim_head=attn_dim_head,
|
|
heads=attn_heads,
|
|
dropout=attn_dropout,
|
|
causal=True,
|
|
flash=flash_attn,
|
|
)
|
|
|
|
encoder_layer = Sequential(
|
|
Residual(TokenShift(TimeAttention(**attn_kwargs))),
|
|
Residual(TokenShift(FeedForward(dim, dim_cond=dim_cond))),
|
|
)
|
|
|
|
decoder_layer = Sequential(
|
|
Residual(TokenShift(TimeAttention(**attn_kwargs))),
|
|
Residual(TokenShift(FeedForward(dim, dim_cond=dim_cond))),
|
|
)
|
|
|
|
elif layer_type == "cond_attend_space":
|
|
has_cond = True
|
|
|
|
attn_kwargs = dict(
|
|
dim=dim,
|
|
dim_cond=dim_cond,
|
|
dim_head=attn_dim_head,
|
|
heads=attn_heads,
|
|
dropout=attn_dropout,
|
|
flash=flash_attn,
|
|
)
|
|
|
|
encoder_layer = Sequential(
|
|
Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim))
|
|
)
|
|
|
|
decoder_layer = Sequential(
|
|
Residual(SpaceAttention(**attn_kwargs)), Residual(FeedForward(dim))
|
|
)
|
|
|
|
elif layer_type == "cond_linear_attend_space":
|
|
has_cond = True
|
|
|
|
attn_kwargs = dict(
|
|
dim=dim,
|
|
dim_cond=dim_cond,
|
|
dim_head=attn_dim_head,
|
|
heads=attn_heads,
|
|
dropout=attn_dropout,
|
|
flash=flash_attn,
|
|
)
|
|
|
|
encoder_layer = Sequential(
|
|
Residual(LinearSpaceAttention(**attn_kwargs)),
|
|
Residual(FeedForward(dim, dim_cond=dim_cond)),
|
|
)
|
|
|
|
decoder_layer = Sequential(
|
|
Residual(LinearSpaceAttention(**attn_kwargs)),
|
|
Residual(FeedForward(dim, dim_cond=dim_cond)),
|
|
)
|
|
|
|
elif layer_type == "cond_attend_time":
|
|
has_cond = True
|
|
|
|
attn_kwargs = dict(
|
|
dim=dim,
|
|
dim_cond=dim_cond,
|
|
dim_head=attn_dim_head,
|
|
heads=attn_heads,
|
|
dropout=attn_dropout,
|
|
causal=True,
|
|
flash=flash_attn,
|
|
)
|
|
|
|
encoder_layer = Sequential(
|
|
Residual(TokenShift(TimeAttention(**attn_kwargs))),
|
|
Residual(TokenShift(FeedForward(dim, dim_cond=dim_cond))),
|
|
)
|
|
|
|
decoder_layer = Sequential(
|
|
Residual(TokenShift(TimeAttention(**attn_kwargs))),
|
|
Residual(TokenShift(FeedForward(dim, dim_cond=dim_cond))),
|
|
)
|
|
|
|
else:
|
|
raise ValueError(f"unknown layer type {layer_type}")
|
|
|
|
self.encoder_layers.append(encoder_layer)
|
|
self.decoder_layers.insert(0, decoder_layer)
|
|
|
|
dim = dim_out
|
|
has_cond_across_layers.append(has_cond)
|
|
|
|
# add a final norm just before quantization layer
|
|
|
|
self.encoder_layers.append(
|
|
Sequential(
|
|
Rearrange("b c ... -> b ... c"),
|
|
nn.LayerNorm(dim),
|
|
Rearrange("b ... c -> b c ..."),
|
|
)
|
|
)
|
|
|
|
self.time_downsample_factor = time_downsample_factor
|
|
self.time_padding = time_downsample_factor - 1
|
|
|
|
self.fmap_size = layer_fmap_size
|
|
|
|
# use a MLP stem for conditioning, if needed
|
|
|
|
self.has_cond_across_layers = has_cond_across_layers
|
|
self.has_cond = any(has_cond_across_layers)
|
|
|
|
self.encoder_cond_in = nn.Identity()
|
|
self.decoder_cond_in = nn.Identity()
|
|
|
|
if has_cond:
|
|
self.dim_cond = dim_cond
|
|
|
|
self.encoder_cond_in = Sequential(
|
|
nn.Linear(dim_cond, int(dim_cond * dim_cond_expansion_factor)), nn.SiLU()
|
|
)
|
|
|
|
self.decoder_cond_in = Sequential(
|
|
nn.Linear(dim_cond, int(dim_cond * dim_cond_expansion_factor)), nn.SiLU()
|
|
)
|
|
|
|
# quantizer related
|
|
|
|
self.use_fsq = use_fsq
|
|
|
|
if not use_fsq:
|
|
assert exists(codebook_size) and not exists(
|
|
fsq_levels
|
|
), "if use_fsq is set to False, `codebook_size` must be set (and not `fsq_levels`)"
|
|
|
|
# lookup free quantizer(s) - multiple codebooks is possible
|
|
# each codebook will get its own entropy regularization
|
|
|
|
self.quantizers = LFQ(
|
|
dim=dim,
|
|
codebook_size=codebook_size,
|
|
num_codebooks=num_codebooks,
|
|
entropy_loss_weight=lfq_entropy_loss_weight,
|
|
commitment_loss_weight=lfq_commitment_loss_weight,
|
|
diversity_gamma=lfq_diversity_gamma,
|
|
)
|
|
|
|
else:
|
|
assert (
|
|
not exists(codebook_size) and exists(fsq_levels)
|
|
), "if use_fsq is set to True, `fsq_levels` must be set (and not `codebook_size`). the effective codebook size is the cumulative product of all the FSQ levels"
|
|
|
|
self.quantizers = FSQ(fsq_levels, dim=dim, num_codebooks=num_codebooks)
|
|
|
|
self.quantizer_aux_loss_weight = quantizer_aux_loss_weight
|
|
|
|
# dummy loss
|
|
|
|
self.register_buffer("zero", torch.tensor(0.0), persistent=False)
|
|
|
|
# perceptual loss related
|
|
|
|
use_vgg = channels in {1, 3, 4} and perceptual_loss_weight > 0.0
|
|
|
|
self.vgg = None
|
|
self.perceptual_loss_weight = perceptual_loss_weight
|
|
|
|
if use_vgg:
|
|
if not exists(vgg):
|
|
vgg = torchvision.models.vgg16(weights=vgg_weights)
|
|
|
|
vgg.classifier = Sequential(*vgg.classifier[:-2])
|
|
|
|
self.vgg = vgg
|
|
|
|
self.use_vgg = use_vgg
|
|
|
|
# main flag for whether to use GAN at all
|
|
|
|
self.use_gan = use_gan
|
|
|
|
# discriminator
|
|
|
|
discr_kwargs = default(
|
|
discr_kwargs, dict(dim=dim, image_size=image_size, channels=channels, max_dim=512)
|
|
)
|
|
|
|
self.discr = Discriminator(**discr_kwargs)
|
|
|
|
self.adversarial_loss_weight = adversarial_loss_weight
|
|
self.grad_penalty_loss_weight = grad_penalty_loss_weight
|
|
|
|
self.has_gan = use_gan and adversarial_loss_weight > 0.0
|
|
|
|
# multi-scale discriminators
|
|
|
|
self.has_multiscale_gan = use_gan and multiscale_adversarial_loss_weight > 0.0
|
|
|
|
self.multiscale_discrs = ModuleList([*multiscale_discrs])
|
|
|
|
self.multiscale_adversarial_loss_weight = multiscale_adversarial_loss_weight
|
|
|
|
self.has_multiscale_discrs = (
|
|
use_gan and multiscale_adversarial_loss_weight > 0.0 and len(multiscale_discrs) > 0
|
|
)
|
|
|
|
@property
|
|
def device(self):
|
|
return self.zero.device
|
|
|
|
@classmethod
|
|
def init_and_load_from(cls, path, strict=True):
|
|
path = Path(path)
|
|
assert path.exists()
|
|
pkg = torch.load(str(path), map_location="cpu")
|
|
|
|
assert "config" in pkg, "model configs were not found in this saved checkpoint"
|
|
|
|
config = pickle.loads(pkg["config"])
|
|
tokenizer = cls(**config)
|
|
tokenizer.load(path, strict=strict)
|
|
return tokenizer
|
|
|
|
def parameters(self):
|
|
return [
|
|
*self.conv_in.parameters(),
|
|
*self.conv_in_first_frame.parameters(),
|
|
*self.conv_out_first_frame.parameters(),
|
|
*self.conv_out.parameters(),
|
|
*self.encoder_layers.parameters(),
|
|
*self.decoder_layers.parameters(),
|
|
*self.encoder_cond_in.parameters(),
|
|
*self.decoder_cond_in.parameters(),
|
|
*self.quantizers.parameters(),
|
|
]
|
|
|
|
def discr_parameters(self):
|
|
return self.discr.parameters()
|
|
|
|
def copy_for_eval(self):
|
|
device = self.device
|
|
vae_copy = copy.deepcopy(self.cpu())
|
|
|
|
maybe_del_attr_(vae_copy, "discr")
|
|
maybe_del_attr_(vae_copy, "vgg")
|
|
maybe_del_attr_(vae_copy, "multiscale_discrs")
|
|
|
|
vae_copy.eval()
|
|
return vae_copy.to(device)
|
|
|
|
@remove_vgg
|
|
def state_dict(self, *args, **kwargs):
|
|
return super().state_dict(*args, **kwargs)
|
|
|
|
@remove_vgg
|
|
def load_state_dict(self, *args, **kwargs):
|
|
return super().load_state_dict(*args, **kwargs)
|
|
|
|
def save(self, path, overwrite=True):
|
|
path = Path(path)
|
|
assert overwrite or not path.exists(), f"{str(path)} already exists"
|
|
|
|
pkg = dict(model_state_dict=self.state_dict(), version=__version__, config=self._configs)
|
|
|
|
torch.save(pkg, str(path))
|
|
|
|
def load(self, path, strict=True):
|
|
path = Path(path)
|
|
assert path.exists()
|
|
|
|
pkg = torch.load(str(path))
|
|
state_dict = pkg.get("model_state_dict")
|
|
version = pkg.get("version")
|
|
|
|
assert exists(state_dict)
|
|
|
|
if exists(version):
|
|
print(f"loading checkpointed tokenizer from version {version}")
|
|
|
|
self.load_state_dict(state_dict, strict=strict)
|
|
|
|
@beartype
|
|
def encode(
|
|
self,
|
|
video: Tensor,
|
|
quantize=False,
|
|
cond: Optional[Tensor] = None,
|
|
video_contains_first_frame=True,
|
|
):
|
|
encode_first_frame_separately = (
|
|
self.separate_first_frame_encoding and video_contains_first_frame
|
|
)
|
|
|
|
# whether to pad video or not
|
|
|
|
if video_contains_first_frame:
|
|
video_len = video.shape[2]
|
|
|
|
video = pad_at_dim(video, (self.time_padding, 0), value=0.0, dim=2)
|
|
video_packed_shape = [
|
|
torch.Size([self.time_padding]),
|
|
torch.Size([]),
|
|
torch.Size([video_len - 1]),
|
|
]
|
|
|
|
# conditioning, if needed
|
|
|
|
assert (
|
|
(not self.has_cond) or exists(cond)
|
|
), "`cond` must be passed into tokenizer forward method since conditionable layers were specified"
|
|
|
|
if exists(cond):
|
|
assert cond.shape == (video.shape[0], self.dim_cond)
|
|
|
|
cond = self.encoder_cond_in(cond)
|
|
cond_kwargs = dict(cond=cond)
|
|
|
|
# initial conv
|
|
# taking into account whether to encode first frame separately
|
|
|
|
if encode_first_frame_separately:
|
|
pad, first_frame, video = unpack(video, video_packed_shape, "b c * h w")
|
|
first_frame = self.conv_in_first_frame(first_frame)
|
|
|
|
video = self.conv_in(video)
|
|
|
|
if encode_first_frame_separately:
|
|
video, _ = pack([first_frame, video], "b c * h w")
|
|
video = pad_at_dim(video, (self.time_padding, 0), dim=2)
|
|
|
|
# encoder layers
|
|
|
|
for fn, has_cond in zip(self.encoder_layers, self.has_cond_across_layers):
|
|
layer_kwargs = dict()
|
|
|
|
if has_cond:
|
|
layer_kwargs = cond_kwargs
|
|
|
|
video = fn(video, **layer_kwargs)
|
|
|
|
maybe_quantize = identity if not quantize else self.quantizers
|
|
|
|
return maybe_quantize(video)
|
|
|
|
@beartype
|
|
def decode_from_code_indices(
|
|
self, codes: Tensor, cond: Optional[Tensor] = None, video_contains_first_frame=True
|
|
):
|
|
assert codes.dtype in (torch.long, torch.int32)
|
|
|
|
if codes.ndim == 2:
|
|
video_code_len = codes.shape[-1]
|
|
assert divisible_by(
|
|
video_code_len, self.fmap_size**2
|
|
), f"flattened video ids must have a length ({video_code_len}) that is divisible by the fmap size ({self.fmap_size}) squared ({self.fmap_size ** 2})"
|
|
|
|
codes = rearrange(codes, "b (f h w) -> b f h w", h=self.fmap_size, w=self.fmap_size)
|
|
|
|
quantized = self.quantizers.indices_to_codes(codes)
|
|
|
|
return self.decode(
|
|
quantized, cond=cond, video_contains_first_frame=video_contains_first_frame
|
|
)
|
|
|
|
@beartype
|
|
def decode(
|
|
self, quantized: Tensor, cond: Optional[Tensor] = None, video_contains_first_frame=True
|
|
):
|
|
decode_first_frame_separately = (
|
|
self.separate_first_frame_encoding and video_contains_first_frame
|
|
)
|
|
|
|
batch = quantized.shape[0]
|
|
|
|
# conditioning, if needed
|
|
|
|
assert (
|
|
(not self.has_cond) or exists(cond)
|
|
), "`cond` must be passed into tokenizer forward method since conditionable layers were specified"
|
|
|
|
if exists(cond):
|
|
assert cond.shape == (batch, self.dim_cond)
|
|
|
|
cond = self.decoder_cond_in(cond)
|
|
cond_kwargs = dict(cond=cond)
|
|
|
|
# decoder layers
|
|
|
|
x = quantized
|
|
|
|
for fn, has_cond in zip(self.decoder_layers, reversed(self.has_cond_across_layers)):
|
|
layer_kwargs = dict()
|
|
|
|
if has_cond:
|
|
layer_kwargs = cond_kwargs
|
|
|
|
x = fn(x, **layer_kwargs)
|
|
|
|
# to pixels
|
|
|
|
if decode_first_frame_separately:
|
|
left_pad, xff, x = (
|
|
x[:, :, : self.time_padding],
|
|
x[:, :, self.time_padding],
|
|
x[:, :, (self.time_padding + 1) :],
|
|
)
|
|
|
|
out = self.conv_out(x)
|
|
outff = self.conv_out_first_frame(xff)
|
|
|
|
video, _ = pack([outff, out], "b c * h w")
|
|
|
|
else:
|
|
video = self.conv_out(x)
|
|
|
|
# if video were padded, remove padding
|
|
|
|
if video_contains_first_frame:
|
|
video = video[:, :, self.time_padding :]
|
|
|
|
return video
|
|
|
|
@torch.no_grad()
|
|
def tokenize(self, video):
|
|
self.eval()
|
|
return self.forward(video, return_codes=True)
|
|
|
|
@beartype
|
|
def forward(
|
|
self,
|
|
video_or_images: Tensor,
|
|
cond: Optional[Tensor] = None,
|
|
return_loss=False,
|
|
return_codes=False,
|
|
return_recon=False,
|
|
return_discr_loss=False,
|
|
return_recon_loss_only=False,
|
|
apply_gradient_penalty=True,
|
|
video_contains_first_frame=True,
|
|
adversarial_loss_weight=None,
|
|
multiscale_adversarial_loss_weight=None,
|
|
):
|
|
adversarial_loss_weight = default(adversarial_loss_weight, self.adversarial_loss_weight)
|
|
multiscale_adversarial_loss_weight = default(
|
|
multiscale_adversarial_loss_weight, self.multiscale_adversarial_loss_weight
|
|
)
|
|
|
|
assert (return_loss + return_codes + return_discr_loss) <= 1
|
|
assert video_or_images.ndim in {4, 5}
|
|
|
|
assert video_or_images.shape[-2:] == (self.image_size, self.image_size)
|
|
|
|
# accept images for image pretraining (curriculum learning from images to video)
|
|
|
|
is_image = video_or_images.ndim == 4
|
|
|
|
if is_image:
|
|
video = rearrange(video_or_images, "b c ... -> b c 1 ...")
|
|
video_contains_first_frame = True
|
|
else:
|
|
video = video_or_images
|
|
|
|
batch, channels, frames = video.shape[:3]
|
|
|
|
assert divisible_by(
|
|
frames - int(video_contains_first_frame), self.time_downsample_factor
|
|
), f"number of frames {frames} minus the first frame ({frames - int(video_contains_first_frame)}) must be divisible by the total downsample factor across time {self.time_downsample_factor}"
|
|
|
|
# encoder
|
|
|
|
x = self.encode(video, cond=cond, video_contains_first_frame=video_contains_first_frame)
|
|
|
|
# lookup free quantization
|
|
|
|
if self.use_fsq:
|
|
quantized, codes = self.quantizers(x)
|
|
|
|
aux_losses = self.zero
|
|
quantizer_loss_breakdown = None
|
|
else:
|
|
(quantized, codes, aux_losses), quantizer_loss_breakdown = self.quantizers(
|
|
x, return_loss_breakdown=True
|
|
)
|
|
|
|
if return_codes and not return_recon:
|
|
return codes
|
|
|
|
# decoder
|
|
|
|
recon_video = self.decode(
|
|
quantized, cond=cond, video_contains_first_frame=video_contains_first_frame
|
|
)
|
|
|
|
if return_codes:
|
|
return codes, recon_video
|
|
|
|
# reconstruction loss
|
|
|
|
if not (return_loss or return_discr_loss or return_recon_loss_only):
|
|
return recon_video
|
|
|
|
recon_loss = F.mse_loss(video, recon_video)
|
|
|
|
# for validation, only return recon loss
|
|
|
|
if return_recon_loss_only:
|
|
return recon_loss, recon_video
|
|
|
|
# gan discriminator loss
|
|
|
|
if return_discr_loss:
|
|
assert self.has_gan
|
|
assert exists(self.discr)
|
|
|
|
# pick a random frame for image discriminator
|
|
|
|
frame_indices = torch.randn((batch, frames)).topk(1, dim=-1).indices
|
|
|
|
real = pick_video_frame(video, frame_indices)
|
|
|
|
if apply_gradient_penalty:
|
|
real = real.requires_grad_()
|
|
|
|
fake = pick_video_frame(recon_video, frame_indices)
|
|
|
|
real_logits = self.discr(real)
|
|
fake_logits = self.discr(fake.detach())
|
|
|
|
discr_loss = hinge_discr_loss(fake_logits, real_logits)
|
|
|
|
# multiscale discriminators
|
|
|
|
multiscale_discr_losses = []
|
|
|
|
if self.has_multiscale_discrs:
|
|
for discr in self.multiscale_discrs:
|
|
multiscale_real_logits = discr(video)
|
|
multiscale_fake_logits = discr(recon_video.detach())
|
|
|
|
multiscale_discr_loss = hinge_discr_loss(
|
|
multiscale_fake_logits, multiscale_real_logits
|
|
)
|
|
|
|
multiscale_discr_losses.append(multiscale_discr_loss)
|
|
else:
|
|
multiscale_discr_losses.append(self.zero)
|
|
|
|
# gradient penalty
|
|
|
|
if apply_gradient_penalty:
|
|
gradient_penalty_loss = gradient_penalty(real, real_logits)
|
|
else:
|
|
gradient_penalty_loss = self.zero
|
|
|
|
# total loss
|
|
|
|
total_loss = (
|
|
discr_loss
|
|
+ gradient_penalty_loss * self.grad_penalty_loss_weight
|
|
+ sum(multiscale_discr_losses) * self.multiscale_adversarial_loss_weight
|
|
)
|
|
|
|
discr_loss_breakdown = DiscrLossBreakdown(
|
|
discr_loss, multiscale_discr_losses, gradient_penalty_loss
|
|
)
|
|
|
|
return total_loss, discr_loss_breakdown
|
|
|
|
# perceptual loss
|
|
|
|
if self.use_vgg:
|
|
frame_indices = torch.randn((batch, frames)).topk(1, dim=-1).indices
|
|
|
|
input_vgg_input = pick_video_frame(video, frame_indices)
|
|
recon_vgg_input = pick_video_frame(recon_video, frame_indices)
|
|
|
|
if channels == 1:
|
|
input_vgg_input = repeat(input_vgg_input, "b 1 h w -> b c h w", c=3)
|
|
recon_vgg_input = repeat(recon_vgg_input, "b 1 h w -> b c h w", c=3)
|
|
|
|
elif channels == 4:
|
|
input_vgg_input = input_vgg_input[:, :3]
|
|
recon_vgg_input = recon_vgg_input[:, :3]
|
|
|
|
input_vgg_feats = self.vgg(input_vgg_input)
|
|
recon_vgg_feats = self.vgg(recon_vgg_input)
|
|
|
|
perceptual_loss = F.mse_loss(input_vgg_feats, recon_vgg_feats)
|
|
else:
|
|
perceptual_loss = self.zero
|
|
|
|
# get gradient with respect to perceptual loss for last decoder layer
|
|
# needed for adaptive weighting
|
|
|
|
last_dec_layer = self.conv_out.conv.weight
|
|
|
|
norm_grad_wrt_perceptual_loss = None
|
|
|
|
if self.training and self.use_vgg and (self.has_gan or self.has_multiscale_discrs):
|
|
norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(
|
|
perceptual_loss, last_dec_layer
|
|
).norm(p=2)
|
|
|
|
# per-frame image discriminator
|
|
|
|
recon_video_frames = None
|
|
|
|
if self.has_gan:
|
|
frame_indices = torch.randn((batch, frames)).topk(1, dim=-1).indices
|
|
recon_video_frames = pick_video_frame(recon_video, frame_indices)
|
|
|
|
fake_logits = self.discr(recon_video_frames)
|
|
gen_loss = hinge_gen_loss(fake_logits)
|
|
|
|
adaptive_weight = 1.0
|
|
|
|
if exists(norm_grad_wrt_perceptual_loss):
|
|
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p=2)
|
|
adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(
|
|
min=1e-3
|
|
)
|
|
adaptive_weight.clamp_(max=1e3)
|
|
|
|
if torch.isnan(adaptive_weight).any():
|
|
adaptive_weight = 1.0
|
|
else:
|
|
gen_loss = self.zero
|
|
adaptive_weight = 0.0
|
|
|
|
# multiscale discriminator losses
|
|
|
|
multiscale_gen_losses = []
|
|
multiscale_gen_adaptive_weights = []
|
|
|
|
if self.has_multiscale_gan and self.has_multiscale_discrs:
|
|
if not exists(recon_video_frames):
|
|
recon_video_frames = pick_video_frame(recon_video, frame_indices)
|
|
|
|
for discr in self.multiscale_discrs:
|
|
fake_logits = recon_video_frames
|
|
multiscale_gen_loss = hinge_gen_loss(fake_logits)
|
|
|
|
multiscale_gen_losses.append(multiscale_gen_loss)
|
|
|
|
multiscale_adaptive_weight = 1.0
|
|
|
|
if exists(norm_grad_wrt_perceptual_loss):
|
|
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(
|
|
multiscale_gen_loss, last_dec_layer
|
|
).norm(p=2)
|
|
multiscale_adaptive_weight = (
|
|
norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(min=1e-5)
|
|
)
|
|
multiscale_adaptive_weight.clamp_(max=1e3)
|
|
|
|
multiscale_gen_adaptive_weights.append(multiscale_adaptive_weight)
|
|
|
|
# calculate total loss
|
|
|
|
total_loss = (
|
|
recon_loss
|
|
+ aux_losses * self.quantizer_aux_loss_weight
|
|
+ perceptual_loss * self.perceptual_loss_weight
|
|
+ gen_loss * adaptive_weight * adversarial_loss_weight
|
|
)
|
|
|
|
if self.has_multiscale_discrs:
|
|
weighted_multiscale_gen_losses = sum(
|
|
loss * weight
|
|
for loss, weight in zip(multiscale_gen_losses, multiscale_gen_adaptive_weights)
|
|
)
|
|
|
|
total_loss = (
|
|
total_loss + weighted_multiscale_gen_losses * multiscale_adversarial_loss_weight
|
|
)
|
|
|
|
# loss breakdown
|
|
|
|
loss_breakdown = LossBreakdown(
|
|
recon_loss,
|
|
aux_losses,
|
|
quantizer_loss_breakdown,
|
|
perceptual_loss,
|
|
gen_loss,
|
|
adaptive_weight,
|
|
multiscale_gen_losses,
|
|
multiscale_gen_adaptive_weights,
|
|
)
|
|
|
|
return total_loss, loss_breakdown
|
|
|
|
|
|
# main class
|
|
|
|
|
|
class MagViT2(Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
return x
|