From b6572a7cdde84f1ea911afe1eef648d90bef0ffa Mon Sep 17 00:00:00 2001 From: YYuX-1145 <138500330+YYuX-1145@users.noreply.github.com> Date: Sun, 24 Mar 2024 12:57:13 +0800 Subject: [PATCH] Add files via upload --- api_simple.py | 365 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 365 insertions(+) create mode 100644 api_simple.py diff --git a/api_simple.py b/api_simple.py new file mode 100644 index 00000000..ca9715f4 --- /dev/null +++ b/api_simple.py @@ -0,0 +1,365 @@ +""" +# api.py usage + +` python api.py -dr "123.wav" -dt "一二三。" -dl "zh" ` + +## 执行参数: + +`-s` - `SoVITS模型路径, 可在 config.py 中指定` +`-g` - `GPT模型路径, 可在 config.py 中指定` + +调用请求缺少参考音频时使用 +`-dr` - `默认参考音频路径` +`-dt` - `默认参考音频文本` +`-dl` - `默认参考音频语种, "中文","英文","日文","zh","en","ja"` + +`-d` - `推理设备, "cuda","cpu"` +`-a` - `绑定地址, 默认"127.0.0.1"` +`-p` - `绑定端口, 默认9880, 可在 config.py 中指定` +`-fp` - `覆盖 config.py 使用全精度` +`-hp` - `覆盖 config.py 使用半精度` + +`-hb` - `cnhubert路径` +`-b` - `bert路径` + +## 调用: + +### 推理 + +endpoint: `/` + +使用执行参数指定的参考音频: +GET: + `http://127.0.0.1:9880?text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh` +POST: +```json +{ + "text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。", + "text_language": "zh" +} +``` + +手动指定当次推理所使用的参考音频: +GET: + `http://127.0.0.1:9880?refer_wav_path=123.wav&prompt_text=一二三。&prompt_language=zh&text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh` +POST: +```json +{ + "refer_wav_path": "123.wav", + "prompt_text": "一二三。", + "prompt_language": "zh", + "text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。", + "text_language": "zh" +} +``` + +RESP: +成功: 直接返回 wav 音频流, http code 200 +失败: 返回包含错误信息的 json, http code 400 + + +### 更换默认参考音频 + +endpoint: `/change_refer` + +key与推理端一样 + +GET: + `http://127.0.0.1:9880/change_refer?refer_wav_path=123.wav&prompt_text=一二三。&prompt_language=zh` +POST: +```json +{ + "refer_wav_path": "123.wav", + "prompt_text": "一二三。", + "prompt_language": "zh" +} +``` + +RESP: +成功: json, http code 200 +失败: json, 400 + + +### 命令控制 + +endpoint: `/control` + +command: +"restart": 重新运行 +"exit": 结束运行 + +GET: + `http://127.0.0.1:9880/control?command=restart` +POST: +```json +{ + "command": "restart" +} +``` + +RESP: 无 + +""" + + +import argparse +import os +import sys + +now_dir = os.getcwd() +sys.path.append(now_dir) +sys.path.append("%s/GPT_SoVITS" % (now_dir)) + +import soundfile as sf +from fastapi import FastAPI, Request, HTTPException +from fastapi.responses import StreamingResponse, JSONResponse +import uvicorn +from io import BytesIO +import inference_webui +from inference_webui import inference as get_tts_wav +import signal +import config as global_config +g_config = global_config.Config() +# AVAILABLE_COMPUTE = "cuda" if torch.cuda.is_available() else "cpu" + +parser = argparse.ArgumentParser(description="GPT-SoVITS api") + +parser.add_argument("-s", "--sovits_path", type=str, default=g_config.sovits_path, help="SoVITS模型路径") +parser.add_argument("-g", "--gpt_path", type=str, default=g_config.gpt_path, help="GPT模型路径") + +parser.add_argument("-dr", "--default_refer_path", type=str, default="", help="默认参考音频路径") +parser.add_argument("-dt", "--default_refer_text", type=str, default="", help="默认参考音频文本") +parser.add_argument("-dl", "--default_refer_language", type=str, default="", help="默认参考音频语种") + +parser.add_argument("-d", "--device", type=str, default=g_config.infer_device, help="cuda / cpu") +parser.add_argument("-a", "--bind_addr", type=str, default="0.0.0.0", help="default: 0.0.0.0") +parser.add_argument("-p", "--port", type=int, default=g_config.api_port, help="default: 9880") +parser.add_argument("-w", "--workers", type=int, default=1, help="num_workers") +#parser.add_argument("-fp", "--full_precision", action="store_true", default=False, help="覆盖config.is_half为False, 使用全精度") +#parser.add_argument("-hp", "--half_precision", action="store_true", default=False, help="覆盖config.is_half为True, 使用半精度") +# bool值的用法为 `python ./api.py -fp ...` +# 此时 full_precision==True, half_precision==False + +parser.add_argument("-hb", "--hubert_path", type=str, default=g_config.cnhubert_path, help="覆盖config.cnhubert_path") +parser.add_argument("-b", "--bert_path", type=str, default=g_config.bert_path, help="覆盖config.bert_path") + +args = parser.parse_args() + +sovits_path = args.sovits_path +gpt_path = args.gpt_path + + +def change_sovits_weights(sovits_path): + if sovits_path is not None and sovits_path !="": + inference_webui.tts_pipline.init_vits_weights(sovits_path) +def change_gpt_weights(gpt_path): + if gpt_path is not None and gpt_path !="": + inference_webui.tts_pipline.init_t2s_weights(gpt_path) + +change_sovits_weights(sovits_path) +change_gpt_weights(gpt_path) + + +class DefaultRefer: + def __init__(self, path, text, language): + self.path = args.default_refer_path + self.text = args.default_refer_text + self.language = args.default_refer_language + + def is_ready(self) -> bool: + return is_full(self.path, self.text, self.language) + + +default_refer = DefaultRefer(args.default_refer_path, args.default_refer_text, args.default_refer_language) + +device = args.device +port = args.port +host = args.bind_addr +workers=args.workers + + + +def is_empty(*items): # 任意一项不为空返回False + for item in items: + if item is not None and item != "": + return False + return True + + +def is_full(*items): # 任意一项为空返回False + for item in items: + if item is None or item == "": + return False + return True + + + +dict_language = { + "中文": "zh", + "英文": "en", + "日文": "ja", + "ZH": "zh", + "EN": "en", + "JA": "ja", + "zh": "zh", + "en": "en", + "ja": "ja" +} + + +def handle_control(command): + if command == "restart": + os.execl(g_config.python_exec, g_config.python_exec, *sys.argv) + elif command == "exit": + os.kill(os.getpid(), signal.SIGTERM) + exit(0) + + +def handle_change(path, text, language): + if is_empty(path, text, language): + return JSONResponse({"code": 400, "message": '缺少任意一项以下参数: "path", "text", "language"'}, status_code=400) + + if path != "" or path is not None: + default_refer.path = path + if text != "" or text is not None: + default_refer.text = text + if language != "" or language is not None: + default_refer.language = language + + print(f"[INFO] 当前默认参考音频路径: {default_refer.path}") + print(f"[INFO] 当前默认参考音频文本: {default_refer.text}") + print(f"[INFO] 当前默认参考音频语种: {default_refer.language}") + print(f"[INFO] is_ready: {default_refer.is_ready()}") + + return JSONResponse({"code": 0, "message": "Success"}, status_code=200) + + +def handle(text, text_language, + refer_wav_path, prompt_text, + prompt_language, top_k, + top_p, temperature, + text_split_method, batch_size, + speed_factor, ref_text_free, + split_bucket,fragment_interval, + seed): + if ( + refer_wav_path == "" or refer_wav_path is None + or prompt_text == "" or prompt_text is None + or prompt_language == "" or prompt_language is None + ): + refer_wav_path, prompt_text, prompt_language = ( + default_refer.path, + default_refer.text, + default_refer.language, + ) + if not default_refer.is_ready(): + return JSONResponse({"code": 400, "message": "未指定参考音频且接口无预设"}, status_code=400) + prompt_text = prompt_text.strip("\n") + prompt_language, text = prompt_language, text.strip("\n") + gen = get_tts_wav(text, text_language, + refer_wav_path, prompt_text, + prompt_language, top_k, + top_p, temperature, + text_split_method, batch_size, + speed_factor, ref_text_free, + split_bucket,fragment_interval, + seed + ) + audio,_ = next(gen) + sampling_rate,audio_data=audio + + wav = BytesIO() + sf.write(wav, audio_data, sampling_rate, format="wav") + wav.seek(0) + return StreamingResponse(wav, media_type="audio/wav") + + +app = FastAPI() + +#clark新增-----2024-02-21 +#可在启动后动态修改模型,以此满足同一个api不同的朗读者请求 +@app.post("/set_model") +async def set_model(request: Request): + json_post_raw = await request.json() + global gpt_path + gpt_path=json_post_raw.get("gpt_model_path") + global sovits_path + sovits_path=json_post_raw.get("sovits_model_path") + print("gptpath"+gpt_path+";vitspath"+sovits_path) + change_sovits_weights(sovits_path) + change_gpt_weights(gpt_path) + return "ok" +# 新增-----end------ + +@app.post("/control") +async def control(request: Request): + json_post_raw = await request.json() + return handle_control(json_post_raw.get("command")) + + +@app.get("/control") +async def control(command: str = None): + return handle_control(command) + + +@app.post("/change_refer") +async def change_refer(request: Request): + json_post_raw = await request.json() + return handle_change( + json_post_raw.get("refer_wav_path"), + json_post_raw.get("prompt_text"), + json_post_raw.get("prompt_language") + ) + + +@app.get("/change_refer") +async def change_refer( + refer_wav_path: str = None, + prompt_text: str = None, + prompt_language: str = None +): + return handle_change(refer_wav_path, prompt_text, prompt_language) + +''' +@app.post("/") +async def tts_endpoint(request: Request): + json_post_raw = await request.json() + return handle( + json_post_raw.get("refer_wav_path"), + json_post_raw.get("prompt_text"), + json_post_raw.get("prompt_language"), + json_post_raw.get("text"), + json_post_raw.get("text_language"), + ) +''' + +@app.get("/") +async def tts_endpoint( + refer_wav_path: str = None, + prompt_text: str = None, + prompt_language: str = None, + text: str = None, + text_language: str = None, + top_k:int =5, + top_p:float =1, + temperature:float=1, + text_split_method:str="凑四句一切", + batch_size:int=20, + speed_factor:float=1, + ref_text_free:bool=False, + split_bucket:bool=True, + fragment_interval:float=0.3, + seed:int=-1, +): + return handle(text, text_language, + refer_wav_path, prompt_text, + prompt_language, top_k, + top_p, temperature, + text_split_method, batch_size, + speed_factor, ref_text_free, + split_bucket,fragment_interval, + seed) + + +if __name__ == "__main__": + uvicorn.run(f'{os.path.basename(__file__).split(".")[0]}:app', host=host, port=port, workers=workers)