Skip to content

Commit

Permalink
first version of a labeling step done. ALPHA
Browse files Browse the repository at this point in the history
  • Loading branch information
Hendrik-code committed Nov 26, 2024
1 parent cd15a70 commit f5ae024
Show file tree
Hide file tree
Showing 19 changed files with 446 additions and 218 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ nnunetv2 = "2.4.2"
TPTBox = "^0.2.1"
antspyx = "0.4.2"
rich = "^13.6.0"
monai="^1.3.0"


[tool.poetry.dev-dependencies]
Expand Down
2 changes: 0 additions & 2 deletions spineps/architectures/pl_densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
from torch import nn
from TypeSaveArgParse import Class_to_ArgParse

from spineps.architectures.read_labels import Objectives


@dataclass
class ARGS_MODEL(Class_to_ArgParse):
Expand Down
2 changes: 0 additions & 2 deletions spineps/architectures/pl_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,6 @@ def _shared_metric_append(self, metrics, outputs):
def _shared_cat_metrics(self, outputs):
results = {}
for m, v in outputs.items():
# v = np.asarray(v)
# print(m, v.shape)
stacked = torch.stack(v)
results[m] = torch.mean(stacked) if m != "dice_p_cls" else torch.mean(stacked, dim=0)
return results
Expand Down
27 changes: 0 additions & 27 deletions spineps/architectures/read_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,30 +195,3 @@ def flatten(a: list[str | int | list[str] | list[int]]):
else:
for b in a:
yield from flatten(b)


###

# Eval-pipeline zuerst
# sensitivity, recall, AUC, ROC, F1, MCC
# dann MONAI baseline bauen mit Resnet, Densenet, ViT
if __name__ == "__main__":
objectives = Objectives(
[
Target.FULLYVISIBLE,
Target.REGION,
Target.VERTREL,
Target.VERT,
],
as_group=True,
)

entry_dict = {
"vert_exact": VertExact.L1,
"vert_region": VertRegion.LWS,
"vert_rel": VertRel.FIRST_LWK,
"vert_cut": True,
}

label = objectives(entry_dict)
print(label)
54 changes: 47 additions & 7 deletions spineps/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
from TPTBox import BIDS_FILE, Log_Type, No_Logger

from spineps.get_models import (
get_actual_model,
get_instance_model,
get_segmentation_model,
get_labeling_model,
get_semantic_model,
modelid2folder_instance,
modelid2folder_labeling,
modelid2folder_semantic,
)
from spineps.seg_run import process_dataset, process_img_nii
Expand Down Expand Up @@ -74,6 +76,7 @@ def parser_arguments(parser: argparse.ArgumentParser):
def entry_point():
modelids_semantic = list(modelid2folder_semantic().keys())
modelids_instance = list(modelid2folder_instance().keys())
modelids_labeling = [*list(modelid2folder_labeling().keys()), "none"]
###########################
###########################
main_parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
Expand Down Expand Up @@ -107,6 +110,16 @@ def entry_point():
metavar="",
help=f"The model used for the vertebra instance segmentation. Choices are {modelids_instance} or a string absolute path the model folder",
)
parser_sample.add_argument(
"-model_labeling",
"-ml",
# type=str.lower,
default="labeling",
# required=True,
# choices=modelids_instance,
metavar="",
help=f"The model used for the vertebra labeling classification. Choices are {modelids_labeling} or a string absolute path the model folder",
)
parser_sample = parser_arguments(parser_sample)

###########################
Expand Down Expand Up @@ -135,6 +148,16 @@ def entry_point():
metavar="",
help=f"The model used for the vertebra segmentation. Choices are {model_vert_choices} or a string absolute path the model folder",
)
parser_dataset.add_argument(
"-model_labeling",
"-ml",
# type=str.lower,
default="labeling",
# required=True,
# choices=modelids_instance,
metavar="",
help=f"The model used for the vertebra labeling classification. Choices are {modelids_labeling} or a string absolute path the model folder",
)
parser_dataset.add_argument(
"-ignore_bids_filter",
"-ibf",
Expand Down Expand Up @@ -180,23 +203,31 @@ def run_sample(opt: Namespace):
if not input_path.endswith(".nii.gz"):
input_path += ".nii.gz"
assert os.path.isfile(input_path), f"-input does not exist or is not a file, got {input_path}" # noqa: PTH113

# model semantic
if "/" in str(opt.model_semantic):
# given path
model_semantic = get_segmentation_model(opt.model_semantic, use_cpu=opt.cpu).load()
model_semantic = get_actual_model(opt.model_semantic, use_cpu=opt.cpu).load()
else:
model_semantic = get_semantic_model(opt.model_semantic, use_cpu=opt.cpu).load()
# model instance
if "/" in str(opt.model_instance):
model_instance = get_segmentation_model(opt.model_instance, use_cpu=opt.cpu).load()
model_instance = get_actual_model(opt.model_instance, use_cpu=opt.cpu).load()
else:
model_instance = get_instance_model(opt.model_instance, use_cpu=opt.cpu).load()
# model labeling
if opt.model_labeling == "none":
model_labeling = None
elif "/" in str(opt.model_labeling):
model_labeling = get_actual_model(opt.model_labeling, use_cpu=opt.cpu).load()
else:
model_labeling = get_labeling_model(opt.model_labeling, use_cpu=opt.cpu).load()

bids_sample = BIDS_FILE(input_path, dataset=dataset, verbose=True)

kwargs = {
"img_ref": bids_sample,
"model_semantic": model_semantic,
"model_instance": model_instance,
"model_labeling": model_labeling,
"derivative_name": opt.der_name,
#
# "save_uncertainty_image": opt.save_unc_img,
Expand Down Expand Up @@ -245,24 +276,33 @@ def run_dataset(opt: Namespace):
if opt.model_semantic == "auto":
model_semantic = None
elif "/" in str(opt.model_semantic):
model_semantic = get_segmentation_model(opt.model_semantic, use_cpu=opt.cpu).load()
model_semantic = get_actual_model(opt.model_semantic, use_cpu=opt.cpu).load()
else:
model_semantic = get_semantic_model(opt.model_semantic, use_cpu=opt.cpu).load()

# Model Instance
if opt.model_instance == "auto":
model_instance = None
elif "/" in str(opt.model_instance):
model_instance = get_segmentation_model(opt.model_instance, use_cpu=opt.cpu).load()
model_instance = get_actual_model(opt.model_instance, use_cpu=opt.cpu).load()
else:
model_instance = get_instance_model(opt.model_instance, use_cpu=opt.cpu).load()

# Model Labeling
if opt.model_labeling == "none":
model_labeling = None
elif "/" in str(opt.model_labeling):
model_labeling = get_actual_model(opt.model_labeling, use_cpu=opt.cpu).load()
else:
model_labeling = get_labeling_model(opt.model_labeling, use_cpu=opt.cpu).load()

assert model_instance is not None, "-model_vert was None"

kwargs = {
"dataset_path": input_dir,
"model_semantic": model_semantic,
"model_instance": model_instance,
"model_labeling": model_labeling,
"rawdata_name": opt.raw_name,
"derivative_name": opt.der_name,
#
Expand Down
94 changes: 81 additions & 13 deletions spineps/get_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@

from TPTBox import Log_Type, No_Logger

from spineps.seg_enums import Modality
from spineps.seg_model import Segmentation_Model, modeltype2class
from spineps.utils.auto_download import download_if_missing, instances, semantic
from spineps.lab_model import VertLabelingClassifier
from spineps.seg_enums import Modality, ModelType, SpinepsPhase
from spineps.seg_model import Segmentation_Model, Segmentation_Model_NNunet, Segmentation_Model_Unet3D
from spineps.utils.auto_download import download_if_missing, instances, labeling, semantic
from spineps.utils.filepaths import get_mri_segmentor_models_dir, search_path
from spineps.utils.seg_modelconfig import load_inference_config

Expand Down Expand Up @@ -38,8 +39,8 @@ def get_semantic_model(model_name: str, **kwargs) -> Segmentation_Model:
config_path = _modelid2folder_subreg[model_name]
if str(config_path).startswith("http"):
# Resolve HTTP
config_path = download_if_missing(model_name, config_path, is_instance=False)
return get_segmentation_model(config_path, **kwargs)
config_path = download_if_missing(model_name, config_path, phase=SpinepsPhase.SEMANTIC)
return get_actual_model(config_path, **kwargs)


def get_instance_model(model_name: str, **kwargs) -> Segmentation_Model:
Expand All @@ -66,13 +67,43 @@ def get_instance_model(model_name: str, **kwargs) -> Segmentation_Model:
config_path = _modelid2folder_vert[model_name]
if str(config_path).startswith("http"):
# Resolve HTTP
config_path = download_if_missing(model_name, config_path, is_instance=True)
config_path = download_if_missing(model_name, config_path, phase=SpinepsPhase.INSTANCE)

return get_segmentation_model(config_path, **kwargs)
return get_actual_model(config_path, **kwargs)


def get_labeling_model(model_name: str, **kwargs) -> VertLabelingClassifier:
"""Finds and returns an instance model by name
Args:
model_name (str): _description_
Returns:
Segmentation_Model: _description_
"""
model_name = model_name.lower()
_modelid2folder_labeling = modelid2folder_labeling()
possible_keys = list(_modelid2folder_labeling.keys())
if len(possible_keys) == 0:
logger.print(
"Found no available labeling models. Did you set one up by downloading modelweights and putting them into the folder specified by the env variable or did you want to specify with an absolute path instead?",
Log_Type.FAIL,
)
raise KeyError(model_name)
if model_name not in possible_keys:
logger.print(f"Model with name {model_name} does not exist, options are {possible_keys}", Log_Type.FAIL)
raise KeyError(model_name)
config_path = _modelid2folder_labeling[model_name]
if str(config_path).startswith("http"):
# Resolve HTTP
config_path = download_if_missing(model_name, config_path, phase=SpinepsPhase.LABELING)

return get_actual_model(config_path, **kwargs)


_modelid2folder_semantic: dict[str, Path | str] | None = None
_modelid2folder_instance: dict[str, Path | str] | None = None
_modelid2folder_labeling: dict[str, Path | str] | None = None


def modelid2folder_semantic() -> dict[str, Path | str]:
Expand All @@ -99,6 +130,18 @@ def modelid2folder_instance() -> dict[str, Path | str]:
return check_available_models(get_mri_segmentor_models_dir())[1]


def modelid2folder_labeling() -> dict[str, Path | str]:
"""Returns the dictionary mapping labeling model ids to their corresponding path
Returns:
_type_: _description_
"""
if _modelid2folder_labeling is not None:
return _modelid2folder_labeling
else:
return check_available_models(get_mri_segmentor_models_dir())[2]


def check_available_models(models_folder: str | Path, verbose: bool = False) -> tuple[dict[str, Path | int], dict[str, Path | int]]:
"""Searches through the specified directories and finds models, sorting them into the dictionaries mapping to instance or semantic models
Expand All @@ -115,26 +158,51 @@ def check_available_models(models_folder: str | Path, verbose: bool = False) ->
assert models_folder.exists(), f"models_folder {models_folder} does not exist"

config_paths = search_path(models_folder, query="**/inference_config.json", suppress=True)
global _modelid2folder_semantic, _modelid2folder_instance # noqa: PLW0603
global _modelid2folder_semantic, _modelid2folder_instance, _modelid2folder_labeling # noqa: PLW0603
_modelid2folder_semantic = semantic # id to model_folder
_modelid2folder_instance = instances # id to model_folder
_modelid2folder_labeling = labeling
for cp in config_paths:
model_folder = cp.parent
model_folder_name = model_folder.name.lower()
try:
inference_config = load_inference_config(str(cp))
if Modality.SEG in inference_config.modalities:
if inference_config.modeltype == ModelType.classifier:
_modelid2folder_labeling[model_folder_name] = model_folder
elif Modality.SEG in inference_config.modalities:
_modelid2folder_instance[model_folder_name] = model_folder
else:
_modelid2folder_semantic[model_folder_name] = model_folder
except Exception as e:
logger.print(f"Modelfolder '{model_folder_name}' ignored, caused by '{e}'", Log_Type.STRANGE, verbose=verbose)
# raise e #

return _modelid2folder_semantic, _modelid2folder_instance
return _modelid2folder_semantic, _modelid2folder_instance, _modelid2folder_labeling


def modeltype2class(modeltype: ModelType):
"""Maps ModelType to actual Segmentation_Model Subclass
Args:
type (ModelType): _description_
Raises:
NotImplementedError: _description_
Returns:
_type_: _description_
"""
if modeltype == ModelType.nnunet:
return Segmentation_Model_NNunet
elif modeltype == ModelType.unet:
return Segmentation_Model_Unet3D
elif modeltype == ModelType.classifier:
return VertLabelingClassifier
else:
raise NotImplementedError(modeltype)


def get_segmentation_model(in_config: str | Path, **kwargs) -> Segmentation_Model:
def get_actual_model(in_config: str | Path, **kwargs) -> Segmentation_Model | VertLabelingClassifier:
"""Creates the Model class from given path
Args:
Expand All @@ -154,13 +222,13 @@ def get_segmentation_model(in_config: str | Path, **kwargs) -> Segmentation_Mode
path_search = search_path(in_dir, "**/*inference_config.json", suppress=True)
if len(path_search) == 0:
logger.print(
f"get_segmentation_model: did not find a singular inference_config.json in {in_dir}/**/*inference_config.json. Is this the correct folder?",
f"get_actual_model: did not find a singular inference_config.json in {in_dir}/**/*inference_config.json. Is this the correct folder?",
Log_Type.FAIL,
)
raise FileNotFoundError(f"{in_dir}/**/*inference_config.json")
assert (
len(path_search) == 1
), f"get_segmentation_model: found more than one inference_config.json in {in_dir}/**/*inference_config.json. Ambigous behavior, please manually correct this by removing one of these.\nFound {path_search}"
), f"get_actual_model: found more than one inference_config.json in {in_dir}/**/*inference_config.json. Ambigous behavior, please manually correct this by removing one of these.\nFound {path_search}"
in_dir = path_search[0]
# else:
# base = filepath_model(in_config, model_dir=None)
Expand Down
Loading

0 comments on commit f5ae024

Please sign in to comment.