From 22294abf68f17eefc00ec9b363bfcf1ca21f145a Mon Sep 17 00:00:00 2001 From: Andrey Anshin Date: Thu, 28 Dec 2023 16:26:37 +0400 Subject: [PATCH] Use base aws classes in AWS Step Functions Operators/Sensors/Triggers (#36468) --- .../amazon/aws/operators/step_function.py | 63 ++++---- .../amazon/aws/sensors/step_function.py | 36 ++--- .../amazon/aws/triggers/step_function.py | 9 +- .../operators/step_functions.rst | 5 + .../aws/operators/test_step_function.py | 143 ++++++++++-------- .../amazon/aws/sensors/test_step_function.py | 112 +++++++------- 6 files changed, 190 insertions(+), 178 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/step_function.py b/airflow/providers/amazon/aws/operators/step_function.py index 68324df731dec..e02de32bae589 100644 --- a/airflow/providers/amazon/aws/operators/step_function.py +++ b/airflow/providers/amazon/aws/operators/step_function.py @@ -22,15 +22,16 @@ from airflow.configuration import conf from airflow.exceptions import AirflowException -from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook +from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator from airflow.providers.amazon.aws.triggers.step_function import StepFunctionsExecutionCompleteTrigger +from airflow.providers.amazon.aws.utils.mixins import aws_template_fields if TYPE_CHECKING: from airflow.utils.context import Context -class StepFunctionStartExecutionOperator(BaseOperator): +class StepFunctionStartExecutionOperator(AwsBaseOperator[StepFunctionHook]): """ An Operator that begins execution of an AWS Step Function State Machine. @@ -50,10 +51,20 @@ class StepFunctionStartExecutionOperator(BaseOperator): :param deferrable: If True, the operator will wait asynchronously for the job to complete. This implies waiting for completion. This mode requires aiobotocore module to be installed. (default: False, but can be overridden in config file by setting default_deferrable to True) + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html """ - template_fields: Sequence[str] = ("state_machine_arn", "name", "input") - template_ext: Sequence[str] = () + aws_hook_class = StepFunctionHook + template_fields: Sequence[str] = aws_template_fields("state_machine_arn", "name", "input") ui_color = "#f9c915" def __init__( @@ -62,8 +73,6 @@ def __init__( state_machine_arn: str, name: str | None = None, state_machine_input: dict | str | None = None, - aws_conn_id: str = "aws_default", - region_name: str | None = None, waiter_max_attempts: int = 30, waiter_delay: int = 60, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), @@ -73,18 +82,12 @@ def __init__( self.state_machine_arn = state_machine_arn self.name = name self.input = state_machine_input - self.aws_conn_id = aws_conn_id - self.region_name = region_name self.waiter_delay = waiter_delay self.waiter_max_attempts = waiter_max_attempts self.deferrable = deferrable def execute(self, context: Context): - hook = StepFunctionHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) - - execution_arn = hook.start_execution(self.state_machine_arn, self.name, self.input) - - if execution_arn is None: + if not (execution_arn := self.hook.start_execution(self.state_machine_arn, self.name, self.input)): raise AirflowException(f"Failed to start State Machine execution for: {self.state_machine_arn}") self.log.info("Started State Machine execution for %s: %s", self.state_machine_arn, execution_arn) @@ -96,6 +99,8 @@ def execute(self, context: Context): waiter_max_attempts=self.waiter_max_attempts, aws_conn_id=self.aws_conn_id, region_name=self.region_name, + botocore_config=self.botocore_config, + verify=self.verify, ), method_name="execute_complete", timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay), @@ -110,7 +115,7 @@ def execute_complete(self, context: Context, event: dict[str, Any] | None = None return event["execution_arn"] -class StepFunctionGetExecutionOutputOperator(BaseOperator): +class StepFunctionGetExecutionOutputOperator(AwsBaseOperator[StepFunctionHook]): """ An Operator that returns the output of an AWS Step Function State Machine execution. @@ -121,30 +126,28 @@ class StepFunctionGetExecutionOutputOperator(BaseOperator): :ref:`howto/operator:StepFunctionGetExecutionOutputOperator` :param execution_arn: ARN of the Step Function State Machine Execution - :param aws_conn_id: aws connection to use, defaults to 'aws_default' + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html """ - template_fields: Sequence[str] = ("execution_arn",) - template_ext: Sequence[str] = () + aws_hook_class = StepFunctionHook + template_fields: Sequence[str] = aws_template_fields("execution_arn") ui_color = "#f9c915" - def __init__( - self, - *, - execution_arn: str, - aws_conn_id: str = "aws_default", - region_name: str | None = None, - **kwargs, - ): + def __init__(self, *, execution_arn: str, **kwargs): super().__init__(**kwargs) self.execution_arn = execution_arn - self.aws_conn_id = aws_conn_id - self.region_name = region_name def execute(self, context: Context): - hook = StepFunctionHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) - - execution_status = hook.describe_execution(self.execution_arn) + execution_status = self.hook.describe_execution(self.execution_arn) response = None if "output" in execution_status: response = json.loads(execution_status["output"]) diff --git a/airflow/providers/amazon/aws/sensors/step_function.py b/airflow/providers/amazon/aws/sensors/step_function.py index 053a751336268..5e0d3cfcf79cc 100644 --- a/airflow/providers/amazon/aws/sensors/step_function.py +++ b/airflow/providers/amazon/aws/sensors/step_function.py @@ -17,20 +17,20 @@ from __future__ import annotations import json -from functools import cached_property from typing import TYPE_CHECKING, Sequence from deprecated import deprecated from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook -from airflow.sensors.base import BaseSensorOperator +from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor +from airflow.providers.amazon.aws.utils.mixins import aws_template_fields if TYPE_CHECKING: from airflow.utils.context import Context -class StepFunctionExecutionSensor(BaseSensorOperator): +class StepFunctionExecutionSensor(AwsBaseSensor[StepFunctionHook]): """ Poll the Step Function State Machine Execution until it reaches a terminal state; fails if the task fails. @@ -42,7 +42,16 @@ class StepFunctionExecutionSensor(BaseSensorOperator): :ref:`howto/sensor:StepFunctionExecutionSensor` :param execution_arn: execution_arn to check the state of - :param aws_conn_id: aws connection to use, defaults to 'aws_default' + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html """ INTERMEDIATE_STATES = ("RUNNING",) @@ -53,22 +62,13 @@ class StepFunctionExecutionSensor(BaseSensorOperator): ) SUCCESS_STATES = ("SUCCEEDED",) - template_fields: Sequence[str] = ("execution_arn",) - template_ext: Sequence[str] = () + aws_hook_class = StepFunctionHook + template_fields: Sequence[str] = aws_template_fields("execution_arn") ui_color = "#66c3ff" - def __init__( - self, - *, - execution_arn: str, - aws_conn_id: str = "aws_default", - region_name: str | None = None, - **kwargs, - ): + def __init__(self, *, execution_arn: str, **kwargs): super().__init__(**kwargs) self.execution_arn = execution_arn - self.aws_conn_id = aws_conn_id - self.region_name = region_name def poke(self, context: Context): execution_status = self.hook.describe_execution(self.execution_arn) @@ -93,7 +93,3 @@ def poke(self, context: Context): def get_hook(self) -> StepFunctionHook: """Create and return a StepFunctionHook.""" return self.hook - - @cached_property - def hook(self) -> StepFunctionHook: - return StepFunctionHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) diff --git a/airflow/providers/amazon/aws/triggers/step_function.py b/airflow/providers/amazon/aws/triggers/step_function.py index da0f186da9cef..6fe6af2218486 100644 --- a/airflow/providers/amazon/aws/triggers/step_function.py +++ b/airflow/providers/amazon/aws/triggers/step_function.py @@ -43,6 +43,7 @@ def __init__( waiter_max_attempts: int = 30, aws_conn_id: str | None = None, region_name: str | None = None, + **kwargs, ) -> None: super().__init__( serialized_fields={"execution_arn": execution_arn, "region_name": region_name}, @@ -56,7 +57,13 @@ def __init__( waiter_delay=waiter_delay, waiter_max_attempts=waiter_max_attempts, aws_conn_id=aws_conn_id, + **kwargs, ) def hook(self) -> AwsGenericHook: - return StepFunctionHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) + return StepFunctionHook( + aws_conn_id=self.aws_conn_id, + region_name=self.region_name, + verify=self.verify, + config=self.botocore_config, + ) diff --git a/docs/apache-airflow-providers-amazon/operators/step_functions.rst b/docs/apache-airflow-providers-amazon/operators/step_functions.rst index 7736fa9b16747..5ab5d19e68290 100644 --- a/docs/apache-airflow-providers-amazon/operators/step_functions.rst +++ b/docs/apache-airflow-providers-amazon/operators/step_functions.rst @@ -28,6 +28,11 @@ Prerequisite Tasks .. include:: ../_partials/prerequisite_tasks.rst +Generic Parameters +------------------ + +.. include:: ../_partials/generic_parameters.rst + Operators --------- diff --git a/tests/providers/amazon/aws/operators/test_step_function.py b/tests/providers/amazon/aws/operators/test_step_function.py index 91ccebf7c6e29..6845a7f98ad93 100644 --- a/tests/providers/amazon/aws/operators/test_step_function.py +++ b/tests/providers/amazon/aws/operators/test_step_function.py @@ -18,12 +18,10 @@ from __future__ import annotations from unittest import mock -from unittest.mock import MagicMock import pytest -from airflow.exceptions import TaskDeferred -from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook +from airflow.exceptions import AirflowException, TaskDeferred from airflow.providers.amazon.aws.operators.step_function import ( StepFunctionGetExecutionOutputOperator, StepFunctionStartExecutionOperator, @@ -40,104 +38,106 @@ INPUT = "{}" +@pytest.fixture +def mocked_context(): + return mock.MagicMock(name="FakeContext") + + class TestStepFunctionGetExecutionOutputOperator: TASK_ID = "step_function_get_execution_output" - def setup_method(self): - self.mock_context = MagicMock() - def test_init(self): - # Given / When - operator = StepFunctionGetExecutionOutputOperator( + op = StepFunctionGetExecutionOutputOperator( task_id=self.TASK_ID, execution_arn=EXECUTION_ARN, aws_conn_id=AWS_CONN_ID, region_name=REGION_NAME, + verify="/spam/egg.pem", + botocore_config={"read_timeout": 42}, ) - - # Then - assert self.TASK_ID == operator.task_id - assert EXECUTION_ARN == operator.execution_arn - assert AWS_CONN_ID == operator.aws_conn_id - assert REGION_NAME == operator.region_name - - @mock.patch("airflow.providers.amazon.aws.operators.step_function.StepFunctionHook") - @pytest.mark.parametrize("response", ["output", "error"]) - def test_execute(self, mock_hook, response): - # Given - hook_response = {response: "{}"} - - hook_instance = mock_hook.return_value - hook_instance.describe_execution.return_value = hook_response - - operator = StepFunctionGetExecutionOutputOperator( + assert op.execution_arn == EXECUTION_ARN + assert op.hook.aws_conn_id == AWS_CONN_ID + assert op.hook._region_name == REGION_NAME + assert op.hook._verify == "/spam/egg.pem" + assert op.hook._config is not None + assert op.hook._config.read_timeout == 42 + + op = StepFunctionGetExecutionOutputOperator(task_id=self.TASK_ID, execution_arn=EXECUTION_ARN) + assert op.hook.aws_conn_id == "aws_default" + assert op.hook._region_name is None + assert op.hook._verify is None + assert op.hook._config is None + + @mock.patch.object(StepFunctionGetExecutionOutputOperator, "hook") + @pytest.mark.parametrize( + "response, expected_output", + [ + pytest.param({"output": '{"foo": "bar"}'}, {"foo": "bar"}, id="output"), + pytest.param({"error": '{"spam": "egg"}'}, {"spam": "egg"}, id="error"), + pytest.param({"other": '{"baz": "qux"}'}, None, id="other"), + ], + ) + def test_execute(self, mocked_hook, mocked_context, response, expected_output): + mocked_hook.describe_execution.return_value = response + op = StepFunctionGetExecutionOutputOperator( task_id=self.TASK_ID, execution_arn=EXECUTION_ARN, - aws_conn_id=AWS_CONN_ID, - region_name=REGION_NAME, + aws_conn_id=None, ) - - # When - result = operator.execute(self.mock_context) - - # Then - assert {} == result + assert op.execute(mocked_context) == expected_output + mocked_hook.describe_execution.assert_called_once_with(EXECUTION_ARN) class TestStepFunctionStartExecutionOperator: TASK_ID = "step_function_start_execution_task" - def setup_method(self): - self.mock_context = MagicMock() - def test_init(self): - # Given / When - operator = StepFunctionStartExecutionOperator( + op = StepFunctionStartExecutionOperator( task_id=self.TASK_ID, state_machine_arn=STATE_MACHINE_ARN, name=NAME, state_machine_input=INPUT, aws_conn_id=AWS_CONN_ID, region_name=REGION_NAME, + verify=False, + botocore_config={"read_timeout": 42}, ) - - # Then - assert self.TASK_ID == operator.task_id - assert STATE_MACHINE_ARN == operator.state_machine_arn - assert NAME == operator.name - assert INPUT == operator.input - assert AWS_CONN_ID == operator.aws_conn_id - assert REGION_NAME == operator.region_name - - @mock.patch("airflow.providers.amazon.aws.operators.step_function.StepFunctionHook") - def test_execute(self, mock_hook): - # Given + assert op.state_machine_arn == STATE_MACHINE_ARN + assert op.state_machine_arn == STATE_MACHINE_ARN + assert op.name == NAME + assert op.input == INPUT + assert op.hook.aws_conn_id == AWS_CONN_ID + assert op.hook._region_name == REGION_NAME + assert op.hook._verify is False + assert op.hook._config is not None + assert op.hook._config.read_timeout == 42 + + op = StepFunctionStartExecutionOperator(task_id=self.TASK_ID, state_machine_arn=STATE_MACHINE_ARN) + assert op.hook.aws_conn_id == "aws_default" + assert op.hook._region_name is None + assert op.hook._verify is None + assert op.hook._config is None + + @mock.patch.object(StepFunctionStartExecutionOperator, "hook") + def test_execute(self, mocked_hook, mocked_context): hook_response = ( "arn:aws:states:us-east-1:123456789012:execution:" "pseudo-state-machine:020f5b16-b1a1-4149-946f-92dd32d97934" ) - - hook_instance = mock_hook.return_value - hook_instance.start_execution.return_value = hook_response - - operator = StepFunctionStartExecutionOperator( + mocked_hook.start_execution.return_value = hook_response + op = StepFunctionStartExecutionOperator( task_id=self.TASK_ID, state_machine_arn=STATE_MACHINE_ARN, name=NAME, state_machine_input=INPUT, - aws_conn_id=AWS_CONN_ID, - region_name=REGION_NAME, + aws_conn_id=None, ) + assert op.execute(mocked_context) == hook_response + mocked_hook.start_execution.assert_called_once_with(STATE_MACHINE_ARN, NAME, INPUT) - # When - result = operator.execute(self.mock_context) - - # Then - assert hook_response == result - - @mock.patch.object(StepFunctionHook, "start_execution") - def test_step_function_start_execution_deferrable(self, mock_start_execution): - mock_start_execution.return_value = "test-execution-arn" + @mock.patch.object(StepFunctionStartExecutionOperator, "hook") + def test_step_function_start_execution_deferrable(self, mocked_hook): + mocked_hook.start_execution.return_value = "test-execution-arn" operator = StepFunctionStartExecutionOperator( task_id=self.TASK_ID, state_machine_arn=STATE_MACHINE_ARN, @@ -149,3 +149,14 @@ def test_step_function_start_execution_deferrable(self, mock_start_execution): ) with pytest.raises(TaskDeferred): operator.execute(None) + mocked_hook.start_execution.assert_called_once_with(STATE_MACHINE_ARN, NAME, INPUT) + + @mock.patch.object(StepFunctionStartExecutionOperator, "hook") + @pytest.mark.parametrize("execution_arn", [pytest.param(None, id="none"), pytest.param("", id="empty")]) + def test_step_function_no_execution_arn_returns(self, mocked_hook, execution_arn): + mocked_hook.start_execution.return_value = execution_arn + op = StepFunctionStartExecutionOperator( + task_id=self.TASK_ID, state_machine_arn=STATE_MACHINE_ARN, aws_conn_id=None + ) + with pytest.raises(AirflowException, match="Failed to start State Machine execution"): + op.execute({}) diff --git a/tests/providers/amazon/aws/sensors/test_step_function.py b/tests/providers/amazon/aws/sensors/test_step_function.py index b6a47d49cd642..878691dc1cddc 100644 --- a/tests/providers/amazon/aws/sensors/test_step_function.py +++ b/tests/providers/amazon/aws/sensors/test_step_function.py @@ -19,7 +19,6 @@ import json from unittest import mock -from unittest.mock import MagicMock import pytest @@ -35,72 +34,63 @@ REGION_NAME = "us-west-2" -class TestStepFunctionExecutionSensor: - def setup_method(self): - self.mock_context = MagicMock() - - def test_init(self): - sensor = StepFunctionExecutionSensor( - task_id=TASK_ID, execution_arn=EXECUTION_ARN, aws_conn_id=AWS_CONN_ID, region_name=REGION_NAME - ) - - assert TASK_ID == sensor.task_id - assert EXECUTION_ARN == sensor.execution_arn - assert AWS_CONN_ID == sensor.aws_conn_id - assert REGION_NAME == sensor.region_name - - @pytest.mark.parametrize("mock_status", ["FAILED", "TIMED_OUT", "ABORTED"]) - @mock.patch("airflow.providers.amazon.aws.sensors.step_function.StepFunctionHook") - def test_exceptions(self, mock_hook, mock_status): - hook_response = {"status": mock_status} - - hook_instance = mock_hook.return_value - hook_instance.describe_execution.return_value = hook_response - - sensor = StepFunctionExecutionSensor( - task_id=TASK_ID, execution_arn=EXECUTION_ARN, aws_conn_id=AWS_CONN_ID, region_name=REGION_NAME - ) - - with pytest.raises(AirflowException): - sensor.poke(self.mock_context) +@pytest.fixture +def mocked_context(): + return mock.MagicMock(name="FakeContext") - @mock.patch("airflow.providers.amazon.aws.sensors.step_function.StepFunctionHook") - def test_running(self, mock_hook): - hook_response = {"status": "RUNNING"} - - hook_instance = mock_hook.return_value - hook_instance.describe_execution.return_value = hook_response - - sensor = StepFunctionExecutionSensor( - task_id=TASK_ID, execution_arn=EXECUTION_ARN, aws_conn_id=AWS_CONN_ID, region_name=REGION_NAME - ) - - assert not sensor.poke(self.mock_context) - - @mock.patch("airflow.providers.amazon.aws.sensors.step_function.StepFunctionHook") - def test_succeeded(self, mock_hook): - hook_response = {"status": "SUCCEEDED"} - - hook_instance = mock_hook.return_value - hook_instance.describe_execution.return_value = hook_response +class TestStepFunctionExecutionSensor: + def test_init(self): sensor = StepFunctionExecutionSensor( - task_id=TASK_ID, execution_arn=EXECUTION_ARN, aws_conn_id=AWS_CONN_ID, region_name=REGION_NAME + task_id=TASK_ID, + execution_arn=EXECUTION_ARN, + aws_conn_id=AWS_CONN_ID, + region_name=REGION_NAME, + verify=True, + botocore_config={"read_timeout": 42}, ) - - assert sensor.poke(self.mock_context) - + assert sensor.execution_arn == EXECUTION_ARN + assert sensor.hook.aws_conn_id == AWS_CONN_ID + assert sensor.hook._region_name == REGION_NAME + assert sensor.hook._verify is True + assert sensor.hook._config is not None + assert sensor.hook._config.read_timeout == 42 + + sensor = StepFunctionExecutionSensor(task_id=TASK_ID, execution_arn=EXECUTION_ARN) + assert sensor.hook.aws_conn_id == "aws_default" + assert sensor.hook._region_name is None + assert sensor.hook._verify is None + assert sensor.hook._config is None + + @mock.patch.object(StepFunctionExecutionSensor, "hook") + @pytest.mark.parametrize("status", StepFunctionExecutionSensor.INTERMEDIATE_STATES) + def test_running(self, mocked_hook, status, mocked_context): + mocked_hook.describe_execution.return_value = {"status": status} + sensor = StepFunctionExecutionSensor(task_id=TASK_ID, execution_arn=EXECUTION_ARN, aws_conn_id=None) + assert sensor.poke(mocked_context) is False + + @mock.patch.object(StepFunctionExecutionSensor, "hook") + @pytest.mark.parametrize("status", StepFunctionExecutionSensor.SUCCESS_STATES) + def test_succeeded(self, mocked_hook, status, mocked_context): + mocked_hook.describe_execution.return_value = {"status": status} + sensor = StepFunctionExecutionSensor(task_id=TASK_ID, execution_arn=EXECUTION_ARN, aws_conn_id=None) + assert sensor.poke(mocked_context) is True + + @mock.patch.object(StepFunctionExecutionSensor, "hook") + @pytest.mark.parametrize("status", StepFunctionExecutionSensor.FAILURE_STATES) @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) + "soft_fail, expected_exception", + [ + pytest.param(True, AirflowSkipException, id="soft-fail"), + pytest.param(False, AirflowException, id="non-soft-fail"), + ], ) - @mock.patch("airflow.providers.amazon.aws.hooks.step_function.StepFunctionHook.describe_execution") - def test_fail_poke(self, describe_execution, soft_fail, expected_exception): + def test_failure(self, mocked_hook, status, soft_fail, expected_exception, mocked_context): + output = {"test": "test"} + mocked_hook.describe_execution.return_value = {"status": status, "output": json.dumps(output)} sensor = StepFunctionExecutionSensor( - task_id=TASK_ID, execution_arn=EXECUTION_ARN, aws_conn_id=AWS_CONN_ID, region_name=REGION_NAME + task_id=TASK_ID, execution_arn=EXECUTION_ARN, aws_conn_id=None, soft_fail=soft_fail ) - sensor.soft_fail = soft_fail - output = '{"test": "test"}' - describe_execution.return_value = {"status": "FAILED", "output": output} - message = f"Step Function sensor failed. State Machine Output: {json.loads(output)}" + message = f"Step Function sensor failed. State Machine Output: {output}" with pytest.raises(expected_exception, match=message): - sensor.poke(context={}) + sensor.poke(mocked_context)