Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add refresh endpoint for JWT tokens #2570

Draft
wants to merge 27 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
8c81ea8
Add pyjwt dependency
stveit Feb 6, 2023
8bde9e3
Add jwt refresh token model
stveit Feb 6, 2023
d8f5a64
Add methods for generating jwt tokens
stveit Feb 6, 2023
090f8e8
Add function for expiring jwt token
stveit Feb 6, 2023
b7a8fa6
Put model in manage namespace
stveit Feb 6, 2023
b7db477
Mark certain functions as properties
stveit Nov 7, 2023
348612d
Make jwtconf data easier to mock
stveit Nov 7, 2023
0404d5c
fixup! Mark certain functions as properties
stveit Nov 7, 2023
0ca8096
Rename properties to be more explicit
stveit Nov 7, 2023
2103df5
Add static method for decoding token
stveit Nov 7, 2023
6cd5833
Update docstring
stveit Nov 9, 2023
10c3e7c
Improve function documentation
stveit Nov 9, 2023
3386f29
Use jwt conf directly
stveit Nov 9, 2023
217cd4a
Mark function as private
stveit Nov 9, 2023
104ee18
Remove properties and use directly in expire funn
stveit Nov 10, 2023
a7554c4
Set exp claim to the past when expiring
stveit Nov 10, 2023
9eebe3f
Dont change nbf claim in expire
stveit Nov 10, 2023
7522e96
Use property once
stveit Nov 13, 2023
0a300cf
Remove property decorators
stveit Nov 13, 2023
3fcbc40
Reorder class
stveit Nov 13, 2023
b0b6e07
Fix wrong function name being used
stveit Nov 13, 2023
5416307
Add tests
stveit Nov 13, 2023
17759b3
Shorten docstring
stveit Nov 13, 2023
5fe8ee2
Use term "data" instead of "body"
stveit Nov 13, 2023
0163627
Add refresh view
stveit Feb 6, 2023
67d742d
Add url for refresh view
stveit Feb 6, 2023
35737f8
Add tests
stveit Nov 22, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 95 additions & 1 deletion python/nav/models/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,18 @@
#
"""Models for the NAV API"""

from datetime import datetime
from datetime import datetime, timedelta
from typing import Dict, Any

import jwt

from django.db import models
from django.urls import reverse

from nav.adapters import HStoreField
from nav.models.fields import VarcharField
from nav.models.profiles import Account
from nav.jwtconf import JWTConf


class APIToken(models.Model):
Expand Down Expand Up @@ -66,3 +70,93 @@

class Meta(object):
db_table = 'apitoken'


class JWTRefreshToken(models.Model):
"""RefreshTokens are used for generating new access tokens"""

token = VarcharField()
name = VarcharField(unique=True)
description = models.TextField(null=True, blank=True)

ACCESS_EXPIRE_DELTA = timedelta(hours=1)
REFRESH_EXPIRE_DELTA = timedelta(days=1)

def __str__(self):
return self.token

Check warning on line 86 in python/nav/models/api.py

View check run for this annotation

Codecov / codecov/patch

python/nav/models/api.py#L86

Added line #L86 was not covered by tests

def data(self) -> Dict[str, Any]:
"""Data of token as a dict"""
return self._decode_token(self.token)

def is_active(self) -> bool:
"""True if token is active. A token is considered active when
the nbf claim is in the past and the exp claim is in the future
"""
now = datetime.now()
data = self.data()
nbf = datetime.fromtimestamp(data['nbf'])
exp = datetime.fromtimestamp(data['exp'])
return now >= nbf and now < exp

def expire(self):
"""Expires the token"""
# Base claims for expired token on existing claims
expired_data = self.data()
expired_data['exp'] = (datetime.now() - timedelta(hours=1)).timestamp()
self.token = self._encode_token(expired_data)
self.save()

@classmethod
def generate_access_token(cls, token_data: Dict[str, Any] = {}) -> str:
"""Generates and returns an access token in JWT format.
Will use `token_data` as a basis for the new token,
but certain claims will be overridden.
"""
return cls._generate_token(token_data, cls.ACCESS_EXPIRE_DELTA, "access_token")

@classmethod
def generate_refresh_token(cls, token_data: Dict[str, Any] = {}) -> str:
"""Generates and returns a refresh token in JWT format.
Will use `token_data` as a basis for the new token,
but certain claims will be overridden.
"""
return cls._generate_token(
token_data, cls.REFRESH_EXPIRE_DELTA, "refresh_token"
)

@classmethod
def _generate_token(
cls, token_data: Dict[str, Any], expiry_delta: timedelta, token_type: str
) -> str:
"""Generates and returns a token in JWT format. Will use `token_data` as a basis
for the new token, but certain claims will be overridden
"""
new_token = dict(token_data)
now = datetime.now()
name = JWTConf().get_nav_name()
updated_claims = {
'exp': (now + expiry_delta).timestamp(),
'nbf': now.timestamp(),
'iat': now.timestamp(),
'aud': name,
'iss': name,
'token_type': token_type,
}
new_token.update(updated_claims)
return cls._encode_token(new_token)

@classmethod
def _encode_token(cls, token_data: Dict[str, Any]) -> str:
"""Returns an encoded token in JWT format"""
return jwt.encode(
token_data, JWTConf().get_nav_private_key(), algorithm="RS256"
)

@classmethod
def _decode_token(cls, token: str) -> Dict[str, Any]:
"""Decodes a token in JWT format and returns the data of the decoded token"""
return jwt.decode(token, options={'verify_signature': False})

class Meta(object):
db_table = 'jwtrefreshtoken'
7 changes: 7 additions & 0 deletions python/nav/models/sql/changes/sc.05.05.0002.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
--- Create table for storing JWT refresh tokens
CREATE TABLE manage.JWTRefreshToken (
id SERIAL PRIMARY KEY,
token VARCHAR NOT NULL,
name VARCHAR NOT NULL UNIQUE,
description VARCHAR
);
1 change: 1 addition & 0 deletions python/nav/web/api/v1/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,5 @@
name="prefix-usage-detail",
),
re_path(r'^', include(router.urls)),
re_path(r'^refresh/$', views.JWTRefreshViewSet.as_view(), name='jwt-refresh'),
]
28 changes: 28 additions & 0 deletions python/nav/web/api/v1/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from oidc_auth.authentication import JSONWebTokenAuthentication

from nav.models import manage, event, cabling, rack, profiles
from nav.models.api import JWTRefreshToken
from nav.models.fields import INFINITY, UNRESOLVED
from nav.web.servicecheckers import load_checker_classes
from nav.util import auth_token, is_valid_cidr
Expand Down Expand Up @@ -1107,6 +1108,33 @@ class RackViewSet(NAVAPIMixin, viewsets.ReadOnlyModelViewSet):
search_fields = ['rackname']


class JWTRefreshViewSet(APIView):
"""
Accepts a valid refresh token.
Returns a new refresh token and an access token.
"""

def post(self, request):
try:
db_token = JWTRefreshToken.objects.get(
token=request.data.get('refresh_token')
)
except JWTRefreshToken.DoesNotExist:
return Response("Invalid token", status=status.HTTP_403_FORBIDDEN)
if not db_token.is_active():
return Response("Inactive token", status=status.HTTP_403_FORBIDDEN)
token_data = db_token.data()
access_token = JWTRefreshToken.generate_access_token(token_data)
refresh_token = JWTRefreshToken.generate_refresh_token(token_data)
db_token.token = refresh_token
db_token.save()
response_data = {
'access_token': access_token,
'refresh_token': refresh_token,
}
return Response(response_data)


def get_or_create_token(request):
"""Gets an existing token or creates a new one. If the old token has
expired, create a new one.
Expand Down
2 changes: 2 additions & 0 deletions requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,5 @@ napalm==3.4.1

backports.zoneinfo ; python_version < '3.9'
git+https://github.com/Uninett/[email protected]#egg=drf-oidc-auth

pyjwt>=2.6.0
153 changes: 153 additions & 0 deletions tests/integration/jwt_refresh_endpoint_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import pytest

from unittest.mock import Mock, patch

from django.urls import reverse
from rest_framework.reverse import reverse_lazy
from nav.models.api import JWTRefreshToken


def test_token_not_in_database_should_be_rejected(db, api_client, url):
token = JWTRefreshToken.generate_refresh_token()
assert not JWTRefreshToken.objects.filter(token=token).exists()
response = api_client.post(
url,
follow=True,
data={
'refresh_token': token,
},
)
assert response.status_code == 403


def test_expired_token_should_be_rejected(db, api_client, url):
token = JWTRefreshToken.generate_refresh_token()
db_token = JWTRefreshToken(token=token)
db_token.save()
db_token.expire()
response = api_client.post(
url,
follow=True,
data={
'refresh_token': db_token.token,
},
)
assert response.status_code == 403


def test_valid_token_should_be_accepted(db, api_client, url):
token = JWTRefreshToken.generate_refresh_token()
db_token = JWTRefreshToken(token=token)
db_token.save()
response = api_client.post(
url,
follow=True,
data={
'refresh_token': token,
},
)
assert response.status_code == 200


def test_valid_token_should_be_replaced_by_new_token_in_db(db, api_client, url):
token = JWTRefreshToken.generate_refresh_token()
db_token = JWTRefreshToken(token=token)
db_token.save()
response = api_client.post(
url,
follow=True,
data={
'refresh_token': token,
},
)
assert response.status_code == 200
assert not JWTRefreshToken.objects.filter(token=token).exists()
new_token = response.data.get("refresh_token")
assert JWTRefreshToken.objects.filter(token=new_token).exists()


def test_should_include_access_and_refresh_token_in_response(db, api_client, url):
token = JWTRefreshToken.generate_refresh_token()
db_token = JWTRefreshToken(token=token)
db_token.save()
response = api_client.post(
url,
follow=True,
data={
'refresh_token': token,
},
)
assert response.status_code == 200
assert "access_token" in response.data
assert "refresh_token" in response.data


@pytest.fixture()
def url():
return reverse('api:1:jwt-refresh')


@pytest.fixture(scope="module", autouse=True)
def jwtconf_mock(private_key, nav_name) -> str:
"""Mocks the get_nave_name and get_nav_private_key functions for
the JWTConf class
"""
with patch("nav.models.api.JWTConf") as _jwtconf_mock:
instance = _jwtconf_mock.return_value
instance.get_nav_name = Mock(return_value=nav_name)
instance.get_nav_private_key = Mock(return_value=private_key)
yield _jwtconf_mock


@pytest.fixture(scope="module")
def private_key() -> str:
"""Yields a private key in PEM format"""
key = """-----BEGIN PRIVATE KEY-----
MIIEuwIBADANBgkqhkiG9w0BAQEFAASCBKUwggShAgEAAoIBAQCp+4AEZM4uYZKu
/hrKzySMTFFx3/ncWo6XAFpADQHXLOwRB9Xh1/OwigHiqs/wHRAAmnrlkwCCQA8r
xiHBAMjp5ApbkyggQz/DVijrpSba6Tiy1cyBTZC3cvOK2FpJzsakJLhIXD1HaULO
ClyIJB/YrmHmQc8SL3Uzou5mMpdcBC2pzwmEW1cvQURpnvgrDF8V86GrQkjK6nIP
IEeuW6kbD5lWFAPfLf1ohDWex3yxeSFyXNRApJhbF4HrKFemPkOi7acsky38UomQ
jZgAMHPotJNkQvAHcnXHhg0FcWGdohv5bc/Ctt9GwZOzJxwyJLBBsSewbE310TZi
3oLU1TmvAgMBAAECgf8zrhi95+gdMeKRpwV+TnxOK5CXjqvo0vTcnr7Runf/c9On
WeUtRPr83E4LxuMcSGRqdTfoP0loUGb3EsYwZ+IDOnyWWvytfRoQdExSA2RM1PDo
GRiUN4Dy8CrGNqvnb3agG99Ay3Ura6q5T20n9ykM4qKL3yDrO9fmWyMgRJbAOAYm
xzf7H910mDZghXPpq8nzDky0JLNZcaqbxuPQ3+EI4p2dLNXbNqMPs8Y20JKLeOPs
HikRM0zfhHEJSt5IPFQ54/CzscGHGeCleQINWTgvDLMcE5fJMvbLLZixV+YsBfAq
e2JsSubS+9RI2ktMlSKaemr8yeoIpsXfAiJSHkECgYEA0NKU18xK+9w5IXfgNwI4
peu2tWgwyZSp5R2pdLT7O1dJoLYRoAmcXNePB0VXNARqGxTNypJ9zmMawNmf3YRS
BqG8aKz7qpATlx9OwYlk09fsS6MeVmaur8bHGHP6O+gt7Xg+zhiFPvU9P5LB+C0Z
0d4grEmIxNhJCtJRQOThD8ECgYEA0GKRO9SJdnhw1b6LPLd+o/AX7IEzQDHwdtfi
0h7hKHHGBlUMbIBwwjKmyKm6cSe0PYe96LqrVg+cVf84wbLZPAixhOjyplLznBzF
LqOrfFPfI5lQVhslE1H1CdLlk9eyT96jDgmLAg8EGSMV8aLGj++Gi2l/isujHlWF
BI4YpW8CgYEAsyKyhJzABmbYq5lGQmopZkxapCwJDiP1ypIzd+Z5TmKGytLlM8CK
3iocjEQzlm/jBfBGyWv5eD8UCDOoLEMCiqXcFn+uNJb79zvoN6ZBVGl6TzhTIhNb
73Y5/QQguZtnKrtoRSxLwcJnFE41D0zBRYOjy6gZJ6PSpPHeuiid2QECgYACuZc+
mgvmIbMQCHrXo2qjiCs364SZDU4gr7gGmWLGXZ6CTLBp5tASqgjmTNnkSumfeFvy
ZCaDbJbVxQ2f8s/GajKwEz/BDwqievnVH0zJxmr/kyyqw5Ybh5HVvA1GfqaVRssJ
DvTjZQDft0a9Lyy7ix1OS2XgkcMjTWj840LNPwKBgDPXMBgL5h41jd7jCsXzPhyr
V96RzQkPcKsoVvrCoNi8eoEYgRd9jwfiU12rlXv+fgVXrrfMoJBoYT6YtrxEJVdM
RAjRpnE8PMqCUA8Rd7RFK9Vp5Uo8RxTNvk9yPvDv1+lHHV7lEltIk5PXuKPHIrc1
nNUyhzvJs2Qba2L/huNC
-----END PRIVATE KEY-----"""
yield key


@pytest.fixture()
def public_key() -> str:
"""Yields a public key in PEM format"""
key = """-----BEGIN PUBLIC KEY-----
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAqfuABGTOLmGSrv4ays8k
jExRcd/53FqOlwBaQA0B1yzsEQfV4dfzsIoB4qrP8B0QAJp65ZMAgkAPK8YhwQDI
6eQKW5MoIEM/w1Yo66Um2uk4stXMgU2Qt3LzithaSc7GpCS4SFw9R2lCzgpciCQf
2K5h5kHPEi91M6LuZjKXXAQtqc8JhFtXL0FEaZ74KwxfFfOhq0JIyupyDyBHrlup
Gw+ZVhQD3y39aIQ1nsd8sXkhclzUQKSYWxeB6yhXpj5Dou2nLJMt/FKJkI2YADBz
6LSTZELwB3J1x4YNBXFhnaIb+W3PwrbfRsGTsyccMiSwQbEnsGxN9dE2Yt6C1NU5
rwIDAQAB
-----END PUBLIC KEY-----"""
yield key


@pytest.fixture(scope="module")
def nav_name() -> str:
yield "nav"
Loading
Loading