Preload CUDNN For Ctranslate2

This commit is contained in:
XXXXRT666 2025-05-09 10:44:08 +01:00
parent 5c4f9e8e00
commit 8f76e19603
2 changed files with 48 additions and 2 deletions

View File

@ -10,6 +10,7 @@ from faster_whisper import WhisperModel
from tqdm import tqdm
from tools.asr.config import check_fw_local_models
from tools.my_utils import load_cudnn
# fmt: off
language_code_list = [
@ -93,6 +94,8 @@ def execute_asr(input_folder, output_folder, model_size, language, precision):
return output_file_path
load_cudnn()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(

View File

@ -1,11 +1,17 @@
import ctypes
import glob
import os
import sys
import traceback
from pathlib import Path
import ffmpeg
import numpy as np
import gradio as gr
from tools.i18n.i18n import I18nAuto
import numpy as np
import pandas as pd
from tools.i18n.i18n import I18nAuto
i18n = I18nAuto(language=os.environ.get("language", "Auto"))
@ -127,3 +133,40 @@ def check_details(path_list=None, is_train=False, is_dataset_processing=False):
...
else:
gr.Warning(i18n("缺少语义数据集"))
def load_cudnn():
import torch
if not torch.cuda.is_available():
print("[INFO] CUDA is not available, skipping cuDNN setup.")
return
if sys.platform == "win32":
torch_lib_dir = Path(torch.__file__).parent / "lib"
if torch_lib_dir.exists():
os.add_dll_directory(str(torch_lib_dir))
print(f"[INFO] Added DLL directory: {torch_lib_dir}")
else:
print(f"[WARNING] Torch lib directory not found: {torch_lib_dir}")
elif sys.platform == "linux":
site_packages = Path(torch.__file__).resolve().parents[1]
cudnn_dir = site_packages / "nvidia" / "cudnn" / "lib"
if not cudnn_dir.exists():
print(f"[ERROR] cudnn dir not found: {cudnn_dir}")
return
pattern = str(cudnn_dir / "libcudnn_cnn*.so*")
matching_files = sorted(glob.glob(pattern))
if not matching_files:
print(f"[ERROR] No libcudnn_cnn*.so* found in {cudnn_dir}")
return
for so_path in matching_files:
try:
ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL)
print(f"[INFO] Loaded: {so_path}")
except OSError as e:
print(f"[WARNING] Failed to load {so_path}: {e}")