mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-30 06:22:52 +08:00
Add files via upload
This commit is contained in:
parent
16e9040056
commit
5f68cc072a
@ -1,4 +1,5 @@
|
|||||||
import os
|
import os,re
|
||||||
|
import pdb
|
||||||
|
|
||||||
gpt_path = os.environ.get(
|
gpt_path = os.environ.get(
|
||||||
"gpt_path", "pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
|
"gpt_path", "pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
|
||||||
@ -42,8 +43,6 @@ if is_half == True:
|
|||||||
else:
|
else:
|
||||||
bert_model = bert_model.to(device)
|
bert_model = bert_model.to(device)
|
||||||
|
|
||||||
|
|
||||||
# bert_model=bert_model.to(device)
|
|
||||||
def get_bert_feature(text, word2ph):
|
def get_bert_feature(text, word2ph):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
inputs = tokenizer(text, return_tensors="pt")
|
inputs = tokenizer(text, return_tensors="pt")
|
||||||
@ -57,15 +56,8 @@ def get_bert_feature(text, word2ph):
|
|||||||
repeat_feature = res[i].repeat(word2ph[i], 1)
|
repeat_feature = res[i].repeat(word2ph[i], 1)
|
||||||
phone_level_feature.append(repeat_feature)
|
phone_level_feature.append(repeat_feature)
|
||||||
phone_level_feature = torch.cat(phone_level_feature, dim=0)
|
phone_level_feature = torch.cat(phone_level_feature, dim=0)
|
||||||
# if(is_half==True):phone_level_feature=phone_level_feature.half()
|
|
||||||
return phone_level_feature.T
|
return phone_level_feature.T
|
||||||
|
|
||||||
|
|
||||||
n_semantic = 1024
|
|
||||||
|
|
||||||
dict_s2=torch.load(sovits_path,map_location="cpu")
|
|
||||||
hps=dict_s2["config"]
|
|
||||||
|
|
||||||
class DictToAttrRecursive(dict):
|
class DictToAttrRecursive(dict):
|
||||||
def __init__(self, input_dict):
|
def __init__(self, input_dict):
|
||||||
super().__init__(input_dict)
|
super().__init__(input_dict)
|
||||||
@ -94,40 +86,48 @@ class DictToAttrRecursive(dict):
|
|||||||
raise AttributeError(f"Attribute {item} not found")
|
raise AttributeError(f"Attribute {item} not found")
|
||||||
|
|
||||||
|
|
||||||
hps = DictToAttrRecursive(hps)
|
|
||||||
|
|
||||||
hps.model.semantic_frame_rate = "25hz"
|
|
||||||
dict_s1 = torch.load(gpt_path, map_location="cpu")
|
|
||||||
config = dict_s1["config"]
|
|
||||||
ssl_model = cnhubert.get_model()
|
ssl_model = cnhubert.get_model()
|
||||||
if is_half == True:
|
if is_half == True:
|
||||||
ssl_model = ssl_model.half().to(device)
|
ssl_model = ssl_model.half().to(device)
|
||||||
else:
|
else:
|
||||||
ssl_model = ssl_model.to(device)
|
ssl_model = ssl_model.to(device)
|
||||||
|
|
||||||
vq_model = SynthesizerTrn(
|
def change_sovits_weights(sovits_path):
|
||||||
|
global vq_model,hps
|
||||||
|
dict_s2=torch.load(sovits_path,map_location="cpu")
|
||||||
|
hps=dict_s2["config"]
|
||||||
|
hps = DictToAttrRecursive(hps)
|
||||||
|
hps.model.semantic_frame_rate = "25hz"
|
||||||
|
vq_model = SynthesizerTrn(
|
||||||
hps.data.filter_length // 2 + 1,
|
hps.data.filter_length // 2 + 1,
|
||||||
hps.train.segment_size // hps.data.hop_length,
|
hps.train.segment_size // hps.data.hop_length,
|
||||||
n_speakers=hps.data.n_speakers,
|
n_speakers=hps.data.n_speakers,
|
||||||
**hps.model
|
**hps.model
|
||||||
)
|
)
|
||||||
if is_half == True:
|
del vq_model.enc_q
|
||||||
|
if is_half == True:
|
||||||
vq_model = vq_model.half().to(device)
|
vq_model = vq_model.half().to(device)
|
||||||
else:
|
else:
|
||||||
vq_model = vq_model.to(device)
|
vq_model = vq_model.to(device)
|
||||||
vq_model.eval()
|
vq_model.eval()
|
||||||
print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
|
print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
|
||||||
hz = 50
|
change_sovits_weights(sovits_path)
|
||||||
max_sec = config["data"]["max_sec"]
|
|
||||||
t2s_model = Text2SemanticLightningModule(config, "ojbk", is_train=False)
|
|
||||||
t2s_model.load_state_dict(dict_s1["weight"])
|
|
||||||
if is_half == True:
|
|
||||||
t2s_model = t2s_model.half()
|
|
||||||
t2s_model = t2s_model.to(device)
|
|
||||||
t2s_model.eval()
|
|
||||||
total = sum([param.nelement() for param in t2s_model.parameters()])
|
|
||||||
print("Number of parameter: %.2fM" % (total / 1e6))
|
|
||||||
|
|
||||||
|
def change_gpt_weights(gpt_path):
|
||||||
|
global hz,max_sec,t2s_model,config
|
||||||
|
hz = 50
|
||||||
|
dict_s1 = torch.load(gpt_path, map_location="cpu")
|
||||||
|
config = dict_s1["config"]
|
||||||
|
max_sec = config["data"]["max_sec"]
|
||||||
|
t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
|
||||||
|
t2s_model.load_state_dict(dict_s1["weight"])
|
||||||
|
if is_half == True:
|
||||||
|
t2s_model = t2s_model.half()
|
||||||
|
t2s_model = t2s_model.to(device)
|
||||||
|
t2s_model.eval()
|
||||||
|
total = sum([param.nelement() for param in t2s_model.parameters()])
|
||||||
|
print("Number of parameter: %.2fM" % (total / 1e6))
|
||||||
|
change_gpt_weights(gpt_path)
|
||||||
|
|
||||||
def get_spepc(hps, filename):
|
def get_spepc(hps, filename):
|
||||||
audio = load_audio(filename, int(hps.data.sampling_rate))
|
audio = load_audio(filename, int(hps.data.sampling_rate))
|
||||||
@ -325,14 +325,46 @@ def cut3(inp):
|
|||||||
inp = inp.strip("\n")
|
inp = inp.strip("\n")
|
||||||
return "\n".join(["%s。" % item for item in inp.strip("。").split("。")])
|
return "\n".join(["%s。" % item for item in inp.strip("。").split("。")])
|
||||||
|
|
||||||
|
def custom_sort_key(s):
|
||||||
|
# 使用正则表达式提取字符串中的数字部分和非数字部分
|
||||||
|
parts = re.split('(\d+)', s)
|
||||||
|
# 将数字部分转换为整数,非数字部分保持不变
|
||||||
|
parts = [int(part) if part.isdigit() else part for part in parts]
|
||||||
|
return parts
|
||||||
|
|
||||||
|
def change_choices():
|
||||||
|
SoVITS_names, GPT_names = get_weights_names()
|
||||||
|
return {"choices": sorted(SoVITS_names,key=custom_sort_key), "__type__": "update"}, {"choices": sorted(GPT_names,key=custom_sort_key), "__type__": "update"}
|
||||||
|
|
||||||
|
pretrained_sovits_name="GPT_SoVITS/pretrained_models/s2G488k.pth"
|
||||||
|
pretrained_gpt_name="GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
|
||||||
|
SoVITS_weight_root="SoVITS_weights"
|
||||||
|
GPT_weight_root="GPT_weights"
|
||||||
|
os.makedirs(SoVITS_weight_root,exist_ok=True)
|
||||||
|
os.makedirs(GPT_weight_root,exist_ok=True)
|
||||||
|
def get_weights_names():
|
||||||
|
SoVITS_names = [pretrained_sovits_name]
|
||||||
|
for name in os.listdir(SoVITS_weight_root):
|
||||||
|
if name.endswith(".pth"):SoVITS_names.append("%s/%s"%(SoVITS_weight_root,name))
|
||||||
|
GPT_names = [pretrained_gpt_name]
|
||||||
|
for name in os.listdir(GPT_weight_root):
|
||||||
|
if name.endswith(".ckpt"): GPT_names.append("%s/%s"%(GPT_weight_root,name))
|
||||||
|
return SoVITS_names,GPT_names
|
||||||
|
SoVITS_names,GPT_names = get_weights_names()
|
||||||
|
|
||||||
with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
||||||
gr.Markdown(
|
gr.Markdown(
|
||||||
value=i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. <br>如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录<b>LICENSE</b>.")
|
value=i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. <br>如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录<b>LICENSE</b>.")
|
||||||
)
|
)
|
||||||
# with gr.Tabs():
|
|
||||||
# with gr.TabItem(i18n("伴奏人声分离&去混响&去回声")):
|
|
||||||
with gr.Group():
|
with gr.Group():
|
||||||
|
gr.Markdown(value=i18n("模型切换"))
|
||||||
|
with gr.Row():
|
||||||
|
GPT_dropdown = gr.Dropdown(label=i18n("GPT模型列表"), choices=sorted(GPT_names, key=custom_sort_key), value=gpt_path,interactive=True)
|
||||||
|
SoVITS_dropdown = gr.Dropdown(label=i18n("SoVITS模型列表"), choices=sorted(SoVITS_names, key=custom_sort_key), value=sovits_path,interactive=True)
|
||||||
|
refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary")
|
||||||
|
refresh_button.click(fn=change_choices, inputs=[], outputs=[SoVITS_dropdown, GPT_dropdown])
|
||||||
|
SoVITS_dropdown.change(change_sovits_weights,[SoVITS_dropdown],[])
|
||||||
|
GPT_dropdown.change(change_gpt_weights,[GPT_dropdown],[])
|
||||||
gr.Markdown(value=i18n("*请上传并填写参考信息"))
|
gr.Markdown(value=i18n("*请上传并填写参考信息"))
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
inp_ref = gr.Audio(label=i18n("请上传参考音频"), type="filepath")
|
inp_ref = gr.Audio(label=i18n("请上传参考音频"), type="filepath")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user