Merge pull request #2 from ivy-consulting/test-branch

add param of ikko model
This commit is contained in:
Samuel Abatneh 2025-06-24 10:04:01 +03:00 committed by GitHub
commit b74efdb23b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 149 additions and 81 deletions

View File

@ -7,7 +7,8 @@ from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights
i18n = I18nAuto()
def synthesize(GPT_model_path, SoVITS_model_path, ref_audio_path, ref_text, ref_language, target_text, text_language, output_path):
def synthesize(GPT_model_path, SoVITS_model_path, ref_audio_path, ref_text, ref_language, target_text, text_language, output_path, loudness_boost=False, gain=0, normalize=False,
energy_scale=1.0, volume_scale=1.0, strain_effect=0.0):
# Change model weights
change_gpt_weights(gpt_path=GPT_model_path)
@ -18,7 +19,15 @@ def synthesize(GPT_model_path, SoVITS_model_path, ref_audio_path, ref_text, ref_
prompt_text=ref_text,
prompt_language=i18n(ref_language),
text=target_text,
text_language=i18n(text_language), top_p=1, temperature=1)
text_language=i18n(text_language),
top_p=1,
temperature=1,
loudness_boost=loudness_boost,
gain=gain,
normalize=normalize,
energy_scale=energy_scale,
volume_scale=volume_scale,
strain_effect=strain_effect)
result_list = list(synthesis_result)

View File

@ -510,7 +510,11 @@ def audio_sr(audio,sr):
##ref_wav_path+prompt_text+prompt_language+text(单个)+text_language+top_k+top_p+temperature
# cache_tokens={}#暂未实现清理机制
cache= {}
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=15, top_p=1, temperature=1, ref_free = False,speed=1,if_freeze=False,inp_refs=None,sample_steps=16,if_sr=False,pause_second=0.3):
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"),
top_k=15, top_p=1, temperature=1, ref_free=False, speed=1, if_freeze=False,
inp_refs=None, sample_steps=16, if_sr=False, pause_second=0.3,
loudness_boost=False, gain=0, normalize=False, energy_scale=1.0,
volume_scale=1.0, strain_effect=0.0):
global cache
if ref_wav_path: pass
else: gr.Warning(i18n('请上传参考音频'))
@ -527,7 +531,6 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
prompt_language = dict_language[prompt_language]
text_language = dict_language[text_language]
if not ref_free:
prompt_text = prompt_text.strip("\n")
if (prompt_text[-1] not in splits): prompt_text += "" if prompt_language != "en" else "."
@ -631,14 +634,14 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
### v3不存在以下逻辑和inp_refs
if model_version != "v3":
refers = []
if(inp_refs):
if inp_refs:
for path in inp_refs:
try:
refer = get_spepc(hps, path.name).to(dtype).to(device)
refers.append(refer)
except:
traceback.print_exc()
if(len(refers)==0):refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)]
if len(refers) == 0: refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)]
audio = vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed)[0][0] # .cpu().detach().numpy()
else:
refer = get_spepc(hps, ref_wav_path).to(device).to(dtype)
@ -690,13 +693,46 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
with torch.inference_mode():
wav_gen = bigvgan_model(cmf_res)
audio = wav_gen[0][0] # .cpu().detach().numpy()
# Initial clipping check
max_audio = torch.abs(audio).max() # 简单防止16bit爆音
if max_audio>1:audio=audio/max_audio
if max_audio > 1:
audio = audio / max_audio
# Apply new parameters
audio = audio.to(torch.float32) # Ensure float32 for processing
if loudness_boost:
# Boost loudness using RMS-based scaling (adjust multiplier as needed)
rms = torch.sqrt(torch.mean(audio ** 2))
audio = audio * (rms * 1.5) if rms > 0 else audio
if gain > 0:
# Apply gain in dB
audio = audio * (10 ** (gain / 20))
if normalize:
# Normalize to [-1, 1]
max_abs = torch.abs(audio).max()
audio = audio / max_abs if max_abs > 0 else audio
if energy_scale != 1.0:
# Scale energy
audio = audio * torch.sqrt(torch.tensor(energy_scale))
if volume_scale != 1.0:
# Direct volume scaling
audio = audio * volume_scale
if strain_effect > 0.0:
# Add strain effect (basic distortion)
audio = audio + (audio ** 2 * strain_effect)
# Final clipping check after effects
max_audio = torch.abs(audio).max()
if max_audio > 1:
audio = audio / max_audio
audio_opt.append(audio)
audio_opt.append(zero_wav_torch) # zero_wav
t4 = ttime()
t.extend([t2 - t1, t3 - t2, t4 - t3])
t1 = ttime()
print("%.3f\t%.3f\t%.3f\t%.3f" % (t[0], sum(t[1::3]), sum(t[2::3]), sum(t[3::3])))
audio_opt = torch.cat(audio_opt, 0) # np.concatenate
sr = hps.data.sampling_rate if model_version != "v3" else 24000
@ -709,7 +745,6 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
audio_opt = audio_opt.cpu().detach().numpy()
yield sr, (audio_opt * 32767).astype(np.int16)
def split(todo_text):
todo_text = todo_text.replace("……", "").replace("——", "")
if todo_text[-1] not in splits:

54
api.py
View File

@ -1158,6 +1158,12 @@ def version_4_cli(
character_name: str = "Kurari",
model_id: int = 14,
version: str = "v1", # v3 or v4
loudness_boost=False,
gain=0,
normalize=False,
energy_scale=1.0,
volume_scale=1.0,
strain_effect=0.0,
):
# Create a temporary buffer to store the audio
audio_buffer = io.BytesIO()
@ -1169,6 +1175,9 @@ def version_4_cli(
SoVITS_model_path = "GPT_SoVITS/pretrained_models/saotome_e9_s522_l32.pth"
ref_language = "日文"
elif character_name in ["Ikko", "ikko", "Ikka", "ikka"]:
if loudness_boost:
path = "idols/ikka/ikko_boost.wav"
else:
path = "idols/ikka/ikko.wav"
GPT_model_path = "GPT_SoVITS/pretrained_models/ikko-san-e45.ckpt"
SoVITS_model_path = "GPT_SoVITS/pretrained_models/ikko-san_e15_s1305_l32.pth"
@ -1187,7 +1196,13 @@ def version_4_cli(
ref_language = ref_language,
target_text = target_text,
text_language = text_language,
output_path = output_path # Don't save to file
output_path = output_path, # Don't save to file
loudness_boost=loudness_boost,
gain=gain,
normalize=normalize,
energy_scale=energy_scale,
volume_scale=volume_scale,
strain_effect=strain_effect
)
# Get the last audio data and sample rate from synthesis result
@ -1218,7 +1233,13 @@ async def tts_endpoint(
speed: float = 1.0,
sample_steps: int = 20,
if_sr: bool = False,
version: str = "v1", # v3 or v4
version: str = "v1",
loudness_boost: str = "false", # Accept as string from URL, convert to bool
gain: str = "0", # Accept as string from URL, convert to float
normalize: str = "false", # Accept as string from URL, convert to bool
energy_scale: str = "1.0", # Accept as string from URL, convert to float
volume_scale: str = "1.0", # Accept as string from URL, convert to float
strain_effect: str = "0.0" # Accept as string from URL, convert to float
):
if character == "kurari" or character == "Kurari":
prompt_text = "おはよう〜。今日はどんな1日過ごすーくらりはね〜いつでもあなたの味方だよ"
@ -1226,21 +1247,10 @@ async def tts_endpoint(
prompt_text = "朝ごはんにはトーストと卵、そしてコーヒーを飲みました。簡単だけど、朝の時間が少し幸せに感じられる瞬間でした。"
elif character in ["Ikko", "ikko", "Ikka", "ikka"]:
prompt_text = "せおいなげ、まじばな、らぶらぶ、あげあげ、まぼろし"
import warnings
warnings.warn(f"the character name is {character}. ")
if (character == "Kurari") or character == "saotome" or character == "ikka" or character == "Ikka" or character== "ikko" or character == "Ikko":
"""
"中文": "all_zh",
"粤语": "all_yue",
"英文": "en",
"日文": "all_ja",
"韩文": "all_ko",
"中英混合": "zh",
"粤英混合": "yue",
"日英混合": "ja",
"""
if character in ["Kurari", "saotome", "ikka", "Ikka", "ikko", "Ikko"]:
if text_language == "all_ja":
text_language = "日文"
elif text_language == "ja":
@ -1254,13 +1264,27 @@ async def tts_endpoint(
elif text_language == "ko":
text_language = "韩文"
# Convert string parameters from URL to appropriate types
loudness_boost = loudness_boost.lower() == "true"
gain = float(gain)
normalize = normalize.lower() == "true"
energy_scale = float(energy_scale)
volume_scale = float(volume_scale)
strain_effect = float(strain_effect)
audio_buffer, sample_rate = version_4_cli(
character_name=character,
ref_text=prompt_text,
ref_language="日文",
target_text=text,
text_language=text_language or "日文",
version=version, # v2 or v3
version=version,
loudness_boost=loudness_boost,
gain=gain,
normalize=normalize,
energy_scale=energy_scale,
volume_scale=volume_scale,
strain_effect=strain_effect
)
if audio_buffer:

BIN
idols/ikka/ikko_boost.wav Normal file

Binary file not shown.