Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DRY the identity pick and validation – now without mocking #84

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions app/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,16 @@


class NoIdentityError(RuntimeError):
"""
There is no current identity. The request is not authenticated.
"""
pass


class InvalidIdentityError(ValueError):
"""
The identity header is missing or invalid.
"""
pass


Expand All @@ -21,26 +31,30 @@ def _pick_identity():
try:
payload = request.headers[_IDENTITY_HEADER]
except KeyError:
abort(Forbidden.code)
raise InvalidIdentityError("The identity header is missing.")

try:
return from_encoded(payload)
except (KeyError, TypeError, ValueError):
abort(Forbidden.code)
raise InvalidIdentityError("The identity header cannot be decoded.")


def _validate(identity):
try:
validate(identity)
except Exception:
abort(Forbidden.code)
raise InvalidIdentityError("The identity header is invalid.")


def requires_identity(view_func):
@wraps(view_func)
def _wrapper(*args, **kwargs):
identity = _pick_identity()
_validate(identity)
try:
identity = _pick_identity()
_validate(identity)
except InvalidIdentityError:
abort(Forbidden.code)

ctx = _request_ctx_stack.top
ctx.identity = identity
return view_func(*args, **kwargs)
Expand Down
201 changes: 178 additions & 23 deletions test_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,40 @@
import os

from api import api_operation
from app import create_app
from app.auth import (
_IDENTITY_HEADER,
InvalidIdentityError,
_validate,
_pick_identity,
requires_identity,
)
from app.config import Config
from app.auth.identity import from_dict, from_encoded, from_json, Identity, validate
from base64 import b64encode
from contextlib import contextmanager
from json import dumps
from unittest import main, TestCase
from unittest.mock import Mock, patch
import pytest
from werkzeug.exceptions import Forbidden


def _encode_header(dict_):
"""
Encode the header payload dictionary.
"""
json = dumps(dict_)
return b64encode(json.encode())


def _identity():
"""
Create a valid Identity object.
"""
return Identity(account_number="some number")


class ApiOperationTestCase(TestCase):
"""
Test the API operation decorator that increments the request counter with every
Expand Down Expand Up @@ -47,17 +67,7 @@ def test_return_value_is_passed(self):
self.assertEqual(old_func.return_value, new_func())


class AuthIdentityConstructorTestCase(TestCase):
"""
Tests the Identity module constructors.
"""

@staticmethod
def _identity():
return Identity(account_number="some number")


class AuthIdentityFromDictTest(AuthIdentityConstructorTestCase):
class AuthIdentityFromDictTest(TestCase):
"""
Tests creating an Identity from a dictionary.
"""
Expand All @@ -66,7 +76,7 @@ def test_valid(self):
"""
Initialize the Identity object with a valid dictionary.
"""
identity = self._identity()
identity = _identity()

dict_ = {
"account_number": identity.account_number,
Expand All @@ -86,7 +96,7 @@ def test_invalid(self):
from_dict(dict_)


class AuthIdentityFromJsonTest(AuthIdentityConstructorTestCase):
class AuthIdentityFromJsonTest(TestCase):
"""
Tests creating an Identity from a JSON string.
"""
Expand All @@ -95,7 +105,7 @@ def test_valid(self):
"""
Initialize the Identity object with a valid JSON string.
"""
identity = self._identity()
identity = _identity()

dict_ = {"identity": identity._asdict()}
json = dumps(dict_)
Expand Down Expand Up @@ -126,7 +136,7 @@ def test_invalid_format(self):
Initializing the Identity object with a JSON string that is not
formatted correctly.
"""
identity = self._identity()
identity = _identity()

dict_ = identity._asdict()
json = dumps(dict_)
Expand All @@ -135,7 +145,7 @@ def test_invalid_format(self):
from_json(json)


class AuthIdentityFromEncodedTest(AuthIdentityConstructorTestCase):
class AuthIdentityFromEncodedTest(TestCase):
"""
Tests creating an Identity from a Base64 encoded JSON string, which is what is in
the HTTP header.
Expand All @@ -146,11 +156,10 @@ def test_valid(self):
Initialize the Identity object with an encoded payload – a base64-encoded JSON.
That would typically be a raw HTTP header content.
"""
identity = self._identity()
identity = _identity()

dict_ = {"identity": identity._asdict()}
json = dumps(dict_)
base64 = b64encode(json.encode())
base64 = _encode_header(dict_)

try:
self.assertEqual(identity, from_encoded(base64))
Expand Down Expand Up @@ -178,7 +187,7 @@ def test_invalid_format(self):
Initializing the Identity object with an valid Base64 encoded payload
that does not contain the "identity" field.
"""
identity = self._identity()
identity = _identity()

dict_ = identity._asdict()
json = dumps(dict_)
Expand Down Expand Up @@ -210,14 +219,160 @@ def test_invalid(self):
validate(identity)

def test__validate_identity(self):
with self.assertRaises(Forbidden):
"""
A specific exception is raised if the identity cannot be decoded.
"""
with self.assertRaises(InvalidIdentityError):
_validate(None)
with self.assertRaises(Forbidden):
with self.assertRaises(InvalidIdentityError):
_validate("")
with self.assertRaises(Forbidden):
with self.assertRaises(InvalidIdentityError):
_validate({})


class AuthPickIdentityTestCase(TestCase):
"""
The identity is read and decoded from the header. If it’s missing or undecodeable,
an Exception is raised.
"""
def setUp(self):
self.app = create_app(config_name="testing")

@contextmanager
def _test_request_context(self, headers):
with self.app.test_request_context(headers=headers) as context:
yield context

def test_identity_is_invalid_if_header_is_missing(self):
with self._test_request_context({}):
with self.assertRaises(InvalidIdentityError):
_pick_identity()

def test_identity_is_invalid_if_decode_fails(self):
payload = "invalid"
with self._test_request_context({_IDENTITY_HEADER: payload}):
with self.assertRaises(InvalidIdentityError):
# b64decode raises ValueError.
_pick_identity()

def test_identity_is_invalid_if_identity_key_is_missing(self):
payload = _encode_header({})
with self._test_request_context({_IDENTITY_HEADER: payload}):
with self.assertRaises(InvalidIdentityError):
# dict["_identity"] raises KeyError.
_pick_identity()

def test_identity_is_invalid_if_account_number_is_missing(self):
payload = _encode_header({"identity": {}})
with self._test_request_context({_IDENTITY_HEADER: payload}):
with self.assertRaises(InvalidIdentityError):
# Failed "account_number" in dict check raises TypeError.
_pick_identity()

def test_decoded_identity_is_returned(self):
identity = _identity()
payload = _encode_header({"identity": identity._asdict()})
with self._test_request_context({_IDENTITY_HEADER: payload}):
result = _pick_identity()
self.assertEqual(identity, result)


class AuthValidate(TestCase):
"""
The retrieved identity is validated and if it’s not valid, an exception is raised.
"""
def test_no_exception_is_raised_if_identity_is_valid(self):
identity = Identity(account_number="some account")
try:
_validate(identity)
self.assertTrue(True)
except InvalidIdentityError:
self.fail()

def test_exception_is_raised_if_identity_is_not_valid(self):
with self.assertRaises(InvalidIdentityError):
_validate(Identity(account_number=None))


class AuthRequiresIdentityTestCase(TestCase):
"""
Tests the requires_identity decorator for that it doesn’t accept a request with an
invalid identity header.
"""
def setUp(self):
self.app = create_app(config_name="testing")
self._dummy_calls = []

@contextmanager
def _test_request_context(self, headers):
with self.app.test_request_context(headers=headers) as context:
yield context

@requires_identity
def _dummy_view_func(self, *args, **kwargs):
self._dummy_calls.append((args, kwargs))
return "some return value"

def test_request_is_aborted_with_forbidden_if_identity_cant_be_picked(self):
with self._test_request_context({}) as request:
with self.assertRaises(Forbidden):
self._dummy_view_func()

def test_identity_is_not_put_into_request_context_if_it_cant_be_picked(self):
with self._test_request_context({}) as request:
with self.assertRaises(Exception):
self._dummy_view_func()
self.assertFalse(hasattr(request, "identity"))

def test_request_is_aborted_with_forbidden_if_identity_is_not_valid(self):
identity = {"account_number": None}
payload = _encode_header({"identity": identity})
with self._test_request_context({_IDENTITY_HEADER: payload}) as request:
with self.assertRaises(Forbidden):
self._dummy_view_func()

def test_identity_is_not_put_into_request_context_if_its_not_valid(self):
identity = {"account_number": None}
payload = _encode_header({"identity": identity})
with self._test_request_context({_IDENTITY_HEADER: payload}) as request:
with self.assertRaises(Exception):
self._dummy_view_func()
self.assertFalse(hasattr(request, "identity"))

def test_identity_is_put_into_request_context_if_it_is_valid(self):
identity = _identity()
payload = _encode_header({"identity": identity._asdict()})
with self._test_request_context({_IDENTITY_HEADER: payload}) as request:
self._dummy_view_func()
self.assertTrue(hasattr(request, "identity"))
self.assertEqual(identity, request.identity)

def test_view_func_is_called_if_identity_is_valid(self):
identity = _identity()
payload = _encode_header({"identity": identity._asdict()})
with self._test_request_context({_IDENTITY_HEADER: payload}):
self._dummy_view_func()
self.assertEqual(1, len(self._dummy_calls))

def test_view_func_is_called_with_original_args(self):
args, kwargs = ("some", "args"), {"some": "kwargs"}

identity = _identity()
payload = _encode_header({"identity": identity._asdict()})
with self._test_request_context({_IDENTITY_HEADER: payload}):
self._dummy_view_func(*args, **kwargs)

self.assertEqual(self._dummy_calls[0], (args, kwargs))

def test_view_func_result_is_returned(self):
identity = _identity()
payload = _encode_header({"identity": identity._asdict()})
with self._test_request_context({_IDENTITY_HEADER: payload}):
result = self._dummy_view_func()

self.assertEqual("some return value", result)


@pytest.mark.usefixtures("monkeypatch")
def test_noauthmode(monkeypatch):
with monkeypatch.context() as m:
Expand Down