强化api_v2

This commit is contained in:
XTer 2024-04-09 03:51:52 +08:00
parent 3706ad1b8b
commit 745bd44132

177
api_v2.py
View File

@ -22,10 +22,10 @@ POST:
```json ```json
{ {
"text": "", # str.(required) text to be synthesized "text": "", # str.(required) text to be synthesized
"text_lang": "", # str.(required) language of the text to be synthesized
"ref_audio_path": "", # str.(required) reference audio path. "ref_audio_path": "", # str.(required) reference audio path.
"prompt_text": "", # str.(optional) prompt text for the reference audio "prompt_text": "", # str.(optional) prompt text for the reference audio
"prompt_lang": "", # str.(required) language of the prompt text for the reference audio "text_lang": "auto", # str.(optional) language of the text to be synthesized
"prompt_lang": "auto", # str.(optional) language of the prompt text for the reference audio
"top_k": 5, # int.(optional) top k sampling "top_k": 5, # int.(optional) top k sampling
"top_p": 1, # float.(optional) top p sampling "top_p": 1, # float.(optional) top p sampling
"temperature": 1, # float.(optional) temperature for sampling "temperature": 1, # float.(optional) temperature for sampling
@ -117,8 +117,12 @@ from io import BytesIO
from tools.i18n.i18n import I18nAuto from tools.i18n.i18n import I18nAuto
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config
from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse, FileResponse
from pydantic import BaseModel from pydantic import BaseModel
import tempfile
from urllib.parse import unquote
# print(sys.path) # print(sys.path)
i18n = I18nAuto() i18n = I18nAuto()
cut_method_names = get_cut_method_names() cut_method_names = get_cut_method_names()
@ -141,12 +145,14 @@ tts_config = TTS_Config(config_path)
tts_pipeline = TTS(tts_config) tts_pipeline = TTS(tts_config)
APP = FastAPI() APP = FastAPI()
# modified from https://github.com/X-T-E-R/GPT-SoVITS-Inference/blob/stable/Inference/src/TTS_Instance.py
class TTS_Request(BaseModel): class TTS_Request(BaseModel):
text: str = None text: str = None
text_lang: str = None
ref_audio_path: str = None ref_audio_path: str = None
prompt_lang: str = None
prompt_text: str = "" prompt_text: str = ""
text_lang: str = "auto"
prompt_lang: str = "auto"
top_k:int = 5 top_k:int = 5
top_p:float = 1 top_p:float = 1
temperature:float = 1 temperature:float = 1
@ -160,6 +166,41 @@ class TTS_Request(BaseModel):
media_type:str = "wav" media_type:str = "wav"
streaming_mode:bool = False streaming_mode:bool = False
# 青春版 from TTS_Task from https://github.com/X-T-E-R/GPT-SoVITS-Inference/blob/stable/Inference/src/TTS_Instance.py
def update(self, req:dict):
for key in req:
if hasattr(self, key):
type_ = type(getattr(self, key))
value = unquote(req[key])
if type_ == bool:
value = value.lower() in ["true", "1"]
elif type_ == int:
value = int(value)
elif type_ == float:
value = float(value)
setattr(self, key, value)
def to_dict(self):
return self.model_dump()
def check(self):
if (self.text_lang in [None, ""]) or self.text_lang.lower() not in tts_config.languages:
self.text_lang = "auto"
if (self.prompt_lang in [None, ""]) or self.prompt_lang.lower() not in tts_config.languages:
self.prompt_lang = "auto"
if self.text in [None, ""]:
return JSONResponse(status_code=400, content={"message": "text is required"})
if self.ref_audio_path in [None, ""]:
return JSONResponse(status_code=400, content={"message": "ref_audio_path is required"})
if self.streaming_mode and self.media_type not in ["wav", "raw", "ogg", "aac"]:
return JSONResponse(status_code=400, content={"message": f"media_type {self.media_type} is not supported in streaming mode"})
if self.text_split_method not in cut_method_names:
return JSONResponse(status_code=400, content={"message": f"text_split_method:{self.text_split_method} is not supported"})
return None
# 有点想删掉这些东西为了streaming 写了一堆东西但是貌似用streaming的时候一般用的是wav
### modify from https://github.com/RVC-Boss/GPT-SoVITS/pull/894/files ### modify from https://github.com/RVC-Boss/GPT-SoVITS/pull/894/files
def pack_ogg(io_buffer:BytesIO, data:np.ndarray, rate:int): def pack_ogg(io_buffer:BytesIO, data:np.ndarray, rate:int):
with sf.SoundFile(io_buffer, mode='w', samplerate=rate, channels=1, format='ogg') as audio_file: with sf.SoundFile(io_buffer, mode='w', samplerate=rate, channels=1, format='ogg') as audio_file:
@ -231,39 +272,8 @@ def handle_control(command:str):
os.kill(os.getpid(), signal.SIGTERM) os.kill(os.getpid(), signal.SIGTERM)
exit(0) exit(0)
# 不用写成异步的,反正要等,也不能并行
def check_params(req:dict): def tts_handle(req:dict):
text:str = req.get("text", "")
text_lang:str = req.get("text_lang", "")
ref_audio_path:str = req.get("ref_audio_path", "")
streaming_mode:bool = req.get("streaming_mode", False)
media_type:str = req.get("media_type", "wav")
prompt_lang:str = req.get("prompt_lang", "")
text_split_method:str = req.get("text_split_method", "cut5")
if ref_audio_path in [None, ""]:
return JSONResponse(status_code=400, content={"message": "ref_audio_path is required"})
if text in [None, ""]:
return JSONResponse(status_code=400, content={"message": "text is required"})
if (text_lang in [None, ""]) :
return JSONResponse(status_code=400, content={"message": "text_lang is required"})
elif text_lang.lower() not in tts_config.languages:
return JSONResponse(status_code=400, content={"message": "text_lang is not supported"})
if (prompt_lang in [None, ""]) :
return JSONResponse(status_code=400, content={"message": "prompt_lang is required"})
elif prompt_lang.lower() not in tts_config.languages:
return JSONResponse(status_code=400, content={"message": "prompt_lang is not supported"})
if media_type not in ["wav", "raw", "ogg", "aac"]:
return JSONResponse(status_code=400, content={"message": "media_type is not supported"})
elif media_type == "ogg" and not streaming_mode:
return JSONResponse(status_code=400, content={"message": "ogg format is not supported in non-streaming mode"})
if text_split_method not in cut_method_names:
return JSONResponse(status_code=400, content={"message": f"text_split_method:{text_split_method} is not supported"})
return None
async def tts_handle(req:dict):
""" """
Text to speech handler. Text to speech handler.
@ -271,10 +281,10 @@ async def tts_handle(req:dict):
req (dict): req (dict):
{ {
"text": "", # str.(required) text to be synthesized "text": "", # str.(required) text to be synthesized
"text_lang: "", # str.(required) language of the text to be synthesized
"ref_audio_path": "", # str.(required) reference audio path "ref_audio_path": "", # str.(required) reference audio path
"prompt_text": "", # str.(optional) prompt text for the reference audio "prompt_text": "", # str.(optional) prompt text for the reference audio
"prompt_lang": "", # str.(required) language of the prompt text for the reference audio "text_lang: "auto", # str. language of the text to be synthesized
"prompt_lang": "auto", # str. language of the prompt text for the reference audio
"top_k": 5, # int. top k sampling "top_k": 5, # int. top k sampling
"top_p": 1, # float. top p sampling "top_p": 1, # float. top p sampling
"temperature": 1, # float. temperature for sampling "temperature": 1, # float. temperature for sampling
@ -292,13 +302,10 @@ async def tts_handle(req:dict):
StreamingResponse: audio stream response. StreamingResponse: audio stream response.
""" """
# 已经检查过了,这里不再检查
streaming_mode = req.get("streaming_mode", False) streaming_mode = req.get("streaming_mode", False)
media_type = req.get("media_type", "wav") media_type = req.get("media_type", "wav")
check_res = check_params(req)
if check_res is not None:
return check_res
if streaming_mode: if streaming_mode:
req["return_fragment"] = True req["return_fragment"] = True
@ -316,17 +323,26 @@ async def tts_handle(req:dict):
return StreamingResponse(streaming_generator(tts_generator, media_type, ), media_type=f"audio/{media_type}") return StreamingResponse(streaming_generator(tts_generator, media_type, ), media_type=f"audio/{media_type}")
else: else:
# 换用临时文件,支持更多格式,速度能更快,并且会避免占线
sr, audio_data = next(tts_generator) sr, audio_data = next(tts_generator)
audio_data = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue() format = media_type
return Response(audio_data, media_type=f"audio/{media_type}") with tempfile.NamedTemporaryFile(delete=False, suffix=f'.{format}') as tmp_file:
# 尝试写入用户指定的格式,如果失败则回退到 WAV 格式
try:
sf.write(tmp_file, audio_data, sr, format=format)
except Exception as e:
# 如果指定的格式无法写入,则回退到 WAV 格式
sf.write(tmp_file, audio_data, sr, format='wav')
format = 'wav' # 更新格式为 wav
tmp_file_path = tmp_file.name
# 返回文件响应FileResponse 会负责将文件发送给客户端
return FileResponse(tmp_file_path, media_type=f"audio/{format}", filename=f"audio.{format}")
except Exception as e: except Exception as e:
return JSONResponse(status_code=400, content={"message": f"tts failed", "Exception": str(e)}) return JSONResponse(status_code=400, content={"message": f"tts failed", "Exception": str(e)})
@APP.get("/control") @APP.get("/control")
async def control(command: str = None): async def control(command: str = None):
if command is None: if command is None:
@ -334,53 +350,24 @@ async def control(command: str = None):
handle_control(command) handle_control(command)
# modified from https://github.com/X-T-E-R/GPT-SoVITS-Inference/blob/stable/Inference/src/tts_backend.py
@APP.get("/tts") @APP.get("/tts")
async def tts_get_endpoint(
text: str = None,
text_lang: str = None,
ref_audio_path: str = None,
prompt_lang: str = None,
prompt_text: str = "",
top_k:int = 5,
top_p:float = 1,
temperature:float = 1,
text_split_method:str = "cut0",
batch_size:int = 1,
batch_threshold:float = 0.75,
split_bucket:bool = True,
speed_factor:float = 1.0,
fragment_interval:float = 0.3,
seed:int = -1,
media_type:str = "wav",
streaming_mode:bool = False,
):
req = {
"text": text,
"text_lang": text_lang.lower(),
"ref_audio_path": ref_audio_path,
"prompt_text": prompt_text,
"prompt_lang": prompt_lang.lower(),
"top_k": top_k,
"top_p": top_p,
"temperature": temperature,
"text_split_method": text_split_method,
"batch_size":int(batch_size),
"batch_threshold":float(batch_threshold),
"speed_factor":float(speed_factor),
"split_bucket":split_bucket,
"fragment_interval":fragment_interval,
"seed":seed,
"media_type":media_type,
"streaming_mode":streaming_mode,
}
return await tts_handle(req)
@APP.post("/tts") @APP.post("/tts")
async def tts_post_endpoint(request: TTS_Request): async def tts_get_endpoint(request: Request):
req = request.dict()
return await tts_handle(req) # 尝试从JSON中获取数据如果不是JSON则从查询参数中获取
if request.method == "GET":
data = request.query_params
else:
data = await request.json()
req = TTS_Request()
req.update(data)
res = req.check()
if res is not None:
return res
return tts_handle(req.to_dict())
@APP.get("/set_refer_audio") @APP.get("/set_refer_audio")
@ -436,7 +423,7 @@ async def set_sovits_weights(weights_path: str = None):
if __name__ == "__main__": if __name__ == "__main__":
try: try:
uvicorn.run(APP, host=host, port=port, workers=1) uvicorn.run(APP, host=host, port=port) # 删去workers=1uvicorn这么写没法加 workers
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
os.kill(os.getpid(), signal.SIGTERM) os.kill(os.getpid(), signal.SIGTERM)