Skip to content

Commit

Permalink
refactor [skip linting]
Browse files Browse the repository at this point in the history
  • Loading branch information
mrsmrynk committed Dec 21, 2023
1 parent e28b17f commit 51d80e4
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 22 deletions.
22 changes: 10 additions & 12 deletions src/inference/inference.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,29 @@
import numpy as np
from numpy import typing as npt

from src.utils.settings import IMAGE_SIZE, PADDING_SIZE
from .model import ModelProtocol # noqa: F401 (used for type hinting)
from .model import ModelProtocol


class Inference:

def __init__(self,
model):
model: ModelProtocol) -> None:
"""
| Initializer method
:param ModelProtocol model: model
:param model: model
:returns: None
:rtype: None
"""
self.model = model

@staticmethod
def remove_padding(mask):
def remove_padding(mask: npt.NDArray[np.uint8]) -> npt.NDArray[np.uint8]:
"""
| Returns the mask without padding.
:param np.ndarray[np.uint8] mask: mask
:param mask: mask
:returns: mask without padding
:rtype: np.ndarray[np.uint8]
"""
assert isinstance(mask, np.ndarray)
assert mask.dtype == np.uint8
Expand All @@ -34,15 +33,14 @@ def remove_padding(mask):
return np.array(mask[PADDING_SIZE:-PADDING_SIZE, PADDING_SIZE:-PADDING_SIZE])

def predict_mask(self,
image,
apply_padding=False):
image: npt.NDArray[np.float32],
apply_padding: bool = False) -> npt.NDArray[np.uint8]:
"""
| Returns the mask.
:param np.ndarray[np.float32] image: image
:param bool apply_padding: if True, the padding of the mask is removed
:param image: image
:param apply_padding: if True, the padding of the mask is removed
:returns: mask
:rtype: np.ndarray[np.uint8]
"""
assert isinstance(image, np.ndarray)
assert image.dtype == np.float32
Expand Down
18 changes: 8 additions & 10 deletions src/inference/model.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,46 @@
from pathlib import Path # noqa: F401 (used for type hinting)
from pathlib import Path
from typing import Protocol

import numpy as np
import onnxruntime as ort
from numpy import typing as npt


class ModelProtocol(Protocol):

def run(self,
image):
image: npt.NDArray[np.float32]) -> npt.NDArray[np.uint8]:
"""
| Returns the mask.
:param np.ndarray[np.float32] image: image
:param image: image
:returns: mask
:rtype: np.ndarray[np.uint8]
"""
...


class ONNXModel:

def __init__(self,
path):
path: Path) -> None:
"""
| Initializer method
:param Path path: path to the onnx model
:param path: path to the onnx model
:returns: None
:rtype: None
"""
assert isinstance(path, Path)

self.path = path
self._session = ort.InferenceSession(str(self.path))

def run(self,
image):
image: npt.NDArray[np.float32]) -> npt.NDArray[np.uint8]:
"""
| Returns the mask.
:param np.ndarray[np.float32] image: image
:param image: image
:returns: mask
:rtype: np.ndarray[np.uint8]
"""
assert isinstance(image, np.ndarray)
assert image.dtype == np.float32
Expand Down

0 comments on commit 51d80e4

Please sign in to comment.