diff --git a/tools/my_utils.py b/tools/my_utils.py index 21fd55f8..03a0b66a 100644 --- a/tools/my_utils.py +++ b/tools/my_utils.py @@ -147,6 +147,18 @@ def load_cudnn(): if torch_lib_dir.exists(): os.add_dll_directory(str(torch_lib_dir)) print(f"[INFO] Added DLL directory: {torch_lib_dir}") + pattern = str(torch_lib_dir / "cudnn_cnn*.so*") + matching_files = sorted(glob.glob(pattern)) + if not matching_files: + print(f"[ERROR] No cudnn_cnn*.dll found in {torch_lib_dir}") + return + for dll_path in matching_files: + dll_name = os.path.basename(dll_path) + try: + ctypes.CDLL(dll_name) + print(f"[INFO] Loaded: {dll_name}") + except OSError as e: + print(f"[WARNING] Failed to load {dll_name}: {e}") else: print(f"[WARNING] Torch lib directory not found: {torch_lib_dir}")