diff --git a/inngest/experimental/mocked/trigger.py b/inngest/experimental/mocked/trigger.py index 81578c4b..453704a3 100644 --- a/inngest/experimental/mocked/trigger.py +++ b/inngest/experimental/mocked/trigger.py @@ -48,6 +48,11 @@ def trigger( stack: list[str] = [] steps: dict[str, object] = {} planned = set[str]() + attempt = 0 + + max_attempt = 4 + if fn._opts.retries is not None: + max_attempt = fn._opts.retries while True: step_id: typing.Optional[str] = None @@ -57,9 +62,9 @@ def trigger( logger = unittest.mock.Mock() request = server_lib.ServerRequest( ctx=server_lib.ServerRequestCtx( - attempt=0, + attempt=attempt, disable_immediate_execution=True, - run_id="abc123", + run_id="test", stack=server_lib.ServerRequestCtxStack(stack=stack), ), event=event[0], @@ -110,11 +115,15 @@ def trigger( ) if res.error: - return _Result( - error=res.error, - output=None, - status=Status.FAILED, - ) + if attempt >= max_attempt: + return _Result( + error=res.error, + output=None, + status=Status.FAILED, + ) + + attempt += 1 + continue if res.multi: for step in res.multi: @@ -122,6 +131,17 @@ def trigger( # Unreachable continue + if step.error: + if attempt >= max_attempt: + return _Result( + error=step.error, + output=None, + status=Status.FAILED, + ) + + attempt += 1 + continue + if step.step.display_name in step_stubs: stub = step_stubs[step.step.display_name] if stub is Timeout: diff --git a/inngest/experimental/mocked/trigger_test.py b/inngest/experimental/mocked/trigger_test.py index 21532361..7901945a 100644 --- a/inngest/experimental/mocked/trigger_test.py +++ b/inngest/experimental/mocked/trigger_test.py @@ -245,3 +245,87 @@ def fn( with pytest.raises(UnstubbedStepError): trigger(fn, inngest.Event(name="test"), client) + + def test_retry_step(self) -> None: + counter = 0 + + @client.create_function( + fn_id="test", + trigger=inngest.TriggerEvent(event="test"), + ) + def fn( + ctx: inngest.Context, + step: inngest.StepSync, + ) -> str: + def a() -> str: + nonlocal counter + counter += 1 + if counter < 2: + raise Exception("oh no") + return "hi" + + return step.run("a", a) + + res = trigger(fn, inngest.Event(name="test"), client) + assert res.status is Status.COMPLETED + assert res.output == "hi" + + def test_fail_step(self) -> None: + @client.create_function( + fn_id="test", + retries=0, + trigger=inngest.TriggerEvent(event="test"), + ) + def fn( + ctx: inngest.Context, + step: inngest.StepSync, + ) -> None: + def a() -> None: + raise Exception("oh no") + + step.run("a", a) + + res = trigger(fn, inngest.Event(name="test"), client) + assert res.status is Status.FAILED + assert res.output is None + assert isinstance(res.error, Exception) + assert str(res.error) == "oh no" + + def test_retry_fn(self) -> None: + counter = 0 + + @client.create_function( + fn_id="test", + trigger=inngest.TriggerEvent(event="test"), + ) + def fn( + ctx: inngest.Context, + step: inngest.StepSync, + ) -> str: + nonlocal counter + counter += 1 + if counter < 2: + raise Exception("oh no") + return "hi" + + res = trigger(fn, inngest.Event(name="test"), client) + assert res.status is Status.COMPLETED + assert res.output == "hi" + + def test_fail_fn(self) -> None: + @client.create_function( + fn_id="test", + retries=0, + trigger=inngest.TriggerEvent(event="test"), + ) + def fn( + ctx: inngest.Context, + step: inngest.StepSync, + ) -> None: + raise Exception("oh no") + + res = trigger(fn, inngest.Event(name="test"), client) + assert res.status is Status.FAILED + assert res.output is None + assert isinstance(res.error, Exception) + assert str(res.error) == "oh no"