add new params

This commit is contained in:
samiabat 2025-06-24 08:17:08 +03:00
parent 9d24ff72cf
commit 9e4313fb4e
3 changed files with 108 additions and 52 deletions

View File

@ -7,7 +7,8 @@ from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights
i18n = I18nAuto()
def synthesize(GPT_model_path, SoVITS_model_path, ref_audio_path, ref_text, ref_language, target_text, text_language, output_path):
def synthesize(GPT_model_path, SoVITS_model_path, ref_audio_path, ref_text, ref_language, target_text, text_language, output_path, loudness_boost=False, gain=0, normalize=False,
energy_scale=1.0, volume_scale=1.0, strain_effect=0.0):
# Change model weights
change_gpt_weights(gpt_path=GPT_model_path)
@ -18,7 +19,15 @@ def synthesize(GPT_model_path, SoVITS_model_path, ref_audio_path, ref_text, ref_
prompt_text=ref_text,
prompt_language=i18n(ref_language),
text=target_text,
text_language=i18n(text_language), top_p=1, temperature=1)
text_language=i18n(text_language),
top_p=1,
temperature=1,
loudness_boost=loudness_boost,
gain=gain,
normalize=normalize,
energy_scale=energy_scale,
volume_scale=volume_scale,
strain_effect=strain_effect)
result_list = list(synthesis_result)

View File

@ -510,7 +510,11 @@ def audio_sr(audio,sr):
##ref_wav_path+prompt_text+prompt_language+text(单个)+text_language+top_k+top_p+temperature
# cache_tokens={}#暂未实现清理机制
cache= {}
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=15, top_p=1, temperature=1, ref_free = False,speed=1,if_freeze=False,inp_refs=None,sample_steps=16,if_sr=False,pause_second=0.3):
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"),
top_k=15, top_p=1, temperature=1, ref_free=False, speed=1, if_freeze=False,
inp_refs=None, sample_steps=16, if_sr=False, pause_second=0.3,
loudness_boost=False, gain=0, normalize=False, energy_scale=1.0,
volume_scale=1.0, strain_effect=0.0):
global cache
if ref_wav_path: pass
else: gr.Warning(i18n('请上传参考音频'))
@ -527,7 +531,6 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
prompt_language = dict_language[prompt_language]
text_language = dict_language[text_language]
if not ref_free:
prompt_text = prompt_text.strip("\n")
if (prompt_text[-1] not in splits): prompt_text += "" if prompt_language != "en" else "."
@ -631,14 +634,14 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
### v3不存在以下逻辑和inp_refs
if model_version != "v3":
refers = []
if(inp_refs):
if inp_refs:
for path in inp_refs:
try:
refer = get_spepc(hps, path.name).to(dtype).to(device)
refers.append(refer)
except:
traceback.print_exc()
if(len(refers)==0):refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)]
if len(refers) == 0: refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)]
audio = vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed)[0][0] # .cpu().detach().numpy()
else:
refer = get_spepc(hps, ref_wav_path).to(device).to(dtype)
@ -690,13 +693,46 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
with torch.inference_mode():
wav_gen = bigvgan_model(cmf_res)
audio = wav_gen[0][0] # .cpu().detach().numpy()
# Initial clipping check
max_audio = torch.abs(audio).max() # 简单防止16bit爆音
if max_audio>1:audio=audio/max_audio
if max_audio > 1:
audio = audio / max_audio
# Apply new parameters
audio = audio.to(torch.float32) # Ensure float32 for processing
if loudness_boost:
# Boost loudness using RMS-based scaling (adjust multiplier as needed)
rms = torch.sqrt(torch.mean(audio ** 2))
audio = audio * (rms * 1.5) if rms > 0 else audio
if gain > 0:
# Apply gain in dB
audio = audio * (10 ** (gain / 20))
if normalize:
# Normalize to [-1, 1]
max_abs = torch.abs(audio).max()
audio = audio / max_abs if max_abs > 0 else audio
if energy_scale != 1.0:
# Scale energy
audio = audio * torch.sqrt(torch.tensor(energy_scale))
if volume_scale != 1.0:
# Direct volume scaling
audio = audio * volume_scale
if strain_effect > 0.0:
# Add strain effect (basic distortion)
audio = audio + (audio ** 2 * strain_effect)
# Final clipping check after effects
max_audio = torch.abs(audio).max()
if max_audio > 1:
audio = audio / max_audio
audio_opt.append(audio)
audio_opt.append(zero_wav_torch) # zero_wav
t4 = ttime()
t.extend([t2 - t1, t3 - t2, t4 - t3])
t1 = ttime()
print("%.3f\t%.3f\t%.3f\t%.3f" % (t[0], sum(t[1::3]), sum(t[2::3]), sum(t[3::3])))
audio_opt = torch.cat(audio_opt, 0) # np.concatenate
sr = hps.data.sampling_rate if model_version != "v3" else 24000
@ -709,7 +745,6 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
audio_opt = audio_opt.cpu().detach().numpy()
yield sr, (audio_opt * 32767).astype(np.int16)
def split(todo_text):
todo_text = todo_text.replace("……", "").replace("——", "")
if todo_text[-1] not in splits:

14
api.py
View File

@ -1158,6 +1158,12 @@ def version_4_cli(
character_name: str = "Kurari",
model_id: int = 14,
version: str = "v1", # v3 or v4
loudness_boost=False,
gain=0,
normalize=False,
energy_scale=1.0,
volume_scale=1.0,
strain_effect=0.0,
):
# Create a temporary buffer to store the audio
audio_buffer = io.BytesIO()
@ -1187,7 +1193,13 @@ def version_4_cli(
ref_language = ref_language,
target_text = target_text,
text_language = text_language,
output_path = output_path # Don't save to file
output_path = output_path, # Don't save to file
loudness_boost=loudness_boost,
gain=gain,
normalize=normalize,
energy_scale=energy_scale,
volume_scale=volume_scale,
strain_effect=strain_effect
)
# Get the last audio data and sample rate from synthesis result