mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 19:41:59 +08:00
update 1105 sst test code with fake cp
This commit is contained in:
parent
3a9af5bdd9
commit
4a3035d64e
@ -179,19 +179,31 @@ class SATVideoDiffusionEngine(nn.Module):
|
||||
n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])
|
||||
n_rounds = math.ceil(z.shape[0] / n_samples)
|
||||
all_out = []
|
||||
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
|
||||
for n in range(n_rounds):
|
||||
if isinstance(self.first_stage_model.decoder, VideoDecoder):
|
||||
kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])}
|
||||
else:
|
||||
kwargs = {}
|
||||
frame = z.shape[2] * 4 - 3
|
||||
if frame <= 9:
|
||||
use_cp = False
|
||||
else:
|
||||
use_cp = True
|
||||
out = self.first_stage_model.decode(z[n * n_samples : (n + 1) * n_samples], use_cp=use_cp, **kwargs)
|
||||
all_out.append(out)
|
||||
for n in range(n_rounds):
|
||||
z_now = z[n * n_samples : (n + 1) * n_samples, :, 1:]
|
||||
latent_time = z_now.shape[2] # check the time latent
|
||||
temporal_compress_times = 4
|
||||
|
||||
fake_cp_size = min(10, latent_time // 2)
|
||||
start_frame = 0
|
||||
|
||||
recons = []
|
||||
start_frame = 0
|
||||
for i in range(fake_cp_size):
|
||||
end_frame = start_frame + latent_time // fake_cp_size + (1 if i < latent_time % fake_cp_size else 0)
|
||||
|
||||
fake_cp_rank0 = True if i == 0 else False
|
||||
clear_fake_cp_cache = True if i == fake_cp_size - 1 else False
|
||||
with torch.no_grad():
|
||||
recon = self.first_stage_model.decode(
|
||||
z_now[:, :, start_frame:end_frame].contiguous(),
|
||||
clear_fake_cp_cache=clear_fake_cp_cache,
|
||||
fake_cp_rank0=fake_cp_rank0,
|
||||
)
|
||||
recons.append(recon)
|
||||
start_frame = end_frame
|
||||
recons = torch.cat(recons, dim=2)
|
||||
all_out.append(recons)
|
||||
out = torch.cat(all_out, dim=0)
|
||||
return out
|
||||
|
||||
|
@ -654,7 +654,6 @@ class DiffusionTransformer(BaseModel):
|
||||
time_interpolation=1.0,
|
||||
use_SwiGLU=False,
|
||||
use_RMSNorm=False,
|
||||
cfg_embed_dim=None,
|
||||
ofs_embed_dim=None,
|
||||
**kwargs,
|
||||
):
|
||||
@ -669,7 +668,6 @@ class DiffusionTransformer(BaseModel):
|
||||
self.hidden_size = hidden_size
|
||||
self.model_channels = hidden_size
|
||||
self.time_embed_dim = time_embed_dim if time_embed_dim is not None else hidden_size
|
||||
self.cfg_embed_dim = cfg_embed_dim
|
||||
self.ofs_embed_dim = ofs_embed_dim
|
||||
self.num_classes = num_classes
|
||||
self.adm_in_channels = adm_in_channels
|
||||
@ -728,13 +726,6 @@ class DiffusionTransformer(BaseModel):
|
||||
linear(self.ofs_embed_dim, self.ofs_embed_dim),
|
||||
)
|
||||
|
||||
if self.cfg_embed_dim is not None:
|
||||
self.cfg_embed = nn.Sequential(
|
||||
linear(self.cfg_embed_dim, self.cfg_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(self.cfg_embed_dim, self.cfg_embed_dim),
|
||||
)
|
||||
|
||||
if self.num_classes is not None:
|
||||
if isinstance(self.num_classes, int):
|
||||
self.label_emb = nn.Embedding(self.num_classes, time_embed_dim)
|
||||
@ -848,14 +839,6 @@ class DiffusionTransformer(BaseModel):
|
||||
ofs_emb = timestep_embedding(kwargs["ofs"], self.ofs_embed_dim, repeat_only=False, dtype=self.dtype)
|
||||
ofs_emb = self.ofs_embed(ofs_emb)
|
||||
emb = emb + ofs_emb
|
||||
if self.cfg_embed_dim is not None:
|
||||
cfg_emb = kwargs["scale_emb"]
|
||||
cfg_emb = self.cfg_embed(cfg_emb)
|
||||
emb = emb + cfg_emb
|
||||
|
||||
if "ofs" in kwargs.keys():
|
||||
ofs_emb = timestep_embedding(kwargs["ofs"], self.ofs_embed_dim, repeat_only=False, dtype=self.dtype)
|
||||
ofs_emb = self.ofs_embed(ofs_emb)
|
||||
|
||||
kwargs["seq_length"] = t * h * w // reduce(mul, self.patch_size)
|
||||
kwargs["images"] = x
|
||||
|
@ -4,7 +4,7 @@ echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
|
||||
environs="WORLD_SIZE=1 RANK=0 LOCAL_RANK=0 LOCAL_WORLD_SIZE=1"
|
||||
|
||||
run_cmd="$environs python sample_video.py --base configs/cogvideox_5b.yaml configs/inference.yaml --seed $RANDOM"
|
||||
run_cmd="$environs python sample_video.py --base configs/test_cogvideox_5b.yaml configs/test_inference.yaml --seed $RANDOM"
|
||||
|
||||
echo ${run_cmd}
|
||||
eval ${run_cmd}
|
||||
|
@ -135,14 +135,14 @@ def sampling_main(args, model_cls):
|
||||
sample_func = model.sample
|
||||
num_samples = [1]
|
||||
force_uc_zero_embeddings = ["txt"]
|
||||
|
||||
T, C = args.sampling_num_frames, args.latent_channels
|
||||
with torch.no_grad():
|
||||
for text, cnt in tqdm(data_iter):
|
||||
if args.image2video:
|
||||
# use with input image shape
|
||||
text, image_path = text.split('@@')
|
||||
text, image_path = text.split("@@")
|
||||
assert os.path.exists(image_path), image_path
|
||||
image = Image.open(image_path).convert('RGB')
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
(img_W, img_H) = image.size
|
||||
|
||||
def nearest_multiple_of_16(n):
|
||||
@ -163,7 +163,7 @@ def sampling_main(args, model_cls):
|
||||
chained_trainsforms.append(TT.Resize(size=[int(H * 8), int(W * 8)], interpolation=1))
|
||||
chained_trainsforms.append(TT.ToTensor())
|
||||
transform = TT.Compose(chained_trainsforms)
|
||||
image = transform(image).unsqueeze(0).to('cuda')
|
||||
image = transform(image).unsqueeze(0).to("cuda")
|
||||
image = image * 2.0 - 1.0
|
||||
image = image.unsqueeze(2).to(torch.bfloat16)
|
||||
image = model.encode_first_stage(image, None)
|
||||
@ -173,7 +173,7 @@ def sampling_main(args, model_cls):
|
||||
image = torch.concat([image, torch.zeros(pad_shape).to(image.device).to(image.dtype)], dim=1)
|
||||
else:
|
||||
image_size = args.sampling_image_size
|
||||
T, H, W, C = args.sampling_num_frames, image_size[0], image_size[1], args.latent_channels
|
||||
H, W = image_size[0], image_size[1]
|
||||
F = 8 # 8x downsampled
|
||||
image = None
|
||||
|
||||
@ -183,11 +183,7 @@ def sampling_main(args, model_cls):
|
||||
src = global_rank * mp_size
|
||||
torch.distributed.broadcast_object_list(text_cast, src=src, group=mpu.get_model_parallel_group())
|
||||
text = text_cast[0]
|
||||
value_dict = {
|
||||
'prompt': text,
|
||||
'negative_prompt': '',
|
||||
'num_frames': torch.tensor(T).unsqueeze(0)
|
||||
}
|
||||
value_dict = {"prompt": text, "negative_prompt": "", "num_frames": torch.tensor(T).unsqueeze(0)}
|
||||
|
||||
batch, batch_uc = get_batch(
|
||||
get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples
|
||||
@ -216,11 +212,7 @@ def sampling_main(args, model_cls):
|
||||
for index in range(args.batch_size):
|
||||
if args.image2video:
|
||||
samples_z = sample_func(
|
||||
c,
|
||||
uc=uc,
|
||||
batch_size=1,
|
||||
shape=(T, C, H, W),
|
||||
ofs=torch.tensor([2.0]).to('cuda')
|
||||
c, uc=uc, batch_size=1, shape=(T, C, H, W), ofs=torch.tensor([2.0]).to("cuda")
|
||||
)
|
||||
else:
|
||||
samples_z = sample_func(
|
||||
@ -228,7 +220,7 @@ def sampling_main(args, model_cls):
|
||||
uc=uc,
|
||||
batch_size=1,
|
||||
shape=(T, C, H // F, W // F),
|
||||
).to('cuda')
|
||||
).to("cuda")
|
||||
|
||||
samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous()
|
||||
if args.only_save_latents:
|
||||
@ -250,11 +242,12 @@ def sampling_main(args, model_cls):
|
||||
if mpu.get_model_parallel_rank() == 0:
|
||||
save_video_as_grid_and_mp4(samples, save_path, fps=args.sampling_fps)
|
||||
|
||||
if __name__ == '__main__':
|
||||
if 'OMPI_COMM_WORLD_LOCAL_RANK' in os.environ:
|
||||
os.environ['LOCAL_RANK'] = os.environ['OMPI_COMM_WORLD_LOCAL_RANK']
|
||||
os.environ['WORLD_SIZE'] = os.environ['OMPI_COMM_WORLD_SIZE']
|
||||
os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK']
|
||||
|
||||
if __name__ == "__main__":
|
||||
if "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ:
|
||||
os.environ["LOCAL_RANK"] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]
|
||||
os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
|
||||
os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
|
||||
py_parser = argparse.ArgumentParser(add_help=False)
|
||||
known, args_list = py_parser.parse_known_args()
|
||||
|
||||
|
@ -5,7 +5,8 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
from beartype.typing import Union, Tuple
|
||||
from beartype import beartype
|
||||
from beartype.typing import Union, Tuple, Optional, List
|
||||
from einops import rearrange
|
||||
|
||||
from sgm.util import (
|
||||
@ -76,7 +77,6 @@ def _split(input_, dim):
|
||||
|
||||
cp_rank = get_context_parallel_rank()
|
||||
|
||||
|
||||
inpu_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous()
|
||||
input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous()
|
||||
dim_size = input_.size()[dim] // cp_world_size
|
||||
@ -88,7 +88,6 @@ def _split(input_, dim):
|
||||
output = torch.cat([inpu_first_frame_, output], dim=dim)
|
||||
output = output.contiguous()
|
||||
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@ -421,7 +420,8 @@ class ContextParallelGroupNorm(torch.nn.GroupNorm):
|
||||
return output
|
||||
|
||||
|
||||
def Normalize(in_channels, gather=False, **kwargs): # same for 3D and 2D
|
||||
def Normalize(in_channels, gather=False, **kwargs):
|
||||
# same for 3D and 2D
|
||||
if gather:
|
||||
return ContextParallelGroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
else:
|
||||
@ -444,6 +444,7 @@ class SpatialNorm3D(nn.Module):
|
||||
self.norm_layer = ContextParallelGroupNorm(num_channels=f_channels, **norm_layer_params)
|
||||
else:
|
||||
self.norm_layer = torch.nn.GroupNorm(num_channels=f_channels, **norm_layer_params)
|
||||
# self.norm_layer = norm_layer(num_channels=f_channels, **norm_layer_params)
|
||||
if freeze_norm_layer:
|
||||
for p in self.norm_layer.parameters:
|
||||
p.requires_grad = False
|
||||
@ -467,24 +468,34 @@ class SpatialNorm3D(nn.Module):
|
||||
kernel_size=1,
|
||||
)
|
||||
|
||||
def forward(self, f, zq, clear_fake_cp_cache=True):
|
||||
if f.shape[2] > 1 and f.shape[2] % 2 == 1:
|
||||
def forward(self, f, zq, clear_fake_cp_cache=True, fake_cp_rank0=True):
|
||||
if f.shape[2] > 1 and get_context_parallel_rank() == 0 and fake_cp_rank0:
|
||||
f_first, f_rest = f[:, :, :1], f[:, :, 1:]
|
||||
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
|
||||
zq_first, zq_rest = zq[:, :, :1], zq[:, :, 1:]
|
||||
zq_first = torch.nn.functional.interpolate(zq_first, size=f_first_size, mode="nearest")
|
||||
zq_rest = torch.nn.functional.interpolate(zq_rest, size=f_rest_size, mode="nearest")
|
||||
|
||||
zq_rest_splits = torch.split(zq_rest, 32, dim=1)
|
||||
interpolated_splits = [
|
||||
torch.nn.functional.interpolate(split, size=f_rest_size, mode="nearest") for split in zq_rest_splits
|
||||
]
|
||||
|
||||
zq_rest = torch.cat(interpolated_splits, dim=1)
|
||||
# zq_rest = torch.nn.functional.interpolate(zq_rest, size=f_rest_size, mode="nearest")
|
||||
zq = torch.cat([zq_first, zq_rest], dim=2)
|
||||
else:
|
||||
zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:], mode="nearest")
|
||||
f_size = f.shape[-3:]
|
||||
|
||||
zq_splits = torch.split(zq, 32, dim=1)
|
||||
interpolated_splits = [
|
||||
torch.nn.functional.interpolate(split, size=f_size, mode="nearest") for split in zq_splits
|
||||
]
|
||||
zq = torch.cat(interpolated_splits, dim=1)
|
||||
|
||||
if self.add_conv:
|
||||
zq = self.conv(zq, clear_cache=clear_fake_cp_cache)
|
||||
|
||||
# f = conv_gather_from_context_parallel_region(f, dim=2, kernel_size=1)
|
||||
norm_f = self.norm_layer(f)
|
||||
# norm_f = conv_scatter_to_context_parallel_region(norm_f, dim=2, kernel_size=1)
|
||||
|
||||
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
|
||||
return new_f
|
||||
|
||||
@ -520,26 +531,39 @@ class Upsample3D(nn.Module):
|
||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.compress_time = compress_time
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, fake_cp_rank0=True):
|
||||
if self.compress_time and x.shape[2] > 1:
|
||||
# Process the time dimension first as x_first
|
||||
x_first, x_rest = x[:, :, 0], x[:, :, 1:]
|
||||
# print(x_first.shape)
|
||||
x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0, mode="nearest")
|
||||
# split the rest of the frames to avoid MAX_INT overflow in Pytorch
|
||||
splits = torch.split(x_rest, 16, dim=1)
|
||||
interpolated_splits = [
|
||||
torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits
|
||||
]
|
||||
x_rest = torch.cat(interpolated_splits, dim=1)
|
||||
# concatenate the first frame with the rest
|
||||
x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2)
|
||||
if get_context_parallel_rank() == 0 and fake_cp_rank0:
|
||||
# print(x.shape)
|
||||
# split first frame
|
||||
x_first, x_rest = x[:, :, 0], x[:, :, 1:]
|
||||
|
||||
x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0, mode="nearest")
|
||||
|
||||
splits = torch.split(x_rest, 32, dim=1)
|
||||
interpolated_splits = [
|
||||
torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits
|
||||
]
|
||||
x_rest = torch.cat(interpolated_splits, dim=1)
|
||||
|
||||
# x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0, mode="nearest")
|
||||
x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2)
|
||||
else:
|
||||
splits = torch.split(x, 32, dim=1)
|
||||
interpolated_splits = [
|
||||
torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits
|
||||
]
|
||||
x = torch.cat(interpolated_splits, dim=1)
|
||||
|
||||
# x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
|
||||
else:
|
||||
# only interpolate 2D
|
||||
t = x.shape[2]
|
||||
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||
# x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
|
||||
splits = torch.split(x, 16, dim=1)
|
||||
splits = torch.split(x, 32, dim=1)
|
||||
interpolated_splits = [
|
||||
torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits
|
||||
]
|
||||
@ -566,15 +590,17 @@ class DownSample3D(nn.Module):
|
||||
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0)
|
||||
self.compress_time = compress_time
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, fake_cp_rank0=True):
|
||||
if self.compress_time and x.shape[2] > 1:
|
||||
h, w = x.shape[-2:]
|
||||
x = rearrange(x, "b c t h w -> (b h w) c t")
|
||||
|
||||
if x.shape[-1] % 2 == 1:
|
||||
if get_context_parallel_rank() == 0 and fake_cp_rank0:
|
||||
# split first frame
|
||||
x_first, x_rest = x[..., 0], x[..., 1:]
|
||||
|
||||
if x_rest.shape[-1] > 0:
|
||||
splits = torch.split(x_rest, 16, dim=1)
|
||||
splits = torch.split(x_rest, 32, dim=1)
|
||||
interpolated_splits = [
|
||||
torch.nn.functional.avg_pool1d(split, kernel_size=2, stride=2) for split in splits
|
||||
]
|
||||
@ -582,7 +608,8 @@ class DownSample3D(nn.Module):
|
||||
x = torch.cat([x_first[..., None], x_rest], dim=-1)
|
||||
x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w)
|
||||
else:
|
||||
splits = torch.split(x, 16, dim=1)
|
||||
# x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2)
|
||||
splits = torch.split(x, 32, dim=1)
|
||||
interpolated_splits = [
|
||||
torch.nn.functional.avg_pool1d(split, kernel_size=2, stride=2) for split in splits
|
||||
]
|
||||
@ -666,21 +693,33 @@ class ContextParallelResnetBlock3D(nn.Module):
|
||||
padding=0,
|
||||
)
|
||||
|
||||
def forward(self, x, temb, zq=None, clear_fake_cp_cache=True):
|
||||
def forward(self, x, temb, zq=None, clear_fake_cp_cache=True, fake_cp_rank0=True):
|
||||
h = x
|
||||
|
||||
# if isinstance(self.norm1, torch.nn.GroupNorm):
|
||||
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
|
||||
if zq is not None:
|
||||
h = self.norm1(h, zq, clear_fake_cp_cache=clear_fake_cp_cache)
|
||||
h = self.norm1(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0)
|
||||
else:
|
||||
h = self.norm1(h)
|
||||
# if isinstance(self.norm1, torch.nn.GroupNorm):
|
||||
# h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
|
||||
|
||||
h = nonlinearity(h)
|
||||
h = self.conv1(h, clear_cache=clear_fake_cp_cache)
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None]
|
||||
|
||||
# if isinstance(self.norm2, torch.nn.GroupNorm):
|
||||
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
|
||||
if zq is not None:
|
||||
h = self.norm2(h, zq, clear_fake_cp_cache=clear_fake_cp_cache)
|
||||
h = self.norm2(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0)
|
||||
else:
|
||||
h = self.norm2(h)
|
||||
# if isinstance(self.norm2, torch.nn.GroupNorm):
|
||||
# h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
|
||||
|
||||
h = nonlinearity(h)
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h, clear_cache=clear_fake_cp_cache)
|
||||
@ -788,28 +827,32 @@ class ContextParallelEncoder3D(nn.Module):
|
||||
kernel_size=3,
|
||||
)
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
def forward(self, x, clear_fake_cp_cache=True, fake_cp_rank0=True):
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# downsampling
|
||||
h = self.conv_in(x)
|
||||
h = self.conv_in(x, clear_cache=clear_fake_cp_cache)
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](h, temb)
|
||||
h = self.down[i_level].block[i_block](h, temb, clear_fake_cp_cache=clear_fake_cp_cache)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
print("Attention not implemented")
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
h = self.down[i_level].downsample(h)
|
||||
h = self.down[i_level].downsample(h, fake_cp_rank0=fake_cp_rank0)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.block_2(h, temb)
|
||||
h = self.mid.block_1(h, temb, clear_fake_cp_cache=clear_fake_cp_cache)
|
||||
h = self.mid.block_2(h, temb, clear_fake_cp_cache=clear_fake_cp_cache)
|
||||
|
||||
# end
|
||||
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
|
||||
h = self.norm_out(h)
|
||||
# h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
|
||||
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
h = self.conv_out(h, clear_cache=clear_fake_cp_cache)
|
||||
|
||||
return h
|
||||
|
||||
@ -912,10 +955,15 @@ class ContextParallelDecoder3D(nn.Module):
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
# # Symmetrical enc-dec
|
||||
if i_level <= self.temporal_compress_level:
|
||||
up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=True)
|
||||
else:
|
||||
up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False)
|
||||
# if i_level < self.num_resolutions - self.temporal_compress_level:
|
||||
# up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False)
|
||||
# else:
|
||||
# up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=True)
|
||||
self.up.insert(0, up)
|
||||
|
||||
self.norm_out = Normalize3D(block_in, zq_ch, add_conv=add_conv, gather=gather_norm)
|
||||
@ -926,34 +974,38 @@ class ContextParallelDecoder3D(nn.Module):
|
||||
kernel_size=3,
|
||||
)
|
||||
|
||||
def forward(self, z, clear_fake_cp_cache=True, **kwargs):
|
||||
def forward(self, z, clear_fake_cp_cache=True, fake_cp_rank0=True):
|
||||
self.last_z_shape = z.shape
|
||||
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
t = z.shape[2]
|
||||
# z to block_in
|
||||
|
||||
zq = z
|
||||
h = self.conv_in(z, clear_cache=clear_fake_cp_cache)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache)
|
||||
h = self.mid.block_2(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache)
|
||||
h = self.mid.block_1(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0)
|
||||
h = self.mid.block_2(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache)
|
||||
h = self.up[i_level].block[i_block](
|
||||
h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0
|
||||
)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h, zq)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
h = self.up[i_level].upsample(h, fake_cp_rank0=fake_cp_rank0)
|
||||
|
||||
# end
|
||||
if self.give_pre_end:
|
||||
return h
|
||||
|
||||
h = self.norm_out(h, zq, clear_fake_cp_cache=clear_fake_cp_cache)
|
||||
h = self.norm_out(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h, clear_cache=clear_fake_cp_cache)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user