mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-29 13:47:47 +08:00
supported top_k, top_p, temperature, speed for api.py (#1340)
This commit is contained in:
parent
4e87bd1a25
commit
de0266d2d9
44
api.py
44
api.py
@ -72,6 +72,28 @@ RESP:
|
|||||||
成功: 直接返回 wav 音频流, http code 200
|
成功: 直接返回 wav 音频流, http code 200
|
||||||
失败: 返回包含错误信息的 json, http code 400
|
失败: 返回包含错误信息的 json, http code 400
|
||||||
|
|
||||||
|
手动指定当次推理所使用的参考音频,并提供参数:
|
||||||
|
GET:
|
||||||
|
`http://127.0.0.1:9880?refer_wav_path=123.wav&prompt_text=一二三。&prompt_language=zh&text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh&top_k=20&top_p=0.6&temperature=0.6&speed=1`
|
||||||
|
POST:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"refer_wav_path": "123.wav",
|
||||||
|
"prompt_text": "一二三。",
|
||||||
|
"prompt_language": "zh",
|
||||||
|
"text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。",
|
||||||
|
"text_language": "zh",
|
||||||
|
"top_k": 20,
|
||||||
|
"top_p": 0.6,
|
||||||
|
"temperature": 0.6,
|
||||||
|
"speed": 1
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
RESP:
|
||||||
|
成功: 直接返回 wav 音频流, http code 200
|
||||||
|
失败: 返回包含错误信息的 json, http code 400
|
||||||
|
|
||||||
|
|
||||||
### 更换默认参考音频
|
### 更换默认参考音频
|
||||||
|
|
||||||
@ -446,7 +468,7 @@ def only_punc(text):
|
|||||||
return not any(t.isalnum() or t.isalpha() for t in text)
|
return not any(t.isalnum() or t.isalpha() for t in text)
|
||||||
|
|
||||||
|
|
||||||
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language):
|
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, top_k= 20, top_p = 0.6, temperature = 0.6, speed = 1):
|
||||||
t0 = ttime()
|
t0 = ttime()
|
||||||
prompt_text = prompt_text.strip("\n")
|
prompt_text = prompt_text.strip("\n")
|
||||||
prompt_language, text = prompt_language, text.strip("\n")
|
prompt_language, text = prompt_language, text.strip("\n")
|
||||||
@ -494,7 +516,9 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language)
|
|||||||
prompt,
|
prompt,
|
||||||
bert,
|
bert,
|
||||||
# prompt_phone_len=ph_offset,
|
# prompt_phone_len=ph_offset,
|
||||||
top_k=config['inference']['top_k'],
|
top_k = top_k,
|
||||||
|
top_p = top_p,
|
||||||
|
temperature = temperature,
|
||||||
early_stop_num=hz * max_sec)
|
early_stop_num=hz * max_sec)
|
||||||
t3 = ttime()
|
t3 = ttime()
|
||||||
# print(pred_semantic.shape,idx)
|
# print(pred_semantic.shape,idx)
|
||||||
@ -507,7 +531,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language)
|
|||||||
# audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
|
# audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
|
||||||
audio = \
|
audio = \
|
||||||
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0),
|
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0),
|
||||||
refer).detach().cpu().numpy()[
|
refer,speed=speed).detach().cpu().numpy()[
|
||||||
0, 0] ###试试重建不带上prompt部分
|
0, 0] ###试试重建不带上prompt部分
|
||||||
audio_opt.append(audio)
|
audio_opt.append(audio)
|
||||||
audio_opt.append(zero_wav)
|
audio_opt.append(zero_wav)
|
||||||
@ -553,7 +577,7 @@ def handle_change(path, text, language):
|
|||||||
return JSONResponse({"code": 0, "message": "Success"}, status_code=200)
|
return JSONResponse({"code": 0, "message": "Success"}, status_code=200)
|
||||||
|
|
||||||
|
|
||||||
def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc):
|
def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed):
|
||||||
if (
|
if (
|
||||||
refer_wav_path == "" or refer_wav_path is None
|
refer_wav_path == "" or refer_wav_path is None
|
||||||
or prompt_text == "" or prompt_text is None
|
or prompt_text == "" or prompt_text is None
|
||||||
@ -572,7 +596,7 @@ def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cu
|
|||||||
else:
|
else:
|
||||||
text = cut_text(text,cut_punc)
|
text = cut_text(text,cut_punc)
|
||||||
|
|
||||||
return StreamingResponse(get_tts_wav(refer_wav_path, prompt_text, prompt_language, text, text_language), media_type="audio/"+media_type)
|
return StreamingResponse(get_tts_wav(refer_wav_path, prompt_text, prompt_language, text, text_language, top_k, top_p, temperature, speed), media_type="audio/"+media_type)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -755,6 +779,10 @@ async def tts_endpoint(request: Request):
|
|||||||
json_post_raw.get("text"),
|
json_post_raw.get("text"),
|
||||||
json_post_raw.get("text_language"),
|
json_post_raw.get("text_language"),
|
||||||
json_post_raw.get("cut_punc"),
|
json_post_raw.get("cut_punc"),
|
||||||
|
json_post_raw.get("top_k", 10),
|
||||||
|
json_post_raw.get("top_p" 1.0),
|
||||||
|
json_post_raw.get("temperature" 1.0),
|
||||||
|
json_post_raw.get("speed", 1.0)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -766,8 +794,12 @@ async def tts_endpoint(
|
|||||||
text: str = None,
|
text: str = None,
|
||||||
text_language: str = None,
|
text_language: str = None,
|
||||||
cut_punc: str = None,
|
cut_punc: str = None,
|
||||||
|
top_k: int = 10,
|
||||||
|
top_p: float = 1.0,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
speed: float = 1.0
|
||||||
):
|
):
|
||||||
return handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc)
|
return handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cut_punc, top_k, top_p, temperature, speed)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
x
Reference in New Issue
Block a user