mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-09 16:40:00 +08:00
Add media type arg
This commit is contained in:
parent
d35dfd92b3
commit
b4d5ed7e85
62
api.py
62
api.py
@ -19,6 +19,7 @@
|
||||
`-fp` - `覆盖 config.py 使用全精度`
|
||||
`-hp` - `覆盖 config.py 使用半精度`
|
||||
`-sm` - `流式返回模式, 默认不启用, "close","c", "normal","n", "keepalive","k"`
|
||||
·-mt` - `返回的音频编码格式, 流式默认ogg, 非流式默认wav, "wav", "ogg", "aac"`
|
||||
|
||||
`-hb` - `cnhubert路径`
|
||||
`-b` - `bert路径`
|
||||
@ -127,6 +128,7 @@ from module.mel_processing import spectrogram_torch
|
||||
from my_utils import load_audio
|
||||
import config as global_config
|
||||
import logging
|
||||
import subprocess
|
||||
|
||||
|
||||
class DefaultRefer:
|
||||
@ -306,8 +308,54 @@ def get_spepc(hps, filename):
|
||||
|
||||
|
||||
def pack_audio(audio_bytes, data, rate):
|
||||
if media_type == "ogg":
|
||||
audio_bytes = pack_ogg(audio_bytes, data, rate)
|
||||
elif media_type == "aac":
|
||||
audio_bytes = pack_aac(audio_bytes, data, rate)
|
||||
else:
|
||||
# wav无法流式, 先暂存raw
|
||||
audio_bytes = pack_raw(audio_bytes, data, rate)
|
||||
|
||||
return audio_bytes
|
||||
|
||||
|
||||
def pack_ogg(audio_bytes, data, rate):
|
||||
with sf.SoundFile(audio_bytes, mode='w', samplerate=rate, channels=1, format='ogg') as audio_file:
|
||||
audio_file.write(data)
|
||||
|
||||
return audio_bytes
|
||||
|
||||
|
||||
def pack_raw(audio_bytes, data, rate):
|
||||
audio_bytes.write(data.tobytes())
|
||||
|
||||
return audio_bytes
|
||||
|
||||
|
||||
def pack_wav(audio_bytes, rate):
|
||||
data = np.frombuffer(audio_bytes.getvalue(),dtype=np.int16)
|
||||
wav_bytes = BytesIO()
|
||||
sf.write(wav_bytes, data, rate, format='wav')
|
||||
|
||||
return wav_bytes
|
||||
|
||||
|
||||
def pack_aac(audio_bytes, data, rate):
|
||||
process = subprocess.Popen([
|
||||
'ffmpeg',
|
||||
'-f', 's16le', # 输入16位有符号小端整数PCM
|
||||
'-ar', str(rate), # 设置采样率
|
||||
'-ac', '1', # 单声道
|
||||
'-i', 'pipe:0', # 从管道读取输入
|
||||
'-c:a', 'aac', # 音频编码器为AAC
|
||||
'-b:a', '192k', # 比特率
|
||||
'-vn', # 不包含视频
|
||||
'-f', 'adts', # 输出AAC数据流格式
|
||||
'pipe:1' # 将输出写入管道
|
||||
], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
out, _ = process.communicate(input=data.tobytes())
|
||||
audio_bytes.write(out)
|
||||
|
||||
return audio_bytes
|
||||
|
||||
|
||||
@ -315,6 +363,7 @@ def read_clean_buffer(audio_bytes):
|
||||
audio_chunk = audio_bytes.getvalue()
|
||||
audio_bytes.truncate(0)
|
||||
audio_bytes.seek(0)
|
||||
|
||||
return audio_bytes, audio_chunk
|
||||
|
||||
|
||||
@ -387,6 +436,8 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language)
|
||||
yield audio_chunk
|
||||
|
||||
if not stream_mode == "normal":
|
||||
if media_type == "wav":
|
||||
audio_bytes = pack_wav(audio_bytes,hps.data.sampling_rate)
|
||||
yield audio_bytes.getvalue()
|
||||
|
||||
|
||||
@ -433,7 +484,7 @@ def handle(refer_wav_path, prompt_text, prompt_language, text, text_language):
|
||||
if not default_refer.is_ready():
|
||||
return JSONResponse({"code": 400, "message": "未指定参考音频且接口无预设"}, status_code=400)
|
||||
|
||||
return StreamingResponse(get_tts_wav(refer_wav_path, prompt_text, prompt_language, text, text_language), media_type="audio/ogg")
|
||||
return StreamingResponse(get_tts_wav(refer_wav_path, prompt_text, prompt_language, text, text_language), media_type="audio/"+media_type)
|
||||
|
||||
|
||||
|
||||
@ -483,6 +534,7 @@ parser.add_argument("-hp", "--half_precision", action="store_true", default=Fals
|
||||
# bool值的用法为 `python ./api.py -fp ...`
|
||||
# 此时 full_precision==True, half_precision==False
|
||||
parser.add_argument("-sm", "--stream_mode", type=str, default="close", help="流式返回模式, close / normal / keepalive")
|
||||
parser.add_argument("-mt", "--media_type", type=str, default="wav", help="音频编码格式, wav / ogg / aac")
|
||||
parser.add_argument("-hb", "--hubert_path", type=str, default=g_config.cnhubert_path, help="覆盖config.cnhubert_path")
|
||||
parser.add_argument("-b", "--bert_path", type=str, default=g_config.bert_path, help="覆盖config.bert_path")
|
||||
|
||||
@ -532,6 +584,14 @@ if args.stream_mode.lower() in ["normal","n"]:
|
||||
else:
|
||||
stream_mode = "close"
|
||||
|
||||
# 音频编码格式
|
||||
if args.media_type.lower() in ["aac","ogg"]:
|
||||
media_type = args.media_type.lower()
|
||||
elif stream_mode == "close":
|
||||
media_type = "wav"
|
||||
else:
|
||||
media_type = "ogg"
|
||||
logger.info(f"编码格式: {media_type}")
|
||||
|
||||
# 初始化模型
|
||||
cnhubert.cnhubert_base_path = cnhubert_base_path
|
||||
|
Loading…
x
Reference in New Issue
Block a user