Skip to content

Commit

Permalink
Add EventSender to resolve dependency cycle
Browse files Browse the repository at this point in the history
`Scheduler` has a reference to `Job` and `Job` has a reference to
`Scheduler`. Adding `EventSender` lets us resolve this cycle as now
`Scheduler` has a reference to `Job`s and `EventSender`, but each `Job`
only refers to `EventSender`.
  • Loading branch information
pinkwah committed Jan 17, 2024
1 parent 8a6fc5f commit f7f4a56
Show file tree
Hide file tree
Showing 7 changed files with 260 additions and 189 deletions.
11 changes: 11 additions & 0 deletions src/ert/ensemble_evaluator/identifiers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Literal

ACTIVE = "active"
CURRENT_MEMORY_USAGE = "current_memory_usage"
DATA = "data"
Expand Down Expand Up @@ -30,6 +32,15 @@
EVTYPE_FORWARD_MODEL_SUCCESS = "com.equinor.ert.forward_model_job.success"
EVTYPE_FORWARD_MODEL_FAILURE = "com.equinor.ert.forward_model_job.failure"

EvGroupRealizationType = Literal[
"com.equinor.ert.realization.failure",
"com.equinor.ert.realization.pending",
"com.equinor.ert.realization.running",
"com.equinor.ert.realization.success",
"com.equinor.ert.realization.unknown",
"com.equinor.ert.realization.waiting",
"com.equinor.ert.realization.timeout",
]

EVGROUP_REALIZATION = {
EVTYPE_REALIZATION_FAILURE,
Expand Down
69 changes: 69 additions & 0 deletions src/ert/scheduler/event_sender.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from __future__ import annotations

import asyncio
import ssl
from typing import TYPE_CHECKING, Any, Mapping, Optional

from cloudevents.conversion import to_json
from cloudevents.http import CloudEvent
from websockets import Headers, connect

if TYPE_CHECKING:
from ert.ensemble_evaluator.identifiers import EvGroupRealizationType


class EventSender:
def __init__(
self,
ens_id: Optional[str],
ee_uri: Optional[str],
ee_cert: Optional[str],
ee_token: Optional[str],
) -> None:
self.ens_id = ens_id
self.ee_uri = ee_uri
self.ee_cert = ee_cert
self.ee_token = ee_token
self.events: asyncio.Queue[CloudEvent] = asyncio.Queue()

async def send(
self,
type: EvGroupRealizationType,
source: str,
attributes: Optional[Mapping[str, Any]] = None,
data: Optional[Mapping[str, Any]] = None,
) -> None:
event = CloudEvent(
{
"type": type,
"source": f"/ert/ensemble/{self.ens_id}/{source}",
**(attributes or {}),
},
data,
)
await self.events.put(event)

async def publisher(self) -> None:
if not self.ee_uri:
return
tls: Optional[ssl.SSLContext] = None
if self.ee_cert:
tls = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
tls.load_verify_locations(cadata=self.ee_cert)
headers = Headers()
if self.ee_token:
headers["token"] = self.ee_token

async for conn in connect(
self.ee_uri,
ssl=tls,
extra_headers=headers,
open_timeout=60,
ping_timeout=60,
ping_interval=60,
close_timeout=60,
):
while True:
event = await self.events.get()
print(f"==SENDING {event=}")
await conn.send(to_json(event))
107 changes: 60 additions & 47 deletions src/ert/scheduler/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,23 @@
import uuid
from enum import Enum
from pathlib import Path
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, Callable, Coroutine, List, Mapping, Optional

from cloudevents.conversion import to_json
from cloudevents.http import CloudEvent
from lxml import etree

from ert.callbacks import forward_model_ok
from ert.constant_filenames import ERROR_file
from ert.job_queue.queue import _queue_state_event_type
from ert.load_status import LoadStatus
from ert.scheduler.driver import Driver
from ert.scheduler.event_sender import EventSender
from ert.storage.realization_storage_state import RealizationStorageState

if TYPE_CHECKING:
from ert.ensemble_evaluator._builder._realization import Realization
from ert.scheduler.scheduler import Scheduler
from ert.ensemble_evaluator.identifiers import EvGroupRealizationType

logger = logging.getLogger(__name__)

# Duplicated to avoid circular imports
EVTYPE_REALIZATION_TIMEOUT = "com.equinor.ert.realization.timeout"


class State(str, Enum):
WAITING = "WAITING"
Expand All @@ -40,6 +35,17 @@ class State(str, Enum):
ABORTED = "ABORTED"


STATE_TO_EE: Mapping[State, EvGroupRealizationType] = {
State.WAITING: "com.equinor.ert.realization.waiting",
State.SUBMITTING: "com.equinor.ert.realization.waiting",
State.PENDING: "com.equinor.ert.realization.pending",
State.RUNNING: "com.equinor.ert.realization.running",
State.COMPLETED: "com.equinor.ert.realization.success",
State.FAILED: "com.equinor.ert.realization.failure",
State.ABORTED: "com.equinor.ert.realization.failure",
}


STATE_TO_LEGACY = {
State.WAITING: "WAITING",
State.SUBMITTING: "SUBMITTED",
Expand All @@ -59,13 +65,19 @@ class Job:
(LSF, PBS, SLURM, etc.)
"""

def __init__(self, scheduler: Scheduler, real: Realization) -> None:
def __init__(
self,
real: Realization,
*,
on_complete: Optional[Callable[[int], Coroutine[None, None, None]]] = None,
) -> None:
self.real = real
self.state = State.WAITING
self.started = asyncio.Event()
self.returncode: asyncio.Future[int] = asyncio.Future()
self.aborted = asyncio.Event()
self._scheduler: Scheduler = scheduler
self.on_complete = on_complete
self._event_sender: Optional[EventSender] = None
self._callback_status_msg: str = ""
self._requested_max_submit: Optional[int] = None
self._start_time: Optional[float] = None
Expand All @@ -75,10 +87,6 @@ def __init__(self, scheduler: Scheduler, real: Realization) -> None:
def iens(self) -> int:
return self.real.iens

@property
def driver(self) -> Driver:
return self._scheduler.driver

@property
def running_duration(self) -> float:
if self._start_time:
Expand All @@ -87,13 +95,15 @@ def running_duration(self) -> float:
return time.time() - self._start_time
return 0

async def _submit_and_run_once(self, sem: asyncio.BoundedSemaphore) -> None:
async def _submit_and_run_once(
self, sem: asyncio.BoundedSemaphore, driver: Driver
) -> None:
await sem.acquire()
timeout_task: Optional[asyncio.Task[None]] = None

try:
await self._send(State.SUBMITTING)
await self.driver.submit(
await driver.submit(
self.real.iens, self.real.job_script, cwd=self.real.run_arg.runpath
)

Expand All @@ -102,7 +112,7 @@ async def _submit_and_run_once(self, sem: asyncio.BoundedSemaphore) -> None:
self._start_time = time.time()

await self._send(State.RUNNING)
if self.real.max_runtime is not None and self.real.max_runtime > 0:
if (self.real.max_runtime or 0) > 0:
timeout_task = asyncio.create_task(self._max_runtime_task())
while not self.returncode.done():
await asyncio.sleep(0.01)
Expand All @@ -119,12 +129,16 @@ async def _submit_and_run_once(self, sem: asyncio.BoundedSemaphore) -> None:

if callback_status == LoadStatus.LOAD_SUCCESSFUL:
await self._send(State.COMPLETED)
self._end_time = time.time()
if self.on_complete is not None:
await self.on_complete(self.iens)
else:
assert callback_status in (
LoadStatus.LOAD_FAILURE,
LoadStatus.TIME_MAP_FAILURE,
)
await self._send(State.FAILED)
await self._handle_failure()

else:
await self._send(State.FAILED)
Expand All @@ -133,22 +147,27 @@ async def _submit_and_run_once(self, sem: asyncio.BoundedSemaphore) -> None:

except asyncio.CancelledError:
await self._send(State.ABORTING)
await self.driver.kill(self.iens)
await driver.kill(self.iens)
await self.aborted.wait()
await self._send(State.ABORTED)
await self._handle_aborted()
finally:
if timeout_task and not timeout_task.done():
timeout_task.cancel()
sem.release()

async def __call__(
self, start: asyncio.Event, sem: asyncio.BoundedSemaphore, max_submit: int = 2
self,
sem: asyncio.BoundedSemaphore,
event_sender: EventSender,
driver: Driver,
max_submit: int = 2,
) -> None:
self._event_sender = event_sender
self._requested_max_submit = max_submit
await start.wait()

for attempt in range(max_submit):
await self._submit_and_run_once(sem)
await self._submit_and_run_once(sem, driver)

if self.returncode.done() or self.aborted.is_set():
break
Expand All @@ -159,17 +178,16 @@ async def __call__(
async def _max_runtime_task(self) -> None:
assert self.real.max_runtime is not None
await asyncio.sleep(self.real.max_runtime)
timeout_event = CloudEvent(
{
"type": EVTYPE_REALIZATION_TIMEOUT,
"source": f"/ert/ensemble/{self._scheduler._ens_id}/real/{self.iens}",
"id": str(uuid.uuid1()),
}
)
assert self._scheduler._events is not None
await self._scheduler._events.put(to_json(timeout_event))

self.returncode.cancel() # Triggers CancelledError
if self._event_sender is not None:
await self._event_sender.send(
"com.equinor.ert.realization.timeout",
f"real/{self.iens}",
attributes={"id": str(uuid.uuid1())},
)

self._event_sender = None
self.returncode.cancel()

async def _handle_failure(self) -> None:
assert self._requested_max_submit is not None
Expand All @@ -196,28 +214,23 @@ async def _handle_aborted(self) -> None:

async def _send(self, state: State) -> None:
self.state = state
if state == State.FAILED:
await self._handle_failure()

elif state == State.ABORTED:
await self._handle_aborted()
if self._event_sender is None:
return

elif state == State.COMPLETED:
self._end_time = time.time()
await self._scheduler.completed_jobs.put(self.iens)
if (status := STATE_TO_EE.get(state)) is None:
# This message does not need to be propagated to the user
return

status = STATE_TO_LEGACY[state]
event = CloudEvent(
{
"type": _queue_state_event_type(status),
"source": f"/ert/ensemble/{self._scheduler._ens_id}/real/{self.iens}",
await self._event_sender.send(
status,
f"real/{self.iens}",
attributes={
"datacontenttype": "application/json",
},
{
"queue_event_type": status,
data={
"queue_event_type": STATE_TO_LEGACY[state],
},
)
await self._scheduler._events.put(to_json(event))


def log_info_from_exit_file(exit_file_path: Path) -> None:
Expand Down
Loading

0 comments on commit f7f4a56

Please sign in to comment.