diff --git a/finetune/datasets/__init__.py b/finetune/datasets/__init__.py new file mode 100644 index 0000000..8c9b61a --- /dev/null +++ b/finetune/datasets/__init__.py @@ -0,0 +1,12 @@ +from .i2v_dataset import I2VDatasetWithResize, I2VDatasetWithBuckets +from .t2v_dataset import T2VDatasetWithResize, T2VDatasetWithBuckets +from .bucket_sampler import BucketSampler + + +__all__ = [ + "I2VDatasetWithResize", + "I2VDatasetWithBuckets", + "T2VDatasetWithResize", + "T2VDatasetWithBuckets", + "BucketSampler" +] diff --git a/finetune/datasets/bucket_sampler.py b/finetune/datasets/bucket_sampler.py new file mode 100644 index 0000000..bf1beb1 --- /dev/null +++ b/finetune/datasets/bucket_sampler.py @@ -0,0 +1,73 @@ +import random +import logging + +from torch.utils.data import Sampler +from torch.utils.data import Dataset + +logger = logging.getLogger(__name__) + + +class BucketSampler(Sampler): + r""" + PyTorch Sampler that groups 3D data by height, width and frames. + + Args: + data_source (`VideoDataset`): + A PyTorch dataset object that is an instance of `VideoDataset`. + batch_size (`int`, defaults to `8`): + The batch size to use for training. + shuffle (`bool`, defaults to `True`): + Whether or not to shuffle the data in each batch before dispatching to dataloader. + drop_last (`bool`, defaults to `False`): + Whether or not to drop incomplete buckets of data after completely iterating over all data + in the dataset. If set to True, only batches that have `batch_size` number of entries will + be yielded. If set to False, it is guaranteed that all data in the dataset will be processed + and batches that do not have `batch_size` number of entries will also be yielded. + """ + + def __init__( + self, data_source: Dataset, batch_size: int = 8, shuffle: bool = True, drop_last: bool = False + ) -> None: + self.data_source = data_source + self.batch_size = batch_size + self.shuffle = shuffle + self.drop_last = drop_last + + self.buckets = {resolution: [] for resolution in data_source.video_resolution_buckets} + + self._raised_warning_for_drop_last = False + + + def __len__(self): + if self.drop_last and not self._raised_warning_for_drop_last: + self._raised_warning_for_drop_last = True + logger.warning( + "Calculating the length for bucket sampler is not possible when `drop_last` is set to True. This may cause problems when setting the number of epochs used for training." + ) + return (len(self.data_source) + self.batch_size - 1) // self.batch_size + + + def __iter__(self): + for index, data in enumerate(self.data_source): + video_metadata = data["video_metadata"] + f, h, w = video_metadata["num_frames"], video_metadata["height"], video_metadata["width"] + + self.buckets[(f, h, w)].append(data) + if len(self.buckets[(f, h, w)]) == self.batch_size: + if self.shuffle: + random.shuffle(self.buckets[(f, h, w)]) + yield self.buckets[(f, h, w)] + del self.buckets[(f, h, w)] + self.buckets[(f, h, w)] = [] + + if self.drop_last: + return + + for fhw, bucket in list(self.buckets.items()): + if len(bucket) == 0: + continue + if self.shuffle: + random.shuffle(bucket) + yield bucket + del self.buckets[fhw] + self.buckets[fhw] = [] diff --git a/finetune/datasets/i2v_dataset.py b/finetune/datasets/i2v_dataset.py new file mode 100644 index 0000000..2137dce --- /dev/null +++ b/finetune/datasets/i2v_dataset.py @@ -0,0 +1,206 @@ +from pathlib import Path +from typing import Any, Dict, List, Tuple +from typing_extensions import override + +import torch +from accelerate.logging import get_logger +from torch.utils.data import Dataset +from torchvision import transforms + +from .utils import ( + load_prompts, load_videos, load_images, + + preprocess_image_with_resize, + preprocess_video_with_resize, + preprocess_video_with_buckets +) + +# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error +# Very few bug reports but it happens. Look in decord Github issues for more relevant information. +import decord # isort:skip + +decord.bridge.set_bridge("torch") + +logger = get_logger(__name__) + + +class BaseI2VDataset(Dataset): + """ + + """ + def __init__( + self, + data_root: str, + caption_column: str, + video_column: str, + image_column: str, + *args, + **kwargs + ) -> None: + super().__init__() + + data_root = Path(data_root) + self.prompts = load_prompts(data_root / caption_column) + self.videos = load_videos(data_root / video_column) + self.images = load_images(data_root / image_column) + + # Check if number of prompts matches number of videos and images + if not (len(self.videos) == len(self.prompts) == len(self.images)): + raise ValueError( + f"Expected length of prompts, videos and images to be the same but found {len(self.prompts)=}, {len(self.videos)=} and {len(self.images)=}. Please ensure that the number of caption prompts, videos and images match in your dataset." + ) + + # Check if all video files exist + if any(not path.is_file() for path in self.videos): + raise ValueError( + f"Some video files were not found. Please ensure that all video files exist in the dataset directory. Missing file: {next(path for path in self.videos if not path.is_file())}" + ) + + # Check if all image files exist + if any(not path.is_file() for path in self.images): + raise ValueError( + f"Some image files were not found. Please ensure that all image files exist in the dataset directory. Missing file: {next(path for path in self.images if not path.is_file())}" + ) + + def __len__(self) -> int: + return len(self.videos) + + def __getitem__(self, index: int) -> Dict[str, Any]: + if isinstance(index, list): + # Here, index is actually a list of data objects that we need to return. + # The BucketSampler should ideally return indices. But, in the sampler, we'd like + # to have information about num_frames, height and width. Since this is not stored + # as metadata, we need to read the video to get this information. You could read this + # information without loading the full video in memory, but we do it anyway. In order + # to not load the video twice (once to get the metadata, and once to return the loaded video + # based on sampled indices), we cache it in the BucketSampler. When the sampler is + # to yield, we yield the cache data instead of indices. So, this special check ensures + # that data is not loaded a second time. PRs are welcome for improvements. + return index + + prompt = self.prompts[index] + + # shape of frames: [F, C, H, W] + # shape of image: [C, H, W] + frames, image = self.preprocess(self.videos[index], self.images[index]) + + frames = self.video_transform(frames) + image = self.image_transform(image) + + return { + "prompt": prompt, + "video": frames, + "video_metadata": { + "num_frames": frames.shape[0], + "height": frames.shape[2], + "width": frames.shape[3], + }, + "image": image + } + + def preprocess(self, video_path: Path, image_path: Path) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Loads and preprocesses a video and an image. + + Args: + video_path: Path to the video file to load + image_path: Path to the image file to load + + Returns: + A tuple containing: + - video(torch.Tensor) of shape [F, C, H, W] where F is number of frames, + C is number of channels, H is height and W is width + - image(torch.Tensor) of shape [C, H, W] + """ + raise NotImplementedError("Subclass must implement this method") + + def video_transform(self, frames: torch.Tensor) -> torch.Tensor: + """ + Applies transformations to a video. + + Args: + frames (torch.Tensor): A 4D tensor representing a video + with shape [F, C, H, W] where: + - F is number of frames + - C is number of channels (3 for RGB) + - H is height + - W is width + + Returns: + torch.Tensor: The transformed video tensor + """ + raise NotImplementedError("Subclass must implement this method") + + def image_transform(self, image: torch.Tensor) -> torch.Tensor: + """ + Applies transformations to an image. + + Args: + image (torch.Tensor): A 3D tensor representing an image + with shape [C, H, W] where: + - C is number of channels (3 for RGB) + - H is height + - W is width + + Returns: + torch.Tensor: The transformed image tensor + """ + raise NotImplementedError("Subclass must implement this method") + + +class I2VDatasetWithResize(BaseI2VDataset): + """ + + """ + def __init__(self, max_num_frames: int, height: int, width: int, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + self.max_num_frames = max_num_frames + self.height = height + self.width = width + + self.__frame_transforms = transforms.Compose( + [ + transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0) + ] + ) + + @override + def preprocess(self, video_path: Path, image_path: Path) -> Tuple[torch.Tensor, torch.Tensor]: + video = preprocess_video_with_resize(video_path, self.max_num_frames, self.height, self.width) + image = preprocess_image_with_resize(image_path, self.height, self.width) + return video, image + + @override + def video_transform(self, frames: torch.Tensor) -> torch.Tensor: + return torch.stack([self.__frame_transforms(f) for f in frames], dim=0) + + +class I2VDatasetWithBuckets(BaseI2VDataset): + """ + + """ + def __init__(self, video_resolution_buckets: List[Tuple[int, int, int]], *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + self.video_resolution_buckets = video_resolution_buckets + self.__frame_transforms = transforms.Compose( + [ + transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0) + ] + ) + self.__image_transforms = self.__frame_transforms + + @override + def preprocess(self, video_path: Path, image_path: Path) -> Tuple[torch.Tensor, torch.Tensor]: + video = preprocess_video_with_buckets(video_path, self.video_resolution_buckets) + image = preprocess_image_with_resize(image_path, video.shape[2], video.shape[3]) + return video, image + + @override + def video_transform(self, frames: torch.Tensor) -> torch.Tensor: + return torch.stack([self.__frame_transforms(f) for f in frames], dim=0) + + @override + def image_transform(self, image: torch.Tensor) -> torch.Tensor: + return self.__image_transforms(image) diff --git a/finetune/datasets/t2v_dataset.py b/finetune/datasets/t2v_dataset.py new file mode 100644 index 0000000..aa4ed72 --- /dev/null +++ b/finetune/datasets/t2v_dataset.py @@ -0,0 +1,177 @@ +from pathlib import Path +from typing import Any, Dict, List, Tuple +from typing_extensions import override + +import torch +from accelerate.logging import get_logger +from torch.utils.data import Dataset +from torchvision import transforms + +from .utils import ( + load_prompts, load_videos, + preprocess_video_with_resize, + preprocess_video_with_buckets +) + +# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error +# Very few bug reports but it happens. Look in decord Github issues for more relevant information. +import decord # isort:skip + +decord.bridge.set_bridge("torch") + +logger = get_logger(__name__) + + +class BaseT2VDataset(Dataset): + """ + + """ + def __init__( + self, + data_root: str, + caption_column: str, + video_column: str, + *args, + **kwargs + ) -> None: + super().__init__() + + data_root = Path(data_root) + self.prompts = load_prompts(data_root / caption_column) + self.videos = load_videos(data_root / video_column) + + # Check if all video files exist + if any(not path.is_file() for path in self.videos): + raise ValueError( + f"Some video files were not found. Please ensure that all video files exist in the dataset directory. Missing file: {next(path for path in self.videos if not path.is_file())}" + ) + + # Check if number of prompts matches number of videos + if len(self.videos) != len(self.prompts): + raise ValueError( + f"Expected length of prompts and videos to be the same but found {len(self.prompts)=} and {len(self.videos)=}. Please ensure that the number of caption prompts and videos match in your dataset." + ) + + def __len__(self) -> int: + return len(self.videos) + + def __getitem__(self, index: int) -> Dict[str, Any]: + if isinstance(index, list): + # Here, index is actually a list of data objects that we need to return. + # The BucketSampler should ideally return indices. But, in the sampler, we'd like + # to have information about num_frames, height and width. Since this is not stored + # as metadata, we need to read the video to get this information. You could read this + # information without loading the full video in memory, but we do it anyway. In order + # to not load the video twice (once to get the metadata, and once to return the loaded video + # based on sampled indices), we cache it in the BucketSampler. When the sampler is + # to yield, we yield the cache data instead of indices. So, this special check ensures + # that data is not loaded a second time. PRs are welcome for improvements. + return index + + prompt = self.prompts[index] + + # shape of frames: [F, C, H, W] + frames = self.preprocess(self.videos[index]) + frames = self.video_transform(frames) + + return { + "prompt": prompt, + "video": frames, + "video_metadata": { + "num_frames": frames.shape[0], + "height": frames.shape[2], + "width": frames.shape[3], + }, + } + + def preprocess(self, video_path: Path) -> torch.Tensor: + """ + Loads and preprocesses a video. + + Args: + video_path: Path to the video file to load. + + Returns: + torch.Tensor: Video tensor of shape [F, C, H, W] where: + - F is number of frames + - C is number of channels (3 for RGB) + - H is height + - W is width + """ + raise NotImplementedError("Subclass must implement this method") + + def video_transform(self, frames: torch.Tensor) -> torch.Tensor: + """ + Applies transformations to a video. + + Args: + frames (torch.Tensor): A 4D tensor representing a video + with shape [F, C, H, W] where: + - F is number of frames + - C is number of channels (3 for RGB) + - H is height + - W is width + + Returns: + torch.Tensor: The transformed video tensor + """ + raise NotImplementedError("Subclass must implement this method") + + +class T2VDatasetWithResize(BaseT2VDataset): + """ + + """ + def __init__(self, max_num_frames: int, height: int, width: int, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + self.max_num_frames = max_num_frames + self.height = height + self.width = width + + self.__frame_transform = transforms.Compose( + [ + transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0) + ] + ) + + @override + def preprocess(self, video_path: Path) -> torch.Tensor: + return preprocess_video_with_resize( + video_path, self.max_num_frames, self.height, self.width, + ) + + @override + def video_transform(self, frames: torch.Tensor) -> torch.Tensor: + return torch.stack([self.__frame_transform(f) for f in frames], dim=0) + + +class T2VDatasetWithBuckets(BaseT2VDataset): + """ + + """ + def __init__(self, video_resolution_buckets: List[Tuple[int, int, int]], *args, **kwargs) -> None: + """ + Args: + resolution_buckets: List of tuples representing the resolution buckets. + Each tuple contains three integers: (max_num_frames, height, width). + """ + super().__init__(*args, **kwargs) + + self.video_resolution_buckets = video_resolution_buckets + + self.__frame_transform = transforms.Compose( + [ + transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0) + ] + ) + + @override + def preprocess(self, video_path: Path) -> torch.Tensor: + return preprocess_video_with_buckets( + video_path, self.video_resolution_buckets + ) + + @override + def video_transform(self, frames: torch.Tensor) -> torch.Tensor: + return torch.stack([self.__frame_transform(f) for f in frames], dim=0) diff --git a/finetune/datasets/utils.py b/finetune/datasets/utils.py new file mode 100644 index 0000000..cf0525e --- /dev/null +++ b/finetune/datasets/utils.py @@ -0,0 +1,141 @@ +import torch +import cv2 + +from typing import List, Tuple +from pathlib import Path +from torchvision import transforms +from torchvision.transforms.functional import resize + +# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error +# Very few bug reports but it happens. Look in decord Github issues for more relevant information. +import decord # isort:skip + +decord.bridge.set_bridge("torch") + + +########## loaders ########## + +def load_prompts(prompt_path: Path) -> List[str]: + with open(prompt_path, "r", encoding="utf-8") as file: + return [line.strip() for line in file.readlines() if len(line.strip()) > 0] + + +def load_videos(video_path: Path) -> List[Path]: + with open(video_path, "r", encoding="utf-8") as file: + return [video_path.parent / line.strip() for line in file.readlines() if len(line.strip()) > 0] + + +def load_images(image_path: Path) -> List[Path]: + with open(image_path, "r", encoding="utf-8") as file: + return [image_path.parent / line.strip() for line in file.readlines() if len(line.strip()) > 0] + + +########## preprocessors ########## + +def preprocess_image_with_resize( + image_path: Path, + height: int, + width: int, +) -> torch.Tensor: + """ + Loads and resizes a single image. + + Args: + image_path: Path to the image file. + height: Target height for resizing. + width: Target width for resizing. + + Returns: + torch.Tensor: Image tensor with shape [C, H, W] where: + C = number of channels (3 for RGB) + H = height + W = width + """ + image = cv2.imread(image_path.as_posix()) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + image = cv2.resize(image, (width, height)) + image = torch.from_numpy(image).float() + image = image.permute(2, 0, 1).contiguous() + return image + + +def preprocess_video_with_resize( + video_path: Path, + max_num_frames: int, + height: int, + width: int, +) -> torch.Tensor: + """ + Loads and resizes a single video. + + The function processes the video through these steps: + 1. If video frame count > max_num_frames, downsample frames evenly + 2. If video dimensions don't match (height, width), resize frames + + Args: + video_path: Path to the video file. + max_num_frames: Maximum number of frames to keep. + height: Target height for resizing. + width: Target width for resizing. + + Returns: + A torch.Tensor with shape [F, C, H, W] where: + F = number of frames + C = number of channels (3 for RGB) + H = height + W = width + """ + video_reader = decord.VideoReader(uri=video_path.as_posix(), width=width, height=height) + video_num_frames = len(video_reader) + + indices = list(range(0, video_num_frames, video_num_frames // max_num_frames)) + frames = video_reader.get_batch(indices) + frames = frames[: max_num_frames].float() + frames = frames.permute(0, 3, 1, 2).contiguous() + return frames + + +def preprocess_video_with_buckets( + video_path: Path, + resolution_buckets: List[Tuple[int, int, int]], +) -> torch.Tensor: + """ + Args: + video_path: Path to the video file. + resolution_buckets: List of tuples (num_frames, height, width) representing + available resolution buckets. + + Returns: + torch.Tensor: Video tensor with shape [F, C, H, W] where: + F = number of frames + C = number of channels (3 for RGB) + H = height + W = width + + The function processes the video through these steps: + 1. Finds nearest frame bucket <= video frame count + 2. Downsamples frames evenly to match bucket size + 3. Finds nearest resolution bucket based on dimensions + 4. Resizes frames to match bucket resolution + """ + video_reader = decord.VideoReader(uri=video_path.as_posix()) + video_num_frames = len(video_reader) + resolution_buckets = [bucket for bucket in resolution_buckets if bucket[0] <= video_num_frames] + if len(resolution_buckets) == 0: + raise ValueError(f"video frame count in {video_path} is less than all frame buckets {resolution_buckets}") + + nearest_frame_bucket = min( + resolution_buckets, + key=lambda bucket: video_num_frames - bucket[0], + default=1, + )[0] + frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket)) + frames = video_reader.get_batch(frame_indices) + frames = frames[:nearest_frame_bucket].float() + frames = frames.permute(0, 3, 1, 2).contiguous() + + nearest_res = min(resolution_buckets, key=lambda x: abs(x[1] - frames.shape[2]) + abs(x[2] - frames.shape[3])) + nearest_res = (nearest_res[1], nearest_res[2]) + frames = torch.stack([resize(f, nearest_res) for f in frames], dim=0) + + return frames \ No newline at end of file