From dda127162ef59b89942ffbc5aa0ae7c48a4ba7a0 Mon Sep 17 00:00:00 2001 From: mateo Date: Fri, 19 Apr 2024 16:32:25 +0200 Subject: [PATCH 01/10] try mp --- pyroengine/core.py | 99 +++++++++++++++++++++---------------------- pyroengine/engine.py | 34 ++++++--------- pyroengine/sensors.py | 2 +- src/run.py | 4 +- 4 files changed, 63 insertions(+), 76 deletions(-) diff --git a/pyroengine/core.py b/pyroengine/core.py index 776af058..74e3055a 100644 --- a/pyroengine/core.py +++ b/pyroengine/core.py @@ -1,62 +1,59 @@ -# Copyright (C) 2022-2024, Pyronear. - -# This program is licensed under the Apache License 2.0. -# See LICENSE or go to for full license details. - import logging -import signal -from types import FrameType -from typing import List, Optional - -import urllib3 - -from .engine import Engine -from .sensors import ReolinkCamera - -__all__ = ["SystemController"] - -urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) +import time +from multiprocessing import Process, Queue +from typing import Optional -logging.basicConfig(format="%(asctime)s | %(levelname)s: %(message)s", level=logging.INFO, force=True) - - -def handler(signum: int, frame: Optional[FrameType]) -> None: - raise Exception("Analyze stream timeout") +import numpy as np +from PIL import Image class SystemController: - """Implements the full system controller - - Args: - engine: the image analyzer - cameras: the cameras to get the visual streams from - """ - - def __init__(self, engine: Engine, cameras: List[ReolinkCamera]) -> None: + def __init__(self, engine, cameras): self.engine = engine self.cameras = cameras - - def analyze_stream(self, idx: int) -> None: - assert 0 <= idx < len(self.cameras) - try: - img = self.cameras[idx].capture() - try: - self.engine.predict(img, self.cameras[idx].ip_address) - except Exception: - logging.warning(f"Unable to analyze stream from camera {self.cameras[idx]}") - except Exception: - logging.warning(f"Unable to fetch stream from camera {self.cameras[idx]}") - - def run(self, period=30): - """Analyzes all camera streams""" + self.prediction_results = Queue() # Queue for handling results + + def capture_and_predict(self, idx): + """Capture an image from the camera and perform prediction in a single function.""" + img = self.cameras[idx].capture() + if img is not None: + preds = self.engine.predict(img, self.cameras[idx].ip_address) + print("pred", preds, idx) + # Send the result along with the image and camera ID for further processing + self.prediction_results.put((preds, img, self.cameras[idx].ip_address)) + else: + logging.error(f"Failed to capture image from camera {self.cameras[idx].ip_address}") + + def process_results(self, start_time): + """Process results sequentially from the results queue.""" + while not self.prediction_results.empty(): + preds, frame, cam_id = self.prediction_results.get() + print(cam_id, preds) + self.engine.process_prediction(preds, frame, cam_id) + + print(f"remain {30-(time.time()-start_time)} for sending") + + # Uploading pending alerts + if len(self.engine._alerts) > 0: + self.engine._process_alerts() + + def run(self): + """Create a process for each camera to capture and predict simultaneously.""" + start_time = time.time() + processes = [] for idx in range(len(self.cameras)): - try: - signal.signal(signal.SIGALRM, handler) - signal.alarm(int(period / len(self.cameras))) - self.analyze_stream(idx) - signal.alarm(0) - except Exception: - logging.warning(f"Analyze stream timeout on {self.cameras[idx]}") + process = Process(target=self.capture_and_predict, args=(idx,)) + processes.append(process) + process.start() + print(f"done capture and process after {time.time()-start_time}") + + # # Ensure all processes complete + # for process in processes: + # process.join() + + # Process all collected results + self.process_results(start_time) + print(f"done all {time.time()-start_time}") def __repr__(self) -> str: repr_str = f"{self.__class__.__name__}(" diff --git a/pyroengine/engine.py b/pyroengine/engine.py index 7e71f3bb..aa7c255c 100644 --- a/pyroengine/engine.py +++ b/pyroengine/engine.py @@ -234,6 +234,8 @@ def _update_states(self, frame: Image.Image, preds: np.ndarray, cam_key: str) -> if box.shape[0] > 0: boxes = np.concatenate([boxes, box]) + print("boxes", cam_key, boxes, preds) + conf = 0 output_predictions = np.zeros((0, 5)) # Get the best ones @@ -243,6 +245,7 @@ def _update_states(self, frame: Image.Image, preds: np.ndarray, cam_key: str) -> best_boxes_scores = np.array([sum(boxes[iou > 0, 4]) for iou in ious.T]) combine_predictions = best_boxes[best_boxes_scores > conf_th, :] conf = np.max(best_boxes_scores) / (self.nb_consecutive_frames + 1) # memory + preds + print("preds", conf) if len(combine_predictions): @@ -284,6 +287,15 @@ def predict(self, frame: Image.Image, cam_id: Optional[str] = None) -> float: except ConnectionError: logging.warning(f"Unable to reach the pyro-api with {cam_id}") + cam_key = cam_id or "-1" + if is_day_time(self._cache, frame, self.day_time_strategy): + # Inference with ONNX + return self.model(frame.convert("RGB"), self.occlusion_masks[cam_key]) + + else: + return np.zeros((0, 5)) + + def process_prediction(self, preds: np.ndarray, frame: Image.Image, cam_id: Optional[str] = None): cam_key = cam_id or "-1" # Reduce image size to save bandwidth if isinstance(self.frame_size, tuple): @@ -292,8 +304,6 @@ def predict(self, frame: Image.Image, cam_id: Optional[str] = None) -> float: frame_resize = frame if is_day_time(self._cache, frame, self.day_time_strategy): - # Inference with ONNX - preds = self.model(frame.convert("RGB"), self.occlusion_masks[cam_key]) conf = self._update_states(frame_resize, preds, cam_key) # Log analysis result @@ -314,32 +324,12 @@ def predict(self, frame: Image.Image, cam_id: Optional[str] = None) -> float: else: conf = 0 # return default value - # Uploading pending alerts - if len(self._alerts) > 0: - self._process_alerts() - # Check if it's time to backup pending alerts ts = datetime.utcnow() if ts > self.last_cache_dump + timedelta(minutes=self.cache_backup_period): self._dump_cache() self.last_cache_dump = ts - # save frame - if len(self.api_client) > 0 and isinstance(self.frame_saving_period, int) and isinstance(cam_id, str): - self._states[cam_key]["frame_count"] += 1 - if self._states[cam_key]["frame_count"] == self.frame_saving_period: - # Save frame on device - self._local_backup(frame_resize, cam_id, is_alert=False) - # Send frame to the api - stream = io.BytesIO() - frame_resize.save(stream, format="JPEG", quality=self.jpeg_quality) - try: - self._upload_frame(cam_id, stream.getvalue()) - # Reset frame counter - self._states[cam_key]["frame_count"] = 0 - except ConnectionError: - stream.seek(0) # "Rewind" the stream to the beginning so we can read its content - return float(conf) def _upload_frame(self, cam_id: str, media_data: bytes) -> Response: diff --git a/pyroengine/sensors.py b/pyroengine/sensors.py index b91e78bc..2deb5ed7 100644 --- a/pyroengine/sensors.py +++ b/pyroengine/sensors.py @@ -13,7 +13,7 @@ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) -CAM_URL = "https://{ip_address}/cgi-bin/api.cgi?cmd=Snap&channel=0&rs=wuuPhkmUCeI9WG7C&user={login}&password={password}" +CAM_URL = "http://{ip_address}/cgi-bin/api.cgi?cmd=Snap&channel=0&rs=wuuPhkmUCeI9WG7C&user={login}&password={password}" class ReolinkCamera: diff --git a/src/run.py b/src/run.py index 88f2151c..f6ff9d6c 100644 --- a/src/run.py +++ b/src/run.py @@ -70,7 +70,7 @@ def main(args): while True: start_ts = time.time() - sys_controller.run(args.period) + sys_controller.run() # Sleep only once all images are processed time.sleep(max(args.period - time.time() + start_ts, 0)) @@ -107,7 +107,7 @@ def main(args): parser.add_argument("--backup-size", type=int, default=10000, help="Local backup can't be bigger than 10Go") # Time config - parser.add_argument("--period", type=int, default=30, help="Number of seconds between each camera stream analysis") + parser.add_argument("--period", type=int, default=5, help="Number of seconds between each camera stream analysis") parser.add_argument("--save-period", type=int, default=3600, help="Number of seconds between each media save") args = parser.parse_args() From 148a4e178d920785052bf9eae0e89d80f189cce5 Mon Sep 17 00:00:00 2001 From: mateo Date: Fri, 19 Apr 2024 16:56:39 +0200 Subject: [PATCH 02/10] remove debugging --- pyroengine/core.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/pyroengine/core.py b/pyroengine/core.py index 74e3055a..63621018 100644 --- a/pyroengine/core.py +++ b/pyroengine/core.py @@ -18,7 +18,6 @@ def capture_and_predict(self, idx): img = self.cameras[idx].capture() if img is not None: preds = self.engine.predict(img, self.cameras[idx].ip_address) - print("pred", preds, idx) # Send the result along with the image and camera ID for further processing self.prediction_results.put((preds, img, self.cameras[idx].ip_address)) else: @@ -28,11 +27,8 @@ def process_results(self, start_time): """Process results sequentially from the results queue.""" while not self.prediction_results.empty(): preds, frame, cam_id = self.prediction_results.get() - print(cam_id, preds) self.engine.process_prediction(preds, frame, cam_id) - print(f"remain {30-(time.time()-start_time)} for sending") - # Uploading pending alerts if len(self.engine._alerts) > 0: self.engine._process_alerts() @@ -45,15 +41,9 @@ def run(self): process = Process(target=self.capture_and_predict, args=(idx,)) processes.append(process) process.start() - print(f"done capture and process after {time.time()-start_time}") - - # # Ensure all processes complete - # for process in processes: - # process.join() # Process all collected results self.process_results(start_time) - print(f"done all {time.time()-start_time}") def __repr__(self) -> str: repr_str = f"{self.__class__.__name__}(" From c85b2b71a870d8bdbca867d008eb17e82c42f0a6 Mon Sep 17 00:00:00 2001 From: mateo Date: Fri, 19 Apr 2024 17:21:04 +0200 Subject: [PATCH 03/10] add try catch --- pyroengine/core.py | 75 ++++++++++++++++++++++++++++++-------------- pyroengine/engine.py | 3 -- src/run.py | 4 +-- 3 files changed, 54 insertions(+), 28 deletions(-) diff --git a/pyroengine/core.py b/pyroengine/core.py index 63621018..1d82bbd4 100644 --- a/pyroengine/core.py +++ b/pyroengine/core.py @@ -1,10 +1,22 @@ import logging +import signal +import threading import time from multiprocessing import Process, Queue +from types import FrameType from typing import Optional -import numpy as np -from PIL import Image +import urllib3 + +__all__ = ["SystemController"] + +urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + +logging.basicConfig(format="%(asctime)s | %(levelname)s: %(message)s", level=logging.INFO, force=True) + + +def handler(signum: int, frame: Optional[FrameType]) -> None: + raise Exception("Analyze stream timeout") class SystemController: @@ -15,35 +27,52 @@ def __init__(self, engine, cameras): def capture_and_predict(self, idx): """Capture an image from the camera and perform prediction in a single function.""" - img = self.cameras[idx].capture() + try: + img = self.cameras[idx].capture() + except Exception: + logging.warning(f"Unable to fetch stream from camera {self.cameras[idx]}") if img is not None: - preds = self.engine.predict(img, self.cameras[idx].ip_address) - # Send the result along with the image and camera ID for further processing - self.prediction_results.put((preds, img, self.cameras[idx].ip_address)) + try: + preds = self.engine.predict(img, self.cameras[idx].ip_address) + # Send the result along with the image and camera ID for further processing + self.prediction_results.put((preds, img, self.cameras[idx].ip_address)) + except Exception: + logging.warning(f"Unable to analyze stream from camera {self.cameras[idx]}") else: logging.error(f"Failed to capture image from camera {self.cameras[idx].ip_address}") - def process_results(self, start_time): + def process_results(self): """Process results sequentially from the results queue.""" while not self.prediction_results.empty(): - preds, frame, cam_id = self.prediction_results.get() - self.engine.process_prediction(preds, frame, cam_id) - - # Uploading pending alerts - if len(self.engine._alerts) > 0: - self.engine._process_alerts() + try: + preds, frame, cam_id = self.prediction_results.get() + self.engine.process_prediction(preds, frame, cam_id) + except Exception: + logging.warning(f"Unable to process prediction from camera {cam_id}") + try: + # Uploading pending alerts + if len(self.engine._alerts) > 0: + self.engine._process_alerts() + except Exception: + logging.warning(f"Unable to process alerts") - def run(self): + def run(self, period=30): """Create a process for each camera to capture and predict simultaneously.""" - start_time = time.time() - processes = [] - for idx in range(len(self.cameras)): - process = Process(target=self.capture_and_predict, args=(idx,)) - processes.append(process) - process.start() - - # Process all collected results - self.process_results(start_time) + try: + signal.signal(signal.SIGALRM, handler) + signal.alarm(int(period)) + processes = [] + for idx in range(len(self.cameras)): + process = Process(target=self.capture_and_predict, args=(idx,)) + processes.append(process) + process.start() + + # Process all collected results + self.process_results() + + signal.alarm(0) + except Exception: + logging.warning(f"Analyze stream timeout on {self.cameras[idx]}") def __repr__(self) -> str: repr_str = f"{self.__class__.__name__}(" diff --git a/pyroengine/engine.py b/pyroengine/engine.py index aa7c255c..dbde69b6 100644 --- a/pyroengine/engine.py +++ b/pyroengine/engine.py @@ -234,8 +234,6 @@ def _update_states(self, frame: Image.Image, preds: np.ndarray, cam_key: str) -> if box.shape[0] > 0: boxes = np.concatenate([boxes, box]) - print("boxes", cam_key, boxes, preds) - conf = 0 output_predictions = np.zeros((0, 5)) # Get the best ones @@ -245,7 +243,6 @@ def _update_states(self, frame: Image.Image, preds: np.ndarray, cam_key: str) -> best_boxes_scores = np.array([sum(boxes[iou > 0, 4]) for iou in ious.T]) combine_predictions = best_boxes[best_boxes_scores > conf_th, :] conf = np.max(best_boxes_scores) / (self.nb_consecutive_frames + 1) # memory + preds - print("preds", conf) if len(combine_predictions): diff --git a/src/run.py b/src/run.py index f6ff9d6c..88f2151c 100644 --- a/src/run.py +++ b/src/run.py @@ -70,7 +70,7 @@ def main(args): while True: start_ts = time.time() - sys_controller.run() + sys_controller.run(args.period) # Sleep only once all images are processed time.sleep(max(args.period - time.time() + start_ts, 0)) @@ -107,7 +107,7 @@ def main(args): parser.add_argument("--backup-size", type=int, default=10000, help="Local backup can't be bigger than 10Go") # Time config - parser.add_argument("--period", type=int, default=5, help="Number of seconds between each camera stream analysis") + parser.add_argument("--period", type=int, default=30, help="Number of seconds between each camera stream analysis") parser.add_argument("--save-period", type=int, default=3600, help="Number of seconds between each media save") args = parser.parse_args() From 6e85489be762b16d0e95b054e80a3504744171ac Mon Sep 17 00:00:00 2001 From: mateo Date: Fri, 19 Apr 2024 17:26:54 +0200 Subject: [PATCH 04/10] style --- pyroengine/core.py | 16 +++++++++++----- pyroengine/engine.py | 2 +- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/pyroengine/core.py b/pyroengine/core.py index 1d82bbd4..a202e987 100644 --- a/pyroengine/core.py +++ b/pyroengine/core.py @@ -1,18 +1,24 @@ +# Copyright (C) 2022-2024, Pyronear. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + import logging import signal -import threading -import time from multiprocessing import Process, Queue from types import FrameType -from typing import Optional +from typing import Optional, Tuple +import numpy as np import urllib3 +from PIL import Image __all__ = ["SystemController"] urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) logging.basicConfig(format="%(asctime)s | %(levelname)s: %(message)s", level=logging.INFO, force=True) +PredictionResult = Tuple[np.ndarray, Image.Image, str] def handler(signum: int, frame: Optional[FrameType]) -> None: @@ -23,7 +29,7 @@ class SystemController: def __init__(self, engine, cameras): self.engine = engine self.cameras = cameras - self.prediction_results = Queue() # Queue for handling results + self.prediction_results: Queue[PredictionResult] = Queue() # Queue for handling results def capture_and_predict(self, idx): """Capture an image from the camera and perform prediction in a single function.""" @@ -54,7 +60,7 @@ def process_results(self): if len(self.engine._alerts) > 0: self.engine._process_alerts() except Exception: - logging.warning(f"Unable to process alerts") + logging.warning("Unable to process alerts") def run(self, period=30): """Create a process for each camera to capture and predict simultaneously.""" diff --git a/pyroengine/engine.py b/pyroengine/engine.py index dbde69b6..d5de215d 100644 --- a/pyroengine/engine.py +++ b/pyroengine/engine.py @@ -267,7 +267,7 @@ def _update_states(self, frame: Image.Image, preds: np.ndarray, cam_key: str) -> return conf - def predict(self, frame: Image.Image, cam_id: Optional[str] = None) -> float: + def predict(self, frame: Image.Image, cam_id: Optional[str] = None) -> np.ndarray[Any, Any]: """Computes the confidence that the image contains wildfire cues Args: From 93c1950a5cd9de9470ff6b83ef2ac0e853aa4053 Mon Sep 17 00:00:00 2001 From: mateo Date: Mon, 22 Apr 2024 10:13:03 +0200 Subject: [PATCH 05/10] fix typeerrror --- pyroengine/engine.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyroengine/engine.py b/pyroengine/engine.py index d5de215d..d50935d2 100644 --- a/pyroengine/engine.py +++ b/pyroengine/engine.py @@ -17,6 +17,7 @@ import cv2 # type: ignore[import-untyped] import numpy as np +from numpy.typing import NDArray from PIL import Image from pyroclient import client from requests.exceptions import ConnectionError @@ -267,7 +268,7 @@ def _update_states(self, frame: Image.Image, preds: np.ndarray, cam_key: str) -> return conf - def predict(self, frame: Image.Image, cam_id: Optional[str] = None) -> np.ndarray[Any, Any]: + def predict(self, frame: Image.Image, cam_id: Optional[str] = None) -> np.ndarray: """Computes the confidence that the image contains wildfire cues Args: From 48bb7f1a8457eb738601832958743412a777033e Mon Sep 17 00:00:00 2001 From: mateo Date: Mon, 22 Apr 2024 10:16:10 +0200 Subject: [PATCH 06/10] unused import --- pyroengine/engine.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pyroengine/engine.py b/pyroengine/engine.py index d50935d2..9d1d43a2 100644 --- a/pyroengine/engine.py +++ b/pyroengine/engine.py @@ -17,7 +17,6 @@ import cv2 # type: ignore[import-untyped] import numpy as np -from numpy.typing import NDArray from PIL import Image from pyroclient import client from requests.exceptions import ConnectionError From 489d32beaecc29a307a79299b969b370cd49df0c Mon Sep 17 00:00:00 2001 From: mateo Date: Mon, 22 Apr 2024 10:28:12 +0200 Subject: [PATCH 07/10] not in this pr --- pyroengine/engine.py | 16 ++++++++++++++++ pyroengine/sensors.py | 2 +- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/pyroengine/engine.py b/pyroengine/engine.py index 9d1d43a2..cf926925 100644 --- a/pyroengine/engine.py +++ b/pyroengine/engine.py @@ -327,6 +327,22 @@ def process_prediction(self, preds: np.ndarray, frame: Image.Image, cam_id: Opti self._dump_cache() self.last_cache_dump = ts + # save frame + if len(self.api_client) > 0 and isinstance(self.frame_saving_period, int) and isinstance(cam_id, str): + self._states[cam_key]["frame_count"] += 1 + if self._states[cam_key]["frame_count"] == self.frame_saving_period: + # Save frame on device + self._local_backup(frame_resize, cam_id, is_alert=False) + # Send frame to the api + stream = io.BytesIO() + frame_resize.save(stream, format="JPEG", quality=self.jpeg_quality) + try: + self._upload_frame(cam_id, stream.getvalue()) + # Reset frame counter + self._states[cam_key]["frame_count"] = 0 + except ConnectionError: + stream.seek(0) # "Rewind" the stream to the beginning so we can read its content + return float(conf) def _upload_frame(self, cam_id: str, media_data: bytes) -> Response: diff --git a/pyroengine/sensors.py b/pyroengine/sensors.py index 2deb5ed7..b91e78bc 100644 --- a/pyroengine/sensors.py +++ b/pyroengine/sensors.py @@ -13,7 +13,7 @@ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) -CAM_URL = "http://{ip_address}/cgi-bin/api.cgi?cmd=Snap&channel=0&rs=wuuPhkmUCeI9WG7C&user={login}&password={password}" +CAM_URL = "https://{ip_address}/cgi-bin/api.cgi?cmd=Snap&channel=0&rs=wuuPhkmUCeI9WG7C&user={login}&password={password}" class ReolinkCamera: From 84ebef9ac2537834202a559bd7828141ffafa84c Mon Sep 17 00:00:00 2001 From: mateo Date: Mon, 22 Apr 2024 11:35:06 +0200 Subject: [PATCH 08/10] fix tests --- tests/test_core.py | 19 ++++++++++--------- tests/test_engine.py | 24 ++++++++++++++++++------ 2 files changed, 28 insertions(+), 15 deletions(-) diff --git a/tests/test_core.py b/tests/test_core.py index 1f89285b..3404bf27 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,18 +1,19 @@ -import pytest +from unittest.mock import MagicMock from pyroengine.core import SystemController from pyroengine.engine import Engine -def test_systemcontroller(tmpdir_factory, mock_wildfire_image): - # Cache +def test_systemcontroller_with_mock_camera(tmpdir_factory): + # Setup folder = str(tmpdir_factory.mktemp("engine_cache")) - engine = Engine(cache_folder=folder) - cams = [] + # Creating a mock camera + mock_camera = MagicMock() + mock_camera.capture.return_value = "Mock Image" + mock_camera.ip_address = "192.168.1.1" + cams = [mock_camera] controller = SystemController(engine, cams) - with pytest.raises(AssertionError): - controller.analyze_stream(0) - - assert len(repr(controller).split("\n")) == 2 + # This should not raise an error as the camera is mocked + controller.capture_and_predict(0) diff --git a/tests/test_engine.py b/tests/test_engine.py index f050ba78..3ac30082 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -3,6 +3,7 @@ from datetime import datetime from pathlib import Path +import numpy as np from dotenv import load_dotenv from PIL import Image @@ -45,7 +46,10 @@ def test_engine_offline(tmpdir_factory, mock_wildfire_image, mock_forest_image): # inference engine = Engine(nb_consecutive_frames=4, cache_folder=folder) - out = engine.predict(mock_forest_image) + preds = engine.predict(mock_forest_image) + assert isinstance(preds, np.ndarray) + assert preds.shape == (0, 5) + out = engine.process_prediction(preds, mock_forest_image) assert isinstance(out, float) and 0 <= out <= 1 assert len(engine._states["-1"]["last_predictions"]) == 1 assert engine._states["-1"]["frame_count"] == 0 @@ -57,7 +61,10 @@ def test_engine_offline(tmpdir_factory, mock_wildfire_image, mock_forest_image): assert engine._states["-1"]["last_predictions"][0][3] < datetime.utcnow().isoformat() assert engine._states["-1"]["last_predictions"][0][4] is False - out = engine.predict(mock_wildfire_image) + preds = engine.predict(mock_wildfire_image) + assert isinstance(preds, np.ndarray) + assert preds.shape == (1, 5) + out = engine.process_prediction(preds, mock_wildfire_image) assert isinstance(out, float) and 0 <= out <= 1 assert len(engine._states["-1"]["last_predictions"]) == 2 assert engine._states["-1"]["ongoing"] is False @@ -68,7 +75,8 @@ def test_engine_offline(tmpdir_factory, mock_wildfire_image, mock_forest_image): assert engine._states["-1"]["last_predictions"][1][3] < datetime.utcnow().isoformat() assert engine._states["-1"]["last_predictions"][1][4] is False - out = engine.predict(mock_wildfire_image) + preds = engine.predict(mock_wildfire_image) + out = engine.process_prediction(preds, mock_wildfire_image) assert isinstance(out, float) and 0 <= out <= 1 assert len(engine._states["-1"]["last_predictions"]) == 3 assert engine._states["-1"]["ongoing"] is True @@ -110,12 +118,15 @@ def test_engine_online(tmpdir_factory, mock_wildfire_stream, mock_wildfire_image json_respone = response.json() assert start_ts < json_respone["last_ping"] < ts # Send an alert - engine.predict(mock_wildfire_image, "dummy_cam") + preds = engine.predict(mock_wildfire_image, "dummy_cam") + engine.process_prediction(preds, mock_wildfire_image, "dummy_cam") assert len(engine._states["dummy_cam"]["last_predictions"]) == 1 assert len(engine._alerts) == 0 assert engine._states["dummy_cam"]["ongoing"] is False - engine.predict(mock_wildfire_image, "dummy_cam") + preds = engine.predict(mock_wildfire_image, "dummy_cam") + engine.process_prediction(preds, mock_wildfire_image, "dummy_cam") + print(engine._states["dummy_cam"]) assert len(engine._states["dummy_cam"]["last_predictions"]) == 2 assert engine._states["dummy_cam"]["ongoing"] is True @@ -126,6 +137,7 @@ def test_engine_online(tmpdir_factory, mock_wildfire_stream, mock_wildfire_image response = engine._upload_frame("dummy_cam", mock_wildfire_stream) assert response.status_code // 100 == 2 # Upload frame in process - engine.predict(mock_wildfire_image, "dummy_cam") + preds = engine.predict(mock_wildfire_image, "dummy_cam") + engine.process_prediction(preds, mock_wildfire_image, "dummy_cam") # Check that a new media has been created & uploaded assert engine._states["dummy_cam"]["frame_count"] == 0 From e974c8adc7adda7a08ad051b59f32f000d4c9d33 Mon Sep 17 00:00:00 2001 From: mateo Date: Mon, 22 Apr 2024 11:35:34 +0200 Subject: [PATCH 09/10] remove debug --- tests/test_engine.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_engine.py b/tests/test_engine.py index 3ac30082..2cd11252 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -126,7 +126,6 @@ def test_engine_online(tmpdir_factory, mock_wildfire_stream, mock_wildfire_image preds = engine.predict(mock_wildfire_image, "dummy_cam") engine.process_prediction(preds, mock_wildfire_image, "dummy_cam") - print(engine._states["dummy_cam"]) assert len(engine._states["dummy_cam"]["last_predictions"]) == 2 assert engine._states["dummy_cam"]["ongoing"] is True From 7c14434487c8b3bf5a038fa02f9937aef84ffea7 Mon Sep 17 00:00:00 2001 From: mateo Date: Mon, 22 Apr 2024 11:41:37 +0200 Subject: [PATCH 10/10] missing process alerts --- tests/test_engine.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_engine.py b/tests/test_engine.py index 2cd11252..7ffdbb45 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -126,6 +126,8 @@ def test_engine_online(tmpdir_factory, mock_wildfire_stream, mock_wildfire_image preds = engine.predict(mock_wildfire_image, "dummy_cam") engine.process_prediction(preds, mock_wildfire_image, "dummy_cam") + if len(engine._alerts) > 0: + engine._process_alerts() assert len(engine._states["dummy_cam"]["last_predictions"]) == 2 assert engine._states["dummy_cam"]["ongoing"] is True