Skip to content
This repository has been archived by the owner on Dec 1, 2018. It is now read-only.

add oidc support #155

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion pykube/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from six.moves.urllib.parse import urlparse

from .exceptions import HTTPError
from .oidc import handle_oidc
from .utils import jsonpath_installed, jsonpath_parse


Expand Down Expand Up @@ -104,7 +105,9 @@ def send(self, request, **kwargs):
auth_config.get("expiry"),
config,
)
# @@@ support oidc
elif auth_provider.get("name") == "oidc":
auth_config = auth_provider.get("config")
handle_oidc(request, config, auth_provider)
elif "client-certificate" in config.user:
kwargs["cert"] = (
config.user["client-certificate"].filename(),
Expand Down
112 changes: 112 additions & 0 deletions pykube/oidc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
"""
Open ID Connect (OIDC) related code.
"""

import base64
import json
import requests
import six
import time


# If our token is about to expire we should refresh in anticipation
EXPIRY_BUFFER_SECONDS = 10


def _pad_b64(b64):
"""Fix padding for base64 value if necessary"""
pad_len = len(b64) % 4
if pad_len != 0:
missing_padding = (4 - pad_len)
b64 += '=' * missing_padding
return b64


def _id_token_expired(id_token):
"""Is this id token expired?"""
parts = id_token.split('.')
if len(parts) != 3:
raise RuntimeError('ID Token is not valid')
payload_b64 = _pad_b64(parts[1])
if isinstance(payload_b64, six.binary_type):
payload_b64 = six.text_type(payload_b64, encoding='utf-8')
payload = base64.b64decode(payload_b64)
payload_json = json.loads(payload)
expiry = payload_json['exp']
now = int(time.time())
return (now + EXPIRY_BUFFER_SECONDS) > expiry


def _token_endpoint(auth_config):
"""Get the token endpoint from the well known config"""
idp_issuer_url = auth_config.get('idp-issuer-url')

if not idp_issuer_url:
raise RuntimeError('idp-issuer-url not found in config')

discovery_endpoint = idp_issuer_url + '/.well-known/openid-configuration'
r = requests.get(discovery_endpoint)
r.raise_for_status()
discovery_json = r.json()
return discovery_json['token_endpoint']


def _refresh_id_token(auth_config):
"""Generate a new id token from the refresh token"""
refresh_token = auth_config.get('refresh-token')

if not refresh_token:
raise RuntimeError('id-token missing or expired and refresh-token is missing')

client_id = auth_config.get('client-id')
if not client_id:
raise RuntimeError('client-id not found in auth config')

client_secret = auth_config.get('client-secret')
if not client_secret:
raise RuntimeError('client-secret not found in auth config')

token_endpoint = _token_endpoint(auth_config)
data = {
'grant_type': 'refresh_token',
'client_id': client_id,
'client_secret': client_secret,
'refresh_token': refresh_token,
}
r = requests.post(token_endpoint, data=data)
r.raise_for_status()
return r.json()['id_token']


def _persist_credentials(config, id_token):
user_name = config.contexts[config.current_context]['user']
user = [u['user'] for u in config.doc['users'] if u['name'] == user_name][0]
user['auth-provider']['config']['id-token'] = id_token
config.persist_doc()
config.reload()


def _id_token(auth_provider):
"""Return the configured id token if it is not expired, otherwise refresh it"""
auth_config = auth_provider.get('config')

if not auth_config:
raise RuntimeError('auth-provider config not found')

id_token = auth_config.get('id-token')
should_persist = False
if not id_token or _id_token_expired(id_token):
id_token = _refresh_id_token(auth_config)
auth_config['id-token'] = id_token
should_persist = True
return id_token, should_persist


def handle_oidc(request, config, auth_provider):
"""Handle authentication via Open ID Connect"""
id_token, should_persist = _id_token(auth_provider)

if should_persist:
_persist_credentials(config, id_token)

request.headers['Authorization'] = 'Bearer {}'.format(id_token)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

setup(
name="pykube",
version="0.16a1",
version="0.16a2",
description="Python client library for Kubernetes",
long_description=long_description,
author="Eldarion, Inc.",
Expand Down
51 changes: 51 additions & 0 deletions test/test_oidc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""
pykube.oidc unittests
"""
import base64
import logging
import json

from . import TestCase
from pykube import oidc

logger = logging.getLogger(__name__)


class TestOIDC(TestCase):
def test_pad_b64(self):
"""Check that the correct padding is applied to unpadded b64 strings"""
test1 = {"value": b"any carnal pleasure.",
"unpadded": "YW55IGNhcm5hbCBwbGVhc3VyZS4",
"padded": "YW55IGNhcm5hbCBwbGVhc3VyZS4="}
test2 = {"value": b"any carnal pleasure",
"unpadded": "YW55IGNhcm5hbCBwbGVhc3VyZQ",
"padded": "YW55IGNhcm5hbCBwbGVhc3VyZQ=="}
test3 = {"value": b"any carnal pleasur",
"unpadded": "YW55IGNhcm5hbCBwbGVhc3Vy",
"padded": "YW55IGNhcm5hbCBwbGVhc3Vy"}

for test in [test1, test2, test3]:
padded = oidc._pad_b64(test["unpadded"])
self.assertEqual(test["padded"], padded)
value = base64.b64decode(padded)
self.assertEqual(test["value"], value)

def _payload_to_b64(self, payload):
payload_j = json.dumps(payload)
payload_b = payload_j.encode('utf-8')
payload_b64 = base64.b64encode(payload_b)
return payload_b64.decode('utf-8')

def test_id_token_expired(self):
"""Does the token expiry check work?"""
id_token_fmt = 'YW55IGNhcm5hbCBwbGVhc3VyZS4.{}.YW55IGNhcm5hbCBwbGVhc3VyZS4'

payload_expired = {'exp': 0}
payload_expired_b64 = self._payload_to_b64(payload_expired)
id_token_expired = id_token_fmt.format(payload_expired_b64)
self.assertTrue(oidc._id_token_expired(id_token_expired))

payload_valid = {'exp': 99999999999}
payload_valid_b64 = self._payload_to_b64(payload_valid)
id_token_valid = id_token_fmt.format(payload_valid_b64)
self.assertFalse(oidc._id_token_expired(id_token_valid))