-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #361 from AllenCellModeling/feature/api_config_class
Initial new API
- Loading branch information
Showing
7 changed files
with
399 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .cyto_dl_base_model import CytoDLBaseModel | ||
from .segmentation_plugin_model import SegmentationPluginModel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
from abc import ABC, abstractmethod | ||
from copy import deepcopy | ||
from pathlib import Path | ||
from typing import Any, List, Optional, Union | ||
|
||
import pyrootutils | ||
from hydra import compose, initialize_config_dir | ||
from hydra.core.global_hydra import GlobalHydra | ||
from omegaconf import DictConfig, OmegaConf, open_dict | ||
|
||
from cyto_dl.api.data import ExperimentType | ||
from cyto_dl.eval import evaluate as evaluate_model | ||
from cyto_dl.train import train as train_model | ||
|
||
# TODO: encapsulate experiment management (file system) details here, will require passing output_dir | ||
# into the factory methods, maybe | ||
|
||
|
||
class CytoDLBaseModel(ABC): | ||
"""A CytoDLBaseModel is used to configure, train, and run predictions on a cyto-dl model.""" | ||
|
||
def __init__(self, cfg: DictConfig): | ||
"""Not intended for direct use by clients. | ||
Please see the classmethod factory methods instead. | ||
""" | ||
self._cfg: DictConfig = cfg | ||
|
||
@classmethod | ||
@abstractmethod | ||
def _get_experiment_type(cls) -> ExperimentType: | ||
"""Return experiment type for this config (e.g. segmentation_plugin, gan, etc)""" | ||
pass | ||
|
||
@classmethod | ||
def from_existing_config(cls, config_filepath: Path): | ||
"""Returns a model from an existing config. | ||
:param config_filepath: path to a .yaml config file that will be used as the basis | ||
for this CytoDLBaseModel (must be generated by the CytoDLBaseModel subclass that wants | ||
to use it). | ||
""" | ||
return cls(OmegaConf.load(config_filepath)) | ||
|
||
# TODO: if spatial_dims is only ever 2 or 3, create an enum for it | ||
@classmethod | ||
def from_default_config(cls, spatial_dims: int): | ||
"""Returns a model from the default config. | ||
:param spatial_dims: dimensions for the model (e.g. 2) | ||
""" | ||
cfg_dir: Path = ( | ||
pyrootutils.find_root(search_from=__file__, indicator=("pyproject.toml", "README.md")) | ||
/ "configs" | ||
) | ||
GlobalHydra.instance().clear() | ||
with initialize_config_dir(version_base="1.2", config_dir=str(cfg_dir)): | ||
cfg: DictConfig = compose( | ||
config_name="train.yaml", # train.yaml can work for prediction too | ||
return_hydra_config=True, | ||
overrides=[ | ||
f"experiment=im2im/{cls._get_experiment_type().name.lower()}", | ||
f"spatial_dims={spatial_dims}", | ||
], | ||
) | ||
with open_dict(cfg): | ||
del cfg["hydra"] | ||
cfg.extras.enforce_tags = False | ||
cfg.extras.print_config = False | ||
return cls(cfg) | ||
|
||
@abstractmethod | ||
def _set_max_epochs(self, max_epochs: int) -> None: | ||
pass | ||
|
||
@abstractmethod | ||
def _set_manifest_path(self, manifest_path: Union[str, Path]) -> None: | ||
pass | ||
|
||
@abstractmethod | ||
def _set_output_dir(self, output_dir: Union[str, Path]) -> None: | ||
pass | ||
|
||
def _key_exists(self, k: str) -> bool: | ||
keys: List[str] = k.split(".") | ||
curr_dict: DictConfig = self._cfg | ||
while keys: | ||
key: str = keys.pop(0) | ||
if key not in curr_dict: | ||
return False | ||
curr_dict = curr_dict[key] | ||
return True | ||
|
||
def _set_cfg(self, k: str, v: Any) -> None: | ||
if not self._key_exists(k): | ||
raise KeyError(f"{k} not found in config dict") | ||
OmegaConf.update(self._cfg, k, v) | ||
|
||
def _get_cfg(self, k: str) -> Any: | ||
if not self._key_exists(k): | ||
raise KeyError(f"{k} not found in config dict") | ||
return OmegaConf.select(self._cfg, k) | ||
|
||
def _set_training_config(self, train: bool): | ||
self._set_cfg("train", train) | ||
self._set_cfg("test", train) | ||
# afaik, task_name isn't used outside of template_utils.py - do we need to support this? | ||
self._set_cfg("task_name", "train" if train else "predict") | ||
|
||
def _set_ckpt(self, ckpt: Optional[Path]) -> None: | ||
self._set_cfg("ckpt_path", str(ckpt.resolve()) if ckpt else ckpt) | ||
|
||
# does experiment name have any effect? | ||
def set_experiment_name(self, name: str) -> None: | ||
self._set_cfg("experiment_name", name) | ||
|
||
def get_experiment_name(self) -> str: | ||
return self._get_cfg("experiment_name") | ||
|
||
def get_config(self) -> DictConfig: | ||
return deepcopy(self._cfg) | ||
|
||
def save_config(self, path: Path) -> None: | ||
OmegaConf.save(self._cfg, path) | ||
|
||
def train( | ||
self, | ||
max_epochs: int, | ||
manifest_path: Union[str, Path], | ||
output_dir: Union[str, Path], | ||
checkpoint: Optional[Path] = None, | ||
) -> None: | ||
self._set_training_config(True) | ||
self._set_max_epochs(max_epochs) | ||
self._set_manifest_path(manifest_path) | ||
self._set_output_dir(output_dir) | ||
self._set_ckpt(checkpoint) | ||
train_model(self._cfg) | ||
|
||
def predict( | ||
self, manifest_path: Union[str, Path], output_dir: Union[str, Path], checkpoint: Path | ||
) -> None: | ||
self._set_training_config(False) | ||
self._set_manifest_path(manifest_path) | ||
self._set_output_dir(output_dir) | ||
self._set_ckpt(checkpoint) | ||
evaluate_model(self._cfg) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
from pathlib import Path | ||
from typing import List, Optional, Tuple, Union | ||
|
||
from omegaconf import DictConfig, ListConfig | ||
|
||
from cyto_dl.api.cyto_dl_model import CytoDLBaseModel | ||
from cyto_dl.api.data import ExperimentType, HardwareType, PatchSize | ||
|
||
|
||
class SegmentationPluginModel(CytoDLBaseModel): | ||
"""A SegmentationPluginModel handles configuration, training, and prediction using the default | ||
segmentation_plugin experiment from CytoDL.""" | ||
|
||
def __init__(self, cfg: DictConfig): | ||
super().__init__(cfg) | ||
self._has_split_column = False | ||
|
||
@classmethod | ||
def _get_experiment_type(cls) -> ExperimentType: | ||
return ExperimentType.SEGMENTATION_PLUGIN | ||
|
||
def _set_max_epochs(self, max_epochs: int) -> None: | ||
self._set_cfg("trainer.max_epochs", max_epochs) | ||
|
||
def _set_manifest_path(self, manifest_path: Union[str, Path]) -> None: | ||
self._set_cfg("data.path", str(manifest_path)) | ||
|
||
def _set_output_dir(self, output_dir: Union[str, Path]) -> None: | ||
self._set_cfg("paths.output_dir", str(output_dir)) | ||
self._set_cfg("paths.work_dir", str(output_dir)) | ||
|
||
def set_input_channel(self, input_channel: int) -> None: | ||
self._set_cfg("input_channel", input_channel) | ||
|
||
def get_input_channel(self) -> int: | ||
return self._get_cfg("input_channel") | ||
|
||
def set_raw_image_channels(self, channels: int) -> None: | ||
self._set_cfg("raw_im_channels", channels) | ||
|
||
def get_raw_image_channels(self) -> int: | ||
return self._get_cfg("raw_im_channels") | ||
|
||
# TODO: is there a better way to deal with column names + split columns? | ||
def set_manifest_column_names( | ||
self, | ||
source: str, | ||
target1: str, | ||
target2: str, | ||
merge_mask: str, | ||
exclude_mask: str, | ||
base_image: str, | ||
) -> None: | ||
self._set_cfg("source_col", source) | ||
self._set_cfg("target_col1", target1) | ||
self._set_cfg("target_col2", target2) | ||
self._set_cfg("merge_mask_col", merge_mask) | ||
self._set_cfg("exclude_mask_col", exclude_mask) | ||
self._set_cfg("base_image_col", base_image) | ||
|
||
def get_manifest_column_names(self) -> Tuple[str, str, str, str, str, str]: | ||
return ( | ||
self._get_cfg("source_col"), | ||
self._get_cfg("target_col1"), | ||
self._get_cfg("target_col2"), | ||
self._get_cfg("merge_mask_col"), | ||
self._get_cfg("exclude_mask_col"), | ||
self._get_cfg("base_image_col"), | ||
) | ||
|
||
def set_split_column(self, split_column: str) -> None: | ||
self._set_cfg("data.split_column", split_column) | ||
existing_cols: ListConfig = self._get_cfg("data.columns") | ||
if self._has_split_column: | ||
existing_cols[-1] = split_column | ||
else: | ||
existing_cols.append(split_column) | ||
self._has_split_column = True | ||
|
||
def get_split_column(self) -> Optional[str]: | ||
return self._get_cfg("data.split_column") | ||
|
||
def remove_split_column(self) -> None: | ||
if self._has_split_column: | ||
self._set_cfg("data.split_column", None) | ||
existing_cols: ListConfig = self._get_cfg("data.columns") | ||
del existing_cols[-1] | ||
self._has_split_column = False | ||
|
||
def set_patch_size(self, patch_size: PatchSize) -> None: | ||
self._set_cfg("data._aux.patch_shape", patch_size.value) | ||
|
||
def get_patch_size(self) -> Optional[PatchSize]: | ||
p_shape: ListConfig = self._get_cfg("data._aux.patch_shape") | ||
return PatchSize(list(p_shape)) if p_shape else None | ||
|
||
def set_hardware_type(self, hardware_type: HardwareType) -> None: | ||
self._set_cfg("trainer.accelerator", hardware_type.value) | ||
|
||
def get_hardware_type(self) -> HardwareType: | ||
return HardwareType(self._get_cfg("trainer.accelerator")) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
this: 1 | ||
is: | ||
a: 2 | ||
bad: 3 | ||
config: 4 |
Oops, something went wrong.