diff --git a/cyto_dl/api/model.py b/cyto_dl/api/model.py index 4be1eb79..21e7d9a1 100644 --- a/cyto_dl/api/model.py +++ b/cyto_dl/api/model.py @@ -75,7 +75,6 @@ def load_default_experiment( output_dir.mkdir(parents=True, exist_ok=True) self.cfg = cfg - OmegaConf.save(self.cfg, output_dir / f'{"train" if train else "eval"}_config.yaml') def print_config(self): print_config_tree(self.cfg, resolve=True) @@ -88,6 +87,13 @@ def override_config(self, overrides: Dict[str, Union[str, int, float, bool]]): for k, v in overrides.items(): OmegaConf.update(self.cfg, k, v) + def save_config(self, path: Path) -> None: + """Save current config to provided path. + + :param path: path at which to save config + """ + OmegaConf.save(self.cfg, path) + async def _train_async(self): return train_model(self.cfg) diff --git a/pyproject.toml b/pyproject.toml index 7456c717..4a25c4e2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,7 +43,7 @@ dependencies = [ "astropy>=5.2", "rich", "boto3", - "bioio", + "bioio>=1.0.1", "bioio-czi", "bioio-ome-tiff", "bioio-tifffile" diff --git a/tests/api/test_model.py b/tests/api/test_model.py index f21c0187..39f51911 100644 --- a/tests/api/test_model.py +++ b/tests/api/test_model.py @@ -14,7 +14,7 @@ def test_load_default_experiment_valid_exp_type(MockMkdir, MockSave): model: CytoDLModel = CytoDLModel() model.load_default_experiment(ExperimentType.SEGMENTATION.value, "fake_dir") MockMkdir.assert_called() - MockSave.assert_called() + MockSave.assert_not_called() @patch("cyto_dl.api.model.OmegaConf.save")