Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
elijahbenizzy committed Sep 22, 2024
1 parent d6f58f6 commit 7bccbf4
Show file tree
Hide file tree
Showing 14 changed files with 657 additions and 75 deletions.
81 changes: 58 additions & 23 deletions burr/core/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,17 +70,40 @@ def run(self, state: State, **run_kwargs) -> dict:
pass

@property
def inputs(self) -> Union[list[str], tuple[list[str], list[str]]]:
def inputs(
self,
) -> Tuple[Union[list[str], Dict[str, Type[Type]]], Union[list[str], Dict[str, Type[Type]]]]:
"""Represents inputs that are used for this to run.
These correspond to the ``**run_kwargs`` in `run` above.
Note that this has two possible return values:
1. A list of strings -- these are the keys that are required to run the function
2. A tuple of two lists of strings -- the first list is the required keys, the second is the optional keys
1. A list of strings/dict of string -> ypes -- these are the keys that are required to run the function
2. A tuple of two lists of strings/dict of strings -> types -- the first list is the required keys, the second is the optional keys
:return: Either a list of strings (required inputs) or a tuple of two lists of strings (required and optional inputs)
"""
return []
return [], []

@property
def input_schema(self) -> Tuple[dict[str, Type[Type]], dict[str, Type[Type]]]:
"""Returns the input schema for the function.
The input schema is a type that can be used to validate the input to the function.
Note that this is separate from inputs() for backwards compatibility --
inputs() can return a schema *or* a tuple of required and optional inputs.
:return: Tuple of required inputs and optional inputs with attached types
"""
inputs = self.inputs
if len(inputs) == 1:
inputs = (inputs[0], {})
out = []
for input_spec in inputs:
if isinstance(input_spec, list):
out.append({key: Any for key in input_spec})
else:
out.append(input_spec)
return tuple(out)

@property
def optional_and_required_inputs(self) -> tuple[set[str], set[str]]:
Expand Down Expand Up @@ -213,11 +236,6 @@ def get_source(self) -> str:
to display a different source"""
return inspect.getsource(self.__class__)

def input_schema(self) -> Any:
"""Returns the input schema for the action.
The input schema is a type that can be used to validate the input to the action"""
return None

def __repr__(self):
read_repr = ", ".join(self.reads) if self.reads else "{}"
write_repr = ", ".join(self.writes) if self.writes else "{}"
Expand Down Expand Up @@ -534,7 +552,9 @@ def is_async(self) -> bool:

# the following exist to share implementation between FunctionBasedStreamingAction and FunctionBasedAction
# TODO -- think through the class hierarchy to simplify, for now this is OK
def derive_inputs_from_fn(bound_params: dict, fn: Callable) -> tuple[list[str], list[str]]:
def derive_inputs_from_fn(
bound_params: dict, fn: Callable
) -> tuple[dict[str, Type[Type]], dict[str, Type[Type]]]:
"""Derives inputs from the function, given the bound parameters. This assumes that the function
has inputs named `state`, as well as any number of other kwarg-boundable parameters.
Expand All @@ -543,20 +563,29 @@ def derive_inputs_from_fn(bound_params: dict, fn: Callable) -> tuple[list[str],
:return: Required and optional inputs
"""
sig = inspect.signature(fn)
required_inputs, optional_inputs = [], []
required_inputs, optional_inputs = {}, {}
for param_name, param in sig.parameters.items():
if param_name != "state" and param_name not in bound_params:
if param.default is inspect.Parameter.empty:
# has no default means its required
required_inputs.append(param_name)
required_inputs[param_name] = (
param.annotation if param.annotation != inspect.Parameter.empty else Any
)
else:
# has a default means its optional
optional_inputs.append(param_name)
optional_inputs[param_name] = (
param.annotation if param.annotation != inspect.Parameter.empty else Any
)
return required_inputs, optional_inputs


FunctionBasedActionType = Union["FunctionBasedAction", "FunctionBasedStreamingAction"]

InputSpec = Tuple[Union[list[str], Dict[str, Type[Type]], Union[list[str], Dict[str, Type[Type]]]]]


OptionalInputType = Optional[InputSpec]


class FunctionBasedAction(SingleStepAction):
ACTION_FUNCTION = "action_function"
Expand All @@ -567,7 +596,9 @@ def __init__(
reads: List[str],
writes: List[str],
bound_params: Optional[dict] = None,
input_spec: Optional[tuple[list[str], list[str]]] = None,
input_spec: Optional[
Tuple[Union[list[str], Dict[str, Type[Type]]], Union[list[str], Dict[str, Type[Type]]]]
] = None,
originating_fn: Optional[Callable] = None,
schema: ActionSchema = DEFAULT_SCHEMA,
):
Expand All @@ -589,11 +620,9 @@ def __init__(
self._inputs = (
derive_inputs_from_fn(self._bound_params, self._fn)
if input_spec is None
else (
[item for item in input_spec[0] if item not in self._bound_params],
[item for item in input_spec[1] if item not in self._bound_params],
)
else input_spec
)
print(self._inputs, self._name)
self._schema = schema

@property
Expand All @@ -609,7 +638,9 @@ def writes(self) -> list[str]:
return self._writes

@property
def inputs(self) -> tuple[list[str], list[str]]:
def inputs(
self,
) -> Tuple[Union[list[str], Dict[str, Type[Type]]], Union[list[str], Dict[str, Type[Type]]]]:
return self._inputs

@property
Expand All @@ -630,7 +661,7 @@ def with_params(self, **kwargs: Any) -> "FunctionBasedAction":
self._reads,
self._writes,
{**self._bound_params, **kwargs},
input_spec=self._inputs,
input_spec=self.input_schema,
originating_fn=self._originating_fn,
schema=self._schema,
)
Expand Down Expand Up @@ -1044,7 +1075,9 @@ def __init__(
reads: List[str],
writes: List[str],
bound_params: Optional[dict] = None,
input_spec: Optional[tuple[list[str], list[str]]] = None,
input_spec: Optional[
Tuple[Union[list[str], Dict[str, Type[Type]]], Union[list[str], Dict[str, Type[Type]]]]
] = None,
originating_fn: Optional[Callable] = None,
schema: ActionSchema = DEFAULT_SCHEMA,
):
Expand Down Expand Up @@ -1111,13 +1144,15 @@ def with_params(self, **kwargs: Any) -> "FunctionBasedStreamingAction":
self._reads,
self._writes,
{**self._bound_params, **kwargs},
input_spec=self._inputs,
input_spec=self.input_schema,
originating_fn=self._originating_fn,
schema=self._schema,
)

@property
def inputs(self) -> tuple[list[str], list[str]]:
def inputs(
self,
) -> Tuple[Union[list[str], Dict[str, Type[Type]]], Union[list[str], Dict[str, Type[Type]]]]:
return self._inputs

@property
Expand Down
Loading

0 comments on commit 7bccbf4

Please sign in to comment.