修复了t2s模型无prompt输入时的bug GPT_SoVITS/AR/models/t2s_model.py

增加一些新特性,并修复了一些bug   GPT_SoVITS/TTS_infer_pack/TTS.py
	优化网页布局   GPT_SoVITS/inference_webui.py
This commit is contained in:
chasonjiang 2024-03-10 01:20:42 +08:00
parent 2fe3207d71
commit ed2ffe1356
3 changed files with 194 additions and 101 deletions

View File

@ -549,7 +549,6 @@ class Text2SemanticDecoder(nn.Module):
y_list = [None]*y.shape[0]
batch_idx_map = list(range(y.shape[0]))
idx_list = [None]*y.shape[0]
cache_y_emb = y_emb
for idx in tqdm(range(1500)):
if idx == 0:
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask)
@ -589,8 +588,6 @@ class Text2SemanticDecoder(nn.Module):
if reserved_idx_of_batch_for_y is not None:
# index = torch.LongTensor(batch_idx_map).to(y.device)
y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y)
if cache_y_emb is not None:
cache_y_emb = torch.index_select(cache_y_emb, dim=0, index=reserved_idx_of_batch_for_y)
if k_cache is not None :
for i in range(len(k_cache)):
k_cache[i] = torch.index_select(k_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
@ -617,8 +614,8 @@ class Text2SemanticDecoder(nn.Module):
####################### update next step ###################################
y_emb = self.ar_audio_embedding(y[:, -1:])
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx]
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to( dtype= y_emb.dtype,device=y_emb.device)
if (None in idx_list):
for i in range(x.shape[0]):
if idx_list[i] is None:

View File

@ -1,8 +1,7 @@
import os, sys
import ffmpeg
now_dir = os.getcwd()
sys.path.append(now_dir)
import ffmpeg
import os
from typing import Generator, List, Union
import numpy as np
@ -164,6 +163,9 @@ class TTS:
"bert_features":None,
"norm_text":None,
}
self.stop_flag:bool = False
def _init_models(self,):
self.init_t2s_weights(self.configs.t2s_weights_path)
@ -310,7 +312,7 @@ class TTS:
batch = torch.stack(padded_sequences)
return batch
def to_batch(self, data:list, prompt_data:dict=None, batch_size:int=5, threshold:float=0.75):
def to_batch(self, data:list, prompt_data:dict=None, batch_size:int=5, threshold:float=0.75, split_bucket:bool=True):
_data:list = []
index_and_len_list = []
@ -318,30 +320,35 @@ class TTS:
norm_text_len = len(item["norm_text"])
index_and_len_list.append([idx, norm_text_len])
index_and_len_list.sort(key=lambda x: x[1])
# index_and_len_batch_list = [index_and_len_list[idx:min(idx+batch_size,len(index_and_len_list))] for idx in range(0,len(index_and_len_list),batch_size)]
index_and_len_list = np.array(index_and_len_list, dtype=np.int64)
# for batch_idx, index_and_len_batch in enumerate(index_and_len_batch_list):
batch_index_list = []
batch_index_list_len = 0
pos = 0
while pos <index_and_len_list.shape[0]:
# batch_index_list.append(index_and_len_list[pos:min(pos+batch_size,len(index_and_len_list))])
pos_end = min(pos+batch_size,index_and_len_list.shape[0])
while pos < pos_end:
batch=index_and_len_list[pos:pos_end, 1].astype(np.float32)
score=batch[(pos_end-pos)//2]/batch.mean()
if (score>=threshold) or (pos_end-pos==1):
batch_index=index_and_len_list[pos:pos_end, 0].tolist()
batch_index_list_len += len(batch_index)
batch_index_list.append(batch_index)
pos = pos_end
break
pos_end=pos_end-1
assert batch_index_list_len == len(data)
if split_bucket:
index_and_len_list.sort(key=lambda x: x[1])
index_and_len_list = np.array(index_and_len_list, dtype=np.int64)
batch_index_list_len = 0
pos = 0
while pos <index_and_len_list.shape[0]:
# batch_index_list.append(index_and_len_list[pos:min(pos+batch_size,len(index_and_len_list))])
pos_end = min(pos+batch_size,index_and_len_list.shape[0])
while pos < pos_end:
batch=index_and_len_list[pos:pos_end, 1].astype(np.float32)
score=batch[(pos_end-pos)//2]/batch.mean()
if (score>=threshold) or (pos_end-pos==1):
batch_index=index_and_len_list[pos:pos_end, 0].tolist()
batch_index_list_len += len(batch_index)
batch_index_list.append(batch_index)
pos = pos_end
break
pos_end=pos_end-1
assert batch_index_list_len == len(data)
else:
for i in range(len(data)):
if i%batch_size == 0:
batch_index_list.append([])
batch_index_list[-1].append(i)
for batch_idx, index_list in enumerate(batch_index_list):
item_list = [data[idx] for idx in index_list]
@ -399,7 +406,8 @@ class TTS:
_data[index] = data[i][j]
return _data
def stop(self,):
self.stop_flag = True
def run(self, inputs:dict):
@ -409,22 +417,26 @@ class TTS:
Args:
inputs (dict):
{
"text": "",
"text_lang: "",
"ref_audio_path": "",
"prompt_text": "",
"prompt_lang": "",
"top_k": 5,
"top_p": 0.9,
"temperature": 0.6,
"text_split_method": "",
"batch_size": 1,
"batch_threshold": 0.75,
"speed_factor":1.0,
"text": "", # str. text to be synthesized
"text_lang: "", # str. language of the text to be synthesized
"ref_audio_path": "", # str. reference audio path
"prompt_text": "", # str. prompt text for the reference audio
"prompt_lang": "", # str. language of the prompt text for the reference audio
"top_k": 5, # int. top k sampling
"top_p": 0.9, # float. top p sampling
"temperature": 0.6, # float. temperature for sampling
"text_split_method": "", # str. text split method, see text_segmentaion_method.py for details.
"batch_size": 1, # int. batch size for inference
"batch_threshold": 0.75, # float. threshold for batch splitting.
"split_bucket: True, # bool. whether to split the batch into multiple buckets.
"return_fragment": False, # bool. step by step return the audio fragment.
"speed_factor":1.0, # float. control the speed of the synthesized audio.
}
returns:
tulpe[int, np.ndarray]: sampling rate and audio data.
"""
self.stop_flag:bool = False
text:str = inputs.get("text", "")
text_lang:str = inputs.get("text_lang", "")
ref_audio_path:str = inputs.get("ref_audio_path", "")
@ -437,7 +449,20 @@ class TTS:
batch_size = inputs.get("batch_size", 1)
batch_threshold = inputs.get("batch_threshold", 0.75)
speed_factor = inputs.get("speed_factor", 1.0)
split_bucket = inputs.get("split_bucket", True)
return_fragment = inputs.get("return_fragment", False)
if return_fragment:
split_bucket = False
print(i18n("分段返回模式已开启"))
if split_bucket:
split_bucket = False
print(i18n("分段返回模式不支持分桶处理,已自动关闭分桶处理"))
if split_bucket:
print(i18n("分桶处理模式已开启"))
no_prompt_text = False
if prompt_text in [None, ""]:
no_prompt_text = True
@ -481,7 +506,9 @@ class TTS:
data, batch_index_list = self.to_batch(data,
prompt_data=self.prompt_cache if not no_prompt_text else None,
batch_size=batch_size,
threshold=batch_threshold)
threshold=batch_threshold,
split_bucket=split_bucket
)
t2 = ttime()
zero_wav = torch.zeros(
int(self.configs.sampling_rate * 0.3),
@ -557,27 +584,57 @@ class TTS:
audio_fragment.cpu().numpy()
) ###试试重建不带上prompt部分
audio.append(batch_audio_fragment)
# audio.append(zero_wav)
t5 = ttime()
t_45 += t5 - t4
if return_fragment:
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t4 - t3, t5 - t4))
yield self.audio_postprocess(batch_audio_fragment,
self.configs.sampling_rate,
batch_index_list,
speed_factor,
split_bucket)
else:
audio.append(batch_audio_fragment)
if self.stop_flag:
yield self.configs.sampling_rate, (zero_wav.cpu().numpy()).astype(np.int16)
return
if not return_fragment:
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45))
yield self.audio_postprocess(audio,
self.configs.sampling_rate,
batch_index_list,
speed_factor,
split_bucket)
def audio_postprocess(self,
audio:np.ndarray,
sr:int,
batch_index_list:list=None,
speed_factor:float=1.0,
split_bucket:bool=True)->tuple[int, np.ndarray]:
if split_bucket:
audio = self.recovery_order(audio, batch_index_list)
else:
audio = [item for batch in audio for item in batch]
audio = self.recovery_order(audio, batch_index_list)
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t_34, t_45))
audio = np.concatenate(audio, 0)
audio = (audio * 32768).astype(np.int16)
try:
if speed_factor != 1.0:
audio = speed_change(audio, speed=speed_factor, sr=int(self.configs.sampling_rate))
audio = speed_change(audio, speed=speed_factor, sr=int(sr))
except Exception as e:
print(f"Failed to change speed of audio: \n{e}")
yield self.configs.sampling_rate, audio
return sr, audio
def speed_change(input_audio:np.ndarray, speed:float, sr:int):
# 将 NumPy 数组转换为原始 PCM 流

View File

@ -6,8 +6,11 @@
全部按英文识别
全部按日文识别
'''
import os, re, logging
import os, sys
now_dir = os.getcwd()
sys.path.append(now_dir)
import os, re, logging
logging.getLogger("markdown_it").setLevel(logging.ERROR)
logging.getLogger("urllib3").setLevel(logging.ERROR)
logging.getLogger("httpcore").setLevel(logging.ERROR)
@ -18,10 +21,7 @@ logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
import pdb
import torch
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_lightning_module.py
import os, sys
now_dir = os.getcwd()
sys.path.append(now_dir)
infer_ttswebui = os.environ.get("infer_ttswebui", 9872)
infer_ttswebui = int(infer_ttswebui)
@ -34,6 +34,7 @@ import gradio as gr
from TTS_infer_pack.TTS import TTS, TTS_Config
from TTS_infer_pack.text_segmentation_method import cut1, cut2, cut3, cut4, cut5
from tools.i18n.i18n import I18nAuto
from TTS_infer_pack.text_segmentation_method import get_method
i18n = I18nAuto()
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 确保直接启动推理UI时也能够设置。
@ -68,19 +69,28 @@ tts_pipline = TTS(tts_config)
gpt_path = tts_config.t2s_weights_path
sovits_path = tts_config.vits_weights_path
def inference(text, text_lang, ref_audio_path, prompt_text, prompt_lang, top_k, top_p, temperature, text_split_method, batch_size, speed_factor):
def inference(text, text_lang,
ref_audio_path, prompt_text,
prompt_lang, top_k,
top_p, temperature,
text_split_method, batch_size,
speed_factor, ref_text_free,
split_bucket
):
inputs={
"text": text,
"text_lang": dict_language[text_lang],
"ref_audio_path": ref_audio_path,
"prompt_text": prompt_text,
"prompt_text": prompt_text if not ref_text_free else "",
"prompt_lang": dict_language[prompt_lang],
"top_k": top_k,
"top_p": top_p,
"temperature": temperature,
"text_split_method": cut_method[text_split_method],
"batch_size":int(batch_size),
"speed_factor":float(speed_factor)
"speed_factor":float(speed_factor),
"split_bucket":split_bucket,
"return_fragment":False,
}
yield next(tts_pipline.run(inputs))
@ -121,7 +131,9 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
gr.Markdown(
value=i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. <br>如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录<b>LICENSE</b>.")
)
with gr.Group():
with gr.Column():
# with gr.Group():
gr.Markdown(value=i18n("模型切换"))
with gr.Row():
GPT_dropdown = gr.Dropdown(label=i18n("GPT模型列表"), choices=sorted(GPT_names, key=custom_sort_key), value=gpt_path, interactive=True)
@ -130,61 +142,88 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
refresh_button.click(fn=change_choices, inputs=[], outputs=[SoVITS_dropdown, GPT_dropdown])
SoVITS_dropdown.change(tts_pipline.init_vits_weights, [SoVITS_dropdown], [])
GPT_dropdown.change(tts_pipline.init_t2s_weights, [GPT_dropdown], [])
gr.Markdown(value=i18n("*请上传并填写参考信息"))
with gr.Row():
with gr.Row():
with gr.Column():
gr.Markdown(value=i18n("*请上传并填写参考信息"))
inp_ref = gr.Audio(label=i18n("请上传3~10秒内参考音频超过会报错"), type="filepath")
with gr.Column():
ref_text_free = gr.Checkbox(label=i18n("开启无参考文本模式。不填参考文本亦相当于开启。"), value=False, interactive=True, show_label=True)
gr.Markdown(i18n("使用无参考文本模式时建议使用微调的GPT听不清参考音频说的啥(不晓得写啥)可以开,开启后无视填写的参考文本。"))
prompt_text = gr.Textbox(label=i18n("参考音频的文本"), value="")
prompt_language = gr.Dropdown(
label=i18n("参考音频的语种"), choices=[i18n("中文"), i18n("英文"), i18n("日文"), i18n("中英混合"), i18n("日英混合"), i18n("多语种混合")], value=i18n("中文")
)
gr.Markdown(value=i18n("*请填写需要合成的目标文本和语种模式"))
with gr.Row():
text = gr.Textbox(label=i18n("需要合成的文本"), value="")
prompt_text = gr.Textbox(label=i18n("参考音频的文本"), value="", lines=2)
with gr.Row():
prompt_language = gr.Dropdown(
label=i18n("参考音频的语种"), choices=[i18n("中文"), i18n("英文"), i18n("日文"), i18n("中英混合"), i18n("日英混合"), i18n("多语种混合")], value=i18n("中文")
)
with gr.Column():
ref_text_free = gr.Checkbox(label=i18n("开启无参考文本模式。不填参考文本亦相当于开启。"), value=False, interactive=True, show_label=True)
gr.Markdown(i18n("使用无参考文本模式时建议使用微调的GPT听不清参考音频说的啥(不晓得写啥)可以开,开启后无视填写的参考文本。"))
with gr.Column():
gr.Markdown(value=i18n("*请填写需要合成的目标文本和语种模式"))
text = gr.Textbox(label=i18n("需要合成的文本"), value="", lines=16, max_lines=16)
text_language = gr.Dropdown(
label=i18n("需要合成的语种"), choices=[i18n("中文"), i18n("英文"), i18n("日文"), i18n("中英混合"), i18n("日英混合"), i18n("多语种混合")], value=i18n("中文")
)
how_to_cut = gr.Radio(
label=i18n("怎么切"),
choices=[i18n("不切"), i18n("凑四句一切"), i18n("凑50字一切"), i18n("按中文句号。切"), i18n("按英文句号.切"), i18n("按标点符号切"), ],
value=i18n("凑四句一切"),
interactive=True,
)
with gr.Row():
gr.Markdown(value=i18n("gpt采样参数(无参考文本时不要太低)"))
with gr.Group():
gr.Markdown(value=i18n("推理设置"))
with gr.Row():
with gr.Column():
batch_size = gr.Slider(minimum=1,maximum=20,step=1,label=i18n("batch_size"),value=1,interactive=True)
speed_factor = gr.Slider(minimum=0.25,maximum=4,step=0.05,label="speed_factor",value=1.0,interactive=True)
top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True)
top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True)
temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=1,interactive=True)
inference_button = gr.Button(i18n("合成语音"), variant="primary")
output = gr.Audio(label=i18n("输出的语音"))
with gr.Column():
how_to_cut = gr.Radio(
label=i18n("怎么切"),
choices=[i18n("不切"), i18n("凑四句一切"), i18n("凑50字一切"), i18n("按中文句号。切"), i18n("按英文句号.切"), i18n("按标点符号切"), ],
value=i18n("凑四句一切"),
interactive=True,
)
split_bucket = gr.Checkbox(label=i18n("数据分桶(可能会降低一点计算量,选就对了)"), value=True, interactive=True, show_label=True)
# with gr.Column():
output = gr.Audio(label=i18n("输出的语音"))
with gr.Row():
inference_button = gr.Button(i18n("合成语音"), variant="primary")
stop_infer = gr.Button(i18n("终止合成"), variant="primary")
inference_button.click(
inference,
[text,text_language, inp_ref, prompt_text, prompt_language, top_k, top_p, temperature, how_to_cut, batch_size, speed_factor],
[
text,text_language, inp_ref,
prompt_text, prompt_language,
top_k, top_p, temperature,
how_to_cut, batch_size,
speed_factor, ref_text_free,
split_bucket
],
[output],
)
stop_infer.click(tts_pipline.stop, [], [])
with gr.Group():
gr.Markdown(value=i18n("文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。"))
with gr.Row():
text_inp = gr.Textbox(label=i18n("需要合成的切分前文本"), value="")
button1 = gr.Button(i18n("凑四句一切"), variant="primary")
button2 = gr.Button(i18n("凑50字一切"), variant="primary")
button3 = gr.Button(i18n("按中文句号。切"), variant="primary")
button4 = gr.Button(i18n("按英文句号.切"), variant="primary")
button5 = gr.Button(i18n("按标点符号切"), variant="primary")
text_opt = gr.Textbox(label=i18n("切分后文本"), value="")
button1.click(cut1, [text_inp], [text_opt])
button2.click(cut2, [text_inp], [text_opt])
button3.click(cut3, [text_inp], [text_opt])
button4.click(cut4, [text_inp], [text_opt])
button5.click(cut5, [text_inp], [text_opt])
text_inp = gr.Textbox(label=i18n("需要合成的切分前文本"), value="", lines=4)
with gr.Column():
_how_to_cut = gr.Radio(
label=i18n("怎么切"),
choices=[i18n("不切"), i18n("凑四句一切"), i18n("凑50字一切"), i18n("按中文句号。切"), i18n("按英文句号.切"), i18n("按标点符号切"), ],
value=i18n("凑四句一切"),
interactive=True,
)
cut_text= gr.Button(i18n("切分"), variant="primary")
def to_cut(text_inp, how_to_cut):
if len(text_inp.strip()) == 0 or text_inp==[]:
return ""
method = get_method(cut_method[how_to_cut])
return method(text_inp)
text_opt = gr.Textbox(label=i18n("切分后文本"), value="", lines=4)
cut_text.click(to_cut, [text_inp, _how_to_cut], [text_opt])
gr.Markdown(value=i18n("后续将支持转音素、手工修改音素、语音合成分步执行。"))
app.queue(concurrency_count=511, max_size=1022).launch(