mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-06 03:57:44 +08:00
update Gradient Checkpointing to reduce VRAM usage (#2040)
* update Gradient Checkpointing to reduce VRAM usage * fix inference
This commit is contained in:
parent
86acb7a89d
commit
c2b3298bed
@ -18,7 +18,8 @@
|
|||||||
"warmup_epochs": 0,
|
"warmup_epochs": 0,
|
||||||
"c_mel": 45,
|
"c_mel": 45,
|
||||||
"c_kl": 1.0,
|
"c_kl": 1.0,
|
||||||
"text_low_lr_rate": 0.4
|
"text_low_lr_rate": 0.4,
|
||||||
|
"grad_ckpt": false
|
||||||
},
|
},
|
||||||
"data": {
|
"data": {
|
||||||
"max_wav_value": 32768.0,
|
"max_wav_value": 32768.0,
|
||||||
|
@ -12,6 +12,7 @@ from __future__ import annotations
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
|
||||||
from x_transformers.x_transformers import RotaryEmbedding
|
from x_transformers.x_transformers import RotaryEmbedding
|
||||||
|
|
||||||
@ -121,6 +122,14 @@ class DiT(nn.Module):
|
|||||||
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
|
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
|
||||||
self.proj_out = nn.Linear(dim, mel_dim)
|
self.proj_out = nn.Linear(dim, mel_dim)
|
||||||
|
|
||||||
|
def ckpt_wrapper(self, module):
|
||||||
|
# https://github.com/chuanyangjin/fast-DiT/blob/main/models.py
|
||||||
|
def ckpt_forward(*inputs):
|
||||||
|
outputs = module(*inputs)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
return ckpt_forward
|
||||||
|
|
||||||
def forward(#x, prompt_x, x_lens, t, style,cond
|
def forward(#x, prompt_x, x_lens, t, style,cond
|
||||||
self,#d is channel,n is T
|
self,#d is channel,n is T
|
||||||
x0: float["b n d"], # nosied input audio # noqa: F722
|
x0: float["b n d"], # nosied input audio # noqa: F722
|
||||||
@ -129,11 +138,12 @@ class DiT(nn.Module):
|
|||||||
time: float["b"] | float[""], # time step # noqa: F821 F722
|
time: float["b"] | float[""], # time step # noqa: F821 F722
|
||||||
dt_base_bootstrap,
|
dt_base_bootstrap,
|
||||||
text0, # : int["b nt"] # noqa: F722#####condition feature
|
text0, # : int["b nt"] # noqa: F722#####condition feature
|
||||||
|
use_grad_ckpt, # bool
|
||||||
###no-use
|
###no-use
|
||||||
drop_audio_cond=False, # cfg for cond audio
|
drop_audio_cond=False, # cfg for cond audio
|
||||||
drop_text=False, # cfg for text
|
drop_text=False, # cfg for text
|
||||||
# mask: bool["b n"] | None = None, # noqa: F722
|
# mask: bool["b n"] | None = None, # noqa: F722
|
||||||
|
|
||||||
):
|
):
|
||||||
|
|
||||||
x=x0.transpose(2,1)
|
x=x0.transpose(2,1)
|
||||||
@ -158,7 +168,10 @@ class DiT(nn.Module):
|
|||||||
residual = x
|
residual = x
|
||||||
|
|
||||||
for block in self.transformer_blocks:
|
for block in self.transformer_blocks:
|
||||||
x = block(x, t, mask=mask, rope=rope)
|
if use_grad_ckpt:
|
||||||
|
x = checkpoint(self.ckpt_wrapper(block), x, t, mask, rope, use_reentrant=False)
|
||||||
|
else:
|
||||||
|
x = block(x, t, mask=mask, rope=rope)
|
||||||
|
|
||||||
if self.long_skip_connection is not None:
|
if self.long_skip_connection is not None:
|
||||||
x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
|
x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
|
||||||
|
@ -1089,15 +1089,15 @@ class CFM(torch.nn.Module):
|
|||||||
t_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * t
|
t_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * t
|
||||||
d_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * d
|
d_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * d
|
||||||
# v_pred = model(x, t_tensor, d_tensor, **extra_args)
|
# v_pred = model(x, t_tensor, d_tensor, **extra_args)
|
||||||
v_pred = self.estimator(x, prompt_x, x_lens, t_tensor,d_tensor, mu,drop_audio_cond=False,drop_text=False).transpose(2, 1)
|
v_pred = self.estimator(x, prompt_x, x_lens, t_tensor,d_tensor, mu, use_grad_ckpt=False,drop_audio_cond=False,drop_text=False).transpose(2, 1)
|
||||||
if inference_cfg_rate>1e-5:
|
if inference_cfg_rate>1e-5:
|
||||||
neg = self.estimator(x, prompt_x, x_lens, t_tensor, d_tensor, mu, drop_audio_cond=True, drop_text=True).transpose(2, 1)
|
neg = self.estimator(x, prompt_x, x_lens, t_tensor, d_tensor, mu, use_grad_ckpt=False, drop_audio_cond=True, drop_text=True).transpose(2, 1)
|
||||||
v_pred=v_pred+(v_pred-neg)*inference_cfg_rate
|
v_pred=v_pred+(v_pred-neg)*inference_cfg_rate
|
||||||
x = x + d * v_pred
|
x = x + d * v_pred
|
||||||
t = t + d
|
t = t + d
|
||||||
x[:, :, :prompt_len] = 0
|
x[:, :, :prompt_len] = 0
|
||||||
return x
|
return x
|
||||||
def forward(self, x1, x_lens, prompt_lens, mu):
|
def forward(self, x1, x_lens, prompt_lens, mu, use_grad_ckpt):
|
||||||
b, _, t = x1.shape
|
b, _, t = x1.shape
|
||||||
|
|
||||||
# random timestep
|
# random timestep
|
||||||
@ -1117,16 +1117,16 @@ class CFM(torch.nn.Module):
|
|||||||
d_input = d.clone()
|
d_input = d.clone()
|
||||||
d_input[d_input < 1e-2] = 0
|
d_input[d_input < 1e-2] = 0
|
||||||
# with torch.no_grad():
|
# with torch.no_grad():
|
||||||
v_pred_1 = self.estimator(xt, prompt, x_lens, t, d_input, mu).transpose(2, 1).detach()
|
v_pred_1 = self.estimator(xt, prompt, x_lens, t, d_input, mu, use_grad_ckpt).transpose(2, 1).detach()
|
||||||
# v_pred_1 = self.diffusion(xt, t, d_input, cond=conditioning).detach()
|
# v_pred_1 = self.diffusion(xt, t, d_input, cond=conditioning).detach()
|
||||||
x_mid = xt + d[:, None, None] * v_pred_1
|
x_mid = xt + d[:, None, None] * v_pred_1
|
||||||
# v_pred_2 = self.diffusion(x_mid, t+d, d_input, cond=conditioning).detach()
|
# v_pred_2 = self.diffusion(x_mid, t+d, d_input, cond=conditioning).detach()
|
||||||
v_pred_2 = self.estimator(x_mid, prompt, x_lens, t+d, d_input, mu).transpose(2, 1).detach()
|
v_pred_2 = self.estimator(x_mid, prompt, x_lens, t+d, d_input, mu, use_grad_ckpt).transpose(2, 1).detach()
|
||||||
vt = (v_pred_1 + v_pred_2) / 2
|
vt = (v_pred_1 + v_pred_2) / 2
|
||||||
vt = vt.detach()
|
vt = vt.detach()
|
||||||
dt = 2*d
|
dt = 2*d
|
||||||
|
|
||||||
vt_pred = self.estimator(xt, prompt, x_lens, t,dt, mu).transpose(2,1)
|
vt_pred = self.estimator(xt, prompt, x_lens, t,dt, mu, use_grad_ckpt).transpose(2,1)
|
||||||
loss = 0
|
loss = 0
|
||||||
|
|
||||||
# print(45555555,estimator_out.shape,u.shape,x_lens,prompt_lens)#45555555 torch.Size([7, 465, 100]) torch.Size([7, 100, 465]) tensor([461, 461, 451, 451, 442, 442, 442], device='cuda:0') tensor([ 96, 93, 185, 59, 244, 262, 294], device='cuda:0')
|
# print(45555555,estimator_out.shape,u.shape,x_lens,prompt_lens)#45555555 torch.Size([7, 465, 100]) torch.Size([7, 100, 465]) tensor([461, 461, 451, 451, 442, 442, 442], device='cuda:0') tensor([ 96, 93, 185, 59, 244, 262, 294], device='cuda:0')
|
||||||
@ -1220,7 +1220,7 @@ class SynthesizerTrnV3(nn.Module):
|
|||||||
self.linear_mel=nn.Conv1d(inter_channels2,100,1,stride=1)
|
self.linear_mel=nn.Conv1d(inter_channels2,100,1,stride=1)
|
||||||
self.cfm = CFM(100,DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),)#text_dim is condition feature dim
|
self.cfm = CFM(100,DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),)#text_dim is condition feature dim
|
||||||
|
|
||||||
def forward(self, ssl, y, mel,ssl_lengths,y_lengths, text, text_lengths,mel_lengths):#ssl_lengths no need now
|
def forward(self, ssl, y, mel,ssl_lengths,y_lengths, text, text_lengths,mel_lengths, use_grad_ckpt):#ssl_lengths no need now
|
||||||
with autocast(enabled=False):
|
with autocast(enabled=False):
|
||||||
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
|
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
|
||||||
ge = self.ref_enc(y[:,:704] * y_mask, y_mask)
|
ge = self.ref_enc(y[:,:704] * y_mask, y_mask)
|
||||||
@ -1245,7 +1245,7 @@ class SynthesizerTrnV3(nn.Module):
|
|||||||
minn=min(mel.shape[-1],fea.shape[-1])
|
minn=min(mel.shape[-1],fea.shape[-1])
|
||||||
mel=mel[:,:,:minn]
|
mel=mel[:,:,:minn]
|
||||||
fea=fea[:,:,:minn]
|
fea=fea[:,:,:minn]
|
||||||
cfm_loss= self.cfm(mel, mel_lengths, prompt_len, fea)
|
cfm_loss= self.cfm(mel, mel_lengths, prompt_len, fea, use_grad_ckpt)
|
||||||
return cfm_loss
|
return cfm_loss
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
@ -304,7 +304,7 @@ def train_and_evaluate(
|
|||||||
text, text_lengths = text.to(device), text_lengths.to(device)
|
text, text_lengths = text.to(device), text_lengths.to(device)
|
||||||
|
|
||||||
with autocast(enabled=hps.train.fp16_run):
|
with autocast(enabled=hps.train.fp16_run):
|
||||||
cfm_loss = net_g(ssl, spec, mel,ssl_lengths,spec_lengths, text, text_lengths,mel_lengths)
|
cfm_loss = net_g(ssl, spec, mel,ssl_lengths,spec_lengths, text, text_lengths,mel_lengths, use_grad_ckpt=hps.train.grad_ckpt)
|
||||||
loss_gen_all=cfm_loss
|
loss_gen_all=cfm_loss
|
||||||
optim_g.zero_grad()
|
optim_g.zero_grad()
|
||||||
scaler.scale(loss_gen_all).backward()
|
scaler.scale(loss_gen_all).backward()
|
||||||
|
11
webui.py
11
webui.py
@ -10,6 +10,7 @@ import json,yaml,torch,pdb,re,shutil
|
|||||||
import platform
|
import platform
|
||||||
import psutil
|
import psutil
|
||||||
import signal
|
import signal
|
||||||
|
os.environ['TORCH_DISTRIBUTED_DEBUG'] = 'INFO'
|
||||||
torch.manual_seed(233333)
|
torch.manual_seed(233333)
|
||||||
tmp = os.path.join(now_dir, "TEMP")
|
tmp = os.path.join(now_dir, "TEMP")
|
||||||
os.makedirs(tmp, exist_ok=True)
|
os.makedirs(tmp, exist_ok=True)
|
||||||
@ -327,7 +328,7 @@ def close_denoise():
|
|||||||
return "已终止语音降噪进程", {"__type__":"update","visible":True}, {"__type__":"update","visible":False}
|
return "已终止语音降噪进程", {"__type__":"update","visible":True}, {"__type__":"update","visible":False}
|
||||||
|
|
||||||
p_train_SoVITS=None
|
p_train_SoVITS=None
|
||||||
def open1Ba(batch_size,total_epoch,exp_name,text_low_lr_rate,if_save_latest,if_save_every_weights,save_every_epoch,gpu_numbers1Ba,pretrained_s2G,pretrained_s2D):
|
def open1Ba(batch_size,total_epoch,exp_name,text_low_lr_rate,if_save_latest,if_save_every_weights,save_every_epoch,gpu_numbers1Ba,pretrained_s2G,pretrained_s2D,if_grad_ckpt):
|
||||||
global p_train_SoVITS
|
global p_train_SoVITS
|
||||||
if(p_train_SoVITS==None):
|
if(p_train_SoVITS==None):
|
||||||
with open("GPT_SoVITS/configs/s2.json")as f:
|
with open("GPT_SoVITS/configs/s2.json")as f:
|
||||||
@ -349,6 +350,7 @@ def open1Ba(batch_size,total_epoch,exp_name,text_low_lr_rate,if_save_latest,if_s
|
|||||||
data["train"]["if_save_every_weights"]=if_save_every_weights
|
data["train"]["if_save_every_weights"]=if_save_every_weights
|
||||||
data["train"]["save_every_epoch"]=save_every_epoch
|
data["train"]["save_every_epoch"]=save_every_epoch
|
||||||
data["train"]["gpu_numbers"]=gpu_numbers1Ba
|
data["train"]["gpu_numbers"]=gpu_numbers1Ba
|
||||||
|
data["train"]["grad_ckpt"]=if_grad_ckpt
|
||||||
data["model"]["version"]=version
|
data["model"]["version"]=version
|
||||||
data["data"]["exp_dir"]=data["s2_ckpt_dir"]=s2_dir
|
data["data"]["exp_dir"]=data["s2_ckpt_dir"]=s2_dir
|
||||||
data["save_weight_dir"]=SoVITS_weight_root[int(version[-1])-1]
|
data["save_weight_dir"]=SoVITS_weight_root[int(version[-1])-1]
|
||||||
@ -779,7 +781,7 @@ def switch_version(version_):
|
|||||||
else:
|
else:
|
||||||
gr.Warning(i18n(f'未下载{version.upper()}模型'))
|
gr.Warning(i18n(f'未下载{version.upper()}模型'))
|
||||||
set_default()
|
set_default()
|
||||||
return {'__type__':'update', 'value':pretrained_sovits_name[int(version[-1])-1]}, {'__type__':'update', 'value':pretrained_sovits_name[int(version[-1])-1].replace("s2G","s2D")}, {'__type__':'update', 'value':pretrained_gpt_name[int(version[-1])-1]}, {'__type__':'update', 'value':pretrained_gpt_name[int(version[-1])-1]}, {'__type__':'update', 'value':pretrained_sovits_name[int(version[-1])-1]},{'__type__':'update',"value":default_batch_size,"maximum":default_max_batch_size},{'__type__':'update',"value":default_sovits_epoch,"maximum":max_sovits_epoch},{'__type__':'update',"value":default_sovits_save_every_epoch,"maximum":max_sovits_save_every_epoch},{'__type__':'update',"interactive":True if version!="v3"else False}
|
return {'__type__':'update', 'value':pretrained_sovits_name[int(version[-1])-1]}, {'__type__':'update', 'value':pretrained_sovits_name[int(version[-1])-1].replace("s2G","s2D")}, {'__type__':'update', 'value':pretrained_gpt_name[int(version[-1])-1]}, {'__type__':'update', 'value':pretrained_gpt_name[int(version[-1])-1]}, {'__type__':'update', 'value':pretrained_sovits_name[int(version[-1])-1]},{'__type__':'update',"value":default_batch_size,"maximum":default_max_batch_size},{'__type__':'update',"value":default_sovits_epoch,"maximum":max_sovits_epoch},{'__type__':'update',"value":default_sovits_save_every_epoch,"maximum":max_sovits_save_every_epoch},{'__type__':'update',"interactive":True if version!="v3"else False},{'__type__':'update',"interactive":True if version == "v3" else False}
|
||||||
|
|
||||||
if os.path.exists('GPT_SoVITS/text/G2PWModel'):...
|
if os.path.exists('GPT_SoVITS/text/G2PWModel'):...
|
||||||
else:
|
else:
|
||||||
@ -1016,6 +1018,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
|||||||
with gr.Column():
|
with gr.Column():
|
||||||
if_save_latest = gr.Checkbox(label=i18n("是否仅保存最新的ckpt文件以节省硬盘空间"), value=True, interactive=True, show_label=True)
|
if_save_latest = gr.Checkbox(label=i18n("是否仅保存最新的ckpt文件以节省硬盘空间"), value=True, interactive=True, show_label=True)
|
||||||
if_save_every_weights = gr.Checkbox(label=i18n("是否在每次保存时间点将最终小模型保存至weights文件夹"), value=True, interactive=True, show_label=True)
|
if_save_every_weights = gr.Checkbox(label=i18n("是否在每次保存时间点将最终小模型保存至weights文件夹"), value=True, interactive=True, show_label=True)
|
||||||
|
if_grad_ckpt = gr.Checkbox(label="是否开启梯度检查点节省显存占用", value=False, interactive=True if version == "v3" else False, show_label=True) # 只有V3s2可以用
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
gpu_numbers1Ba = gr.Textbox(label=i18n("GPU卡号以-分割,每个卡号一个进程"), value="%s" % (gpus), interactive=True)
|
gpu_numbers1Ba = gr.Textbox(label=i18n("GPU卡号以-分割,每个卡号一个进程"), value="%s" % (gpus), interactive=True)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
@ -1045,7 +1048,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
|||||||
button1Bb_close = gr.Button(i18n("终止GPT训练"), variant="primary",visible=False)
|
button1Bb_close = gr.Button(i18n("终止GPT训练"), variant="primary",visible=False)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
info1Bb=gr.Textbox(label=i18n("GPT训练进程输出信息"))
|
info1Bb=gr.Textbox(label=i18n("GPT训练进程输出信息"))
|
||||||
button1Ba_open.click(open1Ba, [batch_size,total_epoch,exp_name,text_low_lr_rate,if_save_latest,if_save_every_weights,save_every_epoch,gpu_numbers1Ba,pretrained_s2G,pretrained_s2D], [info1Ba,button1Ba_open,button1Ba_close])
|
button1Ba_open.click(open1Ba, [batch_size,total_epoch,exp_name,text_low_lr_rate,if_save_latest,if_save_every_weights,save_every_epoch,gpu_numbers1Ba,pretrained_s2G,pretrained_s2D,if_grad_ckpt], [info1Ba,button1Ba_open,button1Ba_close])
|
||||||
button1Ba_close.click(close1Ba, [], [info1Ba,button1Ba_open,button1Ba_close])
|
button1Ba_close.click(close1Ba, [], [info1Ba,button1Ba_open,button1Ba_close])
|
||||||
button1Bb_open.click(open1Bb, [batch_size1Bb,total_epoch1Bb,exp_name,if_dpo,if_save_latest1Bb,if_save_every_weights1Bb,save_every_epoch1Bb,gpu_numbers1Bb,pretrained_s1], [info1Bb,button1Bb_open,button1Bb_close])
|
button1Bb_open.click(open1Bb, [batch_size1Bb,total_epoch1Bb,exp_name,if_dpo,if_save_latest1Bb,if_save_every_weights1Bb,save_every_epoch1Bb,gpu_numbers1Bb,pretrained_s1], [info1Bb,button1Bb_open,button1Bb_close])
|
||||||
button1Bb_close.click(close1Bb, [], [info1Bb,button1Bb_open,button1Bb_close])
|
button1Bb_close.click(close1Bb, [], [info1Bb,button1Bb_open,button1Bb_close])
|
||||||
@ -1069,7 +1072,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
|||||||
tts_info = gr.Textbox(label=i18n("TTS推理WebUI进程输出信息"))
|
tts_info = gr.Textbox(label=i18n("TTS推理WebUI进程输出信息"))
|
||||||
open_tts.click(change_tts_inference, [bert_pretrained_dir,cnhubert_base_dir,gpu_number_1C,GPT_dropdown,SoVITS_dropdown, batched_infer_enabled], [tts_info,open_tts,close_tts])
|
open_tts.click(change_tts_inference, [bert_pretrained_dir,cnhubert_base_dir,gpu_number_1C,GPT_dropdown,SoVITS_dropdown, batched_infer_enabled], [tts_info,open_tts,close_tts])
|
||||||
close_tts.click(change_tts_inference, [bert_pretrained_dir,cnhubert_base_dir,gpu_number_1C,GPT_dropdown,SoVITS_dropdown, batched_infer_enabled], [tts_info,open_tts,close_tts])
|
close_tts.click(change_tts_inference, [bert_pretrained_dir,cnhubert_base_dir,gpu_number_1C,GPT_dropdown,SoVITS_dropdown, batched_infer_enabled], [tts_info,open_tts,close_tts])
|
||||||
version_checkbox.change(switch_version,[version_checkbox],[pretrained_s2G,pretrained_s2D,pretrained_s1,GPT_dropdown,SoVITS_dropdown,batch_size,total_epoch,save_every_epoch,text_low_lr_rate])
|
version_checkbox.change(switch_version,[version_checkbox],[pretrained_s2G,pretrained_s2D,pretrained_s1,GPT_dropdown,SoVITS_dropdown,batch_size,total_epoch,save_every_epoch,text_low_lr_rate, if_grad_ckpt])
|
||||||
with gr.TabItem(i18n("2-GPT-SoVITS-变声")):gr.Markdown(value=i18n("施工中,请静候佳音"))
|
with gr.TabItem(i18n("2-GPT-SoVITS-变声")):gr.Markdown(value=i18n("施工中,请静候佳音"))
|
||||||
app.queue().launch(#concurrency_count=511, max_size=1022
|
app.queue().launch(#concurrency_count=511, max_size=1022
|
||||||
server_name="0.0.0.0",
|
server_name="0.0.0.0",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user