Skip to content

Commit

Permalink
Add user serializer (#4940)
Browse files Browse the repository at this point in the history
* add user serializer, enable it for v3/users/current, v3/users/retrieve and v3/users/update

* update tests to account for new user serializer

* update tests to account for new user serializer
  • Loading branch information
crutan authored Feb 3, 2025
1 parent 12e60e2 commit 2729052
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 76 deletions.
6 changes: 6 additions & 0 deletions seed/landing/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,9 @@ def save(self, *args, **kwargs):
if self.email.lower() != self.username:
self.email = self.username
return super().save(*args, **kwargs)

def serialize(self):
from seed.serializers.users import UserSerializer

serializer = UserSerializer(self)
return serializer.data
42 changes: 42 additions & 0 deletions seed/serializers/users.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""
SEED Platform (TM), Copyright (c) Alliance for Sustainable Energy, LLC, and other contributors.
See also https://github.com/SEED-platform/seed/blob/main/LICENSE.md
"""

from django_otp import devices_for_user
from django_otp.plugins.otp_email.models import EmailDevice
from django_otp.plugins.otp_totp.models import TOTPDevice
from rest_framework import serializers

from seed.landing.models import SEEDUser as User
from seed.views.main import _get_default_org as get_default_org_for_user


class UserSerializer(serializers.ModelSerializer):
class Meta:
model = User
fields = ("first_name", "last_name", "email", "username", "api_key", "is_superuser", "id", "pk")

def to_representation(self, instance):
ret = super().to_representation(instance)

two_factor_devices = list(devices_for_user(instance))
if two_factor_devices and isinstance(two_factor_devices[0], EmailDevice):
ret["two_factor_method"] = "email"
elif two_factor_devices and isinstance(two_factor_devices[0], TOTPDevice):
ret["two_factor_method"] = "token"
else:
ret["two_factor_method"] = "disabled"

additional_fields = dict(
list(
zip(
("org_id", "org_name", "org_role", "ali_name", "ali_id", "is_ali_root", "is_ali_leaf"),
get_default_org_for_user(instance),
)
)
)
for k, v in additional_fields.items():
ret[k] = v

return ret
85 changes: 60 additions & 25 deletions seed/tests/test_account_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from seed.utils.organizations import create_organization
from seed.utils.users import get_js_role, get_role_from_js
from seed.views.main import _get_default_org
from seed.views.main import _get_default_org as get_default_org_for_user
from seed.views.v3.organizations import _dict_org


Expand Down Expand Up @@ -575,27 +576,70 @@ def test_update_user(self):
json.dumps(user_data),
content_type="application/json",
)
self.assertEqual(
json.loads(resp.content), {"status": "success", "api_key": "", "email": "[email protected]", "first_name": "bob", "last_name": "d"}
)
(
initial_org_id,
initial_org_name,
initial_org_user_role,
access_level_instance_name,
access_level_instance_id,
is_ali_root,
is_ali_leaf,
) = get_default_org_for_user(self.user)
profile = {
"username": "[email protected]",
"email": "[email protected]",
"first_name": "bob",
"last_name": "d",
"ali_id": access_level_instance_id,
"ali_name": access_level_instance_name,
"api_key": "",
"is_ali_root": is_ali_root,
"is_ali_leaf": is_ali_leaf,
"org_id": initial_org_id,
"org_name": initial_org_name,
"org_role": initial_org_user_role,
"pk": self.user.pk,
"id": self.user.pk,
"two_factor_method": "disabled",
"is_superuser": self.user.is_superuser,
}
self.assertEqual(json.loads(resp.content), profile)

def test_get_user_profile(self):
"""test for get_user_profile"""
resp = self.client.get(
reverse_lazy("api:v3:user-detail", args=[self.user.pk]),
content_type="application/json",
)
self.assertEqual(
json.loads(resp.content),
{
"status": "success",
"api_key": "",
"email": "[email protected]",
"first_name": "Johnny",
"last_name": "Energy",
"two_factor_method": "disabled",
},
)
(
initial_org_id,
initial_org_name,
initial_org_user_role,
access_level_instance_name,
access_level_instance_id,
is_ali_root,
is_ali_leaf,
) = get_default_org_for_user(self.user)
profile = {
"username": "[email protected]",
"email": "[email protected]",
"first_name": "Johnny",
"last_name": "Energy",
"ali_id": access_level_instance_id,
"ali_name": access_level_instance_name,
"api_key": "",
"is_ali_root": is_ali_root,
"is_ali_leaf": is_ali_leaf,
"org_id": initial_org_id,
"org_name": initial_org_name,
"org_role": initial_org_user_role,
"pk": self.user.pk,
"id": self.user.pk,
"two_factor_method": "disabled",
"is_superuser": self.user.is_superuser,
}

self.assertEqual(json.loads(resp.content), profile)
resp = self.client.post(
reverse_lazy("api:v3:user-generate-api-key", args=[self.user.pk]),
content_type="application/json",
Expand All @@ -604,17 +648,8 @@ def test_get_user_profile(self):
reverse_lazy("api:v3:user-detail", args=[self.user.pk]),
content_type="application/json",
)
self.assertEqual(
json.loads(resp.content),
{
"status": "success",
"api_key": User.objects.get(pk=self.user.pk).api_key,
"email": "[email protected]",
"first_name": "Johnny",
"last_name": "Energy",
"two_factor_method": "disabled",
},
)
profile["api_key"] = User.objects.get(pk=self.user.pk).api_key
self.assertEqual(json.loads(resp.content), profile)

def test_generate_api_key(self):
"""test for generate_api_key
Expand Down
8 changes: 4 additions & 4 deletions seed/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def test_user_profile(self):
self.assertEqual(r.status_code, 200)

r = json.loads(r.content)
self.assertEqual(r["status"], "success")
self.assertEqual(r["first_name"], "Jaqen")
self.assertEqual(r["last_name"], "H'ghar")
self.client.logout()
Expand Down Expand Up @@ -258,10 +257,11 @@ def test_update_user(self):
)

# re-retrieve the user profile
r = self.client.get("/api/v3/users/" + str(self.user.pk) + "/", follow=True, **self.headers)
r = json.loads(r.content)
res = self.client.get("/api/v3/users/" + str(self.user.pk) + "/", follow=True, **self.headers)
r = json.loads(res.content)

self.assertEqual(res.status_code, 200)

self.assertEqual(r["status"], "success")
self.assertEqual(r["first_name"], "Arya")
self.assertEqual(r["last_name"], "Stark")

Expand Down
51 changes: 4 additions & 47 deletions seed/views/v3/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@
from django.core.exceptions import ObjectDoesNotExist, ValidationError
from django.db import IntegrityError
from django.http import JsonResponse
from django_otp import devices_for_user
from django_otp.plugins.otp_email.models import EmailDevice
from django_otp.plugins.otp_totp.models import TOTPDevice
from drf_yasg.utils import swagger_auto_schema
from rest_framework import serializers, status, viewsets
from rest_framework.decorators import action
Expand All @@ -28,7 +26,6 @@
from seed.utils.api_schema import AutoSchemaHelper, swagger_auto_schema_org_query_param
from seed.utils.organizations import create_organization
from seed.utils.users import get_role_from_js
from seed.views.main import _get_default_org as get_default_org_for_user

_log = logging.getLogger(__name__)

Expand Down Expand Up @@ -259,23 +256,8 @@ def current(self, request):
required: true
type: string
"""
response = dict(
list(
zip(
("org_id", "org_name", "org_role", "ali_name", "ali_id", "is_ali_root", "is_ali_leaf"),
get_default_org_for_user(request.user),
)
)
)
response["pk"] = request.user.id
response["id"] = request.user.id
response["first_name"] = request.user.first_name
response["last_name"] = request.user.last_name
response["email"] = request.user.email
response["username"] = request.user.username
response["is_superuser"] = request.user.is_superuser
response["api_key"] = request.user.api_key
return JsonResponse(response)

return JsonResponse(request.user.serialize())

@swagger_auto_schema(
manual_parameters=[AutoSchemaHelper.query_org_id_field()],
Expand Down Expand Up @@ -383,24 +365,7 @@ def retrieve(self, request, pk=None):
else:
return content

two_factor_devices = list(devices_for_user(user))
if two_factor_devices and isinstance(two_factor_devices[0], EmailDevice):
two_factor_method = "email"
elif two_factor_devices and isinstance(two_factor_devices[0], TOTPDevice):
two_factor_method = "token"
else:
two_factor_method = "disabled"

return JsonResponse(
{
"status": "success",
"first_name": user.first_name,
"last_name": user.last_name,
"email": user.email,
"api_key": user.api_key,
"two_factor_method": two_factor_method,
}
)
return JsonResponse(user.serialize())

@ajax_request_class
@action(detail=True, methods=["POST"])
Expand Down Expand Up @@ -458,15 +423,7 @@ def update(self, request, pk=None):
user.email = json_user.get("email")
user.username = json_user.get("email", "").lower()
user.save()
return JsonResponse(
{
"status": "success",
"first_name": user.first_name,
"last_name": user.last_name,
"email": user.email,
"api_key": user.api_key,
}
)
return JsonResponse(user.serialize())

@swagger_auto_schema(
request_body=AutoSchemaHelper.schema_factory(
Expand Down

0 comments on commit 2729052

Please sign in to comment.