mirror of
https://github.com/THUDM/CogVideo.git
synced 2026-05-06 22:58:13 +08:00
Merge 7bc152ff350eed9967206a681032abe38b7b8ca7 into 7a1af7154511e0ce4e4be8d62faa8c5e5a3532d2
This commit is contained in:
commit
256e7cfd91
@ -115,7 +115,7 @@ class CogVideoXI2VLoraTrainer(Trainer):
|
||||
|
||||
patch_size_t = self.state.transformer_config.patch_size_t
|
||||
if patch_size_t is not None:
|
||||
ncopy = latent.shape[2] % patch_size_t
|
||||
ncopy = (patch_size_t - latent.shape[2] % patch_size_t) % patch_size_t
|
||||
# Copy the first frame ncopy times to match patch_size_t
|
||||
first_frame = latent[:, :, :1, :, :] # Get first frame [B, C, 1, H, W]
|
||||
latent = torch.cat([first_frame.repeat(1, 1, ncopy, 1, 1), latent], dim=2)
|
||||
|
||||
@ -109,7 +109,7 @@ class CogVideoXT2VLoraTrainer(Trainer):
|
||||
|
||||
patch_size_t = self.state.transformer_config.patch_size_t
|
||||
if patch_size_t is not None:
|
||||
ncopy = latent.shape[2] % patch_size_t
|
||||
ncopy = (patch_size_t - latent.shape[2] % patch_size_t) % patch_size_t
|
||||
# Copy the first frame ncopy times to match patch_size_t
|
||||
first_frame = latent[:, :, :1, :, :] # Get first frame [B, C, 1, H, W]
|
||||
latent = torch.cat([first_frame.repeat(1, 1, ncopy, 1, 1), latent], dim=2)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user