From d9631cb154b43f3059488671b17f2b4f4249be97 Mon Sep 17 00:00:00 2001 From: Greger Teigre Wedel Date: Mon, 23 Sep 2024 13:11:01 +0200 Subject: [PATCH] Refactor OAuthDeviceCode to support non-Entra IdPs (#1892) Co-authored-by: anders-albert --- CHANGELOG.md | 4 + cognite/client/_version.py | 2 +- cognite/client/credentials.py | 263 ++++++++++++++++-- pyproject.toml | 2 +- tests/tests_unit/test_credential_providers.py | 10 +- 5 files changed, 251 insertions(+), 30 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c28ac6ec4..6ed70cfb6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,10 @@ Changes are grouped as follows - `Fixed` for any bug fixes. - `Security` in case of vulnerabilities. +## [7.62.1] - 2024-09-23 +### Changed +- Support for `OAuthDeviceCode` now supports non Entra IdPs + ## [7.62.0] - 2024-09-19 ### Added - All `update` methods now accept a new parameter `mode` that controls how non-update objects should be diff --git a/cognite/client/_version.py b/cognite/client/_version.py index 31a9fd931..35203968c 100644 --- a/cognite/client/_version.py +++ b/cognite/client/_version.py @@ -1,4 +1,4 @@ from __future__ import annotations -__version__ = "7.62.0" +__version__ = "7.62.1" __api_subversion__ = "20230101" diff --git a/cognite/client/credentials.py b/cognite/client/credentials.py index ed5a24421..a6ee426e4 100644 --- a/cognite/client/credentials.py +++ b/cognite/client/credentials.py @@ -2,10 +2,12 @@ import atexit import inspect +import json import tempfile import threading import time from abc import abstractmethod +from datetime import datetime from pathlib import Path from types import MappingProxyType from typing import Any, Callable, Protocol, runtime_checkable @@ -15,7 +17,7 @@ from requests_oauthlib import OAuth2Session from cognite.client.exceptions import CogniteAuthError -from cognite.client.utils._auxiliary import load_resource_to_dict +from cognite.client.utils._auxiliary import at_least_one_is_not_none, exactly_one_is_not_none, load_resource_to_dict _TOKEN_EXPIRY_LEEWAY_SECONDS_DEFAULT = 30 # Do not change without also updating all the docstrings using it @@ -192,12 +194,15 @@ def authorization_header(self) -> tuple[str, str]: class _WithMsalSerializableTokenCache: @staticmethod - def _create_serializable_token_cache(cache_path: Path) -> SerializableTokenCache: + def _create_serializable_token_cache(cache_path: Path, clear_cache: bool = False) -> SerializableTokenCache: token_cache = SerializableTokenCache() if cache_path.exists(): - with cache_path.open() as fh: - token_cache.deserialize(fh.read()) + if clear_cache: + cache_path.unlink(missing_ok=True) + else: + with cache_path.open() as fh: + token_cache.deserialize(fh.read()) def __at_exit() -> None: if token_cache.has_state_changed: @@ -211,29 +216,61 @@ def __at_exit() -> None: def _resolve_token_cache_path(token_cache_path: Path | None, client_id: str) -> Path: return token_cache_path or Path(tempfile.gettempdir()) / f"cognitetokencache.{client_id}.bin" - def _create_client_app(self, token_cache_path: Path, client_id: str, authority_url: str) -> PublicClientApplication: + def _create_client_app( + self, + token_cache_path: Path, + client_id: str, + authority_url: str | None = None, + oauth_discovery_url: str | None = None, + clear_cache: bool = False, + mem_cache_only: bool = False, + ) -> PublicClientApplication: from cognite.client.config import global_config - # In addition to caching in memory, we also cache the token on disk so it can be reused across processes: - serializable_token_cache = self._create_serializable_token_cache(token_cache_path) + if authority_url and oauth_discovery_url: + raise ValueError( + "Only one of 'authority_url' (for MS Entra) or 'oauth_discovery_url' (for other IdPs) should be provided." + ) + + # In addition to caching in memory, we also cache the token on disk so it can be reused across processes. + if mem_cache_only: + serializable_token_cache = SerializableTokenCache() + else: + serializable_token_cache = self._create_serializable_token_cache(token_cache_path, clear_cache) return PublicClientApplication( client_id=client_id, authority=authority_url, token_cache=serializable_token_cache, verify=not global_config.disable_ssl, + oidc_authority=oauth_discovery_url, + # These two must be set to `False` to support non-Entra authorities. + instance_discovery=False, + validate_authority=False, ) + @staticmethod + def _get_cached_token(cache_path: Path) -> dict[str, Any]: + if not cache_path.exists(): + return {} + token = json.loads(cache_path.read_text()) + return token + class OAuthDeviceCode(_OAuthCredentialProviderWithTokenRefresh, _WithMsalSerializableTokenCache): """OAuth credential provider for the device code login flow. Args: - authority_url (str): OAuth authority url - client_id (str): Your application's client id. - scopes (list[str]): A list of scopes. + authority_url (str | None): MS Entra OAuth authority url, typically "https://login.microsoftonline.com/{tenant_id}" + client_id (str): Your application's client id that allows device code flows. + scopes (list[str] | None): A list of scopes. + cdf_cluster (str | None): The CDF cluster where the CDF project is located. If provided, scopes will be set to + [f"https://{cdf_cluster}.cognitedata.com/IDENTITY https://{cdf_cluster}.cognitedata.com/user_impersonation openid profile"]. + oauth_discovery_url (str | None): Standard OAuth discovery URL, should be where "/.well-known/openid-configuration" is found. token_cache_path (Path | None): Location to store token cache, defaults to os temp directory/cognitetokencache.{client_id}.bin. token_expiry_leeway_seconds (int): The token is refreshed at the earliest when this number of seconds is left before expiry. Default: 30 sec - + clear_cache (bool): If True, the token cache will be cleared on initialization. Default: False + mem_cache_only (bool): If True, the token cache will only be stored in memory. Default: False + **token_custom_args (Any): Additional request parameters to pass to the authorization endpoint. Examples: >>> from cognite.client.credentials import OAuthDeviceCode @@ -242,23 +279,59 @@ class OAuthDeviceCode(_OAuthCredentialProviderWithTokenRefresh, _WithMsalSeriali ... client_id="abcd", ... scopes=["https://greenfield.cognitedata.com/.default"], ... ) + + Create credentials with auth0 + + >>> from cognite.client.credentials import OAuthDeviceCode + >>> oauth_provider = OAuthDeviceCode( + ... authority_url=None, + ... oauth_discovery_url="https://my-tenant.auth0.com/oauth", + ... client_id="abcd", + ... scopes=["IDENTITY", "user_impersonation"], + ... ) """ def __init__( self, - authority_url: str, + authority_url: str | None, client_id: str, - scopes: list[str], + scopes: list[str] | None = None, + cdf_cluster: str | None = None, + oauth_discovery_url: str | None = None, token_cache_path: Path | None = None, token_expiry_leeway_seconds: int = _TOKEN_EXPIRY_LEEWAY_SECONDS_DEFAULT, + clear_cache: bool = False, + mem_cache_only: bool = False, + **token_custom_args: Any, ) -> None: super().__init__(token_expiry_leeway_seconds) + if not exactly_one_is_not_none(authority_url, oauth_discovery_url): + raise ValueError("Either 'authority_url' or 'oauth_discovery_url' must be provided, and not both.") + if not at_least_one_is_not_none(scopes, cdf_cluster): + raise ValueError("Either 'scopes' or 'cdf_cluster' must be provided.") + if not client_id: + raise ValueError("'client_id' must be provided.") self.__authority_url = authority_url + self.__oauth_discovery_url = oauth_discovery_url self.__client_id = client_id - self.__scopes = scopes + self.__scopes = scopes or [ + f"https://{cdf_cluster}.cognitedata.com/IDENTITY", + f"https://{cdf_cluster}.cognitedata.com/user_impersonation", + "openid", + "profile", + ] + self.__mem_cache_only = mem_cache_only + self.__token_custom_args = token_custom_args self._token_cache_path = self._resolve_token_cache_path(token_cache_path, client_id) - self.__app = self._create_client_app(self._token_cache_path, client_id, authority_url) + self.__app = self._create_client_app( + self._token_cache_path, + client_id, + authority_url, + oauth_discovery_url, + clear_cache, + mem_cache_only, + ) def __getstate__(self) -> dict[str, Any]: # PublicClientApplication is not picklable, temporarily remove: @@ -269,12 +342,23 @@ def __getstate__(self) -> dict[str, Any]: def __setstate__(self, state: dict[str, Any]) -> None: super().__setstate__(state) - self.__app = self._create_client_app(self._token_cache_path, self.__client_id, self.__authority_url) + self.__app = self._create_client_app( + token_cache_path=self._token_cache_path, + client_id=self.__client_id, + authority_url=self.__authority_url, + oauth_discovery_url=self.__oauth_discovery_url, + clear_cache=False, + mem_cache_only=self.__mem_cache_only, + ) @property - def authority_url(self) -> str: + def authority_url(self) -> str | None: return self.__authority_url + @property + def oauth_discovery_url(self) -> str | None: + return self.__oauth_discovery_url + @property def client_id(self) -> str: return self.__client_id @@ -283,22 +367,100 @@ def client_id(self) -> str: def scopes(self) -> list[str]: return self.__scopes + def scope_string(self) -> str: + return " ".join(self.__scopes) + + def _get_token(self, convert_timestamps: bool = True) -> dict[str, Any]: + """Return a dictionary with the current token and expiry time.""" + if self._token_cache_path.exists(): + token = self._get_cached_token(self._token_cache_path) + else: + if _app := getattr(self, f"_{type(self).__name__}__app", None): + if _app.token_cache.has_state_changed: + with open(self._token_cache_path, "w+") as fh: + fh.write(_app.token_cache.serialize()) + token = self._get_cached_token(self._token_cache_path) + + if convert_timestamps: + if "AccessToken" in token: + for key, value in token["AccessToken"].items(): + for subkey in ["expires_on", "extended_expires_on", "cached_at"]: + if subkey in value: + value[subkey] = datetime.fromtimestamp(int(value[subkey])).isoformat() + return token + def _refresh_access_token(self) -> tuple[str, float]: # First check if a token cache exists on disk. If yes, find and use: # - A valid access token. # - A valid refresh token, and if so, use it automatically to redeem a new access token. credentials = None - if accounts := self.__app.get_accounts(): - credentials = self.__app.acquire_token_silent(scopes=self.__scopes, account=accounts[0]) - - # If we're unable to find (or acquire a new) access token, we initiate the device code auth flow: + for token in self.__app.token_cache.search(self.__app.token_cache.CredentialType.REFRESH_TOKEN): + if "expires_on" in token and token["expires_on"] > time.time(): + credentials = token + break + if credentials is not None: + credentials = self.__app.client.obtain_token_by_refresh_token(credentials.get("secret", "")) + else: + for token in self.__app.token_cache.search(self.__app.token_cache.CredentialType.ACCESS_TOKEN): + if expiry := int(token.get("expires_on", 0)) - time.time() > 0: + credentials = { + "access_token": token.get("secret"), + "expires_in": expiry, + } + break + # If we're unable to find (or acquire a new) access token, we initiate the device code auth flow. + # The msal device_code flow does not support setting the audience, so we need to handle it manually. + # We use the http client instantiated as part of the msal client, as well as the details found + # in oauth discovery. if credentials is None: - device_flow = self.__app.initiate_device_flow(scopes=self.__scopes) - # print device code user instructions to screen - print(f"Device code: {device_flow['message']}") # noqa: T201 - credentials = self.__app.acquire_token_by_device_flow(flow=device_flow) + data = { + "scope": self.scope_string(), + "client_id": self.client_id, + } + for key, value in self.__token_custom_args.items(): + data[key] = value + try: + device_flow = self.__app.http_client.post( + self.__app.authority.device_authorization_endpoint, + data=data, + headers={ + "Accept": "application/json", + "Content-Type": "application/x-www-form-urlencoded;charset=UTF-8", + }, + ).json() + except Exception as e: + raise CogniteAuthError("Error initiating device flow") from e + if "verification_uri" in device_flow: + print( # noqa: T201 + f"Visit {device_flow['verification_uri']} and enter the code: {device_flow.get('user_code', 'ERROR')}" + ) + elif "message" in device_flow: + print(f"Device code: {device_flow.get('message', device_flow.get('user_code', 'ERROR'))}") # noqa: T201 + else: + raise CogniteAuthError( + f"Error initiating device flow: {device_flow.get('error')} - {device_flow.get('error_description')}" + ) + if "interval" not in device_flow: + # Set default interval according to standard + device_flow["interval"] = 5 + if "expires_in" in device_flow: + # msal library uses expires_at instead of the standard expires_in + device_flow["expires_at"] = device_flow["expires_in"] + time.time() + # Poll for token + credentials = self.__app.client.obtain_token_by_device_flow( + flow=device_flow, + data=dict( + data, + code=device_flow.get( + "device_code" + ), # Hack from msal library to get the code from the device flow, not standard + ), + ) self._verify_credentials(credentials) + self.__app.token_cache.add( + dict(credentials, environment=self.__app.authority.instance), + ) return credentials["access_token"], time.time() + float(credentials["expires_in"]) @classmethod @@ -326,13 +488,62 @@ def load(cls, config: dict[str, Any] | str) -> OAuthDeviceCode: return cls( authority_url=loaded["authority_url"], client_id=loaded["client_id"], - scopes=loaded["scopes"], + scopes=loaded.get("scopes"), + cdf_cluster=loaded.get("cdf_cluster"), token_cache_path=Path(token_cache_path) if token_cache_path else None, token_expiry_leeway_seconds=int( loaded.get("token_expiry_leeway_seconds", _TOKEN_EXPIRY_LEEWAY_SECONDS_DEFAULT) ), ) + @classmethod + def default_for_azure_ad( + cls, + tenant_id: str, + client_id: str, + cdf_cluster: str, + token_cache_path: Path | None = None, + token_expiry_leeway_seconds: int = _TOKEN_EXPIRY_LEEWAY_SECONDS_DEFAULT, + clear_cache: bool = False, + mem_cache_only: bool = False, + ) -> OAuthDeviceCode: + """ + Create an OAuthDeviceCode instance for Azure with default URLs and scopes. It uses the pre-configured Cognite + app registration for device code flow. If you need device code flow with another app registration, instantiate + OAuthDeviceCode directly. + + The default configuration creates the URLs based on the tenant id and cluster: + + * Authority URL: "https://login.microsoftonline.com/{tenant_id}" + * Scopes: [f"https://{cdf_cluster}.cognitedata.com/.default"] + + Args: + tenant_id (str): The Azure tenant id + client_id (str): An app registration that allows device code flow. + cdf_cluster (str): The CDF cluster where the CDF project is located. + token_cache_path (Path | None): Location to store token cache, defaults to os temp directory/cognitetokencache.{client_id}.bin. + token_expiry_leeway_seconds (int): The token is refreshed at the earliest when this number of seconds is left before expiry. Default: 30 sec + clear_cache (bool): If True, the token cache will be cleared on initialization. Default: False + mem_cache_only (bool): If True, the token cache will only be stored in memory. Default: False + Returns: + OAuthDeviceCode: An OAuthDeviceCode instance + """ + return cls( + authority_url=f"https://login.microsoftonline.com/{tenant_id}", + client_id=client_id, # Default application for CDF API for device code flow + scopes=[ + f"https://{cdf_cluster}.cognitedata.com/IDENTITY", + f"https://{cdf_cluster}.cognitedata.com/user_impersonation", + "profile", + "openid", + ], + token_cache_path=token_cache_path, + token_expiry_leeway_seconds=token_expiry_leeway_seconds, + clear_cache=clear_cache, + mem_cache_only=mem_cache_only, + audience=f"https://{cdf_cluster}.cognitedata.com", + ) + class OAuthInteractive(_OAuthCredentialProviderWithTokenRefresh, _WithMsalSerializableTokenCache): """OAuth credential provider for an interactive login flow. diff --git a/pyproject.toml b/pyproject.toml index 6e43cfb26..abddbd574 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [tool.poetry] name = "cognite-sdk" -version = "7.62.0" +version = "7.62.1" description = "Cognite Python SDK" readme = "README.md" documentation = "https://cognite-sdk-python.readthedocs-hosted.com" diff --git a/tests/tests_unit/test_credential_providers.py b/tests/tests_unit/test_credential_providers.py index bce49dc4f..ba272b866 100644 --- a/tests/tests_unit/test_credential_providers.py +++ b/tests/tests_unit/test_credential_providers.py @@ -1,6 +1,6 @@ from types import MappingProxyType from typing import ClassVar -from unittest.mock import patch +from unittest.mock import Mock, patch import pytest from oauthlib.oauth2 import InvalidClientIdError @@ -97,7 +97,13 @@ class TestOAuthDeviceCode: @patch("cognite.client.credentials.PublicClientApplication") @pytest.mark.parametrize("expires_in", (1000, "1001")) # some IDPs return as string def test_access_token_generated(self, mock_public_client, expires_in): - mock_public_client().acquire_token_silent.return_value = { + mock_response = Mock() + mock_response.json.return_value = { + "user_code": "ABCDEF", + "message": "Follow the link and enter the code", + } + mock_public_client().http_client.post.return_value = mock_response + mock_public_client().client.obtain_token_by_device_flow.return_value = { "access_token": "azure_token", "expires_in": expires_in, }