From aed4935fcea3b36da6d2080246cba71524b6bf6b Mon Sep 17 00:00:00 2001 From: Wu Zichen Date: Wed, 24 Jan 2024 16:41:23 +0800 Subject: [PATCH] mps support --- webui.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/webui.py b/webui.py index 4461056..59eb0ff 100644 --- a/webui.py +++ b/webui.py @@ -45,14 +45,17 @@ i18n = I18nAuto() from scipy.io import wavfile from tools.my_utils import load_audio from multiprocessing import cpu_count + +os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 当遇到mps不支持的步骤时使用cpu + n_cpu=cpu_count() -# 判断是否有能用来训练和加速推理的N卡 ngpu = torch.cuda.device_count() gpu_infos = [] mem = [] if_gpu_ok = False +# 判断是否有能用来训练和加速推理的N卡 if torch.cuda.is_available() or ngpu != 0: for i in range(ngpu): gpu_name = torch.cuda.get_device_name(i) @@ -61,6 +64,12 @@ if torch.cuda.is_available() or ngpu != 0: if_gpu_ok = True # 至少有一张能用的N卡 gpu_infos.append("%s\t%s" % (i, gpu_name)) mem.append(int(torch.cuda.get_device_properties(i).total_memory/ 1024/ 1024/ 1024+ 0.4)) +# 判断是否支持mps加速 +if torch.backends.mps.is_available(): + if_gpu_ok = True + gpu_infos.append("%s\t%s" % ("0", "Apple GPU")) + mem.append(psutil.virtual_memory().total/ 1024 / 1024 / 1024) # 实测使用系统内存作为显存不会爆显存 + if if_gpu_ok and len(gpu_infos) > 0: gpu_info = "\n".join(gpu_infos) default_batch_size = min(mem) // 2