mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-06-04 05:01:27 +08:00
Merge 19f51ae6f9b3a3dfa12758afb5e8f7e10cd0275d into 08d627c3338173c3229286d8787060d6559fe0f8
This commit is contained in:
commit
bcdfbbd33f
259
asr_api.py
Normal file
259
asr_api.py
Normal file
@ -0,0 +1,259 @@
|
||||
"""
|
||||
# asr_api.py usage
|
||||
|
||||
` python asr_api.py -a 127.0.0.1 -p 9881 `
|
||||
|
||||
## 调用:
|
||||
|
||||
### 语音转文本
|
||||
|
||||
endpoint: `/asr`
|
||||
|
||||
GET:
|
||||
`http://127.0.0.1:9881/asr?audio_path=data/voice_ref/ref.wav&language=zh`
|
||||
|
||||
POST:
|
||||
```json
|
||||
{
|
||||
"audio_path": "data/voice_ref/ref.wav",
|
||||
"language": "zh",
|
||||
"engine": "auto",
|
||||
"model_size": "large-v3",
|
||||
"precision": "float16"
|
||||
}
|
||||
```
|
||||
|
||||
RESP:
|
||||
成功: {"text": "..."}
|
||||
失败: {"message": "..."} (http code 400)
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
sys.path.append(f"{now_dir}/GPT_SoVITS")
|
||||
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
from tools.asr.config import get_models
|
||||
from tools.asr.fasterwhisper_asr import download_model, language_code_list
|
||||
from tools.asr.funasr_asr import only_asr
|
||||
from tools.my_utils import load_cudnn
|
||||
|
||||
|
||||
load_cudnn()
|
||||
|
||||
parser = argparse.ArgumentParser(description="GPT-SoVITS ASR api")
|
||||
parser.add_argument("-a", "--bind_addr", type=str, default="127.0.0.1", help="default: 127.0.0.1")
|
||||
parser.add_argument("-p", "--port", type=int, default=9881, help="default: 9881")
|
||||
parser.add_argument(
|
||||
"-s",
|
||||
"--model_size",
|
||||
type=str,
|
||||
default="large-v3",
|
||||
choices=get_models(),
|
||||
help="default: large-v3",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-l",
|
||||
"--language",
|
||||
type=str,
|
||||
default="auto",
|
||||
choices=language_code_list,
|
||||
help="default language",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-pr",
|
||||
"--precision",
|
||||
type=str,
|
||||
default="float16",
|
||||
choices=["float16", "float32", "int8"],
|
||||
help="compute precision for faster-whisper",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
host = args.bind_addr
|
||||
port = args.port
|
||||
argv = sys.argv
|
||||
|
||||
APP = FastAPI()
|
||||
|
||||
|
||||
class ASRRequest(BaseModel):
|
||||
audio_path: Optional[str] = None
|
||||
language: Optional[str] = None
|
||||
engine: str = "auto" # auto | fasterwhisper | funasr
|
||||
model_size: Optional[str] = None
|
||||
precision: Optional[str] = None
|
||||
|
||||
|
||||
whisper_models = {}
|
||||
whisper_model_paths = {}
|
||||
|
||||
|
||||
def handle_control(command: str):
|
||||
if command == "restart":
|
||||
os.execl(sys.executable, sys.executable, *argv)
|
||||
elif command == "exit":
|
||||
os.kill(os.getpid(), signal.SIGTERM)
|
||||
exit(0)
|
||||
|
||||
|
||||
def check_params(req: dict):
|
||||
audio_path = req.get("audio_path")
|
||||
language = req.get("language", "auto")
|
||||
engine = req.get("engine", "auto")
|
||||
model_size = req.get("model_size", args.model_size)
|
||||
precision = req.get("precision", args.precision)
|
||||
|
||||
if not audio_path:
|
||||
return JSONResponse(status_code=400, content={"message": "audio_path is required"})
|
||||
audio_path = os.path.abspath(audio_path)
|
||||
if not os.path.isfile(audio_path):
|
||||
return JSONResponse(status_code=400, content={"message": f"audio_path not found: {audio_path}"})
|
||||
req["audio_path"] = audio_path
|
||||
if language not in language_code_list:
|
||||
return JSONResponse(status_code=400, content={"message": f"language not supported: {language}"})
|
||||
if engine not in ["auto", "fasterwhisper", "funasr"]:
|
||||
return JSONResponse(status_code=400, content={"message": f"engine not supported: {engine}"})
|
||||
if model_size not in get_models():
|
||||
return JSONResponse(status_code=400, content={"message": f"model_size not supported: {model_size}"})
|
||||
if precision not in ["float16", "float32", "int8"]:
|
||||
return JSONResponse(status_code=400, content={"message": f"precision not supported: {precision}"})
|
||||
return None
|
||||
|
||||
|
||||
def get_whisper_model(model_size: str, precision: str):
|
||||
key = f"{model_size}:{precision}"
|
||||
if key in whisper_models:
|
||||
return whisper_models[key]
|
||||
|
||||
if model_size not in whisper_model_paths:
|
||||
whisper_model_paths[model_size] = download_model(model_size)
|
||||
model_path = whisper_model_paths[model_size]
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = WhisperModel(model_path, device=device, compute_type=precision)
|
||||
whisper_models[key] = model
|
||||
return model
|
||||
|
||||
|
||||
def asr_with_fasterwhisper(audio_path: str, language: str, model_size: str, precision: str):
|
||||
model = get_whisper_model(model_size=model_size, precision=precision)
|
||||
fw_language = None if language == "auto" else language
|
||||
segments, info = model.transcribe(
|
||||
audio=audio_path,
|
||||
beam_size=5,
|
||||
vad_filter=True,
|
||||
vad_parameters=dict(min_silence_duration_ms=700),
|
||||
language=fw_language,
|
||||
)
|
||||
|
||||
# 中文/粤语默认转到 FunASR,文本稳定性通常更好
|
||||
if info.language in ["zh", "yue"]:
|
||||
text = only_asr(audio_path, language=info.language.lower())
|
||||
if text:
|
||||
return text
|
||||
|
||||
text = "".join(segment.text for segment in segments).strip()
|
||||
return text
|
||||
|
||||
|
||||
def asr_with_funasr(audio_path: str, language: str):
|
||||
lang = language if language in ["zh", "yue"] else "zh"
|
||||
return (only_asr(audio_path, language=lang) or "").strip()
|
||||
|
||||
|
||||
async def asr_handle(req: dict):
|
||||
check_res = check_params(req)
|
||||
if check_res is not None:
|
||||
return check_res
|
||||
|
||||
audio_path = req["audio_path"]
|
||||
language = req.get("language", args.language)
|
||||
engine = req.get("engine", "auto")
|
||||
model_size = req.get("model_size", args.model_size)
|
||||
precision = req.get("precision", args.precision)
|
||||
|
||||
try:
|
||||
if engine == "funasr":
|
||||
text = asr_with_funasr(audio_path=audio_path, language=language)
|
||||
elif engine == "fasterwhisper":
|
||||
text = asr_with_fasterwhisper(
|
||||
audio_path=audio_path,
|
||||
language=language,
|
||||
model_size=model_size,
|
||||
precision=precision,
|
||||
)
|
||||
else:
|
||||
# auto: 中文/粤语优先 FunASR,其它语种走 Faster-Whisper
|
||||
if language in ["zh", "yue"]:
|
||||
text = asr_with_funasr(audio_path=audio_path, language=language)
|
||||
if not text:
|
||||
text = asr_with_fasterwhisper(
|
||||
audio_path=audio_path,
|
||||
language=language,
|
||||
model_size=model_size,
|
||||
precision=precision,
|
||||
)
|
||||
else:
|
||||
text = asr_with_fasterwhisper(
|
||||
audio_path=audio_path,
|
||||
language=language,
|
||||
model_size=model_size,
|
||||
precision=precision,
|
||||
)
|
||||
|
||||
return JSONResponse(status_code=200, content={"text": text})
|
||||
except Exception as e:
|
||||
return JSONResponse(status_code=400, content={"message": "asr failed", "Exception": str(e)})
|
||||
|
||||
|
||||
@APP.get("/control")
|
||||
async def control(command: str = None):
|
||||
if command is None:
|
||||
return JSONResponse(status_code=400, content={"message": "command is required"})
|
||||
handle_control(command)
|
||||
|
||||
|
||||
@APP.get("/asr")
|
||||
async def asr_get_endpoint(
|
||||
audio_path: str = None,
|
||||
language: str = args.language,
|
||||
engine: str = "auto",
|
||||
model_size: str = args.model_size,
|
||||
precision: str = args.precision,
|
||||
):
|
||||
req = {
|
||||
"audio_path": audio_path,
|
||||
"language": language.lower() if language else "auto",
|
||||
"engine": engine,
|
||||
"model_size": model_size,
|
||||
"precision": precision,
|
||||
}
|
||||
return await asr_handle(req)
|
||||
|
||||
|
||||
@APP.post("/asr")
|
||||
async def asr_post_endpoint(request: ASRRequest):
|
||||
req = request.dict()
|
||||
req["language"] = (req.get("language") or args.language).lower()
|
||||
req["model_size"] = req.get("model_size") or args.model_size
|
||||
req["precision"] = req.get("precision") or args.precision
|
||||
return await asr_handle(req)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(APP, host=host, port=port, workers=1)
|
||||
Loading…
x
Reference in New Issue
Block a user