mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-16 05:36:34 +08:00
Add ModelScope Snapshot Download For ASR
This commit is contained in:
parent
11aa78bd9b
commit
84e902eaa8
@ -16,7 +16,7 @@ pypinyin
|
||||
pyopenjtalk>=0.4.1
|
||||
g2p_en
|
||||
torchaudio
|
||||
modelscope==1.10.0
|
||||
modelscope
|
||||
sentencepiece
|
||||
transformers>=4.43,<=4.50
|
||||
peft
|
||||
@ -39,7 +39,5 @@ x_transformers
|
||||
torchmetrics<=1.5
|
||||
pydantic<=2.10.6
|
||||
ctranslate2>=4.0,<5
|
||||
huggingface_hub>=0.13
|
||||
tokenizers>=0.13,<1
|
||||
av>=11
|
||||
tqdm
|
||||
|
@ -1,34 +1,13 @@
|
||||
import os
|
||||
|
||||
|
||||
def check_fw_local_models():
|
||||
"""
|
||||
启动时检查本地是否有 Faster Whisper 模型.
|
||||
"""
|
||||
model_size_list = [
|
||||
"medium",
|
||||
"medium.en",
|
||||
"distil-large-v2",
|
||||
"distil-large-v3",
|
||||
"large-v1",
|
||||
"large-v2",
|
||||
"large-v3",
|
||||
]
|
||||
for i, size in enumerate(model_size_list):
|
||||
if os.path.exists(f"tools/asr/models/faster-whisper-{size}"):
|
||||
model_size_list[i] = size + "-local"
|
||||
return model_size_list
|
||||
|
||||
|
||||
def get_models():
|
||||
model_size_list = [
|
||||
"medium",
|
||||
"medium.en",
|
||||
"distil-large-v2",
|
||||
"distil-large-v3",
|
||||
"large-v1",
|
||||
"large-v2",
|
||||
"large-v3",
|
||||
"large-v3-turbo",
|
||||
"distil-large-v2",
|
||||
"distil-large-v3",
|
||||
"distil-large-v3.5",
|
||||
]
|
||||
return model_size_list
|
||||
|
||||
@ -36,7 +15,7 @@ def get_models():
|
||||
asr_dict = {
|
||||
"达摩 ASR (中文)": {"lang": ["zh", "yue"], "size": ["large"], "path": "funasr_asr.py", "precision": ["float32"]},
|
||||
"Faster Whisper (多语种)": {
|
||||
"lang": ["auto", "zh", "en", "ja", "ko", "yue"],
|
||||
"lang": ["auto", "en", "ja", "ko", "yue"],
|
||||
"size": get_models(),
|
||||
"path": "fasterwhisper_asr.py",
|
||||
"precision": ["float32", "float16", "int8"],
|
||||
|
@ -1,12 +1,12 @@
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from faster_whisper import WhisperModel
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.errors import LocalEntryNotFoundError
|
||||
from huggingface_hub import snapshot_download as snapshot_download_hf
|
||||
from modelscope import snapshot_download as snapshot_download_ms
|
||||
from tqdm import tqdm
|
||||
|
||||
from tools.asr.config import get_models
|
||||
@ -40,11 +40,35 @@ language_code_list = [
|
||||
|
||||
|
||||
def download_model(model_size: str):
|
||||
if "distil" in model_size:
|
||||
repo_id = "Systran/faster-{}-whisper-{}".format(*model_size.split("-", maxsplit=1))
|
||||
url = "https://huggingface.co/api/models/gpt2"
|
||||
try:
|
||||
requests.get(url, timeout=3)
|
||||
source = "HF"
|
||||
except Exception:
|
||||
source = "ModelScope"
|
||||
|
||||
model_path = ""
|
||||
if source == "HF":
|
||||
if "distil" in model_size:
|
||||
if "3.5" in model_size:
|
||||
repo_id = "distil-whisper/distil-large-v3.5-ct2"
|
||||
model_path = "tools/asr/models/faster-whisper-distil-large-v3.5"
|
||||
else:
|
||||
repo_id = "Systran/faster-{}-whisper-{}".format(*model_size.split("-", maxsplit=1))
|
||||
elif model_size == "large-v3-turbo":
|
||||
repo_id = "mobiuslabsgmbh/faster-whisper-large-v3-turbo"
|
||||
model_path = "tools/asr/models/faster-whisper-large-v3-turbo"
|
||||
else:
|
||||
repo_id = f"Systran/faster-whisper-{model_size}"
|
||||
model_path = (
|
||||
model_path
|
||||
or f"tools/asr/models/{repo_id.replace('Systran/', '').replace('distil-whisper/', '', 1)}".replace(
|
||||
"distil-whisper", "whisper-distil"
|
||||
)
|
||||
)
|
||||
else:
|
||||
repo_id = f"Systran/faster-whisper-{model_size}"
|
||||
model_path = f"tools/asr/models/{repo_id.strip('Systran/')}"
|
||||
repo_id = "XXXXRT/faster-whisper"
|
||||
model_path = f"tools/asr/models/faster-whisper-{model_size}".replace("distil-whisper", "whisper-distil")
|
||||
|
||||
files: list[str] = [
|
||||
"config.json",
|
||||
@ -58,26 +82,24 @@ def download_model(model_size: str):
|
||||
|
||||
files.remove("vocabulary.txt")
|
||||
|
||||
for attempt in range(2):
|
||||
try:
|
||||
snapshot_download(
|
||||
repo_id=repo_id,
|
||||
allow_patterns=files,
|
||||
local_dir=model_path,
|
||||
)
|
||||
break
|
||||
except LocalEntryNotFoundError:
|
||||
if attempt < 1:
|
||||
time.sleep(2)
|
||||
else:
|
||||
print("[ERROR] LocalEntryNotFoundError and no fallback.")
|
||||
traceback.print_exc()
|
||||
exit(1)
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Unexpected error on attempt {attempt + 1}: {e}")
|
||||
traceback.print_exc()
|
||||
exit(1)
|
||||
if source == "ModelScope":
|
||||
files = [f"faster-whisper-{model_size}/{file}".replace("whisper-distil", "distil-whisper") for file in files]
|
||||
|
||||
if source == "HF":
|
||||
print(f"Downloading model from HuggingFace: {repo_id} to {model_path}")
|
||||
snapshot_download_hf(
|
||||
repo_id,
|
||||
local_dir=model_path,
|
||||
local_dir_use_symlinks=False,
|
||||
allow_patterns=files,
|
||||
)
|
||||
else:
|
||||
print(f"Downloading model from ModelScope: {repo_id} to {model_path}")
|
||||
snapshot_download_ms(
|
||||
repo_id,
|
||||
local_dir=model_path,
|
||||
allow_patterns=files,
|
||||
)
|
||||
return model_path
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user