Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Group ee variables in dataclass #6624

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 38 additions & 30 deletions src/ert/job_queue/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import logging
import ssl
from collections import deque
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union

from cloudevents.conversion import to_json
Expand Down Expand Up @@ -77,6 +78,14 @@ def _queue_state_event_type(state: str) -> str:
return _queue_state_to_event_type_map[state]


@dataclass
class _EnsembleEvaluator:
uri: str
token: Optional[str]
certificate: Optional[Union[str, bytes]] = None
ssl_context: Optional[Union[ssl.SSLContext, bool]] = None


class JobQueue:
"""Represents a queue of realizations (aka Jobs) to be executed on a
cluster."""
Expand All @@ -88,12 +97,8 @@ def __init__(self, queue_config: "QueueConfig"):

self._queue_stopped = False

# Wrap these in a dataclass?
self._ens_id: Optional[str] = None
self._ee_uri: Optional[str] = None
self._ee_cert: Optional[Union[str, bytes]] = None
self._ee_token: Optional[str] = None
self._ee_ssl_context: Optional[Union[ssl.SSLContext, bool]] = None
self._ensemble_id: Optional[str] = None
self._ee: Optional[_EnsembleEvaluator] = None

self._changes_to_publish: Optional[
asyncio.Queue[Union[Dict[int, str], object]]
Expand Down Expand Up @@ -206,18 +211,16 @@ def set_ee_info(
ee_token: Optional[str] = None,
verify_context: bool = True,
) -> None:
self._ens_id = ens_id
self._ee_token = ee_token

self._ee_uri = ee_uri
self._ensemble_id = ens_id
self._ee = _EnsembleEvaluator(token=ee_token, uri=ee_uri)
if ee_cert is not None:
self._ee_cert = ee_cert
self._ee_token = ee_token
self._ee_ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
self._ee.certificate = ee_cert
self._ee.token = ee_token
self._ee.ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
if verify_context:
self._ee_ssl_context.load_verify_locations(cadata=ee_cert)
self._ee.ssl_context.load_verify_locations(cadata=ee_cert)
else:
self._ee_ssl_context = True if ee_uri.startswith("wss") else None
self._ee.ssl_context = True if ee_uri.startswith("wss") else None

@staticmethod
def _translate_change_to_cloudevent(
Expand All @@ -237,10 +240,12 @@ def _translate_change_to_cloudevent(
async def _publish_changes(
self, changes: Dict[int, str], ee_connection: WebSocketClientProtocol
) -> None:
assert self._ens_id is not None
assert self._ensemble_id is not None
events = deque(
[
JobQueue._translate_change_to_cloudevent(self._ens_id, real_id, status)
JobQueue._translate_change_to_cloudevent(
self._ensemble_id, real_id, status
)
for real_id, status in changes.items()
]
)
Expand All @@ -249,11 +254,7 @@ async def _publish_changes(
events.popleft()

async def _jobqueue_publisher(self) -> None:
ee_headers = Headers()
if self._ee_token is not None:
ee_headers["token"] = self._ee_token

if self._ee_uri is None:
if self._ee is None:
# If no ensemble evaluator present, we will publish to the log
assert self._changes_to_publish is not None
while (
Expand All @@ -262,9 +263,13 @@ async def _jobqueue_publisher(self) -> None:
logger.warning(f"State change in jobqueue.execute(): {change}")
return

ee_headers = Headers()
if self._ee.token is not None:
ee_headers["token"] = self._ee.token

async for ee_connection in connect(
self._ee_uri,
ssl=self._ee_ssl_context,
self._ee.uri,
ssl=self._ee.ssl_context,
extra_headers=ee_headers,
open_timeout=60,
ping_timeout=60,
Expand Down Expand Up @@ -418,21 +423,24 @@ def add_dispatch_information_to_jobs_file(
self,
experiment_id: Optional[str] = None,
) -> None:
assert self._ee is not None
for job in self._realizations:
cert_path = f"{job.realization.run_arg.runpath}/{CERT_FILE}"
if self._ee_cert is not None:
if self._ee.certificate is not None:
with open(cert_path, "w", encoding="utf-8") as cert_file:
cert_file.write(str(self._ee_cert))
cert_file.write(str(self._ee.certificate))
with open(
f"{job.realization.run_arg.runpath}/{JOBS_FILE}", "r+", encoding="utf-8"
) as jobs_file:
data = json.load(jobs_file)

data["ens_id"] = self._ens_id
data["ens_id"] = self._ensemble_id
data["real_id"] = job.realization.run_arg.iens
data["dispatch_url"] = self._ee_uri
data["ee_token"] = self._ee_token
data["ee_cert_path"] = cert_path if self._ee_cert is not None else None
data["dispatch_url"] = self._ee.uri
data["ee_token"] = self._ee.token
data["ee_cert_path"] = (
cert_path if self._ee.certificate is not None else None
)
data["experiment_id"] = experiment_id

jobs_file.seek(0)
Expand Down
Loading