Skip to content

Commit

Permalink
add strict=false loading to eval (#414)
Browse files Browse the repository at this point in the history
* add strict=false loading to eval

* precommit

---------

Co-authored-by: Benjamin Morris <[email protected]>
Co-authored-by: Benjamin Morris <[email protected]>
  • Loading branch information
3 people authored Aug 14, 2024
1 parent 965b371 commit fd4d67a
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 12 deletions.
4 changes: 3 additions & 1 deletion cyto_dl/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,11 @@ def evaluate(cfg: DictConfig, data=None) -> Tuple[dict, dict, dict]:
log.info("Logging hyperparameters!")
utils.log_hyperparameters(object_dict)

model, load_params = utils.load_checkpoint(model, cfg.get("checkpoint"))

log.info("Starting testing!")
method = trainer.test if cfg.get("test", False) else trainer.predict
output = method(model=model, dataloaders=data, ckpt_path=cfg.checkpoint.ckpt_path)
output = method(model=model, dataloaders=data, ckpt_path=load_params.ckpt_path)
metric_dict = trainer.callback_metrics

return metric_dict, object_dict, output
Expand Down
12 changes: 1 addition & 11 deletions cyto_dl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,17 +95,7 @@ def train(cfg: DictConfig, data=None) -> Tuple[dict, dict]:

if cfg.get("train"):
log.info("Starting training!")

load_params = cfg.get("checkpoint")
if load_params.get("weights_only"):
assert load_params.get(
"ckpt_path"
), "ckpt_path must be provided to with argument weights_only=True"
# load model from state dict to get around trainer.max_epochs limit, useful for resuming model training from existing weights
state_dict = torch.load(load_params["ckpt_path"])["state_dict"]
model.load_state_dict(state_dict, strict=load_params.get("strict", True))
load_params["ckpt_path"] = None

model, load_params = utils.load_checkpoint(model, cfg.get("checkpoint"))
if isinstance(data, LightningDataModule):
trainer.fit(model=model, datamodule=data, ckpt_path=load_params.get("ckpt_path"))
else:
Expand Down
1 change: 1 addition & 0 deletions cyto_dl/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .array import create_dataloader, extract_array_predictions
from .checkpoint import load_checkpoint
from .config import kv_to_dict, remove_aux_key
from .pylogger import get_pylogger
from .rich_utils import enforce_tags, print_config_tree
Expand Down
16 changes: 16 additions & 0 deletions cyto_dl/utils/checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import torch


def load_checkpoint(model, load_params):
if load_params.get("weights_only"):
assert load_params.get(
"ckpt_path"
), "ckpt_path must be provided to with argument weights_only=True"
# load model from state dict to get around trainer.max_epochs limit, useful for resuming model training from existing weights
state_dict = torch.load(load_params["ckpt_path"])["state_dict"]
model.load_state_dict(state_dict, strict=load_params.get("strict", True))
# set ckpt_path to None to avoid loading checkpoint again with model.fit/model.test
load_params["ckpt_path"] = None
elif not load_params.get("strict"):
raise ValueError("To use `strict=False`, `weights_only` must be set to True")
return model, load_params

0 comments on commit fd4d67a

Please sign in to comment.