mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-07 15:19:59 +08:00
Added support for inference text slicing in api.py
This commit is contained in:
parent
9be39a8739
commit
e9625c3a9e
206
api.py
206
api.py
@ -35,7 +35,8 @@ POST:
|
||||
```json
|
||||
{
|
||||
"text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。",
|
||||
"text_language": "zh"
|
||||
"text_language": "zh",
|
||||
"slice": "按标点符号切"
|
||||
}
|
||||
```
|
||||
|
||||
@ -120,7 +121,7 @@ RESP: 无
|
||||
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import os, re
|
||||
import sys
|
||||
|
||||
now_dir = os.getcwd()
|
||||
@ -166,6 +167,7 @@ parser.add_argument("-a", "--bind_addr", type=str, default="0.0.0.0", help="defa
|
||||
parser.add_argument("-p", "--port", type=int, default=g_config.api_port, help="default: 9880")
|
||||
parser.add_argument("-fp", "--full_precision", action="store_true", default=False, help="覆盖config.is_half为False, 使用全精度")
|
||||
parser.add_argument("-hp", "--half_precision", action="store_true", default=False, help="覆盖config.is_half为True, 使用半精度")
|
||||
#parser.add_argument("-sl", "--slice", type=str, default="No slice", help="文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。")
|
||||
# bool值的用法为 `python ./api.py -fp ...`
|
||||
# 此时 full_precision==True, half_precision==False
|
||||
|
||||
@ -375,6 +377,19 @@ dict_language = {
|
||||
"多语种混合": "auto"
|
||||
}
|
||||
|
||||
slice_option = {
|
||||
"凑四句一切": "凑四句一切",
|
||||
"凑50字一切": "凑50字一切",
|
||||
"按中文句号。切": "按中文句号。切",
|
||||
"按英文句号.切": "按英文句号.切",
|
||||
"按标点符号切": "按标点符号切",
|
||||
"per 4 sentences": "凑四句一切",
|
||||
"per 50 letters": "凑50字一切",
|
||||
"per period": "按英文句号.切",
|
||||
"per punctuation mark": "按标点符号切",
|
||||
None: "No slice"
|
||||
}
|
||||
|
||||
dtype=torch.float16 if is_half == True else torch.float32
|
||||
def get_bert_inf(phones, word2ph, norm_text, language):
|
||||
language=language.replace("all_","")
|
||||
@ -447,13 +462,41 @@ def get_phones_and_bert(text,language):
|
||||
|
||||
return phones,bert.to(dtype),norm_text
|
||||
|
||||
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language):
|
||||
def merge_short_text_in_array(texts, threshold):
|
||||
if (len(texts)) < 2:
|
||||
return texts
|
||||
result = []
|
||||
text = ""
|
||||
for ele in texts:
|
||||
text += ele
|
||||
if len(text) >= threshold:
|
||||
result.append(text)
|
||||
text = ""
|
||||
if (len(text) > 0):
|
||||
if len(result) == 0:
|
||||
result.append(text)
|
||||
else:
|
||||
result[len(result) - 1] += text
|
||||
return result
|
||||
|
||||
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut):
|
||||
# not supporting ref_free
|
||||
t0 = ttime()
|
||||
prompt_language = dict_language[prompt_language]
|
||||
prompt_text = prompt_text.strip("\n")
|
||||
prompt_language, text = prompt_language, text.strip("\n")
|
||||
zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32)
|
||||
if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_language != "en" else "."
|
||||
print(f"实际输入的参考文本: {prompt_text}")
|
||||
text = text.strip("\n")
|
||||
if (text[0] not in splits and len(get_first(text)) < 4): text = "。" + text if text_language != "en" else "." + text
|
||||
print(f"实际输入的目标文本: {text}")
|
||||
|
||||
zero_wav = np.zeros(
|
||||
int(hps.data.sampling_rate * 0.3),
|
||||
dtype=np.float16 if is_half == True else np.float32
|
||||
)
|
||||
with torch.no_grad():
|
||||
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
|
||||
# neglected error checking for reference audio duration
|
||||
wav16k = torch.from_numpy(wav16k)
|
||||
zero_wav_torch = torch.from_numpy(zero_wav)
|
||||
if (is_half == True):
|
||||
@ -463,20 +506,48 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language)
|
||||
wav16k = wav16k.to(device)
|
||||
zero_wav_torch = zero_wav_torch.to(device)
|
||||
wav16k = torch.cat([wav16k, zero_wav_torch])
|
||||
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
|
||||
ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
|
||||
"last_hidden_state"
|
||||
].transpose(
|
||||
1, 2
|
||||
) # .float()
|
||||
codes = vq_model.extract_latent(ssl_content)
|
||||
|
||||
prompt_semantic = codes[0, 0]
|
||||
t1 = ttime()
|
||||
prompt_language = dict_language[prompt_language]
|
||||
text_language = dict_language[text_language]
|
||||
phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language)
|
||||
phones1 = cleaned_text_to_sequence(phones1)
|
||||
|
||||
# 文本切句
|
||||
# default to no slice if no argument is provided
|
||||
how_to_cut = slice_option[how_to_cut]
|
||||
print(f"[INFO] 文本切句選項: {how_to_cut}")
|
||||
if (how_to_cut == "凑四句一切"):
|
||||
text = cut1(text)
|
||||
elif (how_to_cut == "凑50字一切"):
|
||||
text = cut2(text)
|
||||
elif (how_to_cut == "按英文句号.切"):
|
||||
text = cut3(text)
|
||||
elif (how_to_cut == "按英文句号.切"):
|
||||
text = cut4(text)
|
||||
elif (how_to_cut == "按标点符号切"):
|
||||
text = cut5(text)
|
||||
|
||||
while "\n\n" in text:
|
||||
text = text.replace("\n\n", "\n")
|
||||
print(f"实际输入的目标文本(切句后): {text}")
|
||||
texts = text.split("\n")
|
||||
texts = merge_short_text_in_array(texts, 5)
|
||||
audio_opt = []
|
||||
|
||||
phones1,bert1,norm_text1=get_phones_and_bert(prompt_text, prompt_language)
|
||||
|
||||
for text in texts:
|
||||
# 解决输入目标文本的空行导致报错的问题
|
||||
if (len(text.strip()) == 0):
|
||||
continue
|
||||
if (text[-1] not in splits): text += "。" if text_language != "en" else "."
|
||||
print(f"实际输入的目标文本(每句): {text}")
|
||||
phones2,bert2,norm_text2=get_phones_and_bert(text, text_language)
|
||||
print(f"前端处理后的文本(每句): {norm_text2}")
|
||||
bert = torch.cat([bert1, bert2], 1)
|
||||
all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0)
|
||||
|
||||
@ -496,23 +567,118 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language)
|
||||
early_stop_num=hz * max_sec)
|
||||
t3 = ttime()
|
||||
# print(pred_semantic.shape,idx)
|
||||
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次
|
||||
pred_semantic = pred_semantic[:, -idx:].unsqueeze(
|
||||
0
|
||||
) # .unsqueeze(0)#mq要多unsqueeze一次
|
||||
refer = get_spepc(hps, ref_wav_path) # .to(device)
|
||||
if (is_half == True):
|
||||
refer = refer.half().to(device)
|
||||
else:
|
||||
refer = refer.to(device)
|
||||
# audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
|
||||
audio = \
|
||||
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0),
|
||||
refer).detach().cpu().numpy()[
|
||||
0, 0] ###试试重建不带上prompt部分
|
||||
audio = (
|
||||
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer
|
||||
)
|
||||
.detach()
|
||||
.cpu()
|
||||
.numpy()[0, 0]
|
||||
) ###试试重建不带上prompt部分
|
||||
max_audio=np.abs(audio).max()#简单防止16bit爆音
|
||||
if max_audio>1:audio/=max_audio
|
||||
audio_opt.append(audio)
|
||||
audio_opt.append(zero_wav)
|
||||
t4 = ttime()
|
||||
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
|
||||
yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(np.int16)
|
||||
yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(
|
||||
np.int16
|
||||
)
|
||||
|
||||
splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", }
|
||||
|
||||
def get_first(text):
|
||||
pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
|
||||
text = re.split(pattern, text)[0].strip()
|
||||
return text
|
||||
|
||||
def split(todo_text):
|
||||
todo_text = todo_text.replace("……", "。").replace("——", ",")
|
||||
if todo_text[-1] not in splits:
|
||||
todo_text += "。"
|
||||
i_split_head = i_split_tail = 0
|
||||
len_text = len(todo_text)
|
||||
todo_texts = []
|
||||
while 1:
|
||||
if i_split_head >= len_text:
|
||||
break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
|
||||
if todo_text[i_split_head] in splits:
|
||||
i_split_head += 1
|
||||
todo_texts.append(todo_text[i_split_tail:i_split_head])
|
||||
i_split_tail = i_split_head
|
||||
else:
|
||||
i_split_head += 1
|
||||
return todo_texts
|
||||
|
||||
def cut1(inp):
|
||||
inp = inp.strip("\n")
|
||||
inps = split(inp)
|
||||
split_idx = list(range(0, len(inps), 4))
|
||||
split_idx[-1] = None
|
||||
if len(split_idx) > 1:
|
||||
opts = []
|
||||
for idx in range(len(split_idx) - 1):
|
||||
opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]]))
|
||||
else:
|
||||
opts = [inp]
|
||||
return "\n".join(opts)
|
||||
|
||||
|
||||
def cut2(inp):
|
||||
inp = inp.strip("\n")
|
||||
inps = split(inp)
|
||||
if len(inps) < 2:
|
||||
return inp
|
||||
opts = []
|
||||
summ = 0
|
||||
tmp_str = ""
|
||||
for i in range(len(inps)):
|
||||
summ += len(inps[i])
|
||||
tmp_str += inps[i]
|
||||
if summ > 50:
|
||||
summ = 0
|
||||
opts.append(tmp_str)
|
||||
tmp_str = ""
|
||||
if tmp_str != "":
|
||||
opts.append(tmp_str)
|
||||
# print(opts)
|
||||
if len(opts) > 1 and len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起
|
||||
opts[-2] = opts[-2] + opts[-1]
|
||||
opts = opts[:-1]
|
||||
return "\n".join(opts)
|
||||
|
||||
|
||||
def cut3(inp):
|
||||
inp = inp.strip("\n")
|
||||
return "\n".join(["%s" % item for item in inp.strip("。").split("。")])
|
||||
|
||||
|
||||
def cut4(inp):
|
||||
inp = inp.strip("\n")
|
||||
return "\n".join(["%s" % item for item in inp.strip(".").split(".")])
|
||||
|
||||
|
||||
# contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py
|
||||
def cut5(inp):
|
||||
# if not re.search(r'[^\w\s]', inp[-1]):
|
||||
# inp += '。'
|
||||
inp = inp.strip("\n")
|
||||
punds = r'[,.;?!、,。?!;:…]'
|
||||
items = re.split(f'({punds})', inp)
|
||||
mergeitems = ["".join(group) for group in zip(items[::2], items[1::2])]
|
||||
# 在句子不存在符号或句尾无符号的时候保证文本完整
|
||||
if len(items)%2 == 1:
|
||||
mergeitems.append(items[-1])
|
||||
opt = "\n".join(mergeitems)
|
||||
return opt
|
||||
|
||||
def handle_control(command):
|
||||
if command == "restart":
|
||||
@ -541,7 +707,7 @@ def handle_change(path, text, language):
|
||||
return JSONResponse({"code": 0, "message": "Success"}, status_code=200)
|
||||
|
||||
|
||||
def handle(refer_wav_path, prompt_text, prompt_language, text, text_language):
|
||||
def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, slice):
|
||||
if (
|
||||
refer_wav_path == "" or refer_wav_path is None
|
||||
or prompt_text == "" or prompt_text is None
|
||||
@ -557,7 +723,7 @@ def handle(refer_wav_path, prompt_text, prompt_language, text, text_language):
|
||||
|
||||
with torch.no_grad():
|
||||
gen = get_tts_wav(
|
||||
refer_wav_path, prompt_text, prompt_language, text, text_language
|
||||
refer_wav_path, prompt_text, prompt_language, text, text_language, slice
|
||||
)
|
||||
sampling_rate, audio_data = next(gen)
|
||||
|
||||
@ -628,6 +794,7 @@ async def tts_endpoint(request: Request):
|
||||
json_post_raw.get("prompt_language"),
|
||||
json_post_raw.get("text"),
|
||||
json_post_raw.get("text_language"),
|
||||
json_post_raw.get("slice"),
|
||||
)
|
||||
|
||||
|
||||
@ -638,8 +805,9 @@ async def tts_endpoint(
|
||||
prompt_language: str = None,
|
||||
text: str = None,
|
||||
text_language: str = None,
|
||||
slice: str = None
|
||||
):
|
||||
return handle(refer_wav_path, prompt_text, prompt_language, text, text_language)
|
||||
return handle(refer_wav_path, prompt_text, prompt_language, text, text_language, slice)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
x
Reference in New Issue
Block a user