Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
elijahbenizzy committed Feb 6, 2024
1 parent ed410d0 commit c85afff
Show file tree
Hide file tree
Showing 38 changed files with 2,941 additions and 599 deletions.
20 changes: 15 additions & 5 deletions api_examples/code_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,31 @@ def display(response):
def main():
# Help me figure out which loss function to use for my neural network
prompt = Placeholder("prompt", reads=["prompt"], writes=["chat_history"])
plan = Placeholder("plan", reads=["chat_history", "prompt", "search_history"], writes=["plan", "todo"])
plan = Placeholder(
"plan", reads=["chat_history", "prompt", "search_history"], writes=["plan", "todo"]
)

web_search = Placeholder("web_search", reads=["plan"], writes=["search_history"])
stack_overflow_search = Placeholder("stack_overflow_search", reads=["plan"], writes=["responses", "search_history"])
stack_overflow_search = Placeholder(
"stack_overflow_search", reads=["plan"], writes=["responses", "search_history"]
)
mathematica = Placeholder("mathematica", reads=["plan"], writes=["responses", "search_history"])

synthesizer = Placeholder("synthesizer", reads=["responses"], writes=["response"])
response = Placeholder("response", reads=["response"], writes=["chat_history"])

agent = (
ApplicationBuilder()
.with_state(prompt=input(...), responses=[], response=None, chat_history=[...], search_history=[])
.with_state(
prompt=input(...), responses=[], response=None, chat_history=[...], search_history=[]
)
.with_transition(prompt, plan)
.with_transition(plan, web_search, Condition.expr("'web_search' == todo")) # ["web_search", "stack_overflow_search"]
.with_transition(plan, stack_overflow_search, Condition.expr("'stack_overflow_search' == todo"))
.with_transition(
plan, web_search, Condition.expr("'web_search' == todo")
) # ["web_search", "stack_overflow_search"]
.with_transition(
plan, stack_overflow_search, Condition.expr("'stack_overflow_search' == todo")
)
.with_transition(plan, mathematica, Condition.expr("'mathematica' == todo"))
.with_transition([web_search, stack_overflow_search, mathematica], plan)
# done
Expand Down
18 changes: 10 additions & 8 deletions api_examples/model_training.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from hamilton import driver

from burr.core import Result, Condition
from burr.core import Condition, Result
from burr.core.application import ApplicationBuilder
from burr.integrations.hamilton import Hamilton, from_state, from_value, update_state, append_state
from burr.integrations.hamilton import Hamilton, append_state, from_state, from_value, update_state
from hamilton import driver


def main():
Expand Down Expand Up @@ -38,7 +37,7 @@ def main():
inputs={
"model": from_state("model", missing="drop"),
"epochs": from_state("epochs"),
"dataset_path": from_value("./dataset.csv")
"dataset_path": from_value("./dataset.csv"),
},
outputs={
"model": update_state("model"),
Expand All @@ -62,13 +61,16 @@ def main():
)
.build()
)
state, result, = agent.run(["result"])
(
state,
result,
) = agent.run(["result"])
agent.visualize("./out", include_conditions=True)
for step in iter(agent):
report(step) # call out to w&b
report(step) # call out to w&b
if step.name == "result":
break
result, = agent.run(["result"])
(result,) = agent.run(["result"])
# nprint(result)


Expand Down
6 changes: 4 additions & 2 deletions burr/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from burr.core.action import Action, Condition, Result, default, expr, when
from burr.core.application import ApplicationBuilder, Application
from burr.core.state import State
from burr.core.function import Function, Condition, DEFAULT, Result
from burr.core.application import Application, Transition, ApplicationBuilder

__all__ = ["Action", "ApplicationBuilder", "Condition", "Result", "default", "when", "expr", "Application"]
133 changes: 63 additions & 70 deletions burr/core/function.py → burr/core/action.py
Original file line number Diff line number Diff line change
@@ -1,101 +1,92 @@
import abc
import ast
import copy
import inspect
from typing import Callable, List, Optional

from burr.core.state import State


class Function(abc.ABC):
def __init__(self, name: Optional[str]):
"""Represents a function in a state machine. This is the base class from which:
1. Custom functions
2. Conditions
3. Results
All extend this class. Note that name is optional so that APIs can set the
name on these actions as part of instantiation.
When they're used, they must have a name set.
@property
@abc.abstractmethod
def reads(self) -> list[str]:
pass

:param name:
"""
self._name = name
@abc.abstractmethod
def run(self, state: State) -> dict:
pass

def with_name(self, name: str) -> "Function":
"""Returns a copy of the given function with the given name."""
copied = copy.copy(self)
copied.set_name(name)
return copied
def is_async(self):
return inspect.iscoroutinefunction(self.run)

def set_name(self, name: str):
self._name = name

class Reducer(abc.ABC):
@property
def name(self) -> str:
"""Gives the name of this action. This should be unique
across your agent."""
return self._name

@abc.abstractmethod
def run(self, state: State) -> dict:
"""Runs the action, given a state. Returns the result (a dictionary)
of running the action.
:param state:
:return:
"""
def writes(self) -> list[str]:
pass

@abc.abstractmethod
def update(self, result: dict, state: State) -> State:
"""Updates the state given the result of running the action.
pass

Note that if this attempts to access anything not in self.reads()
or writes to anything not in self.writes(), it will fail.

class Action(Function, Reducer, abc.ABC):
def __init__(self):
"""Represents an action in a state machine. This is the base class from which:
:param result: Result of running the action
:param state: State to update
:return: The updated state
1. Custom actions
2. Conditions
3. Results
All extend this class. Note that name is optional so that APIs can set the
name on these actions as part of instantiation.
When they're used, they must have a name set.
"""
self._name = None

@property
@abc.abstractmethod
def reads(self) -> list[str]:
"""The list of keys in the state that this uses as an input.
def with_name(self, name: str) -> "Action":
"""Returns a copy of the given action with the given name. Why do we need this?
We instantiate actions without names, and then set them later. This is a way to
make the API cleaner/consolidate it, and the ApplicationBuilder will end up handling it
for you, in the with_actions(...) method, which is the only way to use actions.
:return: List of keys in the state
Note they can also take in names in the constructor for testing, but otherwise this is
not something users will ever have to think about.
:param name: Name to set
:return: A new action with the given name
"""
if self._name is not None:
raise ValueError(
f"Name of {self} already set to {self._name} -- cannot set name to {name}"
)
# TODO -- ensure that we're not mutating anything later on
# If we are, we may want to copy more intelligently
new_action = copy.copy(self)
new_action._name = name
return new_action

@property
@abc.abstractmethod
def writes(self) -> list[str]:
"""The list of keys in the state that this writes to.
:return:
"""
pass
def name(self) -> str:
"""Gives the name of this action. This should be unique
across your agent."""
return self._name

def __repr__(self):
return f"{self.name}({', '.join(self.reads)}) -> {', '.join(self.writes)}"
read_repr = ", ".join(self.reads) if self.reads else "{}"
write_repr = ", ".join(self.writes) if self.writes else "{}"
return f"{self.name}: {read_repr} -> {write_repr}"


class PureFunction(Function, abc.ABC):
def update(self, result: dict, state: State) -> State:
return state

@property
def writes(self) -> list[str]:
return []


class Condition(PureFunction):
class Condition(Function):
KEY = "PROCEED"

def __init__(self, keys: List[str], resolver: Callable[[State], bool], name: str = None):
super().__init__(name=name) # TODO -- find a better way of making this unique
self._resolver = resolver
self._keys = keys
self._name = name

@staticmethod
def expr(expr: str) -> "Condition":
Expand All @@ -117,7 +108,6 @@ def visit_Name(self, node):

# Compile the expression into a callable function
def condition_func(state: State) -> bool:

__globals = state.get_all() # we can get all becuase externally we will subset
return eval(compile(tree, "<string>", "eval"), {}, __globals)

Expand All @@ -126,9 +116,6 @@ def condition_func(state: State) -> bool:
def run(self, state: State) -> dict:
return {Condition.KEY: self._resolver(state)}

def update(self, result: dict, state: State) -> State:
return state

@property
def reads(self) -> list[str]:
return self._keys
Expand All @@ -148,13 +135,19 @@ def condition_func(state: State) -> bool:
name = f"{', '.join(f'{key}={value}' for key, value in sorted(kwargs.items()))}"
return Condition(keys, condition_func, name=name)

@property
def name(self) -> str:
return self._name

DEFAULT = Condition([], lambda _: True, name="default")

default = Condition([], lambda _: True, name="default")
when = Condition.when
expr = Condition.expr

class Result(Function):
def __init__(self, name: str, fields: list[str]):
super().__init__(name)

class Result(Action):
def __init__(self, fields: list[str]):
super(Result, self).__init__()
self._fields = fields

def run(self, state: State) -> dict:
Expand All @@ -169,4 +162,4 @@ def reads(self) -> list[str]:

@property
def writes(self) -> list[str]:
return self._fields
return []
Loading

0 comments on commit c85afff

Please sign in to comment.