Skip to content

Commit

Permalink
refactor: Use enum to represent plane
Browse files Browse the repository at this point in the history
  • Loading branch information
shernshiou committed Feb 19, 2024
1 parent 45e3716 commit 68ff24b
Showing 1 changed file with 188 additions and 75 deletions.
263 changes: 188 additions & 75 deletions darwin/exporter/formats/nifti.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import json as native_json
import re
from dataclasses import dataclass
from enum import Enum
from numbers import Number
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Tuple, Union

Expand All @@ -10,7 +12,7 @@
console = Console()
try:
import nibabel as nib
from nibabel.orientations import axcodes2ornt, io_orientation, ornt_transform
from nibabel.orientations import io_orientation, ornt_transform
except ImportError:
import_fail_string = """
You must install darwin-py with pip install darwin-py\[medical]
Expand All @@ -24,6 +26,12 @@
from darwin.utils import convert_polygons_to_mask


class Plane(Enum):
XY = 0
XZ = 1
YZ = 2


@dataclass
class Volume:
pixel_array: np.ndarray
Expand Down Expand Up @@ -169,9 +177,8 @@ def build_output_volumes(

def check_for_error_and_return_imageid(
video_annotation: dt.AnnotationFile, output_dir: Path
):
"""
Given the video_annotation file and the output directory, checks for a range of errors and
) -> Union[str, bool]:
"""Given the video_annotation file and the output directory, checks for a range of errors and
returns messages accordingly.
Parameters
Expand All @@ -183,8 +190,8 @@ def check_for_error_and_return_imageid(
Returns
-------
image_id : str
Union[str, bool]
Returns the image_id if no errors are found, otherwise returns False
"""
# Check if all item slots have the correct file-extension
for slot in video_annotation.slots:
Expand Down Expand Up @@ -240,6 +247,47 @@ def check_for_error_and_return_imageid(
return image_id


def update_pixel_array(
volume: Dict,
annotation_class_name: str,
im_mask: np.ndarray,
plane: Plane,
frame_idx: int,
) -> Dict:
"""Updates the pixel array of the given volume with the given mask.
Parameters
----------
volume : Dict
Volume with pixel array to be updated
annotation_class_name : str
Name of the annotation class
im_mask : np.ndarray
Mask to be added to the pixel array
plane : Plane
Plane of the mask
frame_idx : int
Frame index of the mask
Returns
-------
Dict
Updated volume
"""
plane_to_slice = {
Plane.XY: np.s_[:, :, frame_idx],
Plane.XZ: np.s_[:, frame_idx, :],
Plane.YZ: np.s_[frame_idx, :, :],
}
if plane in plane_to_slice:
slice_ = plane_to_slice[plane]
volume[annotation_class_name].pixel_array[slice_] = np.logical_or(
im_mask,
volume[annotation_class_name].pixel_array[slice_],
)
return volume


def populate_output_volumes_from_polygons(
annotations: List[Union[dt.Annotation, dt.VideoAnnotation]],
slot_map: Dict,
Expand Down Expand Up @@ -273,76 +321,41 @@ def populate_output_volumes_from_polygons(
volume = output_volumes.get(series_instance_uid)
frames = annotation.frames

# define the different planes
XYPLANE = 0
XZPLANE = 1
YZPLANE = 2

for frame_idx in frames.keys():
view_idx = get_view_idx_from_slot_name(
plane = get_plane_from_slot_name(
slot_name, slot.metadata.get("orientation")
)
if view_idx == XYPLANE:
height, width = (
volume[annotation.annotation_class.name].dims[0],
volume[annotation.annotation_class.name].dims[1],
)
elif view_idx == XZPLANE:
height, width = (
volume[annotation.annotation_class.name].dims[0],
volume[annotation.annotation_class.name].dims[2],
)
elif view_idx == YZPLANE:
height, width = (
volume[annotation.annotation_class.name].dims[1],
volume[annotation.annotation_class.name].dims[2],
)
if "paths" in frames[frame_idx].data:
dims = volume[annotation.annotation_class.name].dims
if plane == Plane.XY:
height, width = dims[0], dims[1]
elif plane == Plane.XZ:
height, width = dims[0], dims[2]
elif plane == Plane.YZ:
height, width = dims[1], dims[2]
pixdims = volume[annotation.annotation_class.name].pixdims
frame_data = frames[frame_idx].data
if "paths" in frame_data:
# Dealing with a complex polygon
polygons = [
shift_polygon_coords(
polygon_path, volume[annotation.annotation_class.name].pixdims
)
for polygon_path in frames[frame_idx].data["paths"]
shift_polygon_coords(polygon_path, pixdims)
for polygon_path in frame_data["paths"]
]
elif "path" in frames[frame_idx].data:
elif "path" in frame_data:
# Dealing with a simple polygon
polygons = shift_polygon_coords(
frames[frame_idx].data["path"],
volume[annotation.annotation_class.name].pixdims,
frame_data["path"],
pixdims,
)
else:
continue
frames[frame_idx].annotation_class.name
im_mask = convert_polygons_to_mask(polygons, height=height, width=width)
volume = output_volumes[series_instance_uid]
if view_idx == 0:
volume[annotation.annotation_class.name].pixel_array[
:, :, frame_idx
] = np.logical_or(
im_mask,
volume[annotation.annotation_class.name].pixel_array[
:, :, frame_idx
],
)
elif view_idx == 1:
volume[annotation.annotation_class.name].pixel_array[
:, frame_idx, :
] = np.logical_or(
im_mask,
volume[annotation.annotation_class.name].pixel_array[
:, frame_idx, :
],
)
elif view_idx == 2:
volume[annotation.annotation_class.name].pixel_array[
frame_idx, :, :
] = np.logical_or(
im_mask,
volume[annotation.annotation_class.name].pixel_array[
frame_idx, :, :
],
)
volume = update_pixel_array(
output_volumes[series_instance_uid],
annotation.annotation_class.name,
im_mask,
plane,
frame_idx,
)
return volume


Expand Down Expand Up @@ -404,6 +417,22 @@ def populate_output_volumes_from_raster_layer(
def write_output_volume_to_disk(
output_volumes: Dict, image_id: str, output_dir: Union[str, Path]
) -> None:
"""Writes the given output volumes to disk.
Parameters
----------
output_volumes : Dict
Output volumes to be written to disk
image_id : str
The specific image id
output_dir : Union[str, Path]
The output directory to write the volumes to
Returns
-------
None
"""

# volumes are the values of this nested dict
def unnest_dict_to_list(d: Dict) -> List:
result = []
Expand Down Expand Up @@ -438,7 +467,7 @@ def unnest_dict_to_list(d: Dict) -> List:
nib.save(img=img, filename=output_path)


def shift_polygon_coords(polygon, pixdim):
def shift_polygon_coords(polygon: List[Dict], pixdim: List[Number]) -> List:
# Need to make it clear that we flip x/y because we need to take the transpose later.
if pixdim[1] > pixdim[0]:
return [{"x": p["y"], "y": p["x"] * pixdim[1] / pixdim[0]} for p in polygon]
Expand All @@ -448,24 +477,63 @@ def shift_polygon_coords(polygon, pixdim):
return [{"x": p["y"], "y": p["x"]} for p in polygon]


def get_view_idx(frame_idx, groups):
def get_view_idx(frame_idx: int, groups: List) -> int:
"""Returns the view index for the given frame index and groups.
Parameters
----------
frame_idx : int
Frame index
groups : List
List of groups
Returns
-------
int
View index
"""
if groups is None:
return 0
for view_idx, group in enumerate(groups):
if frame_idx in group:
return view_idx


def get_view_idx_from_slot_name(slot_name: str, orientation: Union[str, None]) -> int:
def get_plane_from_slot_name(slot_name: str, orientation: Union[str, None]) -> Plane:
"""Returns the plane from the given slot name and orientation.
Parameters
----------
slot_name : str
Slot name
orientation : Union[str, None]
Orientation
Returns
-------
Plane
Enum representing the plane
"""
if orientation is None:
orientation_dict = {"0.1": 0, "0.2": 1, "0.3": 2}
return orientation_dict.get(slot_name, 0)
else:
orientation_dict = {"AXIAL": 0, "SAGITTAL": 1, "CORONAL": 2}
return orientation_dict.get(orientation, 0)
return Plane(orientation_dict.get(slot_name, 0))
orientation_dict = {"AXIAL": 0, "SAGITTAL": 1, "CORONAL": 2}
return Plane(orientation_dict.get(orientation, 0))


def process_metadata(metadata: Dict) -> Tuple:
"""Processes the metadata and returns the volume dimensions, pixel dimensions, affine and original affine.
Parameters
----------
metadata : Dict
Metadata to be processed
Returns
-------
Tuple
Tuple containing volume dimensions, pixel dimensions, affine and original affine
"""
volume_dims = metadata.get("shape")
pixdim = metadata.get("pixdim")
affine = process_affine(metadata.get("affine"))
Expand All @@ -489,9 +557,23 @@ def process_metadata(metadata: Dict) -> Tuple:
return volume_dims, pixdim, affine, original_affine


def process_affine(affine):
def process_affine(affine: Union[str, List, np.ndarray]) -> Optional[np.ndarray]:
"""Converts affine to numpy array if it is not already.
Parameters
----------
affine : Union[str, List, np.ndarray]
affine object to be converted
Returns
-------
Optional[np.ndarray]
affine as numpy array
"""
if isinstance(affine, str):
affine = np.squeeze(np.array([ast.literal_eval(l) for l in affine.split("\n")]))
affine = np.squeeze(
np.array([ast.literal_eval(lst) for lst in affine.split("\n")])
)
elif isinstance(affine, list):
affine = np.array(affine).astype(float)
else:
Expand All @@ -503,6 +585,22 @@ def process_affine(affine):
def create_error_message_json(
error_message: str, output_dir: Union[str, Path], image_id: str
) -> bool:
"""Creates a json file with the given error message.
Parameters
----------
error_message : str
Error message to be written to the file
output_dir : Union[str, Path]
Output directory
image_id : str
Associated image id
Returns
-------
bool
Always returns False
"""
output_path = Path(output_dir) / f"{image_id}_error.json"
if not output_path.parent.exists():
output_path.parent.mkdir(parents=True)
Expand All @@ -512,8 +610,23 @@ def create_error_message_json(
return False


def decode_rle(rle_data, width, height):
"""Decodes run-length encoding (RLE) data into a mask array."""
def decode_rle(rle_data: List[int], width: int, height: int) -> np.ndarray:
"""Decodes run-length encoding (RLE) data into a mask array.
Parameters
----------
rle_data : List[int]
List of RLE data
width : int
Width of the data
height : int
Height of the data
Returns
-------
np.ndarray
RLE data
"""
total_pixels = width * height
mask = np.zeros(total_pixels, dtype=np.uint8)
pos = 0
Expand Down

0 comments on commit 68ff24b

Please sign in to comment.