Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallelism #370

Merged
merged 5 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ but realized that it has a wide array of applications and decided to release it

While Burr is stable and well-tested, we have quite a few tools/features on our roadmap!

1. Parallelism -- support for recursive "sub-agents" through an ergonomic API (not: this is already feasible, see [recursive applications](http://localhost:8000/concepts/recursion/)).
1. Parallelism -- support for recursive "sub-agents" through an ergonomic API (not: this is already feasible, see [recursive applications](https://burr.dagworks.io/recursion/)).
2. Testing & eval curation. Curating data with annotations and being able to export these annotations to create unit & integration tests.
3. Various efficiency/usability improvements for the core library (see [planned capabilities](https://burr.dagworks.io/concepts/planned-capabilities/) for more details). This includes:
1. First-class support for retries + exception management
Expand Down
44 changes: 43 additions & 1 deletion burr/core/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import logging
import pprint
import uuid
from concurrent.futures import Executor, ThreadPoolExecutor
from contextlib import AbstractContextManager
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -497,6 +498,7 @@ class ApplicationContext(AbstractContextManager):
partition_key: Optional[str]
sequence_id: Optional[int]
tracker: Optional["TrackingClient"]
parallel_executor_factory: Callable[[], Executor]

@staticmethod
def get() -> Optional["ApplicationContext"]:
Expand Down Expand Up @@ -683,6 +685,10 @@ def post_run_step(
StreamResultType = TypeVar("StreamResultType", bound=Union[dict, Any])


def _create_default_executor() -> Executor:
return ThreadPoolExecutor()


class Application(Generic[ApplicationStateType]):
def __init__(
self,
Expand All @@ -697,6 +703,7 @@ def __init__(
fork_parent_pointer: Optional[burr_types.ParentPointer] = None,
spawning_parent_pointer: Optional[burr_types.ParentPointer] = None,
tracker: Optional["TrackingClient"] = None,
parallel_executor_factory: Optional[Executor] = None,
):
"""Instantiates an Application. This is an internal API -- use the builder!
Expand Down Expand Up @@ -731,6 +738,11 @@ def __init__(
self._set_sequence_id(sequence_id)
self._builder = builder
self._parent_pointer = fork_parent_pointer
self._parallel_executor_factory = (
parallel_executor_factory
if parallel_executor_factory is not None
else _create_default_executor
)
self._dependency_factory = {
"__tracer": functools.partial(
visibility.tracing.TracerFactory,
Expand Down Expand Up @@ -780,6 +792,7 @@ def _context_factory(self, action: Action, sequence_id: int) -> ApplicationConte
tracker=self._tracker,
partition_key=self._partition_key,
sequence_id=sequence_id,
parallel_executor_factory=self._parallel_executor_factory,
)

def _step(
Expand Down Expand Up @@ -862,7 +875,7 @@ def _process_inputs(self, inputs: Dict[str, Any], action: Action) -> Dict[str, A
BASE_ERROR_MESSAGE
+ f"Inputs starting with a double underscore ({starting_with_double_underscore}) "
f"are reserved for internal use/injected inputs."
"Please do not use keys"
"Please do not directly pass keys starting with a double underscore."
)
inputs = inputs.copy()
processed_inputs = {}
Expand Down Expand Up @@ -1922,6 +1935,7 @@ def __init__(self):
self.graph_builder = None
self.prebuilt_graph = None
self.typing_system = None
self._parallel_executor_factory = None

def with_identifiers(
self, app_id: str = None, partition_key: str = None, sequence_id: int = None
Expand Down Expand Up @@ -2015,6 +2029,33 @@ def with_graph(self, graph: Graph) -> "ApplicationBuilder[StateType]":
self.prebuilt_graph = graph
return self

def with_parallel_executor(self, executor_factory: lambda: Executor):
"""Assigns a default executor to be used for recursive/parallel sub-actions. This effectively allows
for executing multiple Burr apps in parallel. See https://burr.dagworks.io/pull/concepts/parallelism/
for more details.
This will default to a simple threadpool executor, meaning that you will be bound by the number of threads
your computer can handle. If you want to use a more advanced executor, you can pass it in here -- any subclass
of concurrent.futures.Executor will work.
If you specify executors for specific tasks, this will default to that.
Note that, if you are using asyncio, you cannot specify an executor. It will default to using
asyncio.gather with asyncio's event loop.
:param executor:
:return:
"""
if self._parallel_executor_factory is not None:
raise ValueError(
BASE_ERROR_MESSAGE
+ "You have already set an executor. You cannot set multiple executors. Current executor is:"
f"{self._parallel_executor_factory}"
)

self._parallel_executor_factory = executor_factory
return self

def _ensure_no_prebuilt_graph(self):
if self.prebuilt_graph is not None:
raise ValueError(
Expand Down Expand Up @@ -2365,4 +2406,5 @@ def build(self) -> Application[StateType]:
if self.spawn_from_app_id is not None
else None
),
parallel_executor_factory=self._parallel_executor_factory,
)
5 changes: 1 addition & 4 deletions burr/core/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def _validate_actions(actions: Optional[List[Action]]):
def _validate_transitions(
transitions: Optional[List[Tuple[str, str, Condition]]], actions: Set[str]
):
assert_set(transitions, "_transitions", "with_transitions")
exhausted = {} # items for which we have seen a default transition
for from_, to, condition in transitions:
if from_ not in actions:
Expand Down Expand Up @@ -235,7 +234,7 @@ class GraphBuilder:

def __init__(self):
"""Initializes the graph builder."""
self.transitions: Optional[List[Tuple[str, str, Condition]]] = None
self.transitions: Optional[List[Tuple[str, str, Condition]]] = []
self.actions: Optional[List[Action]] = None

def with_actions(
Expand Down Expand Up @@ -283,8 +282,6 @@ def with_transitions(
:param transitions: Transitions to add
:return: The application builder for future chaining.
"""
if self.transitions is None:
self.transitions = []
for transition in transitions:
from_, to_, *conditions = transition
if len(conditions) > 0:
Expand Down
Loading
Loading