diff --git a/CHANGELOG.md b/CHANGELOG.md index 3c29a35..e9d4dea 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ## Unreleased +## v2.1.0 - 2024-10-13 + +- Refactored pemission claim mapping + ## v2.0.3 - 2024-10-10 - Made audience optional in cli diff --git a/armasec/__init__.py b/armasec/__init__.py index c833199..4c1cddc 100644 --- a/armasec/__init__.py +++ b/armasec/__init__.py @@ -1,6 +1,6 @@ from armasec.armasec import Armasec from armasec.openid_config_loader import OpenidConfigLoader -from armasec.token_decoder import TokenDecoder +from armasec.token_decoder import TokenDecoder, extract_keycloak_permissions from armasec.token_manager import TokenManager from armasec.token_payload import TokenPayload from armasec.token_security import TokenSecurity @@ -12,4 +12,5 @@ "TokenPayload", "TokenDecoder", "OpenidConfigLoader", + "extract_keycloak_permissions", ] diff --git a/armasec/schemas/armasec_config.py b/armasec/schemas/armasec_config.py index 492419f..7de74c3 100644 --- a/armasec/schemas/armasec_config.py +++ b/armasec/schemas/armasec_config.py @@ -2,7 +2,7 @@ This module provides a pydantic schema describing Armasec's configuration parameters. """ -from typing import Any, Dict, List, Optional, Set, Union +from typing import Any, Dict, List, Optional, Set, Union, Callable import snick from pydantic import BaseModel, Field @@ -46,12 +46,13 @@ class DomainConfig(BaseModel): ), ) ) - payload_claim_mapping: Optional[Dict[str, Any]] = Field( + permission_extractor: Optional[Callable[[Dict[str, Any]], List[str]]] = Field( None, description=snick.unwrap( """ - Optional mappings that are applied to map claims to top-level properties of - TokenPayload. See docs for `TokenDecoder` for more info. + Optional function that may be used to extract permissions from the decoded token + dictionary when the permissions are not a top-level claim in the token. + See docs for `TokenDecoder` for more info. """ ), ) diff --git a/armasec/token_decoder.py b/armasec/token_decoder.py index 65f8a68..f81962f 100644 --- a/armasec/token_decoder.py +++ b/armasec/token_decoder.py @@ -7,8 +7,6 @@ from functools import partial from typing import Callable -import jmespath -import buzz from jose import jwt from armasec.exceptions import AuthenticationError, PayloadMappingError @@ -30,7 +28,7 @@ def __init__( algorithm: str = "RS256", debug_logger: Callable[..., None] | None = None, decode_options_override: dict | None = None, - payload_claim_mapping: dict | None = None, + permission_extractor: Callable[[dict], list[str]] | None = None, ): """ Initializes a TokenDecoder. @@ -44,36 +42,46 @@ def __init__( decode_options_override: Options that can override the default behavior of the jwt decode method. For example, one can ignore token expiration by setting this to `{ "verify_exp": False }` - payload_claim_mapping: Optional mappings that are applied to map claims to top-level - attribute of TokenPayload using a dict format of: + permission_extractor: Optional function that may be used to extract permissions from + the decoded token dictionary when the permissions are not a + top-level claim in the token. If not provided, permissions will + be assumed to be a top-level claim in the token. - ``` - { - "top_level_attribute": "decoded.token.JMESPath" - } - ``` - The values _must_ be a valid JMESPath. - - Consider this example: + Consider the example token: ``` { - "permissions": "resource_access.default.roles" + "exp": 1728627701, + "iat": 1728626801, + "jti": "24fdb7ef-d773-4e6b-982a-b8126dd58af7", + "sub": "dfa64115-40b5-46ab-924c-c376e73f631d", + "azp": "my-client", + "resource_access": { + "my-client": { + "roles": [ + "read:stuff" + ] + }, + }, } ``` - The above example would result in a TokenPayload like: + In this example, the permissions are found at + `resource_access.my-client.roles`. To produce a TokenPayload + with the permissions set as expected, you could supply a + permission extractor like this: ``` - TokenPayload(permissions=token["resource_access"]["default"]["roles"]) + def my_extractor(decoded_token: dict) -> list[str]: + resource_key = decoded_token["azp"] + return decoded_token["resource_access"][resource_key]["roles"] ``` - Raises a 500 if the path does not match """ self.algorithm = algorithm self.jwks = jwks self.debug_logger = debug_logger if debug_logger else noop self.decode_options_override = decode_options_override if decode_options_override else {} - self.payload_claim_mapping = payload_claim_mapping if payload_claim_mapping else {} + self.permission_extractor = permission_extractor def get_decode_key(self, token: str) -> dict: """ @@ -128,18 +136,15 @@ def decode(self, token: str, **claims) -> TokenPayload: self.debug_logger(f"Raw payload dictionary is {payload_dict}") with PayloadMappingError.handle_errors( - "Failed to map decoded token to payload", + "Failed to map decoded token to TokenPayload", do_except=partial(log_error, self.debug_logger), ): - for payload_key, token_jmespath in self.payload_claim_mapping.items(): - mapped_value = jmespath.search(token_jmespath, payload_dict) - buzz.require_condition( - mapped_value is not None, - f"No matching values found for claim mapping {token_jmespath} -> {payload_key}", - raise_exc_class=KeyError, + if self.permission_extractor is not None: + self.debug_logger("Attempting to extract permissions.") + payload_dict["permissions"] = self.permission_extractor(payload_dict) + self.debug_logger( + f"Payload dictionary with extracted permissions is {payload_dict}" ) - payload_dict[payload_key] = mapped_value - self.debug_logger(f"Mapped payload dictionary is {payload_dict}") self.debug_logger("Attempting to convert to TokenPayload") token_payload = TokenPayload( @@ -148,3 +153,39 @@ def decode(self, token: str, **claims) -> TokenPayload: ) self.debug_logger(f"Built token_payload as {token_payload}") return token_payload + + +def extract_keycloak_permissions(decoded_token: dict) -> list[str]: + """ + Provide a permission extractor for Keycloak. + + By default, Keycloak packages the roles for a given client + nested within the "resource_access" claim. In order to extract + those roles into the expected permissions in the TokenPayload, + this permission_extractor can be used. + + Here is an example decoded token from Keycloak (with some stuff + removed to improve readability): + + ``` + { + "exp": 1728627701, + "iat": 1728626801, + "jti": "24fdb7ef-d773-4e6b-982a-b8126dd58af7", + "sub": "dfa64115-40b5-46ab-924c-c376e73f631d", + "azp": "my-client", + "resource_access": { + "my-client": { + "roles": [ + "read:stuff" + ] + }, + }, + } + ``` + + This extractor would extract the roles `["read:stuff"]` as the + permissions for the TokenPayload returned by the TokenDecoder. + """ + resource_key = decoded_token["azp"] + return decoded_token["resource_access"][resource_key]["roles"] diff --git a/armasec/token_security.py b/armasec/token_security.py index c14e149..2b5e13f 100644 --- a/armasec/token_security.py +++ b/armasec/token_security.py @@ -225,7 +225,7 @@ def _load_manager(self, domain_config: DomainConfig) -> TokenManager: loader.jwks, domain_config.algorithm, debug_logger=self.debug_logger, - payload_claim_mapping=domain_config.payload_claim_mapping, + permission_extractor=domain_config.permission_extractor, ) return TokenManager( loader.config, diff --git a/docs/source/tutorials/getting_started_with_keycloak.md b/docs/source/tutorials/getting_started_with_keycloak.md index e384ae8..3b9f4c9 100644 --- a/docs/source/tutorials/getting_started_with_keycloak.md +++ b/docs/source/tutorials/getting_started_with_keycloak.md @@ -202,7 +202,7 @@ and this user does not have any roles mapped to it! ## Start up the example app ```python title="example.py" linenums="1" -from armasec import Armasec +from armasec import Armasec, extract_keycloak_permissions from fastapi import FastAPI, Depends @@ -211,7 +211,7 @@ armasec = Armasec( domain="localhost:8080/realms/master", audience="http://keycloak.local", use_https=False, - payload_claim_mapping=dict(permissions="resource_access.armasec_tutorial.roles"), + permission_extractor=extract_keycloak_permissions, debug_logger=print, debug_exceptions=True, ) @@ -224,10 +224,10 @@ async def check_access(): Note in this example that the `use_https` flag must be set to false to allow a local server using unsecured HTTP. -Also not that we need to add a `payload_claim_mapping` because Keycloak does not provide -a permissions claim at the top level. This mapping copies the roles found at -`resource_access.armasec_tutorial.roles` to a top-level attribute of the token payload -called permissions. +Also not that we need to tell Armasec to use `extract_keycloak_permissions()` because +Keycloak does not provide a permissions claim at the top level. This extractor function +extracts the roles from `resource_access.armasec_tutorial.roles` so that they can +be used as the "permissions" in the decoded token payload. Copy the `example.py` app to a local source file called "example.py". diff --git a/poetry.lock b/poetry.lock index 1ea2ec5..c8ba5fe 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "annotated-types" @@ -928,17 +928,6 @@ MarkupSafe = ">=2.0" [package.extras] i18n = ["Babel (>=2.7)"] -[[package]] -name = "jmespath" -version = "1.0.1" -description = "JSON Matching Expressions" -optional = false -python-versions = ">=3.7" -files = [ - {file = "jmespath-1.0.1-py3-none-any.whl", hash = "sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980"}, - {file = "jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe"}, -] - [[package]] name = "loguru" version = "0.5.3" @@ -2458,17 +2447,6 @@ rich = ">=10.11.0" shellingham = ">=1.3.0" typing-extensions = ">=3.7.4.3" -[[package]] -name = "types-jmespath" -version = "1.0.2.20240106" -description = "Typing stubs for jmespath" -optional = false -python-versions = ">=3.8" -files = [ - {file = "types-jmespath-1.0.2.20240106.tar.gz", hash = "sha256:b4a65a116bfc1c700a4fd9d24e2e397f4a431122e0320a77b7f1989a6b5d819e"}, - {file = "types_jmespath-1.0.2.20240106-py3-none-any.whl", hash = "sha256:c3e715fcaae9e5f8d74e14328fdedc4f2b3f0e18df17f3e457ae0a18e245bde0"}, -] - [[package]] name = "typing-extensions" version = "4.12.2" @@ -2879,4 +2857,4 @@ cli = ["loguru", "pendulum", "pyperclip", "rich", "typer"] [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "cdabc0ce4d93feee31a4a6291887db665fbba251c69d4cfa98521b9b8b26395a" +content-hash = "685c85d594d46e0f0915c69427a0ed5e54fee06b2b45a9b158f23608f5b75714" diff --git a/pyproject.toml b/pyproject.toml index 3ef959e..f07e815 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "armasec" -version = "2.0.3" +version = "2.1.0" description = "Injectable FastAPI auth via OIDC" authors = ["Omnivector Engineering Team "] license = "MIT" @@ -27,6 +27,7 @@ pydantic = "^2.7" httpx = "^0" snick = "^1.3" py-buzz = "^4.1" +pluggy = "^1.4.0" # These must be included as a main dependency for the pytest extension to work out of the box respx = "^0" @@ -39,13 +40,10 @@ loguru = {version = "^0.5.3", optional = true} rich = {version = "^13.5.2", optional = true} pendulum = {version = "^3.0.0", optional = true} pyperclip = {version = "^1.8.2", optional = true} -jmespath = "^1.0.1" -pluggy = "^1.4.0" [tool.poetry.extras] cli = ["typer", "loguru", "rich", "pendulum", "pyperclip"] - [tool.poetry.group.dev.dependencies] ipython = ">=7,<9" asgi-lifespan = "^1.0.1" @@ -63,8 +61,6 @@ pygments = "^2.16.1" plummet = {extras = ["time-machine"], version = "^1.2.1"} pytest-mock = "^3.11.1" ruff = "^0.3" -types-jmespath = "^1.0.2.7" - [tool.poetry.scripts] armasec = {callable = "armasec_cli.main:app", extras = ["cli"]} @@ -81,12 +77,7 @@ testpaths = ["tests"] asyncio_mode = "auto" [[tool.mypy.overrides]] -module = [ - "jose", - "buzz", - "snick", - "auto_name_enum", -] +module = ["jose"] ignore_missing_imports = true [tool.ruff] diff --git a/tests/test_token_decoder.py b/tests/test_token_decoder.py index 3f3fc90..879d5ea 100644 --- a/tests/test_token_decoder.py +++ b/tests/test_token_decoder.py @@ -2,13 +2,14 @@ These tests verify the functionality of the TokenDecoder. """ +from uuid import uuid4 from unittest import mock import pytest from armasec.exceptions import AuthenticationError, PayloadMappingError from armasec.schemas.jwks import JWK, JWKs -from armasec.token_decoder import TokenDecoder +from armasec.token_decoder import TokenDecoder, extract_keycloak_permissions def test_get_decode_key(rs256_jwk, build_rs256_token, rs256_kid): @@ -83,9 +84,9 @@ def test_decode__fails_when_jwt_decode_throws_an_error(rs256_jwk): decoder.decode("doesn't matter what token we pass here") -def test_decode__with_payload_claim_mapping(rs256_jwk, build_rs256_token): +def test_decode__with_permission_extractor(rs256_jwk, build_rs256_token): """ - Verify that an RS256Decoder applies a payload_claim_mapping to a valid jwt. + Verify that an RS256Decoder can extract permissions from a valid jwt. """ token = build_rs256_token( claim_overrides=dict( @@ -95,9 +96,13 @@ def test_decode__with_payload_claim_mapping(rs256_jwk, build_rs256_token): resource_access=dict(default=dict(roles=["read:stuff", "write:stuff"])), ), ) + + def extractor(token_dict): + return token_dict["resource_access"]["default"]["roles"] + decoder = TokenDecoder( JWKs(keys=[rs256_jwk]), - payload_claim_mapping=dict(permissions="resource_access.default.roles"), + permission_extractor=extractor, ) token_payload = decoder.decode(token) assert token_payload.sub == "test_decode-test-sub" @@ -105,35 +110,53 @@ def test_decode__with_payload_claim_mapping(rs256_jwk, build_rs256_token): assert token_payload.permissions == ["read:stuff", "write:stuff"] -def test_decode__missing_payload_claim_mapping(rs256_jwk, build_rs256_token): +def test_decode__permission_extractor_raises_error(rs256_jwk, build_rs256_token): """ - Verify that an RS256Decoder throws an error if mapping failed. + Verify that an RS256Decoder handles a failure in the permission extractor. - There will be an error if there is a missing claim mapping. - There will be an error if the jmespath expression is invalid. + If an exception is raised by the permission extractor, it should be handled + by the decoder and PayloadMappingError should be raised instead. """ token = build_rs256_token( claim_overrides=dict( sub="test_decode-test-sub", azp="some-fake-id", + permissions=[], + resource_access=dict(default=dict(roles=["read:stuff", "write:stuff"])), ), ) - decoder = TokenDecoder( - JWKs(keys=[rs256_jwk]), - payload_claim_mapping=dict(foo="bar.baz"), - ) - with pytest.raises( - PayloadMappingError, - match="Failed to map decoded token.*No matching values", - ): - decoder.decode(token) + + def extractor(_): + raise RuntimeError("Boom!") decoder = TokenDecoder( JWKs(keys=[rs256_jwk]), - payload_claim_mapping=dict(foo="bar-baz"), + permission_extractor=extractor, ) with pytest.raises( PayloadMappingError, - match="Failed to map decoded token.*Bad jmespath expression", + match="Failed to map decoded token.*Boom!", ): decoder.decode(token) + + +def test_extract_keycloak_permissions(): + """ + Verify the `extract_keycloak_permissions()` works as intended. + + It should correctly extract's the client's role as the permissions to be used in the + TokenPayload. + """ + client_id = uuid4() + decoded_token = { + "exp": 1728627701, + "iat": 1728626801, + "jti": "24fdb7ef-d773-4e6b-982a-b8126dd58af7", + "sub": "dfa64115-40b5-46ab-924c-c376e73f631d", + "azp": client_id, + "resource_access": { + client_id: {"roles": ["read:stuff"]}, + }, + } + + assert extract_keycloak_permissions(decoded_token) == ["read:stuff"]