mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-07 15:19:59 +08:00
增加对 Linux 下纯 CPU 训练的支持
This commit is contained in:
parent
372a2ff89a
commit
d3d5694ed8
@ -108,15 +108,15 @@ def main(args):
|
|||||||
logger = TensorBoardLogger(name=output_dir.stem, save_dir=output_dir)
|
logger = TensorBoardLogger(name=output_dir.stem, save_dir=output_dir)
|
||||||
trainer: Trainer = Trainer(
|
trainer: Trainer = Trainer(
|
||||||
max_epochs=config["train"]["epochs"],
|
max_epochs=config["train"]["epochs"],
|
||||||
accelerator="gpu",
|
accelerator="cpu",
|
||||||
# val_check_interval=9999999999999999999999,###不要验证
|
# val_check_interval=9999999999999999999999,###不要验证
|
||||||
# check_val_every_n_epoch=None,
|
# check_val_every_n_epoch=None,
|
||||||
limit_val_batches=0,
|
limit_val_batches=0,
|
||||||
devices=-1,
|
devices=1,
|
||||||
benchmark=False,
|
benchmark=False,
|
||||||
fast_dev_run=False,
|
fast_dev_run=False,
|
||||||
strategy = "auto" if torch.backends.mps.is_available() else DDPStrategy(
|
strategy = "auto" if torch.backends.mps.is_available() else DDPStrategy(
|
||||||
process_group_backend="nccl" if platform.system() != "Windows" else "gloo"
|
process_group_backend="nccl" if platform.system() != "Windows" and torch.cuda.is_available() else "gloo"
|
||||||
), # mps 不支持多节点训练
|
), # mps 不支持多节点训练
|
||||||
precision=config["train"]["precision"],
|
precision=config["train"]["precision"],
|
||||||
logger=logger,
|
logger=logger,
|
||||||
|
@ -44,12 +44,9 @@ global_step = 0
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""Assume Single Node Multi GPUs Training Only"""
|
"""Assume Single Node Multi GPUs Training Only"""
|
||||||
assert torch.cuda.is_available() or torch.backends.mps.is_available(), "Only GPU training is allowed."
|
|
||||||
|
|
||||||
if torch.backends.mps.is_available():
|
|
||||||
n_gpus = 1
|
|
||||||
else:
|
|
||||||
n_gpus = torch.cuda.device_count()
|
n_gpus = torch.cuda.device_count()
|
||||||
|
if n_gpus == 0:
|
||||||
|
n_gpus = 1
|
||||||
os.environ["MASTER_ADDR"] = "localhost"
|
os.environ["MASTER_ADDR"] = "localhost"
|
||||||
os.environ["MASTER_PORT"] = str(randint(20000, 55555))
|
os.environ["MASTER_PORT"] = str(randint(20000, 55555))
|
||||||
|
|
||||||
@ -73,7 +70,7 @@ def run(rank, n_gpus, hps):
|
|||||||
writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
|
writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
|
||||||
|
|
||||||
dist.init_process_group(
|
dist.init_process_group(
|
||||||
backend = "gloo" if os.name == "nt" or torch.backends.mps.is_available() else "nccl",
|
backend = "gloo" if os.name == "nt" or torch.cuda.is_available() == False or torch.backends.mps.is_available() else "nccl",
|
||||||
init_method="env://",
|
init_method="env://",
|
||||||
world_size=n_gpus,
|
world_size=n_gpus,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
@ -137,9 +134,9 @@ def run(rank, n_gpus, hps):
|
|||||||
hps.train.segment_size // hps.data.hop_length,
|
hps.train.segment_size // hps.data.hop_length,
|
||||||
n_speakers=hps.data.n_speakers,
|
n_speakers=hps.data.n_speakers,
|
||||||
**hps.model,
|
**hps.model,
|
||||||
).to("mps")
|
).to("cpu")
|
||||||
|
|
||||||
net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank) if torch.cuda.is_available() else MultiPeriodDiscriminator(hps.model.use_spectral_norm).to("mps")
|
net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank) if torch.cuda.is_available() else MultiPeriodDiscriminator(hps.model.use_spectral_norm).to("cpu")
|
||||||
for name, param in net_g.named_parameters():
|
for name, param in net_g.named_parameters():
|
||||||
if not param.requires_grad:
|
if not param.requires_grad:
|
||||||
print(name, "not requires_grad")
|
print(name, "not requires_grad")
|
||||||
@ -187,8 +184,8 @@ def run(rank, n_gpus, hps):
|
|||||||
net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
|
net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
|
||||||
net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
|
net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
|
||||||
else:
|
else:
|
||||||
net_g = net_g.to("mps")
|
net_g = net_g.to("cpu")
|
||||||
net_d = net_d.to("mps")
|
net_d = net_d.to("cpu")
|
||||||
|
|
||||||
try: # 如果能加载自动resume
|
try: # 如果能加载自动resume
|
||||||
_, _, _, epoch_str = utils.load_checkpoint(
|
_, _, _, epoch_str = utils.load_checkpoint(
|
||||||
@ -320,12 +317,12 @@ def train_and_evaluate(
|
|||||||
rank, non_blocking=True
|
rank, non_blocking=True
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
spec, spec_lengths = spec.to("mps"), spec_lengths.to("mps")
|
spec, spec_lengths = spec.to("cpu"), spec_lengths.to("cpu")
|
||||||
y, y_lengths = y.to("mps"), y_lengths.to("mps")
|
y, y_lengths = y.to("cpu"), y_lengths.to("cpu")
|
||||||
ssl = ssl.to("mps")
|
ssl = ssl.to("cpu")
|
||||||
ssl.requires_grad = False
|
ssl.requires_grad = False
|
||||||
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
|
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
|
||||||
text, text_lengths = text.to("mps"), text_lengths.to("mps")
|
text, text_lengths = text.to("cpu"), text_lengths.to("cpu")
|
||||||
|
|
||||||
with autocast(enabled=hps.train.fp16_run):
|
with autocast(enabled=hps.train.fp16_run):
|
||||||
(
|
(
|
||||||
@ -532,10 +529,10 @@ def evaluate(hps, generator, eval_loader, writer_eval):
|
|||||||
ssl = ssl.cuda()
|
ssl = ssl.cuda()
|
||||||
text, text_lengths = text.cuda(), text_lengths.cuda()
|
text, text_lengths = text.cuda(), text_lengths.cuda()
|
||||||
else:
|
else:
|
||||||
spec, spec_lengths = spec.to("mps"), spec_lengths.to("mps")
|
spec, spec_lengths = spec.to("cpu"), spec_lengths.to("cpu")
|
||||||
y, y_lengths = y.to("mps"), y_lengths.to("mps")
|
y, y_lengths = y.to("cpu"), y_lengths.to("cpu")
|
||||||
ssl = ssl.to("mps")
|
ssl = ssl.to("cpu")
|
||||||
text, text_lengths = text.to("mps"), text_lengths.to("mps")
|
text, text_lengths = text.to("cpu"), text_lengths.to("cpu")
|
||||||
for test in [0, 1]:
|
for test in [0, 1]:
|
||||||
y_hat, mask, *_ = generator.module.infer(
|
y_hat, mask, *_ = generator.module.infer(
|
||||||
ssl, spec, spec_lengths, text, text_lengths, test=test
|
ssl, spec, spec_lengths, text, text_lengths, test=test
|
||||||
|
Loading…
x
Reference in New Issue
Block a user