From c4bead4f53c883d9e07ab3ff55c287a4ed26fa22 Mon Sep 17 00:00:00 2001
From: John Wilkie <124276291+JBWilkie@users.noreply.github.com>
Date: Tue, 31 Oct 2023 15:28:29 +0000
Subject: [PATCH] [IO-2040][external] Fix for NifTI exports that include
 filenames with dots (#705)

* Fix for NifTI file exports with dots in the filename

* Fixed incorrectly pruned characters from filenames
---
 darwin/exporter/formats/darwin.py | 15 +++--
 darwin/exporter/formats/nifti.py  | 97 ++++++++++++++++++++++---------
 2 files changed, 79 insertions(+), 33 deletions(-)

diff --git a/darwin/exporter/formats/darwin.py b/darwin/exporter/formats/darwin.py
index 2e6193b4c..eadab47ad 100644
--- a/darwin/exporter/formats/darwin.py
+++ b/darwin/exporter/formats/darwin.py
@@ -50,12 +50,15 @@ def build_image_annotation(annotation_file: dt.AnnotationFile) -> Dict[str, Any]
     print(annotations)
     for annotation in annotation_file.annotations:
         payload = {
-            annotation.annotation_class.annotation_type: _build_annotation_data(annotation),
+            annotation.annotation_class.annotation_type: _build_annotation_data(
+                annotation
+            ),
             "name": annotation.annotation_class.name,
         }
 
         if (
-            annotation.annotation_class.annotation_type == "complex_polygon" or annotation.annotation_class.annotation_type == "polygon"
+            annotation.annotation_class.annotation_type == "complex_polygon"
+            or annotation.annotation_class.annotation_type == "polygon"
         ) and "bounding_box" in annotation.data:
             payload["bounding_box"] = annotation.data["bounding_box"]
 
@@ -83,7 +86,9 @@ def build_annotation_data(annotation: dt.Annotation) -> Dict[str, Any]:
         return {"path": annotation.data["paths"]}
 
     if annotation.annotation_class.annotation_type == "polygon":
-        return dict(filter(lambda item: item[0] != "bounding_box", annotation.data.items()))
+        return dict(
+            filter(lambda item: item[0] != "bounding_box", annotation.data.items())
+        )
 
     return dict(annotation.data)
 
@@ -93,6 +98,8 @@ def _build_annotation_data(annotation: dt.Annotation) -> Dict[str, Any]:
         return {"path": annotation.data["paths"]}
 
     if annotation.annotation_class.annotation_type == "polygon":
-        return dict(filter(lambda item: item[0] != "bounding_box", annotation.data.items()))
+        return dict(
+            filter(lambda item: item[0] != "bounding_box", annotation.data.items())
+        )
 
     return dict(annotation.data)
diff --git a/darwin/exporter/formats/nifti.py b/darwin/exporter/formats/nifti.py
index 115f7afa1..dacf21b7a 100644
--- a/darwin/exporter/formats/nifti.py
+++ b/darwin/exporter/formats/nifti.py
@@ -63,8 +63,12 @@ def export(annotation_files: Iterable[dt.AnnotationFile], output_dir: Path) -> N
         output_volumes = build_output_volumes(video_annotation)
         slot_map = {slot.name: slot for slot in video_annotation.slots}
         for annotation in video_annotation.annotations:
-            populate_output_volumes(annotation, output_dir, slot_map, output_volumes, image_id)
-        write_output_volume_to_disk(output_volumes, image_id=image_id, output_dir=output_dir)
+            populate_output_volumes(
+                annotation, output_dir, slot_map, output_volumes, image_id
+            )
+        write_output_volume_to_disk(
+            output_volumes, image_id=image_id, output_dir=output_dir
+        )
 
 
 def build_output_volumes(video_annotation: dt.AnnotationFile) -> Dict:
@@ -97,7 +101,9 @@ def build_output_volumes(video_annotation: dt.AnnotationFile) -> Dict:
     for slot in video_annotation.slots:
         slot_metadata = slot.metadata
         assert slot_metadata is not None
-        series_instance_uid = slot_metadata.get("SeriesInstanceUID", "SeriesIntanceUIDNotProvided")
+        series_instance_uid = slot_metadata.get(
+            "SeriesInstanceUID", "SeriesIntanceUIDNotProvided"
+        )
         # Builds output volumes per class
         volume_dims, pixdims, affine, original_affine = process_metadata(slot.metadata)
         output_volumes[series_instance_uid] = {
@@ -115,7 +121,9 @@ def build_output_volumes(video_annotation: dt.AnnotationFile) -> Dict:
     return output_volumes
 
 
-def check_for_error_and_return_imageid(video_annotation: dt.AnnotationFile, output_dir: Path):
+def check_for_error_and_return_imageid(
+    video_annotation: dt.AnnotationFile, output_dir: Path
+):
     """
     Given the video_annotation file and the output directory, checks for a range of errors and
     returns messages accordingly.
@@ -135,16 +143,19 @@ def check_for_error_and_return_imageid(video_annotation: dt.AnnotationFile, outp
 
     output_volumes = None
     filename = Path(video_annotation.filename)
-    suffixes = filename.suffixes
-    if len(suffixes) > 2:
-        return create_error_message_json(
-            "Misconfigured filename, contains too many suffixes", output_dir, str(filename)
-        )
-    elif len(suffixes) == 2:
+    try:
+        suffixes = filename.suffixes[-2:]
+    except IndexError:
+        suffixes = filename.suffixes
+    if len(suffixes) == 2:
         if suffixes[0] == ".nii" and suffixes[1] == ".gz":
-            image_id = str(filename).strip("".join(suffixes))
+            image_id = str(filename).rstrip("".join(suffixes))
         else:
-            return create_error_message_json("Two suffixes found but not ending in .nii.gz", output_dir, str(filename))
+            return create_error_message_json(
+                "Two suffixes found but not ending in .nii.gz",
+                output_dir,
+                str(filename),
+            )
     elif len(suffixes) == 1:
         if suffixes[0] == ".nii" or suffixes[0] == ".dcm":
             image_id = filename.stem
@@ -162,14 +173,18 @@ def check_for_error_and_return_imageid(video_annotation: dt.AnnotationFile, outp
             str(filename),
         )
     if video_annotation is None:
-        return create_error_message_json("video_annotation not found", output_dir, image_id)
+        return create_error_message_json(
+            "video_annotation not found", output_dir, image_id
+        )
 
     for slot in video_annotation.slots:
         # Pick the first slot to take the metadata from. We assume that all slots have the same metadata.
         metadata = slot.metadata
         if metadata is None:
             return create_error_message_json(
-                f"No metadata found for {str(filename)}, are you sure this is medical data?", output_dir, image_id
+                f"No metadata found for {str(filename)}, are you sure this is medical data?",
+                output_dir,
+                image_id,
             )
 
         volume_dims, pixdim, affine, _ = process_metadata(metadata)
@@ -214,7 +229,9 @@ def populate_output_volumes(
 
     slot_name = annotation.slot_names[0]
     slot = slot_map[slot_name]
-    series_instance_uid = slot.metadata.get("SeriesInstanceUID", "SeriesIntanceUIDNotProvided")
+    series_instance_uid = slot.metadata.get(
+        "SeriesInstanceUID", "SeriesIntanceUIDNotProvided"
+    )
     volume = output_volumes.get(series_instance_uid)
     frames = annotation.frames
     frame_new = {}
@@ -226,7 +243,9 @@ def populate_output_volumes(
 
     for frame_idx in frames.keys():
         frame_new[frame_idx] = frames
-        view_idx = get_view_idx_from_slot_name(slot_name, slot.metadata.get("orientation"))
+        view_idx = get_view_idx_from_slot_name(
+            slot_name, slot.metadata.get("orientation")
+        )
         if view_idx == XYPLANE:
             height, width = (
                 volume[annotation.annotation_class.name].dims[0],
@@ -245,13 +264,16 @@ def populate_output_volumes(
         if "paths" in frames[frame_idx].data:
             # Dealing with a complex polygon
             polygons = [
-                shift_polygon_coords(polygon_path, volume[annotation.annotation_class.name].pixdims)
+                shift_polygon_coords(
+                    polygon_path, volume[annotation.annotation_class.name].pixdims
+                )
                 for polygon_path in frames[frame_idx].data["paths"]
             ]
         elif "path" in frames[frame_idx].data:
             # Dealing with a simple polygon
             polygons = shift_polygon_coords(
-                frames[frame_idx].data["path"], volume[annotation.annotation_class.name].pixdims
+                frames[frame_idx].data["path"],
+                volume[annotation.annotation_class.name].pixdims,
             )
         else:
             continue
@@ -259,20 +281,31 @@ def populate_output_volumes(
         im_mask = convert_polygons_to_mask(polygons, height=height, width=width)
         volume = output_volumes[series_instance_uid]
         if view_idx == 0:
-            volume[annotation.annotation_class.name].pixel_array[:, :, frame_idx] = np.logical_or(
-                im_mask, volume[annotation.annotation_class.name].pixel_array[:, :, frame_idx]
+            volume[annotation.annotation_class.name].pixel_array[
+                :, :, frame_idx
+            ] = np.logical_or(
+                im_mask,
+                volume[annotation.annotation_class.name].pixel_array[:, :, frame_idx],
             )
         elif view_idx == 1:
-            volume[annotation.annotation_class.name].pixel_array[:, frame_idx, :] = np.logical_or(
-                im_mask, volume[annotation.annotation_class.name].pixel_array[:, frame_idx, :]
+            volume[annotation.annotation_class.name].pixel_array[
+                :, frame_idx, :
+            ] = np.logical_or(
+                im_mask,
+                volume[annotation.annotation_class.name].pixel_array[:, frame_idx, :],
             )
         elif view_idx == 2:
-            volume[annotation.annotation_class.name].pixel_array[frame_idx, :, :] = np.logical_or(
-                im_mask, volume[annotation.annotation_class.name].pixel_array[frame_idx, :, :]
+            volume[annotation.annotation_class.name].pixel_array[
+                frame_idx, :, :
+            ] = np.logical_or(
+                im_mask,
+                volume[annotation.annotation_class.name].pixel_array[frame_idx, :, :],
             )
 
 
-def write_output_volume_to_disk(output_volumes: Dict, image_id: str, output_dir: Union[str, Path]) -> None:
+def write_output_volume_to_disk(
+    output_volumes: Dict, image_id: str, output_dir: Union[str, Path]
+) -> None:
     # volumes are the values of this nested dict
     def unnest_dict_to_list(d: Dict) -> List:
         result = []
@@ -290,11 +323,15 @@ def unnest_dict_to_list(d: Dict) -> List:
             affine=volume.affine,
         )
         if volume.original_affine is not None:
-            orig_ornt = io_orientation(volume.original_affine)  # Get orientation of current affine
+            orig_ornt = io_orientation(
+                volume.original_affine
+            )  # Get orientation of current affine
             img_ornt = io_orientation(volume.affine)  # Get orientation of RAS affine
-            from_canonical = ornt_transform(img_ornt, orig_ornt)  # Get transform from RAS to current affine
+            from_canonical = ornt_transform(
+                img_ornt, orig_ornt
+            )  # Get transform from RAS to current affine
             img = img.as_reoriented(from_canonical)
-        output_path = Path(output_dir) / f"{image_id}_{volume.series_instance_uid}_{volume.class_name}.nii.gz"
+        output_path = Path(output_dir) / f"{image_id}_{volume.class_name}.nii.gz"
         if not output_path.parent.exists():
             output_path.parent.mkdir(parents=True)
         nib.save(img=img, filename=output_path)
@@ -362,7 +399,9 @@ def process_affine(affine):
         return affine
 
 
-def create_error_message_json(error_message: str, output_dir: Union[str, Path], image_id: str) -> bool:
+def create_error_message_json(
+    error_message: str, output_dir: Union[str, Path], image_id: str
+) -> bool:
     output_path = Path(output_dir) / f"{image_id}_error.json"
     if not output_path.parent.exists():
         output_path.parent.mkdir(parents=True)