From 5a3f9d374f9d5dc3a1efc6aa85a8e7484ecc919b Mon Sep 17 00:00:00 2001 From: Aaron Harper Date: Wed, 23 Oct 2024 23:29:53 -0400 Subject: [PATCH] try --- Makefile | 2 +- .../remote_state_middleware/s3_driver.py | 15 +- tests/net.py | 43 ++-- .../cases/__init__.py | 4 +- .../cases/base.py | 3 + .../cases/step_output_aws.py | 191 ++++++++++++++++++ .../cases/step_output_s3.py | 155 -------------- .../cases/step_output_s3_old.py | 190 +++++++++++++++++ 8 files changed, 423 insertions(+), 180 deletions(-) create mode 100644 tests/test_experimental/test_remote_state_middleware/cases/step_output_aws.py delete mode 100644 tests/test_experimental/test_remote_state_middleware/cases/step_output_s3.py create mode 100644 tests/test_experimental/test_remote_state_middleware/cases/step_output_s3_old.py diff --git a/Makefile b/Makefile index f21bf469..eab342c1 100644 --- a/Makefile +++ b/Makefile @@ -19,7 +19,7 @@ install: check-venv @pip install -e '.[extra]' -c constraints.txt itest: check-venv - @pytest -n 4 -v tests + @pytest -svv tests/test_experimental/test_remote_state_middleware -k memory pre-commit: format-check lint type-check utest diff --git a/inngest/experimental/remote_state_middleware/s3_driver.py b/inngest/experimental/remote_state_middleware/s3_driver.py index 7763b854..036d2aec 100644 --- a/inngest/experimental/remote_state_middleware/s3_driver.py +++ b/inngest/experimental/remote_state_middleware/s3_driver.py @@ -4,6 +4,7 @@ import typing import boto3 +import mypy_boto3_s3 import pydantic import inngest @@ -33,16 +34,15 @@ class S3Driver(StateDriver): def __init__( # noqa: D107 self, *, + # aws_access_key_id: typing.Optional[str] = None, + # aws_secret_access_key: typing.Optional[str] = None, bucket: str, - endpoint_url: typing.Optional[str] = None, - region_name: str, + client: mypy_boto3_s3.S3Client, + # endpoint_url: typing.Optional[str] = None, + # region_name: str, ) -> None: self._bucket = bucket - self._client = boto3.client( - "s3", - endpoint_url=endpoint_url, - region_name=region_name, - ) + self._client = client def _create_key(self) -> str: chars = string.ascii_letters + string.digits @@ -84,7 +84,6 @@ def save_step( """ key = f"inngest/remote_state/{run_id}/{self._create_key()}" - self._client.create_bucket(Bucket=self._bucket) self._client.put_object( Body=json.dumps(value), Bucket=self._bucket, diff --git a/tests/net.py b/tests/net.py index f7da1fb1..ab183656 100644 --- a/tests/net.py +++ b/tests/net.py @@ -1,21 +1,36 @@ -import contextlib +import random import socket +import time import typing HOST: typing.Final = "0.0.0.0" -_min_port: typing.Final = 9000 -_max_port: typing.Final = 9999 + +_used_ports: set[int] = set() def get_available_port() -> int: - for port in range(_min_port, _max_port + 1): - with contextlib.closing( - socket.socket(socket.AF_INET, socket.SOCK_STREAM) - ) as sock: - try: - sock.bind((HOST, port)) - return port - except OSError: - continue - - raise Exception("failed to find available port") + start_time = time.time() + + while True: + if time.time() - start_time > 5: + raise Exception("timeout finding available port") + + port = random.randint(9000, 9999) + + if port in _used_ports: + continue + + if not _is_port_available(port): + continue + + _used_ports.add(port) + return port + + +def _is_port_available(port: int) -> bool: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + try: + s.bind((HOST, port)) + return True + except OSError: + return False diff --git a/tests/test_experimental/test_remote_state_middleware/cases/__init__.py b/tests/test_experimental/test_remote_state_middleware/cases/__init__.py index bf5ef55e..0f9fdc94 100644 --- a/tests/test_experimental/test_remote_state_middleware/cases/__init__.py +++ b/tests/test_experimental/test_remote_state_middleware/cases/__init__.py @@ -1,12 +1,12 @@ import inngest from inngest._internal import server_lib -from . import base, step_failed, step_output_in_memory, step_output_s3 +from . import base, step_failed, step_output_aws, step_output_in_memory _modules = ( step_failed, step_output_in_memory, - step_output_s3, + step_output_aws, ) diff --git a/tests/test_experimental/test_remote_state_middleware/cases/base.py b/tests/test_experimental/test_remote_state_middleware/cases/base.py index 6f138d97..4f2f732f 100644 --- a/tests/test_experimental/test_remote_state_middleware/cases/base.py +++ b/tests/test_experimental/test_remote_state_middleware/cases/base.py @@ -18,6 +18,9 @@ class TestClass(typing.Protocol): client: inngest.Inngest + def addCleanup(self, func: typing.Callable) -> None: + ... + @dataclasses.dataclass class Case: diff --git a/tests/test_experimental/test_remote_state_middleware/cases/step_output_aws.py b/tests/test_experimental/test_remote_state_middleware/cases/step_output_aws.py new file mode 100644 index 00000000..a760995c --- /dev/null +++ b/tests/test_experimental/test_remote_state_middleware/cases/step_output_aws.py @@ -0,0 +1,191 @@ +""" +Ensure step and function output is encrypted and decrypted correctly +""" + +import json + +import inngest +import tests.helper +from inngest._internal import server_lib + +# from inngest.experimental import remote_state_middleware +from tests import net + +from . import base + +# import boto3 +# import moto +# import moto.server + + +class _State(base.BaseState): + event: inngest.Event + events: list[inngest.Event] + + +def create( + client: inngest.Inngest, + framework: server_lib.Framework, + is_sync: bool, +) -> base.Case: + print("Creating case") + test_name = base.create_test_name(__file__) + event_name = base.create_event_name(framework, test_name) + fn_id = base.create_fn_id(test_name) + state = _State() + + aws_port = net.get_available_port() + + # Start mock AWS server. + # aws_server = moto.server.ThreadedMotoServer(port=net.get_available_port()) + # aws_server.start() + # aws_host, aws_port = aws_server.get_host_and_port() + aws_url = f"http://localhost:{aws_port}" + aws_access_key_id = "test" + aws_secret_access_key = "test" + aws_region = "us-east-1" + s3_bucket = "inngest" + + # s3_client = boto3.client( + # "s3", + # endpoint_url=aws_url, + # region_name=aws_region, + # aws_access_key_id=aws_access_key_id, + # aws_secret_access_key=aws_secret_access_key, + # ) + + # Create S3 driver. + # driver = remote_state_middleware.S3Driver( + # # aws_access_key_id=aws_access_key_id, + # # aws_secret_access_key=aws_secret_access_key, + # bucket=s3_bucket, + # client=s3_client, + # # endpoint_url=aws_url, + # # region_name=aws_region, + # ) + + @client.create_function( + fn_id=fn_id, + middleware=[ + # remote_state_middleware.RemoteStateMiddleware.factory(driver) + ], + retries=0, + trigger=inngest.TriggerEvent(event=event_name), + ) + def fn_sync( + ctx: inngest.Context, + step: inngest.StepSync, + ) -> str: + state.run_id = ctx.run_id + + def _step_1() -> str: + return "test string" + + step_1_output = step.run("step_1", _step_1) + assert step_1_output == "test string" + + def _step_2() -> list[inngest.JSON]: + return [{"a": {"b": 1}}] + + step_2_output = step.run("step_2", _step_2) + assert step_2_output == [{"a": {"b": 1}}] + + return "function output" + + @client.create_function( + fn_id=fn_id, + middleware=[ + # remote_state_middleware.RemoteStateMiddleware.factory(driver) + ], + retries=0, + trigger=inngest.TriggerEvent(event=event_name), + ) + async def fn_async( + ctx: inngest.Context, + step: inngest.Step, + ) -> str: + state.run_id = ctx.run_id + + def _step_1() -> str: + return "test string" + + step_1_output = await step.run("step_1", _step_1) + assert step_1_output == "test string" + + def _step_2() -> list[inngest.JSON]: + return [{"a": {"b": 1}}] + + step_2_output = await step.run("step_2", _step_2) + assert step_2_output == [{"a": {"b": 1}}] + + return "function output" + + async def run_test(self: base.TestClass) -> None: + # aws_server = moto.server.ThreadedMotoServer(port=aws_port) + # aws_server.start() + # self.addCleanup(aws_server.stop) + # self.addCleanup(s3_client.close) + + # # Create bucket. + # print("Creating bucket") + + # s3_client.create_bucket(Bucket=s3_bucket) + # # client.close() + + # print("Running test") + # self.client.send_sync(inngest.Event(name=event_name)) + + # print("Waiting for run ID") + # run_id = state.wait_for_run_id() + # print("Waiting for run status") + # run = tests.helper.client.wait_for_run_status( + # run_id, + # tests.helper.RunStatus.COMPLETED, + # ) + + # print("Getting step output") + + # # Ensure that step_1 output is encrypted and its value is correct + # output = json.loads( + # tests.helper.client.get_step_output( + # run_id=run_id, + # step_id="step_1", + # ) + # ) + + # assert isinstance(output, dict) + # data = output.get("data") + # assert isinstance(data, dict) + + # # Ensure the step output is remotely stored. + # assert driver._marker in data + + # print("Getting step output") + # output = json.loads( + # tests.helper.client.get_step_output( + # run_id=run_id, + # step_id="step_2", + # ) + # ) + # assert isinstance(output, dict) + # data = output.get("data") + # assert isinstance(data, dict) + + # # Ensure the step output is remotely stored. + # assert driver._marker in data + + # assert run.output is not None + # assert json.loads(run.output) == "function output" + print("done") + # aws_server.stop() + + if is_sync: + fn = fn_sync + else: + fn = fn_async + + return base.Case( + fn=fn, + run_test=run_test, + name=test_name, + ) diff --git a/tests/test_experimental/test_remote_state_middleware/cases/step_output_s3.py b/tests/test_experimental/test_remote_state_middleware/cases/step_output_s3.py deleted file mode 100644 index 144035f3..00000000 --- a/tests/test_experimental/test_remote_state_middleware/cases/step_output_s3.py +++ /dev/null @@ -1,155 +0,0 @@ -""" -Ensure step and function output is encrypted and decrypted correctly -""" - -import json - -import boto3 -import moto -import moto.server - -import inngest -import tests.helper -from inngest._internal import server_lib -from inngest.experimental import remote_state_middleware -from tests import net - -from . import base - - -class _State(base.BaseState): - event: inngest.Event - events: list[inngest.Event] - - -@moto.mock_aws -def create( - client: inngest.Inngest, - framework: server_lib.Framework, - is_sync: bool, -) -> base.Case: - test_name = base.create_test_name(__file__) - event_name = base.create_event_name(framework, test_name) - fn_id = base.create_fn_id(test_name) - state = _State() - - aws_server = moto.server.ThreadedMotoServer(port=net.get_available_port()) - aws_server.start() - aws_host, aws_port = aws_server.get_host_and_port() - - conn = boto3.resource("s3", region_name="us-east-1") - conn.create_bucket(Bucket="inngest") - driver = remote_state_middleware.S3Driver( - bucket="inngest", - endpoint_url=f"http://{aws_host}:{aws_port}", - region_name="us-east-1", - ) - - driver.save_step("run_id", "value") - - @client.create_function( - fn_id=fn_id, - middleware=[ - remote_state_middleware.RemoteStateMiddleware.factory(driver) - ], - retries=0, - trigger=inngest.TriggerEvent(event=event_name), - ) - def fn_sync( - ctx: inngest.Context, - step: inngest.StepSync, - ) -> str: - state.run_id = ctx.run_id - - def _step_1() -> str: - return "test string" - - step_1_output = step.run("step_1", _step_1) - assert step_1_output == "test string" - - def _step_2() -> list[inngest.JSON]: - return [{"a": {"b": 1}}] - - step_2_output = step.run("step_2", _step_2) - assert step_2_output == [{"a": {"b": 1}}] - - return "function output" - - @client.create_function( - fn_id=fn_id, - middleware=[ - remote_state_middleware.RemoteStateMiddleware.factory(driver) - ], - retries=0, - trigger=inngest.TriggerEvent(event=event_name), - ) - async def fn_async( - ctx: inngest.Context, - step: inngest.Step, - ) -> str: - state.run_id = ctx.run_id - - def _step_1() -> str: - return "test string" - - step_1_output = await step.run("step_1", _step_1) - assert step_1_output == "test string" - - def _step_2() -> list[inngest.JSON]: - return [{"a": {"b": 1}}] - - step_2_output = await step.run("step_2", _step_2) - assert step_2_output == [{"a": {"b": 1}}] - - return "function output" - - async def run_test(self: base.TestClass) -> None: - self.client.send_sync(inngest.Event(name=event_name)) - - run_id = state.wait_for_run_id() - run = tests.helper.client.wait_for_run_status( - run_id, - tests.helper.RunStatus.COMPLETED, - ) - - # Ensure that step_1 output is encrypted and its value is correct - output = json.loads( - tests.helper.client.get_step_output( - run_id=run_id, - step_id="step_1", - ) - ) - - assert isinstance(output, dict) - data = output.get("data") - assert isinstance(data, dict) - - # Ensure the step output is remotely stored. - assert driver._marker in data - - output = json.loads( - tests.helper.client.get_step_output( - run_id=run_id, - step_id="step_2", - ) - ) - assert isinstance(output, dict) - data = output.get("data") - assert isinstance(data, dict) - - # Ensure the step output is remotely stored. - assert driver._marker in data - - assert run.output is not None - assert json.loads(run.output) == "function output" - - if is_sync: - fn = fn_sync - else: - fn = fn_async - - return base.Case( - fn=fn, - run_test=run_test, - name=test_name, - ) diff --git a/tests/test_experimental/test_remote_state_middleware/cases/step_output_s3_old.py b/tests/test_experimental/test_remote_state_middleware/cases/step_output_s3_old.py new file mode 100644 index 00000000..a243d3e6 --- /dev/null +++ b/tests/test_experimental/test_remote_state_middleware/cases/step_output_s3_old.py @@ -0,0 +1,190 @@ +""" +Ensure step and function output is encrypted and decrypted correctly +""" + +import json + +import boto3 +import moto +import moto.server + +import inngest +import tests.helper +from inngest._internal import server_lib +from inngest.experimental import remote_state_middleware +from tests import net + +from . import base + + +class _State(base.BaseState): + event: inngest.Event + events: list[inngest.Event] + + +def create( + client: inngest.Inngest, + framework: server_lib.Framework, + is_sync: bool, +) -> base.Case: + print("Creating case") + test_name = base.create_test_name(__file__) + event_name = base.create_event_name(framework, test_name) + fn_id = base.create_fn_id(test_name) + state = _State() + + aws_port = net.get_available_port() + + # Start mock AWS server. + # aws_server = moto.server.ThreadedMotoServer(port=net.get_available_port()) + # aws_server.start() + # aws_host, aws_port = aws_server.get_host_and_port() + aws_url = f"http://localhost:{aws_port}" + aws_access_key_id = "test" + aws_secret_access_key = "test" + aws_region = "us-east-1" + s3_bucket = "inngest" + + # s3_client = boto3.client( + # "s3", + # endpoint_url=aws_url, + # region_name=aws_region, + # aws_access_key_id=aws_access_key_id, + # aws_secret_access_key=aws_secret_access_key, + # ) + + # Create S3 driver. + # driver = remote_state_middleware.S3Driver( + # # aws_access_key_id=aws_access_key_id, + # # aws_secret_access_key=aws_secret_access_key, + # bucket=s3_bucket, + # client=s3_client, + # # endpoint_url=aws_url, + # # region_name=aws_region, + # ) + + @client.create_function( + fn_id=fn_id, + middleware=[ + # remote_state_middleware.RemoteStateMiddleware.factory(driver) + ], + retries=0, + trigger=inngest.TriggerEvent(event=event_name), + ) + def fn_sync( + ctx: inngest.Context, + step: inngest.StepSync, + ) -> str: + state.run_id = ctx.run_id + + def _step_1() -> str: + return "test string" + + step_1_output = step.run("step_1", _step_1) + assert step_1_output == "test string" + + def _step_2() -> list[inngest.JSON]: + return [{"a": {"b": 1}}] + + step_2_output = step.run("step_2", _step_2) + assert step_2_output == [{"a": {"b": 1}}] + + return "function output" + + @client.create_function( + fn_id=fn_id, + middleware=[ + # remote_state_middleware.RemoteStateMiddleware.factory(driver) + ], + retries=0, + trigger=inngest.TriggerEvent(event=event_name), + ) + async def fn_async( + ctx: inngest.Context, + step: inngest.Step, + ) -> str: + state.run_id = ctx.run_id + + def _step_1() -> str: + return "test string" + + step_1_output = await step.run("step_1", _step_1) + assert step_1_output == "test string" + + def _step_2() -> list[inngest.JSON]: + return [{"a": {"b": 1}}] + + step_2_output = await step.run("step_2", _step_2) + assert step_2_output == [{"a": {"b": 1}}] + + return "function output" + + async def run_test(self: base.TestClass) -> None: + # aws_server = moto.server.ThreadedMotoServer(port=aws_port) + # aws_server.start() + # self.addCleanup(aws_server.stop) + # self.addCleanup(s3_client.close) + + # # Create bucket. + # print("Creating bucket") + + # s3_client.create_bucket(Bucket=s3_bucket) + # # client.close() + + # print("Running test") + # self.client.send_sync(inngest.Event(name=event_name)) + + # print("Waiting for run ID") + # run_id = state.wait_for_run_id() + # print("Waiting for run status") + # run = tests.helper.client.wait_for_run_status( + # run_id, + # tests.helper.RunStatus.COMPLETED, + # ) + + # print("Getting step output") + + # # Ensure that step_1 output is encrypted and its value is correct + # output = json.loads( + # tests.helper.client.get_step_output( + # run_id=run_id, + # step_id="step_1", + # ) + # ) + + # assert isinstance(output, dict) + # data = output.get("data") + # assert isinstance(data, dict) + + # # Ensure the step output is remotely stored. + # assert driver._marker in data + + # print("Getting step output") + # output = json.loads( + # tests.helper.client.get_step_output( + # run_id=run_id, + # step_id="step_2", + # ) + # ) + # assert isinstance(output, dict) + # data = output.get("data") + # assert isinstance(data, dict) + + # # Ensure the step output is remotely stored. + # assert driver._marker in data + + # assert run.output is not None + # assert json.loads(run.output) == "function output" + print("done") + # aws_server.stop() + + if is_sync: + fn = fn_sync + else: + fn = fn_async + + return base.Case( + fn=fn, + run_test=run_test, + name=test_name, + )