From b080c6a010a396849f41d1b54b37cd26df30d2ef Mon Sep 17 00:00:00 2001 From: zR <2448370773@qq.com> Date: Thu, 2 Jan 2025 11:48:18 +0800 Subject: [PATCH] put lora back(sat), unavailable running --- .gitignore | 1 + sat/diffusion_video.py | 44 +++++++++++++++++++++++++---------------- sat/dit_video_concat.py | 5 ++++- 3 files changed, 32 insertions(+), 18 deletions(-) diff --git a/.gitignore b/.gitignore index 2ee7d3d..99c05a3 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,4 @@ venv **/results **/*.mp4 **/validation_set +CogVideo-1.0 diff --git a/sat/diffusion_video.py b/sat/diffusion_video.py index 226ed6e..e39dc76 100644 --- a/sat/diffusion_video.py +++ b/sat/diffusion_video.py @@ -90,27 +90,37 @@ class SATVideoDiffusionEngine(nn.Module): self.no_cond_log = no_cond_log self.device = args.device + # put lora add here def disable_untrainable_params(self): total_trainable = 0 - for n, p in self.named_parameters(): - if p.requires_grad == False: - continue - flag = False - for prefix in self.not_trainable_prefixes: - if n.startswith(prefix) or prefix == "all": - flag = True - break + if self.lora_train: + for n, p in self.named_parameters(): + if p.requires_grad == False: + continue + if 'lora_layer' not in n: + p.lr_scale = 0 + else: + total_trainable += p.numel() + else: + for n, p in self.named_parameters(): + if p.requires_grad == False: + continue + flag = False + for prefix in self.not_trainable_prefixes: + if n.startswith(prefix) or prefix == "all": + flag = True + break - lora_prefix = ["matrix_A", "matrix_B"] - for prefix in lora_prefix: - if prefix in n: - flag = False - break + lora_prefix = ['matrix_A', 'matrix_B'] + for prefix in lora_prefix: + if prefix in n: + flag = False + break - if flag: - p.requires_grad_(False) - else: - total_trainable += p.numel() + if flag: + p.requires_grad_(False) + else: + total_trainable += p.numel() print_rank0("***** Total trainable parameters: " + str(total_trainable) + " *****") diff --git a/sat/dit_video_concat.py b/sat/dit_video_concat.py index 22c3821..82c1d56 100644 --- a/sat/dit_video_concat.py +++ b/sat/dit_video_concat.py @@ -31,6 +31,7 @@ class ImagePatchEmbeddingMixin(BaseMixin): def word_embedding_forward(self, input_ids, **kwargs): images = kwargs["images"] # (b,t,c,h,w) emb = rearrange(images, "b t c h w -> b (t h w) c") + # emb = rearrange(images, "b c t h w -> b (t h w) c") emb = rearrange( emb, "b (t o h p w q) c -> b (t h w) (c o p q)", @@ -810,7 +811,9 @@ class DiffusionTransformer(BaseModel): ), reinit=True, ) - + if "lora_config" in module_configs: + lora_config = module_configs["lora_config"] + self.add_mixin("lora", instantiate_from_config(lora_config, layer_num=self.num_layers), reinit=True) return def forward(self, x, timesteps=None, context=None, y=None, **kwargs):