Skip to content

Commit

Permalink
Fix mypy errors in grpc code
Browse files Browse the repository at this point in the history
  • Loading branch information
jhamon committed Nov 4, 2023
1 parent feaf5a4 commit 310cdda
Show file tree
Hide file tree
Showing 9 changed files with 413 additions and 348 deletions.
8 changes: 8 additions & 0 deletions .github/workflows/testing.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,14 @@ jobs:
- name: Run unit tests (GRPC)
run: poetry run pytest --cov=pinecone --timeout=120 tests/unit_grpc

- name: mypy check
run: |
# Still lots of errors when running on the whole package (especially
# in the generated core module), but we can check these subpackages
# so we don't add new regressions.
poetry run mypy pinecone/grpc
package:
name: Check packaging
runs-on: ubuntu-latest
Expand Down
2 changes: 1 addition & 1 deletion pinecone/grpc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .index_grpc import GRPCIndex
from .pinecone import Pinecone
from .pinecone import PineconeGRPC
12 changes: 6 additions & 6 deletions pinecone/grpc/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import logging
from abc import ABC, abstractmethod
from functools import wraps
from typing import Dict
from typing import Dict, Optional

import certifi
import grpc
from grpc._channel import _InactiveRpcError
from grpc._channel import _InactiveRpcError, Channel
import json

from .retry import RetryConfig
Expand All @@ -29,15 +29,15 @@ def __init__(
self,
index_name: str,
config: Config,
channel=None,
grpc_config: GRPCClientConfig = None,
_endpoint_override: str = None,
channel: Optional[Channel] =None,
grpc_config: Optional[GRPCClientConfig] = None,
_endpoint_override: Optional[str] = None,
):
self.name = index_name

self.grpc_client_config = grpc_config or GRPCClientConfig()
self.retry_config = self.grpc_client_config.retry_config or RetryConfig()
self.fixed_metadata = {"api-key": config.API_KEY, "service-name": index_name, "client-version": CLIENT_VERSION}
self.fixed_metadata = {"api-key": config.api_key, "service-name": index_name, "client-version": CLIENT_VERSION}
self._endpoint_override = _endpoint_override

self.method_config = json.dumps(
Expand Down
2 changes: 1 addition & 1 deletion pinecone/grpc/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class GRPCClientConfig(NamedTuple):
conn_timeout: int = 1
reuse_channel: bool = True
retry_config: Optional[RetryConfig] = None
grpc_channel_options: Dict[str, str] = None
grpc_channel_options: Optional[Dict[str, str]] = None

@classmethod
def _from_dict(cls, kwargs: dict):
Expand Down
51 changes: 29 additions & 22 deletions pinecone/grpc/index_grpc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import numbers
from typing import Optional, Dict, Iterable, Union, List, Tuple, Any
from typing import Optional, Dict, Iterable, Union, List, Tuple, Any, TypedDict, cast
from collections.abc import Mapping

from google.protobuf import json_format
Expand Down Expand Up @@ -41,6 +41,10 @@

_logger = logging.getLogger(__name__)

class SparseVectorTypedDict(TypedDict):
indices: List[int]
values: List[float]


class GRPCIndex(GRPCIndexBase):
"""A client for interacting with a Pinecone index via GRPC API."""
Expand Down Expand Up @@ -119,7 +123,7 @@ def upsert(
"https://docs.pinecone.io/docs/performance-tuning"
)

def _dict_to_grpc_vector(item):
def _dict_to_grpc_vector(item) -> GRPCVector:
item_keys = set(item.keys())
if not item_keys.issuperset(REQUIRED_VECTOR_FIELDS):
raise ValueError(
Expand Down Expand Up @@ -150,7 +154,7 @@ def _dict_to_grpc_vector(item):
) from e

metadata = item.get("metadata", None)
if metadata is not None and not isinstance(metadata, Mapping):
if metadata is not None and not isinstance(metadata, Dict):
raise TypeError(f"Column `metadata` is expected to be a dictionary, found {type(metadata)}")

try:
Expand All @@ -163,11 +167,11 @@ def _dict_to_grpc_vector(item):

except TypeError as e:
# No need to raise a dedicated error for `id` - protobuf's error message is clear enough
if not isinstance(item["values"], Iterable) or not isinstance(item["values"][0], numbers.Real):
if not isinstance(item["values"], Iterable) or not isinstance(item["values"].__iter__().__next__(), numbers.Real):
raise TypeError(f"Column `values` is expected to be a list of floats")
raise

def _vector_transform(item):
def _vector_transform(item) -> GRPCVector:
if isinstance(item, GRPCVector):
return item
elif isinstance(item, tuple):
Expand All @@ -178,7 +182,7 @@ def _vector_transform(item):
f"To pass sparse values please use either dicts or GRPCVector objects as inputs."
)
id, values, metadata = fix_tuple_length(item, 3)
return GRPCVector(id=id, values=values, metadata=dict_to_proto_struct(metadata) or {})
return GRPCVector(id=id, values=values, metadata=dict_to_proto_struct(metadata) or None)
elif isinstance(item, Mapping):
return _dict_to_grpc_vector(item)
raise ValueError(f"Invalid vector value passed: cannot interpret type {type(item)}")
Expand Down Expand Up @@ -218,11 +222,11 @@ def _upsert_batch(
def upsert_from_dataframe(
self,
df,
namespace: str = None,
namespace: str = "",
batch_size: int = 500,
use_async_requests: bool = True,
show_progress: bool = True,
) -> None:
) -> UpsertResponse:
"""Upserts a dataframe into the index.
Args:
Expand Down Expand Up @@ -251,11 +255,13 @@ def upsert_from_dataframe(
results.append(res)

if use_async_requests:
results = [async_result.result() for async_result in tqdm(results, desc="collecting async responses")]
cast_results = cast(List[PineconeGrpcFuture], results)
results = [async_result.result() for async_result in tqdm(cast_results, desc="collecting async responses")]

upserted_count = 0
for res in results:
upserted_count += res.upserted_count
if hasattr(res, 'upserted_count') and isinstance(res.upserted_count, int):
upserted_count += res.upserted_count

return UpsertResponse(upserted_count=upserted_count)

Expand Down Expand Up @@ -307,10 +313,10 @@ def delete(
"""

if filter is not None:
filter = dict_to_proto_struct(filter)
filter_struct = dict_to_proto_struct(filter)

args_dict = self._parse_non_empty_args(
[("ids", ids), ("delete_all", delete_all), ("namespace", namespace), ("filter", filter)]
[("ids", ids), ("delete_all", delete_all), ("namespace", namespace), ("filter", filter_struct)]
)
timeout = kwargs.pop("timeout", None)

Expand Down Expand Up @@ -356,7 +362,7 @@ def query(
filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None,
include_values: Optional[bool] = None,
include_metadata: Optional[bool] = None,
sparse_vector: Optional[Union[GRPCSparseValues, Dict[str, Union[List[float], List[int]]]]] = None,
sparse_vector: Optional[Union[GRPCSparseValues, SparseVectorTypedDict]] = None,
**kwargs,
) -> QueryResponse:
"""
Expand Down Expand Up @@ -415,7 +421,7 @@ def _query_transform(item):
queries = list(map(_query_transform, queries)) if queries is not None else None

if filter is not None:
filter = dict_to_proto_struct(filter)
filter_struct = dict_to_proto_struct(filter)

sparse_vector = self._parse_sparse_values_arg(sparse_vector)
args_dict = self._parse_non_empty_args(
Expand All @@ -425,7 +431,7 @@ def _query_transform(item):
("queries", queries),
("namespace", namespace),
("top_k", top_k),
("filter", filter),
("filter", filter_struct),
("include_values", include_values),
("include_metadata", include_metadata),
("sparse_vector", sparse_vector),
Expand All @@ -437,7 +443,8 @@ def _query_transform(item):
timeout = kwargs.pop("timeout", None)
response = self._wrap_grpc_call(self.stub.Query, request, timeout=timeout)
json_response = json_format.MessageToDict(response)
return parse_query_response(json_response, vector is not None or id, _check_type=False)
unary_query = True if vector is not None or id else False
return parse_query_response(json_response, unary_query, _check_type=False)

def update(
self,
Expand All @@ -446,7 +453,7 @@ def update(
values: Optional[List[float]] = None,
set_metadata: Optional[Dict[str, Union[str, float, int, bool, List[int], List[float], List[str]]]] = None,
namespace: Optional[str] = None,
sparse_values: Optional[Union[GRPCSparseValues, Dict[str, Union[List[float], List[int]]]]] = None,
sparse_values: Optional[Union[GRPCSparseValues, SparseVectorTypedDict]] = None,
**kwargs,
) -> Union[UpdateResponse, PineconeGrpcFuture]:
"""
Expand Down Expand Up @@ -479,14 +486,14 @@ def update(
Returns: UpdateResponse (contains no data) or a PineconeGrpcFuture object if async_req is True.
"""
if set_metadata is not None:
set_metadata = dict_to_proto_struct(set_metadata)
set_metadata_struct = dict_to_proto_struct(set_metadata)
timeout = kwargs.pop("timeout", None)

sparse_values = self._parse_sparse_values_arg(sparse_values)
args_dict = self._parse_non_empty_args(
[
("values", values),
("set_metadata", set_metadata),
("set_metadata", set_metadata_struct),
("namespace", namespace),
("sparse_values", sparse_values),
]
Expand Down Expand Up @@ -518,8 +525,8 @@ def describe_index_stats(
Returns: DescribeIndexStatsResponse object which contains stats about the index.
"""
if filter is not None:
filter = dict_to_proto_struct(filter)
args_dict = self._parse_non_empty_args([("filter", filter)])
filter_struct = dict_to_proto_struct(filter)
args_dict = self._parse_non_empty_args([("filter", filter_struct)])
timeout = kwargs.pop("timeout", None)

request = DescribeIndexStatsRequest(**args_dict)
Expand All @@ -533,7 +540,7 @@ def _parse_non_empty_args(args: List[Tuple[str, Any]]) -> Dict[str, Any]:

@staticmethod
def _parse_sparse_values_arg(
sparse_values: Optional[Union[GRPCSparseValues, Dict[str, Union[List[float], List[int]]]]]
sparse_values: Optional[Union[GRPCSparseValues, SparseVectorTypedDict]]
) -> Optional[GRPCSparseValues]:
if sparse_values is None:
return None
Expand Down
6 changes: 4 additions & 2 deletions pinecone/grpc/pinecone.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from ..control.pinecone import Pinecone
from ..config.config import ConfigBuilder
from .index_grpc import GRPCIndex

class Pinecone(Pinecone):
class PineconeGRPC(Pinecone):
def Index(self, name: str):
index_host = self.index_host_store.get_host(self.index_api, self.config, name)
return GRPCIndex(api_key=self.config.API_KEY, host=index_host)
config = ConfigBuilder.build(api_key=self.config.api_key, host=index_host)
return GRPCIndex(index_name=name, config=config)
26 changes: 20 additions & 6 deletions pinecone/grpc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,18 @@
NamespaceSummary,
)

from typing import NamedTuple, Optional

class QueryResponseKwargs(NamedTuple):
check_type: bool
namespace: Optional[str]
matches: Optional[list]
results: Optional[list]

def _generate_request_id() -> str:
return str(uuid.uuid4())

def dict_to_proto_struct(d: dict) -> "Struct":
def dict_to_proto_struct(d: Optional[dict]) -> "Struct":
if not d:
d = {}
s = Struct()
Expand Down Expand Up @@ -80,13 +88,19 @@ def parse_query_response(response: dict, unary_query: bool, _check_type: bool =
)
m.append(sc)

kwargs = {"_check_type": _check_type}
if unary_query:
kwargs["namespace"] = response.get("namespace", "")
kwargs["matches"] = m
namespace = response.get("namespace", "")
matches = m
results = None
else:
kwargs["results"] = res
return QueryResponse(**kwargs)
namespace = None
matches = None
results = res

kw = QueryResponseKwargs(_check_type, namespace, matches, results)
kw_dict = kw._asdict()
kw_dict["_check_type"] = kw.check_type
return QueryResponse(**kw._asdict())


def parse_stats_response(response: dict):
Expand Down
Loading

0 comments on commit 310cdda

Please sign in to comment.