mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-01 03:55:53 +08:00
Fix Onnxruntime-gpu NVRTC Error
This commit is contained in:
parent
5d8d9e3232
commit
d5e3dbf09f
@ -17,6 +17,7 @@ import onnxruntime
|
|||||||
import requests
|
import requests
|
||||||
|
|
||||||
onnxruntime.set_default_logger_severity(3)
|
onnxruntime.set_default_logger_severity(3)
|
||||||
|
onnxruntime.preload_dlls()
|
||||||
from pypinyin import Style, pinyin
|
from pypinyin import Style, pinyin
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
@ -147,7 +147,7 @@ def load_cudnn():
|
|||||||
if torch_lib_dir.exists():
|
if torch_lib_dir.exists():
|
||||||
os.add_dll_directory(str(torch_lib_dir))
|
os.add_dll_directory(str(torch_lib_dir))
|
||||||
print(f"[INFO] Added DLL directory: {torch_lib_dir}")
|
print(f"[INFO] Added DLL directory: {torch_lib_dir}")
|
||||||
pattern = str(torch_lib_dir / "cudnn_cnn*.so*")
|
pattern = str(torch_lib_dir / "cudnn_cnn*.dll")
|
||||||
matching_files = sorted(glob.glob(pattern))
|
matching_files = sorted(glob.glob(pattern))
|
||||||
if not matching_files:
|
if not matching_files:
|
||||||
print(f"[ERROR] No cudnn_cnn*.dll found in {torch_lib_dir}")
|
print(f"[ERROR] No cudnn_cnn*.dll found in {torch_lib_dir}")
|
||||||
@ -182,3 +182,52 @@ def load_cudnn():
|
|||||||
print(f"[INFO] Loaded: {so_path}")
|
print(f"[INFO] Loaded: {so_path}")
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
print(f"[WARNING] Failed to load {so_path}: {e}")
|
print(f"[WARNING] Failed to load {so_path}: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def load_nvrtc():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
print("[INFO] CUDA is not available, skipping nvrtc setup.")
|
||||||
|
return
|
||||||
|
|
||||||
|
if sys.platform == "win32":
|
||||||
|
torch_lib_dir = Path(torch.__file__).parent / "lib"
|
||||||
|
if torch_lib_dir.exists():
|
||||||
|
os.add_dll_directory(str(torch_lib_dir))
|
||||||
|
print(f"[INFO] Added DLL directory: {torch_lib_dir}")
|
||||||
|
pattern = str(torch_lib_dir / "nvrtc*.dll")
|
||||||
|
matching_files = sorted(glob.glob(pattern))
|
||||||
|
if not matching_files:
|
||||||
|
print(f"[ERROR] No nvrtc*.dll found in {torch_lib_dir}")
|
||||||
|
return
|
||||||
|
for dll_path in matching_files:
|
||||||
|
dll_name = os.path.basename(dll_path)
|
||||||
|
try:
|
||||||
|
ctypes.CDLL(dll_name)
|
||||||
|
print(f"[INFO] Loaded: {dll_name}")
|
||||||
|
except OSError as e:
|
||||||
|
print(f"[WARNING] Failed to load {dll_name}: {e}")
|
||||||
|
else:
|
||||||
|
print(f"[WARNING] Torch lib directory not found: {torch_lib_dir}")
|
||||||
|
|
||||||
|
elif sys.platform == "linux":
|
||||||
|
site_packages = Path(torch.__file__).resolve().parents[1]
|
||||||
|
nvrtc_dir = site_packages / "nvidia" / "cuda_nvrtc" / "lib"
|
||||||
|
|
||||||
|
if not nvrtc_dir.exists():
|
||||||
|
print(f"[ERROR] nvrtc dir not found: {nvrtc_dir}")
|
||||||
|
return
|
||||||
|
|
||||||
|
pattern = str(nvrtc_dir / "libnvrtc*.so*")
|
||||||
|
matching_files = sorted(glob.glob(pattern))
|
||||||
|
if not matching_files:
|
||||||
|
print(f"[ERROR] No libnvrtc*.so* found in {nvrtc_dir}")
|
||||||
|
return
|
||||||
|
|
||||||
|
for so_path in matching_files:
|
||||||
|
try:
|
||||||
|
ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL)
|
||||||
|
print(f"[INFO] Loaded: {so_path}")
|
||||||
|
except OSError as e:
|
||||||
|
print(f"[WARNING] Failed to load {so_path}: {e}")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user