From 0bcdf0155c340b32d18a33aaf7a96f43b8f1e91e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E6=82=A6?= Date: Fri, 26 Jan 2024 14:09:50 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E6=A8=A1=E5=9E=8B=E8=AE=B0?= =?UTF-8?q?=E5=BF=86=E5=8A=9F=E8=83=BD=EF=BC=8C=E4=B8=8D=E7=94=A8=E4=BA=8C?= =?UTF-8?q?=E6=AC=A1=E9=80=89=E6=8B=A9=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 添加模型记忆功能,不用二次选择模型 --- GPT_SoVITS/inference_webui.py | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index bb57183..fdee8d9 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -6,10 +6,25 @@ logging.getLogger("httpx").setLevel(logging.ERROR) logging.getLogger("asyncio").setLevel(logging.ERROR) import pdb -gpt_path = os.environ.get( - "gpt_path", "pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt" -) -sovits_path = os.environ.get("sovits_path", "pretrained_models/s2G488k.pth") +if os.path.exists("./gweight.txt"): + with open("./gweight.txt", 'r',encoding="utf-8") as file: + gweight_data = file.read() + gpt_path = os.environ.get( + "gpt_path", gweight_data) +else: + gpt_path = os.environ.get( + "gpt_path", "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt") + +if os.path.exists("./sweight.txt"): + with open("./sweight.txt", 'r',encoding="utf-8") as file: + sweight_data = file.read() + sovits_path = os.environ.get("sovits_path", sweight_data) +else: + sovits_path = os.environ.get("sovits_path", "GPT_SoVITS/pretrained_models/s2G488k.pth") +# gpt_path = os.environ.get( +# "gpt_path", "pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt" +# ) +# sovits_path = os.environ.get("sovits_path", "pretrained_models/s2G488k.pth") cnhubert_base_path = os.environ.get( "cnhubert_base_path", "pretrained_models/chinese-hubert-base" ) @@ -124,6 +139,7 @@ def change_sovits_weights(sovits_path): vq_model = vq_model.to(device) vq_model.eval() print(vq_model.load_state_dict(dict_s2["weight"], strict=False)) + with open("./sweight.txt","w",encoding="utf-8")as f:f.write(sovits_path) change_sovits_weights(sovits_path) def change_gpt_weights(gpt_path): @@ -140,6 +156,7 @@ def change_gpt_weights(gpt_path): t2s_model.eval() total = sum([param.nelement() for param in t2s_model.parameters()]) print("Number of parameter: %.2fM" % (total / 1e6)) + with open("./gweight.txt","w",encoding="utf-8")as f:f.write(gpt_path) change_gpt_weights(gpt_path) def get_spepc(hps, filename):