From 9e4313fb4ef305b8646b019c46aaaa63ac9b4d6c Mon Sep 17 00:00:00 2001 From: samiabat Date: Tue, 24 Jun 2025 08:17:08 +0300 Subject: [PATCH] add new params --- GPT_SoVITS/inference_cli.py | 13 +++- GPT_SoVITS/inference_webui.py | 133 +++++++++++++++++++++------------- api.py | 14 +++- 3 files changed, 108 insertions(+), 52 deletions(-) diff --git a/GPT_SoVITS/inference_cli.py b/GPT_SoVITS/inference_cli.py index 6a57ca2a..43004a21 100644 --- a/GPT_SoVITS/inference_cli.py +++ b/GPT_SoVITS/inference_cli.py @@ -7,7 +7,8 @@ from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights i18n = I18nAuto() -def synthesize(GPT_model_path, SoVITS_model_path, ref_audio_path, ref_text, ref_language, target_text, text_language, output_path): +def synthesize(GPT_model_path, SoVITS_model_path, ref_audio_path, ref_text, ref_language, target_text, text_language, output_path, loudness_boost=False, gain=0, normalize=False, + energy_scale=1.0, volume_scale=1.0, strain_effect=0.0): # Change model weights change_gpt_weights(gpt_path=GPT_model_path) @@ -18,7 +19,15 @@ def synthesize(GPT_model_path, SoVITS_model_path, ref_audio_path, ref_text, ref_ prompt_text=ref_text, prompt_language=i18n(ref_language), text=target_text, - text_language=i18n(text_language), top_p=1, temperature=1) + text_language=i18n(text_language), + top_p=1, + temperature=1, + loudness_boost=loudness_boost, + gain=gain, + normalize=normalize, + energy_scale=energy_scale, + volume_scale=volume_scale, + strain_effect=strain_effect) result_list = list(synthesis_result) diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index 9c135ca5..3e7a1d7b 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -510,24 +510,27 @@ def audio_sr(audio,sr): ##ref_wav_path+prompt_text+prompt_language+text(单个)+text_language+top_k+top_p+temperature # cache_tokens={}#暂未实现清理机制 cache= {} -def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=15, top_p=1, temperature=1, ref_free = False,speed=1,if_freeze=False,inp_refs=None,sample_steps=16,if_sr=False,pause_second=0.3): +def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), + top_k=15, top_p=1, temperature=1, ref_free=False, speed=1, if_freeze=False, + inp_refs=None, sample_steps=16, if_sr=False, pause_second=0.3, + loudness_boost=False, gain=0, normalize=False, energy_scale=1.0, + volume_scale=1.0, strain_effect=0.0): global cache - if ref_wav_path:pass - else:gr.Warning(i18n('请上传参考音频')) - if text:pass - else:gr.Warning(i18n('请填入推理文本')) + if ref_wav_path: pass + else: gr.Warning(i18n('请上传参考音频')) + if text: pass + else: gr.Warning(i18n('请填入推理文本')) t = [] if prompt_text is None or len(prompt_text) == 0: ref_free = True - if model_version=="v3": - ref_free=False#s2v3暂不支持ref_free + if model_version == "v3": + ref_free = False # s2v3暂不支持ref_free else: - if_sr=False + if_sr = False t0 = ttime() prompt_language = dict_language[prompt_language] text_language = dict_language[text_language] - if not ref_free: prompt_text = prompt_text.strip("\n") if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_language != "en" else "." @@ -567,7 +570,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, prompt = prompt_semantic.unsqueeze(0).to(device) t1 = ttime() - t.append(t1-t0) + t.append(t1 - t0) if (how_to_cut == i18n("凑四句一切")): text = cut1(text) @@ -586,21 +589,21 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, texts = process_text(texts) texts = merge_short_text_in_array(texts, 5) audio_opt = [] - ###s2v3暂不支持ref_free + ### s2v3暂不支持ref_free if not ref_free: - phones1,bert1,norm_text1=get_phones_and_bert(prompt_text, prompt_language, version) + phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language, version) - for i_text,text in enumerate(texts): + for i_text, text in enumerate(texts): # 解决输入目标文本的空行导致报错的问题 if (len(text.strip()) == 0): continue if (text[-1] not in splits): text += "。" if text_language != "en" else "." print(i18n("实际输入的目标文本(每句):"), text) - phones2,bert2,norm_text2=get_phones_and_bert(text, text_language, version) + phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language, version) print(i18n("前端处理后的文本(每句):"), norm_text2) if not ref_free: bert = torch.cat([bert1, bert2], 1) - all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0) + all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0) else: bert = bert2 all_phoneme_ids = torch.LongTensor(phones2).to(device).unsqueeze(0) @@ -611,7 +614,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, t2 = ttime() # cache_key="%s-%s-%s-%s-%s-%s-%s-%s"%(ref_wav_path,prompt_text,prompt_language,text,text_language,top_k,top_p,temperature) # print(cache.keys(),if_freeze) - if(i_text in cache and if_freeze==True):pred_semantic=cache[i_text] + if(i_text in cache and if_freeze == True): pred_semantic = cache[i_text] else: with torch.no_grad(): pred_semantic, idx = t2s_model.model.infer_panel( @@ -626,33 +629,33 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, early_stop_num=hz * max_sec, ) pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) - cache[i_text]=pred_semantic + cache[i_text] = pred_semantic t3 = ttime() - ###v3不存在以下逻辑和inp_refs - if model_version!="v3": - refers=[] - if(inp_refs): + ### v3不存在以下逻辑和inp_refs + if model_version != "v3": + refers = [] + if inp_refs: for path in inp_refs: try: refer = get_spepc(hps, path.name).to(dtype).to(device) refers.append(refer) except: traceback.print_exc() - if(len(refers)==0):refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)] - audio = vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers,speed=speed)[0][0]#.cpu().detach().numpy() + if len(refers) == 0: refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)] + audio = vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed)[0][0] # .cpu().detach().numpy() else: refer = get_spepc(hps, ref_wav_path).to(device).to(dtype) - phoneme_ids0=torch.LongTensor(phones1).to(device).unsqueeze(0) - phoneme_ids1=torch.LongTensor(phones2).to(device).unsqueeze(0) + phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0) + phoneme_ids1 = torch.LongTensor(phones2).to(device).unsqueeze(0) # print(11111111, phoneme_ids0, phoneme_ids1) - fea_ref,ge = vq_model.decode_encp(prompt.unsqueeze(0), phoneme_ids0, refer) + fea_ref, ge = vq_model.decode_encp(prompt.unsqueeze(0), phoneme_ids0, refer) ref_audio, sr = torchaudio.load(ref_wav_path) - ref_audio=ref_audio.to(device).float() + ref_audio = ref_audio.to(device).float() if (ref_audio.shape[0] == 2): ref_audio = ref_audio.mean(0).unsqueeze(0) - if sr!=24000: - ref_audio=resample(ref_audio,sr) - # print("ref_audio",ref_audio.abs().mean()) + if sr != 24000: + ref_audio = resample(ref_audio, sr) + # print("ref_audio", ref_audio.abs().mean()) mel2 = mel_fn(ref_audio) mel2 = norm_spec(mel2) T_min = min(mel2.shape[2], fea_ref.shape[2]) @@ -663,12 +666,12 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, fea_ref = fea_ref[:, :, -468:] T_min = 468 chunk_len = 934 - T_min - # print("fea_ref",fea_ref,fea_ref.shape) - # print("mel2",mel2) - mel2=mel2.to(dtype) - fea_todo, ge = vq_model.decode_encp(pred_semantic, phoneme_ids1, refer, ge,speed) - # print("fea_todo",fea_todo) - # print("ge",ge.abs().mean()) + # print("fea_ref", fea_ref, fea_ref.shape) + # print("mel2", mel2) + mel2 = mel2.to(dtype) + fea_todo, ge = vq_model.decode_encp(pred_semantic, phoneme_ids1, refer, ge, speed) + # print("fea_todo", fea_todo) + # print("ge", ge.abs().mean()) cfm_resss = [] idx = 0 while (1): @@ -686,30 +689,62 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, cfm_resss.append(cfm_res) cmf_res = torch.cat(cfm_resss, 2) cmf_res = denorm_spec(cmf_res) - if bigvgan_model==None:init_bigvgan() + if bigvgan_model == None: init_bigvgan() with torch.inference_mode(): wav_gen = bigvgan_model(cmf_res) - audio=wav_gen[0][0]#.cpu().detach().numpy() - max_audio=torch.abs(audio).max()#简单防止16bit爆音 - if max_audio>1:audio=audio/max_audio + audio = wav_gen[0][0] # .cpu().detach().numpy() + + # Initial clipping check + max_audio = torch.abs(audio).max() # 简单防止16bit爆音 + if max_audio > 1: + audio = audio / max_audio + + # Apply new parameters + audio = audio.to(torch.float32) # Ensure float32 for processing + if loudness_boost: + # Boost loudness using RMS-based scaling (adjust multiplier as needed) + rms = torch.sqrt(torch.mean(audio ** 2)) + audio = audio * (rms * 1.5) if rms > 0 else audio + if gain > 0: + # Apply gain in dB + audio = audio * (10 ** (gain / 20)) + if normalize: + # Normalize to [-1, 1] + max_abs = torch.abs(audio).max() + audio = audio / max_abs if max_abs > 0 else audio + if energy_scale != 1.0: + # Scale energy + audio = audio * torch.sqrt(torch.tensor(energy_scale)) + if volume_scale != 1.0: + # Direct volume scaling + audio = audio * volume_scale + if strain_effect > 0.0: + # Add strain effect (basic distortion) + audio = audio + (audio ** 2 * strain_effect) + + # Final clipping check after effects + max_audio = torch.abs(audio).max() + if max_audio > 1: + audio = audio / max_audio + audio_opt.append(audio) - audio_opt.append(zero_wav_torch)#zero_wav + audio_opt.append(zero_wav_torch) # zero_wav t4 = ttime() - t.extend([t2 - t1,t3 - t2, t4 - t3]) + t.extend([t2 - t1, t3 - t2, t4 - t3]) t1 = ttime() + print("%.3f\t%.3f\t%.3f\t%.3f" % (t[0], sum(t[1::3]), sum(t[2::3]), sum(t[3::3]))) - audio_opt=torch.cat(audio_opt, 0)#np.concatenate - sr=hps.data.sampling_rate if model_version!="v3"else 24000 - if if_sr==True and sr==24000: + audio_opt = torch.cat(audio_opt, 0) # np.concatenate + sr = hps.data.sampling_rate if model_version != "v3" else 24000 + if if_sr == True and sr == 24000: print(i18n("音频超分中")) - audio_opt,sr=audio_sr(audio_opt.unsqueeze(0),sr) - max_audio=np.abs(audio_opt).max() + audio_opt, sr = audio_sr(audio_opt.unsqueeze(0), sr) + max_audio = np.abs(audio_opt).max() if max_audio > 1: audio_opt /= max_audio else: - audio_opt=audio_opt.cpu().detach().numpy() + audio_opt = audio_opt.cpu().detach().numpy() yield sr, (audio_opt * 32767).astype(np.int16) - def split(todo_text): todo_text = todo_text.replace("……", "。").replace("——", ",") if todo_text[-1] not in splits: diff --git a/api.py b/api.py index 76dbfe11..f42dbd96 100644 --- a/api.py +++ b/api.py @@ -1158,6 +1158,12 @@ def version_4_cli( character_name: str = "Kurari", model_id: int = 14, version: str = "v1", # v3 or v4 + loudness_boost=False, + gain=0, + normalize=False, + energy_scale=1.0, + volume_scale=1.0, + strain_effect=0.0, ): # Create a temporary buffer to store the audio audio_buffer = io.BytesIO() @@ -1187,7 +1193,13 @@ def version_4_cli( ref_language = ref_language, target_text = target_text, text_language = text_language, - output_path = output_path # Don't save to file + output_path = output_path, # Don't save to file + loudness_boost=loudness_boost, + gain=gain, + normalize=normalize, + energy_scale=energy_scale, + volume_scale=volume_scale, + strain_effect=strain_effect ) # Get the last audio data and sample rate from synthesis result