From 5cee016c434c79b42a6e439ea80131fc6386ea29 Mon Sep 17 00:00:00 2001
From: justoy
Date: Thu, 3 Jul 2025 11:21:44 -0700
Subject: [PATCH] add how to cut param to api.py
---
api.py | 138 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++-
1 file changed, 137 insertions(+), 1 deletion(-)
diff --git a/api.py b/api.py
index cc0896a2..f3c1dba6 100644
--- a/api.py
+++ b/api.py
@@ -22,6 +22,7 @@
·-mt` - `返回的音频编码格式, 流式默认ogg, 非流式默认wav, "wav", "ogg", "aac"`
·-st` - `返回的音频数据类型, 默认int16, "int16", "int32"`
·-cp` - `文本切分符号设定, 默认为空, 以",.,。"字符串的方式传入`
+·-htc` - `文本切分方式, 默认"不切", "不切", "凑四句一切", "凑50字一切", "按中文句号。切", "按英文句号.切", "按标点符号切"`
`-hb` - `cnhubert路径`
`-b` - `bert路径`
@@ -55,6 +56,18 @@ POST:
}
```
+使用执行参数指定的参考音频并设定文本切分方式:
+GET:
+ `http://127.0.0.1:9880?text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh&how_to_cut=按中文句号。切`
+POST:
+```json
+{
+ "text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。",
+ "text_language": "zh",
+ "how_to_cut": "按中文句号。切"
+}
+```
+
手动指定当次推理所使用的参考音频:
GET:
`http://127.0.0.1:9880?refer_wav_path=123.wav&prompt_text=一二三。&prompt_language=zh&text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh`
@@ -75,7 +88,7 @@ RESP:
手动指定当次推理所使用的参考音频,并提供参数:
GET:
- `http://127.0.0.1:9880?refer_wav_path=123.wav&prompt_text=一二三。&prompt_language=zh&text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh&top_k=20&top_p=0.6&temperature=0.6&speed=1&inp_refs="456.wav"&inp_refs="789.wav"`
+ `http://127.0.0.1:9880?refer_wav_path=123.wav&prompt_text=一二三。&prompt_language=zh&text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh&top_k=20&top_p=0.6&temperature=0.6&speed=1&how_to_cut=按中文句号。切&inp_refs="456.wav"&inp_refs="789.wav"`
POST:
```json
{
@@ -88,6 +101,7 @@ POST:
"top_p": 0.6,
"temperature": 0.6,
"speed": 1,
+ "how_to_cut": "按中文句号。切",
"inp_refs": ["456.wav","789.wav"]
}
```
@@ -144,6 +158,7 @@ import argparse
import os
import re
import sys
+from string import punctuation
now_dir = os.getcwd()
sys.path.append(now_dir)
@@ -827,12 +842,110 @@ splits = {
}
+def split(todo_text):
+ todo_text = todo_text.replace("……", "。").replace("——", ",")
+ if todo_text[-1] not in splits:
+ todo_text += "。"
+ i_split_head = i_split_tail = 0
+ len_text = len(todo_text)
+ todo_texts = []
+ while 1:
+ if i_split_head >= len_text:
+ break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
+ if todo_text[i_split_head] in splits:
+ i_split_head += 1
+ todo_texts.append(todo_text[i_split_tail:i_split_head])
+ i_split_tail = i_split_head
+ else:
+ i_split_head += 1
+ return todo_texts
+
+
+def cut1(inp):
+ inp = inp.strip("\n")
+ inps = split(inp)
+ split_idx = list(range(0, len(inps), 4))
+ split_idx[-1] = None
+ if len(split_idx) > 1:
+ opts = []
+ for idx in range(len(split_idx) - 1):
+ opts.append("".join(inps[split_idx[idx] : split_idx[idx + 1]]))
+ else:
+ opts = [inp]
+ opts = [item for item in opts if not set(item).issubset(punctuation)]
+ return "\n".join(opts)
+
+
+def cut2(inp):
+ inp = inp.strip("\n")
+ inps = split(inp)
+ if len(inps) < 2:
+ return inp
+ opts = []
+ summ = 0
+ tmp_str = ""
+ for i in range(len(inps)):
+ summ += len(inps[i])
+ tmp_str += inps[i]
+ if summ > 50:
+ summ = 0
+ opts.append(tmp_str)
+ tmp_str = ""
+ if tmp_str != "":
+ opts.append(tmp_str)
+ # print(opts)
+ if len(opts) > 1 and len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起
+ opts[-2] = opts[-2] + opts[-1]
+ opts = opts[:-1]
+ opts = [item for item in opts if not set(item).issubset(punctuation)]
+ return "\n".join(opts)
+
+
+def cut3(inp):
+ inp = inp.strip("\n")
+ opts = ["%s" % item for item in inp.strip("。").split("。")]
+ opts = [item for item in opts if not set(item).issubset(punctuation)]
+ return "\n".join(opts)
+
+
+def cut4(inp):
+ inp = inp.strip("\n")
+ opts = re.split(r"(? 0 and i < len(inp) - 1 and inp[i - 1].isdigit() and inp[i + 1].isdigit():
+ items.append(char)
+ else:
+ items.append(char)
+ mergeitems.append("".join(items))
+ items = []
+ else:
+ items.append(char)
+
+ if items:
+ mergeitems.append("".join(items))
+
+ opt = [item for item in mergeitems if not set(item).issubset(punds)]
+ return "\n".join(opt)
+
+
def get_tts_wav(
ref_wav_path,
prompt_text,
prompt_language,
text,
text_language,
+ how_to_cut="不切",
top_k=15,
top_p=0.6,
temperature=0.6,
@@ -909,6 +1022,20 @@ def get_tts_wav(
refer, audio_tensor = get_spepc(hps, ref_wav_path, dtype, device)
t1 = ttime()
+ # Apply text cutting based on how_to_cut parameter
+ if how_to_cut == "凑四句一切":
+ text = cut1(text)
+ elif how_to_cut == "凑50字一切":
+ text = cut2(text)
+ elif how_to_cut == "按中文句号。切":
+ text = cut3(text)
+ elif how_to_cut == "按英文句号.切":
+ text = cut4(text)
+ elif how_to_cut == "按标点符号切":
+ text = cut5(text)
+ while "\n\n" in text:
+ text = text.replace("\n\n", "\n")
+
# os.environ['version'] = version
prompt_language = dict_language[prompt_language.lower()]
text_language = dict_language[text_language.lower()]
@@ -1104,6 +1231,7 @@ def handle(
text,
text_language,
cut_punc,
+ how_to_cut,
top_k,
top_p,
temperature,
@@ -1140,6 +1268,7 @@ def handle(
prompt_language,
text,
text_language,
+ how_to_cut,
top_k,
top_p,
temperature,
@@ -1211,6 +1340,9 @@ parser.add_argument("-mt", "--media_type", type=str, default="wav", help="音频
parser.add_argument("-st", "--sub_type", type=str, default="int16", help="音频数据类型, int16 / int32")
parser.add_argument("-cp", "--cut_punc", type=str, default="", help="文本切分符号设定, 符号范围,.;?!、,。?!;:…")
# 切割常用分句符为 `python ./api.py -cp ".?!。?!"`
+parser.add_argument("-htc", "--how_to_cut", type=str, default="不切",
+ choices=["不切", "凑四句一切", "凑50字一切", "按中文句号。切", "按英文句号.切", "按标点符号切"],
+ help="文本切分方式")
parser.add_argument("-hb", "--hubert_path", type=str, default=g_config.cnhubert_path, help="覆盖config.cnhubert_path")
parser.add_argument("-b", "--bert_path", type=str, default=g_config.bert_path, help="覆盖config.bert_path")
@@ -1223,6 +1355,7 @@ host = args.bind_addr
cnhubert_base_path = args.hubert_path
bert_path = args.bert_path
default_cut_punc = args.cut_punc
+how_to_cut = args.how_to_cut
# 应用参数配置
default_refer = DefaultRefer(args.default_refer_path, args.default_refer_text, args.default_refer_language)
@@ -1348,6 +1481,7 @@ async def tts_endpoint(request: Request):
json_post_raw.get("text"),
json_post_raw.get("text_language"),
json_post_raw.get("cut_punc"),
+ json_post_raw.get("how_to_cut"),
json_post_raw.get("top_k", 15),
json_post_raw.get("top_p", 1.0),
json_post_raw.get("temperature", 1.0),
@@ -1366,6 +1500,7 @@ async def tts_endpoint(
text: str = None,
text_language: str = None,
cut_punc: str = None,
+ how_to_cut: str = Query(default=None),
top_k: int = 15,
top_p: float = 1.0,
temperature: float = 1.0,
@@ -1381,6 +1516,7 @@ async def tts_endpoint(
text,
text_language,
cut_punc,
+ how_to_cut,
top_k,
top_p,
temperature,