mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-08-15 05:21:57 +08:00
add new params
This commit is contained in:
parent
9d24ff72cf
commit
9e4313fb4e
@ -7,7 +7,8 @@ from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights
|
||||
|
||||
i18n = I18nAuto()
|
||||
|
||||
def synthesize(GPT_model_path, SoVITS_model_path, ref_audio_path, ref_text, ref_language, target_text, text_language, output_path):
|
||||
def synthesize(GPT_model_path, SoVITS_model_path, ref_audio_path, ref_text, ref_language, target_text, text_language, output_path, loudness_boost=False, gain=0, normalize=False,
|
||||
energy_scale=1.0, volume_scale=1.0, strain_effect=0.0):
|
||||
|
||||
# Change model weights
|
||||
change_gpt_weights(gpt_path=GPT_model_path)
|
||||
@ -18,7 +19,15 @@ def synthesize(GPT_model_path, SoVITS_model_path, ref_audio_path, ref_text, ref_
|
||||
prompt_text=ref_text,
|
||||
prompt_language=i18n(ref_language),
|
||||
text=target_text,
|
||||
text_language=i18n(text_language), top_p=1, temperature=1)
|
||||
text_language=i18n(text_language),
|
||||
top_p=1,
|
||||
temperature=1,
|
||||
loudness_boost=loudness_boost,
|
||||
gain=gain,
|
||||
normalize=normalize,
|
||||
energy_scale=energy_scale,
|
||||
volume_scale=volume_scale,
|
||||
strain_effect=strain_effect)
|
||||
|
||||
result_list = list(synthesis_result)
|
||||
|
||||
|
@ -510,7 +510,11 @@ def audio_sr(audio,sr):
|
||||
##ref_wav_path+prompt_text+prompt_language+text(单个)+text_language+top_k+top_p+temperature
|
||||
# cache_tokens={}#暂未实现清理机制
|
||||
cache= {}
|
||||
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=15, top_p=1, temperature=1, ref_free = False,speed=1,if_freeze=False,inp_refs=None,sample_steps=16,if_sr=False,pause_second=0.3):
|
||||
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"),
|
||||
top_k=15, top_p=1, temperature=1, ref_free=False, speed=1, if_freeze=False,
|
||||
inp_refs=None, sample_steps=16, if_sr=False, pause_second=0.3,
|
||||
loudness_boost=False, gain=0, normalize=False, energy_scale=1.0,
|
||||
volume_scale=1.0, strain_effect=0.0):
|
||||
global cache
|
||||
if ref_wav_path: pass
|
||||
else: gr.Warning(i18n('请上传参考音频'))
|
||||
@ -527,7 +531,6 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
||||
prompt_language = dict_language[prompt_language]
|
||||
text_language = dict_language[text_language]
|
||||
|
||||
|
||||
if not ref_free:
|
||||
prompt_text = prompt_text.strip("\n")
|
||||
if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_language != "en" else "."
|
||||
@ -631,14 +634,14 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
||||
### v3不存在以下逻辑和inp_refs
|
||||
if model_version != "v3":
|
||||
refers = []
|
||||
if(inp_refs):
|
||||
if inp_refs:
|
||||
for path in inp_refs:
|
||||
try:
|
||||
refer = get_spepc(hps, path.name).to(dtype).to(device)
|
||||
refers.append(refer)
|
||||
except:
|
||||
traceback.print_exc()
|
||||
if(len(refers)==0):refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)]
|
||||
if len(refers) == 0: refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)]
|
||||
audio = vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed)[0][0] # .cpu().detach().numpy()
|
||||
else:
|
||||
refer = get_spepc(hps, ref_wav_path).to(device).to(dtype)
|
||||
@ -690,13 +693,46 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
||||
with torch.inference_mode():
|
||||
wav_gen = bigvgan_model(cmf_res)
|
||||
audio = wav_gen[0][0] # .cpu().detach().numpy()
|
||||
|
||||
# Initial clipping check
|
||||
max_audio = torch.abs(audio).max() # 简单防止16bit爆音
|
||||
if max_audio>1:audio=audio/max_audio
|
||||
if max_audio > 1:
|
||||
audio = audio / max_audio
|
||||
|
||||
# Apply new parameters
|
||||
audio = audio.to(torch.float32) # Ensure float32 for processing
|
||||
if loudness_boost:
|
||||
# Boost loudness using RMS-based scaling (adjust multiplier as needed)
|
||||
rms = torch.sqrt(torch.mean(audio ** 2))
|
||||
audio = audio * (rms * 1.5) if rms > 0 else audio
|
||||
if gain > 0:
|
||||
# Apply gain in dB
|
||||
audio = audio * (10 ** (gain / 20))
|
||||
if normalize:
|
||||
# Normalize to [-1, 1]
|
||||
max_abs = torch.abs(audio).max()
|
||||
audio = audio / max_abs if max_abs > 0 else audio
|
||||
if energy_scale != 1.0:
|
||||
# Scale energy
|
||||
audio = audio * torch.sqrt(torch.tensor(energy_scale))
|
||||
if volume_scale != 1.0:
|
||||
# Direct volume scaling
|
||||
audio = audio * volume_scale
|
||||
if strain_effect > 0.0:
|
||||
# Add strain effect (basic distortion)
|
||||
audio = audio + (audio ** 2 * strain_effect)
|
||||
|
||||
# Final clipping check after effects
|
||||
max_audio = torch.abs(audio).max()
|
||||
if max_audio > 1:
|
||||
audio = audio / max_audio
|
||||
|
||||
audio_opt.append(audio)
|
||||
audio_opt.append(zero_wav_torch) # zero_wav
|
||||
t4 = ttime()
|
||||
t.extend([t2 - t1, t3 - t2, t4 - t3])
|
||||
t1 = ttime()
|
||||
|
||||
print("%.3f\t%.3f\t%.3f\t%.3f" % (t[0], sum(t[1::3]), sum(t[2::3]), sum(t[3::3])))
|
||||
audio_opt = torch.cat(audio_opt, 0) # np.concatenate
|
||||
sr = hps.data.sampling_rate if model_version != "v3" else 24000
|
||||
@ -709,7 +745,6 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
||||
audio_opt = audio_opt.cpu().detach().numpy()
|
||||
yield sr, (audio_opt * 32767).astype(np.int16)
|
||||
|
||||
|
||||
def split(todo_text):
|
||||
todo_text = todo_text.replace("……", "。").replace("——", ",")
|
||||
if todo_text[-1] not in splits:
|
||||
|
14
api.py
14
api.py
@ -1158,6 +1158,12 @@ def version_4_cli(
|
||||
character_name: str = "Kurari",
|
||||
model_id: int = 14,
|
||||
version: str = "v1", # v3 or v4
|
||||
loudness_boost=False,
|
||||
gain=0,
|
||||
normalize=False,
|
||||
energy_scale=1.0,
|
||||
volume_scale=1.0,
|
||||
strain_effect=0.0,
|
||||
):
|
||||
# Create a temporary buffer to store the audio
|
||||
audio_buffer = io.BytesIO()
|
||||
@ -1187,7 +1193,13 @@ def version_4_cli(
|
||||
ref_language = ref_language,
|
||||
target_text = target_text,
|
||||
text_language = text_language,
|
||||
output_path = output_path # Don't save to file
|
||||
output_path = output_path, # Don't save to file
|
||||
loudness_boost=loudness_boost,
|
||||
gain=gain,
|
||||
normalize=normalize,
|
||||
energy_scale=energy_scale,
|
||||
volume_scale=volume_scale,
|
||||
strain_effect=strain_effect
|
||||
)
|
||||
|
||||
# Get the last audio data and sample rate from synthesis result
|
||||
|
Loading…
x
Reference in New Issue
Block a user