mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-12-16 01:06:57 +08:00
Merge pull request #3 from plae-tljg/musa_MUSA1009
solved subtle problem on import and syntax
This commit is contained in:
commit
70a9243285
@ -17,7 +17,7 @@ import torch.multiprocessing as mp
|
||||
from torch.cuda.amp import GradScaler
|
||||
musa_ddp = False
|
||||
if musa_utils.is_available():
|
||||
os.environ["MUSA_LAUNCH_BLOCKING"] = 1
|
||||
os.environ["MUSA_LAUNCH_BLOCKING"] = "1"
|
||||
autocast = torch.musa.amp.autocast
|
||||
musa_ddp = musa_utils.should_ddp()
|
||||
elif torch.cuda.is_available():
|
||||
|
||||
@ -18,6 +18,8 @@ from lightning_fabric.utilities.types import _DEVICE
|
||||
from pytorch_lightning.accelerators.accelerator import Accelerator
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
from musa_utils import _musa_available
|
||||
|
||||
import logging
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -68,8 +68,7 @@ def get_device_dtype(device_idx: int = 0) -> tuple[torch.dtype, float, float]:
|
||||
return dtype, version_val, mem_gb
|
||||
|
||||
def should_ddp(device_idx: int = 0) -> bool:
|
||||
device_name = torch.musa.get_device_properties(device_idx).name
|
||||
numbers_in_name = [int(n) for n in re.findall(r'\d+', device_name)]
|
||||
_, version_val, _ = get_device_dtype(device_idx)
|
||||
if version_val <= 2.1:
|
||||
return False
|
||||
else:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user