Fix Onnxruntime-gpu NVRTC Error

This commit is contained in:
XXXXRT666 2025-05-12 08:01:10 +01:00
parent 5d8d9e3232
commit d5e3dbf09f
2 changed files with 51 additions and 1 deletions

View File

@ -17,6 +17,7 @@ import onnxruntime
import requests
onnxruntime.set_default_logger_severity(3)
onnxruntime.preload_dlls()
from pypinyin import Style, pinyin
from transformers import AutoTokenizer

View File

@ -147,7 +147,7 @@ def load_cudnn():
if torch_lib_dir.exists():
os.add_dll_directory(str(torch_lib_dir))
print(f"[INFO] Added DLL directory: {torch_lib_dir}")
pattern = str(torch_lib_dir / "cudnn_cnn*.so*")
pattern = str(torch_lib_dir / "cudnn_cnn*.dll")
matching_files = sorted(glob.glob(pattern))
if not matching_files:
print(f"[ERROR] No cudnn_cnn*.dll found in {torch_lib_dir}")
@ -182,3 +182,52 @@ def load_cudnn():
print(f"[INFO] Loaded: {so_path}")
except OSError as e:
print(f"[WARNING] Failed to load {so_path}: {e}")
def load_nvrtc():
import torch
if not torch.cuda.is_available():
print("[INFO] CUDA is not available, skipping nvrtc setup.")
return
if sys.platform == "win32":
torch_lib_dir = Path(torch.__file__).parent / "lib"
if torch_lib_dir.exists():
os.add_dll_directory(str(torch_lib_dir))
print(f"[INFO] Added DLL directory: {torch_lib_dir}")
pattern = str(torch_lib_dir / "nvrtc*.dll")
matching_files = sorted(glob.glob(pattern))
if not matching_files:
print(f"[ERROR] No nvrtc*.dll found in {torch_lib_dir}")
return
for dll_path in matching_files:
dll_name = os.path.basename(dll_path)
try:
ctypes.CDLL(dll_name)
print(f"[INFO] Loaded: {dll_name}")
except OSError as e:
print(f"[WARNING] Failed to load {dll_name}: {e}")
else:
print(f"[WARNING] Torch lib directory not found: {torch_lib_dir}")
elif sys.platform == "linux":
site_packages = Path(torch.__file__).resolve().parents[1]
nvrtc_dir = site_packages / "nvidia" / "cuda_nvrtc" / "lib"
if not nvrtc_dir.exists():
print(f"[ERROR] nvrtc dir not found: {nvrtc_dir}")
return
pattern = str(nvrtc_dir / "libnvrtc*.so*")
matching_files = sorted(glob.glob(pattern))
if not matching_files:
print(f"[ERROR] No libnvrtc*.so* found in {nvrtc_dir}")
return
for so_path in matching_files:
try:
ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL)
print(f"[INFO] Loaded: {so_path}")
except OSError as e:
print(f"[WARNING] Failed to load {so_path}: {e}")