Update inference_webui.py

To implement recursive construction while retaining the characteristics of the original dictionary, we can slightly modify the DictToAttrRecursive class. This allows each object to retain its characteristics as a dictionary while accessing keys and values as attributes.
This commit is contained in:
spicysama 2024-01-17 02:34:07 +08:00 committed by GitHub
parent 2078ad1177
commit 8c26982faf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,34 +1,22 @@
import os
gpt_path=os.environ.get("gpt_path","pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt")
sovits_path=os.environ.get("sovits_path","pretrained_models/s2G488k.pth")
cnhubert_base_path=os.environ.get("cnhubert_base_path","pretrained_models/chinese-hubert-base")
bert_path=os.environ.get("bert_path","pretrained_models/chinese-roberta-wwm-ext-large")
infer_ttswebui=os.environ.get("infer_ttswebui",9872)
infer_ttswebui=int(infer_ttswebui)
gpt_path=os.environ.get("gpt_path","./pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt")
sovits_path=os.environ.get("sovits_path","./pretrained_models/s2G488k.pth")
cnhubert_base_path=os.environ.get("cnhubert_base_path","./pretrained_models/chinese-hubert-base")
bert_path=os.environ.get("bert_path","./pretrained_models/chinese-roberta-wwm-ext-large")
if("_CUDA_VISIBLE_DEVICES"in os.environ):
os.environ["CUDA_VISIBLE_DEVICES"]=os.environ["_CUDA_VISIBLE_DEVICES"]
is_half=eval(os.environ.get("is_half","True"))
import gradio as gr
from transformers import AutoModelForMaskedLM, AutoTokenizer
import sys,torch,numpy as np
from pathlib import Path
import os,pdb,utils,librosa,math,traceback,requests,argparse,torch,multiprocessing,pandas as pd,torch.multiprocessing as mp,soundfile
# torch.backends.cuda.sdp_kernel("flash")
# torch.backends.cuda.enable_flash_sdp(True)
# torch.backends.cuda.enable_mem_efficient_sdp(True) # Not avaliable if torch version is lower than 2.0
# torch.backends.cuda.enable_math_sdp(True)
from random import shuffle
from AR.utils import get_newest_ckpt
from glob import glob
from tqdm import tqdm
import numpy as np
import librosa,torch
from feature_extractor import cnhubert
cnhubert.cnhubert_base_path=cnhubert_base_path
from io import BytesIO
from module.models import SynthesizerTrn
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from AR.utils.io import load_yaml_config
from text import cleaned_text_to_sequence
from text.cleaner import text_to_sequence, clean_text
from text.cleaner import clean_text
from time import time as ttime
from module.mel_processing import spectrogram_torch
from my_utils import load_audio
@ -58,14 +46,33 @@ def get_bert_feature(text, word2ph):
n_semantic = 1024
dict_s2=torch.load(sovits_path,map_location="cpu")
hps=dict_s2["config"]
class DictToAttrRecursive:
class DictToAttrRecursive(dict):
def __init__(self, input_dict):
super().__init__(input_dict)
for key, value in input_dict.items():
if isinstance(value, dict):
# 如果值是字典,递归调用构造函数
setattr(self, key, DictToAttrRecursive(value))
else:
setattr(self, key, value)
value = DictToAttrRecursive(value)
self[key] = value
setattr(self, key, value)
def __getattr__(self, item):
try:
return self[item]
except KeyError:
raise AttributeError(f"Attribute {item} not found")
def __setattr__(self, key, value):
if isinstance(value, dict):
value = DictToAttrRecursive(value)
super(DictToAttrRecursive, self).__setitem__(key, value)
super().__setattr__(key, value)
def __delattr__(self, item):
try:
del self[item]
except KeyError:
raise AttributeError(f"Attribute {item} not found")
hps = DictToAttrRecursive(hps)
hps.model.semantic_frame_rate="25hz"
@ -130,9 +137,9 @@ def get_tts_wav(ref_wav_path,prompt_text,prompt_language,text,text_language):
for text in texts:
phones2, word2ph2, norm_text2 = clean_text(text, text_language)
phones2 = cleaned_text_to_sequence(phones2)
if(prompt_language=="zh"):bert1 = get_bert_feature(norm_text1, word2ph1).to(device)
if(prompt_language=="zh"):bert1 = get_bert_feature(norm_text1, word2ph1)
else:bert1 = torch.zeros((1024, len(phones1)),dtype=torch.float16 if is_half==True else torch.float32).to(device)
if(text_language=="zh"):bert2 = get_bert_feature(norm_text2, word2ph2).to(device)
if(text_language=="zh"):bert2 = get_bert_feature(norm_text2, word2ph2)
else:bert2 = torch.zeros((1024, len(phones2))).to(bert1)
bert = torch.cat([bert1, bert2], 1)
@ -234,14 +241,14 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
with gr.Row():
inp_ref = gr.Audio(label="请上传参考音频", type="filepath")
prompt_text= gr.Textbox(label="参考音频的文本",value="")
prompt_language= gr.Dropdown(label="参考音频的语种",choices=["中文","英文","日文"],value="中文")
prompt_language= gr.Dropdown(label="参考音频的语种",choices=["中文","英文","日文"])
gr.Markdown(
value=
"*请填写需要合成的目标文本"
)
with gr.Row():
text=gr.Textbox(label="需要合成的文本",value="")
text_language = gr.Dropdown(label="需要合成的语种", choices=["中文", "英文", "日文"],value="中文")
text_language = gr.Dropdown(label="需要合成的语种", choices=["中文", "英文", "日文"])
inference_button=gr.Button("合成语音", variant="primary")
output = gr.Audio(label="输出的语音")
inference_button.click(get_tts_wav, [inp_ref, prompt_text,prompt_language, text,text_language], [output])
@ -267,6 +274,6 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
app.queue(concurrency_count=511, max_size=1022).launch(
server_name="0.0.0.0",
inbrowser=True,
server_port=infer_ttswebui,
server_port=9872,
quiet=True,
)