From addb081d1ae85b5d7a51787960a5f10cb2475ecb Mon Sep 17 00:00:00 2001 From: elijahbenizzy Date: Sun, 4 Feb 2024 23:15:04 -0800 Subject: [PATCH] WIP --- api_examples/code_generation.py | 77 ---- api_examples/gpt_4.py | 75 ---- api_examples/model_training.py | 76 ---- burr/core/__init__.py | 16 +- burr/core/{function.py => action.py} | 135 +++--- burr/core/application.py | 522 ++++++++++++++++------- burr/core/implementations.py | 34 ++ burr/core/state.py | 190 +++++++-- burr/implementations.py | 21 - burr/integrations/base.py | 6 + burr/integrations/hamilton.py | 124 ++++-- burr/integrations/streamlit.py | 270 ++++++++++++ burr/lifecycle/__init__.py | 25 ++ burr/lifecycle/base.py | 100 +++++ burr/lifecycle/default.py | 56 +++ burr/lifecycle/internal.py | 187 ++++++++ examples/counter/application.py | 45 ++ examples/counter/requirements.txt | 1 + examples/counter/streamlit_app.py | 64 +++ examples/cowsay/application.py | 103 +++++ examples/cowsay/requirements.txt | 2 + examples/cowsay/streamlit_app.py | 69 +++ examples/gpt/app.py | 427 ------------------ examples/gpt/application.py | 284 ++++++++++++ examples/gpt/capabilities.py | 128 ------ examples/gpt/requirements.txt | 5 - examples/gpt/run.py | 93 ---- examples/gpt/server.py | 110 ----- examples/gpt/streamlit_app.py | 92 ++++ examples/ml_training.py | 110 +++++ examples/simulation.py | 0 requirements-test.txt | 2 + tests/core/test_action.py | 117 +++++ tests/core/test_application.py | 471 ++++++++++++++++++-- tests/core/test_function.py | 37 -- tests/core/test_implementations.py | 15 + tests/core/test_state.py | 23 +- tests/integrations/test_burr_hamilton.py | 98 +++++ tests/pytest.ini | 2 + 39 files changed, 2845 insertions(+), 1367 deletions(-) delete mode 100644 api_examples/code_generation.py delete mode 100644 api_examples/gpt_4.py delete mode 100644 api_examples/model_training.py rename burr/core/{function.py => action.py} (62%) create mode 100644 burr/core/implementations.py delete mode 100644 burr/implementations.py create mode 100644 burr/integrations/base.py create mode 100644 burr/integrations/streamlit.py create mode 100644 burr/lifecycle/__init__.py create mode 100644 burr/lifecycle/base.py create mode 100644 burr/lifecycle/default.py create mode 100644 burr/lifecycle/internal.py create mode 100644 examples/counter/application.py create mode 100644 examples/counter/requirements.txt create mode 100644 examples/counter/streamlit_app.py create mode 100644 examples/cowsay/application.py create mode 100644 examples/cowsay/requirements.txt create mode 100644 examples/cowsay/streamlit_app.py delete mode 100644 examples/gpt/app.py create mode 100644 examples/gpt/application.py delete mode 100644 examples/gpt/capabilities.py delete mode 100644 examples/gpt/requirements.txt delete mode 100644 examples/gpt/run.py delete mode 100644 examples/gpt/server.py create mode 100644 examples/gpt/streamlit_app.py create mode 100644 examples/ml_training.py create mode 100644 examples/simulation.py create mode 100644 requirements-test.txt create mode 100644 tests/core/test_action.py delete mode 100644 tests/core/test_function.py create mode 100644 tests/core/test_implementations.py create mode 100644 tests/integrations/test_burr_hamilton.py create mode 100644 tests/pytest.ini diff --git a/api_examples/code_generation.py b/api_examples/code_generation.py deleted file mode 100644 index d72a3251..00000000 --- a/api_examples/code_generation.py +++ /dev/null @@ -1,77 +0,0 @@ -from burr.core import ApplicationBuilder, Condition -from burr.implementations import Placeholder - - -def display(response): - pass - - -def main(): - # Help me figure out which loss function to use for my neural network - prompt = Placeholder("prompt", reads=["prompt"], writes=["chat_history"]) - plan = Placeholder("plan", reads=["chat_history", "prompt", "search_history"], writes=["plan", "todo"]) - - web_search = Placeholder("web_search", reads=["plan"], writes=["search_history"]) - stack_overflow_search = Placeholder("stack_overflow_search", reads=["plan"], writes=["responses", "search_history"]) - mathematica = Placeholder("mathematica", reads=["plan"], writes=["responses", "search_history"]) - - synthesizer = Placeholder("synthesizer", reads=["responses"], writes=["response"]) - response = Placeholder("response", reads=["response"], writes=["chat_history"]) - - agent = ( - ApplicationBuilder() - .with_state(prompt=input(...), responses=[], response=None, chat_history=[...], search_history=[]) - .with_transition(prompt, plan) - .with_transition(plan, web_search, Condition.expr("'web_search' == todo")) # ["web_search", "stack_overflow_search"] - .with_transition(plan, stack_overflow_search, Condition.expr("'stack_overflow_search' == todo")) - .with_transition(plan, mathematica, Condition.expr("'mathematica' == todo")) - .with_transition([web_search, stack_overflow_search, mathematica], plan) - # done - .with_transition(plan, synthesizer, Condition.expr("todo is None")) - .with_transition(synthesizer, response) - .with_transition(response, prompt) - .build() - ) - agent.visualize("./out.png", include_conditions=True) - - # response = Placeholder("response", reads=["response", "safe"], writes=["chat_history"]) - # error = Placeholder("error", reads=["error"], writes=["chat_history", "error"]) - # - # agent = ( - # AgentBuilder() - # .with_state(chat_history=[]) - # .with_transition(prompt, check_safety) - # .with_transition( - # check_safety, decide_mode, Condition.expr("safe") - # ) # if safe, decide what to do next - # .with_transition( - # check_safety, response, Condition.expr("not safe") - # ) # if not safe, go to output - # .with_transition(decide_mode, generate_image, Condition.when(mode="image")) - # .with_transition(decide_mode, generate_text, Condition.when(mode="text")) - # .with_transition(decide_mode, generate_code, Condition.when(mode="code")) - # .with_transition(decide_mode, bing_search, Condition.when(mode="search")) - # .with_transition( - # [generate_image, generate_text, generate_code, bing_search], - # response, - # Condition.expr("not error"), - # ) - # .with_transition( - # [generate_image, generate_text, generate_code, bing_search], - # error, - # Condition.expr("bool(error)"), - # ) - # .with_transition(response, prompt) - # .with_transition(error, prompt) - # .build() - # ) - # agent.visualize("./out", include_conditions=False, include_state=True) - # response, error = agent.run(["response", "error"], terminate_when="any_complete") - # if response is not None: - # display(response) - # else: - # display(error) - - -if __name__ == "__main__": - main() diff --git a/api_examples/gpt_4.py b/api_examples/gpt_4.py deleted file mode 100644 index 5c33a69b..00000000 --- a/api_examples/gpt_4.py +++ /dev/null @@ -1,75 +0,0 @@ -from burr.core import ApplicationBuilder, Condition -from burr.implementations import Placeholder - - -def display(response): - pass - - -def main(): - prompt = Placeholder("prompt", reads=["prompt"], writes=["chat_history", "error"]) - check_safety = Placeholder("check_safety", reads=["prompt"], writes=["safe"]) - decide_mode = Placeholder("decide_mode", reads=["chat_history", "prompt"], writes=["mode"]) - - generate_image = Placeholder( - "generate_image", - reads=["mode", "chat_history"], - writes=["response", "error"], - ) - generate_text = Placeholder( - "generate_text", - reads=["mode", "chat_history"], - writes=["response", "error"], - ) - generate_code = Placeholder( - "generate_code", - reads=["mode", "chat_history"], - writes=["response", "error"], - ) - bing_search = Placeholder( - "bing_search", - reads=["mode", "chat_history"], - writes=["response", "error"], - ) - - response = Placeholder("response", reads=["response", "safe"], writes=["chat_history"]) - error = Placeholder("error", reads=["error"], writes=["chat_history", "error"]) - - agent = ( - ApplicationBuilder() - .with_state(chat_history=[]) - .with_transition(prompt, check_safety) - .with_transition( - check_safety, decide_mode, Condition.expr("safe") - ) # if safe, decide what to do next - .with_transition( - check_safety, response, Condition.expr("not safe") - ) # if not safe, go to output - .with_transition(decide_mode, generate_image, Condition.when(mode="image")) - .with_transition(decide_mode, generate_text, Condition.when(mode="text")) - .with_transition(decide_mode, generate_code, Condition.when(mode="code")) - .with_transition(decide_mode, bing_search, Condition.when(mode="search")) - .with_transition( - [generate_image, generate_text, generate_code, bing_search], - response, - Condition.expr("not error"), - ) - .with_transition( - [generate_image, generate_text, generate_code, bing_search], - error, - Condition.expr("bool(error)"), - ) - .with_transition(response, prompt) - .with_transition(error, prompt) - .build() - ) - agent.visualize("./out", include_conditions=False, include_state=True) - response, error = agent.run(["response", "error"], gate="any_complete") - if response is not None: - display(response) - else: - display(error) - - -if __name__ == "__main__": - main() diff --git a/api_examples/model_training.py b/api_examples/model_training.py deleted file mode 100644 index c641bd7d..00000000 --- a/api_examples/model_training.py +++ /dev/null @@ -1,76 +0,0 @@ -from hamilton import driver - -from burr.core import Result, Condition -from burr.core.application import ApplicationBuilder -from burr.integrations.hamilton import Hamilton, from_state, from_value, update_state, append_state - - -def main(): - # dr = driver.Driver(...) - # train_model = Hamilton( - # "train_model", - # dr, - # inputs={ - # "model": state("model", missing="dont_include"), - # "epochs": state("epochs"), - # "dataset_path": "./dataset.csv" - # }, - # overrides={ - # "training_data": state("training_data", missing="dont_include") - # }, - # final_vars=["model", "epochs", "training_data", "metrics"], - # update=lambda state, results: ( - # state - # .update(model=results["model"]) - # .append(metrics=results["metrics"]) - # .update(training_data=results["training_data"]) - # .update(epochs=state["epochs"] + 1) - # ) - # ) - # train_model = Placeholder( - # "train_model", - # reads=["model", "epochs", "training_data"], - # writes=["model", "epochs", "training_data", "metrics"], - # ) - train_model = Hamilton( - name="train_model", - driver=driver.Driver(...), - inputs={ - "model": from_state("model", missing="drop"), - "epochs": from_state("epochs"), - "dataset_path": from_value("./dataset.csv") - }, - outputs={ - "model": update_state("model"), - "epochs": update_state("epochs"), - "training_data": update_state("training_data"), - "metrics": append_state("metrics"), - }, - ) - terminal = Result("result", fields=["model", "metrics"]) - - agent = ( - ApplicationBuilder() - .with_state(epochs=0) - .with_transition( - train_model, - terminal, - Condition.expr("epochs > 100 or metrics[-1]['loss'] < 0.1"), - ) - .with_loop( - train_model, - ) - .build() - ) - state, result, = agent.run(["result"]) - agent.visualize("./out", include_conditions=True) - for step in iter(agent): - report(step) # call out to w&b - if step.name == "result": - break - result, = agent.run(["result"]) - # nprint(result) - - -if __name__ == "__main__": - main() diff --git a/burr/core/__init__.py b/burr/core/__init__.py index a2034519..52747b3c 100644 --- a/burr/core/__init__.py +++ b/burr/core/__init__.py @@ -1,3 +1,15 @@ +from burr.core.action import Action, Condition, Result, default, expr, when +from burr.core.application import Application, ApplicationBuilder from burr.core.state import State -from burr.core.function import Function, Condition, DEFAULT, Result -from burr.core.application import Application, Transition, ApplicationBuilder + +__all__ = [ + "Action", + "ApplicationBuilder", + "Condition", + "Result", + "default", + "when", + "expr", + "Application", + "State", +] diff --git a/burr/core/function.py b/burr/core/action.py similarity index 62% rename from burr/core/function.py rename to burr/core/action.py index 95d65d68..fe1b8fc5 100644 --- a/burr/core/function.py +++ b/burr/core/action.py @@ -1,101 +1,92 @@ import abc import ast import copy -from typing import Callable, List, Optional +import inspect +from typing import Callable, List from burr.core.state import State class Function(abc.ABC): - def __init__(self, name: Optional[str]): - """Represents a function in a state machine. This is the base class from which: - 1. Custom functions - 2. Conditions - 3. Results - - All extend this class. Note that name is optional so that APIs can set the - name on these actions as part of instantiation. - When they're used, they must have a name set. + @property + @abc.abstractmethod + def reads(self) -> list[str]: + pass - :param name: - """ - self._name = name + @abc.abstractmethod + def run(self, state: State) -> dict: + pass - def with_name(self, name: str) -> "Function": - """Returns a copy of the given function with the given name.""" - copied = copy.copy(self) - copied.set_name(name) - return copied + def is_async(self): + return inspect.iscoroutinefunction(self.run) - def set_name(self, name: str): - self._name = name +class Reducer(abc.ABC): @property - def name(self) -> str: - """Gives the name of this action. This should be unique - across your agent.""" - return self._name - @abc.abstractmethod - def run(self, state: State) -> dict: - """Runs the action, given a state. Returns the result (a dictionary) - of running the action. - - :param state: - :return: - """ + def writes(self) -> list[str]: pass @abc.abstractmethod def update(self, result: dict, state: State) -> State: - """Updates the state given the result of running the action. + pass - Note that if this attempts to access anything not in self.reads() - or writes to anything not in self.writes(), it will fail. +class Action(Function, Reducer, abc.ABC): + def __init__(self): + """Represents an action in a state machine. This is the base class from which: - :param result: Result of running the action - :param state: State to update - :return: The updated state + 1. Custom actions + 2. Conditions + 3. Results + + All extend this class. Note that name is optional so that APIs can set the + name on these actions as part of instantiation. + When they're used, they must have a name set. """ + self._name = None - @property - @abc.abstractmethod - def reads(self) -> list[str]: - """The list of keys in the state that this uses as an input. + def with_name(self, name: str) -> "Action": + """Returns a copy of the given action with the given name. Why do we need this? + We instantiate actions without names, and then set them later. This is a way to + make the API cleaner/consolidate it, and the ApplicationBuilder will end up handling it + for you, in the with_actions(...) method, which is the only way to use actions. - :return: List of keys in the state + Note they can also take in names in the constructor for testing, but otherwise this is + not something users will ever have to think about. + + :param name: Name to set + :return: A new action with the given name """ + if self._name is not None: + raise ValueError( + f"Name of {self} already set to {self._name} -- cannot set name to {name}" + ) + # TODO -- ensure that we're not mutating anything later on + # If we are, we may want to copy more intelligently + new_action = copy.copy(self) + new_action._name = name + return new_action @property - @abc.abstractmethod - def writes(self) -> list[str]: - """The list of keys in the state that this writes to. - - :return: - """ - pass + def name(self) -> str: + """Gives the name of this action. This should be unique + across your agent.""" + return self._name def __repr__(self): - return f"{self.name}({', '.join(self.reads)}) -> {', '.join(self.writes)}" + read_repr = ", ".join(self.reads) if self.reads else "{}" + write_repr = ", ".join(self.writes) if self.writes else "{}" + return f"{self.name}: {read_repr} -> {write_repr}" -class PureFunction(Function, abc.ABC): - def update(self, result: dict, state: State) -> State: - return state - - @property - def writes(self) -> list[str]: - return [] - - -class Condition(PureFunction): +class Condition(Function): KEY = "PROCEED" def __init__(self, keys: List[str], resolver: Callable[[State], bool], name: str = None): - super().__init__(name=name) # TODO -- find a better way of making this unique self._resolver = resolver self._keys = keys + self._name = name @staticmethod def expr(expr: str) -> "Condition": @@ -117,7 +108,6 @@ def visit_Name(self, node): # Compile the expression into a callable function def condition_func(state: State) -> bool: - __globals = state.get_all() # we can get all becuase externally we will subset return eval(compile(tree, "", "eval"), {}, __globals) @@ -126,9 +116,6 @@ def condition_func(state: State) -> bool: def run(self, state: State) -> dict: return {Condition.KEY: self._resolver(state)} - def update(self, result: dict, state: State) -> State: - return state - @property def reads(self) -> list[str]: return self._keys @@ -148,13 +135,19 @@ def condition_func(state: State) -> bool: name = f"{', '.join(f'{key}={value}' for key, value in sorted(kwargs.items()))}" return Condition(keys, condition_func, name=name) + @property + def name(self) -> str: + return self._name -DEFAULT = Condition([], lambda _: True, name="default") +default = Condition([], lambda _: True, name="default") +when = Condition.when +expr = Condition.expr -class Result(Function): - def __init__(self, name: str, fields: list[str]): - super().__init__(name) + +class Result(Action): + def __init__(self, fields: list[str]): + super(Result, self).__init__() self._fields = fields def run(self, state: State) -> dict: @@ -169,4 +162,4 @@ def reads(self) -> list[str]: @property def writes(self) -> list[str]: - return self._fields + return [] diff --git a/burr/core/application.py b/burr/core/application.py index 9200ff03..ea03d6d0 100644 --- a/burr/core/application.py +++ b/burr/core/application.py @@ -1,11 +1,12 @@ import collections import dataclasses -import dataclasses import logging -from typing import Any, List, Literal, Tuple, Union, Optional +from typing import Any, AsyncGenerator, Generator, List, Literal, Optional, Set, Tuple, Union -from burr.core.function import Function, Condition, DEFAULT +from burr.core.action import Action, Condition, Function, Reducer, default from burr.core.state import State +from burr.lifecycle.base import LifecycleAdapter +from burr.lifecycle.internal import LifecycleAdapterSet logger = logging.getLogger(__name__) @@ -14,8 +15,8 @@ class Transition: """Internal, utility class""" - from_: Function - to: Function + from_: Action + to: Action condition: Condition @@ -24,7 +25,7 @@ class Transition: PRIOR_STEP = "__PRIOR_STEP" -def run_function(function: Function, state: State) -> dict: +def _run_function(function: Function, state: State) -> dict: """Runs a function, returning the result of running the function. Note this restricts the keys in the state to only those that the function reads. @@ -33,86 +34,209 @@ def run_function(function: Function, state: State) -> dict: :param state: :return: """ + if function.is_async(): + raise ValueError( + f"Cannot run async: {function} " + "in non-async context. Use astep()/aiterate()/arun() " + "instead...)" + ) state_to_use = state.subset(*function.reads) return function.run(state_to_use) -def update_state(function: Function, state: State, result: dict) -> State: - """Runs the function updater, returning the new state. Note this restricts the - keys in the state to only those that the function writes. +async def _arun_function(function: Function, state: State) -> dict: + """Runs a function, returning the result of running the function. + Async version of the above. :param function: :param state: + :return: + """ + state_to_use = state.subset(*function.reads) + return await function.run(state_to_use) + + +def _run_reducer(reducer: Reducer, state: State, result: dict, name: str) -> State: + """Runs the reducer, returning the new state. Note this restricts the + keys in the state to only those that the function writes. + + :param reducer: + :param state: :param result: :return: """ - state_to_use = state.subset(*function.writes) - new_state = function.update(result, state_to_use) + state_to_use = state.subset(*reducer.writes) + new_state = reducer.update(result, state_to_use) keys_in_new_state = set(new_state.keys()) - extra_keys = keys_in_new_state - set(function.writes) + extra_keys = keys_in_new_state - set(reducer.writes) if extra_keys: raise ValueError( - f"Function {function.name} wrote to keys {extra_keys} that it did not declare!" + f"Action {name} attempted to write to keys {extra_keys} " + f"that it did not declare. It declared: ({reducer.writes})!" ) - return state.merge(new_state.update(**{PRIOR_STEP: function.name})) + return state.merge(new_state.update(**{PRIOR_STEP: name})) class Application: - def __init__(self, - functions: List[Function], - transitions: List[Transition], - state: State, - initial_step: str): - self.function_map = { - fn.name: fn for fn in functions - } - self.adjacency_map = Application._create_adjacency_map(transitions) - # self.state = Application._initialize_state(initial_step, state) - self.transitions = transitions - self.functions = functions - self.initial_step = initial_step - self.state = state - - @staticmethod - def _create_adjacency_map(transitions: List[Transition]) -> dict: - adjacency_map = collections.defaultdict(list) - for transition in transitions: - from_ = transition.from_ - to = transition.to - adjacency_map[from_.name].append((to.name, transition.condition)) - return adjacency_map - - def _set_state(self, new_state: State): - self.state = new_state - - def get_next_function(self) -> Optional[Function]: - if not self.state.has(PRIOR_STEP): - return self.function_map[self.initial_step] - possibilities = self.adjacency_map[self.state[PRIOR_STEP]] - for next_function, condition in possibilities: - if condition.run(self.state)[Condition.KEY]: - return self.function_map[next_function] - return None - - def step(self) -> Tuple[Function, dict, State]: + def __init__( + self, + actions: List[Action], + transitions: List[Transition], + state: State, + initial_step: str, + adapter_set: Optional[LifecycleAdapterSet] = None, + ): + self._action_map = {action.name: action for action in actions} + self._adjacency_map = Application._create_adjacency_map(transitions) + self._transitions = transitions + self._actions = actions + self._initial_step = initial_step + self._state = state + self._adapter_set = adapter_set if adapter_set is not None else LifecycleAdapterSet() + + def step(self) -> Optional[Tuple[Action, dict, State]]: """Performs a single step, advancing the state machine along. This returns a tuple of the action that was run, the result of running the action, and the new state. - :return: Tuple[Function, dict, State] + Use this if you just want to do something with the state and not rely on generators. + E.G. press forward/backwards, hnuman in the loop, etc... Odds are this is not + the method you want -- you'll want iterate() (if you want to see the state/ + results along the way), or run() (if you just want the final state/results). + + :return: Tuple[Function, dict, State] -- the function that was just ran, the result of running it, and the new state + """ + next_action = self.get_next_action() + self._adapter_set.call_all_lifecycle_hooks_sync( + "pre_run_step", action=next_action, state=self._state + ) + if next_action is None: + return None + exc = None + result = None + new_state = self._state + try: + result = _run_function(next_action, self._state) + new_state = _run_reducer(next_action, self._state, result, next_action.name) + self._set_state(new_state) + except Exception as e: + exc = e + raise e + finally: + self._adapter_set.call_all_lifecycle_hooks_sync( + "post_run_step", action=next_action, state=new_state, result=result, exception=exc + ) + return next_action, result, new_state + + async def astep(self) -> Optional[Tuple[Action, dict, State]]: + """Asynchronous version of step. + + :return: """ - next_function = self.get_next_function() - if next_function is None: - raise StopIteration("No next function!") - result = run_function(next_function, self.state) - new_state = update_state(next_function, self.state, result) + next_action = self.get_next_action() + if next_action is None: + return None + await self._adapter_set.call_all_lifecycle_hooks_async( + "pre_run_step", action=next_action, state=self._state + ) + exc = None + result = None + new_state = self._state + try: + if not next_action.is_async(): + # we can just delegate to the synchronous version, it will block the event loop, + # but that's safer than assuming its OK to launch a thread + # TODO -- add an option/configuration to launch a thread (yikes, not super safe, but for a pure function + # which this is supposed to be its OK). + # this delegatees hooks to the synchronous version, so we'll call all of them as well + return self.step() + result = await _arun_function(next_action, self._state) + new_state = _run_reducer(next_action, self._state, result, next_action.name) + except Exception as e: + exc = e + raise e + finally: + await self._adapter_set.call_all_lifecycle_hooks_sync_and_async( + "post_run_step", action=next_action, state=new_state, result=result, exception=exc + ) self._set_state(new_state) - return next_function, result, new_state + return next_action, result, new_state + + def iterate( + self, *, until: list[str], gate: TerminationCondition = "any_complete" + ) -> Generator[Tuple[Action, dict, State], None, Tuple[State, List[dict]]]: + """Returns a generator that calls step() in a row, enabling you to see the state + of the system as it updates. Note this returns a generator, and also the final result + (for convenience). + + :param until: The list of actions to run until -- these must be strings + that match the names of the actions. + :param gate: The gate to run until. This can be "any_complete" or "all_complete" + :return: Each iteration returns the result of running `step` + The final result is the current state + results in the order that they were specified. + """ + + if gate != "any_complete": + raise NotImplementedError( + "Only any_complete is supported for now -- " + "please reach out to the developers to unblock other gate types!" + ) + + results: List[Optional[dict]] = [None for _ in until] + index_by_name = {name: index for index, name in enumerate(until)} + seen_results = set() + + condition = ( + lambda: any(item in seen_results for item in until) + if gate == "any_complete" + else lambda: len(seen_results) == len(until) + ) + + while not condition(): + result = self.step() + if result is None: + break + action, result, state = result + if action.name in until: + seen_results.add(action.name) + results[index_by_name[action.name]] = result + yield action, result, state + return self._state, results + + async def aiterate( + self, *, until: list[str], gate: TerminationCondition = "any_complete" + ) -> AsyncGenerator[Tuple[Action, dict, State], Tuple[State, List[dict]]]: + """Returns a generator that calls step() in a row, enabling you to see the state + of the system as it updates. This is the asynchronous version so it has no capability of t + + :param until: The list of actions to run until -- these must be strings + that match the names of the action. + :param gate: The gate to run until. This can be "any_complete" or "all_complete" + :return: Each iteration returns the result of running `step` + The final result is the current state + results in the order that they were specified. + """ + + seen_results = set() + condition = ( + lambda: any(item in seen_results for item in until) + if gate == "any_complete" + else lambda: len(seen_results) == len(until) + ) + + while not condition(): + result = await self.astep() + if result is None: + break + action, result, state = result + if action.name in until: + seen_results.add(action.name) + yield action, result, state def run( - self, - until: list[str], # TODO -- until_encountered - gate: TerminationCondition = "any_complete", + self, + *, + until: list[str], + gate: TerminationCondition = "any_complete", ) -> Tuple[State, List[dict]]: """ @@ -120,41 +244,46 @@ def run( :param until: :return: """ + gen = self.iterate(until=until, gate=gate) + while True: + try: + next(gen) + except StopIteration as e: + return e.value + + async def arun( + self, + *, + until: list[str], + gate: TerminationCondition = "any_complete", + ): + """Asynchronous version of run. + + :param gate: + :param until: + :return: + """ + state = self._state results: List[Optional[dict]] = [None for _ in until] index_by_name = {name: index for index, name in enumerate(until)} - seen_results = set() - - def _format_state(state: State, truncate_chars: int = 60) -> str: - out = ", ".join([f"{key}: {value}" for key, value in state.items()]) - if len(out) > truncate_chars: - out = out[:truncate_chars - 3] + "..." - return out - - condition = lambda: any(item in seen_results for item in until) \ - if gate == "any_complete" else lambda: len(seen_results) == len(until) - while not condition(): - print(f"|>> Running step: {self.get_next_function()}") - function, result, state = self.step() - print(f"|-->>> Ran step: {function.name} with result: {_format_state(result)}") - print(f"|-->>> New state: {_format_state(state)}") - if function.name in until: - seen_results.add(function.name) - results[index_by_name[function.name]] = result - return self.state, results + async for action, result, state in self.aiterate(until=until, gate=gate): + if action.name in until: + results[index_by_name[action.name]] = result + return state, results def visualize( - self, - output_file_path: str, - include_conditions: bool = False, - include_state: bool = False, - view: bool = False, - **graphviz_kwargs: Any, + self, + output_file_path: str, + include_conditions: bool = False, + include_state: bool = False, + view: bool = False, + **graphviz_kwargs: Any, ): try: import graphviz # noqa: F401 except ModuleNotFoundError: logger.exception( - " graphviz is required for visualizing the function graph. Install it with:" + " graphviz is required for visualizing the application graph. Install it with:" '\n\n pip install "dw-burr[visualization]" or pip install graphviz \n\n' ) return @@ -172,31 +301,126 @@ def visualize( else: digraph_attr[g_key] = g_value digraph = graphviz.Digraph(**digraph_attr) - for action in self.functions: + for action in self._actions: label = ( action.name if not include_state else f"{action.name}({', '.join(action.reads)}): {', '.join(action.writes)}" ) digraph.node(action.name, label=label, shape="box", style="rounded") - for transition in self.transitions: + for transition in self._transitions: condition = transition.condition digraph.edge( transition.from_.name, transition.to.name, - label=condition.name if include_conditions and condition is not DEFAULT else None, - style="dashed" if transition.condition is not DEFAULT else "solid", + label=condition.name if include_conditions and condition is not default else None, + style="dashed" if transition.condition is not default else "solid", ) digraph.render(output_file_path, view=view) return digraph + @staticmethod + def _create_adjacency_map(transitions: List[Transition]) -> dict: + adjacency_map = collections.defaultdict(list) + for transition in transitions: + from_ = transition.from_ + to = transition.to + adjacency_map[from_.name].append((to.name, transition.condition)) + return adjacency_map + + def _set_state(self, new_state: State): + self._state = new_state + + def get_next_action(self) -> Optional[Action]: + if PRIOR_STEP not in self._state: + return self._action_map[self._initial_step] + possibilities = self._adjacency_map[self._state[PRIOR_STEP]] + for next_action, condition in possibilities: + if condition.run(self._state)[Condition.KEY]: + return self._action_map[next_action] + return None + + def update_state(self, new_state: State): + """Updates state -- this is meant to be called if you need to do + anything with the state. For example: + 1. Reset it (after going through a loop) + 2. Store to some external source/log out + + :param new_state: + :return: + """ + self._state = new_state + + @property + def state(self) -> State: + """Gives the state. Recall that state is purely immutable + -- anything you do with this state will not be persisted unless you + subsequently call update_state. + + :return: The current state object. + """ + return self._state + + @property + def actions(self) -> List[Action]: + return self._actions + + +def _assert_set(value: Optional[Any], field: str, method: str): + if value is None: + raise ValueError( + f"Must set {field} before building application! Do so with ApplicationBuilder.{method}" + ) + + +def _validate_transitions( + transitions: Optional[List[Tuple[str, str, Condition]]], actions: Set[str] +): + _assert_set(transitions, "_transitions", "with_transitions") + exhausted = {} # items for which we have seen a default transition + for from_, to, condition in transitions: + if from_ not in actions: + raise ValueError( + f"Transition source: {from_} not found in actions! " + f"Please add to actions using with_actions({from_}=...)" + ) + if to not in actions: + raise ValueError( + f"Transition target: {to} not found in actions! " + f"Please add to actions using with_actions({to}=...)" + ) + if condition.name == "default": # we have seen a default transition + if from_ in exhausted: + raise ValueError( + f"Transition {from_} -> {to} is redundant -- " + f"a default transition has already been set for {from_}" + ) + exhausted[from_] = True + return True + + +def _validate_start(start: Optional[str], actions: Set[str]): + _assert_set(start, "_start", "with_entrypoint") + if start not in actions: + raise ValueError( + f"Entrypoint: {start} not found in actions. Please add " + f"using with_actions({start}=...)" + ) + + +def _validate_actions(actions: Optional[List[Action]]): + _assert_set(actions, "_actions", "with_actions") + if len(actions) == 0: + raise ValueError("Must have at least one action in the application!") + @dataclasses.dataclass class ApplicationBuilder: - state: State = None + state: State = dataclasses.field(default_factory=State) transitions: List[Tuple[str, str, Condition]] = None - functions: List[Function] = None + actions: List[Action] = None start: str = None + lifecycle_adapters: List[LifecycleAdapter] = dataclasses.field(default_factory=list) def with_state(self, **kwargs) -> "ApplicationBuilder": if self.state is not None: @@ -205,27 +429,47 @@ def with_state(self, **kwargs) -> "ApplicationBuilder": self.state = State(kwargs) return self - def with_entrypoint(self, function: str) -> "ApplicationBuilder": - self.start = function - return self + def with_entrypoint(self, action: str) -> "ApplicationBuilder": + """Adds an entrypoint to the application. This is the action that will be run first. + This can only be called once. - def with_loop(self, from_to: str) -> "ApplicationBuilder": - if self.transitions is None: - self.transitions = [] - self.transitions.append((from_to, from_to, DEFAULT)) + :param action: + :return: + """ + # TODO -- validate only called once + self.start = action return self - def with_functions(self, **functions: Function): - if self.functions is None: - self.functions = [] - for key, value in functions.items(): - self.functions.append(value.with_name(key)) + def with_actions(self, **actions: Action) -> "ApplicationBuilder": + """Adds an action to the application. The actions are granted names (using the with_name) + method post-adding, using the kw argument. Thus, this is the only supported way to add actions. + + :param actions: Actions to add, keyed by name + :return: The application builder for future chaining. + """ + if self.actions is None: + self.actions = [] + for key, value in actions.items(): + self.actions.append(value.with_name(key)) return self - def with_transitions(self, *transitions: Union[ - Tuple[Union[str, list[str]], str], - Tuple[Union[str, list[str]], str, Condition] - ]) -> "ApplicationBuilder": + def with_transitions( + self, + *transitions: Union[ + Tuple[Union[str, list[str]], str], Tuple[Union[str, list[str]], str, Condition] + ], + ) -> "ApplicationBuilder": + """Adds transitions to the application. Transitions are specified as tuples of either: + 1. (from, to, condition) + 2. (from, to) -- condition is set to DEFAULT (which is a fallback) + + Transitions will be evaluated in order of specification -- if one is met, the others will not be evaluated. + Note that one transition can be terminal -- the system doesn't have + + + :param transitions: Transitions to add + :return: The application builder for future chaining. + """ if self.transitions is None: self.transitions = [] for transition in transitions: @@ -233,58 +477,44 @@ def with_transitions(self, *transitions: Union[ if len(conditions) > 0: condition = conditions[0] else: - condition = DEFAULT + condition = default if not isinstance(from_, list): from_ = [from_] for action in from_: + if not isinstance(action, str): + raise ValueError(f"Transition source must be a string, not {action}") + if not isinstance(to_, str): + raise ValueError(f"Transition target must be a string, not {to_}") self.transitions.append((action, to_, condition)) - - # if self.transitions is None: - # self.transitions = [] - # self.transitions.extend(transitions) return self - # def with_transition( - # self, - # from_: Union[Function, list[Function]], - # to: Function, - # condition: Condition = DEFAULT, - # ) -> "ApplicationBuilder": - # if self.transitions is None: - # self.transitions = [] - # if isinstance(from_, list): - # for action in from_: - # self.transitions.append(Transition(action, to, condition)) - # else: - # self.transitions.append(Transition(from_, to, condition)) - # return self - - def _assert_set(self, field: str, method: str): - if getattr(self, field) is None: - raise ValueError(f"Must set {field} before building application! Do so with ApplicationBuilder.{method}") + def with_hooks(self, *adapters: LifecycleAdapter) -> "ApplicationBuilder": + """Adds a lifecycle adapter to the application. This is a way to add hooks to the application so that + they are run at the appropriate times. You can use this to synchronize state out, log results, etc... + + :param adapter: Adapter to add + :return: The application builder for future chaining. + """ + self.lifecycle_adapters.extend(adapters) + return self def build(self) -> Application: - self._assert_set("start", "with_start") - self._assert_set("state", "with_state") - self._assert_set("transitions", "with_transition") - fns_by_name = {fn.name: fn for fn in self.functions} - for from_, to, condition in self.transitions: - if from_ not in fns_by_name: - raise ValueError(f"Transition source: {from_} not found in functions!") - if to not in fns_by_name: - raise ValueError(f"Transition target: {to} not found in functions!") - # transactions_by_name[transition.from_.name] = transition.from_ - # transactions_by_name[transition.to.name] = transition.to + _validate_actions(self.actions) + actions_by_name = {action.name: action for action in self.actions} + all_actions = set(actions_by_name.keys()) + _validate_transitions(self.transitions, all_actions) + _validate_start(self.start, all_actions) return Application( - functions=self.functions, + actions=self.actions, transitions=[ Transition( - from_=fns_by_name[from_], - to=fns_by_name[to], + from_=actions_by_name[from_], + to=actions_by_name[to], condition=condition, ) for from_, to, condition in self.transitions ], state=self.state, initial_step=self.start, + adapter_set=LifecycleAdapterSet(*self.lifecycle_adapters), ) diff --git a/burr/core/implementations.py b/burr/core/implementations.py new file mode 100644 index 00000000..62c49993 --- /dev/null +++ b/burr/core/implementations.py @@ -0,0 +1,34 @@ +from burr.core import State +from burr.core.action import Action + + +class Placeholder(Action): + """This is a placeholder action -- you would expect it to break if you tried to run it. It is specifically + for the following workflow: + 1. Create your state machine out of placeholders to model it + 2. Visualize the state machine + 2. Replace the placeholders with real actions as you see fit + """ + + def __init__(self, reads: list[str], writes: list[str]): + super().__init__() + self._reads = reads + self._writes = writes + + def run(self, state: State) -> dict: + raise NotImplementedError( + f"This is a placeholder action and thus you are unable to run. Please implement: {self}!" + ) + + def update(self, result: dict, state: State) -> State: + raise NotImplementedError( + f"This is a placeholder action and thus cannot update state. Please implement: {self}!" + ) + + @property + def reads(self) -> list[str]: + return self._reads + + @property + def writes(self) -> list[str]: + return self._writes diff --git a/burr/core/state.py b/burr/core/state.py index 2cf481bf..a311887b 100644 --- a/burr/core/state.py +++ b/burr/core/state.py @@ -1,18 +1,133 @@ +import abc +import copy +import dataclasses import logging -from typing import Any, Dict, Mapping, Iterator +from typing import Any, Dict, Iterator, Mapping logger = logging.getLogger(__name__) +class StateDelta(abc.ABC): + """Represents a delta operation for state. This represents a transaction. + Note it has the ability to mutate state in-place, but will be layered behind an immutable + state object.""" + + @classmethod + @abc.abstractmethod + def name(cls) -> str: + """Unique name of this operation for ser/deser""" + pass + + def serialize(self) -> dict: + """Converts the state delta to a JSON object""" + if not dataclasses.is_dataclass(self): + raise TypeError("serialize method is only supported for dataclass instances") + return dataclasses.asdict(self) + + def base_serialize(self) -> dict: + """Converts the state delta to a JSON object""" + return {"name": self.name(), "operation": self.serialize()} + + @classmethod + def deserialize(cls, json_dict: dict) -> "StateDelta": + """Converts a JSON object to a state delta""" + if not dataclasses.is_dataclass(cls): + raise TypeError("deserialize method is only supported for dataclass types") + return cls(**json_dict) # Assumes all fields in the dataclass match keys in json_dict + + def base_deserialize(self, json_dict: dict) -> "StateDelta": + """Converts a JSON object to a state delta""" + return self.deserialize(json_dict) + + @abc.abstractmethod + def reads(self) -> list[str]: + """Returns the keys that this state delta reads""" + pass + + @abc.abstractmethod + def writes(self) -> list[str]: + """Returns the keys that this state delta writes""" + pass + + @abc.abstractmethod + def apply_mutate(self, inputs: dict): + """Applies the state delta to the inputs""" + pass + + +@dataclasses.dataclass +class SetFields(StateDelta): + """State delta that sets fields in the state""" + + values: Mapping[str, Any] + + @classmethod + def name(cls) -> str: + return "set" + + def reads(self) -> list[str]: + return list(self.values.keys()) + + def writes(self) -> list[str]: + return list(self.values.keys()) + + def apply_mutate(self, inputs: dict): + inputs.update(self.values) + + +@dataclasses.dataclass +class AppendFields(StateDelta): + """State delta that appends to fields in the state""" + + values: Mapping[str, Any] + + @classmethod + def name(cls) -> str: + return "append" + + def reads(self) -> list[str]: + return list(self.values.keys()) + + def writes(self) -> list[str]: + return list(self.values.keys()) + + def apply_mutate(self, inputs: dict): + for key, value in self.values.items(): + if key not in inputs: + inputs[key] = [] + if not isinstance(inputs[key], list): + raise ValueError(f"Cannot append to non-list value {key}={inputs[self.key]}") + inputs[key].append(value) + + +@dataclasses.dataclass +class DeleteField(StateDelta): + """State delta that deletes fields from the state""" + + keys: list[str] + + @classmethod + def name(cls) -> str: + return "delete" + + def reads(self) -> list[str]: + return list(self.keys) + + def writes(self) -> list[str]: + return [] + + def apply_mutate(self, inputs: dict): + for key in self.keys: + inputs.pop(key, None) + + class State(Mapping): """An immutable state object. Things to consider: - 1. Merging state back in -- how do we make this ergonomic - 2. Adding hooks on change - 3. Pulling/pushing to external places - 4. Simultaneous writes/reads in the case of parallelism - 5. Schema enforcement -- how to specify/manage? Should this be a + 1. Adding hooks on change + 2. Pulling/pushing to external places + 3. Simultaneous writes/reads in the case of parallelism + 4. Schema enforcement -- how to specify/manage? Should this be a dataclass when implemented? - 6. How to implement this so its a nice wrapper over dict -- likely a subclass of sorts """ def __init__(self, initial_values: Dict[str, Any] = None): @@ -20,39 +135,46 @@ def __init__(self, initial_values: Dict[str, Any] = None): initial_values = dict() self._state = initial_values - def get(self, key: str, default: Any = None) -> any: - """Queries the value of a set of keys in the state""" - try: - return self[key] - except KeyError: - return default - - def has(self, key: str) -> bool: - """Checks if a key is in the state""" - return key in self._state - - def keys(self) -> list[str]: - """Returns the keys in the state""" - return list(self._state.keys()) + def apply_operation(self, operation: StateDelta) -> "State": + """Applies a given operation to the state, returning a new state""" + new_state = copy.deepcopy(self._state) # TODO -- restrict to just the read keys + operation.apply_mutate( + new_state + ) # todo -- validate that the write keys are the only different ones + return State(new_state) def get_all(self) -> Dict[str, Any]: - """Returns the entire state""" - return self._state + """Returns the entire state, realize as a dictionary. This is a copy.""" + return dict(self) def update(self, **updates: Any) -> "State": """Updates the state with a set of key-value pairs""" - return State({**self.get_all(), **updates}) + return self.apply_operation(SetFields(updates)) def append(self, **updates: Any) -> "State": """For each pair specified, appends to a list in state.""" - new_state = self.get_all().copy() - for key, value in updates.items(): - if key not in new_state: - new_state[key] = [] - if not isinstance(new_state[key], list): - raise ValueError(f"Cannot append to non-list value {key}={new_state[key]}") - new_state[key].append(value) - return State(new_state) + + return self.apply_operation(AppendFields(updates)) + + def wipe(self, delete: list[str] = None, keep: list[str] = None): + """Wipes the state, either by deleting the keys in delete and keeping everything else + or keeping the keys in keep. and deleting everything else. If you pass nothing in + it will delete the whole thing. + + :param delete: + :param keep: + :return: + """ + if delete is not None and keep is not None: + raise ValueError( + f"You cannot specify both delete and keep -- not both! " + f"You have specified: delete={delete}, keep={keep}" + ) + if delete is not None: + fields_to_delete = delete + else: + fields_to_delete = [key for key in self._state if key not in keep] + return self.apply_operation(DeleteField(fields_to_delete)) def merge(self, other: "State") -> "State": """Merges two states together, overwriting the values in self @@ -61,7 +183,7 @@ def merge(self, other: "State") -> "State": def subset(self, *keys: str, ignore_missing: bool = True) -> "State": """Returns a subset of the state, with only the given keys""" - return State({key: self[key] for key in keys if self.has(key) or not ignore_missing}) + return State({key: self[key] for key in keys if key in self or not ignore_missing}) def __getitem__(self, __k: str) -> Any: return self._state[__k] @@ -73,4 +195,4 @@ def __iter__(self) -> Iterator[Any]: return iter(self._state) def __repr__(self): - return self.get_all().__repr__() # quick hack + return self.get_all().__repr__() # quick hack diff --git a/burr/implementations.py b/burr/implementations.py deleted file mode 100644 index 0f16a719..00000000 --- a/burr/implementations.py +++ /dev/null @@ -1,21 +0,0 @@ -from burr.core import State -from burr.core.function import Function - - -class Placeholder(Function): - def __init__(self, name: str, reads: list[str], writes: list[str]): - super().__init__(name) - self._reads = reads - self._writes = writes - - def run(self, state: State) -> dict: - return {key: ... for key in self.writes()} - - def update(self, result: dict, state: State) -> State: - return state.update(**result) - - def reads(self) -> list[str]: - return self._reads - - def writes(self) -> list[str]: - return self._writes diff --git a/burr/integrations/base.py b/burr/integrations/base.py new file mode 100644 index 00000000..b6b4d2f1 --- /dev/null +++ b/burr/integrations/base.py @@ -0,0 +1,6 @@ +def require_plugin(import_error: ImportError, libraries: list[str], plugin_name: str): + raise ImportError( + f"Missing plugin {plugin_name}! To use the {plugin_name} plugin, you must install the following libraries: {libraries}." + f"You can install this with dw-burr[{plugin_name}] or pip install {' '.join(libraries)} (replace with your " + f"package manager of choice)." + ) from import_error diff --git a/burr/integrations/hamilton.py b/burr/integrations/hamilton.py index 694cb9be..78c9527b 100644 --- a/burr/integrations/hamilton.py +++ b/burr/integrations/hamilton.py @@ -1,7 +1,7 @@ import dataclasses -from typing import Any, Callable, Literal, Tuple, Dict, Union, List +from typing import Any, Dict, Literal, Tuple, Union -from burr.core import Function, State +from burr.core import Action, State from hamilton.driver import Driver @@ -24,13 +24,26 @@ class LiteralSource: def from_state( - key: str, - missing: MissingAction = "error", + key: str, + missing: MissingAction = "error", ) -> StateSource: + """Indicates that an input should come from state. + Specify "missing" to allow for missing keys to be dropped or raise an error. + + :param key: Key in state to use + :param missing: What to do if the key is missing + :return: A StateSource object -- use by Hamilton(inputs=...) + """ return StateSource(key, missing) -def from_value(value: str) -> LiteralSource: +def from_value(value: Any) -> LiteralSource: + """Indicates that an input should come from a literal (variable/constant) value. + Use this if you just want to fix a parameter into the Hamilton DAG. + + :param value: Value to use + :return: A LiteralSource object -- use by Hamilton(inputs=...) + """ return LiteralSource(value) @@ -41,45 +54,81 @@ class Output: def update_state(key: str) -> Output: + """At the update step of a Hamilton Action, call state.update to the key field of state. + Used with outputs= parameter of Hamilton(...) + + :param key: Field in state to udpate + :return: An Output object + """ return Output(key, "update") def append_state(key: str): + """At the update state of a Hamilton Action, call state.append to the key field of state. + Used with outputs= parameter of Hamilton(...) + + :param key: Field in state to append to + :return: An Output object + """ return Output(key, "append") DEFAULT_DRIVER = None -class Hamilton(Function): +class Hamilton(Action): @staticmethod def set_driver(driver: Driver): + """Default method if all the hamilton nodes are using the same driver. + Will set globally, so be careful. + + Note that the driver must have the default adapter (so that it returns a dict). + + :param driver: Driver to use + """ global DEFAULT_DRIVER DEFAULT_DRIVER = driver def __init__( - self, - inputs: Dict[str, Input], - outputs: Dict[str, Output], - driver: Driver = None, - name: str = None, + self, + inputs: Dict[str, Input], + outputs: Dict[str, Output], + driver: Driver = None, ): - super().__init__(name=name) + """Creates a Hamilton action. Allows youy to specify: + 1. How to wire state fields into hamilton inputs + 2 How to wire hamilton outputs into state fields + + Note that we o not distinguish between overrides and inputs -- we intelligently decide + which are which based on the driver's available variables. + + :param inputs: + :param outputs: + :param driver: + :param name: + """ + super(Hamilton, self).__init__() if driver is None and DEFAULT_DRIVER is None: - raise ValueError("Driver must be set before creating a Hamilton function. " - "You can do so with Hamilton.set_driver(...) to set it globally, " - "or pass in driver to the Hamilton(...) constructor.") + raise ValueError( + "Driver must be set before creating a Hamilton function. " + "You can do so with Hamilton.set_driver(...) to set it globally, " + "or pass in driver to the Hamilton(...) constructor." + ) self._driver = driver if driver is not None else DEFAULT_DRIVER self._inputs = inputs self._outputs = outputs - def _extract_inputs_overrides(self, state: State, driver: Driver) -> Tuple[dict, dict]: + @property + def driver(self): + return self._driver + + def _extract_inputs_overrides(self, state: State) -> Tuple[dict, dict]: """Extracts the inputs and overrides from the state.""" def resolve_value(source: Input) -> Any: if isinstance(source, StateSource): - if state.has(source.state_key): - return state.get(source.state_key) + if source.state_key in state: + return state[source.state_key] else: if source.missing == "error": raise ValueError(f"Missing state key {source.state_key}") @@ -90,8 +139,13 @@ def resolve_value(source: Input) -> Any: inputs = {} overrides = {} - dr_vars = {node.name: node for node in driver.list_available_variables()} + dr_vars = {node.name: node for node in self._driver.list_available_variables()} for key, source in self._inputs.items(): + if key not in dr_vars: + raise ValueError( + f"Input {key} not available in driver -- " + f"available variables are: {list(self._driver.list_available_variables())}" + ) node = dr_vars[key] if node.is_external_input: inputs[node.name] = resolve_value(source) @@ -100,7 +154,12 @@ def resolve_value(source: Input) -> Any: return inputs, overrides def run(self, state: State) -> dict: - inputs, overrides = self._extract_inputs_overrides(state, self._driver) + """Runs a hamilton action, using the driver to execute the hamilton DAG. + + :param state: The state to use + :return: The results of the hamilton DAG + """ + inputs, overrides = self._extract_inputs_overrides(state) result = self._driver.raw_execute( list(self._outputs.keys()), overrides=overrides, @@ -109,11 +168,20 @@ def run(self, state: State) -> dict: return result def update(self, result: dict, state: State) -> State: + """Updates the state with the results of the hamilton action, as specified in the outputs. + + :param result: The results of the hamilton DAG + :param state: The state to update + """ update_values = { - self._outputs[key].key: value for key, value in result.items() if self._outputs[key].mode == "update" + self._outputs[key].key: value + for key, value in result.items() + if self._outputs[key].mode == "update" } append_values = { - self._outputs[key].key: value for key, value in result.items() if self._outputs[key].mode == "append" + self._outputs[key].key: value + for key, value in result.items() + if self._outputs[key].mode == "append" } return state.update(**update_values).append(**append_values) @@ -124,7 +192,9 @@ def reads(self) -> list[str]: 1. Parse the inputs and overrides to determine which state items are read 2. Return them """ - return [source.state_key for source in self._inputs.values() if isinstance(source, StateSource)] + return [ + source.state_key for source in self._inputs.values() if isinstance(source, StateSource) + ] @property def writes(self) -> list[str]: @@ -142,4 +212,10 @@ def visualize_step(self, **kwargs): inputs = {key: ... for key in self._inputs} overrides = inputs final_vars = list(self._outputs.keys()) - return dr.visualize_execution(final_vars=final_vars, inputs=inputs, overrides=overrides, bypass_validation=True, **kwargs) + return dr.visualize_execution( + final_vars=final_vars, + inputs=inputs, + overrides=overrides, + bypass_validation=True, + **kwargs, + ) diff --git a/burr/integrations/streamlit.py b/burr/integrations/streamlit.py new file mode 100644 index 00000000..9b5b481e --- /dev/null +++ b/burr/integrations/streamlit.py @@ -0,0 +1,270 @@ +import dataclasses +import inspect +import json +from typing import List, Optional + +from burr.core import Application +from burr.integrations.base import require_plugin +from burr.integrations.hamilton import Hamilton, StateSource + +try: + import colorsys + + import graphviz + import matplotlib.colors as mc + import streamlit as st +except ImportError as e: + require_plugin( + e, + ["streamlit", "graphviz", "colorsys", "matplotlib"], + "streamlit", + ) + + +@dataclasses.dataclass +class Record: + state: dict + action: str + result: dict + + +@dataclasses.dataclass +class AppState: + display_index: Optional[int] # index in the state/results dict + history: list[Record] + app: Application # we have to have this for registering the state machine -- + num_prior_nodes: int = 5 # view last 5 + + @property + def current_action(self) -> Optional[str]: + if self.display_index is None or len(self.history) == 0: + return None + return self.history[self.display_index].action + + @property + def prior_actions(self) -> List[str]: + if self.display_index is None or len(self.history) == 0: + return [] + return [ + record.action + for record in self.history[ + self.display_index - self.num_prior_nodes : self.display_index + ] + ] + + @property + def next_action(self) -> str: + return self.app.get_next_action().name + + @property + def max_index(self) -> int: + return len(self.history) - 1 + + @classmethod + def from_empty(cls, app: Application) -> "AppState": + return AppState(display_index=0, history=[], app=app) + + +def load_state_from_log_file(jsonl_log_file: str, app: Application) -> AppState: + """Initializes the state from a log file. This must have been logged using StateAndResultFullLogger. + Note that, currently, you must pass in an Application object (although that will be optoinal in the future). + + :param jsonl_log_file: Log file to load + :param app: Application object + :return: AppState + """ + out = [] + for i, line in enumerate(open(jsonl_log_file)): + json_line = json.loads(line) + record = Record( + state=json_line["state"], + action=json_line["action"], + result=json_line["result"] + # TODO -- add start time, end time + ) + out.append(record) + return AppState(display_index=len(out) - 1, history=out, app=app) + + +def update_state(new_state: AppState): + st.session_state.burr_state = new_state + + +def get_state(): + if "burr_state" not in st.session_state: + raise ValueError( + "No state found in streamlit session state. To initialize the state, call " + "initialize_state() as the first line from streamlit -- it will do nothing if the state is already initialized." + ) + return st.session_state.get("burr_state") + + +def _modify_state_machine_digraph( + digraph: graphviz.Digraph, current_node: str = None, prior_nodes: list = [] +): + def lighten_color(color, amount=0.5): + if amount > 1: + amount = 1 + if amount < 0: + amount = 0 + try: + c = mc.cnames[color] + except KeyError: + c = color + c = colorsys.rgb_to_hls(*mc.to_rgb(c)) + lightened_color = colorsys.hls_to_rgb(c[0], 1 - amount * (1 - c[1]), c[2]) + return mc.to_hex(lightened_color) + + digraph.node(current_node, fillcolor="blue", style="rounded,filled", fontcolor="white") + seen = {current_node} + base_color = "lightblue" + for i, node in enumerate(prior_nodes): + if node not in seen: + seen.add(node) + lighter_color = lighten_color(base_color, amount=1 - ((i + 1) * 0.1)) + digraph.node(node, fillcolor=lighter_color, style="rounded,filled", fontcolor="black") + else: + continue + + digraph.attr(bgcolor="transparent") + + +def render_state_machine(state: AppState): + """Visualization of the current state machine. Highlights: + 1. Current node in blue (with white backgorund) + 2. Prior nodes in progressively lighter shades of blue + + Use this individually, or within the "render_explorer" view + + :param state: + :return: + """ + prior_nodes = state.prior_actions # grab the prior nodes + current_node = state.next_action + app = state.app + visualized = app.visualize(None, include_conditions=False, include_state=True) + if current_node is not None: + _modify_state_machine_digraph( + visualized, current_node=current_node, prior_nodes=prior_nodes + ) + st.graphviz_chart(visualized, use_container_width=True) + + +def render_action(state: AppState): + """Renders the current action, including the reads, writes, and the code for the action. + With Hamilton actions, it will also show the visualization of the action. + + This can be used individually (with a state object) or within the "render_explorer" view. + + :param state: + :return: + """ + app: Application = state.app + current_node = state.current_action + actions = {action.name: action for action in app.actions} + if current_node is None: + st.markdown("No current action.") + return + action_object = actions[current_node] + is_hamilton = isinstance(action_object, Hamilton) + + def format_read(var): + out = f"- `{var}`" + if is_hamilton: + inputs = action_object._inputs # TODO -- don't use private variables + corresponding_input = { + k: v for k, v in inputs.items() if isinstance(v, StateSource) and v.state_key == var + } + if corresponding_input: + return f"- `state['{var}']` → `{list(corresponding_input.keys())[0]}`" + return out + + def format_write(var): + out = f"- `{var}`" + if is_hamilton: + outputs = action_object._outputs + corresponding_output = {k: v for k, v in outputs.items() if v.key == var} + k, v = list(corresponding_output.items())[0] + if corresponding_output: + out = f"- `state['{var}']`" + if v.mode == "update": + out += f"← `{k}`" + if v.mode == "append": + out += f" ← `state['{var}'].append({k})`" + # out += " (`.append()`)" + return out + + reads = "\n".join([format_read(var) for var in action_object.reads]) + writes = "\n".join([format_write(var) for var in action_object.writes]) + st.markdown(f"#### Reads: \n {reads}") + st.markdown(f"#### Writes: \n {writes}") + if is_hamilton: + digraph = action_object.visualize_step(show_legend=False) + st.graphviz_chart(digraph, use_container_width=False) + else: + code = inspect.getsource(action_object.__class__) + st.code(code, language="python") + + +def render_state_results(state: AppState): + """Render the state and results for the current state. This includes the state and the result of the action. + This can be used individually (with a state object) or within the "render_explorer" view. + + :param state: State object + :return: None + """ + if len(state.history) == 0: # empty, have not yet started + return + state_to_render = state.history[state.display_index].state + result_to_render = state.history[state.display_index].result + # if "chat_history" in state_to_render: + # del state_to_render["chat_history"] + st.header("State") + st.json(state_to_render, expanded=True) + st.header("Result") + st.json(result_to_render, expanded=True) + + +def set_slider_to_current(): + st.session_state.index_slider = get_state().max_index + + +def render_explorer(app_state: AppState): + """Renders the entire explorer, including the state machine, the action, and the state/results. + + :param app_state: State of the app + :return: None + """ + total_state_length = len(app_state.history) + placeholder = st.empty() + placeholder.markdown("` `") + if total_state_length > 1: + slider_values = list(range(total_state_length)) + slider_strings = [record.action for record in app_state.history] + + def stringify(i): + return slider_strings[i] + + slider = st.select_slider( + "index", + options=slider_values, + label_visibility="hidden", + format_func=stringify, + key="index_slider", + ) + current_node_index = slider + # TODO -- consider a callback here instead + app_state.display_index = current_node_index + + with placeholder.container(height=800): + state_machine_view, step_view, data_view = st.tabs( + ["Application", "Action", "State/Results"] + ) + with st.container(): + with state_machine_view: + render_state_machine(app_state) + with step_view: + render_action(app_state) + with data_view: + render_state_results(app_state) + update_state(app_state) diff --git a/burr/lifecycle/__init__.py b/burr/lifecycle/__init__.py new file mode 100644 index 00000000..25e31ead --- /dev/null +++ b/burr/lifecycle/__init__.py @@ -0,0 +1,25 @@ +from burr.lifecycle.base import ( + LifecycleAdapter, + PostRunApplicationHook, + PostRunApplicationHookAsync, + PostRunStepHook, + PostRunStepHookAsync, + PreRunApplicationHook, + PreRunApplicationHookAsync, + PreRunStepHook, + PreRunStepHookAsync, +) +from burr.lifecycle.default import StateAndResultsFullLogger + +__all__ = [ + "PreRunStepHook", + "PreRunStepHookAsync", + "PostRunStepHook", + "PostRunStepHookAsync", + "PreRunApplicationHook", + "PreRunApplicationHookAsync", + "PostRunApplicationHook", + "PostRunApplicationHookAsync", + "LifecycleAdapter", + "StateAndResultsFullLogger", +] diff --git a/burr/lifecycle/base.py b/burr/lifecycle/base.py new file mode 100644 index 00000000..5e3e4eff --- /dev/null +++ b/burr/lifecycle/base.py @@ -0,0 +1,100 @@ +import abc +from typing import TYPE_CHECKING, Any, Optional, Union + +if TYPE_CHECKING: + # type-checking-only for a circular import + from burr.core import State, Action + +from burr.lifecycle.internal import lifecycle + + +@lifecycle.base_hook("pre_run_step") +class PreRunStepHook(abc.ABC): + @abc.abstractmethod + def pre_run_step(self, *, state: "State", action: "Action", **future_kwargs: Any): + pass + + +@lifecycle.base_hook("pre_run_step") +class PreRunStepHookAsync(abc.ABC): + @abc.abstractmethod + async def pre_run_step(self, *, state: "State", action: "Action", **future_kwargs: Any): + pass + + +@lifecycle.base_hook("post_run_step") +class PostRunStepHook(abc.ABC): + @abc.abstractmethod + def post_run_step( + self, + *, + state: "State", + action: "Action", + result: Optional[dict], + exception: Exception, + **future_kwargs: Any + ): + pass + + +@lifecycle.base_hook("post_run_step") +class PostRunStepHookAsync(abc.ABC): + @abc.abstractmethod + async def post_run_step( + self, + *, + state: "State", + action: "Action", + result: Optional[dict], + exception: Exception, + **future_kwargs: Any + ): + pass + + +# THESE ARE NOT IN USE +# TODO -- implement/decide how to use them +@lifecycle.base_hook("pre_run_application") +class PreRunApplicationHook(abc.ABC): + @abc.abstractmethod + def pre_run_application(self, *, state: "State", **future_kwargs: Any): + pass + + +@lifecycle.base_hook("pre_run_application") +class PreRunApplicationHookAsync(abc.ABC): + @abc.abstractmethod + async def pre_run_application(self, *, state: "State", **future_kwargs): + pass + + +@lifecycle.base_hook("post_run_application") +class PostRunApplicationHook(abc.ABC): + @abc.abstractmethod + def post_run_application( + self, *, state: "State", until: list[str], results: list[dict], **future_kwargs + ): + pass + + +@lifecycle.base_hook("post_run_application") +class PostRunApplicationHookAsync(abc.ABC): + @abc.abstractmethod + async def post_run_application( + self, *, state: "State", until: list[str], results: list[dict], **future_kwargs + ): + pass + + +# strictly for typing -- this conflicts a bit with the lifecycle decorator above, but its fine for now +# This makes IDE completion/type-hinting easier +LifecycleAdapter = Union[ + PreRunStepHook, + PreRunStepHookAsync, + PostRunStepHook, + PostRunStepHookAsync, + PreRunApplicationHook, + PreRunApplicationHookAsync, + PostRunApplicationHook, + PostRunApplicationHookAsync, +] diff --git a/burr/lifecycle/default.py b/burr/lifecycle/default.py new file mode 100644 index 00000000..b1f4c57d --- /dev/null +++ b/burr/lifecycle/default.py @@ -0,0 +1,56 @@ +import datetime +import json +from typing import TYPE_CHECKING, Any, Callable, Literal, Optional + +if TYPE_CHECKING: + from burr.core import State, Action + +from burr.lifecycle.base import PostRunStepHook, PreRunStepHook + + +def safe_json(obj: Any) -> str: + return json.dumps(obj, default=str) + + +class StateAndResultsFullLogger(PostRunStepHook, PreRunStepHook): + DONT_INCLUDE = object() # sentinel value + + def __init__( + self, + jsonl_path: str, + mode: Literal["append", "w"] = "append", + json_dump: Callable[[dict], str] = safe_json, + ): + if not jsonl_path.endswith(".jsonl"): + raise ValueError(f"jsonl_path must end with .jsonl. Got: {jsonl_path}") + self.jsonl_path = jsonl_path + open_mode = "a" if mode == "append" else "w" + self.f = open(jsonl_path, mode=open_mode) # open in append mode + self.tracker = [] # tracker to keep track of timing/whatnot + + def pre_run_step(self, **future_kwargs: Any): + self.tracker.append({"time": datetime.datetime.now()}) + + def post_run_step( + self, + *, + state: "State", + action: "Action", + result: Optional[dict], + exception: Exception, + **future_kwargs: Any, + ): + state_and_result = { + "state": state.get_all(), + "action": action.name, + "result": result, + "exception": str(exception), + "start_time": self.tracker[-1]["time"].isoformat(), + "end_time": datetime.datetime.now().isoformat(), + } + self.f.writelines([safe_json(state_and_result) + "\n"]) + + def __del__(self): + if hasattr(self, "f"): + # possible something fails beforehand + self.f.close() diff --git a/burr/lifecycle/internal.py b/burr/lifecycle/internal.py new file mode 100644 index 00000000..f70bc5b9 --- /dev/null +++ b/burr/lifecycle/internal.py @@ -0,0 +1,187 @@ +"""Base tooling, internal-facing, for lifecycle hooks. This is stolen from the +hamilton implementation, but significantly simplified.""" +import asyncio +import collections +import inspect +from typing import TYPE_CHECKING, Callable, Dict, List, Set, Tuple + +if TYPE_CHECKING: + # type-checking-only for a circular import + from burr.lifecycle.base import LifecycleAdapter + +SYNC_HOOK = "hooks" +ASYNC_HOOK = "async_hooks" + +REGISTERED_SYNC_HOOKS: Set[str] = set() +REGISTERED_ASYNC_HOOKS: Set[str] = set() + + +class InvalidLifecycleHook(Exception): + """Container exception to indicate that a lifecycle adapter is invalid.""" + + pass + + +def validate_hook_fn(fn: Callable): + """Validates that a function forms a valid hook. This means: + 1. Function returns nothing + 2. Function must consist of only kwarg-only arguments + + :param fn: The function to validate + :raises InvalidLifecycleAdapter: If the function is not a valid hook + """ + sig = inspect.signature(fn) + if ( + "future_kwargs" not in sig.parameters + or sig.parameters["future_kwargs"].kind != inspect.Parameter.VAR_KEYWORD + ): + raise InvalidLifecycleHook( + f"Lifecycle hooks must have a `**future_kwargs` argument. Method/hook {fn} does not." + ) + for param in sig.parameters.values(): + if param.name != "future_kwargs": + if param.kind != inspect.Parameter.KEYWORD_ONLY and param.name != "self": + raise InvalidLifecycleHook( + f"Lifecycle hooks can only have keyword-only arguments. " + f"Method/hook {fn} has argument {param} that is not keyword-only." + ) + + +class lifecycle: + """Container class for decorators to register hooks. + This is just a container so it looks clean (`@lifecycle.base_hook(...)`), but we could easily move it out. + What do these decorators do? + 1. We decorate a class with a method/hook/validator call + 2. This implies that there exists a function by that name + 3. We validate that that function has an appropriate signature + 4. We store this in the appropriate registry (see the constants above) + Then, when we want to perform a hook/method/validator, we can ask the AdapterLifecycleSet to do so. + It crawls up the MRO, looking to see which classes are in the registry, then elects which functions to run. + See LifecycleAdapterSet for more information. + """ + + @classmethod + def base_hook(cls, fn_name: str): + """Hooks get called at distinct stages of Hamilton's execution. + These can be layered together, and potentially coupled to other hooks. + + :param fn_name: Name of the function that will reside in the class we're decorating + """ + + def decorator(clazz): + fn = getattr(clazz, fn_name, None) + if fn is None: + raise ValueError( + f"Class {clazz} does not have a method {fn_name}, but is " + f'decorated with @lifecycle.base_hook("{fn_name}"). The parameter ' + f"to @lifecycle.base_hook must be the name " + f"of a method on the class." + ) + validate_hook_fn(fn) + if inspect.iscoroutinefunction(fn): + setattr(clazz, ASYNC_HOOK, fn_name) + REGISTERED_ASYNC_HOOKS.add(fn_name) + else: + setattr(clazz, SYNC_HOOK, fn_name) + REGISTERED_SYNC_HOOKS.add(fn_name) + return clazz + + return decorator + + +class LifecycleAdapterSet: + """An internal class that groups together all the lifecycle adapters. + This allows us to call methods through a delegation pattern, enabling us to add + whatever callbacks, logging, error-handling, etc... we need globally. While this + does increase the stack trace in an error, it should be pretty easy to figure out what's going on. + """ + + def __init__(self, *adapters: "LifecycleAdapter"): + """Initializes the adapter set. + + :param adapters: Adapters to group together + """ + self._adapters = list(adapters) + self.sync_hooks, self.async_hooks = self._get_lifecycle_hooks() + + def _get_lifecycle_hooks( + self, + ) -> Tuple[Dict[str, List["LifecycleAdapter"]], Dict[str, List["LifecycleAdapter"]]]: + sync_hooks = collections.defaultdict(list) + async_hooks = collections.defaultdict(list) + for adapter in self.adapters: + for cls in inspect.getmro(adapter.__class__): + sync_hook = getattr(cls, SYNC_HOOK, None) + if sync_hook is not None: + if adapter not in sync_hooks[sync_hook]: + sync_hooks[sync_hook].append(adapter) + async_hook = getattr(cls, ASYNC_HOOK, None) + if async_hook is not None: + if adapter not in async_hooks[async_hook]: + async_hooks[async_hook].append(adapter) + return ( + {hook: adapters for hook, adapters in sync_hooks.items()}, + {hook: adapters for hook, adapters in async_hooks.items()}, + ) + + def _does_hook(self, hook_name: str, is_async: bool) -> bool: + """Whether or not a hook is implemented by any of the adapters in this group. + If this hook is not registered, this will raise a ValueError. + + :param hook_name: Name of the hook + :param is_async: Whether you want the async version or not + :return: True if this adapter set does this hook, False otherwise + """ + if is_async and hook_name not in REGISTERED_ASYNC_HOOKS: + raise ValueError( + f"Hook {hook_name} is not registered as an asynchronous lifecycle hook. " + f"Registered hooks are {REGISTERED_ASYNC_HOOKS}" + ) + if not is_async and hook_name not in REGISTERED_SYNC_HOOKS: + raise ValueError( + f"Hook {hook_name} is not registered as a synchronous lifecycle hook. " + f"Registered hooks are {REGISTERED_SYNC_HOOKS}" + ) + if not is_async: + return hook_name in self.sync_hooks + return hook_name in self.async_hooks + + def call_all_lifecycle_hooks_sync(self, hook_name: str, **kwargs): + """Calls all the lifecycle hooks in this group, by hook name (stage) + + :param hook_name: Name of the hooks to call + :param kwargs: Keyword arguments to pass into the hook + """ + if not self._does_hook(hook_name, False): + return + for adapter in self.sync_hooks[hook_name]: + getattr(adapter, hook_name)(**kwargs) + + async def call_all_lifecycle_hooks_async(self, hook_name: str, **kwargs): + """Calls all the lifecycle hooks in this group, by hook name (stage). + + :param hook_name: Name of the hook + :param kwargs: Keyword arguments to pass into the hook + """ + if not self._does_hook(hook_name, True): + return + futures = [] + for adapter in self.async_hooks[hook_name]: + futures.append(getattr(adapter, hook_name)(**kwargs)) + await asyncio.gather(*futures) + + async def call_all_lifecycle_hooks_sync_and_async(self, hook_name: str, **kwargs): + """Calls all the lifecycle hooks in this group, by hook name (stage). + + :param hook_name: Name of the hook + """ + self.call_all_lifecycle_hooks_sync(hook_name, **kwargs) + await self.call_all_lifecycle_hooks_async(hook_name, **kwargs) + + @property + def adapters(self) -> List["LifecycleAdapter"]: + """Gives the adapters in this group + + :return: A list of adapters + """ + return self._adapters diff --git a/examples/counter/application.py b/examples/counter/application.py new file mode 100644 index 00000000..9fe3db26 --- /dev/null +++ b/examples/counter/application.py @@ -0,0 +1,45 @@ +import burr.core +from burr.core import Action, Result, State, default, expr +from burr.lifecycle import StateAndResultsFullLogger + + +class CounterAction(Action): + @property + def reads(self) -> list[str]: + return ["counter"] + + def run(self, state: State) -> dict: + return {"counter": state["counter"] + 1} + + @property + def writes(self) -> list[str]: + return ["counter"] + + def update(self, result: dict, state: State) -> State: + return state.update(**result) + + +def application(count_up_to: int = 10, log_file: str = None): + return ( + burr.core.ApplicationBuilder() + .with_state( + counter=0, + ) + .with_actions(counter=CounterAction(), result=Result(["counter"])) + .with_transitions( + ("counter", "counter", expr(f"counter < {count_up_to}")), + ("counter", "result", default), + ("result", "counter", expr("counter == 0")), # when we've reset, manually + ) + .with_entrypoint("counter") + .with_hooks(*[StateAndResultsFullLogger(log_file)] if log_file else []) + .build() + ) + + +if __name__ == "__main__": + app = application(log_file="counter.jsonl") + state, result = app.run(until=["result"]) + app.visualize(output_file_path="counter.png", include_conditions=True, view=True) + assert state["counter"] == 10 + print(state["counter"]) diff --git a/examples/counter/requirements.txt b/examples/counter/requirements.txt new file mode 100644 index 00000000..1670c9d8 --- /dev/null +++ b/examples/counter/requirements.txt @@ -0,0 +1 @@ +dw-burr[streamlit] diff --git a/examples/counter/streamlit_app.py b/examples/counter/streamlit_app.py new file mode 100644 index 00000000..6d9a7002 --- /dev/null +++ b/examples/counter/streamlit_app.py @@ -0,0 +1,64 @@ +import counter +import streamlit as st + +from burr.integrations.streamlit import ( + AppState, + Record, + get_state, + render_explorer, + set_slider_to_current, + update_state, +) + + +def counter_view(app_state: AppState): + application = app_state.app + button = st.button("Forward", use_container_width=True) + if button: + step_output = application.step() + if step_output is not None: + action, result, state = step_output + app_state.history.append(Record(state.get_all(), action.name, result)) + set_slider_to_current() + else: + application.update_state(application.state.update(counter=0)) + action, result, state = application.step() + app_state.history.append(Record(state.get_all(), action.name, result)) + + +def retrieve_state(): + if "burr_state" not in st.session_state: + state = AppState.from_empty(app=counter.application()) + else: + state = get_state() + return state + + +def main(): + st.set_page_config(layout="wide") + st.title("Counting numbers with Burr") + app_state = retrieve_state() # retrieve first so we can use for the ret of the step + columns = st.columns(2) + with columns[0]: + st.write( + "This is a simple counter app. It counts to 10, then loops back to 0. You can reset it at any time. " + "While we know that this is easy to do with a simple loop + streamlit, it highlights the state that Burr manages." + "Use the slider to rewind/see what happened in the past, and the visualizations to understand how we navigate " + "through the state machine!" + ) + counter_view(app_state) + with st.container(height=800): + md_lines = [] + for item in app_state.history: + if item.action == "counter": + md_lines.append(f"Counted to {item.state['counter']}!") + else: + md_lines.append("Looping back! ") + st.code("\n".join(md_lines), language=None) + with columns[1]: + render_explorer(app_state) + update_state(app_state) # update so the next iteration knows what to do + + +if __name__ == "__main__": + main() diff --git a/examples/cowsay/application.py b/examples/cowsay/application.py new file mode 100644 index 00000000..7f7f014d --- /dev/null +++ b/examples/cowsay/application.py @@ -0,0 +1,103 @@ +import random +import time +from typing import List, Optional + +import cowsay + +from burr.core import Action, Application, ApplicationBuilder, State, default, expr +from burr.lifecycle import PostRunStepHook + + +class CowSay(Action): + def __init__(self, say_what: List[Optional[str]]): + super(CowSay, self).__init__() + self.say_what = say_what + + @property + def reads(self) -> list[str]: + return [] + + def run(self, state: State) -> dict: + say_what = random.choice(self.say_what) + return { + "cow_said": cowsay.get_output_string("cow", say_what) if say_what is not None else None + } + + @property + def writes(self) -> list[str]: + return ["cow_said"] + + def update(self, result: dict, state: State) -> State: + return state.update(**result) + + +class CowShouldSay(Action): + @property + def reads(self) -> list[str]: + return [] + + def run(self, state: State) -> dict: + if not random.randint(0, 3): + return {"cow_should_speak": True} + return {"cow_should_speak": False} + + @property + def writes(self) -> list[str]: + return ["cow_should_speak"] + + def update(self, result: dict, state: State) -> State: + return state.update(**result) + + +class PrintWhatTheCowSaid(PostRunStepHook): + def post_run_step(self, *, state: "State", action: "Action", **future_kwargs): + if action.name != "cow_should_say" and state["cow_said"] is not None: + print(state["cow_said"]) + + +class CowCantSpeakFast(PostRunStepHook): + def __init__(self, sleep_time: float): + super(PostRunStepHook, self).__init__() + self.sleep_time = sleep_time + + def post_run_step(self, *, state: "State", action: "Action", **future_kwargs): + if action.name != "cow_should_say": # no need to print if we're not saying anything + time.sleep(self.sleep_time) + + +def application(in_terminal: bool = False) -> Application: + hooks = ( + [ + PrintWhatTheCowSaid(), + CowCantSpeakFast(sleep_time=2.0), + ] + if in_terminal + else [] + ) + return ( + ApplicationBuilder() + .with_state( + cow_said=None, + ) + .with_actions( + say_nothing=CowSay([None]), + say_hello=CowSay(["Hello world!", "What's up?", "Are you Aaron Burr, sir?"]), + cow_should_say=CowShouldSay(), + ) + .with_transitions( + ("cow_should_say", "say_hello", expr("cow_should_speak")), + ("say_hello", "cow_should_say", default), + ("cow_should_say", "say_nothing", expr("not cow_should_speak")), + ("say_nothing", "cow_should_say", default), + ) + .with_entrypoint("cow_should_say") + .with_hooks(*hooks) + .build() + ) + + +if __name__ == "__main__": + app = application(in_terminal=True) + app.visualize(output_file_path="cowsay.png", include_conditions=True, view=True) + while True: + state, result, action = app.step() diff --git a/examples/cowsay/requirements.txt b/examples/cowsay/requirements.txt new file mode 100644 index 00000000..3155b767 --- /dev/null +++ b/examples/cowsay/requirements.txt @@ -0,0 +1,2 @@ +cowsay +dw-burr[streamlit] diff --git a/examples/cowsay/streamlit_app.py b/examples/cowsay/streamlit_app.py new file mode 100644 index 00000000..8db01336 --- /dev/null +++ b/examples/cowsay/streamlit_app.py @@ -0,0 +1,69 @@ +import application as cowsay_application +import streamlit as st + +from burr.integrations.streamlit import ( + AppState, + Record, + get_state, + render_explorer, + set_slider_to_current, + update_state, +) + + +def cow_say_view(app_state: AppState): + application = app_state.app + button = st.button("Cow say?", use_container_width=True) + if button: + step_output = application.step() + action, result, state = step_output + app_state.history.append(Record(state.get_all(), action.name, result)) + set_slider_to_current() + + +def render_cow_said(record: Record): + cow_said = record.state.get("cow_said") + cow_should_speak = record.state.get("cow_should_speak") + with st.chat_message("cow", avatar="🐮" if cow_should_speak and cow_said else "💭"): + grey = "#A9A9A9" + if record.action == "cow_should_say": + if cow_should_speak: + st.markdown( + f"

Cow will not speak.

", unsafe_allow_html=True + ) + else: + st.markdown( + f"

Cow will not speak.

", unsafe_allow_html=True + ) + else: + if cow_said is None: + st.markdown(f"

...

", unsafe_allow_html=True) + if cow_said is not None: + st.code(cow_said, language="plaintext") + + +def retrieve_state(): + if "burr_state" not in st.session_state: + state = AppState.from_empty(app=cowsay_application.application()) + else: + state = get_state() + return state + + +def main(): + st.set_page_config(layout="wide") + st.title("Talking cows with Burr") + app_state = retrieve_state() # retrieve first so we can use for the ret of the step + columns = st.columns(2) + with columns[0]: + cow_say_view(app_state) + with st.container(height=800): + for item in app_state.history: + render_cow_said(item) + with columns[1]: + render_explorer(app_state) + update_state(app_state) # update so the next iteration knows what to do + + +if __name__ == "__main__": + main() diff --git a/examples/gpt/app.py b/examples/gpt/app.py deleted file mode 100644 index ab49bc0f..00000000 --- a/examples/gpt/app.py +++ /dev/null @@ -1,427 +0,0 @@ -import dataclasses -import inspect -import logging -import time -from random import random -from typing import Generator, Optional - -import graphviz -import streamlit as st -from graphviz import Digraph -from hamilton import driver - -import burr.core -from burr import core -from burr.core import Result, ApplicationBuilder, Condition, State -from burr.integrations.hamilton import Hamilton, from_state, append_state, update_state, StateSource -from examples.gpt import capabilities -from examples.gpt.server import ChatItem - -from streamlit_float import float_init, float_parent -import streamlit_super_slider - -st.set_page_config(layout="wide") -st.title("Burr Demo") - -float_init() - - -def burr_app() -> burr.core.Application: - dr = driver.Driver({"provider": "openai"}, capabilities) # TODO -- add modules - Hamilton.set_driver(dr) - prompt = Hamilton( - inputs={"prompt": from_state("prompt")}, - outputs={"processed_prompt": append_state("chat_history")}, - ) - check_safety = Hamilton( - inputs={"prompt": from_state("prompt")}, - outputs={"safe": update_state("safe")}, - ) - decide_mode = Hamilton( - inputs={"prompt": from_state("prompt")}, - outputs={"mode": update_state("mode")}, - ) - generate_image = Hamilton( - inputs={"prompt": from_state("prompt")}, - outputs={"generated_image": update_state("response")}, - ) - generate_text = Hamilton( - inputs={"chat_history": from_state("chat_history")}, - outputs={"generated_text": update_state("response")}, - ) - response = Hamilton( - inputs={"response": from_state("response"), "safe": from_state("safe"), "mode": from_state("mode")}, - outputs={"processed_response": append_state("chat_history")}, - ) - error = Hamilton( - inputs={"error": from_state("error")}, - outputs={"processed_error": append_state("chat_history"), "error": update_state("error")}, - ) - output = Result("output", fields=["chat_history"]) - - app = ( - ApplicationBuilder() - .with_state(chat_history=[]) - .with_functions( - prompt=prompt, - check_safety=check_safety, - decide_mode=decide_mode, - generate_image=generate_image, - generate_text=generate_text, - response=response, - error=error, - output=output, - ) - .with_transitions( - ("prompt", "check_safety"), - ("check_safety", "decide_mode", Condition.expr("safe")), - ("check_safety", "response", Condition.expr("not safe")), - ("decide_mode", "generate_image", Condition.when(mode="image")), - ("decide_mode", "generate_text", Condition.when(mode="text")), - ("generate_image", "response"), - ("generate_text", "response"), - ("response", "output"), - ("error", "output"), - ("output", "prompt"), - ) - .with_entrypoint("prompt") - .build() - ) - return app - - -def initialize_state(): - # Initialize chat history - if "chat_history" not in st.session_state: - st.session_state.chat_history = [] - if "state_history" not in st.session_state: - st.session_state.state_history = [] - if "app" not in st.session_state: - st.session_state.app = burr_app() - if "running" not in st.session_state: - st.session_state.running = False - - -def render_message(message: ChatItem): - type_ = message["type"] - role = message["role"] - content = message["content"] - - with st.chat_message(role): - if type_ == "text": - st.markdown(content) - elif type_ == "image": - st.image(content) - - -def render_chat_history(): - # st.markdown(''' - # ''', unsafe_allow_html=True) - # container = st.container() - # container.markdown("", unsafe_allow_html=True) - - with st.container(): - for message in st.session_state.chat_history: - render_message(message) - current_node = _get_current_node() - if st.session_state.running: - st.status(current_node) - - -# def render_thoughts(): -# app: core.Application = st.session_state.app -# is_running = st.session_state.running -# next_step = app.get_next_function() -# -# if not is_running or next_step is None: -# return -# if next_step.name != "prompt": -# label = f"Running: {next_step.name}" -# status = st.status(label, state="running", expanded=False) -# with status: -# render_state() -# status.update(label=label, state="running", expanded=False) - -def _get_current_node() -> Optional[str]: - app: core.Application = st.session_state.app - is_running = st.session_state.running - next_step = app.get_next_function() - # print(f"next_step={next_step}") - if not is_running or next_step is None: - return "prompt" # back to the beginning - return next_step.name - - -def prompt(): - prompt = st.chat_input( - "I can generate an image or text! And I will choose intelligently...", - disabled=st.session_state.running, - key="chat_input" - ) - if prompt: - st.session_state.app.state = ( - st.session_state.app.state - .update(prompt=prompt) - ) # quick hack to reset it - st.session_state.running = True - - -# def _modify_digraph_attrs(digraph: graphviz.Digraph, current_node: str = None): -# digraph.node(current_node, fillcolor="lightblue", style="rounded,filled", fontcolor="white") -# digraph.attr(bgcolor="transparent") - -def _modify_digraph_attrs(digraph: graphviz.Digraph, current_node: str = None, prior_nodes: list = []): - # Define a function to generate a lighter color - # def lighten_color(color, amount=0.5): - # if amount > 1: - # amount = 1 - # import matplotlib.colors as mc - # import colorsys - # try: - # c = mc.cnames[color] - # except: - # c = color - # c = colorsys.rgb_to_hls(*mc.to_rgb(c)) - # lightened_color = colorsys.hls_to_rgb(1-c[0], amount * (1 - c[1]), 1-c[2]) - # return mc.to_hex(lightened_color) - import matplotlib.colors as mc - import colorsys - - def lighten_color(color, amount=0.5): - if amount > 1: - amount = 1 - if amount < 0: - amount = 0 - try: - c = mc.cnames[color] - except KeyError: - c = color - c = colorsys.rgb_to_hls(*mc.to_rgb(c)) - lightened_color = colorsys.hls_to_rgb(c[0], 1 - amount * (1 - c[1]), c[2]) - return mc.to_hex(lightened_color) - - # # Test the function with increasing amounts of lightening - # base_color = "lightblue" - # lightened_colors = [lighten_color(base_color, amount=i * 0.1) for i in range(5)] - # lightened_colors - - # Set attributes for the current node - digraph.node(current_node, fillcolor="blue", style="rounded,filled", fontcolor="white") - - # Set attributes for prior nodes with progressively lighter colors - seen = {current_node} - base_color = "lightblue" - print(current_node, prior_nodes) - for i, node in enumerate(prior_nodes): - if node not in seen: - seen.add(node) - else: - continue - lighter_color = lighten_color(base_color, amount=1 - ((i + 1) * 0.1)) - print(node, lighter_color) - digraph.node(node, fillcolor=lighter_color, style="rounded,filled", fontcolor="black") - - digraph.attr(bgcolor="transparent") - - -@st.cache_data -def state_machine_viz(current_node: str, slider_position: int, _prior_nodes: list = None): - if _prior_nodes is None: - _prior_nodes = [] - app = st.session_state.app - visualized = app.visualize(None, include_conditions=False, include_state=True) - _modify_digraph_attrs(visualized, current_node=current_node, prior_nodes=_prior_nodes) - return visualized - - -def render_state_machine(current_node: str, slider_position: int, prior_nodes: list[str]): - visualized = state_machine_viz(current_node, slider_position, prior_nodes) - st.graphviz_chart(visualized, use_container_width=True) - - -def render_state(slider_position: int): - # state_to_render = { - # key: value for key, value in st.session_state.app.state.get_all().items() if key != "chat_history" - # } - if len(st.session_state.state_history) == 0: - return - state_to_render = st.session_state.state_history[slider_position]["prior_state"].get_all() - result_to_render = st.session_state.state_history[slider_position]["result"] - - # if "chat_history" in state_to_render: - # del state_to_render["chat_history"] - st.header("State") - st.json(state_to_render, expanded=False) - st.header("Result") - st.json(result_to_render, expanded=False) - - -def render_step(current_node: str): - # current_node = _get_current_node() - app: core.Application = st.session_state.app - node_object: core.Function = app.function_map[current_node] - is_hamilton = isinstance(node_object, Hamilton) - - def format_read(var): - out = f"- `{var}`" - if is_hamilton: - inputs = node_object._inputs - corresponding_input = { - k: v for k, v in inputs.items() if isinstance(v, StateSource) and v.state_key == var - } - if corresponding_input: - return f"- `state['{var}']` → `{list(corresponding_input.keys())[0]}`" - return out - - def format_write(var): - out = f"- `{var}`" - if is_hamilton: - outputs = node_object._outputs - corresponding_output = { - k: v for k, v in outputs.items() if v.key == var - } - k, v = list(corresponding_output.items())[0] - if corresponding_output: - out = f"- `state['{var}']`" - if v.mode == "update": - out += f"← `{k}`" - if v.mode == "append": - out += f" ← `state['{var}'].append({k})`" - # out += " (`.append()`)" - return out - - reads = "\n".join([format_read(var) for var in node_object.reads]) - writes = "\n".join([format_write(var) for var in node_object.writes]) - st.markdown(f"#### Reads: \n {reads}") - st.markdown(f"#### Writes: \n {writes}") - if is_hamilton: - digraph = node_object.visualize_step(show_legend=False) - st.graphviz_chart(digraph, use_container_width=False) - else: - code = inspect.getsource(node_object.__class__) - st.code(code, language="python") - - -def explorer(default_expanded: bool) -> str: - is_running = st.session_state.running - # running_str = "running" if is_running else "complete" - - # label = f"Running: {current_node}" if is_running and current_node is not None else "Inspect..." - # status = st.status(label, state=running_str, expanded=True) - # with status: - # render_state() - slider = 0 - total_state_length = len(st.session_state.state_history) - current_node = _get_current_node() - prior_nodes = [ - st.session_state.state_history[i]["function"] for i in range(-total_state_length, slider + total_state_length - 1) - ] - placeholder = st.empty() - placeholder.markdown("` `") - if total_state_length > 1: - slider = st.slider("index", -total_state_length, 0, 0, key="rewind", label_visibility="hidden") - if slider < 0: - current_node_index = slider + total_state_length - current_node = st.session_state.state_history[current_node_index]["function"] - prior_nodes = [ - st.session_state.state_history[i]["function"] for i in range(current_node_index, current_node_index - 5, -1) - ] - - # prior_nodes = list(reversed(prior_nodes)) - - with placeholder.container(height=500): - - state_machine_view, step_view, data_view = st.tabs(["App", "Function", "State/Results"]) - # placeholder = st.empty() - # placeholder.header(f"node={current_node}") - # with st.container(): - with st.container(): - with state_machine_view: - render_state_machine(current_node=current_node, slider_position=slider, prior_nodes=prior_nodes) - with step_view: - render_step(current_node=current_node) - with data_view: - render_state(slider_position=slider) - return current_node - - -def stop(): - st.session_state.running = False - - -def forward_step(): - # import time - # time.sleep(10) - if st.session_state.running: - app: burr.core.Application = st.session_state.app - try: - prior_state = app.state - t = time.time() - function, result, state = app.step() - # print(f">>>>> function={function.name} in {time.time() - t} seconds") - st.session_state.chat_history = state["chat_history"] - st.session_state.state_history.append({ - "state": state, - "prior_state": prior_state, - "function": function.name, - "result": result - }) # just keep track of it all so we can look through it - if function.name == "output": - print("hit output, stopping") - stop() - st.rerun() - except StopIteration: - print("stop iteration, stopping") - stop() - except Exception as e: - print("exception") - stop() - raise e - - -def _print_status(): - print('------') - # print("running: burr=", st.session_state.get("app")) - # if st.session_state.get("app"): - # print("running: state=", st.session_state.get("app").state) - chat_history = st.session_state.get("chat_history") - if chat_history is not None: - print("running: chat_history=", len(chat_history)) - print("running: running=", st.session_state.get("running")) - - -def explainer(): - placeholder = st.empty() - current_node = explorer(default_expanded=True) - placeholder.markdown(f"`{current_node}`") - - -def chat(): - render_chat_history() - forward_step() - - -def main(): - # _print_status() - initialize_state() - col1, col2 = st.columns(2) - prompt() - # afterwards cause it relies on modified state... - with col2: - col2.float() - explainer() - with col1: - chat() - - -if __name__ == '__main__': - main() - # _print_status() - # initialize_state() - # render_chat_history() - # # render_thoughts() - # on_prompt() - # # if not st.session_state.get("running"): - # forward_step() diff --git a/examples/gpt/application.py b/examples/gpt/application.py new file mode 100644 index 00000000..318ffae7 --- /dev/null +++ b/examples/gpt/application.py @@ -0,0 +1,284 @@ +import abc +import functools + +import openai + +from burr.core import Action, ApplicationBuilder, State, default, expr, when + + +class PromptInput(Action): + @property + def reads(self) -> list[str]: + return ["prompt"] + + def run(self, state: State) -> dict: + return {"chat_record": {"role": "user", "content": state["prompt"], "type": "text"}} + + @property + def writes(self) -> list[str]: + return ["chat_history"] + + def update(self, result: dict, state: State) -> State: + return state.wipe(keep=["prompt", "chat_history"]).append( + chat_history=result["chat_record"] + ) + + +class SafetyCheck(Action): + @property + def reads(self) -> list[str]: + return ["prompt"] + + def run(self, state: State) -> dict: + if "unsafe" in state["prompt"]: + # quick for testing + return {"safe": False} + return {"safe": True} + + @property + def writes(self) -> list[str]: + return ["safe"] + + def update(self, result: dict, state: State) -> State: + return state.update(safe=result["safe"]) + + +MODES = [ + "answer_question", + "draw_image", + "generate_code", +] + + +@functools.lru_cache(maxsize=None) +def _get_openai_client(): + return openai.Client() + + +class ChooseMode(Action): + def __init__( + self, client: openai.Client = _get_openai_client(), model: str = "gpt-4", modes=tuple(MODES) + ): + super(ChooseMode, self).__init__() + self.client = client + self.model = model + self.modes = modes + + @property + def reads(self) -> list[str]: + return ["prompt"] + + def run(self, state: State) -> dict: + prompt = ( + f"You are a chatbot. You've been prompted this: {state['prompt']}. " + f"You have the capability of responding in the following modes: {' '.join(self.modes)}. " + "Please respond with *only* a single word representing the mode that most accurately" + " corresponds to the prompt. Fr instance, if the prompt is 'draw a picture of a cat', " + "the mode would be 'image'. If the prompt is 'what is the capital of France', the mode would be 'text'." + "If none of these modes apply, please respond with 'unknown'." + ) + result = self.client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": prompt}, + ], + ) + content = result.choices[0].message.content + mode = content.lower() + if mode not in self.modes: + mode = "unknown" + return {"mode": mode} + + @property + def writes(self) -> list[str]: + return ["mode"] + + def update(self, result: dict, state: State) -> State: + return state.update(mode=result["mode"]) + + +class BaseChatCompletion(Action, abc.ABC): + @property + def reads(self) -> list[str]: + return ["prompt", "chat_history"] + + @abc.abstractmethod + def chat_response(self, state: State) -> dict: + pass + + def run(self, state: State) -> dict: + return {"response": self.chat_response(state)} + + @property + def writes(self) -> list[str]: + return ["response"] + + def update(self, result: dict, state: State) -> State: + return state.update(**result) + + +class DontKnowResponse(BaseChatCompletion): + def __init__(self, modes=tuple(MODES)): + super(DontKnowResponse, self).__init__() + self.modes = modes + + def chat_response(self, state: State) -> dict: + return { + "content": f"None of the modes I support: ({self.modes}) apply to your question. Please clarify?", + "type": "text", + "role": "assistant", + } + + +def _get_text_response(chat_history: list[dict], model: str, client: openai.Client) -> str: + chat_history_api_format = [ + { + "role": chat["role"], + "content": chat["content"], + } + for chat in chat_history + ] + result = client.chat.completions.create( + model=model, + messages=chat_history_api_format, + ) + return result.choices[0].message.content + + +class AnswerQuestionResponse(BaseChatCompletion): + def __init__(self, client: openai.Client = _get_openai_client(), model: str = "gpt-4"): + super(AnswerQuestionResponse, self).__init__() + self.client = client + self.model = model + + def chat_response(self, state: State) -> dict: + chat_history = state["chat_history"].copy() + chat_history[-1][ + "content" + ] = f"Please answer the following question: {chat_history[-1]['content']}" + response = _get_text_response(chat_history, self.model, self.client) + return {"content": response, "type": "text", "role": "assistant"} + + +class GenerateImageResponse(BaseChatCompletion): + def __init__(self, client: openai.Client = _get_openai_client(), model: str = "dall-e-2"): + super(GenerateImageResponse, self).__init__() + self.client = client + self.model = model + + def chat_response(self, state: State) -> dict: + result = self.client.images.generate( + model=self.model, prompt=state["prompt"], size="1024x1024", quality="standard", n=1 + ) + return {"content": result.data[0].url, "type": "image", "role": "assistant"} + + +class GenerateCodeResponse(BaseChatCompletion): + def __init__(self, client: openai.Client = _get_openai_client(), model: str = "gpt-4"): + super(GenerateCodeResponse, self).__init__() + self.client = client + self.model = model + + def chat_response(self, state: State) -> dict: + chat_history = state["chat_history"].copy() + chat_history[-1]["content"] = ( + f"Please answer the following question, " + f"responding *only* with code, and nothing else: {chat_history[-1]['content']}" + ) + return { + "content": _get_text_response(state["chat_history"], self.model, self.client), + "type": "code", + "role": "assistant", + } + + +class Response(Action): + @property + def reads(self) -> list[str]: + return ["response", "safe", "mode"] + + def run(self, state: State) -> dict: + if not state["safe"]: + return { + "chat_record": { + "role": "assistant", + "content": "I'm sorry, I can't respond to that.", + "type": "text", + } + } + return {"chat_record": state["response"]} + + @property + def writes(self) -> list[str]: + return ["chat_history"] + + def update(self, result: dict, state: State) -> State: + return state.append(chat_history=result["chat_record"]) + + +class Error(Action): + @property + def reads(self) -> list[str]: + return ["error"] + + def run(self, state: State) -> dict: + return { + "chat_record": {"role": "assistant", "content": str(state["error"]), "type": "error"} + } + + @property + def writes(self) -> list[str]: + return ["chat_history"] + + def update(self, result: dict, state: State) -> State: + return state.append(chat_history=result["chat_record"]) + + +def application(): + return ( + ApplicationBuilder() + .with_actions( + prompt=PromptInput(), + check_safety=SafetyCheck(), + decide_mode=ChooseMode(), + generate_image=GenerateImageResponse(), + generate_code=GenerateCodeResponse(), + answer_question=AnswerQuestionResponse(), + prompt_for_more=DontKnowResponse(), + response=Response(), + error=Error(), + ) + .with_entrypoint("prompt") + .with_state(chat_history=[]) + .with_transitions( + ("prompt", "check_safety", default), + ("check_safety", "decide_mode", when(safe=True)), + ("check_safety", "response", default), + ("decide_mode", "generate_image", when(mode="draw_image")), + ("decide_mode", "generate_code", when(mode="generate_code")), + ("decide_mode", "answer_question", when(mode="answer_question")), + ("decide_mode", "prompt_for_more", default), + ( + ["generate_image", "answer_question", "generate_code", "prompt_for_more"], + "response", + when(error=None), + ), + ( + ["generate_image", "answer_question", "generate_code", "prompt_for_more"], + "error", + expr("error is not None"), + ), + ("response", "prompt", default), + ("error", "prompt", default), + ) + .build() + ) + + +if __name__ == "__main__": + app = application() # doing good data science is up to you... + # state, result = app.run(until=["result"]) + app.visualize(output_file_path="ml_training.png", include_conditions=False, view=True) + # assert state["counter"] == 10 + # print(state["counter"]) diff --git a/examples/gpt/capabilities.py b/examples/gpt/capabilities.py deleted file mode 100644 index 7ef83f6e..00000000 --- a/examples/gpt/capabilities.py +++ /dev/null @@ -1,128 +0,0 @@ -import random -from typing import Optional, Tuple, TypedDict - -import openai -from hamilton.function_modifiers import config - -ChatContents = TypedDict( - "ChatContents", - { - "role": str, - "content": str, - "type": str - } -) - - -def processed_prompt(prompt: str) -> dict: - return { - "role": "user", - "content": prompt, - "type": "text" - } - - -@config.when(provider="openai") -def client() -> openai.Client: - return openai.Client() - - -def text_model() -> str: - return "gpt-4" - - -def image_model() -> str: - return "dall-e-2" - - -def safe(prompt: str) -> bool: - # TODO -- decide whether this is safe or not - # if "unsafe" in prompt: - # return False - return True - - -def modes() -> Tuple[str, ...]: - return "text", "image" - - -def find_mode_prompt(prompt: str, modes: Tuple[str, ...]) -> str: - return (f"You are a chatbot. You've been prompted this: {prompt}. " - f"You have the capability of responding in the following modes: {' '.join(modes)}. " - "Please respond with *only* a single word representing the mode that most accurately" - " corresponds to the prompt. FOr instaance, if the prompt is 'draw a picture of a cat', " - "the mode would be 'image'. If the prompt is 'what is the capital of France', the mode would be 'text'.") - - -def suggested_mode(find_mode_prompt: str, client: openai.Client, text_model: str) -> str: - result = client.chat.completions.create( - model=text_model, - messages=[ - {"role": "system", "content": "You are a helpful assistant"}, - {"role": "user", "content": find_mode_prompt} - ] - ) - content = result.choices[0].message.content - return content - - -def mode(suggested_mode: str, modes: Tuple[str, ...]) -> str: - # TODO -- use instructor! - if suggested_mode.lower() not in modes: - return "text" # default to text - return suggested_mode.lower() - - -def generated_text(chat_history: list[dict], text_model: str, client: openai.Client) -> str: - chat_history_api_format = [ - { - "role": chat["role"], - "content": chat["content"], - } - for i, chat in enumerate(chat_history) - ] - result = client.chat.completions.create( - model=text_model, - messages=chat_history_api_format, - ) - return result.choices[0].message.content - - -# def generated_text_error(generated_text) -> dict: -# ... - - -def generated_image(prompt: str, image_model: str, client: openai.Client) -> str: - result = client.images.generate( - model=image_model, - prompt=prompt, - size="1024x1024", - quality="standard", - n=1 - ) - return result.data[0].url - - -def processed_response( - response: Optional[str], - mode: str, - safe: bool) -> ChatContents: - if not safe: - return { - "role": "assistant", - "content": "I'm sorry, I can't do that.", - "type": "text" - } - return { - "role": "assistant", - "type": mode, - "content": response - } - - -def processed_error(error: str) -> dict: - return { - "role": "assistant", - "error": error, - "type": "text" - } diff --git a/examples/gpt/requirements.txt b/examples/gpt/requirements.txt deleted file mode 100644 index a6e1be9b..00000000 --- a/examples/gpt/requirements.txt +++ /dev/null @@ -1,5 +0,0 @@ -openai -fastui -fastapi -sf-hamilton -uvicorn diff --git a/examples/gpt/run.py b/examples/gpt/run.py deleted file mode 100644 index 697efa43..00000000 --- a/examples/gpt/run.py +++ /dev/null @@ -1,93 +0,0 @@ -from hamilton import driver - -from burr.core import ApplicationBuilder, Condition, Result -from burr.implementations import Placeholder -from burr.integrations.hamilton import Hamilton, from_state, update_state, append_state -from examples.gpt import capabilities - - -def display(response): - pass - - -def main(): - dr = driver.Driver({"provider" : "openai"}, capabilities) # TODO -- add modules - Hamilton.set_driver(dr) - prompt = Hamilton( - name="prompt", - inputs={"prompt": from_state("prompt")}, - outputs={"processed_prompt": append_state("chat_history")}, - ) - check_safety = Hamilton( - name="check_safety", - inputs={"prompt": from_state("prompt")}, - outputs={"safe": update_state("safe")}, - ) - decide_mode = Hamilton( - name="decide_mode", - inputs={"prompt": from_state("prompt")}, - outputs={"mode": update_state("mode")}, - ) - generate_image = Hamilton( - name="generate_image", - inputs={"prompt": from_state("prompt")}, - outputs={"generated_image": update_state("response")}, - ) - generate_text = Hamilton( - name="generate_text", - inputs={"chat_history": from_state("chat_history")}, - outputs={"generated_text": update_state("response")}, - ) - response = Hamilton( - name="response", - inputs={"response": from_state("response"), "safe": from_state("safe"), "mode": from_state("mode")}, - outputs={"processed_response": append_state("chat_history")}, - ) - error = Hamilton( - name="error", - inputs={"error": from_state("error")}, - outputs={"processed_error": append_state("chat_history"), "error": update_state("error")}, - ) - output = Result("output", fields=["chat_history"]) - - agent = ( - ApplicationBuilder() - .with_state(chat_history=[], prompt="Draw an image of a turtle saying 'hello, world'") - .with_transition(prompt, check_safety) - .with_transition( - check_safety, decide_mode, Condition.expr("safe") - ) # if safe, decide what to do next - .with_transition( - check_safety, response, Condition.expr("not safe") - ) # if not safe, go to output - .with_transition(decide_mode, generate_image, Condition.when(mode="image")) - .with_transition(decide_mode, generate_text, Condition.when(mode="text")) - # .with_transition(decide_mode, generate_code, Condition.when(mode="code")) - # .with_transition(decide_mode, bing_search, Condition.when(mode="search")) - .with_transition( - [generate_image, generate_text], - response, - Condition.expr("True"), - ) - # .with_transition( - # [generate_image, generate_text], - # error, - # Condition.expr("bool(error)"), - # ) - .with_transition([response, error], output) - .with_transition(output, prompt) - .with_entrypoint(prompt) - .build() - ) - # agent.visualize("./out", include_conditions=True, include_state=True) - state, [response,] = agent.run(["output"], gate="any_complete") - import pprint - pprint.pprint(response) - # if response is not None: - # display(response) - # else: - # display(error) - - -if __name__ == "__main__": - main() diff --git a/examples/gpt/server.py b/examples/gpt/server.py deleted file mode 100644 index 24e7e65d..00000000 --- a/examples/gpt/server.py +++ /dev/null @@ -1,110 +0,0 @@ -from typing import Annotated, Literal - -from fastapi import FastAPI -from fastapi.responses import HTMLResponse -from fastui import FastUI, AnyComponent, prebuilt_html, components as c -from fastui.events import GoToEvent -from fastui.forms import fastui_form -from pydantic import BaseModel, Field - -app = FastAPI() - - -class ChatItem(BaseModel): - content: str - role: str - type: Literal['text', 'image'] - - -class ChatHistory(BaseModel): - chat_history: list[ChatItem] = Field(default_factory=list) - - -class ChatForm(BaseModel): - prompt: str | None = Field( - None - ) - - -app.state.chat_history = ChatHistory() # initialize it - - -def chat_history_generator(): - pass - - -def global_page(*components: AnyComponent, title: str | None = None) -> list[AnyComponent]: - return [ - c.Navbar( - title='Burr Demo', - title_event=GoToEvent(url='/'), - links=[ - c.Link( - components=[c.Text(text='Chat')], - on_click=GoToEvent(url='/chat/'), - active='startswith:/chat', - ), - c.Link( - components=[c.Text(text='State')], - on_click=GoToEvent(url='/state/'), - active='startswith:/state', - ), - ], - ), - c.Page( - components=[ - *((c.Heading(text=title),) if title else ()), - *components, - ], - ), - c.Footer( - # extra_text='Burr Demo', - links=[ - c.Link( - components=[c.Text(text='Github')], on_click=GoToEvent(url='https://github.com/dagworks-inc/burr') - ), - c.Link(components=[c.Text(text='PyPI')], on_click=GoToEvent(url='https://pypi.org/project/dw-burr/')), - ], - ), - ] - - -@app.get("/api/", response_model=FastUI, response_model_exclude_none=True) -def home() -> list[AnyComponent]: - """ - Show a table of four users, `/api` is the endpoint the frontend will connect to - when a user visits `/` to fetch components to render. - """ - return global_page(title="Welcome!") - - -@app.get("/api/chat/", response_model=FastUI, response_model_exclude_none=True) -def chat() -> list[AnyComponent]: - cs = [ - # c.Heading(text='Large Form', level=2), - # c.Paragraph(text='Form with a lot of fields.'), - c.ModelForm(model=ChatForm, submit_url='/api/', display_mode="default"), - ] - - # input_form = c.ModelForm(model=ChatForm, submit_url='/api/forms/big'), - return global_page(*cs, title=None) - - -@app.get('/{path:path}') -async def html_landing() -> HTMLResponse: - """Simple HTML page which serves the React app, comes last as it matches all paths.""" - return HTMLResponse(prebuilt_html(title='Burr Demo')) -# -# -# @app.get("/api/chat/", response_model=FastUI, response_model_exclude_none=True) -# def user_profile(user_id: int) -> list[AnyComponent]: -# """ -# User profile page, the frontend will fetch this when the user visits `/user/{id}/`. -# """ -# return c.Page -# -# -# @app.get('/{path:path}') -# async def html_landing() -> HTMLResponse: -# """Simple HTML page which serves the React app, comes last as it matches all paths.""" -# return HTMLResponse(prebuilt_html(title='FastUI Demo')) diff --git a/examples/gpt/streamlit_app.py b/examples/gpt/streamlit_app.py new file mode 100644 index 00000000..9b952707 --- /dev/null +++ b/examples/gpt/streamlit_app.py @@ -0,0 +1,92 @@ +from typing import Optional + +import application as chatbot_application +import streamlit as st + +from burr.integrations.streamlit import ( + AppState, + Record, + get_state, + render_explorer, + set_slider_to_current, + update_state, +) + + +def render_chat_message(record: Record): + if record.action in ["prompt", "response"]: + with st.chat_message(record.result["chat_record"]["role"]): + content = record.result["chat_record"]["content"] + content_type = record.result["chat_record"]["type"] + if content_type == "image": + st.image(content) + elif content_type == "code": + st.code(content) + elif content_type == "text": + st.write(content) + elif content_type == "error": + st.error(content) + + +def retrieve_state(): + if "burr_state" not in st.session_state: + state = AppState.from_empty(app=chatbot_application.application()) + else: + state = get_state() + return state + + +def chatbot_step(app_state: AppState, prompt: Optional[str]) -> bool: + """Pushes state forward for the chatbot. Returns whether or not to rerun the app. + + :param app_state: State of the app + :param prompt: Prompt to set the chatbot to. If this is None it means it should continue and not be reset. + :return: + """ + if prompt is not None: + # We need to update + app_state.app.update_state(app_state.app.state.update(prompt=prompt)) + st.session_state.running = True # set to running + # if its not running this is a no-op + if not st.session_state.get("running", False): + return False + application = app_state.app + step_output = application.step() + if step_output is None: + st.session_state.running = False + return False + action, result, state = step_output + app_state.history.append(Record(state.get_all(), action.name, result)) + set_slider_to_current() + if action.name == "response": + # we've gotten to the end + st.session_state.running = False + return True # run one last time + return True + + +def main(): + st.set_page_config(layout="wide") + st.title("GPT clone with Burr") + app_state = retrieve_state() # retrieve first so we can use for the ret of the step + columns = st.columns(2) + with columns[0]: + prompt = st.chat_input( + "...", disabled=st.session_state.get("running", False), key="chat_input" + ) + should_rerun = chatbot_step(app_state, prompt) + with st.container(height=800): + for item in app_state.history: + render_chat_message(item) + with columns[1]: + render_explorer(app_state) + update_state(app_state) # update so the next iteration knows what to do + # rerun moves the state machine forwards one + # print(st.session_state.should_rerun) + if should_rerun: + print("rerunning") + st.rerun() + + +if __name__ == "__main__": + main() diff --git a/examples/ml_training.py b/examples/ml_training.py new file mode 100644 index 00000000..5e58c66b --- /dev/null +++ b/examples/ml_training.py @@ -0,0 +1,110 @@ +import burr.core.application +from burr.core import Action, Condition, State, default + + +class ProcessDataAction(Action): + @property + def reads(self) -> list[str]: + return ["data_path"] + + def run(self, state: State) -> dict: + pass + + @property + def writes(self) -> list[str]: + return ["training_data", "evaluation_data"] + + def update(self, result: dict, state: State) -> State: + return state.update(training_data=result["training_data"]) + + +class TrainModel(Action): + @property + def reads(self) -> list[str]: + return ["training_data", "epochs"] + + def run(self, state: State) -> dict: + pass + + @property + def writes(self) -> list[str]: + return ["models", "training_metrics", "epochs"] + + def update(self, result: dict, state: State) -> State: + return state.update( + epochs=result["epochs"], # overwrite each epoch + ).append( + models=result["model"], # append -- note this can get big if your model is big + # so you'll want to overwrite but store the conditions, or log somewhere + metrics=result["metrics"], # append the metrics + ) + + +class ValidateModel(Action): + @property + def reads(self) -> list[str]: + return ["models", "evaluation_data"] + + def run(self, state: State) -> dict: + pass + + @property + def writes(self) -> list[str]: + return ["validation_metrics"] + + def update(self, result: dict, state: State) -> State: + return state.append(validation_metrics=result["validation_metrics"]) + + +class BestModel(Action): + @property + def reads(self) -> list[str]: + return ["validation_metrics", "models"] + + def run(self, state: State) -> dict: + pass + + @property + def writes(self) -> list[str]: + return ["best_model"] + + def update(self, result: dict, state: State) -> State: + return state.update(best_model=result["best_model"]) + + +def application(epochs: int) -> burr.core.application.Application: + return ( + burr.core.ApplicationBuilder() + .with_state( + data_path="data.csv", + epochs=10, + training_data=None, + evaluation_data=None, + models=[], + training_metrics=[], + validation_metrics=[], + best_model=None, + ) + .with_actions( + process_data=ProcessDataAction(), + train_model=TrainModel(), + validate_model=ValidateModel(), + best_model=BestModel(), + ) + .with_transitions( + ("process_data", "train_model", default), + ("train_model", "validate_model", default), + ("validate_model", "best_model", Condition.expr(f"epochs>{epochs}")), + ("validate_model", "train_model", default), + ) + .with_entrypoint("process_data") + .build() + ) + + +if __name__ == "__main__": + app = application(100) # doing good data science is up to you... + # state, result = app.run(until=["result"]) + app.visualize(output_file_path="ml_training.png", include_conditions=True, view=True) + # assert state["counter"] == 10 + # print(state["counter"]) diff --git a/examples/simulation.py b/examples/simulation.py new file mode 100644 index 00000000..e69de29b diff --git a/requirements-test.txt b/requirements-test.txt new file mode 100644 index 00000000..ee4ba018 --- /dev/null +++ b/requirements-test.txt @@ -0,0 +1,2 @@ +pytest +pytest-asyncio diff --git a/tests/core/test_action.py b/tests/core/test_action.py new file mode 100644 index 00000000..102a565b --- /dev/null +++ b/tests/core/test_action.py @@ -0,0 +1,117 @@ +from burr.core import State +from burr.core.action import Action, Condition, Function, Result, default + + +def test_is_async_true(): + class AsyncFunction(Function): + def reads(self) -> list[str]: + return [] + + async def run(self, state: State) -> dict: + return {} + + func = AsyncFunction() + assert func.is_async() + + +def test_is_async_false(): + class SyncFunction(Function): + def reads(self) -> list[str]: + return [] + + def run(self, state: State) -> dict: + return {} + + func = SyncFunction() + assert not func.is_async() + + +def test_with_name(): + class BasicAction(Action): + @property + def reads(self) -> list[str]: + return ["input_variable"] + + def run(self, state: State) -> dict: + return {"output_variable": state["input_variable"]} + + @property + def writes(self) -> list[str]: + return ["output_variable"] + + def update(self, result: dict, state: State) -> State: + return state.update(**result) + + action = BasicAction() + assert action.name is None # Nothing set initially + with_name = action.with_name("my_action") + assert with_name.name == "my_action" # Name set on copy + assert with_name.reads == action.reads + assert with_name.writes == action.writes + + +def test_condition(): + cond = Condition(["foo"], lambda state: state["foo"] == "bar", name="foo") + assert cond.name == "foo" + assert cond.reads == ["foo"] + assert cond.run(State({"foo": "bar"})) == {Condition.KEY: True} + assert cond.run(State({"foo": "baz"})) == {Condition.KEY: False} + + +def test_condition_when(): + cond = Condition.when(foo="bar") + assert cond.name == "foo=bar" + assert cond.reads == ["foo"] + assert cond.run(State({"foo": "bar"})) == {Condition.KEY: True} + assert cond.run(State({"foo": "baz"})) == {Condition.KEY: False} + + +def test_condition_when_complex(): + cond = Condition.when(foo="bar", baz="qux") + assert cond.name == "baz=qux, foo=bar" + assert sorted(cond.reads) == ["baz", "foo"] + assert cond.run(State({"foo": "bar", "baz": "qux"})) == {Condition.KEY: True} + assert cond.run(State({"foo": "baz", "baz": "qux"})) == {Condition.KEY: False} + assert cond.run(State({"foo": "bar", "baz": "corge"})) == {Condition.KEY: False} + assert cond.run(State({"foo": "baz", "baz": "corge"})) == {Condition.KEY: False} + + +def test_condition_default(): + cond = default + assert cond.name == "default" + assert cond.reads == [] + assert cond.run(State({"foo": "bar"})) == {Condition.KEY: True} + + +def test_condition_expr(): + cond = Condition.expr("foo == 'bar'") + assert cond.name == "foo == 'bar'" + assert cond.reads == ["foo"] + assert cond.run(State({"foo": "bar"})) == {Condition.KEY: True} + assert cond.run(State({"foo": "baz"})) == {Condition.KEY: False} + + +def test_condition_expr_complex(): + cond = Condition.expr("foo == 'bar' and baz == 'qux'") + assert cond.name == "foo == 'bar' and baz == 'qux'" + assert sorted(cond.reads) == ["baz", "foo"] + assert cond.run(State({"foo": "bar", "baz": "qux"})) == {Condition.KEY: True} + assert cond.run(State({"foo": "baz", "baz": "qux"})) == {Condition.KEY: False} + assert cond.run(State({"foo": "bar", "baz": "corge"})) == {Condition.KEY: False} + assert cond.run(State({"foo": "baz", "baz": "corge"})) == {Condition.KEY: False} + + +def test_result(): + result = Result(fields=["foo", "bar"]) + assert result.run(State({"foo": "baz", "bar": "qux", "baz": "quux"})) == { + "foo": "baz", + "bar": "qux", + } + assert result.writes == [] # doesn't write anything + assert result.reads == ["foo", "bar"] + # no results + assert result.update( + {"foo": "baz", "bar": "qux"}, State({"foo": "baz", "bar": "qux", "baz": "quux"}) + ) == State( + {"foo": "baz", "bar": "qux", "baz": "quux"} + ) # no impact diff --git a/tests/core/test_application.py b/tests/core/test_application.py index 1a909dfe..17b4aafa 100644 --- a/tests/core/test_application.py +++ b/tests/core/test_application.py @@ -1,25 +1,43 @@ -from typing import Callable +import asyncio +from typing import Awaitable, Callable import pytest from burr.core import State -from burr.core.application import Application, Transition, ApplicationBuilder -from burr.core.function import Function, DEFAULT, Condition -from burr.implementations import Result - - -class PassedInFunction(Function): - def __init__(self, - name: str, - reads: list[str], - writes: list[str], - fn: Callable[[State], dict], - update_fn: Callable[[dict, State], State]): +from burr.core.action import Action, Condition, Result, default +from burr.core.application import ( + Application, + ApplicationBuilder, + Transition, + _arun_function, + _assert_set, + _run_function, + _validate_actions, + _validate_start, + _validate_transitions, +) +from burr.lifecycle import ( + PostRunStepHook, + PostRunStepHookAsync, + PreRunStepHook, + PreRunStepHookAsync, + internal, +) + + +class PassedInAction(Action): + def __init__( + self, + reads: list[str], + writes: list[str], + fn: Callable[[State], dict], + update_fn: Callable[[dict, State], State], + ): + super(PassedInAction, self).__init__() self._reads = reads self._writes = writes self._fn = fn self._update_fn = update_fn - super().__init__(name) def run(self, state: State) -> dict: return self._fn(state) @@ -36,50 +54,220 @@ def writes(self) -> list[str]: return self._writes -counter_action = PassedInFunction( - name="counter", +class PassedInActionAsync(PassedInAction): + def __init__( + self, + reads: list[str], + writes: list[str], + fn: Callable[[State], Awaitable[dict]], + update_fn: Callable[[dict, State], State], + ): + super().__init__(reads=reads, writes=writes, fn=fn, update_fn=update_fn) # type: ignore + + async def run(self, state: State) -> dict: + return await self._fn(state) + + +base_counter_action = PassedInAction( reads=["counter"], writes=["counter"], fn=lambda state: {"counter": state.get("counter", 0) + 1}, - update_fn=lambda result, state: state.update_state(**result), + update_fn=lambda result, state: state.update(**result), +) + + +async def _counter_update_async(state: State) -> dict: + await asyncio.sleep(0.0001) # just so we can make this *truly* async + # does not matter, but more accurately simulates an async function + return {"counter": state.get("counter", 0) + 1} + + +base_counter_action_async = PassedInActionAsync( + reads=["counter"], + writes=["counter"], + fn=_counter_update_async, + update_fn=lambda result, state: state.update(**result), ) +def test_run_function(): + """Tests that we can run a function""" + action = base_counter_action + state = State({}) + result = _run_function(action, state) + assert result == {"counter": 1} + + +def test_run_function_cant_run_async(): + """Tests that we can't run an async function""" + action = base_counter_action_async + state = State({}) + with pytest.raises(ValueError, match="async"): + _run_function(action, state) + + +async def test_a_run_function(): + """Tests that we can run an async function""" + action = base_counter_action_async + state = State({}) + result = await _arun_function(action, state) + assert result == {"counter": 1} + + def test_app_step(): + """Tests that we can run a step in an app""" + counter_action = base_counter_action.with_name("counter") + app = Application( + actions=[counter_action], + transitions=[Transition(counter_action, counter_action, default)], + state=State({}), + initial_step="counter", + ) + action, result, state = app.step() + assert action.name == "counter" + assert result == {"counter": 1} + + +def test_app_step_done(): + """Tests that when we cannot run a step, we return None""" + counter_action = base_counter_action.with_name("counter") + app = Application( + actions=[counter_action], transitions=[], state=State({}), initial_step="counter" + ) + app.step() + assert app.step() is None + + +async def test_app_astep(): + """Tests that we can run an async step in an app""" + counter_action = base_counter_action_async.with_name("counter_async") app = Application( - functions=[counter_action], - transitions=[Transition(counter_action, counter_action, DEFAULT)], + actions=[counter_action], + transitions=[Transition(counter_action, counter_action, default)], state=State({}), - initial_step="counter" + initial_step="counter_async", ) - function, result, state = app.step() - assert function.name == "counter" + action, result, state = await app.astep() + assert action.name == "counter_async" assert result == {"counter": 1} +async def test_app_astep_done(): + """Tests that when we cannot run a step, we return None""" + counter_action = base_counter_action_async.with_name("counter_async") + app = Application( + actions=[counter_action], transitions=[], state=State({}), initial_step="counter_async" + ) + await app.astep() + assert await app.astep() is None + + +# internal API def test_app_many_steps(): + counter_action = base_counter_action.with_name("counter") + app = Application( + actions=[counter_action], + transitions=[Transition(counter_action, counter_action, default)], + state=State({}), + initial_step="counter", + ) + action, result = None, None + for i in range(100): + action, result, state = app.step() + assert action.name == "counter" + assert result == {"counter": 100} + + +async def test_app_many_a_steps(): + counter_action = base_counter_action_async.with_name("counter_async") app = Application( - functions=[counter_action], - transitions=[Transition(counter_action, counter_action, DEFAULT)], + actions=[counter_action], + transitions=[Transition(counter_action, counter_action, default)], state=State({}), - initial_step="counter" + initial_step="counter_async", ) + action, result = None, None for i in range(100): - function, result, state = app.step() - assert function.name == "counter" + action, result, state = await app.astep() + assert action.name == "counter_async" assert result == {"counter": 100} +def test_app_iterate(): + result_action = Result(fields=["counter"]).with_name("result") + counter_action = base_counter_action.with_name("counter") + app = Application( + actions=[counter_action, result_action], + transitions=[ + Transition(counter_action, counter_action, Condition.expr("counter < 10")), + Transition(counter_action, result_action, default), + ], + state=State({}), + initial_step="counter", + ) + res = [] + gen = app.iterate(until=["result"]) + counter = 0 + try: + while True: + action, result, state = next(gen) + if action.name == "counter": + assert state["counter"] == counter + 1 + assert result["counter"] == state["counter"] + counter = result["counter"] + else: + res.append(result) + assert state["counter"] == 10 + assert result["counter"] == 10 + except StopIteration as e: + stop_iteration_error = e + generator_result = stop_iteration_error.value + state, results = generator_result + assert state["counter"] == 10 + assert len(results) == 1 + (result,) = results + assert result["counter"] == 10 + + +async def test_app_a_iterate(): + result_action = Result(fields=["counter"]).with_name("result") + counter_action = base_counter_action_async.with_name("counter") + app = Application( + actions=[counter_action, result_action], + transitions=[ + Transition(counter_action, counter_action, Condition.expr("counter < 10")), + Transition(counter_action, result_action, default), + ], + state=State({}), + initial_step="counter", + ) + res = [] + gen = app.aiterate(until=["result"]) + counter = 0 + # Note that we use an async-for loop cause the API is different, this doesn't + # return anything (async generators are not allowed to). + async for action, result, state in gen: + if action.name == "counter": + assert state["counter"] == counter + 1 + assert result["counter"] == state["counter"] + counter = result["counter"] + else: + res.append(result) + assert state["counter"] == 10 + assert result["counter"] == 10 + + def test_app_run(): - result_action = Result(name="result", fields=["counter"]) + result_action = Result(fields=["counter"]).with_name("result") + counter_action = base_counter_action.with_name("counter") app = Application( - functions=[counter_action, result_action], + actions=[counter_action, result_action], transitions=[ Transition(counter_action, counter_action, Condition.expr("counter < 10")), - Transition(counter_action, result_action, DEFAULT) + Transition(counter_action, result_action, default), ], state=State({}), - initial_step="counter" + initial_step="counter", ) state, results = app.run(until=["result"]) assert state["counter"] == 10 @@ -87,21 +275,228 @@ def test_app_run(): assert results[0]["counter"] == 10 -def test_application_builder_simple(): +async def test_app_a_run(): + result_action = Result(fields=["counter"]).with_name("result") + counter_action = base_counter_action_async.with_name("counter") + app = Application( + actions=[counter_action, result_action], + transitions=[ + Transition(counter_action, counter_action, Condition.expr("counter < 10")), + Transition(counter_action, result_action, default), + ], + state=State({}), + initial_step="counter", + ) + state, results = await app.arun(until=["result"]) + assert state["counter"] == 10 + assert len(results) == 1 + (result,) = results + assert result["counter"] == 10 + + +async def test_app_a_run_async_and_sync(): + result_action = Result(fields=["counter"]).with_name("result") + counter_action_sync = base_counter_action_async.with_name("counter_sync") + counter_action_async = base_counter_action_async.with_name("counter_async") + app = Application( + actions=[counter_action_sync, counter_action_async, result_action], + transitions=[ + Transition(counter_action_sync, counter_action_async, Condition.expr("counter < 20")), + Transition(counter_action_async, counter_action_sync, default), + Transition(counter_action_sync, result_action, default), + ], + state=State({}), + initial_step="counter_sync", + ) + state, results = await app.arun(until=["result"]) + assert state["counter"] > 20 + assert len(results) == 1 + (result,) = results + assert result["counter"] > 20 + + +def test_app_set_state(): + counter_action = base_counter_action.with_name("counter") + app = Application( + actions=[counter_action], + transitions=[Transition(counter_action, counter_action, default)], + state=State(), + initial_step="counter", + ) + assert "counter" not in app.state # initial value + app.step() + assert app.state["counter"] == 1 # updated value + state = app.state + app.update_state(state.update(counter=2)) + assert app.state["counter"] == 2 # updated value + + +def test_app_get_next_step(): + counter_action_1 = base_counter_action.with_name("counter_1") + counter_action_2 = base_counter_action.with_name("counter_2") + counter_action_3 = base_counter_action.with_name("counter_3") + app = Application( + actions=[counter_action_1, counter_action_2, counter_action_3], + transitions=[ + Transition(counter_action_1, counter_action_2, default), + Transition(counter_action_2, counter_action_3, default), + Transition(counter_action_3, counter_action_1, default), + ], + state=State(), + initial_step="counter_1", + ) + # uninitialized -- counter_1 + assert app.get_next_action().name == "counter_1" + app.step() + # ran counter_1 -- counter_2 + assert app.get_next_action().name == "counter_2" + app.step() + # ran counter_2 -- counter_3 + assert app.get_next_action().name == "counter_3" + app.step() + # ran counter_3 -- back to counter_1 + assert app.get_next_action().name == "counter_1" + + +def test_application_builder_complete(): app = ( ApplicationBuilder() .with_state(counter=0) - .with_entrypoint(counter_action) - .with_transition(counter_action, counter_action, Condition.expr("counter < 10")) - .with_transition(counter_action, Result(name="result", fields=["counter"]), DEFAULT) + .with_actions(counter=base_counter_action, result=Result(fields=["counter"])) + .with_transitions( + ("counter", "counter", Condition.expr("counter < 10")), ("counter", "result") + ) + .with_entrypoint("counter") .build() ) - assert len(app.functions) == 2 - assert len(app.transitions) == 2 - assert app.get_next_function().name == "counter" + assert len(app._actions) == 2 + assert len(app._transitions) == 2 + assert app.get_next_action().name == "counter" + + +def test__validate_transitions_correct(): + _validate_transitions( + [("counter", "counter", Condition.expr("counter < 10")), ("counter", "result", default)], + {"counter", "result"}, + ) + + +def test__validate_transitions_missing_action(): + with pytest.raises(ValueError, match="not found"): + _validate_transitions( + [ + ("counter", "counter", Condition.expr("counter < 10")), + ("counter", "result", default), + ], + {"counter"}, + ) + + +def test__validate_transitions_redundant_transition(): + with pytest.raises(ValueError, match="redundant"): + _validate_transitions( + [ + ("counter", "counter", Condition.expr("counter < 10")), + ("counter", "result", default), + ("counter", "counter", default), # this is unreachable as we already have a default + ], + {"counter", "result"}, + ) + + +def test__validate_start_valid(): + _validate_start("counter", {"counter", "result"}) + + +def test__validate_start_not_found(): + with pytest.raises(ValueError, match="not found"): + _validate_start("counter", {"result"}) + + +def test__validate_actions_valid(): + _validate_actions([Result(["test"])]) + + +def test__validate_actions_empty(): + with pytest.raises(ValueError, match="at least one"): + _validate_actions([]) + + +def test__asset_set(): + _assert_set("foo", "foo", "bar") + + +def test__assert_set_unset(): + with pytest.raises(ValueError, match="foo"): + _assert_set(None, "foo", "bar") def test_application_builder_unset(): with pytest.raises(ValueError): ApplicationBuilder().build() + +def test_application_runs_hooks_sync(): + class ActionTracker(PreRunStepHook, PostRunStepHook): + def __init__(self): + self.pre_called = [] + self.post_called = [] + + def pre_run_step(self, *, action: Action, **future_kwargs): + self.pre_called.append(action.name) + + def post_run_step(self, *, action: Action, **future_kwargs): + self.post_called.append(action.name) + + tracker = ActionTracker() + counter_action = base_counter_action.with_name("counter") + result_action = Result(fields=["counter"]).with_name("result") + app = Application( + actions=[counter_action, result_action], + transitions=[ + Transition(counter_action, result_action, Condition.expr("counter >= 10")), + Transition(counter_action, counter_action, default), + ], + state=State({}), + initial_step="counter", + adapter_set=internal.LifecycleAdapterSet(tracker), + ) + app.run(until=["result"]) + assert set(tracker.pre_called) == {"counter", "result"} + assert set(tracker.post_called) == {"counter", "result"} + assert len(tracker.pre_called) == 11 + assert len(tracker.post_called) == 11 + + +async def test_application_runs_hooks_async(): + class ActionTrackerAsync(PreRunStepHookAsync, PostRunStepHookAsync): + def __init__(self): + self.pre_called = [] + self.post_called = [] + + async def pre_run_step(self, *, action: Action, **future_kwargs): + await asyncio.sleep(0.0001) + self.pre_called.append(action.name) + + async def post_run_step(self, *, action: Action, **future_kwargs): + await asyncio.sleep(0.0001) + self.post_called.append(action.name) + + tracker = ActionTrackerAsync() + counter_action = base_counter_action.with_name("counter") + result_action = Result(fields=["counter"]).with_name("result") + app = Application( + actions=[counter_action, result_action], + transitions=[ + Transition(counter_action, result_action, Condition.expr("counter >= 10")), + Transition(counter_action, counter_action, default), + ], + state=State({}), + initial_step="counter", + adapter_set=internal.LifecycleAdapterSet(tracker), + ) + await app.arun(until=["result"]) + assert set(tracker.pre_called) == {"counter", "result"} + assert set(tracker.post_called) == {"counter", "result"} + assert len(tracker.pre_called) == 11 + assert len(tracker.post_called) == 11 diff --git a/tests/core/test_function.py b/tests/core/test_function.py deleted file mode 100644 index c4601855..00000000 --- a/tests/core/test_function.py +++ /dev/null @@ -1,37 +0,0 @@ -from burr.core import State -from burr.core.function import Condition, DEFAULT - - -def test_condition(): - cond = Condition(["foo"], lambda state: state["foo"] == "bar", name="foo") - assert cond.name == "foo" - assert cond.reads == ["foo"] - assert cond.writes == [] - assert cond.run(State({"foo": "bar"})) == {Condition.KEY: True} - assert cond.run(State({"foo": "baz"})) == {Condition.KEY: False} - - -def test_condition_when(): - cond = Condition.when(foo="bar") - assert cond.name == "foo=bar" - assert cond.reads == ["foo"] - assert cond.writes == [] - assert cond.run(State({"foo": "bar"})) == {Condition.KEY: True} - assert cond.run(State({"foo": "baz"})) == {Condition.KEY: False} - - -def test_condition_default(): - cond = DEFAULT - assert cond.name == "default" - assert cond.reads == [] - assert cond.writes == [] - assert cond.run(State({"foo": "bar"})) == {Condition.KEY: True} - - -def test_condition_expr(): - cond = Condition.expr("foo == 'bar'") - assert cond.name == "foo == 'bar'" - assert cond.reads == ["foo"] - assert cond.writes == [] - assert cond.run(State({"foo": "bar"})) == {Condition.KEY: True} - assert cond.run(State({"foo": "baz"})) == {Condition.KEY: False} diff --git a/tests/core/test_implementations.py b/tests/core/test_implementations.py new file mode 100644 index 00000000..0fa5bc32 --- /dev/null +++ b/tests/core/test_implementations.py @@ -0,0 +1,15 @@ +import pytest + +from burr.core import State +from burr.core.implementations import Placeholder + + +def test_placedholder_action(): + action = Placeholder(reads=["foo"], writes=["bar"]).with_name("test") + assert action.reads == ["foo"] + assert action.writes == ["bar"] + with pytest.raises(NotImplementedError): + action.run(State({})) + + with pytest.raises(NotImplementedError): + action.update({}, State({})) diff --git a/tests/core/test_state.py b/tests/core/test_state.py index e2f6add3..fa813051 100644 --- a/tests/core/test_state.py +++ b/tests/core/test_state.py @@ -29,10 +29,10 @@ def test_state_get_missing_default(): assert state.get("baz", "qux") == "qux" -def test_state_has(): +def test_state_in(): state = State({"foo": "bar"}) - assert state.has("foo") - assert not state.has("baz") + assert "foo" in state + assert "baz" not in state def test_state_get_all(): @@ -63,3 +63,20 @@ def test_state_update(): state = State({"foo": "bar", "baz": "qux"}) updated = state.update(foo="baz") assert updated.get_all() == {"foo": "baz", "baz": "qux"} + + +def test_state_init(): + state = State({"foo": "bar", "baz": "qux"}) + assert state.get_all() == {"foo": "bar", "baz": "qux"} + + +def test_state_wipe_delete(): + state = State({"foo": "bar", "baz": "qux"}) + wiped = state.wipe(delete=["foo"]) + assert wiped.get_all() == {"baz": "qux"} + + +def test_state_wipe_keep(): + state = State({"foo": "bar", "baz": "qux"}) + wiped = state.wipe(keep=["foo"]) + assert wiped.get_all() == {"foo": "bar"} diff --git a/tests/integrations/test_burr_hamilton.py b/tests/integrations/test_burr_hamilton.py new file mode 100644 index 00000000..1e836ebc --- /dev/null +++ b/tests/integrations/test_burr_hamilton.py @@ -0,0 +1,98 @@ +import pytest + +from burr.core import State +from burr.integrations.hamilton import Hamilton, from_state, from_value, update_state +from hamilton import ad_hoc_utils, driver + + +def _incrementing_driver(): + def incremented_count(current_count: int) -> int: + return current_count + 1 + + def incremented_count_2(current_count: int, increment_by: int = 1) -> int: + return current_count + increment_by + + def sum_of_counts(incremented_count: int, incremented_count_2: int) -> int: + return incremented_count + incremented_count_2 + + mod = ad_hoc_utils.create_temporary_module( + incremented_count, incremented_count_2, sum_of_counts + ) + dr = driver.Driver({}, mod) + return dr + + +def test_set_driver(): + dr = _incrementing_driver() + Hamilton.set_driver(dr) + h = Hamilton({}, {}, driver=dr) + assert h.driver == dr + + +def test__extract_inputs_overrides(): + dr = _incrementing_driver() + h = Hamilton( + inputs={"current_count": from_state("count"), "incremented_count_2": from_value(2)}, + outputs={"sum_of_counts": update_state("count")}, + driver=dr, + ) + inputs, overrides = h._extract_inputs_overrides(State({"count": 0})) + assert inputs == {"current_count": 0} + assert overrides == {"incremented_count_2": 2} + + +def test__extract_inputs_overrides_missing_inputs(): + dr = _incrementing_driver() + h = Hamilton( + inputs={"current_count_not_present": from_state("count")}, + outputs={"sum_of_counts": update_state("count")}, + driver=dr, + ) + with pytest.raises(ValueError, match="not available"): + inputs, _ = h._extract_inputs_overrides(State({"count": 0})) + + +def test_reads(): + dr = _incrementing_driver() + h = Hamilton( + inputs={"current_count": from_state("count"), "incremented_count_2": from_value(2)}, + outputs={"sum_of_counts": update_state("count")}, + driver=dr, + ) + assert h.reads == ["count"] + + +def test_writes(): + dr = _incrementing_driver() + h = Hamilton( + inputs={"current_count": from_state("count"), "incremented_count_2": from_value(2)}, + outputs={"sum_of_counts": update_state("count")}, + driver=dr, + ) + assert h.writes == ["count"] + + +def test_run_step_with_multiple_inputs(): + dr = _incrementing_driver() + h = Hamilton( + inputs={"current_count": from_state("count"), "increment_by": from_value(5)}, + outputs={"sum_of_counts": update_state("count")}, + driver=dr, + ) + result = h.run(State({"count": 1})) + assert result == {"sum_of_counts": 8} + new_state = h.update(result, State({"count": 1})) + assert new_state.get_all() == {"count": 8} + + +def test_run_step_with_overrides(): + dr = _incrementing_driver() + h = Hamilton( + inputs={"current_count": from_state("count"), "incremented_count_2": from_value(2)}, + outputs={"sum_of_counts": update_state("count")}, + driver=dr, + ) + result = h.run(State({"count": 1})) + assert result == {"sum_of_counts": 4} + new_state = h.update(result, State({"count": 1})) + assert new_state.get_all() == {"count": 4} diff --git a/tests/pytest.ini b/tests/pytest.ini new file mode 100644 index 00000000..40880458 --- /dev/null +++ b/tests/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +asyncio_mode=auto