diff --git a/src/ert/cli/main.py b/src/ert/cli/main.py index 547ec323cd0..7552851de7a 100644 --- a/src/ert/cli/main.py +++ b/src/ert/cli/main.py @@ -24,7 +24,7 @@ ) from ert.namespace import Namespace from ert.plugins import ErtPluginManager -from ert.run_models.base_run_model import StatusEvents +from ert.run_models.event import StatusEvents from ert.run_models.model_factory import create_model from ert.storage import open_storage from ert.storage.local_storage import local_storage_set_ert_config diff --git a/src/ert/cli/monitor.py b/src/ert/cli/monitor.py index 921b584f222..429b997ca53 100644 --- a/src/ert/cli/monitor.py +++ b/src/ert/cli/monitor.py @@ -22,11 +22,11 @@ FORWARD_MODEL_STATE_FAILURE, REAL_STATE_TO_COLOR, ) -from ert.run_models.base_run_model import StatusEvents from ert.run_models.event import ( RunModelDataEvent, RunModelErrorEvent, RunModelUpdateEndEvent, + StatusEvents, ) from ert.shared.status.utils import format_running_time diff --git a/src/ert/ensemble_evaluator/snapshot.py b/src/ert/ensemble_evaluator/snapshot.py index 1b4c5ebb6de..15918b4907b 100644 --- a/src/ert/ensemble_evaluator/snapshot.py +++ b/src/ert/ensemble_evaluator/snapshot.py @@ -113,6 +113,11 @@ def __init__(self) -> None: sorted_fm_step_ids=defaultdict(list), ) + def __eq__(self, other: object) -> bool: + if not isinstance(other, EnsembleSnapshot): + return NotImplemented + return self.to_dict() == other.to_dict() + @classmethod def from_nested_dict(cls, source: Mapping[Any, Any]) -> EnsembleSnapshot: ensemble = EnsembleSnapshot() @@ -156,13 +161,17 @@ def to_dict(self) -> dict[str, Any]: if self._ensemble_state: dict_["status"] = self._ensemble_state if self._realization_snapshots: - dict_["reals"] = self._realization_snapshots + dict_["reals"] = { + k: _filter_nones(v) for k, v in self._realization_snapshots.items() + } for (real_id, fm_id), fm_values_dict in self._fm_step_snapshots.items(): if "reals" not in dict_: dict_["reals"] = {} if real_id not in dict_["reals"]: - dict_["reals"][real_id] = RealizationSnapshot(fm_steps={}) + dict_["reals"][real_id] = _filter_nones( + RealizationSnapshot(fm_steps={}) + ) if "fm_steps" not in dict_["reals"][real_id]: dict_["reals"][real_id]["fm_steps"] = {} @@ -392,15 +401,27 @@ class RealizationSnapshot(TypedDict, total=False): def _realization_dict_to_realization_snapshot( source: dict[str, Any], ) -> RealizationSnapshot: + start_time = source.get("start_time") + if start_time and isinstance(start_time, str): + start_time = datetime.fromisoformat(start_time) + end_time = source.get("end_time") + if end_time and isinstance(end_time, str): + end_time = datetime.fromisoformat(end_time) + realization = RealizationSnapshot( status=source.get("status"), active=source.get("active"), - start_time=source.get("start_time"), - end_time=source.get("end_time"), + start_time=start_time, + end_time=end_time, exec_hosts=source.get("exec_hosts"), message=source.get("message"), fm_steps=source.get("fm_steps", {}), ) + for step in realization["fm_steps"].values(): + if "start_time" in step and isinstance(step["start_time"], str): + step["start_time"] = datetime.fromisoformat(step["start_time"]) + if "end_time" in step and isinstance(step["end_time"], str): + step["end_time"] = datetime.fromisoformat(step["end_time"]) return _filter_nones(realization) diff --git a/src/ert/run_models/base_run_model.py b/src/ert/run_models/base_run_model.py index bd6392647de..37c46ba0352 100644 --- a/src/ert/run_models/base_run_model.py +++ b/src/ert/run_models/base_run_model.py @@ -20,9 +20,6 @@ from _ert.events import EESnapshot, EESnapshotUpdate, EETerminated, Event from ert.analysis import ( - AnalysisEvent, - AnalysisStatusEvent, - AnalysisTimeEvent, ErtAnalysisError, smoother_update, ) @@ -40,11 +37,6 @@ Monitor, Realization, ) -from ert.ensemble_evaluator.event import ( - EndEvent, - FullSnapshotEvent, - SnapshotUpdateEvent, -) from ert.ensemble_evaluator.identifiers import STATUS from ert.ensemble_evaluator.snapshot import EnsembleSnapshot from ert.ensemble_evaluator.state import ( @@ -63,12 +55,19 @@ from ..config.analysis_config import UpdateSettings from ..run_arg import RunArg from .event import ( + AnalysisEvent, + AnalysisStatusEvent, + AnalysisTimeEvent, + EndEvent, + FullSnapshotEvent, RunModelDataEvent, RunModelErrorEvent, RunModelStatusEvent, RunModelTimeEvent, RunModelUpdateBeginEvent, RunModelUpdateEndEvent, + SnapshotUpdateEvent, + StatusEvents, ) logger = logging.getLogger(__name__) @@ -76,21 +75,6 @@ if TYPE_CHECKING: from ert.config import QueueConfig -StatusEvents = ( - FullSnapshotEvent - | SnapshotUpdateEvent - | EndEvent - | AnalysisEvent - | AnalysisStatusEvent - | AnalysisTimeEvent - | RunModelErrorEvent - | RunModelStatusEvent - | RunModelTimeEvent - | RunModelUpdateBeginEvent - | RunModelDataEvent - | RunModelUpdateEndEvent -) - class OutOfOrderSnapshotUpdateException(ValueError): pass @@ -398,7 +382,7 @@ def get_current_status(self) -> dict[str, int]: status["Finished"] += ( self._get_number_of_finished_realizations_from_reruns() ) - return status + return dict(status) def _get_number_of_finished_realizations_from_reruns(self) -> int: return self.active_realizations.count( diff --git a/src/ert/run_models/event.py b/src/ert/run_models/event.py index 424c8ec9235..153a509b9c1 100644 --- a/src/ert/run_models/event.py +++ b/src/ert/run_models/event.py @@ -1,10 +1,23 @@ from __future__ import annotations -from dataclasses import dataclass +import json +from dataclasses import asdict, dataclass +from datetime import datetime from pathlib import Path from uuid import UUID +from ert.analysis import ( + AnalysisEvent, + AnalysisStatusEvent, + AnalysisTimeEvent, +) from ert.analysis.event import DataSection +from ert.ensemble_evaluator.event import ( + EndEvent, + FullSnapshotEvent, + SnapshotUpdateEvent, +) +from ert.ensemble_evaluator.snapshot import EnsembleSnapshot @dataclass @@ -56,3 +69,88 @@ class RunModelErrorEvent(RunModelEvent): def write_as_csv(self, output_path: Path | None) -> None: if output_path and self.data: self.data.to_csv("Report", output_path / str(self.run_id)) + + +StatusEvents = ( + AnalysisEvent + | AnalysisStatusEvent + | AnalysisTimeEvent + | EndEvent + | FullSnapshotEvent + | SnapshotUpdateEvent + | RunModelErrorEvent + | RunModelStatusEvent + | RunModelTimeEvent + | RunModelUpdateBeginEvent + | RunModelDataEvent + | RunModelUpdateEndEvent +) + + +EVENT_MAPPING = { + "AnalysisEvent": AnalysisEvent, + "AnalysisStatusEvent": AnalysisStatusEvent, + "AnalysisTimeEvent": AnalysisTimeEvent, + "EndEvent": EndEvent, + "FullSnapshotEvent": FullSnapshotEvent, + "SnapshotUpdateEvent": SnapshotUpdateEvent, + "RunModelErrorEvent": RunModelErrorEvent, + "RunModelStatusEvent": RunModelStatusEvent, + "RunModelTimeEvent": RunModelTimeEvent, + "RunModelUpdateBeginEvent": RunModelUpdateBeginEvent, + "RunModelDataEvent": RunModelDataEvent, + "RunModelUpdateEndEvent": RunModelUpdateEndEvent, +} + + +def status_event_from_json(json_str: str) -> StatusEvents: + json_dict = json.loads(json_str) + event_type = json_dict.pop("event_type", None) + + match event_type: + case FullSnapshotEvent.__name__: + snapshot = EnsembleSnapshot.from_nested_dict(json_dict["snapshot"]) + json_dict["snapshot"] = snapshot + return FullSnapshotEvent(**json_dict) + case SnapshotUpdateEvent.__name__: + snapshot = EnsembleSnapshot.from_nested_dict(json_dict["snapshot"]) + json_dict["snapshot"] = snapshot + return SnapshotUpdateEvent(**json_dict) + case RunModelDataEvent.__name__ | RunModelUpdateEndEvent.__name__: + if "run_id" in json_dict and isinstance(json_dict["run_id"], str): + json_dict["run_id"] = UUID(json_dict["run_id"]) + if json_dict.get("data"): + json_dict["data"] = DataSection(**json_dict["data"]) + return EVENT_MAPPING[event_type](**json_dict) + case _: + if event_type in EVENT_MAPPING: + if "run_id" in json_dict and isinstance(json_dict["run_id"], str): + json_dict["run_id"] = UUID(json_dict["run_id"]) + return EVENT_MAPPING[event_type](**json_dict) + else: + raise TypeError(f"Unknown status event type {event_type}") + + +def status_event_to_json(event: StatusEvents) -> str: + match event: + case FullSnapshotEvent() | SnapshotUpdateEvent(): + assert event.snapshot is not None + event_dict = asdict(event) + event_dict.update( + { + "snapshot": event.snapshot.to_dict(), + "event_type": event.__class__.__name__, + } + ) + return json.dumps( + event_dict, + default=lambda o: o.strftime("%Y-%m-%dT%H:%M:%S") + if isinstance(o, datetime) + else None, + ) + case StatusEvents: + event_dict = asdict(event) + event_dict["event_type"] = StatusEvents.__class__.__name__ + return json.dumps( + event_dict, default=lambda o: str(o) if isinstance(o, UUID) else None + ) diff --git a/src/ert/run_models/model_factory.py b/src/ert/run_models/model_factory.py index e1df99a9d5f..0cede97e4d2 100644 --- a/src/ert/run_models/model_factory.py +++ b/src/ert/run_models/model_factory.py @@ -32,7 +32,7 @@ import numpy.typing as npt from ert.namespace import Namespace - from ert.run_models.base_run_model import StatusEvents + from ert.run_models.event import StatusEvents from ert.storage import Storage diff --git a/tests/ert/__init__.py b/tests/ert/__init__.py index ee05dbc228c..e3db923e86d 100644 --- a/tests/ert/__init__.py +++ b/tests/ert/__init__.py @@ -38,6 +38,7 @@ def build( exec_hosts: str | None = None, start_time: datetime | None = None, end_time: datetime | None = None, + message: str | None = None, ) -> EnsembleSnapshot: snapshot = EnsembleSnapshot() snapshot._ensemble_state = status @@ -53,6 +54,7 @@ def build( end_time=end_time, exec_hosts=exec_hosts, status=status, + message=message, ), ) return snapshot diff --git a/tests/ert/unit_tests/run_models/test_status_events_serialization.py b/tests/ert/unit_tests/run_models/test_status_events_serialization.py new file mode 100644 index 00000000000..758b8546965 --- /dev/null +++ b/tests/ert/unit_tests/run_models/test_status_events_serialization.py @@ -0,0 +1,160 @@ +import uuid +from collections import defaultdict +from datetime import datetime as dt + +import pytest + +from ert.analysis.event import DataSection +from ert.ensemble_evaluator import state +from ert.ensemble_evaluator.snapshot import EnsembleSnapshotMetadata +from ert.run_models.event import ( + AnalysisEvent, + AnalysisStatusEvent, + AnalysisTimeEvent, + EndEvent, + FullSnapshotEvent, + RunModelDataEvent, + RunModelStatusEvent, + RunModelTimeEvent, + RunModelUpdateBeginEvent, + RunModelUpdateEndEvent, + SnapshotUpdateEvent, + status_event_from_json, + status_event_to_json, +) +from tests.ert import SnapshotBuilder + +METADATA = EnsembleSnapshotMetadata( + aggr_fm_step_status_colors=defaultdict(dict), + real_status_colors={}, + sorted_real_ids=[], + sorted_fm_step_ids=defaultdict(list), +) + + +@pytest.mark.parametrize( + "events", + [ + pytest.param( + [ + FullSnapshotEvent( + snapshot=( + SnapshotBuilder(metadata=METADATA) + .add_fm_step( + fm_step_id="0", + index="0", + name="fm_step_0", + status=state.FORWARD_MODEL_STATE_START, + current_memory_usage="500", + max_memory_usage="1000", + stdout="job_fm_step_0.stdout", + stderr="job_fm_step_0.stderr", + start_time=dt(1999, 1, 1), + ) + .add_fm_step( + fm_step_id="1", + index="1", + name="fm_step_1", + status=state.FORWARD_MODEL_STATE_START, + current_memory_usage="500", + max_memory_usage="1000", + stdout="job_fm_step_1.stdout", + stderr="job_fm_step_1.stderr", + start_time=dt(1999, 1, 1), + end_time=None, + ) + .build( + real_ids=["0", "1"], + status=state.REALIZATION_STATE_UNKNOWN, + start_time=dt(1999, 1, 1), + exec_hosts="12121.121", + message="Some message", + ) + ), + iteration_label="Foo", + total_iterations=1, + progress=0.25, + realization_count=4, + status_count={"Finished": 1, "Pending": 2, "Unknown": 1}, + iteration=0, + ), + SnapshotUpdateEvent( + snapshot=SnapshotBuilder(metadata=METADATA) + .add_fm_step( + fm_step_id="0", + index="0", + status=state.FORWARD_MODEL_STATE_FINISHED, + name="fm_step_0", + end_time=dt(2019, 1, 1), + ) + .build( + real_ids=["1"], + status=state.REALIZATION_STATE_RUNNING, + ), + iteration_label="Foo", + total_iterations=1, + progress=0.5, + realization_count=4, + status_count={"Finished": 2, "Running": 1, "Unknown": 1}, + iteration=0, + ), + SnapshotUpdateEvent( + snapshot=SnapshotBuilder(metadata=METADATA) + .add_fm_step( + fm_step_id="1", + index="1", + status=state.FORWARD_MODEL_STATE_FAILURE, + name="fm_step_1", + ) + .build( + real_ids=["0"], + status=state.REALIZATION_STATE_FAILED, + end_time=dt(2019, 1, 1), + ), + iteration_label="Foo", + total_iterations=1, + progress=0.5, + realization_count=4, + status_count={"Finished": 2, "Failed": 1, "Unknown": 1}, + iteration=0, + ), + AnalysisEvent(), + AnalysisStatusEvent(msg="hello"), + AnalysisTimeEvent(remaining_time=22.2, elapsed_time=200.42), + EndEvent(failed=False, msg=""), + RunModelStatusEvent(iteration=1, run_id=uuid.uuid1(), msg="Hello"), + RunModelTimeEvent( + iteration=1, + run_id=uuid.uuid1(), + remaining_time=10.42, + elapsed_time=100.42, + ), + RunModelUpdateBeginEvent(iteration=2, run_id=uuid.uuid1()), + RunModelDataEvent( + iteration=1, + run_id=uuid.uuid1(), + name="Micky", + data=DataSection( + header=["Some", "string", "elements"], + data=[["a", 1.1, "b"], ["c", 3]], + extra={"a": "b", "c": "d"}, + ), + ), + RunModelUpdateEndEvent( + iteration=3, + run_id=uuid.uuid1(), + data=DataSection( + header=["Some", "string", "elements"], + data=[["a", 1.1, "b"], ["c", 3]], + extra={"a": "b", "c": "d"}, + ), + ), + ], + ), + ], +) +def test_status_event_serialization(events): + for event in events: + json_res = status_event_to_json(event) + round_trip_event = status_event_from_json(json_res) + assert event == round_trip_event