mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-29 00:30:15 +08:00
Merge f54cbbe74316756b5ab75a9e14aafdfbd304da5b into 11aa78bd9bda8b53047cfcae03abf7ca94d27391
This commit is contained in:
commit
1efd74a316
@ -37,6 +37,10 @@ from einops import rearrange, repeat
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
from module.distrib import broadcast_tensors, is_distributed
|
||||||
|
from module.ddp_utils import SyncFunction
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
@ -69,27 +73,40 @@ def sample_vectors(samples, num: int):
|
|||||||
return samples[indices]
|
return samples[indices]
|
||||||
|
|
||||||
|
|
||||||
def kmeans(samples, num_clusters: int, num_iters: int = 10):
|
def kmeans(samples, num_clusters: int, num_iters: int = 10, frames_to_use: int = 10_000, batch_size: int = 64):
|
||||||
dim, dtype = samples.shape[-1], samples.dtype
|
N, D = samples.shape
|
||||||
max_kmeans_samples = 500
|
dtype, device = samples.dtype, samples.device
|
||||||
samples = samples[:max_kmeans_samples, :]
|
|
||||||
|
if frames_to_use < N:
|
||||||
|
indices = torch.randperm(N, device=device)[:frames_to_use]
|
||||||
|
samples = samples[indices]
|
||||||
|
|
||||||
means = sample_vectors(samples, num_clusters)
|
means = sample_vectors(samples, num_clusters)
|
||||||
|
|
||||||
print("kmeans start ... ")
|
print("kmeans start ... ")
|
||||||
for _ in tqdm(range(num_iters)):
|
for _ in tqdm(range(num_iters)):
|
||||||
diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d")
|
# Store cluster assignments
|
||||||
dists = -(diffs**2).sum(dim=-1)
|
all_assignments = []
|
||||||
|
|
||||||
buckets = dists.max(dim=-1).indices
|
for i in range(0, samples.shape[0], batch_size):
|
||||||
|
batch = samples[i : i + batch_size] # [B, D]
|
||||||
|
dists = torch.cdist(batch, means, p=2) # [B, C]
|
||||||
|
assignments = dists.argmin(dim=1) # [B]
|
||||||
|
all_assignments.append(assignments)
|
||||||
|
|
||||||
|
buckets = torch.cat(all_assignments, dim=0) # [N]
|
||||||
bins = torch.bincount(buckets, minlength=num_clusters)
|
bins = torch.bincount(buckets, minlength=num_clusters)
|
||||||
zero_mask = bins == 0
|
zero_mask = bins == 0
|
||||||
bins_min_clamped = bins.masked_fill(zero_mask, 1)
|
bins_min_clamped = bins.masked_fill(zero_mask, 1)
|
||||||
|
|
||||||
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
|
# Compute new means
|
||||||
new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
|
new_means = torch.zeros_like(means)
|
||||||
new_means = new_means / bins_min_clamped[..., None]
|
for i in range(num_clusters):
|
||||||
|
mask = buckets == i
|
||||||
|
if mask.any():
|
||||||
|
new_means[i] = samples[mask].mean(dim=0)
|
||||||
|
|
||||||
means = torch.where(zero_mask[..., None], means, new_means)
|
means = torch.where(zero_mask[:, None], means, new_means)
|
||||||
|
|
||||||
return means, bins
|
return means, bins
|
||||||
|
|
||||||
@ -141,13 +158,24 @@ class EuclideanCodebook(nn.Module):
|
|||||||
if self.inited:
|
if self.inited:
|
||||||
return
|
return
|
||||||
|
|
||||||
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
|
if dist.is_available() and dist.is_initialized():
|
||||||
|
# [B * T * world_size, D]
|
||||||
|
data = SyncFunction.apply(data)
|
||||||
|
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
|
||||||
|
else:
|
||||||
|
embed = torch.empty_like(self.embed)
|
||||||
|
cluster_size = torch.empty_like(self.cluster_size)
|
||||||
|
dist.broadcast(embed, src=0)
|
||||||
|
dist.broadcast(cluster_size, src=0)
|
||||||
|
|
||||||
self.embed.data.copy_(embed)
|
self.embed.data.copy_(embed)
|
||||||
self.embed_avg.data.copy_(embed.clone())
|
self.embed_avg.data.copy_(embed.clone())
|
||||||
self.cluster_size.data.copy_(cluster_size)
|
self.cluster_size.data.copy_(cluster_size)
|
||||||
self.inited.data.copy_(torch.Tensor([True]))
|
self.inited.data.copy_(torch.Tensor([True]))
|
||||||
# Make sure all buffers across workers are in sync after initialization
|
# Make sure all buffers across workers are in sync after initialization
|
||||||
# broadcast_tensors(self.buffers())
|
broadcast_tensors(self.buffers())
|
||||||
|
|
||||||
def replace_(self, samples, mask):
|
def replace_(self, samples, mask):
|
||||||
modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed)
|
modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed)
|
||||||
@ -161,9 +189,17 @@ class EuclideanCodebook(nn.Module):
|
|||||||
if not torch.any(expired_codes):
|
if not torch.any(expired_codes):
|
||||||
return
|
return
|
||||||
|
|
||||||
batch_samples = rearrange(batch_samples, "... d -> (...) d")
|
if is_distributed():
|
||||||
self.replace_(batch_samples, mask=expired_codes)
|
# [B * T * world_size, D]
|
||||||
# broadcast_tensors(self.buffers())
|
batch_samples = SyncFunction.apply(batch_samples)
|
||||||
|
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
new_embeds = sample_vectors(batch_samples, expired_codes.sum())
|
||||||
|
else:
|
||||||
|
new_embeds = torch.zeros(expired_codes.sum(), self.embed.size(1), device=self.embed.device)
|
||||||
|
dist.broadcast(new_embeds, src=0)
|
||||||
|
self.embed.data[expired_codes] = new_embeds
|
||||||
|
broadcast_tensors(self.buffers())
|
||||||
|
|
||||||
def preprocess(self, x):
|
def preprocess(self, x):
|
||||||
x = rearrange(x, "... d -> (...) d")
|
x = rearrange(x, "... d -> (...) d")
|
||||||
@ -208,17 +244,26 @@ class EuclideanCodebook(nn.Module):
|
|||||||
quantize = self.dequantize(embed_ind)
|
quantize = self.dequantize(embed_ind)
|
||||||
|
|
||||||
if self.training:
|
if self.training:
|
||||||
|
### Update codebook by EMA
|
||||||
|
embed_onehot_sum = embed_onehot.sum(0) # [cb-size,]
|
||||||
|
embed_sum = x.t() @ embed_onehot # [D, cb-size]
|
||||||
|
if is_distributed():
|
||||||
|
dist.all_reduce(embed_onehot_sum)
|
||||||
|
dist.all_reduce(embed_sum)
|
||||||
|
# Update ema cluster count N_i^t, eq. (6) in vqvae paper
|
||||||
|
self.cluster_size.data.mul_(self.decay).add_(embed_onehot_sum, alpha=1 - self.decay)
|
||||||
|
# Update ema embed: eq. (7) in vqvae paper
|
||||||
|
self.embed_avg.data.mul_(self.decay).add_(embed_sum.t(), alpha=1 - self.decay)
|
||||||
|
# apply laplace smoothing
|
||||||
|
n = self.cluster_size.sum()
|
||||||
|
cluster_size = (self.cluster_size + self.epsilon) / (n + self.codebook_size * self.epsilon) * n
|
||||||
|
# Update ema embed: eq. (8) in vqvae paper
|
||||||
|
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
|
||||||
|
self.embed.data.copy_(embed_normalized)
|
||||||
|
|
||||||
# We do the expiry of code at that point as buffers are in sync
|
# We do the expiry of code at that point as buffers are in sync
|
||||||
# and all the workers will take the same decision.
|
# and all the workers will take the same decision.
|
||||||
self.expire_codes_(x)
|
self.expire_codes_(x)
|
||||||
ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
|
|
||||||
embed_sum = x.t() @ embed_onehot
|
|
||||||
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
|
|
||||||
cluster_size = (
|
|
||||||
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) * self.cluster_size.sum()
|
|
||||||
)
|
|
||||||
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
|
|
||||||
self.embed.data.copy_(embed_normalized)
|
|
||||||
|
|
||||||
return quantize, embed_ind
|
return quantize, embed_ind
|
||||||
|
|
||||||
|
181
GPT_SoVITS/module/ddp_utils.py
Normal file
181
GPT_SoVITS/module/ddp_utils.py
Normal file
@ -0,0 +1,181 @@
|
|||||||
|
import torch
|
||||||
|
from torch.nn.parallel import DistributedDataParallel
|
||||||
|
from torch.nn.parallel.distributed import _find_tensors
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
|
|
||||||
|
# from https://github.com/Lightning-AI/lightning-bolts/blob/5d61197cd2f491f69e238137a5edabe80ae14ad9/pl_bolts/models/self_supervised/simclr/simclr_module.py#L20
|
||||||
|
class SyncFunction(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
# @torch.no_grad()
|
||||||
|
def forward(ctx, tensor):
|
||||||
|
world_size = torch.distributed.get_world_size()
|
||||||
|
|
||||||
|
# Collect batch sizes from all processes
|
||||||
|
local_bs = torch.tensor([tensor.shape[0]], device=tensor.device)
|
||||||
|
batch_sizes = [torch.zeros_like(local_bs) for _ in range(world_size)]
|
||||||
|
torch.distributed.all_gather(batch_sizes, local_bs)
|
||||||
|
|
||||||
|
# Convert to integer list and find the minimum
|
||||||
|
batch_sizes_int = [bs.item() for bs in batch_sizes]
|
||||||
|
min_bs = min(batch_sizes_int)
|
||||||
|
|
||||||
|
# Crop the tensor to the minimum batch size if needed
|
||||||
|
cropped_tensor = tensor[:min_bs] if tensor.shape[0] > min_bs else tensor
|
||||||
|
|
||||||
|
# Prepare for gathering
|
||||||
|
out_shape = (min_bs * world_size,) + tensor.shape[1:]
|
||||||
|
gathered_tensor = torch.zeros(out_shape, dtype=tensor.dtype, device=tensor.device)
|
||||||
|
|
||||||
|
# Build tensor list for all_gather
|
||||||
|
tensor_list = list(torch.chunk(gathered_tensor, world_size))
|
||||||
|
|
||||||
|
# Perform all_gather using the cropped tensors
|
||||||
|
torch.distributed.all_gather(tensor_list, cropped_tensor)
|
||||||
|
|
||||||
|
# Save for backward pass
|
||||||
|
ctx.min_bs = min_bs
|
||||||
|
ctx.world_size = world_size
|
||||||
|
ctx.orig_shape = tensor.shape
|
||||||
|
|
||||||
|
return gathered_tensor
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
assert False
|
||||||
|
grad_input = grad_output.clone()
|
||||||
|
torch.distributed.all_reduce(grad_input, op=torch.distributed.ReduceOp.SUM, async_op=False)
|
||||||
|
|
||||||
|
idx_from = torch.distributed.get_rank() * ctx.batch_size
|
||||||
|
idx_to = (torch.distributed.get_rank() + 1) * ctx.batch_size
|
||||||
|
return grad_input[idx_from:idx_to]
|
||||||
|
|
||||||
|
class DDP(DistributedDataParallel):
|
||||||
|
"""
|
||||||
|
Override the forward call in lightning so it goes to training and validation step respectively
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(self, *inputs, **kwargs): # pragma: no cover
|
||||||
|
if version.parse(torch.__version__[:6]) < version.parse("1.11"):
|
||||||
|
self._sync_params()
|
||||||
|
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
|
||||||
|
assert len(self.device_ids) == 1
|
||||||
|
if self.module.training:
|
||||||
|
output = self.module.training_step(*inputs[0], **kwargs[0])
|
||||||
|
elif self.module.testing:
|
||||||
|
output = self.module.test_step(*inputs[0], **kwargs[0])
|
||||||
|
else:
|
||||||
|
output = self.module.validation_step(*inputs[0], **kwargs[0])
|
||||||
|
if torch.is_grad_enabled():
|
||||||
|
# We'll return the output object verbatim since it is a freeform
|
||||||
|
# object. We need to find any tensors in this object, though,
|
||||||
|
# because we need to figure out which parameters were used during
|
||||||
|
# this forward pass, to ensure we short circuit reduction for any
|
||||||
|
# unused parameters. Only if `find_unused_parameters` is set.
|
||||||
|
if self.find_unused_parameters:
|
||||||
|
self.reducer.prepare_for_backward(list(_find_tensors(output)))
|
||||||
|
else:
|
||||||
|
self.reducer.prepare_for_backward([])
|
||||||
|
else:
|
||||||
|
from torch.nn.parallel.distributed import (
|
||||||
|
Join,
|
||||||
|
_DDPSink,
|
||||||
|
_tree_flatten_with_rref,
|
||||||
|
_tree_unflatten_with_rref,
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.autograd.profiler.record_function("DistributedDataParallel.forward"):
|
||||||
|
if torch.is_grad_enabled() and self.require_backward_grad_sync:
|
||||||
|
self.logger.set_runtime_stats_and_log()
|
||||||
|
self.num_iterations += 1
|
||||||
|
self.reducer.prepare_for_forward()
|
||||||
|
|
||||||
|
# Notify the join context that this process has not joined, if
|
||||||
|
# needed
|
||||||
|
work = Join.notify_join_context(self)
|
||||||
|
if work:
|
||||||
|
self.reducer._set_forward_pass_work_handle(work, self._divide_by_initial_world_size)
|
||||||
|
|
||||||
|
# Calling _rebuild_buckets before forward compuation,
|
||||||
|
# It may allocate new buckets before deallocating old buckets
|
||||||
|
# inside _rebuild_buckets. To save peak memory usage,
|
||||||
|
# call _rebuild_buckets before the peak memory usage increases
|
||||||
|
# during forward computation.
|
||||||
|
# This should be called only once during whole training period.
|
||||||
|
if torch.is_grad_enabled() and self.reducer._rebuild_buckets():
|
||||||
|
print("Reducer buckets have been rebuilt in this iteration.")
|
||||||
|
self._has_rebuilt_buckets = True
|
||||||
|
|
||||||
|
# sync params according to location (before/after forward) user
|
||||||
|
# specified as part of hook, if hook was specified.
|
||||||
|
buffer_hook_registered = hasattr(self, "buffer_hook")
|
||||||
|
if self._check_sync_bufs_pre_fwd():
|
||||||
|
self._sync_buffers()
|
||||||
|
|
||||||
|
if self._join_config.enable:
|
||||||
|
# Notify joined ranks whether they should sync in backwards pass or not.
|
||||||
|
self._check_global_requires_backward_grad_sync(is_joined_rank=False)
|
||||||
|
|
||||||
|
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
|
||||||
|
if self.module.training:
|
||||||
|
output = self.module.training_step(*inputs[0], **kwargs[0])
|
||||||
|
elif self.module.testing:
|
||||||
|
output = self.module.test_step(*inputs[0], **kwargs[0])
|
||||||
|
else:
|
||||||
|
output = self.module.validation_step(*inputs[0], **kwargs[0])
|
||||||
|
|
||||||
|
# sync params according to location (before/after forward) user
|
||||||
|
# specified as part of hook, if hook was specified.
|
||||||
|
if self._check_sync_bufs_post_fwd():
|
||||||
|
self._sync_buffers()
|
||||||
|
|
||||||
|
if torch.is_grad_enabled() and self.require_backward_grad_sync:
|
||||||
|
self.require_forward_param_sync = True
|
||||||
|
# We'll return the output object verbatim since it is a freeform
|
||||||
|
# object. We need to find any tensors in this object, though,
|
||||||
|
# because we need to figure out which parameters were used during
|
||||||
|
# this forward pass, to ensure we short circuit reduction for any
|
||||||
|
# unused parameters. Only if `find_unused_parameters` is set.
|
||||||
|
if self.find_unused_parameters and not self.static_graph:
|
||||||
|
# Do not need to populate this for static graph.
|
||||||
|
self.reducer.prepare_for_backward(list(_find_tensors(output)))
|
||||||
|
else:
|
||||||
|
self.reducer.prepare_for_backward([])
|
||||||
|
else:
|
||||||
|
self.require_forward_param_sync = False
|
||||||
|
|
||||||
|
# TODO: DDPSink is currently enabled for unused parameter detection and
|
||||||
|
# static graph training for first iteration.
|
||||||
|
if (self.find_unused_parameters and not self.static_graph) or (
|
||||||
|
self.static_graph and self.num_iterations == 1
|
||||||
|
):
|
||||||
|
state_dict = {
|
||||||
|
"static_graph": self.static_graph,
|
||||||
|
"num_iterations": self.num_iterations,
|
||||||
|
}
|
||||||
|
|
||||||
|
output_tensor_list, treespec, output_is_rref = _tree_flatten_with_rref(output)
|
||||||
|
output_placeholders = [None for _ in range(len(output_tensor_list))]
|
||||||
|
# Do not touch tensors that have no grad_fn, which can cause issues
|
||||||
|
# such as https://github.com/pytorch/pytorch/issues/60733
|
||||||
|
for i, output in enumerate(output_tensor_list):
|
||||||
|
if torch.is_tensor(output) and output.grad_fn is None:
|
||||||
|
output_placeholders[i] = output
|
||||||
|
|
||||||
|
# When find_unused_parameters=True, makes tensors which require grad
|
||||||
|
# run through the DDPSink backward pass. When not all outputs are
|
||||||
|
# used in loss, this makes those corresponding tensors receive
|
||||||
|
# undefined gradient which the reducer then handles to ensure
|
||||||
|
# param.grad field is not touched and we don't error out.
|
||||||
|
passthrough_tensor_list = _DDPSink.apply(
|
||||||
|
self.reducer,
|
||||||
|
state_dict,
|
||||||
|
*output_tensor_list,
|
||||||
|
)
|
||||||
|
for i in range(len(output_placeholders)):
|
||||||
|
if output_placeholders[i] is None:
|
||||||
|
output_placeholders[i] = passthrough_tensor_list[i]
|
||||||
|
|
||||||
|
# Reconstruct output data structure.
|
||||||
|
output = _tree_unflatten_with_rref(output_placeholders, treespec, output_is_rref)
|
||||||
|
return output
|
123
GPT_SoVITS/module/distrib.py
Normal file
123
GPT_SoVITS/module/distrib.py
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
"""Torch distributed utilities."""
|
||||||
|
|
||||||
|
import typing as tp
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def rank():
|
||||||
|
if torch.distributed.is_initialized():
|
||||||
|
return torch.distributed.get_rank()
|
||||||
|
else:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def world_size():
|
||||||
|
if torch.distributed.is_initialized():
|
||||||
|
return torch.distributed.get_world_size()
|
||||||
|
else:
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
def is_distributed():
|
||||||
|
return world_size() > 1
|
||||||
|
|
||||||
|
|
||||||
|
def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM):
|
||||||
|
if is_distributed():
|
||||||
|
return torch.distributed.all_reduce(tensor, op)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_complex_or_float(tensor):
|
||||||
|
return torch.is_floating_point(tensor) or torch.is_complex(tensor)
|
||||||
|
|
||||||
|
|
||||||
|
def _check_number_of_params(params: tp.List[torch.Tensor]):
|
||||||
|
# utility function to check that the number of params in all workers is the same,
|
||||||
|
# and thus avoid a deadlock with distributed all reduce.
|
||||||
|
if not is_distributed() or not params:
|
||||||
|
return
|
||||||
|
# print('params[0].device ', params[0].device)
|
||||||
|
tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long)
|
||||||
|
all_reduce(tensor)
|
||||||
|
if tensor.item() != len(params) * world_size():
|
||||||
|
# If not all the workers have the same number, for at least one of them,
|
||||||
|
# this inequality will be verified.
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Mismatch in number of params: ours is {len(params)}, at least one worker has a different one."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0):
|
||||||
|
"""Broadcast the tensors from the given parameters to all workers.
|
||||||
|
This can be used to ensure that all workers have the same model to start with.
|
||||||
|
"""
|
||||||
|
if not is_distributed():
|
||||||
|
return
|
||||||
|
tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)]
|
||||||
|
_check_number_of_params(tensors)
|
||||||
|
handles = []
|
||||||
|
for tensor in tensors:
|
||||||
|
handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True)
|
||||||
|
handles.append(handle)
|
||||||
|
for handle in handles:
|
||||||
|
handle.wait()
|
||||||
|
|
||||||
|
|
||||||
|
def sync_buffer(buffers, average=True):
|
||||||
|
"""
|
||||||
|
Sync grad for buffers. If average is False, broadcast instead of averaging.
|
||||||
|
"""
|
||||||
|
if not is_distributed():
|
||||||
|
return
|
||||||
|
handles = []
|
||||||
|
for buffer in buffers:
|
||||||
|
if torch.is_floating_point(buffer.data):
|
||||||
|
if average:
|
||||||
|
handle = torch.distributed.all_reduce(buffer.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
|
||||||
|
else:
|
||||||
|
handle = torch.distributed.broadcast(buffer.data, src=0, async_op=True)
|
||||||
|
handles.append((buffer, handle))
|
||||||
|
for buffer, handle in handles:
|
||||||
|
handle.wait()
|
||||||
|
if average:
|
||||||
|
buffer.data /= world_size
|
||||||
|
|
||||||
|
|
||||||
|
def sync_grad(params):
|
||||||
|
"""
|
||||||
|
Simpler alternative to DistributedDataParallel, that doesn't rely
|
||||||
|
on any black magic. For simple models it can also be as fast.
|
||||||
|
Just call this on your model parameters after the call to backward!
|
||||||
|
"""
|
||||||
|
if not is_distributed():
|
||||||
|
return
|
||||||
|
handles = []
|
||||||
|
for p in params:
|
||||||
|
if p.grad is not None:
|
||||||
|
handle = torch.distributed.all_reduce(p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
|
||||||
|
handles.append((p, handle))
|
||||||
|
for p, handle in handles:
|
||||||
|
handle.wait()
|
||||||
|
p.grad.data /= world_size()
|
||||||
|
|
||||||
|
|
||||||
|
def average_metrics(metrics: tp.Dict[str, float], count=1.0):
|
||||||
|
"""Average a dictionary of metrics across all workers, using the optional
|
||||||
|
`count` as unormalized weight.
|
||||||
|
"""
|
||||||
|
if not is_distributed():
|
||||||
|
return metrics
|
||||||
|
keys, values = zip(*metrics.items())
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32)
|
||||||
|
tensor *= count
|
||||||
|
all_reduce(tensor)
|
||||||
|
averaged = (tensor[:-1] / tensor[-1]).cpu().tolist()
|
||||||
|
return dict(zip(keys, averaged))
|
Loading…
x
Reference in New Issue
Block a user