diff --git a/GPT_SoVITS/module/core_vq.py b/GPT_SoVITS/module/core_vq.py index b7dab317..40745386 100644 --- a/GPT_SoVITS/module/core_vq.py +++ b/GPT_SoVITS/module/core_vq.py @@ -37,6 +37,10 @@ from einops import rearrange, repeat import torch from torch import nn import torch.nn.functional as F +import torch.distributed as dist + +from module.distrib import broadcast_tensors, is_distributed +from module.ddp_utils import SyncFunction from tqdm import tqdm @@ -69,27 +73,40 @@ def sample_vectors(samples, num: int): return samples[indices] -def kmeans(samples, num_clusters: int, num_iters: int = 10): - dim, dtype = samples.shape[-1], samples.dtype - max_kmeans_samples = 500 - samples = samples[:max_kmeans_samples, :] +def kmeans(samples, num_clusters: int, num_iters: int = 10, frames_to_use: int = 10_000, batch_size: int = 64): + N, D = samples.shape + dtype, device = samples.dtype, samples.device + + if frames_to_use < N: + indices = torch.randperm(N, device=device)[:frames_to_use] + samples = samples[indices] + means = sample_vectors(samples, num_clusters) print("kmeans start ... ") for _ in tqdm(range(num_iters)): - diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d") - dists = -(diffs**2).sum(dim=-1) + # Store cluster assignments + all_assignments = [] - buckets = dists.max(dim=-1).indices + for i in range(0, samples.shape[0], batch_size): + batch = samples[i : i + batch_size] # [B, D] + dists = torch.cdist(batch, means, p=2) # [B, C] + assignments = dists.argmin(dim=1) # [B] + all_assignments.append(assignments) + + buckets = torch.cat(all_assignments, dim=0) # [N] bins = torch.bincount(buckets, minlength=num_clusters) zero_mask = bins == 0 bins_min_clamped = bins.masked_fill(zero_mask, 1) - new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) - new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples) - new_means = new_means / bins_min_clamped[..., None] + # Compute new means + new_means = torch.zeros_like(means) + for i in range(num_clusters): + mask = buckets == i + if mask.any(): + new_means[i] = samples[mask].mean(dim=0) - means = torch.where(zero_mask[..., None], means, new_means) + means = torch.where(zero_mask[:, None], means, new_means) return means, bins @@ -141,13 +158,24 @@ class EuclideanCodebook(nn.Module): if self.inited: return - embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) + if dist.is_available() and dist.is_initialized(): + # [B * T * world_size, D] + data = SyncFunction.apply(data) + + if dist.get_rank() == 0: + embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) + else: + embed = torch.empty_like(self.embed) + cluster_size = torch.empty_like(self.cluster_size) + dist.broadcast(embed, src=0) + dist.broadcast(cluster_size, src=0) + self.embed.data.copy_(embed) self.embed_avg.data.copy_(embed.clone()) self.cluster_size.data.copy_(cluster_size) self.inited.data.copy_(torch.Tensor([True])) # Make sure all buffers across workers are in sync after initialization - # broadcast_tensors(self.buffers()) + broadcast_tensors(self.buffers()) def replace_(self, samples, mask): modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed) @@ -161,9 +189,17 @@ class EuclideanCodebook(nn.Module): if not torch.any(expired_codes): return - batch_samples = rearrange(batch_samples, "... d -> (...) d") - self.replace_(batch_samples, mask=expired_codes) - # broadcast_tensors(self.buffers()) + if is_distributed(): + # [B * T * world_size, D] + batch_samples = SyncFunction.apply(batch_samples) + + if dist.get_rank() == 0: + new_embeds = sample_vectors(batch_samples, expired_codes.sum()) + else: + new_embeds = torch.zeros(expired_codes.sum(), self.embed.size(1), device=self.embed.device) + dist.broadcast(new_embeds, src=0) + self.embed.data[expired_codes] = new_embeds + broadcast_tensors(self.buffers()) def preprocess(self, x): x = rearrange(x, "... d -> (...) d") @@ -208,17 +244,26 @@ class EuclideanCodebook(nn.Module): quantize = self.dequantize(embed_ind) if self.training: + ### Update codebook by EMA + embed_onehot_sum = embed_onehot.sum(0) # [cb-size,] + embed_sum = x.t() @ embed_onehot # [D, cb-size] + if is_distributed(): + dist.all_reduce(embed_onehot_sum) + dist.all_reduce(embed_sum) + # Update ema cluster count N_i^t, eq. (6) in vqvae paper + self.cluster_size.data.mul_(self.decay).add_(embed_onehot_sum, alpha=1 - self.decay) + # Update ema embed: eq. (7) in vqvae paper + self.embed_avg.data.mul_(self.decay).add_(embed_sum.t(), alpha=1 - self.decay) + # apply laplace smoothing + n = self.cluster_size.sum() + cluster_size = (self.cluster_size + self.epsilon) / (n + self.codebook_size * self.epsilon) * n + # Update ema embed: eq. (8) in vqvae paper + embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) + self.embed.data.copy_(embed_normalized) + # We do the expiry of code at that point as buffers are in sync # and all the workers will take the same decision. self.expire_codes_(x) - ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay) - embed_sum = x.t() @ embed_onehot - ema_inplace(self.embed_avg, embed_sum.t(), self.decay) - cluster_size = ( - laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) * self.cluster_size.sum() - ) - embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) - self.embed.data.copy_(embed_normalized) return quantize, embed_ind diff --git a/GPT_SoVITS/module/ddp_utils.py b/GPT_SoVITS/module/ddp_utils.py new file mode 100644 index 00000000..af30dd3f --- /dev/null +++ b/GPT_SoVITS/module/ddp_utils.py @@ -0,0 +1,181 @@ +import torch +from torch.nn.parallel import DistributedDataParallel +from torch.nn.parallel.distributed import _find_tensors +from packaging import version + + +# from https://github.com/Lightning-AI/lightning-bolts/blob/5d61197cd2f491f69e238137a5edabe80ae14ad9/pl_bolts/models/self_supervised/simclr/simclr_module.py#L20 +class SyncFunction(torch.autograd.Function): + @staticmethod + # @torch.no_grad() + def forward(ctx, tensor): + world_size = torch.distributed.get_world_size() + + # Collect batch sizes from all processes + local_bs = torch.tensor([tensor.shape[0]], device=tensor.device) + batch_sizes = [torch.zeros_like(local_bs) for _ in range(world_size)] + torch.distributed.all_gather(batch_sizes, local_bs) + + # Convert to integer list and find the minimum + batch_sizes_int = [bs.item() for bs in batch_sizes] + min_bs = min(batch_sizes_int) + + # Crop the tensor to the minimum batch size if needed + cropped_tensor = tensor[:min_bs] if tensor.shape[0] > min_bs else tensor + + # Prepare for gathering + out_shape = (min_bs * world_size,) + tensor.shape[1:] + gathered_tensor = torch.zeros(out_shape, dtype=tensor.dtype, device=tensor.device) + + # Build tensor list for all_gather + tensor_list = list(torch.chunk(gathered_tensor, world_size)) + + # Perform all_gather using the cropped tensors + torch.distributed.all_gather(tensor_list, cropped_tensor) + + # Save for backward pass + ctx.min_bs = min_bs + ctx.world_size = world_size + ctx.orig_shape = tensor.shape + + return gathered_tensor + + @staticmethod + def backward(ctx, grad_output): + assert False + grad_input = grad_output.clone() + torch.distributed.all_reduce(grad_input, op=torch.distributed.ReduceOp.SUM, async_op=False) + + idx_from = torch.distributed.get_rank() * ctx.batch_size + idx_to = (torch.distributed.get_rank() + 1) * ctx.batch_size + return grad_input[idx_from:idx_to] + +class DDP(DistributedDataParallel): + """ + Override the forward call in lightning so it goes to training and validation step respectively + """ + + def forward(self, *inputs, **kwargs): # pragma: no cover + if version.parse(torch.__version__[:6]) < version.parse("1.11"): + self._sync_params() + inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) + assert len(self.device_ids) == 1 + if self.module.training: + output = self.module.training_step(*inputs[0], **kwargs[0]) + elif self.module.testing: + output = self.module.test_step(*inputs[0], **kwargs[0]) + else: + output = self.module.validation_step(*inputs[0], **kwargs[0]) + if torch.is_grad_enabled(): + # We'll return the output object verbatim since it is a freeform + # object. We need to find any tensors in this object, though, + # because we need to figure out which parameters were used during + # this forward pass, to ensure we short circuit reduction for any + # unused parameters. Only if `find_unused_parameters` is set. + if self.find_unused_parameters: + self.reducer.prepare_for_backward(list(_find_tensors(output))) + else: + self.reducer.prepare_for_backward([]) + else: + from torch.nn.parallel.distributed import ( + Join, + _DDPSink, + _tree_flatten_with_rref, + _tree_unflatten_with_rref, + ) + + with torch.autograd.profiler.record_function("DistributedDataParallel.forward"): + if torch.is_grad_enabled() and self.require_backward_grad_sync: + self.logger.set_runtime_stats_and_log() + self.num_iterations += 1 + self.reducer.prepare_for_forward() + + # Notify the join context that this process has not joined, if + # needed + work = Join.notify_join_context(self) + if work: + self.reducer._set_forward_pass_work_handle(work, self._divide_by_initial_world_size) + + # Calling _rebuild_buckets before forward compuation, + # It may allocate new buckets before deallocating old buckets + # inside _rebuild_buckets. To save peak memory usage, + # call _rebuild_buckets before the peak memory usage increases + # during forward computation. + # This should be called only once during whole training period. + if torch.is_grad_enabled() and self.reducer._rebuild_buckets(): + print("Reducer buckets have been rebuilt in this iteration.") + self._has_rebuilt_buckets = True + + # sync params according to location (before/after forward) user + # specified as part of hook, if hook was specified. + buffer_hook_registered = hasattr(self, "buffer_hook") + if self._check_sync_bufs_pre_fwd(): + self._sync_buffers() + + if self._join_config.enable: + # Notify joined ranks whether they should sync in backwards pass or not. + self._check_global_requires_backward_grad_sync(is_joined_rank=False) + + inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) + if self.module.training: + output = self.module.training_step(*inputs[0], **kwargs[0]) + elif self.module.testing: + output = self.module.test_step(*inputs[0], **kwargs[0]) + else: + output = self.module.validation_step(*inputs[0], **kwargs[0]) + + # sync params according to location (before/after forward) user + # specified as part of hook, if hook was specified. + if self._check_sync_bufs_post_fwd(): + self._sync_buffers() + + if torch.is_grad_enabled() and self.require_backward_grad_sync: + self.require_forward_param_sync = True + # We'll return the output object verbatim since it is a freeform + # object. We need to find any tensors in this object, though, + # because we need to figure out which parameters were used during + # this forward pass, to ensure we short circuit reduction for any + # unused parameters. Only if `find_unused_parameters` is set. + if self.find_unused_parameters and not self.static_graph: + # Do not need to populate this for static graph. + self.reducer.prepare_for_backward(list(_find_tensors(output))) + else: + self.reducer.prepare_for_backward([]) + else: + self.require_forward_param_sync = False + + # TODO: DDPSink is currently enabled for unused parameter detection and + # static graph training for first iteration. + if (self.find_unused_parameters and not self.static_graph) or ( + self.static_graph and self.num_iterations == 1 + ): + state_dict = { + "static_graph": self.static_graph, + "num_iterations": self.num_iterations, + } + + output_tensor_list, treespec, output_is_rref = _tree_flatten_with_rref(output) + output_placeholders = [None for _ in range(len(output_tensor_list))] + # Do not touch tensors that have no grad_fn, which can cause issues + # such as https://github.com/pytorch/pytorch/issues/60733 + for i, output in enumerate(output_tensor_list): + if torch.is_tensor(output) and output.grad_fn is None: + output_placeholders[i] = output + + # When find_unused_parameters=True, makes tensors which require grad + # run through the DDPSink backward pass. When not all outputs are + # used in loss, this makes those corresponding tensors receive + # undefined gradient which the reducer then handles to ensure + # param.grad field is not touched and we don't error out. + passthrough_tensor_list = _DDPSink.apply( + self.reducer, + state_dict, + *output_tensor_list, + ) + for i in range(len(output_placeholders)): + if output_placeholders[i] is None: + output_placeholders[i] = passthrough_tensor_list[i] + + # Reconstruct output data structure. + output = _tree_unflatten_with_rref(output_placeholders, treespec, output_is_rref) + return output diff --git a/GPT_SoVITS/module/distrib.py b/GPT_SoVITS/module/distrib.py new file mode 100644 index 00000000..cabf8f8a --- /dev/null +++ b/GPT_SoVITS/module/distrib.py @@ -0,0 +1,123 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Torch distributed utilities.""" + +import typing as tp + +import torch + + +def rank(): + if torch.distributed.is_initialized(): + return torch.distributed.get_rank() + else: + return 0 + + +def world_size(): + if torch.distributed.is_initialized(): + return torch.distributed.get_world_size() + else: + return 1 + + +def is_distributed(): + return world_size() > 1 + + +def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM): + if is_distributed(): + return torch.distributed.all_reduce(tensor, op) + + +def _is_complex_or_float(tensor): + return torch.is_floating_point(tensor) or torch.is_complex(tensor) + + +def _check_number_of_params(params: tp.List[torch.Tensor]): + # utility function to check that the number of params in all workers is the same, + # and thus avoid a deadlock with distributed all reduce. + if not is_distributed() or not params: + return + # print('params[0].device ', params[0].device) + tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long) + all_reduce(tensor) + if tensor.item() != len(params) * world_size(): + # If not all the workers have the same number, for at least one of them, + # this inequality will be verified. + raise RuntimeError( + f"Mismatch in number of params: ours is {len(params)}, at least one worker has a different one." + ) + + +def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0): + """Broadcast the tensors from the given parameters to all workers. + This can be used to ensure that all workers have the same model to start with. + """ + if not is_distributed(): + return + tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)] + _check_number_of_params(tensors) + handles = [] + for tensor in tensors: + handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True) + handles.append(handle) + for handle in handles: + handle.wait() + + +def sync_buffer(buffers, average=True): + """ + Sync grad for buffers. If average is False, broadcast instead of averaging. + """ + if not is_distributed(): + return + handles = [] + for buffer in buffers: + if torch.is_floating_point(buffer.data): + if average: + handle = torch.distributed.all_reduce(buffer.data, op=torch.distributed.ReduceOp.SUM, async_op=True) + else: + handle = torch.distributed.broadcast(buffer.data, src=0, async_op=True) + handles.append((buffer, handle)) + for buffer, handle in handles: + handle.wait() + if average: + buffer.data /= world_size + + +def sync_grad(params): + """ + Simpler alternative to DistributedDataParallel, that doesn't rely + on any black magic. For simple models it can also be as fast. + Just call this on your model parameters after the call to backward! + """ + if not is_distributed(): + return + handles = [] + for p in params: + if p.grad is not None: + handle = torch.distributed.all_reduce(p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True) + handles.append((p, handle)) + for p, handle in handles: + handle.wait() + p.grad.data /= world_size() + + +def average_metrics(metrics: tp.Dict[str, float], count=1.0): + """Average a dictionary of metrics across all workers, using the optional + `count` as unormalized weight. + """ + if not is_distributed(): + return metrics + keys, values = zip(*metrics.items()) + device = "cuda" if torch.cuda.is_available() else "cpu" + tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32) + tensor *= count + all_reduce(tensor) + averaged = (tensor[:-1] / tensor[-1]).cpu().tolist() + return dict(zip(keys, averaged))