-
Notifications
You must be signed in to change notification settings - Fork 74
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adds basic tracking component to Burr.
This is local, and leverages with_tracker() in the application builder. It uses pydantic to track models, this will be shared between the tracker and the UI.
- Loading branch information
1 parent
e41623c
commit d721985
Showing
7 changed files
with
352 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .client import LocalTrackingClient | ||
|
||
__all__ = ["LocalTrackingClient"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
import datetime | ||
import json | ||
import logging | ||
import os | ||
import traceback | ||
import uuid | ||
from typing import Any, Optional | ||
|
||
from burr.core import Action, ApplicationGraph, State | ||
from burr.integrations.base import require_plugin | ||
from burr.lifecycle import PostRunStepHook, PreRunStepHook | ||
from burr.lifecycle.base import PostApplicationCreateHook | ||
from burr.tracking.models import ApplicationModel, BeginEntryModel, EndEntryModel | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
try: | ||
import pydantic | ||
except ImportError as e: | ||
require_plugin( | ||
e, | ||
["pydantic"], | ||
"tracking-client", | ||
) | ||
|
||
|
||
def _format_exception(exception: Exception) -> Optional[str]: | ||
if exception is None: | ||
return None | ||
return "".join(traceback.format_exception(type(exception), exception, exception.__traceback__)) | ||
|
||
|
||
class LocalTrackingClient(PostApplicationCreateHook, PreRunStepHook, PostRunStepHook): | ||
"""Tracker to track locally -- goes along with the Burr UI. Writes | ||
down the following: | ||
1. The whole application + debugging information (e.g. source code) to a file | ||
2. A line for the start/end of each step""" | ||
|
||
GRAPH_FILENAME = "graph.json" | ||
LOG_FILENAME = "log.jsonl" | ||
|
||
def __init__( | ||
self, | ||
project: str, | ||
storage_dir: str = "~/.burr", | ||
app_id: Optional[str] = None, | ||
): | ||
"""Instantiates a local tracking client. This will create the following directories, if they don't exist: | ||
1. The base directory (defaults to ~/.burr) | ||
2. The project directory (defaults to ~/.burr/<project>) | ||
3. The application directory (defaults to ~/.burr/<project>/<app_id>) on each | ||
On application create, it will write the state machine to the application directory. | ||
On pre/post run step, it will write the start/end of each step to the application directory. | ||
:param project: Project name -- if this already exists it will be used, otherwise it will be created. | ||
:param storage_dir: Storage directory | ||
:param app_id: Unique application ID. If not provided, a random one will be generated. If this already exists, | ||
it will use that one/append to the files in that one. | ||
""" | ||
if app_id is None: | ||
app_id = f"app_{str(uuid.uuid4())}" | ||
storage_dir = os.path.join(os.path.expanduser(storage_dir), project) | ||
self.app_id = app_id | ||
self.storage_dir = storage_dir | ||
self._ensure_dir_structure() | ||
self.f = open(os.path.join(self.storage_dir, self.app_id, self.LOG_FILENAME), "a") | ||
|
||
def _ensure_dir_structure(self): | ||
if not os.path.exists(self.storage_dir): | ||
logger.info(f"Creating storage directory: {self.storage_dir}") | ||
os.makedirs(self.storage_dir) | ||
application_path = os.path.join(self.storage_dir, self.app_id) | ||
if not os.path.exists(application_path): | ||
logger.info(f"Creating application directory: {application_path}") | ||
os.makedirs(application_path) | ||
|
||
def post_application_create( | ||
self, *, state: "State", application_graph: "ApplicationGraph", **future_kwargs: Any | ||
): | ||
path = os.path.join(self.storage_dir, self.app_id, self.GRAPH_FILENAME) | ||
if os.path.exists(path): | ||
logger.info(f"Graph already exists at {path}. Not overwriting.") | ||
return | ||
graph = ApplicationModel.from_application_graph(application_graph).model_dump() | ||
with open(path, "w") as f: | ||
json.dump(graph, f) | ||
|
||
def _append_write_line(self, model: pydantic.BaseModel): | ||
self.f.write(model.model_dump_json() + "\n") | ||
self.f.flush() | ||
|
||
def pre_run_step(self, *, state: State, action: Action, **future_kwargs: Any): | ||
pre_run_entry = BeginEntryModel( | ||
start_time=datetime.datetime.now(), | ||
action=action.name, | ||
inputs={}, | ||
) | ||
self._append_write_line(pre_run_entry) | ||
|
||
def post_run_step( | ||
self, | ||
*, | ||
state: State, | ||
action: Action, | ||
result: Optional[dict], | ||
exception: Exception, | ||
**future_kwargs: Any, | ||
): | ||
post_run_entry = EndEntryModel( | ||
end_time=datetime.datetime.now(), | ||
action=action.name, | ||
result=result, | ||
exception=_format_exception(exception), | ||
state=state.get_all(), | ||
) | ||
self._append_write_line(post_run_entry) | ||
|
||
def __del__(self): | ||
self.f.close() | ||
|
||
|
||
# TODO -- implement async version | ||
# class AsyncTrackingClient(PreRunStepHookAsync, PostRunStepHookAsync, PostApplicationCreateHook): | ||
# def post_application_create(self, *, state: State, state_graph: ApplicationGraph, **future_kwargs: Any): | ||
# pass | ||
# | ||
# async def pre_run_step(self, *, state: State, action: Action, **future_kwargs: Any): | ||
# raise NotImplementedError(f"TODO: {self.__class__.__name__}.pre_run_step") | ||
# | ||
# async def post_run_step(self, *, state: State, action: Action, result: Optional[dict], exception: Exception, **future_kwargs: Any): | ||
# raise NotImplementedError(f"TODO: {self.__class__.__name__}.pre_run_step") | ||
# |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
import datetime | ||
import inspect | ||
from typing import Any, Dict, Optional | ||
|
||
from burr.core import Action | ||
from burr.core.action import FunctionBasedAction | ||
from burr.core.application import ApplicationGraph, Transition | ||
from burr.integrations.base import require_plugin | ||
|
||
try: | ||
import pydantic | ||
except ImportError as e: | ||
require_plugin( | ||
e, | ||
["pydantic"], | ||
"tracking-client", | ||
) | ||
|
||
|
||
class IdentifyingModel(pydantic.BaseModel): | ||
model_type: str | ||
|
||
|
||
class ActionModel(IdentifyingModel): | ||
"""Pydantic model that represents an action for storing/visualization in the UI""" | ||
|
||
name: str | ||
reads: list[str] | ||
writes: list[str] | ||
code: str | ||
model_type: str = "action" | ||
|
||
@staticmethod | ||
def from_action(action: Action) -> "ActionModel": | ||
"""Creates an ActionModel from an action. | ||
:param action: Action to create the model from | ||
:return: | ||
""" | ||
if isinstance(action, FunctionBasedAction): | ||
code = inspect.getsource(action.fn) | ||
else: | ||
code = inspect.getsource(action.__class__) | ||
return ActionModel( | ||
name=action.name, | ||
reads=list(action.reads), | ||
writes=list(action.writes), | ||
code=code, | ||
) | ||
|
||
|
||
class TransitionModel(IdentifyingModel): | ||
"""Pydantic model that represents a transition for storing/visualization in the UI""" | ||
|
||
from_: str | ||
to: str | ||
condition: str | ||
model_type: str = "transition" | ||
|
||
@staticmethod | ||
def from_transition(transition: Transition) -> "TransitionModel": | ||
return TransitionModel( | ||
from_=transition.from_.name, to=transition.to.name, condition=transition.condition.name | ||
) | ||
|
||
|
||
class ApplicationModel(IdentifyingModel): | ||
"""Pydantic model that represents an application for storing/visualization in the UI""" | ||
|
||
entrypoint: str | ||
actions: list[ActionModel] | ||
transitions: list[TransitionModel] | ||
model_type: str = "application" | ||
|
||
@staticmethod | ||
def from_application_graph(application_graph: ApplicationGraph) -> "ApplicationModel": | ||
return ApplicationModel( | ||
entrypoint=application_graph.entrypoint.name, | ||
actions=[ActionModel.from_action(action) for action in application_graph.actions], | ||
transitions=[ | ||
TransitionModel.from_transition(transition) | ||
for transition in application_graph.transitions | ||
], | ||
) | ||
|
||
|
||
class BeginEntryModel(IdentifyingModel): | ||
"""Pydantic model that represents an entry for the beginning of a step""" | ||
|
||
start_time: datetime.datetime | ||
action: str | ||
inputs: Dict[str, Any] | ||
model_type: str = "begin_entry" | ||
|
||
|
||
class EndEntryModel(IdentifyingModel): | ||
"""Pydantic model that represents an entry for the end of a step""" | ||
|
||
end_time: datetime.datetime | ||
action: str | ||
result: Optional[dict] | ||
exception: Optional[str] | ||
state: Dict[str, Any] # TODO -- consider logging updates to the state so we can recreate | ||
model_type: str = "end_entry" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
import json | ||
import os | ||
import uuid | ||
from typing import Tuple | ||
|
||
import pytest | ||
|
||
import burr | ||
from burr.core import Result, State, action, default, expr | ||
from burr.tracking import LocalTrackingClient | ||
from burr.tracking.models import ApplicationModel, BeginEntryModel, EndEntryModel | ||
|
||
|
||
@action(reads=["counter", "break_at"], writes=["counter"]) | ||
def counter(state: State) -> Tuple[dict, State]: | ||
result = {"counter": state["counter"] + 1} | ||
if state["break_at"] == result["counter"]: | ||
raise ValueError("Broken") | ||
return result, state.update(**result) | ||
|
||
|
||
def sample_application(project_name: str, log_dir: str, app_id: str, broken: bool = False): | ||
return ( | ||
burr.core.ApplicationBuilder() | ||
.with_state(counter=0, break_at=2 if broken else -1) | ||
.with_actions(counter=counter, result=Result(["counter"])) | ||
.with_transitions( | ||
("counter", "counter", expr("counter < 2")), # just count to two for testing | ||
("counter", "result", default), | ||
) | ||
.with_entrypoint("counter") | ||
.with_tracker( | ||
project_name, tracker="local", params={"storage_dir": log_dir, "app_id": app_id} | ||
) | ||
.build() | ||
) | ||
|
||
|
||
def test_application_tracks_end_to_end(tmpdir: str): | ||
app_id = str(uuid.uuid4()) | ||
log_dir = os.path.join(tmpdir, "tracking") | ||
project_name = "test_application_tracks_end_to_end" | ||
app = sample_application(project_name, log_dir, app_id) | ||
app.run(until=["result"]) | ||
results_dir = os.path.join(log_dir, project_name, app_id) | ||
assert os.path.exists(results_dir) | ||
assert os.path.exists(log_output := os.path.join(results_dir, LocalTrackingClient.LOG_FILENAME)) | ||
assert os.path.exists( | ||
graph_output := os.path.join(results_dir, LocalTrackingClient.GRAPH_FILENAME) | ||
) | ||
with open(log_output) as f: | ||
log_contents = [json.loads(item) for item in f.readlines()] | ||
with open(graph_output) as f: | ||
graph_contents = json.load(f) | ||
assert graph_contents["model_type"] == "application" | ||
app_model = ApplicationModel.parse_obj(graph_contents) | ||
assert app_model.entrypoint == "counter" | ||
assert app_model.actions[0].name == "counter" | ||
assert app_model.actions[1].name == "result" | ||
pre_run = [ | ||
BeginEntryModel.parse_obj(line) | ||
for line in log_contents | ||
if line["model_type"] == "begin_entry" | ||
] | ||
post_run = [ | ||
EndEntryModel.parse_obj(line) for line in log_contents if line["model_type"] == "end_entry" | ||
] | ||
assert len(pre_run) == 3 | ||
assert len(post_run) == 3 | ||
assert not any(item.exception for item in post_run) | ||
|
||
|
||
def test_application_tracks_end_to_end_broken(tmpdir: str): | ||
app_id = str(uuid.uuid4()) | ||
log_dir = os.path.join(tmpdir, "tracking") | ||
project_name = "test_application_tracks_end_to_end" | ||
app = sample_application(project_name, log_dir, app_id, broken=True) | ||
with pytest.raises(ValueError): | ||
app.run(until=["result"]) | ||
results_dir = os.path.join(log_dir, project_name, app_id) | ||
assert os.path.exists(results_dir) | ||
assert os.path.exists(log_output := os.path.join(results_dir, LocalTrackingClient.LOG_FILENAME)) | ||
assert os.path.exists( | ||
graph_output := os.path.join(results_dir, LocalTrackingClient.GRAPH_FILENAME) | ||
) | ||
with open(log_output) as f: | ||
log_contents = [json.loads(item) for item in f.readlines()] | ||
with open(graph_output) as f: | ||
graph_contents = json.load(f) | ||
assert graph_contents["model_type"] == "application" | ||
app_model = ApplicationModel.parse_obj(graph_contents) | ||
assert app_model.entrypoint == "counter" | ||
assert app_model.actions[0].name == "counter" | ||
assert app_model.actions[1].name == "result" | ||
pre_run = [ | ||
BeginEntryModel.parse_obj(line) | ||
for line in log_contents | ||
if line["model_type"] == "begin_entry" | ||
] | ||
post_run = [ | ||
EndEntryModel.parse_obj(line) for line in log_contents if line["model_type"] == "end_entry" | ||
] | ||
assert len(pre_run) == 2 | ||
assert len(post_run) == 2 | ||
assert len(post_run[-1].exception) > 0 and "Broken" in post_run[-1].exception |