mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-05 12:38:35 +08:00
feat: api.py change refer
This commit is contained in:
parent
93dd8334f4
commit
5111713ed7
90
api.py
90
api.py
@ -7,7 +7,7 @@ import torch
|
||||
import librosa
|
||||
import soundfile as sf
|
||||
from fastapi import FastAPI, Request, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
import uvicorn
|
||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
import numpy as np
|
||||
@ -51,10 +51,18 @@ args = parser.parse_args()
|
||||
sovits_path = args.sovits_path
|
||||
gpt_path = args.gpt_path
|
||||
|
||||
default_refer_path = args.default_refer_path
|
||||
default_refer_text = args.default_refer_text
|
||||
default_refer_language = args.default_refer_language
|
||||
has_preset = False
|
||||
|
||||
class DefaultRefer:
|
||||
def __init__(self, path, text, language):
|
||||
self.path = args.default_refer_path
|
||||
self.text = args.default_refer_text
|
||||
self.language = args.default_refer_language
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
return is_full(self.path, self.text, self.language)
|
||||
|
||||
|
||||
default_refer = DefaultRefer(args.default_refer_path, args.default_refer_text, args.default_refer_language)
|
||||
|
||||
device = args.device
|
||||
port = args.port
|
||||
@ -68,15 +76,13 @@ if gpt_path == "":
|
||||
print(f"[WARN] 未指定GPT模型路径, fallback后当前值: {gpt_path}")
|
||||
|
||||
# 指定默认参考音频, 调用方 未提供/未给全 参考音频参数时使用
|
||||
if default_refer_path == "" or default_refer_text == "" or default_refer_language == "":
|
||||
default_refer_path, default_refer_text, default_refer_language = "", "", ""
|
||||
if default_refer.path == "" or default_refer.text == "" or default_refer.language == "":
|
||||
default_refer.path, default_refer.text, default_refer.language = "", "", ""
|
||||
print("[INFO] 未指定默认参考音频")
|
||||
has_preset = False
|
||||
else:
|
||||
print(f"[INFO] 默认参考音频路径: {default_refer_path}")
|
||||
print(f"[INFO] 默认参考音频文本: {default_refer_text}")
|
||||
print(f"[INFO] 默认参考音频语种: {default_refer_language}")
|
||||
has_preset = True
|
||||
print(f"[INFO] 默认参考音频路径: {default_refer.path}")
|
||||
print(f"[INFO] 默认参考音频文本: {default_refer.text}")
|
||||
print(f"[INFO] 默认参考音频语种: {default_refer.language}")
|
||||
|
||||
is_half = g_config.is_half
|
||||
if args.full_precision:
|
||||
@ -100,6 +106,20 @@ else:
|
||||
bert_model = bert_model.to(device)
|
||||
|
||||
|
||||
def is_empty(*items): # 任意一项不为空返回False
|
||||
for item in items:
|
||||
if item is not None and item != "":
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def is_full(*items): # 任意一项为空返回False
|
||||
for item in items:
|
||||
if item is None or item == "":
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_bert_feature(text, word2ph):
|
||||
with torch.no_grad():
|
||||
inputs = tokenizer(text, return_tensors="pt")
|
||||
@ -264,6 +284,25 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language)
|
||||
yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(np.int16)
|
||||
|
||||
|
||||
def handle_change(path, text, language):
|
||||
if is_empty(path, text, language):
|
||||
raise HTTPException(status_code=400, detail='缺少任意一项以下参数: "path", "text", "language"')
|
||||
|
||||
if path != "" or path is not None:
|
||||
default_refer.path = path
|
||||
if text != "" or text is not None:
|
||||
default_refer.text = text
|
||||
if language != "" or language is not None:
|
||||
default_refer.language = language
|
||||
|
||||
print(f"[INFO] 当前默认参考音频路径: {default_refer.path}")
|
||||
print(f"[INFO] 当前默认参考音频文本: {default_refer.text}")
|
||||
print(f"[INFO] 当前默认参考音频语种: {default_refer.language}")
|
||||
print(f"[INFO] is_ready: {default_refer.is_ready()}")
|
||||
|
||||
return JSONResponse({"code": 0, "message": "Success"}, status_code=200)
|
||||
|
||||
|
||||
def handle(command, refer_wav_path, prompt_text, prompt_language, text, text_language):
|
||||
if command == "/restart":
|
||||
os.execl(g_config.python_exec, g_config.python_exec, *sys.argv)
|
||||
@ -277,11 +316,11 @@ def handle(command, refer_wav_path, prompt_text, prompt_language, text, text_lan
|
||||
or prompt_language == "" or prompt_language is None
|
||||
):
|
||||
refer_wav_path, prompt_text, prompt_language = (
|
||||
default_refer_path,
|
||||
default_refer_text,
|
||||
default_refer_language,
|
||||
default_refer.path,
|
||||
default_refer.text,
|
||||
default_refer.language,
|
||||
)
|
||||
if not has_preset:
|
||||
if not default_refer.is_ready():
|
||||
raise HTTPException(status_code=400, detail="未指定参考音频且接口无预设")
|
||||
|
||||
with torch.no_grad():
|
||||
@ -301,6 +340,25 @@ def handle(command, refer_wav_path, prompt_text, prompt_language, text, text_lan
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
@app.post("/change_refer")
|
||||
async def change_refer(request: Request):
|
||||
json_post_raw = await request.json()
|
||||
return handle_change(
|
||||
json_post_raw.get("path"),
|
||||
json_post_raw.get("text"),
|
||||
json_post_raw.get("language")
|
||||
)
|
||||
|
||||
|
||||
@app.get("/change_refer")
|
||||
async def change_refer(
|
||||
path: str = None,
|
||||
text: str = None,
|
||||
language: str = None
|
||||
):
|
||||
return handle_change(path, text, language)
|
||||
|
||||
|
||||
@app.post("/")
|
||||
async def tts_endpoint(request: Request):
|
||||
json_post_raw = await request.json()
|
||||
|
Loading…
x
Reference in New Issue
Block a user