From 49b0bc01802ed5cf1a859a1c8baf548e2d16aa1c Mon Sep 17 00:00:00 2001 From: benjijamorris <54606172+benjijamorris@users.noreply.github.com> Date: Tue, 1 Oct 2024 11:52:51 -0700 Subject: [PATCH] load ckpt to cpu (#428) * load ckpt to cpu * precommit --------- Co-authored-by: Benjamin Morris --- cyto_dl/utils/checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cyto_dl/utils/checkpoint.py b/cyto_dl/utils/checkpoint.py index 7e6e0043..e6f98aae 100644 --- a/cyto_dl/utils/checkpoint.py +++ b/cyto_dl/utils/checkpoint.py @@ -7,7 +7,7 @@ def load_checkpoint(model, load_params): "ckpt_path" ), "ckpt_path must be provided to with argument weights_only=True" # load model from state dict to get around trainer.max_epochs limit, useful for resuming model training from existing weights - state_dict = torch.load(load_params["ckpt_path"])["state_dict"] + state_dict = torch.load(load_params["ckpt_path"], map_location="cpu")["state_dict"] model.load_state_dict(state_dict, strict=load_params.get("strict", True)) # set ckpt_path to None to avoid loading checkpoint again with model.fit/model.test load_params["ckpt_path"] = None