Skip to content

Commit

Permalink
Type-Checking the Zip operator on Dictionary should not crash (#149)
Browse files Browse the repository at this point in the history
* feat: Add test for dictionary with Zip key

Add a new test case to check that the type follows from a dictionary through a Zip operation. The test creates a lambda expression that selects the 'pt' attribute from a dictionary of jets, which are obtained from an event object. The test ensures that the expected exception is raised when an invalid key is used.

Refactor the test_type_based_replacement.py file to include the new test_dictionary_Zip_key() function.

* Fix up type and spelling errors

* Fix up more spelling and type checking errors

* Ignore `zip` when looking at dictionary attributes

* Fix up bad addition
  • Loading branch information
gordonwatts authored Sep 12, 2024
1 parent 95b14c9 commit 29a3406
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 16 deletions.
10 changes: 5 additions & 5 deletions func_adl/ast/function_simplifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,8 @@ def visit_SelectMany_of_SelectMany(self, parent: ast.Call, selection: ast.Lambda
"SelectMany", [cast(ast.AST, captured_body), cast(ast.AST, func_g)]
)
new_select_lambda = lambda_build(captured_arg, new_select)
new_selectmany = function_call("SelectMany", [seq, cast(ast.AST, new_select_lambda)])
return new_selectmany
new_select_many = function_call("SelectMany", [seq, cast(ast.AST, new_select_lambda)])
return new_select_many

def call_SelectMany(self, node: ast.Call, args: List[ast.AST]):
r"""
Expand Down Expand Up @@ -442,7 +442,7 @@ def visit_Subscript_Tuple(self, v: ast.Tuple, s: Union[ast.Num, ast.Constant, as
# Get the value out - this is due to supporting python 3.7-3.9
n = _get_value_from_index(s)
if n is None:
return ast.Subscript(v, s, ast.Load())
return ast.Subscript(v, s, ast.Load()) # type: ignore
assert isinstance(n, int), "Programming error: index is not an integer in tuple subscript"
if n >= len(v.elts):
raise FuncADLIndexError(
Expand All @@ -460,7 +460,7 @@ def visit_Subscript_List(self, v: ast.List, s: Union[ast.Num, ast.Constant, ast.
"""
n = _get_value_from_index(s)
if n is None:
return ast.Subscript(v, s, ast.Load())
return ast.Subscript(v, s, ast.Load()) # type: ignore
if n >= len(v.elts):
raise FuncADLIndexError(
f"Attempt to access the {n}th element of a tuple"
Expand All @@ -484,7 +484,7 @@ def visit_Subscript_Dict_with_value(self, v: ast.Dict, s: Union[str, int]):
if _get_value_from_index(value) == s:
return v.values[index]

return ast.Subscript(v, s, ast.Load())
return ast.Subscript(v, s, ast.Load()) # type: ignore

def visit_Subscript_Of_First(self, first: ast.AST, s):
"""
Expand Down
24 changes: 13 additions & 11 deletions func_adl/type_based_replacement.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def func_adl_callable(
Callable[[ObjectStream[W], ast.Call], Tuple[ObjectStream[W], ast.AST]]
] = None
):
"""Dectorator that will declare a function that can be used inline in
"""Decorator that will declare a function that can be used inline in
a `func_adl` expression. The body of the function, what the backend
translates it to, must be given by another route (e.g. via `MetaData`
and the `processor` argument).
Expand Down Expand Up @@ -388,7 +388,7 @@ def _fill_in_default_arguments(func: Callable, call: ast.Call) -> Tuple[ast.Call
)
return_type = Any

return call, return_type
return call, return_type # type: ignore


def fixup_ast_from_modifications(transformed_ast: ast.AST, original_ast: ast.Call) -> ast.Call:
Expand Down Expand Up @@ -582,7 +582,7 @@ def type_follow_in_callbacks(
# If this is a known collection class, we can use call-backs to follow it.
if get_origin(call_site_info.obj_type) in _g_collection_classes:
rtn_value = self.process_method_call_on_stream_obj(
_g_collection_classes[get_origin(call_site_info.obj_type)],
_g_collection_classes[get_origin(call_site_info.obj_type)], # type: ignore
m_name,
r_node,
get_args(call_site_info.obj_type)[0],
Expand Down Expand Up @@ -631,7 +631,7 @@ def process_method_call(self, node: ast.Call, obj_type: type) -> ast.AST:
Args:
node (ast.Call): The ast node
obj_type (type): The object type this method call is occuring against
obj_type (type): The object type this method call is occurring against
Returns:
ast.AST: An updated ast that is the new method call (with default args, etc.)
Expand All @@ -643,7 +643,7 @@ def process_method_call(self, node: ast.Call, obj_type: type) -> ast.AST:
base_obj_list_all = [obj_type]
if is_iterable(obj_type):
item_type = unwrap_iterable(obj_type)
base_obj_list_all += [c[item_type] for c in _g_collection_classes]
base_obj_list_all += [c[item_type] for c in _g_collection_classes] # type: ignore

assert isinstance(r_node.func, ast.Attribute)
m_name = r_node.func.attr
Expand Down Expand Up @@ -758,7 +758,7 @@ def process_function_call(self, node: ast.Call, func_info: _FuncAdlFunction) ->
f"function {func_info.function.__name__} ({str(e)})"
) from e

def process_paramaterized_method_call(
def process_parameterized_method_call(
self,
node: ast.Call,
obj_type: Type,
Expand Down Expand Up @@ -815,7 +815,7 @@ def visit_Call(self, node: ast.Call) -> ast.AST:
if isinstance(t_node.func.value, ast.Attribute):
found_type = self.lookup_type(t_node.func.value.value)
if found_type is not None:
t_node = self.process_paramaterized_method_call(
t_node = self.process_parameterized_method_call(
t_node,
found_type,
t_node.func.value.attr,
Expand Down Expand Up @@ -941,17 +941,17 @@ def visit_Constant(self, node: ast.Constant) -> Any:
return node

def visit_Num(self, node: ast.Num) -> Any: # pragma: no cover
"3.7 compatability"
"3.7 compatibility"
self._found_types[node] = type(node.n)
return node

def visit_Str(self, node: ast.Str) -> Any: # pragma: no cover
"3.7 compatability"
"3.7 compatibility"
self._found_types[node] = str
return node

def visit_NameConstant(self, node: ast.NameConstant) -> Any: # pragma: no cover
"3.7 compatability"
"3.7 compatibility"
if node.value is None:
raise ValueError("Do not know how to work with pythons None")
self._found_types[node] = bool
Expand All @@ -968,6 +968,8 @@ def visit_Attribute(self, node: ast.Attribute) -> Any:
e for e, k in enumerate(t_node.value.keys) if k.value == key # type: ignore
]
if len(key_index) == 0:
if t_node.attr.lower() == "zip":
return t_node
raise ValueError(f"Key {key} not found in dict expression!!")
value = t_node.value.values[key_index[0]]
self._found_types[node] = self.lookup_type(value)
Expand Down Expand Up @@ -999,7 +1001,7 @@ def remap_from_lambda(
orig_type = o_stream.item_type
var_name = l_func.args.args[0].arg
stream, new_body, return_type = remap_by_types(o_stream, var_name, orig_type, l_func.body)
return stream, ast.Lambda(l_func.args, new_body), return_type
return stream, ast.Lambda(l_func.args, new_body), return_type # type: ignore


def reset_global_functions():
Expand Down
17 changes: 17 additions & 0 deletions tests/test_type_based_replacement.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,23 @@ def test_dictionary_bad_key():
assert "jetsss" in str(e)


def test_dictionary_Zip_key():
"Check that type follow from a dictionary through a Zip works"

s = ast_lambda(
"""({
'jet_pt': e.Jets().Select(lambda j: j.pt()),
'jet_eta': e.Jets().Select(lambda j: j.eta())}
.Zip()
.Select(lambda j: j.pt()))"""
)
objs = ObjectStream[Event](ast.Name(id="e", ctx=ast.Load()))

new_objs, new_s, expr_type = remap_by_types(objs, "e", Event, s)

assert expr_type == Any


def test_dictionary_through_Select():
"""Make sure the Select statement carries the typing all the way through"""

Expand Down

0 comments on commit 29a3406

Please sign in to comment.