diff --git a/api.py b/api.py index aa822ca..151e6c2 100644 --- a/api.py +++ b/api.py @@ -72,6 +72,28 @@ RESP: 成功: 直接返回 wav 音频流, http code 200 失败: 返回包含错误信息的 json, http code 400 +手动指定当次推理所使用的参考音频,并提供参数: +GET: + `http://127.0.0.1:9880?refer_wav_path=123.wav&prompt_text=一二三。&prompt_language=zh&text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh&top_k=20&top_p=0.6&temperature=0.6&speed=1` +POST: +```json +{ + "refer_wav_path": "123.wav", + "prompt_text": "一二三。", + "prompt_language": "zh", + "text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。", + "text_language": "zh", + "top_k": 20, + "top_p": 0.6, + "temperature": 0.6, + "speed": 1 +} +``` + +RESP: +成功: 直接返回 wav 音频流, http code 200 +失败: 返回包含错误信息的 json, http code 400 + ### 更换默认参考音频 @@ -446,7 +468,7 @@ def only_punc(text): return not any(t.isalnum() or t.isalpha() for t in text) -def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language): +def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, top_k= 20, top_p = 0.6, temperature = 0.6, speed = 1): t0 = ttime() prompt_text = prompt_text.strip("\n") prompt_language, text = prompt_language, text.strip("\n") @@ -494,7 +516,9 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language) prompt, bert, # prompt_phone_len=ph_offset, - top_k=config['inference']['top_k'], + top_k = top_k, + top_p = top_p, + temperature = temperature, early_stop_num=hz * max_sec) t3 = ttime() # print(pred_semantic.shape,idx) @@ -507,7 +531,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language) # audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0] audio = \ vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), - refer).detach().cpu().numpy()[ + refer,speed=speed).detach().cpu().numpy()[ 0, 0] ###试试重建不带上prompt部分 audio_opt.append(audio) audio_opt.append(zero_wav) @@ -553,7 +577,7 @@ def handle_change(path, text, language): return JSONResponse({"code": 0, "message": "Success"}, status_code=200) -def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc): +def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed): if ( refer_wav_path == "" or refer_wav_path is None or prompt_text == "" or prompt_text is None @@ -572,7 +596,7 @@ def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cu else: text = cut_text(text,cut_punc) - return StreamingResponse(get_tts_wav(refer_wav_path, prompt_text, prompt_language, text, text_language), media_type="audio/"+media_type) + return StreamingResponse(get_tts_wav(refer_wav_path, prompt_text, prompt_language, text, text_language, top_k, top_p, temperature, speed), media_type="audio/"+media_type) @@ -755,6 +779,10 @@ async def tts_endpoint(request: Request): json_post_raw.get("text"), json_post_raw.get("text_language"), json_post_raw.get("cut_punc"), + json_post_raw.get("top_k", 10), + json_post_raw.get("top_p" 1.0), + json_post_raw.get("temperature" 1.0), + json_post_raw.get("speed", 1.0) ) @@ -766,8 +794,12 @@ async def tts_endpoint( text: str = None, text_language: str = None, cut_punc: str = None, + top_k: int = 10, + top_p: float = 1.0, + temperature: float = 1.0, + speed: float = 1.0 ): - return handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc) + return handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed) if __name__ == "__main__":