diff --git a/config.py b/config.py index 8b5f378..c9124bf 100644 --- a/config.py +++ b/config.py @@ -1,5 +1,6 @@ import sys,os +import torch # 推理用的指定模型 sovits_path = "" @@ -14,7 +15,12 @@ pretrained_gpt_path = "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch= exp_root = "logs" python_exec = sys.executable or "python" -infer_device = "cuda" +if torch.cuda.is_available(): + infer_device = "cuda" +elif torch.mps.is_available(): + infer_device = "mps" +else: + infer_device = "cpu" webui_port_main = 9874 webui_port_uvr5 = 9873