Merge 6c740579050efa68c17e7cdf59e33f1126c3d6fb into 7a1af7154511e0ce4e4be8d62faa8c5e5a3532d2

This commit is contained in:
Harikrishna KP 2026-02-07 05:16:56 +05:30 committed by GitHub
commit 2474e2440a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -109,10 +109,12 @@ class CogVideoXT2VLoraTrainer(Trainer):
patch_size_t = self.state.transformer_config.patch_size_t
if patch_size_t is not None:
ncopy = latent.shape[2] % patch_size_t
remainder = latent.shape[2] % patch_size_t
ncopy = (patch_size_t - remainder) % patch_size_t
# Copy the first frame ncopy times to match patch_size_t
first_frame = latent[:, :, :1, :, :] # Get first frame [B, C, 1, H, W]
latent = torch.cat([first_frame.repeat(1, 1, ncopy, 1, 1), latent], dim=2)
if ncopy > 0:
first_frame = latent[:, :, :1, :, :] # Get first frame [B, C, 1, H, W]
latent = torch.cat([first_frame.repeat(1, 1, ncopy, 1, 1), latent], dim=2)
assert latent.shape[2] % patch_size_t == 0
batch_size, num_channels, num_frames, height, width = latent.shape