Skip to content

Commit

Permalink
update metatensor metadata handling
Browse files Browse the repository at this point in the history
  • Loading branch information
Benjamin Morris committed Oct 18, 2023
1 parent 1816699 commit bf19a86
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 22 deletions.
2 changes: 1 addition & 1 deletion cyto_dl/callbacks/outlier_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def on_predict_epoch_start(self, trainer, pl_module):

def _inference_batch_end(self, batch):
if self._run:
batch_names = batch["raw_meta_dict"]["filename_or_obj"]
batch_names = batch["raw"].meta["filename_or_obj"]
# activations are saved per-patch
distances_per_image = len(self.mahalanobis_distances[self.layer_names[0]]) // len(
batch_names
Expand Down
18 changes: 11 additions & 7 deletions cyto_dl/datamodules/czi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
import pandas as pd
from aicsimageio.aics_image import AICSImage
from monai.data import DataLoader, Dataset
from monai.data import DataLoader, Dataset, MetaTensor
from monai.transforms import Compose, apply_transform
from omegaconf import ListConfig

Expand Down Expand Up @@ -129,13 +129,17 @@ def _transform(self, index: int):
img.set_scene(img_data.pop("scene"))
data_i = img.get_image_dask_data(**img_data).compute()
data_i = self._ensure_channel_first(data_i)
output_img = (
apply_transform(self.transform, data_i) if self.transform is not None else data_i
)

return {
self.out_key: apply_transform(self.transform, data_i)
if self.transform is not None
else data_i,
f"{self.out_key}_meta_dict": {
"filename_or_obj": original_path.replace(".", self._metadata_to_str(img_data))
},
self.out_key: MetaTensor(
output_img,
meta={
"filename_or_obj": original_path.replace(".", self._metadata_to_str(img_data))
},
)
}

def __len__(self):
Expand Down
4 changes: 2 additions & 2 deletions cyto_dl/image/io/aicsimage_loader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List

from aicsimageio import AICSImage
from monai.data import MetaTensor
from monai.transforms import Transform


Expand Down Expand Up @@ -52,7 +53,6 @@ def __call__(self, data):
img.set_scene(data[self.scene_key])
kwargs = {k: data[k] for k in self.kwargs_keys}
img = img.get_image_dask_data(**kwargs).compute()
data[self.out_key] = img
data[f"{self.out_key}_meta_dict"] = {"filename_or_obj": path, "kwargs": kwargs}
data[self.out_key] = MetaTensor(img, meta={"filename_or_obj": path, "kwargs": kwargs})

return data
2 changes: 2 additions & 0 deletions cyto_dl/models/im2im/gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def _extract_loss(self, outs, loss_type):
return self._sum_losses(loss)

def model_step(self, stage, batch, batch_idx):
batch["filenames"] = batch[self.hparams.x_key].meta["filename_or_obj"]
# convert monai metatensors to tensors
for k, v in batch.items():
if isinstance(v, MetaTensor):
Expand Down Expand Up @@ -210,6 +211,7 @@ def model_step(self, stage, batch, batch_idx):
return loss_dict, None, None

def predict_step(self, batch, batch_idx):
batch["filenames"] = batch[self.hparams.x_key].meta["filename_or_obj"]
# convert monai metatensors to tensors
for k, v in batch.items():
if isinstance(v, MetaTensor):
Expand Down
13 changes: 5 additions & 8 deletions cyto_dl/models/im2im/multi_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,24 +168,21 @@ def _get_run_heads(self, batch, stage):

def model_step(self, stage, batch, batch_idx):
# convert monai metatensors to tensors
batch["filenames"] = batch[self.hparams.x_key].meta["filename_or_obj"]
for k, v in batch.items():
if isinstance(v, MetaTensor):
batch[k] = v.as_tensor()

run_heads = self._get_run_heads(batch, stage)
outs = self.run_forward(batch, stage, self.should_save_image(batch_idx, stage), run_heads)

if stage != "predict":
losses = {head_name: head_result["loss"] for head_name, head_result in outs.items()}
losses = self._sum_losses(losses)
return losses, None, None

preds = {head_name: head_result["y_hat_out"] for head_name, head_result in outs.items()}

return None, preds, None
losses = {head_name: head_result["loss"] for head_name, head_result in outs.items()}
losses = self._sum_losses(losses)
return losses, None, None

def predict_step(self, batch, batch_idx):
stage = "predict"
batch["filenames"] = batch[self.hparams.x_key].meta["filename_or_obj"]
# convert monai metatensors to tensors
for k, v in batch.items():
if isinstance(v, MetaTensor):
Expand Down
6 changes: 2 additions & 4 deletions cyto_dl/nn/head/base_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,13 @@ def save_image(self, y_hat, batch, stage, global_step):
if self.save_raw:
raw_out = self._postprocess(batch[self.x_key], img_type="input")
try:
metadata_filenames = batch[f"{self.x_key}_meta_dict"]["filename_or_obj"]
metadata_filenames = batch["filenames"]
filename_map = {"input": metadata_filenames, "output": []}
metadata_filenames = [
f"{Path(fn).stem}_{self.head_name}.tif" for fn in metadata_filenames
]
except KeyError:
raise ValueError(
f"Please ensure your batches contain key `{self.x_key}_meta_dict['filename_or_obj']`"
)
raise ValueError("Please ensure your batches have key `filenames`")
save_name = (
[f"{global_step}_{self.head_name}.tif"]
if stage in ("train", "val")
Expand Down

0 comments on commit bf19a86

Please sign in to comment.