diff --git a/api.py b/api.py index 60d5919..32379c3 100644 --- a/api.py +++ b/api.py @@ -1,3 +1,107 @@ +""" +# api.py usage + +` python api.py -dr "123.wav" -dt "一二三。" -dl "zh" ` + +## 执行参数: + +`-s` - `SoVITS模型路径, 可在 config.py 中指定` +`-g` - `GPT模型路径, 可在 config.py 中指定` + +调用请求缺少参考音频时使用 +`-dr` - `默认参考音频路径` +`-dt` - `默认参考音频文本` +`-dl` - `默认参考音频语种, "中文","英文","日文","zh","en","ja"` + +`-d` - `推理设备, "cuda","cpu"` +`-a` - `绑定地址, 默认"127.0.0.1"` +`-p` - `绑定端口, 默认9880, 可在 config.py 中指定` +`-fp` - `覆盖 config.py 使用全精度` +`-hp` - `覆盖 config.py 使用半精度` + +`-hb` - `cnhubert路径` +`-b` - `bert路径` + +## 调用: + +### 推理 + +endpoint: `/` + +使用执行参数指定的参考音频: +GET: + `http://127.0.0.1:9880?text=你所热爱的,就是你的生活。&text_language=zh` +POST: +```json +{ + "text": "你所热爱的,就是你的生活。", + "text_language": "zh" +} +``` + +手动指定当次推理所使用的参考音频: +GET: + `http://127.0.0.1:9880?refer_wav_path=123.wav&prompt_text=一二三。&prompt_language=zh&text=你所热爱的,就是你的生活。&text_language=zh` +POST: +```json +{ + "refer_wav_path": "123.wav", + "prompt_text": "一二三。", + "prompt_language": "zh", + "text": "你所热爱的,就是你的生活。", + "text_language": "zh" +} +``` + +RESP: +成功: 直接返回 wav 音频流, http code 200 +失败: 返回包含错误信息的 json, http code 400 + + +### 更换默认参考音频 + +endpoint: `/change_refer` + +key与推理端一样 + +GET: + `http://127.0.0.1:9880/change_refer?refer_wav_path=123.wav&prompt_text=一二三。&prompt_language=zh` +POST: +```json +{ + "refer_wav_path": "123.wav", + "prompt_text": "一二三。", + "prompt_language": "zh" +} +``` + +RESP: +成功: json, http code 200 +失败: json, 400 + + +### 命令控制 + +endpoint: `/control` + +command: +"restart": 重新运行 +"exit": 结束运行 + +GET: + `http://127.0.0.1:9880/control?command=restart` +POST: +```json +{ + "command": "restart" +} +``` + +RESP: 无 + +""" + + import argparse import os import signal @@ -30,14 +134,13 @@ parser = argparse.ArgumentParser(description="GPT-SoVITS api") parser.add_argument("-s", "--sovits_path", type=str, default=g_config.sovits_path, help="SoVITS模型路径") parser.add_argument("-g", "--gpt_path", type=str, default=g_config.gpt_path, help="GPT模型路径") -parser.add_argument("-dr", "--default_refer_path", type=str, default="", - help="默认参考音频路径, 请求缺少参考音频时调用") +parser.add_argument("-dr", "--default_refer_path", type=str, default="", help="默认参考音频路径") parser.add_argument("-dt", "--default_refer_text", type=str, default="", help="默认参考音频文本") parser.add_argument("-dl", "--default_refer_language", type=str, default="", help="默认参考音频语种") parser.add_argument("-d", "--device", type=str, default=g_config.infer_device, help="cuda / cpu") -parser.add_argument("-p", "--port", type=int, default=g_config.api_port, help="default: 9880") parser.add_argument("-a", "--bind_addr", type=str, default="127.0.0.1", help="default: 127.0.0.1") +parser.add_argument("-p", "--port", type=int, default=g_config.api_port, help="default: 9880") parser.add_argument("-fp", "--full_precision", action="store_true", default=False, help="覆盖config.is_half为False, 使用全精度") parser.add_argument("-hp", "--half_precision", action="store_true", default=False, help="覆盖config.is_half为True, 使用半精度") # bool值的用法为 `python ./api.py -fp ...` @@ -284,9 +387,17 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language) yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(np.int16) +def handle_control(command): + if command == "restart": + os.execl(g_config.python_exec, g_config.python_exec, *sys.argv) + elif command == "exit": + os.kill(os.getpid(), signal.SIGTERM) + exit(0) + + def handle_change(path, text, language): if is_empty(path, text, language): - raise HTTPException(status_code=400, detail='缺少任意一项以下参数: "path", "text", "language"') + return JSONResponse({"code": 400, "message": '缺少任意一项以下参数: "path", "text", "language"'}, status_code=400) if path != "" or path is not None: default_refer.path = path @@ -303,13 +414,7 @@ def handle_change(path, text, language): return JSONResponse({"code": 0, "message": "Success"}, status_code=200) -def handle(command, refer_wav_path, prompt_text, prompt_language, text, text_language): - if command == "/restart": - os.execl(g_config.python_exec, g_config.python_exec, *sys.argv) - elif command == "/exit": - os.kill(os.getpid(), signal.SIGTERM) - exit(0) - +def handle(refer_wav_path, prompt_text, prompt_language, text, text_language): if ( refer_wav_path == "" or refer_wav_path is None or prompt_text == "" or prompt_text is None @@ -321,7 +426,7 @@ def handle(command, refer_wav_path, prompt_text, prompt_language, text, text_lan default_refer.language, ) if not default_refer.is_ready(): - raise HTTPException(status_code=400, detail="未指定参考音频且接口无预设") + return JSONResponse({"code": 400, "message": "未指定参考音频且接口无预设"}, status_code=400) with torch.no_grad(): gen = get_tts_wav( @@ -340,30 +445,40 @@ def handle(command, refer_wav_path, prompt_text, prompt_language, text, text_lan app = FastAPI() +@app.post("/control") +async def control(request: Request): + json_post_raw = await request.json() + return handle_control(json_post_raw.get("command")) + + +@app.get("/control") +async def control(command: str = None): + return handle_control(command) + + @app.post("/change_refer") async def change_refer(request: Request): json_post_raw = await request.json() return handle_change( - json_post_raw.get("path"), - json_post_raw.get("text"), - json_post_raw.get("language") + json_post_raw.get("refer_wav_path"), + json_post_raw.get("prompt_text"), + json_post_raw.get("prompt_language") ) @app.get("/change_refer") async def change_refer( - path: str = None, - text: str = None, - language: str = None + refer_wav_path: str = None, + prompt_text: str = None, + prompt_language: str = None ): - return handle_change(path, text, language) + return handle_change(refer_wav_path, prompt_text, prompt_language) @app.post("/") async def tts_endpoint(request: Request): json_post_raw = await request.json() return handle( - json_post_raw.get("command"), json_post_raw.get("refer_wav_path"), json_post_raw.get("prompt_text"), json_post_raw.get("prompt_language"), @@ -374,14 +489,13 @@ async def tts_endpoint(request: Request): @app.get("/") async def tts_endpoint( - command: str = None, refer_wav_path: str = None, prompt_text: str = None, prompt_language: str = None, text: str = None, text_language: str = None, ): - return handle(command, refer_wav_path, prompt_text, prompt_language, text, text_language) + return handle(refer_wav_path, prompt_text, prompt_language, text, text_language) if __name__ == "__main__":