From a6291c985af3059ba7d8117e067fe7e2dc0934a7 Mon Sep 17 00:00:00 2001 From: Jia Zheng Date: Tue, 27 Aug 2024 22:05:37 +0800 Subject: [PATCH] :zap: preload captions --- sat/data_video.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/sat/data_video.py b/sat/data_video.py index 505f2e0..d5d0ea0 100644 --- a/sat/data_video.py +++ b/sat/data_video.py @@ -369,6 +369,7 @@ class SFTDataset(Dataset): self.skip_frms_num = skip_frms_num self.video_paths = [] + self.captions = [] for root, dirnames, filenames in os.walk(data_dir): for filename in filenames: @@ -376,6 +377,13 @@ class SFTDataset(Dataset): video_path = os.path.join(root, filename) self.video_paths.append(video_path) + caption_path = video_path.replace(".mp4", ".txt").replace("videos", "labels") + if os.path.exists(caption_path): + caption = open(caption_path, "r").read().splitlines()[0] + else: + caption = "" + self.captions.append(caption) + def __getitem__(self, index): decord.bridge.set_bridge("torch") @@ -435,15 +443,9 @@ class SFTDataset(Dataset): tensor_frms = resize_for_rectangle_crop(tensor_frms, self.video_size, reshape_mode="center") tensor_frms = (tensor_frms - 127.5) / 127.5 - caption_path = video_path.replace(".mp4", ".txt").replace("videos", "labels") - if os.path.exists(caption_path): - caption = open(caption_path, "r").read().splitlines()[0] - else: - caption = "" - item = { "mp4": tensor_frms, - "txt": caption, + "txt": self.captions[index], "num_frames": num_frames, "fps": self.fps, }