diff --git a/api.py b/api.py index ea0e39d0..4f44d0b6 100644 --- a/api.py +++ b/api.py @@ -13,7 +13,7 @@ `-dt` - `默认参考音频文本` `-dl` - `默认参考音频语种, "中文","英文","日文","zh","en","ja"` -`-d` - `推理设备, "cuda","cpu"` +`-d` - `推理设备, "cuda","cpu","musa"` `-a` - `绑定地址, 默认"127.0.0.1"` `-p` - `绑定端口, 默认9880, 可在 config.py 中指定` `-fp` - `覆盖 config.py 使用全精度` @@ -124,6 +124,10 @@ import signal import LangSegment from time import time as ttime import torch +try: + import torch_musa +except ImportError: + pass import librosa import soundfile as sf from fastapi import FastAPI, Request, HTTPException @@ -570,7 +574,7 @@ parser.add_argument("-g", "--gpt_path", type=str, default=g_config.gpt_path, hel parser.add_argument("-dr", "--default_refer_path", type=str, default="", help="默认参考音频路径") parser.add_argument("-dt", "--default_refer_text", type=str, default="", help="默认参考音频文本") parser.add_argument("-dl", "--default_refer_language", type=str, default="", help="默认参考音频语种") -parser.add_argument("-d", "--device", type=str, default=g_config.infer_device, help="cuda / cpu") +parser.add_argument("-d", "--device", type=str, default=g_config.infer_device, help="cuda / cpu / MUSA ") parser.add_argument("-a", "--bind_addr", type=str, default="0.0.0.0", help="default: 0.0.0.0") parser.add_argument("-p", "--port", type=int, default=g_config.api_port, help="default: 9880") parser.add_argument("-fp", "--full_precision", action="store_true", default=False, help="覆盖config.is_half为False, 使用全精度") diff --git a/config.py b/config.py index 1f741285..d53789ec 100644 --- a/config.py +++ b/config.py @@ -17,10 +17,23 @@ pretrained_gpt_path = "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch= exp_root = "logs" python_exec = sys.executable or "python" + +infer_device = "cpu" + +# 判断是否有摩尔线程显卡可用 +try: + import torch_musa + use_torch_musa = True +except ImportError: + use_torch_musa = False +if use_torch_musa: + if torch.musa.is_available(): + infer_device = "musa" + is_half=False + print("GPT-SoVITS running on MUSA!") + if torch.cuda.is_available(): infer_device = "cuda" -else: - infer_device = "cpu" webui_port_main = 9874 webui_port_uvr5 = 9873