diff --git a/burr/core/action.py b/burr/core/action.py index eb7fc6939..855a3fdfb 100644 --- a/burr/core/action.py +++ b/burr/core/action.py @@ -2,8 +2,21 @@ import ast import copy import inspect +import sys import types -from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, TypeVar, Union +from typing import ( + Any, + Callable, + Dict, + Generator, + Iterator, + List, + Optional, + Protocol, + Tuple, + TypeVar, + Union, +) from burr.core.state import State @@ -128,6 +141,10 @@ def name(self) -> str: def single_step(self) -> bool: return False + @property + def streaming(self) -> bool: + return False + def __repr__(self): read_repr = ", ".join(self.reads) if self.reads else "{}" write_repr = ", ".join(self.writes) if self.writes else "{}" @@ -412,6 +429,152 @@ def is_async(self) -> bool: return inspect.iscoroutinefunction(self._fn) +StreamingResultType = Generator[dict, None, dict] + + +class StreamingAction(Action, abc.ABC): + @abc.abstractmethod + def stream_run(self: State, **run_kwargs) -> StreamingResultType: + """Streaming action stream_run is different than standard action run. It: + 1. streams in a result (the dict output) + 2. Returns the final result + + Note that the user, in this case, is responsible for joining the result. + + For instance, you could have: + + .. code-block:: python + + def stream_run(state: State) -> StreamingResultType: + buffer = [] # you might want to be more efficient than simple strcat + for token in query(state['prompt']): + yield {'response' : token} + buffer.append(token) + return {'response' : "".join(buffer)} + + This would utilize a simple string buffer (implemented by a list) to store the results + and then join them at the end. We return the final result. + + :param run_kwargs: + :return: + """ + pass + + def run(self, state: State, **run_kwargs) -> dict: + gen = self.stream_run(state, **run_kwargs) + while True: + try: + next(gen) # if we just run through, we do nothing with the result + except StopIteration as e: + return e.value + + @property + def streaming(self) -> bool: + return True + + +# TODO -- documentation for this +class StreamingResultContainer(Iterator[dict]): + """Container for a streaming result. This allows you to: + 1. Iterate over the result as it comes in + 2. Get the final result/state at the end + + If you're familiar with generators/iterators in python, this is effectively an + iterator that caches the final result after calling it. This is meant to be used + exclusively with the streaming action calls in `Application`. Note that you will + never instantiate this class directly, but you will use it in the API when it is returned + by :py:meth:`stream_result `. + For reference, here's how you would use it: + + .. code-block:: python + + streaming_result_container = application.stream_result(...) + action_we_just_ran = streaming_result_container.get() + print(f"getting streaming results for action={action_we_just_ran.name}") + + for result_component in streaming_result_container: + print(result_component['response']) # this assumes you have a response key in your result + + final_state, final_result = streaming_result_container.get() + """ + + def __next__(self): + return next(self.generator()) + + def __init__( + self, + streaming_result_generator: Generator[dict, None, Tuple[dict, State]], + action: Action, + initial_state: State, + process_result: Callable[[dict, State], Tuple[dict, State]], + callback: Callable[[Optional[dict], State, Optional[Exception]], None], + ): + self.streaming_result_generator = streaming_result_generator + self._action = action + self._callback = callback + self._process_result = process_result + self._initial_state = initial_state + self._result = None, self._initial_state + self._callback_realized = False + + @property + def action(self) -> Action: + """Gives you the action that this iterator is running.""" + return self._action + + def __iter__(self): + return self.generator() + + def generator(self): + """Gets the next result in the iterator""" + try: + while True: + yield next(self.streaming_result_generator) + except StopIteration as e: + if self._result[0] is not None: + return + output = e.value + self._result = self._process_result(*output) + finally: + exc = sys.exc_info()[1] + # For now this will not be the right exception type (Generator close), + # but its OK -- the exception is outside of our control fllow + if not self._callback_realized: + self._callback_realized = True + self._callback(*self._result, exc) + + def get(self) -> Tuple[Optional[dict], State]: + """Blocking call to get the final result of the streaming action. This will + run through the entire generator (or until an exception is raised) and return + the final result. + + :return: A tuple of the result and the new state + """ + for _ in self: + pass + return self._result + + +class SingleStepStreamingAction(SingleStepAction, abc.ABC): + @abc.abstractmethod + def stream_run_and_update( + self, state: State, **run_kwargs + ) -> Generator[dict, None, Tuple[dict, State]]: + """Streaming version of the run and update function""" + + def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]: + gen = self.stream_run_and_update(state, **run_kwargs) + while True: + try: + next(gen) # if we just run through, we do nothing with the result + except StopIteration as e: + return e.value + + @property + def streaming(self) -> bool: + return True + + def _validate_action_function(fn: Callable): """Validates that an action has the signature: (state: State) -> Tuple[dict, State] diff --git a/burr/core/application.py b/burr/core/application.py index 1f67619c4..533257429 100644 --- a/burr/core/application.py +++ b/burr/core/application.py @@ -15,6 +15,7 @@ Set, Tuple, Union, + cast, ) from burr import visibility @@ -24,6 +25,9 @@ Function, Reducer, SingleStepAction, + SingleStepStreamingAction, + StreamingAction, + StreamingResultContainer, create_action, default, ) @@ -182,6 +186,24 @@ def _run_single_step_action( return out +def _run_single_step_streaming_action( + action: SingleStepStreamingAction, state: State, inputs: Optional[Dict[str, Any]] +) -> Generator[dict, None, Tuple[dict, State]]: + action.validate_inputs(inputs) + generator = action.stream_run_and_update(state, **inputs) + return (yield from generator) + + +def _run_multi_step_streaming_action( + action: StreamingAction, state: State, inputs: Optional[Dict[str, Any]] +) -> Generator[dict, None, Tuple[dict, State]]: + action.validate_inputs(inputs) + generator = action.stream_run(state, **inputs) + result = yield from generator + new_state = _run_reducer(action, state, result, action.name) + return result, _state_update(state, new_state) + + async def _arun_single_step_action( action: SingleStepAction, state: State, inputs: Optional[Dict[str, Any]] ) -> Tuple[dict, State]: @@ -254,7 +276,7 @@ def step(self, inputs: Optional[Dict[str, Any]] = None) -> Optional[Tuple[Action self._increment_sequence_id() def _step( - self, inputs: Optional[Dict[str, Any]] = None, _run_hooks: bool = True + self, inputs: Optional[Dict[str, Any]], _run_hooks: bool = True ) -> Optional[Tuple[Action, dict, State]]: """Internal-facing version of step. This is the same as step, but with an additional parameter to hide hook execution so async can leverage it.""" @@ -282,7 +304,7 @@ def _step( result = _run_function(next_action, self._state, inputs) new_state = _run_reducer(next_action, self._state, result, next_action.name) - new_state = self.update_internal_state_value(new_state, next_action) + new_state = self._update_internal_state_value(new_state, next_action) self._set_state(new_state) except Exception as e: exc = e @@ -300,7 +322,7 @@ def _step( ) return next_action, result, new_state - def update_internal_state_value(self, new_state: State, next_action: Action) -> State: + def _update_internal_state_value(self, new_state: State, next_action: Action) -> State: """Updates the internal state values of the new state.""" new_state = new_state.update( **{ @@ -377,7 +399,7 @@ async def astep(self, inputs: Dict[str, Any] = None) -> Optional[Tuple[Action, d else: result = await _arun_function(next_action, self._state, inputs=inputs) new_state = _run_reducer(next_action, self._state, result, next_action.name) - new_state = self.update_internal_state_value(new_state, next_action) + new_state = self._update_internal_state_value(new_state, next_action) self._set_state(new_state) except Exception as e: exc = e @@ -512,7 +534,7 @@ async def aiterate( halt_before: list[str] = None, halt_after: list[str] = None, inputs: Optional[Dict[str, Any]] = None, - ) -> AsyncGenerator[Tuple[Action, dict, State], Tuple[Action, Optional[dict], State]]: + ) -> AsyncGenerator[Tuple[Action, dict, State], None]: """Returns a generator that calls step() in a row, enabling you to see the state of the system as it updates. This is the asynchronous version so it has no capability of t @@ -584,6 +606,180 @@ async def arun( pass return self._return_value_iterate(halt_before, halt_after, prior_action, result) + def _validate_streaming_inputs(self, halt_after: list[str]): + if len(halt_after) == 0: + raise ValueError("halt_after conditions for streaming must be non-empty.") + missing_actions = set(halt_after) - set([action.name for action in self._actions]) + # TODO -- implement this check elsewhere as well, break out into further utility functions + if len(missing_actions) > 0: + raise ValueError( + f"Actions {missing_actions} were passed in as halt_after conditions, but not found in the actions list! " + f"Actions found: {[action.name for action in self._actions]}" + ) + action_objects = {self._action_map[action_name] for action_name in halt_after} + non_streaming_actions = {action.name for action in action_objects if not action.streaming} + if len(non_streaming_actions) > 0: + raise ValueError( + f"Actions {non_streaming_actions} were passed in as halt_after conditions, but are not streaming actions! " + f"Actions found: {[action.name for action in action_objects]}. We are " + f"planning to support this, but do not do so yet." # TODO -- link to docs on streaming + ) + + def stream_result( + self, halt_after: list[str], inputs: Optional[Dict[str, Any]] = None + ) -> StreamingResultContainer: + """Streams a result out. Note that this has to be used with a streaming action -- E.G. one that + is implemented as a generator. This allows you to stream data out for the final action in the sequence. + Use this if you care about time-to-first-result/token, or if you want to stream data out of the system. + This is specifically meant to be finite, and constructed so that state, etc... is all updated after + this is called and the generator is exhausted. Hooks are also called after the generator is exhausted (or fails), + unless exceptions occur prior to generation (in which case they are called before). + + :param halt_after: The list of actions to halt after execution of. It will halt on the first one. + :param inputs: Inputs to the action -- this is if this action requires an input that is passed in from the outside world + :return: A streaming result container, which is a generator that will yield results as they come in, as wel as cache/give you the final result, and update state accordingly. + + To see how this works, let's take the following action (simplified as a single-node workflow) as an example: + + .. code-block:: python + + @streaming_action(reads=[], writes=['response']) + def streaming_response(state: State, prompt: str) -> Generator[dict, None, Tuple[dict, State]]: + response = client.chat.completions.create( + model='gpt-3.5-turbo', + messages=[{ + 'role': 'user', + 'content': prompt + }], + temperature=0, + ) + buffer = [] + for chunk in response: + delta = chunk.choices[0].delta.content + buffer.append(delta) + # yield partial results + yield {'response': delta} + full_response = ''.join(buffer) + # return the final result + return {'response': full_response}, state.update(response=full_response) + + To use streaming_result, you pass in names of streaming actions (such as the one above) to the halt_after + parameter: + + .. code-block:: python + + application = ApplicationBuilder().with_actions(streaming_response=streaming_response)...build() + prompt = "Count to 100, with a comma between each number and no newlines. E.g., 1, 2, 3, ..." + output = application.stream_result(halt_after='streaming_response', inputs={"prompt": prompt}) + for result in output: + print(result['response']) # one by one + + result, state = output.get() + print(result) # all at once + + Note that if you have multiple halt_after conditions, you can use the ``.action`` attribute to get the action that + was run. + + .. code-block:: python + + application = ApplicationBuilder().with_actions( + streaming_response=streaming_response, + error=error # another function that outputs an error, streaming + )...build() + prompt = "Count to 100, with a comma between each number and no newlines. E.g., 1, 2, 3, ..." + output = application.stream_result(halt_after='streaming_response', inputs={"prompt": prompt}) + color = "red" if output.action.name == "error" else "green" + for result in output: + print(format(result['response'], color)) # assumes that error and streaming_response both have the same output shape + + (Still a WIP -- this does not work yet) You can stream non-streaming responses, you just won't get anything in the generator: + + .. code-block:: python + + application = ApplicationBuilder().with_actions( + streaming_response=streaming_response, + error=non_streaming_error # a non-streaming function that outputs an error + )...build() + prompt = "Count to 100, with a comma between each number and no newlines. E.g., 1, 2, 3, ..." + output = application.stream_result(halt_after='streaming_response', inputs={"prompt": prompt}) + color = "red" if output.action.name == "error" else "green" + if action.name == "streaming_response": # can also use the ``.streaming`` attribute of action + for result in output: + print(format(result['response'], color)) # assumes that error and streaming_response both have the same output shape + else: + result, state = output.get() + print(format(result['response'], color)) + """ + self._validate_streaming_inputs(halt_after) + next_action = self.get_next_action() + if next_action is None: + raise ValueError( + f"Cannot stream result -- no next action found! Prior action was: {self._state[PRIOR_STEP]}" + ) + + if next_action not in halt_after: + # fast forward until we get to the action + self.run(halt_before=halt_after, inputs=inputs) + inputs = {} # inputs always go to the first action, we want to wipe them afterwards + self._adapter_set.call_all_lifecycle_hooks_sync( + "pre_run_step", + action=next_action, + state=self._state, + inputs=inputs, + sequence_id=self.sequence_id, + ) + # we need to track if there's any exceptions that occur during this + try: + + def process_result(result: dict, state: State) -> Tuple[Dict[str, Any], State]: + new_state = self._update_internal_state_value(state, next_action) + self._set_state(new_state) + return result, new_state + + def callback( + result: Optional[dict], + state: State, + exc: Optional[Exception] = None, + seq_id=self.sequence_id, + ): + self._adapter_set.call_all_lifecycle_hooks_sync( + "post_run_step", + action=next_action, + state=state, + result=result, + sequence_id=seq_id, + exception=exc, + ) + # we want to increment regardless of failure + self._increment_sequence_id() + + if next_action.single_step: + next_action = cast(SingleStepStreamingAction, next_action) + generator = _run_single_step_streaming_action(next_action, self._state, inputs) + return StreamingResultContainer( + generator, next_action, self._state, process_result, callback + ) + else: + next_action = cast(StreamingAction, next_action) + generator = _run_multi_step_streaming_action(next_action, self._state, inputs) + except Exception as e: + # We only want to raise this in the case of an exception + # otherwise, this will get delegated to the finally + # block of the streaming result container + self._adapter_set.call_all_lifecycle_hooks_sync( + "post_run_step", + action=next_action, + state=self._state, + result=None, + sequence_id=self.sequence_id, + exception=e, + ) + self._increment_sequence_id() + raise + return StreamingResultContainer( + generator, next_action, self._state, process_result, callback + ) + def visualize( self, output_file_path: Optional[str], diff --git a/examples/gpt/application.py b/examples/gpt/application.py index 9dc60d829..70f141b23 100644 --- a/examples/gpt/application.py +++ b/examples/gpt/application.py @@ -73,7 +73,7 @@ def prompt_for_more(state: State) -> Tuple[dict, State]: return result, state.update(**result) -@action(reads=["prompt", "chat_history", "mode"], writes=["response"]) +@action(reads=["prompt", "chat_history", "mode"], writes=["chat_history"]) def chat_response( state: State, prepend_prompt: str, display_type: str = "text", model: str = "gpt-3.5-turbo" ) -> Tuple[dict, State]: @@ -92,36 +92,49 @@ def chat_response( messages=chat_history_api_format, ) response = result.choices[0].message.content - result = {"response": {"content": response, "type": MODES[state["mode"]], "role": "assistant"}} - return result, state.update(**result) + chat_item = {"content": response, "type": MODES[state["mode"]], "role": "assistant"} + return {"response": chat_history}, state.append(chat_history=chat_item) -@action(reads=["prompt", "chat_history", "mode"], writes=["response"]) +@action(reads=["prompt", "chat_history", "mode"], writes=["chat_history"]) def image_response(state: State, model: str = "dall-e-2") -> Tuple[dict, State]: client = _get_openai_client() result = client.images.generate( model=model, prompt=state["prompt"], size="1024x1024", quality="standard", n=1 ) response = result.data[0].url - result = {"response": {"content": response, "type": MODES[state["mode"]], "role": "assistant"}} - return result, state.update(**result) + chat_item = {"content": response, "type": MODES[state["mode"]], "role": "assistant"} + return {"response": chat_item}, state.update(response=chat_item).append(chat_history=chat_item) -@action(reads=["response", "safe", "mode"], writes=["chat_history"]) -def response(state: State) -> Tuple[dict, State]: - if not state["safe"]: - result = { - "chat_item": { - "role": "assistant", - "content": "I'm sorry, I can't respond to that.", - "type": "text", - } +@action(reads=["prompt", "chat_history", "mode"], writes=["chat_history"]) +def unsafe(state: State) -> Tuple[dict, State]: + result = { + "chat_item": { + "role": "assistant", + "content": "I'm sorry, I can't respond to that.", + "type": "text", } - else: - result = {"chat_item": state["response"]} + } return result, state.append(chat_history=result["chat_item"]) +# +# @action(reads=["response", "safe", "mode"], writes=["chat_history"]) +# def response(state: State) -> Tuple[dict, State]: +# if not state["safe"]: +# result = { +# "chat_item": { +# "role": "assistant", +# "content": "I'm sorry, I can't respond to that.", +# "type": "text", +# } +# } +# else: +# result = {"chat_item": state["response"]} +# return result, state.append(chat_history=result["chat_item"]) + + # TODO -- add in error handling # @action(reads=["error"], writes=["chat_history"]) # def error(state: State) -> Tuple[dict, State]: @@ -146,14 +159,14 @@ def base_application(hooks: List[LifecycleAdapter], app_id: str, storage_dir: st prepend_prompt="Please answer the following question:", ), prompt_for_more=prompt_for_more, - response=response, + unsafe=unsafe, ) .with_entrypoint("prompt") .with_state(chat_history=[]) .with_transitions( ("prompt", "check_safety", default), ("check_safety", "decide_mode", when(safe=True)), - ("check_safety", "response", default), + ("check_safety", "unsafe", default), ("decide_mode", "generate_image", when(mode="generate_image")), ("decide_mode", "generate_code", when(mode="generate_code")), ("decide_mode", "answer_question", when(mode="answer_question")), diff --git a/examples/gpt/simple_streamlit_app.py b/examples/gpt/simple_streamlit_app.py index 1a221fb4f..eed8b8af0 100644 --- a/examples/gpt/simple_streamlit_app.py +++ b/examples/gpt/simple_streamlit_app.py @@ -31,12 +31,12 @@ def main(): st.title("Chatbot example with Burr") app = initialize_app() - prompt = st.chat_input("Ask me a question!", key="chat_input") + user_input = st.chat_input("Ask me a question!", key="chat_input") for chat_message in app.state.get("chat_history", []): render_chat_message(chat_message) - if prompt: + if user_input: for action, result, state in app.iterate( - inputs={"prompt": prompt}, halt_after=["response"] + inputs={"prompt": user_input}, halt_after=["response"] ): if action.name in ["prompt", "response"]: render_chat_message(result["chat_item"]) diff --git a/tests/core/test_action.py b/tests/core/test_action.py index 45b9be731..d86f4700f 100644 --- a/tests/core/test_action.py +++ b/tests/core/test_action.py @@ -1,5 +1,5 @@ import asyncio -from typing import Tuple +from typing import Generator, Tuple import pytest @@ -11,11 +11,15 @@ Input, Result, SingleStepAction, + StreamingAction, + StreamingResultContainer, + StreamingResultType, _validate_action_function, action, create_action, default, ) +from burr.core.implementations import Placeholder def test_is_async_true(): @@ -344,3 +348,116 @@ def correct_signature(state: State) -> Tuple[dict, State]: pass _validate_action_function(correct_signature) + + +def test_streaming_action_stream_run(): + class SimpleStreamingAction(StreamingAction): + def stream_run(self, state: State, **run_kwargs) -> StreamingResultType: + buffer = [] + for char in state["echo"]: + yield {"response": char} + buffer.append(char) + return {"response": "".join(buffer)} + + @property + def reads(self) -> list[str]: + return [] + + @property + def writes(self) -> list[str]: + return ["response"] + + def update(self, result: dict, state: State) -> State: + return state.update(**result) + + action = SimpleStreamingAction() + STR = "test streaming action" + assert action.run(State({"echo": STR}))["response"] == STR + + +def sample_generator(chars: str) -> Generator[dict, None, Tuple[dict, State]]: + buffer = [] + for c in chars: + buffer.append(c) + yield {"response": c} + joined = "".join(buffer) + return {"response": joined}, State({"response": joined}) + + +def test_streaming_result_container_iterate(): + string_value = "test streaming action" + container = StreamingResultContainer( + sample_generator(string_value), + action=None, + initial_state=State(), + process_result=lambda r, s: (r, s), + callback=lambda s, r, e: None, + ) + assert [item["response"] for item in list(container)] == list(string_value) + result, state = container.get() + assert result["response"] == string_value + + +def test_streaming_result_get_runs_through(): + string_value = "test streaming action" + container = StreamingResultContainer( + sample_generator(string_value), + action=None, + initial_state=State(), + process_result=lambda r, s: (r, s), + callback=lambda s, r, e: None, + ) + result, state = container.get() + assert result["response"] == string_value + + +def test_streaming_result_callback_called(): + called = [] + string_value = "test streaming action" + + container = StreamingResultContainer( + sample_generator(string_value), + action=None, + # initial state is here solely for returning debugging so we can return an + # state to the user in the case of failure + initial_state=State({"foo": "bar"}), + process_result=lambda r, s: (r, s), + callback=lambda s, r, e: called.append((s, r, e)), + ) + container.get() + assert len(called) == 1 + state, result, error = called[0] + assert result["response"] == string_value + assert state["response"] == string_value + assert error is None + + +def test_streaming_result_callback_error(): + """This tests whether the callback is called when an error occurs in the generator. Note that try/except + blocks -- this is required so we can end up delegating to the generators closing capability""" + + class SentinelError(Exception): + pass + + try: + called = [] + string_value = "test streaming action" + container = StreamingResultContainer( + sample_generator(string_value), + action=Placeholder(reads=[], writes=[]), + initial_state=State({"foo": "bar"}), + process_result=lambda r, s: (r, s), + callback=lambda r, s, e: called.append((r, s, e)), + ) + try: + next(container) + raise SentinelError("error") + finally: + assert len(called) == 1 + ((result, state, error),) = called + assert state["foo"] == "bar" + assert result is None + # Exception is currently not exactly what we want, so won't assert on that. + # See note in StreamingResultContainer + except SentinelError: + pass diff --git a/tests/core/test_application.py b/tests/core/test_application.py index acf8d4c01..e0e87130c 100644 --- a/tests/core/test_application.py +++ b/tests/core/test_application.py @@ -1,11 +1,20 @@ import asyncio import logging -from typing import Awaitable, Callable, Tuple +from typing import Awaitable, Callable, Generator, Tuple import pytest from burr.core import State -from burr.core.action import Action, Condition, Result, SingleStepAction, default +from burr.core.action import ( + Action, + Condition, + Result, + SingleStepAction, + SingleStepStreamingAction, + StreamingAction, + StreamingResultType, + default, +) from burr.core.application import ( PRIOR_STEP, Application, @@ -15,8 +24,10 @@ _arun_single_step_action, _assert_set, _run_function, + _run_multi_step_streaming_action, _run_reducer, _run_single_step_action, + _run_single_step_streaming_action, _validate_actions, _validate_start, _validate_transitions, @@ -254,11 +265,56 @@ def inputs(self) -> list[str]: return ["additional_increment"] +class StreamingCounter(StreamingAction): + def stream_run(self, state: State, **run_kwargs) -> StreamingResultType: + if "steps_per_couunt" in run_kwargs: + steps_per_count = run_kwargs["granularity"] + else: + steps_per_count = 10 + count = state["count"] + for i in range(steps_per_count): + yield {"count": count + ((i + 1) / 10)} + return {"count": count + 1} + + @property + def reads(self) -> list[str]: + return ["count"] + + @property + def writes(self) -> list[str]: + return ["count", "tracker"] + + def update(self, result: dict, state: State) -> State: + return state.update(**result).append(tracker=result["count"]) + + +class SingleStepStreamingCounter(SingleStepStreamingAction): + def stream_run_and_update( + self, state: State, **run_kwargs + ) -> Generator[dict, None, Tuple[dict, State]]: + steps_per_count = run_kwargs.get("granularity", 10) + count = state["count"] + for i in range(steps_per_count): + yield {"count": count + (i + 1 / 10)} + return {"count": count + 1}, state.update(count=count + 1).append(tracker=count + 1) + + @property + def reads(self) -> list[str]: + return ["count"] + + @property + def writes(self) -> list[str]: + return ["count", "tracker"] + + base_single_step_counter = SingleStepCounter() base_single_step_counter_async = SingleStepCounterAsync() base_single_step_counter_with_inputs = SingleStepCounterWithInputs() base_single_step_counter_with_inputs_async = SingleStepCounterWithInputsAsync() +base_streaming_counter = StreamingCounter() +base_streaming_single_step_counter = SingleStepStreamingCounter() + def test__run_single_step_action(): action = base_single_step_counter.with_name("counter") @@ -328,6 +384,38 @@ def test__run_single_step_action_deletes_state(): assert "to_delete" not in state +def test__run_multistep_streaming_action(): + action = base_streaming_counter.with_name("counter") + state = State({"count": 0, "tracker": []}) + generator = _run_multi_step_streaming_action(action, state, inputs={}) + last_result = -1 + try: + while True: + result = next(generator) + assert result["count"] > last_result + last_result = result["count"] + except StopIteration as e: + result, state = e.value + assert result == {"count": 1} + assert state.subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]} + + +def test__run_single_step_streaming_action(): + action = base_streaming_single_step_counter.with_name("counter") + state = State({"count": 0, "tracker": []}) + generator = _run_single_step_streaming_action(action, state, inputs={}) + last_result = -1 + try: + while True: + result = next(generator) + assert result["count"] > last_result + last_result = result["count"] + except StopIteration as e: + result, state = e.value + assert result == {"count": 1} + assert state.subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]} + + class SingleStepActionWithDeletionAsync(SingleStepActionWithDeletion): async def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]: return {}, state.wipe(delete=["to_delete"]) @@ -768,6 +856,44 @@ async def test_app_a_run_async_and_sync(): assert result["counter"] > 20 +def test_stream_result_halt_after(): + counter_action = base_streaming_counter.with_name("counter") + counter_action_2 = base_streaming_counter.with_name("counter_2") + app = Application( + actions=[counter_action, counter_action_2], + transitions=[ + Transition(counter_action, counter_action_2, default), + ], + state=State({"count": 0}), + initial_step="counter", + ) + streaming_container = app.stream_result(halt_after=["counter_2"]) + results = list(streaming_container) + assert len(results) == 10 + result, state = streaming_container.get() + assert result["count"] == state["count"] == 2 + assert state["tracker"] == [1, 2] + + +def test_stream_result_halt_after_single_step(): + counter_action = base_streaming_single_step_counter.with_name("counter") + counter_action_2 = base_streaming_single_step_counter.with_name("counter_2") + app = Application( + actions=[counter_action, counter_action_2], + transitions=[ + Transition(counter_action, counter_action_2, default), + ], + state=State({"count": 0}), + initial_step="counter", + ) + streaming_container = app.stream_result(halt_after=["counter_2"]) + results = list(streaming_container) + assert len(results) == 10 + result, state = streaming_container.get() + assert result["count"] == state["count"] == 2 + assert state["tracker"] == [1, 2] + + def test_app_set_state(): counter_action = base_counter_action.with_name("counter") app = Application(