mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-06 03:57:56 +08:00
Adapt dataset for text embeddings and add noise padding
- Add text embedding support in dataset collation - Pad 2 random noise frames at the beginning of latent space during training
This commit is contained in:
parent
49dc370de6
commit
9157e0cbc8
@ -117,6 +117,10 @@ class CogVideoXT2VLoraTrainer(Trainer):
|
||||
patch_size_t = self.state.transformer_config.patch_size_t
|
||||
if patch_size_t is not None and latent.shape[2] % patch_size_t != 0:
|
||||
raise ValueError("Number of frames in latent must be divisible by patch size, please check your args for training.")
|
||||
|
||||
# Add 2 random noise frames at the beginning of frame dimension
|
||||
noise_frames = torch.randn(latent.shape[0], latent.shape[1], 2, latent.shape[3], latent.shape[4], device=latent.device, dtype=latent.dtype)
|
||||
latent = torch.cat([noise_frames, latent], dim=2)
|
||||
|
||||
batch_size, num_channels, num_frames, height, width = latent.shape
|
||||
|
||||
@ -185,7 +189,7 @@ class CogVideoXT2VLoraTrainer(Trainer):
|
||||
prompt, image, video = eval_data["prompt"], eval_data["image"], eval_data["video"]
|
||||
|
||||
video_generate = pipe(
|
||||
num_frames=self.state.train_frames - 1, # -1 is because t2v does not require adding an image frame like i2v does
|
||||
num_frames=self.state.train_frames, # since we pad 2 frames in latent, we still use train_frames
|
||||
height=self.state.train_height,
|
||||
width=self.state.train_width,
|
||||
prompt=prompt,
|
||||
|
Loading…
x
Reference in New Issue
Block a user