From 45ec6ef172b4ab302f6a86e14a6a65cbe6f740e5 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Mon, 21 Oct 2024 16:59:36 +0100 Subject: [PATCH 1/6] Added plots module --- brainglobe_template_builder/plots.py | 331 +++++++++++++++++++++++++++ 1 file changed, 331 insertions(+) create mode 100644 brainglobe_template_builder/plots.py diff --git a/brainglobe_template_builder/plots.py b/brainglobe_template_builder/plots.py new file mode 100644 index 0000000..8cf515b --- /dev/null +++ b/brainglobe_template_builder/plots.py @@ -0,0 +1,331 @@ +from pathlib import Path +from typing import Literal + +import numpy as np +from brainglobe_space import AnatomicalSpace +from matplotlib import pyplot as plt + + +def plot_orthographic( + img: np.ndarray, + anat_space: str = "ASR", + voxel_sizes: tuple[float, float, float] = (1.0, 1.0, 1.0), + show_slices: tuple[int, int, int] | None = None, + mip_attenuation: float = 0.01, + save_path: Path | None = None, + **kwargs, +) -> tuple[plt.Figure, np.ndarray]: + """Plot image volume in three orthogonal views, plus a surface rendering. + + The surface rendering is a maximum intensity projection (MIP) along the + vertical (superior-inferior) axis and is shown from the top. + + Parameters + ---------- + img : np.ndarray + Image volume to plot. + anat_space : str, optional + Anatomical space of of the image volume according to the Brainglobe + definition (origin and order of axes), by default "ASR". + voxel_sizes : tuple, optional + Voxels sizes in micrometers per dimension, by default (1.0, 1.0, 1.0). + The relative sizes of the axes will be preserved in the plot. + show_slices : tuple, optional + Which slice to show per dimension. If None (default), show the middle + slice along each dimension. + mip_attenuation : float, optional + Attenuation factor for the MIP, by default 0.01. + A value of 0 means no attenuation. + save_path : Path, optional + Path to save the plot, by default None (no saving). + **kwargs + Additional keyword arguments to pass to ``matplotlib.pyplot.imshow``. + + Returns + ------- + tuple[plt.Figure, np.ndarray] + Matplotlib figure and axes objects + + """ + + space = AnatomicalSpace(anat_space) + vertical_axis = space.get_axis_idx("vertical") + + # Get middle slices if not specified + if show_slices is None: + slices_list = [s // 2 for s in img.shape] + else: + slices_list = list(show_slices) + + # Pad the image with zeros to make it cubic + img, pad_sizes = _pad_with_zeros(img, target=max(img.shape)) + slices_list = [s + pad_sizes[i] for i, s in enumerate(slices_list)] + + # Compute (attenuated) MIP along the vertical axis + mip, mip_label = _compute_attenuated_mip( + img, vertical_axis, mip_attenuation + ) + + # Create figure with 4 subplots (3 orthogonal views + MIP) + fig, axs = plt.subplots(1, 4, figsize=(14, 4)) + views = [img.take(slc, axis=i) for i, slc in enumerate(slices_list)] + views.append(mip) + axis_labels = [*space.axis_labels, space.axis_labels[vertical_axis]] + section_names = [s.capitalize() for s in space.sections] + [mip_label] + + kwargs = _set_imshow_defaults(img, kwargs) + + for j, (section, labels) in enumerate(zip(section_names, axis_labels)): + ax = axs[j] + ax.imshow(views[j], **kwargs) + ax.set_title(section) + ax.set_ylabel(labels[0]) + ax.set_xlabel(labels[1]) + ax = _clear_spines_and_ticks(ax) + plt.tight_layout() + + if save_path: + _save_and_close_figure( + fig, save_path.parent, save_path.name.split(".")[0] + ) + return fig, axs + + +def plot_grid( + img: np.ndarray, + anat_space="ASR", + section: Literal["frontal", "horizontal", "sagittal"] = "frontal", + n_slices: int = 12, + save_path: Path | None = None, + **kwargs, +) -> tuple[plt.Figure, np.ndarray]: + """Plot image volume as a grid of slices along a given anatomical section. + + Parameters + ---------- + img : np.ndarray + Image volume to plot. + anat_space : str, optional + Anatomical space of of the image volume according to the Brainglobe + definition (origin and order of axes), by default "ASR". + section : str, optional + Section to show, must be one of "frontal", "horizontal", or "sagittal", + by default "frontal". + n_slices : int, optional + Number of slices to show, by default 12. Slices will be evenly spaced, + starting from the first and ending with the last slice. + save_path : Path, optional + Path to save the plot, by default None (no saving). + **kwargs + Additional keyword arguments to pass to ``matplotlib.pyplot.imshow``. + + Returns + ------- + tuple[plt.Figure, np.ndarray] + Matplotlib figure and axes objects + + """ + space = AnatomicalSpace(anat_space) + section_to_axis = { # Mapping of section names to space axes + "frontal": "sagittal", + "horizontal": "vertical", + "sagittal": "frontal", + } + axis_idx = space.get_axis_idx(section_to_axis[section]) + + # Ensure n_slices is not greater than the number of slices in the image + n_slices = min(n_slices, img.shape[axis_idx]) + # ensure first and last slices are included + show_slices = np.linspace(0, img.shape[axis_idx] - 1, n_slices, dtype=int) + + # Get slices along the specified axis and arrange them in a grid + grid_img = _grid_from_slices( + [img.take(slc, axis=axis_idx) for slc in show_slices] + ) + + # Plot the grid image + fig, ax = plt.subplots(1, 1, figsize=(12, 12)) + kwargs = _set_imshow_defaults(img, kwargs) + ax.imshow(grid_img, **kwargs) + + section_name = section.capitalize() + ax.set_title(f"{section_name} slices") + ax.set_xlabel(space.axis_labels[axis_idx][1]) + ax.set_ylabel(space.axis_labels[axis_idx][0]) + ax = _clear_spines_and_ticks(ax) + plt.tight_layout() + + if save_path: + _save_and_close_figure( + fig, save_path.parent, save_path.name.split(".")[0] + ) + return fig, ax + + +def _compute_attenuated_mip( + img: np.ndarray, axis: int, attenuation_factor: float +) -> tuple[np.ndarray, str]: + """Compute the maximum intensity projection (MIP) with attenuation. + + If the image is zero-padded, attenuation is only applied within the + non-zero region along the specified axis. + + Parameters + ---------- + img : np.ndarray + Image volume. + axis : int + Axis along which to compute the MIP. + attenuation_factor : float + Attenuation factor for the MIP. 0 means no attenuation. + + Returns + ------- + tuple[np.ndarray, str] + MIP image and label. The label is "MIP" if no attenuation is applied, + and "MIP (attenuated)" otherwise. + """ + + mip_label = "MIP" + + if attenuation_factor < 0: + raise ValueError("Attenuation factor must be non-negative.") + + if attenuation_factor < 1e-6: + # If the factor is too small, skip attenuation + mip = np.max(img, axis=axis) + return mip, mip_label + + # Find the non-zero bounding box along the specified axis + other_axes = tuple(i for i in range(img.ndim) if i != axis) + non_zero_mask = np.any(img != 0, axis=other_axes) + non_zero_indices = np.nonzero(non_zero_mask)[0] + start, end = non_zero_indices[0], non_zero_indices[-1] + 1 + + # Trim the image along the attenuation axis (get rid of zero-padding) + slices = [slice(None)] * img.ndim + slices[axis] = slice(start, end) + trimmed_img = img[tuple(slices)] + + # Apply attenuation to the trimmed image + attenuation = np.exp( + -attenuation_factor * np.arange(trimmed_img.shape[axis]) + ) + attenuation_shape = [1] * trimmed_img.ndim + attenuation_shape[axis] = trimmed_img.shape[axis] + attenuation = attenuation.reshape(attenuation_shape) + attenuated_img = trimmed_img.astype(np.float32) * attenuation + + # Compute and return the attenuated MIP + mip = np.max(attenuated_img, axis=axis) + mip_label += " (attenuated)" + + return mip, mip_label + + +def _save_and_close_figure(fig: plt.Figure, plots_dir: Path, filename: str): + """Save figure in both PNG and PDF formats and close it.""" + fig.savefig(plots_dir / f"{filename}.png") + fig.savefig(plots_dir / f"{filename}.pdf") + plt.close(fig) + + +def _clear_spines_and_ticks(ax: plt.Axes) -> plt.Axes: + """Clear spines and ticks from a matplotlib axis.""" + ax.set_xticks([]) + ax.set_yticks([]) + for spine in ax.spines.values(): + spine.set_visible(False) + return ax + + +def _set_imshow_defaults(img: np.ndarray, kwargs: dict) -> dict: + """Set default values for imshow keyword arguments. + + These apply only if the user does not provide them explicitly. + """ + if "vmin" not in kwargs and "vmax" not in kwargs: + vmin, vmax = _auto_adjust_contrast(img) + kwargs.setdefault("vmin", vmin) + kwargs.setdefault("vmax", vmax) + + kwargs.setdefault("cmap", "gray") + kwargs.setdefault("aspect", "equal") + return kwargs + + +def _grid_from_slices(slices: list[np.ndarray]) -> np.ndarray: + """Create a grid image from a list of 2D slices. + + The number of rows is automatically determined based on the square root + of the number of slices, rounded up. + + Parameters + ---------- + slices : list[np.ndarray] + List of 2D slices to concatenate. + + Returns + ------- + np.ndarray + A 2D image, with the input slices arranged in a grid. + + """ + + n_slices = len(slices) + slice_height, slice_width = slices[0].shape + + # Form image mosaic grid by concatenating slices + n_rows = int(np.ceil(np.sqrt(n_slices))) + n_cols = int(np.ceil(n_slices / n_rows)) + grid_img = np.zeros( + (n_rows * slice_height, n_cols * slice_width), + ) + for i, slice in enumerate(slices): + row = i // n_cols + col = i % n_cols + grid_img[ + row * slice_height : (row + 1) * slice_height, + col * slice_width : (col + 1) * slice_width, + ] = slice + + return grid_img + + +def _pad_with_zeros( + img: np.ndarray, target: int = 512 +) -> tuple[np.ndarray, tuple[int, int, int]]: + """Pad the volume with zeros to reach the target size in all dimensions.""" + pad_sizes = [(target - s) // 2 for s in img.shape] + padded_img = np.pad( + img, + ( + (pad_sizes[0], pad_sizes[0]), + (pad_sizes[1], pad_sizes[1]), + (pad_sizes[2], pad_sizes[2]), + ), + mode="constant", + ) + return padded_img, tuple(pad_sizes) + + +def _auto_adjust_contrast(img, lower_percentile=1, upper_percentile=99): + """Adjust contrast of an image using percentile-based scaling.""" + # Mask near-zero voxels to exclude background + if np.issubdtype(img.dtype, np.integer): + background_threshold = 1 + else: + background_threshold = np.finfo(img.dtype).eps + + brain_mask = img > background_threshold + + # Exclude bright artifacts + vmax = np.percentile(img[brain_mask], upper_percentile) + artifact_mask = img <= vmax + combined_mask = brain_mask & artifact_mask + + # Compute vmin and vmax + vmin = np.percentile(img[combined_mask], lower_percentile) + vmax = np.percentile(img[combined_mask], upper_percentile) + + return vmin, vmax From a0aed25f6218e9f5dc269cf8d0bd0d1290806452 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Mon, 21 Oct 2024 17:27:18 +0100 Subject: [PATCH 2/6] removed unnecessary voxel_sizes parameter --- brainglobe_template_builder/plots.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/brainglobe_template_builder/plots.py b/brainglobe_template_builder/plots.py index 8cf515b..d85e323 100644 --- a/brainglobe_template_builder/plots.py +++ b/brainglobe_template_builder/plots.py @@ -9,7 +9,6 @@ def plot_orthographic( img: np.ndarray, anat_space: str = "ASR", - voxel_sizes: tuple[float, float, float] = (1.0, 1.0, 1.0), show_slices: tuple[int, int, int] | None = None, mip_attenuation: float = 0.01, save_path: Path | None = None, @@ -17,8 +16,9 @@ def plot_orthographic( ) -> tuple[plt.Figure, np.ndarray]: """Plot image volume in three orthogonal views, plus a surface rendering. - The surface rendering is a maximum intensity projection (MIP) along the - vertical (superior-inferior) axis and is shown from the top. + The function assumes isotropic voxels (otherwise the proportions of the + image will be distorted). The surface rendering is a maximum intensity + projection (MIP) along the vertical (superior-inferior) axis. Parameters ---------- @@ -27,9 +27,6 @@ def plot_orthographic( anat_space : str, optional Anatomical space of of the image volume according to the Brainglobe definition (origin and order of axes), by default "ASR". - voxel_sizes : tuple, optional - Voxels sizes in micrometers per dimension, by default (1.0, 1.0, 1.0). - The relative sizes of the axes will be preserved in the plot. show_slices : tuple, optional Which slice to show per dimension. If None (default), show the middle slice along each dimension. From 3a0102dd0abbaf10fcb735f8ecb8b723c4668fe0 Mon Sep 17 00:00:00 2001 From: Niko Sirmpilatze Date: Mon, 28 Oct 2024 14:57:14 +0000 Subject: [PATCH 3/6] Apply suggestions from code review Co-authored-by: Alessandro Felder --- brainglobe_template_builder/plots.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/brainglobe_template_builder/plots.py b/brainglobe_template_builder/plots.py index d85e323..a0999cc 100644 --- a/brainglobe_template_builder/plots.py +++ b/brainglobe_template_builder/plots.py @@ -55,6 +55,7 @@ def plot_orthographic( slices_list = list(show_slices) # Pad the image with zeros to make it cubic + # so projections along different axes have the same size img, pad_sizes = _pad_with_zeros(img, target=max(img.shape)) slices_list = [s + pad_sizes[i] for i, s in enumerate(slices_list)] @@ -97,6 +98,9 @@ def plot_grid( **kwargs, ) -> tuple[plt.Figure, np.ndarray]: """Plot image volume as a grid of slices along a given anatomical section. + + Image contrast is auto-adjusted to 1-99% of range unless `vmin` and `vmax` are specified + as keyword arguments. Parameters ---------- @@ -110,7 +114,8 @@ def plot_grid( by default "frontal". n_slices : int, optional Number of slices to show, by default 12. Slices will be evenly spaced, - starting from the first and ending with the last slice. + starting from the first and ending with the last slice. If a higher value + than the number of slices in the image is chosen, all slices are shown. save_path : Path, optional Path to save the plot, by default None (no saving). **kwargs From 99a479a2cfe6a362b50ec1b0365064cd49c1844e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 28 Oct 2024 14:59:52 +0000 Subject: [PATCH 4/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- brainglobe_template_builder/plots.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/brainglobe_template_builder/plots.py b/brainglobe_template_builder/plots.py index a0999cc..d215b21 100644 --- a/brainglobe_template_builder/plots.py +++ b/brainglobe_template_builder/plots.py @@ -98,8 +98,8 @@ def plot_grid( **kwargs, ) -> tuple[plt.Figure, np.ndarray]: """Plot image volume as a grid of slices along a given anatomical section. - - Image contrast is auto-adjusted to 1-99% of range unless `vmin` and `vmax` are specified + + Image contrast is auto-adjusted to 1-99% of range unless `vmin` and `vmax` are specified as keyword arguments. Parameters From 4bd70e13db15c3d07253d234590de7e8819a60ff Mon Sep 17 00:00:00 2001 From: niksirbi Date: Mon, 28 Oct 2024 15:01:35 +0000 Subject: [PATCH 5/6] fix typos and linting --- brainglobe_template_builder/plots.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/brainglobe_template_builder/plots.py b/brainglobe_template_builder/plots.py index d215b21..89d8860 100644 --- a/brainglobe_template_builder/plots.py +++ b/brainglobe_template_builder/plots.py @@ -25,7 +25,7 @@ def plot_orthographic( img : np.ndarray Image volume to plot. anat_space : str, optional - Anatomical space of of the image volume according to the Brainglobe + Anatomical space of the image volume according to the Brainglobe definition (origin and order of axes), by default "ASR". show_slices : tuple, optional Which slice to show per dimension. If None (default), show the middle @@ -99,23 +99,24 @@ def plot_grid( ) -> tuple[plt.Figure, np.ndarray]: """Plot image volume as a grid of slices along a given anatomical section. - Image contrast is auto-adjusted to 1-99% of range unless `vmin` and `vmax` are specified - as keyword arguments. + Image contrast is auto-adjusted to 1-99% of range unless `vmin` and `vmax` + are specified as keyword arguments. Parameters ---------- img : np.ndarray Image volume to plot. anat_space : str, optional - Anatomical space of of the image volume according to the Brainglobe + Anatomical space of the image volume according to the Brainglobe definition (origin and order of axes), by default "ASR". section : str, optional Section to show, must be one of "frontal", "horizontal", or "sagittal", by default "frontal". n_slices : int, optional Number of slices to show, by default 12. Slices will be evenly spaced, - starting from the first and ending with the last slice. If a higher value - than the number of slices in the image is chosen, all slices are shown. + starting from the first and ending with the last slice. If a higher + value than the number of slices in the image is chosen, all slices + are shown. save_path : Path, optional Path to save the plot, by default None (no saving). **kwargs From 1bf3f936bbd5cd2aa6c6547f8d3e9e899f867508 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Mon, 28 Oct 2024 18:11:15 +0000 Subject: [PATCH 6/6] hard-code 1-99% as vmin-vmax default values --- brainglobe_template_builder/plots.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/brainglobe_template_builder/plots.py b/brainglobe_template_builder/plots.py index 89d8860..b0b1f03 100644 --- a/brainglobe_template_builder/plots.py +++ b/brainglobe_template_builder/plots.py @@ -99,8 +99,8 @@ def plot_grid( ) -> tuple[plt.Figure, np.ndarray]: """Plot image volume as a grid of slices along a given anatomical section. - Image contrast is auto-adjusted to 1-99% of range unless `vmin` and `vmax` - are specified as keyword arguments. + Image contrast is auto-adjusted to 1-99% of range unless overridden by + ``vmin`` and ``vmax`` passed as keyword arguments. Parameters ---------- @@ -247,10 +247,11 @@ def _set_imshow_defaults(img: np.ndarray, kwargs: dict) -> dict: These apply only if the user does not provide them explicitly. """ - if "vmin" not in kwargs and "vmax" not in kwargs: - vmin, vmax = _auto_adjust_contrast(img) - kwargs.setdefault("vmin", vmin) - kwargs.setdefault("vmax", vmax) + missing_keys = [key for key in ("vmin", "vmax") if key not in kwargs] + if missing_keys: + defaults = _auto_adjust_contrast(img) + for key in missing_keys: + kwargs.setdefault(key, defaults[key]) kwargs.setdefault("cmap", "gray") kwargs.setdefault("aspect", "equal") @@ -312,16 +313,21 @@ def _pad_with_zeros( return padded_img, tuple(pad_sizes) -def _auto_adjust_contrast(img, lower_percentile=1, upper_percentile=99): - """Adjust contrast of an image using percentile-based scaling.""" +def _auto_adjust_contrast(img: np.ndarray) -> dict: + """Adjust contrast of an image using percentile-based scaling. + + Uses the 1-99% range of the image intensity values to set vmin and vmax.""" # Mask near-zero voxels to exclude background if np.issubdtype(img.dtype, np.integer): background_threshold = 1 else: background_threshold = np.finfo(img.dtype).eps - brain_mask = img > background_threshold + # Hard-coded percentiles for default contrast adjustment + lower_percentile = 1 + upper_percentile = 99 + # Exclude bright artifacts vmax = np.percentile(img[brain_mask], upper_percentile) artifact_mask = img <= vmax @@ -331,4 +337,4 @@ def _auto_adjust_contrast(img, lower_percentile=1, upper_percentile=99): vmin = np.percentile(img[combined_mask], lower_percentile) vmax = np.percentile(img[combined_mask], upper_percentile) - return vmin, vmax + return {"vmin": vmin, "vmax": vmax}