Merge pull request #3 from plae-tljg/musa_MUSA1009

solved subtle problem on import and syntax
This commit is contained in:
Kakaru 2025-10-20 20:43:46 +08:00 committed by GitHub
commit 70a9243285
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 4 additions and 3 deletions

View File

@ -17,7 +17,7 @@ import torch.multiprocessing as mp
from torch.cuda.amp import GradScaler
musa_ddp = False
if musa_utils.is_available():
os.environ["MUSA_LAUNCH_BLOCKING"] = 1
os.environ["MUSA_LAUNCH_BLOCKING"] = "1"
autocast = torch.musa.amp.autocast
musa_ddp = musa_utils.should_ddp()
elif torch.cuda.is_available():

View File

@ -18,6 +18,8 @@ from lightning_fabric.utilities.types import _DEVICE
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from musa_utils import _musa_available
import logging
_log = logging.getLogger(__name__)

View File

@ -68,8 +68,7 @@ def get_device_dtype(device_idx: int = 0) -> tuple[torch.dtype, float, float]:
return dtype, version_val, mem_gb
def should_ddp(device_idx: int = 0) -> bool:
device_name = torch.musa.get_device_properties(device_idx).name
numbers_in_name = [int(n) for n in re.findall(r'\d+', device_name)]
_, version_val, _ = get_device_dtype(device_idx)
if version_val <= 2.1:
return False
else: