Skip to content

Commit

Permalink
Add step.wait_for_event
Browse files Browse the repository at this point in the history
  • Loading branch information
amh4r committed Oct 25, 2023
1 parent 245b7e2 commit 28fc888
Show file tree
Hide file tree
Showing 24 changed files with 510 additions and 257 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ install: check-venv
@pip install '.[extra]' -c constraints.txt

itest: check-venv
@pytest tests
@pytest -n 4 -v tests

pre-commit: format-check lint type-check utest

Expand All @@ -27,4 +27,4 @@ type-check: check-venv
@mypy inngest tests

utest: check-venv
@pytest inngest
@pytest -v inngest
2 changes: 2 additions & 0 deletions examples/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
print_event,
send_event,
two_steps_and_sleep,
wait_for_event,
)

functions = [
Expand All @@ -14,6 +15,7 @@
print_event.fn,
send_event.fn,
two_steps_and_sleep.fn,
wait_for_event.fn,
]

__all__ = ["functions"]
14 changes: 14 additions & 0 deletions examples/functions/wait_for_event.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import inngest


@inngest.create_function(
inngest.FunctionOpts(id="wait_for_event", name="wait_for_event"),
inngest.TriggerEvent(event="app/wait_for_event"),
)
def fn(*, step: inngest.Step, **_kwargs: object) -> None:
res = step.wait_for_event(
"wait",
event="app/wait_for_event.fulfill",
timeout=inngest.Duration.second(2),
)
step.run("print-result", lambda: print(res))
2 changes: 2 additions & 0 deletions inngest/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ._internal.client import Inngest
from ._internal.const import Duration
from ._internal.errors import NonRetriableError
from ._internal.event import Event
from ._internal.frameworks import flask, tornado
Expand All @@ -14,6 +15,7 @@
__all__ = [
"BatchConfig",
"CancelConfig",
"Duration",
"Event",
"Function",
"FunctionOpts",
Expand Down
4 changes: 2 additions & 2 deletions inngest/_internal/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .errors import (
InternalError,
InvalidBaseURL,
InvalidFunctionConfig,
InvalidConfig,
MissingFunction,
)
from .execution import Call, CallError
Expand Down Expand Up @@ -154,7 +154,7 @@ def call_function(
def get_function_configs(self, app_url: str) -> list[FunctionConfig]:
configs = [fn.get_config(app_url) for fn in self._fns.values()]
if len(configs) == 0:
raise InvalidFunctionConfig("no functions found")
raise InvalidConfig("no functions found")
return configs

def _parse_registration_response(
Expand Down
6 changes: 3 additions & 3 deletions inngest/_internal/comm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import inngest

from .comm import CommHandler
from .errors import InvalidFunctionConfig
from .errors import InvalidConfig


class Test_get_function_configs(TestCase): # pylint: disable=invalid-name
Expand All @@ -29,7 +29,7 @@ def test_full_config(self) -> None:
batch_events=inngest.BatchConfig(max_size=2, timeout="1m"),
cancel=inngest.CancelConfig(
event="app/cancel",
if_expression="true",
if_exp="true",
timeout="1m",
),
id="fn",
Expand Down Expand Up @@ -60,5 +60,5 @@ def test_no_functions(self) -> None:
logger=self.client.logger,
)

with pytest.raises(InvalidFunctionConfig, match="no functions found"):
with pytest.raises(InvalidConfig, match="no functions found"):
comm.get_function_configs("http://foo.bar")
22 changes: 22 additions & 0 deletions inngest/_internal/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,28 @@
VERSION: Final = "0.1.0"


class Duration:
@classmethod
def second(cls, count: int = 1) -> int:
return count * 60 * 1000

@classmethod
def minute(cls, count: int = 1) -> int:
return count * cls.second(60)

@classmethod
def hour(cls, count: int = 1) -> int:
return count * cls.minute(60)

@classmethod
def day(cls, count: int = 1) -> int:
return count * cls.hour(24)

@classmethod
def week(cls, count: int = 1) -> int:
return count * cls.day(7)


class EnvKey(Enum):
BASE_URL = "INNGEST_BASE_URL"
EVENT_KEY = "INNGEST_EVENT_KEY"
Expand Down
4 changes: 2 additions & 2 deletions inngest/_internal/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(self, message: str | None = None) -> None:
)


class InvalidFunctionConfig(InternalError):
class InvalidConfig(InternalError):
status_code: int = 500

def __init__(self, message: str | None = None) -> None:
Expand All @@ -37,7 +37,7 @@ def __init__(self, message: str | None = None) -> None:
def from_validation_error(
cls,
err: ValidationError,
) -> InvalidFunctionConfig:
) -> InvalidConfig:
"""
Extract info from Pydantic's ValidationError and return our internal
InvalidFunctionConfig error.
Expand Down
2 changes: 2 additions & 0 deletions inngest/_internal/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@ class CallResponse(BaseModel):
id: str
name: str
op: Opcode
opts: dict[str, object] | None = None


class Opcode(Enum):
SLEEP = "Sleep"
STEP = "Step"
WAIT_FOR_EVENT = "WaitForEvent"
2 changes: 2 additions & 0 deletions inngest/_internal/execution_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ def test_serialization() -> None:
id="my_id",
name="my_name",
op=Opcode.STEP,
opts={},
).to_dict()

expectation = {
Expand All @@ -16,6 +17,7 @@ def test_serialization() -> None:
"id": "my_id",
"name": "my_name",
"op": "Step",
"opts": {},
}

assert actual == expectation
66 changes: 59 additions & 7 deletions inngest/_internal/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,7 @@
from pydantic import ValidationError

from .client import Inngest
from .errors import (
InvalidFunctionConfig,
NonRetriableError,
UnserializableOutput,
)
from .errors import InvalidConfig, NonRetriableError, UnserializableOutput
from .event import Event
from .execution import Call, CallError, CallResponse, Opcode
from .function_config import (
Expand All @@ -27,7 +23,7 @@
TriggerCron,
TriggerEvent,
)
from .transforms import hash_step_id, to_iso_utc
from .transforms import hash_step_id, to_duration_str, to_iso_utc
from .types import BaseModel, EmptySentinel, T


Expand All @@ -53,7 +49,7 @@ def convert_validation_error(
self,
err: ValidationError,
) -> BaseException:
return InvalidFunctionConfig.from_validation_error(err)
return InvalidConfig.from_validation_error(err)


class Function:
Expand Down Expand Up @@ -86,6 +82,7 @@ def call(
id=out.hashed_id,
name=out.name,
op=out.op,
opts=out.opts,
)
]
except Exception as err:
Expand Down Expand Up @@ -146,12 +143,14 @@ def __init__(
hashed_id: str,
name: str,
op: Opcode,
opts: dict[str, object] | None = None,
) -> None:
self.data = data
self.display_name = display_name
self.hashed_id = hashed_id
self.name = name
self.op = op
self.opts = opts


class _Step:
Expand Down Expand Up @@ -228,6 +227,49 @@ def sleep_until(
op=Opcode.SLEEP,
)

def wait_for_event(
self,
id: str, # pylint: disable=redefined-builtin
*,
event: str,
if_exp: str | None = None,
timeout: int,
) -> Event | None:
"""
Args:
event: Event name.
if_exp: An expression to filter events.
timeout: The maximum number of milliseconds to wait for the event.
"""

id_count = self._step_id_counter.increment(id)
if id_count > 1:
id = f"{id}:{id_count - 1}"
hashed_id = hash_step_id(id)

memo = self._get_memo(hashed_id)
if memo is not EmptySentinel:
if memo is None:
# Timeout
return None

# Fulfilled by an event
return Event.model_validate(memo)

opts: dict[str, object] = {
"timeout": to_duration_str(timeout),
}
if if_exp is not None:
opts["if"] = if_exp

raise EarlyReturn(
hashed_id=hashed_id,
display_name=id,
name=event,
op=Opcode.WAIT_FOR_EVENT,
opts=opts,
)


class _FunctionHandler(Protocol):
def __call__(self, *, event: Event, step: Step) -> object:
Expand Down Expand Up @@ -256,6 +298,16 @@ def sleep_until(
) -> None:
...

def wait_for_event(
self,
id: str, # pylint: disable=redefined-builtin
*,
event: str,
if_exp: str | None = None,
timeout: int,
) -> Event | None:
...


class _StepIDCounter:
def __init__(self) -> None:
Expand Down
6 changes: 3 additions & 3 deletions inngest/_internal/function_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from pydantic import Field, ValidationError

from .errors import InvalidFunctionConfig
from .errors import InvalidConfig
from .types import BaseModel

# A number > 0 followed by a time unit (s, m, h, d, w)
Expand All @@ -16,12 +16,12 @@ def convert_validation_error(
self,
err: ValidationError,
) -> BaseException:
return InvalidFunctionConfig.from_validation_error(err)
return InvalidConfig.from_validation_error(err)


class CancelConfig(_BaseConfig):
event: str
if_expression: str | None = None
if_exp: str | None = None
timeout: str | None = Field(default=None, pattern=TIME_PERIOD_REGEX)


Expand Down
17 changes: 17 additions & 0 deletions inngest/_internal/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import re
from datetime import datetime, timezone

from .const import Duration
from .errors import InvalidConfig
from .types import T


Expand Down Expand Up @@ -37,3 +39,18 @@ def to_iso_utc(value: datetime) -> str:
value.astimezone(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]
+ "Z"
)


def to_duration_str(ms: int) -> str:
if ms < Duration.second():
raise InvalidConfig("duration must be at least 1 second")
if ms < Duration.minute():
return f"{ms // Duration.second()}s"
if ms < Duration.hour():
return f"{ms // Duration.minute()}m"
if ms < Duration.day():
return f"{ms // Duration.hour()}h"
if ms < Duration.week():
return f"{ms // Duration.day()}d"

return f"{ms // Duration.week()}w"
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ extra = [
"mypy==1.6.1",
"pylint==3.0.1",
"pytest==7.4.2",
"pytest-xdist[psutil]==3.3.1",
"python-json-logger==2.0.7",
"toml==0.10.2",
"tornado==6.3.3",
Expand Down
20 changes: 1 addition & 19 deletions tests/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import time
from typing import Callable, Protocol
from typing import Protocol

import requests

Expand Down Expand Up @@ -36,20 +35,3 @@ def set_up(case: _FrameworkTestCase) -> None:

def tear_down(case: _FrameworkTestCase) -> None:
case.http_proxy.stop()


def wait_for(
assertion: Callable[[], None],
timeout: int = 5,
) -> None:
start = time.time()
while True:
try:
assertion()
return
except Exception as err:
timed_out = time.time() - start > timeout
if timed_out:
raise err

time.sleep(0.2)
Loading

0 comments on commit 28fc888

Please sign in to comment.