Skip to content

Commit

Permalink
Remove PY2 compat, reformat using black
Browse files Browse the repository at this point in the history
  • Loading branch information
lnagel committed Apr 9, 2020
1 parent f2741bc commit dab6662
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 68 deletions.
3 changes: 1 addition & 2 deletions rest_framework_sso/keys.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# coding: utf-8
import os
import six

from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.serialization import load_pem_private_key, load_pem_public_key
Expand Down Expand Up @@ -34,7 +33,7 @@ def get_key_file_name(keys, issuer, key_id=None):
if not keys.get(issuer):
raise InvalidKeyError("No keys defined for the given issuer")
issuer_keys = keys.get(issuer)
if isinstance(issuer_keys, (str, six.text_type)):
if isinstance(issuer_keys, str):
issuer_keys = [issuer_keys]
if key_id:
issuer_keys = [ik for ik in issuer_keys if key_id in (ik, get_key_id(ik))]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
class Migration(migrations.Migration):

dependencies = [
('rest_framework_sso', '0001_initial'),
("rest_framework_sso", "0001_initial"),
]

operations = [
migrations.AddField(
model_name='sessiontoken',
name='last_used_at',
model_name="sessiontoken",
name="last_used_at",
field=models.DateTimeField(blank=True, db_index=True, null=True),
),
]
9 changes: 1 addition & 8 deletions rest_framework_sso/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,10 @@
from __future__ import absolute_import, unicode_literals

import uuid
import six

from django.conf import settings
from django.db import models

try:
from django.utils.encoding import python_2_unicode_compatible as smart_text
except ImportError:
from django.utils.encoding import smart_text

from django.utils.translation import ugettext_lazy as _

# Prior to Django 1.5, the AUTH_USER_MODEL setting does not exist.
Expand All @@ -27,7 +21,6 @@
AUTH_USER_MODEL = getattr(settings, "AUTH_USER_MODEL", "auth.User")


@smart_text
class SessionToken(models.Model):
"""
The default session token model.
Expand All @@ -54,7 +47,7 @@ class Meta:
verbose_name_plural = _("Session tokens")

def __str__(self):
return six.text_type(self.id)
return str(self.id)

def update_attributes(self, request):
if request.META.get("HTTP_X_FORWARDED_FOR"):
Expand Down
5 changes: 2 additions & 3 deletions rest_framework_sso/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# coding: utf-8
from __future__ import absolute_import, unicode_literals

import six
import jwt

from datetime import datetime
Expand Down Expand Up @@ -96,14 +95,14 @@ def decode_jwt_token(token):
unverified_claims = jwt.decode(token, verify=False)

if unverified_header.get(claims.KEY_ID):
unverified_key_id = six.text_type(unverified_header.get(claims.KEY_ID))
unverified_key_id = str(unverified_header.get(claims.KEY_ID))
else:
unverified_key_id = None

if claims.ISSUER not in unverified_claims:
raise MissingRequiredClaimError(claims.ISSUER)

unverified_issuer = six.text_type(unverified_claims[claims.ISSUER])
unverified_issuer = str(unverified_claims[claims.ISSUER])

if api_settings.ACCEPTED_ISSUERS is not None and unverified_issuer not in api_settings.ACCEPTED_ISSUERS:
raise InvalidIssuerError("Invalid issuer")
Expand Down
117 changes: 65 additions & 52 deletions tests/test_keys.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# coding: utf-8
from __future__ import absolute_import, unicode_literals
import six

from cryptography.hazmat.backends.openssl.rsa import _RSAPrivateKey, _RSAPublicKey
from django.test import TestCase
Expand All @@ -11,110 +10,124 @@

class TestReadKeyFile(TestCase):
def test_read(self):
key_data = keys.read_key_file('test-2048.pem')
self.assertIsInstance(key_data, six.binary_type)
key_data_lines = key_data.decode('utf-8').split('\n')
self.assertIn('-----BEGIN PRIVATE KEY-----', key_data_lines)
self.assertIn('-----END PRIVATE KEY-----', key_data_lines)
self.assertIn('-----BEGIN PUBLIC KEY-----', key_data_lines)
self.assertIn('-----END PUBLIC KEY-----', key_data_lines)
key_data = keys.read_key_file("test-2048.pem")
self.assertIsInstance(key_data, bytes)
key_data_lines = key_data.decode("utf-8").split("\n")
self.assertIn("-----BEGIN PRIVATE KEY-----", key_data_lines)
self.assertIn("-----END PRIVATE KEY-----", key_data_lines)
self.assertIn("-----BEGIN PUBLIC KEY-----", key_data_lines)
self.assertIn("-----END PUBLIC KEY-----", key_data_lines)


class TestGetKeyId(TestCase):
def test_root_simple(self):
key_id = keys.get_key_id(file_name='keyfile')
self.assertEqual(key_id, 'keyfile')
key_id = keys.get_key_id(file_name="keyfile")
self.assertEqual(key_id, "keyfile")

def test_root_pem_extension(self):
key_id = keys.get_key_id(file_name='keyfile.pem')
self.assertEqual(key_id, 'keyfile')
key_id = keys.get_key_id(file_name="keyfile.pem")
self.assertEqual(key_id, "keyfile")

def test_subfolder_simple(self):
key_id = keys.get_key_id(file_name='subfolder/keyfile')
self.assertEqual(key_id, 'subfolder/keyfile')
key_id = keys.get_key_id(file_name="subfolder/keyfile")
self.assertEqual(key_id, "subfolder/keyfile")

def test_subfolder_pem_extension(self):
key_id = keys.get_key_id(file_name='subfolder/keyfile.pem')
self.assertEqual(key_id, 'subfolder/keyfile')
key_id = keys.get_key_id(file_name="subfolder/keyfile.pem")
self.assertEqual(key_id, "subfolder/keyfile")


class TestGetKeyFileName(TestCase):
def test_empty_keys(self):
with self.assertRaisesMessage(InvalidKeyError, 'No keys defined for the given issuer'):
keys.get_key_file_name(keys={}, issuer='test-issuer')
with self.assertRaisesMessage(InvalidKeyError, "No keys defined for the given issuer"):
keys.get_key_file_name(keys={}, issuer="test-issuer")

def test_other_issuer_keys(self):
with self.assertRaisesMessage(InvalidKeyError, 'No keys defined for the given issuer'):
keys.get_key_file_name(keys={'other-issuer': ['other-key.pem']}, issuer='test-issuer')
with self.assertRaisesMessage(InvalidKeyError, "No keys defined for the given issuer"):
keys.get_key_file_name(keys={"other-issuer": ["other-key.pem"]}, issuer="test-issuer")

def test_one_key_string(self):
file_name = keys.get_key_file_name(keys={'test-issuer': 'first-key.pem'}, issuer='test-issuer')
self.assertEqual(file_name, 'first-key.pem')
file_name = keys.get_key_file_name(keys={"test-issuer": "first-key.pem"}, issuer="test-issuer")
self.assertEqual(file_name, "first-key.pem")

def test_one_key_list(self):
file_name = keys.get_key_file_name(keys={'test-issuer': ['first-key.pem']}, issuer='test-issuer')
self.assertEqual(file_name, 'first-key.pem')
file_name = keys.get_key_file_name(keys={"test-issuer": ["first-key.pem"]}, issuer="test-issuer")
self.assertEqual(file_name, "first-key.pem")

def test_one_key_with_key_id(self):
file_name = keys.get_key_file_name(keys={'test-issuer': ['first-key.pem']}, issuer='test-issuer')
self.assertEqual(file_name, 'first-key.pem')
file_name = keys.get_key_file_name(keys={"test-issuer": ["first-key.pem"]}, issuer="test-issuer")
self.assertEqual(file_name, "first-key.pem")

def test_one_key_incorrect_key_id(self):
with self.assertRaisesMessage(InvalidKeyError, 'No key matches the given key_id'):
keys.get_key_file_name(keys={'test-issuer': ['first-key.pem']}, issuer='test-issuer', key_id='incorrect-key')
with self.assertRaisesMessage(InvalidKeyError, "No key matches the given key_id"):
keys.get_key_file_name(
keys={"test-issuer": ["first-key.pem"]}, issuer="test-issuer", key_id="incorrect-key"
)

def test_two_keys_no_key_id(self):
file_name = keys.get_key_file_name(keys={'test-issuer': ['first-key.pem', 'second-key.pem']}, issuer='test-issuer')
self.assertEqual(file_name, 'first-key.pem')
file_name = keys.get_key_file_name(
keys={"test-issuer": ["first-key.pem", "second-key.pem"]}, issuer="test-issuer"
)
self.assertEqual(file_name, "first-key.pem")

def test_two_keys_with_key_id_1_exact(self):
file_name = keys.get_key_file_name(keys={'test-issuer': ['first-key.pem', 'second-key.pem']}, issuer='test-issuer', key_id='first-key.pem')
self.assertEqual(file_name, 'first-key.pem')
file_name = keys.get_key_file_name(
keys={"test-issuer": ["first-key.pem", "second-key.pem"]}, issuer="test-issuer", key_id="first-key.pem"
)
self.assertEqual(file_name, "first-key.pem")

def test_two_keys_with_key_id_1_no_pem(self):
file_name = keys.get_key_file_name(keys={'test-issuer': ['first-key.pem', 'second-key.pem']}, issuer='test-issuer', key_id='first-key')
self.assertEqual(file_name, 'first-key.pem')
file_name = keys.get_key_file_name(
keys={"test-issuer": ["first-key.pem", "second-key.pem"]}, issuer="test-issuer", key_id="first-key"
)
self.assertEqual(file_name, "first-key.pem")

def test_two_keys_with_key_id_2_exact(self):
file_name = keys.get_key_file_name(keys={'test-issuer': ['first-key.pem', 'second-key.pem']}, issuer='test-issuer', key_id='second-key.pem')
self.assertEqual(file_name, 'second-key.pem')
file_name = keys.get_key_file_name(
keys={"test-issuer": ["first-key.pem", "second-key.pem"]}, issuer="test-issuer", key_id="second-key.pem"
)
self.assertEqual(file_name, "second-key.pem")

def test_two_keys_with_key_id_2_no_pem(self):
file_name = keys.get_key_file_name(keys={'test-issuer': ['first-key.pem', 'second-key.pem']}, issuer='test-issuer', key_id='second-key')
self.assertEqual(file_name, 'second-key.pem')
file_name = keys.get_key_file_name(
keys={"test-issuer": ["first-key.pem", "second-key.pem"]}, issuer="test-issuer", key_id="second-key"
)
self.assertEqual(file_name, "second-key.pem")

def test_two_keys_incorrect_key_id(self):
with self.assertRaisesMessage(InvalidKeyError, 'No key matches the given key_id'):
keys.get_key_file_name(keys={'test-issuer': ['first-key.pem', 'second-key.pem']}, issuer='test-issuer', key_id='incorrect-key')
with self.assertRaisesMessage(InvalidKeyError, "No key matches the given key_id"):
keys.get_key_file_name(
keys={"test-issuer": ["first-key.pem", "second-key.pem"]}, issuer="test-issuer", key_id="incorrect-key"
)


class TestGetPrivateKeyAndKeyId(TestCase):
def test_empty_keys(self):
with self.assertRaisesMessage(InvalidKeyError, 'No keys defined for the given issuer'):
keys.get_private_key_and_key_id(issuer='other-issuer')
with self.assertRaisesMessage(InvalidKeyError, "No keys defined for the given issuer"):
keys.get_private_key_and_key_id(issuer="other-issuer")

def test_first_key(self):
private_key, key_id = keys.get_private_key_and_key_id(issuer='test-issuer')
private_key, key_id = keys.get_private_key_and_key_id(issuer="test-issuer")
self.assertIsInstance(private_key, _RSAPrivateKey)
self.assertEqual(key_id, 'test-2048')
self.assertEqual(key_id, "test-2048")

def test_second_key(self):
private_key, key_id = keys.get_private_key_and_key_id(issuer='test-issuer', key_id='test-1024')
private_key, key_id = keys.get_private_key_and_key_id(issuer="test-issuer", key_id="test-1024")
self.assertIsInstance(private_key, _RSAPrivateKey)
self.assertEqual(key_id, 'test-1024')
self.assertEqual(key_id, "test-1024")


class TestGetPublicKeyAndKeyId(TestCase):
def test_empty_keys(self):
with self.assertRaisesMessage(InvalidKeyError, 'No keys defined for the given issuer'):
keys.get_public_key_and_key_id(issuer='other-issuer')
with self.assertRaisesMessage(InvalidKeyError, "No keys defined for the given issuer"):
keys.get_public_key_and_key_id(issuer="other-issuer")

def test_first_key(self):
public_key, key_id = keys.get_public_key_and_key_id(issuer='test-issuer')
public_key, key_id = keys.get_public_key_and_key_id(issuer="test-issuer")
self.assertIsInstance(public_key, _RSAPublicKey)
self.assertEqual(key_id, 'test-2048')
self.assertEqual(key_id, "test-2048")

def test_second_key(self):
public_key, key_id = keys.get_public_key_and_key_id(issuer='test-issuer', key_id='test-1024')
public_key, key_id = keys.get_public_key_and_key_id(issuer="test-issuer", key_id="test-1024")
self.assertIsInstance(public_key, _RSAPublicKey)
self.assertEqual(key_id, 'test-1024')
self.assertEqual(key_id, "test-1024")

0 comments on commit dab6662

Please sign in to comment.