From 760dd6141ec13940ffb6df13a99350bd2489dfeb Mon Sep 17 00:00:00 2001 From: Daniel Saelid Date: Tue, 5 Mar 2024 18:06:18 -0800 Subject: [PATCH 1/6] added data module and tests for model --- cyto_dl/api/data.py | 45 +++++++++++++++++++++++++++++++++++++++++ cyto_dl/api/model.py | 11 +++++----- tests/api/test_model.py | 24 ++++++++++++++++++++++ 3 files changed, 75 insertions(+), 5 deletions(-) create mode 100644 cyto_dl/api/data.py create mode 100644 tests/api/test_model.py diff --git a/cyto_dl/api/data.py b/cyto_dl/api/data.py new file mode 100644 index 000000000..99ce6199f --- /dev/null +++ b/cyto_dl/api/data.py @@ -0,0 +1,45 @@ +# goal: allow the API client to interact with CytoDLModel without +# worrying about the actual values being used in config files +from enum import Enum +import skimage + + +class ExperimentType(Enum): + GAN = "gan" + INSTANCE_SEG = "instance_seg" + LABEL_FREE = "labelfree" + SEGMENTATION_PLUGIN = "segmentation_plugin" + SEGMENTATION = "segmentation" + + +class HardwareType(Enum): + CPU = "cpu" + GPU = "gpu" + MPS = "mps" + # other hardware types available, but require more complicated config + + +class PatchSize(Enum): + """ + Patch size for training, and their respective patch shapes. + TODO: get from benji + """ + + SMALL = [1, 3, 3] + MEDIUM = [16, 32, 32] + LARGE = [20, 40, 40] + + +# importing skimage takes a while. +# could speed it up by hardcoding this enum, but would have to update +# if skimage adds/deletes threshold filters + +ThresholdFilters = Enum( + "ThresholdFilters", + { + func_name.split("threshold_")[-1].upper(): func_name + for func_name in filter( + lambda x: x.startswith("threshold_"), dir(skimage.filters) + ) + }, +) diff --git a/cyto_dl/api/model.py b/cyto_dl/api/model.py index 90211df79..6d79941f5 100644 --- a/cyto_dl/api/model.py +++ b/cyto_dl/api/model.py @@ -10,8 +10,7 @@ from cyto_dl.train import train as train_model from cyto_dl.utils.download_test_data import download_test_data from cyto_dl.utils.rich_utils import print_config_tree - -DEFAULT_EXPERIMENTS = ["gan", "instance_seg", "labelfree", "segmentation_plugin", "segmentation"] +from cyto_dl.api.data import ExperimentType class CytoDLModel: @@ -45,11 +44,13 @@ def load_config_from_dict(self, config: dict): """Load configuration from dictionary.""" self.cfg = config + # TODO: replace experiment_type str with api.data.ExperimentType -> will + # require corresponding changes in ml-segmenter def load_default_experiment( - self, experiment_name: str, output_dir: str, train=True, overrides: List = [] + self, experiment_type: str, output_dir: str, train=True, overrides: List = [] ): """Load configuration from directory.""" - assert experiment_name in DEFAULT_EXPERIMENTS + assert experiment_type in {exp_type.value for exp_type in ExperimentType} config_dir = self.root / "configs" GlobalHydra.instance().clear() @@ -57,7 +58,7 @@ def load_default_experiment( cfg = compose( config_name="train.yaml" if train else "eval.yaml", return_hydra_config=True, - overrides=[f"experiment=im2im/{experiment_name}"] + overrides, + overrides=[f"experiment=im2im/{experiment_type}"] + overrides, ) with open_dict(cfg): diff --git a/tests/api/test_model.py b/tests/api/test_model.py new file mode 100644 index 000000000..5056fe2bd --- /dev/null +++ b/tests/api/test_model.py @@ -0,0 +1,24 @@ +import cyto_dl.api.model +from cyto_dl.api.model import CytoDLModel +from cyto_dl.api.data import ExperimentType +from unittest.mock import patch +import pytest + +# mock these functions to avoid attempts to write to file system +@patch("cyto_dl.api.model.OmegaConf.save") +@patch("cyto_dl.api.model.Path.mkdir") +def test_load_default_experiment_valid_exp_type(MockMkdir, MockSave): + model: CytoDLModel = CytoDLModel() + model.load_default_experiment(ExperimentType.SEGMENTATION.value, "fake_dir") + MockMkdir.assert_called() + MockSave.assert_called() + +@patch("cyto_dl.api.model.OmegaConf.save") +@patch("cyto_dl.api.model.Path.mkdir") +def test_load_default_experiment_invalid_exp_type(MockMkdir, MockSave): + model: CytoDLModel = CytoDLModel() + with pytest.raises(AssertionError): + model.load_default_experiment("invalid_exp_type", "fake_dir") + MockMkdir.assert_not_called() + MockSave.assert_not_called() + From 1c7f4517199ddda03f761dff184aff49494c07a9 Mon Sep 17 00:00:00 2001 From: Daniel Saelid Date: Wed, 6 Mar 2024 14:38:02 -0800 Subject: [PATCH 2/6] changed enum name --- cyto_dl/api/data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cyto_dl/api/data.py b/cyto_dl/api/data.py index 99ce6199f..a52d63414 100644 --- a/cyto_dl/api/data.py +++ b/cyto_dl/api/data.py @@ -34,8 +34,8 @@ class PatchSize(Enum): # could speed it up by hardcoding this enum, but would have to update # if skimage adds/deletes threshold filters -ThresholdFilters = Enum( - "ThresholdFilters", +AutoThresholdMethod = Enum( + "AutoThresholdMethod", { func_name.split("threshold_")[-1].upper(): func_name for func_name in filter( From 8ef0da2709e2c1a7dadf1bf6193fe558ca1c91ee Mon Sep 17 00:00:00 2001 From: Daniel Saelid Date: Wed, 6 Mar 2024 15:12:42 -0800 Subject: [PATCH 3/6] get rid of mac gpu option --- cyto_dl/api/data.py | 1 - 1 file changed, 1 deletion(-) diff --git a/cyto_dl/api/data.py b/cyto_dl/api/data.py index a52d63414..067c173f8 100644 --- a/cyto_dl/api/data.py +++ b/cyto_dl/api/data.py @@ -15,7 +15,6 @@ class ExperimentType(Enum): class HardwareType(Enum): CPU = "cpu" GPU = "gpu" - MPS = "mps" # other hardware types available, but require more complicated config From 222e1f14ae9600dc65e5fa0c12ee5a9010f74b08 Mon Sep 17 00:00:00 2001 From: Daniel Saelid Date: Thu, 7 Mar 2024 09:51:51 -0800 Subject: [PATCH 4/6] ran pre-commit hooks --- cyto_dl/api/data.py | 9 ++++----- cyto_dl/api/model.py | 2 +- tests/api/test_model.py | 11 +++++++---- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/cyto_dl/api/data.py b/cyto_dl/api/data.py index 067c173f8..e9e29475d 100644 --- a/cyto_dl/api/data.py +++ b/cyto_dl/api/data.py @@ -1,6 +1,7 @@ # goal: allow the API client to interact with CytoDLModel without # worrying about the actual values being used in config files from enum import Enum + import skimage @@ -19,8 +20,8 @@ class HardwareType(Enum): class PatchSize(Enum): - """ - Patch size for training, and their respective patch shapes. + """Patch size for training, and their respective patch shapes. + TODO: get from benji """ @@ -37,8 +38,6 @@ class PatchSize(Enum): "AutoThresholdMethod", { func_name.split("threshold_")[-1].upper(): func_name - for func_name in filter( - lambda x: x.startswith("threshold_"), dir(skimage.filters) - ) + for func_name in filter(lambda x: x.startswith("threshold_"), dir(skimage.filters)) }, ) diff --git a/cyto_dl/api/model.py b/cyto_dl/api/model.py index 6d79941f5..d9010a7f8 100644 --- a/cyto_dl/api/model.py +++ b/cyto_dl/api/model.py @@ -6,11 +6,11 @@ from hydra.core.global_hydra import GlobalHydra from omegaconf import OmegaConf, open_dict +from cyto_dl.api.data import ExperimentType from cyto_dl.eval import evaluate from cyto_dl.train import train as train_model from cyto_dl.utils.download_test_data import download_test_data from cyto_dl.utils.rich_utils import print_config_tree -from cyto_dl.api.data import ExperimentType class CytoDLModel: diff --git a/tests/api/test_model.py b/tests/api/test_model.py index 5056fe2bd..f21c01875 100644 --- a/tests/api/test_model.py +++ b/tests/api/test_model.py @@ -1,9 +1,12 @@ -import cyto_dl.api.model -from cyto_dl.api.model import CytoDLModel -from cyto_dl.api.data import ExperimentType from unittest.mock import patch + import pytest +import cyto_dl.api.model +from cyto_dl.api.data import ExperimentType +from cyto_dl.api.model import CytoDLModel + + # mock these functions to avoid attempts to write to file system @patch("cyto_dl.api.model.OmegaConf.save") @patch("cyto_dl.api.model.Path.mkdir") @@ -13,6 +16,7 @@ def test_load_default_experiment_valid_exp_type(MockMkdir, MockSave): MockMkdir.assert_called() MockSave.assert_called() + @patch("cyto_dl.api.model.OmegaConf.save") @patch("cyto_dl.api.model.Path.mkdir") def test_load_default_experiment_invalid_exp_type(MockMkdir, MockSave): @@ -21,4 +25,3 @@ def test_load_default_experiment_invalid_exp_type(MockMkdir, MockSave): model.load_default_experiment("invalid_exp_type", "fake_dir") MockMkdir.assert_not_called() MockSave.assert_not_called() - From 0e3828a151966ff1ae2550a4db25cdf285382d02 Mon Sep 17 00:00:00 2001 From: Daniel Saelid Date: Thu, 7 Mar 2024 10:38:36 -0800 Subject: [PATCH 5/6] update patch sizes based on Benji's review --- cyto_dl/api/data.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/cyto_dl/api/data.py b/cyto_dl/api/data.py index e9e29475d..a6b5f0c6f 100644 --- a/cyto_dl/api/data.py +++ b/cyto_dl/api/data.py @@ -20,14 +20,13 @@ class HardwareType(Enum): class PatchSize(Enum): - """Patch size for training, and their respective patch shapes. - - TODO: get from benji + """ + Patch size for training, and their respective patch shapes. """ - SMALL = [1, 3, 3] - MEDIUM = [16, 32, 32] - LARGE = [20, 40, 40] + SMALL = [16, 32, 32] + MEDIUM = [16, 64, 64] + LARGE = [16, 128, 128] # importing skimage takes a while. From 448e7fc9851f72cc6ad02fc8e21b76cc484a0f68 Mon Sep 17 00:00:00 2001 From: Daniel Saelid Date: Thu, 7 Mar 2024 10:42:18 -0800 Subject: [PATCH 6/6] pre-commit --- cyto_dl/api/data.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/cyto_dl/api/data.py b/cyto_dl/api/data.py index a6b5f0c6f..24e9648fa 100644 --- a/cyto_dl/api/data.py +++ b/cyto_dl/api/data.py @@ -20,9 +20,7 @@ class HardwareType(Enum): class PatchSize(Enum): - """ - Patch size for training, and their respective patch shapes. - """ + """Patch size for training, and their respective patch shapes.""" SMALL = [16, 32, 32] MEDIUM = [16, 64, 64]