diff --git a/api.py b/api.py index 725b12d..60d5919 100644 --- a/api.py +++ b/api.py @@ -7,7 +7,7 @@ import torch import librosa import soundfile as sf from fastapi import FastAPI, Request, HTTPException -from fastapi.responses import StreamingResponse +from fastapi.responses import StreamingResponse, JSONResponse import uvicorn from transformers import AutoModelForMaskedLM, AutoTokenizer import numpy as np @@ -51,10 +51,18 @@ args = parser.parse_args() sovits_path = args.sovits_path gpt_path = args.gpt_path -default_refer_path = args.default_refer_path -default_refer_text = args.default_refer_text -default_refer_language = args.default_refer_language -has_preset = False + +class DefaultRefer: + def __init__(self, path, text, language): + self.path = args.default_refer_path + self.text = args.default_refer_text + self.language = args.default_refer_language + + def is_ready(self) -> bool: + return is_full(self.path, self.text, self.language) + + +default_refer = DefaultRefer(args.default_refer_path, args.default_refer_text, args.default_refer_language) device = args.device port = args.port @@ -68,15 +76,13 @@ if gpt_path == "": print(f"[WARN] 未指定GPT模型路径, fallback后当前值: {gpt_path}") # 指定默认参考音频, 调用方 未提供/未给全 参考音频参数时使用 -if default_refer_path == "" or default_refer_text == "" or default_refer_language == "": - default_refer_path, default_refer_text, default_refer_language = "", "", "" +if default_refer.path == "" or default_refer.text == "" or default_refer.language == "": + default_refer.path, default_refer.text, default_refer.language = "", "", "" print("[INFO] 未指定默认参考音频") - has_preset = False else: - print(f"[INFO] 默认参考音频路径: {default_refer_path}") - print(f"[INFO] 默认参考音频文本: {default_refer_text}") - print(f"[INFO] 默认参考音频语种: {default_refer_language}") - has_preset = True + print(f"[INFO] 默认参考音频路径: {default_refer.path}") + print(f"[INFO] 默认参考音频文本: {default_refer.text}") + print(f"[INFO] 默认参考音频语种: {default_refer.language}") is_half = g_config.is_half if args.full_precision: @@ -100,6 +106,20 @@ else: bert_model = bert_model.to(device) +def is_empty(*items): # 任意一项不为空返回False + for item in items: + if item is not None and item != "": + return False + return True + + +def is_full(*items): # 任意一项为空返回False + for item in items: + if item is None or item == "": + return False + return True + + def get_bert_feature(text, word2ph): with torch.no_grad(): inputs = tokenizer(text, return_tensors="pt") @@ -203,7 +223,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language) else: wav16k = wav16k.to(device) zero_wav_torch = zero_wav_torch.to(device) - wav16k=torch.cat([wav16k,zero_wav_torch]) + wav16k = torch.cat([wav16k, zero_wav_torch]) ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float() codes = vq_model.extract_latent(ssl_content) prompt_semantic = codes[0, 0] @@ -264,6 +284,25 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language) yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(np.int16) +def handle_change(path, text, language): + if is_empty(path, text, language): + raise HTTPException(status_code=400, detail='缺少任意一项以下参数: "path", "text", "language"') + + if path != "" or path is not None: + default_refer.path = path + if text != "" or text is not None: + default_refer.text = text + if language != "" or language is not None: + default_refer.language = language + + print(f"[INFO] 当前默认参考音频路径: {default_refer.path}") + print(f"[INFO] 当前默认参考音频文本: {default_refer.text}") + print(f"[INFO] 当前默认参考音频语种: {default_refer.language}") + print(f"[INFO] is_ready: {default_refer.is_ready()}") + + return JSONResponse({"code": 0, "message": "Success"}, status_code=200) + + def handle(command, refer_wav_path, prompt_text, prompt_language, text, text_language): if command == "/restart": os.execl(g_config.python_exec, g_config.python_exec, *sys.argv) @@ -277,11 +316,11 @@ def handle(command, refer_wav_path, prompt_text, prompt_language, text, text_lan or prompt_language == "" or prompt_language is None ): refer_wav_path, prompt_text, prompt_language = ( - default_refer_path, - default_refer_text, - default_refer_language, + default_refer.path, + default_refer.text, + default_refer.language, ) - if not has_preset: + if not default_refer.is_ready(): raise HTTPException(status_code=400, detail="未指定参考音频且接口无预设") with torch.no_grad(): @@ -301,6 +340,25 @@ def handle(command, refer_wav_path, prompt_text, prompt_language, text, text_lan app = FastAPI() +@app.post("/change_refer") +async def change_refer(request: Request): + json_post_raw = await request.json() + return handle_change( + json_post_raw.get("path"), + json_post_raw.get("text"), + json_post_raw.get("language") + ) + + +@app.get("/change_refer") +async def change_refer( + path: str = None, + text: str = None, + language: str = None +): + return handle_change(path, text, language) + + @app.post("/") async def tts_endpoint(request: Request): json_post_raw = await request.json()