Skip to content

Commit

Permalink
refactor: refactor Course.retrieve to prevent calling get_object twice (
Browse files Browse the repository at this point in the history
#4219)

* refactor: refactor Course.retrieve to prevent calling get_object twice

* fix: refetch entitlements if they are created
  • Loading branch information
zawan-ila authored Jan 10, 2024
1 parent 54f3973 commit 81d6691
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
8 changes: 4 additions & 4 deletions course_discovery/apps/api/v1/tests/test_views/test_courses.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def test_get(self):
""" Verify the endpoint returns the details for a single course. """
url = reverse('api:v1:course-detail', kwargs={'key': self.course.key})

with self.assertNumQueries(43, threshold=3):
with self.assertNumQueries(26, threshold=3):
response = self.client.get(url)
assert response.status_code == 200
assert response.data == self.serialize_course(self.course)
Expand All @@ -112,7 +112,7 @@ def test_get_uuid(self):
""" Verify the endpoint returns the details for a single course with UUID. """
url = reverse('api:v1:course-detail', kwargs={'key': self.course.uuid})

with self.assertNumQueries(44):
with self.assertNumQueries(27):
response = self.client.get(url)
assert response.status_code == 200
assert response.data == self.serialize_course(self.course)
Expand All @@ -121,7 +121,7 @@ def test_get_exclude_deleted_programs(self):
""" Verify the endpoint returns no deleted associated programs """
ProgramFactory(courses=[self.course], status=ProgramStatus.Deleted)
url = reverse('api:v1:course-detail', kwargs={'key': self.course.key})
with self.assertNumQueries(43):
with self.assertNumQueries(26):
response = self.client.get(url)
assert response.status_code == 200
assert response.data.get('programs') == []
Expand All @@ -134,7 +134,7 @@ def test_get_include_deleted_programs(self):
ProgramFactory(courses=[self.course], status=ProgramStatus.Deleted)
url = reverse('api:v1:course-detail', kwargs={'key': self.course.key})
url += '?include_deleted_programs=1'
with self.assertNumQueries(47):
with self.assertNumQueries(29):
response = self.client.get(url)
assert response.status_code == 200
assert response.data == self.serialize_course(self.course, extra_context={'include_deleted_programs': True})
Expand Down
8 changes: 6 additions & 2 deletions course_discovery/apps/api/v1/views/courses.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,8 +516,12 @@ def retrieve(self, request, *args, **kwargs):
course = self.get_object()
if get_query_param(request, 'editable') and not course.entitlements.exists():
create_missing_entitlement(course)

return super().retrieve(request, *args, **kwargs)
course.refresh_from_db(fields=['entitlements'])
# Rather than call super().retrieve, we instantiate the serializer and return its
# data ourselves. This is to prevent duplicate calls (and hence duplicate queries)
# to self.get_object. Note that we have called get_object once already(see above).
serializer = self.get_serializer(course)
return Response(serializer.data)


class CourseRecommendationViewSet(RetrieveModelMixin, viewsets.GenericViewSet):
Expand Down

0 comments on commit 81d6691

Please sign in to comment.