Update autoencoder.py

This commit is contained in:
zR 2024-11-08 21:49:02 +08:00
parent 8e07303dbf
commit e43a7645fd

View File

@ -52,24 +52,6 @@ class AbstractAutoencoder(pl.LightningModule):
if version.parse(torch.__version__) >= version.parse("2.0.0"):
self.automatic_optimization = False
def apply_ckpt(self, ckpt: Union[None, str, dict]):
if ckpt is None:
return
self.init_from_ckpt(ckpt)
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
missing_keys, unexpected_keys = self.load_state_dict(sd, strict=False)
print("Missing keys: ", missing_keys)
print("Unexpected keys: ", unexpected_keys)
print(f"Restored from {path}")
def apply_ckpt(self, ckpt: Union[None, str, dict]):
if ckpt is None:
return
@ -81,7 +63,6 @@ class AbstractAutoencoder(pl.LightningModule):
engine = instantiate_from_config(ckpt)
engine(self)
@abstractmethod
def get_input(self, batch) -> Any:
raise NotImplementedError()
@ -116,7 +97,9 @@ class AbstractAutoencoder(pl.LightningModule):
def instantiate_optimizer_from_config(self, params, lr, cfg):
logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
return get_obj_from_str(cfg["target"])(params, lr=lr, **cfg.get("params", dict()))
return get_obj_from_str(cfg["target"])(
params, lr=lr, **cfg.get("params", dict())
)
def configure_optimizers(self) -> Any:
raise NotImplementedError()