diff --git a/Makefile b/Makefile index 44e1ecd..1b1d7e0 100644 --- a/Makefile +++ b/Makefile @@ -16,7 +16,7 @@ install: check-venv @pip install '.[extra]' -c constraints.txt itest: check-venv - @pytest tests + @pytest -n 4 -v tests pre-commit: format-check lint type-check utest @@ -27,4 +27,4 @@ type-check: check-venv @mypy inngest tests utest: check-venv - @pytest inngest + @pytest -v inngest diff --git a/examples/functions/__init__.py b/examples/functions/__init__.py index d09ef6e..ed07414 100644 --- a/examples/functions/__init__.py +++ b/examples/functions/__init__.py @@ -5,6 +5,7 @@ print_event, send_event, two_steps_and_sleep, + wait_for_event, ) functions = [ @@ -14,6 +15,7 @@ print_event.fn, send_event.fn, two_steps_and_sleep.fn, + wait_for_event.fn, ] __all__ = ["functions"] diff --git a/examples/functions/wait_for_event.py b/examples/functions/wait_for_event.py new file mode 100644 index 0000000..d486a7c --- /dev/null +++ b/examples/functions/wait_for_event.py @@ -0,0 +1,14 @@ +import inngest + + +@inngest.create_function( + inngest.FunctionOpts(id="wait_for_event", name="wait_for_event"), + inngest.TriggerEvent(event="app/wait_for_event"), +) +def fn(*, step: inngest.Step, **_kwargs: object) -> None: + res = step.wait_for_event( + "wait", + event="app/wait_for_event.fulfill", + timeout=inngest.Duration.second(2), + ) + step.run("print-result", lambda: print(res)) diff --git a/inngest/__init__.py b/inngest/__init__.py index 2a98205..7e70c64 100644 --- a/inngest/__init__.py +++ b/inngest/__init__.py @@ -1,4 +1,5 @@ from ._internal.client import Inngest +from ._internal.const import Duration from ._internal.errors import NonRetriableError from ._internal.event import Event from ._internal.frameworks import flask, tornado @@ -14,6 +15,7 @@ __all__ = [ "BatchConfig", "CancelConfig", + "Duration", "Event", "Function", "FunctionOpts", diff --git a/inngest/_internal/comm.py b/inngest/_internal/comm.py index e93b99d..525a44c 100644 --- a/inngest/_internal/comm.py +++ b/inngest/_internal/comm.py @@ -20,7 +20,7 @@ from .errors import ( InternalError, InvalidBaseURL, - InvalidFunctionConfig, + InvalidConfig, MissingFunction, ) from .execution import Call, CallError @@ -154,7 +154,7 @@ def call_function( def get_function_configs(self, app_url: str) -> list[FunctionConfig]: configs = [fn.get_config(app_url) for fn in self._fns.values()] if len(configs) == 0: - raise InvalidFunctionConfig("no functions found") + raise InvalidConfig("no functions found") return configs def _parse_registration_response( diff --git a/inngest/_internal/comm_test.py b/inngest/_internal/comm_test.py index ebf71a0..7035af4 100644 --- a/inngest/_internal/comm_test.py +++ b/inngest/_internal/comm_test.py @@ -8,7 +8,7 @@ import inngest from .comm import CommHandler -from .errors import InvalidFunctionConfig +from .errors import InvalidConfig class Test_get_function_configs(TestCase): # pylint: disable=invalid-name @@ -29,7 +29,7 @@ def test_full_config(self) -> None: batch_events=inngest.BatchConfig(max_size=2, timeout="1m"), cancel=inngest.CancelConfig( event="app/cancel", - if_expression="true", + if_exp="true", timeout="1m", ), id="fn", @@ -60,5 +60,5 @@ def test_no_functions(self) -> None: logger=self.client.logger, ) - with pytest.raises(InvalidFunctionConfig, match="no functions found"): + with pytest.raises(InvalidConfig, match="no functions found"): comm.get_function_configs("http://foo.bar") diff --git a/inngest/_internal/const.py b/inngest/_internal/const.py index 03aa397..a7c8a3c 100644 --- a/inngest/_internal/const.py +++ b/inngest/_internal/const.py @@ -8,6 +8,28 @@ VERSION: Final = "0.1.0" +class Duration: + @classmethod + def second(cls, count: int = 1) -> int: + return count * 60 * 1000 + + @classmethod + def minute(cls, count: int = 1) -> int: + return count * cls.second(60) + + @classmethod + def hour(cls, count: int = 1) -> int: + return count * cls.minute(60) + + @classmethod + def day(cls, count: int = 1) -> int: + return count * cls.hour(24) + + @classmethod + def week(cls, count: int = 1) -> int: + return count * cls.day(7) + + class EnvKey(Enum): BASE_URL = "INNGEST_BASE_URL" EVENT_KEY = "INNGEST_EVENT_KEY" diff --git a/inngest/_internal/errors.py b/inngest/_internal/errors.py index be4e644..a6cd22b 100644 --- a/inngest/_internal/errors.py +++ b/inngest/_internal/errors.py @@ -24,7 +24,7 @@ def __init__(self, message: str | None = None) -> None: ) -class InvalidFunctionConfig(InternalError): +class InvalidConfig(InternalError): status_code: int = 500 def __init__(self, message: str | None = None) -> None: @@ -37,7 +37,7 @@ def __init__(self, message: str | None = None) -> None: def from_validation_error( cls, err: ValidationError, - ) -> InvalidFunctionConfig: + ) -> InvalidConfig: """ Extract info from Pydantic's ValidationError and return our internal InvalidFunctionConfig error. diff --git a/inngest/_internal/execution.py b/inngest/_internal/execution.py index ef7db4e..fec231c 100644 --- a/inngest/_internal/execution.py +++ b/inngest/_internal/execution.py @@ -37,8 +37,10 @@ class CallResponse(BaseModel): id: str name: str op: Opcode + opts: dict[str, object] | None = None class Opcode(Enum): SLEEP = "Sleep" STEP = "Step" + WAIT_FOR_EVENT = "WaitForEvent" diff --git a/inngest/_internal/execution_test.py b/inngest/_internal/execution_test.py index feb0eaf..fff6cba 100644 --- a/inngest/_internal/execution_test.py +++ b/inngest/_internal/execution_test.py @@ -8,6 +8,7 @@ def test_serialization() -> None: id="my_id", name="my_name", op=Opcode.STEP, + opts={}, ).to_dict() expectation = { @@ -16,6 +17,7 @@ def test_serialization() -> None: "id": "my_id", "name": "my_name", "op": "Step", + "opts": {}, } assert actual == expectation diff --git a/inngest/_internal/function.py b/inngest/_internal/function.py index fd3d659..fb8759b 100644 --- a/inngest/_internal/function.py +++ b/inngest/_internal/function.py @@ -9,11 +9,7 @@ from pydantic import ValidationError from .client import Inngest -from .errors import ( - InvalidFunctionConfig, - NonRetriableError, - UnserializableOutput, -) +from .errors import InvalidConfig, NonRetriableError, UnserializableOutput from .event import Event from .execution import Call, CallError, CallResponse, Opcode from .function_config import ( @@ -27,7 +23,7 @@ TriggerCron, TriggerEvent, ) -from .transforms import hash_step_id, to_iso_utc +from .transforms import hash_step_id, to_duration_str, to_iso_utc from .types import BaseModel, EmptySentinel, T @@ -53,7 +49,7 @@ def convert_validation_error( self, err: ValidationError, ) -> BaseException: - return InvalidFunctionConfig.from_validation_error(err) + return InvalidConfig.from_validation_error(err) class Function: @@ -86,6 +82,7 @@ def call( id=out.hashed_id, name=out.name, op=out.op, + opts=out.opts, ) ] except Exception as err: @@ -146,12 +143,14 @@ def __init__( hashed_id: str, name: str, op: Opcode, + opts: dict[str, object] | None = None, ) -> None: self.data = data self.display_name = display_name self.hashed_id = hashed_id self.name = name self.op = op + self.opts = opts class _Step: @@ -228,6 +227,49 @@ def sleep_until( op=Opcode.SLEEP, ) + def wait_for_event( + self, + id: str, # pylint: disable=redefined-builtin + *, + event: str, + if_exp: str | None = None, + timeout: int, + ) -> Event | None: + """ + Args: + event: Event name. + if_exp: An expression to filter events. + timeout: The maximum number of milliseconds to wait for the event. + """ + + id_count = self._step_id_counter.increment(id) + if id_count > 1: + id = f"{id}:{id_count - 1}" + hashed_id = hash_step_id(id) + + memo = self._get_memo(hashed_id) + if memo is not EmptySentinel: + if memo is None: + # Timeout + return None + + # Fulfilled by an event + return Event.model_validate(memo) + + opts: dict[str, object] = { + "timeout": to_duration_str(timeout), + } + if if_exp is not None: + opts["if"] = if_exp + + raise EarlyReturn( + hashed_id=hashed_id, + display_name=id, + name=event, + op=Opcode.WAIT_FOR_EVENT, + opts=opts, + ) + class _FunctionHandler(Protocol): def __call__(self, *, event: Event, step: Step) -> object: @@ -256,6 +298,16 @@ def sleep_until( ) -> None: ... + def wait_for_event( + self, + id: str, # pylint: disable=redefined-builtin + *, + event: str, + if_exp: str | None = None, + timeout: int, + ) -> Event | None: + ... + class _StepIDCounter: def __init__(self) -> None: diff --git a/inngest/_internal/function_config.py b/inngest/_internal/function_config.py index bff3b43..9e19940 100644 --- a/inngest/_internal/function_config.py +++ b/inngest/_internal/function_config.py @@ -4,7 +4,7 @@ from pydantic import Field, ValidationError -from .errors import InvalidFunctionConfig +from .errors import InvalidConfig from .types import BaseModel # A number > 0 followed by a time unit (s, m, h, d, w) @@ -16,12 +16,12 @@ def convert_validation_error( self, err: ValidationError, ) -> BaseException: - return InvalidFunctionConfig.from_validation_error(err) + return InvalidConfig.from_validation_error(err) class CancelConfig(_BaseConfig): event: str - if_expression: str | None = None + if_exp: str | None = None timeout: str | None = Field(default=None, pattern=TIME_PERIOD_REGEX) diff --git a/inngest/_internal/transforms.py b/inngest/_internal/transforms.py index 8cbfc6a..bdc84c6 100644 --- a/inngest/_internal/transforms.py +++ b/inngest/_internal/transforms.py @@ -2,6 +2,8 @@ import re from datetime import datetime, timezone +from .const import Duration +from .errors import InvalidConfig from .types import T @@ -37,3 +39,18 @@ def to_iso_utc(value: datetime) -> str: value.astimezone(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z" ) + + +def to_duration_str(ms: int) -> str: + if ms < Duration.second(): + raise InvalidConfig("duration must be at least 1 second") + if ms < Duration.minute(): + return f"{ms // Duration.second()}s" + if ms < Duration.hour(): + return f"{ms // Duration.minute()}m" + if ms < Duration.day(): + return f"{ms // Duration.hour()}h" + if ms < Duration.week(): + return f"{ms // Duration.day()}d" + + return f"{ms // Duration.week()}w" diff --git a/pyproject.toml b/pyproject.toml index c7da709..42768f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ extra = [ "mypy==1.6.1", "pylint==3.0.1", "pytest==7.4.2", + "pytest-xdist[psutil]==3.3.1", "python-json-logger==2.0.7", "toml==0.10.2", "tornado==6.3.3", diff --git a/tests/base.py b/tests/base.py index 05ec536..3fc1ddb 100644 --- a/tests/base.py +++ b/tests/base.py @@ -1,5 +1,4 @@ -import time -from typing import Callable, Protocol +from typing import Protocol import requests @@ -36,20 +35,3 @@ def set_up(case: _FrameworkTestCase) -> None: def tear_down(case: _FrameworkTestCase) -> None: case.http_proxy.stop() - - -def wait_for( - assertion: Callable[[], None], - timeout: int = 5, -) -> None: - start = time.time() - while True: - try: - assertion() - return - except Exception as err: - timed_out = time.time() - start > timeout - if timed_out: - raise err - - time.sleep(0.2) diff --git a/tests/cases.py b/tests/cases.py deleted file mode 100644 index 47211a5..0000000 --- a/tests/cases.py +++ /dev/null @@ -1,219 +0,0 @@ -from dataclasses import dataclass -from typing import Callable - -import inngest -from inngest._internal.errors import UnserializableOutput - -from .base import wait_for - - -class _BaseState: - def is_done(self) -> bool: - raise NotImplementedError() - - -@dataclass -class Case: - event_name: str - fn: inngest.Function - name: str - run_test: Callable[[object], None] - state: _BaseState - - -def _event_payload(client: inngest.Inngest, framework: str) -> Case: - name = "event_payload" - event_name = f"{framework}/{name}" - - class State(_BaseState): - event: inngest.Event | None = None - - def is_done(self) -> bool: - return self.event is not None - - state = State() - - @inngest.create_function( - inngest.FunctionOpts(id=name), - inngest.TriggerEvent(event=event_name), - ) - def fn(*, event: inngest.Event, **_kwargs: object) -> None: - state.event = event - - def run_test(_self: object) -> None: - client.send( - inngest.Event( - data={"foo": {"bar": "baz"}}, - name=event_name, - user={"a": {"b": "c"}}, - ) - ) - - def assertion() -> None: - assert state.event is not None - assert state.event.id != "" - assert state.event.name == event_name - assert state.event.data == {"foo": {"bar": "baz"}} - assert state.event.ts > 0 - assert state.event.user == {"a": {"b": "c"}} - - wait_for(assertion) - - return Case( - event_name=event_name, - fn=fn, - run_test=run_test, - state=state, - name=name, - ) - - -def _no_steps(client: inngest.Inngest, framework: str) -> Case: - name = "no_steps" - event_name = f"{framework}/{name}" - - class State(_BaseState): - counter = 0 - - def is_done(self) -> bool: - return self.counter == 1 - - state = State() - - @inngest.create_function( - inngest.FunctionOpts(id=name), - inngest.TriggerEvent(event=event_name), - ) - def fn(**_kwargs: object) -> None: - state.counter += 1 - - def run_test(_self: object) -> None: - client.send(inngest.Event(name=event_name)) - - def assertion() -> None: - assert state.is_done() - - wait_for(assertion) - - return Case( - event_name=event_name, - fn=fn, - run_test=run_test, - state=state, - name=name, - ) - - -def _two_steps(client: inngest.Inngest, framework: str) -> Case: - name = "two_steps" - event_name = f"{framework}/{name}" - - class State(_BaseState): - step_1_counter = 0 - step_2_counter = 0 - end_counter = 0 - - def is_done(self) -> bool: - return ( - self.step_1_counter == 1 - and self.step_2_counter == 1 - and self.end_counter == 1 - ) - - state = State() - - @inngest.create_function( - inngest.FunctionOpts(id=name), - inngest.TriggerEvent(event=event_name), - ) - def fn(*, step: inngest.Step, **_kwargs: object) -> None: - def step_1() -> str: - state.step_1_counter += 1 - return "hi" - - step.run("step_1", step_1) - - def step_2() -> None: - state.step_2_counter += 1 - - step.run("step_2", step_2) - state.end_counter += 1 - - def run_test(_self: object) -> None: - client.send(inngest.Event(name=event_name)) - - def assertion() -> None: - assert state.is_done() - - wait_for(assertion) - - return Case( - event_name=event_name, - fn=fn, - run_test=run_test, - state=state, - name=name, - ) - - -def _unserializable_step_output( - client: inngest.Inngest, - framework: str, -) -> Case: - name = "unserializable_step_output" - event_name = f"{framework}/{name}" - - class State(_BaseState): - error: BaseException | None = None - - def is_done(self) -> bool: - return self.error is not None - - state = State() - - @inngest.create_function( - inngest.FunctionOpts(id=name, retries=0), - inngest.TriggerEvent(event=event_name), - ) - def fn(*, step: inngest.Step, **_kwargs: object) -> None: - class Foo: - pass - - def step_1() -> Foo: - return Foo() - - try: - step.run("step_1", step_1) - except BaseException as err: - state.error = err - raise - - def run_test(_self: object) -> None: - client.send(inngest.Event(name=event_name)) - - def assertion() -> None: - assert state.is_done() - assert isinstance(state.error, UnserializableOutput) - assert ( - str(state.error) - == "Object of type Foo is not JSON serializable" - ) - - wait_for(assertion) - - return Case( - event_name=event_name, - fn=fn, - run_test=run_test, - state=state, - name=name, - ) - - -def create_cases(client: inngest.Inngest, framework: str) -> list[Case]: - return [ - _event_payload(client, framework), - _no_steps(client, framework), - _two_steps(client, framework), - _unserializable_step_output(client, framework), - ] diff --git a/tests/cases/__init__.py b/tests/cases/__init__.py new file mode 100644 index 0000000..784b9bd --- /dev/null +++ b/tests/cases/__init__.py @@ -0,0 +1,28 @@ +import inngest + +from . import ( + event_payload, + no_steps, + two_steps, + unserializable_step_output, + wait_for_event_fulfill, + wait_for_event_timeout, +) +from .base import Case + + +def create_cases(client: inngest.Inngest, framework: str) -> list[Case]: + return [ + case.create(client, framework) + for case in ( + event_payload, + no_steps, + two_steps, + unserializable_step_output, + wait_for_event_fulfill, + wait_for_event_timeout, + ) + ] + + +__all__ = ["create_cases"] diff --git a/tests/cases/base.py b/tests/cases/base.py new file mode 100644 index 0000000..36f9a5a --- /dev/null +++ b/tests/cases/base.py @@ -0,0 +1,36 @@ +import time +from dataclasses import dataclass +from typing import Callable + +import inngest + + +class BaseState: + def is_done(self) -> bool: + raise NotImplementedError() + + +@dataclass +class Case: + event_name: str + fn: inngest.Function + name: str + run_test: Callable[[object], None] + state: BaseState + + +def wait_for( + assertion: Callable[[], None], + timeout: int = 5, +) -> None: + start = time.time() + while True: + try: + assertion() + return + except Exception as err: + timed_out = time.time() - start > timeout + if timed_out: + raise err + + time.sleep(0.2) diff --git a/tests/cases/event_payload.py b/tests/cases/event_payload.py new file mode 100644 index 0000000..48c6dc1 --- /dev/null +++ b/tests/cases/event_payload.py @@ -0,0 +1,50 @@ +import inngest + +from .base import BaseState, Case, wait_for + + +class _State(BaseState): + event: inngest.Event | None = None + + def is_done(self) -> bool: + return self.event is not None + + +def create(client: inngest.Inngest, framework: str) -> Case: + name = "event_payload" + event_name = f"{framework}/{name}" + state = _State() + + @inngest.create_function( + inngest.FunctionOpts(id=name), + inngest.TriggerEvent(event=event_name), + ) + def fn(*, event: inngest.Event, **_kwargs: object) -> None: + state.event = event + + def run_test(_self: object) -> None: + client.send( + inngest.Event( + data={"foo": {"bar": "baz"}}, + name=event_name, + user={"a": {"b": "c"}}, + ) + ) + + def assertion() -> None: + assert state.event is not None + assert state.event.id != "" + assert state.event.name == event_name + assert state.event.data == {"foo": {"bar": "baz"}} + assert state.event.ts > 0 + assert state.event.user == {"a": {"b": "c"}} + + wait_for(assertion) + + return Case( + event_name=event_name, + fn=fn, + run_test=run_test, + state=state, + name=name, + ) diff --git a/tests/cases/no_steps.py b/tests/cases/no_steps.py new file mode 100644 index 0000000..9047384 --- /dev/null +++ b/tests/cases/no_steps.py @@ -0,0 +1,39 @@ +import inngest + +from .base import BaseState, Case, wait_for + + +class _State(BaseState): + counter = 0 + + def is_done(self) -> bool: + return self.counter == 1 + + +def create(client: inngest.Inngest, framework: str) -> Case: + name = "no_steps" + event_name = f"{framework}/{name}" + state = _State() + + @inngest.create_function( + inngest.FunctionOpts(id=name), + inngest.TriggerEvent(event=event_name), + ) + def fn(**_kwargs: object) -> None: + state.counter += 1 + + def run_test(_self: object) -> None: + client.send(inngest.Event(name=event_name)) + + def assertion() -> None: + assert state.is_done() + + wait_for(assertion) + + return Case( + event_name=event_name, + fn=fn, + run_test=run_test, + state=state, + name=name, + ) diff --git a/tests/cases/two_steps.py b/tests/cases/two_steps.py new file mode 100644 index 0000000..b9f9962 --- /dev/null +++ b/tests/cases/two_steps.py @@ -0,0 +1,55 @@ +import inngest + +from .base import BaseState, Case, wait_for + + +class _State(BaseState): + step_1_counter = 0 + step_2_counter = 0 + end_counter = 0 + + def is_done(self) -> bool: + return ( + self.step_1_counter == 1 + and self.step_2_counter == 1 + and self.end_counter == 1 + ) + + +def create(client: inngest.Inngest, framework: str) -> Case: + name = "two_steps" + event_name = f"{framework}/{name}" + state = _State() + + @inngest.create_function( + inngest.FunctionOpts(id=name), + inngest.TriggerEvent(event=event_name), + ) + def fn(*, step: inngest.Step, **_kwargs: object) -> None: + def step_1() -> str: + state.step_1_counter += 1 + return "hi" + + step.run("step_1", step_1) + + def step_2() -> None: + state.step_2_counter += 1 + + step.run("step_2", step_2) + state.end_counter += 1 + + def run_test(_self: object) -> None: + client.send(inngest.Event(name=event_name)) + + def assertion() -> None: + assert state.is_done() + + wait_for(assertion) + + return Case( + event_name=event_name, + fn=fn, + run_test=run_test, + state=state, + name=name, + ) diff --git a/tests/cases/unserializable_step_output.py b/tests/cases/unserializable_step_output.py new file mode 100644 index 0000000..a06131b --- /dev/null +++ b/tests/cases/unserializable_step_output.py @@ -0,0 +1,58 @@ +import inngest +from inngest._internal.errors import UnserializableOutput + +from .base import BaseState, Case, wait_for + + +class _State(BaseState): + error: BaseException | None = None + + def is_done(self) -> bool: + return self.error is not None + + +def create( + client: inngest.Inngest, + framework: str, +) -> Case: + name = "unserializable_step_output" + event_name = f"{framework}/{name}" + state = _State() + + @inngest.create_function( + inngest.FunctionOpts(id=name, retries=0), + inngest.TriggerEvent(event=event_name), + ) + def fn(*, step: inngest.Step, **_kwargs: object) -> None: + class Foo: + pass + + def step_1() -> Foo: + return Foo() + + try: + step.run("step_1", step_1) + except BaseException as err: + state.error = err + raise + + def run_test(_self: object) -> None: + client.send(inngest.Event(name=event_name)) + + def assertion() -> None: + assert state.is_done() + assert isinstance(state.error, UnserializableOutput) + assert ( + str(state.error) + == "Object of type Foo is not JSON serializable" + ) + + wait_for(assertion) + + return Case( + event_name=event_name, + fn=fn, + run_test=run_test, + state=state, + name=name, + ) diff --git a/tests/cases/wait_for_event_fulfill.py b/tests/cases/wait_for_event_fulfill.py new file mode 100644 index 0000000..15b89b3 --- /dev/null +++ b/tests/cases/wait_for_event_fulfill.py @@ -0,0 +1,63 @@ +import time + +import inngest + +from .base import BaseState, Case, wait_for + + +class _State(BaseState): + is_started = False + result: inngest.Event | None = None + + def is_done(self) -> bool: + return self.result is not None + + +def create( + client: inngest.Inngest, + framework: str, +) -> Case: + name = "wait_for_event_fulfill" + event_name = f"{framework}/{name}" + state = _State() + + @inngest.create_function( + inngest.FunctionOpts(id=name, retries=0), + inngest.TriggerEvent(event=event_name), + ) + def fn(*, step: inngest.Step, **_kwargs: object) -> None: + state.is_started = True + + state.result = step.wait_for_event( + "wait", + event=f"{event_name}.fulfill", + timeout=inngest.Duration.minute(1), + ) + + def run_test(_self: object) -> None: + client.send(inngest.Event(name=event_name)) + + def assert_started() -> None: + assert state.is_started is True + time.sleep(0.5) + + wait_for(assert_started) + + client.send(inngest.Event(name=f"{event_name}.fulfill")) + + def assertion() -> None: + assert state.is_done() + assert isinstance(state.result, inngest.Event) + assert state.result.id != "" + assert state.result.name == f"{event_name}.fulfill" + assert state.result.ts > 0 + + wait_for(assertion) + + return Case( + event_name=event_name, + fn=fn, + run_test=run_test, + state=state, + name=name, + ) diff --git a/tests/cases/wait_for_event_timeout.py b/tests/cases/wait_for_event_timeout.py new file mode 100644 index 0000000..e360065 --- /dev/null +++ b/tests/cases/wait_for_event_timeout.py @@ -0,0 +1,47 @@ +import inngest + +from .base import BaseState, Case, wait_for + + +class _State(BaseState): + result: inngest.Event | None | str = "not_set" + + def is_done(self) -> bool: + print(self.result) + return self.result is None + + +def create( + client: inngest.Inngest, + framework: str, +) -> Case: + name = "wait_for_event_timeout" + event_name = f"{framework}/{name}" + state = _State() + + @inngest.create_function( + inngest.FunctionOpts(id=name, retries=0), + inngest.TriggerEvent(event=event_name), + ) + def fn(*, step: inngest.Step, **_kwargs: object) -> None: + state.result = step.wait_for_event( + "wait", + event=f"{event_name}.fulfill", + timeout=inngest.Duration.second(1), + ) + + def run_test(_self: object) -> None: + client.send(inngest.Event(name=event_name)) + + def assertion() -> None: + assert state.is_done() + + wait_for(assertion) + + return Case( + event_name=event_name, + fn=fn, + run_test=run_test, + state=state, + name=name, + )