diff --git a/pinecone/config.py b/pinecone/config.py index 4e29f013..dfde4b97 100644 --- a/pinecone/config.py +++ b/pinecone/config.py @@ -1,6 +1,3 @@ -# -# Copyright (c) 2020-2021 Pinecone Systems Inc. All right reserved. -# import logging import sys from typing import NamedTuple, List @@ -16,13 +13,17 @@ from pinecone.core.client.exceptions import ApiKeyError from pinecone.core.api_action import ActionAPI, WhoAmIResponse from pinecone.core.utils import warn_deprecated, check_kwargs -from pinecone.core.utils.constants import CLIENT_VERSION, PARENT_LOGGER_NAME, DEFAULT_PARENT_LOGGER_LEVEL, \ - TCP_KEEPIDLE, TCP_KEEPINTVL, TCP_KEEPCNT +from pinecone.core.utils.constants import ( + CLIENT_VERSION, + PARENT_LOGGER_NAME, + DEFAULT_PARENT_LOGGER_LEVEL, + TCP_KEEPIDLE, + TCP_KEEPINTVL, + TCP_KEEPCNT, +) from pinecone.core.client.configuration import Configuration as OpenApiConfiguration -__all__ = [ - "Config", "init" -] +__all__ = ["Config", "init"] _logger = logging.getLogger(__name__) _parent_logger = logging.getLogger(PARENT_LOGGER_NAME) @@ -63,10 +64,10 @@ def reset(self, config_file=None, **kwargs): # Get the environment first. Make sure that it is not overwritten in subsequent config objects. environment = ( - kwargs.pop("environment", None) - or os.getenv("PINECONE_ENVIRONMENT") - or file_config.pop("environment", None) - or "us-west1-gcp" + kwargs.pop("environment", None) + or os.getenv("PINECONE_ENVIRONMENT") + or file_config.pop("environment", None) + or "us-west1-gcp" ) config = config._replace(environment=environment) @@ -102,24 +103,21 @@ def reset(self, config_file=None, **kwargs): if not self._config.project_name: config = config._replace( - **self._preprocess_and_validate_config({'project_name': whoami_response.projectname})) + **self._preprocess_and_validate_config({"project_name": whoami_response.projectname}) + ) self._config = config # Set OpenAPI client config default_openapi_config = OpenApiConfiguration.get_default_copy() default_openapi_config.ssl_ca_cert = certifi.where() - openapi_config = ( - kwargs.pop("openapi_config", None) - or default_openapi_config - ) + openapi_config = kwargs.pop("openapi_config", None) or default_openapi_config openapi_config.socket_options = self._get_socket_options() config = config._replace(openapi_config=openapi_config) self._config = config - def _preprocess_and_validate_config(self, config: dict) -> dict: """Normalize, filter, and validate config keys/values. @@ -128,9 +126,9 @@ def _preprocess_and_validate_config(self, config: dict) -> dict: """ # general preprocessing and filtering result = {k: v for k, v in config.items() if k in ConfigBase._fields if v is not None} - result.pop('environment', None) + result.pop("environment", None) # validate api key - api_key = result.get('api_key') + api_key = result.get("api_key") # if api_key: # try: # uuid.UUID(api_key) @@ -152,11 +150,12 @@ def _load_config_file(self, config_file: str) -> dict: return config_obj @staticmethod - def _get_socket_options(do_keep_alive: bool = True, - keep_alive_idle_sec: int = TCP_KEEPIDLE, - keep_alive_interval_sec: int = TCP_KEEPINTVL, - keep_alive_tries: int = TCP_KEEPCNT - ) -> List[tuple]: + def _get_socket_options( + do_keep_alive: bool = True, + keep_alive_idle_sec: int = TCP_KEEPIDLE, + keep_alive_interval_sec: int = TCP_KEEPINTVL, + keep_alive_tries: int = TCP_KEEPCNT, + ) -> List[tuple]: """ Returns the socket options to pass to OpenAPI's Rest client Args: @@ -179,8 +178,12 @@ def _get_socket_options(do_keep_alive: bool = True, # TCP Keep Alive Probes for different platforms platform = sys.platform # TCP Keep Alive Probes for Linux - if platform == 'linux' and hasattr(socket, "TCP_KEEPIDLE") and hasattr(socket, "TCP_KEEPINTVL") \ - and hasattr(socket, "TCP_KEEPCNT"): + if ( + platform == "linux" + and hasattr(socket, "TCP_KEEPIDLE") + and hasattr(socket, "TCP_KEEPINTVL") + and hasattr(socket, "TCP_KEEPCNT") + ): socket_params += [(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, keep_alive_idle_sec)] socket_params += [(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, keep_alive_interval_sec)] socket_params += [(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, keep_alive_tries)] @@ -193,7 +196,7 @@ def _get_socket_options(do_keep_alive: bool = True, # socket.ioctl((socket.SIO_KEEPALIVE_VALS, (1, keep_alive_idle_sec * 1000, keep_alive_interval_sec * 1000))) # TCP Keep Alive Probes for Mac OS - elif platform == 'darwin': + elif platform == "darwin": TCP_KEEPALIVE = 0x10 socket_params += [(socket.IPPROTO_TCP, TCP_KEEPALIVE, keep_alive_interval_sec)] @@ -226,15 +229,22 @@ def LOG_LEVEL(self): """ warn_deprecated( description='LOG_LEVEL is deprecated. Use the standard logging module logger "pinecone" instead.', - deprecated_in='2.0.2', - removal_in='3.0.0' + deprecated_in="2.0.2", + removal_in="3.0.0", ) - return logging.getLevelName(logging.getLogger('pinecone').level) - - -def init(api_key: str = None, host: str = None, environment: str = None, project_name: str = None, - log_level: str = None, openapi_config: OpenApiConfiguration = None, - config: str = "~/.pinecone", **kwargs): + return logging.getLevelName(logging.getLogger("pinecone").level) + + +def init( + api_key: str = None, + host: str = None, + environment: str = None, + project_name: str = None, + log_level: str = None, + openapi_config: OpenApiConfiguration = None, + config: str = "~/.pinecone", + **kwargs +): """Initializes the Pinecone client. :param api_key: Required if not set in config file or by environment variable ``PINECONE_API_KEY``. @@ -246,13 +256,20 @@ def init(api_key: str = None, host: str = None, environment: str = None, project :param log_level: Deprecated since v2.0.2 [Will be removed in v3.0.0]; use the standard logging module to manage logger "pinecone" instead. """ check_kwargs(init, kwargs) - Config.reset(project_name=project_name, api_key=api_key, controller_host=host, environment=environment, - openapi_config=openapi_config, config_file=config, **kwargs) + Config.reset( + project_name=project_name, + api_key=api_key, + controller_host=host, + environment=environment, + openapi_config=openapi_config, + config_file=config, + **kwargs + ) if log_level: warn_deprecated( description='log_level is deprecated. Use the standard logging module to manage logger "pinecone" instead.', - deprecated_in='2.0.2', - removal_in='3.0.0' + deprecated_in="2.0.2", + removal_in="3.0.0", ) diff --git a/pinecone/core/utils/__init__.py b/pinecone/core/utils/__init__.py index cc944eb7..d4cca164 100644 --- a/pinecone/core/utils/__init__.py +++ b/pinecone/core/utils/__init__.py @@ -1,6 +1,3 @@ -# -# Copyright (c) 2020-2021 Pinecone Systems Inc. All right reserved. -# import inspect import logging import re @@ -24,11 +21,15 @@ DNS_COMPATIBLE_REGEX = re.compile("^[a-z0-9]([a-z0-9]|[-])+[a-z0-9]$") -def dump_numpy_public(np_array: 'np.ndarray', compressed: bool = False) -> 'vector_column_service_pb2.NdArray': +def dump_numpy_public(np_array: "np.ndarray", compressed: bool = False) -> "vector_column_service_pb2.NdArray": """ Dump numpy array to vector_column_service_pb2.NdArray """ - warn_deprecated('dump_numpy_public and all numpy-related features will be removed in a future version', deprecated_in='2.2.1', removal_in='3.0.0') + warn_deprecated( + "dump_numpy_public and all numpy-related features will be removed in a future version", + deprecated_in="2.2.1", + removal_in="3.0.0", + ) protobuf_arr = vector_column_service_pb2.NdArray() protobuf_arr.dtype = str(np_array.dtype) protobuf_arr.shape.extend(np_array.shape) @@ -40,24 +41,30 @@ def dump_numpy_public(np_array: 'np.ndarray', compressed: bool = False) -> 'vect return protobuf_arr -def dump_strings_public(strs: List[str], compressed: bool = False) -> 'vector_column_service_pb2.NdArray': - warn_deprecated('dump_strings_public and all numpy-related features will be removed in a future version', deprecated_in='2.2.1', removal_in='3.0.0') - return dump_numpy_public(np.array(strs, dtype='S'), compressed=compressed) +def dump_strings_public(strs: List[str], compressed: bool = False) -> "vector_column_service_pb2.NdArray": + warn_deprecated( + "dump_strings_public and all numpy-related features will be removed in a future version", + deprecated_in="2.2.1", + removal_in="3.0.0", + ) + return dump_numpy_public(np.array(strs, dtype="S"), compressed=compressed) def get_version(): - return Path(__file__).parent.parent.parent.joinpath('__version__').read_text().strip() + return Path(__file__).parent.parent.parent.joinpath("__version__").read_text().strip() def get_environment(): - return Path(__file__).parent.parent.parent.joinpath('__environment__').read_text().strip() + return Path(__file__).parent.parent.parent.joinpath("__environment__").read_text().strip() def validate_dns_name(name): if not DNS_COMPATIBLE_REGEX.match(name): - raise ValueError("{} is invalid - service names and node names must consist of lower case " - "alphanumeric characters or '-', start with an alphabetic character, and end with an " - "alphanumeric character (e.g. 'my-name', or 'abc-123')".format(name)) + raise ValueError( + "{} is invalid - service names and node names must consist of lower case " + "alphanumeric characters or '-', start with an alphabetic character, and end with an " + "alphanumeric character (e.g. 'my-name', or 'abc-123')".format(name) + ) def _generate_request_id() -> str: @@ -70,13 +77,13 @@ def fix_tuple_length(t, n): def get_user_agent(): - client_id = f'python-client-{get_version()}' - user_agent_details = {'requests': requests.__version__, 'urllib3': urllib3.__version__} - user_agent = '{} ({})'.format(client_id, ', '.join([f'{k}:{v}' for k, v in user_agent_details.items()])) + client_id = f"python-client-{get_version()}" + user_agent_details = {"requests": requests.__version__, "urllib3": urllib3.__version__} + user_agent = "{} ({})".format(client_id, ", ".join([f"{k}:{v}" for k, v in user_agent_details.items()])) return user_agent -def dict_to_proto_struct(d: dict) -> 'Struct': +def dict_to_proto_struct(d: dict) -> "Struct": if not d: d = {} s = Struct() @@ -84,17 +91,21 @@ def dict_to_proto_struct(d: dict) -> 'Struct': return s -def proto_struct_to_dict(s: 'Struct') -> dict: +def proto_struct_to_dict(s: "Struct") -> dict: return json_format.MessageToDict(s) -def load_numpy_public(proto_arr: 'vector_column_service_pb2.NdArray') -> 'np.ndarray': +def load_numpy_public(proto_arr: "vector_column_service_pb2.NdArray") -> "np.ndarray": """ Load numpy array from protobuf :param proto_arr: :return: """ - warn_deprecated('load_numpy_public and all numpy-related features will be removed in a future version', deprecated_in='2.2.1', removal_in='3.0.0') + warn_deprecated( + "load_numpy_public and all numpy-related features will be removed in a future version", + deprecated_in="2.2.1", + removal_in="3.0.0", + ) if len(proto_arr.shape) == 0: return np.array([]) if proto_arr.compressed: @@ -104,16 +115,22 @@ def load_numpy_public(proto_arr: 'vector_column_service_pb2.NdArray') -> 'np.nda return numpy_arr.reshape(proto_arr.shape) -def load_strings_public(proto_arr: 'vector_column_service_pb2.NdArray') -> List[str]: - warn_deprecated('load_strings_public and all numpy-related features will be removed in a future version', deprecated_in='2.2.1', removal_in='3.0.0') - return [str(item, 'utf-8') for item in load_numpy_public(proto_arr)] +def load_strings_public(proto_arr: "vector_column_service_pb2.NdArray") -> List[str]: + warn_deprecated( + "load_strings_public and all numpy-related features will be removed in a future version", + deprecated_in="2.2.1", + removal_in="3.0.0", + ) + return [str(item, "utf-8") for item in load_numpy_public(proto_arr)] -def warn_deprecated(description: str = '', deprecated_in: str = None, removal_in: str = None): - message = f'DEPRECATED since v{deprecated_in} [Will be removed in v{removal_in}]: {description}' + +def warn_deprecated(description: str = "", deprecated_in: str = None, removal_in: str = None): + message = f"DEPRECATED since v{deprecated_in} [Will be removed in v{removal_in}]: {description}" warnings.warn(message, FutureWarning) + def check_kwargs(caller, given): argspec = inspect.getfullargspec(caller) diff = set(given).difference(argspec.args) if diff: - logging.exception(caller.__name__ + ' had unexpected keyword argument(s): ' + ', '.join(diff), exc_info=False) \ No newline at end of file + logging.exception(caller.__name__ + " had unexpected keyword argument(s): " + ", ".join(diff), exc_info=False) diff --git a/pinecone/index.py b/pinecone/index.py index d7b9b7fc..f3e68ba2 100644 --- a/pinecone/index.py +++ b/pinecone/index.py @@ -1,6 +1,3 @@ -# -# Copyright (c) 2020-2021 Pinecone Systems Inc. All right reserved. -# from tqdm.autonotebook import tqdm from importlib.util import find_spec import numbers @@ -12,39 +9,77 @@ from .core.client.model.sparse_values import SparseValues from pinecone import Config from pinecone.core.client import ApiClient -from .core.client.models import FetchResponse, ProtobufAny, QueryRequest, QueryResponse, QueryVector, RpcStatus, \ - ScoredVector, SingleQueryResults, DescribeIndexStatsResponse, UpsertRequest, UpsertResponse, UpdateRequest, \ - Vector, DeleteRequest, UpdateRequest, DescribeIndexStatsRequest +from .core.client.models import ( + FetchResponse, + ProtobufAny, + QueryRequest, + QueryResponse, + QueryVector, + RpcStatus, + ScoredVector, + SingleQueryResults, + DescribeIndexStatsResponse, + UpsertRequest, + UpsertResponse, + UpdateRequest, + Vector, + DeleteRequest, + UpdateRequest, + DescribeIndexStatsRequest, +) from pinecone.core.client.api.vector_operations_api import VectorOperationsApi from pinecone.core.utils import fix_tuple_length, get_user_agent, warn_deprecated import copy __all__ = [ - "Index", "FetchResponse", "ProtobufAny", "QueryRequest", "QueryResponse", "QueryVector", "RpcStatus", - "ScoredVector", "SingleQueryResults", "DescribeIndexStatsResponse", "UpsertRequest", "UpsertResponse", - "UpdateRequest", "Vector", "DeleteRequest", "UpdateRequest", "DescribeIndexStatsRequest", "SparseValues" + "Index", + "FetchResponse", + "ProtobufAny", + "QueryRequest", + "QueryResponse", + "QueryVector", + "RpcStatus", + "ScoredVector", + "SingleQueryResults", + "DescribeIndexStatsResponse", + "UpsertRequest", + "UpsertResponse", + "UpdateRequest", + "Vector", + "DeleteRequest", + "UpdateRequest", + "DescribeIndexStatsRequest", + "SparseValues", ] from .core.utils.constants import REQUIRED_VECTOR_FIELDS, OPTIONAL_VECTOR_FIELDS from .core.utils.error_handling import validate_and_convert_errors _OPENAPI_ENDPOINT_PARAMS = ( - '_return_http_data_only', '_preload_content', '_request_timeout', - '_check_input_type', '_check_return_type', '_host_index', 'async_req' + "_return_http_data_only", + "_preload_content", + "_request_timeout", + "_check_input_type", + "_check_return_type", + "_host_index", + "async_req", ) + def parse_query_response(response: QueryResponse, unary_query: bool): if unary_query: - response._data_store.pop('results', None) + response._data_store.pop("results", None) else: - response._data_store.pop('matches', None) - response._data_store.pop('namespace', None) + response._data_store.pop("matches", None) + response._data_store.pop("namespace", None) return response + def upsert_numpy_deprecation_notice(context): numpy_deprecataion_notice = "The ability to pass a numpy ndarray as part of a dictionary argument to upsert() will be removed in a future version of the pinecone client. To remove this warning, use the numpy.ndarray.tolist method to convert your ndarray into a python list before calling upsert()." message = " ".join([context, numpy_deprecataion_notice]) - warn_deprecated(message, deprecated_in='2.2.1', removal_in='3.0.0') + warn_deprecated(message, deprecated_in="2.2.1", removal_in="3.0.0") + class Index(ApiClient): @@ -52,30 +87,29 @@ class Index(ApiClient): A client for interacting with a Pinecone index via REST API. For improved performance, use the Pinecone GRPC index client. """ + def __init__(self, index_name: str, pool_threads=1): openapi_client_config = copy.deepcopy(Config.OPENAPI_CONFIG) openapi_client_config.api_key = openapi_client_config.api_key or {} - openapi_client_config.api_key['ApiKeyAuth'] = openapi_client_config.api_key.get('ApiKeyAuth', Config.API_KEY) + openapi_client_config.api_key["ApiKeyAuth"] = openapi_client_config.api_key.get("ApiKeyAuth", Config.API_KEY) openapi_client_config.server_variables = openapi_client_config.server_variables or {} openapi_client_config.server_variables = { - **{ - 'environment': Config.ENVIRONMENT, - 'index_name': index_name, - 'project_name': Config.PROJECT_NAME - }, - **openapi_client_config.server_variables + **{"environment": Config.ENVIRONMENT, "index_name": index_name, "project_name": Config.PROJECT_NAME}, + **openapi_client_config.server_variables, } super().__init__(configuration=openapi_client_config, pool_threads=pool_threads) self.user_agent = get_user_agent() self._vector_api = VectorOperationsApi(self) @validate_and_convert_errors - def upsert(self, - vectors: Union[List[Vector], List[tuple], List[dict]], - namespace: Optional[str] = None, - batch_size: Optional[int] = None, - show_progress: bool = True, - **kwargs) -> UpsertResponse: + def upsert( + self, + vectors: Union[List[Vector], List[tuple], List[dict]], + namespace: Optional[str] = None, + batch_size: Optional[int] = None, + show_progress: bool = True, + **kwargs + ) -> UpsertResponse: """ The upsert operation writes vectors into a namespace. If a new value is upserted for an existing vector id, it will overwrite the previous value. @@ -139,55 +173,58 @@ def upsert(self, Returns: UpsertResponse, includes the number of vectors upserted. """ - _check_type = kwargs.pop('_check_type', False) + _check_type = kwargs.pop("_check_type", False) - if kwargs.get('async_req', False) and batch_size is not None: - raise ValueError('async_req is not supported when batch_size is provided.' - 'To upsert in parallel, please follow: ' - 'https://docs.pinecone.io/docs/insert-data#sending-upserts-in-parallel') + if kwargs.get("async_req", False) and batch_size is not None: + raise ValueError( + "async_req is not supported when batch_size is provided." + "To upsert in parallel, please follow: " + "https://docs.pinecone.io/docs/insert-data#sending-upserts-in-parallel" + ) if batch_size is None: return self._upsert_batch(vectors, namespace, _check_type, **kwargs) if not isinstance(batch_size, int) or batch_size <= 0: - raise ValueError('batch_size must be a positive integer') + raise ValueError("batch_size must be a positive integer") - pbar = tqdm(total=len(vectors), disable=not show_progress, desc='Upserted vectors') + pbar = tqdm(total=len(vectors), disable=not show_progress, desc="Upserted vectors") total_upserted = 0 for i in range(0, len(vectors), batch_size): - batch_result = self._upsert_batch(vectors[i:i + batch_size], namespace, _check_type, **kwargs) + batch_result = self._upsert_batch(vectors[i : i + batch_size], namespace, _check_type, **kwargs) pbar.update(batch_result.upserted_count) # we can't use here pbar.n for the case show_progress=False total_upserted += batch_result.upserted_count return UpsertResponse(upserted_count=total_upserted) - def _upsert_batch(self, - vectors: List[Vector], - namespace: Optional[str], - _check_type: bool, - **kwargs) -> UpsertResponse: - - args_dict = self._parse_non_empty_args([('namespace', namespace)]) + def _upsert_batch( + self, vectors: List[Vector], namespace: Optional[str], _check_type: bool, **kwargs + ) -> UpsertResponse: + args_dict = self._parse_non_empty_args([("namespace", namespace)]) def _dict_to_vector(item): item_keys = set(item.keys()) if not item_keys.issuperset(REQUIRED_VECTOR_FIELDS): raise ValueError( - f"Vector dictionary is missing required fields: {list(REQUIRED_VECTOR_FIELDS - item_keys)}") + f"Vector dictionary is missing required fields: {list(REQUIRED_VECTOR_FIELDS - item_keys)}" + ) excessive_keys = item_keys - (REQUIRED_VECTOR_FIELDS | OPTIONAL_VECTOR_FIELDS) if len(excessive_keys) > 0: - raise ValueError(f"Found excess keys in the vector dictionary: {list(excessive_keys)}. " - f"The allowed keys are: {list(REQUIRED_VECTOR_FIELDS | OPTIONAL_VECTOR_FIELDS)}") + raise ValueError( + f"Found excess keys in the vector dictionary: {list(excessive_keys)}. " + f"The allowed keys are: {list(REQUIRED_VECTOR_FIELDS | OPTIONAL_VECTOR_FIELDS)}" + ) - if 'sparse_values' in item: - if not isinstance(item['sparse_values'], Mapping): + if "sparse_values" in item: + if not isinstance(item["sparse_values"], Mapping): raise ValueError( - f"Column `sparse_values` is expected to be a dictionary, found {type(item['sparse_values'])}") + f"Column `sparse_values` is expected to be a dictionary, found {type(item['sparse_values'])}" + ) - indices = item['sparse_values'].get('indices', None) - values = item['sparse_values'].get('values', None) + indices = item["sparse_values"].get("indices", None) + values = item["sparse_values"].get("values", None) if isinstance(values, np.ndarray): upsert_numpy_deprecation_notice("Deprecated type passed in sparse_values['values'].") @@ -196,27 +233,28 @@ def _dict_to_vector(item): upsert_numpy_deprecation_notice("Deprecated type passed in sparse_values['indices'].") indices = indices.tolist() try: - item['sparse_values'] = SparseValues(indices=indices, values=values) + item["sparse_values"] = SparseValues(indices=indices, values=values) except TypeError as e: - raise ValueError("Found unexpected data in column `sparse_values`. " - "Expected format is `'sparse_values': {'indices': List[int], 'values': List[float]}`." - ) from e + raise ValueError( + "Found unexpected data in column `sparse_values`. " + "Expected format is `'sparse_values': {'indices': List[int], 'values': List[float]}`." + ) from e - if 'metadata' in item: - metadata = item.get('metadata') + if "metadata" in item: + metadata = item.get("metadata") if not isinstance(metadata, Mapping): raise TypeError(f"Column `metadata` is expected to be a dictionary, found {type(metadata)}") - - if isinstance(item['values'], np.ndarray): + + if isinstance(item["values"], np.ndarray): upsert_numpy_deprecation_notice("Deprecated type passed in 'values'.") - item['values'] = item['values'].tolist() + item["values"] = item["values"].tolist() try: return Vector(**item) except TypeError as e: # if not isinstance(item['values'], Iterable) or not isinstance(item['values'][0], numbers.Real): # raise TypeError(f"Column `values` is expected to be a list of floats") - if not isinstance(item['values'], Iterable) or not isinstance(item['values'][0], numbers.Real): + if not isinstance(item["values"], Iterable) or not isinstance(item["values"][0], numbers.Real): raise TypeError(f"Column `values` is expected to be a list of floats") raise @@ -225,9 +263,11 @@ def _vector_transform(item: Union[Vector, Tuple]): return item elif isinstance(item, tuple): if len(item) > 3: - raise ValueError(f"Found a tuple of length {len(item)} which is not supported. " - f"Vectors can be represented as tuples either the form (id, values, metadata) or (id, values). " - f"To pass sparse values please use either dicts or a Vector objects as inputs.") + raise ValueError( + f"Found a tuple of length {len(item)} which is not supported. " + f"Vectors can be represented as tuples either the form (id, values, metadata) or (id, values). " + f"To pass sparse values please use either dicts or a Vector objects as inputs." + ) id, values, metadata = fix_tuple_length(item, 3) return Vector(id=id, values=values, metadata=metadata or {}, _check_type=_check_type) elif isinstance(item, Mapping): @@ -247,14 +287,12 @@ def _vector_transform(item: Union[Vector, Tuple]): @staticmethod def _iter_dataframe(df, batch_size): for i in range(0, len(df), batch_size): - batch = df.iloc[i:i + batch_size].to_dict(orient="records") + batch = df.iloc[i : i + batch_size].to_dict(orient="records") yield batch - def upsert_from_dataframe(self, - df, - namespace: str = None, - batch_size: int = 500, - show_progress: bool = True) -> UpsertResponse: + def upsert_from_dataframe( + self, df, namespace: str = None, batch_size: int = 500, show_progress: bool = True + ) -> UpsertResponse: """Upserts a dataframe into the index. Args: @@ -266,7 +304,9 @@ def upsert_from_dataframe(self, try: import pandas as pd except ImportError: - raise RuntimeError("The `pandas` package is not installed. Please install pandas to use `upsert_from_dataframe()`") + raise RuntimeError( + "The `pandas` package is not installed. Please install pandas to use `upsert_from_dataframe()`" + ) if not isinstance(df, pd.DataFrame): raise ValueError(f"Only pandas dataframes are supported. Found: {type(df)}") @@ -285,52 +325,53 @@ def upsert_from_dataframe(self, return UpsertResponse(upserted_count=upserted_count) @validate_and_convert_errors - def delete(self, - ids: Optional[List[str]] = None, - delete_all: Optional[bool] = None, - namespace: Optional[str] = None, - filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None, - **kwargs) -> Dict[str, Any]: + def delete( + self, + ids: Optional[List[str]] = None, + delete_all: Optional[bool] = None, + namespace: Optional[str] = None, + filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None, + **kwargs + ) -> Dict[str, Any]: """ - The Delete operation deletes vectors from the index, from a single namespace. - No error raised if the vector id does not exist. - Note: for any delete call, if namespace is not specified, the default namespace is used. - - Delete can occur in the following mutual exclusive ways: - 1. Delete by ids from a single namespace - 2. Delete all vectors from a single namespace by setting delete_all to True - 3. Delete all vectors from a single namespace by specifying a metadata filter - (note that for this option delete all must be set to False) - - API reference: https://docs.pinecone.io/reference/delete_post + The Delete operation deletes vectors from the index, from a single namespace. + No error raised if the vector id does not exist. + Note: for any delete call, if namespace is not specified, the default namespace is used. + + Delete can occur in the following mutual exclusive ways: + 1. Delete by ids from a single namespace + 2. Delete all vectors from a single namespace by setting delete_all to True + 3. Delete all vectors from a single namespace by specifying a metadata filter + (note that for this option delete all must be set to False) + + API reference: https://docs.pinecone.io/reference/delete_post + + Examples: + >>> index.delete(ids=['id1', 'id2'], namespace='my_namespace') + >>> index.delete(delete_all=True, namespace='my_namespace') + >>> index.delete(filter={'key': 'value'}, namespace='my_namespace') + + Args: + ids (List[str]): Vector ids to delete [optional] + delete_all (bool): This indicates that all vectors in the index namespace should be deleted.. [optional] + Default is False. + namespace (str): The namespace to delete vectors from [optional] + If not specified, the default namespace is used. + filter (Dict[str, Union[str, float, int, bool, List, dict]]): + If specified, the metadata filter here will be used to select the vectors to delete. + This is mutually exclusive with specifying ids to delete in the ids param or using delete_all=True. + See https://www.pinecone.io/docs/metadata-filtering/.. [optional] - Examples: - >>> index.delete(ids=['id1', 'id2'], namespace='my_namespace') - >>> index.delete(delete_all=True, namespace='my_namespace') - >>> index.delete(filter={'key': 'value'}, namespace='my_namespace') - - Args: - ids (List[str]): Vector ids to delete [optional] - delete_all (bool): This indicates that all vectors in the index namespace should be deleted.. [optional] - Default is False. - namespace (str): The namespace to delete vectors from [optional] - If not specified, the default namespace is used. - filter (Dict[str, Union[str, float, int, bool, List, dict]]): - If specified, the metadata filter here will be used to select the vectors to delete. - This is mutually exclusive with specifying ids to delete in the ids param or using delete_all=True. - See https://www.pinecone.io/docs/metadata-filtering/.. [optional] - - Keyword Args: - Supports OpenAPI client keyword arguments. See pinecone.core.client.models.DeleteRequest for more details. + Keyword Args: + Supports OpenAPI client keyword arguments. See pinecone.core.client.models.DeleteRequest for more details. - Returns: An empty dictionary if the delete operation was successful. + Returns: An empty dictionary if the delete operation was successful. """ - _check_type = kwargs.pop('_check_type', False) - args_dict = self._parse_non_empty_args([('ids', ids), - ('delete_all', delete_all), - ('namespace', namespace), - ('filter', filter)]) + _check_type = kwargs.pop("_check_type", False) + args_dict = self._parse_non_empty_args( + [("ids", ids), ("delete_all", delete_all), ("namespace", namespace), ("filter", filter)] + ) return self._vector_api.delete( DeleteRequest( @@ -342,10 +383,7 @@ def delete(self, ) @validate_and_convert_errors - def fetch(self, - ids: List[str], - namespace: Optional[str] = None, - **kwargs) -> FetchResponse: + def fetch(self, ids: List[str], namespace: Optional[str] = None, **kwargs) -> FetchResponse: """ The fetch operation looks up and returns vectors, by ID, from a single namespace. The returned vectors include the vector data and/or metadata. @@ -366,21 +404,23 @@ def fetch(self, Returns: FetchResponse object which contains the list of Vector objects, and namespace name. """ - args_dict = self._parse_non_empty_args([('namespace', namespace)]) + args_dict = self._parse_non_empty_args([("namespace", namespace)]) return self._vector_api.fetch(ids=ids, **args_dict, **kwargs) @validate_and_convert_errors - def query(self, - vector: Optional[List[float]] = None, - id: Optional[str] = None, - queries: Optional[Union[List[QueryVector], List[Tuple]]] = None, - top_k: Optional[int] = None, - namespace: Optional[str] = None, - filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None, - include_values: Optional[bool] = None, - include_metadata: Optional[bool] = None, - sparse_vector: Optional[Union[SparseValues, Dict[str, Union[List[float], List[int]]]]] = None, - **kwargs) -> QueryResponse: + def query( + self, + vector: Optional[List[float]] = None, + id: Optional[str] = None, + queries: Optional[Union[List[QueryVector], List[Tuple]]] = None, + top_k: Optional[int] = None, + namespace: Optional[str] = None, + filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None, + include_values: Optional[bool] = None, + include_metadata: Optional[bool] = None, + sparse_vector: Optional[Union[SparseValues, Dict[str, Union[List[float], List[int]]]]] = None, + **kwargs + ) -> QueryResponse: """ The Query operation searches a namespace, using a query vector. It retrieves the ids of the most similar items in a namespace, along with their similarity scores. @@ -426,6 +466,7 @@ def query(self, Returns: QueryResponse object which contains the list of the closest vectors as ScoredVector objects, and namespace name. """ + def _query_transform(item): if isinstance(item, QueryVector): return item @@ -439,19 +480,23 @@ def _query_transform(item): return QueryVector(values=item, _check_type=_check_type) raise ValueError(f"Invalid query vector value passed: cannot interpret type {type(item)}") - _check_type = kwargs.pop('_check_type', False) + _check_type = kwargs.pop("_check_type", False) queries = list(map(_query_transform, queries)) if queries is not None else None sparse_vector = self._parse_sparse_values_arg(sparse_vector) - args_dict = self._parse_non_empty_args([('vector', vector), - ('id', id), - ('queries', queries), - ('top_k', top_k), - ('namespace', namespace), - ('filter', filter), - ('include_values', include_values), - ('include_metadata', include_metadata), - ('sparse_vector', sparse_vector)]) + args_dict = self._parse_non_empty_args( + [ + ("vector", vector), + ("id", id), + ("queries", queries), + ("top_k", top_k), + ("namespace", namespace), + ("filter", filter), + ("include_values", include_values), + ("include_metadata", include_metadata), + ("sparse_vector", sparse_vector), + ] + ) response = self._vector_api.query( QueryRequest( **args_dict, @@ -463,14 +508,15 @@ def _query_transform(item): return parse_query_response(response, vector is not None or id) @validate_and_convert_errors - def update(self, - id: str, - values: Optional[List[float]] = None, - set_metadata: Optional[Dict[str, - Union[str, float, int, bool, List[int], List[float], List[str]]]] = None, - namespace: Optional[str] = None, - sparse_values: Optional[Union[SparseValues, Dict[str, Union[List[float], List[int]]]]] = None, - **kwargs) -> Dict[str, Any]: + def update( + self, + id: str, + values: Optional[List[float]] = None, + set_metadata: Optional[Dict[str, Union[str, float, int, bool, List[int], List[float], List[str]]]] = None, + namespace: Optional[str] = None, + sparse_values: Optional[Union[SparseValues, Dict[str, Union[List[float], List[int]]]]] = None, + **kwargs + ) -> Dict[str, Any]: """ The Update operation updates vector in a namespace. If a value is included, it will overwrite the previous value. @@ -502,24 +548,30 @@ def update(self, Returns: An empty dictionary if the update was successful. """ - _check_type = kwargs.pop('_check_type', False) + _check_type = kwargs.pop("_check_type", False) sparse_values = self._parse_sparse_values_arg(sparse_values) - args_dict = self._parse_non_empty_args([('values', values), - ('set_metadata', set_metadata), - ('namespace', namespace), - ('sparse_values', sparse_values)]) - return self._vector_api.update(UpdateRequest( + args_dict = self._parse_non_empty_args( + [ + ("values", values), + ("set_metadata", set_metadata), + ("namespace", namespace), + ("sparse_values", sparse_values), + ] + ) + return self._vector_api.update( + UpdateRequest( id=id, **args_dict, _check_type=_check_type, **{k: v for k, v in kwargs.items() if k not in _OPENAPI_ENDPOINT_PARAMS} ), - **{k: v for k, v in kwargs.items() if k in _OPENAPI_ENDPOINT_PARAMS}) + **{k: v for k, v in kwargs.items() if k in _OPENAPI_ENDPOINT_PARAMS} + ) @validate_and_convert_errors - def describe_index_stats(self, - filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None, - **kwargs) -> DescribeIndexStatsResponse: + def describe_index_stats( + self, filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None, **kwargs + ) -> DescribeIndexStatsResponse: """ The DescribeIndexStats operation returns statistics about the index's contents. For example: The vector count per namespace and the number of dimensions. @@ -537,8 +589,8 @@ def describe_index_stats(self, Returns: DescribeIndexStatsResponse object which contains stats about the index. """ - _check_type = kwargs.pop('_check_type', False) - args_dict = self._parse_non_empty_args([('filter', filter)]) + _check_type = kwargs.pop("_check_type", False) + args_dict = self._parse_non_empty_args([("filter", filter)]) return self._vector_api.describe_index_stats( DescribeIndexStatsRequest( @@ -555,8 +607,8 @@ def _parse_non_empty_args(args: List[Tuple[str, Any]]) -> Dict[str, Any]: @staticmethod def _parse_sparse_values_arg( - sparse_values: Optional[Union[SparseValues, - Dict[str, Union[List[float], List[int]]]]]) -> Optional[SparseValues]: + sparse_values: Optional[Union[SparseValues, Dict[str, Union[List[float], List[int]]]]] + ) -> Optional[SparseValues]: if sparse_values is None: return None @@ -566,6 +618,7 @@ def _parse_sparse_values_arg( if not isinstance(sparse_values, dict) or "indices" not in sparse_values or "values" not in sparse_values: raise ValueError( "Invalid sparse values argument. Expected a dict of: {'indices': List[int], 'values': List[float]}." - f"Received: {sparse_values}") + f"Received: {sparse_values}" + ) return SparseValues(indices=sparse_values["indices"], values=sparse_values["values"])