update 1105 sst test code with fake cp

This commit is contained in:
zR 2024-11-05 12:55:54 +08:00
parent 3a9af5bdd9
commit 4a3035d64e
5 changed files with 137 additions and 97 deletions

View File

@ -179,19 +179,31 @@ class SATVideoDiffusionEngine(nn.Module):
n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0]) n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])
n_rounds = math.ceil(z.shape[0] / n_samples) n_rounds = math.ceil(z.shape[0] / n_samples)
all_out = [] all_out = []
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast): for n in range(n_rounds):
for n in range(n_rounds): z_now = z[n * n_samples : (n + 1) * n_samples, :, 1:]
if isinstance(self.first_stage_model.decoder, VideoDecoder): latent_time = z_now.shape[2] # check the time latent
kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])} temporal_compress_times = 4
else:
kwargs = {} fake_cp_size = min(10, latent_time // 2)
frame = z.shape[2] * 4 - 3 start_frame = 0
if frame <= 9:
use_cp = False recons = []
else: start_frame = 0
use_cp = True for i in range(fake_cp_size):
out = self.first_stage_model.decode(z[n * n_samples : (n + 1) * n_samples], use_cp=use_cp, **kwargs) end_frame = start_frame + latent_time // fake_cp_size + (1 if i < latent_time % fake_cp_size else 0)
all_out.append(out)
fake_cp_rank0 = True if i == 0 else False
clear_fake_cp_cache = True if i == fake_cp_size - 1 else False
with torch.no_grad():
recon = self.first_stage_model.decode(
z_now[:, :, start_frame:end_frame].contiguous(),
clear_fake_cp_cache=clear_fake_cp_cache,
fake_cp_rank0=fake_cp_rank0,
)
recons.append(recon)
start_frame = end_frame
recons = torch.cat(recons, dim=2)
all_out.append(recons)
out = torch.cat(all_out, dim=0) out = torch.cat(all_out, dim=0)
return out return out

View File

@ -654,7 +654,6 @@ class DiffusionTransformer(BaseModel):
time_interpolation=1.0, time_interpolation=1.0,
use_SwiGLU=False, use_SwiGLU=False,
use_RMSNorm=False, use_RMSNorm=False,
cfg_embed_dim=None,
ofs_embed_dim=None, ofs_embed_dim=None,
**kwargs, **kwargs,
): ):
@ -669,7 +668,6 @@ class DiffusionTransformer(BaseModel):
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.model_channels = hidden_size self.model_channels = hidden_size
self.time_embed_dim = time_embed_dim if time_embed_dim is not None else hidden_size self.time_embed_dim = time_embed_dim if time_embed_dim is not None else hidden_size
self.cfg_embed_dim = cfg_embed_dim
self.ofs_embed_dim = ofs_embed_dim self.ofs_embed_dim = ofs_embed_dim
self.num_classes = num_classes self.num_classes = num_classes
self.adm_in_channels = adm_in_channels self.adm_in_channels = adm_in_channels
@ -728,13 +726,6 @@ class DiffusionTransformer(BaseModel):
linear(self.ofs_embed_dim, self.ofs_embed_dim), linear(self.ofs_embed_dim, self.ofs_embed_dim),
) )
if self.cfg_embed_dim is not None:
self.cfg_embed = nn.Sequential(
linear(self.cfg_embed_dim, self.cfg_embed_dim),
nn.SiLU(),
linear(self.cfg_embed_dim, self.cfg_embed_dim),
)
if self.num_classes is not None: if self.num_classes is not None:
if isinstance(self.num_classes, int): if isinstance(self.num_classes, int):
self.label_emb = nn.Embedding(self.num_classes, time_embed_dim) self.label_emb = nn.Embedding(self.num_classes, time_embed_dim)
@ -848,14 +839,6 @@ class DiffusionTransformer(BaseModel):
ofs_emb = timestep_embedding(kwargs["ofs"], self.ofs_embed_dim, repeat_only=False, dtype=self.dtype) ofs_emb = timestep_embedding(kwargs["ofs"], self.ofs_embed_dim, repeat_only=False, dtype=self.dtype)
ofs_emb = self.ofs_embed(ofs_emb) ofs_emb = self.ofs_embed(ofs_emb)
emb = emb + ofs_emb emb = emb + ofs_emb
if self.cfg_embed_dim is not None:
cfg_emb = kwargs["scale_emb"]
cfg_emb = self.cfg_embed(cfg_emb)
emb = emb + cfg_emb
if "ofs" in kwargs.keys():
ofs_emb = timestep_embedding(kwargs["ofs"], self.ofs_embed_dim, repeat_only=False, dtype=self.dtype)
ofs_emb = self.ofs_embed(ofs_emb)
kwargs["seq_length"] = t * h * w // reduce(mul, self.patch_size) kwargs["seq_length"] = t * h * w // reduce(mul, self.patch_size)
kwargs["images"] = x kwargs["images"] = x

View File

@ -4,7 +4,7 @@ echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
environs="WORLD_SIZE=1 RANK=0 LOCAL_RANK=0 LOCAL_WORLD_SIZE=1" environs="WORLD_SIZE=1 RANK=0 LOCAL_RANK=0 LOCAL_WORLD_SIZE=1"
run_cmd="$environs python sample_video.py --base configs/cogvideox_5b.yaml configs/inference.yaml --seed $RANDOM" run_cmd="$environs python sample_video.py --base configs/test_cogvideox_5b.yaml configs/test_inference.yaml --seed $RANDOM"
echo ${run_cmd} echo ${run_cmd}
eval ${run_cmd} eval ${run_cmd}

View File

@ -135,14 +135,14 @@ def sampling_main(args, model_cls):
sample_func = model.sample sample_func = model.sample
num_samples = [1] num_samples = [1]
force_uc_zero_embeddings = ["txt"] force_uc_zero_embeddings = ["txt"]
T, C = args.sampling_num_frames, args.latent_channels
with torch.no_grad(): with torch.no_grad():
for text, cnt in tqdm(data_iter): for text, cnt in tqdm(data_iter):
if args.image2video: if args.image2video:
# use with input image shape # use with input image shape
text, image_path = text.split('@@') text, image_path = text.split("@@")
assert os.path.exists(image_path), image_path assert os.path.exists(image_path), image_path
image = Image.open(image_path).convert('RGB') image = Image.open(image_path).convert("RGB")
(img_W, img_H) = image.size (img_W, img_H) = image.size
def nearest_multiple_of_16(n): def nearest_multiple_of_16(n):
@ -163,7 +163,7 @@ def sampling_main(args, model_cls):
chained_trainsforms.append(TT.Resize(size=[int(H * 8), int(W * 8)], interpolation=1)) chained_trainsforms.append(TT.Resize(size=[int(H * 8), int(W * 8)], interpolation=1))
chained_trainsforms.append(TT.ToTensor()) chained_trainsforms.append(TT.ToTensor())
transform = TT.Compose(chained_trainsforms) transform = TT.Compose(chained_trainsforms)
image = transform(image).unsqueeze(0).to('cuda') image = transform(image).unsqueeze(0).to("cuda")
image = image * 2.0 - 1.0 image = image * 2.0 - 1.0
image = image.unsqueeze(2).to(torch.bfloat16) image = image.unsqueeze(2).to(torch.bfloat16)
image = model.encode_first_stage(image, None) image = model.encode_first_stage(image, None)
@ -173,7 +173,7 @@ def sampling_main(args, model_cls):
image = torch.concat([image, torch.zeros(pad_shape).to(image.device).to(image.dtype)], dim=1) image = torch.concat([image, torch.zeros(pad_shape).to(image.device).to(image.dtype)], dim=1)
else: else:
image_size = args.sampling_image_size image_size = args.sampling_image_size
T, H, W, C = args.sampling_num_frames, image_size[0], image_size[1], args.latent_channels H, W = image_size[0], image_size[1]
F = 8 # 8x downsampled F = 8 # 8x downsampled
image = None image = None
@ -183,11 +183,7 @@ def sampling_main(args, model_cls):
src = global_rank * mp_size src = global_rank * mp_size
torch.distributed.broadcast_object_list(text_cast, src=src, group=mpu.get_model_parallel_group()) torch.distributed.broadcast_object_list(text_cast, src=src, group=mpu.get_model_parallel_group())
text = text_cast[0] text = text_cast[0]
value_dict = { value_dict = {"prompt": text, "negative_prompt": "", "num_frames": torch.tensor(T).unsqueeze(0)}
'prompt': text,
'negative_prompt': '',
'num_frames': torch.tensor(T).unsqueeze(0)
}
batch, batch_uc = get_batch( batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples
@ -216,11 +212,7 @@ def sampling_main(args, model_cls):
for index in range(args.batch_size): for index in range(args.batch_size):
if args.image2video: if args.image2video:
samples_z = sample_func( samples_z = sample_func(
c, c, uc=uc, batch_size=1, shape=(T, C, H, W), ofs=torch.tensor([2.0]).to("cuda")
uc=uc,
batch_size=1,
shape=(T, C, H, W),
ofs=torch.tensor([2.0]).to('cuda')
) )
else: else:
samples_z = sample_func( samples_z = sample_func(
@ -228,7 +220,7 @@ def sampling_main(args, model_cls):
uc=uc, uc=uc,
batch_size=1, batch_size=1,
shape=(T, C, H // F, W // F), shape=(T, C, H // F, W // F),
).to('cuda') ).to("cuda")
samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous() samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous()
if args.only_save_latents: if args.only_save_latents:
@ -250,11 +242,12 @@ def sampling_main(args, model_cls):
if mpu.get_model_parallel_rank() == 0: if mpu.get_model_parallel_rank() == 0:
save_video_as_grid_and_mp4(samples, save_path, fps=args.sampling_fps) save_video_as_grid_and_mp4(samples, save_path, fps=args.sampling_fps)
if __name__ == '__main__':
if 'OMPI_COMM_WORLD_LOCAL_RANK' in os.environ: if __name__ == "__main__":
os.environ['LOCAL_RANK'] = os.environ['OMPI_COMM_WORLD_LOCAL_RANK'] if "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ:
os.environ['WORLD_SIZE'] = os.environ['OMPI_COMM_WORLD_SIZE'] os.environ["LOCAL_RANK"] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]
os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK'] os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
py_parser = argparse.ArgumentParser(add_help=False) py_parser = argparse.ArgumentParser(add_help=False)
known, args_list = py_parser.parse_known_args() known, args_list = py_parser.parse_known_args()

View File

@ -5,7 +5,8 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import numpy as np import numpy as np
from beartype.typing import Union, Tuple from beartype import beartype
from beartype.typing import Union, Tuple, Optional, List
from einops import rearrange from einops import rearrange
from sgm.util import ( from sgm.util import (
@ -76,7 +77,6 @@ def _split(input_, dim):
cp_rank = get_context_parallel_rank() cp_rank = get_context_parallel_rank()
inpu_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous() inpu_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous()
input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous() input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous()
dim_size = input_.size()[dim] // cp_world_size dim_size = input_.size()[dim] // cp_world_size
@ -88,7 +88,6 @@ def _split(input_, dim):
output = torch.cat([inpu_first_frame_, output], dim=dim) output = torch.cat([inpu_first_frame_, output], dim=dim)
output = output.contiguous() output = output.contiguous()
return output return output
@ -421,7 +420,8 @@ class ContextParallelGroupNorm(torch.nn.GroupNorm):
return output return output
def Normalize(in_channels, gather=False, **kwargs): # same for 3D and 2D def Normalize(in_channels, gather=False, **kwargs):
# same for 3D and 2D
if gather: if gather:
return ContextParallelGroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) return ContextParallelGroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
else: else:
@ -444,6 +444,7 @@ class SpatialNorm3D(nn.Module):
self.norm_layer = ContextParallelGroupNorm(num_channels=f_channels, **norm_layer_params) self.norm_layer = ContextParallelGroupNorm(num_channels=f_channels, **norm_layer_params)
else: else:
self.norm_layer = torch.nn.GroupNorm(num_channels=f_channels, **norm_layer_params) self.norm_layer = torch.nn.GroupNorm(num_channels=f_channels, **norm_layer_params)
# self.norm_layer = norm_layer(num_channels=f_channels, **norm_layer_params)
if freeze_norm_layer: if freeze_norm_layer:
for p in self.norm_layer.parameters: for p in self.norm_layer.parameters:
p.requires_grad = False p.requires_grad = False
@ -467,24 +468,34 @@ class SpatialNorm3D(nn.Module):
kernel_size=1, kernel_size=1,
) )
def forward(self, f, zq, clear_fake_cp_cache=True): def forward(self, f, zq, clear_fake_cp_cache=True, fake_cp_rank0=True):
if f.shape[2] > 1 and f.shape[2] % 2 == 1: if f.shape[2] > 1 and get_context_parallel_rank() == 0 and fake_cp_rank0:
f_first, f_rest = f[:, :, :1], f[:, :, 1:] f_first, f_rest = f[:, :, :1], f[:, :, 1:]
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:] f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
zq_first, zq_rest = zq[:, :, :1], zq[:, :, 1:] zq_first, zq_rest = zq[:, :, :1], zq[:, :, 1:]
zq_first = torch.nn.functional.interpolate(zq_first, size=f_first_size, mode="nearest") zq_first = torch.nn.functional.interpolate(zq_first, size=f_first_size, mode="nearest")
zq_rest = torch.nn.functional.interpolate(zq_rest, size=f_rest_size, mode="nearest")
zq_rest_splits = torch.split(zq_rest, 32, dim=1)
interpolated_splits = [
torch.nn.functional.interpolate(split, size=f_rest_size, mode="nearest") for split in zq_rest_splits
]
zq_rest = torch.cat(interpolated_splits, dim=1)
# zq_rest = torch.nn.functional.interpolate(zq_rest, size=f_rest_size, mode="nearest")
zq = torch.cat([zq_first, zq_rest], dim=2) zq = torch.cat([zq_first, zq_rest], dim=2)
else: else:
zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:], mode="nearest") f_size = f.shape[-3:]
zq_splits = torch.split(zq, 32, dim=1)
interpolated_splits = [
torch.nn.functional.interpolate(split, size=f_size, mode="nearest") for split in zq_splits
]
zq = torch.cat(interpolated_splits, dim=1)
if self.add_conv: if self.add_conv:
zq = self.conv(zq, clear_cache=clear_fake_cp_cache) zq = self.conv(zq, clear_cache=clear_fake_cp_cache)
# f = conv_gather_from_context_parallel_region(f, dim=2, kernel_size=1)
norm_f = self.norm_layer(f) norm_f = self.norm_layer(f)
# norm_f = conv_scatter_to_context_parallel_region(norm_f, dim=2, kernel_size=1)
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
return new_f return new_f
@ -520,26 +531,39 @@ class Upsample3D(nn.Module):
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
self.compress_time = compress_time self.compress_time = compress_time
def forward(self, x): def forward(self, x, fake_cp_rank0=True):
if self.compress_time and x.shape[2] > 1: if self.compress_time and x.shape[2] > 1:
# Process the time dimension first as x_first if get_context_parallel_rank() == 0 and fake_cp_rank0:
x_first, x_rest = x[:, :, 0], x[:, :, 1:] # print(x.shape)
# print(x_first.shape) # split first frame
x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0, mode="nearest") x_first, x_rest = x[:, :, 0], x[:, :, 1:]
# split the rest of the frames to avoid MAX_INT overflow in Pytorch
splits = torch.split(x_rest, 16, dim=1) x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0, mode="nearest")
interpolated_splits = [
torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits splits = torch.split(x_rest, 32, dim=1)
] interpolated_splits = [
x_rest = torch.cat(interpolated_splits, dim=1) torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits
# concatenate the first frame with the rest ]
x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2) x_rest = torch.cat(interpolated_splits, dim=1)
# x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0, mode="nearest")
x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2)
else:
splits = torch.split(x, 32, dim=1)
interpolated_splits = [
torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits
]
x = torch.cat(interpolated_splits, dim=1)
# x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
else: else:
# only interpolate 2D
t = x.shape[2] t = x.shape[2]
x = rearrange(x, "b c t h w -> (b t) c h w") x = rearrange(x, "b c t h w -> (b t) c h w")
# x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
splits = torch.split(x, 16, dim=1) splits = torch.split(x, 32, dim=1)
interpolated_splits = [ interpolated_splits = [
torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits
] ]
@ -566,15 +590,17 @@ class DownSample3D(nn.Module):
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0) self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0)
self.compress_time = compress_time self.compress_time = compress_time
def forward(self, x): def forward(self, x, fake_cp_rank0=True):
if self.compress_time and x.shape[2] > 1: if self.compress_time and x.shape[2] > 1:
h, w = x.shape[-2:] h, w = x.shape[-2:]
x = rearrange(x, "b c t h w -> (b h w) c t") x = rearrange(x, "b c t h w -> (b h w) c t")
if x.shape[-1] % 2 == 1: if get_context_parallel_rank() == 0 and fake_cp_rank0:
# split first frame
x_first, x_rest = x[..., 0], x[..., 1:] x_first, x_rest = x[..., 0], x[..., 1:]
if x_rest.shape[-1] > 0: if x_rest.shape[-1] > 0:
splits = torch.split(x_rest, 16, dim=1) splits = torch.split(x_rest, 32, dim=1)
interpolated_splits = [ interpolated_splits = [
torch.nn.functional.avg_pool1d(split, kernel_size=2, stride=2) for split in splits torch.nn.functional.avg_pool1d(split, kernel_size=2, stride=2) for split in splits
] ]
@ -582,7 +608,8 @@ class DownSample3D(nn.Module):
x = torch.cat([x_first[..., None], x_rest], dim=-1) x = torch.cat([x_first[..., None], x_rest], dim=-1)
x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w) x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w)
else: else:
splits = torch.split(x, 16, dim=1) # x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2)
splits = torch.split(x, 32, dim=1)
interpolated_splits = [ interpolated_splits = [
torch.nn.functional.avg_pool1d(split, kernel_size=2, stride=2) for split in splits torch.nn.functional.avg_pool1d(split, kernel_size=2, stride=2) for split in splits
] ]
@ -666,21 +693,33 @@ class ContextParallelResnetBlock3D(nn.Module):
padding=0, padding=0,
) )
def forward(self, x, temb, zq=None, clear_fake_cp_cache=True): def forward(self, x, temb, zq=None, clear_fake_cp_cache=True, fake_cp_rank0=True):
h = x h = x
# if isinstance(self.norm1, torch.nn.GroupNorm):
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
if zq is not None: if zq is not None:
h = self.norm1(h, zq, clear_fake_cp_cache=clear_fake_cp_cache) h = self.norm1(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0)
else: else:
h = self.norm1(h) h = self.norm1(h)
# if isinstance(self.norm1, torch.nn.GroupNorm):
# h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
h = nonlinearity(h) h = nonlinearity(h)
h = self.conv1(h, clear_cache=clear_fake_cp_cache) h = self.conv1(h, clear_cache=clear_fake_cp_cache)
if temb is not None: if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None] h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None]
# if isinstance(self.norm2, torch.nn.GroupNorm):
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
if zq is not None: if zq is not None:
h = self.norm2(h, zq, clear_fake_cp_cache=clear_fake_cp_cache) h = self.norm2(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0)
else: else:
h = self.norm2(h) h = self.norm2(h)
# if isinstance(self.norm2, torch.nn.GroupNorm):
# h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
h = nonlinearity(h) h = nonlinearity(h)
h = self.dropout(h) h = self.dropout(h)
h = self.conv2(h, clear_cache=clear_fake_cp_cache) h = self.conv2(h, clear_cache=clear_fake_cp_cache)
@ -788,28 +827,32 @@ class ContextParallelEncoder3D(nn.Module):
kernel_size=3, kernel_size=3,
) )
def forward(self, x, **kwargs): def forward(self, x, clear_fake_cp_cache=True, fake_cp_rank0=True):
# timestep embedding # timestep embedding
temb = None temb = None
# downsampling # downsampling
h = self.conv_in(x) h = self.conv_in(x, clear_cache=clear_fake_cp_cache)
for i_level in range(self.num_resolutions): for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks): for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](h, temb) h = self.down[i_level].block[i_block](h, temb, clear_fake_cp_cache=clear_fake_cp_cache)
if len(self.down[i_level].attn) > 0: if len(self.down[i_level].attn) > 0:
print("Attention not implemented")
h = self.down[i_level].attn[i_block](h) h = self.down[i_level].attn[i_block](h)
if i_level != self.num_resolutions - 1: if i_level != self.num_resolutions - 1:
h = self.down[i_level].downsample(h) h = self.down[i_level].downsample(h, fake_cp_rank0=fake_cp_rank0)
# middle # middle
h = self.mid.block_1(h, temb) h = self.mid.block_1(h, temb, clear_fake_cp_cache=clear_fake_cp_cache)
h = self.mid.block_2(h, temb) h = self.mid.block_2(h, temb, clear_fake_cp_cache=clear_fake_cp_cache)
# end # end
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
h = self.norm_out(h) h = self.norm_out(h)
# h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
h = nonlinearity(h) h = nonlinearity(h)
h = self.conv_out(h) h = self.conv_out(h, clear_cache=clear_fake_cp_cache)
return h return h
@ -912,10 +955,15 @@ class ContextParallelDecoder3D(nn.Module):
up.block = block up.block = block
up.attn = attn up.attn = attn
if i_level != 0: if i_level != 0:
# # Symmetrical enc-dec
if i_level <= self.temporal_compress_level: if i_level <= self.temporal_compress_level:
up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=True) up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=True)
else: else:
up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False) up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False)
# if i_level < self.num_resolutions - self.temporal_compress_level:
# up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False)
# else:
# up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=True)
self.up.insert(0, up) self.up.insert(0, up)
self.norm_out = Normalize3D(block_in, zq_ch, add_conv=add_conv, gather=gather_norm) self.norm_out = Normalize3D(block_in, zq_ch, add_conv=add_conv, gather=gather_norm)
@ -926,34 +974,38 @@ class ContextParallelDecoder3D(nn.Module):
kernel_size=3, kernel_size=3,
) )
def forward(self, z, clear_fake_cp_cache=True, **kwargs): def forward(self, z, clear_fake_cp_cache=True, fake_cp_rank0=True):
self.last_z_shape = z.shape self.last_z_shape = z.shape
# timestep embedding # timestep embedding
temb = None temb = None
t = z.shape[2]
# z to block_in # z to block_in
zq = z zq = z
h = self.conv_in(z, clear_cache=clear_fake_cp_cache) h = self.conv_in(z, clear_cache=clear_fake_cp_cache)
# middle # middle
h = self.mid.block_1(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache) h = self.mid.block_1(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0)
h = self.mid.block_2(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache) h = self.mid.block_2(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0)
# upsampling # upsampling
for i_level in reversed(range(self.num_resolutions)): for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1): for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache) h = self.up[i_level].block[i_block](
h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0
)
if len(self.up[i_level].attn) > 0: if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h, zq) h = self.up[i_level].attn[i_block](h, zq)
if i_level != 0: if i_level != 0:
h = self.up[i_level].upsample(h) h = self.up[i_level].upsample(h, fake_cp_rank0=fake_cp_rank0)
# end # end
if self.give_pre_end: if self.give_pre_end:
return h return h
h = self.norm_out(h, zq, clear_fake_cp_cache=clear_fake_cp_cache) h = self.norm_out(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0)
h = nonlinearity(h) h = nonlinearity(h)
h = self.conv_out(h, clear_cache=clear_fake_cp_cache) h = self.conv_out(h, clear_cache=clear_fake_cp_cache)