From ee1c71f198ff8db79a883a6178866a236e054864 Mon Sep 17 00:00:00 2001 From: Karasukaigan <80465610+Karasukaigan@users.noreply.github.com> Date: Thu, 5 Jun 2025 17:10:20 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E4=BA=86=E9=83=A8=E5=88=86?= =?UTF-8?q?=E6=98=BE=E5=8D=A1=E6=97=A0=E6=B3=95=E6=AD=A3=E7=A1=AE=E8=AF=86?= =?UTF-8?q?=E5=88=AB=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 优化了config.py中对于显卡的判断逻辑。原本的判断逻辑会导致2080Ti无法被正确识别。 --- config.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/config.py b/config.py index 81cda369..ad7e6fd9 100644 --- a/config.py +++ b/config.py @@ -158,13 +158,14 @@ def get_device_dtype_sm(idx: int) -> tuple[torch.device, torch.dtype, float, flo major, minor = capability sm_version = major + minor / 10.0 is_16_series = bool(re.search(r"16\d{2}", name)) - if mem_gb < 4: + if mem_gb < 4 or sm_version < 5.3: return cpu, torch.float32, 0.0, 0.0 - if (sm_version >= 7.0 and sm_version != 7.5) or (5.3 <= sm_version <= 6.0): - if is_16_series and sm_version == 7.5: - return cuda, torch.float32, sm_version, mem_gb # 16系卡除外 - else: - return cuda, torch.float16, sm_version, mem_gb + if sm_version < 6.0: + return cuda, torch.float32, sm_version, mem_gb + if is_16_series and sm_version == 7.5: # 16系列不使用float16 + return cuda, torch.float32, sm_version, mem_gb + if sm_version >= 7.0: + return cuda, torch.float16, sm_version, mem_gb return cpu, torch.float32, 0.0, 0.0