diff --git a/src/ert/job_queue/queue.py b/src/ert/job_queue/queue.py index 6cd533be01c..03f8a76d14a 100644 --- a/src/ert/job_queue/queue.py +++ b/src/ert/job_queue/queue.py @@ -10,6 +10,7 @@ import logging import ssl from collections import deque +from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union from cloudevents.conversion import to_json @@ -77,6 +78,14 @@ def _queue_state_event_type(state: str) -> str: return _queue_state_to_event_type_map[state] +@dataclass +class _EnsembleEvaluator: + uri: str + token: Optional[str] + certificate: Optional[Union[str, bytes]] = None + ssl_context: Optional[Union[ssl.SSLContext, bool]] = None + + class JobQueue: """Represents a queue of realizations (aka Jobs) to be executed on a cluster.""" @@ -88,12 +97,8 @@ def __init__(self, queue_config: "QueueConfig"): self._queue_stopped = False - # Wrap these in a dataclass? - self._ens_id: Optional[str] = None - self._ee_uri: Optional[str] = None - self._ee_cert: Optional[Union[str, bytes]] = None - self._ee_token: Optional[str] = None - self._ee_ssl_context: Optional[Union[ssl.SSLContext, bool]] = None + self._ensemble_id: Optional[str] = None + self._ee: Optional[_EnsembleEvaluator] = None self._changes_to_publish: Optional[ asyncio.Queue[Union[Dict[int, str], object]] @@ -206,18 +211,16 @@ def set_ee_info( ee_token: Optional[str] = None, verify_context: bool = True, ) -> None: - self._ens_id = ens_id - self._ee_token = ee_token - - self._ee_uri = ee_uri + self._ensemble_id = ens_id + self._ee = _EnsembleEvaluator(token=ee_token, uri=ee_uri) if ee_cert is not None: - self._ee_cert = ee_cert - self._ee_token = ee_token - self._ee_ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + self._ee.certificate = ee_cert + self._ee.token = ee_token + self._ee.ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) if verify_context: - self._ee_ssl_context.load_verify_locations(cadata=ee_cert) + self._ee.ssl_context.load_verify_locations(cadata=ee_cert) else: - self._ee_ssl_context = True if ee_uri.startswith("wss") else None + self._ee.ssl_context = True if ee_uri.startswith("wss") else None @staticmethod def _translate_change_to_cloudevent( @@ -237,10 +240,12 @@ def _translate_change_to_cloudevent( async def _publish_changes( self, changes: Dict[int, str], ee_connection: WebSocketClientProtocol ) -> None: - assert self._ens_id is not None + assert self._ensemble_id is not None events = deque( [ - JobQueue._translate_change_to_cloudevent(self._ens_id, real_id, status) + JobQueue._translate_change_to_cloudevent( + self._ensemble_id, real_id, status + ) for real_id, status in changes.items() ] ) @@ -249,11 +254,7 @@ async def _publish_changes( events.popleft() async def _jobqueue_publisher(self) -> None: - ee_headers = Headers() - if self._ee_token is not None: - ee_headers["token"] = self._ee_token - - if self._ee_uri is None: + if self._ee is None: # If no ensemble evaluator present, we will publish to the log assert self._changes_to_publish is not None while ( @@ -262,9 +263,13 @@ async def _jobqueue_publisher(self) -> None: logger.warning(f"State change in jobqueue.execute(): {change}") return + ee_headers = Headers() + if self._ee.token is not None: + ee_headers["token"] = self._ee.token + async for ee_connection in connect( - self._ee_uri, - ssl=self._ee_ssl_context, + self._ee.uri, + ssl=self._ee.ssl_context, extra_headers=ee_headers, open_timeout=60, ping_timeout=60, @@ -418,21 +423,24 @@ def add_dispatch_information_to_jobs_file( self, experiment_id: Optional[str] = None, ) -> None: + assert self._ee is not None for job in self._realizations: cert_path = f"{job.realization.run_arg.runpath}/{CERT_FILE}" - if self._ee_cert is not None: + if self._ee.certificate is not None: with open(cert_path, "w", encoding="utf-8") as cert_file: - cert_file.write(str(self._ee_cert)) + cert_file.write(str(self._ee.certificate)) with open( f"{job.realization.run_arg.runpath}/{JOBS_FILE}", "r+", encoding="utf-8" ) as jobs_file: data = json.load(jobs_file) - data["ens_id"] = self._ens_id + data["ens_id"] = self._ensemble_id data["real_id"] = job.realization.run_arg.iens - data["dispatch_url"] = self._ee_uri - data["ee_token"] = self._ee_token - data["ee_cert_path"] = cert_path if self._ee_cert is not None else None + data["dispatch_url"] = self._ee.uri + data["ee_token"] = self._ee.token + data["ee_cert_path"] = ( + cert_path if self._ee.certificate is not None else None + ) data["experiment_id"] = experiment_id jobs_file.seek(0)