mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-09 00:10:00 +08:00
rebuild api_v2
This commit is contained in:
parent
67e3796d36
commit
3eea6570da
177
api_v2.py
177
api_v2.py
@ -22,10 +22,10 @@ POST:
|
|||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"text": "", # str.(required) text to be synthesized
|
"text": "", # str.(required) text to be synthesized
|
||||||
"text_lang": "", # str.(required) language of the text to be synthesized
|
|
||||||
"ref_audio_path": "", # str.(required) reference audio path.
|
"ref_audio_path": "", # str.(required) reference audio path.
|
||||||
"prompt_text": "", # str.(optional) prompt text for the reference audio
|
"prompt_text": "", # str.(optional) prompt text for the reference audio
|
||||||
"prompt_lang": "", # str.(required) language of the prompt text for the reference audio
|
"text_lang": "auto", # str.(optional) language of the text to be synthesized
|
||||||
|
"prompt_lang": "auto", # str.(optional) language of the prompt text for the reference audio
|
||||||
"top_k": 5, # int.(optional) top k sampling
|
"top_k": 5, # int.(optional) top k sampling
|
||||||
"top_p": 1, # float.(optional) top p sampling
|
"top_p": 1, # float.(optional) top p sampling
|
||||||
"temperature": 1, # float.(optional) temperature for sampling
|
"temperature": 1, # float.(optional) temperature for sampling
|
||||||
@ -117,8 +117,12 @@ from io import BytesIO
|
|||||||
from tools.i18n.i18n import I18nAuto
|
from tools.i18n.i18n import I18nAuto
|
||||||
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config
|
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config
|
||||||
from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names
|
from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse, FileResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
from urllib.parse import unquote
|
||||||
|
|
||||||
# print(sys.path)
|
# print(sys.path)
|
||||||
i18n = I18nAuto()
|
i18n = I18nAuto()
|
||||||
cut_method_names = get_cut_method_names()
|
cut_method_names = get_cut_method_names()
|
||||||
@ -141,12 +145,14 @@ tts_config = TTS_Config(config_path)
|
|||||||
tts_pipeline = TTS(tts_config)
|
tts_pipeline = TTS(tts_config)
|
||||||
|
|
||||||
APP = FastAPI()
|
APP = FastAPI()
|
||||||
|
|
||||||
|
# modified from https://github.com/X-T-E-R/GPT-SoVITS-Inference/blob/stable/Inference/src/TTS_Instance.py
|
||||||
class TTS_Request(BaseModel):
|
class TTS_Request(BaseModel):
|
||||||
text: str = None
|
text: str = None
|
||||||
text_lang: str = None
|
|
||||||
ref_audio_path: str = None
|
ref_audio_path: str = None
|
||||||
prompt_lang: str = None
|
|
||||||
prompt_text: str = ""
|
prompt_text: str = ""
|
||||||
|
text_lang: str = "auto"
|
||||||
|
prompt_lang: str = "auto"
|
||||||
top_k:int = 5
|
top_k:int = 5
|
||||||
top_p:float = 1
|
top_p:float = 1
|
||||||
temperature:float = 1
|
temperature:float = 1
|
||||||
@ -160,6 +166,41 @@ class TTS_Request(BaseModel):
|
|||||||
media_type:str = "wav"
|
media_type:str = "wav"
|
||||||
streaming_mode:bool = False
|
streaming_mode:bool = False
|
||||||
|
|
||||||
|
# 青春版 from TTS_Task from https://github.com/X-T-E-R/GPT-SoVITS-Inference/blob/stable/Inference/src/TTS_Instance.py
|
||||||
|
def update(self, req:dict):
|
||||||
|
for key in req:
|
||||||
|
if hasattr(self, key):
|
||||||
|
type_ = type(getattr(self, key))
|
||||||
|
value = unquote(req[key])
|
||||||
|
if type_ == bool:
|
||||||
|
value = value.lower() in ["true", "1"]
|
||||||
|
elif type_ == int:
|
||||||
|
value = int(value)
|
||||||
|
elif type_ == float:
|
||||||
|
value = float(value)
|
||||||
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return self.model_dump()
|
||||||
|
|
||||||
|
def check(self):
|
||||||
|
if (self.text_lang in [None, ""]) or self.text_lang.lower() not in tts_config.languages:
|
||||||
|
self.text_lang = "auto"
|
||||||
|
if (self.prompt_lang in [None, ""]) or self.prompt_lang.lower() not in tts_config.languages:
|
||||||
|
self.prompt_lang = "auto"
|
||||||
|
|
||||||
|
if self.text in [None, ""]:
|
||||||
|
return JSONResponse(status_code=400, content={"message": "text is required"})
|
||||||
|
if self.ref_audio_path in [None, ""]:
|
||||||
|
return JSONResponse(status_code=400, content={"message": "ref_audio_path is required"})
|
||||||
|
if self.streaming_mode and self.media_type not in ["wav", "raw", "ogg", "aac"]:
|
||||||
|
return JSONResponse(status_code=400, content={"message": f"media_type {self.media_type} is not supported in streaming mode"})
|
||||||
|
if self.text_split_method not in cut_method_names:
|
||||||
|
return JSONResponse(status_code=400, content={"message": f"text_split_method:{self.text_split_method} is not supported"})
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# 有点想删掉这些东西,为了streaming 写了一堆东西,但是貌似用streaming的时候,一般用的是wav
|
||||||
### modify from https://github.com/RVC-Boss/GPT-SoVITS/pull/894/files
|
### modify from https://github.com/RVC-Boss/GPT-SoVITS/pull/894/files
|
||||||
def pack_ogg(io_buffer:BytesIO, data:np.ndarray, rate:int):
|
def pack_ogg(io_buffer:BytesIO, data:np.ndarray, rate:int):
|
||||||
with sf.SoundFile(io_buffer, mode='w', samplerate=rate, channels=1, format='ogg') as audio_file:
|
with sf.SoundFile(io_buffer, mode='w', samplerate=rate, channels=1, format='ogg') as audio_file:
|
||||||
@ -231,39 +272,8 @@ def handle_control(command:str):
|
|||||||
os.kill(os.getpid(), signal.SIGTERM)
|
os.kill(os.getpid(), signal.SIGTERM)
|
||||||
exit(0)
|
exit(0)
|
||||||
|
|
||||||
|
# 不用写成异步的,反正要等,也不能并行
|
||||||
def check_params(req:dict):
|
def tts_handle(req:dict):
|
||||||
text:str = req.get("text", "")
|
|
||||||
text_lang:str = req.get("text_lang", "")
|
|
||||||
ref_audio_path:str = req.get("ref_audio_path", "")
|
|
||||||
streaming_mode:bool = req.get("streaming_mode", False)
|
|
||||||
media_type:str = req.get("media_type", "wav")
|
|
||||||
prompt_lang:str = req.get("prompt_lang", "")
|
|
||||||
text_split_method:str = req.get("text_split_method", "cut5")
|
|
||||||
|
|
||||||
if ref_audio_path in [None, ""]:
|
|
||||||
return JSONResponse(status_code=400, content={"message": "ref_audio_path is required"})
|
|
||||||
if text in [None, ""]:
|
|
||||||
return JSONResponse(status_code=400, content={"message": "text is required"})
|
|
||||||
if (text_lang in [None, ""]) :
|
|
||||||
return JSONResponse(status_code=400, content={"message": "text_lang is required"})
|
|
||||||
elif text_lang.lower() not in tts_config.languages:
|
|
||||||
return JSONResponse(status_code=400, content={"message": "text_lang is not supported"})
|
|
||||||
if (prompt_lang in [None, ""]) :
|
|
||||||
return JSONResponse(status_code=400, content={"message": "prompt_lang is required"})
|
|
||||||
elif prompt_lang.lower() not in tts_config.languages:
|
|
||||||
return JSONResponse(status_code=400, content={"message": "prompt_lang is not supported"})
|
|
||||||
if media_type not in ["wav", "raw", "ogg", "aac"]:
|
|
||||||
return JSONResponse(status_code=400, content={"message": "media_type is not supported"})
|
|
||||||
elif media_type == "ogg" and not streaming_mode:
|
|
||||||
return JSONResponse(status_code=400, content={"message": "ogg format is not supported in non-streaming mode"})
|
|
||||||
|
|
||||||
if text_split_method not in cut_method_names:
|
|
||||||
return JSONResponse(status_code=400, content={"message": f"text_split_method:{text_split_method} is not supported"})
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def tts_handle(req:dict):
|
|
||||||
"""
|
"""
|
||||||
Text to speech handler.
|
Text to speech handler.
|
||||||
|
|
||||||
@ -271,10 +281,10 @@ async def tts_handle(req:dict):
|
|||||||
req (dict):
|
req (dict):
|
||||||
{
|
{
|
||||||
"text": "", # str.(required) text to be synthesized
|
"text": "", # str.(required) text to be synthesized
|
||||||
"text_lang: "", # str.(required) language of the text to be synthesized
|
|
||||||
"ref_audio_path": "", # str.(required) reference audio path
|
"ref_audio_path": "", # str.(required) reference audio path
|
||||||
"prompt_text": "", # str.(optional) prompt text for the reference audio
|
"prompt_text": "", # str.(optional) prompt text for the reference audio
|
||||||
"prompt_lang": "", # str.(required) language of the prompt text for the reference audio
|
"text_lang: "auto", # str. language of the text to be synthesized
|
||||||
|
"prompt_lang": "auto", # str. language of the prompt text for the reference audio
|
||||||
"top_k": 5, # int. top k sampling
|
"top_k": 5, # int. top k sampling
|
||||||
"top_p": 1, # float. top p sampling
|
"top_p": 1, # float. top p sampling
|
||||||
"temperature": 1, # float. temperature for sampling
|
"temperature": 1, # float. temperature for sampling
|
||||||
@ -292,13 +302,10 @@ async def tts_handle(req:dict):
|
|||||||
StreamingResponse: audio stream response.
|
StreamingResponse: audio stream response.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# 已经检查过了,这里不再检查
|
||||||
streaming_mode = req.get("streaming_mode", False)
|
streaming_mode = req.get("streaming_mode", False)
|
||||||
media_type = req.get("media_type", "wav")
|
media_type = req.get("media_type", "wav")
|
||||||
|
|
||||||
check_res = check_params(req)
|
|
||||||
if check_res is not None:
|
|
||||||
return check_res
|
|
||||||
|
|
||||||
if streaming_mode:
|
if streaming_mode:
|
||||||
req["return_fragment"] = True
|
req["return_fragment"] = True
|
||||||
|
|
||||||
@ -316,17 +323,26 @@ async def tts_handle(req:dict):
|
|||||||
return StreamingResponse(streaming_generator(tts_generator, media_type, ), media_type=f"audio/{media_type}")
|
return StreamingResponse(streaming_generator(tts_generator, media_type, ), media_type=f"audio/{media_type}")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
# 换用临时文件,支持更多格式,速度能更快,并且会避免占线
|
||||||
sr, audio_data = next(tts_generator)
|
sr, audio_data = next(tts_generator)
|
||||||
audio_data = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue()
|
format = media_type
|
||||||
return Response(audio_data, media_type=f"audio/{media_type}")
|
with tempfile.NamedTemporaryFile(delete=False, suffix=f'.{format}') as tmp_file:
|
||||||
|
# 尝试写入用户指定的格式,如果失败则回退到 WAV 格式
|
||||||
|
try:
|
||||||
|
sf.write(tmp_file, audio_data, sr, format=format)
|
||||||
|
except Exception as e:
|
||||||
|
# 如果指定的格式无法写入,则回退到 WAV 格式
|
||||||
|
sf.write(tmp_file, audio_data, sr, format='wav')
|
||||||
|
format = 'wav' # 更新格式为 wav
|
||||||
|
|
||||||
|
tmp_file_path = tmp_file.name
|
||||||
|
# 返回文件响应,FileResponse 会负责将文件发送给客户端
|
||||||
|
return FileResponse(tmp_file_path, media_type=f"audio/{format}", filename=f"audio.{format}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return JSONResponse(status_code=400, content={"message": f"tts failed", "Exception": str(e)})
|
return JSONResponse(status_code=400, content={"message": f"tts failed", "Exception": str(e)})
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@APP.get("/control")
|
@APP.get("/control")
|
||||||
async def control(command: str = None):
|
async def control(command: str = None):
|
||||||
if command is None:
|
if command is None:
|
||||||
@ -334,53 +350,24 @@ async def control(command: str = None):
|
|||||||
handle_control(command)
|
handle_control(command)
|
||||||
|
|
||||||
|
|
||||||
|
# modified from https://github.com/X-T-E-R/GPT-SoVITS-Inference/blob/stable/Inference/src/tts_backend.py
|
||||||
@APP.get("/tts")
|
@APP.get("/tts")
|
||||||
async def tts_get_endpoint(
|
|
||||||
text: str = None,
|
|
||||||
text_lang: str = None,
|
|
||||||
ref_audio_path: str = None,
|
|
||||||
prompt_lang: str = None,
|
|
||||||
prompt_text: str = "",
|
|
||||||
top_k:int = 5,
|
|
||||||
top_p:float = 1,
|
|
||||||
temperature:float = 1,
|
|
||||||
text_split_method:str = "cut0",
|
|
||||||
batch_size:int = 1,
|
|
||||||
batch_threshold:float = 0.75,
|
|
||||||
split_bucket:bool = True,
|
|
||||||
speed_factor:float = 1.0,
|
|
||||||
fragment_interval:float = 0.3,
|
|
||||||
seed:int = -1,
|
|
||||||
media_type:str = "wav",
|
|
||||||
streaming_mode:bool = False,
|
|
||||||
):
|
|
||||||
req = {
|
|
||||||
"text": text,
|
|
||||||
"text_lang": text_lang.lower(),
|
|
||||||
"ref_audio_path": ref_audio_path,
|
|
||||||
"prompt_text": prompt_text,
|
|
||||||
"prompt_lang": prompt_lang.lower(),
|
|
||||||
"top_k": top_k,
|
|
||||||
"top_p": top_p,
|
|
||||||
"temperature": temperature,
|
|
||||||
"text_split_method": text_split_method,
|
|
||||||
"batch_size":int(batch_size),
|
|
||||||
"batch_threshold":float(batch_threshold),
|
|
||||||
"speed_factor":float(speed_factor),
|
|
||||||
"split_bucket":split_bucket,
|
|
||||||
"fragment_interval":fragment_interval,
|
|
||||||
"seed":seed,
|
|
||||||
"media_type":media_type,
|
|
||||||
"streaming_mode":streaming_mode,
|
|
||||||
}
|
|
||||||
return await tts_handle(req)
|
|
||||||
|
|
||||||
|
|
||||||
@APP.post("/tts")
|
@APP.post("/tts")
|
||||||
async def tts_post_endpoint(request: TTS_Request):
|
async def tts_get_endpoint(request: Request):
|
||||||
req = request.dict()
|
|
||||||
return await tts_handle(req)
|
# 尝试从JSON中获取数据,如果不是JSON,则从查询参数中获取
|
||||||
|
if request.method == "GET":
|
||||||
|
data = request.query_params
|
||||||
|
else:
|
||||||
|
data = await request.json()
|
||||||
|
|
||||||
|
req = TTS_Request()
|
||||||
|
req.update(data)
|
||||||
|
res = req.check()
|
||||||
|
if res is not None:
|
||||||
|
return res
|
||||||
|
|
||||||
|
return tts_handle(req.to_dict())
|
||||||
|
|
||||||
|
|
||||||
@APP.get("/set_refer_audio")
|
@APP.get("/set_refer_audio")
|
||||||
@ -436,7 +423,7 @@ async def set_sovits_weights(weights_path: str = None):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
try:
|
try:
|
||||||
uvicorn.run(APP, host=host, port=port, workers=1)
|
uvicorn.run(APP, host=host, port=port) # 删去workers=1,uvicorn这么写没法加 workers
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
os.kill(os.getpid(), signal.SIGTERM)
|
os.kill(os.getpid(), signal.SIGTERM)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user