Add int32

This commit is contained in:
KamioRinn 2024-08-19 23:30:19 +08:00
parent dbb6b42fdb
commit f20e90fac6

27
api.py
View File

@ -20,6 +20,7 @@
`-hp` - `覆盖 config.py 使用半精度`
`-sm` - `流式返回模式, 默认不启用, "close","c", "normal","n", "keepalive","k"`
·-mt` - `返回的音频编码格式, 流式默认ogg, 非流式默认wav, "wav", "ogg", "aac"`
·-st` - `返回的音频数据类型, 默认int16, "int16", "int32"`
·-cp` - `文本切分符号设定, 默认为空, ",.,。"字符串的方式传入`
`-hb` - `cnhubert路径`
@ -487,9 +488,14 @@ def pack_raw(audio_bytes, data, rate):
def pack_wav(audio_bytes, rate):
data = np.frombuffer(audio_bytes.getvalue(),dtype=np.int16)
wav_bytes = BytesIO()
sf.write(wav_bytes, data, rate, format='WAV')
if is_int32:
data = np.frombuffer(audio_bytes.getvalue(),dtype=np.int32)
wav_bytes = BytesIO()
sf.write(wav_bytes, data, rate, format='WAV', subtype='PCM_32')
else:
data = np.frombuffer(audio_bytes.getvalue(),dtype=np.int16)
wav_bytes = BytesIO()
sf.write(wav_bytes, data, rate, format='WAV')
return wav_bytes
@ -631,8 +637,10 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
audio_opt.append(audio)
audio_opt.append(zero_wav)
t4 = ttime()
audio_bytes = pack_audio(audio_bytes,(np.concatenate(audio_opt, 0) * 32768).astype(np.int16),hps.data.sampling_rate)
# audio_bytes = pack_audio(audio_bytes,(np.concatenate(audio_opt, 0) * 2147483647).astype(np.int32),hps.data.sampling_rate)
if is_int32:
audio_bytes = pack_audio(audio_bytes,(np.concatenate(audio_opt, 0) * 2147483647).astype(np.int32),hps.data.sampling_rate)
else:
audio_bytes = pack_audio(audio_bytes,(np.concatenate(audio_opt, 0) * 32768).astype(np.int16),hps.data.sampling_rate)
# logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
if stream_mode == "normal":
audio_bytes, audio_chunk = read_clean_buffer(audio_bytes)
@ -749,6 +757,7 @@ parser.add_argument("-hp", "--half_precision", action="store_true", default=Fals
# 此时 full_precision==True, half_precision==False
parser.add_argument("-sm", "--stream_mode", type=str, default="close", help="流式返回模式, close / normal / keepalive")
parser.add_argument("-mt", "--media_type", type=str, default="wav", help="音频编码格式, wav / ogg / aac")
parser.add_argument("-st", "--sub_type", type=str, default="int16", help="音频数据类型, int16 / int32")
parser.add_argument("-cp", "--cut_punc", type=str, default="", help="文本切分符号设定, 符号范围,.;?!、,。?!;:…")
# 切割常用分句符为 `python ./api.py -cp ".?!。?!"`
parser.add_argument("-hb", "--hubert_path", type=str, default=g_config.cnhubert_path, help="覆盖config.cnhubert_path")
@ -810,6 +819,14 @@ else:
media_type = "ogg"
logger.info(f"编码格式: {media_type}")
# 音频数据类型
if args.sub_type.lower() == 'int32':
is_int32 = True
logger.info(f"数据类型: int32")
else:
is_int32 = False
logger.info(f"数据类型: int16")
# 初始化模型
cnhubert.cnhubert_base_path = cnhubert_base_path
tokenizer = AutoTokenizer.from_pretrained(bert_path)