Skip to content

Commit

Permalink
Adds basic tracking component to Burr.
Browse files Browse the repository at this point in the history
This is local, and leverages with_tracker() in the application builder.
It uses pydantic to track models, this will be shared between the
tracker and the UI.
  • Loading branch information
elijahbenizzy committed Feb 16, 2024
1 parent e41623c commit d721985
Show file tree
Hide file tree
Showing 7 changed files with 352 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install -e ".[tests]"
python -m pip install -e ".[tests,tracking-client]"
- name: Run tests
run: |
python -m pytest tests
3 changes: 3 additions & 0 deletions burr/tracking/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .client import LocalTrackingClient

__all__ = ["LocalTrackingClient"]
133 changes: 133 additions & 0 deletions burr/tracking/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import datetime
import json
import logging
import os
import traceback
import uuid
from typing import Any, Optional

from burr.core import Action, ApplicationGraph, State
from burr.integrations.base import require_plugin
from burr.lifecycle import PostRunStepHook, PreRunStepHook
from burr.lifecycle.base import PostApplicationCreateHook
from burr.tracking.models import ApplicationModel, BeginEntryModel, EndEntryModel

logger = logging.getLogger(__name__)

try:
import pydantic
except ImportError as e:
require_plugin(
e,
["pydantic"],
"tracking-client",
)


def _format_exception(exception: Exception) -> Optional[str]:
if exception is None:
return None
return "".join(traceback.format_exception(type(exception), exception, exception.__traceback__))


class LocalTrackingClient(PostApplicationCreateHook, PreRunStepHook, PostRunStepHook):
"""Tracker to track locally -- goes along with the Burr UI. Writes
down the following:
1. The whole application + debugging information (e.g. source code) to a file
2. A line for the start/end of each step"""

GRAPH_FILENAME = "graph.json"
LOG_FILENAME = "log.jsonl"

def __init__(
self,
project: str,
storage_dir: str = "~/.burr",
app_id: Optional[str] = None,
):
"""Instantiates a local tracking client. This will create the following directories, if they don't exist:
1. The base directory (defaults to ~/.burr)
2. The project directory (defaults to ~/.burr/<project>)
3. The application directory (defaults to ~/.burr/<project>/<app_id>) on each
On application create, it will write the state machine to the application directory.
On pre/post run step, it will write the start/end of each step to the application directory.
:param project: Project name -- if this already exists it will be used, otherwise it will be created.
:param storage_dir: Storage directory
:param app_id: Unique application ID. If not provided, a random one will be generated. If this already exists,
it will use that one/append to the files in that one.
"""
if app_id is None:
app_id = f"app_{str(uuid.uuid4())}"
storage_dir = os.path.join(os.path.expanduser(storage_dir), project)
self.app_id = app_id
self.storage_dir = storage_dir
self._ensure_dir_structure()
self.f = open(os.path.join(self.storage_dir, self.app_id, self.LOG_FILENAME), "a")

def _ensure_dir_structure(self):
if not os.path.exists(self.storage_dir):
logger.info(f"Creating storage directory: {self.storage_dir}")
os.makedirs(self.storage_dir)
application_path = os.path.join(self.storage_dir, self.app_id)
if not os.path.exists(application_path):
logger.info(f"Creating application directory: {application_path}")
os.makedirs(application_path)

def post_application_create(
self, *, state: "State", application_graph: "ApplicationGraph", **future_kwargs: Any
):
path = os.path.join(self.storage_dir, self.app_id, self.GRAPH_FILENAME)
if os.path.exists(path):
logger.info(f"Graph already exists at {path}. Not overwriting.")
return
graph = ApplicationModel.from_application_graph(application_graph).model_dump()
with open(path, "w") as f:
json.dump(graph, f)

def _append_write_line(self, model: pydantic.BaseModel):
self.f.write(model.model_dump_json() + "\n")
self.f.flush()

def pre_run_step(self, *, state: State, action: Action, **future_kwargs: Any):
pre_run_entry = BeginEntryModel(
start_time=datetime.datetime.now(),
action=action.name,
inputs={},
)
self._append_write_line(pre_run_entry)

def post_run_step(
self,
*,
state: State,
action: Action,
result: Optional[dict],
exception: Exception,
**future_kwargs: Any,
):
post_run_entry = EndEntryModel(
end_time=datetime.datetime.now(),
action=action.name,
result=result,
exception=_format_exception(exception),
state=state.get_all(),
)
self._append_write_line(post_run_entry)

def __del__(self):
self.f.close()


# TODO -- implement async version
# class AsyncTrackingClient(PreRunStepHookAsync, PostRunStepHookAsync, PostApplicationCreateHook):
# def post_application_create(self, *, state: State, state_graph: ApplicationGraph, **future_kwargs: Any):
# pass
#
# async def pre_run_step(self, *, state: State, action: Action, **future_kwargs: Any):
# raise NotImplementedError(f"TODO: {self.__class__.__name__}.pre_run_step")
#
# async def post_run_step(self, *, state: State, action: Action, result: Optional[dict], exception: Exception, **future_kwargs: Any):
# raise NotImplementedError(f"TODO: {self.__class__.__name__}.pre_run_step")
#
104 changes: 104 additions & 0 deletions burr/tracking/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import datetime
import inspect
from typing import Any, Dict, Optional

from burr.core import Action
from burr.core.action import FunctionBasedAction
from burr.core.application import ApplicationGraph, Transition
from burr.integrations.base import require_plugin

try:
import pydantic
except ImportError as e:
require_plugin(
e,
["pydantic"],
"tracking-client",
)


class IdentifyingModel(pydantic.BaseModel):
model_type: str


class ActionModel(IdentifyingModel):
"""Pydantic model that represents an action for storing/visualization in the UI"""

name: str
reads: list[str]
writes: list[str]
code: str
model_type: str = "action"

@staticmethod
def from_action(action: Action) -> "ActionModel":
"""Creates an ActionModel from an action.
:param action: Action to create the model from
:return:
"""
if isinstance(action, FunctionBasedAction):
code = inspect.getsource(action.fn)
else:
code = inspect.getsource(action.__class__)
return ActionModel(
name=action.name,
reads=list(action.reads),
writes=list(action.writes),
code=code,
)


class TransitionModel(IdentifyingModel):
"""Pydantic model that represents a transition for storing/visualization in the UI"""

from_: str
to: str
condition: str
model_type: str = "transition"

@staticmethod
def from_transition(transition: Transition) -> "TransitionModel":
return TransitionModel(
from_=transition.from_.name, to=transition.to.name, condition=transition.condition.name
)


class ApplicationModel(IdentifyingModel):
"""Pydantic model that represents an application for storing/visualization in the UI"""

entrypoint: str
actions: list[ActionModel]
transitions: list[TransitionModel]
model_type: str = "application"

@staticmethod
def from_application_graph(application_graph: ApplicationGraph) -> "ApplicationModel":
return ApplicationModel(
entrypoint=application_graph.entrypoint.name,
actions=[ActionModel.from_action(action) for action in application_graph.actions],
transitions=[
TransitionModel.from_transition(transition)
for transition in application_graph.transitions
],
)


class BeginEntryModel(IdentifyingModel):
"""Pydantic model that represents an entry for the beginning of a step"""

start_time: datetime.datetime
action: str
inputs: Dict[str, Any]
model_type: str = "begin_entry"


class EndEntryModel(IdentifyingModel):
"""Pydantic model that represents an entry for the end of a step"""

end_time: datetime.datetime
action: str
result: Optional[dict]
exception: Optional[str]
state: Dict[str, Any] # TODO -- consider logging updates to the state so we can recreate
model_type: str = "end_entry"
3 changes: 2 additions & 1 deletion examples/counter/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def application(count_up_to: int = 10, log_file: str = None):
("result", "counter", expr("counter == 0")), # when we've reset, manually
)
.with_entrypoint("counter")
.with_tracker("counter")
.with_hooks(*[StateAndResultsFullLogger(log_file)] if log_file else [])
.build()
)
Expand All @@ -31,6 +32,6 @@ def application(count_up_to: int = 10, log_file: str = None):
if __name__ == "__main__":
app = application(log_file="counter.jsonl")
state, result = app.run(until=["result"])
app.visualize(output_file_path="digraph", include_conditions=True, view=True, format="png")
app.visualize(output_file_path="digraph", include_conditions=True, view=False, format="png")
assert state["counter"] == 10
print(state["counter"])
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ documentation = [
"sphinx-toolbox"
]

tracking-client = [
"pydantic"
]

[tool.poetry.packages]
py_modules = ["burr"]

Expand Down
105 changes: 105 additions & 0 deletions tests/tracking/test_local_tracking_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import json
import os
import uuid
from typing import Tuple

import pytest

import burr
from burr.core import Result, State, action, default, expr
from burr.tracking import LocalTrackingClient
from burr.tracking.models import ApplicationModel, BeginEntryModel, EndEntryModel


@action(reads=["counter", "break_at"], writes=["counter"])
def counter(state: State) -> Tuple[dict, State]:
result = {"counter": state["counter"] + 1}
if state["break_at"] == result["counter"]:
raise ValueError("Broken")
return result, state.update(**result)


def sample_application(project_name: str, log_dir: str, app_id: str, broken: bool = False):
return (
burr.core.ApplicationBuilder()
.with_state(counter=0, break_at=2 if broken else -1)
.with_actions(counter=counter, result=Result(["counter"]))
.with_transitions(
("counter", "counter", expr("counter < 2")), # just count to two for testing
("counter", "result", default),
)
.with_entrypoint("counter")
.with_tracker(
project_name, tracker="local", params={"storage_dir": log_dir, "app_id": app_id}
)
.build()
)


def test_application_tracks_end_to_end(tmpdir: str):
app_id = str(uuid.uuid4())
log_dir = os.path.join(tmpdir, "tracking")
project_name = "test_application_tracks_end_to_end"
app = sample_application(project_name, log_dir, app_id)
app.run(until=["result"])
results_dir = os.path.join(log_dir, project_name, app_id)
assert os.path.exists(results_dir)
assert os.path.exists(log_output := os.path.join(results_dir, LocalTrackingClient.LOG_FILENAME))
assert os.path.exists(
graph_output := os.path.join(results_dir, LocalTrackingClient.GRAPH_FILENAME)
)
with open(log_output) as f:
log_contents = [json.loads(item) for item in f.readlines()]
with open(graph_output) as f:
graph_contents = json.load(f)
assert graph_contents["model_type"] == "application"
app_model = ApplicationModel.parse_obj(graph_contents)
assert app_model.entrypoint == "counter"
assert app_model.actions[0].name == "counter"
assert app_model.actions[1].name == "result"
pre_run = [
BeginEntryModel.parse_obj(line)
for line in log_contents
if line["model_type"] == "begin_entry"
]
post_run = [
EndEntryModel.parse_obj(line) for line in log_contents if line["model_type"] == "end_entry"
]
assert len(pre_run) == 3
assert len(post_run) == 3
assert not any(item.exception for item in post_run)


def test_application_tracks_end_to_end_broken(tmpdir: str):
app_id = str(uuid.uuid4())
log_dir = os.path.join(tmpdir, "tracking")
project_name = "test_application_tracks_end_to_end"
app = sample_application(project_name, log_dir, app_id, broken=True)
with pytest.raises(ValueError):
app.run(until=["result"])
results_dir = os.path.join(log_dir, project_name, app_id)
assert os.path.exists(results_dir)
assert os.path.exists(log_output := os.path.join(results_dir, LocalTrackingClient.LOG_FILENAME))
assert os.path.exists(
graph_output := os.path.join(results_dir, LocalTrackingClient.GRAPH_FILENAME)
)
with open(log_output) as f:
log_contents = [json.loads(item) for item in f.readlines()]
with open(graph_output) as f:
graph_contents = json.load(f)
assert graph_contents["model_type"] == "application"
app_model = ApplicationModel.parse_obj(graph_contents)
assert app_model.entrypoint == "counter"
assert app_model.actions[0].name == "counter"
assert app_model.actions[1].name == "result"
pre_run = [
BeginEntryModel.parse_obj(line)
for line in log_contents
if line["model_type"] == "begin_entry"
]
post_run = [
EndEntryModel.parse_obj(line) for line in log_contents if line["model_type"] == "end_entry"
]
assert len(pre_run) == 2
assert len(post_run) == 2
assert len(post_run[-1].exception) > 0 and "Broken" in post_run[-1].exception

0 comments on commit d721985

Please sign in to comment.