Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add event serialization testing #9573

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/ert/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
)
from ert.namespace import Namespace
from ert.plugins import ErtPluginManager
from ert.run_models.base_run_model import StatusEvents
from ert.run_models.event import StatusEvents
from ert.run_models.model_factory import create_model
from ert.storage import open_storage
from ert.storage.local_storage import local_storage_set_ert_config
Expand Down
2 changes: 1 addition & 1 deletion src/ert/cli/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
FORWARD_MODEL_STATE_FAILURE,
REAL_STATE_TO_COLOR,
)
from ert.run_models.base_run_model import StatusEvents
from ert.run_models.event import (
RunModelDataEvent,
RunModelErrorEvent,
RunModelUpdateEndEvent,
StatusEvents,
)
from ert.shared.status.utils import format_running_time

Expand Down
29 changes: 25 additions & 4 deletions src/ert/ensemble_evaluator/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ def __init__(self) -> None:
sorted_fm_step_ids=defaultdict(list),
)

def __eq__(self, other: object) -> bool:
if not isinstance(other, EnsembleSnapshot):
return NotImplemented
return self.to_dict() == other.to_dict()

@classmethod
def from_nested_dict(cls, source: Mapping[Any, Any]) -> EnsembleSnapshot:
ensemble = EnsembleSnapshot()
Expand Down Expand Up @@ -156,13 +161,17 @@ def to_dict(self) -> dict[str, Any]:
if self._ensemble_state:
dict_["status"] = self._ensemble_state
if self._realization_snapshots:
dict_["reals"] = self._realization_snapshots
dict_["reals"] = {
k: _filter_nones(v) for k, v in self._realization_snapshots.items()
}

for (real_id, fm_id), fm_values_dict in self._fm_step_snapshots.items():
if "reals" not in dict_:
dict_["reals"] = {}
if real_id not in dict_["reals"]:
dict_["reals"][real_id] = RealizationSnapshot(fm_steps={})
dict_["reals"][real_id] = _filter_nones(
RealizationSnapshot(fm_steps={})
)
if "fm_steps" not in dict_["reals"][real_id]:
dict_["reals"][real_id]["fm_steps"] = {}

Expand Down Expand Up @@ -392,15 +401,27 @@ class RealizationSnapshot(TypedDict, total=False):
def _realization_dict_to_realization_snapshot(
source: dict[str, Any],
) -> RealizationSnapshot:
start_time = source.get("start_time")
if start_time and isinstance(start_time, str):
start_time = datetime.fromisoformat(start_time)
end_time = source.get("end_time")
if end_time and isinstance(end_time, str):
end_time = datetime.fromisoformat(end_time)

Comment on lines +404 to +410
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where does this not happen? Seems like we mostly do:

def convert_iso8601_to_datetime(
    timestamp: datetime | str,
) -> datetime:
    if isinstance(timestamp, datetime):
        return timestamp

    return datetime.fromisoformat(timestamp)

realization = RealizationSnapshot(
status=source.get("status"),
active=source.get("active"),
start_time=source.get("start_time"),
end_time=source.get("end_time"),
start_time=start_time,
end_time=end_time,
exec_hosts=source.get("exec_hosts"),
message=source.get("message"),
fm_steps=source.get("fm_steps", {}),
)
for step in realization["fm_steps"].values():
if "start_time" in step and isinstance(step["start_time"], str):
step["start_time"] = datetime.fromisoformat(step["start_time"])
if "end_time" in step and isinstance(step["end_time"], str):
step["end_time"] = datetime.fromisoformat(step["end_time"])
Comment on lines +420 to +424
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here, should perhaps do it on the sender side?

return _filter_nones(realization)


Expand Down
32 changes: 8 additions & 24 deletions src/ert/run_models/base_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@

from _ert.events import EESnapshot, EESnapshotUpdate, EETerminated, Event
from ert.analysis import (
AnalysisEvent,
AnalysisStatusEvent,
AnalysisTimeEvent,
ErtAnalysisError,
smoother_update,
)
Expand All @@ -40,11 +37,6 @@
Monitor,
Realization,
)
from ert.ensemble_evaluator.event import (
EndEvent,
FullSnapshotEvent,
SnapshotUpdateEvent,
)
from ert.ensemble_evaluator.identifiers import STATUS
from ert.ensemble_evaluator.snapshot import EnsembleSnapshot
from ert.ensemble_evaluator.state import (
Expand All @@ -63,34 +55,26 @@
from ..config.analysis_config import UpdateSettings
from ..run_arg import RunArg
from .event import (
AnalysisEvent,
AnalysisStatusEvent,
AnalysisTimeEvent,
EndEvent,
FullSnapshotEvent,
RunModelDataEvent,
RunModelErrorEvent,
RunModelStatusEvent,
RunModelTimeEvent,
RunModelUpdateBeginEvent,
RunModelUpdateEndEvent,
SnapshotUpdateEvent,
StatusEvents,
)

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
from ert.config import QueueConfig

StatusEvents = (
FullSnapshotEvent
| SnapshotUpdateEvent
| EndEvent
| AnalysisEvent
| AnalysisStatusEvent
| AnalysisTimeEvent
| RunModelErrorEvent
| RunModelStatusEvent
| RunModelTimeEvent
| RunModelUpdateBeginEvent
| RunModelDataEvent
| RunModelUpdateEndEvent
)


class OutOfOrderSnapshotUpdateException(ValueError):
pass
Expand Down Expand Up @@ -398,7 +382,7 @@ def get_current_status(self) -> dict[str, int]:
status["Finished"] += (
self._get_number_of_finished_realizations_from_reruns()
)
return status
return dict(status)

def _get_number_of_finished_realizations_from_reruns(self) -> int:
return self.active_realizations.count(
Expand Down
100 changes: 99 additions & 1 deletion src/ert/run_models/event.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,23 @@
from __future__ import annotations

from dataclasses import dataclass
import json
from dataclasses import asdict, dataclass
from datetime import datetime
from pathlib import Path
from uuid import UUID

from ert.analysis import (
AnalysisEvent,
AnalysisStatusEvent,
AnalysisTimeEvent,
)
from ert.analysis.event import DataSection
from ert.ensemble_evaluator.event import (
EndEvent,
FullSnapshotEvent,
SnapshotUpdateEvent,
)
from ert.ensemble_evaluator.snapshot import EnsembleSnapshot


@dataclass
Expand Down Expand Up @@ -56,3 +69,88 @@ class RunModelErrorEvent(RunModelEvent):
def write_as_csv(self, output_path: Path | None) -> None:
if output_path and self.data:
self.data.to_csv("Report", output_path / str(self.run_id))


StatusEvents = (
AnalysisEvent
| AnalysisStatusEvent
| AnalysisTimeEvent
| EndEvent
| FullSnapshotEvent
| SnapshotUpdateEvent
| RunModelErrorEvent
| RunModelStatusEvent
| RunModelTimeEvent
| RunModelUpdateBeginEvent
| RunModelDataEvent
| RunModelUpdateEndEvent
)


EVENT_MAPPING = {
"AnalysisEvent": AnalysisEvent,
"AnalysisStatusEvent": AnalysisStatusEvent,
"AnalysisTimeEvent": AnalysisTimeEvent,
"EndEvent": EndEvent,
"FullSnapshotEvent": FullSnapshotEvent,
"SnapshotUpdateEvent": SnapshotUpdateEvent,
"RunModelErrorEvent": RunModelErrorEvent,
"RunModelStatusEvent": RunModelStatusEvent,
"RunModelTimeEvent": RunModelTimeEvent,
"RunModelUpdateBeginEvent": RunModelUpdateBeginEvent,
"RunModelDataEvent": RunModelDataEvent,
"RunModelUpdateEndEvent": RunModelUpdateEndEvent,
}


def status_event_from_json(json_str: str) -> StatusEvents:
json_dict = json.loads(json_str)
event_type = json_dict.pop("event_type", None)

match event_type:
case FullSnapshotEvent.__name__:
snapshot = EnsembleSnapshot.from_nested_dict(json_dict["snapshot"])
json_dict["snapshot"] = snapshot
return FullSnapshotEvent(**json_dict)
case SnapshotUpdateEvent.__name__:
snapshot = EnsembleSnapshot.from_nested_dict(json_dict["snapshot"])
json_dict["snapshot"] = snapshot
return SnapshotUpdateEvent(**json_dict)
case RunModelDataEvent.__name__ | RunModelUpdateEndEvent.__name__:
if "run_id" in json_dict and isinstance(json_dict["run_id"], str):
json_dict["run_id"] = UUID(json_dict["run_id"])
if json_dict.get("data"):
json_dict["data"] = DataSection(**json_dict["data"])
return EVENT_MAPPING[event_type](**json_dict)
case _:
if event_type in EVENT_MAPPING:
if "run_id" in json_dict and isinstance(json_dict["run_id"], str):
json_dict["run_id"] = UUID(json_dict["run_id"])
return EVENT_MAPPING[event_type](**json_dict)
else:
raise TypeError(f"Unknown status event type {event_type}")
Comment on lines +110 to +131
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think this could also be done by pydantic directly?
for example:

class EventWrapper(BaseModel):
    model_config = ConfigDict(arbitrary_types_allowed=True)
    event: Annotated[StatusEvents, Field(discriminator='event_type')]



def status_event_to_json(event: StatusEvents) -> str:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my PR I did:

await websocket.send_json(event)

so did not need a custom function, is that needed? It just takes StatusEvents:

StatusEvents = Union[
    FullSnapshotEvent,
    SnapshotUpdateEvent,
    EndEvent,
    AnalysisStatusEvent,
    AnalysisTimeEvent,
    AnalysisReportEvent,
    RunModelErrorEvent,
    RunModelStatusEvent,
    RunModelTimeEvent,
    RunModelUpdateBeginEvent,
    RunModelDataEvent,
    RunModelUpdateEndEvent,
]

match event:
case FullSnapshotEvent() | SnapshotUpdateEvent():
assert event.snapshot is not None
event_dict = asdict(event)
event_dict.update(
{
"snapshot": event.snapshot.to_dict(),
"event_type": event.__class__.__name__,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be better to add this directly to the events?

event_type: Literal['FullSnapshotEvent'] = 'FullSnapshotEvent'

the pro is it is more explicit, but on the flip side it is more verbose and has to be kept in sync for all events 🤔

}
)
return json.dumps(
event_dict,
default=lambda o: o.strftime("%Y-%m-%dT%H:%M:%S")
if isinstance(o, datetime)
else None,
)
case StatusEvents:
event_dict = asdict(event)
event_dict["event_type"] = StatusEvents.__class__.__name__
return json.dumps(
event_dict, default=lambda o: str(o) if isinstance(o, UUID) else None
)
2 changes: 1 addition & 1 deletion src/ert/run_models/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import numpy.typing as npt

from ert.namespace import Namespace
from ert.run_models.base_run_model import StatusEvents
from ert.run_models.event import StatusEvents
from ert.storage import Storage


Expand Down
2 changes: 2 additions & 0 deletions tests/ert/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def build(
exec_hosts: str | None = None,
start_time: datetime | None = None,
end_time: datetime | None = None,
message: str | None = None,
) -> EnsembleSnapshot:
snapshot = EnsembleSnapshot()
snapshot._ensemble_state = status
Expand All @@ -53,6 +54,7 @@ def build(
end_time=end_time,
exec_hosts=exec_hosts,
status=status,
message=message,
),
)
return snapshot
Expand Down
Loading
Loading