diff --git a/GPT_SoVITS/text/g2pw/onnx_api.py b/GPT_SoVITS/text/g2pw/onnx_api.py index de7b9ff3..a9b596b5 100644 --- a/GPT_SoVITS/text/g2pw/onnx_api.py +++ b/GPT_SoVITS/text/g2pw/onnx_api.py @@ -17,6 +17,7 @@ import onnxruntime import requests onnxruntime.set_default_logger_severity(3) +onnxruntime.preload_dlls() from pypinyin import Style, pinyin from transformers import AutoTokenizer diff --git a/tools/my_utils.py b/tools/my_utils.py index 03a0b66a..c54d432d 100644 --- a/tools/my_utils.py +++ b/tools/my_utils.py @@ -147,7 +147,7 @@ def load_cudnn(): if torch_lib_dir.exists(): os.add_dll_directory(str(torch_lib_dir)) print(f"[INFO] Added DLL directory: {torch_lib_dir}") - pattern = str(torch_lib_dir / "cudnn_cnn*.so*") + pattern = str(torch_lib_dir / "cudnn_cnn*.dll") matching_files = sorted(glob.glob(pattern)) if not matching_files: print(f"[ERROR] No cudnn_cnn*.dll found in {torch_lib_dir}") @@ -182,3 +182,52 @@ def load_cudnn(): print(f"[INFO] Loaded: {so_path}") except OSError as e: print(f"[WARNING] Failed to load {so_path}: {e}") + + +def load_nvrtc(): + import torch + + if not torch.cuda.is_available(): + print("[INFO] CUDA is not available, skipping nvrtc setup.") + return + + if sys.platform == "win32": + torch_lib_dir = Path(torch.__file__).parent / "lib" + if torch_lib_dir.exists(): + os.add_dll_directory(str(torch_lib_dir)) + print(f"[INFO] Added DLL directory: {torch_lib_dir}") + pattern = str(torch_lib_dir / "nvrtc*.dll") + matching_files = sorted(glob.glob(pattern)) + if not matching_files: + print(f"[ERROR] No nvrtc*.dll found in {torch_lib_dir}") + return + for dll_path in matching_files: + dll_name = os.path.basename(dll_path) + try: + ctypes.CDLL(dll_name) + print(f"[INFO] Loaded: {dll_name}") + except OSError as e: + print(f"[WARNING] Failed to load {dll_name}: {e}") + else: + print(f"[WARNING] Torch lib directory not found: {torch_lib_dir}") + + elif sys.platform == "linux": + site_packages = Path(torch.__file__).resolve().parents[1] + nvrtc_dir = site_packages / "nvidia" / "cuda_nvrtc" / "lib" + + if not nvrtc_dir.exists(): + print(f"[ERROR] nvrtc dir not found: {nvrtc_dir}") + return + + pattern = str(nvrtc_dir / "libnvrtc*.so*") + matching_files = sorted(glob.glob(pattern)) + if not matching_files: + print(f"[ERROR] No libnvrtc*.so* found in {nvrtc_dir}") + return + + for so_path in matching_files: + try: + ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL) + print(f"[INFO] Loaded: {so_path}") + except OSError as e: + print(f"[WARNING] Failed to load {so_path}: {e}")