diff --git a/pyroengine/core.py b/pyroengine/core.py index 776af058..a202e987 100644 --- a/pyroengine/core.py +++ b/pyroengine/core.py @@ -5,19 +5,20 @@ import logging import signal +from multiprocessing import Process, Queue from types import FrameType -from typing import List, Optional +from typing import Optional, Tuple +import numpy as np import urllib3 - -from .engine import Engine -from .sensors import ReolinkCamera +from PIL import Image __all__ = ["SystemController"] urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) logging.basicConfig(format="%(asctime)s | %(levelname)s: %(message)s", level=logging.INFO, force=True) +PredictionResult = Tuple[np.ndarray, Image.Image, str] def handler(signum: int, frame: Optional[FrameType]) -> None: @@ -25,38 +26,59 @@ def handler(signum: int, frame: Optional[FrameType]) -> None: class SystemController: - """Implements the full system controller - - Args: - engine: the image analyzer - cameras: the cameras to get the visual streams from - """ - - def __init__(self, engine: Engine, cameras: List[ReolinkCamera]) -> None: + def __init__(self, engine, cameras): self.engine = engine self.cameras = cameras + self.prediction_results: Queue[PredictionResult] = Queue() # Queue for handling results - def analyze_stream(self, idx: int) -> None: - assert 0 <= idx < len(self.cameras) + def capture_and_predict(self, idx): + """Capture an image from the camera and perform prediction in a single function.""" try: img = self.cameras[idx].capture() + except Exception: + logging.warning(f"Unable to fetch stream from camera {self.cameras[idx]}") + if img is not None: try: - self.engine.predict(img, self.cameras[idx].ip_address) + preds = self.engine.predict(img, self.cameras[idx].ip_address) + # Send the result along with the image and camera ID for further processing + self.prediction_results.put((preds, img, self.cameras[idx].ip_address)) except Exception: logging.warning(f"Unable to analyze stream from camera {self.cameras[idx]}") - except Exception: - logging.warning(f"Unable to fetch stream from camera {self.cameras[idx]}") + else: + logging.error(f"Failed to capture image from camera {self.cameras[idx].ip_address}") - def run(self, period=30): - """Analyzes all camera streams""" - for idx in range(len(self.cameras)): + def process_results(self): + """Process results sequentially from the results queue.""" + while not self.prediction_results.empty(): try: - signal.signal(signal.SIGALRM, handler) - signal.alarm(int(period / len(self.cameras))) - self.analyze_stream(idx) - signal.alarm(0) + preds, frame, cam_id = self.prediction_results.get() + self.engine.process_prediction(preds, frame, cam_id) except Exception: - logging.warning(f"Analyze stream timeout on {self.cameras[idx]}") + logging.warning(f"Unable to process prediction from camera {cam_id}") + try: + # Uploading pending alerts + if len(self.engine._alerts) > 0: + self.engine._process_alerts() + except Exception: + logging.warning("Unable to process alerts") + + def run(self, period=30): + """Create a process for each camera to capture and predict simultaneously.""" + try: + signal.signal(signal.SIGALRM, handler) + signal.alarm(int(period)) + processes = [] + for idx in range(len(self.cameras)): + process = Process(target=self.capture_and_predict, args=(idx,)) + processes.append(process) + process.start() + + # Process all collected results + self.process_results() + + signal.alarm(0) + except Exception: + logging.warning(f"Analyze stream timeout on {self.cameras[idx]}") def __repr__(self) -> str: repr_str = f"{self.__class__.__name__}(" diff --git a/pyroengine/engine.py b/pyroengine/engine.py index 7e71f3bb..cf926925 100644 --- a/pyroengine/engine.py +++ b/pyroengine/engine.py @@ -267,7 +267,7 @@ def _update_states(self, frame: Image.Image, preds: np.ndarray, cam_key: str) -> return conf - def predict(self, frame: Image.Image, cam_id: Optional[str] = None) -> float: + def predict(self, frame: Image.Image, cam_id: Optional[str] = None) -> np.ndarray: """Computes the confidence that the image contains wildfire cues Args: @@ -284,6 +284,15 @@ def predict(self, frame: Image.Image, cam_id: Optional[str] = None) -> float: except ConnectionError: logging.warning(f"Unable to reach the pyro-api with {cam_id}") + cam_key = cam_id or "-1" + if is_day_time(self._cache, frame, self.day_time_strategy): + # Inference with ONNX + return self.model(frame.convert("RGB"), self.occlusion_masks[cam_key]) + + else: + return np.zeros((0, 5)) + + def process_prediction(self, preds: np.ndarray, frame: Image.Image, cam_id: Optional[str] = None): cam_key = cam_id or "-1" # Reduce image size to save bandwidth if isinstance(self.frame_size, tuple): @@ -292,8 +301,6 @@ def predict(self, frame: Image.Image, cam_id: Optional[str] = None) -> float: frame_resize = frame if is_day_time(self._cache, frame, self.day_time_strategy): - # Inference with ONNX - preds = self.model(frame.convert("RGB"), self.occlusion_masks[cam_key]) conf = self._update_states(frame_resize, preds, cam_key) # Log analysis result @@ -314,10 +321,6 @@ def predict(self, frame: Image.Image, cam_id: Optional[str] = None) -> float: else: conf = 0 # return default value - # Uploading pending alerts - if len(self._alerts) > 0: - self._process_alerts() - # Check if it's time to backup pending alerts ts = datetime.utcnow() if ts > self.last_cache_dump + timedelta(minutes=self.cache_backup_period): diff --git a/tests/test_core.py b/tests/test_core.py index 1f89285b..3404bf27 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,18 +1,19 @@ -import pytest +from unittest.mock import MagicMock from pyroengine.core import SystemController from pyroengine.engine import Engine -def test_systemcontroller(tmpdir_factory, mock_wildfire_image): - # Cache +def test_systemcontroller_with_mock_camera(tmpdir_factory): + # Setup folder = str(tmpdir_factory.mktemp("engine_cache")) - engine = Engine(cache_folder=folder) - cams = [] + # Creating a mock camera + mock_camera = MagicMock() + mock_camera.capture.return_value = "Mock Image" + mock_camera.ip_address = "192.168.1.1" + cams = [mock_camera] controller = SystemController(engine, cams) - with pytest.raises(AssertionError): - controller.analyze_stream(0) - - assert len(repr(controller).split("\n")) == 2 + # This should not raise an error as the camera is mocked + controller.capture_and_predict(0) diff --git a/tests/test_engine.py b/tests/test_engine.py index f050ba78..7ffdbb45 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -3,6 +3,7 @@ from datetime import datetime from pathlib import Path +import numpy as np from dotenv import load_dotenv from PIL import Image @@ -45,7 +46,10 @@ def test_engine_offline(tmpdir_factory, mock_wildfire_image, mock_forest_image): # inference engine = Engine(nb_consecutive_frames=4, cache_folder=folder) - out = engine.predict(mock_forest_image) + preds = engine.predict(mock_forest_image) + assert isinstance(preds, np.ndarray) + assert preds.shape == (0, 5) + out = engine.process_prediction(preds, mock_forest_image) assert isinstance(out, float) and 0 <= out <= 1 assert len(engine._states["-1"]["last_predictions"]) == 1 assert engine._states["-1"]["frame_count"] == 0 @@ -57,7 +61,10 @@ def test_engine_offline(tmpdir_factory, mock_wildfire_image, mock_forest_image): assert engine._states["-1"]["last_predictions"][0][3] < datetime.utcnow().isoformat() assert engine._states["-1"]["last_predictions"][0][4] is False - out = engine.predict(mock_wildfire_image) + preds = engine.predict(mock_wildfire_image) + assert isinstance(preds, np.ndarray) + assert preds.shape == (1, 5) + out = engine.process_prediction(preds, mock_wildfire_image) assert isinstance(out, float) and 0 <= out <= 1 assert len(engine._states["-1"]["last_predictions"]) == 2 assert engine._states["-1"]["ongoing"] is False @@ -68,7 +75,8 @@ def test_engine_offline(tmpdir_factory, mock_wildfire_image, mock_forest_image): assert engine._states["-1"]["last_predictions"][1][3] < datetime.utcnow().isoformat() assert engine._states["-1"]["last_predictions"][1][4] is False - out = engine.predict(mock_wildfire_image) + preds = engine.predict(mock_wildfire_image) + out = engine.process_prediction(preds, mock_wildfire_image) assert isinstance(out, float) and 0 <= out <= 1 assert len(engine._states["-1"]["last_predictions"]) == 3 assert engine._states["-1"]["ongoing"] is True @@ -110,12 +118,16 @@ def test_engine_online(tmpdir_factory, mock_wildfire_stream, mock_wildfire_image json_respone = response.json() assert start_ts < json_respone["last_ping"] < ts # Send an alert - engine.predict(mock_wildfire_image, "dummy_cam") + preds = engine.predict(mock_wildfire_image, "dummy_cam") + engine.process_prediction(preds, mock_wildfire_image, "dummy_cam") assert len(engine._states["dummy_cam"]["last_predictions"]) == 1 assert len(engine._alerts) == 0 assert engine._states["dummy_cam"]["ongoing"] is False - engine.predict(mock_wildfire_image, "dummy_cam") + preds = engine.predict(mock_wildfire_image, "dummy_cam") + engine.process_prediction(preds, mock_wildfire_image, "dummy_cam") + if len(engine._alerts) > 0: + engine._process_alerts() assert len(engine._states["dummy_cam"]["last_predictions"]) == 2 assert engine._states["dummy_cam"]["ongoing"] is True @@ -126,6 +138,7 @@ def test_engine_online(tmpdir_factory, mock_wildfire_stream, mock_wildfire_image response = engine._upload_frame("dummy_cam", mock_wildfire_stream) assert response.status_code // 100 == 2 # Upload frame in process - engine.predict(mock_wildfire_image, "dummy_cam") + preds = engine.predict(mock_wildfire_image, "dummy_cam") + engine.process_prediction(preds, mock_wildfire_image, "dummy_cam") # Check that a new media has been created & uploaded assert engine._states["dummy_cam"]["frame_count"] == 0