Skip to content

Commit

Permalink
Updates code in response to PR
Browse files Browse the repository at this point in the history
  • Loading branch information
skrawcz authored and elijahbenizzy committed Mar 1, 2024
1 parent 92d3aba commit c78170e
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 44 deletions.
29 changes: 14 additions & 15 deletions burr/core/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,13 +261,7 @@ def _step(
result = _run_function(next_action, self._state, inputs)
new_state = _run_reducer(next_action, self._state, result, next_action.name)

new_state = new_state.update(
**{
PRIOR_STEP: next_action.name,
# make it a string for future proofing
SEQUENCE_ID: str(int(self._state.get(SEQUENCE_ID, 0)) + 1),
}
)
new_state = self.update_internal_state_value(new_state, next_action)
self._set_state(new_state)
except Exception as e:
exc = e
Expand All @@ -284,6 +278,17 @@ def _step(
)
return next_action, result, new_state

def update_internal_state_value(self, new_state: State, next_action: Action) -> State:
"""Updates the internal state values of the new state."""
new_state = new_state.update(
**{
PRIOR_STEP: next_action.name,
# make it a string for future proofing
SEQUENCE_ID: str(int(self._state.get(SEQUENCE_ID, 0)) + 1),
}
)
return new_state

async def astep(self, inputs: Dict[str, Any] = None) -> Optional[Tuple[Action, dict, State]]:
"""Asynchronous version of step.
Expand Down Expand Up @@ -320,13 +325,7 @@ async def astep(self, inputs: Dict[str, Any] = None) -> Optional[Tuple[Action, d
else:
result = await _arun_function(next_action, self._state, inputs=inputs)
new_state = _run_reducer(next_action, self._state, result, next_action.name)
new_state = new_state.update(
**{
PRIOR_STEP: next_action.name,
# make it a string for future proofing
SEQUENCE_ID: str(int(self._state.get(SEQUENCE_ID, 0)) + 1),
}
)
new_state = self.update_internal_state_value(new_state, next_action)
except Exception as e:
exc = e
logger.exception(_format_error_message(next_action, self._state, inputs))
Expand Down Expand Up @@ -794,7 +793,7 @@ def with_tracker(
):
"""Adds a "tracker" to the application. The tracker specifies
a project name (used for disambiguating groups of tracers), and plugs into the
Burr UI. Currently, the only supported tracker is local, which takes in the params
Burr UI. Currently the only supported tracker is local, which takes in the params
`storage_dir` and `app_id`, which have automatic defaults.
:param project: Project name
Expand Down
59 changes: 34 additions & 25 deletions burr/tracking/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,31 +62,35 @@ def __init__(
"""
if app_id is None:
app_id = f"app_{str(uuid.uuid4())}"
storage_dir = self.get_storage_path(project, storage_dir)
storage_dir = LocalTrackingClient.get_storage_path(project, storage_dir)
self.app_id = app_id
self.storage_dir = storage_dir
self._ensure_dir_structure()
self.f = open(os.path.join(self.storage_dir, self.app_id, self.LOG_FILENAME), "a")

@staticmethod
def get_storage_path(project, storage_dir):
@classmethod
def get_storage_path(cls, project, storage_dir):
return os.path.join(os.path.expanduser(storage_dir), project)

@classmethod
def get_state(
def load_state(
cls,
project: str,
app_id: str,
sequence_no: int = -1,
storage_dir: str = DEFAULT_STORAGE_DIR,
) -> tuple[dict, str]:
"""Initialize the state to debug from an exception.
"""Function to load state from what the tracking client got.
It defaults to loading the last state, but you can supply a sequence number.
We will make loading state more ergonomic, but at this time this is what you get.
:param project:
:param app_id:
:param sequence_no:
:param storage_dir:
:return:
:param project: the name of the project
:param app_id: the application instance id
:param sequence_no: the sequence number of the state to load. Defaults to last index (i.e. -1).
:param storage_dir: the storage directory.
:return: the state as a dictionary, and the entry point as a string.
"""
if sequence_no is None:
sequence_no = -1 # get the last one
Expand All @@ -95,28 +99,33 @@ def get_state(
raise ValueError(f"No logs found for {project}/{app_id} under {storage_dir}")
with open(path, "r") as f:
json_lines = f.readlines()
# load as JSON
json_lines = [json.loads(js_line) for js_line in json_lines]
# filter to only end_entry
json_lines = [js_line for js_line in json_lines if js_line["type"] == "end_entry"]
line = {}
if sequence_no < 0:
try:
line = json_lines[sequence_no]
else:
found_line = False
for line in json_lines:
if line["sequence_no"] == sequence_no:
found_line = True
break
if not found_line:
raise ValueError(f"Sequence number {sequence_no} not found for {project}/{app_id}.")
state = line["state"]
except IndexError:
raise ValueError(f"Sequence number {sequence_no} not found for {project}/{app_id}.")
# check sequence number matches if non-negative
line_seq = int(line["sequence_no"])
if -1 < sequence_no != line_seq:
logger.warning(
f"Sequence number mismatch. For {project}/{app_id}: "
f"actual:{line_seq} != expected:{sequence_no}"
)
# get the prior state
prior_state = line["state"]
entry_point = line["action"]
# delete internally stuff. We can't loop over the keys and delete them in the same loop
to_delete = []
for key in state.keys():
for key in prior_state.keys():
# remove any internal "__" state
if key.startswith("__"):
to_delete.append(key)
for key in to_delete:
del state[key]
entry_point = line["action"]
return state, entry_point
del prior_state[key]
return prior_state, entry_point

def _ensure_dir_structure(self):
if not os.path.exists(self.storage_dir):
Expand Down
6 changes: 6 additions & 0 deletions docs/concepts/state.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,9 @@ after the action is complete. Pseudocode:
current_state = current_state.merge(new_state)
If you're used to thinking about version control, this is a bit like a commit/checkout/merge mechanism.

Reloading Prior State
---------------------
Note, if state is serializeable, it means that if stored, it can be reloaded. This is useful for
reloading state from a previous run (for debugging or as part of the application), or for storing state in a database.
We are building more capabilties here, for now for debugging purposes, see the :ref:`tracking <trackingclientref>` section.
31 changes: 31 additions & 0 deletions docs/concepts/tracking.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ The data model for tracking is simple:
3. **Steps** are the individual steps that are executed in the state machine. The Burr UI will show the state of the
state machine at the time of the step execution, as well as the input to and results of the step.

.. _trackingclientref:

---------------
Tracking Client
---------------
Expand All @@ -35,6 +37,35 @@ This currently defaults to (and only supports) the :py:class:`LocalTrackingClien
writes to a local file system, although we will be making it pluggable in the future. It will, by default, write to the directory
``~/.burr``.

Debugging via Reloading Prior State
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Because the tracking client writes to the file system, you can reload the state of the state machine at any time. This is
useful for debugging, because you can quickly recreate the issue by running the state machine with the same point in time.

To do so, you'd use the classmethod _load_state()_ on the :py:class:`LocalTrackingClient <burr.tracking.LocalTrackingClient>`.

For example, as you initialize the Burr Application, you'd have some control flow like this:

.. code-block:: python
from burr.tracking import client
project_name = "demo:hamilton-multi-agent"
if app_instance_id:
initial_state, entry_point = client.LocalTrackingClient.load_state(
project_name, app_instance_id
)
# TODO: any custom logic for re-creating the state if it's some object that needs to be re-instantiated
else:
initial_state, entry_point = default_state_and_entry_point()
app = (
ApplicationBuilder()
.with_state(**initial_state)
.with_entry_point(entry_point)
# ... etc fill in the rest here
)
---------------
Tracking Server
---------------
Expand Down
5 changes: 3 additions & 2 deletions examples/multi-agent-collaboration/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,9 @@ def default_state_and_entry_point() -> tuple[dict, str]:
def main(app_instance_id: str = None):
project_name = "demo:hamilton-multi-agent"
if app_instance_id:
state, entry_point = burr_tclient.LocalTrackingClient.get_state(
project_name, app_instance_id
state, entry_point = burr_tclient.LocalTrackingClient.load_state(
project_name,
app_instance_id,
)
else:
state, entry_point = default_state_and_entry_point()
Expand Down
2 changes: 1 addition & 1 deletion examples/multi-agent-collaboration/hamilton_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def default_state_and_entry_point() -> tuple[dict, str]:
def main(app_instance_id: str = None):
project_name = "demo:hamilton-multi-agent-v1"
if app_instance_id:
state, entry_point = burr_tclient.LocalTrackingClient.get_state(
state, entry_point = burr_tclient.LocalTrackingClient.load_state(
project_name, app_instance_id
)
else:
Expand Down
2 changes: 1 addition & 1 deletion examples/multi-agent-collaboration/lcel_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def default_state_and_entry_point() -> tuple[dict, str]:
def main(app_instance_id: str = None):
project_name = "demo:hamilton-multi-agent"
if app_instance_id:
initial_state, entry_point = burr_tclient.LocalTrackingClient.get_state(
initial_state, entry_point = burr_tclient.LocalTrackingClient.load_state(
project_name, app_instance_id
)
# TODO: rehydrate langchain objects from JSON
Expand Down

0 comments on commit c78170e

Please sign in to comment.