Improve Windows single-GPU v3 LoRA training

This commit is contained in:
东云 2026-04-18 13:50:17 +08:00
parent 2d9193b0d3
commit 96b8701186
2 changed files with 52 additions and 30 deletions

View File

@ -55,6 +55,10 @@ def main():
n_gpus = torch.cuda.device_count() n_gpus = torch.cuda.device_count()
else: else:
n_gpus = 1 n_gpus = 1
if n_gpus <= 1:
run(0, n_gpus, hps)
return
os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(randint(20000, 55555)) os.environ["MASTER_PORT"] = str(randint(20000, 55555))
@ -77,12 +81,14 @@ def run(rank, n_gpus, hps):
writer = SummaryWriter(log_dir=hps.s2_ckpt_dir) writer = SummaryWriter(log_dir=hps.s2_ckpt_dir)
writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval")) writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
dist.init_process_group( use_ddp = n_gpus > 1
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", if use_ddp:
init_method="env://?use_libuv=False", dist.init_process_group(
world_size=n_gpus, backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
rank=rank, init_method="env://?use_libuv=False",
) world_size=n_gpus,
rank=rank,
)
torch.manual_seed(hps.train.seed) torch.manual_seed(hps.train.seed)
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.set_device(rank) torch.cuda.set_device(rank)
@ -118,15 +124,20 @@ def run(rank, n_gpus, hps):
shuffle=True, shuffle=True,
) )
collate_fn = TextAudioSpeakerCollate() collate_fn = TextAudioSpeakerCollate()
train_loader = DataLoader( worker_count = 0 if os.name == "nt" and n_gpus <= 1 else min(2 if os.name == "nt" else 5, os.cpu_count() or 1)
train_dataset, loader_kwargs = dict(
num_workers=5, num_workers=worker_count,
shuffle=False, shuffle=False,
pin_memory=True, pin_memory=torch.cuda.is_available(),
collate_fn=collate_fn, collate_fn=collate_fn,
batch_sampler=train_sampler, batch_sampler=train_sampler,
persistent_workers=True, )
prefetch_factor=3, if worker_count > 0:
loader_kwargs["persistent_workers"] = True
loader_kwargs["prefetch_factor"] = 2 if os.name == "nt" else 3
train_loader = DataLoader(
train_dataset,
**loader_kwargs,
) )
save_root = "%s/logs_s2_%s_lora_%s" % (hps.data.exp_dir, hps.model.version, hps.train.lora_rank) save_root = "%s/logs_s2_%s_lora_%s" % (hps.data.exp_dir, hps.model.version, hps.train.lora_rank)
os.makedirs(save_root, exist_ok=True) os.makedirs(save_root, exist_ok=True)
@ -156,7 +167,9 @@ def run(rank, n_gpus, hps):
def model2cuda(net_g, rank): def model2cuda(net_g, rank):
if torch.cuda.is_available(): if torch.cuda.is_available():
net_g = DDP(net_g.cuda(rank), device_ids=[rank], find_unused_parameters=True) net_g = net_g.cuda(rank)
if use_ddp:
net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
else: else:
net_g = net_g.to(device) net_g = net_g.to(device)
return net_g return net_g
@ -242,6 +255,8 @@ def run(rank, n_gpus, hps):
None, None,
) )
scheduler_g.step() scheduler_g.step()
if use_ddp and dist.is_initialized():
dist.destroy_process_group()
print("training done") print("training done")
@ -327,22 +342,28 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
global_step += 1 global_step += 1
if epoch % hps.train.save_every_epoch == 0 and rank == 0: if epoch % hps.train.save_every_epoch == 0 and rank == 0:
if hps.train.if_save_latest == 0: try:
utils.save_checkpoint( if hps.train.if_save_latest == 0:
net_g, utils.save_checkpoint(
optim_g, net_g,
hps.train.learning_rate, optim_g,
epoch, hps.train.learning_rate,
os.path.join(save_root, "G_{}.pth".format(global_step)), epoch,
) os.path.join(save_root, "G_{}.pth".format(global_step)),
else: )
utils.save_checkpoint( else:
net_g, utils.save_checkpoint(
optim_g, net_g,
hps.train.learning_rate, optim_g,
epoch, hps.train.learning_rate,
os.path.join(save_root, "G_{}.pth".format(233333333333)), epoch,
) os.path.join(save_root, "G_{}.pth".format(233333333333)),
)
except Exception as e:
if logger is not None:
logger.warning(f"skip large checkpoint save due to error: {e}")
else:
print(f"skip large checkpoint save due to error: {e}")
if rank == 0 and hps.train.if_save_every_weights == True: if rank == 0 and hps.train.if_save_every_weights == True:
if hasattr(net_g, "module"): if hasattr(net_g, "module"):
ckpt = net_g.module.state_dict() ckpt = net_g.module.state_dict()

View File

@ -69,7 +69,8 @@ def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
name = os.path.basename(path) name = os.path.basename(path)
tmp_path = "%s.pth" % (ttime()) tmp_path = "%s.pth" % (ttime())
torch.save(fea, tmp_path) torch.save(fea, tmp_path)
shutil.move(tmp_path, "%s/%s" % (dir, name)) os.makedirs(dir, exist_ok=True)
os.replace(tmp_path, "%s/%s" % (dir, name))
def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path): def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):