diff --git a/pyroengine/engine.py b/pyroengine/engine.py index d5de215d..d50935d2 100644 --- a/pyroengine/engine.py +++ b/pyroengine/engine.py @@ -17,6 +17,7 @@ import cv2 # type: ignore[import-untyped] import numpy as np +from numpy.typing import NDArray from PIL import Image from pyroclient import client from requests.exceptions import ConnectionError @@ -267,7 +268,7 @@ def _update_states(self, frame: Image.Image, preds: np.ndarray, cam_key: str) -> return conf - def predict(self, frame: Image.Image, cam_id: Optional[str] = None) -> np.ndarray[Any, Any]: + def predict(self, frame: Image.Image, cam_id: Optional[str] = None) -> np.ndarray: """Computes the confidence that the image contains wildfire cues Args: