mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-05 12:38:35 +08:00
修复resume epoch数识别错,每次resume都要都训一轮的问题
修复resume epoch数识别错,每次resume都要都训一轮的问题
This commit is contained in:
parent
514fb692db
commit
aa07216bba
@ -205,6 +205,7 @@ def run(rank, n_gpus, hps):
|
|||||||
net_g,
|
net_g,
|
||||||
optim_g,
|
optim_g,
|
||||||
)
|
)
|
||||||
|
epoch_str+=1
|
||||||
global_step = (epoch_str - 1) * len(train_loader)
|
global_step = (epoch_str - 1) * len(train_loader)
|
||||||
# epoch_str = 1
|
# epoch_str = 1
|
||||||
# global_step = 0
|
# global_step = 0
|
||||||
@ -215,7 +216,7 @@ def run(rank, n_gpus, hps):
|
|||||||
if hps.train.pretrained_s2G != ""and hps.train.pretrained_s2G != None and os.path.exists(hps.train.pretrained_s2G):
|
if hps.train.pretrained_s2G != ""and hps.train.pretrained_s2G != None and os.path.exists(hps.train.pretrained_s2G):
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
|
logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
|
||||||
print(
|
print("loaded pretrained %s" % hps.train.pretrained_s2G,
|
||||||
net_g.module.load_state_dict(
|
net_g.module.load_state_dict(
|
||||||
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
|
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
|
||||||
strict=False,
|
strict=False,
|
||||||
@ -227,7 +228,7 @@ def run(rank, n_gpus, hps):
|
|||||||
if hps.train.pretrained_s2D != ""and hps.train.pretrained_s2D != None and os.path.exists(hps.train.pretrained_s2D):
|
if hps.train.pretrained_s2D != ""and hps.train.pretrained_s2D != None and os.path.exists(hps.train.pretrained_s2D):
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
logger.info("loaded pretrained %s" % hps.train.pretrained_s2D)
|
logger.info("loaded pretrained %s" % hps.train.pretrained_s2D)
|
||||||
print(
|
print("loaded pretrained %s" % hps.train.pretrained_s2D,
|
||||||
net_d.module.load_state_dict(
|
net_d.module.load_state_dict(
|
||||||
torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"]
|
torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"]
|
||||||
) if torch.cuda.is_available() else net_d.load_state_dict(
|
) if torch.cuda.is_available() else net_d.load_state_dict(
|
||||||
@ -251,6 +252,7 @@ def run(rank, n_gpus, hps):
|
|||||||
scaler = GradScaler(enabled=hps.train.fp16_run)
|
scaler = GradScaler(enabled=hps.train.fp16_run)
|
||||||
|
|
||||||
for epoch in range(epoch_str, hps.train.epochs + 1):
|
for epoch in range(epoch_str, hps.train.epochs + 1):
|
||||||
|
print("start training from epoch %s"%epoch)
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
train_and_evaluate(
|
train_and_evaluate(
|
||||||
rank,
|
rank,
|
||||||
@ -280,6 +282,7 @@ def run(rank, n_gpus, hps):
|
|||||||
)
|
)
|
||||||
scheduler_g.step()
|
scheduler_g.step()
|
||||||
scheduler_d.step()
|
scheduler_d.step()
|
||||||
|
print("training done")
|
||||||
|
|
||||||
|
|
||||||
def train_and_evaluate(
|
def train_and_evaluate(
|
||||||
|
@ -178,6 +178,7 @@ def run(rank, n_gpus, hps):
|
|||||||
net_g,
|
net_g,
|
||||||
optim_g,
|
optim_g,
|
||||||
)
|
)
|
||||||
|
epoch_str+=1
|
||||||
global_step = (epoch_str - 1) * len(train_loader)
|
global_step = (epoch_str - 1) * len(train_loader)
|
||||||
# epoch_str = 1
|
# epoch_str = 1
|
||||||
# global_step = 0
|
# global_step = 0
|
||||||
@ -188,7 +189,7 @@ def run(rank, n_gpus, hps):
|
|||||||
if hps.train.pretrained_s2G != ""and hps.train.pretrained_s2G != None and os.path.exists(hps.train.pretrained_s2G):
|
if hps.train.pretrained_s2G != ""and hps.train.pretrained_s2G != None and os.path.exists(hps.train.pretrained_s2G):
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
|
logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
|
||||||
print(
|
print("loaded pretrained %s" % hps.train.pretrained_s2G,
|
||||||
net_g.module.load_state_dict(
|
net_g.module.load_state_dict(
|
||||||
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
|
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
|
||||||
strict=False,
|
strict=False,
|
||||||
@ -225,6 +226,7 @@ def run(rank, n_gpus, hps):
|
|||||||
|
|
||||||
net_d=optim_d=scheduler_d=None
|
net_d=optim_d=scheduler_d=None
|
||||||
for epoch in range(epoch_str, hps.train.epochs + 1):
|
for epoch in range(epoch_str, hps.train.epochs + 1):
|
||||||
|
print("start training from epoch %s"%epoch)
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
train_and_evaluate(
|
train_and_evaluate(
|
||||||
rank,
|
rank,
|
||||||
@ -254,6 +256,7 @@ def run(rank, n_gpus, hps):
|
|||||||
)
|
)
|
||||||
scheduler_g.step()
|
scheduler_g.step()
|
||||||
# scheduler_d.step()
|
# scheduler_d.step()
|
||||||
|
print("training done")
|
||||||
|
|
||||||
|
|
||||||
def train_and_evaluate(
|
def train_and_evaluate(
|
||||||
|
@ -161,6 +161,7 @@ def run(rank, n_gpus, hps):
|
|||||||
net_g,
|
net_g,
|
||||||
optim_g,
|
optim_g,
|
||||||
)
|
)
|
||||||
|
epoch_str+=1
|
||||||
global_step = (epoch_str - 1) * len(train_loader)
|
global_step = (epoch_str - 1) * len(train_loader)
|
||||||
except: # 如果首次不能加载,加载pretrain
|
except: # 如果首次不能加载,加载pretrain
|
||||||
# traceback.print_exc()
|
# traceback.print_exc()
|
||||||
@ -170,7 +171,7 @@ def run(rank, n_gpus, hps):
|
|||||||
if hps.train.pretrained_s2G != ""and hps.train.pretrained_s2G != None and os.path.exists(hps.train.pretrained_s2G):
|
if hps.train.pretrained_s2G != ""and hps.train.pretrained_s2G != None and os.path.exists(hps.train.pretrained_s2G):
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
|
logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
|
||||||
print(
|
print("loaded pretrained %s" % hps.train.pretrained_s2G,
|
||||||
net_g.load_state_dict(
|
net_g.load_state_dict(
|
||||||
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
|
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
|
||||||
strict=False,
|
strict=False,
|
||||||
@ -198,6 +199,7 @@ def run(rank, n_gpus, hps):
|
|||||||
|
|
||||||
net_d=optim_d=scheduler_d=None
|
net_d=optim_d=scheduler_d=None
|
||||||
for epoch in range(epoch_str, hps.train.epochs + 1):
|
for epoch in range(epoch_str, hps.train.epochs + 1):
|
||||||
|
print("start training from epoch %s"%epoch)
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
train_and_evaluate(
|
train_and_evaluate(
|
||||||
rank,
|
rank,
|
||||||
@ -226,6 +228,7 @@ def run(rank, n_gpus, hps):
|
|||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
scheduler_g.step()
|
scheduler_g.step()
|
||||||
|
print("training done")
|
||||||
|
|
||||||
def train_and_evaluate(
|
def train_and_evaluate(
|
||||||
rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers
|
rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers
|
||||||
|
Loading…
x
Reference in New Issue
Block a user