支持设置默认切分方式以及切分方式切分,优化默认参数修改

现在只需输入想修改的默认参数即可修改,无需输入全部参数
This commit is contained in:
XXXXRT666 2024-03-03 09:07:53 +00:00 committed by GitHub
parent 0ab0e5390f
commit ee14237609
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

199
api.py
View File

@ -12,6 +12,7 @@
`-dr` - `默认参考音频路径`
`-dt` - `默认参考音频文本`
`-dl` - `默认参考音频语种, "中文","英文","日文","zh","en","ja"`
`-cut` - `默认切分方式,"凑四句一切","凑50字一切","按中文句号。切","按英文句号.切","按标点符号切",分别为cut1,cut2,cut3,cut4,cut5`
`-d` - `推理设备, "cuda","cpu","mps"`
`-a` - `绑定地址, 默认"127.0.0.1"`
@ -28,7 +29,7 @@
endpoint: `/`
使用执行参数指定的参考音频:
使用执行参数指定的参考音频与切分方式:
GET:
`http://127.0.0.1:9880?text=先帝创业未半而中道崩殂今天下三分益州疲弊此诚危急存亡之秋也&text_language=zh`
POST:
@ -39,9 +40,9 @@ POST:
}
```
手动指定当次推理所使用的参考音频:
手动指定当次推理所使用的参考音频与切分方式:
GET:
`http://127.0.0.1:9880?refer_wav_path=123.wav&prompt_text=一二三&prompt_language=zh&text=先帝创业未半而中道崩殂今天下三分益州疲弊此诚危急存亡之秋也&text_language=zh`
`http://127.0.0.1:9880?refer_wav_path=123.wav&prompt_text=一二三&prompt_language=zh&text=先帝创业未半而中道崩殂今天下三分益州疲弊此诚危急存亡之秋也&text_language=zh&how_to_cut=cut1`
POST:
```json
{
@ -50,6 +51,7 @@ POST:
"prompt_language": "zh",
"text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。",
"text_language": "zh"
"how_to_cut"= "cut1"
}
```
@ -58,20 +60,21 @@ RESP:
失败: 返回包含错误信息的 json, http code 400
### 更换默认参考音频
### 更换默认参考音频及切分方式
endpoint: `/change_refer`
key与推理端一样
GET:
`http://127.0.0.1:9880/change_refer?refer_wav_path=123.wav&prompt_text=一二三&prompt_language=zh`
`http://127.0.0.1:9880/change_refer?refer_wav_path=123.wav&prompt_text=一二三&prompt_language=zh&how_to_cut=cut1`
POST:
```json
{
"refer_wav_path": "123.wav",
"prompt_text": "一二三。",
"prompt_language": "zh"
"how_to_cut": "cut1"
}
```
@ -80,6 +83,7 @@ RESP:
失败: json, 400
### 命令控制
endpoint: `/control`
@ -103,7 +107,7 @@ RESP: 无
import argparse
import os
import os,re
import sys
now_dir = os.getcwd()
@ -142,6 +146,7 @@ parser.add_argument("-g", "--gpt_path", type=str, default=g_config.gpt_path, hel
parser.add_argument("-dr", "--default_refer_path", type=str, default="", help="默认参考音频路径")
parser.add_argument("-dt", "--default_refer_text", type=str, default="", help="默认参考音频文本")
parser.add_argument("-dl", "--default_refer_language", type=str, default="", help="默认参考音频语种")
parser.add_argument("-cut", "--default_how_to_cut", type=str, default="cut1", help="默认切分方式")
parser.add_argument("-d", "--device", type=str, default=g_config.infer_device, help="cuda / cpu / mps")
parser.add_argument("-a", "--bind_addr", type=str, default="0.0.0.0", help="default: 0.0.0.0")
@ -154,23 +159,27 @@ parser.add_argument("-hp", "--half_precision", action="store_true", default=Fals
parser.add_argument("-hb", "--hubert_path", type=str, default=g_config.cnhubert_path, help="覆盖config.cnhubert_path")
parser.add_argument("-b", "--bert_path", type=str, default=g_config.bert_path, help="覆盖config.bert_path")
args = parser.parse_args()
sovits_path = args.sovits_path
gpt_path = args.gpt_path
splits = {"", "", "", "", ",", ".", "?", "!", "~", ":", "", "", "", }
class DefaultRefer:
def __init__(self, path, text, language):
def __init__(self, path, text, language, how_to_cut):
self.path = args.default_refer_path
self.text = args.default_refer_text
self.language = args.default_refer_language
self.how_to_cut = args.default_how_to_cut
def is_ready(self) -> bool:
return is_full(self.path, self.text, self.language)
return is_full(self.path, self.text, self.language,self.how_to_cut)
default_refer = DefaultRefer(args.default_refer_path, args.default_refer_text, args.default_refer_language)
default_refer = DefaultRefer(args.default_refer_path, args.default_refer_text, args.default_refer_language, args.default_how_to_cut)
device = args.device
port = args.port
@ -188,10 +197,23 @@ if default_refer.path == "" or default_refer.text == "" or default_refer.languag
default_refer.path, default_refer.text, default_refer.language = "", "", ""
print("[INFO] 未指定默认参考音频")
else:
if (default_refer.how_to_cut == "cut1"):
cut_info = "凑四句一切"
elif (default_refer.how_to_cut == "cut2"):
cut_info = "凑50字一切"
elif (default_refer.how_to_cut == "cut3"):
cut_info = "按中文句号。切"
elif (default_refer.how_to_cut == "cut4"):
cut_info = "按英文句号.切"
elif (default_refer.how_to_cut == "cut5"):
cut_info = "按标点符号切"
print(f"[INFO] 默认参考音频路径: {default_refer.path}")
print(f"[INFO] 默认参考音频文本: {default_refer.text}")
print(f"[INFO] 默认参考音频语种: {default_refer.language}")
print(f"[INFO] 默认切割方式: {cut_info}")
is_half = g_config.is_half
if args.full_precision:
is_half = False
@ -354,6 +376,86 @@ dict_language = {
}
def split(todo_text):
todo_text = todo_text.replace("……", "").replace("——", "")
if todo_text[-1] not in splits:
todo_text += ""
i_split_head = i_split_tail = 0
len_text = len(todo_text)
todo_texts = []
while 1:
if i_split_head >= len_text:
break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
if todo_text[i_split_head] in splits:
i_split_head += 1
todo_texts.append(todo_text[i_split_tail:i_split_head])
i_split_tail = i_split_head
else:
i_split_head += 1
return todo_texts
def cut1(inp):
inp = inp.strip("\n")
inps = split(inp)
split_idx = list(range(0, len(inps), 4))
split_idx[-1] = None
if len(split_idx) > 1:
opts = []
for idx in range(len(split_idx) - 1):
opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]]))
else:
opts = [inp]
return "\n".join(opts)
def cut2(inp):
inp = inp.strip("\n")
inps = split(inp)
if len(inps) < 2:
return inp
opts = []
summ = 0
tmp_str = ""
for i in range(len(inps)):
summ += len(inps[i])
tmp_str += inps[i]
if summ > 50:
summ = 0
opts.append(tmp_str)
tmp_str = ""
if tmp_str != "":
opts.append(tmp_str)
# print(opts)
if len(opts) > 1 and len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起
opts[-2] = opts[-2] + opts[-1]
opts = opts[:-1]
return "\n".join(opts)
def cut3(inp):
inp = inp.strip("\n")
return "\n".join(["%s" % item for item in inp.strip("").split("")])
def cut4(inp):
inp = inp.strip("\n")
return "\n".join(["%s" % item for item in inp.strip(".").split(".")])
# contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py
def cut5(inp):
# if not re.search(r'[^\w\s]', inp[-1]):
# inp += '。'
inp = inp.strip("\n")
punds = r'[,.;?!、,。?!;:…]'
items = re.split(f'({punds})', inp)
mergeitems = ["".join(group) for group in zip(items[::2], items[1::2])]
# 在句子不存在符号或句尾无符号的时候保证文本完整
if len(items)%2 == 1:
mergeitems.append(items[-1])
opt = "\n".join(mergeitems)
return opt
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language):
t0 = ttime()
prompt_text = prompt_text.strip("\n")
@ -374,8 +476,10 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language)
codes = vq_model.extract_latent(ssl_content)
prompt_semantic = codes[0, 0]
t1 = ttime()
prompt_language = dict_language[prompt_language]
text_language = dict_language[text_language]
phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language)
phones1 = cleaned_text_to_sequence(phones1)
texts = text.split("\n")
@ -438,26 +542,46 @@ def handle_control(command):
exit(0)
def handle_change(path, text, language):
if is_empty(path, text, language):
return JSONResponse({"code": 400, "message": '缺少任意一项以下参数: "path", "text", "language"'}, status_code=400)
def handle_change(path, text, language, how_to_cut):
if is_empty(path, text, language, how_to_cut):
return JSONResponse({"code": 400, "message": '缺少任意一项以下参数: "path", "text", "language","how_to_cut"'}, status_code=400)
if path == "" or path is None:
path = default_refer.path
if text == "" or text is None:
text = default_refer.text
if language == "" or language is None:
language = default_refer.language
if how_to_cut == "" or how_to_cut is None:
how_to_cut = default_refer.how_to_cut
default_refer.path = path
default_refer.text = text
default_refer.language = language
default_refer.how_to_cut = how_to_cut
if path != "" or path is not None:
default_refer.path = path
if text != "" or text is not None:
default_refer.text = text
if language != "" or language is not None:
default_refer.language = language
if (default_refer.how_to_cut == "cut1"):
cut_info = "凑四句一切"
elif (default_refer.how_to_cut == "cut2"):
cut_info = "凑50字一切"
elif (default_refer.how_to_cut == "cut3"):
cut_info = "按中文句号。切"
elif (default_refer.how_to_cut == "cut4"):
cut_info = "按英文句号.切"
elif (default_refer.how_to_cut == "cut5"):
cut_info = "按标点符号切"
print(f"[INFO] 当前默认参考音频路径: {default_refer.path}")
print(f"[INFO] 当前默认参考音频文本: {default_refer.text}")
print(f"[INFO] 当前默认参考音频语种: {default_refer.language}")
print(f"[INFO] 默认切割方式: {cut_info}")
print(f"[INFO] is_ready: {default_refer.is_ready()}")
return JSONResponse({"code": 0, "message": "Success"}, status_code=200)
def handle(refer_wav_path, prompt_text, prompt_language, text, text_language):
def handle(refer_wav_path, prompt_text, prompt_language, text, text_language,how_to_cut):
if (
refer_wav_path == "" or refer_wav_path is None
or prompt_text == "" or prompt_text is None
@ -468,8 +592,30 @@ def handle(refer_wav_path, prompt_text, prompt_language, text, text_language):
default_refer.text,
default_refer.language,
)
if (how_to_cut == "" or how_to_cut is None):
how_to_cut = default_refer.how_to_cut
if not default_refer.is_ready():
return JSONResponse({"code": 400, "message": "未指定参考音频且接口无预设"}, status_code=400)
if (how_to_cut == "cut1"):
text = cut1(text) #凑四句一切
elif (how_to_cut == "cut2"):
text = cut2(text) #凑50字一切
elif (how_to_cut == "cut3"):
text = cut3(text) #按中文句号。切
elif (how_to_cut == "cut4"):
text = cut4(text) #按英文句号.切
elif (how_to_cut == "cut5"):
text = cut5(text) #按标点符号切
while "\n\n" in text:
text = text.replace("\n\n", "\n")
print("实际输入的目标文本(切句后):", text)
with torch.no_grad():
gen = get_tts_wav(
@ -522,7 +668,8 @@ async def change_refer(request: Request):
return handle_change(
json_post_raw.get("refer_wav_path"),
json_post_raw.get("prompt_text"),
json_post_raw.get("prompt_language")
json_post_raw.get("prompt_language"),
json_post_raw.get("how_to_cut")
)
@ -530,9 +677,11 @@ async def change_refer(request: Request):
async def change_refer(
refer_wav_path: str = None,
prompt_text: str = None,
prompt_language: str = None
prompt_language: str = None,
how_to_cut: str = None
):
return handle_change(refer_wav_path, prompt_text, prompt_language)
return handle_change(refer_wav_path, prompt_text, prompt_language, how_to_cut)
@app.post("/")
@ -544,7 +693,8 @@ async def tts_endpoint(request: Request):
json_post_raw.get("prompt_language"),
json_post_raw.get("text"),
json_post_raw.get("text_language"),
)
json_post_raw.get("how_to_cut")
)
@app.get("/")
@ -554,8 +704,9 @@ async def tts_endpoint(
prompt_language: str = None,
text: str = None,
text_language: str = None,
how_to_cut: str = None
):
return handle(refer_wav_path, prompt_text, prompt_language, text, text_language)
return handle(refer_wav_path, prompt_text, prompt_language, text, text_language,how_to_cut)
if __name__ == "__main__":