Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multiprocessing engine #188

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 47 additions & 25 deletions pyroengine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,58 +5,80 @@

import logging
import signal
from multiprocessing import Process, Queue
from types import FrameType
from typing import List, Optional
from typing import Optional, Tuple

import numpy as np
import urllib3

from .engine import Engine
from .sensors import ReolinkCamera
from PIL import Image

__all__ = ["SystemController"]

urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)

logging.basicConfig(format="%(asctime)s | %(levelname)s: %(message)s", level=logging.INFO, force=True)
PredictionResult = Tuple[np.ndarray, Image.Image, str]


def handler(signum: int, frame: Optional[FrameType]) -> None:
raise Exception("Analyze stream timeout")


class SystemController:
"""Implements the full system controller

Args:
engine: the image analyzer
cameras: the cameras to get the visual streams from
"""

def __init__(self, engine: Engine, cameras: List[ReolinkCamera]) -> None:
def __init__(self, engine, cameras):
self.engine = engine
self.cameras = cameras
self.prediction_results: Queue[PredictionResult] = Queue() # Queue for handling results

def analyze_stream(self, idx: int) -> None:
assert 0 <= idx < len(self.cameras)
def capture_and_predict(self, idx):
"""Capture an image from the camera and perform prediction in a single function."""
try:
img = self.cameras[idx].capture()
except Exception:
logging.warning(f"Unable to fetch stream from camera {self.cameras[idx]}")
if img is not None:
try:
self.engine.predict(img, self.cameras[idx].ip_address)
preds = self.engine.predict(img, self.cameras[idx].ip_address)
# Send the result along with the image and camera ID for further processing
self.prediction_results.put((preds, img, self.cameras[idx].ip_address))
except Exception:
logging.warning(f"Unable to analyze stream from camera {self.cameras[idx]}")
except Exception:
logging.warning(f"Unable to fetch stream from camera {self.cameras[idx]}")
else:
logging.error(f"Failed to capture image from camera {self.cameras[idx].ip_address}")

def run(self, period=30):
"""Analyzes all camera streams"""
for idx in range(len(self.cameras)):
def process_results(self):
"""Process results sequentially from the results queue."""
while not self.prediction_results.empty():
try:
signal.signal(signal.SIGALRM, handler)
signal.alarm(int(period / len(self.cameras)))
self.analyze_stream(idx)
signal.alarm(0)
preds, frame, cam_id = self.prediction_results.get()
self.engine.process_prediction(preds, frame, cam_id)
except Exception:
logging.warning(f"Analyze stream timeout on {self.cameras[idx]}")
logging.warning(f"Unable to process prediction from camera {cam_id}")
try:
# Uploading pending alerts
if len(self.engine._alerts) > 0:
self.engine._process_alerts()
except Exception:
logging.warning("Unable to process alerts")

def run(self, period=30):
"""Create a process for each camera to capture and predict simultaneously."""
try:
signal.signal(signal.SIGALRM, handler)
signal.alarm(int(period))
processes = []
for idx in range(len(self.cameras)):
process = Process(target=self.capture_and_predict, args=(idx,))
processes.append(process)
process.start()

# Process all collected results
self.process_results()

signal.alarm(0)
except Exception:
logging.warning(f"Analyze stream timeout on {self.cameras[idx]}")

def __repr__(self) -> str:
repr_str = f"{self.__class__.__name__}("
Expand Down
17 changes: 10 additions & 7 deletions pyroengine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def _update_states(self, frame: Image.Image, preds: np.ndarray, cam_key: str) ->

return conf

def predict(self, frame: Image.Image, cam_id: Optional[str] = None) -> float:
def predict(self, frame: Image.Image, cam_id: Optional[str] = None) -> np.ndarray:
"""Computes the confidence that the image contains wildfire cues

Args:
Expand All @@ -284,6 +284,15 @@ def predict(self, frame: Image.Image, cam_id: Optional[str] = None) -> float:
except ConnectionError:
logging.warning(f"Unable to reach the pyro-api with {cam_id}")

cam_key = cam_id or "-1"
if is_day_time(self._cache, frame, self.day_time_strategy):
# Inference with ONNX
return self.model(frame.convert("RGB"), self.occlusion_masks[cam_key])

else:
return np.zeros((0, 5))

def process_prediction(self, preds: np.ndarray, frame: Image.Image, cam_id: Optional[str] = None):
cam_key = cam_id or "-1"
# Reduce image size to save bandwidth
if isinstance(self.frame_size, tuple):
Expand All @@ -292,8 +301,6 @@ def predict(self, frame: Image.Image, cam_id: Optional[str] = None) -> float:
frame_resize = frame

if is_day_time(self._cache, frame, self.day_time_strategy):
# Inference with ONNX
preds = self.model(frame.convert("RGB"), self.occlusion_masks[cam_key])
conf = self._update_states(frame_resize, preds, cam_key)

# Log analysis result
Expand All @@ -314,10 +321,6 @@ def predict(self, frame: Image.Image, cam_id: Optional[str] = None) -> float:
else:
conf = 0 # return default value

# Uploading pending alerts
if len(self._alerts) > 0:
self._process_alerts()

# Check if it's time to backup pending alerts
ts = datetime.utcnow()
if ts > self.last_cache_dump + timedelta(minutes=self.cache_backup_period):
Expand Down
19 changes: 10 additions & 9 deletions tests/test_core.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
import pytest
from unittest.mock import MagicMock

from pyroengine.core import SystemController
from pyroengine.engine import Engine


def test_systemcontroller(tmpdir_factory, mock_wildfire_image):
# Cache
def test_systemcontroller_with_mock_camera(tmpdir_factory):
# Setup
folder = str(tmpdir_factory.mktemp("engine_cache"))

engine = Engine(cache_folder=folder)
cams = []
# Creating a mock camera
mock_camera = MagicMock()
mock_camera.capture.return_value = "Mock Image"
mock_camera.ip_address = "192.168.1.1"
cams = [mock_camera]
controller = SystemController(engine, cams)

with pytest.raises(AssertionError):
controller.analyze_stream(0)

assert len(repr(controller).split("\n")) == 2
# This should not raise an error as the camera is mocked
controller.capture_and_predict(0)
25 changes: 19 additions & 6 deletions tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from datetime import datetime
from pathlib import Path

import numpy as np
from dotenv import load_dotenv
from PIL import Image

Expand Down Expand Up @@ -45,7 +46,10 @@ def test_engine_offline(tmpdir_factory, mock_wildfire_image, mock_forest_image):

# inference
engine = Engine(nb_consecutive_frames=4, cache_folder=folder)
out = engine.predict(mock_forest_image)
preds = engine.predict(mock_forest_image)
assert isinstance(preds, np.ndarray)
assert preds.shape == (0, 5)
out = engine.process_prediction(preds, mock_forest_image)
assert isinstance(out, float) and 0 <= out <= 1
assert len(engine._states["-1"]["last_predictions"]) == 1
assert engine._states["-1"]["frame_count"] == 0
Expand All @@ -57,7 +61,10 @@ def test_engine_offline(tmpdir_factory, mock_wildfire_image, mock_forest_image):
assert engine._states["-1"]["last_predictions"][0][3] < datetime.utcnow().isoformat()
assert engine._states["-1"]["last_predictions"][0][4] is False

out = engine.predict(mock_wildfire_image)
preds = engine.predict(mock_wildfire_image)
assert isinstance(preds, np.ndarray)
assert preds.shape == (1, 5)
out = engine.process_prediction(preds, mock_wildfire_image)
assert isinstance(out, float) and 0 <= out <= 1
assert len(engine._states["-1"]["last_predictions"]) == 2
assert engine._states["-1"]["ongoing"] is False
Expand All @@ -68,7 +75,8 @@ def test_engine_offline(tmpdir_factory, mock_wildfire_image, mock_forest_image):
assert engine._states["-1"]["last_predictions"][1][3] < datetime.utcnow().isoformat()
assert engine._states["-1"]["last_predictions"][1][4] is False

out = engine.predict(mock_wildfire_image)
preds = engine.predict(mock_wildfire_image)
out = engine.process_prediction(preds, mock_wildfire_image)
assert isinstance(out, float) and 0 <= out <= 1
assert len(engine._states["-1"]["last_predictions"]) == 3
assert engine._states["-1"]["ongoing"] is True
Expand Down Expand Up @@ -110,12 +118,16 @@ def test_engine_online(tmpdir_factory, mock_wildfire_stream, mock_wildfire_image
json_respone = response.json()
assert start_ts < json_respone["last_ping"] < ts
# Send an alert
engine.predict(mock_wildfire_image, "dummy_cam")
preds = engine.predict(mock_wildfire_image, "dummy_cam")
engine.process_prediction(preds, mock_wildfire_image, "dummy_cam")
assert len(engine._states["dummy_cam"]["last_predictions"]) == 1
assert len(engine._alerts) == 0
assert engine._states["dummy_cam"]["ongoing"] is False

engine.predict(mock_wildfire_image, "dummy_cam")
preds = engine.predict(mock_wildfire_image, "dummy_cam")
engine.process_prediction(preds, mock_wildfire_image, "dummy_cam")
if len(engine._alerts) > 0:
engine._process_alerts()
assert len(engine._states["dummy_cam"]["last_predictions"]) == 2

assert engine._states["dummy_cam"]["ongoing"] is True
Expand All @@ -126,6 +138,7 @@ def test_engine_online(tmpdir_factory, mock_wildfire_stream, mock_wildfire_image
response = engine._upload_frame("dummy_cam", mock_wildfire_stream)
assert response.status_code // 100 == 2
# Upload frame in process
engine.predict(mock_wildfire_image, "dummy_cam")
preds = engine.predict(mock_wildfire_image, "dummy_cam")
engine.process_prediction(preds, mock_wildfire_image, "dummy_cam")
# Check that a new media has been created & uploaded
assert engine._states["dummy_cam"]["frame_count"] == 0
Loading