Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Burr tracking client #18

Merged
merged 4 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install -e ".[tests]"
python -m pip install -e ".[tests,tracking-client]"
- name: Run tests
run: |
python -m pytest tests
3 changes: 2 additions & 1 deletion burr/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from burr.core.action import Action, Condition, Result, action, default, expr, when
from burr.core.application import Application, ApplicationBuilder
from burr.core.application import Application, ApplicationBuilder, ApplicationGraph
from burr.core.state import State

__all__ = [
"action",
"Action",
"Application",
"ApplicationBuilder",
"ApplicationGraph",
"Condition",
"default",
"expr",
Expand Down
86 changes: 78 additions & 8 deletions burr/core/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Any,
AsyncGenerator,
Callable,
Dict,
Generator,
List,
Literal,
Expand Down Expand Up @@ -95,7 +96,7 @@ def _run_reducer(reducer: Reducer, state: State, result: dict, name: str) -> Sta
f"Action {name} attempted to write to keys {extra_keys} "
f"that it did not declare. It declared: ({reducer.writes})!"
)
return state.merge(new_state.update(**{PRIOR_STEP: name}))
return state.merge(new_state)


def _create_dict_string(kwargs: dict) -> str:
Expand Down Expand Up @@ -156,6 +157,20 @@ async def _arun_single_step_action(action: SingleStepAction, state: State) -> Tu
return result, state.merge(new_state.subset(*action.writes)) # we just want the writes action


@dataclasses.dataclass
class ApplicationGraph:
"""User-facing representation of the state machine. This has

#. All the action objects
#. All the transition objects
#. The entrypoint action
"""

actions: List[Action]
transitions: List[Transition]
entrypoint: Action


class Application:
def __init__(
self,
Expand All @@ -172,6 +187,10 @@ def __init__(
self._initial_step = initial_step
self._state = state
self._adapter_set = adapter_set if adapter_set is not None else LifecycleAdapterSet()
self._graph = self._create_graph()
self._adapter_set.call_all_lifecycle_hooks_sync(
"post_application_create", state=self._state, application_graph=self._graph
)

def step(self) -> Optional[Tuple[Action, dict, State]]:
"""Performs a single step, advancing the state machine along.
Expand Down Expand Up @@ -200,6 +219,7 @@ def step(self) -> Optional[Tuple[Action, dict, State]]:
else:
result = _run_function(next_action, self._state)
new_state = _run_reducer(next_action, self._state, result, next_action.name)
new_state = new_state.update(**{PRIOR_STEP: next_action.name})
self._set_state(new_state)
except Exception as e:
exc = e
Expand Down Expand Up @@ -238,6 +258,7 @@ async def astep(self) -> Optional[Tuple[Action, dict, State]]:
else:
result = await _arun_function(next_action, self._state)
new_state = _run_reducer(next_action, self._state, result, next_action.name)
new_state = new_state.update(**{PRIOR_STEP: next_action.name})
except Exception as e:
exc = e
logger.exception(_format_error_message(next_action, self._state))
Expand Down Expand Up @@ -280,10 +301,10 @@ def iterate(
)

while not condition():
result = self.step()
if result is None:
state_output = self.step()
if state_output is None:
break
action, result, state = result
action, result, state = state_output
if action.name in until:
seen_results.add(action.name)
results[index_by_name[action.name]] = result
Expand Down Expand Up @@ -462,9 +483,23 @@ def state(self) -> State:
"""
return self._state

def _create_graph(self) -> ApplicationGraph:
"""Internal-facing utility function for creating an ApplicationGraph"""
all_actions = {action.name: action for action in self._actions}
return ApplicationGraph(
actions=self._actions,
transitions=self._transitions,
entrypoint=all_actions[self._initial_step],
)

@property
def actions(self) -> List[Action]:
return self._actions
def graph(self) -> ApplicationGraph:
"""Application graph object -- if you want to inspect, visualize, etc..
this is what you want.

:return: The application graph object
"""
return self._graph


def _assert_set(value: Optional[Any], field: str, method: str):
Expand Down Expand Up @@ -524,6 +559,14 @@ def __init__(self):
self.lifecycle_adapters: List[LifecycleAdapter] = list()

def with_state(self, **kwargs) -> "ApplicationBuilder":
"""Sets initial values in the state. If you want to load from a prior state,
you can do so here and pass the values in.

TODO -- enable passing in a `state` object instead of `**kwargs`

:param kwargs: Key-value pairs to set in the state
:return: The application builder for future chaining.
"""
if self.state is not None:
self.state = self.state.update(**kwargs)
else:
Expand All @@ -534,8 +577,8 @@ def with_entrypoint(self, action: str) -> "ApplicationBuilder":
"""Adds an entrypoint to the application. This is the action that will be run first.
This can only be called once.

:param action:
:return:
:param action: The name of the action to set as the entrypoint
:return: The application builder for future chaining.
"""
# TODO -- validate only called once
self.start = action
Expand Down Expand Up @@ -599,7 +642,34 @@ def with_hooks(self, *adapters: LifecycleAdapter) -> "ApplicationBuilder":
self.lifecycle_adapters.extend(adapters)
return self

def with_tracker(
self, project: str, tracker: Literal["local"] = "local", params: Dict[str, Any] = None
):
"""Adds a "tracker" to the application. The tracker specifies
a project name (used for disambiguating groups of tracers), and plugs into the
Burr UI. Currently the only supported tracker is local, which takes in the params
`storage_dir` and `app_id`, which have automatic defaults.

:param project: Project name
:param tracker: Tracker to use, currently only ``local`` is available
:param params: Parameters to pass to the tracker
:return: The application builder for future chaining.
"""
if params is None:
params = {}
if tracker == "local":
from burr.tracking.client import LocalTrackingClient

self.lifecycle_adapters.append(LocalTrackingClient(project=project, **params))
else:
raise ValueError(f"Tracker {tracker} not supported")
return self

def build(self) -> Application:
"""Builds the application.

:return: The application object
"""
_validate_actions(self.actions)
actions_by_name = {action.name: action for action in self.actions}
all_actions = set(actions_by_name.keys())
Expand Down
2 changes: 1 addition & 1 deletion burr/integrations/streamlit.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def render_action(state: AppState):
"""
app: Application = state.app
current_node = state.current_action
actions = {action.name: action for action in app.actions}
actions = {action.name: action for action in app.graph.actions}
if current_node is None:
st.markdown("No current action.")
return
Expand Down
22 changes: 21 additions & 1 deletion burr/lifecycle/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

if TYPE_CHECKING:
# type-checking-only for a circular import
from burr.core import State, Action
from burr.core import State, Action, ApplicationGraph

from burr.lifecycle.internal import lifecycle

Expand Down Expand Up @@ -88,6 +88,25 @@ async def post_run_step(
pass


@lifecycle.base_hook("post_application_create")
class PostApplicationCreateHook(abc.ABC):
"""Synchronous hook that runs post instantiation of an ``Application``
object (after ``.build()`` is called on the ``ApplicationBuilder`` object.)"""

@abc.abstractmethod
def post_application_create(
self, *, state: "State", application_graph: "ApplicationGraph", **future_kwargs: Any
):
"""Runs after an "application" object is instantiated. This is run by the Application, in its constructor,
as the last step.

:param state: Current state of the application
:param application_graph: Application graph of the application, representing the state machine
:param future_kwargs: Future keyword arguments for backwards compatibility
"""
pass


# THESE ARE NOT IN USE
# TODO -- implement/decide how to use them
@lifecycle.base_hook("pre_run_application")
Expand Down Expand Up @@ -133,4 +152,5 @@ async def post_run_application(
PreRunApplicationHookAsync,
PostRunApplicationHook,
PostRunApplicationHookAsync,
PostApplicationCreateHook,
]
3 changes: 3 additions & 0 deletions burr/tracking/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .client import LocalTrackingClient

__all__ = ["LocalTrackingClient"]
134 changes: 134 additions & 0 deletions burr/tracking/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import datetime
import json
import logging
import os
import traceback
import uuid
from typing import Any, Optional

from burr.core import Action, ApplicationGraph, State
from burr.integrations.base import require_plugin
from burr.lifecycle import PostRunStepHook, PreRunStepHook
from burr.lifecycle.base import PostApplicationCreateHook
from burr.tracking.models import ApplicationModel, BeginEntryModel, EndEntryModel

logger = logging.getLogger(__name__)

try:
import pydantic
except ImportError as e:
require_plugin(
e,
["pydantic"],
"tracking-client",
)


def _format_exception(exception: Exception) -> Optional[str]:
if exception is None:
return None
return "".join(traceback.format_exception(type(exception), exception, exception.__traceback__))


class LocalTrackingClient(PostApplicationCreateHook, PreRunStepHook, PostRunStepHook):
"""Tracker to track locally -- goes along with the Burr UI. Writes
down the following:
#. The whole application + debugging information (e.g. source code) to a file
#. A line for the start/end of each step
"""

GRAPH_FILENAME = "graph.json"
LOG_FILENAME = "log.jsonl"

def __init__(
self,
project: str,
storage_dir: str = "~/.burr",
app_id: Optional[str] = None,
):
"""Instantiates a local tracking client. This will create the following directories, if they don't exist:
#. The base directory (defaults to ~/.burr)
#. The project directory (defaults to ~/.burr/<project>)
#. The application directory (defaults to ~/.burr/<project>/<app_id>) on each

On application create, it will write the state machine to the application directory.
On pre/post run step, it will write the start/end of each step to the application directory.

:param project: Project name -- if this already exists it will be used, otherwise it will be created.
:param storage_dir: Storage directory
:param app_id: Unique application ID. If not provided, a random one will be generated. If this already exists,
it will use that one/append to the files in that one.
"""
if app_id is None:
app_id = f"app_{str(uuid.uuid4())}"
storage_dir = os.path.join(os.path.expanduser(storage_dir), project)
self.app_id = app_id
self.storage_dir = storage_dir
self._ensure_dir_structure()
self.f = open(os.path.join(self.storage_dir, self.app_id, self.LOG_FILENAME), "a")

def _ensure_dir_structure(self):
if not os.path.exists(self.storage_dir):
logger.info(f"Creating storage directory: {self.storage_dir}")
os.makedirs(self.storage_dir)
application_path = os.path.join(self.storage_dir, self.app_id)
if not os.path.exists(application_path):
logger.info(f"Creating application directory: {application_path}")
os.makedirs(application_path)

def post_application_create(
self, *, state: "State", application_graph: "ApplicationGraph", **future_kwargs: Any
):
path = os.path.join(self.storage_dir, self.app_id, self.GRAPH_FILENAME)
if os.path.exists(path):
logger.info(f"Graph already exists at {path}. Not overwriting.")
return
graph = ApplicationModel.from_application_graph(application_graph).model_dump()
with open(path, "w") as f:
json.dump(graph, f)

def _append_write_line(self, model: pydantic.BaseModel):
self.f.write(model.model_dump_json() + "\n")
self.f.flush()

def pre_run_step(self, *, state: State, action: Action, **future_kwargs: Any):
pre_run_entry = BeginEntryModel(
start_time=datetime.datetime.now(),
action=action.name,
inputs={},
)
self._append_write_line(pre_run_entry)

def post_run_step(
self,
*,
state: State,
action: Action,
result: Optional[dict],
exception: Exception,
**future_kwargs: Any,
):
post_run_entry = EndEntryModel(
end_time=datetime.datetime.now(),
action=action.name,
result=result,
exception=_format_exception(exception),
state=state.get_all(),
)
self._append_write_line(post_run_entry)

def __del__(self):
self.f.close()


# TODO -- implement async version
# class AsyncTrackingClient(PreRunStepHookAsync, PostRunStepHookAsync, PostApplicationCreateHook):
# def post_application_create(self, *, state: State, state_graph: ApplicationGraph, **future_kwargs: Any):
# pass
#
# async def pre_run_step(self, *, state: State, action: Action, **future_kwargs: Any):
# raise NotImplementedError(f"TODO: {self.__class__.__name__}.pre_run_step")
#
# async def post_run_step(self, *, state: State, action: Action, result: Optional[dict], exception: Exception, **future_kwargs: Any):
# raise NotImplementedError(f"TODO: {self.__class__.__name__}.pre_run_step")
#
Loading
Loading