Skip to content

Commit

Permalink
feat: filter restricted runs on APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
zawan-ila committed Jun 14, 2024
1 parent 9a5b5aa commit 13714a7
Show file tree
Hide file tree
Showing 30 changed files with 464 additions and 64 deletions.
24 changes: 14 additions & 10 deletions course_discovery/apps/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from course_discovery.apps.api.fields import (
HtmlField, ImageField, SlugRelatedFieldWithReadSerializer, SlugRelatedTranslatableField, StdImageSerializerField
)
from course_discovery.apps.api.utils import StudioAPI
from course_discovery.apps.api.utils import StudioAPI, get_excluded_restriction_types
from course_discovery.apps.catalogs.models import Catalog
from course_discovery.apps.core.api_client.lms import LMSAPIClient
from course_discovery.apps.core.utils import update_instance
Expand Down Expand Up @@ -1638,8 +1638,10 @@ class CourseWithRecommendationsSerializer(FlexFieldsSerializerMixin, TimestampMo
recommendations = serializers.SerializerMethodField()

def get_recommendations(self, course):
excluded_restriction_types = get_excluded_restriction_types(self.context['request'])
recommended_courses = course.recommendations(excluded_restriction_types=excluded_restriction_types)
return CourseRecommendationSerializer(
course.recommendations(),
recommended_courses,
many=True,
context={
'request': self.context.get('request'),
Expand Down Expand Up @@ -1996,7 +1998,7 @@ def get_organization_logo_override_url(self, obj):
return None

@classmethod
def prefetch_queryset(cls, partner, queryset=None):
def prefetch_queryset(cls, partner, queryset=None, course_runs=None):
# Explicitly check if the queryset is None before selecting related
queryset = queryset if queryset is not None else Program.objects.filter(partner=partner)

Expand All @@ -2020,7 +2022,7 @@ def prefetch_queryset(cls, partner, queryset=None):
'degree__rankings',
'degree__quick_facts',
'labels',
Prefetch('courses', queryset=MinimalProgramCourseSerializer.prefetch_queryset()),
Prefetch('courses', queryset=MinimalProgramCourseSerializer.prefetch_queryset(course_runs=course_runs)),
Prefetch('authoring_organizations', queryset=OrganizationSerializer.prefetch_queryset(partner)),
)

Expand Down Expand Up @@ -2165,8 +2167,8 @@ class MinimalExtendedProgramSerializer(MinimalProgramSerializer):
expected_learning_items = serializers.SlugRelatedField(many=True, read_only=True, slug_field='value')

@classmethod
def prefetch_queryset(cls, partner, queryset=None):
queryset = super().prefetch_queryset(partner=partner, queryset=queryset)
def prefetch_queryset(cls, partner, queryset=None, course_runs=None):
queryset = super().prefetch_queryset(partner=partner, queryset=queryset, course_runs=course_runs)

return queryset.prefetch_related(
'expected_learning_items',
Expand Down Expand Up @@ -2209,7 +2211,7 @@ class ProgramSerializer(MinimalProgramSerializer):
product_source = SourceSerializer(required=False, read_only=True)

@classmethod
def prefetch_queryset(cls, partner, queryset=None):
def prefetch_queryset(cls, partner, queryset=None, course_runs=None):
"""
Prefetch the related objects that will be serialized with a `Program`.
Expand Down Expand Up @@ -2255,7 +2257,7 @@ def prefetch_queryset(cls, partner, queryset=None):
'instructor_ordering',
# We need the full Course prefetch here to get CourseRun information that methods on the Program
# model iterate across (e.g. language). These fields aren't prefetched by the minimal Course serializer.
Prefetch('courses', queryset=CourseSerializer.prefetch_queryset(partner=partner)),
Prefetch('courses', queryset=CourseSerializer.prefetch_queryset(partner=partner, course_runs=course_runs)),
Prefetch('authoring_organizations', queryset=OrganizationSerializer.prefetch_queryset(partner)),
Prefetch('credit_backing_organizations', queryset=OrganizationSerializer.prefetch_queryset(partner)),
Prefetch('corporate_endorsements', queryset=CorporateEndorsementSerializer.prefetch_queryset()),
Expand Down Expand Up @@ -2302,11 +2304,13 @@ class PathwaySerializer(BaseModelSerializer):
course_run_statuses = serializers.ReadOnlyField()

@classmethod
def prefetch_queryset(cls, partner):
def prefetch_queryset(cls, partner, course_runs=None):
queryset = Pathway.objects.filter(partner=partner)

return queryset.prefetch_related(
Prefetch('programs', queryset=MinimalProgramSerializer.prefetch_queryset(partner=partner)),
Prefetch('programs', queryset=MinimalProgramSerializer.prefetch_queryset(
partner=partner, course_runs=course_runs
)),
)

class Meta:
Expand Down
13 changes: 11 additions & 2 deletions course_discovery/apps/api/tests/test_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2347,7 +2347,9 @@ def test_detail_fields_in_response(self, is_post_request):
'staff': MinimalPersonSerializer(course_run.staff, many=True,
context={'request': request}).data,
'content_language': course_run.language.code if course_run.language else None,

'restriction_type': (
course_run.restricted_run.restriction_type if hasattr(course_run, 'restricted_run') else None
)
}],
'uuid': str(course.uuid),
'subjects': [subject.name for subject in course.subjects.all()],
Expand Down Expand Up @@ -2418,6 +2420,9 @@ def get_expected_data(cls, course, course_run, course_skill, seat):
'estimated_hours': get_course_run_estimated_hours(course_run),
'first_enrollable_paid_seat_price': course_run.first_enrollable_paid_seat_price or 0.0,
'is_enrollable': course_run.is_enrollable,
'restriction_type': (
course_run.restricted_run.restriction_type if hasattr(course_run, 'restricted_run') else None
)
}],
'uuid': str(course.uuid),
'subjects': [subject.name for subject in course.subjects.all()],
Expand Down Expand Up @@ -2549,6 +2554,9 @@ def get_expected_data(cls, course_run, course_skill, request):
'first_enrollable_paid_seat_sku': course_run.first_enrollable_paid_seat_sku(),
'first_enrollable_paid_seat_price': course_run.first_enrollable_paid_seat_price,
'is_enrollable': course_run.is_enrollable,
'restriction_type': (
course_run.restricted_run.restriction_type if hasattr(course_run, 'restricted_run') else None
)
}


Expand Down Expand Up @@ -2751,7 +2759,8 @@ def get_expected_data(cls, learner_pathway, request):
'visible_via_association': True,
'steps': LearnerPathwayStepSerializer(
learner_pathway.steps.all(),
many=True
many=True,
context={'request': request}
).data,
'created': serialize_datetime(learner_pathway.created),
}
Expand Down
6 changes: 6 additions & 0 deletions course_discovery/apps/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from course_discovery.apps.core.api_client.lms import LMSAPIClient
from course_discovery.apps.core.utils import serialize_datetime
from course_discovery.apps.course_metadata.choices import CourseRunRestrictionType
from course_discovery.apps.course_metadata.models import CourseRun

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -199,6 +200,11 @@ def increment_character(character):
return chr(ord(character) + 1) if character != 'z' else 'a'


def get_excluded_restriction_types(request):
include_restricted = request.query_params.get('include_restricted', '').split(',')
return list(set(CourseRunRestrictionType.values) - set(include_restricted))


class StudioAPI:
"""
A convenience class for talking to the Studio API - designed to allow subclassing by the publisher django app,
Expand Down
29 changes: 28 additions & 1 deletion course_discovery/apps/api/v1/tests/test_views/test_catalogs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from course_discovery.apps.course_metadata.choices import CourseRunStatus
from course_discovery.apps.course_metadata.models import Course, CourseType
from course_discovery.apps.course_metadata.tests.factories import (
CourseRunFactory, SeatFactory, SeatTypeFactory, SubjectFactory
CourseRunFactory, RestrictedCourseRunFactory, SeatFactory, SeatTypeFactory, SubjectFactory
)
from course_discovery.conftest import get_course_run_states

Expand Down Expand Up @@ -335,6 +335,33 @@ def test_courses(self, state):
assert response.status_code == 200
assert response.data['results'] == []

@ddt.data([True, 2], [False, 1])
@ddt.unpack
def test_courses_with_restricted_runs(self, include_restriction_param, expected_result_count):
url = reverse('api:v1:catalog-courses', kwargs={'id': self.catalog.id})
Course.objects.all().delete()

now = datetime.datetime.now(pytz.UTC)
future = now + datetime.timedelta(days=30)
course_run = CourseRunFactory.create(
course__title='ABC Test Course With Archived', end=future, enrollment_end=future
)
restricted_course_run = CourseRunFactory.create(
course=course_run.course,
course__title='ABC Test Course With Archived', end=future, enrollment_end=future,
status=CourseRunStatus.Published
)
RestrictedCourseRunFactory(course_run=restricted_course_run, restriction_type='custom-b2b-enterprise')
SeatFactory.create(course_run=course_run)
SeatFactory.create(course_run=restricted_course_run)

if include_restriction_param:
url += '?include_restricted=custom-b2b-enterprise'

response = self.client.get(url)
assert response.status_code == 200
assert len(response.data['results'][0]['course_runs']) == expected_result_count

def test_courses_with_include_archived(self):
"""
Verify the endpoint returns the list of available and archived courses if include archived
Expand Down
38 changes: 38 additions & 0 deletions course_discovery/apps/api/v1/tests/test_views/test_course_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pytz
import responses
from django.contrib.auth.models import Group
from django.core.management import call_command
from django.db.models.functions import Lower
from django.db.models.signals import pre_save
from django.test import override_settings
Expand Down Expand Up @@ -1211,6 +1212,43 @@ def test_list_sorted_by_course_start_date(self):
self.serialize_course_run(CourseRun.objects.all().order_by('start'), many=True)
)

@ddt.data(True, False)
def test_list_include_restricted(self, include_restriction_param):
restricted_run = CourseRunFactory(course__partner=self.partner)
RestrictedCourseRunFactory(course_run=restricted_run, restriction_type='custom-b2c')
url = reverse('api:v1:course_run-list')
if include_restriction_param:
url += '?include_restricted=custom-b2c'

with self.assertNumQueries(14, threshold=3):
response = self.client.get(url)

assert response.status_code == 200
retrieved_keys = [r['key'] for r in response.data['results']]
if include_restriction_param:
assert restricted_run.key in retrieved_keys
else:
assert restricted_run.key not in retrieved_keys

@ddt.data([True, 4], [False, 3])
@ddt.unpack
def test_list_query_include_restricted(self, include_restriction_param, expected_result_count):
CourseRunFactory.create_batch(3, title='Some cool title', course__partner=self.partner)
CourseRunFactory(title='non-cool title')
restricted_run = CourseRunFactory(title='Some cool title', course__partner=self.partner)
RestrictedCourseRunFactory(course_run=restricted_run, restriction_type='custom-b2c')
query = 'title:Some cool title'
url = '{root}?q={query}'.format(root=reverse('api:v1:course_run-list'), query=query)
if include_restriction_param:
url += '&include_restricted=custom-b2c,custom-b2b-enterprise'

call_command('search_index', '--rebuild', '-f')

with self.assertNumQueries(30, threshold=3):
response = self.client.get(url)

assert len(response.data['results']) == expected_result_count

def test_list_query(self):
""" Verify the endpoint returns a filtered list of courses """
course_runs = CourseRunFactory.create_batch(3, title='Some random title', course__partner=self.partner)
Expand Down
64 changes: 63 additions & 1 deletion course_discovery/apps/api/v1/tests/test_views/test_courses.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from course_discovery.apps.course_metadata.tests.factories import (
CourseEditorFactory, CourseEntitlementFactory, CourseFactory, CourseLocationRestrictionFactory, CourseRunFactory,
CourseTypeFactory, GeoLocationFactory, LevelTypeFactory, OrganizationFactory, ProductValueFactory, ProgramFactory,
SeatFactory, SeatTypeFactory, SourceFactory, SubjectFactory
RestrictedCourseRunFactory, SeatFactory, SeatTypeFactory, SourceFactory, SubjectFactory
)
from course_discovery.apps.course_metadata.toggles import IS_SUBDIRECTORY_SLUG_FORMAT_ENABLED
from course_discovery.apps.course_metadata.utils import data_modified_timestamp_update, ensure_draft_world
Expand Down Expand Up @@ -278,6 +278,68 @@ def test_course_runs_are_ordered(self):
self.assertListEqual(response.data['course_run_keys'], expected_keys)
self.assertListEqual([run['key'] for run in response.data['course_runs']], expected_keys)

@ddt.data(True, False)
def test_course_runs_restriction(self, include_restriction_param):
run_restricted = CourseRunFactory(
course=self.course,
start=datetime.datetime(2033, 1, 1, tzinfo=pytz.UTC),
status=CourseRunStatus.Published
)
run_not_restricted = CourseRunFactory(
course=self.course,
start=datetime.datetime(2033, 1, 1, tzinfo=pytz.UTC),
status=CourseRunStatus.Unpublished
)
RestrictedCourseRunFactory(course_run=run_restricted, restriction_type='custom-b2c')
SeatFactory(course_run=run_restricted)
SeatFactory(course_run=run_not_restricted)

url = reverse('api:v1:course-detail', kwargs={'key': self.course.key})
if include_restriction_param:
url += '?include_restricted=custom-b2c'
with self.assertNumQueries(36, threshold=3):
response = self.client.get(url)
assert response.status_code == 200

if not include_restriction_param:
self.assertEqual(response.data['course_run_keys'], [run_not_restricted.key])
self.assertEqual(response.data['course_run_statuses'], [run_not_restricted.status])
self.assertEqual(len(response.data['course_runs']), 1)
self.assertEqual(response.data['advertised_course_run_uuid'], None)
else:
self.assertEqual(set(response.data['course_run_keys']), {run_not_restricted.key, run_restricted.key})
self.assertEqual(
set(response.data['course_run_statuses']),
{run_not_restricted.status, run_restricted.status}
)
self.assertEqual(len(response.data['course_runs']), 2)
self.assertEqual(response.data['advertised_course_run_uuid'], run_restricted.uuid)

def test_course_runs_restriction_param(self):
run_restricted = CourseRunFactory(
course=self.course,
start=datetime.datetime(2033, 1, 1, tzinfo=pytz.UTC),
status=CourseRunStatus.Published
)
run_not_restricted = CourseRunFactory(
course=self.course,
start=datetime.datetime(2033, 1, 1, tzinfo=pytz.UTC),
status=CourseRunStatus.Unpublished
)
RestrictedCourseRunFactory(course_run=run_restricted, restriction_type='custom-b2c')
SeatFactory(course_run=run_restricted)

url = reverse('api:v1:course-detail', kwargs={'key': self.course.key})
url += '?include_restricted=custom-b2c'
with self.assertNumQueries(36, threshold=3):
response = self.client.get(url)
assert response.status_code == 200

self.assertEqual(set(response.data['course_run_keys']), {run_not_restricted.key, run_restricted.key})
self.assertEqual(set(response.data['course_run_statuses']), {run_not_restricted.status, run_restricted.status})
self.assertEqual(len(response.data['course_runs']), 2)
self.assertEqual(response.data['advertised_course_run_uuid'], run_restricted.uuid)

def test_list(self):
""" Verify the endpoint returns a list of all courses. """
url = reverse('api:v1:course-list')
Expand Down
26 changes: 22 additions & 4 deletions course_discovery/apps/api/v1/tests/test_views/test_programs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
from course_discovery.apps.api.v1.views.programs import ProgramViewSet
from course_discovery.apps.core.tests.factories import USER_PASSWORD, UserFactory
from course_discovery.apps.core.tests.helpers import make_image_file
from course_discovery.apps.course_metadata.choices import ProgramStatus
from course_discovery.apps.course_metadata.choices import CourseRunStatus, ProgramStatus
from course_discovery.apps.course_metadata.models import CourseType, Program, ProgramType
from course_discovery.apps.course_metadata.tests.factories import (
CorporateEndorsementFactory, CourseFactory, CourseRunFactory, CurriculumCourseMembershipFactory, CurriculumFactory,
CurriculumProgramMembershipFactory, DegreeAdditionalMetadataFactory, DegreeFactory, EndorsementFactory,
ExpectedLearningItemFactory, JobOutlookItemFactory, OrganizationFactory, PersonFactory, ProgramFactory,
ProgramTypeFactory, VideoFactory
ProgramTypeFactory, RestrictedCourseRunFactory, VideoFactory
)


Expand Down Expand Up @@ -48,13 +48,16 @@ def setup(self, client, django_assert_num_queries, partner):
self.partner = partner
self.request = request

def create_program(self, courses=None, program_type=None):
def create_program(self, courses=None, program_type=None, include_restricted_run=False):
organizations = [OrganizationFactory(partner=self.partner)]
person = PersonFactory()

if courses is None:
courses = [CourseFactory(partner=self.partner)]
CourseRunFactory(course=courses[0], staff=[person])
course_run = CourseRunFactory(course=courses[0], staff=[person])

if include_restricted_run:
RestrictedCourseRunFactory(course_run=course_run, restriction_type='custom-b2c')

if program_type is None:
program_type = ProgramTypeFactory()
Expand Down Expand Up @@ -216,6 +219,21 @@ def test_list(self):

self.assert_list_results(self.list_path, expected, 26)

@pytest.mark.parametrize("include_restriction_param", [True, False])
def test_list_restricted_runs(self, include_restriction_param):
self.create_program(include_restricted_run=True)
query_param_string = "?include_restricted=custom-b2c" if include_restriction_param else ""
resp = self.client.get(self.list_path + query_param_string)

if include_restriction_param:
assert resp.data['results'][0]['courses'][0]['course_runs']
assert resp.data['results'][0]['courses'][0]['course_run_statuses']
assert resp.data['results'][0]['course_run_statuses'] == [CourseRunStatus.Published]
else:
assert not resp.data['results'][0]['courses'][0]['course_runs']
assert not resp.data['results'][0]['courses'][0]['course_run_statuses']
assert resp.data['results'][0]['course_run_statuses'] == []

def test_extended_query_param_fields(self):
""" Verify that the `extended` query param will result in an extended amount of fields returned. """
for _ in range(3):
Expand Down
Loading

0 comments on commit 13714a7

Please sign in to comment.