""" # asr_api.py usage ` python asr_api.py -a 127.0.0.1 -p 9881 ` ## 调用: ### 语音转文本 endpoint: `/asr` GET: `http://127.0.0.1:9881/asr?audio_path=data/voice_ref/ref.wav&language=zh` POST: ```json { "audio_path": "data/voice_ref/ref.wav", "language": "zh", "engine": "auto", "model_size": "large-v3", "precision": "float16" } ``` RESP: 成功: {"text": "..."} 失败: {"message": "..."} (http code 400) """ import argparse import os import signal import sys from typing import Optional import torch from fastapi import FastAPI from fastapi.responses import JSONResponse from pydantic import BaseModel now_dir = os.getcwd() sys.path.append(now_dir) sys.path.append(f"{now_dir}/GPT_SoVITS") from faster_whisper import WhisperModel from tools.asr.config import get_models from tools.asr.fasterwhisper_asr import download_model, language_code_list from tools.asr.funasr_asr import only_asr from tools.my_utils import load_cudnn load_cudnn() parser = argparse.ArgumentParser(description="GPT-SoVITS ASR api") parser.add_argument("-a", "--bind_addr", type=str, default="127.0.0.1", help="default: 127.0.0.1") parser.add_argument("-p", "--port", type=int, default=9881, help="default: 9881") parser.add_argument( "-s", "--model_size", type=str, default="large-v3", choices=get_models(), help="default: large-v3", ) parser.add_argument( "-l", "--language", type=str, default="auto", choices=language_code_list, help="default language", ) parser.add_argument( "-pr", "--precision", type=str, default="float16", choices=["float16", "float32", "int8"], help="compute precision for faster-whisper", ) args = parser.parse_args() host = args.bind_addr port = args.port argv = sys.argv APP = FastAPI() class ASRRequest(BaseModel): audio_path: Optional[str] = None language: Optional[str] = None engine: str = "auto" # auto | fasterwhisper | funasr model_size: Optional[str] = None precision: Optional[str] = None whisper_models = {} whisper_model_paths = {} def handle_control(command: str): if command == "restart": os.execl(sys.executable, sys.executable, *argv) elif command == "exit": os.kill(os.getpid(), signal.SIGTERM) exit(0) def check_params(req: dict): audio_path = req.get("audio_path") language = req.get("language", "auto") engine = req.get("engine", "auto") model_size = req.get("model_size", args.model_size) precision = req.get("precision", args.precision) if not audio_path: return JSONResponse(status_code=400, content={"message": "audio_path is required"}) audio_path = os.path.abspath(audio_path) if not os.path.isfile(audio_path): return JSONResponse(status_code=400, content={"message": f"audio_path not found: {audio_path}"}) req["audio_path"] = audio_path if language not in language_code_list: return JSONResponse(status_code=400, content={"message": f"language not supported: {language}"}) if engine not in ["auto", "fasterwhisper", "funasr"]: return JSONResponse(status_code=400, content={"message": f"engine not supported: {engine}"}) if model_size not in get_models(): return JSONResponse(status_code=400, content={"message": f"model_size not supported: {model_size}"}) if precision not in ["float16", "float32", "int8"]: return JSONResponse(status_code=400, content={"message": f"precision not supported: {precision}"}) return None def get_whisper_model(model_size: str, precision: str): key = f"{model_size}:{precision}" if key in whisper_models: return whisper_models[key] if model_size not in whisper_model_paths: whisper_model_paths[model_size] = download_model(model_size) model_path = whisper_model_paths[model_size] device = "cuda" if torch.cuda.is_available() else "cpu" model = WhisperModel(model_path, device=device, compute_type=precision) whisper_models[key] = model return model def asr_with_fasterwhisper(audio_path: str, language: str, model_size: str, precision: str): model = get_whisper_model(model_size=model_size, precision=precision) fw_language = None if language == "auto" else language segments, info = model.transcribe( audio=audio_path, beam_size=5, vad_filter=True, vad_parameters=dict(min_silence_duration_ms=700), language=fw_language, ) # 中文/粤语默认转到 FunASR,文本稳定性通常更好 if info.language in ["zh", "yue"]: text = only_asr(audio_path, language=info.language.lower()) if text: return text text = "".join(segment.text for segment in segments).strip() return text def asr_with_funasr(audio_path: str, language: str): lang = language if language in ["zh", "yue"] else "zh" return (only_asr(audio_path, language=lang) or "").strip() async def asr_handle(req: dict): check_res = check_params(req) if check_res is not None: return check_res audio_path = req["audio_path"] language = req.get("language", args.language) engine = req.get("engine", "auto") model_size = req.get("model_size", args.model_size) precision = req.get("precision", args.precision) try: if engine == "funasr": text = asr_with_funasr(audio_path=audio_path, language=language) elif engine == "fasterwhisper": text = asr_with_fasterwhisper( audio_path=audio_path, language=language, model_size=model_size, precision=precision, ) else: # auto: 中文/粤语优先 FunASR,其它语种走 Faster-Whisper if language in ["zh", "yue"]: text = asr_with_funasr(audio_path=audio_path, language=language) if not text: text = asr_with_fasterwhisper( audio_path=audio_path, language=language, model_size=model_size, precision=precision, ) else: text = asr_with_fasterwhisper( audio_path=audio_path, language=language, model_size=model_size, precision=precision, ) return JSONResponse(status_code=200, content={"text": text}) except Exception as e: return JSONResponse(status_code=400, content={"message": "asr failed", "Exception": str(e)}) @APP.get("/control") async def control(command: str = None): if command is None: return JSONResponse(status_code=400, content={"message": "command is required"}) handle_control(command) @APP.get("/asr") async def asr_get_endpoint( audio_path: str = None, language: str = args.language, engine: str = "auto", model_size: str = args.model_size, precision: str = args.precision, ): req = { "audio_path": audio_path, "language": language.lower() if language else "auto", "engine": engine, "model_size": model_size, "precision": precision, } return await asr_handle(req) @APP.post("/asr") async def asr_post_endpoint(request: ASRRequest): req = request.dict() req["language"] = (req.get("language") or args.language).lower() req["model_size"] = req.get("model_size") or args.model_size req["precision"] = req.get("precision") or args.precision return await asr_handle(req) if __name__ == "__main__": import uvicorn uvicorn.run(APP, host=host, port=port, workers=1)