diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 28990ea2..cc55cd08 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -33,7 +33,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | - python -m pip install -e ".[tests]" + python -m pip install -e ".[tests,tracking-client]" - name: Run tests run: | python -m pytest tests diff --git a/burr/core/__init__.py b/burr/core/__init__.py index 8140573e..3445bdff 100644 --- a/burr/core/__init__.py +++ b/burr/core/__init__.py @@ -1,5 +1,5 @@ from burr.core.action import Action, Condition, Result, action, default, expr, when -from burr.core.application import Application, ApplicationBuilder +from burr.core.application import Application, ApplicationBuilder, ApplicationGraph from burr.core.state import State __all__ = [ @@ -7,6 +7,7 @@ "Action", "Application", "ApplicationBuilder", + "ApplicationGraph", "Condition", "default", "expr", diff --git a/burr/core/application.py b/burr/core/application.py index 3571404c..d9e84ab5 100644 --- a/burr/core/application.py +++ b/burr/core/application.py @@ -6,6 +6,7 @@ Any, AsyncGenerator, Callable, + Dict, Generator, List, Literal, @@ -95,7 +96,7 @@ def _run_reducer(reducer: Reducer, state: State, result: dict, name: str) -> Sta f"Action {name} attempted to write to keys {extra_keys} " f"that it did not declare. It declared: ({reducer.writes})!" ) - return state.merge(new_state.update(**{PRIOR_STEP: name})) + return state.merge(new_state) def _create_dict_string(kwargs: dict) -> str: @@ -156,6 +157,20 @@ async def _arun_single_step_action(action: SingleStepAction, state: State) -> Tu return result, state.merge(new_state.subset(*action.writes)) # we just want the writes action +@dataclasses.dataclass +class ApplicationGraph: + """User-facing representation of the state machine. This has + + #. All the action objects + #. All the transition objects + #. The entrypoint action + """ + + actions: List[Action] + transitions: List[Transition] + entrypoint: Action + + class Application: def __init__( self, @@ -172,6 +187,10 @@ def __init__( self._initial_step = initial_step self._state = state self._adapter_set = adapter_set if adapter_set is not None else LifecycleAdapterSet() + self._graph = self._create_graph() + self._adapter_set.call_all_lifecycle_hooks_sync( + "post_application_create", state=self._state, application_graph=self._graph + ) def step(self) -> Optional[Tuple[Action, dict, State]]: """Performs a single step, advancing the state machine along. @@ -200,6 +219,7 @@ def step(self) -> Optional[Tuple[Action, dict, State]]: else: result = _run_function(next_action, self._state) new_state = _run_reducer(next_action, self._state, result, next_action.name) + new_state = new_state.update(**{PRIOR_STEP: next_action.name}) self._set_state(new_state) except Exception as e: exc = e @@ -238,6 +258,7 @@ async def astep(self) -> Optional[Tuple[Action, dict, State]]: else: result = await _arun_function(next_action, self._state) new_state = _run_reducer(next_action, self._state, result, next_action.name) + new_state = new_state.update(**{PRIOR_STEP: next_action.name}) except Exception as e: exc = e logger.exception(_format_error_message(next_action, self._state)) @@ -280,10 +301,10 @@ def iterate( ) while not condition(): - result = self.step() - if result is None: + state_output = self.step() + if state_output is None: break - action, result, state = result + action, result, state = state_output if action.name in until: seen_results.add(action.name) results[index_by_name[action.name]] = result @@ -462,9 +483,23 @@ def state(self) -> State: """ return self._state + def _create_graph(self) -> ApplicationGraph: + """Internal-facing utility function for creating an ApplicationGraph""" + all_actions = {action.name: action for action in self._actions} + return ApplicationGraph( + actions=self._actions, + transitions=self._transitions, + entrypoint=all_actions[self._initial_step], + ) + @property - def actions(self) -> List[Action]: - return self._actions + def graph(self) -> ApplicationGraph: + """Application graph object -- if you want to inspect, visualize, etc.. + this is what you want. + + :return: The application graph object + """ + return self._graph def _assert_set(value: Optional[Any], field: str, method: str): @@ -524,6 +559,14 @@ def __init__(self): self.lifecycle_adapters: List[LifecycleAdapter] = list() def with_state(self, **kwargs) -> "ApplicationBuilder": + """Sets initial values in the state. If you want to load from a prior state, + you can do so here and pass the values in. + + TODO -- enable passing in a `state` object instead of `**kwargs` + + :param kwargs: Key-value pairs to set in the state + :return: The application builder for future chaining. + """ if self.state is not None: self.state = self.state.update(**kwargs) else: @@ -534,8 +577,8 @@ def with_entrypoint(self, action: str) -> "ApplicationBuilder": """Adds an entrypoint to the application. This is the action that will be run first. This can only be called once. - :param action: - :return: + :param action: The name of the action to set as the entrypoint + :return: The application builder for future chaining. """ # TODO -- validate only called once self.start = action @@ -599,7 +642,34 @@ def with_hooks(self, *adapters: LifecycleAdapter) -> "ApplicationBuilder": self.lifecycle_adapters.extend(adapters) return self + def with_tracker( + self, project: str, tracker: Literal["local"] = "local", params: Dict[str, Any] = None + ): + """Adds a "tracker" to the application. The tracker specifies + a project name (used for disambiguating groups of tracers), and plugs into the + Burr UI. Currently the only supported tracker is local, which takes in the params + `storage_dir` and `app_id`, which have automatic defaults. + + :param project: Project name + :param tracker: Tracker to use, currently only ``local`` is available + :param params: Parameters to pass to the tracker + :return: The application builder for future chaining. + """ + if params is None: + params = {} + if tracker == "local": + from burr.tracking.client import LocalTrackingClient + + self.lifecycle_adapters.append(LocalTrackingClient(project=project, **params)) + else: + raise ValueError(f"Tracker {tracker} not supported") + return self + def build(self) -> Application: + """Builds the application. + + :return: The application object + """ _validate_actions(self.actions) actions_by_name = {action.name: action for action in self.actions} all_actions = set(actions_by_name.keys()) diff --git a/burr/integrations/streamlit.py b/burr/integrations/streamlit.py index f7f77565..2cba0b67 100644 --- a/burr/integrations/streamlit.py +++ b/burr/integrations/streamlit.py @@ -173,7 +173,7 @@ def render_action(state: AppState): """ app: Application = state.app current_node = state.current_action - actions = {action.name: action for action in app.actions} + actions = {action.name: action for action in app.graph.actions} if current_node is None: st.markdown("No current action.") return diff --git a/burr/lifecycle/base.py b/burr/lifecycle/base.py index 390d86ab..43644713 100644 --- a/burr/lifecycle/base.py +++ b/burr/lifecycle/base.py @@ -3,7 +3,7 @@ if TYPE_CHECKING: # type-checking-only for a circular import - from burr.core import State, Action + from burr.core import State, Action, ApplicationGraph from burr.lifecycle.internal import lifecycle @@ -88,6 +88,25 @@ async def post_run_step( pass +@lifecycle.base_hook("post_application_create") +class PostApplicationCreateHook(abc.ABC): + """Synchronous hook that runs post instantiation of an ``Application`` + object (after ``.build()`` is called on the ``ApplicationBuilder`` object.)""" + + @abc.abstractmethod + def post_application_create( + self, *, state: "State", application_graph: "ApplicationGraph", **future_kwargs: Any + ): + """Runs after an "application" object is instantiated. This is run by the Application, in its constructor, + as the last step. + + :param state: Current state of the application + :param application_graph: Application graph of the application, representing the state machine + :param future_kwargs: Future keyword arguments for backwards compatibility + """ + pass + + # THESE ARE NOT IN USE # TODO -- implement/decide how to use them @lifecycle.base_hook("pre_run_application") @@ -133,4 +152,5 @@ async def post_run_application( PreRunApplicationHookAsync, PostRunApplicationHook, PostRunApplicationHookAsync, + PostApplicationCreateHook, ] diff --git a/burr/tracking/__init__.py b/burr/tracking/__init__.py new file mode 100644 index 00000000..22b63ae1 --- /dev/null +++ b/burr/tracking/__init__.py @@ -0,0 +1,3 @@ +from .client import LocalTrackingClient + +__all__ = ["LocalTrackingClient"] diff --git a/burr/tracking/client.py b/burr/tracking/client.py new file mode 100644 index 00000000..890b4378 --- /dev/null +++ b/burr/tracking/client.py @@ -0,0 +1,134 @@ +import datetime +import json +import logging +import os +import traceback +import uuid +from typing import Any, Optional + +from burr.core import Action, ApplicationGraph, State +from burr.integrations.base import require_plugin +from burr.lifecycle import PostRunStepHook, PreRunStepHook +from burr.lifecycle.base import PostApplicationCreateHook +from burr.tracking.models import ApplicationModel, BeginEntryModel, EndEntryModel + +logger = logging.getLogger(__name__) + +try: + import pydantic +except ImportError as e: + require_plugin( + e, + ["pydantic"], + "tracking-client", + ) + + +def _format_exception(exception: Exception) -> Optional[str]: + if exception is None: + return None + return "".join(traceback.format_exception(type(exception), exception, exception.__traceback__)) + + +class LocalTrackingClient(PostApplicationCreateHook, PreRunStepHook, PostRunStepHook): + """Tracker to track locally -- goes along with the Burr UI. Writes + down the following: + #. The whole application + debugging information (e.g. source code) to a file + #. A line for the start/end of each step + """ + + GRAPH_FILENAME = "graph.json" + LOG_FILENAME = "log.jsonl" + + def __init__( + self, + project: str, + storage_dir: str = "~/.burr", + app_id: Optional[str] = None, + ): + """Instantiates a local tracking client. This will create the following directories, if they don't exist: + #. The base directory (defaults to ~/.burr) + #. The project directory (defaults to ~/.burr/) + #. The application directory (defaults to ~/.burr//) on each + + On application create, it will write the state machine to the application directory. + On pre/post run step, it will write the start/end of each step to the application directory. + + :param project: Project name -- if this already exists it will be used, otherwise it will be created. + :param storage_dir: Storage directory + :param app_id: Unique application ID. If not provided, a random one will be generated. If this already exists, + it will use that one/append to the files in that one. + """ + if app_id is None: + app_id = f"app_{str(uuid.uuid4())}" + storage_dir = os.path.join(os.path.expanduser(storage_dir), project) + self.app_id = app_id + self.storage_dir = storage_dir + self._ensure_dir_structure() + self.f = open(os.path.join(self.storage_dir, self.app_id, self.LOG_FILENAME), "a") + + def _ensure_dir_structure(self): + if not os.path.exists(self.storage_dir): + logger.info(f"Creating storage directory: {self.storage_dir}") + os.makedirs(self.storage_dir) + application_path = os.path.join(self.storage_dir, self.app_id) + if not os.path.exists(application_path): + logger.info(f"Creating application directory: {application_path}") + os.makedirs(application_path) + + def post_application_create( + self, *, state: "State", application_graph: "ApplicationGraph", **future_kwargs: Any + ): + path = os.path.join(self.storage_dir, self.app_id, self.GRAPH_FILENAME) + if os.path.exists(path): + logger.info(f"Graph already exists at {path}. Not overwriting.") + return + graph = ApplicationModel.from_application_graph(application_graph).model_dump() + with open(path, "w") as f: + json.dump(graph, f) + + def _append_write_line(self, model: pydantic.BaseModel): + self.f.write(model.model_dump_json() + "\n") + self.f.flush() + + def pre_run_step(self, *, state: State, action: Action, **future_kwargs: Any): + pre_run_entry = BeginEntryModel( + start_time=datetime.datetime.now(), + action=action.name, + inputs={}, + ) + self._append_write_line(pre_run_entry) + + def post_run_step( + self, + *, + state: State, + action: Action, + result: Optional[dict], + exception: Exception, + **future_kwargs: Any, + ): + post_run_entry = EndEntryModel( + end_time=datetime.datetime.now(), + action=action.name, + result=result, + exception=_format_exception(exception), + state=state.get_all(), + ) + self._append_write_line(post_run_entry) + + def __del__(self): + self.f.close() + + +# TODO -- implement async version +# class AsyncTrackingClient(PreRunStepHookAsync, PostRunStepHookAsync, PostApplicationCreateHook): +# def post_application_create(self, *, state: State, state_graph: ApplicationGraph, **future_kwargs: Any): +# pass +# +# async def pre_run_step(self, *, state: State, action: Action, **future_kwargs: Any): +# raise NotImplementedError(f"TODO: {self.__class__.__name__}.pre_run_step") +# +# async def post_run_step(self, *, state: State, action: Action, result: Optional[dict], exception: Exception, **future_kwargs: Any): +# raise NotImplementedError(f"TODO: {self.__class__.__name__}.pre_run_step") +# diff --git a/burr/tracking/models.py b/burr/tracking/models.py new file mode 100644 index 00000000..935b9af3 --- /dev/null +++ b/burr/tracking/models.py @@ -0,0 +1,104 @@ +import datetime +import inspect +from typing import Any, Dict, Optional + +from burr.core import Action +from burr.core.action import FunctionBasedAction +from burr.core.application import ApplicationGraph, Transition +from burr.integrations.base import require_plugin + +try: + import pydantic +except ImportError as e: + require_plugin( + e, + ["pydantic"], + "tracking-client", + ) + + +class IdentifyingModel(pydantic.BaseModel): + model_type: str + + +class ActionModel(IdentifyingModel): + """Pydantic model that represents an action for storing/visualization in the UI""" + + name: str + reads: list[str] + writes: list[str] + code: str + model_type: str = "action" + + @staticmethod + def from_action(action: Action) -> "ActionModel": + """Creates an ActionModel from an action. + + :param action: Action to create the model from + :return: + """ + if isinstance(action, FunctionBasedAction): + code = inspect.getsource(action.fn) + else: + code = inspect.getsource(action.__class__) + return ActionModel( + name=action.name, + reads=list(action.reads), + writes=list(action.writes), + code=code, + ) + + +class TransitionModel(IdentifyingModel): + """Pydantic model that represents a transition for storing/visualization in the UI""" + + from_: str + to: str + condition: str + model_type: str = "transition" + + @staticmethod + def from_transition(transition: Transition) -> "TransitionModel": + return TransitionModel( + from_=transition.from_.name, to=transition.to.name, condition=transition.condition.name + ) + + +class ApplicationModel(IdentifyingModel): + """Pydantic model that represents an application for storing/visualization in the UI""" + + entrypoint: str + actions: list[ActionModel] + transitions: list[TransitionModel] + model_type: str = "application" + + @staticmethod + def from_application_graph(application_graph: ApplicationGraph) -> "ApplicationModel": + return ApplicationModel( + entrypoint=application_graph.entrypoint.name, + actions=[ActionModel.from_action(action) for action in application_graph.actions], + transitions=[ + TransitionModel.from_transition(transition) + for transition in application_graph.transitions + ], + ) + + +class BeginEntryModel(IdentifyingModel): + """Pydantic model that represents an entry for the beginning of a step""" + + start_time: datetime.datetime + action: str + inputs: Dict[str, Any] + model_type: str = "begin_entry" + + +class EndEntryModel(IdentifyingModel): + """Pydantic model that represents an entry for the end of a step""" + + end_time: datetime.datetime + action: str + result: Optional[dict] + exception: Optional[str] + state: Dict[str, Any] # TODO -- consider logging updates to the state so we can recreate + model_type: str = "end_entry" diff --git a/docs/concepts/index.rst b/docs/concepts/index.rst index 86b3b3f9..838a972e 100644 --- a/docs/concepts/index.rst +++ b/docs/concepts/index.rst @@ -14,4 +14,5 @@ Overview of the concepts -- read these to get a mental model for how Burr works. actions transitions hooks + tracking planned-capabilities diff --git a/docs/concepts/state-machine.rst b/docs/concepts/state-machine.rst index a34cc324..1bcc5000 100644 --- a/docs/concepts/state-machine.rst +++ b/docs/concepts/state-machine.rst @@ -91,3 +91,14 @@ Run just calls out to ``iterate`` and returns the final state. Currently the ``until`` variable is a ``or`` gate (E.G. ``any_complete``), although we will be adding an ``and`` gate (E.G. ``all_complete``), and the ability to run until the state machine naturally executes (``until=None``). + +---------- +Inspection +---------- + +You can ask various questions of the state machine using publicly-supported APIs: + +- ``application.graph`` will give you a static reprsentation of the state machine with enough information to visualize +- ``application.state`` will give you the current state of the state machine. Note that if you modify it the results will not show up -- state is immutable! + +See the :ref:`application docs ` diff --git a/docs/concepts/tracking.rst b/docs/concepts/tracking.rst new file mode 100644 index 00000000..0face0c6 --- /dev/null +++ b/docs/concepts/tracking.rst @@ -0,0 +1,27 @@ +.. _tracking: + +============= +Tracking Burr +============= + +Burr comes with a telemetry system that allows tracking a variety of information for debugging, +both in development and production. Note this is a WIP. + +--------------- +Tracking Client +--------------- + +When you use :py:meth:`burr.core.application.ApplicationBuilder.with_tracker`, you add a tracker to Burr. +This is a lifecycle hook that does the following: + +#. Logs the static representation of the state machine +#. Logs any information before/after step execution, including + - The step name + - The step input + - The state at time of execution + - The timestamps + +This currently defaults to (and only supports) the :py:class:`burr.tracking.LocalTrackingClient` class, which +writes to a local file system, althoguh we will be making it pluggable in the future. + +This will be used with the UI, which can serve out of the specified directory. More coming soon! diff --git a/docs/reference/application.rst b/docs/reference/application.rst index 87455732..21aa7399 100644 --- a/docs/reference/application.rst +++ b/docs/reference/application.rst @@ -9,7 +9,12 @@ and not the ``Application`` class directly. .. autoclass:: burr.core.application.ApplicationBuilder :members: +.. _applicationref: + .. autoclass:: burr.core.application.Application :members: .. automethod:: __init__ + +.. autoclass:: burr.core.application.ApplicationGraph + :members: diff --git a/docs/reference/index.rst b/docs/reference/index.rst index 6e22292c..f6184f4b 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -15,5 +15,6 @@ need functionality that is not publicly exposed, please open an issue and we can actions state conditions + tracking lifecycle integrations/index diff --git a/docs/reference/lifecycle.rst b/docs/reference/lifecycle.rst index bf2778d7..7287d3b3 100644 --- a/docs/reference/lifecycle.rst +++ b/docs/reference/lifecycle.rst @@ -19,6 +19,9 @@ and add instances to the application builder to customize your state machines's .. autoclass:: burr.lifecycle.base.PostRunStepHookAsync :members: +.. autoclass:: burr.lifecycle.base.PostApplicationCreateHook + :members: + These hooks are available for you to use: .. autoclass:: burr.lifecycle.default.StateAndResultsFullLogger diff --git a/docs/reference/tracking.rst b/docs/reference/tracking.rst new file mode 100644 index 00000000..5af4094d --- /dev/null +++ b/docs/reference/tracking.rst @@ -0,0 +1,12 @@ +======== +Tracking +======== + +Reference on the Tracking API. This is purely for information -- you should not use this directly. +Rather, you should use this through :py:meth:`burr.core.application.ApplicationBuilder.with_tracker` + + +.. autoclass:: burr.tracking.LocalTrackingClient + :members: + + .. automethod:: __init__ diff --git a/examples/counter/application.py b/examples/counter/application.py index dceba7d3..82e2eb62 100644 --- a/examples/counter/application.py +++ b/examples/counter/application.py @@ -23,6 +23,7 @@ def application(count_up_to: int = 10, log_file: str = None): ("result", "counter", expr("counter == 0")), # when we've reset, manually ) .with_entrypoint("counter") + .with_tracker("counter") .with_hooks(*[StateAndResultsFullLogger(log_file)] if log_file else []) .build() ) @@ -31,6 +32,6 @@ def application(count_up_to: int = 10, log_file: str = None): if __name__ == "__main__": app = application(log_file="counter.jsonl") state, result = app.run(until=["result"]) - app.visualize(output_file_path="digraph", include_conditions=True, view=True, format="png") + app.visualize(output_file_path="digraph", include_conditions=True, view=False, format="png") assert state["counter"] == 10 print(state["counter"]) diff --git a/pyproject.toml b/pyproject.toml index 3327cb5e..ebb02a0a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,10 @@ documentation = [ "sphinx-toolbox" ] +tracking-client = [ + "pydantic" +] + [tool.poetry.packages] py_modules = ["burr"] diff --git a/tests/core/test_application.py b/tests/core/test_application.py index 48db4bd9..d1b111a7 100644 --- a/tests/core/test_application.py +++ b/tests/core/test_application.py @@ -7,6 +7,7 @@ from burr.core import State from burr.core.action import Action, Condition, Result, SingleStepAction, default from burr.core.application import ( + PRIOR_STEP, Application, ApplicationBuilder, Transition, @@ -26,6 +27,7 @@ PreRunStepHookAsync, internal, ) +from burr.lifecycle.base import PostApplicationCreateHook class PassedInAction(Action): @@ -202,6 +204,7 @@ def test_app_step(): action, result, state = app.step() assert action.name == "counter" assert result == {"counter": 1} + assert state[PRIOR_STEP] == "counter" # internal contract, not part of the public API def test_app_step_broken(caplog): @@ -241,6 +244,7 @@ async def test_app_astep(): action, result, state = await app.astep() assert action.name == "counter_async" assert result == {"counter": 1} + assert state[PRIOR_STEP] == "counter_async" # internal contract, not part of the public API async def test_app_astep_broken(caplog): @@ -542,7 +546,7 @@ def test_application_builder_unset(): ApplicationBuilder().build() -def test_application_runs_hooks_sync(): +def test_application_run_step_hooks_sync(): class ActionTracker(PreRunStepHook, PostRunStepHook): def __init__(self): self.pre_called = [] @@ -574,7 +578,7 @@ def post_run_step(self, *, action: Action, **future_kwargs): assert len(tracker.post_called) == 11 -async def test_application_runs_hooks_async(): +async def test_application_run_step_hooks_async(): class ActionTrackerAsync(PreRunStepHookAsync, PostRunStepHookAsync): def __init__(self): self.pre_called = [] @@ -606,3 +610,49 @@ async def post_run_step(self, *, action: Action, **future_kwargs): assert set(tracker.post_called) == {"counter", "result"} assert len(tracker.pre_called) == 11 assert len(tracker.post_called) == 11 + + +def test_application_post_application_create_hook(): + class PostApplicationCreateTracker(PostApplicationCreateHook): + def __init__(self): + self.called_args = None + self.call_count = 0 + + def post_application_create(self, **kwargs): + self.called_args = kwargs + self.call_count += 1 + + tracker = PostApplicationCreateTracker() + counter_action = base_counter_action.with_name("counter") + result_action = Result(fields=["counter"]).with_name("result") + Application( + actions=[counter_action, result_action], + transitions=[ + Transition(counter_action, result_action, Condition.expr("counter >= 10")), + Transition(counter_action, counter_action, default), + ], + state=State({}), + initial_step="counter", + adapter_set=internal.LifecycleAdapterSet(tracker), + ) + assert "state" in tracker.called_args + assert "application_graph" in tracker.called_args + assert tracker.call_count == 1 + + +async def test_application_gives_graph(): + counter_action = base_counter_action.with_name("counter") + result_action = Result(fields=["counter"]).with_name("result") + app = Application( + actions=[counter_action, result_action], + transitions=[ + Transition(counter_action, result_action, Condition.expr("counter >= 10")), + Transition(counter_action, counter_action, default), + ], + state=State({}), + initial_step="counter", + ) + graph = app.graph + assert len(graph.actions) == 2 + assert len(graph.transitions) == 2 + assert graph.entrypoint.name == "counter" diff --git a/tests/tracking/test_local_tracking_client.py b/tests/tracking/test_local_tracking_client.py new file mode 100644 index 00000000..d8b62259 --- /dev/null +++ b/tests/tracking/test_local_tracking_client.py @@ -0,0 +1,105 @@ +import json +import os +import uuid +from typing import Tuple + +import pytest + +import burr +from burr.core import Result, State, action, default, expr +from burr.tracking import LocalTrackingClient +from burr.tracking.models import ApplicationModel, BeginEntryModel, EndEntryModel + + +@action(reads=["counter", "break_at"], writes=["counter"]) +def counter(state: State) -> Tuple[dict, State]: + result = {"counter": state["counter"] + 1} + if state["break_at"] == result["counter"]: + raise ValueError("Broken") + return result, state.update(**result) + + +def sample_application(project_name: str, log_dir: str, app_id: str, broken: bool = False): + return ( + burr.core.ApplicationBuilder() + .with_state(counter=0, break_at=2 if broken else -1) + .with_actions(counter=counter, result=Result(["counter"])) + .with_transitions( + ("counter", "counter", expr("counter < 2")), # just count to two for testing + ("counter", "result", default), + ) + .with_entrypoint("counter") + .with_tracker( + project_name, tracker="local", params={"storage_dir": log_dir, "app_id": app_id} + ) + .build() + ) + + +def test_application_tracks_end_to_end(tmpdir: str): + app_id = str(uuid.uuid4()) + log_dir = os.path.join(tmpdir, "tracking") + project_name = "test_application_tracks_end_to_end" + app = sample_application(project_name, log_dir, app_id) + app.run(until=["result"]) + results_dir = os.path.join(log_dir, project_name, app_id) + assert os.path.exists(results_dir) + assert os.path.exists(log_output := os.path.join(results_dir, LocalTrackingClient.LOG_FILENAME)) + assert os.path.exists( + graph_output := os.path.join(results_dir, LocalTrackingClient.GRAPH_FILENAME) + ) + with open(log_output) as f: + log_contents = [json.loads(item) for item in f.readlines()] + with open(graph_output) as f: + graph_contents = json.load(f) + assert graph_contents["model_type"] == "application" + app_model = ApplicationModel.parse_obj(graph_contents) + assert app_model.entrypoint == "counter" + assert app_model.actions[0].name == "counter" + assert app_model.actions[1].name == "result" + pre_run = [ + BeginEntryModel.parse_obj(line) + for line in log_contents + if line["model_type"] == "begin_entry" + ] + post_run = [ + EndEntryModel.parse_obj(line) for line in log_contents if line["model_type"] == "end_entry" + ] + assert len(pre_run) == 3 + assert len(post_run) == 3 + assert not any(item.exception for item in post_run) + + +def test_application_tracks_end_to_end_broken(tmpdir: str): + app_id = str(uuid.uuid4()) + log_dir = os.path.join(tmpdir, "tracking") + project_name = "test_application_tracks_end_to_end" + app = sample_application(project_name, log_dir, app_id, broken=True) + with pytest.raises(ValueError): + app.run(until=["result"]) + results_dir = os.path.join(log_dir, project_name, app_id) + assert os.path.exists(results_dir) + assert os.path.exists(log_output := os.path.join(results_dir, LocalTrackingClient.LOG_FILENAME)) + assert os.path.exists( + graph_output := os.path.join(results_dir, LocalTrackingClient.GRAPH_FILENAME) + ) + with open(log_output) as f: + log_contents = [json.loads(item) for item in f.readlines()] + with open(graph_output) as f: + graph_contents = json.load(f) + assert graph_contents["model_type"] == "application" + app_model = ApplicationModel.parse_obj(graph_contents) + assert app_model.entrypoint == "counter" + assert app_model.actions[0].name == "counter" + assert app_model.actions[1].name == "result" + pre_run = [ + BeginEntryModel.parse_obj(line) + for line in log_contents + if line["model_type"] == "begin_entry" + ] + post_run = [ + EndEntryModel.parse_obj(line) for line in log_contents if line["model_type"] == "end_entry" + ] + assert len(pre_run) == 2 + assert len(post_run) == 2 + assert len(post_run[-1].exception) > 0 and "Broken" in post_run[-1].exception