Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

providers/oauth2: Add provider federation between OAuth2 Providers #12083

Merged
merged 12 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion authentik/blueprints/v1/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,12 @@
from authentik.outposts.models import OutpostServiceConnection
from authentik.policies.models import Policy, PolicyBindingModel
from authentik.policies.reputation.models import Reputation
from authentik.providers.oauth2.models import AccessToken, AuthorizationCode, RefreshToken
from authentik.providers.oauth2.models import (
AccessToken,
AuthorizationCode,
DeviceToken,
RefreshToken,
)
from authentik.providers.scim.models import SCIMProviderGroup, SCIMProviderUser
from authentik.rbac.models import Role
from authentik.sources.scim.models import SCIMSourceGroup, SCIMSourceUser
Expand Down Expand Up @@ -125,6 +130,7 @@ def excluded_models() -> list[type[Model]]:
MicrosoftEntraProviderGroup,
EndpointDevice,
EndpointDeviceConnection,
DeviceToken,
)


Expand Down
3 changes: 2 additions & 1 deletion authentik/providers/oauth2/api/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ class Meta:
"sub_mode",
"property_mappings",
"issuer_mode",
"jwks_sources",
"jwt_federation_sources",
"jwt_federation_providers",
]
extra_kwargs = ProviderSerializer.Meta.extra_kwargs

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Generated by Django 5.0.9 on 2024-11-22 14:25

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
("authentik_providers_oauth2", "0024_remove_oauth2provider_redirect_uris_and_more"),
]

operations = [
migrations.RenameField(
model_name="oauth2provider",
old_name="jwks_sources",
new_name="jwt_federation_sources",
),
migrations.AddField(
model_name="oauth2provider",
name="jwt_federation_providers",
field=models.ManyToManyField(
blank=True, default=None, to="authentik_providers_oauth2.oauth2provider"
),
),
]
3 changes: 2 additions & 1 deletion authentik/providers/oauth2/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ class OAuth2Provider(WebfingerProvider, Provider):
related_name="oauth2provider_encryption_key_set",
)

jwks_sources = models.ManyToManyField(
jwt_federation_sources = models.ManyToManyField(
OAuthSource,
verbose_name=_(
"Any JWT signed by the JWK of the selected source can be used to authenticate."
Expand All @@ -253,6 +253,7 @@ class OAuth2Provider(WebfingerProvider, Provider):
default=None,
blank=True,
)
jwt_federation_providers = models.ManyToManyField("OAuth2Provider", blank=True, default=None)

@cached_property
def jwt_key(self) -> tuple[str | PrivateKeyTypes, str]:
Expand Down
228 changes: 228 additions & 0 deletions authentik/providers/oauth2/tests/test_token_cc_jwt_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
"""Test token view"""

from datetime import datetime, timedelta
from json import loads

from django.test import RequestFactory
from django.urls import reverse
from django.utils.timezone import now
from jwt import decode

from authentik.blueprints.tests import apply_blueprint
from authentik.core.models import Application, Group
from authentik.core.tests.utils import create_test_cert, create_test_flow, create_test_user
from authentik.lib.generators import generate_id
from authentik.policies.models import PolicyBinding
from authentik.providers.oauth2.constants import (
GRANT_TYPE_CLIENT_CREDENTIALS,
SCOPE_OPENID,
SCOPE_OPENID_EMAIL,
SCOPE_OPENID_PROFILE,
TOKEN_TYPE,
)
from authentik.providers.oauth2.models import (
AccessToken,
OAuth2Provider,
RedirectURI,
RedirectURIMatchingMode,
ScopeMapping,
)
from authentik.providers.oauth2.tests.utils import OAuthTestCase


class TestTokenClientCredentialsJWTProvider(OAuthTestCase):
"""Test token (client_credentials, with JWT) view"""

@apply_blueprint("system/providers-oauth2.yaml")
def setUp(self) -> None:
super().setUp()
self.factory = RequestFactory()
self.other_cert = create_test_cert()
self.cert = create_test_cert()

self.other_provider = OAuth2Provider.objects.create(
name=generate_id(),
authorization_flow=create_test_flow(),
signing_key=self.other_cert,
)
self.other_provider.property_mappings.set(ScopeMapping.objects.all())
self.app = Application.objects.create(
name=generate_id(), slug=generate_id(), provider=self.other_provider
)

self.provider: OAuth2Provider = OAuth2Provider.objects.create(
name="test",
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver")],
signing_key=self.cert,
)
self.provider.jwt_federation_providers.add(self.other_provider)
self.provider.property_mappings.set(ScopeMapping.objects.all())
self.app = Application.objects.create(name="test", slug="test", provider=self.provider)

def test_invalid_type(self):
"""test invalid type"""
response = self.client.post(
reverse("authentik_providers_oauth2:token"),
{
"grant_type": GRANT_TYPE_CLIENT_CREDENTIALS,
"scope": f"{SCOPE_OPENID} {SCOPE_OPENID_EMAIL} {SCOPE_OPENID_PROFILE}",
"client_id": self.provider.client_id,
"client_assertion_type": "foo",
"client_assertion": "foo.bar",
},
)
self.assertEqual(response.status_code, 400)
body = loads(response.content.decode())
self.assertEqual(body["error"], "invalid_grant")

def test_invalid_jwt(self):
"""test invalid JWT"""
response = self.client.post(
reverse("authentik_providers_oauth2:token"),
{
"grant_type": GRANT_TYPE_CLIENT_CREDENTIALS,
"scope": f"{SCOPE_OPENID} {SCOPE_OPENID_EMAIL} {SCOPE_OPENID_PROFILE}",
"client_id": self.provider.client_id,
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"client_assertion": "foo.bar",
},
)
self.assertEqual(response.status_code, 400)
body = loads(response.content.decode())
self.assertEqual(body["error"], "invalid_grant")

def test_invalid_signature(self):
"""test invalid JWT"""
token = self.provider.encode(
{
"sub": "foo",
"exp": datetime.now() + timedelta(hours=2),
}
)
response = self.client.post(
reverse("authentik_providers_oauth2:token"),
{
"grant_type": GRANT_TYPE_CLIENT_CREDENTIALS,
"scope": f"{SCOPE_OPENID} {SCOPE_OPENID_EMAIL} {SCOPE_OPENID_PROFILE}",
"client_id": self.provider.client_id,
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"client_assertion": token + "foo",
},
)
self.assertEqual(response.status_code, 400)
body = loads(response.content.decode())
self.assertEqual(body["error"], "invalid_grant")

def test_invalid_expired(self):
"""test invalid JWT"""
token = self.provider.encode(
{
"sub": "foo",
"exp": datetime.now() - timedelta(hours=2),
}
)
response = self.client.post(
reverse("authentik_providers_oauth2:token"),
{
"grant_type": GRANT_TYPE_CLIENT_CREDENTIALS,
"scope": f"{SCOPE_OPENID} {SCOPE_OPENID_EMAIL} {SCOPE_OPENID_PROFILE}",
"client_id": self.provider.client_id,
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"client_assertion": token,
},
)
self.assertEqual(response.status_code, 400)
body = loads(response.content.decode())
self.assertEqual(body["error"], "invalid_grant")

def test_invalid_no_app(self):
"""test invalid JWT"""
self.app.provider = None
self.app.save()
token = self.provider.encode(
{
"sub": "foo",
"exp": datetime.now() + timedelta(hours=2),
}
)
response = self.client.post(
reverse("authentik_providers_oauth2:token"),
{
"grant_type": GRANT_TYPE_CLIENT_CREDENTIALS,
"scope": f"{SCOPE_OPENID} {SCOPE_OPENID_EMAIL} {SCOPE_OPENID_PROFILE}",
"client_id": self.provider.client_id,
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"client_assertion": token,
},
)
self.assertEqual(response.status_code, 400)
body = loads(response.content.decode())
self.assertEqual(body["error"], "invalid_grant")

def test_invalid_access_denied(self):
"""test invalid JWT"""
group = Group.objects.create(name="foo")
PolicyBinding.objects.create(
group=group,
target=self.app,
order=0,
)
token = self.provider.encode(
{
"sub": "foo",
"exp": datetime.now() + timedelta(hours=2),
}
)
response = self.client.post(
reverse("authentik_providers_oauth2:token"),
{
"grant_type": GRANT_TYPE_CLIENT_CREDENTIALS,
"scope": f"{SCOPE_OPENID} {SCOPE_OPENID_EMAIL} {SCOPE_OPENID_PROFILE}",
"client_id": self.provider.client_id,
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"client_assertion": token,
},
)
self.assertEqual(response.status_code, 400)
body = loads(response.content.decode())
self.assertEqual(body["error"], "invalid_grant")

def test_successful(self):
"""test successful"""
user = create_test_user()
token = self.other_provider.encode(
{
"sub": "foo",
"exp": datetime.now() + timedelta(hours=2),
}
)
AccessToken.objects.create(
provider=self.other_provider,
token=token,
user=user,
auth_time=now(),
)

response = self.client.post(
reverse("authentik_providers_oauth2:token"),
{
"grant_type": GRANT_TYPE_CLIENT_CREDENTIALS,
"scope": f"{SCOPE_OPENID} {SCOPE_OPENID_EMAIL} {SCOPE_OPENID_PROFILE}",
"client_id": self.provider.client_id,
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"client_assertion": token,
},
)
self.assertEqual(response.status_code, 200)
body = loads(response.content.decode())
self.assertEqual(body["token_type"], TOKEN_TYPE)
_, alg = self.provider.jwt_key
jwt = decode(
body["access_token"],
key=self.provider.signing_key.public_key,
algorithms=[alg],
audience=self.provider.client_id,
)
self.assertEqual(jwt["given_name"], user.name)
self.assertEqual(jwt["preferred_username"], user.username)
21 changes: 14 additions & 7 deletions authentik/providers/oauth2/tests/test_token_cc_jwt_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,16 @@ class TestTokenClientCredentialsJWTSource(OAuthTestCase):
def setUp(self) -> None:
super().setUp()
self.factory = RequestFactory()
self.other_cert = create_test_cert()
# Provider used as a helper to sign JWTs with the same key as the OAuth source has
self.helper_provider = OAuth2Provider.objects.create(
name=generate_id(),
authorization_flow=create_test_flow(),
signing_key=self.other_cert,
)
self.cert = create_test_cert()

jwk = JWKSView().get_jwk_for_key(self.cert, "sig")
jwk = JWKSView().get_jwk_for_key(self.other_cert, "sig")
self.source: OAuthSource = OAuthSource.objects.create(
name=generate_id(),
slug=generate_id(),
Expand All @@ -62,7 +69,7 @@ def setUp(self) -> None:
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://testserver")],
signing_key=self.cert,
)
self.provider.jwks_sources.add(self.source)
self.provider.jwt_federation_sources.add(self.source)
self.provider.property_mappings.set(ScopeMapping.objects.all())
self.app = Application.objects.create(name="test", slug="test", provider=self.provider)

Expand Down Expand Up @@ -100,7 +107,7 @@ def test_invalid_jwt(self):

def test_invalid_signature(self):
"""test invalid JWT"""
token = self.provider.encode(
token = self.helper_provider.encode(
{
"sub": "foo",
"exp": datetime.now() + timedelta(hours=2),
Expand All @@ -122,7 +129,7 @@ def test_invalid_signature(self):

def test_invalid_expired(self):
"""test invalid JWT"""
token = self.provider.encode(
token = self.helper_provider.encode(
{
"sub": "foo",
"exp": datetime.now() - timedelta(hours=2),
Expand All @@ -146,7 +153,7 @@ def test_invalid_no_app(self):
"""test invalid JWT"""
self.app.provider = None
self.app.save()
token = self.provider.encode(
token = self.helper_provider.encode(
{
"sub": "foo",
"exp": datetime.now() + timedelta(hours=2),
Expand Down Expand Up @@ -174,7 +181,7 @@ def test_invalid_access_denied(self):
target=self.app,
order=0,
)
token = self.provider.encode(
token = self.helper_provider.encode(
{
"sub": "foo",
"exp": datetime.now() + timedelta(hours=2),
Expand All @@ -196,7 +203,7 @@ def test_invalid_access_denied(self):

def test_successful(self):
"""test successful"""
token = self.provider.encode(
token = self.helper_provider.encode(
{
"sub": "foo",
"exp": datetime.now() + timedelta(hours=2),
Expand Down
2 changes: 1 addition & 1 deletion authentik/providers/oauth2/views/device_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def validate_code(self, code: int) -> HttpResponse | None:


class OAuthDeviceCodeStage(ChallengeStageView):
"""Flow challenge for users to enter device codes"""
"""Flow challenge for users to enter device code"""

response_class = OAuthDeviceCodeChallengeResponse

Expand Down
Loading
Loading