diff --git a/tools/asr/fasterwhisper_asr.py b/tools/asr/fasterwhisper_asr.py index e570f174..27cabbc2 100644 --- a/tools/asr/fasterwhisper_asr.py +++ b/tools/asr/fasterwhisper_asr.py @@ -10,6 +10,7 @@ from faster_whisper import WhisperModel from tqdm import tqdm from tools.asr.config import check_fw_local_models +from tools.my_utils import load_cudnn # fmt: off language_code_list = [ @@ -93,6 +94,8 @@ def execute_asr(input_folder, output_folder, model_size, language, precision): return output_file_path +load_cudnn() + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( diff --git a/tools/my_utils.py b/tools/my_utils.py index 44d326e1..21fd55f8 100644 --- a/tools/my_utils.py +++ b/tools/my_utils.py @@ -1,11 +1,17 @@ +import ctypes +import glob import os +import sys import traceback +from pathlib import Path + import ffmpeg -import numpy as np import gradio as gr -from tools.i18n.i18n import I18nAuto +import numpy as np import pandas as pd +from tools.i18n.i18n import I18nAuto + i18n = I18nAuto(language=os.environ.get("language", "Auto")) @@ -127,3 +133,40 @@ def check_details(path_list=None, is_train=False, is_dataset_processing=False): ... else: gr.Warning(i18n("缺少语义数据集")) + + +def load_cudnn(): + import torch + + if not torch.cuda.is_available(): + print("[INFO] CUDA is not available, skipping cuDNN setup.") + return + + if sys.platform == "win32": + torch_lib_dir = Path(torch.__file__).parent / "lib" + if torch_lib_dir.exists(): + os.add_dll_directory(str(torch_lib_dir)) + print(f"[INFO] Added DLL directory: {torch_lib_dir}") + else: + print(f"[WARNING] Torch lib directory not found: {torch_lib_dir}") + + elif sys.platform == "linux": + site_packages = Path(torch.__file__).resolve().parents[1] + cudnn_dir = site_packages / "nvidia" / "cudnn" / "lib" + + if not cudnn_dir.exists(): + print(f"[ERROR] cudnn dir not found: {cudnn_dir}") + return + + pattern = str(cudnn_dir / "libcudnn_cnn*.so*") + matching_files = sorted(glob.glob(pattern)) + if not matching_files: + print(f"[ERROR] No libcudnn_cnn*.so* found in {cudnn_dir}") + return + + for so_path in matching_files: + try: + ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL) + print(f"[INFO] Loaded: {so_path}") + except OSError as e: + print(f"[WARNING] Failed to load {so_path}: {e}")