Added support for inference text slicing in api.py

This commit is contained in:
zih-an 2024-03-06 16:15:13 +00:00
parent 9be39a8739
commit e9625c3a9e

206
api.py
View File

@ -35,7 +35,8 @@ POST:
```json
{
"text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。",
"text_language": "zh"
"text_language": "zh",
"slice": "按标点符号切"
}
```
@ -120,7 +121,7 @@ RESP: 无
import argparse
import os
import os, re
import sys
now_dir = os.getcwd()
@ -166,6 +167,7 @@ parser.add_argument("-a", "--bind_addr", type=str, default="0.0.0.0", help="defa
parser.add_argument("-p", "--port", type=int, default=g_config.api_port, help="default: 9880")
parser.add_argument("-fp", "--full_precision", action="store_true", default=False, help="覆盖config.is_half为False, 使用全精度")
parser.add_argument("-hp", "--half_precision", action="store_true", default=False, help="覆盖config.is_half为True, 使用半精度")
#parser.add_argument("-sl", "--slice", type=str, default="No slice", help="文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。")
# bool值的用法为 `python ./api.py -fp ...`
# 此时 full_precision==True, half_precision==False
@ -375,6 +377,19 @@ dict_language = {
"多语种混合": "auto"
}
slice_option = {
"凑四句一切": "凑四句一切",
"凑50字一切": "凑50字一切",
"按中文句号。切": "按中文句号。切",
"按英文句号.切": "按英文句号.切",
"按标点符号切": "按标点符号切",
"per 4 sentences": "凑四句一切",
"per 50 letters": "凑50字一切",
"per period": "按英文句号.切",
"per punctuation mark": "按标点符号切",
None: "No slice"
}
dtype=torch.float16 if is_half == True else torch.float32
def get_bert_inf(phones, word2ph, norm_text, language):
language=language.replace("all_","")
@ -447,13 +462,41 @@ def get_phones_and_bert(text,language):
return phones,bert.to(dtype),norm_text
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language):
def merge_short_text_in_array(texts, threshold):
if (len(texts)) < 2:
return texts
result = []
text = ""
for ele in texts:
text += ele
if len(text) >= threshold:
result.append(text)
text = ""
if (len(text) > 0):
if len(result) == 0:
result.append(text)
else:
result[len(result) - 1] += text
return result
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut):
# not supporting ref_free
t0 = ttime()
prompt_language = dict_language[prompt_language]
prompt_text = prompt_text.strip("\n")
prompt_language, text = prompt_language, text.strip("\n")
zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32)
if (prompt_text[-1] not in splits): prompt_text += "" if prompt_language != "en" else "."
print(f"实际输入的参考文本: {prompt_text}")
text = text.strip("\n")
if (text[0] not in splits and len(get_first(text)) < 4): text = "" + text if text_language != "en" else "." + text
print(f"实际输入的目标文本: {text}")
zero_wav = np.zeros(
int(hps.data.sampling_rate * 0.3),
dtype=np.float16 if is_half == True else np.float32
)
with torch.no_grad():
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
# neglected error checking for reference audio duration
wav16k = torch.from_numpy(wav16k)
zero_wav_torch = torch.from_numpy(zero_wav)
if (is_half == True):
@ -463,20 +506,48 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language)
wav16k = wav16k.to(device)
zero_wav_torch = zero_wav_torch.to(device)
wav16k = torch.cat([wav16k, zero_wav_torch])
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
"last_hidden_state"
].transpose(
1, 2
) # .float()
codes = vq_model.extract_latent(ssl_content)
prompt_semantic = codes[0, 0]
t1 = ttime()
prompt_language = dict_language[prompt_language]
text_language = dict_language[text_language]
phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language)
phones1 = cleaned_text_to_sequence(phones1)
# 文本切句
# default to no slice if no argument is provided
how_to_cut = slice_option[how_to_cut]
print(f"[INFO] 文本切句選項: {how_to_cut}")
if (how_to_cut == "凑四句一切"):
text = cut1(text)
elif (how_to_cut == "凑50字一切"):
text = cut2(text)
elif (how_to_cut == "按英文句号.切"):
text = cut3(text)
elif (how_to_cut == "按英文句号.切"):
text = cut4(text)
elif (how_to_cut == "按标点符号切"):
text = cut5(text)
while "\n\n" in text:
text = text.replace("\n\n", "\n")
print(f"实际输入的目标文本(切句后): {text}")
texts = text.split("\n")
texts = merge_short_text_in_array(texts, 5)
audio_opt = []
phones1,bert1,norm_text1=get_phones_and_bert(prompt_text, prompt_language)
for text in texts:
# 解决输入目标文本的空行导致报错的问题
if (len(text.strip()) == 0):
continue
if (text[-1] not in splits): text += "" if text_language != "en" else "."
print(f"实际输入的目标文本(每句): {text}")
phones2,bert2,norm_text2=get_phones_and_bert(text, text_language)
print(f"前端处理后的文本(每句): {norm_text2}")
bert = torch.cat([bert1, bert2], 1)
all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0)
@ -496,23 +567,118 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language)
early_stop_num=hz * max_sec)
t3 = ttime()
# print(pred_semantic.shape,idx)
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次
pred_semantic = pred_semantic[:, -idx:].unsqueeze(
0
) # .unsqueeze(0)#mq要多unsqueeze一次
refer = get_spepc(hps, ref_wav_path) # .to(device)
if (is_half == True):
refer = refer.half().to(device)
else:
refer = refer.to(device)
# audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
audio = \
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0),
refer).detach().cpu().numpy()[
0, 0] ###试试重建不带上prompt部分
audio = (
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer
)
.detach()
.cpu()
.numpy()[0, 0]
) ###试试重建不带上prompt部分
max_audio=np.abs(audio).max()#简单防止16bit爆音
if max_audio>1:audio/=max_audio
audio_opt.append(audio)
audio_opt.append(zero_wav)
t4 = ttime()
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(np.int16)
yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(
np.int16
)
splits = {"", "", "", "", ",", ".", "?", "!", "~", ":", "", "", "", }
def get_first(text):
pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
text = re.split(pattern, text)[0].strip()
return text
def split(todo_text):
todo_text = todo_text.replace("……", "").replace("——", "")
if todo_text[-1] not in splits:
todo_text += ""
i_split_head = i_split_tail = 0
len_text = len(todo_text)
todo_texts = []
while 1:
if i_split_head >= len_text:
break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
if todo_text[i_split_head] in splits:
i_split_head += 1
todo_texts.append(todo_text[i_split_tail:i_split_head])
i_split_tail = i_split_head
else:
i_split_head += 1
return todo_texts
def cut1(inp):
inp = inp.strip("\n")
inps = split(inp)
split_idx = list(range(0, len(inps), 4))
split_idx[-1] = None
if len(split_idx) > 1:
opts = []
for idx in range(len(split_idx) - 1):
opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]]))
else:
opts = [inp]
return "\n".join(opts)
def cut2(inp):
inp = inp.strip("\n")
inps = split(inp)
if len(inps) < 2:
return inp
opts = []
summ = 0
tmp_str = ""
for i in range(len(inps)):
summ += len(inps[i])
tmp_str += inps[i]
if summ > 50:
summ = 0
opts.append(tmp_str)
tmp_str = ""
if tmp_str != "":
opts.append(tmp_str)
# print(opts)
if len(opts) > 1 and len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起
opts[-2] = opts[-2] + opts[-1]
opts = opts[:-1]
return "\n".join(opts)
def cut3(inp):
inp = inp.strip("\n")
return "\n".join(["%s" % item for item in inp.strip("").split("")])
def cut4(inp):
inp = inp.strip("\n")
return "\n".join(["%s" % item for item in inp.strip(".").split(".")])
# contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py
def cut5(inp):
# if not re.search(r'[^\w\s]', inp[-1]):
# inp += '。'
inp = inp.strip("\n")
punds = r'[,.;?!、,。?!;:…]'
items = re.split(f'({punds})', inp)
mergeitems = ["".join(group) for group in zip(items[::2], items[1::2])]
# 在句子不存在符号或句尾无符号的时候保证文本完整
if len(items)%2 == 1:
mergeitems.append(items[-1])
opt = "\n".join(mergeitems)
return opt
def handle_control(command):
if command == "restart":
@ -541,7 +707,7 @@ def handle_change(path, text, language):
return JSONResponse({"code": 0, "message": "Success"}, status_code=200)
def handle(refer_wav_path, prompt_text, prompt_language, text, text_language):
def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, slice):
if (
refer_wav_path == "" or refer_wav_path is None
or prompt_text == "" or prompt_text is None
@ -557,7 +723,7 @@ def handle(refer_wav_path, prompt_text, prompt_language, text, text_language):
with torch.no_grad():
gen = get_tts_wav(
refer_wav_path, prompt_text, prompt_language, text, text_language
refer_wav_path, prompt_text, prompt_language, text, text_language, slice
)
sampling_rate, audio_data = next(gen)
@ -628,6 +794,7 @@ async def tts_endpoint(request: Request):
json_post_raw.get("prompt_language"),
json_post_raw.get("text"),
json_post_raw.get("text_language"),
json_post_raw.get("slice"),
)
@ -638,8 +805,9 @@ async def tts_endpoint(
prompt_language: str = None,
text: str = None,
text_language: str = None,
slice: str = None
):
return handle(refer_wav_path, prompt_text, prompt_language, text, text_language)
return handle(refer_wav_path, prompt_text, prompt_language, text, text_language, slice)
if __name__ == "__main__":