support sovits v3 lora training, 8G GPU memory is enough

support sovits v3 lora training, 8G GPU memory is enough
This commit is contained in:
RVC-Boss 2025-02-23 00:37:14 +08:00 committed by GitHub
parent 56509a17c9
commit e937b625e4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 488 additions and 55 deletions

View File

@ -28,10 +28,13 @@ try:
analytics.version_check = lambda:None
except:...
version=model_version=os.environ.get("version","v2")
pretrained_sovits_name=["GPT_SoVITS/pretrained_models/s2G488k.pth", "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth","GPT_SoVITS/pretrained_models/s2Gv3.pth"]
path_sovits_v3="GPT_SoVITS/pretrained_models/s2Gv3.pth"
is_exist_s2gv3=os.path.exists(path_sovits_v3)
pretrained_sovits_name=["GPT_SoVITS/pretrained_models/s2G488k.pth", "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth",path_sovits_v3]
pretrained_gpt_name=["GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt","GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt", "GPT_SoVITS/pretrained_models/s1v3.ckpt"]
_ =[[],[]]
for i in range(3):
if os.path.exists(pretrained_gpt_name[i]):_[0].append(pretrained_gpt_name[i])
@ -73,6 +76,7 @@ is_share = eval(is_share)
if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
# is_half=False
punctuation = set(['!', '?', '', ',', '.', '-'," "])
import gradio as gr
from transformers import AutoModelForMaskedLM, AutoTokenizer
@ -83,13 +87,26 @@ from feature_extractor import cnhubert
cnhubert.cnhubert_base_path = cnhubert_base_path
from GPT_SoVITS.module.models import SynthesizerTrn,SynthesizerTrnV3
import numpy as np
import random
def set_seed(seed):
if seed == -1:
seed = random.randint(0, 1000000)
seed = int(seed)
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# set_seed(42)
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from text import cleaned_text_to_sequence
from text.cleaner import clean_text
from time import time as ttime
from module.mel_processing import spectrogram_torch
from tools.my_utils import load_audio
from tools.i18n.i18n import I18nAuto, scan_language_list
from peft import LoraConfig, PeftModel, get_peft_model
language=os.environ.get("language","Auto")
language=sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
@ -192,38 +209,17 @@ def resample(audio_tensor, sr0):
).to(device)
return resample_transform_dict[sr0](audio_tensor)
###todo:put them to process_ckpt and modify my_save func (save sovits weights), gpt save weights use my_save in process_ckpt
#symbol_version-model_version-if_lora_v3
from process_ckpt import get_sovits_version_from_path_fast,load_sovits_new
def change_sovits_weights(sovits_path,prompt_language=None,text_language=None):
global vq_model, hps, version, model_version, dict_language
'''
v1:about 82942KB
half thr:82978KB
v2:about 83014KB
half thr:100MB
v1base:103490KB
half thr:103520KB
v2base:103551KB
v3:about 750MB
~82978K~100M~103420~700M
v1-v2-v1base-v2base-v3
version:
symbols version and timebre_embedding version
model_version:
sovits is v1/2 (VITS) or v3 (shortcut CFM DiT)
'''
size=os.path.getsize(sovits_path)
if size<82978*1024:
model_version=version="v1"
elif size<100*1024*1024:
model_version=version="v2"
elif size<103520*1024:
model_version=version="v1"
elif size<700*1024*1024:
model_version = version = "v2"
else:
version = "v2"
model_version="v3"
global vq_model, hps, version, model_version, dict_language,if_lora_v3
version, model_version, if_lora_v3=get_sovits_version_from_path_fast(sovits_path)
# print(sovits_path,version, model_version, if_lora_v3)
if if_lora_v3==True and is_exist_s2gv3==False:
info=i18n("GPT_SoVITS/pretrained_models/s2Gv3.pth v3sovits的底模没下载对识别为v3sovits的lora没法加载")
gr.Warning(info)
raise FileExistsError(info)
dict_language = dict_language_v1 if version =='v1' else dict_language_v2
if prompt_language is not None and text_language is not None:
if prompt_language in list(dict_language.keys()):
@ -244,11 +240,13 @@ def change_sovits_weights(sovits_path,prompt_language=None,text_language=None):
visible_inp_refs=True
yield {'__type__':'update', 'choices':list(dict_language.keys())}, {'__type__':'update', 'choices':list(dict_language.keys())}, prompt_text_update, prompt_language_update, text_update, text_language_update,{"__type__": "update", "visible": visible_sample_steps},{"__type__": "update", "visible": visible_inp_refs},{"__type__": "update", "value": False,"interactive":True if model_version!="v3"else False}
dict_s2 = torch.load(sovits_path, map_location="cpu", weights_only=False)
dict_s2 = load_sovits_new(sovits_path)
hps = dict_s2["config"]
hps = DictToAttrRecursive(hps)
hps.model.semantic_frame_rate = "25hz"
if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
if 'enc_p.text_embedding.weight'not in dict_s2['weight']:
hps.model.version = "v2"#v3model,v2sybomls
elif dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
hps.model.version = "v1"
else:
hps.model.version = "v2"
@ -278,7 +276,23 @@ def change_sovits_weights(sovits_path,prompt_language=None,text_language=None):
else:
vq_model = vq_model.to(device)
vq_model.eval()
print("loading sovits_%s"%model_version,vq_model.load_state_dict(dict_s2["weight"], strict=False))
if if_lora_v3==False:
print("loading sovits_%s"%model_version,vq_model.load_state_dict(dict_s2["weight"], strict=False))
else:
print("loading sovits_v3pretrained_G", vq_model.load_state_dict(load_sovits_new(path_sovits_v3)["weight"], strict=False))
lora_rank=dict_s2["lora_rank"]
lora_config = LoraConfig(
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
r=lora_rank,
lora_alpha=lora_rank,
init_lora_weights=True,
)
vq_model.cfm = get_peft_model(vq_model.cfm, lora_config)
print("loading sovits_v3_lora%s"%(lora_rank),vq_model.load_state_dict(dict_s2["weight"], strict=False))
vq_model.cfm = vq_model.cfm.merge_and_unload()
# torch.save(vq_model.state_dict(),"merge_win.pth")
vq_model.eval()
with open("./weight.json")as f:
data=f.read()
data=json.loads(data)
@ -333,7 +347,8 @@ else:init_bigvgan()
def get_spepc(hps, filename):
audio = load_audio(filename, int(hps.data.sampling_rate))
# audio = load_audio(filename, int(hps.data.sampling_rate))
audio, sampling_rate = librosa.load(filename, sr=int(hps.data.sampling_rate))
audio = torch.FloatTensor(audio)
maxx=audio.abs().max()
if(maxx>1):audio/=min(2,maxx)
@ -443,11 +458,7 @@ def get_phones_and_bert(text,language,version,final=False):
return phones,bert.to(dtype),norm_text
from module.mel_processing import spectrogram_torch,spec_to_mel_torch
def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
spec=spectrogram_torch(y,n_fft,sampling_rate,hop_size,win_size,center)
mel=spec_to_mel_torch(spec,n_fft,num_mels,sampling_rate,fmin,fmax)
return mel
from module.mel_processing import spectrogram_torch,mel_spectrogram_torch
mel_fn_args = {
"n_fft": 1024,
"win_size": 1024,
@ -465,7 +476,7 @@ def norm_spec(x):
return (x - spec_min) / (spec_max - spec_min) * 2 - 1
def denorm_spec(x):
return (x + 1) / 2 * (spec_max - spec_min) + spec_min
mel_fn=lambda x: mel_spectrogram(x, **mel_fn_args)
mel_fn=lambda x: mel_spectrogram_torch(x, **mel_fn_args)
def merge_short_text_in_array(texts, threshold):
@ -617,6 +628,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
refer = get_spepc(hps, ref_wav_path).to(device).to(dtype)#######这里要重采样切到32k,因为src是24k的没有单独的32k的src所以不能改成2个路径
phoneme_ids0=torch.LongTensor(phones1).to(device).unsqueeze(0)
phoneme_ids1=torch.LongTensor(phones2).to(device).unsqueeze(0)
# print(11111111, phoneme_ids0, phoneme_ids1)
fea_ref,ge = vq_model.decode_encp(prompt.unsqueeze(0), phoneme_ids0, refer)
ref_audio, sr = torchaudio.load(ref_wav_path)
ref_audio=ref_audio.to(device).float()
@ -624,7 +636,8 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
ref_audio = ref_audio.mean(0).unsqueeze(0)
if sr!=24000:
ref_audio=resample(ref_audio,sr)
mel2 = mel_fn(ref_audio.to(dtype))
# print("ref_audio",ref_audio.abs().mean())
mel2 = mel_fn(ref_audio)
mel2 = norm_spec(mel2)
T_min = min(mel2.shape[2], fea_ref.shape[2])
mel2 = mel2[:, :, :T_min]
@ -634,7 +647,12 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
fea_ref = fea_ref[:, :, -468:]
T_min = 468
chunk_len = 934 - T_min
# print("fea_ref",fea_ref,fea_ref.shape)
# print("mel2",mel2)
mel2=mel2.to(dtype)
fea_todo, ge = vq_model.decode_encp(pred_semantic, phoneme_ids1, refer, ge)
# print("fea_todo",fea_todo)
# print("ge",ge.abs().mean())
cfm_resss = []
idx = 0
while (1):
@ -642,9 +660,12 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
if (fea_todo_chunk.shape[-1] == 0): break
idx += chunk_len
fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1)
# set_seed(123)
cfm_res = vq_model.cfm.inference(fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0)
cfm_res = cfm_res[:, :, mel2.shape[2]:]
mel2 = cfm_res[:, :, -T_min:]
# print("fea", fea)
# print("mel2in", mel2)
fea_ref = fea_todo_chunk[:, :, -T_min:]
cfm_resss.append(cfm_res)
cmf_res = torch.cat(cfm_resss, 2)
@ -653,8 +674,8 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
with torch.inference_mode():
wav_gen = model(cmf_res)
audio=wav_gen[0][0].cpu().detach().numpy()
max_audio=np.abs(audio).max()#简单防止16bit爆音
if max_audio>1:audio/=max_audio
max_audio=np.abs(audio).max()#简单防止16bit爆音
if max_audio>1:audio/=max_audio
audio_opt.append(audio)
audio_opt.append(zero_wav)
t4 = ttime()

View File

@ -14,7 +14,24 @@ def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
torch.save(fea,tmp_path)
shutil.move(tmp_path,"%s/%s"%(dir,name))
def savee(ckpt, name, epoch, steps, hps):
'''
00:v1
01:v2
02:v3
03:v3lora
'''
from io import BytesIO
def my_save2(fea,path):
bio = BytesIO()
torch.save(fea, bio)
bio.seek(0)
data = bio.getvalue()
data = b'03' + data[2:]###temp for v3lora only, todo
with open(path, "wb") as f: f.write(data)
def savee(ckpt, name, epoch, steps, hps,lora_rank=None):
try:
opt = OrderedDict()
opt["weight"] = {}
@ -24,8 +41,66 @@ def savee(ckpt, name, epoch, steps, hps):
opt["weight"][key] = ckpt[key].half()
opt["config"] = hps
opt["info"] = "%sepoch_%siteration" % (epoch, steps)
# torch.save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
my_save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
if lora_rank:
opt["lora_rank"]=lora_rank
my_save2(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
else:
my_save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
return "Success."
except:
return traceback.format_exc()
head2version={
b'00':["v1","v1",False],
b'01':["v2","v2",False],
b'02':["v2","v3",False],
b'03':["v2","v3",True],
}
hash_pretrained_dict={
"dc3c97e17592963677a4a1681f30c653":["v2","v2",False],#s2G488k.pth#sovits_v1_pretrained
"43797be674a37c1c83ee81081941ed0f":["v2","v3",False],#s2Gv3.pth#sovits_v3_pretrained
"6642b37f3dbb1f76882b69937c95a5f3":["v2","v2",False],#s2G2333K.pth#sovits_v2_pretrained
}
import hashlib
def get_hash_from_file(sovits_path):
with open(sovits_path,"rb")as f:data=f.read(8192)
hash_md5 = hashlib.md5()
hash_md5.update(data)
return hash_md5.hexdigest()
def get_sovits_version_from_path_fast(sovits_path):
###1-if it is pretrained sovits models, by hash
hash=get_hash_from_file(sovits_path)
if hash in hash_pretrained_dict:
return hash_pretrained_dict[hash]
###2-new weights or old weights, by head
with open(sovits_path,"rb")as f:version=f.read(2)
if version!=b"PK":
return head2version[version]
###3-old weights, by file size
if_lora_v3=False
size=os.path.getsize(sovits_path)
'''
v1weights:about 82942KB
half thr:82978KB
v2weights:about 83014KB
v3weights:about 750MB
'''
if size < 82978 * 1024:
model_version = version = "v1"
elif size < 700 * 1024 * 1024:
model_version = version = "v2"
else:
version = "v2"
model_version = "v3"
return version,model_version,if_lora_v3
def load_sovits_new(sovits_path):
f=open(sovits_path,"rb")
meta=f.read(2)
if meta!="PK":
data = b'PK' + f.read()
bio = BytesIO()
bio.write(data)
bio.seek(0)
return torch.load(bio, map_location="cpu", weights_only=False)
return torch.load(sovits_path,map_location="cpu", weights_only=False)

View File

@ -26,12 +26,7 @@ from AR.utils import get_newest_ckpt
from collections import OrderedDict
from time import time as ttime
import shutil
def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
dir=os.path.dirname(path)
name=os.path.basename(path)
tmp_path="%s.pth"%(ttime())
torch.save(fea,tmp_path)
shutil.move(tmp_path,"%s/%s"%(dir,name))
from process_ckpt import my_save
class my_model_ckpt(ModelCheckpoint):

View File

@ -0,0 +1,342 @@
import warnings
warnings.filterwarnings("ignore")
import utils, os
hps = utils.get_hparams(stage=2)
os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torch.multiprocessing as mp
import torch.distributed as dist, traceback
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
import logging, traceback
logging.getLogger("matplotlib").setLevel(logging.INFO)
logging.getLogger("h5py").setLevel(logging.INFO)
logging.getLogger("numba").setLevel(logging.INFO)
from random import randint
from module import commons
from peft import LoraConfig, PeftModel, get_peft_model
from module.data_utils import (
TextAudioSpeakerLoaderV3 as TextAudioSpeakerLoader,
TextAudioSpeakerCollateV3 as TextAudioSpeakerCollate,
DistributedBucketSampler,
)
from module.models import (
SynthesizerTrnV3 as SynthesizerTrn,
MultiPeriodDiscriminator,
)
from module.losses import generator_loss, discriminator_loss, feature_loss, kl_loss
from module.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
from process_ckpt import savee
from collections import OrderedDict as od
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = False
###反正A100fp32更快那试试tf32吧
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响
# from config import pretrained_s2G,pretrained_s2D
global_step = 0
device = "cpu" # cuda以外的设备等mps优化后加入
def main():
if torch.cuda.is_available():
n_gpus = torch.cuda.device_count()
else:
n_gpus = 1
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(randint(20000, 55555))
mp.spawn(
run,
nprocs=n_gpus,
args=(
n_gpus,
hps,
),
)
def run(rank, n_gpus, hps):
global global_step,no_grad_names,save_root,lora_rank
if rank == 0:
logger = utils.get_logger(hps.data.exp_dir)
logger.info(hps)
# utils.check_git_hash(hps.s2_ckpt_dir)
writer = SummaryWriter(log_dir=hps.s2_ckpt_dir)
writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
dist.init_process_group(
backend = "gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
init_method="env://?use_libuv=False",
world_size=n_gpus,
rank=rank,
)
torch.manual_seed(hps.train.seed)
if torch.cuda.is_available():
torch.cuda.set_device(rank)
train_dataset = TextAudioSpeakerLoader(hps.data) ########
train_sampler = DistributedBucketSampler(
train_dataset,
hps.train.batch_size,
[
32,
300,
400,
500,
600,
700,
800,
900,
1000,
# 1100,
# 1200,
# 1300,
# 1400,
# 1500,
# 1600,
# 1700,
# 1800,
# 1900,
],
num_replicas=n_gpus,
rank=rank,
shuffle=True,
)
collate_fn = TextAudioSpeakerCollate()
train_loader = DataLoader(
train_dataset,
num_workers=6,
shuffle=False,
pin_memory=True,
collate_fn=collate_fn,
batch_sampler=train_sampler,
persistent_workers=True,
prefetch_factor=4,
)
save_root="%s/logs_s2_%s_lora_%s" % (hps.data.exp_dir,hps.model.version,hps.train.lora_rank)
os.makedirs(save_root,exist_ok=True)
lora_rank=int(hps.train.lora_rank)
lora_config = LoraConfig(
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
r=lora_rank,
lora_alpha=lora_rank,
init_lora_weights=True,
)
def get_model(hps):return SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model,
)
def get_optim(net_g):
return torch.optim.AdamW(
filter(lambda p: p.requires_grad, net_g.parameters()), ###默认所有层lr一致
hps.train.learning_rate,
betas=hps.train.betas,
eps=hps.train.eps,
)
def model2cuda(net_g,rank):
if torch.cuda.is_available():
net_g = DDP(net_g.cuda(rank), device_ids=[rank], find_unused_parameters=True)
else:
net_g = net_g.to(device)
return net_g
try:# 如果能加载自动resume
net_g = get_model(hps)
net_g.cfm = get_peft_model(net_g.cfm, lora_config)
net_g=model2cuda(net_g,rank)
optim_g=get_optim(net_g)
# _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
_, _, _, epoch_str = utils.load_checkpoint(
utils.latest_checkpoint_path(save_root, "G_*.pth"),
net_g,
optim_g,
)
global_step = (epoch_str - 1) * len(train_loader)
except: # 如果首次不能加载加载pretrain
# traceback.print_exc()
epoch_str = 1
global_step = 0
net_g = get_model(hps)
if hps.train.pretrained_s2G != ""and hps.train.pretrained_s2G != None and os.path.exists(hps.train.pretrained_s2G):
if rank == 0:
logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
print(
net_g.load_state_dict(
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
strict=False,
)
)
net_g.cfm = get_peft_model(net_g.cfm, lora_config)
net_g=model2cuda(net_g,rank)
optim_g = get_optim(net_g)
no_grad_names=set()
for name, param in net_g.named_parameters():
if not param.requires_grad:
no_grad_names.add(name.replace("module.",""))
# print(name, "not requires_grad")
# print(no_grad_names)
# os._exit(233333)
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
optim_g, gamma=hps.train.lr_decay, last_epoch=-1
)
for _ in range(epoch_str):
scheduler_g.step()
scaler = GradScaler(enabled=hps.train.fp16_run)
net_d=optim_d=scheduler_d=None
for epoch in range(epoch_str, hps.train.epochs + 1):
if rank == 0:
train_and_evaluate(
rank,
epoch,
hps,
[net_g, net_d],
[optim_g, optim_d],
[scheduler_g, scheduler_d],
scaler,
# [train_loader, eval_loader], logger, [writer, writer_eval])
[train_loader, None],
logger,
[writer, writer_eval],
)
else:
train_and_evaluate(
rank,
epoch,
hps,
[net_g, net_d],
[optim_g, optim_d],
[scheduler_g, scheduler_d],
scaler,
[train_loader, None],
None,
None,
)
scheduler_g.step()
def train_and_evaluate(
rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers
):
net_g, net_d = nets
optim_g, optim_d = optims
# scheduler_g, scheduler_d = schedulers
train_loader, eval_loader = loaders
if writers is not None:
writer, writer_eval = writers
train_loader.batch_sampler.set_epoch(epoch)
global global_step
net_g.train()
for batch_idx, (ssl, spec, mel, ssl_lengths, spec_lengths, text, text_lengths, mel_lengths) in enumerate(tqdm(train_loader)):
if torch.cuda.is_available():
spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(
rank, non_blocking=True
)
mel, mel_lengths = mel.cuda(rank, non_blocking=True), mel_lengths.cuda(
rank, non_blocking=True
)
ssl = ssl.cuda(rank, non_blocking=True)
ssl.requires_grad = False
text, text_lengths = text.cuda(rank, non_blocking=True), text_lengths.cuda(
rank, non_blocking=True
)
else:
spec, spec_lengths = spec.to(device), spec_lengths.to(device)
mel, mel_lengths = mel.to(device), mel_lengths.to(device)
ssl = ssl.to(device)
ssl.requires_grad = False
text, text_lengths = text.to(device), text_lengths.to(device)
with autocast(enabled=hps.train.fp16_run):
cfm_loss = net_g(ssl, spec, mel,ssl_lengths,spec_lengths, text, text_lengths,mel_lengths, use_grad_ckpt=hps.train.grad_ckpt)
loss_gen_all=cfm_loss
optim_g.zero_grad()
scaler.scale(loss_gen_all).backward()
scaler.unscale_(optim_g)
grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
scaler.step(optim_g)
scaler.update()
if rank == 0:
if global_step % hps.train.log_interval == 0:
lr = optim_g.param_groups[0]['lr']
losses = [cfm_loss]
logger.info('Train Epoch: {} [{:.0f}%]'.format(
epoch,
100. * batch_idx / len(train_loader)))
logger.info([x.item() for x in losses] + [global_step, lr])
scalar_dict = {"loss/g/total": loss_gen_all, "learning_rate": lr, "grad_norm_g": grad_norm_g}
utils.summarize(
writer=writer,
global_step=global_step,
scalars=scalar_dict)
global_step += 1
if epoch % hps.train.save_every_epoch == 0 and rank == 0:
if hps.train.if_save_latest == 0:
utils.save_checkpoint(
net_g,
optim_g,
hps.train.learning_rate,
epoch,
os.path.join(
save_root, "G_{}.pth".format(global_step)
),
)
else:
utils.save_checkpoint(
net_g,
optim_g,
hps.train.learning_rate,
epoch,
os.path.join(
save_root, "G_{}.pth".format(233333333333)
),
)
if rank == 0 and hps.train.if_save_every_weights == True:
if hasattr(net_g, "module"):
ckpt = net_g.module.state_dict()
else:
ckpt = net_g.state_dict()
sim_ckpt=od()
for key in ckpt:
# if "cfm"not in key:
# print(key)
if key not in no_grad_names:
sim_ckpt[key]=ckpt[key].half().cpu()
logger.info(
"saving ckpt %s_e%s:%s"
% (
hps.name,
epoch,
savee(
sim_ckpt,
hps.name + "_e%s_s%s_l%s" % (epoch, global_step,lora_rank),
epoch,
global_step,
hps,lora_rank=lora_rank
),
)
)
if rank == 0:
logger.info("====> Epoch: {}".format(epoch))
if __name__ == "__main__":
main()