Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(low-code): add profile assertion flow to oauth authenticator component #236

57 changes: 52 additions & 5 deletions airbyte_cdk/sources/declarative/auth/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
#

from dataclasses import InitVar, dataclass, field
from typing import Any, List, Mapping, Optional, Union
from typing import Any, List, Mapping, MutableMapping, Optional, Union

import pendulum

from airbyte_cdk.sources.declarative.auth.declarative_authenticator import DeclarativeAuthenticator
from airbyte_cdk.sources.declarative.interpolation.interpolated_boolean import InterpolatedBoolean
from airbyte_cdk.sources.declarative.interpolation.interpolated_mapping import InterpolatedMapping
from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString
from airbyte_cdk.sources.message import MessageRepository, NoopMessageRepository
Expand Down Expand Up @@ -44,10 +45,10 @@ class DeclarativeOauth2Authenticator(AbstractOauth2Authenticator, DeclarativeAut
message_repository (MessageRepository): the message repository used to emit logs on HTTP requests
"""

client_id: Union[InterpolatedString, str]
client_secret: Union[InterpolatedString, str]
config: Mapping[str, Any]
parameters: InitVar[Mapping[str, Any]]
client_id: Optional[Union[InterpolatedString, str]] = None
client_secret: Optional[Union[InterpolatedString, str]] = None
token_refresh_endpoint: Optional[Union[InterpolatedString, str]] = None
refresh_token: Optional[Union[InterpolatedString, str]] = None
scopes: Optional[List[str]] = None
Expand All @@ -66,6 +67,8 @@ class DeclarativeOauth2Authenticator(AbstractOauth2Authenticator, DeclarativeAut
grant_type_name: Union[InterpolatedString, str] = "grant_type"
grant_type: Union[InterpolatedString, str] = "refresh_token"
message_repository: MessageRepository = NoopMessageRepository()
profile_assertion: Optional[DeclarativeAuthenticator] = None
use_profile_assertion: Optional[Union[InterpolatedBoolean, str, bool]] = False
lazebnyi marked this conversation as resolved.
Show resolved Hide resolved

def __post_init__(self, parameters: Mapping[str, Any]) -> None:
super().__init__()
Expand Down Expand Up @@ -99,7 +102,12 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
self.grant_type_name = InterpolatedString.create(
self.grant_type_name, parameters=parameters
)
self.grant_type = InterpolatedString.create(self.grant_type, parameters=parameters)
self.grant_type = InterpolatedString.create(
"urn:ietf:params:oauth:grant-type:jwt-bearer"
if self.use_profile_assertion
else self.grant_type,
parameters=parameters,
)
self._refresh_request_body = InterpolatedMapping(
self.refresh_request_body or {}, parameters=parameters
)
Expand All @@ -115,6 +123,13 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
if self.token_expiry_date
else pendulum.now().subtract(days=1) # type: ignore # substract does not have type hints
)
self.use_profile_assertion = (
InterpolatedBoolean(self.use_profile_assertion, parameters=parameters)
if isinstance(self.use_profile_assertion, str)
else self.use_profile_assertion
)
self.assertion_name = "assertion"

if self.access_token_value is not None:
self._access_token_value = InterpolatedString.create(
self.access_token_value, parameters=parameters
Expand All @@ -126,9 +141,20 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
self._access_token_value if self.access_token_value else None
)

if not self.use_profile_assertion and any(
client_creds is None for client_creds in [self.client_id, self.client_secret]
):
raise ValueError(
"OAuthAuthenticator configuration error: Both 'client_id' and 'client_secret' are required for the "
"basic OAuth flow."
)
if self.profile_assertion is None and self.use_profile_assertion:
raise ValueError(
"OAuthAuthenticator configuration error: 'profile_assertion' is required when using the profile assertion flow."
)
if self.get_grant_type() == "refresh_token" and self._refresh_token is None:
raise ValueError(
"OAuthAuthenticator needs a refresh_token parameter if grant_type is set to `refresh_token`"
"OAuthAuthenticator configuration error: A 'refresh_token' is required when the 'grant_type' is set to 'refresh_token'."
)

def get_token_refresh_endpoint(self) -> Optional[str]:
Expand Down Expand Up @@ -192,6 +218,27 @@ def get_token_expiry_date(self) -> pendulum.DateTime:
def set_token_expiry_date(self, value: Union[str, int]) -> None:
self._token_expiry_date = self._parse_token_expiration_date(value)

def get_assertion_name(self) -> str:
return self.assertion_name

def get_assertion(self) -> str:
if self.profile_assertion is None:
raise ValueError("profile_assertion is not set")
return self.profile_assertion.token

def build_refresh_request_body(self) -> Mapping[str, Any]:
"""
Returns the request body to set on the refresh request

Override to define additional parameters
"""
payload: MutableMapping[str, Any] = {
self.get_grant_type_name(): self.get_grant_type(),
self.get_assertion_name(): self.get_assertion(),
}

return payload if self.use_profile_assertion else super().build_refresh_request_body()

@property
def access_token(self) -> str:
if self._access_token is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1058,8 +1058,6 @@ definitions:
type: object
required:
- type
- client_id
- client_secret
properties:
type:
type: string
Expand Down Expand Up @@ -1254,6 +1252,15 @@ definitions:
default: []
examples:
- ["invalid_grant", "invalid_permissions"]
profile_assertion:
title: Profile Assertion
description: The authenticator being used to authenticate the client authenticator.
"$ref": "#/definitions/JwtAuthenticator"
use_profile_assertion:
lazebnyi marked this conversation as resolved.
Show resolved Hide resolved
title: Use Profile Assertion
description: Enable using profile assertion as a flow for OAuth authorization.
type: boolean
default: false
$parameters:
type: object
additionalProperties: true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -496,8 +496,8 @@ class OAuthAuthenticator(BaseModel):
examples=["custom_app_id"],
title="Client ID Property Name",
)
client_id: str = Field(
...,
client_id: Optional[str] = Field(
None,
description="The OAuth client ID. Fill it in the user inputs.",
examples=["{{ config['client_id }}", "{{ config['credentials']['client_id }}"],
title="Client ID",
Expand All @@ -508,8 +508,8 @@ class OAuthAuthenticator(BaseModel):
examples=["custom_app_secret"],
title="Client Secret Property Name",
)
client_secret: str = Field(
...,
client_secret: Optional[str] = Field(
None,
description="The OAuth client secret. Fill it in the user inputs.",
examples=[
"{{ config['client_secret }}",
Expand Down Expand Up @@ -614,6 +614,16 @@ class OAuthAuthenticator(BaseModel):
description="When the token updater is defined, new refresh tokens, access tokens and the access token expiry date are written back from the authentication response to the config object. This is important if the refresh token can only used once.",
title="Token Updater",
)
profile_assertion: Optional[JwtAuthenticator] = Field(
None,
description="The authenticator being used to authenticate the client authenticator.",
title="Profile Assertion",
)
use_profile_assertion: Optional[bool] = Field(
False,
description="Enable using profile assertion as a flow for OAuth authorization.",
title="Use Profile Assertion",
)
parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")


Expand Down Expand Up @@ -719,7 +729,7 @@ class HttpResponseFilter(BaseModel):
class TypesMap(BaseModel):
target_type: Union[str, List[str]]
current_type: Union[str, List[str]]
condition: Optional[str]
condition: Optional[str] = None


class SchemaTypeIdentifier(BaseModel):
Expand Down Expand Up @@ -797,14 +807,11 @@ class DpathFlattenFields(BaseModel):
field_path: List[str] = Field(
...,
description="A path to field that needs to be flattened.",
examples=[
["data"],
["data", "*", "field"],
],
examples=[["data"], ["data", "*", "field"]],
title="Field Path",
)
delete_origin_value: Optional[bool] = Field(
False,
None,
description="Whether to delete the origin value or keep it. Default is False.",
title="Delete Origin Value",
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1925,6 +1925,12 @@ def create_no_pagination(
def create_oauth_authenticator(
self, model: OAuthAuthenticatorModel, config: Config, **kwargs: Any
) -> DeclarativeOauth2Authenticator:
profile_assertion = (
self._create_component_from_model(model.profile_assertion, config=config)
if model.profile_assertion
else None
)

if model.refresh_token_updater:
# ignore type error because fixing it would have a lot of dependencies, revisit later
return DeclarativeSingleUseRefreshTokenOauth2Authenticator( # type: ignore
Expand Down Expand Up @@ -1997,6 +2003,7 @@ def create_oauth_authenticator(
config=config,
parameters=model.parameters or {},
message_repository=self._message_repository,
profile_assertion=profile_assertion,
)

def create_offset_increment(
Expand Down
45 changes: 44 additions & 1 deletion unit_tests/sources/declarative/auth/test_oauth.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#

import base64
import json
import logging
from unittest.mock import Mock

Expand All @@ -12,14 +14,16 @@
from requests import Response

from airbyte_cdk.sources.declarative.auth import DeclarativeOauth2Authenticator
from airbyte_cdk.sources.declarative.auth.jwt import JwtAuthenticator
from airbyte_cdk.test.mock_http import HttpMocker, HttpRequest, HttpResponse
from airbyte_cdk.utils.airbyte_secrets_utils import filter_secrets

LOGGER = logging.getLogger(__name__)

resp = Response()

config = {
"refresh_endpoint": "refresh_end",
"refresh_endpoint": "https://refresh_endpoint.com",
"client_id": "some_client_id",
"client_secret": "some_client_secret",
"token_expiry_date": pendulum.now().subtract(days=2).to_rfc3339_string(),
Expand Down Expand Up @@ -412,6 +416,45 @@ def test_set_token_expiry_date_no_format(self, mocker, expires_in_response, next
assert "access_token" == token
assert oauth.get_token_expiry_date() == pendulum.parse(next_day)

def test_profile_assertion(self, mocker):
with HttpMocker() as http_mocker:
jwt = JwtAuthenticator(
config={},
parameters={},
secret_key="test",
algorithm="HS256",
token_duration=1000,
typ="JWT",
iss="iss",
)

mocker.patch(
"airbyte_cdk.sources.declarative.auth.jwt.JwtAuthenticator.token",
new_callable=lambda: "token",
)

oauth = DeclarativeOauth2Authenticator(
token_refresh_endpoint="https://refresh_endpoint.com/",
config=config,
parameters={},
profile_assertion=jwt,
use_profile_assertion=True,
)
http_mocker.post(
HttpRequest(
url="https://refresh_endpoint.com/",
body="grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Ajwt-bearer&assertion=token",
),
HttpResponse(body=json.dumps({"access_token": "access_token", "expires_in": 1000})),
)

token = oauth.refresh_access_token()

assert ("access_token", 1000) == token

filtered = filter_secrets("access_token")
assert filtered == "****"

def test_error_handling(self, mocker):
oauth = DeclarativeOauth2Authenticator(
token_refresh_endpoint="{{ config['refresh_endpoint'] }}",
Expand Down
Loading