Skip to content

Commit

Permalink
Merge pull request #28 from Hendrik-code/axial_vibe
Browse files Browse the repository at this point in the history
Axial vibe
  • Loading branch information
Hendrik-code authored Aug 29, 2024
2 parents dfb9655 + f1d98a6 commit 5a56cbc
Show file tree
Hide file tree
Showing 17 changed files with 159 additions and 113 deletions.
26 changes: 26 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,31 @@ This is a segmentation pipeline to automatically, and robustly, segment the whol
If you are using SPINEPS, please cite the following:

```
SPINEPS:
Hendrik Möller, Robert Graf, Joachim Schmitt, Benjamin Keinert, Matan Atad, Anjany
Sekuboyina, Felix Streckenbach, Hanna Schon, Florian Kofler, Thomas Kroencke, Ste-
fanie Bette, Stefan Willich, Thomas Keil, Thoralf Niendorf, Tobias Pischon, Beate Ende-
mann, Bjoern Menze, Daniel Rueckert, and Jan S. Kirschke. Spineps – automatic whole
spine segmentation of t2-weighted mr images using a two-phase approach to multi-class
semantic and instance segmentation. arXiv preprint arXiv:2402.16368, 2024.
Source of the T2w/T1w Segmentation:
Robert Graf, Joachim Schmitt, Sarah Schlaeger, Hendrik Kristian Möller, Vasiliki
Sideri-Lampretsa, Anjany Sekuboyina, Sandro Manuel Krieg, Benedikt Wiestler, Bjoern
Menze, Daniel Rueckert, Jan Stefan Kirschke. Denoising diffusion-based MRI to CT image
translation enables automated spinal segmentation. Eur Radiol Exp 7, 70 (2023).
https://doi.org/10.1186/s41747-023-00385-2
```
SPINEPS:

ArXiv link: <a href="https://arxiv.org/abs/2402.16368">https://arxiv.org/abs/2402.16368</a>

Source of the T2w/T1w Segmentation:

Open Access link: <a href="https://doi.org/10.1186/s41747-023-00385-2">https://doi.org/10.1186/s41747-023-00385-2</a>

BibTeX citation:
```
@article{moeller2024,
Expand All @@ -34,6 +49,17 @@ BibTeX citation:
archivePrefix={arXiv},
primaryClass={eess.IV},
}
@article{graf2023denoising,
title={Denoising diffusion-based MRI to CT image translation enables automated spinal segmentation},
author={Graf, Robert and Schmitt, Joachim and Schlaeger, Sarah and M{\"o}ller, Hendrik Kristian and Sideri-Lampretsa, Vasiliki and Sekuboyina, Anjany and Krieg, Sandro Manuel and Wiestler, Benedikt and Menze, Bjoern and Rueckert, Daniel and others},
journal={European Radiology Experimental},
volume={7},
number={1},
pages={70},
year={2023},
publisher={Springer}
}
```

## Installation (Ubuntu)
Expand Down
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ exclude = ["models", "examples"]
[tool.poetry.scripts]
spineps = 'spineps.entrypoint:entry_point'

spineps_ = 'spineps.entrypoint:entrypoint_no_checks'

[tool.poetry.dependencies]
python = "^3.10 || ^3.11"
connected-components-3d = "^3.12.3"
Expand All @@ -28,9 +30,9 @@ SciPy = "^1.11.2"
torchmetrics = "^1.1.2"
tqdm = "^4.66.1"
einops= "^0.6.1"
nnunetv2 = "2.2"
nnunetv2 = "2.4.2"
tptbox = "^0.1.4"
antspyx = "*"
antspyx = "0.4.2"
rich = "^13.6.0"


Expand Down
2 changes: 2 additions & 0 deletions spineps/auto_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
semantic: dict[str, Path | str] = {
"t2w": link + current_highest_version + "/t2w.zip",
"t1w": link + current_highest_version + "/t1w.zip",
"vibe": link + current_highest_version + "/vibe.zip",
}


download_names = {
"instance": "instance_sagittal",
"t2w": "T2w_semantic",
"t1w": "T1w_semantic",
"vibe": "Vibe_semantic",
}


Expand Down
4 changes: 2 additions & 2 deletions spineps/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def entry_point():
"-ms",
# type=str.lower,
default=None,
# required=True,
required=True,
# choices=modelids_semantic,
metavar="",
help=f"The model used for the semantic segmentation. Choices are {modelids_semantic} or a string absolute path the model folder",
Expand Down Expand Up @@ -165,7 +165,7 @@ def entry_point():

@citation_reminder
def run_sample(opt: Namespace):
input_path = Path(opt.input)
input_path = Path(opt.input).absolute()
dataset = str(input_path.parent)
assert os.path.exists(dataset), f"-input parent does not exist, got {dataset}" # noqa: PTH110
assert dataset not in ("", "."), f"-input you only gave a filename, not a direction to the file, got {input_path}"
Expand Down
42 changes: 42 additions & 0 deletions spineps/example/get_gpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import time # noqa: INP001

import GPUtil
from TPTBox import Log_Type, No_Logger

logger = No_Logger()


def get_gpu(verbose: bool = False, max_load: float = 0.3, max_memory: float = 0.4):
GPUtil.showUtilization() if verbose else None
device_ids = GPUtil.getAvailable(
order="load",
limit=4,
maxLoad=max_load,
maxMemory=max_memory,
includeNan=False,
excludeID=[],
excludeUUID=[],
)
return device_ids


def intersection(lst1, lst2):
return set(lst1).intersection(lst2)


def get_free_gpus(blocked_gpus=None, max_load: float = 0.3, max_memory: float = 0.4):
# print("get_free_gpus")
if blocked_gpus is None:
blocked_gpus = {0: False, 1: False, 2: False, 3: False}
cached_list = get_gpu(max_load=max_load, max_memory=max_memory)
for _ in range(15):
time.sleep(0.25)
cached_list = intersection(cached_list, get_gpu())
# print("result:", list(cached_list))
gpulist = [i for i in list(cached_list) if i not in blocked_gpus or blocked_gpus[i] is False]
# print("result:", gpulist)
return gpulist


def thread_print(fold, *text):
logger.print(f"Fold [{fold}]: ", *text)
25 changes: 16 additions & 9 deletions spineps/example/helper_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,29 @@

from TPTBox import BIDS_FILE # noqa: E402

from spineps.models import get_segmentation_model # noqa: E402
from spineps.models import get_instance_model, get_semantic_model # noqa: E402
from spineps.seg_run import process_img_nii # noqa: E402
from spineps.utils.filepaths import filepath_model # noqa: E402

# Example
# python /spineps/example/helper_parallel.py -i PATH/TO/IMG.nii.gz -ds DATASET-PATH -der derivatives -ms [t1w,t2w,vibe] -mv instance

if __name__ == "__main__":
main_parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
main_parser.add_argument("-i", type=str)
main_parser.add_argument("-ds", type=str)
main_parser.add_argument("-der", type=str)
main_parser.add_argument("-ms", type=str)
main_parser.add_argument("-mv", type=str)
main_parser.add_argument("-der", default="derivatives", type=str)
main_parser.add_argument("-ms", default="t2w", type=str)
main_parser.add_argument("-mv", default="instance", type=str)
main_parser.add_argument("-snap", default=None, type=str)

opt = main_parser.parse_args()

input_bids_file = BIDS_FILE(file=opt.i, dataset=opt.ds)

model_dir = "/DATA/NAS/ongoing_projects/hendrik/nako-segmentation/nnUNet/"
ms = get_segmentation_model(in_config=filepath_model(opt.ms, model_dir=model_dir))
mv = get_segmentation_model(in_config=filepath_model(opt.mv, model_dir=model_dir))

ms = get_semantic_model(opt.ms)
mv = get_instance_model(opt.mv)
if opt.snap is not None:
Path(opt.snap).mkdir(exist_ok=True, parents=True)
process_img_nii(
img_ref=input_bids_file,
derivative_name=opt.der,
Expand All @@ -38,4 +41,8 @@
override_instance=False,
save_debug_data=False,
verbose=False,
ignore_compatibility_issues=False, # If true, we do not check if the file ending match like _T2w.nii.gz for T2w images
ignore_bids_filter=False, # If true, we do not check if BIDS compliant
save_raw=False, # Save output as they are produced by the model
snapshot_copy_folder=opt.snap,
)
5 changes: 0 additions & 5 deletions spineps/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,6 @@ def check_available_models(models_folder: str | Path, verbose: bool = False) ->
except Exception as e:
logger.print(f"Modelfolder '{model_folder_name}' ignored, caused by '{e}'", Log_Type.STRANGE, verbose=verbose)
# raise e #
if len(config_paths) == 0 or len(_modelid2folder_instance.keys()) == 0 or len(_modelid2folder_semantic.keys()) == 0:
logger.print(
"Automatic search for models did not find anything. Did you set the environment variable correctly? Did you download model weights and put them into the specified folder? Ignore this if you specified your model using an absolute path.",
Log_Type.FAIL,
)

return _modelid2folder_semantic, _modelid2folder_instance

Expand Down
13 changes: 5 additions & 8 deletions spineps/phase_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def get_corpus_coms(
seg_nii.assert_affine(orientation=["P", "I", "R"])
#
# Extract Corpus region and try to find all coms naively (some skips shouldnt matter)
corpus_nii = seg_nii.extract_label(Location.Vertebra_Corpus_border.value)
corpus_nii = seg_nii.extract_label([Location.Vertebra_Corpus_border, Location.Vertebra_Corpus])
corpus_nii.erode_msk_(mm=2, connectivity=2, verbose=False)
if 1 in corpus_nii.unique() and corpus_size_cleaning > 0:
corpus_nii.set_array_(
Expand All @@ -192,7 +192,7 @@ def get_corpus_coms(
)

if 1 not in corpus_nii.unique():
logger.print("No 1 in corpus nifty, cannot make vertebra mask", Log_Type.FAIL)
logger.print(f"No corpus found after get_corpus_coms post process, cannot make vertebra mask. {corpus_nii.unique()}", Log_Type.FAIL)
return None

if not process_detect_and_solve_merged_corpi:
Expand Down Expand Up @@ -256,19 +256,15 @@ def get_corpus_coms(
stats_by_height.pop(vl)
stats_by_height = dict(sorted(stats_by_height.items(), key=lambda x: x[1][0]))
stats_by_height_keys = list(stats_by_height.keys())
print(stats_by_height_keys)
continue

logger.print("Merged corpi, try to fix it", verbose=verbose)
neighbor_verts = {
stats_by_height_keys[idx + i]: stats_by_height[stats_by_height_keys[idx + i]]
for i in [-5, -4, -3, -2, -1, 1, 2, 3, 4, 5]
if (idx + i) in stats_by_height_keys and stats_by_height_keys[idx + i] < 99
if (idx + i) < len(stats_by_height_keys) and (idx + i) >= 0 and stats_by_height_keys[idx + i] < 99
}
# stats_by_height_keys[idx + i]
# for i in [-5, -4, -3, -2, -1, 1, 2, 3, 4, 5]
# if (idx + i) in stats_by_height_keys and stats_by_height_keys[idx + i] < 99
# ] # (+-3)

logger.print("neighbor_vert_labels", neighbor_verts, verbose=verbose)
if len(neighbor_verts) == 0:
logger.print("Got no neighbor vert labels to fix", Log_Type.FAIL)
Expand Down Expand Up @@ -505,6 +501,7 @@ def collect_vertebra_predictions(
47: 7,
48: 8,
49: 9,
50: 9,
Location.Spinal_Cord.value: 0,
Location.Spinal_Canal.value: 0,
Location.Vertebra_Disc.value: 0,
Expand Down
8 changes: 4 additions & 4 deletions spineps/seg_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


class MetaEnum(EnumMeta):
def __contains__(cls, item): # noqa: N805
def __contains__(cls, item):
try:
cls[item]
except ValueError:
Expand Down Expand Up @@ -62,11 +62,11 @@ def format_keys(cls, modalities: Self | list[Self]) -> list[str]:
elif modality == Modality.SEG:
result += ["msk", "seg"]
elif modality == Modality.T1w:
result += ["T1w", "t1", "T1"]
result += ["T1w", "t1", "T1", "T1c"]
elif modality == Modality.T2w:
result += ["T2w", "dixon", "mr", "t2", "T2"]
result += ["T2w", "dixon", "mr", "t2", "T2", "T2c"]
elif modality == Modality.Vibe:
result += ["t1dixon", "vibe"]
result += ["t1dixon", "vibe", "mevibe", "GRE"]
elif modality == Modality.MPR:
result += ["mpr", "MPR", "Mpr"]
else:
Expand Down
8 changes: 4 additions & 4 deletions spineps/seg_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,17 +240,17 @@ def modelid(self, include_log_name: bool = False):
return name
return self.inference_config.log_name

def dict_representation(self, input_zms: ZOOMS | None):
def dict_representation(self):
info = {
"name": self.modelid(), # self.inference_config.__repr__()
"model_path": str(self.model_folder),
"modality": str(self.modalities()),
"aquisition": str(self.acquisition()),
"resolution_range": str(self.inference_config.resolution_range),
}
if input_zms is not None:
proc_zms = self.calc_recommended_resampling_zoom(input_zms)
info["resolution_processed"] = str(proc_zms)
# if input_zms is not None:
# proc_zms = self.calc_recommended_resampling_zoom(input_zms)
# info["resolution_processed"] = str(proc_zms)
return info

def __str__(self):
Expand Down
3 changes: 1 addition & 2 deletions spineps/seg_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def predict_centroids_from_both(
seg_nii: NII,
models: list[Segmentation_Model],
parameter: dict[str, Any],
input_zms_pir: ZOOMS | None = None,
):
"""Calculates the centroids of each vertebra corpus by using both semantic and instance mask
Expand All @@ -65,7 +64,7 @@ def predict_centroids_from_both(

models_repr = {}
for idx, m in enumerate(models):
models_repr[idx] = m.dict_representation(input_zms_pir)
models_repr[idx] = m.dict_representation()
ctd.info["source"] = "MRI Segmentation Pipeline"
ctd.info["version"] = pipeline_version()
ctd.info["models"] = models_repr
Expand Down
Loading

0 comments on commit 5a56cbc

Please sign in to comment.