Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
elijahbenizzy committed Oct 27, 2024
1 parent 4238c72 commit 153742b
Show file tree
Hide file tree
Showing 7 changed files with 872 additions and 198 deletions.
2 changes: 1 addition & 1 deletion burr/core/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,7 +862,7 @@ def _process_inputs(self, inputs: Dict[str, Any], action: Action) -> Dict[str, A
BASE_ERROR_MESSAGE
+ f"Inputs starting with a double underscore ({starting_with_double_underscore}) "
f"are reserved for internal use/injected inputs."
"Please do not use keys"
"Please do not directly pass keys starting with a double underscore."
)
inputs = inputs.copy()
processed_inputs = {}
Expand Down
5 changes: 1 addition & 4 deletions burr/core/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def _validate_actions(actions: Optional[List[Action]]):
def _validate_transitions(
transitions: Optional[List[Tuple[str, str, Condition]]], actions: Set[str]
):
assert_set(transitions, "_transitions", "with_transitions")
exhausted = {} # items for which we have seen a default transition
for from_, to, condition in transitions:
if from_ not in actions:
Expand Down Expand Up @@ -235,7 +234,7 @@ class GraphBuilder:

def __init__(self):
"""Initializes the graph builder."""
self.transitions: Optional[List[Tuple[str, str, Condition]]] = None
self.transitions: Optional[List[Tuple[str, str, Condition]]] = []
self.actions: Optional[List[Action]] = None

def with_actions(
Expand Down Expand Up @@ -283,8 +282,6 @@ def with_transitions(
:param transitions: Transitions to add
:return: The application builder for future chaining.
"""
if self.transitions is None:
self.transitions = []
for transition in transitions:
from_, to_, *conditions = transition
if len(conditions) > 0:
Expand Down
75 changes: 61 additions & 14 deletions burr/core/parallelism.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import abc
import asyncio
import dataclasses
import inspect
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Callable, Dict, Generator, List, Tuple, Union

from burr.core import Action, Application, ApplicationBuilder, ApplicationContext, Graph, State
Expand Down Expand Up @@ -58,6 +61,11 @@ def _create_app(self, parent_context: ApplicationContext) -> Application:
partition_key=parent_context.partition_key,
)
.with_tracker(parent_context.tracker.copy()) # We have to copy
# TODO -- handle persistence...
.with_identifiers(
app_id=self.application_id,
partition_key=parent_context.partition_key, # cascade the partition key
)
.build()
)

Expand Down Expand Up @@ -101,20 +109,60 @@ def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]:
:param run_kwargs:
:return:
"""
context: ApplicationContext = run_kwargs.get("__context")
if context is None:
raise ValueError("This action requires a context to run")
state_without_internals = state.wipe(
delete=[item for item in state.keys() if item.startswith("__")]
)
task_generator = self.tasks(state_without_internals, context, run_kwargs)

# TODO -- run in parallel
def state_generator():
for task in task_generator:
yield task.run(run_kwargs["__context"])
def _run_and_update():
context: ApplicationContext = run_kwargs.get("__context")
if context is None:
raise ValueError("This action requires a context to run")
state_without_internals = state.wipe(
delete=[item for item in state.keys() if item.startswith("__")]
)
task_generator = self.tasks(state_without_internals, context, run_kwargs)

def execute_task(task):
return task.run(run_kwargs["__context"])

return {}, self.reduce(state_without_internals, state_generator())
# TODO -- take the threadpool executor out and make it generic
with ThreadPoolExecutor() as executor:
# Directly map the generator to the executor
results = list(executor.map(execute_task, task_generator))

def state_generator() -> Generator[Any, None, None]:
yield from results

return {}, self.reduce(state_without_internals, state_generator())

async def _arun_and_update():
context: ApplicationContext = run_kwargs.get("__context")
if context is None:
raise ValueError("This action requires a context to run")
state_without_internals = state.wipe(
delete=[item for item in state.keys() if item.startswith("__")]
)
task_generator = self.tasks(state_without_internals, context, run_kwargs)

# TODO -- run in parallel
async def state_generator():
"""This makes it easier on the user -- if they don't have an async generator we can still exhause it
This way we run through all of the task generators. These correspond to the task generation capabilities above (the map*/task generation stuff)
"""
if inspect.isasyncgen(task_generator):
coroutines = [task.arun(context) async for task in task_generator]
else:
coroutines = [task.arun(context) for task in task_generator]
results = await asyncio.gather(*coroutines)
# TODO -- yield in order...
for result in results:
yield result

return {}, await self.reduce(state_without_internals, state_generator())

if self.is_async():
return _arun_and_update() # type: ignore
return _run_and_update()

def is_async(self) -> bool:
return False

@property
def inputs(self) -> Union[list[str], tuple[list[str], list[str]]]:
Expand Down Expand Up @@ -288,8 +336,7 @@ def actions(
:param inputs:
:return:
"""
for sub_state in self.states(state, context, inputs):
yield self.action(sub_state, inputs)
yield self.action(state, inputs)

@abc.abstractmethod
def reduce(self, state: State, results: Generator[State, None, None]) -> State:
Expand Down
7 changes: 7 additions & 0 deletions burr/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,13 @@ def subset(self, *keys: str, ignore_missing: bool = True) -> "State[StateType]":
)

def __getitem__(self, __k: str) -> Any:
if __k not in self._state:
raise KeyError(
f"Key \"{__k}\" not found in state. Keys state knows about are: {[key for key in self._state.keys() if not key.startswith('__')]}. "
"If you hit this within the context of an application, you want to "
"(a) ensure that an upstream action has produced this state/it is set as an initial state value and "
"(b) ensure that your action declares this as a read key."
)
return self._state[__k]

def __len__(self) -> int:
Expand Down
229 changes: 229 additions & 0 deletions examples/recursive/application_parallel_feature.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
from typing import Any, Callable, Dict, Generator, List, Tuple, Union

import openai

from burr.core import Action, Application, ApplicationBuilder, Condition, State, action
from burr.core.application import ApplicationContext
from burr.core.graph import GraphBuilder
from burr.core.parallelism import MapStates, RunnableGraph


# full agent
def _query_llm(prompt: str) -> str:
"""Simple wrapper around the OpenAI API."""
client = openai.Client()
return (
client.chat.completions.create(
model="gpt-4o",
messages=[
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": prompt},
],
)
.choices[0]
.message.content
)


@action(
reads=["feedback", "current_draft", "poem_type", "prompt"],
writes=["current_draft", "draft_history", "num_drafts"],
)
def write(state: State) -> Tuple[dict, State]:
"""Writes a draft of a poem."""
poem_subject = state["prompt"]
poem_type = state["poem_type"]
current_draft = state.get("current_draft")
feedback = state.get("feedback")

parts = [
f'You are an AI poet. Create a {poem_type} poem on the following subject: "{poem_subject}". '
"It is absolutely imperative that you respond with only the poem and no other text."
]

if current_draft:
parts.append(f'Here is the current draft of the poem: "{current_draft}".')

if feedback:
parts.append(f'Please incorporate the following feedback: "{feedback}".')

parts.append(
f"Ensure the poem is creative, adheres to the style of a {poem_type}, and improves upon the previous draft."
)

prompt = "\n".join(parts)

draft = _query_llm(prompt)

return {"draft": draft}, state.update(
current_draft=draft,
draft_history=state.get("draft_history", []) + [draft],
).increment(num_drafts=1)


@action(reads=["current_draft", "poem_type", "prompt"], writes=["feedback"])
def edit(state: State) -> Tuple[dict, State]:
"""Edits a draft of a poem, providing feedback"""
poem_subject = state["prompt"]
poem_type = state["poem_type"]
current_draft = state["current_draft"]

prompt = f"""
You are an AI poetry critic. Review the following {poem_type} poem based on the subject: "{poem_subject}".
Here is the current draft of the poem: "{current_draft}".
Provide detailed feedback to improve the poem. If the poem is already excellent and needs no changes, simply respond with an empty string.
"""

feedback = _query_llm(prompt)

return {"feedback": feedback}, state.update(feedback=feedback)


@action(reads=["current_draft"], writes=["final_draft"])
def final_draft(state: State) -> Tuple[dict, State]:
return {"final_draft": state["current_draft"]}, state.update(final_draft=state["current_draft"])


#
#
# def _create_sub_application(
# max_num_drafts: int,
# spawning_application_context: ApplicationContext,
# poem_type: str,
# prompt: str,
# ) -> Application:
# """Utility to create sub-application -- note"""
# out = (
# ApplicationBuilder()
# .with_actions(
# edit,
# write,
# final_draft,
# )
# .with_transitions(
# ("write", "edit", Condition.expr(f"num_drafts < {max_num_drafts}")),
# ("write", "final_draft"),
# ("edit", "final_draft", Condition.expr("len(feedback) == 0")),
# ("edit", "write"),
# )
# .with_tracker(spawning_application_context.tracker.copy()) # remember to do `copy()` here!
# .with_spawning_parent(
# spawning_application_context.app_id,
# spawning_application_context.sequence_id,
# spawning_application_context.partition_key,
# )
# .with_entrypoint("write")
# .with_state(
# current_draft=None,
# poem_type=poem_type,
# prompt=prompt,
# feedback=None,
# )
# .build()
# )
# return out

sub_application_graph = (
GraphBuilder()
.with_actions(
edit,
write,
final_draft,
)
.with_transitions(
("write", "edit", Condition.expr("num_drafts < max_num_drafts")),
("write", "final_draft"),
("edit", "final_draft", Condition.expr("len(feedback) == 0")),
("edit", "write"),
)
.build()
)


# full agent
@action(
reads=[],
writes=[
"max_drafts",
"poem_types",
"poem_subject",
],
)
def user_input(
state: State, max_drafts: int, poem_types: List[str], poem_subject: str
) -> Tuple[dict, State]:
"""Collects user input for the poem generation process."""
return {
"max_drafts": max_drafts,
"poem_types": poem_types,
"poem_subject": poem_subject,
}, state.update(max_drafts=max_drafts, poem_types=poem_types, poem_subject=poem_subject)


class GenerateAllPoems(MapStates):
def states(
self, state: State, context: ApplicationContext, inputs: Dict[str, Any]
) -> Generator[State, None, None]:
for poem_type in state["poem_types"]:
yield state.update(poem_type=poem_type, prompt=state["poem_subject"], max_num_drafts=2)

def action(
self, state: State, inputs: Dict[str, Any]
) -> Union[Action, Callable, RunnableGraph]:
return RunnableGraph(sub_application_graph, entrypoint="write", halt_after=["final_draft"])

def reduce(self, state: State, results: Generator[State, None, None]) -> State:
new_state = state
for output_state in results:
new_state = new_state.append(proposals=output_state["final_draft"])
return new_state

@property
def writes(self) -> list[str]:
return ["proposals"]

@property
def reads(self) -> list[str]:
return ["max_drafts", "poem_types", "poem_subject"]


@action(reads=["proposals", "prompts"], writes=["final_results"])
def final_results(state: State) -> Tuple[dict, State]:
# joins them into a string
proposals = state["proposals"]
final_results = "\n\n".join(
[f"{poem_type}:\n{proposal}" for poem_type, proposal in zip(state["poem_types"], proposals)]
)
return {"final_results": final_results}, state.update(final_results=final_results)


def application() -> Application:
return (
ApplicationBuilder()
.with_actions(user_input, final_results, generate_all_poems=GenerateAllPoems())
.with_transitions(
("user_input", "generate_all_poems"),
("generate_all_poems", "final_results"),
)
.with_tracker(project="demo:parallelism_poem_generation")
.with_entrypoint("user_input")
.build()
)


if __name__ == "__main__":
app = application()
app.visualize(output_file_path="statemachine", format="png")
app.run(
halt_after=["final_results"],
inputs={
"max_drafts": 2,
"poem_types": [
"sonnet",
"limerick",
"haiku",
"acrostic",
],
"poem_subject": "state machines",
},
)
Binary file modified examples/recursive/statemachine.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 153742b

Please sign in to comment.