Tighten PR scope to single-GPU training path fixes

This commit is contained in:
东云 2026-04-18 15:02:38 +08:00
parent e8c53643e7
commit 43506a8a69

View File

@ -342,28 +342,22 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
global_step += 1 global_step += 1
if epoch % hps.train.save_every_epoch == 0 and rank == 0: if epoch % hps.train.save_every_epoch == 0 and rank == 0:
try: if hps.train.if_save_latest == 0:
if hps.train.if_save_latest == 0: utils.save_checkpoint(
utils.save_checkpoint( net_g,
net_g, optim_g,
optim_g, hps.train.learning_rate,
hps.train.learning_rate, epoch,
epoch, os.path.join(save_root, "G_{}.pth".format(global_step)),
os.path.join(save_root, "G_{}.pth".format(global_step)), )
) else:
else: utils.save_checkpoint(
utils.save_checkpoint( net_g,
net_g, optim_g,
optim_g, hps.train.learning_rate,
hps.train.learning_rate, epoch,
epoch, os.path.join(save_root, "G_{}.pth".format(233333333333)),
os.path.join(save_root, "G_{}.pth".format(233333333333)), )
)
except Exception as e:
if logger is not None:
logger.warning(f"skip large checkpoint save due to error: {e}")
else:
print(f"skip large checkpoint save due to error: {e}")
if rank == 0 and hps.train.if_save_every_weights == True: if rank == 0 and hps.train.if_save_every_weights == True:
if hasattr(net_g, "module"): if hasattr(net_g, "module"):
ckpt = net_g.module.state_dict() ckpt = net_g.module.state_dict()