diff --git a/configs/eval.yaml b/configs/eval.yaml index 352327d6..b03052fb 100644 --- a/configs/eval.yaml +++ b/configs/eval.yaml @@ -20,4 +20,7 @@ task_name: "eval" tags: ["dev"] # passing checkpoint path is necessary for evaluation -ckpt_path: ??? +checkpoint: + ckpt_path: ??? + weights_only: null + strict: True diff --git a/configs/experiment/im2im/segmentation_plugin.yaml b/configs/experiment/im2im/segmentation_plugin.yaml index f8c5a3ac..7056292a 100644 --- a/configs/experiment/im2im/segmentation_plugin.yaml +++ b/configs/experiment/im2im/segmentation_plugin.yaml @@ -17,7 +17,10 @@ defaults: tags: ["dev"] seed: 12345 -ckpt_path: null # must override for prediction +checkpoint: + ckpt_path: null # must override for prediction + weights_only: null + strict: False experiment_name: experiment_name run_name: run_name diff --git a/configs/train.yaml b/configs/train.yaml index 8b297edc..f755de87 100644 --- a/configs/train.yaml +++ b/configs/train.yaml @@ -47,7 +47,9 @@ train: True test: True # simply provide checkpoint path to resume training -ckpt_path: null - +checkpoint: + ckpt_path: null + weights_only: null + strict: True # seed for random number generators in pytorch, numpy and python.random seed: null diff --git a/cyto_dl/compile.py b/cyto_dl/compile.py index d7ae3752..c05787b7 100644 --- a/cyto_dl/compile.py +++ b/cyto_dl/compile.py @@ -83,7 +83,7 @@ def compile(cfg: DictConfig) -> Tuple[dict, dict]: { "model_file": str(Path(pkg_root) / cfg.model_file), "handler": str(Path(pkg_root) / cfg.handler_file), - "serialized_file": cfg.ckpt_path, + "serialized_file": cfg.checkpoint.ckpt_path, "model_name": name, "version": version, "extra_files": str(cfg_path), diff --git a/cyto_dl/eval.py b/cyto_dl/eval.py index 7961a151..9e37b4bd 100644 --- a/cyto_dl/eval.py +++ b/cyto_dl/eval.py @@ -33,7 +33,7 @@ def evaluate(cfg: DictConfig, data=None) -> Tuple[dict, dict, dict]: Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects. """ - if not cfg.ckpt_path: + if not cfg.checkpoint.ckpt_path: raise ValueError("Checkpoint path must be included for testing") # resolve config to avoid unresolvable interpolations in the stored config @@ -84,7 +84,7 @@ def evaluate(cfg: DictConfig, data=None) -> Tuple[dict, dict, dict]: log.info("Starting testing!") method = trainer.test if cfg.get("test", False) else trainer.predict - output = method(model=model, dataloaders=data, ckpt_path=cfg.ckpt_path) + output = method(model=model, dataloaders=data, ckpt_path=cfg.checkpoint.ckpt_path) metric_dict = trainer.callback_metrics return metric_dict, object_dict, output diff --git a/cyto_dl/train.py b/cyto_dl/train.py index 0c3e03c5..dd60c82a 100644 --- a/cyto_dl/train.py +++ b/cyto_dl/train.py @@ -95,24 +95,25 @@ def train(cfg: DictConfig, data=None) -> Tuple[dict, dict]: if cfg.get("train"): log.info("Starting training!") - - if cfg.get("weights_only"): - assert cfg.get( + + load_params = cfg.get("checkpoint") + if load_params.get("weights_only"): + assert load_params.get( "ckpt_path" ), "ckpt_path must be provided to with argument weights_only=True" # load model from state dict to get around trainer.max_epochs limit, useful for resuming model training from existing weights - state_dict = torch.load(cfg["ckpt_path"])["state_dict"] - model.load_state_dict(state_dict) - cfg["ckpt_path"] = None + state_dict = torch.load(load_params["ckpt_path"])["state_dict"] + model.load_state_dict(state_dict, strict=load_params.get("strict", True)) + load_params["ckpt_path"] = None if isinstance(data, LightningDataModule): - trainer.fit(model=model, datamodule=data, ckpt_path=cfg.get("ckpt_path")) + trainer.fit(model=model, datamodule=data, ckpt_path=load_params.get("ckpt_path")) else: trainer.fit( model=model, train_dataloaders=data.train_dataloaders, val_dataloaders=data.val_dataloaders, - ckpt_path=cfg.get("ckpt_path"), + ckpt_path=load_params.get("ckpt_path"), ) train_metrics = trainer.callback_metrics diff --git a/tests/conftest.py b/tests/conftest.py index 6c66b577..48815e21 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -53,7 +53,7 @@ def cfg_eval_global(request) -> DictConfig: config_name="eval.yaml", return_hydra_config=True, overrides=[ - "ckpt_path=.", + "checkpoint.ckpt_path=.", f"experiment=im2im/{request.param}.yaml", "trainer=cpu.yaml", ], diff --git a/tests/test_array_models.py b/tests/test_array_models.py index 6d369599..6c807163 100644 --- a/tests/test_array_models.py +++ b/tests/test_array_models.py @@ -42,7 +42,7 @@ def test_array_train_predict(tmp_path): "logger": None, "trainer.accelerator": "cpu", "trainer.devices": 1, - "ckpt_path": ckpt_path, + "checkpoint.ckpt_path": ckpt_path, } model.load_default_experiment( diff --git a/tests/test_eval.py b/tests/test_eval.py index f5d20cfe..1b86b577 100644 --- a/tests/test_eval.py +++ b/tests/test_eval.py @@ -35,7 +35,7 @@ def test_train_eval(tmp_path, cfg_train, cfg_eval, spatial_dims): assert "last.ckpt" in os.listdir(tmp_path / "checkpoints") with open_dict(cfg_eval): - cfg_eval.ckpt_path = str(tmp_path / "checkpoints" / "last.ckpt") + cfg_eval.checkpoint.ckpt_path = str(tmp_path / "checkpoints" / "last.ckpt") cfg_eval.test = True cfg_eval.spatial_dims = spatial_dims diff --git a/tests/test_train.py b/tests/test_train.py index 22a88daf..54c959f3 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -63,7 +63,7 @@ def test_train_resume(tmp_path, cfg_train, spatial_dims): assert "epoch_000.ckpt" in files with open_dict(cfg_train): - cfg_train.ckpt_path = str(tmp_path / "checkpoints" / "last.ckpt") + cfg_train.checkpoint.ckpt_path = str(tmp_path / "checkpoints" / "last.ckpt") cfg_train.trainer.max_epochs = 2 metric_dict_2, _ = train(cfg_train)