Skip to content

Commit

Permalink
Adds schema tracking to action
Browse files Browse the repository at this point in the history
This is just a Quick wrapper class to represent a schema. Note that this is currently used internally,
just to store the appropriate information. This does not validate or do conversion, currently that
is done within the pydantic model state typing system (which is also internal in its implementation).

We will likely centralize that logic at some point when we get more -- it would look something like this:
1. Action is passed an ActionSchema
2. Action is parameterized on the ActionSchema types
3. Action takes state, validates the type and converts to StateInputType
4. Action runs, returns intermediate result + state
5. Action validates intermediate result type (or converts to dict? Probably just keeps it
6. Action converts StateOutputType to State

Also we don't have this split out into two classes (input/output schema), but we will likely do that shortly.
  • Loading branch information
elijahbenizzy committed Sep 10, 2024
1 parent a0cc9c3 commit 7c52920
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 24 deletions.
34 changes: 34 additions & 0 deletions burr/core/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from typing import Self

from burr.core.state import State
from burr.core.typing import ActionSchema


class Function(abc.ABC):
Expand Down Expand Up @@ -130,6 +131,20 @@ def update(self, result: dict, state: State) -> State:
pass


class DefaultSchema(ActionSchema):
def state_input_type(self) -> type[State]:
raise NotImplementedError

def state_output_type(self) -> type[State]:
raise NotImplementedError

def intermediate_result_type(self) -> type[dict]:
return dict


DEFAULT_SCHEMA = DefaultSchema()


class Action(Function, Reducer, abc.ABC):
def __init__(self):
"""Represents an action in a state machine. This is the base class from which
Expand Down Expand Up @@ -173,6 +188,10 @@ def single_step(self) -> bool:
def streaming(self) -> bool:
return False

@property
def schema(self) -> ActionSchema:
return DEFAULT_SCHEMA

def get_source(self) -> str:
"""Returns the source code of the action. This will default to
the source code of the class in which the action is implemented,
Expand Down Expand Up @@ -524,6 +543,7 @@ def __init__(
bound_params: Optional[dict] = None,
input_spec: Optional[tuple[list[str], list[str]]] = None,
originating_fn: Optional[Callable] = None,
schema: ActionSchema = DEFAULT_SCHEMA,
):
"""Instantiates a function-based action with the given function, reads, and writes.
The function must take in a state and return a tuple of (result, new_state).
Expand All @@ -548,6 +568,7 @@ def __init__(
[item for item in input_spec[1] if item not in self._bound_params],
)
)
self._schema = schema

@property
def fn(self) -> Callable:
Expand All @@ -565,6 +586,10 @@ def writes(self) -> list[str]:
def inputs(self) -> tuple[list[str], list[str]]:
return self._inputs

@property
def schema(self) -> ActionSchema:
return self._schema

def with_params(self, **kwargs: Any) -> "FunctionBasedAction":
"""Binds parameters to the function.
Note that there is no reason to call this by the user. This *could*
Expand All @@ -580,6 +605,8 @@ def with_params(self, **kwargs: Any) -> "FunctionBasedAction":
self._writes,
{**self._bound_params, **kwargs},
input_spec=self._inputs,
originating_fn=self._originating_fn,
schema=self._schema,
)

def run_and_update(self, state: State, **run_kwargs) -> tuple[dict, State]:
Expand Down Expand Up @@ -981,6 +1008,7 @@ def __init__(
bound_params: Optional[dict] = None,
input_spec: Optional[tuple[list[str], list[str]]] = None,
originating_fn: Optional[Callable] = None,
schema: ActionSchema = DEFAULT_SCHEMA,
):
"""Instantiates a function-based streaming action with the given function, reads, and writes.
The function must take in a state (and inputs) and return a generator of (result, new_state).
Expand All @@ -1003,6 +1031,7 @@ def __init__(
)
)
self._originating_fn = originating_fn if originating_fn is not None else fn
self._schema = schema

async def _a_stream_run_and_update(
self, state: State, **run_kwargs
Expand Down Expand Up @@ -1046,6 +1075,7 @@ def with_params(self, **kwargs: Any) -> "FunctionBasedStreamingAction":
{**self._bound_params, **kwargs},
input_spec=self._inputs,
originating_fn=self._originating_fn,
schema=self._schema,
)

@property
Expand All @@ -1056,6 +1086,10 @@ def inputs(self) -> tuple[list[str], list[str]]:
def fn(self) -> Union[StreamingFn, StreamingFnAsync]:
return self._fn

@property
def schema(self) -> ActionSchema:
return self._schema

def is_async(self) -> bool:
return inspect.isasyncgenfunction(self._fn)

Expand Down
39 changes: 23 additions & 16 deletions burr/core/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from burr.common import types as burr_types
from burr.core import persistence, validation
from burr.core.action import (
DEFAULT_SCHEMA,
Action,
AsyncStreamingAction,
AsyncStreamingResultContainer,
Expand All @@ -44,15 +45,16 @@
from burr.core.graph import Graph, GraphBuilder
from burr.core.persistence import BaseStateLoader, BaseStateSaver
from burr.core.state import State
from burr.core.typing import DictBasedTypingSystem
from burr.core.typing import ActionSchema, DictBasedTypingSystem, TypingSystem
from burr.core.validation import BASE_ERROR_MESSAGE
from burr.lifecycle.base import ExecuteMethod, LifecycleAdapter, PostRunStepHook, PreRunStepHook
from burr.lifecycle.internal import LifecycleAdapterSet
from burr.visibility import tracing
from burr.visibility.tracing import tracer_factory_context_var

if TYPE_CHECKING:
from burr.core.typing import TypingSystem
# TODO -- push type-checking check back from here
# OR just put everything under type-chekcing...
from burr.tracking.base import TrackingClient

logger = logging.getLogger(__name__)
Expand All @@ -64,8 +66,13 @@
StateTypeToSet = TypeVar("StateTypeToSet")


def _validate_result(result: dict, name: str) -> None:
if not isinstance(result, dict):
def _validate_result(result: Any, name: str, schema: ActionSchema = DEFAULT_SCHEMA) -> None:
# TODO -- validate the output type is action schema's output type...
# TODO -- split out the action schema into input/output schema types
# Then action schema will have both
# we'll just need to ensure we pass the right ones
result_type = schema.intermediate_result_type()
if not isinstance(result, result_type):
raise ValueError(
f"Action {name} returned a non-dict result: {result}. "
f"All results must be dictionaries."
Expand All @@ -80,13 +87,15 @@ def _raise_fn_return_validation_error(output: Any, action_name: str):
)


def _adjust_single_step_output(output: Union[State, Tuple[dict, State]], action_name: str):
def _adjust_single_step_output(
output: Union[State, Tuple[dict, State]], action_name: str, action_schema: ActionSchema
):
"""Adjusts the output of a single step action to be a tuple of (result, state) or just state"""

if isinstance(output, tuple):
if not len(output) == 2:
_raise_fn_return_validation_error(output, action_name)
_validate_result(output[0], action_name)
_validate_result(output[0], action_name, action_schema)
if not isinstance(output[1], State):
_raise_fn_return_validation_error(output, action_name)
return output
Expand Down Expand Up @@ -242,11 +251,10 @@ def _run_single_step_action(
# TODO -- guard all reads/writes with a subset of the state
action.validate_inputs(inputs)
result, new_state = _adjust_single_step_output(
action.run_and_update(state, **inputs), action.name
action.run_and_update(state, **inputs), action.name, action.schema
)
_validate_result(result, action.name)
_validate_result(result, action.name, action.schema)
out = result, _state_update(state, new_state)
_validate_result(result, action.name)
_validate_reducer_writes(action, new_state, action.name)
return out

Expand Down Expand Up @@ -300,8 +308,7 @@ def _run_single_step_streaming_action(
f"Action {action.name} did not return a state update. For streaming actions, the last yield "
f"statement must be a tuple of (result, state_update). For example, yield dict(foo='bar'), state.update(foo='bar')"
)
# TODO -- get this back in and use the action's schema (still not set) to validate the result...
# _validate_result(result, action.name)
_validate_result(result, action.name, action.schema)
_validate_reducer_writes(action, state_update, action.name)
yield result, state_update

Expand Down Expand Up @@ -354,7 +361,7 @@ async def _arun_single_step_streaming_action(
f"statement must be a tuple of (result, state_update). For example, yield dict(foo='bar'), state.update(foo='bar')"
)
# TODO -- add back in validation when we have a schema
# _validate_result(result, action.name)
_validate_result(result, action.name, action.schema)
_validate_reducer_writes(action, state_update, action.name)
# TODO -- add guard against zero-length stream
yield result, state_update
Expand Down Expand Up @@ -405,7 +412,7 @@ def _run_multi_step_streaming_action(
count += 1
yield next_result, None
state_update = _run_reducer(action, state, result, action.name)
_validate_result(result, action.name)
_validate_result(result, action.name, action.schema)
_validate_reducer_writes(action, state_update, action.name)
yield result, state_update

Expand Down Expand Up @@ -449,7 +456,7 @@ async def _arun_multi_step_streaming_action(
count += 1
yield next_result, None
state_update = _run_reducer(action, state, result, action.name)
_validate_result(result, action.name)
_validate_result(result, action.name, action.schema)
_validate_reducer_writes(action, state_update, action.name)
yield result, state_update

Expand All @@ -461,9 +468,9 @@ async def _arun_single_step_action(
state_to_use = state
action.validate_inputs(inputs)
result, new_state = _adjust_single_step_output(
await action.run_and_update(state_to_use, **inputs), action.name
await action.run_and_update(state_to_use, **inputs), action.name, action.schema
)
_validate_result(result, action.name)
_validate_result(result, action.name, action.schema)
_validate_reducer_writes(action, new_state, action.name)
return result, _state_update(state, new_state)

Expand Down
41 changes: 41 additions & 0 deletions burr/core/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,47 @@ def construct_state(self, data: BaseType) -> State[BaseType]:
"""


StateInputType = TypeVar("StateInputType")
StateOutputType = TypeVar("StateOutputType")
IntermediateResultType = TypeVar("IntermediateResultType")


class ActionSchema(
abc.ABC,
Generic[
StateInputType,
StateOutputType,
IntermediateResultType,
],
):
"""Quick wrapper class to represent a schema. Note that this is currently used internally,
just to store the appropriate information. This does not validate or do conversion, currently that
is done within the pydantic model state typing system (which is also internal in its implementation).
We will likely centralize that logic at some point when we get more -- it would look something like this:
1. Action is passed an ActionSchema
2. Action is parameterized on the ActionSchema types
3. Action takes state, validates the type and converts to StateInputType
4. Action runs, returns intermediate result + state
5. Action validates intermediate result type (or converts to dict? Probably just keeps it
6. Action converts StateOutputType to State
"""

@abc.abstractmethod
def state_input_type() -> Type[StateInputType]:
pass

@abc.abstractmethod
def state_output_type() -> Type[StateOutputType]:
pass

@abc.abstractmethod
def intermediate_result_type() -> Type[IntermediateResultType]:
pass


class DictBasedTypingSystem(TypingSystem[dict]):
"""Effectively a no-op. State is backed by a dictionary, which allows every state item
to... be a dictionary."""
Expand Down
54 changes: 51 additions & 3 deletions burr/integrations/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from burr.core import Action, Graph, State
from burr.core.action import FunctionBasedAction, FunctionBasedStreamingAction, bind, get_inputs
from burr.core.typing import TypingSystem
from burr.core.typing import ActionSchema, TypingSystem

Inputs = ParamSpec("Inputs")

Expand Down Expand Up @@ -133,8 +133,37 @@ def _validate_keys(model: Type[pydantic.BaseModel], keys: List[str], fn: Callabl
)


StateInputType = TypeVar("StateInputType", bound=pydantic.BaseModel)
StateOutputType = TypeVar("StateOutputType", bound=pydantic.BaseModel)
IntermediateResultType = TypeVar("IntermediateResultType", bound=Union[pydantic.BaseModel, dict])


class PydanticActionSchema(ActionSchema[StateInputType, StateOutputType, IntermediateResultType]):
def __init__(
self,
input_type: Type[StateInputType],
output_type: Type[StateOutputType],
intermediate_result_type: Type[IntermediateResultType],
):
self._input_type = input_type
self._output_type = output_type
self._intermediate_result_type = intermediate_result_type

def state_input_type(self) -> Type[StateInputType]:
return self._input_type

def state_output_type(self) -> Type[StateOutputType]:
return self._output_type

def intermediate_result_type(self) -> type[IntermediateResultType]:
return self._intermediate_result_type


def pydantic_action(
reads: List[str], writes: List[str]
reads: List[str],
writes: List[str],
state_input_type: Optional[Type[pydantic.BaseModel]] = None,
state_output_type: Optional[Type[pydantic.BaseModel]] = None,
) -> Callable[[PydanticActionFunction], PydanticActionFunction]:
"""Action that specifies inputs/outputs using pydantic models.
This should make it easier to develop with guardrails.
Expand All @@ -147,7 +176,15 @@ def pydantic_action(
"""

def decorator(fn: PydanticActionFunction) -> PydanticActionFunction:
itype, otype = _validate_and_extract_signature_types(fn)
if state_input_type is None and state_output_type is None:
itype, otype = _validate_and_extract_signature_types(fn)

elif state_input_type is not None and state_output_type is not None:
itype, otype = state_input_type, state_output_type
else:
raise ValueError(
"If you specify state_input_type or state_output_type, you must specify both."
)
_validate_keys(model=itype, keys=reads, fn=fn)
_validate_keys(model=otype, keys=writes, fn=fn)
SubsetInputType = subset_model(
Expand All @@ -162,6 +199,7 @@ def decorator(fn: PydanticActionFunction) -> PydanticActionFunction:
force_optional_fields=[],
model_name_suffix=f"{fn.__name__}_input",
)
# TODO -- figure out

def action_function(state: State, **kwargs) -> State:
model_to_use = model_from_state(model=SubsetInputType, state=state)
Expand Down Expand Up @@ -191,6 +229,11 @@ async def async_action_function(state: State, **kwargs) -> State:
writes,
input_spec=get_inputs({}, fn),
originating_fn=fn,
schema=PydanticActionSchema(
input_type=SubsetInputType,
output_type=SubsetOutputType,
intermediate_result_type=dict,
),
),
)
setattr(fn, "bind", types.MethodType(bind, fn))
Expand Down Expand Up @@ -316,6 +359,11 @@ async def async_action_generator(
writes,
input_spec=get_inputs({}, fn),
originating_fn=fn,
schema=PydanticActionSchema(
input_type=SubsetInputType,
output_type=SubsetOutputType,
intermediate_result_type=dict,
),
),
)
setattr(fn, "bind", types.MethodType(bind, fn))
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "burr"
version = "0.30.0rc4"
version = "0.30.0rc5"
dependencies = [] # yes, there are none
requires-python = ">=3.9"
authors = [
Expand Down
Loading

0 comments on commit 7c52920

Please sign in to comment.