Skip to content

Commit

Permalink
Fix mypy issues
Browse files Browse the repository at this point in the history
  • Loading branch information
jhamon committed Oct 18, 2024
1 parent dbee57b commit f708236
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 27 deletions.
35 changes: 21 additions & 14 deletions pinecone/grpc/index_grpc_asyncio.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Union, List, Dict, Awaitable
from typing import Optional, Union, List, Dict, Awaitable, Any

from tqdm.asyncio import tqdm
import asyncio
Expand Down Expand Up @@ -91,14 +91,14 @@ async def upsert(
max_concurrent_requests: Optional[int] = None,
semaphore: Optional[asyncio.Semaphore] = None,
**kwargs,
) -> Awaitable[UpsertResponse]:
) -> UpsertResponse:
timeout = kwargs.pop("timeout", None)
vectors = list(map(VectorFactoryGRPC.build, vectors))
semaphore = self._get_semaphore(max_concurrent_requests, semaphore)

if batch_size is None:
return await self._upsert_batch(
vectors, namespace, timeout=timeout, semaphore=semaphore, **kwargs
vectors=vectors, namespace=namespace, timeout=timeout, semaphore=semaphore, **kwargs
)

if not isinstance(batch_size, int) or batch_size <= 0:
Expand Down Expand Up @@ -132,7 +132,7 @@ async def _upsert_batch(
namespace: Optional[str],
timeout: Optional[int] = None,
**kwargs,
) -> Awaitable[UpsertResponse]:
) -> UpsertResponse:
args_dict = parse_non_empty_args([("namespace", namespace)])
request = UpsertRequest(vectors=vectors, **args_dict)
return await self.runner.run_asyncio(
Expand All @@ -151,7 +151,7 @@ async def _query(
sparse_vector: Optional[Union[GRPCSparseValues, SparseVectorTypedDict]] = None,
semaphore: Optional[asyncio.Semaphore] = None,
**kwargs,
) -> Awaitable[Dict]:
) -> dict[str, Any]:
if vector is not None and id is not None:
raise ValueError("Cannot specify both `id` and `vector`")

Expand Down Expand Up @@ -182,7 +182,8 @@ async def _query(
response = await self.runner.run_asyncio(
self.stub.Query, request, timeout=timeout, semaphore=semaphore
)
return json_format.MessageToDict(response)
parsed = json_format.MessageToDict(response)
return parsed

async def query(
self,
Expand All @@ -196,7 +197,7 @@ async def query(
sparse_vector: Optional[Union[GRPCSparseValues, SparseVectorTypedDict]] = None,
semaphore: Optional[asyncio.Semaphore] = None,
**kwargs,
) -> Awaitable[QueryResponse]:
) -> QueryResponse:
"""
The Query operation searches a namespace, using a query vector.
It retrieves the ids of the most similar items in a namespace, along with their similarity scores.
Expand Down Expand Up @@ -257,9 +258,9 @@ async def query(

async def composite_query(
self,
vector: Optional[List[float]] = None,
namespaces: Optional[List[str]] = None,
top_k: Optional[int] = 10,
vector: List[float],
namespaces: List[str],
top_k: Optional[int] = None,
filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None,
include_values: Optional[bool] = None,
include_metadata: Optional[bool] = None,
Expand All @@ -268,17 +269,23 @@ async def composite_query(
max_concurrent_requests: Optional[int] = None,
semaphore: Optional[asyncio.Semaphore] = None,
**kwargs,
) -> Awaitable[CompositeQueryResults]:
) -> CompositeQueryResults:
aggregator_lock = asyncio.Lock()
semaphore = self._get_semaphore(max_concurrent_requests, semaphore)

# The caller may only want the topK=1 result across all queries,
if len(namespaces) == 0:
raise ValueError("At least one namespace must be specified")
if len(vector) == 0:
raise ValueError("Query vector must not be empty")

# The caller may only want the top_k=1 result across all queries,
# but we need to get at least 2 results from each query in order to
# aggregate them correctly. So we'll temporarily set topK to 2 for the
# subqueries, and then we'll take the topK=1 results from the aggregated
# results.
aggregator = QueryResultsAggregator(top_k=top_k)
subquery_topk = top_k if top_k > 2 else 2
overall_topk = top_k if top_k is not None else 10
aggregator = QueryResultsAggregator(top_k=overall_topk)
subquery_topk = overall_topk if overall_topk > 2 else 2

target_namespaces = set(namespaces) # dedup namespaces
query_tasks = [
Expand Down
14 changes: 14 additions & 0 deletions pinecone/grpc/query_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from typing import TypedDict, List, Dict, Any


class ScoredVectorTypedDict(TypedDict):
id: str
score: float
values: List[float]
metadata: dict


class QueryResultsTypedDict(TypedDict):
matches: List[ScoredVectorTypedDict]
namespace: str
usage: Dict[str, Any]
29 changes: 17 additions & 12 deletions pinecone/grpc/query_results_aggregator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import List, Tuple
from typing import List, Tuple, Optional, Any, Dict
import json
import heapq
from pinecone.core.openapi.data.models import QueryResponse, Usage
from pinecone.core.openapi.data.models import Usage

from dataclasses import dataclass, asdict

Expand All @@ -15,14 +15,14 @@ class ScoredVectorWithNamespace:
sparse_values: dict
metadata: dict

def __init__(self, aggregate_results_heap_tuple: Tuple[float, int, dict, str]):
def __init__(self, aggregate_results_heap_tuple: Tuple[float, int, object, str]):
json_vector = aggregate_results_heap_tuple[2]
self.namespace = aggregate_results_heap_tuple[3]
self.id = json_vector.get("id")
self.score = json_vector.get("score")
self.values = json_vector.get("values")
self.sparse_values = json_vector.get("sparse_values", None)
self.metadata = json_vector.get("metadata", None)
self.id = json_vector.get("id") # type: ignore
self.score = json_vector.get("score") # type: ignore
self.values = json_vector.get("values") # type: ignore
self.sparse_values = json_vector.get("sparse_values", None) # type: ignore
self.metadata = json_vector.get("metadata", None) # type: ignore

def __getitem__(self, key):
if hasattr(self, key):
Expand Down Expand Up @@ -106,10 +106,11 @@ def __init__(self, top_k: int):
raise QueryResultsAggregatorInvalidTopKError(top_k)
self.top_k = top_k
self.usage_read_units = 0
self.heap = []
self.heap: List[Tuple[float, int, object, str]] = []
self.insertion_counter = 0
self.is_dotproduct = None
self.read = False
self.final_results: Optional[CompositeQueryResults] = None

def _is_dotproduct_index(self, matches):
# The interpretation of the score depends on the similar metric used.
Expand All @@ -135,15 +136,15 @@ def _process_matches(self, matches, ns, heap_item_fn):
else:
heapq.heappushpop(self.heap, heap_item_fn(match, ns))

def add_results(self, results: QueryResponse):
def add_results(self, results: Dict[str, Any]):
if self.read:
# This is mainly just to sanity check in test cases which get quite confusing
# if you read results twice due to the heap being emptied when constructing
# the ordered results.
raise ValueError("Results have already been read. Cannot add more results.")

matches = results.get("matches", [])
ns = results.get("namespace")
ns: str = results.get("namespace", "")
self.usage_read_units += results.get("usage", {}).get("readUnits", 0)

if len(matches) == 0:
Expand All @@ -161,7 +162,11 @@ def add_results(self, results: QueryResponse):

def get_results(self) -> CompositeQueryResults:
if self.read:
return self.final_results
if self.final_results is not None:
return self.final_results
else:
# I don't think this branch can ever actually be reached, but the type checker disagrees
raise ValueError("Results have already been read. Cannot get results again.")
self.read = True

self.final_results = CompositeQueryResults(
Expand Down
2 changes: 1 addition & 1 deletion pinecone/grpc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
QueryResponse,
DescribeIndexStatsResponse,
NamespaceSummary,
SparseValues as GRPCSparseValues,
)
from pinecone.core.grpc.protos.vector_service_pb2 import SparseValues as GRPCSparseValues
from .sparse_vector import SparseVectorTypedDict

from google.protobuf.struct_pb2 import Struct
Expand Down

0 comments on commit f708236

Please sign in to comment.