Add stream mode arg

This commit is contained in:
KamioRinn 2024-03-29 12:11:56 +08:00
parent bbc4e2080f
commit d35dfd92b3

40
api.py
View File

@ -18,6 +18,7 @@
`-p` - `绑定端口, 默认9880, 可在 config.py 中指定`
`-fp` - `覆盖 config.py 使用全精度`
`-hp` - `覆盖 config.py 使用半精度`
`-sm` - `流式返回模式, 默认不启用, "close","c", "normal","n", "keepalive","k"`
`-hb` - `cnhubert路径`
`-b` - `bert路径`
@ -304,6 +305,19 @@ def get_spepc(hps, filename):
return spec
def pack_audio(audio_bytes, data, rate):
with sf.SoundFile(audio_bytes, mode='w', samplerate=rate, channels=1, format='ogg') as audio_file:
audio_file.write(data)
return audio_bytes
def read_clean_buffer(audio_bytes):
audio_chunk = audio_bytes.getvalue()
audio_bytes.truncate(0)
audio_bytes.seek(0)
return audio_bytes, audio_chunk
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language):
t0 = ttime()
prompt_text = prompt_text.strip("\n")
@ -328,9 +342,10 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language)
text_language = dict_language[text_language.lower()]
phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language)
texts = text.split("\n")
audio_opt = []
audio_bytes = BytesIO()
for text in texts:
audio_opt = []
phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language)
bert = torch.cat([bert1, bert2], 1)
@ -365,13 +380,15 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language)
audio_opt.append(audio)
audio_opt.append(zero_wav)
t4 = ttime()
audio_bytes = pack_audio(audio_bytes,(np.concatenate(audio_opt, 0) * 32768).astype(np.int16),hps.data.sampling_rate)
# logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
ogg = BytesIO()
sf.write(ogg, (np.concatenate(audio_opt, 0) * 32768).astype(np.int16), hps.data.sampling_rate, format="ogg")
ogg.seek(0)
chunk = ogg.read()
yield chunk
audio_opt = []
if stream_mode == "normal":
audio_bytes, audio_chunk = read_clean_buffer(audio_bytes)
yield audio_chunk
if not stream_mode == "normal":
yield audio_bytes.getvalue()
def handle_control(command):
@ -465,6 +482,7 @@ parser.add_argument("-fp", "--full_precision", action="store_true", default=Fals
parser.add_argument("-hp", "--half_precision", action="store_true", default=False, help="覆盖config.is_half为True, 使用半精度")
# bool值的用法为 `python ./api.py -fp ...`
# 此时 full_precision==True, half_precision==False
parser.add_argument("-sm", "--stream_mode", type=str, default="close", help="流式返回模式, close / normal / keepalive")
parser.add_argument("-hb", "--hubert_path", type=str, default=g_config.cnhubert_path, help="覆盖config.cnhubert_path")
parser.add_argument("-b", "--bert_path", type=str, default=g_config.bert_path, help="覆盖config.bert_path")
@ -507,6 +525,14 @@ if args.full_precision and args.half_precision:
is_half = g_config.is_half # 炒饭fallback
logger.info(f"半精: {is_half}")
# 流式返回模式
if args.stream_mode.lower() in ["normal","n"]:
stream_mode = "normal"
logger.info("流式返回已开启")
else:
stream_mode = "close"
# 初始化模型
cnhubert.cnhubert_base_path = cnhubert_base_path
tokenizer = AutoTokenizer.from_pretrained(bert_path)