mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-29 17:10:02 +08:00
31 lines
1.1 KiB
Python
31 lines
1.1 KiB
Python
import importlib.util
|
|
|
|
import torch
|
|
|
|
from .sample_funcs import sample_naive
|
|
from .structs import T2SRequest, T2SResult
|
|
from .t2s_engine import T2SEngine as T2SEngineTorch
|
|
|
|
torch.set_grad_enabled(False)
|
|
|
|
backends = ["torch_varlen"]
|
|
if torch.cuda.is_available():
|
|
backends.append("torch_static_cuda_graph")
|
|
# if importlib.util.find_spec("sageattention") is not None:
|
|
# for i in range(torch.cuda.device_count()):
|
|
# major, minor = torch.cuda.get_device_capability(i)
|
|
# sm_version = major + minor / 10.0
|
|
# if sm_version >= 7.0:
|
|
# backends.append("sage_attn_varlen_cuda_graph")
|
|
if importlib.util.find_spec("flash_attn") is not None:
|
|
for i in range(torch.cuda.device_count()):
|
|
major, minor = torch.cuda.get_device_capability(i)
|
|
sm_version = major + minor / 10.0
|
|
if sm_version >= 7.5:
|
|
backends.append("flash_attn_varlen_cuda_graph")
|
|
# if torch.mps.is_available():
|
|
# backends.append("mps_flash_attn_varlen")
|
|
|
|
|
|
__all__ = ["T2SEngineTorch", "T2SRequest", "sample_naive", "T2SResult", "backends"]
|