Skip to content

Commit

Permalink
WIP on threadpool impl of query_namespaces
Browse files Browse the repository at this point in the history
  • Loading branch information
jhamon committed Oct 25, 2024
1 parent 247a329 commit 31eacbd
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 4 deletions.
102 changes: 100 additions & 2 deletions pinecone/data/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from pinecone.core.openapi.data.api.data_plane_api import DataPlaneApi
from ..utils import setup_openapi_client, parse_non_empty_args
from .vector_factory import VectorFactory
from .query_results_aggregator import QueryResultsAggregator, QueryNamespacesResults
from multiprocessing.pool import ApplyResult

__all__ = [
"Index",
Expand Down Expand Up @@ -361,7 +363,7 @@ def query(
Union[SparseValues, Dict[str, Union[List[float], List[int]]]]
] = None,
**kwargs,
) -> QueryResponse:
) -> Union[QueryResponse, ApplyResult[QueryResponse]]:
"""
The Query operation searches a namespace, using a query vector.
It retrieves the ids of the most similar items in a namespace, along with their similarity scores.
Expand Down Expand Up @@ -403,6 +405,39 @@ def query(
and namespace name.
"""

response = self._query(
*args,
top_k=top_k,
vector=vector,
id=id,
namespace=namespace,
filter=filter,
include_values=include_values,
include_metadata=include_metadata,
sparse_vector=sparse_vector,
**kwargs,
)

if kwargs.get("async_req", False):
return response
else:
return parse_query_response(response)

def _query(
self,
*args,
top_k: int,
vector: Optional[List[float]] = None,
id: Optional[str] = None,
namespace: Optional[str] = None,
filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None,
include_values: Optional[bool] = None,
include_metadata: Optional[bool] = None,
sparse_vector: Optional[
Union[SparseValues, Dict[str, Union[List[float], List[int]]]]
] = None,
**kwargs,
) -> QueryResponse:
if len(args) > 0:
raise ValueError(
"The argument order for `query()` has changed; please use keyword arguments instead of positional arguments. Example: index.query(vector=[0.1, 0.2, 0.3], top_k=10, namespace='my_namespace')"
Expand Down Expand Up @@ -435,7 +470,70 @@ def query(
),
**{k: v for k, v in kwargs.items() if k in _OPENAPI_ENDPOINT_PARAMS},
)
return parse_query_response(response)
return response

def query_namespaces(
self,
vector: List[float],
namespaces: List[str],
top_k: Optional[int] = None,
filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None,
include_values: Optional[bool] = None,
include_metadata: Optional[bool] = None,
sparse_vector: Optional[
Union[SparseValues, Dict[str, Union[List[float], List[int]]]]
] = None,
show_progress: Optional[bool] = True,
**kwargs,
) -> QueryNamespacesResults:
if len(namespaces) == 0:
raise ValueError("At least one namespace must be specified")
if len(vector) == 0:
raise ValueError("Query vector must not be empty")

# The caller may only want the top_k=1 result across all queries,
# but we need to get at least 2 results from each query in order to
# aggregate them correctly. So we'll temporarily set topK to 2 for the
# subqueries, and then we'll take the topK=1 results from the aggregated
# results.
overall_topk = top_k if top_k is not None else 10
aggregator = QueryResultsAggregator(top_k=overall_topk)
subquery_topk = overall_topk if overall_topk > 2 else 2

target_namespaces = set(namespaces) # dedup namespaces
async_results = [
self.query(
vector=vector,
namespace=ns,
top_k=subquery_topk,
filter=filter,
include_values=include_values,
include_metadata=include_metadata,
sparse_vector=sparse_vector,
async_req=True,
**kwargs,
)
for ns in target_namespaces
]

for result in async_results:
response = result.get()
aggregator.add_results(response)

final_results = aggregator.get_results()
return final_results

# with tqdm(
# total=len(query_tasks), disable=not show_progress, desc="Querying namespaces"
# ) as pbar:
# for query_task in asyncio.as_completed(query_tasks):
# response = await query_task
# pbar.update(1)
# async with aggregator_lock:
# aggregator.add_results(response)

# final_results = aggregator.get_results()
# return final_results

@validate_and_convert_errors
def update(
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion pinecone/grpc/index_grpc_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
parse_sparse_values_arg,
)
from .vector_factory_grpc import VectorFactoryGRPC
from .query_results_aggregator import QueryResultsAggregator, QueryNamespacesResults
from ..data.query_results_aggregator import QueryResultsAggregator, QueryNamespacesResults


class GRPCIndexAsyncio(GRPCIndexBase):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pinecone.grpc.query_results_aggregator import (
from pinecone.data.query_results_aggregator import (
QueryResultsAggregator,
QueryResultsAggregatorInvalidTopKError,
QueryResultsAggregregatorNotEnoughResultsError,
Expand Down

0 comments on commit 31eacbd

Please sign in to comment.