From 0a17694edee04ad4605d8352034339374691d94a Mon Sep 17 00:00:00 2001 From: Jin Date: Sat, 28 Sep 2024 19:42:27 +0800 Subject: [PATCH] api_v2.py: support ref_audio input as base64 string --- api_v2.py | 21 ++++++++++++++++++++- requirements.txt | 1 + 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/api_v2.py b/api_v2.py index ea1d0c7f..3008cb38 100644 --- a/api_v2.py +++ b/api_v2.py @@ -212,6 +212,22 @@ def pack_audio(io_buffer:BytesIO, data:np.ndarray, rate:int, media_type:str): return io_buffer +_base64_audio_cache = {} +def save_base64_audio(b64str:str): + import filetype, base64, uuid + global _base64_audio_cache + if b64str in _base64_audio_cache: + return _base64_audio_cache[b64str] + savedir = 'TEMP/upload' + data = base64.b64decode(b64str) + ft = filetype.guess(data) + ext = f'.{ft.extension}' if ft else '' + os.makedirs(savedir, exist_ok=True) + saveto = f'{savedir}/{uuid.uuid1()}{ext}' + with open(saveto, 'wb') as outf: + outf.write(data) + _base64_audio_cache[b64str] = saveto + return saveto # from https://huggingface.co/spaces/coqui/voice-chat-with-mistral/blob/main/app.py def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=32000): @@ -277,7 +293,7 @@ async def tts_handle(req:dict): { "text": "", # str.(required) text to be synthesized "text_lang: "", # str.(required) language of the text to be synthesized - "ref_audio_path": "", # str.(required) reference audio path + "ref_audio_path": "", # str.(required) reference audio path ; allow data of format base64:xxxxxx "aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker synthesis "prompt_text": "", # str.(optional) prompt text for the reference audio "prompt_lang": "", # str.(required) language of the prompt text for the reference audio @@ -303,6 +319,9 @@ async def tts_handle(req:dict): streaming_mode = req.get("streaming_mode", False) return_fragment = req.get("return_fragment", False) media_type = req.get("media_type", "wav") + ref_audio_path = req.get("ref_audio_path", "") + if ref_audio_path.startswith("base64:"): + req['ref_audio_path'] = ref_audio_path = save_base64_audio(ref_audio_path[len("base64:"):]) check_res = check_params(req) if check_res is not None: diff --git a/requirements.txt b/requirements.txt index 280d9d99..3792d88b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -34,3 +34,4 @@ opencc; sys_platform != 'linux' opencc==1.1.1; sys_platform == 'linux' python_mecab_ko; sys_platform != 'win32' fastapi<0.112.2 +filetype