mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-07 15:19:59 +08:00
Add stream mode arg
This commit is contained in:
parent
bbc4e2080f
commit
d35dfd92b3
40
api.py
40
api.py
@ -18,6 +18,7 @@
|
||||
`-p` - `绑定端口, 默认9880, 可在 config.py 中指定`
|
||||
`-fp` - `覆盖 config.py 使用全精度`
|
||||
`-hp` - `覆盖 config.py 使用半精度`
|
||||
`-sm` - `流式返回模式, 默认不启用, "close","c", "normal","n", "keepalive","k"`
|
||||
|
||||
`-hb` - `cnhubert路径`
|
||||
`-b` - `bert路径`
|
||||
@ -304,6 +305,19 @@ def get_spepc(hps, filename):
|
||||
return spec
|
||||
|
||||
|
||||
def pack_audio(audio_bytes, data, rate):
|
||||
with sf.SoundFile(audio_bytes, mode='w', samplerate=rate, channels=1, format='ogg') as audio_file:
|
||||
audio_file.write(data)
|
||||
return audio_bytes
|
||||
|
||||
|
||||
def read_clean_buffer(audio_bytes):
|
||||
audio_chunk = audio_bytes.getvalue()
|
||||
audio_bytes.truncate(0)
|
||||
audio_bytes.seek(0)
|
||||
return audio_bytes, audio_chunk
|
||||
|
||||
|
||||
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language):
|
||||
t0 = ttime()
|
||||
prompt_text = prompt_text.strip("\n")
|
||||
@ -328,9 +342,10 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language)
|
||||
text_language = dict_language[text_language.lower()]
|
||||
phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language)
|
||||
texts = text.split("\n")
|
||||
audio_opt = []
|
||||
audio_bytes = BytesIO()
|
||||
|
||||
for text in texts:
|
||||
audio_opt = []
|
||||
phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language)
|
||||
bert = torch.cat([bert1, bert2], 1)
|
||||
|
||||
@ -365,13 +380,15 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language)
|
||||
audio_opt.append(audio)
|
||||
audio_opt.append(zero_wav)
|
||||
t4 = ttime()
|
||||
audio_bytes = pack_audio(audio_bytes,(np.concatenate(audio_opt, 0) * 32768).astype(np.int16),hps.data.sampling_rate)
|
||||
# logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
|
||||
ogg = BytesIO()
|
||||
sf.write(ogg, (np.concatenate(audio_opt, 0) * 32768).astype(np.int16), hps.data.sampling_rate, format="ogg")
|
||||
ogg.seek(0)
|
||||
chunk = ogg.read()
|
||||
yield chunk
|
||||
audio_opt = []
|
||||
if stream_mode == "normal":
|
||||
audio_bytes, audio_chunk = read_clean_buffer(audio_bytes)
|
||||
yield audio_chunk
|
||||
|
||||
if not stream_mode == "normal":
|
||||
yield audio_bytes.getvalue()
|
||||
|
||||
|
||||
|
||||
def handle_control(command):
|
||||
@ -465,6 +482,7 @@ parser.add_argument("-fp", "--full_precision", action="store_true", default=Fals
|
||||
parser.add_argument("-hp", "--half_precision", action="store_true", default=False, help="覆盖config.is_half为True, 使用半精度")
|
||||
# bool值的用法为 `python ./api.py -fp ...`
|
||||
# 此时 full_precision==True, half_precision==False
|
||||
parser.add_argument("-sm", "--stream_mode", type=str, default="close", help="流式返回模式, close / normal / keepalive")
|
||||
parser.add_argument("-hb", "--hubert_path", type=str, default=g_config.cnhubert_path, help="覆盖config.cnhubert_path")
|
||||
parser.add_argument("-b", "--bert_path", type=str, default=g_config.bert_path, help="覆盖config.bert_path")
|
||||
|
||||
@ -507,6 +525,14 @@ if args.full_precision and args.half_precision:
|
||||
is_half = g_config.is_half # 炒饭fallback
|
||||
logger.info(f"半精: {is_half}")
|
||||
|
||||
# 流式返回模式
|
||||
if args.stream_mode.lower() in ["normal","n"]:
|
||||
stream_mode = "normal"
|
||||
logger.info("流式返回已开启")
|
||||
else:
|
||||
stream_mode = "close"
|
||||
|
||||
|
||||
# 初始化模型
|
||||
cnhubert.cnhubert_base_path = cnhubert_base_path
|
||||
tokenizer = AutoTokenizer.from_pretrained(bert_path)
|
||||
|
Loading…
x
Reference in New Issue
Block a user