add new params

This commit is contained in:
samiabat 2025-06-24 08:17:08 +03:00
parent 9d24ff72cf
commit 9e4313fb4e
3 changed files with 108 additions and 52 deletions

View File

@ -7,7 +7,8 @@ from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights
i18n = I18nAuto() i18n = I18nAuto()
def synthesize(GPT_model_path, SoVITS_model_path, ref_audio_path, ref_text, ref_language, target_text, text_language, output_path): def synthesize(GPT_model_path, SoVITS_model_path, ref_audio_path, ref_text, ref_language, target_text, text_language, output_path, loudness_boost=False, gain=0, normalize=False,
energy_scale=1.0, volume_scale=1.0, strain_effect=0.0):
# Change model weights # Change model weights
change_gpt_weights(gpt_path=GPT_model_path) change_gpt_weights(gpt_path=GPT_model_path)
@ -18,7 +19,15 @@ def synthesize(GPT_model_path, SoVITS_model_path, ref_audio_path, ref_text, ref_
prompt_text=ref_text, prompt_text=ref_text,
prompt_language=i18n(ref_language), prompt_language=i18n(ref_language),
text=target_text, text=target_text,
text_language=i18n(text_language), top_p=1, temperature=1) text_language=i18n(text_language),
top_p=1,
temperature=1,
loudness_boost=loudness_boost,
gain=gain,
normalize=normalize,
energy_scale=energy_scale,
volume_scale=volume_scale,
strain_effect=strain_effect)
result_list = list(synthesis_result) result_list = list(synthesis_result)

View File

@ -510,24 +510,27 @@ def audio_sr(audio,sr):
##ref_wav_path+prompt_text+prompt_language+text(单个)+text_language+top_k+top_p+temperature ##ref_wav_path+prompt_text+prompt_language+text(单个)+text_language+top_k+top_p+temperature
# cache_tokens={}#暂未实现清理机制 # cache_tokens={}#暂未实现清理机制
cache= {} cache= {}
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=15, top_p=1, temperature=1, ref_free = False,speed=1,if_freeze=False,inp_refs=None,sample_steps=16,if_sr=False,pause_second=0.3): def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"),
top_k=15, top_p=1, temperature=1, ref_free=False, speed=1, if_freeze=False,
inp_refs=None, sample_steps=16, if_sr=False, pause_second=0.3,
loudness_boost=False, gain=0, normalize=False, energy_scale=1.0,
volume_scale=1.0, strain_effect=0.0):
global cache global cache
if ref_wav_path:pass if ref_wav_path: pass
else:gr.Warning(i18n('请上传参考音频')) else: gr.Warning(i18n('请上传参考音频'))
if text:pass if text: pass
else:gr.Warning(i18n('请填入推理文本')) else: gr.Warning(i18n('请填入推理文本'))
t = [] t = []
if prompt_text is None or len(prompt_text) == 0: if prompt_text is None or len(prompt_text) == 0:
ref_free = True ref_free = True
if model_version=="v3": if model_version == "v3":
ref_free=False#s2v3暂不支持ref_free ref_free = False # s2v3暂不支持ref_free
else: else:
if_sr=False if_sr = False
t0 = ttime() t0 = ttime()
prompt_language = dict_language[prompt_language] prompt_language = dict_language[prompt_language]
text_language = dict_language[text_language] text_language = dict_language[text_language]
if not ref_free: if not ref_free:
prompt_text = prompt_text.strip("\n") prompt_text = prompt_text.strip("\n")
if (prompt_text[-1] not in splits): prompt_text += "" if prompt_language != "en" else "." if (prompt_text[-1] not in splits): prompt_text += "" if prompt_language != "en" else "."
@ -567,7 +570,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
prompt = prompt_semantic.unsqueeze(0).to(device) prompt = prompt_semantic.unsqueeze(0).to(device)
t1 = ttime() t1 = ttime()
t.append(t1-t0) t.append(t1 - t0)
if (how_to_cut == i18n("凑四句一切")): if (how_to_cut == i18n("凑四句一切")):
text = cut1(text) text = cut1(text)
@ -586,21 +589,21 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
texts = process_text(texts) texts = process_text(texts)
texts = merge_short_text_in_array(texts, 5) texts = merge_short_text_in_array(texts, 5)
audio_opt = [] audio_opt = []
###s2v3暂不支持ref_free ### s2v3暂不支持ref_free
if not ref_free: if not ref_free:
phones1,bert1,norm_text1=get_phones_and_bert(prompt_text, prompt_language, version) phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language, version)
for i_text,text in enumerate(texts): for i_text, text in enumerate(texts):
# 解决输入目标文本的空行导致报错的问题 # 解决输入目标文本的空行导致报错的问题
if (len(text.strip()) == 0): if (len(text.strip()) == 0):
continue continue
if (text[-1] not in splits): text += "" if text_language != "en" else "." if (text[-1] not in splits): text += "" if text_language != "en" else "."
print(i18n("实际输入的目标文本(每句):"), text) print(i18n("实际输入的目标文本(每句):"), text)
phones2,bert2,norm_text2=get_phones_and_bert(text, text_language, version) phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language, version)
print(i18n("前端处理后的文本(每句):"), norm_text2) print(i18n("前端处理后的文本(每句):"), norm_text2)
if not ref_free: if not ref_free:
bert = torch.cat([bert1, bert2], 1) bert = torch.cat([bert1, bert2], 1)
all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0) all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
else: else:
bert = bert2 bert = bert2
all_phoneme_ids = torch.LongTensor(phones2).to(device).unsqueeze(0) all_phoneme_ids = torch.LongTensor(phones2).to(device).unsqueeze(0)
@ -611,7 +614,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
t2 = ttime() t2 = ttime()
# cache_key="%s-%s-%s-%s-%s-%s-%s-%s"%(ref_wav_path,prompt_text,prompt_language,text,text_language,top_k,top_p,temperature) # cache_key="%s-%s-%s-%s-%s-%s-%s-%s"%(ref_wav_path,prompt_text,prompt_language,text,text_language,top_k,top_p,temperature)
# print(cache.keys(),if_freeze) # print(cache.keys(),if_freeze)
if(i_text in cache and if_freeze==True):pred_semantic=cache[i_text] if(i_text in cache and if_freeze == True): pred_semantic = cache[i_text]
else: else:
with torch.no_grad(): with torch.no_grad():
pred_semantic, idx = t2s_model.model.infer_panel( pred_semantic, idx = t2s_model.model.infer_panel(
@ -626,33 +629,33 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
early_stop_num=hz * max_sec, early_stop_num=hz * max_sec,
) )
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) pred_semantic = pred_semantic[:, -idx:].unsqueeze(0)
cache[i_text]=pred_semantic cache[i_text] = pred_semantic
t3 = ttime() t3 = ttime()
###v3不存在以下逻辑和inp_refs ### v3不存在以下逻辑和inp_refs
if model_version!="v3": if model_version != "v3":
refers=[] refers = []
if(inp_refs): if inp_refs:
for path in inp_refs: for path in inp_refs:
try: try:
refer = get_spepc(hps, path.name).to(dtype).to(device) refer = get_spepc(hps, path.name).to(dtype).to(device)
refers.append(refer) refers.append(refer)
except: except:
traceback.print_exc() traceback.print_exc()
if(len(refers)==0):refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)] if len(refers) == 0: refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)]
audio = vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers,speed=speed)[0][0]#.cpu().detach().numpy() audio = vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed)[0][0] # .cpu().detach().numpy()
else: else:
refer = get_spepc(hps, ref_wav_path).to(device).to(dtype) refer = get_spepc(hps, ref_wav_path).to(device).to(dtype)
phoneme_ids0=torch.LongTensor(phones1).to(device).unsqueeze(0) phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0)
phoneme_ids1=torch.LongTensor(phones2).to(device).unsqueeze(0) phoneme_ids1 = torch.LongTensor(phones2).to(device).unsqueeze(0)
# print(11111111, phoneme_ids0, phoneme_ids1) # print(11111111, phoneme_ids0, phoneme_ids1)
fea_ref,ge = vq_model.decode_encp(prompt.unsqueeze(0), phoneme_ids0, refer) fea_ref, ge = vq_model.decode_encp(prompt.unsqueeze(0), phoneme_ids0, refer)
ref_audio, sr = torchaudio.load(ref_wav_path) ref_audio, sr = torchaudio.load(ref_wav_path)
ref_audio=ref_audio.to(device).float() ref_audio = ref_audio.to(device).float()
if (ref_audio.shape[0] == 2): if (ref_audio.shape[0] == 2):
ref_audio = ref_audio.mean(0).unsqueeze(0) ref_audio = ref_audio.mean(0).unsqueeze(0)
if sr!=24000: if sr != 24000:
ref_audio=resample(ref_audio,sr) ref_audio = resample(ref_audio, sr)
# print("ref_audio",ref_audio.abs().mean()) # print("ref_audio", ref_audio.abs().mean())
mel2 = mel_fn(ref_audio) mel2 = mel_fn(ref_audio)
mel2 = norm_spec(mel2) mel2 = norm_spec(mel2)
T_min = min(mel2.shape[2], fea_ref.shape[2]) T_min = min(mel2.shape[2], fea_ref.shape[2])
@ -663,12 +666,12 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
fea_ref = fea_ref[:, :, -468:] fea_ref = fea_ref[:, :, -468:]
T_min = 468 T_min = 468
chunk_len = 934 - T_min chunk_len = 934 - T_min
# print("fea_ref",fea_ref,fea_ref.shape) # print("fea_ref", fea_ref, fea_ref.shape)
# print("mel2",mel2) # print("mel2", mel2)
mel2=mel2.to(dtype) mel2 = mel2.to(dtype)
fea_todo, ge = vq_model.decode_encp(pred_semantic, phoneme_ids1, refer, ge,speed) fea_todo, ge = vq_model.decode_encp(pred_semantic, phoneme_ids1, refer, ge, speed)
# print("fea_todo",fea_todo) # print("fea_todo", fea_todo)
# print("ge",ge.abs().mean()) # print("ge", ge.abs().mean())
cfm_resss = [] cfm_resss = []
idx = 0 idx = 0
while (1): while (1):
@ -686,30 +689,62 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
cfm_resss.append(cfm_res) cfm_resss.append(cfm_res)
cmf_res = torch.cat(cfm_resss, 2) cmf_res = torch.cat(cfm_resss, 2)
cmf_res = denorm_spec(cmf_res) cmf_res = denorm_spec(cmf_res)
if bigvgan_model==None:init_bigvgan() if bigvgan_model == None: init_bigvgan()
with torch.inference_mode(): with torch.inference_mode():
wav_gen = bigvgan_model(cmf_res) wav_gen = bigvgan_model(cmf_res)
audio=wav_gen[0][0]#.cpu().detach().numpy() audio = wav_gen[0][0] # .cpu().detach().numpy()
max_audio=torch.abs(audio).max()#简单防止16bit爆音
if max_audio>1:audio=audio/max_audio # Initial clipping check
max_audio = torch.abs(audio).max() # 简单防止16bit爆音
if max_audio > 1:
audio = audio / max_audio
# Apply new parameters
audio = audio.to(torch.float32) # Ensure float32 for processing
if loudness_boost:
# Boost loudness using RMS-based scaling (adjust multiplier as needed)
rms = torch.sqrt(torch.mean(audio ** 2))
audio = audio * (rms * 1.5) if rms > 0 else audio
if gain > 0:
# Apply gain in dB
audio = audio * (10 ** (gain / 20))
if normalize:
# Normalize to [-1, 1]
max_abs = torch.abs(audio).max()
audio = audio / max_abs if max_abs > 0 else audio
if energy_scale != 1.0:
# Scale energy
audio = audio * torch.sqrt(torch.tensor(energy_scale))
if volume_scale != 1.0:
# Direct volume scaling
audio = audio * volume_scale
if strain_effect > 0.0:
# Add strain effect (basic distortion)
audio = audio + (audio ** 2 * strain_effect)
# Final clipping check after effects
max_audio = torch.abs(audio).max()
if max_audio > 1:
audio = audio / max_audio
audio_opt.append(audio) audio_opt.append(audio)
audio_opt.append(zero_wav_torch)#zero_wav audio_opt.append(zero_wav_torch) # zero_wav
t4 = ttime() t4 = ttime()
t.extend([t2 - t1,t3 - t2, t4 - t3]) t.extend([t2 - t1, t3 - t2, t4 - t3])
t1 = ttime() t1 = ttime()
print("%.3f\t%.3f\t%.3f\t%.3f" % (t[0], sum(t[1::3]), sum(t[2::3]), sum(t[3::3]))) print("%.3f\t%.3f\t%.3f\t%.3f" % (t[0], sum(t[1::3]), sum(t[2::3]), sum(t[3::3])))
audio_opt=torch.cat(audio_opt, 0)#np.concatenate audio_opt = torch.cat(audio_opt, 0) # np.concatenate
sr=hps.data.sampling_rate if model_version!="v3"else 24000 sr = hps.data.sampling_rate if model_version != "v3" else 24000
if if_sr==True and sr==24000: if if_sr == True and sr == 24000:
print(i18n("音频超分中")) print(i18n("音频超分中"))
audio_opt,sr=audio_sr(audio_opt.unsqueeze(0),sr) audio_opt, sr = audio_sr(audio_opt.unsqueeze(0), sr)
max_audio=np.abs(audio_opt).max() max_audio = np.abs(audio_opt).max()
if max_audio > 1: audio_opt /= max_audio if max_audio > 1: audio_opt /= max_audio
else: else:
audio_opt=audio_opt.cpu().detach().numpy() audio_opt = audio_opt.cpu().detach().numpy()
yield sr, (audio_opt * 32767).astype(np.int16) yield sr, (audio_opt * 32767).astype(np.int16)
def split(todo_text): def split(todo_text):
todo_text = todo_text.replace("……", "").replace("——", "") todo_text = todo_text.replace("……", "").replace("——", "")
if todo_text[-1] not in splits: if todo_text[-1] not in splits:

14
api.py
View File

@ -1158,6 +1158,12 @@ def version_4_cli(
character_name: str = "Kurari", character_name: str = "Kurari",
model_id: int = 14, model_id: int = 14,
version: str = "v1", # v3 or v4 version: str = "v1", # v3 or v4
loudness_boost=False,
gain=0,
normalize=False,
energy_scale=1.0,
volume_scale=1.0,
strain_effect=0.0,
): ):
# Create a temporary buffer to store the audio # Create a temporary buffer to store the audio
audio_buffer = io.BytesIO() audio_buffer = io.BytesIO()
@ -1187,7 +1193,13 @@ def version_4_cli(
ref_language = ref_language, ref_language = ref_language,
target_text = target_text, target_text = target_text,
text_language = text_language, text_language = text_language,
output_path = output_path # Don't save to file output_path = output_path, # Don't save to file
loudness_boost=loudness_boost,
gain=gain,
normalize=normalize,
energy_scale=energy_scale,
volume_scale=volume_scale,
strain_effect=strain_effect
) )
# Get the last audio data and sample rate from synthesis result # Get the last audio data and sample rate from synthesis result