diff --git a/musa_accelerator.py b/musa_accelerator.py index 4788852e..4512bc25 100644 --- a/musa_accelerator.py +++ b/musa_accelerator.py @@ -36,10 +36,10 @@ class MUSAAccelerator(Accelerator): device = torch.device(f"musa:0") torch.musa.set_device(device) - @override - def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: - if device.type == "musa": - return torch_musa.memory_stats(device) + #@override + #def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: + # if device.type == "musa": + # return torch_musa.memory_stats(device) @override def teardown(self) -> None: