diff --git a/musa_accelerator.py b/musa_accelerator.py index 2eaccd2e..56ec7029 100644 --- a/musa_accelerator.py +++ b/musa_accelerator.py @@ -17,7 +17,6 @@ from lightning_fabric.utilities.types import _DEVICE from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.utilities.exceptions import MisconfigurationException -import musa_utils # from icecream import ic import logging @@ -55,7 +54,7 @@ class MUSAAccelerator(Accelerator): @staticmethod def parse_devices(devices: Any) -> Any: - return musa_utils.get_device() + return torch.device("musa") if _musa_available else None @staticmethod @override