diff --git a/pinecone/grpc/query_results_aggregator.py b/pinecone/grpc/query_results_aggregator.py index 0ef92be5..e1608c1f 100644 --- a/pinecone/grpc/query_results_aggregator.py +++ b/pinecone/grpc/query_results_aggregator.py @@ -89,20 +89,22 @@ def __init__(self, namespace: str): class QueryResultsAggregregatorNotEnoughResultsError(Exception): - def __init__(self, top_k: int, num_results: int): + def __init__(self, num_results: int): super().__init__( - f"Cannot interpret results without at least two matches. In order to aggregate results from multiple queries, top_k must be greater than 1 in order to correctly infer the similarity metric from scores. Expected at least {top_k} results but got {num_results}." + "Cannot interpret results without at least two matches. In order to aggregate results from multiple queries, top_k must be greater than 1 in order to correctly infer the similarity metric from scores." ) class QueryResultsAggregatorInvalidTopKError(Exception): def __init__(self, top_k: int): - super().__init__(f"Invalid top_k value {top_k}. top_k must be a positive integer.") + super().__init__( + f"Invalid top_k value {top_k}. To aggregate results from multiple queries the top_k must be at least 2." + ) class QueryResultsAggregator: def __init__(self, top_k: int): - if top_k < 1: + if top_k < 2: raise QueryResultsAggregatorInvalidTopKError(top_k) self.top_k = top_k self.usage_read_units = 0 @@ -155,11 +157,14 @@ def add_results(self, results: Dict[str, Any]): self.usage_read_units += results.get("usage", {}).get("readUnits", 0) if len(matches) == 0: - raise QueryResultsAggregationEmptyResultsError(ns) + return if self.is_dotproduct is None: if len(matches) == 1: - raise QueryResultsAggregregatorNotEnoughResultsError(self.top_k, len(matches)) + # This condition should match the second time we add results containing + # only one match. We need at least two matches in a single response in order + # to infer the similarity metric + raise QueryResultsAggregregatorNotEnoughResultsError(len(matches)) self.is_dotproduct = self._is_dotproduct_index(matches) if self.is_dotproduct: diff --git a/tests/unit_grpc/test_query_results_aggregator.py b/tests/unit_grpc/test_query_results_aggregator.py index e78d255b..b4c78802 100644 --- a/tests/unit_grpc/test_query_results_aggregator.py +++ b/tests/unit_grpc/test_query_results_aggregator.py @@ -269,7 +269,7 @@ def test_still_correct_with_early_return_generated_dotproduct(self): class TestQueryResultsAggregatorOutputUX: def test_can_interact_with_attributes(self): - aggregator = QueryResultsAggregator(top_k=1) + aggregator = QueryResultsAggregator(top_k=2) results1 = { "matches": [ { @@ -414,6 +414,8 @@ class TestQueryAggregatorEdgeCases: def test_topK_too_small(self): with pytest.raises(QueryResultsAggregatorInvalidTopKError): QueryResultsAggregator(top_k=0) + with pytest.raises(QueryResultsAggregatorInvalidTopKError): + QueryResultsAggregator(top_k=1) def test_matches_too_small(self): aggregator = QueryResultsAggregator(top_k=3) @@ -431,3 +433,121 @@ def test_empty_results(self): assert results is not None assert results.usage.read_units == 0 assert len(results.matches) == 0 + + def test_empty_results_with_usage(self): + aggregator = QueryResultsAggregator(top_k=3) + + aggregator.add_results({"matches": [], "usage": {"readUnits": 5}, "namespace": "ns1"}) + aggregator.add_results({"matches": [], "usage": {"readUnits": 5}, "namespace": "ns2"}) + aggregator.add_results({"matches": [], "usage": {"readUnits": 5}, "namespace": "ns3"}) + + results = aggregator.get_results() + assert results is not None + assert results.usage.read_units == 15 + assert len(results.matches) == 0 + + def test_exactly_one_result(self): + aggregator = QueryResultsAggregator(top_k=3) + results1 = { + "matches": [{"id": "2", "score": 0.01}, {"id": "3", "score": 0.2}], + "usage": {"readUnits": 5}, + "namespace": "ns2", + } + aggregator.add_results(results1) + + results2 = { + "matches": [{"id": "1", "score": 0.1}], + "usage": {"readUnits": 5}, + "namespace": "ns1", + } + aggregator.add_results(results2) + results = aggregator.get_results() + assert results.usage.read_units == 10 + assert len(results.matches) == 3 + assert results.matches[0].id == "2" + assert results.matches[0].namespace == "ns2" + assert results.matches[0].score == 0.01 + assert results.matches[1].id == "1" + assert results.matches[1].namespace == "ns1" + assert results.matches[1].score == 0.1 + assert results.matches[2].id == "3" + assert results.matches[2].namespace == "ns2" + assert results.matches[2].score == 0.2 + + def test_two_result_sets_with_single_result_errors(self): + with pytest.raises(QueryResultsAggregregatorNotEnoughResultsError): + aggregator = QueryResultsAggregator(top_k=3) + results1 = { + "matches": [{"id": "1", "score": 0.1}], + "usage": {"readUnits": 5}, + "namespace": "ns1", + } + aggregator.add_results(results1) + results2 = { + "matches": [{"id": "2", "score": 0.01}], + "usage": {"readUnits": 5}, + "namespace": "ns2", + } + aggregator.add_results(results2) + + def test_single_result_after_index_type_known_no_error(self): + aggregator = QueryResultsAggregator(top_k=3) + + results3 = { + "matches": [{"id": "2", "score": 0.01}, {"id": "3", "score": 0.2}], + "usage": {"readUnits": 5}, + "namespace": "ns3", + } + aggregator.add_results(results3) + + results1 = { + "matches": [{"id": "1", "score": 0.1}], + "usage": {"readUnits": 5}, + "namespace": "ns1", + } + aggregator.add_results(results1) + results2 = {"matches": [], "usage": {"readUnits": 5}, "namespace": "ns2"} + aggregator.add_results(results2) + + results = aggregator.get_results() + assert results.usage.read_units == 15 + assert len(results.matches) == 3 + assert results.matches[0].id == "2" + assert results.matches[0].namespace == "ns3" + assert results.matches[0].score == 0.01 + assert results.matches[1].id == "1" + assert results.matches[1].namespace == "ns1" + assert results.matches[1].score == 0.1 + assert results.matches[2].id == "3" + assert results.matches[2].namespace == "ns3" + assert results.matches[2].score == 0.2 + + def test_all_empty_results(self): + aggregator = QueryResultsAggregator(top_k=10) + + aggregator.add_results({"matches": [], "usage": {"readUnits": 5}, "namespace": "ns1"}) + aggregator.add_results({"matches": [], "usage": {"readUnits": 5}, "namespace": "ns2"}) + aggregator.add_results({"matches": [], "usage": {"readUnits": 5}, "namespace": "ns3"}) + + results = aggregator.get_results() + + assert results.usage.read_units == 15 + assert len(results.matches) == 0 + + def test_some_empty_results(self): + aggregator = QueryResultsAggregator(top_k=10) + results2 = { + "matches": [{"id": "2", "score": 0.01}, {"id": "3", "score": 0.2}], + "usage": {"readUnits": 5}, + "namespace": "ns0", + } + aggregator.add_results(results2) + + aggregator.add_results({"matches": [], "usage": {"readUnits": 5}, "namespace": "ns1"}) + aggregator.add_results({"matches": [], "usage": {"readUnits": 5}, "namespace": "ns2"}) + aggregator.add_results({"matches": [], "usage": {"readUnits": 5}, "namespace": "ns3"}) + + results = aggregator.get_results() + + assert results.usage.read_units == 20 + assert len(results.matches) == 2