mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-06-01 02:24:39 +08:00
feat(datasets): implement video dataset modules
- Add dataset implementations for text-to-video and image-to-video - Include bucket sampler for efficient batch processing - Add utility functions for data processing - Create dataset package structure with proper initialization
This commit is contained in:
parent
e3f6def234
commit
918ebb5a54
12
finetune/datasets/__init__.py
Normal file
12
finetune/datasets/__init__.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
from .i2v_dataset import I2VDatasetWithResize, I2VDatasetWithBuckets
|
||||||
|
from .t2v_dataset import T2VDatasetWithResize, T2VDatasetWithBuckets
|
||||||
|
from .bucket_sampler import BucketSampler
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"I2VDatasetWithResize",
|
||||||
|
"I2VDatasetWithBuckets",
|
||||||
|
"T2VDatasetWithResize",
|
||||||
|
"T2VDatasetWithBuckets",
|
||||||
|
"BucketSampler"
|
||||||
|
]
|
73
finetune/datasets/bucket_sampler.py
Normal file
73
finetune/datasets/bucket_sampler.py
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
import random
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from torch.utils.data import Sampler
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BucketSampler(Sampler):
|
||||||
|
r"""
|
||||||
|
PyTorch Sampler that groups 3D data by height, width and frames.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_source (`VideoDataset`):
|
||||||
|
A PyTorch dataset object that is an instance of `VideoDataset`.
|
||||||
|
batch_size (`int`, defaults to `8`):
|
||||||
|
The batch size to use for training.
|
||||||
|
shuffle (`bool`, defaults to `True`):
|
||||||
|
Whether or not to shuffle the data in each batch before dispatching to dataloader.
|
||||||
|
drop_last (`bool`, defaults to `False`):
|
||||||
|
Whether or not to drop incomplete buckets of data after completely iterating over all data
|
||||||
|
in the dataset. If set to True, only batches that have `batch_size` number of entries will
|
||||||
|
be yielded. If set to False, it is guaranteed that all data in the dataset will be processed
|
||||||
|
and batches that do not have `batch_size` number of entries will also be yielded.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, data_source: Dataset, batch_size: int = 8, shuffle: bool = True, drop_last: bool = False
|
||||||
|
) -> None:
|
||||||
|
self.data_source = data_source
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.shuffle = shuffle
|
||||||
|
self.drop_last = drop_last
|
||||||
|
|
||||||
|
self.buckets = {resolution: [] for resolution in data_source.video_resolution_buckets}
|
||||||
|
|
||||||
|
self._raised_warning_for_drop_last = False
|
||||||
|
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
if self.drop_last and not self._raised_warning_for_drop_last:
|
||||||
|
self._raised_warning_for_drop_last = True
|
||||||
|
logger.warning(
|
||||||
|
"Calculating the length for bucket sampler is not possible when `drop_last` is set to True. This may cause problems when setting the number of epochs used for training."
|
||||||
|
)
|
||||||
|
return (len(self.data_source) + self.batch_size - 1) // self.batch_size
|
||||||
|
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
for index, data in enumerate(self.data_source):
|
||||||
|
video_metadata = data["video_metadata"]
|
||||||
|
f, h, w = video_metadata["num_frames"], video_metadata["height"], video_metadata["width"]
|
||||||
|
|
||||||
|
self.buckets[(f, h, w)].append(data)
|
||||||
|
if len(self.buckets[(f, h, w)]) == self.batch_size:
|
||||||
|
if self.shuffle:
|
||||||
|
random.shuffle(self.buckets[(f, h, w)])
|
||||||
|
yield self.buckets[(f, h, w)]
|
||||||
|
del self.buckets[(f, h, w)]
|
||||||
|
self.buckets[(f, h, w)] = []
|
||||||
|
|
||||||
|
if self.drop_last:
|
||||||
|
return
|
||||||
|
|
||||||
|
for fhw, bucket in list(self.buckets.items()):
|
||||||
|
if len(bucket) == 0:
|
||||||
|
continue
|
||||||
|
if self.shuffle:
|
||||||
|
random.shuffle(bucket)
|
||||||
|
yield bucket
|
||||||
|
del self.buckets[fhw]
|
||||||
|
self.buckets[fhw] = []
|
206
finetune/datasets/i2v_dataset.py
Normal file
206
finetune/datasets/i2v_dataset.py
Normal file
@ -0,0 +1,206 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Tuple
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from accelerate.logging import get_logger
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
from .utils import (
|
||||||
|
load_prompts, load_videos, load_images,
|
||||||
|
|
||||||
|
preprocess_image_with_resize,
|
||||||
|
preprocess_video_with_resize,
|
||||||
|
preprocess_video_with_buckets
|
||||||
|
)
|
||||||
|
|
||||||
|
# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
|
||||||
|
# Very few bug reports but it happens. Look in decord Github issues for more relevant information.
|
||||||
|
import decord # isort:skip
|
||||||
|
|
||||||
|
decord.bridge.set_bridge("torch")
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseI2VDataset(Dataset):
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
data_root: str,
|
||||||
|
caption_column: str,
|
||||||
|
video_column: str,
|
||||||
|
image_column: str,
|
||||||
|
*args,
|
||||||
|
**kwargs
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
data_root = Path(data_root)
|
||||||
|
self.prompts = load_prompts(data_root / caption_column)
|
||||||
|
self.videos = load_videos(data_root / video_column)
|
||||||
|
self.images = load_images(data_root / image_column)
|
||||||
|
|
||||||
|
# Check if number of prompts matches number of videos and images
|
||||||
|
if not (len(self.videos) == len(self.prompts) == len(self.images)):
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected length of prompts, videos and images to be the same but found {len(self.prompts)=}, {len(self.videos)=} and {len(self.images)=}. Please ensure that the number of caption prompts, videos and images match in your dataset."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if all video files exist
|
||||||
|
if any(not path.is_file() for path in self.videos):
|
||||||
|
raise ValueError(
|
||||||
|
f"Some video files were not found. Please ensure that all video files exist in the dataset directory. Missing file: {next(path for path in self.videos if not path.is_file())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if all image files exist
|
||||||
|
if any(not path.is_file() for path in self.images):
|
||||||
|
raise ValueError(
|
||||||
|
f"Some image files were not found. Please ensure that all image files exist in the dataset directory. Missing file: {next(path for path in self.images if not path.is_file())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.videos)
|
||||||
|
|
||||||
|
def __getitem__(self, index: int) -> Dict[str, Any]:
|
||||||
|
if isinstance(index, list):
|
||||||
|
# Here, index is actually a list of data objects that we need to return.
|
||||||
|
# The BucketSampler should ideally return indices. But, in the sampler, we'd like
|
||||||
|
# to have information about num_frames, height and width. Since this is not stored
|
||||||
|
# as metadata, we need to read the video to get this information. You could read this
|
||||||
|
# information without loading the full video in memory, but we do it anyway. In order
|
||||||
|
# to not load the video twice (once to get the metadata, and once to return the loaded video
|
||||||
|
# based on sampled indices), we cache it in the BucketSampler. When the sampler is
|
||||||
|
# to yield, we yield the cache data instead of indices. So, this special check ensures
|
||||||
|
# that data is not loaded a second time. PRs are welcome for improvements.
|
||||||
|
return index
|
||||||
|
|
||||||
|
prompt = self.prompts[index]
|
||||||
|
|
||||||
|
# shape of frames: [F, C, H, W]
|
||||||
|
# shape of image: [C, H, W]
|
||||||
|
frames, image = self.preprocess(self.videos[index], self.images[index])
|
||||||
|
|
||||||
|
frames = self.video_transform(frames)
|
||||||
|
image = self.image_transform(image)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"prompt": prompt,
|
||||||
|
"video": frames,
|
||||||
|
"video_metadata": {
|
||||||
|
"num_frames": frames.shape[0],
|
||||||
|
"height": frames.shape[2],
|
||||||
|
"width": frames.shape[3],
|
||||||
|
},
|
||||||
|
"image": image
|
||||||
|
}
|
||||||
|
|
||||||
|
def preprocess(self, video_path: Path, image_path: Path) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Loads and preprocesses a video and an image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_path: Path to the video file to load
|
||||||
|
image_path: Path to the image file to load
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple containing:
|
||||||
|
- video(torch.Tensor) of shape [F, C, H, W] where F is number of frames,
|
||||||
|
C is number of channels, H is height and W is width
|
||||||
|
- image(torch.Tensor) of shape [C, H, W]
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Subclass must implement this method")
|
||||||
|
|
||||||
|
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Applies transformations to a video.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frames (torch.Tensor): A 4D tensor representing a video
|
||||||
|
with shape [F, C, H, W] where:
|
||||||
|
- F is number of frames
|
||||||
|
- C is number of channels (3 for RGB)
|
||||||
|
- H is height
|
||||||
|
- W is width
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The transformed video tensor
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Subclass must implement this method")
|
||||||
|
|
||||||
|
def image_transform(self, image: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Applies transformations to an image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (torch.Tensor): A 3D tensor representing an image
|
||||||
|
with shape [C, H, W] where:
|
||||||
|
- C is number of channels (3 for RGB)
|
||||||
|
- H is height
|
||||||
|
- W is width
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The transformed image tensor
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Subclass must implement this method")
|
||||||
|
|
||||||
|
|
||||||
|
class I2VDatasetWithResize(BaseI2VDataset):
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(self, max_num_frames: int, height: int, width: int, *args, **kwargs) -> None:
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
self.max_num_frames = max_num_frames
|
||||||
|
self.height = height
|
||||||
|
self.width = width
|
||||||
|
|
||||||
|
self.__frame_transforms = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@override
|
||||||
|
def preprocess(self, video_path: Path, image_path: Path) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
video = preprocess_video_with_resize(video_path, self.max_num_frames, self.height, self.width)
|
||||||
|
image = preprocess_image_with_resize(image_path, self.height, self.width)
|
||||||
|
return video, image
|
||||||
|
|
||||||
|
@override
|
||||||
|
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
|
||||||
|
return torch.stack([self.__frame_transforms(f) for f in frames], dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
class I2VDatasetWithBuckets(BaseI2VDataset):
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(self, video_resolution_buckets: List[Tuple[int, int, int]], *args, **kwargs) -> None:
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
self.video_resolution_buckets = video_resolution_buckets
|
||||||
|
self.__frame_transforms = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.__image_transforms = self.__frame_transforms
|
||||||
|
|
||||||
|
@override
|
||||||
|
def preprocess(self, video_path: Path, image_path: Path) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
video = preprocess_video_with_buckets(video_path, self.video_resolution_buckets)
|
||||||
|
image = preprocess_image_with_resize(image_path, video.shape[2], video.shape[3])
|
||||||
|
return video, image
|
||||||
|
|
||||||
|
@override
|
||||||
|
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
|
||||||
|
return torch.stack([self.__frame_transforms(f) for f in frames], dim=0)
|
||||||
|
|
||||||
|
@override
|
||||||
|
def image_transform(self, image: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.__image_transforms(image)
|
177
finetune/datasets/t2v_dataset.py
Normal file
177
finetune/datasets/t2v_dataset.py
Normal file
@ -0,0 +1,177 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Tuple
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from accelerate.logging import get_logger
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
from .utils import (
|
||||||
|
load_prompts, load_videos,
|
||||||
|
preprocess_video_with_resize,
|
||||||
|
preprocess_video_with_buckets
|
||||||
|
)
|
||||||
|
|
||||||
|
# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
|
||||||
|
# Very few bug reports but it happens. Look in decord Github issues for more relevant information.
|
||||||
|
import decord # isort:skip
|
||||||
|
|
||||||
|
decord.bridge.set_bridge("torch")
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseT2VDataset(Dataset):
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
data_root: str,
|
||||||
|
caption_column: str,
|
||||||
|
video_column: str,
|
||||||
|
*args,
|
||||||
|
**kwargs
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
data_root = Path(data_root)
|
||||||
|
self.prompts = load_prompts(data_root / caption_column)
|
||||||
|
self.videos = load_videos(data_root / video_column)
|
||||||
|
|
||||||
|
# Check if all video files exist
|
||||||
|
if any(not path.is_file() for path in self.videos):
|
||||||
|
raise ValueError(
|
||||||
|
f"Some video files were not found. Please ensure that all video files exist in the dataset directory. Missing file: {next(path for path in self.videos if not path.is_file())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if number of prompts matches number of videos
|
||||||
|
if len(self.videos) != len(self.prompts):
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected length of prompts and videos to be the same but found {len(self.prompts)=} and {len(self.videos)=}. Please ensure that the number of caption prompts and videos match in your dataset."
|
||||||
|
)
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.videos)
|
||||||
|
|
||||||
|
def __getitem__(self, index: int) -> Dict[str, Any]:
|
||||||
|
if isinstance(index, list):
|
||||||
|
# Here, index is actually a list of data objects that we need to return.
|
||||||
|
# The BucketSampler should ideally return indices. But, in the sampler, we'd like
|
||||||
|
# to have information about num_frames, height and width. Since this is not stored
|
||||||
|
# as metadata, we need to read the video to get this information. You could read this
|
||||||
|
# information without loading the full video in memory, but we do it anyway. In order
|
||||||
|
# to not load the video twice (once to get the metadata, and once to return the loaded video
|
||||||
|
# based on sampled indices), we cache it in the BucketSampler. When the sampler is
|
||||||
|
# to yield, we yield the cache data instead of indices. So, this special check ensures
|
||||||
|
# that data is not loaded a second time. PRs are welcome for improvements.
|
||||||
|
return index
|
||||||
|
|
||||||
|
prompt = self.prompts[index]
|
||||||
|
|
||||||
|
# shape of frames: [F, C, H, W]
|
||||||
|
frames = self.preprocess(self.videos[index])
|
||||||
|
frames = self.video_transform(frames)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"prompt": prompt,
|
||||||
|
"video": frames,
|
||||||
|
"video_metadata": {
|
||||||
|
"num_frames": frames.shape[0],
|
||||||
|
"height": frames.shape[2],
|
||||||
|
"width": frames.shape[3],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def preprocess(self, video_path: Path) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Loads and preprocesses a video.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_path: Path to the video file to load.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Video tensor of shape [F, C, H, W] where:
|
||||||
|
- F is number of frames
|
||||||
|
- C is number of channels (3 for RGB)
|
||||||
|
- H is height
|
||||||
|
- W is width
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Subclass must implement this method")
|
||||||
|
|
||||||
|
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Applies transformations to a video.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frames (torch.Tensor): A 4D tensor representing a video
|
||||||
|
with shape [F, C, H, W] where:
|
||||||
|
- F is number of frames
|
||||||
|
- C is number of channels (3 for RGB)
|
||||||
|
- H is height
|
||||||
|
- W is width
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The transformed video tensor
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Subclass must implement this method")
|
||||||
|
|
||||||
|
|
||||||
|
class T2VDatasetWithResize(BaseT2VDataset):
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(self, max_num_frames: int, height: int, width: int, *args, **kwargs) -> None:
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
self.max_num_frames = max_num_frames
|
||||||
|
self.height = height
|
||||||
|
self.width = width
|
||||||
|
|
||||||
|
self.__frame_transform = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@override
|
||||||
|
def preprocess(self, video_path: Path) -> torch.Tensor:
|
||||||
|
return preprocess_video_with_resize(
|
||||||
|
video_path, self.max_num_frames, self.height, self.width,
|
||||||
|
)
|
||||||
|
|
||||||
|
@override
|
||||||
|
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
|
||||||
|
return torch.stack([self.__frame_transform(f) for f in frames], dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
class T2VDatasetWithBuckets(BaseT2VDataset):
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(self, video_resolution_buckets: List[Tuple[int, int, int]], *args, **kwargs) -> None:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
resolution_buckets: List of tuples representing the resolution buckets.
|
||||||
|
Each tuple contains three integers: (max_num_frames, height, width).
|
||||||
|
"""
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
self.video_resolution_buckets = video_resolution_buckets
|
||||||
|
|
||||||
|
self.__frame_transform = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@override
|
||||||
|
def preprocess(self, video_path: Path) -> torch.Tensor:
|
||||||
|
return preprocess_video_with_buckets(
|
||||||
|
video_path, self.video_resolution_buckets
|
||||||
|
)
|
||||||
|
|
||||||
|
@override
|
||||||
|
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
|
||||||
|
return torch.stack([self.__frame_transform(f) for f in frames], dim=0)
|
141
finetune/datasets/utils.py
Normal file
141
finetune/datasets/utils.py
Normal file
@ -0,0 +1,141 @@
|
|||||||
|
import torch
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
from typing import List, Tuple
|
||||||
|
from pathlib import Path
|
||||||
|
from torchvision import transforms
|
||||||
|
from torchvision.transforms.functional import resize
|
||||||
|
|
||||||
|
# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
|
||||||
|
# Very few bug reports but it happens. Look in decord Github issues for more relevant information.
|
||||||
|
import decord # isort:skip
|
||||||
|
|
||||||
|
decord.bridge.set_bridge("torch")
|
||||||
|
|
||||||
|
|
||||||
|
########## loaders ##########
|
||||||
|
|
||||||
|
def load_prompts(prompt_path: Path) -> List[str]:
|
||||||
|
with open(prompt_path, "r", encoding="utf-8") as file:
|
||||||
|
return [line.strip() for line in file.readlines() if len(line.strip()) > 0]
|
||||||
|
|
||||||
|
|
||||||
|
def load_videos(video_path: Path) -> List[Path]:
|
||||||
|
with open(video_path, "r", encoding="utf-8") as file:
|
||||||
|
return [video_path.parent / line.strip() for line in file.readlines() if len(line.strip()) > 0]
|
||||||
|
|
||||||
|
|
||||||
|
def load_images(image_path: Path) -> List[Path]:
|
||||||
|
with open(image_path, "r", encoding="utf-8") as file:
|
||||||
|
return [image_path.parent / line.strip() for line in file.readlines() if len(line.strip()) > 0]
|
||||||
|
|
||||||
|
|
||||||
|
########## preprocessors ##########
|
||||||
|
|
||||||
|
def preprocess_image_with_resize(
|
||||||
|
image_path: Path,
|
||||||
|
height: int,
|
||||||
|
width: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Loads and resizes a single image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_path: Path to the image file.
|
||||||
|
height: Target height for resizing.
|
||||||
|
width: Target width for resizing.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Image tensor with shape [C, H, W] where:
|
||||||
|
C = number of channels (3 for RGB)
|
||||||
|
H = height
|
||||||
|
W = width
|
||||||
|
"""
|
||||||
|
image = cv2.imread(image_path.as_posix())
|
||||||
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||||
|
image = cv2.resize(image, (width, height))
|
||||||
|
image = torch.from_numpy(image).float()
|
||||||
|
image = image.permute(2, 0, 1).contiguous()
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_video_with_resize(
|
||||||
|
video_path: Path,
|
||||||
|
max_num_frames: int,
|
||||||
|
height: int,
|
||||||
|
width: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Loads and resizes a single video.
|
||||||
|
|
||||||
|
The function processes the video through these steps:
|
||||||
|
1. If video frame count > max_num_frames, downsample frames evenly
|
||||||
|
2. If video dimensions don't match (height, width), resize frames
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_path: Path to the video file.
|
||||||
|
max_num_frames: Maximum number of frames to keep.
|
||||||
|
height: Target height for resizing.
|
||||||
|
width: Target width for resizing.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A torch.Tensor with shape [F, C, H, W] where:
|
||||||
|
F = number of frames
|
||||||
|
C = number of channels (3 for RGB)
|
||||||
|
H = height
|
||||||
|
W = width
|
||||||
|
"""
|
||||||
|
video_reader = decord.VideoReader(uri=video_path.as_posix(), width=width, height=height)
|
||||||
|
video_num_frames = len(video_reader)
|
||||||
|
|
||||||
|
indices = list(range(0, video_num_frames, video_num_frames // max_num_frames))
|
||||||
|
frames = video_reader.get_batch(indices)
|
||||||
|
frames = frames[: max_num_frames].float()
|
||||||
|
frames = frames.permute(0, 3, 1, 2).contiguous()
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_video_with_buckets(
|
||||||
|
video_path: Path,
|
||||||
|
resolution_buckets: List[Tuple[int, int, int]],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
video_path: Path to the video file.
|
||||||
|
resolution_buckets: List of tuples (num_frames, height, width) representing
|
||||||
|
available resolution buckets.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Video tensor with shape [F, C, H, W] where:
|
||||||
|
F = number of frames
|
||||||
|
C = number of channels (3 for RGB)
|
||||||
|
H = height
|
||||||
|
W = width
|
||||||
|
|
||||||
|
The function processes the video through these steps:
|
||||||
|
1. Finds nearest frame bucket <= video frame count
|
||||||
|
2. Downsamples frames evenly to match bucket size
|
||||||
|
3. Finds nearest resolution bucket based on dimensions
|
||||||
|
4. Resizes frames to match bucket resolution
|
||||||
|
"""
|
||||||
|
video_reader = decord.VideoReader(uri=video_path.as_posix())
|
||||||
|
video_num_frames = len(video_reader)
|
||||||
|
resolution_buckets = [bucket for bucket in resolution_buckets if bucket[0] <= video_num_frames]
|
||||||
|
if len(resolution_buckets) == 0:
|
||||||
|
raise ValueError(f"video frame count in {video_path} is less than all frame buckets {resolution_buckets}")
|
||||||
|
|
||||||
|
nearest_frame_bucket = min(
|
||||||
|
resolution_buckets,
|
||||||
|
key=lambda bucket: video_num_frames - bucket[0],
|
||||||
|
default=1,
|
||||||
|
)[0]
|
||||||
|
frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket))
|
||||||
|
frames = video_reader.get_batch(frame_indices)
|
||||||
|
frames = frames[:nearest_frame_bucket].float()
|
||||||
|
frames = frames.permute(0, 3, 1, 2).contiguous()
|
||||||
|
|
||||||
|
nearest_res = min(resolution_buckets, key=lambda x: abs(x[1] - frames.shape[2]) + abs(x[2] - frames.shape[3]))
|
||||||
|
nearest_res = (nearest_res[1], nearest_res[2])
|
||||||
|
frames = torch.stack([resize(f, nearest_res) for f in frames], dim=0)
|
||||||
|
|
||||||
|
return frames
|
Loading…
x
Reference in New Issue
Block a user