diff --git a/rest_framework_sso/keys.py b/rest_framework_sso/keys.py index 2d64874..d42bd24 100644 --- a/rest_framework_sso/keys.py +++ b/rest_framework_sso/keys.py @@ -1,6 +1,5 @@ # coding: utf-8 import os -import six from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.serialization import load_pem_private_key, load_pem_public_key @@ -34,7 +33,7 @@ def get_key_file_name(keys, issuer, key_id=None): if not keys.get(issuer): raise InvalidKeyError("No keys defined for the given issuer") issuer_keys = keys.get(issuer) - if isinstance(issuer_keys, (str, six.text_type)): + if isinstance(issuer_keys, str): issuer_keys = [issuer_keys] if key_id: issuer_keys = [ik for ik in issuer_keys if key_id in (ik, get_key_id(ik))] diff --git a/rest_framework_sso/migrations/0002_sessiontoken_last_used_at.py b/rest_framework_sso/migrations/0002_sessiontoken_last_used_at.py index 649354c..c47e62c 100644 --- a/rest_framework_sso/migrations/0002_sessiontoken_last_used_at.py +++ b/rest_framework_sso/migrations/0002_sessiontoken_last_used_at.py @@ -6,13 +6,13 @@ class Migration(migrations.Migration): dependencies = [ - ('rest_framework_sso', '0001_initial'), + ("rest_framework_sso", "0001_initial"), ] operations = [ migrations.AddField( - model_name='sessiontoken', - name='last_used_at', + model_name="sessiontoken", + name="last_used_at", field=models.DateTimeField(blank=True, db_index=True, null=True), ), ] diff --git a/rest_framework_sso/models.py b/rest_framework_sso/models.py index bb4ecb7..0d422c2 100644 --- a/rest_framework_sso/models.py +++ b/rest_framework_sso/models.py @@ -2,16 +2,10 @@ from __future__ import absolute_import, unicode_literals import uuid -import six from django.conf import settings from django.db import models -try: - from django.utils.encoding import python_2_unicode_compatible as smart_text -except ImportError: - from django.utils.encoding import smart_text - from django.utils.translation import ugettext_lazy as _ # Prior to Django 1.5, the AUTH_USER_MODEL setting does not exist. @@ -27,7 +21,6 @@ AUTH_USER_MODEL = getattr(settings, "AUTH_USER_MODEL", "auth.User") -@smart_text class SessionToken(models.Model): """ The default session token model. @@ -54,7 +47,7 @@ class Meta: verbose_name_plural = _("Session tokens") def __str__(self): - return six.text_type(self.id) + return str(self.id) def update_attributes(self, request): if request.META.get("HTTP_X_FORWARDED_FOR"): diff --git a/rest_framework_sso/utils.py b/rest_framework_sso/utils.py index 1d9c7a5..6d262c8 100644 --- a/rest_framework_sso/utils.py +++ b/rest_framework_sso/utils.py @@ -1,7 +1,6 @@ # coding: utf-8 from __future__ import absolute_import, unicode_literals -import six import jwt from datetime import datetime @@ -96,14 +95,14 @@ def decode_jwt_token(token): unverified_claims = jwt.decode(token, verify=False) if unverified_header.get(claims.KEY_ID): - unverified_key_id = six.text_type(unverified_header.get(claims.KEY_ID)) + unverified_key_id = str(unverified_header.get(claims.KEY_ID)) else: unverified_key_id = None if claims.ISSUER not in unverified_claims: raise MissingRequiredClaimError(claims.ISSUER) - unverified_issuer = six.text_type(unverified_claims[claims.ISSUER]) + unverified_issuer = str(unverified_claims[claims.ISSUER]) if api_settings.ACCEPTED_ISSUERS is not None and unverified_issuer not in api_settings.ACCEPTED_ISSUERS: raise InvalidIssuerError("Invalid issuer") diff --git a/tests/test_keys.py b/tests/test_keys.py index e65a375..2901ec8 100644 --- a/tests/test_keys.py +++ b/tests/test_keys.py @@ -1,6 +1,5 @@ # coding: utf-8 from __future__ import absolute_import, unicode_literals -import six from cryptography.hazmat.backends.openssl.rsa import _RSAPrivateKey, _RSAPublicKey from django.test import TestCase @@ -11,110 +10,124 @@ class TestReadKeyFile(TestCase): def test_read(self): - key_data = keys.read_key_file('test-2048.pem') - self.assertIsInstance(key_data, six.binary_type) - key_data_lines = key_data.decode('utf-8').split('\n') - self.assertIn('-----BEGIN PRIVATE KEY-----', key_data_lines) - self.assertIn('-----END PRIVATE KEY-----', key_data_lines) - self.assertIn('-----BEGIN PUBLIC KEY-----', key_data_lines) - self.assertIn('-----END PUBLIC KEY-----', key_data_lines) + key_data = keys.read_key_file("test-2048.pem") + self.assertIsInstance(key_data, bytes) + key_data_lines = key_data.decode("utf-8").split("\n") + self.assertIn("-----BEGIN PRIVATE KEY-----", key_data_lines) + self.assertIn("-----END PRIVATE KEY-----", key_data_lines) + self.assertIn("-----BEGIN PUBLIC KEY-----", key_data_lines) + self.assertIn("-----END PUBLIC KEY-----", key_data_lines) class TestGetKeyId(TestCase): def test_root_simple(self): - key_id = keys.get_key_id(file_name='keyfile') - self.assertEqual(key_id, 'keyfile') + key_id = keys.get_key_id(file_name="keyfile") + self.assertEqual(key_id, "keyfile") def test_root_pem_extension(self): - key_id = keys.get_key_id(file_name='keyfile.pem') - self.assertEqual(key_id, 'keyfile') + key_id = keys.get_key_id(file_name="keyfile.pem") + self.assertEqual(key_id, "keyfile") def test_subfolder_simple(self): - key_id = keys.get_key_id(file_name='subfolder/keyfile') - self.assertEqual(key_id, 'subfolder/keyfile') + key_id = keys.get_key_id(file_name="subfolder/keyfile") + self.assertEqual(key_id, "subfolder/keyfile") def test_subfolder_pem_extension(self): - key_id = keys.get_key_id(file_name='subfolder/keyfile.pem') - self.assertEqual(key_id, 'subfolder/keyfile') + key_id = keys.get_key_id(file_name="subfolder/keyfile.pem") + self.assertEqual(key_id, "subfolder/keyfile") class TestGetKeyFileName(TestCase): def test_empty_keys(self): - with self.assertRaisesMessage(InvalidKeyError, 'No keys defined for the given issuer'): - keys.get_key_file_name(keys={}, issuer='test-issuer') + with self.assertRaisesMessage(InvalidKeyError, "No keys defined for the given issuer"): + keys.get_key_file_name(keys={}, issuer="test-issuer") def test_other_issuer_keys(self): - with self.assertRaisesMessage(InvalidKeyError, 'No keys defined for the given issuer'): - keys.get_key_file_name(keys={'other-issuer': ['other-key.pem']}, issuer='test-issuer') + with self.assertRaisesMessage(InvalidKeyError, "No keys defined for the given issuer"): + keys.get_key_file_name(keys={"other-issuer": ["other-key.pem"]}, issuer="test-issuer") def test_one_key_string(self): - file_name = keys.get_key_file_name(keys={'test-issuer': 'first-key.pem'}, issuer='test-issuer') - self.assertEqual(file_name, 'first-key.pem') + file_name = keys.get_key_file_name(keys={"test-issuer": "first-key.pem"}, issuer="test-issuer") + self.assertEqual(file_name, "first-key.pem") def test_one_key_list(self): - file_name = keys.get_key_file_name(keys={'test-issuer': ['first-key.pem']}, issuer='test-issuer') - self.assertEqual(file_name, 'first-key.pem') + file_name = keys.get_key_file_name(keys={"test-issuer": ["first-key.pem"]}, issuer="test-issuer") + self.assertEqual(file_name, "first-key.pem") def test_one_key_with_key_id(self): - file_name = keys.get_key_file_name(keys={'test-issuer': ['first-key.pem']}, issuer='test-issuer') - self.assertEqual(file_name, 'first-key.pem') + file_name = keys.get_key_file_name(keys={"test-issuer": ["first-key.pem"]}, issuer="test-issuer") + self.assertEqual(file_name, "first-key.pem") def test_one_key_incorrect_key_id(self): - with self.assertRaisesMessage(InvalidKeyError, 'No key matches the given key_id'): - keys.get_key_file_name(keys={'test-issuer': ['first-key.pem']}, issuer='test-issuer', key_id='incorrect-key') + with self.assertRaisesMessage(InvalidKeyError, "No key matches the given key_id"): + keys.get_key_file_name( + keys={"test-issuer": ["first-key.pem"]}, issuer="test-issuer", key_id="incorrect-key" + ) def test_two_keys_no_key_id(self): - file_name = keys.get_key_file_name(keys={'test-issuer': ['first-key.pem', 'second-key.pem']}, issuer='test-issuer') - self.assertEqual(file_name, 'first-key.pem') + file_name = keys.get_key_file_name( + keys={"test-issuer": ["first-key.pem", "second-key.pem"]}, issuer="test-issuer" + ) + self.assertEqual(file_name, "first-key.pem") def test_two_keys_with_key_id_1_exact(self): - file_name = keys.get_key_file_name(keys={'test-issuer': ['first-key.pem', 'second-key.pem']}, issuer='test-issuer', key_id='first-key.pem') - self.assertEqual(file_name, 'first-key.pem') + file_name = keys.get_key_file_name( + keys={"test-issuer": ["first-key.pem", "second-key.pem"]}, issuer="test-issuer", key_id="first-key.pem" + ) + self.assertEqual(file_name, "first-key.pem") def test_two_keys_with_key_id_1_no_pem(self): - file_name = keys.get_key_file_name(keys={'test-issuer': ['first-key.pem', 'second-key.pem']}, issuer='test-issuer', key_id='first-key') - self.assertEqual(file_name, 'first-key.pem') + file_name = keys.get_key_file_name( + keys={"test-issuer": ["first-key.pem", "second-key.pem"]}, issuer="test-issuer", key_id="first-key" + ) + self.assertEqual(file_name, "first-key.pem") def test_two_keys_with_key_id_2_exact(self): - file_name = keys.get_key_file_name(keys={'test-issuer': ['first-key.pem', 'second-key.pem']}, issuer='test-issuer', key_id='second-key.pem') - self.assertEqual(file_name, 'second-key.pem') + file_name = keys.get_key_file_name( + keys={"test-issuer": ["first-key.pem", "second-key.pem"]}, issuer="test-issuer", key_id="second-key.pem" + ) + self.assertEqual(file_name, "second-key.pem") def test_two_keys_with_key_id_2_no_pem(self): - file_name = keys.get_key_file_name(keys={'test-issuer': ['first-key.pem', 'second-key.pem']}, issuer='test-issuer', key_id='second-key') - self.assertEqual(file_name, 'second-key.pem') + file_name = keys.get_key_file_name( + keys={"test-issuer": ["first-key.pem", "second-key.pem"]}, issuer="test-issuer", key_id="second-key" + ) + self.assertEqual(file_name, "second-key.pem") def test_two_keys_incorrect_key_id(self): - with self.assertRaisesMessage(InvalidKeyError, 'No key matches the given key_id'): - keys.get_key_file_name(keys={'test-issuer': ['first-key.pem', 'second-key.pem']}, issuer='test-issuer', key_id='incorrect-key') + with self.assertRaisesMessage(InvalidKeyError, "No key matches the given key_id"): + keys.get_key_file_name( + keys={"test-issuer": ["first-key.pem", "second-key.pem"]}, issuer="test-issuer", key_id="incorrect-key" + ) class TestGetPrivateKeyAndKeyId(TestCase): def test_empty_keys(self): - with self.assertRaisesMessage(InvalidKeyError, 'No keys defined for the given issuer'): - keys.get_private_key_and_key_id(issuer='other-issuer') + with self.assertRaisesMessage(InvalidKeyError, "No keys defined for the given issuer"): + keys.get_private_key_and_key_id(issuer="other-issuer") def test_first_key(self): - private_key, key_id = keys.get_private_key_and_key_id(issuer='test-issuer') + private_key, key_id = keys.get_private_key_and_key_id(issuer="test-issuer") self.assertIsInstance(private_key, _RSAPrivateKey) - self.assertEqual(key_id, 'test-2048') + self.assertEqual(key_id, "test-2048") def test_second_key(self): - private_key, key_id = keys.get_private_key_and_key_id(issuer='test-issuer', key_id='test-1024') + private_key, key_id = keys.get_private_key_and_key_id(issuer="test-issuer", key_id="test-1024") self.assertIsInstance(private_key, _RSAPrivateKey) - self.assertEqual(key_id, 'test-1024') + self.assertEqual(key_id, "test-1024") class TestGetPublicKeyAndKeyId(TestCase): def test_empty_keys(self): - with self.assertRaisesMessage(InvalidKeyError, 'No keys defined for the given issuer'): - keys.get_public_key_and_key_id(issuer='other-issuer') + with self.assertRaisesMessage(InvalidKeyError, "No keys defined for the given issuer"): + keys.get_public_key_and_key_id(issuer="other-issuer") def test_first_key(self): - public_key, key_id = keys.get_public_key_and_key_id(issuer='test-issuer') + public_key, key_id = keys.get_public_key_and_key_id(issuer="test-issuer") self.assertIsInstance(public_key, _RSAPublicKey) - self.assertEqual(key_id, 'test-2048') + self.assertEqual(key_id, "test-2048") def test_second_key(self): - public_key, key_id = keys.get_public_key_and_key_id(issuer='test-issuer', key_id='test-1024') + public_key, key_id = keys.get_public_key_and_key_id(issuer="test-issuer", key_id="test-1024") self.assertIsInstance(public_key, _RSAPublicKey) - self.assertEqual(key_id, 'test-1024') + self.assertEqual(key_id, "test-1024")