From 345822c1d5ac06de46163abb8efaee8a3da8011c Mon Sep 17 00:00:00 2001 From: Jen Hamon Date: Fri, 18 Oct 2024 09:33:35 -0400 Subject: [PATCH] More test cases for aggregator --- .gitignore | 5 +- pinecone/grpc/index_grpc_asyncio.py | 39 +++-- pinecone/grpc/pinecone.py | 8 +- pinecone/grpc/query_results_aggregator.py | 72 ++++++++-- .../test_query_results_aggregator.py | 135 +++++++++++------- 5 files changed, 176 insertions(+), 83 deletions(-) diff --git a/.gitignore b/.gitignore index 4200d51d..52d700b2 100644 --- a/.gitignore +++ b/.gitignore @@ -137,7 +137,7 @@ venv.bak/ .ropeproject # pdocs documentation -# We want to exclude any locally generated artifacts, but we rely on +# We want to exclude any locally generated artifacts, but we rely on # keeping documentation assets in the docs/ folder. docs/* !docs/pinecone-python-client-fork.png @@ -155,4 +155,5 @@ dmypy.json *.hdf5 *~ -tests/integration/proxy_config/logs \ No newline at end of file +tests/integration/proxy_config/logs +*.parquet diff --git a/pinecone/grpc/index_grpc_asyncio.py b/pinecone/grpc/index_grpc_asyncio.py index 7a9aa8da..593fd3c3 100644 --- a/pinecone/grpc/index_grpc_asyncio.py +++ b/pinecone/grpc/index_grpc_asyncio.py @@ -112,7 +112,18 @@ async def upsert( for batch in vector_batches ] - return await tqdm.gather(*tasks, disable=not show_progress, desc="Upserted batches") + if namespace is not None: + pbar_desc = f"Upserted vectors in namespace '{namespace}'" + else: + pbar_desc = "Upserted vectors in namespace ''" + + upserted_count = 0 + with tqdm(total=len(vectors), disable=not show_progress, desc=pbar_desc) as pbar: + for task in asyncio.as_completed(tasks): + res = await task + pbar.update(res.upserted_count) + upserted_count += res.upserted_count + return UpsertResponse(upserted_count=upserted_count) async def _upsert_batch( self, @@ -173,12 +184,12 @@ async def _query( ) return json_format.MessageToDict(response) - async def composite_query( + async def query( self, vector: Optional[List[float]] = None, id: Optional[str] = None, namespace: Optional[str] = None, - top_k: Optional[int] = None, + top_k: Optional[int] = 10, filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None, include_values: Optional[bool] = None, include_metadata: Optional[bool] = None, @@ -244,12 +255,11 @@ async def composite_query( ) return parse_query_response(json_response, _check_type=False) - async def multi_namespace_query( + async def composite_query( self, vector: Optional[List[float]] = None, - id: Optional[str] = None, - namespaces: Optional[str] = None, - top_k: Optional[int] = None, + namespaces: Optional[List[str]] = None, + top_k: Optional[int] = 10, filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None, include_values: Optional[bool] = None, include_metadata: Optional[bool] = None, @@ -258,12 +268,13 @@ async def multi_namespace_query( semaphore: Optional[asyncio.Semaphore] = None, **kwargs, ) -> Awaitable[CompositeQueryResults]: + aggregator_lock = asyncio.Lock() semaphore = self._get_semaphore(max_concurrent_requests, semaphore) + aggregator = QueryResultsAggregator(top_k=top_k) - queries = [ + query_tasks = [ self._query( vector=vector, - id=id, namespace=ns, top_k=top_k, filter=filter, @@ -276,13 +287,11 @@ async def multi_namespace_query( for ns in namespaces ] - results = await asyncio.gather(*queries, return_exceptions=True) + for query_task in asyncio.as_completed(query_tasks): + response = await query_task + async with aggregator_lock: + aggregator.add_results(response) - aggregator = QueryResultsAggregator(top_k=top_k) - for result in results: - if isinstance(result, Exception): - continue - aggregator.add_results(result) final_results = aggregator.get_results() return final_results diff --git a/pinecone/grpc/pinecone.py b/pinecone/grpc/pinecone.py index 1878a940..8c3b1c00 100644 --- a/pinecone/grpc/pinecone.py +++ b/pinecone/grpc/pinecone.py @@ -48,7 +48,7 @@ class PineconeGRPC(Pinecone): """ - def Index(self, name: str = "", host: str = "", use_asyncio=False, **kwargs): + def Index(self, name: str = "", host: str = "", **kwargs): """ Target an index for data operations. @@ -119,6 +119,12 @@ def Index(self, name: str = "", host: str = "", use_asyncio=False, **kwargs): index.query(vector=[...], top_k=10) ``` """ + return self._init_index(name=name, host=host, use_asyncio=False, **kwargs) + + def AsyncioIndex(self, name: str = "", host: str = "", **kwargs): + return self._init_index(name=name, host=host, use_asyncio=True, **kwargs) + + def _init_index(self, name: str, host: str, use_asyncio=False, **kwargs): if name == "" and host == "": raise ValueError("Either name or host must be specified") diff --git a/pinecone/grpc/query_results_aggregator.py b/pinecone/grpc/query_results_aggregator.py index 71da3868..238cff17 100644 --- a/pinecone/grpc/query_results_aggregator.py +++ b/pinecone/grpc/query_results_aggregator.py @@ -81,39 +81,85 @@ def __repr__(self): ) +class QueryResultsAggregationEmptyResultsError(Exception): + def __init__(self, namespace: str): + super().__init__( + f"Cannot infer metric type from empty query results. Query result for namespace '{namespace}' is empty. Have you spelled the namespace name correctly?" + ) + + +class QueryResultsAggregregatorNotEnoughResultsError(Exception): + def __init__(self, top_k: int, num_results: int): + super().__init__( + f"Cannot interpret results without at least two matches. In order to aggregate results from multiple queries, top_k must be greater than 1 in order to correctly infer the similarity metric from scores. Expected at least {top_k} results but got {num_results}." + ) + + +class QueryResultsAggregatorInvalidTopKError(Exception): + def __init__(self, top_k: int): + super().__init__(f"Invalid top_k value {top_k}. top_k must be a positive integer.") + + class QueryResultsAggregator: def __init__(self, top_k: int): + if top_k < 1: + raise QueryResultsAggregatorInvalidTopKError(top_k) self.top_k = top_k self.usage_read_units = 0 self.heap = [] self.insertion_counter = 0 + self.is_dotproduct = None self.read = False + def __is_dotproduct_index(self, matches): + # The interpretation of the score depends on the similar metric used. + # Unlike other index types, in indexes configured for dotproduct, + # a higher score is better. We have to infer this is the case by inspecting + # the order of the scores in the results. + for i in range(1, len(matches)): + if matches[i].get("score") > matches[i - 1].get("score"): # Found an increase + return False + return True + def add_results(self, results: QueryResponse): if self.read: + # This is mainly just to sanity check in test cases which get quite confusing + # if you read results twice due to the heap being emptied when constructing + # the ordered results. raise ValueError("Results have already been read. Cannot add more results.") - self.usage_read_units += results.get("usage", {}).get("readUnits", 0) + matches = results.get("matches", []) ns = results.get("namespace") - for match in results.get("matches", []): - self.insertion_counter += 1 - score = match.get("score") - if len(self.heap) < self.top_k: - heapq.heappush(self.heap, (-score, -self.insertion_counter, match, ns)) - else: - heapq.heappushpop(self.heap, (-score, -self.insertion_counter, match, ns)) + self.usage_read_units += results.get("usage", {}).get("readUnits", 0) + + if self.is_dotproduct is None: + if len(matches) == 0: + raise QueryResultsAggregationEmptyResultsError(ns) + if len(matches) == 1: + raise QueryResultsAggregregatorNotEnoughResultsError(self.top_k, len(matches)) + self.is_dotproduct = self.__is_dotproduct_index(matches) + + print("is_dotproduct:", self.is_dotproduct) + if self.is_dotproduct: + raise NotImplementedError("Dotproduct indexes are not yet supported.") + else: + for match in matches: + self.insertion_counter += 1 + score = match.get("score") + if len(self.heap) < self.top_k: + heapq.heappush(self.heap, (-score, -self.insertion_counter, match, ns)) + else: + heapq.heappushpop(self.heap, (-score, -self.insertion_counter, match, ns)) def get_results(self) -> CompositeQueryResults: if self.read: - # This is mainly just to sanity check in test cases which get quite confusing - # if you read results twice due to the heap being emptied each time you read - # results into an ordered list. - raise ValueError("Results have already been read. Cannot read again.") + return self.final_results self.read = True - return CompositeQueryResults( + self.final_results = CompositeQueryResults( usage=Usage(read_units=self.usage_read_units), matches=[ ScoredVectorWithNamespace(heapq.heappop(self.heap)) for _ in range(len(self.heap)) ][::-1], ) + return self.final_results diff --git a/tests/unit_grpc/test_query_results_aggregator.py b/tests/unit_grpc/test_query_results_aggregator.py index 85ee680d..d4213ae6 100644 --- a/tests/unit_grpc/test_query_results_aggregator.py +++ b/tests/unit_grpc/test_query_results_aggregator.py @@ -1,13 +1,12 @@ -from pinecone.grpc.query_results_aggregator import QueryResultsAggregator +from pinecone.grpc.query_results_aggregator import ( + QueryResultsAggregator, + QueryResultsAggregatorInvalidTopKError, + QueryResultsAggregregatorNotEnoughResultsError, +) +import pytest class TestQueryResultsAggregator: - def test_empty_results(self): - aggregator = QueryResultsAggregator(top_k=3) - results = aggregator.get_results() - assert results.usage.read_units == 0 - assert len(results.matches) == 0 - def test_keeps_running_usage_total(self): aggregator = QueryResultsAggregator(top_k=3) @@ -87,40 +86,44 @@ def test_inserting_duplicate_scores_stable_ordering(self): assert results.matches[4].id == "4" # 0.22 assert results.matches[4].namespace == "ns1" - # def test_returns_topk(self): - # aggregator = QueryResultsAggregator(top_k=5) + def test_correctly_handles_dotproduct_metric(self): + # For this index metric, the higher the score, the more similar the vectors are. + # We have to infer that we have this type of index by seeing whether scores are + # sorted in descending or ascending order. + aggregator = QueryResultsAggregator(top_k=3) - # results1 = QueryResponse( - # matches=[ - # ScoredVector(id="1", score=0.1, vector=[]), - # ScoredVector(id="2", score=0.11, vector=[]), - # ScoredVector(id="3", score=0.12, vector=[]), - # ScoredVector(id="4", score=0.13, vector=[]), - # ScoredVector(id="5", score=0.14, vector=[]), - # ], - # usage=Usage(read_units=5) - # ) - # aggregator.add_results(results1) + desc_results1 = { + "matches": [ + {"id": "1", "score": 0.9, "values": []}, + {"id": "2", "score": 0.8, "values": []}, + {"id": "3", "score": 0.7, "values": []}, + {"id": "4", "score": 0.6, "values": []}, + {"id": "5", "score": 0.5, "values": []}, + ], + "usage": {"readUnits": 5}, + "namespace": "ns1", + } + aggregator.add_results(desc_results1) - # results2 = QueryResponse( - # matches=[ - # ScoredVector(id="7", score=0.101, vector=[]), - # ScoredVector(id="8", score=0.102, vector=[]), - # ScoredVector(id="9", score=0.121, vector=[]), - # ScoredVector(id="10", score=0.2, vector=[]), - # ScoredVector(id="11", score=0.4, vector=[]), - # ], - # usage=Usage(read_units=7) - # ) - # aggregator.add_results(results2) + results2 = { + "matches": [ + {"id": "7", "score": 0.77, "values": []}, + {"id": "8", "score": 0.88, "values": []}, + {"id": "9", "score": 0.99, "values": []}, + {"id": "10", "score": 0.1010, "values": []}, + {"id": "11", "score": 0.1111, "values": []}, + ], + "usage": {"readUnits": 7}, + "namespace": "ns2", + } + aggregator.add_results(results2) - # combined = aggregator.get_results() - # assert len(combined.matches) == 5 - # assert combined.matches[0].id == "1" # 0.1 - # assert combined.matches[1].id == "7" # 0.101 - # assert combined.matches[2].id == "8" # 0.102 - # assert combined.matches[3].id == "3" # 0.12 - # assert combined.matches[4].id == "9" # 0.121 + results = aggregator.get_results() + assert results.usage.read_units == 12 + assert len(results.matches) == 3 + assert results.matches[0].id == "9" # 0.99 + assert results.matches[1].id == "1" # 0.9 + assert results.matches[2].id == "8" # 0.88 class TestQueryResultsAggregatorOutputUX: @@ -139,7 +142,8 @@ def test_can_interact_with_attributes(self): "list": [1, 2, 3], "list2": ["foo", "bar"], }, - } + }, + {"id": "2", "score": 0.4}, ], "usage": {"readUnits": 5}, "namespace": "ns1", @@ -153,7 +157,7 @@ def test_can_interact_with_attributes(self): assert results.matches[0].values == [0.31, 0.32, 0.33, 0.34, 0.35, 0.36] def test_can_interact_like_dict(self): - aggregator = QueryResultsAggregator(top_k=1) + aggregator = QueryResultsAggregator(top_k=3) results1 = { "matches": [ { @@ -167,7 +171,19 @@ def test_can_interact_like_dict(self): "list": [1, 2, 3], "list2": ["foo", "bar"], }, - } + }, + { + "id": "2", + "score": 0.4, + "values": [0.31, 0.32, 0.33, 0.34, 0.35, 0.36], + "sparse_values": {"indices": [1, 2], "values": [0.2, 0.4]}, + "metadata": { + "hello": "world", + "number": 42, + "list": [1, 2, 3], + "list2": ["foo", "bar"], + }, + }, ], "usage": {"readUnits": 5}, "namespace": "ns1", @@ -235,27 +251,42 @@ def test_can_print_complete_results_without_error(self, capsys): "list": [1, 2, 3], "list2": ["foo", "bar"], }, - } - ], - "usage": {"readUnits": 5}, - "namespace": "ns1", - } - aggregator.add_results(results1) - - results1 = { - "matches": [ + }, { "id": "2", "score": 0.4, "values": [0.31, 0.32, 0.33, 0.34, 0.35, 0.36], "sparse_values": {"indices": [1, 2], "values": [0.2, 0.4]}, "metadata": {"boolean": True, "nullish": None}, - } + }, ], "usage": {"readUnits": 5}, - "namespace": "ns2", + "namespace": "ns1", } aggregator.add_results(results1) results = aggregator.get_results() print(results) capsys.readouterr() + + +class TestQueryAggregatorEdgeCases: + def test_topK_too_small(self): + with pytest.raises(QueryResultsAggregatorInvalidTopKError): + QueryResultsAggregator(top_k=0) + + def test_matches_too_small(self): + aggregator = QueryResultsAggregator(top_k=3) + results1 = { + "matches": [{"id": "1", "score": 0.1}], + "usage": {"readUnits": 5}, + "namespace": "ns1", + } + with pytest.raises(QueryResultsAggregregatorNotEnoughResultsError): + aggregator.add_results(results1) + + def test_empty_results(self): + aggregator = QueryResultsAggregator(top_k=3) + results = aggregator.get_results() + assert results is not None + assert results.usage.read_units == 0 + assert len(results.matches) == 0