GPT-SoVITS/musa_accelerator.py
2025-10-11 22:36:58 +08:00

116 lines
3.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# refer to: https://github.com/plae-tljg/GPT-SoVITS-Musa/blob/main/code_patches.zip
import logging
import os
import shutil
import subprocess
from typing import Any, Optional, Union
import torch
import torch_musa # 添加MUSA支持
from typing_extensions import override
import pytorch_lightning as pl
from lightning_fabric.accelerators import _AcceleratorRegistry
from lightning_fabric.utilities.device_parser import _parse_gpu_ids
from lightning_fabric.utilities.types import _DEVICE
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.utilities.exceptions import MisconfigurationException
# from icecream import ic
import logging
_log = logging.getLogger(__name__)
class MUSAAccelerator(Accelerator):
"""Accelerator for Moore Threads MUSA devices."""
@override
def setup_device(self, device: torch.device) -> None:
"""
Raises:
MisconfigurationException: If the selected device is not MUSA.
"""
# ic(device)
if device.type != "musa":
raise MisconfigurationException(f"Device should be MUSA, got {device} instead.")
# MUSA 也需要设置当前设备
if device.index is None:
device = torch.device(f"musa:0")
# ic(device)
torch.musa.set_device(device)
@override
def get_device_stats(self, device: _DEVICE) -> dict[str, Any]:
# 修改设备检测逻辑以支持MUSA
if device.type == "musa":
return torch_musa.memory_stats(device)
@override
def teardown(self) -> None:
"""Clean up any state created by the accelerator."""
torch.musa.empty_cache()
@staticmethod
def parse_devices(devices: Any) -> Any:
return torch.device("musa") if _musa_available else None
@staticmethod
@override
def get_parallel_devices(devices: Any) -> list[torch.device]: # 改为Any输出list[torch.device]
"""Gets parallel devices for the Accelerator."""
# ic(devices) # 保留调试
if isinstance(devices, torch.device):
# 新增如果已经是单个device返回列表包装
return [devices]
if devices is None or devices == "auto":
# auto使用所有可用设备
num_devices = MUSAAccelerator.auto_device_count()
devices = list(range(num_devices))
elif isinstance(devices, int):
# int生成索引范围
devices = list(range(devices))
elif isinstance(devices, (list, tuple)):
# 已列表:直接用
pass
else:
# 其他raise错误
raise ValueError(f"Unsupported devices type: {type(devices)}. Expected torch.device, int, list, tuple, or 'auto'.")
# 现在devices是索引列表创建设备
return [torch.device("musa", i) for i in devices]
@staticmethod
@override
def auto_device_count() -> int:
"""Get the number of MUSA devices when set to `auto`."""
# 直接使用我们的工具函数
return torch_musa.device_count()
@staticmethod
@override
def is_available() -> bool:
"""Checks if MUSA is available on the system."""
# 直接使用我们的工具函数
return torch_musa.is_available()
@classmethod
@override
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
"""Register this accelerator with its name and description."""
accelerator_registry.register(
"musa",
cls,
description=cls.__name__,
override=True,
)
# registry = _AcceleratorRegistry
# MUSAAccelerator.register_accelerators(registry)