Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
elijahbenizzy committed Feb 17, 2024
1 parent 514fed8 commit 70bb602
Show file tree
Hide file tree
Showing 10 changed files with 657 additions and 115 deletions.
99 changes: 89 additions & 10 deletions burr/core/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import copy
import inspect
import types
from typing import Any, Callable, List, Protocol, Tuple, TypeVar, Union
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, TypeVar, Union

from burr.core.state import State

Expand All @@ -21,13 +21,44 @@ def reads(self) -> list[str]:
pass

@abc.abstractmethod
def run(self, state: State) -> dict:
def run(self, state: State, **run_kwargs) -> dict:
"""Runs the function on the given state and returns the result.
The result is jsut a key/value dictionary.
:param state: State to run the function on
:param run_kwargs: Additional arguments to the function passed at runtime.
:return: Result of the function
"""
pass

@property
def inputs(self) -> list[str]:
"""Represents inputs that are required for this to run.
These correspond to the `**function_kwargs` in `run` above.
:return:
"""
return []

def validate_inputs(self, inputs: Optional[Dict[str, Any]]) -> None:
"""Validates the inputs to the function. This is a convenience method
to allow for validation of inputs before running the function.
:param inputs: Inputs to validate
:raises ValueError: If the inputs are invalid
"""
if inputs is None:
inputs = {}
required_inputs = set(self.inputs)
given_inputs = set(inputs.keys())
missing_inputs = required_inputs - given_inputs
additional_inputs = given_inputs - required_inputs
if missing_inputs or additional_inputs:
raise ValueError(
f"Inputs to function {self} are invalid. "
f"Missing inputs: {missing_inputs}, "
f"Additional inputs: {additional_inputs}"
)

def is_async(self) -> bool:
"""Convenience method to check if the function is async or not.
Expand Down Expand Up @@ -148,7 +179,7 @@ def condition_func(state: State) -> bool:

return Condition(keys, condition_func, name=expr)

def run(self, state: State) -> dict:
def run(self, state: State, **kwargs) -> dict:
return {Condition.KEY: self._resolver(state)}

@property
Expand Down Expand Up @@ -197,7 +228,7 @@ def name(self) -> str:


class Result(Action):
def __init__(self, fields: list[str]):
def __init__(self, *fields: str):
"""Represents a result action. This is purely a convenience class to
pull data from state and give it out to the result. It does nothing to
the state itself.
Expand All @@ -215,13 +246,43 @@ def update(self, result: dict, state: State) -> State:

@property
def reads(self) -> list[str]:
return self._fields
return list(self._fields)

@property
def writes(self) -> list[str]:
return []


class Input(Action):
def __init__(self, *fields: str):
"""Represents an input action -- this reads something from an input
then writes that directly to state. This is a convenience class for when you don't
need to process the input and just want to put it in state for later use.
:param fields: Fields to pull from the state
"""
super(Input, self).__init__()
self._fields = fields

@property
def reads(self) -> list[str]:
return [] # nothing from state

def run(self, state: State, **run_kwargs) -> dict:
return {key: run_kwargs[key] for key in self._fields}

@property
def writes(self) -> list[str]:
return list(self._fields)

@property
def inputs(self) -> list[str]:
return list(self._fields)

def update(self, result: dict, state: State) -> State:
return state.update(**result)


class SingleStepAction(Action, abc.ABC):
"""Internal representation of a "single-step" action. While most actions will have
a run and an update, this is a convenience class for actions that return them both at the same time.
Expand All @@ -241,11 +302,12 @@ def single_step(self) -> bool:
return True

@abc.abstractmethod
def run_and_update(self, state: State) -> Tuple[dict, State]:
def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]:
"""Performs a run/update at the same time.
:param state:
:return:
:param state: State to run the action on
:param run_kwargs: Additional arguments to the function passed at runtime.
:return: Result of the action and the new state
"""
pass

Expand All @@ -268,6 +330,15 @@ def update(self, result: dict, state: State) -> State:
"SingleStepAction.update should never be called independently -- use run_and_update instead."
)

def is_async(self) -> bool:
"""Convenience method to check if the function is async or not.
We'll want to clean up the class hierarchy, but this is all internal.
See note on ``run`` and ``update`` above
:return: True if the function is async, False otherwise
"""
return inspect.iscoroutinefunction(self.run_and_update)


class FunctionBasedAction(SingleStepAction):
ACTION_FUNCTION = "action_function"
Expand Down Expand Up @@ -304,6 +375,14 @@ def reads(self) -> list[str]:
def writes(self) -> list[str]:
return self._writes

@property
def inputs(self) -> list[str]:
return [
param
for param in inspect.signature(self._fn).parameters
if param != "state" and param not in self._bound_params
]

def with_params(self, **kwargs: Any) -> "FunctionBasedAction":
"""Binds parameters to the function.
Note that there is no reason to call this by the user. This *could*
Expand All @@ -317,8 +396,8 @@ def with_params(self, **kwargs: Any) -> "FunctionBasedAction":
new_action._bound_params = {**self._bound_params, **kwargs}
return new_action

def run_and_update(self, state: State) -> Tuple[dict, State]:
return self._fn(state, **self._bound_params)
def run_and_update(self, state: State, **run_kwargs) -> Tuple[dict, State]:
return self._fn(state, **self._bound_params, **run_kwargs)


def _validate_action_function(fn: Callable):
Expand Down
Loading

0 comments on commit 70bb602

Please sign in to comment.