diff --git a/api.py b/api.py index 11c642bc..af65286b 100644 --- a/api.py +++ b/api.py @@ -19,6 +19,7 @@ `-fp` - `覆盖 config.py 使用全精度` `-hp` - `覆盖 config.py 使用半精度` `-sm` - `流式返回模式, 默认不启用, "close","c", "normal","n", "keepalive","k"` +·-mt` - `返回的音频编码格式, 流式默认ogg, 非流式默认wav, "wav", "ogg", "aac"` `-hb` - `cnhubert路径` `-b` - `bert路径` @@ -127,6 +128,7 @@ from module.mel_processing import spectrogram_torch from my_utils import load_audio import config as global_config import logging +import subprocess class DefaultRefer: @@ -306,8 +308,54 @@ def get_spepc(hps, filename): def pack_audio(audio_bytes, data, rate): + if media_type == "ogg": + audio_bytes = pack_ogg(audio_bytes, data, rate) + elif media_type == "aac": + audio_bytes = pack_aac(audio_bytes, data, rate) + else: + # wav无法流式, 先暂存raw + audio_bytes = pack_raw(audio_bytes, data, rate) + + return audio_bytes + + +def pack_ogg(audio_bytes, data, rate): with sf.SoundFile(audio_bytes, mode='w', samplerate=rate, channels=1, format='ogg') as audio_file: audio_file.write(data) + + return audio_bytes + + +def pack_raw(audio_bytes, data, rate): + audio_bytes.write(data.tobytes()) + + return audio_bytes + + +def pack_wav(audio_bytes, rate): + data = np.frombuffer(audio_bytes.getvalue(),dtype=np.int16) + wav_bytes = BytesIO() + sf.write(wav_bytes, data, rate, format='wav') + + return wav_bytes + + +def pack_aac(audio_bytes, data, rate): + process = subprocess.Popen([ + 'ffmpeg', + '-f', 's16le', # 输入16位有符号小端整数PCM + '-ar', str(rate), # 设置采样率 + '-ac', '1', # 单声道 + '-i', 'pipe:0', # 从管道读取输入 + '-c:a', 'aac', # 音频编码器为AAC + '-b:a', '192k', # 比特率 + '-vn', # 不包含视频 + '-f', 'adts', # 输出AAC数据流格式 + 'pipe:1' # 将输出写入管道 + ], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + out, _ = process.communicate(input=data.tobytes()) + audio_bytes.write(out) + return audio_bytes @@ -315,6 +363,7 @@ def read_clean_buffer(audio_bytes): audio_chunk = audio_bytes.getvalue() audio_bytes.truncate(0) audio_bytes.seek(0) + return audio_bytes, audio_chunk @@ -387,6 +436,8 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language) yield audio_chunk if not stream_mode == "normal": + if media_type == "wav": + audio_bytes = pack_wav(audio_bytes,hps.data.sampling_rate) yield audio_bytes.getvalue() @@ -433,7 +484,7 @@ def handle(refer_wav_path, prompt_text, prompt_language, text, text_language): if not default_refer.is_ready(): return JSONResponse({"code": 400, "message": "未指定参考音频且接口无预设"}, status_code=400) - return StreamingResponse(get_tts_wav(refer_wav_path, prompt_text, prompt_language, text, text_language), media_type="audio/ogg") + return StreamingResponse(get_tts_wav(refer_wav_path, prompt_text, prompt_language, text, text_language), media_type="audio/"+media_type) @@ -483,6 +534,7 @@ parser.add_argument("-hp", "--half_precision", action="store_true", default=Fals # bool值的用法为 `python ./api.py -fp ...` # 此时 full_precision==True, half_precision==False parser.add_argument("-sm", "--stream_mode", type=str, default="close", help="流式返回模式, close / normal / keepalive") +parser.add_argument("-mt", "--media_type", type=str, default="wav", help="音频编码格式, wav / ogg / aac") parser.add_argument("-hb", "--hubert_path", type=str, default=g_config.cnhubert_path, help="覆盖config.cnhubert_path") parser.add_argument("-b", "--bert_path", type=str, default=g_config.bert_path, help="覆盖config.bert_path") @@ -532,6 +584,14 @@ if args.stream_mode.lower() in ["normal","n"]: else: stream_mode = "close" +# 音频编码格式 +if args.media_type.lower() in ["aac","ogg"]: + media_type = args.media_type.lower() +elif stream_mode == "close": + media_type = "wav" +else: + media_type = "ogg" +logger.info(f"编码格式: {media_type}") # 初始化模型 cnhubert.cnhubert_base_path = cnhubert_base_path