diff --git a/configs/data/local/segmentation_all_cells_mask_from_zarr.yaml b/configs/data/local/segmentation_all_cells_mask_from_zarr.yaml index e34a7724..cfa29d75 100644 --- a/configs/data/local/segmentation_all_cells_mask_from_zarr.yaml +++ b/configs/data/local/segmentation_all_cells_mask_from_zarr.yaml @@ -5,7 +5,6 @@ batch_size: 1 pin_memory: True #persistent_workers: False - csv_path: img_path_column: movie_path channel_column: bf_channel @@ -14,5 +13,3 @@ out_key: ${source_col} transforms: - _target_: monai.transforms.ToTensor - _target_: monai.transforms.NormalizeIntensity - - diff --git a/configs/experiment/local/eval_scale1_new.yaml b/configs/experiment/local/eval_scale1_new.yaml index 552cd824..c7d2d046 100644 --- a/configs/experiment/local/eval_scale1_new.yaml +++ b/configs/experiment/local/eval_scale1_new.yaml @@ -28,17 +28,17 @@ trainer: model: compile: True - #save_dir: /allen/aics/assay-dev/users/Suraj/EMT_Work/image_analysis_test/EMT_image_analysis/Colony_mask_training_inference/eval_whole_movie_multiscale_patch1_zarr_aws + save_dir: /allen/aics/assay-dev/users/Suraj/EMT_Work/image_analysis_test/EMT_image_analysis/Colony_mask_training_inference/eval_whole_movie_multiscale_patch1_zarr_aws data: csv_path: /allen/aics/assay-dev/users/Suraj/EMT_Work/image_analysis_test/EMT_image_analysis/Colony_mask_training_inference/sample_csv/predict_all_cells_mask_zarr_aws_v0.csv #path to csv containing test movies #batch_size: 1 - _aux: + _aux: patch_shape: [16, 128, 128] callbacks: predict_saving: _target_: cyto_dl.callbacks.ImageSaver - save_dir: /allen/aics/assay-dev/users/Suraj/EMT_Work/image_analysis_test/EMT_image_analysis/Colony_mask_training_inference/eval_whole_movie_multiscale_patch1_zarr_aws + save_dir: ${model.save_dir} stages: ["predict"] - save_input: False \ No newline at end of file + save_input: False diff --git a/configs/model/local/segmentation_all_cells_mask.yaml b/configs/model/local/segmentation_all_cells_mask.yaml index c9994ac0..823ab05d 100644 --- a/configs/model/local/segmentation_all_cells_mask.yaml +++ b/configs/model/local/segmentation_all_cells_mask.yaml @@ -41,7 +41,7 @@ _aux: - - ${target_col} - _target_: cyto_dl.nn.BaseHead loss: - _target_: monai.losses.GeneralizedDiceFocalLoss ##Main loss + _target_: monai.losses.GeneralizedDiceFocalLoss ##Main loss sigmoid: True postprocess: input: diff --git a/cyto_dl/callbacks/image_saver.py b/cyto_dl/callbacks/image_saver.py index 4458016f..2b95a1fe 100644 --- a/cyto_dl/callbacks/image_saver.py +++ b/cyto_dl/callbacks/image_saver.py @@ -20,7 +20,8 @@ def __init__( Parameters ---------- save_dir: Union[str, Path] - Directory to save images + Directory to save images. Only testing saves in this directory, prediction + saves in model save directory save_every_n_epochs:int=1 Frequency to save images stages:List[str]=["train", "val"] @@ -49,6 +50,7 @@ def on_predict_batch_end( if outputs is None: # image has already been saved return + for i, head_io_map in enumerate(io_map.values()): for k, save_path in head_io_map.items(): self._save(save_path, outputs[k]["pred"][i]) diff --git a/cyto_dl/datamodules/multidim_image.py b/cyto_dl/datamodules/multidim_image.py index 14643061..8f3c69d8 100644 --- a/cyto_dl/datamodules/multidim_image.py +++ b/cyto_dl/datamodules/multidim_image.py @@ -98,6 +98,9 @@ def _get_timepoints(self, row, img): timepoints = range(start, stop + 1, step) if stop > 0 else range(img.dims.T) return list(timepoints) + def _get_filename(self, image_input_path): + return image_input_path.split("/")[-1].split(".")[0] + def get_per_file_args(self, df): img_data = [] for row in df.itertuples(): @@ -105,6 +108,8 @@ def get_per_file_args(self, df): img = BioImage(row[self.img_path_column]) scenes = self._get_scenes(row, img) timepoints = self._get_timepoints(row, img) + filename = self._get_filename(row[self.img_path_column]) + for scene in scenes: for timepoint in timepoints: img_data.append( @@ -115,6 +120,7 @@ def get_per_file_args(self, df): "scene": scene, "T": timepoint, "original_path": row[self.img_path_column], + "filename_or_obj": filename + f"_{timepoint}", # needs to be part of metadata to generate IO maps } ) img_data.reverse()