diff --git a/finetune/utils/memory_utils.py b/finetune/utils/memory_utils.py index a7a136b..b341d22 100644 --- a/finetune/utils/memory_utils.py +++ b/finetune/utils/memory_utils.py @@ -51,6 +51,10 @@ def free_memory() -> None: # TODO(aryan): handle non-cuda devices +def unload_model(model): + model.to("cpu") + + def make_contiguous(x: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: if isinstance(x, torch.Tensor): return x.contiguous()