mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-06 03:57:44 +08:00
feat: api.py change refer
This commit is contained in:
parent
93dd8334f4
commit
5111713ed7
90
api.py
90
api.py
@ -7,7 +7,7 @@ import torch
|
|||||||
import librosa
|
import librosa
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
from fastapi import FastAPI, Request, HTTPException
|
from fastapi import FastAPI, Request, HTTPException
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse, JSONResponse
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -51,10 +51,18 @@ args = parser.parse_args()
|
|||||||
sovits_path = args.sovits_path
|
sovits_path = args.sovits_path
|
||||||
gpt_path = args.gpt_path
|
gpt_path = args.gpt_path
|
||||||
|
|
||||||
default_refer_path = args.default_refer_path
|
|
||||||
default_refer_text = args.default_refer_text
|
class DefaultRefer:
|
||||||
default_refer_language = args.default_refer_language
|
def __init__(self, path, text, language):
|
||||||
has_preset = False
|
self.path = args.default_refer_path
|
||||||
|
self.text = args.default_refer_text
|
||||||
|
self.language = args.default_refer_language
|
||||||
|
|
||||||
|
def is_ready(self) -> bool:
|
||||||
|
return is_full(self.path, self.text, self.language)
|
||||||
|
|
||||||
|
|
||||||
|
default_refer = DefaultRefer(args.default_refer_path, args.default_refer_text, args.default_refer_language)
|
||||||
|
|
||||||
device = args.device
|
device = args.device
|
||||||
port = args.port
|
port = args.port
|
||||||
@ -68,15 +76,13 @@ if gpt_path == "":
|
|||||||
print(f"[WARN] 未指定GPT模型路径, fallback后当前值: {gpt_path}")
|
print(f"[WARN] 未指定GPT模型路径, fallback后当前值: {gpt_path}")
|
||||||
|
|
||||||
# 指定默认参考音频, 调用方 未提供/未给全 参考音频参数时使用
|
# 指定默认参考音频, 调用方 未提供/未给全 参考音频参数时使用
|
||||||
if default_refer_path == "" or default_refer_text == "" or default_refer_language == "":
|
if default_refer.path == "" or default_refer.text == "" or default_refer.language == "":
|
||||||
default_refer_path, default_refer_text, default_refer_language = "", "", ""
|
default_refer.path, default_refer.text, default_refer.language = "", "", ""
|
||||||
print("[INFO] 未指定默认参考音频")
|
print("[INFO] 未指定默认参考音频")
|
||||||
has_preset = False
|
|
||||||
else:
|
else:
|
||||||
print(f"[INFO] 默认参考音频路径: {default_refer_path}")
|
print(f"[INFO] 默认参考音频路径: {default_refer.path}")
|
||||||
print(f"[INFO] 默认参考音频文本: {default_refer_text}")
|
print(f"[INFO] 默认参考音频文本: {default_refer.text}")
|
||||||
print(f"[INFO] 默认参考音频语种: {default_refer_language}")
|
print(f"[INFO] 默认参考音频语种: {default_refer.language}")
|
||||||
has_preset = True
|
|
||||||
|
|
||||||
is_half = g_config.is_half
|
is_half = g_config.is_half
|
||||||
if args.full_precision:
|
if args.full_precision:
|
||||||
@ -100,6 +106,20 @@ else:
|
|||||||
bert_model = bert_model.to(device)
|
bert_model = bert_model.to(device)
|
||||||
|
|
||||||
|
|
||||||
|
def is_empty(*items): # 任意一项不为空返回False
|
||||||
|
for item in items:
|
||||||
|
if item is not None and item != "":
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def is_full(*items): # 任意一项为空返回False
|
||||||
|
for item in items:
|
||||||
|
if item is None or item == "":
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def get_bert_feature(text, word2ph):
|
def get_bert_feature(text, word2ph):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
inputs = tokenizer(text, return_tensors="pt")
|
inputs = tokenizer(text, return_tensors="pt")
|
||||||
@ -264,6 +284,25 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language)
|
|||||||
yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(np.int16)
|
yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(np.int16)
|
||||||
|
|
||||||
|
|
||||||
|
def handle_change(path, text, language):
|
||||||
|
if is_empty(path, text, language):
|
||||||
|
raise HTTPException(status_code=400, detail='缺少任意一项以下参数: "path", "text", "language"')
|
||||||
|
|
||||||
|
if path != "" or path is not None:
|
||||||
|
default_refer.path = path
|
||||||
|
if text != "" or text is not None:
|
||||||
|
default_refer.text = text
|
||||||
|
if language != "" or language is not None:
|
||||||
|
default_refer.language = language
|
||||||
|
|
||||||
|
print(f"[INFO] 当前默认参考音频路径: {default_refer.path}")
|
||||||
|
print(f"[INFO] 当前默认参考音频文本: {default_refer.text}")
|
||||||
|
print(f"[INFO] 当前默认参考音频语种: {default_refer.language}")
|
||||||
|
print(f"[INFO] is_ready: {default_refer.is_ready()}")
|
||||||
|
|
||||||
|
return JSONResponse({"code": 0, "message": "Success"}, status_code=200)
|
||||||
|
|
||||||
|
|
||||||
def handle(command, refer_wav_path, prompt_text, prompt_language, text, text_language):
|
def handle(command, refer_wav_path, prompt_text, prompt_language, text, text_language):
|
||||||
if command == "/restart":
|
if command == "/restart":
|
||||||
os.execl(g_config.python_exec, g_config.python_exec, *sys.argv)
|
os.execl(g_config.python_exec, g_config.python_exec, *sys.argv)
|
||||||
@ -277,11 +316,11 @@ def handle(command, refer_wav_path, prompt_text, prompt_language, text, text_lan
|
|||||||
or prompt_language == "" or prompt_language is None
|
or prompt_language == "" or prompt_language is None
|
||||||
):
|
):
|
||||||
refer_wav_path, prompt_text, prompt_language = (
|
refer_wav_path, prompt_text, prompt_language = (
|
||||||
default_refer_path,
|
default_refer.path,
|
||||||
default_refer_text,
|
default_refer.text,
|
||||||
default_refer_language,
|
default_refer.language,
|
||||||
)
|
)
|
||||||
if not has_preset:
|
if not default_refer.is_ready():
|
||||||
raise HTTPException(status_code=400, detail="未指定参考音频且接口无预设")
|
raise HTTPException(status_code=400, detail="未指定参考音频且接口无预设")
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@ -301,6 +340,25 @@ def handle(command, refer_wav_path, prompt_text, prompt_language, text, text_lan
|
|||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/change_refer")
|
||||||
|
async def change_refer(request: Request):
|
||||||
|
json_post_raw = await request.json()
|
||||||
|
return handle_change(
|
||||||
|
json_post_raw.get("path"),
|
||||||
|
json_post_raw.get("text"),
|
||||||
|
json_post_raw.get("language")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/change_refer")
|
||||||
|
async def change_refer(
|
||||||
|
path: str = None,
|
||||||
|
text: str = None,
|
||||||
|
language: str = None
|
||||||
|
):
|
||||||
|
return handle_change(path, text, language)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/")
|
@app.post("/")
|
||||||
async def tts_endpoint(request: Request):
|
async def tts_endpoint(request: Request):
|
||||||
json_post_raw = await request.json()
|
json_post_raw = await request.json()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user