Skip to content

Commit

Permalink
Add EventSender to resolve dependency cycle
Browse files Browse the repository at this point in the history
`Scheduler` has a reference to `Job` and `Job` has a reference to
`Scheduler`. Adding `EventSender` lets us resolve this cycle as now
`Scheduler` has a reference to `Job`s and `EventSender`, but each `Job`
only refers to `EventSender`.
  • Loading branch information
pinkwah committed Feb 16, 2024
1 parent 0dd106b commit f34acb9
Show file tree
Hide file tree
Showing 13 changed files with 297 additions and 218 deletions.
26 changes: 13 additions & 13 deletions src/ert/analysis/_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,16 +683,16 @@ def adaptive_localization_progress_callback(
start = time.time()
for param_batch_idx in batches:
X_local = temp_storage[param_group.name][param_batch_idx, :]
temp_storage[param_group.name][param_batch_idx, :] = (
smoother_adaptive_es.assimilate(
X=X_local,
Y=S,
D=D,
alpha=1.0, # The user is responsible for scaling observation covariance (esmda usage)
correlation_threshold=module.correlation_threshold,
cov_YY=cov_YY,
progress_callback=adaptive_localization_progress_callback,
)
temp_storage[param_group.name][
param_batch_idx, :
] = smoother_adaptive_es.assimilate(
X=X_local,
Y=S,
D=D,
alpha=1.0, # The user is responsible for scaling observation covariance (esmda usage)
correlation_threshold=module.correlation_threshold,
cov_YY=cov_YY,
progress_callback=adaptive_localization_progress_callback,
)
_logger.info(
f"Adaptive Localization of {param_group} completed in {(time.time() - start) / 60} minutes"
Expand Down Expand Up @@ -849,9 +849,9 @@ def analysis_IES(
)
if active_parameter_indices := param_group.index_list:
X = temp_storage[param_group.name][active_parameter_indices, :]
temp_storage[param_group.name][active_parameter_indices, :] = (
X + X @ sies_smoother.W / np.sqrt(len(iens_active_index) - 1)
)
temp_storage[param_group.name][
active_parameter_indices, :
] = X + X @ sies_smoother.W / np.sqrt(len(iens_active_index) - 1)
else:
X = temp_storage[param_group.name]
temp_storage[param_group.name] = X + X @ sies_smoother.W / np.sqrt(
Expand Down
6 changes: 4 additions & 2 deletions src/ert/config/ert_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,10 +435,12 @@ def __init__(self, job):
)

@overload
def substitute(self, string: str) -> str: ...
def substitute(self, string: str) -> str:
...

@overload
def substitute(self, string: None) -> None: ...
def substitute(self, string: None) -> None:
...

def substitute(self, string):
if string is None:
Expand Down
4 changes: 3 additions & 1 deletion src/ert/config/parsing/observations_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,9 @@ def parse(filename: str) -> ConfContent:
)


def _parse_content(content: str, filename: str) -> List[
def _parse_content(
content: str, filename: str
) -> List[
Union[
SimpleHistoryDeclaration,
Tuple[ObservationType, FileContextToken, Dict[FileContextToken, Any]],
Expand Down
3 changes: 2 additions & 1 deletion src/ert/config/response_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ class ResponseConfig(ABC):
name: str

@abstractmethod
def read_from_file(self, run_path: str, iens: int) -> xr.Dataset: ...
def read_from_file(self, run_path: str, iens: int) -> xr.Dataset:
...

def to_dict(self) -> Dict[str, Any]:
data = dataclasses.asdict(self, dict_factory=CustomDict)
Expand Down
3 changes: 2 additions & 1 deletion src/ert/ensemble_evaluator/_builder/_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@


class _KillAllJobs(Protocol):
def kill_all_jobs(self) -> None: ...
def kill_all_jobs(self) -> None:
...


class LegacyEnsemble(Ensemble):
Expand Down
11 changes: 11 additions & 0 deletions src/ert/ensemble_evaluator/identifiers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Literal

ACTIVE = "active"
CURRENT_MEMORY_USAGE = "current_memory_usage"
DATA = "data"
Expand Down Expand Up @@ -30,6 +32,15 @@
EVTYPE_FORWARD_MODEL_SUCCESS = "com.equinor.ert.forward_model_job.success"
EVTYPE_FORWARD_MODEL_FAILURE = "com.equinor.ert.forward_model_job.failure"

EvGroupRealizationType = Literal[
"com.equinor.ert.realization.failure",
"com.equinor.ert.realization.pending",
"com.equinor.ert.realization.running",
"com.equinor.ert.realization.success",
"com.equinor.ert.realization.unknown",
"com.equinor.ert.realization.waiting",
"com.equinor.ert.realization.timeout",
]

EVGROUP_REALIZATION = {
EVTYPE_REALIZATION_FAILURE,
Expand Down
68 changes: 68 additions & 0 deletions src/ert/scheduler/event_sender.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from __future__ import annotations

import asyncio
import ssl
from typing import TYPE_CHECKING, Any, Mapping, Optional

from cloudevents.conversion import to_json
from cloudevents.http import CloudEvent
from websockets import Headers, connect

if TYPE_CHECKING:
from ert.ensemble_evaluator.identifiers import EvGroupRealizationType


class EventSender:
def __init__(
self,
ens_id: Optional[str],
ee_uri: Optional[str],
ee_cert: Optional[str],
ee_token: Optional[str],
) -> None:
self.ens_id = ens_id
self.ee_uri = ee_uri
self.ee_cert = ee_cert
self.ee_token = ee_token
self.events: asyncio.Queue[CloudEvent] = asyncio.Queue()

async def send(
self,
type: EvGroupRealizationType,
source: str,
attributes: Optional[Mapping[str, Any]] = None,
data: Optional[Mapping[str, Any]] = None,
) -> None:
event = CloudEvent(
{
"type": type,
"source": f"/ert/ensemble/{self.ens_id}/{source}",
**(attributes or {}),
},
data,
)
await self.events.put(event)

async def publisher(self) -> None:
if not self.ee_uri:
return
tls: Optional[ssl.SSLContext] = None
if self.ee_cert:
tls = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
tls.load_verify_locations(cadata=self.ee_cert)
headers = Headers()
if self.ee_token:
headers["token"] = self.ee_token

async for conn in connect(
self.ee_uri,
ssl=tls,
extra_headers=headers,
open_timeout=60,
ping_timeout=60,
ping_interval=60,
close_timeout=60,
):
while True:
event = await self.events.get()
await conn.send(to_json(event))
Loading

0 comments on commit f34acb9

Please sign in to comment.