diff --git a/cyto_dl/callbacks/outlier_detection.py b/cyto_dl/callbacks/outlier_detection.py index ef706ff0e..789ca4090 100644 --- a/cyto_dl/callbacks/outlier_detection.py +++ b/cyto_dl/callbacks/outlier_detection.py @@ -118,7 +118,7 @@ def on_predict_epoch_start(self, trainer, pl_module): def _inference_batch_end(self, batch): if self._run: - batch_names = batch["raw_meta_dict"]["filename_or_obj"] + batch_names = batch["raw"].meta["filename_or_obj"] # activations are saved per-patch distances_per_image = len(self.mahalanobis_distances[self.layer_names[0]]) // len( batch_names diff --git a/cyto_dl/datamodules/czi.py b/cyto_dl/datamodules/czi.py index 30ba4f9ed..dc1f582f2 100644 --- a/cyto_dl/datamodules/czi.py +++ b/cyto_dl/datamodules/czi.py @@ -4,7 +4,7 @@ import numpy as np import pandas as pd from aicsimageio.aics_image import AICSImage -from monai.data import DataLoader, Dataset +from monai.data import DataLoader, Dataset, MetaTensor from monai.transforms import Compose, apply_transform from omegaconf import ListConfig @@ -129,13 +129,17 @@ def _transform(self, index: int): img.set_scene(img_data.pop("scene")) data_i = img.get_image_dask_data(**img_data).compute() data_i = self._ensure_channel_first(data_i) + output_img = ( + apply_transform(self.transform, data_i) if self.transform is not None else data_i + ) + return { - self.out_key: apply_transform(self.transform, data_i) - if self.transform is not None - else data_i, - f"{self.out_key}_meta_dict": { - "filename_or_obj": original_path.replace(".", self._metadata_to_str(img_data)) - }, + self.out_key: MetaTensor( + output_img, + meta={ + "filename_or_obj": original_path.replace(".", self._metadata_to_str(img_data)) + }, + ) } def __len__(self): diff --git a/cyto_dl/image/io/aicsimage_loader.py b/cyto_dl/image/io/aicsimage_loader.py index 43ce6d0ac..eb40a1277 100644 --- a/cyto_dl/image/io/aicsimage_loader.py +++ b/cyto_dl/image/io/aicsimage_loader.py @@ -1,6 +1,7 @@ from typing import List from aicsimageio import AICSImage +from monai.data import MetaTensor from monai.transforms import Transform @@ -52,7 +53,6 @@ def __call__(self, data): img.set_scene(data[self.scene_key]) kwargs = {k: data[k] for k in self.kwargs_keys} img = img.get_image_dask_data(**kwargs).compute() - data[self.out_key] = img - data[f"{self.out_key}_meta_dict"] = {"filename_or_obj": path, "kwargs": kwargs} + data[self.out_key] = MetaTensor(img, meta={"filename_or_obj": path, "kwargs": kwargs}) return data diff --git a/cyto_dl/models/im2im/gan.py b/cyto_dl/models/im2im/gan.py index 12a2d5f94..b1e40b7a0 100644 --- a/cyto_dl/models/im2im/gan.py +++ b/cyto_dl/models/im2im/gan.py @@ -174,6 +174,7 @@ def _extract_loss(self, outs, loss_type): return self._sum_losses(loss) def model_step(self, stage, batch, batch_idx): + batch["filenames"] = batch[self.hparams.x_key].meta["filename_or_obj"] # convert monai metatensors to tensors for k, v in batch.items(): if isinstance(v, MetaTensor): @@ -210,6 +211,7 @@ def model_step(self, stage, batch, batch_idx): return loss_dict, None, None def predict_step(self, batch, batch_idx): + batch["filenames"] = batch[self.hparams.x_key].meta["filename_or_obj"] # convert monai metatensors to tensors for k, v in batch.items(): if isinstance(v, MetaTensor): diff --git a/cyto_dl/models/im2im/multi_task.py b/cyto_dl/models/im2im/multi_task.py index 88ff6ddc5..24a9180c3 100644 --- a/cyto_dl/models/im2im/multi_task.py +++ b/cyto_dl/models/im2im/multi_task.py @@ -168,6 +168,7 @@ def _get_run_heads(self, batch, stage): def model_step(self, stage, batch, batch_idx): # convert monai metatensors to tensors + batch["filenames"] = batch[self.hparams.x_key].meta["filename_or_obj"] for k, v in batch.items(): if isinstance(v, MetaTensor): batch[k] = v.as_tensor() @@ -175,17 +176,13 @@ def model_step(self, stage, batch, batch_idx): run_heads = self._get_run_heads(batch, stage) outs = self.run_forward(batch, stage, self.should_save_image(batch_idx, stage), run_heads) - if stage != "predict": - losses = {head_name: head_result["loss"] for head_name, head_result in outs.items()} - losses = self._sum_losses(losses) - return losses, None, None - - preds = {head_name: head_result["y_hat_out"] for head_name, head_result in outs.items()} - - return None, preds, None + losses = {head_name: head_result["loss"] for head_name, head_result in outs.items()} + losses = self._sum_losses(losses) + return losses, None, None def predict_step(self, batch, batch_idx): stage = "predict" + batch["filenames"] = batch[self.hparams.x_key].meta["filename_or_obj"] # convert monai metatensors to tensors for k, v in batch.items(): if isinstance(v, MetaTensor): diff --git a/cyto_dl/nn/head/base_head.py b/cyto_dl/nn/head/base_head.py index d65e7242c..179f5ad30 100644 --- a/cyto_dl/nn/head/base_head.py +++ b/cyto_dl/nn/head/base_head.py @@ -68,15 +68,13 @@ def save_image(self, y_hat, batch, stage, global_step): if self.save_raw: raw_out = self._postprocess(batch[self.x_key], img_type="input") try: - metadata_filenames = batch[f"{self.x_key}_meta_dict"]["filename_or_obj"] + metadata_filenames = batch["filenames"] filename_map = {"input": metadata_filenames, "output": []} metadata_filenames = [ f"{Path(fn).stem}_{self.head_name}.tif" for fn in metadata_filenames ] except KeyError: - raise ValueError( - f"Please ensure your batches contain key `{self.x_key}_meta_dict['filename_or_obj']`" - ) + raise ValueError("Please ensure your batches have key `filenames`") save_name = ( [f"{global_step}_{self.head_name}.tif"] if stage in ("train", "val")