GPT-SoVITS/musa_accelerator.py
KakaruHayate 90141d2029 clean
2025-10-13 22:21:50 +08:00

92 lines
3.0 KiB
Python

# refer to: https://github.com/plae-tljg/GPT-SoVITS-Musa/blob/main/code_patches.zip
# Kakaru(https://github.com/KakaruHayate/) 2025/10/11
import logging
import os
import shutil
import subprocess
from typing import Any, Optional, Union
import torch
import torch_musa
from typing_extensions import override
import pytorch_lightning as pl
from lightning_fabric.accelerators import _AcceleratorRegistry
from lightning_fabric.utilities.device_parser import _parse_gpu_ids
from lightning_fabric.utilities.types import _DEVICE
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.utilities.exceptions import MisconfigurationException
import logging
_log = logging.getLogger(__name__)
class MUSAAccelerator(Accelerator):
"""Accelerator for Moore Threads MUSA devices."""
@override
def setup_device(self, device: torch.device) -> None:
"""
Raises:
MisconfigurationException: If the selected device is not MUSA.
"""
if device.type != "musa":
raise MisconfigurationException(f"Device should be MUSA, got {device} instead.")
if device.index is None:
device = torch.device(f"musa:0")
torch.musa.set_device(device)
#@override
#def get_device_stats(self, device: _DEVICE) -> dict[str, Any]:
# if device.type == "musa":
# return torch_musa.memory_stats(device)
@override
def teardown(self) -> None:
"""Clean up any state created by the accelerator."""
torch.musa.empty_cache()
@staticmethod
def parse_devices(devices: Any) -> Any:
return torch.device("musa") if _musa_available else None
@staticmethod
@override
def get_parallel_devices(devices: Any) -> list[torch.device]:
"""Gets parallel devices for the Accelerator."""
if isinstance(devices, torch.device):
return [devices]
if devices is None or devices == "auto":
num_devices = MUSAAccelerator.auto_device_count()
devices = list(range(num_devices))
elif isinstance(devices, int):
devices = list(range(devices))
elif isinstance(devices, (list, tuple)):
pass
else:
raise ValueError(f"Unsupported devices type: {type(devices)}. Expected torch.device, int, list, tuple, or 'auto'.")
return [torch.device("musa", i) for i in devices]
@staticmethod
@override
def auto_device_count() -> int:
"""Get the number of MUSA devices when set to `auto`."""
return torch_musa.device_count()
@staticmethod
@override
def is_available() -> bool:
"""Checks if MUSA is available on the system."""
return torch_musa.is_available()
@classmethod
@override
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
"""Register this accelerator with its name and description."""
accelerator_registry.register(
"musa",
cls,
description=cls.__name__,
override=True,
)