diff --git a/samcli/local/docker/lambda_build_container.py b/samcli/local/docker/lambda_build_container.py index c2c20e54ad..e291ec79e8 100644 --- a/samcli/local/docker/lambda_build_container.py +++ b/samcli/local/docker/lambda_build_container.py @@ -9,12 +9,21 @@ from typing import List from uuid import uuid4 +from samcli.commands.exceptions import UserException +from samcli.lib.utils.architecture import X86_64, ARM64 from samcli.commands._utils.experimental import get_enabled_experimental_flags from samcli.local.docker.container import Container LOG = logging.getLogger(__name__) +class InvalidArchitectureForImage(UserException): + """ + Raised when architecture that is provided for the image is invalid + """ + pass + + class LambdaBuildContainer(Container): """ Class to manage Build containers that are capable of building AWS Lambda functions. @@ -297,4 +306,8 @@ def get_image_tag(architecture): str Image tag """ + if architecture not in [X86_64, ARM64]: + raise InvalidArchitectureForImage( + f"'{architecture}' is not a valid architecture, it should be either '{X86_64}' or '{ARM64}'" + ) return f"{LambdaBuildContainer._IMAGE_TAG}-{architecture}" diff --git a/tests/unit/local/docker/test_lambda_build_container.py b/tests/unit/local/docker/test_lambda_build_container.py index c72e808f13..f372c3f4ed 100644 --- a/tests/unit/local/docker/test_lambda_build_container.py +++ b/tests/unit/local/docker/test_lambda_build_container.py @@ -11,7 +11,7 @@ from parameterized import parameterized from samcli.lib.utils.architecture import X86_64, ARM64 -from samcli.local.docker.lambda_build_container import LambdaBuildContainer +from samcli.local.docker.lambda_build_container import LambdaBuildContainer, InvalidArchitectureForImage class TestLambdaBuildContainer_init(TestCase): @@ -217,6 +217,10 @@ class TestLambdaBuildContainer_get_image_tag(TestCase): def test_must_get_image_tag(self, architecture, expected_image_tag): self.assertEqual(expected_image_tag, LambdaBuildContainer.get_image_tag(architecture)) + def test_must_raise_an_error_for_invalid_architecture(self): + with self.assertRaises(InvalidArchitectureForImage): + LambdaBuildContainer.get_image_tag("invalid-architecture") + class TestLambdaBuildContainer_get_entrypoint(TestCase): def test_must_get_entrypoint(self):