mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-08-31 19:43:09 +08:00
Fix Onnxruntime-gpu NVRTC Error
This commit is contained in:
parent
5d8d9e3232
commit
d5e3dbf09f
@ -17,6 +17,7 @@ import onnxruntime
|
||||
import requests
|
||||
|
||||
onnxruntime.set_default_logger_severity(3)
|
||||
onnxruntime.preload_dlls()
|
||||
from pypinyin import Style, pinyin
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
@ -147,7 +147,7 @@ def load_cudnn():
|
||||
if torch_lib_dir.exists():
|
||||
os.add_dll_directory(str(torch_lib_dir))
|
||||
print(f"[INFO] Added DLL directory: {torch_lib_dir}")
|
||||
pattern = str(torch_lib_dir / "cudnn_cnn*.so*")
|
||||
pattern = str(torch_lib_dir / "cudnn_cnn*.dll")
|
||||
matching_files = sorted(glob.glob(pattern))
|
||||
if not matching_files:
|
||||
print(f"[ERROR] No cudnn_cnn*.dll found in {torch_lib_dir}")
|
||||
@ -182,3 +182,52 @@ def load_cudnn():
|
||||
print(f"[INFO] Loaded: {so_path}")
|
||||
except OSError as e:
|
||||
print(f"[WARNING] Failed to load {so_path}: {e}")
|
||||
|
||||
|
||||
def load_nvrtc():
|
||||
import torch
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
print("[INFO] CUDA is not available, skipping nvrtc setup.")
|
||||
return
|
||||
|
||||
if sys.platform == "win32":
|
||||
torch_lib_dir = Path(torch.__file__).parent / "lib"
|
||||
if torch_lib_dir.exists():
|
||||
os.add_dll_directory(str(torch_lib_dir))
|
||||
print(f"[INFO] Added DLL directory: {torch_lib_dir}")
|
||||
pattern = str(torch_lib_dir / "nvrtc*.dll")
|
||||
matching_files = sorted(glob.glob(pattern))
|
||||
if not matching_files:
|
||||
print(f"[ERROR] No nvrtc*.dll found in {torch_lib_dir}")
|
||||
return
|
||||
for dll_path in matching_files:
|
||||
dll_name = os.path.basename(dll_path)
|
||||
try:
|
||||
ctypes.CDLL(dll_name)
|
||||
print(f"[INFO] Loaded: {dll_name}")
|
||||
except OSError as e:
|
||||
print(f"[WARNING] Failed to load {dll_name}: {e}")
|
||||
else:
|
||||
print(f"[WARNING] Torch lib directory not found: {torch_lib_dir}")
|
||||
|
||||
elif sys.platform == "linux":
|
||||
site_packages = Path(torch.__file__).resolve().parents[1]
|
||||
nvrtc_dir = site_packages / "nvidia" / "cuda_nvrtc" / "lib"
|
||||
|
||||
if not nvrtc_dir.exists():
|
||||
print(f"[ERROR] nvrtc dir not found: {nvrtc_dir}")
|
||||
return
|
||||
|
||||
pattern = str(nvrtc_dir / "libnvrtc*.so*")
|
||||
matching_files = sorted(glob.glob(pattern))
|
||||
if not matching_files:
|
||||
print(f"[ERROR] No libnvrtc*.so* found in {nvrtc_dir}")
|
||||
return
|
||||
|
||||
for so_path in matching_files:
|
||||
try:
|
||||
ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL)
|
||||
print(f"[INFO] Loaded: {so_path}")
|
||||
except OSError as e:
|
||||
print(f"[WARNING] Failed to load {so_path}: {e}")
|
||||
|
Loading…
x
Reference in New Issue
Block a user