diff --git a/GPT_SoVITS/s2_train.py b/GPT_SoVITS/s2_train.py index cb978181..dc9e7643 100644 --- a/GPT_SoVITS/s2_train.py +++ b/GPT_SoVITS/s2_train.py @@ -17,7 +17,7 @@ import torch.multiprocessing as mp from torch.cuda.amp import GradScaler musa_ddp = False if musa_utils.is_available(): - os.environ["MUSA_LAUNCH_BLOCKING"] = 1 + os.environ["MUSA_LAUNCH_BLOCKING"] = "1" autocast = torch.musa.amp.autocast musa_ddp = musa_utils.should_ddp() elif torch.cuda.is_available(): diff --git a/musa_accelerator.py b/musa_accelerator.py index 4512bc25..249e152c 100644 --- a/musa_accelerator.py +++ b/musa_accelerator.py @@ -18,6 +18,8 @@ from lightning_fabric.utilities.types import _DEVICE from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.utilities.exceptions import MisconfigurationException +from musa_utils import _musa_available + import logging _log = logging.getLogger(__name__) diff --git a/musa_utils.py b/musa_utils.py index c5200023..ebc14ac4 100644 --- a/musa_utils.py +++ b/musa_utils.py @@ -68,8 +68,7 @@ def get_device_dtype(device_idx: int = 0) -> tuple[torch.dtype, float, float]: return dtype, version_val, mem_gb def should_ddp(device_idx: int = 0) -> bool: - device_name = torch.musa.get_device_properties(device_idx).name - numbers_in_name = [int(n) for n in re.findall(r'\d+', device_name)] + _, version_val, _ = get_device_dtype(device_idx) if version_val <= 2.1: return False else: