diff --git a/GPT_SoVITS/s2_train_v3_lora.py b/GPT_SoVITS/s2_train_v3_lora.py index ff62ccfe..de9da40e 100644 --- a/GPT_SoVITS/s2_train_v3_lora.py +++ b/GPT_SoVITS/s2_train_v3_lora.py @@ -55,6 +55,10 @@ def main(): n_gpus = torch.cuda.device_count() else: n_gpus = 1 + if n_gpus <= 1: + run(0, n_gpus, hps) + return + os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = str(randint(20000, 55555)) @@ -77,12 +81,14 @@ def run(rank, n_gpus, hps): writer = SummaryWriter(log_dir=hps.s2_ckpt_dir) writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval")) - dist.init_process_group( - backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", - init_method="env://?use_libuv=False", - world_size=n_gpus, - rank=rank, - ) + use_ddp = n_gpus > 1 + if use_ddp: + dist.init_process_group( + backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", + init_method="env://?use_libuv=False", + world_size=n_gpus, + rank=rank, + ) torch.manual_seed(hps.train.seed) if torch.cuda.is_available(): torch.cuda.set_device(rank) @@ -118,15 +124,20 @@ def run(rank, n_gpus, hps): shuffle=True, ) collate_fn = TextAudioSpeakerCollate() - train_loader = DataLoader( - train_dataset, - num_workers=5, + worker_count = 0 if os.name == "nt" and n_gpus <= 1 else min(2 if os.name == "nt" else 5, os.cpu_count() or 1) + loader_kwargs = dict( + num_workers=worker_count, shuffle=False, - pin_memory=True, + pin_memory=torch.cuda.is_available(), collate_fn=collate_fn, batch_sampler=train_sampler, - persistent_workers=True, - prefetch_factor=3, + ) + if worker_count > 0: + loader_kwargs["persistent_workers"] = True + loader_kwargs["prefetch_factor"] = 2 if os.name == "nt" else 3 + train_loader = DataLoader( + train_dataset, + **loader_kwargs, ) save_root = "%s/logs_s2_%s_lora_%s" % (hps.data.exp_dir, hps.model.version, hps.train.lora_rank) os.makedirs(save_root, exist_ok=True) @@ -156,7 +167,9 @@ def run(rank, n_gpus, hps): def model2cuda(net_g, rank): if torch.cuda.is_available(): - net_g = DDP(net_g.cuda(rank), device_ids=[rank], find_unused_parameters=True) + net_g = net_g.cuda(rank) + if use_ddp: + net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True) else: net_g = net_g.to(device) return net_g @@ -242,6 +255,8 @@ def run(rank, n_gpus, hps): None, ) scheduler_g.step() + if use_ddp and dist.is_initialized(): + dist.destroy_process_group() print("training done") @@ -327,22 +342,28 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade global_step += 1 if epoch % hps.train.save_every_epoch == 0 and rank == 0: - if hps.train.if_save_latest == 0: - utils.save_checkpoint( - net_g, - optim_g, - hps.train.learning_rate, - epoch, - os.path.join(save_root, "G_{}.pth".format(global_step)), - ) - else: - utils.save_checkpoint( - net_g, - optim_g, - hps.train.learning_rate, - epoch, - os.path.join(save_root, "G_{}.pth".format(233333333333)), - ) + try: + if hps.train.if_save_latest == 0: + utils.save_checkpoint( + net_g, + optim_g, + hps.train.learning_rate, + epoch, + os.path.join(save_root, "G_{}.pth".format(global_step)), + ) + else: + utils.save_checkpoint( + net_g, + optim_g, + hps.train.learning_rate, + epoch, + os.path.join(save_root, "G_{}.pth".format(233333333333)), + ) + except Exception as e: + if logger is not None: + logger.warning(f"skip large checkpoint save due to error: {e}") + else: + print(f"skip large checkpoint save due to error: {e}") if rank == 0 and hps.train.if_save_every_weights == True: if hasattr(net_g, "module"): ckpt = net_g.module.state_dict() diff --git a/GPT_SoVITS/utils.py b/GPT_SoVITS/utils.py index 08e18384..a0fe21a6 100644 --- a/GPT_SoVITS/utils.py +++ b/GPT_SoVITS/utils.py @@ -69,7 +69,8 @@ def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path name = os.path.basename(path) tmp_path = "%s.pth" % (ttime()) torch.save(fea, tmp_path) - shutil.move(tmp_path, "%s/%s" % (dir, name)) + os.makedirs(dir, exist_ok=True) + os.replace(tmp_path, "%s/%s" % (dir, name)) def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):