diff --git a/api.py b/api.py index 91294c8a..11c642bc 100644 --- a/api.py +++ b/api.py @@ -18,6 +18,7 @@ `-p` - `绑定端口, 默认9880, 可在 config.py 中指定` `-fp` - `覆盖 config.py 使用全精度` `-hp` - `覆盖 config.py 使用半精度` +`-sm` - `流式返回模式, 默认不启用, "close","c", "normal","n", "keepalive","k"` `-hb` - `cnhubert路径` `-b` - `bert路径` @@ -304,6 +305,19 @@ def get_spepc(hps, filename): return spec +def pack_audio(audio_bytes, data, rate): + with sf.SoundFile(audio_bytes, mode='w', samplerate=rate, channels=1, format='ogg') as audio_file: + audio_file.write(data) + return audio_bytes + + +def read_clean_buffer(audio_bytes): + audio_chunk = audio_bytes.getvalue() + audio_bytes.truncate(0) + audio_bytes.seek(0) + return audio_bytes, audio_chunk + + def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language): t0 = ttime() prompt_text = prompt_text.strip("\n") @@ -328,9 +342,10 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language) text_language = dict_language[text_language.lower()] phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language) texts = text.split("\n") - audio_opt = [] + audio_bytes = BytesIO() for text in texts: + audio_opt = [] phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language) bert = torch.cat([bert1, bert2], 1) @@ -365,13 +380,15 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language) audio_opt.append(audio) audio_opt.append(zero_wav) t4 = ttime() + audio_bytes = pack_audio(audio_bytes,(np.concatenate(audio_opt, 0) * 32768).astype(np.int16),hps.data.sampling_rate) # logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3)) - ogg = BytesIO() - sf.write(ogg, (np.concatenate(audio_opt, 0) * 32768).astype(np.int16), hps.data.sampling_rate, format="ogg") - ogg.seek(0) - chunk = ogg.read() - yield chunk - audio_opt = [] + if stream_mode == "normal": + audio_bytes, audio_chunk = read_clean_buffer(audio_bytes) + yield audio_chunk + + if not stream_mode == "normal": + yield audio_bytes.getvalue() + def handle_control(command): @@ -465,6 +482,7 @@ parser.add_argument("-fp", "--full_precision", action="store_true", default=Fals parser.add_argument("-hp", "--half_precision", action="store_true", default=False, help="覆盖config.is_half为True, 使用半精度") # bool值的用法为 `python ./api.py -fp ...` # 此时 full_precision==True, half_precision==False +parser.add_argument("-sm", "--stream_mode", type=str, default="close", help="流式返回模式, close / normal / keepalive") parser.add_argument("-hb", "--hubert_path", type=str, default=g_config.cnhubert_path, help="覆盖config.cnhubert_path") parser.add_argument("-b", "--bert_path", type=str, default=g_config.bert_path, help="覆盖config.bert_path") @@ -507,6 +525,14 @@ if args.full_precision and args.half_precision: is_half = g_config.is_half # 炒饭fallback logger.info(f"半精: {is_half}") +# 流式返回模式 +if args.stream_mode.lower() in ["normal","n"]: + stream_mode = "normal" + logger.info("流式返回已开启") +else: + stream_mode = "close" + + # 初始化模型 cnhubert.cnhubert_base_path = cnhubert_base_path tokenizer = AutoTokenizer.from_pretrained(bert_path)