diff --git a/finetune/datasets/i2v_dataset.py b/finetune/datasets/i2v_dataset.py index b26bb7f..451611e 100644 --- a/finetune/datasets/i2v_dataset.py +++ b/finetune/datasets/i2v_dataset.py @@ -148,6 +148,7 @@ class BaseI2VDataset(Dataset): frames, image = self.preprocess(video, image) frames = frames.to(self.device) image = image.to(self.device) + image = self.image_transform(image) # Current shape of frames: [F, C, H, W] frames = self.video_transform(frames) diff --git a/finetune/trainer.py b/finetune/trainer.py index 53fc193..d1eaab5 100644 --- a/finetune/trainer.py +++ b/finetune/trainer.py @@ -526,7 +526,7 @@ class Trainer: video, self.state.train_frames, self.state.train_height, self.state.train_width ) # Convert video tensor (F, C, H, W) to list of PIL images - video = (video * 255).round().clamp(0, 255).to(torch.uint8) + video = video.round().clamp(0, 255).to(torch.uint8) video = [Image.fromarray(frame.permute(1, 2, 0).cpu().numpy()) for frame in video] logger.debug(