Skip to content

Commit

Permalink
Merge pull request #17 from Hendrik-code/cpu
Browse files Browse the repository at this point in the history
CPU support
  • Loading branch information
Hendrik-code authored May 7, 2024
2 parents 56f2ac6 + 4eb9e5b commit 5f317c6
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 25 deletions.
17 changes: 9 additions & 8 deletions spineps/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def parser_arguments(parser: argparse.ArgumentParser):
help="Does not apply n4 bias field correction",
)
#
parser.add_argument("-cpu", action="store_true", help="Use CPU instead of GPU (will take way longer)")
parser.add_argument("-run_cprofiler", "-rcp", action="store_true", help="Runs a cprofiler over the entire action")
parser.add_argument("-verbose", "-v", action="store_true", help="Prints much more stuff, may fully clutter your terminal")
return parser
Expand Down Expand Up @@ -183,13 +184,13 @@ def run_sample(opt: Namespace):

if "/" in str(opt.model_semantic):
# given path
model_semantic = get_segmentation_model(opt.model_semantic).load()
model_semantic = get_segmentation_model(opt.model_semantic, use_cpu=opt.cpu).load()
else:
model_semantic = get_semantic_model(opt.model_semantic).load()
model_semantic = get_semantic_model(opt.model_semantic, use_cpu=opt.cpu).load()
if "/" in str(opt.model_instance):
model_instance = get_segmentation_model(opt.model_instance).load()
model_instance = get_segmentation_model(opt.model_instance, use_cpu=opt.cpu).load()
else:
model_instance = get_instance_model(opt.model_instance).load()
model_instance = get_instance_model(opt.model_instance, use_cpu=opt.cpu).load()

bids_sample = BIDS_FILE(input_path, dataset=dataset, verbose=True)

Expand Down Expand Up @@ -247,17 +248,17 @@ def run_dataset(opt: Namespace):
if opt.model_semantic == "auto":
model_semantic = None
elif "/" in str(opt.model_semantic):
model_semantic = get_segmentation_model(opt.model_semantic).load()
model_semantic = get_segmentation_model(opt.model_semantic, use_cpu=opt.cpu).load()
else:
model_semantic = get_semantic_model(opt.model_semantic).load()
model_semantic = get_semantic_model(opt.model_semantic, use_cpu=opt.cpu).load()

# Model Instance
if opt.model_instance == "auto":
model_instance = None
elif "/" in str(opt.model_instance):
model_instance = get_segmentation_model(opt.model_instance).load()
model_instance = get_segmentation_model(opt.model_instance, use_cpu=opt.cpu).load()
else:
model_instance = get_instance_model(opt.model_instance).load()
model_instance = get_instance_model(opt.model_instance, use_cpu=opt.cpu).load()

assert model_instance is not None, "-model_vert was None"

Expand Down
8 changes: 4 additions & 4 deletions spineps/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
logger.override_prefix = "Models"


def get_semantic_model(model_name: str) -> Segmentation_Model:
def get_semantic_model(model_name: str, **kwargs) -> Segmentation_Model:
"""Finds and returns a semantic model by name
Args:
Expand All @@ -33,10 +33,10 @@ def get_semantic_model(model_name: str) -> Segmentation_Model:
if model_name not in possible_keys:
logger.print(f"Model with name {model_name} does not exist, options are {possible_keys}", Log_Type.FAIL)
raise KeyError(model_name)
return get_segmentation_model(_modelid2folder_subreg[model_name])
return get_segmentation_model(_modelid2folder_subreg[model_name], **kwargs)


def get_instance_model(model_name: str) -> Segmentation_Model:
def get_instance_model(model_name: str, **kwargs) -> Segmentation_Model:
"""Finds and returns an instance model by name
Args:
Expand All @@ -57,7 +57,7 @@ def get_instance_model(model_name: str) -> Segmentation_Model:
if model_name not in possible_keys:
logger.print(f"Model with name {model_name} does not exist, options are {possible_keys}", Log_Type.FAIL)
raise KeyError(model_name)
return get_segmentation_model(_modelid2folder_vert[model_name])
return get_segmentation_model(_modelid2folder_vert[model_name], **kwargs)


_modelid2folder_semantic: dict[str, Path] | None = None
Expand Down
19 changes: 13 additions & 6 deletions spineps/seg_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(
self,
model_folder: str | Path,
inference_config: Segmentation_Inference_Config | None = None, # type:ignore
use_cpu: bool = False,
default_verbose: bool = False,
default_allow_tqdm: bool = True,
):
Expand All @@ -44,6 +45,7 @@ def __init__(
assert os.path.exists(str(model_folder)), f"model_folder doesnt exist, got {model_folder}" # noqa: PTH110

self.logger = No_Logger()
self.use_cpu = use_cpu

if inference_config is None:
json_dir = Path(model_folder).joinpath("inference_config.json")
Expand All @@ -61,7 +63,10 @@ def __init__(
self.print("initialized with inference config", self.inference_config)

@abstractmethod
def load(self, folds: tuple[str, ...] | None = None) -> Self:
def load(
self,
folds: tuple[str, ...] | None = None,
) -> Self:
"""Loads the weights from disk
Returns:
Expand Down Expand Up @@ -110,7 +115,7 @@ def segment_scan(
Args:
input (Image_Reference | dict[InputType, Image_Reference]): input
pad_size (int, optional): Padding in each dimension (times two more pixels in each dim). Defaults to 4.
step_size (float | None, optional): _description_. Defaults to 0.5.
step_size (float | None, optional): _description_. Defaults to None.
resample_to_recommended (bool, optional): _description_. Defaults to True.
verbose (bool, optional): _description_. Defaults to False.
Expand Down Expand Up @@ -266,10 +271,11 @@ def __init__(
self,
model_folder: str | Path,
inference_config: Segmentation_Inference_Config | None = None,
use_cpu: bool = False,
default_verbose: bool = False,
default_allow_tqdm: bool = True,
):
super().__init__(model_folder, inference_config, default_verbose, default_allow_tqdm)
super().__init__(model_folder, inference_config, use_cpu, default_verbose, default_allow_tqdm)

def load(self, folds: tuple[str, ...] | None = None) -> Self:
global threads_started # noqa: PLW0603
Expand All @@ -283,7 +289,7 @@ def load(self, folds: tuple[str, ...] | None = None) -> Self:
init_threads=not threads_started,
allow_non_final=True,
verbose=False,
ddevice="cuda",
ddevice="cuda" if not self.use_cpu else "cpu",
)
threads_started = True
self.predictor.allow_tqdm = self.default_allow_tqdm
Expand Down Expand Up @@ -311,10 +317,11 @@ def __init__(
self,
model_folder: str | Path,
inference_config: Segmentation_Inference_Config | None = None,
use_cpu: bool = False,
default_verbose: bool = False,
default_allow_tqdm: bool = True,
):
super().__init__(model_folder, inference_config, default_verbose, default_allow_tqdm)
super().__init__(model_folder, inference_config, use_cpu, default_verbose, default_allow_tqdm)
assert len(self.inference_config.expected_inputs) == 1, "Unet3D cannot expect more than one input"

def load(self, folds: tuple[str, ...] | None = None) -> Self: # noqa: ARG002
Expand All @@ -324,7 +331,7 @@ def load(self, folds: tuple[str, ...] | None = None) -> Self: # noqa: ARG002
assert len(chktpath) == 1
model = PLNet.load_from_checkpoint(checkpoint_path=chktpath[0])
model.eval()
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.device = torch.device("cuda:0" if torch.cuda.is_available() and not self.use_cpu else "cpu")
model.to(self.device)
self.predictor = model
self.print("Model loaded from", self.model_folder, verbose=True)
Expand Down
17 changes: 10 additions & 7 deletions spineps/utils/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(
self.use_gaussian = use_gaussian
self.use_mirroring = use_mirroring
if device.type == "cuda":
torch.backends.cudnn.benchmark = True
device = torch.device(type="cuda", index=0) # set the desired GPU with CUDA_VISIBLE_DEVICES!
if device.type != "cuda" and perform_everything_on_gpu:
print("perform_everything_on_gpu=True is only supported for cuda devices! Setting this to False")
Expand Down Expand Up @@ -288,7 +289,7 @@ def predict_logits_from_preprocessed_data(self, data: torch.Tensor) -> torch.Ten
new_prediction = self.predict_sliding_window_return_logits(data, network=network)
prediction += new_prediction
prediction_stacked.append(new_prediction.to("cpu"))

if len(self.list_of_parameters) > 1:
prediction /= len(self.list_of_parameters)

Expand Down Expand Up @@ -436,18 +437,20 @@ def predict_sliding_window_return_logits(self, input_image: torch.Tensor, networ

slicers = self._internal_get_sliding_window_slicers(data.shape[1:])

precision = torch.half if self.perform_everything_on_gpu else torch.float32

# preallocate results and num_predictions
results_device = self.device if self.perform_everything_on_gpu else torch.device("cpu")
if self.verbose:
print("preallocating arrays")
try:
data = data.to(self.device)
data = data.to(self.device, dtype=precision)
predicted_logits = torch.zeros(
(self.label_manager.num_segmentation_heads, *data.shape[1:]),
dtype=torch.half,
dtype=precision,
device=results_device,
)
n_predictions = torch.zeros(data.shape[1:], dtype=torch.half, device=results_device)
n_predictions = torch.zeros(data.shape[1:], dtype=precision, device=results_device)
if self.use_gaussian:
gaussian = compute_gaussian(
tuple(self.configuration_manager.patch_size),
Expand All @@ -458,13 +461,13 @@ def predict_sliding_window_return_logits(self, input_image: torch.Tensor, networ
except RuntimeError:
# sometimes the stuff is too large for GPUs. In that case fall back to CPU
results_device = torch.device("cpu")
data = data.to(results_device)
data = data.to(results_device, dtype=precision)
predicted_logits = torch.zeros(
(self.label_manager.num_segmentation_heads, *data.shape[1:]),
dtype=torch.half,
dtype=precision,
device=results_device,
)
n_predictions = torch.zeros(data.shape[1:], dtype=torch.half, device=results_device)
n_predictions = torch.zeros(data.shape[1:], dtype=precision, device=results_device)
if self.use_gaussian:
gaussian = compute_gaussian(
tuple(self.configuration_manager.patch_size),
Expand Down

0 comments on commit 5f317c6

Please sign in to comment.