From c038e3760b85fe9f1e692de655b04635a9f9288c Mon Sep 17 00:00:00 2001 From: Aaron Harper Date: Wed, 23 Oct 2024 23:26:26 -0400 Subject: [PATCH] Fixes --- .../remote_state_middleware/__init__.py | 2 +- .../remote_state_middleware/s3_driver.py | 72 ++++++++++++------- tests/conftest.py | 1 + tests/net.py | 43 +++++++---- .../test_encryption_middleware/test_flask.py | 6 +- .../cases/__init__.py | 4 +- .../{step_output_s3.py => step_output_aws.py} | 34 +++++---- .../test_flask.py | 6 +- 8 files changed, 108 insertions(+), 60 deletions(-) rename tests/test_experimental/test_remote_state_middleware/cases/{step_output_s3.py => step_output_aws.py} (85%) diff --git a/inngest/experimental/remote_state_middleware/__init__.py b/inngest/experimental/remote_state_middleware/__init__.py index f84dda2a..fc3bf9bd 100644 --- a/inngest/experimental/remote_state_middleware/__init__.py +++ b/inngest/experimental/remote_state_middleware/__init__.py @@ -5,7 +5,7 @@ - Avoiding step output size limits. NOT STABLE! This is an experimental feature and may change in the future. If -you'd like to use it, we recommend copying this file into your source code. +you'd like to use it, we recommend copying this package into your source code. """ from .in_memory_driver import InMemoryDriver diff --git a/inngest/experimental/remote_state_middleware/s3_driver.py b/inngest/experimental/remote_state_middleware/s3_driver.py index 7763b854..17fe6067 100644 --- a/inngest/experimental/remote_state_middleware/s3_driver.py +++ b/inngest/experimental/remote_state_middleware/s3_driver.py @@ -1,17 +1,27 @@ +from __future__ import annotations + import json import secrets import string import typing -import boto3 import pydantic +import typing_extensions import inngest from .middleware import StateDriver +if typing.TYPE_CHECKING: + from mypy_boto3_s3 import S3Client + + +class _StateSurrogate(pydantic.BaseModel): + """ + Replaces step output sent back to Inngest. Its data is sufficient to + retrieve the actual state. + """ -class _StatePlaceholder(pydantic.BaseModel): bucket: str key: str @@ -30,45 +40,55 @@ class S3Driver(StateDriver): _strategy_identifier: typing.Final = "inngest/s3" - def __init__( # noqa: D107 + def __init__( self, *, bucket: str, - endpoint_url: typing.Optional[str] = None, - region_name: str, + client: S3Client, ) -> None: + """ + Args: + ---- + bucket: Bucket name to store remote state. + client: Boto3 S3 client. + """ + self._bucket = bucket - self._client = boto3.client( - "s3", - endpoint_url=endpoint_url, - region_name=region_name, - ) + self._client = client def _create_key(self) -> str: chars = string.ascii_letters + string.digits return "".join(secrets.choice(chars) for _ in range(32)) + def _is_remote( + self, data: object + ) -> typing_extensions.TypeGuard[dict[str, object]]: + return ( + isinstance(data, dict) + and self._marker in data + and self._strategy_marker in data + and data[self._strategy_marker] == self._strategy_identifier + ) + def load_steps(self, steps: inngest.StepMemos) -> None: """ Hydrate steps with remote state if necessary. + + Args: + ---- + steps: Steps that may need hydration. """ for step in steps.values(): - if not isinstance(step.data, dict): - continue - if self._marker not in step.data: - continue - if self._strategy_marker not in step.data: - continue - if step.data[self._strategy_marker] != self._strategy_identifier: + if not self._is_remote(step.data): continue - placeholder = _StatePlaceholder.model_validate(step.data) + surrogate = _StateSurrogate.model_validate(step.data) step.data = json.loads( self._client.get_object( - Bucket=placeholder.bucket, - Key=placeholder.key, + Bucket=surrogate.bucket, + Key=surrogate.key, )["Body"] .read() .decode() @@ -81,20 +101,24 @@ def save_step( ) -> dict[str, object]: """ Save a step's output to the remote store and return a placeholder. + + Args: + ---- + run_id: Run ID. + value: Step output. """ key = f"inngest/remote_state/{run_id}/{self._create_key()}" - self._client.create_bucket(Bucket=self._bucket) self._client.put_object( Body=json.dumps(value), Bucket=self._bucket, Key=key, ) - placeholder: dict[str, object] = { + surrogate = { self._marker: True, self._strategy_marker: self._strategy_identifier, - **_StatePlaceholder(bucket=self._bucket, key=key).model_dump(), + **_StateSurrogate(bucket=self._bucket, key=key).model_dump(), } - return placeholder + return surrogate diff --git a/tests/conftest.py b/tests/conftest.py index a2758c83..33e5163b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,4 +10,5 @@ def pytest_configure(config: pytest.Config) -> None: def pytest_unconfigure(config: pytest.Config) -> None: + print("pytest_unconfigure") dev_server.server.stop() diff --git a/tests/net.py b/tests/net.py index f7da1fb1..ab183656 100644 --- a/tests/net.py +++ b/tests/net.py @@ -1,21 +1,36 @@ -import contextlib +import random import socket +import time import typing HOST: typing.Final = "0.0.0.0" -_min_port: typing.Final = 9000 -_max_port: typing.Final = 9999 + +_used_ports: set[int] = set() def get_available_port() -> int: - for port in range(_min_port, _max_port + 1): - with contextlib.closing( - socket.socket(socket.AF_INET, socket.SOCK_STREAM) - ) as sock: - try: - sock.bind((HOST, port)) - return port - except OSError: - continue - - raise Exception("failed to find available port") + start_time = time.time() + + while True: + if time.time() - start_time > 5: + raise Exception("timeout finding available port") + + port = random.randint(9000, 9999) + + if port in _used_ports: + continue + + if not _is_port_available(port): + continue + + _used_ports.add(port) + return port + + +def _is_port_available(port: int) -> bool: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + try: + s.bind((HOST, port)) + return True + except OSError: + return False diff --git a/tests/test_experimental/test_encryption_middleware/test_flask.py b/tests/test_experimental/test_encryption_middleware/test_flask.py index ea3079ee..46f1dadd 100644 --- a/tests/test_experimental/test_encryption_middleware/test_flask.py +++ b/tests/test_experimental/test_encryption_middleware/test_flask.py @@ -14,7 +14,7 @@ from . import cases _framework = server_lib.Framework.FLASK -_app_id = f"{_framework.value}-functions" +_app_id = f"{_framework.value}-encryption-middleware" _client = inngest.Inngest( api_base_url=dev_server.server.origin, @@ -32,7 +32,7 @@ _fns.append(case.fn) -class TestFunctions(unittest.IsolatedAsyncioTestCase): +class TestEncryptionMiddleware(unittest.IsolatedAsyncioTestCase): app: flask.testing.FlaskClient client: inngest.Inngest dev_server_port: int @@ -78,7 +78,7 @@ def on_proxy_request( for case in _cases: test_name = f"test_{case.name}" - setattr(TestFunctions, test_name, case.run_test) + setattr(TestEncryptionMiddleware, test_name, case.run_test) if __name__ == "__main__": diff --git a/tests/test_experimental/test_remote_state_middleware/cases/__init__.py b/tests/test_experimental/test_remote_state_middleware/cases/__init__.py index bf5ef55e..0f9fdc94 100644 --- a/tests/test_experimental/test_remote_state_middleware/cases/__init__.py +++ b/tests/test_experimental/test_remote_state_middleware/cases/__init__.py @@ -1,12 +1,12 @@ import inngest from inngest._internal import server_lib -from . import base, step_failed, step_output_in_memory, step_output_s3 +from . import base, step_failed, step_output_aws, step_output_in_memory _modules = ( step_failed, step_output_in_memory, - step_output_s3, + step_output_aws, ) diff --git a/tests/test_experimental/test_remote_state_middleware/cases/step_output_s3.py b/tests/test_experimental/test_remote_state_middleware/cases/step_output_aws.py similarity index 85% rename from tests/test_experimental/test_remote_state_middleware/cases/step_output_s3.py rename to tests/test_experimental/test_remote_state_middleware/cases/step_output_aws.py index 144035f3..3a1d6b7d 100644 --- a/tests/test_experimental/test_remote_state_middleware/cases/step_output_s3.py +++ b/tests/test_experimental/test_remote_state_middleware/cases/step_output_aws.py @@ -22,7 +22,6 @@ class _State(base.BaseState): events: list[inngest.Event] -@moto.mock_aws def create( client: inngest.Inngest, framework: server_lib.Framework, @@ -33,20 +32,26 @@ def create( fn_id = base.create_fn_id(test_name) state = _State() - aws_server = moto.server.ThreadedMotoServer(port=net.get_available_port()) - aws_server.start() - aws_host, aws_port = aws_server.get_host_and_port() + aws_port = net.get_available_port() + aws_url = f"http://localhost:{aws_port}" + aws_access_key_id = "test" + aws_secret_access_key = "test" + aws_region = "us-east-1" + s3_bucket = "inngest" + + s3_client = boto3.client( + "s3", + endpoint_url=aws_url, + region_name=aws_region, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + ) - conn = boto3.resource("s3", region_name="us-east-1") - conn.create_bucket(Bucket="inngest") driver = remote_state_middleware.S3Driver( - bucket="inngest", - endpoint_url=f"http://{aws_host}:{aws_port}", - region_name="us-east-1", + bucket=s3_bucket, + client=s3_client, ) - driver.save_step("run_id", "value") - @client.create_function( fn_id=fn_id, middleware=[ @@ -104,15 +109,18 @@ def _step_2() -> list[inngest.JSON]: return "function output" async def run_test(self: base.TestClass) -> None: - self.client.send_sync(inngest.Event(name=event_name)) + aws_server = moto.server.ThreadedMotoServer(port=aws_port) + aws_server.start() + s3_client.create_bucket(Bucket=s3_bucket) + + self.client.send_sync(inngest.Event(name=event_name)) run_id = state.wait_for_run_id() run = tests.helper.client.wait_for_run_status( run_id, tests.helper.RunStatus.COMPLETED, ) - # Ensure that step_1 output is encrypted and its value is correct output = json.loads( tests.helper.client.get_step_output( run_id=run_id, diff --git a/tests/test_experimental/test_remote_state_middleware/test_flask.py b/tests/test_experimental/test_remote_state_middleware/test_flask.py index ea3079ee..342b52d0 100644 --- a/tests/test_experimental/test_remote_state_middleware/test_flask.py +++ b/tests/test_experimental/test_remote_state_middleware/test_flask.py @@ -14,7 +14,7 @@ from . import cases _framework = server_lib.Framework.FLASK -_app_id = f"{_framework.value}-functions" +_app_id = f"{_framework.value}-remote-state-middleware" _client = inngest.Inngest( api_base_url=dev_server.server.origin, @@ -32,7 +32,7 @@ _fns.append(case.fn) -class TestFunctions(unittest.IsolatedAsyncioTestCase): +class TestRemoteStateMiddleware(unittest.IsolatedAsyncioTestCase): app: flask.testing.FlaskClient client: inngest.Inngest dev_server_port: int @@ -78,7 +78,7 @@ def on_proxy_request( for case in _cases: test_name = f"test_{case.name}" - setattr(TestFunctions, test_name, case.run_test) + setattr(TestRemoteStateMiddleware, test_name, case.run_test) if __name__ == "__main__":