From fd8c860f491a948542a1a9d4948f912db8c2f2bc Mon Sep 17 00:00:00 2001 From: KakaruHayate Date: Mon, 13 Oct 2025 22:14:23 +0800 Subject: [PATCH] clean --- musa_utils.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/musa_utils.py b/musa_utils.py index 9cdde220..c5200023 100644 --- a/musa_utils.py +++ b/musa_utils.py @@ -61,18 +61,18 @@ def get_device_dtype(device_idx: int = 0) -> tuple[torch.dtype, float, float]: mem_bytes = properties.total_memory mem_gb = mem_bytes / (1024**3) + 0.4 device_name = properties.name - dtype = torch.float32 + dtype = torch.float16 numbers_in_name = [int(n) for n in re.findall(r'\d+', device_name)] - if any(num >= 4000 for num in numbers_in_name): - dtype = torch.float16 + if version_val <= 2.1: + dtype = torch.float32 return dtype, version_val, mem_gb def should_ddp(device_idx: int = 0) -> bool: device_name = torch.musa.get_device_properties(device_idx).name numbers_in_name = [int(n) for n in re.findall(r'\d+', device_name)] - if any(num >= 4000 for num in numbers_in_name): - return True - else: + if version_val <= 2.1: return False + else: + return True DEVICE: Optional[torch.device] = get_device()