Skip to content

Commit

Permalink
Additional updates to parallelism docs
Browse files Browse the repository at this point in the history
More updates to make API consistent:

1. Showing inputs for class-based action (repeat but good to hammer
   home)
2. Making API consistent
  • Loading branch information
elijahbenizzy committed Dec 4, 2024
1 parent 2b5985a commit ddd675a
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 21 deletions.
5 changes: 5 additions & 0 deletions docs/concepts/actions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ There are two APIs for defining actions: class-based and function-based. They ar
- use the function-based API when you want to write something quick and terse that reads from a fixed set of state variables
- use the class-based API when you want to leverage inheritance or parameterize the action in more powerful ways

.. _functionbasedactions:

----------------------
Function-based actions
----------------------
Expand Down Expand Up @@ -127,6 +129,9 @@ injected into your Burr Actions. This is done by adding ``__context`` to the act
Class-Based Actions
-------------------

.. _classbasedactions:


You can define an action by implementing the :py:class:`Action <burr.core.action.Action>` class:

.. code-block:: python
Expand Down
43 changes: 22 additions & 21 deletions docs/concepts/parallelism.rst
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,12 @@ This looks as follows -- in this case we're running the same LLM over different
class TestMultiplePrompts(MapStates):
def action(self) -> Action | Callable | RunnableGraph:
def action(self, state: State, inputs: Dict[str, Any]) -> Action | Callable | RunnableGraph:
# make sure to add a name to the action
# This is not necessary for subgraphs, as actions will already have names
return query_llm.with_name("query_llm")
def states(self, state: State) -> Generator[State, None, None]:
def states(self, state: State, context: ApplicationContext, inputs: Dict[str, Any]) -> Generator[State, None, None]:
# You could easily have a list_prompts upstream action that writes to "prompts" in state
# And loop through those
# This hardcodes for simplicity
Expand Down Expand Up @@ -146,7 +146,7 @@ For case (2) (mapping actions over the same state) you implement the ``MapAction
from burr.core import action, state
from burr.core.parallelism import MapActions, RunnableGraph
from typing import Callable, Generator, List
from typing import Callable, Generator, List, Dict, Any
@action(reads=["prompt", "model"], writes=["llm_output"])
def query_llm(state: State, model: str) -> State:
Expand All @@ -155,7 +155,7 @@ For case (2) (mapping actions over the same state) you implement the ``MapAction
class TestMultipleModels(MapActions):
def actions(self, state: State) -> Generator[Action | Callable | RunnableGraph, None, None]:
def actions(self, state: State, context: ApplicationContext, inputs: Dict[str, Any]) -> Generator[Action | Callable | RunnableGraph, None, None]:
# Make sure to add a name to the action if you use bind() with a function,
# note that these can be different actions, functions, etc...
# in this case we're using `.bind()` to create multiple actions, but we can use some mix of
Expand All @@ -167,7 +167,7 @@ For case (2) (mapping actions over the same state) you implement the ``MapAction
]
yield action
def state(self, state: State) -> State:
def state(self, state: State, inputs: Dict[str, Any]) -> State:
return state.update(prompt="What is the meaning of life?")
def reduce(self, state: State, states: Generator[State, None, None]) -> State:
Expand Down Expand Up @@ -314,7 +314,7 @@ This might look as follows -- say we have a simple subflow that takes in a raw p
def action(self, state: State, inputs: Dict[str, Any]) -> Action | Callable | RunnableGraph:
return runnable_graph
def states(self, state: State) -> Generator[State, None, None]:
def states(self, state: State, context: ApplicationContext, inputs: Dict[str, Any]) -> Generator[State, None, None]:
for prompt in [
"What is the meaning of life?",
"What is the airspeed velocity of an unladen swallow?",
Expand All @@ -332,15 +332,11 @@ it can run just as the single prompt we did above. Note this is also doable for
Passing inputs
--------------

.. note::

Should ``MapOverInputs`` be its own class? Or should we have ``bind_from_state(prompt="prompt_field_in_state")`` that allows you to pass it in as
state and just use the mapping capabilities?

Each of these can (optionally) produce ``inputs`` by yielding/returning a tuple from the ``states``/``actions`` function.

This is useful if you want to vary the inputs. Note this is the same as passing ``inputs=`` to ``app.run``.
Parallel actions can accept inputs in the same way that class-based actions do. In order to accept inputs you have to declare them in the class. As we're using the :ref:`class-based API <classbasedactions>`,
this is done by declaring the ``inputs`` property -- a list of strings that are used in inputs. Note you have to use the superclasses
inputs as well to ensure it has everything it needs -- we will likely be automating this.

This looks as follows:

.. code-block:: python
Expand Down Expand Up @@ -379,17 +375,22 @@ This is useful if you want to vary the inputs. Note this is the same as passing
def action(self) -> Action | Callable | RunnableGraph:
return runnable_graph
def states(self, state: State) -> Generator[Tuple[State, dict], None, None]:
for prompt in [
"What is the meaning of life?",
"What is the airspeed velocity of an unladen swallow?",
"What is the best way to cook a steak?",
]:
yield state.update(prompt=prompt), {"model": "gpt-4"} # pass in the model as an input
@property
def inputs(self) -> List[str]:
return ["prompts"] + super().inputs # make sure to include the superclass inputs
def states(self, state: State, context: ApplicationContext, inputs: Dict[str, Any]) -> Generator[Tuple[State, dict], None, None]:
for prompt in inputs["prompts"]:
yield state.update(prompt=prompt)
... # same as above
.. note::

Should ``MapOverInputs`` be its own class? Or should we have ``bind_from_state(prompt="prompt_field_in_state")`` that allows you to pass it in as
state and just use the mapping capabilities? Or are we happy as it currently is because we can pass in inputs through `MapStates`/`MapActions` (as shown above).


Lower-level API
===============
Expand Down

0 comments on commit ddd675a

Please sign in to comment.