Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add SearchAfterMixin for ES search_after capability #4536

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import urllib

from rest_framework.reverse import reverse

from course_discovery.apps.api.v1.tests.test_views.mixins import APITestCase
from course_discovery.apps.core.tests.factories import UserFactory
from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin
from course_discovery.apps.course_metadata.tests.factories import CourseFactory, CourseRunFactory


class CatalogQueryViewSetTests(ElasticsearchTestMixin, APITestCase):
"""
Unit tests for CatalogQueryViewSet.
"""
def setUp(self):
super().setUp()
self.user = UserFactory(is_staff=True, is_superuser=True)
self.client.force_authenticate(self.user)
self.course = CourseFactory(partner=self.partner, key='simple_key')
self.course_run = CourseRunFactory(course=self.course, key='simple/key/run')
self.url_base = reverse('api:v2:catalog-query_contains')
self.error_message = 'CatalogQueryContains endpoint requires query and identifiers list(s)'
self.refresh_index()

def test_contains_single_course_run(self):
""" Verify that a single course_run is contained in a query. """
qs = urllib.parse.urlencode({
'query': 'id:' + self.course_run.key,
'course_run_ids': self.course_run.key,
'course_uuids': self.course.uuid,
})
url = f'{self.url_base}/?{qs}'
response = self.client.get(url)
assert response.status_code == 200
assert response.data == {self.course_run.key: True, str(self.course.uuid): False}

def test_contains_single_course(self):
""" Verify that a single course is contained in a query. """
qs = urllib.parse.urlencode({
'query': 'key:' + self.course.key,
'course_run_ids': self.course_run.key,
'course_uuids': self.course.uuid,
})
url = f'{self.url_base}/?{qs}'
response = self.client.get(url)
assert response.status_code == 200
assert response.data == {self.course_run.key: False, str(self.course.uuid): True}

def test_contains_course_and_run(self):
""" Verify that both the course and the run are contained in the broadest query. """
self.course.course_runs.add(self.course_run)
self.course.save()
qs = urllib.parse.urlencode({
'query': 'org:*',
'course_run_ids': self.course_run.key,
'course_uuids': self.course.uuid,
})
url = f'{self.url_base}/?{qs}'
response = self.client.get(url)
assert response.status_code == 200
assert response.data == {self.course_run.key: True, str(self.course.uuid): True}

def test_no_identifiers(self):
""" Verify that a 400 status is returned if request does not contain any identifier lists. """
qs = urllib.parse.urlencode({
'query': 'id:*'
})
url = f'{self.url_base}/?{qs}'
response = self.client.get(url)
assert response.status_code == 400
assert response.data == self.error_message

def test_no_query(self):
""" Verify that a 400 status is returned if request does not contain a querystring. """
qs = urllib.parse.urlencode({
'course_run_ids': self.course_run.key,
'course_uuids': self.course.uuid,
})
url = f'{self.url_base}/?{qs}'
response = self.client.get(url)
assert response.status_code == 400
assert response.data == self.error_message

def test_incorrect_queries(self):
""" Verify that a 400 status is returned if request contains incorrect query string. """
qs = urllib.parse.urlencode({
'query': 'title:',
'course_run_ids': self.course_run.key,
'course_uuids': self.course.uuid,
})
url = f'{self.url_base}/?{qs}'

response = self.client.get(url)
assert response.status_code == 400
8 changes: 7 additions & 1 deletion course_discovery/apps/api/v2/urls.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
"""API v2 URLs."""

from django.urls import re_path
from rest_framework import routers

from course_discovery.apps.api.v2.views import search as search_views
from course_discovery.apps.api.v2.views.catalog_queries import CatalogQueryContainsViewSet

app_name = 'v2'

urlpatterns = [
re_path(r'^catalog/query_contains/?', CatalogQueryContainsViewSet.as_view(), name='catalog-query_contains'),
]

router = routers.SimpleRouter()
router.register(r'search/all', search_views.AggregateSearchViewSet, basename='search-all')
urlpatterns = router.urls
urlpatterns += router.urls
75 changes: 75 additions & 0 deletions course_discovery/apps/api/v2/views/catalog_queries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import logging
from uuid import UUID

from elasticsearch_dsl.query import Q as ESDSLQ
from rest_framework import status
from rest_framework.generics import GenericAPIView
from rest_framework.permissions import DjangoModelPermissions, IsAuthenticated
from rest_framework.response import Response

from course_discovery.apps.api.mixins import ValidElasticSearchQueryRequiredMixin
from course_discovery.apps.course_metadata.models import Course, CourseRun, SearchAfterMixin
from course_discovery.apps.course_metadata.search_indexes.documents import CourseDocument, CourseRunDocument

log = logging.getLogger(__name__)


class CatalogQueryContainsViewSet(ValidElasticSearchQueryRequiredMixin, GenericAPIView, SearchAfterMixin):
permission_classes = (IsAuthenticated, DjangoModelPermissions)
queryset = Course.objects.all()

def get(self, request):
"""
Determine if a set of courses and/or course runs is found in the query results.

Returns
dict: mapping of course and run identifiers included in the request to boolean values
indicating whether the associated course or run is contained in the queryset
described by the query found in the request.
"""
query = request.GET.get('query')
course_run_ids = request.GET.get('course_run_ids', None)
course_uuids = request.GET.get('course_uuids', None)
partner = self.request.site.partner

if query and (course_run_ids or course_uuids):
log.info(
f"Attempting search against query {query} with course UUIDs {course_uuids} "
f"and course run IDs {course_run_ids}"
)
identified_course_ids = set()
specified_course_ids = []
if course_run_ids:
course_run_ids = course_run_ids.split(',')
specified_course_ids = course_run_ids
identified_course_ids.update(
self.search(
query,
queryset=CourseRun.objects.all(),
partner=ESDSLQ('term', partner=partner.short_code),
identifiers=ESDSLQ('terms', **{'key.raw': course_run_ids}),
document=CourseRunDocument
).values_list('key', flat=True)
)

if course_uuids:
course_uuids = [UUID(course_uuid) for course_uuid in course_uuids.split(',')]
specified_course_ids += course_uuids

log.info(f"Specified course ids: {specified_course_ids}")
identified_course_ids.update(
self.search(
query,
queryset=Course.objects.all(),
partner=ESDSLQ('term', partner=partner.short_code),
identifiers=ESDSLQ('terms', **{'uuid': course_uuids}),
document=CourseDocument
).values_list('uuid', flat=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this uses values_list while course_run_ids is using comprehension. We can should make it consistent.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's rather consistent now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it? The code is still the same.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I misunderstood the earlier comment. Lemme use values_list for both.

)
log.info(f"Identified {len(identified_course_ids)} course ids: {identified_course_ids}")

contains = {str(identifier): identifier in identified_course_ids for identifier in specified_course_ids}
return Response(contains)
return Response(
'CatalogQueryContains endpoint requires query and identifiers list(s)', status=status.HTTP_400_BAD_REQUEST
)
69 changes: 69 additions & 0 deletions course_discovery/apps/course_metadata/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1149,6 +1149,75 @@ def search(cls, query, queryset=None):
return filtered_queryset


class SearchAfterMixin:
"""
Represents objects to query Elasticsearch with `search_after` pagination and load by primary key.
"""

@classmethod
def search(cls, query, queryset=None, page_size=settings.ELASTICSEARCH_DSL_QUERYSET_PAGINATION, partner=None,
identifiers=None, document=None):
"""
Queries the Elasticsearch index with optional pagination using `search_after`.

Args:
query (str) -- Elasticsearch querystring (e.g. `title:intro*`)
queryset (models.QuerySet) -- base queryset to search, defaults to objects.all()
page_size (int) -- Number of results per page.
partner (object) -- To be included in the ES query.
identifiers (object) -- UUID or key of a product.

Returns:
QuerySet
Ali-D-Akbar marked this conversation as resolved.
Show resolved Hide resolved
"""
query = clean_query(query)
queryset = queryset or cls.objects.all()

if query == '(*)':
# Early-exit optimization. Wildcard searching is very expensive in elasticsearch. And since we just
# want everything, we don't need to actually query elasticsearch at all.
return queryset
Ali-D-Akbar marked this conversation as resolved.
Show resolved Hide resolved

logger.info(f"Attempting Elasticsearch document search against query: {query}")
es_document = document or next(iter(registry.get_documents(models=(cls,))), None)

must_queries = [ESDSLQ('query_string', query=query, analyze_wildcard=True)]
if partner:
must_queries.append(partner)
if identifiers:
must_queries.append(identifiers)

dsl_query = ESDSLQ('bool', must=must_queries)

all_ids = set()
search_after = None

while True:
search = (
es_document.search()
.query(dsl_query)
.sort('id')
.extra(size=page_size)
)

search = search.extra(search_after=search_after) if search_after else search

results = search.execute()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious: should we not add error handling here, in case any of the sub-sequent request fails?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error is already handled in this custom dispatch

exception = InvalidQuery(f'Failed to make Elasticsearch request. Got exception: {exc}')
and will be shown on the API response as a result.

{
    "detail": "Failed to make Elasticsearch request. Got exception: RequestError(400, 'search_phase_execution_exception', 'Failed to parse query [(org:)]')"
}


ids = {result.pk for result in results}
if not ids:
logger.info("No more results found.")
break

all_ids.update(ids)
search_after = results[-1].meta.sort if results[-1] else None
logger.info(f"Fetched {len(ids)} records; total so far: {len(all_ids)}")

filtered_queryset = queryset.filter(pk__in=all_ids)
logger.info(f"Filtered queryset of size {len(filtered_queryset)} for query: {query}")
return filtered_queryset


class Collaborator(TimeStampedModel):
"""
Collaborator model, defining any collaborators who helped write course content.
Expand Down
12 changes: 12 additions & 0 deletions course_discovery/apps/course_metadata/tests/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,3 +1011,15 @@ class Meta:

course_run = factory.SubFactory(CourseRunFactory)
restriction_type = FuzzyChoice([name for name, __ in CourseRunRestrictionType.choices])


class CourseProxy(SearchAfterMixin, Course):
"""Proxy model for testing SearchAfterMixin with Course."""
class Meta:
proxy = True


class CourseProxyFactory(CourseFactory):
"""Factory for the CourseProxy proxy model."""
class Meta:
model = CourseProxy
45 changes: 45 additions & 0 deletions course_discovery/apps/course_metadata/tests/test_mixins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from unittest.mock import patch

from django.test import TestCase

from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin
from course_discovery.apps.course_metadata.search_indexes.documents import CourseDocument
from course_discovery.apps.course_metadata.tests import factories


class TestSearchAfterMixin(ElasticsearchTestMixin, TestCase):
"""
Unit tests for SearchAfterMixin.
Uses a proxy model `CourseProxy` that extends this mixin so we can replicate the behavior for Courses.
"""
def setUp(self):
super().setUp()

self.total_courses = 5
factories.CourseFactory.create_batch(self.total_courses)

@patch("course_discovery.apps.course_metadata.models.registry.get_documents")
def test_fetch_all_courses(self, mock_get_documents):
query = 'Course*'
mock_get_documents.return_value = [CourseDocument]

queryset = factories.CourseProxy.search(query=query, page_size=2)

unique_items = set(queryset)
self.assertEqual(len(queryset), len(unique_items), 'Queryset contains duplicate entries.')
self.assertEqual(len(queryset), self.total_courses)

def test_wildcard_query_early_exit(self):
"""
Test the early exit optimization when the query is `(*)`.
"""
query = '*'

queryset = factories.CourseProxy.search(query=query)

self.assertEqual(len(queryset), self.total_courses)
self.assertQuerysetEqual(
queryset.order_by("id"),
factories.Course.objects.all().order_by("id"),
transform=lambda x: x
)
Loading
Loading