diff --git a/api.py b/api.py index 993f090b..eac528a4 100644 --- a/api.py +++ b/api.py @@ -35,7 +35,8 @@ POST: ```json { "text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。", - "text_language": "zh" + "text_language": "zh", + "slice": "按标点符号切" } ``` @@ -120,7 +121,7 @@ RESP: 无 import argparse -import os +import os, re import sys now_dir = os.getcwd() @@ -166,6 +167,7 @@ parser.add_argument("-a", "--bind_addr", type=str, default="0.0.0.0", help="defa parser.add_argument("-p", "--port", type=int, default=g_config.api_port, help="default: 9880") parser.add_argument("-fp", "--full_precision", action="store_true", default=False, help="覆盖config.is_half为False, 使用全精度") parser.add_argument("-hp", "--half_precision", action="store_true", default=False, help="覆盖config.is_half为True, 使用半精度") +#parser.add_argument("-sl", "--slice", type=str, default="No slice", help="文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。") # bool值的用法为 `python ./api.py -fp ...` # 此时 full_precision==True, half_precision==False @@ -375,6 +377,19 @@ dict_language = { "多语种混合": "auto" } +slice_option = { + "凑四句一切": "凑四句一切", + "凑50字一切": "凑50字一切", + "按中文句号。切": "按中文句号。切", + "按英文句号.切": "按英文句号.切", + "按标点符号切": "按标点符号切", + "per 4 sentences": "凑四句一切", + "per 50 letters": "凑50字一切", + "per period": "按英文句号.切", + "per punctuation mark": "按标点符号切", + None: "No slice" +} + dtype=torch.float16 if is_half == True else torch.float32 def get_bert_inf(phones, word2ph, norm_text, language): language=language.replace("all_","") @@ -447,13 +462,41 @@ def get_phones_and_bert(text,language): return phones,bert.to(dtype),norm_text -def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language): +def merge_short_text_in_array(texts, threshold): + if (len(texts)) < 2: + return texts + result = [] + text = "" + for ele in texts: + text += ele + if len(text) >= threshold: + result.append(text) + text = "" + if (len(text) > 0): + if len(result) == 0: + result.append(text) + else: + result[len(result) - 1] += text + return result + +def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut): + # not supporting ref_free t0 = ttime() + prompt_language = dict_language[prompt_language] prompt_text = prompt_text.strip("\n") - prompt_language, text = prompt_language, text.strip("\n") - zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32) + if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_language != "en" else "." + print(f"实际输入的参考文本: {prompt_text}") + text = text.strip("\n") + if (text[0] not in splits and len(get_first(text)) < 4): text = "。" + text if text_language != "en" else "." + text + print(f"实际输入的目标文本: {text}") + + zero_wav = np.zeros( + int(hps.data.sampling_rate * 0.3), + dtype=np.float16 if is_half == True else np.float32 + ) with torch.no_grad(): wav16k, sr = librosa.load(ref_wav_path, sr=16000) + # neglected error checking for reference audio duration wav16k = torch.from_numpy(wav16k) zero_wav_torch = torch.from_numpy(zero_wav) if (is_half == True): @@ -463,20 +506,48 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language) wav16k = wav16k.to(device) zero_wav_torch = zero_wav_torch.to(device) wav16k = torch.cat([wav16k, zero_wav_torch]) - ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float() + ssl_content = ssl_model.model(wav16k.unsqueeze(0))[ + "last_hidden_state" + ].transpose( + 1, 2 + ) # .float() codes = vq_model.extract_latent(ssl_content) + prompt_semantic = codes[0, 0] t1 = ttime() - prompt_language = dict_language[prompt_language] - text_language = dict_language[text_language] - phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language) - phones1 = cleaned_text_to_sequence(phones1) + + # 文本切句 + # default to no slice if no argument is provided + how_to_cut = slice_option[how_to_cut] + print(f"[INFO] 文本切句選項: {how_to_cut}") + if (how_to_cut == "凑四句一切"): + text = cut1(text) + elif (how_to_cut == "凑50字一切"): + text = cut2(text) + elif (how_to_cut == "按英文句号.切"): + text = cut3(text) + elif (how_to_cut == "按英文句号.切"): + text = cut4(text) + elif (how_to_cut == "按标点符号切"): + text = cut5(text) + + while "\n\n" in text: + text = text.replace("\n\n", "\n") + print(f"实际输入的目标文本(切句后): {text}") texts = text.split("\n") + texts = merge_short_text_in_array(texts, 5) audio_opt = [] + phones1,bert1,norm_text1=get_phones_and_bert(prompt_text, prompt_language) for text in texts: + # 解决输入目标文本的空行导致报错的问题 + if (len(text.strip()) == 0): + continue + if (text[-1] not in splits): text += "。" if text_language != "en" else "." + print(f"实际输入的目标文本(每句): {text}") phones2,bert2,norm_text2=get_phones_and_bert(text, text_language) + print(f"前端处理后的文本(每句): {norm_text2}") bert = torch.cat([bert1, bert2], 1) all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0) @@ -496,23 +567,118 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language) early_stop_num=hz * max_sec) t3 = ttime() # print(pred_semantic.shape,idx) - pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次 + pred_semantic = pred_semantic[:, -idx:].unsqueeze( + 0 + ) # .unsqueeze(0)#mq要多unsqueeze一次 refer = get_spepc(hps, ref_wav_path) # .to(device) if (is_half == True): refer = refer.half().to(device) else: refer = refer.to(device) # audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0] - audio = \ - vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), - refer).detach().cpu().numpy()[ - 0, 0] ###试试重建不带上prompt部分 + audio = ( + vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer + ) + .detach() + .cpu() + .numpy()[0, 0] + ) ###试试重建不带上prompt部分 + max_audio=np.abs(audio).max()#简单防止16bit爆音 + if max_audio>1:audio/=max_audio audio_opt.append(audio) audio_opt.append(zero_wav) t4 = ttime() print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3)) - yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(np.int16) + yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype( + np.int16 + ) +splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", } + +def get_first(text): + pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]" + text = re.split(pattern, text)[0].strip() + return text + +def split(todo_text): + todo_text = todo_text.replace("……", "。").replace("——", ",") + if todo_text[-1] not in splits: + todo_text += "。" + i_split_head = i_split_tail = 0 + len_text = len(todo_text) + todo_texts = [] + while 1: + if i_split_head >= len_text: + break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入 + if todo_text[i_split_head] in splits: + i_split_head += 1 + todo_texts.append(todo_text[i_split_tail:i_split_head]) + i_split_tail = i_split_head + else: + i_split_head += 1 + return todo_texts + +def cut1(inp): + inp = inp.strip("\n") + inps = split(inp) + split_idx = list(range(0, len(inps), 4)) + split_idx[-1] = None + if len(split_idx) > 1: + opts = [] + for idx in range(len(split_idx) - 1): + opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]])) + else: + opts = [inp] + return "\n".join(opts) + + +def cut2(inp): + inp = inp.strip("\n") + inps = split(inp) + if len(inps) < 2: + return inp + opts = [] + summ = 0 + tmp_str = "" + for i in range(len(inps)): + summ += len(inps[i]) + tmp_str += inps[i] + if summ > 50: + summ = 0 + opts.append(tmp_str) + tmp_str = "" + if tmp_str != "": + opts.append(tmp_str) + # print(opts) + if len(opts) > 1 and len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起 + opts[-2] = opts[-2] + opts[-1] + opts = opts[:-1] + return "\n".join(opts) + + +def cut3(inp): + inp = inp.strip("\n") + return "\n".join(["%s" % item for item in inp.strip("。").split("。")]) + + +def cut4(inp): + inp = inp.strip("\n") + return "\n".join(["%s" % item for item in inp.strip(".").split(".")]) + + +# contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py +def cut5(inp): + # if not re.search(r'[^\w\s]', inp[-1]): + # inp += '。' + inp = inp.strip("\n") + punds = r'[,.;?!、,。?!;:…]' + items = re.split(f'({punds})', inp) + mergeitems = ["".join(group) for group in zip(items[::2], items[1::2])] + # 在句子不存在符号或句尾无符号的时候保证文本完整 + if len(items)%2 == 1: + mergeitems.append(items[-1]) + opt = "\n".join(mergeitems) + return opt def handle_control(command): if command == "restart": @@ -541,7 +707,7 @@ def handle_change(path, text, language): return JSONResponse({"code": 0, "message": "Success"}, status_code=200) -def handle(refer_wav_path, prompt_text, prompt_language, text, text_language): +def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, slice): if ( refer_wav_path == "" or refer_wav_path is None or prompt_text == "" or prompt_text is None @@ -557,7 +723,7 @@ def handle(refer_wav_path, prompt_text, prompt_language, text, text_language): with torch.no_grad(): gen = get_tts_wav( - refer_wav_path, prompt_text, prompt_language, text, text_language + refer_wav_path, prompt_text, prompt_language, text, text_language, slice ) sampling_rate, audio_data = next(gen) @@ -628,6 +794,7 @@ async def tts_endpoint(request: Request): json_post_raw.get("prompt_language"), json_post_raw.get("text"), json_post_raw.get("text_language"), + json_post_raw.get("slice"), ) @@ -638,8 +805,9 @@ async def tts_endpoint( prompt_language: str = None, text: str = None, text_language: str = None, + slice: str = None ): - return handle(refer_wav_path, prompt_text, prompt_language, text, text_language) + return handle(refer_wav_path, prompt_text, prompt_language, text, text_language, slice) if __name__ == "__main__":