mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-04-29 21:00:42 +08:00
feat:Added batch tts option
This commit is contained in:
parent
cb2b844f45
commit
fb50fc090f
@ -1,3 +1,7 @@
|
||||
{
|
||||
"running_on" : "local"
|
||||
"running_on" : "local",
|
||||
"Default":{
|
||||
"GPT_Path": "不训练直接推v3底模!",
|
||||
"SoVITS_Path": "不训练直接推v2ProPlus底模!"
|
||||
}
|
||||
}
|
||||
@ -24,6 +24,7 @@ class CNHubert(nn.Module):
|
||||
super().__init__()
|
||||
if base_path is None:
|
||||
base_path = cnhubert_base_path
|
||||
print(f"Loading CN-Hubert from \"{base_path}\"")
|
||||
if os.path.exists(base_path):
|
||||
...
|
||||
else:
|
||||
@ -69,6 +70,7 @@ class CNHubert(nn.Module):
|
||||
|
||||
|
||||
def get_model():
|
||||
print("cnhubert_base_path:", cnhubert_base_path)
|
||||
model = CNHubert()
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
@ -12,6 +12,7 @@ import sys
|
||||
import json
|
||||
from pathlib import Path
|
||||
import uuid
|
||||
from scipy.io.wavfile import write
|
||||
|
||||
|
||||
|
||||
@ -31,6 +32,7 @@ with open(merge_dir_txt2(get_my_dir(), "config.json"), "r", encoding="utf-8") as
|
||||
config_json = f.read()
|
||||
config_json = json.loads(config_json)
|
||||
running_on = config_json["running_on"]
|
||||
Default = config_json["Default"]
|
||||
|
||||
ROOT_DIR = str(get_parent_dir(get_my_dir()))
|
||||
sys.path.append(get_my_dir())
|
||||
@ -124,6 +126,7 @@ with open("./weight.json", "r", encoding="utf-8") as file:
|
||||
if isinstance(sovits_path, list):
|
||||
sovits_path = sovits_path[0]
|
||||
|
||||
|
||||
# print(2333333)
|
||||
# print(os.environ["gpt_path"])
|
||||
# print(gpt_path)
|
||||
@ -150,7 +153,7 @@ import numpy as np
|
||||
from feature_extractor import cnhubert
|
||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
|
||||
cnhubert.cnhubert_base_path = cnhubert_base_path
|
||||
cnhubert.cnhubert_base_path = merge_dir_txt2(ROOT_DIR, cnhubert_base_path)
|
||||
|
||||
import random
|
||||
|
||||
@ -184,6 +187,12 @@ language = os.environ.get("language", "Auto")
|
||||
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
|
||||
i18n = I18nAuto(language=language)
|
||||
|
||||
|
||||
if gpt_path in [None, "",]:
|
||||
gpt_path = str(merge_dir_txt2(ROOT_DIR, name2gpt_path[i18n(Default["GPT_Path"])]))
|
||||
if sovits_path in [None, "",]:
|
||||
sovits_path = str(merge_dir_txt2(ROOT_DIR, name2sovits_path[i18n(Default["SoVITS_Path"])]))
|
||||
|
||||
# os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 确保直接启动推理UI时也能够设置。
|
||||
|
||||
if torch.cuda.is_available():
|
||||
@ -214,8 +223,8 @@ dict_language_v2 = {
|
||||
}
|
||||
dict_language = dict_language_v1 if version == "v1" else dict_language_v2
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(bert_path)
|
||||
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(str(merge_dir_txt2(ROOT_DIR,bert_path)))
|
||||
bert_model = AutoModelForMaskedLM.from_pretrained(str(merge_dir_txt2(ROOT_DIR,bert_path)))
|
||||
if is_half == True:
|
||||
bert_model = bert_model.half().to(device)
|
||||
else:
|
||||
@ -428,6 +437,7 @@ except:
|
||||
|
||||
|
||||
def change_gpt_weights(gpt_path):
|
||||
print("gpt_path:", gpt_path)
|
||||
if "!" in gpt_path or "!" in gpt_path:
|
||||
gpt_path = name2gpt_path[gpt_path]
|
||||
global hz, max_sec, t2s_model, config
|
||||
@ -1205,7 +1215,204 @@ def get_tts_wav(
|
||||
else:
|
||||
return opt_sr, (audio_opt * 32767).astype(np.int16)
|
||||
|
||||
def batched_tts_wav(
|
||||
ref_wav_path,
|
||||
prompt_text,
|
||||
prompt_language,
|
||||
texts,
|
||||
text_language,
|
||||
how_to_cut=i18n("不切"),
|
||||
top_k=20,
|
||||
top_p=0.6,
|
||||
temperature=0.6,
|
||||
ref_free=False,
|
||||
speed=1,
|
||||
if_freeze=False,
|
||||
inp_refs=None,
|
||||
sample_steps=8,
|
||||
if_sr=False,
|
||||
pause_second=0.3,
|
||||
|
||||
SaveSvEmb=False,
|
||||
SaveRefers=False,
|
||||
SaveSvEmbName="sv_emb.voice",
|
||||
SaveRefersName="refers.voice",
|
||||
|
||||
SaveGE=False,
|
||||
SaveGEName="ge.voice",
|
||||
|
||||
InjectSvEmb=False,
|
||||
InjectRefers=False,
|
||||
InjectSvEmbName="sv_emb.voice",
|
||||
InjectRefersName="refers.voice",
|
||||
|
||||
EnableAudioLoad=True,
|
||||
|
||||
SaveOutputAsUndecoded=False,
|
||||
SaveOutputAsUndecodedName="output.voice",
|
||||
AddRandomSaltToSaveOutputAsUndecodedName=False,
|
||||
|
||||
ReturnWay = "yield", # "yield" or "return"
|
||||
):
|
||||
count = 0
|
||||
out = []
|
||||
SaveDir = merge_dir_txt2(ROOT_DIR,"output","tts_output",f"batch_{uuid.uuid4()}")
|
||||
if not os.path.exists(SaveDir):
|
||||
os.makedirs(SaveDir,exist_ok=True)
|
||||
for text in texts:
|
||||
if text in [None, " ", ""]:
|
||||
gr.Warning(i18n(f"输入文本第{count}行中有空行,已跳过"))
|
||||
continue
|
||||
else:
|
||||
unparsed = get_tts_wav(
|
||||
ref_wav_path,
|
||||
prompt_text,
|
||||
prompt_language,
|
||||
text,
|
||||
text_language,
|
||||
how_to_cut,
|
||||
top_k,
|
||||
top_p,
|
||||
temperature,
|
||||
ref_free,
|
||||
speed,
|
||||
if_freeze,
|
||||
inp_refs,
|
||||
sample_steps,
|
||||
if_sr,
|
||||
pause_second,
|
||||
|
||||
SaveSvEmb,
|
||||
SaveRefers,
|
||||
SaveSvEmbName,
|
||||
SaveRefersName,
|
||||
|
||||
SaveGE,
|
||||
SaveGEName,
|
||||
|
||||
InjectSvEmb,
|
||||
InjectRefers,
|
||||
InjectSvEmbName,
|
||||
InjectRefersName,
|
||||
|
||||
EnableAudioLoad,
|
||||
|
||||
SaveOutputAsUndecoded,
|
||||
SaveOutputAsUndecodedName,
|
||||
AddRandomSaltToSaveOutputAsUndecodedName,
|
||||
"yield",
|
||||
)
|
||||
unparsed = list(unparsed)
|
||||
print(unparsed)
|
||||
a = text.strip().replace(' ','_').replace('\n','_')
|
||||
wav_path = os.path.join(SaveDir,f"tts_output_{a}_{str(uuid.uuid4())}.wav")
|
||||
write(wav_path, unparsed[0][0], unparsed[0][1])
|
||||
out.append(wav_path)
|
||||
count += 1
|
||||
if ReturnWay == "yield":
|
||||
yield SaveDir
|
||||
else:
|
||||
return SaveDir
|
||||
|
||||
def read_tts_batch_file(file_path):
|
||||
ret = []
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
for l in lines:
|
||||
if l.strip() in [None, " ", ""]:
|
||||
continue
|
||||
else:
|
||||
ret.append(l)
|
||||
return ret
|
||||
|
||||
def batch_tts(
|
||||
ref_wav_path,
|
||||
prompt_text,
|
||||
prompt_language,
|
||||
text_paths,
|
||||
text_language,
|
||||
how_to_cut=i18n("不切"),
|
||||
top_k=20,
|
||||
top_p=0.6,
|
||||
temperature=0.6,
|
||||
ref_free=False,
|
||||
speed=1,
|
||||
if_freeze=False,
|
||||
inp_refs=None,
|
||||
sample_steps=8,
|
||||
if_sr=False,
|
||||
pause_second=0.3,
|
||||
|
||||
SaveSvEmb=False,
|
||||
SaveRefers=False,
|
||||
SaveSvEmbName="sv_emb.voice",
|
||||
SaveRefersName="refers.voice",
|
||||
|
||||
SaveGE=False,
|
||||
SaveGEName="ge.voice",
|
||||
|
||||
InjectSvEmb=False,
|
||||
InjectRefers=False,
|
||||
InjectSvEmbName="sv_emb.voice",
|
||||
InjectRefersName="refers.voice",
|
||||
|
||||
EnableAudioLoad=True,
|
||||
|
||||
SaveOutputAsUndecoded=False,
|
||||
SaveOutputAsUndecodedName="output.voice",
|
||||
AddRandomSaltToSaveOutputAsUndecodedName=False,
|
||||
|
||||
ReturnWay = "yield", # "yield" or "return"
|
||||
):
|
||||
print(text_paths)
|
||||
text_list = []
|
||||
for i in text_paths:
|
||||
text_list.extend(read_tts_batch_file(i))
|
||||
out = batched_tts_wav(
|
||||
ref_wav_path,
|
||||
prompt_text,
|
||||
prompt_language,
|
||||
text_list,
|
||||
text_language,
|
||||
how_to_cut,
|
||||
top_k,
|
||||
top_p,
|
||||
temperature,
|
||||
ref_free,
|
||||
speed,
|
||||
if_freeze,
|
||||
inp_refs,
|
||||
sample_steps,
|
||||
if_sr,
|
||||
pause_second,
|
||||
|
||||
SaveSvEmb,
|
||||
SaveRefers,
|
||||
SaveSvEmbName,
|
||||
SaveRefersName,
|
||||
|
||||
SaveGE,
|
||||
SaveGEName,
|
||||
|
||||
InjectSvEmb,
|
||||
InjectRefers,
|
||||
InjectSvEmbName,
|
||||
InjectRefersName,
|
||||
|
||||
EnableAudioLoad,
|
||||
|
||||
SaveOutputAsUndecoded,
|
||||
SaveOutputAsUndecodedName,
|
||||
AddRandomSaltToSaveOutputAsUndecodedName,
|
||||
|
||||
"yield"
|
||||
)
|
||||
out = list(out)
|
||||
|
||||
if ReturnWay == "yield":
|
||||
yield out
|
||||
else:
|
||||
return out
|
||||
def close_serv():
|
||||
if running_on == "local":
|
||||
sys.exit(0)
|
||||
@ -1540,6 +1747,25 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
|
||||
show_label=True,
|
||||
visible=False if model_version != "v3" else True,
|
||||
)
|
||||
with gr.Row():
|
||||
gr.Markdown(html_center(i18n("批量语音合成参数"), "h3"))
|
||||
with gr.Column(scale=13):
|
||||
txt_paths = gr.File(label=i18n("批量语音合成文本文件,每行一个文本"),
|
||||
file_types=[".txt"],
|
||||
interactive=True,
|
||||
file_count="multiple",
|
||||
scale=13)
|
||||
with gr.Column(scale=7):
|
||||
out = gr.File(label=i18n("批量合成输出的语音文件"),
|
||||
file_types=[".wav"],
|
||||
file_count="directory",)
|
||||
start_batch_btn = gr.Button(i18n("开始批量合成"),
|
||||
variant="primary",
|
||||
size="lg",
|
||||
interactive=True,
|
||||
scale=25)
|
||||
|
||||
|
||||
gr.Markdown(html_center(i18n("*请填写需要合成的目标文本和语种模式"), "h3"))
|
||||
with gr.Row():
|
||||
with gr.Column(scale=13):
|
||||
@ -1648,7 +1874,51 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
|
||||
],
|
||||
[output],
|
||||
|
||||
api_name="get_tts_wav",
|
||||
)
|
||||
|
||||
start_batch_btn.click(
|
||||
batch_tts,
|
||||
[
|
||||
inp_ref,
|
||||
prompt_text,
|
||||
prompt_language,
|
||||
txt_paths,
|
||||
text_language,
|
||||
how_to_cut,
|
||||
top_k,
|
||||
top_p,
|
||||
temperature,
|
||||
ref_text_free,
|
||||
speed,
|
||||
if_freeze,
|
||||
inp_refs,
|
||||
sample_steps,
|
||||
if_sr_Checkbox,
|
||||
pause_second_slider,
|
||||
|
||||
SaveSvEmb,
|
||||
SaveRefers,
|
||||
SaveSvEmbName,
|
||||
SaveRefersName,
|
||||
SaveGE,
|
||||
SaveGEName,
|
||||
InjectSvEmb,
|
||||
InjectRefers,
|
||||
InjectSvEmbName,
|
||||
InjectRefersName,
|
||||
EnableAudioLoad,
|
||||
|
||||
SaveOutputAsUndecoded,
|
||||
SaveOutputAsUndecodedName,
|
||||
AddRandomSaltToSaveOutputAsUndecodedName,
|
||||
|
||||
],
|
||||
[out],
|
||||
|
||||
api_name="batch_tts",
|
||||
)
|
||||
|
||||
SoVITS_dropdown.change(
|
||||
change_sovits_weights,
|
||||
[SoVITS_dropdown, prompt_language, text_language],
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
import sys
|
||||
import os
|
||||
import torch
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.append(f"{os.getcwd()}/GPT_SoVITS/eres2net")
|
||||
sv_path = "GPT_SoVITS/pretrained_models/sv/pretrained_eres2netv2w24s4ep4.ckpt"
|
||||
sys.path.append(f"{str(Path(os.path.dirname(os.path.abspath(__file__))).parent)}/GPT_SoVITS/eres2net")
|
||||
sv_path = f"{str(Path(os.path.dirname(os.path.abspath(__file__))).parent)}/GPT_SoVITS/pretrained_models/sv/pretrained_eres2netv2w24s4ep4.ckpt"
|
||||
from ERes2NetV2 import ERes2NetV2
|
||||
import kaldi as Kaldi
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user