From ee142376096bd90527c6d3bb8a7965779a6a0209 Mon Sep 17 00:00:00 2001 From: XXXXRT666 <157766680+XXXXRT666@users.noreply.github.com> Date: Sun, 3 Mar 2024 09:07:53 +0000 Subject: [PATCH] =?UTF-8?q?=E6=94=AF=E6=8C=81=E8=AE=BE=E7=BD=AE=E9=BB=98?= =?UTF-8?q?=E8=AE=A4=E5=88=87=E5=88=86=E6=96=B9=E5=BC=8F=E4=BB=A5=E5=8F=8A?= =?UTF-8?q?=E5=88=87=E5=88=86=E6=96=B9=E5=BC=8F=E5=88=87=E5=88=86,?= =?UTF-8?q?=E4=BC=98=E5=8C=96=E9=BB=98=E8=AE=A4=E5=8F=82=E6=95=B0=E4=BF=AE?= =?UTF-8?q?=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 现在只需输入想修改的默认参数即可修改,无需输入全部参数 --- api.py | 199 ++++++++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 175 insertions(+), 24 deletions(-) diff --git a/api.py b/api.py index 754f0769..db8c7082 100644 --- a/api.py +++ b/api.py @@ -12,6 +12,7 @@ `-dr` - `默认参考音频路径` `-dt` - `默认参考音频文本` `-dl` - `默认参考音频语种, "中文","英文","日文","zh","en","ja"` +`-cut` - `默认切分方式,"凑四句一切","凑50字一切","按中文句号。切","按英文句号.切","按标点符号切",分别为cut1,cut2,cut3,cut4,cut5` `-d` - `推理设备, "cuda","cpu","mps"` `-a` - `绑定地址, 默认"127.0.0.1"` @@ -28,7 +29,7 @@ endpoint: `/` -使用执行参数指定的参考音频: +使用执行参数指定的参考音频与切分方式: GET: `http://127.0.0.1:9880?text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh` POST: @@ -39,9 +40,9 @@ POST: } ``` -手动指定当次推理所使用的参考音频: +手动指定当次推理所使用的参考音频与切分方式: GET: - `http://127.0.0.1:9880?refer_wav_path=123.wav&prompt_text=一二三。&prompt_language=zh&text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh` + `http://127.0.0.1:9880?refer_wav_path=123.wav&prompt_text=一二三。&prompt_language=zh&text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh&how_to_cut=cut1` POST: ```json { @@ -50,6 +51,7 @@ POST: "prompt_language": "zh", "text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。", "text_language": "zh" + "how_to_cut"= "cut1" } ``` @@ -58,20 +60,21 @@ RESP: 失败: 返回包含错误信息的 json, http code 400 -### 更换默认参考音频 +### 更换默认参考音频及切分方式 endpoint: `/change_refer` key与推理端一样 GET: - `http://127.0.0.1:9880/change_refer?refer_wav_path=123.wav&prompt_text=一二三。&prompt_language=zh` + `http://127.0.0.1:9880/change_refer?refer_wav_path=123.wav&prompt_text=一二三。&prompt_language=zh&how_to_cut=cut1` POST: ```json { "refer_wav_path": "123.wav", "prompt_text": "一二三。", "prompt_language": "zh" + "how_to_cut": "cut1" } ``` @@ -80,6 +83,7 @@ RESP: 失败: json, 400 + ### 命令控制 endpoint: `/control` @@ -103,7 +107,7 @@ RESP: 无 import argparse -import os +import os,re import sys now_dir = os.getcwd() @@ -142,6 +146,7 @@ parser.add_argument("-g", "--gpt_path", type=str, default=g_config.gpt_path, hel parser.add_argument("-dr", "--default_refer_path", type=str, default="", help="默认参考音频路径") parser.add_argument("-dt", "--default_refer_text", type=str, default="", help="默认参考音频文本") parser.add_argument("-dl", "--default_refer_language", type=str, default="", help="默认参考音频语种") +parser.add_argument("-cut", "--default_how_to_cut", type=str, default="cut1", help="默认切分方式") parser.add_argument("-d", "--device", type=str, default=g_config.infer_device, help="cuda / cpu / mps") parser.add_argument("-a", "--bind_addr", type=str, default="0.0.0.0", help="default: 0.0.0.0") @@ -154,23 +159,27 @@ parser.add_argument("-hp", "--half_precision", action="store_true", default=Fals parser.add_argument("-hb", "--hubert_path", type=str, default=g_config.cnhubert_path, help="覆盖config.cnhubert_path") parser.add_argument("-b", "--bert_path", type=str, default=g_config.bert_path, help="覆盖config.bert_path") + args = parser.parse_args() sovits_path = args.sovits_path gpt_path = args.gpt_path +splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", } + class DefaultRefer: - def __init__(self, path, text, language): + def __init__(self, path, text, language, how_to_cut): self.path = args.default_refer_path self.text = args.default_refer_text self.language = args.default_refer_language + self.how_to_cut = args.default_how_to_cut def is_ready(self) -> bool: - return is_full(self.path, self.text, self.language) + return is_full(self.path, self.text, self.language,self.how_to_cut) -default_refer = DefaultRefer(args.default_refer_path, args.default_refer_text, args.default_refer_language) +default_refer = DefaultRefer(args.default_refer_path, args.default_refer_text, args.default_refer_language, args.default_how_to_cut) device = args.device port = args.port @@ -188,10 +197,23 @@ if default_refer.path == "" or default_refer.text == "" or default_refer.languag default_refer.path, default_refer.text, default_refer.language = "", "", "" print("[INFO] 未指定默认参考音频") else: + if (default_refer.how_to_cut == "cut1"): + cut_info = "凑四句一切" + elif (default_refer.how_to_cut == "cut2"): + cut_info = "凑50字一切" + elif (default_refer.how_to_cut == "cut3"): + cut_info = "按中文句号。切" + elif (default_refer.how_to_cut == "cut4"): + cut_info = "按英文句号.切" + elif (default_refer.how_to_cut == "cut5"): + cut_info = "按标点符号切" + print(f"[INFO] 默认参考音频路径: {default_refer.path}") print(f"[INFO] 默认参考音频文本: {default_refer.text}") print(f"[INFO] 默认参考音频语种: {default_refer.language}") + print(f"[INFO] 默认切割方式: {cut_info}") + is_half = g_config.is_half if args.full_precision: is_half = False @@ -354,6 +376,86 @@ dict_language = { } +def split(todo_text): + todo_text = todo_text.replace("……", "。").replace("——", ",") + if todo_text[-1] not in splits: + todo_text += "。" + i_split_head = i_split_tail = 0 + len_text = len(todo_text) + todo_texts = [] + while 1: + if i_split_head >= len_text: + break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入 + if todo_text[i_split_head] in splits: + i_split_head += 1 + todo_texts.append(todo_text[i_split_tail:i_split_head]) + i_split_tail = i_split_head + else: + i_split_head += 1 + return todo_texts + + +def cut1(inp): + inp = inp.strip("\n") + inps = split(inp) + split_idx = list(range(0, len(inps), 4)) + split_idx[-1] = None + if len(split_idx) > 1: + opts = [] + for idx in range(len(split_idx) - 1): + opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]])) + else: + opts = [inp] + return "\n".join(opts) + + +def cut2(inp): + inp = inp.strip("\n") + inps = split(inp) + if len(inps) < 2: + return inp + opts = [] + summ = 0 + tmp_str = "" + for i in range(len(inps)): + summ += len(inps[i]) + tmp_str += inps[i] + if summ > 50: + summ = 0 + opts.append(tmp_str) + tmp_str = "" + if tmp_str != "": + opts.append(tmp_str) + # print(opts) + if len(opts) > 1 and len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起 + opts[-2] = opts[-2] + opts[-1] + opts = opts[:-1] + return "\n".join(opts) + + +def cut3(inp): + inp = inp.strip("\n") + return "\n".join(["%s" % item for item in inp.strip("。").split("。")]) + + +def cut4(inp): + inp = inp.strip("\n") + return "\n".join(["%s" % item for item in inp.strip(".").split(".")]) + + +# contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py +def cut5(inp): + # if not re.search(r'[^\w\s]', inp[-1]): + # inp += '。' + inp = inp.strip("\n") + punds = r'[,.;?!、,。?!;:…]' + items = re.split(f'({punds})', inp) + mergeitems = ["".join(group) for group in zip(items[::2], items[1::2])] + # 在句子不存在符号或句尾无符号的时候保证文本完整 + if len(items)%2 == 1: + mergeitems.append(items[-1]) + opt = "\n".join(mergeitems) + return opt def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language): t0 = ttime() prompt_text = prompt_text.strip("\n") @@ -374,8 +476,10 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language) codes = vq_model.extract_latent(ssl_content) prompt_semantic = codes[0, 0] t1 = ttime() + prompt_language = dict_language[prompt_language] text_language = dict_language[text_language] + phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language) phones1 = cleaned_text_to_sequence(phones1) texts = text.split("\n") @@ -438,26 +542,46 @@ def handle_control(command): exit(0) -def handle_change(path, text, language): - if is_empty(path, text, language): - return JSONResponse({"code": 400, "message": '缺少任意一项以下参数: "path", "text", "language"'}, status_code=400) +def handle_change(path, text, language, how_to_cut): + if is_empty(path, text, language, how_to_cut): + return JSONResponse({"code": 400, "message": '缺少任意一项以下参数: "path", "text", "language","how_to_cut"'}, status_code=400) + + if path == "" or path is None: + path = default_refer.path + if text == "" or text is None: + text = default_refer.text + if language == "" or language is None: + language = default_refer.language + if how_to_cut == "" or how_to_cut is None: + how_to_cut = default_refer.how_to_cut + + default_refer.path = path + default_refer.text = text + default_refer.language = language + default_refer.how_to_cut = how_to_cut - if path != "" or path is not None: - default_refer.path = path - if text != "" or text is not None: - default_refer.text = text - if language != "" or language is not None: - default_refer.language = language + + if (default_refer.how_to_cut == "cut1"): + cut_info = "凑四句一切" + elif (default_refer.how_to_cut == "cut2"): + cut_info = "凑50字一切" + elif (default_refer.how_to_cut == "cut3"): + cut_info = "按中文句号。切" + elif (default_refer.how_to_cut == "cut4"): + cut_info = "按英文句号.切" + elif (default_refer.how_to_cut == "cut5"): + cut_info = "按标点符号切" print(f"[INFO] 当前默认参考音频路径: {default_refer.path}") print(f"[INFO] 当前默认参考音频文本: {default_refer.text}") print(f"[INFO] 当前默认参考音频语种: {default_refer.language}") + print(f"[INFO] 默认切割方式: {cut_info}") print(f"[INFO] is_ready: {default_refer.is_ready()}") return JSONResponse({"code": 0, "message": "Success"}, status_code=200) -def handle(refer_wav_path, prompt_text, prompt_language, text, text_language): +def handle(refer_wav_path, prompt_text, prompt_language, text, text_language,how_to_cut): if ( refer_wav_path == "" or refer_wav_path is None or prompt_text == "" or prompt_text is None @@ -468,8 +592,30 @@ def handle(refer_wav_path, prompt_text, prompt_language, text, text_language): default_refer.text, default_refer.language, ) + if (how_to_cut == "" or how_to_cut is None): + how_to_cut = default_refer.how_to_cut if not default_refer.is_ready(): return JSONResponse({"code": 400, "message": "未指定参考音频且接口无预设"}, status_code=400) + + if (how_to_cut == "cut1"): + text = cut1(text) #凑四句一切 + + elif (how_to_cut == "cut2"): + text = cut2(text) #凑50字一切 + + elif (how_to_cut == "cut3"): + text = cut3(text) #按中文句号。切 + + elif (how_to_cut == "cut4"): + text = cut4(text) #按英文句号.切 + + elif (how_to_cut == "cut5"): + text = cut5(text) #按标点符号切 + + while "\n\n" in text: + text = text.replace("\n\n", "\n") + + print("实际输入的目标文本(切句后):", text) with torch.no_grad(): gen = get_tts_wav( @@ -522,7 +668,8 @@ async def change_refer(request: Request): return handle_change( json_post_raw.get("refer_wav_path"), json_post_raw.get("prompt_text"), - json_post_raw.get("prompt_language") + json_post_raw.get("prompt_language"), + json_post_raw.get("how_to_cut") ) @@ -530,9 +677,11 @@ async def change_refer(request: Request): async def change_refer( refer_wav_path: str = None, prompt_text: str = None, - prompt_language: str = None + prompt_language: str = None, + how_to_cut: str = None ): - return handle_change(refer_wav_path, prompt_text, prompt_language) + return handle_change(refer_wav_path, prompt_text, prompt_language, how_to_cut) + @app.post("/") @@ -544,7 +693,8 @@ async def tts_endpoint(request: Request): json_post_raw.get("prompt_language"), json_post_raw.get("text"), json_post_raw.get("text_language"), - ) + json_post_raw.get("how_to_cut") +) @app.get("/") @@ -554,8 +704,9 @@ async def tts_endpoint( prompt_language: str = None, text: str = None, text_language: str = None, + how_to_cut: str = None ): - return handle(refer_wav_path, prompt_text, prompt_language, text, text_language) + return handle(refer_wav_path, prompt_text, prompt_language, text, text_language,how_to_cut) if __name__ == "__main__":