优化代码逻辑

This commit is contained in:
ChasonJiang 2024-04-04 16:17:40 +08:00
parent 97fe6039b7
commit 510347e6b0
3 changed files with 48 additions and 20 deletions

View File

@ -844,7 +844,7 @@ class TTS:
self.vits_model = None
self.init_t2s_weights(self.configs.t2s_weights_path)
self.init_vits_weights(self.configs.vits_weights_path)
raise
raise e
finally:
self.empty_cache()

View File

@ -16,6 +16,9 @@ def get_method(name:str)->Callable:
raise ValueError(f"Method {name} not found")
return method
def get_method_names()->list:
return list(METHODS.keys())
def register_method(name):
def decorator(func):
METHODS[name] = func

View File

@ -96,6 +96,7 @@ RESP:
"""
import os
import sys
import traceback
from typing import Generator
now_dir = os.getcwd()
@ -115,11 +116,12 @@ import uvicorn
from io import BytesIO
from tools.i18n.i18n import I18nAuto
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config
from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
# print(sys.path)
i18n = I18nAuto()
cut_method_names = get_cut_method_names()
parser = argparse.ArgumentParser(description="GPT-SoVITS api")
parser.add_argument("-c", "--tts_config", type=str, default="GPT_SoVITS/configs/tts_infer.yaml", help="tts_infer路径")
@ -230,6 +232,37 @@ def handle_control(command:str):
exit(0)
def check_params(req:dict):
text:str = req.get("text", "")
text_lang:str = req.get("text_lang", "")
ref_audio_path:str = req.get("ref_audio_path", "")
streaming_mode:bool = req.get("streaming_mode", False)
media_type:str = req.get("media_type", "wav")
prompt_lang:str = req.get("prompt_lang", "")
text_split_method:str = req.get("text_split_method", "cut5")
if ref_audio_path in [None, ""]:
return JSONResponse(status_code=400, content={"message": "ref_audio_path is required"})
if text in [None, ""]:
return JSONResponse(status_code=400, content={"message": "text is required"})
if (text_lang in [None, ""]) :
return JSONResponse(status_code=400, content={"message": "text_language is required"})
elif text_lang.lower() not in tts_config.languages:
return JSONResponse(status_code=400, content={"message": "text_language is not supported"})
if (prompt_lang in [None, ""]) :
return JSONResponse(status_code=400, content={"message": "prompt_lang is required"})
elif prompt_lang.lower() not in tts_config.languages:
return JSONResponse(status_code=400, content={"message": "prompt_lang is not supported"})
if media_type not in ["wav", "raw", "ogg", "aac"]:
return JSONResponse(status_code=400, content={"message": "media_type is not supported"})
elif media_type == "ogg" and not streaming_mode:
return JSONResponse(status_code=400, content={"message": "ogg format is not supported in non-streaming mode"})
if text_split_method not in cut_method_names:
return JSONResponse(status_code=400, content={"message": f"text_split_method:{text_split_method} is not supported"})
return None
async def tts_handle(req:dict):
"""
Text to speech handler.
@ -259,28 +292,16 @@ async def tts_handle(req:dict):
StreamingResponse: audio stream response.
"""
text:str = req.get("text", "")
text_lang:str = req.get("text_lang", "")
ref_audio_path:str = req.get("ref_audio_path", "")
streaming_mode = req.get("streaming_mode", False)
media_type = req.get("media_type", "wav")
check_res = check_params(req)
if check_res is not None:
return check_res
if streaming_mode:
req["return_fragment"] = True
if ref_audio_path in [None, ""]:
return JSONResponse(status_code=400, content={"message": "ref_audio_path is required"})
if text in [None, ""]:
return JSONResponse(status_code=400, content={"message": "text is required"})
if (text_lang in [None, ""]) :
return JSONResponse(status_code=400, content={"message": "text_language is required"})
elif text_lang.lower() not in ["auto", "en", "zh", "ja", "all_zh", "all_ja"]:
return JSONResponse(status_code=400, content={"message": "text_language is not supported"})
if media_type not in ["wav", "raw", "ogg", "aac"]:
return JSONResponse(status_code=400, content={"message": "media_type is not supported"})
elif media_type == "ogg" and not streaming_mode:
return JSONResponse(status_code=400, content={"message": "ogg format is not supported in non-streaming mode"})
try:
tts_generator=tts_pipeline.run(req)
@ -414,5 +435,9 @@ async def set_sovits_weights(weights_path: str = None):
if __name__ == "__main__":
try:
uvicorn.run(APP, host=host, port=port, workers=1)
except Exception as e:
traceback.print_exc()
os.kill(os.getpid(), signal.SIGTERM)
exit(0)