diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index f903b73..9595835 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -28,10 +28,13 @@ try: analytics.version_check = lambda:None except:... version=model_version=os.environ.get("version","v2") -pretrained_sovits_name=["GPT_SoVITS/pretrained_models/s2G488k.pth", "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth","GPT_SoVITS/pretrained_models/s2Gv3.pth"] +path_sovits_v3="GPT_SoVITS/pretrained_models/s2Gv3.pth" +is_exist_s2gv3=os.path.exists(path_sovits_v3) +pretrained_sovits_name=["GPT_SoVITS/pretrained_models/s2G488k.pth", "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth",path_sovits_v3] pretrained_gpt_name=["GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt","GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt", "GPT_SoVITS/pretrained_models/s1v3.ckpt"] + _ =[[],[]] for i in range(3): if os.path.exists(pretrained_gpt_name[i]):_[0].append(pretrained_gpt_name[i]) @@ -73,6 +76,7 @@ is_share = eval(is_share) if "_CUDA_VISIBLE_DEVICES" in os.environ: os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"] is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available() +# is_half=False punctuation = set(['!', '?', '…', ',', '.', '-'," "]) import gradio as gr from transformers import AutoModelForMaskedLM, AutoTokenizer @@ -83,13 +87,26 @@ from feature_extractor import cnhubert cnhubert.cnhubert_base_path = cnhubert_base_path from GPT_SoVITS.module.models import SynthesizerTrn,SynthesizerTrnV3 +import numpy as np +import random +def set_seed(seed): + if seed == -1: + seed = random.randint(0, 1000000) + seed = int(seed) + random.seed(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) +# set_seed(42) + from AR.models.t2s_lightning_module import Text2SemanticLightningModule from text import cleaned_text_to_sequence from text.cleaner import clean_text from time import time as ttime -from module.mel_processing import spectrogram_torch from tools.my_utils import load_audio from tools.i18n.i18n import I18nAuto, scan_language_list +from peft import LoraConfig, PeftModel, get_peft_model language=os.environ.get("language","Auto") language=sys.argv[-1] if sys.argv[-1] in scan_language_list() else language @@ -192,38 +209,17 @@ def resample(audio_tensor, sr0): ).to(device) return resample_transform_dict[sr0](audio_tensor) +###todo:put them to process_ckpt and modify my_save func (save sovits weights), gpt save weights use my_save in process_ckpt +#symbol_version-model_version-if_lora_v3 +from process_ckpt import get_sovits_version_from_path_fast,load_sovits_new def change_sovits_weights(sovits_path,prompt_language=None,text_language=None): - global vq_model, hps, version, model_version, dict_language - ''' - v1:about 82942KB - half thr:82978KB - v2:about 83014KB - half thr:100MB - v1base:103490KB - half thr:103520KB - v2base:103551KB - v3:about 750MB - - ~82978K~100M~103420~700M - v1-v2-v1base-v2base-v3 - version: - symbols version and timebre_embedding version - model_version: - sovits is v1/2 (VITS) or v3 (shortcut CFM DiT) - ''' - size=os.path.getsize(sovits_path) - if size<82978*1024: - model_version=version="v1" - elif size<100*1024*1024: - model_version=version="v2" - elif size<103520*1024: - model_version=version="v1" - elif size<700*1024*1024: - model_version = version = "v2" - else: - version = "v2" - model_version="v3" - + global vq_model, hps, version, model_version, dict_language,if_lora_v3 + version, model_version, if_lora_v3=get_sovits_version_from_path_fast(sovits_path) + # print(sovits_path,version, model_version, if_lora_v3) + if if_lora_v3==True and is_exist_s2gv3==False: + info=i18n("GPT_SoVITS/pretrained_models/s2Gv3.pth v3sovits的底模没下载对,识别为v3sovits的lora没法加载") + gr.Warning(info) + raise FileExistsError(info) dict_language = dict_language_v1 if version =='v1' else dict_language_v2 if prompt_language is not None and text_language is not None: if prompt_language in list(dict_language.keys()): @@ -244,11 +240,13 @@ def change_sovits_weights(sovits_path,prompt_language=None,text_language=None): visible_inp_refs=True yield {'__type__':'update', 'choices':list(dict_language.keys())}, {'__type__':'update', 'choices':list(dict_language.keys())}, prompt_text_update, prompt_language_update, text_update, text_language_update,{"__type__": "update", "visible": visible_sample_steps},{"__type__": "update", "visible": visible_inp_refs},{"__type__": "update", "value": False,"interactive":True if model_version!="v3"else False} - dict_s2 = torch.load(sovits_path, map_location="cpu", weights_only=False) + dict_s2 = load_sovits_new(sovits_path) hps = dict_s2["config"] hps = DictToAttrRecursive(hps) hps.model.semantic_frame_rate = "25hz" - if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322: + if 'enc_p.text_embedding.weight'not in dict_s2['weight']: + hps.model.version = "v2"#v3model,v2sybomls + elif dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322: hps.model.version = "v1" else: hps.model.version = "v2" @@ -278,7 +276,23 @@ def change_sovits_weights(sovits_path,prompt_language=None,text_language=None): else: vq_model = vq_model.to(device) vq_model.eval() - print("loading sovits_%s"%model_version,vq_model.load_state_dict(dict_s2["weight"], strict=False)) + if if_lora_v3==False: + print("loading sovits_%s"%model_version,vq_model.load_state_dict(dict_s2["weight"], strict=False)) + else: + print("loading sovits_v3pretrained_G", vq_model.load_state_dict(load_sovits_new(path_sovits_v3)["weight"], strict=False)) + lora_rank=dict_s2["lora_rank"] + lora_config = LoraConfig( + target_modules=["to_k", "to_q", "to_v", "to_out.0"], + r=lora_rank, + lora_alpha=lora_rank, + init_lora_weights=True, + ) + vq_model.cfm = get_peft_model(vq_model.cfm, lora_config) + print("loading sovits_v3_lora%s"%(lora_rank),vq_model.load_state_dict(dict_s2["weight"], strict=False)) + vq_model.cfm = vq_model.cfm.merge_and_unload() + # torch.save(vq_model.state_dict(),"merge_win.pth") + vq_model.eval() + with open("./weight.json")as f: data=f.read() data=json.loads(data) @@ -333,7 +347,8 @@ else:init_bigvgan() def get_spepc(hps, filename): - audio = load_audio(filename, int(hps.data.sampling_rate)) + # audio = load_audio(filename, int(hps.data.sampling_rate)) + audio, sampling_rate = librosa.load(filename, sr=int(hps.data.sampling_rate)) audio = torch.FloatTensor(audio) maxx=audio.abs().max() if(maxx>1):audio/=min(2,maxx) @@ -443,11 +458,7 @@ def get_phones_and_bert(text,language,version,final=False): return phones,bert.to(dtype),norm_text -from module.mel_processing import spectrogram_torch,spec_to_mel_torch -def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): - spec=spectrogram_torch(y,n_fft,sampling_rate,hop_size,win_size,center) - mel=spec_to_mel_torch(spec,n_fft,num_mels,sampling_rate,fmin,fmax) - return mel +from module.mel_processing import spectrogram_torch,mel_spectrogram_torch mel_fn_args = { "n_fft": 1024, "win_size": 1024, @@ -465,7 +476,7 @@ def norm_spec(x): return (x - spec_min) / (spec_max - spec_min) * 2 - 1 def denorm_spec(x): return (x + 1) / 2 * (spec_max - spec_min) + spec_min -mel_fn=lambda x: mel_spectrogram(x, **mel_fn_args) +mel_fn=lambda x: mel_spectrogram_torch(x, **mel_fn_args) def merge_short_text_in_array(texts, threshold): @@ -617,6 +628,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, refer = get_spepc(hps, ref_wav_path).to(device).to(dtype)#######这里要重采样切到32k,因为src是24k的,没有单独的32k的src,所以不能改成2个路径 phoneme_ids0=torch.LongTensor(phones1).to(device).unsqueeze(0) phoneme_ids1=torch.LongTensor(phones2).to(device).unsqueeze(0) + # print(11111111, phoneme_ids0, phoneme_ids1) fea_ref,ge = vq_model.decode_encp(prompt.unsqueeze(0), phoneme_ids0, refer) ref_audio, sr = torchaudio.load(ref_wav_path) ref_audio=ref_audio.to(device).float() @@ -624,7 +636,8 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, ref_audio = ref_audio.mean(0).unsqueeze(0) if sr!=24000: ref_audio=resample(ref_audio,sr) - mel2 = mel_fn(ref_audio.to(dtype)) + # print("ref_audio",ref_audio.abs().mean()) + mel2 = mel_fn(ref_audio) mel2 = norm_spec(mel2) T_min = min(mel2.shape[2], fea_ref.shape[2]) mel2 = mel2[:, :, :T_min] @@ -634,7 +647,12 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, fea_ref = fea_ref[:, :, -468:] T_min = 468 chunk_len = 934 - T_min + # print("fea_ref",fea_ref,fea_ref.shape) + # print("mel2",mel2) + mel2=mel2.to(dtype) fea_todo, ge = vq_model.decode_encp(pred_semantic, phoneme_ids1, refer, ge) + # print("fea_todo",fea_todo) + # print("ge",ge.abs().mean()) cfm_resss = [] idx = 0 while (1): @@ -642,9 +660,12 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, if (fea_todo_chunk.shape[-1] == 0): break idx += chunk_len fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1) + # set_seed(123) cfm_res = vq_model.cfm.inference(fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0) cfm_res = cfm_res[:, :, mel2.shape[2]:] mel2 = cfm_res[:, :, -T_min:] + # print("fea", fea) + # print("mel2in", mel2) fea_ref = fea_todo_chunk[:, :, -T_min:] cfm_resss.append(cfm_res) cmf_res = torch.cat(cfm_resss, 2) @@ -653,8 +674,8 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, with torch.inference_mode(): wav_gen = model(cmf_res) audio=wav_gen[0][0].cpu().detach().numpy() - max_audio=np.abs(audio).max()#简单防止16bit爆音 - if max_audio>1:audio/=max_audio + max_audio=np.abs(audio).max()#简单防止16bit爆音 + if max_audio>1:audio/=max_audio audio_opt.append(audio) audio_opt.append(zero_wav) t4 = ttime() diff --git a/GPT_SoVITS/process_ckpt.py b/GPT_SoVITS/process_ckpt.py index 3a436f1..36ef434 100644 --- a/GPT_SoVITS/process_ckpt.py +++ b/GPT_SoVITS/process_ckpt.py @@ -14,7 +14,24 @@ def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path torch.save(fea,tmp_path) shutil.move(tmp_path,"%s/%s"%(dir,name)) -def savee(ckpt, name, epoch, steps, hps): +''' +00:v1 +01:v2 +02:v3 +03:v3lora + + +''' +from io import BytesIO +def my_save2(fea,path): + bio = BytesIO() + torch.save(fea, bio) + bio.seek(0) + data = bio.getvalue() + data = b'03' + data[2:]###temp for v3lora only, todo + with open(path, "wb") as f: f.write(data) + +def savee(ckpt, name, epoch, steps, hps,lora_rank=None): try: opt = OrderedDict() opt["weight"] = {} @@ -24,8 +41,66 @@ def savee(ckpt, name, epoch, steps, hps): opt["weight"][key] = ckpt[key].half() opt["config"] = hps opt["info"] = "%sepoch_%siteration" % (epoch, steps) - # torch.save(opt, "%s/%s.pth" % (hps.save_weight_dir, name)) - my_save(opt, "%s/%s.pth" % (hps.save_weight_dir, name)) + if lora_rank: + opt["lora_rank"]=lora_rank + my_save2(opt, "%s/%s.pth" % (hps.save_weight_dir, name)) + else: + my_save(opt, "%s/%s.pth" % (hps.save_weight_dir, name)) return "Success." except: return traceback.format_exc() + +head2version={ + b'00':["v1","v1",False], + b'01':["v2","v2",False], + b'02':["v2","v3",False], + b'03':["v2","v3",True], +} +hash_pretrained_dict={ + "dc3c97e17592963677a4a1681f30c653":["v2","v2",False],#s2G488k.pth#sovits_v1_pretrained + "43797be674a37c1c83ee81081941ed0f":["v2","v3",False],#s2Gv3.pth#sovits_v3_pretrained + "6642b37f3dbb1f76882b69937c95a5f3":["v2","v2",False],#s2G2333K.pth#sovits_v2_pretrained +} +import hashlib +def get_hash_from_file(sovits_path): + with open(sovits_path,"rb")as f:data=f.read(8192) + hash_md5 = hashlib.md5() + hash_md5.update(data) + return hash_md5.hexdigest() +def get_sovits_version_from_path_fast(sovits_path): + ###1-if it is pretrained sovits models, by hash + hash=get_hash_from_file(sovits_path) + if hash in hash_pretrained_dict: + return hash_pretrained_dict[hash] + ###2-new weights or old weights, by head + with open(sovits_path,"rb")as f:version=f.read(2) + if version!=b"PK": + return head2version[version] + ###3-old weights, by file size + if_lora_v3=False + size=os.path.getsize(sovits_path) + ''' + v1weights:about 82942KB + half thr:82978KB + v2weights:about 83014KB + v3weights:about 750MB + ''' + if size < 82978 * 1024: + model_version = version = "v1" + elif size < 700 * 1024 * 1024: + model_version = version = "v2" + else: + version = "v2" + model_version = "v3" + return version,model_version,if_lora_v3 + +def load_sovits_new(sovits_path): + f=open(sovits_path,"rb") + meta=f.read(2) + if meta!="PK": + data = b'PK' + f.read() + bio = BytesIO() + bio.write(data) + bio.seek(0) + return torch.load(bio, map_location="cpu", weights_only=False) + return torch.load(sovits_path,map_location="cpu", weights_only=False) \ No newline at end of file diff --git a/GPT_SoVITS/s1_train.py b/GPT_SoVITS/s1_train.py index 4b510d9..4311db9 100644 --- a/GPT_SoVITS/s1_train.py +++ b/GPT_SoVITS/s1_train.py @@ -26,12 +26,7 @@ from AR.utils import get_newest_ckpt from collections import OrderedDict from time import time as ttime import shutil -def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path - dir=os.path.dirname(path) - name=os.path.basename(path) - tmp_path="%s.pth"%(ttime()) - torch.save(fea,tmp_path) - shutil.move(tmp_path,"%s/%s"%(dir,name)) +from process_ckpt import my_save class my_model_ckpt(ModelCheckpoint): diff --git a/GPT_SoVITS/s2_train_v3_lora.py b/GPT_SoVITS/s2_train_v3_lora.py new file mode 100644 index 0000000..f10bde1 --- /dev/null +++ b/GPT_SoVITS/s2_train_v3_lora.py @@ -0,0 +1,342 @@ +import warnings +warnings.filterwarnings("ignore") +import utils, os +hps = utils.get_hparams(stage=2) +os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",") +import torch +from torch.nn import functional as F +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +import torch.multiprocessing as mp +import torch.distributed as dist, traceback +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.cuda.amp import autocast, GradScaler +from tqdm import tqdm +import logging, traceback + +logging.getLogger("matplotlib").setLevel(logging.INFO) +logging.getLogger("h5py").setLevel(logging.INFO) +logging.getLogger("numba").setLevel(logging.INFO) +from random import randint +from module import commons +from peft import LoraConfig, PeftModel, get_peft_model +from module.data_utils import ( + TextAudioSpeakerLoaderV3 as TextAudioSpeakerLoader, + TextAudioSpeakerCollateV3 as TextAudioSpeakerCollate, + DistributedBucketSampler, +) +from module.models import ( + SynthesizerTrnV3 as SynthesizerTrn, + MultiPeriodDiscriminator, +) +from module.losses import generator_loss, discriminator_loss, feature_loss, kl_loss +from module.mel_processing import mel_spectrogram_torch, spec_to_mel_torch +from process_ckpt import savee +from collections import OrderedDict as od +torch.backends.cudnn.benchmark = False +torch.backends.cudnn.deterministic = False +###反正A100fp32更快,那试试tf32吧 +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True +torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响 +# from config import pretrained_s2G,pretrained_s2D +global_step = 0 + +device = "cpu" # cuda以外的设备,等mps优化后加入 + + +def main(): + + if torch.cuda.is_available(): + n_gpus = torch.cuda.device_count() + else: + n_gpus = 1 + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(randint(20000, 55555)) + + mp.spawn( + run, + nprocs=n_gpus, + args=( + n_gpus, + hps, + ), + ) + + +def run(rank, n_gpus, hps): + global global_step,no_grad_names,save_root,lora_rank + if rank == 0: + logger = utils.get_logger(hps.data.exp_dir) + logger.info(hps) + # utils.check_git_hash(hps.s2_ckpt_dir) + writer = SummaryWriter(log_dir=hps.s2_ckpt_dir) + writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval")) + + dist.init_process_group( + backend = "gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", + init_method="env://?use_libuv=False", + world_size=n_gpus, + rank=rank, + ) + torch.manual_seed(hps.train.seed) + if torch.cuda.is_available(): + torch.cuda.set_device(rank) + + train_dataset = TextAudioSpeakerLoader(hps.data) ######## + train_sampler = DistributedBucketSampler( + train_dataset, + hps.train.batch_size, + [ + 32, + 300, + 400, + 500, + 600, + 700, + 800, + 900, + 1000, + # 1100, + # 1200, + # 1300, + # 1400, + # 1500, + # 1600, + # 1700, + # 1800, + # 1900, + ], + num_replicas=n_gpus, + rank=rank, + shuffle=True, + ) + collate_fn = TextAudioSpeakerCollate() + train_loader = DataLoader( + train_dataset, + num_workers=6, + shuffle=False, + pin_memory=True, + collate_fn=collate_fn, + batch_sampler=train_sampler, + persistent_workers=True, + prefetch_factor=4, + ) + save_root="%s/logs_s2_%s_lora_%s" % (hps.data.exp_dir,hps.model.version,hps.train.lora_rank) + os.makedirs(save_root,exist_ok=True) + lora_rank=int(hps.train.lora_rank) + lora_config = LoraConfig( + target_modules=["to_k", "to_q", "to_v", "to_out.0"], + r=lora_rank, + lora_alpha=lora_rank, + init_lora_weights=True, + ) + def get_model(hps):return SynthesizerTrn( + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + n_speakers=hps.data.n_speakers, + **hps.model, + ) + def get_optim(net_g): + return torch.optim.AdamW( + filter(lambda p: p.requires_grad, net_g.parameters()), ###默认所有层lr一致 + hps.train.learning_rate, + betas=hps.train.betas, + eps=hps.train.eps, + ) + def model2cuda(net_g,rank): + if torch.cuda.is_available(): + net_g = DDP(net_g.cuda(rank), device_ids=[rank], find_unused_parameters=True) + else: + net_g = net_g.to(device) + return net_g + try:# 如果能加载自动resume + net_g = get_model(hps) + net_g.cfm = get_peft_model(net_g.cfm, lora_config) + net_g=model2cuda(net_g,rank) + optim_g=get_optim(net_g) + # _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0) + _, _, _, epoch_str = utils.load_checkpoint( + utils.latest_checkpoint_path(save_root, "G_*.pth"), + net_g, + optim_g, + ) + global_step = (epoch_str - 1) * len(train_loader) + except: # 如果首次不能加载,加载pretrain + # traceback.print_exc() + epoch_str = 1 + global_step = 0 + net_g = get_model(hps) + if hps.train.pretrained_s2G != ""and hps.train.pretrained_s2G != None and os.path.exists(hps.train.pretrained_s2G): + if rank == 0: + logger.info("loaded pretrained %s" % hps.train.pretrained_s2G) + print( + net_g.load_state_dict( + torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"], + strict=False, + ) + ) + net_g.cfm = get_peft_model(net_g.cfm, lora_config) + net_g=model2cuda(net_g,rank) + optim_g = get_optim(net_g) + + no_grad_names=set() + for name, param in net_g.named_parameters(): + if not param.requires_grad: + no_grad_names.add(name.replace("module.","")) + # print(name, "not requires_grad") + # print(no_grad_names) + # os._exit(233333) + + scheduler_g = torch.optim.lr_scheduler.ExponentialLR( + optim_g, gamma=hps.train.lr_decay, last_epoch=-1 + ) + for _ in range(epoch_str): + scheduler_g.step() + + scaler = GradScaler(enabled=hps.train.fp16_run) + + net_d=optim_d=scheduler_d=None + for epoch in range(epoch_str, hps.train.epochs + 1): + if rank == 0: + train_and_evaluate( + rank, + epoch, + hps, + [net_g, net_d], + [optim_g, optim_d], + [scheduler_g, scheduler_d], + scaler, + # [train_loader, eval_loader], logger, [writer, writer_eval]) + [train_loader, None], + logger, + [writer, writer_eval], + ) + else: + train_and_evaluate( + rank, + epoch, + hps, + [net_g, net_d], + [optim_g, optim_d], + [scheduler_g, scheduler_d], + scaler, + [train_loader, None], + None, + None, + ) + scheduler_g.step() + +def train_and_evaluate( + rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers +): + net_g, net_d = nets + optim_g, optim_d = optims + # scheduler_g, scheduler_d = schedulers + train_loader, eval_loader = loaders + if writers is not None: + writer, writer_eval = writers + + train_loader.batch_sampler.set_epoch(epoch) + global global_step + + net_g.train() + for batch_idx, (ssl, spec, mel, ssl_lengths, spec_lengths, text, text_lengths, mel_lengths) in enumerate(tqdm(train_loader)): + if torch.cuda.is_available(): + spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda( + rank, non_blocking=True + ) + mel, mel_lengths = mel.cuda(rank, non_blocking=True), mel_lengths.cuda( + rank, non_blocking=True + ) + ssl = ssl.cuda(rank, non_blocking=True) + ssl.requires_grad = False + text, text_lengths = text.cuda(rank, non_blocking=True), text_lengths.cuda( + rank, non_blocking=True + ) + else: + spec, spec_lengths = spec.to(device), spec_lengths.to(device) + mel, mel_lengths = mel.to(device), mel_lengths.to(device) + ssl = ssl.to(device) + ssl.requires_grad = False + text, text_lengths = text.to(device), text_lengths.to(device) + + with autocast(enabled=hps.train.fp16_run): + cfm_loss = net_g(ssl, spec, mel,ssl_lengths,spec_lengths, text, text_lengths,mel_lengths, use_grad_ckpt=hps.train.grad_ckpt) + loss_gen_all=cfm_loss + optim_g.zero_grad() + scaler.scale(loss_gen_all).backward() + scaler.unscale_(optim_g) + grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None) + scaler.step(optim_g) + scaler.update() + + if rank == 0: + if global_step % hps.train.log_interval == 0: + lr = optim_g.param_groups[0]['lr'] + losses = [cfm_loss] + logger.info('Train Epoch: {} [{:.0f}%]'.format( + epoch, + 100. * batch_idx / len(train_loader))) + logger.info([x.item() for x in losses] + [global_step, lr]) + + scalar_dict = {"loss/g/total": loss_gen_all, "learning_rate": lr, "grad_norm_g": grad_norm_g} + utils.summarize( + writer=writer, + global_step=global_step, + scalars=scalar_dict) + + global_step += 1 + if epoch % hps.train.save_every_epoch == 0 and rank == 0: + if hps.train.if_save_latest == 0: + utils.save_checkpoint( + net_g, + optim_g, + hps.train.learning_rate, + epoch, + os.path.join( + save_root, "G_{}.pth".format(global_step) + ), + ) + else: + utils.save_checkpoint( + net_g, + optim_g, + hps.train.learning_rate, + epoch, + os.path.join( + save_root, "G_{}.pth".format(233333333333) + ), + ) + if rank == 0 and hps.train.if_save_every_weights == True: + if hasattr(net_g, "module"): + ckpt = net_g.module.state_dict() + else: + ckpt = net_g.state_dict() + sim_ckpt=od() + for key in ckpt: + # if "cfm"not in key: + # print(key) + if key not in no_grad_names: + sim_ckpt[key]=ckpt[key].half().cpu() + logger.info( + "saving ckpt %s_e%s:%s" + % ( + hps.name, + epoch, + savee( + sim_ckpt, + hps.name + "_e%s_s%s_l%s" % (epoch, global_step,lora_rank), + epoch, + global_step, + hps,lora_rank=lora_rank + ), + ) + ) + + if rank == 0: + logger.info("====> Epoch: {}".format(epoch)) + + +if __name__ == "__main__": + main()