Skip to content

Commit

Permalink
Merge pull request #347 from AllenCellModeling/feature/api_dataclass
Browse files Browse the repository at this point in the history
Feature/api dataclass
  • Loading branch information
saeliddp authored Mar 22, 2024
2 parents 02dbdde + 448e7fc commit d7c663e
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 5 deletions.
40 changes: 40 additions & 0 deletions cyto_dl/api/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# goal: allow the API client to interact with CytoDLModel without
# worrying about the actual values being used in config files
from enum import Enum

import skimage


class ExperimentType(Enum):
GAN = "gan"
INSTANCE_SEG = "instance_seg"
LABEL_FREE = "labelfree"
SEGMENTATION_PLUGIN = "segmentation_plugin"
SEGMENTATION = "segmentation"


class HardwareType(Enum):
CPU = "cpu"
GPU = "gpu"
# other hardware types available, but require more complicated config


class PatchSize(Enum):
"""Patch size for training, and their respective patch shapes."""

SMALL = [16, 32, 32]
MEDIUM = [16, 64, 64]
LARGE = [16, 128, 128]


# importing skimage takes a while.
# could speed it up by hardcoding this enum, but would have to update
# if skimage adds/deletes threshold filters

AutoThresholdMethod = Enum(
"AutoThresholdMethod",
{
func_name.split("threshold_")[-1].upper(): func_name
for func_name in filter(lambda x: x.startswith("threshold_"), dir(skimage.filters))
},
)
11 changes: 6 additions & 5 deletions cyto_dl/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@
from hydra.core.global_hydra import GlobalHydra
from omegaconf import OmegaConf, open_dict

from cyto_dl.api.data import ExperimentType
from cyto_dl.eval import evaluate
from cyto_dl.train import train as train_model
from cyto_dl.utils.download_test_data import download_test_data
from cyto_dl.utils.rich_utils import print_config_tree

DEFAULT_EXPERIMENTS = ["gan", "instance_seg", "labelfree", "segmentation_plugin", "segmentation"]


class CytoDLModel:
def __init__(self):
Expand Down Expand Up @@ -45,19 +44,21 @@ def load_config_from_dict(self, config: dict):
"""Load configuration from dictionary."""
self.cfg = config

# TODO: replace experiment_type str with api.data.ExperimentType -> will
# require corresponding changes in ml-segmenter
def load_default_experiment(
self, experiment_name: str, output_dir: str, train=True, overrides: List = []
self, experiment_type: str, output_dir: str, train=True, overrides: List = []
):
"""Load configuration from directory."""
assert experiment_name in DEFAULT_EXPERIMENTS
assert experiment_type in {exp_type.value for exp_type in ExperimentType}
config_dir = self.root / "configs"

GlobalHydra.instance().clear()
with initialize_config_dir(version_base="1.2", config_dir=str(config_dir)):
cfg = compose(
config_name="train.yaml" if train else "eval.yaml",
return_hydra_config=True,
overrides=[f"experiment=im2im/{experiment_name}"] + overrides,
overrides=[f"experiment=im2im/{experiment_type}"] + overrides,
)

with open_dict(cfg):
Expand Down
27 changes: 27 additions & 0 deletions tests/api/test_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from unittest.mock import patch

import pytest

import cyto_dl.api.model
from cyto_dl.api.data import ExperimentType
from cyto_dl.api.model import CytoDLModel


# mock these functions to avoid attempts to write to file system
@patch("cyto_dl.api.model.OmegaConf.save")
@patch("cyto_dl.api.model.Path.mkdir")
def test_load_default_experiment_valid_exp_type(MockMkdir, MockSave):
model: CytoDLModel = CytoDLModel()
model.load_default_experiment(ExperimentType.SEGMENTATION.value, "fake_dir")
MockMkdir.assert_called()
MockSave.assert_called()


@patch("cyto_dl.api.model.OmegaConf.save")
@patch("cyto_dl.api.model.Path.mkdir")
def test_load_default_experiment_invalid_exp_type(MockMkdir, MockSave):
model: CytoDLModel = CytoDLModel()
with pytest.raises(AssertionError):
model.load_default_experiment("invalid_exp_type", "fake_dir")
MockMkdir.assert_not_called()
MockSave.assert_not_called()

0 comments on commit d7c663e

Please sign in to comment.