From 510347e6b02fc336fcc1d6abbbb7e1f0fa400067 Mon Sep 17 00:00:00 2001 From: ChasonJiang <1440499136@qq.com> Date: Thu, 4 Apr 2024 16:17:40 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E4=BB=A3=E7=A0=81=E9=80=BB?= =?UTF-8?q?=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- GPT_SoVITS/TTS_infer_pack/TTS.py | 2 +- .../text_segmentation_method.py | 3 + api_v2.py | 63 +++++++++++++------ 3 files changed, 48 insertions(+), 20 deletions(-) diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 08f71d9c..c5f227c1 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -844,7 +844,7 @@ class TTS: self.vits_model = None self.init_t2s_weights(self.configs.t2s_weights_path) self.init_vits_weights(self.configs.vits_weights_path) - raise + raise e finally: self.empty_cache() diff --git a/GPT_SoVITS/TTS_infer_pack/text_segmentation_method.py b/GPT_SoVITS/TTS_infer_pack/text_segmentation_method.py index eb256106..2608f4c9 100644 --- a/GPT_SoVITS/TTS_infer_pack/text_segmentation_method.py +++ b/GPT_SoVITS/TTS_infer_pack/text_segmentation_method.py @@ -16,6 +16,9 @@ def get_method(name:str)->Callable: raise ValueError(f"Method {name} not found") return method +def get_method_names()->list: + return list(METHODS.keys()) + def register_method(name): def decorator(func): METHODS[name] = func diff --git a/api_v2.py b/api_v2.py index 2595aed0..3fe7cd82 100644 --- a/api_v2.py +++ b/api_v2.py @@ -96,6 +96,7 @@ RESP: """ import os import sys +import traceback from typing import Generator now_dir = os.getcwd() @@ -115,11 +116,12 @@ import uvicorn from io import BytesIO from tools.i18n.i18n import I18nAuto from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config +from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names from fastapi.responses import StreamingResponse from pydantic import BaseModel # print(sys.path) i18n = I18nAuto() - +cut_method_names = get_cut_method_names() parser = argparse.ArgumentParser(description="GPT-SoVITS api") parser.add_argument("-c", "--tts_config", type=str, default="GPT_SoVITS/configs/tts_infer.yaml", help="tts_infer路径") @@ -230,6 +232,37 @@ def handle_control(command:str): exit(0) +def check_params(req:dict): + text:str = req.get("text", "") + text_lang:str = req.get("text_lang", "") + ref_audio_path:str = req.get("ref_audio_path", "") + streaming_mode:bool = req.get("streaming_mode", False) + media_type:str = req.get("media_type", "wav") + prompt_lang:str = req.get("prompt_lang", "") + text_split_method:str = req.get("text_split_method", "cut5") + + if ref_audio_path in [None, ""]: + return JSONResponse(status_code=400, content={"message": "ref_audio_path is required"}) + if text in [None, ""]: + return JSONResponse(status_code=400, content={"message": "text is required"}) + if (text_lang in [None, ""]) : + return JSONResponse(status_code=400, content={"message": "text_language is required"}) + elif text_lang.lower() not in tts_config.languages: + return JSONResponse(status_code=400, content={"message": "text_language is not supported"}) + if (prompt_lang in [None, ""]) : + return JSONResponse(status_code=400, content={"message": "prompt_lang is required"}) + elif prompt_lang.lower() not in tts_config.languages: + return JSONResponse(status_code=400, content={"message": "prompt_lang is not supported"}) + if media_type not in ["wav", "raw", "ogg", "aac"]: + return JSONResponse(status_code=400, content={"message": "media_type is not supported"}) + elif media_type == "ogg" and not streaming_mode: + return JSONResponse(status_code=400, content={"message": "ogg format is not supported in non-streaming mode"}) + + if text_split_method not in cut_method_names: + return JSONResponse(status_code=400, content={"message": f"text_split_method:{text_split_method} is not supported"}) + + return None + async def tts_handle(req:dict): """ Text to speech handler. @@ -259,28 +292,16 @@ async def tts_handle(req:dict): StreamingResponse: audio stream response. """ - text:str = req.get("text", "") - text_lang:str = req.get("text_lang", "") - ref_audio_path:str = req.get("ref_audio_path", "") streaming_mode = req.get("streaming_mode", False) media_type = req.get("media_type", "wav") + check_res = check_params(req) + if check_res is not None: + return check_res + if streaming_mode: req["return_fragment"] = True - if ref_audio_path in [None, ""]: - return JSONResponse(status_code=400, content={"message": "ref_audio_path is required"}) - if text in [None, ""]: - return JSONResponse(status_code=400, content={"message": "text is required"}) - if (text_lang in [None, ""]) : - return JSONResponse(status_code=400, content={"message": "text_language is required"}) - elif text_lang.lower() not in ["auto", "en", "zh", "ja", "all_zh", "all_ja"]: - return JSONResponse(status_code=400, content={"message": "text_language is not supported"}) - if media_type not in ["wav", "raw", "ogg", "aac"]: - return JSONResponse(status_code=400, content={"message": "media_type is not supported"}) - elif media_type == "ogg" and not streaming_mode: - return JSONResponse(status_code=400, content={"message": "ogg format is not supported in non-streaming mode"}) - try: tts_generator=tts_pipeline.run(req) @@ -414,5 +435,9 @@ async def set_sovits_weights(weights_path: str = None): if __name__ == "__main__": - uvicorn.run(APP, host=host, port=port, workers=1) - + try: + uvicorn.run(APP, host=host, port=port, workers=1) + except Exception as e: + traceback.print_exc() + os.kill(os.getpid(), signal.SIGTERM) + exit(0) \ No newline at end of file