diff --git a/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py b/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py index dd2b3057b..0a9b15bc0 100644 --- a/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py +++ b/airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py @@ -25,6 +25,13 @@ _NOOP_MESSAGE_REPOSITORY = NoopMessageRepository() +class ResponseKeysMaxRecurtionReached(AirbyteTracedException): + """ + Raised when the max level of recursion is reached, when trying to + find-and-get the target key, during the `_make_handled_request` + """ + + class AbstractOauth2Authenticator(AuthBase): """ Abstract class for an OAuth authenticators that implements the OAuth token refresh flow. The authenticator @@ -53,15 +60,31 @@ def __call__(self, request: requests.PreparedRequest) -> requests.PreparedReques request.headers.update(self.get_auth_header()) return request + @property + def _is_access_token_flow(self) -> bool: + return self.get_token_refresh_endpoint() is None and self.access_token is not None + + @property + def token_expiry_is_time_of_expiration(self) -> bool: + """ + Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid. + """ + + return False + + @property + def token_expiry_date_format(self) -> Optional[str]: + """ + Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires + """ + + return None + def get_auth_header(self) -> Mapping[str, Any]: """HTTP header to set on the requests""" token = self.access_token if self._is_access_token_flow else self.get_access_token() return {"Authorization": f"Bearer {token}"} - @property - def _is_access_token_flow(self) -> bool: - return self.get_token_refresh_endpoint() is None and self.access_token is not None - def get_access_token(self) -> str: """Returns the access token""" if self.token_has_expired(): @@ -107,9 +130,39 @@ def build_refresh_request_headers(self) -> Mapping[str, Any] | None: headers = self.get_refresh_request_headers() return headers if headers else None + def refresh_access_token(self) -> Tuple[str, Union[str, int]]: + """ + Returns the refresh token and its expiration datetime + + :return: a tuple of (access_token, token_lifespan) + """ + response_json = self._make_handled_request() + self._ensure_access_token_in_response(response_json) + + return ( + self._extract_access_token(response_json), + self._extract_token_expiry_date(response_json), + ) + + # ---------------- + # PRIVATE METHODS + # ---------------- + def _wrap_refresh_token_exception( self, exception: requests.exceptions.RequestException ) -> bool: + """ + Wraps and handles exceptions that occur during the refresh token process. + + This method checks if the provided exception is related to a refresh token error + by examining the response status code and specific error content. + + Args: + exception (requests.exceptions.RequestException): The exception raised during the request. + + Returns: + bool: True if the exception is related to a refresh token error, False otherwise. + """ try: if exception.response is not None: exception_content = exception.response.json() @@ -131,7 +184,24 @@ def _wrap_refresh_token_exception( ), max_time=300, ) - def _get_refresh_access_token_response(self) -> Any: + def _make_handled_request(self) -> Any: + """ + Makes a handled HTTP request to refresh an OAuth token. + + This method sends a POST request to the token refresh endpoint with the necessary + headers and body to obtain a new access token. It handles various exceptions that + may occur during the request and logs the response for troubleshooting purposes. + + Returns: + Mapping[str, Any]: The JSON response from the token refresh endpoint. + + Raises: + DefaultBackoffException: If the response status code is 429 (Too Many Requests) + or any 5xx server error. + AirbyteTracedException: If the refresh token is invalid or expired, prompting + re-authentication. + Exception: For any other exceptions that occur during the request. + """ try: response = requests.request( method="POST", @@ -139,22 +209,10 @@ def _get_refresh_access_token_response(self) -> Any: data=self.build_refresh_request_body(), headers=self.build_refresh_request_headers(), ) - if response.ok: - response_json = response.json() - # Add the access token to the list of secrets so it is replaced before logging the response - # An argument could be made to remove the prevous access key from the list of secrets, but unmasking values seems like a security incident waiting to happen... - access_key = response_json.get(self.get_access_token_name()) - if not access_key: - raise Exception( - "Token refresh API response was missing access token {self.get_access_token_name()}" - ) - add_to_secrets(access_key) - self._log_response(response) - return response_json - else: - # log the response even if the request failed for troubleshooting purposes - self._log_response(response) - response.raise_for_status() + # log the response even if the request failed for troubleshooting purposes + self._log_response(response) + response.raise_for_status() + return response.json() except requests.exceptions.RequestException as e: if e.response is not None: if e.response.status_code == 429 or e.response.status_code >= 500: @@ -168,17 +226,34 @@ def _get_refresh_access_token_response(self) -> Any: except Exception as e: raise Exception(f"Error while refreshing access token: {e}") from e - def refresh_access_token(self) -> Tuple[str, Union[str, int]]: + def _ensure_access_token_in_response(self, response_data: Mapping[str, Any]) -> None: """ - Returns the refresh token and its expiration datetime + Ensures that the access token is present in the response data. - :return: a tuple of (access_token, token_lifespan) - """ - response_json = self._get_refresh_access_token_response() + This method attempts to extract the access token from the provided response data. + If the access token is not found, it raises an exception indicating that the token + refresh API response was missing the access token. If the access token is found, + it adds the token to the list of secrets to ensure it is replaced before logging + the response. + + Args: + response_data (Mapping[str, Any]): The response data from which to extract the access token. - return response_json[self.get_access_token_name()], response_json[ - self.get_expires_in_name() - ] + Raises: + Exception: If the access token is not found in the response data. + ResponseKeysMaxRecurtionReached: If the maximum recursion depth is reached while extracting the access token. + """ + try: + access_key = self._extract_access_token(response_data) + if not access_key: + raise Exception( + "Token refresh API response was missing access token {self.get_access_token_name()}" + ) + # Add the access token to the list of secrets so it is replaced before logging the response + # An argument could be made to remove the prevous access key from the list of secrets, but unmasking values seems like a security incident waiting to happen... + add_to_secrets(access_key) + except ResponseKeysMaxRecurtionReached as e: + raise e def _parse_token_expiration_date(self, value: Union[str, int]) -> AirbyteDateTime: """ @@ -206,22 +281,125 @@ def _parse_token_expiration_date(self, value: Union[str, int]) -> AirbyteDateTim f"Invalid expires_in value: {value}. Expected number of seconds when no format specified." ) - @property - def token_expiry_is_time_of_expiration(self) -> bool: + def _extract_access_token(self, response_data: Mapping[str, Any]) -> Any: """ - Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid. + Extracts the access token from the given response data. + + Args: + response_data (Mapping[str, Any]): The response data from which to extract the access token. + + Returns: + str: The extracted access token. """ + return self._find_and_get_value_from_response(response_data, self.get_access_token_name()) - return False + def _extract_refresh_token(self, response_data: Mapping[str, Any]) -> Any: + """ + Extracts the refresh token from the given response data. - @property - def token_expiry_date_format(self) -> Optional[str]: + Args: + response_data (Mapping[str, Any]): The response data from which to extract the refresh token. + + Returns: + str: The extracted refresh token. """ - Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires + return self._find_and_get_value_from_response(response_data, self.get_refresh_token_name()) + + def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> Any: + """ + Extracts the token_expiry_date, like `expires_in` or `expires_at`, etc from the given response data. + + Args: + response_data (Mapping[str, Any]): The response data from which to extract the token_expiry_date. + + Returns: + str: The extracted token_expiry_date. """ + return self._find_and_get_value_from_response(response_data, self.get_expires_in_name()) + + def _find_and_get_value_from_response( + self, + response_data: Mapping[str, Any], + key_name: str, + max_depth: int = 5, + current_depth: int = 0, + ) -> Any: + """ + Recursively searches for a specified key in a nested dictionary or list and returns its value if found. + + Args: + response_data (Mapping[str, Any]): The response data to search through, which can be a dictionary or a list. + key_name (str): The key to search for in the response data. + max_depth (int, optional): The maximum depth to search for the key to avoid infinite recursion. Defaults to 5. + current_depth (int, optional): The current depth of the recursion. Defaults to 0. + + Returns: + Any: The value associated with the specified key if found, otherwise None. + + Raises: + AirbyteTracedException: If the maximum recursion depth is reached without finding the key. + """ + if current_depth > max_depth: + # this is needed to avoid an inf loop, possible with a very deep nesting observed. + message = f"The maximum level of recursion is reached. Couldn't find the speficied `{key_name}` in the response." + raise ResponseKeysMaxRecurtionReached( + internal_message=message, message=message, failure_type=FailureType.config_error + ) + + if isinstance(response_data, dict): + # get from the root level + if key_name in response_data: + return response_data[key_name] + + # get from the nested object + for _, value in response_data.items(): + result = self._find_and_get_value_from_response( + value, key_name, max_depth, current_depth + 1 + ) + if result is not None: + return result + + # get from the nested array object + elif isinstance(response_data, list): + for item in response_data: + result = self._find_and_get_value_from_response( + item, key_name, max_depth, current_depth + 1 + ) + if result is not None: + return result return None + @property + def _message_repository(self) -> Optional[MessageRepository]: + """ + The implementation can define a message_repository if it wants debugging logs for HTTP requests + """ + return _NOOP_MESSAGE_REPOSITORY + + def _log_response(self, response: requests.Response) -> None: + """ + Logs the HTTP response using the message repository if it is available. + + Args: + response (requests.Response): The HTTP response to log. + """ + if self._message_repository: + self._message_repository.log_message( + Level.DEBUG, + lambda: format_http_message( + response, + "Refresh token", + "Obtains access token", + self._NO_STREAM_NAME, + is_auxiliary=True, + ), + ) + + # ---------------- + # ABSTR METHODS + # ---------------- + @abstractmethod def get_token_refresh_endpoint(self) -> Optional[str]: """Returns the endpoint to refresh the access token""" @@ -295,23 +473,3 @@ def access_token(self) -> str: @abstractmethod def access_token(self, value: str) -> str: """Setter for the access token""" - - @property - def _message_repository(self) -> Optional[MessageRepository]: - """ - The implementation can define a message_repository if it wants debugging logs for HTTP requests - """ - return _NOOP_MESSAGE_REPOSITORY - - def _log_response(self, response: requests.Response) -> None: - if self._message_repository: - self._message_repository.log_message( - Level.DEBUG, - lambda: format_http_message( - response, - "Refresh token", - "Obtains access token", - self._NO_STREAM_NAME, - is_auxiliary=True, - ), - ) diff --git a/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py b/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py index 5cbe17e0a..2ff2f60e9 100644 --- a/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py +++ b/airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py @@ -51,7 +51,7 @@ def __init__( refresh_token_error_status_codes: Tuple[int, ...] = (), refresh_token_error_key: str = "", refresh_token_error_values: Tuple[str, ...] = (), - ): + ) -> None: self._token_refresh_endpoint = token_refresh_endpoint self._client_secret_name = client_secret_name self._client_secret = client_secret @@ -175,7 +175,7 @@ def __init__( refresh_token_error_status_codes: Tuple[int, ...] = (), refresh_token_error_key: str = "", refresh_token_error_values: Tuple[str, ...] = (), - ): + ) -> None: """ Args: connector_config (Mapping[str, Any]): The full connector configuration @@ -196,18 +196,12 @@ def __init__( token_expiry_is_time_of_expiration bool: set True it if expires_in is returned as time of expiration instead of the number seconds until expiration message_repository (MessageRepository): the message repository used to emit logs on HTTP requests and control message on config update """ - self._client_id = ( - client_id # type: ignore[assignment] # Incorrect type for assignment - if client_id is not None - else dpath.get(connector_config, ("credentials", "client_id")) # type: ignore[arg-type] + self._connector_config = connector_config + self._client_id: str = self._get_config_value_by_path( + ("credentials", "client_id"), client_id ) - self._client_secret = ( - client_secret # type: ignore[assignment] # Incorrect type for assignment - if client_secret is not None - else dpath.get( - connector_config, # type: ignore[arg-type] - ("credentials", "client_secret"), - ) + self._client_secret: str = self._get_config_value_by_path( + ("credentials", "client_secret"), client_secret ) self._client_id_name = client_id_name self._client_secret_name = client_secret_name @@ -222,9 +216,9 @@ def __init__( super().__init__( token_refresh_endpoint=token_refresh_endpoint, client_id_name=self._client_id_name, - client_id=self.get_client_id(), + client_id=self._client_id, client_secret_name=self._client_secret_name, - client_secret=self.get_client_secret(), + client_secret=self._client_secret, refresh_token=self.get_refresh_token(), refresh_token_name=self._refresh_token_name, scopes=scopes, @@ -242,51 +236,62 @@ def __init__( refresh_token_error_values=refresh_token_error_values, ) - def get_refresh_token_name(self) -> str: - return self._refresh_token_name - - def get_client_id(self) -> str: - return self._client_id - - def get_client_secret(self) -> str: - return self._client_secret - @property def access_token(self) -> str: - return dpath.get( # type: ignore[return-value] - self._connector_config, # type: ignore[arg-type] - self._access_token_config_path, - default="", - ) + """ + Retrieve the access token from the configuration. + + Returns: + str: The access token. + """ + return self._get_config_value_by_path(self._access_token_config_path) # type: ignore[return-value] @access_token.setter def access_token(self, new_access_token: str) -> None: - dpath.new( - self._connector_config, # type: ignore[arg-type] - self._access_token_config_path, - new_access_token, - ) + """ + Sets a new access token. + + Args: + new_access_token (str): The new access token to be set. + """ + self._set_config_value_by_path(self._access_token_config_path, new_access_token) def get_refresh_token(self) -> str: - return dpath.get( # type: ignore[return-value] - self._connector_config, # type: ignore[arg-type] - self._refresh_token_config_path, - default="", - ) + """ + Retrieve the refresh token from the configuration. + + This method fetches the refresh token using the configuration path specified + by `_refresh_token_config_path`. + + Returns: + str: The refresh token as a string. + """ + return self._get_config_value_by_path(self._refresh_token_config_path) # type: ignore[return-value] def set_refresh_token(self, new_refresh_token: str) -> None: - dpath.new( - self._connector_config, # type: ignore[arg-type] - self._refresh_token_config_path, - new_refresh_token, - ) + """ + Updates the refresh token in the configuration. + + Args: + new_refresh_token (str): The new refresh token to be set. + """ + self._set_config_value_by_path(self._refresh_token_config_path, new_refresh_token) def get_token_expiry_date(self) -> AirbyteDateTime: - expiry_date = dpath.get( - self._connector_config, # type: ignore[arg-type] - self._token_expiry_date_config_path, - default="", - ) + """ + Retrieves the token expiry date from the configuration. + + This method fetches the token expiry date from the configuration using the specified path. + If the expiry date is an empty string, it returns the current date and time minus one day. + Otherwise, it parses the expiry date string into an AirbyteDateTime object. + + Returns: + AirbyteDateTime: The parsed or calculated token expiry date. + + Raises: + TypeError: If the result is not an instance of AirbyteDateTime. + """ + expiry_date = self._get_config_value_by_path(self._token_expiry_date_config_path) result = ( ab_datetime_now() - timedelta(days=1) if expiry_date == "" @@ -296,14 +301,15 @@ def get_token_expiry_date(self) -> AirbyteDateTime: return result raise TypeError("Invalid datetime conversion") - def set_token_expiry_date( # type: ignore[override] - self, - new_token_expiry_date: AirbyteDateTime, - ) -> None: - dpath.new( - self._connector_config, # type: ignore[arg-type] - self._token_expiry_date_config_path, - str(new_token_expiry_date), + def set_token_expiry_date(self, new_token_expiry_date: AirbyteDateTime) -> None: # type: ignore[override] + """ + Sets the token expiry date in the configuration. + + Args: + new_token_expiry_date (AirbyteDateTime): The new expiry date for the token. + """ + self._set_config_value_by_path( + self._token_expiry_date_config_path, str(new_token_expiry_date) ) def token_has_expired(self) -> bool: @@ -315,6 +321,16 @@ def get_new_token_expiry_date( access_token_expires_in: str, token_expiry_date_format: str | None = None, ) -> AirbyteDateTime: + """ + Calculate the new token expiry date based on the provided expiration duration or format. + + Args: + access_token_expires_in (str): The duration (in seconds) until the access token expires, or the expiry date in a specific format. + token_expiry_date_format (str | None, optional): The format of the expiry date if provided. Defaults to None. + + Returns: + AirbyteDateTime: The calculated expiry date of the access token. + """ if token_expiry_date_format: return ab_datetime_parse(access_token_expires_in) else: @@ -336,27 +352,82 @@ def get_access_token(self) -> str: self.access_token = new_access_token self.set_refresh_token(new_refresh_token) self.set_token_expiry_date(new_token_expiry_date) - # FIXME emit_configuration_as_airbyte_control_message as been deprecated in favor of package airbyte_cdk.sources.message - # Usually, a class shouldn't care about the implementation details but to keep backward compatibility where we print the - # message directly in the console, this is needed - if not isinstance(self._message_repository, NoopMessageRepository): - self._message_repository.emit_message( - create_connector_config_control_message(self._connector_config) # type: ignore[arg-type] - ) - else: - emit_configuration_as_airbyte_control_message(self._connector_config) # type: ignore[arg-type] + self._emit_control_message() return self.access_token - def refresh_access_token( # type: ignore[override] # Signature doesn't match base class - self, - ) -> Tuple[str, str, str]: - response_json = self._get_refresh_access_token_response() + def refresh_access_token(self) -> Tuple[str, str, str]: # type: ignore[override] + """ + Refreshes the access token by making a handled request and extracting the necessary token information. + + Returns: + Tuple[str, str, str]: A tuple containing the new access token, token expiry date, and refresh token. + """ + response_json = self._make_handled_request() return ( - response_json[self.get_access_token_name()], - response_json[self.get_expires_in_name()], - response_json[self.get_refresh_token_name()], + self._extract_access_token(response_json), + self._extract_token_expiry_date(response_json), + self._extract_refresh_token(response_json), + ) + + def _set_config_value_by_path(self, config_path: Union[str, Sequence[str]], value: Any) -> None: + """ + Set a value in the connector configuration at the specified path. + + Args: + config_path (Union[str, Sequence[str]]): The path within the configuration where the value should be set. + This can be a string representing a single key or a sequence of strings representing a nested path. + value (Any): The value to set at the specified path in the configuration. + + Returns: + None + """ + dpath.new(self._connector_config, config_path, value) # type: ignore[arg-type] + + def _get_config_value_by_path( + self, config_path: Union[str, Sequence[str]], default: Optional[str] = None + ) -> str | Any: + """ + Retrieve a value from the connector configuration using a specified path. + + Args: + config_path (Union[str, Sequence[str]]): The path to the desired configuration value. This can be a string or a sequence of strings. + default (Optional[str], optional): The default value to return if the specified path does not exist in the configuration. Defaults to None. + + Returns: + Any: The value from the configuration at the specified path, or the default value if the path does not exist. + """ + return dpath.get( + self._connector_config, # type: ignore[arg-type] + config_path, + default=default if default is not None else "", ) + def _emit_control_message(self) -> None: + """ + Emits a control message based on the connector configuration. + + This method checks if the message repository is not a NoopMessageRepository. + If it is not, it emits a message using the message repository. Otherwise, + it falls back to emitting the configuration as an Airbyte control message + directly to the console for backward compatibility. + + Note: + The function `emit_configuration_as_airbyte_control_message` has been deprecated + in favor of the package `airbyte_cdk.sources.message`. + + Raises: + TypeError: If the argument types are incorrect. + """ + # FIXME emit_configuration_as_airbyte_control_message as been deprecated in favor of package airbyte_cdk.sources.message + # Usually, a class shouldn't care about the implementation details but to keep backward compatibility where we print the + # message directly in the console, this is needed + if not isinstance(self._message_repository, NoopMessageRepository): + self._message_repository.emit_message( + create_connector_config_control_message(self._connector_config) # type: ignore[arg-type] + ) + else: + emit_configuration_as_airbyte_control_message(self._connector_config) # type: ignore[arg-type] + @property def _message_repository(self) -> MessageRepository: """ diff --git a/unit_tests/connector_builder/test_connector_builder_handler.py b/unit_tests/connector_builder/test_connector_builder_handler.py index c00a7e2f1..b0c91ce30 100644 --- a/unit_tests/connector_builder/test_connector_builder_handler.py +++ b/unit_tests/connector_builder/test_connector_builder_handler.py @@ -600,7 +600,7 @@ def test_config_update() -> None: "expires_in": 3600, } with patch( - "airbyte_cdk.sources.streams.http.requests_native_auth.SingleUseRefreshTokenOauth2Authenticator._get_refresh_access_token_response", + "airbyte_cdk.sources.streams.http.requests_native_auth.SingleUseRefreshTokenOauth2Authenticator._make_handled_request", return_value=refresh_request_response, ): output = handle_connector_builder_request( diff --git a/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py b/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py index 808126988..d756931c8 100644 --- a/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py +++ b/unit_tests/sources/streams/http/requests_native_auth/test_requests_native_auth.py @@ -22,6 +22,9 @@ SingleUseRefreshTokenOauth2Authenticator, TokenAuthenticator, ) +from airbyte_cdk.sources.streams.http.requests_native_auth.abstract_oauth import ( + ResponseKeysMaxRecurtionReached, +) from airbyte_cdk.utils import AirbyteTracedException from airbyte_cdk.utils.datetime_helpers import AirbyteDateTime, ab_datetime_now, ab_datetime_parse @@ -258,7 +261,7 @@ def test_refresh_access_token(self, mocker): assert isinstance(expires_in, int) assert ("access_token", 1000) == (token, expires_in) - # Test with expires_in as str + # Test with expires_in as str(int) mocker.patch.object( resp, "json", return_value={"access_token": "access_token", "expires_in": "2000"} ) @@ -267,7 +270,7 @@ def test_refresh_access_token(self, mocker): assert isinstance(expires_in, str) assert ("access_token", "2000") == (token, expires_in) - # Test with expires_in as str + # Test with expires_in as datetime(str) mocker.patch.object( resp, "json", @@ -278,6 +281,78 @@ def test_refresh_access_token(self, mocker): assert isinstance(expires_in, str) assert ("access_token", "2022-04-24T00:00:00Z") == (token, expires_in) + # Test with nested access_token and expires_in as str(int) + mocker.patch.object( + resp, + "json", + return_value={"data": {"access_token": "access_token_nested", "expires_in": "2001"}}, + ) + token, expires_in = oauth.refresh_access_token() + + assert isinstance(expires_in, str) + assert ("access_token_nested", "2001") == (token, expires_in) + + # Test with multiple nested levels access_token and expires_in as str(int) + mocker.patch.object( + resp, + "json", + return_value={ + "data": { + "scopes": ["one", "two", "three"], + "data2": { + "not_access_token": "test_non_access_token_value", + "data3": { + "some_field": "test_value", + "expires_at": "2800", + "data4": { + "data5": { + "access_token": "access_token_deeply_nested", + "expires_in": "2002", + } + }, + }, + }, + } + }, + ) + token, expires_in = oauth.refresh_access_token() + + assert isinstance(expires_in, str) + assert ("access_token_deeply_nested", "2002") == (token, expires_in) + + # Test with max nested levels access_token and expires_in as str(int) + mocker.patch.object( + resp, + "json", + return_value={ + "data": { + "scopes": ["one", "two", "three"], + "data2": { + "not_access_token": "test_non_access_token_value", + "data3": { + "some_field": "test_value", + "expires_at": "2800", + "data4": { + "data5": { + # this is the edge case, but worth testing. + "data6": { + "access_token": "access_token_super_deeply_nested", + "expires_in": "2003", + } + } + }, + }, + }, + } + }, + ) + with pytest.raises(ResponseKeysMaxRecurtionReached) as exc_info: + oauth.refresh_access_token() + error_message = "The maximum level of recursion is reached. Couldn't find the speficied `access_token` in the response." + assert exc_info.value.internal_message == error_message + assert exc_info.value.message == error_message + assert exc_info.value.failure_type == FailureType.config_error + def test_refresh_access_token_when_headers_provided(self, mocker): expected_headers = { "Authorization": "Bearer some_access_token", @@ -594,6 +669,11 @@ def test_given_message_repository_when_get_access_token_then_log_request( "airbyte_cdk.sources.streams.http.requests_native_auth.abstract_oauth.format_http_message", return_value="formatted json", ) + # patching the `expires_in` + mocker.patch( + "airbyte_cdk.sources.streams.http.requests_native_auth.abstract_oauth.AbstractOauth2Authenticator._find_and_get_value_from_response", + return_value="7200", + ) authenticator.token_has_expired = mocker.Mock(return_value=True) authenticator.get_access_token() @@ -608,7 +688,7 @@ def test_refresh_access_token(self, mocker, connector_config): client_secret=connector_config["credentials"]["client_secret"], ) - authenticator._get_refresh_access_token_response = mocker.Mock( + authenticator._make_handled_request = mocker.Mock( return_value={ authenticator.get_access_token_name(): "new_access_token", authenticator.get_expires_in_name(): "42",