diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py
index 246748a..3046d7a 100644
--- a/GPT_SoVITS/inference_webui.py
+++ b/GPT_SoVITS/inference_webui.py
@@ -1,4 +1,5 @@
-import os
+import os,re
+import pdb
gpt_path = os.environ.get(
"gpt_path", "pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
@@ -42,8 +43,6 @@ if is_half == True:
else:
bert_model = bert_model.to(device)
-
-# bert_model=bert_model.to(device)
def get_bert_feature(text, word2ph):
with torch.no_grad():
inputs = tokenizer(text, return_tensors="pt")
@@ -57,15 +56,8 @@ def get_bert_feature(text, word2ph):
repeat_feature = res[i].repeat(word2ph[i], 1)
phone_level_feature.append(repeat_feature)
phone_level_feature = torch.cat(phone_level_feature, dim=0)
- # if(is_half==True):phone_level_feature=phone_level_feature.half()
return phone_level_feature.T
-
-n_semantic = 1024
-
-dict_s2=torch.load(sovits_path,map_location="cpu")
-hps=dict_s2["config"]
-
class DictToAttrRecursive(dict):
def __init__(self, input_dict):
super().__init__(input_dict)
@@ -94,40 +86,48 @@ class DictToAttrRecursive(dict):
raise AttributeError(f"Attribute {item} not found")
-hps = DictToAttrRecursive(hps)
-
-hps.model.semantic_frame_rate = "25hz"
-dict_s1 = torch.load(gpt_path, map_location="cpu")
-config = dict_s1["config"]
ssl_model = cnhubert.get_model()
if is_half == True:
ssl_model = ssl_model.half().to(device)
else:
ssl_model = ssl_model.to(device)
-vq_model = SynthesizerTrn(
- hps.data.filter_length // 2 + 1,
- hps.train.segment_size // hps.data.hop_length,
- n_speakers=hps.data.n_speakers,
- **hps.model
-)
-if is_half == True:
- vq_model = vq_model.half().to(device)
-else:
- vq_model = vq_model.to(device)
-vq_model.eval()
-print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
-hz = 50
-max_sec = config["data"]["max_sec"]
-t2s_model = Text2SemanticLightningModule(config, "ojbk", is_train=False)
-t2s_model.load_state_dict(dict_s1["weight"])
-if is_half == True:
- t2s_model = t2s_model.half()
-t2s_model = t2s_model.to(device)
-t2s_model.eval()
-total = sum([param.nelement() for param in t2s_model.parameters()])
-print("Number of parameter: %.2fM" % (total / 1e6))
+def change_sovits_weights(sovits_path):
+ global vq_model,hps
+ dict_s2=torch.load(sovits_path,map_location="cpu")
+ hps=dict_s2["config"]
+ hps = DictToAttrRecursive(hps)
+ hps.model.semantic_frame_rate = "25hz"
+ vq_model = SynthesizerTrn(
+ hps.data.filter_length // 2 + 1,
+ hps.train.segment_size // hps.data.hop_length,
+ n_speakers=hps.data.n_speakers,
+ **hps.model
+ )
+ del vq_model.enc_q
+ if is_half == True:
+ vq_model = vq_model.half().to(device)
+ else:
+ vq_model = vq_model.to(device)
+ vq_model.eval()
+ print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
+change_sovits_weights(sovits_path)
+def change_gpt_weights(gpt_path):
+ global hz,max_sec,t2s_model,config
+ hz = 50
+ dict_s1 = torch.load(gpt_path, map_location="cpu")
+ config = dict_s1["config"]
+ max_sec = config["data"]["max_sec"]
+ t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
+ t2s_model.load_state_dict(dict_s1["weight"])
+ if is_half == True:
+ t2s_model = t2s_model.half()
+ t2s_model = t2s_model.to(device)
+ t2s_model.eval()
+ total = sum([param.nelement() for param in t2s_model.parameters()])
+ print("Number of parameter: %.2fM" % (total / 1e6))
+change_gpt_weights(gpt_path)
def get_spepc(hps, filename):
audio = load_audio(filename, int(hps.data.sampling_rate))
@@ -325,14 +325,46 @@ def cut3(inp):
inp = inp.strip("\n")
return "\n".join(["%s。" % item for item in inp.strip("。").split("。")])
+def custom_sort_key(s):
+ # 使用正则表达式提取字符串中的数字部分和非数字部分
+ parts = re.split('(\d+)', s)
+ # 将数字部分转换为整数,非数字部分保持不变
+ parts = [int(part) if part.isdigit() else part for part in parts]
+ return parts
+
+def change_choices():
+ SoVITS_names, GPT_names = get_weights_names()
+ return {"choices": sorted(SoVITS_names,key=custom_sort_key), "__type__": "update"}, {"choices": sorted(GPT_names,key=custom_sort_key), "__type__": "update"}
+
+pretrained_sovits_name="GPT_SoVITS/pretrained_models/s2G488k.pth"
+pretrained_gpt_name="GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
+SoVITS_weight_root="SoVITS_weights"
+GPT_weight_root="GPT_weights"
+os.makedirs(SoVITS_weight_root,exist_ok=True)
+os.makedirs(GPT_weight_root,exist_ok=True)
+def get_weights_names():
+ SoVITS_names = [pretrained_sovits_name]
+ for name in os.listdir(SoVITS_weight_root):
+ if name.endswith(".pth"):SoVITS_names.append("%s/%s"%(SoVITS_weight_root,name))
+ GPT_names = [pretrained_gpt_name]
+ for name in os.listdir(GPT_weight_root):
+ if name.endswith(".ckpt"): GPT_names.append("%s/%s"%(GPT_weight_root,name))
+ return SoVITS_names,GPT_names
+SoVITS_names,GPT_names = get_weights_names()
with gr.Blocks(title="GPT-SoVITS WebUI") as app:
gr.Markdown(
value=i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.
如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.")
)
- # with gr.Tabs():
- # with gr.TabItem(i18n("伴奏人声分离&去混响&去回声")):
with gr.Group():
+ gr.Markdown(value=i18n("模型切换"))
+ with gr.Row():
+ GPT_dropdown = gr.Dropdown(label=i18n("GPT模型列表"), choices=sorted(GPT_names, key=custom_sort_key), value=gpt_path,interactive=True)
+ SoVITS_dropdown = gr.Dropdown(label=i18n("SoVITS模型列表"), choices=sorted(SoVITS_names, key=custom_sort_key), value=sovits_path,interactive=True)
+ refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary")
+ refresh_button.click(fn=change_choices, inputs=[], outputs=[SoVITS_dropdown, GPT_dropdown])
+ SoVITS_dropdown.change(change_sovits_weights,[SoVITS_dropdown],[])
+ GPT_dropdown.change(change_gpt_weights,[GPT_dropdown],[])
gr.Markdown(value=i18n("*请上传并填写参考信息"))
with gr.Row():
inp_ref = gr.Audio(label=i18n("请上传参考音频"), type="filepath")