优化了代码逻辑,提升健壮性

This commit is contained in:
ChasonJiang 2024-04-04 15:25:30 +08:00
parent 488278bbe3
commit 5b573f4880

View File

@ -1,5 +1,5 @@
"""
# api.py usage
# web api 接口文档
` python api_v2.py -a 127.0.0.1 -p 9880 -c GPT_SoVITS/configs/tts_infer.yaml `
@ -14,7 +14,7 @@
endpoint: `/tts`
GET:
`http://127.0.0.1:9880/tts?ref_audio_path=123.wav&prompt_text=一二三&prompt_language=zh&text=先帝创业未半而中道崩殂今天下三分益州疲弊此诚危急存亡之秋也&text_language=zh`
`http://127.0.0.1:9880/tts?text=先帝创业未半而中道崩殂今天下三分益州疲弊此诚危急存亡之秋也&text_lang=zh&ref_audio_path=archive_jingyuan_1.wav&prompt_lang=zh&prompt_text=我是罗浮云骑将军景元不必拘谨将军只是一时的身份你称呼我景元便可&text_split_method=cut5&batch_size=1&media_type=wav&streaming_mode=true`
POST:
```json
@ -27,7 +27,7 @@ POST:
"top_k": 5, # int.(optional) top k sampling
"top_p": 1, # float.(optional) top p sampling
"temperature": 1, # float.(optional) temperature for sampling
"text_split_method": "cut0", # str.(optional) text split method, see text_segmentation_method.py for details.
"text_split_method": "cut5", # str.(optional) text split method, see text_segmentation_method.py for details.
"batch_size": 1, # int.(optional) batch size for inference
"batch_threshold": 0.75, # float.(optional) threshold for batch splitting.
"split_bucket: True, # bool.(optional) whether to split the batch into multiple buckets.
@ -90,26 +90,21 @@ RESP:
###
"""
from asyncio import sleep
import os
import sys
from typing import Generator
now_dir = os.getcwd()
sys.path.append(now_dir)
sys.path.append("%s/GPT_SoVITS" % (now_dir))
import argparse
import io
# import os
import subprocess
# import sys
import wave
import signal
import numpy as np
import torch
import soundfile as sf
from fastapi import FastAPI, Request, HTTPException
from fastapi import FastAPI, Request, HTTPException, Response
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi import FastAPI, UploadFile, File
import uvicorn
@ -117,14 +112,13 @@ from io import BytesIO
from tools.i18n.i18n import I18nAuto
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
# print(sys.path)
i18n = I18nAuto()
parser = argparse.ArgumentParser(description="GPT-SoVITS api")
# parser.add_argument("-api_c", "--api_config", type=str, default="GPT_SoVITS/configs/api_config.yaml", help="api_config路径")
parser.add_argument("-tts_c", "--tts_config", type=str, default="GPT_SoVITS/configs/tts_infer.yaml", help="tts_infer路径")
# parser.add_argument("-d", "--device", type=str, default="cpu", help="cuda / cpu / mps")
parser.add_argument("-c", "--tts_config", type=str, default="GPT_SoVITS/configs/tts_infer.yaml", help="tts_infer路径")
parser.add_argument("-a", "--bind_addr", type=str, default="127.0.0.1", help="default: 127.0.0.1")
parser.add_argument("-p", "--port", type=int, default="9880", help="default: 9880")
args = parser.parse_args()
@ -137,11 +131,28 @@ argv = sys.argv
if config_path in [None, ""]:
config_path = "GPT-SoVITS/configs/tts_infer.yaml"
tts_config=TTS_Config(config_path)
tts_config = TTS_Config(config_path)
tts_pipeline = TTS(tts_config)
APP = FastAPI()
class TTS_Request(BaseModel):
text: str = None
text_lang: str = None
ref_audio_path: str = None
prompt_lang: str = None
prompt_text: str = ""
top_k:int = 5
top_p:float = 1
temperature:float = 1
text_split_method:str = "cut5"
batch_size:int = 1
batch_threshold:float = 0.75
split_bucket:bool = True
speed_factor:float = 1.0
fragment_interval:float = 0.3
seed:int = -1
media_type:str = "wav"
streaming_mode:bool = False
### modify from https://github.com/RVC-Boss/GPT-SoVITS/pull/894/files
def pack_ogg(io_buffer:BytesIO, data:np.ndarray, rate:int):
@ -196,7 +207,7 @@ def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=3
# This will create a wave header then append the frame input
# It should be first on a streaming wav file
# Other frames better should not have it (else you will hear some artifacts each chunk start)
wav_buf = io.BytesIO()
wav_buf = BytesIO()
with wave.open(wav_buf, "wb") as vfout:
vfout.setnchannels(channels)
vfout.setsampwidth(sample_width)
@ -215,7 +226,7 @@ def handle_control(command:str):
exit(0)
def tts_handle(req:dict):
async def tts_handle(req:dict):
"""
Text to speech handler.
@ -257,38 +268,32 @@ def tts_handle(req:dict):
return JSONResponse(status_code=400, content={"message": "ref_audio_path is required"})
if text in [None, ""]:
return JSONResponse(status_code=400, content={"message": "text is required"})
if (text_lang in [None, ""]) or (text_lang.lower() not in ["auto", "en", "zh", "ja", "all_zh", "all_ja"]):
if (text_lang in [None, ""]) :
return JSONResponse(status_code=400, content={"message": "text_language is required"})
elif text_lang.lower() not in ["auto", "en", "zh", "ja", "all_zh", "all_ja"]:
return JSONResponse(status_code=400, content={"message": "text_language is not supported"})
if media_type not in ["wav", "raw", "ogg", "aac"]:
return JSONResponse(status_code=400, content={"message": "media_type is not supported"})
elif media_type == "ogg" and not streaming_mode:
return JSONResponse(status_code=400, content={"message": "ogg format is not supported in non-streaming mode"})
try:
tts_generator=tts_pipeline.run(req)
if streaming_mode:
if media_type == "wav":
def streaming_generator(tts_generator):
def streaming_generator(tts_generator:Generator, media_type:str):
if media_type == "wav":
yield wave_header_chunk()
for sr, chunk in tts_generator:
yield pack_audio(BytesIO(), chunk, sr, "raw").getvalue()
else:
def streaming_generator(tts_generator):
for sr, chunk in tts_generator:
yield pack_audio(BytesIO(), chunk, sr, media_type).getvalue()
_media_type = ""
if media_type in ["wav", "raw"]:
_media_type = f"audio/x-{media_type}"
else:
_media_type = f"audio/{media_type}"
return StreamingResponse(streaming_generator(tts_generator), media_type=_media_type)
media_type = "raw"
for sr, chunk in tts_generator:
yield pack_audio(BytesIO(), chunk, sr, media_type).getvalue()
# _media_type = f"audio/{media_type}" if not (streaming_mode and media_type in ["wav", "raw"]) else f"audio/x-{media_type}"
return StreamingResponse(streaming_generator(tts_generator, media_type, ), media_type=f"audio/{media_type}")
else:
audio_buffer = BytesIO()
sr, audio_data = next(tts_generator)
audio_buffer = pack_audio(audio_buffer, audio_data, sr, media_type)
return StreamingResponse(audio_buffer, media_type=f"audio/{media_type}")
audio_data = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue()
return Response(audio_data, media_type=f"audio/{media_type}")
except Exception as e:
return JSONResponse(status_code=400, content={"message": f"tts failed", "Exception": str(e)})
@ -344,13 +349,13 @@ async def tts_get_endpoint(
"media_type":media_type,
"streaming_mode":streaming_mode,
}
return tts_handle(req)
return await tts_handle(req)
@APP.post("/tts")
async def tts_post_endpoint(request: Request):
req = await request.json()
return tts_handle(req)
async def tts_post_endpoint(request: TTS_Request):
req = request.dict()
return await tts_handle(req)
@APP.get("/set_refer_audio")
@ -404,10 +409,6 @@ async def set_sovits_weights(weights_path: str = None):
if __name__ == "__main__":
try:
uvicorn.run(APP, host=host, port=port, workers=1)
except KeyboardInterrupt:
os.kill(os.getpid(), signal.SIGTERM)
exit(0)
uvicorn.run(APP, host=host, port=port, workers=1)