Skip to content

Commit

Permalink
fix(sqs): don't crash on multiple predefined queues with aws sts sess…
Browse files Browse the repository at this point in the history
…ion (#2224)

* chore(sqs): write the test case for multiple predefined queues with aws sts session

* fix(sqs): don't crash on multiple predefined queues with aws sts session

* refactor(sqs): make _new_predefined_queue_client_with_sts_session()
  • Loading branch information
iBluemind authored Jan 13, 2025
1 parent 4c64cdd commit 83b296f
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 22 deletions.
40 changes: 18 additions & 22 deletions kombu/transport/SQS.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,34 +766,30 @@ def sqs(self, queue=None):
return c

def _handle_sts_session(self, queue, q):
region = q.get('region', self.region)
if not hasattr(self, 'sts_expiration'): # STS token - token init
sts_creds = self.generate_sts_session_token(
self.transport_options.get('sts_role_arn'),
self.transport_options.get('sts_token_timeout', 900))
self.sts_expiration = sts_creds['Expiration']
c = self._predefined_queue_clients[queue] = self.new_sqs_client(
region=q.get('region', self.region),
access_key_id=sts_creds['AccessKeyId'],
secret_access_key=sts_creds['SecretAccessKey'],
session_token=sts_creds['SessionToken'],
)
return c
return self._new_predefined_queue_client_with_sts_session(queue, region)
# STS token - refresh if expired
elif self.sts_expiration.replace(tzinfo=None) < datetime.utcnow():
sts_creds = self.generate_sts_session_token(
self.transport_options.get('sts_role_arn'),
self.transport_options.get('sts_token_timeout', 900))
self.sts_expiration = sts_creds['Expiration']
c = self._predefined_queue_clients[queue] = self.new_sqs_client(
region=q.get('region', self.region),
access_key_id=sts_creds['AccessKeyId'],
secret_access_key=sts_creds['SecretAccessKey'],
session_token=sts_creds['SessionToken'],
)
return c
return self._new_predefined_queue_client_with_sts_session(queue, region)
else: # STS token - ruse existing
if queue not in self._predefined_queue_clients:
return self._new_predefined_queue_client_with_sts_session(queue, region)
return self._predefined_queue_clients[queue]

def _new_predefined_queue_client_with_sts_session(self, queue, region):
sts_creds = self.generate_sts_session_token(
self.transport_options.get('sts_role_arn'),
self.transport_options.get('sts_token_timeout', 900))
self.sts_expiration = sts_creds['Expiration']
c = self._predefined_queue_clients[queue] = self.new_sqs_client(
region=region,
access_key_id=sts_creds['AccessKeyId'],
secret_access_key=sts_creds['SecretAccessKey'],
session_token=sts_creds['SessionToken'],
)
return c

def generate_sts_session_token(self, role_arn, token_expiry_seconds):
sts_client = boto3.client('sts')
sts_policy = sts_client.assume_role(
Expand Down
28 changes: 28 additions & 0 deletions t/unit/transport/test_SQS.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,6 +996,34 @@ def test_sts_session_not_expired(self):
# Assert
mock_generate_sts_session_token.assert_not_called()

def test_sts_session_with_multiple_predefined_queues(self):
connection = Connection(transport=SQS.Transport, transport_options={
'predefined_queues': example_predefined_queues,
'sts_role_arn': 'test::arn'
})
channel = connection.channel()
sqs = SQS_Channel_sqs.__get__(channel, SQS.Channel)

mock_generate_sts_session_token = Mock()
mock_new_sqs_client = Mock()
channel.new_sqs_client = mock_new_sqs_client
mock_generate_sts_session_token.return_value = {
'Expiration': datetime.utcnow() + timedelta(days=1),
'SessionToken': 123,
'AccessKeyId': 123,
'SecretAccessKey': 123
}

channel.generate_sts_session_token = mock_generate_sts_session_token

# Act
sqs(queue='queue-1')
sqs(queue='queue-2')

# Assert
mock_generate_sts_session_token.assert_called()
mock_new_sqs_client.assert_called()

def test_message_attribute(self):
message = 'my test message'
self.producer.publish(message, message_attributes={
Expand Down

0 comments on commit 83b296f

Please sign in to comment.