Skip to content

Commit

Permalink
Adds function to get state to tracker
Browse files Browse the repository at this point in the history
So that we can pull the last state given an ID from the UI.

Updates examples and adds a new one that mirrors the
LCEL example with the tool calling. This variant works.
Where as the Hamilton DAG also doing tool calling somehow doesn't quite work
as well. Still digging in there. Need a run diff tool now... or a way to compare runs...
  • Loading branch information
skrawcz committed Feb 29, 2024
1 parent c878ca0 commit e9264d9
Show file tree
Hide file tree
Showing 7 changed files with 358 additions and 106 deletions.
20 changes: 17 additions & 3 deletions burr/core/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class Transition:
TerminationCondition = Literal["any_complete", "all_complete"]

PRIOR_STEP = "__PRIOR_STEP"
SEQUENCE_ID = "__SEQUENCE_ID"


def _run_function(function: Function, state: State, inputs: Dict[str, Any]) -> dict:
Expand Down Expand Up @@ -259,7 +260,14 @@ def _step(
else:
result = _run_function(next_action, self._state, inputs)
new_state = _run_reducer(next_action, self._state, result, next_action.name)
new_state = new_state.update(**{PRIOR_STEP: next_action.name})

new_state = new_state.update(
**{
PRIOR_STEP: next_action.name,
# make it a string for future proofing
SEQUENCE_ID: str(int(self._state.get(SEQUENCE_ID, 0)) + 1),
}
)
self._set_state(new_state)
except Exception as e:
exc = e
Expand Down Expand Up @@ -312,7 +320,13 @@ async def astep(self, inputs: Dict[str, Any] = None) -> Optional[Tuple[Action, d
else:
result = await _arun_function(next_action, self._state, inputs=inputs)
new_state = _run_reducer(next_action, self._state, result, next_action.name)
new_state = new_state.update(**{PRIOR_STEP: next_action.name})
new_state = new_state.update(
**{
PRIOR_STEP: next_action.name,
# make it a string for future proofing
SEQUENCE_ID: str(int(self._state.get(SEQUENCE_ID, 0)) + 1),
}
)
except Exception as e:
exc = e
logger.exception(_format_error_message(next_action, self._state, inputs))
Expand Down Expand Up @@ -780,7 +794,7 @@ def with_tracker(
):
"""Adds a "tracker" to the application. The tracker specifies
a project name (used for disambiguating groups of tracers), and plugs into the
Burr UI. Currently the only supported tracker is local, which takes in the params
Burr UI. Currently, the only supported tracker is local, which takes in the params
`storage_dir` and `app_id`, which have automatic defaults.
:param project: Project name
Expand Down
55 changes: 53 additions & 2 deletions burr/tracking/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,12 @@ class LocalTrackingClient(PostApplicationCreateHook, PreRunStepHook, PostRunStep

GRAPH_FILENAME = "graph.json"
LOG_FILENAME = "log.jsonl"
DEFAULT_STORAGE_DIR = "~/.burr"

def __init__(
self,
project: str,
storage_dir: str = "~/.burr",
storage_dir: str = DEFAULT_STORAGE_DIR,
app_id: Optional[str] = None,
):
"""Instantiates a local tracking client. This will create the following directories, if they don't exist:
Expand All @@ -61,12 +62,62 @@ def __init__(
"""
if app_id is None:
app_id = f"app_{str(uuid.uuid4())}"
storage_dir = os.path.join(os.path.expanduser(storage_dir), project)
storage_dir = self.get_storage_path(project, storage_dir)
self.app_id = app_id
self.storage_dir = storage_dir
self._ensure_dir_structure()
self.f = open(os.path.join(self.storage_dir, self.app_id, self.LOG_FILENAME), "a")

@staticmethod
def get_storage_path(project, storage_dir):
return os.path.join(os.path.expanduser(storage_dir), project)

@classmethod
def get_state(
cls,
project: str,
app_id: str,
sequence_no: int = -1,
storage_dir: str = DEFAULT_STORAGE_DIR,
) -> tuple[dict, str]:
"""Initialize the state to debug from an exception.
:param project:
:param app_id:
:param sequence_no:
:param storage_dir:
:return:
"""
if sequence_no is None:
sequence_no = -1 # get the last one
path = os.path.join(cls.get_storage_path(project, storage_dir), app_id, cls.LOG_FILENAME)
if not os.path.exists(path):
raise ValueError(f"No logs found for {project}/{app_id} under {storage_dir}")
with open(path, "r") as f:
json_lines = f.readlines()
json_lines = [json.loads(js_line) for js_line in json_lines]
json_lines = [js_line for js_line in json_lines if js_line["type"] == "end_entry"]
line = {}
if sequence_no < 0:
line = json_lines[sequence_no]
else:
found_line = False
for line in json_lines:
if line["sequence_no"] == sequence_no:
found_line = True
break
if not found_line:
raise ValueError(f"Sequence number {sequence_no} not found for {project}/{app_id}.")
state = line["state"]
to_delete = []
for key in state.keys():
if key.startswith("__"):
to_delete.append(key)
for key in to_delete:
del state[key]
entry_point = line["action"]
return state, entry_point

def _ensure_dir_structure(self):
if not os.path.exists(self.storage_dir):
logger.info(f"Creating storage directory: {self.storage_dir}")
Expand Down
132 changes: 40 additions & 92 deletions examples/multi-agent-collaboration/application.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import json
import os.path

import func_agent
from hamilton import driver
from langchain_community.tools.tavily_search import TavilySearchResults
Expand All @@ -10,32 +7,10 @@
from burr.core import Action, ApplicationBuilder, State
from burr.core.action import action
from burr.lifecycle import PostRunStepHook
from burr.tracking import client as burr_tclient

# @action(reads=["code"], writes=["code_result"])
# def run_code(state: State) -> tuple[dict, State]:
# _code = state["code"]
# try:
# result = repl.run(_code)
# except BaseException as e:
# _code_result = {"status": "error", "result": f"Failed to execute. Error: {repr(e)}"
# return {"result": f"Failed to execute. Error: {repr(e)}"}, state.update(code_result=_code_result)
# _code_result = {"status": "success", "result": result}
# return {"status": "success", "result": result}, state.update(code_result=_code_result)
#
# @action(reads=["code"], writes=["code_result"])
# def run_tavily(state: State) -> tuple[dict, State]:
# _code = state["code"]
# try:
# result = repl.run(_code)
# except BaseException as e:
# _code_result = {"status": "error", "result": f"Failed to execute. Error: {repr(e)}"
# return {"result": f"Failed to execute. Error: {repr(e)}"}, state.update(code_result=_code_result)
# _code_result = {"status": "success", "result": result}
# return {"status": "success", "result": result}, state.update(code_result=_code_result)


# Initialize some things needed for tools.
tool_dag = driver.Builder().with_modules(func_agent).build()

repl = PythonREPL()


Expand Down Expand Up @@ -126,26 +101,6 @@ def post_run_step(self, *, state: "State", action: "Action", **future_kwargs):
print("state======\n", state)


def initialize_state_from_logs(tracker_name: str, app_id: str) -> tuple[dict, str]:
"""Initialize the state to debug from an exception
:param tracker_name:
:param app_id:
:return:
"""
# open ~/.burr/{tracker_name}/{app_id}/log.jsonl
# find the first entry with an exception -- and pull state from it.
with open(f"{os.path.expanduser('~/')}/.burr/{tracker_name}/{app_id}/log.jsonl", "r") as f:
lines = f.readlines()
for line in lines:
line = json.loads(line)
if "exception" in line:
state = line["state"]
entry_point = line["action"]
return state, entry_point
raise ValueError(f"No exception found in logs for {tracker_name}/{app_id}")


def default_state_and_entry_point() -> tuple[dict, str]:
return {
"messages": [],
Expand All @@ -157,9 +112,11 @@ def default_state_and_entry_point() -> tuple[dict, str]:


def main(app_instance_id: str = None):
tracker_name = "hamilton-multi-agent"
project_name = "demo:hamilton-multi-agent"
if app_instance_id:
state, entry_point = initialize_state_from_logs(tracker_name, app_instance_id)
state, entry_point = burr_tclient.LocalTrackingClient.get_state(
project_name, app_instance_id
)
else:
state, entry_point = default_state_and_entry_point()

Expand Down Expand Up @@ -189,54 +146,45 @@ def main(app_instance_id: str = None):
)
.with_entrypoint(entry_point)
.with_hooks(PrintStepHook())
.with_tracker(tracker_name)
.with_tracker(project_name)
.build()
)
app.visualize(
output_file_path="hamilton-multi-agent", include_conditions=True, view=True, format="png"
)
app.run(halt_after=["terminal"])
# return app


if __name__ == "__main__":
_app_id = "app_4d1618d2-79d1-4d89-8e3f-70c216c71e63"
# Add an app_id to restart from last sequence in that state
# e.g. fine the ID in the UI and then put it in here "app_4d1618d2-79d1-4d89-8e3f-70c216c71e63"
_app_id = None
main(_app_id)
import sys

sys.exit(0)
"""TODO:
1. need to figure out the messages state.
2. the current design is each DAG run also calls the tool.
3. so need to then update messages history appropriately
so that the "agents" can iterate until they are done.
Note: https://github.com/langchain-ai/langgraph/blob/main/examples/multi_agent/multi-agent-collaboration.ipynb
Is a little messy to figure out. So best to approach from first
principles what is actually going on.
"""
repl = PythonREPL()
tavily_tool = TavilySearchResults(max_results=5)
result = tool_dag.execute(
["executed_tool_calls"],
inputs={
"tools": [tavily_tool],
"system_message": "You should provide accurate data for the chart generator to use.",
"user_query": "Fetch the UK's GDP over the past 5 years,"
" then draw a line graph of it."
" Once you have written code for the graph, finish.",
},
)
import pprint

pprint.pprint(result)

result = tool_dag.execute(
["executed_tool_calls"],
inputs={
"tools": [python_repl],
"system_message": "Any charts you display will be visible by the user.",
"user_query": "Draw a simple line graph of y = x",
},
)
import pprint

pprint.pprint(result)
# some test code
# tavily_tool = TavilySearchResults(max_results=5)
# result = tool_dag.execute(
# ["executed_tool_calls"],
# inputs={
# "tools": [tavily_tool],
# "system_message": "You should provide accurate data for the chart generator to use.",
# "user_query": "Fetch the UK's GDP over the past 5 years,"
# " then draw a line graph of it."
# " Once you have written code for the graph, finish.",
# },
# )
# import pprint
#
# pprint.pprint(result)
#
# result = tool_dag.execute(
# ["executed_tool_calls"],
# inputs={
# "tools": [python_repl],
# "system_message": "Any charts you display will be visible by the user.",
# "user_query": "Draw a simple line graph of y = x",
# },
# )
# import pprint
#
# pprint.pprint(result)
4 changes: 2 additions & 2 deletions examples/multi-agent-collaboration/func_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ def base_system_prompt(tool_names: list[str], system_message: str) -> str:
"You are a helpful AI assistant, collaborating with other assistants."
" Use the provided tools to progress towards answering the question."
" If you are unable to fully answer, that's OK, another assistant with different tools "
" will help where you left off. Execute what you can to make progress."
" will help where you left off. Execute what you can to make progress.\n\n"
" If you or any of the other assistants have the final answer or deliverable,"
" prefix your response with FINAL ANSWER so the team knows to stop."
" prefix your response with FINAL ANSWER so the team knows to stop.\n\n"
f" You have access to the following tools: {tool_names}.\n{system_message}"
)

Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit e9264d9

Please sign in to comment.