Skip to content

Commit

Permalink
Calculate the proper counts for facets
Browse files Browse the repository at this point in the history
  • Loading branch information
Cito committed Aug 7, 2024
1 parent 2a8c48a commit 6cebc0c
Show file tree
Hide file tree
Showing 14 changed files with 244 additions and 113 deletions.
6 changes: 3 additions & 3 deletions src/mass/adapters/outbound/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@ def __init__(self, *, collection):
async def aggregate( # noqa: PLR0913, D102
self,
*,
selected_fields: list[models.FieldLabel],
facet_fields: list[models.FieldLabel],
query: str,
filters: list[models.Filter],
facet_fields: list[models.FieldLabel],
selected_fields: list[models.FieldLabel],
sorting_parameters: list[models.SortingParameter],
skip: int = 0,
limit: int | None = None,
sorting_parameters: list[models.SortingParameter],
) -> JsonObject:
# don't carry out aggregation if the collection is empty
if not await self._collection.find_one():
Expand Down
13 changes: 6 additions & 7 deletions src/mass/adapters/outbound/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,11 @@ def pipeline_facet_sort_and_paginate(
{
"$group": {
"_id": {"$getField": {"field": field, "input": path}},
"count": {"$sum": 1},
"uniqueIds": {"$addToSet": "$_id"},
}
},
{
"$addFields": {"value": "$_id"}
}, # rename "_id" to "value" on each option
{"$match": {"_id": {"$ne": None}}},
{"$addFields": {"value": "$_id", "count": {"$size": "$uniqueIds"}}},
{"$unset": "_id"},
{"$sort": {"value": 1}},
)
Expand Down Expand Up @@ -169,13 +168,13 @@ def pipeline_project(*, facet_fields: list[models.FieldLabel]) -> JsonObject:

def build_pipeline( # noqa: PLR0913
*,
query: str,
filters: list[models.Filter],
facet_fields: list[models.FieldLabel],
selected_fields: list[models.FieldLabel],
query: str,
filters: list[models.Filter],
sorting_parameters: list[models.SortingParameter],
skip: int = 0,
limit: int | None = None,
sorting_parameters: list[models.SortingParameter],
) -> list[JsonObject]:
"""Build aggregation pipeline based on query"""
pipeline: list[JsonObject] = []
Expand Down
10 changes: 6 additions & 4 deletions src/mass/core/query_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,19 @@ async def delete_resource(self, *, resource_id: str, class_name: str) -> None:
except ResourceNotFoundError as err:
raise self.ResourceNotFoundError(resource_id=resource_id) from err

async def handle_query( # noqa: D102, PLR0913
async def handle_query( # noqa: C901, D102, PLR0913
self,
*,
class_name: str,
query: str,
filters: list[models.Filter],
query: str = "",
filters: list[models.Filter] | None = None,
sorting_parameters: list[models.SortingParameter] | None = None,
skip: int = 0,
limit: int | None = None,
sorting_parameters: list[models.SortingParameter] | None = None,
) -> models.QueryResults:
# set empty list if not provided
if filters is None:
filters = []
if not sorting_parameters:
if query:
sorting_parameters = [
Expand Down
6 changes: 3 additions & 3 deletions src/mass/ports/inbound/query_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,11 @@ async def handle_query( # noqa: PLR0913
self,
*,
class_name: str,
query: str,
filters: list[models.Filter],
query: str = "",
filters: list[models.Filter] | None = None,
sorting_parameters: list[models.SortingParameter] | None = None,
skip: int = 0,
limit: int | None = None,
sorting_parameters: list[models.SortingParameter] | None = None,
) -> models.QueryResults:
"""Processes a query
Expand Down
6 changes: 3 additions & 3 deletions src/mass/ports/outbound/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,13 @@ class AggregatorPort(ABC):
async def aggregate( # noqa: PLR0913
self,
*,
selected_fields: list[models.FieldLabel],
facet_fields: list[models.FieldLabel],
query: str,
filters: list[models.Filter],
facet_fields: list[models.FieldLabel],
selected_fields: list[models.FieldLabel],
sorting_parameters: list[models.SortingParameter],
skip: int = 0,
limit: int | None = None,
sorting_parameters: list[models.SortingParameter],
) -> JsonObject:
"""Applies an aggregation pipeline to a mongodb collection"""
...
Expand Down
6 changes: 3 additions & 3 deletions tests/fixtures/joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,11 @@ def recreate_mongodb_indexes(self) -> None:
async def handle_query(
self,
class_name: str,
query: str,
filters: list[models.Filter],
query: str = "",
filters: list[models.Filter] | None = None,
sorting_parameters: list[models.SortingParameter] | None = None,
skip: int = 0,
limit: int | None = None,
sorting_parameters: list[models.SortingParameter] | None = None,
) -> models.QueryResults:
"""Handle a query."""
return await self._query_handler.handle_query(
Expand Down
18 changes: 7 additions & 11 deletions tests/fixtures/test_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,37 +21,29 @@ searchable_classes:
description: Dataset with embedded references.
facetable_fields:
- key: category
name: Category
- key: city
name: Field 1
- key: "object.type"
name: Object Type
- key: object.type
selected_fields:
- key: id_
name: ID
- key: type
name: Location Type
- key: "object.type"
name: Object Type
name: Location ype
- key: object.type
EmptyCollection:
description: An empty collection to test the index creation.
facetable_fields:
- key: fun_fact
name: Fun Fact
selected_fields: []
SortingTests:
description: Data for testing sorting functionality.
facetable_fields:
- key: field
name: Field
selected_fields: []
RelevanceTests:
description: Data for testing sorting by relevance.
facetable_fields:
- key: field
name: Field
- key: data
name: Data
selected_fields: []
FilteringTests:
description: Data for testing filtering on using single and multi-valued fields.
Expand All @@ -61,6 +53,10 @@ searchable_classes:
name: Food
- key: friends.name
name: Friend
- key: items.type
name: Item
- key: items.color
name: Item color
- key: special.features.fur.color
name: Fur color
selected_fields:
Expand Down
69 changes: 69 additions & 0 deletions tests/fixtures/test_data/FilteringTests.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,20 @@
}
],
"id_": "1",
"items": [
{
"color": "gold",
"type": "coin"
},
{
"color": "silver",
"type": "pistol"
},
{
"color": "red",
"type": "shirt"
}
],
"name": "Jack",
"special": {
"features": [
Expand Down Expand Up @@ -42,6 +56,20 @@
}
],
"id_": "2",
"items": [
{
"color": "pink",
"type": "collar"
},
{
"color": "pink",
"type": "shirt"
},
{
"color": "pink",
"type": "bowl"
}
],
"name": "Bruiser",
"special": {
"features": [
Expand Down Expand Up @@ -71,6 +99,32 @@
}
],
"id_": "3",
"items": [
{
"color": "blue",
"type": "collar"
},
{
"color": "green",
"type": "collar"
},
{
"color": "brown",
"type": "collar"
},
{
"color": "gold",
"type": "collar"
},
{
"color": "blue",
"type": "shirt"
},
{
"color": "white",
"type": "bowl"
}
],
"name": "Lady",
"special": {
"features": [
Expand Down Expand Up @@ -109,6 +163,20 @@
}
],
"id_": "4",
"items": [
{
"color": "red",
"type": "bowl"
},
{
"color": "white",
"type": "bowl"
},
{
"color": "yellow",
"type": "bowl"
}
],
"name": "Garfield",
"special": {
"features": [
Expand Down Expand Up @@ -144,6 +212,7 @@
}
],
"id_": "5",
"items": [],
"name": "Flipper",
"special": {
"features": [
Expand Down
14 changes: 4 additions & 10 deletions tests/test_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,7 @@ async def test_resource_upsert(
):
"""Try upserting with no pre-existing resource with matching ID (i.e. insert)"""
# get all the documents in the collection
results_all = await joint_fixture.handle_query(
class_name=CLASS_NAME, query="", filters=[]
)
results_all = await joint_fixture.handle_query(class_name=CLASS_NAME)
assert results_all.count > 0

# define content of resource
Expand Down Expand Up @@ -72,9 +70,7 @@ async def test_resource_upsert(
await joint_fixture.consume_event()

# verify that the resource was added
updated_resources = await joint_fixture.handle_query(
class_name=CLASS_NAME, query="", filters=[]
)
updated_resources = await joint_fixture.handle_query(class_name=CLASS_NAME)
if is_insert:
assert updated_resources.count - results_all.count == 1
else:
Expand All @@ -94,9 +90,7 @@ async def test_resource_delete(joint_fixture: JointFixture):
"""Test resource deletion via event consumption"""
# get all the documents in the collection
targeted_initial_results = await joint_fixture.handle_query(
class_name=CLASS_NAME,
query='"1HotelAlpha-id"',
filters=[],
class_name=CLASS_NAME, query='"1HotelAlpha-id"'
)
assert targeted_initial_results.count == 1
assert targeted_initial_results.hits[0].id_ == "1HotelAlpha-id"
Expand All @@ -117,7 +111,7 @@ async def test_resource_delete(joint_fixture: JointFixture):

# get all the documents in the collection
results_post_delete = await joint_fixture.handle_query(
class_name=CLASS_NAME, query='"1HotelAlpha-id"', filters=[]
class_name=CLASS_NAME, query='"1HotelAlpha-id"'
)

assert results_post_delete.count == 0
Loading

0 comments on commit 6cebc0c

Please sign in to comment.