update 1105 sst test code with fake cp

This commit is contained in:
zR 2024-11-05 12:55:54 +08:00
parent 3a9af5bdd9
commit 4a3035d64e
5 changed files with 137 additions and 97 deletions

View File

@ -179,19 +179,31 @@ class SATVideoDiffusionEngine(nn.Module):
n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])
n_rounds = math.ceil(z.shape[0] / n_samples)
all_out = []
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
for n in range(n_rounds):
if isinstance(self.first_stage_model.decoder, VideoDecoder):
kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])}
else:
kwargs = {}
frame = z.shape[2] * 4 - 3
if frame <= 9:
use_cp = False
else:
use_cp = True
out = self.first_stage_model.decode(z[n * n_samples : (n + 1) * n_samples], use_cp=use_cp, **kwargs)
all_out.append(out)
for n in range(n_rounds):
z_now = z[n * n_samples : (n + 1) * n_samples, :, 1:]
latent_time = z_now.shape[2] # check the time latent
temporal_compress_times = 4
fake_cp_size = min(10, latent_time // 2)
start_frame = 0
recons = []
start_frame = 0
for i in range(fake_cp_size):
end_frame = start_frame + latent_time // fake_cp_size + (1 if i < latent_time % fake_cp_size else 0)
fake_cp_rank0 = True if i == 0 else False
clear_fake_cp_cache = True if i == fake_cp_size - 1 else False
with torch.no_grad():
recon = self.first_stage_model.decode(
z_now[:, :, start_frame:end_frame].contiguous(),
clear_fake_cp_cache=clear_fake_cp_cache,
fake_cp_rank0=fake_cp_rank0,
)
recons.append(recon)
start_frame = end_frame
recons = torch.cat(recons, dim=2)
all_out.append(recons)
out = torch.cat(all_out, dim=0)
return out

View File

@ -654,7 +654,6 @@ class DiffusionTransformer(BaseModel):
time_interpolation=1.0,
use_SwiGLU=False,
use_RMSNorm=False,
cfg_embed_dim=None,
ofs_embed_dim=None,
**kwargs,
):
@ -669,7 +668,6 @@ class DiffusionTransformer(BaseModel):
self.hidden_size = hidden_size
self.model_channels = hidden_size
self.time_embed_dim = time_embed_dim if time_embed_dim is not None else hidden_size
self.cfg_embed_dim = cfg_embed_dim
self.ofs_embed_dim = ofs_embed_dim
self.num_classes = num_classes
self.adm_in_channels = adm_in_channels
@ -728,13 +726,6 @@ class DiffusionTransformer(BaseModel):
linear(self.ofs_embed_dim, self.ofs_embed_dim),
)
if self.cfg_embed_dim is not None:
self.cfg_embed = nn.Sequential(
linear(self.cfg_embed_dim, self.cfg_embed_dim),
nn.SiLU(),
linear(self.cfg_embed_dim, self.cfg_embed_dim),
)
if self.num_classes is not None:
if isinstance(self.num_classes, int):
self.label_emb = nn.Embedding(self.num_classes, time_embed_dim)
@ -848,14 +839,6 @@ class DiffusionTransformer(BaseModel):
ofs_emb = timestep_embedding(kwargs["ofs"], self.ofs_embed_dim, repeat_only=False, dtype=self.dtype)
ofs_emb = self.ofs_embed(ofs_emb)
emb = emb + ofs_emb
if self.cfg_embed_dim is not None:
cfg_emb = kwargs["scale_emb"]
cfg_emb = self.cfg_embed(cfg_emb)
emb = emb + cfg_emb
if "ofs" in kwargs.keys():
ofs_emb = timestep_embedding(kwargs["ofs"], self.ofs_embed_dim, repeat_only=False, dtype=self.dtype)
ofs_emb = self.ofs_embed(ofs_emb)
kwargs["seq_length"] = t * h * w // reduce(mul, self.patch_size)
kwargs["images"] = x

View File

@ -4,7 +4,7 @@ echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
environs="WORLD_SIZE=1 RANK=0 LOCAL_RANK=0 LOCAL_WORLD_SIZE=1"
run_cmd="$environs python sample_video.py --base configs/cogvideox_5b.yaml configs/inference.yaml --seed $RANDOM"
run_cmd="$environs python sample_video.py --base configs/test_cogvideox_5b.yaml configs/test_inference.yaml --seed $RANDOM"
echo ${run_cmd}
eval ${run_cmd}

View File

@ -135,14 +135,14 @@ def sampling_main(args, model_cls):
sample_func = model.sample
num_samples = [1]
force_uc_zero_embeddings = ["txt"]
T, C = args.sampling_num_frames, args.latent_channels
with torch.no_grad():
for text, cnt in tqdm(data_iter):
if args.image2video:
# use with input image shape
text, image_path = text.split('@@')
text, image_path = text.split("@@")
assert os.path.exists(image_path), image_path
image = Image.open(image_path).convert('RGB')
image = Image.open(image_path).convert("RGB")
(img_W, img_H) = image.size
def nearest_multiple_of_16(n):
@ -163,7 +163,7 @@ def sampling_main(args, model_cls):
chained_trainsforms.append(TT.Resize(size=[int(H * 8), int(W * 8)], interpolation=1))
chained_trainsforms.append(TT.ToTensor())
transform = TT.Compose(chained_trainsforms)
image = transform(image).unsqueeze(0).to('cuda')
image = transform(image).unsqueeze(0).to("cuda")
image = image * 2.0 - 1.0
image = image.unsqueeze(2).to(torch.bfloat16)
image = model.encode_first_stage(image, None)
@ -173,7 +173,7 @@ def sampling_main(args, model_cls):
image = torch.concat([image, torch.zeros(pad_shape).to(image.device).to(image.dtype)], dim=1)
else:
image_size = args.sampling_image_size
T, H, W, C = args.sampling_num_frames, image_size[0], image_size[1], args.latent_channels
H, W = image_size[0], image_size[1]
F = 8 # 8x downsampled
image = None
@ -183,11 +183,7 @@ def sampling_main(args, model_cls):
src = global_rank * mp_size
torch.distributed.broadcast_object_list(text_cast, src=src, group=mpu.get_model_parallel_group())
text = text_cast[0]
value_dict = {
'prompt': text,
'negative_prompt': '',
'num_frames': torch.tensor(T).unsqueeze(0)
}
value_dict = {"prompt": text, "negative_prompt": "", "num_frames": torch.tensor(T).unsqueeze(0)}
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples
@ -216,11 +212,7 @@ def sampling_main(args, model_cls):
for index in range(args.batch_size):
if args.image2video:
samples_z = sample_func(
c,
uc=uc,
batch_size=1,
shape=(T, C, H, W),
ofs=torch.tensor([2.0]).to('cuda')
c, uc=uc, batch_size=1, shape=(T, C, H, W), ofs=torch.tensor([2.0]).to("cuda")
)
else:
samples_z = sample_func(
@ -228,7 +220,7 @@ def sampling_main(args, model_cls):
uc=uc,
batch_size=1,
shape=(T, C, H // F, W // F),
).to('cuda')
).to("cuda")
samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous()
if args.only_save_latents:
@ -250,11 +242,12 @@ def sampling_main(args, model_cls):
if mpu.get_model_parallel_rank() == 0:
save_video_as_grid_and_mp4(samples, save_path, fps=args.sampling_fps)
if __name__ == '__main__':
if 'OMPI_COMM_WORLD_LOCAL_RANK' in os.environ:
os.environ['LOCAL_RANK'] = os.environ['OMPI_COMM_WORLD_LOCAL_RANK']
os.environ['WORLD_SIZE'] = os.environ['OMPI_COMM_WORLD_SIZE']
os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK']
if __name__ == "__main__":
if "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ:
os.environ["LOCAL_RANK"] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]
os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
py_parser = argparse.ArgumentParser(add_help=False)
known, args_list = py_parser.parse_known_args()

View File

@ -5,7 +5,8 @@ import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from beartype.typing import Union, Tuple
from beartype import beartype
from beartype.typing import Union, Tuple, Optional, List
from einops import rearrange
from sgm.util import (
@ -76,7 +77,6 @@ def _split(input_, dim):
cp_rank = get_context_parallel_rank()
inpu_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous()
input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous()
dim_size = input_.size()[dim] // cp_world_size
@ -88,7 +88,6 @@ def _split(input_, dim):
output = torch.cat([inpu_first_frame_, output], dim=dim)
output = output.contiguous()
return output
@ -421,7 +420,8 @@ class ContextParallelGroupNorm(torch.nn.GroupNorm):
return output
def Normalize(in_channels, gather=False, **kwargs): # same for 3D and 2D
def Normalize(in_channels, gather=False, **kwargs):
# same for 3D and 2D
if gather:
return ContextParallelGroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
else:
@ -444,6 +444,7 @@ class SpatialNorm3D(nn.Module):
self.norm_layer = ContextParallelGroupNorm(num_channels=f_channels, **norm_layer_params)
else:
self.norm_layer = torch.nn.GroupNorm(num_channels=f_channels, **norm_layer_params)
# self.norm_layer = norm_layer(num_channels=f_channels, **norm_layer_params)
if freeze_norm_layer:
for p in self.norm_layer.parameters:
p.requires_grad = False
@ -467,24 +468,34 @@ class SpatialNorm3D(nn.Module):
kernel_size=1,
)
def forward(self, f, zq, clear_fake_cp_cache=True):
if f.shape[2] > 1 and f.shape[2] % 2 == 1:
def forward(self, f, zq, clear_fake_cp_cache=True, fake_cp_rank0=True):
if f.shape[2] > 1 and get_context_parallel_rank() == 0 and fake_cp_rank0:
f_first, f_rest = f[:, :, :1], f[:, :, 1:]
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
zq_first, zq_rest = zq[:, :, :1], zq[:, :, 1:]
zq_first = torch.nn.functional.interpolate(zq_first, size=f_first_size, mode="nearest")
zq_rest = torch.nn.functional.interpolate(zq_rest, size=f_rest_size, mode="nearest")
zq_rest_splits = torch.split(zq_rest, 32, dim=1)
interpolated_splits = [
torch.nn.functional.interpolate(split, size=f_rest_size, mode="nearest") for split in zq_rest_splits
]
zq_rest = torch.cat(interpolated_splits, dim=1)
# zq_rest = torch.nn.functional.interpolate(zq_rest, size=f_rest_size, mode="nearest")
zq = torch.cat([zq_first, zq_rest], dim=2)
else:
zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:], mode="nearest")
f_size = f.shape[-3:]
zq_splits = torch.split(zq, 32, dim=1)
interpolated_splits = [
torch.nn.functional.interpolate(split, size=f_size, mode="nearest") for split in zq_splits
]
zq = torch.cat(interpolated_splits, dim=1)
if self.add_conv:
zq = self.conv(zq, clear_cache=clear_fake_cp_cache)
# f = conv_gather_from_context_parallel_region(f, dim=2, kernel_size=1)
norm_f = self.norm_layer(f)
# norm_f = conv_scatter_to_context_parallel_region(norm_f, dim=2, kernel_size=1)
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
return new_f
@ -520,26 +531,39 @@ class Upsample3D(nn.Module):
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
self.compress_time = compress_time
def forward(self, x):
def forward(self, x, fake_cp_rank0=True):
if self.compress_time and x.shape[2] > 1:
# Process the time dimension first as x_first
x_first, x_rest = x[:, :, 0], x[:, :, 1:]
# print(x_first.shape)
x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0, mode="nearest")
# split the rest of the frames to avoid MAX_INT overflow in Pytorch
splits = torch.split(x_rest, 16, dim=1)
interpolated_splits = [
torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits
]
x_rest = torch.cat(interpolated_splits, dim=1)
# concatenate the first frame with the rest
x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2)
if get_context_parallel_rank() == 0 and fake_cp_rank0:
# print(x.shape)
# split first frame
x_first, x_rest = x[:, :, 0], x[:, :, 1:]
x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0, mode="nearest")
splits = torch.split(x_rest, 32, dim=1)
interpolated_splits = [
torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits
]
x_rest = torch.cat(interpolated_splits, dim=1)
# x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0, mode="nearest")
x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2)
else:
splits = torch.split(x, 32, dim=1)
interpolated_splits = [
torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits
]
x = torch.cat(interpolated_splits, dim=1)
# x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
else:
# only interpolate 2D
t = x.shape[2]
x = rearrange(x, "b c t h w -> (b t) c h w")
# x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
splits = torch.split(x, 16, dim=1)
splits = torch.split(x, 32, dim=1)
interpolated_splits = [
torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits
]
@ -566,15 +590,17 @@ class DownSample3D(nn.Module):
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0)
self.compress_time = compress_time
def forward(self, x):
def forward(self, x, fake_cp_rank0=True):
if self.compress_time and x.shape[2] > 1:
h, w = x.shape[-2:]
x = rearrange(x, "b c t h w -> (b h w) c t")
if x.shape[-1] % 2 == 1:
if get_context_parallel_rank() == 0 and fake_cp_rank0:
# split first frame
x_first, x_rest = x[..., 0], x[..., 1:]
if x_rest.shape[-1] > 0:
splits = torch.split(x_rest, 16, dim=1)
splits = torch.split(x_rest, 32, dim=1)
interpolated_splits = [
torch.nn.functional.avg_pool1d(split, kernel_size=2, stride=2) for split in splits
]
@ -582,7 +608,8 @@ class DownSample3D(nn.Module):
x = torch.cat([x_first[..., None], x_rest], dim=-1)
x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w)
else:
splits = torch.split(x, 16, dim=1)
# x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2)
splits = torch.split(x, 32, dim=1)
interpolated_splits = [
torch.nn.functional.avg_pool1d(split, kernel_size=2, stride=2) for split in splits
]
@ -666,21 +693,33 @@ class ContextParallelResnetBlock3D(nn.Module):
padding=0,
)
def forward(self, x, temb, zq=None, clear_fake_cp_cache=True):
def forward(self, x, temb, zq=None, clear_fake_cp_cache=True, fake_cp_rank0=True):
h = x
# if isinstance(self.norm1, torch.nn.GroupNorm):
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
if zq is not None:
h = self.norm1(h, zq, clear_fake_cp_cache=clear_fake_cp_cache)
h = self.norm1(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0)
else:
h = self.norm1(h)
# if isinstance(self.norm1, torch.nn.GroupNorm):
# h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
h = nonlinearity(h)
h = self.conv1(h, clear_cache=clear_fake_cp_cache)
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None]
# if isinstance(self.norm2, torch.nn.GroupNorm):
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
if zq is not None:
h = self.norm2(h, zq, clear_fake_cp_cache=clear_fake_cp_cache)
h = self.norm2(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0)
else:
h = self.norm2(h)
# if isinstance(self.norm2, torch.nn.GroupNorm):
# h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h, clear_cache=clear_fake_cp_cache)
@ -788,28 +827,32 @@ class ContextParallelEncoder3D(nn.Module):
kernel_size=3,
)
def forward(self, x, **kwargs):
def forward(self, x, clear_fake_cp_cache=True, fake_cp_rank0=True):
# timestep embedding
temb = None
# downsampling
h = self.conv_in(x)
h = self.conv_in(x, clear_cache=clear_fake_cp_cache)
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](h, temb)
h = self.down[i_level].block[i_block](h, temb, clear_fake_cp_cache=clear_fake_cp_cache)
if len(self.down[i_level].attn) > 0:
print("Attention not implemented")
h = self.down[i_level].attn[i_block](h)
if i_level != self.num_resolutions - 1:
h = self.down[i_level].downsample(h)
h = self.down[i_level].downsample(h, fake_cp_rank0=fake_cp_rank0)
# middle
h = self.mid.block_1(h, temb)
h = self.mid.block_2(h, temb)
h = self.mid.block_1(h, temb, clear_fake_cp_cache=clear_fake_cp_cache)
h = self.mid.block_2(h, temb, clear_fake_cp_cache=clear_fake_cp_cache)
# end
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
h = self.norm_out(h)
# h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
h = nonlinearity(h)
h = self.conv_out(h)
h = self.conv_out(h, clear_cache=clear_fake_cp_cache)
return h
@ -912,10 +955,15 @@ class ContextParallelDecoder3D(nn.Module):
up.block = block
up.attn = attn
if i_level != 0:
# # Symmetrical enc-dec
if i_level <= self.temporal_compress_level:
up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=True)
else:
up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False)
# if i_level < self.num_resolutions - self.temporal_compress_level:
# up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False)
# else:
# up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=True)
self.up.insert(0, up)
self.norm_out = Normalize3D(block_in, zq_ch, add_conv=add_conv, gather=gather_norm)
@ -926,34 +974,38 @@ class ContextParallelDecoder3D(nn.Module):
kernel_size=3,
)
def forward(self, z, clear_fake_cp_cache=True, **kwargs):
def forward(self, z, clear_fake_cp_cache=True, fake_cp_rank0=True):
self.last_z_shape = z.shape
# timestep embedding
temb = None
t = z.shape[2]
# z to block_in
zq = z
h = self.conv_in(z, clear_cache=clear_fake_cp_cache)
# middle
h = self.mid.block_1(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache)
h = self.mid.block_2(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache)
h = self.mid.block_1(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0)
h = self.mid.block_2(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache)
h = self.up[i_level].block[i_block](
h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0
)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h, zq)
if i_level != 0:
h = self.up[i_level].upsample(h)
h = self.up[i_level].upsample(h, fake_cp_rank0=fake_cp_rank0)
# end
if self.give_pre_end:
return h
h = self.norm_out(h, zq, clear_fake_cp_cache=clear_fake_cp_cache)
h = self.norm_out(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp_rank0=fake_cp_rank0)
h = nonlinearity(h)
h = self.conv_out(h, clear_cache=clear_fake_cp_cache)