diff --git a/GPT_SoVITS/s1_train.py b/GPT_SoVITS/s1_train.py index 4a77006..db7b9a3 100644 --- a/GPT_SoVITS/s1_train.py +++ b/GPT_SoVITS/s1_train.py @@ -116,9 +116,9 @@ def main(args): devices=-1, benchmark=False, fast_dev_run=False, - strategy=DDPStrategy( + strategy = "auto" if torch.mps.is_available() else DDPStrategy( process_group_backend="nccl" if platform.system() != "Windows" else "gloo" - ), + ), # mps 不支持多节点训练 precision=config["train"]["precision"], logger=logger, num_sanity_val_steps=0,