Skip to content

Commit

Permalink
Inference Updates (#405)
Browse files Browse the repository at this point in the history
* update saving

* refactor czi dataloader

* check if output is a batch

* some models dont like metatensors

* add features only option

* precommit

* remove channel contrastive

* remove oopsie committed files

* incorporate optional metadata keys

* move metatensor creation

* precommit and update docstring

---------

Co-authored-by: Benjamin Morris <[email protected]>
Co-authored-by: Benjamin Morris <[email protected]>
  • Loading branch information
3 people authored Jul 19, 2024
1 parent 323ff26 commit c65dbbf
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 72 deletions.
22 changes: 22 additions & 0 deletions cyto_dl/callbacks/csv_saver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from pathlib import Path

import pandas as pd
from lightning.pytorch.callbacks import Callback


class CSVSaver(Callback):
def __init__(self, save_dir, meta_keys=[]):
self.save_dir = Path(save_dir)
self.save_dir.mkdir(parents=True, exist_ok=True)
self.meta_keys = meta_keys

def on_predict_epoch_end(self, trainer, pl_module):
# Access the list of predictions from all predict_steps
predictions = trainer.predict_loop.predictions
feats = []
for pred, meta in predictions:
batch_feats = pd.DataFrame(pred)
for k in self.meta_keys:
batch_feats[k] = meta[k]
feats.append(batch_feats)
pd.concat(feats).to_csv(self.save_dir / "predictions.csv", index=False)
142 changes: 86 additions & 56 deletions cyto_dl/datamodules/czi.py → cyto_dl/datamodules/multidim_image.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
from pathlib import Path
from typing import Callable, Optional, Union
from typing import Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import pandas as pd
import torch
from bioio import BioImage
from monai.data import DataLoader, Dataset, MetaTensor
from monai.transforms import Compose, apply_transform
from monai.transforms import Compose, ToTensor, apply_transform
from omegaconf import ListConfig


class CZIDataset(Dataset):
"""Dataset converting a `.csv` file listing CZI files and some metadata into batches of single-
scene, single-timepoint, single-channel images."""
class MultiDimImageDataset(Dataset):
"""Dataset converting a `.csv` file listing multi dimensional (timelapse or multi-scene) files
and some metadata into batches of single- scene, single-timepoint, single-channel images."""

def __init__(
self,
Expand All @@ -24,6 +25,7 @@ def __init__(
time_start_column: str = "start",
time_stop_column: str = "stop",
time_step_column: str = "step",
dict_meta: Optional[Dict] = None,
transform: Optional[Callable] = None,
dask_load: bool = True,
):
Expand All @@ -33,32 +35,35 @@ def __init__(
csv_path: Union[Path, str]
path to csv
img_path_column: str
column in `csv_path` that contains path to CZI file
column in `csv_path` that contains path to multi dimensional (timelapse or multi-scene) file
channel_column:str
Column in `csv_path` that contains which channel to extract from CZI file. Should be an integer.
Column in `csv_path` that contains which channel to extract from multi dimensional (timelapse or multi-scene) file. Should be an integer.
out_key:str
Key where single-scene/timepoint/channel is saved in output dictionary
spatial_dims:int=3
Spatial dimension of output image. Must be 2 for YX or 3 for ZYX
scene_column:str="scene",
Column in `csv_path` that contains scenes to extract from CZI file. If not specified, all scenes will
Column in `csv_path` that contains scenes to extract from multi-scene file. If not specified, all scenes will
be extracted. If multiple scenes are specified, they should be separated by a comma (e.g. `scene1,scene2`)
time_start_column:str="start"
Column in `csv_path` specifying which timepoint in timelapse CZI to start extracting. If any of `start_column`, `stop_column`, or `step_column`
Column in `csv_path` specifying which timepoint in timelapse image to start extracting. If any of `start_column`, `stop_column`, or `step_column`
are not specified, all timepoints are extracted.
time_stop_column:str="stop"
Column in `csv_path` specifying which timepoint in timelapse CZI to stop extracting. If any of `start_column`, `stop_column`, or `step_column`
Column in `csv_path` specifying which timepoint in timelapse image to stop extracting. If any of `start_column`, `stop_column`, or `step_column`
are not specified, all timepoints are extracted.
time_step_column:str="step"
Column in `csv_path` specifying step between timepoints. For example, values in this column should be `2` if every other timepoint should be run.
If any of `start_column`, `stop_column`, or `step_column` are not specified, all timepoints are extracted.
dict_meta: Optional[Dict]
Dictionary version of CSV file. If not provided, CSV file is read from `csv_path`.
transform: Optional[Callable] = None
Callable to that accepts numpy array. For example, image normalization functions could be passed here.
dask_load: bool = True
Whether to use dask to load images. If False, full images are loaded into memory before extracting specified scenes/timepoints.
"""
super().__init__(None, transform)
df = pd.read_csv(csv_path)
df = pd.read_csv(csv_path) if csv_path is not None else pd.DataFrame([dict_meta])

self.img_path_column = img_path_column
self.channel_column = channel_column
self.scene_column = scene_column
Expand Down Expand Up @@ -87,13 +92,11 @@ def _get_scenes(self, row, img):
return scenes

def _get_timepoints(self, row, img):
start = row.get(self.time_start_column, -1)
start = row.get(self.time_start_column, 0)
stop = row.get(self.time_stop_column, -1)
step = row.get(self.time_step_column, 1)
timepoints = list(range(start, stop + 1, step))
if np.any(np.array((start, stop, step)) == -1):
timepoints = list(range(img.dims.T))
return timepoints
timepoints = range(start, stop + 1, step) if stop > 0 else range(img.dims.T)
return list(timepoints)

def get_per_file_args(self, df):
img_data = []
Expand Down Expand Up @@ -124,81 +127,105 @@ def _ensure_channel_first(self, img):
img = np.expand_dims(img, 0)
return img

def create_metatensor(self, img, meta):
if isinstance(img, np.ndarray):
img = torch.from_numpy(img.astype(float))
if isinstance(img, MetaTensor):
img.meta.update(meta)
return img
elif isinstance(img, torch.Tensor):
return MetaTensor(
img,
meta=meta,
)
raise ValueError(f"Expected img to be MetaTensor or torch.Tensor, got {type(img)}")

def is_batch(self, x):
return isinstance(x, list) or len(x.shape) == self.spatial_dims + 2

def _transform(self, index: int):
img_data = self.img_data.pop()
img = img_data.pop("img")
original_path = img_data.pop("original_path")
scene = img_data.pop("scene")
img.set_scene(scene)

if self.dask_load:
data_i = img.get_image_dask_data(**img_data).compute()
else:
data_i = img.get_image_data(**img_data)
# add scene and path information back to metadata
img_data["scene"] = scene
img_data["original_path"] = original_path
data_i = self._ensure_channel_first(data_i)
data_i = self.create_metatensor(data_i, img_data)

output_img = (
apply_transform(self.transform, data_i) if self.transform is not None else data_i
)

return {
self.out_key: MetaTensor(
output_img,
meta={
"filename_or_obj": original_path.replace(".", self._metadata_to_str(img_data))
},
)
}
# some monai transforms return a batch. When collated, the batch dimension gets moved to the channel dimension
if self.is_batch(output_img):
return [{self.out_key: img} for img in output_img]
return {self.out_key: img}

def __len__(self):
return len(self.img_data)


def make_CZI_dataloader(
csv_path,
img_path_column,
channel_column,
out_key,
spatial_dims=3,
scene_column="scene",
time_start_column="start",
time_stop_column="stop",
time_step_column="step",
transforms=None,
def make_multidim_image_dataloader(
csv_path: Optional[Union[Path, str]] = None,
img_path_column: str = "path",
channel_column: str = "channel",
out_key: str = "image",
spatial_dims: int = 3,
scene_column: str = "scene",
time_start_column: str = "start",
time_stop_column: str = "stop",
time_step_column: str = "step",
dict_meta: Optional[Dict] = None,
transforms: Optional[Union[List[Callable], Tuple[Callable], ListConfig]] = None,
**dataloader_kwargs,
):
"""Function to create a CZI Dataset. Currently, this dataset is only useful during prediction
and cannot be used for training or testing.
) -> DataLoader:
"""Function to create a MultiDimImage DataLoader. Currently, this dataset is only useful during
prediction and cannot be used for training or testing.
Parameters
----------
csv_path: Union[Path, str]
csv_path: Optional[Union[Path, str]]
path to csv
img_path_column: str
column in `csv_path` that contains path to CZI file
channel_column:str
Column in `csv_path` that contains which channel to extract from CZI file. Should be an integer.
out_key:str
column in `csv_path` that contains path to multi dimensional (timelapse or multi-scene) file
channel_column: str
Column in `csv_path` that contains which channel to extract from multi dim image file. Should be an integer.
out_key: str
Key where single-scene/timepoint/channel is saved in output dictionary
spatial_dims:int=3
spatial_dims: int
Spatial dimension of output image. Must be 2 for YX or 3 for ZYX
scene_column:str="scene",
Column in `csv_path` that contains scenes to extract from CZI file. If not specified, all scenes will
scene_column: str
Column in `csv_path` that contains scenes to extract from multiscene file. If not specified, all scenes will
be extracted. If multiple scenes are specified, they should be separated by a comma (e.g. `scene1,scene2`)
time_start_column:str="start"
Column in `csv_path` specifying which timepoint in timelapse CZI to start extracting. If any of `start_column`, `stop_column`, or `step_column`
time_start_column: str
Column in `csv_path` specifying which timepoint in timelapse image to start extracting. If any of `start_column`, `stop_column`, or `step_column`
are not specified, all timepoints are extracted.
time_stop_column:str="stop"
Column in `csv_path` specifying which timepoint in timelapse CZI to stop extracting. If any of `start_column`, `stop_column`, or `step_column`
time_stop_column: str
Column in `csv_path` specifying which timepoint in timelapse image to stop extracting. If any of `start_column`, `stop_column`, or `step_column`
are not specified, all timepoints are extracted.
time_step_column:str="step"
time_step_column: str
Column in `csv_path` specifying step between timepoints. For example, values in this column should be `2` if every other timepoint should be run.
If any of `start_column`, `stop_column`, or `step_column` are not specified, all timepoints are extracted.
transform: Optional[Callable] = None
Callable to that accepts numpy array. For example, image normalization functions could be passed here.
dict_meta: Optional[Dict]
Dictionary version of CSV file. If not provided, CSV file is read from `csv_path`.
transforms: Optional[Union[List[Callable], Tuple[Callable], ListConfig]]
Callable or list of callables that accept numpy array. For example, image normalization functions could be passed here. Dataloading is already handled by the dataset.
Returns
-------
DataLoader
The DataLoader object for the MultiDimIMage dataset.
"""
if isinstance(transforms, (list, tuple, ListConfig)):
transforms = Compose(transforms)
dataset = CZIDataset(
dataset = MultiDimImageDataset(
csv_path,
img_path_column,
channel_column,
Expand All @@ -208,6 +235,9 @@ def make_CZI_dataloader(
time_start_column=time_start_column,
time_stop_column=time_stop_column,
time_step_column=time_step_column,
dict_meta=dict_meta,
transform=transforms,
)
return DataLoader(dataset, **dataloader_kwargs)
# currently only supports a 0/1 workers
num_workers = min(dataloader_kwargs.pop("num_workers"), 1)
return DataLoader(dataset, num_workers=num_workers, **dataloader_kwargs)
9 changes: 3 additions & 6 deletions cyto_dl/models/classification/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,7 @@ def model_step(self, stage, batch, batch_idx):
return loss, logits.argmax(dim=1), labels

def predict_step(self, batch, batch_idx):
logits = self(batch[self.hparams.x_key]).squeeze(0)
x = batch[self.hparams.anchor_key]
logits = self(x).squeeze(0)
preds = torch.argmax(logits, dim=1).cpu().numpy()
if self.hparams.write_batch_predictions:
pd.DataFrame([preds]).to_csv(
Path(self.hparams.save_dir) / f"predictions_batch={batch_idx}.csv", index=False
)
return preds
return preds, x.meta
13 changes: 3 additions & 10 deletions cyto_dl/models/contrastive/contrastive.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from sklearn.decomposition import PCA
Expand Down Expand Up @@ -130,12 +129,6 @@ def model_step(self, stage, batch, batch_idx):
return out["loss"], None, None

def predict_step(self, batch, batch_idx):
x = batch[self.hparams.anchor_key].as_tensor()
embeddings = self.backbone(x)
preds = pd.DataFrame(
embeddings.detach().cpu().numpy(), columns=[str(i) for i in range(embeddings.shape[1])]
)
for key in self.hparams.meta_keys:
preds[key] = batch[key]
preds.to_csv(Path(self.hparams.save_dir) / f"{batch_idx}_predictions.csv")
return None, None, None
x = batch[self.hparams.anchor_key]
embeddings = self.backbone(x if isinstance(x, torch.Tensor) else x.as_tensor())
return embeddings.detach().cpu().numpy(), x.meta
7 changes: 7 additions & 0 deletions cyto_dl/nn/vits/mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ def __init__(
use_crossmae: Optional[bool] = False,
context_pixels: Optional[List[int]] = [0, 0, 0],
input_channels: Optional[int] = 1,
features_only: Optional[bool] = False,
) -> None:
"""
Parameters
Expand Down Expand Up @@ -247,6 +248,9 @@ def __init__(
context_pixels: List[int]
Number of extra pixels around each patch to include in convolutional embedding to encoder dimension.
input_channels: int
Number of input channels
features_only: bool
Only use encoder to extract features
"""
super().__init__()
assert spatial_dims in (2, 3), "Spatial dims must be 2 or 3"
Expand All @@ -262,6 +266,7 @@ def __init__(
), "base_patch_size must be of length spatial_dims"

self.mask_ratio = mask_ratio
self.features_only = features_only

self.encoder = MAE_Encoder(
num_patches,
Expand Down Expand Up @@ -290,5 +295,7 @@ def __init__(

def forward(self, img):
features, mask, forward_indexes, backward_indexes = self.encoder(img, self.mask_ratio)
if self.features_only:
return features
predicted_img = self.decoder(features, forward_indexes, backward_indexes)
return predicted_img, mask

0 comments on commit c65dbbf

Please sign in to comment.