mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-06-24 03:01:36 +08:00
Update train_cogvideox_image_to_video_lora.py
This commit is contained in:
parent
f28708d845
commit
e169e7b045
@ -573,36 +573,36 @@ class VideoDataset(Dataset):
|
||||
return instance_prompts, instance_videos
|
||||
|
||||
def _resize_for_rectangle_crop(self, arr):
|
||||
image_size = self.height, self.width
|
||||
reshape_mode = self.video_reshape_mode
|
||||
if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
|
||||
arr = resize(
|
||||
arr,
|
||||
size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
|
||||
interpolation=InterpolationMode.BICUBIC,
|
||||
)
|
||||
else:
|
||||
arr = resize(
|
||||
arr,
|
||||
size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
|
||||
interpolation=InterpolationMode.BICUBIC,
|
||||
)
|
||||
image_size = self.height, self.width
|
||||
reshape_mode = self.video_reshape_mode
|
||||
if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
|
||||
arr = resize(
|
||||
arr,
|
||||
size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
|
||||
interpolation=InterpolationMode.BICUBIC,
|
||||
)
|
||||
else:
|
||||
arr = resize(
|
||||
arr,
|
||||
size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
|
||||
interpolation=InterpolationMode.BICUBIC,
|
||||
)
|
||||
|
||||
h, w = arr.shape[2], arr.shape[3]
|
||||
arr = arr.squeeze(0)
|
||||
h, w = arr.shape[2], arr.shape[3]
|
||||
arr = arr.squeeze(0)
|
||||
|
||||
delta_h = h - image_size[0]
|
||||
delta_w = w - image_size[1]
|
||||
delta_h = h - image_size[0]
|
||||
delta_w = w - image_size[1]
|
||||
|
||||
if reshape_mode == "random" or reshape_mode == "none":
|
||||
top = np.random.randint(0, delta_h + 1)
|
||||
left = np.random.randint(0, delta_w + 1)
|
||||
elif reshape_mode == "center":
|
||||
top, left = delta_h // 2, delta_w // 2
|
||||
else:
|
||||
raise NotImplementedError
|
||||
arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
|
||||
return arr
|
||||
if reshape_mode == "random" or reshape_mode == "none":
|
||||
top = np.random.randint(0, delta_h + 1)
|
||||
left = np.random.randint(0, delta_w + 1)
|
||||
elif reshape_mode == "center":
|
||||
top, left = delta_h // 2, delta_w // 2
|
||||
else:
|
||||
raise NotImplementedError
|
||||
arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
|
||||
return arr
|
||||
|
||||
def _preprocess_data(self):
|
||||
try:
|
||||
@ -622,8 +622,7 @@ class VideoDataset(Dataset):
|
||||
videos = []
|
||||
|
||||
for filename in self.instance_video_paths:
|
||||
progress_dataset_bar.update(1)
|
||||
video_reader = decord.VideoReader(uri=filename.as_posix(), width=self.width, height=self.height)
|
||||
video_reader = decord.VideoReader(uri=filename.as_posix())
|
||||
video_num_frames = len(video_reader)
|
||||
|
||||
start_frame = min(self.skip_frames_start, video_num_frames)
|
||||
@ -651,8 +650,12 @@ class VideoDataset(Dataset):
|
||||
# Training transforms
|
||||
frames = (frames - 127.5) / 127.5
|
||||
frames = frames.permute(0, 3, 1, 2) # [F, C, H, W]
|
||||
progress_dataset_bar.set_description(
|
||||
f"Loading progress Resizing video from {frames.shape[2]}x{frames.shape[3]} to {self.height}x{self.width}"
|
||||
)
|
||||
frames = self._resize_for_rectangle_crop(frames)
|
||||
videos.append(frames.contiguous()) # [F, C, H, W]
|
||||
progress_dataset_bar.update(1)
|
||||
|
||||
progress_dataset_bar.close()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user