优化了代码逻辑,提升健壮性

This commit is contained in:
ChasonJiang 2024-04-04 15:25:30 +08:00
parent 488278bbe3
commit 5b573f4880

View File

@ -1,5 +1,5 @@
""" """
# api.py usage # web api 接口文档
` python api_v2.py -a 127.0.0.1 -p 9880 -c GPT_SoVITS/configs/tts_infer.yaml ` ` python api_v2.py -a 127.0.0.1 -p 9880 -c GPT_SoVITS/configs/tts_infer.yaml `
@ -14,7 +14,7 @@
endpoint: `/tts` endpoint: `/tts`
GET: GET:
`http://127.0.0.1:9880/tts?ref_audio_path=123.wav&prompt_text=一二三&prompt_language=zh&text=先帝创业未半而中道崩殂今天下三分益州疲弊此诚危急存亡之秋也&text_language=zh` `http://127.0.0.1:9880/tts?text=先帝创业未半而中道崩殂今天下三分益州疲弊此诚危急存亡之秋也&text_lang=zh&ref_audio_path=archive_jingyuan_1.wav&prompt_lang=zh&prompt_text=我是罗浮云骑将军景元不必拘谨将军只是一时的身份你称呼我景元便可&text_split_method=cut5&batch_size=1&media_type=wav&streaming_mode=true`
POST: POST:
```json ```json
@ -27,7 +27,7 @@ POST:
"top_k": 5, # int.(optional) top k sampling "top_k": 5, # int.(optional) top k sampling
"top_p": 1, # float.(optional) top p sampling "top_p": 1, # float.(optional) top p sampling
"temperature": 1, # float.(optional) temperature for sampling "temperature": 1, # float.(optional) temperature for sampling
"text_split_method": "cut0", # str.(optional) text split method, see text_segmentation_method.py for details. "text_split_method": "cut5", # str.(optional) text split method, see text_segmentation_method.py for details.
"batch_size": 1, # int.(optional) batch size for inference "batch_size": 1, # int.(optional) batch size for inference
"batch_threshold": 0.75, # float.(optional) threshold for batch splitting. "batch_threshold": 0.75, # float.(optional) threshold for batch splitting.
"split_bucket: True, # bool.(optional) whether to split the batch into multiple buckets. "split_bucket: True, # bool.(optional) whether to split the batch into multiple buckets.
@ -90,26 +90,21 @@ RESP:
### ###
""" """
from asyncio import sleep
import os import os
import sys import sys
from typing import Generator
now_dir = os.getcwd() now_dir = os.getcwd()
sys.path.append(now_dir) sys.path.append(now_dir)
sys.path.append("%s/GPT_SoVITS" % (now_dir)) sys.path.append("%s/GPT_SoVITS" % (now_dir))
import argparse import argparse
import io
# import os
import subprocess import subprocess
# import sys
import wave import wave
import signal import signal
import numpy as np import numpy as np
import torch
import soundfile as sf import soundfile as sf
from fastapi import FastAPI, Request, HTTPException from fastapi import FastAPI, Request, HTTPException, Response
from fastapi.responses import StreamingResponse, JSONResponse from fastapi.responses import StreamingResponse, JSONResponse
from fastapi import FastAPI, UploadFile, File from fastapi import FastAPI, UploadFile, File
import uvicorn import uvicorn
@ -117,14 +112,13 @@ from io import BytesIO
from tools.i18n.i18n import I18nAuto from tools.i18n.i18n import I18nAuto
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from pydantic import BaseModel
# print(sys.path) # print(sys.path)
i18n = I18nAuto() i18n = I18nAuto()
parser = argparse.ArgumentParser(description="GPT-SoVITS api") parser = argparse.ArgumentParser(description="GPT-SoVITS api")
# parser.add_argument("-api_c", "--api_config", type=str, default="GPT_SoVITS/configs/api_config.yaml", help="api_config路径") parser.add_argument("-c", "--tts_config", type=str, default="GPT_SoVITS/configs/tts_infer.yaml", help="tts_infer路径")
parser.add_argument("-tts_c", "--tts_config", type=str, default="GPT_SoVITS/configs/tts_infer.yaml", help="tts_infer路径")
# parser.add_argument("-d", "--device", type=str, default="cpu", help="cuda / cpu / mps")
parser.add_argument("-a", "--bind_addr", type=str, default="127.0.0.1", help="default: 127.0.0.1") parser.add_argument("-a", "--bind_addr", type=str, default="127.0.0.1", help="default: 127.0.0.1")
parser.add_argument("-p", "--port", type=int, default="9880", help="default: 9880") parser.add_argument("-p", "--port", type=int, default="9880", help="default: 9880")
args = parser.parse_args() args = parser.parse_args()
@ -137,11 +131,28 @@ argv = sys.argv
if config_path in [None, ""]: if config_path in [None, ""]:
config_path = "GPT-SoVITS/configs/tts_infer.yaml" config_path = "GPT-SoVITS/configs/tts_infer.yaml"
tts_config=TTS_Config(config_path) tts_config = TTS_Config(config_path)
tts_pipeline = TTS(tts_config) tts_pipeline = TTS(tts_config)
APP = FastAPI() APP = FastAPI()
class TTS_Request(BaseModel):
text: str = None
text_lang: str = None
ref_audio_path: str = None
prompt_lang: str = None
prompt_text: str = ""
top_k:int = 5
top_p:float = 1
temperature:float = 1
text_split_method:str = "cut5"
batch_size:int = 1
batch_threshold:float = 0.75
split_bucket:bool = True
speed_factor:float = 1.0
fragment_interval:float = 0.3
seed:int = -1
media_type:str = "wav"
streaming_mode:bool = False
### modify from https://github.com/RVC-Boss/GPT-SoVITS/pull/894/files ### modify from https://github.com/RVC-Boss/GPT-SoVITS/pull/894/files
def pack_ogg(io_buffer:BytesIO, data:np.ndarray, rate:int): def pack_ogg(io_buffer:BytesIO, data:np.ndarray, rate:int):
@ -196,7 +207,7 @@ def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=3
# This will create a wave header then append the frame input # This will create a wave header then append the frame input
# It should be first on a streaming wav file # It should be first on a streaming wav file
# Other frames better should not have it (else you will hear some artifacts each chunk start) # Other frames better should not have it (else you will hear some artifacts each chunk start)
wav_buf = io.BytesIO() wav_buf = BytesIO()
with wave.open(wav_buf, "wb") as vfout: with wave.open(wav_buf, "wb") as vfout:
vfout.setnchannels(channels) vfout.setnchannels(channels)
vfout.setsampwidth(sample_width) vfout.setsampwidth(sample_width)
@ -215,7 +226,7 @@ def handle_control(command:str):
exit(0) exit(0)
def tts_handle(req:dict): async def tts_handle(req:dict):
""" """
Text to speech handler. Text to speech handler.
@ -257,38 +268,32 @@ def tts_handle(req:dict):
return JSONResponse(status_code=400, content={"message": "ref_audio_path is required"}) return JSONResponse(status_code=400, content={"message": "ref_audio_path is required"})
if text in [None, ""]: if text in [None, ""]:
return JSONResponse(status_code=400, content={"message": "text is required"}) return JSONResponse(status_code=400, content={"message": "text is required"})
if (text_lang in [None, ""]) or (text_lang.lower() not in ["auto", "en", "zh", "ja", "all_zh", "all_ja"]): if (text_lang in [None, ""]) :
return JSONResponse(status_code=400, content={"message": "text_language is required"}) return JSONResponse(status_code=400, content={"message": "text_language is required"})
elif text_lang.lower() not in ["auto", "en", "zh", "ja", "all_zh", "all_ja"]:
return JSONResponse(status_code=400, content={"message": "text_language is not supported"})
if media_type not in ["wav", "raw", "ogg", "aac"]: if media_type not in ["wav", "raw", "ogg", "aac"]:
return JSONResponse(status_code=400, content={"message": "media_type is not supported"}) return JSONResponse(status_code=400, content={"message": "media_type is not supported"})
elif media_type == "ogg" and not streaming_mode:
return JSONResponse(status_code=400, content={"message": "ogg format is not supported in non-streaming mode"})
try: try:
tts_generator=tts_pipeline.run(req) tts_generator=tts_pipeline.run(req)
if streaming_mode: if streaming_mode:
def streaming_generator(tts_generator:Generator, media_type:str):
if media_type == "wav": if media_type == "wav":
def streaming_generator(tts_generator):
yield wave_header_chunk() yield wave_header_chunk()
for sr, chunk in tts_generator: media_type = "raw"
yield pack_audio(BytesIO(), chunk, sr, "raw").getvalue()
else:
def streaming_generator(tts_generator):
for sr, chunk in tts_generator: for sr, chunk in tts_generator:
yield pack_audio(BytesIO(), chunk, sr, media_type).getvalue() yield pack_audio(BytesIO(), chunk, sr, media_type).getvalue()
# _media_type = f"audio/{media_type}" if not (streaming_mode and media_type in ["wav", "raw"]) else f"audio/x-{media_type}"
_media_type = "" return StreamingResponse(streaming_generator(tts_generator, media_type, ), media_type=f"audio/{media_type}")
if media_type in ["wav", "raw"]:
_media_type = f"audio/x-{media_type}"
else:
_media_type = f"audio/{media_type}"
return StreamingResponse(streaming_generator(tts_generator), media_type=_media_type)
else: else:
audio_buffer = BytesIO()
sr, audio_data = next(tts_generator) sr, audio_data = next(tts_generator)
audio_buffer = pack_audio(audio_buffer, audio_data, sr, media_type) audio_data = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue()
return StreamingResponse(audio_buffer, media_type=f"audio/{media_type}") return Response(audio_data, media_type=f"audio/{media_type}")
except Exception as e: except Exception as e:
return JSONResponse(status_code=400, content={"message": f"tts failed", "Exception": str(e)}) return JSONResponse(status_code=400, content={"message": f"tts failed", "Exception": str(e)})
@ -344,13 +349,13 @@ async def tts_get_endpoint(
"media_type":media_type, "media_type":media_type,
"streaming_mode":streaming_mode, "streaming_mode":streaming_mode,
} }
return tts_handle(req) return await tts_handle(req)
@APP.post("/tts") @APP.post("/tts")
async def tts_post_endpoint(request: Request): async def tts_post_endpoint(request: TTS_Request):
req = await request.json() req = request.dict()
return tts_handle(req) return await tts_handle(req)
@APP.get("/set_refer_audio") @APP.get("/set_refer_audio")
@ -404,10 +409,6 @@ async def set_sovits_weights(weights_path: str = None):
if __name__ == "__main__": if __name__ == "__main__":
try:
uvicorn.run(APP, host=host, port=port, workers=1) uvicorn.run(APP, host=host, port=port, workers=1)
except KeyboardInterrupt:
os.kill(os.getpid(), signal.SIGTERM)
exit(0)