diff --git a/burr/core/__init__.py b/burr/core/__init__.py index 3d61c144..c371ac47 100644 --- a/burr/core/__init__.py +++ b/burr/core/__init__.py @@ -5,6 +5,7 @@ ApplicationContext, ApplicationGraph, ) +from burr.core.graph import Graph from burr.core.state import State __all__ = [ @@ -20,4 +21,5 @@ "Result", "State", "when", + "Graph", ] diff --git a/burr/core/action.py b/burr/core/action.py index 61381afc..cb9a7ac5 100644 --- a/burr/core/action.py +++ b/burr/core/action.py @@ -6,18 +6,22 @@ import sys import types import typing +from collections.abc import AsyncIterator from typing import ( + TYPE_CHECKING, Any, AsyncGenerator, Callable, Coroutine, Dict, Generator, + Generic, Iterator, List, Optional, Protocol, Tuple, + Type, TypeVar, Union, ) @@ -28,6 +32,18 @@ from typing import Self from burr.core.state import State +from burr.core.typing import ActionSchema + +# This is here to make accessing the pydantic actions easier +# we just attach them to action so you can call `@action.pyddantic...` +# The IDE will like it better and thus be able to auto-complete/type-check +# TODO - come up with a better way to attach integrations to core objects +imported_pydantic = False +if TYPE_CHECKING: + try: + from pydantic import BaseModel + except ImportError: + pass class Function(abc.ABC): @@ -96,9 +112,11 @@ def validate_inputs(self, inputs: Optional[Dict[str, Any]]) -> None: f"Inputs to function {self} are invalid. " + f"Missing the following inputs: {', '.join(missing_inputs)}." if missing_inputs - else "" f"Additional inputs: {','.join(additional_inputs)}." - if additional_inputs - else "" + else ( + "" f"Additional inputs: {','.join(additional_inputs)}." + if additional_inputs + else "" + ) ) def is_async(self) -> bool: @@ -127,6 +145,20 @@ def update(self, result: dict, state: State) -> State: pass +class DefaultSchema(ActionSchema): + def state_input_type(self) -> type[State]: + raise NotImplementedError + + def state_output_type(self) -> type[State]: + raise NotImplementedError + + def intermediate_result_type(self) -> type[dict]: + return dict + + +DEFAULT_SCHEMA = DefaultSchema() + + class Action(Function, Reducer, abc.ABC): def __init__(self): """Represents an action in a state machine. This is the base class from which @@ -170,6 +202,10 @@ def single_step(self) -> bool: def streaming(self) -> bool: return False + @property + def schema(self) -> ActionSchema: + return DEFAULT_SCHEMA + def get_source(self) -> str: """Returns the source code of the action. This will default to the source code of the class in which the action is implemented, @@ -177,6 +213,11 @@ def get_source(self) -> str: to display a different source""" return inspect.getsource(self.__class__) + def input_schema(self) -> Any: + """Returns the input schema for the action. + The input schema is a type that can be used to validate the input to the action""" + return None + def __repr__(self): read_repr = ", ".join(self.reads) if self.reads else "{}" write_repr = ", ".join(self.writes) if self.writes else "{}" @@ -493,7 +534,14 @@ def is_async(self) -> bool: # the following exist to share implementation between FunctionBasedStreamingAction and FunctionBasedAction # TODO -- think through the class hierarchy to simplify, for now this is OK -def _get_inputs(bound_params: dict, fn: Callable) -> tuple[list[str], list[str]]: +def derive_inputs_from_fn(bound_params: dict, fn: Callable) -> tuple[list[str], list[str]]: + """Derives inputs from the function, given the bound parameters. This assumes that the function + has inputs named `state`, as well as any number of other kwarg-boundable parameters. + + :param bound_params: Parameters that are already bound to the function + :param fn: Function to derive inputs from + :return: Required and optional inputs + """ sig = inspect.signature(fn) required_inputs, optional_inputs = [], [] for param_name, param in sig.parameters.items(): @@ -507,9 +555,7 @@ def _get_inputs(bound_params: dict, fn: Callable) -> tuple[list[str], list[str]] return required_inputs, optional_inputs -FunctionBasedActionType = TypeVar( - "FunctionBasedActionType", bound=Union["FunctionBasedAction", "FunctionBasedStreamingAction"] -) +FunctionBasedActionType = Union["FunctionBasedAction", "FunctionBasedStreamingAction"] class FunctionBasedAction(SingleStepAction): @@ -520,21 +566,35 @@ def __init__( fn: Callable, reads: List[str], writes: List[str], - bound_params: dict = None, + bound_params: Optional[dict] = None, + input_spec: Optional[tuple[list[str], list[str]]] = None, + originating_fn: Optional[Callable] = None, + schema: ActionSchema = DEFAULT_SCHEMA, ): """Instantiates a function-based action with the given function, reads, and writes. The function must take in a state and return a tuple of (result, new_state). - :param fn: - :param reads: - :param writes: + :param fn: Function to run + :param reads: Keys that the function reads from the state + :param writes: Keys that the function writes to the state + :param bound_params: Prior bound parameters + :param input_spec: Specification for inputs. Will derive from function if not provided. """ super(FunctionBasedAction, self).__init__() + self._originating_fn = originating_fn if originating_fn is not None else fn self._fn = fn self._reads = reads self._writes = writes self._bound_params = bound_params if bound_params is not None else {} - self._inputs = _get_inputs(self._bound_params, self._fn) + self._inputs = ( + derive_inputs_from_fn(self._bound_params, self._fn) + if input_spec is None + else ( + [item for item in input_spec[0] if item not in self._bound_params], + [item for item in input_spec[1] if item not in self._bound_params], + ) + ) + self._schema = schema @property def fn(self) -> Callable: @@ -552,6 +612,10 @@ def writes(self) -> list[str]: def inputs(self) -> tuple[list[str], list[str]]: return self._inputs + @property + def schema(self) -> ActionSchema: + return self._schema + def with_params(self, **kwargs: Any) -> "FunctionBasedAction": """Binds parameters to the function. Note that there is no reason to call this by the user. This *could* @@ -562,7 +626,13 @@ def with_params(self, **kwargs: Any) -> "FunctionBasedAction": :return: """ return FunctionBasedAction( - self._fn, self._reads, self._writes, {**self._bound_params, **kwargs} + self._fn, + self._reads, + self._writes, + {**self._bound_params, **kwargs}, + input_spec=self._inputs, + originating_fn=self._originating_fn, + schema=self._schema, ) def run_and_update(self, state: State, **run_kwargs) -> tuple[dict, State]: @@ -573,10 +643,12 @@ def is_async(self) -> bool: def get_source(self) -> str: """Return the source of the code for this action.""" - return inspect.getsource(self._fn) + return inspect.getsource(self._originating_fn) + +StateType = TypeVar("StateType") -StreamType = Tuple[dict, Optional[State]] +StreamType = Tuple[dict, Optional[State[StateType]]] GeneratorReturnType = Generator[StreamType, None, None] AsyncGeneratorReturnType = AsyncGenerator[StreamType, None] @@ -590,7 +662,7 @@ class StreamingAction(Action, abc.ABC): they run in multiple passes (run -> update)""" @abc.abstractmethod - def stream_run(self, state: State, **run_kwargs) -> Generator[dict, None, None]: + def stream_run(self, state: State[StateType], **run_kwargs) -> Generator[dict, None, None]: """Streaming action ``stream_run`` is different than standard action run. It: 1. streams in an intermediate result (the dict output) 2. yields the final result at the end @@ -617,7 +689,7 @@ def stream_run(state: State) -> Generator[dict, None, dict]: """ pass - def run(self, state: State, **run_kwargs) -> dict: + def run(self, state: State[StateType], **run_kwargs) -> dict: """Runs the streaming action through to completion.""" gen = self.stream_run(state, **run_kwargs) last_result = None @@ -636,7 +708,7 @@ class AsyncStreamingAction(Action, abc.ABC): Note this is the "multi-step" variant, in which run/update are separate.""" @abc.abstractmethod - async def stream_run(self, state: State, **run_kwargs) -> AsyncGenerator[dict, None]: + async def stream_run(self, state, **run_kwargs) -> AsyncGenerator[dict, None]: """Asynchronous streaming action ``stream_run`` is different than the standard action run. It: 1. streams in an intermediate result (the dict output) 2. yields the final result at the end @@ -663,7 +735,7 @@ async def stream_run(state: State) -> Generator[dict, None, dict]: """ pass - async def run(self, state: State, **run_kwargs) -> dict: + async def run(self, state: State[StateType], **run_kwargs) -> dict: """Runs the streaming action through to completion. Returns the final result. This is used if we want a streaming action as an intermediate. @@ -686,7 +758,10 @@ def is_async(self) -> bool: return True -class StreamingResultContainer(Iterator[dict]): +StreamResultType = TypeVar("StreamResultType") + + +class StreamingResultContainer(Generic[StateType, StreamResultType], Iterator[StreamResultType]): """Container for a streaming result. This allows you to: 1. Iterate over the result as it comes in @@ -711,12 +786,16 @@ class StreamingResultContainer(Iterator[dict]): """ @staticmethod - def pass_through(results: dict, final_state: State) -> "StreamingResultContainer": + def pass_through( + results: StreamResultType, final_state: State[StateType] + ) -> "StreamingResultContainer[StreamResultType, StateType]": """Instantiates a streaming result container that just passes through the given results This is to be used internally -- it allows us to wrap non-streaming action results in a streaming result container.""" - def empty_generator() -> GeneratorReturnType: + def empty_generator() -> ( + Generator[Tuple[StreamResultType, Optional[State[StateType]]], None, None] + ): yield results, final_state return StreamingResultContainer( @@ -729,7 +808,7 @@ def empty_generator() -> GeneratorReturnType: def __init__( self, streaming_result_generator: GeneratorReturnType, - initial_state: State, + initial_state: State[StateType], process_result: Callable[[dict, State], tuple[dict, State]], callback: Callable[[Optional[dict], State, Optional[Exception]], None], ): @@ -750,7 +829,7 @@ def __init__( self._result = None self._callback_realized = False - def __next__(self): + def __next__(self) -> StreamResultType: if self._result is not None: # we're done, and we've run through it raise StopIteration @@ -760,7 +839,7 @@ def __next__(self): raise StopIteration return result - def __iter__(self): + def __iter__(self) -> Iterator[StreamResultType]: def gen_fn(): try: while True: @@ -780,7 +859,7 @@ def gen_fn(): # as the async version return gen_fn() - def get(self) -> StreamType: + def get(self) -> Tuple[StreamResultType, State[StateType]]: # exhaust the generator for _ in self: pass @@ -788,7 +867,10 @@ def get(self) -> StreamType: return self._result -class AsyncStreamingResultContainer(typing.AsyncIterator[dict]): +class AsyncStreamingResultContainer( + Generic[StateType, StreamResultType], + AsyncIterator[StreamResultType], +): """Container for an async streaming result. This allows you to: 1. Iterate over the result as it comes in 2. Await the final result/state at the end @@ -814,10 +896,13 @@ class AsyncStreamingResultContainer(typing.AsyncIterator[dict]): def __init__( self, streaming_result_generator: AsyncGeneratorReturnType, - initial_state: State, - process_result: Callable[[dict, State], tuple[dict, State]], + initial_state: State[StateType], + process_result: Callable[ + [StreamResultType, State[StateType]], tuple[StreamResultType, State[StateType]] + ], callback: Callable[ - [Optional[dict], State, Optional[Exception]], typing.Coroutine[None, None, None] + [Optional[StreamResultType], State[StateType], Optional[Exception]], + typing.Coroutine[None, None, None], ], ): """Initializes an async streaming result container. User will never call directly. @@ -836,7 +921,7 @@ def __init__( self._result = None self._callback_realized = False - async def __anext__(self): + async def __anext__(self) -> StreamResultType: """Moves to the next state in the streaming result""" if self._result is not None: # we're done, and we've run through it @@ -847,7 +932,7 @@ async def __anext__(self): raise StopAsyncIteration return result - def __aiter__(self): + def __aiter__(self) -> AsyncIterator[StreamResultType]: """Gives the iterator. Just calls anext, assigning the result in the finally block. Note this may not be perfect due to the complexity of callbacks for async generators, but it works in most cases.""" @@ -869,7 +954,7 @@ async def gen_fn(): # return it as `__aiter__` cannot be async/have awaits :/ return gen_fn() - async def get(self) -> tuple[Optional[dict], State]: + async def get(self) -> tuple[Optional[StreamResultType], State[StateType]]: # exhaust the generator async for _ in self: pass @@ -877,17 +962,21 @@ async def get(self) -> tuple[Optional[dict], State]: return self._result @staticmethod - def pass_through(results: dict, final_state: State) -> "AsyncStreamingResultContainer": + def pass_through( + results: StreamResultType, final_state: State[StateType] + ) -> "AsyncStreamingResultContainer[StateType]": """Creates a streaming result container that just passes through the given results. This is not a public facing API.""" async def just_results() -> AsyncGeneratorReturnType: yield results, final_state - async def empty_callback(result: Optional[dict], state: State, exc: Optional[Exception]): + async def empty_callback( + result: Optional[StreamResultType], state: State, exc: Optional[Exception] + ): pass - return AsyncStreamingResultContainer( + return AsyncStreamingResultContainer[StateType, StreamResultType]( just_results(), final_state, lambda result, state: (result, state), empty_callback ) @@ -954,7 +1043,10 @@ def __init__( ], reads: List[str], writes: List[str], - bound_params: dict = None, + bound_params: Optional[dict] = None, + input_spec: Optional[tuple[list[str], list[str]]] = None, + originating_fn: Optional[Callable] = None, + schema: ActionSchema = DEFAULT_SCHEMA, ): """Instantiates a function-based streaming action with the given function, reads, and writes. The function must take in a state (and inputs) and return a generator of (result, new_state). @@ -968,6 +1060,16 @@ def __init__( self._reads = reads self._writes = writes self._bound_params = bound_params if bound_params is not None else {} + self._inputs = ( + derive_inputs_from_fn(self._bound_params, self._fn) + if input_spec is None + else ( + [item for item in input_spec[0] if item not in self._bound_params], + [item for item in input_spec[1] if item not in self._bound_params], + ) + ) + self._originating_fn = originating_fn if originating_fn is not None else fn + self._schema = schema async def _a_stream_run_and_update( self, state: State, **run_kwargs @@ -1005,23 +1107,33 @@ def with_params(self, **kwargs: Any) -> "FunctionBasedStreamingAction": :return: """ return FunctionBasedStreamingAction( - self._fn, self._reads, self._writes, {**self._bound_params, **kwargs} + self._fn, + self._reads, + self._writes, + {**self._bound_params, **kwargs}, + input_spec=self._inputs, + originating_fn=self._originating_fn, + schema=self._schema, ) @property def inputs(self) -> tuple[list[str], list[str]]: - return _get_inputs(self._bound_params, self._fn) + return self._inputs @property def fn(self) -> Union[StreamingFn, StreamingFnAsync]: return self._fn + @property + def schema(self) -> ActionSchema: + return self._schema + def is_async(self) -> bool: return inspect.isasyncgenfunction(self._fn) def get_source(self) -> str: """Return the source of the code for this action""" - return inspect.getsource(self._fn) + return inspect.getsource(self._originating_fn) C = TypeVar("C", bound=Callable) # placeholder for any Callable @@ -1069,77 +1181,148 @@ def my_action(state: State, z: int) -> tuple[dict, State]: return self -def action(reads: List[str], writes: List[str]) -> Callable[[Callable], FunctionRepresentingAction]: - """Decorator to create a function-based action. This is user-facing. - Note that, in the future, with typed state, we may not need this for - all cases. +class action: + @staticmethod + def pydantic( + reads: List[str], + writes: List[str], + state_input_type: Optional[Type["BaseModel"]] = None, + state_output_type: Optional[Type["BaseModel"]] = None, + ) -> Callable: + """Action that specifies inputs/outputs using pydantic models. + This should make it easier to develop with guardrails. + + :param reads: keys that this model reads. Note that this will be a subset of the pydantic model with which this is decorated. + We will be validating that the keys are present in the model. + :param writes: keys that this model writes. Note that this will be a subset of the pydantic model with which this is decorated. + We will be validating that the keys are present in the model. + :param state_input_type: The pydantic model type that is used to represent the input state. + If this is None it will attempt to derive from the signature. + :param state_output_type: The pydantic model type that is used to represent the output state. + If this is None it will attempt to derive from the signature. + :return: + """ + try: + from burr.integrations.pydantic import pydantic_action + except ImportError: + raise ImportError( + "Please install pydantic to use the pydantic decorator. pip install burr[pydantic]" + ) - If parameters are not bound, they will be interpreted as inputs and must - be passed in at runtime. If they have default values, they will be recorded - as optional inputs. These can (optionally) be provided at runtime. + return pydantic_action( + reads=reads, + writes=writes, + state_input_type=state_input_type, + state_output_type=state_output_type, + ) - :param reads: Items to read from the state - :param writes: Items to write to the state - :return: The decorator to assign the function as an action - """ + def __init__(self, reads: List[str], writes: List[str]): + """Decorator to create a function-based action. This is user-facing. + Note that, in the future, with typed state, we may not need this for + all cases. + + If parameters are not bound, they will be interpreted as inputs and must + be passed in at runtime. If they have default values, they will be recorded + as optional inputs. These can (optionally) be provided at runtime. + + :param reads: Items to read from the state + :param writes: Items to write to the state + :return: The decorator to assign the function as an action + """ + self.reads = reads + self.writes = writes - def decorator(fn) -> FunctionRepresentingAction: - setattr(fn, FunctionBasedAction.ACTION_FUNCTION, FunctionBasedAction(fn, reads, writes)) + def __call__(self, fn) -> FunctionRepresentingAction: + setattr( + fn, + FunctionBasedAction.ACTION_FUNCTION, + FunctionBasedAction(fn, self.reads, self.writes), + ) setattr(fn, "bind", types.MethodType(bind, fn)) return fn - return decorator +class streaming_action: + @staticmethod + def pydantic( + reads: List[str], + writes: List[str], + state_input_type: Type["BaseModel"], + state_output_type: Type["BaseModel"], + stream_type: Union[Type["BaseModel"], Type[dict]], + ) -> Callable: + """Creates a streaming action that uses pydantic models. + + :param reads: The fields this consumes from the state. + :param writes: The fields this writes to the state. + :param stream_type: The pydantic model or dictionary type that is used to represent the partial results. + Use a dict if you want this untyped. + :param state_input_type: The pydantic model type that is used to represent the input state. + :param state_output_type: The pydantic model type that is used to represent the output state. + :return: The same function, decorated function. + """ + try: + from burr.integrations.pydantic import pydantic_streaming_action + except ImportError: + raise ImportError( + "Please install pydantic to use the pydantic decorator. pip install 'burr[pydantic]'" + ) + + return pydantic_streaming_action( + reads=reads, + writes=writes, + state_input_type=state_input_type, + state_output_type=state_output_type, + stream_type=stream_type, + ) -def streaming_action( - reads: List[str], writes: List[str] -) -> Callable[[Callable], FunctionRepresentingAction]: - """Decorator to create a streaming function-based action. This is user-facing. + def __init__(self, reads: List[str], writes: List[str]): + """Decorator to create a streaming function-based action. This is user-facing. - If parameters are not bound, they will be interpreted as inputs and must be passed in at runtime. + If parameters are not bound, they will be interpreted as inputs and must be passed in at runtime. - See the following example for how to use this decorator -- this reads ``prompt`` from the state and writes - ``response`` back out, yielding all intermediate chunks. + See the following example for how to use this decorator -- this reads ``prompt`` from the state and writes + ``response`` back out, yielding all intermediate chunks. - Note that this *must* return a value. If it does not, we will not know how to update the state, and - we will error out. + Note that this *must* return a value. If it does not, we will not know how to update the state, and + we will error out. - .. code-block:: python + .. code-block:: python - @streaming_action(reads=["prompt"], writes=['response']) - def streaming_response(state: State) -> Generator[dict, None, tuple[dict, State]]: - response = client.chat.completions.create( - model='gpt-3.5-turbo', - messages=[{ - 'role': 'user', - 'content': state["prompt"] - }], - temperature=0, - ) - buffer = [] - for chunk in response: - delta = chunk.choices[0].delta.content - buffer.append(delta) - # yield partial results - yield {'response': delta}, None - full_response = ''.join(buffer) - # return the final result - return {'response': full_response}, state.update(response=full_response) + @streaming_action(reads=["prompt"], writes=['response']) + def streaming_response(state: State) -> Generator[dict, None, tuple[dict, State]]: + response = client.chat.completions.create( + model='gpt-3.5-turbo', + messages=[{ + 'role': 'user', + 'content': state["prompt"] + }], + temperature=0, + ) + buffer = [] + for chunk in response: + delta = chunk.choices[0].delta.content + buffer.append(delta) + # yield partial results + yield {'response': delta}, None + full_response = ''.join(buffer) + # return the final result + return {'response': full_response}, state.update(response=full_response) - """ + """ + self.reads = reads + self.writes = writes - def wrapped(fn) -> FunctionRepresentingAction: + def __call__(self, fn: Callable) -> FunctionRepresentingAction: fn = copy_func(fn) setattr( fn, FunctionBasedAction.ACTION_FUNCTION, - FunctionBasedStreamingAction(fn, reads, writes), + FunctionBasedStreamingAction(fn, self.reads, self.writes), ) setattr(fn, "bind", types.MethodType(bind, fn)) return fn - return wrapped - ActionT = TypeVar("ActionT", bound=Action) diff --git a/burr/core/application.py b/burr/core/application.py index 6ae01f7d..30c92b9a 100644 --- a/burr/core/application.py +++ b/burr/core/application.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import contextvars import dataclasses import functools @@ -13,6 +15,7 @@ Callable, Dict, Generator, + Generic, List, Literal, Optional, @@ -27,6 +30,7 @@ from burr.common import types as burr_types from burr.core import persistence, validation from burr.core.action import ( + DEFAULT_SCHEMA, Action, AsyncStreamingAction, AsyncStreamingResultContainer, @@ -41,6 +45,7 @@ from burr.core.graph import Graph, GraphBuilder from burr.core.persistence import BaseStateLoader, BaseStateSaver from burr.core.state import State +from burr.core.typing import ActionSchema, DictBasedTypingSystem, TypingSystem from burr.core.validation import BASE_ERROR_MESSAGE from burr.lifecycle.base import ExecuteMethod, LifecycleAdapter, PostRunStepHook, PreRunStepHook from burr.lifecycle.internal import LifecycleAdapterSet @@ -48,6 +53,8 @@ from burr.visibility.tracing import tracer_factory_context_var if TYPE_CHECKING: + # TODO -- figure out whether we want to just do if TYPE_CHECKING + # for all first-class imports as Ruff suggests... from burr.tracking.base import TrackingClient logger = logging.getLogger(__name__) @@ -55,11 +62,17 @@ PRIOR_STEP = "__PRIOR_STEP" SEQUENCE_ID = "__SEQUENCE_ID" +StateType = TypeVar("StateType") +StateTypeToSet = TypeVar("StateTypeToSet") + -def _validate_result(result: dict, name: str) -> None: - if not isinstance(result, dict): +def _validate_result(result: Any, name: str, schema: ActionSchema = DEFAULT_SCHEMA) -> None: + # TODO -- split out the action schema into input/output schema types + # Currently they're tied together, but this doesn't make as much sense for single-step actions + result_type = schema.intermediate_result_type() + if not isinstance(result, result_type): raise ValueError( - f"Action {name} returned a non-dict result: {result}. " + f"Action {name} returned a non-{result_type.__name__} result: {result}. " f"All results must be dictionaries." ) @@ -72,13 +85,15 @@ def _raise_fn_return_validation_error(output: Any, action_name: str): ) -def _adjust_single_step_output(output: Union[State, Tuple[dict, State]], action_name: str): +def _adjust_single_step_output( + output: Union[State, Tuple[dict, State]], action_name: str, action_schema: ActionSchema +): """Adjusts the output of a single step action to be a tuple of (result, state) or just state""" if isinstance(output, tuple): if not len(output) == 2: _raise_fn_return_validation_error(output, action_name) - _validate_result(output[0], action_name) + _validate_result(output[0], action_name, action_schema) if not isinstance(output[1], State): _raise_fn_return_validation_error(output, action_name) return output @@ -234,11 +249,10 @@ def _run_single_step_action( # TODO -- guard all reads/writes with a subset of the state action.validate_inputs(inputs) result, new_state = _adjust_single_step_output( - action.run_and_update(state, **inputs), action.name + action.run_and_update(state, **inputs), action.name, action.schema ) - _validate_result(result, action.name) + _validate_result(result, action.name, action.schema) out = result, _state_update(state, new_state) - _validate_result(result, action.name) _validate_reducer_writes(action, new_state, action.name) return out @@ -292,7 +306,7 @@ def _run_single_step_streaming_action( f"Action {action.name} did not return a state update. For streaming actions, the last yield " f"statement must be a tuple of (result, state_update). For example, yield dict(foo='bar'), state.update(foo='bar')" ) - _validate_result(result, action.name) + _validate_result(result, action.name, action.schema) _validate_reducer_writes(action, state_update, action.name) yield result, state_update @@ -344,7 +358,8 @@ async def _arun_single_step_streaming_action( f"Action {action.name} did not return a state update. For async actions, the last yield " f"statement must be a tuple of (result, state_update). For example, yield dict(foo='bar'), state.update(foo='bar')" ) - _validate_result(result, action.name) + # TODO -- add back in validation when we have a schema + _validate_result(result, action.name, action.schema) _validate_reducer_writes(action, state_update, action.name) # TODO -- add guard against zero-length stream yield result, state_update @@ -395,7 +410,7 @@ def _run_multi_step_streaming_action( count += 1 yield next_result, None state_update = _run_reducer(action, state, result, action.name) - _validate_result(result, action.name) + _validate_result(result, action.name, action.schema) _validate_reducer_writes(action, state_update, action.name) yield result, state_update @@ -439,7 +454,7 @@ async def _arun_multi_step_streaming_action( count += 1 yield next_result, None state_update = _run_reducer(action, state, result, action.name) - _validate_result(result, action.name) + _validate_result(result, action.name, action.schema) _validate_reducer_writes(action, state_update, action.name) yield result, state_update @@ -451,9 +466,9 @@ async def _arun_single_step_action( state_to_use = state action.validate_inputs(inputs) result, new_state = _adjust_single_step_output( - await action.run_and_update(state_to_use, **inputs), action.name + await action.run_and_update(state_to_use, **inputs), action.name, action.schema ) - _validate_result(result, action.name) + _validate_result(result, action.name, action.schema) _validate_reducer_writes(action, new_state, action.name) return result, _state_update(state, new_state) @@ -664,11 +679,15 @@ def post_run_step( del self.token_pointer_map[(app_id, sequence_id)] -class Application: +ApplicationStateType = TypeVar("ApplicationStateType") +StreamResultType = TypeVar("StreamResultType", bound=Union[dict, Any]) + + +class Application(Generic[ApplicationStateType]): def __init__( self, graph: Graph, - state: State, + state: State[ApplicationStateType], partition_key: Optional[str], uid: str, entrypoint: str, @@ -824,7 +843,9 @@ def reset_to_entrypoint(self) -> None: in your graph, but this will do the trick if you need it!""" self._set_state(self._state.wipe(delete=[PRIOR_STEP])) - def _update_internal_state_value(self, new_state: State, next_action: Action) -> State: + def _update_internal_state_value( + self, new_state: State[ApplicationStateType], next_action: Action + ) -> State[ApplicationStateType]: """Updates the internal state values of the new state.""" new_state = new_state.update( **{ @@ -889,7 +910,9 @@ def _process_inputs(self, inputs: Dict[str, Any], action: Action) -> Dict[str, A # @telemetry.capture_function_usage # ditto with step() @_call_execute_method_pre_post(ExecuteMethod.astep) - async def astep(self, inputs: Dict[str, Any] = None) -> Optional[Tuple[Action, dict, State]]: + async def astep( + self, inputs: Optional[Dict[str, Any]] = None + ) -> Optional[Tuple[Action, dict, State[ApplicationStateType]]]: """Asynchronous version of step. :param inputs: Inputs to the action -- this is if this action @@ -1022,7 +1045,7 @@ def _return_value_iterate( halt_after: list[str], prior_action: Optional[Action], result: Optional[dict], - ) -> Tuple[Optional[Action], Optional[dict], State]: + ) -> Tuple[Optional[Action], Optional[dict], State[ApplicationStateType]]: """Utility function to decide what to return for iterate/arun. Note that run() will delegate to the return value of iterate, whereas arun cannot delegate to the return value of aiterate (as async generators cannot return a value). We put the code centrally to clean up the logic. @@ -1057,7 +1080,11 @@ def iterate( halt_before: list[str] = None, halt_after: list[str] = None, inputs: Optional[Dict[str, Any]] = None, - ) -> Generator[Tuple[Action, dict, State], None, Tuple[Action, Optional[dict], State]]: + ) -> Generator[ + Tuple[Action, dict, State[ApplicationStateType]], + None, + Tuple[Action, Optional[dict], State[ApplicationStateType]], + ]: """Returns a generator that calls step() in a row, enabling you to see the state of the system as it updates. Note this returns a generator, and also the final result (for convenience). @@ -1095,7 +1122,7 @@ async def aiterate( halt_before: list[str] = None, halt_after: list[str] = None, inputs: Optional[Dict[str, Any]] = None, - ) -> AsyncGenerator[Tuple[Action, dict, State], None]: + ) -> AsyncGenerator[Tuple[Action, dict, State[ApplicationStateType]], None]: """Returns a generator that calls step() in a row, enabling you to see the state of the system as it updates. This is the asynchronous version so it has no capability of t @@ -1125,7 +1152,7 @@ def run( halt_before: list[str] = None, halt_after: list[str] = None, inputs: Optional[Dict[str, Any]] = None, - ) -> Tuple[Action, Optional[dict], State]: + ) -> Tuple[Action, Optional[dict], State[ApplicationStateType]]: """Runs your application through until completion. Does not give access to the state along the way -- if you want that, use iterate(). @@ -1177,9 +1204,9 @@ async def arun( def stream_result( self, halt_after: list[str], - halt_before: list[str] = None, + halt_before: Optional[list[str]] = None, inputs: Optional[Dict[str, Any]] = None, - ) -> Tuple[Action, StreamingResultContainer]: + ) -> Tuple[Action, StreamingResultContainer[ApplicationStateType, Union[dict, Any]]]: """Streams a result out. :param halt_after: The list of actions to halt after execution of. It will halt on the first one. @@ -1426,9 +1453,9 @@ def callback( async def astream_result( self, halt_after: list[str], - halt_before: list[str] = None, + halt_before: Optional[list[str]] = None, inputs: Optional[Dict[str, Any]] = None, - ) -> Tuple[Action, AsyncStreamingResultContainer]: + ) -> Tuple[Action, AsyncStreamingResultContainer[ApplicationStateType, Union[dict, Any]]]: """Streams a result out in an asynchronous manner. :param halt_after: The list of actions to halt after execution of. It will halt on the first one. @@ -1680,7 +1707,7 @@ async def callback( sequence_id=self.sequence_id, exception=e, ) - await call_execute_method_wrapper.acall_post(self, None, e) + await call_execute_method_wrapper.acall_post(self, e) await self._adapter_set.call_all_lifecycle_hooks_sync_and_async( "post_end_stream", action=next_action.name, @@ -1726,13 +1753,13 @@ def visualize( **engine_kwargs, ) - def _set_state(self, new_state: State): + def _set_state(self, new_state: State[ApplicationStateType]): self._state = new_state def get_next_action(self) -> Optional[Action]: return self._graph.get_next_node(self._state.get(PRIOR_STEP), self._state, self.entrypoint) - def update_state(self, new_state: State): + def update_state(self, new_state: State[ApplicationStateType]): """Updates state -- this is meant to be called if you need to do anything with the state. For example: 1. Reset it (after going through a loop) @@ -1744,7 +1771,7 @@ def update_state(self, new_state: State): self._state = new_state @property - def state(self) -> State: + def state(self) -> State[ApplicationStateType]: """Gives the state. Recall that state is purely immutable -- anything you do with this state will not be persisted unless you subsequently call update_state. @@ -1837,7 +1864,7 @@ def partition_key(self) -> Optional[str]: return self._partition_key @property - def builder(self) -> Optional["ApplicationBuilder"]: + def builder(self) -> Optional["ApplicationBuilder[ApplicationStateType]"]: """Returns the application builder that was used to build this application. Note that this asusmes the application was built using the builder. Otherwise, @@ -1873,10 +1900,10 @@ def _validate_start(start: Optional[str], actions: Set[str]): ) -class ApplicationBuilder: +class ApplicationBuilder(Generic[StateType]): def __init__(self): self.start = None - self.state: Optional[State] = None + self.state: Optional[State[StateType]] = None self.lifecycle_adapters: List[LifecycleAdapter] = list() self.app_id: str = str(uuid.uuid4()) self.partition_key: Optional[str] = None @@ -1894,10 +1921,11 @@ def __init__(self): self.tracker = None self.graph_builder = None self.prebuilt_graph = None + self.typing_system = None def with_identifiers( self, app_id: str = None, partition_key: str = None, sequence_id: int = None - ) -> "ApplicationBuilder": + ) -> "ApplicationBuilder[StateType]": """Assigns various identifiers to the application. This is used for tracking, persistence, etc... :param app_id: Application ID -- this will be assigned to a uuid if not set. @@ -1915,7 +1943,21 @@ def with_identifiers( self.sequence_id = sequence_id return self - def with_state(self, **kwargs) -> "ApplicationBuilder": + def with_typing( + self, typing_system: TypingSystem[StateTypeToSet] + ) -> "ApplicationBuilder[StateTypeToSet]": + """Sets the typing system for the application. This is used to enforce typing on the state. + + :param typing_system: Typing system to use + :return: The application builder for future chaining. + """ + if typing_system is not None: + self.typing_system = typing_system + return self # type: ignore + + def with_state( + self, state: Optional[Union[State, StateTypeToSet]] = None, **kwargs + ) -> "ApplicationBuilder[StateType]": """Sets initial values in the state. If you want to load from a prior state, you can do so here and pass the values in. @@ -1930,13 +1972,30 @@ def with_state(self, **kwargs) -> "ApplicationBuilder": "the .initialize_from() API. Either allow the persister to set the " "state, or set the state manually." ) - if self.state is not None: + if state is not None: + if self.state is not None: + raise ValueError( + BASE_ERROR_MESSAGE + + "State items have already been set -- you cannot use the type-based API as well." + " Either set state with with_state(**kwargs) or pass in a state/typed object." + ) + if isinstance(state, State): + self.state = state + elif self.typing_system is not None: + self.state = self.typing_system.construct_state(state) + else: + raise ValueError( + BASE_ERROR_MESSAGE + + "You have not set a typing system, and you are passing in a typed state object." + " Please set a typing system using with_typing before doing so." + ) + elif self.state is not None: self.state = self.state.update(**kwargs) else: self.state = State(kwargs) return self - def with_graph(self, graph: Graph) -> "ApplicationBuilder": + def with_graph(self, graph: Graph) -> "ApplicationBuilder[StateType]": """Adds a prebuilt graph -- this is an alternative to using the with_actions and with_transitions methods. While you will likely use with_actions and with_transitions, you may want this in a few cases: @@ -1969,7 +2028,7 @@ def _initialize_graph_builder(self): if self.graph_builder is None: self.graph_builder = GraphBuilder() - def with_entrypoint(self, action: str) -> "ApplicationBuilder": + def with_entrypoint(self, action: str) -> "ApplicationBuilder[StateType]": """Adds an entrypoint to the application. This is the action that will be run first. This can only be called once. @@ -1988,7 +2047,7 @@ def with_entrypoint(self, action: str) -> "ApplicationBuilder": def with_actions( self, *action_list: Union[Action, Callable], **action_dict: Union[Action, Callable] - ) -> "ApplicationBuilder": + ) -> "ApplicationBuilder[StateType]": """Adds an action to the application. The actions are granted names (using the with_name) method post-adding, using the kw argument. If it already has a name (or you wish to use the function name, raw, and it is a function-based-action), then you can use the *args* parameter. This is the only supported way to add actions. @@ -2007,7 +2066,7 @@ def with_transitions( *transitions: Union[ Tuple[Union[str, list[str]], str], Tuple[Union[str, list[str]], str, Condition] ], - ) -> "ApplicationBuilder": + ) -> "ApplicationBuilder[StateType]": """Adds transitions to the application. Transitions are specified as tuples of either: 1. (from, to, condition) 2. (from, to) -- condition is set to DEFAULT (which is a fallback) @@ -2024,7 +2083,7 @@ def with_transitions( self.graph_builder = self.graph_builder.with_transitions(*transitions) return self - def with_hooks(self, *adapters: LifecycleAdapter) -> "ApplicationBuilder": + def with_hooks(self, *adapters: LifecycleAdapter) -> "ApplicationBuilder[StateType]": """Adds a lifecycle adapter to the application. This is a way to add hooks to the application so that they are run at the appropriate times. You can use this to synchronize state out, log results, etc... @@ -2093,7 +2152,7 @@ def initialize_from( fork_from_app_id: str = None, fork_from_partition_key: str = None, fork_from_sequence_id: int = None, - ) -> "ApplicationBuilder": + ) -> "ApplicationBuilder[StateType]": """Initializes the application we will build from some prior state object. Note (1) that you can *either* call this or use `with_state` and `with_entrypoint`. @@ -2133,7 +2192,7 @@ def initialize_from( def with_state_persister( self, persister: Union[BaseStateSaver, LifecycleAdapter], on_every: str = "step" - ) -> "ApplicationBuilder": + ) -> "ApplicationBuilder[StateType]": """Adds a state persister to the application. This is a way to persist state out to a database, file, etc... at the specified interval. This is one of two options: @@ -2156,7 +2215,7 @@ def with_state_persister( def with_spawning_parent( self, app_id: str, sequence_id: int, partition_key: Optional[str] = None - ) -> "ApplicationBuilder": + ) -> "ApplicationBuilder[StateType]": """Sets the 'spawning' parent application that created this app. This is used for tracking purposes. Doing this creates a parent/child relationship. There can be many spawned children from a single sequence ID (just as there can be many forks of an app). @@ -2251,14 +2310,14 @@ def _get_built_graph(self) -> Graph: raise ValueError( BASE_ERROR_MESSAGE + "You must set the graph using with_graph, or use with_entrypoint, with_actions, and with_transitions" - "to build the graph." + " to build the graph." ) if self.graph_builder is not None: return self.graph_builder.build() return self.prebuilt_graph @telemetry.capture_function_usage - def build(self) -> Application: + def build(self) -> Application[StateType]: """Builds the application. This function is a bit messy as we iron out the exact logic and rigor we want around things. @@ -2274,6 +2333,10 @@ def build(self) -> Application: self._load_from_persister() graph = self._get_built_graph() _validate_start(self.start, {action.name for action in graph.actions}) + typing_system: TypingSystem[StateType] = ( + self.typing_system if self.typing_system is not None else DictBasedTypingSystem() + ) # type: ignore + self.state = self.state.with_typing_system(typing_system=typing_system) return Application( graph=graph, state=self.state, @@ -2283,19 +2346,23 @@ def build(self) -> Application: entrypoint=self.start, adapter_set=LifecycleAdapterSet(*self.lifecycle_adapters), builder=self, - fork_parent_pointer=burr_types.ParentPointer( - app_id=self.fork_from_app_id, - partition_key=self.fork_from_partition_key, - sequence_id=self.fork_from_sequence_id, - ) - if self.loaded_from_fork - else None, + fork_parent_pointer=( + burr_types.ParentPointer( + app_id=self.fork_from_app_id, + partition_key=self.fork_from_partition_key, + sequence_id=self.fork_from_sequence_id, + ) + if self.loaded_from_fork + else None + ), tracker=self.tracker, - spawning_parent_pointer=burr_types.ParentPointer( - app_id=self.spawn_from_app_id, - partition_key=self.spawn_from_partition_key, - sequence_id=self.spawn_from_sequence_id, - ) - if self.spawn_from_app_id is not None - else None, + spawning_parent_pointer=( + burr_types.ParentPointer( + app_id=self.spawn_from_app_id, + partition_key=self.spawn_from_partition_key, + sequence_id=self.spawn_from_sequence_id, + ) + if self.spawn_from_app_id is not None + else None + ), ) diff --git a/burr/core/state.py b/burr/core/state.py index 62c9a583..67037b9f 100644 --- a/burr/core/state.py +++ b/burr/core/state.py @@ -4,9 +4,10 @@ import importlib import inspect import logging -from typing import Any, Callable, Dict, Iterator, Mapping, Union +from typing import Any, Callable, Dict, Generic, Iterator, Mapping, Optional, TypeVar, Union from burr.core import serde +from burr.core.typing import DictBasedTypingSystem, TypingSystem logger = logging.getLogger(__name__) @@ -142,7 +143,7 @@ def apply_mutate(self, inputs: dict): if key not in inputs: inputs[key] = [] if not isinstance(inputs[key], list): - raise ValueError(f"Cannot append to non-list value {key}={inputs[self.key]}") + raise ValueError(f"Cannot append to non-list value {key}={inputs[key]}") inputs[key].append(value) def validate(self, input_state: Dict[str, Any]): @@ -211,22 +212,51 @@ def apply_mutate(self, inputs: dict): inputs.pop(key, None) -class State(Mapping): +StateType = TypeVar("StateType", bound=Union[Dict[str, Any], Any]) +AssignedStateType = TypeVar("AssignedStateType") + + +class State(Mapping, Generic[StateType]): """An immutable state object. This is the only way to interact with state in Burr.""" - def __init__(self, initial_values: Dict[str, Any] = None): + def __init__( + self, + initial_values: Optional[Dict[str, Any]] = None, + typing_system: Optional[TypingSystem[StateType]] = None, + ): if initial_values is None: initial_values = dict() + self._typing_system = ( + typing_system if typing_system is not None else DictBasedTypingSystem() # type: ignore + ) self._state = initial_values - def apply_operation(self, operation: StateDelta) -> "State": + @property + def typing_system(self) -> TypingSystem[StateType]: + return self._typing_system + + def with_typing_system( + self, typing_system: TypingSystem[AssignedStateType] + ) -> "State[AssignedStateType]": + """Copies state with a specific typing system""" + return State(self._state, typing_system=typing_system) + + @property + def data(self) -> StateType: + return self._typing_system.construct_data(self) # type: ignore + + def apply_operation(self, operation: StateDelta) -> "State[StateType]": """Applies a given operation to the state, returning a new state""" - new_state = copy.deepcopy(self._state) # TODO -- restrict to just the read keys + + # Moved to copy.copy instead of copy.deepcopy + # TODO -- just copy the ones that have changed + # And if they can't be copied then we use the same ones... + new_state = copy.copy(self._state) # TODO -- restrict to just the read keys operation.validate(new_state) operation.apply_mutate( new_state ) # todo -- validate that the write keys are the only different ones - return State(new_state) + return State(new_state, typing_system=self._typing_system) def get_all(self) -> Dict[str, Any]: """Returns the entire state, realize as a dictionary. This is a copy.""" @@ -251,7 +281,7 @@ def _serialize(k, v, **extrakwargs) -> Union[dict, str]: return {k: _serialize(k, v, **kwargs) for k, v in _dict.items()} @classmethod - def deserialize(cls, json_dict: dict, **kwargs) -> "State": + def deserialize(cls, json_dict: dict, **kwargs) -> "State[StateType]": """Converts a dictionary representing a JSON object back into a state""" def _deserialize(k, v: Union[str, dict], **extrakwargs) -> Callable: @@ -262,7 +292,7 @@ def _deserialize(k, v: Union[str, dict], **extrakwargs) -> Callable: return State({k: _deserialize(k, v, **kwargs) for k, v in json_dict.items()}) - def update(self, **updates: Any) -> "State": + def update(self, **updates: Any) -> "State[StateType]": """Updates the state with a set of key-value pairs Does an upsert operation (if the keys exist their value will be overwritten, otherwise they will be created) @@ -277,7 +307,7 @@ def update(self, **updates: Any) -> "State": """ return self.apply_operation(SetFields(updates)) - def append(self, **updates: Any) -> "State": + def append(self, **updates: Any) -> "State[StateType]": """Appends to the state with a set of key-value pairs. Each one must correspond to a list-like object, or an error will be raised. @@ -295,7 +325,7 @@ def append(self, **updates: Any) -> "State": return self.apply_operation(AppendFields(updates)) - def increment(self, **updates: int) -> "State": + def increment(self, **updates: int) -> "State[StateType]": """Increments the state with a set of key-value pairs. Each one must correspond to an integer, or an error will be raised. @@ -304,7 +334,7 @@ def increment(self, **updates: int) -> "State": """ "" return self.apply_operation(IncrementFields(updates)) - def wipe(self, delete: list[str] = None, keep: list[str] = None): + def wipe(self, delete: Optional[list[str]] = None, keep: Optional[list[str]] = None): """Wipes the state, either by deleting the keys in delete and keeping everything else or keeping the keys in keep. and deleting everything else. If you pass nothing in it will delete the whole thing. @@ -324,14 +354,17 @@ def wipe(self, delete: list[str] = None, keep: list[str] = None): fields_to_delete = [key for key in self._state if key not in keep] return self.apply_operation(DeleteField(fields_to_delete)) - def merge(self, other: "State") -> "State": + def merge(self, other: "State") -> "State[StateType]": """Merges two states together, overwriting the values in self with those in other.""" - return State({**self.get_all(), **other.get_all()}) + return State({**self.get_all(), **other.get_all()}, self.typing_system) - def subset(self, *keys: str, ignore_missing: bool = True) -> "State": + def subset(self, *keys: str, ignore_missing: bool = True) -> "State[StateType]": """Returns a subset of the state, with only the given keys""" - return State({key: self[key] for key in keys if key in self or not ignore_missing}) + return State( + {key: self[key] for key in keys if key in self or not ignore_missing}, + self.typing_system, + ) def __getitem__(self, __k: str) -> Any: return self._state[__k] diff --git a/burr/core/typing.py b/burr/core/typing.py new file mode 100644 index 00000000..39de1757 --- /dev/null +++ b/burr/core/typing.py @@ -0,0 +1,130 @@ +from __future__ import annotations + +import abc +from typing import TYPE_CHECKING, Generic, Type, TypeVar + +BaseType = TypeVar("BaseType") +# SpecificType = TypeVar('SpecificType', bound=BaseType) + +if TYPE_CHECKING: + from burr.core import Action, Graph, State + +try: + from typing import Self +except ImportError: + Self = "TypingSystem" + + +class TypingSystem(abc.ABC, Generic[BaseType]): + @abc.abstractmethod + def state_type(self) -> Type[BaseType]: + """Gives the type that represents the state of the + application at any given time. Note that this must have + adequate support for Optionals (E.G. non-required values). + + :return: + """ + + @abc.abstractmethod + def state_pre_action_run_type(self, action: Action, graph: Graph) -> Type[BaseType]: + """Gives the type that represents the state after an action has completed. + Note that this could be smart -- E.g. it should have all possible upstream + types filled in. + + :param action: + :return: + """ + + @abc.abstractmethod + def state_post_action_run_type(self, action: Action, graph: Graph) -> Type[BaseType]: + """Gives the type that represents the state after an action has completed. + Note that this could be smart -- E.g. it should have all possible upstream + types filled in. + + :param action: + :return: + """ + + def validate_state(self, state: State) -> None: + """Validates the state to ensure it is valid. + + :param state: + :return: + """ + + @abc.abstractmethod + def construct_data(self, state: State[BaseType]) -> BaseType: + """Constructs a type based on the arguments passed in. + + :param kwargs: + :return: + """ + + @abc.abstractmethod + def construct_state(self, data: BaseType) -> State[BaseType]: + """Constructs a state based on the arguments passed in. + + :param kwargs: + :return: + """ + + +StateInputType = TypeVar("StateInputType") +StateOutputType = TypeVar("StateOutputType") +IntermediateResultType = TypeVar("IntermediateResultType") + + +class ActionSchema( + abc.ABC, + Generic[ + StateInputType, + StateOutputType, + IntermediateResultType, + ], +): + """Quick wrapper class to represent a schema. Note that this is currently used internally, + just to store the appropriate information. This does not validate or do conversion, currently that + is done within the pydantic model state typing system (which is also internal in its implementation). + + + + We will likely centralize that logic at some point when we get more -- it would look something like this: + 1. Action is passed an ActionSchema + 2. Action is parameterized on the ActionSchema types + 3. Action takes state, validates the type and converts to StateInputType + 4. Action runs, returns intermediate result + state + 5. Action validates intermediate result type (or converts to dict? Probably just keeps it + 6. Action converts StateOutputType to State + """ + + @abc.abstractmethod + def state_input_type() -> Type[StateInputType]: + pass + + @abc.abstractmethod + def state_output_type() -> Type[StateOutputType]: + pass + + @abc.abstractmethod + def intermediate_result_type() -> Type[IntermediateResultType]: + pass + + +class DictBasedTypingSystem(TypingSystem[dict]): + """Effectively a no-op. State is backed by a dictionary, which allows every state item + to... be a dictionary.""" + + def state_type(self) -> Type[dict]: + return dict + + def state_pre_action_run_type(self, action: Action, graph: Graph) -> Type[dict]: + return dict + + def state_post_action_run_type(self, action: Action, graph: Graph) -> Type[dict]: + return dict + + def construct_data(self, state: State[dict]) -> dict: + return state.get_all() + + def construct_state(self, data: dict) -> State[dict]: + return State(data, typing_system=self) diff --git a/burr/core/validation.py b/burr/core/validation.py index e1e95514..16baf2a0 100644 --- a/burr/core/validation.py +++ b/burr/core/validation.py @@ -12,5 +12,5 @@ def assert_set(value: Optional[Any], field: str, method: str): if value is None: raise ValueError( BASE_ERROR_MESSAGE - + f"Must set `{field}` before building application! Do so with ApplicationBuilder.{method}" + + f"Must call `{method}` before building application! Do so with ApplicationBuilder." ) diff --git a/burr/integrations/pydantic.py b/burr/integrations/pydantic.py new file mode 100644 index 00000000..ebe8a63a --- /dev/null +++ b/burr/integrations/pydantic.py @@ -0,0 +1,394 @@ +from __future__ import annotations + +import copy +import inspect +import types +from typing import ( + AsyncGenerator, + Awaitable, + Callable, + Generator, + List, + Optional, + ParamSpec, + Tuple, + Type, + TypeVar, + Union, +) + +import pydantic +from pydantic_core import PydanticUndefined + +from burr.core import Action, Graph, State +from burr.core.action import ( + FunctionBasedAction, + FunctionBasedStreamingAction, + bind, + derive_inputs_from_fn, +) +from burr.core.typing import ActionSchema, TypingSystem + +Inputs = ParamSpec("Inputs") + +PydanticActionFunction = Callable[..., Union[pydantic.BaseModel, Awaitable[pydantic.BaseModel]]] + + +def model_to_dict(model: pydantic.BaseModel, include: Optional[List[str]] = None) -> dict: + """Utility function to convert a pydantic model to a dictionary.""" + keys = model.model_fields.keys() + keys = keys if include is None else [item for item in include if item in model.model_fields] + return {key: getattr(model, key) for key in keys} + + +ModelType = TypeVar("ModelType", bound=pydantic.BaseModel) + + +def subset_model( + model: Type[ModelType], + fields: List[str], + force_optional_fields: List[str], + model_name_suffix: str, +) -> Type[ModelType]: + """Creates a new pydantic model that is a subset of the original model. + This is just to make it more efficient, as we can dynamically alter pydantic models + + :param fields: Fields that we want to include in the new model. + :param force_optional_fields: Fields that we want to include in the new model, but that will always be optional. + :param model: The model type to subset. + :param model_name_suffix: The suffix to add to the model name. + :return: The new model type. + """ + new_fields = {} + + for name, field_info in model.model_fields.items(): + if name in fields: + # copy directly + # TODO -- handle cross-field validation + new_fields[name] = (field_info.annotation, field_info) + elif name in force_optional_fields: + new_field_info = copy.deepcopy(field_info) + if new_field_info.default_factory is None and ( + new_field_info.default is PydanticUndefined + ): + # in this case we can set to None + new_field_info.default = None + annotation = field_info.annotation + if annotation is not None: + new_field_info.annotation = Optional[annotation] # type: ignore + new_fields[name] = (new_field_info.annotation, new_field_info) + return pydantic.create_model(model.__name__ + model_name_suffix, **new_fields) # type: ignore + + +def merge_to_state(model: pydantic.BaseModel, write_keys: List[str], state: State) -> State: + """Merges a pydantic model that is a subset of the new state back into the state + TODO -- implement + TODO -- consider validating that the entire state is correct + TODO -- consider validating just the deltas (if that's possible) + """ + write_dict = model_to_dict(model=model, include=write_keys) + return state.update(**write_dict) + + +def model_from_state(model: Type[ModelType], state: State) -> ModelType: + """Creates a model from the state object -- capturing just the fields that are relevant to the model itself. + + :param model: model type to create + :param state: state object to create from + :return: model object + """ + keys = [item for item in model.model_fields.keys() if item in state] + return model(**{key: state[key] for key in keys}) + + +def _validate_and_extract_signature_types( + fn: PydanticActionFunction, +) -> Tuple[Type[pydantic.BaseModel], Type[pydantic.BaseModel]]: + sig = inspect.signature(fn) + if "state" not in sig.parameters: + raise ValueError( + f"Function fn: {fn.__qualname__} is not a valid pydantic action. " + "The first argument of a pydantic " + "action must be the state object. Got signature: {sig}." + ) + state_model = sig.parameters["state"].annotation + if state_model is inspect.Parameter.empty or not issubclass(state_model, pydantic.BaseModel): + raise ValueError( + f"Function fn: {fn.__qualname__} is not a valid pydantic action. " + "a type annotation of a type extending: pydantic.BaseModel. Got parameter " + "state: {state_model.__qualname__}." + ) + if sig.return_annotation is inspect.Parameter.empty or not issubclass( + sig.return_annotation, pydantic.BaseModel + ): + raise ValueError( + f"Function fn: {fn.__qualname__} is not a valid pydantic action. " + "The return type must be a subclass of pydantic" + ".BaseModel. Got return type: {sig.return_annotation}." + ) + return state_model, sig.return_annotation + + +def _validate_keys(model: Type[pydantic.BaseModel], keys: List[str], fn: Callable) -> None: + missing_keys = [key for key in keys if key not in model.model_fields] + if missing_keys: + raise ValueError( + f"Function fn: {fn.__qualname__} is not a valid pydantic action. " + f"The keys: {missing_keys} are not present in the model: {model.__qualname__}." + ) + + +StateInputType = TypeVar("StateInputType", bound=pydantic.BaseModel) +StateOutputType = TypeVar("StateOutputType", bound=pydantic.BaseModel) +IntermediateResultType = TypeVar("IntermediateResultType", bound=Union[pydantic.BaseModel, dict]) + + +class PydanticActionSchema(ActionSchema[StateInputType, StateOutputType, IntermediateResultType]): + def __init__( + self, + input_type: Type[StateInputType], + output_type: Type[StateOutputType], + intermediate_result_type: Type[IntermediateResultType], + ): + self._input_type = input_type + self._output_type = output_type + self._intermediate_result_type = intermediate_result_type + + def state_input_type(self) -> Type[StateInputType]: + return self._input_type + + def state_output_type(self) -> Type[StateOutputType]: + return self._output_type + + def intermediate_result_type(self) -> type[IntermediateResultType]: + return self._intermediate_result_type + + +def pydantic_action( + reads: List[str], + writes: List[str], + state_input_type: Optional[Type[pydantic.BaseModel]] = None, + state_output_type: Optional[Type[pydantic.BaseModel]] = None, +) -> Callable[[PydanticActionFunction], PydanticActionFunction]: + """See docstring for @action.pydantic""" + + def decorator(fn: PydanticActionFunction) -> PydanticActionFunction: + if state_input_type is None and state_output_type is None: + itype, otype = _validate_and_extract_signature_types(fn) + + elif state_input_type is not None and state_output_type is not None: + itype, otype = state_input_type, state_output_type + else: + raise ValueError( + "If you specify state_input_type or state_output_type, you must specify both." + ) + _validate_keys(model=itype, keys=reads, fn=fn) + _validate_keys(model=otype, keys=writes, fn=fn) + SubsetInputType = subset_model( + model=itype, + fields=reads, + force_optional_fields=[item for item in writes if item not in reads], + model_name_suffix=f"{fn.__name__}_input", + ) + SubsetOutputType = subset_model( + model=otype, + fields=writes, + force_optional_fields=[], + model_name_suffix=f"{fn.__name__}_input", + ) + # TODO -- figure out + + def action_function(state: State, **kwargs) -> State: + model_to_use = model_from_state(model=SubsetInputType, state=state) + result = fn(state=model_to_use, **kwargs) + # TODO -- validate that we can always construct this from the dict... + # We really want a copy-type function + output = SubsetOutputType(**model_to_dict(result, include=writes)) + return merge_to_state(model=output, write_keys=writes, state=state) + + async def async_action_function(state: State, **kwargs) -> State: + model_to_use = model_from_state(model=SubsetInputType, state=state) + result = await fn(state=model_to_use, **kwargs) + output = SubsetOutputType(**model_to_dict(result, include=writes)) + return merge_to_state(model=output, write_keys=writes, state=state) + + is_async = inspect.iscoroutinefunction(fn) + # This recreates the @action decorator + # TODO -- use the @action decorator directly + # TODO -- ensure that the function is the right one -- specifically it probably won't show code in the UI + # now + setattr( + fn, + FunctionBasedAction.ACTION_FUNCTION, + FunctionBasedAction( + async_action_function if is_async else action_function, + reads, + writes, + input_spec=derive_inputs_from_fn({}, fn), + originating_fn=fn, + schema=PydanticActionSchema( + input_type=SubsetInputType, + output_type=SubsetOutputType, + intermediate_result_type=dict, + ), + ), + ) + setattr(fn, "bind", types.MethodType(bind, fn)) + # TODO -- figure out typing + # It's not smart enough to know that we have satisfied the type signature, + # as we dynamically apply it using setattr + return fn + + return decorator + + +PartialType = Union[Type[pydantic.BaseModel], Type[dict]] + +PydanticStreamingActionFunctionSync = Callable[ + ..., Generator[Tuple[Union[pydantic.BaseModel, dict], Optional[pydantic.BaseModel]], None, None] +] + +PydanticStreamingActionFunctionAsync = Callable[ + ..., AsyncGenerator[Tuple[Union[pydantic.BaseModel, dict], Optional[pydantic.BaseModel]], None] +] + +PydanticStreamingActionFunction = Union[ + PydanticStreamingActionFunctionSync, PydanticStreamingActionFunctionAsync +] + +PydanticStreamingActionFunctionVar = TypeVar( + "PydanticStreamingActionFunctionVar", bound=PydanticStreamingActionFunction +) + + +def _validate_and_extract_signature_types_streaming( + fn: PydanticStreamingActionFunction, + stream_type: Optional[Union[Type[pydantic.BaseModel], Type[dict]]], + state_input_type: Optional[Type[pydantic.BaseModel]] = None, + state_output_type: Optional[Type[pydantic.BaseModel]] = None, +) -> Tuple[ + Type[pydantic.BaseModel], Type[pydantic.BaseModel], Union[Type[dict], Type[pydantic.BaseModel]] +]: + if stream_type is None: + # TODO -- derive from the signature + raise ValueError(f"stream_type is required for function: {fn.__qualname__}") + if state_input_type is None: + # TODO -- derive from the signature + raise ValueError(f"state_input_type is required for function: {fn.__qualname__}") + if state_output_type is None: + # TODO -- derive from the signature + raise ValueError(f"state_output_type is required for function: {fn.__qualname__}") + return state_input_type, state_output_type, stream_type + + +def pydantic_streaming_action( + reads: List[str], + writes: List[str], + state_input_type: Type[pydantic.BaseModel], + state_output_type: Type[pydantic.BaseModel], + stream_type: PartialType, +) -> Callable[[PydanticStreamingActionFunction], PydanticStreamingActionFunction]: + """See docstring for @streaming_action.pydantic""" + + def decorator(fn: PydanticStreamingActionFunctionVar) -> PydanticStreamingActionFunctionVar: + itype, otype, stream_type_processed = _validate_and_extract_signature_types_streaming( + fn, stream_type, state_input_type=state_input_type, state_output_type=state_output_type + ) + _validate_keys(model=itype, keys=reads, fn=fn) + _validate_keys(model=otype, keys=writes, fn=fn) + SubsetInputType = subset_model( + model=itype, + fields=reads, + force_optional_fields=[item for item in writes if item not in reads], + model_name_suffix=f"{fn.__name__}_input", + ) + SubsetOutputType = subset_model( + model=otype, + fields=writes, + force_optional_fields=[], + model_name_suffix=f"{fn.__name__}_input", + ) + # PartialModelType = stream_type_processed # TODO -- attach to action + # We don't currently use this, but we will be passing to the action to validate + + def action_generator( + state: State, **kwargs + ) -> Generator[tuple[PartialType, Optional[State]], None, None]: + model_to_use = model_from_state(model=SubsetInputType, state=state) + for partial, state_update in fn(state=model_to_use, **kwargs): + if state_update is None: + yield partial, None + else: + output = SubsetOutputType(**model_to_dict(state_update, include=writes)) + yield partial, merge_to_state(model=output, write_keys=writes, state=state) + + async def async_action_generator( + state: State, **kwargs + ) -> AsyncGenerator[tuple[dict, Optional[State]], None]: + model_to_use = model_from_state(model=SubsetInputType, state=state) + async for partial, state_update in fn(state=model_to_use, **kwargs): + if state_update is None: + yield partial, None + else: + output = SubsetOutputType(**model_to_dict(state_update, include=writes)) + yield partial, merge_to_state(model=output, write_keys=writes, state=state) + + is_async = inspect.isasyncgenfunction(fn) + # This recreates the @streaming_action decorator + # TODO -- use the @streaming_action decorator directly + setattr( + fn, + FunctionBasedAction.ACTION_FUNCTION, + FunctionBasedStreamingAction( + async_action_generator if is_async else action_generator, + reads, + writes, + input_spec=derive_inputs_from_fn({}, fn), + originating_fn=fn, + schema=PydanticActionSchema( + input_type=SubsetInputType, + output_type=SubsetOutputType, + intermediate_result_type=stream_type_processed, + ), + ), + ) + setattr(fn, "bind", types.MethodType(bind, fn)) + return fn + + return decorator + + +StateModel = TypeVar("StateModel", bound=pydantic.BaseModel) + + +class PydanticTypingSystem(TypingSystem[StateModel]): + """Typing system for pydantic models. + + :param TypingSystem: Parameterized on the state model type. + """ + + def __init__(self, model_type: Type[StateModel]): + self.model_type = model_type + + def state_type(self) -> Type[StateModel]: + return self.model_type + + def state_pre_action_run_type(self, action: Action, graph: Graph) -> Type[pydantic.BaseModel]: + raise NotImplementedError( + "TODO -- crawl through" + "the graph to figure out what can possibly be optional and what can't..." + "First get all " + ) + + def state_post_action_run_type(self, action: Action, graph: Graph) -> Type[pydantic.BaseModel]: + raise NotImplementedError( + "TODO -- crawl through" + "the graph to figure out what can possibly be optional and what can't..." + "First get all " + ) + + def construct_data(self, state: State) -> StateModel: + return model_from_state(model=self.model_type, state=state) + + def construct_state(self, data: StateModel) -> State: + return State(model_to_dict(data)) diff --git a/burr/telemetry.py b/burr/telemetry.py index 57879eed..0fb47fc0 100644 --- a/burr/telemetry.py +++ b/burr/telemetry.py @@ -24,7 +24,7 @@ import platform import threading import uuid -from typing import TYPE_CHECKING, Callable, List +from typing import TYPE_CHECKING, Callable, List, TypeVar from urllib import request if TYPE_CHECKING: @@ -256,7 +256,10 @@ def create_and_send_cli_event(command: str): send_event_json(event) -def capture_function_usage(call_fn: Callable) -> Callable: +CallableT = TypeVar("CallableT", bound=Callable) + + +def capture_function_usage(call_fn: CallableT) -> CallableT: """Decorator to wrap some application functions for telemetry capture. We want to use this for non-execute functions. diff --git a/docs/concepts/index.rst b/docs/concepts/index.rst index d5f76c6d..42405173 100644 --- a/docs/concepts/index.rst +++ b/docs/concepts/index.rst @@ -18,6 +18,7 @@ Overview of the concepts -- read these to get a mental model for how Burr works. state-persistence serde streaming-actions + state-typing hooks additional-visibility recursion diff --git a/docs/reference/actions.rst b/docs/reference/actions.rst index ce473ac8..72d52e1d 100644 --- a/docs/reference/actions.rst +++ b/docs/reference/actions.rst @@ -22,7 +22,12 @@ Actions .. automethod:: __init__ -.. autodecorator:: burr.core.action.action +.. autoclass:: burr.core.action.action + :members: + + .. automethod:: __init__ + + .. autofunction:: burr.core.action.bind diff --git a/docs/reference/index.rst b/docs/reference/index.rst index 0613b28b..f6e96eec 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -1,3 +1,4 @@ + .. _reference: ======================== @@ -20,5 +21,6 @@ need functionality that is not publicly exposed, please open an issue and we can tracking visibility lifecycle + typing integrations/index telemetry diff --git a/docs/reference/tracking.rst b/docs/reference/tracking.rst index 73c9c4fa..ea317156 100644 --- a/docs/reference/tracking.rst +++ b/docs/reference/tracking.rst @@ -3,7 +3,7 @@ Tracking ======== Reference on the Tracking/Telemetry API. -Rather, you should use this throug/in conjunction with :py:meth:`burr.core.application.ApplicationBuilder.with_tracker`. +Rather, you should use this through/in conjunction with :py:meth:`burr.core.application.ApplicationBuilder.with_tracker`. .. autoclass:: burr.tracking.LocalTrackingClient diff --git a/examples/typed-state/README.md b/examples/typed-state/README.md new file mode 100644 index 00000000..fde60cce --- /dev/null +++ b/examples/typed-state/README.md @@ -0,0 +1,264 @@ +# Typed State + +This example goes over how to use the typed state features in Burr with pydantic. + +It will cover the following concepts: + +1. Why you might want to use typing for your state in the first place +1. IDE setup to make use of typing. +1. Defining/altering state at the action level +1. Defining types for streaming contexts +1. Wiring that through to a FastAPI app + +This README will contain snippets + link out to the code. + +This adapts the [instructor + youtube example](../youtube-to-social-media-post/). This +takes in a youtube video, creates a transcript, and uses OpenAI to generate a social media post based on that transcript. + +## Why type state? + +Burr originally came without typing -- typing is an additional (optional) plugin that requires some complexity. So, why use typing at all? And why use Pydantic? Lots of good reasons: + +### Typing provides guard-rails + +Typing state ensures you have guarentees about what the data in your state is, and what shape it takes. This makes it easier to work with/manipulate. + +### Typing makes development easier + +IDEs have integrations for handling type annotations, enabling you to +avoid the cognitive burden of tracking types in your head. You get auto-completion on typed classes, and errors if you attempt to access or assign to a field that does not exist. + +### Typing makes downstream integration easier + +Multiple tools use types (especially with pydantic) to make interacting with data easier. In this example we use [instructor](https://python.useinstructor.com/blog/), as well as [FastAPI](https://fastapi.tiangolo.com/) to leverage pydantic models that Burr also uses. + +### Typing provides a form of documentation + +Type-annotation in python allows you to read your code and get some sense of what it is actually doing. This can be a warm introduction to python for those who came from the world of java initially (the authors of this library included), and make reasoning about a complex codebase simpler. + +## Setting up your IDE + +VSCode (or an editor with a similar interface) is generally optimal for this. It has +pluggable typing (e.g. pylance), which handles generics cleanly. Note: pycharm +is often behind on typing support. See issues like [this](https://youtrack.jetbrains.com/issue/PY-44364) and [this](https://youtrack.jetbrains.com/issue/PY-27627/Support-explicitly-parametrized-generic-class-instantiation-syntax). + +While it will still work in pycharm, you will not get some of the better auto-completion capabilities until those issues are resolved. + +## Defining typed state at the application level + +The code for this is in [application.py](application.py). + +First, define a pydantic model -- make it as nested/recursive as you want. This will represent +your entire state. In this case, we're going to have a transcript of a youtube video that +was given by the user, as well as the social media post. The high-level is here -- the rest is +in the code: + +```python +class ApplicationState(BaseModel): + # Make these have defaults as they are only set in actions + transcript: Optional[str] = Field( + description="The full transcript of the YouTube video.", default=None + ) + post: Optional[SocialMediaPost] = Field( + description="The generated social media post.", default=None + ) +``` + +Note that this should exactly model your state -- we need to make things optional, +as there are points in time when running your application where we have not populated all fields in the application state. For our example the transcript/post will not have been assigned. + +Next, we add it in to the application object, both the initial value(s) and the typing system. +The typing system is what's responsible for managing the schema: + +```python +app = ( + ApplicationBuilder() + .with_actions( + ... + ) + .with_transitions( + ... + ) + .with_entrypoint(...) + .with_typing(PydanticTypingSystem(ApplicationState)) + .with_state(ApplicationState()) + .build() + ) +``` + +That ensures that application and the application's state object are parameterized on the state type. + +To get the typed state object you defined, you can use `.data` on any state object returned by the application: + +```python +# just from the application +print(app.state.data.transcript) + +# after execution +_, _, state = app.run(halt_after=..., inputs=...) +print(state.data.transcript) +``` + +## Defining/altering typed state at the action level + +The code for this is in [application.py](application.py). + +In addition to defining state centrally, we can define it at an action level. + +The code is straightforward, but the API is slightly different from standard Burr. Rather than +using the immutable state-based API, we in-place mutate pydantic models in the action function. Don't worry, the state object you are modifying is still immutable, you're just modifying a copy and returning it. + +In this case, we use `@action.pydantic`, which tells that we're using typed state and which fields to read/write to from state. It derives your typed state class(es) from the function annotations, although you can also pass it the pydantic classes as arguments to the decorator if you prefer. + +Note that the reads/writes have to be a subset of your defined state object -- fields that exist in your pydantic model. +This is because the action will pull all its data from state -- it listens to specific sub-fields. +In this case we use the global `ApplicationState` object as described above, although it can use a subset/compatible set of fields (or, if you elect not to use centralized state, it just has to be compatible with upstream/downstream versions). + +Under the hood, burr will subset the state class so it only has the relevant fields (the reads/write) fields. + +```python +@action.pydantic(reads=["transcript"], writes=["post"]) +def generate_post(state: ApplicationState, llm_client) -> ApplicationState: + """Use the Instructor LLM client to generate `SocialMediaPost` from the YouTube transcript.""" + + # read the transcript from state + transcript = state.transcript + + response = llm_client.chat.completions.create( + model="gpt-4o-mini", + response_model=SocialMediaPost, + messages=[...] + ) + # mutate in place + state.post = response + # return the state + return state +``` + +## Typed state for streaming actions + +The code for this is in [application.py](application.py). + +For streaming actions, not only can we type the input/output state, but we can also type the intermediate result. + +In this case, we just use the `SocialMediaPost` as we did in the application state. Instructor will be streaming that in as it gets created. + +`@streaming_action.pydantic` currently requires you to pass in all the pydantic models as classes, although we will be adding the option to derive from the function signature. + +We first call out to OpenAI, then we stream through + +```python + +@streaming_action.pydantic( + reads=["transcript"], + writes=["post"], + state_input_type=ApplicationState, + state_output_type=ApplicationState, + stream_type=SocialMediaPost, +) +def generate_post_streaming( + state: ApplicationState, llm_client +) -> Generator[Tuple[SocialMediaPost, Optional[ApplicationState]], None, None]: + """Streams a post as it's getting created. This allows for interacting data on the UI side of partial + results, using instructor's streaming capabilities for partial responses: + https://python.useinstructor.com/concepts/partial/ + + :param state: input state -- of the shape `ApplicationState` + :param llm_client: the LLM client, we will bind this in the application + :yield: a tuple of the post and the state -- state will be non-null when it's done + """ + + transcript = state.transcript + response = llm_client.chat.completions.create_partial( + model="gpt-4o-mini", + response_model=SocialMediaPost, + messages=[...], + stream=True, + ) + for post in response: + yield post, None + state.post = post + yield post, state +``` + +When we call out to the application we built, we have to add a type-hint to get typing to work +in the IDE (see the line `streaming_container: StreamingResultContainer[...]`), but we still have the same benefits as the non-streaming approach. + +```python +app = build_streaming_application(...) # builder left out for now +_, streaming_container = app.stream_result( + halt_after=["generate_post"], + inputs={"youtube_url": "https://www.youtube.com/watch?v=hqutVJyd3TI"}, +) +# annotate to make type-completion easier +streaming_container: StreamingResultContainer[ApplicationState, SocialMediaPost] +# post is of type SocialMediaPost +for post in streaming_container: + obj = post.model_dump() + console.clear() + console.print(obj) +``` + +## FastAPI integration + +The code for this is in [server.py](server.py). + +To integrate this with FastAPI is easy, and gets easier with the types cascading through. + +### Non-streaming + +For the non-streaming case, we declare an endpoint that returns the entire state. Note you +may want a subset, but for now this is simple as it matches the pydantic models we defined above. + +```python +@app.get("/social_media_post", response_model=SocialMediaPost) +def social_media_post(youtube_url: str = DEFAULT_YOUTUBE_URL) -> SocialMediaPost: + _, _, state = burr_app.run(halt_after=["generate_post"], inputs={"youtube_url": youtube_url}) + return state.data.post +``` + +### Streaming + +The streaming case involves using FastAPI's [StreamingResponse API](https://fastapi.tiangolo.com/advanced/custom-response/#streamingresponse). We define a generator, which simply yields all +intermediate results: + +```python + +@app.get("/social_media_post_streaming", response_class=StreamingResponse) +def social_media_post_streaming(youtube_url: str = DEFAULT_YOUTUBE_URL) -> StreamingResponse: + """Creates a completion for the chat message""" + + def gen(): + _, streaming_container = burr_app_streaming.stream_result( + halt_after=["generate_post"], + inputs={"youtube_url": youtube_url}, + ) # type: ignore + for post in streaming_container: + obj = post.model_dump() + yield json.dumps(obj) + + return StreamingResponse(gen()) +``` + +Note that `StreamingResponse` is not typed, but you have access to the types with the post +object, which corresponds to the stream from above! + +Async streaming is similar. + +You can run `server.py` with `python server.py`, which will open up on port 7443. You can use the `./curls.sh` command to query the server (it will use a default video, modify to pass your own): + +```bash +./curls.sh # default, non-streaming +./curls.sh streaming # streaming endpoint +./sucls.sh streaming_async # streaming async endpoint +``` + +Note you'll have to have [jq](https://jqlang.github.io/jq/) installed for this to work. + +## Caveats + next steps + +Some things we'll be building out shortly: + +1. The ability to derive application level schemas from individual actions +2. The ability to automatically generate a FastAPI application from state + Burr +3. Configurable validation for state -- guardrails to choose when/when not to validate in pydantic diff --git a/examples/typed-state/__init__.py b/examples/typed-state/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/typed-state/application.py b/examples/typed-state/application.py new file mode 100644 index 00000000..88f1fad8 --- /dev/null +++ b/examples/typed-state/application.py @@ -0,0 +1,263 @@ +from typing import AsyncGenerator, Generator, Optional, Tuple, Union + +import instructor +import openai +from pydantic import BaseModel, Field +from pydantic.json_schema import SkipJsonSchema +from rich.console import Console +from youtube_transcript_api import YouTubeTranscriptApi + +from burr.core import Application, ApplicationBuilder, action +from burr.core.action import ( + AsyncStreamingResultContainer, + StreamingResultContainer, + streaming_action, +) +from burr.integrations.pydantic import PydanticTypingSystem + + +class Concept(BaseModel): + term: str = Field(description="A key term or concept mentioned.") + definition: str = Field(description="A brief definition or explanation of the term.") + timestamp: float = Field(description="Timestamp when the concept is explained.") + + +class SocialMediaPost(BaseModel): + """A social media post about a YouTube video generated its transcript""" + + topic: str = Field(description="Main topic discussed.") + hook: str = Field( + description="Statement to grab the attention of the reader and announce the topic." + ) + body: str = Field( + description="The body of the social media post. It should be informative and make the reader curious about viewing the video." + ) + concepts: list[Concept] = Field( + description="Important concepts about Hamilton or Burr mentioned in this post -- please have at least 1", + min_items=0, + max_items=3, + validate_default=False, + ) + key_takeaways: list[str] = Field( + description="A list of informative key takeways for the reader -- please have at least 1", + min_items=0, + max_items=4, + validate_default=False, + ) + youtube_url: SkipJsonSchema[Union[str, None]] = None + + +class ApplicationState(BaseModel): + # Make these have defaults as they are only set in actions + transcript: Optional[str] = Field( + description="The full transcript of the YouTube video.", default=None + ) + post: Optional[SocialMediaPost] = Field( + description="The generated social media post.", default=None + ) + + +@action.pydantic(reads=[], writes=["transcript"]) +def get_youtube_transcript(state: ApplicationState, youtube_url: str) -> ApplicationState: + """Get the official YouTube transcript for a video given its URL""" + _, _, video_id = youtube_url.partition("?v=") + + transcript = YouTubeTranscriptApi.get_transcript(video_id, languages=["en"]) + state.transcript = " ".join([f"ts={entry['start']} - {entry['text']}" for entry in transcript]) + return state + + +@action.pydantic(reads=["transcript"], writes=["post"]) +def generate_post(state: ApplicationState, llm_client) -> ApplicationState: + """Use the Instructor LLM client to generate `SocialMediaPost` from the YouTube transcript.""" + + # read the transcript from state + transcript = state.transcript + + response = llm_client.chat.completions.create( + model="gpt-4o-mini", + response_model=SocialMediaPost, + messages=[ + { + "role": "system", + "content": "Analyze the given YouTube transcript and generate a compelling social media post.", + }, + {"role": "user", "content": transcript}, + ], + ) + state.post = response + + # store the chapters in state + return state + + +@streaming_action.pydantic( + reads=["transcript"], + writes=["post"], + state_input_type=ApplicationState, + state_output_type=ApplicationState, + stream_type=SocialMediaPost, +) +def generate_post_streaming( + state: ApplicationState, llm_client +) -> Generator[Tuple[SocialMediaPost, Optional[ApplicationState]], None, None]: + """Streams a post as it's getting created. This allows for interacting data on the UI side of partial + results, using instructor's streaming capabilities for partial responses: + https://python.useinstructor.com/concepts/partial/ + + :param state: input state -- of the shape `ApplicationState` + :param llm_client: the LLM client, we will bind this in the application + :yield: a tuple of the post and the state -- state will be non-null when it's done + """ + + transcript = state.transcript + response = llm_client.chat.completions.create_partial( + model="gpt-4o-mini", + response_model=SocialMediaPost, + messages=[ + { + "role": "system", + "content": "Analyze the given YouTube transcript and generate a compelling social media post.", + }, + {"role": "user", "content": transcript}, + ], + stream=True, + ) + final_post: SocialMediaPost = None # type: ignore + for post in response: + final_post = post + yield post, None + + yield final_post, state + + +@streaming_action.pydantic( + reads=["transcript"], + writes=["post"], + state_input_type=ApplicationState, + state_output_type=ApplicationState, + stream_type=SocialMediaPost, +) +async def generate_post_streaming_async( + state: ApplicationState, llm_client +) -> AsyncGenerator[Tuple[SocialMediaPost, Optional[ApplicationState]], None]: + """Async implementation of the streaming action above""" + + transcript = state.transcript + response = llm_client.chat.completions.create_partial( + model="gpt-4o-mini", + response_model=SocialMediaPost, + messages=[ + { + "role": "system", + "content": "Analyze the given YouTube transcript and generate a compelling social media post.", + }, + {"role": "user", "content": transcript}, + ], + stream=True, + ) + final_post = None + async for post in response: + final_post = post + yield post, None + + yield final_post, state + + +def build_application() -> Application[ApplicationState]: + """Builds the standard application (non-streaming)""" + llm_client = instructor.from_openai(openai.OpenAI()) + app = ( + ApplicationBuilder() + .with_actions( + get_youtube_transcript, + generate_post.bind(llm_client=llm_client), + ) + .with_transitions( + ("get_youtube_transcript", "generate_post"), + ("generate_post", "get_youtube_transcript"), + ) + .with_entrypoint("get_youtube_transcript") + .with_typing(PydanticTypingSystem(ApplicationState)) + .with_state(ApplicationState()) + .with_tracker(project="youtube-post") + .build() + ) + return app + + +def build_streaming_application() -> Application[ApplicationState]: + """Builds the streaming application -- this uses the generate_post_streaming action""" + llm_client = instructor.from_openai(openai.OpenAI()) + app = ( + ApplicationBuilder() + .with_actions( + get_youtube_transcript, + generate_post=generate_post_streaming.bind(llm_client=llm_client), + ) + .with_transitions( + ("get_youtube_transcript", "generate_post"), + ("generate_post", "get_youtube_transcript"), + ) + .with_entrypoint("get_youtube_transcript") + .with_typing(PydanticTypingSystem(ApplicationState)) + .with_state(ApplicationState()) + .with_tracker(project="youtube-post") + .build() + ) + return app + + +def build_streaming_application_async() -> Application[ApplicationState]: + """Builds the async streaming application -- uses the generate_post_streaming_async action""" + llm_client = instructor.from_openai(openai.AsyncOpenAI()) + app = ( + ApplicationBuilder() + .with_actions( + get_youtube_transcript, + generate_post=generate_post_streaming_async.bind(llm_client=llm_client), + ) + .with_transitions( + ("get_youtube_transcript", "generate_post"), + ("generate_post", "get_youtube_transcript"), + ) + .with_entrypoint("get_youtube_transcript") + .with_typing(PydanticTypingSystem(ApplicationState)) + .with_state(ApplicationState()) + .with_tracker(project="test-youtube-post") + .build() + ) + return app + + +async def run_async(): + """quick function to run async -- this is not called in the mainline, see commented out code""" + console = Console() + app = build_streaming_application_async() + + _, streaming_container = await app.astream_result( + halt_after=["generate_post"], + inputs={"youtube_url": "https://www.youtube.com/watch?v=hqutVJyd3TI"}, + ) # type: ignore + streaming_container: AsyncStreamingResultContainer[ApplicationState, SocialMediaPost] + + async for post in streaming_container: + obj = post.model_dump() + console.clear() + console.print(obj) + + +# mainline -- runs streaming and prints to console +if __name__ == "__main__": + # asyncio.run(run_async()) + console = Console() + app = build_streaming_application() + _, streaming_container = app.stream_result( + halt_after=["generate_post"], + inputs={"youtube_url": "https://www.youtube.com/watch?v=hqutVJyd3TI"}, + ) # type: ignore + streaming_container: StreamingResultContainer[ApplicationState, SocialMediaPost] + for post in streaming_container: + obj = post.model_dump() + console.clear() + console.print(obj) diff --git a/examples/typed-state/curls.sh b/examples/typed-state/curls.sh new file mode 100755 index 00000000..fb479e56 --- /dev/null +++ b/examples/typed-state/curls.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# Default to the 'social_media_post' endpoint if no argument is passed +ENDPOINT="social_media_post" +if [[ "$1" == "streaming_async" ]]; then + ENDPOINT="social_media_post_streaming_async" +elif [[ "$1" == "streaming" ]]; then + ENDPOINT="social_media_post_streaming" +fi + +# Perform the curl request to the chosen endpoint +curl -X 'GET' "http://localhost:7443/$ENDPOINT" \ + -s -H 'Accept: application/json' \ + --no-buffer | jq --unbuffered -c '.' | while IFS= read -r line; do + if [[ "$line" != "" ]]; then # Check for non-empty lines + clear + echo "$line" | jq --color-output . + sleep .01 # Add a small delay for visual clarity + fi +done diff --git a/examples/typed-state/notebook.ipynb b/examples/typed-state/notebook.ipynb new file mode 100644 index 00000000..bd14801e --- /dev/null +++ b/examples/typed-state/notebook.ipynb @@ -0,0 +1,478 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "8419f04e-f945-491d-9526-6aebbabbad6a", + "metadata": {}, + "source": [ + "# Typed State\n", + "\n", + "In this example we're going to be using state-typing with instructor + Burr to generate social media posts from youtube videos.\n", + "\n", + "First, let's define some pydantic models. Note you'll need the env var `OPENAI_API_KEY` set. " + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "d62b0737-683c-4748-8b55-c15402aa1b2f", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: burr[pydantic] in /Users/elijahbenizzy/.pyenv/versions/3.12.0/envs/burr-3-12/lib/python3.12/site-packages (0.24.0)\n", + "Requirement already satisfied: instructor in /Users/elijahbenizzy/.pyenv/versions/3.12.0/envs/burr-3-12/lib/python3.12/site-packages (1.4.1)\n", + "Requirement already satisfied: openai in /Users/elijahbenizzy/.pyenv/versions/3.12.0/envs/burr-3-12/lib/python3.12/site-packages (1.44.0)\n", + "Requirement already satisfied: rich in /Users/elijahbenizzy/.pyenv/versions/3.12.0/envs/burr-3-12/lib/python3.12/site-packages (13.7.1)\n", + "\u001b[33mWARNING: burr 0.24.0 does not provide the extra 'pydantic'\u001b[0m\u001b[33m\n", + "\u001b[0mRequirement already satisfied: aiohttp<4.0.0,>=3.9.1 in /Users/elijahbenizzy/.pyenv/versions/3.12.0/envs/burr-3-12/lib/python3.12/site-packages (from instructor) (3.9.5)\n", + "Requirement already satisfied: docstring-parser<0.17,>=0.16 in /Users/elijahbenizzy/.pyenv/versions/3.12.0/envs/burr-3-12/lib/python3.12/site-packages (from instructor) (0.16)\n", + "Requirement already satisfied: jiter<0.6.0,>=0.5.0 in /Users/elijahbenizzy/.pyenv/versions/3.12.0/envs/burr-3-12/lib/python3.12/site-packages (from instructor) (0.5.0)\n", + "Requirement already satisfied: pydantic<3.0.0,>=2.8.0 in /Users/elijahbenizzy/.pyenv/versions/3.12.0/envs/burr-3-12/lib/python3.12/site-packages (from instructor) (2.8.2)\n", + "Requirement already satisfied: pydantic-core<3.0.0,>=2.18.0 in /Users/elijahbenizzy/.pyenv/versions/3.12.0/envs/burr-3-12/lib/python3.12/site-packages (from instructor) (2.20.1)\n", + "Requirement already satisfied: tenacity<9.0.0,>=8.4.1 in /Users/elijahbenizzy/.pyenv/versions/3.12.0/envs/burr-3-12/lib/python3.12/site-packages (from instructor) (8.5.0)\n", + "Requirement already satisfied: typer<1.0.0,>=0.9.0 in /Users/elijahbenizzy/.pyenv/versions/3.12.0/envs/burr-3-12/lib/python3.12/site-packages (from instructor) (0.12.3)\n", + "Requirement already satisfied: anyio<5,>=3.5.0 in /Users/elijahbenizzy/.pyenv/versions/3.12.0/envs/burr-3-12/lib/python3.12/site-packages (from openai) (4.4.0)\n", + "Requirement already satisfied: distro<2,>=1.7.0 in /Users/elijahbenizzy/.pyenv/versions/3.12.0/envs/burr-3-12/lib/python3.12/site-packages (from openai) (1.9.0)\n", + "Requirement already satisfied: httpx<1,>=0.23.0 in /Users/elijahbenizzy/.pyenv/versions/3.12.0/envs/burr-3-12/lib/python3.12/site-packages (from openai) (0.27.0)\n", + "Requirement already satisfied: sniffio in /Users/elijahbenizzy/.pyenv/versions/3.12.0/envs/burr-3-12/lib/python3.12/site-packages (from openai) (1.3.1)\n", + "Requirement already satisfied: tqdm>4 in /Users/elijahbenizzy/.pyenv/versions/3.12.0/envs/burr-3-12/lib/python3.12/site-packages (from openai) (4.66.4)\n", + "Requirement already satisfied: typing-extensions<5,>=4.11 in /Users/elijahbenizzy/.pyenv/versions/3.12.0/envs/burr-3-12/lib/python3.12/site-packages (from openai) (4.12.2)\n", + "Requirement already satisfied: markdown-it-py>=2.2.0 in /Users/elijahbenizzy/.pyenv/versions/3.12.0/envs/burr-3-12/lib/python3.12/site-packages (from rich) (3.0.0)\n", + "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /Users/elijahbenizzy/.pyenv/versions/3.12.0/envs/burr-3-12/lib/python3.12/site-packages (from rich) (2.18.0)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /Users/elijahbenizzy/.pyenv/versions/3.12.0/envs/burr-3-12/lib/python3.12/site-packages (from aiohttp<4.0.0,>=3.9.1->instructor) (1.3.1)\n", + "Requirement already satisfied: attrs>=17.3.0 in /Users/elijahbenizzy/.pyenv/versions/3.12.0/envs/burr-3-12/lib/python3.12/site-packages (from aiohttp<4.0.0,>=3.9.1->instructor) (23.2.0)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /Users/elijahbenizzy/.pyenv/versions/3.12.0/envs/burr-3-12/lib/python3.12/site-packages (from aiohttp<4.0.0,>=3.9.1->instructor) (1.4.1)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /Users/elijahbenizzy/.pyenv/versions/3.12.0/envs/burr-3-12/lib/python3.12/site-packages (from aiohttp<4.0.0,>=3.9.1->instructor) (6.0.5)\n", + "Requirement already satisfied: yarl<2.0,>=1.0 in /Users/elijahbenizzy/.pyenv/versions/3.12.0/envs/burr-3-12/lib/python3.12/site-packages (from aiohttp<4.0.0,>=3.9.1->instructor) (1.9.4)\n", + "Requirement already satisfied: idna>=2.8 in /Users/elijahbenizzy/.pyenv/versions/3.12.0/envs/burr-3-12/lib/python3.12/site-packages (from anyio<5,>=3.5.0->openai) (3.7)\n", + "Requirement already satisfied: certifi in /Users/elijahbenizzy/.pyenv/versions/3.12.0/envs/burr-3-12/lib/python3.12/site-packages (from httpx<1,>=0.23.0->openai) (2024.7.4)\n", + "Requirement already satisfied: httpcore==1.* in /Users/elijahbenizzy/.pyenv/versions/3.12.0/envs/burr-3-12/lib/python3.12/site-packages (from httpx<1,>=0.23.0->openai) (1.0.5)\n", + "Requirement already satisfied: h11<0.15,>=0.13 in /Users/elijahbenizzy/.pyenv/versions/3.12.0/envs/burr-3-12/lib/python3.12/site-packages (from httpcore==1.*->httpx<1,>=0.23.0->openai) (0.14.0)\n", + "Requirement already satisfied: mdurl~=0.1 in /Users/elijahbenizzy/.pyenv/versions/3.12.0/envs/burr-3-12/lib/python3.12/site-packages (from markdown-it-py>=2.2.0->rich) (0.1.2)\n", + "Requirement already satisfied: annotated-types>=0.4.0 in /Users/elijahbenizzy/.pyenv/versions/3.12.0/envs/burr-3-12/lib/python3.12/site-packages (from pydantic<3.0.0,>=2.8.0->instructor) (0.7.0)\n", + "Requirement already satisfied: click>=8.0.0 in /Users/elijahbenizzy/.pyenv/versions/3.12.0/envs/burr-3-12/lib/python3.12/site-packages (from typer<1.0.0,>=0.9.0->instructor) (8.1.7)\n", + "Requirement already satisfied: shellingham>=1.3.0 in /Users/elijahbenizzy/.pyenv/versions/3.12.0/envs/burr-3-12/lib/python3.12/site-packages (from typer<1.0.0,>=0.9.0->instructor) (1.5.4)\n", + "\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.2.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.2\u001b[0m\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython3.12 -m pip install --upgrade pip\u001b[0m\n", + "Note: you may need to restart the kernel to use updated packages.\n" + ] + } + ], + "source": [ + "%pip install 'burr[pydantic]' instructor openai rich" + ] + }, + { + "cell_type": "markdown", + "id": "66fe6bbc-3c75-4ff6-9974-4d82109dc47a", + "metadata": {}, + "source": [ + "# Imports/setup" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "c3da83a6-0047-4599-aa0c-c4d7a0cc7e78", + "metadata": {}, + "outputs": [], + "source": [ + "from typing import AsyncGenerator, Generator, Optional, Tuple, Union\n", + "\n", + "import instructor\n", + "import openai\n", + "from pydantic import BaseModel, Field\n", + "from pydantic.json_schema import SkipJsonSchema\n", + "from rich.console import Console\n", + "from youtube_transcript_api import YouTubeTranscriptApi\n", + "\n", + "from burr.core import Application, ApplicationBuilder, action\n", + "from burr.core.action import (\n", + " AsyncStreamingResultContainer,\n", + " StreamingResultContainer,\n", + " streaming_action,\n", + ")\n", + "from burr.integrations.pydantic import PydanticTypingSystem\n", + "import json\n", + "import time\n", + "from rich import print_json\n", + "from IPython.display import clear_output" + ] + }, + { + "cell_type": "markdown", + "id": "1c5d6d06-3a82-44df-b4fe-82c95b3d3d05", + "metadata": {}, + "source": [ + "# Constructs\n", + "\n", + "Let's define some pydantic models to use -- these will help shape our application" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "90258678-b33e-4e84-9a22-c7a6d29aca5d", + "metadata": {}, + "outputs": [], + "source": [ + "class Concept(BaseModel):\n", + " term: str = Field(description=\"A key term or concept mentioned.\")\n", + " definition: str = Field(description=\"A brief definition or explanation of the term.\")\n", + " timestamp: float = Field(description=\"Timestamp when the concept is explained.\")\n", + "\n", + "class SocialMediaPost(BaseModel):\n", + " \"\"\"A social media post about a YouTube video generated its transcript\"\"\"\n", + "\n", + " topic: str = Field(description=\"Main topic discussed.\")\n", + " hook: str = Field(\n", + " description=\"Statement to grab the attention of the reader and announce the topic.\"\n", + " )\n", + " body: str = Field(\n", + " description=\"The body of the social media post. It should be informative and make the reader curious about viewing the video.\"\n", + " )\n", + " concepts: list[Concept] = Field(\n", + " description=\"Important concepts about Hamilton or Burr mentioned in this post -- please have at least 1\",\n", + " min_items=0,\n", + " max_items=3,\n", + " validate_default=False,\n", + " )\n", + " key_takeaways: list[str] = Field(\n", + " description=\"A list of informative key takeways for the reader -- please have at least 1\",\n", + " min_items=0,\n", + " max_items=4,\n", + " validate_default=False,\n", + " )\n", + " youtube_url: SkipJsonSchema[Union[str, None]] = None" + ] + }, + { + "cell_type": "markdown", + "id": "03534c74-aa1f-4a1d-91c7-7c6328228c8d", + "metadata": {}, + "source": [ + "# State Type\n", + "\n", + "Using those, we'll define a core pydantic model that sets up the central schema for our application. Note these are optional, they won't be set when our application starts!" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "5158c112-f054-4881-89dd-12ca7f1e0e79", + "metadata": {}, + "outputs": [], + "source": [ + "class ApplicationState(BaseModel):\n", + " # Make these have defaults as they are only set in actions\n", + " transcript: Optional[str] = Field(\n", + " description=\"The full transcript of the YouTube video.\", default=None\n", + " )\n", + " post: Optional[SocialMediaPost] = Field(\n", + " description=\"The generated social media post.\", default=None\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "f8e0d745-6561-4331-9a5d-86948fee706d", + "metadata": {}, + "source": [ + "# Write an action to transcribe a youtube URL\n", + "\n", + "Note we take in a youtube URL + the state in the format we want, and write to `transcript`. We actually read nothing, as the transcript is an input.\n", + "Different than normal Burr, we actually mutate the model we send in (this allows us to leverage pydantic validation)." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "df35e278-3578-4611-a9de-c3948911ff70", + "metadata": {}, + "outputs": [], + "source": [ + "@action.pydantic(reads=[], writes=[\"transcript\"])\n", + "def get_youtube_transcript(state: ApplicationState, youtube_url: str) -> ApplicationState:\n", + " \"\"\"Get the official YouTube transcript for a video given its URL\"\"\"\n", + " _, _, video_id = youtube_url.partition(\"?v=\")\n", + "\n", + " transcript = YouTubeTranscriptApi.get_transcript(video_id, languages=[\"en\"])\n", + " state.transcript = \" \".join([f\"ts={entry['start']} - {entry['text']}\" for entry in transcript])\n", + " return state" + ] + }, + { + "cell_type": "markdown", + "id": "f5dbcaf5-5471-44b0-82e3-c4b32f01a115", + "metadata": {}, + "source": [ + "# Write an action to stream back pydantic models\n", + "\n", + "We specify the state input type, state output type, and the stream type, streaming it all back using [instructors streaming capability](https://python.useinstructor.com/concepts/partial/)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "4737d8b1-e1c1-472c-bc01-a47d328797d8", + "metadata": {}, + "outputs": [], + "source": [ + "@streaming_action.pydantic(\n", + " reads=[\"transcript\"],\n", + " writes=[\"post\"],\n", + " state_input_type=ApplicationState,\n", + " state_output_type=ApplicationState,\n", + " stream_type=SocialMediaPost,\n", + ")\n", + "def generate_post(\n", + " state: ApplicationState, llm_client\n", + ") -> Generator[Tuple[SocialMediaPost, Optional[ApplicationState]], None, None]:\n", + "\n", + " transcript = state.transcript\n", + " response = llm_client.chat.completions.create_partial(\n", + " model=\"gpt-4o-mini\",\n", + " response_model=SocialMediaPost,\n", + " messages=[\n", + " {\n", + " \"role\": \"system\",\n", + " \"content\": \"Analyze the given YouTube transcript and generate a compelling social media post.\",\n", + " },\n", + " {\"role\": \"user\", \"content\": transcript},\n", + " ],\n", + " stream=True,\n", + " )\n", + " final_post: SocialMediaPost = None # type: ignore\n", + " for post in response:\n", + " final_post = post\n", + " yield post, None\n", + "\n", + " yield final_post, state" + ] + }, + { + "cell_type": "markdown", + "id": "968daa7a-c83f-4890-b775-513dbc347068", + "metadata": {}, + "source": [ + "# Wire together in an application\n", + "\n", + "We specify the application to have type `ApplicationState` as the state, and pass it an initial value" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "11bdf44a-5ff2-4cd3-87b9-175eb68de8e4", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "get_youtube_transcript\n", + "\n", + "get_youtube_transcript\n", + "\n", + "\n", + "\n", + "generate_post\n", + "\n", + "generate_post\n", + "\n", + "\n", + "\n", + "get_youtube_transcript->generate_post\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input__youtube_url\n", + "\n", + "input: youtube_url\n", + "\n", + "\n", + "\n", + "input__youtube_url->get_youtube_transcript\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "generate_post->get_youtube_transcript\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "llm_client = instructor.from_openai(openai.OpenAI())\n", + "app = (\n", + " ApplicationBuilder()\n", + " .with_actions(\n", + " get_youtube_transcript,\n", + " generate_post.bind(llm_client=llm_client),\n", + " )\n", + " .with_transitions(\n", + " (\"get_youtube_transcript\", \"generate_post\"),\n", + " (\"generate_post\", \"get_youtube_transcript\"),\n", + " )\n", + " .with_entrypoint(\"get_youtube_transcript\")\n", + " .with_typing(PydanticTypingSystem(ApplicationState))\n", + " .with_state(ApplicationState())\n", + " .with_tracker(project=\"youtube-post\")\n", + " .build()\n", + ")\n", + "# in case we want to access the state\n", + "assert isinstance(app.state.data, ApplicationState)\n", + "app" + ] + }, + { + "cell_type": "markdown", + "id": "29eb49bc-f1c0-4172-831d-bd748ac90548", + "metadata": {}, + "source": [ + "# Run it!\n", + "\n", + "Now we can run it!" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "ba1486aa-f79f-4e02-959c-ffc0ba58b6ac", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
{\n",
+       "  \"topic\": \"Burr Framework Overview\",\n",
+       "  \"hook\": \"Ever faced challenges while debugging AI applications? Here's a solution!\",\n",
+       "  \"body\": \"Dive into the world of agent applications with Burr! In this quick overview, we explore how Burr helps you debug failing AI calls and track state effectively. Learn to build a graph that connects actions and states, allowing you to resume your application exactly where you left off. Whether it’s fixing an error mid-run or replaying past actions, Burr's observability features enhance your development workflow. Ready to optimize your debugging process? Check out the full video to unravel the potential of Burr!\",\n",
+       "  \"concepts\": [\n",
+       "    {\n",
+       "      \"term\": \"Agent Application\",\n",
+       "      \"definition\": \"A system that models states and actions to create decision-making processes.\",\n",
+       "      \"timestamp\": 105.479\n",
+       "    },\n",
+       "    {\n",
+       "      \"term\": \"State Object\",\n",
+       "      \"definition\": \"An object that holds the state information for actions to read and write during execution.\",\n",
+       "      \"timestamp\": 145.56\n",
+       "    },\n",
+       "    {\n",
+       "      \"term\": \"Graph Representation\",\n",
+       "      \"definition\": \"A structural representation of actions and their interconnections in an agent system, depicted as nodes and edges.\",\n",
+       "      \"timestamp\": 179.28\n",
+       "    }\n",
+       "  ],\n",
+       "  \"key_takeaways\": [\n",
+       "    \"Burr allows near-instantaneous debugging without restarting from scratch.\",\n",
+       "    \"The framework promotes building a stateful graph structure for actions.\",\n",
+       "    \"Use local tracking to effortlessly monitor and interact with your agent's state.\"\n",
+       "  ],\n",
+       "  \"youtube_url\": null\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"topic\"\u001b[0m: \u001b[32m\"Burr Framework Overview\"\u001b[0m,\n", + " \u001b[1;34m\"hook\"\u001b[0m: \u001b[32m\"Ever faced challenges while debugging AI applications? Here's a solution!\"\u001b[0m,\n", + " \u001b[1;34m\"body\"\u001b[0m: \u001b[32m\"Dive into the world of agent applications with Burr! In this quick overview, we explore how Burr helps you debug failing AI calls and track state effectively. Learn to build a graph that connects actions and states, allowing you to resume your application exactly where you left off. Whether it’s fixing an error mid-run or replaying past actions, Burr's observability features enhance your development workflow. Ready to optimize your debugging process? Check out the full video to unravel the potential of Burr!\"\u001b[0m,\n", + " \u001b[1;34m\"concepts\"\u001b[0m: \u001b[1m[\u001b[0m\n", + " \u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"term\"\u001b[0m: \u001b[32m\"Agent Application\"\u001b[0m,\n", + " \u001b[1;34m\"definition\"\u001b[0m: \u001b[32m\"A system that models states and actions to create decision-making processes.\"\u001b[0m,\n", + " \u001b[1;34m\"timestamp\"\u001b[0m: \u001b[1;36m105.479\u001b[0m\n", + " \u001b[1m}\u001b[0m,\n", + " \u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"term\"\u001b[0m: \u001b[32m\"State Object\"\u001b[0m,\n", + " \u001b[1;34m\"definition\"\u001b[0m: \u001b[32m\"An object that holds the state information for actions to read and write during execution.\"\u001b[0m,\n", + " \u001b[1;34m\"timestamp\"\u001b[0m: \u001b[1;36m145.56\u001b[0m\n", + " \u001b[1m}\u001b[0m,\n", + " \u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"term\"\u001b[0m: \u001b[32m\"Graph Representation\"\u001b[0m,\n", + " \u001b[1;34m\"definition\"\u001b[0m: \u001b[32m\"A structural representation of actions and their interconnections in an agent system, depicted as nodes and edges.\"\u001b[0m,\n", + " \u001b[1;34m\"timestamp\"\u001b[0m: \u001b[1;36m179.28\u001b[0m\n", + " \u001b[1m}\u001b[0m\n", + " \u001b[1m]\u001b[0m,\n", + " \u001b[1;34m\"key_takeaways\"\u001b[0m: \u001b[1m[\u001b[0m\n", + " \u001b[32m\"Burr allows near-instantaneous debugging without restarting from scratch.\"\u001b[0m,\n", + " \u001b[32m\"The framework promotes building a stateful graph structure for actions.\"\u001b[0m,\n", + " \u001b[32m\"Use local tracking to effortlessly monitor and interact with your agent's state.\"\u001b[0m\n", + " \u001b[1m]\u001b[0m,\n", + " \u001b[1;34m\"youtube_url\"\u001b[0m: \u001b[3;35mnull\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "_, streaming_container = app.stream_result(\n", + " halt_after=[\"generate_post\"],\n", + " inputs={\"youtube_url\": \"https://www.youtube.com/watch?v=hqutVJyd3TI\"},\n", + ")\n", + "for post in streaming_container:\n", + " assert isinstance(post, SocialMediaPost)\n", + " clear_output(wait=True)\n", + " obj = post.model_dump()\n", + " json_str = json.dumps(obj, indent=2)\n", + " print_json(json_str)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/typed-state/server.py b/examples/typed-state/server.py new file mode 100644 index 00000000..b5592b55 --- /dev/null +++ b/examples/typed-state/server.py @@ -0,0 +1,109 @@ +import contextlib +import json +import logging + +import fastapi +import uvicorn +from application import ( + ApplicationState, + SocialMediaPost, + build_application, + build_streaming_application, + build_streaming_application_async, +) +from fastapi.responses import StreamingResponse + +from burr.core import Application +from burr.core.action import AsyncStreamingResultContainer, StreamingResultContainer + +logger = logging.getLogger(__name__) + +# define a global `burr_app` variable +burr_app: Application[ApplicationState] = None +# This does streaming, in sync mode +burr_app_streaming: Application[ApplicationState] = None +# And this does streaming, in async mode +burr_app_streaming_async: Application[ApplicationState] = None + +DEFAULT_YOUTUBE_URL = "https://www.youtube.com/watch?v=hqutVJyd3TI" + + +@contextlib.asynccontextmanager +async def lifespan(app: fastapi.FastAPI): + """Instantiate the Burr applications on FastAPI startup.""" + global burr_app, burr_app_streaming, burr_app_streaming_async + burr_app = build_application() + burr_app_streaming = build_streaming_application() + burr_app_streaming_async = build_streaming_application_async() + yield + + +app = fastapi.FastAPI(lifespan=lifespan) + + +@app.get("/social_media_post", response_model=SocialMediaPost) +def social_media_post(youtube_url: str = DEFAULT_YOUTUBE_URL) -> SocialMediaPost: + """Basic, synchronous single-step API. + This just returns the social media post, no streaming response. + + + :param youtube_url: youtube URL for the transcript, defaults to DEFAULT_YOUTUBE_URL + :return: the social media post + """ + # Note that state is of type State[ApplicationState] + # This means that it has a data field of type ApplicationState + # which means that our IDE will happily auto-complete for us + # and that we can get the pydantic model + _, _, state = burr_app.run(halt_after=["generate_post"], inputs={"youtube_url": youtube_url}) + return state.data.post + + +@app.get("/social_media_post_streaming", response_class=StreamingResponse) +def social_media_post_streaming(youtube_url: str = DEFAULT_YOUTUBE_URL) -> StreamingResponse: + """Creates a completion for the chat message""" + + def gen(): + _, streaming_container = burr_app_streaming.stream_result( + halt_after=["generate_post"], + inputs={"youtube_url": youtube_url}, + ) # type: ignore + # We annotate this so we can get the right types cascaded through + streaming_container: StreamingResultContainer[ApplicationState, SocialMediaPost] + # Every post is of type SocialMediaPost, and the IDE (if you're using PyLance or an equivalent) should know + for post in streaming_container: + obj = post.model_dump() + yield json.dumps(obj) + # if we called streaming_container.get(), we would get two objects -- + # This state will have a field data of type ApplicationState, which we can use if we want + + # post, state = streaming_container.get() + # state.data.transcript # valid + auto-completion in the IDE + + # return a final streaming result -- it'll just have strings + # for certain SSE frameworks you may want to delimit with data: + return StreamingResponse(gen()) + + +@app.get("/social_media_post_streaming_async", response_class=StreamingResponse) +async def social_media_post_streaming_async( + youtube_url: str = DEFAULT_YOUTUBE_URL, +) -> StreamingResponse: + """Creates a completion for the chat message""" + + async def gen(): + _, streaming_container = await burr_app_streaming_async.astream_result( + halt_after=["generate_post"], + inputs={"youtube_url": youtube_url}, + ) # type: ignore + # We annotate this so we can get the right types cascaded through + streaming_container: AsyncStreamingResultContainer[ApplicationState, SocialMediaPost] + # Every post is of type SocialMediaPost, and the IDE (if you're using PyLance or an equivalent) should know + async for post in streaming_container: + obj = post.model_dump() + yield json.dumps(obj) + + return StreamingResponse(gen()) + + +if __name__ == "__main__": + uvicorn.run("server:app", host="127.0.0.1", port=7443, reload=True) diff --git a/examples/typed-state/statemachine.png b/examples/typed-state/statemachine.png new file mode 100644 index 00000000..ec7a8b24 Binary files /dev/null and b/examples/typed-state/statemachine.png differ diff --git a/telemetry/ui/src/components/routes/app/StepList.tsx b/telemetry/ui/src/components/routes/app/StepList.tsx index b42911b4..2d297c0d 100644 --- a/telemetry/ui/src/components/routes/app/StepList.tsx +++ b/telemetry/ui/src/components/routes/app/StepList.tsx @@ -123,8 +123,7 @@ const CommonTableRow = (props: { } else { props.setCurrentSelectedIndex(props.sequenceID); } - }} - > + }}> {props.children} ); @@ -176,8 +175,7 @@ const ActionTableRow = (props: { currentSelectedIndex={currentSelectedIndex} step={props.step} setCurrentHoverIndex={setCurrentHoverIndex} - setCurrentSelectedIndex={setCurrentSelectedIndex} - > + setCurrentSelectedIndex={setCurrentSelectedIndex}> {sequenceID}
@@ -189,8 +187,7 @@ const ActionTableRow = (props: { />
+ className={`${props.minimized ? 'w-32' : 'w-72 max-w-72'} flex flex-row justify-start gap-1 items-center`}> + setCurrentSelectedIndex={setCurrentSelectedIndex}> + className={` ${normalText} w-48 min-w-48 max-w-48 truncate pl-9`}>
{ @@ -311,8 +306,7 @@ const LinkSubTable = (props: { `/project/${props.projectId}/${subApp.child.partition_key || 'null'}/${subApp.child.app_id}` ); e.stopPropagation(); - }} - > + }}> {subApp.child.app_id}
@@ -394,25 +388,21 @@ const StepSubTableRow = (props: { currentSelectedIndex={currentSelectedIndex} step={props.step} setCurrentHoverIndex={setCurrentHoverIndex} - setCurrentSelectedIndex={setCurrentSelectedIndex} - > + setCurrentSelectedIndex={setCurrentSelectedIndex}> + className={` ${lightText} w-10 min-w-10 ${props.displaySpanID ? '' : 'text-opacity-0'}`}> {spanIDUniqueToAction} {!props.minimized ? ( <> + className={`${normalText} ${props.minimized ? 'w-32 min-w-32' : 'w-72 max-w-72'} flex flex-col`}>
{[...Array(depth).keys()].map((i) => ( + className={`${i === depth - 1 ? 'opacity-0' : 'opacity-0'} text-lg text-gray-600 w-4 flex-shrink-0`}> ))} { e.stopPropagation(); - }} - > + }}>
+ }}>
) : (
+ }}> {
+ }}> {hoverItem}
} @@ -999,9 +985,7 @@ const ActionSubTable = (props: { isExpanded={isTraceExpanded} setExpanded={setTraceExpanded} allowExpand={ - step.spans.length > 0 || - // step.streaming_events.length > 0 || - step.attributes.length > 0 + step.spans.length > 0 || step.streaming_events.length > 0 || step.attributes.length > 0 } latestTimeSeen={latestTimeSeen} expandNonSpanAttributes={expandNonSpanAttributes} @@ -1159,7 +1143,9 @@ export const StepList = (props: { : new Date(); const MinimizeTableIcon = props.minimized ? ChevronRightIcon : ChevronLeftIcon; const FullScreenIcon = props.fullScreen ? AiOutlineFullscreenExit : AiOutlineFullscreen; - const displaySpansCol = stepsWithEllapsedTime.some((step) => step.spans.length > 0); + const displaySpansCol = stepsWithEllapsedTime.some( + (step) => step.spans.length > 0 || step.streaming_events.length > 0 + ); const displayLinksCol = props.links.length > 0; const linksBySequenceID = props.links.reduce((acc, child) => { const existing = acc.get(child.sequence_id || -1) || []; @@ -1381,8 +1367,7 @@ const ParentLink = (props: {
+ to={`/project/${props.projectId}/${props.parentPointer.partition_key}/${props.parentPointer.app_id}`}> {props.parentPointer.app_id} @ diff --git a/tests/core/test_action.py b/tests/core/test_action.py index 6a477973..675f5dd3 100644 --- a/tests/core/test_action.py +++ b/tests/core/test_action.py @@ -19,6 +19,7 @@ action, create_action, default, + derive_inputs_from_fn, streaming_action, ) @@ -743,3 +744,53 @@ async def callback(r: Optional[dict], s: State, e: Exception): ((result, state, error),) = called assert state["foo"] == "bar" assert result is None + + +def test_derive_inputs_from_fn_state_only(): + def fn(state): + ... + + bound_params = {} + required, optional = derive_inputs_from_fn(bound_params, fn) + assert required == [] + assert optional == [] + + +def test_derive_inputs_from_fn_state_and_required(): + def fn(state, a, b): + ... + + bound_params = {"state": 1} + required, optional = derive_inputs_from_fn(bound_params, fn) + assert required == ["a", "b"] + assert optional == [] + + +def test_derive_inputs_from_fn_state_required_and_optional(): + def fn(state, a, b=2): + ... + + bound_params = {"state": 1} + required, optional = derive_inputs_from_fn(bound_params, fn) + assert required == ["a"] + assert optional == ["b"] + + +def test_derive_inputs_from_fnh_state_and_all_bound_except_state(): + def fn(state, a, b): + ... + + bound_params = {"a": 1, "b": 2} + required, optional = derive_inputs_from_fn(bound_params, fn) + assert required == [] + assert optional == [] + + +def test_non_existent_bound_parameters(): + def fn(state, a): + ... + + bound_params = {"a": 1, "non_existent": 2} + required, optional = derive_inputs_from_fn(bound_params, fn) + assert required == [] + assert optional == [] diff --git a/tests/core/test_application.py b/tests/core/test_application.py index 53daa0e6..11256f16 100644 --- a/tests/core/test_application.py +++ b/tests/core/test_application.py @@ -9,6 +9,7 @@ from burr.core import State from burr.core.action import ( + DEFAULT_SCHEMA, Action, AsyncGenerator, AsyncStreamingAction, @@ -41,6 +42,7 @@ ) from burr.core.graph import Graph, GraphBuilder, Transition from burr.core.persistence import BaseStatePersister, DevNullPersister, PersistedStateData +from burr.core.typing import TypingSystem from burr.lifecycle import ( PostRunStepHook, PostRunStepHookAsync, @@ -2425,25 +2427,28 @@ def test__validate_start_not_found(): def test__adjust_single_step_output_result_and_state(): state = State({"count": 1}) result = {"count": 1} - assert _adjust_single_step_output((result, state), "test_action") == (result, state) + assert _adjust_single_step_output((result, state), "test_action", DEFAULT_SCHEMA) == ( + result, + state, + ) def test__adjust_single_step_output_just_state(): state = State({"count": 1}) - assert _adjust_single_step_output(state, "test_action") == ({}, state) + assert _adjust_single_step_output(state, "test_action", DEFAULT_SCHEMA) == ({}, state) def test__adjust_single_step_output_errors_incorrect_type(): state = "foo" with pytest.raises(ValueError, match="must return either"): - _adjust_single_step_output(state, "test_action") + _adjust_single_step_output(state, "test_action", DEFAULT_SCHEMA) def test__adjust_single_step_output_errors_incorrect_result_type(): state = State() result = "bar" with pytest.raises(ValueError, match="non-dict"): - _adjust_single_step_output((state, result), "test_action") + _adjust_single_step_output((state, result), "test_action", DEFAULT_SCHEMA) def test_application_builder_unset(): @@ -3184,3 +3189,44 @@ def recursive_action(state: State) -> State: len(hook.pre_called) == 62 ) # 63 - the initial one from the call to recursive_action outside the application assert len(hook.post_called) == 62 # ditto + + +class CounterState(State): + count: int + + +class SimpleTypingSystem(TypingSystem[CounterState]): + def state_type(self) -> type[CounterState]: + return CounterState + + def state_pre_action_run_type(self, action: Action, graph: Graph) -> type[Any]: + raise NotImplementedError + + def state_post_action_run_type(self, action: Action, graph: Graph) -> type[Any]: + raise NotImplementedError + + def construct_data(self, state: State[Any]) -> CounterState: + return CounterState({"count": state["count"]}) + + def construct_state(self, data: Any) -> State[Any]: + raise NotImplementedError + + +def test_builder_captures_typing_system(): + """Tests that the typing system is captured correctly""" + counter_action = base_counter_action.with_name("counter") + result_action = Result("count").with_name("result") + app = ( + ApplicationBuilder() + .with_actions(counter_action, result_action) + .with_transitions(("counter", "counter", expr("count < 10"))) + .with_transitions(("counter", "result", default)) + .with_entrypoint("counter") + .with_state(count=0) + .with_typing(SimpleTypingSystem()) + .build() + ) + assert isinstance(app.state.data, CounterState) + _, _, state = app.run(halt_after=["result"]) + assert isinstance(state.data, CounterState) + assert state.data["count"] == 10 diff --git a/tests/core/test_state.py b/tests/core/test_state.py index 4572f8e4..30374d2a 100644 --- a/tests/core/test_state.py +++ b/tests/core/test_state.py @@ -1,6 +1,10 @@ +from typing import Any + import pytest +from burr.core import Action, Graph from burr.core.state import State, register_field_serde +from burr.core.typing import TypingSystem def test_state_access(): @@ -158,3 +162,26 @@ def my_field_serializer(value: str, **kwargs) -> dict: with pytest.raises(ValueError): # deserializer still bad register_field_serde("my_field", my_field_serializer, my_field_deserializer) + + +class SimpleTypingSystem(TypingSystem[Any]): + def state_type(self) -> type[Any]: + raise NotImplementedError + + def state_pre_action_run_type(self, action: Action, graph: Graph) -> type[Any]: + raise NotImplementedError + + def state_post_action_run_type(self, action: Action, graph: Graph) -> type[Any]: + raise NotImplementedError + + def construct_data(self, state: State[Any]) -> Any: + raise NotImplementedError + + def construct_state(self, data: State[Any]) -> State[Any]: + raise NotImplementedError + + +def test_state_apply_keeps_typing_system(): + state = State({"foo": "bar"}, typing_system=SimpleTypingSystem()) + assert state.update(foo="baz").typing_system is state.typing_system + assert state.subset("foo").typing_system is state.typing_system diff --git a/tests/core/test_validation.py b/tests/core/test_validation.py index 21866be6..8303edf9 100644 --- a/tests/core/test_validation.py +++ b/tests/core/test_validation.py @@ -8,5 +8,5 @@ def test__assert_set(): def test__assert_set_unset(): - with pytest.raises(ValueError, match="foo"): + with pytest.raises(ValueError, match="bar"): assert_set(None, "foo", "bar") diff --git a/tests/integrations/test_burr_opentelemetry.py b/tests/integrations/test_burr_opentelemetry.py index d19ed6ea..e3789894 100644 --- a/tests/integrations/test_burr_opentelemetry.py +++ b/tests/integrations/test_burr_opentelemetry.py @@ -7,7 +7,7 @@ from burr.integrations.opentelemetry import convert_to_otel_attribute -class TestModel(pydantic.BaseModel): +class SampleModel(pydantic.BaseModel): foo: int bar: bool @@ -21,7 +21,7 @@ class TestModel(pydantic.BaseModel): ((1.0, 1.0), [1.0, 1.0]), ((True, True), [True, True]), (("hello", "hello"), ["hello", "hello"]), - (TestModel(foo=1, bar=True), json.dumps(serde.serialize(TestModel(foo=1, bar=True)))), + (SampleModel(foo=1, bar=True), json.dumps(serde.serialize(SampleModel(foo=1, bar=True)))), ], ) def test_convert_to_otel_attribute(value, expected):