mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-15 21:26:51 +08:00
clean
This commit is contained in:
parent
1f91faf51e
commit
47bb5a2cba
@ -1,4 +1,5 @@
|
|||||||
# refer to: https://github.com/plae-tljg/GPT-SoVITS-Musa/blob/main/code_patches.zip
|
# refer to: https://github.com/plae-tljg/GPT-SoVITS-Musa/blob/main/code_patches.zip
|
||||||
|
# Kakaru(https://github.com/KakaruHayate/) 2025/10/11
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
@ -7,7 +8,7 @@ import subprocess
|
|||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch_musa # 添加MUSA支持
|
import torch_musa
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
@ -17,8 +18,6 @@ from lightning_fabric.utilities.types import _DEVICE
|
|||||||
from pytorch_lightning.accelerators.accelerator import Accelerator
|
from pytorch_lightning.accelerators.accelerator import Accelerator
|
||||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||||
|
|
||||||
# from icecream import ic
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
_log = logging.getLogger(__name__)
|
_log = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -31,19 +30,14 @@ class MUSAAccelerator(Accelerator):
|
|||||||
Raises:
|
Raises:
|
||||||
MisconfigurationException: If the selected device is not MUSA.
|
MisconfigurationException: If the selected device is not MUSA.
|
||||||
"""
|
"""
|
||||||
# ic(device)
|
|
||||||
if device.type != "musa":
|
if device.type != "musa":
|
||||||
raise MisconfigurationException(f"Device should be MUSA, got {device} instead.")
|
raise MisconfigurationException(f"Device should be MUSA, got {device} instead.")
|
||||||
# MUSA 也需要设置当前设备
|
|
||||||
if device.index is None:
|
if device.index is None:
|
||||||
device = torch.device(f"musa:0")
|
device = torch.device(f"musa:0")
|
||||||
# ic(device)
|
|
||||||
torch.musa.set_device(device)
|
torch.musa.set_device(device)
|
||||||
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def get_device_stats(self, device: _DEVICE) -> dict[str, Any]:
|
def get_device_stats(self, device: _DEVICE) -> dict[str, Any]:
|
||||||
# 修改设备检测逻辑以支持MUSA
|
|
||||||
if device.type == "musa":
|
if device.type == "musa":
|
||||||
return torch_musa.memory_stats(device)
|
return torch_musa.memory_stats(device)
|
||||||
|
|
||||||
@ -58,47 +52,31 @@ class MUSAAccelerator(Accelerator):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@override
|
@override
|
||||||
def get_parallel_devices(devices: Any) -> list[torch.device]: # 改为Any,输出list[torch.device]
|
def get_parallel_devices(devices: Any) -> list[torch.device]:
|
||||||
"""Gets parallel devices for the Accelerator."""
|
"""Gets parallel devices for the Accelerator."""
|
||||||
# ic(devices) # 保留调试
|
|
||||||
|
|
||||||
if isinstance(devices, torch.device):
|
if isinstance(devices, torch.device):
|
||||||
# 新增:如果已经是单个device,返回列表包装
|
|
||||||
return [devices]
|
return [devices]
|
||||||
|
|
||||||
if devices is None or devices == "auto":
|
if devices is None or devices == "auto":
|
||||||
# auto:使用所有可用设备
|
|
||||||
num_devices = MUSAAccelerator.auto_device_count()
|
num_devices = MUSAAccelerator.auto_device_count()
|
||||||
devices = list(range(num_devices))
|
devices = list(range(num_devices))
|
||||||
|
|
||||||
elif isinstance(devices, int):
|
elif isinstance(devices, int):
|
||||||
# int:生成索引范围
|
|
||||||
devices = list(range(devices))
|
devices = list(range(devices))
|
||||||
|
|
||||||
elif isinstance(devices, (list, tuple)):
|
elif isinstance(devices, (list, tuple)):
|
||||||
# 已列表:直接用
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# 其他:raise错误
|
|
||||||
raise ValueError(f"Unsupported devices type: {type(devices)}. Expected torch.device, int, list, tuple, or 'auto'.")
|
raise ValueError(f"Unsupported devices type: {type(devices)}. Expected torch.device, int, list, tuple, or 'auto'.")
|
||||||
|
|
||||||
# 现在devices是索引列表,创建设备
|
|
||||||
return [torch.device("musa", i) for i in devices]
|
return [torch.device("musa", i) for i in devices]
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@override
|
@override
|
||||||
def auto_device_count() -> int:
|
def auto_device_count() -> int:
|
||||||
"""Get the number of MUSA devices when set to `auto`."""
|
"""Get the number of MUSA devices when set to `auto`."""
|
||||||
# 直接使用我们的工具函数
|
|
||||||
return torch_musa.device_count()
|
return torch_musa.device_count()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@override
|
@override
|
||||||
def is_available() -> bool:
|
def is_available() -> bool:
|
||||||
"""Checks if MUSA is available on the system."""
|
"""Checks if MUSA is available on the system."""
|
||||||
# 直接使用我们的工具函数
|
|
||||||
return torch_musa.is_available()
|
return torch_musa.is_available()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -111,6 +89,3 @@ class MUSAAccelerator(Accelerator):
|
|||||||
description=cls.__name__,
|
description=cls.__name__,
|
||||||
override=True,
|
override=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# registry = _AcceleratorRegistry
|
|
||||||
# MUSAAccelerator.register_accelerators(registry)
|
|
Loading…
x
Reference in New Issue
Block a user