Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Default state for map actions + other minor testing-related fixes #466

Merged
merged 4 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions burr/core/parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,15 +628,17 @@ def actions(
:return: Generator of actions to run
"""

@abc.abstractmethod
def state(self, state: State, inputs: Dict[str, Any]):
"""Gives the state for each of the actions
"""Gives the state for each of the actions.
By default, this will give out the current state. That said,
you may want to adjust this -- E.G. to translate state into
a format the sub-actions would expect.

:param state: State at the time of running the action
:param inputs: Runtime inputs to the action
:return: State for the action
"""
pass
return state

def states(
self, state: State, context: ApplicationContext, inputs: Dict[str, Any]
Expand Down
12 changes: 6 additions & 6 deletions burr/tracking/server/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,30 +93,30 @@ def from_logs(log_lines: List[bytes]) -> List["Step"]:
json_line = safe_json_load(line)
# TODO -- make these into constants
if json_line["type"] == "begin_entry":
begin_step = BeginEntryModel.parse_obj(json_line)
begin_step = BeginEntryModel.model_validate(json_line)
steps_by_sequence_id[begin_step.sequence_id].step_start_log = begin_step
elif json_line["type"] == "end_entry":
step_end_log = EndEntryModel.parse_obj(json_line)
step_end_log = EndEntryModel.model_validate(json_line)
steps_by_sequence_id[step_end_log.sequence_id].step_end_log = step_end_log
elif json_line["type"] == "begin_span":
span = BeginSpanModel.parse_obj(json_line)
span = BeginSpanModel.model_validate(json_line)
spans_by_id[span.span_id] = PartialSpan(
begin_entry=span,
end_entry=None,
)
elif json_line["type"] == "end_span":
end_span = EndSpanModel.parse_obj(json_line)
end_span = EndSpanModel.model_validate(json_line)
span = spans_by_id[end_span.span_id]
span.end_entry = end_span
elif json_line["type"] == "attribute":
attribute = AttributeModel.parse_obj(json_line)
attribute = AttributeModel.model_validate(json_line)
attributes_by_step[attribute.action_sequence_id].append(attribute)
elif json_line["type"] in ["begin_stream", "first_item_stream", "end_stream"]:
streaming_event = {
"begin_stream": InitializeStreamModel,
"first_item_stream": FirstItemStreamModel,
"end_stream": EndStreamModel,
}[json_line["type"]].parse_obj(json_line)
}[json_line["type"]].model_validate(json_line)
steps_by_sequence_id[streaming_event.sequence_id].streaming_events.append(
streaming_event
)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ documentation = [
]

tracking-client = [
"pydantic"
"pydantic>1"
]

tracking-client-s3 = [
Expand Down
22 changes: 11 additions & 11 deletions tests/core/test_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -3333,7 +3333,7 @@ def load(
builder.with_state_persister(persister)


class TestActionWithoutContext(Action):
class ActionWithoutContext(Action):
def run(self, other_param, foo):
pass

Expand All @@ -3352,23 +3352,23 @@ def inputs(self) -> Union[list[str], tuple[list[str], list[str]]]:
return ["other_param", "foo"]


class TestActionWithContext(TestActionWithoutContext):
class ActionWithContext(ActionWithoutContext):
def run(self, __context, other_param, foo):
pass

def inputs(self) -> Union[list[str], tuple[list[str], list[str]]]:
return ["other_param", "foo", "__context"]


class TestActionWithKwargs(TestActionWithoutContext):
class ActionWithKwargs(ActionWithoutContext):
def run(self, other_param, foo, **kwargs):
pass

def inputs(self) -> Union[list[str], tuple[list[str], list[str]]]:
return ["other_param", "foo", "__context"]


class TestActionWithContextTracer(TestActionWithoutContext):
class ActionWithContextTracer(ActionWithoutContext):
def run(self, __context, other_param, foo, __tracer):
pass

Expand All @@ -3377,27 +3377,27 @@ def inputs(self) -> Union[list[str], tuple[list[str], list[str]]]:


def test_remap_context_variable_with_mangled_context_kwargs():
_action = TestActionWithKwargs()
_action = ActionWithKwargs()

inputs = {"__context": "context_value", "other_key": "other_value", "foo": "foo_value"}
expected = {"__context": "context_value", "other_key": "other_value", "foo": "foo_value"}
assert _remap_dunder_parameters(_action.run, inputs, ["__context", "__tracer"]) == expected


def test_remap_context_variable_with_mangled_context():
_action = TestActionWithContext()
_action = ActionWithContext()

inputs = {"__context": "context_value", "other_key": "other_value", "foo": "foo_value"}
expected = {
f"_{TestActionWithContext.__name__}__context": "context_value",
f"_{ActionWithContext.__name__}__context": "context_value",
"other_key": "other_value",
"foo": "foo_value",
}
assert _remap_dunder_parameters(_action.run, inputs, ["__context", "__tracer"]) == expected


def test_remap_context_variable_with_mangled_contexttracer():
_action = TestActionWithContextTracer()
_action = ActionWithContextTracer()

inputs = {
"__context": "context_value",
Expand All @@ -3406,16 +3406,16 @@ def test_remap_context_variable_with_mangled_contexttracer():
"foo": "foo_value",
}
expected = {
f"_{TestActionWithContextTracer.__name__}__context": "context_value",
f"_{ActionWithContextTracer.__name__}__context": "context_value",
"other_key": "other_value",
"foo": "foo_value",
f"_{TestActionWithContextTracer.__name__}__tracer": "tracer_value",
f"_{ActionWithContextTracer.__name__}__tracer": "tracer_value",
}
assert _remap_dunder_parameters(_action.run, inputs, ["__context", "__tracer"]) == expected


def test_remap_context_variable_without_mangled_context():
_action = TestActionWithoutContext()
_action = ActionWithoutContext()
inputs = {"__context": "context_value", "other_key": "other_value", "foo": "foo_value"}
expected = {"__context": "context_value", "other_key": "other_value", "foo": "foo_value"}
assert _remap_dunder_parameters(_action.run, inputs, ["__context", "__tracer"]) == expected
22 changes: 22 additions & 0 deletions tests/core/test_parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,28 @@ def _group_events_by_app_id(
return grouped_events


def test_map_actions_default_state():
class MapActionsAllApproaches(MapActions):
def actions(
self, state: State, inputs: Dict[str, Any], context: ApplicationContext
) -> Generator[Union[Action, Callable, RunnableGraph], None, None]:
...

def reduce(self, state: State, states: Generator[State, None, None]) -> State:
...

@property
def writes(self) -> list[str]:
return []

@property
def reads(self) -> list[str]:
return []

state_to_test = State({"foo": "bar", "baz": "qux"})
assert MapActionsAllApproaches().state(state_to_test, {}).get_all() == state_to_test.get_all()


def test_e2e_map_actions_sync_subgraph():
"""Tests map actions over multiple action types (runnable graph, function, action class...)"""

Expand Down
5 changes: 2 additions & 3 deletions tests/integrations/test_burr_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pydantic
import pytest
from pydantic import BaseModel, EmailStr, Field
from pydantic import BaseModel, ConfigDict, EmailStr, Field
from pydantic.fields import FieldInfo

from burr.core import expr
Expand Down Expand Up @@ -110,8 +110,7 @@ class MyModelWithConfig(pydantic.BaseModel):
foo: int
arbitrary: Arbitrary

class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(arbitrary_types_allowed=True)
skrawcz marked this conversation as resolved.
Show resolved Hide resolved

SubsetModel = subset_model(MyModelWithConfig, ["foo", "bar"], [], "Subset")
assert SubsetModel.__name__ == "MyModelWithConfigSubset"
Expand Down
Loading