From 9281d76da1ff28496a58084c2a8be2e4e506fe59 Mon Sep 17 00:00:00 2001 From: DanSava Date: Mon, 6 Jan 2025 16:51:43 +0200 Subject: [PATCH] Make sure FullSnapshotEvent and SnapshotUpdateEvent is json serializable --- src/ert/ensemble_evaluator/event.py | 30 ++++++++++++++++++++++++++ src/ert/ensemble_evaluator/snapshot.py | 3 +++ src/ert/run_models/base_run_model.py | 2 +- tests/everest/test_everserver.py | 21 +++++++++++++++--- 4 files changed, 52 insertions(+), 4 deletions(-) diff --git a/src/ert/ensemble_evaluator/event.py b/src/ert/ensemble_evaluator/event.py index 5639269fd6f..deabab3c7c7 100644 --- a/src/ert/ensemble_evaluator/event.py +++ b/src/ert/ensemble_evaluator/event.py @@ -1,3 +1,4 @@ +import json from dataclasses import dataclass from .snapshot import EnsembleSnapshot @@ -27,3 +28,32 @@ class SnapshotUpdateEvent(_UpdateEvent): class EndEvent: failed: bool msg: str | None = None + + +def snapshot_event_from_json(json_str: str) -> FullSnapshotEvent | SnapshotUpdateEvent: + json_dict = json.loads(json_str) + snapshot = EnsembleSnapshot.from_nested_dict(json_dict["snapshot"]) + json_dict["snapshot"] = snapshot + match json_dict.pop("type"): + case "FullSnapshotEvent": + return FullSnapshotEvent(**json_dict) + case "SnapshotUpdateEvent": + return SnapshotUpdateEvent(**json_dict) + case unknown: + raise TypeError(f"Unknown snapshot update event type {unknown}") + + +def snapshot_event_to_json(event: FullSnapshotEvent | SnapshotUpdateEvent) -> str: + assert event.snapshot is not None + return json.dumps( + { + "iteration_label": event.iteration_label, + "total_iterations": event.total_iterations, + "progress": event.progress, + "realization_count": event.realization_count, + "status_count": event.status_count, + "iteration": event.iteration, + "snapshot": event.snapshot.to_dict(), + "type": event.__class__.__name__, + } + ) diff --git a/src/ert/ensemble_evaluator/snapshot.py b/src/ert/ensemble_evaluator/snapshot.py index 1b4c5ebb6de..e6f7b661a08 100644 --- a/src/ert/ensemble_evaluator/snapshot.py +++ b/src/ert/ensemble_evaluator/snapshot.py @@ -113,6 +113,9 @@ def __init__(self) -> None: sorted_fm_step_ids=defaultdict(list), ) + def __eq__(self, other: EnsembleSnapshot) -> bool: + return self.to_dict() == other.to_dict() + @classmethod def from_nested_dict(cls, source: Mapping[Any, Any]) -> EnsembleSnapshot: ensemble = EnsembleSnapshot() diff --git a/src/ert/run_models/base_run_model.py b/src/ert/run_models/base_run_model.py index bd6392647de..46542d4bb15 100644 --- a/src/ert/run_models/base_run_model.py +++ b/src/ert/run_models/base_run_model.py @@ -398,7 +398,7 @@ def get_current_status(self) -> dict[str, int]: status["Finished"] += ( self._get_number_of_finished_realizations_from_reruns() ) - return status + return dict(status) def _get_number_of_finished_realizations_from_reruns(self) -> int: return self.active_realizations.count( diff --git a/tests/everest/test_everserver.py b/tests/everest/test_everserver.py index 3793eab37f2..68f8af189a5 100644 --- a/tests/everest/test_everserver.py +++ b/tests/everest/test_everserver.py @@ -8,6 +8,11 @@ from seba_sqlite.snapshot import SebaSnapshot from _ert.events import event_from_json, event_to_json +from ert.ensemble_evaluator import FullSnapshotEvent, SnapshotUpdateEvent +from ert.ensemble_evaluator.event import ( + snapshot_event_from_json, + snapshot_event_to_json, +) from ert.run_models.everest_run_model import EverestExitCode, EverestRunModel from everest.config import EverestConfig, OptimizationConfig, ServerConfig from everest.detached import ServerStatus, everserver_status @@ -265,18 +270,28 @@ def check_status_round_tripping(status): config, simulation_callback=check_status_round_tripping, ) - + send_event = run_model.send_event send_snapshot_event = run_model.send_snapshot_event - def check_event_serialization_round_trip(*args, **_): + def check_snapshot_event_serialization_round_trip(*args, **_): event, _ = args event_json = event_to_json(event) round_trip_event = event_from_json(str(event_json)) assert event == round_trip_event send_snapshot_event(*args) - run_model.send_snapshot_event = check_event_serialization_round_trip + def check_event_serialization_round_trip(event): + if isinstance(event, (FullSnapshotEvent | SnapshotUpdateEvent)): + json_str = snapshot_event_to_json(event) + round_trip = snapshot_event_from_json(json_str) + assert event == round_trip + send_event(event) + + run_model.send_event = check_event_serialization_round_trip + run_model.send_snapshot_event = check_snapshot_event_serialization_round_trip evaluator_server_config = evaluator_server_config_generator(run_model) run_model.run_experiment(evaluator_server_config) + + assert run_model.result is not None