mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-06-03 12:18:12 +08:00
Improve Windows single-GPU v3 LoRA training
This commit is contained in:
parent
2d9193b0d3
commit
96b8701186
@ -55,6 +55,10 @@ def main():
|
|||||||
n_gpus = torch.cuda.device_count()
|
n_gpus = torch.cuda.device_count()
|
||||||
else:
|
else:
|
||||||
n_gpus = 1
|
n_gpus = 1
|
||||||
|
if n_gpus <= 1:
|
||||||
|
run(0, n_gpus, hps)
|
||||||
|
return
|
||||||
|
|
||||||
os.environ["MASTER_ADDR"] = "localhost"
|
os.environ["MASTER_ADDR"] = "localhost"
|
||||||
os.environ["MASTER_PORT"] = str(randint(20000, 55555))
|
os.environ["MASTER_PORT"] = str(randint(20000, 55555))
|
||||||
|
|
||||||
@ -77,12 +81,14 @@ def run(rank, n_gpus, hps):
|
|||||||
writer = SummaryWriter(log_dir=hps.s2_ckpt_dir)
|
writer = SummaryWriter(log_dir=hps.s2_ckpt_dir)
|
||||||
writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
|
writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
|
||||||
|
|
||||||
dist.init_process_group(
|
use_ddp = n_gpus > 1
|
||||||
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
|
if use_ddp:
|
||||||
init_method="env://?use_libuv=False",
|
dist.init_process_group(
|
||||||
world_size=n_gpus,
|
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
|
||||||
rank=rank,
|
init_method="env://?use_libuv=False",
|
||||||
)
|
world_size=n_gpus,
|
||||||
|
rank=rank,
|
||||||
|
)
|
||||||
torch.manual_seed(hps.train.seed)
|
torch.manual_seed(hps.train.seed)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.set_device(rank)
|
torch.cuda.set_device(rank)
|
||||||
@ -118,15 +124,20 @@ def run(rank, n_gpus, hps):
|
|||||||
shuffle=True,
|
shuffle=True,
|
||||||
)
|
)
|
||||||
collate_fn = TextAudioSpeakerCollate()
|
collate_fn = TextAudioSpeakerCollate()
|
||||||
train_loader = DataLoader(
|
worker_count = 0 if os.name == "nt" and n_gpus <= 1 else min(2 if os.name == "nt" else 5, os.cpu_count() or 1)
|
||||||
train_dataset,
|
loader_kwargs = dict(
|
||||||
num_workers=5,
|
num_workers=worker_count,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
pin_memory=True,
|
pin_memory=torch.cuda.is_available(),
|
||||||
collate_fn=collate_fn,
|
collate_fn=collate_fn,
|
||||||
batch_sampler=train_sampler,
|
batch_sampler=train_sampler,
|
||||||
persistent_workers=True,
|
)
|
||||||
prefetch_factor=3,
|
if worker_count > 0:
|
||||||
|
loader_kwargs["persistent_workers"] = True
|
||||||
|
loader_kwargs["prefetch_factor"] = 2 if os.name == "nt" else 3
|
||||||
|
train_loader = DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
**loader_kwargs,
|
||||||
)
|
)
|
||||||
save_root = "%s/logs_s2_%s_lora_%s" % (hps.data.exp_dir, hps.model.version, hps.train.lora_rank)
|
save_root = "%s/logs_s2_%s_lora_%s" % (hps.data.exp_dir, hps.model.version, hps.train.lora_rank)
|
||||||
os.makedirs(save_root, exist_ok=True)
|
os.makedirs(save_root, exist_ok=True)
|
||||||
@ -156,7 +167,9 @@ def run(rank, n_gpus, hps):
|
|||||||
|
|
||||||
def model2cuda(net_g, rank):
|
def model2cuda(net_g, rank):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
net_g = DDP(net_g.cuda(rank), device_ids=[rank], find_unused_parameters=True)
|
net_g = net_g.cuda(rank)
|
||||||
|
if use_ddp:
|
||||||
|
net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
|
||||||
else:
|
else:
|
||||||
net_g = net_g.to(device)
|
net_g = net_g.to(device)
|
||||||
return net_g
|
return net_g
|
||||||
@ -242,6 +255,8 @@ def run(rank, n_gpus, hps):
|
|||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
scheduler_g.step()
|
scheduler_g.step()
|
||||||
|
if use_ddp and dist.is_initialized():
|
||||||
|
dist.destroy_process_group()
|
||||||
print("training done")
|
print("training done")
|
||||||
|
|
||||||
|
|
||||||
@ -327,22 +342,28 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
|
|||||||
|
|
||||||
global_step += 1
|
global_step += 1
|
||||||
if epoch % hps.train.save_every_epoch == 0 and rank == 0:
|
if epoch % hps.train.save_every_epoch == 0 and rank == 0:
|
||||||
if hps.train.if_save_latest == 0:
|
try:
|
||||||
utils.save_checkpoint(
|
if hps.train.if_save_latest == 0:
|
||||||
net_g,
|
utils.save_checkpoint(
|
||||||
optim_g,
|
net_g,
|
||||||
hps.train.learning_rate,
|
optim_g,
|
||||||
epoch,
|
hps.train.learning_rate,
|
||||||
os.path.join(save_root, "G_{}.pth".format(global_step)),
|
epoch,
|
||||||
)
|
os.path.join(save_root, "G_{}.pth".format(global_step)),
|
||||||
else:
|
)
|
||||||
utils.save_checkpoint(
|
else:
|
||||||
net_g,
|
utils.save_checkpoint(
|
||||||
optim_g,
|
net_g,
|
||||||
hps.train.learning_rate,
|
optim_g,
|
||||||
epoch,
|
hps.train.learning_rate,
|
||||||
os.path.join(save_root, "G_{}.pth".format(233333333333)),
|
epoch,
|
||||||
)
|
os.path.join(save_root, "G_{}.pth".format(233333333333)),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
if logger is not None:
|
||||||
|
logger.warning(f"skip large checkpoint save due to error: {e}")
|
||||||
|
else:
|
||||||
|
print(f"skip large checkpoint save due to error: {e}")
|
||||||
if rank == 0 and hps.train.if_save_every_weights == True:
|
if rank == 0 and hps.train.if_save_every_weights == True:
|
||||||
if hasattr(net_g, "module"):
|
if hasattr(net_g, "module"):
|
||||||
ckpt = net_g.module.state_dict()
|
ckpt = net_g.module.state_dict()
|
||||||
|
|||||||
@ -69,7 +69,8 @@ def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
|
|||||||
name = os.path.basename(path)
|
name = os.path.basename(path)
|
||||||
tmp_path = "%s.pth" % (ttime())
|
tmp_path = "%s.pth" % (ttime())
|
||||||
torch.save(fea, tmp_path)
|
torch.save(fea, tmp_path)
|
||||||
shutil.move(tmp_path, "%s/%s" % (dir, name))
|
os.makedirs(dir, exist_ok=True)
|
||||||
|
os.replace(tmp_path, "%s/%s" % (dir, name))
|
||||||
|
|
||||||
|
|
||||||
def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
|
def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user