diff --git a/rest_framework_jwt/serializers.py b/rest_framework_jwt/serializers.py index 12b10a44..56a40e85 100644 --- a/rest_framework_jwt/serializers.py +++ b/rest_framework_jwt/serializers.py @@ -47,7 +47,7 @@ def validate(self, attrs): } if all(credentials.values()): - user = authenticate(**credentials) + user = authenticate(request=self.context.get('request', None), **credentials) if user: if not user.is_active: diff --git a/tests/test_serializers.py b/tests/test_serializers.py index 2e6c7e53..91a5627b 100644 --- a/tests/test_serializers.py +++ b/tests/test_serializers.py @@ -2,9 +2,13 @@ from distutils.version import StrictVersion import django +from django.http import HttpRequest from django.test import TestCase from django.test.utils import override_settings + import rest_framework +from rest_framework.request import Request + from rest_framework_jwt.compat import get_user_model from rest_framework_jwt.serializers import JSONWebTokenSerializer @@ -29,9 +33,14 @@ def setUp(self): 'password': self.password } + def get_serializer(self, **kwargs): + serializer = JSONWebTokenSerializer(**kwargs) + serializer.context['request'] = Request(HttpRequest()), + return serializer + @unittest.skipUnless(drf2, 'not supported in this version') def test_empty_drf2(self): - serializer = JSONWebTokenSerializer() + serializer = self.get_serializer() expected = { 'username': '' } @@ -40,7 +49,7 @@ def test_empty_drf2(self): @unittest.skipUnless(drf3, 'not supported in this version') def test_empty_drf3(self): - serializer = JSONWebTokenSerializer() + serializer = self.get_serializer() expected = { 'username': '', 'password': '' @@ -49,7 +58,7 @@ def test_empty_drf3(self): self.assertEqual(serializer.data, expected) def test_create(self): - serializer = JSONWebTokenSerializer(data=self.data) + serializer = self.get_serializer(data=self.data) is_valid = serializer.is_valid() token = serializer.object['token'] @@ -60,7 +69,7 @@ def test_create(self): def test_invalid_credentials(self): self.data['password'] = 'wrong' - serializer = JSONWebTokenSerializer(data=self.data) + serializer = self.get_serializer(data=self.data) is_valid = serializer.is_valid() expected_error = { @@ -77,7 +86,7 @@ def test_disabled_user(self): self.user.is_active = False self.user.save() - serializer = JSONWebTokenSerializer(data=self.data) + serializer = self.get_serializer(data=self.data) is_valid = serializer.is_valid() expected_error = { @@ -96,7 +105,7 @@ def test_disabled_user_all_users_backend(self): self.user.is_active = False self.user.save() - serializer = JSONWebTokenSerializer(data=self.data) + serializer = self.get_serializer(data=self.data) is_valid = serializer.is_valid() expected_error = { @@ -107,7 +116,7 @@ def test_disabled_user_all_users_backend(self): self.assertEqual(serializer.errors, expected_error) def test_required_fields(self): - serializer = JSONWebTokenSerializer(data={}) + serializer = self.get_serializer(data={}) is_valid = serializer.is_valid() expected_error = {