Update i2v_dataset.py

image should also be transformed to [-1, 1]
This commit is contained in:
Zheng Guang Cong 2025-01-11 17:24:35 +08:00 committed by GitHub
parent 35383e2db3
commit cd861bbe1e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -148,12 +148,13 @@ class BaseI2VDataset(Dataset):
frames, image = self.preprocess(video, image)
frames = frames.to(self.device)
image = image.to(self.device)
image = self.image_transform(image)
# Current shape of frames: [F, C, H, W]
frames = self.video_transform(frames)
# Add image into the first frame.
# Note, **this operation maybe model-specific**, and maybe change in the future.
frames = torch.cat([self.image_transform(image).unsqueeze(0), frames], dim=0)
frames = torch.cat([image.unsqueeze(0), frames], dim=0)
# Convert to [B, C, F, H, W]
frames = frames.unsqueeze(0)