put lora back(sat), unavailable running

This commit is contained in:
zR 2025-01-02 11:48:18 +08:00
parent a88c1ede69
commit b080c6a010
3 changed files with 32 additions and 18 deletions

1
.gitignore vendored
View File

@ -22,3 +22,4 @@ venv
**/results
**/*.mp4
**/validation_set
CogVideo-1.0

View File

@ -90,8 +90,18 @@ class SATVideoDiffusionEngine(nn.Module):
self.no_cond_log = no_cond_log
self.device = args.device
# put lora add here
def disable_untrainable_params(self):
total_trainable = 0
if self.lora_train:
for n, p in self.named_parameters():
if p.requires_grad == False:
continue
if 'lora_layer' not in n:
p.lr_scale = 0
else:
total_trainable += p.numel()
else:
for n, p in self.named_parameters():
if p.requires_grad == False:
continue
@ -101,7 +111,7 @@ class SATVideoDiffusionEngine(nn.Module):
flag = True
break
lora_prefix = ["matrix_A", "matrix_B"]
lora_prefix = ['matrix_A', 'matrix_B']
for prefix in lora_prefix:
if prefix in n:
flag = False

View File

@ -31,6 +31,7 @@ class ImagePatchEmbeddingMixin(BaseMixin):
def word_embedding_forward(self, input_ids, **kwargs):
images = kwargs["images"] # (b,t,c,h,w)
emb = rearrange(images, "b t c h w -> b (t h w) c")
# emb = rearrange(images, "b c t h w -> b (t h w) c")
emb = rearrange(
emb,
"b (t o h p w q) c -> b (t h w) (c o p q)",
@ -810,7 +811,9 @@ class DiffusionTransformer(BaseModel):
),
reinit=True,
)
if "lora_config" in module_configs:
lora_config = module_configs["lora_config"]
self.add_mixin("lora", instantiate_from_config(lora_config, layer_num=self.num_layers), reinit=True)
return
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):