Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: (OAuthAuthenticator) - get the access_token, refresh_token, expires_in recursively from response #285

Merged
merged 7 commits into from
Jan 30, 2025
270 changes: 214 additions & 56 deletions airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@
_NOOP_MESSAGE_REPOSITORY = NoopMessageRepository()


class ResponseKeysMaxRecurtionReached(AirbyteTracedException):
"""
Raised when the max level of recursion is reached, when trying to
find-and-get the target key, during the `_make_handled_request`
"""


class AbstractOauth2Authenticator(AuthBase):
"""
Abstract class for an OAuth authenticators that implements the OAuth token refresh flow. The authenticator
Expand Down Expand Up @@ -53,15 +60,31 @@ def __call__(self, request: requests.PreparedRequest) -> requests.PreparedReques
request.headers.update(self.get_auth_header())
return request

@property
def _is_access_token_flow(self) -> bool:
return self.get_token_refresh_endpoint() is None and self.access_token is not None

@property
def token_expiry_is_time_of_expiration(self) -> bool:
"""
Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid.
"""

return False

@property
def token_expiry_date_format(self) -> Optional[str]:
"""
Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires
"""

return None

def get_auth_header(self) -> Mapping[str, Any]:
"""HTTP header to set on the requests"""
token = self.access_token if self._is_access_token_flow else self.get_access_token()
return {"Authorization": f"Bearer {token}"}

@property
def _is_access_token_flow(self) -> bool:
return self.get_token_refresh_endpoint() is None and self.access_token is not None

def get_access_token(self) -> str:
"""Returns the access token"""
if self.token_has_expired():
Expand Down Expand Up @@ -107,9 +130,39 @@ def build_refresh_request_headers(self) -> Mapping[str, Any] | None:
headers = self.get_refresh_request_headers()
return headers if headers else None

def refresh_access_token(self) -> Tuple[str, Union[str, int]]:
"""
Returns the refresh token and its expiration datetime

:return: a tuple of (access_token, token_lifespan)
"""
response_json = self._make_handled_request()
self._ensure_access_token_in_response(response_json)

return (
self._extract_access_token(response_json),
self._extract_token_expiry_date(response_json),
)

# ----------------
# PRIVATE METHODS
# ----------------

def _wrap_refresh_token_exception(
self, exception: requests.exceptions.RequestException
) -> bool:
"""
Wraps and handles exceptions that occur during the refresh token process.

This method checks if the provided exception is related to a refresh token error
by examining the response status code and specific error content.

Args:
exception (requests.exceptions.RequestException): The exception raised during the request.

Returns:
bool: True if the exception is related to a refresh token error, False otherwise.
"""
try:
if exception.response is not None:
exception_content = exception.response.json()
Expand All @@ -131,30 +184,35 @@ def _wrap_refresh_token_exception(
),
max_time=300,
)
def _get_refresh_access_token_response(self) -> Any:
def _make_handled_request(self) -> Any:
"""
Makes a handled HTTP request to refresh an OAuth token.

This method sends a POST request to the token refresh endpoint with the necessary
headers and body to obtain a new access token. It handles various exceptions that
may occur during the request and logs the response for troubleshooting purposes.

Returns:
Mapping[str, Any]: The JSON response from the token refresh endpoint.

Raises:
DefaultBackoffException: If the response status code is 429 (Too Many Requests)
or any 5xx server error.
AirbyteTracedException: If the refresh token is invalid or expired, prompting
re-authentication.
Exception: For any other exceptions that occur during the request.
"""
try:
response = requests.request(
method="POST",
url=self.get_token_refresh_endpoint(), # type: ignore # returns None, if not provided, but str | bytes is expected.
data=self.build_refresh_request_body(),
headers=self.build_refresh_request_headers(),
)
if response.ok:
response_json = response.json()
# Add the access token to the list of secrets so it is replaced before logging the response
# An argument could be made to remove the prevous access key from the list of secrets, but unmasking values seems like a security incident waiting to happen...
access_key = response_json.get(self.get_access_token_name())
if not access_key:
raise Exception(
"Token refresh API response was missing access token {self.get_access_token_name()}"
)
add_to_secrets(access_key)
self._log_response(response)
return response_json
else:
# log the response even if the request failed for troubleshooting purposes
self._log_response(response)
response.raise_for_status()
# log the response even if the request failed for troubleshooting purposes
self._log_response(response)
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
if e.response is not None:
if e.response.status_code == 429 or e.response.status_code >= 500:
Expand All @@ -168,17 +226,34 @@ def _get_refresh_access_token_response(self) -> Any:
except Exception as e:
raise Exception(f"Error while refreshing access token: {e}") from e

def refresh_access_token(self) -> Tuple[str, Union[str, int]]:
def _ensure_access_token_in_response(self, response_data: Mapping[str, Any]) -> None:
"""
Returns the refresh token and its expiration datetime
Ensures that the access token is present in the response data.

:return: a tuple of (access_token, token_lifespan)
"""
response_json = self._get_refresh_access_token_response()
This method attempts to extract the access token from the provided response data.
If the access token is not found, it raises an exception indicating that the token
refresh API response was missing the access token. If the access token is found,
it adds the token to the list of secrets to ensure it is replaced before logging
the response.

Args:
response_data (Mapping[str, Any]): The response data from which to extract the access token.

return response_json[self.get_access_token_name()], response_json[
self.get_expires_in_name()
]
Raises:
Exception: If the access token is not found in the response data.
ResponseKeysMaxRecurtionReached: If the maximum recursion depth is reached while extracting the access token.
"""
try:
access_key = self._extract_access_token(response_data)
if not access_key:
raise Exception(
"Token refresh API response was missing access token {self.get_access_token_name()}"
)
# Add the access token to the list of secrets so it is replaced before logging the response
# An argument could be made to remove the prevous access key from the list of secrets, but unmasking values seems like a security incident waiting to happen...
add_to_secrets(access_key)
except ResponseKeysMaxRecurtionReached as e:
raise e

def _parse_token_expiration_date(self, value: Union[str, int]) -> AirbyteDateTime:
"""
Expand Down Expand Up @@ -206,22 +281,125 @@ def _parse_token_expiration_date(self, value: Union[str, int]) -> AirbyteDateTim
f"Invalid expires_in value: {value}. Expected number of seconds when no format specified."
)

@property
def token_expiry_is_time_of_expiration(self) -> bool:
def _extract_access_token(self, response_data: Mapping[str, Any]) -> Any:
"""
Indicates that the Token Expiry returns the date until which the token will be valid, not the amount of time it will be valid.
Extracts the access token from the given response data.

Args:
response_data (Mapping[str, Any]): The response data from which to extract the access token.

Returns:
str: The extracted access token.
"""
return self._find_and_get_value_from_response(response_data, self.get_access_token_name())

return False
def _extract_refresh_token(self, response_data: Mapping[str, Any]) -> Any:
"""
Extracts the refresh token from the given response data.

@property
def token_expiry_date_format(self) -> Optional[str]:
Args:
response_data (Mapping[str, Any]): The response data from which to extract the refresh token.

Returns:
str: The extracted refresh token.
"""
Format of the datetime; exists it if expires_in is returned as the expiration datetime instead of seconds until it expires
return self._find_and_get_value_from_response(response_data, self.get_refresh_token_name())

def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> Any:
"""
Extracts the token_expiry_date, like `expires_in` or `expires_at`, etc from the given response data.

Args:
response_data (Mapping[str, Any]): The response data from which to extract the token_expiry_date.

Returns:
str: The extracted token_expiry_date.
"""
return self._find_and_get_value_from_response(response_data, self.get_expires_in_name())

def _find_and_get_value_from_response(
self,
response_data: Mapping[str, Any],
key_name: str,
max_depth: int = 5,
current_depth: int = 0,
) -> Any:
"""
Recursively searches for a specified key in a nested dictionary or list and returns its value if found.

Args:
response_data (Mapping[str, Any]): The response data to search through, which can be a dictionary or a list.
key_name (str): The key to search for in the response data.
max_depth (int, optional): The maximum depth to search for the key to avoid infinite recursion. Defaults to 5.
current_depth (int, optional): The current depth of the recursion. Defaults to 0.

Returns:
Any: The value associated with the specified key if found, otherwise None.

Raises:
AirbyteTracedException: If the maximum recursion depth is reached without finding the key.
"""
if current_depth > max_depth:
# this is needed to avoid an inf loop, possible with a very deep nesting observed.
message = f"The maximum level of recursion is reached. Couldn't find the speficied `{key_name}` in the response."
raise ResponseKeysMaxRecurtionReached(
internal_message=message, message=message, failure_type=FailureType.config_error
)

if isinstance(response_data, dict):
# get from the root level
if key_name in response_data:
return response_data[key_name]

# get from the nested object
for _, value in response_data.items():
result = self._find_and_get_value_from_response(
value, key_name, max_depth, current_depth + 1
)
if result is not None:
return result

# get from the nested array object
elif isinstance(response_data, list):
for item in response_data:
result = self._find_and_get_value_from_response(
item, key_name, max_depth, current_depth + 1
)
if result is not None:
return result

return None

@property
def _message_repository(self) -> Optional[MessageRepository]:
"""
The implementation can define a message_repository if it wants debugging logs for HTTP requests
"""
return _NOOP_MESSAGE_REPOSITORY

def _log_response(self, response: requests.Response) -> None:
"""
Logs the HTTP response using the message repository if it is available.

Args:
response (requests.Response): The HTTP response to log.
"""
if self._message_repository:
self._message_repository.log_message(
Level.DEBUG,
lambda: format_http_message(
response,
"Refresh token",
"Obtains access token",
self._NO_STREAM_NAME,
is_auxiliary=True,
),
)

# ----------------
# ABSTR METHODS
# ----------------

@abstractmethod
def get_token_refresh_endpoint(self) -> Optional[str]:
"""Returns the endpoint to refresh the access token"""
Expand Down Expand Up @@ -295,23 +473,3 @@ def access_token(self) -> str:
@abstractmethod
def access_token(self, value: str) -> str:
"""Setter for the access token"""

@property
def _message_repository(self) -> Optional[MessageRepository]:
"""
The implementation can define a message_repository if it wants debugging logs for HTTP requests
"""
return _NOOP_MESSAGE_REPOSITORY

def _log_response(self, response: requests.Response) -> None:
if self._message_repository:
self._message_repository.log_message(
Level.DEBUG,
lambda: format_http_message(
response,
"Refresh token",
"Obtains access token",
self._NO_STREAM_NAME,
is_auxiliary=True,
),
)
Loading
Loading