diff --git a/.gitignore b/.gitignore index 2ee7d3d..99c05a3 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,4 @@ venv **/results **/*.mp4 **/validation_set +CogVideo-1.0 diff --git a/sat/diffusion_video.py b/sat/diffusion_video.py index 226ed6e..b9c0552 100644 --- a/sat/diffusion_video.py +++ b/sat/diffusion_video.py @@ -10,7 +10,6 @@ import torch from torch import nn from sgm.modules import UNCONDITIONAL_CONFIG -from sgm.modules.autoencoding.temporal_ae import VideoDecoder from sgm.modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER from sgm.util import ( default, @@ -90,27 +89,37 @@ class SATVideoDiffusionEngine(nn.Module): self.no_cond_log = no_cond_log self.device = args.device + # put lora add here def disable_untrainable_params(self): total_trainable = 0 - for n, p in self.named_parameters(): - if p.requires_grad == False: - continue - flag = False - for prefix in self.not_trainable_prefixes: - if n.startswith(prefix) or prefix == "all": - flag = True - break + if self.lora_train: + for n, p in self.named_parameters(): + if p.requires_grad == False: + continue + if 'lora_layer' not in n: + p.lr_scale = 0 + else: + total_trainable += p.numel() + else: + for n, p in self.named_parameters(): + if p.requires_grad == False: + continue + flag = False + for prefix in self.not_trainable_prefixes: + if n.startswith(prefix) or prefix == "all": + flag = True + break - lora_prefix = ["matrix_A", "matrix_B"] - for prefix in lora_prefix: - if prefix in n: - flag = False - break + lora_prefix = ['matrix_A', 'matrix_B'] + for prefix in lora_prefix: + if prefix in n: + flag = False + break - if flag: - p.requires_grad_(False) - else: - total_trainable += p.numel() + if flag: + p.requires_grad_(False) + else: + total_trainable += p.numel() print_rank0("***** Total trainable parameters: " + str(total_trainable) + " *****") @@ -182,11 +191,7 @@ class SATVideoDiffusionEngine(nn.Module): for n in range(n_rounds): z_now = z[n * n_samples : (n + 1) * n_samples, :, 1:] latent_time = z_now.shape[2] # check the time latent - temporal_compress_times = 4 - fake_cp_size = min(10, latent_time // 2) - start_frame = 0 - recons = [] start_frame = 0 for i in range(fake_cp_size): diff --git a/sat/dit_video_concat.py b/sat/dit_video_concat.py index 22c3821..82c1d56 100644 --- a/sat/dit_video_concat.py +++ b/sat/dit_video_concat.py @@ -31,6 +31,7 @@ class ImagePatchEmbeddingMixin(BaseMixin): def word_embedding_forward(self, input_ids, **kwargs): images = kwargs["images"] # (b,t,c,h,w) emb = rearrange(images, "b t c h w -> b (t h w) c") + # emb = rearrange(images, "b c t h w -> b (t h w) c") emb = rearrange( emb, "b (t o h p w q) c -> b (t h w) (c o p q)", @@ -810,7 +811,9 @@ class DiffusionTransformer(BaseModel): ), reinit=True, ) - + if "lora_config" in module_configs: + lora_config = module_configs["lora_config"] + self.add_mixin("lora", instantiate_from_config(lora_config, layer_num=self.num_layers), reinit=True) return def forward(self, x, timesteps=None, context=None, y=None, **kwargs):