diff --git a/cyto_dl/eval.py b/cyto_dl/eval.py index 9e37b4bd..d437599a 100644 --- a/cyto_dl/eval.py +++ b/cyto_dl/eval.py @@ -82,9 +82,11 @@ def evaluate(cfg: DictConfig, data=None) -> Tuple[dict, dict, dict]: log.info("Logging hyperparameters!") utils.log_hyperparameters(object_dict) + model, load_params = utils.load_checkpoint(model, cfg.get("checkpoint")) + log.info("Starting testing!") method = trainer.test if cfg.get("test", False) else trainer.predict - output = method(model=model, dataloaders=data, ckpt_path=cfg.checkpoint.ckpt_path) + output = method(model=model, dataloaders=data, ckpt_path=load_params.ckpt_path) metric_dict = trainer.callback_metrics return metric_dict, object_dict, output diff --git a/cyto_dl/train.py b/cyto_dl/train.py index a6f64398..ee359a91 100644 --- a/cyto_dl/train.py +++ b/cyto_dl/train.py @@ -95,17 +95,7 @@ def train(cfg: DictConfig, data=None) -> Tuple[dict, dict]: if cfg.get("train"): log.info("Starting training!") - - load_params = cfg.get("checkpoint") - if load_params.get("weights_only"): - assert load_params.get( - "ckpt_path" - ), "ckpt_path must be provided to with argument weights_only=True" - # load model from state dict to get around trainer.max_epochs limit, useful for resuming model training from existing weights - state_dict = torch.load(load_params["ckpt_path"])["state_dict"] - model.load_state_dict(state_dict, strict=load_params.get("strict", True)) - load_params["ckpt_path"] = None - + model, load_params = utils.load_checkpoint(model, cfg.get("checkpoint")) if isinstance(data, LightningDataModule): trainer.fit(model=model, datamodule=data, ckpt_path=load_params.get("ckpt_path")) else: diff --git a/cyto_dl/utils/__init__.py b/cyto_dl/utils/__init__.py index ac3acaea..371815c0 100644 --- a/cyto_dl/utils/__init__.py +++ b/cyto_dl/utils/__init__.py @@ -1,4 +1,5 @@ from .array import create_dataloader, extract_array_predictions +from .checkpoint import load_checkpoint from .config import kv_to_dict, remove_aux_key from .pylogger import get_pylogger from .rich_utils import enforce_tags, print_config_tree diff --git a/cyto_dl/utils/checkpoint.py b/cyto_dl/utils/checkpoint.py new file mode 100644 index 00000000..7e6e0043 --- /dev/null +++ b/cyto_dl/utils/checkpoint.py @@ -0,0 +1,16 @@ +import torch + + +def load_checkpoint(model, load_params): + if load_params.get("weights_only"): + assert load_params.get( + "ckpt_path" + ), "ckpt_path must be provided to with argument weights_only=True" + # load model from state dict to get around trainer.max_epochs limit, useful for resuming model training from existing weights + state_dict = torch.load(load_params["ckpt_path"])["state_dict"] + model.load_state_dict(state_dict, strict=load_params.get("strict", True)) + # set ckpt_path to None to avoid loading checkpoint again with model.fit/model.test + load_params["ckpt_path"] = None + elif not load_params.get("strict"): + raise ValueError("To use `strict=False`, `weights_only` must be set to True") + return model, load_params