From 5cee016c434c79b42a6e439ea80131fc6386ea29 Mon Sep 17 00:00:00 2001 From: justoy Date: Thu, 3 Jul 2025 11:21:44 -0700 Subject: [PATCH] add how to cut param to api.py --- api.py | 138 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 137 insertions(+), 1 deletion(-) diff --git a/api.py b/api.py index cc0896a2..f3c1dba6 100644 --- a/api.py +++ b/api.py @@ -22,6 +22,7 @@ ·-mt` - `返回的音频编码格式, 流式默认ogg, 非流式默认wav, "wav", "ogg", "aac"` ·-st` - `返回的音频数据类型, 默认int16, "int16", "int32"` ·-cp` - `文本切分符号设定, 默认为空, 以",.,。"字符串的方式传入` +·-htc` - `文本切分方式, 默认"不切", "不切", "凑四句一切", "凑50字一切", "按中文句号。切", "按英文句号.切", "按标点符号切"` `-hb` - `cnhubert路径` `-b` - `bert路径` @@ -55,6 +56,18 @@ POST: } ``` +使用执行参数指定的参考音频并设定文本切分方式: +GET: + `http://127.0.0.1:9880?text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh&how_to_cut=按中文句号。切` +POST: +```json +{ + "text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。", + "text_language": "zh", + "how_to_cut": "按中文句号。切" +} +``` + 手动指定当次推理所使用的参考音频: GET: `http://127.0.0.1:9880?refer_wav_path=123.wav&prompt_text=一二三。&prompt_language=zh&text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh` @@ -75,7 +88,7 @@ RESP: 手动指定当次推理所使用的参考音频,并提供参数: GET: - `http://127.0.0.1:9880?refer_wav_path=123.wav&prompt_text=一二三。&prompt_language=zh&text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh&top_k=20&top_p=0.6&temperature=0.6&speed=1&inp_refs="456.wav"&inp_refs="789.wav"` + `http://127.0.0.1:9880?refer_wav_path=123.wav&prompt_text=一二三。&prompt_language=zh&text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh&top_k=20&top_p=0.6&temperature=0.6&speed=1&how_to_cut=按中文句号。切&inp_refs="456.wav"&inp_refs="789.wav"` POST: ```json { @@ -88,6 +101,7 @@ POST: "top_p": 0.6, "temperature": 0.6, "speed": 1, + "how_to_cut": "按中文句号。切", "inp_refs": ["456.wav","789.wav"] } ``` @@ -144,6 +158,7 @@ import argparse import os import re import sys +from string import punctuation now_dir = os.getcwd() sys.path.append(now_dir) @@ -827,12 +842,110 @@ splits = { } +def split(todo_text): + todo_text = todo_text.replace("……", "。").replace("——", ",") + if todo_text[-1] not in splits: + todo_text += "。" + i_split_head = i_split_tail = 0 + len_text = len(todo_text) + todo_texts = [] + while 1: + if i_split_head >= len_text: + break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入 + if todo_text[i_split_head] in splits: + i_split_head += 1 + todo_texts.append(todo_text[i_split_tail:i_split_head]) + i_split_tail = i_split_head + else: + i_split_head += 1 + return todo_texts + + +def cut1(inp): + inp = inp.strip("\n") + inps = split(inp) + split_idx = list(range(0, len(inps), 4)) + split_idx[-1] = None + if len(split_idx) > 1: + opts = [] + for idx in range(len(split_idx) - 1): + opts.append("".join(inps[split_idx[idx] : split_idx[idx + 1]])) + else: + opts = [inp] + opts = [item for item in opts if not set(item).issubset(punctuation)] + return "\n".join(opts) + + +def cut2(inp): + inp = inp.strip("\n") + inps = split(inp) + if len(inps) < 2: + return inp + opts = [] + summ = 0 + tmp_str = "" + for i in range(len(inps)): + summ += len(inps[i]) + tmp_str += inps[i] + if summ > 50: + summ = 0 + opts.append(tmp_str) + tmp_str = "" + if tmp_str != "": + opts.append(tmp_str) + # print(opts) + if len(opts) > 1 and len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起 + opts[-2] = opts[-2] + opts[-1] + opts = opts[:-1] + opts = [item for item in opts if not set(item).issubset(punctuation)] + return "\n".join(opts) + + +def cut3(inp): + inp = inp.strip("\n") + opts = ["%s" % item for item in inp.strip("。").split("。")] + opts = [item for item in opts if not set(item).issubset(punctuation)] + return "\n".join(opts) + + +def cut4(inp): + inp = inp.strip("\n") + opts = re.split(r"(? 0 and i < len(inp) - 1 and inp[i - 1].isdigit() and inp[i + 1].isdigit(): + items.append(char) + else: + items.append(char) + mergeitems.append("".join(items)) + items = [] + else: + items.append(char) + + if items: + mergeitems.append("".join(items)) + + opt = [item for item in mergeitems if not set(item).issubset(punds)] + return "\n".join(opt) + + def get_tts_wav( ref_wav_path, prompt_text, prompt_language, text, text_language, + how_to_cut="不切", top_k=15, top_p=0.6, temperature=0.6, @@ -909,6 +1022,20 @@ def get_tts_wav( refer, audio_tensor = get_spepc(hps, ref_wav_path, dtype, device) t1 = ttime() + # Apply text cutting based on how_to_cut parameter + if how_to_cut == "凑四句一切": + text = cut1(text) + elif how_to_cut == "凑50字一切": + text = cut2(text) + elif how_to_cut == "按中文句号。切": + text = cut3(text) + elif how_to_cut == "按英文句号.切": + text = cut4(text) + elif how_to_cut == "按标点符号切": + text = cut5(text) + while "\n\n" in text: + text = text.replace("\n\n", "\n") + # os.environ['version'] = version prompt_language = dict_language[prompt_language.lower()] text_language = dict_language[text_language.lower()] @@ -1104,6 +1231,7 @@ def handle( text, text_language, cut_punc, + how_to_cut, top_k, top_p, temperature, @@ -1140,6 +1268,7 @@ def handle( prompt_language, text, text_language, + how_to_cut, top_k, top_p, temperature, @@ -1211,6 +1340,9 @@ parser.add_argument("-mt", "--media_type", type=str, default="wav", help="音频 parser.add_argument("-st", "--sub_type", type=str, default="int16", help="音频数据类型, int16 / int32") parser.add_argument("-cp", "--cut_punc", type=str, default="", help="文本切分符号设定, 符号范围,.;?!、,。?!;:…") # 切割常用分句符为 `python ./api.py -cp ".?!。?!"` +parser.add_argument("-htc", "--how_to_cut", type=str, default="不切", + choices=["不切", "凑四句一切", "凑50字一切", "按中文句号。切", "按英文句号.切", "按标点符号切"], + help="文本切分方式") parser.add_argument("-hb", "--hubert_path", type=str, default=g_config.cnhubert_path, help="覆盖config.cnhubert_path") parser.add_argument("-b", "--bert_path", type=str, default=g_config.bert_path, help="覆盖config.bert_path") @@ -1223,6 +1355,7 @@ host = args.bind_addr cnhubert_base_path = args.hubert_path bert_path = args.bert_path default_cut_punc = args.cut_punc +how_to_cut = args.how_to_cut # 应用参数配置 default_refer = DefaultRefer(args.default_refer_path, args.default_refer_text, args.default_refer_language) @@ -1348,6 +1481,7 @@ async def tts_endpoint(request: Request): json_post_raw.get("text"), json_post_raw.get("text_language"), json_post_raw.get("cut_punc"), + json_post_raw.get("how_to_cut"), json_post_raw.get("top_k", 15), json_post_raw.get("top_p", 1.0), json_post_raw.get("temperature", 1.0), @@ -1366,6 +1500,7 @@ async def tts_endpoint( text: str = None, text_language: str = None, cut_punc: str = None, + how_to_cut: str = Query(default=None), top_k: int = 15, top_p: float = 1.0, temperature: float = 1.0, @@ -1381,6 +1516,7 @@ async def tts_endpoint( text, text_language, cut_punc, + how_to_cut, top_k, top_p, temperature,