Add ModelScope Snapshot Download For ASR

This commit is contained in:
XXXXRT666 2025-10-05 12:35:01 +01:00
parent 11aa78bd9b
commit 84e902eaa8
3 changed files with 54 additions and 55 deletions

View File

@ -16,7 +16,7 @@ pypinyin
pyopenjtalk>=0.4.1
g2p_en
torchaudio
modelscope==1.10.0
modelscope
sentencepiece
transformers>=4.43,<=4.50
peft
@ -39,7 +39,5 @@ x_transformers
torchmetrics<=1.5
pydantic<=2.10.6
ctranslate2>=4.0,<5
huggingface_hub>=0.13
tokenizers>=0.13,<1
av>=11
tqdm

View File

@ -1,34 +1,13 @@
import os
def check_fw_local_models():
"""
启动时检查本地是否有 Faster Whisper 模型.
"""
model_size_list = [
"medium",
"medium.en",
"distil-large-v2",
"distil-large-v3",
"large-v1",
"large-v2",
"large-v3",
]
for i, size in enumerate(model_size_list):
if os.path.exists(f"tools/asr/models/faster-whisper-{size}"):
model_size_list[i] = size + "-local"
return model_size_list
def get_models():
model_size_list = [
"medium",
"medium.en",
"distil-large-v2",
"distil-large-v3",
"large-v1",
"large-v2",
"large-v3",
"large-v3-turbo",
"distil-large-v2",
"distil-large-v3",
"distil-large-v3.5",
]
return model_size_list
@ -36,7 +15,7 @@ def get_models():
asr_dict = {
"达摩 ASR (中文)": {"lang": ["zh", "yue"], "size": ["large"], "path": "funasr_asr.py", "precision": ["float32"]},
"Faster Whisper (多语种)": {
"lang": ["auto", "zh", "en", "ja", "ko", "yue"],
"lang": ["auto", "en", "ja", "ko", "yue"],
"size": get_models(),
"path": "fasterwhisper_asr.py",
"precision": ["float32", "float16", "int8"],

View File

@ -1,12 +1,12 @@
import argparse
import os
import time
import traceback
import requests
import torch
from faster_whisper import WhisperModel
from huggingface_hub import snapshot_download
from huggingface_hub.errors import LocalEntryNotFoundError
from huggingface_hub import snapshot_download as snapshot_download_hf
from modelscope import snapshot_download as snapshot_download_ms
from tqdm import tqdm
from tools.asr.config import get_models
@ -40,11 +40,35 @@ language_code_list = [
def download_model(model_size: str):
if "distil" in model_size:
repo_id = "Systran/faster-{}-whisper-{}".format(*model_size.split("-", maxsplit=1))
url = "https://huggingface.co/api/models/gpt2"
try:
requests.get(url, timeout=3)
source = "HF"
except Exception:
source = "ModelScope"
model_path = ""
if source == "HF":
if "distil" in model_size:
if "3.5" in model_size:
repo_id = "distil-whisper/distil-large-v3.5-ct2"
model_path = "tools/asr/models/faster-whisper-distil-large-v3.5"
else:
repo_id = "Systran/faster-{}-whisper-{}".format(*model_size.split("-", maxsplit=1))
elif model_size == "large-v3-turbo":
repo_id = "mobiuslabsgmbh/faster-whisper-large-v3-turbo"
model_path = "tools/asr/models/faster-whisper-large-v3-turbo"
else:
repo_id = f"Systran/faster-whisper-{model_size}"
model_path = (
model_path
or f"tools/asr/models/{repo_id.replace('Systran/', '').replace('distil-whisper/', '', 1)}".replace(
"distil-whisper", "whisper-distil"
)
)
else:
repo_id = f"Systran/faster-whisper-{model_size}"
model_path = f"tools/asr/models/{repo_id.strip('Systran/')}"
repo_id = "XXXXRT/faster-whisper"
model_path = f"tools/asr/models/faster-whisper-{model_size}".replace("distil-whisper", "whisper-distil")
files: list[str] = [
"config.json",
@ -58,26 +82,24 @@ def download_model(model_size: str):
files.remove("vocabulary.txt")
for attempt in range(2):
try:
snapshot_download(
repo_id=repo_id,
allow_patterns=files,
local_dir=model_path,
)
break
except LocalEntryNotFoundError:
if attempt < 1:
time.sleep(2)
else:
print("[ERROR] LocalEntryNotFoundError and no fallback.")
traceback.print_exc()
exit(1)
except Exception as e:
print(f"[ERROR] Unexpected error on attempt {attempt + 1}: {e}")
traceback.print_exc()
exit(1)
if source == "ModelScope":
files = [f"faster-whisper-{model_size}/{file}".replace("whisper-distil", "distil-whisper") for file in files]
if source == "HF":
print(f"Downloading model from HuggingFace: {repo_id} to {model_path}")
snapshot_download_hf(
repo_id,
local_dir=model_path,
local_dir_use_symlinks=False,
allow_patterns=files,
)
else:
print(f"Downloading model from ModelScope: {repo_id} to {model_path}")
snapshot_download_ms(
repo_id,
local_dir=model_path,
allow_patterns=files,
)
return model_path