diff --git a/src/lambda_codebase/account_processing/delete_default_vpc.py b/src/lambda_codebase/account_processing/delete_default_vpc.py index 9cd3309a5..586f278a4 100644 --- a/src/lambda_codebase/account_processing/delete_default_vpc.py +++ b/src/lambda_codebase/account_processing/delete_default_vpc.py @@ -5,7 +5,10 @@ Deletes the default VPC in a particular region """ import os + from aws_xray_sdk.core import patch_all +from botocore.exceptions import ClientError +import tenacity # ADF imports from logger import configure_logger @@ -25,12 +28,25 @@ def assume_role(account_id): "adf_delete_default_vpc", ) - +@tenacity.retry( + retry=tenacity.retry_if_exception_type(ClientError), + # Fail after 180 Sec of retrying + stop=tenacity.stop_after_delay(180), + # Wait 2^x * 1 second between each retry starting with 4s, max 10s intervals + wait=tenacity.wait_exponential(multiplier=1, min=4, max=10), +) def find_default_vpc(ec2_client): - vpc_response = ec2_client.describe_vpcs() - for vpc in vpc_response["Vpcs"]: - if vpc["IsDefault"] is True: - return vpc["VpcId"] + try: + vpc_response = ec2_client.describe_vpcs() + for vpc in vpc_response["Vpcs"]: + if vpc.get("IsDefault", False): + return vpc["VpcId"] + except ClientError as error: + LOGGER.debug( + "An error occurred while describing VPCs: %s", error + ) + raise + # If no default VPC found, return None return None diff --git a/src/lambda_codebase/account_processing/requirements.txt b/src/lambda_codebase/account_processing/requirements.txt index 2542bd380..0d3022cc7 100644 --- a/src/lambda_codebase/account_processing/requirements.txt +++ b/src/lambda_codebase/account_processing/requirements.txt @@ -1,2 +1,3 @@ aws-xray-sdk==2.13.0 pyyaml~=6.0.1 +tenacity==8.2.3 diff --git a/src/lambda_codebase/account_processing/tests/__init__.py b/src/lambda_codebase/account_processing/tests/__init__.py index 014883ae9..ffb650ebf 100644 --- a/src/lambda_codebase/account_processing/tests/__init__.py +++ b/src/lambda_codebase/account_processing/tests/__init__.py @@ -2,3 +2,12 @@ # SPDX-License-Identifier: MIT-0 # pylint: skip-file + +""" +__init__ for tests module +""" + +import sys +import os + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) diff --git a/src/lambda_codebase/account_processing/tests/test_delete_default_vpc.py b/src/lambda_codebase/account_processing/tests/test_delete_default_vpc.py new file mode 100644 index 000000000..e5020a2a6 --- /dev/null +++ b/src/lambda_codebase/account_processing/tests/test_delete_default_vpc.py @@ -0,0 +1,48 @@ +# Copyright Amazon.com Inc. or its affiliates. +# SPDX-License-Identifier: MIT-0 + +""" +Tests the delete_default_vpc lambda +""" + +import unittest +from unittest.mock import MagicMock, patch +from delete_default_vpc import find_default_vpc +from botocore.exceptions import ClientError + + +class TestFindDefaultVPC(unittest.TestCase): + + @patch("tenacity.nap.time.sleep", MagicMock()) + @patch('delete_default_vpc.patch_all') + # pylint: disable=unused-argument + def test_find_default_vpc(self, mock_patch_all): + # Create a mock ec2_client + mock_ec2_client = MagicMock() + + # Define the side effects for describe_vpcs method + side_effects = [ + ClientError({'Error': {'Code': 'MockTestError'}}, 'describe_vpcs'), + ClientError({'Error': {'Code': 'MockTestError'}}, 'describe_vpcs'), + {"Vpcs": [ + {"VpcId": "vpc-123", "IsDefault": False}, + {"VpcId": "vpc-456", "IsDefault": True}, + {"VpcId": "vpc-789", "IsDefault": False} + ]} + ] + + # Set side_effect for the mock ec2_client.describe_vpcs + mock_ec2_client.describe_vpcs.side_effect = side_effects + + # Call the function with the mock ec2_client + default_vpc_id = find_default_vpc(mock_ec2_client) + + # Check if the correct default VPC ID is returned + self.assertEqual(default_vpc_id, "vpc-456") + + # Check if describe_vpcs method is called 3 times + self.assertEqual(mock_ec2_client.describe_vpcs.call_count, 3) + + +if __name__ == '__main__': + unittest.main()