From 84e902eaa8ec5739230f9e9040f1b79249556751 Mon Sep 17 00:00:00 2001 From: XXXXRT666 <157766680+XXXXRT666@users.noreply.github.com> Date: Sun, 5 Oct 2025 12:35:01 +0100 Subject: [PATCH] Add ModelScope Snapshot Download For ASR --- requirements.txt | 4 +- tools/asr/config.py | 31 +++----------- tools/asr/fasterwhisper_asr.py | 74 ++++++++++++++++++++++------------ 3 files changed, 54 insertions(+), 55 deletions(-) diff --git a/requirements.txt b/requirements.txt index 90e4957d..578bb87c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,7 +16,7 @@ pypinyin pyopenjtalk>=0.4.1 g2p_en torchaudio -modelscope==1.10.0 +modelscope sentencepiece transformers>=4.43,<=4.50 peft @@ -39,7 +39,5 @@ x_transformers torchmetrics<=1.5 pydantic<=2.10.6 ctranslate2>=4.0,<5 -huggingface_hub>=0.13 -tokenizers>=0.13,<1 av>=11 tqdm diff --git a/tools/asr/config.py b/tools/asr/config.py index 9c26a4f6..097aa643 100644 --- a/tools/asr/config.py +++ b/tools/asr/config.py @@ -1,34 +1,13 @@ -import os - - -def check_fw_local_models(): - """ - 启动时检查本地是否有 Faster Whisper 模型. - """ - model_size_list = [ - "medium", - "medium.en", - "distil-large-v2", - "distil-large-v3", - "large-v1", - "large-v2", - "large-v3", - ] - for i, size in enumerate(model_size_list): - if os.path.exists(f"tools/asr/models/faster-whisper-{size}"): - model_size_list[i] = size + "-local" - return model_size_list - - def get_models(): model_size_list = [ "medium", "medium.en", - "distil-large-v2", - "distil-large-v3", - "large-v1", "large-v2", "large-v3", + "large-v3-turbo", + "distil-large-v2", + "distil-large-v3", + "distil-large-v3.5", ] return model_size_list @@ -36,7 +15,7 @@ def get_models(): asr_dict = { "达摩 ASR (中文)": {"lang": ["zh", "yue"], "size": ["large"], "path": "funasr_asr.py", "precision": ["float32"]}, "Faster Whisper (多语种)": { - "lang": ["auto", "zh", "en", "ja", "ko", "yue"], + "lang": ["auto", "en", "ja", "ko", "yue"], "size": get_models(), "path": "fasterwhisper_asr.py", "precision": ["float32", "float16", "int8"], diff --git a/tools/asr/fasterwhisper_asr.py b/tools/asr/fasterwhisper_asr.py index a2ebe975..1f98b840 100644 --- a/tools/asr/fasterwhisper_asr.py +++ b/tools/asr/fasterwhisper_asr.py @@ -1,12 +1,12 @@ import argparse import os -import time import traceback +import requests import torch from faster_whisper import WhisperModel -from huggingface_hub import snapshot_download -from huggingface_hub.errors import LocalEntryNotFoundError +from huggingface_hub import snapshot_download as snapshot_download_hf +from modelscope import snapshot_download as snapshot_download_ms from tqdm import tqdm from tools.asr.config import get_models @@ -40,11 +40,35 @@ language_code_list = [ def download_model(model_size: str): - if "distil" in model_size: - repo_id = "Systran/faster-{}-whisper-{}".format(*model_size.split("-", maxsplit=1)) + url = "https://huggingface.co/api/models/gpt2" + try: + requests.get(url, timeout=3) + source = "HF" + except Exception: + source = "ModelScope" + + model_path = "" + if source == "HF": + if "distil" in model_size: + if "3.5" in model_size: + repo_id = "distil-whisper/distil-large-v3.5-ct2" + model_path = "tools/asr/models/faster-whisper-distil-large-v3.5" + else: + repo_id = "Systran/faster-{}-whisper-{}".format(*model_size.split("-", maxsplit=1)) + elif model_size == "large-v3-turbo": + repo_id = "mobiuslabsgmbh/faster-whisper-large-v3-turbo" + model_path = "tools/asr/models/faster-whisper-large-v3-turbo" + else: + repo_id = f"Systran/faster-whisper-{model_size}" + model_path = ( + model_path + or f"tools/asr/models/{repo_id.replace('Systran/', '').replace('distil-whisper/', '', 1)}".replace( + "distil-whisper", "whisper-distil" + ) + ) else: - repo_id = f"Systran/faster-whisper-{model_size}" - model_path = f"tools/asr/models/{repo_id.strip('Systran/')}" + repo_id = "XXXXRT/faster-whisper" + model_path = f"tools/asr/models/faster-whisper-{model_size}".replace("distil-whisper", "whisper-distil") files: list[str] = [ "config.json", @@ -58,26 +82,24 @@ def download_model(model_size: str): files.remove("vocabulary.txt") - for attempt in range(2): - try: - snapshot_download( - repo_id=repo_id, - allow_patterns=files, - local_dir=model_path, - ) - break - except LocalEntryNotFoundError: - if attempt < 1: - time.sleep(2) - else: - print("[ERROR] LocalEntryNotFoundError and no fallback.") - traceback.print_exc() - exit(1) - except Exception as e: - print(f"[ERROR] Unexpected error on attempt {attempt + 1}: {e}") - traceback.print_exc() - exit(1) + if source == "ModelScope": + files = [f"faster-whisper-{model_size}/{file}".replace("whisper-distil", "distil-whisper") for file in files] + if source == "HF": + print(f"Downloading model from HuggingFace: {repo_id} to {model_path}") + snapshot_download_hf( + repo_id, + local_dir=model_path, + local_dir_use_symlinks=False, + allow_patterns=files, + ) + else: + print(f"Downloading model from ModelScope: {repo_id} to {model_path}") + snapshot_download_ms( + repo_id, + local_dir=model_path, + allow_patterns=files, + ) return model_path