Skip to content

Commit

Permalink
WIP on asyncio index and composite_query method
Browse files Browse the repository at this point in the history
  • Loading branch information
jhamon committed Oct 18, 2024
1 parent 3edd2a7 commit 6d99df7
Show file tree
Hide file tree
Showing 9 changed files with 1,000 additions and 92 deletions.
16 changes: 9 additions & 7 deletions pinecone/control/pinecone.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import time
import logging
from typing import Optional, Dict, Any, Union, List, Tuple, Literal
from typing import Optional, Dict, Any, Union, Literal

from .index_host_store import IndexHostStore

Expand All @@ -10,7 +10,12 @@
from pinecone.core.openapi.shared.api_client import ApiClient


from pinecone.utils import normalize_host, setup_openapi_client, build_plugin_setup_client
from pinecone.utils import (
normalize_host,
setup_openapi_client,
build_plugin_setup_client,
parse_non_empty_args,
)
from pinecone.core.openapi.control.models import (
CreateCollectionRequest,
CreateIndexRequest,
Expand Down Expand Up @@ -317,9 +322,6 @@ def create_index(

api_instance = self.index_api

def _parse_non_empty_args(args: List[Tuple[str, Any]]) -> Dict[str, Any]:
return {arg_name: val for arg_name, val in args if val is not None}

if deletion_protection in ["enabled", "disabled"]:
dp = DeletionProtection(deletion_protection)
else:
Expand All @@ -329,7 +331,7 @@ def _parse_non_empty_args(args: List[Tuple[str, Any]]) -> Dict[str, Any]:
if "serverless" in spec:
index_spec = IndexSpec(serverless=ServerlessSpecModel(**spec["serverless"]))
elif "pod" in spec:
args_dict = _parse_non_empty_args(
args_dict = parse_non_empty_args(
[
("environment", spec["pod"].get("environment")),
("metadata_config", spec["pod"].get("metadata_config")),
Expand All @@ -351,7 +353,7 @@ def _parse_non_empty_args(args: List[Tuple[str, Any]]) -> Dict[str, Any]:
serverless=ServerlessSpecModel(cloud=spec.cloud, region=spec.region)
)
elif isinstance(spec, PodSpec):
args_dict = _parse_non_empty_args(
args_dict = parse_non_empty_args(
[
("replicas", spec.replicas),
("shards", spec.shards),
Expand Down
1 change: 1 addition & 0 deletions pinecone/grpc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
"""

from .index_grpc import GRPCIndex
from .index_grpc_asyncio import GRPCIndexAsyncio
from .pinecone import PineconeGRPC
from .config import GRPCClientConfig

Expand Down
29 changes: 21 additions & 8 deletions pinecone/grpc/grpc_runner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from functools import wraps
from typing import Dict, Tuple, Optional

Expand Down Expand Up @@ -62,20 +63,32 @@ async def run_asyncio(
credentials: Optional[CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[Compression] = None,
semaphore: Optional[asyncio.Semaphore] = None,
):
@wraps(func)
async def wrapped():
user_provided_metadata = metadata or {}
_metadata = self._prepare_metadata(user_provided_metadata)
try:
return await func(
request,
timeout=timeout,
metadata=_metadata,
credentials=credentials,
wait_for_ready=wait_for_ready,
compression=compression,
)
if semaphore is not None:
async with semaphore:
return await func(
request,
timeout=timeout,
metadata=_metadata,
credentials=credentials,
wait_for_ready=wait_for_ready,
compression=compression,
)
else:
return await func(
request,
timeout=timeout,
metadata=_metadata,
credentials=credentials,
wait_for_ready=wait_for_ready,
compression=compression,
)
except _InactiveRpcError as e:
raise PineconeException(e._state.debug_error_string) from e

Expand Down
62 changes: 17 additions & 45 deletions pinecone/grpc/index_grpc.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
import logging
from typing import Optional, Dict, Union, List, Tuple, Any, TypedDict, cast
from typing import Optional, Dict, Union, List, cast

from google.protobuf import json_format

from tqdm.autonotebook import tqdm

from pinecone.utils import parse_non_empty_args
from .utils import (
dict_to_proto_struct,
parse_fetch_response,
parse_query_response,
parse_stats_response,
parse_sparse_values_arg,
)
from .vector_factory_grpc import VectorFactoryGRPC
from .base import GRPCIndexBase
from .future import PineconeGrpcFuture
from .sparse_vector import SparseVectorTypedDict
from .config import GRPCClientConfig

from pinecone.core.openapi.data.models import (
FetchResponse,
Expand All @@ -36,10 +42,7 @@
)
from pinecone import Vector as NonGRPCVector
from pinecone.core.grpc.protos.vector_service_pb2_grpc import VectorServiceStub
from .base import GRPCIndexBase
from .future import PineconeGrpcFuture

from .config import GRPCClientConfig
from pinecone.config import Config
from grpc._channel import Channel

Expand All @@ -49,11 +52,6 @@
_logger = logging.getLogger(__name__)


class SparseVectorTypedDict(TypedDict):
indices: List[int]
values: List[float]


class GRPCIndex(GRPCIndexBase):
"""A client for interacting with a Pinecone index via GRPC API."""

Expand Down Expand Up @@ -152,7 +150,7 @@ def upsert(

vectors = list(map(VectorFactoryGRPC.build, vectors))
if async_req:
args_dict = self._parse_non_empty_args([("namespace", namespace)])
args_dict = parse_non_empty_args([("namespace", namespace)])
request = UpsertRequest(vectors=vectors, **args_dict, **kwargs)
future = self.runner.run(self.stub.Upsert.future, request, timeout=timeout)
return PineconeGrpcFuture(future)
Expand All @@ -178,7 +176,7 @@ def upsert(
def _upsert_batch(
self, vectors: List[GRPCVector], namespace: Optional[str], timeout: Optional[int], **kwargs
) -> UpsertResponse:
args_dict = self._parse_non_empty_args([("namespace", namespace)])
args_dict = parse_non_empty_args([("namespace", namespace)])
request = UpsertRequest(vectors=vectors, **args_dict)
return self.runner.run(self.stub.Upsert, request, timeout=timeout, **kwargs)

Expand Down Expand Up @@ -285,7 +283,7 @@ def delete(
else:
filter_struct = None

args_dict = self._parse_non_empty_args(
args_dict = parse_non_empty_args(
[
("ids", ids),
("delete_all", delete_all),
Expand Down Expand Up @@ -322,7 +320,7 @@ def fetch(
"""
timeout = kwargs.pop("timeout", None)

args_dict = self._parse_non_empty_args([("namespace", namespace)])
args_dict = parse_non_empty_args([("namespace", namespace)])

request = FetchRequest(ids=ids, **args_dict, **kwargs)
response = self.runner.run(self.stub.Fetch, request, timeout=timeout)
Expand Down Expand Up @@ -388,8 +386,8 @@ def query(
else:
filter_struct = None

sparse_vector = self._parse_sparse_values_arg(sparse_vector)
args_dict = self._parse_non_empty_args(
sparse_vector = parse_sparse_values_arg(sparse_vector)
args_dict = parse_non_empty_args(
[
("vector", vector),
("id", id),
Expand Down Expand Up @@ -456,8 +454,8 @@ def update(
set_metadata_struct = None

timeout = kwargs.pop("timeout", None)
sparse_values = self._parse_sparse_values_arg(sparse_values)
args_dict = self._parse_non_empty_args(
sparse_values = parse_sparse_values_arg(sparse_values)
args_dict = parse_non_empty_args(
[
("values", values),
("set_metadata", set_metadata_struct),
Expand Down Expand Up @@ -506,7 +504,7 @@ def list_paginated(
Returns: SimpleListResponse object which contains the list of ids, the namespace name, pagination information, and usage showing the number of read_units consumed.
"""
args_dict = self._parse_non_empty_args(
args_dict = parse_non_empty_args(
[
("prefix", prefix),
("limit", limit),
Expand Down Expand Up @@ -585,36 +583,10 @@ def describe_index_stats(
filter_struct = dict_to_proto_struct(filter)
else:
filter_struct = None
args_dict = self._parse_non_empty_args([("filter", filter_struct)])
args_dict = parse_non_empty_args([("filter", filter_struct)])
timeout = kwargs.pop("timeout", None)

request = DescribeIndexStatsRequest(**args_dict)
response = self.runner.run(self.stub.DescribeIndexStats, request, timeout=timeout)
json_response = json_format.MessageToDict(response)
return parse_stats_response(json_response)

@staticmethod
def _parse_non_empty_args(args: List[Tuple[str, Any]]) -> Dict[str, Any]:
return {arg_name: val for arg_name, val in args if val is not None}

@staticmethod
def _parse_sparse_values_arg(
sparse_values: Optional[Union[GRPCSparseValues, SparseVectorTypedDict]],
) -> Optional[GRPCSparseValues]:
if sparse_values is None:
return None

if isinstance(sparse_values, GRPCSparseValues):
return sparse_values

if (
not isinstance(sparse_values, dict)
or "indices" not in sparse_values
or "values" not in sparse_values
):
raise ValueError(
"Invalid sparse values argument. Expected a dict of: {'indices': List[int], 'values': List[float]}."
f"Received: {sparse_values}"
)

return GRPCSparseValues(indices=sparse_values["indices"], values=sparse_values["values"])
Loading

0 comments on commit 6d99df7

Please sign in to comment.