This commit is contained in:
KakaruHayate 2025-10-11 22:47:05 +08:00
parent 1f91faf51e
commit 47bb5a2cba

View File

@ -1,4 +1,5 @@
# refer to: https://github.com/plae-tljg/GPT-SoVITS-Musa/blob/main/code_patches.zip # refer to: https://github.com/plae-tljg/GPT-SoVITS-Musa/blob/main/code_patches.zip
# Kakaru(https://github.com/KakaruHayate/) 2025/10/11
import logging import logging
import os import os
@ -7,7 +8,7 @@ import subprocess
from typing import Any, Optional, Union from typing import Any, Optional, Union
import torch import torch
import torch_musa # 添加MUSA支持 import torch_musa
from typing_extensions import override from typing_extensions import override
import pytorch_lightning as pl import pytorch_lightning as pl
@ -17,8 +18,6 @@ from lightning_fabric.utilities.types import _DEVICE
from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.exceptions import MisconfigurationException
# from icecream import ic
import logging import logging
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
@ -31,19 +30,14 @@ class MUSAAccelerator(Accelerator):
Raises: Raises:
MisconfigurationException: If the selected device is not MUSA. MisconfigurationException: If the selected device is not MUSA.
""" """
# ic(device)
if device.type != "musa": if device.type != "musa":
raise MisconfigurationException(f"Device should be MUSA, got {device} instead.") raise MisconfigurationException(f"Device should be MUSA, got {device} instead.")
# MUSA 也需要设置当前设备
if device.index is None: if device.index is None:
device = torch.device(f"musa:0") device = torch.device(f"musa:0")
# ic(device)
torch.musa.set_device(device) torch.musa.set_device(device)
@override @override
def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: def get_device_stats(self, device: _DEVICE) -> dict[str, Any]:
# 修改设备检测逻辑以支持MUSA
if device.type == "musa": if device.type == "musa":
return torch_musa.memory_stats(device) return torch_musa.memory_stats(device)
@ -58,47 +52,31 @@ class MUSAAccelerator(Accelerator):
@staticmethod @staticmethod
@override @override
def get_parallel_devices(devices: Any) -> list[torch.device]: # 改为Any输出list[torch.device] def get_parallel_devices(devices: Any) -> list[torch.device]:
"""Gets parallel devices for the Accelerator.""" """Gets parallel devices for the Accelerator."""
# ic(devices) # 保留调试
if isinstance(devices, torch.device): if isinstance(devices, torch.device):
# 新增如果已经是单个device返回列表包装
return [devices] return [devices]
if devices is None or devices == "auto": if devices is None or devices == "auto":
# auto使用所有可用设备
num_devices = MUSAAccelerator.auto_device_count() num_devices = MUSAAccelerator.auto_device_count()
devices = list(range(num_devices)) devices = list(range(num_devices))
elif isinstance(devices, int): elif isinstance(devices, int):
# int生成索引范围
devices = list(range(devices)) devices = list(range(devices))
elif isinstance(devices, (list, tuple)): elif isinstance(devices, (list, tuple)):
# 已列表:直接用
pass pass
else: else:
# 其他raise错误
raise ValueError(f"Unsupported devices type: {type(devices)}. Expected torch.device, int, list, tuple, or 'auto'.") raise ValueError(f"Unsupported devices type: {type(devices)}. Expected torch.device, int, list, tuple, or 'auto'.")
# 现在devices是索引列表创建设备
return [torch.device("musa", i) for i in devices] return [torch.device("musa", i) for i in devices]
@staticmethod @staticmethod
@override @override
def auto_device_count() -> int: def auto_device_count() -> int:
"""Get the number of MUSA devices when set to `auto`.""" """Get the number of MUSA devices when set to `auto`."""
# 直接使用我们的工具函数
return torch_musa.device_count() return torch_musa.device_count()
@staticmethod @staticmethod
@override @override
def is_available() -> bool: def is_available() -> bool:
"""Checks if MUSA is available on the system.""" """Checks if MUSA is available on the system."""
# 直接使用我们的工具函数
return torch_musa.is_available() return torch_musa.is_available()
@classmethod @classmethod
@ -111,6 +89,3 @@ class MUSAAccelerator(Accelerator):
description=cls.__name__, description=cls.__name__,
override=True, override=True,
) )
# registry = _AcceleratorRegistry
# MUSAAccelerator.register_accelerators(registry)