Skip to content

Commit

Permalink
Adds additional safeguards around writing state
Browse files Browse the repository at this point in the history
If your action declares a state write that it does not apply, it will
now error out. This is backwards compatible, as that was always
considered undefined behavior, and bad practice.

Note that this does not mess with deleting state in actions, as that is
not declared in the writes() field. We will want to handle it later.
  • Loading branch information
elijahbenizzy committed Mar 22, 2024
1 parent 2cee6e2 commit e244dab
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 3 deletions.
17 changes: 16 additions & 1 deletion burr/core/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,16 @@ def _state_update(state_to_modify: State, modified_state: State) -> State:
return state_to_modify.merge(modified_state).wipe(delete=deleted_keys_filtered)


def _validate_reducer_writes(reducer: Reducer, state: State, name: str) -> None:
required_writes = reducer.writes
missing_writes = set(reducer.writes) - state.keys()
if len(missing_writes) > 0:
raise ValueError(
f"State is missing write keys after running: {name}. Missing keys are: {missing_writes}. "
f"Has writes: {required_writes}"
)


def _run_reducer(reducer: Reducer, state: State, result: dict, name: str) -> State:
"""Runs the reducer, returning the new state. Note this restricts the
keys in the state to only those that the function writes.
Expand All @@ -132,6 +142,7 @@ def _run_reducer(reducer: Reducer, state: State, result: dict, name: str) -> Sta
f"Action {name} attempted to write to keys {extra_keys} "
f"that it did not declare. It declared: ({reducer.writes})!"
)
_validate_reducer_writes(reducer, new_state, name)
return _state_update(state, new_state)


Expand Down Expand Up @@ -184,6 +195,7 @@ def _run_single_step_action(
action.validate_inputs(inputs)
result, new_state = action.run_and_update(state, **inputs)
out = result, _state_update(state, new_state)
_validate_reducer_writes(action, new_state, action.name)
return out


Expand All @@ -192,7 +204,9 @@ def _run_single_step_streaming_action(
) -> Generator[dict, None, Tuple[dict, State]]:
action.validate_inputs(inputs)
generator = action.stream_run_and_update(state, **inputs)
return (yield from generator)
result, state = yield from generator
_validate_reducer_writes(action, state, action.name)
return result, state


def _run_multi_step_streaming_action(
Expand All @@ -212,6 +226,7 @@ async def _arun_single_step_action(
state_to_use = state
action.validate_inputs(inputs)
result, new_state = await action.run_and_update(state_to_use, **inputs)
_validate_reducer_writes(action, new_state, action.name)
return result, _state_update(state, new_state)


Expand Down
107 changes: 105 additions & 2 deletions tests/core/test_application.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import collections
import logging
from typing import Awaitable, Callable, Generator, Tuple

Expand All @@ -8,6 +9,7 @@
from burr.core.action import (
Action,
Condition,
Reducer,
Result,
SingleStepAction,
SingleStepStreamingAction,
Expand Down Expand Up @@ -253,6 +255,107 @@ async def test__arun_function_with_inputs():
assert result == {"count": 2}


def test_run_reducer_errors_missing_writes():
class BrokenReducer(Reducer):
def update(self, result: dict, state: State) -> State:
return state.update(present_value=1)

@property
def writes(self) -> list[str]:
return ["missing_value", "present_value"]

reducer = BrokenReducer()
state = State()
with pytest.raises(ValueError, match="missing_value"):
_run_reducer(reducer, state, {}, "broken_reducer")


def test_run_single_step_action_errors_missing_writes():
class BrokenAction(SingleStepAction):
@property
def reads(self) -> list[str]:
return []

def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]:
return {"present_value": 1}, state.update(present_value=1)

@property
def writes(self) -> list[str]:
return ["missing_value", "present_value"]

action = BrokenAction()
state = State()
with pytest.raises(ValueError, match="missing_value"):
_run_single_step_action(action, state, inputs={})


async def test_arun_single_step_action_errors_missing_writes():
class BrokenAction(SingleStepAction):
@property
def reads(self) -> list[str]:
return []

async def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]:
await asyncio.sleep(0.0001) # just so we can make this *truly* async
return {"present_value": 1}, state.update(present_value=1)

@property
def writes(self) -> list[str]:
return ["missing_value", "present_value"]

action = BrokenAction()
state = State()
with pytest.raises(ValueError, match="missing_value"):
await _arun_single_step_action(action, state, inputs={})


def test_run_single_step_streaming_action_errors_missing_write():
class BrokenAction(SingleStepStreamingAction):
def stream_run_and_update(
self, state: State, **run_kwargs
) -> Generator[dict, None, Tuple[dict, State]]:
yield {}
return {"present_value": 1}, state.update(present_value=1)

@property
def reads(self) -> list[str]:
return []

@property
def writes(self) -> list[str]:
return ["missing_value", "present_value"]

action = BrokenAction()
state = State()
with pytest.raises(ValueError, match="missing_value"):
gen = _run_single_step_streaming_action(action, state, inputs={})
collections.deque(gen, maxlen=0) # exhaust the generator


def test_run_multi_step_streaming_action_errors_missing_write():
class BrokenAction(StreamingAction):
def stream_run(self, state: State, **run_kwargs) -> Generator[dict, None, dict]:
yield {}
return {"present_value": 1}

def update(self, result: dict, state: State) -> State:
return state.update(present_value=1)

@property
def reads(self) -> list[str]:
return []

@property
def writes(self) -> list[str]:
return ["missing_value", "present_value"]

action = BrokenAction()
state = State()
with pytest.raises(ValueError, match="missing_value"):
gen = _run_multi_step_streaming_action(action, state, inputs={})
collections.deque(gen, maxlen=0) # exhaust the generator


class SingleStepCounter(SingleStepAction):
def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]:
result = {"count": state["count"] + 1 + sum([0] + list(run_kwargs.values()))}
Expand Down Expand Up @@ -398,11 +501,11 @@ def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]:

@property
def reads(self) -> list[str]:
return ["to_delete"]
return []

@property
def writes(self) -> list[str]:
return ["to_delete"]
return []


def test__run_single_step_action_deletes_state():
Expand Down

0 comments on commit e244dab

Please sign in to comment.