mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-06-04 05:01:27 +08:00
Merge 19f51ae6f9b3a3dfa12758afb5e8f7e10cd0275d into 08d627c3338173c3229286d8787060d6559fe0f8
This commit is contained in:
commit
bcdfbbd33f
259
asr_api.py
Normal file
259
asr_api.py
Normal file
@ -0,0 +1,259 @@
|
|||||||
|
"""
|
||||||
|
# asr_api.py usage
|
||||||
|
|
||||||
|
` python asr_api.py -a 127.0.0.1 -p 9881 `
|
||||||
|
|
||||||
|
## 调用:
|
||||||
|
|
||||||
|
### 语音转文本
|
||||||
|
|
||||||
|
endpoint: `/asr`
|
||||||
|
|
||||||
|
GET:
|
||||||
|
`http://127.0.0.1:9881/asr?audio_path=data/voice_ref/ref.wav&language=zh`
|
||||||
|
|
||||||
|
POST:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"audio_path": "data/voice_ref/ref.wav",
|
||||||
|
"language": "zh",
|
||||||
|
"engine": "auto",
|
||||||
|
"model_size": "large-v3",
|
||||||
|
"precision": "float16"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
RESP:
|
||||||
|
成功: {"text": "..."}
|
||||||
|
失败: {"message": "..."} (http code 400)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import signal
|
||||||
|
import sys
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
now_dir = os.getcwd()
|
||||||
|
sys.path.append(now_dir)
|
||||||
|
sys.path.append(f"{now_dir}/GPT_SoVITS")
|
||||||
|
|
||||||
|
from faster_whisper import WhisperModel
|
||||||
|
|
||||||
|
from tools.asr.config import get_models
|
||||||
|
from tools.asr.fasterwhisper_asr import download_model, language_code_list
|
||||||
|
from tools.asr.funasr_asr import only_asr
|
||||||
|
from tools.my_utils import load_cudnn
|
||||||
|
|
||||||
|
|
||||||
|
load_cudnn()
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="GPT-SoVITS ASR api")
|
||||||
|
parser.add_argument("-a", "--bind_addr", type=str, default="127.0.0.1", help="default: 127.0.0.1")
|
||||||
|
parser.add_argument("-p", "--port", type=int, default=9881, help="default: 9881")
|
||||||
|
parser.add_argument(
|
||||||
|
"-s",
|
||||||
|
"--model_size",
|
||||||
|
type=str,
|
||||||
|
default="large-v3",
|
||||||
|
choices=get_models(),
|
||||||
|
help="default: large-v3",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-l",
|
||||||
|
"--language",
|
||||||
|
type=str,
|
||||||
|
default="auto",
|
||||||
|
choices=language_code_list,
|
||||||
|
help="default language",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-pr",
|
||||||
|
"--precision",
|
||||||
|
type=str,
|
||||||
|
default="float16",
|
||||||
|
choices=["float16", "float32", "int8"],
|
||||||
|
help="compute precision for faster-whisper",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
host = args.bind_addr
|
||||||
|
port = args.port
|
||||||
|
argv = sys.argv
|
||||||
|
|
||||||
|
APP = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
|
class ASRRequest(BaseModel):
|
||||||
|
audio_path: Optional[str] = None
|
||||||
|
language: Optional[str] = None
|
||||||
|
engine: str = "auto" # auto | fasterwhisper | funasr
|
||||||
|
model_size: Optional[str] = None
|
||||||
|
precision: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
whisper_models = {}
|
||||||
|
whisper_model_paths = {}
|
||||||
|
|
||||||
|
|
||||||
|
def handle_control(command: str):
|
||||||
|
if command == "restart":
|
||||||
|
os.execl(sys.executable, sys.executable, *argv)
|
||||||
|
elif command == "exit":
|
||||||
|
os.kill(os.getpid(), signal.SIGTERM)
|
||||||
|
exit(0)
|
||||||
|
|
||||||
|
|
||||||
|
def check_params(req: dict):
|
||||||
|
audio_path = req.get("audio_path")
|
||||||
|
language = req.get("language", "auto")
|
||||||
|
engine = req.get("engine", "auto")
|
||||||
|
model_size = req.get("model_size", args.model_size)
|
||||||
|
precision = req.get("precision", args.precision)
|
||||||
|
|
||||||
|
if not audio_path:
|
||||||
|
return JSONResponse(status_code=400, content={"message": "audio_path is required"})
|
||||||
|
audio_path = os.path.abspath(audio_path)
|
||||||
|
if not os.path.isfile(audio_path):
|
||||||
|
return JSONResponse(status_code=400, content={"message": f"audio_path not found: {audio_path}"})
|
||||||
|
req["audio_path"] = audio_path
|
||||||
|
if language not in language_code_list:
|
||||||
|
return JSONResponse(status_code=400, content={"message": f"language not supported: {language}"})
|
||||||
|
if engine not in ["auto", "fasterwhisper", "funasr"]:
|
||||||
|
return JSONResponse(status_code=400, content={"message": f"engine not supported: {engine}"})
|
||||||
|
if model_size not in get_models():
|
||||||
|
return JSONResponse(status_code=400, content={"message": f"model_size not supported: {model_size}"})
|
||||||
|
if precision not in ["float16", "float32", "int8"]:
|
||||||
|
return JSONResponse(status_code=400, content={"message": f"precision not supported: {precision}"})
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_whisper_model(model_size: str, precision: str):
|
||||||
|
key = f"{model_size}:{precision}"
|
||||||
|
if key in whisper_models:
|
||||||
|
return whisper_models[key]
|
||||||
|
|
||||||
|
if model_size not in whisper_model_paths:
|
||||||
|
whisper_model_paths[model_size] = download_model(model_size)
|
||||||
|
model_path = whisper_model_paths[model_size]
|
||||||
|
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
model = WhisperModel(model_path, device=device, compute_type=precision)
|
||||||
|
whisper_models[key] = model
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def asr_with_fasterwhisper(audio_path: str, language: str, model_size: str, precision: str):
|
||||||
|
model = get_whisper_model(model_size=model_size, precision=precision)
|
||||||
|
fw_language = None if language == "auto" else language
|
||||||
|
segments, info = model.transcribe(
|
||||||
|
audio=audio_path,
|
||||||
|
beam_size=5,
|
||||||
|
vad_filter=True,
|
||||||
|
vad_parameters=dict(min_silence_duration_ms=700),
|
||||||
|
language=fw_language,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 中文/粤语默认转到 FunASR,文本稳定性通常更好
|
||||||
|
if info.language in ["zh", "yue"]:
|
||||||
|
text = only_asr(audio_path, language=info.language.lower())
|
||||||
|
if text:
|
||||||
|
return text
|
||||||
|
|
||||||
|
text = "".join(segment.text for segment in segments).strip()
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def asr_with_funasr(audio_path: str, language: str):
|
||||||
|
lang = language if language in ["zh", "yue"] else "zh"
|
||||||
|
return (only_asr(audio_path, language=lang) or "").strip()
|
||||||
|
|
||||||
|
|
||||||
|
async def asr_handle(req: dict):
|
||||||
|
check_res = check_params(req)
|
||||||
|
if check_res is not None:
|
||||||
|
return check_res
|
||||||
|
|
||||||
|
audio_path = req["audio_path"]
|
||||||
|
language = req.get("language", args.language)
|
||||||
|
engine = req.get("engine", "auto")
|
||||||
|
model_size = req.get("model_size", args.model_size)
|
||||||
|
precision = req.get("precision", args.precision)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if engine == "funasr":
|
||||||
|
text = asr_with_funasr(audio_path=audio_path, language=language)
|
||||||
|
elif engine == "fasterwhisper":
|
||||||
|
text = asr_with_fasterwhisper(
|
||||||
|
audio_path=audio_path,
|
||||||
|
language=language,
|
||||||
|
model_size=model_size,
|
||||||
|
precision=precision,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# auto: 中文/粤语优先 FunASR,其它语种走 Faster-Whisper
|
||||||
|
if language in ["zh", "yue"]:
|
||||||
|
text = asr_with_funasr(audio_path=audio_path, language=language)
|
||||||
|
if not text:
|
||||||
|
text = asr_with_fasterwhisper(
|
||||||
|
audio_path=audio_path,
|
||||||
|
language=language,
|
||||||
|
model_size=model_size,
|
||||||
|
precision=precision,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
text = asr_with_fasterwhisper(
|
||||||
|
audio_path=audio_path,
|
||||||
|
language=language,
|
||||||
|
model_size=model_size,
|
||||||
|
precision=precision,
|
||||||
|
)
|
||||||
|
|
||||||
|
return JSONResponse(status_code=200, content={"text": text})
|
||||||
|
except Exception as e:
|
||||||
|
return JSONResponse(status_code=400, content={"message": "asr failed", "Exception": str(e)})
|
||||||
|
|
||||||
|
|
||||||
|
@APP.get("/control")
|
||||||
|
async def control(command: str = None):
|
||||||
|
if command is None:
|
||||||
|
return JSONResponse(status_code=400, content={"message": "command is required"})
|
||||||
|
handle_control(command)
|
||||||
|
|
||||||
|
|
||||||
|
@APP.get("/asr")
|
||||||
|
async def asr_get_endpoint(
|
||||||
|
audio_path: str = None,
|
||||||
|
language: str = args.language,
|
||||||
|
engine: str = "auto",
|
||||||
|
model_size: str = args.model_size,
|
||||||
|
precision: str = args.precision,
|
||||||
|
):
|
||||||
|
req = {
|
||||||
|
"audio_path": audio_path,
|
||||||
|
"language": language.lower() if language else "auto",
|
||||||
|
"engine": engine,
|
||||||
|
"model_size": model_size,
|
||||||
|
"precision": precision,
|
||||||
|
}
|
||||||
|
return await asr_handle(req)
|
||||||
|
|
||||||
|
|
||||||
|
@APP.post("/asr")
|
||||||
|
async def asr_post_endpoint(request: ASRRequest):
|
||||||
|
req = request.dict()
|
||||||
|
req["language"] = (req.get("language") or args.language).lower()
|
||||||
|
req["model_size"] = req.get("model_size") or args.model_size
|
||||||
|
req["precision"] = req.get("precision") or args.precision
|
||||||
|
return await asr_handle(req)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
uvicorn.run(APP, host=host, port=port, workers=1)
|
||||||
Loading…
x
Reference in New Issue
Block a user