Skip to content

Commit

Permalink
Ensure that all results all dictionaries
Browse files Browse the repository at this point in the history
This is a common mistake. We need them to be dictionaries as that's what
the framework expects, but it breaks in serialization (too late)
  • Loading branch information
elijahbenizzy committed Mar 22, 2024
1 parent e244dab commit 5b817c6
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 12 deletions.
35 changes: 28 additions & 7 deletions burr/core/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,15 @@ class Transition:
SEQUENCE_ID = "__SEQUENCE_ID"


def _run_function(function: Function, state: State, inputs: Dict[str, Any]) -> dict:
def _validate_result(result: dict, name: str) -> None:
if not isinstance(result, dict):
raise ValueError(
f"Action {name} returned a non-dict result: {result}. "
f"All results must be dictionaries."
)


def _run_function(function: Function, state: State, inputs: Dict[str, Any], name: str) -> dict:
"""Runs a function, returning the result of running the function.
Note this restricts the keys in the state to only those that the
function reads.
Expand All @@ -66,21 +74,27 @@ def _run_function(function: Function, state: State, inputs: Dict[str, Any]) -> d
"""
if function.is_async():
raise ValueError(
f"Cannot run async: {function} "
f"Cannot run async: {name} "
"in non-async context. Use astep()/aiterate()/arun() "
"instead...)"
)
state_to_use = state.subset(*function.reads)
function.validate_inputs(inputs)
return function.run(state_to_use, **inputs)
result = function.run(state_to_use, **inputs)
_validate_result(result, name)
return result


async def _arun_function(function: Function, state: State, inputs: Dict[str, Any]) -> dict:
async def _arun_function(
function: Function, state: State, inputs: Dict[str, Any], name: str
) -> dict:
"""Runs a function, returning the result of running the function.
Async version of the above."""
state_to_use = state.subset(*function.reads)
function.validate_inputs(inputs)
return await function.run(state_to_use, **inputs)
result = await function.run(state_to_use, **inputs)
_validate_result(result, name)
return result


def _state_update(state_to_modify: State, modified_state: State) -> State:
Expand Down Expand Up @@ -194,7 +208,9 @@ def _run_single_step_action(
# TODO -- guard all reads/writes with a subset of the state
action.validate_inputs(inputs)
result, new_state = action.run_and_update(state, **inputs)
_validate_result(result, action.name)
out = result, _state_update(state, new_state)
_validate_result(result, action.name)
_validate_reducer_writes(action, new_state, action.name)
return out

Expand All @@ -205,6 +221,7 @@ def _run_single_step_streaming_action(
action.validate_inputs(inputs)
generator = action.stream_run_and_update(state, **inputs)
result, state = yield from generator
_validate_result(result, action.name)
_validate_reducer_writes(action, state, action.name)
return result, state

Expand All @@ -215,6 +232,7 @@ def _run_multi_step_streaming_action(
action.validate_inputs(inputs)
generator = action.stream_run(state, **inputs)
result = yield from generator
_validate_result(result, action.name)
new_state = _run_reducer(action, state, result, action.name)
return result, _state_update(state, new_state)

Expand All @@ -226,6 +244,7 @@ async def _arun_single_step_action(
state_to_use = state
action.validate_inputs(inputs)
result, new_state = await action.run_and_update(state_to_use, **inputs)
_validate_result(result, action.name)
_validate_reducer_writes(action, new_state, action.name)
return result, _state_update(state, new_state)

Expand Down Expand Up @@ -333,7 +352,7 @@ def _step(
if next_action.single_step:
result, new_state = _run_single_step_action(next_action, self._state, inputs)
else:
result = _run_function(next_action, self._state, inputs)
result = _run_function(next_action, self._state, inputs, name=next_action.name)
new_state = _run_reducer(next_action, self._state, result, next_action.name)

new_state = self._update_internal_state_value(new_state, next_action)
Expand Down Expand Up @@ -437,7 +456,9 @@ async def astep(self, inputs: Dict[str, Any] = None) -> Optional[Tuple[Action, d
next_action, self._state, inputs=inputs
)
else:
result = await _arun_function(next_action, self._state, inputs=inputs)
result = await _arun_function(
next_action, self._state, inputs=inputs, name=next_action.name
)
new_state = _run_reducer(next_action, self._state, result, next_action.name)
new_state = self._update_internal_state_value(new_state, next_action)
self._set_state(new_state)
Expand Down
133 changes: 128 additions & 5 deletions tests/core/test_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,19 +187,40 @@ class BrokenStepException(Exception):
)


async def incorrect(x):
return "not a dict"


base_action_incorrect_result_type = PassedInAction(
reads=[],
writes=[],
fn=lambda x: "not a dict",
update_fn=lambda result, state: state,
inputs=[],
)

base_action_incorrect_result_type_async = PassedInActionAsync(
reads=[],
writes=[],
fn=incorrect,
update_fn=lambda result, state: state,
inputs=[],
)


def test__run_function():
"""Tests that we can run a function"""
action = base_counter_action
state = State({})
result = _run_function(action, state, inputs={})
result = _run_function(action, state, inputs={}, name=action.name)
assert result == {"count": 1}


def test__run_function_with_inputs():
"""Tests that we can run a function"""
action = base_counter_action_with_inputs
state = State({})
result = _run_function(action, state, inputs={"additional_increment": 1})
result = _run_function(action, state, inputs={"additional_increment": 1}, name=action.name)
assert result == {"count": 2}


Expand All @@ -208,7 +229,15 @@ def test__run_function_cant_run_async():
action = base_counter_action_async
state = State({})
with pytest.raises(ValueError, match="async"):
_run_function(action, state, inputs={})
_run_function(action, state, inputs={}, name=action.name)


def test__run_function_incorrect_result_type():
"""Tests that we can run an async function"""
action = base_action_incorrect_result_type
state = State({})
with pytest.raises(ValueError, match="returned a non-dict"):
_run_function(action, state, inputs={}, name=action.name)


def test__run_reducer_modifies_state():
Expand Down Expand Up @@ -243,15 +272,25 @@ async def test__arun_function():
"""Tests that we can run an async function"""
action = base_counter_action_async
state = State({})
result = await _arun_function(action, state, inputs={})
result = await _arun_function(action, state, inputs={}, name=action.name)
assert result == {"count": 1}


async def test__arun_function_incorrect_result_type():
"""Tests that we can run an async function"""
action = base_action_incorrect_result_type_async
state = State({})
with pytest.raises(ValueError, match="returned a non-dict"):
await _arun_function(action, state, inputs={}, name=action.name)


async def test__arun_function_with_inputs():
"""Tests that we can run an async function"""
action = base_counter_action_with_inputs_async
state = State({})
result = await _arun_function(action, state, inputs={"additional_increment": 1})
result = await _arun_function(
action, state, inputs={"additional_increment": 1}, name=action.name
)
assert result == {"count": 2}


Expand Down Expand Up @@ -376,6 +415,24 @@ def inputs(self) -> list[str]:
return ["additional_increment"]


class SingleStepActionIncorrectResultType(SingleStepAction):
def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]:
return "not a dict", state

@property
def reads(self) -> list[str]:
return []

@property
def writes(self) -> list[str]:
return []


class SingleStepActionIncorrectResultTypeAsync(SingleStepActionIncorrectResultType):
async def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]:
return "not a dict", state


class SingleStepCounterAsync(SingleStepCounter):
async def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]:
await asyncio.sleep(0.0001) # just so we can make this *truly* async
Expand Down Expand Up @@ -438,6 +495,39 @@ def writes(self) -> list[str]:
return ["count", "tracker"]


class StreamingActionIncorrectResultType(StreamingAction):
def stream_run(self, state: State, **run_kwargs) -> Generator[dict, None, dict]:
yield {}
return "not a dict"

@property
def reads(self) -> list[str]:
return []

@property
def writes(self) -> list[str]:
return []

def update(self, result: dict, state: State) -> State:
return state


class StreamingSingleStepActionIncorrectResultType(SingleStepStreamingAction):
def stream_run_and_update(
self, state: State, **run_kwargs
) -> Generator[dict, None, Tuple[dict, State]]:
yield {}
return "not a dict", state

@property
def reads(self) -> list[str]:
return []

@property
def writes(self) -> list[str]:
return []


base_single_step_counter = SingleStepCounter()
base_single_step_counter_async = SingleStepCounterAsync()
base_single_step_counter_with_inputs = SingleStepCounterWithInputs()
Expand All @@ -446,6 +536,9 @@ def writes(self) -> list[str]:
base_streaming_counter = StreamingCounter()
base_streaming_single_step_counter = SingleStepStreamingCounter()

base_single_step_action_incorrect_result_type = SingleStepActionIncorrectResultType()
base_single_step_action_incorrect_result_type_async = SingleStepActionIncorrectResultTypeAsync()


def test__run_single_step_action():
action = base_single_step_counter.with_name("counter")
Expand All @@ -458,6 +551,20 @@ def test__run_single_step_action():
assert state.subset("count", "tracker").get_all() == {"count": 2, "tracker": [1, 2]}


def test__run_single_step_action_incorrect_result_type():
action = base_single_step_action_incorrect_result_type.with_name("counter")
state = State({"count": 0, "tracker": []})
with pytest.raises(ValueError, match="returned a non-dict"):
_run_single_step_action(action, state, inputs={})


async def test__arun_single_step_action_incorrect_result_type():
action = base_single_step_action_incorrect_result_type_async.with_name("counter")
state = State({"count": 0, "tracker": []})
with pytest.raises(ValueError, match="returned a non-dict"):
await _arun_single_step_action(action, state, inputs={})


def test__run_single_step_action_with_inputs():
action = base_single_step_counter_with_inputs.with_name("counter")
state = State({"count": 0, "tracker": []})
Expand Down Expand Up @@ -531,6 +638,22 @@ def test__run_multistep_streaming_action():
assert state.subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]}


def test__run_streaming_action_incorrect_result_type():
action = StreamingActionIncorrectResultType()
state = State()
with pytest.raises(ValueError, match="returned a non-dict"):
gen = _run_multi_step_streaming_action(action, state, inputs={})
collections.deque(gen, maxlen=0) # exhaust the generator


def test__run_single_step_streaming_action_incorrect_result_type():
action = StreamingSingleStepActionIncorrectResultType()
state = State()
with pytest.raises(ValueError, match="returned a non-dict"):
gen = _run_single_step_streaming_action(action, state, inputs={})
collections.deque(gen, maxlen=0) # exhaust the generator


def test__run_single_step_streaming_action():
action = base_streaming_single_step_counter.with_name("counter")
state = State({"count": 0, "tracker": []})
Expand Down

0 comments on commit 5b817c6

Please sign in to comment.