Skip to content

Commit

Permalink
Add S3 driver
Browse files Browse the repository at this point in the history
  • Loading branch information
amh4r committed Oct 24, 2024
1 parent c3b76d1 commit 8dc6d8e
Show file tree
Hide file tree
Showing 10 changed files with 305 additions and 39 deletions.
2 changes: 2 additions & 0 deletions inngest/experimental/remote_state_middleware/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@

from .in_memory_driver import InMemoryDriver
from .middleware import RemoteStateMiddleware, StateDriver
from .s3_driver import S3Driver

__all__ = [
"InMemoryDriver",
"RemoteStateMiddleware",
"S3Driver",
"StateDriver",
]
21 changes: 15 additions & 6 deletions inngest/experimental/remote_state_middleware/in_memory_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,24 @@


class _StatePlaceholder(pydantic.BaseModel):
__REMOTE_STATE__: typing.Literal[True] = True
key: str


class InMemoryDriver(StateDriver):
"""
In-memory driver for remote state middleware.
In-memory driver for remote state middleware. This probably doesn't have any
use besides being a reference.
"""

# Marker to indicate that the data is stored remotely.
_marker: typing.Final = "__REMOTE_STATE__"

# Marker to indicate which strategy was used. This is useful for knowing
# whether the official S3 driver was used.
_strategy_marker: typing.Final = "__STRATEGY__"

_strategy_identifier: typing.Final = "inngest/memory"

def __init__(self) -> None: # noqa: D107
self._data: dict[str, object] = {}

Expand All @@ -39,16 +45,18 @@ def load_steps(self, steps: inngest.StepMemos) -> None:
continue
if self._marker not in step.data:
continue

try:
placeholder = _StatePlaceholder.model_validate(step.data)
except pydantic.ValidationError:
if self._strategy_marker not in step.data:
continue
if step.data[self._strategy_marker] != self._strategy_identifier:
continue

placeholder = _StatePlaceholder.model_validate(step.data)

step.data = self._data[placeholder.key]

def save_step(
self,
run_id: str,
value: object,
) -> dict[str, object]:
"""
Expand All @@ -60,6 +68,7 @@ def save_step(

placeholder: dict[str, object] = {
self._marker: True,
self._strategy_marker: self._strategy_identifier,
**_StatePlaceholder(key=key).model_dump(),
}

Expand Down
14 changes: 13 additions & 1 deletion inngest/experimental/remote_state_middleware/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@ def load_steps(self, steps: inngest.StepMemos) -> None:

def save_step(
self,
run_id: str,
value: object,
) -> dict[str, object]:
"""
Store the value and return a key to retrieve it later.
Args:
----
run_id: Run ID.
value: Output for an ended step.
"""

Expand All @@ -43,6 +45,8 @@ class RemoteStateMiddleware(inngest.MiddlewareSync):
output is stored within your infrastructure rather than Inngest's.
"""

_run_id: typing.Optional[str] = None

def __init__(
self,
client: inngest.Inngest,
Expand Down Expand Up @@ -98,6 +102,7 @@ def transform_input(
"""

self._driver.load_steps(steps)
self._run_id = ctx.run_id

def transform_output(self, result: inngest.TransformOutputResult) -> None:
"""
Expand All @@ -110,4 +115,11 @@ def transform_output(self, result: inngest.TransformOutputResult) -> None:
if result.has_output() is False:
return None

result.output = self._driver.save_step(result.output)
if self._run_id is None:
# Unreachable
raise Exception("missing run ID")

result.output = self._driver.save_step(
self._run_id,
result.output,
)
100 changes: 100 additions & 0 deletions inngest/experimental/remote_state_middleware/s3_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import json
import secrets
import string
import typing

import boto3
import pydantic

import inngest

from .middleware import StateDriver


class _StatePlaceholder(pydantic.BaseModel):
bucket: str
key: str


class S3Driver(StateDriver):
"""
S3 driver for remote state middleware.
"""

# Marker to indicate that the data is stored remotely.
_marker: typing.Final = "__REMOTE_STATE__"

# Marker to indicate which strategy was used. This is useful for knowing
# whether the official S3 driver was used.
_strategy_marker: typing.Final = "__STRATEGY__"

_strategy_identifier: typing.Final = "inngest/s3"

def __init__( # noqa: D107
self,
*,
bucket: str,
endpoint_url: typing.Optional[str] = None,
region_name: str,
) -> None:
self._bucket = bucket
self._client = boto3.client(
"s3",
endpoint_url=endpoint_url,
region_name=region_name,
)

def _create_key(self) -> str:
chars = string.ascii_letters + string.digits
return "".join(secrets.choice(chars) for _ in range(32))

def load_steps(self, steps: inngest.StepMemos) -> None:
"""
Hydrate steps with remote state if necessary.
"""

for step in steps.values():
if not isinstance(step.data, dict):
continue
if self._marker not in step.data:
continue
if self._strategy_marker not in step.data:
continue
if step.data[self._strategy_marker] != self._strategy_identifier:
continue

placeholder = _StatePlaceholder.model_validate(step.data)

step.data = json.loads(
self._client.get_object(
Bucket=placeholder.bucket,
Key=placeholder.key,
)["Body"]
.read()
.decode()
)

def save_step(
self,
run_id: str,
value: object,
) -> dict[str, object]:
"""
Save a step's output to the remote store and return a placeholder.
"""

key = f"inngest/remote_state/{run_id}/{self._create_key()}"
self._client.create_bucket(Bucket=self._bucket)
self._client.put_object(
Body=json.dumps(value),
Bucket=self._bucket,
Key=key,
)

placeholder: dict[str, object] = {
self._marker: True,
self._strategy_marker: self._strategy_identifier,
**_StatePlaceholder(bucket=self._bucket, key=key).model_dump(),
}

return placeholder
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,13 @@ dependencies = [
extra = [
"Django==4.2",
"Flask==2.3.0",
"boto3-stubs[s3]==1.35.46 ",
"boto3==1.35.47",
"build==1.0.3",
"cryptography==42.0.5",
"django-types==0.19.1",
"fastapi==0.100.0",
"moto[s3,server]==5.0.18",
"mypy==1.10.0",
"pynacl==1.5.0",
"pytest==7.4.2",
Expand Down
43 changes: 14 additions & 29 deletions tests/net.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,21 @@
import random
import contextlib
import socket
import time
import typing

HOST: typing.Final = "0.0.0.0"

_used_ports: set[int] = set()
_min_port: typing.Final = 9000
_max_port: typing.Final = 9999


def get_available_port() -> int:
start_time = time.time()

while True:
if time.time() - start_time > 5:
raise Exception("timeout finding available port")

port = random.randint(9000, 9999)

if port in _used_ports:
continue

if not _is_port_available(port):
continue

_used_ports.add(port)
return port


def _is_port_available(port: int) -> bool:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
try:
s.bind((HOST, port))
return True
except OSError:
return False
for port in range(_min_port, _max_port + 1):
with contextlib.closing(
socket.socket(socket.AF_INET, socket.SOCK_STREAM)
) as sock:
try:
sock.bind((HOST, port))
return port
except OSError:
continue

raise Exception("failed to find available port")
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import inngest
from inngest._internal import server_lib

from . import base, step_failed, step_output
from . import base, step_failed, step_output_in_memory, step_output_s3

_modules = (
step_failed,
step_output,
step_output_in_memory,
step_output_s3,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def _step() -> str:
try:
step.run("step_1", _step)
except Exception as e:
print("hi", str(e))
return str(e)

return "unreachable"
Expand Down
Loading

0 comments on commit 8dc6d8e

Please sign in to comment.