mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-16 22:15:36 +08:00
Add ModelScope Snapshot Download For ASR
This commit is contained in:
parent
11aa78bd9b
commit
84e902eaa8
@ -16,7 +16,7 @@ pypinyin
|
|||||||
pyopenjtalk>=0.4.1
|
pyopenjtalk>=0.4.1
|
||||||
g2p_en
|
g2p_en
|
||||||
torchaudio
|
torchaudio
|
||||||
modelscope==1.10.0
|
modelscope
|
||||||
sentencepiece
|
sentencepiece
|
||||||
transformers>=4.43,<=4.50
|
transformers>=4.43,<=4.50
|
||||||
peft
|
peft
|
||||||
@ -39,7 +39,5 @@ x_transformers
|
|||||||
torchmetrics<=1.5
|
torchmetrics<=1.5
|
||||||
pydantic<=2.10.6
|
pydantic<=2.10.6
|
||||||
ctranslate2>=4.0,<5
|
ctranslate2>=4.0,<5
|
||||||
huggingface_hub>=0.13
|
|
||||||
tokenizers>=0.13,<1
|
|
||||||
av>=11
|
av>=11
|
||||||
tqdm
|
tqdm
|
||||||
|
@ -1,34 +1,13 @@
|
|||||||
import os
|
|
||||||
|
|
||||||
|
|
||||||
def check_fw_local_models():
|
|
||||||
"""
|
|
||||||
启动时检查本地是否有 Faster Whisper 模型.
|
|
||||||
"""
|
|
||||||
model_size_list = [
|
|
||||||
"medium",
|
|
||||||
"medium.en",
|
|
||||||
"distil-large-v2",
|
|
||||||
"distil-large-v3",
|
|
||||||
"large-v1",
|
|
||||||
"large-v2",
|
|
||||||
"large-v3",
|
|
||||||
]
|
|
||||||
for i, size in enumerate(model_size_list):
|
|
||||||
if os.path.exists(f"tools/asr/models/faster-whisper-{size}"):
|
|
||||||
model_size_list[i] = size + "-local"
|
|
||||||
return model_size_list
|
|
||||||
|
|
||||||
|
|
||||||
def get_models():
|
def get_models():
|
||||||
model_size_list = [
|
model_size_list = [
|
||||||
"medium",
|
"medium",
|
||||||
"medium.en",
|
"medium.en",
|
||||||
"distil-large-v2",
|
|
||||||
"distil-large-v3",
|
|
||||||
"large-v1",
|
|
||||||
"large-v2",
|
"large-v2",
|
||||||
"large-v3",
|
"large-v3",
|
||||||
|
"large-v3-turbo",
|
||||||
|
"distil-large-v2",
|
||||||
|
"distil-large-v3",
|
||||||
|
"distil-large-v3.5",
|
||||||
]
|
]
|
||||||
return model_size_list
|
return model_size_list
|
||||||
|
|
||||||
@ -36,7 +15,7 @@ def get_models():
|
|||||||
asr_dict = {
|
asr_dict = {
|
||||||
"达摩 ASR (中文)": {"lang": ["zh", "yue"], "size": ["large"], "path": "funasr_asr.py", "precision": ["float32"]},
|
"达摩 ASR (中文)": {"lang": ["zh", "yue"], "size": ["large"], "path": "funasr_asr.py", "precision": ["float32"]},
|
||||||
"Faster Whisper (多语种)": {
|
"Faster Whisper (多语种)": {
|
||||||
"lang": ["auto", "zh", "en", "ja", "ko", "yue"],
|
"lang": ["auto", "en", "ja", "ko", "yue"],
|
||||||
"size": get_models(),
|
"size": get_models(),
|
||||||
"path": "fasterwhisper_asr.py",
|
"path": "fasterwhisper_asr.py",
|
||||||
"precision": ["float32", "float16", "int8"],
|
"precision": ["float32", "float16", "int8"],
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import time
|
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
import requests
|
||||||
import torch
|
import torch
|
||||||
from faster_whisper import WhisperModel
|
from faster_whisper import WhisperModel
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download as snapshot_download_hf
|
||||||
from huggingface_hub.errors import LocalEntryNotFoundError
|
from modelscope import snapshot_download as snapshot_download_ms
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from tools.asr.config import get_models
|
from tools.asr.config import get_models
|
||||||
@ -40,11 +40,35 @@ language_code_list = [
|
|||||||
|
|
||||||
|
|
||||||
def download_model(model_size: str):
|
def download_model(model_size: str):
|
||||||
if "distil" in model_size:
|
url = "https://huggingface.co/api/models/gpt2"
|
||||||
repo_id = "Systran/faster-{}-whisper-{}".format(*model_size.split("-", maxsplit=1))
|
try:
|
||||||
|
requests.get(url, timeout=3)
|
||||||
|
source = "HF"
|
||||||
|
except Exception:
|
||||||
|
source = "ModelScope"
|
||||||
|
|
||||||
|
model_path = ""
|
||||||
|
if source == "HF":
|
||||||
|
if "distil" in model_size:
|
||||||
|
if "3.5" in model_size:
|
||||||
|
repo_id = "distil-whisper/distil-large-v3.5-ct2"
|
||||||
|
model_path = "tools/asr/models/faster-whisper-distil-large-v3.5"
|
||||||
|
else:
|
||||||
|
repo_id = "Systran/faster-{}-whisper-{}".format(*model_size.split("-", maxsplit=1))
|
||||||
|
elif model_size == "large-v3-turbo":
|
||||||
|
repo_id = "mobiuslabsgmbh/faster-whisper-large-v3-turbo"
|
||||||
|
model_path = "tools/asr/models/faster-whisper-large-v3-turbo"
|
||||||
|
else:
|
||||||
|
repo_id = f"Systran/faster-whisper-{model_size}"
|
||||||
|
model_path = (
|
||||||
|
model_path
|
||||||
|
or f"tools/asr/models/{repo_id.replace('Systran/', '').replace('distil-whisper/', '', 1)}".replace(
|
||||||
|
"distil-whisper", "whisper-distil"
|
||||||
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
repo_id = f"Systran/faster-whisper-{model_size}"
|
repo_id = "XXXXRT/faster-whisper"
|
||||||
model_path = f"tools/asr/models/{repo_id.strip('Systran/')}"
|
model_path = f"tools/asr/models/faster-whisper-{model_size}".replace("distil-whisper", "whisper-distil")
|
||||||
|
|
||||||
files: list[str] = [
|
files: list[str] = [
|
||||||
"config.json",
|
"config.json",
|
||||||
@ -58,26 +82,24 @@ def download_model(model_size: str):
|
|||||||
|
|
||||||
files.remove("vocabulary.txt")
|
files.remove("vocabulary.txt")
|
||||||
|
|
||||||
for attempt in range(2):
|
if source == "ModelScope":
|
||||||
try:
|
files = [f"faster-whisper-{model_size}/{file}".replace("whisper-distil", "distil-whisper") for file in files]
|
||||||
snapshot_download(
|
|
||||||
repo_id=repo_id,
|
|
||||||
allow_patterns=files,
|
|
||||||
local_dir=model_path,
|
|
||||||
)
|
|
||||||
break
|
|
||||||
except LocalEntryNotFoundError:
|
|
||||||
if attempt < 1:
|
|
||||||
time.sleep(2)
|
|
||||||
else:
|
|
||||||
print("[ERROR] LocalEntryNotFoundError and no fallback.")
|
|
||||||
traceback.print_exc()
|
|
||||||
exit(1)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[ERROR] Unexpected error on attempt {attempt + 1}: {e}")
|
|
||||||
traceback.print_exc()
|
|
||||||
exit(1)
|
|
||||||
|
|
||||||
|
if source == "HF":
|
||||||
|
print(f"Downloading model from HuggingFace: {repo_id} to {model_path}")
|
||||||
|
snapshot_download_hf(
|
||||||
|
repo_id,
|
||||||
|
local_dir=model_path,
|
||||||
|
local_dir_use_symlinks=False,
|
||||||
|
allow_patterns=files,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(f"Downloading model from ModelScope: {repo_id} to {model_path}")
|
||||||
|
snapshot_download_ms(
|
||||||
|
repo_id,
|
||||||
|
local_dir=model_path,
|
||||||
|
allow_patterns=files,
|
||||||
|
)
|
||||||
return model_path
|
return model_path
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user