支持24k音频超分48k采样率

支持24k音频超分48k采样率
This commit is contained in:
RVC-Boss 2025-02-27 16:08:26 +08:00 committed by GitHub
parent 9208fb8157
commit 00bdc01113
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -238,7 +238,7 @@ def change_sovits_weights(sovits_path,prompt_language=None,text_language=None):
else: else:
visible_sample_steps=False visible_sample_steps=False
visible_inp_refs=True visible_inp_refs=True
yield {'__type__':'update', 'choices':list(dict_language.keys())}, {'__type__':'update', 'choices':list(dict_language.keys())}, prompt_text_update, prompt_language_update, text_update, text_language_update,{"__type__": "update", "visible": visible_sample_steps},{"__type__": "update", "visible": visible_inp_refs},{"__type__": "update", "value": False,"interactive":True if model_version!="v3"else False} yield {'__type__':'update', 'choices':list(dict_language.keys())}, {'__type__':'update', 'choices':list(dict_language.keys())}, prompt_text_update, prompt_language_update, text_update, text_language_update,{"__type__": "update", "visible": visible_sample_steps},{"__type__": "update", "visible": visible_inp_refs},{"__type__": "update", "value": False,"interactive":True if model_version!="v3"else False},{"__type__": "update", "visible":True if model_version=="v3"else False}
dict_s2 = load_sovits_new(sovits_path) dict_s2 = load_sovits_new(sovits_path)
hps = dict_s2["config"] hps = dict_s2["config"]
@ -289,7 +289,7 @@ def change_sovits_weights(sovits_path,prompt_language=None,text_language=None):
) )
vq_model.cfm = get_peft_model(vq_model.cfm, lora_config) vq_model.cfm = get_peft_model(vq_model.cfm, lora_config)
print("loading sovits_v3_lora%s"%(lora_rank)) print("loading sovits_v3_lora%s"%(lora_rank))
vq_model.load_state_dict(dict_s2["weight"], strict=False)## vq_model.load_state_dict(dict_s2["weight"], strict=False)
vq_model.cfm = vq_model.cfm.merge_and_unload() vq_model.cfm = vq_model.cfm.merge_and_unload()
# torch.save(vq_model.state_dict(),"merge_win.pth") # torch.save(vq_model.state_dict(),"merge_win.pth")
vq_model.eval() vq_model.eval()
@ -332,18 +332,18 @@ now_dir = os.getcwd()
import soundfile import soundfile
def init_bigvgan(): def init_bigvgan():
global model global bigvgan_model
from BigVGAN import bigvgan from BigVGAN import bigvgan
model = bigvgan.BigVGAN.from_pretrained("%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,), use_cuda_kernel=False) # if True, RuntimeError: Ninja is required to load C++ extensions bigvgan_model = bigvgan.BigVGAN.from_pretrained("%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,), use_cuda_kernel=False) # if True, RuntimeError: Ninja is required to load C++ extensions
# remove weight norm in the model and set to eval mode # remove weight norm in the model and set to eval mode
model.remove_weight_norm() bigvgan_model.remove_weight_norm()
model = model.eval() bigvgan_model = bigvgan_model.eval()
if is_half == True: if is_half == True:
model = model.half().to(device) bigvgan_model = bigvgan_model.half().to(device)
else: else:
model = model.to(device) bigvgan_model = bigvgan_model.to(device)
if model_version!="v3":model=None if model_version!="v3":bigvgan_model=None
else:init_bigvgan() else:init_bigvgan()
@ -460,7 +460,13 @@ def get_phones_and_bert(text,language,version,final=False):
return phones,bert.to(dtype),norm_text return phones,bert.to(dtype),norm_text
from module.mel_processing import spectrogram_torch,mel_spectrogram_torch from module.mel_processing import spectrogram_torch,mel_spectrogram_torch
mel_fn_args = { spec_min = -12
spec_max = 2
def norm_spec(x):
return (x - spec_min) / (spec_max - spec_min) * 2 - 1
def denorm_spec(x):
return (x + 1) / 2 * (spec_max - spec_min) + spec_min
mel_fn=lambda x: mel_spectrogram_torch(x, **{
"n_fft": 1024, "n_fft": 1024,
"win_size": 1024, "win_size": 1024,
"hop_size": 256, "hop_size": 256,
@ -469,16 +475,7 @@ mel_fn_args = {
"fmin": 0, "fmin": 0,
"fmax": None, "fmax": None,
"center": False "center": False
} })
spec_min = -12
spec_max = 2
def norm_spec(x):
return (x - spec_min) / (spec_max - spec_min) * 2 - 1
def denorm_spec(x):
return (x + 1) / 2 * (spec_max - spec_min) + spec_min
mel_fn=lambda x: mel_spectrogram_torch(x, **mel_fn_args)
def merge_short_text_in_array(texts, threshold): def merge_short_text_in_array(texts, threshold):
if (len(texts)) < 2: if (len(texts)) < 2:
@ -497,10 +494,23 @@ def merge_short_text_in_array(texts, threshold):
result[len(result) - 1] += text result[len(result) - 1] += text
return result return result
sr_model=None
def audio_sr(audio,sr):
global sr_model
if sr_model==None:
from tools.audio_sr import AP_BWE
try:
sr_model=AP_BWE(device)
except FileNotFoundError:
gr.Warning(i18n("你没有下载超分模型的参数,因此不进行超分。如想超分请先参照教程把文件下载好"))
return audio,sr
return sr_model(audio,sr)
##ref_wav_path+prompt_text+prompt_language+text(单个)+text_language+top_k+top_p+temperature ##ref_wav_path+prompt_text+prompt_language+text(单个)+text_language+top_k+top_p+temperature
# cache_tokens={}#暂未实现清理机制 # cache_tokens={}#暂未实现清理机制
cache= {} cache= {}
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_p=0.6, temperature=0.6, ref_free = False,speed=1,if_freeze=False,inp_refs=None,sample_steps=8): def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_p=0.6, temperature=0.6, ref_free = False,speed=1,if_freeze=False,inp_refs=None,sample_steps=8,if_sr=False,pause_second=0.3):
global cache global cache
if ref_wav_path:pass if ref_wav_path:pass
else:gr.Warning(i18n('请上传参考音频')) else:gr.Warning(i18n('请上传参考音频'))
@ -524,9 +534,14 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
print(i18n("实际输入的目标文本:"), text) print(i18n("实际输入的目标文本:"), text)
zero_wav = np.zeros( zero_wav = np.zeros(
int(hps.data.sampling_rate * 0.3), int(hps.data.sampling_rate * pause_second),
dtype=np.float16 if is_half == True else np.float32, dtype=np.float16 if is_half == True else np.float32,
) )
zero_wav_torch = torch.from_numpy(zero_wav)
if is_half == True:
zero_wav_torch = zero_wav_torch.half().to(device)
else:
zero_wav_torch = zero_wav_torch.to(device)
if not ref_free: if not ref_free:
with torch.no_grad(): with torch.no_grad():
wav16k, sr = librosa.load(ref_wav_path, sr=16000) wav16k, sr = librosa.load(ref_wav_path, sr=16000)
@ -534,13 +549,10 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
gr.Warning(i18n("参考音频在3~10秒范围外请更换")) gr.Warning(i18n("参考音频在3~10秒范围外请更换"))
raise OSError(i18n("参考音频在3~10秒范围外请更换")) raise OSError(i18n("参考音频在3~10秒范围外请更换"))
wav16k = torch.from_numpy(wav16k) wav16k = torch.from_numpy(wav16k)
zero_wav_torch = torch.from_numpy(zero_wav)
if is_half == True: if is_half == True:
wav16k = wav16k.half().to(device) wav16k = wav16k.half().to(device)
zero_wav_torch = zero_wav_torch.half().to(device)
else: else:
wav16k = wav16k.to(device) wav16k = wav16k.to(device)
zero_wav_torch = zero_wav_torch.to(device)
wav16k = torch.cat([wav16k, zero_wav_torch]) wav16k = torch.cat([wav16k, zero_wav_torch])
ssl_content = ssl_model.model(wav16k.unsqueeze(0))[ ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
"last_hidden_state" "last_hidden_state"
@ -671,22 +683,26 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
cfm_resss.append(cfm_res) cfm_resss.append(cfm_res)
cmf_res = torch.cat(cfm_resss, 2) cmf_res = torch.cat(cfm_resss, 2)
cmf_res = denorm_spec(cmf_res) cmf_res = denorm_spec(cmf_res)
if model==None:init_bigvgan() if bigvgan_model==None:init_bigvgan()
with torch.inference_mode(): with torch.inference_mode():
wav_gen = model(cmf_res) wav_gen = bigvgan_model(cmf_res)
audio=wav_gen[0][0].cpu().detach().numpy() audio=wav_gen[0][0]#.cpu().detach().numpy()
max_audio=np.abs(audio).max()#简单防止16bit爆音 max_audio=torch.abs(audio).max()#简单防止16bit爆音#np.abs(audio).max()
if max_audio>1:audio/=max_audio if max_audio>1:audio/=max_audio
audio_opt.append(audio) audio_opt.append(audio)
audio_opt.append(zero_wav) audio_opt.append(zero_wav_torch)#zero_wav
t4 = ttime() t4 = ttime()
t.extend([t2 - t1,t3 - t2, t4 - t3]) t.extend([t2 - t1,t3 - t2, t4 - t3])
t1 = ttime() t1 = ttime()
print("%.3f\t%.3f\t%.3f\t%.3f" % print("%.3f\t%.3f\t%.3f\t%.3f" % (t[0], sum(t[1::3]), sum(t[2::3]), sum(t[3::3])))
(t[0], sum(t[1::3]), sum(t[2::3]), sum(t[3::3])) audio_opt=torch.cat(audio_opt, 0)#np.concatenate
)
sr=hps.data.sampling_rate if model_version!="v3"else 24000 sr=hps.data.sampling_rate if model_version!="v3"else 24000
yield sr, (np.concatenate(audio_opt, 0) * 32768).astype(np.int16) if if_sr==True and sr==24000:
print(i18n("音频超分中"))
audio_opt,sr=audio_sr(audio_opt.unsqueeze(0),sr)
else:
audio_opt=audio_opt.cpu().detach().numpy()
yield sr, (audio_opt * 32767).astype(np.int16)
def split(todo_text): def split(todo_text):
@ -864,7 +880,8 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
label=i18n("参考音频的语种"), choices=list(dict_language.keys()), value=i18n("中文"), label=i18n("参考音频的语种"), choices=list(dict_language.keys()), value=i18n("中文"),
) )
inp_refs = gr.File(label=i18n("可选项:通过拖拽多个文件上传多个参考音频(建议同性),平均融合他们的音色。如不填写此项,音色由左侧单个参考音频控制。如是微调模型,建议参考音频全部在微调训练集音色内,底模不用管。"),file_count="multiple")if model_version!="v3"else gr.File(label=i18n("可选项:通过拖拽多个文件上传多个参考音频(建议同性),平均融合他们的音色。如不填写此项,音色由左侧单个参考音频控制。如是微调模型,建议参考音频全部在微调训练集音色内,底模不用管。"),file_count="multiple",visible=False) inp_refs = gr.File(label=i18n("可选项:通过拖拽多个文件上传多个参考音频(建议同性),平均融合他们的音色。如不填写此项,音色由左侧单个参考音频控制。如是微调模型,建议参考音频全部在微调训练集音色内,底模不用管。"),file_count="multiple")if model_version!="v3"else gr.File(label=i18n("可选项:通过拖拽多个文件上传多个参考音频(建议同性),平均融合他们的音色。如不填写此项,音色由左侧单个参考音频控制。如是微调模型,建议参考音频全部在微调训练集音色内,底模不用管。"),file_count="multiple",visible=False)
sample_steps = gr.Radio(label=i18n("采样步数,如果觉得电,提高试试,如果觉得慢,降低试试"),value=32,choices=[4,8,16,32],visible=True)if model_version=="v3"else gr.Radio(label=i18n("采样步数,如果觉得电,提高试试,如果觉得慢,降低试试"),value=8,choices=[4,8,16,32],visible=False) sample_steps = gr.Radio(label=i18n("采样步数,如果觉得电,提高试试,如果觉得慢,降低试试"),value=32,choices=[4,8,16,32],visible=True)if model_version=="v3"else gr.Radio(label=i18n("采样步数,如果觉得电,提高试试,如果觉得慢,降低试试"),choices=[4,8,16,32],visible=False)
if_sr_Checkbox=gr.Checkbox(label=i18n("v3输出如果觉得闷可以试试开超分"), value=False, interactive=True, show_label=True,visible=False if model_version!="v3"else True)
gr.Markdown(html_center(i18n("*请填写需要合成的目标文本和语种模式"),'h3')) gr.Markdown(html_center(i18n("*请填写需要合成的目标文本和语种模式"),'h3'))
with gr.Row(): with gr.Row():
with gr.Column(scale=13): with gr.Column(scale=13):
@ -881,7 +898,9 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
) )
gr.Markdown(value=html_center(i18n("语速调整,高为更快"))) gr.Markdown(value=html_center(i18n("语速调整,高为更快")))
if_freeze=gr.Checkbox(label=i18n("是否直接对上次合成结果调整语速和音色。防止随机性。"), value=False, interactive=True,show_label=True, scale=1) if_freeze=gr.Checkbox(label=i18n("是否直接对上次合成结果调整语速和音色。防止随机性。"), value=False, interactive=True,show_label=True, scale=1)
speed = gr.Slider(minimum=0.6,maximum=1.65,step=0.05,label=i18n("语速"),value=1,interactive=True, scale=1) with gr.Row():
speed = gr.Slider(minimum=0.6,maximum=1.65,step=0.05,label=i18n("语速"),value=1,interactive=True, scale=1)
pause_second_slider = gr.Slider(minimum=0.1,maximum=0.5,step=0.01,label=i18n("句间停顿秒数"),value=0.3,interactive=True, scale=1)
gr.Markdown(html_center(i18n("GPT采样参数(无参考文本时不要太低。不懂就用默认)"))) gr.Markdown(html_center(i18n("GPT采样参数(无参考文本时不要太低。不懂就用默认)")))
top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=15,interactive=True, scale=1) top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=15,interactive=True, scale=1)
top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True, scale=1) top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True, scale=1)
@ -896,10 +915,10 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
inference_button.click( inference_button.click(
get_tts_wav, get_tts_wav,
[inp_ref, prompt_text, prompt_language, text, text_language, how_to_cut, top_k, top_p, temperature, ref_text_free,speed,if_freeze,inp_refs,sample_steps], [inp_ref, prompt_text, prompt_language, text, text_language, how_to_cut, top_k, top_p, temperature, ref_text_free,speed,if_freeze,inp_refs,sample_steps,if_sr_Checkbox,pause_second_slider],
[output], [output],
) )
SoVITS_dropdown.change(change_sovits_weights, [SoVITS_dropdown,prompt_language,text_language], [prompt_language,text_language,prompt_text,prompt_language,text,text_language,sample_steps,inp_refs,ref_text_free]) SoVITS_dropdown.change(change_sovits_weights, [SoVITS_dropdown,prompt_language,text_language], [prompt_language,text_language,prompt_text,prompt_language,text,text_language,sample_steps,inp_refs,ref_text_free,if_sr_Checkbox])
GPT_dropdown.change(change_gpt_weights, [GPT_dropdown], []) GPT_dropdown.change(change_gpt_weights, [GPT_dropdown], [])
# gr.Markdown(value=i18n("文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。")) # gr.Markdown(value=i18n("文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。"))