diff --git a/README.md b/README.md index e6c9f38..d97070a 100644 --- a/README.md +++ b/README.md @@ -12,8 +12,8 @@ This is a segmentation pipeline to automatically, and robustly, segment the whol If you are using SPINEPS, please cite the following: ``` -Hendrik M ̈oller, Robert Graf, Joachim Schmitt, Benjamin Keinert, Matan Atad, Anjany -Sekuboyina, Felix Streckenbach, Hanna Sch ̈on, Florian Kofler, Thomas Kroencke, Ste- +Hendrik Möller, Robert Graf, Joachim Schmitt, Benjamin Keinert, Matan Atad, Anjany +Sekuboyina, Felix Streckenbach, Hanna Schon, Florian Kofler, Thomas Kroencke, Ste- fanie Bette, Stefan Willich, Thomas Keil, Thoralf Niendorf, Tobias Pischon, Beate Ende- mann, Bjoern Menze, Daniel Rueckert, and Jan S. Kirschke. Spineps – automatic whole spine segmentation of t2-weighted mr images using a two-phase approach to multi-class @@ -24,7 +24,7 @@ ArXiv link: https://arxiv.org/abs/240 BibTeX citation: ``` -@article{moller2024, +@article{moeller2024, title={SPINEPS -- Automatic Whole Spine Segmentation of T2-weighted MR images using a Two-Phase Approach to Multi-class Semantic and Instance Segmentation}, author={Hendrik Möller and Robert Graf and Joachim Schmitt and Benjamin Keinert and Matan Atad and Anjany Sekuboyina and Felix Streckenbach and Hanna Schön and Florian Kofler and Thomas Kroencke and Stefanie Bette and Stefan Willich and Thomas Keil and Thoralf Niendorf and Tobias Pischon and Beate Endemann and Bjoern Menze and Daniel Rueckert and Jan S. Kirschke}, journal={arXiv preprint arXiv:2402.16368}, diff --git a/pyproject.toml b/pyproject.toml index ad66c52..a9162e0 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ einops= "^0.6.1" nnunetv2 = "2.2" tptbox = "*" antspyx = "*" +rich = "^13.6.0" [tool.poetry-dynamic-versioning] diff --git a/spineps/entrypoint.py b/spineps/entrypoint.py index fb9a350..29638f3 100755 --- a/spineps/entrypoint.py +++ b/spineps/entrypoint.py @@ -9,6 +9,7 @@ from spineps.models import get_instance_model, get_segmentation_model, get_semantic_model, modelid2folder_instance, modelid2folder_semantic from spineps.seg_run import process_dataset, process_img_nii +from spineps.utils.citation_reminder import citation_reminder logger = No_Logger() logger.override_prefix = "Init" @@ -159,6 +160,14 @@ def entry_point(): ########################### opt = main_parser.parse_args() + # Print citation + print("###########################") + print("SPINEPS: please cite") + print( + "Hendrik Möller, Robert Graf, Joachim Schmitt, Benjamin Keinert, Matan Atad, Anjany Sekuboyina, Felix Streckenbach, Hanna Sch ̈on, Florian Kofler, Thomas Kroencke, Stefanie Bette, Stefan Willich, Thomas Keil, Thoralf Niendorf, Tobias Pischon, Beate Ende-mann, Bjoern Menze, Daniel Rueckert, and Jan S. Kirschke. Spineps - automatic whole spine segmentation of t2-weighted mr images using a two-phase approach to multi-class semantic and instance segmentation. arXiv preprint arXiv:2402.16368, 2024." + ) + print("###########################") + # print(opt) if opt.cmd == "sample": run_sample(opt) @@ -168,6 +177,7 @@ def entry_point(): raise NotImplementedError("cmd", opt.cmd) +@citation_reminder def run_sample(opt: Namespace): input_path = Path(opt.input) dataset = str(input_path.parent) @@ -234,6 +244,7 @@ def run_sample(opt: Namespace): return 1 +@citation_reminder def run_dataset(opt: Namespace): input_dir = Path(opt.directory) assert input_dir.exists(), f"-input does not exist, {input_dir}" diff --git a/spineps/seg_model.py b/spineps/seg_model.py index db397dc..f18a9c5 100755 --- a/spineps/seg_model.py +++ b/spineps/seg_model.py @@ -11,6 +11,7 @@ from spineps.seg_enums import Acquisition, InputType, Modality, ModelType, OutputType from spineps.seg_modelconfig import Segmentation_Inference_Config, load_inference_config from spineps.Unet3D.pl_unet import PLNet +from spineps.utils.citation_reminder import citation_reminder from spineps.utils.filepaths import search_path from spineps.utils.inference_api import load_inf_model, run_inference @@ -94,6 +95,7 @@ def same_modelzoom_as_model(self, model: Self, input_zoom: Zooms) -> bool: match: bool = bool(np.all([self_zms[i] - model_zms[i] < 1e-4 for i in range(3)])) return match + @citation_reminder def segment_scan( self, input_image: Image_Reference | dict[InputType, Image_Reference], diff --git a/spineps/seg_modelconfig.py b/spineps/seg_modelconfig.py index c2aa67d..7c0ecdb 100755 --- a/spineps/seg_modelconfig.py +++ b/spineps/seg_modelconfig.py @@ -22,7 +22,7 @@ def __init__( resolution_range: Zooms | tuple[Zooms, Zooms], default_step_size: float, labels: dict, - expected_inputs: list[InputType] = [InputType.img], # noqa: B006 + expected_inputs: list[InputType | str] = [InputType.img], # noqa: B006 **kwargs, ): if not isinstance(modality, list): @@ -37,7 +37,7 @@ def __init__( self.available_folds: int = int(available_folds) self.inference_augmentation: bool = inference_augmentation self.default_step_size = float(default_step_size) - self.expected_inputs = [InputType[i] for i in expected_inputs] # type: ignore + self.expected_inputs = [InputType[i] if isinstance(i, str) else i for i in expected_inputs] # type: ignore names = [member.name for member in Location] try: self.segmentation_labels = { diff --git a/spineps/seg_run.py b/spineps/seg_run.py index 13efc64..51adbf1 100755 --- a/spineps/seg_run.py +++ b/spineps/seg_run.py @@ -24,8 +24,10 @@ check_model_modality_acquisition, find_best_matching_model, ) +from spineps.utils.citation_reminder import citation_reminder +@citation_reminder def process_dataset( dataset_path: Path, model_instance: Segmentation_Model, @@ -226,6 +228,7 @@ def process_dataset( logger.print(not_properly_processed) +@citation_reminder def process_img_nii( # noqa: C901 img_ref: BIDS_FILE, model_semantic: Segmentation_Model, diff --git a/spineps/utils/citation_reminder.py b/spineps/utils/citation_reminder.py new file mode 100644 index 0000000..c88c81b --- /dev/null +++ b/spineps/utils/citation_reminder.py @@ -0,0 +1,42 @@ +import atexit +import os + +from rich.console import Console + +GITHUB_LINK = "https://github.com/Hendrik-code/spineps" + +ARXIV_LINK = "https://arxiv.org/abs/2402.16368" + +has_reminded_citation = False + + +def citation_reminder(func): + """Decorator to remind users to cite SPINEPS.""" + + def wrapper(*args, **kwargs): + global has_reminded_citation # noqa: PLW0603 + if not has_reminded_citation: + print_citation_reminder() + has_reminded_citation = True + func_result = func(*args, **kwargs) + return func_result + + return wrapper + + +def print_citation_reminder(): + console = Console() + console.rule("Thank you for using [bold]SPINEPS[/bold]") + console.print( + "Please support our development by citing", + justify="center", + ) + console.print( + f"GitHub: {GITHUB_LINK}\nArXiv: {ARXIV_LINK}\n Thank you!", + justify="center", + ) + console.rule() + console.line() + + +atexit.register(print_citation_reminder)