Skip to content

Commit

Permalink
Add EventSender to resolve dependency cycle
Browse files Browse the repository at this point in the history
`Scheduler` has a reference to `Job` and `Job` has a reference to
`Scheduler`. Adding `EventSender` lets us resolve this cycle as now
`Scheduler` has a reference to `Job`s and `EventSender`, but each `Job`
only refers to `EventSender`.
  • Loading branch information
pinkwah committed Dec 29, 2023
1 parent 5e2df1f commit 23f0fdb
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 97 deletions.
66 changes: 66 additions & 0 deletions src/ert/scheduler/event_sender.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from __future__ import annotations

import asyncio
import ssl
from pathlib import Path
from typing import Any, Mapping, Optional

from cloudevents.conversion import to_json
from cloudevents.http import CloudEvent
from websockets import Headers, connect


class EventSender:
def __init__(
self,
ens_id: Optional[str],
ee_uri: Optional[str],
ee_cert: Optional[str],
ee_token: Optional[str],
) -> None:
self.ens_id = ens_id
self.ee_uri = ee_uri
self.ee_cert = ee_cert
self.ee_token = ee_token
self.events: asyncio.Queue[CloudEvent] = asyncio.Queue()

async def send(
self,
type: str,
source: str,
attributes: Optional[Mapping[str, Any]] = None,
data: Optional[Mapping[str, Any]] = None,
) -> None:
event = CloudEvent(
{
"type": type,
"source": str(Path(f"/ert/ensemble/{self.ens_id}") / source),
**(attributes or {}),
},
data,
)
await self.events.put(event)

async def publisher(self) -> None:
if not self.ee_uri:
return
tls: Optional[ssl.SSLContext] = None
if self.ee_cert:
tls = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
tls.load_verify_locations(cadata=self.ee_cert)
headers = Headers()
if self.ee_token:
headers["token"] = self.ee_token

async for conn in connect(
self.ee_uri,
ssl=tls,
extra_headers=headers,
open_timeout=60,
ping_timeout=60,
ping_interval=60,
close_timeout=60,
):
while True:
event = await self.events.get()
await conn.send(to_json(event))
65 changes: 32 additions & 33 deletions src/ert/scheduler/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,14 @@
from enum import Enum
from typing import TYPE_CHECKING, Optional

from cloudevents.conversion import to_json
from cloudevents.http import CloudEvent

from ert.callbacks import forward_model_ok
from ert.ensemble_evaluator.identifiers import EVTYPE_REALIZATION_TIMEOUT
from ert.job_queue.queue import _queue_state_event_type
from ert.load_status import LoadStatus
from ert.scheduler.driver import Driver
from ert.scheduler.event_sender import EventSender

if TYPE_CHECKING:
from ert.ensemble_evaluator._builder._realization import Realization
from ert.scheduler.scheduler import Scheduler

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -52,28 +48,26 @@ class Job:
(LSF, PBS, SLURM, etc.)
"""

def __init__(self, scheduler: Scheduler, real: Realization) -> None:
def __init__(self, real: Realization) -> None:
self.real = real
self.started = asyncio.Event()
self.returncode: asyncio.Future[int] = asyncio.Future()
self.aborted = asyncio.Event()
self._scheduler = scheduler
self._event_sender: Optional[EventSender] = None

@property
def iens(self) -> int:
return self.real.iens

@property
def driver(self) -> Driver:
return self._scheduler.driver

async def _submit_and_run_once(self, sem: asyncio.BoundedSemaphore) -> None:
async def _submit_and_run_once(
self, sem: asyncio.BoundedSemaphore, driver: Driver
) -> None:
await sem.acquire()
timeout_task: Optional[asyncio.Task[None]] = None

try:
await self._send(State.SUBMITTING)
await self.driver.submit(
await driver.submit(
self.real.iens, self.real.job_script, cwd=self.real.run_arg.runpath
)

Expand All @@ -100,7 +94,7 @@ async def _submit_and_run_once(self, sem: asyncio.BoundedSemaphore) -> None:

except asyncio.CancelledError:
await self._send(State.ABORTING)
await self.driver.kill(self.iens)
await driver.kill(self.iens)
await self.aborted.wait()
await self._send(State.ABORTED)
finally:
Expand All @@ -109,12 +103,16 @@ async def _submit_and_run_once(self, sem: asyncio.BoundedSemaphore) -> None:
sem.release()

async def __call__(
self, start: asyncio.Event, sem: asyncio.BoundedSemaphore, max_submit: int = 2
self,
sem: asyncio.BoundedSemaphore,
event_sender: EventSender,
driver: Driver,
max_submit: int = 2,
) -> None:
await start.wait()
self._event_sender = event_sender

for attempt in range(max_submit):
await self._submit_and_run_once(sem)
await self._submit_and_run_once(sem, driver)

if self.returncode.done() or self.aborted.is_set():
break
Expand All @@ -128,28 +126,29 @@ async def __call__(
async def _max_runtime_task(self) -> None:
assert self.real.max_runtime is not None
await asyncio.sleep(self.real.max_runtime)
timeout_event = CloudEvent(
{
"type": EVTYPE_REALIZATION_TIMEOUT,
"source": f"/ert/ensemble/{self._scheduler._ens_id}/real/{self.iens}",
"id": str(uuid.uuid1()),
}
)
assert self._scheduler._events is not None
await self._scheduler._events.put(to_json(timeout_event))

self.returncode.cancel() # Triggers CancelledError
self.returncode.cancel()

if self._event_sender is None:
return
await self._event_sender.send(
EVTYPE_REALIZATION_TIMEOUT,
f"/real/{self.iens}",
attributes={"id": str(uuid.uuid1())},
)

async def _send(self, state: State) -> None:
if self._event_sender is None:
return

status = STATE_TO_LEGACY[state]
event = CloudEvent(
{
"type": _queue_state_event_type(status),
"source": f"/ert/ensemble/{self._scheduler._ens_id}/real/{self.iens}",
await self._event_sender.send(
status,
f"/real/{self.iens}",
attributes={
"datacontenttype": "application/json",
},
{
data={
"queue_event_type": status,
},
)
await self._scheduler._events.put(to_json(event))
62 changes: 17 additions & 45 deletions src/ert/scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,21 @@
import json
import logging
import os
import ssl
from dataclasses import asdict
from typing import (
TYPE_CHECKING,
Any,
Mapping,
MutableMapping,
Optional,
Sequence,
)

from pydantic.dataclasses import dataclass
from websockets import Headers
from websockets.client import connect

from ert.async_utils import background_tasks
from ert.job_queue.queue import EVTYPE_ENSEMBLE_CANCELLED, EVTYPE_ENSEMBLE_STOPPED
from ert.scheduler.driver import Driver, JobEvent
from ert.scheduler.event_sender import EventSender
from ert.scheduler.job import Job
from ert.scheduler.local_driver import LocalDriver

Expand Down Expand Up @@ -60,19 +57,20 @@ def __init__(
self.driver = driver
self._tasks: MutableMapping[int, asyncio.Task[None]] = {}

self._jobs: Mapping[int, Job] = {
real.iens: Job(self, real) for real in (realizations or [])
}

self._events: asyncio.Queue[Any] = asyncio.Queue()
self._cancelled = False
self._max_submit = max_submit
self._max_running = max_running

self._ee_uri = ee_uri
self._ens_id = ens_id
self._ee_cert = ee_cert
self._ee_token = ee_token
self.event_sender = EventSender(
ens_id=ens_id,
ee_uri=ee_uri,
ee_cert=ee_cert,
ee_token=ee_token,
)

self._jobs: Mapping[int, Job] = {
real.iens: Job(real) for real in (realizations or [])
}

def kill_all_jobs(self) -> None:
self._cancelled = True
Expand All @@ -82,30 +80,6 @@ def kill_all_jobs(self) -> None:
def stop_long_running_jobs(self, minimum_required_realizations: int) -> None:
pass

async def _publisher(self) -> None:
if not self._ee_uri:
return
tls: Optional[ssl.SSLContext] = None
if self._ee_cert:
tls = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
tls.load_verify_locations(cadata=self._ee_cert)
headers = Headers()
if self._ee_token:
headers["token"] = self._ee_token

async with connect(
self._ee_uri,
ssl=tls,
extra_headers=headers,
open_timeout=60,
ping_timeout=60,
ping_interval=60,
close_timeout=60,
) as conn:
while True:
event = await self._events.get()
await conn.send(event)

def add_dispatch_information_to_jobs_file(self) -> None:
for job in self._jobs.values():
self._update_jobs_json(job.iens, job.real.run_arg.runpath)
Expand All @@ -114,18 +88,16 @@ async def execute(
self,
) -> str:
async with background_tasks() as cancel_when_execute_is_done:
cancel_when_execute_is_done(self._publisher())
cancel_when_execute_is_done(self.event_sender.publisher())
cancel_when_execute_is_done(self._process_event_queue())
cancel_when_execute_is_done(self.driver.poll())

start = asyncio.Event()
sem = asyncio.BoundedSemaphore(self._max_running)
for iens, job in self._jobs.items():
self._tasks[iens] = asyncio.create_task(
job(start, sem, self._max_submit)
job(sem, self.event_sender, self.driver, self._max_submit)
)

start.set()
for task in self._tasks.values():
await task

Expand All @@ -151,11 +123,11 @@ async def _process_event_queue(self) -> None:
def _update_jobs_json(self, iens: int, runpath: str) -> None:
jobs = _JobsJson(
experiment_id="_",
ens_id=self._ens_id,
ens_id=self.event_sender.ens_id,
real_id=str(iens),
dispatch_url=self._ee_uri,
ee_token=self._ee_token,
ee_cert_path=self._ee_cert,
dispatch_url=self.event_sender.ee_uri,
ee_token=self.event_sender.ee_token,
ee_cert_path=self.event_sender.ee_cert,
)
jobs_path = os.path.join(runpath, "jobs.json")
with open(jobs_path, "r") as fp:
Expand Down
Loading

0 comments on commit 23f0fdb

Please sign in to comment.