Skip to content

Commit

Permalink
Fix JWT cache (#393)
Browse files Browse the repository at this point in the history
- Refactor JWT cache code. Mostly moving into a separate class that will
strictly handle cache logic. Add an abstract JwtCache class so it's easy
to extend to an eventual remote redis call.
- Introduce `cachetools`.
- Add `JWT_CACHING_TTL_SECONDS` which allows us to specify when tokens
should be evicted. Before, it was possible that Confidant would
provision expired JWTs from its custom cache.
  • Loading branch information
skiptomyliu authored Aug 7, 2023
1 parent 96c5f5b commit 7df3c5b
Show file tree
Hide file tree
Showing 7 changed files with 216 additions and 159 deletions.
3 changes: 2 additions & 1 deletion confidant/routes/jwks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from flask import blueprints, jsonify, request

from confidant import authnz
from confidant.services.jwkmanager import jwk_manager
from confidant.services.jwkmanager import JWKManager
from confidant.schema.jwks import jwt_response_schema, JWTResponse, \
jwks_list_response_schema, JWKSListResponse
from confidant.settings import ACL_MODULE
Expand All @@ -13,6 +13,7 @@
blueprint = blueprints.Blueprint('jwks', __name__)

acl_module_check = misc.load_module(ACL_MODULE)
jwk_manager = JWKManager()


@blueprint.route('/v1/jwks/token', methods=['GET'], defaults={'id': None})
Expand Down
140 changes: 91 additions & 49 deletions confidant/services/jwkmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,75 @@
from typing import Tuple

import jwt
from abc import ABC, abstractmethod
from cerberus import Validator
from confidant.settings import JWT_ACTIVE_SIGNING_KEYS
from confidant.settings import JWT_CACHING_ENABLED
from confidant.settings import JWT_CERTIFICATE_AUTHORITIES
from confidant.settings import JWT_DEFAULT_JWT_EXPIRATION_SECONDS
from confidant.settings import JWT_CACHING_MAX_SIZE
from confidant.settings import JWT_CACHING_TTL_SECONDS
from confidant.utils import stats
from jwcrypto import jwk

from cachetools import TTLCache
from jwcrypto import jwk

logger = logging.getLogger(__name__)

CA_SCHEMA = {
'crt': {'type': 'string', 'required': True},
'key': {'type': 'string', 'required': True},
'passphrase': {'type': 'string', 'required': True},
'passphrase': {'type': 'string', 'required': True, 'nullable': True},
'kid': {'type': 'string', 'required': True},
}


class JwtCache(ABC):

@abstractmethod
def get_jwt(self, kid: str, requester: str, user: str) -> str:
raise NotImplementedError()

@abstractmethod
def set_jwt(self, kid: str, requester: str, user: str, jwt: str) -> None:
raise NotImplementedError()


class LocalJwtCache(JwtCache):
def __init__(self) -> None:
self._token_cache = TTLCache(
maxsize=JWT_CACHING_MAX_SIZE,
ttl=JWT_CACHING_TTL_SECONDS
)

def cache_key(self, kid: str, requester: str, user: str) -> str:
return f'{kid}:{requester}:{user}'

def get_jwt(self, kid: str, requester: str, user: str) -> str:
cached_jwt = self._token_cache.get(self.cache_key(kid, requester, user))
return cached_jwt

def set_jwt(self, kid: str, requester: str, user: str, jwt: str) -> None:
self._token_cache[self.cache_key(kid, requester, user)] = jwt


# XXX: TODO add remote redis cache
class RedisCache(JwtCache):
def __init__(self) -> None:
raise NotImplementedError()

def get_jwt(self, kid: str, requester: str, user: str) -> str:
raise NotImplementedError()

def set_jwt(self, kid: str, requester: str, user: str, jwt: str) -> None:
raise NotImplementedError()


class JWKManager:
def __init__(self) -> None:
self._keys = {}
self._token_cache = {}
# XXX: TODO add hook here to point to remote redis cache
self._jwt_cache = LocalJwtCache()
self._pem_cache = {}

self._load_certificate_authorities()
Expand All @@ -41,7 +87,8 @@ def _load_certificate_authorities(self) -> None:
for environment in JWT_CERTIFICATE_AUTHORITIES:
for ca in JWT_CERTIFICATE_AUTHORITIES[environment]:
if validator.validate(ca):
self.set_key(environment, ca['kid'],
self.set_key(environment,
ca['kid'],
ca['key'],
passphrase=ca['passphrase'])
else:
Expand Down Expand Up @@ -80,58 +127,56 @@ def _get_key(self, kid: str, environment: str):
)
return self._pem_cache[environment][kid]

def get_jwt(self, environment: str, payload: dict,
def _get_active_kids(self) -> List[str]:
return list(JWT_ACTIVE_SIGNING_KEYS.values())

def get_jwt(self, environment: str,
payload: dict,
expiration_seconds: int = JWT_DEFAULT_JWT_EXPIRATION_SECONDS,
algorithm: str = 'RS256') -> str:

kid, key = self.get_active_key(environment)
if not key:
raise ValueError('No active key for this environment')

if 'user' not in payload:
user = payload.get('user')
requester = payload.get('requester')

if not user:
raise ValueError('Please include the user in the payload')

if 'requester' not in payload:
if not requester:
raise ValueError('Please include the requester in the payload')

user = payload['user']
requester = payload['requester']
if kid not in self._token_cache:
self._token_cache[kid] = {}

if requester not in self._token_cache[kid]:
self._token_cache[kid][requester] = {}

now = datetime.now(tz=timezone.utc)

# return token from cache
if user in self._token_cache[kid][requester].keys() \
and JWT_CACHING_ENABLED:
if now < self._token_cache[kid][requester][user]['expiry']:
stats.incr('jwt.get_jwt.cache.hit')
return self._token_cache[kid][requester][user]['token']

# cache miss, generate new token and update cache
expiry = now + timedelta(seconds=expiration_seconds)
payload.update({
'iat': now,
'nbf': now,
'exp': expiry,
})

with stats.timer('jwt.get_jwt.encode'):
token = jwt.encode(
payload=payload,
headers={'kid': kid},
key=key,
algorithm=algorithm,
)

self._token_cache[kid][requester][user] = {
'expiry': expiry,
'token': token
}
stats.incr('jwt.get_jwt.create')
return token
jwt_str = None
if JWT_CACHING_ENABLED:
jwt_str = self._jwt_cache.get_jwt(kid, requester, user)
if jwt_str:
stats.incr(f'jwt.get_jwt.cache.{kid}.{requester}.hit')
else:
stats.incr(f'jwt.get_jwt.cache.{kid}.{requester}.miss')

# cache miss, create a new jwt
if not jwt_str:
now = datetime.now(tz=timezone.utc)
expiry = now + timedelta(seconds=expiration_seconds)
payload.update({
'iat': now,
'nbf': now,
'exp': expiry,
})
with stats.timer('jwt.get_jwt.encode'):
jwt_str = jwt.encode(
payload=payload,
headers={'kid': kid},
key=key,
algorithm=algorithm,
)
stats.incr(f'jwt.get_jwt.{kid}.{requester}.create')
if JWT_CACHING_ENABLED:
self._jwt_cache.set_jwt(kid, requester, user, jwt_str)

return jwt_str

def get_active_key(self, environment: str) -> Tuple[str, Optional[jwk.JWK]]:
# The active signing key used to sign JWTs
Expand All @@ -153,6 +198,3 @@ def get_jwks(self, environment: str, algorithm: str = 'RS256') \
else:
stats.incr(f'jwt.get_jwks.{environment}.miss')
return []


jwk_manager = JWKManager()
29 changes: 21 additions & 8 deletions confidant/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,27 @@ def str_env(var_name, default=''):

JWT_CACHING_ENABLED = bool_env('JWT_CACHING_ENABLED', False)

JWT_CACHING_MAX_SIZE = int_env('JWT_CACHING_MAX_SIZE', 1000)

# Maximum time JWTs are stored in Confidant's cache.
# Warning: this needs to be considerably less than the JWT TTL
# to avoid Confidant issuing very short lived JWTs.
# Enabling the cache means JWTs will have a varying minimum/maximum TTL window.
# example:
# JWT_CACHING_TTL_SECONDS = 900 (15 min)
# JWT_DEFAULT_JWT_EXPIRATION_SECONDS = 3600 (1 hr)
# This means JWTs issued will be between 3600 and 2700 seconds
JWT_CACHING_TTL_SECONDS = int_env('JWT_CACHING_TTL_SECONDS', 900)

JWT_DEFAULT_JWT_EXPIRATION_SECONDS = int_env(
'JWT_DEFAULT_JWT_EXPIRATION_SECONDS', 3600
)

# Key IDs from CERTIFICATE_AUTHORITIES that should be used to sign new JWTs,
# provide a JSON with the following format:
# {"staging": "some_kid", "production": "some_kid"}
JWT_ACTIVE_SIGNING_KEYS = json.loads(str_env('JWT_ACTIVE_SIGNING_KEYS', '{}'))

# Configuration validation
_settings_failures = False
if len(set(SCOPED_AUTH_KEYS.values())) != len(SCOPED_AUTH_KEYS.values()):
Expand All @@ -651,11 +672,3 @@ def get(name, default=None):

# Module that will perform an external ACL check on API endpoints
ACL_MODULE = str_env('ACL_MODULE', 'confidant.authnz.rbac:default_acl')
JWT_DEFAULT_JWT_EXPIRATION_SECONDS = int_env(
'JWT_DEFAULT_JWT_EXPIRATION_SECONDS', 3600
)

# Key IDs from CERTIFICATE_AUTHORITIES that should be used to sign new JWTs,
# provide a JSON with the following format:
# {"staging": "some_kid", "production": "some_kid"}
JWT_ACTIVE_SIGNING_KEYS = json.loads(str_env('JWT_ACTIVE_SIGNING_KEYS', '{}'))
3 changes: 3 additions & 0 deletions requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -200,5 +200,8 @@ pyjwt>=2.6.0
jwcrypto
cerberus

# caching jwt
cachetools==5.2.0

# for typing
mypy
6 changes: 4 additions & 2 deletions requirements3.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ botocore==1.12.227
# boto3
# pynamodb
# s3transfer
cachetools==5.2.0
# via -r requirements.in
cerberus==1.3.4
# via -r requirements.in
certifi==2023.5.7
Expand Down Expand Up @@ -203,10 +205,10 @@ zope-interface==6.0
# via gevent

pip==23.1.2
# via -r piptools_requirements3.txt
# via -r piptools_requirements.txt
setuptools==68.0.0
# via
# -r piptools_requirements3.txt
# -r piptools_requirements.txt
# cerberus
# gevent
# gunicorn
Expand Down
5 changes: 2 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import pytest

from jwcrypto import jwk
Expand Down Expand Up @@ -151,7 +150,7 @@ def test_jwks():

@pytest.fixture
def test_certificate_authorities():
return json.dumps({
return {
'test': [
{
'crt': TEST_CERTIFICATE.decode('utf-8'),
Expand All @@ -174,4 +173,4 @@ def test_certificate_authorities():
'kid': '0h7R8dL0rU-b3p3onft_BPfuRW1Ld7YjsFnOWJuFXUE',
},
],
})
}
Loading

0 comments on commit 7df3c5b

Please sign in to comment.