# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/bucket_sampler.py # reference: https://github.com/lifeiteng/vall-e import itertools import math import random from random import shuffle from typing import Iterator, Optional, TypeVar import torch import torch.distributed as dist from torch.utils.data import Dataset, Sampler __all__ = [ "DistributedBucketSampler", ] T_co = TypeVar("T_co", covariant=True) class DistributedBucketSampler(Sampler[T_co]): r""" sort the dataset wrt. input length divide samples into buckets sort within buckets divide buckets into batches sort batches """ def __init__( self, dataset: Dataset, num_replicas: Optional[int] = None, rank: Optional[int] = None, shuffle: bool = True, seed: int = 0, drop_last: bool = False, batch_size: int = 32, ) -> None: if num_replicas is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") num_replicas = dist.get_world_size() if torch.cuda.is_available() else 1 if rank is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") rank = dist.get_rank() if torch.cuda.is_available() else 0 if torch.cuda.is_available(): torch.cuda.set_device(rank) if rank >= num_replicas or rank < 0: raise ValueError("Invalid rank {}, rank should be in the interval [0, {}]".format(rank, num_replicas - 1)) self.dataset = dataset self.num_replicas = num_replicas self.rank = rank self.epoch = 0 self.drop_last = drop_last # If the dataset length is evenly divisible by # of replicas, then there # is no need to drop any data, since the dataset will be split equally. if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type] # Split to nearest available length that is evenly divisible. # This is to ensure each rank receives the same amount of data when # using this Sampler. self.num_samples = math.ceil( (len(self.dataset) - self.num_replicas) / self.num_replicas, # type: ignore[arg-type] ) else: self.num_samples = math.ceil( len(self.dataset) / self.num_replicas, ) # type: ignore[arg-type] self.total_size = self.num_samples * self.num_replicas self.shuffle = shuffle self.seed = seed self.batch_size = batch_size self.id_with_length = self._get_sample_lengths() self.id_buckets = self.make_buckets(bucket_width=2.0) def _get_sample_lengths(self): id_with_lengths = [] for i in range(len(self.dataset)): id_with_lengths.append((i, self.dataset.get_sample_length(i))) id_with_lengths.sort(key=lambda x: x[1]) return id_with_lengths def make_buckets(self, bucket_width: float = 2.0): buckets = [] cur = [] max_sec = bucket_width for id, sec in self.id_with_length: if sec < max_sec: cur.append(id) else: buckets.append(cur) cur = [id] max_sec += bucket_width if len(cur) > 0: buckets.append(cur) return buckets def __iter__(self) -> Iterator[T_co]: if self.shuffle: # deterministically shuffle based on epoch and seed g = torch.Generator() g.manual_seed(self.seed + self.epoch) random.seed(self.epoch + self.seed) shuffled_bucket = [] for buc in self.id_buckets: buc_copy = buc.copy() shuffle(buc_copy) shuffled_bucket.append(buc_copy) grouped_batch_size = self.batch_size * self.num_replicas shuffled_bucket = list(itertools.chain(*shuffled_bucket)) n_batch = int(math.ceil(len(shuffled_bucket) / grouped_batch_size)) batches = [shuffled_bucket[b * grouped_batch_size : (b + 1) * grouped_batch_size] for b in range(n_batch)] shuffle(batches) indices = list(itertools.chain(*batches)) else: # type: ignore[arg-type] indices = list(range(len(self.dataset))) if not self.drop_last: # add extra samples to make it evenly divisible padding_size = self.total_size - len(indices) if padding_size <= len(indices): indices += indices[:padding_size] else: indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] else: # remove tail of data to make it evenly divisible. indices = indices[: self.total_size] assert len(indices) == self.total_size # subsample indices = indices[self.rank : self.total_size : self.num_replicas] assert len(indices) == self.num_samples return iter(indices) def __len__(self) -> int: return self.num_samples def set_epoch(self, epoch: int) -> None: r""" Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas use a different random ordering for each epoch. Otherwise, the next iteration of this sampler will yield the same ordering. Args: epoch (int): Epoch number. """ self.epoch = epoch