Skip to content

Commit

Permalink
Add EventSender to resolve dependency cycle
Browse files Browse the repository at this point in the history
`Scheduler` has a reference to `Job` and `Job` has a reference to
`Scheduler`. Adding `EventSender` lets us resolve this cycle as now
`Scheduler` has a reference to `Job`s and `EventSender`, but each `Job`
only refers to `EventSender`.
  • Loading branch information
pinkwah committed Jan 11, 2024
1 parent f27f7f8 commit e3682ca
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 101 deletions.
13 changes: 13 additions & 0 deletions src/ert/ensemble_evaluator/identifiers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Literal

ACTIVE = "active"
CURRENT_MEMORY_USAGE = "current_memory_usage"
DATA = "data"
Expand Down Expand Up @@ -31,6 +33,17 @@
EVTYPE_FORWARD_MODEL_FAILURE = "com.equinor.ert.forward_model_job.failure"


EvGroupRealizationType = Literal[
"com.equinor.ert.realization.failure",
"com.equinor.ert.realization.pending",
"com.equinor.ert.realization.running",
"com.equinor.ert.realization.success",
"com.equinor.ert.realization.unknown",
"com.equinor.ert.realization.waiting",
"com.equinor.ert.realization.timeout",
]


EVGROUP_REALIZATION = {
EVTYPE_REALIZATION_FAILURE,
EVTYPE_REALIZATION_PENDING,
Expand Down
69 changes: 69 additions & 0 deletions src/ert/scheduler/event_sender.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from __future__ import annotations

import asyncio
import ssl
from pathlib import Path
from typing import TYPE_CHECKING, Any, Mapping, Optional

from cloudevents.conversion import to_json
from cloudevents.http import CloudEvent
from websockets import Headers, connect

if TYPE_CHECKING:
from ert.ensemble_evaluator.identifiers import EvGroupRealizationType


class EventSender:
def __init__(
self,
ens_id: Optional[str],
ee_uri: Optional[str],
ee_cert: Optional[str],
ee_token: Optional[str],
) -> None:
self.ens_id = ens_id
self.ee_uri = ee_uri
self.ee_cert = ee_cert
self.ee_token = ee_token
self.events: asyncio.Queue[CloudEvent] = asyncio.Queue()

async def send(
self,
type: EvGroupRealizationType,
source: str,
attributes: Optional[Mapping[str, Any]] = None,
data: Optional[Mapping[str, Any]] = None,
) -> None:
event = CloudEvent(
{
"type": type,
"source": str(Path(f"/ert/ensemble/{self.ens_id}") / source),
**(attributes or {}),
},
data,
)
await self.events.put(event)

async def publisher(self) -> None:
if not self.ee_uri:
return
tls: Optional[ssl.SSLContext] = None
if self.ee_cert:
tls = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
tls.load_verify_locations(cadata=self.ee_cert)
headers = Headers()
if self.ee_token:
headers["token"] = self.ee_token

async for conn in connect(
self.ee_uri,
ssl=tls,
extra_headers=headers,
open_timeout=60,
ping_timeout=60,
ping_interval=60,
close_timeout=60,
):
while True:
event = await self.events.get()
await conn.send(to_json(event))
66 changes: 31 additions & 35 deletions src/ert/scheduler/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,20 @@
from pathlib import Path
from typing import TYPE_CHECKING, List, Optional

from cloudevents.conversion import to_json
from cloudevents.http import CloudEvent
from lxml import etree

from ert.callbacks import forward_model_ok
from ert.constant_filenames import ERROR_file
from ert.job_queue.queue import _queue_state_event_type
from ert.load_status import LoadStatus
from ert.scheduler.driver import Driver
from ert.scheduler.event_sender import EventSender
from ert.storage.realization_storage_state import RealizationStorageState

if TYPE_CHECKING:
from ert.ensemble_evaluator._builder._realization import Realization
from ert.scheduler.scheduler import Scheduler

logger = logging.getLogger(__name__)

# Duplicated to avoid circular imports
EVTYPE_REALIZATION_TIMEOUT = "com.equinor.ert.realization.timeout"


class State(str, Enum):
WAITING = "WAITING"
Expand Down Expand Up @@ -59,13 +53,13 @@ class Job:
(LSF, PBS, SLURM, etc.)
"""

def __init__(self, scheduler: Scheduler, real: Realization) -> None:
def __init__(self, real: Realization) -> None:
self.real = real
self.state = State.WAITING
self.started = asyncio.Event()
self.returncode: asyncio.Future[int] = asyncio.Future()
self.aborted = asyncio.Event()
self._scheduler: Scheduler = scheduler
self._event_sender: Optional[EventSender] = None
self._callback_status_msg: str = ""
self._requested_max_submit: Optional[int] = None
self._start_time: Optional[float] = None
Expand All @@ -75,10 +69,6 @@ def __init__(self, scheduler: Scheduler, real: Realization) -> None:
def iens(self) -> int:
return self.real.iens

@property
def driver(self) -> Driver:
return self._scheduler.driver

@property
def running_duration(self) -> float:
if self._start_time:
Expand All @@ -87,13 +77,15 @@ def running_duration(self) -> float:
return time.time() - self._start_time
return 0

async def _submit_and_run_once(self, sem: asyncio.BoundedSemaphore) -> None:
async def _submit_and_run_once(
self, sem: asyncio.BoundedSemaphore, driver: Driver
) -> None:
await sem.acquire()
timeout_task: Optional[asyncio.Task[None]] = None

try:
await self._send(State.SUBMITTING)
await self.driver.submit(
await driver.submit(
self.real.iens, self.real.job_script, cwd=self.real.run_arg.runpath
)

Expand Down Expand Up @@ -133,7 +125,7 @@ async def _submit_and_run_once(self, sem: asyncio.BoundedSemaphore) -> None:

except asyncio.CancelledError:
await self._send(State.ABORTING)
await self.driver.kill(self.iens)
await driver.kill(self.iens)
await self.aborted.wait()
await self._send(State.ABORTED)
finally:
Expand All @@ -142,13 +134,17 @@ async def _submit_and_run_once(self, sem: asyncio.BoundedSemaphore) -> None:
sem.release()

async def __call__(
self, start: asyncio.Event, sem: asyncio.BoundedSemaphore, max_submit: int = 2
self,
sem: asyncio.BoundedSemaphore,
event_sender: EventSender,
driver: Driver,
max_submit: int = 2,
) -> None:
self._event_sender = event_sender
self._requested_max_submit = max_submit
await start.wait()

for attempt in range(max_submit):
await self._submit_and_run_once(sem)
await self._submit_and_run_once(sem, driver)

if self.returncode.done() or self.aborted.is_set():
break
Expand All @@ -159,17 +155,16 @@ async def __call__(
async def _max_runtime_task(self) -> None:
assert self.real.max_runtime is not None
await asyncio.sleep(self.real.max_runtime)
timeout_event = CloudEvent(
{
"type": EVTYPE_REALIZATION_TIMEOUT,
"source": f"/ert/ensemble/{self._scheduler._ens_id}/real/{self.iens}",
"id": str(uuid.uuid1()),
}
)
assert self._scheduler._events is not None
await self._scheduler._events.put(to_json(timeout_event))

self.returncode.cancel() # Triggers CancelledError
self.returncode.cancel()

if self._event_sender is None:
return
await self._event_sender.send(
"com.equinor.ert.realization.timeout",
f"/real/{self.iens}",
attributes={"id": str(uuid.uuid1())},
)

async def _handle_failure(self) -> None:
assert self._requested_max_submit is not None
Expand All @@ -187,6 +182,8 @@ async def _handle_failure(self) -> None:
log_info_from_exit_file(Path(self.real.run_arg.runpath) / ERROR_file)

async def _send(self, state: State) -> None:
if self._event_sender is None:
return
self.state = state
if state in (State.FAILED, State.ABORTED):
await self._handle_failure()
Expand All @@ -196,17 +193,16 @@ async def _send(self, state: State) -> None:
await self._scheduler.completed_jobs.put(self.iens)

Check failure on line 193 in src/ert/scheduler/job.py

View workflow job for this annotation

GitHub Actions / type-checking (3.11)

"Job" has no attribute "_scheduler"

status = STATE_TO_LEGACY[state]
event = CloudEvent(
{
"type": _queue_state_event_type(status),
"source": f"/ert/ensemble/{self._scheduler._ens_id}/real/{self.iens}",
await self._event_sender.send(
status,

Check failure on line 197 in src/ert/scheduler/job.py

View workflow job for this annotation

GitHub Actions / type-checking (3.11)

Argument 1 to "send" of "EventSender" has incompatible type "str"; expected "Literal['com.equinor.ert.realization.failure', 'com.equinor.ert.realization.pending', 'com.equinor.ert.realization.running', 'com.equinor.ert.realization.success', 'com.equinor.ert.realization.unknown', 'com.equinor.ert.realization.waiting', 'com.equinor.ert.realization.timeout']"
f"/real/{self.iens}",
attributes={
"datacontenttype": "application/json",
},
{
data={
"queue_event_type": status,
},
)
await self._scheduler._events.put(to_json(event))


def log_info_from_exit_file(exit_file_path: Path) -> None:
Expand Down
70 changes: 26 additions & 44 deletions src/ert/scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,26 @@
import json
import logging
import os
import ssl
from collections import defaultdict
from dataclasses import asdict
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, MutableMapping, Optional, Sequence
from typing import (
TYPE_CHECKING,
Any,
Dict,
Mapping,
MutableMapping,
Optional,
Sequence,
)

from pydantic.dataclasses import dataclass
from websockets import Headers
from websockets.client import connect

from ert.async_utils import background_tasks
from ert.constant_filenames import CERT_FILE
from ert.job_queue.queue import EVTYPE_ENSEMBLE_CANCELLED, EVTYPE_ENSEMBLE_STOPPED
from ert.scheduler.driver import Driver, JobEvent
from ert.scheduler.event_sender import EventSender
from ert.scheduler.job import Job
from ert.scheduler.job import State as JobState
from ert.scheduler.local_driver import LocalDriver
Expand Down Expand Up @@ -57,8 +63,8 @@ def __init__(
self.driver = driver
self._tasks: MutableMapping[int, asyncio.Task[None]] = {}

self._jobs: MutableMapping[int, Job] = {
real.iens: Job(self, real) for real in (realizations or [])
self._jobs: Mapping[int, Job] = {
real.iens: Job(real) for real in (realizations or [])
}

self._events: asyncio.Queue[Any] = asyncio.Queue()
Expand All @@ -70,10 +76,12 @@ def __init__(
self._max_submit = max_submit
self._max_running = max_running

self._ee_uri = ee_uri
self._ens_id = ens_id
self._ee_cert = ee_cert
self._ee_token = ee_token
self.event_sender = EventSender(
ens_id=ens_id,
ee_uri=ee_uri,
ee_cert=ee_cert,
ee_token=ee_token,
)

def kill_all_jobs(self) -> None:
self._cancelled = True
Expand Down Expand Up @@ -116,30 +124,6 @@ def count_states(self) -> Dict[JobState, int]:
counts[job.state] += 1
return counts

async def _publisher(self) -> None:
if not self._ee_uri:
return
tls: Optional[ssl.SSLContext] = None
if self._ee_cert:
tls = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
tls.load_verify_locations(cadata=self._ee_cert)
headers = Headers()
if self._ee_token:
headers["token"] = self._ee_token

async with connect(
self._ee_uri,
ssl=tls,
extra_headers=headers,
open_timeout=60,
ping_timeout=60,
ping_interval=60,
close_timeout=60,
) as conn:
while True:
event = await self._events.get()
await conn.send(event)

def add_dispatch_information_to_jobs_file(self) -> None:
for job in self._jobs.values():
self._update_jobs_json(job.iens, job.real.run_arg.runpath)
Expand All @@ -149,7 +133,7 @@ async def execute(
min_required_realizations: int = 0,
) -> str:
async with background_tasks() as cancel_when_execute_is_done:
cancel_when_execute_is_done(self._publisher())
cancel_when_execute_is_done(self.event_sender.publisher())
cancel_when_execute_is_done(self._process_event_queue())
cancel_when_execute_is_done(self.driver.poll())
if min_required_realizations > 0:
Expand All @@ -158,14 +142,12 @@ async def execute(
)
cancel_when_execute_is_done(self._update_avg_job_runtime())

start = asyncio.Event()
sem = asyncio.BoundedSemaphore(self._max_running or len(self._jobs))
for iens, job in self._jobs.items():
self._tasks[iens] = asyncio.create_task(
job(start, sem, self._max_submit)
job(sem, self.event_sender, self.driver, self._max_submit)
)

start.set()
for task in self._tasks.values():
await task

Expand All @@ -190,15 +172,15 @@ async def _process_event_queue(self) -> None:

def _update_jobs_json(self, iens: int, runpath: str) -> None:
cert_path = f"{runpath}/{CERT_FILE}"
if self._ee_cert is not None:
Path(cert_path).write_text(self._ee_cert, encoding="utf-8")
if self.event_sender.ee_cert is not None:
Path(cert_path).write_text(self.event_sender.ee_cert, encoding="utf-8")
jobs = _JobsJson(
experiment_id=None,
ens_id=self._ens_id,
ens_id=self.event_sender.ens_id,
real_id=iens,
dispatch_url=self._ee_uri,
ee_token=self._ee_token,
ee_cert_path=cert_path if self._ee_cert is not None else None,
dispatch_url=self.event_sender.ee_uri,
ee_token=self.event_sender.ee_token,
ee_cert_path=self.event_sender.ee_cert and cert_path,
)
jobs_path = os.path.join(runpath, "jobs.json")
with open(jobs_path, "r") as fp:
Expand Down
Loading

0 comments on commit e3682ca

Please sign in to comment.