diff --git a/python/job_runner/__init__.py b/python/job_runner/__init__.py index 017359bcd9..dbfb19a91e 100644 --- a/python/job_runner/__init__.py +++ b/python/job_runner/__init__.py @@ -1,2 +1,3 @@ +CERT_FILE = ".ee.pem" JOBS_FILE = "jobs.json" LOG_URL = "http://devnull.statoil.no:4444" diff --git a/python/job_runner/cli.py b/python/job_runner/cli.py index ec2aa1ca40..a1cf13707c 100644 --- a/python/job_runner/cli.py +++ b/python/job_runner/cli.py @@ -10,14 +10,20 @@ from job_runner import JOBS_FILE -def _setup_reporters(is_interactive_run, ee_id, evaluator_url): +def _setup_reporters( + is_interactive_run, ee_id, evaluator_url, ee_token=None, ee_cert_path=None +): reporters = [] if is_interactive_run: reporters.append(reporting.Interactive()) elif ee_id: reporters.append(reporting.File(sync_disc_timeout=0)) reporters.append(reporting.Network()) - reporters.append(reporting.Event(evaluator_url=evaluator_url)) + reporters.append( + reporting.Event( + evaluator_url=evaluator_url, token=ee_token, cert_path=ee_cert_path + ) + ) else: reporters.append(reporting.File()) reporters.append(reporting.Network()) @@ -49,12 +55,16 @@ def main(args): with open(JOBS_FILE, "r") as json_file: jobs_data = json.load(json_file) ee_id = jobs_data.get("ee_id") + ee_token = jobs_data.get("ee_token") + ee_cert_path = jobs_data.get("ee_cert_path") evaluator_url = jobs_data.get("dispatch_url") except ValueError as e: raise IOError("Job Runner cli failed to load JSON-file.{}".format(str(e))) is_interactive_run = len(parsed_args.job) > 0 - reporters = _setup_reporters(is_interactive_run, ee_id, evaluator_url) + reporters = _setup_reporters( + is_interactive_run, ee_id, evaluator_url, ee_token, ee_cert_path + ) job_runner = JobRunner(jobs_data) diff --git a/python/job_runner/reporting/event.py b/python/job_runner/reporting/event.py index 9706fdbe6a..3424117faa 100644 --- a/python/job_runner/reporting/event.py +++ b/python/job_runner/reporting/event.py @@ -21,8 +21,14 @@ class TransitionError(ValueError): class Event: - def __init__(self, evaluator_url): + def __init__(self, evaluator_url, token=None, cert_path=None): self._evaluator_url = evaluator_url + self._token = token + if cert_path is not None: + with open(cert_path) as f: + self._cert = f.read() + else: + self._cert = None self._ee_id = None self._real_id = None @@ -63,7 +69,7 @@ def report(self, msg): self._state = new_state def _dump_event(self, event): - with Client(self._evaluator_url) as client: + with Client(self._evaluator_url, self._token, self._cert) as client: client.send(to_json(event).decode()) def _step_path(self): diff --git a/python/job_runner/util/client.py b/python/job_runner/util/client.py index 9f6e3d315a..b381ad95ca 100644 --- a/python/job_runner/util/client.py +++ b/python/job_runner/util/client.py @@ -2,7 +2,9 @@ from websockets import ConnectionClosedOK import asyncio import cloudevents +import ssl from websockets.exceptions import ConnectionClosed +from websockets.http import Headers class Client: @@ -14,17 +16,38 @@ def __exit__(self, exc_type, exc_value, exc_traceback): self.loop.run_until_complete(self.websocket.close()) self.loop.close() - def __init__(self, url, max_retries=10, timeout_multiplier=5): + def __init__( + self, url, token=None, cert=None, max_retries=10, timeout_multiplier=5 + ): if url is None: raise ValueError("url was None") self.url = url + self.token = token + self._extra_headers = Headers() + if token is not None: + self._extra_headers["token"] = token + + # Mimics the behavior of the ssl argument when connection to + # websockets. If none is specified it will deduce based on the url, + # if True it will enforce TLS, and if you want to use self signed + # certificates you need to pass an ssl_context with the certificate + # loaded. + if cert is not None: + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ssl_context.load_verify_locations(cadata=cert) + else: + ssl_context = True if url.startswith("wss") else None + self._ssl_context = ssl_context + self._max_retries = max_retries self._timeout_multiplier = timeout_multiplier self.websocket = None self.loop = asyncio.new_event_loop() async def get_websocket(self): - return await websockets.connect(self.url) + return await websockets.connect( + self.url, ssl=self._ssl_context, extra_headers=self._extra_headers + ) async def _send(self, msg): for retry in range(self._max_retries + 1): diff --git a/python/res/job_queue/queue.py b/python/res/job_queue/queue.py index 6d015a5654..c6556450a3 100644 --- a/python/res/job_queue/queue.py +++ b/python/res/job_queue/queue.py @@ -23,12 +23,14 @@ import json import logging import time +import ssl import typing import websockets +from websockets.http import Headers from cloudevents.http import CloudEvent, to_json from cwrap import BaseCClass -from job_runner import JOBS_FILE +from job_runner import JOBS_FILE, CERT_FILE from res import ResPrototype from res.job_queue import JobQueueNode, JobStatusType, ThreadStatus @@ -432,8 +434,20 @@ async def _publish_changes(changes, websocket): for event in events: await websocket.send(to_json(event)) - async def execute_queue_async(self, ws_uri, pool_sema, evaluators): - async with websockets.connect(ws_uri) as websocket: + async def execute_queue_async( + self, ws_uri, pool_sema, evaluators, cert=None, token=None + ): + if cert is not None: + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ssl_context.load_verify_locations(cadata=cert) + else: + ssl_context = True if ws_uri.startswith("wss") else None + headers = Headers() + if token is not None: + headers["token"] = token + async with websockets.connect( + ws_uri, ssl=ssl_context, extra_headers=headers + ) as websocket: await JobQueue._publish_changes(self.snapshot(), websocket) try: @@ -569,14 +583,23 @@ def snapshot(self) -> typing.Optional[typing.Dict[int, str]]: return None return snapshot - def add_ensemble_evaluator_information_to_jobs_file(self, ee_id, dispatch_url): + def add_ensemble_evaluator_information_to_jobs_file( + self, ee_id, dispatch_url, cert, token + ): for q_index, q_node in enumerate(self.job_list): + cert_path = f"{q_node.run_path}/{CERT_FILE}" + with open(cert_path, "w") as cert_file: + cert_file.write(cert) with open(f"{q_node.run_path}/{JOBS_FILE}", "r+") as jobs_file: data = json.load(jobs_file) + data["ee_id"] = ee_id data["real_id"] = self._qindex_to_iens[q_index] data["step_id"] = 0 data["dispatch_url"] = dispatch_url + data["ee_token"] = token + data["ee_cert_path"] = cert_path + jobs_file.seek(0) jobs_file.truncate() json.dump(data, jobs_file, indent=4)