Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(low-code): pass refresh headers to oauth #219

Merged
merged 5 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions airbyte_cdk/sources/declarative/auth/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class DeclarativeOauth2Authenticator(AbstractOauth2Authenticator, DeclarativeAut
token_expiry_date_format str: format of the datetime; provide it if expires_in is returned in datetime instead of seconds
token_expiry_is_time_of_expiration bool: set True it if expires_in is returned as time of expiration instead of the number seconds until expiration
refresh_request_body (Optional[Mapping[str, Any]]): The request body to send in the refresh request
refresh_request_headers (Optional[Mapping[str, Any]]): The request headers to send in the refresh request
grant_type: The grant_type to request for access_token. If set to refresh_token, the refresh_token parameter has to be provided
message_repository (MessageRepository): the message repository used to emit logs on HTTP requests
"""
Expand All @@ -58,6 +59,7 @@ class DeclarativeOauth2Authenticator(AbstractOauth2Authenticator, DeclarativeAut
access_token_value: Optional[Union[InterpolatedString, str]] = None
expires_in_name: Union[InterpolatedString, str] = "expires_in"
refresh_request_body: Optional[Mapping[str, Any]] = None
refresh_request_headers: Optional[Mapping[str, Any]] = None
grant_type: Union[InterpolatedString, str] = "refresh_token"
message_repository: MessageRepository = NoopMessageRepository()

Expand Down Expand Up @@ -87,6 +89,9 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
self._refresh_request_body = InterpolatedMapping(
self.refresh_request_body or {}, parameters=parameters
)
self._refresh_request_headers = InterpolatedMapping(
self.refresh_request_headers or {}, parameters=parameters
)
self._token_expiry_date: pendulum.DateTime = (
pendulum.parse(
InterpolatedString.create(self.token_expiry_date, parameters=parameters).eval(
Expand Down Expand Up @@ -152,6 +157,9 @@ def get_grant_type(self) -> str:
def get_refresh_request_body(self) -> Mapping[str, Any]:
return self._refresh_request_body.eval(self.config)

def get_refresh_request_headers(self) -> Mapping[str, Any]:
return self._refresh_request_headers.eval(self.config)

def get_token_expiry_date(self) -> pendulum.DateTime:
return self._token_expiry_date # type: ignore # _token_expiry_date is a pendulum.DateTime. It is never None despite what mypy thinks

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1111,6 +1111,14 @@ definitions:
- applicationId: "{{ config['application_id'] }}"
applicationSecret: "{{ config['application_secret'] }}"
token: "{{ config['token'] }}"
refresh_request_headers:
title: Refresh Request Headers
description: Headers of the request sent to get a new access token.
type: object
additionalProperties: true
examples:
- Authorization: "<AUTH_TOKEN>"
Content-Type: "application/x-www-form-urlencoded"
scopes:
title: Scopes
description: List of scopes that should be granted to the access token.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,17 @@ class OAuthAuthenticator(BaseModel):
],
title="Refresh Request Body",
)
refresh_request_headers: Optional[Dict[str, Any]] = Field(
None,
description="Headers of the request sent to get a new access token.",
examples=[
{
"Authorization": "<AUTH_TOKEN>",
"Content-Type": "application/x-www-form-urlencoded",
}
],
title="Refresh Request Headers",
)
scopes: Optional[List[str]] = Field(
None,
description="List of scopes that should be granted to the access token.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1900,6 +1900,9 @@ def create_oauth_authenticator(
refresh_request_body=InterpolatedMapping(
model.refresh_request_body or {}, parameters=model.parameters or {}
).eval(config),
refresh_request_headers=InterpolatedMapping(
model.refresh_request_headers or {}, parameters=model.parameters or {}
).eval(config),
scopes=model.scopes,
token_expiry_date_format=model.token_expiry_date_format,
message_repository=self._message_repository,
Expand All @@ -1916,6 +1919,7 @@ def create_oauth_authenticator(
expires_in_name=model.expires_in_name or "expires_in",
grant_type=model.grant_type or "refresh_token",
refresh_request_body=model.refresh_request_body,
refresh_request_headers=model.refresh_request_headers,
refresh_token=model.refresh_token,
scopes=model.scopes,
token_expiry_date=model.token_expiry_date,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,14 @@ def build_refresh_request_body(self) -> Mapping[str, Any]:

return payload

def build_refresh_request_headers(self) -> Mapping[str, Any] | None:
"""
Returns the request headers to set on the refresh request

"""
headers = self.get_refresh_request_headers()
lazebnyi marked this conversation as resolved.
Show resolved Hide resolved
return headers if headers else None

def _wrap_refresh_token_exception(
self, exception: requests.exceptions.RequestException
) -> bool:
Expand Down Expand Up @@ -128,6 +136,7 @@ def _get_refresh_access_token_response(self) -> Any:
method="POST",
url=self.get_token_refresh_endpoint(), # type: ignore # returns None, if not provided, but str | bytes is expected.
data=self.build_refresh_request_body(),
headers=self.build_refresh_request_headers(),
)
if response.ok:
response_json = response.json()
Expand Down Expand Up @@ -242,6 +251,10 @@ def get_expires_in_name(self) -> str:
def get_refresh_request_body(self) -> Mapping[str, Any]:
"""Returns the request body to set on the refresh request"""

@abstractmethod
def get_refresh_request_headers(self) -> Mapping[str, Any]:
"""Returns the request headers to set on the refresh request"""

@abstractmethod
def get_grant_type(self) -> str:
"""Returns grant_type specified for requesting access_token"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(
access_token_name: str = "access_token",
expires_in_name: str = "expires_in",
refresh_request_body: Mapping[str, Any] | None = None,
refresh_request_headers: Mapping[str, Any] | None = None,
grant_type: str = "refresh_token",
token_expiry_is_time_of_expiration: bool = False,
refresh_token_error_status_codes: Tuple[int, ...] = (),
Expand All @@ -50,6 +51,7 @@ def __init__(
self._access_token_name = access_token_name
self._expires_in_name = expires_in_name
self._refresh_request_body = refresh_request_body
self._refresh_request_headers = refresh_request_headers
self._grant_type = grant_type

self._token_expiry_date = token_expiry_date or pendulum.now().subtract(days=1) # type: ignore [no-untyped-call]
Expand Down Expand Up @@ -84,6 +86,9 @@ def get_expires_in_name(self) -> str:
def get_refresh_request_body(self) -> Mapping[str, Any]:
return self._refresh_request_body # type: ignore [return-value]

def get_refresh_request_headers(self) -> Mapping[str, Any]:
return self._refresh_request_headers # type: ignore [return-value]

def get_grant_type(self) -> str:
return self._grant_type

Expand Down Expand Up @@ -129,6 +134,7 @@ def __init__(
expires_in_name: str = "expires_in",
refresh_token_name: str = "refresh_token",
refresh_request_body: Mapping[str, Any] | None = None,
refresh_request_headers: Mapping[str, Any] | None = None,
grant_type: str = "refresh_token",
client_id: Optional[str] = None,
client_secret: Optional[str] = None,
Expand All @@ -151,6 +157,7 @@ def __init__(
expires_in_name (str, optional): Name of the name of the field that characterizes when the current access token will expire, used to parse the refresh token response. Defaults to "expires_in".
refresh_token_name (str, optional): Name of the name of the refresh token field, used to parse the refresh token response. Defaults to "refresh_token".
refresh_request_body (Mapping[str, Any], optional): Custom key value pair that will be added to the refresh token request body. Defaults to None.
refresh_request_headers (Mapping[str, Any], optional): Custom key value pair that will be added to the refresh token request headers. Defaults to None.
grant_type (str, optional): OAuth grant type. Defaults to "refresh_token".
client_id (Optional[str]): The client id to authenticate. If not specified, defaults to credentials.client_id in the config object.
client_secret (Optional[str]): The client secret to authenticate. If not specified, defaults to credentials.client_secret in the config object.
Expand Down Expand Up @@ -191,6 +198,7 @@ def __init__(
access_token_name=access_token_name,
expires_in_name=expires_in_name,
refresh_request_body=refresh_request_body,
refresh_request_headers=refresh_request_headers,
grant_type=grant_type,
token_expiry_date_format=token_expiry_date_format,
token_expiry_is_time_of_expiration=token_expiry_is_time_of_expiration,
Expand Down
69 changes: 67 additions & 2 deletions unit_tests/sources/declarative/auth/test_oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,39 @@ def test_refresh_request_body(self):
}
assert body == expected

def test_refresh_request_headers(self):
"""
Request headers should match given configuration.
"""
oauth = DeclarativeOauth2Authenticator(
token_refresh_endpoint="{{ config['refresh_endpoint'] }}",
client_id="{{ config['client_id'] }}",
client_secret="{{ config['client_secret'] }}",
refresh_token="{{ parameters['refresh_token'] }}",
config=config,
token_expiry_date="{{ config['token_expiry_date'] }}",
refresh_request_headers={
"Authorization": "<TOKEN>",
"Content-Type": "application/x-www-form-urlencoded",
},
parameters=parameters,
)
headers = oauth.build_refresh_request_headers()
expected = {"Authorization": "<TOKEN>", "Content-Type": "application/x-www-form-urlencoded"}
assert headers == expected

oauth = DeclarativeOauth2Authenticator(
token_refresh_endpoint="{{ config['refresh_endpoint'] }}",
client_id="{{ config['client_id'] }}",
client_secret="{{ config['client_secret'] }}",
refresh_token="{{ parameters['refresh_token'] }}",
config=config,
token_expiry_date="{{ config['token_expiry_date'] }}",
parameters=parameters,
)
headers = oauth.build_refresh_request_headers()
assert headers is None

def test_refresh_with_encode_config_params(self):
oauth = DeclarativeOauth2Authenticator(
token_refresh_endpoint="{{ config['refresh_endpoint'] }}",
Expand Down Expand Up @@ -191,6 +224,36 @@ def test_refresh_access_token(self, mocker):
filtered = filter_secrets("access_token")
assert filtered == "****"

def test_refresh_access_token_when_headers_provided(self, mocker):
expected_headers = {
"Authorization": "Bearer some_access_token",
"Content-Type": "application/x-www-form-urlencoded",
}
oauth = DeclarativeOauth2Authenticator(
token_refresh_endpoint="{{ config['refresh_endpoint'] }}",
client_id="{{ config['client_id'] }}",
client_secret="{{ config['client_secret'] }}",
refresh_token="{{ config['refresh_token'] }}",
config=config,
scopes=["scope1", "scope2"],
token_expiry_date="{{ config['token_expiry_date'] }}",
refresh_request_headers=expected_headers,
parameters={},
)

resp.status_code = 200
mocker.patch.object(
resp, "json", return_value={"access_token": "access_token", "expires_in": 1000}
)
mocked_request = mocker.patch.object(
requests, "request", side_effect=mock_request, autospec=True
)
token = oauth.refresh_access_token()

assert ("access_token", 1000) == token

assert mocked_request.call_args.kwargs["headers"] == expected_headers

def test_refresh_access_token_missing_access_token(self, mocker):
oauth = DeclarativeOauth2Authenticator(
token_refresh_endpoint="{{ config['refresh_endpoint'] }}",
Expand Down Expand Up @@ -371,7 +434,9 @@ def test_error_handling(self, mocker):
assert e.value.errno == 400


def mock_request(method, url, data):
def mock_request(method, url, data, headers):
if url == "refresh_end":
return resp
raise Exception(f"Error while refreshing access token with request: {method}, {url}, {data}")
raise Exception(
f"Error while refreshing access token with request: {method}, {url}, {data}, {headers}"
)
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,38 @@ def test_refresh_request_body(self):
}
assert body == expected

def test_refresh_request_headers(self):
"""
Request headers should match given configuration.
"""
oauth = Oauth2Authenticator(
token_refresh_endpoint="refresh_end",
client_id="some_client_id",
client_secret="some_client_secret",
refresh_token="some_refresh_token",
token_expiry_date=pendulum.now().add(days=3),
refresh_request_headers={
"Authorization": "Bearer some_refresh_token",
"Content-Type": "application/x-www-form-urlencoded",
},
)
headers = oauth.build_refresh_request_headers()
expected = {
"Authorization": "Bearer some_refresh_token",
"Content-Type": "application/x-www-form-urlencoded",
}
assert headers == expected

oauth = Oauth2Authenticator(
token_refresh_endpoint="refresh_end",
client_id="some_client_id",
client_secret="some_client_secret",
refresh_token="some_refresh_token",
token_expiry_date=pendulum.now().add(days=3),
)
headers = oauth.build_refresh_request_headers()
assert headers is None

def test_refresh_access_token(self, mocker):
oauth = Oauth2Authenticator(
token_refresh_endpoint="refresh_end",
Expand Down Expand Up @@ -210,6 +242,35 @@ def test_refresh_access_token(self, mocker):
assert isinstance(expires_in, str)
assert ("access_token", "2022-04-24T00:00:00Z") == (token, expires_in)

def test_refresh_access_token_when_headers_provided(self, mocker):
expected_headers = {
"Authorization": "Bearer some_access_token",
"Content-Type": "application/x-www-form-urlencoded",
}
oauth = Oauth2Authenticator(
token_refresh_endpoint="refresh_end",
client_id="some_client_id",
client_secret="some_client_secret",
refresh_token="some_refresh_token",
scopes=["scope1", "scope2"],
token_expiry_date=pendulum.now().add(days=3),
refresh_request_headers=expected_headers,
)

resp.status_code = 200
mocker.patch.object(
resp, "json", return_value={"access_token": "access_token", "expires_in": 1000}
)
mocked_request = mocker.patch.object(
requests, "request", side_effect=mock_request, autospec=True
)
token, expires_in = oauth.refresh_access_token()

assert isinstance(expires_in, int)
assert ("access_token", 1000) == (token, expires_in)

assert mocked_request.call_args.kwargs["headers"] == expected_headers

@pytest.mark.parametrize(
"expires_in_response, token_expiry_date_format, expected_token_expiry_date",
[
Expand Down Expand Up @@ -522,7 +583,9 @@ def test_refresh_access_token(self, mocker, connector_config):
)


def mock_request(method, url, data):
def mock_request(method, url, data, headers):
if url == "refresh_end":
return resp
raise Exception(f"Error while refreshing access token with request: {method}, {url}, {data}")
raise Exception(
f"Error while refreshing access token with request: {method}, {url}, {data}, {headers}"
)
Loading