perf(api.py): 推理过程发生异常后返回报错信息

This commit is contained in:
Soulter 2024-03-05 17:57:03 +08:00
parent 78ab26ea17
commit fd04d26f06

43
api.py
View File

@ -456,13 +456,14 @@ def handle_change(path, text, language):
return JSONResponse({"code": 0, "message": "Success"}, status_code=200)
def _empty_parameter(*items):
for item in items:
if item is None or item == "":
return True
return False
def handle(refer_wav_path, prompt_text, prompt_language, text, text_language):
if (
refer_wav_path == "" or refer_wav_path is None
or prompt_text == "" or prompt_text is None
or prompt_language == "" or prompt_language is None
):
if (_empty_parameter(refer_wav_path, prompt_text, prompt_language, text, text_language)):
refer_wav_path, prompt_text, prompt_language = (
default_refer.path,
default_refer.text,
@ -470,22 +471,26 @@ def handle(refer_wav_path, prompt_text, prompt_language, text, text_language):
)
if not default_refer.is_ready():
return JSONResponse({"code": 400, "message": "未指定参考音频且接口无预设"}, status_code=400)
if _empty_parameter(text, text_language):
return JSONResponse({"code": 400, "message": "缺少参数: text 或 text_language"}, status_code=400)
try:
with torch.no_grad():
gen = get_tts_wav(
refer_wav_path, prompt_text, prompt_language, text, text_language
)
sampling_rate, audio_data = next(gen)
with torch.no_grad():
gen = get_tts_wav(
refer_wav_path, prompt_text, prompt_language, text, text_language
)
sampling_rate, audio_data = next(gen)
wav = BytesIO()
sf.write(wav, audio_data, sampling_rate, format="wav")
wav.seek(0)
wav = BytesIO()
sf.write(wav, audio_data, sampling_rate, format="wav")
wav.seek(0)
torch.cuda.empty_cache()
if device == "mps":
print('executed torch.mps.empty_cache()')
torch.mps.empty_cache()
return StreamingResponse(wav, media_type="audio/wav")
torch.cuda.empty_cache()
if device == "mps":
print('executed torch.mps.empty_cache()')
torch.mps.empty_cache()
return StreamingResponse(wav, media_type="audio/wav")
except Exception as e:
return JSONResponse({"code": 400, "message": f"error: {e}"}, status_code=400)
app = FastAPI()