diff --git a/src/inference/inference.py b/src/inference/inference.py index 62e5c30..d8dd3c1 100644 --- a/src/inference/inference.py +++ b/src/inference/inference.py @@ -1,30 +1,29 @@ import numpy as np +from numpy import typing as npt from src.utils.settings import IMAGE_SIZE, PADDING_SIZE -from .model import ModelProtocol # noqa: F401 (used for type hinting) +from .model import ModelProtocol class Inference: def __init__(self, - model): + model: ModelProtocol) -> None: """ | Initializer method - :param ModelProtocol model: model + :param model: model :returns: None - :rtype: None """ self.model = model @staticmethod - def remove_padding(mask): + def remove_padding(mask: npt.NDArray[np.uint8]) -> npt.NDArray[np.uint8]: """ | Returns the mask without padding. - :param np.ndarray[np.uint8] mask: mask + :param mask: mask :returns: mask without padding - :rtype: np.ndarray[np.uint8] """ assert isinstance(mask, np.ndarray) assert mask.dtype == np.uint8 @@ -34,15 +33,14 @@ def remove_padding(mask): return np.array(mask[PADDING_SIZE:-PADDING_SIZE, PADDING_SIZE:-PADDING_SIZE]) def predict_mask(self, - image, - apply_padding=False): + image: npt.NDArray[np.float32], + apply_padding: bool = False) -> npt.NDArray[np.uint8]: """ | Returns the mask. - :param np.ndarray[np.float32] image: image - :param bool apply_padding: if True, the padding of the mask is removed + :param image: image + :param apply_padding: if True, the padding of the mask is removed :returns: mask - :rtype: np.ndarray[np.uint8] """ assert isinstance(image, np.ndarray) assert image.dtype == np.float32 diff --git a/src/inference/model.py b/src/inference/model.py index 0e7e26b..d497910 100644 --- a/src/inference/model.py +++ b/src/inference/model.py @@ -1,20 +1,20 @@ -from pathlib import Path # noqa: F401 (used for type hinting) +from pathlib import Path from typing import Protocol import numpy as np import onnxruntime as ort +from numpy import typing as npt class ModelProtocol(Protocol): def run(self, - image): + image: npt.NDArray[np.float32]) -> npt.NDArray[np.uint8]: """ | Returns the mask. - :param np.ndarray[np.float32] image: image + :param image: image :returns: mask - :rtype: np.ndarray[np.uint8] """ ... @@ -22,13 +22,12 @@ def run(self, class ONNXModel: def __init__(self, - path): + path: Path) -> None: """ | Initializer method - :param Path path: path to the onnx model + :param path: path to the onnx model :returns: None - :rtype: None """ assert isinstance(path, Path) @@ -36,13 +35,12 @@ def __init__(self, self._session = ort.InferenceSession(str(self.path)) def run(self, - image): + image: npt.NDArray[np.float32]) -> npt.NDArray[np.uint8]: """ | Returns the mask. - :param np.ndarray[np.float32] image: image + :param image: image :returns: mask - :rtype: np.ndarray[np.uint8] """ assert isinstance(image, np.ndarray) assert image.dtype == np.float32