mirror of
https://github.com/THUDM/CogVideo.git
synced 2025-04-05 19:41:59 +08:00
Update autoencoder.py
This commit is contained in:
parent
8e07303dbf
commit
e43a7645fd
@ -52,24 +52,6 @@ class AbstractAutoencoder(pl.LightningModule):
|
||||
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
||||
self.automatic_optimization = False
|
||||
|
||||
def apply_ckpt(self, ckpt: Union[None, str, dict]):
|
||||
if ckpt is None:
|
||||
return
|
||||
self.init_from_ckpt(ckpt)
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if k.startswith(ik):
|
||||
print("Deleting key {} from state_dict.".format(k))
|
||||
del sd[k]
|
||||
missing_keys, unexpected_keys = self.load_state_dict(sd, strict=False)
|
||||
print("Missing keys: ", missing_keys)
|
||||
print("Unexpected keys: ", unexpected_keys)
|
||||
print(f"Restored from {path}")
|
||||
|
||||
def apply_ckpt(self, ckpt: Union[None, str, dict]):
|
||||
if ckpt is None:
|
||||
return
|
||||
@ -81,7 +63,6 @@ class AbstractAutoencoder(pl.LightningModule):
|
||||
engine = instantiate_from_config(ckpt)
|
||||
engine(self)
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def get_input(self, batch) -> Any:
|
||||
raise NotImplementedError()
|
||||
@ -116,7 +97,9 @@ class AbstractAutoencoder(pl.LightningModule):
|
||||
|
||||
def instantiate_optimizer_from_config(self, params, lr, cfg):
|
||||
logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
|
||||
return get_obj_from_str(cfg["target"])(params, lr=lr, **cfg.get("params", dict()))
|
||||
return get_obj_from_str(cfg["target"])(
|
||||
params, lr=lr, **cfg.get("params", dict())
|
||||
)
|
||||
|
||||
def configure_optimizers(self) -> Any:
|
||||
raise NotImplementedError()
|
||||
|
Loading…
x
Reference in New Issue
Block a user