Skip to content

Commit

Permalink
feat[bckend-middleware]:Implemented custom middleware to check for bl…
Browse files Browse the repository at this point in the history
…acklisted access tokens in Redis.
  • Loading branch information
shikharpa committed Apr 1, 2024
1 parent d4b819d commit 185897d
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 7 deletions.
20 changes: 13 additions & 7 deletions services/api/kalvi/api/views/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from rest_framework_simplejwt.views import TokenRefreshView as SimpleJWTTokenRefreshView
from django.contrib.auth import authenticate
from db.renderer import UserRenderer
from rest_framework_simplejwt.tokens import RefreshToken
from rest_framework_simplejwt.tokens import RefreshToken, AccessToken
from rest_framework_simplejwt.exceptions import TokenError
from rest_framework.permissions import AllowAny
from rest_framework.permissions import AllowAny, IsAuthenticated
from django.utils.encoding import smart_str
from django.core.validators import validate_email
from django.utils.http import urlsafe_base64_decode
Expand Down Expand Up @@ -159,15 +159,20 @@ def post(self, request, uid, token, format=None):
class SignOutEndpoint(APIView):
def post(self, request):
refresh_token = request.data.get('refresh_token')
access_token = request.headers.get('Authorization').split(' ')[1]
if refresh_token:
try:
# Connect to Redis
redis_conn = get_redis_connection()
token = str(RefreshToken(refresh_token))
# Blacklist the token in Redis
expiration_time = int(settings.SIMPLE_JWT['REFRESH_TOKEN_LIFETIME'].total_seconds())
redis_conn.set(token, 'blacklisted')
redis_conn.expire(token, expiration_time)
refresh_token_str = str(RefreshToken(refresh_token))
access_token_str = str(AccessToken(access_token))
# Blacklist tokens in Redis
refresh_exp_time = int(settings.SIMPLE_JWT['REFRESH_TOKEN_LIFETIME'].total_seconds())
access_exp_time = int(settings.SIMPLE_JWT['ACCESS_TOKEN_LIFETIME'].total_seconds())
with redis_conn.pipeline() as pipe:
pipe.set(refresh_token_str, 'blacklisted', ex=refresh_exp_time)
pipe.set(access_token_str, 'blacklisted', ex=access_exp_time)
pipe.execute()
return Response({'message': 'Logged out successfully'}, status=status.HTTP_200_OK)
except Exception:
return Response({'error': 'Please try after some time'}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
Expand All @@ -176,6 +181,7 @@ def post(self, request):

# Viewset class for getting access token from refresh token.
class TokenRefreshView(SimpleJWTTokenRefreshView):
#permission_classes = [IsAuthenticated]
def post(self, request, *args, **kwargs):
refresh_token = request.data.get('refresh')
if refresh_token:
Expand Down
Empty file.
27 changes: 27 additions & 0 deletions services/api/kalvi/middlewares/tokenMiddleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from rest_framework_simplejwt.tokens import AccessToken
from rest_framework import status
from kalvi.api.views.auth import get_redis_connection
from django.http import JsonResponse

class CustomOutstandingTokenMiddleware:
"""
This middleware checks for blacklisted tokens in Redis. It does not perform any token validation.
"""
def __init__(self, get_response=None):
self.get_response = get_response

def __call__(self, request):
try:
auth_header = request.headers.get('Authorization')
if not auth_header:
return self.get_response(request)
token = auth_header.split()[1]
token_obj = str(AccessToken(token))
redis_conn = get_redis_connection()
# Check if the token is blacklisted in Redis
if redis_conn.exists(token_obj):
return JsonResponse({'error': "Access token is blacklisted"}, status=status.HTTP_401_UNAUTHORIZED)
else:
return self.get_response(request)
except Exception:
return self.get_response(request)
1 change: 1 addition & 0 deletions services/api/kalvi/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
"django.middleware.common.CommonMiddleware",
"django.middleware.csrf.CsrfViewMiddleware",
"django.contrib.auth.middleware.AuthenticationMiddleware",
"kalvi.middlewares.tokenMiddleware.CustomOutstandingTokenMiddleware",
"django.contrib.messages.middleware.MessageMiddleware",
"django.middleware.clickjacking.XFrameOptionsMiddleware",
]
Expand Down

0 comments on commit 185897d

Please sign in to comment.