This repository has been archived by the owner on Dec 1, 2018. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 183
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e8a4629
commit 9ef5bb6
Showing
4 changed files
with
168 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
""" | ||
Open ID Connect (OIDC) related code. | ||
""" | ||
|
||
import base64 | ||
import json | ||
import requests | ||
import six | ||
import time | ||
|
||
|
||
# If our token is about to expire we should refresh in anticipation | ||
EXPIRY_BUFFER_SECONDS = 10 | ||
|
||
|
||
def _pad_b64(b64): | ||
"""Fix padding for base64 value if necessary""" | ||
pad_len = len(b64) % 4 | ||
if pad_len != 0: | ||
missing_padding = (4 - pad_len) | ||
b64 += '=' * missing_padding | ||
return b64 | ||
|
||
|
||
def _id_token_expired(id_token): | ||
"""Is this id token expired?""" | ||
parts = id_token.split('.') | ||
if len(parts) != 3: | ||
raise RuntimeError('ID Token is not valid') | ||
payload_b64 = _pad_b64(parts[1]) | ||
if isinstance(payload_b64, six.binary_type): | ||
payload_b64 = six.text_type(payload_b64, encoding='utf-8') | ||
payload = base64.b64decode(payload_b64) | ||
payload_json = json.loads(payload) | ||
expiry = payload_json['exp'] | ||
now = int(time.time()) | ||
return (now + EXPIRY_BUFFER_SECONDS) > expiry | ||
|
||
|
||
def _token_endpoint(auth_config): | ||
"""Get the token endpoint from the well known config""" | ||
idp_issuer_url = auth_config.get('idp-issuer-url') | ||
|
||
if not idp_issuer_url: | ||
raise RuntimeError('idp-issuer-url not found in config') | ||
|
||
discovery_endpoint = idp_issuer_url + '/.well-known/openid-configuration' | ||
r = requests.get(discovery_endpoint) | ||
r.raise_for_status() | ||
discovery_json = r.json() | ||
return discovery_json['token_endpoint'] | ||
|
||
|
||
def _refresh_id_token(auth_config): | ||
"""Generate a new id token from the refresh token""" | ||
refresh_token = auth_config.get('refresh-token') | ||
|
||
if not refresh_token: | ||
raise RuntimeError('id-token missing or expired and refresh-token is missing') | ||
|
||
client_id = auth_config.get('client-id') | ||
if not client_id: | ||
raise RuntimeError('client-id not found in auth config') | ||
|
||
client_secret = auth_config.get('client-secret') | ||
if not client_secret: | ||
raise RuntimeError('client-secret not found in auth config') | ||
|
||
token_endpoint = _token_endpoint(auth_config) | ||
data = { | ||
'grant_type': 'refresh_token', | ||
'client_id': client_id, | ||
'client_secret': client_secret, | ||
'refresh_token': refresh_token, | ||
} | ||
r = requests.post(token_endpoint, data=data) | ||
r.raise_for_status() | ||
return r.json()['id_token'] | ||
|
||
|
||
def _persist_credentials(config, id_token): | ||
user_name = config.contexts[config.current_context]['user'] | ||
user = [u['user'] for u in config.doc['users'] if u['name'] == user_name][0] | ||
user['auth-provider']['config']['id-token'] = id_token | ||
config.persist_doc() | ||
config.reload() | ||
|
||
|
||
def _id_token(auth_provider): | ||
"""Return the configured id token if it is not expired, otherwise refresh it""" | ||
auth_config = auth_provider.get('config') | ||
|
||
if not auth_config: | ||
raise RuntimeError('auth-provider config not found') | ||
|
||
id_token = auth_config.get('id-token') | ||
should_persist = False | ||
if not id_token or _id_token_expired(id_token): | ||
id_token = _refresh_id_token(auth_config) | ||
auth_config['id-token'] = id_token | ||
should_persist = True | ||
return id_token, should_persist | ||
|
||
|
||
def handle_oidc(request, config, auth_provider): | ||
"""Handle authentication via Open ID Connect""" | ||
id_token, should_persist = _id_token(auth_provider) | ||
|
||
if should_persist: | ||
_persist_credentials(config, id_token) | ||
|
||
request.headers['Authorization'] = 'Bearer {}'.format(id_token) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
""" | ||
pykube.oidc unittests | ||
""" | ||
import base64 | ||
import logging | ||
import json | ||
|
||
from . import TestCase | ||
from pykube import oidc | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class TestOIDC(TestCase): | ||
def test_pad_b64(self): | ||
"""Check that the correct padding is applied to unpadded b64 strings""" | ||
test1 = {"value": b"any carnal pleasure.", | ||
"unpadded": "YW55IGNhcm5hbCBwbGVhc3VyZS4", | ||
"padded": "YW55IGNhcm5hbCBwbGVhc3VyZS4="} | ||
test2 = {"value": b"any carnal pleasure", | ||
"unpadded": "YW55IGNhcm5hbCBwbGVhc3VyZQ", | ||
"padded": "YW55IGNhcm5hbCBwbGVhc3VyZQ=="} | ||
test3 = {"value": b"any carnal pleasur", | ||
"unpadded": "YW55IGNhcm5hbCBwbGVhc3Vy", | ||
"padded": "YW55IGNhcm5hbCBwbGVhc3Vy"} | ||
|
||
for test in [test1, test2, test3]: | ||
padded = oidc._pad_b64(test["unpadded"]) | ||
self.assertEqual(test["padded"], padded) | ||
value = base64.b64decode(padded) | ||
self.assertEqual(test["value"], value) | ||
|
||
def _payload_to_b64(self, payload): | ||
payload_j = json.dumps(payload) | ||
payload_b = payload_j.encode('utf-8') | ||
payload_b64 = base64.b64encode(payload_b) | ||
return payload_b64.decode('utf-8') | ||
|
||
def test_id_token_expired(self): | ||
"""Does the token expiry check work?""" | ||
id_token_fmt = 'YW55IGNhcm5hbCBwbGVhc3VyZS4.{}.YW55IGNhcm5hbCBwbGVhc3VyZS4' | ||
|
||
payload_expired = {'exp': 0} | ||
payload_expired_b64 = self._payload_to_b64(payload_expired) | ||
id_token_expired = id_token_fmt.format(payload_expired_b64) | ||
self.assertTrue(oidc._id_token_expired(id_token_expired)) | ||
|
||
payload_valid = {'exp': 99999999999} | ||
payload_valid_b64 = self._payload_to_b64(payload_valid) | ||
id_token_valid = id_token_fmt.format(payload_valid_b64) | ||
self.assertFalse(oidc._id_token_expired(id_token_valid)) |