mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-06-25 03:29:18 +08:00
Update train_cogvideox_image_to_video_lora.py
This commit is contained in:
parent
f28708d845
commit
e169e7b045
@ -573,36 +573,36 @@ class VideoDataset(Dataset):
|
|||||||
return instance_prompts, instance_videos
|
return instance_prompts, instance_videos
|
||||||
|
|
||||||
def _resize_for_rectangle_crop(self, arr):
|
def _resize_for_rectangle_crop(self, arr):
|
||||||
image_size = self.height, self.width
|
image_size = self.height, self.width
|
||||||
reshape_mode = self.video_reshape_mode
|
reshape_mode = self.video_reshape_mode
|
||||||
if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
|
if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
|
||||||
arr = resize(
|
arr = resize(
|
||||||
arr,
|
arr,
|
||||||
size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
|
size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
|
||||||
interpolation=InterpolationMode.BICUBIC,
|
interpolation=InterpolationMode.BICUBIC,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
arr = resize(
|
arr = resize(
|
||||||
arr,
|
arr,
|
||||||
size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
|
size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
|
||||||
interpolation=InterpolationMode.BICUBIC,
|
interpolation=InterpolationMode.BICUBIC,
|
||||||
)
|
)
|
||||||
|
|
||||||
h, w = arr.shape[2], arr.shape[3]
|
h, w = arr.shape[2], arr.shape[3]
|
||||||
arr = arr.squeeze(0)
|
arr = arr.squeeze(0)
|
||||||
|
|
||||||
delta_h = h - image_size[0]
|
delta_h = h - image_size[0]
|
||||||
delta_w = w - image_size[1]
|
delta_w = w - image_size[1]
|
||||||
|
|
||||||
if reshape_mode == "random" or reshape_mode == "none":
|
if reshape_mode == "random" or reshape_mode == "none":
|
||||||
top = np.random.randint(0, delta_h + 1)
|
top = np.random.randint(0, delta_h + 1)
|
||||||
left = np.random.randint(0, delta_w + 1)
|
left = np.random.randint(0, delta_w + 1)
|
||||||
elif reshape_mode == "center":
|
elif reshape_mode == "center":
|
||||||
top, left = delta_h // 2, delta_w // 2
|
top, left = delta_h // 2, delta_w // 2
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
|
arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
|
||||||
return arr
|
return arr
|
||||||
|
|
||||||
def _preprocess_data(self):
|
def _preprocess_data(self):
|
||||||
try:
|
try:
|
||||||
@ -622,8 +622,7 @@ class VideoDataset(Dataset):
|
|||||||
videos = []
|
videos = []
|
||||||
|
|
||||||
for filename in self.instance_video_paths:
|
for filename in self.instance_video_paths:
|
||||||
progress_dataset_bar.update(1)
|
video_reader = decord.VideoReader(uri=filename.as_posix())
|
||||||
video_reader = decord.VideoReader(uri=filename.as_posix(), width=self.width, height=self.height)
|
|
||||||
video_num_frames = len(video_reader)
|
video_num_frames = len(video_reader)
|
||||||
|
|
||||||
start_frame = min(self.skip_frames_start, video_num_frames)
|
start_frame = min(self.skip_frames_start, video_num_frames)
|
||||||
@ -651,8 +650,12 @@ class VideoDataset(Dataset):
|
|||||||
# Training transforms
|
# Training transforms
|
||||||
frames = (frames - 127.5) / 127.5
|
frames = (frames - 127.5) / 127.5
|
||||||
frames = frames.permute(0, 3, 1, 2) # [F, C, H, W]
|
frames = frames.permute(0, 3, 1, 2) # [F, C, H, W]
|
||||||
|
progress_dataset_bar.set_description(
|
||||||
|
f"Loading progress Resizing video from {frames.shape[2]}x{frames.shape[3]} to {self.height}x{self.width}"
|
||||||
|
)
|
||||||
frames = self._resize_for_rectangle_crop(frames)
|
frames = self._resize_for_rectangle_crop(frames)
|
||||||
videos.append(frames.contiguous()) # [F, C, H, W]
|
videos.append(frames.contiguous()) # [F, C, H, W]
|
||||||
|
progress_dataset_bar.update(1)
|
||||||
|
|
||||||
progress_dataset_bar.close()
|
progress_dataset_bar.close()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user