Skip to content

Commit

Permalink
More test cases for aggregator
Browse files Browse the repository at this point in the history
  • Loading branch information
jhamon committed Oct 18, 2024
1 parent 6d99df7 commit 345822c
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 83 deletions.
5 changes: 3 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ venv.bak/
.ropeproject

# pdocs documentation
# We want to exclude any locally generated artifacts, but we rely on
# We want to exclude any locally generated artifacts, but we rely on
# keeping documentation assets in the docs/ folder.
docs/*
!docs/pinecone-python-client-fork.png
Expand All @@ -155,4 +155,5 @@ dmypy.json
*.hdf5
*~

tests/integration/proxy_config/logs
tests/integration/proxy_config/logs
*.parquet
39 changes: 24 additions & 15 deletions pinecone/grpc/index_grpc_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,18 @@ async def upsert(
for batch in vector_batches
]

return await tqdm.gather(*tasks, disable=not show_progress, desc="Upserted batches")
if namespace is not None:
pbar_desc = f"Upserted vectors in namespace '{namespace}'"
else:
pbar_desc = "Upserted vectors in namespace ''"

upserted_count = 0
with tqdm(total=len(vectors), disable=not show_progress, desc=pbar_desc) as pbar:
for task in asyncio.as_completed(tasks):
res = await task
pbar.update(res.upserted_count)
upserted_count += res.upserted_count
return UpsertResponse(upserted_count=upserted_count)

async def _upsert_batch(
self,
Expand Down Expand Up @@ -173,12 +184,12 @@ async def _query(
)
return json_format.MessageToDict(response)

async def composite_query(
async def query(
self,
vector: Optional[List[float]] = None,
id: Optional[str] = None,
namespace: Optional[str] = None,
top_k: Optional[int] = None,
top_k: Optional[int] = 10,
filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None,
include_values: Optional[bool] = None,
include_metadata: Optional[bool] = None,
Expand Down Expand Up @@ -244,12 +255,11 @@ async def composite_query(
)
return parse_query_response(json_response, _check_type=False)

async def multi_namespace_query(
async def composite_query(
self,
vector: Optional[List[float]] = None,
id: Optional[str] = None,
namespaces: Optional[str] = None,
top_k: Optional[int] = None,
namespaces: Optional[List[str]] = None,
top_k: Optional[int] = 10,
filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None,
include_values: Optional[bool] = None,
include_metadata: Optional[bool] = None,
Expand All @@ -258,12 +268,13 @@ async def multi_namespace_query(
semaphore: Optional[asyncio.Semaphore] = None,
**kwargs,
) -> Awaitable[CompositeQueryResults]:
aggregator_lock = asyncio.Lock()
semaphore = self._get_semaphore(max_concurrent_requests, semaphore)
aggregator = QueryResultsAggregator(top_k=top_k)

queries = [
query_tasks = [
self._query(
vector=vector,
id=id,
namespace=ns,
top_k=top_k,
filter=filter,
Expand All @@ -276,13 +287,11 @@ async def multi_namespace_query(
for ns in namespaces
]

results = await asyncio.gather(*queries, return_exceptions=True)
for query_task in asyncio.as_completed(query_tasks):
response = await query_task
async with aggregator_lock:
aggregator.add_results(response)

aggregator = QueryResultsAggregator(top_k=top_k)
for result in results:
if isinstance(result, Exception):
continue
aggregator.add_results(result)
final_results = aggregator.get_results()
return final_results

Expand Down
8 changes: 7 additions & 1 deletion pinecone/grpc/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class PineconeGRPC(Pinecone):
"""

def Index(self, name: str = "", host: str = "", use_asyncio=False, **kwargs):
def Index(self, name: str = "", host: str = "", **kwargs):
"""
Target an index for data operations.
Expand Down Expand Up @@ -119,6 +119,12 @@ def Index(self, name: str = "", host: str = "", use_asyncio=False, **kwargs):
index.query(vector=[...], top_k=10)
```
"""
return self._init_index(name=name, host=host, use_asyncio=False, **kwargs)

def AsyncioIndex(self, name: str = "", host: str = "", **kwargs):
return self._init_index(name=name, host=host, use_asyncio=True, **kwargs)

def _init_index(self, name: str, host: str, use_asyncio=False, **kwargs):
if name == "" and host == "":
raise ValueError("Either name or host must be specified")

Expand Down
72 changes: 59 additions & 13 deletions pinecone/grpc/query_results_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,39 +81,85 @@ def __repr__(self):
)


class QueryResultsAggregationEmptyResultsError(Exception):
def __init__(self, namespace: str):
super().__init__(
f"Cannot infer metric type from empty query results. Query result for namespace '{namespace}' is empty. Have you spelled the namespace name correctly?"
)


class QueryResultsAggregregatorNotEnoughResultsError(Exception):
def __init__(self, top_k: int, num_results: int):
super().__init__(
f"Cannot interpret results without at least two matches. In order to aggregate results from multiple queries, top_k must be greater than 1 in order to correctly infer the similarity metric from scores. Expected at least {top_k} results but got {num_results}."
)


class QueryResultsAggregatorInvalidTopKError(Exception):
def __init__(self, top_k: int):
super().__init__(f"Invalid top_k value {top_k}. top_k must be a positive integer.")


class QueryResultsAggregator:
def __init__(self, top_k: int):
if top_k < 1:
raise QueryResultsAggregatorInvalidTopKError(top_k)
self.top_k = top_k
self.usage_read_units = 0
self.heap = []
self.insertion_counter = 0
self.is_dotproduct = None
self.read = False

def __is_dotproduct_index(self, matches):
# The interpretation of the score depends on the similar metric used.
# Unlike other index types, in indexes configured for dotproduct,
# a higher score is better. We have to infer this is the case by inspecting
# the order of the scores in the results.
for i in range(1, len(matches)):
if matches[i].get("score") > matches[i - 1].get("score"): # Found an increase
return False
return True

def add_results(self, results: QueryResponse):
if self.read:
# This is mainly just to sanity check in test cases which get quite confusing
# if you read results twice due to the heap being emptied when constructing
# the ordered results.
raise ValueError("Results have already been read. Cannot add more results.")

self.usage_read_units += results.get("usage", {}).get("readUnits", 0)
matches = results.get("matches", [])
ns = results.get("namespace")
for match in results.get("matches", []):
self.insertion_counter += 1
score = match.get("score")
if len(self.heap) < self.top_k:
heapq.heappush(self.heap, (-score, -self.insertion_counter, match, ns))
else:
heapq.heappushpop(self.heap, (-score, -self.insertion_counter, match, ns))
self.usage_read_units += results.get("usage", {}).get("readUnits", 0)

if self.is_dotproduct is None:
if len(matches) == 0:
raise QueryResultsAggregationEmptyResultsError(ns)
if len(matches) == 1:
raise QueryResultsAggregregatorNotEnoughResultsError(self.top_k, len(matches))
self.is_dotproduct = self.__is_dotproduct_index(matches)

print("is_dotproduct:", self.is_dotproduct)
if self.is_dotproduct:
raise NotImplementedError("Dotproduct indexes are not yet supported.")
else:
for match in matches:
self.insertion_counter += 1
score = match.get("score")
if len(self.heap) < self.top_k:
heapq.heappush(self.heap, (-score, -self.insertion_counter, match, ns))
else:
heapq.heappushpop(self.heap, (-score, -self.insertion_counter, match, ns))

def get_results(self) -> CompositeQueryResults:
if self.read:
# This is mainly just to sanity check in test cases which get quite confusing
# if you read results twice due to the heap being emptied each time you read
# results into an ordered list.
raise ValueError("Results have already been read. Cannot read again.")
return self.final_results
self.read = True

return CompositeQueryResults(
self.final_results = CompositeQueryResults(
usage=Usage(read_units=self.usage_read_units),
matches=[
ScoredVectorWithNamespace(heapq.heappop(self.heap)) for _ in range(len(self.heap))
][::-1],
)
return self.final_results
Loading

0 comments on commit 345822c

Please sign in to comment.