This commit is contained in:
KakaruHayate 2025-10-13 22:21:50 +08:00
parent fd8c860f49
commit 90141d2029

View File

@ -36,10 +36,10 @@ class MUSAAccelerator(Accelerator):
device = torch.device(f"musa:0")
torch.musa.set_device(device)
@override
def get_device_stats(self, device: _DEVICE) -> dict[str, Any]:
if device.type == "musa":
return torch_musa.memory_stats(device)
#@override
#def get_device_stats(self, device: _DEVICE) -> dict[str, Any]:
# if device.type == "musa":
# return torch_musa.memory_stats(device)
@override
def teardown(self) -> None: