diff --git a/GPT_SoVITS/s1_train.py b/GPT_SoVITS/s1_train.py index 3eeb57c1..ba8b4cbb 100644 --- a/GPT_SoVITS/s1_train.py +++ b/GPT_SoVITS/s1_train.py @@ -15,7 +15,7 @@ from AR.utils.io import load_yaml_config from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import TensorBoardLogger # WandbLogger -from pytorch_lightning.strategies import DDPStrategy +from pytorch_lightning.strategies import DDPStrategy, SingleDeviceStrategy logging.getLogger("numba").setLevel(logging.WARNING) logging.getLogger("matplotlib").setLevel(logging.WARNING) @@ -25,6 +25,11 @@ from collections import OrderedDict from AR.utils import get_newest_ckpt from process_ckpt import my_save +import musa_utils +if musa_utils.is_available(): + import musa_accelerator + os.environ["MUSA_VISIBLE_DEVICES"] = os.environ["_MUSA_VISIBLE_DEVICES"] + class my_model_ckpt(ModelCheckpoint): def __init__( @@ -108,18 +113,28 @@ def main(args): logger = TensorBoardLogger(name=output_dir.stem, save_dir=output_dir) os.environ["MASTER_ADDR"] = "localhost" os.environ["USE_LIBUV"] = "0" + if musa_utils.is_available(): + accelerator = musa_accelerator.MUSAAccelerator() + devices = -1 # MUSA暂时使用单GPU + strategy = SingleDeviceStrategy(device=musa_utils.get_device()) # 不使用分布式训练 + elif torch.cuda.is_available(): + accelerator = "gpu" + devices = -1 + strategy = DDPStrategy(process_group_backend="nccl" if platform.system() != "Windows" else "gloo") + else: + accelerator = "cpu" + devices = 1 + strategy = "auto" trainer: Trainer = Trainer( max_epochs=config["train"]["epochs"], - accelerator="gpu" if torch.cuda.is_available() else "cpu", + accelerator=accelerator, # val_check_interval=9999999999999999999999,###不要验证 # check_val_every_n_epoch=None, limit_val_batches=0, - devices=-1 if torch.cuda.is_available() else 1, + devices=devices, benchmark=False, fast_dev_run=False, - strategy=DDPStrategy(process_group_backend="nccl" if platform.system() != "Windows" else "gloo") - if torch.cuda.is_available() - else "auto", + strategy=strategy, precision=config["train"]["precision"], logger=logger, num_sanity_val_steps=0, diff --git a/musa_accelerator.py b/musa_accelerator.py new file mode 100644 index 00000000..2eaccd2e --- /dev/null +++ b/musa_accelerator.py @@ -0,0 +1,117 @@ +# refer to: https://github.com/plae-tljg/GPT-SoVITS-Musa/blob/main/code_patches.zip + +import logging +import os +import shutil +import subprocess +from typing import Any, Optional, Union + +import torch +import torch_musa # 添加MUSA支持 +from typing_extensions import override + +import pytorch_lightning as pl +from lightning_fabric.accelerators import _AcceleratorRegistry +from lightning_fabric.utilities.device_parser import _parse_gpu_ids +from lightning_fabric.utilities.types import _DEVICE +from pytorch_lightning.accelerators.accelerator import Accelerator +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +import musa_utils +# from icecream import ic + +import logging +_log = logging.getLogger(__name__) + +class MUSAAccelerator(Accelerator): + """Accelerator for Moore Threads MUSA devices.""" + + @override + def setup_device(self, device: torch.device) -> None: + """ + Raises: + MisconfigurationException: If the selected device is not MUSA. + """ + # ic(device) + if device.type != "musa": + raise MisconfigurationException(f"Device should be MUSA, got {device} instead.") + # MUSA 也需要设置当前设备 + if device.index is None: + device = torch.device(f"musa:0") + # ic(device) + torch.musa.set_device(device) + + + @override + def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: + # 修改设备检测逻辑以支持MUSA + if device.type == "musa": + return torch_musa.memory_stats(device) + + @override + def teardown(self) -> None: + """Clean up any state created by the accelerator.""" + torch.musa.empty_cache() + + @staticmethod + def parse_devices(devices: Any) -> Any: + return musa_utils.get_device() + + @staticmethod + @override + def get_parallel_devices(devices: Any) -> list[torch.device]: # 改为Any,输出list[torch.device] + """Gets parallel devices for the Accelerator.""" + # ic(devices) # 保留调试 + + if isinstance(devices, torch.device): + # 新增:如果已经是单个device,返回列表包装 + return [devices] + + if devices is None or devices == "auto": + # auto:使用所有可用设备 + num_devices = MUSAAccelerator.auto_device_count() + devices = list(range(num_devices)) + + elif isinstance(devices, int): + # int:生成索引范围 + devices = list(range(devices)) + + elif isinstance(devices, (list, tuple)): + # 已列表:直接用 + pass + + else: + # 其他:raise错误 + raise ValueError(f"Unsupported devices type: {type(devices)}. Expected torch.device, int, list, tuple, or 'auto'.") + + # 现在devices是索引列表,创建设备 + return [torch.device("musa", i) for i in devices] + + + @staticmethod + @override + def auto_device_count() -> int: + """Get the number of MUSA devices when set to `auto`.""" + # 直接使用我们的工具函数 + return torch_musa.device_count() + + @staticmethod + @override + def is_available() -> bool: + """Checks if MUSA is available on the system.""" + # 直接使用我们的工具函数 + return torch_musa.is_available() + + @classmethod + @override + def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None: + """Register this accelerator with its name and description.""" + accelerator_registry.register( + "musa", + cls, + description=cls.__name__, + override=True, + ) + +# registry = _AcceleratorRegistry +# MUSAAccelerator.register_accelerators(registry) \ No newline at end of file