Skip to content

Commit

Permalink
Adds additional conditions/transitions
Browse files Browse the repository at this point in the history
condition.lmda is just a shortcut for Condition. This also has a
commented out condition.exists -- this is ready, but we do not support
optional state items, so it is not feasible yet.
  • Loading branch information
elijahbenizzy committed Aug 28, 2024
1 parent b51b221 commit f51fdcf
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 1 deletion.
38 changes: 37 additions & 1 deletion burr/core/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,13 @@ def __repr__(self):
class Condition(Function):
KEY = "PROCEED"

def __init__(self, keys: List[str], resolver: Callable[[State], bool], name: str = None):
def __init__(
self,
keys: List[str],
resolver: Callable[[State], bool],
name: str = None,
optional_keys: List[str] = None,
):
"""Base condition class. Chooses keys to read from the state and a resolver function.
If you want a condition that defaults to true, use Condition.default or just default.
Expand All @@ -202,6 +208,7 @@ def __init__(self, keys: List[str], resolver: Callable[[State], bool], name: str
"""
self._resolver = resolver
self._keys = keys
self._optional_keys = optional_keys if optional_keys is not None else []
self._name = name

@staticmethod
Expand Down Expand Up @@ -240,6 +247,33 @@ def condition_func(state: State) -> bool:

return Condition(keys, condition_func, name=expr)

@staticmethod
def lmda(resolver: Callable[[State], bool], state_keys: List[str]) -> "Condition":
"""Returns a condition that evaluates the given function of State.
Note that this is just a simple wrapper over the Condition object.
This does not (yet) support optional (default) arguments.
:param fn:
:param state_keys:
:return:
"""
return Condition(state_keys, resolver, name=f"lmda_{resolver.__name__}_")

# TODO -- decide what to do with this when we have optional keys
# @staticmethod
# def exists(*keys: str) -> "Condition":
# """Returns a condition that checks if the given key exists in the state.
#
# :param key: Key to check for existence
# :return: A condition that checks if the given key exists in the state
# """
# return Condition(
# list(keys),
# lambda state: all(item in state for item in keys),
# name=f"exists_{'_and_'.join(sorted(keys))}"
# )

def _validate(self, state: State):
missing_keys = set(self._keys) - set(state.keys())
if missing_keys:
Expand Down Expand Up @@ -340,6 +374,8 @@ def __invert__(self):
default = Condition.default
when = Condition.when
expr = Condition.expr
lmda = Condition.lmda
# exists = Condition.exists


class Result(Action):
Expand Down
24 changes: 24 additions & 0 deletions tests/core/test_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,30 @@ def test_condition_or():
assert cond_or.run(State({"foo": "baz", "baz": "corge"})) == {Condition.KEY: False}


def test_condition_lmda():
cond = Condition.lmda(lambda state: state["foo"] == "bar", ["foo"])
assert cond.reads == ["foo"]
assert cond.run(State({"foo": "bar"})) == {Condition.KEY: True}
assert cond.run(State({"foo": "baz"})) == {Condition.KEY: False}


# TODO -- add this in once we decide what to do with optional keys...
# def test_condition_exists_single():
# cond = Condition.exists("foo")
# assert cond.name == "exists_foo"
# assert cond.reads == ["foo"]
# assert cond.run(State({"foo": "baz", "bar": "baz"})) == {Condition.KEY: True}
# assert cond.run(State({})) == {Condition.KEY: False}
#
#
# def test_condition_exists_double():
# cond = Condition.exists("foo", "bar")
# assert cond.name == "exists_bar_and_foo"
# assert cond.reads == ["foo"]
# assert cond.run(State({"foo": "baz", "bar": "baz"})) == {Condition.KEY: True}
# assert cond.run(State({"foo" : "bar"})) == {Condition.KEY: False}


def test_result():
result = Result("foo", "bar")
assert result.run(State({"foo": "baz", "bar": "qux", "baz": "quux"})) == {
Expand Down

0 comments on commit f51fdcf

Please sign in to comment.