mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-08-15 05:21:57 +08:00
add new params
This commit is contained in:
parent
9d24ff72cf
commit
9e4313fb4e
@ -7,7 +7,8 @@ from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights
|
|||||||
|
|
||||||
i18n = I18nAuto()
|
i18n = I18nAuto()
|
||||||
|
|
||||||
def synthesize(GPT_model_path, SoVITS_model_path, ref_audio_path, ref_text, ref_language, target_text, text_language, output_path):
|
def synthesize(GPT_model_path, SoVITS_model_path, ref_audio_path, ref_text, ref_language, target_text, text_language, output_path, loudness_boost=False, gain=0, normalize=False,
|
||||||
|
energy_scale=1.0, volume_scale=1.0, strain_effect=0.0):
|
||||||
|
|
||||||
# Change model weights
|
# Change model weights
|
||||||
change_gpt_weights(gpt_path=GPT_model_path)
|
change_gpt_weights(gpt_path=GPT_model_path)
|
||||||
@ -18,7 +19,15 @@ def synthesize(GPT_model_path, SoVITS_model_path, ref_audio_path, ref_text, ref_
|
|||||||
prompt_text=ref_text,
|
prompt_text=ref_text,
|
||||||
prompt_language=i18n(ref_language),
|
prompt_language=i18n(ref_language),
|
||||||
text=target_text,
|
text=target_text,
|
||||||
text_language=i18n(text_language), top_p=1, temperature=1)
|
text_language=i18n(text_language),
|
||||||
|
top_p=1,
|
||||||
|
temperature=1,
|
||||||
|
loudness_boost=loudness_boost,
|
||||||
|
gain=gain,
|
||||||
|
normalize=normalize,
|
||||||
|
energy_scale=energy_scale,
|
||||||
|
volume_scale=volume_scale,
|
||||||
|
strain_effect=strain_effect)
|
||||||
|
|
||||||
result_list = list(synthesis_result)
|
result_list = list(synthesis_result)
|
||||||
|
|
||||||
|
@ -510,24 +510,27 @@ def audio_sr(audio,sr):
|
|||||||
##ref_wav_path+prompt_text+prompt_language+text(单个)+text_language+top_k+top_p+temperature
|
##ref_wav_path+prompt_text+prompt_language+text(单个)+text_language+top_k+top_p+temperature
|
||||||
# cache_tokens={}#暂未实现清理机制
|
# cache_tokens={}#暂未实现清理机制
|
||||||
cache= {}
|
cache= {}
|
||||||
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=15, top_p=1, temperature=1, ref_free = False,speed=1,if_freeze=False,inp_refs=None,sample_steps=16,if_sr=False,pause_second=0.3):
|
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"),
|
||||||
|
top_k=15, top_p=1, temperature=1, ref_free=False, speed=1, if_freeze=False,
|
||||||
|
inp_refs=None, sample_steps=16, if_sr=False, pause_second=0.3,
|
||||||
|
loudness_boost=False, gain=0, normalize=False, energy_scale=1.0,
|
||||||
|
volume_scale=1.0, strain_effect=0.0):
|
||||||
global cache
|
global cache
|
||||||
if ref_wav_path:pass
|
if ref_wav_path: pass
|
||||||
else:gr.Warning(i18n('请上传参考音频'))
|
else: gr.Warning(i18n('请上传参考音频'))
|
||||||
if text:pass
|
if text: pass
|
||||||
else:gr.Warning(i18n('请填入推理文本'))
|
else: gr.Warning(i18n('请填入推理文本'))
|
||||||
t = []
|
t = []
|
||||||
if prompt_text is None or len(prompt_text) == 0:
|
if prompt_text is None or len(prompt_text) == 0:
|
||||||
ref_free = True
|
ref_free = True
|
||||||
if model_version=="v3":
|
if model_version == "v3":
|
||||||
ref_free=False#s2v3暂不支持ref_free
|
ref_free = False # s2v3暂不支持ref_free
|
||||||
else:
|
else:
|
||||||
if_sr=False
|
if_sr = False
|
||||||
t0 = ttime()
|
t0 = ttime()
|
||||||
prompt_language = dict_language[prompt_language]
|
prompt_language = dict_language[prompt_language]
|
||||||
text_language = dict_language[text_language]
|
text_language = dict_language[text_language]
|
||||||
|
|
||||||
|
|
||||||
if not ref_free:
|
if not ref_free:
|
||||||
prompt_text = prompt_text.strip("\n")
|
prompt_text = prompt_text.strip("\n")
|
||||||
if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_language != "en" else "."
|
if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_language != "en" else "."
|
||||||
@ -567,7 +570,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
|||||||
prompt = prompt_semantic.unsqueeze(0).to(device)
|
prompt = prompt_semantic.unsqueeze(0).to(device)
|
||||||
|
|
||||||
t1 = ttime()
|
t1 = ttime()
|
||||||
t.append(t1-t0)
|
t.append(t1 - t0)
|
||||||
|
|
||||||
if (how_to_cut == i18n("凑四句一切")):
|
if (how_to_cut == i18n("凑四句一切")):
|
||||||
text = cut1(text)
|
text = cut1(text)
|
||||||
@ -586,21 +589,21 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
|||||||
texts = process_text(texts)
|
texts = process_text(texts)
|
||||||
texts = merge_short_text_in_array(texts, 5)
|
texts = merge_short_text_in_array(texts, 5)
|
||||||
audio_opt = []
|
audio_opt = []
|
||||||
###s2v3暂不支持ref_free
|
### s2v3暂不支持ref_free
|
||||||
if not ref_free:
|
if not ref_free:
|
||||||
phones1,bert1,norm_text1=get_phones_and_bert(prompt_text, prompt_language, version)
|
phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language, version)
|
||||||
|
|
||||||
for i_text,text in enumerate(texts):
|
for i_text, text in enumerate(texts):
|
||||||
# 解决输入目标文本的空行导致报错的问题
|
# 解决输入目标文本的空行导致报错的问题
|
||||||
if (len(text.strip()) == 0):
|
if (len(text.strip()) == 0):
|
||||||
continue
|
continue
|
||||||
if (text[-1] not in splits): text += "。" if text_language != "en" else "."
|
if (text[-1] not in splits): text += "。" if text_language != "en" else "."
|
||||||
print(i18n("实际输入的目标文本(每句):"), text)
|
print(i18n("实际输入的目标文本(每句):"), text)
|
||||||
phones2,bert2,norm_text2=get_phones_and_bert(text, text_language, version)
|
phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language, version)
|
||||||
print(i18n("前端处理后的文本(每句):"), norm_text2)
|
print(i18n("前端处理后的文本(每句):"), norm_text2)
|
||||||
if not ref_free:
|
if not ref_free:
|
||||||
bert = torch.cat([bert1, bert2], 1)
|
bert = torch.cat([bert1, bert2], 1)
|
||||||
all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0)
|
all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
|
||||||
else:
|
else:
|
||||||
bert = bert2
|
bert = bert2
|
||||||
all_phoneme_ids = torch.LongTensor(phones2).to(device).unsqueeze(0)
|
all_phoneme_ids = torch.LongTensor(phones2).to(device).unsqueeze(0)
|
||||||
@ -611,7 +614,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
|||||||
t2 = ttime()
|
t2 = ttime()
|
||||||
# cache_key="%s-%s-%s-%s-%s-%s-%s-%s"%(ref_wav_path,prompt_text,prompt_language,text,text_language,top_k,top_p,temperature)
|
# cache_key="%s-%s-%s-%s-%s-%s-%s-%s"%(ref_wav_path,prompt_text,prompt_language,text,text_language,top_k,top_p,temperature)
|
||||||
# print(cache.keys(),if_freeze)
|
# print(cache.keys(),if_freeze)
|
||||||
if(i_text in cache and if_freeze==True):pred_semantic=cache[i_text]
|
if(i_text in cache and if_freeze == True): pred_semantic = cache[i_text]
|
||||||
else:
|
else:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
pred_semantic, idx = t2s_model.model.infer_panel(
|
pred_semantic, idx = t2s_model.model.infer_panel(
|
||||||
@ -626,33 +629,33 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
|||||||
early_stop_num=hz * max_sec,
|
early_stop_num=hz * max_sec,
|
||||||
)
|
)
|
||||||
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0)
|
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0)
|
||||||
cache[i_text]=pred_semantic
|
cache[i_text] = pred_semantic
|
||||||
t3 = ttime()
|
t3 = ttime()
|
||||||
###v3不存在以下逻辑和inp_refs
|
### v3不存在以下逻辑和inp_refs
|
||||||
if model_version!="v3":
|
if model_version != "v3":
|
||||||
refers=[]
|
refers = []
|
||||||
if(inp_refs):
|
if inp_refs:
|
||||||
for path in inp_refs:
|
for path in inp_refs:
|
||||||
try:
|
try:
|
||||||
refer = get_spepc(hps, path.name).to(dtype).to(device)
|
refer = get_spepc(hps, path.name).to(dtype).to(device)
|
||||||
refers.append(refer)
|
refers.append(refer)
|
||||||
except:
|
except:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
if(len(refers)==0):refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)]
|
if len(refers) == 0: refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)]
|
||||||
audio = vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers,speed=speed)[0][0]#.cpu().detach().numpy()
|
audio = vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed)[0][0] # .cpu().detach().numpy()
|
||||||
else:
|
else:
|
||||||
refer = get_spepc(hps, ref_wav_path).to(device).to(dtype)
|
refer = get_spepc(hps, ref_wav_path).to(device).to(dtype)
|
||||||
phoneme_ids0=torch.LongTensor(phones1).to(device).unsqueeze(0)
|
phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0)
|
||||||
phoneme_ids1=torch.LongTensor(phones2).to(device).unsqueeze(0)
|
phoneme_ids1 = torch.LongTensor(phones2).to(device).unsqueeze(0)
|
||||||
# print(11111111, phoneme_ids0, phoneme_ids1)
|
# print(11111111, phoneme_ids0, phoneme_ids1)
|
||||||
fea_ref,ge = vq_model.decode_encp(prompt.unsqueeze(0), phoneme_ids0, refer)
|
fea_ref, ge = vq_model.decode_encp(prompt.unsqueeze(0), phoneme_ids0, refer)
|
||||||
ref_audio, sr = torchaudio.load(ref_wav_path)
|
ref_audio, sr = torchaudio.load(ref_wav_path)
|
||||||
ref_audio=ref_audio.to(device).float()
|
ref_audio = ref_audio.to(device).float()
|
||||||
if (ref_audio.shape[0] == 2):
|
if (ref_audio.shape[0] == 2):
|
||||||
ref_audio = ref_audio.mean(0).unsqueeze(0)
|
ref_audio = ref_audio.mean(0).unsqueeze(0)
|
||||||
if sr!=24000:
|
if sr != 24000:
|
||||||
ref_audio=resample(ref_audio,sr)
|
ref_audio = resample(ref_audio, sr)
|
||||||
# print("ref_audio",ref_audio.abs().mean())
|
# print("ref_audio", ref_audio.abs().mean())
|
||||||
mel2 = mel_fn(ref_audio)
|
mel2 = mel_fn(ref_audio)
|
||||||
mel2 = norm_spec(mel2)
|
mel2 = norm_spec(mel2)
|
||||||
T_min = min(mel2.shape[2], fea_ref.shape[2])
|
T_min = min(mel2.shape[2], fea_ref.shape[2])
|
||||||
@ -663,12 +666,12 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
|||||||
fea_ref = fea_ref[:, :, -468:]
|
fea_ref = fea_ref[:, :, -468:]
|
||||||
T_min = 468
|
T_min = 468
|
||||||
chunk_len = 934 - T_min
|
chunk_len = 934 - T_min
|
||||||
# print("fea_ref",fea_ref,fea_ref.shape)
|
# print("fea_ref", fea_ref, fea_ref.shape)
|
||||||
# print("mel2",mel2)
|
# print("mel2", mel2)
|
||||||
mel2=mel2.to(dtype)
|
mel2 = mel2.to(dtype)
|
||||||
fea_todo, ge = vq_model.decode_encp(pred_semantic, phoneme_ids1, refer, ge,speed)
|
fea_todo, ge = vq_model.decode_encp(pred_semantic, phoneme_ids1, refer, ge, speed)
|
||||||
# print("fea_todo",fea_todo)
|
# print("fea_todo", fea_todo)
|
||||||
# print("ge",ge.abs().mean())
|
# print("ge", ge.abs().mean())
|
||||||
cfm_resss = []
|
cfm_resss = []
|
||||||
idx = 0
|
idx = 0
|
||||||
while (1):
|
while (1):
|
||||||
@ -686,30 +689,62 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
|||||||
cfm_resss.append(cfm_res)
|
cfm_resss.append(cfm_res)
|
||||||
cmf_res = torch.cat(cfm_resss, 2)
|
cmf_res = torch.cat(cfm_resss, 2)
|
||||||
cmf_res = denorm_spec(cmf_res)
|
cmf_res = denorm_spec(cmf_res)
|
||||||
if bigvgan_model==None:init_bigvgan()
|
if bigvgan_model == None: init_bigvgan()
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
wav_gen = bigvgan_model(cmf_res)
|
wav_gen = bigvgan_model(cmf_res)
|
||||||
audio=wav_gen[0][0]#.cpu().detach().numpy()
|
audio = wav_gen[0][0] # .cpu().detach().numpy()
|
||||||
max_audio=torch.abs(audio).max()#简单防止16bit爆音
|
|
||||||
if max_audio>1:audio=audio/max_audio
|
# Initial clipping check
|
||||||
|
max_audio = torch.abs(audio).max() # 简单防止16bit爆音
|
||||||
|
if max_audio > 1:
|
||||||
|
audio = audio / max_audio
|
||||||
|
|
||||||
|
# Apply new parameters
|
||||||
|
audio = audio.to(torch.float32) # Ensure float32 for processing
|
||||||
|
if loudness_boost:
|
||||||
|
# Boost loudness using RMS-based scaling (adjust multiplier as needed)
|
||||||
|
rms = torch.sqrt(torch.mean(audio ** 2))
|
||||||
|
audio = audio * (rms * 1.5) if rms > 0 else audio
|
||||||
|
if gain > 0:
|
||||||
|
# Apply gain in dB
|
||||||
|
audio = audio * (10 ** (gain / 20))
|
||||||
|
if normalize:
|
||||||
|
# Normalize to [-1, 1]
|
||||||
|
max_abs = torch.abs(audio).max()
|
||||||
|
audio = audio / max_abs if max_abs > 0 else audio
|
||||||
|
if energy_scale != 1.0:
|
||||||
|
# Scale energy
|
||||||
|
audio = audio * torch.sqrt(torch.tensor(energy_scale))
|
||||||
|
if volume_scale != 1.0:
|
||||||
|
# Direct volume scaling
|
||||||
|
audio = audio * volume_scale
|
||||||
|
if strain_effect > 0.0:
|
||||||
|
# Add strain effect (basic distortion)
|
||||||
|
audio = audio + (audio ** 2 * strain_effect)
|
||||||
|
|
||||||
|
# Final clipping check after effects
|
||||||
|
max_audio = torch.abs(audio).max()
|
||||||
|
if max_audio > 1:
|
||||||
|
audio = audio / max_audio
|
||||||
|
|
||||||
audio_opt.append(audio)
|
audio_opt.append(audio)
|
||||||
audio_opt.append(zero_wav_torch)#zero_wav
|
audio_opt.append(zero_wav_torch) # zero_wav
|
||||||
t4 = ttime()
|
t4 = ttime()
|
||||||
t.extend([t2 - t1,t3 - t2, t4 - t3])
|
t.extend([t2 - t1, t3 - t2, t4 - t3])
|
||||||
t1 = ttime()
|
t1 = ttime()
|
||||||
|
|
||||||
print("%.3f\t%.3f\t%.3f\t%.3f" % (t[0], sum(t[1::3]), sum(t[2::3]), sum(t[3::3])))
|
print("%.3f\t%.3f\t%.3f\t%.3f" % (t[0], sum(t[1::3]), sum(t[2::3]), sum(t[3::3])))
|
||||||
audio_opt=torch.cat(audio_opt, 0)#np.concatenate
|
audio_opt = torch.cat(audio_opt, 0) # np.concatenate
|
||||||
sr=hps.data.sampling_rate if model_version!="v3"else 24000
|
sr = hps.data.sampling_rate if model_version != "v3" else 24000
|
||||||
if if_sr==True and sr==24000:
|
if if_sr == True and sr == 24000:
|
||||||
print(i18n("音频超分中"))
|
print(i18n("音频超分中"))
|
||||||
audio_opt,sr=audio_sr(audio_opt.unsqueeze(0),sr)
|
audio_opt, sr = audio_sr(audio_opt.unsqueeze(0), sr)
|
||||||
max_audio=np.abs(audio_opt).max()
|
max_audio = np.abs(audio_opt).max()
|
||||||
if max_audio > 1: audio_opt /= max_audio
|
if max_audio > 1: audio_opt /= max_audio
|
||||||
else:
|
else:
|
||||||
audio_opt=audio_opt.cpu().detach().numpy()
|
audio_opt = audio_opt.cpu().detach().numpy()
|
||||||
yield sr, (audio_opt * 32767).astype(np.int16)
|
yield sr, (audio_opt * 32767).astype(np.int16)
|
||||||
|
|
||||||
|
|
||||||
def split(todo_text):
|
def split(todo_text):
|
||||||
todo_text = todo_text.replace("……", "。").replace("——", ",")
|
todo_text = todo_text.replace("……", "。").replace("——", ",")
|
||||||
if todo_text[-1] not in splits:
|
if todo_text[-1] not in splits:
|
||||||
|
14
api.py
14
api.py
@ -1158,6 +1158,12 @@ def version_4_cli(
|
|||||||
character_name: str = "Kurari",
|
character_name: str = "Kurari",
|
||||||
model_id: int = 14,
|
model_id: int = 14,
|
||||||
version: str = "v1", # v3 or v4
|
version: str = "v1", # v3 or v4
|
||||||
|
loudness_boost=False,
|
||||||
|
gain=0,
|
||||||
|
normalize=False,
|
||||||
|
energy_scale=1.0,
|
||||||
|
volume_scale=1.0,
|
||||||
|
strain_effect=0.0,
|
||||||
):
|
):
|
||||||
# Create a temporary buffer to store the audio
|
# Create a temporary buffer to store the audio
|
||||||
audio_buffer = io.BytesIO()
|
audio_buffer = io.BytesIO()
|
||||||
@ -1187,7 +1193,13 @@ def version_4_cli(
|
|||||||
ref_language = ref_language,
|
ref_language = ref_language,
|
||||||
target_text = target_text,
|
target_text = target_text,
|
||||||
text_language = text_language,
|
text_language = text_language,
|
||||||
output_path = output_path # Don't save to file
|
output_path = output_path, # Don't save to file
|
||||||
|
loudness_boost=loudness_boost,
|
||||||
|
gain=gain,
|
||||||
|
normalize=normalize,
|
||||||
|
energy_scale=energy_scale,
|
||||||
|
volume_scale=volume_scale,
|
||||||
|
strain_effect=strain_effect
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get the last audio data and sample rate from synthesis result
|
# Get the last audio data and sample rate from synthesis result
|
||||||
|
Loading…
x
Reference in New Issue
Block a user