CogVideo/sat/dit_video_concat.py
Yuxuan Zhang 39c6562dc8 format
2025-03-22 15:14:06 +08:00

900 lines
32 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from functools import partial
from einops import rearrange, repeat
from functools import reduce
from operator import mul
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from sat.model.base_model import BaseModel, non_conflict
from sat.model.mixins import BaseMixin
from sat.transformer_defaults import HOOKS_DEFAULT, attention_fn_default
from sat.mpu.layers import ColumnParallelLinear
from sgm.util import instantiate_from_config
from sgm.modules.diffusionmodules.openaimodel import Timestep
from sgm.modules.diffusionmodules.util import linear, timestep_embedding
from sat.ops.layernorm import LayerNorm, RMSNorm
class ImagePatchEmbeddingMixin(BaseMixin):
def __init__(self, in_channels, hidden_size, patch_size, text_hidden_size=None):
super().__init__()
self.patch_size = patch_size
self.proj = nn.Linear(in_channels * reduce(mul, patch_size), hidden_size)
if text_hidden_size is not None:
self.text_proj = nn.Linear(text_hidden_size, hidden_size)
else:
self.text_proj = None
def word_embedding_forward(self, input_ids, **kwargs):
images = kwargs["images"] # (b,t,c,h,w)
emb = rearrange(images, "b t c h w -> b (t h w) c")
# emb = rearrange(images, "b c t h w -> b (t h w) c")
emb = rearrange(
emb,
"b (t o h p w q) c -> b (t h w) (c o p q)",
t=kwargs["rope_T"],
h=kwargs["rope_H"],
w=kwargs["rope_W"],
o=self.patch_size[0],
p=self.patch_size[1],
q=self.patch_size[2],
)
emb = self.proj(emb)
if self.text_proj is not None:
text_emb = self.text_proj(kwargs["encoder_outputs"])
emb = torch.cat((text_emb, emb), dim=1) # (b,n_t+t*n_i,d)
emb = emb.contiguous()
return emb # (b,n_t+t*n_i,d)
def reinit(self, parent_model=None):
w = self.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.proj.bias, 0)
del self.transformer.word_embeddings
def get_3d_sincos_pos_embed(
embed_dim,
grid_height,
grid_width,
t_size,
cls_token=False,
height_interpolation=1.0,
width_interpolation=1.0,
time_interpolation=1.0,
):
"""
grid_size: int of the grid height and width
t_size: int of the temporal size
return:
pos_embed: [t_size*grid_size * grid_size, embed_dim] or [1+t_size*grid_size * grid_size, embed_dim]
(w/ or w/o cls_token)
"""
assert embed_dim % 4 == 0
embed_dim_spatial = embed_dim // 4 * 3
embed_dim_temporal = embed_dim // 4
# spatial
grid_h = np.arange(grid_height, dtype=np.float32) / height_interpolation
grid_w = np.arange(grid_width, dtype=np.float32) / width_interpolation
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_height, grid_width])
pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid)
# temporal
grid_t = np.arange(t_size, dtype=np.float32) / time_interpolation
pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t)
# concate: [T, H, W] order
pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :]
pos_embed_temporal = np.repeat(
pos_embed_temporal, grid_height * grid_width, axis=1
) # [T, H*W, D // 4]
pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :]
pos_embed_spatial = np.repeat(pos_embed_spatial, t_size, axis=0) # [T, H*W, D // 4 * 3]
pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1)
return pos_embed # [T, H*W, D]
def get_2d_sincos_pos_embed(embed_dim, grid_height, grid_width, cls_token=False, extra_tokens=0):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_h = np.arange(grid_height, dtype=np.float32)
grid_w = np.arange(grid_width, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_height, grid_width])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
class Basic2DPositionEmbeddingMixin(BaseMixin):
def __init__(self, height, width, compressed_num_frames, hidden_size, text_length=0):
super().__init__()
self.height = height
self.width = width
self.spatial_length = height * width
self.pos_embedding = nn.Parameter(
torch.zeros(1, int(text_length + self.spatial_length), int(hidden_size)),
requires_grad=False,
)
def position_embedding_forward(self, position_ids, **kwargs):
return self.pos_embedding
def reinit(self, parent_model=None):
del self.transformer.position_embeddings
pos_embed = get_2d_sincos_pos_embed(self.pos_embedding.shape[-1], self.height, self.width)
self.pos_embedding.data[:, -self.spatial_length :].copy_(
torch.from_numpy(pos_embed).float().unsqueeze(0)
)
class Basic3DPositionEmbeddingMixin(BaseMixin):
def __init__(
self,
height,
width,
compressed_num_frames,
hidden_size,
text_length=0,
height_interpolation=1.0,
width_interpolation=1.0,
time_interpolation=1.0,
):
super().__init__()
self.height = height
self.width = width
self.text_length = text_length
self.compressed_num_frames = compressed_num_frames
self.spatial_length = height * width
self.num_patches = height * width * compressed_num_frames
self.pos_embedding = nn.Parameter(
torch.zeros(1, int(text_length + self.num_patches), int(hidden_size)),
requires_grad=False,
)
self.height_interpolation = height_interpolation
self.width_interpolation = width_interpolation
self.time_interpolation = time_interpolation
def position_embedding_forward(self, position_ids, **kwargs):
if kwargs["images"].shape[1] == 1:
return self.pos_embedding[:, : self.text_length + self.spatial_length]
return self.pos_embedding[:, : self.text_length + kwargs["seq_length"]]
def reinit(self, parent_model=None):
del self.transformer.position_embeddings
pos_embed = get_3d_sincos_pos_embed(
self.pos_embedding.shape[-1],
self.height,
self.width,
self.compressed_num_frames,
height_interpolation=self.height_interpolation,
width_interpolation=self.width_interpolation,
time_interpolation=self.time_interpolation,
)
pos_embed = torch.from_numpy(pos_embed).float()
pos_embed = rearrange(pos_embed, "t n d -> (t n) d")
self.pos_embedding.data[:, -self.num_patches :].copy_(pos_embed)
def broadcat(tensors, dim=-1):
num_tensors = len(tensors)
shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
shape_len = list(shape_lens)[0]
dim = (dim + shape_len) if dim < 0 else dim
dims = list(zip(*map(lambda t: list(t.shape), tensors)))
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
assert all(
[*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]
), "invalid dimensions for broadcastable concatentation"
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
expanded_dims.insert(dim, (dim, dims[dim]))
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
return torch.cat(tensors, dim=dim)
def rotate_half(x):
x = rearrange(x, "... (d r) -> ... d r", r=2)
x1, x2 = x.unbind(dim=-1)
x = torch.stack((-x2, x1), dim=-1)
return rearrange(x, "... d r -> ... (d r)")
class Rotary3DPositionEmbeddingMixin(BaseMixin):
def __init__(
self,
height,
width,
compressed_num_frames,
hidden_size,
hidden_size_head,
text_length,
theta=10000,
rot_v=False,
height_interpolation=1.0,
width_interpolation=1.0,
time_interpolation=1.0,
learnable_pos_embed=False,
):
super().__init__()
self.rot_v = rot_v
dim_t = hidden_size_head // 4
dim_h = hidden_size_head // 8 * 3
dim_w = hidden_size_head // 8 * 3
freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2)[: (dim_t // 2)].float() / dim_t))
freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2)[: (dim_h // 2)].float() / dim_h))
freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2)[: (dim_w // 2)].float() / dim_w))
grid_t = torch.arange(compressed_num_frames, dtype=torch.float32)
grid_h = torch.arange(height, dtype=torch.float32)
grid_w = torch.arange(width, dtype=torch.float32)
freqs_t = torch.einsum("..., f -> ... f", grid_t, freqs_t)
freqs_h = torch.einsum("..., f -> ... f", grid_h, freqs_h)
freqs_w = torch.einsum("..., f -> ... f", grid_w, freqs_w)
freqs_t = repeat(freqs_t, "... n -> ... (n r)", r=2)
freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2)
freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
freqs = broadcat(
(freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]),
dim=-1,
)
freqs = freqs.contiguous()
self.freqs_sin = freqs.sin().cuda()
self.freqs_cos = freqs.cos().cuda()
self.text_length = text_length
if learnable_pos_embed:
num_patches = height * width * compressed_num_frames + text_length
self.pos_embedding = nn.Parameter(
torch.zeros(1, num_patches, int(hidden_size)), requires_grad=True
)
else:
self.pos_embedding = None
def rotary(self, t, **kwargs):
def reshape_freq(freqs):
freqs = freqs[: kwargs["rope_T"], : kwargs["rope_H"], : kwargs["rope_W"]].contiguous()
freqs = rearrange(freqs, "t h w d -> (t h w) d")
freqs = freqs.unsqueeze(0).unsqueeze(0)
return freqs
freqs_cos = reshape_freq(self.freqs_cos).to(t.dtype)
freqs_sin = reshape_freq(self.freqs_sin).to(t.dtype)
return t * freqs_cos + rotate_half(t) * freqs_sin
def position_embedding_forward(self, position_ids, **kwargs):
if self.pos_embedding is not None:
return self.pos_embedding[:, : self.text_length + kwargs["seq_length"]]
else:
return None
def attention_fn(
self,
query_layer,
key_layer,
value_layer,
attention_mask,
attention_dropout=None,
log_attention_weights=None,
scaling_attention_score=True,
**kwargs,
):
attention_fn_default = HOOKS_DEFAULT["attention_fn"]
query_layer = torch.cat(
(
query_layer[
:,
:,
: kwargs["text_length"],
],
self.rotary(
query_layer[
:,
:,
kwargs["text_length"] :,
],
**kwargs,
),
),
dim=2,
)
key_layer = torch.cat(
(
key_layer[
:,
:,
: kwargs["text_length"],
],
self.rotary(
key_layer[
:,
:,
kwargs["text_length"] :,
],
**kwargs,
),
),
dim=2,
)
if self.rot_v:
value_layer = torch.cat(
(
value_layer[
:,
:,
: kwargs["text_length"],
],
self.rotary(
value_layer[
:,
:,
kwargs["text_length"] :,
],
**kwargs,
),
),
dim=2,
)
return attention_fn_default(
query_layer,
key_layer,
value_layer,
attention_mask,
attention_dropout=attention_dropout,
log_attention_weights=log_attention_weights,
scaling_attention_score=scaling_attention_score,
**kwargs,
)
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def unpatchify(x, c, patch_size, w, h, **kwargs):
"""
x: (N, T/2 * S, patch_size**3 * C)
imgs: (N, T, H, W, C)
patch_size 被拆解为三个不同的维度 (o, p, q)分别对应了深度o、高度p和宽度q。这使得 patch 大小在不同维度上可以不相等,增加了灵活性。
"""
imgs = rearrange(
x,
"b (t h w) (c o p q) -> b (t o) c (h p) (w q)",
c=c,
o=patch_size[0],
p=patch_size[1],
q=patch_size[2],
t=kwargs["rope_T"],
h=kwargs["rope_H"],
w=kwargs["rope_W"],
)
return imgs
class FinalLayerMixin(BaseMixin):
def __init__(
self,
hidden_size,
time_embed_dim,
patch_size,
out_channels,
latent_width,
latent_height,
elementwise_affine,
):
super().__init__()
self.hidden_size = hidden_size
self.patch_size = patch_size
self.out_channels = out_channels
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=elementwise_affine, eps=1e-6)
self.linear = nn.Linear(hidden_size, reduce(mul, patch_size) * out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(), nn.Linear(time_embed_dim, 2 * hidden_size, bias=True)
)
def final_forward(self, logits, **kwargs):
x, emb = (
logits[:, kwargs["text_length"] :, :],
kwargs["emb"],
) # x:(b,(t n),d),只取了x中后面images的部分
shift, scale = self.adaLN_modulation(emb).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return unpatchify(
x,
c=self.out_channels,
patch_size=self.patch_size,
w=kwargs["rope_W"],
h=kwargs["rope_H"],
**kwargs,
)
def reinit(self, parent_model=None):
nn.init.xavier_uniform_(self.linear.weight)
nn.init.constant_(self.linear.bias, 0)
class SwiGLUMixin(BaseMixin):
def __init__(self, num_layers, in_features, hidden_features, bias=False):
super().__init__()
self.w2 = nn.ModuleList(
[
ColumnParallelLinear(
in_features,
hidden_features,
gather_output=False,
bias=bias,
module=self,
name="dense_h_to_4h_gate",
)
for i in range(num_layers)
]
)
def mlp_forward(self, hidden_states, **kw_args):
x = hidden_states
origin = self.transformer.layers[kw_args["layer_id"]].mlp
x1 = origin.dense_h_to_4h(x)
x2 = self.w2[kw_args["layer_id"]](x)
hidden = origin.activation_func(x2) * x1
x = origin.dense_4h_to_h(hidden)
return x
class AdaLNMixin(BaseMixin):
def __init__(
self,
hidden_size,
num_layers,
time_embed_dim,
compressed_num_frames,
qk_ln=True,
hidden_size_head=None,
elementwise_affine=True,
):
super().__init__()
self.num_layers = num_layers
self.compressed_num_frames = compressed_num_frames
self.adaLN_modulations = nn.ModuleList(
[
nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim, 12 * hidden_size))
for _ in range(num_layers)
]
)
self.qk_ln = qk_ln
if qk_ln:
self.query_layernorm_list = nn.ModuleList(
[
LayerNorm(hidden_size_head, eps=1e-6, elementwise_affine=elementwise_affine)
for _ in range(num_layers)
]
)
self.key_layernorm_list = nn.ModuleList(
[
LayerNorm(hidden_size_head, eps=1e-6, elementwise_affine=elementwise_affine)
for _ in range(num_layers)
]
)
def layer_forward(
self,
hidden_states,
mask,
*args,
**kwargs,
):
text_length = kwargs["text_length"]
# hidden_states (b,(n_t+t*n_i),d)
text_hidden_states = hidden_states[:, :text_length] # (b,n,d)
img_hidden_states = hidden_states[:, text_length:] # (b,(t n),d)
layer = self.transformer.layers[kwargs["layer_id"]]
adaLN_modulation = self.adaLN_modulations[kwargs["layer_id"]]
(
shift_msa,
scale_msa,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
text_shift_msa,
text_scale_msa,
text_gate_msa,
text_shift_mlp,
text_scale_mlp,
text_gate_mlp,
) = adaLN_modulation(kwargs["emb"]).chunk(12, dim=1)
gate_msa, gate_mlp, text_gate_msa, text_gate_mlp = (
gate_msa.unsqueeze(1),
gate_mlp.unsqueeze(1),
text_gate_msa.unsqueeze(1),
text_gate_mlp.unsqueeze(1),
)
# self full attention (b,(t n),d)
img_attention_input = layer.input_layernorm(img_hidden_states)
text_attention_input = layer.input_layernorm(text_hidden_states)
img_attention_input = modulate(img_attention_input, shift_msa, scale_msa)
text_attention_input = modulate(text_attention_input, text_shift_msa, text_scale_msa)
attention_input = torch.cat(
(text_attention_input, img_attention_input), dim=1
) # (b,n_t+t*n_i,d)
attention_output = layer.attention(attention_input, mask, **kwargs)
text_attention_output = attention_output[:, :text_length] # (b,n,d)
img_attention_output = attention_output[:, text_length:] # (b,(t n),d)
if self.transformer.layernorm_order == "sandwich":
text_attention_output = layer.third_layernorm(text_attention_output)
img_attention_output = layer.third_layernorm(img_attention_output)
img_hidden_states = img_hidden_states + gate_msa * img_attention_output # (b,(t n),d)
text_hidden_states = text_hidden_states + text_gate_msa * text_attention_output # (b,n,d)
# mlp (b,(t n),d)
img_mlp_input = layer.post_attention_layernorm(img_hidden_states) # vision (b,(t n),d)
text_mlp_input = layer.post_attention_layernorm(text_hidden_states) # language (b,n,d)
img_mlp_input = modulate(img_mlp_input, shift_mlp, scale_mlp)
text_mlp_input = modulate(text_mlp_input, text_shift_mlp, text_scale_mlp)
mlp_input = torch.cat((text_mlp_input, img_mlp_input), dim=1) # (b,(n_t+t*n_i),d
mlp_output = layer.mlp(mlp_input, **kwargs)
img_mlp_output = mlp_output[:, text_length:] # vision (b,(t n),d)
text_mlp_output = mlp_output[:, :text_length] # language (b,n,d)
if self.transformer.layernorm_order == "sandwich":
text_mlp_output = layer.fourth_layernorm(text_mlp_output)
img_mlp_output = layer.fourth_layernorm(img_mlp_output)
img_hidden_states = img_hidden_states + gate_mlp * img_mlp_output # vision (b,(t n),d)
text_hidden_states = (
text_hidden_states + text_gate_mlp * text_mlp_output
) # language (b,n,d)
hidden_states = torch.cat(
(text_hidden_states, img_hidden_states), dim=1
) # (b,(n_t+t*n_i),d)
return hidden_states
def reinit(self, parent_model=None):
for layer in self.adaLN_modulations:
nn.init.constant_(layer[-1].weight, 0)
nn.init.constant_(layer[-1].bias, 0)
@non_conflict
def attention_fn(
self,
query_layer,
key_layer,
value_layer,
attention_mask,
attention_dropout=None,
log_attention_weights=None,
scaling_attention_score=True,
old_impl=attention_fn_default,
**kwargs,
):
if self.qk_ln:
query_layernorm = self.query_layernorm_list[kwargs["layer_id"]]
key_layernorm = self.key_layernorm_list[kwargs["layer_id"]]
query_layer = query_layernorm(query_layer)
key_layer = key_layernorm(key_layer)
return old_impl(
query_layer,
key_layer,
value_layer,
attention_mask,
attention_dropout=attention_dropout,
log_attention_weights=log_attention_weights,
scaling_attention_score=scaling_attention_score,
**kwargs,
)
str_to_dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
class DiffusionTransformer(BaseModel):
def __init__(
self,
transformer_args,
num_frames,
time_compressed_rate,
latent_width,
latent_height,
patch_size,
in_channels,
out_channels,
hidden_size,
num_layers,
num_attention_heads,
elementwise_affine,
time_embed_dim=None,
num_classes=None,
modules={},
input_time="adaln",
adm_in_channels=None,
parallel_output=True,
height_interpolation=1.0,
width_interpolation=1.0,
time_interpolation=1.0,
use_SwiGLU=False,
use_RMSNorm=False,
ofs_embed_dim=None,
**kwargs,
):
self.latent_width = latent_width
self.latent_height = latent_height
self.patch_size = patch_size
self.num_frames = num_frames
self.time_compressed_rate = time_compressed_rate
self.spatial_length = latent_width * latent_height // reduce(mul, patch_size[1:])
self.in_channels = in_channels
self.out_channels = out_channels
self.hidden_size = hidden_size
self.model_channels = hidden_size
self.time_embed_dim = time_embed_dim if time_embed_dim is not None else hidden_size
self.ofs_embed_dim = ofs_embed_dim
self.num_classes = num_classes
self.adm_in_channels = adm_in_channels
self.input_time = input_time
self.num_layers = num_layers
self.num_attention_heads = num_attention_heads
self.is_decoder = transformer_args.is_decoder
self.elementwise_affine = elementwise_affine
self.height_interpolation = height_interpolation
self.width_interpolation = width_interpolation
self.time_interpolation = time_interpolation
self.inner_hidden_size = hidden_size * 4
try:
self.dtype = str_to_dtype[kwargs.pop("dtype")]
except:
self.dtype = torch.float32
if use_SwiGLU:
kwargs["activation_func"] = F.silu
elif "activation_func" not in kwargs:
approx_gelu = nn.GELU(approximate="tanh")
kwargs["activation_func"] = approx_gelu
if use_RMSNorm:
kwargs["layernorm"] = RMSNorm
else:
kwargs["layernorm"] = partial(
LayerNorm, elementwise_affine=elementwise_affine, eps=1e-6
)
transformer_args.num_layers = num_layers
transformer_args.hidden_size = hidden_size
transformer_args.num_attention_heads = num_attention_heads
transformer_args.parallel_output = parallel_output
super().__init__(args=transformer_args, transformer=None, **kwargs)
module_configs = modules
self._build_modules(module_configs)
if use_SwiGLU:
self.add_mixin(
"swiglu",
SwiGLUMixin(num_layers, hidden_size, self.inner_hidden_size, bias=False),
reinit=True,
)
def _build_modules(self, module_configs):
model_channels = self.hidden_size
time_embed_dim = self.time_embed_dim
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
if self.ofs_embed_dim is not None:
self.ofs_embed = nn.Sequential(
linear(self.ofs_embed_dim, self.ofs_embed_dim),
nn.SiLU(),
linear(self.ofs_embed_dim, self.ofs_embed_dim),
)
if self.num_classes is not None:
if isinstance(self.num_classes, int):
self.label_emb = nn.Embedding(self.num_classes, time_embed_dim)
elif self.num_classes == "continuous":
print("setting up linear c_adm embedding layer")
self.label_emb = nn.Linear(1, time_embed_dim)
elif self.num_classes == "timestep":
self.label_emb = nn.Sequential(
Timestep(model_channels),
nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
),
)
elif self.num_classes == "sequential":
assert self.adm_in_channels is not None
self.label_emb = nn.Sequential(
nn.Sequential(
linear(self.adm_in_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
)
else:
raise ValueError()
pos_embed_config = module_configs["pos_embed_config"]
self.add_mixin(
"pos_embed",
instantiate_from_config(
pos_embed_config,
height=self.latent_height // self.patch_size[1],
width=self.latent_width // self.patch_size[2],
compressed_num_frames=(self.num_frames - 1) // self.time_compressed_rate + 1,
hidden_size=self.hidden_size,
height_interpolation=self.height_interpolation,
width_interpolation=self.width_interpolation,
time_interpolation=self.time_interpolation,
),
reinit=True,
)
patch_embed_config = module_configs["patch_embed_config"]
self.add_mixin(
"patch_embed",
instantiate_from_config(
patch_embed_config,
patch_size=self.patch_size,
hidden_size=self.hidden_size,
in_channels=self.in_channels,
),
reinit=True,
)
if self.input_time == "adaln":
adaln_layer_config = module_configs["adaln_layer_config"]
self.add_mixin(
"adaln_layer",
instantiate_from_config(
adaln_layer_config,
hidden_size=self.hidden_size,
num_layers=self.num_layers,
compressed_num_frames=(self.num_frames - 1) // self.time_compressed_rate + 1,
hidden_size_head=self.hidden_size // self.num_attention_heads,
time_embed_dim=self.time_embed_dim,
elementwise_affine=self.elementwise_affine,
),
)
else:
raise NotImplementedError
final_layer_config = module_configs["final_layer_config"]
self.add_mixin(
"final_layer",
instantiate_from_config(
final_layer_config,
hidden_size=self.hidden_size,
patch_size=self.patch_size,
out_channels=self.out_channels,
time_embed_dim=self.time_embed_dim,
latent_width=self.latent_width,
latent_height=self.latent_height,
elementwise_affine=self.elementwise_affine,
),
reinit=True,
)
if "lora_config" in module_configs:
lora_config = module_configs["lora_config"]
self.add_mixin(
"lora", instantiate_from_config(lora_config, layer_num=self.num_layers), reinit=True
)
return
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
b, t, d, h, w = x.shape
if x.dtype != self.dtype:
x = x.to(self.dtype)
if "concat_images" in kwargs and kwargs["concat_images"] is not None:
if kwargs["concat_images"].shape[0] != x.shape[0]:
concat_images = kwargs["concat_images"].repeat(2, 1, 1, 1, 1)
else:
concat_images = kwargs["concat_images"]
x = torch.cat([x, concat_images], dim=2)
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
t_emb = timestep_embedding(
timesteps, self.model_channels, repeat_only=False, dtype=self.dtype
)
emb = self.time_embed(t_emb)
if self.num_classes is not None:
assert x.shape[0] % y.shape[0] == 0
y = y.repeat_interleave(x.shape[0] // y.shape[0], dim=0)
emb = emb + self.label_emb(y)
if self.ofs_embed_dim is not None:
ofs_emb = timestep_embedding(
kwargs["ofs"], self.ofs_embed_dim, repeat_only=False, dtype=self.dtype
)
ofs_emb = self.ofs_embed(ofs_emb)
emb = emb + ofs_emb
kwargs["seq_length"] = t * h * w // reduce(mul, self.patch_size)
kwargs["images"] = x
kwargs["emb"] = emb
kwargs["encoder_outputs"] = context
kwargs["text_length"] = context.shape[1]
kwargs["rope_T"] = t // self.patch_size[0]
kwargs["rope_H"] = h // self.patch_size[1]
kwargs["rope_W"] = w // self.patch_size[2]
kwargs["input_ids"] = kwargs["position_ids"] = kwargs["attention_mask"] = torch.ones(
(1, 1)
).to(x.dtype)
output = super().forward(**kwargs)[0]
return output