diff --git a/finetune/datasets/i2v_dataset.py b/finetune/datasets/i2v_dataset.py index 76663a3..451611e 100644 --- a/finetune/datasets/i2v_dataset.py +++ b/finetune/datasets/i2v_dataset.py @@ -148,12 +148,13 @@ class BaseI2VDataset(Dataset): frames, image = self.preprocess(video, image) frames = frames.to(self.device) image = image.to(self.device) + image = self.image_transform(image) # Current shape of frames: [F, C, H, W] frames = self.video_transform(frames) # Add image into the first frame. # Note, **this operation maybe model-specific**, and maybe change in the future. - frames = torch.cat([self.image_transform(image).unsqueeze(0), frames], dim=0) + frames = torch.cat([image.unsqueeze(0), frames], dim=0) # Convert to [B, C, F, H, W] frames = frames.unsqueeze(0)