diff --git a/spineps/entrypoint.py b/spineps/entrypoint.py index 0d47f92..d5da75f 100755 --- a/spineps/entrypoint.py +++ b/spineps/entrypoint.py @@ -63,6 +63,7 @@ def parser_arguments(parser: argparse.ArgumentParser): help="Does not apply n4 bias field correction", ) # + parser.add_argument("-cpu", action="store_true", help="Use CPU instead of GPU (will take way longer)") parser.add_argument("-run_cprofiler", "-rcp", action="store_true", help="Runs a cprofiler over the entire action") parser.add_argument("-verbose", "-v", action="store_true", help="Prints much more stuff, may fully clutter your terminal") return parser @@ -183,13 +184,13 @@ def run_sample(opt: Namespace): if "/" in str(opt.model_semantic): # given path - model_semantic = get_segmentation_model(opt.model_semantic).load() + model_semantic = get_segmentation_model(opt.model_semantic, use_cpu=opt.cpu).load() else: - model_semantic = get_semantic_model(opt.model_semantic).load() + model_semantic = get_semantic_model(opt.model_semantic, use_cpu=opt.cpu).load() if "/" in str(opt.model_instance): - model_instance = get_segmentation_model(opt.model_instance).load() + model_instance = get_segmentation_model(opt.model_instance, use_cpu=opt.cpu).load() else: - model_instance = get_instance_model(opt.model_instance).load() + model_instance = get_instance_model(opt.model_instance, use_cpu=opt.cpu).load() bids_sample = BIDS_FILE(input_path, dataset=dataset, verbose=True) @@ -247,17 +248,17 @@ def run_dataset(opt: Namespace): if opt.model_semantic == "auto": model_semantic = None elif "/" in str(opt.model_semantic): - model_semantic = get_segmentation_model(opt.model_semantic).load() + model_semantic = get_segmentation_model(opt.model_semantic, use_cpu=opt.cpu).load() else: - model_semantic = get_semantic_model(opt.model_semantic).load() + model_semantic = get_semantic_model(opt.model_semantic, use_cpu=opt.cpu).load() # Model Instance if opt.model_instance == "auto": model_instance = None elif "/" in str(opt.model_instance): - model_instance = get_segmentation_model(opt.model_instance).load() + model_instance = get_segmentation_model(opt.model_instance, use_cpu=opt.cpu).load() else: - model_instance = get_instance_model(opt.model_instance).load() + model_instance = get_instance_model(opt.model_instance, use_cpu=opt.cpu).load() assert model_instance is not None, "-model_vert was None" diff --git a/spineps/models.py b/spineps/models.py index 314a211..2e0d861 100755 --- a/spineps/models.py +++ b/spineps/models.py @@ -12,7 +12,7 @@ logger.override_prefix = "Models" -def get_semantic_model(model_name: str) -> Segmentation_Model: +def get_semantic_model(model_name: str, **kwargs) -> Segmentation_Model: """Finds and returns a semantic model by name Args: @@ -33,10 +33,10 @@ def get_semantic_model(model_name: str) -> Segmentation_Model: if model_name not in possible_keys: logger.print(f"Model with name {model_name} does not exist, options are {possible_keys}", Log_Type.FAIL) raise KeyError(model_name) - return get_segmentation_model(_modelid2folder_subreg[model_name]) + return get_segmentation_model(_modelid2folder_subreg[model_name], **kwargs) -def get_instance_model(model_name: str) -> Segmentation_Model: +def get_instance_model(model_name: str, **kwargs) -> Segmentation_Model: """Finds and returns an instance model by name Args: @@ -57,7 +57,7 @@ def get_instance_model(model_name: str) -> Segmentation_Model: if model_name not in possible_keys: logger.print(f"Model with name {model_name} does not exist, options are {possible_keys}", Log_Type.FAIL) raise KeyError(model_name) - return get_segmentation_model(_modelid2folder_vert[model_name]) + return get_segmentation_model(_modelid2folder_vert[model_name], **kwargs) _modelid2folder_semantic: dict[str, Path] | None = None diff --git a/spineps/seg_model.py b/spineps/seg_model.py index f97d122..12c151e 100755 --- a/spineps/seg_model.py +++ b/spineps/seg_model.py @@ -29,6 +29,7 @@ def __init__( self, model_folder: str | Path, inference_config: Segmentation_Inference_Config | None = None, # type:ignore + use_cpu: bool = False, default_verbose: bool = False, default_allow_tqdm: bool = True, ): @@ -44,6 +45,7 @@ def __init__( assert os.path.exists(str(model_folder)), f"model_folder doesnt exist, got {model_folder}" # noqa: PTH110 self.logger = No_Logger() + self.use_cpu = use_cpu if inference_config is None: json_dir = Path(model_folder).joinpath("inference_config.json") @@ -61,7 +63,10 @@ def __init__( self.print("initialized with inference config", self.inference_config) @abstractmethod - def load(self, folds: tuple[str, ...] | None = None) -> Self: + def load( + self, + folds: tuple[str, ...] | None = None, + ) -> Self: """Loads the weights from disk Returns: @@ -110,7 +115,7 @@ def segment_scan( Args: input (Image_Reference | dict[InputType, Image_Reference]): input pad_size (int, optional): Padding in each dimension (times two more pixels in each dim). Defaults to 4. - step_size (float | None, optional): _description_. Defaults to 0.5. + step_size (float | None, optional): _description_. Defaults to None. resample_to_recommended (bool, optional): _description_. Defaults to True. verbose (bool, optional): _description_. Defaults to False. @@ -266,10 +271,11 @@ def __init__( self, model_folder: str | Path, inference_config: Segmentation_Inference_Config | None = None, + use_cpu: bool = False, default_verbose: bool = False, default_allow_tqdm: bool = True, ): - super().__init__(model_folder, inference_config, default_verbose, default_allow_tqdm) + super().__init__(model_folder, inference_config, use_cpu, default_verbose, default_allow_tqdm) def load(self, folds: tuple[str, ...] | None = None) -> Self: global threads_started # noqa: PLW0603 @@ -283,7 +289,7 @@ def load(self, folds: tuple[str, ...] | None = None) -> Self: init_threads=not threads_started, allow_non_final=True, verbose=False, - ddevice="cuda", + ddevice="cuda" if not self.use_cpu else "cpu", ) threads_started = True self.predictor.allow_tqdm = self.default_allow_tqdm @@ -311,10 +317,11 @@ def __init__( self, model_folder: str | Path, inference_config: Segmentation_Inference_Config | None = None, + use_cpu: bool = False, default_verbose: bool = False, default_allow_tqdm: bool = True, ): - super().__init__(model_folder, inference_config, default_verbose, default_allow_tqdm) + super().__init__(model_folder, inference_config, use_cpu, default_verbose, default_allow_tqdm) assert len(self.inference_config.expected_inputs) == 1, "Unet3D cannot expect more than one input" def load(self, folds: tuple[str, ...] | None = None) -> Self: # noqa: ARG002 @@ -324,7 +331,7 @@ def load(self, folds: tuple[str, ...] | None = None) -> Self: # noqa: ARG002 assert len(chktpath) == 1 model = PLNet.load_from_checkpoint(checkpoint_path=chktpath[0]) model.eval() - self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + self.device = torch.device("cuda:0" if torch.cuda.is_available() and not self.use_cpu else "cpu") model.to(self.device) self.predictor = model self.print("Model loaded from", self.model_folder, verbose=True) diff --git a/spineps/utils/predictor.py b/spineps/utils/predictor.py index 7900696..1a902b5 100755 --- a/spineps/utils/predictor.py +++ b/spineps/utils/predictor.py @@ -56,6 +56,7 @@ def __init__( self.use_gaussian = use_gaussian self.use_mirroring = use_mirroring if device.type == "cuda": + torch.backends.cudnn.benchmark = True device = torch.device(type="cuda", index=0) # set the desired GPU with CUDA_VISIBLE_DEVICES! if device.type != "cuda" and perform_everything_on_gpu: print("perform_everything_on_gpu=True is only supported for cuda devices! Setting this to False") @@ -288,7 +289,7 @@ def predict_logits_from_preprocessed_data(self, data: torch.Tensor) -> torch.Ten new_prediction = self.predict_sliding_window_return_logits(data, network=network) prediction += new_prediction prediction_stacked.append(new_prediction.to("cpu")) - + if len(self.list_of_parameters) > 1: prediction /= len(self.list_of_parameters) @@ -436,18 +437,20 @@ def predict_sliding_window_return_logits(self, input_image: torch.Tensor, networ slicers = self._internal_get_sliding_window_slicers(data.shape[1:]) + precision = torch.half if self.perform_everything_on_gpu else torch.float32 + # preallocate results and num_predictions results_device = self.device if self.perform_everything_on_gpu else torch.device("cpu") if self.verbose: print("preallocating arrays") try: - data = data.to(self.device) + data = data.to(self.device, dtype=precision) predicted_logits = torch.zeros( (self.label_manager.num_segmentation_heads, *data.shape[1:]), - dtype=torch.half, + dtype=precision, device=results_device, ) - n_predictions = torch.zeros(data.shape[1:], dtype=torch.half, device=results_device) + n_predictions = torch.zeros(data.shape[1:], dtype=precision, device=results_device) if self.use_gaussian: gaussian = compute_gaussian( tuple(self.configuration_manager.patch_size), @@ -458,13 +461,13 @@ def predict_sliding_window_return_logits(self, input_image: torch.Tensor, networ except RuntimeError: # sometimes the stuff is too large for GPUs. In that case fall back to CPU results_device = torch.device("cpu") - data = data.to(results_device) + data = data.to(results_device, dtype=precision) predicted_logits = torch.zeros( (self.label_manager.num_segmentation_heads, *data.shape[1:]), - dtype=torch.half, + dtype=precision, device=results_device, ) - n_predictions = torch.zeros(data.shape[1:], dtype=torch.half, device=results_device) + n_predictions = torch.zeros(data.shape[1:], dtype=precision, device=results_device) if self.use_gaussian: gaussian = compute_gaussian( tuple(self.configuration_manager.patch_size),