Skip to content

Commit

Permalink
Implemets parallelism capabilities according to spec
Browse files Browse the repository at this point in the history
  • Loading branch information
elijahbenizzy committed Nov 14, 2024
1 parent e057b56 commit 9aed277
Show file tree
Hide file tree
Showing 10 changed files with 1,733 additions and 5 deletions.
44 changes: 43 additions & 1 deletion burr/core/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import logging
import pprint
import uuid
from concurrent.futures import Executor, ThreadPoolExecutor
from contextlib import AbstractContextManager
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -497,6 +498,7 @@ class ApplicationContext(AbstractContextManager):
partition_key: Optional[str]
sequence_id: Optional[int]
tracker: Optional["TrackingClient"]
parallel_executor_factory: Callable[[], Executor]

@staticmethod
def get() -> Optional["ApplicationContext"]:
Expand Down Expand Up @@ -683,6 +685,10 @@ def post_run_step(
StreamResultType = TypeVar("StreamResultType", bound=Union[dict, Any])


def _create_default_executor() -> Executor:
return ThreadPoolExecutor()


class Application(Generic[ApplicationStateType]):
def __init__(
self,
Expand All @@ -697,6 +703,7 @@ def __init__(
fork_parent_pointer: Optional[burr_types.ParentPointer] = None,
spawning_parent_pointer: Optional[burr_types.ParentPointer] = None,
tracker: Optional["TrackingClient"] = None,
parallel_executor_factory: Optional[Executor] = None,
):
"""Instantiates an Application. This is an internal API -- use the builder!
Expand Down Expand Up @@ -731,6 +738,11 @@ def __init__(
self._set_sequence_id(sequence_id)
self._builder = builder
self._parent_pointer = fork_parent_pointer
self._parallel_executor_factory = (
parallel_executor_factory
if parallel_executor_factory is not None
else _create_default_executor
)
self._dependency_factory = {
"__tracer": functools.partial(
visibility.tracing.TracerFactory,
Expand Down Expand Up @@ -780,6 +792,7 @@ def _context_factory(self, action: Action, sequence_id: int) -> ApplicationConte
tracker=self._tracker,
partition_key=self._partition_key,
sequence_id=sequence_id,
parallel_executor_factory=self._parallel_executor_factory,
)

def _step(
Expand Down Expand Up @@ -862,7 +875,7 @@ def _process_inputs(self, inputs: Dict[str, Any], action: Action) -> Dict[str, A
BASE_ERROR_MESSAGE
+ f"Inputs starting with a double underscore ({starting_with_double_underscore}) "
f"are reserved for internal use/injected inputs."
"Please do not use keys"
"Please do not directly pass keys starting with a double underscore."
)
inputs = inputs.copy()
processed_inputs = {}
Expand Down Expand Up @@ -1922,6 +1935,7 @@ def __init__(self):
self.graph_builder = None
self.prebuilt_graph = None
self.typing_system = None
self._parallel_executor_factory = None

def with_identifiers(
self, app_id: str = None, partition_key: str = None, sequence_id: int = None
Expand Down Expand Up @@ -2015,6 +2029,33 @@ def with_graph(self, graph: Graph) -> "ApplicationBuilder[StateType]":
self.prebuilt_graph = graph
return self

def with_parallel_executor(self, executor_factory: lambda: Executor):
"""Assigns a default executor to be used for recursive/parallel sub-actions. This effectively allows
for executing multiple Burr apps in parallel. See https://burr.dagworks.io/pull/concepts/parallelism/
for more details.
This will default to a simple threadpool executor, meaning that you will be bound by the number of threads
your computer can handle. If you want to use a more advanced executor, you can pass it in here -- any subclass
of concurrent.futures.Executor will work.
If you specify executors for specific tasks, this will default to that.
Note that, if you are using asyncio, you cannot specify an executor. It will default to using
asyncio.gather with asyncio's event loop.
:param executor:
:return:
"""
if self._parallel_executor_factory is not None:
raise ValueError(
BASE_ERROR_MESSAGE
+ "You have already set an executor. You cannot set multiple executors. Current executor is:"
f"{self._parallel_executor_factory}"
)

self._parallel_executor_factory = executor_factory
return self

def _ensure_no_prebuilt_graph(self):
if self.prebuilt_graph is not None:
raise ValueError(
Expand Down Expand Up @@ -2365,4 +2406,5 @@ def build(self) -> Application[StateType]:
if self.spawn_from_app_id is not None
else None
),
parallel_executor_factory=self._parallel_executor_factory,
)
5 changes: 1 addition & 4 deletions burr/core/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def _validate_actions(actions: Optional[List[Action]]):
def _validate_transitions(
transitions: Optional[List[Tuple[str, str, Condition]]], actions: Set[str]
):
assert_set(transitions, "_transitions", "with_transitions")
exhausted = {} # items for which we have seen a default transition
for from_, to, condition in transitions:
if from_ not in actions:
Expand Down Expand Up @@ -235,7 +234,7 @@ class GraphBuilder:

def __init__(self):
"""Initializes the graph builder."""
self.transitions: Optional[List[Tuple[str, str, Condition]]] = None
self.transitions: Optional[List[Tuple[str, str, Condition]]] = []
self.actions: Optional[List[Action]] = None

def with_actions(
Expand Down Expand Up @@ -283,8 +282,6 @@ def with_transitions(
:param transitions: Transitions to add
:return: The application builder for future chaining.
"""
if self.transitions is None:
self.transitions = []
for transition in transitions:
from_, to_, *conditions = transition
if len(conditions) > 0:
Expand Down
Loading

0 comments on commit 9aed277

Please sign in to comment.