From f51fdcf9877af10a13a860d00e7e3df1e1a74ea7 Mon Sep 17 00:00:00 2001 From: elijahbenizzy Date: Tue, 27 Aug 2024 22:41:32 -0700 Subject: [PATCH] Adds additional conditions/transitions condition.lmda is just a shortcut for Condition. This also has a commented out condition.exists -- this is ready, but we do not support optional state items, so it is not feasible yet. --- burr/core/action.py | 38 +++++++++++++++++++++++++++++++++++++- tests/core/test_action.py | 24 ++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 1 deletion(-) diff --git a/burr/core/action.py b/burr/core/action.py index 5acabfe7..61381afc 100644 --- a/burr/core/action.py +++ b/burr/core/action.py @@ -186,7 +186,13 @@ def __repr__(self): class Condition(Function): KEY = "PROCEED" - def __init__(self, keys: List[str], resolver: Callable[[State], bool], name: str = None): + def __init__( + self, + keys: List[str], + resolver: Callable[[State], bool], + name: str = None, + optional_keys: List[str] = None, + ): """Base condition class. Chooses keys to read from the state and a resolver function. If you want a condition that defaults to true, use Condition.default or just default. @@ -202,6 +208,7 @@ def __init__(self, keys: List[str], resolver: Callable[[State], bool], name: str """ self._resolver = resolver self._keys = keys + self._optional_keys = optional_keys if optional_keys is not None else [] self._name = name @staticmethod @@ -240,6 +247,33 @@ def condition_func(state: State) -> bool: return Condition(keys, condition_func, name=expr) + @staticmethod + def lmda(resolver: Callable[[State], bool], state_keys: List[str]) -> "Condition": + """Returns a condition that evaluates the given function of State. + Note that this is just a simple wrapper over the Condition object. + + This does not (yet) support optional (default) arguments. + + :param fn: + :param state_keys: + :return: + """ + return Condition(state_keys, resolver, name=f"lmda_{resolver.__name__}_") + + # TODO -- decide what to do with this when we have optional keys + # @staticmethod + # def exists(*keys: str) -> "Condition": + # """Returns a condition that checks if the given key exists in the state. + # + # :param key: Key to check for existence + # :return: A condition that checks if the given key exists in the state + # """ + # return Condition( + # list(keys), + # lambda state: all(item in state for item in keys), + # name=f"exists_{'_and_'.join(sorted(keys))}" + # ) + def _validate(self, state: State): missing_keys = set(self._keys) - set(state.keys()) if missing_keys: @@ -340,6 +374,8 @@ def __invert__(self): default = Condition.default when = Condition.when expr = Condition.expr +lmda = Condition.lmda +# exists = Condition.exists class Result(Action): diff --git a/tests/core/test_action.py b/tests/core/test_action.py index b623b598..6a477973 100644 --- a/tests/core/test_action.py +++ b/tests/core/test_action.py @@ -176,6 +176,30 @@ def test_condition_or(): assert cond_or.run(State({"foo": "baz", "baz": "corge"})) == {Condition.KEY: False} +def test_condition_lmda(): + cond = Condition.lmda(lambda state: state["foo"] == "bar", ["foo"]) + assert cond.reads == ["foo"] + assert cond.run(State({"foo": "bar"})) == {Condition.KEY: True} + assert cond.run(State({"foo": "baz"})) == {Condition.KEY: False} + + +# TODO -- add this in once we decide what to do with optional keys... +# def test_condition_exists_single(): +# cond = Condition.exists("foo") +# assert cond.name == "exists_foo" +# assert cond.reads == ["foo"] +# assert cond.run(State({"foo": "baz", "bar": "baz"})) == {Condition.KEY: True} +# assert cond.run(State({})) == {Condition.KEY: False} +# +# +# def test_condition_exists_double(): +# cond = Condition.exists("foo", "bar") +# assert cond.name == "exists_bar_and_foo" +# assert cond.reads == ["foo"] +# assert cond.run(State({"foo": "baz", "bar": "baz"})) == {Condition.KEY: True} +# assert cond.run(State({"foo" : "bar"})) == {Condition.KEY: False} + + def test_result(): result = Result("foo", "bar") assert result.run(State({"foo": "baz", "bar": "qux", "baz": "quux"})) == {