From e3be776d1f5c354cd83ad58eaea9203ce0783b10 Mon Sep 17 00:00:00 2001 From: samiabat Date: Tue, 24 Jun 2025 09:25:23 +0300 Subject: [PATCH] add params to the endpoint --- api.py | 65 +++++++++++++++++++++++++++++++++------------------------- 1 file changed, 37 insertions(+), 28 deletions(-) diff --git a/api.py b/api.py index f42dbd96..3f5d537e 100644 --- a/api.py +++ b/api.py @@ -1218,19 +1218,25 @@ def version_4_cli( @app.get("/") async def tts_endpoint( - prompt_text: str = "今日は友達と一緒に映画を見に行く予定ですが、天気が悪くて少し心配です。", - prompt_language: str = "all_ja", - character: str = "saotome", - text: str = None, - text_language: str = None, - cut_punc: str = None, - top_k: int = 15, - top_p: float = 1.0, - temperature: float = 1.0, - speed: float = 1.0, - sample_steps: int = 20, - if_sr: bool = False, - version: str = "v1", # v3 or v4 + prompt_text: str = "今日は友達と一緒に映画を見に行く予定ですが、天気が悪くて少し心配です。", + prompt_language: str = "all_ja", + character: str = "saotome", + text: str = None, + text_language: str = None, + cut_punc: str = None, + top_k: int = 15, + top_p: float = 1.0, + temperature: float = 1.0, + speed: float = 1.0, + sample_steps: int = 20, + if_sr: bool = False, + version: str = "v1", + loudness_boost: str = "false", # Accept as string from URL, convert to bool + gain: str = "0", # Accept as string from URL, convert to float + normalize: str = "false", # Accept as string from URL, convert to bool + energy_scale: str = "1.0", # Accept as string from URL, convert to float + volume_scale: str = "1.0", # Accept as string from URL, convert to float + strain_effect: str = "0.0" # Accept as string from URL, convert to float ): if character == "kurari" or character == "Kurari": prompt_text = "おはよう〜。今日はどんな1日過ごすー?くらりはね〜いつでもあなたの味方だよ" @@ -1238,21 +1244,10 @@ async def tts_endpoint( prompt_text = "朝ごはんにはトーストと卵、そしてコーヒーを飲みました。簡単だけど、朝の時間が少し幸せに感じられる瞬間でした。" elif character in ["Ikko", "ikko", "Ikka", "ikka"]: prompt_text = "せおいなげ、まじばな、らぶらぶ、あげあげ、まぼろし" - import warnings warnings.warn(f"the character name is {character}. ") - - if (character == "Kurari") or character == "saotome" or character == "ikka" or character == "Ikka" or character== "ikko" or character == "Ikko": - """ - "中文": "all_zh", - "粤语": "all_yue", - "英文": "en", - "日文": "all_ja", - "韩文": "all_ko", - "中英混合": "zh", - "粤英混合": "yue", - "日英混合": "ja", - """ + + if character in ["Kurari", "saotome", "ikka", "Ikka", "ikko", "Ikko"]: if text_language == "all_ja": text_language = "日文" elif text_language == "ja": @@ -1266,15 +1261,29 @@ async def tts_endpoint( elif text_language == "ko": text_language = "韩文" + # Convert string parameters from URL to appropriate types + loudness_boost = loudness_boost.lower() == "true" + gain = float(gain) + normalize = normalize.lower() == "true" + energy_scale = float(energy_scale) + volume_scale = float(volume_scale) + strain_effect = float(strain_effect) + audio_buffer, sample_rate = version_4_cli( character_name=character, ref_text=prompt_text, ref_language="日文", target_text=text, text_language=text_language or "日文", - version=version, # v2 or v3 + version=version, + loudness_boost=loudness_boost, + gain=gain, + normalize=normalize, + energy_scale=energy_scale, + volume_scale=volume_scale, + strain_effect=strain_effect ) - + if audio_buffer: return StreamingResponse( audio_buffer,