add params to the endpoint

This commit is contained in:
samiabat 2025-06-24 09:25:23 +03:00
parent 9e4313fb4e
commit e3be776d1f

61
api.py
View File

@ -1218,19 +1218,25 @@ def version_4_cli(
@app.get("/") @app.get("/")
async def tts_endpoint( async def tts_endpoint(
prompt_text: str = "今日は友達と一緒に映画を見に行く予定ですが、天気が悪くて少し心配です。", prompt_text: str = "今日は友達と一緒に映画を見に行く予定ですが、天気が悪くて少し心配です。",
prompt_language: str = "all_ja", prompt_language: str = "all_ja",
character: str = "saotome", character: str = "saotome",
text: str = None, text: str = None,
text_language: str = None, text_language: str = None,
cut_punc: str = None, cut_punc: str = None,
top_k: int = 15, top_k: int = 15,
top_p: float = 1.0, top_p: float = 1.0,
temperature: float = 1.0, temperature: float = 1.0,
speed: float = 1.0, speed: float = 1.0,
sample_steps: int = 20, sample_steps: int = 20,
if_sr: bool = False, if_sr: bool = False,
version: str = "v1", # v3 or v4 version: str = "v1",
loudness_boost: str = "false", # Accept as string from URL, convert to bool
gain: str = "0", # Accept as string from URL, convert to float
normalize: str = "false", # Accept as string from URL, convert to bool
energy_scale: str = "1.0", # Accept as string from URL, convert to float
volume_scale: str = "1.0", # Accept as string from URL, convert to float
strain_effect: str = "0.0" # Accept as string from URL, convert to float
): ):
if character == "kurari" or character == "Kurari": if character == "kurari" or character == "Kurari":
prompt_text = "おはよう〜。今日はどんな1日過ごすーくらりはね〜いつでもあなたの味方だよ" prompt_text = "おはよう〜。今日はどんな1日過ごすーくらりはね〜いつでもあなたの味方だよ"
@ -1238,21 +1244,10 @@ async def tts_endpoint(
prompt_text = "朝ごはんにはトーストと卵、そしてコーヒーを飲みました。簡単だけど、朝の時間が少し幸せに感じられる瞬間でした。" prompt_text = "朝ごはんにはトーストと卵、そしてコーヒーを飲みました。簡単だけど、朝の時間が少し幸せに感じられる瞬間でした。"
elif character in ["Ikko", "ikko", "Ikka", "ikka"]: elif character in ["Ikko", "ikko", "Ikka", "ikka"]:
prompt_text = "せおいなげ、まじばな、らぶらぶ、あげあげ、まぼろし" prompt_text = "せおいなげ、まじばな、らぶらぶ、あげあげ、まぼろし"
import warnings import warnings
warnings.warn(f"the character name is {character}. ") warnings.warn(f"the character name is {character}. ")
if (character == "Kurari") or character == "saotome" or character == "ikka" or character == "Ikka" or character== "ikko" or character == "Ikko": if character in ["Kurari", "saotome", "ikka", "Ikka", "ikko", "Ikko"]:
"""
"中文": "all_zh",
"粤语": "all_yue",
"英文": "en",
"日文": "all_ja",
"韩文": "all_ko",
"中英混合": "zh",
"粤英混合": "yue",
"日英混合": "ja",
"""
if text_language == "all_ja": if text_language == "all_ja":
text_language = "日文" text_language = "日文"
elif text_language == "ja": elif text_language == "ja":
@ -1266,13 +1261,27 @@ async def tts_endpoint(
elif text_language == "ko": elif text_language == "ko":
text_language = "韩文" text_language = "韩文"
# Convert string parameters from URL to appropriate types
loudness_boost = loudness_boost.lower() == "true"
gain = float(gain)
normalize = normalize.lower() == "true"
energy_scale = float(energy_scale)
volume_scale = float(volume_scale)
strain_effect = float(strain_effect)
audio_buffer, sample_rate = version_4_cli( audio_buffer, sample_rate = version_4_cli(
character_name=character, character_name=character,
ref_text=prompt_text, ref_text=prompt_text,
ref_language="日文", ref_language="日文",
target_text=text, target_text=text,
text_language=text_language or "日文", text_language=text_language or "日文",
version=version, # v2 or v3 version=version,
loudness_boost=loudness_boost,
gain=gain,
normalize=normalize,
energy_scale=energy_scale,
volume_scale=volume_scale,
strain_effect=strain_effect
) )
if audio_buffer: if audio_buffer: