This commit is contained in:
KakaruHayate 2025-10-13 22:14:23 +08:00
parent aada52050e
commit fd8c860f49

View File

@ -61,18 +61,18 @@ def get_device_dtype(device_idx: int = 0) -> tuple[torch.dtype, float, float]:
mem_bytes = properties.total_memory
mem_gb = mem_bytes / (1024**3) + 0.4
device_name = properties.name
dtype = torch.float32
dtype = torch.float16
numbers_in_name = [int(n) for n in re.findall(r'\d+', device_name)]
if any(num >= 4000 for num in numbers_in_name):
dtype = torch.float16
if version_val <= 2.1:
dtype = torch.float32
return dtype, version_val, mem_gb
def should_ddp(device_idx: int = 0) -> bool:
device_name = torch.musa.get_device_properties(device_idx).name
numbers_in_name = [int(n) for n in re.findall(r'\d+', device_name)]
if any(num >= 4000 for num in numbers_in_name):
return True
else:
if version_val <= 2.1:
return False
else:
return True
DEVICE: Optional[torch.device] = get_device()