Skip to content

Commit

Permalink
Merge pull request #3 from Hendrik-code/usability_upgrade
Browse files Browse the repository at this point in the history
Usability upgrade
  • Loading branch information
Hendrik-code authored Feb 29, 2024
2 parents 3dbb54a + 7cd4737 commit 2f3214b
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 20 deletions.
11 changes: 2 additions & 9 deletions spineps/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def parser_arguments(parser: argparse.ArgumentParser):
"-nocrop",
"-nc",
action="store_true",
help="Does not crop input before semantically segmenting",
help="Does not crop input before semantically segmenting. Can improve the segmentation a little but depending on size costs more computation time",
)
parser.add_argument(
"-non4",
Expand All @@ -68,6 +68,7 @@ def parser_arguments(parser: argparse.ArgumentParser):
return parser


@citation_reminder
def entry_point():
modelids_semantic = list(modelid2folder_semantic().keys())
modelids_instance = list(modelid2folder_instance().keys())
Expand Down Expand Up @@ -160,14 +161,6 @@ def entry_point():
###########################
opt = main_parser.parse_args()

# Print citation
print("###########################")
print("SPINEPS: please cite")
print(
"Hendrik Möller, Robert Graf, Joachim Schmitt, Benjamin Keinert, Matan Atad, Anjany Sekuboyina, Felix Streckenbach, Hanna Sch ̈on, Florian Kofler, Thomas Kroencke, Stefanie Bette, Stefan Willich, Thomas Keil, Thoralf Niendorf, Tobias Pischon, Beate Ende-mann, Bjoern Menze, Daniel Rueckert, and Jan S. Kirschke. Spineps - automatic whole spine segmentation of t2-weighted mr images using a two-phase approach to multi-class semantic and instance segmentation. arXiv preprint arXiv:2402.16368, 2024."
)
print("###########################")

# print(opt)
if opt.cmd == "sample":
run_sample(opt)
Expand Down
41 changes: 34 additions & 7 deletions spineps/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,16 @@ def get_semantic_model(model_name: str) -> Segmentation_Model:
"""
model_name = model_name.lower()
_modelid2folder_subreg = modelid2folder_semantic()
if model_name not in _modelid2folder_subreg.keys():
logger.print(f"Model with name {model_name} does not exist, options are {_modelid2folder_subreg.keys()}")
possible_keys = list(_modelid2folder_subreg.keys())
if len(possible_keys) == 0:
logger.print(
"Found no available semantic models. Did you set one up by downloading modelweights and putting them into the folder specified by the env variable or did you want to specify with an absolute path instead?",
Log_Type.FAIL,
)
raise KeyError(model_name)
if model_name not in possible_keys:
logger.print(f"Model with name {model_name} does not exist, options are {possible_keys}", Log_Type.FAIL)
raise KeyError(model_name)
return get_segmentation_model(_modelid2folder_subreg[model_name])


Expand All @@ -39,8 +47,16 @@ def get_instance_model(model_name: str) -> Segmentation_Model:
"""
model_name = model_name.lower()
_modelid2folder_vert = modelid2folder_instance()
if model_name not in _modelid2folder_vert.keys():
logger.print(f"Model with name {model_name} does not exist, options are {_modelid2folder_vert.keys()}")
possible_keys = list(_modelid2folder_vert.keys())
if len(possible_keys) == 0:
logger.print(
"Found no available instance models. Did you set one up by downloading modelweights and putting them into the folder specified by the env variable or did you want to specify with an absolute path instead?",
Log_Type.FAIL,
)
raise KeyError(model_name)
if model_name not in possible_keys:
logger.print(f"Model with name {model_name} does not exist, options are {possible_keys}", Log_Type.FAIL)
raise KeyError(model_name)
return get_segmentation_model(_modelid2folder_vert[model_name])


Expand Down Expand Up @@ -87,7 +103,7 @@ def check_available_models(models_folder: str | Path, verbose: bool = False) ->
models_folder = Path(models_folder)
assert models_folder.exists(), f"models_folder {models_folder} does not exist"

config_paths = search_path(models_folder, query="**/inference_config.json")
config_paths = search_path(models_folder, query="**/inference_config.json", suppress=True)
global _modelid2folder_semantic, _modelid2folder_instance # noqa: PLW0603
_modelid2folder_semantic = {} # id to model_folder
_modelid2folder_instance = {} # id to model_folder
Expand All @@ -103,6 +119,11 @@ def check_available_models(models_folder: str | Path, verbose: bool = False) ->
except Exception as e:
logger.print(f"Modelfolder '{model_folder_name}' ignored, caused by '{e}'", Log_Type.STRANGE, verbose=verbose)
# raise e #
if len(config_paths) == 0 or len(_modelid2folder_instance.keys()) == 0 or len(_modelid2folder_semantic.keys()) == 0:
logger.print(
"Automatic search for models did not find anything. Did you set the environment variable correctly? Did you download model weights and put them into the specified folder? Ignore this if you specified your model using an absolute path.",
Log_Type.FAIL,
)
return _modelid2folder_semantic, _modelid2folder_instance


Expand All @@ -122,10 +143,16 @@ def get_segmentation_model(in_config: str | Path, **kwargs) -> Segmentation_Mode

if os.path.isdir(str(in_dir)): # noqa: PTH112
# search for config
path_search = search_path(in_dir, "**/*inference_config.json")
path_search = search_path(in_dir, "**/*inference_config.json", suppress=True)
if len(path_search) == 0:
logger.print(
f"get_segmentation_model: did not find a singular inference_config.json in {in_dir}/**/*inference_config.json. Is this the correct folder?",
Log_Type.FAIL,
)
raise FileNotFoundError(f"{in_dir}/**/*inference_config.json")
assert (
len(path_search) == 1
), f"get_segmentation_model: did not found a singular inference_config.json in {in_dir}/**/*inference_config.json"
), f"get_segmentation_model: found more than one inference_config.json in {in_dir}/**/*inference_config.json. Ambigous behavior, please manually correct this by removing one of these.\nFound {path_search}"
in_dir = path_search[0]
# else:
# base = filepath_model(in_config, model_dir=None)
Expand Down
8 changes: 4 additions & 4 deletions spineps/utils/filepaths.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from itertools import chain
from pathlib import Path

spineps_environment_path_override = Path(
"/DATA/NAS/ongoing_projects/hendrik/mri_usage/models/"
) # None # You can put an absolute path to the model weights here instead of using environment variable
spineps_environment_path_override = None # Path(
# "/DATA/NAS/ongoing_projects/hendrik/mri_usage/models/"
# ) # None # You can put an absolute path to the model weights here instead of using environment variable
spineps_environment_path_backup = Path(__file__).parent.parent.joinpath("models") # EDIT this to use this instead of environment variable


Expand All @@ -29,7 +29,7 @@ def get_mri_segmentor_models_dir() -> Path:
folder_path is not None
), "Environment variable 'SPINEPS_SEGMENTOR_MODELS' is not defined. Setup the environment variable as stated in the readme or set the override in utils.filepaths.py"
folder_path = Path(folder_path)
assert folder_path.exists(), f"'SPINEPS_SEGMENTOR_MODELS' path {folder_path} does not exist"
assert folder_path.exists(), f"Environment variable 'SPINEPS_SEGMENTOR_MODELS' = {folder_path} does not exist"
return folder_path


Expand Down

0 comments on commit 2f3214b

Please sign in to comment.