Skip to content

Commit

Permalink
improved filepath search so the error messages are more meaningful fo…
Browse files Browse the repository at this point in the history
…r the user
  • Loading branch information
Hendrik-code committed Feb 29, 2024
1 parent 4ae79ff commit 7cd4737
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 11 deletions.
41 changes: 34 additions & 7 deletions spineps/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,16 @@ def get_semantic_model(model_name: str) -> Segmentation_Model:
"""
model_name = model_name.lower()
_modelid2folder_subreg = modelid2folder_semantic()
if model_name not in _modelid2folder_subreg.keys():
logger.print(f"Model with name {model_name} does not exist, options are {_modelid2folder_subreg.keys()}")
possible_keys = list(_modelid2folder_subreg.keys())
if len(possible_keys) == 0:
logger.print(
"Found no available semantic models. Did you set one up by downloading modelweights and putting them into the folder specified by the env variable or did you want to specify with an absolute path instead?",
Log_Type.FAIL,
)
raise KeyError(model_name)
if model_name not in possible_keys:
logger.print(f"Model with name {model_name} does not exist, options are {possible_keys}", Log_Type.FAIL)
raise KeyError(model_name)
return get_segmentation_model(_modelid2folder_subreg[model_name])


Expand All @@ -39,8 +47,16 @@ def get_instance_model(model_name: str) -> Segmentation_Model:
"""
model_name = model_name.lower()
_modelid2folder_vert = modelid2folder_instance()
if model_name not in _modelid2folder_vert.keys():
logger.print(f"Model with name {model_name} does not exist, options are {_modelid2folder_vert.keys()}")
possible_keys = list(_modelid2folder_vert.keys())
if len(possible_keys) == 0:
logger.print(
"Found no available instance models. Did you set one up by downloading modelweights and putting them into the folder specified by the env variable or did you want to specify with an absolute path instead?",
Log_Type.FAIL,
)
raise KeyError(model_name)
if model_name not in possible_keys:
logger.print(f"Model with name {model_name} does not exist, options are {possible_keys}", Log_Type.FAIL)
raise KeyError(model_name)
return get_segmentation_model(_modelid2folder_vert[model_name])


Expand Down Expand Up @@ -87,7 +103,7 @@ def check_available_models(models_folder: str | Path, verbose: bool = False) ->
models_folder = Path(models_folder)
assert models_folder.exists(), f"models_folder {models_folder} does not exist"

config_paths = search_path(models_folder, query="**/inference_config.json")
config_paths = search_path(models_folder, query="**/inference_config.json", suppress=True)
global _modelid2folder_semantic, _modelid2folder_instance # noqa: PLW0603
_modelid2folder_semantic = {} # id to model_folder
_modelid2folder_instance = {} # id to model_folder
Expand All @@ -103,6 +119,11 @@ def check_available_models(models_folder: str | Path, verbose: bool = False) ->
except Exception as e:
logger.print(f"Modelfolder '{model_folder_name}' ignored, caused by '{e}'", Log_Type.STRANGE, verbose=verbose)
# raise e #
if len(config_paths) == 0 or len(_modelid2folder_instance.keys()) == 0 or len(_modelid2folder_semantic.keys()) == 0:
logger.print(
"Automatic search for models did not find anything. Did you set the environment variable correctly? Did you download model weights and put them into the specified folder? Ignore this if you specified your model using an absolute path.",
Log_Type.FAIL,
)
return _modelid2folder_semantic, _modelid2folder_instance


Expand All @@ -122,10 +143,16 @@ def get_segmentation_model(in_config: str | Path, **kwargs) -> Segmentation_Mode

if os.path.isdir(str(in_dir)): # noqa: PTH112
# search for config
path_search = search_path(in_dir, "**/*inference_config.json")
path_search = search_path(in_dir, "**/*inference_config.json", suppress=True)
if len(path_search) == 0:
logger.print(
f"get_segmentation_model: did not find a singular inference_config.json in {in_dir}/**/*inference_config.json. Is this the correct folder?",
Log_Type.FAIL,
)
raise FileNotFoundError(f"{in_dir}/**/*inference_config.json")
assert (
len(path_search) == 1
), f"get_segmentation_model: did not found a singular inference_config.json in {in_dir}/**/*inference_config.json"
), f"get_segmentation_model: found more than one inference_config.json in {in_dir}/**/*inference_config.json. Ambigous behavior, please manually correct this by removing one of these.\nFound {path_search}"
in_dir = path_search[0]
# else:
# base = filepath_model(in_config, model_dir=None)
Expand Down
8 changes: 4 additions & 4 deletions spineps/utils/filepaths.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from itertools import chain
from pathlib import Path

spineps_environment_path_override = Path(
"/DATA/NAS/ongoing_projects/hendrik/mri_usage/models/"
) # None # You can put an absolute path to the model weights here instead of using environment variable
spineps_environment_path_override = None # Path(
# "/DATA/NAS/ongoing_projects/hendrik/mri_usage/models/"
# ) # None # You can put an absolute path to the model weights here instead of using environment variable
spineps_environment_path_backup = Path(__file__).parent.parent.joinpath("models") # EDIT this to use this instead of environment variable


Expand All @@ -29,7 +29,7 @@ def get_mri_segmentor_models_dir() -> Path:
folder_path is not None
), "Environment variable 'SPINEPS_SEGMENTOR_MODELS' is not defined. Setup the environment variable as stated in the readme or set the override in utils.filepaths.py"
folder_path = Path(folder_path)
assert folder_path.exists(), f"'SPINEPS_SEGMENTOR_MODELS' path {folder_path} does not exist"
assert folder_path.exists(), f"Environment variable 'SPINEPS_SEGMENTOR_MODELS' = {folder_path} does not exist"
return folder_path


Expand Down

0 comments on commit 7cd4737

Please sign in to comment.