Skip to content

Commit

Permalink
wrap algorithm in cli to allow proper capture within mhub framework
Browse files Browse the repository at this point in the history
  • Loading branch information
silvandeleemput committed Apr 18, 2024
1 parent 2bafc52 commit fb17009
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 13 deletions.
35 changes: 22 additions & 13 deletions models/gc_node21_baseline/utils/Node21BaselineRunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@
Email: [email protected]
-------------------------------------------------------------
"""
import SimpleITK
import json
import sys
from pathlib import Path

from mhubio.core import Instance, InstanceData, IO, Module, Meta, ValueOutput, OutputDataCollection

# Import Node21 baseline nodule detection algorithm from the node21_detection_baseline repo
from process import Noduledetection

CLI_PATH = Path(__file__).parent.absolute() / "cli.py"


@ValueOutput.Name('noduleprob')
Expand Down Expand Up @@ -44,16 +44,25 @@ class Node21BaselineRunner(Module):
@IO.OutputDatas('nodule_probs', NoduleProbability)
@IO.OutputDatas('nodule_bounding_boxes', NoduleBoundingBox)
def task(self, instance: Instance, in_data: InstanceData, out_data: InstanceData, nodule_probs: OutputDataCollection, nodule_bounding_boxes: OutputDataCollection) -> None:
# Read input image
input_image = SimpleITK.ReadImage(in_data.abspath)

# Run nodule detection algorithm on the input image and generate predictions
tmp_path = Path("/app/tmp")
predictions = Noduledetection(input_dir=tmp_path, output_dir=tmp_path).predict(input_image=input_image)

# Export the predictions to a JSON file
with open(out_data.abspath, "w") as f:
json.dump(predictions, f, indent=4)
# build command (order matters!)
cmd = [
sys.executable,
str(CLI_PATH),
in_data.abspath,
out_data.abspath
]

# run the command as subprocess
self.subprocess(cmd, text=True)

# Confirm the expected output file was generated
if not Path(out_data.abspath).is_file():
raise FileNotFoundError(f"Node21BaseLineRunner - Could not find the expected "
f"output file: {out_data.abspath}, something went wrong running the CLI.")

# Read the predictions to a JSON file
with open(out_data.abspath, "r") as f:
predictions = json.load(f)

# Export the relevant data
for nodule_idx, box in enumerate(predictions["boxes"]):
Expand Down
50 changes: 50 additions & 0 deletions models/gc_node21_baseline/utils/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""
---------------------------------------------------
Mhub / DIAG - CLI for the Node21 baseline Algorithm
The model algorith was wrapped in a CLI to ensure
the mhub framework is able to properly capture
the stdout generated by the algorithm
---------------------------------------------------
---------------------------------------------------
Author: Sil van de Leemput
Email: [email protected]
---------------------------------------------------
"""

import argparse
from pathlib import Path
import json

import SimpleITK

# Import Node21 baseline nodule detection algorithm from the node21_detection_baseline repo
from process import Noduledetection


def run_classifier(input_cxr: Path, output_json_file: Path):
# Read input image
input_image = SimpleITK.ReadImage(str(input_cxr))

# Run nodule detection algorithm on the input image and generate predictions
tmp_path = Path("/app/tmp")
predictions = Noduledetection(input_dir=tmp_path, output_dir=tmp_path).predict(input_image=input_image)

# Export the predictions to a JSON file
with open(output_json_file, "w") as f:
json.dump(predictions, f, indent=4)


def run_classifier_cli():
parser = argparse.ArgumentParser("CLI to run the Node21 baseline classifier")
parser.add_argument("input_cxr", type=str, help="input CXR image (MHA)")
parser.add_argument("output_json_file", type=str, help="Output nodule bounding boxes and probabilities predictions (JSON)")
args = parser.parse_args()
run_classifier(
input_cxr=Path(args.input_cxr),
output_json_file=Path(args.output_json_file)
)


if __name__ == "__main__":
run_classifier_cli()

0 comments on commit fb17009

Please sign in to comment.