From 4203a3e039ed2a1e8e0f1df9b6666122181a86fb Mon Sep 17 00:00:00 2001 From: Shaun Kruger Date: Sun, 26 May 2019 21:53:37 -0600 Subject: [PATCH 1/3] Improve oauth middleware and provide helper mixins --- Vagrantfile | 2 +- docs/changes.rst | 9 +- provider/oauth2/forms.py | 4 +- provider/oauth2/middleware.py | 19 +++- provider/oauth2/mixins.py | 34 +++++++ provider/oauth2/models.py | 2 +- provider/oauth2/tests/__init__.py | 0 provider/oauth2/tests/test_middleware.py | 97 +++++++++++++++++++ .../oauth2/{tests.py => tests/test_views.py} | 19 ++-- provider/oauth2/tests/urls.py | 42 ++++++++ provider/oauth2/views.py | 2 +- provider/views.py | 2 +- tests/settings.py | 8 +- tests/urls.py | 3 +- tox.ini | 11 ++- 15 files changed, 232 insertions(+), 22 deletions(-) create mode 100644 provider/oauth2/mixins.py create mode 100644 provider/oauth2/tests/__init__.py create mode 100644 provider/oauth2/tests/test_middleware.py rename provider/oauth2/{tests.py => tests/test_views.py} (97%) create mode 100644 provider/oauth2/tests/urls.py diff --git a/Vagrantfile b/Vagrantfile index 50036bc9..15eb3c16 100644 --- a/Vagrantfile +++ b/Vagrantfile @@ -48,7 +48,7 @@ Vagrant.configure(2) do |config| # Display the VirtualBox GUI when booting the machine # vb.gui = true # Customize the amount of memory on the VM: - vb.memory = "1024" + vb.memory = "2048" end # # View the documentation for the provider you are using for more diff --git a/docs/changes.rst b/docs/changes.rst index 01ee6654..9ad08951 100644 --- a/docs/changes.rst +++ b/docs/changes.rst @@ -1,4 +1,11 @@ +v 2.2 +----- +* Improve Oauth2UserMiddleware +* Prevent SessionMiddleware from creating new sessions when using oauth tokens. +* Add OAuthRequiredMixin to allow scope enforcement + v 2.1 +----- * Fixed documentation links. Removed 2.0 package. v 2.0 @@ -7,7 +14,7 @@ v 2.0 v 1.2 ----- -Updated to make skopes configurable in the database and update for Django 1.7 +Updated to make scopes configurable in the database and update for Django 1.7 v 1.0 ----- diff --git a/provider/oauth2/forms.py b/provider/oauth2/forms.py index 548e813a..3666d45d 100644 --- a/provider/oauth2/forms.py +++ b/provider/oauth2/forms.py @@ -3,7 +3,7 @@ from django.contrib.auth import authenticate from django.conf import settings from django.utils.translation import ugettext as _ -from provider.constants import RESPONSE_TYPE_CHOICES, SCOPES +from provider.constants import RESPONSE_TYPE_CHOICES, SCOPES, PUBLIC from provider.forms import OAuthForm, OAuthValidationError from provider.utils import now from provider.oauth2.models import Client, Grant, RefreshToken, Scope @@ -298,7 +298,7 @@ def clean(self): except Client.DoesNotExist: raise OAuthValidationError({'error': 'invalid_client'}) - if client.client_type != 1: # public + if client.client_type != PUBLIC: # public raise OAuthValidationError({'error': 'invalid_client'}) data['client'] = client diff --git a/provider/oauth2/middleware.py b/provider/oauth2/middleware.py index 8dec181e..2dc7eb57 100644 --- a/provider/oauth2/middleware.py +++ b/provider/oauth2/middleware.py @@ -1,13 +1,16 @@ +from django.conf import settings from django.contrib import auth from django.core.exceptions import ImproperlyConfigured +from django.utils.deprecation import MiddlewareMixin from provider.oauth2.models import AccessToken import logging log = logging.getLogger(__name__) -class Oauth2UserMiddleware(object): + +class Oauth2UserMiddleware(MiddlewareMixin): """ Middleware for using OAuth credentials to authenticate requests @@ -32,6 +35,13 @@ def process_request(self, request): " Insert 'django.contrib.auth.middleware.AuthenticationMiddleware'" " before this Oauth2UserMiddleware class." ) + if 'django.contrib.auth.backends.RemoteUserBackend' not in settings.AUTHENTICATION_BACKENDS: + raise ImproperlyConfigured( + "Remote user authentication backend is required for this module to work." + " Insert 'django.contrib.auth.backends.RemoteUserBackend' into the" + " AUTHENTICATION_BACKENDS list in your settings." + + ) try: access_token_http = self._http_access_token(request) access_token_get = request.GET.get('access_token', access_token_http) @@ -49,6 +59,13 @@ def process_request(self, request): user = auth.authenticate(remote_user=token.user.username) auth.login(request, user) request.oauth2_client = token.client + request.oauth2_token = token except Exception as e: log.error("Oauth2UserMiddleware encountered an exception! " "{}: {}".format(e.__class__.__name__, e)) + + def process_response(self, request, response): + if hasattr(request, 'oauth2_token'): + # Set modified=False to prevent the session from being stored and the cookie from being sent + request.session.modified = False + return response diff --git a/provider/oauth2/mixins.py b/provider/oauth2/mixins.py new file mode 100644 index 00000000..6160c4d4 --- /dev/null +++ b/provider/oauth2/mixins.py @@ -0,0 +1,34 @@ +from django.utils.decorators import classonlymethod +from django.http.response import JsonResponse + + +class OAuthRegisteredScopes(object): + scopes = set() + + +class OAuthRequiredMixin(object): + accepted_oauth_scopes = [] + + @classonlymethod + def as_view(cls, *args, **kwargs): + for scope in cls.accepted_oauth_scopes: + OAuthRegisteredScopes.scopes.add(scope) + + return super(OAuthRequiredMixin, cls).as_view() + + def dispatch(self, request, *args, **kwargs): + scopes = list() + if hasattr(request, 'oauth2_token'): + scopes = set(request.oauth2_token.scope.all().values_list('name', flat=True)) + + if request.user.is_authenticated and scopes.intersection(self.accepted_oauth_scopes): + return super(OAuthRequiredMixin, self).dispatch(request, *args, **kwargs) + + return JsonResponse( + { + 'error': 'bad_access_token', + 'accepted_scopes': sorted(self.accepted_oauth_scopes), + 'token_scopes': sorted(scopes) + }, + status=401 + ) diff --git a/provider/oauth2/models.py b/provider/oauth2/models.py index c02f00e1..4870c3da 100644 --- a/provider/oauth2/models.py +++ b/provider/oauth2/models.py @@ -44,7 +44,7 @@ def __unicode__(self): return self.redirect_uri def get_default_token_expiry(self): - public = (self.client_type == 1) + public = (self.client_type == constants.PUBLIC) return get_token_expiry(public) class Meta: diff --git a/provider/oauth2/tests/__init__.py b/provider/oauth2/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/provider/oauth2/tests/test_middleware.py b/provider/oauth2/tests/test_middleware.py new file mode 100644 index 00000000..e3509e84 --- /dev/null +++ b/provider/oauth2/tests/test_middleware.py @@ -0,0 +1,97 @@ +import json +from six.moves.urllib_parse import urlparse + +from django.shortcuts import reverse +from django.http import QueryDict + +from provider.oauth2.models import Scope +from provider.oauth2.mixins import OAuthRegisteredScopes +from provider.oauth2.tests.test_views import BaseOAuth2TestCase + + +class MiddlewareTestCase(BaseOAuth2TestCase): + fixtures = ['test_oauth2.json'] + + def setUp(self): + if not Scope.objects.filter(name='read').exists(): + Scope.objects.create(name='read') + + def _login_authorize_get_token(self): + required_props = ['access_token', 'token_type'] + + self.login() + self._login_and_authorize() + + response = self.client.get(self.redirect_url()) + query = QueryDict(urlparse(response['Location']).query) + code = query['code'] + + response = self.client.post(self.access_token_url(), { + 'grant_type': 'authorization_code', + 'client_id': self.get_client().client_id, + 'client_secret': self.get_client().client_secret, + 'code': code}) + + self.assertEqual(200, response.status_code, response.content) + + token = json.loads(response.content) + + for prop in required_props: + self.assertIn(prop, token, "Access token response missing " + "required property: %s" % prop) + + return token + + def test_mixin_scopes(self): + self.assertIn('read', OAuthRegisteredScopes.scopes) + + def test_no_token(self): + # user_url = self.live_server_url + reverse('tests:user', args=[self.get_user().pk]) + # result = requests.get(user_url) + + user_url = reverse('tests:user', args=[self.get_user().pk]) + result = self.client.get(user_url) + + self.assertEqual(result.status_code, 401) + + def test_token_access(self): + self.login() + token_info = self._login_authorize_get_token() + token = token_info['access_token'] + + # Create a new client to ensure a clean session + oauth_client = self.client_class() + + user_url = reverse('tests:user', args=[self.get_user().pk]) + result = oauth_client.get(user_url, {'access_token': token}) + + self.assertEqual(result.status_code, 200) + result_json = result.json() + self.assertEqual(result_json.get('id'), self.get_user().pk) + + def test_unauthorized_scope(self): + self.login() + token_info = self._login_authorize_get_token() + token = token_info['access_token'] + + badscope_url = reverse('tests:badscope') + + oauth_client = self.client_class() + + result = oauth_client.get(badscope_url, {'access_token': token}) + + self.assertEqual(result.status_code, 401) + result_json = result.json() + # self.assertEqual(result_json.get('id'), self.get_user().pk) + + def test_no_stored_session(self): + self.login() + token_info = self._login_authorize_get_token() + token = token_info['access_token'] + + oauth_client = self.client_class() + + user_url = reverse('tests:user', args=[self.get_user().pk]) + result = oauth_client.get(user_url, {'access_token': token}) + + self.assertNotIn('sessionid', result.cookies) diff --git a/provider/oauth2/tests.py b/provider/oauth2/tests/test_views.py similarity index 97% rename from provider/oauth2/tests.py rename to provider/oauth2/tests/test_views.py index 94544277..f67a1a35 100644 --- a/provider/oauth2/tests.py +++ b/provider/oauth2/tests/test_views.py @@ -51,9 +51,10 @@ def get_password(self): def _login_and_authorize(self, url_func=None): if url_func is None: - url_func = lambda: self.auth_url() + '?client_id={}&response_type=code&state=abc'.format( - self.get_client().client_id - ) + def url_func(): + return self.auth_url() + '?client_id={}&response_type=code&state=abc'.format( + self.get_client().client_id + ) response = self.client.get(url_func()) response = self.client.get(self.auth_url2()) @@ -344,7 +345,7 @@ def test_refreshing_an_access_token(self): def test_password_grant_public(self): c = self.get_client() - c.client_type = 1 # public + c.client_type = constants.PUBLIC c.save() response = self.client.post(self.access_token_url(), { @@ -363,7 +364,7 @@ def test_password_grant_public(self): def test_password_grant_confidential(self): c = self.get_client() - c.client_type = 0 # confidential + c.client_type = constants.CONFIDENTIAL c.save() response = self.client.post(self.access_token_url(), { @@ -379,7 +380,7 @@ def test_password_grant_confidential(self): def test_password_grant_confidential_no_secret(self): c = self.get_client() - c.client_type = 0 # confidential + c.client_type = constants.CONFIDENTIAL c.save() response = self.client.post(self.access_token_url(), { @@ -393,7 +394,7 @@ def test_password_grant_confidential_no_secret(self): def test_password_grant_invalid_password_public(self): c = self.get_client() - c.client_type = 1 # public + c.client_type = constants.PUBLIC c.save() response = self.client.post(self.access_token_url(), { @@ -408,7 +409,7 @@ def test_password_grant_invalid_password_public(self): def test_password_grant_invalid_password_confidential(self): c = self.get_client() - c.client_type = 0 # confidential + c.client_type = constants.CONFIDENTIAL c.save() response = self.client.post(self.access_token_url(), { @@ -497,7 +498,7 @@ def test_client_form(self): 'name': 'TestName', 'url': 'http://127.0.0.1:8000', 'redirect_uri': 'http://localhost:8000/', - 'client_type': constants.CLIENT_TYPES[0][0]}) + 'client_type': constants.CONFIDENTIAL}) self.assertTrue(form.is_valid()) form.save() diff --git a/provider/oauth2/tests/urls.py b/provider/oauth2/tests/urls.py new file mode 100644 index 00000000..0eefd116 --- /dev/null +++ b/provider/oauth2/tests/urls.py @@ -0,0 +1,42 @@ +from django.conf.urls import url +from django.http.response import JsonResponse +from django.views.generic import View +from django.contrib.auth.mixins import LoginRequiredMixin +from django.contrib.auth.models import User +from django.shortcuts import get_object_or_404 + +from provider.oauth2.mixins import OAuthRequiredMixin + +app_name = 'tests' + + +class UserView(OAuthRequiredMixin, LoginRequiredMixin, View): + accepted_oauth_scopes = ['read'] + + def get(self, request, *args, **kwargs): + user = get_object_or_404(User, pk=self.kwargs['pk']) + return JsonResponse( + { + 'username': user.username, + 'id': user.pk, + } + ) + + +class BadScopeView(OAuthRequiredMixin, LoginRequiredMixin, View): + accepted_oauth_scopes = ['badscope'] + + def get(self, request, *args, **kwargs): + user = self.request.user + return JsonResponse( + { + 'username': user.username, + 'id': user.pk, + } + ) + + +urlpatterns = [ + url('^badscope$', BadScopeView.as_view(), name='badscope'), + url('^user/(?P\d+)$', UserView.as_view(), name='user'), +] diff --git a/provider/oauth2/views.py b/provider/oauth2/views.py index b17697e2..578c4bc4 100644 --- a/provider/oauth2/views.py +++ b/provider/oauth2/views.py @@ -118,7 +118,7 @@ def get_access_token(self, request, user, scope, client): except models.AccessToken.DoesNotExist: # None found... make a new one! at = self.create_access_token(request, user, scope, client) - if client.client_type != 1: + if client.client_type != constants.PUBLIC: self.create_refresh_token(request, user, scope, at, client) return at diff --git a/provider/views.py b/provider/views.py index 6cd8a57f..27c38855 100644 --- a/provider/views.py +++ b/provider/views.py @@ -557,7 +557,7 @@ def password(self, request, data, client): at = self.create_access_token(request, user, scope, client) # Public clients don't get refresh tokens - if client.client_type != 1: + if client.client_type != constants.PUBLIC: rt = self.create_refresh_token(request, user, scope, at, client) return self.access_token_response(at) diff --git a/tests/settings.py b/tests/settings.py index 333f26bd..12e2dc8c 100644 --- a/tests/settings.py +++ b/tests/settings.py @@ -70,7 +70,7 @@ # 'django.template.context_processors.debug', # 'django.template.context_processors.request', 'django.contrib.auth.context_processors.auth', - # 'django.contrib.messages.context_processors.messages', + 'django.contrib.messages.context_processors.messages', ], }, }, @@ -81,10 +81,16 @@ 'django.middleware.common.CommonMiddleware', 'django.middleware.csrf.CsrfViewMiddleware', 'django.contrib.auth.middleware.AuthenticationMiddleware', + 'provider.oauth2.middleware.Oauth2UserMiddleware', 'django.contrib.messages.middleware.MessageMiddleware', 'django.middleware.clickjacking.XFrameOptionsMiddleware', ) +AUTHENTICATION_BACKENDS = [ + 'django.contrib.auth.backends.RemoteUserBackend', + 'django.contrib.auth.backends.ModelBackend', +] + PASSWORD_HASHERS = [ 'django.contrib.auth.hashers.PBKDF2PasswordHasher', 'django.contrib.auth.hashers.PBKDF2SHA1PasswordHasher', diff --git a/tests/urls.py b/tests/urls.py index 5308d4fc..ea504bc0 100644 --- a/tests/urls.py +++ b/tests/urls.py @@ -5,5 +5,6 @@ urlpatterns = [ url(r'^admin/', admin.site.urls), - url(r'^oauth2/', include('provider.oauth2.urls', namespace = 'oauth2')), + url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')), + url(r'^tests/', include('provider.oauth2.tests.urls', namespace='tests')), ] diff --git a/tox.ini b/tox.ini index 94c8ab73..76d9a7eb 100644 --- a/tox.ini +++ b/tox.ini @@ -1,7 +1,7 @@ [tox] toxworkdir={env:TOX_WORK_DIR:.tox} downloadcache = {toxworkdir}/cache/ -envlist = py{2.7,3.6}-django1.11,py3.6-django{2.0,2.1} +envlist = py{2.7,3.6}-django1.11,py3.6-django{2.0,2.1,2.2} [testenv] setenv = @@ -13,7 +13,7 @@ deps = [travis] python = 2.7: py2.7-django1.11 - 3.6: py3.6-django{1.11,2.0,2.1} + 3.6: py3.6-django{1.11,2.0,2.1,2.2} [testenv:py2.7-django1.11] basepython = python2.7 @@ -32,5 +32,10 @@ deps = Django>=2.0,<2.1 [testenv:py3.6-django2.1] basepython = python3.6 -deps = Django>=2.1 +deps = Django>=2.1,<2.2 + {[testenv]deps} + +[testenv:py3.6-django2.2] +basepython = python3.6 +deps = Django>=2.2 {[testenv]deps} From 645ef1003e3d4d75435f9d5b932c1c72d83cc382 Mon Sep 17 00:00:00 2001 From: Shaun Kruger Date: Sun, 26 May 2019 22:07:02 -0600 Subject: [PATCH 2/3] Update build environment for sqlite version upgrade --- .travis.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.travis.yml b/.travis.yml index 8327ded9..a5885444 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,3 +1,4 @@ +dist: xenial sudo: false language: python python: From 071294b5118ec2291b3258696c1c9a4ab8300b04 Mon Sep 17 00:00:00 2001 From: Shaun Kruger Date: Mon, 29 Jul 2019 23:29:41 -0600 Subject: [PATCH 3/3] Increment version --- provider/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/provider/__init__.py b/provider/__init__.py index ab028660..d3c88807 100644 --- a/provider/__init__.py +++ b/provider/__init__.py @@ -1 +1 @@ -__version__ = "2.1" +__version__ = "2.2"