diff --git a/src/ert/ensemble_evaluator/_builder/_legacy.py b/src/ert/ensemble_evaluator/_builder/_legacy.py index c8ce0ad763c..447e3d08b58 100644 --- a/src/ert/ensemble_evaluator/_builder/_legacy.py +++ b/src/ert/ensemble_evaluator/_builder/_legacy.py @@ -194,23 +194,21 @@ async def _evaluate_inner( # pylint: disable=too-many-branches ) ] - # Tell queue to pass info to the jobs-file - # NOTE: This touches files on disk... - sema = threading.BoundedSemaphore(value=CONCURRENT_INTERNALIZATION) - self._job_queue.add_dispatch_information_to_jobs_file( + self._job_queue.set_ee_info( + ee_uri=self._config.dispatch_uri, ens_id=self.id_, - dispatch_url=self._config.dispatch_uri, - cert=self._config.cert, - token=self._config.token, + ee_cert=self._config.cert, + ee_token=self._config.token, ) + # Tell queue to pass info to the jobs-file + # NOTE: This touches files on disk... + self._job_queue.add_dispatch_information_to_jobs_file() + + sema = threading.BoundedSemaphore(value=CONCURRENT_INTERNALIZATION) result: str = await self._job_queue.execute( - self._config.dispatch_uri, - self.id_, sema, queue_evaluators, - ee_cert=self._config.cert, - ee_token=self._config.token, ) except Exception: diff --git a/src/ert/job_queue/queue.py b/src/ert/job_queue/queue.py index 73bf2df21fc..3832d7048c3 100644 --- a/src/ert/job_queue/queue.py +++ b/src/ert/job_queue/queue.py @@ -112,7 +112,7 @@ def __init__(self, queue_config: QueueConfig): self._max_submit = queue_config.max_submit self._pool_sema = BoundedSemaphore(value=CONCURRENT_INTERNALIZATION) - self.ens_id: Optional[str] = None + self._ens_id: Optional[str] = None self._ee_uri: Optional[str] = None self._ee_cert: Optional[Union[str, bytes]] = None self._ee_token: Optional[str] = None @@ -214,6 +214,27 @@ def launch_jobs(self, pool_sema: Semaphore) -> None: max_submit=self.max_submit, ) + def set_ee_info( + self, + ee_uri: str, + ens_id: str, + ee_cert: Optional[Union[str, bytes]] = None, + ee_token: Optional[str] = None, + verify_context: bool = True, + ) -> None: + self._ens_id = ens_id + self._ee_token = ee_token + + self._ee_uri = ee_uri + if ee_cert is not None: + self._ee_cert = ee_cert + self._ee_token = ee_token + self._ee_ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + if verify_context: + self._ee_ssl_context.load_verify_locations(cadata=ee_cert) + else: + self._ee_ssl_context = True if ee_uri.startswith("wss") else None + @staticmethod def _translate_change_to_cloudevent( ens_id: str, real_id: int, status: str @@ -232,10 +253,10 @@ def _translate_change_to_cloudevent( async def _publish_changes( self, changes: Dict[int, str], ee_connection: WebSocketClientProtocol ) -> None: - assert self.ens_id is not None # mypy + assert self._ens_id is not None # mypy events = deque( [ - JobQueue._translate_change_to_cloudevent(self.ens_id, real_id, status) + JobQueue._translate_change_to_cloudevent(self._ens_id, real_id, status) for real_id, status in changes.items() ] ) @@ -274,30 +295,13 @@ async def _jobqueue_publisher(self) -> None: async def execute( self, - ee_uri: Optional[str] = None, - ens_id: Optional[str] = None, pool_sema: Optional[threading.BoundedSemaphore] = None, evaluators: Optional[Iterable[Callable[..., Any]]] = None, - ee_cert: Optional[Union[str, bytes]] = None, - ee_token: Optional[str] = None, ) -> str: - self.ens_id = ens_id - self._ee_token = ee_token - if pool_sema is not None: self._pool_sema = pool_sema if evaluators is None: evaluators = [] - if ee_uri is not None: - self._ee_uri = ee_uri - if ee_cert is not None: - self._ee_cert = ee_cert - self._ee_token = ee_token - self._ee_ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - self._ee_ssl_context.load_verify_locations(cadata=ee_cert) - else: - self._ee_ssl_context = True if ee_uri.startswith("wss") else None - if self._ee_uri is not None: self._changes_to_publish = asyncio.Queue() asyncio.create_task(self._jobqueue_publisher()) @@ -425,27 +429,23 @@ def changes_without_transition(self) -> Tuple[Dict[int, str], List[JobStatus]]: def add_dispatch_information_to_jobs_file( self, - ens_id: str, - dispatch_url: str, - cert: Optional[str], - token: Optional[str], experiment_id: Optional[str] = None, ) -> None: for q_index, q_node in enumerate(self.job_list): cert_path = f"{q_node.run_path}/{CERT_FILE}" - if cert is not None: + if self._ee_cert is not None: with open(cert_path, "w", encoding="utf-8") as cert_file: - cert_file.write(cert) + cert_file.write(self._ee_cert) with open( f"{q_node.run_path}/{JOBS_FILE}", "r+", encoding="utf-8" ) as jobs_file: data = json.load(jobs_file) - data["ens_id"] = ens_id + data["ens_id"] = self._ens_id data["real_id"] = self._differ.qindex_to_iens(q_index) - data["dispatch_url"] = dispatch_url - data["ee_token"] = token - data["ee_cert_path"] = cert_path if cert is not None else None + data["dispatch_url"] = self._ee_uri + data["ee_token"] = self._ee_token + data["ee_cert_path"] = cert_path if self._ee_cert is not None else None data["experiment_id"] = experiment_id jobs_file.seek(0) diff --git a/tests/unit_tests/ensemble_evaluator/test_async_queue_execution.py b/tests/unit_tests/ensemble_evaluator/test_async_queue_execution.py index 5dd868869a9..26c243cebac 100644 --- a/tests/unit_tests/ensemble_evaluator/test_async_queue_execution.py +++ b/tests/unit_tests/ensemble_evaluator/test_async_queue_execution.py @@ -55,7 +55,8 @@ async def test_happy_path( for real in ensemble.reals: queue.add_realization(real, callback_timeout=None) - await queue.execute(url, "ee_0", threading.BoundedSemaphore(value=10), None) + queue.set_ee_info(ee_uri=url, ens_id="ee_0") + await queue.execute(pool_sema=threading.BoundedSemaphore(value=10)) done.set_result(None) await mock_ws_task diff --git a/tests/unit_tests/ensemble_evaluator/test_ensemble_legacy.py b/tests/unit_tests/ensemble_evaluator/test_ensemble_legacy.py index 8da74c02dd5..2472a13d0d2 100644 --- a/tests/unit_tests/ensemble_evaluator/test_ensemble_legacy.py +++ b/tests/unit_tests/ensemble_evaluator/test_ensemble_legacy.py @@ -70,7 +70,6 @@ def test_run_and_cancel_legacy_ensemble(tmpdir, make_ensemble_builder, monkeypat if cancel: mon.signal_cancel() cancel = False - print("tracking..") assert evaluator._ensemble.status == state.ENSEMBLE_STATE_CANCELLED @@ -102,9 +101,7 @@ def test_run_legacy_ensemble_exception(tmpdir, make_ensemble_builder, monkeypatc state.ENSEMBLE_STATE_FAILED, state.ENSEMBLE_STATE_STOPPED, ]: - print("FAILED OR STOPPED") monitor.signal_done() - print("retracking") assert evaluator._ensemble.status == state.ENSEMBLE_STATE_FAILED # realisations should not finish, thus not creating a status-file diff --git a/tests/unit_tests/job_queue/test_job_queue.py b/tests/unit_tests/job_queue/test_job_queue.py index 8dacc7f86e5..e5cabca2f47 100644 --- a/tests/unit_tests/job_queue/test_job_queue.py +++ b/tests/unit_tests/job_queue/test_job_queue.py @@ -232,11 +232,14 @@ def test_add_dispatch_info(tmpdir, monkeypatch, simple_script): runpaths = [Path(DUMMY_CONFIG["run_path"].format(iens)) for iens in range(10)] for runpath in runpaths: (runpath / "jobs.json").write_text(json.dumps({}), encoding="utf-8") - job_queue.add_dispatch_information_to_jobs_file( + job_queue.set_ee_info( + ee_uri=dispatch_url, ens_id=ens_id, - dispatch_url=dispatch_url, - cert=cert, - token=token, + ee_cert=cert, + ee_token=token, + verify_context=False, + ) + job_queue.add_dispatch_information_to_jobs_file( experiment_id="experiment_id", ) @@ -262,12 +265,10 @@ def test_add_dispatch_info_cert_none(tmpdir, monkeypatch, simple_script): runpaths = [Path(DUMMY_CONFIG["run_path"].format(iens)) for iens in range(10)] for runpath in runpaths: (runpath / "jobs.json").write_text(json.dumps({}), encoding="utf-8") - job_queue.add_dispatch_information_to_jobs_file( - ens_id=ens_id, - dispatch_url=dispatch_url, - cert=cert, - token=token, + job_queue.set_ee_info( + ee_uri=dispatch_url, ens_id=ens_id, ee_cert=cert, ee_token=token ) + job_queue.add_dispatch_information_to_jobs_file() for runpath in runpaths: job_file_path = runpath / "jobs.json"