Skip to content

Commit

Permalink
feat: add load from dict function to CredentialProvider, GlobalConfig…
Browse files Browse the repository at this point in the history
…, ClientConfig, and CogniteClient
  • Loading branch information
nodegard committed Jul 23, 2024
1 parent ab02470 commit e316362
Show file tree
Hide file tree
Showing 8 changed files with 302 additions and 104 deletions.
112 changes: 71 additions & 41 deletions cognite/client/_cognite_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from cognite.client._api.vision import VisionAPI
from cognite.client._api.workflows import WorkflowAPI
from cognite.client._api_client import APIClient
from cognite.client.config import ClientConfig, GlobalConfig, global_config
from cognite.client.config import ClientConfig, global_config
from cognite.client.credentials import CredentialProvider, OAuthClientCredentials, OAuthInteractive
from cognite.client.utils._auxiliary import get_current_sdk_version

Expand Down Expand Up @@ -219,59 +219,89 @@ def default_oauth_interactive(
return cls.default(project, cdf_cluster, credentials, client_name)

@classmethod
def from_dict(cls, config: dict[str, Any]) -> CogniteClient:
cognite_sdk_config_input = config.get("cognite")
if cognite_sdk_config_input is None:
raise ValueError("cognite section is missing in the configuration file")

global_config_input = cognite_sdk_config_input.get("global_config")
if global_config_input:
# TODO: set global config based on input
# GlobalConfig.from_dictionary(**global_config_input)
global_config = GlobalConfig() # noqa: F841

client_config_input = cognite_sdk_config_input.get("client_config")
if client_config_input:
credentials_config_input = client_config_input.get("credentials")
if credentials_config_input is None:
raise ValueError("credentials section is missing in the configuration file")
else:
credentials = CredentialProvider.load(credentials_config_input)
client_config_input["credentials"] = credentials
client_config = ClientConfig(**client_config_input)
else:
raise ValueError("client_config section is missing in the configuration file")
def load(cls, config: dict) -> CogniteClient:
"""Loads a dictionary of configuration fields into a cognite client object.
return cls(client_config)
Args:
config (dict): A dictionary containing configuration values needed to create a CogniteClient.
@classmethod
Returns:
CogniteClient: A cognite client object.
Examples:
Create a cognite client object from a dictionary input:
>>> from cognite.client import CogniteClient
>>> import os
>>> config = {
... "client_name": "abcd",
... "project": "cdf-project",
... "base_url": "https://api.cognitedata.com/",
... "client_credentials": {
... "client_id": "abcd",
... "client_secret": os.environ["OAUTH_CLIENT_SECRET"],
... "token_url": os.environ["TOKEN_URL"],
... "scopes": ["https://greenfield.cognitedata.com/.default"],
... # Any additional IDP-specific token args. e.g.
... "audience": "some-audience",
... }
... }
>>> client = CogniteClient.load(config)
"""
return cls(ClientConfig.load(config))

@classmethod # TODO: design discussion on if we should have this method or not, and if it should sub envs or not
def from_yaml(cls, file_path: str | Path) -> CogniteClient:
# TODO: docstring, type hints, and error handling
"""Loads a YAML file containing configuration fields into a cognite client object.
Any environment variables in the YAML file will be replaced with their defined values given they are referenced
using the following syntax: ${ENV_VAR_NAME} (recommended) or $ENV_VAR_NAME.
Note: The environment variables must be defined in the current environment and there are no implicit environment
variables available in the YAML file (e.g. CDF_PROJECT will not automatically replace the project name
unless `project: ${CDF_PROJECT}` is defined in the YAML file).
Args:
file_path (str | Path): The path to the YAML file containing the configuration values needed to create a CogniteClient.
Returns:
CogniteClient: A cognite client object.
Examples:
Create a cognite client object from a YAML file, using envs from the current environment:
>>> config.yaml
>>> project: $MY_CDF_PROJECT
>>> base_url: https://${MY_CDF_CLUSTER}.cognitedata.com/
>>> client_credentials:
>>> token: ${MY_CDF_TOKEN}
>>> from cognite.client import CogniteClient
>>> client = CogniteClient.from_yaml("config.yaml")
Create a cognite client object from a YAML file, using envs from a .env file:
>>> from cognite.client import CogniteClient
>>> from dotenv import load_dotenv
>>> load_dotenv()
>>> client = CogniteClient.from_yaml("config.yaml")
"""
file_path = Path(file_path)
if not file_path.is_file():
raise ValueError(f"File {file_path} is not a file")

try:
with file_path.open("r") as file_raw:
sub_template = Template(file_raw.read()) # FIXME: use string.Template or expand yaml.SafeLoader class

env_dict = dict(os.environ) # FIXME: is load_dotenv() needed?

# TODO: get all missing env vars and raise error if any (without using 3.11 Template attributes)
# if not sub_template.is_valid(): # type: ignore[attr-defined]
# raise ValueError("Invalid template")

# all_identifiers = sub_template.get_identifiers() # type: ignore[attr-defined]

# missing_env_vars = set(all_identifiers) - set(env_dict.keys())
# if missing_env_vars:
# raise ValueError(f"Missing environment variables: {missing_env_vars}")
env_sub_template = Template(file_raw.read())

file_env_parsed = sub_template.safe_substitute(env_dict)
try:
file_env_parsed = env_sub_template.substitute(dict(os.environ))
except KeyError as e:
raise ValueError(f"Error substituting environment variable: {e}")
except ValueError as e:
raise ValueError(f"Error substituting environment variable: {e}")

config_input = yaml.safe_load(file_env_parsed)
except yaml.YAMLError as e:
raise ValueError(f"Error parsing YAML file {file_path}: {e}")

return cls.from_dict(config_input)
return cls.load(config_input)
68 changes: 68 additions & 0 deletions cognite/client/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,36 @@ def __init__(self) -> None:
self.max_workers: int = 5
self.silence_feature_preview_warnings: bool = False

@classmethod
def load(cls, config: dict) -> GlobalConfig:
"""Loads a dictionary of configuration fields into a client config object.
Note: This must be done before instantiating a CogniteClient for the configuration to take effect.
Args:
config (dict): A dictionary containing configuration values defined in the GlobalConfig class.
Returns:
GlobalConfig: A global configuration object.
Examples:
Create a global config object from a dictionary input:
>>> from cognite.client.config import GlobalConfig
>>> config = {
... "max_retries": 5,
... "disable_ssl": True,
... }
>>> global_config = GlobalConfig.load(config)
"""
global_config = cls()
for key, value in config.items():
if not hasattr(global_config, key):
raise ValueError(f"Invalid key in global config: {key}")
setattr(global_config, key, value)

return global_config


global_config = GlobalConfig()

Expand Down Expand Up @@ -163,3 +193,41 @@ def default(
credentials=credentials,
base_url=f"https://{cdf_cluster}.cognitedata.com/",
)

@classmethod
def load(cls, config: dict) -> ClientConfig:
"""Loads a dictionary of configuration fields into a client config object.
Args:
config (dict): A dictionary containing configuration values defined in the ClientConfig class.
Returns:
ClientConfig: A client config object.
Examples:
Create a client config object from a dictionary input:
>>> from cognite.client.config import ClientConfig
>>> import os
>>> config = {
... "client_name": "abcd",
... "project": "cdf-project",
... "base_url": "https://api.cognitedata.com/",
... "client_credentials": {
... "client_id": "abcd",
... "client_secret": os.environ["OAUTH_CLIENT_SECRET"],
... "token_url": os.environ["TOKEN_URL"],
... "scopes": ["https://greenfield.cognitedata.com/.default"],
... # Any additional IDP-specific token args. e.g.
... "audience": "some-audience",
... }
... }
>>> client_config = ClientConfig.load(config)
"""
try:
credentials_config_input = config.pop("credentials")
except KeyError:
raise ValueError("'credentials' is a required field and must be included in the input dictionary.")

credentials = CredentialProvider.load(credentials_config_input)
return ClientConfig(credentials=credentials, **config)
65 changes: 46 additions & 19 deletions cognite/client/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,37 +24,64 @@ def authorization_header(self) -> tuple[str, str]:
raise NotImplementedError

@classmethod
def load(cls, resource: dict) -> CredentialProvider:
def load(cls, config: dict) -> CredentialProvider:
"""Create a CredentialProvider from a configuration dictionary.
Args:
resource (dict): The type of credential provider.
config (dict): A dictionary containing the configuration for the credential provider.
The dictionary must contain exactly one top level key, which is the type of the credential provider and must be one of the following strings:
"token", "client_credentials", "interactive", "device_code", "client_certificate".
The value of the key is a dictionary containing the configuration for the credential provider.
Returns:
CredentialProvider: Initialized credential provider of the specified type.
Examples:
Get a token credential provider:
>>> from cognite.client.credentials import CredentialProvider
>>> credential_provider = CredentialProvider.from_config("token", "my secret token")
>>> config = {"token": "my secret token"}
>>> credential_provider = CredentialProvider.load(config)
Get a client credential provider:
>>> from cognite.client.credentials import CredentialProvider
>>> import os
>>> config = {
... "client_credentials": {
... "client_id": "abcd",
... "client_secret": os.environ["OAUTH_CLIENT_SECRET"],
... "token_url": os.environ["TOKEN_URL"],
... "scopes": ["https://greenfield.cognitedata.com/.default"],
... # Any additional IDP-specific token args. e.g.
... "audience": "some-audience",
... }
... }
>>> credential_provider = CredentialProvider.load(config)
"""
if len(resource) != 1:
raise ValueError("Credential provider configuration must contain exactly one key-value pair.")

credential_type, config = next(iter(resource.items()))

if credential_type == "token":
return Token(config)
elif credential_type == "o_auth_client_credentials":
return OAuthClientCredentials(**config)
elif credential_type == "o_auth_interactive":
return OAuthInteractive(**config)
elif credential_type == "o_auth_device_code":
return OAuthDeviceCode(**config)
elif credential_type == "o_auth_client_certificate":
return OAuthClientCertificate(**config)
if not isinstance(config, dict) or len(config) != 1:
raise ValueError(
"Credential provider configuration must be a dictionary containing exactly one top level key."
)

credential_type, credential_config = next(iter(config.items()))

supported_credential_types = {
"token": Token,
"client_credentials": OAuthClientCredentials,
"interactive": OAuthInteractive,
"device_code": OAuthDeviceCode,
"client_certificate": OAuthClientCertificate,
}

if credential_type not in supported_credential_types.keys():
raise ValueError(
f"Invalid credential provider type: '{credential_type}', the valid options are {list(supported_credential_types.keys())}."
)
elif credential_type == "token":
return supported_credential_types[credential_type](credential_config)
else:
raise ValueError(f"The provided credential type {credential_type} is not supported.")
return supported_credential_types[credential_type](**credential_config)


class Token(CredentialProvider):
Expand Down
22 changes: 17 additions & 5 deletions tests/tests_unit/test_cognite_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def set_env_vars(monkeypatch):
env_vars = {
"COGNITE_PROJECT": "test-project",
"COGNITE_CLIENT_NAME": "test-project",
"credential_type": "o_auth_client_credentials",
"credential_type": "client_credentials",
"URL": "test",
"COGNITE_CLIENT_SECRET": "test-client-secret",
"COGNITE_DEBUG": "true",
Expand Down Expand Up @@ -116,9 +116,21 @@ def test_verify_ssl_enabled_by_default(self, rsps, client_config_w_token_factory
assert client._api_client._http_client_with_retry.session.verify is True
assert client._api_client._http_client.session.verify is True

def test_client_from_yaml(self):
path = os.path.join(os.path.dirname(__file__), "test_config.yaml")
client = CogniteClient.from_yaml(path)
def test_client_load(self):
config = {
"project": "test-project",
"client_name": "cognite-sdk-python",
"debug": True,
"credentials": {
"client_credentials": {
"client_id": "test-client-id",
"client_secret": "test-client-secret",
"token_url": TOKEN_URL,
"scopes": ["https://test.com/.default", "https://test.com/.admin"],
}
},
}
client = CogniteClient.load(config)
assert client.config.project == "test-project"
assert client.config.credentials.client_id == "test-client-id"
assert client.config.credentials.client_secret == "test-client-secret"
Expand All @@ -138,7 +150,7 @@ def test_client_from_yaml_with_envs(self, set_env_vars):

def test_client_from_yaml_missing_envs(self):
path = os.path.join(os.path.dirname(__file__), "test_config_envs.yaml")
with pytest.raises(ValueError, match="Missing environment variables: .*"):
with pytest.raises(ValueError, match=r"Error substituting environment variable: .*"):
CogniteClient.from_yaml(path)


Expand Down
50 changes: 50 additions & 0 deletions tests/tests_unit/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import pytest

from cognite.client.config import ClientConfig, GlobalConfig
from cognite.client.credentials import Token


class TestGlobalConfig:
def test_load(self):
config = {
"max_workers": 5,
"max_retries": 3,
}
global_config = GlobalConfig.load(config)
assert global_config.max_workers == 5
assert global_config.max_retries == 3

def test_load_non_existent_attr(self):
config = {
"test": 10,
}
with pytest.raises(ValueError, match=r"Invalid key in global config: .*"):
GlobalConfig.load(config)


class TestClientConfig:
def test_default(self):
config = {
"project": "test-project",
"cdf_cluster": "test-cluster",
"credentials": Token("abc"),
"client_name": "test-client",
}
client_config = ClientConfig.default(**config)
assert client_config.project == "test-project"
assert client_config.base_url == "https://test-cluster.cognitedata.com"
assert isinstance(client_config.credentials, Token)
assert client_config.client_name == "test-client"

def test_load(self):
config = {
"project": "test-project",
"base_url": "https://test-cluster.cognitedata.com/",
"credentials": {"token": "abc"},
"client_name": "test-client",
}
client_config = ClientConfig.load(config)
assert client_config.project == "test-project"
assert client_config.base_url == "https://test-cluster.cognitedata.com"
assert isinstance(client_config.credentials, Token)
assert client_config.client_name == "test-client"
Loading

0 comments on commit e316362

Please sign in to comment.