mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-08 16:00:01 +08:00
优化了代码逻辑,提升健壮性
This commit is contained in:
parent
488278bbe3
commit
5b573f4880
85
api_v2.py
85
api_v2.py
@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
# api.py usage
|
# web api 接口文档
|
||||||
|
|
||||||
` python api_v2.py -a 127.0.0.1 -p 9880 -c GPT_SoVITS/configs/tts_infer.yaml `
|
` python api_v2.py -a 127.0.0.1 -p 9880 -c GPT_SoVITS/configs/tts_infer.yaml `
|
||||||
|
|
||||||
@ -14,7 +14,7 @@
|
|||||||
|
|
||||||
endpoint: `/tts`
|
endpoint: `/tts`
|
||||||
GET:
|
GET:
|
||||||
`http://127.0.0.1:9880/tts?ref_audio_path=123.wav&prompt_text=一二三。&prompt_language=zh&text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh`
|
`http://127.0.0.1:9880/tts?text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_lang=zh&ref_audio_path=archive_jingyuan_1.wav&prompt_lang=zh&prompt_text=我是「罗浮」云骑将军景元。不必拘谨,「将军」只是一时的身份,你称呼我景元便可&text_split_method=cut5&batch_size=1&media_type=wav&streaming_mode=true`
|
||||||
|
|
||||||
POST:
|
POST:
|
||||||
```json
|
```json
|
||||||
@ -27,7 +27,7 @@ POST:
|
|||||||
"top_k": 5, # int.(optional) top k sampling
|
"top_k": 5, # int.(optional) top k sampling
|
||||||
"top_p": 1, # float.(optional) top p sampling
|
"top_p": 1, # float.(optional) top p sampling
|
||||||
"temperature": 1, # float.(optional) temperature for sampling
|
"temperature": 1, # float.(optional) temperature for sampling
|
||||||
"text_split_method": "cut0", # str.(optional) text split method, see text_segmentation_method.py for details.
|
"text_split_method": "cut5", # str.(optional) text split method, see text_segmentation_method.py for details.
|
||||||
"batch_size": 1, # int.(optional) batch size for inference
|
"batch_size": 1, # int.(optional) batch size for inference
|
||||||
"batch_threshold": 0.75, # float.(optional) threshold for batch splitting.
|
"batch_threshold": 0.75, # float.(optional) threshold for batch splitting.
|
||||||
"split_bucket: True, # bool.(optional) whether to split the batch into multiple buckets.
|
"split_bucket: True, # bool.(optional) whether to split the batch into multiple buckets.
|
||||||
@ -90,26 +90,21 @@ RESP:
|
|||||||
###
|
###
|
||||||
|
|
||||||
"""
|
"""
|
||||||
from asyncio import sleep
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
from typing import Generator
|
||||||
|
|
||||||
now_dir = os.getcwd()
|
now_dir = os.getcwd()
|
||||||
sys.path.append(now_dir)
|
sys.path.append(now_dir)
|
||||||
sys.path.append("%s/GPT_SoVITS" % (now_dir))
|
sys.path.append("%s/GPT_SoVITS" % (now_dir))
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import io
|
|
||||||
# import os
|
|
||||||
import subprocess
|
import subprocess
|
||||||
# import sys
|
|
||||||
import wave
|
import wave
|
||||||
|
|
||||||
import signal
|
import signal
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
from fastapi import FastAPI, Request, HTTPException
|
from fastapi import FastAPI, Request, HTTPException, Response
|
||||||
from fastapi.responses import StreamingResponse, JSONResponse
|
from fastapi.responses import StreamingResponse, JSONResponse
|
||||||
from fastapi import FastAPI, UploadFile, File
|
from fastapi import FastAPI, UploadFile, File
|
||||||
import uvicorn
|
import uvicorn
|
||||||
@ -117,14 +112,13 @@ from io import BytesIO
|
|||||||
from tools.i18n.i18n import I18nAuto
|
from tools.i18n.i18n import I18nAuto
|
||||||
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config
|
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
|
from pydantic import BaseModel
|
||||||
# print(sys.path)
|
# print(sys.path)
|
||||||
i18n = I18nAuto()
|
i18n = I18nAuto()
|
||||||
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="GPT-SoVITS api")
|
parser = argparse.ArgumentParser(description="GPT-SoVITS api")
|
||||||
# parser.add_argument("-api_c", "--api_config", type=str, default="GPT_SoVITS/configs/api_config.yaml", help="api_config路径")
|
parser.add_argument("-c", "--tts_config", type=str, default="GPT_SoVITS/configs/tts_infer.yaml", help="tts_infer路径")
|
||||||
parser.add_argument("-tts_c", "--tts_config", type=str, default="GPT_SoVITS/configs/tts_infer.yaml", help="tts_infer路径")
|
|
||||||
# parser.add_argument("-d", "--device", type=str, default="cpu", help="cuda / cpu / mps")
|
|
||||||
parser.add_argument("-a", "--bind_addr", type=str, default="127.0.0.1", help="default: 127.0.0.1")
|
parser.add_argument("-a", "--bind_addr", type=str, default="127.0.0.1", help="default: 127.0.0.1")
|
||||||
parser.add_argument("-p", "--port", type=int, default="9880", help="default: 9880")
|
parser.add_argument("-p", "--port", type=int, default="9880", help="default: 9880")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@ -141,7 +135,24 @@ tts_config=TTS_Config(config_path)
|
|||||||
tts_pipeline = TTS(tts_config)
|
tts_pipeline = TTS(tts_config)
|
||||||
|
|
||||||
APP = FastAPI()
|
APP = FastAPI()
|
||||||
|
class TTS_Request(BaseModel):
|
||||||
|
text: str = None
|
||||||
|
text_lang: str = None
|
||||||
|
ref_audio_path: str = None
|
||||||
|
prompt_lang: str = None
|
||||||
|
prompt_text: str = ""
|
||||||
|
top_k:int = 5
|
||||||
|
top_p:float = 1
|
||||||
|
temperature:float = 1
|
||||||
|
text_split_method:str = "cut5"
|
||||||
|
batch_size:int = 1
|
||||||
|
batch_threshold:float = 0.75
|
||||||
|
split_bucket:bool = True
|
||||||
|
speed_factor:float = 1.0
|
||||||
|
fragment_interval:float = 0.3
|
||||||
|
seed:int = -1
|
||||||
|
media_type:str = "wav"
|
||||||
|
streaming_mode:bool = False
|
||||||
|
|
||||||
### modify from https://github.com/RVC-Boss/GPT-SoVITS/pull/894/files
|
### modify from https://github.com/RVC-Boss/GPT-SoVITS/pull/894/files
|
||||||
def pack_ogg(io_buffer:BytesIO, data:np.ndarray, rate:int):
|
def pack_ogg(io_buffer:BytesIO, data:np.ndarray, rate:int):
|
||||||
@ -196,7 +207,7 @@ def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=3
|
|||||||
# This will create a wave header then append the frame input
|
# This will create a wave header then append the frame input
|
||||||
# It should be first on a streaming wav file
|
# It should be first on a streaming wav file
|
||||||
# Other frames better should not have it (else you will hear some artifacts each chunk start)
|
# Other frames better should not have it (else you will hear some artifacts each chunk start)
|
||||||
wav_buf = io.BytesIO()
|
wav_buf = BytesIO()
|
||||||
with wave.open(wav_buf, "wb") as vfout:
|
with wave.open(wav_buf, "wb") as vfout:
|
||||||
vfout.setnchannels(channels)
|
vfout.setnchannels(channels)
|
||||||
vfout.setsampwidth(sample_width)
|
vfout.setsampwidth(sample_width)
|
||||||
@ -215,7 +226,7 @@ def handle_control(command:str):
|
|||||||
exit(0)
|
exit(0)
|
||||||
|
|
||||||
|
|
||||||
def tts_handle(req:dict):
|
async def tts_handle(req:dict):
|
||||||
"""
|
"""
|
||||||
Text to speech handler.
|
Text to speech handler.
|
||||||
|
|
||||||
@ -257,38 +268,32 @@ def tts_handle(req:dict):
|
|||||||
return JSONResponse(status_code=400, content={"message": "ref_audio_path is required"})
|
return JSONResponse(status_code=400, content={"message": "ref_audio_path is required"})
|
||||||
if text in [None, ""]:
|
if text in [None, ""]:
|
||||||
return JSONResponse(status_code=400, content={"message": "text is required"})
|
return JSONResponse(status_code=400, content={"message": "text is required"})
|
||||||
if (text_lang in [None, ""]) or (text_lang.lower() not in ["auto", "en", "zh", "ja", "all_zh", "all_ja"]):
|
if (text_lang in [None, ""]) :
|
||||||
return JSONResponse(status_code=400, content={"message": "text_language is required"})
|
return JSONResponse(status_code=400, content={"message": "text_language is required"})
|
||||||
|
elif text_lang.lower() not in ["auto", "en", "zh", "ja", "all_zh", "all_ja"]:
|
||||||
|
return JSONResponse(status_code=400, content={"message": "text_language is not supported"})
|
||||||
if media_type not in ["wav", "raw", "ogg", "aac"]:
|
if media_type not in ["wav", "raw", "ogg", "aac"]:
|
||||||
return JSONResponse(status_code=400, content={"message": "media_type is not supported"})
|
return JSONResponse(status_code=400, content={"message": "media_type is not supported"})
|
||||||
|
elif media_type == "ogg" and not streaming_mode:
|
||||||
|
return JSONResponse(status_code=400, content={"message": "ogg format is not supported in non-streaming mode"})
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tts_generator=tts_pipeline.run(req)
|
tts_generator=tts_pipeline.run(req)
|
||||||
|
|
||||||
if streaming_mode:
|
if streaming_mode:
|
||||||
|
def streaming_generator(tts_generator:Generator, media_type:str):
|
||||||
if media_type == "wav":
|
if media_type == "wav":
|
||||||
def streaming_generator(tts_generator):
|
|
||||||
yield wave_header_chunk()
|
yield wave_header_chunk()
|
||||||
for sr, chunk in tts_generator:
|
media_type = "raw"
|
||||||
yield pack_audio(BytesIO(), chunk, sr, "raw").getvalue()
|
|
||||||
else:
|
|
||||||
def streaming_generator(tts_generator):
|
|
||||||
for sr, chunk in tts_generator:
|
for sr, chunk in tts_generator:
|
||||||
yield pack_audio(BytesIO(), chunk, sr, media_type).getvalue()
|
yield pack_audio(BytesIO(), chunk, sr, media_type).getvalue()
|
||||||
|
# _media_type = f"audio/{media_type}" if not (streaming_mode and media_type in ["wav", "raw"]) else f"audio/x-{media_type}"
|
||||||
_media_type = ""
|
return StreamingResponse(streaming_generator(tts_generator, media_type, ), media_type=f"audio/{media_type}")
|
||||||
if media_type in ["wav", "raw"]:
|
|
||||||
_media_type = f"audio/x-{media_type}"
|
|
||||||
else:
|
|
||||||
_media_type = f"audio/{media_type}"
|
|
||||||
return StreamingResponse(streaming_generator(tts_generator), media_type=_media_type)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
audio_buffer = BytesIO()
|
|
||||||
sr, audio_data = next(tts_generator)
|
sr, audio_data = next(tts_generator)
|
||||||
audio_buffer = pack_audio(audio_buffer, audio_data, sr, media_type)
|
audio_data = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue()
|
||||||
return StreamingResponse(audio_buffer, media_type=f"audio/{media_type}")
|
return Response(audio_data, media_type=f"audio/{media_type}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return JSONResponse(status_code=400, content={"message": f"tts failed", "Exception": str(e)})
|
return JSONResponse(status_code=400, content={"message": f"tts failed", "Exception": str(e)})
|
||||||
|
|
||||||
@ -344,13 +349,13 @@ async def tts_get_endpoint(
|
|||||||
"media_type":media_type,
|
"media_type":media_type,
|
||||||
"streaming_mode":streaming_mode,
|
"streaming_mode":streaming_mode,
|
||||||
}
|
}
|
||||||
return tts_handle(req)
|
return await tts_handle(req)
|
||||||
|
|
||||||
|
|
||||||
@APP.post("/tts")
|
@APP.post("/tts")
|
||||||
async def tts_post_endpoint(request: Request):
|
async def tts_post_endpoint(request: TTS_Request):
|
||||||
req = await request.json()
|
req = request.dict()
|
||||||
return tts_handle(req)
|
return await tts_handle(req)
|
||||||
|
|
||||||
|
|
||||||
@APP.get("/set_refer_audio")
|
@APP.get("/set_refer_audio")
|
||||||
@ -404,10 +409,6 @@ async def set_sovits_weights(weights_path: str = None):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
try:
|
|
||||||
uvicorn.run(APP, host=host, port=port, workers=1)
|
uvicorn.run(APP, host=host, port=port, workers=1)
|
||||||
except KeyboardInterrupt:
|
|
||||||
os.kill(os.getpid(), signal.SIGTERM)
|
|
||||||
exit(0)
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user