Merge 19f51ae6f9b3a3dfa12758afb5e8f7e10cd0275d into 08d627c3338173c3229286d8787060d6559fe0f8

This commit is contained in:
DayDaySpeed 2026-06-01 19:55:42 +08:00 committed by GitHub
commit bcdfbbd33f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

259
asr_api.py Normal file
View File

@ -0,0 +1,259 @@
"""
# asr_api.py usage
` python asr_api.py -a 127.0.0.1 -p 9881 `
## 调用:
### 语音转文本
endpoint: `/asr`
GET:
`http://127.0.0.1:9881/asr?audio_path=data/voice_ref/ref.wav&language=zh`
POST:
```json
{
"audio_path": "data/voice_ref/ref.wav",
"language": "zh",
"engine": "auto",
"model_size": "large-v3",
"precision": "float16"
}
```
RESP:
成功: {"text": "..."}
失败: {"message": "..."} (http code 400)
"""
import argparse
import os
import signal
import sys
from typing import Optional
import torch
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from pydantic import BaseModel
now_dir = os.getcwd()
sys.path.append(now_dir)
sys.path.append(f"{now_dir}/GPT_SoVITS")
from faster_whisper import WhisperModel
from tools.asr.config import get_models
from tools.asr.fasterwhisper_asr import download_model, language_code_list
from tools.asr.funasr_asr import only_asr
from tools.my_utils import load_cudnn
load_cudnn()
parser = argparse.ArgumentParser(description="GPT-SoVITS ASR api")
parser.add_argument("-a", "--bind_addr", type=str, default="127.0.0.1", help="default: 127.0.0.1")
parser.add_argument("-p", "--port", type=int, default=9881, help="default: 9881")
parser.add_argument(
"-s",
"--model_size",
type=str,
default="large-v3",
choices=get_models(),
help="default: large-v3",
)
parser.add_argument(
"-l",
"--language",
type=str,
default="auto",
choices=language_code_list,
help="default language",
)
parser.add_argument(
"-pr",
"--precision",
type=str,
default="float16",
choices=["float16", "float32", "int8"],
help="compute precision for faster-whisper",
)
args = parser.parse_args()
host = args.bind_addr
port = args.port
argv = sys.argv
APP = FastAPI()
class ASRRequest(BaseModel):
audio_path: Optional[str] = None
language: Optional[str] = None
engine: str = "auto" # auto | fasterwhisper | funasr
model_size: Optional[str] = None
precision: Optional[str] = None
whisper_models = {}
whisper_model_paths = {}
def handle_control(command: str):
if command == "restart":
os.execl(sys.executable, sys.executable, *argv)
elif command == "exit":
os.kill(os.getpid(), signal.SIGTERM)
exit(0)
def check_params(req: dict):
audio_path = req.get("audio_path")
language = req.get("language", "auto")
engine = req.get("engine", "auto")
model_size = req.get("model_size", args.model_size)
precision = req.get("precision", args.precision)
if not audio_path:
return JSONResponse(status_code=400, content={"message": "audio_path is required"})
audio_path = os.path.abspath(audio_path)
if not os.path.isfile(audio_path):
return JSONResponse(status_code=400, content={"message": f"audio_path not found: {audio_path}"})
req["audio_path"] = audio_path
if language not in language_code_list:
return JSONResponse(status_code=400, content={"message": f"language not supported: {language}"})
if engine not in ["auto", "fasterwhisper", "funasr"]:
return JSONResponse(status_code=400, content={"message": f"engine not supported: {engine}"})
if model_size not in get_models():
return JSONResponse(status_code=400, content={"message": f"model_size not supported: {model_size}"})
if precision not in ["float16", "float32", "int8"]:
return JSONResponse(status_code=400, content={"message": f"precision not supported: {precision}"})
return None
def get_whisper_model(model_size: str, precision: str):
key = f"{model_size}:{precision}"
if key in whisper_models:
return whisper_models[key]
if model_size not in whisper_model_paths:
whisper_model_paths[model_size] = download_model(model_size)
model_path = whisper_model_paths[model_size]
device = "cuda" if torch.cuda.is_available() else "cpu"
model = WhisperModel(model_path, device=device, compute_type=precision)
whisper_models[key] = model
return model
def asr_with_fasterwhisper(audio_path: str, language: str, model_size: str, precision: str):
model = get_whisper_model(model_size=model_size, precision=precision)
fw_language = None if language == "auto" else language
segments, info = model.transcribe(
audio=audio_path,
beam_size=5,
vad_filter=True,
vad_parameters=dict(min_silence_duration_ms=700),
language=fw_language,
)
# 中文/粤语默认转到 FunASR文本稳定性通常更好
if info.language in ["zh", "yue"]:
text = only_asr(audio_path, language=info.language.lower())
if text:
return text
text = "".join(segment.text for segment in segments).strip()
return text
def asr_with_funasr(audio_path: str, language: str):
lang = language if language in ["zh", "yue"] else "zh"
return (only_asr(audio_path, language=lang) or "").strip()
async def asr_handle(req: dict):
check_res = check_params(req)
if check_res is not None:
return check_res
audio_path = req["audio_path"]
language = req.get("language", args.language)
engine = req.get("engine", "auto")
model_size = req.get("model_size", args.model_size)
precision = req.get("precision", args.precision)
try:
if engine == "funasr":
text = asr_with_funasr(audio_path=audio_path, language=language)
elif engine == "fasterwhisper":
text = asr_with_fasterwhisper(
audio_path=audio_path,
language=language,
model_size=model_size,
precision=precision,
)
else:
# auto: 中文/粤语优先 FunASR其它语种走 Faster-Whisper
if language in ["zh", "yue"]:
text = asr_with_funasr(audio_path=audio_path, language=language)
if not text:
text = asr_with_fasterwhisper(
audio_path=audio_path,
language=language,
model_size=model_size,
precision=precision,
)
else:
text = asr_with_fasterwhisper(
audio_path=audio_path,
language=language,
model_size=model_size,
precision=precision,
)
return JSONResponse(status_code=200, content={"text": text})
except Exception as e:
return JSONResponse(status_code=400, content={"message": "asr failed", "Exception": str(e)})
@APP.get("/control")
async def control(command: str = None):
if command is None:
return JSONResponse(status_code=400, content={"message": "command is required"})
handle_control(command)
@APP.get("/asr")
async def asr_get_endpoint(
audio_path: str = None,
language: str = args.language,
engine: str = "auto",
model_size: str = args.model_size,
precision: str = args.precision,
):
req = {
"audio_path": audio_path,
"language": language.lower() if language else "auto",
"engine": engine,
"model_size": model_size,
"precision": precision,
}
return await asr_handle(req)
@APP.post("/asr")
async def asr_post_endpoint(request: ASRRequest):
req = request.dict()
req["language"] = (req.get("language") or args.language).lower()
req["model_size"] = req.get("model_size") or args.model_size
req["precision"] = req.get("precision") or args.precision
return await asr_handle(req)
if __name__ == "__main__":
import uvicorn
uvicorn.run(APP, host=host, port=port, workers=1)