Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
elijahbenizzy committed Mar 9, 2024
1 parent c27e2ee commit 195b681
Show file tree
Hide file tree
Showing 6 changed files with 646 additions and 31 deletions.
165 changes: 164 additions & 1 deletion burr/core/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,21 @@
import ast
import copy
import inspect
import sys
import types
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, TypeVar, Union
from typing import (
Any,
Callable,
Dict,
Generator,
Iterator,
List,
Optional,
Protocol,
Tuple,
TypeVar,
Union,
)

from burr.core.state import State

Expand Down Expand Up @@ -128,6 +141,10 @@ def name(self) -> str:
def single_step(self) -> bool:
return False

@property
def streaming(self) -> bool:
return False

def __repr__(self):
read_repr = ", ".join(self.reads) if self.reads else "{}"
write_repr = ", ".join(self.writes) if self.writes else "{}"
Expand Down Expand Up @@ -412,6 +429,152 @@ def is_async(self) -> bool:
return inspect.iscoroutinefunction(self._fn)


StreamingResultType = Generator[dict, None, dict]


class StreamingAction(Action, abc.ABC):
@abc.abstractmethod
def stream_run(self: State, **run_kwargs) -> StreamingResultType:
"""Streaming action stream_run is different than standard action run. It:
1. streams in a result (the dict output)
2. Returns the final result
Note that the user, in this case, is responsible for joining the result.
For instance, you could have:
.. code-block:: python
def stream_run(state: State) -> StreamingResultType:
buffer = [] # you might want to be more efficient than simple strcat
for token in query(state['prompt']):
yield {'response' : token}
buffer.append(token)
return {'response' : "".join(buffer)}
This would utilize a simple string buffer (implemented by a list) to store the results
and then join them at the end. We return the final result.
:param run_kwargs:
:return:
"""
pass

def run(self, state: State, **run_kwargs) -> dict:
gen = self.stream_run(state, **run_kwargs)
while True:
try:
next(gen) # if we just run through, we do nothing with the result
except StopIteration as e:
return e.value

@property
def streaming(self) -> bool:
return True


# TODO -- documentation for this
class StreamingResultContainer(Iterator[dict]):
"""Container for a streaming result. This allows you to:
1. Iterate over the result as it comes in
2. Get the final result/state at the end
If you're familiar with generators/iterators in python, this is effectively an
iterator that caches the final result after calling it. This is meant to be used
exclusively with the streaming action calls in `Application`. Note that you will
never instantiate this class directly, but you will use it in the API when it is returned
by :py:meth:`stream_result <burr.core.application.Application.stream_result>`.
For reference, here's how you would use it:
.. code-block:: python
streaming_result_container = application.stream_result(...)
action_we_just_ran = streaming_result_container.get()
print(f"getting streaming results for action={action_we_just_ran.name}")
for result_component in streaming_result_container:
print(result_component['response']) # this assumes you have a response key in your result
final_state, final_result = streaming_result_container.get()
"""

def __next__(self):
return next(self.generator())

def __init__(
self,
streaming_result_generator: Generator[dict, None, Tuple[dict, State]],
action: Action,
initial_state: State,
process_result: Callable[[dict, State], Tuple[dict, State]],
callback: Callable[[Optional[dict], State, Optional[Exception]], None],
):
self.streaming_result_generator = streaming_result_generator
self._action = action
self._callback = callback
self._process_result = process_result
self._initial_state = initial_state
self._result = None, self._initial_state
self._callback_realized = False

@property
def action(self) -> Action:
"""Gives you the action that this iterator is running."""
return self._action

def __iter__(self):
return self.generator()

def generator(self):
"""Gets the next result in the iterator"""
try:
while True:
yield next(self.streaming_result_generator)
except StopIteration as e:
if self._result[0] is not None:
return
output = e.value
self._result = self._process_result(*output)
finally:
exc = sys.exc_info()[1]
# For now this will not be the right exception type (Generator close),
# but its OK -- the exception is outside of our control fllow
if not self._callback_realized:
self._callback_realized = True
self._callback(*self._result, exc)

def get(self) -> Tuple[Optional[dict], State]:
"""Blocking call to get the final result of the streaming action. This will
run through the entire generator (or until an exception is raised) and return
the final result.
:return: A tuple of the result and the new state
"""
for _ in self:
pass
return self._result


class SingleStepStreamingAction(SingleStepAction, abc.ABC):
@abc.abstractmethod
def stream_run_and_update(
self, state: State, **run_kwargs
) -> Generator[dict, None, Tuple[dict, State]]:
"""Streaming version of the run and update function"""

def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]:
gen = self.stream_run_and_update(state, **run_kwargs)
while True:
try:
next(gen) # if we just run through, we do nothing with the result
except StopIteration as e:
return e.value

@property
def streaming(self) -> bool:
return True


def _validate_action_function(fn: Callable):
"""Validates that an action has the signature: (state: State) -> Tuple[dict, State]
Expand Down
Loading

0 comments on commit 195b681

Please sign in to comment.