Merge remote-tracking branch 'upstream/CogVideoX_dev' into dev

This commit is contained in:
OleehyO 2025-01-06 10:47:56 +00:00
commit 1b886326b2
3 changed files with 32 additions and 23 deletions

1
.gitignore vendored
View File

@ -22,3 +22,4 @@ venv
**/results **/results
**/*.mp4 **/*.mp4
**/validation_set **/validation_set
CogVideo-1.0

View File

@ -10,7 +10,6 @@ import torch
from torch import nn from torch import nn
from sgm.modules import UNCONDITIONAL_CONFIG from sgm.modules import UNCONDITIONAL_CONFIG
from sgm.modules.autoencoding.temporal_ae import VideoDecoder
from sgm.modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER from sgm.modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
from sgm.util import ( from sgm.util import (
default, default,
@ -90,27 +89,37 @@ class SATVideoDiffusionEngine(nn.Module):
self.no_cond_log = no_cond_log self.no_cond_log = no_cond_log
self.device = args.device self.device = args.device
# put lora add here
def disable_untrainable_params(self): def disable_untrainable_params(self):
total_trainable = 0 total_trainable = 0
for n, p in self.named_parameters(): if self.lora_train:
if p.requires_grad == False: for n, p in self.named_parameters():
continue if p.requires_grad == False:
flag = False continue
for prefix in self.not_trainable_prefixes: if 'lora_layer' not in n:
if n.startswith(prefix) or prefix == "all": p.lr_scale = 0
flag = True else:
break total_trainable += p.numel()
else:
for n, p in self.named_parameters():
if p.requires_grad == False:
continue
flag = False
for prefix in self.not_trainable_prefixes:
if n.startswith(prefix) or prefix == "all":
flag = True
break
lora_prefix = ["matrix_A", "matrix_B"] lora_prefix = ['matrix_A', 'matrix_B']
for prefix in lora_prefix: for prefix in lora_prefix:
if prefix in n: if prefix in n:
flag = False flag = False
break break
if flag: if flag:
p.requires_grad_(False) p.requires_grad_(False)
else: else:
total_trainable += p.numel() total_trainable += p.numel()
print_rank0("***** Total trainable parameters: " + str(total_trainable) + " *****") print_rank0("***** Total trainable parameters: " + str(total_trainable) + " *****")
@ -182,11 +191,7 @@ class SATVideoDiffusionEngine(nn.Module):
for n in range(n_rounds): for n in range(n_rounds):
z_now = z[n * n_samples : (n + 1) * n_samples, :, 1:] z_now = z[n * n_samples : (n + 1) * n_samples, :, 1:]
latent_time = z_now.shape[2] # check the time latent latent_time = z_now.shape[2] # check the time latent
temporal_compress_times = 4
fake_cp_size = min(10, latent_time // 2) fake_cp_size = min(10, latent_time // 2)
start_frame = 0
recons = [] recons = []
start_frame = 0 start_frame = 0
for i in range(fake_cp_size): for i in range(fake_cp_size):

View File

@ -31,6 +31,7 @@ class ImagePatchEmbeddingMixin(BaseMixin):
def word_embedding_forward(self, input_ids, **kwargs): def word_embedding_forward(self, input_ids, **kwargs):
images = kwargs["images"] # (b,t,c,h,w) images = kwargs["images"] # (b,t,c,h,w)
emb = rearrange(images, "b t c h w -> b (t h w) c") emb = rearrange(images, "b t c h w -> b (t h w) c")
# emb = rearrange(images, "b c t h w -> b (t h w) c")
emb = rearrange( emb = rearrange(
emb, emb,
"b (t o h p w q) c -> b (t h w) (c o p q)", "b (t o h p w q) c -> b (t h w) (c o p q)",
@ -810,7 +811,9 @@ class DiffusionTransformer(BaseModel):
), ),
reinit=True, reinit=True,
) )
if "lora_config" in module_configs:
lora_config = module_configs["lora_config"]
self.add_mixin("lora", instantiate_from_config(lora_config, layer_num=self.num_layers), reinit=True)
return return
def forward(self, x, timesteps=None, context=None, y=None, **kwargs): def forward(self, x, timesteps=None, context=None, y=None, **kwargs):