diff --git a/cyto_dl/callbacks/csv_saver.py b/cyto_dl/callbacks/csv_saver.py new file mode 100644 index 000000000..75fab6af2 --- /dev/null +++ b/cyto_dl/callbacks/csv_saver.py @@ -0,0 +1,22 @@ +from pathlib import Path + +import pandas as pd +from lightning.pytorch.callbacks import Callback + + +class CSVSaver(Callback): + def __init__(self, save_dir, meta_keys=[]): + self.save_dir = Path(save_dir) + self.save_dir.mkdir(parents=True, exist_ok=True) + self.meta_keys = meta_keys + + def on_predict_epoch_end(self, trainer, pl_module): + # Access the list of predictions from all predict_steps + predictions = trainer.predict_loop.predictions + feats = [] + for pred, meta in predictions: + batch_feats = pd.DataFrame(pred) + for k in self.meta_keys: + batch_feats[k] = meta[k] + feats.append(batch_feats) + pd.concat(feats).to_csv(self.save_dir / "predictions.csv", index=False) diff --git a/cyto_dl/datamodules/czi.py b/cyto_dl/datamodules/multidim_image.py similarity index 61% rename from cyto_dl/datamodules/czi.py rename to cyto_dl/datamodules/multidim_image.py index 1e0d749bc..a5599d672 100644 --- a/cyto_dl/datamodules/czi.py +++ b/cyto_dl/datamodules/multidim_image.py @@ -1,17 +1,18 @@ from pathlib import Path -from typing import Callable, Optional, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import numpy as np import pandas as pd +import torch from bioio import BioImage from monai.data import DataLoader, Dataset, MetaTensor -from monai.transforms import Compose, apply_transform +from monai.transforms import Compose, ToTensor, apply_transform from omegaconf import ListConfig -class CZIDataset(Dataset): - """Dataset converting a `.csv` file listing CZI files and some metadata into batches of single- - scene, single-timepoint, single-channel images.""" +class MultiDimImageDataset(Dataset): + """Dataset converting a `.csv` file listing multi dimensional (timelapse or multi-scene) files + and some metadata into batches of single- scene, single-timepoint, single-channel images.""" def __init__( self, @@ -24,6 +25,7 @@ def __init__( time_start_column: str = "start", time_stop_column: str = "stop", time_step_column: str = "step", + dict_meta: Optional[Dict] = None, transform: Optional[Callable] = None, dask_load: bool = True, ): @@ -33,32 +35,35 @@ def __init__( csv_path: Union[Path, str] path to csv img_path_column: str - column in `csv_path` that contains path to CZI file + column in `csv_path` that contains path to multi dimensional (timelapse or multi-scene) file channel_column:str - Column in `csv_path` that contains which channel to extract from CZI file. Should be an integer. + Column in `csv_path` that contains which channel to extract from multi dimensional (timelapse or multi-scene) file. Should be an integer. out_key:str Key where single-scene/timepoint/channel is saved in output dictionary spatial_dims:int=3 Spatial dimension of output image. Must be 2 for YX or 3 for ZYX scene_column:str="scene", - Column in `csv_path` that contains scenes to extract from CZI file. If not specified, all scenes will + Column in `csv_path` that contains scenes to extract from multi-scene file. If not specified, all scenes will be extracted. If multiple scenes are specified, they should be separated by a comma (e.g. `scene1,scene2`) time_start_column:str="start" - Column in `csv_path` specifying which timepoint in timelapse CZI to start extracting. If any of `start_column`, `stop_column`, or `step_column` + Column in `csv_path` specifying which timepoint in timelapse image to start extracting. If any of `start_column`, `stop_column`, or `step_column` are not specified, all timepoints are extracted. time_stop_column:str="stop" - Column in `csv_path` specifying which timepoint in timelapse CZI to stop extracting. If any of `start_column`, `stop_column`, or `step_column` + Column in `csv_path` specifying which timepoint in timelapse image to stop extracting. If any of `start_column`, `stop_column`, or `step_column` are not specified, all timepoints are extracted. time_step_column:str="step" Column in `csv_path` specifying step between timepoints. For example, values in this column should be `2` if every other timepoint should be run. If any of `start_column`, `stop_column`, or `step_column` are not specified, all timepoints are extracted. + dict_meta: Optional[Dict] + Dictionary version of CSV file. If not provided, CSV file is read from `csv_path`. transform: Optional[Callable] = None Callable to that accepts numpy array. For example, image normalization functions could be passed here. dask_load: bool = True Whether to use dask to load images. If False, full images are loaded into memory before extracting specified scenes/timepoints. """ super().__init__(None, transform) - df = pd.read_csv(csv_path) + df = pd.read_csv(csv_path) if csv_path is not None else pd.DataFrame([dict_meta]) + self.img_path_column = img_path_column self.channel_column = channel_column self.scene_column = scene_column @@ -87,13 +92,11 @@ def _get_scenes(self, row, img): return scenes def _get_timepoints(self, row, img): - start = row.get(self.time_start_column, -1) + start = row.get(self.time_start_column, 0) stop = row.get(self.time_stop_column, -1) step = row.get(self.time_step_column, 1) - timepoints = list(range(start, stop + 1, step)) - if np.any(np.array((start, stop, step)) == -1): - timepoints = list(range(img.dims.T)) - return timepoints + timepoints = range(start, stop + 1, step) if stop > 0 else range(img.dims.T) + return list(timepoints) def get_per_file_args(self, df): img_data = [] @@ -124,81 +127,105 @@ def _ensure_channel_first(self, img): img = np.expand_dims(img, 0) return img + def create_metatensor(self, img, meta): + if isinstance(img, np.ndarray): + img = torch.from_numpy(img.astype(float)) + if isinstance(img, MetaTensor): + img.meta.update(meta) + return img + elif isinstance(img, torch.Tensor): + return MetaTensor( + img, + meta=meta, + ) + raise ValueError(f"Expected img to be MetaTensor or torch.Tensor, got {type(img)}") + + def is_batch(self, x): + return isinstance(x, list) or len(x.shape) == self.spatial_dims + 2 + def _transform(self, index: int): img_data = self.img_data.pop() img = img_data.pop("img") original_path = img_data.pop("original_path") scene = img_data.pop("scene") img.set_scene(scene) + if self.dask_load: data_i = img.get_image_dask_data(**img_data).compute() else: data_i = img.get_image_data(**img_data) + # add scene and path information back to metadata img_data["scene"] = scene + img_data["original_path"] = original_path data_i = self._ensure_channel_first(data_i) + data_i = self.create_metatensor(data_i, img_data) + output_img = ( apply_transform(self.transform, data_i) if self.transform is not None else data_i ) - - return { - self.out_key: MetaTensor( - output_img, - meta={ - "filename_or_obj": original_path.replace(".", self._metadata_to_str(img_data)) - }, - ) - } + # some monai transforms return a batch. When collated, the batch dimension gets moved to the channel dimension + if self.is_batch(output_img): + return [{self.out_key: img} for img in output_img] + return {self.out_key: img} def __len__(self): return len(self.img_data) -def make_CZI_dataloader( - csv_path, - img_path_column, - channel_column, - out_key, - spatial_dims=3, - scene_column="scene", - time_start_column="start", - time_stop_column="stop", - time_step_column="step", - transforms=None, +def make_multidim_image_dataloader( + csv_path: Optional[Union[Path, str]] = None, + img_path_column: str = "path", + channel_column: str = "channel", + out_key: str = "image", + spatial_dims: int = 3, + scene_column: str = "scene", + time_start_column: str = "start", + time_stop_column: str = "stop", + time_step_column: str = "step", + dict_meta: Optional[Dict] = None, + transforms: Optional[Union[List[Callable], Tuple[Callable], ListConfig]] = None, **dataloader_kwargs, -): - """Function to create a CZI Dataset. Currently, this dataset is only useful during prediction - and cannot be used for training or testing. +) -> DataLoader: + """Function to create a MultiDimImage DataLoader. Currently, this dataset is only useful during + prediction and cannot be used for training or testing. Parameters ---------- - csv_path: Union[Path, str] + csv_path: Optional[Union[Path, str]] path to csv img_path_column: str - column in `csv_path` that contains path to CZI file - channel_column:str - Column in `csv_path` that contains which channel to extract from CZI file. Should be an integer. - out_key:str + column in `csv_path` that contains path to multi dimensional (timelapse or multi-scene) file + channel_column: str + Column in `csv_path` that contains which channel to extract from multi dim image file. Should be an integer. + out_key: str Key where single-scene/timepoint/channel is saved in output dictionary - spatial_dims:int=3 + spatial_dims: int Spatial dimension of output image. Must be 2 for YX or 3 for ZYX - scene_column:str="scene", - Column in `csv_path` that contains scenes to extract from CZI file. If not specified, all scenes will + scene_column: str + Column in `csv_path` that contains scenes to extract from multiscene file. If not specified, all scenes will be extracted. If multiple scenes are specified, they should be separated by a comma (e.g. `scene1,scene2`) - time_start_column:str="start" - Column in `csv_path` specifying which timepoint in timelapse CZI to start extracting. If any of `start_column`, `stop_column`, or `step_column` + time_start_column: str + Column in `csv_path` specifying which timepoint in timelapse image to start extracting. If any of `start_column`, `stop_column`, or `step_column` are not specified, all timepoints are extracted. - time_stop_column:str="stop" - Column in `csv_path` specifying which timepoint in timelapse CZI to stop extracting. If any of `start_column`, `stop_column`, or `step_column` + time_stop_column: str + Column in `csv_path` specifying which timepoint in timelapse image to stop extracting. If any of `start_column`, `stop_column`, or `step_column` are not specified, all timepoints are extracted. - time_step_column:str="step" + time_step_column: str Column in `csv_path` specifying step between timepoints. For example, values in this column should be `2` if every other timepoint should be run. If any of `start_column`, `stop_column`, or `step_column` are not specified, all timepoints are extracted. - transform: Optional[Callable] = None - Callable to that accepts numpy array. For example, image normalization functions could be passed here. + dict_meta: Optional[Dict] + Dictionary version of CSV file. If not provided, CSV file is read from `csv_path`. + transforms: Optional[Union[List[Callable], Tuple[Callable], ListConfig]] + Callable or list of callables that accept numpy array. For example, image normalization functions could be passed here. Dataloading is already handled by the dataset. + + Returns + ------- + DataLoader + The DataLoader object for the MultiDimIMage dataset. """ if isinstance(transforms, (list, tuple, ListConfig)): transforms = Compose(transforms) - dataset = CZIDataset( + dataset = MultiDimImageDataset( csv_path, img_path_column, channel_column, @@ -208,6 +235,9 @@ def make_CZI_dataloader( time_start_column=time_start_column, time_stop_column=time_stop_column, time_step_column=time_step_column, + dict_meta=dict_meta, transform=transforms, ) - return DataLoader(dataset, **dataloader_kwargs) + # currently only supports a 0/1 workers + num_workers = min(dataloader_kwargs.pop("num_workers"), 1) + return DataLoader(dataset, num_workers=num_workers, **dataloader_kwargs) diff --git a/cyto_dl/models/classification/classification.py b/cyto_dl/models/classification/classification.py index c4730f838..2a21c550a 100644 --- a/cyto_dl/models/classification/classification.py +++ b/cyto_dl/models/classification/classification.py @@ -118,10 +118,7 @@ def model_step(self, stage, batch, batch_idx): return loss, logits.argmax(dim=1), labels def predict_step(self, batch, batch_idx): - logits = self(batch[self.hparams.x_key]).squeeze(0) + x = batch[self.hparams.anchor_key] + logits = self(x).squeeze(0) preds = torch.argmax(logits, dim=1).cpu().numpy() - if self.hparams.write_batch_predictions: - pd.DataFrame([preds]).to_csv( - Path(self.hparams.save_dir) / f"predictions_batch={batch_idx}.csv", index=False - ) - return preds + return preds, x.meta diff --git a/cyto_dl/models/contrastive/contrastive.py b/cyto_dl/models/contrastive/contrastive.py index 55b0f81c3..8c148b1e7 100644 --- a/cyto_dl/models/contrastive/contrastive.py +++ b/cyto_dl/models/contrastive/contrastive.py @@ -2,7 +2,6 @@ import matplotlib.pyplot as plt import numpy as np -import pandas as pd import torch import torch.nn as nn from sklearn.decomposition import PCA @@ -130,12 +129,6 @@ def model_step(self, stage, batch, batch_idx): return out["loss"], None, None def predict_step(self, batch, batch_idx): - x = batch[self.hparams.anchor_key].as_tensor() - embeddings = self.backbone(x) - preds = pd.DataFrame( - embeddings.detach().cpu().numpy(), columns=[str(i) for i in range(embeddings.shape[1])] - ) - for key in self.hparams.meta_keys: - preds[key] = batch[key] - preds.to_csv(Path(self.hparams.save_dir) / f"{batch_idx}_predictions.csv") - return None, None, None + x = batch[self.hparams.anchor_key] + embeddings = self.backbone(x if isinstance(x, torch.Tensor) else x.as_tensor()) + return embeddings.detach().cpu().numpy(), x.meta diff --git a/cyto_dl/nn/vits/mae.py b/cyto_dl/nn/vits/mae.py index 1cc906dea..c05a6b311 100644 --- a/cyto_dl/nn/vits/mae.py +++ b/cyto_dl/nn/vits/mae.py @@ -220,6 +220,7 @@ def __init__( use_crossmae: Optional[bool] = False, context_pixels: Optional[List[int]] = [0, 0, 0], input_channels: Optional[int] = 1, + features_only: Optional[bool] = False, ) -> None: """ Parameters @@ -247,6 +248,9 @@ def __init__( context_pixels: List[int] Number of extra pixels around each patch to include in convolutional embedding to encoder dimension. input_channels: int + Number of input channels + features_only: bool + Only use encoder to extract features """ super().__init__() assert spatial_dims in (2, 3), "Spatial dims must be 2 or 3" @@ -262,6 +266,7 @@ def __init__( ), "base_patch_size must be of length spatial_dims" self.mask_ratio = mask_ratio + self.features_only = features_only self.encoder = MAE_Encoder( num_patches, @@ -290,5 +295,7 @@ def __init__( def forward(self, img): features, mask, forward_indexes, backward_indexes = self.encoder(img, self.mask_ratio) + if self.features_only: + return features predicted_img = self.decoder(features, forward_indexes, backward_indexes) return predicted_img, mask