Skip to content

Commit

Permalink
Fixes __tracer for class based actions
Browse files Browse the repository at this point in the history
Refactors __context treatment to also handle __tracer.
This now enables one to pass through/request the tracer
object in a class based action.
  • Loading branch information
skrawcz authored and elijahbenizzy committed Dec 4, 2024
1 parent 0555f6c commit 86f4cb2
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 19 deletions.
36 changes: 21 additions & 15 deletions burr/core/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,28 +103,34 @@ def _adjust_single_step_output(
_raise_fn_return_validation_error(output, action_name)


def _remap_context_variable(run_method: Callable, inputs: Dict[str, Any]) -> dict:
"""This is a utility function to remap the __context variable to the mangled variable in the function signature.
def _remap_dunder_parameters(
run_method: Callable, inputs: Dict[str, Any], vars_to_remap: List[str]
) -> dict:
"""This is a utility function to remap the __dunder parameters to the mangled version in the function signature.
Python mangles the variable name in the function signature, so we need to remap it to the correct variable name.
Python mangles __parameter names in the function signature, so we need to remap it to the correct parameter name.
:param run_method: the run method to inspect.
:param inputs: the inputs to inspect
:param vars_to_remap: the variables to remap
:return: potentially new dict with the remapped variable, else the original dict.
"""
# Get the signature of the method being run. This should be Function.run() or similar.
signature = inspect.signature(run_method)
mangled_params: Dict[str, Optional[str]] = {v: None for v in vars_to_remap}
# Find the name-mangled __context variable
mangled_context_name = None
for param in signature.parameters.values():
if param.name.endswith("__context"):
mangled_context_name = param.name
break

# If a mangled __context variable is found, remap the value in inputs
if mangled_context_name and "__context" in inputs:
for dunder_param in mangled_params.keys():
for param in signature.parameters.values():
if param.name.endswith(dunder_param):
mangled_params[dunder_param] = param.name
break

# If any mangled __parameter is found, remap the value in inputs
if any(mangled_params.values()):
inputs = inputs.copy()
inputs[mangled_context_name] = inputs.pop("__context")
for dunder_param, mangled_name in mangled_params.items():
if mangled_name and dunder_param in inputs:
inputs[mangled_name] = inputs.pop(dunder_param)
return inputs


Expand All @@ -146,9 +152,9 @@ def _run_function(function: Function, state: State, inputs: Dict[str, Any], name
)
state_to_use = state.subset(*function.reads)
function.validate_inputs(inputs)
if "__context" in inputs:
# potentially need to remap the __context variable
inputs = _remap_context_variable(function.run, inputs)
if "__context" in inputs or "__tracer" in inputs:
# potentially need to remap the __context & __tracer variables
inputs = _remap_dunder_parameters(function.run, inputs, ["__context", "__tracer"])
result = function.run(state_to_use, **inputs)
_validate_result(result, name)
return result
Expand Down
34 changes: 30 additions & 4 deletions tests/core/test_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
_arun_multi_step_streaming_action,
_arun_single_step_action,
_arun_single_step_streaming_action,
_remap_context_variable,
_remap_dunder_parameters,
_run_function,
_run_multi_step_streaming_action,
_run_reducer,
Expand Down Expand Up @@ -3368,12 +3368,20 @@ def inputs(self) -> Union[list[str], tuple[list[str], list[str]]]:
return ["other_param", "foo", "__context"]


class TestActionWithContextTracer(TestActionWithoutContext):
def run(self, __context, other_param, foo, __tracer):
pass

def inputs(self) -> Union[list[str], tuple[list[str], list[str]]]:
return ["other_param", "foo", "__context", "__tracer"]


def test_remap_context_variable_with_mangled_context_kwargs():
_action = TestActionWithKwargs()

inputs = {"__context": "context_value", "other_key": "other_value", "foo": "foo_value"}
expected = {"__context": "context_value", "other_key": "other_value", "foo": "foo_value"}
assert _remap_context_variable(_action.run, inputs) == expected
assert _remap_dunder_parameters(_action.run, inputs, ["__context", "__tracer"]) == expected


def test_remap_context_variable_with_mangled_context():
Expand All @@ -3385,11 +3393,29 @@ def test_remap_context_variable_with_mangled_context():
"other_key": "other_value",
"foo": "foo_value",
}
assert _remap_context_variable(_action.run, inputs) == expected
assert _remap_dunder_parameters(_action.run, inputs, ["__context", "__tracer"]) == expected


def test_remap_context_variable_with_mangled_contexttracer():
_action = TestActionWithContextTracer()

inputs = {
"__context": "context_value",
"__tracer": "tracer_value",
"other_key": "other_value",
"foo": "foo_value",
}
expected = {
f"_{TestActionWithContextTracer.__name__}__context": "context_value",
"other_key": "other_value",
"foo": "foo_value",
f"_{TestActionWithContextTracer.__name__}__tracer": "tracer_value",
}
assert _remap_dunder_parameters(_action.run, inputs, ["__context", "__tracer"]) == expected


def test_remap_context_variable_without_mangled_context():
_action = TestActionWithoutContext()
inputs = {"__context": "context_value", "other_key": "other_value", "foo": "foo_value"}
expected = {"__context": "context_value", "other_key": "other_value", "foo": "foo_value"}
assert _remap_context_variable(_action.run, inputs) == expected
assert _remap_dunder_parameters(_action.run, inputs, ["__context", "__tracer"]) == expected

0 comments on commit 86f4cb2

Please sign in to comment.