Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add experimental remote state middleware #178

Merged
merged 7 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions inngest/experimental/remote_state_middleware/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""
Remote state middleware for Inngest. This middleware allows you to store state
where you want, rather than in Inngest's infrastructure. This is useful for:
- Reducing bandwidth to/from the Inngest server.
- Avoiding step output size limits.

NOT STABLE! This is an experimental feature and may change in the future. If
you'd like to use it, we recommend copying this package into your source code.
"""

from .in_memory_driver import InMemoryDriver
from .middleware import RemoteStateMiddleware, StateDriver
from .s3_driver import S3Driver

__all__ = [
"InMemoryDriver",
"RemoteStateMiddleware",
"S3Driver",
"StateDriver",
]
75 changes: 75 additions & 0 deletions inngest/experimental/remote_state_middleware/in_memory_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import secrets
import string
import typing

import pydantic

import inngest

from .middleware import StateDriver


class _StatePlaceholder(pydantic.BaseModel):
key: str


class InMemoryDriver(StateDriver):
"""
In-memory driver for remote state middleware. This probably doesn't have any
use besides being a reference.
"""

# Marker to indicate that the data is stored remotely.
_marker: typing.Final = "__REMOTE_STATE__"

# Marker to indicate which strategy was used. This is useful for knowing
# whether the official S3 driver was used.
_strategy_marker: typing.Final = "__STRATEGY__"

_strategy_identifier: typing.Final = "inngest/memory"

def __init__(self) -> None: # noqa: D107
self._data: dict[str, object] = {}

def _create_key(self) -> str:
chars = string.ascii_letters + string.digits
return "".join(secrets.choice(chars) for _ in range(32))

def load_steps(self, steps: inngest.StepMemos) -> None:
"""
Hydrate steps with remote state if necessary.
"""

for step in steps.values():
if not isinstance(step.data, dict):
continue
if self._marker not in step.data:
continue
if self._strategy_marker not in step.data:
continue
if step.data[self._strategy_marker] != self._strategy_identifier:
continue

placeholder = _StatePlaceholder.model_validate(step.data)

step.data = self._data[placeholder.key]

def save_step(
self,
run_id: str,
value: object,
) -> dict[str, object]:
"""
Save a step's output to the remote store and return a placeholder.
"""

key = self._create_key()
self._data[key] = value

placeholder: dict[str, object] = {
self._marker: True,
self._strategy_marker: self._strategy_identifier,
**_StatePlaceholder(key=key).model_dump(),
}

return placeholder
129 changes: 129 additions & 0 deletions inngest/experimental/remote_state_middleware/middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from __future__ import annotations

import typing

import inngest
from inngest._internal import server_lib


class StateDriver(typing.Protocol):
"""
Protocol for the state driver.
"""

def load_steps(self, steps: inngest.StepMemos) -> None:
"""
Retrieve the value associated with the key.

Args:
----
steps: Steps whose output may need to be loaded from the remote store.
"""

...

def save_step(
self,
run_id: str,
value: object,
) -> dict[str, object]:
"""
Store the value and return a key to retrieve it later.

Args:
----
run_id: Run ID.
value: Output for an ended step.
"""

...


class RemoteStateMiddleware(inngest.MiddlewareSync):
"""
Middleware that reads/writes step output in a custom store (e.g. AWS S3).
This can drastically reduce bandwidth to/from the Inngest server, since step
output is stored within your infrastructure rather than Inngest's.
"""

_run_id: typing.Optional[str] = None

def __init__(
self,
client: inngest.Inngest,
raw_request: object,
driver: StateDriver,
) -> None:
"""
Args:
----
client: Inngest client.
raw_request: Framework/platform specific request object.
driver: State driver.
"""

super().__init__(client, raw_request)

self._driver = driver

@classmethod
def factory(
cls,
driver: StateDriver,
) -> typing.Callable[[inngest.Inngest, object], RemoteStateMiddleware]:
"""
Create a remote state middleware that can be passed to an Inngest client
or function.

Args:
----
driver: State driver.
"""

def _factory(
client: inngest.Inngest,
raw_request: object,
) -> RemoteStateMiddleware:
return cls(
client,
raw_request,
driver,
)

return _factory

def transform_input(
self,
ctx: inngest.Context,
function: inngest.Function,
steps: inngest.StepMemos,
) -> None:
"""
Inject remote state.
"""

self._driver.load_steps(steps)
self._run_id = ctx.run_id

def transform_output(self, result: inngest.TransformOutputResult) -> None:
"""
Store step output externally and replace with a marker and key.
"""

if result.step is None:
return None

if result.step.op is not server_lib.Opcode.STEP_RUN:
return None

if result.has_output() is False:
return None

if self._run_id is None:
# Unreachable
raise Exception("missing run ID")

result.output = self._driver.save_step(
self._run_id,
result.output,
)
124 changes: 124 additions & 0 deletions inngest/experimental/remote_state_middleware/s3_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from __future__ import annotations

import json
import secrets
import string
import typing

import pydantic
import typing_extensions

import inngest

from .middleware import StateDriver

if typing.TYPE_CHECKING:
from mypy_boto3_s3 import S3Client


class _StateSurrogate(pydantic.BaseModel):
"""
Replaces step output sent back to Inngest. Its data is sufficient to
retrieve the actual state.
"""

bucket: str
key: str


class S3Driver(StateDriver):
"""
S3 driver for remote state middleware.
"""

# Marker to indicate that the data is stored remotely.
_marker: typing.Final = "__REMOTE_STATE__"

# Marker to indicate which strategy was used. This is useful for knowing
# whether the official S3 driver was used.
_strategy_marker: typing.Final = "__STRATEGY__"

_strategy_identifier: typing.Final = "inngest/s3"

def __init__(
self,
*,
bucket: str,
client: S3Client,
) -> None:
"""
Args:
----
bucket: Bucket name to store remote state.
client: Boto3 S3 client.
"""

self._bucket = bucket
self._client = client

def _create_key(self) -> str:
chars = string.ascii_letters + string.digits
return "".join(secrets.choice(chars) for _ in range(32))

def _is_remote(
self, data: object
) -> typing_extensions.TypeGuard[dict[str, object]]:
return (
isinstance(data, dict)
and self._marker in data
and self._strategy_marker in data
and data[self._strategy_marker] == self._strategy_identifier
)

def load_steps(self, steps: inngest.StepMemos) -> None:
"""
Hydrate steps with remote state if necessary.

Args:
----
steps: Steps that may need hydration.
"""

for step in steps.values():
if not self._is_remote(step.data):
continue

surrogate = _StateSurrogate.model_validate(step.data)

step.data = json.loads(
self._client.get_object(
Bucket=surrogate.bucket,
Key=surrogate.key,
)["Body"]
.read()
.decode()
)

def save_step(
self,
run_id: str,
value: object,
) -> dict[str, object]:
"""
Save a step's output to the remote store and return a placeholder.

Args:
----
run_id: Run ID.
value: Step output.
"""

key = f"inngest/remote_state/{run_id}/{self._create_key()}"
self._client.put_object(
Body=json.dumps(value),
Bucket=self._bucket,
Key=key,
)

surrogate = {
self._marker: True,
self._strategy_marker: self._strategy_identifier,
**_StateSurrogate(bucket=self._bucket, key=key).model_dump(),
}

return surrogate
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,13 @@ dependencies = [
extra = [
"Django==4.2",
"Flask==2.3.0",
"boto3-stubs[s3]==1.35.46 ",
"boto3==1.35.47",
"build==1.0.3",
"cryptography==42.0.5",
"django-types==0.19.1",
"fastapi==0.100.0",
"moto[s3,server]==5.0.18",
"mypy==1.10.0",
"pynacl==1.5.0",
"pytest==7.4.2",
Expand Down
Loading
Loading