Skip to content

Commit

Permalink
move metatensor creation
Browse files Browse the repository at this point in the history
  • Loading branch information
Benjamin Morris committed Jul 18, 2024
1 parent 528df83 commit 31217d5
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions cyto_dl/datamodules/multidim_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
from bioio import BioImage
from monai.data import DataLoader, Dataset, MetaTensor
from monai.transforms import Compose, apply_transform
from monai.transforms import Compose, apply_transform, ToTensor
from omegaconf import ListConfig


Expand Down Expand Up @@ -128,6 +128,8 @@ def _ensure_channel_first(self, img):
return img

def create_metatensor(self, img, meta):
if isinstance(img, np.ndarray):
img = torch.from_numpy(img.astype(float))
if isinstance(img, MetaTensor):
img.meta.update(meta)
return img
Expand Down Expand Up @@ -156,15 +158,17 @@ def _transform(self, index: int):
img_data["scene"] = scene
img_data["original_path"] = original_path
data_i = self._ensure_channel_first(data_i)
data_i = self.create_metatensor(data_i, img_data)

output_img = (
apply_transform(self.transform, data_i) if self.transform is not None else data_i
)
# some monai transforms return a batch. When collated, the batch dimension gets moved to the channel dimension
if self.is_batch(output_img):
return [
{self.out_key: self.create_metatensor(img, meta=img_data)} for img in output_img
{self.out_key: img} for img in output_img
]
return {self.out_key: self.create_metatensor(output_img, meta=img_data)}
return {self.out_key: img}

def __len__(self):
return len(self.img_data)
Expand Down

0 comments on commit 31217d5

Please sign in to comment.