Skip to content

Commit

Permalink
Finish updating code for Python3 and Django2 compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
skruger committed Jan 6, 2019
1 parent 237efca commit 9d0b233
Show file tree
Hide file tree
Showing 8 changed files with 72 additions and 28 deletions.
19 changes: 15 additions & 4 deletions provider/oauth2/backends.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import base64

from provider.utils import now
from provider.oauth2.forms import ClientAuthForm, PublicPasswordGrantForm
from provider.oauth2.models import AccessToken
Expand Down Expand Up @@ -28,8 +30,9 @@ def authenticate(self, request=None):
return None

try:
basic, base64 = auth.split(' ')
client_id, client_secret = base64.decode('base64').split(':')
basic, enc_user_passwd = auth.split(' ')
user_pass = base64.b64decode(enc_user_passwd).decode('utf8')
client_id, client_secret = user_pass.split(':')

form = ClientAuthForm({
'client_id': client_id,
Expand All @@ -53,7 +56,11 @@ def authenticate(self, request=None):
if request is None:
return None

form = ClientAuthForm(request.REQUEST)
if hasattr(request, 'REQUEST'):
args = request.REQUEST
else:
args = request.POST or request.GET
form = ClientAuthForm(args)

if form.is_valid():
return form.cleaned_data.get('client')
Expand All @@ -74,7 +81,11 @@ def authenticate(self, request=None):
if request is None:
return None

form = PublicPasswordGrantForm(request.REQUEST)
if hasattr(request, 'REQUEST'):
args = request.REQUEST
else:
args = request.POST or request.GET
form = PublicPasswordGrantForm(args)

if form.is_valid():
return form.cleaned_data.get('client')
Expand Down
5 changes: 3 additions & 2 deletions provider/oauth2/forms.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from six import string_types
from django import forms
from django.contrib.auth import authenticate
from django.conf import settings
Expand Down Expand Up @@ -51,7 +52,7 @@ class ScopeModelChoiceField(forms.ModelMultipleChoiceField):
# widget = forms.TextInput

def to_python(self, value):
if isinstance(value, basestring):
if isinstance(value, string_types):
return [s for s in value.split(' ') if s != '']
else:
return value
Expand Down Expand Up @@ -159,7 +160,7 @@ def save(self, **kwargs):

grant = Grant(**kwargs)
grant.save()
grant.scope = self.cleaned_data.get('scope')
grant.scope.set(self.cleaned_data.get('scope'))
return grant


Expand Down
42 changes: 25 additions & 17 deletions provider/oauth2/tests.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import base64
import json
from six.moves.urllib_parse import urlparse, parse_qs
import datetime
from six.moves.urllib_parse import urlparse, parse_qs

from unittest import SkipTest
from django.http import QueryDict
from django.conf import settings
Expand Down Expand Up @@ -49,14 +51,16 @@ def get_password(self):

def _login_and_authorize(self, url_func=None):
if url_func is None:
url_func = lambda: self.auth_url() + '?client_id=%s&response_type=code&state=abc' % self.get_client().client_id
url_func = lambda: self.auth_url() + '?client_id={}&response_type=code&state=abc'.format(
self.get_client().client_id
)

response = self.client.get(url_func())
response = self.client.get(self.auth_url2())

response = self.client.post(self.auth_url2(), {'authorize': True, 'scope': 'read'})
self.assertEqual(302, response.status_code, response.content)
self.assertTrue(self.redirect_url() in response['Location'])
self.assertIn(self.redirect_url(), response['Location'])


class AuthorizationTest(BaseOAuth2TestCase):
Expand Down Expand Up @@ -90,31 +94,31 @@ def test_authorization_requires_client_id(self):
response = self.client.get(self.auth_url2())

self.assertEqual(400, response.status_code)
self.assertTrue("An unauthorized client tried to access your resources." in response.content)
self.assertIn("An unauthorized client tried to access your resources.", response.content.decode('utf8'))

def test_authorization_rejects_invalid_client_id(self):
self.login()
response = self.client.get(self.auth_url() + '?client_id=123')
response = self.client.get(self.auth_url2())

self.assertEqual(400, response.status_code)
self.assertTrue("An unauthorized client tried to access your resources." in response.content)
self.assertIn(b"An unauthorized client tried to access your resources.", response.content)

def test_authorization_requires_response_type(self):
self.login()
response = self.client.get(self.auth_url() + '?client_id=%s' % self.get_client().client_id)
response = self.client.get(self.auth_url2())

self.assertEqual(400, response.status_code)
self.assertTrue(escape(u"No 'response_type' supplied.") in response.content)
self.assertIn(escape(u"No 'response_type' supplied."), response.content.decode('utf8'))

def test_authorization_requires_supported_response_type(self):
self.login()
response = self.client.get(self.auth_url() + '?client_id=%s&response_type=unsupported' % self.get_client().client_id)
response = self.client.get(self.auth_url2())

self.assertEqual(400, response.status_code)
self.assertTrue(escape(u"'unsupported' is not a supported response type.") in response.content)
self.assertIn(escape(u"'unsupported' is not a supported response type."), response.content.decode('utf8'))

response = self.client.get(self.auth_url() + '?client_id=%s&response_type=code' % self.get_client().client_id)
response = self.client.get(self.auth_url2())
Expand All @@ -133,7 +137,7 @@ def test_authorization_requires_a_valid_redirect_uri(self):
response = self.client.get(self.auth_url2())

self.assertEqual(400, response.status_code)
self.assertTrue(escape(u"The requested redirect didn't match the client settings.") in response.content)
self.assertIn(escape("The requested redirect didn't match the client settings."), response.content.decode('utf8'))

response = self.client.get(self.auth_url() + '?client_id=%s&response_type=code&redirect_uri=%s' % (
self.get_client().client_id,
Expand All @@ -148,7 +152,7 @@ def test_authorization_requires_a_valid_scope(self):
response = self.client.get(self.auth_url() + '?client_id=%s&response_type=code&scope=invalid+invalid2' % self.get_client().client_id)

self.assertEqual(400, response.status_code)
self.assertTrue(escape(u"Invalid scope.") in response.content)
self.assertIn(escape(u"Invalid scope."), response.content.decode('utf8'))

response = self.client.get(self.auth_url() + '?client_id=%s&response_type=code&scope=%s' % (
self.get_client().client_id,
Expand Down Expand Up @@ -428,9 +432,12 @@ class AuthBackendTest(BaseOAuth2TestCase):

def test_basic_client_backend(self):
request = type('Request', (object,), {'META': {}})()
request.META['HTTP_AUTHORIZATION'] = "Basic " + "{0}:{1}".format(
user_pass = "{0}:{1}".format(
self.get_client().client_id,
self.get_client().client_secret).encode('base64')
self.get_client().client_secret
)
user_pass64 = base64.b64encode(user_pass.encode('utf8')).decode('utf8')
request.META['HTTP_AUTHORIZATION'] = "Basic {}".format(user_pass64)

self.assertEqual(BasicClientBackend().authenticate(request).id,
2, "Didn't return the right client.")
Expand Down Expand Up @@ -470,13 +477,13 @@ def test_authorization_enforces_SSL(self):
response = self.client.get(self.auth_url())

self.assertEqual(400, response.status_code)
self.assertTrue("A secure connection is required." in response.content)
self.assertIn("A secure connection is required.", response.content.decode('utf8'))

def test_access_token_enforces_SSL(self):
response = self.client.post(self.access_token_url(), {})

self.assertEqual(400, response.status_code)
self.assertTrue("A secure connection is required." in response.content)
self.assertIn("A secure connection is required.", response.content.decode('utf8'))


class ClientFormTest(TestCase):
Expand Down Expand Up @@ -548,11 +555,12 @@ def test_clear_expired(self):

self.assertEqual(302, response.status_code)
location = response['Location']
self.assertFalse('error' in location)
self.assertTrue('code' in location)

self.assertNotIn('error', location)
self.assertIn('code', location)
print(location)
# verify that Grant with code exists
code = parse_qs(location)['code'][0]
parsed_location = urlparse(location)
code = parse_qs(parsed_location.query)['code'][0]
self.assertTrue(Grant.objects.filter(code=code).exists())

# use the code/grant
Expand Down
3 changes: 2 additions & 1 deletion provider/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
See :class:`provider.scope.to_int` on how scopes are combined.
"""
from functools import reduce

from .constants import SCOPES

Expand Down Expand Up @@ -73,7 +74,7 @@ def to_names(scope):
"""
return [
name
for (name, value) in SCOPE_NAME_DICT.iteritems()
for (name, value) in SCOPE_NAME_DICT.items()
if check(value, scope)
]

Expand Down
1 change: 0 additions & 1 deletion provider/templates/provider/authorize.html
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
{% load scope %}
{% load url from future %}
{% block content %}
{% if not error %}
<p>{{ client.name }} would like to access your data with the following permissions:</p>
Expand Down
2 changes: 1 addition & 1 deletion provider/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def clear_data(self, request):
"""
Clear all OAuth related data from the session store.
"""
for key in request.session.keys():
for key in list(request.session.keys()):
if key.startswith(constants.SESSION_KEY):
del request.session[key]

Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
Django>=2.1
shortuuid>=0.4
six>=0.11.0
sqlparse>=0.2.4
26 changes: 24 additions & 2 deletions tests/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from django import VERSION as DJANGO_VERSION

DEBUG = True
TEMPLATE_DEBUG = DEBUG

ADMINS = (
('Tester', '[email protected]'),
Expand Down Expand Up @@ -61,7 +60,23 @@
'provider.oauth2',
)

MIDDLEWARE_CLASSES = (
TEMPLATES = [
{
'BACKEND': 'django.template.backends.django.DjangoTemplates',
'DIRS': [],
'APP_DIRS': True,
'OPTIONS': {
'context_processors': [
# 'django.template.context_processors.debug',
# 'django.template.context_processors.request',
'django.contrib.auth.context_processors.auth',
# 'django.contrib.messages.context_processors.messages',
],
},
},
]

MIDDLEWARE = (
'django.contrib.sessions.middleware.SessionMiddleware',
'django.middleware.common.CommonMiddleware',
'django.middleware.csrf.CsrfViewMiddleware',
Expand All @@ -70,6 +85,13 @@
'django.middleware.clickjacking.XFrameOptionsMiddleware',
)

PASSWORD_HASHERS = [
'django.contrib.auth.hashers.PBKDF2PasswordHasher',
'django.contrib.auth.hashers.PBKDF2SHA1PasswordHasher',
'django.contrib.auth.hashers.BCryptSHA256PasswordHasher',
'django.contrib.auth.hashers.SHA1PasswordHasher', # Used by unit tests
]

# Use DiscoverRunner on Django 1.7 and above
if DJANGO_VERSION[0] == 1 and DJANGO_VERSION[1] >= 7:
TEST_RUNNER = 'django.test.runner.DiscoverRunner'
Expand Down

0 comments on commit 9d0b233

Please sign in to comment.