mirror of
https://github.com/THUDM/CogVideo.git
synced 2026-06-01 17:08:13 +08:00
Fix incorrect frame padding formula in lora trainers
The ncopy calculation used `latent.shape[2] % patch_size_t` which computes the remainder rather than the number of frames needed to reach alignment. For example, with shape[2]=13 and patch_size_t=4, this gives ncopy=1, resulting in 14 frames which is still not divisible by 4, causing the assertion to fail. The correct formula is `(patch_size_t - latent.shape[2] % patch_size_t) % patch_size_t` which computes how many frames must be prepended to reach the next multiple of patch_size_t. The outer modulo handles the already-aligned case (returns 0 instead of patch_size_t). Fixes #782
This commit is contained in:
parent
7a1af71545
commit
7bc152ff35
@ -115,7 +115,7 @@ class CogVideoXI2VLoraTrainer(Trainer):
|
|||||||
|
|
||||||
patch_size_t = self.state.transformer_config.patch_size_t
|
patch_size_t = self.state.transformer_config.patch_size_t
|
||||||
if patch_size_t is not None:
|
if patch_size_t is not None:
|
||||||
ncopy = latent.shape[2] % patch_size_t
|
ncopy = (patch_size_t - latent.shape[2] % patch_size_t) % patch_size_t
|
||||||
# Copy the first frame ncopy times to match patch_size_t
|
# Copy the first frame ncopy times to match patch_size_t
|
||||||
first_frame = latent[:, :, :1, :, :] # Get first frame [B, C, 1, H, W]
|
first_frame = latent[:, :, :1, :, :] # Get first frame [B, C, 1, H, W]
|
||||||
latent = torch.cat([first_frame.repeat(1, 1, ncopy, 1, 1), latent], dim=2)
|
latent = torch.cat([first_frame.repeat(1, 1, ncopy, 1, 1), latent], dim=2)
|
||||||
|
|||||||
@ -109,7 +109,7 @@ class CogVideoXT2VLoraTrainer(Trainer):
|
|||||||
|
|
||||||
patch_size_t = self.state.transformer_config.patch_size_t
|
patch_size_t = self.state.transformer_config.patch_size_t
|
||||||
if patch_size_t is not None:
|
if patch_size_t is not None:
|
||||||
ncopy = latent.shape[2] % patch_size_t
|
ncopy = (patch_size_t - latent.shape[2] % patch_size_t) % patch_size_t
|
||||||
# Copy the first frame ncopy times to match patch_size_t
|
# Copy the first frame ncopy times to match patch_size_t
|
||||||
first_frame = latent[:, :, :1, :, :] # Get first frame [B, C, 1, H, W]
|
first_frame = latent[:, :, :1, :, :] # Get first frame [B, C, 1, H, W]
|
||||||
latent = torch.cat([first_frame.repeat(1, 1, ncopy, 1, 1), latent], dim=2)
|
latent = torch.cat([first_frame.repeat(1, 1, ncopy, 1, 1), latent], dim=2)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user