mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-15 21:26:51 +08:00
support S1 train on MUSA
This commit is contained in:
parent
72be145051
commit
50db2e9199
@ -15,7 +15,7 @@ from AR.utils.io import load_yaml_config
|
||||
from pytorch_lightning import Trainer, seed_everything
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.loggers import TensorBoardLogger # WandbLogger
|
||||
from pytorch_lightning.strategies import DDPStrategy
|
||||
from pytorch_lightning.strategies import DDPStrategy, SingleDeviceStrategy
|
||||
|
||||
logging.getLogger("numba").setLevel(logging.WARNING)
|
||||
logging.getLogger("matplotlib").setLevel(logging.WARNING)
|
||||
@ -25,6 +25,11 @@ from collections import OrderedDict
|
||||
from AR.utils import get_newest_ckpt
|
||||
from process_ckpt import my_save
|
||||
|
||||
import musa_utils
|
||||
if musa_utils.is_available():
|
||||
import musa_accelerator
|
||||
os.environ["MUSA_VISIBLE_DEVICES"] = os.environ["_MUSA_VISIBLE_DEVICES"]
|
||||
|
||||
|
||||
class my_model_ckpt(ModelCheckpoint):
|
||||
def __init__(
|
||||
@ -108,18 +113,28 @@ def main(args):
|
||||
logger = TensorBoardLogger(name=output_dir.stem, save_dir=output_dir)
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["USE_LIBUV"] = "0"
|
||||
if musa_utils.is_available():
|
||||
accelerator = musa_accelerator.MUSAAccelerator()
|
||||
devices = -1 # MUSA暂时使用单GPU
|
||||
strategy = SingleDeviceStrategy(device=musa_utils.get_device()) # 不使用分布式训练
|
||||
elif torch.cuda.is_available():
|
||||
accelerator = "gpu"
|
||||
devices = -1
|
||||
strategy = DDPStrategy(process_group_backend="nccl" if platform.system() != "Windows" else "gloo")
|
||||
else:
|
||||
accelerator = "cpu"
|
||||
devices = 1
|
||||
strategy = "auto"
|
||||
trainer: Trainer = Trainer(
|
||||
max_epochs=config["train"]["epochs"],
|
||||
accelerator="gpu" if torch.cuda.is_available() else "cpu",
|
||||
accelerator=accelerator,
|
||||
# val_check_interval=9999999999999999999999,###不要验证
|
||||
# check_val_every_n_epoch=None,
|
||||
limit_val_batches=0,
|
||||
devices=-1 if torch.cuda.is_available() else 1,
|
||||
devices=devices,
|
||||
benchmark=False,
|
||||
fast_dev_run=False,
|
||||
strategy=DDPStrategy(process_group_backend="nccl" if platform.system() != "Windows" else "gloo")
|
||||
if torch.cuda.is_available()
|
||||
else "auto",
|
||||
strategy=strategy,
|
||||
precision=config["train"]["precision"],
|
||||
logger=logger,
|
||||
num_sanity_val_steps=0,
|
||||
|
117
musa_accelerator.py
Normal file
117
musa_accelerator.py
Normal file
@ -0,0 +1,117 @@
|
||||
# refer to: https://github.com/plae-tljg/GPT-SoVITS-Musa/blob/main/code_patches.zip
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch_musa # 添加MUSA支持
|
||||
from typing_extensions import override
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_fabric.accelerators import _AcceleratorRegistry
|
||||
from lightning_fabric.utilities.device_parser import _parse_gpu_ids
|
||||
from lightning_fabric.utilities.types import _DEVICE
|
||||
from pytorch_lightning.accelerators.accelerator import Accelerator
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
import musa_utils
|
||||
# from icecream import ic
|
||||
|
||||
import logging
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
class MUSAAccelerator(Accelerator):
|
||||
"""Accelerator for Moore Threads MUSA devices."""
|
||||
|
||||
@override
|
||||
def setup_device(self, device: torch.device) -> None:
|
||||
"""
|
||||
Raises:
|
||||
MisconfigurationException: If the selected device is not MUSA.
|
||||
"""
|
||||
# ic(device)
|
||||
if device.type != "musa":
|
||||
raise MisconfigurationException(f"Device should be MUSA, got {device} instead.")
|
||||
# MUSA 也需要设置当前设备
|
||||
if device.index is None:
|
||||
device = torch.device(f"musa:0")
|
||||
# ic(device)
|
||||
torch.musa.set_device(device)
|
||||
|
||||
|
||||
@override
|
||||
def get_device_stats(self, device: _DEVICE) -> dict[str, Any]:
|
||||
# 修改设备检测逻辑以支持MUSA
|
||||
if device.type == "musa":
|
||||
return torch_musa.memory_stats(device)
|
||||
|
||||
@override
|
||||
def teardown(self) -> None:
|
||||
"""Clean up any state created by the accelerator."""
|
||||
torch.musa.empty_cache()
|
||||
|
||||
@staticmethod
|
||||
def parse_devices(devices: Any) -> Any:
|
||||
return musa_utils.get_device()
|
||||
|
||||
@staticmethod
|
||||
@override
|
||||
def get_parallel_devices(devices: Any) -> list[torch.device]: # 改为Any,输出list[torch.device]
|
||||
"""Gets parallel devices for the Accelerator."""
|
||||
# ic(devices) # 保留调试
|
||||
|
||||
if isinstance(devices, torch.device):
|
||||
# 新增:如果已经是单个device,返回列表包装
|
||||
return [devices]
|
||||
|
||||
if devices is None or devices == "auto":
|
||||
# auto:使用所有可用设备
|
||||
num_devices = MUSAAccelerator.auto_device_count()
|
||||
devices = list(range(num_devices))
|
||||
|
||||
elif isinstance(devices, int):
|
||||
# int:生成索引范围
|
||||
devices = list(range(devices))
|
||||
|
||||
elif isinstance(devices, (list, tuple)):
|
||||
# 已列表:直接用
|
||||
pass
|
||||
|
||||
else:
|
||||
# 其他:raise错误
|
||||
raise ValueError(f"Unsupported devices type: {type(devices)}. Expected torch.device, int, list, tuple, or 'auto'.")
|
||||
|
||||
# 现在devices是索引列表,创建设备
|
||||
return [torch.device("musa", i) for i in devices]
|
||||
|
||||
|
||||
@staticmethod
|
||||
@override
|
||||
def auto_device_count() -> int:
|
||||
"""Get the number of MUSA devices when set to `auto`."""
|
||||
# 直接使用我们的工具函数
|
||||
return torch_musa.device_count()
|
||||
|
||||
@staticmethod
|
||||
@override
|
||||
def is_available() -> bool:
|
||||
"""Checks if MUSA is available on the system."""
|
||||
# 直接使用我们的工具函数
|
||||
return torch_musa.is_available()
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
|
||||
"""Register this accelerator with its name and description."""
|
||||
accelerator_registry.register(
|
||||
"musa",
|
||||
cls,
|
||||
description=cls.__name__,
|
||||
override=True,
|
||||
)
|
||||
|
||||
# registry = _AcceleratorRegistry
|
||||
# MUSAAccelerator.register_accelerators(registry)
|
Loading…
x
Reference in New Issue
Block a user