Skip to content

Commit

Permalink
Fixes up bug in which single step async actions didn't work
Browse files Browse the repository at this point in the history
Everything worked except it treated them as sync actions. This treats
them as async actions. We also duplicate sync tests to work for async
actions.
  • Loading branch information
elijahbenizzy committed Mar 2, 2024
1 parent 9756adc commit 320f8da
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 2 deletions.
3 changes: 3 additions & 0 deletions burr/core/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,9 @@ def with_params(self, **kwargs: Any) -> "FunctionBasedAction":
def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]:
return self._fn(state, **self._bound_params, **run_kwargs)

def is_async(self) -> bool:
return inspect.iscoroutinefunction(self._fn)


def _validate_action_function(fn: Callable):
"""Validates that an action has the signature: (state: State) -> Tuple[dict, State]
Expand Down
59 changes: 59 additions & 0 deletions tests/core/test_action.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from typing import Tuple

import pytest
Expand Down Expand Up @@ -199,6 +200,64 @@ def my_action(
assert result == {"output_variable": 1111}


async def test_function_based_action_async():
@action(reads=["input_variable"], writes=["output_variable"])
async def my_action(state: State) -> Tuple[dict, State]:
await asyncio.sleep(0.01)
return {"output_variable": state["input_variable"]}, state.update(
output_variable=state["input_variable"]
)

fn_based_action = create_action(my_action, name="my_action")
assert fn_based_action.is_async()
assert fn_based_action.single_step is True
assert fn_based_action.name == "my_action"
assert fn_based_action.reads == ["input_variable"]
assert fn_based_action.writes == ["output_variable"]
result, state = await fn_based_action.run_and_update(State({"input_variable": "foo"}))
assert result == {"output_variable": "foo"}
assert state.get_all() == {"input_variable": "foo", "output_variable": "foo"}


async def test_function_based_action_with_inputs_async():
@action(reads=["input_variable"], writes=["output_variable"])
async def my_action(state: State, bound_input: int, unbound_input: int) -> Tuple[dict, State]:
await asyncio.sleep(0.01)
res = state["input_variable"] + bound_input + unbound_input
return {"output_variable": res}, state.update(output_variable=res)

fn_based_action: SingleStepAction = create_action(
my_action.bind(bound_input=10), name="my_action"
)
assert fn_based_action.is_async()
assert fn_based_action.inputs == ["unbound_input"]
result, state = await fn_based_action.run_and_update(
State({"input_variable": 1}), unbound_input=100
)
assert state.get_all() == {"input_variable": 1, "output_variable": 111}
assert result == {"output_variable": 111}


async def test_function_based_action_with_defaults_async():
@action(reads=["input_variable"], writes=["output_variable"])
async def my_action(
state: State, bound_input: int, unbound_input: int, unbound_default_input: int = 1000
) -> Tuple[dict, State]:
await asyncio.sleep(0.01)
res = state["input_variable"] + bound_input + unbound_input + unbound_default_input
return {"output_variable": res}, state.update(output_variable=res)

fn_based_action: SingleStepAction = create_action(
my_action.bind(bound_input=10), name="my_action"
)
assert fn_based_action.inputs == ["unbound_input"]
result, state = await fn_based_action.run_and_update(
State({"input_variable": 1}), unbound_input=100
)
assert state.get_all() == {"input_variable": 1, "output_variable": 1111}
assert result == {"output_variable": 1111}


def test_create_action_class_api():
raw_action = BasicAction()
created_action = create_action(raw_action, name="my_action")
Expand Down
4 changes: 2 additions & 2 deletions tests/core/test_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def test_app_step_with_inputs_missing():
state=State({"count": 0, "tracker": []}),
initial_step="counter",
)
with pytest.raises(ValueError, match="Missing the following inputs"):
with pytest.raises(ValueError, match="missing required inputs"):
app.step(inputs={})


Expand Down Expand Up @@ -447,7 +447,7 @@ async def test_app_astep_with_inputs_missing():
state=State({"count": 0, "tracker": []}),
initial_step="counter_async",
)
with pytest.raises(ValueError, match="Missing the following inputs"):
with pytest.raises(ValueError, match="missing required inputs"):
await app.astep(inputs={})


Expand Down

0 comments on commit 320f8da

Please sign in to comment.