From c84b92e081cc7713224a9a8936152974af23e32d Mon Sep 17 00:00:00 2001 From: Jia Zheng Date: Tue, 27 Aug 2024 21:05:16 +0800 Subject: [PATCH 1/2] :zap: loading video online --- sat/data_video.py | 150 +++++++++++++++++++++++----------------------- 1 file changed, 76 insertions(+), 74 deletions(-) diff --git a/sat/data_video.py b/sat/data_video.py index 04a11ca..505f2e0 100644 --- a/sat/data_video.py +++ b/sat/data_video.py @@ -362,93 +362,95 @@ class SFTDataset(Dataset): skip_frms_num: ignore the first and the last xx frames, avoiding transitions. """ super(SFTDataset, self).__init__() + + self.video_size = video_size + self.fps = fps + self.max_num_frames = max_num_frames + self.skip_frms_num = skip_frms_num - self.videos_list = [] - self.captions_list = [] - self.num_frames_list = [] - self.fps_list = [] + self.video_paths = [] - decord.bridge.set_bridge("torch") for root, dirnames, filenames in os.walk(data_dir): for filename in filenames: if filename.endswith(".mp4"): video_path = os.path.join(root, filename) - vr = VideoReader(uri=video_path, height=-1, width=-1) - actual_fps = vr.get_avg_fps() - ori_vlen = len(vr) - - if ori_vlen / actual_fps * fps > max_num_frames: - num_frames = max_num_frames - start = int(skip_frms_num) - end = int(start + num_frames / fps * actual_fps) - end_safty = min(int(start + num_frames / fps * actual_fps), int(ori_vlen)) - indices = np.arange(start, end, (end - start) // num_frames).astype(int) - temp_frms = vr.get_batch(np.arange(start, end_safty)) - assert temp_frms is not None - tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms - tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())] - else: - if ori_vlen > max_num_frames: - num_frames = max_num_frames - start = int(skip_frms_num) - end = int(ori_vlen - skip_frms_num) - indices = np.arange(start, end, (end - start) // num_frames).astype(int) - temp_frms = vr.get_batch(np.arange(start, end)) - assert temp_frms is not None - tensor_frms = ( - torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms - ) - tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())] - else: - - def nearest_smaller_4k_plus_1(n): - remainder = n % 4 - if remainder == 0: - return n - 3 - else: - return n - remainder + 1 - - start = int(skip_frms_num) - end = int(ori_vlen - skip_frms_num) - num_frames = nearest_smaller_4k_plus_1( - end - start - ) # 3D VAE requires the number of frames to be 4k+1 - end = int(start + num_frames) - temp_frms = vr.get_batch(np.arange(start, end)) - assert temp_frms is not None - tensor_frms = ( - torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms - ) - - tensor_frms = pad_last_frame( - tensor_frms, max_num_frames - ) # the len of indices may be less than num_frames, due to round error - tensor_frms = tensor_frms.permute(0, 3, 1, 2) # [T, H, W, C] -> [T, C, H, W] - tensor_frms = resize_for_rectangle_crop(tensor_frms, video_size, reshape_mode="center") - tensor_frms = (tensor_frms - 127.5) / 127.5 - self.videos_list.append(tensor_frms) - - # caption - caption_path = os.path.join(root, filename.replace(".mp4", ".txt")).replace("videos", "labels") - if os.path.exists(caption_path): - caption = open(caption_path, "r").read().splitlines()[0] - else: - caption = "" - self.captions_list.append(caption) - self.num_frames_list.append(num_frames) - self.fps_list.append(fps) + self.video_paths.append(video_path) def __getitem__(self, index): + + decord.bridge.set_bridge("torch") + + video_path = self.video_paths[index] + vr = VideoReader(uri=video_path, height=-1, width=-1) + actual_fps = vr.get_avg_fps() + ori_vlen = len(vr) + + if ori_vlen / actual_fps * self.fps > self.max_num_frames: + num_frames = self.max_num_frames + start = int(self.skip_frms_num) + end = int(start + num_frames / self.fps * actual_fps) + end_safty = min(int(start + num_frames / self.fps * actual_fps), int(ori_vlen)) + indices = np.arange(start, end, (end - start) // num_frames).astype(int) + temp_frms = vr.get_batch(np.arange(start, end_safty)) + assert temp_frms is not None + tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms + tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())] + else: + if ori_vlen > self.max_num_frames: + num_frames = self.max_num_frames + start = int(self.skip_frms_num) + end = int(ori_vlen - self.skip_frms_num) + indices = np.arange(start, end, (end - start) // num_frames).astype(int) + temp_frms = vr.get_batch(np.arange(start, end)) + assert temp_frms is not None + tensor_frms = ( + torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms + ) + tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())] + else: + + def nearest_smaller_4k_plus_1(n): + remainder = n % 4 + if remainder == 0: + return n - 3 + else: + return n - remainder + 1 + + start = int(self.skip_frms_num) + end = int(ori_vlen - self.skip_frms_num) + num_frames = nearest_smaller_4k_plus_1( + end - start + ) # 3D VAE requires the number of frames to be 4k+1 + end = int(start + num_frames) + temp_frms = vr.get_batch(np.arange(start, end)) + assert temp_frms is not None + tensor_frms = ( + torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms + ) + + tensor_frms = pad_last_frame( + tensor_frms, self.max_num_frames + ) # the len of indices may be less than num_frames, due to round error + tensor_frms = tensor_frms.permute(0, 3, 1, 2) # [T, H, W, C] -> [T, C, H, W] + tensor_frms = resize_for_rectangle_crop(tensor_frms, self.video_size, reshape_mode="center") + tensor_frms = (tensor_frms - 127.5) / 127.5 + + caption_path = video_path.replace(".mp4", ".txt").replace("videos", "labels") + if os.path.exists(caption_path): + caption = open(caption_path, "r").read().splitlines()[0] + else: + caption = "" + item = { - "mp4": self.videos_list[index], - "txt": self.captions_list[index], - "num_frames": self.num_frames_list[index], - "fps": self.fps_list[index], + "mp4": tensor_frms, + "txt": caption, + "num_frames": num_frames, + "fps": self.fps, } return item def __len__(self): - return len(self.fps_list) + return len(self.video_paths) @classmethod def create_dataset_function(cls, path, args, **kwargs): From a6291c985af3059ba7d8117e067fe7e2dc0934a7 Mon Sep 17 00:00:00 2001 From: Jia Zheng Date: Tue, 27 Aug 2024 22:05:37 +0800 Subject: [PATCH 2/2] :zap: preload captions --- sat/data_video.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/sat/data_video.py b/sat/data_video.py index 505f2e0..d5d0ea0 100644 --- a/sat/data_video.py +++ b/sat/data_video.py @@ -369,6 +369,7 @@ class SFTDataset(Dataset): self.skip_frms_num = skip_frms_num self.video_paths = [] + self.captions = [] for root, dirnames, filenames in os.walk(data_dir): for filename in filenames: @@ -376,6 +377,13 @@ class SFTDataset(Dataset): video_path = os.path.join(root, filename) self.video_paths.append(video_path) + caption_path = video_path.replace(".mp4", ".txt").replace("videos", "labels") + if os.path.exists(caption_path): + caption = open(caption_path, "r").read().splitlines()[0] + else: + caption = "" + self.captions.append(caption) + def __getitem__(self, index): decord.bridge.set_bridge("torch") @@ -435,15 +443,9 @@ class SFTDataset(Dataset): tensor_frms = resize_for_rectangle_crop(tensor_frms, self.video_size, reshape_mode="center") tensor_frms = (tensor_frms - 127.5) / 127.5 - caption_path = video_path.replace(".mp4", ".txt").replace("videos", "labels") - if os.path.exists(caption_path): - caption = open(caption_path, "r").read().splitlines()[0] - else: - caption = "" - item = { "mp4": tensor_frms, - "txt": caption, + "txt": self.captions[index], "num_frames": num_frames, "fps": self.fps, }