Skip to content

Commit

Permalink
Fix aws assume role session creation when deferrable (apache#40051)
Browse files Browse the repository at this point in the history
* fixed aws connection issue for  assume_role when session created with defarrable=True
  • Loading branch information
gopidesupavan authored Jun 6, 2024
1 parent a31b10e commit 42a2b1a
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
11 changes: 8 additions & 3 deletions airflow/providers/amazon/aws/hooks/base_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,9 @@ def get_async_session(self):

return async_get_session()

def create_session(self, deferrable: bool = False) -> boto3.session.Session:
def create_session(
self, deferrable: bool = False
) -> boto3.session.Session | aiobotocore.session.AioSession:
"""Create boto3 or aiobotocore Session from connection config."""
if not self.conn:
self.log.info(
Expand Down Expand Up @@ -198,7 +200,7 @@ def _create_basic_session(self, session_kwargs: dict[str, Any]) -> boto3.session

def _create_session_with_assume_role(
self, session_kwargs: dict[str, Any], deferrable: bool = False
) -> boto3.session.Session:
) -> boto3.session.Session | aiobotocore.session.AioSession:
if self.conn.assume_role_method == "assume_role_with_web_identity":
# Deferred credentials have no initial credentials
credential_fetcher = self._get_web_identity_credential_fetcher()
Expand Down Expand Up @@ -239,7 +241,10 @@ def _create_session_with_assume_role(
session._credentials = credentials
session.set_config_variable("region", self.basic_session.region_name)

return boto3.session.Session(botocore_session=session, **session_kwargs)
if not deferrable:
return boto3.session.Session(botocore_session=session, **session_kwargs)

return session

def _refresh_credentials(self) -> dict[str, Any]:
self.log.debug("Refreshing credentials")
Expand Down
7 changes: 6 additions & 1 deletion tests/providers/amazon/aws/hooks/test_base_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from unittest.mock import MagicMock, PropertyMock, mock_open
from uuid import UUID

import aiobotocore.session
import boto3
import botocore
import jinja2
Expand Down Expand Up @@ -247,6 +248,7 @@ def test_async_create_session_from_credentials(self, region_name, profile_name):
session_profile = async_session.get_config_variable("profile")

assert session_profile == profile_name
assert isinstance(async_session, aiobotocore.session.AioSession)

@pytest.mark.asyncio
async def test_async_create_a_session_from_credentials_without_token(self):
Expand All @@ -266,6 +268,7 @@ async def test_async_create_a_session_from_credentials_without_token(self):
assert cred.access_key == "test_aws_access_key_id"
assert cred.secret_key == "test_aws_secret_access_key"
assert cred.token is None
assert isinstance(async_session, aiobotocore.session.AioSession)

config_for_credentials_test = [
(
Expand Down Expand Up @@ -300,6 +303,7 @@ def test_get_credentials_from_role_arn(self, conn_id, conn_extra, region_name):
# Validate method of botocore credentials provider.
# It shouldn't be 'explicit' which refers in this case to initial credentials.
assert session.get_credentials().method == "sts-assume-role"
assert isinstance(session, boto3.session.Session)

@pytest.mark.asyncio
@pytest.mark.parametrize(
Expand Down Expand Up @@ -330,14 +334,15 @@ def side_effect():
conn = AwsConnectionWrapper.from_connection_metadata(conn_id=conn_id, extra=extra)
sf = BaseSessionFactory(conn=conn)
session = sf.create_session(deferrable=True)
assert session.region_name == region_name
assert session.get_config_variable("region") == region_name
# Validate method of botocore credentials provider.
# It shouldn't be 'explicit' which refers in this case to initial credentials.
credentials = await session.get_credentials()

assert inspect.iscoroutinefunction(credentials.get_frozen_credentials)

assert credentials.method == "sts-assume-role"
assert isinstance(session, aiobotocore.session.AioSession)


class TestAwsBaseHook:
Expand Down

0 comments on commit 42a2b1a

Please sign in to comment.