From d3d5694ed896f67b821e08516353a4c5424b9949 Mon Sep 17 00:00:00 2001 From: lalala-233 <2317987274@qq.com> Date: Sat, 3 Feb 2024 15:42:23 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E5=AF=B9=20Linux=20=E4=B8=8B?= =?UTF-8?q?=E7=BA=AF=20CPU=20=E8=AE=AD=E7=BB=83=E7=9A=84=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- GPT_SoVITS/s1_train.py | 6 +++--- GPT_SoVITS/s2_train.py | 33 +++++++++++++++------------------ 2 files changed, 18 insertions(+), 21 deletions(-) diff --git a/GPT_SoVITS/s1_train.py b/GPT_SoVITS/s1_train.py index 3bbfdfb3..9bf2af26 100644 --- a/GPT_SoVITS/s1_train.py +++ b/GPT_SoVITS/s1_train.py @@ -108,15 +108,15 @@ def main(args): logger = TensorBoardLogger(name=output_dir.stem, save_dir=output_dir) trainer: Trainer = Trainer( max_epochs=config["train"]["epochs"], - accelerator="gpu", + accelerator="cpu", # val_check_interval=9999999999999999999999,###不要验证 # check_val_every_n_epoch=None, limit_val_batches=0, - devices=-1, + devices=1, benchmark=False, fast_dev_run=False, strategy = "auto" if torch.backends.mps.is_available() else DDPStrategy( - process_group_backend="nccl" if platform.system() != "Windows" else "gloo" + process_group_backend="nccl" if platform.system() != "Windows" and torch.cuda.is_available() else "gloo" ), # mps 不支持多节点训练 precision=config["train"]["precision"], logger=logger, diff --git a/GPT_SoVITS/s2_train.py b/GPT_SoVITS/s2_train.py index e6b64f6b..80f51734 100644 --- a/GPT_SoVITS/s2_train.py +++ b/GPT_SoVITS/s2_train.py @@ -44,12 +44,9 @@ global_step = 0 def main(): """Assume Single Node Multi GPUs Training Only""" - assert torch.cuda.is_available() or torch.backends.mps.is_available(), "Only GPU training is allowed." - - if torch.backends.mps.is_available(): + n_gpus = torch.cuda.device_count() + if n_gpus == 0: n_gpus = 1 - else: - n_gpus = torch.cuda.device_count() os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = str(randint(20000, 55555)) @@ -73,7 +70,7 @@ def run(rank, n_gpus, hps): writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval")) dist.init_process_group( - backend = "gloo" if os.name == "nt" or torch.backends.mps.is_available() else "nccl", + backend = "gloo" if os.name == "nt" or torch.cuda.is_available() == False or torch.backends.mps.is_available() else "nccl", init_method="env://", world_size=n_gpus, rank=rank, @@ -137,9 +134,9 @@ def run(rank, n_gpus, hps): hps.train.segment_size // hps.data.hop_length, n_speakers=hps.data.n_speakers, **hps.model, - ).to("mps") + ).to("cpu") - net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank) if torch.cuda.is_available() else MultiPeriodDiscriminator(hps.model.use_spectral_norm).to("mps") + net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank) if torch.cuda.is_available() else MultiPeriodDiscriminator(hps.model.use_spectral_norm).to("cpu") for name, param in net_g.named_parameters(): if not param.requires_grad: print(name, "not requires_grad") @@ -187,8 +184,8 @@ def run(rank, n_gpus, hps): net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True) net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True) else: - net_g = net_g.to("mps") - net_d = net_d.to("mps") + net_g = net_g.to("cpu") + net_d = net_d.to("cpu") try: # 如果能加载自动resume _, _, _, epoch_str = utils.load_checkpoint( @@ -320,12 +317,12 @@ def train_and_evaluate( rank, non_blocking=True ) else: - spec, spec_lengths = spec.to("mps"), spec_lengths.to("mps") - y, y_lengths = y.to("mps"), y_lengths.to("mps") - ssl = ssl.to("mps") + spec, spec_lengths = spec.to("cpu"), spec_lengths.to("cpu") + y, y_lengths = y.to("cpu"), y_lengths.to("cpu") + ssl = ssl.to("cpu") ssl.requires_grad = False # ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True) - text, text_lengths = text.to("mps"), text_lengths.to("mps") + text, text_lengths = text.to("cpu"), text_lengths.to("cpu") with autocast(enabled=hps.train.fp16_run): ( @@ -532,10 +529,10 @@ def evaluate(hps, generator, eval_loader, writer_eval): ssl = ssl.cuda() text, text_lengths = text.cuda(), text_lengths.cuda() else: - spec, spec_lengths = spec.to("mps"), spec_lengths.to("mps") - y, y_lengths = y.to("mps"), y_lengths.to("mps") - ssl = ssl.to("mps") - text, text_lengths = text.to("mps"), text_lengths.to("mps") + spec, spec_lengths = spec.to("cpu"), spec_lengths.to("cpu") + y, y_lengths = y.to("cpu"), y_lengths.to("cpu") + ssl = ssl.to("cpu") + text, text_lengths = text.to("cpu"), text_lengths.to("cpu") for test in [0, 1]: y_hat, mask, *_ = generator.module.infer( ssl, spec, spec_lengths, text, text_lengths, test=test