diff --git a/webui.py b/webui.py index 4b896ff..70bf5f0 100644 --- a/webui.py +++ b/webui.py @@ -51,13 +51,17 @@ n_cpu=cpu_count() ngpu = torch.cuda.device_count() gpu_infos = [] mem = [] +if_gpu_ok = False if torch.cuda.is_available() or ngpu != 0: for i in range(ngpu): gpu_name = torch.cuda.get_device_name(i) - gpu_infos.append("%s\t%s" % (i, gpu_name)) - mem.append(int(torch.cuda.get_device_properties(i).total_memory/ 1024/ 1024/ 1024+ 0.4)) -if len(gpu_infos) > 0: + if any(value in gpu_name.upper()for value in ["10","16","20","30","40","A2","A3","A4","P4","A50","500","A60","70","80","90","M4","T4","TITAN","L","4060"]): + # A10#A100#V100#A40#P40#M40#K80#A4500 + if_gpu_ok = True # 至少有一张能用的N卡 + gpu_infos.append("%s\t%s" % (i, gpu_name)) + mem.append(int(torch.cuda.get_device_properties(i).total_memory/ 1024/ 1024/ 1024+ 0.4)) +if if_gpu_ok and len(gpu_infos) > 0: gpu_info = "\n".join(gpu_infos) default_batch_size = min(mem) // 2 else: