diff --git a/spineps/models.py b/spineps/models.py index 14294d1..314a211 100755 --- a/spineps/models.py +++ b/spineps/models.py @@ -23,8 +23,16 @@ def get_semantic_model(model_name: str) -> Segmentation_Model: """ model_name = model_name.lower() _modelid2folder_subreg = modelid2folder_semantic() - if model_name not in _modelid2folder_subreg.keys(): - logger.print(f"Model with name {model_name} does not exist, options are {_modelid2folder_subreg.keys()}") + possible_keys = list(_modelid2folder_subreg.keys()) + if len(possible_keys) == 0: + logger.print( + "Found no available semantic models. Did you set one up by downloading modelweights and putting them into the folder specified by the env variable or did you want to specify with an absolute path instead?", + Log_Type.FAIL, + ) + raise KeyError(model_name) + if model_name not in possible_keys: + logger.print(f"Model with name {model_name} does not exist, options are {possible_keys}", Log_Type.FAIL) + raise KeyError(model_name) return get_segmentation_model(_modelid2folder_subreg[model_name]) @@ -39,8 +47,16 @@ def get_instance_model(model_name: str) -> Segmentation_Model: """ model_name = model_name.lower() _modelid2folder_vert = modelid2folder_instance() - if model_name not in _modelid2folder_vert.keys(): - logger.print(f"Model with name {model_name} does not exist, options are {_modelid2folder_vert.keys()}") + possible_keys = list(_modelid2folder_vert.keys()) + if len(possible_keys) == 0: + logger.print( + "Found no available instance models. Did you set one up by downloading modelweights and putting them into the folder specified by the env variable or did you want to specify with an absolute path instead?", + Log_Type.FAIL, + ) + raise KeyError(model_name) + if model_name not in possible_keys: + logger.print(f"Model with name {model_name} does not exist, options are {possible_keys}", Log_Type.FAIL) + raise KeyError(model_name) return get_segmentation_model(_modelid2folder_vert[model_name]) @@ -87,7 +103,7 @@ def check_available_models(models_folder: str | Path, verbose: bool = False) -> models_folder = Path(models_folder) assert models_folder.exists(), f"models_folder {models_folder} does not exist" - config_paths = search_path(models_folder, query="**/inference_config.json") + config_paths = search_path(models_folder, query="**/inference_config.json", suppress=True) global _modelid2folder_semantic, _modelid2folder_instance # noqa: PLW0603 _modelid2folder_semantic = {} # id to model_folder _modelid2folder_instance = {} # id to model_folder @@ -103,6 +119,11 @@ def check_available_models(models_folder: str | Path, verbose: bool = False) -> except Exception as e: logger.print(f"Modelfolder '{model_folder_name}' ignored, caused by '{e}'", Log_Type.STRANGE, verbose=verbose) # raise e # + if len(config_paths) == 0 or len(_modelid2folder_instance.keys()) == 0 or len(_modelid2folder_semantic.keys()) == 0: + logger.print( + "Automatic search for models did not find anything. Did you set the environment variable correctly? Did you download model weights and put them into the specified folder? Ignore this if you specified your model using an absolute path.", + Log_Type.FAIL, + ) return _modelid2folder_semantic, _modelid2folder_instance @@ -122,10 +143,16 @@ def get_segmentation_model(in_config: str | Path, **kwargs) -> Segmentation_Mode if os.path.isdir(str(in_dir)): # noqa: PTH112 # search for config - path_search = search_path(in_dir, "**/*inference_config.json") + path_search = search_path(in_dir, "**/*inference_config.json", suppress=True) + if len(path_search) == 0: + logger.print( + f"get_segmentation_model: did not find a singular inference_config.json in {in_dir}/**/*inference_config.json. Is this the correct folder?", + Log_Type.FAIL, + ) + raise FileNotFoundError(f"{in_dir}/**/*inference_config.json") assert ( len(path_search) == 1 - ), f"get_segmentation_model: did not found a singular inference_config.json in {in_dir}/**/*inference_config.json" + ), f"get_segmentation_model: found more than one inference_config.json in {in_dir}/**/*inference_config.json. Ambigous behavior, please manually correct this by removing one of these.\nFound {path_search}" in_dir = path_search[0] # else: # base = filepath_model(in_config, model_dir=None) diff --git a/spineps/utils/filepaths.py b/spineps/utils/filepaths.py index 5926268..5e4d7c6 100755 --- a/spineps/utils/filepaths.py +++ b/spineps/utils/filepaths.py @@ -5,9 +5,9 @@ from itertools import chain from pathlib import Path -spineps_environment_path_override = Path( - "/DATA/NAS/ongoing_projects/hendrik/mri_usage/models/" -) # None # You can put an absolute path to the model weights here instead of using environment variable +spineps_environment_path_override = None # Path( +# "/DATA/NAS/ongoing_projects/hendrik/mri_usage/models/" +# ) # None # You can put an absolute path to the model weights here instead of using environment variable spineps_environment_path_backup = Path(__file__).parent.parent.joinpath("models") # EDIT this to use this instead of environment variable @@ -29,7 +29,7 @@ def get_mri_segmentor_models_dir() -> Path: folder_path is not None ), "Environment variable 'SPINEPS_SEGMENTOR_MODELS' is not defined. Setup the environment variable as stated in the readme or set the override in utils.filepaths.py" folder_path = Path(folder_path) - assert folder_path.exists(), f"'SPINEPS_SEGMENTOR_MODELS' path {folder_path} does not exist" + assert folder_path.exists(), f"Environment variable 'SPINEPS_SEGMENTOR_MODELS' = {folder_path} does not exist" return folder_path