diff --git a/src/ert/cli/monitor.py b/src/ert/cli/monitor.py index f86991f3733..ea9157c8865 100644 --- a/src/ert/cli/monitor.py +++ b/src/ert/cli/monitor.py @@ -109,9 +109,8 @@ def _print_job_errors(self) -> None: for snapshot in self._snapshots.values(): for real in snapshot.reals.values(): for job in real.forward_models.values(): - if job["status"] == FORWARD_MODEL_STATE_FAILURE: - err = job["error"] - assert isinstance(err, str) + if job.get("status") == FORWARD_MODEL_STATE_FAILURE: + err = job.get("error") result = failed_jobs.get(err, 0) failed_jobs[err] = result + 1 for error, number_of_jobs in failed_jobs.items(): diff --git a/src/ert/ensemble_evaluator/_builder/_ensemble.py b/src/ert/ensemble_evaluator/_builder/_ensemble.py index 38a9a4268a2..819655f287e 100644 --- a/src/ert/ensemble_evaluator/_builder/_ensemble.py +++ b/src/ert/ensemble_evaluator/_builder/_ensemble.py @@ -17,6 +17,7 @@ from _ert_forward_model_runner.client import Client from ert.ensemble_evaluator import state from ert.ensemble_evaluator.snapshot import ( + ForwardModel, PartialSnapshot, RealizationSnapshot, Snapshot, @@ -160,11 +161,11 @@ def _create_snapshot(self) -> Snapshot: status=state.REALIZATION_STATE_WAITING, ) for index, forward_model in enumerate(real.forward_models): - reals[str(real.iens)].forward_models[str(index)] = { - "status": state.FORWARD_MODEL_STATE_START, - "index": str(index), - "name": forward_model.name, - } + reals[str(real.iens)].forward_models[str(index)] = ForwardModel( + status=state.FORWARD_MODEL_STATE_START, + index=str(index), + name=forward_model.name, + ) top = SnapshotDict( reals=reals, status=state.ENSEMBLE_STATE_UNKNOWN, diff --git a/src/ert/ensemble_evaluator/snapshot.py b/src/ert/ensemble_evaluator/snapshot.py index 060ea0ea7dd..15c32588159 100644 --- a/src/ert/ensemble_evaluator/snapshot.py +++ b/src/ert/ensemble_evaluator/snapshot.py @@ -11,10 +11,11 @@ Optional, Sequence, Tuple, - TypedDict, Union, ) +from typing_extensions import TypedDict + from cloudevents.http import CloudEvent from dateutil.parser import parse from pydantic import BaseModel @@ -90,25 +91,24 @@ class SnapshotMetadata(TypedDict, total=False): sorted_forward_model_ids: DefaultDict[RealId, List[FmStepId]] -def _filter_nones(some_dict: Dict[str, Any]) -> Dict[str, Any]: +def _filter_nones(some_dict: Union[Dict[str, Any], "ForwardModel"]) -> Dict[str, Any]: return {key: value for key, value in some_dict.items() if value is not None} class PartialSnapshot: def __init__(self, snapshot: Optional["Snapshot"] = None) -> None: self._realization_states: Dict[ - str, Dict[str, Union[bool, datetime.datetime, str]] + str, Dict[str, Union[bool, datetime.datetime, str, Dict[str, "ForwardModel"]]] ] = defaultdict(dict) """A shallow dictionary of realization states. The key is a string with realization number, pointing to a dict with keys active (bool), start_time (datetime), end_time (datetime) and status (str).""" self._forward_model_states: Dict[ - Tuple[str, str], Dict[str, Union[str, datetime.datetime]] - ] = defaultdict(dict) + Tuple[str, str], "ForwardModel" + ] = defaultdict(lambda: ForwardModel()) """A shallow dictionary of forward_model states. The key is a tuple of two - strings with realization id and forward_model id, pointing to a dict with - the same members as the ForwardModel.""" + strings with realization id and forward_model id, pointing to a ForwardModel.""" self._ensemble_state: Optional[str] = None # TODO not sure about possible values at this point, as GUI hijacks this one as @@ -148,22 +148,20 @@ def update_forward_model( self, real_id: str, forward_model_id: str, - forward_model: Dict[str, Union[str, datetime.datetime, None]], + forward_model: "ForwardModel", ) -> "PartialSnapshot": - forward_model_update = _filter_nones(forward_model) - self._forward_model_states[(real_id, forward_model_id)].update( - forward_model_update + forward_model ) if self._snapshot: self._snapshot._my_partial._forward_model_states[ (real_id, forward_model_id) - ].update(forward_model_update) + ].update(forward_model) return self def get_all_forward_models( self, - ) -> Mapping[Tuple[str, str], Dict[str, Union[str, datetime.datetime]]]: + ) -> Mapping[Tuple[str, str], "ForwardModel"]: if self._snapshot: return self._snapshot.get_all_forward_models() return {} @@ -213,17 +211,15 @@ def to_dict(self) -> Dict[str, Any]: if self._realization_states: _dict["reals"] = self._realization_states - for fm_tuple, fm_values_dict in self._forward_model_states.items(): - real_id = fm_tuple[0] + for (real_id, fm_id), fm_values_dict in self._forward_model_states.items(): if "reals" not in _dict: _dict["reals"] = {} if real_id not in _dict["reals"]: _dict["reals"][real_id] = {} if "forward_models" not in _dict["reals"][real_id]: - _dict["reals"][real_id]["forward_models"] = {} + _dict["reals"][real_id]["forward_models"] = ForwardModel() - forward_model_id = fm_tuple[1] - _dict["reals"][real_id]["forward_models"][forward_model_id] = fm_values_dict + _dict["reals"][real_id]["forward_models"][fm_id] = fm_values_dict return _dict @@ -274,7 +270,7 @@ def from_cloudevent(self, event: CloudEvent) -> "PartialSnapshot": ) in self._snapshot.get_forward_models_for_real( _get_real_id(e_source) ).items(): - if forward_model["status"] != state.FORWARD_MODEL_STATE_FINISHED: + if forward_model.get("status") != state.FORWARD_MODEL_STATE_FINISHED: real_id = _get_real_id(e_source) forward_model_idx = (real_id, forward_model_id) if forward_model_idx not in self._forward_model_states: @@ -305,25 +301,26 @@ def from_cloudevent(self, event: CloudEvent) -> "PartialSnapshot": if event.data is not None: error = event.data.get(ids.ERROR_MSG) - fm_dict = { - ids.STATUS: status, - ids.START_TIME: start_time, - ids.END_TIME: end_time, - ids.INDEX: _get_forward_model_index(e_source), - ids.ERROR: error, - } + fm = ForwardModel( + status=status, + start_time=start_time, + end_time=end_time, + index=_get_forward_model_index(e_source), + error=error + ) + if e_type == ids.EVTYPE_FORWARD_MODEL_RUNNING: - fm_dict[ids.CURRENT_MEMORY_USAGE] = event.data.get( + fm[ids.CURRENT_MEMORY_USAGE] = event.data.get( ids.CURRENT_MEMORY_USAGE ) - fm_dict[ids.MAX_MEMORY_USAGE] = event.data.get(ids.MAX_MEMORY_USAGE) + fm[ids.MAX_MEMORY_USAGE] = event.data.get(ids.MAX_MEMORY_USAGE) if e_type == ids.EVTYPE_FORWARD_MODEL_START: - fm_dict[ids.STDOUT] = event.data.get(ids.STDOUT) - fm_dict[ids.STDERR] = event.data.get(ids.STDERR) + fm[ids.STDOUT] = event.data.get(ids.STDOUT) + fm[ids.STDERR] = event.data.get(ids.STDERR) self.update_forward_model( _get_real_id(e_source), _get_forward_model_id(e_source), - fm_dict, + fm, ) elif e_type in ids.EVGROUP_ENSEMBLE: @@ -362,15 +359,16 @@ def metadata(self) -> SnapshotMetadata: def get_all_forward_models( self, - ) -> Mapping[Tuple[str, str], Dict[str, Union[str, datetime.datetime]]]: - return self._my_partial._forward_model_states + ) -> Mapping[Tuple[str, str], "ForwardModel"]: + return self._my_partial._forward_model_states.copy() def get_forward_model_status_for_all_reals( self, - ) -> Mapping[Tuple[str, str], Union[str, datetime.datetime]]: + ) -> Mapping[Tuple[str, str], str]: return { idx: forward_model_state["status"] for idx, forward_model_state in self._my_partial._forward_model_states.items() + if "status" in forward_model_state and forward_model_state["status"] is not None } @property @@ -379,11 +377,11 @@ def reals(self) -> Mapping[str, "RealizationSnapshot"]: def get_forward_models_for_real( self, real_id: str - ) -> Dict[str, Dict[str, Union[str, datetime.datetime]]]: + ) -> Dict[str, "ForwardModel"]: return { - forward_model_idx[1]: forward_model_data - for forward_model_idx, forward_model_data in self._my_partial._forward_model_states.items() - if forward_model_idx[0] == real_id + fm_idx[1]: forward_model_data.copy() + for fm_idx, forward_model_data in self._my_partial._forward_model_states.items() + if fm_idx[0] == real_id } def get_real(self, real_id: str) -> "RealizationSnapshot": @@ -391,8 +389,8 @@ def get_real(self, real_id: str) -> "RealizationSnapshot": def get_job( self, real_id: str, forward_model_id: str - ) -> Dict[str, Union[str, datetime.datetime]]: - return self._my_partial._forward_model_states[(real_id, forward_model_id)] + ) -> "ForwardModel": + return self._my_partial._forward_model_states[(real_id, forward_model_id)].copy() def get_successful_realizations(self) -> typing.List[int]: return [ @@ -413,13 +411,24 @@ def data(self) -> Mapping[str, Any]: # The gui uses this return self._my_partial.to_dict() +class ForwardModel(TypedDict, total=False): + status: Optional[str] + start_time: Optional[datetime.datetime] + end_time: Optional[datetime.datetime] + index: Optional[str] + current_memory_usage: Optional[str] + max_memory_usage: Optional[str] + name: Optional[str] + error: Optional[str] + stdout: Optional[str] + stderr: Optional[str] class RealizationSnapshot(BaseModel): status: Optional[str] = None active: Optional[bool] = None start_time: Optional[datetime.datetime] = None end_time: Optional[datetime.datetime] = None - forward_models: Dict[str, Dict[str, Union[str, datetime.datetime]]] = {} + forward_models: Dict[str, ForwardModel] = {} class SnapshotDict(BaseModel): @@ -429,7 +438,7 @@ class SnapshotDict(BaseModel): class SnapshotBuilder(BaseModel): - forward_models: Dict[str, Dict[str, Union[str, datetime.datetime]]] = {} + forward_models: Dict[str, ForwardModel] = {} metadata: Dict[str, Any] = {} def build( @@ -463,18 +472,16 @@ def add_forward_model( stdout: Optional[str] = None, stderr: Optional[str] = None, ) -> "SnapshotBuilder": - self.forward_models[forward_model_id] = _filter_nones( - { - "status": status, - "index": index, - "start_time": start_time, - "end_time": end_time, - "name": name, - "stdout": stdout, - "stderr": stderr, - "current_memory_usage": current_memory_usage, - "max_memory_usage": max_memory_usage, - } + self.forward_models[forward_model_id] = ForwardModel( + status=status, + index=index, + start_time=start_time, + end_time=end_time, + name=name, + stdout=stdout, + stderr=stderr, + current_memory_usage=current_memory_usage, + max_memory_usage=max_memory_usage, ) return self diff --git a/src/ert/gui/model/node.py b/src/ert/gui/model/node.py index 5ac342eacd2..431c3dc8480 100644 --- a/src/ert/gui/model/node.py +++ b/src/ert/gui/model/node.py @@ -1,13 +1,14 @@ from __future__ import annotations -import datetime from abc import ABC, abstractmethod from dataclasses import dataclass, field from enum import Enum, auto -from typing import Any, Dict, Optional, Union +from typing import Any, Optional from qtpy.QtGui import QColor +from ert.ensemble_evaluator.snapshot import ForwardModel + class NodeType(Enum): ROOT = auto() @@ -135,7 +136,7 @@ def row(self) -> int: @dataclass class ForwardModelStepNode(_Node): parent: RealNode - data: Dict[str, Union[str, datetime.datetime]] + data: ForwardModel def add_child(self, *args, **kwargs): pass diff --git a/src/ert/gui/model/snapshot.py b/src/ert/gui/model/snapshot.py index 1a1c964fdd3..aa9c6d44054 100644 --- a/src/ert/gui/model/snapshot.py +++ b/src/ert/gui/model/snapshot.py @@ -218,13 +218,12 @@ def _add_partial_snapshot(self, partial: PartialSnapshot, iter_: int) -> None: jobs_changed_by_real[real_id].append(job_node.row()) - job_without_nones = {k: v for (k, v) in job.items() if v is not None} - job_node.data.update(job_without_nones) - if job.get("current_memory_usage", None) is not None: + job_node.data.update(job) + if "current_memory_usage" in job and job["current_memory_usage"] is not None: cur_mem_usage = int(float(job["current_memory_usage"])) real_node.data.current_memory_usage = cur_mem_usage self.root.data.current_memory_usage = cur_mem_usage - if job.get("max_memory_usage", None) is not None: + if "max_memory_usage" in job and job["max_memory_usage"] is not None: max_mem_usage = int(float(job["max_memory_usage"])) real_node.data.max_memory_usage = max_mem_usage self.root.data.max_memory_usage = max_mem_usage diff --git a/tests/performance_tests/test_snapshot.py b/tests/performance_tests/test_snapshot.py index cc123d506db..38545cc30ca 100644 --- a/tests/performance_tests/test_snapshot.py +++ b/tests/performance_tests/test_snapshot.py @@ -7,6 +7,7 @@ from ert.ensemble_evaluator import identifiers as ids from ert.ensemble_evaluator import state from ert.ensemble_evaluator.snapshot import ( + ForwardModel, PartialSnapshot, RealizationSnapshot, Snapshot, @@ -74,11 +75,11 @@ def simulate_forward_model_event_handling( status=state.REALIZATION_STATE_WAITING, ) for fm_idx in range(forward_models): - reals[f"{real}"].forward_models[str(fm_idx)] = { - "status": state.FORWARD_MODEL_STATE_START, - "index": fm_idx, - "name": f"FM_{fm_idx}", - } + reals[f"{real}"].forward_models[str(fm_idx)] = ForwardModel( + status=state.FORWARD_MODEL_STATE_START, + index=fm_idx, + name=f"FM_{fm_idx}", + ) top = SnapshotDict( reals=reals, status=state.ENSEMBLE_STATE_UNKNOWN, metadata={"foo": "bar"} ) diff --git a/tests/unit_tests/ensemble_evaluator/test_snapshot.py b/tests/unit_tests/ensemble_evaluator/test_snapshot.py index 850b9d3977a..098adb9b3a9 100644 --- a/tests/unit_tests/ensemble_evaluator/test_snapshot.py +++ b/tests/unit_tests/ensemble_evaluator/test_snapshot.py @@ -6,6 +6,7 @@ from ert.ensemble_evaluator import identifiers as ids from ert.ensemble_evaluator import state from ert.ensemble_evaluator.snapshot import ( + ForwardModel, PartialSnapshot, Snapshot, SnapshotBuilder, @@ -19,58 +20,58 @@ def test_snapshot_merge(snapshot: Snapshot): update_event.update_forward_model( real_id="1", forward_model_id="0", - forward_model={ - "status": "Finished", - "index": "0", - "start_time": datetime(year=2020, month=10, day=27), - "end_time": datetime(year=2020, month=10, day=28), - }, + forward_model=ForwardModel( + status="Finished", + index="0", + start_time=datetime(year=2020, month=10, day=27), + end_time=datetime(year=2020, month=10, day=28), + ), ) update_event.update_forward_model( real_id="1", forward_model_id="1", - forward_model={ - "status": "Running", - "index": "1", - "start_time": datetime(year=2020, month=10, day=27), - }, + forward_model=ForwardModel( + status="Running", + index="1", + start_time=datetime(year=2020, month=10, day=27), + ), ) update_event.update_forward_model( real_id="9", forward_model_id="0", - forward_model={ - "status": "Running", - "index": "0", - "start_time": datetime(year=2020, month=10, day=27), - }, + forward_model=ForwardModel( + status="Running", + index="0", + start_time=datetime(year=2020, month=10, day=27), + ), ) snapshot.merge_event(update_event) assert snapshot.status == state.ENSEMBLE_STATE_UNKNOWN - assert snapshot.get_job(real_id="1", forward_model_id="0") == { - "status": "Finished", - "index": "0", - "start_time": datetime(year=2020, month=10, day=27), - "end_time": datetime(year=2020, month=10, day=28), - "name": "forward_model0", - } + assert snapshot.get_job(real_id="1", forward_model_id="0") == ForwardModel( + status="Finished", + index="0", + start_time=datetime(year=2020, month=10, day=27), + end_time=datetime(year=2020, month=10, day=28), + name="forward_model0", + ) - assert snapshot.get_job(real_id="1", forward_model_id="1") == { - "status": "Running", - "index": "1", - "start_time": datetime(year=2020, month=10, day=27), - "name": "forward_model1", - } + assert snapshot.get_job(real_id="1", forward_model_id="1") == ForwardModel( + status="Running", + index="1", + start_time=datetime(year=2020, month=10, day=27), + name="forward_model1", + ) - assert snapshot.get_job(real_id="9", forward_model_id="0")["status"] == "Running" - assert snapshot.get_job(real_id="9", forward_model_id="0") == { - "status": "Running", - "index": "0", - "start_time": datetime(year=2020, month=10, day=27), - "name": "forward_model0", - } + assert snapshot.get_job(real_id="9", forward_model_id="0").status == "Running" + assert snapshot.get_job(real_id="9", forward_model_id="0") == ForwardModel( + status="Running", + index="0", + start_time=datetime(year=2020, month=10, day=27), + name="forward_model0", + ) @pytest.mark.parametrize( diff --git a/tests/unit_tests/gui/conftest.py b/tests/unit_tests/gui/conftest.py index 7dc0baf9ff2..8382efe625e 100644 --- a/tests/unit_tests/gui/conftest.py +++ b/tests/unit_tests/gui/conftest.py @@ -21,6 +21,7 @@ from ert.config import ErtConfig from ert.enkf_main import EnKFMain from ert.ensemble_evaluator.snapshot import ( + ForwardModel, RealizationSnapshot, Snapshot, SnapshotBuilder, @@ -320,52 +321,54 @@ def full_snapshot() -> Snapshot: status=REALIZATION_STATE_RUNNING, active=True, forward_models={ - "0": { - "start_time": dt.now(), - "end_time": dt.now(), - "name": "poly_eval", - "index": "0", - "status": FORWARD_MODEL_STATE_START, - "error": "error", - "stdout": "std_out_file", - "stderr": "std_err_file", - "current_memory_usage": "123", - "max_memory_usage": "312", - }, - "1": { - "start_time": dt.now(), - "end_time": dt.now(), - "name": "poly_postval", - "index": "1", - "status": FORWARD_MODEL_STATE_START, - "error": "error", - "stdout": "std_out_file", - "stderr": "std_err_file", - "current_memory_usage": "123", - "max_memory_usage": "312", - }, - "2": { - "start_time": dt.now(), - "name": "poly_post_mortem", - "index": "2", - "status": FORWARD_MODEL_STATE_START, - "error": "error", - "stdout": "std_out_file", - "stderr": "std_err_file", - "current_memory_usage": "123", - "max_memory_usage": "312", - }, - "3": { - "start_time": dt.now(), - "name": "poly_not_started", - "index": "3", - "status": FORWARD_MODEL_STATE_START, - "error": "error", - "stdout": "std_out_file", - "stderr": "std_err_file", - "current_memory_usage": "123", - "max_memory_usage": "312", - }, + "0": ForwardModel( + start_time=dt.now(), + end_time=dt.now(), + name="poly_eval", + index="0", + status=FORWARD_MODEL_STATE_START, + error="error", + stdout="std_out_file", + stderr="std_err_file", + current_memory_usage="123", + max_memory_usage="312", + ), + "1": ForwardModel( + start_time=dt.now(), + end_time=dt.now(), + name="poly_postval", + index="1", + status=FORWARD_MODEL_STATE_START, + error="error", + stdout="std_out_file", + stderr="std_err_file", + current_memory_usage="123", + max_memory_usage="312", + ), + "2": ForwardModel( + start_time=dt.now(), + end_time=None, + name="poly_post_mortem", + index="2", + status=FORWARD_MODEL_STATE_START, + error="error", + stdout="std_out_file", + stderr="std_err_file", + current_memory_usage="123", + max_memory_usage="312", + ), + "3": ForwardModel( + start_time=dt.now(), + end_time=None, + name="poly_not_started", + index="3", + status=FORWARD_MODEL_STATE_START, + error="error", + stdout="std_out_file", + stderr="std_err_file", + current_memory_usage="123", + max_memory_usage="312", + ), }, ) snapshot = SnapshotDict( @@ -384,29 +387,30 @@ def waiting_snapshot() -> Snapshot: status=REALIZATION_STATE_WAITING, active=True, forward_models={ - "0": { - "start_time": dt.now(), - "end_time": dt.now(), - "name": "poly_eval", - "index": "0", - "status": FORWARD_MODEL_STATE_START, - "error": "error", - "stdout": "std_out_file", - "stderr": "std_err_file", - "current_memory_usage": "123", - "max_memory_usage": "312", - }, - "1": { - "start_time": dt.now(), - "name": "poly_postval", - "index": "1", - "status": FORWARD_MODEL_STATE_START, - "error": "error", - "stdout": "std_out_file", - "stderr": "std_err_file", - "current_memory_usage": "123", - "max_memory_usage": "312", - }, + "0": ForwardModel( + start_time=dt.now(), + end_time=dt.now(), + name="poly_eval", + index="0", + status=FORWARD_MODEL_STATE_START, + error="error", + stdout="std_out_file", + stderr="std_err_file", + current_memory_usage="123", + max_memory_usage="312", + ), + "1": ForwardModel( + start_time=dt.now(), + end_time=None, + name="poly_postval", + index="1", + status=FORWARD_MODEL_STATE_START, + error="error", + stdout="std_out_file", + stderr="std_err_file", + current_memory_usage="123", + max_memory_usage="312", + ), }, ) snapshot = SnapshotDict( diff --git a/tests/unit_tests/gui/model/gui_models_utils.py b/tests/unit_tests/gui/model/gui_models_utils.py index de491fc08bb..45c5bbeca93 100644 --- a/tests/unit_tests/gui/model/gui_models_utils.py +++ b/tests/unit_tests/gui/model/gui_models_utils.py @@ -1,9 +1,11 @@ -from ert.ensemble_evaluator.snapshot import PartialSnapshot +from ert.ensemble_evaluator.snapshot import ForwardModel, PartialSnapshot from ert.ensemble_evaluator.state import FORWARD_MODEL_STATE_FINISHED def partial_snapshot(snapshot) -> PartialSnapshot: partial = PartialSnapshot(snapshot) partial._realization_states["0"].update({"status": FORWARD_MODEL_STATE_FINISHED}) - partial.update_forward_model("0", "0", {"status": FORWARD_MODEL_STATE_FINISHED}) + partial.update_forward_model( + "0", "0", ForwardModel(status=FORWARD_MODEL_STATE_FINISHED) + ) return partial diff --git a/tests/unit_tests/gui/model/test_job_list.py b/tests/unit_tests/gui/model/test_job_list.py index 3c3a389cf87..c5713bde2b1 100644 --- a/tests/unit_tests/gui/model/test_job_list.py +++ b/tests/unit_tests/gui/model/test_job_list.py @@ -7,7 +7,7 @@ from pytestqt.qt_compat import qt_api from ert.ensemble_evaluator import identifiers as ids -from ert.ensemble_evaluator.snapshot import PartialSnapshot +from ert.ensemble_evaluator.snapshot import ForwardModel, PartialSnapshot from ert.ensemble_evaluator.state import ( FORWARD_MODEL_STATE_FAILURE, FORWARD_MODEL_STATE_RUNNING, @@ -69,11 +69,11 @@ def test_changes(full_snapshot): partial.update_forward_model( "0", "0", - forward_model={ - "status": FORWARD_MODEL_STATE_FAILURE, - "start_time": start_time, - "end_time": end_time, - }, + forward_model=ForwardModel( + status=FORWARD_MODEL_STATE_FAILURE, + start_time=start_time, + end_time=end_time, + ), ) source_model._add_partial_snapshot(SnapshotModel.prerender(partial), 0) assert ( @@ -123,10 +123,10 @@ def test_duration(mock_datetime, timezone, full_snapshot): partial.update_forward_model( "0", "2", - forward_model={ - "status": FORWARD_MODEL_STATE_RUNNING, - "start_time": start_time, - }, + forward_model=ForwardModel( + status=FORWARD_MODEL_STATE_RUNNING, + start_time=start_time, + ), ) source_model._add_partial_snapshot(SnapshotModel.prerender(partial), 0) assert ( @@ -151,7 +151,7 @@ def test_no_cross_talk(full_snapshot): # Test that changes to iter=1 does not bleed into iter=0 partial = PartialSnapshot(full_snapshot) partial.update_forward_model( - "0", "0", forward_model={"status": FORWARD_MODEL_STATE_FAILURE} + "0", "0", forward_model=ForwardModel(status=FORWARD_MODEL_STATE_FAILURE) ) source_model._add_partial_snapshot(SnapshotModel.prerender(partial), 1) assert ( diff --git a/tests/unit_tests/gui/model/test_snapshot.py b/tests/unit_tests/gui/model/test_snapshot.py index e049b322bad..93d9b0a3613 100644 --- a/tests/unit_tests/gui/model/test_snapshot.py +++ b/tests/unit_tests/gui/model/test_snapshot.py @@ -4,7 +4,7 @@ from qtpy.QtCore import QModelIndex from qtpy.QtGui import QColor -from ert.ensemble_evaluator.snapshot import PartialSnapshot +from ert.ensemble_evaluator.snapshot import ForwardModel, PartialSnapshot from ert.ensemble_evaluator.state import ( COLOR_FAILED, COLOR_FINISHED, @@ -58,9 +58,15 @@ def test_realization_state_matches_display_color(full_snapshot): model._add_snapshot(SnapshotModel.prerender(full_snapshot), 0) partial = PartialSnapshot(full_snapshot) - partial.update_forward_model("0", "0", {"status": FORWARD_MODEL_STATE_FINISHED}) - partial.update_forward_model("0", "1", {"status": FORWARD_MODEL_STATE_FAILURE}) - partial.update_forward_model("0", "2", {"status": FORWARD_MODEL_STATE_RUNNING}) + partial.update_forward_model( + "0", "0", ForwardModel(status=FORWARD_MODEL_STATE_FINISHED) + ) + partial.update_forward_model( + "0", "1", ForwardModel(status=FORWARD_MODEL_STATE_FAILURE) + ) + partial.update_forward_model( + "0", "2", ForwardModel(status=FORWARD_MODEL_STATE_RUNNING) + ) model._add_partial_snapshot(SnapshotModel.prerender(partial), 0) first_real = model.index(0, 0, model.index(0, 0)) @@ -113,7 +119,7 @@ def test_display_color_changes_when_realization_state_is_not_running( model._add_snapshot(SnapshotModel.prerender(waiting_snapshot), 0) partial = PartialSnapshot(waiting_snapshot) - partial.update_forward_model("0", "0", {"status": forward_model_state}) + partial.update_forward_model("0", "0", ForwardModel(status=forward_model_state)) model._add_partial_snapshot(SnapshotModel.prerender(partial), 0) first_real = model.index(0, 0, model.index(0, 0))