mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-06 03:57:56 +08:00
Update i2v_dataset.py
image should also be transformed to [-1, 1]
This commit is contained in:
parent
35383e2db3
commit
cd861bbe1e
@ -148,12 +148,13 @@ class BaseI2VDataset(Dataset):
|
|||||||
frames, image = self.preprocess(video, image)
|
frames, image = self.preprocess(video, image)
|
||||||
frames = frames.to(self.device)
|
frames = frames.to(self.device)
|
||||||
image = image.to(self.device)
|
image = image.to(self.device)
|
||||||
|
image = self.image_transform(image)
|
||||||
# Current shape of frames: [F, C, H, W]
|
# Current shape of frames: [F, C, H, W]
|
||||||
frames = self.video_transform(frames)
|
frames = self.video_transform(frames)
|
||||||
|
|
||||||
# Add image into the first frame.
|
# Add image into the first frame.
|
||||||
# Note, **this operation maybe model-specific**, and maybe change in the future.
|
# Note, **this operation maybe model-specific**, and maybe change in the future.
|
||||||
frames = torch.cat([self.image_transform(image).unsqueeze(0), frames], dim=0)
|
frames = torch.cat([image.unsqueeze(0), frames], dim=0)
|
||||||
|
|
||||||
# Convert to [B, C, F, H, W]
|
# Convert to [B, C, F, H, W]
|
||||||
frames = frames.unsqueeze(0)
|
frames = frames.unsqueeze(0)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user