Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement query_namespaces over grpc #416

Merged
merged 2 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions pinecone/grpc/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pinecone import Config
from .config import GRPCClientConfig
from .grpc_runner import GrpcRunner
from concurrent.futures import ThreadPoolExecutor

from pinecone_plugin_interface import load_and_install as install_plugins

Expand All @@ -29,10 +30,12 @@ def __init__(
config: Config,
channel: Optional[Channel] = None,
grpc_config: Optional[GRPCClientConfig] = None,
pool_threads: Optional[int] = None,
_endpoint_override: Optional[str] = None,
):
self.config = config
self.grpc_client_config = grpc_config or GRPCClientConfig()
self.pool_threads = pool_threads

self._endpoint_override = _endpoint_override

Expand All @@ -58,6 +61,13 @@ def stub_openapi_client_builder(*args, **kwargs):
except Exception as e:
_logger.error(f"Error loading plugins in GRPCIndex: {e}")

@property
def threadpool_executor(self):
if self._pool is None:
pt = self.pool_threads or 10
self._pool = ThreadPoolExecutor(max_workers=pt)
return self._pool

@property
@abstractmethod
def stub_class(self):
Expand Down
48 changes: 47 additions & 1 deletion pinecone/grpc/index_grpc.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import logging
from typing import Optional, Dict, Union, List, Tuple, Any, TypedDict, cast
from typing import Optional, Dict, Union, List, Tuple, Any, TypedDict, Iterable, cast

from google.protobuf import json_format

from tqdm.autonotebook import tqdm
from concurrent.futures import as_completed, Future


from .utils import (
dict_to_proto_struct,
Expand Down Expand Up @@ -35,6 +37,7 @@
SparseValues as GRPCSparseValues,
)
from pinecone import Vector as NonGRPCVector
from pinecone.data.query_results_aggregator import QueryNamespacesResults, QueryResultsAggregator
from pinecone.core.grpc.protos.vector_service_pb2_grpc import VectorServiceStub
from .base import GRPCIndexBase
from .future import PineconeGrpcFuture
Expand Down Expand Up @@ -402,6 +405,49 @@ def query(
json_response = json_format.MessageToDict(response)
return parse_query_response(json_response, _check_type=False)

def query_namespaces(
self,
vector: List[float],
namespaces: List[str],
top_k: Optional[int] = None,
filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None,
include_values: Optional[bool] = None,
include_metadata: Optional[bool] = None,
sparse_vector: Optional[Union[GRPCSparseValues, SparseVectorTypedDict]] = None,
**kwargs,
) -> QueryNamespacesResults:
if namespaces is None or len(namespaces) == 0:
raise ValueError("At least one namespace must be specified")
if len(vector) == 0:
raise ValueError("Query vector must not be empty")

overall_topk = top_k if top_k is not None else 10
aggregator = QueryResultsAggregator(top_k=overall_topk)

target_namespaces = set(namespaces) # dedup namespaces
futures = [
self.threadpool_executor.submit(
self.query,
vector=vector,
namespace=ns,
top_k=overall_topk,
filter=filter,
include_values=include_values,
include_metadata=include_metadata,
sparse_vector=sparse_vector,
async_req=False,
**kwargs,
)
for ns in target_namespaces
]

only_futures = cast(Iterable[Future], futures)
for response in as_completed(only_futures):
aggregator.add_results(response.result())

final_results = aggregator.get_results()
return final_results

def update(
self,
id: str,
Expand Down
4 changes: 3 additions & 1 deletion pinecone/grpc/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,13 @@ def Index(self, name: str = "", host: str = "", **kwargs):
# Use host if it is provided, otherwise get host from describe_index
index_host = host or self.index_host_store.get_host(self.index_api, self.config, name)

pt = kwargs.pop("pool_threads", None) or self.pool_threads

config = ConfigBuilder.build(
api_key=self.config.api_key,
host=index_host,
source_tag=self.config.source_tag,
proxy_url=self.config.proxy_url,
ssl_ca_certs=self.config.ssl_ca_certs,
)
return GRPCIndex(index_name=name, config=config, **kwargs)
return GRPCIndex(index_name=name, config=config, pool_threads=pt, **kwargs)
4 changes: 0 additions & 4 deletions tests/integration/data/test_query_namespaces.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import pytest
import os
from ..helpers import random_string, poll_stats_for_namespace
from pinecone.data.query_results_aggregator import (
QueryResultsAggregatorInvalidTopKError,
Expand All @@ -9,9 +8,6 @@
from pinecone import Vector


@pytest.mark.skipif(
os.getenv("USE_GRPC") == "true", reason="query_namespaces currently only available via rest"
)
class TestQueryNamespacesRest:
def test_query_namespaces(self, idx):
ns_prefix = random_string(5)
Expand Down
Loading