mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-07 15:19:59 +08:00
完善模型更改逻辑、文档
This commit is contained in:
parent
939971afe3
commit
7104a7e671
167
api.py
167
api.py
@ -14,7 +14,7 @@
|
|||||||
`-dl` - `默认参考音频语种, "中文","英文","日文","zh","en","ja"`
|
`-dl` - `默认参考音频语种, "中文","英文","日文","zh","en","ja"`
|
||||||
|
|
||||||
`-d` - `推理设备, "cuda","cpu","mps"`
|
`-d` - `推理设备, "cuda","cpu","mps"`
|
||||||
`-a` - `绑定地址, 默认"127.0.0.1"`
|
`-a` - `绑定地址, 默认"0.0.0.0"`
|
||||||
`-p` - `绑定端口, 默认9880, 可在 config.py 中指定`
|
`-p` - `绑定端口, 默认9880, 可在 config.py 中指定`
|
||||||
`-fp` - `覆盖 config.py 使用全精度`
|
`-fp` - `覆盖 config.py 使用全精度`
|
||||||
`-hp` - `覆盖 config.py 使用半精度`
|
`-hp` - `覆盖 config.py 使用半精度`
|
||||||
@ -54,13 +54,13 @@ POST:
|
|||||||
```
|
```
|
||||||
|
|
||||||
RESP:
|
RESP:
|
||||||
成功: 直接返回 wav 音频流, http code 200
|
成功: 直接返回 wav 音频流, http code 200
|
||||||
失败: 返回包含错误信息的 json, http code 400
|
失败: 返回包含错误信息的 json, http code 400
|
||||||
|
|
||||||
|
|
||||||
### 更换默认参考音频
|
### 更换默认参考音频
|
||||||
|
|
||||||
endpoint: `/change_refer`
|
endpoints: `/change_refer`, `/set_refer`
|
||||||
|
|
||||||
key与推理端一样
|
key与推理端一样
|
||||||
|
|
||||||
@ -76,8 +76,31 @@ POST:
|
|||||||
```
|
```
|
||||||
|
|
||||||
RESP:
|
RESP:
|
||||||
成功: json, http code 200
|
成功: json, http code 200
|
||||||
失败: json, 400
|
失败: json, 400
|
||||||
|
|
||||||
|
|
||||||
|
### 更换模型
|
||||||
|
|
||||||
|
endpoints: `/change_model`, `/change_weight`, `/set_model`, `/set_weight`
|
||||||
|
|
||||||
|
key alias:
|
||||||
|
"gpt", "gpt_path", "gpt_model_path"
|
||||||
|
"sovits", "sovits_path", "sovits_model_path"
|
||||||
|
|
||||||
|
GET:
|
||||||
|
`http://127.0.0.1:9880/change_weight?gpt=./GPT_weights/suijiSUI-e20.ckpt&sovits=./SoVITS_weights/suijiSUI_e20_s3280.pth`
|
||||||
|
POST:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"gpt": "./GPT_weights/suijiSUI-e20.ckpt",
|
||||||
|
"sovits": "./SoVITS_weights/suijiSUI_e20_s3280.pth"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
RESP:
|
||||||
|
成功: json, http code 200
|
||||||
|
失败: json, 400 | "Internal Server Error"
|
||||||
|
|
||||||
|
|
||||||
### 命令控制
|
### 命令控制
|
||||||
@ -85,8 +108,8 @@ RESP:
|
|||||||
endpoint: `/control`
|
endpoint: `/control`
|
||||||
|
|
||||||
command:
|
command:
|
||||||
"restart": 重新运行
|
"restart": 重新运行
|
||||||
"exit": 结束运行
|
"exit": 结束运行
|
||||||
|
|
||||||
GET:
|
GET:
|
||||||
`http://127.0.0.1:9880/control?command=restart`
|
`http://127.0.0.1:9880/control?command=restart`
|
||||||
@ -101,7 +124,6 @@ RESP: 无
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
@ -132,8 +154,6 @@ import config as global_config
|
|||||||
|
|
||||||
g_config = global_config.Config()
|
g_config = global_config.Config()
|
||||||
|
|
||||||
# AVAILABLE_COMPUTE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="GPT-SoVITS api")
|
parser = argparse.ArgumentParser(description="GPT-SoVITS api")
|
||||||
|
|
||||||
parser.add_argument("-s", "--sovits_path", type=str, default=g_config.sovits_path, help="SoVITS模型路径")
|
parser.add_argument("-s", "--sovits_path", type=str, default=g_config.sovits_path, help="SoVITS模型路径")
|
||||||
@ -216,17 +236,18 @@ else:
|
|||||||
|
|
||||||
def is_empty(*items): # 任意一项不为空返回False
|
def is_empty(*items): # 任意一项不为空返回False
|
||||||
for item in items:
|
for item in items:
|
||||||
if item is not None and item != "":
|
if item:
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def is_full(*items): # 任意一项为空返回False
|
def is_full(*items): # 任意一项为空返回False
|
||||||
for item in items:
|
for item in items:
|
||||||
if item is None or item == "":
|
if not item:
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def change_sovits_weights(sovits_path):
|
def change_sovits_weights(sovits_path):
|
||||||
global vq_model, hps
|
global vq_model, hps
|
||||||
dict_s2 = torch.load(sovits_path, map_location="cpu")
|
dict_s2 = torch.load(sovits_path, map_location="cpu")
|
||||||
@ -249,6 +270,8 @@ def change_sovits_weights(sovits_path):
|
|||||||
print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
|
print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
|
||||||
with open("./sweight.txt", "w", encoding="utf-8") as f:
|
with open("./sweight.txt", "w", encoding="utf-8") as f:
|
||||||
f.write(sovits_path)
|
f.write(sovits_path)
|
||||||
|
|
||||||
|
|
||||||
def change_gpt_weights(gpt_path):
|
def change_gpt_weights(gpt_path):
|
||||||
global hz, max_sec, t2s_model, config
|
global hz, max_sec, t2s_model, config
|
||||||
hz = 50
|
hz = 50
|
||||||
@ -438,9 +461,10 @@ def handle_control(command):
|
|||||||
exit(0)
|
exit(0)
|
||||||
|
|
||||||
|
|
||||||
def handle_change(path, text, language):
|
def handle_change_refer(path, text, language):
|
||||||
if is_empty(path, text, language):
|
if is_empty(path, text, language):
|
||||||
return JSONResponse({"code": 400, "message": '缺少任意一项以下参数: "path", "text", "language"'}, status_code=400)
|
return JSONResponse({"code": 400, "message": '缺少任意一项以下参数: "path", "text", "language"'},
|
||||||
|
status_code=400)
|
||||||
|
|
||||||
if path != "" or path is not None:
|
if path != "" or path is not None:
|
||||||
default_refer.path = path
|
default_refer.path = path
|
||||||
@ -457,12 +481,12 @@ def handle_change(path, text, language):
|
|||||||
return JSONResponse({"code": 0, "message": "Success"}, status_code=200)
|
return JSONResponse({"code": 0, "message": "Success"}, status_code=200)
|
||||||
|
|
||||||
|
|
||||||
def handle(refer_wav_path, prompt_text, prompt_language, text, text_language):
|
def handle_refer(refer_wav_path, prompt_text, prompt_language, text, text_language):
|
||||||
if (
|
if ( # 缺任意一个
|
||||||
refer_wav_path == "" or refer_wav_path is None
|
not refer_wav_path
|
||||||
or prompt_text == "" or prompt_text is None
|
or not prompt_text
|
||||||
or prompt_language == "" or prompt_language is None
|
or not prompt_language
|
||||||
):
|
): # 使用全局
|
||||||
refer_wav_path, prompt_text, prompt_language = (
|
refer_wav_path, prompt_text, prompt_language = (
|
||||||
default_refer.path,
|
default_refer.path,
|
||||||
default_refer.text,
|
default_refer.text,
|
||||||
@ -481,6 +505,7 @@ def handle(refer_wav_path, prompt_text, prompt_language, text, text_language):
|
|||||||
sf.write(wav, audio_data, sampling_rate, format="wav")
|
sf.write(wav, audio_data, sampling_rate, format="wav")
|
||||||
wav.seek(0)
|
wav.seek(0)
|
||||||
|
|
||||||
|
if device == "cuda":
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
if device == "mps":
|
if device == "mps":
|
||||||
print('executed torch.mps.empty_cache()')
|
print('executed torch.mps.empty_cache()')
|
||||||
@ -488,22 +513,88 @@ def handle(refer_wav_path, prompt_text, prompt_language, text, text_language):
|
|||||||
return StreamingResponse(wav, media_type="audio/wav")
|
return StreamingResponse(wav, media_type="audio/wav")
|
||||||
|
|
||||||
|
|
||||||
|
def handle_change_weights(gpt, sovits):
|
||||||
|
if is_empty(gpt, sovits):
|
||||||
|
return JSONResponse({"code": 400, "message": f"缺少任意一项以下参数: {gpt_alias}, {sovits_alias}"},
|
||||||
|
status_code=400)
|
||||||
|
|
||||||
|
global gpt_path, sovits_path
|
||||||
|
|
||||||
|
if gpt:
|
||||||
|
gpt_path = gpt
|
||||||
|
print(f"New gpt_path: {gpt_path}")
|
||||||
|
change_gpt_weights(gpt_path)
|
||||||
|
|
||||||
|
if sovits:
|
||||||
|
sovits_path = sovits
|
||||||
|
print(f"New sovits_path: {sovits_path}")
|
||||||
|
change_sovits_weights(sovits_path)
|
||||||
|
|
||||||
|
return JSONResponse({"code": 0, "message": "Success"}, status_code=200)
|
||||||
|
|
||||||
|
|
||||||
|
gpt_alias = (
|
||||||
|
"gpt",
|
||||||
|
"gpt_path",
|
||||||
|
"gpt_model_path" # @JavaAndPython55 用的这个key, 嫌太长直接alias了
|
||||||
|
)
|
||||||
|
sovits_alias = (
|
||||||
|
"sovits",
|
||||||
|
"sovits_path",
|
||||||
|
"sovits_model_path"
|
||||||
|
)
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
#clark新增-----2024-02-21
|
|
||||||
#可在启动后动态修改模型,以此满足同一个api不同的朗读者请求
|
# clark新增-----2024-02-21
|
||||||
|
# 可在启动后动态修改模型,以此满足同一个api不同的朗读者请求
|
||||||
@app.post("/set_model")
|
@app.post("/set_model")
|
||||||
async def set_model(request: Request):
|
@app.post("/set_weight")
|
||||||
|
@app.post("/change_model")
|
||||||
|
@app.post("/change_weight")
|
||||||
|
async def change_weight(request: Request):
|
||||||
json_post_raw = await request.json()
|
json_post_raw = await request.json()
|
||||||
global gpt_path
|
|
||||||
gpt_path=json_post_raw.get("gpt_model_path")
|
gpt, sovits = "", ""
|
||||||
global sovits_path
|
for ga in gpt_alias:
|
||||||
sovits_path=json_post_raw.get("sovits_model_path")
|
g = json_post_raw.get(ga)
|
||||||
print("gptpath"+gpt_path+";vitspath"+sovits_path)
|
if g:
|
||||||
change_sovits_weights(sovits_path)
|
gpt = g
|
||||||
change_gpt_weights(gpt_path)
|
break
|
||||||
return "ok"
|
for sa in sovits_alias:
|
||||||
# 新增-----end------
|
s = json_post_raw.get(sa)
|
||||||
|
if s:
|
||||||
|
sovits = s
|
||||||
|
break
|
||||||
|
|
||||||
|
return handle_change_weights(gpt, sovits)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/set_model")
|
||||||
|
@app.get("/set_weight")
|
||||||
|
@app.get("/change_model")
|
||||||
|
@app.get("/change_weight")
|
||||||
|
async def change_weight(
|
||||||
|
gpt: str = None,
|
||||||
|
gpt_path: str = None,
|
||||||
|
gpt_model_path: str = None,
|
||||||
|
sovits: str = None,
|
||||||
|
sovits_path: str = None,
|
||||||
|
sovits_model_path: str = None,
|
||||||
|
):
|
||||||
|
GPT, SOVITS = "", ""
|
||||||
|
for gg in (gpt, gpt_path, gpt_model_path):
|
||||||
|
if gg:
|
||||||
|
GPT = gg
|
||||||
|
break
|
||||||
|
for ss in (sovits, sovits_path, sovits_model_path):
|
||||||
|
if ss:
|
||||||
|
SOVITS = ss
|
||||||
|
break
|
||||||
|
|
||||||
|
return handle_change_weights(GPT, SOVITS)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/control")
|
@app.post("/control")
|
||||||
async def control(request: Request):
|
async def control(request: Request):
|
||||||
@ -516,29 +607,31 @@ async def control(command: str = None):
|
|||||||
return handle_control(command)
|
return handle_control(command)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/set_refer")
|
||||||
@app.post("/change_refer")
|
@app.post("/change_refer")
|
||||||
async def change_refer(request: Request):
|
async def change_refer(request: Request):
|
||||||
json_post_raw = await request.json()
|
json_post_raw = await request.json()
|
||||||
return handle_change(
|
return handle_change_refer(
|
||||||
json_post_raw.get("refer_wav_path"),
|
json_post_raw.get("refer_wav_path"),
|
||||||
json_post_raw.get("prompt_text"),
|
json_post_raw.get("prompt_text"),
|
||||||
json_post_raw.get("prompt_language")
|
json_post_raw.get("prompt_language")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/set_refer")
|
||||||
@app.get("/change_refer")
|
@app.get("/change_refer")
|
||||||
async def change_refer(
|
async def change_refer(
|
||||||
refer_wav_path: str = None,
|
refer_wav_path: str = None,
|
||||||
prompt_text: str = None,
|
prompt_text: str = None,
|
||||||
prompt_language: str = None
|
prompt_language: str = None
|
||||||
):
|
):
|
||||||
return handle_change(refer_wav_path, prompt_text, prompt_language)
|
return handle_change_refer(refer_wav_path, prompt_text, prompt_language)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/")
|
@app.post("/")
|
||||||
async def tts_endpoint(request: Request):
|
async def tts_endpoint(request: Request):
|
||||||
json_post_raw = await request.json()
|
json_post_raw = await request.json()
|
||||||
return handle(
|
return handle_refer(
|
||||||
json_post_raw.get("refer_wav_path"),
|
json_post_raw.get("refer_wav_path"),
|
||||||
json_post_raw.get("prompt_text"),
|
json_post_raw.get("prompt_text"),
|
||||||
json_post_raw.get("prompt_language"),
|
json_post_raw.get("prompt_language"),
|
||||||
@ -555,7 +648,7 @@ async def tts_endpoint(
|
|||||||
text: str = None,
|
text: str = None,
|
||||||
text_language: str = None,
|
text_language: str = None,
|
||||||
):
|
):
|
||||||
return handle(refer_wav_path, prompt_text, prompt_language, text, text_language)
|
return handle_refer(refer_wav_path, prompt_text, prompt_language, text, text_language)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
x
Reference in New Issue
Block a user