Skip to content

Commit

Permalink
Merge pull request #377 from AllenCellModeling/feature/top_level_plug…
Browse files Browse the repository at this point in the history
…in_overrides

Move overrides to top-level config
  • Loading branch information
saeliddp authored Apr 24, 2024
2 parents 889775c + 73ef86c commit bb6fa8d
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 74 deletions.
34 changes: 18 additions & 16 deletions configs/experiment/im2im/segmentation_plugin.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,21 @@ defaults:
- override /model: im2im/segmentation_plugin.yaml
- override /callbacks: default.yaml
- override /trainer: gpu.yaml
- override /logger: mlflow.yaml
- override /logger: csv.yaml

# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters

# parameters with value MUST_OVERRIDE must be overridden before using this config, all other
# parameters have a reasonable default value

tags: ["dev"]
seed: 12345

experiment_name: YOUR_EXP_NAME
run_name: YOUR_RUN_NAME
ckpt_path: null # must override for prediction

experiment_name: experiment_name
run_name: run_name

# manifest columns
source_col: raw
Expand All @@ -26,24 +31,21 @@ exclude_mask_col: exclude_mask
base_image_col: base_image

# data params
spatial_dims: 3
spatial_dims: MUST_OVERRIDE # int value, req for first training, should not change after
input_channel: 0
raw_im_channels: 1

trainer:
max_epochs: 100
max_epochs: 1 # must override for training
accelerator: gpu

data:
path: ${paths.data_dir}/example_experiment_data/s3_data
cache_dir: ${paths.data_dir}/example_experiment_data/cache
path: MUST_OVERRIDE # string path to manifest
split_column: null
batch_size: 1
_aux:
patch_shape:
# small, medium, large
# 32 pix, 64 pix, 128 pix

# OVERRIDE:
# data._aux.patch_shape
# model._aux.strides
# model._aux.kernel_size
# model._aux.upsample_kernel_size
patch_shape: [16, 32, 32]

paths:
output_dir: MUST_OVERRIDE
work_dir: ${paths.output_dir} # it's unclear to me if this is necessary or used
66 changes: 40 additions & 26 deletions cyto_dl/api/cyto_dl_model/cyto_dl_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,60 +12,74 @@
from cyto_dl.eval import evaluate as evaluate_model
from cyto_dl.train import train as train_model

# TODO: encapsulate experiment management (file system) details here, will require passing output_dir
# into the factory methods, maybe


class CytoDLBaseModel(ABC):
"""A CytoDLBaseModel is used to configure, train, and run predictions on a cyto-dl model."""

def __init__(self, config_filepath: Optional[Path] = None):
"""
:param config_filepath: path to a .yaml config file that will be used as the basis
for this CytoDLBaseModel (must be generated by the CytoDLBaseModel subclass that wants
to use it). If None, a default config will be used instead.
def __init__(self, cfg: DictConfig):
"""Not intended for direct use by clients.
Please see the classmethod factory methods instead.
"""
self._cfg: DictConfig = (
OmegaConf.load(config_filepath) if config_filepath else self._generate_default_config()
)
self._cfg: DictConfig = cfg

@classmethod
@abstractmethod
def _get_experiment_type(self) -> ExperimentType:
def _get_experiment_type(cls) -> ExperimentType:
"""Return experiment type for this config (e.g. segmentation_plugin, gan, etc)"""
pass

@abstractmethod
def _set_max_epochs(self, max_epochs: int) -> None:
pass
@classmethod
def from_existing_config(cls, config_filepath: Path):
"""Returns a model from an existing config.
@abstractmethod
def _set_manifest_path(self, manifest_path: Union[str, Path]) -> None:
pass
:param config_filepath: path to a .yaml config file that will be used as the basis
for this CytoDLBaseModel (must be generated by the CytoDLBaseModel subclass that wants
to use it).
"""
return cls(OmegaConf.load(config_filepath))

@abstractmethod
def _set_output_dir(self, output_dir: Union[str, Path]) -> None:
pass
# TODO: if spatial_dims is only ever 2 or 3, create an enum for it
@classmethod
def from_default_config(cls, spatial_dims: int):
"""Returns a model from the default config.
def _generate_default_config(self) -> DictConfig:
:param spatial_dims: dimensions for the model (e.g. 2)
"""
cfg_dir: Path = (
pyrootutils.find_root(search_from=__file__, indicator=("pyproject.toml", "README.md"))
/ "configs"
)
GlobalHydra.instance().clear()
with initialize_config_dir(version_base="1.2", config_dir=str(cfg_dir)):
# the overrides for 'paths.' are necessary to make later overrides work (for some reason I don't fully understand)
# I'd prefer to do this in experiment configs instead of the overrides here
cfg: DictConfig = compose(
config_name="train.yaml", # only using train.yaml after conversation w/ Benji
config_name="train.yaml", # train.yaml can work for prediction too
return_hydra_config=True,
overrides=[
f"experiment=im2im/{self._get_experiment_type().name.lower()}",
"paths.output_dir=PLACEHOLDER",
"paths.work_dir=PLACEHOLDER",
f"experiment=im2im/{cls._get_experiment_type().name.lower()}",
f"spatial_dims={spatial_dims}",
],
)
with open_dict(cfg):
del cfg["hydra"]
cfg.extras.enforce_tags = False
cfg.extras.print_config = False
return cfg
return cls(cfg)

@abstractmethod
def _set_max_epochs(self, max_epochs: int) -> None:
pass

@abstractmethod
def _set_manifest_path(self, manifest_path: Union[str, Path]) -> None:
pass

@abstractmethod
def _set_output_dir(self, output_dir: Union[str, Path]) -> None:
pass

def _key_exists(self, k: str) -> bool:
keys: List[str] = k.split(".")
Expand Down
24 changes: 5 additions & 19 deletions cyto_dl/api/cyto_dl_model/segmentation_plugin_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pathlib import Path
from typing import List, Optional, Tuple, Union

from omegaconf import ListConfig
from omegaconf import DictConfig, ListConfig

from cyto_dl.api.cyto_dl_model import CytoDLBaseModel
from cyto_dl.api.data import ExperimentType, HardwareType, PatchSize
Expand All @@ -11,13 +11,12 @@ class SegmentationPluginModel(CytoDLBaseModel):
"""A SegmentationPluginModel handles configuration, training, and prediction using the default
segmentation_plugin experiment from CytoDL."""

def __init__(self, config_filepath: Optional[Path] = None):
super().__init__(config_filepath)
def __init__(self, cfg: DictConfig):
super().__init__(cfg)
self._has_split_column = False

# we currently have an override for ['mode'] in ml-seg, but I can't find top-level 'mode' in the configs,
# do we need to support this?
def _get_experiment_type(self) -> ExperimentType:
@classmethod
def _get_experiment_type(cls) -> ExperimentType:
return ExperimentType.SEGMENTATION_PLUGIN

def _set_max_epochs(self, max_epochs: int) -> None:
Expand All @@ -28,18 +27,8 @@ def _set_manifest_path(self, manifest_path: Union[str, Path]) -> None:

def _set_output_dir(self, output_dir: Union[str, Path]) -> None:
self._set_cfg("paths.output_dir", str(output_dir))
# I can't find where work_dir is actually used in cyto_dl, do we need to support this?
self._set_cfg("paths.work_dir", str(output_dir))

# a lot of these keys are duplicated across the im2im experiment types (but not present in top-level
# train.yaml or eval.yaml) - should we move these into the top-level configs and move these setters and
# getters accordingly?
def set_spatial_dims(self, spatial_dims: int) -> None:
self._set_cfg("spatial_dims", spatial_dims)

def get_spatial_dims(self) -> int:
return self._get_cfg("spatial_dims")

def set_input_channel(self, input_channel: int) -> None:
self._set_cfg("input_channel", input_channel)

Expand Down Expand Up @@ -98,9 +87,6 @@ def remove_split_column(self) -> None:
del existing_cols[-1]
self._has_split_column = False

# is patch_shape required in order to run training/prediction?
# if so, it should be an argument to train/predict/__init__, or a default
# should be set in the config
def set_patch_size(self, patch_size: PatchSize) -> None:
self._set_cfg("data._aux.patch_shape", patch_size.value)

Expand Down
15 changes: 2 additions & 13 deletions tests/api/cyto_dl_model/test_segmentation_plugin_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
# are actually used how we expect by Cyto-DL, just that they exist in the default config.
@pytest.fixture
def model_with_default_config() -> SegmentationPluginModel:
return SegmentationPluginModel()
return SegmentationPluginModel.from_default_config(3)


@pytest.fixture
def model_with_bad_config() -> SegmentationPluginModel:
return SegmentationPluginModel(config_filepath=Path(__file__).parent / "bad_config.yaml")
return SegmentationPluginModel.from_existing_config(Path(__file__).parent / "bad_config.yaml")


class TestDefaultConfig:
Expand All @@ -46,11 +46,6 @@ def test_predict(
model_with_default_config.predict("manifest", "output_dir", Path("ckpt"))
evaluate_model_mock.assert_called_once()

def test_spatial_dims(self, model_with_default_config: SegmentationPluginModel):
assert model_with_default_config.get_spatial_dims() is not None
model_with_default_config.set_spatial_dims(24)
assert model_with_default_config.get_spatial_dims() == 24

def test_input_channel(self, model_with_default_config: SegmentationPluginModel):
assert model_with_default_config.get_input_channel() is not None
model_with_default_config.set_input_channel(785)
Expand Down Expand Up @@ -103,12 +98,6 @@ def test_predict(
with pytest.raises(KeyError):
model_with_bad_config.predict("manifest", "output_dir", Path("ckpt"))

def test_spatial_dims(self, model_with_bad_config: SegmentationPluginModel):
with pytest.raises(KeyError):
model_with_bad_config.get_spatial_dims()
with pytest.raises(KeyError):
model_with_bad_config.set_spatial_dims(24)

def test_input_channel(self, model_with_bad_config: SegmentationPluginModel):
with pytest.raises(KeyError):
model_with_bad_config.get_input_channel()
Expand Down

0 comments on commit bb6fa8d

Please sign in to comment.