-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
wrap algorithm in cli to allow proper capture within mhub framework
- Loading branch information
1 parent
2bafc52
commit fb17009
Showing
2 changed files
with
72 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,14 +8,14 @@ | |
Email: [email protected] | ||
------------------------------------------------------------- | ||
""" | ||
import SimpleITK | ||
import json | ||
import sys | ||
from pathlib import Path | ||
|
||
from mhubio.core import Instance, InstanceData, IO, Module, Meta, ValueOutput, OutputDataCollection | ||
|
||
# Import Node21 baseline nodule detection algorithm from the node21_detection_baseline repo | ||
from process import Noduledetection | ||
|
||
CLI_PATH = Path(__file__).parent.absolute() / "cli.py" | ||
|
||
|
||
@ValueOutput.Name('noduleprob') | ||
|
@@ -44,16 +44,25 @@ class Node21BaselineRunner(Module): | |
@IO.OutputDatas('nodule_probs', NoduleProbability) | ||
@IO.OutputDatas('nodule_bounding_boxes', NoduleBoundingBox) | ||
def task(self, instance: Instance, in_data: InstanceData, out_data: InstanceData, nodule_probs: OutputDataCollection, nodule_bounding_boxes: OutputDataCollection) -> None: | ||
# Read input image | ||
input_image = SimpleITK.ReadImage(in_data.abspath) | ||
|
||
# Run nodule detection algorithm on the input image and generate predictions | ||
tmp_path = Path("/app/tmp") | ||
predictions = Noduledetection(input_dir=tmp_path, output_dir=tmp_path).predict(input_image=input_image) | ||
|
||
# Export the predictions to a JSON file | ||
with open(out_data.abspath, "w") as f: | ||
json.dump(predictions, f, indent=4) | ||
# build command (order matters!) | ||
cmd = [ | ||
sys.executable, | ||
str(CLI_PATH), | ||
in_data.abspath, | ||
out_data.abspath | ||
] | ||
|
||
# run the command as subprocess | ||
self.subprocess(cmd, text=True) | ||
|
||
# Confirm the expected output file was generated | ||
if not Path(out_data.abspath).is_file(): | ||
raise FileNotFoundError(f"Node21BaseLineRunner - Could not find the expected " | ||
f"output file: {out_data.abspath}, something went wrong running the CLI.") | ||
|
||
# Read the predictions to a JSON file | ||
with open(out_data.abspath, "r") as f: | ||
predictions = json.load(f) | ||
|
||
# Export the relevant data | ||
for nodule_idx, box in enumerate(predictions["boxes"]): | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
""" | ||
--------------------------------------------------- | ||
Mhub / DIAG - CLI for the Node21 baseline Algorithm | ||
The model algorith was wrapped in a CLI to ensure | ||
the mhub framework is able to properly capture | ||
the stdout generated by the algorithm | ||
--------------------------------------------------- | ||
--------------------------------------------------- | ||
Author: Sil van de Leemput | ||
Email: [email protected] | ||
--------------------------------------------------- | ||
""" | ||
|
||
import argparse | ||
from pathlib import Path | ||
import json | ||
|
||
import SimpleITK | ||
|
||
# Import Node21 baseline nodule detection algorithm from the node21_detection_baseline repo | ||
from process import Noduledetection | ||
|
||
|
||
def run_classifier(input_cxr: Path, output_json_file: Path): | ||
# Read input image | ||
input_image = SimpleITK.ReadImage(str(input_cxr)) | ||
|
||
# Run nodule detection algorithm on the input image and generate predictions | ||
tmp_path = Path("/app/tmp") | ||
predictions = Noduledetection(input_dir=tmp_path, output_dir=tmp_path).predict(input_image=input_image) | ||
|
||
# Export the predictions to a JSON file | ||
with open(output_json_file, "w") as f: | ||
json.dump(predictions, f, indent=4) | ||
|
||
|
||
def run_classifier_cli(): | ||
parser = argparse.ArgumentParser("CLI to run the Node21 baseline classifier") | ||
parser.add_argument("input_cxr", type=str, help="input CXR image (MHA)") | ||
parser.add_argument("output_json_file", type=str, help="Output nodule bounding boxes and probabilities predictions (JSON)") | ||
args = parser.parse_args() | ||
run_classifier( | ||
input_cxr=Path(args.input_cxr), | ||
output_json_file=Path(args.output_json_file) | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
run_classifier_cli() |