mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-06 03:57:56 +08:00
⚡ preload captions
This commit is contained in:
parent
c84b92e081
commit
a6291c985a
@ -369,6 +369,7 @@ class SFTDataset(Dataset):
|
|||||||
self.skip_frms_num = skip_frms_num
|
self.skip_frms_num = skip_frms_num
|
||||||
|
|
||||||
self.video_paths = []
|
self.video_paths = []
|
||||||
|
self.captions = []
|
||||||
|
|
||||||
for root, dirnames, filenames in os.walk(data_dir):
|
for root, dirnames, filenames in os.walk(data_dir):
|
||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
@ -376,6 +377,13 @@ class SFTDataset(Dataset):
|
|||||||
video_path = os.path.join(root, filename)
|
video_path = os.path.join(root, filename)
|
||||||
self.video_paths.append(video_path)
|
self.video_paths.append(video_path)
|
||||||
|
|
||||||
|
caption_path = video_path.replace(".mp4", ".txt").replace("videos", "labels")
|
||||||
|
if os.path.exists(caption_path):
|
||||||
|
caption = open(caption_path, "r").read().splitlines()[0]
|
||||||
|
else:
|
||||||
|
caption = ""
|
||||||
|
self.captions.append(caption)
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
|
|
||||||
decord.bridge.set_bridge("torch")
|
decord.bridge.set_bridge("torch")
|
||||||
@ -435,15 +443,9 @@ class SFTDataset(Dataset):
|
|||||||
tensor_frms = resize_for_rectangle_crop(tensor_frms, self.video_size, reshape_mode="center")
|
tensor_frms = resize_for_rectangle_crop(tensor_frms, self.video_size, reshape_mode="center")
|
||||||
tensor_frms = (tensor_frms - 127.5) / 127.5
|
tensor_frms = (tensor_frms - 127.5) / 127.5
|
||||||
|
|
||||||
caption_path = video_path.replace(".mp4", ".txt").replace("videos", "labels")
|
|
||||||
if os.path.exists(caption_path):
|
|
||||||
caption = open(caption_path, "r").read().splitlines()[0]
|
|
||||||
else:
|
|
||||||
caption = ""
|
|
||||||
|
|
||||||
item = {
|
item = {
|
||||||
"mp4": tensor_frms,
|
"mp4": tensor_frms,
|
||||||
"txt": caption,
|
"txt": self.captions[index],
|
||||||
"num_frames": num_frames,
|
"num_frames": num_frames,
|
||||||
"fps": self.fps,
|
"fps": self.fps,
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user