Skip to content

Commit

Permalink
Group ensemble evaluator config variables in job queue
Browse files Browse the repository at this point in the history
  • Loading branch information
JHolba committed Nov 21, 2023
1 parent 8e94c28 commit d05e43b
Showing 1 changed file with 38 additions and 30 deletions.
68 changes: 38 additions & 30 deletions src/ert/job_queue/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import logging
import ssl
from collections import deque
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union

from cloudevents.conversion import to_json
Expand Down Expand Up @@ -77,6 +78,14 @@ def _queue_state_event_type(state: str) -> str:
return _queue_state_to_event_type_map[state]


@dataclass
class _EnsembleEvaluator:
uri: str
token: Optional[str]
certificate: Optional[Union[str, bytes]] = None
ssl_context: Optional[Union[ssl.SSLContext, bool]] = None


class JobQueue:
"""Represents a queue of realizations (aka Jobs) to be executed on a
cluster."""
Expand All @@ -88,12 +97,8 @@ def __init__(self, queue_config: "QueueConfig"):

self._queue_stopped = False

# Wrap these in a dataclass?
self._ens_id: Optional[str] = None
self._ee_uri: Optional[str] = None
self._ee_cert: Optional[Union[str, bytes]] = None
self._ee_token: Optional[str] = None
self._ee_ssl_context: Optional[Union[ssl.SSLContext, bool]] = None
self._ensemble_id: Optional[str] = None
self._ee: Optional[_EnsembleEvaluator] = None

self._changes_to_publish: Optional[
asyncio.Queue[Union[Dict[int, str], object]]
Expand Down Expand Up @@ -206,18 +211,16 @@ def set_ee_info(
ee_token: Optional[str] = None,
verify_context: bool = True,
) -> None:
self._ens_id = ens_id
self._ee_token = ee_token

self._ee_uri = ee_uri
self._ensemble_id = ens_id
self._ee = _EnsembleEvaluator(token=ee_token, uri=ee_uri)
if ee_cert is not None:
self._ee_cert = ee_cert
self._ee_token = ee_token
self._ee_ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
self._ee.certificate = ee_cert
self._ee.token = ee_token
self._ee.ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
if verify_context:
self._ee_ssl_context.load_verify_locations(cadata=ee_cert)
self._ee.ssl_context.load_verify_locations(cadata=ee_cert)
else:
self._ee_ssl_context = True if ee_uri.startswith("wss") else None
self._ee.ssl_context = True if ee_uri.startswith("wss") else None

@staticmethod
def _translate_change_to_cloudevent(
Expand All @@ -237,10 +240,12 @@ def _translate_change_to_cloudevent(
async def _publish_changes(
self, changes: Dict[int, str], ee_connection: WebSocketClientProtocol
) -> None:
assert self._ens_id is not None
assert self._ensemble_id is not None
events = deque(
[
JobQueue._translate_change_to_cloudevent(self._ens_id, real_id, status)
JobQueue._translate_change_to_cloudevent(
self._ensemble_id, real_id, status
)
for real_id, status in changes.items()
]
)
Expand All @@ -249,11 +254,7 @@ async def _publish_changes(
events.popleft()

async def _jobqueue_publisher(self) -> None:
ee_headers = Headers()
if self._ee_token is not None:
ee_headers["token"] = self._ee_token

if self._ee_uri is None:
if self._ee is None:
# If no ensemble evaluator present, we will publish to the log
assert self._changes_to_publish is not None
while (
Expand All @@ -262,9 +263,13 @@ async def _jobqueue_publisher(self) -> None:
logger.warning(f"State change in jobqueue.execute(): {change}")
return

ee_headers = Headers()
if self._ee.token is not None:
ee_headers["token"] = self._ee.token

async for ee_connection in connect(
self._ee_uri,
ssl=self._ee_ssl_context,
self._ee.uri,
ssl=self._ee.ssl_context,
extra_headers=ee_headers,
open_timeout=60,
ping_timeout=60,
Expand Down Expand Up @@ -418,21 +423,24 @@ def add_dispatch_information_to_jobs_file(
self,
experiment_id: Optional[str] = None,
) -> None:
assert self._ee is not None
for job in self._realizations:
cert_path = f"{job.realization.run_arg.runpath}/{CERT_FILE}"
if self._ee_cert is not None:
if self._ee.certificate is not None:
with open(cert_path, "w", encoding="utf-8") as cert_file:
cert_file.write(str(self._ee_cert))
cert_file.write(str(self._ee.certificate))
with open(
f"{job.realization.run_arg.runpath}/{JOBS_FILE}", "r+", encoding="utf-8"
) as jobs_file:
data = json.load(jobs_file)

data["ens_id"] = self._ens_id
data["ens_id"] = self._ensemble_id
data["real_id"] = job.realization.run_arg.iens
data["dispatch_url"] = self._ee_uri
data["ee_token"] = self._ee_token
data["ee_cert_path"] = cert_path if self._ee_cert is not None else None
data["dispatch_url"] = self._ee.uri
data["ee_token"] = self._ee.token
data["ee_cert_path"] = (
cert_path if self._ee.certificate is not None else None
)
data["experiment_id"] = experiment_id

jobs_file.seek(0)
Expand Down

0 comments on commit d05e43b

Please sign in to comment.