Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor OAuthDeviceCode to support non-Entra IdPs #1892

Merged
merged 16 commits into from
Sep 23, 2024
Merged
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
253 changes: 228 additions & 25 deletions cognite/client/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

import atexit
import inspect
import json
import tempfile
import threading
import time
from abc import abstractmethod
from datetime import datetime
from pathlib import Path
from types import MappingProxyType
from typing import Any, Callable, Protocol, runtime_checkable
Expand All @@ -15,7 +17,7 @@
from requests_oauthlib import OAuth2Session

from cognite.client.exceptions import CogniteAuthError
from cognite.client.utils._auxiliary import load_resource_to_dict
from cognite.client.utils._auxiliary import at_least_one_is_not_none, exactly_one_is_not_none, load_resource_to_dict

_TOKEN_EXPIRY_LEEWAY_SECONDS_DEFAULT = 30 # Do not change without also updating all the docstrings using it

Expand Down Expand Up @@ -192,12 +194,15 @@ def authorization_header(self) -> tuple[str, str]:

class _WithMsalSerializableTokenCache:
@staticmethod
def _create_serializable_token_cache(cache_path: Path) -> SerializableTokenCache:
def _create_serializable_token_cache(cache_path: Path, clear_cache: bool = False) -> SerializableTokenCache:
token_cache = SerializableTokenCache()

if cache_path.exists():
with cache_path.open() as fh:
token_cache.deserialize(fh.read())
if clear_cache:
cache_path.unlink(missing_ok=True)
else:
with cache_path.open() as fh:
token_cache.deserialize(fh.read())

def __at_exit() -> None:
if token_cache.has_state_changed:
Expand All @@ -211,29 +216,60 @@ def __at_exit() -> None:
def _resolve_token_cache_path(token_cache_path: Path | None, client_id: str) -> Path:
return token_cache_path or Path(tempfile.gettempdir()) / f"cognitetokencache.{client_id}.bin"

def _create_client_app(self, token_cache_path: Path, client_id: str, authority_url: str) -> PublicClientApplication:
def _create_client_app(
self,
token_cache_path: Path,
client_id: str,
authority_url: str | None = None,
oauth_discovery_url: str | None = None,
clear_cache: bool = False,
mem_cache_only: bool = False,
) -> PublicClientApplication:
from cognite.client.config import global_config

# In addition to caching in memory, we also cache the token on disk so it can be reused across processes:
serializable_token_cache = self._create_serializable_token_cache(token_cache_path)
if authority_url and oauth_discovery_url:
raise ValueError(
"Only one of 'authority_url' (for MS Entra) or 'oauth_discovery_url' (for other IdPs) should be provided."
)

# In addition to caching in memory, we also cache the token on disk so it can be reused across processes.
if mem_cache_only:
serializable_token_cache = SerializableTokenCache()
else:
serializable_token_cache = self._create_serializable_token_cache(token_cache_path, clear_cache)
return PublicClientApplication(
client_id=client_id,
authority=authority_url,
token_cache=serializable_token_cache,
verify=not global_config.disable_ssl,
oidc_authority=oauth_discovery_url,
instance_discovery=False, # Turn off to support non-Entra authorities
validate_authority=False, # Turn off to support non-Entra authorities
doctrino marked this conversation as resolved.
Show resolved Hide resolved
)

@staticmethod
def _get_cached_token(cache_path: Path) -> dict[str, Any]:
if not cache_path.exists():
return {}
token = json.loads(cache_path.read_text())
return token


class OAuthDeviceCode(_OAuthCredentialProviderWithTokenRefresh, _WithMsalSerializableTokenCache):
"""OAuth credential provider for the device code login flow.

Args:
authority_url (str): OAuth authority url
client_id (str): Your application's client id.
scopes (list[str]): A list of scopes.
authority_url (str | None): MS Entra OAuth authority url, typically "https://login.microsoftonline.com/{tenant_id}"
client_id (str): Your application's client id that allows device code flows.
scopes (list[str] | None): A list of scopes.
cdf_cluster (str | None): The CDF cluster where the CDF project is located. If provided, scopes will be set to
[f"https://{cdf_cluster}.cognitedata.com/IDENTITY https://{cdf_cluster}.cognitedata.com/user_impersonation openid profile"].
oauth_discovery_url (str | None): Standard OAuth discovery URL, should be where "/.well-known/openid-configuration" is found.
token_cache_path (Path | None): Location to store token cache, defaults to os temp directory/cognitetokencache.{client_id}.bin.
token_expiry_leeway_seconds (int): The token is refreshed at the earliest when this number of seconds is left before expiry. Default: 30 sec

clear_cache (bool): If True, the token cache will be cleared on initialization. Default: False
mem_cache_only (bool): If True, the token cache will only be stored in memory. Default: False
**token_custom_args (Any): Additional request parameters to pass to the authorization endpoint.
Examples:

>>> from cognite.client.credentials import OAuthDeviceCode
Expand All @@ -242,23 +278,53 @@ class OAuthDeviceCode(_OAuthCredentialProviderWithTokenRefresh, _WithMsalSeriali
... client_id="abcd",
... scopes=["https://greenfield.cognitedata.com/.default"],
... )

>>> from cognite.client.credentials import OAuthDeviceCode
>>> oauth_provider = OAuthDeviceCode(
... oauth_discovery_url="https://my-tenant.auth0.com/oauth",
... client_id="abcd",
... scopes=["IDENTITY", "user_impersonation"],
... )
"""

def __init__(
self,
authority_url: str,
authority_url: str | None,
client_id: str,
scopes: list[str],
scopes: list[str] | None = None,
cdf_cluster: str | None = None,
oauth_discovery_url: str | None = None,
token_cache_path: Path | None = None,
token_expiry_leeway_seconds: int = _TOKEN_EXPIRY_LEEWAY_SECONDS_DEFAULT,
clear_cache: bool = False,
mem_cache_only: bool = False,
**token_custom_args: Any,
) -> None:
super().__init__(token_expiry_leeway_seconds)
if not exactly_one_is_not_none(authority_url, oauth_discovery_url):
raise ValueError("Either 'authority_url' or 'oauth_discovery_url' must be provided.")
if not at_least_one_is_not_none(scopes, cdf_cluster):
raise ValueError("Either 'scopes' or 'cdf_cluster' must be provided.")
if not client_id:
raise ValueError("'client_id' must be provided.")
self.__authority_url = authority_url
self.__oauth_discovery_url = oauth_discovery_url
self.__client_id = client_id
self.__scopes = scopes
self.__scopes = scopes or [
f"https://{cdf_cluster}.cognitedata.com/IDENTITY https://{cdf_cluster}.cognitedata.com/user_impersonation openid profile"
gregertw marked this conversation as resolved.
Show resolved Hide resolved
]
self.__mem_cache_only = mem_cache_only
self.__token_custom_args = token_custom_args

self._token_cache_path = self._resolve_token_cache_path(token_cache_path, client_id)
self.__app = self._create_client_app(self._token_cache_path, client_id, authority_url)
self.__app = self._create_client_app(
self._token_cache_path,
client_id,
authority_url,
oauth_discovery_url,
clear_cache,
mem_cache_only,
)

def __getstate__(self) -> dict[str, Any]:
# PublicClientApplication is not picklable, temporarily remove:
Expand All @@ -269,12 +335,23 @@ def __getstate__(self) -> dict[str, Any]:

def __setstate__(self, state: dict[str, Any]) -> None:
super().__setstate__(state)
self.__app = self._create_client_app(self._token_cache_path, self.__client_id, self.__authority_url)
self.__app = self._create_client_app(
token_cache_path=self._token_cache_path,
client_id=self.__client_id,
authority_url=self.__authority_url,
oauth_discovery_url=self.__oauth_discovery_url,
clear_cache=False,
mem_cache_only=self.__mem_cache_only,
)

@property
def authority_url(self) -> str:
def authority_url(self) -> str | None:
return self.__authority_url

@property
def oauth_discovery_url(self) -> str | None:
return self.__oauth_discovery_url

@property
def client_id(self) -> str:
return self.__client_id
Expand All @@ -283,22 +360,100 @@ def client_id(self) -> str:
def scopes(self) -> list[str]:
return self.__scopes

def scope_string(self) -> str:
return " ".join(self.__scopes)

def _get_token(self, convert_timestamps: bool = True) -> dict[str, Any]:
gregertw marked this conversation as resolved.
Show resolved Hide resolved
"""Return a dictionary with the current token and expiry time."""
if self._token_cache_path.exists():
token = self._get_cached_token(self._token_cache_path)
else:
if _app := getattr(self, f"_{type(self).__name__}__app", None):
if _app.token_cache.has_state_changed:
with open(self._token_cache_path, "w+") as fh:
fh.write(_app.token_cache.serialize())
token = self._get_cached_token(self._token_cache_path)

if convert_timestamps:
if "AccessToken" in token:
for key, value in token["AccessToken"].items():
for subkey in ["expires_on", "extended_expires_on", "cached_at"]:
if subkey in value:
value[subkey] = datetime.fromtimestamp(int(value[subkey])).isoformat()
return token

def _refresh_access_token(self) -> tuple[str, float]:
# First check if a token cache exists on disk. If yes, find and use:
# - A valid access token.
# - A valid refresh token, and if so, use it automatically to redeem a new access token.
credentials = None
if accounts := self.__app.get_accounts():
credentials = self.__app.acquire_token_silent(scopes=self.__scopes, account=accounts[0])

# If we're unable to find (or acquire a new) access token, we initiate the device code auth flow:
for token in self.__app.token_cache.search(self.__app.token_cache.CredentialType.REFRESH_TOKEN):
if "expires_on" in token and token["expires_on"] > time.time():
credentials = token
break
if credentials is not None:
credentials = self.__app.client.obtain_token_by_refresh_token(credentials.get("secret", ""))
else:
for token in self.__app.token_cache.search(self.__app.token_cache.CredentialType.ACCESS_TOKEN):
if expiry := int(token.get("expires_on", 0)) - time.time() > 0:
credentials = {
"access_token": token.get("secret"),
"expires_in": expiry,
}
break
# If we're unable to find (or acquire a new) access token, we initiate the device code auth flow.
# The msal device_code flow does not support setting the audience, so we need to handle it manually.
# We use the http client instantiated as part of the msal client, as well as the details found
# in oauth discovery.
if credentials is None:
device_flow = self.__app.initiate_device_flow(scopes=self.__scopes)
# print device code user instructions to screen
print(f"Device code: {device_flow['message']}") # noqa: T201
credentials = self.__app.acquire_token_by_device_flow(flow=device_flow)
data = {
"scope": self.scope_string(),
"client_id": self.client_id,
}
for key, value in self.__token_custom_args.items():
data[key] = value
try:
device_flow = self.__app.http_client.post(
self.__app.authority.device_authorization_endpoint,
data=data,
headers={
"Accept": "application/json",
"Content-Type": "application/x-www-form-urlencoded;charset=UTF-8",
},
).json()
except Exception as e:
raise CogniteAuthError("Error initiating device flow") from e
if "verification_uri" in device_flow:
print( # noqa: T201
f"Visit {device_flow['verification_uri']} and enter the code: {device_flow.get('user_code', 'ERROR')}"
)
elif "message" in device_flow:
print(f"Device code: {device_flow.get('message', device_flow.get('user_code', 'ERROR'))}") # noqa: T201
else:
raise CogniteAuthError(
f"Error initiating device flow: {device_flow.get('error')} - {device_flow.get('error_description')}"
)
if "interval" not in device_flow:
# Set default interval according to standard
device_flow["interval"] = 5
if "expires_in" in device_flow:
# msal library uses expires_at instead of the standard expires_in
device_flow["expires_at"] = device_flow["expires_in"] + time.time()
# Poll for token
credentials = self.__app.client.obtain_token_by_device_flow(
flow=device_flow,
data=dict(
data,
code=device_flow.get(
"device_code"
), # Hack from msal library to get the code from the device flow, not standard
),
)

self._verify_credentials(credentials)
self.__app.token_cache.add(
dict(credentials, environment=self.__app.authority.instance),
)
return credentials["access_token"], time.time() + float(credentials["expires_in"])

@classmethod
Expand Down Expand Up @@ -333,6 +488,54 @@ def load(cls, config: dict[str, Any] | str) -> OAuthDeviceCode:
),
)

@classmethod
def default_for_azure_ad(
cls,
tenant_id: str,
client_id: str,
cdf_cluster: str,
token_cache_path: Path | None = None,
token_expiry_leeway_seconds: int = _TOKEN_EXPIRY_LEEWAY_SECONDS_DEFAULT,
clear_cache: bool = False,
mem_cache_only: bool = False,
) -> OAuthDeviceCode:
"""
Create an OAuthDeviceCode instance for Azure with default URLs and scopes. It uses the pre-configured Cognite
app registration for device code flow. If you need device code flow with another app registration, instantiate
OAuthDeviceCode directly.

The default configuration creates the URLs based on the tenant id and cluster:

* Authority URL: "https://login.microsoftonline.com/{tenant_id}"
* Scopes: [f"https://{cdf_cluster}.cognitedata.com/.default"]

Args:
tenant_id (str): The Azure tenant id
client_id (str): An app registration that allows device code flow.
cdf_cluster (str): The CDF cluster where the CDF project is located.
token_cache_path (Path | None): Location to store token cache, defaults to os temp directory/cognitetokencache.{client_id}.bin.
token_expiry_leeway_seconds (int): The token is refreshed at the earliest when this number of seconds is left before expiry. Default: 30 sec
clear_cache (bool): If True, the token cache will be cleared on initialization. Default: False
mem_cache_only (bool): If True, the token cache will only be stored in memory. Default: False
Returns:
OAuthDeviceCode: An OAuthDeviceCode instance
"""
return cls(
authority_url=f"https://login.microsoftonline.com/{tenant_id}",
client_id=client_id, # Default application for CDF API for device code flow
scopes=[
f"https://{cdf_cluster}.cognitedata.com/IDENTITY",
f"https://{cdf_cluster}.cognitedata.com/user_impersonation",
"profile",
"openid",
],
token_cache_path=token_cache_path,
token_expiry_leeway_seconds=token_expiry_leeway_seconds,
clear_cache=clear_cache,
mem_cache_only=mem_cache_only,
audience=f"https://{cdf_cluster}.cognitedata.com",
)


class OAuthInteractive(_OAuthCredentialProviderWithTokenRefresh, _WithMsalSerializableTokenCache):
"""OAuth credential provider for an interactive login flow.
Expand Down
Loading