import logging import os import platform import sys import warnings from contextlib import nullcontext from random import randint import torch import torch.distributed as dist from rich.progress import BarColumn, Progress, TaskID, TextColumn, TimeElapsedColumn, TimeRemainingColumn from torch.amp.autocast_mode import autocast from torch.amp.grad_scaler import GradScaler from torch.multiprocessing.spawn import spawn from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter import GPT_SoVITS.utils as utils from GPT_SoVITS.Accelerate import console, logger from GPT_SoVITS.Accelerate.logger import SpeedColumnIteration from GPT_SoVITS.module import commons from GPT_SoVITS.module.data_utils import ( DistributedBucketSampler, TextAudioSpeakerCollateV3, TextAudioSpeakerLoaderV3, ) from GPT_SoVITS.module.models import SynthesizerTrnV3 from GPT_SoVITS.process_ckpt import save_ckpt hps = utils.get_hparams(stage=2) warnings.filterwarnings("ignore") logging.getLogger("matplotlib").setLevel(logging.INFO) logging.getLogger("h5py").setLevel(logging.INFO) logging.getLogger("numba").setLevel(logging.INFO) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = False torch.backends.cuda.matmul.allow_tf32 = True ###反正A100fp32更快,那试试tf32吧 torch.backends.cudnn.allow_tf32 = True torch.set_grad_enabled(True) torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响 global_step = 0 if torch.cuda.is_available(): device_str = "cuda" elif torch.mps.is_available(): device_str = "mps" else: device_str = "cpu" multigpu = torch.cuda.device_count() > 1 if torch.cuda.is_available() else False def main(): if torch.cuda.is_available(): n_gpus = torch.cuda.device_count() else: n_gpus = 1 os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = str(randint(20000, 55555)) if platform.system() == "Windows": os.environ["USE_LIBUV"] = "0" spawn( run, nprocs=n_gpus, args=( n_gpus, hps, ), ) def run(rank, n_gpus, hps): global global_step device = torch.device(f"{device_str}:{rank}") if rank == 0: logger.add( os.path.join(hps.data.exp_dir, "train.log"), level="INFO", enqueue=True, backtrace=True, diagnose=True, format="{time:YY-MM-DD HH:mm:ss}\t{name}\t{level}\t{message}", ) console.print(hps.to_dict()) writer: SummaryWriter | None = SummaryWriter(log_dir=hps.s2_ckpt_dir) writer_eval: SummaryWriter | None = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval")) else: writer = writer_eval = None if multigpu: dist.init_process_group( backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", init_method="env://", world_size=n_gpus, rank=rank, ) torch.manual_seed(hps.train.seed) if torch.cuda.is_available(): torch.cuda.set_device(rank) train_dataset = TextAudioSpeakerLoaderV3(hps.data) train_sampler = DistributedBucketSampler( train_dataset, hps.train.batch_size, [ 32, 300, 400, 500, 600, 700, 800, 900, 1000, ], num_replicas=n_gpus, rank=rank, shuffle=True, ) collate_fn = TextAudioSpeakerCollateV3() train_loader = DataLoader( train_dataset, num_workers=6, shuffle=False, pin_memory=True, collate_fn=collate_fn, batch_sampler=train_sampler, persistent_workers=True, prefetch_factor=4, ) # if rank == 0: # eval_dataset = TextAudioSpeakerLoaderV3(hps.data.validation_files, hps.data, val=True) # eval_loader = DataLoader(eval_dataset, num_workers=0, shuffle=False, # batch_size=1, pin_memory=True, # drop_last=False, collate_fn=collate_fn) net_g = SynthesizerTrnV3( hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, n_speakers=hps.data.n_speakers, **hps.model, ).to(device) optim_g = torch.optim.AdamW( filter(lambda p: p.requires_grad, net_g.parameters()), # 默认所有层lr一致 hps.train.learning_rate, betas=hps.train.betas, eps=hps.train.eps, ) if multigpu: net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True) # type: ignore else: pass try: # 如果能加载自动resume _, _, _, epoch_str = utils.load_checkpoint( utils.latest_checkpoint_path(f"{hps.data.exp_dir}/logs_s2_{hps.model.version}", "G_*.pth"), net_g, optim_g, ) epoch_str += 1 global_step = (epoch_str - 1) * len(train_loader) except Exception: # 如果首次不能加载,加载pretrain epoch_str = 1 global_step = 0 if ( hps.train.pretrained_s2G != "" and hps.train.pretrained_s2G is not None and os.path.exists(hps.train.pretrained_s2G) ): if rank == 0: logger.info(f"loaded pretrained {hps.train.pretrained_s2G}") console.print( f"loaded pretrained %{hps.train.pretrained_s2G}", net_g.module.load_state_dict( # type: ignore torch.load(hps.train.pretrained_s2G, map_location="cpu", weights_only=False)["weight"], strict=False, ) if multigpu else net_g.load_state_dict( torch.load(hps.train.pretrained_s2G, map_location="cpu", weights_only=False)["weight"], strict=False, ), ) scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=-1) for _ in range(epoch_str): scheduler_g.step() scaler = GradScaler(device=device.type, enabled=hps.train.fp16_run) net_d = optim_d = scheduler_d = None if rank == 0: logger.info(f"start training from epoch {epoch_str}") with ( Progress( TextColumn("[cyan]{task.description}"), BarColumn(), TextColumn("{task.completed}/{task.total}"), TimeElapsedColumn(), console=console, redirect_stderr=True, redirect_stdout=True, ) if rank == 0 else nullcontext() as progress ): if isinstance(progress, Progress): epoch_task: TaskID | None = progress.add_task( "Epoch", total=int(hps.train.epochs), completed=int(epoch_str) - 1, ) else: epoch_task = step_task = None for epoch in range(epoch_str, hps.train.epochs + 1): if rank == 0: assert epoch_task is not None assert progress is not None progress.advance(epoch_task, 1) train_and_evaluate( device, epoch, hps, (net_g, net_d), (optim_g, optim_d), (scheduler_g, scheduler_d), scaler, (train_loader, None), logger, (writer, writer_eval), ) progress.advance(epoch_task, 1) else: train_and_evaluate( device, epoch, hps, (net_g, net_d), (optim_g, optim_d), (scheduler_g, scheduler_d), scaler, (train_loader, None), None, (None, None), ) scheduler_g.step() if rank == 0: assert progress progress.stop() logger.info("Training Done") sys.exit(0) def train_and_evaluate( device: torch.device, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers, ): net_g, net_d = nets optim_g, optim_d = optims # scheduler_g, scheduler_d = schedulers train_loader, eval_loader = loaders writer, writer_eval = writers train_loader.batch_sampler.set_epoch(epoch) global global_step net_g.train() with ( Progress( TextColumn("[cyan]{task.description}"), BarColumn(), TextColumn("{task.completed}/{task.total}"), SpeedColumnIteration(show_speed=True), TimeRemainingColumn(elapsed_when_finished=True), console=console, redirect_stderr=True, redirect_stdout=True, transient=not (int(epoch) == int(hps.train.epochs)), ) if device.index == 0 else nullcontext() as progress ): if isinstance(progress, Progress): step_task: TaskID | None = progress.add_task("Steps", total=len(train_loader)) else: step_task = None for batch_idx, (ssl, spec, mel, ssl_lengths, spec_lengths, text, text_lengths, mel_lengths) in enumerate( train_loader ): spec, spec_lengths = spec.to(device, non_blocking=True), spec_lengths.to(device, non_blocking=True) mel, mel_lengths = mel.to(device, non_blocking=True), mel_lengths.to(device, non_blocking=True) ssl = ssl.to(device, non_blocking=True) ssl.requires_grad = False # ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True) text, text_lengths = text.to(device, non_blocking=True), text_lengths.to(device, non_blocking=True) with autocast(device_type=device.type, dtype=torch.float16, enabled=hps.train.fp16_run): cfm_loss = net_g( ssl, spec, mel, ssl_lengths, spec_lengths, text, text_lengths, mel_lengths, use_grad_ckpt=hps.train.grad_ckpt, ) loss_gen_all = cfm_loss optim_g.zero_grad() scaler.scale(loss_gen_all).backward() scaler.unscale_(optim_g) grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None) scaler.step(optim_g) scaler.update() if device.index == 0 and progress is not None and step_task is not None: progress.advance(step_task, 1) if device.index == 0: if global_step % hps.train.log_interval == 0: lr = optim_g.param_groups[0]["lr"] # losses = [commit_loss,cfm_loss,mel_loss,loss_disc, loss_gen, loss_fm, loss_mel, loss_kl] losses = [cfm_loss] logger.info( "Train Epoch: {} [{:.0f}%]".format( epoch, 100.0 * batch_idx / len(train_loader), ) ) logger.info([x.item() for x in losses] + [global_step, lr]) scalar_dict = {"loss/g/total": loss_gen_all, "learning_rate": lr, "grad_norm_g": grad_norm_g} utils.summarize( writer=writer, global_step=global_step, # images=image_dict, scalars=scalar_dict, ) global_step += 1 if hps.train.if_save_latest == 0: utils.save_checkpoint( net_g, optim_g, hps.train.learning_rate, epoch, os.path.join( f"{hps.data.exp_dir}/logs_s2_{hps.model.version}", f"G_{global_step}.pth", ), logger, ) else: utils.save_checkpoint( net_g, optim_g, hps.train.learning_rate, epoch, os.path.join( f"{hps.data.exp_dir}/logs_s2_{hps.model.version}", "G_233333333333.pth", ), logger, ) if epoch % hps.train.save_every_epoch == 0 and device.index == 0: if hps.train.if_save_every_weights is True: if hasattr(net_g, "module"): ckpt = net_g.module.state_dict() else: ckpt = net_g.state_dict() save_info = save_ckpt( ckpt, hps.name + f"_e{epoch}_s{global_step}", epoch, global_step, hps, ) logger.info(f"saving ckpt {hps.name}_e{epoch}:{save_info}") if __name__ == "__main__": main()