Skip to content

Commit

Permalink
remove license header
Browse files Browse the repository at this point in the history
  • Loading branch information
gdj0nes committed Sep 26, 2023
1 parent 1ba7d59 commit d87165b
Show file tree
Hide file tree
Showing 3 changed files with 314 additions and 227 deletions.
97 changes: 57 additions & 40 deletions pinecone/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
#
# Copyright (c) 2020-2021 Pinecone Systems Inc. All right reserved.
#
import logging
import sys
from typing import NamedTuple, List
Expand All @@ -16,13 +13,17 @@
from pinecone.core.client.exceptions import ApiKeyError
from pinecone.core.api_action import ActionAPI, WhoAmIResponse
from pinecone.core.utils import warn_deprecated, check_kwargs
from pinecone.core.utils.constants import CLIENT_VERSION, PARENT_LOGGER_NAME, DEFAULT_PARENT_LOGGER_LEVEL, \
TCP_KEEPIDLE, TCP_KEEPINTVL, TCP_KEEPCNT
from pinecone.core.utils.constants import (
CLIENT_VERSION,
PARENT_LOGGER_NAME,
DEFAULT_PARENT_LOGGER_LEVEL,
TCP_KEEPIDLE,
TCP_KEEPINTVL,
TCP_KEEPCNT,
)
from pinecone.core.client.configuration import Configuration as OpenApiConfiguration

__all__ = [
"Config", "init"
]
__all__ = ["Config", "init"]

_logger = logging.getLogger(__name__)
_parent_logger = logging.getLogger(PARENT_LOGGER_NAME)
Expand Down Expand Up @@ -63,10 +64,10 @@ def reset(self, config_file=None, **kwargs):

# Get the environment first. Make sure that it is not overwritten in subsequent config objects.
environment = (
kwargs.pop("environment", None)
or os.getenv("PINECONE_ENVIRONMENT")
or file_config.pop("environment", None)
or "us-west1-gcp"
kwargs.pop("environment", None)
or os.getenv("PINECONE_ENVIRONMENT")
or file_config.pop("environment", None)
or "us-west1-gcp"
)
config = config._replace(environment=environment)

Expand Down Expand Up @@ -102,24 +103,21 @@ def reset(self, config_file=None, **kwargs):

if not self._config.project_name:
config = config._replace(
**self._preprocess_and_validate_config({'project_name': whoami_response.projectname}))
**self._preprocess_and_validate_config({"project_name": whoami_response.projectname})
)

self._config = config

# Set OpenAPI client config
default_openapi_config = OpenApiConfiguration.get_default_copy()
default_openapi_config.ssl_ca_cert = certifi.where()
openapi_config = (
kwargs.pop("openapi_config", None)
or default_openapi_config
)
openapi_config = kwargs.pop("openapi_config", None) or default_openapi_config

openapi_config.socket_options = self._get_socket_options()

config = config._replace(openapi_config=openapi_config)
self._config = config


def _preprocess_and_validate_config(self, config: dict) -> dict:
"""Normalize, filter, and validate config keys/values.
Expand All @@ -128,9 +126,9 @@ def _preprocess_and_validate_config(self, config: dict) -> dict:
"""
# general preprocessing and filtering
result = {k: v for k, v in config.items() if k in ConfigBase._fields if v is not None}
result.pop('environment', None)
result.pop("environment", None)
# validate api key
api_key = result.get('api_key')
api_key = result.get("api_key")
# if api_key:
# try:
# uuid.UUID(api_key)
Expand All @@ -152,11 +150,12 @@ def _load_config_file(self, config_file: str) -> dict:
return config_obj

@staticmethod
def _get_socket_options(do_keep_alive: bool = True,
keep_alive_idle_sec: int = TCP_KEEPIDLE,
keep_alive_interval_sec: int = TCP_KEEPINTVL,
keep_alive_tries: int = TCP_KEEPCNT
) -> List[tuple]:
def _get_socket_options(
do_keep_alive: bool = True,
keep_alive_idle_sec: int = TCP_KEEPIDLE,
keep_alive_interval_sec: int = TCP_KEEPINTVL,
keep_alive_tries: int = TCP_KEEPCNT,
) -> List[tuple]:
"""
Returns the socket options to pass to OpenAPI's Rest client
Args:
Expand All @@ -179,8 +178,12 @@ def _get_socket_options(do_keep_alive: bool = True,
# TCP Keep Alive Probes for different platforms
platform = sys.platform
# TCP Keep Alive Probes for Linux
if platform == 'linux' and hasattr(socket, "TCP_KEEPIDLE") and hasattr(socket, "TCP_KEEPINTVL") \
and hasattr(socket, "TCP_KEEPCNT"):
if (
platform == "linux"
and hasattr(socket, "TCP_KEEPIDLE")
and hasattr(socket, "TCP_KEEPINTVL")
and hasattr(socket, "TCP_KEEPCNT")
):
socket_params += [(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, keep_alive_idle_sec)]
socket_params += [(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, keep_alive_interval_sec)]
socket_params += [(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, keep_alive_tries)]
Expand All @@ -193,7 +196,7 @@ def _get_socket_options(do_keep_alive: bool = True,
# socket.ioctl((socket.SIO_KEEPALIVE_VALS, (1, keep_alive_idle_sec * 1000, keep_alive_interval_sec * 1000)))

# TCP Keep Alive Probes for Mac OS
elif platform == 'darwin':
elif platform == "darwin":
TCP_KEEPALIVE = 0x10
socket_params += [(socket.IPPROTO_TCP, TCP_KEEPALIVE, keep_alive_interval_sec)]

Expand Down Expand Up @@ -226,15 +229,22 @@ def LOG_LEVEL(self):
"""
warn_deprecated(
description='LOG_LEVEL is deprecated. Use the standard logging module logger "pinecone" instead.',
deprecated_in='2.0.2',
removal_in='3.0.0'
deprecated_in="2.0.2",
removal_in="3.0.0",
)
return logging.getLevelName(logging.getLogger('pinecone').level)


def init(api_key: str = None, host: str = None, environment: str = None, project_name: str = None,
log_level: str = None, openapi_config: OpenApiConfiguration = None,
config: str = "~/.pinecone", **kwargs):
return logging.getLevelName(logging.getLogger("pinecone").level)


def init(
api_key: str = None,
host: str = None,
environment: str = None,
project_name: str = None,
log_level: str = None,
openapi_config: OpenApiConfiguration = None,
config: str = "~/.pinecone",
**kwargs
):
"""Initializes the Pinecone client.
:param api_key: Required if not set in config file or by environment variable ``PINECONE_API_KEY``.
Expand All @@ -246,13 +256,20 @@ def init(api_key: str = None, host: str = None, environment: str = None, project
:param log_level: Deprecated since v2.0.2 [Will be removed in v3.0.0]; use the standard logging module to manage logger "pinecone" instead.
"""
check_kwargs(init, kwargs)
Config.reset(project_name=project_name, api_key=api_key, controller_host=host, environment=environment,
openapi_config=openapi_config, config_file=config, **kwargs)
Config.reset(
project_name=project_name,
api_key=api_key,
controller_host=host,
environment=environment,
openapi_config=openapi_config,
config_file=config,
**kwargs
)
if log_level:
warn_deprecated(
description='log_level is deprecated. Use the standard logging module to manage logger "pinecone" instead.',
deprecated_in='2.0.2',
removal_in='3.0.0'
deprecated_in="2.0.2",
removal_in="3.0.0",
)


Expand Down
69 changes: 43 additions & 26 deletions pinecone/core/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
#
# Copyright (c) 2020-2021 Pinecone Systems Inc. All right reserved.
#
import inspect
import logging
import re
Expand All @@ -24,11 +21,15 @@
DNS_COMPATIBLE_REGEX = re.compile("^[a-z0-9]([a-z0-9]|[-])+[a-z0-9]$")


def dump_numpy_public(np_array: 'np.ndarray', compressed: bool = False) -> 'vector_column_service_pb2.NdArray':
def dump_numpy_public(np_array: "np.ndarray", compressed: bool = False) -> "vector_column_service_pb2.NdArray":
"""
Dump numpy array to vector_column_service_pb2.NdArray
"""
warn_deprecated('dump_numpy_public and all numpy-related features will be removed in a future version', deprecated_in='2.2.1', removal_in='3.0.0')
warn_deprecated(
"dump_numpy_public and all numpy-related features will be removed in a future version",
deprecated_in="2.2.1",
removal_in="3.0.0",
)
protobuf_arr = vector_column_service_pb2.NdArray()
protobuf_arr.dtype = str(np_array.dtype)
protobuf_arr.shape.extend(np_array.shape)
Expand All @@ -40,24 +41,30 @@ def dump_numpy_public(np_array: 'np.ndarray', compressed: bool = False) -> 'vect
return protobuf_arr


def dump_strings_public(strs: List[str], compressed: bool = False) -> 'vector_column_service_pb2.NdArray':
warn_deprecated('dump_strings_public and all numpy-related features will be removed in a future version', deprecated_in='2.2.1', removal_in='3.0.0')
return dump_numpy_public(np.array(strs, dtype='S'), compressed=compressed)
def dump_strings_public(strs: List[str], compressed: bool = False) -> "vector_column_service_pb2.NdArray":
warn_deprecated(
"dump_strings_public and all numpy-related features will be removed in a future version",
deprecated_in="2.2.1",
removal_in="3.0.0",
)
return dump_numpy_public(np.array(strs, dtype="S"), compressed=compressed)


def get_version():
return Path(__file__).parent.parent.parent.joinpath('__version__').read_text().strip()
return Path(__file__).parent.parent.parent.joinpath("__version__").read_text().strip()


def get_environment():
return Path(__file__).parent.parent.parent.joinpath('__environment__').read_text().strip()
return Path(__file__).parent.parent.parent.joinpath("__environment__").read_text().strip()


def validate_dns_name(name):
if not DNS_COMPATIBLE_REGEX.match(name):
raise ValueError("{} is invalid - service names and node names must consist of lower case "
"alphanumeric characters or '-', start with an alphabetic character, and end with an "
"alphanumeric character (e.g. 'my-name', or 'abc-123')".format(name))
raise ValueError(
"{} is invalid - service names and node names must consist of lower case "
"alphanumeric characters or '-', start with an alphabetic character, and end with an "
"alphanumeric character (e.g. 'my-name', or 'abc-123')".format(name)
)


def _generate_request_id() -> str:
Expand All @@ -70,31 +77,35 @@ def fix_tuple_length(t, n):


def get_user_agent():
client_id = f'python-client-{get_version()}'
user_agent_details = {'requests': requests.__version__, 'urllib3': urllib3.__version__}
user_agent = '{} ({})'.format(client_id, ', '.join([f'{k}:{v}' for k, v in user_agent_details.items()]))
client_id = f"python-client-{get_version()}"
user_agent_details = {"requests": requests.__version__, "urllib3": urllib3.__version__}
user_agent = "{} ({})".format(client_id, ", ".join([f"{k}:{v}" for k, v in user_agent_details.items()]))
return user_agent


def dict_to_proto_struct(d: dict) -> 'Struct':
def dict_to_proto_struct(d: dict) -> "Struct":
if not d:
d = {}
s = Struct()
s.update(d)
return s


def proto_struct_to_dict(s: 'Struct') -> dict:
def proto_struct_to_dict(s: "Struct") -> dict:
return json_format.MessageToDict(s)


def load_numpy_public(proto_arr: 'vector_column_service_pb2.NdArray') -> 'np.ndarray':
def load_numpy_public(proto_arr: "vector_column_service_pb2.NdArray") -> "np.ndarray":
"""
Load numpy array from protobuf
:param proto_arr:
:return:
"""
warn_deprecated('load_numpy_public and all numpy-related features will be removed in a future version', deprecated_in='2.2.1', removal_in='3.0.0')
warn_deprecated(
"load_numpy_public and all numpy-related features will be removed in a future version",
deprecated_in="2.2.1",
removal_in="3.0.0",
)
if len(proto_arr.shape) == 0:
return np.array([])
if proto_arr.compressed:
Expand All @@ -104,16 +115,22 @@ def load_numpy_public(proto_arr: 'vector_column_service_pb2.NdArray') -> 'np.nda
return numpy_arr.reshape(proto_arr.shape)


def load_strings_public(proto_arr: 'vector_column_service_pb2.NdArray') -> List[str]:
warn_deprecated('load_strings_public and all numpy-related features will be removed in a future version', deprecated_in='2.2.1', removal_in='3.0.0')
return [str(item, 'utf-8') for item in load_numpy_public(proto_arr)]
def load_strings_public(proto_arr: "vector_column_service_pb2.NdArray") -> List[str]:
warn_deprecated(
"load_strings_public and all numpy-related features will be removed in a future version",
deprecated_in="2.2.1",
removal_in="3.0.0",
)
return [str(item, "utf-8") for item in load_numpy_public(proto_arr)]

def warn_deprecated(description: str = '', deprecated_in: str = None, removal_in: str = None):
message = f'DEPRECATED since v{deprecated_in} [Will be removed in v{removal_in}]: {description}'

def warn_deprecated(description: str = "", deprecated_in: str = None, removal_in: str = None):
message = f"DEPRECATED since v{deprecated_in} [Will be removed in v{removal_in}]: {description}"
warnings.warn(message, FutureWarning)


def check_kwargs(caller, given):
argspec = inspect.getfullargspec(caller)
diff = set(given).difference(argspec.args)
if diff:
logging.exception(caller.__name__ + ' had unexpected keyword argument(s): ' + ', '.join(diff), exc_info=False)
logging.exception(caller.__name__ + " had unexpected keyword argument(s): " + ", ".join(diff), exc_info=False)
Loading

0 comments on commit d87165b

Please sign in to comment.