From 4b0fae83020389eed0dfd283c5122e5f3df584fc Mon Sep 17 00:00:00 2001 From: JavaAndPython55 <34533090+JavaAndPython55@users.noreply.github.com> Date: Wed, 21 Feb 2024 18:11:59 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E5=A2=9Eapi.py=E4=B8=AD=EF=BC=9A?= =?UTF-8?q?=E5=8F=AF=E5=9C=A8=E5=90=AF=E5=8A=A8=E5=90=8E=E5=8A=A8=E6=80=81?= =?UTF-8?q?=E4=BF=AE=E6=94=B9=E6=A8=A1=E5=9E=8B=EF=BC=8C=E4=BB=A5=E6=AD=A4?= =?UTF-8?q?=E6=BB=A1=E8=B6=B3=E5=90=8C=E4=B8=80=E4=B8=AAapi=E4=B8=8D?= =?UTF-8?q?=E5=90=8C=E7=9A=84=E6=9C=97=E8=AF=BB=E8=80=85=E8=AF=B7=E6=B1=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 可在启动后动态修改模型,以此满足同一个api不同的朗读者请求 --- api.py | 54 +++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 53 insertions(+), 1 deletion(-) diff --git a/api.py b/api.py index b8d584e..754f076 100644 --- a/api.py +++ b/api.py @@ -144,7 +144,7 @@ parser.add_argument("-dt", "--default_refer_text", type=str, default="", help=" parser.add_argument("-dl", "--default_refer_language", type=str, default="", help="默认参考音频语种") parser.add_argument("-d", "--device", type=str, default=g_config.infer_device, help="cuda / cpu / mps") -parser.add_argument("-a", "--bind_addr", type=str, default="127.0.0.1", help="default: 127.0.0.1") +parser.add_argument("-a", "--bind_addr", type=str, default="0.0.0.0", help="default: 0.0.0.0") parser.add_argument("-p", "--port", type=int, default=g_config.api_port, help="default: 9880") parser.add_argument("-fp", "--full_precision", action="store_true", default=False, help="覆盖config.is_half为False, 使用全精度") parser.add_argument("-hp", "--half_precision", action="store_true", default=False, help="覆盖config.is_half为True, 使用半精度") @@ -227,6 +227,44 @@ def is_full(*items): # 任意一项为空返回False return False return True +def change_sovits_weights(sovits_path): + global vq_model, hps + dict_s2 = torch.load(sovits_path, map_location="cpu") + hps = dict_s2["config"] + hps = DictToAttrRecursive(hps) + hps.model.semantic_frame_rate = "25hz" + vq_model = SynthesizerTrn( + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + n_speakers=hps.data.n_speakers, + **hps.model + ) + if ("pretrained" not in sovits_path): + del vq_model.enc_q + if is_half == True: + vq_model = vq_model.half().to(device) + else: + vq_model = vq_model.to(device) + vq_model.eval() + print(vq_model.load_state_dict(dict_s2["weight"], strict=False)) + with open("./sweight.txt", "w", encoding="utf-8") as f: + f.write(sovits_path) +def change_gpt_weights(gpt_path): + global hz, max_sec, t2s_model, config + hz = 50 + dict_s1 = torch.load(gpt_path, map_location="cpu") + config = dict_s1["config"] + max_sec = config["data"]["max_sec"] + t2s_model = Text2SemanticLightningModule(config, "****", is_train=False) + t2s_model.load_state_dict(dict_s1["weight"]) + if is_half == True: + t2s_model = t2s_model.half() + t2s_model = t2s_model.to(device) + t2s_model.eval() + total = sum([param.nelement() for param in t2s_model.parameters()]) + print("Number of parameter: %.2fM" % (total / 1e6)) + with open("./gweight.txt", "w", encoding="utf-8") as f: f.write(gpt_path) + def get_bert_feature(text, word2ph): with torch.no_grad(): @@ -452,6 +490,20 @@ def handle(refer_wav_path, prompt_text, prompt_language, text, text_language): app = FastAPI() +#clark新增-----2024-02-21 +#可在启动后动态修改模型,以此满足同一个api不同的朗读者请求 +@app.post("/set_model") +async def set_model(request: Request): + json_post_raw = await request.json() + global gpt_path + gpt_path=json_post_raw.get("gpt_model_path") + global sovits_path + sovits_path=json_post_raw.get("sovits_model_path") + print("gptpath"+gpt_path+";vitspath"+sovits_path) + change_sovits_weights(sovits_path) + change_gpt_weights(gpt_path) + return "ok" +# 新增-----end------ @app.post("/control") async def control(request: Request):