From 68ff24b5c218a1731f455354029eed42d29bb81a Mon Sep 17 00:00:00 2001 From: Shern Shiou Tan Date: Thu, 1 Feb 2024 09:15:08 +0100 Subject: [PATCH] refactor: Use enum to represent plane --- darwin/exporter/formats/nifti.py | 263 ++++++++++++++++++++++--------- 1 file changed, 188 insertions(+), 75 deletions(-) diff --git a/darwin/exporter/formats/nifti.py b/darwin/exporter/formats/nifti.py index b28c10dcb..da4d1b8bd 100644 --- a/darwin/exporter/formats/nifti.py +++ b/darwin/exporter/formats/nifti.py @@ -2,6 +2,8 @@ import json as native_json import re from dataclasses import dataclass +from enum import Enum +from numbers import Number from pathlib import Path from typing import Dict, Iterable, List, Optional, Tuple, Union @@ -10,7 +12,7 @@ console = Console() try: import nibabel as nib - from nibabel.orientations import axcodes2ornt, io_orientation, ornt_transform + from nibabel.orientations import io_orientation, ornt_transform except ImportError: import_fail_string = """ You must install darwin-py with pip install darwin-py\[medical] @@ -24,6 +26,12 @@ from darwin.utils import convert_polygons_to_mask +class Plane(Enum): + XY = 0 + XZ = 1 + YZ = 2 + + @dataclass class Volume: pixel_array: np.ndarray @@ -169,9 +177,8 @@ def build_output_volumes( def check_for_error_and_return_imageid( video_annotation: dt.AnnotationFile, output_dir: Path -): - """ - Given the video_annotation file and the output directory, checks for a range of errors and +) -> Union[str, bool]: + """Given the video_annotation file and the output directory, checks for a range of errors and returns messages accordingly. Parameters @@ -183,8 +190,8 @@ def check_for_error_and_return_imageid( Returns ------- - image_id : str - + Union[str, bool] + Returns the image_id if no errors are found, otherwise returns False """ # Check if all item slots have the correct file-extension for slot in video_annotation.slots: @@ -240,6 +247,47 @@ def check_for_error_and_return_imageid( return image_id +def update_pixel_array( + volume: Dict, + annotation_class_name: str, + im_mask: np.ndarray, + plane: Plane, + frame_idx: int, +) -> Dict: + """Updates the pixel array of the given volume with the given mask. + + Parameters + ---------- + volume : Dict + Volume with pixel array to be updated + annotation_class_name : str + Name of the annotation class + im_mask : np.ndarray + Mask to be added to the pixel array + plane : Plane + Plane of the mask + frame_idx : int + Frame index of the mask + + Returns + ------- + Dict + Updated volume + """ + plane_to_slice = { + Plane.XY: np.s_[:, :, frame_idx], + Plane.XZ: np.s_[:, frame_idx, :], + Plane.YZ: np.s_[frame_idx, :, :], + } + if plane in plane_to_slice: + slice_ = plane_to_slice[plane] + volume[annotation_class_name].pixel_array[slice_] = np.logical_or( + im_mask, + volume[annotation_class_name].pixel_array[slice_], + ) + return volume + + def populate_output_volumes_from_polygons( annotations: List[Union[dt.Annotation, dt.VideoAnnotation]], slot_map: Dict, @@ -273,76 +321,41 @@ def populate_output_volumes_from_polygons( volume = output_volumes.get(series_instance_uid) frames = annotation.frames - # define the different planes - XYPLANE = 0 - XZPLANE = 1 - YZPLANE = 2 - for frame_idx in frames.keys(): - view_idx = get_view_idx_from_slot_name( + plane = get_plane_from_slot_name( slot_name, slot.metadata.get("orientation") ) - if view_idx == XYPLANE: - height, width = ( - volume[annotation.annotation_class.name].dims[0], - volume[annotation.annotation_class.name].dims[1], - ) - elif view_idx == XZPLANE: - height, width = ( - volume[annotation.annotation_class.name].dims[0], - volume[annotation.annotation_class.name].dims[2], - ) - elif view_idx == YZPLANE: - height, width = ( - volume[annotation.annotation_class.name].dims[1], - volume[annotation.annotation_class.name].dims[2], - ) - if "paths" in frames[frame_idx].data: + dims = volume[annotation.annotation_class.name].dims + if plane == Plane.XY: + height, width = dims[0], dims[1] + elif plane == Plane.XZ: + height, width = dims[0], dims[2] + elif plane == Plane.YZ: + height, width = dims[1], dims[2] + pixdims = volume[annotation.annotation_class.name].pixdims + frame_data = frames[frame_idx].data + if "paths" in frame_data: # Dealing with a complex polygon polygons = [ - shift_polygon_coords( - polygon_path, volume[annotation.annotation_class.name].pixdims - ) - for polygon_path in frames[frame_idx].data["paths"] + shift_polygon_coords(polygon_path, pixdims) + for polygon_path in frame_data["paths"] ] - elif "path" in frames[frame_idx].data: + elif "path" in frame_data: # Dealing with a simple polygon polygons = shift_polygon_coords( - frames[frame_idx].data["path"], - volume[annotation.annotation_class.name].pixdims, + frame_data["path"], + pixdims, ) else: continue - frames[frame_idx].annotation_class.name im_mask = convert_polygons_to_mask(polygons, height=height, width=width) - volume = output_volumes[series_instance_uid] - if view_idx == 0: - volume[annotation.annotation_class.name].pixel_array[ - :, :, frame_idx - ] = np.logical_or( - im_mask, - volume[annotation.annotation_class.name].pixel_array[ - :, :, frame_idx - ], - ) - elif view_idx == 1: - volume[annotation.annotation_class.name].pixel_array[ - :, frame_idx, : - ] = np.logical_or( - im_mask, - volume[annotation.annotation_class.name].pixel_array[ - :, frame_idx, : - ], - ) - elif view_idx == 2: - volume[annotation.annotation_class.name].pixel_array[ - frame_idx, :, : - ] = np.logical_or( - im_mask, - volume[annotation.annotation_class.name].pixel_array[ - frame_idx, :, : - ], - ) + volume = update_pixel_array( + output_volumes[series_instance_uid], + annotation.annotation_class.name, + im_mask, + plane, + frame_idx, + ) return volume @@ -404,6 +417,22 @@ def populate_output_volumes_from_raster_layer( def write_output_volume_to_disk( output_volumes: Dict, image_id: str, output_dir: Union[str, Path] ) -> None: + """Writes the given output volumes to disk. + + Parameters + ---------- + output_volumes : Dict + Output volumes to be written to disk + image_id : str + The specific image id + output_dir : Union[str, Path] + The output directory to write the volumes to + + Returns + ------- + None + """ + # volumes are the values of this nested dict def unnest_dict_to_list(d: Dict) -> List: result = [] @@ -438,7 +467,7 @@ def unnest_dict_to_list(d: Dict) -> List: nib.save(img=img, filename=output_path) -def shift_polygon_coords(polygon, pixdim): +def shift_polygon_coords(polygon: List[Dict], pixdim: List[Number]) -> List: # Need to make it clear that we flip x/y because we need to take the transpose later. if pixdim[1] > pixdim[0]: return [{"x": p["y"], "y": p["x"] * pixdim[1] / pixdim[0]} for p in polygon] @@ -448,7 +477,21 @@ def shift_polygon_coords(polygon, pixdim): return [{"x": p["y"], "y": p["x"]} for p in polygon] -def get_view_idx(frame_idx, groups): +def get_view_idx(frame_idx: int, groups: List) -> int: + """Returns the view index for the given frame index and groups. + + Parameters + ---------- + frame_idx : int + Frame index + groups : List + List of groups + + Returns + ------- + int + View index + """ if groups is None: return 0 for view_idx, group in enumerate(groups): @@ -456,16 +499,41 @@ def get_view_idx(frame_idx, groups): return view_idx -def get_view_idx_from_slot_name(slot_name: str, orientation: Union[str, None]) -> int: +def get_plane_from_slot_name(slot_name: str, orientation: Union[str, None]) -> Plane: + """Returns the plane from the given slot name and orientation. + + Parameters + ---------- + slot_name : str + Slot name + orientation : Union[str, None] + Orientation + + Returns + ------- + Plane + Enum representing the plane + """ if orientation is None: orientation_dict = {"0.1": 0, "0.2": 1, "0.3": 2} - return orientation_dict.get(slot_name, 0) - else: - orientation_dict = {"AXIAL": 0, "SAGITTAL": 1, "CORONAL": 2} - return orientation_dict.get(orientation, 0) + return Plane(orientation_dict.get(slot_name, 0)) + orientation_dict = {"AXIAL": 0, "SAGITTAL": 1, "CORONAL": 2} + return Plane(orientation_dict.get(orientation, 0)) def process_metadata(metadata: Dict) -> Tuple: + """Processes the metadata and returns the volume dimensions, pixel dimensions, affine and original affine. + + Parameters + ---------- + metadata : Dict + Metadata to be processed + + Returns + ------- + Tuple + Tuple containing volume dimensions, pixel dimensions, affine and original affine + """ volume_dims = metadata.get("shape") pixdim = metadata.get("pixdim") affine = process_affine(metadata.get("affine")) @@ -489,9 +557,23 @@ def process_metadata(metadata: Dict) -> Tuple: return volume_dims, pixdim, affine, original_affine -def process_affine(affine): +def process_affine(affine: Union[str, List, np.ndarray]) -> Optional[np.ndarray]: + """Converts affine to numpy array if it is not already. + + Parameters + ---------- + affine : Union[str, List, np.ndarray] + affine object to be converted + + Returns + ------- + Optional[np.ndarray] + affine as numpy array + """ if isinstance(affine, str): - affine = np.squeeze(np.array([ast.literal_eval(l) for l in affine.split("\n")])) + affine = np.squeeze( + np.array([ast.literal_eval(lst) for lst in affine.split("\n")]) + ) elif isinstance(affine, list): affine = np.array(affine).astype(float) else: @@ -503,6 +585,22 @@ def process_affine(affine): def create_error_message_json( error_message: str, output_dir: Union[str, Path], image_id: str ) -> bool: + """Creates a json file with the given error message. + + Parameters + ---------- + error_message : str + Error message to be written to the file + output_dir : Union[str, Path] + Output directory + image_id : str + Associated image id + + Returns + ------- + bool + Always returns False + """ output_path = Path(output_dir) / f"{image_id}_error.json" if not output_path.parent.exists(): output_path.parent.mkdir(parents=True) @@ -512,8 +610,23 @@ def create_error_message_json( return False -def decode_rle(rle_data, width, height): - """Decodes run-length encoding (RLE) data into a mask array.""" +def decode_rle(rle_data: List[int], width: int, height: int) -> np.ndarray: + """Decodes run-length encoding (RLE) data into a mask array. + + Parameters + ---------- + rle_data : List[int] + List of RLE data + width : int + Width of the data + height : int + Height of the data + + Returns + ------- + np.ndarray + RLE data + """ total_pixels = width * height mask = np.zeros(total_pixels, dtype=np.uint8) pos = 0