Adapt dataset for text embeddings and add noise padding

- Add text embedding support in dataset collation
- Pad 2 random noise frames at the beginning of latent space during training
This commit is contained in:
OleehyO 2025-01-06 10:44:58 +00:00
parent 49dc370de6
commit 9157e0cbc8

View File

@ -117,6 +117,10 @@ class CogVideoXT2VLoraTrainer(Trainer):
patch_size_t = self.state.transformer_config.patch_size_t
if patch_size_t is not None and latent.shape[2] % patch_size_t != 0:
raise ValueError("Number of frames in latent must be divisible by patch size, please check your args for training.")
# Add 2 random noise frames at the beginning of frame dimension
noise_frames = torch.randn(latent.shape[0], latent.shape[1], 2, latent.shape[3], latent.shape[4], device=latent.device, dtype=latent.dtype)
latent = torch.cat([noise_frames, latent], dim=2)
batch_size, num_channels, num_frames, height, width = latent.shape
@ -185,7 +189,7 @@ class CogVideoXT2VLoraTrainer(Trainer):
prompt, image, video = eval_data["prompt"], eval_data["image"], eval_data["video"]
video_generate = pipe(
num_frames=self.state.train_frames - 1, # -1 is because t2v does not require adding an image frame like i2v does
num_frames=self.state.train_frames, # since we pad 2 frames in latent, we still use train_frames
height=self.state.train_height,
width=self.state.train_width,
prompt=prompt,