diff --git a/api.py b/api.py index 754f0769..aa5df909 100644 --- a/api.py +++ b/api.py @@ -456,13 +456,14 @@ def handle_change(path, text, language): return JSONResponse({"code": 0, "message": "Success"}, status_code=200) +def _empty_parameter(*items): + for item in items: + if item is None or item == "": + return True + return False def handle(refer_wav_path, prompt_text, prompt_language, text, text_language): - if ( - refer_wav_path == "" or refer_wav_path is None - or prompt_text == "" or prompt_text is None - or prompt_language == "" or prompt_language is None - ): + if (_empty_parameter(refer_wav_path, prompt_text, prompt_language, text, text_language)): refer_wav_path, prompt_text, prompt_language = ( default_refer.path, default_refer.text, @@ -470,22 +471,26 @@ def handle(refer_wav_path, prompt_text, prompt_language, text, text_language): ) if not default_refer.is_ready(): return JSONResponse({"code": 400, "message": "未指定参考音频且接口无预设"}, status_code=400) + if _empty_parameter(text, text_language): + return JSONResponse({"code": 400, "message": "缺少参数: text 或 text_language"}, status_code=400) + try: + with torch.no_grad(): + gen = get_tts_wav( + refer_wav_path, prompt_text, prompt_language, text, text_language + ) + sampling_rate, audio_data = next(gen) - with torch.no_grad(): - gen = get_tts_wav( - refer_wav_path, prompt_text, prompt_language, text, text_language - ) - sampling_rate, audio_data = next(gen) + wav = BytesIO() + sf.write(wav, audio_data, sampling_rate, format="wav") + wav.seek(0) - wav = BytesIO() - sf.write(wav, audio_data, sampling_rate, format="wav") - wav.seek(0) - - torch.cuda.empty_cache() - if device == "mps": - print('executed torch.mps.empty_cache()') - torch.mps.empty_cache() - return StreamingResponse(wav, media_type="audio/wav") + torch.cuda.empty_cache() + if device == "mps": + print('executed torch.mps.empty_cache()') + torch.mps.empty_cache() + return StreamingResponse(wav, media_type="audio/wav") + except Exception as e: + return JSONResponse({"code": 400, "message": f"error: {e}"}, status_code=400) app = FastAPI()