-
Notifications
You must be signed in to change notification settings - Fork 171
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add SearchAfterMixin for ES search_after capability #4536
Open
Ali-D-Akbar
wants to merge
3
commits into
master
Choose a base branch
from
aakbar/PROD-4233
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
94 changes: 94 additions & 0 deletions
94
course_discovery/apps/api/v2/tests/test_views/test_catalog_queries.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
import urllib | ||
|
||
from rest_framework.reverse import reverse | ||
|
||
from course_discovery.apps.api.v1.tests.test_views.mixins import APITestCase | ||
from course_discovery.apps.core.tests.factories import UserFactory | ||
from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin | ||
from course_discovery.apps.course_metadata.tests.factories import CourseFactory, CourseRunFactory | ||
|
||
|
||
class CatalogQueryViewSetTests(ElasticsearchTestMixin, APITestCase): | ||
""" | ||
Unit tests for CatalogQueryViewSet. | ||
""" | ||
def setUp(self): | ||
super().setUp() | ||
self.user = UserFactory(is_staff=True, is_superuser=True) | ||
self.client.force_authenticate(self.user) | ||
self.course = CourseFactory(partner=self.partner, key='simple_key') | ||
self.course_run = CourseRunFactory(course=self.course, key='simple/key/run') | ||
self.url_base = reverse('api:v2:catalog-query_contains') | ||
self.error_message = 'CatalogQueryContains endpoint requires query and identifiers list(s)' | ||
self.refresh_index() | ||
|
||
def test_contains_single_course_run(self): | ||
""" Verify that a single course_run is contained in a query. """ | ||
qs = urllib.parse.urlencode({ | ||
'query': 'id:' + self.course_run.key, | ||
'course_run_ids': self.course_run.key, | ||
'course_uuids': self.course.uuid, | ||
}) | ||
url = f'{self.url_base}/?{qs}' | ||
response = self.client.get(url) | ||
assert response.status_code == 200 | ||
assert response.data == {self.course_run.key: True, str(self.course.uuid): False} | ||
|
||
def test_contains_single_course(self): | ||
""" Verify that a single course is contained in a query. """ | ||
qs = urllib.parse.urlencode({ | ||
'query': 'key:' + self.course.key, | ||
'course_run_ids': self.course_run.key, | ||
'course_uuids': self.course.uuid, | ||
}) | ||
url = f'{self.url_base}/?{qs}' | ||
response = self.client.get(url) | ||
assert response.status_code == 200 | ||
assert response.data == {self.course_run.key: False, str(self.course.uuid): True} | ||
|
||
def test_contains_course_and_run(self): | ||
""" Verify that both the course and the run are contained in the broadest query. """ | ||
self.course.course_runs.add(self.course_run) | ||
self.course.save() | ||
qs = urllib.parse.urlencode({ | ||
'query': 'org:*', | ||
'course_run_ids': self.course_run.key, | ||
'course_uuids': self.course.uuid, | ||
}) | ||
url = f'{self.url_base}/?{qs}' | ||
response = self.client.get(url) | ||
assert response.status_code == 200 | ||
assert response.data == {self.course_run.key: True, str(self.course.uuid): True} | ||
|
||
def test_no_identifiers(self): | ||
""" Verify that a 400 status is returned if request does not contain any identifier lists. """ | ||
qs = urllib.parse.urlencode({ | ||
'query': 'id:*' | ||
}) | ||
url = f'{self.url_base}/?{qs}' | ||
response = self.client.get(url) | ||
assert response.status_code == 400 | ||
assert response.data == self.error_message | ||
|
||
def test_no_query(self): | ||
""" Verify that a 400 status is returned if request does not contain a querystring. """ | ||
qs = urllib.parse.urlencode({ | ||
'course_run_ids': self.course_run.key, | ||
'course_uuids': self.course.uuid, | ||
}) | ||
url = f'{self.url_base}/?{qs}' | ||
response = self.client.get(url) | ||
assert response.status_code == 400 | ||
assert response.data == self.error_message | ||
|
||
def test_incorrect_queries(self): | ||
""" Verify that a 400 status is returned if request contains incorrect query string. """ | ||
qs = urllib.parse.urlencode({ | ||
'query': 'title:', | ||
'course_run_ids': self.course_run.key, | ||
'course_uuids': self.course.uuid, | ||
}) | ||
url = f'{self.url_base}/?{qs}' | ||
|
||
response = self.client.get(url) | ||
assert response.status_code == 400 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,17 @@ | ||
"""API v2 URLs.""" | ||
|
||
from django.urls import re_path | ||
from rest_framework import routers | ||
|
||
from course_discovery.apps.api.v2.views import search as search_views | ||
from course_discovery.apps.api.v2.views.catalog_queries import CatalogQueryContainsViewSet | ||
|
||
app_name = 'v2' | ||
|
||
urlpatterns = [ | ||
re_path(r'^catalog/query_contains/?', CatalogQueryContainsViewSet.as_view(), name='catalog-query_contains'), | ||
] | ||
|
||
router = routers.SimpleRouter() | ||
router.register(r'search/all', search_views.AggregateSearchViewSet, basename='search-all') | ||
urlpatterns = router.urls | ||
urlpatterns += router.urls |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
import logging | ||
from uuid import UUID | ||
|
||
from elasticsearch_dsl.query import Q as ESDSLQ | ||
from rest_framework import status | ||
from rest_framework.generics import GenericAPIView | ||
from rest_framework.permissions import DjangoModelPermissions, IsAuthenticated | ||
from rest_framework.response import Response | ||
|
||
from course_discovery.apps.api.mixins import ValidElasticSearchQueryRequiredMixin | ||
from course_discovery.apps.course_metadata.models import Course, CourseRun, SearchAfterMixin | ||
from course_discovery.apps.course_metadata.search_indexes.documents import CourseDocument, CourseRunDocument | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
class CatalogQueryContainsViewSet(ValidElasticSearchQueryRequiredMixin, GenericAPIView, SearchAfterMixin): | ||
permission_classes = (IsAuthenticated, DjangoModelPermissions) | ||
queryset = Course.objects.all() | ||
|
||
def get(self, request): | ||
""" | ||
Determine if a set of courses and/or course runs is found in the query results. | ||
|
||
Returns | ||
dict: mapping of course and run identifiers included in the request to boolean values | ||
indicating whether the associated course or run is contained in the queryset | ||
described by the query found in the request. | ||
""" | ||
query = request.GET.get('query') | ||
course_run_ids = request.GET.get('course_run_ids', None) | ||
course_uuids = request.GET.get('course_uuids', None) | ||
partner = self.request.site.partner | ||
|
||
if query and (course_run_ids or course_uuids): | ||
log.info( | ||
f"Attempting search against query {query} with course UUIDs {course_uuids} " | ||
f"and course run IDs {course_run_ids}" | ||
) | ||
identified_course_ids = set() | ||
specified_course_ids = [] | ||
if course_run_ids: | ||
course_run_ids = course_run_ids.split(',') | ||
specified_course_ids = course_run_ids | ||
identified_course_ids.update( | ||
self.search( | ||
query, | ||
queryset=CourseRun.objects.all(), | ||
partner=ESDSLQ('term', partner=partner.short_code), | ||
identifiers=ESDSLQ('terms', **{'key.raw': course_run_ids}), | ||
document=CourseRunDocument | ||
).values_list('key', flat=True) | ||
) | ||
|
||
if course_uuids: | ||
course_uuids = [UUID(course_uuid) for course_uuid in course_uuids.split(',')] | ||
specified_course_ids += course_uuids | ||
|
||
log.info(f"Specified course ids: {specified_course_ids}") | ||
identified_course_ids.update( | ||
self.search( | ||
query, | ||
queryset=Course.objects.all(), | ||
partner=ESDSLQ('term', partner=partner.short_code), | ||
identifiers=ESDSLQ('terms', **{'uuid': course_uuids}), | ||
document=CourseDocument | ||
).values_list('uuid', flat=True) | ||
) | ||
log.info(f"Identified {len(identified_course_ids)} course ids: {identified_course_ids}") | ||
|
||
contains = {str(identifier): identifier in identified_course_ids for identifier in specified_course_ids} | ||
return Response(contains) | ||
return Response( | ||
'CatalogQueryContains endpoint requires query and identifiers list(s)', status=status.HTTP_400_BAD_REQUEST | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -1149,6 +1149,75 @@ def search(cls, query, queryset=None): | |||
return filtered_queryset | ||||
|
||||
|
||||
class SearchAfterMixin: | ||||
""" | ||||
Represents objects to query Elasticsearch with `search_after` pagination and load by primary key. | ||||
""" | ||||
|
||||
@classmethod | ||||
def search(cls, query, queryset=None, page_size=settings.ELASTICSEARCH_DSL_QUERYSET_PAGINATION, partner=None, | ||||
identifiers=None, document=None): | ||||
""" | ||||
Queries the Elasticsearch index with optional pagination using `search_after`. | ||||
|
||||
Args: | ||||
query (str) -- Elasticsearch querystring (e.g. `title:intro*`) | ||||
queryset (models.QuerySet) -- base queryset to search, defaults to objects.all() | ||||
page_size (int) -- Number of results per page. | ||||
partner (object) -- To be included in the ES query. | ||||
identifiers (object) -- UUID or key of a product. | ||||
|
||||
Returns: | ||||
QuerySet | ||||
Ali-D-Akbar marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
""" | ||||
query = clean_query(query) | ||||
queryset = queryset or cls.objects.all() | ||||
|
||||
if query == '(*)': | ||||
# Early-exit optimization. Wildcard searching is very expensive in elasticsearch. And since we just | ||||
# want everything, we don't need to actually query elasticsearch at all. | ||||
return queryset | ||||
Ali-D-Akbar marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
|
||||
logger.info(f"Attempting Elasticsearch document search against query: {query}") | ||||
es_document = document or next(iter(registry.get_documents(models=(cls,))), None) | ||||
|
||||
must_queries = [ESDSLQ('query_string', query=query, analyze_wildcard=True)] | ||||
if partner: | ||||
must_queries.append(partner) | ||||
if identifiers: | ||||
must_queries.append(identifiers) | ||||
|
||||
dsl_query = ESDSLQ('bool', must=must_queries) | ||||
|
||||
all_ids = set() | ||||
search_after = None | ||||
|
||||
while True: | ||||
search = ( | ||||
es_document.search() | ||||
.query(dsl_query) | ||||
.sort('id') | ||||
.extra(size=page_size) | ||||
) | ||||
|
||||
search = search.extra(search_after=search_after) if search_after else search | ||||
|
||||
results = search.execute() | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. curious: should we not add error handling here, in case any of the sub-sequent request fails? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The error is already handled in this custom dispatch
|
||||
|
||||
ids = {result.pk for result in results} | ||||
if not ids: | ||||
logger.info("No more results found.") | ||||
break | ||||
|
||||
all_ids.update(ids) | ||||
search_after = results[-1].meta.sort if results[-1] else None | ||||
logger.info(f"Fetched {len(ids)} records; total so far: {len(all_ids)}") | ||||
|
||||
filtered_queryset = queryset.filter(pk__in=all_ids) | ||||
logger.info(f"Filtered queryset of size {len(filtered_queryset)} for query: {query}") | ||||
return filtered_queryset | ||||
|
||||
|
||||
class Collaborator(TimeStampedModel): | ||||
""" | ||||
Collaborator model, defining any collaborators who helped write course content. | ||||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
45 changes: 45 additions & 0 deletions
45
course_discovery/apps/course_metadata/tests/test_mixins.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
from unittest.mock import patch | ||
|
||
from django.test import TestCase | ||
|
||
from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin | ||
from course_discovery.apps.course_metadata.search_indexes.documents import CourseDocument | ||
from course_discovery.apps.course_metadata.tests import factories | ||
|
||
|
||
class TestSearchAfterMixin(ElasticsearchTestMixin, TestCase): | ||
""" | ||
Unit tests for SearchAfterMixin. | ||
Uses a proxy model `CourseProxy` that extends this mixin so we can replicate the behavior for Courses. | ||
""" | ||
def setUp(self): | ||
super().setUp() | ||
|
||
self.total_courses = 5 | ||
factories.CourseFactory.create_batch(self.total_courses) | ||
|
||
@patch("course_discovery.apps.course_metadata.models.registry.get_documents") | ||
def test_fetch_all_courses(self, mock_get_documents): | ||
query = 'Course*' | ||
mock_get_documents.return_value = [CourseDocument] | ||
|
||
queryset = factories.CourseProxy.search(query=query, page_size=2) | ||
|
||
unique_items = set(queryset) | ||
self.assertEqual(len(queryset), len(unique_items), 'Queryset contains duplicate entries.') | ||
self.assertEqual(len(queryset), self.total_courses) | ||
|
||
def test_wildcard_query_early_exit(self): | ||
""" | ||
Test the early exit optimization when the query is `(*)`. | ||
""" | ||
query = '*' | ||
|
||
queryset = factories.CourseProxy.search(query=query) | ||
|
||
self.assertEqual(len(queryset), self.total_courses) | ||
self.assertQuerysetEqual( | ||
queryset.order_by("id"), | ||
factories.Course.objects.all().order_by("id"), | ||
transform=lambda x: x | ||
) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this uses values_list while course_run_ids is using comprehension. We can should make it consistent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's rather consistent now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it? The code is still the same.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I misunderstood the earlier comment. Lemme use values_list for both.