Skip to content

Commit

Permalink
Fix step.invoke data is not encrypted (#187)
Browse files Browse the repository at this point in the history
  • Loading branch information
amh4r authored Dec 8, 2024
1 parent 02bdcbb commit 9a91065
Show file tree
Hide file tree
Showing 6 changed files with 175 additions and 14 deletions.
39 changes: 39 additions & 0 deletions inngest/experimental/encryption_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import nacl.utils

import inngest
from inngest._internal import server_lib

# Marker to indicate that the data is encrypted
_encryption_marker: typing.Final = "__ENCRYPTED__"
Expand Down Expand Up @@ -176,6 +177,31 @@ def _decrypt_event_data(
data.keys(), key=lambda k: k != self._event_encryption_field
)

if _is_encrypted(data):
# Event data has top-level encryption. However, there may also be
# unencrypted fields (like "_inngest" if this is an invoke event).

decrypted = self._decrypt(data)
if not isinstance(decrypted, dict):
raise Exception("decrypted data is not a dict")

# Need to type cast because mypy thinks it's a `dict[str, object]`.
decrypted = typing.cast(
typing.Mapping[str, inngest.JSON], decrypted
)

# This should be empty if this isn't an invoke event.
unencrypted_data = {
k: v
for k, v in data.items()
if k not in (_encryption_marker, _strategy_marker, "data")
}

return {
**unencrypted_data,
**decrypted,
}

# Iterate over all the keys, decrypting the first encrypted field found.
# It's possible that the event producer uses a different encryption
# field
Expand Down Expand Up @@ -236,6 +262,19 @@ def transform_output(self, result: inngest.TransformOutputResult) -> None:
if result.has_output():
result.output = self._encrypt(result.output)

# Encrypt invoke data if present.
if (
result.step is not None
and result.step.op is server_lib.Opcode.INVOKE
and result.step.opts is not None
):
payload = result.step.opts.get("payload", {})
if isinstance(payload, dict):
data = payload.get("data")
if data is not None:
payload["data"] = self._encrypt(data)
result.step.opts["payload"] = payload


def _is_encrypted(value: object) -> bool:
if not isinstance(value, dict):
Expand Down
5 changes: 4 additions & 1 deletion tests/gql.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ def query(self, query: Query) -> typing.Union[Response, Error]:
return Error(message=f"failed to parse response as JSON: {e}")

if gql_res.errors is not None:
return Error(message="GraphQL error", response=gql_res)
msg = "GraphQL error"
if len(gql_res.errors) > 0:
msg += f": {gql_res.errors[0].get('message')}"
return Error(message=msg, response=gql_res)

return gql_res

Expand Down
7 changes: 7 additions & 0 deletions tests/helper.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from __future__ import annotations

import enum
import json
import time
import typing

import pydantic

import inngest
from inngest._internal import types
from inngest.experimental import dev_server

Expand Down Expand Up @@ -153,6 +155,9 @@ def wait_for_run_status(
query = """
query GetRun($run_id: ID!) {
functionRun(query: { functionRunId: $run_id }) {
event {
raw
}
id
output
status
Expand All @@ -168,6 +173,7 @@ def wait_for_run_status(
if not isinstance(run, dict):
raise Exception("unexpected response")
if run["status"] == status.value:
run["event"] = json.loads(run["event"]["raw"])
return _Run.model_validate(run)

if any(run["status"] == s.value for s in ended_statuses):
Expand All @@ -185,6 +191,7 @@ def wait_for_run_status(


class _Run(types.BaseModel):
event: inngest.Event
id: str
output: typing.Optional[str]
status: RunStatus
Expand Down
63 changes: 51 additions & 12 deletions tests/test_experimental/test_encryption_middleware/cases/invoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
Ensure that invoke works.
"""

import typing

import nacl.encoding
import nacl.hash
import nacl.secret
Expand All @@ -26,8 +28,8 @@


class _State(base.BaseState):
event: inngest.Event
events: list[inngest.Event]
child_run_id: typing.Optional[str] = None
child_event: typing.Optional[inngest.Event] = None


def create(
Expand All @@ -49,8 +51,13 @@ def create(
def child_fn_sync(
ctx: inngest.Context,
step: inngest.StepSync,
) -> str:
return f"Hello, {ctx.event.data['name']}!"
) -> dict[str, str]:
state.child_run_id = ctx.run_id
state.child_event = ctx.event

return {
"msg": f"Hello, {ctx.event.data['name']}!",
}

@client.create_function(
fn_id=fn_id,
Expand All @@ -69,8 +76,8 @@ def fn_sync(
function=child_fn_sync,
data={"name": "Alice"},
)
assert isinstance(result, str)
assert result == "Hello, Alice!"
assert isinstance(result, dict)
assert result["msg"] == "Hello, Alice!"

@client.create_function(
fn_id=f"{fn_id}/child",
Expand All @@ -81,8 +88,13 @@ def fn_sync(
async def child_fn_async(
ctx: inngest.Context,
step: inngest.Step,
) -> str:
return f"Hello, {ctx.event.data['name']}!"
) -> dict[str, str]:
state.child_run_id = ctx.run_id
state.child_event = ctx.event

return {
"msg": f"Hello, {ctx.event.data['name']}!",
}

@client.create_function(
fn_id=fn_id,
Expand All @@ -96,13 +108,13 @@ async def fn_async(
) -> None:
state.run_id = ctx.run_id

result = step.invoke(
result = await step.invoke(
"invoke",
function=child_fn_sync,
function=child_fn_async,
data={"name": "Alice"},
)
assert isinstance(result, str)
assert result == "Hello, Alice!"
assert isinstance(result, dict)
assert result["msg"] == "Hello, Alice!"

async def run_test(self: base.TestClass) -> None:
self.client.send_sync(inngest.Event(name=event_name))
Expand All @@ -113,6 +125,33 @@ async def run_test(self: base.TestClass) -> None:
tests.helper.RunStatus.COMPLETED,
)

assert state.child_event is not None

# Ensure we stripped the encryption fields (encryption marker, strategy
# marker, and encrypted data).
assert sorted(state.child_event.data.keys()) == ["_inngest", "name"]

assert state.child_run_id is not None
child_run = tests.helper.client.wait_for_run_status(
state.child_run_id,
tests.helper.RunStatus.COMPLETED,
)

# Ensure the stored event has the encryption fields.
assert sorted(child_run.event.data.keys()) == [
"__ENCRYPTED__",
"__STRATEGY__",
"_inngest",
"data",
]

# Ensure the data is encrypted.
encrypted_data = child_run.event.data["data"]
assert isinstance(encrypted_data, str)
assert enc.decrypt(encrypted_data.encode("utf-8")) == {
"name": "Alice",
}

if is_sync:
fn = [child_fn_sync, fn_sync]
else:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import threading
import unittest

import fastapi
import uvicorn

import inngest
import inngest.fast_api
from inngest._internal import server_lib
from inngest.experimental import dev_server
from tests import base, net

from . import cases

_framework = server_lib.Framework.FAST_API
_app_id = f"{_framework.value}-encryption-middleware"

_client = inngest.Inngest(
api_base_url=dev_server.server.origin,
app_id=_app_id,
event_api_base_url=dev_server.server.origin,
is_production=False,
)

_cases = cases.create_async_cases(_client, _framework)
_fns: list[inngest.Function] = []
for case in _cases:
if isinstance(case.fn, list):
_fns.extend(case.fn)
else:
_fns.append(case.fn)


class TestEncryptionMiddleware(unittest.IsolatedAsyncioTestCase):
client = _client
app_thread: threading.Thread

@classmethod
def setUpClass(cls) -> None:
super().setUpClass()

port = net.get_available_port()

def start_app() -> None:
app = fastapi.FastAPI()
inngest.fast_api.serve(
app,
_client,
_fns,
)
uvicorn.run(app, host="0.0.0.0", port=port, log_level="warning")

# Start FastAPI in a thread instead of using their test client, since
# their test client doesn't seem to actually run requests in parallel
# (this is evident in the flakiness of our asyncio race test). If we fix
# this issue, we can go back to their test client
cls.app_thread = threading.Thread(daemon=True, target=start_app)
cls.app_thread.start()
base.register(port)

@classmethod
def tearDownClass(cls) -> None:
super().tearDownClass()
cls.app_thread.join(timeout=1)


for case in _cases:
test_name = f"test_{case.name}"
setattr(TestEncryptionMiddleware, test_name, case.run_test)


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion tests/test_function/cases/batch_that_needs_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ async def run_test(self: base.TestClass) -> None:
)
self.client.send_sync(events)

run_id = state.wait_for_run_id()
run_id = state.wait_for_run_id(timeout=datetime.timedelta(seconds=10))
tests.helper.client.wait_for_run_status(
run_id,
tests.helper.RunStatus.COMPLETED,
Expand Down

0 comments on commit 9a91065

Please sign in to comment.