From 7df3c5b98f0cf39b03a09075410e48ca2ef12d4f Mon Sep 17 00:00:00 2001 From: Dean Liu Date: Mon, 7 Aug 2023 16:07:25 -0700 Subject: [PATCH] Fix JWT cache (#393) - Refactor JWT cache code. Mostly moving into a separate class that will strictly handle cache logic. Add an abstract JwtCache class so it's easy to extend to an eventual remote redis call. - Introduce `cachetools`. - Add `JWT_CACHING_TTL_SECONDS` which allows us to specify when tokens should be evicted. Before, it was possible that Confidant would provision expired JWTs from its custom cache. --- confidant/routes/jwks.py | 3 +- confidant/services/jwkmanager.py | 140 ++++++++----- confidant/settings.py | 29 ++- requirements.in | 3 + requirements3.txt | 6 +- tests/conftest.py | 5 +- .../confidant/services/jwkmanager_test.py | 189 +++++++++--------- 7 files changed, 216 insertions(+), 159 deletions(-) diff --git a/confidant/routes/jwks.py b/confidant/routes/jwks.py index 7deed43d..5737359e 100644 --- a/confidant/routes/jwks.py +++ b/confidant/routes/jwks.py @@ -3,7 +3,7 @@ from flask import blueprints, jsonify, request from confidant import authnz -from confidant.services.jwkmanager import jwk_manager +from confidant.services.jwkmanager import JWKManager from confidant.schema.jwks import jwt_response_schema, JWTResponse, \ jwks_list_response_schema, JWKSListResponse from confidant.settings import ACL_MODULE @@ -13,6 +13,7 @@ blueprint = blueprints.Blueprint('jwks', __name__) acl_module_check = misc.load_module(ACL_MODULE) +jwk_manager = JWKManager() @blueprint.route('/v1/jwks/token', methods=['GET'], defaults={'id': None}) diff --git a/confidant/services/jwkmanager.py b/confidant/services/jwkmanager.py index 6afec444..b1d9e48c 100644 --- a/confidant/services/jwkmanager.py +++ b/confidant/services/jwkmanager.py @@ -8,29 +8,75 @@ from typing import Tuple import jwt +from abc import ABC, abstractmethod from cerberus import Validator from confidant.settings import JWT_ACTIVE_SIGNING_KEYS from confidant.settings import JWT_CACHING_ENABLED from confidant.settings import JWT_CERTIFICATE_AUTHORITIES from confidant.settings import JWT_DEFAULT_JWT_EXPIRATION_SECONDS +from confidant.settings import JWT_CACHING_MAX_SIZE +from confidant.settings import JWT_CACHING_TTL_SECONDS from confidant.utils import stats -from jwcrypto import jwk +from cachetools import TTLCache +from jwcrypto import jwk logger = logging.getLogger(__name__) CA_SCHEMA = { 'crt': {'type': 'string', 'required': True}, 'key': {'type': 'string', 'required': True}, - 'passphrase': {'type': 'string', 'required': True}, + 'passphrase': {'type': 'string', 'required': True, 'nullable': True}, 'kid': {'type': 'string', 'required': True}, } +class JwtCache(ABC): + + @abstractmethod + def get_jwt(self, kid: str, requester: str, user: str) -> str: + raise NotImplementedError() + + @abstractmethod + def set_jwt(self, kid: str, requester: str, user: str, jwt: str) -> None: + raise NotImplementedError() + + +class LocalJwtCache(JwtCache): + def __init__(self) -> None: + self._token_cache = TTLCache( + maxsize=JWT_CACHING_MAX_SIZE, + ttl=JWT_CACHING_TTL_SECONDS + ) + + def cache_key(self, kid: str, requester: str, user: str) -> str: + return f'{kid}:{requester}:{user}' + + def get_jwt(self, kid: str, requester: str, user: str) -> str: + cached_jwt = self._token_cache.get(self.cache_key(kid, requester, user)) + return cached_jwt + + def set_jwt(self, kid: str, requester: str, user: str, jwt: str) -> None: + self._token_cache[self.cache_key(kid, requester, user)] = jwt + + +# XXX: TODO add remote redis cache +class RedisCache(JwtCache): + def __init__(self) -> None: + raise NotImplementedError() + + def get_jwt(self, kid: str, requester: str, user: str) -> str: + raise NotImplementedError() + + def set_jwt(self, kid: str, requester: str, user: str, jwt: str) -> None: + raise NotImplementedError() + + class JWKManager: def __init__(self) -> None: self._keys = {} - self._token_cache = {} + # XXX: TODO add hook here to point to remote redis cache + self._jwt_cache = LocalJwtCache() self._pem_cache = {} self._load_certificate_authorities() @@ -41,7 +87,8 @@ def _load_certificate_authorities(self) -> None: for environment in JWT_CERTIFICATE_AUTHORITIES: for ca in JWT_CERTIFICATE_AUTHORITIES[environment]: if validator.validate(ca): - self.set_key(environment, ca['kid'], + self.set_key(environment, + ca['kid'], ca['key'], passphrase=ca['passphrase']) else: @@ -80,58 +127,56 @@ def _get_key(self, kid: str, environment: str): ) return self._pem_cache[environment][kid] - def get_jwt(self, environment: str, payload: dict, + def _get_active_kids(self) -> List[str]: + return list(JWT_ACTIVE_SIGNING_KEYS.values()) + + def get_jwt(self, environment: str, + payload: dict, expiration_seconds: int = JWT_DEFAULT_JWT_EXPIRATION_SECONDS, algorithm: str = 'RS256') -> str: + kid, key = self.get_active_key(environment) if not key: raise ValueError('No active key for this environment') - if 'user' not in payload: + user = payload.get('user') + requester = payload.get('requester') + + if not user: raise ValueError('Please include the user in the payload') - if 'requester' not in payload: + if not requester: raise ValueError('Please include the requester in the payload') - user = payload['user'] - requester = payload['requester'] - if kid not in self._token_cache: - self._token_cache[kid] = {} - - if requester not in self._token_cache[kid]: - self._token_cache[kid][requester] = {} - - now = datetime.now(tz=timezone.utc) - - # return token from cache - if user in self._token_cache[kid][requester].keys() \ - and JWT_CACHING_ENABLED: - if now < self._token_cache[kid][requester][user]['expiry']: - stats.incr('jwt.get_jwt.cache.hit') - return self._token_cache[kid][requester][user]['token'] - - # cache miss, generate new token and update cache - expiry = now + timedelta(seconds=expiration_seconds) - payload.update({ - 'iat': now, - 'nbf': now, - 'exp': expiry, - }) - - with stats.timer('jwt.get_jwt.encode'): - token = jwt.encode( - payload=payload, - headers={'kid': kid}, - key=key, - algorithm=algorithm, - ) - - self._token_cache[kid][requester][user] = { - 'expiry': expiry, - 'token': token - } - stats.incr('jwt.get_jwt.create') - return token + jwt_str = None + if JWT_CACHING_ENABLED: + jwt_str = self._jwt_cache.get_jwt(kid, requester, user) + if jwt_str: + stats.incr(f'jwt.get_jwt.cache.{kid}.{requester}.hit') + else: + stats.incr(f'jwt.get_jwt.cache.{kid}.{requester}.miss') + + # cache miss, create a new jwt + if not jwt_str: + now = datetime.now(tz=timezone.utc) + expiry = now + timedelta(seconds=expiration_seconds) + payload.update({ + 'iat': now, + 'nbf': now, + 'exp': expiry, + }) + with stats.timer('jwt.get_jwt.encode'): + jwt_str = jwt.encode( + payload=payload, + headers={'kid': kid}, + key=key, + algorithm=algorithm, + ) + stats.incr(f'jwt.get_jwt.{kid}.{requester}.create') + if JWT_CACHING_ENABLED: + self._jwt_cache.set_jwt(kid, requester, user, jwt_str) + + return jwt_str def get_active_key(self, environment: str) -> Tuple[str, Optional[jwk.JWK]]: # The active signing key used to sign JWTs @@ -153,6 +198,3 @@ def get_jwks(self, environment: str, algorithm: str = 'RS256') \ else: stats.incr(f'jwt.get_jwks.{environment}.miss') return [] - - -jwk_manager = JWKManager() diff --git a/confidant/settings.py b/confidant/settings.py index 6daf98be..9748caf2 100644 --- a/confidant/settings.py +++ b/confidant/settings.py @@ -630,6 +630,27 @@ def str_env(var_name, default=''): JWT_CACHING_ENABLED = bool_env('JWT_CACHING_ENABLED', False) +JWT_CACHING_MAX_SIZE = int_env('JWT_CACHING_MAX_SIZE', 1000) + +# Maximum time JWTs are stored in Confidant's cache. +# Warning: this needs to be considerably less than the JWT TTL +# to avoid Confidant issuing very short lived JWTs. +# Enabling the cache means JWTs will have a varying minimum/maximum TTL window. +# example: +# JWT_CACHING_TTL_SECONDS = 900 (15 min) +# JWT_DEFAULT_JWT_EXPIRATION_SECONDS = 3600 (1 hr) +# This means JWTs issued will be between 3600 and 2700 seconds +JWT_CACHING_TTL_SECONDS = int_env('JWT_CACHING_TTL_SECONDS', 900) + +JWT_DEFAULT_JWT_EXPIRATION_SECONDS = int_env( + 'JWT_DEFAULT_JWT_EXPIRATION_SECONDS', 3600 +) + +# Key IDs from CERTIFICATE_AUTHORITIES that should be used to sign new JWTs, +# provide a JSON with the following format: +# {"staging": "some_kid", "production": "some_kid"} +JWT_ACTIVE_SIGNING_KEYS = json.loads(str_env('JWT_ACTIVE_SIGNING_KEYS', '{}')) + # Configuration validation _settings_failures = False if len(set(SCOPED_AUTH_KEYS.values())) != len(SCOPED_AUTH_KEYS.values()): @@ -651,11 +672,3 @@ def get(name, default=None): # Module that will perform an external ACL check on API endpoints ACL_MODULE = str_env('ACL_MODULE', 'confidant.authnz.rbac:default_acl') -JWT_DEFAULT_JWT_EXPIRATION_SECONDS = int_env( - 'JWT_DEFAULT_JWT_EXPIRATION_SECONDS', 3600 -) - -# Key IDs from CERTIFICATE_AUTHORITIES that should be used to sign new JWTs, -# provide a JSON with the following format: -# {"staging": "some_kid", "production": "some_kid"} -JWT_ACTIVE_SIGNING_KEYS = json.loads(str_env('JWT_ACTIVE_SIGNING_KEYS', '{}')) diff --git a/requirements.in b/requirements.in index bf5f4166..dc43accd 100644 --- a/requirements.in +++ b/requirements.in @@ -200,5 +200,8 @@ pyjwt>=2.6.0 jwcrypto cerberus +# caching jwt +cachetools==5.2.0 + # for typing mypy diff --git a/requirements3.txt b/requirements3.txt index 48f4817f..cb6ab5e1 100644 --- a/requirements3.txt +++ b/requirements3.txt @@ -22,6 +22,8 @@ botocore==1.12.227 # boto3 # pynamodb # s3transfer +cachetools==5.2.0 + # via -r requirements.in cerberus==1.3.4 # via -r requirements.in certifi==2023.5.7 @@ -203,10 +205,10 @@ zope-interface==6.0 # via gevent pip==23.1.2 - # via -r piptools_requirements3.txt + # via -r piptools_requirements.txt setuptools==68.0.0 # via - # -r piptools_requirements3.txt + # -r piptools_requirements.txt # cerberus # gevent # gunicorn diff --git a/tests/conftest.py b/tests/conftest.py index 0885cb54..af7dbea7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,3 @@ -import json import pytest from jwcrypto import jwk @@ -151,7 +150,7 @@ def test_jwks(): @pytest.fixture def test_certificate_authorities(): - return json.dumps({ + return { 'test': [ { 'crt': TEST_CERTIFICATE.decode('utf-8'), @@ -174,4 +173,4 @@ def test_certificate_authorities(): 'kid': '0h7R8dL0rU-b3p3onft_BPfuRW1Ld7YjsFnOWJuFXUE', }, ], - }) + } diff --git a/tests/unit/confidant/services/jwkmanager_test.py b/tests/unit/confidant/services/jwkmanager_test.py index 1e531f17..5df97ed9 100644 --- a/tests/unit/confidant/services/jwkmanager_test.py +++ b/tests/unit/confidant/services/jwkmanager_test.py @@ -1,19 +1,20 @@ import confidant.services.jwkmanager import datetime - -from jwcrypto import jwk - +import base64 +import json import pytest +from jwcrypto import jwk from pytest_mock.plugin import MockerFixture - from typing import Dict, Union - from unittest.mock import patch, Mock - -from confidant.services.jwkmanager import jwk_manager +from confidant.services.jwkmanager import JWKManager +from confidant.services.jwkmanager import LocalJwtCache +from confidant.settings import JWT_CACHING_MAX_SIZE +from confidant.settings import JWT_CACHING_TTL_SECONDS def test_set_key(test_key_pair: jwk.JWK): + jwk_manager = JWKManager() test_private_key = test_key_pair.export_to_pem(private_key=True, password=None) kid = jwk_manager.set_key('test', @@ -23,6 +24,7 @@ def test_set_key(test_key_pair: jwk.JWK): def test_set_key_encrypted(test_encrypted_key: str): + jwk_manager = JWKManager() kid = jwk_manager.set_key('test', 'test-key', test_encrypted_key, passphrase='123456') assert kid == 'test-key' @@ -38,6 +40,7 @@ def test_get_jwt( test_jwk_payload: Dict[str, Union[str, bool]], test_jwt: str ): + jwk_manager = JWKManager() test_private_key = test_key_pair.export_to_pem(private_key=True, password=None) mocker.patch( @@ -55,14 +58,18 @@ def test_get_jwt( jwk_manager.set_key('test', test_key_pair.thumbprint(), test_private_key.decode('utf-8')) - result = jwk_manager.get_jwt('test', - test_jwk_payload) + result = jwk_manager.get_jwt('test', test_jwk_payload) assert result == test_jwt +def helper_jwt_parser(jwt_str, field): + payload_str = f"{jwt_str.split('.')[1]}=" + payload_dict = json.loads(base64.b64decode(payload_str)) + return payload_dict[field] + + @patch.object(confidant.services.jwkmanager, 'datetime', Mock(wraps=datetime.datetime)) -@patch.object(confidant.services.jwkmanager, 'JWT_CACHING_ENABLED', True) @patch.object(confidant.services.jwkmanager, 'JWT_ACTIVE_SIGNING_KEYS', {'test': '0h7R8dL0rU-b3p3onft_BPfuRW1Ld7YjsFnOWJuFXUE'}) def test_get_jwt_caches_jwt( @@ -71,97 +78,65 @@ def test_get_jwt_caches_jwt( test_jwk_payload: Dict[str, Union[str, bool]], test_jwt: str ): + + jwk_manager = JWKManager() test_private_key = test_key_pair.export_to_pem(private_key=True, password=None) - mocker.patch( - 'confidant.services.jwkmanager.datetime.now', - return_value=datetime.datetime( - year=2020, - month=10, - day=10, - hour=0, - minute=0, - second=0, - microsecond=0 - ) - ) jwk_manager.set_key('test', test_key_pair.thumbprint(), test_private_key.decode('utf-8')) - result = jwk_manager.get_jwt('test', - test_jwk_payload) - mocker.patch( - 'confidant.services.jwkmanager.datetime.now', - return_value=datetime.datetime( - year=2020, - month=10, - day=10, - hour=0, - minute=1, - second=0, - microsecond=0 - ) + # Test enabling caching + mocker.patch.object(confidant.services.jwkmanager, + 'JWT_CACHING_ENABLED', True) + # Test that if cache doesn't return a jwt, we call set_jwt + local_cache_get_mock = mocker.patch.object( + jwk_manager, + '_jwt_cache' ) - cached_result = jwk_manager.get_jwt('test', - test_jwk_payload) - assert result == test_jwt - assert result == cached_result - - -@patch.object(confidant.services.jwkmanager, 'datetime', - Mock(wraps=datetime.datetime)) -@patch.object(confidant.services.jwkmanager, 'JWT_CACHING_ENABLED', False) -@patch.object(confidant.services.jwkmanager, 'JWT_ACTIVE_SIGNING_KEYS', - {'test': '0h7R8dL0rU-b3p3onft_BPfuRW1Ld7YjsFnOWJuFXUE'}) -def test_get_jwt_does_not_cache_jwt( - mocker: MockerFixture, - test_key_pair: jwk.JWK, - test_jwk_payload: Dict[str, Union[str, bool]], - test_jwt: str -): - test_private_key = test_key_pair.export_to_pem(private_key=True, - password=None) - mocker.patch( - 'confidant.services.jwkmanager.datetime.now', - return_value=datetime.datetime( - year=2020, - month=10, - day=10, - hour=0, - minute=0, - second=0, - microsecond=0 - ) + local_cache_get_mock.get_jwt.return_value = None + jwk_manager.get_jwt('test', test_jwk_payload) + assert local_cache_get_mock.get_jwt.called is True + assert local_cache_get_mock.set_jwt.called is True + + # Test that if cache returns a jwt, we don't call set_jwt + local_cache_get_mock = mocker.patch.object( + jwk_manager, + '_jwt_cache' ) - jwk_manager.set_key('test', - test_key_pair.thumbprint(), - test_private_key.decode('utf-8')) - result = jwk_manager.get_jwt('test', - test_jwk_payload) - - mocker.patch( - 'confidant.services.jwkmanager.datetime.now', - return_value=datetime.datetime( - year=2020, - month=10, - day=10, - hour=0, - minute=1, - second=1, - microsecond=0 - ) + local_cache_get_mock.get_jwt.return_value = test_jwt + jwk_manager.get_jwt('test', test_jwk_payload) + assert local_cache_get_mock.get_jwt.called is True + assert local_cache_get_mock.set_jwt.called is False + + # Test cache disabled + mocker.patch.object(confidant.services.jwkmanager, + 'JWT_CACHING_ENABLED', False) + local_cache_get_mock = mocker.patch.object( + jwk_manager, + '_jwt_cache' ) - not_cached_result = jwk_manager.get_jwt('test', - test_jwk_payload) - assert result == test_jwt - assert result != not_cached_result + local_cache_get_mock.get_jwt.return_value = None + jwk_manager.get_jwt('test', test_jwk_payload) + assert local_cache_get_mock.get_jwt.called is False + assert local_cache_get_mock.set_jwt.called is False + + # Test that if cache returns a jwt, we don't call set_jwt + local_cache_get_mock = mocker.patch.object( + jwk_manager, + '_jwt_cache' + ) + local_cache_get_mock.get_jwt.return_value = test_jwt + jwk_manager.get_jwt('test', test_jwk_payload) + assert local_cache_get_mock.get_jwt.called is False + assert local_cache_get_mock.set_jwt.called is False def test_get_jwt_raises_no_key_id( test_key_pair: jwk.JWK, test_jwk_payload: Dict[str, Union[str, bool]] ): + jwk_manager = JWKManager() test_private_key = test_key_pair.export_to_pem(private_key=True, password=None) jwk_manager.set_key('test', 'test-key', test_private_key.decode('utf-8')) @@ -171,10 +146,9 @@ def test_get_jwt_raises_no_key_id( def test_get_jwks( test_key_pair: jwk.JWK, - test_jwk_payload: Dict[str, Union[str, bool]], - test_jwt: str, test_jwks: Dict[str, str] ): + jwk_manager = JWKManager() test_private_key = test_key_pair.export_to_pem(private_key=True, password=None) jwk_manager.set_key('testing', @@ -185,11 +159,8 @@ def test_get_jwks( assert result[0] == test_jwks -def test_get_jwks_not_found( - test_key_pair: jwk.JWK, - test_jwk_payload: Dict[str, Union[str, bool]], - test_jwt: str, -): +def test_get_jwks_not_found(): + jwk_manager = JWKManager() result = jwk_manager.get_jwks('non-existent') assert not result @@ -207,6 +178,8 @@ def test_get_jwt_with_ca( with patch.object(confidant.services.jwkmanager, 'JWT_CERTIFICATE_AUTHORITIES', test_certificate_authorities): + + jwk_manager = JWKManager() mocker.patch( 'confidant.services.jwkmanager.datetime.now', return_value=datetime.datetime( @@ -219,6 +192,30 @@ def test_get_jwt_with_ca( microsecond=0 ) ) - result = jwk_manager.get_jwt('test', - test_jwk_payload) - assert result == test_jwt + result = jwk_manager.get_jwt('test', test_jwk_payload) + assert result == test_jwt + + +def test_localcache_init(): + localcache = LocalJwtCache() + assert localcache._token_cache.maxsize == JWT_CACHING_MAX_SIZE + assert localcache._token_cache.ttl == JWT_CACHING_TTL_SECONDS + assert len(localcache._token_cache) == 0 + + +def test_localcache_cache_key(): + localcache = LocalJwtCache() + result = localcache.cache_key('marge', 'homer', 'bart') + assert result == 'marge:homer:bart' + + +def test_localcache_get_jwt(): + localcache = LocalJwtCache() + cached_jwt = localcache.get_jwt('marge', 'homer', 'bart') + assert cached_jwt is None + assert len(localcache._token_cache) == 0 + + cached_jwt = localcache.set_jwt('marge', 'homer', 'bart', 'lisa') + cached_jwt = localcache.get_jwt('marge', 'homer', 'bart') + assert cached_jwt == 'lisa' + assert len(localcache._token_cache) == 1