feat: api.py change refer

This commit is contained in:
Miuzarte 2024-01-24 20:16:39 +08:00
parent 93dd8334f4
commit 5111713ed7

90
api.py
View File

@ -7,7 +7,7 @@ import torch
import librosa import librosa
import soundfile as sf import soundfile as sf
from fastapi import FastAPI, Request, HTTPException from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse, JSONResponse
import uvicorn import uvicorn
from transformers import AutoModelForMaskedLM, AutoTokenizer from transformers import AutoModelForMaskedLM, AutoTokenizer
import numpy as np import numpy as np
@ -51,10 +51,18 @@ args = parser.parse_args()
sovits_path = args.sovits_path sovits_path = args.sovits_path
gpt_path = args.gpt_path gpt_path = args.gpt_path
default_refer_path = args.default_refer_path
default_refer_text = args.default_refer_text class DefaultRefer:
default_refer_language = args.default_refer_language def __init__(self, path, text, language):
has_preset = False self.path = args.default_refer_path
self.text = args.default_refer_text
self.language = args.default_refer_language
def is_ready(self) -> bool:
return is_full(self.path, self.text, self.language)
default_refer = DefaultRefer(args.default_refer_path, args.default_refer_text, args.default_refer_language)
device = args.device device = args.device
port = args.port port = args.port
@ -68,15 +76,13 @@ if gpt_path == "":
print(f"[WARN] 未指定GPT模型路径, fallback后当前值: {gpt_path}") print(f"[WARN] 未指定GPT模型路径, fallback后当前值: {gpt_path}")
# 指定默认参考音频, 调用方 未提供/未给全 参考音频参数时使用 # 指定默认参考音频, 调用方 未提供/未给全 参考音频参数时使用
if default_refer_path == "" or default_refer_text == "" or default_refer_language == "": if default_refer.path == "" or default_refer.text == "" or default_refer.language == "":
default_refer_path, default_refer_text, default_refer_language = "", "", "" default_refer.path, default_refer.text, default_refer.language = "", "", ""
print("[INFO] 未指定默认参考音频") print("[INFO] 未指定默认参考音频")
has_preset = False
else: else:
print(f"[INFO] 默认参考音频路径: {default_refer_path}") print(f"[INFO] 默认参考音频路径: {default_refer.path}")
print(f"[INFO] 默认参考音频文本: {default_refer_text}") print(f"[INFO] 默认参考音频文本: {default_refer.text}")
print(f"[INFO] 默认参考音频语种: {default_refer_language}") print(f"[INFO] 默认参考音频语种: {default_refer.language}")
has_preset = True
is_half = g_config.is_half is_half = g_config.is_half
if args.full_precision: if args.full_precision:
@ -100,6 +106,20 @@ else:
bert_model = bert_model.to(device) bert_model = bert_model.to(device)
def is_empty(*items): # 任意一项不为空返回False
for item in items:
if item is not None and item != "":
return False
return True
def is_full(*items): # 任意一项为空返回False
for item in items:
if item is None or item == "":
return False
return True
def get_bert_feature(text, word2ph): def get_bert_feature(text, word2ph):
with torch.no_grad(): with torch.no_grad():
inputs = tokenizer(text, return_tensors="pt") inputs = tokenizer(text, return_tensors="pt")
@ -264,6 +284,25 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language)
yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(np.int16) yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(np.int16)
def handle_change(path, text, language):
if is_empty(path, text, language):
raise HTTPException(status_code=400, detail='缺少任意一项以下参数: "path", "text", "language"')
if path != "" or path is not None:
default_refer.path = path
if text != "" or text is not None:
default_refer.text = text
if language != "" or language is not None:
default_refer.language = language
print(f"[INFO] 当前默认参考音频路径: {default_refer.path}")
print(f"[INFO] 当前默认参考音频文本: {default_refer.text}")
print(f"[INFO] 当前默认参考音频语种: {default_refer.language}")
print(f"[INFO] is_ready: {default_refer.is_ready()}")
return JSONResponse({"code": 0, "message": "Success"}, status_code=200)
def handle(command, refer_wav_path, prompt_text, prompt_language, text, text_language): def handle(command, refer_wav_path, prompt_text, prompt_language, text, text_language):
if command == "/restart": if command == "/restart":
os.execl(g_config.python_exec, g_config.python_exec, *sys.argv) os.execl(g_config.python_exec, g_config.python_exec, *sys.argv)
@ -277,11 +316,11 @@ def handle(command, refer_wav_path, prompt_text, prompt_language, text, text_lan
or prompt_language == "" or prompt_language is None or prompt_language == "" or prompt_language is None
): ):
refer_wav_path, prompt_text, prompt_language = ( refer_wav_path, prompt_text, prompt_language = (
default_refer_path, default_refer.path,
default_refer_text, default_refer.text,
default_refer_language, default_refer.language,
) )
if not has_preset: if not default_refer.is_ready():
raise HTTPException(status_code=400, detail="未指定参考音频且接口无预设") raise HTTPException(status_code=400, detail="未指定参考音频且接口无预设")
with torch.no_grad(): with torch.no_grad():
@ -301,6 +340,25 @@ def handle(command, refer_wav_path, prompt_text, prompt_language, text, text_lan
app = FastAPI() app = FastAPI()
@app.post("/change_refer")
async def change_refer(request: Request):
json_post_raw = await request.json()
return handle_change(
json_post_raw.get("path"),
json_post_raw.get("text"),
json_post_raw.get("language")
)
@app.get("/change_refer")
async def change_refer(
path: str = None,
text: str = None,
language: str = None
):
return handle_change(path, text, language)
@app.post("/") @app.post("/")
async def tts_endpoint(request: Request): async def tts_endpoint(request: Request):
json_post_raw = await request.json() json_post_raw = await request.json()