fix: bypass gloo DDP for Windows single-GPU training

On Windows with a single GPU, dist.init_process_group() using the
gloo backend frequently fails with 'unsupported gloo device', caused
by virtual network adapters (VPN, VMware, Hyper-V, etc.).

Changes:
- s2_train.py: skip dist.init_process_group() on Windows single-GPU;
  add DummyDDP wrapper to maintain .module interface compatibility
- s1_train.py: set USE_LIBUV=0 to avoid socket conflicts;
  use strategy='auto' for single-GPU (bypasses gloo entirely),
  DDPStrategy only activated for multi-GPU setups
- utils.py, bucket_sampler.py: related compatibility adjustments

Tested on: Windows 11, single NVIDIA GPU, Python 3.10, PyTorch 2.5
This commit is contained in:
fanfan-love-meatmeat 2026-03-05 10:56:21 +08:00
parent 2d9193b0d3
commit 832e5b6160
4 changed files with 38 additions and 14 deletions

View File

@ -39,12 +39,12 @@ class DistributedBucketSampler(Sampler[T_co]):
if num_replicas is None: if num_replicas is None:
if not dist.is_available(): if not dist.is_available():
raise RuntimeError("Requires distributed package to be available") raise RuntimeError("Requires distributed package to be available")
num_replicas = dist.get_world_size() if torch.cuda.is_available() else 1 num_replicas = dist.get_world_size() if dist.is_initialized() else 1
if rank is None: if rank is None:
if not dist.is_available(): if not dist.is_available():
raise RuntimeError("Requires distributed package to be available") raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank() if torch.cuda.is_available() else 0 rank = dist.get_rank() if dist.is_initialized() else 0
if torch.cuda.is_available(): if torch.cuda.is_available() and dist.is_initialized():
torch.cuda.set_device(rank) torch.cuda.set_device(rank)
if rank >= num_replicas or rank < 0: if rank >= num_replicas or rank < 0:
raise ValueError("Invalid rank {}, rank should be in the interval [0, {}]".format(rank, num_replicas - 1)) raise ValueError("Invalid rank {}, rank should be in the interval [0, {}]".format(rank, num_replicas - 1))

View File

@ -118,7 +118,7 @@ def main(args):
benchmark=False, benchmark=False,
fast_dev_run=False, fast_dev_run=False,
strategy=DDPStrategy(process_group_backend="nccl" if platform.system() != "Windows" else "gloo") strategy=DDPStrategy(process_group_backend="nccl" if platform.system() != "Windows" else "gloo")
if torch.cuda.is_available() if torch.cuda.is_available() and torch.cuda.device_count() > 1
else "auto", else "auto",
precision=config["train"]["precision"], precision=config["train"]["precision"],
logger=logger, logger=logger,

View File

@ -77,6 +77,7 @@ def run(rank, n_gpus, hps):
writer = SummaryWriter(log_dir=hps.s2_ckpt_dir) writer = SummaryWriter(log_dir=hps.s2_ckpt_dir)
writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval")) writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
if not (os.name == "nt" and n_gpus == 1):
dist.init_process_group( dist.init_process_group(
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
init_method="env://?use_libuv=False", init_method="env://?use_libuv=False",
@ -197,6 +198,16 @@ def run(rank, n_gpus, hps):
eps=hps.train.eps, eps=hps.train.eps,
) )
if torch.cuda.is_available(): if torch.cuda.is_available():
if os.name == "nt" and n_gpus == 1:
class DummyDDP(torch.nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
net_g = DummyDDP(net_g)
net_d = DummyDDP(net_d)
else:
net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True) net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True) net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
else: else:

View File

@ -64,12 +64,25 @@ import shutil
from time import time as ttime from time import time as ttime
import time
def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
dir = os.path.dirname(path) dir = os.path.dirname(path)
name = os.path.basename(path) name = os.path.basename(path)
tmp_path = "%s.pth" % (ttime()) tmp_path = "%s.pth" % (time.time())
torch.save(fea, tmp_path) torch.save(fea, tmp_path)
shutil.move(tmp_path, "%s/%s" % (dir, name)) target_path = "%s/%s" % (dir, name)
try:
shutil.move(tmp_path, target_path)
except Exception as e:
print(f"Move failed with error {e}, retrying via copy and delete...")
if os.path.exists(target_path):
try:
os.remove(target_path)
except:
pass
shutil.copyfile(tmp_path, target_path)
os.remove(tmp_path)
def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path): def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):