-
Notifications
You must be signed in to change notification settings - Fork 110
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add event serialization testing #9573
base: main
Are you sure you want to change the base?
Changes from all commits
002662f
9c6ec15
46e184d
ca2d82b
9cfb9d8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -113,6 +113,11 @@ def __init__(self) -> None: | |
sorted_fm_step_ids=defaultdict(list), | ||
) | ||
|
||
def __eq__(self, other: object) -> bool: | ||
if not isinstance(other, EnsembleSnapshot): | ||
return NotImplemented | ||
return self.to_dict() == other.to_dict() | ||
|
||
@classmethod | ||
def from_nested_dict(cls, source: Mapping[Any, Any]) -> EnsembleSnapshot: | ||
ensemble = EnsembleSnapshot() | ||
|
@@ -156,13 +161,17 @@ def to_dict(self) -> dict[str, Any]: | |
if self._ensemble_state: | ||
dict_["status"] = self._ensemble_state | ||
if self._realization_snapshots: | ||
dict_["reals"] = self._realization_snapshots | ||
dict_["reals"] = { | ||
k: _filter_nones(v) for k, v in self._realization_snapshots.items() | ||
} | ||
|
||
for (real_id, fm_id), fm_values_dict in self._fm_step_snapshots.items(): | ||
if "reals" not in dict_: | ||
dict_["reals"] = {} | ||
if real_id not in dict_["reals"]: | ||
dict_["reals"][real_id] = RealizationSnapshot(fm_steps={}) | ||
dict_["reals"][real_id] = _filter_nones( | ||
RealizationSnapshot(fm_steps={}) | ||
) | ||
if "fm_steps" not in dict_["reals"][real_id]: | ||
dict_["reals"][real_id]["fm_steps"] = {} | ||
|
||
|
@@ -392,15 +401,27 @@ class RealizationSnapshot(TypedDict, total=False): | |
def _realization_dict_to_realization_snapshot( | ||
source: dict[str, Any], | ||
) -> RealizationSnapshot: | ||
start_time = source.get("start_time") | ||
if start_time and isinstance(start_time, str): | ||
start_time = datetime.fromisoformat(start_time) | ||
end_time = source.get("end_time") | ||
if end_time and isinstance(end_time, str): | ||
end_time = datetime.fromisoformat(end_time) | ||
|
||
realization = RealizationSnapshot( | ||
status=source.get("status"), | ||
active=source.get("active"), | ||
start_time=source.get("start_time"), | ||
end_time=source.get("end_time"), | ||
start_time=start_time, | ||
end_time=end_time, | ||
exec_hosts=source.get("exec_hosts"), | ||
message=source.get("message"), | ||
fm_steps=source.get("fm_steps", {}), | ||
) | ||
for step in realization["fm_steps"].values(): | ||
if "start_time" in step and isinstance(step["start_time"], str): | ||
step["start_time"] = datetime.fromisoformat(step["start_time"]) | ||
if "end_time" in step and isinstance(step["end_time"], str): | ||
step["end_time"] = datetime.fromisoformat(step["end_time"]) | ||
Comment on lines
+420
to
+424
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment here, should perhaps do it on the sender side? |
||
return _filter_nones(realization) | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,23 @@ | ||
from __future__ import annotations | ||
|
||
from dataclasses import dataclass | ||
import json | ||
from dataclasses import asdict, dataclass | ||
from datetime import datetime | ||
from pathlib import Path | ||
from uuid import UUID | ||
|
||
from ert.analysis import ( | ||
AnalysisEvent, | ||
AnalysisStatusEvent, | ||
AnalysisTimeEvent, | ||
) | ||
from ert.analysis.event import DataSection | ||
from ert.ensemble_evaluator.event import ( | ||
EndEvent, | ||
FullSnapshotEvent, | ||
SnapshotUpdateEvent, | ||
) | ||
from ert.ensemble_evaluator.snapshot import EnsembleSnapshot | ||
|
||
|
||
@dataclass | ||
|
@@ -56,3 +69,88 @@ class RunModelErrorEvent(RunModelEvent): | |
def write_as_csv(self, output_path: Path | None) -> None: | ||
if output_path and self.data: | ||
self.data.to_csv("Report", output_path / str(self.run_id)) | ||
|
||
|
||
StatusEvents = ( | ||
AnalysisEvent | ||
| AnalysisStatusEvent | ||
| AnalysisTimeEvent | ||
| EndEvent | ||
| FullSnapshotEvent | ||
| SnapshotUpdateEvent | ||
| RunModelErrorEvent | ||
| RunModelStatusEvent | ||
| RunModelTimeEvent | ||
| RunModelUpdateBeginEvent | ||
| RunModelDataEvent | ||
| RunModelUpdateEndEvent | ||
) | ||
|
||
|
||
EVENT_MAPPING = { | ||
"AnalysisEvent": AnalysisEvent, | ||
"AnalysisStatusEvent": AnalysisStatusEvent, | ||
"AnalysisTimeEvent": AnalysisTimeEvent, | ||
"EndEvent": EndEvent, | ||
"FullSnapshotEvent": FullSnapshotEvent, | ||
"SnapshotUpdateEvent": SnapshotUpdateEvent, | ||
"RunModelErrorEvent": RunModelErrorEvent, | ||
"RunModelStatusEvent": RunModelStatusEvent, | ||
"RunModelTimeEvent": RunModelTimeEvent, | ||
"RunModelUpdateBeginEvent": RunModelUpdateBeginEvent, | ||
"RunModelDataEvent": RunModelDataEvent, | ||
"RunModelUpdateEndEvent": RunModelUpdateEndEvent, | ||
} | ||
|
||
|
||
def status_event_from_json(json_str: str) -> StatusEvents: | ||
json_dict = json.loads(json_str) | ||
event_type = json_dict.pop("event_type", None) | ||
|
||
match event_type: | ||
case FullSnapshotEvent.__name__: | ||
snapshot = EnsembleSnapshot.from_nested_dict(json_dict["snapshot"]) | ||
json_dict["snapshot"] = snapshot | ||
return FullSnapshotEvent(**json_dict) | ||
case SnapshotUpdateEvent.__name__: | ||
snapshot = EnsembleSnapshot.from_nested_dict(json_dict["snapshot"]) | ||
json_dict["snapshot"] = snapshot | ||
return SnapshotUpdateEvent(**json_dict) | ||
case RunModelDataEvent.__name__ | RunModelUpdateEndEvent.__name__: | ||
if "run_id" in json_dict and isinstance(json_dict["run_id"], str): | ||
json_dict["run_id"] = UUID(json_dict["run_id"]) | ||
if json_dict.get("data"): | ||
json_dict["data"] = DataSection(**json_dict["data"]) | ||
return EVENT_MAPPING[event_type](**json_dict) | ||
case _: | ||
if event_type in EVENT_MAPPING: | ||
if "run_id" in json_dict and isinstance(json_dict["run_id"], str): | ||
json_dict["run_id"] = UUID(json_dict["run_id"]) | ||
return EVENT_MAPPING[event_type](**json_dict) | ||
else: | ||
raise TypeError(f"Unknown status event type {event_type}") | ||
Comment on lines
+110
to
+131
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Think this could also be done by pydantic directly? class EventWrapper(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
event: Annotated[StatusEvents, Field(discriminator='event_type')] |
||
|
||
|
||
def status_event_to_json(event: StatusEvents) -> str: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In my PR I did: await websocket.send_json(event) so did not need a custom function, is that needed? It just takes StatusEvents = Union[
FullSnapshotEvent,
SnapshotUpdateEvent,
EndEvent,
AnalysisStatusEvent,
AnalysisTimeEvent,
AnalysisReportEvent,
RunModelErrorEvent,
RunModelStatusEvent,
RunModelTimeEvent,
RunModelUpdateBeginEvent,
RunModelDataEvent,
RunModelUpdateEndEvent,
] |
||
match event: | ||
case FullSnapshotEvent() | SnapshotUpdateEvent(): | ||
assert event.snapshot is not None | ||
event_dict = asdict(event) | ||
event_dict.update( | ||
{ | ||
"snapshot": event.snapshot.to_dict(), | ||
"event_type": event.__class__.__name__, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Might be better to add this directly to the events? event_type: Literal['FullSnapshotEvent'] = 'FullSnapshotEvent' the pro is it is more explicit, but on the flip side it is more verbose and has to be kept in sync for all events 🤔 |
||
} | ||
) | ||
return json.dumps( | ||
event_dict, | ||
default=lambda o: o.strftime("%Y-%m-%dT%H:%M:%S") | ||
if isinstance(o, datetime) | ||
else None, | ||
) | ||
case StatusEvents: | ||
event_dict = asdict(event) | ||
event_dict["event_type"] = StatusEvents.__class__.__name__ | ||
return json.dumps( | ||
event_dict, default=lambda o: str(o) if isinstance(o, UUID) else None | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where does this not happen? Seems like we mostly do: