Skip to content

Commit

Permalink
Merge pull request #361 from AllenCellModeling/feature/api_config_class
Browse files Browse the repository at this point in the history
Initial new API
  • Loading branch information
saeliddp authored Apr 24, 2024
2 parents aa54da5 + bb6fa8d commit d3107f9
Show file tree
Hide file tree
Showing 7 changed files with 399 additions and 16 deletions.
34 changes: 18 additions & 16 deletions configs/experiment/im2im/segmentation_plugin.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,21 @@ defaults:
- override /model: im2im/segmentation_plugin.yaml
- override /callbacks: default.yaml
- override /trainer: gpu.yaml
- override /logger: mlflow.yaml
- override /logger: csv.yaml

# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters

# parameters with value MUST_OVERRIDE must be overridden before using this config, all other
# parameters have a reasonable default value

tags: ["dev"]
seed: 12345

experiment_name: YOUR_EXP_NAME
run_name: YOUR_RUN_NAME
ckpt_path: null # must override for prediction

experiment_name: experiment_name
run_name: run_name

# manifest columns
source_col: raw
Expand All @@ -26,24 +31,21 @@ exclude_mask_col: exclude_mask
base_image_col: base_image

# data params
spatial_dims: 3
spatial_dims: MUST_OVERRIDE # int value, req for first training, should not change after
input_channel: 0
raw_im_channels: 1

trainer:
max_epochs: 100
max_epochs: 1 # must override for training
accelerator: gpu

data:
path: ${paths.data_dir}/example_experiment_data/s3_data
cache_dir: ${paths.data_dir}/example_experiment_data/cache
path: MUST_OVERRIDE # string path to manifest
split_column: null
batch_size: 1
_aux:
patch_shape:
# small, medium, large
# 32 pix, 64 pix, 128 pix

# OVERRIDE:
# data._aux.patch_shape
# model._aux.strides
# model._aux.kernel_size
# model._aux.upsample_kernel_size
patch_shape: [16, 32, 32]

paths:
output_dir: MUST_OVERRIDE
work_dir: ${paths.output_dir} # it's unclear to me if this is necessary or used
2 changes: 2 additions & 0 deletions cyto_dl/api/cyto_dl_model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .cyto_dl_base_model import CytoDLBaseModel
from .segmentation_plugin_model import SegmentationPluginModel
147 changes: 147 additions & 0 deletions cyto_dl/api/cyto_dl_model/cyto_dl_base_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
from abc import ABC, abstractmethod
from copy import deepcopy
from pathlib import Path
from typing import Any, List, Optional, Union

import pyrootutils
from hydra import compose, initialize_config_dir
from hydra.core.global_hydra import GlobalHydra
from omegaconf import DictConfig, OmegaConf, open_dict

from cyto_dl.api.data import ExperimentType
from cyto_dl.eval import evaluate as evaluate_model
from cyto_dl.train import train as train_model

# TODO: encapsulate experiment management (file system) details here, will require passing output_dir
# into the factory methods, maybe


class CytoDLBaseModel(ABC):
"""A CytoDLBaseModel is used to configure, train, and run predictions on a cyto-dl model."""

def __init__(self, cfg: DictConfig):
"""Not intended for direct use by clients.
Please see the classmethod factory methods instead.
"""
self._cfg: DictConfig = cfg

@classmethod
@abstractmethod
def _get_experiment_type(cls) -> ExperimentType:
"""Return experiment type for this config (e.g. segmentation_plugin, gan, etc)"""
pass

@classmethod
def from_existing_config(cls, config_filepath: Path):
"""Returns a model from an existing config.
:param config_filepath: path to a .yaml config file that will be used as the basis
for this CytoDLBaseModel (must be generated by the CytoDLBaseModel subclass that wants
to use it).
"""
return cls(OmegaConf.load(config_filepath))

# TODO: if spatial_dims is only ever 2 or 3, create an enum for it
@classmethod
def from_default_config(cls, spatial_dims: int):
"""Returns a model from the default config.
:param spatial_dims: dimensions for the model (e.g. 2)
"""
cfg_dir: Path = (
pyrootutils.find_root(search_from=__file__, indicator=("pyproject.toml", "README.md"))
/ "configs"
)
GlobalHydra.instance().clear()
with initialize_config_dir(version_base="1.2", config_dir=str(cfg_dir)):
cfg: DictConfig = compose(
config_name="train.yaml", # train.yaml can work for prediction too
return_hydra_config=True,
overrides=[
f"experiment=im2im/{cls._get_experiment_type().name.lower()}",
f"spatial_dims={spatial_dims}",
],
)
with open_dict(cfg):
del cfg["hydra"]
cfg.extras.enforce_tags = False
cfg.extras.print_config = False
return cls(cfg)

@abstractmethod
def _set_max_epochs(self, max_epochs: int) -> None:
pass

@abstractmethod
def _set_manifest_path(self, manifest_path: Union[str, Path]) -> None:
pass

@abstractmethod
def _set_output_dir(self, output_dir: Union[str, Path]) -> None:
pass

def _key_exists(self, k: str) -> bool:
keys: List[str] = k.split(".")
curr_dict: DictConfig = self._cfg
while keys:
key: str = keys.pop(0)
if key not in curr_dict:
return False
curr_dict = curr_dict[key]
return True

def _set_cfg(self, k: str, v: Any) -> None:
if not self._key_exists(k):
raise KeyError(f"{k} not found in config dict")
OmegaConf.update(self._cfg, k, v)

def _get_cfg(self, k: str) -> Any:
if not self._key_exists(k):
raise KeyError(f"{k} not found in config dict")
return OmegaConf.select(self._cfg, k)

def _set_training_config(self, train: bool):
self._set_cfg("train", train)
self._set_cfg("test", train)
# afaik, task_name isn't used outside of template_utils.py - do we need to support this?
self._set_cfg("task_name", "train" if train else "predict")

def _set_ckpt(self, ckpt: Optional[Path]) -> None:
self._set_cfg("ckpt_path", str(ckpt.resolve()) if ckpt else ckpt)

# does experiment name have any effect?
def set_experiment_name(self, name: str) -> None:
self._set_cfg("experiment_name", name)

def get_experiment_name(self) -> str:
return self._get_cfg("experiment_name")

def get_config(self) -> DictConfig:
return deepcopy(self._cfg)

def save_config(self, path: Path) -> None:
OmegaConf.save(self._cfg, path)

def train(
self,
max_epochs: int,
manifest_path: Union[str, Path],
output_dir: Union[str, Path],
checkpoint: Optional[Path] = None,
) -> None:
self._set_training_config(True)
self._set_max_epochs(max_epochs)
self._set_manifest_path(manifest_path)
self._set_output_dir(output_dir)
self._set_ckpt(checkpoint)
train_model(self._cfg)

def predict(
self, manifest_path: Union[str, Path], output_dir: Union[str, Path], checkpoint: Path
) -> None:
self._set_training_config(False)
self._set_manifest_path(manifest_path)
self._set_output_dir(output_dir)
self._set_ckpt(checkpoint)
evaluate_model(self._cfg)
101 changes: 101 additions & 0 deletions cyto_dl/api/cyto_dl_model/segmentation_plugin_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from pathlib import Path
from typing import List, Optional, Tuple, Union

from omegaconf import DictConfig, ListConfig

from cyto_dl.api.cyto_dl_model import CytoDLBaseModel
from cyto_dl.api.data import ExperimentType, HardwareType, PatchSize


class SegmentationPluginModel(CytoDLBaseModel):
"""A SegmentationPluginModel handles configuration, training, and prediction using the default
segmentation_plugin experiment from CytoDL."""

def __init__(self, cfg: DictConfig):
super().__init__(cfg)
self._has_split_column = False

@classmethod
def _get_experiment_type(cls) -> ExperimentType:
return ExperimentType.SEGMENTATION_PLUGIN

def _set_max_epochs(self, max_epochs: int) -> None:
self._set_cfg("trainer.max_epochs", max_epochs)

def _set_manifest_path(self, manifest_path: Union[str, Path]) -> None:
self._set_cfg("data.path", str(manifest_path))

def _set_output_dir(self, output_dir: Union[str, Path]) -> None:
self._set_cfg("paths.output_dir", str(output_dir))
self._set_cfg("paths.work_dir", str(output_dir))

def set_input_channel(self, input_channel: int) -> None:
self._set_cfg("input_channel", input_channel)

def get_input_channel(self) -> int:
return self._get_cfg("input_channel")

def set_raw_image_channels(self, channels: int) -> None:
self._set_cfg("raw_im_channels", channels)

def get_raw_image_channels(self) -> int:
return self._get_cfg("raw_im_channels")

# TODO: is there a better way to deal with column names + split columns?
def set_manifest_column_names(
self,
source: str,
target1: str,
target2: str,
merge_mask: str,
exclude_mask: str,
base_image: str,
) -> None:
self._set_cfg("source_col", source)
self._set_cfg("target_col1", target1)
self._set_cfg("target_col2", target2)
self._set_cfg("merge_mask_col", merge_mask)
self._set_cfg("exclude_mask_col", exclude_mask)
self._set_cfg("base_image_col", base_image)

def get_manifest_column_names(self) -> Tuple[str, str, str, str, str, str]:
return (
self._get_cfg("source_col"),
self._get_cfg("target_col1"),
self._get_cfg("target_col2"),
self._get_cfg("merge_mask_col"),
self._get_cfg("exclude_mask_col"),
self._get_cfg("base_image_col"),
)

def set_split_column(self, split_column: str) -> None:
self._set_cfg("data.split_column", split_column)
existing_cols: ListConfig = self._get_cfg("data.columns")
if self._has_split_column:
existing_cols[-1] = split_column
else:
existing_cols.append(split_column)
self._has_split_column = True

def get_split_column(self) -> Optional[str]:
return self._get_cfg("data.split_column")

def remove_split_column(self) -> None:
if self._has_split_column:
self._set_cfg("data.split_column", None)
existing_cols: ListConfig = self._get_cfg("data.columns")
del existing_cols[-1]
self._has_split_column = False

def set_patch_size(self, patch_size: PatchSize) -> None:
self._set_cfg("data._aux.patch_shape", patch_size.value)

def get_patch_size(self) -> Optional[PatchSize]:
p_shape: ListConfig = self._get_cfg("data._aux.patch_shape")
return PatchSize(list(p_shape)) if p_shape else None

def set_hardware_type(self, hardware_type: HardwareType) -> None:
self._set_cfg("trainer.accelerator", hardware_type.value)

def get_hardware_type(self) -> HardwareType:
return HardwareType(self._get_cfg("trainer.accelerator"))
3 changes: 3 additions & 0 deletions cyto_dl/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@


class CytoDLModel:
# TODO: add optional CytoDLConfig param to init--if client passes a
# CytoDLConfig subtype, config will be initialized in constructor and
# calls to train/predict can be run immediately
def __init__(self):
self.cfg = None
self._training = False
Expand Down
5 changes: 5 additions & 0 deletions tests/api/cyto_dl_model/bad_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
this: 1
is:
a: 2
bad: 3
config: 4
Loading

0 comments on commit d3107f9

Please sign in to comment.