mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-09 00:10:00 +08:00
优化代码逻辑
This commit is contained in:
parent
97fe6039b7
commit
510347e6b0
@ -844,7 +844,7 @@ class TTS:
|
||||
self.vits_model = None
|
||||
self.init_t2s_weights(self.configs.t2s_weights_path)
|
||||
self.init_vits_weights(self.configs.vits_weights_path)
|
||||
raise
|
||||
raise e
|
||||
finally:
|
||||
self.empty_cache()
|
||||
|
||||
|
@ -16,6 +16,9 @@ def get_method(name:str)->Callable:
|
||||
raise ValueError(f"Method {name} not found")
|
||||
return method
|
||||
|
||||
def get_method_names()->list:
|
||||
return list(METHODS.keys())
|
||||
|
||||
def register_method(name):
|
||||
def decorator(func):
|
||||
METHODS[name] = func
|
||||
|
63
api_v2.py
63
api_v2.py
@ -96,6 +96,7 @@ RESP:
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from typing import Generator
|
||||
|
||||
now_dir = os.getcwd()
|
||||
@ -115,11 +116,12 @@ import uvicorn
|
||||
from io import BytesIO
|
||||
from tools.i18n.i18n import I18nAuto
|
||||
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config
|
||||
from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
# print(sys.path)
|
||||
i18n = I18nAuto()
|
||||
|
||||
cut_method_names = get_cut_method_names()
|
||||
|
||||
parser = argparse.ArgumentParser(description="GPT-SoVITS api")
|
||||
parser.add_argument("-c", "--tts_config", type=str, default="GPT_SoVITS/configs/tts_infer.yaml", help="tts_infer路径")
|
||||
@ -230,6 +232,37 @@ def handle_control(command:str):
|
||||
exit(0)
|
||||
|
||||
|
||||
def check_params(req:dict):
|
||||
text:str = req.get("text", "")
|
||||
text_lang:str = req.get("text_lang", "")
|
||||
ref_audio_path:str = req.get("ref_audio_path", "")
|
||||
streaming_mode:bool = req.get("streaming_mode", False)
|
||||
media_type:str = req.get("media_type", "wav")
|
||||
prompt_lang:str = req.get("prompt_lang", "")
|
||||
text_split_method:str = req.get("text_split_method", "cut5")
|
||||
|
||||
if ref_audio_path in [None, ""]:
|
||||
return JSONResponse(status_code=400, content={"message": "ref_audio_path is required"})
|
||||
if text in [None, ""]:
|
||||
return JSONResponse(status_code=400, content={"message": "text is required"})
|
||||
if (text_lang in [None, ""]) :
|
||||
return JSONResponse(status_code=400, content={"message": "text_language is required"})
|
||||
elif text_lang.lower() not in tts_config.languages:
|
||||
return JSONResponse(status_code=400, content={"message": "text_language is not supported"})
|
||||
if (prompt_lang in [None, ""]) :
|
||||
return JSONResponse(status_code=400, content={"message": "prompt_lang is required"})
|
||||
elif prompt_lang.lower() not in tts_config.languages:
|
||||
return JSONResponse(status_code=400, content={"message": "prompt_lang is not supported"})
|
||||
if media_type not in ["wav", "raw", "ogg", "aac"]:
|
||||
return JSONResponse(status_code=400, content={"message": "media_type is not supported"})
|
||||
elif media_type == "ogg" and not streaming_mode:
|
||||
return JSONResponse(status_code=400, content={"message": "ogg format is not supported in non-streaming mode"})
|
||||
|
||||
if text_split_method not in cut_method_names:
|
||||
return JSONResponse(status_code=400, content={"message": f"text_split_method:{text_split_method} is not supported"})
|
||||
|
||||
return None
|
||||
|
||||
async def tts_handle(req:dict):
|
||||
"""
|
||||
Text to speech handler.
|
||||
@ -259,28 +292,16 @@ async def tts_handle(req:dict):
|
||||
StreamingResponse: audio stream response.
|
||||
"""
|
||||
|
||||
text:str = req.get("text", "")
|
||||
text_lang:str = req.get("text_lang", "")
|
||||
ref_audio_path:str = req.get("ref_audio_path", "")
|
||||
streaming_mode = req.get("streaming_mode", False)
|
||||
media_type = req.get("media_type", "wav")
|
||||
|
||||
check_res = check_params(req)
|
||||
if check_res is not None:
|
||||
return check_res
|
||||
|
||||
if streaming_mode:
|
||||
req["return_fragment"] = True
|
||||
|
||||
if ref_audio_path in [None, ""]:
|
||||
return JSONResponse(status_code=400, content={"message": "ref_audio_path is required"})
|
||||
if text in [None, ""]:
|
||||
return JSONResponse(status_code=400, content={"message": "text is required"})
|
||||
if (text_lang in [None, ""]) :
|
||||
return JSONResponse(status_code=400, content={"message": "text_language is required"})
|
||||
elif text_lang.lower() not in ["auto", "en", "zh", "ja", "all_zh", "all_ja"]:
|
||||
return JSONResponse(status_code=400, content={"message": "text_language is not supported"})
|
||||
if media_type not in ["wav", "raw", "ogg", "aac"]:
|
||||
return JSONResponse(status_code=400, content={"message": "media_type is not supported"})
|
||||
elif media_type == "ogg" and not streaming_mode:
|
||||
return JSONResponse(status_code=400, content={"message": "ogg format is not supported in non-streaming mode"})
|
||||
|
||||
try:
|
||||
tts_generator=tts_pipeline.run(req)
|
||||
|
||||
@ -414,5 +435,9 @@ async def set_sovits_weights(weights_path: str = None):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(APP, host=host, port=port, workers=1)
|
||||
|
||||
try:
|
||||
uvicorn.run(APP, host=host, port=port, workers=1)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
os.kill(os.getpid(), signal.SIGTERM)
|
||||
exit(0)
|
Loading…
x
Reference in New Issue
Block a user