mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 19:41:59 +08:00
Merge remote-tracking branch 'upstream/main' into dev
This commit is contained in:
commit
30ba1085ff
@ -147,6 +147,7 @@ class BaseI2VDataset(Dataset):
|
|||||||
frames, image = self.preprocess(video, image)
|
frames, image = self.preprocess(video, image)
|
||||||
frames = frames.to(self.device)
|
frames = frames.to(self.device)
|
||||||
image = image.to(self.device)
|
image = image.to(self.device)
|
||||||
|
image = self.image_transform(image)
|
||||||
# Current shape of frames: [F, C, H, W]
|
# Current shape of frames: [F, C, H, W]
|
||||||
frames = self.video_transform(frames)
|
frames = self.video_transform(frames)
|
||||||
|
|
||||||
|
@ -539,7 +539,7 @@ class Trainer:
|
|||||||
video, self.state.train_frames, self.state.train_height, self.state.train_width
|
video, self.state.train_frames, self.state.train_height, self.state.train_width
|
||||||
)
|
)
|
||||||
# Convert video tensor (F, C, H, W) to list of PIL images
|
# Convert video tensor (F, C, H, W) to list of PIL images
|
||||||
video = (video * 255).round().clamp(0, 255).to(torch.uint8)
|
video = video.round().clamp(0, 255).to(torch.uint8)
|
||||||
video = [Image.fromarray(frame.permute(1, 2, 0).cpu().numpy()) for frame in video]
|
video = [Image.fromarray(frame.permute(1, 2, 0).cpu().numpy()) for frame in video]
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user