From 7104a7e67137810256a2183ae25a96834d559666 Mon Sep 17 00:00:00 2001 From: Miuzarte <982809597@qq.com> Date: Wed, 21 Feb 2024 22:22:32 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E5=AE=8C=E5=96=84=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E6=9B=B4=E6=94=B9=E9=80=BB=E8=BE=91=E3=80=81=E6=96=87=E6=A1=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api.py | 169 ++++++++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 131 insertions(+), 38 deletions(-) diff --git a/api.py b/api.py index 754f0769..687c4845 100644 --- a/api.py +++ b/api.py @@ -14,7 +14,7 @@ `-dl` - `默认参考音频语种, "中文","英文","日文","zh","en","ja"` `-d` - `推理设备, "cuda","cpu","mps"` -`-a` - `绑定地址, 默认"127.0.0.1"` +`-a` - `绑定地址, 默认"0.0.0.0"` `-p` - `绑定端口, 默认9880, 可在 config.py 中指定` `-fp` - `覆盖 config.py 使用全精度` `-hp` - `覆盖 config.py 使用半精度` @@ -54,13 +54,13 @@ POST: ``` RESP: -成功: 直接返回 wav 音频流, http code 200 -失败: 返回包含错误信息的 json, http code 400 + 成功: 直接返回 wav 音频流, http code 200 + 失败: 返回包含错误信息的 json, http code 400 ### 更换默认参考音频 -endpoint: `/change_refer` +endpoints: `/change_refer`, `/set_refer` key与推理端一样 @@ -76,8 +76,31 @@ POST: ``` RESP: -成功: json, http code 200 -失败: json, 400 + 成功: json, http code 200 + 失败: json, 400 + + +### 更换模型 + +endpoints: `/change_model`, `/change_weight`, `/set_model`, `/set_weight` + +key alias: + "gpt", "gpt_path", "gpt_model_path" + "sovits", "sovits_path", "sovits_model_path" + +GET: + `http://127.0.0.1:9880/change_weight?gpt=./GPT_weights/suijiSUI-e20.ckpt&sovits=./SoVITS_weights/suijiSUI_e20_s3280.pth` +POST: +```json +{ + "gpt": "./GPT_weights/suijiSUI-e20.ckpt", + "sovits": "./SoVITS_weights/suijiSUI_e20_s3280.pth" +} +``` + +RESP: + 成功: json, http code 200 + 失败: json, 400 | "Internal Server Error" ### 命令控制 @@ -85,8 +108,8 @@ RESP: endpoint: `/control` command: -"restart": 重新运行 -"exit": 结束运行 + "restart": 重新运行 + "exit": 结束运行 GET: `http://127.0.0.1:9880/control?command=restart` @@ -101,7 +124,6 @@ RESP: 无 """ - import argparse import os import sys @@ -132,8 +154,6 @@ import config as global_config g_config = global_config.Config() -# AVAILABLE_COMPUTE = "cuda" if torch.cuda.is_available() else "cpu" - parser = argparse.ArgumentParser(description="GPT-SoVITS api") parser.add_argument("-s", "--sovits_path", type=str, default=g_config.sovits_path, help="SoVITS模型路径") @@ -216,17 +236,18 @@ else: def is_empty(*items): # 任意一项不为空返回False for item in items: - if item is not None and item != "": + if item: return False return True def is_full(*items): # 任意一项为空返回False for item in items: - if item is None or item == "": + if not item: return False return True + def change_sovits_weights(sovits_path): global vq_model, hps dict_s2 = torch.load(sovits_path, map_location="cpu") @@ -249,6 +270,8 @@ def change_sovits_weights(sovits_path): print(vq_model.load_state_dict(dict_s2["weight"], strict=False)) with open("./sweight.txt", "w", encoding="utf-8") as f: f.write(sovits_path) + + def change_gpt_weights(gpt_path): global hz, max_sec, t2s_model, config hz = 50 @@ -438,9 +461,10 @@ def handle_control(command): exit(0) -def handle_change(path, text, language): +def handle_change_refer(path, text, language): if is_empty(path, text, language): - return JSONResponse({"code": 400, "message": '缺少任意一项以下参数: "path", "text", "language"'}, status_code=400) + return JSONResponse({"code": 400, "message": '缺少任意一项以下参数: "path", "text", "language"'}, + status_code=400) if path != "" or path is not None: default_refer.path = path @@ -457,12 +481,12 @@ def handle_change(path, text, language): return JSONResponse({"code": 0, "message": "Success"}, status_code=200) -def handle(refer_wav_path, prompt_text, prompt_language, text, text_language): - if ( - refer_wav_path == "" or refer_wav_path is None - or prompt_text == "" or prompt_text is None - or prompt_language == "" or prompt_language is None - ): +def handle_refer(refer_wav_path, prompt_text, prompt_language, text, text_language): + if ( # 缺任意一个 + not refer_wav_path + or not prompt_text + or not prompt_language + ): # 使用全局 refer_wav_path, prompt_text, prompt_language = ( default_refer.path, default_refer.text, @@ -481,29 +505,96 @@ def handle(refer_wav_path, prompt_text, prompt_language, text, text_language): sf.write(wav, audio_data, sampling_rate, format="wav") wav.seek(0) - torch.cuda.empty_cache() + if device == "cuda": + torch.cuda.empty_cache() if device == "mps": print('executed torch.mps.empty_cache()') torch.mps.empty_cache() return StreamingResponse(wav, media_type="audio/wav") +def handle_change_weights(gpt, sovits): + if is_empty(gpt, sovits): + return JSONResponse({"code": 400, "message": f"缺少任意一项以下参数: {gpt_alias}, {sovits_alias}"}, + status_code=400) + + global gpt_path, sovits_path + + if gpt: + gpt_path = gpt + print(f"New gpt_path: {gpt_path}") + change_gpt_weights(gpt_path) + + if sovits: + sovits_path = sovits + print(f"New sovits_path: {sovits_path}") + change_sovits_weights(sovits_path) + + return JSONResponse({"code": 0, "message": "Success"}, status_code=200) + + +gpt_alias = ( + "gpt", + "gpt_path", + "gpt_model_path" # @JavaAndPython55 用的这个key, 嫌太长直接alias了 +) +sovits_alias = ( + "sovits", + "sovits_path", + "sovits_model_path" +) + app = FastAPI() -#clark新增-----2024-02-21 -#可在启动后动态修改模型,以此满足同一个api不同的朗读者请求 + +# clark新增-----2024-02-21 +# 可在启动后动态修改模型,以此满足同一个api不同的朗读者请求 @app.post("/set_model") -async def set_model(request: Request): +@app.post("/set_weight") +@app.post("/change_model") +@app.post("/change_weight") +async def change_weight(request: Request): json_post_raw = await request.json() - global gpt_path - gpt_path=json_post_raw.get("gpt_model_path") - global sovits_path - sovits_path=json_post_raw.get("sovits_model_path") - print("gptpath"+gpt_path+";vitspath"+sovits_path) - change_sovits_weights(sovits_path) - change_gpt_weights(gpt_path) - return "ok" -# 新增-----end------ + + gpt, sovits = "", "" + for ga in gpt_alias: + g = json_post_raw.get(ga) + if g: + gpt = g + break + for sa in sovits_alias: + s = json_post_raw.get(sa) + if s: + sovits = s + break + + return handle_change_weights(gpt, sovits) + + +@app.get("/set_model") +@app.get("/set_weight") +@app.get("/change_model") +@app.get("/change_weight") +async def change_weight( + gpt: str = None, + gpt_path: str = None, + gpt_model_path: str = None, + sovits: str = None, + sovits_path: str = None, + sovits_model_path: str = None, +): + GPT, SOVITS = "", "" + for gg in (gpt, gpt_path, gpt_model_path): + if gg: + GPT = gg + break + for ss in (sovits, sovits_path, sovits_model_path): + if ss: + SOVITS = ss + break + + return handle_change_weights(GPT, SOVITS) + @app.post("/control") async def control(request: Request): @@ -516,29 +607,31 @@ async def control(command: str = None): return handle_control(command) +@app.post("/set_refer") @app.post("/change_refer") async def change_refer(request: Request): json_post_raw = await request.json() - return handle_change( + return handle_change_refer( json_post_raw.get("refer_wav_path"), json_post_raw.get("prompt_text"), json_post_raw.get("prompt_language") ) +@app.get("/set_refer") @app.get("/change_refer") async def change_refer( refer_wav_path: str = None, prompt_text: str = None, prompt_language: str = None ): - return handle_change(refer_wav_path, prompt_text, prompt_language) + return handle_change_refer(refer_wav_path, prompt_text, prompt_language) @app.post("/") async def tts_endpoint(request: Request): json_post_raw = await request.json() - return handle( + return handle_refer( json_post_raw.get("refer_wav_path"), json_post_raw.get("prompt_text"), json_post_raw.get("prompt_language"), @@ -555,7 +648,7 @@ async def tts_endpoint( text: str = None, text_language: str = None, ): - return handle(refer_wav_path, prompt_text, prompt_language, text, text_language) + return handle_refer(refer_wav_path, prompt_text, prompt_language, text, text_language) if __name__ == "__main__": From 501d0fdff14eec4d2dbeb265ca97fe55c38ee9e8 Mon Sep 17 00:00:00 2001 From: Miuzarte <982809597@qq.com> Date: Wed, 21 Feb 2024 22:25:47 +0800 Subject: [PATCH 2/2] wrong word --- api.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/api.py b/api.py index 687c4845..2e4a30d5 100644 --- a/api.py +++ b/api.py @@ -481,7 +481,7 @@ def handle_change_refer(path, text, language): return JSONResponse({"code": 0, "message": "Success"}, status_code=200) -def handle_refer(refer_wav_path, prompt_text, prompt_language, text, text_language): +def handle_infer(refer_wav_path, prompt_text, prompt_language, text, text_language): if ( # 缺任意一个 not refer_wav_path or not prompt_text @@ -631,7 +631,7 @@ async def change_refer( @app.post("/") async def tts_endpoint(request: Request): json_post_raw = await request.json() - return handle_refer( + return handle_infer( json_post_raw.get("refer_wav_path"), json_post_raw.get("prompt_text"), json_post_raw.get("prompt_language"), @@ -648,7 +648,7 @@ async def tts_endpoint( text: str = None, text_language: str = None, ): - return handle_refer(refer_wav_path, prompt_text, prompt_language, text, text_language) + return handle_infer(refer_wav_path, prompt_text, prompt_language, text, text_language) if __name__ == "__main__":