From 745bd44132b37559c9e63db631b20f90b631a5ee Mon Sep 17 00:00:00 2001 From: XTer Date: Tue, 9 Apr 2024 03:51:52 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E5=BC=BA=E5=8C=96api=5Fv2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api_v2.py | 179 +++++++++++++++++++++++++----------------------------- 1 file changed, 83 insertions(+), 96 deletions(-) diff --git a/api_v2.py b/api_v2.py index 50180595..e58ff323 100644 --- a/api_v2.py +++ b/api_v2.py @@ -22,17 +22,17 @@ POST: ```json { "text": "", # str.(required) text to be synthesized - "text_lang": "", # str.(required) language of the text to be synthesized "ref_audio_path": "", # str.(required) reference audio path. "prompt_text": "", # str.(optional) prompt text for the reference audio - "prompt_lang": "", # str.(required) language of the prompt text for the reference audio + "text_lang": "auto", # str.(optional) language of the text to be synthesized + "prompt_lang": "auto", # str.(optional) language of the prompt text for the reference audio "top_k": 5, # int.(optional) top k sampling "top_p": 1, # float.(optional) top p sampling "temperature": 1, # float.(optional) temperature for sampling "text_split_method": "cut5", # str.(optional) text split method, see text_segmentation_method.py for details. "batch_size": 1, # int.(optional) batch size for inference "batch_threshold": 0.75, # float.(optional) threshold for batch splitting. - "split_bucket": true, # bool.(optional) whether to split the batch into multiple buckets. + "split_bucket": true, # bool.(optional) whether to split the batch into multiple buckets. "speed_factor":1.0, # float.(optional) control the speed of the synthesized audio. "fragment_interval":0.3, # float.(optional) to control the interval of the audio fragment. "seed": -1, # int.(optional) random seed for reproducibility. @@ -117,8 +117,12 @@ from io import BytesIO from tools.i18n.i18n import I18nAuto from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names -from fastapi.responses import StreamingResponse +from fastapi.responses import StreamingResponse, FileResponse from pydantic import BaseModel +import tempfile + +from urllib.parse import unquote + # print(sys.path) i18n = I18nAuto() cut_method_names = get_cut_method_names() @@ -141,12 +145,14 @@ tts_config = TTS_Config(config_path) tts_pipeline = TTS(tts_config) APP = FastAPI() + +# modified from https://github.com/X-T-E-R/GPT-SoVITS-Inference/blob/stable/Inference/src/TTS_Instance.py class TTS_Request(BaseModel): text: str = None - text_lang: str = None ref_audio_path: str = None - prompt_lang: str = None prompt_text: str = "" + text_lang: str = "auto" + prompt_lang: str = "auto" top_k:int = 5 top_p:float = 1 temperature:float = 1 @@ -160,6 +166,41 @@ class TTS_Request(BaseModel): media_type:str = "wav" streaming_mode:bool = False + # 青春版 from TTS_Task from https://github.com/X-T-E-R/GPT-SoVITS-Inference/blob/stable/Inference/src/TTS_Instance.py + def update(self, req:dict): + for key in req: + if hasattr(self, key): + type_ = type(getattr(self, key)) + value = unquote(req[key]) + if type_ == bool: + value = value.lower() in ["true", "1"] + elif type_ == int: + value = int(value) + elif type_ == float: + value = float(value) + setattr(self, key, value) + + def to_dict(self): + return self.model_dump() + + def check(self): + if (self.text_lang in [None, ""]) or self.text_lang.lower() not in tts_config.languages: + self.text_lang = "auto" + if (self.prompt_lang in [None, ""]) or self.prompt_lang.lower() not in tts_config.languages: + self.prompt_lang = "auto" + + if self.text in [None, ""]: + return JSONResponse(status_code=400, content={"message": "text is required"}) + if self.ref_audio_path in [None, ""]: + return JSONResponse(status_code=400, content={"message": "ref_audio_path is required"}) + if self.streaming_mode and self.media_type not in ["wav", "raw", "ogg", "aac"]: + return JSONResponse(status_code=400, content={"message": f"media_type {self.media_type} is not supported in streaming mode"}) + if self.text_split_method not in cut_method_names: + return JSONResponse(status_code=400, content={"message": f"text_split_method:{self.text_split_method} is not supported"}) + return None + + +# 有点想删掉这些东西,为了streaming 写了一堆东西,但是貌似用streaming的时候,一般用的是wav ### modify from https://github.com/RVC-Boss/GPT-SoVITS/pull/894/files def pack_ogg(io_buffer:BytesIO, data:np.ndarray, rate:int): with sf.SoundFile(io_buffer, mode='w', samplerate=rate, channels=1, format='ogg') as audio_file: @@ -231,39 +272,8 @@ def handle_control(command:str): os.kill(os.getpid(), signal.SIGTERM) exit(0) - -def check_params(req:dict): - text:str = req.get("text", "") - text_lang:str = req.get("text_lang", "") - ref_audio_path:str = req.get("ref_audio_path", "") - streaming_mode:bool = req.get("streaming_mode", False) - media_type:str = req.get("media_type", "wav") - prompt_lang:str = req.get("prompt_lang", "") - text_split_method:str = req.get("text_split_method", "cut5") - - if ref_audio_path in [None, ""]: - return JSONResponse(status_code=400, content={"message": "ref_audio_path is required"}) - if text in [None, ""]: - return JSONResponse(status_code=400, content={"message": "text is required"}) - if (text_lang in [None, ""]) : - return JSONResponse(status_code=400, content={"message": "text_lang is required"}) - elif text_lang.lower() not in tts_config.languages: - return JSONResponse(status_code=400, content={"message": "text_lang is not supported"}) - if (prompt_lang in [None, ""]) : - return JSONResponse(status_code=400, content={"message": "prompt_lang is required"}) - elif prompt_lang.lower() not in tts_config.languages: - return JSONResponse(status_code=400, content={"message": "prompt_lang is not supported"}) - if media_type not in ["wav", "raw", "ogg", "aac"]: - return JSONResponse(status_code=400, content={"message": "media_type is not supported"}) - elif media_type == "ogg" and not streaming_mode: - return JSONResponse(status_code=400, content={"message": "ogg format is not supported in non-streaming mode"}) - - if text_split_method not in cut_method_names: - return JSONResponse(status_code=400, content={"message": f"text_split_method:{text_split_method} is not supported"}) - - return None - -async def tts_handle(req:dict): +# 不用写成异步的,反正要等,也不能并行 +def tts_handle(req:dict): """ Text to speech handler. @@ -271,10 +281,10 @@ async def tts_handle(req:dict): req (dict): { "text": "", # str.(required) text to be synthesized - "text_lang: "", # str.(required) language of the text to be synthesized "ref_audio_path": "", # str.(required) reference audio path "prompt_text": "", # str.(optional) prompt text for the reference audio - "prompt_lang": "", # str.(required) language of the prompt text for the reference audio + "text_lang: "auto", # str. language of the text to be synthesized + "prompt_lang": "auto", # str. language of the prompt text for the reference audio "top_k": 5, # int. top k sampling "top_p": 1, # float. top p sampling "temperature": 1, # float. temperature for sampling @@ -292,13 +302,10 @@ async def tts_handle(req:dict): StreamingResponse: audio stream response. """ + # 已经检查过了,这里不再检查 streaming_mode = req.get("streaming_mode", False) media_type = req.get("media_type", "wav") - check_res = check_params(req) - if check_res is not None: - return check_res - if streaming_mode: req["return_fragment"] = True @@ -316,17 +323,26 @@ async def tts_handle(req:dict): return StreamingResponse(streaming_generator(tts_generator, media_type, ), media_type=f"audio/{media_type}") else: + # 换用临时文件,支持更多格式,速度能更快,并且会避免占线 sr, audio_data = next(tts_generator) - audio_data = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue() - return Response(audio_data, media_type=f"audio/{media_type}") + format = media_type + with tempfile.NamedTemporaryFile(delete=False, suffix=f'.{format}') as tmp_file: + # 尝试写入用户指定的格式,如果失败则回退到 WAV 格式 + try: + sf.write(tmp_file, audio_data, sr, format=format) + except Exception as e: + # 如果指定的格式无法写入,则回退到 WAV 格式 + sf.write(tmp_file, audio_data, sr, format='wav') + format = 'wav' # 更新格式为 wav + + tmp_file_path = tmp_file.name + # 返回文件响应,FileResponse 会负责将文件发送给客户端 + return FileResponse(tmp_file_path, media_type=f"audio/{format}", filename=f"audio.{format}") except Exception as e: return JSONResponse(status_code=400, content={"message": f"tts failed", "Exception": str(e)}) - - - @APP.get("/control") async def control(command: str = None): if command is None: @@ -334,53 +350,24 @@ async def control(command: str = None): handle_control(command) - +# modified from https://github.com/X-T-E-R/GPT-SoVITS-Inference/blob/stable/Inference/src/tts_backend.py @APP.get("/tts") -async def tts_get_endpoint( - text: str = None, - text_lang: str = None, - ref_audio_path: str = None, - prompt_lang: str = None, - prompt_text: str = "", - top_k:int = 5, - top_p:float = 1, - temperature:float = 1, - text_split_method:str = "cut0", - batch_size:int = 1, - batch_threshold:float = 0.75, - split_bucket:bool = True, - speed_factor:float = 1.0, - fragment_interval:float = 0.3, - seed:int = -1, - media_type:str = "wav", - streaming_mode:bool = False, - ): - req = { - "text": text, - "text_lang": text_lang.lower(), - "ref_audio_path": ref_audio_path, - "prompt_text": prompt_text, - "prompt_lang": prompt_lang.lower(), - "top_k": top_k, - "top_p": top_p, - "temperature": temperature, - "text_split_method": text_split_method, - "batch_size":int(batch_size), - "batch_threshold":float(batch_threshold), - "speed_factor":float(speed_factor), - "split_bucket":split_bucket, - "fragment_interval":fragment_interval, - "seed":seed, - "media_type":media_type, - "streaming_mode":streaming_mode, - } - return await tts_handle(req) - - @APP.post("/tts") -async def tts_post_endpoint(request: TTS_Request): - req = request.dict() - return await tts_handle(req) +async def tts_get_endpoint(request: Request): + + # 尝试从JSON中获取数据,如果不是JSON,则从查询参数中获取 + if request.method == "GET": + data = request.query_params + else: + data = await request.json() + + req = TTS_Request() + req.update(data) + res = req.check() + if res is not None: + return res + + return tts_handle(req.to_dict()) @APP.get("/set_refer_audio") @@ -436,7 +423,7 @@ async def set_sovits_weights(weights_path: str = None): if __name__ == "__main__": try: - uvicorn.run(APP, host=host, port=port, workers=1) + uvicorn.run(APP, host=host, port=port) # 删去workers=1,uvicorn这么写没法加 workers except Exception as e: traceback.print_exc() os.kill(os.getpid(), signal.SIGTERM) From fe136e949338be40ac2fa241494b17d46f9d3a6f Mon Sep 17 00:00:00 2001 From: XTer Date: Tue, 9 Apr 2024 04:00:55 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E4=BF=AE=E6=AD=A3bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api_v2.py b/api_v2.py index e58ff323..8684338f 100644 --- a/api_v2.py +++ b/api_v2.py @@ -171,7 +171,7 @@ class TTS_Request(BaseModel): for key in req: if hasattr(self, key): type_ = type(getattr(self, key)) - value = unquote(req[key]) + value = unquote(str(req[key])) if type_ == bool: value = value.lower() in ["true", "1"] elif type_ == int: