add how to cut param to api.py

This commit is contained in:
justoy 2025-07-03 11:21:44 -07:00 committed by 1
parent 6df61f58e4
commit 5cee016c43

138
api.py
View File

@ -22,6 +22,7 @@
·-mt` - `返回的音频编码格式, 流式默认ogg, 非流式默认wav, "wav", "ogg", "aac"`
·-st` - `返回的音频数据类型, 默认int16, "int16", "int32"`
·-cp` - `文本切分符号设定, 默认为空, ",.,。"字符串的方式传入`
·-htc` - `文本切分方式, 默认"不切", "不切", "凑四句一切", "凑50字一切", "按中文句号。切", "按英文句号.切", "按标点符号切"`
`-hb` - `cnhubert路径`
`-b` - `bert路径`
@ -55,6 +56,18 @@ POST:
}
```
使用执行参数指定的参考音频并设定文本切分方式:
GET:
`http://127.0.0.1:9880?text=先帝创业未半而中道崩殂今天下三分益州疲弊此诚危急存亡之秋也&text_language=zh&how_to_cut=按中文句号`
POST:
```json
{
"text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。",
"text_language": "zh",
"how_to_cut": "按中文句号。切"
}
```
手动指定当次推理所使用的参考音频:
GET:
`http://127.0.0.1:9880?refer_wav_path=123.wav&prompt_text=一二三&prompt_language=zh&text=先帝创业未半而中道崩殂今天下三分益州疲弊此诚危急存亡之秋也&text_language=zh`
@ -75,7 +88,7 @@ RESP:
手动指定当次推理所使用的参考音频并提供参数:
GET:
`http://127.0.0.1:9880?refer_wav_path=123.wav&prompt_text=一二三&prompt_language=zh&text=先帝创业未半而中道崩殂今天下三分益州疲弊此诚危急存亡之秋也&text_language=zh&top_k=20&top_p=0.6&temperature=0.6&speed=1&inp_refs="456.wav"&inp_refs="789.wav"`
`http://127.0.0.1:9880?refer_wav_path=123.wav&prompt_text=一二三&prompt_language=zh&text=先帝创业未半而中道崩殂今天下三分益州疲弊此诚危急存亡之秋也&text_language=zh&top_k=20&top_p=0.6&temperature=0.6&speed=1&how_to_cut=按中文句号&inp_refs="456.wav"&inp_refs="789.wav"`
POST:
```json
{
@ -88,6 +101,7 @@ POST:
"top_p": 0.6,
"temperature": 0.6,
"speed": 1,
"how_to_cut": "按中文句号。切",
"inp_refs": ["456.wav","789.wav"]
}
```
@ -144,6 +158,7 @@ import argparse
import os
import re
import sys
from string import punctuation
now_dir = os.getcwd()
sys.path.append(now_dir)
@ -827,12 +842,110 @@ splits = {
}
def split(todo_text):
todo_text = todo_text.replace("……", "").replace("——", "")
if todo_text[-1] not in splits:
todo_text += ""
i_split_head = i_split_tail = 0
len_text = len(todo_text)
todo_texts = []
while 1:
if i_split_head >= len_text:
break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
if todo_text[i_split_head] in splits:
i_split_head += 1
todo_texts.append(todo_text[i_split_tail:i_split_head])
i_split_tail = i_split_head
else:
i_split_head += 1
return todo_texts
def cut1(inp):
inp = inp.strip("\n")
inps = split(inp)
split_idx = list(range(0, len(inps), 4))
split_idx[-1] = None
if len(split_idx) > 1:
opts = []
for idx in range(len(split_idx) - 1):
opts.append("".join(inps[split_idx[idx] : split_idx[idx + 1]]))
else:
opts = [inp]
opts = [item for item in opts if not set(item).issubset(punctuation)]
return "\n".join(opts)
def cut2(inp):
inp = inp.strip("\n")
inps = split(inp)
if len(inps) < 2:
return inp
opts = []
summ = 0
tmp_str = ""
for i in range(len(inps)):
summ += len(inps[i])
tmp_str += inps[i]
if summ > 50:
summ = 0
opts.append(tmp_str)
tmp_str = ""
if tmp_str != "":
opts.append(tmp_str)
# print(opts)
if len(opts) > 1 and len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起
opts[-2] = opts[-2] + opts[-1]
opts = opts[:-1]
opts = [item for item in opts if not set(item).issubset(punctuation)]
return "\n".join(opts)
def cut3(inp):
inp = inp.strip("\n")
opts = ["%s" % item for item in inp.strip("").split("")]
opts = [item for item in opts if not set(item).issubset(punctuation)]
return "\n".join(opts)
def cut4(inp):
inp = inp.strip("\n")
opts = re.split(r"(?<!\d)\.(?!\d)", inp.strip("."))
opts = [item for item in opts if not set(item).issubset(punctuation)]
return "\n".join(opts)
def cut5(inp):
inp = inp.strip("\n")
punds = {",", ".", ";", "?", "!", "", "", "", "", "", ";", "", ""}
mergeitems = []
items = []
for i, char in enumerate(inp):
if char in punds:
if char == "." and i > 0 and i < len(inp) - 1 and inp[i - 1].isdigit() and inp[i + 1].isdigit():
items.append(char)
else:
items.append(char)
mergeitems.append("".join(items))
items = []
else:
items.append(char)
if items:
mergeitems.append("".join(items))
opt = [item for item in mergeitems if not set(item).issubset(punds)]
return "\n".join(opt)
def get_tts_wav(
ref_wav_path,
prompt_text,
prompt_language,
text,
text_language,
how_to_cut="不切",
top_k=15,
top_p=0.6,
temperature=0.6,
@ -909,6 +1022,20 @@ def get_tts_wav(
refer, audio_tensor = get_spepc(hps, ref_wav_path, dtype, device)
t1 = ttime()
# Apply text cutting based on how_to_cut parameter
if how_to_cut == "凑四句一切":
text = cut1(text)
elif how_to_cut == "凑50字一切":
text = cut2(text)
elif how_to_cut == "按中文句号。切":
text = cut3(text)
elif how_to_cut == "按英文句号.切":
text = cut4(text)
elif how_to_cut == "按标点符号切":
text = cut5(text)
while "\n\n" in text:
text = text.replace("\n\n", "\n")
# os.environ['version'] = version
prompt_language = dict_language[prompt_language.lower()]
text_language = dict_language[text_language.lower()]
@ -1104,6 +1231,7 @@ def handle(
text,
text_language,
cut_punc,
how_to_cut,
top_k,
top_p,
temperature,
@ -1140,6 +1268,7 @@ def handle(
prompt_language,
text,
text_language,
how_to_cut,
top_k,
top_p,
temperature,
@ -1211,6 +1340,9 @@ parser.add_argument("-mt", "--media_type", type=str, default="wav", help="音频
parser.add_argument("-st", "--sub_type", type=str, default="int16", help="音频数据类型, int16 / int32")
parser.add_argument("-cp", "--cut_punc", type=str, default="", help="文本切分符号设定, 符号范围,.;?!、,。?!;:…")
# 切割常用分句符为 `python ./api.py -cp ".?!。?!"`
parser.add_argument("-htc", "--how_to_cut", type=str, default="不切",
choices=["不切", "凑四句一切", "凑50字一切", "按中文句号。切", "按英文句号.切", "按标点符号切"],
help="文本切分方式")
parser.add_argument("-hb", "--hubert_path", type=str, default=g_config.cnhubert_path, help="覆盖config.cnhubert_path")
parser.add_argument("-b", "--bert_path", type=str, default=g_config.bert_path, help="覆盖config.bert_path")
@ -1223,6 +1355,7 @@ host = args.bind_addr
cnhubert_base_path = args.hubert_path
bert_path = args.bert_path
default_cut_punc = args.cut_punc
how_to_cut = args.how_to_cut
# 应用参数配置
default_refer = DefaultRefer(args.default_refer_path, args.default_refer_text, args.default_refer_language)
@ -1348,6 +1481,7 @@ async def tts_endpoint(request: Request):
json_post_raw.get("text"),
json_post_raw.get("text_language"),
json_post_raw.get("cut_punc"),
json_post_raw.get("how_to_cut"),
json_post_raw.get("top_k", 15),
json_post_raw.get("top_p", 1.0),
json_post_raw.get("temperature", 1.0),
@ -1366,6 +1500,7 @@ async def tts_endpoint(
text: str = None,
text_language: str = None,
cut_punc: str = None,
how_to_cut: str = Query(default=None),
top_k: int = 15,
top_p: float = 1.0,
temperature: float = 1.0,
@ -1381,6 +1516,7 @@ async def tts_endpoint(
text,
text_language,
cut_punc,
how_to_cut,
top_k,
top_p,
temperature,