import logging import random from torch.utils.data import Dataset, Sampler logger = logging.getLogger(__name__) class BucketSampler(Sampler): r""" PyTorch Sampler that groups 3D data by height, width and frames. Args: data_source (`VideoDataset`): A PyTorch dataset object that is an instance of `VideoDataset`. batch_size (`int`, defaults to `8`): The batch size to use for training. shuffle (`bool`, defaults to `True`): Whether or not to shuffle the data in each batch before dispatching to dataloader. drop_last (`bool`, defaults to `False`): Whether or not to drop incomplete buckets of data after completely iterating over all data in the dataset. If set to True, only batches that have `batch_size` number of entries will be yielded. If set to False, it is guaranteed that all data in the dataset will be processed and batches that do not have `batch_size` number of entries will also be yielded. """ def __init__( self, data_source: Dataset, batch_size: int = 8, shuffle: bool = True, drop_last: bool = False, ) -> None: self.data_source = data_source self.batch_size = batch_size self.shuffle = shuffle self.drop_last = drop_last self.buckets = {resolution: [] for resolution in data_source.video_resolution_buckets} self._raised_warning_for_drop_last = False def __len__(self): if self.drop_last and not self._raised_warning_for_drop_last: self._raised_warning_for_drop_last = True logger.warning( "Calculating the length for bucket sampler is not possible when `drop_last` is set to True. This may cause problems when setting the number of epochs used for training." ) return (len(self.data_source) + self.batch_size - 1) // self.batch_size def __iter__(self): for index, data in enumerate(self.data_source): video_metadata = data["video_metadata"] f, h, w = ( video_metadata["num_frames"], video_metadata["height"], video_metadata["width"], ) self.buckets[(f, h, w)].append(data) if len(self.buckets[(f, h, w)]) == self.batch_size: if self.shuffle: random.shuffle(self.buckets[(f, h, w)]) yield self.buckets[(f, h, w)] del self.buckets[(f, h, w)] self.buckets[(f, h, w)] = [] if self.drop_last: return for fhw, bucket in list(self.buckets.items()): if len(bucket) == 0: continue if self.shuffle: random.shuffle(bucket) yield bucket del self.buckets[fhw] self.buckets[fhw] = []