mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-06-24 13:33:33 +08:00
fix precision auto detection
fix precision auto detection
This commit is contained in:
parent
dbf7702b54
commit
7d70852a3f
14
config.py
14
config.py
@ -144,7 +144,7 @@ webui_port_subfix = 9871
|
|||||||
|
|
||||||
api_port = 9880
|
api_port = 9880
|
||||||
|
|
||||||
|
#Thanks to the contribution of @Karasukaigan and @XXXXRT666
|
||||||
def get_device_dtype_sm(idx: int) -> tuple[torch.device, torch.dtype, float, float]:
|
def get_device_dtype_sm(idx: int) -> tuple[torch.device, torch.dtype, float, float]:
|
||||||
cpu = torch.device("cpu")
|
cpu = torch.device("cpu")
|
||||||
cuda = torch.device(f"cuda:{idx}")
|
cuda = torch.device(f"cuda:{idx}")
|
||||||
@ -157,14 +157,10 @@ def get_device_dtype_sm(idx: int) -> tuple[torch.device, torch.dtype, float, flo
|
|||||||
mem_gb = mem_bytes / (1024**3) + 0.4
|
mem_gb = mem_bytes / (1024**3) + 0.4
|
||||||
major, minor = capability
|
major, minor = capability
|
||||||
sm_version = major + minor / 10.0
|
sm_version = major + minor / 10.0
|
||||||
is_16_series = bool(re.search(r"16\d{2}", name))
|
is_16_series = bool(re.search(r"16\d{2}", name))and sm_version == 7.5
|
||||||
if mem_gb < 4:
|
if mem_gb < 4 or sm_version < 5.3:return cpu, torch.float32, 0.0, 0.0
|
||||||
return cpu, torch.float32, 0.0, 0.0
|
if sm_version == 6.1 or is_16_series==True:return cuda, torch.float32, sm_version, mem_gb
|
||||||
if (sm_version >= 7.0 and sm_version != 7.5) or (5.3 <= sm_version <= 6.0):
|
if sm_version > 6.1:return cuda, torch.float16, sm_version, mem_gb
|
||||||
if is_16_series and sm_version == 7.5:
|
|
||||||
return cuda, torch.float32, sm_version, mem_gb # 16系卡除外
|
|
||||||
else:
|
|
||||||
return cuda, torch.float16, sm_version, mem_gb
|
|
||||||
return cpu, torch.float32, 0.0, 0.0
|
return cpu, torch.float32, 0.0, 0.0
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user