Update api_v2.py

This commit is contained in:
刘悦 2024-06-10 13:31:10 +08:00 committed by GitHub
parent 3c4f5462eb
commit c65b448304
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -114,6 +114,7 @@ import soundfile as sf
from fastapi import FastAPI, Request, HTTPException, Response
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi import FastAPI, UploadFile, File
from fastapi.staticfiles import StaticFiles
import uvicorn
from io import BytesIO
from tools.i18n.i18n import I18nAuto
@ -143,6 +144,8 @@ tts_config = TTS_Config(config_path)
tts_pipeline = TTS(tts_config)
APP = FastAPI()
APP.mount("/srt", StaticFiles(directory="./srt"), name="srt")
APP.mount("/audio", StaticFiles(directory="./audio"), name="audio")
class TTS_Request(BaseModel):
text: str = None
text_lang: str = None
@ -329,7 +332,55 @@ async def tts_handle(req:dict):
return JSONResponse(status_code=400, content={"message": f"tts failed", "Exception": str(e)})
async def tts_handle_srt(req:dict,request):
"""
Text to speech handler.
Args:
req (dict):
{
"text": "", # str.(required) text to be synthesized
"text_lang: "", # str.(required) language of the text to be synthesized
"ref_audio_path": "", # str.(required) reference audio path
"prompt_text": "", # str.(optional) prompt text for the reference audio
"prompt_lang": "", # str.(required) language of the prompt text for the reference audio
"top_k": 5, # int. top k sampling
"top_p": 1, # float. top p sampling
"temperature": 1, # float. temperature for sampling
"text_split_method": "cut5", # str. text split method, see text_segmentation_method.py for details.
"batch_size": 1, # int. batch size for inference
"batch_threshold": 0.75, # float. threshold for batch splitting.
"split_bucket: True, # bool. whether to split the batch into multiple buckets.
"speed_factor":1.0, # float. control the speed of the synthesized audio.
"fragment_interval":0.3, # float. to control the interval of the audio fragment.
"seed": -1, # int. random seed for reproducibility.
"media_type": "wav", # str. media type of the output audio, support "wav", "raw", "ogg", "aac".
"streaming_mode": False, # bool. whether to return a streaming response.
"parallel_infer": True, # bool.(optional) whether to use parallel inference.
"repetition_penalty": 1.35 # float.(optional) repetition penalty for T2S model.
}
returns:
StreamingResponse: audio stream response.
"""
streaming_mode = req.get("streaming_mode", False)
media_type = req.get("media_type", "wav")
check_res = check_params(req)
if check_res is not None:
return check_res
try:
tts_generator=tts_pipeline.run(req)
sr, audio_data = next(tts_generator)
print(audio_data)
#audio_data = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue()
#return Response(audio_data, media_type=f"audio/{media_type}")
return JSONResponse({"code":"200", "srt":f"http://{request.url.hostname}:{request.url.port}/srt/tts-out.srt","audio":f"http://{request.url.hostname}:{request.url.port}/audio/audio.wav"})
except Exception as e:
return JSONResponse(status_code=400, content={"message": f"tts failed", "Exception": str(e)})
@ -339,7 +390,55 @@ async def control(command: str = None):
return JSONResponse(status_code=400, content={"message": "command is required"})
handle_control(command)
@APP.get("/srt")
async def tts_get_endpoint_srt(request: Request,
text: str = None,
text_lang: str = None,
ref_audio_path: str = None,
prompt_lang: str = None,
prompt_text: str = "",
top_k:int = 5,
top_p:float = 1,
temperature:float = 1,
text_split_method:str = "cut5",
batch_size:int = 10,
batch_threshold:float = 0.75,
split_bucket:bool = True,
speed_factor:float = 1.0,
fragment_interval:float = 0.3,
seed:int = -1,
media_type:str = "wav",
streaming_mode:bool = False,
parallel_infer:bool = True,
repetition_penalty:float = 1.35
):
req = {
"text": text,
"text_lang": text_lang.lower(),
"ref_audio_path": ref_audio_path,
"prompt_text": prompt_text,
"prompt_lang": prompt_lang.lower(),
"top_k": top_k,
"top_p": top_p,
"temperature": temperature,
"text_split_method": text_split_method,
"batch_size":int(batch_size),
"batch_threshold":float(batch_threshold),
"speed_factor":float(speed_factor),
"split_bucket":split_bucket,
"fragment_interval":fragment_interval,
"seed":seed,
"media_type":media_type,
"streaming_mode":streaming_mode,
"parallel_infer":parallel_infer,
"repetition_penalty":float(repetition_penalty)
}
return await tts_handle_srt(req,request)
@APP.post("/srt")
async def tts_post_endpoint_srt(request: TTS_Request,req1: Request):
req = request.dict()
return await tts_handle_srt(req,req1)
@APP.get("/tts")
async def tts_get_endpoint(