Skip to content

Commit

Permalink
WIP: change to typeddict
Browse files Browse the repository at this point in the history
  • Loading branch information
JHolba committed Jun 20, 2024
1 parent 6409d85 commit 721c25b
Show file tree
Hide file tree
Showing 11 changed files with 218 additions and 197 deletions.
5 changes: 2 additions & 3 deletions src/ert/cli/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,8 @@ def _print_job_errors(self) -> None:
for snapshot in self._snapshots.values():
for real in snapshot.reals.values():
for job in real.forward_models.values():
if job["status"] == FORWARD_MODEL_STATE_FAILURE:
err = job["error"]
assert isinstance(err, str)
if job.get("status") == FORWARD_MODEL_STATE_FAILURE:
err = job.get("error")
result = failed_jobs.get(err, 0)
failed_jobs[err] = result + 1
for error, number_of_jobs in failed_jobs.items():
Expand Down
11 changes: 6 additions & 5 deletions src/ert/ensemble_evaluator/_builder/_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from _ert_forward_model_runner.client import Client
from ert.ensemble_evaluator import state
from ert.ensemble_evaluator.snapshot import (
ForwardModel,
PartialSnapshot,
RealizationSnapshot,
Snapshot,
Expand Down Expand Up @@ -160,11 +161,11 @@ def _create_snapshot(self) -> Snapshot:
status=state.REALIZATION_STATE_WAITING,
)
for index, forward_model in enumerate(real.forward_models):
reals[str(real.iens)].forward_models[str(index)] = {
"status": state.FORWARD_MODEL_STATE_START,
"index": str(index),
"name": forward_model.name,
}
reals[str(real.iens)].forward_models[str(index)] = ForwardModel(
status=state.FORWARD_MODEL_STATE_START,
index=str(index),
name=forward_model.name,
)
top = SnapshotDict(
reals=reals,
status=state.ENSEMBLE_STATE_UNKNOWN,
Expand Down
115 changes: 61 additions & 54 deletions src/ert/ensemble_evaluator/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
Optional,
Sequence,
Tuple,
TypedDict,
Union,
)

from typing_extensions import TypedDict

from cloudevents.http import CloudEvent
from dateutil.parser import parse
from pydantic import BaseModel
Expand Down Expand Up @@ -90,25 +91,24 @@ class SnapshotMetadata(TypedDict, total=False):
sorted_forward_model_ids: DefaultDict[RealId, List[FmStepId]]


def _filter_nones(some_dict: Dict[str, Any]) -> Dict[str, Any]:
def _filter_nones(some_dict: Union[Dict[str, Any], "ForwardModel"]) -> Dict[str, Any]:
return {key: value for key, value in some_dict.items() if value is not None}


class PartialSnapshot:
def __init__(self, snapshot: Optional["Snapshot"] = None) -> None:
self._realization_states: Dict[
str, Dict[str, Union[bool, datetime.datetime, str]]
str, Dict[str, Union[bool, datetime.datetime, str, Dict[str, "ForwardModel"]]]
] = defaultdict(dict)
"""A shallow dictionary of realization states. The key is a string with
realization number, pointing to a dict with keys active (bool),
start_time (datetime), end_time (datetime) and status (str)."""

self._forward_model_states: Dict[
Tuple[str, str], Dict[str, Union[str, datetime.datetime]]
] = defaultdict(dict)
Tuple[str, str], "ForwardModel"
] = defaultdict(lambda: ForwardModel())
"""A shallow dictionary of forward_model states. The key is a tuple of two
strings with realization id and forward_model id, pointing to a dict with
the same members as the ForwardModel."""
strings with realization id and forward_model id, pointing to a ForwardModel."""

self._ensemble_state: Optional[str] = None
# TODO not sure about possible values at this point, as GUI hijacks this one as
Expand Down Expand Up @@ -148,22 +148,20 @@ def update_forward_model(
self,
real_id: str,
forward_model_id: str,
forward_model: Dict[str, Union[str, datetime.datetime, None]],
forward_model: "ForwardModel",
) -> "PartialSnapshot":
forward_model_update = _filter_nones(forward_model)

self._forward_model_states[(real_id, forward_model_id)].update(
forward_model_update
forward_model
)
if self._snapshot:
self._snapshot._my_partial._forward_model_states[
(real_id, forward_model_id)
].update(forward_model_update)
].update(forward_model)
return self

def get_all_forward_models(
self,
) -> Mapping[Tuple[str, str], Dict[str, Union[str, datetime.datetime]]]:
) -> Mapping[Tuple[str, str], "ForwardModel"]:
if self._snapshot:
return self._snapshot.get_all_forward_models()
return {}
Expand Down Expand Up @@ -213,17 +211,15 @@ def to_dict(self) -> Dict[str, Any]:
if self._realization_states:
_dict["reals"] = self._realization_states

for fm_tuple, fm_values_dict in self._forward_model_states.items():
real_id = fm_tuple[0]
for (real_id, fm_id), fm_values_dict in self._forward_model_states.items():
if "reals" not in _dict:
_dict["reals"] = {}
if real_id not in _dict["reals"]:
_dict["reals"][real_id] = {}
if "forward_models" not in _dict["reals"][real_id]:
_dict["reals"][real_id]["forward_models"] = {}
_dict["reals"][real_id]["forward_models"] = ForwardModel()

forward_model_id = fm_tuple[1]
_dict["reals"][real_id]["forward_models"][forward_model_id] = fm_values_dict
_dict["reals"][real_id]["forward_models"][fm_id] = fm_values_dict

return _dict

Expand Down Expand Up @@ -274,7 +270,7 @@ def from_cloudevent(self, event: CloudEvent) -> "PartialSnapshot":
) in self._snapshot.get_forward_models_for_real(
_get_real_id(e_source)
).items():
if forward_model["status"] != state.FORWARD_MODEL_STATE_FINISHED:
if forward_model.get("status") != state.FORWARD_MODEL_STATE_FINISHED:
real_id = _get_real_id(e_source)
forward_model_idx = (real_id, forward_model_id)
if forward_model_idx not in self._forward_model_states:
Expand Down Expand Up @@ -305,25 +301,26 @@ def from_cloudevent(self, event: CloudEvent) -> "PartialSnapshot":
if event.data is not None:
error = event.data.get(ids.ERROR_MSG)

fm_dict = {
ids.STATUS: status,
ids.START_TIME: start_time,
ids.END_TIME: end_time,
ids.INDEX: _get_forward_model_index(e_source),
ids.ERROR: error,
}
fm = ForwardModel(
status=status,
start_time=start_time,
end_time=end_time,
index=_get_forward_model_index(e_source),
error=error
)

if e_type == ids.EVTYPE_FORWARD_MODEL_RUNNING:
fm_dict[ids.CURRENT_MEMORY_USAGE] = event.data.get(
fm[ids.CURRENT_MEMORY_USAGE] = event.data.get(

Check failure on line 313 in src/ert/ensemble_evaluator/snapshot.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

TypedDict key must be a string literal; expected one of ("status", "start_time", "end_time", "index", "current_memory_usage", ...)
ids.CURRENT_MEMORY_USAGE
)
fm_dict[ids.MAX_MEMORY_USAGE] = event.data.get(ids.MAX_MEMORY_USAGE)
fm[ids.MAX_MEMORY_USAGE] = event.data.get(ids.MAX_MEMORY_USAGE)

Check failure on line 316 in src/ert/ensemble_evaluator/snapshot.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

TypedDict key must be a string literal; expected one of ("status", "start_time", "end_time", "index", "current_memory_usage", ...)
if e_type == ids.EVTYPE_FORWARD_MODEL_START:
fm_dict[ids.STDOUT] = event.data.get(ids.STDOUT)
fm_dict[ids.STDERR] = event.data.get(ids.STDERR)
fm[ids.STDOUT] = event.data.get(ids.STDOUT)

Check failure on line 318 in src/ert/ensemble_evaluator/snapshot.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

TypedDict key must be a string literal; expected one of ("status", "start_time", "end_time", "index", "current_memory_usage", ...)
fm[ids.STDERR] = event.data.get(ids.STDERR)

Check failure on line 319 in src/ert/ensemble_evaluator/snapshot.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

TypedDict key must be a string literal; expected one of ("status", "start_time", "end_time", "index", "current_memory_usage", ...)
self.update_forward_model(
_get_real_id(e_source),
_get_forward_model_id(e_source),
fm_dict,
fm,
)

elif e_type in ids.EVGROUP_ENSEMBLE:
Expand Down Expand Up @@ -362,15 +359,16 @@ def metadata(self) -> SnapshotMetadata:

def get_all_forward_models(
self,
) -> Mapping[Tuple[str, str], Dict[str, Union[str, datetime.datetime]]]:
return self._my_partial._forward_model_states
) -> Mapping[Tuple[str, str], "ForwardModel"]:
return self._my_partial._forward_model_states.copy()

def get_forward_model_status_for_all_reals(
self,
) -> Mapping[Tuple[str, str], Union[str, datetime.datetime]]:
) -> Mapping[Tuple[str, str], str]:
return {
idx: forward_model_state["status"]
for idx, forward_model_state in self._my_partial._forward_model_states.items()
if "status" in forward_model_state and forward_model_state["status"] is not None
}

@property
Expand All @@ -379,20 +377,20 @@ def reals(self) -> Mapping[str, "RealizationSnapshot"]:

def get_forward_models_for_real(
self, real_id: str
) -> Dict[str, Dict[str, Union[str, datetime.datetime]]]:
) -> Dict[str, "ForwardModel"]:
return {
forward_model_idx[1]: forward_model_data
for forward_model_idx, forward_model_data in self._my_partial._forward_model_states.items()
if forward_model_idx[0] == real_id
fm_idx[1]: forward_model_data.copy()
for fm_idx, forward_model_data in self._my_partial._forward_model_states.items()
if fm_idx[0] == real_id
}

def get_real(self, real_id: str) -> "RealizationSnapshot":
return RealizationSnapshot(**self._my_partial._realization_states[real_id])

def get_job(
self, real_id: str, forward_model_id: str
) -> Dict[str, Union[str, datetime.datetime]]:
return self._my_partial._forward_model_states[(real_id, forward_model_id)]
) -> "ForwardModel":
return self._my_partial._forward_model_states[(real_id, forward_model_id)].copy()

def get_successful_realizations(self) -> typing.List[int]:
return [
Expand All @@ -413,13 +411,24 @@ def data(self) -> Mapping[str, Any]:
# The gui uses this
return self._my_partial.to_dict()

class ForwardModel(TypedDict, total=False):
status: Optional[str]
start_time: Optional[datetime.datetime]
end_time: Optional[datetime.datetime]
index: Optional[str]
current_memory_usage: Optional[str]
max_memory_usage: Optional[str]
name: Optional[str]
error: Optional[str]
stdout: Optional[str]
stderr: Optional[str]

class RealizationSnapshot(BaseModel):
status: Optional[str] = None
active: Optional[bool] = None
start_time: Optional[datetime.datetime] = None
end_time: Optional[datetime.datetime] = None
forward_models: Dict[str, Dict[str, Union[str, datetime.datetime]]] = {}
forward_models: Dict[str, ForwardModel] = {}


class SnapshotDict(BaseModel):
Expand All @@ -429,7 +438,7 @@ class SnapshotDict(BaseModel):


class SnapshotBuilder(BaseModel):
forward_models: Dict[str, Dict[str, Union[str, datetime.datetime]]] = {}
forward_models: Dict[str, ForwardModel] = {}
metadata: Dict[str, Any] = {}

def build(
Expand Down Expand Up @@ -463,18 +472,16 @@ def add_forward_model(
stdout: Optional[str] = None,
stderr: Optional[str] = None,
) -> "SnapshotBuilder":
self.forward_models[forward_model_id] = _filter_nones(
{
"status": status,
"index": index,
"start_time": start_time,
"end_time": end_time,
"name": name,
"stdout": stdout,
"stderr": stderr,
"current_memory_usage": current_memory_usage,
"max_memory_usage": max_memory_usage,
}
self.forward_models[forward_model_id] = ForwardModel(
status=status,
index=index,
start_time=start_time,
end_time=end_time,
name=name,
stdout=stdout,
stderr=stderr,
current_memory_usage=current_memory_usage,
max_memory_usage=max_memory_usage,
)
return self

Expand Down
7 changes: 4 additions & 3 deletions src/ert/gui/model/node.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from __future__ import annotations

import datetime
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum, auto
from typing import Any, Dict, Optional, Union
from typing import Any, Optional

from qtpy.QtGui import QColor

from ert.ensemble_evaluator.snapshot import ForwardModel


class NodeType(Enum):
ROOT = auto()
Expand Down Expand Up @@ -135,7 +136,7 @@ def row(self) -> int:
@dataclass
class ForwardModelStepNode(_Node):
parent: RealNode
data: Dict[str, Union[str, datetime.datetime]]
data: ForwardModel

def add_child(self, *args, **kwargs):
pass
7 changes: 3 additions & 4 deletions src/ert/gui/model/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,13 +218,12 @@ def _add_partial_snapshot(self, partial: PartialSnapshot, iter_: int) -> None:

jobs_changed_by_real[real_id].append(job_node.row())

job_without_nones = {k: v for (k, v) in job.items() if v is not None}
job_node.data.update(job_without_nones)
if job.get("current_memory_usage", None) is not None:
job_node.data.update(job)
if "current_memory_usage" in job and job["current_memory_usage"] is not None:
cur_mem_usage = int(float(job["current_memory_usage"]))
real_node.data.current_memory_usage = cur_mem_usage
self.root.data.current_memory_usage = cur_mem_usage
if job.get("max_memory_usage", None) is not None:
if "max_memory_usage" in job and job["max_memory_usage"] is not None:
max_mem_usage = int(float(job["max_memory_usage"]))
real_node.data.max_memory_usage = max_mem_usage
self.root.data.max_memory_usage = max_mem_usage
Expand Down
11 changes: 6 additions & 5 deletions tests/performance_tests/test_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ert.ensemble_evaluator import identifiers as ids
from ert.ensemble_evaluator import state
from ert.ensemble_evaluator.snapshot import (
ForwardModel,
PartialSnapshot,
RealizationSnapshot,
Snapshot,
Expand Down Expand Up @@ -74,11 +75,11 @@ def simulate_forward_model_event_handling(
status=state.REALIZATION_STATE_WAITING,
)
for fm_idx in range(forward_models):
reals[f"{real}"].forward_models[str(fm_idx)] = {
"status": state.FORWARD_MODEL_STATE_START,
"index": fm_idx,
"name": f"FM_{fm_idx}",
}
reals[f"{real}"].forward_models[str(fm_idx)] = ForwardModel(
status=state.FORWARD_MODEL_STATE_START,
index=fm_idx,
name=f"FM_{fm_idx}",
)
top = SnapshotDict(
reals=reals, status=state.ENSEMBLE_STATE_UNKNOWN, metadata={"foo": "bar"}
)
Expand Down
Loading

0 comments on commit 721c25b

Please sign in to comment.