From 320f8da7a83023b34fae52f1c5d0781215117bc4 Mon Sep 17 00:00:00 2001 From: elijahbenizzy Date: Sat, 2 Mar 2024 10:58:42 -0800 Subject: [PATCH] Fixes up bug in which single step async actions didn't work Everything worked except it treated them as sync actions. This treats them as async actions. We also duplicate sync tests to work for async actions. --- burr/core/action.py | 3 ++ tests/core/test_action.py | 59 ++++++++++++++++++++++++++++++++++ tests/core/test_application.py | 4 +-- 3 files changed, 64 insertions(+), 2 deletions(-) diff --git a/burr/core/action.py b/burr/core/action.py index 33d51e00..eb7fc693 100644 --- a/burr/core/action.py +++ b/burr/core/action.py @@ -408,6 +408,9 @@ def with_params(self, **kwargs: Any) -> "FunctionBasedAction": def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]: return self._fn(state, **self._bound_params, **run_kwargs) + def is_async(self) -> bool: + return inspect.iscoroutinefunction(self._fn) + def _validate_action_function(fn: Callable): """Validates that an action has the signature: (state: State) -> Tuple[dict, State] diff --git a/tests/core/test_action.py b/tests/core/test_action.py index 682402df..45b9be73 100644 --- a/tests/core/test_action.py +++ b/tests/core/test_action.py @@ -1,3 +1,4 @@ +import asyncio from typing import Tuple import pytest @@ -199,6 +200,64 @@ def my_action( assert result == {"output_variable": 1111} +async def test_function_based_action_async(): + @action(reads=["input_variable"], writes=["output_variable"]) + async def my_action(state: State) -> Tuple[dict, State]: + await asyncio.sleep(0.01) + return {"output_variable": state["input_variable"]}, state.update( + output_variable=state["input_variable"] + ) + + fn_based_action = create_action(my_action, name="my_action") + assert fn_based_action.is_async() + assert fn_based_action.single_step is True + assert fn_based_action.name == "my_action" + assert fn_based_action.reads == ["input_variable"] + assert fn_based_action.writes == ["output_variable"] + result, state = await fn_based_action.run_and_update(State({"input_variable": "foo"})) + assert result == {"output_variable": "foo"} + assert state.get_all() == {"input_variable": "foo", "output_variable": "foo"} + + +async def test_function_based_action_with_inputs_async(): + @action(reads=["input_variable"], writes=["output_variable"]) + async def my_action(state: State, bound_input: int, unbound_input: int) -> Tuple[dict, State]: + await asyncio.sleep(0.01) + res = state["input_variable"] + bound_input + unbound_input + return {"output_variable": res}, state.update(output_variable=res) + + fn_based_action: SingleStepAction = create_action( + my_action.bind(bound_input=10), name="my_action" + ) + assert fn_based_action.is_async() + assert fn_based_action.inputs == ["unbound_input"] + result, state = await fn_based_action.run_and_update( + State({"input_variable": 1}), unbound_input=100 + ) + assert state.get_all() == {"input_variable": 1, "output_variable": 111} + assert result == {"output_variable": 111} + + +async def test_function_based_action_with_defaults_async(): + @action(reads=["input_variable"], writes=["output_variable"]) + async def my_action( + state: State, bound_input: int, unbound_input: int, unbound_default_input: int = 1000 + ) -> Tuple[dict, State]: + await asyncio.sleep(0.01) + res = state["input_variable"] + bound_input + unbound_input + unbound_default_input + return {"output_variable": res}, state.update(output_variable=res) + + fn_based_action: SingleStepAction = create_action( + my_action.bind(bound_input=10), name="my_action" + ) + assert fn_based_action.inputs == ["unbound_input"] + result, state = await fn_based_action.run_and_update( + State({"input_variable": 1}), unbound_input=100 + ) + assert state.get_all() == {"input_variable": 1, "output_variable": 1111} + assert result == {"output_variable": 1111} + + def test_create_action_class_api(): raw_action = BasicAction() created_action = create_action(raw_action, name="my_action") diff --git a/tests/core/test_application.py b/tests/core/test_application.py index 3814978c..d5acad22 100644 --- a/tests/core/test_application.py +++ b/tests/core/test_application.py @@ -379,7 +379,7 @@ def test_app_step_with_inputs_missing(): state=State({"count": 0, "tracker": []}), initial_step="counter", ) - with pytest.raises(ValueError, match="Missing the following inputs"): + with pytest.raises(ValueError, match="missing required inputs"): app.step(inputs={}) @@ -447,7 +447,7 @@ async def test_app_astep_with_inputs_missing(): state=State({"count": 0, "tracker": []}), initial_step="counter_async", ) - with pytest.raises(ValueError, match="Missing the following inputs"): + with pytest.raises(ValueError, match="missing required inputs"): await app.astep(inputs={})