Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
elijahbenizzy committed Mar 1, 2024
1 parent 6615c6d commit bc44d23
Show file tree
Hide file tree
Showing 7 changed files with 162 additions and 7 deletions.
3 changes: 3 additions & 0 deletions burr/core/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,9 @@ def with_params(self, **kwargs: Any) -> "FunctionBasedAction":
def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]:
return self._fn(state, **self._bound_params, **run_kwargs)

def is_async(self) -> bool:
return inspect.iscoroutinefunction(self._fn)


def _validate_action_function(fn: Callable):
"""Validates that an action has the signature: (state: State) -> Tuple[dict, State]
Expand Down
45 changes: 42 additions & 3 deletions burr/core/application.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import collections
import dataclasses
import functools
import logging
import pprint
from typing import (
Expand All @@ -16,6 +17,7 @@
Union,
)

from burr import visibility
from burr.core.action import (
Action,
Condition,
Expand Down Expand Up @@ -220,20 +222,27 @@ def __init__(
self._adapter_set.call_all_lifecycle_hooks_sync(
"post_application_create", state=self._state, application_graph=self._graph
)
# TODO -- consider adding global inputs + global input factories to the builder
self.dependency_factory = {
"__tracer": functools.partial(
visibility.tracing.TracerFactory, lifecycle_adapters=self._adapter_set
)
}

def step(self, inputs: Optional[Dict[str, Any]] = None) -> Optional[Tuple[Action, dict, State]]:
"""Performs a single step, advancing the state machine along.
This returns a tuple of the action that was run, the result of running
the action, and the new state.
Use this if you just want to do something with the state and not rely on generators.
E.G. press forward/backwards, hnuman in the loop, etc... Odds are this is not
E.G. press forward/backwards, human in the loop, etc... Odds are this is not
the method you want -- you'll want iterate() (if you want to see the state/
results along the way), or run() (if you just want the final state/results).
:param inputs: Inputs to the action -- this is if this action requires an input that is passed in from the outside world :param __run_hooks:
:return: Tuple[Function, dict, State] -- the function that was just ran, the result of running it, and the new state
"""

return self._step(inputs=inputs, _run_hooks=True)

def _step(
Expand All @@ -246,6 +255,7 @@ def _step(
return None
if inputs is None:
inputs = {}
inputs = self._process_inputs(inputs, next_action)
if _run_hooks:
self._adapter_set.call_all_lifecycle_hooks_sync(
"pre_run_step", action=next_action, state=self._state, inputs=inputs
Expand Down Expand Up @@ -276,6 +286,31 @@ def _step(
)
return next_action, result, new_state

def _process_inputs(self, inputs: Dict[str, Any], action: Action) -> Dict[str, Any]:
inputs = inputs.copy()
processed_inputs = {}
for key in list(inputs.keys()):
if key in action.inputs:
processed_inputs[key] = inputs.pop(key)
if len(inputs) > 0:
raise ValueError(
f"Keys {inputs.keys()} were passed in as inputs to action "
f"{action.name}, but not declared by the action as an input! "
f"Action needs: {action.inputs}"
)
missing_inputs = set(action.inputs) - set(processed_inputs.keys())
for required_input in list(missing_inputs):
# if we can find it in the dependency factory, we'll use that
if required_input in self.dependency_factory:
processed_inputs[required_input] = self.dependency_factory[required_input](action)
missing_inputs.remove(required_input)
if len(missing_inputs) > 0:
raise ValueError(
f"Action {action.name} is missing required inputs: {missing_inputs}. "
f"Has inputs: {processed_inputs}"
)
return processed_inputs

async def astep(self, inputs: Dict[str, Any] = None) -> Optional[Tuple[Action, dict, State]]:
"""Asynchronous version of step.
Expand All @@ -285,10 +320,11 @@ async def astep(self, inputs: Dict[str, Any] = None) -> Optional[Tuple[Action, d
:return: Tuple[Function, dict, State] -- the action that was just ran, the result of running it, and the new state
"""
next_action = self.get_next_action()
if inputs is None:
inputs = {}
if next_action is None:
return None
if inputs is None:
inputs = {}
inputs = self._process_inputs(inputs, next_action)
await self._adapter_set.call_all_lifecycle_hooks_sync_and_async(
"pre_run_step", action=next_action, state=self._state, inputs=inputs
)
Expand Down Expand Up @@ -502,6 +538,9 @@ async def arun(
"""
prior_action = None
result = None
halt_before, halt_after, inputs = self._clean_iterate_params(
halt_before, halt_after, inputs
)
async for prior_action, result, state in self.aiterate(
halt_before=halt_before, halt_after=halt_after, inputs=inputs
):
Expand Down
Empty file.
14 changes: 12 additions & 2 deletions burr/tracking/common/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,13 @@ def from_application_graph(application_graph: ApplicationGraph) -> "ApplicationM
)


INPUT_FILTERLIST = {"__tracer"}


def _filter_inputs(d: dict) -> dict:
return {k: v for k, v in d.items() if k not in INPUT_FILTERLIST}


class BeginEntryModel(IdentifyingModel):
"""Pydantic model that represents an entry for the beginning of a step"""

Expand All @@ -94,6 +101,10 @@ class BeginEntryModel(IdentifyingModel):
inputs: Dict[str, Any]
type: str = "begin_entry"

@field_serializer("inputs")
def serialize_inputs(self, inputs):
return _serialize_object(_filter_inputs(inputs))


def _serialize_object(d: object) -> Union[dict, list, object]:
if isinstance(d, list):
Expand All @@ -104,8 +115,7 @@ def _serialize_object(d: object) -> Union[dict, list, object]:
return d.model_dump()
elif hasattr(d, "to_json"):
return d.to_json()
else:
return d
return d


class EndEntryModel(IdentifyingModel):
Expand Down
3 changes: 3 additions & 0 deletions burr/visibility/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from burr.visibility.tracing import TracerFactory as Tracer

__all__ = ["Tracer"]
100 changes: 100 additions & 0 deletions burr/visibility/tracing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import dataclasses
import uuid
from contextlib import AbstractAsyncContextManager, AbstractContextManager
from contextvars import ContextVar
from typing import Any, List, Optional, Type

from burr.lifecycle.internal import LifecycleAdapterSet


@dataclasses.dataclass
class ActionSpan:
action: str
name: str
id: str
parent: Optional["ActionSpan"]

@classmethod
def create(cls, action: str, name: str, parent: Optional["ActionSpan"]) -> "ActionSpan":
return ActionSpan(action=action, name=name, id=create_span_id(), parent=parent)


execution_context_var: ContextVar[Optional[ActionSpan]] = ContextVar(
"execution_context",
default=None,
)


def create_span_id() -> str:
return f"span_{str(uuid.uuid4())}"


class ActionSpanTracer(AbstractContextManager, AbstractAsyncContextManager):
"""Context manager for use within tracing actions"""

def __init__(
self,
action: str,
span_name: str,
lifecycle_adapters: LifecycleAdapterSet,
span_dependencies: List[str],
):
self.action = action
self.lifecycle_adapters = lifecycle_adapters
self.span_name = span_name
self.span_dependencies = span_dependencies

async def __aexit__(
self,
__exc_type: Type[BaseException],
__exc_value: Optional[Type[BaseException]],
__traceback: Optional[Type[BaseException]],
) -> Optional[bool]:
raise NotImplementedError

def __enter__(self):
current_execution_context = execution_context_var.get()
execution_context_var.set(
ActionSpan.create(
action=self.action, name=self.span_name, parent=current_execution_context
)
)

def __exit__(
self,
__exc_type: Type[BaseException],
__exc_value: Optional[Type[BaseException]],
__traceback: Optional[Type[BaseException]],
) -> Optional[bool]:
"""Raise any exception triggered within the runtime context."""
current_execution_context = execution_context_var.get()
execution_context_var.set(current_execution_context.parent)
return None

def log_artifact(self, name: str, value: Any):
raise NotImplementedError

async def __aenter__(self) -> "ActionSpanTracer":
raise NotImplementedError


class TracerFactory:
def __init__(self, action: str, lifecycle_adapters: LifecycleAdapterSet):
self.action = action
self.lifecycle_adapters = lifecycle_adapters

def __call__(
self, span_name: str, span_dependencies: Optional[List[str]] = None
) -> ActionSpanTracer:
return ActionSpanTracer(
action=self.action,
span_name=span_name,
lifecycle_adapters=self.lifecycle_adapters,
span_dependencies=span_dependencies,
)


# with TracerFactory("action", "span", lifecycle.LifecycleAdapterSet()) as tracer:
# with tracer("span", ["edge"]) as span:
# pass
# pass
4 changes: 2 additions & 2 deletions tests/core/test_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def test_app_step_with_inputs_missing():
state=State({"count": 0, "tracker": []}),
initial_step="counter",
)
with pytest.raises(ValueError, match="Missing the following inputs"):
with pytest.raises(ValueError, match="missing required inputs"):
app.step(inputs={})


Expand Down Expand Up @@ -447,7 +447,7 @@ async def test_app_astep_with_inputs_missing():
state=State({"count": 0, "tracker": []}),
initial_step="counter_async",
)
with pytest.raises(ValueError, match="Missing the following inputs"):
with pytest.raises(ValueError, match="missing required inputs"):
await app.astep(inputs={})


Expand Down

0 comments on commit bc44d23

Please sign in to comment.