Adapt dataset for text embeddings and add noise padding

- Add text embedding support in dataset collation
- Pad 2 random noise frames at the beginning of latent space during training
This commit is contained in:
OleehyO 2025-01-06 10:44:58 +00:00
parent 49dc370de6
commit 9157e0cbc8

View File

@ -117,6 +117,10 @@ class CogVideoXT2VLoraTrainer(Trainer):
patch_size_t = self.state.transformer_config.patch_size_t patch_size_t = self.state.transformer_config.patch_size_t
if patch_size_t is not None and latent.shape[2] % patch_size_t != 0: if patch_size_t is not None and latent.shape[2] % patch_size_t != 0:
raise ValueError("Number of frames in latent must be divisible by patch size, please check your args for training.") raise ValueError("Number of frames in latent must be divisible by patch size, please check your args for training.")
# Add 2 random noise frames at the beginning of frame dimension
noise_frames = torch.randn(latent.shape[0], latent.shape[1], 2, latent.shape[3], latent.shape[4], device=latent.device, dtype=latent.dtype)
latent = torch.cat([noise_frames, latent], dim=2)
batch_size, num_channels, num_frames, height, width = latent.shape batch_size, num_channels, num_frames, height, width = latent.shape
@ -185,7 +189,7 @@ class CogVideoXT2VLoraTrainer(Trainer):
prompt, image, video = eval_data["prompt"], eval_data["image"], eval_data["video"] prompt, image, video = eval_data["prompt"], eval_data["image"], eval_data["video"]
video_generate = pipe( video_generate = pipe(
num_frames=self.state.train_frames - 1, # -1 is because t2v does not require adding an image frame like i2v does num_frames=self.state.train_frames, # since we pad 2 frames in latent, we still use train_frames
height=self.state.train_height, height=self.state.train_height,
width=self.state.train_width, width=self.state.train_width,
prompt=prompt, prompt=prompt,