mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-06 03:57:56 +08:00
⚡ preload captions
This commit is contained in:
parent
c84b92e081
commit
a6291c985a
@ -369,6 +369,7 @@ class SFTDataset(Dataset):
|
||||
self.skip_frms_num = skip_frms_num
|
||||
|
||||
self.video_paths = []
|
||||
self.captions = []
|
||||
|
||||
for root, dirnames, filenames in os.walk(data_dir):
|
||||
for filename in filenames:
|
||||
@ -376,6 +377,13 @@ class SFTDataset(Dataset):
|
||||
video_path = os.path.join(root, filename)
|
||||
self.video_paths.append(video_path)
|
||||
|
||||
caption_path = video_path.replace(".mp4", ".txt").replace("videos", "labels")
|
||||
if os.path.exists(caption_path):
|
||||
caption = open(caption_path, "r").read().splitlines()[0]
|
||||
else:
|
||||
caption = ""
|
||||
self.captions.append(caption)
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
decord.bridge.set_bridge("torch")
|
||||
@ -435,15 +443,9 @@ class SFTDataset(Dataset):
|
||||
tensor_frms = resize_for_rectangle_crop(tensor_frms, self.video_size, reshape_mode="center")
|
||||
tensor_frms = (tensor_frms - 127.5) / 127.5
|
||||
|
||||
caption_path = video_path.replace(".mp4", ".txt").replace("videos", "labels")
|
||||
if os.path.exists(caption_path):
|
||||
caption = open(caption_path, "r").read().splitlines()[0]
|
||||
else:
|
||||
caption = ""
|
||||
|
||||
item = {
|
||||
"mp4": tensor_frms,
|
||||
"txt": caption,
|
||||
"txt": self.captions[index],
|
||||
"num_frames": num_frames,
|
||||
"fps": self.fps,
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user