mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-06-03 20:09:15 +08:00
fix: pad latent frames to match patch_size_t requirements
This commit is contained in:
parent
f6d722cec7
commit
e213b6c083
@ -100,22 +100,12 @@ class CogVideoXT2VLoraTrainer(Trainer):
|
|||||||
# Shape of latent: [B, C, F, H, W]
|
# Shape of latent: [B, C, F, H, W]
|
||||||
|
|
||||||
patch_size_t = self.state.transformer_config.patch_size_t
|
patch_size_t = self.state.transformer_config.patch_size_t
|
||||||
if patch_size_t is not None and latent.shape[2] % patch_size_t != 0:
|
if patch_size_t is not None:
|
||||||
raise ValueError(
|
ncopy = latent.shape[2] % patch_size_t
|
||||||
"Number of frames in latent must be divisible by patch size, please check your args for training."
|
# Copy the first frame ncopy times to match patch_size_t
|
||||||
)
|
first_frame = latent[:, :, :1, :, :] # Get first frame [B, C, 1, H, W]
|
||||||
|
latent = torch.cat([first_frame.repeat(1, 1, ncopy, 1, 1), latent], dim=2)
|
||||||
# Add 2 random noise frames at the beginning of frame dimension
|
assert latent.shape[2] % patch_size_t == 0
|
||||||
noise_frames = torch.randn(
|
|
||||||
latent.shape[0],
|
|
||||||
latent.shape[1],
|
|
||||||
2,
|
|
||||||
latent.shape[3],
|
|
||||||
latent.shape[4],
|
|
||||||
device=latent.device,
|
|
||||||
dtype=latent.dtype,
|
|
||||||
)
|
|
||||||
latent = torch.cat([noise_frames, latent], dim=2)
|
|
||||||
|
|
||||||
batch_size, num_channels, num_frames, height, width = latent.shape
|
batch_size, num_channels, num_frames, height, width = latent.shape
|
||||||
|
|
||||||
@ -183,7 +173,7 @@ class CogVideoXT2VLoraTrainer(Trainer):
|
|||||||
prompt, image, video = eval_data["prompt"], eval_data["image"], eval_data["video"]
|
prompt, image, video = eval_data["prompt"], eval_data["image"], eval_data["video"]
|
||||||
|
|
||||||
video_generate = pipe(
|
video_generate = pipe(
|
||||||
num_frames=self.state.train_frames, # since we pad 2 frames in latent, we still use train_frames
|
num_frames=self.state.train_frames,
|
||||||
height=self.state.train_height,
|
height=self.state.train_height,
|
||||||
width=self.state.train_width,
|
width=self.state.train_width,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user