Skip to content

Commit

Permalink
Fixes up typing for streaming result containers
Browse files Browse the repository at this point in the history
  • Loading branch information
elijahbenizzy committed Sep 10, 2024
1 parent 7c52920 commit 35d4257
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 15 deletions.
33 changes: 22 additions & 11 deletions burr/core/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,7 +732,10 @@ def is_async(self) -> bool:
return True


class StreamingResultContainer(Iterator[dict], Generic[StateType]):
StreamResultType = TypeVar("StreamResultType", bound=Union[dict, Any])


class StreamingResultContainer(Iterator[StreamResultType], Generic[StateType, StreamResultType]):
"""Container for a streaming result. This allows you to:
1. Iterate over the result as it comes in
Expand All @@ -758,13 +761,15 @@ class StreamingResultContainer(Iterator[dict], Generic[StateType]):

@staticmethod
def pass_through(
results: dict, final_state: State[StateType]
) -> "StreamingResultContainer[StateType]":
results: StreamResultType, final_state: State[StateType]
) -> "StreamingResultContainer[StreamResultType, StateType]":
"""Instantiates a streaming result container that just passes through the given results
This is to be used internally -- it allows us to wrap non-streaming action results in a streaming
result container."""

def empty_generator() -> Generator[Tuple[dict, Optional[State[StateType]]], None, None]:
def empty_generator() -> (
Generator[Tuple[StreamResultType, Optional[State[StateType]]], None, None]
):
yield results, final_state

return StreamingResultContainer(
Expand Down Expand Up @@ -836,7 +841,9 @@ def get(self) -> StreamType:
return self._result


class AsyncStreamingResultContainer(typing.AsyncIterator[dict], Generic[StateType]):
class AsyncStreamingResultContainer(
typing.AsyncIterator[StreamResultType], Generic[StateType, StreamResultType]
):
"""Container for an async streaming result. This allows you to:
1. Iterate over the result as it comes in
2. Await the final result/state at the end
Expand All @@ -863,9 +870,11 @@ def __init__(
self,
streaming_result_generator: AsyncGeneratorReturnType,
initial_state: State[StateType],
process_result: Callable[[dict, State[StateType]], tuple[dict, State[StateType]]],
process_result: Callable[
[StreamResultType, State[StateType]], tuple[StreamResultType, State[StateType]]
],
callback: Callable[
[Optional[dict], State[StateType], Optional[Exception]],
[Optional[StreamResultType], State[StateType], Optional[Exception]],
typing.Coroutine[None, None, None],
],
):
Expand Down Expand Up @@ -918,7 +927,7 @@ async def gen_fn():
# return it as `__aiter__` cannot be async/have awaits :/
return gen_fn()

async def get(self) -> tuple[Optional[dict], State[StateType]]:
async def get(self) -> tuple[Optional[StreamResultType], State[StateType]]:
# exhaust the generator
async for _ in self:
pass
Expand All @@ -927,18 +936,20 @@ async def get(self) -> tuple[Optional[dict], State[StateType]]:

@staticmethod
def pass_through(
results: dict, final_state: State[StateType]
results: StreamResultType, final_state: State[StateType]
) -> "AsyncStreamingResultContainer[StateType]":
"""Creates a streaming result container that just passes through the given results.
This is not a public facing API."""

async def just_results() -> AsyncGeneratorReturnType:
yield results, final_state

async def empty_callback(result: Optional[dict], state: State, exc: Optional[Exception]):
async def empty_callback(
result: Optional[StreamResultType], state: State, exc: Optional[Exception]
):
pass

return AsyncStreamingResultContainer[StateType](
return AsyncStreamingResultContainer[StateType, StreamResultType](
just_results(), final_state, lambda result, state: (result, state), empty_callback
)

Expand Down
9 changes: 5 additions & 4 deletions burr/core/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,7 @@ def post_run_step(


ApplicationStateType = TypeVar("ApplicationStateType")
StreamResultType = TypeVar("StreamResultType", bound=Union[dict, Any])


class Application(Generic[ApplicationStateType]):
Expand Down Expand Up @@ -1205,9 +1206,9 @@ async def arun(
def stream_result(
self,
halt_after: list[str],
halt_before: list[str] = None,
halt_before: Optional[list[str]] = None,
inputs: Optional[Dict[str, Any]] = None,
) -> Tuple[Action, StreamingResultContainer[ApplicationStateType]]:
) -> Tuple[Action, StreamingResultContainer[ApplicationStateType, Union[dict, Any]]]:
"""Streams a result out.
:param halt_after: The list of actions to halt after execution of. It will halt on the first one.
Expand Down Expand Up @@ -1454,9 +1455,9 @@ def callback(
async def astream_result(
self,
halt_after: list[str],
halt_before: list[str] = None,
halt_before: Optional[list[str]] = None,
inputs: Optional[Dict[str, Any]] = None,
) -> Tuple[Action, AsyncStreamingResultContainer[ApplicationStateType]]:
) -> Tuple[Action, AsyncStreamingResultContainer[ApplicationStateType, Union[dict, Any]]]:
"""Streams a result out in an asynchronous manner.
:param halt_after: The list of actions to halt after execution of. It will halt on the first one.
Expand Down

0 comments on commit 35d4257

Please sign in to comment.