mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-06 22:50:00 +08:00
增加对 Linux 下纯 CPU 训练的支持
This commit is contained in:
parent
372a2ff89a
commit
d3d5694ed8
@ -108,15 +108,15 @@ def main(args):
|
||||
logger = TensorBoardLogger(name=output_dir.stem, save_dir=output_dir)
|
||||
trainer: Trainer = Trainer(
|
||||
max_epochs=config["train"]["epochs"],
|
||||
accelerator="gpu",
|
||||
accelerator="cpu",
|
||||
# val_check_interval=9999999999999999999999,###不要验证
|
||||
# check_val_every_n_epoch=None,
|
||||
limit_val_batches=0,
|
||||
devices=-1,
|
||||
devices=1,
|
||||
benchmark=False,
|
||||
fast_dev_run=False,
|
||||
strategy = "auto" if torch.backends.mps.is_available() else DDPStrategy(
|
||||
process_group_backend="nccl" if platform.system() != "Windows" else "gloo"
|
||||
process_group_backend="nccl" if platform.system() != "Windows" and torch.cuda.is_available() else "gloo"
|
||||
), # mps 不支持多节点训练
|
||||
precision=config["train"]["precision"],
|
||||
logger=logger,
|
||||
|
@ -44,12 +44,9 @@ global_step = 0
|
||||
|
||||
def main():
|
||||
"""Assume Single Node Multi GPUs Training Only"""
|
||||
assert torch.cuda.is_available() or torch.backends.mps.is_available(), "Only GPU training is allowed."
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
n_gpus = torch.cuda.device_count()
|
||||
if n_gpus == 0:
|
||||
n_gpus = 1
|
||||
else:
|
||||
n_gpus = torch.cuda.device_count()
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = str(randint(20000, 55555))
|
||||
|
||||
@ -73,7 +70,7 @@ def run(rank, n_gpus, hps):
|
||||
writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
|
||||
|
||||
dist.init_process_group(
|
||||
backend = "gloo" if os.name == "nt" or torch.backends.mps.is_available() else "nccl",
|
||||
backend = "gloo" if os.name == "nt" or torch.cuda.is_available() == False or torch.backends.mps.is_available() else "nccl",
|
||||
init_method="env://",
|
||||
world_size=n_gpus,
|
||||
rank=rank,
|
||||
@ -137,9 +134,9 @@ def run(rank, n_gpus, hps):
|
||||
hps.train.segment_size // hps.data.hop_length,
|
||||
n_speakers=hps.data.n_speakers,
|
||||
**hps.model,
|
||||
).to("mps")
|
||||
).to("cpu")
|
||||
|
||||
net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank) if torch.cuda.is_available() else MultiPeriodDiscriminator(hps.model.use_spectral_norm).to("mps")
|
||||
net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank) if torch.cuda.is_available() else MultiPeriodDiscriminator(hps.model.use_spectral_norm).to("cpu")
|
||||
for name, param in net_g.named_parameters():
|
||||
if not param.requires_grad:
|
||||
print(name, "not requires_grad")
|
||||
@ -187,8 +184,8 @@ def run(rank, n_gpus, hps):
|
||||
net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
|
||||
net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
|
||||
else:
|
||||
net_g = net_g.to("mps")
|
||||
net_d = net_d.to("mps")
|
||||
net_g = net_g.to("cpu")
|
||||
net_d = net_d.to("cpu")
|
||||
|
||||
try: # 如果能加载自动resume
|
||||
_, _, _, epoch_str = utils.load_checkpoint(
|
||||
@ -320,12 +317,12 @@ def train_and_evaluate(
|
||||
rank, non_blocking=True
|
||||
)
|
||||
else:
|
||||
spec, spec_lengths = spec.to("mps"), spec_lengths.to("mps")
|
||||
y, y_lengths = y.to("mps"), y_lengths.to("mps")
|
||||
ssl = ssl.to("mps")
|
||||
spec, spec_lengths = spec.to("cpu"), spec_lengths.to("cpu")
|
||||
y, y_lengths = y.to("cpu"), y_lengths.to("cpu")
|
||||
ssl = ssl.to("cpu")
|
||||
ssl.requires_grad = False
|
||||
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
|
||||
text, text_lengths = text.to("mps"), text_lengths.to("mps")
|
||||
text, text_lengths = text.to("cpu"), text_lengths.to("cpu")
|
||||
|
||||
with autocast(enabled=hps.train.fp16_run):
|
||||
(
|
||||
@ -532,10 +529,10 @@ def evaluate(hps, generator, eval_loader, writer_eval):
|
||||
ssl = ssl.cuda()
|
||||
text, text_lengths = text.cuda(), text_lengths.cuda()
|
||||
else:
|
||||
spec, spec_lengths = spec.to("mps"), spec_lengths.to("mps")
|
||||
y, y_lengths = y.to("mps"), y_lengths.to("mps")
|
||||
ssl = ssl.to("mps")
|
||||
text, text_lengths = text.to("mps"), text_lengths.to("mps")
|
||||
spec, spec_lengths = spec.to("cpu"), spec_lengths.to("cpu")
|
||||
y, y_lengths = y.to("cpu"), y_lengths.to("cpu")
|
||||
ssl = ssl.to("cpu")
|
||||
text, text_lengths = text.to("cpu"), text_lengths.to("cpu")
|
||||
for test in [0, 1]:
|
||||
y_hat, mask, *_ = generator.module.infer(
|
||||
ssl, spec, spec_lengths, text, text_lengths, test=test
|
||||
|
Loading…
x
Reference in New Issue
Block a user