Add media type arg

This commit is contained in:
KamioRinn 2024-03-29 22:32:16 +08:00
parent d35dfd92b3
commit b4d5ed7e85

62
api.py
View File

@ -19,6 +19,7 @@
`-fp` - `覆盖 config.py 使用全精度` `-fp` - `覆盖 config.py 使用全精度`
`-hp` - `覆盖 config.py 使用半精度` `-hp` - `覆盖 config.py 使用半精度`
`-sm` - `流式返回模式, 默认不启用, "close","c", "normal","n", "keepalive","k"` `-sm` - `流式返回模式, 默认不启用, "close","c", "normal","n", "keepalive","k"`
·-mt` - `返回的音频编码格式, 流式默认ogg, 非流式默认wav, "wav", "ogg", "aac"`
`-hb` - `cnhubert路径` `-hb` - `cnhubert路径`
`-b` - `bert路径` `-b` - `bert路径`
@ -127,6 +128,7 @@ from module.mel_processing import spectrogram_torch
from my_utils import load_audio from my_utils import load_audio
import config as global_config import config as global_config
import logging import logging
import subprocess
class DefaultRefer: class DefaultRefer:
@ -306,8 +308,54 @@ def get_spepc(hps, filename):
def pack_audio(audio_bytes, data, rate): def pack_audio(audio_bytes, data, rate):
if media_type == "ogg":
audio_bytes = pack_ogg(audio_bytes, data, rate)
elif media_type == "aac":
audio_bytes = pack_aac(audio_bytes, data, rate)
else:
# wav无法流式, 先暂存raw
audio_bytes = pack_raw(audio_bytes, data, rate)
return audio_bytes
def pack_ogg(audio_bytes, data, rate):
with sf.SoundFile(audio_bytes, mode='w', samplerate=rate, channels=1, format='ogg') as audio_file: with sf.SoundFile(audio_bytes, mode='w', samplerate=rate, channels=1, format='ogg') as audio_file:
audio_file.write(data) audio_file.write(data)
return audio_bytes
def pack_raw(audio_bytes, data, rate):
audio_bytes.write(data.tobytes())
return audio_bytes
def pack_wav(audio_bytes, rate):
data = np.frombuffer(audio_bytes.getvalue(),dtype=np.int16)
wav_bytes = BytesIO()
sf.write(wav_bytes, data, rate, format='wav')
return wav_bytes
def pack_aac(audio_bytes, data, rate):
process = subprocess.Popen([
'ffmpeg',
'-f', 's16le', # 输入16位有符号小端整数PCM
'-ar', str(rate), # 设置采样率
'-ac', '1', # 单声道
'-i', 'pipe:0', # 从管道读取输入
'-c:a', 'aac', # 音频编码器为AAC
'-b:a', '192k', # 比特率
'-vn', # 不包含视频
'-f', 'adts', # 输出AAC数据流格式
'pipe:1' # 将输出写入管道
], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
out, _ = process.communicate(input=data.tobytes())
audio_bytes.write(out)
return audio_bytes return audio_bytes
@ -315,6 +363,7 @@ def read_clean_buffer(audio_bytes):
audio_chunk = audio_bytes.getvalue() audio_chunk = audio_bytes.getvalue()
audio_bytes.truncate(0) audio_bytes.truncate(0)
audio_bytes.seek(0) audio_bytes.seek(0)
return audio_bytes, audio_chunk return audio_bytes, audio_chunk
@ -387,6 +436,8 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language)
yield audio_chunk yield audio_chunk
if not stream_mode == "normal": if not stream_mode == "normal":
if media_type == "wav":
audio_bytes = pack_wav(audio_bytes,hps.data.sampling_rate)
yield audio_bytes.getvalue() yield audio_bytes.getvalue()
@ -433,7 +484,7 @@ def handle(refer_wav_path, prompt_text, prompt_language, text, text_language):
if not default_refer.is_ready(): if not default_refer.is_ready():
return JSONResponse({"code": 400, "message": "未指定参考音频且接口无预设"}, status_code=400) return JSONResponse({"code": 400, "message": "未指定参考音频且接口无预设"}, status_code=400)
return StreamingResponse(get_tts_wav(refer_wav_path, prompt_text, prompt_language, text, text_language), media_type="audio/ogg") return StreamingResponse(get_tts_wav(refer_wav_path, prompt_text, prompt_language, text, text_language), media_type="audio/"+media_type)
@ -483,6 +534,7 @@ parser.add_argument("-hp", "--half_precision", action="store_true", default=Fals
# bool值的用法为 `python ./api.py -fp ...` # bool值的用法为 `python ./api.py -fp ...`
# 此时 full_precision==True, half_precision==False # 此时 full_precision==True, half_precision==False
parser.add_argument("-sm", "--stream_mode", type=str, default="close", help="流式返回模式, close / normal / keepalive") parser.add_argument("-sm", "--stream_mode", type=str, default="close", help="流式返回模式, close / normal / keepalive")
parser.add_argument("-mt", "--media_type", type=str, default="wav", help="音频编码格式, wav / ogg / aac")
parser.add_argument("-hb", "--hubert_path", type=str, default=g_config.cnhubert_path, help="覆盖config.cnhubert_path") parser.add_argument("-hb", "--hubert_path", type=str, default=g_config.cnhubert_path, help="覆盖config.cnhubert_path")
parser.add_argument("-b", "--bert_path", type=str, default=g_config.bert_path, help="覆盖config.bert_path") parser.add_argument("-b", "--bert_path", type=str, default=g_config.bert_path, help="覆盖config.bert_path")
@ -532,6 +584,14 @@ if args.stream_mode.lower() in ["normal","n"]:
else: else:
stream_mode = "close" stream_mode = "close"
# 音频编码格式
if args.media_type.lower() in ["aac","ogg"]:
media_type = args.media_type.lower()
elif stream_mode == "close":
media_type = "wav"
else:
media_type = "ogg"
logger.info(f"编码格式: {media_type}")
# 初始化模型 # 初始化模型
cnhubert.cnhubert_base_path = cnhubert_base_path cnhubert.cnhubert_base_path = cnhubert_base_path