support S1 train on MUSA

This commit is contained in:
KakaruHayate 2025-10-11 22:33:32 +08:00
parent 72be145051
commit 50db2e9199
2 changed files with 138 additions and 6 deletions

View File

@ -15,7 +15,7 @@ from AR.utils.io import load_yaml_config
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger # WandbLogger
from pytorch_lightning.strategies import DDPStrategy
from pytorch_lightning.strategies import DDPStrategy, SingleDeviceStrategy
logging.getLogger("numba").setLevel(logging.WARNING)
logging.getLogger("matplotlib").setLevel(logging.WARNING)
@ -25,6 +25,11 @@ from collections import OrderedDict
from AR.utils import get_newest_ckpt
from process_ckpt import my_save
import musa_utils
if musa_utils.is_available():
import musa_accelerator
os.environ["MUSA_VISIBLE_DEVICES"] = os.environ["_MUSA_VISIBLE_DEVICES"]
class my_model_ckpt(ModelCheckpoint):
def __init__(
@ -108,18 +113,28 @@ def main(args):
logger = TensorBoardLogger(name=output_dir.stem, save_dir=output_dir)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["USE_LIBUV"] = "0"
if musa_utils.is_available():
accelerator = musa_accelerator.MUSAAccelerator()
devices = -1 # MUSA暂时使用单GPU
strategy = SingleDeviceStrategy(device=musa_utils.get_device()) # 不使用分布式训练
elif torch.cuda.is_available():
accelerator = "gpu"
devices = -1
strategy = DDPStrategy(process_group_backend="nccl" if platform.system() != "Windows" else "gloo")
else:
accelerator = "cpu"
devices = 1
strategy = "auto"
trainer: Trainer = Trainer(
max_epochs=config["train"]["epochs"],
accelerator="gpu" if torch.cuda.is_available() else "cpu",
accelerator=accelerator,
# val_check_interval=9999999999999999999999,###不要验证
# check_val_every_n_epoch=None,
limit_val_batches=0,
devices=-1 if torch.cuda.is_available() else 1,
devices=devices,
benchmark=False,
fast_dev_run=False,
strategy=DDPStrategy(process_group_backend="nccl" if platform.system() != "Windows" else "gloo")
if torch.cuda.is_available()
else "auto",
strategy=strategy,
precision=config["train"]["precision"],
logger=logger,
num_sanity_val_steps=0,

117
musa_accelerator.py Normal file
View File

@ -0,0 +1,117 @@
# refer to: https://github.com/plae-tljg/GPT-SoVITS-Musa/blob/main/code_patches.zip
import logging
import os
import shutil
import subprocess
from typing import Any, Optional, Union
import torch
import torch_musa # 添加MUSA支持
from typing_extensions import override
import pytorch_lightning as pl
from lightning_fabric.accelerators import _AcceleratorRegistry
from lightning_fabric.utilities.device_parser import _parse_gpu_ids
from lightning_fabric.utilities.types import _DEVICE
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.utilities.exceptions import MisconfigurationException
import musa_utils
# from icecream import ic
import logging
_log = logging.getLogger(__name__)
class MUSAAccelerator(Accelerator):
"""Accelerator for Moore Threads MUSA devices."""
@override
def setup_device(self, device: torch.device) -> None:
"""
Raises:
MisconfigurationException: If the selected device is not MUSA.
"""
# ic(device)
if device.type != "musa":
raise MisconfigurationException(f"Device should be MUSA, got {device} instead.")
# MUSA 也需要设置当前设备
if device.index is None:
device = torch.device(f"musa:0")
# ic(device)
torch.musa.set_device(device)
@override
def get_device_stats(self, device: _DEVICE) -> dict[str, Any]:
# 修改设备检测逻辑以支持MUSA
if device.type == "musa":
return torch_musa.memory_stats(device)
@override
def teardown(self) -> None:
"""Clean up any state created by the accelerator."""
torch.musa.empty_cache()
@staticmethod
def parse_devices(devices: Any) -> Any:
return musa_utils.get_device()
@staticmethod
@override
def get_parallel_devices(devices: Any) -> list[torch.device]: # 改为Any输出list[torch.device]
"""Gets parallel devices for the Accelerator."""
# ic(devices) # 保留调试
if isinstance(devices, torch.device):
# 新增如果已经是单个device返回列表包装
return [devices]
if devices is None or devices == "auto":
# auto使用所有可用设备
num_devices = MUSAAccelerator.auto_device_count()
devices = list(range(num_devices))
elif isinstance(devices, int):
# int生成索引范围
devices = list(range(devices))
elif isinstance(devices, (list, tuple)):
# 已列表:直接用
pass
else:
# 其他raise错误
raise ValueError(f"Unsupported devices type: {type(devices)}. Expected torch.device, int, list, tuple, or 'auto'.")
# 现在devices是索引列表创建设备
return [torch.device("musa", i) for i in devices]
@staticmethod
@override
def auto_device_count() -> int:
"""Get the number of MUSA devices when set to `auto`."""
# 直接使用我们的工具函数
return torch_musa.device_count()
@staticmethod
@override
def is_available() -> bool:
"""Checks if MUSA is available on the system."""
# 直接使用我们的工具函数
return torch_musa.is_available()
@classmethod
@override
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
"""Register this accelerator with its name and description."""
accelerator_registry.register(
"musa",
cls,
description=cls.__name__,
override=True,
)
# registry = _AcceleratorRegistry
# MUSAAccelerator.register_accelerators(registry)