Skip to content

Commit

Permalink
Use PoolThreadExecutor
Browse files Browse the repository at this point in the history
  • Loading branch information
jhamon committed Nov 13, 2024
1 parent 743e6f5 commit 0587ae5
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 7 deletions.
10 changes: 10 additions & 0 deletions pinecone/grpc/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pinecone import Config
from .config import GRPCClientConfig
from .grpc_runner import GrpcRunner
from concurrent.futures import ThreadPoolExecutor

from pinecone_plugin_interface import load_and_install as install_plugins

Expand All @@ -29,10 +30,12 @@ def __init__(
config: Config,
channel: Optional[Channel] = None,
grpc_config: Optional[GRPCClientConfig] = None,
pool_threads: Optional[int] = None,
_endpoint_override: Optional[str] = None,
):
self.config = config
self.grpc_client_config = grpc_config or GRPCClientConfig()
self.pool_threads = pool_threads

self._endpoint_override = _endpoint_override

Expand All @@ -58,6 +61,13 @@ def stub_openapi_client_builder(*args, **kwargs):
except Exception as e:
_logger.error(f"Error loading plugins in GRPCIndex: {e}")

@property
def threadpool_executor(self):
if self._pool is None:
pt = self.pool_threads or 10
self._pool = ThreadPoolExecutor(max_workers=pt)
return self._pool

@property
@abstractmethod
def stub_class(self):
Expand Down
11 changes: 5 additions & 6 deletions pinecone/grpc/index_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,25 +426,24 @@ def query_namespaces(

target_namespaces = set(namespaces) # dedup namespaces
futures = [
self.query(
self.threadpool_executor.submit(
self.query,
vector=vector,
namespace=ns,
top_k=overall_topk,
filter=filter,
include_values=include_values,
include_metadata=include_metadata,
sparse_vector=sparse_vector,
async_req=True,
async_req=False,
**kwargs,
)
for ns in target_namespaces
]

only_futures = cast(Iterable[Future], futures)
for future in as_completed(only_futures):
response = future.result()
json_result = json_format.MessageToDict(response)
aggregator.add_results(json_result)
for response in as_completed(only_futures):
aggregator.add_results(response.result())

final_results = aggregator.get_results()
return final_results
Expand Down
4 changes: 3 additions & 1 deletion pinecone/grpc/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,13 @@ def Index(self, name: str = "", host: str = "", **kwargs):
# Use host if it is provided, otherwise get host from describe_index
index_host = host or self.index_host_store.get_host(self.index_api, self.config, name)

pt = kwargs.pop("pool_threads", None) or self.pool_threads

config = ConfigBuilder.build(
api_key=self.config.api_key,
host=index_host,
source_tag=self.config.source_tag,
proxy_url=self.config.proxy_url,
ssl_ca_certs=self.config.ssl_ca_certs,
)
return GRPCIndex(index_name=name, config=config, **kwargs)
return GRPCIndex(index_name=name, config=config, pool_threads=pt, **kwargs)

0 comments on commit 0587ae5

Please sign in to comment.