From 5b573f4880ce95a376dd71bb780ab9ba9df407d0 Mon Sep 17 00:00:00 2001 From: ChasonJiang <1440499136@qq.com> Date: Thu, 4 Apr 2024 15:25:30 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E4=BA=86=E4=BB=A3=E7=A0=81?= =?UTF-8?q?=E9=80=BB=E8=BE=91=EF=BC=8C=E6=8F=90=E5=8D=87=E5=81=A5=E5=A3=AE?= =?UTF-8?q?=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api_v2.py | 97 ++++++++++++++++++++++++++++--------------------------- 1 file changed, 49 insertions(+), 48 deletions(-) diff --git a/api_v2.py b/api_v2.py index 0d8c68c2..92ed7850 100644 --- a/api_v2.py +++ b/api_v2.py @@ -1,5 +1,5 @@ """ -# api.py usage +# web api 接口文档 ` python api_v2.py -a 127.0.0.1 -p 9880 -c GPT_SoVITS/configs/tts_infer.yaml ` @@ -14,7 +14,7 @@ endpoint: `/tts` GET: - `http://127.0.0.1:9880/tts?ref_audio_path=123.wav&prompt_text=一二三。&prompt_language=zh&text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh` + `http://127.0.0.1:9880/tts?text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_lang=zh&ref_audio_path=archive_jingyuan_1.wav&prompt_lang=zh&prompt_text=我是「罗浮」云骑将军景元。不必拘谨,「将军」只是一时的身份,你称呼我景元便可&text_split_method=cut5&batch_size=1&media_type=wav&streaming_mode=true` POST: ```json @@ -27,7 +27,7 @@ POST: "top_k": 5, # int.(optional) top k sampling "top_p": 1, # float.(optional) top p sampling "temperature": 1, # float.(optional) temperature for sampling - "text_split_method": "cut0", # str.(optional) text split method, see text_segmentation_method.py for details. + "text_split_method": "cut5", # str.(optional) text split method, see text_segmentation_method.py for details. "batch_size": 1, # int.(optional) batch size for inference "batch_threshold": 0.75, # float.(optional) threshold for batch splitting. "split_bucket: True, # bool.(optional) whether to split the batch into multiple buckets. @@ -90,26 +90,21 @@ RESP: ### """ -from asyncio import sleep import os import sys +from typing import Generator now_dir = os.getcwd() sys.path.append(now_dir) sys.path.append("%s/GPT_SoVITS" % (now_dir)) import argparse -import io -# import os import subprocess -# import sys import wave - import signal import numpy as np -import torch import soundfile as sf -from fastapi import FastAPI, Request, HTTPException +from fastapi import FastAPI, Request, HTTPException, Response from fastapi.responses import StreamingResponse, JSONResponse from fastapi import FastAPI, UploadFile, File import uvicorn @@ -117,14 +112,13 @@ from io import BytesIO from tools.i18n.i18n import I18nAuto from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config from fastapi.responses import StreamingResponse +from pydantic import BaseModel # print(sys.path) i18n = I18nAuto() parser = argparse.ArgumentParser(description="GPT-SoVITS api") -# parser.add_argument("-api_c", "--api_config", type=str, default="GPT_SoVITS/configs/api_config.yaml", help="api_config路径") -parser.add_argument("-tts_c", "--tts_config", type=str, default="GPT_SoVITS/configs/tts_infer.yaml", help="tts_infer路径") -# parser.add_argument("-d", "--device", type=str, default="cpu", help="cuda / cpu / mps") +parser.add_argument("-c", "--tts_config", type=str, default="GPT_SoVITS/configs/tts_infer.yaml", help="tts_infer路径") parser.add_argument("-a", "--bind_addr", type=str, default="127.0.0.1", help="default: 127.0.0.1") parser.add_argument("-p", "--port", type=int, default="9880", help="default: 9880") args = parser.parse_args() @@ -137,11 +131,28 @@ argv = sys.argv if config_path in [None, ""]: config_path = "GPT-SoVITS/configs/tts_infer.yaml" -tts_config=TTS_Config(config_path) +tts_config = TTS_Config(config_path) tts_pipeline = TTS(tts_config) APP = FastAPI() - +class TTS_Request(BaseModel): + text: str = None + text_lang: str = None + ref_audio_path: str = None + prompt_lang: str = None + prompt_text: str = "" + top_k:int = 5 + top_p:float = 1 + temperature:float = 1 + text_split_method:str = "cut5" + batch_size:int = 1 + batch_threshold:float = 0.75 + split_bucket:bool = True + speed_factor:float = 1.0 + fragment_interval:float = 0.3 + seed:int = -1 + media_type:str = "wav" + streaming_mode:bool = False ### modify from https://github.com/RVC-Boss/GPT-SoVITS/pull/894/files def pack_ogg(io_buffer:BytesIO, data:np.ndarray, rate:int): @@ -196,7 +207,7 @@ def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=3 # This will create a wave header then append the frame input # It should be first on a streaming wav file # Other frames better should not have it (else you will hear some artifacts each chunk start) - wav_buf = io.BytesIO() + wav_buf = BytesIO() with wave.open(wav_buf, "wb") as vfout: vfout.setnchannels(channels) vfout.setsampwidth(sample_width) @@ -215,7 +226,7 @@ def handle_control(command:str): exit(0) -def tts_handle(req:dict): +async def tts_handle(req:dict): """ Text to speech handler. @@ -257,38 +268,32 @@ def tts_handle(req:dict): return JSONResponse(status_code=400, content={"message": "ref_audio_path is required"}) if text in [None, ""]: return JSONResponse(status_code=400, content={"message": "text is required"}) - if (text_lang in [None, ""]) or (text_lang.lower() not in ["auto", "en", "zh", "ja", "all_zh", "all_ja"]): + if (text_lang in [None, ""]) : return JSONResponse(status_code=400, content={"message": "text_language is required"}) + elif text_lang.lower() not in ["auto", "en", "zh", "ja", "all_zh", "all_ja"]: + return JSONResponse(status_code=400, content={"message": "text_language is not supported"}) if media_type not in ["wav", "raw", "ogg", "aac"]: return JSONResponse(status_code=400, content={"message": "media_type is not supported"}) - + elif media_type == "ogg" and not streaming_mode: + return JSONResponse(status_code=400, content={"message": "ogg format is not supported in non-streaming mode"}) try: tts_generator=tts_pipeline.run(req) if streaming_mode: - if media_type == "wav": - def streaming_generator(tts_generator): + def streaming_generator(tts_generator:Generator, media_type:str): + if media_type == "wav": yield wave_header_chunk() - for sr, chunk in tts_generator: - yield pack_audio(BytesIO(), chunk, sr, "raw").getvalue() - else: - def streaming_generator(tts_generator): - for sr, chunk in tts_generator: - yield pack_audio(BytesIO(), chunk, sr, media_type).getvalue() - - _media_type = "" - if media_type in ["wav", "raw"]: - _media_type = f"audio/x-{media_type}" - else: - _media_type = f"audio/{media_type}" - return StreamingResponse(streaming_generator(tts_generator), media_type=_media_type) - + media_type = "raw" + for sr, chunk in tts_generator: + yield pack_audio(BytesIO(), chunk, sr, media_type).getvalue() + # _media_type = f"audio/{media_type}" if not (streaming_mode and media_type in ["wav", "raw"]) else f"audio/x-{media_type}" + return StreamingResponse(streaming_generator(tts_generator, media_type, ), media_type=f"audio/{media_type}") + else: - audio_buffer = BytesIO() sr, audio_data = next(tts_generator) - audio_buffer = pack_audio(audio_buffer, audio_data, sr, media_type) - return StreamingResponse(audio_buffer, media_type=f"audio/{media_type}") + audio_data = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue() + return Response(audio_data, media_type=f"audio/{media_type}") except Exception as e: return JSONResponse(status_code=400, content={"message": f"tts failed", "Exception": str(e)}) @@ -344,13 +349,13 @@ async def tts_get_endpoint( "media_type":media_type, "streaming_mode":streaming_mode, } - return tts_handle(req) + return await tts_handle(req) @APP.post("/tts") -async def tts_post_endpoint(request: Request): - req = await request.json() - return tts_handle(req) +async def tts_post_endpoint(request: TTS_Request): + req = request.dict() + return await tts_handle(req) @APP.get("/set_refer_audio") @@ -404,10 +409,6 @@ async def set_sovits_weights(weights_path: str = None): - if __name__ == "__main__": - try: - uvicorn.run(APP, host=host, port=port, workers=1) - except KeyboardInterrupt: - os.kill(os.getpid(), signal.SIGTERM) - exit(0) + uvicorn.run(APP, host=host, port=port, workers=1) +