From f7082367351b82533462b7ecaef542da8dbeb820 Mon Sep 17 00:00:00 2001 From: Jen Hamon Date: Fri, 18 Oct 2024 12:23:25 -0400 Subject: [PATCH] Fix mypy issues --- pinecone/grpc/index_grpc_asyncio.py | 35 ++++++++++++++--------- pinecone/grpc/query_results.py | 14 +++++++++ pinecone/grpc/query_results_aggregator.py | 29 +++++++++++-------- pinecone/grpc/utils.py | 2 +- 4 files changed, 53 insertions(+), 27 deletions(-) create mode 100644 pinecone/grpc/query_results.py diff --git a/pinecone/grpc/index_grpc_asyncio.py b/pinecone/grpc/index_grpc_asyncio.py index 0f7f906e..1891ed80 100644 --- a/pinecone/grpc/index_grpc_asyncio.py +++ b/pinecone/grpc/index_grpc_asyncio.py @@ -1,4 +1,4 @@ -from typing import Optional, Union, List, Dict, Awaitable +from typing import Optional, Union, List, Dict, Awaitable, Any from tqdm.asyncio import tqdm import asyncio @@ -91,14 +91,14 @@ async def upsert( max_concurrent_requests: Optional[int] = None, semaphore: Optional[asyncio.Semaphore] = None, **kwargs, - ) -> Awaitable[UpsertResponse]: + ) -> UpsertResponse: timeout = kwargs.pop("timeout", None) vectors = list(map(VectorFactoryGRPC.build, vectors)) semaphore = self._get_semaphore(max_concurrent_requests, semaphore) if batch_size is None: return await self._upsert_batch( - vectors, namespace, timeout=timeout, semaphore=semaphore, **kwargs + vectors=vectors, namespace=namespace, timeout=timeout, semaphore=semaphore, **kwargs ) if not isinstance(batch_size, int) or batch_size <= 0: @@ -132,7 +132,7 @@ async def _upsert_batch( namespace: Optional[str], timeout: Optional[int] = None, **kwargs, - ) -> Awaitable[UpsertResponse]: + ) -> UpsertResponse: args_dict = parse_non_empty_args([("namespace", namespace)]) request = UpsertRequest(vectors=vectors, **args_dict) return await self.runner.run_asyncio( @@ -151,7 +151,7 @@ async def _query( sparse_vector: Optional[Union[GRPCSparseValues, SparseVectorTypedDict]] = None, semaphore: Optional[asyncio.Semaphore] = None, **kwargs, - ) -> Awaitable[Dict]: + ) -> dict[str, Any]: if vector is not None and id is not None: raise ValueError("Cannot specify both `id` and `vector`") @@ -182,7 +182,8 @@ async def _query( response = await self.runner.run_asyncio( self.stub.Query, request, timeout=timeout, semaphore=semaphore ) - return json_format.MessageToDict(response) + parsed = json_format.MessageToDict(response) + return parsed async def query( self, @@ -196,7 +197,7 @@ async def query( sparse_vector: Optional[Union[GRPCSparseValues, SparseVectorTypedDict]] = None, semaphore: Optional[asyncio.Semaphore] = None, **kwargs, - ) -> Awaitable[QueryResponse]: + ) -> QueryResponse: """ The Query operation searches a namespace, using a query vector. It retrieves the ids of the most similar items in a namespace, along with their similarity scores. @@ -257,9 +258,9 @@ async def query( async def composite_query( self, - vector: Optional[List[float]] = None, - namespaces: Optional[List[str]] = None, - top_k: Optional[int] = 10, + vector: List[float], + namespaces: List[str], + top_k: Optional[int] = None, filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None, include_values: Optional[bool] = None, include_metadata: Optional[bool] = None, @@ -268,17 +269,23 @@ async def composite_query( max_concurrent_requests: Optional[int] = None, semaphore: Optional[asyncio.Semaphore] = None, **kwargs, - ) -> Awaitable[CompositeQueryResults]: + ) -> CompositeQueryResults: aggregator_lock = asyncio.Lock() semaphore = self._get_semaphore(max_concurrent_requests, semaphore) - # The caller may only want the topK=1 result across all queries, + if len(namespaces) == 0: + raise ValueError("At least one namespace must be specified") + if len(vector) == 0: + raise ValueError("Query vector must not be empty") + + # The caller may only want the top_k=1 result across all queries, # but we need to get at least 2 results from each query in order to # aggregate them correctly. So we'll temporarily set topK to 2 for the # subqueries, and then we'll take the topK=1 results from the aggregated # results. - aggregator = QueryResultsAggregator(top_k=top_k) - subquery_topk = top_k if top_k > 2 else 2 + overall_topk = top_k if top_k is not None else 10 + aggregator = QueryResultsAggregator(top_k=overall_topk) + subquery_topk = overall_topk if overall_topk > 2 else 2 target_namespaces = set(namespaces) # dedup namespaces query_tasks = [ diff --git a/pinecone/grpc/query_results.py b/pinecone/grpc/query_results.py new file mode 100644 index 00000000..b2201b50 --- /dev/null +++ b/pinecone/grpc/query_results.py @@ -0,0 +1,14 @@ +from typing import TypedDict, List, Dict, Any + + +class ScoredVectorTypedDict(TypedDict): + id: str + score: float + values: List[float] + metadata: dict + + +class QueryResultsTypedDict(TypedDict): + matches: List[ScoredVectorTypedDict] + namespace: str + usage: Dict[str, Any] diff --git a/pinecone/grpc/query_results_aggregator.py b/pinecone/grpc/query_results_aggregator.py index 345f006d..171c9e80 100644 --- a/pinecone/grpc/query_results_aggregator.py +++ b/pinecone/grpc/query_results_aggregator.py @@ -1,7 +1,7 @@ -from typing import List, Tuple +from typing import List, Tuple, Optional, Any, Dict import json import heapq -from pinecone.core.openapi.data.models import QueryResponse, Usage +from pinecone.core.openapi.data.models import Usage from dataclasses import dataclass, asdict @@ -15,14 +15,14 @@ class ScoredVectorWithNamespace: sparse_values: dict metadata: dict - def __init__(self, aggregate_results_heap_tuple: Tuple[float, int, dict, str]): + def __init__(self, aggregate_results_heap_tuple: Tuple[float, int, object, str]): json_vector = aggregate_results_heap_tuple[2] self.namespace = aggregate_results_heap_tuple[3] - self.id = json_vector.get("id") - self.score = json_vector.get("score") - self.values = json_vector.get("values") - self.sparse_values = json_vector.get("sparse_values", None) - self.metadata = json_vector.get("metadata", None) + self.id = json_vector.get("id") # type: ignore + self.score = json_vector.get("score") # type: ignore + self.values = json_vector.get("values") # type: ignore + self.sparse_values = json_vector.get("sparse_values", None) # type: ignore + self.metadata = json_vector.get("metadata", None) # type: ignore def __getitem__(self, key): if hasattr(self, key): @@ -106,10 +106,11 @@ def __init__(self, top_k: int): raise QueryResultsAggregatorInvalidTopKError(top_k) self.top_k = top_k self.usage_read_units = 0 - self.heap = [] + self.heap: List[Tuple[float, int, object, str]] = [] self.insertion_counter = 0 self.is_dotproduct = None self.read = False + self.final_results: Optional[CompositeQueryResults] = None def _is_dotproduct_index(self, matches): # The interpretation of the score depends on the similar metric used. @@ -135,7 +136,7 @@ def _process_matches(self, matches, ns, heap_item_fn): else: heapq.heappushpop(self.heap, heap_item_fn(match, ns)) - def add_results(self, results: QueryResponse): + def add_results(self, results: Dict[str, Any]): if self.read: # This is mainly just to sanity check in test cases which get quite confusing # if you read results twice due to the heap being emptied when constructing @@ -143,7 +144,7 @@ def add_results(self, results: QueryResponse): raise ValueError("Results have already been read. Cannot add more results.") matches = results.get("matches", []) - ns = results.get("namespace") + ns: str = results.get("namespace", "") self.usage_read_units += results.get("usage", {}).get("readUnits", 0) if len(matches) == 0: @@ -161,7 +162,11 @@ def add_results(self, results: QueryResponse): def get_results(self) -> CompositeQueryResults: if self.read: - return self.final_results + if self.final_results is not None: + return self.final_results + else: + # I don't think this branch can ever actually be reached, but the type checker disagrees + raise ValueError("Results have already been read. Cannot get results again.") self.read = True self.final_results = CompositeQueryResults( diff --git a/pinecone/grpc/utils.py b/pinecone/grpc/utils.py index 13fd7d98..a6e996ea 100644 --- a/pinecone/grpc/utils.py +++ b/pinecone/grpc/utils.py @@ -10,8 +10,8 @@ QueryResponse, DescribeIndexStatsResponse, NamespaceSummary, - SparseValues as GRPCSparseValues, ) +from pinecone.core.grpc.protos.vector_service_pb2 import SparseValues as GRPCSparseValues from .sparse_vector import SparseVectorTypedDict from google.protobuf.struct_pb2 import Struct