mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-15 21:26:51 +08:00
clean
This commit is contained in:
parent
aada52050e
commit
fd8c860f49
@ -61,18 +61,18 @@ def get_device_dtype(device_idx: int = 0) -> tuple[torch.dtype, float, float]:
|
||||
mem_bytes = properties.total_memory
|
||||
mem_gb = mem_bytes / (1024**3) + 0.4
|
||||
device_name = properties.name
|
||||
dtype = torch.float32
|
||||
numbers_in_name = [int(n) for n in re.findall(r'\d+', device_name)]
|
||||
if any(num >= 4000 for num in numbers_in_name):
|
||||
dtype = torch.float16
|
||||
numbers_in_name = [int(n) for n in re.findall(r'\d+', device_name)]
|
||||
if version_val <= 2.1:
|
||||
dtype = torch.float32
|
||||
return dtype, version_val, mem_gb
|
||||
|
||||
def should_ddp(device_idx: int = 0) -> bool:
|
||||
device_name = torch.musa.get_device_properties(device_idx).name
|
||||
numbers_in_name = [int(n) for n in re.findall(r'\d+', device_name)]
|
||||
if any(num >= 4000 for num in numbers_in_name):
|
||||
return True
|
||||
else:
|
||||
if version_val <= 2.1:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
DEVICE: Optional[torch.device] = get_device()
|
||||
|
Loading…
x
Reference in New Issue
Block a user