diff --git a/GPT_SoVITS/configs/s2.json b/GPT_SoVITS/configs/s2.json index e44e1eb..0bd6722 100644 --- a/GPT_SoVITS/configs/s2.json +++ b/GPT_SoVITS/configs/s2.json @@ -18,7 +18,8 @@ "warmup_epochs": 0, "c_mel": 45, "c_kl": 1.0, - "text_low_lr_rate": 0.4 + "text_low_lr_rate": 0.4, + "grad_ckpt": false }, "data": { "max_wav_value": 32768.0, diff --git a/GPT_SoVITS/f5_tts/model/backbones/dit.py b/GPT_SoVITS/f5_tts/model/backbones/dit.py index 71ce350..8546fc3 100644 --- a/GPT_SoVITS/f5_tts/model/backbones/dit.py +++ b/GPT_SoVITS/f5_tts/model/backbones/dit.py @@ -12,6 +12,7 @@ from __future__ import annotations import torch from torch import nn import torch.nn.functional as F +from torch.utils.checkpoint import checkpoint from x_transformers.x_transformers import RotaryEmbedding @@ -121,6 +122,14 @@ class DiT(nn.Module): self.norm_out = AdaLayerNormZero_Final(dim) # final modulation self.proj_out = nn.Linear(dim, mel_dim) + def ckpt_wrapper(self, module): + # https://github.com/chuanyangjin/fast-DiT/blob/main/models.py + def ckpt_forward(*inputs): + outputs = module(*inputs) + return outputs + + return ckpt_forward + def forward(#x, prompt_x, x_lens, t, style,cond self,#d is channel,n is T x0: float["b n d"], # nosied input audio # noqa: F722 @@ -129,11 +138,12 @@ class DiT(nn.Module): time: float["b"] | float[""], # time step # noqa: F821 F722 dt_base_bootstrap, text0, # : int["b nt"] # noqa: F722#####condition feature - + use_grad_ckpt, # bool ###no-use drop_audio_cond=False, # cfg for cond audio drop_text=False, # cfg for text # mask: bool["b n"] | None = None, # noqa: F722 + ): x=x0.transpose(2,1) @@ -158,7 +168,10 @@ class DiT(nn.Module): residual = x for block in self.transformer_blocks: - x = block(x, t, mask=mask, rope=rope) + if use_grad_ckpt: + x = checkpoint(self.ckpt_wrapper(block), x, t, mask, rope, use_reentrant=False) + else: + x = block(x, t, mask=mask, rope=rope) if self.long_skip_connection is not None: x = self.long_skip_connection(torch.cat((x, residual), dim=-1)) diff --git a/GPT_SoVITS/module/models.py b/GPT_SoVITS/module/models.py index cb6b951..d546fcd 100644 --- a/GPT_SoVITS/module/models.py +++ b/GPT_SoVITS/module/models.py @@ -1089,15 +1089,15 @@ class CFM(torch.nn.Module): t_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * t d_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * d # v_pred = model(x, t_tensor, d_tensor, **extra_args) - v_pred = self.estimator(x, prompt_x, x_lens, t_tensor,d_tensor, mu,drop_audio_cond=False,drop_text=False).transpose(2, 1) + v_pred = self.estimator(x, prompt_x, x_lens, t_tensor,d_tensor, mu, use_grad_ckpt=False,drop_audio_cond=False,drop_text=False).transpose(2, 1) if inference_cfg_rate>1e-5: - neg = self.estimator(x, prompt_x, x_lens, t_tensor, d_tensor, mu, drop_audio_cond=True, drop_text=True).transpose(2, 1) + neg = self.estimator(x, prompt_x, x_lens, t_tensor, d_tensor, mu, use_grad_ckpt=False, drop_audio_cond=True, drop_text=True).transpose(2, 1) v_pred=v_pred+(v_pred-neg)*inference_cfg_rate x = x + d * v_pred t = t + d x[:, :, :prompt_len] = 0 return x - def forward(self, x1, x_lens, prompt_lens, mu): + def forward(self, x1, x_lens, prompt_lens, mu, use_grad_ckpt): b, _, t = x1.shape # random timestep @@ -1117,16 +1117,16 @@ class CFM(torch.nn.Module): d_input = d.clone() d_input[d_input < 1e-2] = 0 # with torch.no_grad(): - v_pred_1 = self.estimator(xt, prompt, x_lens, t, d_input, mu).transpose(2, 1).detach() + v_pred_1 = self.estimator(xt, prompt, x_lens, t, d_input, mu, use_grad_ckpt).transpose(2, 1).detach() # v_pred_1 = self.diffusion(xt, t, d_input, cond=conditioning).detach() x_mid = xt + d[:, None, None] * v_pred_1 # v_pred_2 = self.diffusion(x_mid, t+d, d_input, cond=conditioning).detach() - v_pred_2 = self.estimator(x_mid, prompt, x_lens, t+d, d_input, mu).transpose(2, 1).detach() + v_pred_2 = self.estimator(x_mid, prompt, x_lens, t+d, d_input, mu, use_grad_ckpt).transpose(2, 1).detach() vt = (v_pred_1 + v_pred_2) / 2 vt = vt.detach() dt = 2*d - vt_pred = self.estimator(xt, prompt, x_lens, t,dt, mu).transpose(2,1) + vt_pred = self.estimator(xt, prompt, x_lens, t,dt, mu, use_grad_ckpt).transpose(2,1) loss = 0 # print(45555555,estimator_out.shape,u.shape,x_lens,prompt_lens)#45555555 torch.Size([7, 465, 100]) torch.Size([7, 100, 465]) tensor([461, 461, 451, 451, 442, 442, 442], device='cuda:0') tensor([ 96, 93, 185, 59, 244, 262, 294], device='cuda:0') @@ -1220,7 +1220,7 @@ class SynthesizerTrnV3(nn.Module): self.linear_mel=nn.Conv1d(inter_channels2,100,1,stride=1) self.cfm = CFM(100,DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),)#text_dim is condition feature dim - def forward(self, ssl, y, mel,ssl_lengths,y_lengths, text, text_lengths,mel_lengths):#ssl_lengths no need now + def forward(self, ssl, y, mel,ssl_lengths,y_lengths, text, text_lengths,mel_lengths, use_grad_ckpt):#ssl_lengths no need now with autocast(enabled=False): y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype) ge = self.ref_enc(y[:,:704] * y_mask, y_mask) @@ -1245,7 +1245,7 @@ class SynthesizerTrnV3(nn.Module): minn=min(mel.shape[-1],fea.shape[-1]) mel=mel[:,:,:minn] fea=fea[:,:,:minn] - cfm_loss= self.cfm(mel, mel_lengths, prompt_len, fea) + cfm_loss= self.cfm(mel, mel_lengths, prompt_len, fea, use_grad_ckpt) return cfm_loss @torch.no_grad() diff --git a/GPT_SoVITS/s2_train_v3.py b/GPT_SoVITS/s2_train_v3.py index 597f98a..a5f7da7 100644 --- a/GPT_SoVITS/s2_train_v3.py +++ b/GPT_SoVITS/s2_train_v3.py @@ -304,7 +304,7 @@ def train_and_evaluate( text, text_lengths = text.to(device), text_lengths.to(device) with autocast(enabled=hps.train.fp16_run): - cfm_loss = net_g(ssl, spec, mel,ssl_lengths,spec_lengths, text, text_lengths,mel_lengths) + cfm_loss = net_g(ssl, spec, mel,ssl_lengths,spec_lengths, text, text_lengths,mel_lengths, use_grad_ckpt=hps.train.grad_ckpt) loss_gen_all=cfm_loss optim_g.zero_grad() scaler.scale(loss_gen_all).backward() diff --git a/webui.py b/webui.py index 0ae88f5..af6544e 100644 --- a/webui.py +++ b/webui.py @@ -10,6 +10,7 @@ import json,yaml,torch,pdb,re,shutil import platform import psutil import signal +os.environ['TORCH_DISTRIBUTED_DEBUG'] = 'INFO' torch.manual_seed(233333) tmp = os.path.join(now_dir, "TEMP") os.makedirs(tmp, exist_ok=True) @@ -327,7 +328,7 @@ def close_denoise(): return "已终止语音降噪进程", {"__type__":"update","visible":True}, {"__type__":"update","visible":False} p_train_SoVITS=None -def open1Ba(batch_size,total_epoch,exp_name,text_low_lr_rate,if_save_latest,if_save_every_weights,save_every_epoch,gpu_numbers1Ba,pretrained_s2G,pretrained_s2D): +def open1Ba(batch_size,total_epoch,exp_name,text_low_lr_rate,if_save_latest,if_save_every_weights,save_every_epoch,gpu_numbers1Ba,pretrained_s2G,pretrained_s2D,if_grad_ckpt): global p_train_SoVITS if(p_train_SoVITS==None): with open("GPT_SoVITS/configs/s2.json")as f: @@ -349,6 +350,7 @@ def open1Ba(batch_size,total_epoch,exp_name,text_low_lr_rate,if_save_latest,if_s data["train"]["if_save_every_weights"]=if_save_every_weights data["train"]["save_every_epoch"]=save_every_epoch data["train"]["gpu_numbers"]=gpu_numbers1Ba + data["train"]["grad_ckpt"]=if_grad_ckpt data["model"]["version"]=version data["data"]["exp_dir"]=data["s2_ckpt_dir"]=s2_dir data["save_weight_dir"]=SoVITS_weight_root[int(version[-1])-1] @@ -779,7 +781,7 @@ def switch_version(version_): else: gr.Warning(i18n(f'未下载{version.upper()}模型')) set_default() - return {'__type__':'update', 'value':pretrained_sovits_name[int(version[-1])-1]}, {'__type__':'update', 'value':pretrained_sovits_name[int(version[-1])-1].replace("s2G","s2D")}, {'__type__':'update', 'value':pretrained_gpt_name[int(version[-1])-1]}, {'__type__':'update', 'value':pretrained_gpt_name[int(version[-1])-1]}, {'__type__':'update', 'value':pretrained_sovits_name[int(version[-1])-1]},{'__type__':'update',"value":default_batch_size,"maximum":default_max_batch_size},{'__type__':'update',"value":default_sovits_epoch,"maximum":max_sovits_epoch},{'__type__':'update',"value":default_sovits_save_every_epoch,"maximum":max_sovits_save_every_epoch},{'__type__':'update',"interactive":True if version!="v3"else False} + return {'__type__':'update', 'value':pretrained_sovits_name[int(version[-1])-1]}, {'__type__':'update', 'value':pretrained_sovits_name[int(version[-1])-1].replace("s2G","s2D")}, {'__type__':'update', 'value':pretrained_gpt_name[int(version[-1])-1]}, {'__type__':'update', 'value':pretrained_gpt_name[int(version[-1])-1]}, {'__type__':'update', 'value':pretrained_sovits_name[int(version[-1])-1]},{'__type__':'update',"value":default_batch_size,"maximum":default_max_batch_size},{'__type__':'update',"value":default_sovits_epoch,"maximum":max_sovits_epoch},{'__type__':'update',"value":default_sovits_save_every_epoch,"maximum":max_sovits_save_every_epoch},{'__type__':'update',"interactive":True if version!="v3"else False},{'__type__':'update',"interactive":True if version == "v3" else False} if os.path.exists('GPT_SoVITS/text/G2PWModel'):... else: @@ -1016,6 +1018,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app: with gr.Column(): if_save_latest = gr.Checkbox(label=i18n("是否仅保存最新的ckpt文件以节省硬盘空间"), value=True, interactive=True, show_label=True) if_save_every_weights = gr.Checkbox(label=i18n("是否在每次保存时间点将最终小模型保存至weights文件夹"), value=True, interactive=True, show_label=True) + if_grad_ckpt = gr.Checkbox(label="是否开启梯度检查点节省显存占用", value=False, interactive=True if version == "v3" else False, show_label=True) # 只有V3s2可以用 with gr.Row(): gpu_numbers1Ba = gr.Textbox(label=i18n("GPU卡号以-分割,每个卡号一个进程"), value="%s" % (gpus), interactive=True) with gr.Row(): @@ -1045,7 +1048,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app: button1Bb_close = gr.Button(i18n("终止GPT训练"), variant="primary",visible=False) with gr.Row(): info1Bb=gr.Textbox(label=i18n("GPT训练进程输出信息")) - button1Ba_open.click(open1Ba, [batch_size,total_epoch,exp_name,text_low_lr_rate,if_save_latest,if_save_every_weights,save_every_epoch,gpu_numbers1Ba,pretrained_s2G,pretrained_s2D], [info1Ba,button1Ba_open,button1Ba_close]) + button1Ba_open.click(open1Ba, [batch_size,total_epoch,exp_name,text_low_lr_rate,if_save_latest,if_save_every_weights,save_every_epoch,gpu_numbers1Ba,pretrained_s2G,pretrained_s2D,if_grad_ckpt], [info1Ba,button1Ba_open,button1Ba_close]) button1Ba_close.click(close1Ba, [], [info1Ba,button1Ba_open,button1Ba_close]) button1Bb_open.click(open1Bb, [batch_size1Bb,total_epoch1Bb,exp_name,if_dpo,if_save_latest1Bb,if_save_every_weights1Bb,save_every_epoch1Bb,gpu_numbers1Bb,pretrained_s1], [info1Bb,button1Bb_open,button1Bb_close]) button1Bb_close.click(close1Bb, [], [info1Bb,button1Bb_open,button1Bb_close]) @@ -1069,7 +1072,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app: tts_info = gr.Textbox(label=i18n("TTS推理WebUI进程输出信息")) open_tts.click(change_tts_inference, [bert_pretrained_dir,cnhubert_base_dir,gpu_number_1C,GPT_dropdown,SoVITS_dropdown, batched_infer_enabled], [tts_info,open_tts,close_tts]) close_tts.click(change_tts_inference, [bert_pretrained_dir,cnhubert_base_dir,gpu_number_1C,GPT_dropdown,SoVITS_dropdown, batched_infer_enabled], [tts_info,open_tts,close_tts]) - version_checkbox.change(switch_version,[version_checkbox],[pretrained_s2G,pretrained_s2D,pretrained_s1,GPT_dropdown,SoVITS_dropdown,batch_size,total_epoch,save_every_epoch,text_low_lr_rate]) + version_checkbox.change(switch_version,[version_checkbox],[pretrained_s2G,pretrained_s2D,pretrained_s1,GPT_dropdown,SoVITS_dropdown,batch_size,total_epoch,save_every_epoch,text_low_lr_rate, if_grad_ckpt]) with gr.TabItem(i18n("2-GPT-SoVITS-变声")):gr.Markdown(value=i18n("施工中,请静候佳音")) app.queue().launch(#concurrency_count=511, max_size=1022 server_name="0.0.0.0",