From 8c26982fafff550a8068a329eabcc355bd3bf65a Mon Sep 17 00:00:00 2001 From: spicysama <122108331+AnyaCoder@users.noreply.github.com> Date: Wed, 17 Jan 2024 02:34:07 +0800 Subject: [PATCH] Update inference_webui.py To implement recursive construction while retaining the characteristics of the original dictionary, we can slightly modify the DictToAttrRecursive class. This allows each object to retain its characteristics as a dictionary while accessing keys and values as attributes. --- GPT_SoVITS/inference_webui.py | 69 +++++++++++++++++++---------------- 1 file changed, 38 insertions(+), 31 deletions(-) diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index 4917d32f..230ba1a7 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -1,34 +1,22 @@ import os -gpt_path=os.environ.get("gpt_path","pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt") -sovits_path=os.environ.get("sovits_path","pretrained_models/s2G488k.pth") -cnhubert_base_path=os.environ.get("cnhubert_base_path","pretrained_models/chinese-hubert-base") -bert_path=os.environ.get("bert_path","pretrained_models/chinese-roberta-wwm-ext-large") -infer_ttswebui=os.environ.get("infer_ttswebui",9872) -infer_ttswebui=int(infer_ttswebui) +gpt_path=os.environ.get("gpt_path","./pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt") +sovits_path=os.environ.get("sovits_path","./pretrained_models/s2G488k.pth") +cnhubert_base_path=os.environ.get("cnhubert_base_path","./pretrained_models/chinese-hubert-base") +bert_path=os.environ.get("bert_path","./pretrained_models/chinese-roberta-wwm-ext-large") if("_CUDA_VISIBLE_DEVICES"in os.environ): os.environ["CUDA_VISIBLE_DEVICES"]=os.environ["_CUDA_VISIBLE_DEVICES"] is_half=eval(os.environ.get("is_half","True")) import gradio as gr from transformers import AutoModelForMaskedLM, AutoTokenizer -import sys,torch,numpy as np -from pathlib import Path -import os,pdb,utils,librosa,math,traceback,requests,argparse,torch,multiprocessing,pandas as pd,torch.multiprocessing as mp,soundfile -# torch.backends.cuda.sdp_kernel("flash") -# torch.backends.cuda.enable_flash_sdp(True) -# torch.backends.cuda.enable_mem_efficient_sdp(True) # Not avaliable if torch version is lower than 2.0 -# torch.backends.cuda.enable_math_sdp(True) -from random import shuffle -from AR.utils import get_newest_ckpt -from glob import glob -from tqdm import tqdm +import numpy as np +import librosa,torch from feature_extractor import cnhubert cnhubert.cnhubert_base_path=cnhubert_base_path -from io import BytesIO + from module.models import SynthesizerTrn from AR.models.t2s_lightning_module import Text2SemanticLightningModule -from AR.utils.io import load_yaml_config from text import cleaned_text_to_sequence -from text.cleaner import text_to_sequence, clean_text +from text.cleaner import clean_text from time import time as ttime from module.mel_processing import spectrogram_torch from my_utils import load_audio @@ -58,14 +46,33 @@ def get_bert_feature(text, word2ph): n_semantic = 1024 dict_s2=torch.load(sovits_path,map_location="cpu") hps=dict_s2["config"] -class DictToAttrRecursive: + +class DictToAttrRecursive(dict): def __init__(self, input_dict): + super().__init__(input_dict) for key, value in input_dict.items(): if isinstance(value, dict): - # 如果值是字典,递归调用构造函数 - setattr(self, key, DictToAttrRecursive(value)) - else: - setattr(self, key, value) + value = DictToAttrRecursive(value) + self[key] = value + setattr(self, key, value) + + def __getattr__(self, item): + try: + return self[item] + except KeyError: + raise AttributeError(f"Attribute {item} not found") + + def __setattr__(self, key, value): + if isinstance(value, dict): + value = DictToAttrRecursive(value) + super(DictToAttrRecursive, self).__setitem__(key, value) + super().__setattr__(key, value) + + def __delattr__(self, item): + try: + del self[item] + except KeyError: + raise AttributeError(f"Attribute {item} not found") hps = DictToAttrRecursive(hps) hps.model.semantic_frame_rate="25hz" @@ -130,9 +137,9 @@ def get_tts_wav(ref_wav_path,prompt_text,prompt_language,text,text_language): for text in texts: phones2, word2ph2, norm_text2 = clean_text(text, text_language) phones2 = cleaned_text_to_sequence(phones2) - if(prompt_language=="zh"):bert1 = get_bert_feature(norm_text1, word2ph1).to(device) + if(prompt_language=="zh"):bert1 = get_bert_feature(norm_text1, word2ph1) else:bert1 = torch.zeros((1024, len(phones1)),dtype=torch.float16 if is_half==True else torch.float32).to(device) - if(text_language=="zh"):bert2 = get_bert_feature(norm_text2, word2ph2).to(device) + if(text_language=="zh"):bert2 = get_bert_feature(norm_text2, word2ph2) else:bert2 = torch.zeros((1024, len(phones2))).to(bert1) bert = torch.cat([bert1, bert2], 1) @@ -234,14 +241,14 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app: with gr.Row(): inp_ref = gr.Audio(label="请上传参考音频", type="filepath") prompt_text= gr.Textbox(label="参考音频的文本",value="") - prompt_language= gr.Dropdown(label="参考音频的语种",choices=["中文","英文","日文"],value="中文") + prompt_language= gr.Dropdown(label="参考音频的语种",choices=["中文","英文","日文"]) gr.Markdown( value= "*请填写需要合成的目标文本" ) with gr.Row(): text=gr.Textbox(label="需要合成的文本",value="") - text_language = gr.Dropdown(label="需要合成的语种", choices=["中文", "英文", "日文"],value="中文") + text_language = gr.Dropdown(label="需要合成的语种", choices=["中文", "英文", "日文"]) inference_button=gr.Button("合成语音", variant="primary") output = gr.Audio(label="输出的语音") inference_button.click(get_tts_wav, [inp_ref, prompt_text,prompt_language, text,text_language], [output]) @@ -267,6 +274,6 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app: app.queue(concurrency_count=511, max_size=1022).launch( server_name="0.0.0.0", inbrowser=True, - server_port=infer_ttswebui, + server_port=9872, quiet=True, -) \ No newline at end of file +)