diff --git a/config.py b/config.py index 81cda36..8f4be14 100644 --- a/config.py +++ b/config.py @@ -144,7 +144,7 @@ webui_port_subfix = 9871 api_port = 9880 - +#Thanks to the contribution of @Karasukaigan and @XXXXRT666 def get_device_dtype_sm(idx: int) -> tuple[torch.device, torch.dtype, float, float]: cpu = torch.device("cpu") cuda = torch.device(f"cuda:{idx}") @@ -157,14 +157,10 @@ def get_device_dtype_sm(idx: int) -> tuple[torch.device, torch.dtype, float, flo mem_gb = mem_bytes / (1024**3) + 0.4 major, minor = capability sm_version = major + minor / 10.0 - is_16_series = bool(re.search(r"16\d{2}", name)) - if mem_gb < 4: - return cpu, torch.float32, 0.0, 0.0 - if (sm_version >= 7.0 and sm_version != 7.5) or (5.3 <= sm_version <= 6.0): - if is_16_series and sm_version == 7.5: - return cuda, torch.float32, sm_version, mem_gb # 16系卡除外 - else: - return cuda, torch.float16, sm_version, mem_gb + is_16_series = bool(re.search(r"16\d{2}", name))and sm_version == 7.5 + if mem_gb < 4 or sm_version < 5.3:return cpu, torch.float32, 0.0, 0.0 + if sm_version == 6.1 or is_16_series==True:return cuda, torch.float32, sm_version, mem_gb + if sm_version > 6.1:return cuda, torch.float16, sm_version, mem_gb return cpu, torch.float32, 0.0, 0.0