mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-09 00:10:00 +08:00
优化代码逻辑
This commit is contained in:
parent
97fe6039b7
commit
510347e6b0
@ -844,7 +844,7 @@ class TTS:
|
|||||||
self.vits_model = None
|
self.vits_model = None
|
||||||
self.init_t2s_weights(self.configs.t2s_weights_path)
|
self.init_t2s_weights(self.configs.t2s_weights_path)
|
||||||
self.init_vits_weights(self.configs.vits_weights_path)
|
self.init_vits_weights(self.configs.vits_weights_path)
|
||||||
raise
|
raise e
|
||||||
finally:
|
finally:
|
||||||
self.empty_cache()
|
self.empty_cache()
|
||||||
|
|
||||||
|
@ -16,6 +16,9 @@ def get_method(name:str)->Callable:
|
|||||||
raise ValueError(f"Method {name} not found")
|
raise ValueError(f"Method {name} not found")
|
||||||
return method
|
return method
|
||||||
|
|
||||||
|
def get_method_names()->list:
|
||||||
|
return list(METHODS.keys())
|
||||||
|
|
||||||
def register_method(name):
|
def register_method(name):
|
||||||
def decorator(func):
|
def decorator(func):
|
||||||
METHODS[name] = func
|
METHODS[name] = func
|
||||||
|
63
api_v2.py
63
api_v2.py
@ -96,6 +96,7 @@ RESP:
|
|||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import traceback
|
||||||
from typing import Generator
|
from typing import Generator
|
||||||
|
|
||||||
now_dir = os.getcwd()
|
now_dir = os.getcwd()
|
||||||
@ -115,11 +116,12 @@ import uvicorn
|
|||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from tools.i18n.i18n import I18nAuto
|
from tools.i18n.i18n import I18nAuto
|
||||||
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config
|
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config
|
||||||
|
from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
# print(sys.path)
|
# print(sys.path)
|
||||||
i18n = I18nAuto()
|
i18n = I18nAuto()
|
||||||
|
cut_method_names = get_cut_method_names()
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="GPT-SoVITS api")
|
parser = argparse.ArgumentParser(description="GPT-SoVITS api")
|
||||||
parser.add_argument("-c", "--tts_config", type=str, default="GPT_SoVITS/configs/tts_infer.yaml", help="tts_infer路径")
|
parser.add_argument("-c", "--tts_config", type=str, default="GPT_SoVITS/configs/tts_infer.yaml", help="tts_infer路径")
|
||||||
@ -230,6 +232,37 @@ def handle_control(command:str):
|
|||||||
exit(0)
|
exit(0)
|
||||||
|
|
||||||
|
|
||||||
|
def check_params(req:dict):
|
||||||
|
text:str = req.get("text", "")
|
||||||
|
text_lang:str = req.get("text_lang", "")
|
||||||
|
ref_audio_path:str = req.get("ref_audio_path", "")
|
||||||
|
streaming_mode:bool = req.get("streaming_mode", False)
|
||||||
|
media_type:str = req.get("media_type", "wav")
|
||||||
|
prompt_lang:str = req.get("prompt_lang", "")
|
||||||
|
text_split_method:str = req.get("text_split_method", "cut5")
|
||||||
|
|
||||||
|
if ref_audio_path in [None, ""]:
|
||||||
|
return JSONResponse(status_code=400, content={"message": "ref_audio_path is required"})
|
||||||
|
if text in [None, ""]:
|
||||||
|
return JSONResponse(status_code=400, content={"message": "text is required"})
|
||||||
|
if (text_lang in [None, ""]) :
|
||||||
|
return JSONResponse(status_code=400, content={"message": "text_language is required"})
|
||||||
|
elif text_lang.lower() not in tts_config.languages:
|
||||||
|
return JSONResponse(status_code=400, content={"message": "text_language is not supported"})
|
||||||
|
if (prompt_lang in [None, ""]) :
|
||||||
|
return JSONResponse(status_code=400, content={"message": "prompt_lang is required"})
|
||||||
|
elif prompt_lang.lower() not in tts_config.languages:
|
||||||
|
return JSONResponse(status_code=400, content={"message": "prompt_lang is not supported"})
|
||||||
|
if media_type not in ["wav", "raw", "ogg", "aac"]:
|
||||||
|
return JSONResponse(status_code=400, content={"message": "media_type is not supported"})
|
||||||
|
elif media_type == "ogg" and not streaming_mode:
|
||||||
|
return JSONResponse(status_code=400, content={"message": "ogg format is not supported in non-streaming mode"})
|
||||||
|
|
||||||
|
if text_split_method not in cut_method_names:
|
||||||
|
return JSONResponse(status_code=400, content={"message": f"text_split_method:{text_split_method} is not supported"})
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
async def tts_handle(req:dict):
|
async def tts_handle(req:dict):
|
||||||
"""
|
"""
|
||||||
Text to speech handler.
|
Text to speech handler.
|
||||||
@ -259,28 +292,16 @@ async def tts_handle(req:dict):
|
|||||||
StreamingResponse: audio stream response.
|
StreamingResponse: audio stream response.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
text:str = req.get("text", "")
|
|
||||||
text_lang:str = req.get("text_lang", "")
|
|
||||||
ref_audio_path:str = req.get("ref_audio_path", "")
|
|
||||||
streaming_mode = req.get("streaming_mode", False)
|
streaming_mode = req.get("streaming_mode", False)
|
||||||
media_type = req.get("media_type", "wav")
|
media_type = req.get("media_type", "wav")
|
||||||
|
|
||||||
|
check_res = check_params(req)
|
||||||
|
if check_res is not None:
|
||||||
|
return check_res
|
||||||
|
|
||||||
if streaming_mode:
|
if streaming_mode:
|
||||||
req["return_fragment"] = True
|
req["return_fragment"] = True
|
||||||
|
|
||||||
if ref_audio_path in [None, ""]:
|
|
||||||
return JSONResponse(status_code=400, content={"message": "ref_audio_path is required"})
|
|
||||||
if text in [None, ""]:
|
|
||||||
return JSONResponse(status_code=400, content={"message": "text is required"})
|
|
||||||
if (text_lang in [None, ""]) :
|
|
||||||
return JSONResponse(status_code=400, content={"message": "text_language is required"})
|
|
||||||
elif text_lang.lower() not in ["auto", "en", "zh", "ja", "all_zh", "all_ja"]:
|
|
||||||
return JSONResponse(status_code=400, content={"message": "text_language is not supported"})
|
|
||||||
if media_type not in ["wav", "raw", "ogg", "aac"]:
|
|
||||||
return JSONResponse(status_code=400, content={"message": "media_type is not supported"})
|
|
||||||
elif media_type == "ogg" and not streaming_mode:
|
|
||||||
return JSONResponse(status_code=400, content={"message": "ogg format is not supported in non-streaming mode"})
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tts_generator=tts_pipeline.run(req)
|
tts_generator=tts_pipeline.run(req)
|
||||||
|
|
||||||
@ -414,5 +435,9 @@ async def set_sovits_weights(weights_path: str = None):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
uvicorn.run(APP, host=host, port=port, workers=1)
|
try:
|
||||||
|
uvicorn.run(APP, host=host, port=port, workers=1)
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
os.kill(os.getpid(), signal.SIGTERM)
|
||||||
|
exit(0)
|
Loading…
x
Reference in New Issue
Block a user