From 73ef86cf49f6a4ecfb36d5920f33a4fd0eee8bb9 Mon Sep 17 00:00:00 2001 From: Daniel Saelid Date: Mon, 15 Apr 2024 11:49:01 -0700 Subject: [PATCH] move overrides to top-level --- .../experiment/im2im/segmentation_plugin.yaml | 34 +++++----- .../api/cyto_dl_model/cyto_dl_base_model.py | 66 +++++++++++-------- .../segmentation_plugin_model.py | 24 ++----- .../test_segmentation_plugin_model.py | 15 +---- 4 files changed, 65 insertions(+), 74 deletions(-) diff --git a/configs/experiment/im2im/segmentation_plugin.yaml b/configs/experiment/im2im/segmentation_plugin.yaml index c69274c39..bff9895c9 100644 --- a/configs/experiment/im2im/segmentation_plugin.yaml +++ b/configs/experiment/im2im/segmentation_plugin.yaml @@ -6,16 +6,21 @@ defaults: - override /model: im2im/segmentation_plugin.yaml - override /callbacks: default.yaml - override /trainer: gpu.yaml - - override /logger: mlflow.yaml + - override /logger: csv.yaml # all parameters below will be merged with parameters from default configurations set above # this allows you to overwrite only specified parameters +# parameters with value MUST_OVERRIDE must be overridden before using this config, all other +# parameters have a reasonable default value + tags: ["dev"] seed: 12345 -experiment_name: YOUR_EXP_NAME -run_name: YOUR_RUN_NAME +ckpt_path: null # must override for prediction + +experiment_name: experiment_name +run_name: run_name # manifest columns source_col: raw @@ -26,24 +31,21 @@ exclude_mask_col: exclude_mask base_image_col: base_image # data params -spatial_dims: 3 +spatial_dims: MUST_OVERRIDE # int value, req for first training, should not change after input_channel: 0 raw_im_channels: 1 trainer: - max_epochs: 100 + max_epochs: 1 # must override for training + accelerator: gpu data: - path: ${paths.data_dir}/example_experiment_data/s3_data - cache_dir: ${paths.data_dir}/example_experiment_data/cache + path: MUST_OVERRIDE # string path to manifest + split_column: null batch_size: 1 _aux: - patch_shape: -# small, medium, large -# 32 pix, 64 pix, 128 pix - -# OVERRIDE: -# data._aux.patch_shape -# model._aux.strides -# model._aux.kernel_size -# model._aux.upsample_kernel_size + patch_shape: [16, 32, 32] + +paths: + output_dir: MUST_OVERRIDE + work_dir: ${paths.output_dir} # it's unclear to me if this is necessary or used diff --git a/cyto_dl/api/cyto_dl_model/cyto_dl_base_model.py b/cyto_dl/api/cyto_dl_model/cyto_dl_base_model.py index 8986da834..2261b951d 100644 --- a/cyto_dl/api/cyto_dl_model/cyto_dl_base_model.py +++ b/cyto_dl/api/cyto_dl_model/cyto_dl_base_model.py @@ -12,60 +12,74 @@ from cyto_dl.eval import evaluate as evaluate_model from cyto_dl.train import train as train_model +# TODO: encapsulate experiment management (file system) details here, will require passing output_dir +# into the factory methods, maybe + class CytoDLBaseModel(ABC): """A CytoDLBaseModel is used to configure, train, and run predictions on a cyto-dl model.""" - def __init__(self, config_filepath: Optional[Path] = None): - """ - :param config_filepath: path to a .yaml config file that will be used as the basis - for this CytoDLBaseModel (must be generated by the CytoDLBaseModel subclass that wants - to use it). If None, a default config will be used instead. + def __init__(self, cfg: DictConfig): + """Not intended for direct use by clients. + + Please see the classmethod factory methods instead. """ - self._cfg: DictConfig = ( - OmegaConf.load(config_filepath) if config_filepath else self._generate_default_config() - ) + self._cfg: DictConfig = cfg + @classmethod @abstractmethod - def _get_experiment_type(self) -> ExperimentType: + def _get_experiment_type(cls) -> ExperimentType: """Return experiment type for this config (e.g. segmentation_plugin, gan, etc)""" pass - @abstractmethod - def _set_max_epochs(self, max_epochs: int) -> None: - pass + @classmethod + def from_existing_config(cls, config_filepath: Path): + """Returns a model from an existing config. - @abstractmethod - def _set_manifest_path(self, manifest_path: Union[str, Path]) -> None: - pass + :param config_filepath: path to a .yaml config file that will be used as the basis + for this CytoDLBaseModel (must be generated by the CytoDLBaseModel subclass that wants + to use it). + """ + return cls(OmegaConf.load(config_filepath)) - @abstractmethod - def _set_output_dir(self, output_dir: Union[str, Path]) -> None: - pass + # TODO: if spatial_dims is only ever 2 or 3, create an enum for it + @classmethod + def from_default_config(cls, spatial_dims: int): + """Returns a model from the default config. - def _generate_default_config(self) -> DictConfig: + :param spatial_dims: dimensions for the model (e.g. 2) + """ cfg_dir: Path = ( pyrootutils.find_root(search_from=__file__, indicator=("pyproject.toml", "README.md")) / "configs" ) GlobalHydra.instance().clear() with initialize_config_dir(version_base="1.2", config_dir=str(cfg_dir)): - # the overrides for 'paths.' are necessary to make later overrides work (for some reason I don't fully understand) - # I'd prefer to do this in experiment configs instead of the overrides here cfg: DictConfig = compose( - config_name="train.yaml", # only using train.yaml after conversation w/ Benji + config_name="train.yaml", # train.yaml can work for prediction too return_hydra_config=True, overrides=[ - f"experiment=im2im/{self._get_experiment_type().name.lower()}", - "paths.output_dir=PLACEHOLDER", - "paths.work_dir=PLACEHOLDER", + f"experiment=im2im/{cls._get_experiment_type().name.lower()}", + f"spatial_dims={spatial_dims}", ], ) with open_dict(cfg): del cfg["hydra"] cfg.extras.enforce_tags = False cfg.extras.print_config = False - return cfg + return cls(cfg) + + @abstractmethod + def _set_max_epochs(self, max_epochs: int) -> None: + pass + + @abstractmethod + def _set_manifest_path(self, manifest_path: Union[str, Path]) -> None: + pass + + @abstractmethod + def _set_output_dir(self, output_dir: Union[str, Path]) -> None: + pass def _key_exists(self, k: str) -> bool: keys: List[str] = k.split(".") diff --git a/cyto_dl/api/cyto_dl_model/segmentation_plugin_model.py b/cyto_dl/api/cyto_dl_model/segmentation_plugin_model.py index 68a0b9deb..0e014dbae 100644 --- a/cyto_dl/api/cyto_dl_model/segmentation_plugin_model.py +++ b/cyto_dl/api/cyto_dl_model/segmentation_plugin_model.py @@ -1,7 +1,7 @@ from pathlib import Path from typing import List, Optional, Tuple, Union -from omegaconf import ListConfig +from omegaconf import DictConfig, ListConfig from cyto_dl.api.cyto_dl_model import CytoDLBaseModel from cyto_dl.api.data import ExperimentType, HardwareType, PatchSize @@ -11,13 +11,12 @@ class SegmentationPluginModel(CytoDLBaseModel): """A SegmentationPluginModel handles configuration, training, and prediction using the default segmentation_plugin experiment from CytoDL.""" - def __init__(self, config_filepath: Optional[Path] = None): - super().__init__(config_filepath) + def __init__(self, cfg: DictConfig): + super().__init__(cfg) self._has_split_column = False - # we currently have an override for ['mode'] in ml-seg, but I can't find top-level 'mode' in the configs, - # do we need to support this? - def _get_experiment_type(self) -> ExperimentType: + @classmethod + def _get_experiment_type(cls) -> ExperimentType: return ExperimentType.SEGMENTATION_PLUGIN def _set_max_epochs(self, max_epochs: int) -> None: @@ -28,18 +27,8 @@ def _set_manifest_path(self, manifest_path: Union[str, Path]) -> None: def _set_output_dir(self, output_dir: Union[str, Path]) -> None: self._set_cfg("paths.output_dir", str(output_dir)) - # I can't find where work_dir is actually used in cyto_dl, do we need to support this? self._set_cfg("paths.work_dir", str(output_dir)) - # a lot of these keys are duplicated across the im2im experiment types (but not present in top-level - # train.yaml or eval.yaml) - should we move these into the top-level configs and move these setters and - # getters accordingly? - def set_spatial_dims(self, spatial_dims: int) -> None: - self._set_cfg("spatial_dims", spatial_dims) - - def get_spatial_dims(self) -> int: - return self._get_cfg("spatial_dims") - def set_input_channel(self, input_channel: int) -> None: self._set_cfg("input_channel", input_channel) @@ -98,9 +87,6 @@ def remove_split_column(self) -> None: del existing_cols[-1] self._has_split_column = False - # is patch_shape required in order to run training/prediction? - # if so, it should be an argument to train/predict/__init__, or a default - # should be set in the config def set_patch_size(self, patch_size: PatchSize) -> None: self._set_cfg("data._aux.patch_shape", patch_size.value) diff --git a/tests/api/cyto_dl_model/test_segmentation_plugin_model.py b/tests/api/cyto_dl_model/test_segmentation_plugin_model.py index 8a76111e9..d2190d1ea 100644 --- a/tests/api/cyto_dl_model/test_segmentation_plugin_model.py +++ b/tests/api/cyto_dl_model/test_segmentation_plugin_model.py @@ -16,12 +16,12 @@ # are actually used how we expect by Cyto-DL, just that they exist in the default config. @pytest.fixture def model_with_default_config() -> SegmentationPluginModel: - return SegmentationPluginModel() + return SegmentationPluginModel.from_default_config(3) @pytest.fixture def model_with_bad_config() -> SegmentationPluginModel: - return SegmentationPluginModel(config_filepath=Path(__file__).parent / "bad_config.yaml") + return SegmentationPluginModel.from_existing_config(Path(__file__).parent / "bad_config.yaml") class TestDefaultConfig: @@ -46,11 +46,6 @@ def test_predict( model_with_default_config.predict("manifest", "output_dir", Path("ckpt")) evaluate_model_mock.assert_called_once() - def test_spatial_dims(self, model_with_default_config: SegmentationPluginModel): - assert model_with_default_config.get_spatial_dims() is not None - model_with_default_config.set_spatial_dims(24) - assert model_with_default_config.get_spatial_dims() == 24 - def test_input_channel(self, model_with_default_config: SegmentationPluginModel): assert model_with_default_config.get_input_channel() is not None model_with_default_config.set_input_channel(785) @@ -103,12 +98,6 @@ def test_predict( with pytest.raises(KeyError): model_with_bad_config.predict("manifest", "output_dir", Path("ckpt")) - def test_spatial_dims(self, model_with_bad_config: SegmentationPluginModel): - with pytest.raises(KeyError): - model_with_bad_config.get_spatial_dims() - with pytest.raises(KeyError): - model_with_bad_config.set_spatial_dims(24) - def test_input_channel(self, model_with_bad_config: SegmentationPluginModel): with pytest.raises(KeyError): model_with_bad_config.get_input_channel()