Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
elijahbenizzy committed Mar 2, 2024
1 parent c78170e commit 181c242
Show file tree
Hide file tree
Showing 9 changed files with 749 additions and 11 deletions.
3 changes: 3 additions & 0 deletions burr/core/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,9 @@ def with_params(self, **kwargs: Any) -> "FunctionBasedAction":
def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]:
return self._fn(state, **self._bound_params, **run_kwargs)

def is_async(self) -> bool:
return inspect.iscoroutinefunction(self._fn)


def _validate_action_function(fn: Callable):
"""Validates that an action has the signature: (state: State) -> Tuple[dict, State]
Expand Down
70 changes: 63 additions & 7 deletions burr/core/application.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import collections
import dataclasses
import functools
import logging
import pprint
from typing import (
Expand All @@ -16,6 +17,7 @@
Union,
)

from burr import visibility
from burr.core.action import (
Action,
Condition,
Expand Down Expand Up @@ -221,21 +223,30 @@ def __init__(
self._adapter_set.call_all_lifecycle_hooks_sync(
"post_application_create", state=self._state, application_graph=self._graph
)
# TODO -- consider adding global inputs + global input factories to the builder
self.dependency_factory = {
"__tracer": functools.partial(
visibility.tracing.TracerFactory, lifecycle_adapters=self._adapter_set
)
}

def step(self, inputs: Optional[Dict[str, Any]] = None) -> Optional[Tuple[Action, dict, State]]:
"""Performs a single step, advancing the state machine along.
This returns a tuple of the action that was run, the result of running
the action, and the new state.
Use this if you just want to do something with the state and not rely on generators.
E.G. press forward/backwards, hnuman in the loop, etc... Odds are this is not
E.G. press forward/backwards, human in the loop, etc... Odds are this is not
the method you want -- you'll want iterate() (if you want to see the state/
results along the way), or run() (if you just want the final state/results).
:param inputs: Inputs to the action -- this is if this action requires an input that is passed in from the outside world :param __run_hooks:
:param inputs: Inputs to the action -- this is if this action requires an input that is passed in from the outside world
:return: Tuple[Function, dict, State] -- the function that was just ran, the result of running it, and the new state
"""
return self._step(inputs=inputs, _run_hooks=True)

out = self._step(inputs=inputs, _run_hooks=True)
self._increment_sequence_id()
return out

def _step(
self, inputs: Optional[Dict[str, Any]] = None, _run_hooks: bool = True
Expand All @@ -247,6 +258,7 @@ def _step(
return None
if inputs is None:
inputs = {}
inputs = self._process_inputs(inputs, next_action)
if _run_hooks:
self._adapter_set.call_all_lifecycle_hooks_sync(
"pre_run_step", action=next_action, state=self._state, inputs=inputs
Expand Down Expand Up @@ -283,12 +295,37 @@ def update_internal_state_value(self, new_state: State, next_action: Action) ->
new_state = new_state.update(
**{
PRIOR_STEP: next_action.name,
# make it a string for future proofing
SEQUENCE_ID: str(int(self._state.get(SEQUENCE_ID, 0)) + 1),
}
)
return new_state

def _process_inputs(self, inputs: Dict[str, Any], action: Action) -> Dict[str, Any]:
inputs = inputs.copy()
processed_inputs = {}
for key in list(inputs.keys()):
if key in action.inputs:
processed_inputs[key] = inputs.pop(key)
if len(inputs) > 0:
raise ValueError(
f"Keys {inputs.keys()} were passed in as inputs to action "
f"{action.name}, but not declared by the action as an input! "
f"Action needs: {action.inputs}"
)
missing_inputs = set(action.inputs) - set(processed_inputs.keys())
for required_input in list(missing_inputs):
# if we can find it in the dependency factory, we'll use that
if required_input in self.dependency_factory:
processed_inputs[required_input] = self.dependency_factory[required_input](
action, self.sequence_id
)
missing_inputs.remove(required_input)
if len(missing_inputs) > 0:
raise ValueError(
f"Action {action.name} is missing required inputs: {missing_inputs}. "
f"Has inputs: {processed_inputs}"
)
return processed_inputs

async def astep(self, inputs: Dict[str, Any] = None) -> Optional[Tuple[Action, dict, State]]:
"""Asynchronous version of step.
Expand All @@ -298,10 +335,11 @@ async def astep(self, inputs: Dict[str, Any] = None) -> Optional[Tuple[Action, d
:return: Tuple[Function, dict, State] -- the action that was just ran, the result of running it, and the new state
"""
next_action = self.get_next_action()
if inputs is None:
inputs = {}
if next_action is None:
return None
if inputs is None:
inputs = {}
inputs = self._process_inputs(inputs, next_action)
await self._adapter_set.call_all_lifecycle_hooks_sync_and_async(
"pre_run_step", action=next_action, state=self._state, inputs=inputs
)
Expand Down Expand Up @@ -335,6 +373,7 @@ async def astep(self, inputs: Dict[str, Any] = None) -> Optional[Tuple[Action, d
"post_run_step", action=next_action, state=new_state, result=result, exception=exc
)
self._set_state(new_state)
self._increment_sequence_id()
return next_action, result, new_state

def _clean_iterate_params(
Expand Down Expand Up @@ -515,6 +554,9 @@ async def arun(
"""
prior_action = None
result = None
halt_before, halt_after, inputs = self._clean_iterate_params(
halt_before, halt_after, inputs
)
async for prior_action, result, state in self.aiterate(
halt_before=halt_before, halt_after=halt_after, inputs=inputs
):
Expand Down Expand Up @@ -647,6 +689,20 @@ def graph(self) -> ApplicationGraph:
"""
return self._graph

@property
def sequence_id(self) -> Optional[int]:
"""gives the sequence ID of the current (next) action.
This is incremented after every step is taken -- meaning that incremeneting
it is the very last action that is done. Any logging, etc... will use the current
step's sequence ID
:return:
"""
return self._state.get(SEQUENCE_ID, 0)

def _increment_sequence_id(self):
self._state = self._state.update(**{SEQUENCE_ID: self.sequence_id + 1})


def _assert_set(value: Optional[Any], field: str, method: str):
if value is None:
Expand Down
71 changes: 71 additions & 0 deletions burr/lifecycle/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
if TYPE_CHECKING:
# type-checking-only for a circular import
from burr.core import State, Action, ApplicationGraph
from burr.visibility import ActionSpan

from burr.lifecycle.internal import lifecycle

Expand Down Expand Up @@ -113,6 +114,72 @@ def post_application_create(
pass


@lifecycle.base_hook("pre_start_span")
class PreStartSpanHook(abc.ABC):
"""Hook that runs before a span is started in the tracing API.
This can be either a context manager or a logger."""

@abc.abstractmethod
def pre_start_span(
self,
*,
action: str,
action_sequence_id: int,
span: "ActionSpan",
span_dependencies: list[str],
**future_kwargs: Any,
):
pass


@lifecycle.base_hook("pre_start_span")
class PreStartSpanHookAsync(abc.ABC):
@abc.abstractmethod
async def pre_start_span(
self,
*,
action: str,
action_sequence_id: int,
span: "ActionSpan",
span_dependencies: list[str],
**future_kwargs: Any,
):
pass


@lifecycle.base_hook("post_end_span")
class PostEndSpanHook(abc.ABC):
"""Hook that runs after a span is ended in the tracing API.
This can be either a context manager or a logger."""

@abc.abstractmethod
def post_end_span(
self,
*,
action: str,
action_sequence_id: int,
span: "ActionSpan",
span_dependencies: list[str],
**future_kwargs: Any,
):
pass


@lifecycle.base_hook("post_end_span")
class PostEndSpanHookAsync(abc.ABC):
@abc.abstractmethod
async def post_end_span(
self,
*,
action: str,
action_sequence_id: int,
span: "ActionSpan",
span_dependencies: list[str],
**future_kwargs: Any,
):
pass


# THESE ARE NOT IN USE
# TODO -- implement/decide how to use them
@lifecycle.base_hook("pre_run_application")
Expand Down Expand Up @@ -159,4 +226,8 @@ async def post_run_application(
PostRunApplicationHook,
PostRunApplicationHookAsync,
PostApplicationCreateHook,
PreStartSpanHook,
PreStartSpanHookAsync,
PostEndSpanHook,
PostEndSpanHookAsync,
]
Empty file.
14 changes: 12 additions & 2 deletions burr/tracking/common/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,13 @@ def from_application_graph(application_graph: ApplicationGraph) -> "ApplicationM
)


INPUT_FILTERLIST = {"__tracer"}


def _filter_inputs(d: dict) -> dict:
return {k: v for k, v in d.items() if k not in INPUT_FILTERLIST}


class BeginEntryModel(IdentifyingModel):
"""Pydantic model that represents an entry for the beginning of a step"""

Expand All @@ -94,6 +101,10 @@ class BeginEntryModel(IdentifyingModel):
inputs: Dict[str, Any]
type: str = "begin_entry"

@field_serializer("inputs")
def serialize_inputs(self, inputs):
return _serialize_object(_filter_inputs(inputs))


def _serialize_object(d: object) -> Union[dict, list, object]:
if isinstance(d, list):
Expand All @@ -104,8 +115,7 @@ def _serialize_object(d: object) -> Union[dict, list, object]:
return d.model_dump()
elif hasattr(d, "to_json"):
return d.to_json()
else:
return d
return d


class EndEntryModel(IdentifyingModel):
Expand Down
4 changes: 4 additions & 0 deletions burr/visibility/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from burr.visibility.tracing import ActionSpan
from burr.visibility.tracing import TracerFactory as Tracer

__all__ = ["Tracer", "ActionSpan"]
Loading

0 comments on commit 181c242

Please sign in to comment.