Merge f54cbbe74316756b5ab75a9e14aafdfbd304da5b into 11aa78bd9bda8b53047cfcae03abf7ca94d27391

This commit is contained in:
wzy3650 2025-09-10 03:11:02 -04:00 committed by GitHub
commit 1efd74a316
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 373 additions and 24 deletions

View File

@ -37,6 +37,10 @@ from einops import rearrange, repeat
import torch import torch
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.distributed as dist
from module.distrib import broadcast_tensors, is_distributed
from module.ddp_utils import SyncFunction
from tqdm import tqdm from tqdm import tqdm
@ -69,27 +73,40 @@ def sample_vectors(samples, num: int):
return samples[indices] return samples[indices]
def kmeans(samples, num_clusters: int, num_iters: int = 10): def kmeans(samples, num_clusters: int, num_iters: int = 10, frames_to_use: int = 10_000, batch_size: int = 64):
dim, dtype = samples.shape[-1], samples.dtype N, D = samples.shape
max_kmeans_samples = 500 dtype, device = samples.dtype, samples.device
samples = samples[:max_kmeans_samples, :]
if frames_to_use < N:
indices = torch.randperm(N, device=device)[:frames_to_use]
samples = samples[indices]
means = sample_vectors(samples, num_clusters) means = sample_vectors(samples, num_clusters)
print("kmeans start ... ") print("kmeans start ... ")
for _ in tqdm(range(num_iters)): for _ in tqdm(range(num_iters)):
diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d") # Store cluster assignments
dists = -(diffs**2).sum(dim=-1) all_assignments = []
buckets = dists.max(dim=-1).indices for i in range(0, samples.shape[0], batch_size):
batch = samples[i : i + batch_size] # [B, D]
dists = torch.cdist(batch, means, p=2) # [B, C]
assignments = dists.argmin(dim=1) # [B]
all_assignments.append(assignments)
buckets = torch.cat(all_assignments, dim=0) # [N]
bins = torch.bincount(buckets, minlength=num_clusters) bins = torch.bincount(buckets, minlength=num_clusters)
zero_mask = bins == 0 zero_mask = bins == 0
bins_min_clamped = bins.masked_fill(zero_mask, 1) bins_min_clamped = bins.masked_fill(zero_mask, 1)
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) # Compute new means
new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples) new_means = torch.zeros_like(means)
new_means = new_means / bins_min_clamped[..., None] for i in range(num_clusters):
mask = buckets == i
if mask.any():
new_means[i] = samples[mask].mean(dim=0)
means = torch.where(zero_mask[..., None], means, new_means) means = torch.where(zero_mask[:, None], means, new_means)
return means, bins return means, bins
@ -141,13 +158,24 @@ class EuclideanCodebook(nn.Module):
if self.inited: if self.inited:
return return
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) if dist.is_available() and dist.is_initialized():
# [B * T * world_size, D]
data = SyncFunction.apply(data)
if dist.get_rank() == 0:
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
else:
embed = torch.empty_like(self.embed)
cluster_size = torch.empty_like(self.cluster_size)
dist.broadcast(embed, src=0)
dist.broadcast(cluster_size, src=0)
self.embed.data.copy_(embed) self.embed.data.copy_(embed)
self.embed_avg.data.copy_(embed.clone()) self.embed_avg.data.copy_(embed.clone())
self.cluster_size.data.copy_(cluster_size) self.cluster_size.data.copy_(cluster_size)
self.inited.data.copy_(torch.Tensor([True])) self.inited.data.copy_(torch.Tensor([True]))
# Make sure all buffers across workers are in sync after initialization # Make sure all buffers across workers are in sync after initialization
# broadcast_tensors(self.buffers()) broadcast_tensors(self.buffers())
def replace_(self, samples, mask): def replace_(self, samples, mask):
modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed) modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed)
@ -161,9 +189,17 @@ class EuclideanCodebook(nn.Module):
if not torch.any(expired_codes): if not torch.any(expired_codes):
return return
batch_samples = rearrange(batch_samples, "... d -> (...) d") if is_distributed():
self.replace_(batch_samples, mask=expired_codes) # [B * T * world_size, D]
# broadcast_tensors(self.buffers()) batch_samples = SyncFunction.apply(batch_samples)
if dist.get_rank() == 0:
new_embeds = sample_vectors(batch_samples, expired_codes.sum())
else:
new_embeds = torch.zeros(expired_codes.sum(), self.embed.size(1), device=self.embed.device)
dist.broadcast(new_embeds, src=0)
self.embed.data[expired_codes] = new_embeds
broadcast_tensors(self.buffers())
def preprocess(self, x): def preprocess(self, x):
x = rearrange(x, "... d -> (...) d") x = rearrange(x, "... d -> (...) d")
@ -208,17 +244,26 @@ class EuclideanCodebook(nn.Module):
quantize = self.dequantize(embed_ind) quantize = self.dequantize(embed_ind)
if self.training: if self.training:
### Update codebook by EMA
embed_onehot_sum = embed_onehot.sum(0) # [cb-size,]
embed_sum = x.t() @ embed_onehot # [D, cb-size]
if is_distributed():
dist.all_reduce(embed_onehot_sum)
dist.all_reduce(embed_sum)
# Update ema cluster count N_i^t, eq. (6) in vqvae paper
self.cluster_size.data.mul_(self.decay).add_(embed_onehot_sum, alpha=1 - self.decay)
# Update ema embed: eq. (7) in vqvae paper
self.embed_avg.data.mul_(self.decay).add_(embed_sum.t(), alpha=1 - self.decay)
# apply laplace smoothing
n = self.cluster_size.sum()
cluster_size = (self.cluster_size + self.epsilon) / (n + self.codebook_size * self.epsilon) * n
# Update ema embed: eq. (8) in vqvae paper
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
self.embed.data.copy_(embed_normalized)
# We do the expiry of code at that point as buffers are in sync # We do the expiry of code at that point as buffers are in sync
# and all the workers will take the same decision. # and all the workers will take the same decision.
self.expire_codes_(x) self.expire_codes_(x)
ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
embed_sum = x.t() @ embed_onehot
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
cluster_size = (
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) * self.cluster_size.sum()
)
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
self.embed.data.copy_(embed_normalized)
return quantize, embed_ind return quantize, embed_ind

View File

@ -0,0 +1,181 @@
import torch
from torch.nn.parallel import DistributedDataParallel
from torch.nn.parallel.distributed import _find_tensors
from packaging import version
# from https://github.com/Lightning-AI/lightning-bolts/blob/5d61197cd2f491f69e238137a5edabe80ae14ad9/pl_bolts/models/self_supervised/simclr/simclr_module.py#L20
class SyncFunction(torch.autograd.Function):
@staticmethod
# @torch.no_grad()
def forward(ctx, tensor):
world_size = torch.distributed.get_world_size()
# Collect batch sizes from all processes
local_bs = torch.tensor([tensor.shape[0]], device=tensor.device)
batch_sizes = [torch.zeros_like(local_bs) for _ in range(world_size)]
torch.distributed.all_gather(batch_sizes, local_bs)
# Convert to integer list and find the minimum
batch_sizes_int = [bs.item() for bs in batch_sizes]
min_bs = min(batch_sizes_int)
# Crop the tensor to the minimum batch size if needed
cropped_tensor = tensor[:min_bs] if tensor.shape[0] > min_bs else tensor
# Prepare for gathering
out_shape = (min_bs * world_size,) + tensor.shape[1:]
gathered_tensor = torch.zeros(out_shape, dtype=tensor.dtype, device=tensor.device)
# Build tensor list for all_gather
tensor_list = list(torch.chunk(gathered_tensor, world_size))
# Perform all_gather using the cropped tensors
torch.distributed.all_gather(tensor_list, cropped_tensor)
# Save for backward pass
ctx.min_bs = min_bs
ctx.world_size = world_size
ctx.orig_shape = tensor.shape
return gathered_tensor
@staticmethod
def backward(ctx, grad_output):
assert False
grad_input = grad_output.clone()
torch.distributed.all_reduce(grad_input, op=torch.distributed.ReduceOp.SUM, async_op=False)
idx_from = torch.distributed.get_rank() * ctx.batch_size
idx_to = (torch.distributed.get_rank() + 1) * ctx.batch_size
return grad_input[idx_from:idx_to]
class DDP(DistributedDataParallel):
"""
Override the forward call in lightning so it goes to training and validation step respectively
"""
def forward(self, *inputs, **kwargs): # pragma: no cover
if version.parse(torch.__version__[:6]) < version.parse("1.11"):
self._sync_params()
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
assert len(self.device_ids) == 1
if self.module.training:
output = self.module.training_step(*inputs[0], **kwargs[0])
elif self.module.testing:
output = self.module.test_step(*inputs[0], **kwargs[0])
else:
output = self.module.validation_step(*inputs[0], **kwargs[0])
if torch.is_grad_enabled():
# We'll return the output object verbatim since it is a freeform
# object. We need to find any tensors in this object, though,
# because we need to figure out which parameters were used during
# this forward pass, to ensure we short circuit reduction for any
# unused parameters. Only if `find_unused_parameters` is set.
if self.find_unused_parameters:
self.reducer.prepare_for_backward(list(_find_tensors(output)))
else:
self.reducer.prepare_for_backward([])
else:
from torch.nn.parallel.distributed import (
Join,
_DDPSink,
_tree_flatten_with_rref,
_tree_unflatten_with_rref,
)
with torch.autograd.profiler.record_function("DistributedDataParallel.forward"):
if torch.is_grad_enabled() and self.require_backward_grad_sync:
self.logger.set_runtime_stats_and_log()
self.num_iterations += 1
self.reducer.prepare_for_forward()
# Notify the join context that this process has not joined, if
# needed
work = Join.notify_join_context(self)
if work:
self.reducer._set_forward_pass_work_handle(work, self._divide_by_initial_world_size)
# Calling _rebuild_buckets before forward compuation,
# It may allocate new buckets before deallocating old buckets
# inside _rebuild_buckets. To save peak memory usage,
# call _rebuild_buckets before the peak memory usage increases
# during forward computation.
# This should be called only once during whole training period.
if torch.is_grad_enabled() and self.reducer._rebuild_buckets():
print("Reducer buckets have been rebuilt in this iteration.")
self._has_rebuilt_buckets = True
# sync params according to location (before/after forward) user
# specified as part of hook, if hook was specified.
buffer_hook_registered = hasattr(self, "buffer_hook")
if self._check_sync_bufs_pre_fwd():
self._sync_buffers()
if self._join_config.enable:
# Notify joined ranks whether they should sync in backwards pass or not.
self._check_global_requires_backward_grad_sync(is_joined_rank=False)
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if self.module.training:
output = self.module.training_step(*inputs[0], **kwargs[0])
elif self.module.testing:
output = self.module.test_step(*inputs[0], **kwargs[0])
else:
output = self.module.validation_step(*inputs[0], **kwargs[0])
# sync params according to location (before/after forward) user
# specified as part of hook, if hook was specified.
if self._check_sync_bufs_post_fwd():
self._sync_buffers()
if torch.is_grad_enabled() and self.require_backward_grad_sync:
self.require_forward_param_sync = True
# We'll return the output object verbatim since it is a freeform
# object. We need to find any tensors in this object, though,
# because we need to figure out which parameters were used during
# this forward pass, to ensure we short circuit reduction for any
# unused parameters. Only if `find_unused_parameters` is set.
if self.find_unused_parameters and not self.static_graph:
# Do not need to populate this for static graph.
self.reducer.prepare_for_backward(list(_find_tensors(output)))
else:
self.reducer.prepare_for_backward([])
else:
self.require_forward_param_sync = False
# TODO: DDPSink is currently enabled for unused parameter detection and
# static graph training for first iteration.
if (self.find_unused_parameters and not self.static_graph) or (
self.static_graph and self.num_iterations == 1
):
state_dict = {
"static_graph": self.static_graph,
"num_iterations": self.num_iterations,
}
output_tensor_list, treespec, output_is_rref = _tree_flatten_with_rref(output)
output_placeholders = [None for _ in range(len(output_tensor_list))]
# Do not touch tensors that have no grad_fn, which can cause issues
# such as https://github.com/pytorch/pytorch/issues/60733
for i, output in enumerate(output_tensor_list):
if torch.is_tensor(output) and output.grad_fn is None:
output_placeholders[i] = output
# When find_unused_parameters=True, makes tensors which require grad
# run through the DDPSink backward pass. When not all outputs are
# used in loss, this makes those corresponding tensors receive
# undefined gradient which the reducer then handles to ensure
# param.grad field is not touched and we don't error out.
passthrough_tensor_list = _DDPSink.apply(
self.reducer,
state_dict,
*output_tensor_list,
)
for i in range(len(output_placeholders)):
if output_placeholders[i] is None:
output_placeholders[i] = passthrough_tensor_list[i]
# Reconstruct output data structure.
output = _tree_unflatten_with_rref(output_placeholders, treespec, output_is_rref)
return output

View File

@ -0,0 +1,123 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Torch distributed utilities."""
import typing as tp
import torch
def rank():
if torch.distributed.is_initialized():
return torch.distributed.get_rank()
else:
return 0
def world_size():
if torch.distributed.is_initialized():
return torch.distributed.get_world_size()
else:
return 1
def is_distributed():
return world_size() > 1
def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM):
if is_distributed():
return torch.distributed.all_reduce(tensor, op)
def _is_complex_or_float(tensor):
return torch.is_floating_point(tensor) or torch.is_complex(tensor)
def _check_number_of_params(params: tp.List[torch.Tensor]):
# utility function to check that the number of params in all workers is the same,
# and thus avoid a deadlock with distributed all reduce.
if not is_distributed() or not params:
return
# print('params[0].device ', params[0].device)
tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long)
all_reduce(tensor)
if tensor.item() != len(params) * world_size():
# If not all the workers have the same number, for at least one of them,
# this inequality will be verified.
raise RuntimeError(
f"Mismatch in number of params: ours is {len(params)}, at least one worker has a different one."
)
def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0):
"""Broadcast the tensors from the given parameters to all workers.
This can be used to ensure that all workers have the same model to start with.
"""
if not is_distributed():
return
tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)]
_check_number_of_params(tensors)
handles = []
for tensor in tensors:
handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True)
handles.append(handle)
for handle in handles:
handle.wait()
def sync_buffer(buffers, average=True):
"""
Sync grad for buffers. If average is False, broadcast instead of averaging.
"""
if not is_distributed():
return
handles = []
for buffer in buffers:
if torch.is_floating_point(buffer.data):
if average:
handle = torch.distributed.all_reduce(buffer.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
else:
handle = torch.distributed.broadcast(buffer.data, src=0, async_op=True)
handles.append((buffer, handle))
for buffer, handle in handles:
handle.wait()
if average:
buffer.data /= world_size
def sync_grad(params):
"""
Simpler alternative to DistributedDataParallel, that doesn't rely
on any black magic. For simple models it can also be as fast.
Just call this on your model parameters after the call to backward!
"""
if not is_distributed():
return
handles = []
for p in params:
if p.grad is not None:
handle = torch.distributed.all_reduce(p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
handles.append((p, handle))
for p, handle in handles:
handle.wait()
p.grad.data /= world_size()
def average_metrics(metrics: tp.Dict[str, float], count=1.0):
"""Average a dictionary of metrics across all workers, using the optional
`count` as unormalized weight.
"""
if not is_distributed():
return metrics
keys, values = zip(*metrics.items())
device = "cuda" if torch.cuda.is_available() else "cpu"
tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32)
tensor *= count
all_reduce(tensor)
averaged = (tensor[:-1] / tensor[-1]).cpu().tolist()
return dict(zip(keys, averaged))