mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-08 07:49:59 +08:00
Added support for inference text slicing in api.py
This commit is contained in:
parent
9be39a8739
commit
e9625c3a9e
206
api.py
206
api.py
@ -35,7 +35,8 @@ POST:
|
|||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。",
|
"text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。",
|
||||||
"text_language": "zh"
|
"text_language": "zh",
|
||||||
|
"slice": "按标点符号切"
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -120,7 +121,7 @@ RESP: 无
|
|||||||
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os, re
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
now_dir = os.getcwd()
|
now_dir = os.getcwd()
|
||||||
@ -166,6 +167,7 @@ parser.add_argument("-a", "--bind_addr", type=str, default="0.0.0.0", help="defa
|
|||||||
parser.add_argument("-p", "--port", type=int, default=g_config.api_port, help="default: 9880")
|
parser.add_argument("-p", "--port", type=int, default=g_config.api_port, help="default: 9880")
|
||||||
parser.add_argument("-fp", "--full_precision", action="store_true", default=False, help="覆盖config.is_half为False, 使用全精度")
|
parser.add_argument("-fp", "--full_precision", action="store_true", default=False, help="覆盖config.is_half为False, 使用全精度")
|
||||||
parser.add_argument("-hp", "--half_precision", action="store_true", default=False, help="覆盖config.is_half为True, 使用半精度")
|
parser.add_argument("-hp", "--half_precision", action="store_true", default=False, help="覆盖config.is_half为True, 使用半精度")
|
||||||
|
#parser.add_argument("-sl", "--slice", type=str, default="No slice", help="文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。")
|
||||||
# bool值的用法为 `python ./api.py -fp ...`
|
# bool值的用法为 `python ./api.py -fp ...`
|
||||||
# 此时 full_precision==True, half_precision==False
|
# 此时 full_precision==True, half_precision==False
|
||||||
|
|
||||||
@ -375,6 +377,19 @@ dict_language = {
|
|||||||
"多语种混合": "auto"
|
"多语种混合": "auto"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
slice_option = {
|
||||||
|
"凑四句一切": "凑四句一切",
|
||||||
|
"凑50字一切": "凑50字一切",
|
||||||
|
"按中文句号。切": "按中文句号。切",
|
||||||
|
"按英文句号.切": "按英文句号.切",
|
||||||
|
"按标点符号切": "按标点符号切",
|
||||||
|
"per 4 sentences": "凑四句一切",
|
||||||
|
"per 50 letters": "凑50字一切",
|
||||||
|
"per period": "按英文句号.切",
|
||||||
|
"per punctuation mark": "按标点符号切",
|
||||||
|
None: "No slice"
|
||||||
|
}
|
||||||
|
|
||||||
dtype=torch.float16 if is_half == True else torch.float32
|
dtype=torch.float16 if is_half == True else torch.float32
|
||||||
def get_bert_inf(phones, word2ph, norm_text, language):
|
def get_bert_inf(phones, word2ph, norm_text, language):
|
||||||
language=language.replace("all_","")
|
language=language.replace("all_","")
|
||||||
@ -447,13 +462,41 @@ def get_phones_and_bert(text,language):
|
|||||||
|
|
||||||
return phones,bert.to(dtype),norm_text
|
return phones,bert.to(dtype),norm_text
|
||||||
|
|
||||||
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language):
|
def merge_short_text_in_array(texts, threshold):
|
||||||
|
if (len(texts)) < 2:
|
||||||
|
return texts
|
||||||
|
result = []
|
||||||
|
text = ""
|
||||||
|
for ele in texts:
|
||||||
|
text += ele
|
||||||
|
if len(text) >= threshold:
|
||||||
|
result.append(text)
|
||||||
|
text = ""
|
||||||
|
if (len(text) > 0):
|
||||||
|
if len(result) == 0:
|
||||||
|
result.append(text)
|
||||||
|
else:
|
||||||
|
result[len(result) - 1] += text
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut):
|
||||||
|
# not supporting ref_free
|
||||||
t0 = ttime()
|
t0 = ttime()
|
||||||
|
prompt_language = dict_language[prompt_language]
|
||||||
prompt_text = prompt_text.strip("\n")
|
prompt_text = prompt_text.strip("\n")
|
||||||
prompt_language, text = prompt_language, text.strip("\n")
|
if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_language != "en" else "."
|
||||||
zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32)
|
print(f"实际输入的参考文本: {prompt_text}")
|
||||||
|
text = text.strip("\n")
|
||||||
|
if (text[0] not in splits and len(get_first(text)) < 4): text = "。" + text if text_language != "en" else "." + text
|
||||||
|
print(f"实际输入的目标文本: {text}")
|
||||||
|
|
||||||
|
zero_wav = np.zeros(
|
||||||
|
int(hps.data.sampling_rate * 0.3),
|
||||||
|
dtype=np.float16 if is_half == True else np.float32
|
||||||
|
)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
|
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
|
||||||
|
# neglected error checking for reference audio duration
|
||||||
wav16k = torch.from_numpy(wav16k)
|
wav16k = torch.from_numpy(wav16k)
|
||||||
zero_wav_torch = torch.from_numpy(zero_wav)
|
zero_wav_torch = torch.from_numpy(zero_wav)
|
||||||
if (is_half == True):
|
if (is_half == True):
|
||||||
@ -463,20 +506,48 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language)
|
|||||||
wav16k = wav16k.to(device)
|
wav16k = wav16k.to(device)
|
||||||
zero_wav_torch = zero_wav_torch.to(device)
|
zero_wav_torch = zero_wav_torch.to(device)
|
||||||
wav16k = torch.cat([wav16k, zero_wav_torch])
|
wav16k = torch.cat([wav16k, zero_wav_torch])
|
||||||
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
|
ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
|
||||||
|
"last_hidden_state"
|
||||||
|
].transpose(
|
||||||
|
1, 2
|
||||||
|
) # .float()
|
||||||
codes = vq_model.extract_latent(ssl_content)
|
codes = vq_model.extract_latent(ssl_content)
|
||||||
|
|
||||||
prompt_semantic = codes[0, 0]
|
prompt_semantic = codes[0, 0]
|
||||||
t1 = ttime()
|
t1 = ttime()
|
||||||
prompt_language = dict_language[prompt_language]
|
|
||||||
text_language = dict_language[text_language]
|
# 文本切句
|
||||||
phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language)
|
# default to no slice if no argument is provided
|
||||||
phones1 = cleaned_text_to_sequence(phones1)
|
how_to_cut = slice_option[how_to_cut]
|
||||||
|
print(f"[INFO] 文本切句選項: {how_to_cut}")
|
||||||
|
if (how_to_cut == "凑四句一切"):
|
||||||
|
text = cut1(text)
|
||||||
|
elif (how_to_cut == "凑50字一切"):
|
||||||
|
text = cut2(text)
|
||||||
|
elif (how_to_cut == "按英文句号.切"):
|
||||||
|
text = cut3(text)
|
||||||
|
elif (how_to_cut == "按英文句号.切"):
|
||||||
|
text = cut4(text)
|
||||||
|
elif (how_to_cut == "按标点符号切"):
|
||||||
|
text = cut5(text)
|
||||||
|
|
||||||
|
while "\n\n" in text:
|
||||||
|
text = text.replace("\n\n", "\n")
|
||||||
|
print(f"实际输入的目标文本(切句后): {text}")
|
||||||
texts = text.split("\n")
|
texts = text.split("\n")
|
||||||
|
texts = merge_short_text_in_array(texts, 5)
|
||||||
audio_opt = []
|
audio_opt = []
|
||||||
|
|
||||||
phones1,bert1,norm_text1=get_phones_and_bert(prompt_text, prompt_language)
|
phones1,bert1,norm_text1=get_phones_and_bert(prompt_text, prompt_language)
|
||||||
|
|
||||||
for text in texts:
|
for text in texts:
|
||||||
|
# 解决输入目标文本的空行导致报错的问题
|
||||||
|
if (len(text.strip()) == 0):
|
||||||
|
continue
|
||||||
|
if (text[-1] not in splits): text += "。" if text_language != "en" else "."
|
||||||
|
print(f"实际输入的目标文本(每句): {text}")
|
||||||
phones2,bert2,norm_text2=get_phones_and_bert(text, text_language)
|
phones2,bert2,norm_text2=get_phones_and_bert(text, text_language)
|
||||||
|
print(f"前端处理后的文本(每句): {norm_text2}")
|
||||||
bert = torch.cat([bert1, bert2], 1)
|
bert = torch.cat([bert1, bert2], 1)
|
||||||
all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0)
|
all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0)
|
||||||
|
|
||||||
@ -496,23 +567,118 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language)
|
|||||||
early_stop_num=hz * max_sec)
|
early_stop_num=hz * max_sec)
|
||||||
t3 = ttime()
|
t3 = ttime()
|
||||||
# print(pred_semantic.shape,idx)
|
# print(pred_semantic.shape,idx)
|
||||||
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次
|
pred_semantic = pred_semantic[:, -idx:].unsqueeze(
|
||||||
|
0
|
||||||
|
) # .unsqueeze(0)#mq要多unsqueeze一次
|
||||||
refer = get_spepc(hps, ref_wav_path) # .to(device)
|
refer = get_spepc(hps, ref_wav_path) # .to(device)
|
||||||
if (is_half == True):
|
if (is_half == True):
|
||||||
refer = refer.half().to(device)
|
refer = refer.half().to(device)
|
||||||
else:
|
else:
|
||||||
refer = refer.to(device)
|
refer = refer.to(device)
|
||||||
# audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
|
# audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
|
||||||
audio = \
|
audio = (
|
||||||
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0),
|
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer
|
||||||
refer).detach().cpu().numpy()[
|
)
|
||||||
0, 0] ###试试重建不带上prompt部分
|
.detach()
|
||||||
|
.cpu()
|
||||||
|
.numpy()[0, 0]
|
||||||
|
) ###试试重建不带上prompt部分
|
||||||
|
max_audio=np.abs(audio).max()#简单防止16bit爆音
|
||||||
|
if max_audio>1:audio/=max_audio
|
||||||
audio_opt.append(audio)
|
audio_opt.append(audio)
|
||||||
audio_opt.append(zero_wav)
|
audio_opt.append(zero_wav)
|
||||||
t4 = ttime()
|
t4 = ttime()
|
||||||
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
|
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
|
||||||
yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(np.int16)
|
yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(
|
||||||
|
np.int16
|
||||||
|
)
|
||||||
|
|
||||||
|
splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", }
|
||||||
|
|
||||||
|
def get_first(text):
|
||||||
|
pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
|
||||||
|
text = re.split(pattern, text)[0].strip()
|
||||||
|
return text
|
||||||
|
|
||||||
|
def split(todo_text):
|
||||||
|
todo_text = todo_text.replace("……", "。").replace("——", ",")
|
||||||
|
if todo_text[-1] not in splits:
|
||||||
|
todo_text += "。"
|
||||||
|
i_split_head = i_split_tail = 0
|
||||||
|
len_text = len(todo_text)
|
||||||
|
todo_texts = []
|
||||||
|
while 1:
|
||||||
|
if i_split_head >= len_text:
|
||||||
|
break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
|
||||||
|
if todo_text[i_split_head] in splits:
|
||||||
|
i_split_head += 1
|
||||||
|
todo_texts.append(todo_text[i_split_tail:i_split_head])
|
||||||
|
i_split_tail = i_split_head
|
||||||
|
else:
|
||||||
|
i_split_head += 1
|
||||||
|
return todo_texts
|
||||||
|
|
||||||
|
def cut1(inp):
|
||||||
|
inp = inp.strip("\n")
|
||||||
|
inps = split(inp)
|
||||||
|
split_idx = list(range(0, len(inps), 4))
|
||||||
|
split_idx[-1] = None
|
||||||
|
if len(split_idx) > 1:
|
||||||
|
opts = []
|
||||||
|
for idx in range(len(split_idx) - 1):
|
||||||
|
opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]]))
|
||||||
|
else:
|
||||||
|
opts = [inp]
|
||||||
|
return "\n".join(opts)
|
||||||
|
|
||||||
|
|
||||||
|
def cut2(inp):
|
||||||
|
inp = inp.strip("\n")
|
||||||
|
inps = split(inp)
|
||||||
|
if len(inps) < 2:
|
||||||
|
return inp
|
||||||
|
opts = []
|
||||||
|
summ = 0
|
||||||
|
tmp_str = ""
|
||||||
|
for i in range(len(inps)):
|
||||||
|
summ += len(inps[i])
|
||||||
|
tmp_str += inps[i]
|
||||||
|
if summ > 50:
|
||||||
|
summ = 0
|
||||||
|
opts.append(tmp_str)
|
||||||
|
tmp_str = ""
|
||||||
|
if tmp_str != "":
|
||||||
|
opts.append(tmp_str)
|
||||||
|
# print(opts)
|
||||||
|
if len(opts) > 1 and len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起
|
||||||
|
opts[-2] = opts[-2] + opts[-1]
|
||||||
|
opts = opts[:-1]
|
||||||
|
return "\n".join(opts)
|
||||||
|
|
||||||
|
|
||||||
|
def cut3(inp):
|
||||||
|
inp = inp.strip("\n")
|
||||||
|
return "\n".join(["%s" % item for item in inp.strip("。").split("。")])
|
||||||
|
|
||||||
|
|
||||||
|
def cut4(inp):
|
||||||
|
inp = inp.strip("\n")
|
||||||
|
return "\n".join(["%s" % item for item in inp.strip(".").split(".")])
|
||||||
|
|
||||||
|
|
||||||
|
# contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py
|
||||||
|
def cut5(inp):
|
||||||
|
# if not re.search(r'[^\w\s]', inp[-1]):
|
||||||
|
# inp += '。'
|
||||||
|
inp = inp.strip("\n")
|
||||||
|
punds = r'[,.;?!、,。?!;:…]'
|
||||||
|
items = re.split(f'({punds})', inp)
|
||||||
|
mergeitems = ["".join(group) for group in zip(items[::2], items[1::2])]
|
||||||
|
# 在句子不存在符号或句尾无符号的时候保证文本完整
|
||||||
|
if len(items)%2 == 1:
|
||||||
|
mergeitems.append(items[-1])
|
||||||
|
opt = "\n".join(mergeitems)
|
||||||
|
return opt
|
||||||
|
|
||||||
def handle_control(command):
|
def handle_control(command):
|
||||||
if command == "restart":
|
if command == "restart":
|
||||||
@ -541,7 +707,7 @@ def handle_change(path, text, language):
|
|||||||
return JSONResponse({"code": 0, "message": "Success"}, status_code=200)
|
return JSONResponse({"code": 0, "message": "Success"}, status_code=200)
|
||||||
|
|
||||||
|
|
||||||
def handle(refer_wav_path, prompt_text, prompt_language, text, text_language):
|
def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, slice):
|
||||||
if (
|
if (
|
||||||
refer_wav_path == "" or refer_wav_path is None
|
refer_wav_path == "" or refer_wav_path is None
|
||||||
or prompt_text == "" or prompt_text is None
|
or prompt_text == "" or prompt_text is None
|
||||||
@ -557,7 +723,7 @@ def handle(refer_wav_path, prompt_text, prompt_language, text, text_language):
|
|||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
gen = get_tts_wav(
|
gen = get_tts_wav(
|
||||||
refer_wav_path, prompt_text, prompt_language, text, text_language
|
refer_wav_path, prompt_text, prompt_language, text, text_language, slice
|
||||||
)
|
)
|
||||||
sampling_rate, audio_data = next(gen)
|
sampling_rate, audio_data = next(gen)
|
||||||
|
|
||||||
@ -628,6 +794,7 @@ async def tts_endpoint(request: Request):
|
|||||||
json_post_raw.get("prompt_language"),
|
json_post_raw.get("prompt_language"),
|
||||||
json_post_raw.get("text"),
|
json_post_raw.get("text"),
|
||||||
json_post_raw.get("text_language"),
|
json_post_raw.get("text_language"),
|
||||||
|
json_post_raw.get("slice"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -638,8 +805,9 @@ async def tts_endpoint(
|
|||||||
prompt_language: str = None,
|
prompt_language: str = None,
|
||||||
text: str = None,
|
text: str = None,
|
||||||
text_language: str = None,
|
text_language: str = None,
|
||||||
|
slice: str = None
|
||||||
):
|
):
|
||||||
return handle(refer_wav_path, prompt_text, prompt_language, text, text_language)
|
return handle(refer_wav_path, prompt_text, prompt_language, text, text_language, slice)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
x
Reference in New Issue
Block a user