mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-08-15 05:21:57 +08:00
add new params
This commit is contained in:
parent
9d24ff72cf
commit
9e4313fb4e
@ -7,7 +7,8 @@ from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights
|
||||
|
||||
i18n = I18nAuto()
|
||||
|
||||
def synthesize(GPT_model_path, SoVITS_model_path, ref_audio_path, ref_text, ref_language, target_text, text_language, output_path):
|
||||
def synthesize(GPT_model_path, SoVITS_model_path, ref_audio_path, ref_text, ref_language, target_text, text_language, output_path, loudness_boost=False, gain=0, normalize=False,
|
||||
energy_scale=1.0, volume_scale=1.0, strain_effect=0.0):
|
||||
|
||||
# Change model weights
|
||||
change_gpt_weights(gpt_path=GPT_model_path)
|
||||
@ -18,7 +19,15 @@ def synthesize(GPT_model_path, SoVITS_model_path, ref_audio_path, ref_text, ref_
|
||||
prompt_text=ref_text,
|
||||
prompt_language=i18n(ref_language),
|
||||
text=target_text,
|
||||
text_language=i18n(text_language), top_p=1, temperature=1)
|
||||
text_language=i18n(text_language),
|
||||
top_p=1,
|
||||
temperature=1,
|
||||
loudness_boost=loudness_boost,
|
||||
gain=gain,
|
||||
normalize=normalize,
|
||||
energy_scale=energy_scale,
|
||||
volume_scale=volume_scale,
|
||||
strain_effect=strain_effect)
|
||||
|
||||
result_list = list(synthesis_result)
|
||||
|
||||
|
@ -510,24 +510,27 @@ def audio_sr(audio,sr):
|
||||
##ref_wav_path+prompt_text+prompt_language+text(单个)+text_language+top_k+top_p+temperature
|
||||
# cache_tokens={}#暂未实现清理机制
|
||||
cache= {}
|
||||
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=15, top_p=1, temperature=1, ref_free = False,speed=1,if_freeze=False,inp_refs=None,sample_steps=16,if_sr=False,pause_second=0.3):
|
||||
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"),
|
||||
top_k=15, top_p=1, temperature=1, ref_free=False, speed=1, if_freeze=False,
|
||||
inp_refs=None, sample_steps=16, if_sr=False, pause_second=0.3,
|
||||
loudness_boost=False, gain=0, normalize=False, energy_scale=1.0,
|
||||
volume_scale=1.0, strain_effect=0.0):
|
||||
global cache
|
||||
if ref_wav_path:pass
|
||||
else:gr.Warning(i18n('请上传参考音频'))
|
||||
if text:pass
|
||||
else:gr.Warning(i18n('请填入推理文本'))
|
||||
if ref_wav_path: pass
|
||||
else: gr.Warning(i18n('请上传参考音频'))
|
||||
if text: pass
|
||||
else: gr.Warning(i18n('请填入推理文本'))
|
||||
t = []
|
||||
if prompt_text is None or len(prompt_text) == 0:
|
||||
ref_free = True
|
||||
if model_version=="v3":
|
||||
ref_free=False#s2v3暂不支持ref_free
|
||||
if model_version == "v3":
|
||||
ref_free = False # s2v3暂不支持ref_free
|
||||
else:
|
||||
if_sr=False
|
||||
if_sr = False
|
||||
t0 = ttime()
|
||||
prompt_language = dict_language[prompt_language]
|
||||
text_language = dict_language[text_language]
|
||||
|
||||
|
||||
if not ref_free:
|
||||
prompt_text = prompt_text.strip("\n")
|
||||
if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_language != "en" else "."
|
||||
@ -567,7 +570,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
||||
prompt = prompt_semantic.unsqueeze(0).to(device)
|
||||
|
||||
t1 = ttime()
|
||||
t.append(t1-t0)
|
||||
t.append(t1 - t0)
|
||||
|
||||
if (how_to_cut == i18n("凑四句一切")):
|
||||
text = cut1(text)
|
||||
@ -586,21 +589,21 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
||||
texts = process_text(texts)
|
||||
texts = merge_short_text_in_array(texts, 5)
|
||||
audio_opt = []
|
||||
###s2v3暂不支持ref_free
|
||||
### s2v3暂不支持ref_free
|
||||
if not ref_free:
|
||||
phones1,bert1,norm_text1=get_phones_and_bert(prompt_text, prompt_language, version)
|
||||
phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language, version)
|
||||
|
||||
for i_text,text in enumerate(texts):
|
||||
for i_text, text in enumerate(texts):
|
||||
# 解决输入目标文本的空行导致报错的问题
|
||||
if (len(text.strip()) == 0):
|
||||
continue
|
||||
if (text[-1] not in splits): text += "。" if text_language != "en" else "."
|
||||
print(i18n("实际输入的目标文本(每句):"), text)
|
||||
phones2,bert2,norm_text2=get_phones_and_bert(text, text_language, version)
|
||||
phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language, version)
|
||||
print(i18n("前端处理后的文本(每句):"), norm_text2)
|
||||
if not ref_free:
|
||||
bert = torch.cat([bert1, bert2], 1)
|
||||
all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0)
|
||||
all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
|
||||
else:
|
||||
bert = bert2
|
||||
all_phoneme_ids = torch.LongTensor(phones2).to(device).unsqueeze(0)
|
||||
@ -611,7 +614,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
||||
t2 = ttime()
|
||||
# cache_key="%s-%s-%s-%s-%s-%s-%s-%s"%(ref_wav_path,prompt_text,prompt_language,text,text_language,top_k,top_p,temperature)
|
||||
# print(cache.keys(),if_freeze)
|
||||
if(i_text in cache and if_freeze==True):pred_semantic=cache[i_text]
|
||||
if(i_text in cache and if_freeze == True): pred_semantic = cache[i_text]
|
||||
else:
|
||||
with torch.no_grad():
|
||||
pred_semantic, idx = t2s_model.model.infer_panel(
|
||||
@ -626,33 +629,33 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
||||
early_stop_num=hz * max_sec,
|
||||
)
|
||||
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0)
|
||||
cache[i_text]=pred_semantic
|
||||
cache[i_text] = pred_semantic
|
||||
t3 = ttime()
|
||||
###v3不存在以下逻辑和inp_refs
|
||||
if model_version!="v3":
|
||||
refers=[]
|
||||
if(inp_refs):
|
||||
### v3不存在以下逻辑和inp_refs
|
||||
if model_version != "v3":
|
||||
refers = []
|
||||
if inp_refs:
|
||||
for path in inp_refs:
|
||||
try:
|
||||
refer = get_spepc(hps, path.name).to(dtype).to(device)
|
||||
refers.append(refer)
|
||||
except:
|
||||
traceback.print_exc()
|
||||
if(len(refers)==0):refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)]
|
||||
audio = vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers,speed=speed)[0][0]#.cpu().detach().numpy()
|
||||
if len(refers) == 0: refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)]
|
||||
audio = vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed)[0][0] # .cpu().detach().numpy()
|
||||
else:
|
||||
refer = get_spepc(hps, ref_wav_path).to(device).to(dtype)
|
||||
phoneme_ids0=torch.LongTensor(phones1).to(device).unsqueeze(0)
|
||||
phoneme_ids1=torch.LongTensor(phones2).to(device).unsqueeze(0)
|
||||
phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0)
|
||||
phoneme_ids1 = torch.LongTensor(phones2).to(device).unsqueeze(0)
|
||||
# print(11111111, phoneme_ids0, phoneme_ids1)
|
||||
fea_ref,ge = vq_model.decode_encp(prompt.unsqueeze(0), phoneme_ids0, refer)
|
||||
fea_ref, ge = vq_model.decode_encp(prompt.unsqueeze(0), phoneme_ids0, refer)
|
||||
ref_audio, sr = torchaudio.load(ref_wav_path)
|
||||
ref_audio=ref_audio.to(device).float()
|
||||
ref_audio = ref_audio.to(device).float()
|
||||
if (ref_audio.shape[0] == 2):
|
||||
ref_audio = ref_audio.mean(0).unsqueeze(0)
|
||||
if sr!=24000:
|
||||
ref_audio=resample(ref_audio,sr)
|
||||
# print("ref_audio",ref_audio.abs().mean())
|
||||
if sr != 24000:
|
||||
ref_audio = resample(ref_audio, sr)
|
||||
# print("ref_audio", ref_audio.abs().mean())
|
||||
mel2 = mel_fn(ref_audio)
|
||||
mel2 = norm_spec(mel2)
|
||||
T_min = min(mel2.shape[2], fea_ref.shape[2])
|
||||
@ -663,12 +666,12 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
||||
fea_ref = fea_ref[:, :, -468:]
|
||||
T_min = 468
|
||||
chunk_len = 934 - T_min
|
||||
# print("fea_ref",fea_ref,fea_ref.shape)
|
||||
# print("mel2",mel2)
|
||||
mel2=mel2.to(dtype)
|
||||
fea_todo, ge = vq_model.decode_encp(pred_semantic, phoneme_ids1, refer, ge,speed)
|
||||
# print("fea_todo",fea_todo)
|
||||
# print("ge",ge.abs().mean())
|
||||
# print("fea_ref", fea_ref, fea_ref.shape)
|
||||
# print("mel2", mel2)
|
||||
mel2 = mel2.to(dtype)
|
||||
fea_todo, ge = vq_model.decode_encp(pred_semantic, phoneme_ids1, refer, ge, speed)
|
||||
# print("fea_todo", fea_todo)
|
||||
# print("ge", ge.abs().mean())
|
||||
cfm_resss = []
|
||||
idx = 0
|
||||
while (1):
|
||||
@ -686,30 +689,62 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
||||
cfm_resss.append(cfm_res)
|
||||
cmf_res = torch.cat(cfm_resss, 2)
|
||||
cmf_res = denorm_spec(cmf_res)
|
||||
if bigvgan_model==None:init_bigvgan()
|
||||
if bigvgan_model == None: init_bigvgan()
|
||||
with torch.inference_mode():
|
||||
wav_gen = bigvgan_model(cmf_res)
|
||||
audio=wav_gen[0][0]#.cpu().detach().numpy()
|
||||
max_audio=torch.abs(audio).max()#简单防止16bit爆音
|
||||
if max_audio>1:audio=audio/max_audio
|
||||
audio = wav_gen[0][0] # .cpu().detach().numpy()
|
||||
|
||||
# Initial clipping check
|
||||
max_audio = torch.abs(audio).max() # 简单防止16bit爆音
|
||||
if max_audio > 1:
|
||||
audio = audio / max_audio
|
||||
|
||||
# Apply new parameters
|
||||
audio = audio.to(torch.float32) # Ensure float32 for processing
|
||||
if loudness_boost:
|
||||
# Boost loudness using RMS-based scaling (adjust multiplier as needed)
|
||||
rms = torch.sqrt(torch.mean(audio ** 2))
|
||||
audio = audio * (rms * 1.5) if rms > 0 else audio
|
||||
if gain > 0:
|
||||
# Apply gain in dB
|
||||
audio = audio * (10 ** (gain / 20))
|
||||
if normalize:
|
||||
# Normalize to [-1, 1]
|
||||
max_abs = torch.abs(audio).max()
|
||||
audio = audio / max_abs if max_abs > 0 else audio
|
||||
if energy_scale != 1.0:
|
||||
# Scale energy
|
||||
audio = audio * torch.sqrt(torch.tensor(energy_scale))
|
||||
if volume_scale != 1.0:
|
||||
# Direct volume scaling
|
||||
audio = audio * volume_scale
|
||||
if strain_effect > 0.0:
|
||||
# Add strain effect (basic distortion)
|
||||
audio = audio + (audio ** 2 * strain_effect)
|
||||
|
||||
# Final clipping check after effects
|
||||
max_audio = torch.abs(audio).max()
|
||||
if max_audio > 1:
|
||||
audio = audio / max_audio
|
||||
|
||||
audio_opt.append(audio)
|
||||
audio_opt.append(zero_wav_torch)#zero_wav
|
||||
audio_opt.append(zero_wav_torch) # zero_wav
|
||||
t4 = ttime()
|
||||
t.extend([t2 - t1,t3 - t2, t4 - t3])
|
||||
t.extend([t2 - t1, t3 - t2, t4 - t3])
|
||||
t1 = ttime()
|
||||
|
||||
print("%.3f\t%.3f\t%.3f\t%.3f" % (t[0], sum(t[1::3]), sum(t[2::3]), sum(t[3::3])))
|
||||
audio_opt=torch.cat(audio_opt, 0)#np.concatenate
|
||||
sr=hps.data.sampling_rate if model_version!="v3"else 24000
|
||||
if if_sr==True and sr==24000:
|
||||
audio_opt = torch.cat(audio_opt, 0) # np.concatenate
|
||||
sr = hps.data.sampling_rate if model_version != "v3" else 24000
|
||||
if if_sr == True and sr == 24000:
|
||||
print(i18n("音频超分中"))
|
||||
audio_opt,sr=audio_sr(audio_opt.unsqueeze(0),sr)
|
||||
max_audio=np.abs(audio_opt).max()
|
||||
audio_opt, sr = audio_sr(audio_opt.unsqueeze(0), sr)
|
||||
max_audio = np.abs(audio_opt).max()
|
||||
if max_audio > 1: audio_opt /= max_audio
|
||||
else:
|
||||
audio_opt=audio_opt.cpu().detach().numpy()
|
||||
audio_opt = audio_opt.cpu().detach().numpy()
|
||||
yield sr, (audio_opt * 32767).astype(np.int16)
|
||||
|
||||
|
||||
def split(todo_text):
|
||||
todo_text = todo_text.replace("……", "。").replace("——", ",")
|
||||
if todo_text[-1] not in splits:
|
||||
|
14
api.py
14
api.py
@ -1158,6 +1158,12 @@ def version_4_cli(
|
||||
character_name: str = "Kurari",
|
||||
model_id: int = 14,
|
||||
version: str = "v1", # v3 or v4
|
||||
loudness_boost=False,
|
||||
gain=0,
|
||||
normalize=False,
|
||||
energy_scale=1.0,
|
||||
volume_scale=1.0,
|
||||
strain_effect=0.0,
|
||||
):
|
||||
# Create a temporary buffer to store the audio
|
||||
audio_buffer = io.BytesIO()
|
||||
@ -1187,7 +1193,13 @@ def version_4_cli(
|
||||
ref_language = ref_language,
|
||||
target_text = target_text,
|
||||
text_language = text_language,
|
||||
output_path = output_path # Don't save to file
|
||||
output_path = output_path, # Don't save to file
|
||||
loudness_boost=loudness_boost,
|
||||
gain=gain,
|
||||
normalize=normalize,
|
||||
energy_scale=energy_scale,
|
||||
volume_scale=volume_scale,
|
||||
strain_effect=strain_effect
|
||||
)
|
||||
|
||||
# Get the last audio data and sample rate from synthesis result
|
||||
|
Loading…
x
Reference in New Issue
Block a user