Skip to content
This repository has been archived by the owner on Jul 19, 2021. It is now read-only.

Commit

Permalink
Add certificates and tokens to websocket communication
Browse files Browse the repository at this point in the history
  • Loading branch information
sondreso committed Apr 27, 2021
1 parent 3e329f0 commit 93469fc
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 11 deletions.
1 change: 1 addition & 0 deletions python/job_runner/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
CERT_FILE = ".ee.pem"
JOBS_FILE = "jobs.json"
LOG_URL = "http://devnull.statoil.no:4444"
16 changes: 13 additions & 3 deletions python/job_runner/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,20 @@
from job_runner import JOBS_FILE


def _setup_reporters(is_interactive_run, ee_id, evaluator_url):
def _setup_reporters(
is_interactive_run, ee_id, evaluator_url, ee_token=None, ee_cert_path=None
):
reporters = []
if is_interactive_run:
reporters.append(reporting.Interactive())
elif ee_id:
reporters.append(reporting.File(sync_disc_timeout=0))
reporters.append(reporting.Network())
reporters.append(reporting.Event(evaluator_url=evaluator_url))
reporters.append(
reporting.Event(
evaluator_url=evaluator_url, token=ee_token, cert_path=ee_cert_path
)
)
else:
reporters.append(reporting.File())
reporters.append(reporting.Network())
Expand Down Expand Up @@ -49,12 +55,16 @@ def main(args):
with open(JOBS_FILE, "r") as json_file:
jobs_data = json.load(json_file)
ee_id = jobs_data.get("ee_id")
ee_token = jobs_data.get("ee_token")
ee_cert_path = jobs_data.get("ee_cert_path")
evaluator_url = jobs_data.get("dispatch_url")
except ValueError as e:
raise IOError("Job Runner cli failed to load JSON-file.{}".format(str(e)))

is_interactive_run = len(parsed_args.job) > 0
reporters = _setup_reporters(is_interactive_run, ee_id, evaluator_url)
reporters = _setup_reporters(
is_interactive_run, ee_id, evaluator_url, ee_token, ee_cert_path
)

job_runner = JobRunner(jobs_data)

Expand Down
10 changes: 8 additions & 2 deletions python/job_runner/reporting/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,14 @@ class TransitionError(ValueError):


class Event:
def __init__(self, evaluator_url):
def __init__(self, evaluator_url, token=None, cert_path=None):
self._evaluator_url = evaluator_url
self._token = token
if cert_path is not None:
with open(cert_path) as f:
self._cert = f.read()
else:
self._cert = None

self._ee_id = None
self._real_id = None
Expand Down Expand Up @@ -63,7 +69,7 @@ def report(self, msg):
self._state = new_state

def _dump_event(self, event):
with Client(self._evaluator_url) as client:
with Client(self._evaluator_url, self._token, self._cert) as client:
client.send(to_json(event).decode())

def _step_path(self):
Expand Down
27 changes: 25 additions & 2 deletions python/job_runner/util/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
from websockets import ConnectionClosedOK
import asyncio
import cloudevents
import ssl
from websockets.exceptions import ConnectionClosed
from websockets.http import Headers


class Client:
Expand All @@ -14,17 +16,38 @@ def __exit__(self, exc_type, exc_value, exc_traceback):
self.loop.run_until_complete(self.websocket.close())
self.loop.close()

def __init__(self, url, max_retries=10, timeout_multiplier=5):
def __init__(
self, url, token=None, cert=None, max_retries=10, timeout_multiplier=5
):
if url is None:
raise ValueError("url was None")
self.url = url
self.token = token
self._extra_headers = Headers()
if token is not None:
self._extra_headers["token"] = token

# Mimics the behavior of the ssl argument when connection to
# websockets. If none is specified it will deduce based on the url,
# if True it will enforce TLS, and if you want to use self signed
# certificates you need to pass an ssl_context with the certificate
# loaded.
if cert is not None:
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ssl_context.load_verify_locations(cadata=cert)
else:
ssl_context = True if url.startswith("wss") else None
self._ssl_context = ssl_context

self._max_retries = max_retries
self._timeout_multiplier = timeout_multiplier
self.websocket = None
self.loop = asyncio.new_event_loop()

async def get_websocket(self):
return await websockets.connect(self.url)
return await websockets.connect(
self.url, ssl=self._ssl_context, extra_headers=self._extra_headers
)

async def _send(self, msg):
for retry in range(self._max_retries + 1):
Expand Down
31 changes: 27 additions & 4 deletions python/res/job_queue/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@
import json
import logging
import time
import ssl
import typing

import websockets
from websockets.http import Headers
from cloudevents.http import CloudEvent, to_json
from cwrap import BaseCClass
from job_runner import JOBS_FILE
from job_runner import JOBS_FILE, CERT_FILE
from res import ResPrototype
from res.job_queue import JobQueueNode, JobStatusType, ThreadStatus

Expand Down Expand Up @@ -432,8 +434,20 @@ async def _publish_changes(changes, websocket):
for event in events:
await websocket.send(to_json(event))

async def execute_queue_async(self, ws_uri, pool_sema, evaluators):
async with websockets.connect(ws_uri) as websocket:
async def execute_queue_async(
self, ws_uri, pool_sema, evaluators, cert=None, token=None
):
if cert is not None:
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ssl_context.load_verify_locations(cadata=cert)
else:
ssl_context = True if ws_uri.startswith("wss") else None
headers = Headers()
if token is not None:
headers["token"] = token
async with websockets.connect(
ws_uri, ssl=ssl_context, extra_headers=headers
) as websocket:
await JobQueue._publish_changes(self.snapshot(), websocket)

try:
Expand Down Expand Up @@ -569,14 +583,23 @@ def snapshot(self) -> typing.Optional[typing.Dict[int, str]]:
return None
return snapshot

def add_ensemble_evaluator_information_to_jobs_file(self, ee_id, dispatch_url):
def add_ensemble_evaluator_information_to_jobs_file(
self, ee_id, dispatch_url, cert, token
):
for q_index, q_node in enumerate(self.job_list):
cert_path = f"{q_node.run_path}/{CERT_FILE}"
with open(cert_path, "w") as cert_file:
cert_file.write(cert)
with open(f"{q_node.run_path}/{JOBS_FILE}", "r+") as jobs_file:
data = json.load(jobs_file)

data["ee_id"] = ee_id
data["real_id"] = self._qindex_to_iens[q_index]
data["step_id"] = 0
data["dispatch_url"] = dispatch_url
data["ee_token"] = token
data["ee_cert_path"] = cert_path

jobs_file.seek(0)
jobs_file.truncate()
json.dump(data, jobs_file, indent=4)

0 comments on commit 93469fc

Please sign in to comment.