mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-06 03:57:56 +08:00
Merge remote-tracking branch 'upstream/CogVideoX_dev' into dev
This commit is contained in:
commit
1b886326b2
1
.gitignore
vendored
1
.gitignore
vendored
@ -22,3 +22,4 @@ venv
|
|||||||
**/results
|
**/results
|
||||||
**/*.mp4
|
**/*.mp4
|
||||||
**/validation_set
|
**/validation_set
|
||||||
|
CogVideo-1.0
|
||||||
|
@ -10,7 +10,6 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from sgm.modules import UNCONDITIONAL_CONFIG
|
from sgm.modules import UNCONDITIONAL_CONFIG
|
||||||
from sgm.modules.autoencoding.temporal_ae import VideoDecoder
|
|
||||||
from sgm.modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
|
from sgm.modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
|
||||||
from sgm.util import (
|
from sgm.util import (
|
||||||
default,
|
default,
|
||||||
@ -90,27 +89,37 @@ class SATVideoDiffusionEngine(nn.Module):
|
|||||||
self.no_cond_log = no_cond_log
|
self.no_cond_log = no_cond_log
|
||||||
self.device = args.device
|
self.device = args.device
|
||||||
|
|
||||||
|
# put lora add here
|
||||||
def disable_untrainable_params(self):
|
def disable_untrainable_params(self):
|
||||||
total_trainable = 0
|
total_trainable = 0
|
||||||
for n, p in self.named_parameters():
|
if self.lora_train:
|
||||||
if p.requires_grad == False:
|
for n, p in self.named_parameters():
|
||||||
continue
|
if p.requires_grad == False:
|
||||||
flag = False
|
continue
|
||||||
for prefix in self.not_trainable_prefixes:
|
if 'lora_layer' not in n:
|
||||||
if n.startswith(prefix) or prefix == "all":
|
p.lr_scale = 0
|
||||||
flag = True
|
else:
|
||||||
break
|
total_trainable += p.numel()
|
||||||
|
else:
|
||||||
|
for n, p in self.named_parameters():
|
||||||
|
if p.requires_grad == False:
|
||||||
|
continue
|
||||||
|
flag = False
|
||||||
|
for prefix in self.not_trainable_prefixes:
|
||||||
|
if n.startswith(prefix) or prefix == "all":
|
||||||
|
flag = True
|
||||||
|
break
|
||||||
|
|
||||||
lora_prefix = ["matrix_A", "matrix_B"]
|
lora_prefix = ['matrix_A', 'matrix_B']
|
||||||
for prefix in lora_prefix:
|
for prefix in lora_prefix:
|
||||||
if prefix in n:
|
if prefix in n:
|
||||||
flag = False
|
flag = False
|
||||||
break
|
break
|
||||||
|
|
||||||
if flag:
|
if flag:
|
||||||
p.requires_grad_(False)
|
p.requires_grad_(False)
|
||||||
else:
|
else:
|
||||||
total_trainable += p.numel()
|
total_trainable += p.numel()
|
||||||
|
|
||||||
print_rank0("***** Total trainable parameters: " + str(total_trainable) + " *****")
|
print_rank0("***** Total trainable parameters: " + str(total_trainable) + " *****")
|
||||||
|
|
||||||
@ -182,11 +191,7 @@ class SATVideoDiffusionEngine(nn.Module):
|
|||||||
for n in range(n_rounds):
|
for n in range(n_rounds):
|
||||||
z_now = z[n * n_samples : (n + 1) * n_samples, :, 1:]
|
z_now = z[n * n_samples : (n + 1) * n_samples, :, 1:]
|
||||||
latent_time = z_now.shape[2] # check the time latent
|
latent_time = z_now.shape[2] # check the time latent
|
||||||
temporal_compress_times = 4
|
|
||||||
|
|
||||||
fake_cp_size = min(10, latent_time // 2)
|
fake_cp_size = min(10, latent_time // 2)
|
||||||
start_frame = 0
|
|
||||||
|
|
||||||
recons = []
|
recons = []
|
||||||
start_frame = 0
|
start_frame = 0
|
||||||
for i in range(fake_cp_size):
|
for i in range(fake_cp_size):
|
||||||
|
@ -31,6 +31,7 @@ class ImagePatchEmbeddingMixin(BaseMixin):
|
|||||||
def word_embedding_forward(self, input_ids, **kwargs):
|
def word_embedding_forward(self, input_ids, **kwargs):
|
||||||
images = kwargs["images"] # (b,t,c,h,w)
|
images = kwargs["images"] # (b,t,c,h,w)
|
||||||
emb = rearrange(images, "b t c h w -> b (t h w) c")
|
emb = rearrange(images, "b t c h w -> b (t h w) c")
|
||||||
|
# emb = rearrange(images, "b c t h w -> b (t h w) c")
|
||||||
emb = rearrange(
|
emb = rearrange(
|
||||||
emb,
|
emb,
|
||||||
"b (t o h p w q) c -> b (t h w) (c o p q)",
|
"b (t o h p w q) c -> b (t h w) (c o p q)",
|
||||||
@ -810,7 +811,9 @@ class DiffusionTransformer(BaseModel):
|
|||||||
),
|
),
|
||||||
reinit=True,
|
reinit=True,
|
||||||
)
|
)
|
||||||
|
if "lora_config" in module_configs:
|
||||||
|
lora_config = module_configs["lora_config"]
|
||||||
|
self.add_mixin("lora", instantiate_from_config(lora_config, layer_num=self.num_layers), reinit=True)
|
||||||
return
|
return
|
||||||
|
|
||||||
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
|
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user