From e244dab4c2c0fc2b090f7ac8bc685c8a858789a7 Mon Sep 17 00:00:00 2001 From: elijahbenizzy Date: Thu, 21 Mar 2024 10:44:16 -0700 Subject: [PATCH] Adds additional safeguards around writing state If your action declares a state write that it does not apply, it will now error out. This is backwards compatible, as that was always considered undefined behavior, and bad practice. Note that this does not mess with deleting state in actions, as that is not declared in the writes() field. We will want to handle it later. --- burr/core/application.py | 17 +++++- tests/core/test_application.py | 107 ++++++++++++++++++++++++++++++++- 2 files changed, 121 insertions(+), 3 deletions(-) diff --git a/burr/core/application.py b/burr/core/application.py index d5061755..64bfb39a 100644 --- a/burr/core/application.py +++ b/burr/core/application.py @@ -113,6 +113,16 @@ def _state_update(state_to_modify: State, modified_state: State) -> State: return state_to_modify.merge(modified_state).wipe(delete=deleted_keys_filtered) +def _validate_reducer_writes(reducer: Reducer, state: State, name: str) -> None: + required_writes = reducer.writes + missing_writes = set(reducer.writes) - state.keys() + if len(missing_writes) > 0: + raise ValueError( + f"State is missing write keys after running: {name}. Missing keys are: {missing_writes}. " + f"Has writes: {required_writes}" + ) + + def _run_reducer(reducer: Reducer, state: State, result: dict, name: str) -> State: """Runs the reducer, returning the new state. Note this restricts the keys in the state to only those that the function writes. @@ -132,6 +142,7 @@ def _run_reducer(reducer: Reducer, state: State, result: dict, name: str) -> Sta f"Action {name} attempted to write to keys {extra_keys} " f"that it did not declare. It declared: ({reducer.writes})!" ) + _validate_reducer_writes(reducer, new_state, name) return _state_update(state, new_state) @@ -184,6 +195,7 @@ def _run_single_step_action( action.validate_inputs(inputs) result, new_state = action.run_and_update(state, **inputs) out = result, _state_update(state, new_state) + _validate_reducer_writes(action, new_state, action.name) return out @@ -192,7 +204,9 @@ def _run_single_step_streaming_action( ) -> Generator[dict, None, Tuple[dict, State]]: action.validate_inputs(inputs) generator = action.stream_run_and_update(state, **inputs) - return (yield from generator) + result, state = yield from generator + _validate_reducer_writes(action, state, action.name) + return result, state def _run_multi_step_streaming_action( @@ -212,6 +226,7 @@ async def _arun_single_step_action( state_to_use = state action.validate_inputs(inputs) result, new_state = await action.run_and_update(state_to_use, **inputs) + _validate_reducer_writes(action, new_state, action.name) return result, _state_update(state, new_state) diff --git a/tests/core/test_application.py b/tests/core/test_application.py index 62e0187b..04ea9980 100644 --- a/tests/core/test_application.py +++ b/tests/core/test_application.py @@ -1,4 +1,5 @@ import asyncio +import collections import logging from typing import Awaitable, Callable, Generator, Tuple @@ -8,6 +9,7 @@ from burr.core.action import ( Action, Condition, + Reducer, Result, SingleStepAction, SingleStepStreamingAction, @@ -253,6 +255,107 @@ async def test__arun_function_with_inputs(): assert result == {"count": 2} +def test_run_reducer_errors_missing_writes(): + class BrokenReducer(Reducer): + def update(self, result: dict, state: State) -> State: + return state.update(present_value=1) + + @property + def writes(self) -> list[str]: + return ["missing_value", "present_value"] + + reducer = BrokenReducer() + state = State() + with pytest.raises(ValueError, match="missing_value"): + _run_reducer(reducer, state, {}, "broken_reducer") + + +def test_run_single_step_action_errors_missing_writes(): + class BrokenAction(SingleStepAction): + @property + def reads(self) -> list[str]: + return [] + + def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]: + return {"present_value": 1}, state.update(present_value=1) + + @property + def writes(self) -> list[str]: + return ["missing_value", "present_value"] + + action = BrokenAction() + state = State() + with pytest.raises(ValueError, match="missing_value"): + _run_single_step_action(action, state, inputs={}) + + +async def test_arun_single_step_action_errors_missing_writes(): + class BrokenAction(SingleStepAction): + @property + def reads(self) -> list[str]: + return [] + + async def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]: + await asyncio.sleep(0.0001) # just so we can make this *truly* async + return {"present_value": 1}, state.update(present_value=1) + + @property + def writes(self) -> list[str]: + return ["missing_value", "present_value"] + + action = BrokenAction() + state = State() + with pytest.raises(ValueError, match="missing_value"): + await _arun_single_step_action(action, state, inputs={}) + + +def test_run_single_step_streaming_action_errors_missing_write(): + class BrokenAction(SingleStepStreamingAction): + def stream_run_and_update( + self, state: State, **run_kwargs + ) -> Generator[dict, None, Tuple[dict, State]]: + yield {} + return {"present_value": 1}, state.update(present_value=1) + + @property + def reads(self) -> list[str]: + return [] + + @property + def writes(self) -> list[str]: + return ["missing_value", "present_value"] + + action = BrokenAction() + state = State() + with pytest.raises(ValueError, match="missing_value"): + gen = _run_single_step_streaming_action(action, state, inputs={}) + collections.deque(gen, maxlen=0) # exhaust the generator + + +def test_run_multi_step_streaming_action_errors_missing_write(): + class BrokenAction(StreamingAction): + def stream_run(self, state: State, **run_kwargs) -> Generator[dict, None, dict]: + yield {} + return {"present_value": 1} + + def update(self, result: dict, state: State) -> State: + return state.update(present_value=1) + + @property + def reads(self) -> list[str]: + return [] + + @property + def writes(self) -> list[str]: + return ["missing_value", "present_value"] + + action = BrokenAction() + state = State() + with pytest.raises(ValueError, match="missing_value"): + gen = _run_multi_step_streaming_action(action, state, inputs={}) + collections.deque(gen, maxlen=0) # exhaust the generator + + class SingleStepCounter(SingleStepAction): def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]: result = {"count": state["count"] + 1 + sum([0] + list(run_kwargs.values()))} @@ -398,11 +501,11 @@ def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]: @property def reads(self) -> list[str]: - return ["to_delete"] + return [] @property def writes(self) -> list[str]: - return ["to_delete"] + return [] def test__run_single_step_action_deletes_state():