From 9157e0cbc8e23d9c551f8483100676473cc49892 Mon Sep 17 00:00:00 2001 From: OleehyO Date: Mon, 6 Jan 2025 10:44:58 +0000 Subject: [PATCH] Adapt dataset for text embeddings and add noise padding - Add text embedding support in dataset collation - Pad 2 random noise frames at the beginning of latent space during training --- finetune/models/cogvideox_t2v/lora_trainer.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/finetune/models/cogvideox_t2v/lora_trainer.py b/finetune/models/cogvideox_t2v/lora_trainer.py index 45d31fd..2bf04bc 100644 --- a/finetune/models/cogvideox_t2v/lora_trainer.py +++ b/finetune/models/cogvideox_t2v/lora_trainer.py @@ -117,6 +117,10 @@ class CogVideoXT2VLoraTrainer(Trainer): patch_size_t = self.state.transformer_config.patch_size_t if patch_size_t is not None and latent.shape[2] % patch_size_t != 0: raise ValueError("Number of frames in latent must be divisible by patch size, please check your args for training.") + + # Add 2 random noise frames at the beginning of frame dimension + noise_frames = torch.randn(latent.shape[0], latent.shape[1], 2, latent.shape[3], latent.shape[4], device=latent.device, dtype=latent.dtype) + latent = torch.cat([noise_frames, latent], dim=2) batch_size, num_channels, num_frames, height, width = latent.shape @@ -185,7 +189,7 @@ class CogVideoXT2VLoraTrainer(Trainer): prompt, image, video = eval_data["prompt"], eval_data["image"], eval_data["video"] video_generate = pipe( - num_frames=self.state.train_frames - 1, # -1 is because t2v does not require adding an image frame like i2v does + num_frames=self.state.train_frames, # since we pad 2 frames in latent, we still use train_frames height=self.state.train_height, width=self.state.train_width, prompt=prompt,