From a1bc758089a0739dabc9996b9497c816e7b6d967 Mon Sep 17 00:00:00 2001 From: Jen Hamon Date: Tue, 15 Oct 2024 11:38:36 -0400 Subject: [PATCH] Add unit tests for grpc runner --- pinecone/grpc/grpc_runner.py | 28 +++++++++++++++------------- pinecone/grpc/index_grpc.py | 6 +----- 2 files changed, 16 insertions(+), 18 deletions(-) diff --git a/pinecone/grpc/grpc_runner.py b/pinecone/grpc/grpc_runner.py index 86a00095..253a6b33 100644 --- a/pinecone/grpc/grpc_runner.py +++ b/pinecone/grpc/grpc_runner.py @@ -1,5 +1,5 @@ from functools import wraps -from typing import Dict, Tuple +from typing import Dict, Tuple, Optional from grpc._channel import _InactiveRpcError @@ -8,6 +8,8 @@ from .config import GRPCClientConfig from pinecone.utils.constants import REQUEST_ID, CLIENT_VERSION from pinecone.exceptions.exceptions import PineconeException +from grpc import CallCredentials, Compression +from google.protobuf.message import Message class GrpcRunner: @@ -26,12 +28,12 @@ def __init__(self, index_name: str, config: Config, grpc_config: GRPCClientConfi def run( self, func, - request, - timeout=None, - metadata=None, - credentials=None, - wait_for_ready=None, - compression=None, + request: Message, + timeout: Optional[int] = None, + metadata: Optional[Dict[str, str]] = None, + credentials: Optional[CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[Compression] = None, ): @wraps(func) def wrapped(): @@ -54,12 +56,12 @@ def wrapped(): async def run_asyncio( self, func, - request, - timeout=None, - metadata=None, - credentials=None, - wait_for_ready=None, - compression=None, + request: Message, + timeout: Optional[int] = None, + metadata: Optional[Dict[str, str]] = None, + credentials: Optional[CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[Compression] = None, ): @wraps(func) async def wrapped(): diff --git a/pinecone/grpc/index_grpc.py b/pinecone/grpc/index_grpc.py index 877cd246..6269c23d 100644 --- a/pinecone/grpc/index_grpc.py +++ b/pinecone/grpc/index_grpc.py @@ -155,11 +155,7 @@ def upsert( return UpsertResponse(upserted_count=total_upserted) def _upsert_batch( - self, - vectors: List[GRPCVector], - namespace: Optional[str], - timeout: Optional[float], - **kwargs, + self, vectors: List[GRPCVector], namespace: Optional[str], timeout: Optional[int], **kwargs ) -> UpsertResponse: args_dict = self._parse_non_empty_args([("namespace", namespace)]) request = UpsertRequest(vectors=vectors, **args_dict)