From d7219859986b168aa5bddbfefaa3c51237dd13b8 Mon Sep 17 00:00:00 2001 From: elijahbenizzy Date: Thu, 15 Feb 2024 12:57:47 -0800 Subject: [PATCH] Adds basic tracking component to Burr. This is local, and leverages with_tracker() in the application builder. It uses pydantic to track models, this will be shared between the tracker and the UI. --- .github/workflows/python-package.yml | 2 +- burr/tracking/__init__.py | 3 + burr/tracking/client.py | 133 +++++++++++++++++++ burr/tracking/models.py | 104 +++++++++++++++ examples/counter/application.py | 3 +- pyproject.toml | 4 + tests/tracking/test_local_tracking_client.py | 105 +++++++++++++++ 7 files changed, 352 insertions(+), 2 deletions(-) create mode 100644 burr/tracking/__init__.py create mode 100644 burr/tracking/client.py create mode 100644 burr/tracking/models.py create mode 100644 tests/tracking/test_local_tracking_client.py diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 28990ea2..cc55cd08 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -33,7 +33,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | - python -m pip install -e ".[tests]" + python -m pip install -e ".[tests,tracking-client]" - name: Run tests run: | python -m pytest tests diff --git a/burr/tracking/__init__.py b/burr/tracking/__init__.py new file mode 100644 index 00000000..22b63ae1 --- /dev/null +++ b/burr/tracking/__init__.py @@ -0,0 +1,3 @@ +from .client import LocalTrackingClient + +__all__ = ["LocalTrackingClient"] diff --git a/burr/tracking/client.py b/burr/tracking/client.py new file mode 100644 index 00000000..22cc0aac --- /dev/null +++ b/burr/tracking/client.py @@ -0,0 +1,133 @@ +import datetime +import json +import logging +import os +import traceback +import uuid +from typing import Any, Optional + +from burr.core import Action, ApplicationGraph, State +from burr.integrations.base import require_plugin +from burr.lifecycle import PostRunStepHook, PreRunStepHook +from burr.lifecycle.base import PostApplicationCreateHook +from burr.tracking.models import ApplicationModel, BeginEntryModel, EndEntryModel + +logger = logging.getLogger(__name__) + +try: + import pydantic +except ImportError as e: + require_plugin( + e, + ["pydantic"], + "tracking-client", + ) + + +def _format_exception(exception: Exception) -> Optional[str]: + if exception is None: + return None + return "".join(traceback.format_exception(type(exception), exception, exception.__traceback__)) + + +class LocalTrackingClient(PostApplicationCreateHook, PreRunStepHook, PostRunStepHook): + """Tracker to track locally -- goes along with the Burr UI. Writes + down the following: + 1. The whole application + debugging information (e.g. source code) to a file + 2. A line for the start/end of each step""" + + GRAPH_FILENAME = "graph.json" + LOG_FILENAME = "log.jsonl" + + def __init__( + self, + project: str, + storage_dir: str = "~/.burr", + app_id: Optional[str] = None, + ): + """Instantiates a local tracking client. This will create the following directories, if they don't exist: + 1. The base directory (defaults to ~/.burr) + 2. The project directory (defaults to ~/.burr/) + 3. The application directory (defaults to ~/.burr//) on each + + On application create, it will write the state machine to the application directory. + On pre/post run step, it will write the start/end of each step to the application directory. + + :param project: Project name -- if this already exists it will be used, otherwise it will be created. + :param storage_dir: Storage directory + :param app_id: Unique application ID. If not provided, a random one will be generated. If this already exists, + it will use that one/append to the files in that one. + """ + if app_id is None: + app_id = f"app_{str(uuid.uuid4())}" + storage_dir = os.path.join(os.path.expanduser(storage_dir), project) + self.app_id = app_id + self.storage_dir = storage_dir + self._ensure_dir_structure() + self.f = open(os.path.join(self.storage_dir, self.app_id, self.LOG_FILENAME), "a") + + def _ensure_dir_structure(self): + if not os.path.exists(self.storage_dir): + logger.info(f"Creating storage directory: {self.storage_dir}") + os.makedirs(self.storage_dir) + application_path = os.path.join(self.storage_dir, self.app_id) + if not os.path.exists(application_path): + logger.info(f"Creating application directory: {application_path}") + os.makedirs(application_path) + + def post_application_create( + self, *, state: "State", application_graph: "ApplicationGraph", **future_kwargs: Any + ): + path = os.path.join(self.storage_dir, self.app_id, self.GRAPH_FILENAME) + if os.path.exists(path): + logger.info(f"Graph already exists at {path}. Not overwriting.") + return + graph = ApplicationModel.from_application_graph(application_graph).model_dump() + with open(path, "w") as f: + json.dump(graph, f) + + def _append_write_line(self, model: pydantic.BaseModel): + self.f.write(model.model_dump_json() + "\n") + self.f.flush() + + def pre_run_step(self, *, state: State, action: Action, **future_kwargs: Any): + pre_run_entry = BeginEntryModel( + start_time=datetime.datetime.now(), + action=action.name, + inputs={}, + ) + self._append_write_line(pre_run_entry) + + def post_run_step( + self, + *, + state: State, + action: Action, + result: Optional[dict], + exception: Exception, + **future_kwargs: Any, + ): + post_run_entry = EndEntryModel( + end_time=datetime.datetime.now(), + action=action.name, + result=result, + exception=_format_exception(exception), + state=state.get_all(), + ) + self._append_write_line(post_run_entry) + + def __del__(self): + self.f.close() + + +# TODO -- implement async version +# class AsyncTrackingClient(PreRunStepHookAsync, PostRunStepHookAsync, PostApplicationCreateHook): +# def post_application_create(self, *, state: State, state_graph: ApplicationGraph, **future_kwargs: Any): +# pass +# +# async def pre_run_step(self, *, state: State, action: Action, **future_kwargs: Any): +# raise NotImplementedError(f"TODO: {self.__class__.__name__}.pre_run_step") +# +# async def post_run_step(self, *, state: State, action: Action, result: Optional[dict], exception: Exception, **future_kwargs: Any): +# raise NotImplementedError(f"TODO: {self.__class__.__name__}.pre_run_step") +# diff --git a/burr/tracking/models.py b/burr/tracking/models.py new file mode 100644 index 00000000..935b9af3 --- /dev/null +++ b/burr/tracking/models.py @@ -0,0 +1,104 @@ +import datetime +import inspect +from typing import Any, Dict, Optional + +from burr.core import Action +from burr.core.action import FunctionBasedAction +from burr.core.application import ApplicationGraph, Transition +from burr.integrations.base import require_plugin + +try: + import pydantic +except ImportError as e: + require_plugin( + e, + ["pydantic"], + "tracking-client", + ) + + +class IdentifyingModel(pydantic.BaseModel): + model_type: str + + +class ActionModel(IdentifyingModel): + """Pydantic model that represents an action for storing/visualization in the UI""" + + name: str + reads: list[str] + writes: list[str] + code: str + model_type: str = "action" + + @staticmethod + def from_action(action: Action) -> "ActionModel": + """Creates an ActionModel from an action. + + :param action: Action to create the model from + :return: + """ + if isinstance(action, FunctionBasedAction): + code = inspect.getsource(action.fn) + else: + code = inspect.getsource(action.__class__) + return ActionModel( + name=action.name, + reads=list(action.reads), + writes=list(action.writes), + code=code, + ) + + +class TransitionModel(IdentifyingModel): + """Pydantic model that represents a transition for storing/visualization in the UI""" + + from_: str + to: str + condition: str + model_type: str = "transition" + + @staticmethod + def from_transition(transition: Transition) -> "TransitionModel": + return TransitionModel( + from_=transition.from_.name, to=transition.to.name, condition=transition.condition.name + ) + + +class ApplicationModel(IdentifyingModel): + """Pydantic model that represents an application for storing/visualization in the UI""" + + entrypoint: str + actions: list[ActionModel] + transitions: list[TransitionModel] + model_type: str = "application" + + @staticmethod + def from_application_graph(application_graph: ApplicationGraph) -> "ApplicationModel": + return ApplicationModel( + entrypoint=application_graph.entrypoint.name, + actions=[ActionModel.from_action(action) for action in application_graph.actions], + transitions=[ + TransitionModel.from_transition(transition) + for transition in application_graph.transitions + ], + ) + + +class BeginEntryModel(IdentifyingModel): + """Pydantic model that represents an entry for the beginning of a step""" + + start_time: datetime.datetime + action: str + inputs: Dict[str, Any] + model_type: str = "begin_entry" + + +class EndEntryModel(IdentifyingModel): + """Pydantic model that represents an entry for the end of a step""" + + end_time: datetime.datetime + action: str + result: Optional[dict] + exception: Optional[str] + state: Dict[str, Any] # TODO -- consider logging updates to the state so we can recreate + model_type: str = "end_entry" diff --git a/examples/counter/application.py b/examples/counter/application.py index dceba7d3..82e2eb62 100644 --- a/examples/counter/application.py +++ b/examples/counter/application.py @@ -23,6 +23,7 @@ def application(count_up_to: int = 10, log_file: str = None): ("result", "counter", expr("counter == 0")), # when we've reset, manually ) .with_entrypoint("counter") + .with_tracker("counter") .with_hooks(*[StateAndResultsFullLogger(log_file)] if log_file else []) .build() ) @@ -31,6 +32,6 @@ def application(count_up_to: int = 10, log_file: str = None): if __name__ == "__main__": app = application(log_file="counter.jsonl") state, result = app.run(until=["result"]) - app.visualize(output_file_path="digraph", include_conditions=True, view=True, format="png") + app.visualize(output_file_path="digraph", include_conditions=True, view=False, format="png") assert state["counter"] == 10 print(state["counter"]) diff --git a/pyproject.toml b/pyproject.toml index 3327cb5e..ebb02a0a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,10 @@ documentation = [ "sphinx-toolbox" ] +tracking-client = [ + "pydantic" +] + [tool.poetry.packages] py_modules = ["burr"] diff --git a/tests/tracking/test_local_tracking_client.py b/tests/tracking/test_local_tracking_client.py new file mode 100644 index 00000000..d8b62259 --- /dev/null +++ b/tests/tracking/test_local_tracking_client.py @@ -0,0 +1,105 @@ +import json +import os +import uuid +from typing import Tuple + +import pytest + +import burr +from burr.core import Result, State, action, default, expr +from burr.tracking import LocalTrackingClient +from burr.tracking.models import ApplicationModel, BeginEntryModel, EndEntryModel + + +@action(reads=["counter", "break_at"], writes=["counter"]) +def counter(state: State) -> Tuple[dict, State]: + result = {"counter": state["counter"] + 1} + if state["break_at"] == result["counter"]: + raise ValueError("Broken") + return result, state.update(**result) + + +def sample_application(project_name: str, log_dir: str, app_id: str, broken: bool = False): + return ( + burr.core.ApplicationBuilder() + .with_state(counter=0, break_at=2 if broken else -1) + .with_actions(counter=counter, result=Result(["counter"])) + .with_transitions( + ("counter", "counter", expr("counter < 2")), # just count to two for testing + ("counter", "result", default), + ) + .with_entrypoint("counter") + .with_tracker( + project_name, tracker="local", params={"storage_dir": log_dir, "app_id": app_id} + ) + .build() + ) + + +def test_application_tracks_end_to_end(tmpdir: str): + app_id = str(uuid.uuid4()) + log_dir = os.path.join(tmpdir, "tracking") + project_name = "test_application_tracks_end_to_end" + app = sample_application(project_name, log_dir, app_id) + app.run(until=["result"]) + results_dir = os.path.join(log_dir, project_name, app_id) + assert os.path.exists(results_dir) + assert os.path.exists(log_output := os.path.join(results_dir, LocalTrackingClient.LOG_FILENAME)) + assert os.path.exists( + graph_output := os.path.join(results_dir, LocalTrackingClient.GRAPH_FILENAME) + ) + with open(log_output) as f: + log_contents = [json.loads(item) for item in f.readlines()] + with open(graph_output) as f: + graph_contents = json.load(f) + assert graph_contents["model_type"] == "application" + app_model = ApplicationModel.parse_obj(graph_contents) + assert app_model.entrypoint == "counter" + assert app_model.actions[0].name == "counter" + assert app_model.actions[1].name == "result" + pre_run = [ + BeginEntryModel.parse_obj(line) + for line in log_contents + if line["model_type"] == "begin_entry" + ] + post_run = [ + EndEntryModel.parse_obj(line) for line in log_contents if line["model_type"] == "end_entry" + ] + assert len(pre_run) == 3 + assert len(post_run) == 3 + assert not any(item.exception for item in post_run) + + +def test_application_tracks_end_to_end_broken(tmpdir: str): + app_id = str(uuid.uuid4()) + log_dir = os.path.join(tmpdir, "tracking") + project_name = "test_application_tracks_end_to_end" + app = sample_application(project_name, log_dir, app_id, broken=True) + with pytest.raises(ValueError): + app.run(until=["result"]) + results_dir = os.path.join(log_dir, project_name, app_id) + assert os.path.exists(results_dir) + assert os.path.exists(log_output := os.path.join(results_dir, LocalTrackingClient.LOG_FILENAME)) + assert os.path.exists( + graph_output := os.path.join(results_dir, LocalTrackingClient.GRAPH_FILENAME) + ) + with open(log_output) as f: + log_contents = [json.loads(item) for item in f.readlines()] + with open(graph_output) as f: + graph_contents = json.load(f) + assert graph_contents["model_type"] == "application" + app_model = ApplicationModel.parse_obj(graph_contents) + assert app_model.entrypoint == "counter" + assert app_model.actions[0].name == "counter" + assert app_model.actions[1].name == "result" + pre_run = [ + BeginEntryModel.parse_obj(line) + for line in log_contents + if line["model_type"] == "begin_entry" + ] + post_run = [ + EndEntryModel.parse_obj(line) for line in log_contents if line["model_type"] == "end_entry" + ] + assert len(pre_run) == 2 + assert len(post_run) == 2 + assert len(post_run[-1].exception) > 0 and "Broken" in post_run[-1].exception