From 398f0aa14ac2bf8dfaedd3452198dc2e09752e38 Mon Sep 17 00:00:00 2001 From: zilto Date: Mon, 9 Sep 2024 11:56:11 -0400 Subject: [PATCH] added Select object for transitions --- burr/core/__init__.py | 3 ++- burr/core/action.py | 45 +++++++++++++++++++++++++++++++++++++++ burr/core/application.py | 5 ++++- burr/core/graph.py | 41 +++++++++++++++++++++++++---------- tests/core/test_action.py | 34 +++++++++++++++++++++++++++++ 5 files changed, 115 insertions(+), 13 deletions(-) diff --git a/burr/core/__init__.py b/burr/core/__init__.py index 3d61c144..763a717f 100644 --- a/burr/core/__init__.py +++ b/burr/core/__init__.py @@ -1,4 +1,4 @@ -from burr.core.action import Action, Condition, Result, action, default, expr, when +from burr.core.action import Action, Condition, Result, Select, action, default, expr, when from burr.core.application import ( Application, ApplicationBuilder, @@ -18,6 +18,7 @@ "default", "expr", "Result", + "Select", "State", "when", ] diff --git a/burr/core/action.py b/burr/core/action.py index 61381afc..1e42616a 100644 --- a/burr/core/action.py +++ b/burr/core/action.py @@ -17,6 +17,7 @@ List, Optional, Protocol, + Sequence, Tuple, TypeVar, Union, @@ -378,6 +379,50 @@ def __invert__(self): # exists = Condition.exists +# TODO type `resolver` to prevent user-facing type-mismatch +# e.g., a user provided `def foo(state: State, actions: list)` +# would be too restrictive for a `Sequence` type +class Select(Function): + def __init__( + self, + keys: List[str], + resolver: Callable[[State, Sequence[Action]], str], + name: str = None, + ): + self._keys = keys + self._resolver = resolver + self._name = name + # TODO add a `default` kwarg; + # could an Action, action_name: str, or action_idx: int + # `default` value could be returned if `_resolver` returns None + + @property + def name(self) -> str: + return self._name + + @property + def reads(self) -> list[str]: + return self._keys + + @property + def resolver(self) -> Callable[[State, Sequence[Action]], str]: + return self._resolver + + def __repr__(self) -> str: + return f"select: {self._name}" + + def _validate(self, state: State): + missing_keys = set(self._keys) - set(state.keys()) + if missing_keys: + raise ValueError( + f"Missing keys in state required by condition: {self} {', '.join(missing_keys)}" + ) + + def run(self, state: State, possible_actions: Sequence[Action]) -> str: + self._validate(state) + return self._resolver(state, possible_actions) + + class Result(Action): def __init__(self, *fields: str): """Represents a result action. This is purely a convenience class to diff --git a/burr/core/application.py b/burr/core/application.py index 6ae01f7d..fd611281 100644 --- a/burr/core/application.py +++ b/burr/core/application.py @@ -16,6 +16,7 @@ List, Literal, Optional, + Sequence, Set, Tuple, TypeVar, @@ -33,6 +34,7 @@ Condition, Function, Reducer, + Select, SingleStepAction, SingleStepStreamingAction, StreamingAction, @@ -2005,7 +2007,8 @@ def with_actions( def with_transitions( self, *transitions: Union[ - Tuple[Union[str, list[str]], str], Tuple[Union[str, list[str]], str, Condition] + Tuple[Union[str, Sequence[str]], Union[str, Sequence[str]]], + Tuple[Union[str, Sequence[str]], Union[str, Sequence[str]], Union[Condition, Select]], ], ) -> "ApplicationBuilder": """Adds transitions to the application. Transitions are specified as tuples of either: diff --git a/burr/core/graph.py b/burr/core/graph.py index ff70a4b1..9ea43683 100644 --- a/burr/core/graph.py +++ b/burr/core/graph.py @@ -3,10 +3,10 @@ import inspect import logging import pathlib -from typing import Any, Callable, List, Literal, Optional, Set, Tuple, Union +from typing import Any, Callable, List, Literal, Optional, Sequence, Set, Tuple, Union from burr import telemetry -from burr.core.action import Action, Condition, create_action, default +from burr.core.action import Action, Condition, Select, create_action, default from burr.core.state import State from burr.core.validation import BASE_ERROR_MESSAGE, assert_set @@ -118,6 +118,13 @@ def get_next_node( return self._action_map[entrypoint] possibilities = self._adjacency_map[prior_step] for next_action, condition in possibilities: + # When `Select` is used, all possibilities have the same `condition` attached. + # Hitting a `Select` will necessarily exit the for loop + if isinstance(condition, Select): + possible_actions = [self._action_map[p[0]] for p in possibilities] + selected_action = condition.run(state, possible_actions) + return self._action_map[selected_action] + if condition.run(state)[Condition.KEY]: return self._action_map[next_action] return None @@ -235,7 +242,7 @@ class GraphBuilder: def __init__(self): """Initializes the graph builder.""" - self.transitions: Optional[List[Tuple[str, str, Condition]]] = None + self.transitions: Optional[List[Tuple[str, str, Union[Condition, Select]]]] = None self.actions: Optional[List[Action]] = None def with_actions( @@ -269,7 +276,8 @@ def with_actions( def with_transitions( self, *transitions: Union[ - Tuple[Union[str, list[str]], str], Tuple[Union[str, list[str]], str, Condition] + Tuple[Union[str, Sequence[str]], Union[str, Sequence[str]]], + Tuple[Union[str, Sequence[str]], Union[str, Sequence[str]], Union[Condition, Select]], ], ) -> "GraphBuilder": """Adds transitions to the graph. Transitions are specified as tuples of either: @@ -291,14 +299,25 @@ def with_transitions( condition = conditions[0] else: condition = default - if not isinstance(from_, list): + # check required because issubclass(str, Sequence) == True + if isinstance(from_, Sequence) and not isinstance(from_, str): + from_ = [*from_] + else: from_ = [from_] - for action in from_: - if not isinstance(action, str): - raise ValueError(f"Transition source must be a string, not {action}") - if not isinstance(to_, str): - raise ValueError(f"Transition target must be a string, not {to_}") - self.transitions.append((action, to_, condition)) + if isinstance(to_, Sequence) and not isinstance(to_, str): + if not isinstance(condition, Select): + raise ValueError( + "Transition with multiple targets require a `Select` condition." + ) + else: + to_ = [to_] + for source in from_: + for target in to_: + if not isinstance(source, str): + raise ValueError(f"Transition source must be a string, not {source}") + if not isinstance(target, str): + raise ValueError(f"Transition target must be a string, not {to_}") + self.transitions.append((source, target, condition)) return self def build(self) -> Graph: diff --git a/tests/core/test_action.py b/tests/core/test_action.py index 6a477973..0fa4f3e6 100644 --- a/tests/core/test_action.py +++ b/tests/core/test_action.py @@ -12,6 +12,7 @@ Function, Input, Result, + Select, SingleStepAction, SingleStepStreamingAction, StreamingAction, @@ -200,6 +201,39 @@ def test_condition_lmda(): # assert cond.run(State({"foo" : "bar"})) == {Condition.KEY: False} +def test_select_constant(): + select = Select([], resolver=lambda *args: "foo") + selected_action = select.run(State(), []) + + assert selected_action == "foo" + + +def test_select_determistic(): + @action(reads=[], writes=[]) + def bar(state): + return state + + @action(reads=[], writes=[]) + def baz(state): + return state + + def length_resolver(state: State, actions: list[Action]) -> str: + foo = state["foo"] + action_idx = len(foo) % len(actions) + return actions[action_idx].name + + foo1 = "len=3" # % 2 = 1 + foo2 = "len_is_8" # % 2 = 0 + actions = [create_action(bar, "bar"), create_action(baz, "baz")] + select = Select(["foo"], resolver=length_resolver) + + selected_1 = select.run(State({"foo": foo1}), possible_actions=actions) + assert selected_1 == actions[len(foo1) % len(actions)].name + + selected_2 = select.run(State({"foo": foo2}), possible_actions=actions) + assert selected_2 == actions[len(foo2) % len(actions)].name + + def test_result(): result = Result("foo", "bar") assert result.run(State({"foo": "baz", "bar": "qux", "baz": "quux"})) == {