diff --git a/pinecone/grpc/base.py b/pinecone/grpc/base.py index db1cabf4..e1e26792 100644 --- a/pinecone/grpc/base.py +++ b/pinecone/grpc/base.py @@ -3,18 +3,16 @@ from functools import wraps from typing import Dict, Optional -import certifi import grpc from grpc._channel import _InactiveRpcError, Channel -import json from .retry import RetryConfig +from .channel_factory import GrpcChannelFactory from pinecone import Config from .utils import _generate_request_id from .config import GRPCClientConfig -from pinecone.utils.constants import MAX_MSG_SIZE, REQUEST_ID, CLIENT_VERSION -from pinecone.utils.user_agent import get_user_agent_grpc +from pinecone.utils.constants import REQUEST_ID, CLIENT_VERSION from pinecone.exceptions.exceptions import PineconeException _logger = logging.getLogger(__name__) @@ -35,8 +33,6 @@ def __init__( grpc_config: Optional[GRPCClientConfig] = None, _endpoint_override: Optional[str] = None, ): - self.name = index_name - self.config = config self.grpc_client_config = grpc_config or GRPCClientConfig() self.retry_config = self.grpc_client_config.retry_config or RetryConfig() @@ -51,35 +47,10 @@ def __init__( self._endpoint_override = _endpoint_override - self.method_config = json.dumps( - { - "methodConfig": [ - { - "name": [{"service": "VectorService.Upsert"}], - "retryPolicy": { - "maxAttempts": 5, - "initialBackoff": "0.1s", - "maxBackoff": "1s", - "backoffMultiplier": 2, - "retryableStatusCodes": ["UNAVAILABLE"], - }, - }, - { - "name": [{"service": "VectorService"}], - "retryPolicy": { - "maxAttempts": 5, - "initialBackoff": "0.1s", - "maxBackoff": "1s", - "backoffMultiplier": 2, - "retryableStatusCodes": ["UNAVAILABLE"], - }, - }, - ] - } + self.channel_factory = GrpcChannelFactory( + config=self.config, grpc_client_config=self.grpc_client_config, use_asyncio=False ) - - options = {"grpc.primary_user_agent": get_user_agent_grpc(config)} - self._channel = channel or self._gen_channel(options=options) + self._channel = channel or self._gen_channel() self.stub = self.stub_class(self._channel) @property @@ -93,36 +64,8 @@ def _endpoint(self): grpc_host = f"{grpc_host}:443" return self._endpoint_override if self._endpoint_override else grpc_host - def _gen_channel(self, options=None): - target = self._endpoint() - default_options = { - "grpc.max_send_message_length": MAX_MSG_SIZE, - "grpc.max_receive_message_length": MAX_MSG_SIZE, - "grpc.service_config": self.method_config, - "grpc.enable_retries": True, - "grpc.per_rpc_retry_buffer_size": MAX_MSG_SIZE, - } - if self.grpc_client_config.secure: - default_options["grpc.ssl_target_name_override"] = target.split(":")[0] - if self.config.proxy_url: - default_options["grpc.http_proxy"] = self.config.proxy_url - user_provided_options = options or {} - _options = tuple((k, v) for k, v in {**default_options, **user_provided_options}.items()) - _logger.debug( - "creating new channel with endpoint %s options %s and config %s", - target, - _options, - self.grpc_client_config, - ) - if not self.grpc_client_config.secure: - channel = grpc.insecure_channel(target, options=_options) - else: - ca_certs = self.config.ssl_ca_certs if self.config.ssl_ca_certs else certifi.where() - root_cas = open(ca_certs, "rb").read() - tls = grpc.ssl_channel_credentials(root_certificates=root_cas) - channel = grpc.secure_channel(target, tls, options=_options) - - return channel + def _gen_channel(self): + return self.channel_factory.create_channel(self._endpoint()) @property def channel(self): diff --git a/pinecone/grpc/channel_factory.py b/pinecone/grpc/channel_factory.py new file mode 100644 index 00000000..fd7444ca --- /dev/null +++ b/pinecone/grpc/channel_factory.py @@ -0,0 +1,100 @@ +import logging +from typing import Optional + +import certifi +import grpc +import json + +from pinecone import Config +from .config import GRPCClientConfig +from pinecone.utils.constants import MAX_MSG_SIZE +from pinecone.utils.user_agent import get_user_agent_grpc + +_logger = logging.getLogger(__name__) + + +class GrpcChannelFactory: + def __init__( + self, + config: Config, + grpc_client_config: GRPCClientConfig, + use_asyncio: Optional[bool] = False, + ): + self.config = config + self.grpc_client_config = grpc_client_config + self.use_asyncio = use_asyncio + + def _get_service_config(self): + # https://github.com/grpc/grpc-proto/blob/master/grpc/service_config/service_config.proto + return json.dumps( + { + "methodConfig": [ + { + "name": [{"service": "VectorService.Upsert"}], + "retryPolicy": { + "maxAttempts": 5, + "initialBackoff": "0.1s", + "maxBackoff": "1s", + "backoffMultiplier": 2, + "retryableStatusCodes": ["UNAVAILABLE"], + }, + }, + { + "name": [{"service": "VectorService"}], + "retryPolicy": { + "maxAttempts": 5, + "initialBackoff": "0.1s", + "maxBackoff": "1s", + "backoffMultiplier": 2, + "retryableStatusCodes": ["UNAVAILABLE"], + }, + }, + ] + } + ) + + def _build_options(self, target): + # For property definitions, see https://github.com/grpc/grpc/blob/v1.43.x/include/grpc/impl/codegen/grpc_types.h + options = { + "grpc.max_send_message_length": MAX_MSG_SIZE, + "grpc.max_receive_message_length": MAX_MSG_SIZE, + "grpc.service_config": self._get_service_config(), + "grpc.enable_retries": True, + "grpc.per_rpc_retry_buffer_size": MAX_MSG_SIZE, + "grpc.primary_user_agent": get_user_agent_grpc(self.config), + } + if self.grpc_client_config.secure: + options["grpc.ssl_target_name_override"] = target.split(":")[0] + if self.config.proxy_url: + options["grpc.http_proxy"] = self.config.proxy_url + + options_tuple = tuple((k, v) for k, v in options.items()) + return options_tuple + + def _build_channel_credentials(self): + ca_certs = self.config.ssl_ca_certs if self.config.ssl_ca_certs else certifi.where() + root_cas = open(ca_certs, "rb").read() + channel_creds = grpc.ssl_channel_credentials(root_certificates=root_cas) + return channel_creds + + def create_channel(self, endpoint): + options_tuple = self._build_options(endpoint) + + _logger.debug( + "Creating new channel with endpoint %s options %s and config %s", + endpoint, + options_tuple, + self.grpc_client_config, + ) + + if not self.grpc_client_config.secure: + create_channel_fn = ( + grpc.aio.insecure_channel if self.use_asyncio else grpc.insecure_channel + ) + channel = create_channel_fn(endpoint, options=options_tuple) + else: + channel_creds = self._build_channel_credentials() + create_channel_fn = grpc.aio.secure_channel if self.use_asyncio else grpc.secure_channel + channel = create_channel_fn(endpoint, credentials=channel_creds, options=options_tuple) + + return channel diff --git a/tests/unit_grpc/test_channel_factory.py b/tests/unit_grpc/test_channel_factory.py new file mode 100644 index 00000000..bac13202 --- /dev/null +++ b/tests/unit_grpc/test_channel_factory.py @@ -0,0 +1,141 @@ +import grpc +import re +import pytest +from unittest.mock import patch, MagicMock, ANY + +from pinecone import Config +from pinecone.grpc.channel_factory import GrpcChannelFactory, GRPCClientConfig +from pinecone.utils.constants import MAX_MSG_SIZE + + +@pytest.fixture +def config(): + return Config(ssl_ca_certs=None, proxy_url=None) + + +@pytest.fixture +def grpc_client_config(): + return GRPCClientConfig(secure=True) + + +class TestGrpcChannelFactory: + def test_create_secure_channel_with_default_settings(self, config, grpc_client_config): + factory = GrpcChannelFactory( + config=config, grpc_client_config=grpc_client_config, use_asyncio=False + ) + endpoint = "test.endpoint:443" + + with patch("grpc.secure_channel") as mock_secure_channel, patch( + "certifi.where", return_value="/path/to/certifi/cacert.pem" + ), patch("builtins.open", new_callable=MagicMock) as mock_open: + # Mock the file object to return bytes when read() is called + mock_file = MagicMock() + mock_file.read.return_value = b"mocked_cert_data" + mock_open.return_value = mock_file + channel = factory.create_channel(endpoint) + + mock_secure_channel.assert_called_once() + assert mock_secure_channel.call_args[0][0] == endpoint + assert isinstance(mock_secure_channel.call_args[1]["options"], tuple) + + options = dict(mock_secure_channel.call_args[1]["options"]) + assert options["grpc.ssl_target_name_override"] == "test.endpoint" + assert options["grpc.max_send_message_length"] == MAX_MSG_SIZE + assert options["grpc.per_rpc_retry_buffer_size"] == MAX_MSG_SIZE + assert options["grpc.max_receive_message_length"] == MAX_MSG_SIZE + assert "grpc.service_config" in options + assert options["grpc.enable_retries"] is True + assert ( + re.search( + r"python-client\[grpc\]-\d+\.\d+\.\d+", options["grpc.primary_user_agent"] + ) + is not None + ) + + assert isinstance(channel, MagicMock) + + def test_create_secure_channel_with_proxy(self): + grpc_client_config = GRPCClientConfig(secure=True) + config = Config(proxy_url="http://test.proxy:8080") + factory = GrpcChannelFactory( + config=config, grpc_client_config=grpc_client_config, use_asyncio=False + ) + endpoint = "test.endpoint:443" + + with patch("grpc.secure_channel") as mock_secure_channel: + channel = factory.create_channel(endpoint) + + mock_secure_channel.assert_called_once() + assert "grpc.http_proxy" in dict(mock_secure_channel.call_args[1]["options"]) + assert ( + "http://test.proxy:8080" + == dict(mock_secure_channel.call_args[1]["options"])["grpc.http_proxy"] + ) + assert isinstance(channel, MagicMock) + + def test_create_insecure_channel(self, config): + grpc_client_config = GRPCClientConfig(secure=False) + factory = GrpcChannelFactory( + config=config, grpc_client_config=grpc_client_config, use_asyncio=False + ) + endpoint = "test.endpoint:50051" + + with patch("grpc.insecure_channel") as mock_insecure_channel: + channel = factory.create_channel(endpoint) + + mock_insecure_channel.assert_called_once_with(endpoint, options=ANY) + assert isinstance(channel, MagicMock) + + +class TestGrpcChannelFactoryAsyncio: + def test_create_secure_channel_with_default_settings(self, config, grpc_client_config): + factory = GrpcChannelFactory( + config=config, grpc_client_config=grpc_client_config, use_asyncio=True + ) + endpoint = "test.endpoint:443" + + with patch("grpc.aio.secure_channel") as mock_secure_aio_channel, patch( + "certifi.where", return_value="/path/to/certifi/cacert.pem" + ), patch("builtins.open", new_callable=MagicMock) as mock_open: + # Mock the file object to return bytes when read() is called + mock_file = MagicMock() + mock_file.read.return_value = b"mocked_cert_data" + mock_open.return_value = mock_file + channel = factory.create_channel(endpoint) + + mock_secure_aio_channel.assert_called_once() + assert mock_secure_aio_channel.call_args[0][0] == endpoint + assert isinstance(mock_secure_aio_channel.call_args[1]["options"], tuple) + + options = dict(mock_secure_aio_channel.call_args[1]["options"]) + assert options["grpc.ssl_target_name_override"] == "test.endpoint" + assert options["grpc.max_send_message_length"] == MAX_MSG_SIZE + assert options["grpc.per_rpc_retry_buffer_size"] == MAX_MSG_SIZE + assert options["grpc.max_receive_message_length"] == MAX_MSG_SIZE + assert "grpc.service_config" in options + assert options["grpc.enable_retries"] is True + assert ( + re.search( + r"python-client\[grpc\]-\d+\.\d+\.\d+", options["grpc.primary_user_agent"] + ) + is not None + ) + + security_credentials = mock_secure_aio_channel.call_args[1]["credentials"] + assert security_credentials is not None + assert isinstance(security_credentials, grpc.ChannelCredentials) + + assert isinstance(channel, MagicMock) + + def test_create_insecure_channel_asyncio(self, config): + grpc_client_config = GRPCClientConfig(secure=False) + factory = GrpcChannelFactory( + config=config, grpc_client_config=grpc_client_config, use_asyncio=True + ) + endpoint = "test.endpoint:50051" + + with patch("grpc.aio.insecure_channel") as mock_aio_insecure_channel: + channel = factory.create_channel(endpoint) + + mock_aio_insecure_channel.assert_called_once_with(endpoint, options=ANY) + assert isinstance(channel, MagicMock)