From 0c25e57959e1f6287cbd0fe8fa07369864af2fd3 Mon Sep 17 00:00:00 2001 From: XXXXRT666 <157766680+XXXXRT666@users.noreply.github.com> Date: Sat, 3 Aug 2024 21:03:05 +0800 Subject: [PATCH] =?UTF-8?q?=E8=8B=A5=E5=B9=B2=E6=9D=82=E9=A1=B9=EF=BC=8C?= =?UTF-8?q?=E7=95=8C=E9=9D=A2=E4=BC=98=E5=8C=96=20(#1388)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 现支持直接启动inference_webui,支持选择version和i18n 将不同版本模型放在两个文件夹中,启动inference_webui后将根据version选择模型文件夹 训练保存文件夹也将根据verison变化 --- .gitignore | 3 +- GPT_SoVITS/inference_webui.py | 53 +++++++++++++++++++++-------------- webui.py | 17 ++++++----- 3 files changed, 43 insertions(+), 30 deletions(-) diff --git a/.gitignore b/.gitignore index 3899ba0..5970dae 100644 --- a/.gitignore +++ b/.gitignore @@ -10,5 +10,4 @@ reference GPT_weights SoVITS_weights TEMP -gweight.txt -sweight.txt \ No newline at end of file +weight.json diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index 131c6f6..a7e775e 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -14,30 +14,32 @@ logging.getLogger("httpx").setLevel(logging.ERROR) logging.getLogger("asyncio").setLevel(logging.ERROR) logging.getLogger("charset_normalizer").setLevel(logging.ERROR) logging.getLogger("torchaudio._extension").setLevel(logging.ERROR) -import LangSegment, os, re, sys +import LangSegment, os, re, sys, json import pdb import torch +if len(sys.argv)==1:sys.argv.append('v1') version=os.environ.get("version","v1") +version="v2"if sys.argv[1]=="v2" else version +os.environ['version']=version language=os.environ.get("language","auto") +language=sys.argv[-1] if len(sys.argv[-1])==5 else language pretrained_sovits_name="GPT_SoVITS/pretrained_models/s2G488k.pth"if version=="v1"else"GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth" pretrained_gpt_name="GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"if version=="v1"else "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt" -if os.path.exists("./gweight.txt"): - with open("./gweight.txt", 'r', encoding="utf-8") as file: - gweight_data = file.read() - gpt_path = os.environ.get( - "gpt_path", gweight_data) +if os.path.exists(f"./weight.json"): + pass else: - gpt_path = os.environ.get( - "gpt_path", pretrained_gpt_name) + with open(f"./weight.json", 'w', encoding="utf-8") as file:json.dump({'GPT':{},'SoVITS':{}},file) -if os.path.exists("./sweight.txt"): - with open("./sweight.txt", 'r', encoding="utf-8") as file: - sweight_data = file.read() - sovits_path = os.environ.get("sovits_path", sweight_data) -else: - sovits_path = os.environ.get("sovits_path", pretrained_sovits_name) +with open(f"./weight.json", 'r', encoding="utf-8") as file: + weight_data = file.read() + weight_data=json.loads(weight_data) + gpt_path = os.environ.get( + "gpt_path", weight_data.get('GPT',{}).get(version,pretrained_gpt_name)) + sovits_path = os.environ.get( + "sovits_path", weight_data.get('SoVITS',{}).get(version,pretrained_sovits_name)) + # gpt_path = os.environ.get( # "gpt_path", pretrained_gpt_name # ) @@ -164,8 +166,11 @@ def change_sovits_weights(sovits_path): vq_model = vq_model.to(device) vq_model.eval() print(vq_model.load_state_dict(dict_s2["weight"], strict=False)) - with open("./sweight.txt", "w", encoding="utf-8") as f: - f.write(sovits_path) + with open("./weight.json")as f: + data=f.read() + data=json.loads(data) + data["SoVITS"][version]=sovits_path + with open("./weight.json","w")as f:f.write(json.dumps(data)) change_sovits_weights(sovits_path) @@ -185,7 +190,11 @@ def change_gpt_weights(gpt_path): t2s_model.eval() total = sum([param.nelement() for param in t2s_model.parameters()]) print("Number of parameter: %.2fM" % (total / 1e6)) - with open("./gweight.txt", "w", encoding="utf-8") as f: f.write(gpt_path) + with open("./weight.json")as f: + data=f.read() + data=json.loads(data) + data["GPT"][version]=gpt_path + with open("./weight.json","w")as f:f.write(json.dumps(data)) change_gpt_weights(gpt_path) @@ -586,10 +595,12 @@ def change_choices(): return {"choices": sorted(SoVITS_names, key=custom_sort_key), "__type__": "update"}, {"choices": sorted(GPT_names, key=custom_sort_key), "__type__": "update"} -SoVITS_weight_root = "SoVITS_weights" -GPT_weight_root = "GPT_weights" -os.makedirs(SoVITS_weight_root, exist_ok=True) -os.makedirs(GPT_weight_root, exist_ok=True) +SoVITS_weight_root="SoVITS_weights_v2" if version=='v2' else "SoVITS_weights" +GPT_weight_root="GPT_weights_v2" if version=='v2' else "GPT_weights" +os.makedirs("SoVITS_weights",exist_ok=True) +os.makedirs("GPT_weights",exist_ok=True) +os.makedirs("SoVITS_weights_v2",exist_ok=True) +os.makedirs("GPT_weights_v2",exist_ok=True) def get_weights_names(): diff --git a/webui.py b/webui.py index e962bec..faebb78 100644 --- a/webui.py +++ b/webui.py @@ -1,6 +1,7 @@ import os,shutil,sys,pdb,re -version="v2"if sys.argv[0]=="v2" else"v1" -language=sys.argv[-1] if sys.argv[-1]!='v2' and sys.argv[-1]!='v1' else 'auto' +if len(sys.argv)==1:sys.argv.append('v1') +version="v2"if sys.argv[1]=="v2" else"v1" +language=sys.argv[-1] if len(sys.argv[-1])==5 else "auto" os.environ["version"]=version os.environ["language"]=language now_dir = os.getcwd() @@ -121,10 +122,12 @@ def get_weights_names(): for name in os.listdir(GPT_weight_root): if name.endswith(".ckpt"): GPT_names.append(name) return SoVITS_names,GPT_names -SoVITS_weight_root="SoVITS_weights" -GPT_weight_root="GPT_weights" -os.makedirs(SoVITS_weight_root,exist_ok=True) -os.makedirs(GPT_weight_root,exist_ok=True) +SoVITS_weight_root="SoVITS_weights_v2" if version=='v2' else "SoVITS_weights" +GPT_weight_root="GPT_weights_v2" if version=='v2' else "GPT_weights" +os.makedirs("SoVITS_weights",exist_ok=True) +os.makedirs("GPT_weights",exist_ok=True) +os.makedirs("SoVITS_weights_v2",exist_ok=True) +os.makedirs("GPT_weights_v2",exist_ok=True) SoVITS_names,GPT_names = get_weights_names() def custom_sort_key(s): @@ -208,7 +211,7 @@ def change_tts_inference(if_tts,bert_path,cnhubert_base_path,gpu_number,gpt_path os.environ["is_half"]=str(is_half) os.environ["infer_ttswebui"]=str(webui_port_infer_tts) os.environ["is_share"]=str(is_share) - cmd = '"%s" GPT_SoVITS/inference_webui.py'%(python_exec) + cmd = '"%s" GPT_SoVITS/inference_webui.py "%s"'%(python_exec, language) yield i18n("TTS推理进程已开启") print(cmd) p_tts_inference = Popen(cmd, shell=True)