From 47bb5a2cba6f8d586227b8a92266350711a79371 Mon Sep 17 00:00:00 2001 From: KakaruHayate Date: Sat, 11 Oct 2025 22:47:05 +0800 Subject: [PATCH] clean --- musa_accelerator.py | 31 +++---------------------------- 1 file changed, 3 insertions(+), 28 deletions(-) diff --git a/musa_accelerator.py b/musa_accelerator.py index 56ec7029..4788852e 100644 --- a/musa_accelerator.py +++ b/musa_accelerator.py @@ -1,4 +1,5 @@ # refer to: https://github.com/plae-tljg/GPT-SoVITS-Musa/blob/main/code_patches.zip +# Kakaru(https://github.com/KakaruHayate/) 2025/10/11 import logging import os @@ -7,7 +8,7 @@ import subprocess from typing import Any, Optional, Union import torch -import torch_musa # 添加MUSA支持 +import torch_musa from typing_extensions import override import pytorch_lightning as pl @@ -17,8 +18,6 @@ from lightning_fabric.utilities.types import _DEVICE from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.utilities.exceptions import MisconfigurationException -# from icecream import ic - import logging _log = logging.getLogger(__name__) @@ -31,19 +30,14 @@ class MUSAAccelerator(Accelerator): Raises: MisconfigurationException: If the selected device is not MUSA. """ - # ic(device) if device.type != "musa": raise MisconfigurationException(f"Device should be MUSA, got {device} instead.") - # MUSA 也需要设置当前设备 if device.index is None: device = torch.device(f"musa:0") - # ic(device) torch.musa.set_device(device) - @override def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: - # 修改设备检测逻辑以支持MUSA if device.type == "musa": return torch_musa.memory_stats(device) @@ -58,47 +52,31 @@ class MUSAAccelerator(Accelerator): @staticmethod @override - def get_parallel_devices(devices: Any) -> list[torch.device]: # 改为Any,输出list[torch.device] + def get_parallel_devices(devices: Any) -> list[torch.device]: """Gets parallel devices for the Accelerator.""" - # ic(devices) # 保留调试 - if isinstance(devices, torch.device): - # 新增:如果已经是单个device,返回列表包装 return [devices] - if devices is None or devices == "auto": - # auto:使用所有可用设备 num_devices = MUSAAccelerator.auto_device_count() devices = list(range(num_devices)) - elif isinstance(devices, int): - # int:生成索引范围 devices = list(range(devices)) - elif isinstance(devices, (list, tuple)): - # 已列表:直接用 pass - else: - # 其他:raise错误 raise ValueError(f"Unsupported devices type: {type(devices)}. Expected torch.device, int, list, tuple, or 'auto'.") - - # 现在devices是索引列表,创建设备 return [torch.device("musa", i) for i in devices] - @staticmethod @override def auto_device_count() -> int: """Get the number of MUSA devices when set to `auto`.""" - # 直接使用我们的工具函数 return torch_musa.device_count() @staticmethod @override def is_available() -> bool: """Checks if MUSA is available on the system.""" - # 直接使用我们的工具函数 return torch_musa.is_available() @classmethod @@ -111,6 +89,3 @@ class MUSAAccelerator(Accelerator): description=cls.__name__, override=True, ) - -# registry = _AcceleratorRegistry -# MUSAAccelerator.register_accelerators(registry) \ No newline at end of file