mps support

This commit is contained in:
Wu Zichen 2024-01-24 17:30:49 +08:00
parent ed05443cfe
commit a4c22b24f8

View File

@ -116,9 +116,9 @@ def main(args):
devices=-1,
benchmark=False,
fast_dev_run=False,
strategy=DDPStrategy(
strategy = "auto" if torch.mps.is_available() else DDPStrategy(
process_group_backend="nccl" if platform.system() != "Windows" else "gloo"
),
), # mps 不支持多节点训练
precision=config["train"]["precision"],
logger=logger,
num_sanity_val_steps=0,