From 35d425707a182636320f145ad122d729c9537b00 Mon Sep 17 00:00:00 2001 From: elijahbenizzy Date: Mon, 9 Sep 2024 21:28:12 -0700 Subject: [PATCH] Fixes up typing for streaming result containers --- burr/core/action.py | 33 ++++++++++++++++++++++----------- burr/core/application.py | 9 +++++---- 2 files changed, 27 insertions(+), 15 deletions(-) diff --git a/burr/core/action.py b/burr/core/action.py index 598d2263..535da22d 100644 --- a/burr/core/action.py +++ b/burr/core/action.py @@ -732,7 +732,10 @@ def is_async(self) -> bool: return True -class StreamingResultContainer(Iterator[dict], Generic[StateType]): +StreamResultType = TypeVar("StreamResultType", bound=Union[dict, Any]) + + +class StreamingResultContainer(Iterator[StreamResultType], Generic[StateType, StreamResultType]): """Container for a streaming result. This allows you to: 1. Iterate over the result as it comes in @@ -758,13 +761,15 @@ class StreamingResultContainer(Iterator[dict], Generic[StateType]): @staticmethod def pass_through( - results: dict, final_state: State[StateType] - ) -> "StreamingResultContainer[StateType]": + results: StreamResultType, final_state: State[StateType] + ) -> "StreamingResultContainer[StreamResultType, StateType]": """Instantiates a streaming result container that just passes through the given results This is to be used internally -- it allows us to wrap non-streaming action results in a streaming result container.""" - def empty_generator() -> Generator[Tuple[dict, Optional[State[StateType]]], None, None]: + def empty_generator() -> ( + Generator[Tuple[StreamResultType, Optional[State[StateType]]], None, None] + ): yield results, final_state return StreamingResultContainer( @@ -836,7 +841,9 @@ def get(self) -> StreamType: return self._result -class AsyncStreamingResultContainer(typing.AsyncIterator[dict], Generic[StateType]): +class AsyncStreamingResultContainer( + typing.AsyncIterator[StreamResultType], Generic[StateType, StreamResultType] +): """Container for an async streaming result. This allows you to: 1. Iterate over the result as it comes in 2. Await the final result/state at the end @@ -863,9 +870,11 @@ def __init__( self, streaming_result_generator: AsyncGeneratorReturnType, initial_state: State[StateType], - process_result: Callable[[dict, State[StateType]], tuple[dict, State[StateType]]], + process_result: Callable[ + [StreamResultType, State[StateType]], tuple[StreamResultType, State[StateType]] + ], callback: Callable[ - [Optional[dict], State[StateType], Optional[Exception]], + [Optional[StreamResultType], State[StateType], Optional[Exception]], typing.Coroutine[None, None, None], ], ): @@ -918,7 +927,7 @@ async def gen_fn(): # return it as `__aiter__` cannot be async/have awaits :/ return gen_fn() - async def get(self) -> tuple[Optional[dict], State[StateType]]: + async def get(self) -> tuple[Optional[StreamResultType], State[StateType]]: # exhaust the generator async for _ in self: pass @@ -927,7 +936,7 @@ async def get(self) -> tuple[Optional[dict], State[StateType]]: @staticmethod def pass_through( - results: dict, final_state: State[StateType] + results: StreamResultType, final_state: State[StateType] ) -> "AsyncStreamingResultContainer[StateType]": """Creates a streaming result container that just passes through the given results. This is not a public facing API.""" @@ -935,10 +944,12 @@ def pass_through( async def just_results() -> AsyncGeneratorReturnType: yield results, final_state - async def empty_callback(result: Optional[dict], state: State, exc: Optional[Exception]): + async def empty_callback( + result: Optional[StreamResultType], state: State, exc: Optional[Exception] + ): pass - return AsyncStreamingResultContainer[StateType]( + return AsyncStreamingResultContainer[StateType, StreamResultType]( just_results(), final_state, lambda result, state: (result, state), empty_callback ) diff --git a/burr/core/application.py b/burr/core/application.py index 0b4c8c77..f223f472 100644 --- a/burr/core/application.py +++ b/burr/core/application.py @@ -682,6 +682,7 @@ def post_run_step( ApplicationStateType = TypeVar("ApplicationStateType") +StreamResultType = TypeVar("StreamResultType", bound=Union[dict, Any]) class Application(Generic[ApplicationStateType]): @@ -1205,9 +1206,9 @@ async def arun( def stream_result( self, halt_after: list[str], - halt_before: list[str] = None, + halt_before: Optional[list[str]] = None, inputs: Optional[Dict[str, Any]] = None, - ) -> Tuple[Action, StreamingResultContainer[ApplicationStateType]]: + ) -> Tuple[Action, StreamingResultContainer[ApplicationStateType, Union[dict, Any]]]: """Streams a result out. :param halt_after: The list of actions to halt after execution of. It will halt on the first one. @@ -1454,9 +1455,9 @@ def callback( async def astream_result( self, halt_after: list[str], - halt_before: list[str] = None, + halt_before: Optional[list[str]] = None, inputs: Optional[Dict[str, Any]] = None, - ) -> Tuple[Action, AsyncStreamingResultContainer[ApplicationStateType]]: + ) -> Tuple[Action, AsyncStreamingResultContainer[ApplicationStateType, Union[dict, Any]]]: """Streams a result out in an asynchronous manner. :param halt_after: The list of actions to halt after execution of. It will halt on the first one.