Skip to content

Commit

Permalink
Formalize in-memory driver
Browse files Browse the repository at this point in the history
  • Loading branch information
amh4r committed Oct 23, 2024
1 parent 7255aa0 commit c3b76d1
Show file tree
Hide file tree
Showing 6 changed files with 229 additions and 68 deletions.
18 changes: 18 additions & 0 deletions inngest/experimental/remote_state_middleware/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""
Remote state middleware for Inngest. This middleware allows you to store state
where you want, rather than in Inngest's infrastructure. This is useful for:
- Reducing bandwidth to/from the Inngest server.
- Avoiding step output size limits.
NOT STABLE! This is an experimental feature and may change in the future. If
you'd like to use it, we recommend copying this file into your source code.
"""

from .in_memory_driver import InMemoryDriver
from .middleware import RemoteStateMiddleware, StateDriver

__all__ = [
"InMemoryDriver",
"RemoteStateMiddleware",
"StateDriver",
]
66 changes: 66 additions & 0 deletions inngest/experimental/remote_state_middleware/in_memory_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import secrets
import string
import typing

import pydantic

import inngest

from .middleware import StateDriver


class _StatePlaceholder(pydantic.BaseModel):
__REMOTE_STATE__: typing.Literal[True] = True
key: str


class InMemoryDriver(StateDriver):
"""
In-memory driver for remote state middleware.
"""

# Marker to indicate that the data is stored remotely.
_marker: typing.Final = "__REMOTE_STATE__"

def __init__(self) -> None: # noqa: D107
self._data: dict[str, object] = {}

def _create_key(self) -> str:
chars = string.ascii_letters + string.digits
return "".join(secrets.choice(chars) for _ in range(32))

def load_steps(self, steps: inngest.StepMemos) -> None:
"""
Hydrate steps with remote state if necessary.
"""

for step in steps.values():
if not isinstance(step.data, dict):
continue
if self._marker not in step.data:
continue

try:
placeholder = _StatePlaceholder.model_validate(step.data)
except pydantic.ValidationError:
continue

step.data = self._data[placeholder.key]

def save_step(
self,
value: object,
) -> dict[str, object]:
"""
Save a step's output to the remote store and return a placeholder.
"""

key = self._create_key()
self._data[key] = value

placeholder: dict[str, object] = {
self._marker: True,
**_StatePlaceholder(key=key).model_dump(),
}

return placeholder
Original file line number Diff line number Diff line change
@@ -1,10 +1,3 @@
"""
Remote state middleware for Inngest.
NOT STABLE! This is an experimental feature and may change in the future. If
you'd like to use it, we recommend copying this file into your source code.
"""

from __future__ import annotations

import typing
Expand All @@ -17,33 +10,32 @@ class StateDriver(typing.Protocol):
Protocol for the state driver.
"""

def read(self, key: str) -> object:
def load_steps(self, steps: inngest.StepMemos) -> None:
"""
Retrieve the value associated with the key.
Args:
----
key: Key returned from `create`.
steps: Steps whose output may need to be loaded from the remote store.
"""

...

def write(self, value: object) -> str:
def save_step(
self,
value: object,
) -> dict[str, object]:
"""
Store the value and return a key to retrieve it later.
Args:
----
value: Value to store.
value: Output for an ended step.
"""

...


# Marker to indicate that the data is stored remotely.
_marker: typing.Final = "__REMOTE_STATE__"


class RemoteStateMiddleware(inngest.MiddlewareSync):
"""
Middleware that reads/writes step output in a custom store (e.g. AWS S3).
Expand Down Expand Up @@ -105,36 +97,17 @@ def transform_input(
Inject remote state.
"""

for step in steps.values():
if not _is_external(step.data):
continue

if not isinstance(step.data, dict):
continue

key = step.data.get("key")
if key is None:
continue

step.data = self._driver.read(key)
self._driver.load_steps(steps)

def transform_output(self, result: inngest.TransformOutputResult) -> None:
"""
Store step output externally and replace with a marker and key.
"""

if result.has_output() and result.step is not None:
result.output = {
_marker: True,
"key": self._driver.write(result.output),
}


def _is_external(value: object) -> bool:
if not isinstance(value, dict):
return False
if result.step is None:
return None

if value.get(_marker) is not True:
return False
if result.has_output() is False:
return None

return True
result.output = self._driver.save_step(result.output)
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import inngest
from inngest._internal import server_lib

from . import base, step_output
from . import base, step_failed, step_output

_modules = (step_output,)
_modules = (
step_failed,
step_output,
)


def create_async_cases(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""
Ensure step and function output is encrypted and decrypted correctly
"""

import json

import inngest
import tests.helper
from inngest._internal import server_lib
from inngest.experimental import remote_state_middleware

from . import base


class _State(base.BaseState):
event: inngest.Event
events: list[inngest.Event]


def create(
client: inngest.Inngest,
framework: server_lib.Framework,
is_sync: bool,
) -> base.Case:
test_name = base.create_test_name(__file__)
event_name = base.create_event_name(framework, test_name)
fn_id = base.create_fn_id(test_name)
state = _State()
driver = remote_state_middleware.InMemoryDriver()

@client.create_function(
fn_id=fn_id,
middleware=[
remote_state_middleware.RemoteStateMiddleware.factory(driver)
],
retries=0,
trigger=inngest.TriggerEvent(event=event_name),
)
def fn_sync(
ctx: inngest.Context,
step: inngest.StepSync,
) -> str:
state.run_id = ctx.run_id

def _step() -> str:
raise Exception("oh no")

try:
step.run("step_1", _step)
except Exception as e:
print("hi", str(e))
return str(e)

return "unreachable"

@client.create_function(
fn_id=fn_id,
middleware=[
remote_state_middleware.RemoteStateMiddleware.factory(driver)
],
retries=0,
trigger=inngest.TriggerEvent(event=event_name),
)
async def fn_async(
ctx: inngest.Context,
step: inngest.Step,
) -> str:
state.run_id = ctx.run_id

state.run_id = ctx.run_id

def _step() -> str:
raise Exception("oh no")

try:
await step.run("step_1", _step)
except Exception as e:
return str(e)

return "unreachable"

async def run_test(self: base.TestClass) -> None:
self.client.send_sync(inngest.Event(name=event_name))

run_id = state.wait_for_run_id()
run = tests.helper.client.wait_for_run_status(
run_id,
tests.helper.RunStatus.COMPLETED,
)

# Ensure that step_1 output is encrypted and its value is correct
output = json.loads(
tests.helper.client.get_step_output(
run_id=run_id,
step_id="step_1",
)
)
assert isinstance(output, dict)
assert output.get("data") is None
error = output.get("error")
assert isinstance(error, dict)

# Ensure that the error data was not remotely stored.
assert driver._marker not in error
assert error.get("message") == "oh no"

assert run.output is not None
assert json.loads(run.output) == "oh no"

if is_sync:
fn = fn_sync
else:
fn = fn_async

return base.Case(
fn=fn,
run_test=run_test,
name=test_name,
)
Original file line number Diff line number Diff line change
Expand Up @@ -3,37 +3,15 @@
"""

import json
import random
import string

import inngest
import tests.helper
from inngest._internal import server_lib
from inngest.experimental.remote_state_middleware import (
RemoteStateMiddleware,
StateDriver,
)
from inngest.experimental import remote_state_middleware

from . import base


class _Driver(StateDriver):
def __init__(self) -> None:
self._data: dict[str, object] = {}

def _create_key(self) -> str:
chars = string.ascii_letters + string.digits
return "".join(random.choice(chars) for _ in range(32))

def read(self, key: str) -> object:
return self._data[key]

def write(self, value: object) -> str:
key = self._create_key()
self._data[key] = value
return key


class _State(base.BaseState):
event: inngest.Event
events: list[inngest.Event]
Expand All @@ -48,11 +26,13 @@ def create(
event_name = base.create_event_name(framework, test_name)
fn_id = base.create_fn_id(test_name)
state = _State()
driver = _Driver()
driver = remote_state_middleware.InMemoryDriver()

@client.create_function(
fn_id=fn_id,
middleware=[RemoteStateMiddleware.factory(driver)],
middleware=[
remote_state_middleware.RemoteStateMiddleware.factory(driver)
],
retries=0,
trigger=inngest.TriggerEvent(event=event_name),
)
Expand All @@ -78,7 +58,9 @@ def _step_2() -> list[inngest.JSON]:

@client.create_function(
fn_id=fn_id,
middleware=[RemoteStateMiddleware.factory(driver)],
middleware=[
remote_state_middleware.RemoteStateMiddleware.factory(driver)
],
retries=0,
trigger=inngest.TriggerEvent(event=event_name),
)
Expand Down

0 comments on commit c3b76d1

Please sign in to comment.