Skip to content

Commit

Permalink
Adds helpful error message for action failure
Browse files Browse the repository at this point in the history
This is only during execution of the action, anything around it (hooks,
et...) will not be covered. This reuses code from hamilton to cleanly
display the error messages and show the state at the time.

See #7
  • Loading branch information
elijahbenizzy committed Feb 14, 2024
1 parent 29121ec commit 3bff5e6
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 0 deletions.
37 changes: 37 additions & 0 deletions burr/core/application.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import collections
import dataclasses
import logging
import pprint
from typing import (
Any,
AsyncGenerator,
Expand Down Expand Up @@ -88,6 +89,40 @@ def _run_reducer(reducer: Reducer, state: State, result: dict, name: str) -> Sta
return state.merge(new_state.update(**{PRIOR_STEP: name}))


def _create_dict_string(kwargs: dict) -> str:
"""This is a utility function to create a string representation of a dict.
This is the state that was passed into the function usually. This is useful for debugging,
as it can be printed out to see what the state was.
:param kwargs: The inputs to the function that errored.
:return: The string representation of the inputs, truncated appropriately.
"""
pp = pprint.PrettyPrinter(width=80)
inputs = {}
for k, v in kwargs.items():
item_repr = repr(v)
if len(item_repr) > 50:
item_repr = item_repr[:50] + "..."
else:
item_repr = v
inputs[k] = item_repr
input_string = pp.pformat(inputs)
if len(input_string) > 1000:
input_string = input_string[:1000] + "..."
return input_string


def _format_error_message(action: Action, input_state: State) -> str:
"""Formats the error string, given that we're inside an action"""
message = f"> Action: {action.name} encountered an error!"
padding = " " * (80 - len(message) - 1)
message += padding + "<"
input_string = _create_dict_string(input_state.get_all())
message += "\n> State (at time of action):\n" + input_string
border = "*" * 80
logger.exception("\n" + border + "\n" + message + "\n" + border)


class Application:
def __init__(
self,
Expand Down Expand Up @@ -132,6 +167,7 @@ def step(self) -> Optional[Tuple[Action, dict, State]]:
self._set_state(new_state)
except Exception as e:
exc = e
logger.exception(_format_error_message(next_action, self._state))
raise e
finally:
self._adapter_set.call_all_lifecycle_hooks_sync(
Expand Down Expand Up @@ -165,6 +201,7 @@ async def astep(self) -> Optional[Tuple[Action, dict, State]]:
new_state = _run_reducer(next_action, self._state, result, next_action.name)
except Exception as e:
exc = e
logger.exception(_format_error_message(next_action, self._state))
raise e
finally:
await self._adapter_set.call_all_lifecycle_hooks_sync_and_async(
Expand Down
50 changes: 50 additions & 0 deletions tests/core/test_application.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import logging
from typing import Awaitable, Callable

import pytest
Expand Down Expand Up @@ -90,6 +91,25 @@ async def _counter_update_async(state: State) -> dict:
)


class BrokenStepException(Exception):
pass


base_broken_action = PassedInAction(
reads=[],
writes=[],
fn=lambda x: exec("raise(BrokenStepException(x))"),
update_fn=lambda result, state: state,
)

base_broken_action_async = PassedInActionAsync(
reads=[],
writes=[],
fn=lambda x: exec("raise(BrokenStepException(x))"),
update_fn=lambda result, state: state,
)


def test_run_function():
"""Tests that we can run a function"""
action = base_counter_action
Expand Down Expand Up @@ -128,6 +148,21 @@ def test_app_step():
assert result == {"counter": 1}


def test_app_step_broken(caplog):
"""Tests that we can run a step in an app"""
broken_action = base_broken_action.with_name("broken_action_unique_name")
app = Application(
actions=[broken_action],
transitions=[Transition(broken_action, broken_action, default)],
state=State({}),
initial_step="broken_action_unique_name",
)
with caplog.at_level(logging.ERROR): # it should say the name, that's the only contract for now
with pytest.raises(BrokenStepException):
app.step()
assert "broken_action_unique_name" in caplog.text


def test_app_step_done():
"""Tests that when we cannot run a step, we return None"""
counter_action = base_counter_action.with_name("counter")
Expand All @@ -152,6 +187,21 @@ async def test_app_astep():
assert result == {"counter": 1}


async def test_app_astep_broken(caplog):
"""Tests that we can run a step in an app"""
broken_action = base_broken_action_async.with_name("broken_action_unique_name")
app = Application(
actions=[broken_action],
transitions=[Transition(broken_action, broken_action, default)],
state=State({}),
initial_step="broken_action_unique_name",
)
with caplog.at_level(logging.ERROR): # it should say the name, that's the only contract for now
with pytest.raises(BrokenStepException):
await app.astep()
assert "broken_action_unique_name" in caplog.text


async def test_app_astep_done():
"""Tests that when we cannot run a step, we return None"""
counter_action = base_counter_action_async.with_name("counter_async")
Expand Down

0 comments on commit 3bff5e6

Please sign in to comment.