diff --git a/burr/core/application.py b/burr/core/application.py index 30c92b9a..f9aa54a2 100644 --- a/burr/core/application.py +++ b/burr/core/application.py @@ -7,6 +7,7 @@ import logging import pprint import uuid +from concurrent.futures import Executor, ThreadPoolExecutor from contextlib import AbstractContextManager from typing import ( TYPE_CHECKING, @@ -497,6 +498,7 @@ class ApplicationContext(AbstractContextManager): partition_key: Optional[str] sequence_id: Optional[int] tracker: Optional["TrackingClient"] + parallel_executor_factory: Callable[[], Executor] @staticmethod def get() -> Optional["ApplicationContext"]: @@ -683,6 +685,10 @@ def post_run_step( StreamResultType = TypeVar("StreamResultType", bound=Union[dict, Any]) +def _create_default_executor() -> Executor: + return ThreadPoolExecutor() + + class Application(Generic[ApplicationStateType]): def __init__( self, @@ -697,6 +703,7 @@ def __init__( fork_parent_pointer: Optional[burr_types.ParentPointer] = None, spawning_parent_pointer: Optional[burr_types.ParentPointer] = None, tracker: Optional["TrackingClient"] = None, + parallel_executor_factory: Optional[Executor] = None, ): """Instantiates an Application. This is an internal API -- use the builder! @@ -731,6 +738,11 @@ def __init__( self._set_sequence_id(sequence_id) self._builder = builder self._parent_pointer = fork_parent_pointer + self._parallel_executor_factory = ( + parallel_executor_factory + if parallel_executor_factory is not None + else _create_default_executor + ) self._dependency_factory = { "__tracer": functools.partial( visibility.tracing.TracerFactory, @@ -780,6 +792,7 @@ def _context_factory(self, action: Action, sequence_id: int) -> ApplicationConte tracker=self._tracker, partition_key=self._partition_key, sequence_id=sequence_id, + parallel_executor_factory=self._parallel_executor_factory, ) def _step( @@ -862,7 +875,7 @@ def _process_inputs(self, inputs: Dict[str, Any], action: Action) -> Dict[str, A BASE_ERROR_MESSAGE + f"Inputs starting with a double underscore ({starting_with_double_underscore}) " f"are reserved for internal use/injected inputs." - "Please do not use keys" + "Please do not directly pass keys starting with a double underscore." ) inputs = inputs.copy() processed_inputs = {} @@ -1922,6 +1935,7 @@ def __init__(self): self.graph_builder = None self.prebuilt_graph = None self.typing_system = None + self._parallel_executor_factory = None def with_identifiers( self, app_id: str = None, partition_key: str = None, sequence_id: int = None @@ -2015,6 +2029,33 @@ def with_graph(self, graph: Graph) -> "ApplicationBuilder[StateType]": self.prebuilt_graph = graph return self + def with_parallel_executor(self, executor_factory: lambda: Executor): + """Assigns a default executor to be used for recursive/parallel sub-actions. This effectively allows + for executing multiple Burr apps in parallel. See https://burr.dagworks.io/pull/concepts/parallelism/ + for more details. + + This will default to a simple threadpool executor, meaning that you will be bound by the number of threads + your computer can handle. If you want to use a more advanced executor, you can pass it in here -- any subclass + of concurrent.futures.Executor will work. + + If you specify executors for specific tasks, this will default to that. + + Note that, if you are using asyncio, you cannot specify an executor. It will default to using + asyncio.gather with asyncio's event loop. + + :param executor: + :return: + """ + if self._parallel_executor_factory is not None: + raise ValueError( + BASE_ERROR_MESSAGE + + "You have already set an executor. You cannot set multiple executors. Current executor is:" + f"{self._parallel_executor_factory}" + ) + + self._parallel_executor_factory = executor_factory + return self + def _ensure_no_prebuilt_graph(self): if self.prebuilt_graph is not None: raise ValueError( @@ -2365,4 +2406,5 @@ def build(self) -> Application[StateType]: if self.spawn_from_app_id is not None else None ), + parallel_executor_factory=self._parallel_executor_factory, ) diff --git a/burr/core/graph.py b/burr/core/graph.py index ff70a4b1..61bc4961 100644 --- a/burr/core/graph.py +++ b/burr/core/graph.py @@ -31,7 +31,6 @@ def _validate_actions(actions: Optional[List[Action]]): def _validate_transitions( transitions: Optional[List[Tuple[str, str, Condition]]], actions: Set[str] ): - assert_set(transitions, "_transitions", "with_transitions") exhausted = {} # items for which we have seen a default transition for from_, to, condition in transitions: if from_ not in actions: @@ -235,7 +234,7 @@ class GraphBuilder: def __init__(self): """Initializes the graph builder.""" - self.transitions: Optional[List[Tuple[str, str, Condition]]] = None + self.transitions: Optional[List[Tuple[str, str, Condition]]] = [] self.actions: Optional[List[Action]] = None def with_actions( @@ -283,8 +282,6 @@ def with_transitions( :param transitions: Transitions to add :return: The application builder for future chaining. """ - if self.transitions is None: - self.transitions = [] for transition in transitions: from_, to_, *conditions = transition if len(conditions) > 0: diff --git a/burr/core/parallelism.py b/burr/core/parallelism.py new file mode 100644 index 00000000..47d5f071 --- /dev/null +++ b/burr/core/parallelism.py @@ -0,0 +1,706 @@ +import abc +import asyncio +import dataclasses +import inspect +from typing import Any, AsyncGenerator, Callable, Dict, Generator, List, Tuple, TypeVar, Union + +from burr.core import Action, Application, ApplicationBuilder, ApplicationContext, Graph, State +from burr.core.action import SingleStepAction +from burr.core.graph import GraphBuilder + +SubgraphType = Union[Action, Callable, "RunnableGraph"] + + +@dataclasses.dataclass +class RunnableGraph: + """Contains a graph with information it needs to run. + This is a bit more than a graph -- we have entrypoints + halt_after points. + This is the core element of a recursive action -- your recursive generators can yield these + (as well as actions/functions, which both get turned into single-node graphs...) + """ + + graph: Graph + entrypoint: str + halt_after: List[str] + + @staticmethod + def create(from_: SubgraphType) -> "RunnableGraph": + """Creates a RunnableGraph from a callable/action. This will create a single-node runnable graph, + so we can wrap it up in a task. + + :param from_: Callable or Action to wrap + :return: RunnableGraph + """ + if isinstance(from_, RunnableGraph): + return from_ + if isinstance(from_, Action): + assert ( + from_.name is not None + ), "Action must have a name to be run, internal error, reach out to devs" + graph = GraphBuilder().with_actions(from_).build() + (action,) = graph.actions + return RunnableGraph(graph=graph, entrypoint=action.name, halt_after=[action.name]) + + +@dataclasses.dataclass +class SubGraphTask: + """Task to run a subgraph. Has runtime-spefici information, like inputs, state, and + the application ID. This is the lower-level component -- the user will only directly interact + with this if they use the TaskBasedParallelAction interface, which produces a generator of these. + """ + + graph: RunnableGraph + inputs: Dict[str, Any] + state: State + application_id: str + + def _create_app(self, parent_context: ApplicationContext) -> Application: + return ( + ApplicationBuilder() + .with_graph(self.graph.graph) + .with_entrypoint(self.graph.entrypoint) + .with_state(self.state) + .with_spawning_parent( + app_id=parent_context.app_id, + sequence_id=parent_context.sequence_id, + partition_key=parent_context.partition_key, + ) + .with_tracker(parent_context.tracker.copy()) # We have to copy + # TODO -- handle persistence... + .with_identifiers( + app_id=self.application_id, + partition_key=parent_context.partition_key, # cascade the partition key + ) + .build() + ) + + def run( + self, + parent_context: ApplicationContext, + ) -> State: + """Runs the task -- this simply executes it b y instantiating a sub-application""" + app = self._create_app(parent_context) + action, result, state = app.run( + halt_after=self.graph.halt_after, + inputs={key: value for key, value in self.inputs.items() if not key.startswith("__")}, + ) + return state + + async def arun(self, parent_context: ApplicationContext): + app = self._create_app(parent_context) + action, result, state = await app.arun( + halt_after=self.graph.halt_after, + inputs={key: value for key, value in self.inputs.items() if not key.startswith("__")}, + ) + return state + + +def _stable_app_id_hash(app_id: str, child_key: str) -> str: + """Gives a stable hash for an application. Given the parent app_id and a child key, + this will give a hash that will be stable across runs. + + :param app_id: + :param additional_key: + :return: + """ + ... + + +class TaskBasedParallelAction(SingleStepAction): + """The base class for actions that run a set of tasks in parallel and reduce the results. + This is more power-user mode -- if you need fine-grained control over the set of tasks + your parallel action utilizes, then this is for you. If not, you'll want to see: + + - :py:class:`MapActionsAndStates` -- a cartesian product of actions/states + - :py:class:`MapActions` -- a map of actions over a single state + - :py:class:`MapStates` -- a map of a single action over multiple states + + If you're unfamiliar about where to start, you'll want to see the docs on :ref:`parallelism `. + + This is responsible for two things: + + 1. Creating a set of tasks to run in parallel + 2. Reducing the results of those tasks into a single state for the action to return. + + The following example shows how to call a set of prompts over a set of different models in parallel and return the result. + + .. code-block:: python + + from burr.core import action, state, ApplicationContext + from burr.core.parallelism import MapStates, RunnableGraph + from typing import Callable, Generator, List + + @action(reads=["prompt", "model"], writes=["llm_output"]) + def query_llm(state: State, model: str) -> State: + # TODO -- implement _query_my_llm to call litellm or something + return state.update(llm_output=_query_my_llm(prompt=state["prompt"], model=model)) + + class MultipleTaskExample(TaskBasedParallelAction): + def tasks(state: State, context: ApplicationContext) -> Generator[SubGraphTask, None, None]: + for prompt in state["prompts"]: + for action in [ + query_llm.bind(model="gpt-4").with_name("gpt_4_answer"), + query_llm.bind(model="o1").with_name("o1_answer"), + query_llm.bind(model="claude").with_name("claude_answer"), + ] + yield SubGraphTask( + action=action, # can be a RunnableGraph as well + state=state.update(prompt=prompt), + inputs={}, + # stable hash -- up to you to ensure uniqueness + application_id=hashlib.sha256(context.application_id + action.name + prompt).hexdigest(), + # a few other parameters we might add -- see advanced usage -- failure conditions, etc... + ) + + def reduce(self, states: Generator[State, None, None]) -> State: + all_llm_outputs = [] + for state in states: + all_llm_outputs.append( + { + "output" : state["llm_output"], + "model" : state["model"], + "prompt" : state["prompt"], + } + ) + return state.update(all_llm_outputs=all_llm_outputs) + """ + + def __init__(self): + super().__init__() + + def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]: + """Runs and updates. This is not user-facing, so do not override it. + This runs all actions in parallel (using the supplied executor, from the context), + and then reduces the results. + + :param state: Input state + :param run_kwargs: Additional inputs (runtime inputs) + :return: The results, updated state tuple. The results are empty, but we may add more in the future. + """ + + def _run_and_update(): + context: ApplicationContext = run_kwargs.get("__context") + if context is None: + raise ValueError("This action requires a context to run") + state_without_internals = state.wipe( + delete=[item for item in state.keys() if item.startswith("__")] + ) + task_generator = self.tasks(state_without_internals, context, run_kwargs) + + def execute_task(task): + return task.run(run_kwargs["__context"]) + + with context.parallel_executor_factory() as executor: + # Directly map the generator to the executor + results = list(executor.map(execute_task, task_generator)) + + def state_generator() -> Generator[Any, None, None]: + yield from results + + return {}, self.reduce(state_without_internals, state_generator()) + + async def _arun_and_update(): + context: ApplicationContext = run_kwargs.get("__context") + if context is None: + raise ValueError("This action requires a context to run") + state_without_internals = state.wipe( + delete=[item for item in state.keys() if item.startswith("__")] + ) + task_generator = self.tasks(state_without_internals, context, run_kwargs) + + # TODO -- run in parallel + async def state_generator(): + """This makes it easier on the user -- if they don't have an async generator we can still exhause it + This way we run through all of the task generators. These correspond to the task generation capabilities above (the map*/task generation stuff) + """ + if inspect.isasyncgen(task_generator): + coroutines = [task.arun(context) async for task in task_generator] + else: + coroutines = [task.arun(context) for task in task_generator] + results = await asyncio.gather(*coroutines) + # TODO -- yield in order... + for result in results: + yield result + + return {}, await self.reduce(state_without_internals, state_generator()) + + if self.is_async(): + return _arun_and_update() # type: ignore + return _run_and_update() + + def is_async(self) -> bool: + """This says whether or not the action is async. Note you have to override this if you have async tasks + and want to use asyncio.gather on them. Otherwise leave this blank. + + :return: Whether or not the action is async + """ + return False + + @property + def inputs(self) -> Union[list[str], tuple[list[str], list[str]]]: + """Inputs from this -- if you want to override you'll want to call super() + first so you get these inputs. + + :return: the list of inputs that will populate kwargs. + """ + return ["__context"] # TODO -- add any additional input + + @abc.abstractmethod + def tasks( + self, state: State, context: ApplicationContext, inputs: Dict[str, Any] + ) -> Generator[SubGraphTask, None, None]: + """Creates all tasks that this action will run, given the state/inputs. + This produces a generator of SubGraphTasks that will be run in parallel. + + :param state: State prior to action's execution + :param context: Context for the action + :yield: SubGraphTasks to run + """ + pass + + @abc.abstractmethod + def reduce(self, state: State, states: Generator[State, None, None]) -> State: + """Reduces the states from the tasks into a single state. + + :param states: State outputs from the subtasks + :return: Reduced state + """ + pass + + @property + @abc.abstractmethod + def writes(self) -> list[str]: + pass + + @property + @abc.abstractmethod + def reads(self) -> list[str]: + pass + + +class MapActionsAndStates(TaskBasedParallelAction): + """Base class to run a cartesian product of actions x states. + + For example, if you want to run the following: + + - n prompts + - m models + + This will make it easy to do. If you need fine-grained control, you can use the :py:class:`TaskBasedParallelAction`, + which allows you to specify the tasks individually. If you just want to vary actions/states (and not both), use + :py:class:`MapActions` or :py:class:`MapStates` implementations. + + The following shows how to run a set of prompts over a set of models in parallel and return the results. + + .. code-block:: python + + from burr.core import action, state + from burr.core.parallelism import MapActionsAndStates, RunnableGraph + from typing import Callable, Generator, List + + @action(reads=["prompt", "model"], writes=["llm_output"]) + def query_llm(state: State, model: str) -> State: + # TODO -- implement _query_my_llm to call litellm or something + return state.update(llm_output=_query_my_llm(prompt=state["prompt"], model=model)) + + class TestModelsOverPrompts(MapActionsAndStates): + + def actions(self, state: State) -> Generator[Action | Callable | RunnableGraph, None, None]: + # make sure to add a name to the action + # This is not necessary for subgraphs, as actions will already have names + for action in [ + query_llm.bind(model="gpt-4").with_name("gpt_4_answer"), + query_llm.bind(model="o1").with_name("o1_answer"), + query_llm.bind(model="claude").with_name("claude_answer"), + ] + yield action + + def states(self, state: State) -> Generator[State, None, None]: + for prompt in [ + "What is the meaning of life?", + "What is the airspeed velocity of an unladen swallow?", + "What is the best way to cook a steak?", + ]: + yield state.update(prompt=prompt) + + def reduce(self, states: Generator[State, None, None]) -> State: + all_llm_outputs = [] + for state in states: + all_llm_outputs.append( + { + "output" : state["llm_output"], + "model" : state["model"], + "prompt" : state["prompt"], + } + ) + return state.update(all_llm_outputs=all_llm_outputs) + + def reads() -> List[str]: + return ["prompts"] + + def writes() -> List[str]: + return ["all_llm_outputs"] + + """ + + @abc.abstractmethod + def actions( + self, state: State, context: ApplicationContext, inputs: Dict[str, Any] + ) -> Generator[SubgraphType, None, None]: + """Yields actions to run in parallel. These will be merged with the states as a cartesian product. + + :param state: Input state at the time of running the "parent" action. + :param inputs: Runtime Inputs to the action + :return: Generator of actions to run + """ + pass + + @abc.abstractmethod + def states( + self, state: State, context: ApplicationContext, inputs: Dict[str, Any] + ) -> Generator[State, None, None]: + """Yields states to run in parallel. These will be merged with the actions as a cartesian product. + + :param state: Input state at the time of running the "parent" action. + :param context: Context for the action + :param inputs: Runtime Inputs to the action + :return: Generator of states to run + """ + pass + + def tasks( + self, state: State, context: ApplicationContext, inputs: Dict[str, Any] + ) -> Generator[SubGraphTask, None, None]: + """Takes the cartesian product of actions and states, creating tasks for each. + + :param state: Input state at the time of running the "parent" action. + :param context: Context for the action + :param inputs: Runtime Inputs to the action + :return: Generator of tasks to run + """ + for i, action in enumerate(self.actions(state, context, inputs)): + for j, state in enumerate(self.states(state, context, inputs)): + key = f"{i}-{j}" # this is a stable hash for now but will not handle caching + # TODO -- allow for custom hashes that will indicate stability (user is responsible) + yield SubGraphTask( + graph=RunnableGraph.create(action), + inputs=inputs, + state=state, + application_id=_stable_app_id_hash(context.app_id, key), + ) + + @abc.abstractmethod + def reduce(self, state: State, states: Generator[State, None, None]) -> State: + """Reduces the states from the tasks into a single state. + + :param states: State outputs from the subtasks + :return: Reduced state + """ + pass + + +class MapActions(MapActionsAndStates, abc.ABC): + """Base class to run a set of actions over the same state. Actions can be functions (decorated with @action), + action objects, or subdags implemented as :py:class:`RunnableGraph` objects. With this, you can do the following: + + 1. Specify the actions to run + 2. Specify the state to run the actions over + 3. Reduce the results into a single state + + This is useful, for example, to run different LLMs over the same set of prompts, + + Here is an example (with some pseudocode) of doing just that: + + .. code-block:: python + + from burr.core import action, state + from burr.core.parallelism import MapActions, RunnableGraph + from typing import Callable, Generator, List + + @action(reads=["prompt", "model"], writes=["llm_output"]) + def query_llm(state: State, model: str) -> State: + # TODO -- implement _query_my_llm to call litellm or something + return state.update(llm_output=_query_my_llm(prompt=state["prompt"], model=model)) + + class TestMultipleModels(MapActions): + + def actions(self, state: State) -> Generator[Action | Callable | RunnableGraph, None, None]: + # Make sure to add a name to the action if you use bind() with a function, + # note that these can be different actions, functions, etc... + # in this case we're using `.bind()` to create multiple actions, but we can use some mix of + # subgraphs, functions, action objects, etc... + for action in [ + query_llm.bind(model="gpt-4").with_name("gpt_4_answer"), + query_llm.bind(model="o1").with_name("o1_answer"), + query_llm.bind(model="claude").with_name("claude_answer"), + ] + yield action + + def state(self, state: State) -> State: + return state.update(prompt="What is the meaning of life?") + + def reduce(self, states: Generator[State, None, None]) -> State: + all_llm_outputs = [] + for state in states: + all_llm_outputs.append(state["llm_output"]) + return state.update(all_llm_outputs=all_llm_outputs) + + def reads() -> List[str]: + return ["prompt"] # we're just running this on a single prompt, for multiple actions + + def writes() -> List[str]: + return ["all_llm_outputs"] + + """ + + @abc.abstractmethod + def actions( + self, state: State, inputs: Dict[str, Any], context: ApplicationContext + ) -> Generator[SubgraphType, None, None]: + """Gives all actions to map over, given the state/inputs. + + :param state: State at the time of running the action + :param inputs: Runtime Inputs to the action + :param context: Context for the action + :return: Generator of actions to run + """ + + @abc.abstractmethod + def state(self, state: State, inputs: Dict[str, Any]): + """Gives the state for each of the actions + + :param state: State at the time of running the action + :param inputs: Runtime inputs to the action + :return: State for the action + """ + pass + + def states( + self, state: State, context: ApplicationContext, inputs: Dict[str, Any] + ) -> Generator[State, None, None]: + """Just converts the state into a generator of 1, so we can use the superclass. This is internal.""" + yield self.state(state, inputs) + + @abc.abstractmethod + def reduce(self, state: State, states: Generator[State, None, None]) -> State: + """Reduces the task's results into a single state. Runs through all outputs + and combines them together, to form the final state for the action. + + :param states: State outputs from the subtasks + :return: Reduced state + """ + pass + + +class MapStates(MapActionsAndStates, abc.ABC): + """Base class to run a single action over a set of states. States are given as + updates (manipulations) of the action's input state, specified by the `states` + generator. + + With this, you can do the following: + + 1. Specify the states to run + 2. Specify the action to run over all the states + 3. Reduce the results into a single state + + This is useful, for example, to run different prompts over the same LLM, + + Here is an example (with some pseudocode) of doing just that: + + .. code-block:: python + + from burr.core import action, state + from burr.core.parallelism import MapStates, RunnableGraph + from typing import Callable, Generator, List + + @action(reads=["prompt"], writes=["llm_output"]) + def query_llm(state: State) -> State: + return state.update(llm_output=_query_my_llm(prompt=state["prompt"])) + + class TestMultiplePrompts(MapStates): + + def action(self) -> Action | Callable | RunnableGraph: + # make sure to add a name to the action + # This is not necessary for subgraphs, as actions will already have names + return query_llm.with_name("query_llm") + + def states(self, state: State) -> Generator[State, None, None]: + # You could easily have a list_prompts upstream action that writes to "prompts" in state + # And loop through those + # This hardcodes for simplicity + for prompt in [ + "What is the meaning of life?", + "What is the airspeed velocity of an unladen swallow?", + "What is the best way to cook a steak?", + ]: + yield state.update(prompt=prompt) + + + def reduce(self, states: Generator[State, None, None]) -> State: + all_llm_outputs = [] + for state in states: + all_llm_outputs.append(state["llm_output"]) + return state.update(all_llm_outputs=all_llm_outputs) + + def reads() -> List[str]: + return ["prompts"] + + def writes() -> List[str]: + return ["all_llm_outputs"] + """ + + @abc.abstractmethod + def states( + self, state: State, context: ApplicationContext, inputs: Dict[str, Any] + ) -> Generator[State, None, None]: + """Generates all states to map over, given the state and inputs. + Each state will be an update to the input state. + + For instance, you may want to take an input state that has a list field, and expand it + into a set of states, each with a different value from the list. + + For example: + + .. code-block:: python + + def states(self, state: State, context: ApplicationContext, inputs: Dict[str, Any]) -> Generator[State, None, None]: + for item in state["multiple_fields"]: + yield state.update(individual_field=item) + + :param state: Initial state + :param context: Context for the action + :param inputs: Runtime inputs to the action + :return: Generator of states to run + """ + pass + + @abc.abstractmethod + def action(self, state: State, inputs: Dict[str, Any]) -> SubgraphType: + """The single action to apply to each state. + This can be a function (decorated with `@action`, action object, or subdag). + + :param state: State to run the action over + :param inputs: Runtime inputs to the action + :return: Action to run + """ + pass + + def actions( + self, state: State, context: ApplicationContext, inputs: Dict[str, Any] + ) -> Generator[SubgraphType, None, None]: + """Maps the action over each state generated by the `states` method. + Internally used, do not implement.""" + yield self.action(state, inputs) + + @abc.abstractmethod + def reduce(self, state: State, results: Generator[State, None, None]) -> State: + """Reduces the task's results + + :param results: + :return: + """ + pass + + +GenType = TypeVar("GenType") +ReturnType = TypeVar("ReturnType") + +SyncOrAsyncGenerator = Union[Generator[GenType, None, None], AsyncGenerator[GenType, None]] +SyncOrAsyncGeneratorOrItemOrList = Union[SyncOrAsyncGenerator[GenType], List[GenType], GenType] + + +class PassThroughMapActionsAndStates(MapActionsAndStates): + def __init__( + self, + action: Union[ + SubgraphType, + List[SubgraphType], + Callable[ + [State, ApplicationContext, Dict[str, Any]], SyncOrAsyncGenerator[SubgraphType] + ], + ], + state: Callable[[State, ApplicationContext, Dict[str, Any]], SyncOrAsyncGenerator[State]], + reducer: Callable[[State, SyncOrAsyncGenerator[State]], State], + reads: List[str], + writes: List[str], + inputs: List[str], + ): + super().__init__() + self._action_or_generator = action + self._state_or_generator = state + self._reducer = reducer + self._reads = reads + self._writes = writes + self._inputs = inputs + + def actions( + self, state: State, context: ApplicationContext, inputs: Dict[str, Any] + ) -> Generator[SubgraphType, None, None]: + if isinstance(self._action_or_generator, list): + for action in self._action_or_generator: + yield action + return + if isinstance(self._action_or_generator, SubgraphType): + yield self._action_or_generator + else: + gen = self._action_or_generator(state, context, inputs) + if inspect.isasyncgen(gen): + + async def gen(): + async for item in self._action_or_generator(state, context, inputs): + yield item + + return gen() + else: + yield from self._action_or_generator(state, context, inputs) + + def states( + self, state: State, context: ApplicationContext, inputs: Dict[str, Any] + ) -> Generator[State, None, None]: + gen = self._state_or_generator(state, context, inputs) + if isinstance(gen, State): + yield gen + if inspect.isasyncgen(gen): + + async def gen(): + async for item in self._state_or_generator(state, context, inputs): + yield item + + return gen() + else: + yield from gen + + def reduce(self, state: State, states: SyncOrAsyncGenerator[State]) -> State: + return self._reducer(state, states) + + @property + def writes(self) -> list[str]: + return self._writes + + @property + def reads(self) -> list[str]: + return self._reads + + +def map_reduce_action( + # action: Optional[SubgraphType]=None, + action: Union[ + SubgraphType, + List[SubgraphType], + Callable[ + [State, ApplicationContext, Dict[str, Any]], + SyncOrAsyncGeneratorOrItemOrList[SubgraphType], + ], + ], + state: Callable[ + [State, ApplicationContext, Dict[str, Any]], SyncOrAsyncGeneratorOrItemOrList[State] + ], + reducer: Callable[[State, SyncOrAsyncGenerator[State]], State], + reads: List[str], + writes: List[str], + inputs: List[str], +): + """Experimental API for creating a map-reduce action easily. We'll be improving this.""" + return PassThroughMapActionsAndStates( + action=action, state=state, reducer=reducer, reads=reads, writes=writes, inputs=inputs + ) diff --git a/burr/core/state.py b/burr/core/state.py index addad13e..85e89a8d 100644 --- a/burr/core/state.py +++ b/burr/core/state.py @@ -428,6 +428,13 @@ def subset(self, *keys: str, ignore_missing: bool = True) -> "State[StateType]": ) def __getitem__(self, __k: str) -> Any: + if __k not in self._state: + raise KeyError( + f"Key \"{__k}\" not found in state. Keys state knows about are: {[key for key in self._state.keys() if not key.startswith('__')]}. " + "If you hit this within the context of an application, you want to " + "(a) ensure that an upstream action has produced this state/it is set as an initial state value and " + "(b) ensure that your action declares this as a read key." + ) return self._state[__k] def __len__(self) -> int: diff --git a/docs/concepts/parallelism.rst b/docs/concepts/parallelism.rst index 476904c8..e30791ab 100644 --- a/docs/concepts/parallelism.rst +++ b/docs/concepts/parallelism.rst @@ -1,3 +1,5 @@ +.. _parallelism: + =========== Parallelism =========== diff --git a/docs/reference/index.rst b/docs/reference/index.rst index f6e96eec..1deeb28e 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -21,6 +21,7 @@ need functionality that is not publicly exposed, please open an issue and we can tracking visibility lifecycle + parallelism typing integrations/index telemetry diff --git a/docs/reference/integrations/index.rst b/docs/reference/integrations/index.rst index 0d15cda1..08ae18af 100644 --- a/docs/reference/integrations/index.rst +++ b/docs/reference/integrations/index.rst @@ -13,3 +13,4 @@ Integrations -- we will be adding more traceloop langchain pydantic + haystack diff --git a/docs/reference/parallelism.rst b/docs/reference/parallelism.rst new file mode 100644 index 00000000..a21ec51e --- /dev/null +++ b/docs/reference/parallelism.rst @@ -0,0 +1,27 @@ +.. _parallelismref: + +=========== +Parallelism +=========== + +Tools to make sub-actions/sub-graphs easier to work with. Read the docs on :ref:`parallelism` for more information. + +.. autoclass:: burr.core.parallelism.RunnableGraph + :members: + +.. autoclass:: burr.core.parallelism.SubGraphTask + :members: + +.. autoclass:: burr.core.parallelism.TaskBasedParallelAction + :members: + +.. autoclass:: burr.core.parallelism.MapActionsAndStates + :members: + +.. autoclass:: burr.core.parallelism.MapActions + :members: + +.. autoclass:: burr.core.parallelism.MapStates + :members: + +.. automethod:: burr.core.parallelism.map_reduce_action diff --git a/examples/recursive/statemachine.png b/examples/recursive/statemachine.png index 0856bd5c..9f7e4dad 100644 Binary files a/examples/recursive/statemachine.png and b/examples/recursive/statemachine.png differ diff --git a/tests/core/test_parallelism.py b/tests/core/test_parallelism.py new file mode 100644 index 00000000..b3c84fb0 --- /dev/null +++ b/tests/core/test_parallelism.py @@ -0,0 +1,945 @@ +import asyncio +import dataclasses +import datetime +from random import random +from typing import Any, AsyncGenerator, Callable, Dict, Generator, List, Optional, Union + +import pytest + +from burr.common import types as burr_types +from burr.core import ( + Action, + ApplicationBuilder, + ApplicationContext, + ApplicationGraph, + State, + action, +) +from burr.core.action import Input, Result +from burr.core.graph import GraphBuilder +from burr.core.parallelism import ( + MapActions, + MapActionsAndStates, + MapStates, + RunnableGraph, + SubGraphTask, + TaskBasedParallelAction, + map_reduce_action, +) +from burr.tracking.base import SyncTrackingClient +from burr.visibility import ActionSpan + +old_action = action + + +async def sleep_random(): + await asyncio.sleep(random()) + + +# Single action/callable subgraph +@action(reads=["input_number", "number_to_add"], writes=["output_number"]) +def simple_single_fn_subgraph( + state: State, additional_number: int = 1, identifying_number: int = 1000 +) -> State: + return state.update( + output_number=state["input_number"] + + state["number_to_add"] + + additional_number + + identifying_number + ) + + +# Single action/callable subgraph +@action(reads=["input_number", "number_to_add"], writes=["output_number"]) +async def simple_single_fn_subgraph_async( + state: State, additional_number: int = 1, identifying_number: int = 1000 +) -> State: + await sleep_random() + return state.update( + output_number=state["input_number"] + + state["number_to_add"] + + additional_number + + identifying_number + ) + + +class ClassBasedAction(Action): + def __init__(self, identifying_number: int, name: str = "class_based_action"): + super().__init__() + self._name = name + self.identifying_number = identifying_number + + @property + def reads(self) -> list[str]: + return ["input_number", "number_to_add"] + + def run(self, state: State, **run_kwargs) -> dict: + return { + "output_number": state["input_number"] + + state["number_to_add"] + + run_kwargs.get("additional_number", 1) + + self.identifying_number + } + + @property + def writes(self) -> list[str]: + return ["output_number"] + + def update(self, result: dict, state: State) -> State: + return state.update(**result) + + +class ClassBasedActionAsync(ClassBasedAction): + async def run(self, state: State, **run_kwargs) -> dict: + await sleep_random() + return super().run(state, **run_kwargs) + + +@action(reads=["input_number"], writes=["current_number"]) +def entry_action_for_subgraph(state: State) -> State: + return state.update(current_number=state["input_number"]) + + +@action(reads=["current_number", "number_to_add"], writes=["current_number"]) +def add_number_to_add(state: State) -> State: + return state.update(current_number=state["current_number"] + state["number_to_add"]) + + +@action(reads=["current_number"], writes=["current_number"]) +def add_additional_number_to_add( + state: State, additional_number: int = 1, identifying_number: int = 3000 +) -> State: + return state.update( + current_number=state["current_number"] + additional_number + identifying_number + ) # 1000 is the one that marks this as different + + +@action(reads=["current_number"], writes=["output_number"]) +def final_result(state: State) -> State: + return state.update(output_number=state["current_number"]) + + +@action(reads=["input_number"], writes=["current_number"]) +async def entry_action_for_subgraph_async(state: State) -> State: + await sleep_random() + return entry_action_for_subgraph(state) + + +@action(reads=["current_number", "number_to_add"], writes=["current_number"]) +async def add_number_to_add_async(state: State) -> State: + await sleep_random() + return add_number_to_add(state) + + +@action(reads=["current_number"], writes=["current_number"]) +async def add_additional_number_to_add_async( + state: State, additional_number: int = 1, identifying_number: int = 3000 +) -> State: + await sleep_random() + return add_additional_number_to_add( + state, additional_number=additional_number, identifying_number=identifying_number + ) # 1000 is the one that marks this as different + + +@action(reads=["current_number"], writes=["output_number"]) +async def final_result_async(state: State) -> State: + await sleep_random() + return final_result(state) + + +SubGraphType = Union[Action, Callable, RunnableGraph] + + +def create_full_subgraph(identifying_number: int = 0) -> SubGraphType: + return RunnableGraph( + graph=( + GraphBuilder() + .with_actions( + entry_action_for_subgraph, + add_number_to_add, + add_additional_number_to_add.bind(identifying_number=identifying_number), + final_result, + ) + .with_transitions( + ("entry_action_for_subgraph", "add_number_to_add"), + ("add_number_to_add", "add_additional_number_to_add"), + ("add_additional_number_to_add", "final_result"), + ) + .build() + ), + entrypoint="entry_action_for_subgraph", + halt_after=["final_result"], + ) + + +def create_full_subgraph_async(identifying_number: int = 0) -> SubGraphType: + return RunnableGraph( + graph=GraphBuilder() + .with_actions( + entry_action_for_subgraph=entry_action_for_subgraph_async, + add_number_to_add=add_number_to_add_async, + add_additional_number_to_add=add_additional_number_to_add_async.bind( + identifying_number=identifying_number + ), + final_result=final_result_async, + ) + .with_transitions( + ("entry_action_for_subgraph", "add_number_to_add"), + ("add_number_to_add", "add_additional_number_to_add"), + ("add_additional_number_to_add", "final_result"), + ) + .build(), + entrypoint="entry_action_for_subgraph", + halt_after=["final_result"], + ) + + +FULL_SUBGRAPH: SubGraphType = create_full_subgraph(identifying_number=3000) +FULL_SUBGRAPH_ASYNC: SubGraphType = create_full_subgraph_async(identifying_number=3000) + + +@dataclasses.dataclass +class RecursiveActionTracked: + state_before: Optional[State] + state_after: Optional[State] + action: Action + app_id: str + partition_key: str + sequence_id: int + children: List["RecursiveActionTracked"] = dataclasses.field(default_factory=list) + + +class RecursiveActionTracker(SyncTrackingClient): + """Simple test tracking client for a recursive action""" + + def __init__(self, events: List[RecursiveActionTracked]): + self.events = events + + def copy(self): + """Quick way to copy from the current state. This assumes linearity (which is true in this case, as parallelism is delegated)""" + if self.events: + current_event = self.events[-1] + if current_event.state_after is not None: + raise ValueError("Don't copy if you're not in the middle of an event") + return RecursiveActionTracker(current_event.children) + raise ValueError("Don't copy if you're not in the middle of an event") + + def post_application_create( + self, + *, + app_id: str, + partition_key: Optional[str], + state: "State", + application_graph: "ApplicationGraph", + parent_pointer: Optional[burr_types.ParentPointer], + spawning_parent_pointer: Optional[burr_types.ParentPointer], + **future_kwargs: Any, + ): + pass + + def pre_run_step( + self, + *, + app_id: str, + partition_key: str, + sequence_id: int, + state: "State", + action: "Action", + inputs: Dict[str, Any], + **future_kwargs: Any, + ): + self.events.append( + RecursiveActionTracked( + state_before=state, + state_after=None, + action=action, + app_id=app_id, + partition_key=partition_key, + sequence_id=sequence_id, + ) + ) + + def post_run_step( + self, + *, + app_id: str, + partition_key: str, + sequence_id: int, + state: "State", + action: "Action", + result: Optional[Dict[str, Any]], + exception: Exception, + **future_kwargs: Any, + ): + self.events[-1].state_after = state + + def pre_start_span( + self, + *, + action: str, + action_sequence_id: int, + span: "ActionSpan", + span_dependencies: list[str], + app_id: str, + partition_key: Optional[str], + **future_kwargs: Any, + ): + pass + + def post_end_span( + self, + *, + action: str, + action_sequence_id: int, + span: "ActionSpan", + span_dependencies: list[str], + app_id: str, + partition_key: Optional[str], + **future_kwargs: Any, + ): + pass + + def do_log_attributes( + self, + *, + attributes: Dict[str, Any], + action: str, + action_sequence_id: int, + span: Optional["ActionSpan"], + tags: dict, + app_id: str, + partition_key: Optional[str], + **future_kwargs: Any, + ): + pass + + def pre_start_stream( + self, + *, + action: str, + sequence_id: int, + app_id: str, + partition_key: Optional[str], + **future_kwargs: Any, + ): + pass + + def post_stream_item( + self, + *, + item: Any, + item_index: int, + stream_initialize_time: datetime.datetime, + first_stream_item_start_time: datetime.datetime, + action: str, + sequence_id: int, + app_id: str, + partition_key: Optional[str], + **future_kwargs: Any, + ): + pass + + def post_end_stream( + self, + *, + action: str, + sequence_id: int, + app_id: str, + partition_key: Optional[str], + **future_kwargs: Any, + ): + pass + + +def _group_events_by_app_id( + events: List[RecursiveActionTracked], +) -> Dict[str, List[RecursiveActionTracked]]: + grouped_events = {} + for event in events: + if event.app_id not in grouped_events: + grouped_events[event.app_id] = [] + grouped_events[event.app_id].append(event) + return grouped_events + + +def test_e2e_map_actions_sync_subgraph(): + """Tests map actions over multiple action types (runnable graph, function, action class...)""" + + class MapActionsAllApproaches(MapActions): + def actions( + self, state: State, inputs: Dict[str, Any], context: ApplicationContext + ) -> Generator[Union[Action, Callable, RunnableGraph], None, None]: + for graph_ in [ + simple_single_fn_subgraph.bind(identifying_number=1000), + ClassBasedAction(2000), + create_full_subgraph(3000), + ]: + yield graph_ + + def state(self, state: State, inputs: Dict[str, Any]): + return state.update(input_number=state["input_number_in_state"], number_to_add=10) + + def reduce(self, state: State, states: Generator[State, None, None]) -> State: + # TODO -- ensure that states is in the correct order... + # Or decide to key it? + new_state = state + for output_state in states: + new_state = new_state.append(output_numbers_in_state=output_state["output_number"]) + return new_state + + @property + def writes(self) -> list[str]: + return ["output_numbers_in_state"] + + @property + def reads(self) -> list[str]: + return ["input_number_in_state"] + + app = ( + ApplicationBuilder() + .with_actions( + initial_action=Input("input_number_in_state"), + map_action=MapActionsAllApproaches(), + final_action=Result("output_numbers_in_state"), + ) + .with_transitions(("initial_action", "map_action"), ("map_action", "final_action")) + .with_entrypoint("initial_action") + .with_tracker(RecursiveActionTracker(events := [])) + .build() + ) + action, result, state = app.run( + halt_after=["final_action"], inputs={"input_number_in_state": 100} + ) + assert state["output_numbers_in_state"] == [1111, 2111, 3111] # esnsure order correct + assert len(events) == 3 # three parent actions + _, map_event, __ = events + grouped_events = _group_events_by_app_id(map_event.children) + assert len(grouped_events) == 3 # three unique App IDs, one for each launching subgraph + + +async def test_e2e_map_actions_async_subgraph(): + """Tests map actions over multiple action types (runnable graph, function, action class...)""" + + class MapActionsAllApproachesAsync(MapActions): + def actions( + self, state: State, inputs: Dict[str, Any], context: ApplicationContext + ) -> Generator[Union[Action, Callable, RunnableGraph], None, None]: + for graph_ in [ + simple_single_fn_subgraph_async.bind(identifying_number=1000), + ClassBasedActionAsync(2000), + create_full_subgraph_async(3000), + ]: + yield graph_ + + def is_async(self) -> bool: + return True + + def state(self, state: State, inputs: Dict[str, Any]): + return state.update(input_number=state["input_number_in_state"], number_to_add=10) + + async def reduce(self, state: State, states: AsyncGenerator[State, None]) -> State: + # TODO -- ensure that states is in the correct order... + # Or decide to key it? + new_state = state + async for output_state in states: + new_state = new_state.append(output_numbers_in_state=output_state["output_number"]) + return new_state + + @property + def writes(self) -> list[str]: + return ["output_numbers_in_state"] + + @property + def reads(self) -> list[str]: + return ["input_number_in_state"] + + app = ( + ApplicationBuilder() + .with_actions( + initial_action=Input("input_number_in_state"), + map_action=MapActionsAllApproachesAsync(), + final_action=Result("output_numbers_in_state"), + ) + .with_transitions(("initial_action", "map_action"), ("map_action", "final_action")) + .with_entrypoint("initial_action") + .with_tracker(RecursiveActionTracker(events := [])) + .build() + ) + action, result, state = await app.arun( + halt_after=["final_action"], inputs={"input_number_in_state": 100} + ) + assert state["output_numbers_in_state"] == [1111, 2111, 3111] # ensure order correct + assert len(events) == 3 # three parent actions + _, map_event, __ = events + grouped_events = _group_events_by_app_id(map_event.children) + assert len(grouped_events) == 3 # three unique App IDs, one for each launching subgraph + + +@pytest.mark.parametrize( + "action", + [ + simple_single_fn_subgraph.bind(identifying_number=0), + ClassBasedAction(0), + create_full_subgraph(0), + ], +) +def test_e2e_map_states_sync_subgraph(action: SubGraphType): + """Tests the map states action with a subgraph that is run in parallel. + Collatz conjecture over different starting points""" + + class MapStatesSync(MapStates): + def states( + self, state: State, context: ApplicationContext, inputs: Dict[str, Any] + ) -> Generator[State, None, None]: + for input_number in state["input_numbers_in_state"]: + yield state.update(input_number=input_number, number_to_add=10) + + def action( + self, state: State, inputs: Dict[str, Any] + ) -> Union[Action, Callable, RunnableGraph]: + return action + + def is_async(self) -> bool: + return False + + def reduce(self, state: State, states: Generator[State, None, None]) -> State: + # TODO -- ensure that states is in the correct order... + # Or decide to key it? + new_state = state + for output_state in states: + new_state = new_state.append(output_numbers_in_state=output_state["output_number"]) + return new_state + + @property + def writes(self) -> list[str]: + return ["output_numbers_in_state"] + + @property + def reads(self) -> list[str]: + return ["input_numbers_in_state"] + + app = ( + ApplicationBuilder() + .with_actions( + initial_action=Input("input_numbers_in_state"), + map_action=MapStatesSync(), + final_action=Result("output_numbers_in_state"), + ) + .with_transitions(("initial_action", "map_action"), ("map_action", "final_action")) + .with_entrypoint("initial_action") + .with_tracker(RecursiveActionTracker(events := [])) + .build() + ) + action, result, state = app.run( + halt_after=["final_action"], inputs={"input_numbers_in_state": [100, 200, 300]} + ) + assert state["output_numbers_in_state"] == [111, 211, 311] # ensure order correct + assert len(events) == 3 + _, map_event, __ = events + grouped_events = _group_events_by_app_id(map_event.children) + assert len(grouped_events) == 3 + + +@pytest.mark.parametrize( + "action", + [ + simple_single_fn_subgraph_async.bind(identifying_number=0), + ClassBasedActionAsync(0), + create_full_subgraph_async(0), + ], +) +async def test_e2e_map_states_async_subgraph(action: SubGraphType): + """Tests the map states action with a subgraph that is run in parallel. + Collatz conjecture over different starting points""" + + class MapStatesAsync(MapStates): + def states( + self, state: State, context: ApplicationContext, inputs: Dict[str, Any] + ) -> Generator[State, None, None]: + for input_number in state["input_numbers_in_state"]: + yield state.update(input_number=input_number, number_to_add=10) + + def action( + self, state: State, inputs: Dict[str, Any] + ) -> Union[Action, Callable, RunnableGraph]: + return action + + def is_async(self) -> bool: + return True + + async def reduce(self, state: State, states: AsyncGenerator[State, None]) -> State: + # TODO -- ensure that states is in the correct order... + # Or decide to key it? + new_state = state + async for output_state in states: + new_state = new_state.append(output_numbers_in_state=output_state["output_number"]) + return new_state + + @property + def writes(self) -> list[str]: + return ["output_numbers_in_state"] + + @property + def reads(self) -> list[str]: + return ["input_numbers_in_state"] + + app = ( + ApplicationBuilder() + .with_actions( + initial_action=Input("input_numbers_in_state"), + map_action=MapStatesAsync(), + final_action=Result("output_numbers_in_state"), + ) + .with_transitions(("initial_action", "map_action"), ("map_action", "final_action")) + .with_entrypoint("initial_action") + .with_tracker(RecursiveActionTracker(events := [])) + .build() + ) + action, result, state = await app.arun( + halt_after=["final_action"], inputs={"input_numbers_in_state": [100, 200, 300]} + ) + assert state["output_numbers_in_state"] == [111, 211, 311] # ensure order correct + assert len(events) == 3 + _, map_event, __ = events + grouped_events = _group_events_by_app_id(map_event.children) + assert len(grouped_events) == 3 + + +def test_e2e_map_actions_and_states_sync(): + """Tests the map states action with a subgraph that is run in parallel. + Collatz conjecture over different starting points""" + + class MapStatesAsync(MapActionsAndStates): + def actions( + self, state: State, context: ApplicationContext, inputs: Dict[str, Any] + ) -> Generator[Union[Action, Callable, RunnableGraph], None, None]: + for graph_ in [ + simple_single_fn_subgraph.bind(identifying_number=1000), + ClassBasedAction(2000), + create_full_subgraph(3000), + ]: + yield graph_ + + def states( + self, state: State, context: ApplicationContext, inputs: Dict[str, Any] + ) -> Generator[State, None, None]: + for input_number in state["input_numbers_in_state"]: + yield state.update(input_number=input_number, number_to_add=10) + + def is_async(self) -> bool: + return False + + def reduce(self, state: State, states: Generator[State, None, None]) -> State: + # TODO -- ensure that states is in the correct order... + # Or decide to key it? + new_state = state + for output_state in states: + new_state = new_state.append(output_numbers_in_state=output_state["output_number"]) + return new_state + + @property + def writes(self) -> list[str]: + return ["output_numbers_in_state"] + + @property + def reads(self) -> list[str]: + return ["input_numbers_in_state"] + + app = ( + ApplicationBuilder() + .with_actions( + initial_action=Input("input_numbers_in_state"), + map_action=MapStatesAsync(), + final_action=Result("output_numbers_in_state"), + ) + .with_transitions(("initial_action", "map_action"), ("map_action", "final_action")) + .with_entrypoint("initial_action") + .with_tracker(RecursiveActionTracker(events := [])) + .build() + ) + action, result, state = app.run( + halt_after=["final_action"], inputs={"input_numbers_in_state": [100, 200, 300]} + ) + assert state["output_numbers_in_state"] == [ + 1111, + 1211, + 1311, + 2111, + 2211, + 2311, + 3111, + 3211, + 3311, + ] + assert len(events) == 3 + _, map_event, __ = events + grouped_events = _group_events_by_app_id(map_event.children) + assert len(grouped_events) == 9 # cartesian product of 3 actions and 3 states + + +async def test_e2e_map_actions_and_states_async(): + """Tests the map states action with a subgraph that is run in parallel. + Collatz conjecture over different starting points""" + + class MapStatesAsync(MapActionsAndStates): + def actions( + self, state: State, context: ApplicationContext, inputs: Dict[str, Any] + ) -> Generator[Union[Action, Callable, RunnableGraph], None, None]: + for graph_ in [ + simple_single_fn_subgraph_async.bind(identifying_number=1000), + ClassBasedActionAsync(2000), + create_full_subgraph_async(3000), + ]: + yield graph_ + + def states( + self, state: State, context: ApplicationContext, inputs: Dict[str, Any] + ) -> AsyncGenerator[State, None]: + for input_number in state["input_numbers_in_state"]: + yield state.update(input_number=input_number, number_to_add=10) + + def is_async(self) -> bool: + return True + + async def reduce(self, state: State, states: AsyncGenerator[State, None]) -> State: + # TODO -- ensure that states is in the correct order... + # Or decide to key it? + new_state = state + async for output_state in states: + new_state = new_state.append(output_numbers_in_state=output_state["output_number"]) + return new_state + + @property + def writes(self) -> list[str]: + return ["output_numbers_in_state"] + + @property + def reads(self) -> list[str]: + return ["input_numbers_in_state"] + + app = ( + ApplicationBuilder() + .with_actions( + initial_action=Input("input_numbers_in_state"), + map_action=MapStatesAsync(), + final_action=Result("output_numbers_in_state"), + ) + .with_transitions(("initial_action", "map_action"), ("map_action", "final_action")) + .with_entrypoint("initial_action") + .with_tracker(RecursiveActionTracker(events := [])) + .build() + ) + action, result, state = await app.arun( + halt_after=["final_action"], inputs={"input_numbers_in_state": [100, 200, 300]} + ) + assert state["output_numbers_in_state"] == [ + 1111, + 1211, + 1311, + 2111, + 2211, + 2311, + 3111, + 3211, + 3311, + ] + assert len(events) == 3 + _, map_event, __ = events + grouped_events = _group_events_by_app_id(map_event.children) + assert len(grouped_events) == 9 # cartesian product of 3 actions and 3 states + + +def test_task_level_API_e2e_sync(): + """Tests the map states action with a subgraph that is run in parallel. + Collatz conjecture over different starting points""" + + class TaskBasedAction(TaskBasedParallelAction): + def tasks( + self, state: State, context: ApplicationContext, inputs: Dict[str, Any] + ) -> Generator[SubGraphTask, None, None]: + for j, action in enumerate( + [ + simple_single_fn_subgraph.bind(identifying_number=1000), + ClassBasedAction(2000), + create_full_subgraph(3000), + ] + ): + for i, input_number in enumerate(state["input_numbers_in_state"]): + yield SubGraphTask( + graph=RunnableGraph.create(action), + inputs={}, + state=state.update(input_number=input_number, number_to_add=10), + application_id=f"{i}_{j}", + ) + + def reduce(self, state: State, states: Generator[State, None, None]) -> State: + # TODO -- ensure that states is in the correct order... + # Or decide to key it? + new_state = state + for output_state in states: + new_state = new_state.append(output_numbers_in_state=output_state["output_number"]) + return new_state + + @property + def writes(self) -> list[str]: + return ["output_numbers_in_state"] + + @property + def reads(self) -> list[str]: + return ["input_numbers_in_state"] + + app = ( + ApplicationBuilder() + .with_actions( + initial_action=Input("input_numbers_in_state"), + map_action=TaskBasedAction(), + final_action=Result("output_numbers_in_state"), + ) + .with_transitions(("initial_action", "map_action"), ("map_action", "final_action")) + .with_entrypoint("initial_action") + .with_tracker(RecursiveActionTracker(events := [])) + .build() + ) + action, result, state = app.run( + halt_after=["final_action"], inputs={"input_numbers_in_state": [100, 200, 300]} + ) + assert state["output_numbers_in_state"] == [ + 1111, + 1211, + 1311, + 2111, + 2211, + 2311, + 3111, + 3211, + 3311, + ] + assert len(events) == 3 + _, map_event, __ = events + grouped_events = _group_events_by_app_id(map_event.children) + assert len(grouped_events) == 9 # cartesian product of 3 actions and 3 states + + +async def test_task_level_API_e2e_async(): + """Tests the map states action with a subgraph that is run in parallel. + Collatz conjecture over different starting points""" + + class TaskBasedActionAsync(TaskBasedParallelAction): + async def tasks( + self, state: State, context: ApplicationContext, inputs: Dict[str, Any] + ) -> AsyncGenerator[SubGraphTask, None]: + for j, action in enumerate( + [ + simple_single_fn_subgraph.bind(identifying_number=1000), + ClassBasedAction(2000), + create_full_subgraph(3000), + ] + ): + for i, input_number in enumerate(state["input_numbers_in_state"]): + yield SubGraphTask( + graph=RunnableGraph.create(action), + inputs={}, + state=state.update(input_number=input_number, number_to_add=10), + application_id=f"{i}_{j}", + ) + + async def reduce(self, state: State, states: AsyncGenerator[State, None]) -> State: + # TODO -- ensure that states is in the correct order... + # Or decide to key it? + new_state = state + async for output_state in states: + new_state = new_state.append(output_numbers_in_state=output_state["output_number"]) + return new_state + + @property + def writes(self) -> list[str]: + return ["output_numbers_in_state"] + + @property + def reads(self) -> list[str]: + return ["input_numbers_in_state"] + + def is_async(self) -> bool: + return True + + app = ( + ApplicationBuilder() + .with_actions( + initial_action=Input("input_numbers_in_state"), + map_action=TaskBasedActionAsync(), + final_action=Result("output_numbers_in_state"), + ) + .with_transitions(("initial_action", "map_action"), ("map_action", "final_action")) + .with_entrypoint("initial_action") + .with_tracker(RecursiveActionTracker(events := [])) + .build() + ) + action, result, state = await app.arun( + halt_after=["final_action"], inputs={"input_numbers_in_state": [100, 200, 300]} + ) + assert state["output_numbers_in_state"] == [ + 1111, + 1211, + 1311, + 2111, + 2211, + 2311, + 3111, + 3211, + 3311, + ] + assert len(events) == 3 + _, map_event, __ = events + grouped_events = _group_events_by_app_id(map_event.children) + assert len(grouped_events) == 9 # cartesian product of 3 actions and 3 states + + +def test_map_reduce_function_e2e(): + mre = map_reduce_action( + action=[ + simple_single_fn_subgraph.bind(identifying_number=1000), + ClassBasedAction(2000), + create_full_subgraph(3000), + ], + reads=["input_numbers_in_state"], + writes=["output_numbers_in_state"], + state=lambda state, context, inputs: ( + state.update(input_number=input_number, number_to_add=10) + for input_number in state["input_numbers_in_state"] + ), + inputs=[], + reducer=lambda state, states: state.extend( + output_numbers_in_state=[output_state["output_number"] for output_state in states] + ), + ) + + app = ( + ApplicationBuilder() + .with_actions( + initial_action=Input("input_numbers_in_state"), + map_action=mre, + final_action=Result("output_numbers_in_state"), + ) + .with_transitions(("initial_action", "map_action"), ("map_action", "final_action")) + .with_entrypoint("initial_action") + .with_tracker(RecursiveActionTracker(events := [])) + .build() + ) + action, result, state = app.run( + halt_after=["final_action"], inputs={"input_numbers_in_state": [100, 200, 300]} + ) + assert state["output_numbers_in_state"] == [ + 1111, + 1211, + 1311, + 2111, + 2211, + 2311, + 3111, + 3211, + 3311, + ] + assert len(events) == 3 + _, map_event, __ = events + grouped_events = _group_events_by_app_id(map_event.children) + assert len(grouped_events) == 9 # cartesian product of 3 actions and 3 states