diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index 246748a..3046d7a 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -1,4 +1,5 @@ -import os +import os,re +import pdb gpt_path = os.environ.get( "gpt_path", "pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt" @@ -42,8 +43,6 @@ if is_half == True: else: bert_model = bert_model.to(device) - -# bert_model=bert_model.to(device) def get_bert_feature(text, word2ph): with torch.no_grad(): inputs = tokenizer(text, return_tensors="pt") @@ -57,15 +56,8 @@ def get_bert_feature(text, word2ph): repeat_feature = res[i].repeat(word2ph[i], 1) phone_level_feature.append(repeat_feature) phone_level_feature = torch.cat(phone_level_feature, dim=0) - # if(is_half==True):phone_level_feature=phone_level_feature.half() return phone_level_feature.T - -n_semantic = 1024 - -dict_s2=torch.load(sovits_path,map_location="cpu") -hps=dict_s2["config"] - class DictToAttrRecursive(dict): def __init__(self, input_dict): super().__init__(input_dict) @@ -94,40 +86,48 @@ class DictToAttrRecursive(dict): raise AttributeError(f"Attribute {item} not found") -hps = DictToAttrRecursive(hps) - -hps.model.semantic_frame_rate = "25hz" -dict_s1 = torch.load(gpt_path, map_location="cpu") -config = dict_s1["config"] ssl_model = cnhubert.get_model() if is_half == True: ssl_model = ssl_model.half().to(device) else: ssl_model = ssl_model.to(device) -vq_model = SynthesizerTrn( - hps.data.filter_length // 2 + 1, - hps.train.segment_size // hps.data.hop_length, - n_speakers=hps.data.n_speakers, - **hps.model -) -if is_half == True: - vq_model = vq_model.half().to(device) -else: - vq_model = vq_model.to(device) -vq_model.eval() -print(vq_model.load_state_dict(dict_s2["weight"], strict=False)) -hz = 50 -max_sec = config["data"]["max_sec"] -t2s_model = Text2SemanticLightningModule(config, "ojbk", is_train=False) -t2s_model.load_state_dict(dict_s1["weight"]) -if is_half == True: - t2s_model = t2s_model.half() -t2s_model = t2s_model.to(device) -t2s_model.eval() -total = sum([param.nelement() for param in t2s_model.parameters()]) -print("Number of parameter: %.2fM" % (total / 1e6)) +def change_sovits_weights(sovits_path): + global vq_model,hps + dict_s2=torch.load(sovits_path,map_location="cpu") + hps=dict_s2["config"] + hps = DictToAttrRecursive(hps) + hps.model.semantic_frame_rate = "25hz" + vq_model = SynthesizerTrn( + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + n_speakers=hps.data.n_speakers, + **hps.model + ) + del vq_model.enc_q + if is_half == True: + vq_model = vq_model.half().to(device) + else: + vq_model = vq_model.to(device) + vq_model.eval() + print(vq_model.load_state_dict(dict_s2["weight"], strict=False)) +change_sovits_weights(sovits_path) +def change_gpt_weights(gpt_path): + global hz,max_sec,t2s_model,config + hz = 50 + dict_s1 = torch.load(gpt_path, map_location="cpu") + config = dict_s1["config"] + max_sec = config["data"]["max_sec"] + t2s_model = Text2SemanticLightningModule(config, "****", is_train=False) + t2s_model.load_state_dict(dict_s1["weight"]) + if is_half == True: + t2s_model = t2s_model.half() + t2s_model = t2s_model.to(device) + t2s_model.eval() + total = sum([param.nelement() for param in t2s_model.parameters()]) + print("Number of parameter: %.2fM" % (total / 1e6)) +change_gpt_weights(gpt_path) def get_spepc(hps, filename): audio = load_audio(filename, int(hps.data.sampling_rate)) @@ -325,14 +325,46 @@ def cut3(inp): inp = inp.strip("\n") return "\n".join(["%s。" % item for item in inp.strip("。").split("。")]) +def custom_sort_key(s): + # 使用正则表达式提取字符串中的数字部分和非数字部分 + parts = re.split('(\d+)', s) + # 将数字部分转换为整数,非数字部分保持不变 + parts = [int(part) if part.isdigit() else part for part in parts] + return parts + +def change_choices(): + SoVITS_names, GPT_names = get_weights_names() + return {"choices": sorted(SoVITS_names,key=custom_sort_key), "__type__": "update"}, {"choices": sorted(GPT_names,key=custom_sort_key), "__type__": "update"} + +pretrained_sovits_name="GPT_SoVITS/pretrained_models/s2G488k.pth" +pretrained_gpt_name="GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt" +SoVITS_weight_root="SoVITS_weights" +GPT_weight_root="GPT_weights" +os.makedirs(SoVITS_weight_root,exist_ok=True) +os.makedirs(GPT_weight_root,exist_ok=True) +def get_weights_names(): + SoVITS_names = [pretrained_sovits_name] + for name in os.listdir(SoVITS_weight_root): + if name.endswith(".pth"):SoVITS_names.append("%s/%s"%(SoVITS_weight_root,name)) + GPT_names = [pretrained_gpt_name] + for name in os.listdir(GPT_weight_root): + if name.endswith(".ckpt"): GPT_names.append("%s/%s"%(GPT_weight_root,name)) + return SoVITS_names,GPT_names +SoVITS_names,GPT_names = get_weights_names() with gr.Blocks(title="GPT-SoVITS WebUI") as app: gr.Markdown( value=i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.
如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.") ) - # with gr.Tabs(): - # with gr.TabItem(i18n("伴奏人声分离&去混响&去回声")): with gr.Group(): + gr.Markdown(value=i18n("模型切换")) + with gr.Row(): + GPT_dropdown = gr.Dropdown(label=i18n("GPT模型列表"), choices=sorted(GPT_names, key=custom_sort_key), value=gpt_path,interactive=True) + SoVITS_dropdown = gr.Dropdown(label=i18n("SoVITS模型列表"), choices=sorted(SoVITS_names, key=custom_sort_key), value=sovits_path,interactive=True) + refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary") + refresh_button.click(fn=change_choices, inputs=[], outputs=[SoVITS_dropdown, GPT_dropdown]) + SoVITS_dropdown.change(change_sovits_weights,[SoVITS_dropdown],[]) + GPT_dropdown.change(change_gpt_weights,[GPT_dropdown],[]) gr.Markdown(value=i18n("*请上传并填写参考信息")) with gr.Row(): inp_ref = gr.Audio(label=i18n("请上传参考音频"), type="filepath")