diff --git a/api.py b/api.py index cc0896a2..d38e78ef 100644 --- a/api.py +++ b/api.py @@ -1213,6 +1213,7 @@ parser.add_argument("-cp", "--cut_punc", type=str, default="", help="文本切 # 切割常用分句符为 `python ./api.py -cp ".?!。?!"` parser.add_argument("-hb", "--hubert_path", type=str, default=g_config.cnhubert_path, help="覆盖config.cnhubert_path") parser.add_argument("-b", "--bert_path", type=str, default=g_config.bert_path, help="覆盖config.bert_path") +parser.add_argument("-ak", "--api_key", type=str, default="", help="API密钥,不为空时开启鉴权") args = parser.parse_args() sovits_path = args.sovits_path @@ -1298,6 +1299,15 @@ change_gpt_sovits_weights(gpt_path=gpt_path, sovits_path=sovits_path) app = FastAPI() +@app.middleware("http") +async def api_key_middleware(request: Request, call_next): + if args.api_key: + key = request.headers.get("X-API-Key", "") + if key != args.api_key: + return JSONResponse(status_code=401, content={"message": "Unauthorized"}) + return await call_next(request) + + @app.post("/set_model") async def set_model(request: Request): json_post_raw = await request.json()