diff --git a/sat/arguments.py b/sat/arguments.py index 44767d3..9b0a1bb 100644 --- a/sat/arguments.py +++ b/sat/arguments.py @@ -36,6 +36,7 @@ def add_sampling_config_args(parser): group.add_argument("--input-dir", type=str, default=None) group.add_argument("--input-type", type=str, default="cli") group.add_argument("--input-file", type=str, default="input.txt") + group.add_argument("--sampling-image-size", type=list, default=[768, 1360]) group.add_argument("--final-size", type=int, default=2048) group.add_argument("--sdedit", action="store_true") group.add_argument("--grid-num-rows", type=int, default=1) diff --git a/sat/configs/images.jpg b/sat/configs/images.jpg new file mode 100644 index 0000000..2e34d1a Binary files /dev/null and b/sat/configs/images.jpg differ diff --git a/sat/diffusion_video.py b/sat/diffusion_video.py index 963038b..71b9209 100644 --- a/sat/diffusion_video.py +++ b/sat/diffusion_video.py @@ -185,7 +185,12 @@ class SATVideoDiffusionEngine(nn.Module): kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])} else: kwargs = {} - out = self.first_stage_model.decode(z[n * n_samples : (n + 1) * n_samples], **kwargs) + frame = z.shape[2] * 4 - 3 + if frame <= 9: + use_cp = False + else: + use_cp = True + out = self.first_stage_model.decode(z[n * n_samples : (n + 1) * n_samples], use_cp=use_cp, **kwargs) all_out.append(out) out = torch.cat(all_out, dim=0) return out @@ -218,6 +223,7 @@ class SATVideoDiffusionEngine(nn.Module): shape: Union[None, Tuple, List] = None, prefix=None, concat_images=None, + ofs=None, **kwargs, ): randn = torch.randn(batch_size, *shape).to(torch.float32).to(self.device) @@ -241,7 +247,7 @@ class SATVideoDiffusionEngine(nn.Module): self.model, input, sigma, c, concat_images=concat_images, **addtional_model_inputs ) - samples = self.sampler(denoiser, randn, cond, uc=uc, scale=scale, scale_emb=scale_emb) + samples = self.sampler(denoiser, randn, cond, uc=uc, scale=scale, scale_emb=scale_emb, ofs=ofs) samples = samples.to(self.dtype) return samples diff --git a/sat/dit_video_concat.py b/sat/dit_video_concat.py index 7692116..a759fa0 100644 --- a/sat/dit_video_concat.py +++ b/sat/dit_video_concat.py @@ -1,5 +1,7 @@ from functools import partial from einops import rearrange, repeat +from functools import reduce +from operator import mul import numpy as np import torch @@ -13,38 +15,34 @@ from sat.mpu.layers import ColumnParallelLinear from sgm.util import instantiate_from_config from sgm.modules.diffusionmodules.openaimodel import Timestep -from sgm.modules.diffusionmodules.util import ( - linear, - timestep_embedding, -) +from sgm.modules.diffusionmodules.util import linear, timestep_embedding from sat.ops.layernorm import LayerNorm, RMSNorm class ImagePatchEmbeddingMixin(BaseMixin): - def __init__( - self, - in_channels, - hidden_size, - patch_size, - bias=True, - text_hidden_size=None, - ): + def __init__(self, in_channels, hidden_size, patch_size, text_hidden_size=None): super().__init__() - self.proj = nn.Conv2d(in_channels, hidden_size, kernel_size=patch_size, stride=patch_size, bias=bias) + self.patch_size = patch_size + self.proj = nn.Linear(in_channels * reduce(mul, patch_size), hidden_size) if text_hidden_size is not None: self.text_proj = nn.Linear(text_hidden_size, hidden_size) else: self.text_proj = None def word_embedding_forward(self, input_ids, **kwargs): - # now is 3d patch images = kwargs["images"] # (b,t,c,h,w) - B, T = images.shape[:2] - emb = images.view(-1, *images.shape[2:]) - emb = self.proj(emb) # ((b t),d,h/2,w/2) - emb = emb.view(B, T, *emb.shape[1:]) - emb = emb.flatten(3).transpose(2, 3) # (b,t,n,d) - emb = rearrange(emb, "b t n d -> b (t n) d") + emb = rearrange(images, "b t c h w -> b (t h w) c") + emb = rearrange( + emb, + "b (t o h p w q) c -> b (t h w) (c o p q)", + t=kwargs["rope_T"], + h=kwargs["rope_H"], + w=kwargs["rope_W"], + o=self.patch_size[0], + p=self.patch_size[1], + q=self.patch_size[2], + ) + emb = self.proj(emb) if self.text_proj is not None: text_emb = self.text_proj(kwargs["encoder_outputs"]) @@ -74,7 +72,8 @@ def get_3d_sincos_pos_embed( grid_size: int of the grid height and width t_size: int of the temporal size return: - pos_embed: [t_size*grid_size*grid_size, embed_dim] or [1+t_size*grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + pos_embed: [t_size*grid_size * grid_size, embed_dim] or [1+t_size*grid_size * grid_size, embed_dim] + (w/ or w/o cls_token) """ assert embed_dim % 4 == 0 embed_dim_spatial = embed_dim // 4 * 3 @@ -100,7 +99,6 @@ def get_3d_sincos_pos_embed( pos_embed_spatial = np.repeat(pos_embed_spatial, t_size, axis=0) # [T, H*W, D // 4 * 3] pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1) - # pos_embed = pos_embed.reshape([-1, embed_dim]) # [T*H*W, D] return pos_embed # [T, H*W, D] @@ -259,6 +257,9 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin): text_length, theta=10000, rot_v=False, + height_interpolation=1.0, + width_interpolation=1.0, + time_interpolation=1.0, learnable_pos_embed=False, ): super().__init__() @@ -285,14 +286,10 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin): freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2) freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1) - freqs = rearrange(freqs, "t h w d -> (t h w) d") freqs = freqs.contiguous() - freqs_sin = freqs.sin() - freqs_cos = freqs.cos() - self.register_buffer("freqs_sin", freqs_sin) - self.register_buffer("freqs_cos", freqs_cos) - + self.freqs_sin = freqs.sin().cuda() + self.freqs_cos = freqs.cos().cuda() self.text_length = text_length if learnable_pos_embed: num_patches = height * width * compressed_num_frames + text_length @@ -301,15 +298,20 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin): self.pos_embedding = None def rotary(self, t, **kwargs): - seq_len = t.shape[2] - freqs_cos = self.freqs_cos[:seq_len].unsqueeze(0).unsqueeze(0) - freqs_sin = self.freqs_sin[:seq_len].unsqueeze(0).unsqueeze(0) + def reshape_freq(freqs): + freqs = freqs[: kwargs["rope_T"], : kwargs["rope_H"], : kwargs["rope_W"]].contiguous() + freqs = rearrange(freqs, "t h w d -> (t h w) d") + freqs = freqs.unsqueeze(0).unsqueeze(0) + return freqs + + freqs_cos = reshape_freq(self.freqs_cos).to(t.dtype) + freqs_sin = reshape_freq(self.freqs_sin).to(t.dtype) return t * freqs_cos + rotate_half(t) * freqs_sin def position_embedding_forward(self, position_ids, **kwargs): if self.pos_embedding is not None: - return self.pos_embedding[:, :self.text_length + kwargs["seq_length"]] + return self.pos_embedding[:, : self.text_length + kwargs["seq_length"]] else: return None @@ -326,10 +328,61 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin): ): attention_fn_default = HOOKS_DEFAULT["attention_fn"] - query_layer[:, :, self.text_length :] = self.rotary(query_layer[:, :, self.text_length :]) - key_layer[:, :, self.text_length :] = self.rotary(key_layer[:, :, self.text_length :]) + query_layer = torch.cat( + ( + query_layer[ + :, + :, + : kwargs["text_length"], + ], + self.rotary( + query_layer[ + :, + :, + kwargs["text_length"] :, + ], + **kwargs, + ), + ), + dim=2, + ) + key_layer = torch.cat( + ( + key_layer[ + :, + :, + : kwargs["text_length"], + ], + self.rotary( + key_layer[ + :, + :, + kwargs["text_length"] :, + ], + **kwargs, + ), + ), + dim=2, + ) if self.rot_v: - value_layer[:, :, self.text_length :] = self.rotary(value_layer[:, :, self.text_length :]) + value_layer = torch.cat( + ( + value_layer[ + :, + :, + : kwargs["text_length"], + ], + self.rotary( + value_layer[ + :, + :, + kwargs["text_length"] :, + ], + **kwargs, + ), + ), + dim=2, + ) return attention_fn_default( query_layer, @@ -347,21 +400,25 @@ def modulate(x, shift, scale): return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) -def unpatchify(x, c, p, w, h, rope_position_ids=None, **kwargs): +def unpatchify(x, c, patch_size, w, h, **kwargs): """ x: (N, T/2 * S, patch_size**3 * C) imgs: (N, T, H, W, C) + + patch_size 被拆解为三个不同的维度 (o, p, q),分别对应了深度(o)、高度(p)和宽度(q)。这使得 patch 大小在不同维度上可以不相等,增加了灵活性。 """ - if rope_position_ids is not None: - assert NotImplementedError - # do pix2struct unpatchify - L = x.shape[1] - x = x.reshape(shape=(x.shape[0], L, p, p, c)) - x = torch.einsum("nlpqc->ncplq", x) - imgs = x.reshape(shape=(x.shape[0], c, p, L * p)) - else: - b = x.shape[0] - imgs = rearrange(x, "b (t h w) (c p q) -> b t c (h p) (w q)", b=b, h=h, w=w, c=c, p=p, q=p) + + imgs = rearrange( + x, + "b (t h w) (c o p q) -> b (t o) c (h p) (w q)", + c=c, + o=patch_size[0], + p=patch_size[1], + q=patch_size[2], + t=kwargs["rope_T"], + h=kwargs["rope_H"], + w=kwargs["rope_W"], + ) return imgs @@ -382,27 +439,17 @@ class FinalLayerMixin(BaseMixin): self.patch_size = patch_size self.out_channels = out_channels self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=elementwise_affine, eps=1e-6) - self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.linear = nn.Linear(hidden_size, reduce(mul, patch_size) * out_channels, bias=True) self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim, 2 * hidden_size, bias=True)) - self.spatial_length = latent_width * latent_height // patch_size**2 - self.latent_width = latent_width - self.latent_height = latent_height - def final_forward(self, logits, **kwargs): - x, emb = logits[:, kwargs["text_length"] :, :], kwargs["emb"] # x:(b,(t n),d) + x, emb = logits[:, kwargs["text_length"] :, :], kwargs["emb"] # x:(b,(t n),d),只取了x中后面images的部分 shift, scale = self.adaLN_modulation(emb).chunk(2, dim=1) x = modulate(self.norm_final(x), shift, scale) x = self.linear(x) return unpatchify( - x, - c=self.out_channels, - p=self.patch_size, - w=self.latent_width // self.patch_size, - h=self.latent_height // self.patch_size, - rope_position_ids=kwargs.get("rope_position_ids", None), - **kwargs, + x, c=self.out_channels, patch_size=self.patch_size, w=kwargs["rope_W"], h=kwargs["rope_H"], **kwargs ) def reinit(self, parent_model=None): @@ -440,8 +487,6 @@ class SwiGLUMixin(BaseMixin): class AdaLNMixin(BaseMixin): def __init__( self, - width, - height, hidden_size, num_layers, time_embed_dim, @@ -452,8 +497,6 @@ class AdaLNMixin(BaseMixin): ): super().__init__() self.num_layers = num_layers - self.width = width - self.height = height self.compressed_num_frames = compressed_num_frames self.adaLN_modulations = nn.ModuleList( @@ -611,7 +654,8 @@ class DiffusionTransformer(BaseModel): time_interpolation=1.0, use_SwiGLU=False, use_RMSNorm=False, - zero_init_y_embed=False, + cfg_embed_dim=None, + ofs_embed_dim=None, **kwargs, ): self.latent_width = latent_width @@ -619,12 +663,14 @@ class DiffusionTransformer(BaseModel): self.patch_size = patch_size self.num_frames = num_frames self.time_compressed_rate = time_compressed_rate - self.spatial_length = latent_width * latent_height // patch_size**2 + self.spatial_length = latent_width * latent_height // reduce(mul, patch_size[1:]) self.in_channels = in_channels self.out_channels = out_channels self.hidden_size = hidden_size self.model_channels = hidden_size self.time_embed_dim = time_embed_dim if time_embed_dim is not None else hidden_size + self.cfg_embed_dim = cfg_embed_dim + self.ofs_embed_dim = ofs_embed_dim self.num_classes = num_classes self.adm_in_channels = adm_in_channels self.input_time = input_time @@ -636,7 +682,6 @@ class DiffusionTransformer(BaseModel): self.width_interpolation = width_interpolation self.time_interpolation = time_interpolation self.inner_hidden_size = hidden_size * 4 - self.zero_init_y_embed = zero_init_y_embed try: self.dtype = str_to_dtype[kwargs.pop("dtype")] except: @@ -669,7 +714,6 @@ class DiffusionTransformer(BaseModel): def _build_modules(self, module_configs): model_channels = self.hidden_size - # time_embed_dim = model_channels * 4 time_embed_dim = self.time_embed_dim self.time_embed = nn.Sequential( linear(model_channels, time_embed_dim), @@ -677,6 +721,20 @@ class DiffusionTransformer(BaseModel): linear(time_embed_dim, time_embed_dim), ) + if self.ofs_embed_dim is not None: + self.ofs_embed = nn.Sequential( + linear(self.ofs_embed_dim, self.ofs_embed_dim), + nn.SiLU(), + linear(self.ofs_embed_dim, self.ofs_embed_dim), + ) + + if self.cfg_embed_dim is not None: + self.cfg_embed = nn.Sequential( + linear(self.cfg_embed_dim, self.cfg_embed_dim), + nn.SiLU(), + linear(self.cfg_embed_dim, self.cfg_embed_dim), + ) + if self.num_classes is not None: if isinstance(self.num_classes, int): self.label_emb = nn.Embedding(self.num_classes, time_embed_dim) @@ -701,9 +759,6 @@ class DiffusionTransformer(BaseModel): linear(time_embed_dim, time_embed_dim), ) ) - if self.zero_init_y_embed: - nn.init.constant_(self.label_emb[0][2].weight, 0) - nn.init.constant_(self.label_emb[0][2].bias, 0) else: raise ValueError() @@ -712,10 +767,13 @@ class DiffusionTransformer(BaseModel): "pos_embed", instantiate_from_config( pos_embed_config, - height=self.latent_height // self.patch_size, - width=self.latent_width // self.patch_size, + height=self.latent_height // self.patch_size[1], + width=self.latent_width // self.patch_size[2], compressed_num_frames=(self.num_frames - 1) // self.time_compressed_rate + 1, hidden_size=self.hidden_size, + height_interpolation=self.height_interpolation, + width_interpolation=self.width_interpolation, + time_interpolation=self.time_interpolation, ), reinit=True, ) @@ -737,8 +795,6 @@ class DiffusionTransformer(BaseModel): "adaln_layer", instantiate_from_config( adaln_layer_config, - height=self.latent_height // self.patch_size, - width=self.latent_width // self.patch_size, hidden_size=self.hidden_size, num_layers=self.num_layers, compressed_num_frames=(self.num_frames - 1) // self.time_compressed_rate + 1, @@ -749,7 +805,6 @@ class DiffusionTransformer(BaseModel): ) else: raise NotImplementedError - final_layer_config = module_configs["final_layer_config"] self.add_mixin( "final_layer", @@ -766,25 +821,18 @@ class DiffusionTransformer(BaseModel): reinit=True, ) - if "lora_config" in module_configs: - lora_config = module_configs["lora_config"] - self.add_mixin("lora", instantiate_from_config(lora_config, layer_num=self.num_layers), reinit=True) - return def forward(self, x, timesteps=None, context=None, y=None, **kwargs): b, t, d, h, w = x.shape if x.dtype != self.dtype: x = x.to(self.dtype) - - # This is not use in inference if "concat_images" in kwargs and kwargs["concat_images"] is not None: if kwargs["concat_images"].shape[0] != x.shape[0]: concat_images = kwargs["concat_images"].repeat(2, 1, 1, 1, 1) else: concat_images = kwargs["concat_images"] x = torch.cat([x, concat_images], dim=2) - assert (y is not None) == ( self.num_classes is not None ), "must specify y if and only if the model is class-conditional" @@ -792,17 +840,33 @@ class DiffusionTransformer(BaseModel): emb = self.time_embed(t_emb) if self.num_classes is not None: - # assert y.shape[0] == x.shape[0] assert x.shape[0] % y.shape[0] == 0 y = y.repeat_interleave(x.shape[0] // y.shape[0], dim=0) emb = emb + self.label_emb(y) - kwargs["seq_length"] = t * h * w // (self.patch_size**2) + if self.ofs_embed_dim is not None: + ofs_emb = timestep_embedding(kwargs["ofs"], self.ofs_embed_dim, repeat_only=False, dtype=self.dtype) + ofs_emb = self.ofs_embed(ofs_emb) + emb = emb + ofs_emb + if self.cfg_embed_dim is not None: + cfg_emb = kwargs["scale_emb"] + cfg_emb = self.cfg_embed(cfg_emb) + emb = emb + cfg_emb + + if "ofs" in kwargs.keys(): + ofs_emb = timestep_embedding(kwargs["ofs"], self.ofs_embed_dim, repeat_only=False, dtype=self.dtype) + ofs_emb = self.ofs_embed(ofs_emb) + + kwargs["seq_length"] = t * h * w // reduce(mul, self.patch_size) kwargs["images"] = x kwargs["emb"] = emb kwargs["encoder_outputs"] = context kwargs["text_length"] = context.shape[1] + kwargs["rope_T"] = t // self.patch_size[0] + kwargs["rope_H"] = h // self.patch_size[1] + kwargs["rope_W"] = w // self.patch_size[2] + kwargs["input_ids"] = kwargs["position_ids"] = kwargs["attention_mask"] = torch.ones((1, 1)).to(x.dtype) output = super().forward(**kwargs)[0] return output diff --git a/sat/requirements.txt b/sat/requirements.txt index 75b4649..3c1c501 100644 --- a/sat/requirements.txt +++ b/sat/requirements.txt @@ -1,16 +1,11 @@ -SwissArmyTransformer==0.4.12 -omegaconf==2.3.0 -torch==2.4.0 -torchvision==0.19.0 -pytorch_lightning==2.3.3 -kornia==0.7.3 -beartype==0.18.5 -numpy==2.0.1 -fsspec==2024.5.0 -safetensors==0.4.3 -imageio-ffmpeg==0.5.1 -imageio==2.34.2 -scipy==1.14.0 -decord==0.6.0 -wandb==0.17.5 -deepspeed==0.14.4 \ No newline at end of file +SwissArmyTransformer>=0.4.12 +omegaconf>=2.3.0 +pytorch_lightning>=2.4.0 +kornia>=0.7.3 +beartype>=0.19.0 +fsspec>=2024.2.0 +safetensors>=0.4.5 +scipy>=1.14.1 +decord>=0.6.0 +wandb>=0.18.5 +deepspeed>=0.15.3 \ No newline at end of file diff --git a/sat/sample_video.py b/sat/sample_video.py index 49cfcac..58cf566 100644 --- a/sat/sample_video.py +++ b/sat/sample_video.py @@ -4,24 +4,20 @@ import argparse from typing import List, Union from tqdm import tqdm from omegaconf import ListConfig +from PIL import Image import imageio import torch import numpy as np -from einops import rearrange +from einops import rearrange, repeat import torchvision.transforms as TT - from sat.model.base_model import get_model from sat.training.model_io import load_checkpoint from sat import mpu from diffusion_video import SATVideoDiffusionEngine from arguments import get_args -from torchvision.transforms.functional import center_crop, resize -from torchvision.transforms import InterpolationMode -from PIL import Image - def read_from_cli(): cnt = 0 @@ -56,6 +52,42 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], T=None, device="cuda if key == "txt": batch["txt"] = np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist() batch_uc["txt"] = np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist() + elif key == "original_size_as_tuple": + batch["original_size_as_tuple"] = ( + torch.tensor([value_dict["orig_height"], value_dict["orig_width"]]).to(device).repeat(*N, 1) + ) + elif key == "crop_coords_top_left": + batch["crop_coords_top_left"] = ( + torch.tensor([value_dict["crop_coords_top"], value_dict["crop_coords_left"]]).to(device).repeat(*N, 1) + ) + elif key == "aesthetic_score": + batch["aesthetic_score"] = torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1) + batch_uc["aesthetic_score"] = ( + torch.tensor([value_dict["negative_aesthetic_score"]]).to(device).repeat(*N, 1) + ) + + elif key == "target_size_as_tuple": + batch["target_size_as_tuple"] = ( + torch.tensor([value_dict["target_height"], value_dict["target_width"]]).to(device).repeat(*N, 1) + ) + elif key == "fps": + batch[key] = torch.tensor([value_dict["fps"]]).to(device).repeat(math.prod(N)) + elif key == "fps_id": + batch[key] = torch.tensor([value_dict["fps_id"]]).to(device).repeat(math.prod(N)) + elif key == "motion_bucket_id": + batch[key] = torch.tensor([value_dict["motion_bucket_id"]]).to(device).repeat(math.prod(N)) + elif key == "pool_image": + batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=math.prod(N)).to(device, dtype=torch.half) + elif key == "cond_aug": + batch[key] = repeat( + torch.tensor([value_dict["cond_aug"]]).to("cuda"), + "1 -> b", + b=math.prod(N), + ) + elif key == "cond_frames": + batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0]) + elif key == "cond_frames_without_noise": + batch[key] = repeat(value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0]) else: batch[key] = value_dict[key] @@ -83,37 +115,6 @@ def save_video_as_grid_and_mp4(video_batch: torch.Tensor, save_path: str, fps: i writer.append_data(frame) -def resize_for_rectangle_crop(arr, image_size, reshape_mode="random"): - if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]: - arr = resize( - arr, - size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])], - interpolation=InterpolationMode.BICUBIC, - ) - else: - arr = resize( - arr, - size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]], - interpolation=InterpolationMode.BICUBIC, - ) - - h, w = arr.shape[2], arr.shape[3] - arr = arr.squeeze(0) - - delta_h = h - image_size[0] - delta_w = w - image_size[1] - - if reshape_mode == "random" or reshape_mode == "none": - top = np.random.randint(0, delta_h + 1) - left = np.random.randint(0, delta_w + 1) - elif reshape_mode == "center": - top, left = delta_h // 2, delta_w // 2 - else: - raise NotImplementedError - arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1]) - return arr - - def sampling_main(args, model_cls): if isinstance(model_cls, type): model = get_model(args, model_cls) @@ -127,44 +128,65 @@ def sampling_main(args, model_cls): data_iter = read_from_cli() elif args.input_type == "txt": rank, world_size = mpu.get_data_parallel_rank(), mpu.get_data_parallel_world_size() - print("rank and world_size", rank, world_size) data_iter = read_from_file(args.input_file, rank=rank, world_size=world_size) else: raise NotImplementedError - image_size = [480, 720] - - if args.image2video: - chained_trainsforms = [] - chained_trainsforms.append(TT.ToTensor()) - transform = TT.Compose(chained_trainsforms) - sample_func = model.sample - T, H, W, C, F = args.sampling_num_frames, image_size[0], image_size[1], args.latent_channels, 8 num_samples = [1] force_uc_zero_embeddings = ["txt"] - device = model.device + with torch.no_grad(): for text, cnt in tqdm(data_iter): if args.image2video: - text, image_path = text.split("@@") + # use with input image shape + text, image_path = text.split('@@') assert os.path.exists(image_path), image_path - image = Image.open(image_path).convert("RGB") - image = transform(image).unsqueeze(0).to("cuda") - image = resize_for_rectangle_crop(image, image_size, reshape_mode="center").unsqueeze(0) + image = Image.open(image_path).convert('RGB') + (img_W, img_H) = image.size + + def nearest_multiple_of_16(n): + lower_multiple = (n // 16) * 16 + upper_multiple = (n // 16 + 1) * 16 + if abs(n - lower_multiple) < abs(n - upper_multiple): + return lower_multiple + else: + return upper_multiple + + if img_H < img_W: + H = 96 + W = int(nearest_multiple_of_16(img_W / img_H * H * 8)) // 8 + else: + W = 96 + H = int(nearest_multiple_of_16(img_H / img_W * W * 8)) // 8 + chained_trainsforms = [] + chained_trainsforms.append(TT.Resize(size=[int(H * 8), int(W * 8)], interpolation=1)) + chained_trainsforms.append(TT.ToTensor()) + transform = TT.Compose(chained_trainsforms) + image = transform(image).unsqueeze(0).to('cuda') image = image * 2.0 - 1.0 image = image.unsqueeze(2).to(torch.bfloat16) image = model.encode_first_stage(image, None) + image = image / model.scale_factor image = image.permute(0, 2, 1, 3, 4).contiguous() - pad_shape = (image.shape[0], T - 1, C, H // F, W // F) + pad_shape = (image.shape[0], T - 1, C, H, W) image = torch.concat([image, torch.zeros(pad_shape).to(image.device).to(image.dtype)], dim=1) else: + image_size = args.sampling_image_size + T, H, W, C = args.sampling_num_frames, image_size[0], image_size[1], args.latent_channels + F = 8 # 8x downsampled image = None + text_cast = [text] + mp_size = mpu.get_model_parallel_world_size() + global_rank = torch.distributed.get_rank() // mp_size + src = global_rank * mp_size + torch.distributed.broadcast_object_list(text_cast, src=src, group=mpu.get_model_parallel_group()) + text = text_cast[0] value_dict = { - "prompt": text, - "negative_prompt": "", - "num_frames": torch.tensor(T).unsqueeze(0), + 'prompt': text, + 'negative_prompt': '', + 'num_frames': torch.tensor(T).unsqueeze(0) } batch, batch_uc = get_batch( @@ -187,64 +209,52 @@ def sampling_main(args, model_cls): if not k == "crossattn": c[k], uc[k] = map(lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc)) - if args.image2video and image is not None: + if args.image2video: c["concat"] = image uc["concat"] = image for index in range(args.batch_size): - # reload model on GPU - model.to(device) - samples_z = sample_func( - c, - uc=uc, - batch_size=1, - shape=(T, C, H // F, W // F), - ) + if args.image2video: + samples_z = sample_func( + c, + uc=uc, + batch_size=1, + shape=(T, C, H, W), + ofs=torch.tensor([2.0]).to('cuda') + ) + else: + samples_z = sample_func( + c, + uc=uc, + batch_size=1, + shape=(T, C, H // F, W // F), + ).to('cuda') + samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous() + if args.only_save_latents: + samples_z = 1.0 / model.scale_factor * samples_z + save_path = os.path.join( + args.output_dir, str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120], str(index) + ) + os.makedirs(save_path, exist_ok=True) + torch.save(samples_z, os.path.join(save_path, "latent.pt")) + with open(os.path.join(save_path, "text.txt"), "w") as f: + f.write(text) + else: + samples_x = model.decode_first_stage(samples_z).to(torch.float32) + samples_x = samples_x.permute(0, 2, 1, 3, 4).contiguous() + samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu() + save_path = os.path.join( + args.output_dir, str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120], str(index) + ) + if mpu.get_model_parallel_rank() == 0: + save_video_as_grid_and_mp4(samples, save_path, fps=args.sampling_fps) - # Unload the model from GPU to save GPU memory - model.to("cpu") - torch.cuda.empty_cache() - first_stage_model = model.first_stage_model - first_stage_model = first_stage_model.to(device) - - latent = 1.0 / model.scale_factor * samples_z - - # Decode latent serial to save GPU memory - recons = [] - loop_num = (T - 1) // 2 - for i in range(loop_num): - if i == 0: - start_frame, end_frame = 0, 3 - else: - start_frame, end_frame = i * 2 + 1, i * 2 + 3 - if i == loop_num - 1: - clear_fake_cp_cache = True - else: - clear_fake_cp_cache = False - with torch.no_grad(): - recon = first_stage_model.decode( - latent[:, :, start_frame:end_frame].contiguous(), clear_fake_cp_cache=clear_fake_cp_cache - ) - - recons.append(recon) - - recon = torch.cat(recons, dim=2).to(torch.float32) - samples_x = recon.permute(0, 2, 1, 3, 4).contiguous() - samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu() - - save_path = os.path.join( - args.output_dir, str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120], str(index) - ) - if mpu.get_model_parallel_rank() == 0: - save_video_as_grid_and_mp4(samples, save_path, fps=args.sampling_fps) - - -if __name__ == "__main__": - if "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ: - os.environ["LOCAL_RANK"] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"] - os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"] - os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"] +if __name__ == '__main__': + if 'OMPI_COMM_WORLD_LOCAL_RANK' in os.environ: + os.environ['LOCAL_RANK'] = os.environ['OMPI_COMM_WORLD_LOCAL_RANK'] + os.environ['WORLD_SIZE'] = os.environ['OMPI_COMM_WORLD_SIZE'] + os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK'] py_parser = argparse.ArgumentParser(add_help=False) known, args_list = py_parser.parse_known_args() diff --git a/sat/sgm/modules/diffusionmodules/sampling.py b/sat/sgm/modules/diffusionmodules/sampling.py index f0f1830..6efd154 100644 --- a/sat/sgm/modules/diffusionmodules/sampling.py +++ b/sat/sgm/modules/diffusionmodules/sampling.py @@ -1,7 +1,8 @@ """ -Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py + Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py """ + from typing import Dict, Union import torch @@ -16,7 +17,6 @@ from ...modules.diffusionmodules.sampling_utils import ( to_sigma, ) from ...util import append_dims, default, instantiate_from_config -from ...util import SeededNoise from .guiders import DynamicCFG @@ -44,7 +44,9 @@ class BaseDiffusionSampler: self.device = device def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None): - sigmas = self.discretization(self.num_steps if num_steps is None else num_steps, device=self.device) + sigmas = self.discretization( + self.num_steps if num_steps is None else num_steps, device=self.device + ) uc = default(uc, cond) x *= torch.sqrt(1.0 + sigmas[0] ** 2.0) @@ -83,7 +85,9 @@ class SingleStepDiffusionSampler(BaseDiffusionSampler): class EDMSampler(SingleStepDiffusionSampler): - def __init__(self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs): + def __init__( + self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs + ): super().__init__(*args, **kwargs) self.s_churn = s_churn @@ -102,15 +106,21 @@ class EDMSampler(SingleStepDiffusionSampler): dt = append_dims(next_sigma - sigma_hat, x.ndim) euler_step = self.euler_step(x, d, dt) - x = self.possible_correction_step(euler_step, x, d, dt, next_sigma, denoiser, cond, uc) + x = self.possible_correction_step( + euler_step, x, d, dt, next_sigma, denoiser, cond, uc + ) return x def __call__(self, denoiser, x, cond, uc=None, num_steps=None): - x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) + x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( + x, cond, uc, num_steps + ) for i in self.get_sigma_gen(num_sigmas): gamma = ( - min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) if self.s_tmin <= sigmas[i] <= self.s_tmax else 0.0 + min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) + if self.s_tmin <= sigmas[i] <= self.s_tmax + else 0.0 ) x = self.sampler_step( s_in * sigmas[i], @@ -126,23 +136,30 @@ class EDMSampler(SingleStepDiffusionSampler): class DDIMSampler(SingleStepDiffusionSampler): - def __init__(self, s_noise=0.1, *args, **kwargs): + def __init__( + self, s_noise=0.1, *args, **kwargs + ): super().__init__(*args, **kwargs) self.s_noise = s_noise def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, s_noise=0.0): + denoised = self.denoise(x, denoiser, sigma, cond, uc) d = to_d(x, sigma, denoised) - dt = append_dims(next_sigma * (1 - s_noise**2) ** 0.5 - sigma, x.ndim) + dt = append_dims(next_sigma * (1 - s_noise**2)**0.5 - sigma, x.ndim) euler_step = x + dt * d + s_noise * append_dims(next_sigma, x.ndim) * torch.randn_like(x) - x = self.possible_correction_step(euler_step, x, d, dt, next_sigma, denoiser, cond, uc) + x = self.possible_correction_step( + euler_step, x, d, dt, next_sigma, denoiser, cond, uc + ) return x def __call__(self, denoiser, x, cond, uc=None, num_steps=None): - x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) + x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( + x, cond, uc, num_steps + ) for i in self.get_sigma_gen(num_sigmas): x = self.sampler_step( @@ -181,7 +198,9 @@ class AncestralSampler(SingleStepDiffusionSampler): return x def __call__(self, denoiser, x, cond, uc=None, num_steps=None): - x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) + x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( + x, cond, uc, num_steps + ) for i in self.get_sigma_gen(num_sigmas): x = self.sampler_step( @@ -208,32 +227,43 @@ class LinearMultistepSampler(BaseDiffusionSampler): self.order = order def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs): - x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) + x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( + x, cond, uc, num_steps + ) ds = [] sigmas_cpu = sigmas.detach().cpu().numpy() for i in self.get_sigma_gen(num_sigmas): sigma = s_in * sigmas[i] - denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs) + denoised = denoiser( + *self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs + ) denoised = self.guider(denoised, sigma) d = to_d(x, sigma, denoised) ds.append(d) if len(ds) > self.order: ds.pop(0) cur_order = min(i + 1, self.order) - coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)] + coeffs = [ + linear_multistep_coeff(cur_order, sigmas_cpu, i, j) + for j in range(cur_order) + ] x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds))) return x class EulerEDMSampler(EDMSampler): - def possible_correction_step(self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc): + def possible_correction_step( + self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc + ): return euler_step class HeunEDMSampler(EDMSampler): - def possible_correction_step(self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc): + def possible_correction_step( + self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc + ): if torch.sum(next_sigma) < 1e-14: # Save a network evaluation if all noise levels are 0 return euler_step @@ -243,7 +273,9 @@ class HeunEDMSampler(EDMSampler): d_prime = (d + d_new) / 2.0 # apply correction if noise level is not 0 - x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step) + x = torch.where( + append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step + ) return x @@ -282,7 +314,9 @@ class DPMPP2SAncestralSampler(AncestralSampler): x = x_euler else: h, s, t, t_next = self.get_variables(sigma, sigma_down) - mult = [append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next)] + mult = [ + append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next) + ] x2 = mult[0] * x - mult[1] * denoised denoised2 = self.denoise(x2, denoiser, to_sigma(s), cond, uc) @@ -332,7 +366,10 @@ class DPMPP2MSampler(BaseDiffusionSampler): denoised = self.denoise(x, denoiser, sigma, cond, uc) h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma) - mult = [append_dims(mult, x.ndim) for mult in self.get_mult(h, r, t, t_next, previous_sigma)] + mult = [ + append_dims(mult, x.ndim) + for mult in self.get_mult(h, r, t, t_next, previous_sigma) + ] x_standard = mult[0] * x - mult[1] * denoised if old_denoised is None or torch.sum(next_sigma) < 1e-14: @@ -343,12 +380,16 @@ class DPMPP2MSampler(BaseDiffusionSampler): x_advanced = mult[0] * x - mult[1] * denoised_d # apply correction if noise level is not 0 and not first step - x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard) + x = torch.where( + append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard + ) return x, denoised def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs): - x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) + x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( + x, cond, uc, num_steps + ) old_denoised = None for i in self.get_sigma_gen(num_sigmas): @@ -365,7 +406,6 @@ class DPMPP2MSampler(BaseDiffusionSampler): return x - class SDEDPMPP2MSampler(BaseDiffusionSampler): def get_variables(self, sigma, next_sigma, previous_sigma=None): t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)] @@ -380,7 +420,7 @@ class SDEDPMPP2MSampler(BaseDiffusionSampler): def get_mult(self, h, r, t, t_next, previous_sigma): mult1 = to_sigma(t_next) / to_sigma(t) * (-h).exp() - mult2 = (-2 * h).expm1() + mult2 = (-2*h).expm1() if previous_sigma is not None: mult3 = 1 + 1 / (2 * r) @@ -403,8 +443,11 @@ class SDEDPMPP2MSampler(BaseDiffusionSampler): denoised = self.denoise(x, denoiser, sigma, cond, uc) h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma) - mult = [append_dims(mult, x.ndim) for mult in self.get_mult(h, r, t, t_next, previous_sigma)] - mult_noise = append_dims(next_sigma * (1 - (-2 * h).exp()) ** 0.5, x.ndim) + mult = [ + append_dims(mult, x.ndim) + for mult in self.get_mult(h, r, t, t_next, previous_sigma) + ] + mult_noise = append_dims(next_sigma * (1 - (-2*h).exp())**0.5, x.ndim) x_standard = mult[0] * x - mult[1] * denoised + mult_noise * torch.randn_like(x) if old_denoised is None or torch.sum(next_sigma) < 1e-14: @@ -415,12 +458,16 @@ class SDEDPMPP2MSampler(BaseDiffusionSampler): x_advanced = mult[0] * x - mult[1] * denoised_d + mult_noise * torch.randn_like(x) # apply correction if noise level is not 0 and not first step - x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard) + x = torch.where( + append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard + ) return x, denoised def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, **kwargs): - x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) + x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( + x, cond, uc, num_steps + ) old_denoised = None for i in self.get_sigma_gen(num_sigmas): @@ -437,7 +484,6 @@ class SDEDPMPP2MSampler(BaseDiffusionSampler): return x - class SdeditEDMSampler(EulerEDMSampler): def __init__(self, edit_ratio=0.5, *args, **kwargs): super().__init__(*args, **kwargs) @@ -446,7 +492,9 @@ class SdeditEDMSampler(EulerEDMSampler): def __call__(self, denoiser, image, randn, cond, uc=None, num_steps=None, edit_ratio=None): randn_unit = randn.clone() - randn, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(randn, cond, uc, num_steps) + randn, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( + randn, cond, uc, num_steps + ) if num_steps is None: num_steps = self.num_steps @@ -461,7 +509,9 @@ class SdeditEDMSampler(EulerEDMSampler): x = image + randn_unit * append_dims(s_in * sigmas[i], len(randn_unit.shape)) gamma = ( - min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) if self.s_tmin <= sigmas[i] <= self.s_tmax else 0.0 + min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) + if self.s_tmin <= sigmas[i] <= self.s_tmax + else 0.0 ) x = self.sampler_step( s_in * sigmas[i], @@ -475,8 +525,8 @@ class SdeditEDMSampler(EulerEDMSampler): return x - class VideoDDIMSampler(BaseDiffusionSampler): + def __init__(self, fixed_frames=0, sdedit=False, **kwargs): super().__init__(**kwargs) self.fixed_frames = fixed_frames @@ -484,13 +534,10 @@ class VideoDDIMSampler(BaseDiffusionSampler): def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None): alpha_cumprod_sqrt, timesteps = self.discretization( - self.num_steps if num_steps is None else num_steps, - device=self.device, - return_idx=True, - do_append_zero=False, + self.num_steps if num_steps is None else num_steps, device=self.device, return_idx=True, do_append_zero=False ) alpha_cumprod_sqrt = torch.cat([alpha_cumprod_sqrt, alpha_cumprod_sqrt.new_ones([1])]) - timesteps = torch.cat([torch.tensor(list(timesteps)).new_zeros([1]) - 1, torch.tensor(list(timesteps))]) + timesteps = torch.cat([torch.tensor(list(timesteps)).new_zeros([1])-1, torch.tensor(list(timesteps))]) uc = default(uc, cond) @@ -500,51 +547,36 @@ class VideoDDIMSampler(BaseDiffusionSampler): return x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps - def denoise(self, x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep=None, idx=None, scale=None, scale_emb=None): + def denoise(self, x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep=None, idx=None, scale=None, scale_emb=None, ofs=None): additional_model_inputs = {} + if ofs is not None: + additional_model_inputs['ofs'] = ofs + if isinstance(scale, torch.Tensor) == False and scale == 1: - additional_model_inputs["idx"] = x.new_ones([x.shape[0]]) * timestep + additional_model_inputs['idx'] = x.new_ones([x.shape[0]]) * timestep if scale_emb is not None: - additional_model_inputs["scale_emb"] = scale_emb + additional_model_inputs['scale_emb'] = scale_emb denoised = denoiser(x, alpha_cumprod_sqrt, cond, **additional_model_inputs).to(torch.float32) else: - additional_model_inputs["idx"] = torch.cat([x.new_ones([x.shape[0]]) * timestep] * 2) - denoised = denoiser( - *self.guider.prepare_inputs(x, alpha_cumprod_sqrt, cond, uc), **additional_model_inputs - ).to(torch.float32) + additional_model_inputs['idx'] = torch.cat([x.new_ones([x.shape[0]]) * timestep] * 2) + denoised = denoiser(*self.guider.prepare_inputs(x, alpha_cumprod_sqrt, cond, uc), **additional_model_inputs).to(torch.float32) if isinstance(self.guider, DynamicCFG): - denoised = self.guider( - denoised, (1 - alpha_cumprod_sqrt**2) ** 0.5, step_index=self.num_steps - timestep, scale=scale - ) + denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt**2)**0.5, step_index=self.num_steps - timestep, scale=scale) else: - denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt**2) ** 0.5, scale=scale) + denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt**2)**0.5, scale=scale) return denoised - def sampler_step( - self, - alpha_cumprod_sqrt, - next_alpha_cumprod_sqrt, - denoiser, - x, - cond, - uc=None, - idx=None, - timestep=None, - scale=None, - scale_emb=None, - ): - denoised = self.denoise( - x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx, scale=scale, scale_emb=scale_emb - ).to(torch.float32) + def sampler_step(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, denoiser, x, cond, uc=None, idx=None, timestep=None, scale=None, scale_emb=None, ofs=None): + denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx, scale=scale, scale_emb=scale_emb, ofs=ofs).to(torch.float32) # 1020 - a_t = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5 + a_t = ((1-next_alpha_cumprod_sqrt**2)/(1-alpha_cumprod_sqrt**2))**0.5 b_t = next_alpha_cumprod_sqrt - alpha_cumprod_sqrt * a_t x = append_dims(a_t, x.ndim) * x + append_dims(b_t, x.ndim) * denoised return x - def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None): + def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None, ofs=None): # 1020 x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop( x, cond, uc, num_steps ) @@ -558,25 +590,83 @@ class VideoDDIMSampler(BaseDiffusionSampler): cond, uc, idx=self.num_steps - i, - timestep=timesteps[-(i + 1)], + timestep=timesteps[-(i+1)], scale=scale, scale_emb=scale_emb, + ofs=ofs # 1020 ) return x +class Image2VideoDDIMSampler(BaseDiffusionSampler): + + def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None): + alpha_cumprod_sqrt, timesteps = self.discretization( + self.num_steps if num_steps is None else num_steps, device=self.device, return_idx=True + ) + uc = default(uc, cond) + + num_sigmas = len(alpha_cumprod_sqrt) + + s_in = x.new_ones([x.shape[0]]) + + return x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps + + def denoise(self, x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep=None): + additional_model_inputs = {} + additional_model_inputs['idx'] = torch.cat([x.new_ones([x.shape[0]]) * timestep] * 2) + denoised = denoiser(*self.guider.prepare_inputs(x, alpha_cumprod_sqrt, cond, uc), **additional_model_inputs).to( + torch.float32) + if isinstance(self.guider, DynamicCFG): + denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt ** 2) ** 0.5, step_index=self.num_steps - timestep) + else: + denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt ** 2) ** 0.5) + return denoised + + def sampler_step(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, denoiser, x, cond, uc=None, idx=None, + timestep=None): + # 此处的sigma实际上是alpha_cumprod_sqrt + denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep).to(torch.float32) + if idx == 1: + return denoised + + a_t = ((1 - next_alpha_cumprod_sqrt ** 2) / (1 - alpha_cumprod_sqrt ** 2)) ** 0.5 + b_t = next_alpha_cumprod_sqrt - alpha_cumprod_sqrt * a_t + + x = append_dims(a_t, x.ndim) * x + append_dims(b_t, x.ndim) * denoised + return x + + def __call__(self, image, denoiser, x, cond, uc=None, num_steps=None): + x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop( + x, cond, uc, num_steps + ) + + for i in self.get_sigma_gen(num_sigmas): + x = self.sampler_step( + s_in * alpha_cumprod_sqrt[i], + s_in * alpha_cumprod_sqrt[i + 1], + denoiser, + x, + cond, + uc, + idx=self.num_steps - i, + timestep=timesteps[-(i + 1)] + ) + + return x + class VPSDEDPMPP2MSampler(VideoDDIMSampler): def get_variables(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt=None): - alpha_cumprod = alpha_cumprod_sqrt**2 - lamb = ((alpha_cumprod / (1 - alpha_cumprod)) ** 0.5).log() - next_alpha_cumprod = next_alpha_cumprod_sqrt**2 - lamb_next = ((next_alpha_cumprod / (1 - next_alpha_cumprod)) ** 0.5).log() + alpha_cumprod = alpha_cumprod_sqrt ** 2 + lamb = ((alpha_cumprod / (1-alpha_cumprod))**0.5).log() + next_alpha_cumprod = next_alpha_cumprod_sqrt ** 2 + lamb_next = ((next_alpha_cumprod / (1-next_alpha_cumprod))**0.5).log() h = lamb_next - lamb if previous_alpha_cumprod_sqrt is not None: - previous_alpha_cumprod = previous_alpha_cumprod_sqrt**2 - lamb_previous = ((previous_alpha_cumprod / (1 - previous_alpha_cumprod)) ** 0.5).log() + previous_alpha_cumprod = previous_alpha_cumprod_sqrt ** 2 + lamb_previous = ((previous_alpha_cumprod / (1-previous_alpha_cumprod))**0.5).log() h_last = lamb - lamb_previous r = h_last / h return h, r, lamb, lamb_next @@ -584,8 +674,8 @@ class VPSDEDPMPP2MSampler(VideoDDIMSampler): return h, None, lamb, lamb_next def get_mult(self, h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt): - mult1 = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5 * (-h).exp() - mult2 = (-2 * h).expm1() * next_alpha_cumprod_sqrt + mult1 = ((1-next_alpha_cumprod_sqrt**2) / (1-alpha_cumprod_sqrt**2))**0.5 * (-h).exp() + mult2 = (-2*h).expm1() * next_alpha_cumprod_sqrt if previous_alpha_cumprod_sqrt is not None: mult3 = 1 + 1 / (2 * r) @@ -608,21 +698,18 @@ class VPSDEDPMPP2MSampler(VideoDDIMSampler): timestep=None, scale=None, scale_emb=None, + ofs=None # 1020 ): - denoised = self.denoise( - x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx, scale=scale, scale_emb=scale_emb - ).to(torch.float32) + denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx, scale=scale, scale_emb=scale_emb, ofs=ofs).to(torch.float32) # 1020 if idx == 1: return denoised, denoised - h, r, lamb, lamb_next = self.get_variables( - alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt - ) + h, r, lamb, lamb_next = self.get_variables(alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt) mult = [ append_dims(mult, x.ndim) for mult in self.get_mult(h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt) ] - mult_noise = append_dims((1 - next_alpha_cumprod_sqrt**2) ** 0.5 * (1 - (-2 * h).exp()) ** 0.5, x.ndim) + mult_noise = append_dims((1-next_alpha_cumprod_sqrt**2)**0.5 * (1 - (-2*h).exp())**0.5, x.ndim) x_standard = mult[0] * x - mult[1] * denoised + mult_noise * torch.randn_like(x) if old_denoised is None or torch.sum(next_alpha_cumprod_sqrt) < 1e-14: @@ -636,24 +723,23 @@ class VPSDEDPMPP2MSampler(VideoDDIMSampler): return x, denoised - def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None): + def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None, ofs=None): # 1020 x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop( x, cond, uc, num_steps ) if self.fixed_frames > 0: - prefix_frames = x[:, : self.fixed_frames] + prefix_frames = x[:, :self.fixed_frames] old_denoised = None for i in self.get_sigma_gen(num_sigmas): + if self.fixed_frames > 0: if self.sdedit: rd = torch.randn_like(prefix_frames) - noised_prefix_frames = alpha_cumprod_sqrt[i] * prefix_frames + rd * append_dims( - s_in * (1 - alpha_cumprod_sqrt[i] ** 2) ** 0.5, len(prefix_frames.shape) - ) - x = torch.cat([noised_prefix_frames, x[:, self.fixed_frames :]], dim=1) + noised_prefix_frames = alpha_cumprod_sqrt[i] * prefix_frames + rd * append_dims(s_in * (1 - alpha_cumprod_sqrt[i] ** 2)**0.5, len(prefix_frames.shape)) + x = torch.cat([noised_prefix_frames, x[:, self.fixed_frames:]], dim=1) else: - x = torch.cat([prefix_frames, x[:, self.fixed_frames :]], dim=1) + x = torch.cat([prefix_frames, x[:, self.fixed_frames:]], dim=1) x, old_denoised = self.sampler_step( old_denoised, None if i == 0 else s_in * alpha_cumprod_sqrt[i - 1], @@ -664,28 +750,29 @@ class VPSDEDPMPP2MSampler(VideoDDIMSampler): cond, uc=uc, idx=self.num_steps - i, - timestep=timesteps[-(i + 1)], + timestep=timesteps[-(i+1)], scale=scale, scale_emb=scale_emb, + ofs=ofs # 1020 ) if self.fixed_frames > 0: - x = torch.cat([prefix_frames, x[:, self.fixed_frames :]], dim=1) + x = torch.cat([prefix_frames, x[:, self.fixed_frames:]], dim=1) return x class VPODEDPMPP2MSampler(VideoDDIMSampler): def get_variables(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt=None): - alpha_cumprod = alpha_cumprod_sqrt**2 - lamb = ((alpha_cumprod / (1 - alpha_cumprod)) ** 0.5).log() - next_alpha_cumprod = next_alpha_cumprod_sqrt**2 - lamb_next = ((next_alpha_cumprod / (1 - next_alpha_cumprod)) ** 0.5).log() + alpha_cumprod = alpha_cumprod_sqrt ** 2 + lamb = ((alpha_cumprod / (1-alpha_cumprod))**0.5).log() + next_alpha_cumprod = next_alpha_cumprod_sqrt ** 2 + lamb_next = ((next_alpha_cumprod / (1-next_alpha_cumprod))**0.5).log() h = lamb_next - lamb if previous_alpha_cumprod_sqrt is not None: - previous_alpha_cumprod = previous_alpha_cumprod_sqrt**2 - lamb_previous = ((previous_alpha_cumprod / (1 - previous_alpha_cumprod)) ** 0.5).log() + previous_alpha_cumprod = previous_alpha_cumprod_sqrt ** 2 + lamb_previous = ((previous_alpha_cumprod / (1-previous_alpha_cumprod))**0.5).log() h_last = lamb - lamb_previous r = h_last / h return h, r, lamb, lamb_next @@ -693,7 +780,7 @@ class VPODEDPMPP2MSampler(VideoDDIMSampler): return h, None, lamb, lamb_next def get_mult(self, h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt): - mult1 = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5 + mult1 = ((1-next_alpha_cumprod_sqrt**2) / (1-alpha_cumprod_sqrt**2))**0.5 mult2 = (-h).expm1() * next_alpha_cumprod_sqrt if previous_alpha_cumprod_sqrt is not None: @@ -714,15 +801,13 @@ class VPODEDPMPP2MSampler(VideoDDIMSampler): cond, uc=None, idx=None, - timestep=None, + timestep=None ): denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx).to(torch.float32) if idx == 1: return denoised, denoised - h, r, lamb, lamb_next = self.get_variables( - alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt - ) + h, r, lamb, lamb_next = self.get_variables(alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt) mult = [ append_dims(mult, x.ndim) for mult in self.get_mult(h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt) @@ -757,7 +842,39 @@ class VPODEDPMPP2MSampler(VideoDDIMSampler): cond, uc=uc, idx=self.num_steps - i, - timestep=timesteps[-(i + 1)], + timestep=timesteps[-(i+1)] ) return x + +class VideoDDPMSampler(VideoDDIMSampler): + def sampler_step(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, denoiser, x, cond, uc=None, idx=None): + # 此处的sigma实际上是alpha_cumprod_sqrt + denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, idx*1000//self.num_steps).to(torch.float32) + if idx == 1: + return denoised + + alpha_sqrt = alpha_cumprod_sqrt / next_alpha_cumprod_sqrt + x = append_dims(alpha_sqrt * (1-next_alpha_cumprod_sqrt**2) / (1-alpha_cumprod_sqrt**2), x.ndim) * x \ + + append_dims(next_alpha_cumprod_sqrt * (1-alpha_sqrt**2) / (1-alpha_cumprod_sqrt**2), x.ndim) * denoised \ + + append_dims(((1-next_alpha_cumprod_sqrt**2) * (1-alpha_sqrt**2) / (1-alpha_cumprod_sqrt**2))**0.5, x.ndim) * torch.randn_like(x) + + return x + + def __call__(self, denoiser, x, cond, uc=None, num_steps=None): + x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc = self.prepare_sampling_loop( + x, cond, uc, num_steps + ) + + for i in self.get_sigma_gen(num_sigmas): + x = self.sampler_step( + s_in * alpha_cumprod_sqrt[i], + s_in * alpha_cumprod_sqrt[i + 1], + denoiser, + x, + cond, + uc, + idx=self.num_steps - i + ) + + return x \ No newline at end of file diff --git a/sat/sgm/modules/diffusionmodules/sigma_sampling.py b/sat/sgm/modules/diffusionmodules/sigma_sampling.py index 770de42..8bb623e 100644 --- a/sat/sgm/modules/diffusionmodules/sigma_sampling.py +++ b/sat/sgm/modules/diffusionmodules/sigma_sampling.py @@ -17,23 +17,20 @@ class EDMSampling: class DiscreteSampling: - def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True, uniform_sampling=False): + def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True, uniform_sampling=False, group_num=0): self.num_idx = num_idx - self.sigmas = instantiate_from_config(discretization_config)(num_idx, do_append_zero=do_append_zero, flip=flip) + self.sigmas = instantiate_from_config(discretization_config)( + num_idx, do_append_zero=do_append_zero, flip=flip + ) world_size = mpu.get_data_parallel_world_size() + if world_size <= 8: + uniform_sampling = False self.uniform_sampling = uniform_sampling + self.group_num = group_num if self.uniform_sampling: - i = 1 - while True: - if world_size % i != 0 or num_idx % (world_size // i) != 0: - i += 1 - else: - self.group_num = world_size // i - break - assert self.group_num > 0 - assert world_size % self.group_num == 0 - self.group_width = world_size // self.group_num # the number of rank in one group + assert world_size % group_num == 0 + self.group_width = world_size // group_num # the number of rank in one group self.sigma_interval = self.num_idx // self.group_num def idx_to_sigma(self, idx): @@ -45,9 +42,7 @@ class DiscreteSampling: group_index = rank // self.group_width idx = default( rand, - torch.randint( - group_index * self.sigma_interval, (group_index + 1) * self.sigma_interval, (n_samples,) - ), + torch.randint(group_index * self.sigma_interval, (group_index + 1) * self.sigma_interval, (n_samples,)), ) else: idx = default( @@ -59,7 +54,6 @@ class DiscreteSampling: else: return self.idx_to_sigma(idx) - class PartialDiscreteSampling: def __init__(self, discretization_config, total_num_idx, partial_num_idx, do_append_zero=False, flip=True): self.total_num_idx = total_num_idx diff --git a/sat/vae_modules/autoencoder.py b/sat/vae_modules/autoencoder.py index 7c0cc80..9642fb4 100644 --- a/sat/vae_modules/autoencoder.py +++ b/sat/vae_modules/autoencoder.py @@ -592,8 +592,11 @@ class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine): unregularized: bool = False, input_cp: bool = False, output_cp: bool = False, + use_cp: bool = True, ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: - if self.cp_size > 0 and not input_cp: + if self.cp_size <= 1: + use_cp = False + if self.cp_size > 0 and use_cp and not input_cp: if not is_context_parallel_initialized: initialize_context_parallel(self.cp_size) @@ -603,11 +606,11 @@ class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine): x = _conv_split(x, dim=2, kernel_size=1) if return_reg_log: - z, reg_log = super().encode(x, return_reg_log, unregularized) + z, reg_log = super().encode(x, return_reg_log, unregularized, use_cp=use_cp) else: - z = super().encode(x, return_reg_log, unregularized) + z = super().encode(x, return_reg_log, unregularized, use_cp=use_cp) - if self.cp_size > 0 and not output_cp: + if self.cp_size > 0 and use_cp and not output_cp: z = _conv_gather(z, dim=2, kernel_size=1) if return_reg_log: @@ -619,23 +622,24 @@ class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine): z: torch.Tensor, input_cp: bool = False, output_cp: bool = False, - split_kernel_size: int = 1, + use_cp: bool = True, **kwargs, ): - if self.cp_size > 0 and not input_cp: + if self.cp_size <= 1: + use_cp = False + if self.cp_size > 0 and use_cp and not input_cp: if not is_context_parallel_initialized: initialize_context_parallel(self.cp_size) global_src_rank = get_context_parallel_group_rank() * self.cp_size torch.distributed.broadcast(z, src=global_src_rank, group=get_context_parallel_group()) - z = _conv_split(z, dim=2, kernel_size=split_kernel_size) + z = _conv_split(z, dim=2, kernel_size=1) - x = super().decode(z, **kwargs) - - if self.cp_size > 0 and not output_cp: - x = _conv_gather(x, dim=2, kernel_size=split_kernel_size) + x = super().decode(z, use_cp=use_cp, **kwargs) + if self.cp_size > 0 and use_cp and not output_cp: + x = _conv_gather(x, dim=2, kernel_size=1) return x def forward( diff --git a/sat/vae_modules/cp_enc_dec.py b/sat/vae_modules/cp_enc_dec.py index d50720d..5b32096 100644 --- a/sat/vae_modules/cp_enc_dec.py +++ b/sat/vae_modules/cp_enc_dec.py @@ -5,8 +5,7 @@ import torch.nn as nn import torch.nn.functional as F import numpy as np -from beartype import beartype -from beartype.typing import Union, Tuple, Optional, List +from beartype.typing import Union, Tuple from einops import rearrange from sgm.util import ( @@ -16,11 +15,7 @@ from sgm.util import ( get_context_parallel_group_rank, ) -# try: from vae_modules.utils import SafeConv3d as Conv3d -# except: -# # Degrade to normal Conv3d if SafeConv3d is not available -# from torch.nn import Conv3d def cast_tuple(t, length=1): @@ -81,7 +76,6 @@ def _split(input_, dim): cp_rank = get_context_parallel_rank() - # print('in _split, cp_rank:', cp_rank, 'input_size:', input_.shape) inpu_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous() input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous() @@ -94,7 +88,6 @@ def _split(input_, dim): output = torch.cat([inpu_first_frame_, output], dim=dim) output = output.contiguous() - # print('out _split, cp_rank:', cp_rank, 'output_size:', output.shape) return output @@ -382,19 +375,6 @@ class ContextParallelCausalConv3d(nn.Module): self.cache_padding = None def forward(self, input_, clear_cache=True): - # if input_.shape[2] == 1: # handle image - # # first frame padding - # input_parallel = torch.cat([input_] * self.time_kernel_size, dim=2) - # else: - # input_parallel = conv_pass_from_last_rank(input_, self.temporal_dim, self.time_kernel_size) - - # padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) - # input_parallel = F.pad(input_parallel, padding_2d, mode = 'constant', value = 0) - - # output_parallel = self.conv(input_parallel) - # output = output_parallel - # return output - input_parallel = fake_cp_pass_from_previous_rank( input_, self.temporal_dim, self.time_kernel_size, self.cache_padding ) @@ -464,7 +444,6 @@ class SpatialNorm3D(nn.Module): self.norm_layer = ContextParallelGroupNorm(num_channels=f_channels, **norm_layer_params) else: self.norm_layer = torch.nn.GroupNorm(num_channels=f_channels, **norm_layer_params) - # self.norm_layer = norm_layer(num_channels=f_channels, **norm_layer_params) if freeze_norm_layer: for p in self.norm_layer.parameters: p.requires_grad = False @@ -543,21 +522,29 @@ class Upsample3D(nn.Module): def forward(self, x): if self.compress_time and x.shape[2] > 1: - if x.shape[2] % 2 == 1: - # split first frame - x_first, x_rest = x[:, :, 0], x[:, :, 1:] - - x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0, mode="nearest") - x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0, mode="nearest") - x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2) - else: - x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + # Process the time dimension first as x_first + x_first, x_rest = x[:, :, 0], x[:, :, 1:] + # print(x_first.shape) + x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0, mode="nearest") + # split the rest of the frames to avoid MAX_INT overflow in Pytorch + splits = torch.split(x_rest, 16, dim=1) + interpolated_splits = [ + torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits + ] + x_rest = torch.cat(interpolated_splits, dim=1) + # concatenate the first frame with the rest + x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2) else: - # only interpolate 2D t = x.shape[2] x = rearrange(x, "b c t h w -> (b t) c h w") - x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + + splits = torch.split(x, 16, dim=1) + interpolated_splits = [ + torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits + ] + x = torch.cat(interpolated_splits, dim=1) + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) if self.with_conv: @@ -585,15 +572,21 @@ class DownSample3D(nn.Module): x = rearrange(x, "b c t h w -> (b h w) c t") if x.shape[-1] % 2 == 1: - # split first frame x_first, x_rest = x[..., 0], x[..., 1:] - if x_rest.shape[-1] > 0: - x_rest = torch.nn.functional.avg_pool1d(x_rest, kernel_size=2, stride=2) + splits = torch.split(x_rest, 16, dim=1) + interpolated_splits = [ + torch.nn.functional.avg_pool1d(split, kernel_size=2, stride=2) for split in splits + ] + x_rest = torch.cat(interpolated_splits, dim=1) x = torch.cat([x_first[..., None], x_rest], dim=-1) x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w) else: - x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2) + splits = torch.split(x, 16, dim=1) + interpolated_splits = [ + torch.nn.functional.avg_pool1d(split, kernel_size=2, stride=2) for split in splits + ] + x = torch.cat(interpolated_splits, dim=1) x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w) if self.with_conv: @@ -675,31 +668,19 @@ class ContextParallelResnetBlock3D(nn.Module): def forward(self, x, temb, zq=None, clear_fake_cp_cache=True): h = x - - # if isinstance(self.norm1, torch.nn.GroupNorm): - # h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1) if zq is not None: h = self.norm1(h, zq, clear_fake_cp_cache=clear_fake_cp_cache) else: h = self.norm1(h) - # if isinstance(self.norm1, torch.nn.GroupNorm): - # h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1) - h = nonlinearity(h) h = self.conv1(h, clear_cache=clear_fake_cp_cache) if temb is not None: h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None] - - # if isinstance(self.norm2, torch.nn.GroupNorm): - # h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1) if zq is not None: h = self.norm2(h, zq, clear_fake_cp_cache=clear_fake_cp_cache) else: h = self.norm2(h) - # if isinstance(self.norm2, torch.nn.GroupNorm): - # h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1) - h = nonlinearity(h) h = self.dropout(h) h = self.conv2(h, clear_cache=clear_fake_cp_cache) @@ -826,10 +807,7 @@ class ContextParallelEncoder3D(nn.Module): h = self.mid.block_2(h, temb) # end - # h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1) h = self.norm_out(h) - # h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1) - h = nonlinearity(h) h = self.conv_out(h) @@ -934,10 +912,10 @@ class ContextParallelDecoder3D(nn.Module): up.block = block up.attn = attn if i_level != 0: - if i_level < self.num_resolutions - self.temporal_compress_level: - up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False) - else: + if i_level <= self.temporal_compress_level: up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=True) + else: + up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False) self.up.insert(0, up) self.norm_out = Normalize3D(block_in, zq_ch, add_conv=add_conv, gather=gather_norm) @@ -954,9 +932,7 @@ class ContextParallelDecoder3D(nn.Module): # timestep embedding temb = None - t = z.shape[2] # z to block_in - zq = z h = self.conv_in(z, clear_cache=clear_fake_cp_cache) diff --git a/tools/convert_weight_sat2hf.py b/tools/convert_weight_sat2hf.py index 183be62..f325018 100644 --- a/tools/convert_weight_sat2hf.py +++ b/tools/convert_weight_sat2hf.py @@ -1,22 +1,15 @@ """ -This script demonstrates how to convert and generate video from a text prompt -using CogVideoX with 🤗Huggingface Diffusers Pipeline. -This script requires the `diffusers>=0.30.2` library to be installed. - -Functions: - - reassign_query_key_value_inplace: Reassigns the query, key, and value weights in-place. - - reassign_query_key_layernorm_inplace: Reassigns layer normalization for query and key in-place. - - reassign_adaln_norm_inplace: Reassigns adaptive layer normalization in-place. - - remove_keys_inplace: Removes specified keys from the state_dict in-place. - - replace_up_keys_inplace: Replaces keys in the "up" block in-place. - - get_state_dict: Extracts the state_dict from a saved checkpoint. - - update_state_dict_inplace: Updates the state_dict with new key assignments in-place. - - convert_transformer: Converts a transformer checkpoint to the CogVideoX format. - - convert_vae: Converts a VAE checkpoint to the CogVideoX format. - - get_args: Parses command-line arguments for the script. - - generate_video: Generates a video from a text prompt using the CogVideoX pipeline. -""" +The script demonstrates how to convert the weights of the CogVideoX model from SAT to Hugging Face format. +This script supports the conversion of the following models: +- CogVideoX-2B +- CogVideoX-5B, CogVideoX-5B-I2V +- CogVideoX1.1-5B, CogVideoX1.1-5B-I2V + +Original Script: +https://github.com/huggingface/diffusers/blob/main/scripts/convert_cogvideox_to_diffusers.py + +""" import argparse from typing import Any, Dict @@ -153,12 +146,12 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: def convert_transformer( - ckpt_path: str, - num_layers: int, - num_attention_heads: int, - use_rotary_positional_embeddings: bool, - i2v: bool, - dtype: torch.dtype, + ckpt_path: str, + num_layers: int, + num_attention_heads: int, + use_rotary_positional_embeddings: bool, + i2v: bool, + dtype: torch.dtype, ): PREFIX_KEY = "model.diffusion_model." @@ -172,7 +165,7 @@ def convert_transformer( ).to(dtype=dtype) for key in list(original_state_dict.keys()): - new_key = key[len(PREFIX_KEY) :] + new_key = key[len(PREFIX_KEY):] for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): new_key = new_key.replace(replace_key, rename_key) update_state_dict_inplace(original_state_dict, key, new_key) @@ -209,7 +202,8 @@ def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype): def get_args(): parser = argparse.ArgumentParser() parser.add_argument( - "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint") + "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint" + ) parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint") parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") parser.add_argument("--fp16", action="store_true", default=False, help="Whether to save the model weights in fp16") @@ -259,9 +253,10 @@ if __name__ == "__main__": if args.vae_ckpt_path is not None: vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, dtype) - text_encoder_id = "google/t5-v1_1-xxl" + text_encoder_id = "/share/official_pretrains/hf_home/t5-v1_1-xxl" tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH) text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir) + # Apparently, the conversion does not work anymore without this :shrug: for param in text_encoder.parameters(): param.data = param.data.contiguous() @@ -301,4 +296,7 @@ if __name__ == "__main__": # We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird # for users to specify variant when the default is not fp32 and they want to run with the correct default (which # is either fp16/bf16 here). - pipe.save_pretrained(args.output_path, safe_serialization=True, push_to_hub=args.push_to_hub) + + # This is necessary This is necessary for users with insufficient memory, + # such as those using Colab and notebooks, as it can save some memory used for model loading. + pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=args.push_to_hub)